From 1f23ef78891197e9fe22c3993b722b31f3a746da Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Mon, 17 Mar 2025 10:50:28 +0000 Subject: [PATCH] Enhance SAC configuration and policy with gradient clipping and temperature management - Introduced `grad_clip_norm` parameter in SAC configuration for gradient clipping - Updated SACPolicy to store temperature as an instance variable for consistent usage - Modified loss calculations in SACPolicy to utilize the instance temperature - Enhanced MLP and CriticHead to support a customizable final activation function - Implemented gradient clipping in the learner server during training steps for both actor and critic - Added tracking for gradient norms in training information --- .../common/policies/sac/configuration_sac.py | 2 ++ lerobot/common/policies/sac/modeling_sac.py | 34 ++++++++++++++----- lerobot/scripts/server/learner_server.py | 33 ++++++++++++++++++ 3 files changed, 60 insertions(+), 9 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index b834896e..61e08df4 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -84,10 +84,12 @@ class SACConfig: latent_dim: int = 256 target_entropy: float | None = None use_backup_entropy: bool = True + grad_clip_norm: float = 40.0 critic_network_kwargs: dict[str, Any] = field( default_factory=lambda: { "hidden_dims": [256, 256], "activate_final": True, + "final_activation": None, } ) actor_network_kwargs: dict[str, Any] = field( diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index afbbc945..2c4bad5f 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -330,7 +330,7 @@ class SACPolicy( observation_features: Tensor | None = None, next_observation_features: Tensor | None = None, ) -> Tensor: - temperature = self.log_alpha.exp().item() + self.temperature = self.log_alpha.exp().item() with torch.no_grad(): next_action_preds, next_log_probs, _ = self.actor( next_observations, next_observation_features @@ -358,7 +358,7 @@ class SACPolicy( # critics subsample size min_q, _ = q_targets.min(dim=0) # Get values from min operation if self.config.use_backup_entropy: - min_q = min_q - (temperature * next_log_probs) + min_q = min_q - (self.temperature * next_log_probs) td_target = rewards + (1 - done) * self.config.discount * min_q @@ -398,7 +398,7 @@ class SACPolicy( def compute_loss_actor( self, observations, observation_features: Tensor | None = None ) -> Tensor: - temperature = self.log_alpha.exp().item() + self.temperature = self.log_alpha.exp().item() actions_pi, log_probs, _ = self.actor(observations, observation_features) @@ -413,7 +413,7 @@ class SACPolicy( ) min_q_preds = q_preds.min(dim=0)[0] - actor_loss = ((temperature * log_probs) - min_q_preds).mean() + actor_loss = ((self.temperature * log_probs) - min_q_preds).mean() return actor_loss @@ -425,6 +425,7 @@ class MLP(nn.Module): activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), activate_final: bool = False, dropout_rate: Optional[float] = None, + final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None, ): super().__init__() self.activate_final = activate_final @@ -451,11 +452,24 @@ class MLP(nn.Module): if dropout_rate is not None and dropout_rate > 0: layers.append(nn.Dropout(p=dropout_rate)) layers.append(nn.LayerNorm(hidden_dims[i])) - layers.append( - activations - if isinstance(activations, nn.Module) - else getattr(nn, activations)() - ) + + # If we're at the final layer and a final activation is specified, use it + if ( + i + 1 == len(hidden_dims) + and activate_final + and final_activation is not None + ): + layers.append( + final_activation + if isinstance(final_activation, nn.Module) + else getattr(nn, final_activation)() + ) + else: + layers.append( + activations + if isinstance(activations, nn.Module) + else getattr(nn, activations)() + ) self.net = nn.Sequential(*layers) @@ -516,6 +530,7 @@ class CriticHead(nn.Module): activate_final: bool = False, dropout_rate: Optional[float] = None, init_final: Optional[float] = None, + final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None, ): super().__init__() self.net = MLP( @@ -524,6 +539,7 @@ class CriticHead(nn.Module): activations=activations, activate_final=activate_final, dropout_rate=dropout_rate, + final_activation=final_activation, ) self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1) if init_final is not None: diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 7bd4aee0..580eed1a 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -390,6 +390,10 @@ def add_actor_information_and_train( if cfg.resume else None, ) + + # Update the policy config with the grad_clip_norm value from training config if it exists + clip_grad_norm_value = cfg.training.grad_clip_norm + # compile policy policy = torch.compile(policy) assert isinstance(policy, nn.Module) @@ -507,6 +511,12 @@ def add_actor_information_and_train( ) optimizers["critic"].zero_grad() loss_critic.backward() + + # clip gradients + critic_grad_norm = torch.nn.utils.clip_grad_norm_( + policy.critic_ensemble.parameters(), clip_grad_norm_value + ) + optimizers["critic"].step() batch = replay_buffer.sample(batch_size) @@ -541,10 +551,17 @@ def add_actor_information_and_train( ) optimizers["critic"].zero_grad() loss_critic.backward() + + # clip gradients + critic_grad_norm = torch.nn.utils.clip_grad_norm_( + policy.critic_ensemble.parameters(), clip_grad_norm_value + ).item() + optimizers["critic"].step() training_infos = {} training_infos["loss_critic"] = loss_critic.item() + training_infos["critic_grad_norm"] = critic_grad_norm if optimization_step % cfg.training.policy_update_freq == 0: for _ in range(cfg.training.policy_update_freq): @@ -555,19 +572,35 @@ def add_actor_information_and_train( optimizers["actor"].zero_grad() loss_actor.backward() + + # clip gradients + actor_grad_norm = torch.nn.utils.clip_grad_norm_( + policy.actor.parameters_to_optimize, clip_grad_norm_value + ).item() + optimizers["actor"].step() training_infos["loss_actor"] = loss_actor.item() + training_infos["actor_grad_norm"] = actor_grad_norm + # Temperature optimization loss_temperature = policy.compute_loss_temperature( observations=observations, observation_features=observation_features, ) optimizers["temperature"].zero_grad() loss_temperature.backward() + + # clip gradients + temp_grad_norm = torch.nn.utils.clip_grad_norm_( + [policy.log_alpha], clip_grad_norm_value + ).item() + optimizers["temperature"].step() training_infos["loss_temperature"] = loss_temperature.item() + training_infos["temperature_grad_norm"] = temp_grad_norm + training_infos["temperature"] = policy.temperature if ( time.time() - last_time_policy_pushed