change the tanh distribution to match hil serl

Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
AdilZouitine 2025-04-15 08:31:14 +00:00
parent bda2053106
commit eac79a006d
2 changed files with 74 additions and 20 deletions

View File

@ -51,8 +51,8 @@ class ActorNetworkConfig:
@dataclass @dataclass
class PolicyConfig: class PolicyConfig:
use_tanh_squash: bool = True use_tanh_squash: bool = True
log_std_min: int = -5 log_std_min: int = 1e-5
log_std_max: int = 2 log_std_max: int = 10.0
init_final: float = 0.05 init_final: float = 0.05

View File

@ -27,6 +27,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
from torch import Tensor from torch import Tensor
from torch.distributions import MultivariateNormal, TransformedDistribution, TanhTransform, Transform
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.pretrained import PreTrainedPolicy
@ -927,29 +928,20 @@ class Policy(nn.Module):
# Compute standard deviations # Compute standard deviations
if self.fixed_std is None: if self.fixed_std is None:
log_std = self.std_layer(outputs) log_std = self.std_layer(outputs)
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!" std = torch.exp(log_std) # Match JAX "exp"
std = torch.clamp(std, self.log_std_min, self.log_std_max) # Match JAX default clip
if self.use_tanh_squash:
log_std = torch.tanh(log_std)
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
else:
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
else: else:
log_std = self.fixed_std.expand_as(means) log_std = self.fixed_std.expand_as(means)
# uses tanh activation function to squash the action to be in the range of [-1, 1] # Build transformed distribution
normal = torch.distributions.Normal(means, torch.exp(log_std)) dist = TanhMultivariateNormalDiag(loc=means, scale_diag=std)
x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
log_probs = normal.log_prob(x_t) # Base log probability before Tanh
if self.use_tanh_squash: # Sample actions (reparameterized)
actions = torch.tanh(x_t) actions = dist.rsample()
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
else: # Compute log_probs
actions = x_t # No Tanh; raw Gaussian sample log_probs = dist.log_prob(actions)
log_probs = log_probs.sum(-1) # Sum over action dimensions
means = torch.tanh(means) if self.use_tanh_squash else means
return actions, log_probs, means return actions, log_probs, means
def get_features(self, observations: torch.Tensor) -> torch.Tensor: def get_features(self, observations: torch.Tensor) -> torch.Tensor:
@ -1090,6 +1082,68 @@ class SpatialLearnedEmbeddings(nn.Module):
return output return output
class RescaleFromTanh(Transform):
def __init__(self, low: float = -1, high: float = 1):
super().__init__()
self.low = low
self.high = high
def _call(self, x):
# Rescale from (-1, 1) to (low, high)
return 0.5 * (x + 1.0) * (self.high - self.low) + self.low
def _inverse(self, y):
# Rescale from (low, high) back to (-1, 1)
return 2.0 * (y - self.low) / (self.high - self.low) - 1.0
def log_abs_det_jacobian(self, x, y):
# log|d(rescale)/dx| = sum(log(0.5 * (high - low)))
scale = 0.5 * (self.high - self.low)
return torch.sum(torch.log(scale), dim=-1)
class TanhMultivariateNormalDiag(TransformedDistribution):
def __init__(self, loc, scale_diag, low=None, high=None):
base_dist = MultivariateNormal(loc, torch.diag_embed(scale_diag))
transforms = [TanhTransform(cache_size=1)]
if low is not None and high is not None:
low = torch.as_tensor(low)
high = torch.as_tensor(high)
transforms.insert(0, RescaleFromTanh(low, high))
super().__init__(base_dist, transforms)
def mode(self):
# Mode is mean of base distribution, passed through transforms
x = self.base_dist.mean
for transform in self.transforms:
x = transform(x)
return x
def stddev(self):
std = self.base_dist.stddev
x = std
for transform in self.transforms:
x = transform(x)
return x
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict: def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
converted_params = {} converted_params = {}
for outer_key, inner_dict in normalization_params.items(): for outer_key, inner_dict in normalization_params.items():