From 4a1c26d9ee7108d86e60feee82b3d20ced026642 Mon Sep 17 00:00:00 2001 From: s1lent4gnt Date: Mon, 31 Mar 2025 17:35:59 +0200 Subject: [PATCH] Add grasp critic - Implemented grasp critic to evaluate gripper actions - Added corresponding config parameters for tuning --- .../common/policies/sac/configuration_sac.py | 8 ++ lerobot/common/policies/sac/modeling_sac.py | 124 ++++++++++++++++++ 2 files changed, 132 insertions(+) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 906a3bed..e47185da 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -42,6 +42,14 @@ class CriticNetworkConfig: final_activation: str | None = None +@dataclass +class GraspCriticNetworkConfig: + hidden_dims: list[int] = field(default_factory=lambda: [256, 256]) + activate_final: bool = True + final_activation: str | None = None + output_dim: int = 3 + + @dataclass class ActorNetworkConfig: hidden_dims: list[int] = field(default_factory=lambda: [256, 256]) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index f7866714..3589ad25 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -112,6 +112,26 @@ class SACPolicy( self.critic_ensemble = torch.compile(self.critic_ensemble) self.critic_target = torch.compile(self.critic_target) + + # Create grasp critic + self.grasp_critic = GraspCritic( + encoder=encoder_critic, + input_dim=encoder_critic.output_dim, + **config.grasp_critic_network_kwargs, + ) + + # Create target grasp critic + self.grasp_critic_target = GraspCritic( + encoder=encoder_critic, + input_dim=encoder_critic.output_dim, + **config.grasp_critic_network_kwargs, + ) + + self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict()) + + self.grasp_critic = torch.compile(self.grasp_critic) + self.grasp_critic_target = torch.compile(self.grasp_critic_target) + self.actor = Policy( encoder=encoder_actor, network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)), @@ -176,6 +196,21 @@ class SACPolicy( q_values = critics(observations, actions, observation_features) return q_values + def grasp_critic_forward(self, observations, use_target=False, observation_features=None): + """Forward pass through a grasp critic network + + Args: + observations: Dictionary of observations + use_target: If True, use target critics, otherwise use ensemble critics + observation_features: Optional pre-computed observation features to avoid recomputing encoder output + + Returns: + Tensor of Q-values from the grasp critic network + """ + grasp_critic = self.grasp_critic_target if use_target else self.grasp_critic + q_values = grasp_critic(observations, observation_features) + return q_values + def forward( self, batch: dict[str, Tensor | dict[str, Tensor]], @@ -246,6 +281,18 @@ class SACPolicy( + target_param.data * (1.0 - self.config.critic_target_update_weight) ) + def update_grasp_target_networks(self): + """Update grasp target networks with exponential moving average""" + for target_param, param in zip( + self.grasp_critic_target.parameters(), + self.grasp_critic.parameters(), + strict=False, + ): + target_param.data.copy_( + param.data * self.config.critic_target_update_weight + + target_param.data * (1.0 - self.config.critic_target_update_weight) + ) + def update_temperature(self): self.temperature = self.log_alpha.exp().item() @@ -307,6 +354,32 @@ class SACPolicy( ).sum() return critics_loss + def compute_loss_grasp_critic(self, observations, actions, rewards, next_observations, done, observation_features=None, next_observation_features=None, complementary_info=None): + + batch_size = rewards.shape[0] + grasp_actions = torch.clip(actions[:, -1].long() + 1, 0, 2) # Map [-1, 0, 1] -> [0, 1, 2] + + with torch.no_grad(): + next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=False) + best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1) + + target_next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=True) + target_next_grasp_q = target_next_grasp_qs[torch.arange(batch_size), best_next_grasp_action] + + # Get the grasp penalty from complementary_info + grasp_penalty = torch.zeros_like(rewards) + if complementary_info is not None and "grasp_penalty" in complementary_info: + grasp_penalty = complementary_info["grasp_penalty"] + + grasp_rewards = rewards + grasp_penalty + target_grasp_q = grasp_rewards + (1 - done) * self.config.discount * target_next_grasp_q + + predicted_grasp_qs = self.grasp_critic_forward(observations, use_target=False) + predicted_grasp_q = predicted_grasp_qs[torch.arange(batch_size), grasp_actions] + + grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q, reduction="mean") + return grasp_critic_loss + def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor: """Compute the temperature loss""" # calculate temperature loss @@ -509,6 +582,57 @@ class CriticEnsemble(nn.Module): return q_values +class GraspCritic(nn.Module): + def __init__( + self, + encoder: Optional[nn.Module], + network: nn.Module, + output_dim: int = 3, + init_final: Optional[float] = None, + encoder_is_shared: bool = False, + ): + super().__init__() + self.encoder = encoder + self.network = network + self.output_dim = output_dim + + # Find the last Linear layer's output dimension + for layer in reversed(network.net): + if isinstance(layer, nn.Linear): + out_features = layer.out_features + break + + self.parameters_to_optimize += list(self.network.parameters()) + + if self.encoder is not None and not encoder_is_shared: + self.parameters_to_optimize += list(self.encoder.parameters()) + + self.output_layer = nn.Linear(in_features=out_features, out_features=self.output_dim) + if init_final is not None: + nn.init.uniform_(self.output_layer.weight, -init_final, init_final) + nn.init.uniform_(self.output_layer.bias, -init_final, init_final) + else: + orthogonal_init()(self.output_layer.weight) + + self.parameters_to_optimize += list(self.output_layer.parameters()) + + def forward( + self, + observations: torch.Tensor, + observation_features: torch.Tensor | None = None + ) -> torch.Tensor: + device = get_device_from_parameters(self) + # Move each tensor in observations to device + observations = {k: v.to(device) for k, v in observations.items()} + # Encode observations if encoder exists + obs_enc = ( + observation_features + if observation_features is not None + else (observations if self.encoder is None else self.encoder(observations)) + ) + return self.output_layer(self.network(obs_enc)) + + class Policy(nn.Module): def __init__( self,