diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index a5df2533..9e157a26 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -129,7 +129,7 @@ class SACPolicy( encoder=encoder_critic, input_dim=encoder_critic.output_dim, output_dim=config.num_discrete_actions, - softmax_temperature=1.0, + softmax_temperature=.15, **asdict(config.grasp_critic_network_kwargs), ) @@ -138,7 +138,7 @@ class SACPolicy( encoder=encoder_critic, input_dim=encoder_critic.output_dim, output_dim=config.num_discrete_actions, - softmax_temperature=1.0, + softmax_temperature=0.15, **asdict(config.grasp_critic_network_kwargs), ) @@ -786,6 +786,7 @@ class GraspCritic(nn.Module): super().__init__() self.encoder = encoder self.output_dim = output_dim + self.softmax_temperature = softmax_temperature self.net = MLP( input_dim=input_dim,