From 5b4adc00bb3da018cf10cbde6e120fd5e890c179 Mon Sep 17 00:00:00 2001 From: KeWang1017 Date: Sun, 29 Dec 2024 12:30:39 +0000 Subject: [PATCH] Refactor SAC configuration and policy for improved action sampling and stability - Updated SACConfig to replace standard deviation parameterization with log_std_min and log_std_max for better control over action distributions. - Modified SACPolicy to streamline action selection and log probability calculations, enhancing stochastic behavior. - Removed deprecated TanhMultivariateNormalDiag class to simplify the codebase and improve maintainability. These changes aim to enhance the robustness and performance of the SAC implementation during training and inference. --- .../common/policies/sac/configuration_sac.py | 27 +- lerobot/common/policies/sac/modeling_sac.py | 233 +++--------------- 2 files changed, 43 insertions(+), 217 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 7a4bd364..52c564a6 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -53,30 +53,13 @@ class SACConfig: critic_network_kwargs = { "hidden_dims": [256, 256], "activate_final": True, - } + } actor_network_kwargs = { "hidden_dims": [256, 256], "activate_final": True, - } + } policy_kwargs = { - "tanh_squash_distribution": True, - "std_parameterization": "softplus", - "std_min": 0.005, - "std_max": 5.0, + "use_tanh_squash": True, + "log_std_min": -5, + "log_std_max": 2, } - ) - output_shapes: dict[str, list[int]] = field( - default_factory=lambda: { - "action": [4], - } - ) - - state_encoder_hidden_dim: int = 256 - latent_dim: int = 256 - network_hidden_dims: int = 256 - - # Normalization / Unnormalization - input_normalization_modes: dict[str, str] | None = None - output_normalization_modes: dict[str, str] = field( - default_factory=lambda: {"action": "min_max"}, - ) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 806cb767..1e7fd92b 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -102,9 +102,7 @@ class SACPolicy( @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select action for inference/evaluation""" - distribution = self.actor(batch) - # Sample from the distribution and return just the actions - actions = distribution.mode() # or distribution.sample() for stochastic actions + actions, _ = self.actor(batch) actions = self.unnormalize_outputs({"action": actions})["action"] return actions @@ -129,12 +127,11 @@ class SACPolicy( # reward bias from HIL-SERL code base # add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch - + # calculate critics loss # 1- compute actions from policy - distribution = self.actor(observations) - action_preds = distribution.sample() - action_preds = torch.clamp(action_preds, -1, +1) + action_preds, log_probs = self.actor(next_observations) + # 2- compute q targets q_targets = self.target_qs(next_observations, action_preds) # subsample critics to prevent overfitting if use high UTD (update to date) @@ -147,7 +144,7 @@ class SACPolicy( min_q = q_targets.min(dim=0) # compute td target - td_target = rewards + self.discount * min_q + td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term # 3- compute predicted qs q_preds = self.critic_ensemble(observations, actions) @@ -178,18 +175,12 @@ class SACPolicy( einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape reduction="none" ).sum(0).mean() - # breakpoint() # calculate actors loss # 1- temperature temperature = self.temperature() - # 2- get actions (batch_size, action_dim) and log probs (batch_size,) - distribution = self.actor(observations) - actions = distribution.rsample() - log_probs = distribution.log_prob(actions).sum(-1) - # breakpoint() - actions = torch.clamp(actions, -1, +1) + actions, log_probs = self.actor(observations) # 3- get q-value predictions with torch.no_grad(): q_preds = self.critic_ensemble(observations, actions, return_type="mean") @@ -264,15 +255,13 @@ class Critic(nn.Module): encoder: Optional[nn.Module], network: nn.Module, init_final: Optional[float] = None, - activate_final: bool = False, - device: str = "cuda", + device: str = "cuda" ): super().__init__() self.device = torch.device(device) self.encoder = encoder self.network = network self.init_final = init_final - self.activate_final = activate_final # Find the last Linear layer's output dimension for layer in reversed(network.net): @@ -304,22 +293,6 @@ class Critic(nn.Module): value = self.output_layer(x) return value.squeeze(-1) - def q_value_ensemble( - self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False - ) -> torch.Tensor: - observations = observations.to(self.device) - actions = actions.to(self.device) - - if len(actions.shape) == 3: # [batch_size, num_actions, action_dim] - batch_size, num_actions = actions.shape[:2] - obs_expanded = observations.unsqueeze(1).expand(-1, num_actions, -1) - obs_flat = obs_expanded.reshape(-1, observations.shape[-1]) - actions_flat = actions.reshape(-1, actions.shape[-1]) - q_values = self(obs_flat, actions_flat, train) - return q_values.reshape(batch_size, num_actions) - else: - return self(observations, actions, train) - class Policy(nn.Module): def __init__( @@ -327,26 +300,22 @@ class Policy(nn.Module): encoder: Optional[nn.Module], network: nn.Module, action_dim: int, - std_parameterization: str = "exp", - std_min: float = 0.05, - std_max: float = 2.0, - tanh_squash_distribution: bool = False, + log_std_min: float = -5, + log_std_max: float = 2, fixed_std: Optional[torch.Tensor] = None, init_final: Optional[float] = None, - activate_final: bool = False, - device: str = "cuda", + use_tanh_squash: bool = False, + device: str = "cuda" ): super().__init__() self.device = torch.device(device) self.encoder = encoder self.network = network self.action_dim = action_dim - self.std_parameterization = std_parameterization - self.std_min = std_min - self.std_max = std_max - self.tanh_squash_distribution = tanh_squash_distribution + self.log_std_min = log_std_min + self.log_std_max = log_std_max self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None - self.activate_final = activate_final + self.use_tanh_squash = use_tanh_squash # Find the last Linear layer's output dimension for layer in reversed(network.net): @@ -364,27 +333,20 @@ class Policy(nn.Module): # Standard deviation layer or parameter if fixed_std is None: - if std_parameterization == "uniform": - self.log_stds = nn.Parameter(torch.zeros(action_dim, device=self.device)) + self.std_layer = nn.Linear(out_features, action_dim) + if init_final is not None: + nn.init.uniform_(self.std_layer.weight, -init_final, init_final) + nn.init.uniform_(self.std_layer.bias, -init_final, init_final) else: - self.std_layer = nn.Linear(out_features, action_dim) - if init_final is not None: - nn.init.uniform_(self.std_layer.weight, -init_final, init_final) - nn.init.uniform_(self.std_layer.bias, -init_final, init_final) - else: - orthogonal_init()(self.std_layer.weight) - + orthogonal_init()(self.std_layer.weight) + self.to(self.device) def forward( self, observations: torch.Tensor, - temperature: float = 1.0, - train: bool = False, - non_squash_distribution: bool = False, - ) -> torch.distributions.Distribution: - self.train(train) - + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Encode observations if encoder exists if self.encoder is not None: with torch.set_grad_enabled(train): @@ -398,41 +360,24 @@ class Policy(nn.Module): # Compute standard deviations if self.fixed_std is None: - if self.std_parameterization == "exp": - log_stds = self.std_layer(outputs) - # Clamp log_stds to prevent too large or small values - log_stds = torch.clamp(log_stds, math.log(self.std_min), math.log(self.std_max)) - stds = torch.exp(log_stds) - elif self.std_parameterization == "softplus": - stds = torch.nn.functional.softplus(self.std_layer(outputs)) - stds = torch.clamp(stds, self.std_min, self.std_max) - elif self.std_parameterization == "uniform": - log_stds = torch.clamp(self.log_stds, math.log(self.std_min), math.log(self.std_max)) - stds = torch.exp(log_stds).expand_as(means) - else: - raise ValueError(f"Invalid std_parameterization: {self.std_parameterization}") + log_std = self.std_layer(outputs) + if self.use_tanh_squash: + log_std = torch.tanh(log_std) + log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) else: - assert self.std_parameterization == "fixed" stds = self.fixed_std.expand_as(means) - # Scale with temperature - temperature = torch.tensor(temperature, device=self.device) - stds = torch.clamp(stds, self.std_min, self.std_max) * torch.sqrt(temperature) - - # Create distribution - if self.tanh_squash_distribution and not non_squash_distribution: - distribution = TanhMultivariateNormalDiag( - loc=means, - scale_diag=stds, - ) - else: - distribution = torch.distributions.Normal( - loc=means, - scale=stds, - ) - - return distribution + # uses tahn activation function to squash the action to be in the range of [-1, 1] + normal = torch.distributions.Normal(means, stds) + x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) + log_probs = normal.log_prob(x_t) + if self.use_tanh_squash: + actions = torch.tanh(x_t) + log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) + log_probs = log_probs.sum(-1) # sum over action dim + return actions, log_probs + def get_features(self, observations: torch.Tensor) -> torch.Tensor: """Get encoded features from observations""" observations = observations.to(self.device) @@ -552,110 +497,8 @@ class LagrangeMultiplier(nn.Module): return multiplier * diff -# The TanhMultivariateNormalDiag is a probability distribution that represents a transformed normal (Gaussian) distribution where: -# 1. The base distribution is a diagonal multivariate normal distribution -# 2. The samples from this normal distribution are transformed through a tanh function, which squashes the values to be between -1 and 1 -# 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation -# This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces -class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): - DEFAULT_SAMPLE_SHAPE = torch.Size() - - def __init__( - self, - loc: torch.Tensor, - scale_diag: torch.Tensor, - low: Optional[torch.Tensor] = None, - high: Optional[torch.Tensor] = None, - ): - # Create base normal distribution - base_distribution = torch.distributions.Normal(loc=loc, scale=scale_diag) - - # Create list of transforms - transforms = [] - - # Add tanh transform - transforms.append(torch.distributions.transforms.TanhTransform()) - - # Add rescaling transform if bounds are provided - if low is not None and high is not None: - transforms.append( - torch.distributions.transforms.AffineTransform(loc=(high + low) / 2, scale=(high - low) / 2) - ) - - # Initialize parent class - super().__init__(base_distribution=base_distribution, transforms=transforms) - - # Store parameters - self.loc = loc - self.scale_diag = scale_diag - self.low = low - self.high = high - - def mode(self) -> torch.Tensor: - """Get the mode of the transformed distribution""" - # The mode of a normal distribution is its mean - mode = self.loc - # Apply transforms - for transform in self.transforms: - mode = transform(mode) - - return mode - - def rsample(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> torch.Tensor: - """ - Reparameterized sample from the distribution - """ - # Sample from base distributionrsample - x = self.base_dist.rsample(sample_shape) - - # Apply transforms - for transform in self.transforms: - x = transform(x) - - return x - - def log_prob(self, value: torch.Tensor) -> torch.Tensor: - """ - Compute log probability of a value - Includes the log det jacobian for the transforms - """ - # Initialize log prob - log_prob = torch.zeros_like(value) - - # Inverse transforms to get back to normal distribution - q = value - for transform in reversed(self.transforms): - q_prev = transform.inv(q) # Get the pre-transform value - log_prob = log_prob - transform.log_abs_det_jacobian(q_prev, q) # Sum over action dimensions - q = q_prev - - # Add base distribution log prob - log_prob = log_prob + self.base_dist.log_prob(q) # Sum over action dimensions - - return log_prob - - def sample_and_log_prob(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Sample from the distribution and compute log probability - """ - x = self.rsample(sample_shape) - log_prob = self.log_prob(x) - return x, log_prob - - # def entropy(self) -> torch.Tensor: - # """ - # Compute entropy of the distribution - # """ - # # Start with base distribution entropy - # entropy = self.base_dist.entropy().sum(-1) - - # # Add log det jacobian for each transform - # x = self.rsample() - # for transform in self.transforms: - # entropy = entropy + transform.log_abs_det_jacobian(x, transform(x)) - # x = transform(x) - - # return entropy +def orthogonal_init(): + return lambda x: torch.nn.init.orthogonal_(x, gain=1.0) def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cuda") -> nn.ModuleList: