diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 0e886f90..9bac1c3e 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -51,8 +51,8 @@ class ActorNetworkConfig: @dataclass class PolicyConfig: use_tanh_squash: bool = True - log_std_min: int = -5 - log_std_max: int = 2 + log_std_min: int = 1e-5 + log_std_max: int = 10.0 init_final: float = 0.05 diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index ed965393..d0f83325 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -27,6 +27,7 @@ import torch import torch.nn as nn import torch.nn.functional as F # noqa: N812 from torch import Tensor +from torch.distributions import MultivariateNormal, TransformedDistribution, TanhTransform, Transform from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy @@ -927,29 +928,20 @@ class Policy(nn.Module): # Compute standard deviations if self.fixed_std is None: log_std = self.std_layer(outputs) - assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!" - - if self.use_tanh_squash: - log_std = torch.tanh(log_std) - log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0) - else: - log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) + std = torch.exp(log_std) # Match JAX "exp" + std = torch.clamp(std, self.log_std_min, self.log_std_max) # Match JAX default clip else: log_std = self.fixed_std.expand_as(means) - # uses tanh activation function to squash the action to be in the range of [-1, 1] - normal = torch.distributions.Normal(means, torch.exp(log_std)) - x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1)) - log_probs = normal.log_prob(x_t) # Base log probability before Tanh + # Build transformed distribution + dist = TanhMultivariateNormalDiag(loc=means, scale_diag=std) - if self.use_tanh_squash: - actions = torch.tanh(x_t) - log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh - else: - actions = x_t # No Tanh; raw Gaussian sample + # Sample actions (reparameterized) + actions = dist.rsample() + + # Compute log_probs + log_probs = dist.log_prob(actions) - log_probs = log_probs.sum(-1) # Sum over action dimensions - means = torch.tanh(means) if self.use_tanh_squash else means return actions, log_probs, means def get_features(self, observations: torch.Tensor) -> torch.Tensor: @@ -1090,6 +1082,68 @@ class SpatialLearnedEmbeddings(nn.Module): return output +class RescaleFromTanh(Transform): + def __init__(self, low: float = -1, high: float = 1): + super().__init__() + + self.low = low + + self.high = high + + def _call(self, x): + # Rescale from (-1, 1) to (low, high) + + return 0.5 * (x + 1.0) * (self.high - self.low) + self.low + + def _inverse(self, y): + # Rescale from (low, high) back to (-1, 1) + + return 2.0 * (y - self.low) / (self.high - self.low) - 1.0 + + def log_abs_det_jacobian(self, x, y): + # log|d(rescale)/dx| = sum(log(0.5 * (high - low))) + + scale = 0.5 * (self.high - self.low) + + return torch.sum(torch.log(scale), dim=-1) + + +class TanhMultivariateNormalDiag(TransformedDistribution): + def __init__(self, loc, scale_diag, low=None, high=None): + base_dist = MultivariateNormal(loc, torch.diag_embed(scale_diag)) + + transforms = [TanhTransform(cache_size=1)] + + if low is not None and high is not None: + low = torch.as_tensor(low) + + high = torch.as_tensor(high) + + transforms.insert(0, RescaleFromTanh(low, high)) + + super().__init__(base_dist, transforms) + + def mode(self): + # Mode is mean of base distribution, passed through transforms + + x = self.base_dist.mean + + for transform in self.transforms: + x = transform(x) + + return x + + def stddev(self): + std = self.base_dist.stddev + + x = std + + for transform in self.transforms: + x = transform(x) + + return x + + def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict: converted_params = {} for outer_key, inner_dict in normalization_params.items():