From 699d374d895f8facdd2e5b66e379508e2466c97f Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Tue, 1 Apr 2025 15:43:29 +0000 Subject: [PATCH] Refactor SACPolicy for improved readability and action dimension handling - Cleaned up code formatting for better readability, including consistent spacing and removal of unnecessary blank lines. - Consolidated continuous action dimension calculation to enhance clarity and maintainability. - Simplified loss return statements in the forward method to improve code structure. - Ensured grasp critic parameters are included conditionally based on configuration settings. --- lerobot/common/policies/sac/modeling_sac.py | 59 +++++++++++---------- lerobot/scripts/server/learner_server.py | 14 ++--- 2 files changed, 37 insertions(+), 36 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 0c3d76d2..41ff7d8c 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -33,7 +33,8 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.sac.configuration_sac import SACConfig from lerobot.common.policies.utils import get_device_from_parameters -DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension +DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension + class SACPolicy( PreTrainedPolicy, @@ -50,6 +51,10 @@ class SACPolicy( config.validate_features() self.config = config + continuous_action_dim = config.output_features["action"].shape[0] + if config.num_discrete_actions is not None: + continuous_action_dim -= 1 + if config.dataset_stats is not None: input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats) self.normalize_inputs = Normalize( @@ -117,10 +122,7 @@ class SACPolicy( self.grasp_critic = None self.grasp_critic_target = None - continuous_action_dim = config.output_features["action"].shape[0] if config.num_discrete_actions is not None: - - continuous_action_dim -= 1 # Create grasp critic self.grasp_critic = GraspCritic( encoder=encoder_critic, @@ -142,7 +144,6 @@ class SACPolicy( self.grasp_critic = torch.compile(self.grasp_critic) self.grasp_critic_target = torch.compile(self.grasp_critic_target) - self.actor = Policy( encoder=encoder_actor, network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)), @@ -162,11 +163,14 @@ class SACPolicy( self.temperature = self.log_alpha.exp().item() def get_optim_params(self) -> dict: - return { + optim_params = { "actor": self.actor.parameters_to_optimize, "critic": self.critic_ensemble.parameters_to_optimize, "temperature": self.log_alpha, } + if self.config.num_discrete_actions is not None: + optim_params["grasp_critic"] = self.grasp_critic.parameters_to_optimize + return optim_params def reset(self): """Reset the policy""" @@ -262,7 +266,7 @@ class SACPolicy( done: Tensor = batch["done"] next_observation_features: Tensor = batch.get("next_observation_feature") - loss_critic = self.compute_loss_critic( + loss_critic = self.compute_loss_critic( observations=observations, actions=actions, rewards=rewards, @@ -283,18 +287,21 @@ class SACPolicy( return {"loss_critic": loss_critic} - if model == "actor": - return {"loss_actor": self.compute_loss_actor( - observations=observations, - observation_features=observation_features, - )} + return { + "loss_actor": self.compute_loss_actor( + observations=observations, + observation_features=observation_features, + ) + } if model == "temperature": - return {"loss_temperature": self.compute_loss_temperature( - observations=observations, - observation_features=observation_features, - )} + return { + "loss_temperature": self.compute_loss_temperature( + observations=observations, + observation_features=observation_features, + ) + } raise ValueError(f"Unknown model type: {model}") @@ -366,7 +373,7 @@ class SACPolicy( # In the buffer we have the full action space (continuous + discrete) # We need to split them before concatenating them in the critic forward actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX] - + q_preds = self.critic_forward( observations=observations, actions=actions, @@ -407,15 +414,13 @@ class SACPolicy( # For DQN, select actions using online network, evaluate with target network next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=False) best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1) - + # Get target Q-values from target network target_next_grasp_qs = self.grasp_critic_forward(observations=next_observations, use_target=True) - + # Use gather to select Q-values for best actions target_next_grasp_q = torch.gather( - target_next_grasp_qs, - dim=1, - index=best_next_grasp_action.unsqueeze(-1) + target_next_grasp_qs, dim=1, index=best_next_grasp_action.unsqueeze(-1) ).squeeze(-1) # Compute target Q-value with Bellman equation @@ -423,13 +428,9 @@ class SACPolicy( # Get predicted Q-values for current observations predicted_grasp_qs = self.grasp_critic_forward(observations=observations, use_target=False) - + # Use gather to select Q-values for taken actions - predicted_grasp_q = torch.gather( - predicted_grasp_qs, - dim=1, - index=actions.unsqueeze(-1) - ).squeeze(-1) + predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions.unsqueeze(-1)).squeeze(-1) # Compute MSE loss between predicted and target Q-values grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q) @@ -642,7 +643,7 @@ class GraspCritic(nn.Module): self, encoder: Optional[nn.Module], network: nn.Module, - output_dim: int = 3, # TODO (azouitine): rename it number of discret acitons smth like that + output_dim: int = 3, # TODO (azouitine): rename it number of discret acitons smth like that init_final: Optional[float] = None, encoder_is_shared: bool = False, ): diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 627a1a17..c57f83fc 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -394,7 +394,7 @@ def add_actor_information_and_train( # Use the forward method for critic loss (includes both main critic and grasp critic) critic_output = policy.forward(forward_batch, model="critic") - + # Main critic optimization loss_critic = critic_output["loss_critic"] optimizers["critic"].zero_grad() @@ -405,7 +405,7 @@ def add_actor_information_and_train( optimizers["critic"].step() # Grasp critic optimization (if available) - if "loss_grasp_critic" in critic_output and hasattr(policy, "grasp_critic"): + if "loss_grasp_critic" in critic_output: loss_grasp_critic = critic_output["loss_grasp_critic"] optimizers["grasp_critic"].zero_grad() loss_grasp_critic.backward() @@ -450,7 +450,7 @@ def add_actor_information_and_train( # Use the forward method for critic loss (includes both main critic and grasp critic) critic_output = policy.forward(forward_batch, model="critic") - + # Main critic optimization loss_critic = critic_output["loss_critic"] optimizers["critic"].zero_grad() @@ -475,7 +475,7 @@ def add_actor_information_and_train( parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value ).item() optimizers["grasp_critic"].step() - + # Add grasp critic info to training info training_infos["loss_grasp_critic"] = loss_grasp_critic.item() training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm @@ -492,7 +492,7 @@ def add_actor_information_and_train( parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value ).item() optimizers["actor"].step() - + # Add actor info to training info training_infos["loss_actor"] = loss_actor.item() training_infos["actor_grad_norm"] = actor_grad_norm @@ -506,7 +506,7 @@ def add_actor_information_and_train( parameters=[policy.log_alpha], max_norm=clip_grad_norm_value ).item() optimizers["temperature"].step() - + # Add temperature info to training info training_infos["loss_temperature"] = loss_temperature.item() training_infos["temperature_grad_norm"] = temp_grad_norm @@ -756,7 +756,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module): lr=cfg.policy.actor_lr, ) optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) - + if cfg.policy.num_discrete_actions is not None: optimizer_grasp_critic = torch.optim.Adam( params=policy.grasp_critic.parameters(), lr=policy.critic_lr