change the tanh distribution to match hil serl

Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
AdilZouitine 2025-04-15 08:31:14 +00:00
parent bda2053106
commit eac79a006d
2 changed files with 74 additions and 20 deletions

View File

@ -51,8 +51,8 @@ class ActorNetworkConfig:
@dataclass
class PolicyConfig:
use_tanh_squash: bool = True
log_std_min: int = -5
log_std_max: int = 2
log_std_min: int = 1e-5
log_std_max: int = 10.0
init_final: float = 0.05

View File

@ -27,6 +27,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from torch.distributions import MultivariateNormal, TransformedDistribution, TanhTransform, Transform
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
@ -927,29 +928,20 @@ class Policy(nn.Module):
# Compute standard deviations
if self.fixed_std is None:
log_std = self.std_layer(outputs)
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
if self.use_tanh_squash:
log_std = torch.tanh(log_std)
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
else:
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
std = torch.exp(log_std) # Match JAX "exp"
std = torch.clamp(std, self.log_std_min, self.log_std_max) # Match JAX default clip
else:
log_std = self.fixed_std.expand_as(means)
# uses tanh activation function to squash the action to be in the range of [-1, 1]
normal = torch.distributions.Normal(means, torch.exp(log_std))
x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
log_probs = normal.log_prob(x_t) # Base log probability before Tanh
# Build transformed distribution
dist = TanhMultivariateNormalDiag(loc=means, scale_diag=std)
if self.use_tanh_squash:
actions = torch.tanh(x_t)
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
else:
actions = x_t # No Tanh; raw Gaussian sample
# Sample actions (reparameterized)
actions = dist.rsample()
# Compute log_probs
log_probs = dist.log_prob(actions)
log_probs = log_probs.sum(-1) # Sum over action dimensions
means = torch.tanh(means) if self.use_tanh_squash else means
return actions, log_probs, means
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
@ -1090,6 +1082,68 @@ class SpatialLearnedEmbeddings(nn.Module):
return output
class RescaleFromTanh(Transform):
def __init__(self, low: float = -1, high: float = 1):
super().__init__()
self.low = low
self.high = high
def _call(self, x):
# Rescale from (-1, 1) to (low, high)
return 0.5 * (x + 1.0) * (self.high - self.low) + self.low
def _inverse(self, y):
# Rescale from (low, high) back to (-1, 1)
return 2.0 * (y - self.low) / (self.high - self.low) - 1.0
def log_abs_det_jacobian(self, x, y):
# log|d(rescale)/dx| = sum(log(0.5 * (high - low)))
scale = 0.5 * (self.high - self.low)
return torch.sum(torch.log(scale), dim=-1)
class TanhMultivariateNormalDiag(TransformedDistribution):
def __init__(self, loc, scale_diag, low=None, high=None):
base_dist = MultivariateNormal(loc, torch.diag_embed(scale_diag))
transforms = [TanhTransform(cache_size=1)]
if low is not None and high is not None:
low = torch.as_tensor(low)
high = torch.as_tensor(high)
transforms.insert(0, RescaleFromTanh(low, high))
super().__init__(base_dist, transforms)
def mode(self):
# Mode is mean of base distribution, passed through transforms
x = self.base_dist.mean
for transform in self.transforms:
x = transform(x)
return x
def stddev(self):
std = self.base_dist.stddev
x = std
for transform in self.transforms:
x = transform(x)
return x
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
converted_params = {}
for outer_key, inner_dict in normalization_params.items():