diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py
index 281ffe2e..0e6f8fda 100644
--- a/lerobot/common/policies/sac/modeling_sac.py
+++ b/lerobot/common/policies/sac/modeling_sac.py
@@ -138,6 +138,7 @@ class SACPolicy(
                 encoder=encoder_critic,
                 input_dim=encoder_critic.output_dim,
                 output_dim=config.num_discrete_actions,
+                softmax_temperature=1.0,
                 **asdict(config.grasp_critic_network_kwargs),
             )
 
@@ -198,7 +199,7 @@ class SACPolicy(
 
         if self.config.num_discrete_actions is not None:
             _, discrete_action_distribution = self.grasp_critic(batch, observations_features)
-            discrete_action = discrete_action_distribution.sample()
+            discrete_action = discrete_action_distribution.sample().unsqueeze(-1).float()
             actions = torch.cat([actions, discrete_action], dim=-1)
 
         return actions
@@ -435,7 +436,7 @@ class SACPolicy(
             next_grasp_qs, next_grasp_distribution = self.grasp_critic_forward(
                 next_observations, use_target=False, observation_features=next_observation_features
             )
-            best_next_grasp_action = next_grasp_distribution.sample()
+            best_next_grasp_action = next_grasp_distribution.sample().unsqueeze(-1)
 
             # Get target Q-values from target network
             target_next_grasp_qs, _ = self.grasp_critic_forward(