This commit is contained in:
AdilZouitine 2025-04-08 09:31:29 +00:00
parent 9f6f508edb
commit ab2c2d39fb
1 changed files with 3 additions and 2 deletions

View File

@ -138,6 +138,7 @@ class SACPolicy(
encoder=encoder_critic,
input_dim=encoder_critic.output_dim,
output_dim=config.num_discrete_actions,
softmax_temperature=1.0,
**asdict(config.grasp_critic_network_kwargs),
)
@ -198,7 +199,7 @@ class SACPolicy(
if self.config.num_discrete_actions is not None:
_, discrete_action_distribution = self.grasp_critic(batch, observations_features)
discrete_action = discrete_action_distribution.sample()
discrete_action = discrete_action_distribution.sample().unsqueeze(-1).float()
actions = torch.cat([actions, discrete_action], dim=-1)
return actions
@ -435,7 +436,7 @@ class SACPolicy(
next_grasp_qs, next_grasp_distribution = self.grasp_critic_forward(
next_observations, use_target=False, observation_features=next_observation_features
)
best_next_grasp_action = next_grasp_distribution.sample()
best_next_grasp_action = next_grasp_distribution.sample().unsqueeze(-1)
# Get target Q-values from target network
target_next_grasp_qs, _ = self.grasp_critic_forward(