From ab2c2d39fba0822c2df5f730fd965510e834c27e Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Tue, 8 Apr 2025 09:31:29 +0000 Subject: [PATCH] fix bug --- lerobot/common/policies/sac/modeling_sac.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 281ffe2e..0e6f8fda 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -138,6 +138,7 @@ class SACPolicy( encoder=encoder_critic, input_dim=encoder_critic.output_dim, output_dim=config.num_discrete_actions, + softmax_temperature=1.0, **asdict(config.grasp_critic_network_kwargs), ) @@ -198,7 +199,7 @@ class SACPolicy( if self.config.num_discrete_actions is not None: _, discrete_action_distribution = self.grasp_critic(batch, observations_features) - discrete_action = discrete_action_distribution.sample() + discrete_action = discrete_action_distribution.sample().unsqueeze(-1).float() actions = torch.cat([actions, discrete_action], dim=-1) return actions @@ -435,7 +436,7 @@ class SACPolicy( next_grasp_qs, next_grasp_distribution = self.grasp_critic_forward( next_observations, use_target=False, observation_features=next_observation_features ) - best_next_grasp_action = next_grasp_distribution.sample() + best_next_grasp_action = next_grasp_distribution.sample().unsqueeze(-1) # Get target Q-values from target network target_next_grasp_qs, _ = self.grasp_critic_forward(