From 9386892f8e9a8d33487d9860094f09aa9bd9a5c1 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Mon, 14 Apr 2025 14:00:57 +0000 Subject: [PATCH] Refactor modeling_sac and parameter handling for clarity and reusability. Co-authored-by: s1lent4gnt --- lerobot/common/policies/sac/modeling_sac.py | 62 ++++++++------------- lerobot/scripts/server/learner_server.py | 50 +++++++++++++++-- 2 files changed, 67 insertions(+), 45 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 9ffdf154..05937240 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -167,8 +167,12 @@ class SACPolicy( def get_optim_params(self) -> dict: optim_params = { - "actor": self.actor.parameters_to_optimize, - "critic": self.critic_ensemble.parameters_to_optimize, + "actor": [ + p + for n, p in self.actor.named_parameters() + if not n.startswith("encoder") or not self.shared_encoder + ], + "critic": self.critic_ensemble.parameters(), "temperature": self.log_alpha, } if self.config.num_discrete_actions is not None: @@ -451,11 +455,11 @@ class SACPolicy( target_next_grasp_qs, dim=1, index=best_next_grasp_action ).squeeze(-1) - # Compute target Q-value with Bellman equation - rewards_gripper = rewards - if gripper_penalties is not None: - rewards_gripper = rewards + gripper_penalties - target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q + # Compute target Q-value with Bellman equation + rewards_gripper = rewards + if gripper_penalties is not None: + rewards_gripper = rewards + gripper_penalties + target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q # Get predicted Q-values for current observations predicted_grasp_qs = self.grasp_critic_forward( @@ -510,7 +514,6 @@ class SACObservationEncoder(nn.Module): self.config = config self.input_normalization = input_normalizer self.has_pretrained_vision_encoder = False - self.parameters_to_optimize = [] self.aggregation_size: int = 0 if any("observation.image" in key for key in config.input_features): @@ -527,8 +530,6 @@ class SACObservationEncoder(nn.Module): if config.freeze_vision_encoder: freeze_image_encoder(self.image_enc_layers.image_enc_layers) - self.parameters_to_optimize += self.image_enc_layers.parameters_to_optimize - self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")] if "observation.state" in config.input_features: @@ -542,8 +543,6 @@ class SACObservationEncoder(nn.Module): ) self.aggregation_size += config.latent_dim - self.parameters_to_optimize += list(self.state_enc_layers.parameters()) - if "observation.environment_state" in config.input_features: self.env_state_enc_layers = nn.Sequential( nn.Linear( @@ -554,10 +553,8 @@ class SACObservationEncoder(nn.Module): nn.Tanh(), ) self.aggregation_size += config.latent_dim - self.parameters_to_optimize += list(self.env_state_enc_layers.parameters()) self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim) - self.parameters_to_optimize += list(self.aggregation_layer.parameters()) def forward( self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None @@ -737,12 +734,6 @@ class CriticEnsemble(nn.Module): self.output_normalization = output_normalization self.critics = nn.ModuleList(ensemble) - self.parameters_to_optimize = [] - # Handle the case where a part of the encoder if frozen - if self.encoder is not None: - self.parameters_to_optimize += list(self.encoder.parameters_to_optimize) - self.parameters_to_optimize += list(self.critics.parameters()) - def forward( self, observations: dict[str, torch.Tensor], @@ -805,10 +796,6 @@ class GraspCritic(nn.Module): else: orthogonal_init()(self.output_layer.weight) - self.parameters_to_optimize = [] - self.parameters_to_optimize += list(self.net.parameters()) - self.parameters_to_optimize += list(self.output_layer.parameters()) - def forward( self, observations: torch.Tensor, observation_features: torch.Tensor | None = None ) -> torch.Tensor: @@ -840,12 +827,8 @@ class Policy(nn.Module): self.log_std_max = log_std_max self.fixed_std = fixed_std self.use_tanh_squash = use_tanh_squash - self.parameters_to_optimize = [] + self.encoder_is_shared = encoder_is_shared - self.parameters_to_optimize += list(self.network.parameters()) - - if self.encoder is not None and not encoder_is_shared: - self.parameters_to_optimize += list(self.encoder.parameters()) # Find the last Linear layer's output dimension for layer in reversed(network.net): if isinstance(layer, nn.Linear): @@ -859,7 +842,6 @@ class Policy(nn.Module): else: orthogonal_init()(self.mean_layer.weight) - self.parameters_to_optimize += list(self.mean_layer.parameters()) # Standard deviation layer or parameter if fixed_std is None: self.std_layer = nn.Linear(out_features, action_dim) @@ -868,7 +850,6 @@ class Policy(nn.Module): nn.init.uniform_(self.std_layer.bias, -init_final, init_final) else: orthogonal_init()(self.std_layer.weight) - self.parameters_to_optimize += list(self.std_layer.parameters()) def forward( self, @@ -877,6 +858,8 @@ class Policy(nn.Module): ) -> Tuple[torch.Tensor, torch.Tensor]: # Encode observations if encoder exists obs_enc = self.encoder(observations, vision_encoder_cache=observation_features) + if self.encoder_is_shared: + obs_enc = obs_enc.detach() # Get network outputs outputs = self.network(obs_enc) @@ -966,13 +949,13 @@ class DefaultImageEncoder(nn.Module): nn.Tanh(), ) - self.parameters_to_optimize = [] - if not config.freeze_vision_encoder: - self.parameters_to_optimize += list(self.image_enc_layers.parameters()) - self.parameters_to_optimize += list(self.image_enc_proj.parameters()) + self.freeze_image_encoder = config.freeze_vision_encoder def forward(self, x): - return self.image_enc_proj(self.image_enc_layers(x)) + x = self.image_enc_layers(x) + if self.freeze_image_encoder: + x = x.detach() + return self.image_enc_proj(x) class PretrainedImageEncoder(nn.Module): @@ -985,10 +968,7 @@ class PretrainedImageEncoder(nn.Module): nn.Tanh(), ) - self.parameters_to_optimize = [] - if not config.freeze_vision_encoder: - self.parameters_to_optimize += list(self.image_enc_layers.parameters()) - self.parameters_to_optimize += list(self.image_enc_proj.parameters()) + self.freeze_image_encoder = config.freeze_vision_encoder def _load_pretrained_vision_encoder(self, config: SACConfig): """Set up CNN encoder""" @@ -1009,6 +989,8 @@ class PretrainedImageEncoder(nn.Module): # TODO: (maractingi, azouitine) check the forward pass of the pretrained model # doesn't reach the classifier layer because we don't need it enc_feat = self.image_enc_layers(x).pooler_output + if self.freeze_image_encoder: + enc_feat = enc_feat.detach() enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1)) return enc_feat diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 5b39d0d3..a8a858bf 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -510,7 +510,7 @@ def add_actor_information_and_train( optimizers["actor"].zero_grad() loss_actor.backward() actor_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value + parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value ).item() optimizers["actor"].step() @@ -773,12 +773,14 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module): """ optimizer_actor = torch.optim.Adam( # NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor - params=policy.actor.parameters_to_optimize, + params=[ + p + for n, p in policy.actor.named_parameters() + if not n.startswith("encoder") or not policy.config.shared_encoder + ], lr=cfg.policy.actor_lr, ) - optimizer_critic = torch.optim.Adam( - params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr - ) + optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) if cfg.policy.num_discrete_actions is not None: optimizer_grasp_critic = torch.optim.Adam( @@ -1089,6 +1091,44 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module): parameters_queue.put(state_bytes) +def check_weight_gradients(module: nn.Module) -> dict[str, bool]: + """ + Checks whether each parameter in the module has a gradient. + + Args: + module (nn.Module): A PyTorch module whose parameters will be inspected. + + Returns: + dict[str, bool]: A dictionary where each key is the parameter name and the value is + True if the parameter has an associated gradient (i.e. .grad is not None), + otherwise False. + """ + grad_status = {} + for name, param in module.named_parameters(): + grad_status[name] = param.grad is not None + return grad_status + + +def get_overlapping_parameters(model: nn.Module, grad_status: dict[str, bool]) -> dict[str, bool]: + """ + Returns a dictionary of parameters (from actor) that also exist in the grad_status dictionary. + + Args: + actor (nn.Module): The actor model. + grad_status (dict[str, bool]): A dictionary where keys are parameter names and values indicate + whether each parameter has a gradient. + + Returns: + dict[str, bool]: A dictionary containing only the overlapping parameter names and their gradient status. + """ + # Get actor parameter names as a set. + model_param_names = {name for name, _ in model.named_parameters()} + + # Intersect parameter names between actor and grad_status. + overlapping = {name: grad_status[name] for name in grad_status if name in model_param_names} + return overlapping + + def process_interaction_message( message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None ):