fix bug
This commit is contained in:
parent
9f6f508edb
commit
ab2c2d39fb
|
@ -138,6 +138,7 @@ class SACPolicy(
|
|||
encoder=encoder_critic,
|
||||
input_dim=encoder_critic.output_dim,
|
||||
output_dim=config.num_discrete_actions,
|
||||
softmax_temperature=1.0,
|
||||
**asdict(config.grasp_critic_network_kwargs),
|
||||
)
|
||||
|
||||
|
@ -198,7 +199,7 @@ class SACPolicy(
|
|||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
_, discrete_action_distribution = self.grasp_critic(batch, observations_features)
|
||||
discrete_action = discrete_action_distribution.sample()
|
||||
discrete_action = discrete_action_distribution.sample().unsqueeze(-1).float()
|
||||
actions = torch.cat([actions, discrete_action], dim=-1)
|
||||
|
||||
return actions
|
||||
|
@ -435,7 +436,7 @@ class SACPolicy(
|
|||
next_grasp_qs, next_grasp_distribution = self.grasp_critic_forward(
|
||||
next_observations, use_target=False, observation_features=next_observation_features
|
||||
)
|
||||
best_next_grasp_action = next_grasp_distribution.sample()
|
||||
best_next_grasp_action = next_grasp_distribution.sample().unsqueeze(-1)
|
||||
|
||||
# Get target Q-values from target network
|
||||
target_next_grasp_qs, _ = self.grasp_critic_forward(
|
||||
|
|
Loading…
Reference in New Issue