diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 52c564a6..a324294c 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -53,13 +53,13 @@ class SACConfig: critic_network_kwargs = { "hidden_dims": [256, 256], "activate_final": True, - } + } actor_network_kwargs = { "hidden_dims": [256, 256], "activate_final": True, - } + } policy_kwargs = { "use_tanh_squash": True, "log_std_min": -5, "log_std_max": 2, - } + } diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 1e7fd92b..9df2c859 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -19,7 +19,6 @@ from collections import deque from copy import deepcopy -import math from typing import Callable, Optional, Sequence, Tuple import einops @@ -125,9 +124,9 @@ class SACPolicy( # perform image augmentation - # reward bias from HIL-SERL code base + # reward bias from HIL-SERL code base # add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch - + # calculate critics loss # 1- compute actions from policy action_preds, log_probs = self.actor(next_observations) @@ -137,21 +136,23 @@ class SACPolicy( # subsample critics to prevent overfitting if use high UTD (update to date) if self.config.num_subsample_critics is not None: indices = torch.randperm(self.config.num_critics) - indices = indices[:self.config.num_subsample_critics] + indices = indices[: self.config.num_subsample_critics] q_targets = q_targets[indices] # critics subsample size min_q = q_targets.min(dim=0) # compute td target - td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term + td_target = ( + rewards + self.config.discount * min_q + ) # + self.config.discount * self.temperature() * log_probs # add entropy term # 3- compute predicted qs q_preds = self.critic_ensemble(observations, actions) # 4- Calculate loss # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. - #critics_loss = ( + # critics_loss = ( # ( # F.mse_loss( # q_preds, @@ -167,14 +168,20 @@ class SACPolicy( # ) # .sum(0) # .mean() - #) + # ) # 4- Calculate loss # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. - critics_loss = F.mse_loss( - q_preds, # shape: [num_critics, batch_size] - einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape - reduction="none" - ).sum(0).mean() + critics_loss = ( + F.mse_loss( + q_preds, # shape: [num_critics, batch_size] + einops.repeat( + td_target, "b -> e b", e=q_preds.shape[0] + ), # expand td_target to match q_preds shape + reduction="none", + ) + .sum(0) + .mean() + ) # calculate actors loss # 1- temperature @@ -229,10 +236,10 @@ class MLP(nn.Module): super().__init__() self.activate_final = activate_final layers = [] - + for i, size in enumerate(hidden_dims): - layers.append(nn.Linear(hidden_dims[i-1] if i > 0 else hidden_dims[0], size)) - + layers.append(nn.Linear(hidden_dims[i - 1] if i > 0 else hidden_dims[0], size)) + if i + 1 < len(hidden_dims) or activate_final: if dropout_rate is not None and dropout_rate > 0: layers.append(nn.Dropout(p=dropout_rate)) @@ -255,20 +262,20 @@ class Critic(nn.Module): encoder: Optional[nn.Module], network: nn.Module, init_final: Optional[float] = None, - device: str = "cuda" + device: str = "cuda", ): super().__init__() self.device = torch.device(device) self.encoder = encoder self.network = network self.init_final = init_final - + # Find the last Linear layer's output dimension for layer in reversed(network.net): if isinstance(layer, nn.Linear): out_features = layer.out_features break - + # Output layer if init_final is not None: self.output_layer = nn.Linear(out_features, 1) @@ -305,7 +312,7 @@ class Policy(nn.Module): fixed_std: Optional[torch.Tensor] = None, init_final: Optional[float] = None, use_tanh_squash: bool = False, - device: str = "cuda" + device: str = "cuda", ): super().__init__() self.device = torch.device(device) @@ -316,13 +323,13 @@ class Policy(nn.Module): self.log_std_max = log_std_max self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None self.use_tanh_squash = use_tanh_squash - + # Find the last Linear layer's output dimension for layer in reversed(network.net): if isinstance(layer, nn.Linear): out_features = layer.out_features break - + # Mean layer self.mean_layer = nn.Linear(out_features, action_dim) if init_final is not None: @@ -339,21 +346,16 @@ class Policy(nn.Module): nn.init.uniform_(self.std_layer.bias, -init_final, init_final) else: orthogonal_init()(self.std_layer.weight) - + self.to(self.device) def forward( self, observations: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - # Encode observations if encoder exists - if self.encoder is not None: - with torch.set_grad_enabled(train): - obs_enc = self.encoder(observations, train=train) - else: - obs_enc = observations - + obs_enc = observations if self.encoder is not None else self.encoder(observations) + # Get network outputs outputs = self.network(obs_enc) means = self.mean_layer(outputs) @@ -369,15 +371,15 @@ class Policy(nn.Module): # uses tahn activation function to squash the action to be in the range of [-1, 1] normal = torch.distributions.Normal(means, stds) - x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) + x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) log_probs = normal.log_prob(x_t) if self.use_tanh_squash: actions = torch.tanh(x_t) log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) - log_probs = log_probs.sum(-1) # sum over action dim + log_probs = log_probs.sum(-1) # sum over action dim return actions, log_probs - + def get_features(self, observations: torch.Tensor) -> torch.Tensor: """Get encoded features from observations""" observations = observations.to(self.device) @@ -507,10 +509,6 @@ def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: s return nn.ModuleList(critics).to(device) -def orthogonal_init(): - return lambda x: torch.nn.init.orthogonal_(x, gain=1.0) - - # borrowed from tdmpc def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: """Helper to temporarily flatten extra dims at the start of the image tensor.