From 0eef49a0f62a8fba76c79fec52abbbbdef03ae73 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Thu, 20 Mar 2025 12:55:22 +0000 Subject: [PATCH] Initialize log_alpha with the logarithm of temperature_init in SACPolicy - Updated the SACPolicy class to set log_alpha using the logarithm of the initial temperature value from the configuration. --- lerobot/common/policies/sac/modeling_sac.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 3e99dbd4..266c306d 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -17,6 +17,8 @@ # TODO: (1) better device management +from copy import deepcopy +import math from typing import Callable, Optional, Tuple, Union, Dict, List from pathlib import Path @@ -138,7 +140,9 @@ class SACPolicy( # TODO (azouitine): Handle the case where the temparameter is a fixed # TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise # it triggers "can't optimize a non-leaf Tensor" - self.log_alpha = nn.Parameter(torch.tensor([0.0])) + + temperature_init = config.temperature_init + self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)])) self.temperature = self.log_alpha.exp().item() def _save_pretrained(self, save_directory): @@ -636,9 +640,9 @@ class Policy(nn.Module): # Compute standard deviations if self.fixed_std is None: log_std = self.std_layer(outputs) - assert not torch.isnan( - log_std - ).any(), "[ERROR] log_std became NaN after std_layer!" + assert not torch.isnan(log_std).any(), ( + "[ERROR] log_std became NaN after std_layer!" + ) if self.use_tanh_squash: log_std = torch.tanh(log_std)