From a54baceabb4a6a5620938afa9894c67c3cc7b20a Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Wed, 2 Apr 2025 15:50:39 +0000 Subject: [PATCH] Enhance SACPolicy and learner server for improved grasp critic integration - Updated SACPolicy to conditionally compute grasp critic losses based on the presence of discrete actions. - Refactored the forward method to handle grasp critic model selection and loss computation more clearly. - Adjusted learner server to utilize optimized parameters for grasp critic during training. - Improved action handling in the ManiskillMockGripperWrapper to accommodate both tuple and single action inputs. --- lerobot/common/policies/sac/modeling_sac.py | 95 +++++++++++-------- lerobot/scripts/server/learner_server.py | 16 ++-- .../scripts/server/maniskill_manipulator.py | 11 ++- 3 files changed, 72 insertions(+), 50 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 41ff7d8c..af624592 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -52,8 +52,6 @@ class SACPolicy( self.config = config continuous_action_dim = config.output_features["action"].shape[0] - if config.num_discrete_actions is not None: - continuous_action_dim -= 1 if config.dataset_stats is not None: input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats) @@ -191,7 +189,7 @@ class SACPolicy( if self.config.num_discrete_actions is not None: discrete_action_value = self.grasp_critic(batch) - discrete_action = torch.argmax(discrete_action_value, dim=-1) + discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True) actions = torch.cat([actions, discrete_action], dim=-1) return actions @@ -236,7 +234,7 @@ class SACPolicy( def forward( self, batch: dict[str, Tensor | dict[str, Tensor]], - model: Literal["actor", "critic", "temperature"] = "critic", + model: Literal["actor", "critic", "temperature", "grasp_critic"] = "critic", ) -> dict[str, Tensor]: """Compute the loss for the given model @@ -275,18 +273,25 @@ class SACPolicy( observation_features=observation_features, next_observation_features=next_observation_features, ) - if self.config.num_discrete_actions is not None: - loss_grasp_critic = self.compute_loss_grasp_critic( - observations=observations, - actions=actions, - rewards=rewards, - next_observations=next_observations, - done=done, - ) - return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic} return {"loss_critic": loss_critic} + if model == "grasp_critic" and self.config.num_discrete_actions is not None: + # Extract critic-specific components + rewards: Tensor = batch["reward"] + next_observations: dict[str, Tensor] = batch["next_state"] + done: Tensor = batch["done"] + next_observation_features: Tensor = batch.get("next_observation_feature") + loss_grasp_critic = self.compute_loss_grasp_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + observation_features=observation_features, + next_observation_features=next_observation_features, + ) + return {"loss_grasp_critic": loss_grasp_critic} if model == "actor": return { "loss_actor": self.compute_loss_actor( @@ -373,7 +378,6 @@ class SACPolicy( # In the buffer we have the full action space (continuous + discrete) # We need to split them before concatenating them in the critic forward actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX] - q_preds = self.critic_forward( observations=observations, actions=actions, @@ -407,30 +411,38 @@ class SACPolicy( # NOTE: We only want to keep the discrete action part # In the buffer we have the full action space (continuous + discrete) # We need to split them before concatenating them in the critic forward - actions: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:] - actions = actions.long() + actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone() + actions_discrete = actions_discrete.long() with torch.no_grad(): # For DQN, select actions using online network, evaluate with target network - next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=False) - best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1) + next_grasp_qs = self.grasp_critic_forward( + next_observations, use_target=False, observation_features=next_observation_features + ) + best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1, keepdim=True) # Get target Q-values from target network - target_next_grasp_qs = self.grasp_critic_forward(observations=next_observations, use_target=True) + target_next_grasp_qs = self.grasp_critic_forward( + observations=next_observations, + use_target=True, + observation_features=next_observation_features, + ) # Use gather to select Q-values for best actions target_next_grasp_q = torch.gather( - target_next_grasp_qs, dim=1, index=best_next_grasp_action.unsqueeze(-1) + target_next_grasp_qs, dim=1, index=best_next_grasp_action ).squeeze(-1) # Compute target Q-value with Bellman equation target_grasp_q = rewards + (1 - done) * self.config.discount * target_next_grasp_q # Get predicted Q-values for current observations - predicted_grasp_qs = self.grasp_critic_forward(observations=observations, use_target=False) + predicted_grasp_qs = self.grasp_critic_forward( + observations=observations, use_target=False, observation_features=observation_features + ) # Use gather to select Q-values for taken actions - predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions.unsqueeze(-1)).squeeze(-1) + predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions_discrete).squeeze(-1) # Compute MSE loss between predicted and target Q-values grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q) @@ -642,49 +654,52 @@ class GraspCritic(nn.Module): def __init__( self, encoder: Optional[nn.Module], - network: nn.Module, - output_dim: int = 3, # TODO (azouitine): rename it number of discret acitons smth like that + input_dim: int, + hidden_dims: list[int], + output_dim: int = 3, + activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), + activate_final: bool = False, + dropout_rate: Optional[float] = None, init_final: Optional[float] = None, - encoder_is_shared: bool = False, + final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None, ): super().__init__() self.encoder = encoder - self.network = network self.output_dim = output_dim - # Find the last Linear layer's output dimension - for layer in reversed(network.net): - if isinstance(layer, nn.Linear): - out_features = layer.out_features - break + self.net = MLP( + input_dim=input_dim, + hidden_dims=hidden_dims, + activations=activations, + activate_final=activate_final, + dropout_rate=dropout_rate, + final_activation=final_activation, + ) - self.parameters_to_optimize += list(self.network.parameters()) - - if self.encoder is not None and not encoder_is_shared: - self.parameters_to_optimize += list(self.encoder.parameters()) - - self.output_layer = nn.Linear(in_features=out_features, out_features=self.output_dim) + self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=self.output_dim) if init_final is not None: nn.init.uniform_(self.output_layer.weight, -init_final, init_final) nn.init.uniform_(self.output_layer.bias, -init_final, init_final) else: orthogonal_init()(self.output_layer.weight) + self.parameters_to_optimize = [] + self.parameters_to_optimize += list(self.net.parameters()) self.parameters_to_optimize += list(self.output_layer.parameters()) def forward( self, observations: torch.Tensor, observation_features: torch.Tensor | None = None ) -> torch.Tensor: device = get_device_from_parameters(self) - # Move each tensor in observations to device + # Move each tensor in observations to device by cloning first to avoid inplace operations observations = {k: v.to(device) for k, v in observations.items()} # Encode observations if encoder exists obs_enc = ( - observation_features + observation_features.to(device) if observation_features is not None else (observations if self.encoder is None else self.encoder(observations)) ) - return self.output_layer(self.network(obs_enc)) + return self.output_layer(self.net(obs_enc)) class Policy(nn.Module): diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index c57f83fc..ce9a1b41 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -405,12 +405,13 @@ def add_actor_information_and_train( optimizers["critic"].step() # Grasp critic optimization (if available) - if "loss_grasp_critic" in critic_output: - loss_grasp_critic = critic_output["loss_grasp_critic"] + if policy.config.num_discrete_actions is not None: + discrete_critic_output = policy.forward(forward_batch, model="grasp_critic") + loss_grasp_critic = discrete_critic_output["loss_grasp_critic"] optimizers["grasp_critic"].zero_grad() loss_grasp_critic.backward() grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value + parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value ) optimizers["grasp_critic"].step() @@ -467,12 +468,13 @@ def add_actor_information_and_train( } # Grasp critic optimization (if available) - if "loss_grasp_critic" in critic_output: - loss_grasp_critic = critic_output["loss_grasp_critic"] + if policy.config.num_discrete_actions is not None: + discrete_critic_output = policy.forward(forward_batch, model="grasp_critic") + loss_grasp_critic = discrete_critic_output["loss_grasp_critic"] optimizers["grasp_critic"].zero_grad() loss_grasp_critic.backward() grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value + parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value ).item() optimizers["grasp_critic"].step() @@ -759,7 +761,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module): if cfg.policy.num_discrete_actions is not None: optimizer_grasp_critic = torch.optim.Adam( - params=policy.grasp_critic.parameters(), lr=policy.critic_lr + params=policy.grasp_critic.parameters_to_optimize, lr=cfg.policy.critic_lr ) optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) lr_scheduler = None diff --git a/lerobot/scripts/server/maniskill_manipulator.py b/lerobot/scripts/server/maniskill_manipulator.py index f4a89888..b5c181c1 100644 --- a/lerobot/scripts/server/maniskill_manipulator.py +++ b/lerobot/scripts/server/maniskill_manipulator.py @@ -16,7 +16,6 @@ from lerobot.common.policies.sac.configuration_sac import SACConfig from lerobot.common.policies.sac.modeling_sac import SACPolicy - def preprocess_maniskill_observation( observations: dict[str, np.ndarray], ) -> dict[str, torch.Tensor]: @@ -156,6 +155,7 @@ class TimeLimitWrapper(gym.Wrapper): self.current_step = 0 return super().reset(seed=seed, options=options) + class ManiskillMockGripperWrapper(gym.Wrapper): def __init__(self, env, nb_discrete_actions: int = 3): super().__init__(env) @@ -166,11 +166,16 @@ class ManiskillMockGripperWrapper(gym.Wrapper): self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1])) def step(self, action): - action_agent, telop_action = action + if isinstance(action, tuple): + action_agent, telop_action = action + else: + telop_action = 0 + action_agent = action real_action = action_agent[:-1] final_action = (real_action, telop_action) obs, reward, terminated, truncated, info = self.env.step(final_action) - return obs, reward, terminated, truncated, info + return obs, reward, terminated, truncated, info + def make_maniskill( cfg: ManiskillEnvConfig,