Initialize log_alpha with the logarithm of temperature_init in SACPolicy

- Updated the SACPolicy class to set log_alpha using the logarithm of the initial temperature value from the configuration.
This commit is contained in:
AdilZouitine 2025-03-20 12:55:22 +00:00
parent 2d5effeeba
commit 0eef49a0f6
1 changed files with 8 additions and 4 deletions

View File

@ -17,6 +17,8 @@
# TODO: (1) better device management # TODO: (1) better device management
from copy import deepcopy
import math
from typing import Callable, Optional, Tuple, Union, Dict, List from typing import Callable, Optional, Tuple, Union, Dict, List
from pathlib import Path from pathlib import Path
@ -138,7 +140,9 @@ class SACPolicy(
# TODO (azouitine): Handle the case where the temparameter is a fixed # TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise # TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
# it triggers "can't optimize a non-leaf Tensor" # it triggers "can't optimize a non-leaf Tensor"
self.log_alpha = nn.Parameter(torch.tensor([0.0]))
temperature_init = config.temperature_init
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
self.temperature = self.log_alpha.exp().item() self.temperature = self.log_alpha.exp().item()
def _save_pretrained(self, save_directory): def _save_pretrained(self, save_directory):
@ -636,9 +640,9 @@ class Policy(nn.Module):
# Compute standard deviations # Compute standard deviations
if self.fixed_std is None: if self.fixed_std is None:
log_std = self.std_layer(outputs) log_std = self.std_layer(outputs)
assert not torch.isnan( assert not torch.isnan(log_std).any(), (
log_std "[ERROR] log_std became NaN after std_layer!"
).any(), "[ERROR] log_std became NaN after std_layer!" )
if self.use_tanh_squash: if self.use_tanh_squash:
log_std = torch.tanh(log_std) log_std = torch.tanh(log_std)