change the tanh distribution to match hil serl
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
parent
bda2053106
commit
eac79a006d
|
@ -51,8 +51,8 @@ class ActorNetworkConfig:
|
|||
@dataclass
|
||||
class PolicyConfig:
|
||||
use_tanh_squash: bool = True
|
||||
log_std_min: int = -5
|
||||
log_std_max: int = 2
|
||||
log_std_min: int = 1e-5
|
||||
log_std_max: int = 10.0
|
||||
init_final: float = 0.05
|
||||
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.distributions import MultivariateNormal, TransformedDistribution, TanhTransform, Transform
|
||||
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
|
@ -927,29 +928,20 @@ class Policy(nn.Module):
|
|||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
log_std = self.std_layer(outputs)
|
||||
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||
|
||||
if self.use_tanh_squash:
|
||||
log_std = torch.tanh(log_std)
|
||||
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
|
||||
else:
|
||||
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
|
||||
std = torch.exp(log_std) # Match JAX "exp"
|
||||
std = torch.clamp(std, self.log_std_min, self.log_std_max) # Match JAX default clip
|
||||
else:
|
||||
log_std = self.fixed_std.expand_as(means)
|
||||
|
||||
# uses tanh activation function to squash the action to be in the range of [-1, 1]
|
||||
normal = torch.distributions.Normal(means, torch.exp(log_std))
|
||||
x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
|
||||
log_probs = normal.log_prob(x_t) # Base log probability before Tanh
|
||||
# Build transformed distribution
|
||||
dist = TanhMultivariateNormalDiag(loc=means, scale_diag=std)
|
||||
|
||||
if self.use_tanh_squash:
|
||||
actions = torch.tanh(x_t)
|
||||
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
|
||||
else:
|
||||
actions = x_t # No Tanh; raw Gaussian sample
|
||||
# Sample actions (reparameterized)
|
||||
actions = dist.rsample()
|
||||
|
||||
# Compute log_probs
|
||||
log_probs = dist.log_prob(actions)
|
||||
|
||||
log_probs = log_probs.sum(-1) # Sum over action dimensions
|
||||
means = torch.tanh(means) if self.use_tanh_squash else means
|
||||
return actions, log_probs, means
|
||||
|
||||
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -1090,6 +1082,68 @@ class SpatialLearnedEmbeddings(nn.Module):
|
|||
return output
|
||||
|
||||
|
||||
class RescaleFromTanh(Transform):
|
||||
def __init__(self, low: float = -1, high: float = 1):
|
||||
super().__init__()
|
||||
|
||||
self.low = low
|
||||
|
||||
self.high = high
|
||||
|
||||
def _call(self, x):
|
||||
# Rescale from (-1, 1) to (low, high)
|
||||
|
||||
return 0.5 * (x + 1.0) * (self.high - self.low) + self.low
|
||||
|
||||
def _inverse(self, y):
|
||||
# Rescale from (low, high) back to (-1, 1)
|
||||
|
||||
return 2.0 * (y - self.low) / (self.high - self.low) - 1.0
|
||||
|
||||
def log_abs_det_jacobian(self, x, y):
|
||||
# log|d(rescale)/dx| = sum(log(0.5 * (high - low)))
|
||||
|
||||
scale = 0.5 * (self.high - self.low)
|
||||
|
||||
return torch.sum(torch.log(scale), dim=-1)
|
||||
|
||||
|
||||
class TanhMultivariateNormalDiag(TransformedDistribution):
|
||||
def __init__(self, loc, scale_diag, low=None, high=None):
|
||||
base_dist = MultivariateNormal(loc, torch.diag_embed(scale_diag))
|
||||
|
||||
transforms = [TanhTransform(cache_size=1)]
|
||||
|
||||
if low is not None and high is not None:
|
||||
low = torch.as_tensor(low)
|
||||
|
||||
high = torch.as_tensor(high)
|
||||
|
||||
transforms.insert(0, RescaleFromTanh(low, high))
|
||||
|
||||
super().__init__(base_dist, transforms)
|
||||
|
||||
def mode(self):
|
||||
# Mode is mean of base distribution, passed through transforms
|
||||
|
||||
x = self.base_dist.mean
|
||||
|
||||
for transform in self.transforms:
|
||||
x = transform(x)
|
||||
|
||||
return x
|
||||
|
||||
def stddev(self):
|
||||
std = self.base_dist.stddev
|
||||
|
||||
x = std
|
||||
|
||||
for transform in self.transforms:
|
||||
x = transform(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
||||
converted_params = {}
|
||||
for outer_key, inner_dict in normalization_params.items():
|
||||
|
|
Loading…
Reference in New Issue