Fix cuda graph break

This commit is contained in:
AdilZouitine 2025-03-31 07:59:56 +00:00
parent 66c3672738
commit 8494634d48
2 changed files with 5 additions and 3 deletions

View File

@ -247,6 +247,9 @@ class SACPolicy(
+ target_param.data * (1.0 - self.config.critic_target_update_weight) + target_param.data * (1.0 - self.config.critic_target_update_weight)
) )
def update_temperature(self):
self.temperature = self.log_alpha.exp().item()
def compute_loss_critic( def compute_loss_critic(
self, self,
observations, observations,
@ -257,7 +260,6 @@ class SACPolicy(
observation_features: Tensor | None = None, observation_features: Tensor | None = None,
next_observation_features: Tensor | None = None, next_observation_features: Tensor | None = None,
) -> Tensor: ) -> Tensor:
self.temperature = self.log_alpha.exp().item()
with torch.no_grad(): with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features) next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
@ -319,8 +321,6 @@ class SACPolicy(
observations, observations,
observation_features: Tensor | None = None, observation_features: Tensor | None = None,
) -> Tensor: ) -> Tensor:
self.temperature = self.log_alpha.exp().item()
actions_pi, log_probs, _ = self.actor(observations, observation_features) actions_pi, log_probs, _ = self.actor(observations, observation_features)
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way # TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way

View File

@ -489,6 +489,8 @@ def add_actor_information_and_train(
training_infos["temperature_grad_norm"] = temp_grad_norm training_infos["temperature_grad_norm"] = temp_grad_norm
training_infos["temperature"] = policy.temperature training_infos["temperature"] = policy.temperature
policy.update_temperature()
# Check if it's time to push updated policy to actors # Check if it's time to push updated policy to actors
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency: if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)