From 306c735172a6ffd0d1f1fa0866f29abd8d2c597c Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Tue, 1 Apr 2025 11:42:28 +0000 Subject: [PATCH] Refactor SAC policy and training loop to enhance discrete action support - Updated SACPolicy to conditionally compute losses for grasp critic based on num_discrete_actions. - Simplified forward method to return loss outputs as a dictionary for better clarity. - Adjusted learner_server to handle both main and grasp critic losses during training. - Ensured optimizers are created conditionally for grasp critic based on configuration settings. --- .../common/policies/sac/configuration_sac.py | 2 +- lerobot/common/policies/sac/modeling_sac.py | 54 ++++---- lerobot/scripts/server/learner_server.py | 120 +++++++++--------- 3 files changed, 86 insertions(+), 90 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 66d9aa45..ae38b1c5 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -87,6 +87,7 @@ class SACConfig(PreTrainedConfig): freeze_vision_encoder: Whether to freeze the vision encoder during training. image_encoder_hidden_dim: Hidden dimension size for the image encoder. shared_encoder: Whether to use a shared encoder for actor and critic. + num_discrete_actions: Number of discrete actions, eg for gripper actions. concurrency: Configuration for concurrency settings. actor_learner: Configuration for actor-learner architecture. online_steps: Number of steps for online training. @@ -162,7 +163,6 @@ class SACConfig(PreTrainedConfig): num_critics: int = 2 num_subsample_critics: int | None = None critic_lr: float = 3e-4 - grasp_critic_lr: float = 3e-4 actor_lr: float = 3e-4 temperature_lr: float = 3e-4 critic_target_update_weight: float = 0.005 diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index dd156918..d0e8b25d 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -228,7 +228,7 @@ class SACPolicy( def forward( self, batch: dict[str, Tensor | dict[str, Tensor]], - model: Literal["actor", "critic", "grasp_critic", "temperature"] = "critic", + model: Literal["actor", "critic", "temperature"] = "critic", ) -> dict[str, Tensor]: """Compute the loss for the given model @@ -246,7 +246,6 @@ class SACPolicy( Returns: The computed loss tensor """ - # TODO: (maractingi, azouitine) Respect the function signature we output tensors # Extract common components from batch actions: Tensor = batch["action"] observations: dict[str, Tensor] = batch["state"] @@ -259,7 +258,7 @@ class SACPolicy( done: Tensor = batch["done"] next_observation_features: Tensor = batch.get("next_observation_feature") - return self.compute_loss_critic( + loss_critic = self.compute_loss_critic( observations=observations, actions=actions, rewards=rewards, @@ -268,29 +267,28 @@ class SACPolicy( observation_features=observation_features, next_observation_features=next_observation_features, ) + if self.config.num_discrete_actions is not None: + loss_grasp_critic = self.compute_loss_grasp_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + ) + return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic} - if model == "grasp_critic": - return self.compute_loss_grasp_critic( - observations=observations, - actions=actions, - rewards=rewards, - next_observations=next_observations, - done=done, - observation_features=observation_features, - next_observation_features=next_observation_features, - ) if model == "actor": - return self.compute_loss_actor( + return {"loss_actor": self.compute_loss_actor( observations=observations, observation_features=observation_features, - ) + )} if model == "temperature": - return self.compute_loss_temperature( + return {"loss_temperature": self.compute_loss_temperature( observations=observations, observation_features=observation_features, - ) + )} raise ValueError(f"Unknown model type: {model}") @@ -305,18 +303,16 @@ class SACPolicy( param.data * self.config.critic_target_update_weight + target_param.data * (1.0 - self.config.critic_target_update_weight) ) - - def update_grasp_target_networks(self): - """Update grasp target networks with exponential moving average""" - for target_param, param in zip( - self.grasp_critic_target.parameters(), - self.grasp_critic.parameters(), - strict=False, - ): - target_param.data.copy_( - param.data * self.config.critic_target_update_weight - + target_param.data * (1.0 - self.config.critic_target_update_weight) - ) + if self.config.num_discrete_actions is not None: + for target_param, param in zip( + self.grasp_critic_target.parameters(), + self.grasp_critic.parameters(), + strict=False, + ): + target_param.data.copy_( + param.data * self.config.critic_target_update_weight + + target_param.data * (1.0 - self.config.critic_target_update_weight) + ) def update_temperature(self): self.temperature = self.log_alpha.exp().item() diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 15de2cb7..627a1a17 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -392,32 +392,30 @@ def add_actor_information_and_train( "next_observation_feature": next_observation_features, } - # Use the forward method for critic loss - loss_critic = policy.forward(forward_batch, model="critic") + # Use the forward method for critic loss (includes both main critic and grasp critic) + critic_output = policy.forward(forward_batch, model="critic") + + # Main critic optimization + loss_critic = critic_output["loss_critic"] optimizers["critic"].zero_grad() loss_critic.backward() - - # clip gradients critic_grad_norm = torch.nn.utils.clip_grad_norm_( parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value ) - optimizers["critic"].step() - # Add gripper critic optimization - loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic") - optimizers["grasp_critic"].zero_grad() - loss_grasp_critic.backward() - - # clip gradients - grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value - ) - - optimizers["grasp_critic"].step() + # Grasp critic optimization (if available) + if "loss_grasp_critic" in critic_output and hasattr(policy, "grasp_critic"): + loss_grasp_critic = critic_output["loss_grasp_critic"] + optimizers["grasp_critic"].zero_grad() + loss_grasp_critic.backward() + grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value + ) + optimizers["grasp_critic"].step() + # Update target networks policy.update_target_networks() - policy.update_grasp_target_networks() batch = replay_buffer.sample(batch_size=batch_size) @@ -450,81 +448,80 @@ def add_actor_information_and_train( "next_observation_feature": next_observation_features, } - # Use the forward method for critic loss - loss_critic = policy.forward(forward_batch, model="critic") + # Use the forward method for critic loss (includes both main critic and grasp critic) + critic_output = policy.forward(forward_batch, model="critic") + + # Main critic optimization + loss_critic = critic_output["loss_critic"] optimizers["critic"].zero_grad() loss_critic.backward() - - # clip gradients critic_grad_norm = torch.nn.utils.clip_grad_norm_( parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value ).item() - optimizers["critic"].step() - training_infos = {} - training_infos["loss_critic"] = loss_critic.item() - training_infos["critic_grad_norm"] = critic_grad_norm + # Initialize training info dictionary + training_infos = { + "loss_critic": loss_critic.item(), + "critic_grad_norm": critic_grad_norm, + } - # Add gripper critic optimization - loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic") - optimizers["grasp_critic"].zero_grad() - loss_grasp_critic.backward() - - # clip gradients - grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value - ) - - optimizers["grasp_critic"].step() - - # Add training info for the grasp critic - training_infos["loss_grasp_critic"] = loss_grasp_critic.item() - training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm + # Grasp critic optimization (if available) + if "loss_grasp_critic" in critic_output: + loss_grasp_critic = critic_output["loss_grasp_critic"] + optimizers["grasp_critic"].zero_grad() + loss_grasp_critic.backward() + grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value + ).item() + optimizers["grasp_critic"].step() + + # Add grasp critic info to training info + training_infos["loss_grasp_critic"] = loss_grasp_critic.item() + training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm + # Actor and temperature optimization (at specified frequency) if optimization_step % policy_update_freq == 0: for _ in range(policy_update_freq): - # Use the forward method for actor loss - loss_actor = policy.forward(forward_batch, model="actor") - + # Actor optimization + actor_output = policy.forward(forward_batch, model="actor") + loss_actor = actor_output["loss_actor"] optimizers["actor"].zero_grad() loss_actor.backward() - - # clip gradients actor_grad_norm = torch.nn.utils.clip_grad_norm_( parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value ).item() - optimizers["actor"].step() - + + # Add actor info to training info training_infos["loss_actor"] = loss_actor.item() training_infos["actor_grad_norm"] = actor_grad_norm - # Temperature optimization using forward method - loss_temperature = policy.forward(forward_batch, model="temperature") + # Temperature optimization + temperature_output = policy.forward(forward_batch, model="temperature") + loss_temperature = temperature_output["loss_temperature"] optimizers["temperature"].zero_grad() loss_temperature.backward() - - # clip gradients temp_grad_norm = torch.nn.utils.clip_grad_norm_( parameters=[policy.log_alpha], max_norm=clip_grad_norm_value ).item() - optimizers["temperature"].step() - + + # Add temperature info to training info training_infos["loss_temperature"] = loss_temperature.item() training_infos["temperature_grad_norm"] = temp_grad_norm training_infos["temperature"] = policy.temperature + # Update temperature policy.update_temperature() - # Check if it's time to push updated policy to actors + # Push policy to actors if needed if time.time() - last_time_policy_pushed > policy_parameters_push_frequency: push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) last_time_policy_pushed = time.time() + # Update target networks policy.update_target_networks() - policy.update_grasp_target_networks() # Log training metrics at specified intervals if optimization_step % log_freq == 0: @@ -727,7 +724,7 @@ def save_training_checkpoint( logging.info("Resume training") -def make_optimizers_and_scheduler(cfg, policy: nn.Module): +def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module): """ Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy. @@ -759,17 +756,20 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module): lr=cfg.policy.actor_lr, ) optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) - optimizer_grasp_critic = torch.optim.Adam( - params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr - ) + + if cfg.policy.num_discrete_actions is not None: + optimizer_grasp_critic = torch.optim.Adam( + params=policy.grasp_critic.parameters(), lr=policy.critic_lr + ) optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) lr_scheduler = None optimizers = { "actor": optimizer_actor, "critic": optimizer_critic, - "grasp_critic": optimizer_grasp_critic, "temperature": optimizer_temperature, } + if cfg.policy.num_discrete_actions is not None: + optimizers["grasp_critic"] = optimizer_grasp_critic return optimizers, lr_scheduler