change the tanh distribution to match hil serl
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
parent
bda2053106
commit
eac79a006d
|
@ -51,8 +51,8 @@ class ActorNetworkConfig:
|
||||||
@dataclass
|
@dataclass
|
||||||
class PolicyConfig:
|
class PolicyConfig:
|
||||||
use_tanh_squash: bool = True
|
use_tanh_squash: bool = True
|
||||||
log_std_min: int = -5
|
log_std_min: int = 1e-5
|
||||||
log_std_max: int = 2
|
log_std_max: int = 10.0
|
||||||
init_final: float = 0.05
|
init_final: float = 0.05
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch.distributions import MultivariateNormal, TransformedDistribution, TanhTransform, Transform
|
||||||
|
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
|
@ -927,29 +928,20 @@ class Policy(nn.Module):
|
||||||
# Compute standard deviations
|
# Compute standard deviations
|
||||||
if self.fixed_std is None:
|
if self.fixed_std is None:
|
||||||
log_std = self.std_layer(outputs)
|
log_std = self.std_layer(outputs)
|
||||||
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
|
std = torch.exp(log_std) # Match JAX "exp"
|
||||||
|
std = torch.clamp(std, self.log_std_min, self.log_std_max) # Match JAX default clip
|
||||||
if self.use_tanh_squash:
|
|
||||||
log_std = torch.tanh(log_std)
|
|
||||||
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
|
|
||||||
else:
|
|
||||||
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
|
|
||||||
else:
|
else:
|
||||||
log_std = self.fixed_std.expand_as(means)
|
log_std = self.fixed_std.expand_as(means)
|
||||||
|
|
||||||
# uses tanh activation function to squash the action to be in the range of [-1, 1]
|
# Build transformed distribution
|
||||||
normal = torch.distributions.Normal(means, torch.exp(log_std))
|
dist = TanhMultivariateNormalDiag(loc=means, scale_diag=std)
|
||||||
x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
|
|
||||||
log_probs = normal.log_prob(x_t) # Base log probability before Tanh
|
|
||||||
|
|
||||||
if self.use_tanh_squash:
|
# Sample actions (reparameterized)
|
||||||
actions = torch.tanh(x_t)
|
actions = dist.rsample()
|
||||||
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
|
|
||||||
else:
|
# Compute log_probs
|
||||||
actions = x_t # No Tanh; raw Gaussian sample
|
log_probs = dist.log_prob(actions)
|
||||||
|
|
||||||
log_probs = log_probs.sum(-1) # Sum over action dimensions
|
|
||||||
means = torch.tanh(means) if self.use_tanh_squash else means
|
|
||||||
return actions, log_probs, means
|
return actions, log_probs, means
|
||||||
|
|
||||||
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
||||||
|
@ -1090,6 +1082,68 @@ class SpatialLearnedEmbeddings(nn.Module):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class RescaleFromTanh(Transform):
|
||||||
|
def __init__(self, low: float = -1, high: float = 1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.low = low
|
||||||
|
|
||||||
|
self.high = high
|
||||||
|
|
||||||
|
def _call(self, x):
|
||||||
|
# Rescale from (-1, 1) to (low, high)
|
||||||
|
|
||||||
|
return 0.5 * (x + 1.0) * (self.high - self.low) + self.low
|
||||||
|
|
||||||
|
def _inverse(self, y):
|
||||||
|
# Rescale from (low, high) back to (-1, 1)
|
||||||
|
|
||||||
|
return 2.0 * (y - self.low) / (self.high - self.low) - 1.0
|
||||||
|
|
||||||
|
def log_abs_det_jacobian(self, x, y):
|
||||||
|
# log|d(rescale)/dx| = sum(log(0.5 * (high - low)))
|
||||||
|
|
||||||
|
scale = 0.5 * (self.high - self.low)
|
||||||
|
|
||||||
|
return torch.sum(torch.log(scale), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class TanhMultivariateNormalDiag(TransformedDistribution):
|
||||||
|
def __init__(self, loc, scale_diag, low=None, high=None):
|
||||||
|
base_dist = MultivariateNormal(loc, torch.diag_embed(scale_diag))
|
||||||
|
|
||||||
|
transforms = [TanhTransform(cache_size=1)]
|
||||||
|
|
||||||
|
if low is not None and high is not None:
|
||||||
|
low = torch.as_tensor(low)
|
||||||
|
|
||||||
|
high = torch.as_tensor(high)
|
||||||
|
|
||||||
|
transforms.insert(0, RescaleFromTanh(low, high))
|
||||||
|
|
||||||
|
super().__init__(base_dist, transforms)
|
||||||
|
|
||||||
|
def mode(self):
|
||||||
|
# Mode is mean of base distribution, passed through transforms
|
||||||
|
|
||||||
|
x = self.base_dist.mean
|
||||||
|
|
||||||
|
for transform in self.transforms:
|
||||||
|
x = transform(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def stddev(self):
|
||||||
|
std = self.base_dist.stddev
|
||||||
|
|
||||||
|
x = std
|
||||||
|
|
||||||
|
for transform in self.transforms:
|
||||||
|
x = transform(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
||||||
converted_params = {}
|
converted_params = {}
|
||||||
for outer_key, inner_dict in normalization_params.items():
|
for outer_key, inner_dict in normalization_params.items():
|
||||||
|
|
Loading…
Reference in New Issue