diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 6827da6a..f3eb8e94 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -247,6 +247,9 @@ class SACPolicy( + target_param.data * (1.0 - self.config.critic_target_update_weight) ) + def update_temperature(self): + self.temperature = self.log_alpha.exp().item() + def compute_loss_critic( self, observations, @@ -257,7 +260,6 @@ class SACPolicy( observation_features: Tensor | None = None, next_observation_features: Tensor | None = None, ) -> Tensor: - self.temperature = self.log_alpha.exp().item() with torch.no_grad(): next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features) @@ -319,8 +321,6 @@ class SACPolicy( observations, observation_features: Tensor | None = None, ) -> Tensor: - self.temperature = self.log_alpha.exp().item() - actions_pi, log_probs, _ = self.actor(observations, observation_features) # TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 244a6a47..2334d2e0 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -489,6 +489,8 @@ def add_actor_information_and_train( training_infos["temperature_grad_norm"] = temp_grad_norm training_infos["temperature"] = policy.temperature + policy.update_temperature() + # Check if it's time to push updated policy to actors if time.time() - last_time_policy_pushed > policy_parameters_push_frequency: push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)