From 4a1c26d9ee7108d86e60feee82b3d20ced026642 Mon Sep 17 00:00:00 2001 From: s1lent4gnt <kmeftah.khalil@gmail.com> Date: Mon, 31 Mar 2025 17:35:59 +0200 Subject: [PATCH 01/28] Add grasp critic - Implemented grasp critic to evaluate gripper actions - Added corresponding config parameters for tuning --- .../common/policies/sac/configuration_sac.py | 8 ++ lerobot/common/policies/sac/modeling_sac.py | 124 ++++++++++++++++++ 2 files changed, 132 insertions(+) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 906a3bed..e47185da 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -42,6 +42,14 @@ class CriticNetworkConfig: final_activation: str | None = None +@dataclass +class GraspCriticNetworkConfig: + hidden_dims: list[int] = field(default_factory=lambda: [256, 256]) + activate_final: bool = True + final_activation: str | None = None + output_dim: int = 3 + + @dataclass class ActorNetworkConfig: hidden_dims: list[int] = field(default_factory=lambda: [256, 256]) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index f7866714..3589ad25 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -112,6 +112,26 @@ class SACPolicy( self.critic_ensemble = torch.compile(self.critic_ensemble) self.critic_target = torch.compile(self.critic_target) + + # Create grasp critic + self.grasp_critic = GraspCritic( + encoder=encoder_critic, + input_dim=encoder_critic.output_dim, + **config.grasp_critic_network_kwargs, + ) + + # Create target grasp critic + self.grasp_critic_target = GraspCritic( + encoder=encoder_critic, + input_dim=encoder_critic.output_dim, + **config.grasp_critic_network_kwargs, + ) + + self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict()) + + self.grasp_critic = torch.compile(self.grasp_critic) + self.grasp_critic_target = torch.compile(self.grasp_critic_target) + self.actor = Policy( encoder=encoder_actor, network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)), @@ -176,6 +196,21 @@ class SACPolicy( q_values = critics(observations, actions, observation_features) return q_values + def grasp_critic_forward(self, observations, use_target=False, observation_features=None): + """Forward pass through a grasp critic network + + Args: + observations: Dictionary of observations + use_target: If True, use target critics, otherwise use ensemble critics + observation_features: Optional pre-computed observation features to avoid recomputing encoder output + + Returns: + Tensor of Q-values from the grasp critic network + """ + grasp_critic = self.grasp_critic_target if use_target else self.grasp_critic + q_values = grasp_critic(observations, observation_features) + return q_values + def forward( self, batch: dict[str, Tensor | dict[str, Tensor]], @@ -246,6 +281,18 @@ class SACPolicy( + target_param.data * (1.0 - self.config.critic_target_update_weight) ) + def update_grasp_target_networks(self): + """Update grasp target networks with exponential moving average""" + for target_param, param in zip( + self.grasp_critic_target.parameters(), + self.grasp_critic.parameters(), + strict=False, + ): + target_param.data.copy_( + param.data * self.config.critic_target_update_weight + + target_param.data * (1.0 - self.config.critic_target_update_weight) + ) + def update_temperature(self): self.temperature = self.log_alpha.exp().item() @@ -307,6 +354,32 @@ class SACPolicy( ).sum() return critics_loss + def compute_loss_grasp_critic(self, observations, actions, rewards, next_observations, done, observation_features=None, next_observation_features=None, complementary_info=None): + + batch_size = rewards.shape[0] + grasp_actions = torch.clip(actions[:, -1].long() + 1, 0, 2) # Map [-1, 0, 1] -> [0, 1, 2] + + with torch.no_grad(): + next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=False) + best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1) + + target_next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=True) + target_next_grasp_q = target_next_grasp_qs[torch.arange(batch_size), best_next_grasp_action] + + # Get the grasp penalty from complementary_info + grasp_penalty = torch.zeros_like(rewards) + if complementary_info is not None and "grasp_penalty" in complementary_info: + grasp_penalty = complementary_info["grasp_penalty"] + + grasp_rewards = rewards + grasp_penalty + target_grasp_q = grasp_rewards + (1 - done) * self.config.discount * target_next_grasp_q + + predicted_grasp_qs = self.grasp_critic_forward(observations, use_target=False) + predicted_grasp_q = predicted_grasp_qs[torch.arange(batch_size), grasp_actions] + + grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q, reduction="mean") + return grasp_critic_loss + def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor: """Compute the temperature loss""" # calculate temperature loss @@ -509,6 +582,57 @@ class CriticEnsemble(nn.Module): return q_values +class GraspCritic(nn.Module): + def __init__( + self, + encoder: Optional[nn.Module], + network: nn.Module, + output_dim: int = 3, + init_final: Optional[float] = None, + encoder_is_shared: bool = False, + ): + super().__init__() + self.encoder = encoder + self.network = network + self.output_dim = output_dim + + # Find the last Linear layer's output dimension + for layer in reversed(network.net): + if isinstance(layer, nn.Linear): + out_features = layer.out_features + break + + self.parameters_to_optimize += list(self.network.parameters()) + + if self.encoder is not None and not encoder_is_shared: + self.parameters_to_optimize += list(self.encoder.parameters()) + + self.output_layer = nn.Linear(in_features=out_features, out_features=self.output_dim) + if init_final is not None: + nn.init.uniform_(self.output_layer.weight, -init_final, init_final) + nn.init.uniform_(self.output_layer.bias, -init_final, init_final) + else: + orthogonal_init()(self.output_layer.weight) + + self.parameters_to_optimize += list(self.output_layer.parameters()) + + def forward( + self, + observations: torch.Tensor, + observation_features: torch.Tensor | None = None + ) -> torch.Tensor: + device = get_device_from_parameters(self) + # Move each tensor in observations to device + observations = {k: v.to(device) for k, v in observations.items()} + # Encode observations if encoder exists + obs_enc = ( + observation_features + if observation_features is not None + else (observations if self.encoder is None else self.encoder(observations)) + ) + return self.output_layer(self.network(obs_enc)) + + class Policy(nn.Module): def __init__( self, From 007fee923071a860ff05bf1d7b536375ed6dea5f Mon Sep 17 00:00:00 2001 From: s1lent4gnt <kmeftah.khalil@gmail.com> Date: Mon, 31 Mar 2025 17:36:35 +0200 Subject: [PATCH 02/28] Add complementary info in the replay buffer - Added complementary info in the add method - Added complementary info in the sample method --- lerobot/scripts/server/buffer.py | 37 ++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 776ad9ec..1fbc8803 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -244,6 +244,11 @@ class ReplayBuffer: self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) + + # Initialize complementary_info storage + self.complementary_info_keys = [] + self.complementary_info_storage = {} + self.initialized = True def __len__(self): @@ -277,6 +282,30 @@ class ReplayBuffer: self.dones[self.position] = done self.truncateds[self.position] = truncated + # Store complementary info if provided + if complementary_info is not None: + # Initialize storage for new keys on first encounter + for key, value in complementary_info.items(): + if key not in self.complementary_info_keys: + self.complementary_info_keys.append(key) + if isinstance(value, torch.Tensor): + shape = value.shape if value.ndim > 0 else (1,) + self.complementary_info_storage[key] = torch.zeros( + (self.capacity, *shape), + dtype=value.dtype, + device=self.storage_device + ) + + # Store the value + if key in self.complementary_info_storage: + if isinstance(value, torch.Tensor): + self.complementary_info_storage[key][self.position] = value + else: + # For non-tensor values (like grasp_penalty) + self.complementary_info_storage[key][self.position] = torch.tensor( + value, device=self.storage_device + ) + self.position = (self.position + 1) % self.capacity self.size = min(self.size + 1, self.capacity) @@ -335,6 +364,13 @@ class ReplayBuffer: batch_dones = self.dones[idx].to(self.device).float() batch_truncateds = self.truncateds[idx].to(self.device).float() + # Add complementary_info to batch if it exists + batch_complementary_info = {} + if hasattr(self, 'complementary_info_keys') and self.complementary_info_keys: + for key in self.complementary_info_keys: + if key in self.complementary_info_storage: + batch_complementary_info[key] = self.complementary_info_storage[key][idx].to(self.device) + return BatchTransition( state=batch_state, action=batch_actions, @@ -342,6 +378,7 @@ class ReplayBuffer: next_state=batch_next_state, done=batch_dones, truncated=batch_truncateds, + complementary_info=batch_complementary_info if batch_complementary_info else None, ) @classmethod From 7452f9baaa57d94b1b738ad94acdb23abe57c6d7 Mon Sep 17 00:00:00 2001 From: s1lent4gnt <kmeftah.khalil@gmail.com> Date: Mon, 31 Mar 2025 17:38:16 +0200 Subject: [PATCH 03/28] Add gripper penalty wrapper --- lerobot/scripts/server/gym_manipulator.py | 24 +++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 92e8dcbc..26ed1991 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -1069,6 +1069,29 @@ class ActionScaleWrapper(gym.ActionWrapper): return action * self.scale_vector, is_intervention +class GripperPenaltyWrapper(gym.Wrapper): + def __init__(self, env, penalty=-0.05): + super().__init__(env) + self.penalty = penalty + self.last_gripper_pos = None + + def reset(self, **kwargs): + obs, info = self.env.reset(**kwargs) + self.last_gripper_pos = obs["observation.state"][0, 0] # first idx for the gripper + return obs, info + + def step(self, action): + observation, reward, terminated, truncated, info = self.env.step(action) + + if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or (action[-1] > 0.5 and self.last_gripper_pos < 0.9): + info["grasp_penalty"] = self.penalty + else: + info["grasp_penalty"] = 0.0 + + self.last_gripper_pos = observation["observation.state"][0, 0] # first idx for the gripper + return observation, reward, terminated, truncated, info + + def make_robot_env(cfg) -> gym.vector.VectorEnv: """ Factory function to create a vectorized robot environment. @@ -1144,6 +1167,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None: env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space) env = BatchCompitableWrapper(env=env) + env= GripperPenaltyWrapper(env=env) return env From 2c1e5fa28b67e05e932ebba2e83ec75c73db3c34 Mon Sep 17 00:00:00 2001 From: s1lent4gnt <kmeftah.khalil@gmail.com> Date: Mon, 31 Mar 2025 17:40:00 +0200 Subject: [PATCH 04/28] Add get_gripper_action method to GamepadController --- .../server/end_effector_control_utils.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/lerobot/scripts/server/end_effector_control_utils.py b/lerobot/scripts/server/end_effector_control_utils.py index 3bd927b4..f272426d 100644 --- a/lerobot/scripts/server/end_effector_control_utils.py +++ b/lerobot/scripts/server/end_effector_control_utils.py @@ -311,6 +311,31 @@ class GamepadController(InputController): except pygame.error: logging.error("Error reading gamepad. Is it still connected?") return 0.0, 0.0, 0.0 + + def get_gripper_action(self): + """ + Get gripper action using L3/R3 buttons. + Press left stick (L3) to open the gripper. + Press right stick (R3) to close the gripper. + """ + import pygame + + try: + # Check if buttons are pressed + l3_pressed = self.joystick.get_button(9) + r3_pressed = self.joystick.get_button(10) + + # Determine action based on button presses + if r3_pressed: + return 1.0 # Close gripper + elif l3_pressed: + return -1.0 # Open gripper + else: + return 0.0 # No change + + except pygame.error: + logging.error(f"Error reading gamepad. Is it still connected?") + return 0.0 class GamepadControllerHID(InputController): From c774bbe5222e376dd142b491fe9c206902a83a99 Mon Sep 17 00:00:00 2001 From: s1lent4gnt <kmeftah.khalil@gmail.com> Date: Mon, 31 Mar 2025 18:06:21 +0200 Subject: [PATCH 05/28] Add grasp critic to the training loop - Integrated the grasp critic gradient update to the training loop in learner_server - Added Adam optimizer and configured grasp critic learning rate in configuration_sac - Added target critics networks update after the critics gradient step --- .../common/policies/sac/configuration_sac.py | 1 + lerobot/common/policies/sac/modeling_sac.py | 19 ++++++++-- lerobot/scripts/server/learner_server.py | 35 +++++++++++++++++++ 3 files changed, 53 insertions(+), 2 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index e47185da..b1ce30b6 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -167,6 +167,7 @@ class SACConfig(PreTrainedConfig): num_critics: int = 2 num_subsample_critics: int | None = None critic_lr: float = 3e-4 + grasp_critic_lr: float = 3e-4 actor_lr: float = 3e-4 temperature_lr: float = 3e-4 critic_target_update_weight: float = 0.005 diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 3589ad25..bd74c65b 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -214,7 +214,7 @@ class SACPolicy( def forward( self, batch: dict[str, Tensor | dict[str, Tensor]], - model: Literal["actor", "critic", "temperature"] = "critic", + model: Literal["actor", "critic", "grasp_critic", "temperature"] = "critic", ) -> dict[str, Tensor]: """Compute the loss for the given model @@ -227,7 +227,7 @@ class SACPolicy( - done: Done mask tensor - observation_feature: Optional pre-computed observation features - next_observation_feature: Optional pre-computed next observation features - model: Which model to compute the loss for ("actor", "critic", or "temperature") + model: Which model to compute the loss for ("actor", "critic", "grasp_critic", or "temperature") Returns: The computed loss tensor @@ -254,6 +254,21 @@ class SACPolicy( observation_features=observation_features, next_observation_features=next_observation_features, ) + + if model == "grasp_critic": + # Extract grasp_critic-specific components + complementary_info: dict[str, Tensor] = batch["complementary_info"] + + return self.compute_loss_grasp_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + observation_features=observation_features, + next_observation_features=next_observation_features, + complementary_info=complementary_info, + ) if model == "actor": return self.compute_loss_actor( diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 98d2dbd8..f79e8d57 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -375,6 +375,7 @@ def add_actor_information_and_train( observations = batch["state"] next_observations = batch["next_state"] done = batch["done"] + complementary_info = batch["complementary_info"] check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) observation_features, next_observation_features = get_observation_features( @@ -390,6 +391,7 @@ def add_actor_information_and_train( "done": done, "observation_feature": observation_features, "next_observation_feature": next_observation_features, + "complementary_info": complementary_info, } # Use the forward method for critic loss @@ -404,7 +406,20 @@ def add_actor_information_and_train( optimizers["critic"].step() + # Add gripper critic optimization + loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic") + optimizers["grasp_critic"].zero_grad() + loss_grasp_critic.backward() + + # clip gradients + grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value + ) + + optimizers["grasp_critic"].step() + policy.update_target_networks() + policy.update_grasp_target_networks() batch = replay_buffer.sample(batch_size=batch_size) @@ -435,6 +450,7 @@ def add_actor_information_and_train( "done": done, "observation_feature": observation_features, "next_observation_feature": next_observation_features, + "complementary_info": complementary_info, } # Use the forward method for critic loss @@ -453,6 +469,22 @@ def add_actor_information_and_train( training_infos["loss_critic"] = loss_critic.item() training_infos["critic_grad_norm"] = critic_grad_norm + # Add gripper critic optimization + loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic") + optimizers["grasp_critic"].zero_grad() + loss_grasp_critic.backward() + + # clip gradients + grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value + ) + + optimizers["grasp_critic"].step() + + # Add training info for the grasp critic + training_infos["loss_grasp_critic"] = loss_grasp_critic.item() + training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm + if optimization_step % policy_update_freq == 0: for _ in range(policy_update_freq): # Use the forward method for actor loss @@ -495,6 +527,7 @@ def add_actor_information_and_train( last_time_policy_pushed = time.time() policy.update_target_networks() + policy.update_grasp_target_networks() # Log training metrics at specified intervals if optimization_step % log_freq == 0: @@ -729,11 +762,13 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module): lr=cfg.policy.actor_lr, ) optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) + optimizer_grasp_critic = torch.optim.Adam(params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr) optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) lr_scheduler = None optimizers = { "actor": optimizer_actor, "critic": optimizer_critic, + "grasp_critic": optimizer_grasp_critic, "temperature": optimizer_temperature, } return optimizers, lr_scheduler From 7983baf4fcbabe77e2d946c2fe3f441e6581ebfc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 31 Mar 2025 16:10:00 +0000 Subject: [PATCH 06/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- lerobot/common/policies/sac/modeling_sac.py | 23 ++++++++++++------- lerobot/scripts/server/buffer.py | 12 ++++------ .../server/end_effector_control_utils.py | 10 ++++---- lerobot/scripts/server/gym_manipulator.py | 10 ++++---- lerobot/scripts/server/learner_server.py | 6 +++-- 5 files changed, 35 insertions(+), 26 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index bd74c65b..95ea3928 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -198,7 +198,7 @@ class SACPolicy( def grasp_critic_forward(self, observations, use_target=False, observation_features=None): """Forward pass through a grasp critic network - + Args: observations: Dictionary of observations use_target: If True, use target critics, otherwise use ensemble critics @@ -254,7 +254,7 @@ class SACPolicy( observation_features=observation_features, next_observation_features=next_observation_features, ) - + if model == "grasp_critic": # Extract grasp_critic-specific components complementary_info: dict[str, Tensor] = batch["complementary_info"] @@ -307,7 +307,7 @@ class SACPolicy( param.data * self.config.critic_target_update_weight + target_param.data * (1.0 - self.config.critic_target_update_weight) ) - + def update_temperature(self): self.temperature = self.log_alpha.exp().item() @@ -369,8 +369,17 @@ class SACPolicy( ).sum() return critics_loss - def compute_loss_grasp_critic(self, observations, actions, rewards, next_observations, done, observation_features=None, next_observation_features=None, complementary_info=None): - + def compute_loss_grasp_critic( + self, + observations, + actions, + rewards, + next_observations, + done, + observation_features=None, + next_observation_features=None, + complementary_info=None, + ): batch_size = rewards.shape[0] grasp_actions = torch.clip(actions[:, -1].long() + 1, 0, 2) # Map [-1, 0, 1] -> [0, 1, 2] @@ -632,9 +641,7 @@ class GraspCritic(nn.Module): self.parameters_to_optimize += list(self.output_layer.parameters()) def forward( - self, - observations: torch.Tensor, - observation_features: torch.Tensor | None = None + self, observations: torch.Tensor, observation_features: torch.Tensor | None = None ) -> torch.Tensor: device = get_device_from_parameters(self) # Move each tensor in observations to device diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 1fbc8803..bf65b1ec 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -248,7 +248,7 @@ class ReplayBuffer: # Initialize complementary_info storage self.complementary_info_keys = [] self.complementary_info_storage = {} - + self.initialized = True def __len__(self): @@ -291,11 +291,9 @@ class ReplayBuffer: if isinstance(value, torch.Tensor): shape = value.shape if value.ndim > 0 else (1,) self.complementary_info_storage[key] = torch.zeros( - (self.capacity, *shape), - dtype=value.dtype, - device=self.storage_device + (self.capacity, *shape), dtype=value.dtype, device=self.storage_device ) - + # Store the value if key in self.complementary_info_storage: if isinstance(value, torch.Tensor): @@ -304,7 +302,7 @@ class ReplayBuffer: # For non-tensor values (like grasp_penalty) self.complementary_info_storage[key][self.position] = torch.tensor( value, device=self.storage_device - ) + ) self.position = (self.position + 1) % self.capacity self.size = min(self.size + 1, self.capacity) @@ -366,7 +364,7 @@ class ReplayBuffer: # Add complementary_info to batch if it exists batch_complementary_info = {} - if hasattr(self, 'complementary_info_keys') and self.complementary_info_keys: + if hasattr(self, "complementary_info_keys") and self.complementary_info_keys: for key in self.complementary_info_keys: if key in self.complementary_info_storage: batch_complementary_info[key] = self.complementary_info_storage[key][idx].to(self.device) diff --git a/lerobot/scripts/server/end_effector_control_utils.py b/lerobot/scripts/server/end_effector_control_utils.py index f272426d..3b2cfa90 100644 --- a/lerobot/scripts/server/end_effector_control_utils.py +++ b/lerobot/scripts/server/end_effector_control_utils.py @@ -311,7 +311,7 @@ class GamepadController(InputController): except pygame.error: logging.error("Error reading gamepad. Is it still connected?") return 0.0, 0.0, 0.0 - + def get_gripper_action(self): """ Get gripper action using L3/R3 buttons. @@ -319,12 +319,12 @@ class GamepadController(InputController): Press right stick (R3) to close the gripper. """ import pygame - + try: # Check if buttons are pressed l3_pressed = self.joystick.get_button(9) r3_pressed = self.joystick.get_button(10) - + # Determine action based on button presses if r3_pressed: return 1.0 # Close gripper @@ -332,9 +332,9 @@ class GamepadController(InputController): return -1.0 # Open gripper else: return 0.0 # No change - + except pygame.error: - logging.error(f"Error reading gamepad. Is it still connected?") + logging.error("Error reading gamepad. Is it still connected?") return 0.0 diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 26ed1991..ac3bbb0a 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -1077,18 +1077,20 @@ class GripperPenaltyWrapper(gym.Wrapper): def reset(self, **kwargs): obs, info = self.env.reset(**kwargs) - self.last_gripper_pos = obs["observation.state"][0, 0] # first idx for the gripper + self.last_gripper_pos = obs["observation.state"][0, 0] # first idx for the gripper return obs, info def step(self, action): observation, reward, terminated, truncated, info = self.env.step(action) - if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or (action[-1] > 0.5 and self.last_gripper_pos < 0.9): + if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or ( + action[-1] > 0.5 and self.last_gripper_pos < 0.9 + ): info["grasp_penalty"] = self.penalty else: info["grasp_penalty"] = 0.0 - self.last_gripper_pos = observation["observation.state"][0, 0] # first idx for the gripper + self.last_gripper_pos = observation["observation.state"][0, 0] # first idx for the gripper return observation, reward, terminated, truncated, info @@ -1167,7 +1169,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None: env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space) env = BatchCompitableWrapper(env=env) - env= GripperPenaltyWrapper(env=env) + env = GripperPenaltyWrapper(env=env) return env diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index f79e8d57..0f760dc5 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -417,7 +417,7 @@ def add_actor_information_and_train( ) optimizers["grasp_critic"].step() - + policy.update_target_networks() policy.update_grasp_target_networks() @@ -762,7 +762,9 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module): lr=cfg.policy.actor_lr, ) optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) - optimizer_grasp_critic = torch.optim.Adam(params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr) + optimizer_grasp_critic = torch.optim.Adam( + params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr + ) optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) lr_scheduler = None optimizers = { From fe2ff516a8d354d33283163137a4a266145caa7d Mon Sep 17 00:00:00 2001 From: Michel Aractingi <michel.aractingi@huggingface.co> Date: Tue, 1 Apr 2025 11:08:15 +0200 Subject: [PATCH 07/28] Added Gripper quantization wrapper and grasp penalty removed complementary info from buffer and learner server removed get_gripper_action function added gripper parameters to `common/envs/configs.py` --- lerobot/common/envs/configs.py | 3 + lerobot/scripts/server/buffer.py | 34 ------ .../server/end_effector_control_utils.py | 25 ----- lerobot/scripts/server/gym_manipulator.py | 100 +++++++++++------- lerobot/scripts/server/learner_server.py | 3 - 5 files changed, 66 insertions(+), 99 deletions(-) diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index 825fa162..440512c3 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -203,6 +203,9 @@ class EnvWrapperConfig: joint_masking_action_space: Optional[Any] = None ee_action_space_params: Optional[EEActionSpaceConfig] = None use_gripper: bool = False + gripper_quantization_threshold: float = 0.8 + gripper_penalty: float = 0.0 + open_gripper_on_reset: bool = False @EnvConfig.register_subclass(name="gym_manipulator") diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index bf65b1ec..2af3995e 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -245,10 +245,6 @@ class ReplayBuffer: self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) - # Initialize complementary_info storage - self.complementary_info_keys = [] - self.complementary_info_storage = {} - self.initialized = True def __len__(self): @@ -282,28 +278,6 @@ class ReplayBuffer: self.dones[self.position] = done self.truncateds[self.position] = truncated - # Store complementary info if provided - if complementary_info is not None: - # Initialize storage for new keys on first encounter - for key, value in complementary_info.items(): - if key not in self.complementary_info_keys: - self.complementary_info_keys.append(key) - if isinstance(value, torch.Tensor): - shape = value.shape if value.ndim > 0 else (1,) - self.complementary_info_storage[key] = torch.zeros( - (self.capacity, *shape), dtype=value.dtype, device=self.storage_device - ) - - # Store the value - if key in self.complementary_info_storage: - if isinstance(value, torch.Tensor): - self.complementary_info_storage[key][self.position] = value - else: - # For non-tensor values (like grasp_penalty) - self.complementary_info_storage[key][self.position] = torch.tensor( - value, device=self.storage_device - ) - self.position = (self.position + 1) % self.capacity self.size = min(self.size + 1, self.capacity) @@ -362,13 +336,6 @@ class ReplayBuffer: batch_dones = self.dones[idx].to(self.device).float() batch_truncateds = self.truncateds[idx].to(self.device).float() - # Add complementary_info to batch if it exists - batch_complementary_info = {} - if hasattr(self, "complementary_info_keys") and self.complementary_info_keys: - for key in self.complementary_info_keys: - if key in self.complementary_info_storage: - batch_complementary_info[key] = self.complementary_info_storage[key][idx].to(self.device) - return BatchTransition( state=batch_state, action=batch_actions, @@ -376,7 +343,6 @@ class ReplayBuffer: next_state=batch_next_state, done=batch_dones, truncated=batch_truncateds, - complementary_info=batch_complementary_info if batch_complementary_info else None, ) @classmethod diff --git a/lerobot/scripts/server/end_effector_control_utils.py b/lerobot/scripts/server/end_effector_control_utils.py index 3b2cfa90..3bd927b4 100644 --- a/lerobot/scripts/server/end_effector_control_utils.py +++ b/lerobot/scripts/server/end_effector_control_utils.py @@ -312,31 +312,6 @@ class GamepadController(InputController): logging.error("Error reading gamepad. Is it still connected?") return 0.0, 0.0, 0.0 - def get_gripper_action(self): - """ - Get gripper action using L3/R3 buttons. - Press left stick (L3) to open the gripper. - Press right stick (R3) to close the gripper. - """ - import pygame - - try: - # Check if buttons are pressed - l3_pressed = self.joystick.get_button(9) - r3_pressed = self.joystick.get_button(10) - - # Determine action based on button presses - if r3_pressed: - return 1.0 # Close gripper - elif l3_pressed: - return -1.0 # Open gripper - else: - return 0.0 # No change - - except pygame.error: - logging.error("Error reading gamepad. Is it still connected?") - return 0.0 - class GamepadControllerHID(InputController): """Generate motion deltas from gamepad input using HIDAPI.""" diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index ac3bbb0a..3aa75466 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -761,6 +761,62 @@ class BatchCompitableWrapper(gym.ObservationWrapper): return observation +class GripperPenaltyWrapper(gym.RewardWrapper): + def __init__(self, env, penalty: float = -0.1): + super().__init__(env) + self.penalty = penalty + self.last_gripper_state = None + + def reward(self, reward, action): + gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND + + if isinstance(action, tuple): + action = action[0] + action_normalized = action[-1] / MAX_GRIPPER_COMMAND + + gripper_penalty_bool = (gripper_state_normalized < 0.1 and action_normalized > 0.9) or ( + gripper_state_normalized > 0.9 and action_normalized < 0.1 + ) + breakpoint() + + return reward + self.penalty * gripper_penalty_bool + + def step(self, action): + self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1] + obs, reward, terminated, truncated, info = self.env.step(action) + reward = self.reward(reward, action) + return obs, reward, terminated, truncated, info + + def reset(self, **kwargs): + self.last_gripper_state = None + return super().reset(**kwargs) + + +class GripperQuantizationWrapper(gym.ActionWrapper): + def __init__(self, env, quantization_threshold: float = 0.2): + super().__init__(env) + self.quantization_threshold = quantization_threshold + + def action(self, action): + is_intervention = False + if isinstance(action, tuple): + action, is_intervention = action + + gripper_command = action[-1] + # Quantize gripper command to -1, 0 or 1 + if gripper_command < -self.quantization_threshold: + gripper_command = -MAX_GRIPPER_COMMAND + elif gripper_command > self.quantization_threshold: + gripper_command = MAX_GRIPPER_COMMAND + else: + gripper_command = 0.0 + + gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1] + gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND) + action[-1] = gripper_action.item() + return action, is_intervention + + class EEActionWrapper(gym.ActionWrapper): def __init__(self, env, ee_action_space_params=None, use_gripper=False): super().__init__(env) @@ -820,17 +876,7 @@ class EEActionWrapper(gym.ActionWrapper): fk_func=self.fk_function, ) if self.use_gripper: - # Quantize gripper command to -1, 0 or 1 - if gripper_command < -0.2: - gripper_command = -1.0 - elif gripper_command > 0.2: - gripper_command = 1.0 - else: - gripper_command = 0.0 - - gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1] - gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND) - target_joint_pos[-1] = gripper_action + target_joint_pos[-1] = gripper_command return target_joint_pos, is_intervention @@ -1069,31 +1115,6 @@ class ActionScaleWrapper(gym.ActionWrapper): return action * self.scale_vector, is_intervention -class GripperPenaltyWrapper(gym.Wrapper): - def __init__(self, env, penalty=-0.05): - super().__init__(env) - self.penalty = penalty - self.last_gripper_pos = None - - def reset(self, **kwargs): - obs, info = self.env.reset(**kwargs) - self.last_gripper_pos = obs["observation.state"][0, 0] # first idx for the gripper - return obs, info - - def step(self, action): - observation, reward, terminated, truncated, info = self.env.step(action) - - if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or ( - action[-1] > 0.5 and self.last_gripper_pos < 0.9 - ): - info["grasp_penalty"] = self.penalty - else: - info["grasp_penalty"] = 0.0 - - self.last_gripper_pos = observation["observation.state"][0, 0] # first idx for the gripper - return observation, reward, terminated, truncated, info - - def make_robot_env(cfg) -> gym.vector.VectorEnv: """ Factory function to create a vectorized robot environment. @@ -1143,6 +1164,12 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: # Add reward computation and control wrappers # env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps) + if cfg.wrapper.use_gripper: + env = GripperQuantizationWrapper( + env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold + ) + # env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty) + if cfg.wrapper.ee_action_space_params is not None: env = EEActionWrapper( env=env, @@ -1169,7 +1196,6 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None: env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space) env = BatchCompitableWrapper(env=env) - env = GripperPenaltyWrapper(env=env) return env diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 0f760dc5..15de2cb7 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -375,7 +375,6 @@ def add_actor_information_and_train( observations = batch["state"] next_observations = batch["next_state"] done = batch["done"] - complementary_info = batch["complementary_info"] check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) observation_features, next_observation_features = get_observation_features( @@ -391,7 +390,6 @@ def add_actor_information_and_train( "done": done, "observation_feature": observation_features, "next_observation_feature": next_observation_features, - "complementary_info": complementary_info, } # Use the forward method for critic loss @@ -450,7 +448,6 @@ def add_actor_information_and_train( "done": done, "observation_feature": observation_features, "next_observation_feature": next_observation_features, - "complementary_info": complementary_info, } # Use the forward method for critic loss From 6a215f47ddf7e3235736f60cefb4ffb42166406c Mon Sep 17 00:00:00 2001 From: AdilZouitine <adilzouitinegm@gmail.com> Date: Tue, 1 Apr 2025 09:30:32 +0000 Subject: [PATCH 08/28] Refactor SAC configuration and policy to support discrete actions - Removed GraspCriticNetworkConfig class and integrated its parameters into SACConfig. - Added num_discrete_actions parameter to SACConfig for better action handling. - Updated SACPolicy to conditionally create grasp critic networks based on num_discrete_actions. - Enhanced grasp critic forward pass to handle discrete actions and compute losses accordingly. --- .../common/policies/sac/configuration_sac.py | 9 +- lerobot/common/policies/sac/modeling_sac.py | 103 +++++++++++------- 2 files changed, 68 insertions(+), 44 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index b1ce30b6..66d9aa45 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -42,12 +42,6 @@ class CriticNetworkConfig: final_activation: str | None = None -@dataclass -class GraspCriticNetworkConfig: - hidden_dims: list[int] = field(default_factory=lambda: [256, 256]) - activate_final: bool = True - final_activation: str | None = None - output_dim: int = 3 @dataclass @@ -152,6 +146,7 @@ class SACConfig(PreTrainedConfig): freeze_vision_encoder: bool = True image_encoder_hidden_dim: int = 32 shared_encoder: bool = True + num_discrete_actions: int | None = None # Training parameter online_steps: int = 1000000 @@ -182,7 +177,7 @@ class SACConfig(PreTrainedConfig): critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig) policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig) - + grasp_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig) concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 95ea3928..dd156918 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -33,6 +33,7 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.sac.configuration_sac import SACConfig from lerobot.common.policies.utils import get_device_from_parameters +DISCRETE_DIMENSION_INDEX = -1 class SACPolicy( PreTrainedPolicy, @@ -113,24 +114,30 @@ class SACPolicy( self.critic_ensemble = torch.compile(self.critic_ensemble) self.critic_target = torch.compile(self.critic_target) - # Create grasp critic - self.grasp_critic = GraspCritic( - encoder=encoder_critic, - input_dim=encoder_critic.output_dim, - **config.grasp_critic_network_kwargs, - ) + self.grasp_critic = None + self.grasp_critic_target = None - # Create target grasp critic - self.grasp_critic_target = GraspCritic( - encoder=encoder_critic, - input_dim=encoder_critic.output_dim, - **config.grasp_critic_network_kwargs, - ) + if config.num_discrete_actions is not None: + # Create grasp critic + self.grasp_critic = GraspCritic( + encoder=encoder_critic, + input_dim=encoder_critic.output_dim, + output_dim=config.num_discrete_actions, + **asdict(config.grasp_critic_network_kwargs), + ) - self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict()) + # Create target grasp critic + self.grasp_critic_target = GraspCritic( + encoder=encoder_critic, + input_dim=encoder_critic.output_dim, + output_dim=config.num_discrete_actions, + **asdict(config.grasp_critic_network_kwargs), + ) - self.grasp_critic = torch.compile(self.grasp_critic) - self.grasp_critic_target = torch.compile(self.grasp_critic_target) + self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict()) + + self.grasp_critic = torch.compile(self.grasp_critic) + self.grasp_critic_target = torch.compile(self.grasp_critic_target) self.actor = Policy( encoder=encoder_actor, @@ -173,6 +180,12 @@ class SACPolicy( """Select action for inference/evaluation""" actions, _, _ = self.actor(batch) actions = self.unnormalize_outputs({"action": actions})["action"] + + if self.config.num_discrete_actions is not None: + discrete_action_value = self.grasp_critic(batch) + discrete_action = torch.argmax(discrete_action_value, dim=-1) + actions = torch.cat([actions, discrete_action], dim=-1) + return actions def critic_forward( @@ -192,11 +205,12 @@ class SACPolicy( Returns: Tensor of Q-values from all critics """ + critics = self.critic_target if use_target else self.critic_ensemble q_values = critics(observations, actions, observation_features) return q_values - def grasp_critic_forward(self, observations, use_target=False, observation_features=None): + def grasp_critic_forward(self, observations, use_target=False, observation_features=None) -> torch.Tensor: """Forward pass through a grasp critic network Args: @@ -256,9 +270,6 @@ class SACPolicy( ) if model == "grasp_critic": - # Extract grasp_critic-specific components - complementary_info: dict[str, Tensor] = batch["complementary_info"] - return self.compute_loss_grasp_critic( observations=observations, actions=actions, @@ -267,7 +278,6 @@ class SACPolicy( done=done, observation_features=observation_features, next_observation_features=next_observation_features, - complementary_info=complementary_info, ) if model == "actor": @@ -349,6 +359,12 @@ class SACPolicy( td_target = rewards + (1 - done) * self.config.discount * min_q # 3- compute predicted qs + if self.config.num_discrete_actions is not None: + # NOTE: We only want to keep the continuous action part + # In the buffer we have the full action space (continuous + discrete) + # We need to split them before concatenating them in the critic forward + actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX] + q_preds = self.critic_forward( observations=observations, actions=actions, @@ -378,30 +394,43 @@ class SACPolicy( done, observation_features=None, next_observation_features=None, - complementary_info=None, ): - batch_size = rewards.shape[0] - grasp_actions = torch.clip(actions[:, -1].long() + 1, 0, 2) # Map [-1, 0, 1] -> [0, 1, 2] + # NOTE: We only want to keep the discrete action part + # In the buffer we have the full action space (continuous + discrete) + # We need to split them before concatenating them in the critic forward + actions: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:] + actions = actions.long() with torch.no_grad(): + # For DQN, select actions using online network, evaluate with target network next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=False) best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1) + + # Get target Q-values from target network + target_next_grasp_qs = self.grasp_critic_forward(observations=next_observations, use_target=True) + + # Use gather to select Q-values for best actions + target_next_grasp_q = torch.gather( + target_next_grasp_qs, + dim=1, + index=best_next_grasp_action.unsqueeze(-1) + ).squeeze(-1) - target_next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=True) - target_next_grasp_q = target_next_grasp_qs[torch.arange(batch_size), best_next_grasp_action] + # Compute target Q-value with Bellman equation + target_grasp_q = rewards + (1 - done) * self.config.discount * target_next_grasp_q - # Get the grasp penalty from complementary_info - grasp_penalty = torch.zeros_like(rewards) - if complementary_info is not None and "grasp_penalty" in complementary_info: - grasp_penalty = complementary_info["grasp_penalty"] + # Get predicted Q-values for current observations + predicted_grasp_qs = self.grasp_critic_forward(observations=observations, use_target=False) + + # Use gather to select Q-values for taken actions + predicted_grasp_q = torch.gather( + predicted_grasp_qs, + dim=1, + index=actions.unsqueeze(-1) + ).squeeze(-1) - grasp_rewards = rewards + grasp_penalty - target_grasp_q = grasp_rewards + (1 - done) * self.config.discount * target_next_grasp_q - - predicted_grasp_qs = self.grasp_critic_forward(observations, use_target=False) - predicted_grasp_q = predicted_grasp_qs[torch.arange(batch_size), grasp_actions] - - grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q, reduction="mean") + # Compute MSE loss between predicted and target Q-values + grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q) return grasp_critic_loss def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor: @@ -611,7 +640,7 @@ class GraspCritic(nn.Module): self, encoder: Optional[nn.Module], network: nn.Module, - output_dim: int = 3, + output_dim: int = 3, # TODO (azouitine): rename it number of discret acitons smth like that init_final: Optional[float] = None, encoder_is_shared: bool = False, ): From 306c735172a6ffd0d1f1fa0866f29abd8d2c597c Mon Sep 17 00:00:00 2001 From: AdilZouitine <adilzouitinegm@gmail.com> Date: Tue, 1 Apr 2025 11:42:28 +0000 Subject: [PATCH 09/28] Refactor SAC policy and training loop to enhance discrete action support - Updated SACPolicy to conditionally compute losses for grasp critic based on num_discrete_actions. - Simplified forward method to return loss outputs as a dictionary for better clarity. - Adjusted learner_server to handle both main and grasp critic losses during training. - Ensured optimizers are created conditionally for grasp critic based on configuration settings. --- .../common/policies/sac/configuration_sac.py | 2 +- lerobot/common/policies/sac/modeling_sac.py | 54 ++++---- lerobot/scripts/server/learner_server.py | 120 +++++++++--------- 3 files changed, 86 insertions(+), 90 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 66d9aa45..ae38b1c5 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -87,6 +87,7 @@ class SACConfig(PreTrainedConfig): freeze_vision_encoder: Whether to freeze the vision encoder during training. image_encoder_hidden_dim: Hidden dimension size for the image encoder. shared_encoder: Whether to use a shared encoder for actor and critic. + num_discrete_actions: Number of discrete actions, eg for gripper actions. concurrency: Configuration for concurrency settings. actor_learner: Configuration for actor-learner architecture. online_steps: Number of steps for online training. @@ -162,7 +163,6 @@ class SACConfig(PreTrainedConfig): num_critics: int = 2 num_subsample_critics: int | None = None critic_lr: float = 3e-4 - grasp_critic_lr: float = 3e-4 actor_lr: float = 3e-4 temperature_lr: float = 3e-4 critic_target_update_weight: float = 0.005 diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index dd156918..d0e8b25d 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -228,7 +228,7 @@ class SACPolicy( def forward( self, batch: dict[str, Tensor | dict[str, Tensor]], - model: Literal["actor", "critic", "grasp_critic", "temperature"] = "critic", + model: Literal["actor", "critic", "temperature"] = "critic", ) -> dict[str, Tensor]: """Compute the loss for the given model @@ -246,7 +246,6 @@ class SACPolicy( Returns: The computed loss tensor """ - # TODO: (maractingi, azouitine) Respect the function signature we output tensors # Extract common components from batch actions: Tensor = batch["action"] observations: dict[str, Tensor] = batch["state"] @@ -259,7 +258,7 @@ class SACPolicy( done: Tensor = batch["done"] next_observation_features: Tensor = batch.get("next_observation_feature") - return self.compute_loss_critic( + loss_critic = self.compute_loss_critic( observations=observations, actions=actions, rewards=rewards, @@ -268,29 +267,28 @@ class SACPolicy( observation_features=observation_features, next_observation_features=next_observation_features, ) + if self.config.num_discrete_actions is not None: + loss_grasp_critic = self.compute_loss_grasp_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + ) + return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic} - if model == "grasp_critic": - return self.compute_loss_grasp_critic( - observations=observations, - actions=actions, - rewards=rewards, - next_observations=next_observations, - done=done, - observation_features=observation_features, - next_observation_features=next_observation_features, - ) if model == "actor": - return self.compute_loss_actor( + return {"loss_actor": self.compute_loss_actor( observations=observations, observation_features=observation_features, - ) + )} if model == "temperature": - return self.compute_loss_temperature( + return {"loss_temperature": self.compute_loss_temperature( observations=observations, observation_features=observation_features, - ) + )} raise ValueError(f"Unknown model type: {model}") @@ -305,18 +303,16 @@ class SACPolicy( param.data * self.config.critic_target_update_weight + target_param.data * (1.0 - self.config.critic_target_update_weight) ) - - def update_grasp_target_networks(self): - """Update grasp target networks with exponential moving average""" - for target_param, param in zip( - self.grasp_critic_target.parameters(), - self.grasp_critic.parameters(), - strict=False, - ): - target_param.data.copy_( - param.data * self.config.critic_target_update_weight - + target_param.data * (1.0 - self.config.critic_target_update_weight) - ) + if self.config.num_discrete_actions is not None: + for target_param, param in zip( + self.grasp_critic_target.parameters(), + self.grasp_critic.parameters(), + strict=False, + ): + target_param.data.copy_( + param.data * self.config.critic_target_update_weight + + target_param.data * (1.0 - self.config.critic_target_update_weight) + ) def update_temperature(self): self.temperature = self.log_alpha.exp().item() diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 15de2cb7..627a1a17 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -392,32 +392,30 @@ def add_actor_information_and_train( "next_observation_feature": next_observation_features, } - # Use the forward method for critic loss - loss_critic = policy.forward(forward_batch, model="critic") + # Use the forward method for critic loss (includes both main critic and grasp critic) + critic_output = policy.forward(forward_batch, model="critic") + + # Main critic optimization + loss_critic = critic_output["loss_critic"] optimizers["critic"].zero_grad() loss_critic.backward() - - # clip gradients critic_grad_norm = torch.nn.utils.clip_grad_norm_( parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value ) - optimizers["critic"].step() - # Add gripper critic optimization - loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic") - optimizers["grasp_critic"].zero_grad() - loss_grasp_critic.backward() - - # clip gradients - grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value - ) - - optimizers["grasp_critic"].step() + # Grasp critic optimization (if available) + if "loss_grasp_critic" in critic_output and hasattr(policy, "grasp_critic"): + loss_grasp_critic = critic_output["loss_grasp_critic"] + optimizers["grasp_critic"].zero_grad() + loss_grasp_critic.backward() + grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value + ) + optimizers["grasp_critic"].step() + # Update target networks policy.update_target_networks() - policy.update_grasp_target_networks() batch = replay_buffer.sample(batch_size=batch_size) @@ -450,81 +448,80 @@ def add_actor_information_and_train( "next_observation_feature": next_observation_features, } - # Use the forward method for critic loss - loss_critic = policy.forward(forward_batch, model="critic") + # Use the forward method for critic loss (includes both main critic and grasp critic) + critic_output = policy.forward(forward_batch, model="critic") + + # Main critic optimization + loss_critic = critic_output["loss_critic"] optimizers["critic"].zero_grad() loss_critic.backward() - - # clip gradients critic_grad_norm = torch.nn.utils.clip_grad_norm_( parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value ).item() - optimizers["critic"].step() - training_infos = {} - training_infos["loss_critic"] = loss_critic.item() - training_infos["critic_grad_norm"] = critic_grad_norm + # Initialize training info dictionary + training_infos = { + "loss_critic": loss_critic.item(), + "critic_grad_norm": critic_grad_norm, + } - # Add gripper critic optimization - loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic") - optimizers["grasp_critic"].zero_grad() - loss_grasp_critic.backward() - - # clip gradients - grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value - ) - - optimizers["grasp_critic"].step() - - # Add training info for the grasp critic - training_infos["loss_grasp_critic"] = loss_grasp_critic.item() - training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm + # Grasp critic optimization (if available) + if "loss_grasp_critic" in critic_output: + loss_grasp_critic = critic_output["loss_grasp_critic"] + optimizers["grasp_critic"].zero_grad() + loss_grasp_critic.backward() + grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value + ).item() + optimizers["grasp_critic"].step() + + # Add grasp critic info to training info + training_infos["loss_grasp_critic"] = loss_grasp_critic.item() + training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm + # Actor and temperature optimization (at specified frequency) if optimization_step % policy_update_freq == 0: for _ in range(policy_update_freq): - # Use the forward method for actor loss - loss_actor = policy.forward(forward_batch, model="actor") - + # Actor optimization + actor_output = policy.forward(forward_batch, model="actor") + loss_actor = actor_output["loss_actor"] optimizers["actor"].zero_grad() loss_actor.backward() - - # clip gradients actor_grad_norm = torch.nn.utils.clip_grad_norm_( parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value ).item() - optimizers["actor"].step() - + + # Add actor info to training info training_infos["loss_actor"] = loss_actor.item() training_infos["actor_grad_norm"] = actor_grad_norm - # Temperature optimization using forward method - loss_temperature = policy.forward(forward_batch, model="temperature") + # Temperature optimization + temperature_output = policy.forward(forward_batch, model="temperature") + loss_temperature = temperature_output["loss_temperature"] optimizers["temperature"].zero_grad() loss_temperature.backward() - - # clip gradients temp_grad_norm = torch.nn.utils.clip_grad_norm_( parameters=[policy.log_alpha], max_norm=clip_grad_norm_value ).item() - optimizers["temperature"].step() - + + # Add temperature info to training info training_infos["loss_temperature"] = loss_temperature.item() training_infos["temperature_grad_norm"] = temp_grad_norm training_infos["temperature"] = policy.temperature + # Update temperature policy.update_temperature() - # Check if it's time to push updated policy to actors + # Push policy to actors if needed if time.time() - last_time_policy_pushed > policy_parameters_push_frequency: push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) last_time_policy_pushed = time.time() + # Update target networks policy.update_target_networks() - policy.update_grasp_target_networks() # Log training metrics at specified intervals if optimization_step % log_freq == 0: @@ -727,7 +724,7 @@ def save_training_checkpoint( logging.info("Resume training") -def make_optimizers_and_scheduler(cfg, policy: nn.Module): +def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module): """ Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy. @@ -759,17 +756,20 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module): lr=cfg.policy.actor_lr, ) optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) - optimizer_grasp_critic = torch.optim.Adam( - params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr - ) + + if cfg.policy.num_discrete_actions is not None: + optimizer_grasp_critic = torch.optim.Adam( + params=policy.grasp_critic.parameters(), lr=policy.critic_lr + ) optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) lr_scheduler = None optimizers = { "actor": optimizer_actor, "critic": optimizer_critic, - "grasp_critic": optimizer_grasp_critic, "temperature": optimizer_temperature, } + if cfg.policy.num_discrete_actions is not None: + optimizers["grasp_critic"] = optimizer_grasp_critic return optimizers, lr_scheduler From 451a7b01db13d3e7dc4a111a1f1513f1a1eba396 Mon Sep 17 00:00:00 2001 From: AdilZouitine <adilzouitinegm@gmail.com> Date: Tue, 1 Apr 2025 14:22:08 +0000 Subject: [PATCH 10/28] Add mock gripper support and enhance SAC policy action handling - Introduced mock_gripper parameter in ManiskillEnvConfig to enable gripper simulation. - Added ManiskillMockGripperWrapper to adjust action space for environments with discrete actions. - Updated SACPolicy to compute continuous action dimensions correctly, ensuring compatibility with the new gripper setup. - Refactored action handling in the training loop to accommodate the changes in action dimensions. --- lerobot/common/envs/configs.py | 1 + lerobot/common/policies/sac/modeling_sac.py | 18 +++-- .../scripts/server/maniskill_manipulator.py | 71 ++++++++++++------- 3 files changed, 59 insertions(+), 31 deletions(-) diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index 440512c3..a6eda93b 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -257,6 +257,7 @@ class ManiskillEnvConfig(EnvConfig): robot: str = "so100" # This is a hack to make the robot config work video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig) wrapper: WrapperConfig = field(default_factory=WrapperConfig) + mock_gripper: bool = False features: dict[str, PolicyFeature] = field( default_factory=lambda: { "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index d0e8b25d..0c3d76d2 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -33,7 +33,7 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.sac.configuration_sac import SACConfig from lerobot.common.policies.utils import get_device_from_parameters -DISCRETE_DIMENSION_INDEX = -1 +DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension class SACPolicy( PreTrainedPolicy, @@ -82,7 +82,7 @@ class SACPolicy( # Create a list of critic heads critic_heads = [ CriticHead( - input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0], + input_dim=encoder_critic.output_dim + continuous_action_dim, **asdict(config.critic_network_kwargs), ) for _ in range(config.num_critics) @@ -97,7 +97,7 @@ class SACPolicy( # Create target critic heads as deepcopies of the original critic heads target_critic_heads = [ CriticHead( - input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0], + input_dim=encoder_critic.output_dim + continuous_action_dim, **asdict(config.critic_network_kwargs), ) for _ in range(config.num_critics) @@ -117,7 +117,10 @@ class SACPolicy( self.grasp_critic = None self.grasp_critic_target = None + continuous_action_dim = config.output_features["action"].shape[0] if config.num_discrete_actions is not None: + + continuous_action_dim -= 1 # Create grasp critic self.grasp_critic = GraspCritic( encoder=encoder_critic, @@ -139,15 +142,16 @@ class SACPolicy( self.grasp_critic = torch.compile(self.grasp_critic) self.grasp_critic_target = torch.compile(self.grasp_critic_target) + self.actor = Policy( encoder=encoder_actor, network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)), - action_dim=config.output_features["action"].shape[0], + action_dim=continuous_action_dim, encoder_is_shared=config.shared_encoder, **asdict(config.policy_kwargs), ) if config.target_entropy is None: - config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2) + config.target_entropy = -np.prod(continuous_action_dim) / 2 # (-dim(A)/2) # TODO (azouitine): Handle the case where the temparameter is a fixed # TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise @@ -275,7 +279,9 @@ class SACPolicy( next_observations=next_observations, done=done, ) - return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic} + return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic} + + return {"loss_critic": loss_critic} if model == "actor": diff --git a/lerobot/scripts/server/maniskill_manipulator.py b/lerobot/scripts/server/maniskill_manipulator.py index e10b8766..f4a89888 100644 --- a/lerobot/scripts/server/maniskill_manipulator.py +++ b/lerobot/scripts/server/maniskill_manipulator.py @@ -11,6 +11,10 @@ from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv from lerobot.common.envs.configs import ManiskillEnvConfig from lerobot.configs import parser +from lerobot.configs.train import TrainPipelineConfig +from lerobot.common.policies.sac.configuration_sac import SACConfig +from lerobot.common.policies.sac.modeling_sac import SACPolicy + def preprocess_maniskill_observation( @@ -152,6 +156,21 @@ class TimeLimitWrapper(gym.Wrapper): self.current_step = 0 return super().reset(seed=seed, options=options) +class ManiskillMockGripperWrapper(gym.Wrapper): + def __init__(self, env, nb_discrete_actions: int = 3): + super().__init__(env) + new_shape = env.action_space[0].shape[0] + 1 + new_low = np.concatenate([env.action_space[0].low, [0]]) + new_high = np.concatenate([env.action_space[0].high, [nb_discrete_actions - 1]]) + action_space_agent = gym.spaces.Box(low=new_low, high=new_high, shape=(new_shape,)) + self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1])) + + def step(self, action): + action_agent, telop_action = action + real_action = action_agent[:-1] + final_action = (real_action, telop_action) + obs, reward, terminated, truncated, info = self.env.step(final_action) + return obs, reward, terminated, truncated, info def make_maniskill( cfg: ManiskillEnvConfig, @@ -197,40 +216,42 @@ def make_maniskill( env = ManiSkillCompat(env) env = ManiSkillActionWrapper(env) env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control + if cfg.mock_gripper: + env = ManiskillMockGripperWrapper(env, nb_discrete_actions=3) return env -@parser.wrap() -def main(cfg: ManiskillEnvConfig): - """Main function to run the ManiSkill environment.""" - # Create the ManiSkill environment - env = make_maniskill(cfg, n_envs=1) +# @parser.wrap() +# def main(cfg: TrainPipelineConfig): +# """Main function to run the ManiSkill environment.""" +# # Create the ManiSkill environment +# env = make_maniskill(cfg.env, n_envs=1) - # Reset the environment - obs, info = env.reset() +# # Reset the environment +# obs, info = env.reset() - # Run a simple interaction loop - sum_reward = 0 - for i in range(100): - # Sample a random action - action = env.action_space.sample() +# # Run a simple interaction loop +# sum_reward = 0 +# for i in range(100): +# # Sample a random action +# action = env.action_space.sample() - # Step the environment - start_time = time.perf_counter() - obs, reward, terminated, truncated, info = env.step(action) - step_time = time.perf_counter() - start_time - sum_reward += reward - # Log information +# # Step the environment +# start_time = time.perf_counter() +# obs, reward, terminated, truncated, info = env.step(action) +# step_time = time.perf_counter() - start_time +# sum_reward += reward +# # Log information - # Reset if episode terminated - if terminated or truncated: - logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s") - sum_reward = 0 - obs, info = env.reset() +# # Reset if episode terminated +# if terminated or truncated: +# logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s") +# sum_reward = 0 +# obs, info = env.reset() - # Close the environment - env.close() +# # Close the environment +# env.close() # if __name__ == "__main__": From 699d374d895f8facdd2e5b66e379508e2466c97f Mon Sep 17 00:00:00 2001 From: AdilZouitine <adilzouitinegm@gmail.com> Date: Tue, 1 Apr 2025 15:43:29 +0000 Subject: [PATCH 11/28] Refactor SACPolicy for improved readability and action dimension handling - Cleaned up code formatting for better readability, including consistent spacing and removal of unnecessary blank lines. - Consolidated continuous action dimension calculation to enhance clarity and maintainability. - Simplified loss return statements in the forward method to improve code structure. - Ensured grasp critic parameters are included conditionally based on configuration settings. --- lerobot/common/policies/sac/modeling_sac.py | 59 +++++++++++---------- lerobot/scripts/server/learner_server.py | 14 ++--- 2 files changed, 37 insertions(+), 36 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 0c3d76d2..41ff7d8c 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -33,7 +33,8 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.sac.configuration_sac import SACConfig from lerobot.common.policies.utils import get_device_from_parameters -DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension +DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension + class SACPolicy( PreTrainedPolicy, @@ -50,6 +51,10 @@ class SACPolicy( config.validate_features() self.config = config + continuous_action_dim = config.output_features["action"].shape[0] + if config.num_discrete_actions is not None: + continuous_action_dim -= 1 + if config.dataset_stats is not None: input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats) self.normalize_inputs = Normalize( @@ -117,10 +122,7 @@ class SACPolicy( self.grasp_critic = None self.grasp_critic_target = None - continuous_action_dim = config.output_features["action"].shape[0] if config.num_discrete_actions is not None: - - continuous_action_dim -= 1 # Create grasp critic self.grasp_critic = GraspCritic( encoder=encoder_critic, @@ -142,7 +144,6 @@ class SACPolicy( self.grasp_critic = torch.compile(self.grasp_critic) self.grasp_critic_target = torch.compile(self.grasp_critic_target) - self.actor = Policy( encoder=encoder_actor, network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)), @@ -162,11 +163,14 @@ class SACPolicy( self.temperature = self.log_alpha.exp().item() def get_optim_params(self) -> dict: - return { + optim_params = { "actor": self.actor.parameters_to_optimize, "critic": self.critic_ensemble.parameters_to_optimize, "temperature": self.log_alpha, } + if self.config.num_discrete_actions is not None: + optim_params["grasp_critic"] = self.grasp_critic.parameters_to_optimize + return optim_params def reset(self): """Reset the policy""" @@ -262,7 +266,7 @@ class SACPolicy( done: Tensor = batch["done"] next_observation_features: Tensor = batch.get("next_observation_feature") - loss_critic = self.compute_loss_critic( + loss_critic = self.compute_loss_critic( observations=observations, actions=actions, rewards=rewards, @@ -283,18 +287,21 @@ class SACPolicy( return {"loss_critic": loss_critic} - if model == "actor": - return {"loss_actor": self.compute_loss_actor( - observations=observations, - observation_features=observation_features, - )} + return { + "loss_actor": self.compute_loss_actor( + observations=observations, + observation_features=observation_features, + ) + } if model == "temperature": - return {"loss_temperature": self.compute_loss_temperature( - observations=observations, - observation_features=observation_features, - )} + return { + "loss_temperature": self.compute_loss_temperature( + observations=observations, + observation_features=observation_features, + ) + } raise ValueError(f"Unknown model type: {model}") @@ -366,7 +373,7 @@ class SACPolicy( # In the buffer we have the full action space (continuous + discrete) # We need to split them before concatenating them in the critic forward actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX] - + q_preds = self.critic_forward( observations=observations, actions=actions, @@ -407,15 +414,13 @@ class SACPolicy( # For DQN, select actions using online network, evaluate with target network next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=False) best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1) - + # Get target Q-values from target network target_next_grasp_qs = self.grasp_critic_forward(observations=next_observations, use_target=True) - + # Use gather to select Q-values for best actions target_next_grasp_q = torch.gather( - target_next_grasp_qs, - dim=1, - index=best_next_grasp_action.unsqueeze(-1) + target_next_grasp_qs, dim=1, index=best_next_grasp_action.unsqueeze(-1) ).squeeze(-1) # Compute target Q-value with Bellman equation @@ -423,13 +428,9 @@ class SACPolicy( # Get predicted Q-values for current observations predicted_grasp_qs = self.grasp_critic_forward(observations=observations, use_target=False) - + # Use gather to select Q-values for taken actions - predicted_grasp_q = torch.gather( - predicted_grasp_qs, - dim=1, - index=actions.unsqueeze(-1) - ).squeeze(-1) + predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions.unsqueeze(-1)).squeeze(-1) # Compute MSE loss between predicted and target Q-values grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q) @@ -642,7 +643,7 @@ class GraspCritic(nn.Module): self, encoder: Optional[nn.Module], network: nn.Module, - output_dim: int = 3, # TODO (azouitine): rename it number of discret acitons smth like that + output_dim: int = 3, # TODO (azouitine): rename it number of discret acitons smth like that init_final: Optional[float] = None, encoder_is_shared: bool = False, ): diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 627a1a17..c57f83fc 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -394,7 +394,7 @@ def add_actor_information_and_train( # Use the forward method for critic loss (includes both main critic and grasp critic) critic_output = policy.forward(forward_batch, model="critic") - + # Main critic optimization loss_critic = critic_output["loss_critic"] optimizers["critic"].zero_grad() @@ -405,7 +405,7 @@ def add_actor_information_and_train( optimizers["critic"].step() # Grasp critic optimization (if available) - if "loss_grasp_critic" in critic_output and hasattr(policy, "grasp_critic"): + if "loss_grasp_critic" in critic_output: loss_grasp_critic = critic_output["loss_grasp_critic"] optimizers["grasp_critic"].zero_grad() loss_grasp_critic.backward() @@ -450,7 +450,7 @@ def add_actor_information_and_train( # Use the forward method for critic loss (includes both main critic and grasp critic) critic_output = policy.forward(forward_batch, model="critic") - + # Main critic optimization loss_critic = critic_output["loss_critic"] optimizers["critic"].zero_grad() @@ -475,7 +475,7 @@ def add_actor_information_and_train( parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value ).item() optimizers["grasp_critic"].step() - + # Add grasp critic info to training info training_infos["loss_grasp_critic"] = loss_grasp_critic.item() training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm @@ -492,7 +492,7 @@ def add_actor_information_and_train( parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value ).item() optimizers["actor"].step() - + # Add actor info to training info training_infos["loss_actor"] = loss_actor.item() training_infos["actor_grad_norm"] = actor_grad_norm @@ -506,7 +506,7 @@ def add_actor_information_and_train( parameters=[policy.log_alpha], max_norm=clip_grad_norm_value ).item() optimizers["temperature"].step() - + # Add temperature info to training info training_infos["loss_temperature"] = loss_temperature.item() training_infos["temperature_grad_norm"] = temp_grad_norm @@ -756,7 +756,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module): lr=cfg.policy.actor_lr, ) optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) - + if cfg.policy.num_discrete_actions is not None: optimizer_grasp_critic = torch.optim.Adam( params=policy.grasp_critic.parameters(), lr=policy.critic_lr From 0ed7ff142cd45992469cf2aee7a4f36f8571bb92 Mon Sep 17 00:00:00 2001 From: AdilZouitine <adilzouitinegm@gmail.com> Date: Wed, 2 Apr 2025 15:50:39 +0000 Subject: [PATCH 12/28] Enhance SACPolicy and learner server for improved grasp critic integration - Updated SACPolicy to conditionally compute grasp critic losses based on the presence of discrete actions. - Refactored the forward method to handle grasp critic model selection and loss computation more clearly. - Adjusted learner server to utilize optimized parameters for grasp critic during training. - Improved action handling in the ManiskillMockGripperWrapper to accommodate both tuple and single action inputs. --- lerobot/common/policies/sac/modeling_sac.py | 95 +++++++++++-------- lerobot/scripts/server/learner_server.py | 16 ++-- .../scripts/server/maniskill_manipulator.py | 11 ++- 3 files changed, 72 insertions(+), 50 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 41ff7d8c..af624592 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -52,8 +52,6 @@ class SACPolicy( self.config = config continuous_action_dim = config.output_features["action"].shape[0] - if config.num_discrete_actions is not None: - continuous_action_dim -= 1 if config.dataset_stats is not None: input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats) @@ -191,7 +189,7 @@ class SACPolicy( if self.config.num_discrete_actions is not None: discrete_action_value = self.grasp_critic(batch) - discrete_action = torch.argmax(discrete_action_value, dim=-1) + discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True) actions = torch.cat([actions, discrete_action], dim=-1) return actions @@ -236,7 +234,7 @@ class SACPolicy( def forward( self, batch: dict[str, Tensor | dict[str, Tensor]], - model: Literal["actor", "critic", "temperature"] = "critic", + model: Literal["actor", "critic", "temperature", "grasp_critic"] = "critic", ) -> dict[str, Tensor]: """Compute the loss for the given model @@ -275,18 +273,25 @@ class SACPolicy( observation_features=observation_features, next_observation_features=next_observation_features, ) - if self.config.num_discrete_actions is not None: - loss_grasp_critic = self.compute_loss_grasp_critic( - observations=observations, - actions=actions, - rewards=rewards, - next_observations=next_observations, - done=done, - ) - return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic} return {"loss_critic": loss_critic} + if model == "grasp_critic" and self.config.num_discrete_actions is not None: + # Extract critic-specific components + rewards: Tensor = batch["reward"] + next_observations: dict[str, Tensor] = batch["next_state"] + done: Tensor = batch["done"] + next_observation_features: Tensor = batch.get("next_observation_feature") + loss_grasp_critic = self.compute_loss_grasp_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + observation_features=observation_features, + next_observation_features=next_observation_features, + ) + return {"loss_grasp_critic": loss_grasp_critic} if model == "actor": return { "loss_actor": self.compute_loss_actor( @@ -373,7 +378,6 @@ class SACPolicy( # In the buffer we have the full action space (continuous + discrete) # We need to split them before concatenating them in the critic forward actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX] - q_preds = self.critic_forward( observations=observations, actions=actions, @@ -407,30 +411,38 @@ class SACPolicy( # NOTE: We only want to keep the discrete action part # In the buffer we have the full action space (continuous + discrete) # We need to split them before concatenating them in the critic forward - actions: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:] - actions = actions.long() + actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone() + actions_discrete = actions_discrete.long() with torch.no_grad(): # For DQN, select actions using online network, evaluate with target network - next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=False) - best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1) + next_grasp_qs = self.grasp_critic_forward( + next_observations, use_target=False, observation_features=next_observation_features + ) + best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1, keepdim=True) # Get target Q-values from target network - target_next_grasp_qs = self.grasp_critic_forward(observations=next_observations, use_target=True) + target_next_grasp_qs = self.grasp_critic_forward( + observations=next_observations, + use_target=True, + observation_features=next_observation_features, + ) # Use gather to select Q-values for best actions target_next_grasp_q = torch.gather( - target_next_grasp_qs, dim=1, index=best_next_grasp_action.unsqueeze(-1) + target_next_grasp_qs, dim=1, index=best_next_grasp_action ).squeeze(-1) # Compute target Q-value with Bellman equation target_grasp_q = rewards + (1 - done) * self.config.discount * target_next_grasp_q # Get predicted Q-values for current observations - predicted_grasp_qs = self.grasp_critic_forward(observations=observations, use_target=False) + predicted_grasp_qs = self.grasp_critic_forward( + observations=observations, use_target=False, observation_features=observation_features + ) # Use gather to select Q-values for taken actions - predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions.unsqueeze(-1)).squeeze(-1) + predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions_discrete).squeeze(-1) # Compute MSE loss between predicted and target Q-values grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q) @@ -642,49 +654,52 @@ class GraspCritic(nn.Module): def __init__( self, encoder: Optional[nn.Module], - network: nn.Module, - output_dim: int = 3, # TODO (azouitine): rename it number of discret acitons smth like that + input_dim: int, + hidden_dims: list[int], + output_dim: int = 3, + activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), + activate_final: bool = False, + dropout_rate: Optional[float] = None, init_final: Optional[float] = None, - encoder_is_shared: bool = False, + final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None, ): super().__init__() self.encoder = encoder - self.network = network self.output_dim = output_dim - # Find the last Linear layer's output dimension - for layer in reversed(network.net): - if isinstance(layer, nn.Linear): - out_features = layer.out_features - break + self.net = MLP( + input_dim=input_dim, + hidden_dims=hidden_dims, + activations=activations, + activate_final=activate_final, + dropout_rate=dropout_rate, + final_activation=final_activation, + ) - self.parameters_to_optimize += list(self.network.parameters()) - - if self.encoder is not None and not encoder_is_shared: - self.parameters_to_optimize += list(self.encoder.parameters()) - - self.output_layer = nn.Linear(in_features=out_features, out_features=self.output_dim) + self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=self.output_dim) if init_final is not None: nn.init.uniform_(self.output_layer.weight, -init_final, init_final) nn.init.uniform_(self.output_layer.bias, -init_final, init_final) else: orthogonal_init()(self.output_layer.weight) + self.parameters_to_optimize = [] + self.parameters_to_optimize += list(self.net.parameters()) self.parameters_to_optimize += list(self.output_layer.parameters()) def forward( self, observations: torch.Tensor, observation_features: torch.Tensor | None = None ) -> torch.Tensor: device = get_device_from_parameters(self) - # Move each tensor in observations to device + # Move each tensor in observations to device by cloning first to avoid inplace operations observations = {k: v.to(device) for k, v in observations.items()} # Encode observations if encoder exists obs_enc = ( - observation_features + observation_features.to(device) if observation_features is not None else (observations if self.encoder is None else self.encoder(observations)) ) - return self.output_layer(self.network(obs_enc)) + return self.output_layer(self.net(obs_enc)) class Policy(nn.Module): diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index c57f83fc..ce9a1b41 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -405,12 +405,13 @@ def add_actor_information_and_train( optimizers["critic"].step() # Grasp critic optimization (if available) - if "loss_grasp_critic" in critic_output: - loss_grasp_critic = critic_output["loss_grasp_critic"] + if policy.config.num_discrete_actions is not None: + discrete_critic_output = policy.forward(forward_batch, model="grasp_critic") + loss_grasp_critic = discrete_critic_output["loss_grasp_critic"] optimizers["grasp_critic"].zero_grad() loss_grasp_critic.backward() grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value + parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value ) optimizers["grasp_critic"].step() @@ -467,12 +468,13 @@ def add_actor_information_and_train( } # Grasp critic optimization (if available) - if "loss_grasp_critic" in critic_output: - loss_grasp_critic = critic_output["loss_grasp_critic"] + if policy.config.num_discrete_actions is not None: + discrete_critic_output = policy.forward(forward_batch, model="grasp_critic") + loss_grasp_critic = discrete_critic_output["loss_grasp_critic"] optimizers["grasp_critic"].zero_grad() loss_grasp_critic.backward() grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value + parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value ).item() optimizers["grasp_critic"].step() @@ -759,7 +761,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module): if cfg.policy.num_discrete_actions is not None: optimizer_grasp_critic = torch.optim.Adam( - params=policy.grasp_critic.parameters(), lr=policy.critic_lr + params=policy.grasp_critic.parameters_to_optimize, lr=cfg.policy.critic_lr ) optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) lr_scheduler = None diff --git a/lerobot/scripts/server/maniskill_manipulator.py b/lerobot/scripts/server/maniskill_manipulator.py index f4a89888..b5c181c1 100644 --- a/lerobot/scripts/server/maniskill_manipulator.py +++ b/lerobot/scripts/server/maniskill_manipulator.py @@ -16,7 +16,6 @@ from lerobot.common.policies.sac.configuration_sac import SACConfig from lerobot.common.policies.sac.modeling_sac import SACPolicy - def preprocess_maniskill_observation( observations: dict[str, np.ndarray], ) -> dict[str, torch.Tensor]: @@ -156,6 +155,7 @@ class TimeLimitWrapper(gym.Wrapper): self.current_step = 0 return super().reset(seed=seed, options=options) + class ManiskillMockGripperWrapper(gym.Wrapper): def __init__(self, env, nb_discrete_actions: int = 3): super().__init__(env) @@ -166,11 +166,16 @@ class ManiskillMockGripperWrapper(gym.Wrapper): self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1])) def step(self, action): - action_agent, telop_action = action + if isinstance(action, tuple): + action_agent, telop_action = action + else: + telop_action = 0 + action_agent = action real_action = action_agent[:-1] final_action = (real_action, telop_action) obs, reward, terminated, truncated, info = self.env.step(final_action) - return obs, reward, terminated, truncated, info + return obs, reward, terminated, truncated, info + def make_maniskill( cfg: ManiskillEnvConfig, From 51f1625c2004cdb8adf50c5c862a1c778d746c97 Mon Sep 17 00:00:00 2001 From: AdilZouitine <adilzouitinegm@gmail.com> Date: Thu, 3 Apr 2025 07:44:46 +0000 Subject: [PATCH 13/28] Enhance SACPolicy to support shared encoder and optimize action selection - Cached encoder output in select_action method to reduce redundant computations. - Updated action selection and grasp critic calls to utilize cached encoder features when available. --- lerobot/common/policies/sac/modeling_sac.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index af624592..2246bf8c 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -81,6 +81,7 @@ class SACPolicy( else: encoder_critic = SACObservationEncoder(config, self.normalize_inputs) encoder_actor = SACObservationEncoder(config, self.normalize_inputs) + self.shared_encoder = config.shared_encoder # Create a list of critic heads critic_heads = [ @@ -184,11 +185,16 @@ class SACPolicy( @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select action for inference/evaluation""" - actions, _, _ = self.actor(batch) + # We cached the encoder output to avoid recomputing it + observations_features = None + if self.shared_encoder and self.actor.encoder is not None: + observations_features = self.actor.encoder(batch) + + actions, _, _ = self.actor(batch, observations_features) actions = self.unnormalize_outputs({"action": actions})["action"] if self.config.num_discrete_actions is not None: - discrete_action_value = self.grasp_critic(batch) + discrete_action_value = self.grasp_critic(batch, observations_features) discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True) actions = torch.cat([actions, discrete_action], dim=-1) From 38a8dbd9c9a0fe5bf9fb0244be068869ea0bc1c1 Mon Sep 17 00:00:00 2001 From: AdilZouitine <adilzouitinegm@gmail.com> Date: Thu, 3 Apr 2025 14:23:50 +0000 Subject: [PATCH 14/28] Enhance SAC configuration and replay buffer with asynchronous prefetching support - Added async_prefetch parameter to SACConfig for improved buffer management. - Implemented get_iterator method in ReplayBuffer to support asynchronous prefetching of batches. - Updated learner_server to utilize the new iterator for online and offline sampling, enhancing training efficiency. --- .../common/policies/sac/configuration_sac.py | 4 +- lerobot/scripts/server/buffer.py | 576 ++++-------------- lerobot/scripts/server/learner_server.py | 34 +- 3 files changed, 132 insertions(+), 482 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index ae38b1c5..3d01f47c 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -42,8 +42,6 @@ class CriticNetworkConfig: final_activation: str | None = None - - @dataclass class ActorNetworkConfig: hidden_dims: list[int] = field(default_factory=lambda: [256, 256]) @@ -94,6 +92,7 @@ class SACConfig(PreTrainedConfig): online_env_seed: Seed for the online environment. online_buffer_capacity: Capacity of the online replay buffer. offline_buffer_capacity: Capacity of the offline replay buffer. + async_prefetch: Whether to use asynchronous prefetching for the buffers. online_step_before_learning: Number of steps before learning starts. policy_update_freq: Frequency of policy updates. discount: Discount factor for the SAC algorithm. @@ -154,6 +153,7 @@ class SACConfig(PreTrainedConfig): online_env_seed: int = 10000 online_buffer_capacity: int = 100000 offline_buffer_capacity: int = 100000 + async_prefetch: bool = False online_step_before_learning: int = 100 policy_update_freq: int = 1 diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 2af3995e..c8f85372 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -345,6 +345,109 @@ class ReplayBuffer: truncated=batch_truncateds, ) + def get_iterator( + self, + batch_size: int, + async_prefetch: bool = True, + queue_size: int = 2, + ): + """ + Creates an infinite iterator that yields batches of transitions. + Will automatically restart when internal iterator is exhausted. + + Args: + batch_size (int): Size of batches to sample + async_prefetch (bool): Whether to use asynchronous prefetching with threads (default: True) + queue_size (int): Number of batches to prefetch (default: 2) + + Yields: + BatchTransition: Batched transitions + """ + while True: # Create an infinite loop + if async_prefetch: + # Get the standard iterator + iterator = self._get_async_iterator(queue_size=queue_size, batch_size=batch_size) + else: + iterator = self._get_naive_iterator(batch_size=batch_size, queue_size=queue_size) + + # Yield all items from the iterator + try: + yield from iterator + except StopIteration: + # Just continue the outer loop to create a new iterator + pass + + def _get_async_iterator(self, batch_size: int, queue_size: int = 2): + """ + Creates an iterator that prefetches batches in a background thread. + + Args: + queue_size (int): Number of batches to prefetch (default: 2) + batch_size (int): Size of batches to sample (default: 128) + + Yields: + BatchTransition: Prefetched batch transitions + """ + import threading + import queue + + # Use thread-safe queue + data_queue = queue.Queue(maxsize=queue_size) + running = [True] # Use list to allow modification in nested function + + def prefetch_worker(): + while running[0]: + try: + # Sample data and add to queue + data = self.sample(batch_size) + data_queue.put(data, block=True, timeout=0.5) + except queue.Full: + continue + except Exception as e: + print(f"Prefetch error: {e}") + break + + # Start prefetching thread + thread = threading.Thread(target=prefetch_worker, daemon=True) + thread.start() + + try: + while running[0]: + try: + yield data_queue.get(block=True, timeout=0.5) + except queue.Empty: + if not thread.is_alive(): + break + finally: + # Clean up + running[0] = False + thread.join(timeout=1.0) + + def _get_naive_iterator(self, batch_size: int, queue_size: int = 2): + """ + Creates a simple non-threaded iterator that yields batches. + + Args: + batch_size (int): Size of batches to sample + queue_size (int): Number of initial batches to prefetch + + Yields: + BatchTransition: Batch transitions + """ + import collections + + queue = collections.deque() + + def enqueue(n): + for _ in range(n): + data = self.sample(batch_size) + queue.append(data) + + enqueue(queue_size) + while queue: + yield queue.popleft() + enqueue(1) + @classmethod def from_lerobot_dataset( cls, @@ -710,475 +813,4 @@ def concatenate_batch_transitions( if __name__ == "__main__": - from tempfile import TemporaryDirectory - - # ===== Test 1: Create and use a synthetic ReplayBuffer ===== - print("Testing synthetic ReplayBuffer...") - - # Create sample data dimensions - batch_size = 32 - state_dims = {"observation.image": (3, 84, 84), "observation.state": (10,)} - action_dim = (6,) - - # Create a buffer - buffer = ReplayBuffer( - capacity=1000, - device="cpu", - state_keys=list(state_dims.keys()), - use_drq=True, - storage_device="cpu", - ) - - # Add some random transitions - for i in range(100): - # Create dummy transition data - state = { - "observation.image": torch.rand(1, 3, 84, 84), - "observation.state": torch.rand(1, 10), - } - action = torch.rand(1, 6) - reward = 0.5 - next_state = { - "observation.image": torch.rand(1, 3, 84, 84), - "observation.state": torch.rand(1, 10), - } - done = False if i < 99 else True - truncated = False - - buffer.add( - state=state, - action=action, - reward=reward, - next_state=next_state, - done=done, - truncated=truncated, - ) - - # Test sampling - batch = buffer.sample(batch_size) - print(f"Buffer size: {len(buffer)}") - print( - f"Sampled batch state shapes: {batch['state']['observation.image'].shape}, {batch['state']['observation.state'].shape}" - ) - print(f"Sampled batch action shape: {batch['action'].shape}") - print(f"Sampled batch reward shape: {batch['reward'].shape}") - print(f"Sampled batch done shape: {batch['done'].shape}") - print(f"Sampled batch truncated shape: {batch['truncated'].shape}") - - # ===== Test for state-action-reward alignment ===== - print("\nTesting state-action-reward alignment...") - - # Create a buffer with controlled transitions where we know the relationships - aligned_buffer = ReplayBuffer( - capacity=100, device="cpu", state_keys=["state_value"], storage_device="cpu" - ) - - # Create transitions with known relationships - # - Each state has a unique signature value - # - Action is 2x the state signature - # - Reward is 3x the state signature - # - Next state is signature + 0.01 (unless at episode end) - for i in range(100): - # Create a state with a signature value that encodes the transition number - signature = float(i) / 100.0 - state = {"state_value": torch.tensor([[signature]]).float()} - - # Action is 2x the signature - action = torch.tensor([[2.0 * signature]]).float() - - # Reward is 3x the signature - reward = 3.0 * signature - - # Next state is signature + 0.01, unless end of episode - # End episode every 10 steps - is_end = (i + 1) % 10 == 0 - - if is_end: - # At episode boundaries, next_state repeats current state (as per your implementation) - next_state = {"state_value": torch.tensor([[signature]]).float()} - done = True - else: - # Within episodes, next_state has signature + 0.01 - next_signature = float(i + 1) / 100.0 - next_state = {"state_value": torch.tensor([[next_signature]]).float()} - done = False - - aligned_buffer.add(state, action, reward, next_state, done, False) - - # Sample from this buffer - aligned_batch = aligned_buffer.sample(50) - - # Verify alignments in sampled batch - correct_relationships = 0 - total_checks = 0 - - # For each transition in the batch - for i in range(50): - # Extract signature from state - state_sig = aligned_batch["state"]["state_value"][i].item() - - # Check action is 2x signature (within reasonable precision) - action_val = aligned_batch["action"][i].item() - action_check = abs(action_val - 2.0 * state_sig) < 1e-4 - - # Check reward is 3x signature (within reasonable precision) - reward_val = aligned_batch["reward"][i].item() - reward_check = abs(reward_val - 3.0 * state_sig) < 1e-4 - - # Check next_state relationship matches our pattern - next_state_sig = aligned_batch["next_state"]["state_value"][i].item() - is_done = aligned_batch["done"][i].item() > 0.5 - - # Calculate expected next_state value based on done flag - if is_done: - # For episodes that end, next_state should equal state - next_state_check = abs(next_state_sig - state_sig) < 1e-4 - else: - # For continuing episodes, check if next_state is approximately state + 0.01 - # We need to be careful because we don't know the original index - # So we check if the increment is roughly 0.01 - next_state_check = ( - abs(next_state_sig - state_sig - 0.01) < 1e-4 or abs(next_state_sig - state_sig) < 1e-4 - ) - - # Count correct relationships - if action_check: - correct_relationships += 1 - if reward_check: - correct_relationships += 1 - if next_state_check: - correct_relationships += 1 - - total_checks += 3 - - alignment_accuracy = 100.0 * correct_relationships / total_checks - print(f"State-action-reward-next_state alignment accuracy: {alignment_accuracy:.2f}%") - if alignment_accuracy > 99.0: - print("✅ All relationships verified! Buffer maintains correct temporal relationships.") - else: - print("⚠️ Some relationships don't match expected patterns. Buffer may have alignment issues.") - - # Print some debug information about failures - print("\nDebug information for failed checks:") - for i in range(5): # Print first 5 transitions for debugging - state_sig = aligned_batch["state"]["state_value"][i].item() - action_val = aligned_batch["action"][i].item() - reward_val = aligned_batch["reward"][i].item() - next_state_sig = aligned_batch["next_state"]["state_value"][i].item() - is_done = aligned_batch["done"][i].item() > 0.5 - - print(f"Transition {i}:") - print(f" State: {state_sig:.6f}") - print(f" Action: {action_val:.6f} (expected: {2.0 * state_sig:.6f})") - print(f" Reward: {reward_val:.6f} (expected: {3.0 * state_sig:.6f})") - print(f" Done: {is_done}") - print(f" Next state: {next_state_sig:.6f}") - - # Calculate expected next state - if is_done: - expected_next = state_sig - else: - # This approximation might not be perfect - state_idx = round(state_sig * 100) - expected_next = (state_idx + 1) / 100.0 - - print(f" Expected next state: {expected_next:.6f}") - print() - - # ===== Test 2: Convert to LeRobotDataset and back ===== - with TemporaryDirectory() as temp_dir: - print("\nTesting conversion to LeRobotDataset and back...") - # Convert buffer to dataset - repo_id = "test/replay_buffer_conversion" - # Create a subdirectory to avoid the "directory exists" error - dataset_dir = os.path.join(temp_dir, "dataset1") - dataset = buffer.to_lerobot_dataset(repo_id=repo_id, root=dataset_dir) - - print(f"Dataset created with {len(dataset)} frames") - print(f"Dataset features: {list(dataset.features.keys())}") - - # Check a random sample from the dataset - sample = dataset[0] - print( - f"Dataset sample types: {[(k, type(v)) for k, v in sample.items() if k.startswith('observation')]}" - ) - - # Convert dataset back to buffer - reconverted_buffer = ReplayBuffer.from_lerobot_dataset( - dataset, state_keys=list(state_dims.keys()), device="cpu" - ) - - print(f"Reconverted buffer size: {len(reconverted_buffer)}") - - # Sample from the reconverted buffer - reconverted_batch = reconverted_buffer.sample(batch_size) - print( - f"Reconverted batch state shapes: {reconverted_batch['state']['observation.image'].shape}, {reconverted_batch['state']['observation.state'].shape}" - ) - - # Verify consistency before and after conversion - original_states = batch["state"]["observation.image"].mean().item() - reconverted_states = reconverted_batch["state"]["observation.image"].mean().item() - print(f"Original buffer state mean: {original_states:.4f}") - print(f"Reconverted buffer state mean: {reconverted_states:.4f}") - - if abs(original_states - reconverted_states) < 1.0: - print("Values are reasonably similar - conversion works as expected") - else: - print("WARNING: Significant difference between original and reconverted values") - - print("\nAll previous tests completed!") - - # ===== Test for memory optimization ===== - print("\n===== Testing Memory Optimization =====") - - # Create two buffers, one with memory optimization and one without - standard_buffer = ReplayBuffer( - capacity=1000, - device="cpu", - state_keys=["observation.image", "observation.state"], - storage_device="cpu", - optimize_memory=False, - use_drq=True, - ) - - optimized_buffer = ReplayBuffer( - capacity=1000, - device="cpu", - state_keys=["observation.image", "observation.state"], - storage_device="cpu", - optimize_memory=True, - use_drq=True, - ) - - # Generate sample data with larger state dimensions for better memory impact - print("Generating test data...") - num_episodes = 10 - steps_per_episode = 50 - total_steps = num_episodes * steps_per_episode - - for episode in range(num_episodes): - for step in range(steps_per_episode): - # Index in the overall sequence - i = episode * steps_per_episode + step - - # Create state with identifiable values - img = torch.ones((3, 84, 84)) * (i / total_steps) - state_vec = torch.ones((10,)) * (i / total_steps) - - state = { - "observation.image": img.unsqueeze(0), - "observation.state": state_vec.unsqueeze(0), - } - - # Create next state (i+1 or same as current if last in episode) - is_last_step = step == steps_per_episode - 1 - - if is_last_step: - # At episode end, next state = current state - next_img = img.clone() - next_state_vec = state_vec.clone() - done = True - truncated = False - else: - # Within episode, next state has incremented value - next_val = (i + 1) / total_steps - next_img = torch.ones((3, 84, 84)) * next_val - next_state_vec = torch.ones((10,)) * next_val - done = False - truncated = False - - next_state = { - "observation.image": next_img.unsqueeze(0), - "observation.state": next_state_vec.unsqueeze(0), - } - - # Action and reward - action = torch.tensor([[i / total_steps]]) - reward = float(i / total_steps) - - # Add to both buffers - standard_buffer.add(state, action, reward, next_state, done, truncated) - optimized_buffer.add(state, action, reward, next_state, done, truncated) - - # Verify episode boundaries with our simplified approach - print("\nVerifying simplified memory optimization...") - - # Test with a new buffer with a small sequence - test_buffer = ReplayBuffer( - capacity=20, - device="cpu", - state_keys=["value"], - storage_device="cpu", - optimize_memory=True, - use_drq=False, - ) - - # Add a simple sequence with known episode boundaries - for i in range(20): - val = float(i) - state = {"value": torch.tensor([[val]]).float()} - next_val = float(i + 1) if i % 5 != 4 else val # Episode ends every 5 steps - next_state = {"value": torch.tensor([[next_val]]).float()} - - # Set done=True at every 5th step - done = (i % 5) == 4 - action = torch.tensor([[0.0]]) - reward = 1.0 - truncated = False - - test_buffer.add(state, action, reward, next_state, done, truncated) - - # Get sequential batch for verification - sequential_batch_size = test_buffer.size - all_indices = torch.arange(sequential_batch_size, device=test_buffer.storage_device) - - # Get state tensors - batch_state = {"value": test_buffer.states["value"][all_indices].to(test_buffer.device)} - - # Get next_state using memory-optimized approach (simply index+1) - next_indices = (all_indices + 1) % test_buffer.capacity - batch_next_state = {"value": test_buffer.states["value"][next_indices].to(test_buffer.device)} - - # Get other tensors - batch_dones = test_buffer.dones[all_indices].to(test_buffer.device) - - # Print sequential values - print("State, Next State, Done (Sequential values with simplified optimization):") - state_values = batch_state["value"].squeeze().tolist() - next_values = batch_next_state["value"].squeeze().tolist() - done_flags = batch_dones.tolist() - - # Print all values - for i in range(len(state_values)): - print(f" {state_values[i]:.1f} → {next_values[i]:.1f}, Done: {done_flags[i]}") - - # Explain the memory optimization tradeoff - print("\nWith simplified memory optimization:") - print("- We always use the next state in the buffer (index+1) as next_state") - print("- For terminal states, this means using the first state of the next episode") - print("- This is a common tradeoff in RL implementations for memory efficiency") - print("- Since we track done flags, the algorithm can handle these transitions correctly") - - # Test random sampling - print("\nVerifying random sampling with simplified memory optimization...") - random_samples = test_buffer.sample(20) # Sample all transitions - - # Extract values - random_state_values = random_samples["state"]["value"].squeeze().tolist() - random_next_values = random_samples["next_state"]["value"].squeeze().tolist() - random_done_flags = random_samples["done"].bool().tolist() - - # Print a few samples - print("Random samples - State, Next State, Done (First 10):") - for i in range(10): - print(f" {random_state_values[i]:.1f} → {random_next_values[i]:.1f}, Done: {random_done_flags[i]}") - - # Calculate memory savings - # Assume optimized_buffer and standard_buffer have already been initialized and filled - std_mem = ( - sum( - standard_buffer.states[key].nelement() * standard_buffer.states[key].element_size() - for key in standard_buffer.states - ) - * 2 - ) - opt_mem = sum( - optimized_buffer.states[key].nelement() * optimized_buffer.states[key].element_size() - for key in optimized_buffer.states - ) - - savings_percent = (std_mem - opt_mem) / std_mem * 100 - - print("\nMemory optimization result:") - print(f"- Standard buffer state memory: {std_mem / (1024 * 1024):.2f} MB") - print(f"- Optimized buffer state memory: {opt_mem / (1024 * 1024):.2f} MB") - print(f"- Memory savings for state tensors: {savings_percent:.1f}%") - - print("\nAll memory optimization tests completed!") - - # # ===== Test real dataset conversion ===== - # print("\n===== Testing Real LeRobotDataset Conversion =====") - # try: - # # Try to use a real dataset if available - # dataset_name = "AdilZtn/Maniskill-Pushcube-demonstration-small" - # dataset = LeRobotDataset(repo_id=dataset_name) - - # # Print available keys to debug - # sample = dataset[0] - # print("Available keys in dataset:", list(sample.keys())) - - # # Check for required keys - # if "action" not in sample or "next.reward" not in sample: - # print("Dataset missing essential keys. Cannot convert.") - # raise ValueError("Missing required keys in dataset") - - # # Auto-detect appropriate state keys - # image_keys = [] - # state_keys = [] - # for k, v in sample.items(): - # # Skip metadata keys and action/reward keys - # if k in { - # "index", - # "episode_index", - # "frame_index", - # "timestamp", - # "task_index", - # "action", - # "next.reward", - # "next.done", - # }: - # continue - - # # Infer key type from tensor shape - # if isinstance(v, torch.Tensor): - # if len(v.shape) == 3 and (v.shape[0] == 3 or v.shape[0] == 1): - # # Likely an image (channels, height, width) - # image_keys.append(k) - # else: - # # Likely state or other vector - # state_keys.append(k) - - # print(f"Detected image keys: {image_keys}") - # print(f"Detected state keys: {state_keys}") - - # if not image_keys and not state_keys: - # print("No usable keys found in dataset, skipping further tests") - # raise ValueError("No usable keys found in dataset") - - # # Test with standard and memory-optimized buffers - # for optimize_memory in [False, True]: - # buffer_type = "Standard" if not optimize_memory else "Memory-optimized" - # print(f"\nTesting {buffer_type} buffer with real dataset...") - - # # Convert to ReplayBuffer with detected keys - # replay_buffer = ReplayBuffer.from_lerobot_dataset( - # lerobot_dataset=dataset, - # state_keys=image_keys + state_keys, - # device="cpu", - # optimize_memory=optimize_memory, - # ) - # print(f"Loaded {len(replay_buffer)} transitions from {dataset_name}") - - # # Test sampling - # real_batch = replay_buffer.sample(32) - # print(f"Sampled batch from real dataset ({buffer_type}), state shapes:") - # for key in real_batch["state"]: - # print(f" {key}: {real_batch['state'][key].shape}") - - # # Convert back to LeRobotDataset - # with TemporaryDirectory() as temp_dir: - # dataset_name = f"test/real_dataset_converted_{buffer_type}" - # replay_buffer_converted = replay_buffer.to_lerobot_dataset( - # repo_id=dataset_name, - # root=os.path.join(temp_dir, f"dataset_{buffer_type}"), - # ) - # print( - # f"Successfully converted back to LeRobotDataset with {len(replay_buffer_converted)} frames" - # ) - - # except Exception as e: - # print(f"Real dataset test failed: {e}") - # print("This is expected if running offline or if the dataset is not available.") - - # print("\nAll tests completed!") + pass # All test code is currently commented out diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index ce9a1b41..08baa6ea 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -269,6 +269,7 @@ def add_actor_information_and_train( policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency saving_checkpoint = cfg.save_checkpoint online_steps = cfg.policy.online_steps + async_prefetch = cfg.policy.async_prefetch # Initialize logging for multiprocessing if not use_threads(cfg): @@ -326,6 +327,9 @@ def add_actor_information_and_train( if cfg.dataset is not None: dataset_repo_id = cfg.dataset.repo_id + # Initialize iterators + online_iterator = None + offline_iterator = None # NOTE: THIS IS THE MAIN LOOP OF THE LEARNER while True: # Exit the training loop if shutdown is requested @@ -359,16 +363,29 @@ def add_actor_information_and_train( if len(replay_buffer) < online_step_before_learning: continue + if online_iterator is None: + logging.debug("[LEARNER] Initializing online replay buffer iterator") + online_iterator = replay_buffer.get_iterator( + batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2 + ) + + if offline_replay_buffer is not None and offline_iterator is None: + logging.debug("[LEARNER] Initializing offline replay buffer iterator") + offline_iterator = offline_replay_buffer.get_iterator( + batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2 + ) + logging.debug("[LEARNER] Starting optimization loop") time_for_one_optimization_step = time.time() for _ in range(utd_ratio - 1): - batch = replay_buffer.sample(batch_size=batch_size) + # Sample from the iterators + batch = next(online_iterator) - if dataset_repo_id is not None: - batch_offline = offline_replay_buffer.sample(batch_size=batch_size) - batch = concatenate_batch_transitions( - left_batch_transitions=batch, right_batch_transition=batch_offline - ) + if dataset_repo_id is not None: + batch_offline = next(offline_iterator) + batch = concatenate_batch_transitions( + left_batch_transitions=batch, right_batch_transition=batch_offline + ) actions = batch["action"] rewards = batch["reward"] @@ -418,10 +435,11 @@ def add_actor_information_and_train( # Update target networks policy.update_target_networks() - batch = replay_buffer.sample(batch_size=batch_size) + # Sample for the last update in the UTD ratio + batch = next(online_iterator) if dataset_repo_id is not None: - batch_offline = offline_replay_buffer.sample(batch_size=batch_size) + batch_offline = next(offline_iterator) batch = concatenate_batch_transitions( left_batch_transitions=batch, right_batch_transition=batch_offline ) From e86fe66dbd392fbf087e0ac74f8e23c1262c272b Mon Sep 17 00:00:00 2001 From: AdilZouitine <adilzouitinegm@gmail.com> Date: Thu, 3 Apr 2025 16:05:29 +0000 Subject: [PATCH 15/28] fix indentation issue --- lerobot/scripts/server/learner_server.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 08baa6ea..65b1d9b8 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -381,11 +381,11 @@ def add_actor_information_and_train( # Sample from the iterators batch = next(online_iterator) - if dataset_repo_id is not None: - batch_offline = next(offline_iterator) - batch = concatenate_batch_transitions( - left_batch_transitions=batch, right_batch_transition=batch_offline - ) + if dataset_repo_id is not None: + batch_offline = next(offline_iterator) + batch = concatenate_batch_transitions( + left_batch_transitions=batch, right_batch_transition=batch_offline + ) actions = batch["action"] rewards = batch["reward"] @@ -435,8 +435,8 @@ def add_actor_information_and_train( # Update target networks policy.update_target_networks() - # Sample for the last update in the UTD ratio - batch = next(online_iterator) + # Sample for the last update in the UTD ratio + batch = next(online_iterator) if dataset_repo_id is not None: batch_offline = next(offline_iterator) From 037ecae9e0b89d07dbc3da30e177cd5d7c27ed70 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 4 Apr 2025 07:59:22 +0000 Subject: [PATCH 16/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- lerobot/scripts/server/buffer.py | 3 +-- lerobot/scripts/server/maniskill_manipulator.py | 6 ------ 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index c8f85372..8947f6d9 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -15,7 +15,6 @@ # limitations under the License. import functools import io -import os import pickle from typing import Any, Callable, Optional, Sequence, TypedDict @@ -388,8 +387,8 @@ class ReplayBuffer: Yields: BatchTransition: Prefetched batch transitions """ - import threading import queue + import threading # Use thread-safe queue data_queue = queue.Queue(maxsize=queue_size) diff --git a/lerobot/scripts/server/maniskill_manipulator.py b/lerobot/scripts/server/maniskill_manipulator.py index b5c181c1..03a7ec10 100644 --- a/lerobot/scripts/server/maniskill_manipulator.py +++ b/lerobot/scripts/server/maniskill_manipulator.py @@ -1,5 +1,3 @@ -import logging -import time from typing import Any import einops @@ -10,10 +8,6 @@ from mani_skill.utils.wrappers.record import RecordEpisode from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv from lerobot.common.envs.configs import ManiskillEnvConfig -from lerobot.configs import parser -from lerobot.configs.train import TrainPipelineConfig -from lerobot.common.policies.sac.configuration_sac import SACConfig -from lerobot.common.policies.sac.modeling_sac import SACPolicy def preprocess_maniskill_observation( From 7741526ce42f2017db33841895d36e2c95ac875c Mon Sep 17 00:00:00 2001 From: AdilZouitine <adilzouitinegm@gmail.com> Date: Fri, 4 Apr 2025 14:29:38 +0000 Subject: [PATCH 17/28] fix caching --- lerobot/common/policies/sac/modeling_sac.py | 224 ++++++++++---------- lerobot/scripts/server/learner_server.py | 12 +- 2 files changed, 117 insertions(+), 119 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 2246bf8c..b5bfb36e 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -187,8 +187,8 @@ class SACPolicy( """Select action for inference/evaluation""" # We cached the encoder output to avoid recomputing it observations_features = None - if self.shared_encoder and self.actor.encoder is not None: - observations_features = self.actor.encoder(batch) + if self.shared_encoder: + observations_features = self.actor.encoder.get_image_features(batch) actions, _, _ = self.actor(batch, observations_features) actions = self.unnormalize_outputs({"action": actions})["action"] @@ -484,6 +484,109 @@ class SACPolicy( return actor_loss +class SACObservationEncoder(nn.Module): + """Encode image and/or state vector observations.""" + + def __init__(self, config: SACConfig, input_normalizer: nn.Module): + """ + Creates encoders for pixel and/or state modalities. + """ + super().__init__() + self.config = config + self.input_normalization = input_normalizer + self.has_pretrained_vision_encoder = False + self.parameters_to_optimize = [] + + self.aggregation_size: int = 0 + if any("observation.image" in key for key in config.input_features): + self.camera_number = config.camera_number + + if self.config.vision_encoder_name is not None: + self.image_enc_layers = PretrainedImageEncoder(config) + self.has_pretrained_vision_encoder = True + else: + self.image_enc_layers = DefaultImageEncoder(config) + + self.aggregation_size += config.latent_dim * self.camera_number + + if config.freeze_vision_encoder: + freeze_image_encoder(self.image_enc_layers) + else: + self.parameters_to_optimize += list(self.image_enc_layers.parameters()) + self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")] + + if "observation.state" in config.input_features: + self.state_enc_layers = nn.Sequential( + nn.Linear( + in_features=config.input_features["observation.state"].shape[0], + out_features=config.latent_dim, + ), + nn.LayerNorm(normalized_shape=config.latent_dim), + nn.Tanh(), + ) + self.aggregation_size += config.latent_dim + + self.parameters_to_optimize += list(self.state_enc_layers.parameters()) + + if "observation.environment_state" in config.input_features: + self.env_state_enc_layers = nn.Sequential( + nn.Linear( + in_features=config.input_features["observation.environment_state"].shape[0], + out_features=config.latent_dim, + ), + nn.LayerNorm(normalized_shape=config.latent_dim), + nn.Tanh(), + ) + self.aggregation_size += config.latent_dim + self.parameters_to_optimize += list(self.env_state_enc_layers.parameters()) + + self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim) + self.parameters_to_optimize += list(self.aggregation_layer.parameters()) + + def forward( + self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None + ) -> Tensor: + """Encode the image and/or state vector. + + Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken + over all features. + """ + feat = [] + obs_dict = self.input_normalization(obs_dict) + if len(self.all_image_keys) > 0 and vision_encoder_cache is None: + vision_encoder_cache = self.get_image_features(obs_dict) + feat.append(vision_encoder_cache) + + if vision_encoder_cache is not None: + feat.append(vision_encoder_cache) + + if "observation.environment_state" in self.config.input_features: + feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) + if "observation.state" in self.config.input_features: + feat.append(self.state_enc_layers(obs_dict["observation.state"])) + + features = torch.cat(tensors=feat, dim=-1) + features = self.aggregation_layer(features) + + return features + + def get_image_features(self, batch: dict[str, Tensor]) -> torch.Tensor: + # [N*B, C, H, W] + if len(self.all_image_keys) > 0: + # Batch all images along the batch dimension, then encode them. + images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0) + images_batched = self.image_enc_layers(images_batched) + embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys)) + embeddings_image = torch.cat(embeddings_chunks, dim=-1) + return embeddings_image + return None + + @property + def output_dim(self) -> int: + """Returns the dimension of the encoder output""" + return self.config.latent_dim + + class MLP(nn.Module): def __init__( self, @@ -606,7 +709,7 @@ class CriticEnsemble(nn.Module): def __init__( self, - encoder: Optional[nn.Module], + encoder: SACObservationEncoder, ensemble: List[CriticHead], output_normalization: nn.Module, init_final: Optional[float] = None, @@ -638,11 +741,7 @@ class CriticEnsemble(nn.Module): actions = self.output_normalization(actions)["action"] actions = actions.to(device) - obs_enc = ( - observation_features - if observation_features is not None - else (observations if self.encoder is None else self.encoder(observations)) - ) + obs_enc = self.encoder(observations, observation_features) inputs = torch.cat([obs_enc, actions], dim=-1) @@ -659,7 +758,7 @@ class CriticEnsemble(nn.Module): class GraspCritic(nn.Module): def __init__( self, - encoder: Optional[nn.Module], + encoder: nn.Module, input_dim: int, hidden_dims: list[int], output_dim: int = 3, @@ -699,19 +798,14 @@ class GraspCritic(nn.Module): device = get_device_from_parameters(self) # Move each tensor in observations to device by cloning first to avoid inplace operations observations = {k: v.to(device) for k, v in observations.items()} - # Encode observations if encoder exists - obs_enc = ( - observation_features.to(device) - if observation_features is not None - else (observations if self.encoder is None else self.encoder(observations)) - ) + obs_enc = self.encoder(observations, vision_encoder_cache=observation_features) return self.output_layer(self.net(obs_enc)) class Policy(nn.Module): def __init__( self, - encoder: Optional[nn.Module], + encoder: SACObservationEncoder, network: nn.Module, action_dim: int, log_std_min: float = -5, @@ -722,7 +816,7 @@ class Policy(nn.Module): encoder_is_shared: bool = False, ): super().__init__() - self.encoder = encoder + self.encoder: SACObservationEncoder = encoder self.network = network self.action_dim = action_dim self.log_std_min = log_std_min @@ -765,11 +859,7 @@ class Policy(nn.Module): observation_features: torch.Tensor | None = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Encode observations if encoder exists - obs_enc = ( - observation_features - if observation_features is not None - else (observations if self.encoder is None else self.encoder(observations)) - ) + obs_enc = self.encoder(observations, vision_encoder_cache=observation_features) # Get network outputs outputs = self.network(obs_enc) @@ -813,96 +903,6 @@ class Policy(nn.Module): return observations -class SACObservationEncoder(nn.Module): - """Encode image and/or state vector observations.""" - - def __init__(self, config: SACConfig, input_normalizer: nn.Module): - """ - Creates encoders for pixel and/or state modalities. - """ - super().__init__() - self.config = config - self.input_normalization = input_normalizer - self.has_pretrained_vision_encoder = False - self.parameters_to_optimize = [] - - self.aggregation_size: int = 0 - if any("observation.image" in key for key in config.input_features): - self.camera_number = config.camera_number - - if self.config.vision_encoder_name is not None: - self.image_enc_layers = PretrainedImageEncoder(config) - self.has_pretrained_vision_encoder = True - else: - self.image_enc_layers = DefaultImageEncoder(config) - - self.aggregation_size += config.latent_dim * self.camera_number - - if config.freeze_vision_encoder: - freeze_image_encoder(self.image_enc_layers) - else: - self.parameters_to_optimize += list(self.image_enc_layers.parameters()) - self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")] - - if "observation.state" in config.input_features: - self.state_enc_layers = nn.Sequential( - nn.Linear( - in_features=config.input_features["observation.state"].shape[0], - out_features=config.latent_dim, - ), - nn.LayerNorm(normalized_shape=config.latent_dim), - nn.Tanh(), - ) - self.aggregation_size += config.latent_dim - - self.parameters_to_optimize += list(self.state_enc_layers.parameters()) - - if "observation.environment_state" in config.input_features: - self.env_state_enc_layers = nn.Sequential( - nn.Linear( - in_features=config.input_features["observation.environment_state"].shape[0], - out_features=config.latent_dim, - ), - nn.LayerNorm(normalized_shape=config.latent_dim), - nn.Tanh(), - ) - self.aggregation_size += config.latent_dim - self.parameters_to_optimize += list(self.env_state_enc_layers.parameters()) - - self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim) - self.parameters_to_optimize += list(self.aggregation_layer.parameters()) - - def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: - """Encode the image and/or state vector. - - Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken - over all features. - """ - feat = [] - obs_dict = self.input_normalization(obs_dict) - # Batch all images along the batch dimension, then encode them. - if len(self.all_image_keys) > 0: - images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0) - images_batched = self.image_enc_layers(images_batched) - embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys)) - feat.extend(embeddings_chunks) - - if "observation.environment_state" in self.config.input_features: - feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) - if "observation.state" in self.config.input_features: - feat.append(self.state_enc_layers(obs_dict["observation.state"])) - - features = torch.cat(tensors=feat, dim=-1) - features = self.aggregation_layer(features) - - return features - - @property - def output_dim(self) -> int: - """Returns the dimension of the encoder output""" - return self.config.latent_dim - - class DefaultImageEncoder(nn.Module): def __init__(self, config: SACConfig): super().__init__() diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 65b1d9b8..37586fe9 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -775,7 +775,9 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module): params=policy.actor.parameters_to_optimize, lr=cfg.policy.actor_lr, ) - optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) + optimizer_critic = torch.optim.Adam( + params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr + ) if cfg.policy.num_discrete_actions is not None: optimizer_grasp_critic = torch.optim.Adam( @@ -1024,12 +1026,8 @@ def get_observation_features( return None, None with torch.no_grad(): - observation_features = ( - policy.actor.encoder(observations) if policy.actor.encoder is not None else None - ) - next_observation_features = ( - policy.actor.encoder(next_observations) if policy.actor.encoder is not None else None - ) + observation_features = policy.actor.encoder.get_image_features(observations) + next_observation_features = policy.actor.encoder.get_image_features(next_observations) return observation_features, next_observation_features From 4621f4e4f37faf8ce94b467132a23da7b4a5c839 Mon Sep 17 00:00:00 2001 From: AdilZouitine <adilzouitinegm@gmail.com> Date: Mon, 7 Apr 2025 08:23:49 +0000 Subject: [PATCH 18/28] Handle gripper penalty --- lerobot/common/policies/sac/modeling_sac.py | 11 +- lerobot/scripts/server/buffer.py | 167 ++++++++++++++++---- lerobot/scripts/server/learner_server.py | 2 +- 3 files changed, 147 insertions(+), 33 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index b5bfb36e..e3d3765e 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -288,6 +288,7 @@ class SACPolicy( next_observations: dict[str, Tensor] = batch["next_state"] done: Tensor = batch["done"] next_observation_features: Tensor = batch.get("next_observation_feature") + complementary_info = batch.get("complementary_info") loss_grasp_critic = self.compute_loss_grasp_critic( observations=observations, actions=actions, @@ -296,6 +297,7 @@ class SACPolicy( done=done, observation_features=observation_features, next_observation_features=next_observation_features, + complementary_info=complementary_info, ) return {"loss_grasp_critic": loss_grasp_critic} if model == "actor": @@ -413,6 +415,7 @@ class SACPolicy( done, observation_features=None, next_observation_features=None, + complementary_info=None, ): # NOTE: We only want to keep the discrete action part # In the buffer we have the full action space (continuous + discrete) @@ -420,6 +423,9 @@ class SACPolicy( actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone() actions_discrete = actions_discrete.long() + if complementary_info is not None: + gripper_penalties: Tensor | None = complementary_info.get("gripper_penalty") + with torch.no_grad(): # For DQN, select actions using online network, evaluate with target network next_grasp_qs = self.grasp_critic_forward( @@ -440,7 +446,10 @@ class SACPolicy( ).squeeze(-1) # Compute target Q-value with Bellman equation - target_grasp_q = rewards + (1 - done) * self.config.discount * target_next_grasp_q + rewards_gripper = rewards + if gripper_penalties is not None: + rewards_gripper = rewards - gripper_penalties + target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q # Get predicted Q-values for current observations predicted_grasp_qs = self.grasp_critic_forward( diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 8947f6d9..92ad7dc7 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -32,7 +32,7 @@ class Transition(TypedDict): next_state: dict[str, torch.Tensor] done: bool truncated: bool - complementary_info: dict[str, Any] = None + complementary_info: dict[str, torch.Tensor | float | int] | None = None class BatchTransition(TypedDict): @@ -42,41 +42,43 @@ class BatchTransition(TypedDict): next_state: dict[str, torch.Tensor] done: torch.Tensor truncated: torch.Tensor + complementary_info: dict[str, torch.Tensor | float | int] | None = None def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition: - # Move state tensors to CPU device = torch.device(device) + non_blocking = device.type == "cuda" + + # Move state tensors to device transition["state"] = { - key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items() + key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items() } - # Move action to CPU - transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda") + # Move action to device + transition["action"] = transition["action"].to(device, non_blocking=non_blocking) - # No need to move reward or done, as they are float and bool - - # No need to move reward or done, as they are float and bool + # Move reward and done if they are tensors if isinstance(transition["reward"], torch.Tensor): - transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda") + transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking) if isinstance(transition["done"], torch.Tensor): - transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda") + transition["done"] = transition["done"].to(device, non_blocking=non_blocking) if isinstance(transition["truncated"], torch.Tensor): - transition["truncated"] = transition["truncated"].to(device, non_blocking=device.type == "cuda") + transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking) - # Move next_state tensors to CPU + # Move next_state tensors to device transition["next_state"] = { - key: val.to(device, non_blocking=device.type == "cuda") - for key, val in transition["next_state"].items() + key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items() } - # If complementary_info is present, move its tensors to CPU - # if transition["complementary_info"] is not None: - # transition["complementary_info"] = { - # key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items() - # } + # Move complementary_info tensors if present + if transition.get("complementary_info") is not None: + transition["complementary_info"] = { + key: val.to(device, non_blocking=non_blocking) + for key, val in transition["complementary_info"].items() + } + return transition @@ -216,7 +218,12 @@ class ReplayBuffer: self.image_augmentation_function = torch.compile(base_function) self.use_drq = use_drq - def _initialize_storage(self, state: dict[str, torch.Tensor], action: torch.Tensor): + def _initialize_storage( + self, + state: dict[str, torch.Tensor], + action: torch.Tensor, + complementary_info: Optional[dict[str, torch.Tensor]] = None, + ): """Initialize the storage tensors based on the first transition.""" # Determine shapes from the first transition state_shapes = {key: val.squeeze(0).shape for key, val in state.items()} @@ -244,6 +251,26 @@ class ReplayBuffer: self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) + # Initialize storage for complementary_info + self.has_complementary_info = complementary_info is not None + self.complementary_info_keys = [] + self.complementary_info = {} + + if self.has_complementary_info: + self.complementary_info_keys = list(complementary_info.keys()) + # Pre-allocate tensors for each key in complementary_info + for key, value in complementary_info.items(): + if isinstance(value, torch.Tensor): + value_shape = value.squeeze(0).shape + self.complementary_info[key] = torch.empty( + (self.capacity, *value_shape), device=self.storage_device + ) + elif isinstance(value, (int, float)): + # Handle scalar values similar to reward + self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device) + else: + raise ValueError(f"Unsupported type {type(value)} for complementary_info[{key}]") + self.initialized = True def __len__(self): @@ -262,7 +289,7 @@ class ReplayBuffer: """Saves a transition, ensuring tensors are stored on the designated storage device.""" # Initialize storage if this is the first transition if not self.initialized: - self._initialize_storage(state=state, action=action) + self._initialize_storage(state=state, action=action, complementary_info=complementary_info) # Store the transition in pre-allocated tensors for key in self.states: @@ -277,6 +304,17 @@ class ReplayBuffer: self.dones[self.position] = done self.truncateds[self.position] = truncated + # Handle complementary_info if provided and storage is initialized + if complementary_info is not None and self.has_complementary_info: + # Store the complementary_info + for key in self.complementary_info_keys: + if key in complementary_info: + value = complementary_info[key] + if isinstance(value, torch.Tensor): + self.complementary_info[key][self.position].copy_(value.squeeze(dim=0)) + elif isinstance(value, (int, float)): + self.complementary_info[key][self.position] = value + self.position = (self.position + 1) % self.capacity self.size = min(self.size + 1, self.capacity) @@ -335,6 +373,13 @@ class ReplayBuffer: batch_dones = self.dones[idx].to(self.device).float() batch_truncateds = self.truncateds[idx].to(self.device).float() + # Sample complementary_info if available + batch_complementary_info = None + if self.has_complementary_info: + batch_complementary_info = {} + for key in self.complementary_info_keys: + batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device) + return BatchTransition( state=batch_state, action=batch_actions, @@ -342,6 +387,7 @@ class ReplayBuffer: next_state=batch_next_state, done=batch_dones, truncated=batch_truncateds, + complementary_info=batch_complementary_info, ) def get_iterator( @@ -518,7 +564,19 @@ class ReplayBuffer: if action_delta is not None: first_action = first_action / action_delta - replay_buffer._initialize_storage(state=first_state, action=first_action) + # Get complementary info if available + first_complementary_info = None + if ( + "complementary_info" in first_transition + and first_transition["complementary_info"] is not None + ): + first_complementary_info = { + k: v.to(device) for k, v in first_transition["complementary_info"].items() + } + + replay_buffer._initialize_storage( + state=first_state, action=first_action, complementary_info=first_complementary_info + ) # Fill the buffer with all transitions for data in list_transition: @@ -546,6 +604,7 @@ class ReplayBuffer: next_state=data["next_state"], done=data["done"], truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset + complementary_info=data.get("complementary_info", None), ) return replay_buffer @@ -587,6 +646,13 @@ class ReplayBuffer: f_info = guess_feature_info(t=sample_val, name=key) features[key] = f_info + # Add complementary_info keys if available + if self.has_complementary_info: + for key in self.complementary_info_keys: + sample_val = self.complementary_info[key][0] + f_info = guess_feature_info(t=sample_val, name=f"complementary_info.{key}") + features[f"complementary_info.{key}"] = f_info + # Create an empty LeRobotDataset lerobot_dataset = LeRobotDataset.create( repo_id=repo_id, @@ -620,6 +686,11 @@ class ReplayBuffer: frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu() frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu() + # Add complementary_info if available + if self.has_complementary_info: + for key in self.complementary_info_keys: + frame_dict[f"complementary_info.{key}"] = self.complementary_info[key][actual_idx].cpu() + # Add task field which is required by LeRobotDataset frame_dict["task"] = task_name @@ -686,6 +757,10 @@ class ReplayBuffer: sample = dataset[0] has_done_key = "next.done" in sample + # Check for complementary_info keys + complementary_info_keys = [key for key in sample.keys() if key.startswith("complementary_info.")] + has_complementary_info = len(complementary_info_keys) > 0 + # If not, we need to infer it from episode boundaries if not has_done_key: print("'next.done' key not found in dataset. Inferring from episode boundaries...") @@ -735,6 +810,16 @@ class ReplayBuffer: next_state_data[key] = val.unsqueeze(0) # Add batch dimension next_state = next_state_data + # ----- 5) Complementary info (if available) ----- + complementary_info = None + if has_complementary_info: + complementary_info = {} + for key in complementary_info_keys: + # Strip the "complementary_info." prefix to get the actual key + clean_key = key[len("complementary_info.") :] + val = current_sample[key] + complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension + # ----- Construct the Transition ----- transition = Transition( state=current_state, @@ -743,6 +828,7 @@ class ReplayBuffer: next_state=next_state, done=done, truncated=truncated, + complementary_info=complementary_info, ) transitions.append(transition) @@ -775,32 +861,33 @@ def concatenate_batch_transitions( left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition ) -> BatchTransition: """NOTE: Be careful it change the left_batch_transitions in place""" + # Concatenate state fields left_batch_transitions["state"] = { key: torch.cat( - [ - left_batch_transitions["state"][key], - right_batch_transition["state"][key], - ], + [left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0, ) for key in left_batch_transitions["state"] } + + # Concatenate basic fields left_batch_transitions["action"] = torch.cat( [left_batch_transitions["action"], right_batch_transition["action"]], dim=0 ) left_batch_transitions["reward"] = torch.cat( [left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0 ) + + # Concatenate next_state fields left_batch_transitions["next_state"] = { key: torch.cat( - [ - left_batch_transitions["next_state"][key], - right_batch_transition["next_state"][key], - ], + [left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0, ) for key in left_batch_transitions["next_state"] } + + # Concatenate done and truncated fields left_batch_transitions["done"] = torch.cat( [left_batch_transitions["done"], right_batch_transition["done"]], dim=0 ) @@ -808,8 +895,26 @@ def concatenate_batch_transitions( [left_batch_transitions["truncated"], right_batch_transition["truncated"]], dim=0, ) + + # Handle complementary_info + left_info = left_batch_transitions.get("complementary_info") + right_info = right_batch_transition.get("complementary_info") + + # Only process if right_info exists + if right_info is not None: + # Initialize left complementary_info if needed + if left_info is None: + left_batch_transitions["complementary_info"] = right_info + else: + # Concatenate each field + for key in right_info: + if key in left_info: + left_info[key] = torch.cat([left_info[key], right_info[key]], dim=0) + else: + left_info[key] = right_info[key] + return left_batch_transitions if __name__ == "__main__": - pass # All test code is currently commented out + pass diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 37586fe9..5489d6dc 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +# !/usr/bin/env python # Copyright 2024 The HuggingFace Inc. team. # All rights reserved. From 6c103906532c542c2ba763b5300afd2abc25dda8 Mon Sep 17 00:00:00 2001 From: AdilZouitine <adilzouitinegm@gmail.com> Date: Mon, 7 Apr 2025 14:48:42 +0000 Subject: [PATCH 19/28] Refactor complementary_info handling in ReplayBuffer --- lerobot/scripts/server/buffer.py | 132 ++++++++++++++++++++++++++++--- 1 file changed, 120 insertions(+), 12 deletions(-) diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 92ad7dc7..185412fd 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -74,11 +74,15 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr # Move complementary_info tensors if present if transition.get("complementary_info") is not None: - transition["complementary_info"] = { - key: val.to(device, non_blocking=non_blocking) - for key, val in transition["complementary_info"].items() - } - + for key, val in transition["complementary_info"].items(): + if isinstance(val, torch.Tensor): + transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking) + elif isinstance(val, (int, float, bool)): + transition["complementary_info"][key] = torch.tensor( + val, device=device, non_blocking=non_blocking + ) + else: + raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]") return transition @@ -650,6 +654,8 @@ class ReplayBuffer: if self.has_complementary_info: for key in self.complementary_info_keys: sample_val = self.complementary_info[key][0] + if isinstance(sample_val, torch.Tensor) and sample_val.ndim == 0: + sample_val = sample_val.unsqueeze(0) f_info = guess_feature_info(t=sample_val, name=f"complementary_info.{key}") features[f"complementary_info.{key}"] = f_info @@ -689,7 +695,15 @@ class ReplayBuffer: # Add complementary_info if available if self.has_complementary_info: for key in self.complementary_info_keys: - frame_dict[f"complementary_info.{key}"] = self.complementary_info[key][actual_idx].cpu() + val = self.complementary_info[key][actual_idx] + # Convert tensors to CPU + if isinstance(val, torch.Tensor): + if val.ndim == 0: + val = val.unsqueeze(0) + frame_dict[f"complementary_info.{key}"] = val.cpu() + # Non-tensor values can be used directly + else: + frame_dict[f"complementary_info.{key}"] = val # Add task field which is required by LeRobotDataset frame_dict["task"] = task_name @@ -758,7 +772,7 @@ class ReplayBuffer: has_done_key = "next.done" in sample # Check for complementary_info keys - complementary_info_keys = [key for key in sample.keys() if key.startswith("complementary_info.")] + complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")] has_complementary_info = len(complementary_info_keys) > 0 # If not, we need to infer it from episode boundaries @@ -818,7 +832,13 @@ class ReplayBuffer: # Strip the "complementary_info." prefix to get the actual key clean_key = key[len("complementary_info.") :] val = current_sample[key] - complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension + # Handle tensor and non-tensor values differently + if isinstance(val, torch.Tensor): + complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension + else: + # TODO: (azouitine) Check if it's necessary to convert to tensor + # For non-tensor values, use directly + complementary_info[clean_key] = val # ----- Construct the Transition ----- transition = Transition( @@ -836,12 +856,13 @@ class ReplayBuffer: # Utility function to guess shapes/dtypes from a tensor -def guess_feature_info(t: torch.Tensor, name: str): +def guess_feature_info(t, name: str): """ - Return a dictionary with the 'dtype' and 'shape' for a given tensor or array. + Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value. If it looks like a 3D (C,H,W) shape, we might consider it an 'image'. - Otherwise default to 'float32' for numeric. You can customize as needed. + Otherwise default to appropriate dtype for numeric. """ + shape = tuple(t.shape) # Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image' if len(shape) == 3 and shape[0] in [1, 3]: @@ -917,4 +938,91 @@ def concatenate_batch_transitions( if __name__ == "__main__": - pass + + def test_load_dataset_with_complementary_info(): + """ + Test loading a dataset with complementary_info into a ReplayBuffer. + The dataset 'aractingi/pick_lift_cube_two_cameras_gripper_penalty' contains + gripper_penalty values in complementary_info. + """ + import time + from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + + print("Loading dataset with complementary info...") + # Load a small subset of the dataset (first episode) + dataset = LeRobotDataset( + repo_id="aractingi/pick_lift_cube_two_cameras_gripper_penalty", + ) + + print(f"Dataset loaded with {len(dataset)} frames") + print(f"Dataset features: {list(dataset.features.keys())}") + + # Check if dataset has complementary_info.gripper_penalty + sample = dataset[0] + complementary_info_keys = [key for key in sample if key.startswith("complementary_info")] + print(f"Complementary info keys: {complementary_info_keys}") + + if "complementary_info.gripper_penalty" in sample: + print(f"Found gripper_penalty: {sample['complementary_info.gripper_penalty']}") + + # Extract state keys for the buffer + state_keys = [] + for key in sample: + if key.startswith("observation"): + state_keys.append(key) + + print(f"Using state keys: {state_keys}") + + # Create a replay buffer from the dataset + start_time = time.time() + buffer = ReplayBuffer.from_lerobot_dataset( + lerobot_dataset=dataset, state_keys=state_keys, use_drq=True, optimize_memory=False + ) + load_time = time.time() - start_time + print(f"Loaded dataset into buffer in {load_time:.2f} seconds") + print(f"Buffer size: {len(buffer)}") + + # Check if complementary_info was transferred correctly + print("Sampling from buffer to check complementary_info...") + batch = buffer.sample(batch_size=4) + + if batch["complementary_info"] is not None: + print("Complementary info in batch:") + for key, value in batch["complementary_info"].items(): + print(f" {key}: {type(value)}, shape: {value.shape if hasattr(value, 'shape') else 'N/A'}") + if key == "gripper_penalty": + print(f" Sample gripper_penalty values: {value[:5]}") + else: + print("No complementary_info found in batch") + + # Now convert the buffer back to a LeRobotDataset + print("\nConverting buffer back to LeRobotDataset...") + start_time = time.time() + new_dataset = buffer.to_lerobot_dataset( + repo_id="test_dataset_from_buffer", + fps=dataset.fps, + root="./test_dataset_from_buffer", + task_name="test_conversion", + ) + convert_time = time.time() - start_time + print(f"Converted buffer to dataset in {convert_time:.2f} seconds") + print(f"New dataset size: {len(new_dataset)} frames") + + # Check if complementary_info was preserved + new_sample = new_dataset[0] + new_complementary_info_keys = [key for key in new_sample if key.startswith("complementary_info")] + print(f"New dataset complementary info keys: {new_complementary_info_keys}") + + if "complementary_info.gripper_penalty" in new_sample: + print(f"Found gripper_penalty in new dataset: {new_sample['complementary_info.gripper_penalty']}") + + # Compare original and new datasets + print("\nComparing original and new datasets:") + print(f"Original dataset frames: {len(dataset)}, New dataset frames: {len(new_dataset)}") + print(f"Original features: {list(dataset.features.keys())}") + print(f"New features: {list(new_dataset.features.keys())}") + + return buffer, dataset, new_dataset + + # Run the test + test_load_dataset_with_complementary_info() From 632b2b46c1f7dca5338b87b1d20c3306c1719ff6 Mon Sep 17 00:00:00 2001 From: AdilZouitine <adilzouitinegm@gmail.com> Date: Mon, 7 Apr 2025 15:44:06 +0000 Subject: [PATCH 20/28] fix sign issue --- lerobot/common/policies/sac/modeling_sac.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index e3d3765e..9b909813 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -448,7 +448,7 @@ class SACPolicy( # Compute target Q-value with Bellman equation rewards_gripper = rewards if gripper_penalties is not None: - rewards_gripper = rewards - gripper_penalties + rewards_gripper = rewards + gripper_penalties target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q # Get predicted Q-values for current observations From a7be613ee8329d3f3a2784d314ad599605b72c4a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 7 Apr 2025 15:48:39 +0000 Subject: [PATCH 21/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- lerobot/scripts/server/buffer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 185412fd..8db1a82c 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -946,6 +946,7 @@ if __name__ == "__main__": gripper_penalty values in complementary_info. """ import time + from lerobot.common.datasets.lerobot_dataset import LeRobotDataset print("Loading dataset with complementary info...") From a8135629b4a22141fcae2ea86c050b49e4374d19 Mon Sep 17 00:00:00 2001 From: AdilZouitine <adilzouitinegm@gmail.com> Date: Tue, 8 Apr 2025 08:50:02 +0000 Subject: [PATCH 22/28] Add rounding for safety --- lerobot/common/policies/sac/modeling_sac.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 9b909813..e3d83d36 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -421,6 +421,7 @@ class SACPolicy( # In the buffer we have the full action space (continuous + discrete) # We need to split them before concatenating them in the critic forward actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone() + actions_discrete = torch.round(actions_discrete) actions_discrete = actions_discrete.long() if complementary_info is not None: From d948b95d22b6c9ae65a691c8bd16255a3929b03b Mon Sep 17 00:00:00 2001 From: AdilZouitine <adilzouitinegm@gmail.com> Date: Wed, 9 Apr 2025 13:20:51 +0000 Subject: [PATCH 23/28] fix caching and dataset stats is optional --- .../common/policies/sac/configuration_sac.py | 2 +- lerobot/common/policies/sac/modeling_sac.py | 32 +++++++++++-------- lerobot/scripts/server/learner_server.py | 4 +-- 3 files changed, 22 insertions(+), 16 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 3d01f47c..684ac17f 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -120,7 +120,7 @@ class SACConfig(PreTrainedConfig): } ) - dataset_stats: dict[str, dict[str, list[float]]] = field( + dataset_stats: dict[str, dict[str, list[float]]] | None = field( default_factory=lambda: { "observation.image": { "mean": [0.485, 0.456, 0.406], diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index e3d83d36..34707dc4 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -63,16 +63,21 @@ class SACPolicy( else: self.normalize_inputs = nn.Identity() - output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats) - # HACK: This is hacky and should be removed - dataset_stats = dataset_stats or output_normalization_params - self.normalize_targets = Normalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.unnormalize_outputs = Unnormalize( - config.output_features, config.normalization_mapping, dataset_stats - ) + if config.dataset_stats is not None: + output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats) + + # HACK: This is hacky and should be removed + dataset_stats = dataset_stats or output_normalization_params + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + else: + self.normalize_targets = nn.Identity() + self.unnormalize_outputs = nn.Identity() # NOTE: For images the encoder should be shared between the actor and critic if config.shared_encoder: @@ -188,7 +193,7 @@ class SACPolicy( # We cached the encoder output to avoid recomputing it observations_features = None if self.shared_encoder: - observations_features = self.actor.encoder.get_image_features(batch) + observations_features = self.actor.encoder.get_image_features(batch, normalize=True) actions, _, _ = self.actor(batch, observations_features) actions = self.unnormalize_outputs({"action": actions})["action"] @@ -564,8 +569,7 @@ class SACObservationEncoder(nn.Module): feat = [] obs_dict = self.input_normalization(obs_dict) if len(self.all_image_keys) > 0 and vision_encoder_cache is None: - vision_encoder_cache = self.get_image_features(obs_dict) - feat.append(vision_encoder_cache) + vision_encoder_cache = self.get_image_features(obs_dict, normalize=False) if vision_encoder_cache is not None: feat.append(vision_encoder_cache) @@ -580,8 +584,10 @@ class SACObservationEncoder(nn.Module): return features - def get_image_features(self, batch: dict[str, Tensor]) -> torch.Tensor: + def get_image_features(self, batch: dict[str, Tensor], normalize: bool = True) -> torch.Tensor: # [N*B, C, H, W] + if normalize: + batch = self.input_normalization(batch) if len(self.all_image_keys) > 0: # Batch all images along the batch dimension, then encode them. images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0) diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 5489d6dc..707547a1 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -1026,8 +1026,8 @@ def get_observation_features( return None, None with torch.no_grad(): - observation_features = policy.actor.encoder.get_image_features(observations) - next_observation_features = policy.actor.encoder.get_image_features(next_observations) + observation_features = policy.actor.encoder.get_image_features(observations, normalize=True) + next_observation_features = policy.actor.encoder.get_image_features(next_observations, normalize=True) return observation_features, next_observation_features From e7edf2a8d8d6ca58fbfc7c93c8ae95446e5f5866 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Apr 2025 13:51:31 +0000 Subject: [PATCH 24/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- lerobot/common/policies/sac/modeling_sac.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 34707dc4..b51f9b8f 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -63,7 +63,6 @@ class SACPolicy( else: self.normalize_inputs = nn.Identity() - if config.dataset_stats is not None: output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats) From 5428ab96f5400ef298ac89daddb72e524f556fae Mon Sep 17 00:00:00 2001 From: Michel Aractingi <michel.aractingi@gmail.com> Date: Wed, 9 Apr 2025 17:04:43 +0200 Subject: [PATCH 25/28] General fixes in code, removed delta action, fixed grasp penalty, added logic to put gripper reward in info --- lerobot/common/envs/configs.py | 8 +- lerobot/common/policies/sac/modeling_sac.py | 1 + lerobot/common/robot_devices/control_utils.py | 3 - lerobot/scripts/control_robot.py | 2 - lerobot/scripts/server/buffer.py | 12 +- lerobot/scripts/server/gym_manipulator.py | 110 +++++++++++------- lerobot/scripts/server/learner_server.py | 4 +- 7 files changed, 75 insertions(+), 65 deletions(-) diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index a6eda93b..02911332 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -171,7 +171,6 @@ class VideoRecordConfig: class WrapperConfig: """Configuration for environment wrappers.""" - delta_action: float | None = None joint_masking_action_space: list[bool] | None = None @@ -191,7 +190,6 @@ class EnvWrapperConfig: """Configuration for environment wrappers.""" display_cameras: bool = False - delta_action: float = 0.1 use_relative_joint_positions: bool = True add_joint_velocity_to_observation: bool = False add_ee_pose_to_observation: bool = False @@ -203,11 +201,13 @@ class EnvWrapperConfig: joint_masking_action_space: Optional[Any] = None ee_action_space_params: Optional[EEActionSpaceConfig] = None use_gripper: bool = False - gripper_quantization_threshold: float = 0.8 - gripper_penalty: float = 0.0 + gripper_quantization_threshold: float | None = 0.8 + gripper_penalty: float = 0.0 + gripper_penalty_in_reward: bool = False open_gripper_on_reset: bool = False + @EnvConfig.register_subclass(name="gym_manipulator") @dataclass class HILSerlRobotEnvConfig(EnvConfig): diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index b51f9b8f..b8827a1b 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -428,6 +428,7 @@ class SACPolicy( actions_discrete = torch.round(actions_discrete) actions_discrete = actions_discrete.long() + gripper_penalties: Tensor | None = None if complementary_info is not None: gripper_penalties: Tensor | None = complementary_info.get("gripper_penalty") diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index c834e9e9..170a35de 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -221,7 +221,6 @@ def record_episode( events=events, policy=policy, fps=fps, - # record_delta_actions=record_delta_actions, teleoperate=policy is None, single_task=single_task, ) @@ -267,8 +266,6 @@ def control_loop( if teleoperate: observation, action = robot.teleop_step(record_data=True) - # if record_delta_actions: - # action["action"] = action["action"] - current_joint_positions else: observation = robot.capture_observation() diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 1013001a..658371a1 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -363,8 +363,6 @@ def replay( start_episode_t = time.perf_counter() action = actions[idx]["action"] - # if replay_delta_actions: - # action = action + current_joint_positions robot.send_action(action) dt_s = time.perf_counter() - start_episode_t diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 8db1a82c..92e03d33 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -78,9 +78,7 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr if isinstance(val, torch.Tensor): transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking) elif isinstance(val, (int, float, bool)): - transition["complementary_info"][key] = torch.tensor( - val, device=device, non_blocking=non_blocking - ) + transition["complementary_info"][key] = torch.tensor(val, device=device) else: raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]") return transition @@ -505,7 +503,6 @@ class ReplayBuffer: state_keys: Optional[Sequence[str]] = None, capacity: Optional[int] = None, action_mask: Optional[Sequence[int]] = None, - action_delta: Optional[float] = None, image_augmentation_function: Optional[Callable] = None, use_drq: bool = True, storage_device: str = "cpu", @@ -520,7 +517,6 @@ class ReplayBuffer: state_keys (Optional[Sequence[str]]): The list of keys that appear in `state` and `next_state`. capacity (Optional[int]): Buffer capacity. If None, uses dataset length. action_mask (Optional[Sequence[int]]): Indices of action dimensions to keep. - action_delta (Optional[float]): Factor to divide actions by. image_augmentation_function (Optional[Callable]): Function for image augmentation. If None, uses default random shift with pad=4. use_drq (bool): Whether to use DrQ image augmentation when sampling. @@ -565,9 +561,6 @@ class ReplayBuffer: else: first_action = first_action[:, action_mask] - if action_delta is not None: - first_action = first_action / action_delta - # Get complementary info if available first_complementary_info = None if ( @@ -598,9 +591,6 @@ class ReplayBuffer: else: action = action[:, action_mask] - if action_delta is not None: - action = action / action_delta - replay_buffer.add( state=data["state"], action=action, diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 3aa75466..44bbcf9b 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -42,7 +42,6 @@ class HILSerlRobotEnv(gym.Env): self, robot, use_delta_action_space: bool = True, - delta: float | None = None, display_cameras: bool = False, ): """ @@ -55,8 +54,6 @@ class HILSerlRobotEnv(gym.Env): robot: The robot interface object used to connect and interact with the physical robot. use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute joint positions are used. - delta (float or None): A scaling factor for the relative adjustments applied to joint positions. Should be a value between - 0 and 1 when using a delta action space. display_cameras (bool): If True, the robot's camera feeds will be displayed during execution. """ super().__init__() @@ -74,7 +71,6 @@ class HILSerlRobotEnv(gym.Env): self.current_step = 0 self.episode_data = None - self.delta = delta self.use_delta_action_space = use_delta_action_space self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position") @@ -374,7 +370,7 @@ class RewardWrapper(gym.Wrapper): self.device = device def step(self, action): - observation, _, terminated, truncated, info = self.env.step(action) + observation, reward, terminated, truncated, info = self.env.step(action) images = [ observation[key].to(self.device, non_blocking=self.device.type == "cuda") for key in observation @@ -382,15 +378,17 @@ class RewardWrapper(gym.Wrapper): ] start_time = time.perf_counter() with torch.inference_mode(): - reward = ( + success = ( self.reward_classifier.predict_reward(images, threshold=0.8) if self.reward_classifier is not None else 0.0 ) info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time) - if reward == 1.0: + if success == 1.0: terminated = True + reward = 1.0 + return observation, reward, terminated, truncated, info def reset(self, seed=None, options=None): @@ -720,19 +718,31 @@ class ResetWrapper(gym.Wrapper): env: HILSerlRobotEnv, reset_pose: np.ndarray | None = None, reset_time_s: float = 5, + open_gripper_on_reset: bool = False ): super().__init__(env) self.reset_time_s = reset_time_s self.reset_pose = reset_pose self.robot = self.unwrapped.robot + self.open_gripper_on_reset = open_gripper_on_reset def reset(self, *, seed=None, options=None): + + if self.reset_pose is not None: start_time = time.perf_counter() log_say("Reset the environment.", play_sounds=True) reset_follower_position(self.robot, self.reset_pose) busy_wait(self.reset_time_s - (time.perf_counter() - start_time)) log_say("Reset the environment done.", play_sounds=True) + if self.open_gripper_on_reset: + current_joint_pos = self.robot.follower_arms["main"].read("Present_Position") + current_joint_pos[-1] = MAX_GRIPPER_COMMAND + self.robot.send_action(torch.from_numpy(current_joint_pos)) + busy_wait(0.1) + current_joint_pos[-1] = 0.0 + self.robot.send_action(torch.from_numpy(current_joint_pos)) + busy_wait(0.2) else: log_say( f"Manually reset the environment for {self.reset_time_s} seconds.", @@ -762,37 +772,48 @@ class BatchCompitableWrapper(gym.ObservationWrapper): class GripperPenaltyWrapper(gym.RewardWrapper): - def __init__(self, env, penalty: float = -0.1): + def __init__(self, env, penalty: float = -0.1, gripper_penalty_in_reward: bool = True): super().__init__(env) self.penalty = penalty + self.gripper_penalty_in_reward = gripper_penalty_in_reward self.last_gripper_state = None + def reward(self, reward, action): gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND - if isinstance(action, tuple): - action = action[0] - action_normalized = action[-1] / MAX_GRIPPER_COMMAND + action_normalized = action - 1.0 #action / MAX_GRIPPER_COMMAND - gripper_penalty_bool = (gripper_state_normalized < 0.1 and action_normalized > 0.9) or ( - gripper_state_normalized > 0.9 and action_normalized < 0.1 + gripper_penalty_bool = (gripper_state_normalized < 0.5 and action_normalized > 0.5) or ( + gripper_state_normalized > 0.75 and action_normalized < -0.5 ) - breakpoint() - return reward + self.penalty * gripper_penalty_bool + return reward + self.penalty * int(gripper_penalty_bool) def step(self, action): self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1] + if isinstance(action, tuple): + gripper_action = action[0][-1] + else: + gripper_action = action[-1] obs, reward, terminated, truncated, info = self.env.step(action) - reward = self.reward(reward, action) + gripper_penalty = self.reward(reward, gripper_action) + + if self.gripper_penalty_in_reward: + reward += gripper_penalty + else: + info["gripper_penalty"] = gripper_penalty + return obs, reward, terminated, truncated, info def reset(self, **kwargs): self.last_gripper_state = None - return super().reset(**kwargs) + obs, info = super().reset(**kwargs) + if self.gripper_penalty_in_reward: + info["gripper_penalty"] = 0.0 + return obs, info - -class GripperQuantizationWrapper(gym.ActionWrapper): +class GripperActionWrapper(gym.ActionWrapper): def __init__(self, env, quantization_threshold: float = 0.2): super().__init__(env) self.quantization_threshold = quantization_threshold @@ -801,16 +822,18 @@ class GripperQuantizationWrapper(gym.ActionWrapper): is_intervention = False if isinstance(action, tuple): action, is_intervention = action - gripper_command = action[-1] - # Quantize gripper command to -1, 0 or 1 - if gripper_command < -self.quantization_threshold: - gripper_command = -MAX_GRIPPER_COMMAND - elif gripper_command > self.quantization_threshold: - gripper_command = MAX_GRIPPER_COMMAND - else: - gripper_command = 0.0 + # Gripper actions are between 0, 2 + # we want to quantize them to -1, 0 or 1 + gripper_command = gripper_command - 1.0 + + if self.quantization_threshold is not None: + # Quantize gripper command to -1, 0 or 1 + gripper_command = ( + np.sign(gripper_command) if abs(gripper_command) > self.quantization_threshold else 0.0 + ) + gripper_command = gripper_command * MAX_GRIPPER_COMMAND gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1] gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND) action[-1] = gripper_action.item() @@ -836,10 +859,12 @@ class EEActionWrapper(gym.ActionWrapper): ] ) if self.use_gripper: - action_space_bounds = np.concatenate([action_space_bounds, [1.0]]) + # gripper actions open at 2.0, and closed at 0.0 + min_action_space_bounds = np.concatenate([-action_space_bounds, [0.0]]) + max_action_space_bounds = np.concatenate([action_space_bounds, [2.0]]) ee_action_space = gym.spaces.Box( - low=-action_space_bounds, - high=action_space_bounds, + low=min_action_space_bounds, + high=max_action_space_bounds, shape=(3 + int(self.use_gripper),), dtype=np.float32, ) @@ -997,11 +1022,11 @@ class GamepadControlWrapper(gym.Wrapper): if self.use_gripper: gripper_command = self.controller.gripper_command() if gripper_command == "open": - gamepad_action = np.concatenate([gamepad_action, [1.0]]) + gamepad_action = np.concatenate([gamepad_action, [2.0]]) elif gripper_command == "close": - gamepad_action = np.concatenate([gamepad_action, [-1.0]]) - else: gamepad_action = np.concatenate([gamepad_action, [0.0]]) + else: + gamepad_action = np.concatenate([gamepad_action, [1.0]]) # Check episode ending buttons # We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None @@ -1141,7 +1166,6 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: env = HILSerlRobotEnv( robot=robot, display_cameras=cfg.wrapper.display_cameras, - delta=cfg.wrapper.delta_action, use_delta_action_space=cfg.wrapper.use_relative_joint_positions and cfg.wrapper.ee_action_space_params is None, ) @@ -1165,10 +1189,11 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: # env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps) if cfg.wrapper.use_gripper: - env = GripperQuantizationWrapper( + env = GripperActionWrapper( env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold ) - # env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty) + if cfg.wrapper.gripper_penalty is not None: + env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty, gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward) if cfg.wrapper.ee_action_space_params is not None: env = EEActionWrapper( @@ -1176,6 +1201,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: ee_action_space_params=cfg.wrapper.ee_action_space_params, use_gripper=cfg.wrapper.use_gripper, ) + if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad: # env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params) env = GamepadControlWrapper( @@ -1192,6 +1218,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: env=env, reset_pose=cfg.wrapper.fixed_reset_joint_positions, reset_time_s=cfg.wrapper.reset_time_s, + open_gripper_on_reset=cfg.wrapper.open_gripper_on_reset ) if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None: env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space) @@ -1341,11 +1368,10 @@ def record_dataset(env, policy, cfg): dataset.push_to_hub() -def replay_episode(env, repo_id, root=None, episode=0): +def replay_episode(env, cfg): from lerobot.common.datasets.lerobot_dataset import LeRobotDataset - local_files_only = root is not None - dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only) + dataset = LeRobotDataset(cfg.repo_id, root=cfg.dataset_root, episodes=[cfg.episode]) env.reset() actions = dataset.hf_dataset.select_columns("action") @@ -1353,7 +1379,7 @@ def replay_episode(env, repo_id, root=None, episode=0): for idx in range(dataset.num_frames): start_episode_t = time.perf_counter() - action = actions[idx]["action"][:4] + action = actions[idx]["action"] env.step((action, False)) # env.step((action / env.unwrapped.delta, False)) @@ -1384,9 +1410,7 @@ def main(cfg: EnvConfig): if cfg.mode == "replay": replay_episode( env, - cfg.replay_repo_id, - root=cfg.dataset_root, - episode=cfg.replay_episode, + cfg=cfg, ) exit() diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 707547a1..e4bcc620 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -406,7 +406,8 @@ def add_actor_information_and_train( "next_state": next_observations, "done": done, "observation_feature": observation_features, - "next_observation_feature": next_observation_features, + "next_observation_feature": next_observation_features, + "complementary_info": batch["complementary_info"], } # Use the forward method for critic loss (includes both main critic and grasp critic) @@ -992,7 +993,6 @@ def initialize_offline_replay_buffer( device=device, state_keys=cfg.policy.input_features.keys(), action_mask=active_action_dims, - action_delta=cfg.env.wrapper.delta_action, storage_device=storage_device, optimize_memory=True, capacity=cfg.policy.offline_buffer_capacity, From ba09f44eb7f312c7f12287a4d56ae8d722633fbf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Apr 2025 15:05:17 +0000 Subject: [PATCH 26/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- lerobot/common/envs/configs.py | 3 +-- lerobot/scripts/server/gym_manipulator.py | 22 +++++++++++----------- lerobot/scripts/server/learner_server.py | 2 +- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index 02911332..77220c3c 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -202,12 +202,11 @@ class EnvWrapperConfig: ee_action_space_params: Optional[EEActionSpaceConfig] = None use_gripper: bool = False gripper_quantization_threshold: float | None = 0.8 - gripper_penalty: float = 0.0 + gripper_penalty: float = 0.0 gripper_penalty_in_reward: bool = False open_gripper_on_reset: bool = False - @EnvConfig.register_subclass(name="gym_manipulator") @dataclass class HILSerlRobotEnvConfig(EnvConfig): diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 44bbcf9b..6a8e848a 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -718,7 +718,7 @@ class ResetWrapper(gym.Wrapper): env: HILSerlRobotEnv, reset_pose: np.ndarray | None = None, reset_time_s: float = 5, - open_gripper_on_reset: bool = False + open_gripper_on_reset: bool = False, ): super().__init__(env) self.reset_time_s = reset_time_s @@ -727,8 +727,6 @@ class ResetWrapper(gym.Wrapper): self.open_gripper_on_reset = open_gripper_on_reset def reset(self, *, seed=None, options=None): - - if self.reset_pose is not None: start_time = time.perf_counter() log_say("Reset the environment.", play_sounds=True) @@ -777,12 +775,11 @@ class GripperPenaltyWrapper(gym.RewardWrapper): self.penalty = penalty self.gripper_penalty_in_reward = gripper_penalty_in_reward self.last_gripper_state = None - def reward(self, reward, action): gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND - action_normalized = action - 1.0 #action / MAX_GRIPPER_COMMAND + action_normalized = action - 1.0 # action / MAX_GRIPPER_COMMAND gripper_penalty_bool = (gripper_state_normalized < 0.5 and action_normalized > 0.5) or ( gripper_state_normalized > 0.75 and action_normalized < -0.5 @@ -803,7 +800,7 @@ class GripperPenaltyWrapper(gym.RewardWrapper): reward += gripper_penalty else: info["gripper_penalty"] = gripper_penalty - + return obs, reward, terminated, truncated, info def reset(self, **kwargs): @@ -813,6 +810,7 @@ class GripperPenaltyWrapper(gym.RewardWrapper): info["gripper_penalty"] = 0.0 return obs, info + class GripperActionWrapper(gym.ActionWrapper): def __init__(self, env, quantization_threshold: float = 0.2): super().__init__(env) @@ -1189,11 +1187,13 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: # env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps) if cfg.wrapper.use_gripper: - env = GripperActionWrapper( - env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold - ) + env = GripperActionWrapper(env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold) if cfg.wrapper.gripper_penalty is not None: - env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty, gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward) + env = GripperPenaltyWrapper( + env=env, + penalty=cfg.wrapper.gripper_penalty, + gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward, + ) if cfg.wrapper.ee_action_space_params is not None: env = EEActionWrapper( @@ -1218,7 +1218,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: env=env, reset_pose=cfg.wrapper.fixed_reset_joint_positions, reset_time_s=cfg.wrapper.reset_time_s, - open_gripper_on_reset=cfg.wrapper.open_gripper_on_reset + open_gripper_on_reset=cfg.wrapper.open_gripper_on_reset, ) if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None: env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space) diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index e4bcc620..5b39d0d3 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -406,7 +406,7 @@ def add_actor_information_and_train( "next_state": next_observations, "done": done, "observation_feature": observation_features, - "next_observation_feature": next_observation_features, + "next_observation_feature": next_observation_features, "complementary_info": batch["complementary_info"], } From 854bfb4ff8548c2276e70de0852acd3d620b948e Mon Sep 17 00:00:00 2001 From: AdilZouitine <adilzouitinegm@gmail.com> Date: Fri, 11 Apr 2025 11:50:46 +0000 Subject: [PATCH 27/28] fix encoder training --- lerobot/common/policies/sac/modeling_sac.py | 32 +++++++++++++-------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index b8827a1b..9ffdf154 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -525,9 +525,10 @@ class SACObservationEncoder(nn.Module): self.aggregation_size += config.latent_dim * self.camera_number if config.freeze_vision_encoder: - freeze_image_encoder(self.image_enc_layers) - else: - self.parameters_to_optimize += list(self.image_enc_layers.parameters()) + freeze_image_encoder(self.image_enc_layers.image_enc_layers) + + self.parameters_to_optimize += self.image_enc_layers.parameters_to_optimize + self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")] if "observation.state" in config.input_features: @@ -958,23 +959,25 @@ class DefaultImageEncoder(nn.Module): dummy_batch = torch.zeros(1, *config.input_features[image_key].shape) with torch.inference_mode(): self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:] - self.image_enc_layers.extend( - nn.Sequential( - nn.Flatten(), - nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim), - nn.LayerNorm(config.latent_dim), - nn.Tanh(), - ) + self.image_enc_proj = nn.Sequential( + nn.Flatten(), + nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Tanh(), ) + self.parameters_to_optimize = [] + if not config.freeze_vision_encoder: + self.parameters_to_optimize += list(self.image_enc_layers.parameters()) + self.parameters_to_optimize += list(self.image_enc_proj.parameters()) + def forward(self, x): - return self.image_enc_layers(x) + return self.image_enc_proj(self.image_enc_layers(x)) class PretrainedImageEncoder(nn.Module): def __init__(self, config: SACConfig): super().__init__() - self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config) self.image_enc_proj = nn.Sequential( nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim), @@ -982,6 +985,11 @@ class PretrainedImageEncoder(nn.Module): nn.Tanh(), ) + self.parameters_to_optimize = [] + if not config.freeze_vision_encoder: + self.parameters_to_optimize += list(self.image_enc_layers.parameters()) + self.parameters_to_optimize += list(self.image_enc_proj.parameters()) + def _load_pretrained_vision_encoder(self, config: SACConfig): """Set up CNN encoder""" from transformers import AutoModel From 320a1a92a39e3cf91c5cd9b018d78119da326c0c Mon Sep 17 00:00:00 2001 From: AdilZouitine <adilzouitinegm@gmail.com> Date: Mon, 14 Apr 2025 14:00:57 +0000 Subject: [PATCH 28/28] Refactor modeling_sac and parameter handling for clarity and reusability. Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com> --- lerobot/common/policies/sac/modeling_sac.py | 62 ++++++++------------- lerobot/scripts/server/learner_server.py | 50 +++++++++++++++-- 2 files changed, 67 insertions(+), 45 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 9ffdf154..05937240 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -167,8 +167,12 @@ class SACPolicy( def get_optim_params(self) -> dict: optim_params = { - "actor": self.actor.parameters_to_optimize, - "critic": self.critic_ensemble.parameters_to_optimize, + "actor": [ + p + for n, p in self.actor.named_parameters() + if not n.startswith("encoder") or not self.shared_encoder + ], + "critic": self.critic_ensemble.parameters(), "temperature": self.log_alpha, } if self.config.num_discrete_actions is not None: @@ -451,11 +455,11 @@ class SACPolicy( target_next_grasp_qs, dim=1, index=best_next_grasp_action ).squeeze(-1) - # Compute target Q-value with Bellman equation - rewards_gripper = rewards - if gripper_penalties is not None: - rewards_gripper = rewards + gripper_penalties - target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q + # Compute target Q-value with Bellman equation + rewards_gripper = rewards + if gripper_penalties is not None: + rewards_gripper = rewards + gripper_penalties + target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q # Get predicted Q-values for current observations predicted_grasp_qs = self.grasp_critic_forward( @@ -510,7 +514,6 @@ class SACObservationEncoder(nn.Module): self.config = config self.input_normalization = input_normalizer self.has_pretrained_vision_encoder = False - self.parameters_to_optimize = [] self.aggregation_size: int = 0 if any("observation.image" in key for key in config.input_features): @@ -527,8 +530,6 @@ class SACObservationEncoder(nn.Module): if config.freeze_vision_encoder: freeze_image_encoder(self.image_enc_layers.image_enc_layers) - self.parameters_to_optimize += self.image_enc_layers.parameters_to_optimize - self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")] if "observation.state" in config.input_features: @@ -542,8 +543,6 @@ class SACObservationEncoder(nn.Module): ) self.aggregation_size += config.latent_dim - self.parameters_to_optimize += list(self.state_enc_layers.parameters()) - if "observation.environment_state" in config.input_features: self.env_state_enc_layers = nn.Sequential( nn.Linear( @@ -554,10 +553,8 @@ class SACObservationEncoder(nn.Module): nn.Tanh(), ) self.aggregation_size += config.latent_dim - self.parameters_to_optimize += list(self.env_state_enc_layers.parameters()) self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim) - self.parameters_to_optimize += list(self.aggregation_layer.parameters()) def forward( self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None @@ -737,12 +734,6 @@ class CriticEnsemble(nn.Module): self.output_normalization = output_normalization self.critics = nn.ModuleList(ensemble) - self.parameters_to_optimize = [] - # Handle the case where a part of the encoder if frozen - if self.encoder is not None: - self.parameters_to_optimize += list(self.encoder.parameters_to_optimize) - self.parameters_to_optimize += list(self.critics.parameters()) - def forward( self, observations: dict[str, torch.Tensor], @@ -805,10 +796,6 @@ class GraspCritic(nn.Module): else: orthogonal_init()(self.output_layer.weight) - self.parameters_to_optimize = [] - self.parameters_to_optimize += list(self.net.parameters()) - self.parameters_to_optimize += list(self.output_layer.parameters()) - def forward( self, observations: torch.Tensor, observation_features: torch.Tensor | None = None ) -> torch.Tensor: @@ -840,12 +827,8 @@ class Policy(nn.Module): self.log_std_max = log_std_max self.fixed_std = fixed_std self.use_tanh_squash = use_tanh_squash - self.parameters_to_optimize = [] + self.encoder_is_shared = encoder_is_shared - self.parameters_to_optimize += list(self.network.parameters()) - - if self.encoder is not None and not encoder_is_shared: - self.parameters_to_optimize += list(self.encoder.parameters()) # Find the last Linear layer's output dimension for layer in reversed(network.net): if isinstance(layer, nn.Linear): @@ -859,7 +842,6 @@ class Policy(nn.Module): else: orthogonal_init()(self.mean_layer.weight) - self.parameters_to_optimize += list(self.mean_layer.parameters()) # Standard deviation layer or parameter if fixed_std is None: self.std_layer = nn.Linear(out_features, action_dim) @@ -868,7 +850,6 @@ class Policy(nn.Module): nn.init.uniform_(self.std_layer.bias, -init_final, init_final) else: orthogonal_init()(self.std_layer.weight) - self.parameters_to_optimize += list(self.std_layer.parameters()) def forward( self, @@ -877,6 +858,8 @@ class Policy(nn.Module): ) -> Tuple[torch.Tensor, torch.Tensor]: # Encode observations if encoder exists obs_enc = self.encoder(observations, vision_encoder_cache=observation_features) + if self.encoder_is_shared: + obs_enc = obs_enc.detach() # Get network outputs outputs = self.network(obs_enc) @@ -966,13 +949,13 @@ class DefaultImageEncoder(nn.Module): nn.Tanh(), ) - self.parameters_to_optimize = [] - if not config.freeze_vision_encoder: - self.parameters_to_optimize += list(self.image_enc_layers.parameters()) - self.parameters_to_optimize += list(self.image_enc_proj.parameters()) + self.freeze_image_encoder = config.freeze_vision_encoder def forward(self, x): - return self.image_enc_proj(self.image_enc_layers(x)) + x = self.image_enc_layers(x) + if self.freeze_image_encoder: + x = x.detach() + return self.image_enc_proj(x) class PretrainedImageEncoder(nn.Module): @@ -985,10 +968,7 @@ class PretrainedImageEncoder(nn.Module): nn.Tanh(), ) - self.parameters_to_optimize = [] - if not config.freeze_vision_encoder: - self.parameters_to_optimize += list(self.image_enc_layers.parameters()) - self.parameters_to_optimize += list(self.image_enc_proj.parameters()) + self.freeze_image_encoder = config.freeze_vision_encoder def _load_pretrained_vision_encoder(self, config: SACConfig): """Set up CNN encoder""" @@ -1009,6 +989,8 @@ class PretrainedImageEncoder(nn.Module): # TODO: (maractingi, azouitine) check the forward pass of the pretrained model # doesn't reach the classifier layer because we don't need it enc_feat = self.image_enc_layers(x).pooler_output + if self.freeze_image_encoder: + enc_feat = enc_feat.detach() enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1)) return enc_feat diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 5b39d0d3..a8a858bf 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -510,7 +510,7 @@ def add_actor_information_and_train( optimizers["actor"].zero_grad() loss_actor.backward() actor_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value + parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value ).item() optimizers["actor"].step() @@ -773,12 +773,14 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module): """ optimizer_actor = torch.optim.Adam( # NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor - params=policy.actor.parameters_to_optimize, + params=[ + p + for n, p in policy.actor.named_parameters() + if not n.startswith("encoder") or not policy.config.shared_encoder + ], lr=cfg.policy.actor_lr, ) - optimizer_critic = torch.optim.Adam( - params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr - ) + optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) if cfg.policy.num_discrete_actions is not None: optimizer_grasp_critic = torch.optim.Adam( @@ -1089,6 +1091,44 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module): parameters_queue.put(state_bytes) +def check_weight_gradients(module: nn.Module) -> dict[str, bool]: + """ + Checks whether each parameter in the module has a gradient. + + Args: + module (nn.Module): A PyTorch module whose parameters will be inspected. + + Returns: + dict[str, bool]: A dictionary where each key is the parameter name and the value is + True if the parameter has an associated gradient (i.e. .grad is not None), + otherwise False. + """ + grad_status = {} + for name, param in module.named_parameters(): + grad_status[name] = param.grad is not None + return grad_status + + +def get_overlapping_parameters(model: nn.Module, grad_status: dict[str, bool]) -> dict[str, bool]: + """ + Returns a dictionary of parameters (from actor) that also exist in the grad_status dictionary. + + Args: + actor (nn.Module): The actor model. + grad_status (dict[str, bool]): A dictionary where keys are parameter names and values indicate + whether each parameter has a gradient. + + Returns: + dict[str, bool]: A dictionary containing only the overlapping parameter names and their gradient status. + """ + # Get actor parameter names as a set. + model_param_names = {name for name, _ in model.named_parameters()} + + # Intersect parameter names between actor and grad_status. + overlapping = {name: grad_status[name] for name in grad_status if name in model_param_names} + return overlapping + + def process_interaction_message( message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None ):