diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index b1ce30b6..66d9aa45 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -42,12 +42,6 @@ class CriticNetworkConfig: final_activation: str | None = None -@dataclass -class GraspCriticNetworkConfig: - hidden_dims: list[int] = field(default_factory=lambda: [256, 256]) - activate_final: bool = True - final_activation: str | None = None - output_dim: int = 3 @dataclass @@ -152,6 +146,7 @@ class SACConfig(PreTrainedConfig): freeze_vision_encoder: bool = True image_encoder_hidden_dim: int = 32 shared_encoder: bool = True + num_discrete_actions: int | None = None # Training parameter online_steps: int = 1000000 @@ -182,7 +177,7 @@ class SACConfig(PreTrainedConfig): critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig) policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig) - + grasp_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig) concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 95ea3928..dd156918 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -33,6 +33,7 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.sac.configuration_sac import SACConfig from lerobot.common.policies.utils import get_device_from_parameters +DISCRETE_DIMENSION_INDEX = -1 class SACPolicy( PreTrainedPolicy, @@ -113,24 +114,30 @@ class SACPolicy( self.critic_ensemble = torch.compile(self.critic_ensemble) self.critic_target = torch.compile(self.critic_target) - # Create grasp critic - self.grasp_critic = GraspCritic( - encoder=encoder_critic, - input_dim=encoder_critic.output_dim, - **config.grasp_critic_network_kwargs, - ) + self.grasp_critic = None + self.grasp_critic_target = None - # Create target grasp critic - self.grasp_critic_target = GraspCritic( - encoder=encoder_critic, - input_dim=encoder_critic.output_dim, - **config.grasp_critic_network_kwargs, - ) + if config.num_discrete_actions is not None: + # Create grasp critic + self.grasp_critic = GraspCritic( + encoder=encoder_critic, + input_dim=encoder_critic.output_dim, + output_dim=config.num_discrete_actions, + **asdict(config.grasp_critic_network_kwargs), + ) - self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict()) + # Create target grasp critic + self.grasp_critic_target = GraspCritic( + encoder=encoder_critic, + input_dim=encoder_critic.output_dim, + output_dim=config.num_discrete_actions, + **asdict(config.grasp_critic_network_kwargs), + ) - self.grasp_critic = torch.compile(self.grasp_critic) - self.grasp_critic_target = torch.compile(self.grasp_critic_target) + self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict()) + + self.grasp_critic = torch.compile(self.grasp_critic) + self.grasp_critic_target = torch.compile(self.grasp_critic_target) self.actor = Policy( encoder=encoder_actor, @@ -173,6 +180,12 @@ class SACPolicy( """Select action for inference/evaluation""" actions, _, _ = self.actor(batch) actions = self.unnormalize_outputs({"action": actions})["action"] + + if self.config.num_discrete_actions is not None: + discrete_action_value = self.grasp_critic(batch) + discrete_action = torch.argmax(discrete_action_value, dim=-1) + actions = torch.cat([actions, discrete_action], dim=-1) + return actions def critic_forward( @@ -192,11 +205,12 @@ class SACPolicy( Returns: Tensor of Q-values from all critics """ + critics = self.critic_target if use_target else self.critic_ensemble q_values = critics(observations, actions, observation_features) return q_values - def grasp_critic_forward(self, observations, use_target=False, observation_features=None): + def grasp_critic_forward(self, observations, use_target=False, observation_features=None) -> torch.Tensor: """Forward pass through a grasp critic network Args: @@ -256,9 +270,6 @@ class SACPolicy( ) if model == "grasp_critic": - # Extract grasp_critic-specific components - complementary_info: dict[str, Tensor] = batch["complementary_info"] - return self.compute_loss_grasp_critic( observations=observations, actions=actions, @@ -267,7 +278,6 @@ class SACPolicy( done=done, observation_features=observation_features, next_observation_features=next_observation_features, - complementary_info=complementary_info, ) if model == "actor": @@ -349,6 +359,12 @@ class SACPolicy( td_target = rewards + (1 - done) * self.config.discount * min_q # 3- compute predicted qs + if self.config.num_discrete_actions is not None: + # NOTE: We only want to keep the continuous action part + # In the buffer we have the full action space (continuous + discrete) + # We need to split them before concatenating them in the critic forward + actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX] + q_preds = self.critic_forward( observations=observations, actions=actions, @@ -378,30 +394,43 @@ class SACPolicy( done, observation_features=None, next_observation_features=None, - complementary_info=None, ): - batch_size = rewards.shape[0] - grasp_actions = torch.clip(actions[:, -1].long() + 1, 0, 2) # Map [-1, 0, 1] -> [0, 1, 2] + # NOTE: We only want to keep the discrete action part + # In the buffer we have the full action space (continuous + discrete) + # We need to split them before concatenating them in the critic forward + actions: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:] + actions = actions.long() with torch.no_grad(): + # For DQN, select actions using online network, evaluate with target network next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=False) best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1) + + # Get target Q-values from target network + target_next_grasp_qs = self.grasp_critic_forward(observations=next_observations, use_target=True) + + # Use gather to select Q-values for best actions + target_next_grasp_q = torch.gather( + target_next_grasp_qs, + dim=1, + index=best_next_grasp_action.unsqueeze(-1) + ).squeeze(-1) - target_next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=True) - target_next_grasp_q = target_next_grasp_qs[torch.arange(batch_size), best_next_grasp_action] + # Compute target Q-value with Bellman equation + target_grasp_q = rewards + (1 - done) * self.config.discount * target_next_grasp_q - # Get the grasp penalty from complementary_info - grasp_penalty = torch.zeros_like(rewards) - if complementary_info is not None and "grasp_penalty" in complementary_info: - grasp_penalty = complementary_info["grasp_penalty"] + # Get predicted Q-values for current observations + predicted_grasp_qs = self.grasp_critic_forward(observations=observations, use_target=False) + + # Use gather to select Q-values for taken actions + predicted_grasp_q = torch.gather( + predicted_grasp_qs, + dim=1, + index=actions.unsqueeze(-1) + ).squeeze(-1) - grasp_rewards = rewards + grasp_penalty - target_grasp_q = grasp_rewards + (1 - done) * self.config.discount * target_next_grasp_q - - predicted_grasp_qs = self.grasp_critic_forward(observations, use_target=False) - predicted_grasp_q = predicted_grasp_qs[torch.arange(batch_size), grasp_actions] - - grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q, reduction="mean") + # Compute MSE loss between predicted and target Q-values + grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q) return grasp_critic_loss def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor: @@ -611,7 +640,7 @@ class GraspCritic(nn.Module): self, encoder: Optional[nn.Module], network: nn.Module, - output_dim: int = 3, + output_dim: int = 3, # TODO (azouitine): rename it number of discret acitons smth like that init_final: Optional[float] = None, encoder_is_shared: bool = False, ):