diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 87170d20..821cb93f 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -19,6 +19,7 @@ from collections import deque from copy import deepcopy +import math from typing import Callable, Optional, Sequence, Tuple import einops @@ -100,7 +101,12 @@ class SACPolicy( @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: - actions, _ = self.actor(batch['observations']) + """Select action for inference/evaluation""" + distribution = self.actor(batch) + # Sample from the distribution and return just the actions + actions = distribution.mode() # or distribution.sample() for stochastic actions + actions = self.unnormalize_outputs({"action": actions})["action"] + return actions def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: """Run the batch through the model and compute the loss. @@ -126,7 +132,10 @@ class SACPolicy( # calculate critics loss # 1- compute actions from policy - action_preds, log_probs = self.actor(observations) + distribution = self.actor(observations) + action_preds = distribution.sample() + log_probs = distribution.log_prob(action_preds) + action_preds = torch.clamp(action_preds, -1, +1) # 2- compute q targets q_targets = self.target_qs(next_observations, action_preds) # subsample critics to prevent overfitting if use high UTD (update to date) @@ -146,31 +155,46 @@ class SACPolicy( # 4- Calculate loss # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. - critics_loss = ( - ( - F.mse_loss( - q_preds, - einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]), - reduction="none", - ).sum(0) # sum over ensemble - # `q_preds_ensemble` depends on the first observation and the actions. - * ~batch["observation.state_is_pad"][0] - * ~batch["action_is_pad"] - # q_targets depends on the reward and the next observations. - * ~batch["next.reward_is_pad"] - * ~batch["observation.state_is_pad"][1:] - ) - .sum(0) - .mean() - ) + #critics_loss = ( + # ( + # F.mse_loss( + # q_preds, + # einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]), + # reduction="none", + # ).sum(0) # sum over ensemble + # # `q_preds_ensemble` depends on the first observation and the actions. + # * ~batch["observation.state_is_pad"][0] + # * ~batch["action_is_pad"] + # # q_targets depends on the reward and the next observations. + # * ~batch["next.reward_is_pad"] + # * ~batch["observation.state_is_pad"][1:] + # ) + # .sum(0) + # .mean() + #) + # 4- Calculate loss + # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. + critics_loss = F.mse_loss( + q_preds, # shape: [num_critics, batch_size] + einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape + reduction="none" + ).sum(0).mean() + # breakpoint() # calculate actors loss # 1- temperature temperature = self.temperature() # 2- get actions (batch_size, action_dim) and log probs (batch_size,) +<<<<<<< HEAD actions, log_probs = self.actor(observations) \ +======= + distribution = self.actor(observations) + actions = distribution.sample() + log_probs = distribution.log_prob(actions) + actions = torch.clamp(actions, -1, +1) +>>>>>>> d3c62b92 (Refactor SACPolicy for improved action sampling and standard deviation handling) # 3- get q-value predictions with torch.no_grad(): q_preds = self.critic_ensemble(observations, actions, return_type="mean") @@ -309,8 +333,8 @@ class Policy(nn.Module): network: nn.Module, action_dim: int, std_parameterization: str = "exp", - std_min: float = 1e-5, - std_max: float = 10.0, + std_min: float = 0.05, + std_max: float = 2.0, tanh_squash_distribution: bool = False, fixed_std: Optional[torch.Tensor] = None, init_final: Optional[float] = None, @@ -372,6 +396,7 @@ class Policy(nn.Module): obs_enc = self.encoder(observations, train=train) else: obs_enc = observations + # Get network outputs outputs = self.network(obs_enc) means = self.mean_layer(outputs) @@ -380,18 +405,22 @@ class Policy(nn.Module): if self.fixed_std is None: if self.std_parameterization == "exp": log_stds = self.std_layer(outputs) + # Clamp log_stds to prevent too large or small values + log_stds = torch.clamp(log_stds, math.log(self.std_min), math.log(self.std_max)) stds = torch.exp(log_stds) elif self.std_parameterization == "softplus": stds = torch.nn.functional.softplus(self.std_layer(outputs)) + stds = torch.clamp(stds, self.std_min, self.std_max) elif self.std_parameterization == "uniform": - stds = torch.exp(self.log_stds).expand_as(means) + log_stds = torch.clamp(self.log_stds, math.log(self.std_min), math.log(self.std_max)) + stds = torch.exp(log_stds).expand_as(means) else: raise ValueError(f"Invalid std_parameterization: {self.std_parameterization}") else: assert self.std_parameterization == "fixed" stds = self.fixed_std.expand_as(means) - # Clip standard deviations and scale with temperature + # Scale with temperature temperature = torch.tensor(temperature, device=self.device) stds = torch.clamp(stds, self.std_min, self.std_max) * torch.sqrt(temperature)