Fix cuda graph break
This commit is contained in:
parent
66c3672738
commit
8494634d48
|
@ -247,6 +247,9 @@ class SACPolicy(
|
||||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def update_temperature(self):
|
||||||
|
self.temperature = self.log_alpha.exp().item()
|
||||||
|
|
||||||
def compute_loss_critic(
|
def compute_loss_critic(
|
||||||
self,
|
self,
|
||||||
observations,
|
observations,
|
||||||
|
@ -257,7 +260,6 @@ class SACPolicy(
|
||||||
observation_features: Tensor | None = None,
|
observation_features: Tensor | None = None,
|
||||||
next_observation_features: Tensor | None = None,
|
next_observation_features: Tensor | None = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
self.temperature = self.log_alpha.exp().item()
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
|
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
|
||||||
|
|
||||||
|
@ -319,8 +321,6 @@ class SACPolicy(
|
||||||
observations,
|
observations,
|
||||||
observation_features: Tensor | None = None,
|
observation_features: Tensor | None = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
self.temperature = self.log_alpha.exp().item()
|
|
||||||
|
|
||||||
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
||||||
|
|
||||||
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
|
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
|
||||||
|
|
|
@ -489,6 +489,8 @@ def add_actor_information_and_train(
|
||||||
training_infos["temperature_grad_norm"] = temp_grad_norm
|
training_infos["temperature_grad_norm"] = temp_grad_norm
|
||||||
training_infos["temperature"] = policy.temperature
|
training_infos["temperature"] = policy.temperature
|
||||||
|
|
||||||
|
policy.update_temperature()
|
||||||
|
|
||||||
# Check if it's time to push updated policy to actors
|
# Check if it's time to push updated policy to actors
|
||||||
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
|
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
|
||||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||||
|
|
Loading…
Reference in New Issue