From 3a2308d86f2090e1b9e00b81533fa70ae4d220b2 Mon Sep 17 00:00:00 2001 From: s1lent4gnt Date: Mon, 31 Mar 2025 18:06:21 +0200 Subject: [PATCH] Add grasp critic to the training loop - Integrated the grasp critic gradient update to the training loop in learner_server - Added Adam optimizer and configured grasp critic learning rate in configuration_sac - Added target critics networks update after the critics gradient step --- .../common/policies/sac/configuration_sac.py | 1 + lerobot/common/policies/sac/modeling_sac.py | 19 ++++++++-- lerobot/scripts/server/learner_server.py | 35 +++++++++++++++++++ 3 files changed, 53 insertions(+), 2 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index e47185da..b1ce30b6 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -167,6 +167,7 @@ class SACConfig(PreTrainedConfig): num_critics: int = 2 num_subsample_critics: int | None = None critic_lr: float = 3e-4 + grasp_critic_lr: float = 3e-4 actor_lr: float = 3e-4 temperature_lr: float = 3e-4 critic_target_update_weight: float = 0.005 diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 3589ad25..bd74c65b 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -214,7 +214,7 @@ class SACPolicy( def forward( self, batch: dict[str, Tensor | dict[str, Tensor]], - model: Literal["actor", "critic", "temperature"] = "critic", + model: Literal["actor", "critic", "grasp_critic", "temperature"] = "critic", ) -> dict[str, Tensor]: """Compute the loss for the given model @@ -227,7 +227,7 @@ class SACPolicy( - done: Done mask tensor - observation_feature: Optional pre-computed observation features - next_observation_feature: Optional pre-computed next observation features - model: Which model to compute the loss for ("actor", "critic", or "temperature") + model: Which model to compute the loss for ("actor", "critic", "grasp_critic", or "temperature") Returns: The computed loss tensor @@ -254,6 +254,21 @@ class SACPolicy( observation_features=observation_features, next_observation_features=next_observation_features, ) + + if model == "grasp_critic": + # Extract grasp_critic-specific components + complementary_info: dict[str, Tensor] = batch["complementary_info"] + + return self.compute_loss_grasp_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + observation_features=observation_features, + next_observation_features=next_observation_features, + complementary_info=complementary_info, + ) if model == "actor": return self.compute_loss_actor( diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 98d2dbd8..f79e8d57 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -375,6 +375,7 @@ def add_actor_information_and_train( observations = batch["state"] next_observations = batch["next_state"] done = batch["done"] + complementary_info = batch["complementary_info"] check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) observation_features, next_observation_features = get_observation_features( @@ -390,6 +391,7 @@ def add_actor_information_and_train( "done": done, "observation_feature": observation_features, "next_observation_feature": next_observation_features, + "complementary_info": complementary_info, } # Use the forward method for critic loss @@ -404,7 +406,20 @@ def add_actor_information_and_train( optimizers["critic"].step() + # Add gripper critic optimization + loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic") + optimizers["grasp_critic"].zero_grad() + loss_grasp_critic.backward() + + # clip gradients + grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value + ) + + optimizers["grasp_critic"].step() + policy.update_target_networks() + policy.update_grasp_target_networks() batch = replay_buffer.sample(batch_size=batch_size) @@ -435,6 +450,7 @@ def add_actor_information_and_train( "done": done, "observation_feature": observation_features, "next_observation_feature": next_observation_features, + "complementary_info": complementary_info, } # Use the forward method for critic loss @@ -453,6 +469,22 @@ def add_actor_information_and_train( training_infos["loss_critic"] = loss_critic.item() training_infos["critic_grad_norm"] = critic_grad_norm + # Add gripper critic optimization + loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic") + optimizers["grasp_critic"].zero_grad() + loss_grasp_critic.backward() + + # clip gradients + grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value + ) + + optimizers["grasp_critic"].step() + + # Add training info for the grasp critic + training_infos["loss_grasp_critic"] = loss_grasp_critic.item() + training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm + if optimization_step % policy_update_freq == 0: for _ in range(policy_update_freq): # Use the forward method for actor loss @@ -495,6 +527,7 @@ def add_actor_information_and_train( last_time_policy_pushed = time.time() policy.update_target_networks() + policy.update_grasp_target_networks() # Log training metrics at specified intervals if optimization_step % log_freq == 0: @@ -729,11 +762,13 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module): lr=cfg.policy.actor_lr, ) optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) + optimizer_grasp_critic = torch.optim.Adam(params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr) optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) lr_scheduler = None optimizers = { "actor": optimizer_actor, "critic": optimizer_critic, + "grasp_critic": optimizer_grasp_critic, "temperature": optimizer_temperature, } return optimizers, lr_scheduler