Initialize log_alpha with the logarithm of temperature_init in SACPolicy
- Updated the SACPolicy class to set log_alpha using the logarithm of the initial temperature value from the configuration.
This commit is contained in:
parent
2d5effeeba
commit
0eef49a0f6
|
@ -17,6 +17,8 @@
|
||||||
|
|
||||||
# TODO: (1) better device management
|
# TODO: (1) better device management
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
import math
|
||||||
from typing import Callable, Optional, Tuple, Union, Dict, List
|
from typing import Callable, Optional, Tuple, Union, Dict, List
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -138,7 +140,9 @@ class SACPolicy(
|
||||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
||||||
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
||||||
# it triggers "can't optimize a non-leaf Tensor"
|
# it triggers "can't optimize a non-leaf Tensor"
|
||||||
self.log_alpha = nn.Parameter(torch.tensor([0.0]))
|
|
||||||
|
temperature_init = config.temperature_init
|
||||||
|
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
|
||||||
self.temperature = self.log_alpha.exp().item()
|
self.temperature = self.log_alpha.exp().item()
|
||||||
|
|
||||||
def _save_pretrained(self, save_directory):
|
def _save_pretrained(self, save_directory):
|
||||||
|
@ -636,9 +640,9 @@ class Policy(nn.Module):
|
||||||
# Compute standard deviations
|
# Compute standard deviations
|
||||||
if self.fixed_std is None:
|
if self.fixed_std is None:
|
||||||
log_std = self.std_layer(outputs)
|
log_std = self.std_layer(outputs)
|
||||||
assert not torch.isnan(
|
assert not torch.isnan(log_std).any(), (
|
||||||
log_std
|
"[ERROR] log_std became NaN after std_layer!"
|
||||||
).any(), "[ERROR] log_std became NaN after std_layer!"
|
)
|
||||||
|
|
||||||
if self.use_tanh_squash:
|
if self.use_tanh_squash:
|
||||||
log_std = torch.tanh(log_std)
|
log_std = torch.tanh(log_std)
|
||||||
|
|
Loading…
Reference in New Issue