Fix cuda graph break

This commit is contained in:
AdilZouitine 2025-03-31 07:59:56 +00:00
parent 66c3672738
commit 8494634d48
2 changed files with 5 additions and 3 deletions
lerobot
common/policies/sac
scripts/server

View File

@ -247,6 +247,9 @@ class SACPolicy(
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
def update_temperature(self):
self.temperature = self.log_alpha.exp().item()
def compute_loss_critic(
self,
observations,
@ -257,7 +260,6 @@ class SACPolicy(
observation_features: Tensor | None = None,
next_observation_features: Tensor | None = None,
) -> Tensor:
self.temperature = self.log_alpha.exp().item()
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
@ -319,8 +321,6 @@ class SACPolicy(
observations,
observation_features: Tensor | None = None,
) -> Tensor:
self.temperature = self.log_alpha.exp().item()
actions_pi, log_probs, _ = self.actor(observations, observation_features)
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way

View File

@ -489,6 +489,8 @@ def add_actor_information_and_train(
training_infos["temperature_grad_norm"] = temp_grad_norm
training_infos["temperature"] = policy.temperature
policy.update_temperature()
# Check if it's time to push updated policy to actors
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)