From 4a1c26d9ee7108d86e60feee82b3d20ced026642 Mon Sep 17 00:00:00 2001 From: s1lent4gnt Date: Mon, 31 Mar 2025 17:35:59 +0200 Subject: [PATCH 01/22] Add grasp critic - Implemented grasp critic to evaluate gripper actions - Added corresponding config parameters for tuning --- .../common/policies/sac/configuration_sac.py | 8 ++ lerobot/common/policies/sac/modeling_sac.py | 124 ++++++++++++++++++ 2 files changed, 132 insertions(+) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 906a3bed..e47185da 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -42,6 +42,14 @@ class CriticNetworkConfig: final_activation: str | None = None +@dataclass +class GraspCriticNetworkConfig: + hidden_dims: list[int] = field(default_factory=lambda: [256, 256]) + activate_final: bool = True + final_activation: str | None = None + output_dim: int = 3 + + @dataclass class ActorNetworkConfig: hidden_dims: list[int] = field(default_factory=lambda: [256, 256]) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index f7866714..3589ad25 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -112,6 +112,26 @@ class SACPolicy( self.critic_ensemble = torch.compile(self.critic_ensemble) self.critic_target = torch.compile(self.critic_target) + + # Create grasp critic + self.grasp_critic = GraspCritic( + encoder=encoder_critic, + input_dim=encoder_critic.output_dim, + **config.grasp_critic_network_kwargs, + ) + + # Create target grasp critic + self.grasp_critic_target = GraspCritic( + encoder=encoder_critic, + input_dim=encoder_critic.output_dim, + **config.grasp_critic_network_kwargs, + ) + + self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict()) + + self.grasp_critic = torch.compile(self.grasp_critic) + self.grasp_critic_target = torch.compile(self.grasp_critic_target) + self.actor = Policy( encoder=encoder_actor, network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)), @@ -176,6 +196,21 @@ class SACPolicy( q_values = critics(observations, actions, observation_features) return q_values + def grasp_critic_forward(self, observations, use_target=False, observation_features=None): + """Forward pass through a grasp critic network + + Args: + observations: Dictionary of observations + use_target: If True, use target critics, otherwise use ensemble critics + observation_features: Optional pre-computed observation features to avoid recomputing encoder output + + Returns: + Tensor of Q-values from the grasp critic network + """ + grasp_critic = self.grasp_critic_target if use_target else self.grasp_critic + q_values = grasp_critic(observations, observation_features) + return q_values + def forward( self, batch: dict[str, Tensor | dict[str, Tensor]], @@ -246,6 +281,18 @@ class SACPolicy( + target_param.data * (1.0 - self.config.critic_target_update_weight) ) + def update_grasp_target_networks(self): + """Update grasp target networks with exponential moving average""" + for target_param, param in zip( + self.grasp_critic_target.parameters(), + self.grasp_critic.parameters(), + strict=False, + ): + target_param.data.copy_( + param.data * self.config.critic_target_update_weight + + target_param.data * (1.0 - self.config.critic_target_update_weight) + ) + def update_temperature(self): self.temperature = self.log_alpha.exp().item() @@ -307,6 +354,32 @@ class SACPolicy( ).sum() return critics_loss + def compute_loss_grasp_critic(self, observations, actions, rewards, next_observations, done, observation_features=None, next_observation_features=None, complementary_info=None): + + batch_size = rewards.shape[0] + grasp_actions = torch.clip(actions[:, -1].long() + 1, 0, 2) # Map [-1, 0, 1] -> [0, 1, 2] + + with torch.no_grad(): + next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=False) + best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1) + + target_next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=True) + target_next_grasp_q = target_next_grasp_qs[torch.arange(batch_size), best_next_grasp_action] + + # Get the grasp penalty from complementary_info + grasp_penalty = torch.zeros_like(rewards) + if complementary_info is not None and "grasp_penalty" in complementary_info: + grasp_penalty = complementary_info["grasp_penalty"] + + grasp_rewards = rewards + grasp_penalty + target_grasp_q = grasp_rewards + (1 - done) * self.config.discount * target_next_grasp_q + + predicted_grasp_qs = self.grasp_critic_forward(observations, use_target=False) + predicted_grasp_q = predicted_grasp_qs[torch.arange(batch_size), grasp_actions] + + grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q, reduction="mean") + return grasp_critic_loss + def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor: """Compute the temperature loss""" # calculate temperature loss @@ -509,6 +582,57 @@ class CriticEnsemble(nn.Module): return q_values +class GraspCritic(nn.Module): + def __init__( + self, + encoder: Optional[nn.Module], + network: nn.Module, + output_dim: int = 3, + init_final: Optional[float] = None, + encoder_is_shared: bool = False, + ): + super().__init__() + self.encoder = encoder + self.network = network + self.output_dim = output_dim + + # Find the last Linear layer's output dimension + for layer in reversed(network.net): + if isinstance(layer, nn.Linear): + out_features = layer.out_features + break + + self.parameters_to_optimize += list(self.network.parameters()) + + if self.encoder is not None and not encoder_is_shared: + self.parameters_to_optimize += list(self.encoder.parameters()) + + self.output_layer = nn.Linear(in_features=out_features, out_features=self.output_dim) + if init_final is not None: + nn.init.uniform_(self.output_layer.weight, -init_final, init_final) + nn.init.uniform_(self.output_layer.bias, -init_final, init_final) + else: + orthogonal_init()(self.output_layer.weight) + + self.parameters_to_optimize += list(self.output_layer.parameters()) + + def forward( + self, + observations: torch.Tensor, + observation_features: torch.Tensor | None = None + ) -> torch.Tensor: + device = get_device_from_parameters(self) + # Move each tensor in observations to device + observations = {k: v.to(device) for k, v in observations.items()} + # Encode observations if encoder exists + obs_enc = ( + observation_features + if observation_features is not None + else (observations if self.encoder is None else self.encoder(observations)) + ) + return self.output_layer(self.network(obs_enc)) + + class Policy(nn.Module): def __init__( self, From 007fee923071a860ff05bf1d7b536375ed6dea5f Mon Sep 17 00:00:00 2001 From: s1lent4gnt Date: Mon, 31 Mar 2025 17:36:35 +0200 Subject: [PATCH 02/22] Add complementary info in the replay buffer - Added complementary info in the add method - Added complementary info in the sample method --- lerobot/scripts/server/buffer.py | 37 ++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 776ad9ec..1fbc8803 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -244,6 +244,11 @@ class ReplayBuffer: self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) + + # Initialize complementary_info storage + self.complementary_info_keys = [] + self.complementary_info_storage = {} + self.initialized = True def __len__(self): @@ -277,6 +282,30 @@ class ReplayBuffer: self.dones[self.position] = done self.truncateds[self.position] = truncated + # Store complementary info if provided + if complementary_info is not None: + # Initialize storage for new keys on first encounter + for key, value in complementary_info.items(): + if key not in self.complementary_info_keys: + self.complementary_info_keys.append(key) + if isinstance(value, torch.Tensor): + shape = value.shape if value.ndim > 0 else (1,) + self.complementary_info_storage[key] = torch.zeros( + (self.capacity, *shape), + dtype=value.dtype, + device=self.storage_device + ) + + # Store the value + if key in self.complementary_info_storage: + if isinstance(value, torch.Tensor): + self.complementary_info_storage[key][self.position] = value + else: + # For non-tensor values (like grasp_penalty) + self.complementary_info_storage[key][self.position] = torch.tensor( + value, device=self.storage_device + ) + self.position = (self.position + 1) % self.capacity self.size = min(self.size + 1, self.capacity) @@ -335,6 +364,13 @@ class ReplayBuffer: batch_dones = self.dones[idx].to(self.device).float() batch_truncateds = self.truncateds[idx].to(self.device).float() + # Add complementary_info to batch if it exists + batch_complementary_info = {} + if hasattr(self, 'complementary_info_keys') and self.complementary_info_keys: + for key in self.complementary_info_keys: + if key in self.complementary_info_storage: + batch_complementary_info[key] = self.complementary_info_storage[key][idx].to(self.device) + return BatchTransition( state=batch_state, action=batch_actions, @@ -342,6 +378,7 @@ class ReplayBuffer: next_state=batch_next_state, done=batch_dones, truncated=batch_truncateds, + complementary_info=batch_complementary_info if batch_complementary_info else None, ) @classmethod From 7452f9baaa57d94b1b738ad94acdb23abe57c6d7 Mon Sep 17 00:00:00 2001 From: s1lent4gnt Date: Mon, 31 Mar 2025 17:38:16 +0200 Subject: [PATCH 03/22] Add gripper penalty wrapper --- lerobot/scripts/server/gym_manipulator.py | 24 +++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 92e8dcbc..26ed1991 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -1069,6 +1069,29 @@ class ActionScaleWrapper(gym.ActionWrapper): return action * self.scale_vector, is_intervention +class GripperPenaltyWrapper(gym.Wrapper): + def __init__(self, env, penalty=-0.05): + super().__init__(env) + self.penalty = penalty + self.last_gripper_pos = None + + def reset(self, **kwargs): + obs, info = self.env.reset(**kwargs) + self.last_gripper_pos = obs["observation.state"][0, 0] # first idx for the gripper + return obs, info + + def step(self, action): + observation, reward, terminated, truncated, info = self.env.step(action) + + if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or (action[-1] > 0.5 and self.last_gripper_pos < 0.9): + info["grasp_penalty"] = self.penalty + else: + info["grasp_penalty"] = 0.0 + + self.last_gripper_pos = observation["observation.state"][0, 0] # first idx for the gripper + return observation, reward, terminated, truncated, info + + def make_robot_env(cfg) -> gym.vector.VectorEnv: """ Factory function to create a vectorized robot environment. @@ -1144,6 +1167,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None: env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space) env = BatchCompitableWrapper(env=env) + env= GripperPenaltyWrapper(env=env) return env From 2c1e5fa28b67e05e932ebba2e83ec75c73db3c34 Mon Sep 17 00:00:00 2001 From: s1lent4gnt Date: Mon, 31 Mar 2025 17:40:00 +0200 Subject: [PATCH 04/22] Add get_gripper_action method to GamepadController --- .../server/end_effector_control_utils.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/lerobot/scripts/server/end_effector_control_utils.py b/lerobot/scripts/server/end_effector_control_utils.py index 3bd927b4..f272426d 100644 --- a/lerobot/scripts/server/end_effector_control_utils.py +++ b/lerobot/scripts/server/end_effector_control_utils.py @@ -311,6 +311,31 @@ class GamepadController(InputController): except pygame.error: logging.error("Error reading gamepad. Is it still connected?") return 0.0, 0.0, 0.0 + + def get_gripper_action(self): + """ + Get gripper action using L3/R3 buttons. + Press left stick (L3) to open the gripper. + Press right stick (R3) to close the gripper. + """ + import pygame + + try: + # Check if buttons are pressed + l3_pressed = self.joystick.get_button(9) + r3_pressed = self.joystick.get_button(10) + + # Determine action based on button presses + if r3_pressed: + return 1.0 # Close gripper + elif l3_pressed: + return -1.0 # Open gripper + else: + return 0.0 # No change + + except pygame.error: + logging.error(f"Error reading gamepad. Is it still connected?") + return 0.0 class GamepadControllerHID(InputController): From c774bbe5222e376dd142b491fe9c206902a83a99 Mon Sep 17 00:00:00 2001 From: s1lent4gnt Date: Mon, 31 Mar 2025 18:06:21 +0200 Subject: [PATCH 05/22] Add grasp critic to the training loop - Integrated the grasp critic gradient update to the training loop in learner_server - Added Adam optimizer and configured grasp critic learning rate in configuration_sac - Added target critics networks update after the critics gradient step --- .../common/policies/sac/configuration_sac.py | 1 + lerobot/common/policies/sac/modeling_sac.py | 19 ++++++++-- lerobot/scripts/server/learner_server.py | 35 +++++++++++++++++++ 3 files changed, 53 insertions(+), 2 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index e47185da..b1ce30b6 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -167,6 +167,7 @@ class SACConfig(PreTrainedConfig): num_critics: int = 2 num_subsample_critics: int | None = None critic_lr: float = 3e-4 + grasp_critic_lr: float = 3e-4 actor_lr: float = 3e-4 temperature_lr: float = 3e-4 critic_target_update_weight: float = 0.005 diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 3589ad25..bd74c65b 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -214,7 +214,7 @@ class SACPolicy( def forward( self, batch: dict[str, Tensor | dict[str, Tensor]], - model: Literal["actor", "critic", "temperature"] = "critic", + model: Literal["actor", "critic", "grasp_critic", "temperature"] = "critic", ) -> dict[str, Tensor]: """Compute the loss for the given model @@ -227,7 +227,7 @@ class SACPolicy( - done: Done mask tensor - observation_feature: Optional pre-computed observation features - next_observation_feature: Optional pre-computed next observation features - model: Which model to compute the loss for ("actor", "critic", or "temperature") + model: Which model to compute the loss for ("actor", "critic", "grasp_critic", or "temperature") Returns: The computed loss tensor @@ -254,6 +254,21 @@ class SACPolicy( observation_features=observation_features, next_observation_features=next_observation_features, ) + + if model == "grasp_critic": + # Extract grasp_critic-specific components + complementary_info: dict[str, Tensor] = batch["complementary_info"] + + return self.compute_loss_grasp_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + observation_features=observation_features, + next_observation_features=next_observation_features, + complementary_info=complementary_info, + ) if model == "actor": return self.compute_loss_actor( diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 98d2dbd8..f79e8d57 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -375,6 +375,7 @@ def add_actor_information_and_train( observations = batch["state"] next_observations = batch["next_state"] done = batch["done"] + complementary_info = batch["complementary_info"] check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) observation_features, next_observation_features = get_observation_features( @@ -390,6 +391,7 @@ def add_actor_information_and_train( "done": done, "observation_feature": observation_features, "next_observation_feature": next_observation_features, + "complementary_info": complementary_info, } # Use the forward method for critic loss @@ -404,7 +406,20 @@ def add_actor_information_and_train( optimizers["critic"].step() + # Add gripper critic optimization + loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic") + optimizers["grasp_critic"].zero_grad() + loss_grasp_critic.backward() + + # clip gradients + grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value + ) + + optimizers["grasp_critic"].step() + policy.update_target_networks() + policy.update_grasp_target_networks() batch = replay_buffer.sample(batch_size=batch_size) @@ -435,6 +450,7 @@ def add_actor_information_and_train( "done": done, "observation_feature": observation_features, "next_observation_feature": next_observation_features, + "complementary_info": complementary_info, } # Use the forward method for critic loss @@ -453,6 +469,22 @@ def add_actor_information_and_train( training_infos["loss_critic"] = loss_critic.item() training_infos["critic_grad_norm"] = critic_grad_norm + # Add gripper critic optimization + loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic") + optimizers["grasp_critic"].zero_grad() + loss_grasp_critic.backward() + + # clip gradients + grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value + ) + + optimizers["grasp_critic"].step() + + # Add training info for the grasp critic + training_infos["loss_grasp_critic"] = loss_grasp_critic.item() + training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm + if optimization_step % policy_update_freq == 0: for _ in range(policy_update_freq): # Use the forward method for actor loss @@ -495,6 +527,7 @@ def add_actor_information_and_train( last_time_policy_pushed = time.time() policy.update_target_networks() + policy.update_grasp_target_networks() # Log training metrics at specified intervals if optimization_step % log_freq == 0: @@ -729,11 +762,13 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module): lr=cfg.policy.actor_lr, ) optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) + optimizer_grasp_critic = torch.optim.Adam(params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr) optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) lr_scheduler = None optimizers = { "actor": optimizer_actor, "critic": optimizer_critic, + "grasp_critic": optimizer_grasp_critic, "temperature": optimizer_temperature, } return optimizers, lr_scheduler From 7983baf4fcbabe77e2d946c2fe3f441e6581ebfc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 31 Mar 2025 16:10:00 +0000 Subject: [PATCH 06/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- lerobot/common/policies/sac/modeling_sac.py | 23 ++++++++++++------- lerobot/scripts/server/buffer.py | 12 ++++------ .../server/end_effector_control_utils.py | 10 ++++---- lerobot/scripts/server/gym_manipulator.py | 10 ++++---- lerobot/scripts/server/learner_server.py | 6 +++-- 5 files changed, 35 insertions(+), 26 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index bd74c65b..95ea3928 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -198,7 +198,7 @@ class SACPolicy( def grasp_critic_forward(self, observations, use_target=False, observation_features=None): """Forward pass through a grasp critic network - + Args: observations: Dictionary of observations use_target: If True, use target critics, otherwise use ensemble critics @@ -254,7 +254,7 @@ class SACPolicy( observation_features=observation_features, next_observation_features=next_observation_features, ) - + if model == "grasp_critic": # Extract grasp_critic-specific components complementary_info: dict[str, Tensor] = batch["complementary_info"] @@ -307,7 +307,7 @@ class SACPolicy( param.data * self.config.critic_target_update_weight + target_param.data * (1.0 - self.config.critic_target_update_weight) ) - + def update_temperature(self): self.temperature = self.log_alpha.exp().item() @@ -369,8 +369,17 @@ class SACPolicy( ).sum() return critics_loss - def compute_loss_grasp_critic(self, observations, actions, rewards, next_observations, done, observation_features=None, next_observation_features=None, complementary_info=None): - + def compute_loss_grasp_critic( + self, + observations, + actions, + rewards, + next_observations, + done, + observation_features=None, + next_observation_features=None, + complementary_info=None, + ): batch_size = rewards.shape[0] grasp_actions = torch.clip(actions[:, -1].long() + 1, 0, 2) # Map [-1, 0, 1] -> [0, 1, 2] @@ -632,9 +641,7 @@ class GraspCritic(nn.Module): self.parameters_to_optimize += list(self.output_layer.parameters()) def forward( - self, - observations: torch.Tensor, - observation_features: torch.Tensor | None = None + self, observations: torch.Tensor, observation_features: torch.Tensor | None = None ) -> torch.Tensor: device = get_device_from_parameters(self) # Move each tensor in observations to device diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 1fbc8803..bf65b1ec 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -248,7 +248,7 @@ class ReplayBuffer: # Initialize complementary_info storage self.complementary_info_keys = [] self.complementary_info_storage = {} - + self.initialized = True def __len__(self): @@ -291,11 +291,9 @@ class ReplayBuffer: if isinstance(value, torch.Tensor): shape = value.shape if value.ndim > 0 else (1,) self.complementary_info_storage[key] = torch.zeros( - (self.capacity, *shape), - dtype=value.dtype, - device=self.storage_device + (self.capacity, *shape), dtype=value.dtype, device=self.storage_device ) - + # Store the value if key in self.complementary_info_storage: if isinstance(value, torch.Tensor): @@ -304,7 +302,7 @@ class ReplayBuffer: # For non-tensor values (like grasp_penalty) self.complementary_info_storage[key][self.position] = torch.tensor( value, device=self.storage_device - ) + ) self.position = (self.position + 1) % self.capacity self.size = min(self.size + 1, self.capacity) @@ -366,7 +364,7 @@ class ReplayBuffer: # Add complementary_info to batch if it exists batch_complementary_info = {} - if hasattr(self, 'complementary_info_keys') and self.complementary_info_keys: + if hasattr(self, "complementary_info_keys") and self.complementary_info_keys: for key in self.complementary_info_keys: if key in self.complementary_info_storage: batch_complementary_info[key] = self.complementary_info_storage[key][idx].to(self.device) diff --git a/lerobot/scripts/server/end_effector_control_utils.py b/lerobot/scripts/server/end_effector_control_utils.py index f272426d..3b2cfa90 100644 --- a/lerobot/scripts/server/end_effector_control_utils.py +++ b/lerobot/scripts/server/end_effector_control_utils.py @@ -311,7 +311,7 @@ class GamepadController(InputController): except pygame.error: logging.error("Error reading gamepad. Is it still connected?") return 0.0, 0.0, 0.0 - + def get_gripper_action(self): """ Get gripper action using L3/R3 buttons. @@ -319,12 +319,12 @@ class GamepadController(InputController): Press right stick (R3) to close the gripper. """ import pygame - + try: # Check if buttons are pressed l3_pressed = self.joystick.get_button(9) r3_pressed = self.joystick.get_button(10) - + # Determine action based on button presses if r3_pressed: return 1.0 # Close gripper @@ -332,9 +332,9 @@ class GamepadController(InputController): return -1.0 # Open gripper else: return 0.0 # No change - + except pygame.error: - logging.error(f"Error reading gamepad. Is it still connected?") + logging.error("Error reading gamepad. Is it still connected?") return 0.0 diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 26ed1991..ac3bbb0a 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -1077,18 +1077,20 @@ class GripperPenaltyWrapper(gym.Wrapper): def reset(self, **kwargs): obs, info = self.env.reset(**kwargs) - self.last_gripper_pos = obs["observation.state"][0, 0] # first idx for the gripper + self.last_gripper_pos = obs["observation.state"][0, 0] # first idx for the gripper return obs, info def step(self, action): observation, reward, terminated, truncated, info = self.env.step(action) - if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or (action[-1] > 0.5 and self.last_gripper_pos < 0.9): + if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or ( + action[-1] > 0.5 and self.last_gripper_pos < 0.9 + ): info["grasp_penalty"] = self.penalty else: info["grasp_penalty"] = 0.0 - self.last_gripper_pos = observation["observation.state"][0, 0] # first idx for the gripper + self.last_gripper_pos = observation["observation.state"][0, 0] # first idx for the gripper return observation, reward, terminated, truncated, info @@ -1167,7 +1169,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None: env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space) env = BatchCompitableWrapper(env=env) - env= GripperPenaltyWrapper(env=env) + env = GripperPenaltyWrapper(env=env) return env diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index f79e8d57..0f760dc5 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -417,7 +417,7 @@ def add_actor_information_and_train( ) optimizers["grasp_critic"].step() - + policy.update_target_networks() policy.update_grasp_target_networks() @@ -762,7 +762,9 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module): lr=cfg.policy.actor_lr, ) optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) - optimizer_grasp_critic = torch.optim.Adam(params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr) + optimizer_grasp_critic = torch.optim.Adam( + params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr + ) optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) lr_scheduler = None optimizers = { From fe2ff516a8d354d33283163137a4a266145caa7d Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Tue, 1 Apr 2025 11:08:15 +0200 Subject: [PATCH 07/22] Added Gripper quantization wrapper and grasp penalty removed complementary info from buffer and learner server removed get_gripper_action function added gripper parameters to `common/envs/configs.py` --- lerobot/common/envs/configs.py | 3 + lerobot/scripts/server/buffer.py | 34 ------ .../server/end_effector_control_utils.py | 25 ----- lerobot/scripts/server/gym_manipulator.py | 100 +++++++++++------- lerobot/scripts/server/learner_server.py | 3 - 5 files changed, 66 insertions(+), 99 deletions(-) diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index 825fa162..440512c3 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -203,6 +203,9 @@ class EnvWrapperConfig: joint_masking_action_space: Optional[Any] = None ee_action_space_params: Optional[EEActionSpaceConfig] = None use_gripper: bool = False + gripper_quantization_threshold: float = 0.8 + gripper_penalty: float = 0.0 + open_gripper_on_reset: bool = False @EnvConfig.register_subclass(name="gym_manipulator") diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index bf65b1ec..2af3995e 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -245,10 +245,6 @@ class ReplayBuffer: self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) - # Initialize complementary_info storage - self.complementary_info_keys = [] - self.complementary_info_storage = {} - self.initialized = True def __len__(self): @@ -282,28 +278,6 @@ class ReplayBuffer: self.dones[self.position] = done self.truncateds[self.position] = truncated - # Store complementary info if provided - if complementary_info is not None: - # Initialize storage for new keys on first encounter - for key, value in complementary_info.items(): - if key not in self.complementary_info_keys: - self.complementary_info_keys.append(key) - if isinstance(value, torch.Tensor): - shape = value.shape if value.ndim > 0 else (1,) - self.complementary_info_storage[key] = torch.zeros( - (self.capacity, *shape), dtype=value.dtype, device=self.storage_device - ) - - # Store the value - if key in self.complementary_info_storage: - if isinstance(value, torch.Tensor): - self.complementary_info_storage[key][self.position] = value - else: - # For non-tensor values (like grasp_penalty) - self.complementary_info_storage[key][self.position] = torch.tensor( - value, device=self.storage_device - ) - self.position = (self.position + 1) % self.capacity self.size = min(self.size + 1, self.capacity) @@ -362,13 +336,6 @@ class ReplayBuffer: batch_dones = self.dones[idx].to(self.device).float() batch_truncateds = self.truncateds[idx].to(self.device).float() - # Add complementary_info to batch if it exists - batch_complementary_info = {} - if hasattr(self, "complementary_info_keys") and self.complementary_info_keys: - for key in self.complementary_info_keys: - if key in self.complementary_info_storage: - batch_complementary_info[key] = self.complementary_info_storage[key][idx].to(self.device) - return BatchTransition( state=batch_state, action=batch_actions, @@ -376,7 +343,6 @@ class ReplayBuffer: next_state=batch_next_state, done=batch_dones, truncated=batch_truncateds, - complementary_info=batch_complementary_info if batch_complementary_info else None, ) @classmethod diff --git a/lerobot/scripts/server/end_effector_control_utils.py b/lerobot/scripts/server/end_effector_control_utils.py index 3b2cfa90..3bd927b4 100644 --- a/lerobot/scripts/server/end_effector_control_utils.py +++ b/lerobot/scripts/server/end_effector_control_utils.py @@ -312,31 +312,6 @@ class GamepadController(InputController): logging.error("Error reading gamepad. Is it still connected?") return 0.0, 0.0, 0.0 - def get_gripper_action(self): - """ - Get gripper action using L3/R3 buttons. - Press left stick (L3) to open the gripper. - Press right stick (R3) to close the gripper. - """ - import pygame - - try: - # Check if buttons are pressed - l3_pressed = self.joystick.get_button(9) - r3_pressed = self.joystick.get_button(10) - - # Determine action based on button presses - if r3_pressed: - return 1.0 # Close gripper - elif l3_pressed: - return -1.0 # Open gripper - else: - return 0.0 # No change - - except pygame.error: - logging.error("Error reading gamepad. Is it still connected?") - return 0.0 - class GamepadControllerHID(InputController): """Generate motion deltas from gamepad input using HIDAPI.""" diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index ac3bbb0a..3aa75466 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -761,6 +761,62 @@ class BatchCompitableWrapper(gym.ObservationWrapper): return observation +class GripperPenaltyWrapper(gym.RewardWrapper): + def __init__(self, env, penalty: float = -0.1): + super().__init__(env) + self.penalty = penalty + self.last_gripper_state = None + + def reward(self, reward, action): + gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND + + if isinstance(action, tuple): + action = action[0] + action_normalized = action[-1] / MAX_GRIPPER_COMMAND + + gripper_penalty_bool = (gripper_state_normalized < 0.1 and action_normalized > 0.9) or ( + gripper_state_normalized > 0.9 and action_normalized < 0.1 + ) + breakpoint() + + return reward + self.penalty * gripper_penalty_bool + + def step(self, action): + self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1] + obs, reward, terminated, truncated, info = self.env.step(action) + reward = self.reward(reward, action) + return obs, reward, terminated, truncated, info + + def reset(self, **kwargs): + self.last_gripper_state = None + return super().reset(**kwargs) + + +class GripperQuantizationWrapper(gym.ActionWrapper): + def __init__(self, env, quantization_threshold: float = 0.2): + super().__init__(env) + self.quantization_threshold = quantization_threshold + + def action(self, action): + is_intervention = False + if isinstance(action, tuple): + action, is_intervention = action + + gripper_command = action[-1] + # Quantize gripper command to -1, 0 or 1 + if gripper_command < -self.quantization_threshold: + gripper_command = -MAX_GRIPPER_COMMAND + elif gripper_command > self.quantization_threshold: + gripper_command = MAX_GRIPPER_COMMAND + else: + gripper_command = 0.0 + + gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1] + gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND) + action[-1] = gripper_action.item() + return action, is_intervention + + class EEActionWrapper(gym.ActionWrapper): def __init__(self, env, ee_action_space_params=None, use_gripper=False): super().__init__(env) @@ -820,17 +876,7 @@ class EEActionWrapper(gym.ActionWrapper): fk_func=self.fk_function, ) if self.use_gripper: - # Quantize gripper command to -1, 0 or 1 - if gripper_command < -0.2: - gripper_command = -1.0 - elif gripper_command > 0.2: - gripper_command = 1.0 - else: - gripper_command = 0.0 - - gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1] - gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND) - target_joint_pos[-1] = gripper_action + target_joint_pos[-1] = gripper_command return target_joint_pos, is_intervention @@ -1069,31 +1115,6 @@ class ActionScaleWrapper(gym.ActionWrapper): return action * self.scale_vector, is_intervention -class GripperPenaltyWrapper(gym.Wrapper): - def __init__(self, env, penalty=-0.05): - super().__init__(env) - self.penalty = penalty - self.last_gripper_pos = None - - def reset(self, **kwargs): - obs, info = self.env.reset(**kwargs) - self.last_gripper_pos = obs["observation.state"][0, 0] # first idx for the gripper - return obs, info - - def step(self, action): - observation, reward, terminated, truncated, info = self.env.step(action) - - if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or ( - action[-1] > 0.5 and self.last_gripper_pos < 0.9 - ): - info["grasp_penalty"] = self.penalty - else: - info["grasp_penalty"] = 0.0 - - self.last_gripper_pos = observation["observation.state"][0, 0] # first idx for the gripper - return observation, reward, terminated, truncated, info - - def make_robot_env(cfg) -> gym.vector.VectorEnv: """ Factory function to create a vectorized robot environment. @@ -1143,6 +1164,12 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: # Add reward computation and control wrappers # env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps) + if cfg.wrapper.use_gripper: + env = GripperQuantizationWrapper( + env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold + ) + # env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty) + if cfg.wrapper.ee_action_space_params is not None: env = EEActionWrapper( env=env, @@ -1169,7 +1196,6 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None: env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space) env = BatchCompitableWrapper(env=env) - env = GripperPenaltyWrapper(env=env) return env diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 0f760dc5..15de2cb7 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -375,7 +375,6 @@ def add_actor_information_and_train( observations = batch["state"] next_observations = batch["next_state"] done = batch["done"] - complementary_info = batch["complementary_info"] check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) observation_features, next_observation_features = get_observation_features( @@ -391,7 +390,6 @@ def add_actor_information_and_train( "done": done, "observation_feature": observation_features, "next_observation_feature": next_observation_features, - "complementary_info": complementary_info, } # Use the forward method for critic loss @@ -450,7 +448,6 @@ def add_actor_information_and_train( "done": done, "observation_feature": observation_features, "next_observation_feature": next_observation_features, - "complementary_info": complementary_info, } # Use the forward method for critic loss From 6a215f47ddf7e3235736f60cefb4ffb42166406c Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Tue, 1 Apr 2025 09:30:32 +0000 Subject: [PATCH 08/22] Refactor SAC configuration and policy to support discrete actions - Removed GraspCriticNetworkConfig class and integrated its parameters into SACConfig. - Added num_discrete_actions parameter to SACConfig for better action handling. - Updated SACPolicy to conditionally create grasp critic networks based on num_discrete_actions. - Enhanced grasp critic forward pass to handle discrete actions and compute losses accordingly. --- .../common/policies/sac/configuration_sac.py | 9 +- lerobot/common/policies/sac/modeling_sac.py | 103 +++++++++++------- 2 files changed, 68 insertions(+), 44 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index b1ce30b6..66d9aa45 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -42,12 +42,6 @@ class CriticNetworkConfig: final_activation: str | None = None -@dataclass -class GraspCriticNetworkConfig: - hidden_dims: list[int] = field(default_factory=lambda: [256, 256]) - activate_final: bool = True - final_activation: str | None = None - output_dim: int = 3 @dataclass @@ -152,6 +146,7 @@ class SACConfig(PreTrainedConfig): freeze_vision_encoder: bool = True image_encoder_hidden_dim: int = 32 shared_encoder: bool = True + num_discrete_actions: int | None = None # Training parameter online_steps: int = 1000000 @@ -182,7 +177,7 @@ class SACConfig(PreTrainedConfig): critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig) policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig) - + grasp_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig) concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 95ea3928..dd156918 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -33,6 +33,7 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.sac.configuration_sac import SACConfig from lerobot.common.policies.utils import get_device_from_parameters +DISCRETE_DIMENSION_INDEX = -1 class SACPolicy( PreTrainedPolicy, @@ -113,24 +114,30 @@ class SACPolicy( self.critic_ensemble = torch.compile(self.critic_ensemble) self.critic_target = torch.compile(self.critic_target) - # Create grasp critic - self.grasp_critic = GraspCritic( - encoder=encoder_critic, - input_dim=encoder_critic.output_dim, - **config.grasp_critic_network_kwargs, - ) + self.grasp_critic = None + self.grasp_critic_target = None - # Create target grasp critic - self.grasp_critic_target = GraspCritic( - encoder=encoder_critic, - input_dim=encoder_critic.output_dim, - **config.grasp_critic_network_kwargs, - ) + if config.num_discrete_actions is not None: + # Create grasp critic + self.grasp_critic = GraspCritic( + encoder=encoder_critic, + input_dim=encoder_critic.output_dim, + output_dim=config.num_discrete_actions, + **asdict(config.grasp_critic_network_kwargs), + ) - self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict()) + # Create target grasp critic + self.grasp_critic_target = GraspCritic( + encoder=encoder_critic, + input_dim=encoder_critic.output_dim, + output_dim=config.num_discrete_actions, + **asdict(config.grasp_critic_network_kwargs), + ) - self.grasp_critic = torch.compile(self.grasp_critic) - self.grasp_critic_target = torch.compile(self.grasp_critic_target) + self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict()) + + self.grasp_critic = torch.compile(self.grasp_critic) + self.grasp_critic_target = torch.compile(self.grasp_critic_target) self.actor = Policy( encoder=encoder_actor, @@ -173,6 +180,12 @@ class SACPolicy( """Select action for inference/evaluation""" actions, _, _ = self.actor(batch) actions = self.unnormalize_outputs({"action": actions})["action"] + + if self.config.num_discrete_actions is not None: + discrete_action_value = self.grasp_critic(batch) + discrete_action = torch.argmax(discrete_action_value, dim=-1) + actions = torch.cat([actions, discrete_action], dim=-1) + return actions def critic_forward( @@ -192,11 +205,12 @@ class SACPolicy( Returns: Tensor of Q-values from all critics """ + critics = self.critic_target if use_target else self.critic_ensemble q_values = critics(observations, actions, observation_features) return q_values - def grasp_critic_forward(self, observations, use_target=False, observation_features=None): + def grasp_critic_forward(self, observations, use_target=False, observation_features=None) -> torch.Tensor: """Forward pass through a grasp critic network Args: @@ -256,9 +270,6 @@ class SACPolicy( ) if model == "grasp_critic": - # Extract grasp_critic-specific components - complementary_info: dict[str, Tensor] = batch["complementary_info"] - return self.compute_loss_grasp_critic( observations=observations, actions=actions, @@ -267,7 +278,6 @@ class SACPolicy( done=done, observation_features=observation_features, next_observation_features=next_observation_features, - complementary_info=complementary_info, ) if model == "actor": @@ -349,6 +359,12 @@ class SACPolicy( td_target = rewards + (1 - done) * self.config.discount * min_q # 3- compute predicted qs + if self.config.num_discrete_actions is not None: + # NOTE: We only want to keep the continuous action part + # In the buffer we have the full action space (continuous + discrete) + # We need to split them before concatenating them in the critic forward + actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX] + q_preds = self.critic_forward( observations=observations, actions=actions, @@ -378,30 +394,43 @@ class SACPolicy( done, observation_features=None, next_observation_features=None, - complementary_info=None, ): - batch_size = rewards.shape[0] - grasp_actions = torch.clip(actions[:, -1].long() + 1, 0, 2) # Map [-1, 0, 1] -> [0, 1, 2] + # NOTE: We only want to keep the discrete action part + # In the buffer we have the full action space (continuous + discrete) + # We need to split them before concatenating them in the critic forward + actions: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:] + actions = actions.long() with torch.no_grad(): + # For DQN, select actions using online network, evaluate with target network next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=False) best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1) + + # Get target Q-values from target network + target_next_grasp_qs = self.grasp_critic_forward(observations=next_observations, use_target=True) + + # Use gather to select Q-values for best actions + target_next_grasp_q = torch.gather( + target_next_grasp_qs, + dim=1, + index=best_next_grasp_action.unsqueeze(-1) + ).squeeze(-1) - target_next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=True) - target_next_grasp_q = target_next_grasp_qs[torch.arange(batch_size), best_next_grasp_action] + # Compute target Q-value with Bellman equation + target_grasp_q = rewards + (1 - done) * self.config.discount * target_next_grasp_q - # Get the grasp penalty from complementary_info - grasp_penalty = torch.zeros_like(rewards) - if complementary_info is not None and "grasp_penalty" in complementary_info: - grasp_penalty = complementary_info["grasp_penalty"] + # Get predicted Q-values for current observations + predicted_grasp_qs = self.grasp_critic_forward(observations=observations, use_target=False) + + # Use gather to select Q-values for taken actions + predicted_grasp_q = torch.gather( + predicted_grasp_qs, + dim=1, + index=actions.unsqueeze(-1) + ).squeeze(-1) - grasp_rewards = rewards + grasp_penalty - target_grasp_q = grasp_rewards + (1 - done) * self.config.discount * target_next_grasp_q - - predicted_grasp_qs = self.grasp_critic_forward(observations, use_target=False) - predicted_grasp_q = predicted_grasp_qs[torch.arange(batch_size), grasp_actions] - - grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q, reduction="mean") + # Compute MSE loss between predicted and target Q-values + grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q) return grasp_critic_loss def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor: @@ -611,7 +640,7 @@ class GraspCritic(nn.Module): self, encoder: Optional[nn.Module], network: nn.Module, - output_dim: int = 3, + output_dim: int = 3, # TODO (azouitine): rename it number of discret acitons smth like that init_final: Optional[float] = None, encoder_is_shared: bool = False, ): From 306c735172a6ffd0d1f1fa0866f29abd8d2c597c Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Tue, 1 Apr 2025 11:42:28 +0000 Subject: [PATCH 09/22] Refactor SAC policy and training loop to enhance discrete action support - Updated SACPolicy to conditionally compute losses for grasp critic based on num_discrete_actions. - Simplified forward method to return loss outputs as a dictionary for better clarity. - Adjusted learner_server to handle both main and grasp critic losses during training. - Ensured optimizers are created conditionally for grasp critic based on configuration settings. --- .../common/policies/sac/configuration_sac.py | 2 +- lerobot/common/policies/sac/modeling_sac.py | 54 ++++---- lerobot/scripts/server/learner_server.py | 120 +++++++++--------- 3 files changed, 86 insertions(+), 90 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 66d9aa45..ae38b1c5 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -87,6 +87,7 @@ class SACConfig(PreTrainedConfig): freeze_vision_encoder: Whether to freeze the vision encoder during training. image_encoder_hidden_dim: Hidden dimension size for the image encoder. shared_encoder: Whether to use a shared encoder for actor and critic. + num_discrete_actions: Number of discrete actions, eg for gripper actions. concurrency: Configuration for concurrency settings. actor_learner: Configuration for actor-learner architecture. online_steps: Number of steps for online training. @@ -162,7 +163,6 @@ class SACConfig(PreTrainedConfig): num_critics: int = 2 num_subsample_critics: int | None = None critic_lr: float = 3e-4 - grasp_critic_lr: float = 3e-4 actor_lr: float = 3e-4 temperature_lr: float = 3e-4 critic_target_update_weight: float = 0.005 diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index dd156918..d0e8b25d 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -228,7 +228,7 @@ class SACPolicy( def forward( self, batch: dict[str, Tensor | dict[str, Tensor]], - model: Literal["actor", "critic", "grasp_critic", "temperature"] = "critic", + model: Literal["actor", "critic", "temperature"] = "critic", ) -> dict[str, Tensor]: """Compute the loss for the given model @@ -246,7 +246,6 @@ class SACPolicy( Returns: The computed loss tensor """ - # TODO: (maractingi, azouitine) Respect the function signature we output tensors # Extract common components from batch actions: Tensor = batch["action"] observations: dict[str, Tensor] = batch["state"] @@ -259,7 +258,7 @@ class SACPolicy( done: Tensor = batch["done"] next_observation_features: Tensor = batch.get("next_observation_feature") - return self.compute_loss_critic( + loss_critic = self.compute_loss_critic( observations=observations, actions=actions, rewards=rewards, @@ -268,29 +267,28 @@ class SACPolicy( observation_features=observation_features, next_observation_features=next_observation_features, ) + if self.config.num_discrete_actions is not None: + loss_grasp_critic = self.compute_loss_grasp_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + ) + return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic} - if model == "grasp_critic": - return self.compute_loss_grasp_critic( - observations=observations, - actions=actions, - rewards=rewards, - next_observations=next_observations, - done=done, - observation_features=observation_features, - next_observation_features=next_observation_features, - ) if model == "actor": - return self.compute_loss_actor( + return {"loss_actor": self.compute_loss_actor( observations=observations, observation_features=observation_features, - ) + )} if model == "temperature": - return self.compute_loss_temperature( + return {"loss_temperature": self.compute_loss_temperature( observations=observations, observation_features=observation_features, - ) + )} raise ValueError(f"Unknown model type: {model}") @@ -305,18 +303,16 @@ class SACPolicy( param.data * self.config.critic_target_update_weight + target_param.data * (1.0 - self.config.critic_target_update_weight) ) - - def update_grasp_target_networks(self): - """Update grasp target networks with exponential moving average""" - for target_param, param in zip( - self.grasp_critic_target.parameters(), - self.grasp_critic.parameters(), - strict=False, - ): - target_param.data.copy_( - param.data * self.config.critic_target_update_weight - + target_param.data * (1.0 - self.config.critic_target_update_weight) - ) + if self.config.num_discrete_actions is not None: + for target_param, param in zip( + self.grasp_critic_target.parameters(), + self.grasp_critic.parameters(), + strict=False, + ): + target_param.data.copy_( + param.data * self.config.critic_target_update_weight + + target_param.data * (1.0 - self.config.critic_target_update_weight) + ) def update_temperature(self): self.temperature = self.log_alpha.exp().item() diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 15de2cb7..627a1a17 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -392,32 +392,30 @@ def add_actor_information_and_train( "next_observation_feature": next_observation_features, } - # Use the forward method for critic loss - loss_critic = policy.forward(forward_batch, model="critic") + # Use the forward method for critic loss (includes both main critic and grasp critic) + critic_output = policy.forward(forward_batch, model="critic") + + # Main critic optimization + loss_critic = critic_output["loss_critic"] optimizers["critic"].zero_grad() loss_critic.backward() - - # clip gradients critic_grad_norm = torch.nn.utils.clip_grad_norm_( parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value ) - optimizers["critic"].step() - # Add gripper critic optimization - loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic") - optimizers["grasp_critic"].zero_grad() - loss_grasp_critic.backward() - - # clip gradients - grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value - ) - - optimizers["grasp_critic"].step() + # Grasp critic optimization (if available) + if "loss_grasp_critic" in critic_output and hasattr(policy, "grasp_critic"): + loss_grasp_critic = critic_output["loss_grasp_critic"] + optimizers["grasp_critic"].zero_grad() + loss_grasp_critic.backward() + grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value + ) + optimizers["grasp_critic"].step() + # Update target networks policy.update_target_networks() - policy.update_grasp_target_networks() batch = replay_buffer.sample(batch_size=batch_size) @@ -450,81 +448,80 @@ def add_actor_information_and_train( "next_observation_feature": next_observation_features, } - # Use the forward method for critic loss - loss_critic = policy.forward(forward_batch, model="critic") + # Use the forward method for critic loss (includes both main critic and grasp critic) + critic_output = policy.forward(forward_batch, model="critic") + + # Main critic optimization + loss_critic = critic_output["loss_critic"] optimizers["critic"].zero_grad() loss_critic.backward() - - # clip gradients critic_grad_norm = torch.nn.utils.clip_grad_norm_( parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value ).item() - optimizers["critic"].step() - training_infos = {} - training_infos["loss_critic"] = loss_critic.item() - training_infos["critic_grad_norm"] = critic_grad_norm + # Initialize training info dictionary + training_infos = { + "loss_critic": loss_critic.item(), + "critic_grad_norm": critic_grad_norm, + } - # Add gripper critic optimization - loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic") - optimizers["grasp_critic"].zero_grad() - loss_grasp_critic.backward() - - # clip gradients - grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value - ) - - optimizers["grasp_critic"].step() - - # Add training info for the grasp critic - training_infos["loss_grasp_critic"] = loss_grasp_critic.item() - training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm + # Grasp critic optimization (if available) + if "loss_grasp_critic" in critic_output: + loss_grasp_critic = critic_output["loss_grasp_critic"] + optimizers["grasp_critic"].zero_grad() + loss_grasp_critic.backward() + grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value + ).item() + optimizers["grasp_critic"].step() + + # Add grasp critic info to training info + training_infos["loss_grasp_critic"] = loss_grasp_critic.item() + training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm + # Actor and temperature optimization (at specified frequency) if optimization_step % policy_update_freq == 0: for _ in range(policy_update_freq): - # Use the forward method for actor loss - loss_actor = policy.forward(forward_batch, model="actor") - + # Actor optimization + actor_output = policy.forward(forward_batch, model="actor") + loss_actor = actor_output["loss_actor"] optimizers["actor"].zero_grad() loss_actor.backward() - - # clip gradients actor_grad_norm = torch.nn.utils.clip_grad_norm_( parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value ).item() - optimizers["actor"].step() - + + # Add actor info to training info training_infos["loss_actor"] = loss_actor.item() training_infos["actor_grad_norm"] = actor_grad_norm - # Temperature optimization using forward method - loss_temperature = policy.forward(forward_batch, model="temperature") + # Temperature optimization + temperature_output = policy.forward(forward_batch, model="temperature") + loss_temperature = temperature_output["loss_temperature"] optimizers["temperature"].zero_grad() loss_temperature.backward() - - # clip gradients temp_grad_norm = torch.nn.utils.clip_grad_norm_( parameters=[policy.log_alpha], max_norm=clip_grad_norm_value ).item() - optimizers["temperature"].step() - + + # Add temperature info to training info training_infos["loss_temperature"] = loss_temperature.item() training_infos["temperature_grad_norm"] = temp_grad_norm training_infos["temperature"] = policy.temperature + # Update temperature policy.update_temperature() - # Check if it's time to push updated policy to actors + # Push policy to actors if needed if time.time() - last_time_policy_pushed > policy_parameters_push_frequency: push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) last_time_policy_pushed = time.time() + # Update target networks policy.update_target_networks() - policy.update_grasp_target_networks() # Log training metrics at specified intervals if optimization_step % log_freq == 0: @@ -727,7 +724,7 @@ def save_training_checkpoint( logging.info("Resume training") -def make_optimizers_and_scheduler(cfg, policy: nn.Module): +def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module): """ Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy. @@ -759,17 +756,20 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module): lr=cfg.policy.actor_lr, ) optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) - optimizer_grasp_critic = torch.optim.Adam( - params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr - ) + + if cfg.policy.num_discrete_actions is not None: + optimizer_grasp_critic = torch.optim.Adam( + params=policy.grasp_critic.parameters(), lr=policy.critic_lr + ) optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) lr_scheduler = None optimizers = { "actor": optimizer_actor, "critic": optimizer_critic, - "grasp_critic": optimizer_grasp_critic, "temperature": optimizer_temperature, } + if cfg.policy.num_discrete_actions is not None: + optimizers["grasp_critic"] = optimizer_grasp_critic return optimizers, lr_scheduler From 451a7b01db13d3e7dc4a111a1f1513f1a1eba396 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Tue, 1 Apr 2025 14:22:08 +0000 Subject: [PATCH 10/22] Add mock gripper support and enhance SAC policy action handling - Introduced mock_gripper parameter in ManiskillEnvConfig to enable gripper simulation. - Added ManiskillMockGripperWrapper to adjust action space for environments with discrete actions. - Updated SACPolicy to compute continuous action dimensions correctly, ensuring compatibility with the new gripper setup. - Refactored action handling in the training loop to accommodate the changes in action dimensions. --- lerobot/common/envs/configs.py | 1 + lerobot/common/policies/sac/modeling_sac.py | 18 +++-- .../scripts/server/maniskill_manipulator.py | 71 ++++++++++++------- 3 files changed, 59 insertions(+), 31 deletions(-) diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index 440512c3..a6eda93b 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -257,6 +257,7 @@ class ManiskillEnvConfig(EnvConfig): robot: str = "so100" # This is a hack to make the robot config work video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig) wrapper: WrapperConfig = field(default_factory=WrapperConfig) + mock_gripper: bool = False features: dict[str, PolicyFeature] = field( default_factory=lambda: { "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index d0e8b25d..0c3d76d2 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -33,7 +33,7 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.sac.configuration_sac import SACConfig from lerobot.common.policies.utils import get_device_from_parameters -DISCRETE_DIMENSION_INDEX = -1 +DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension class SACPolicy( PreTrainedPolicy, @@ -82,7 +82,7 @@ class SACPolicy( # Create a list of critic heads critic_heads = [ CriticHead( - input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0], + input_dim=encoder_critic.output_dim + continuous_action_dim, **asdict(config.critic_network_kwargs), ) for _ in range(config.num_critics) @@ -97,7 +97,7 @@ class SACPolicy( # Create target critic heads as deepcopies of the original critic heads target_critic_heads = [ CriticHead( - input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0], + input_dim=encoder_critic.output_dim + continuous_action_dim, **asdict(config.critic_network_kwargs), ) for _ in range(config.num_critics) @@ -117,7 +117,10 @@ class SACPolicy( self.grasp_critic = None self.grasp_critic_target = None + continuous_action_dim = config.output_features["action"].shape[0] if config.num_discrete_actions is not None: + + continuous_action_dim -= 1 # Create grasp critic self.grasp_critic = GraspCritic( encoder=encoder_critic, @@ -139,15 +142,16 @@ class SACPolicy( self.grasp_critic = torch.compile(self.grasp_critic) self.grasp_critic_target = torch.compile(self.grasp_critic_target) + self.actor = Policy( encoder=encoder_actor, network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)), - action_dim=config.output_features["action"].shape[0], + action_dim=continuous_action_dim, encoder_is_shared=config.shared_encoder, **asdict(config.policy_kwargs), ) if config.target_entropy is None: - config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2) + config.target_entropy = -np.prod(continuous_action_dim) / 2 # (-dim(A)/2) # TODO (azouitine): Handle the case where the temparameter is a fixed # TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise @@ -275,7 +279,9 @@ class SACPolicy( next_observations=next_observations, done=done, ) - return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic} + return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic} + + return {"loss_critic": loss_critic} if model == "actor": diff --git a/lerobot/scripts/server/maniskill_manipulator.py b/lerobot/scripts/server/maniskill_manipulator.py index e10b8766..f4a89888 100644 --- a/lerobot/scripts/server/maniskill_manipulator.py +++ b/lerobot/scripts/server/maniskill_manipulator.py @@ -11,6 +11,10 @@ from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv from lerobot.common.envs.configs import ManiskillEnvConfig from lerobot.configs import parser +from lerobot.configs.train import TrainPipelineConfig +from lerobot.common.policies.sac.configuration_sac import SACConfig +from lerobot.common.policies.sac.modeling_sac import SACPolicy + def preprocess_maniskill_observation( @@ -152,6 +156,21 @@ class TimeLimitWrapper(gym.Wrapper): self.current_step = 0 return super().reset(seed=seed, options=options) +class ManiskillMockGripperWrapper(gym.Wrapper): + def __init__(self, env, nb_discrete_actions: int = 3): + super().__init__(env) + new_shape = env.action_space[0].shape[0] + 1 + new_low = np.concatenate([env.action_space[0].low, [0]]) + new_high = np.concatenate([env.action_space[0].high, [nb_discrete_actions - 1]]) + action_space_agent = gym.spaces.Box(low=new_low, high=new_high, shape=(new_shape,)) + self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1])) + + def step(self, action): + action_agent, telop_action = action + real_action = action_agent[:-1] + final_action = (real_action, telop_action) + obs, reward, terminated, truncated, info = self.env.step(final_action) + return obs, reward, terminated, truncated, info def make_maniskill( cfg: ManiskillEnvConfig, @@ -197,40 +216,42 @@ def make_maniskill( env = ManiSkillCompat(env) env = ManiSkillActionWrapper(env) env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control + if cfg.mock_gripper: + env = ManiskillMockGripperWrapper(env, nb_discrete_actions=3) return env -@parser.wrap() -def main(cfg: ManiskillEnvConfig): - """Main function to run the ManiSkill environment.""" - # Create the ManiSkill environment - env = make_maniskill(cfg, n_envs=1) +# @parser.wrap() +# def main(cfg: TrainPipelineConfig): +# """Main function to run the ManiSkill environment.""" +# # Create the ManiSkill environment +# env = make_maniskill(cfg.env, n_envs=1) - # Reset the environment - obs, info = env.reset() +# # Reset the environment +# obs, info = env.reset() - # Run a simple interaction loop - sum_reward = 0 - for i in range(100): - # Sample a random action - action = env.action_space.sample() +# # Run a simple interaction loop +# sum_reward = 0 +# for i in range(100): +# # Sample a random action +# action = env.action_space.sample() - # Step the environment - start_time = time.perf_counter() - obs, reward, terminated, truncated, info = env.step(action) - step_time = time.perf_counter() - start_time - sum_reward += reward - # Log information +# # Step the environment +# start_time = time.perf_counter() +# obs, reward, terminated, truncated, info = env.step(action) +# step_time = time.perf_counter() - start_time +# sum_reward += reward +# # Log information - # Reset if episode terminated - if terminated or truncated: - logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s") - sum_reward = 0 - obs, info = env.reset() +# # Reset if episode terminated +# if terminated or truncated: +# logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s") +# sum_reward = 0 +# obs, info = env.reset() - # Close the environment - env.close() +# # Close the environment +# env.close() # if __name__ == "__main__": From 699d374d895f8facdd2e5b66e379508e2466c97f Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Tue, 1 Apr 2025 15:43:29 +0000 Subject: [PATCH 11/22] Refactor SACPolicy for improved readability and action dimension handling - Cleaned up code formatting for better readability, including consistent spacing and removal of unnecessary blank lines. - Consolidated continuous action dimension calculation to enhance clarity and maintainability. - Simplified loss return statements in the forward method to improve code structure. - Ensured grasp critic parameters are included conditionally based on configuration settings. --- lerobot/common/policies/sac/modeling_sac.py | 59 +++++++++++---------- lerobot/scripts/server/learner_server.py | 14 ++--- 2 files changed, 37 insertions(+), 36 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 0c3d76d2..41ff7d8c 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -33,7 +33,8 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.sac.configuration_sac import SACConfig from lerobot.common.policies.utils import get_device_from_parameters -DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension +DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension + class SACPolicy( PreTrainedPolicy, @@ -50,6 +51,10 @@ class SACPolicy( config.validate_features() self.config = config + continuous_action_dim = config.output_features["action"].shape[0] + if config.num_discrete_actions is not None: + continuous_action_dim -= 1 + if config.dataset_stats is not None: input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats) self.normalize_inputs = Normalize( @@ -117,10 +122,7 @@ class SACPolicy( self.grasp_critic = None self.grasp_critic_target = None - continuous_action_dim = config.output_features["action"].shape[0] if config.num_discrete_actions is not None: - - continuous_action_dim -= 1 # Create grasp critic self.grasp_critic = GraspCritic( encoder=encoder_critic, @@ -142,7 +144,6 @@ class SACPolicy( self.grasp_critic = torch.compile(self.grasp_critic) self.grasp_critic_target = torch.compile(self.grasp_critic_target) - self.actor = Policy( encoder=encoder_actor, network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)), @@ -162,11 +163,14 @@ class SACPolicy( self.temperature = self.log_alpha.exp().item() def get_optim_params(self) -> dict: - return { + optim_params = { "actor": self.actor.parameters_to_optimize, "critic": self.critic_ensemble.parameters_to_optimize, "temperature": self.log_alpha, } + if self.config.num_discrete_actions is not None: + optim_params["grasp_critic"] = self.grasp_critic.parameters_to_optimize + return optim_params def reset(self): """Reset the policy""" @@ -262,7 +266,7 @@ class SACPolicy( done: Tensor = batch["done"] next_observation_features: Tensor = batch.get("next_observation_feature") - loss_critic = self.compute_loss_critic( + loss_critic = self.compute_loss_critic( observations=observations, actions=actions, rewards=rewards, @@ -283,18 +287,21 @@ class SACPolicy( return {"loss_critic": loss_critic} - if model == "actor": - return {"loss_actor": self.compute_loss_actor( - observations=observations, - observation_features=observation_features, - )} + return { + "loss_actor": self.compute_loss_actor( + observations=observations, + observation_features=observation_features, + ) + } if model == "temperature": - return {"loss_temperature": self.compute_loss_temperature( - observations=observations, - observation_features=observation_features, - )} + return { + "loss_temperature": self.compute_loss_temperature( + observations=observations, + observation_features=observation_features, + ) + } raise ValueError(f"Unknown model type: {model}") @@ -366,7 +373,7 @@ class SACPolicy( # In the buffer we have the full action space (continuous + discrete) # We need to split them before concatenating them in the critic forward actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX] - + q_preds = self.critic_forward( observations=observations, actions=actions, @@ -407,15 +414,13 @@ class SACPolicy( # For DQN, select actions using online network, evaluate with target network next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=False) best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1) - + # Get target Q-values from target network target_next_grasp_qs = self.grasp_critic_forward(observations=next_observations, use_target=True) - + # Use gather to select Q-values for best actions target_next_grasp_q = torch.gather( - target_next_grasp_qs, - dim=1, - index=best_next_grasp_action.unsqueeze(-1) + target_next_grasp_qs, dim=1, index=best_next_grasp_action.unsqueeze(-1) ).squeeze(-1) # Compute target Q-value with Bellman equation @@ -423,13 +428,9 @@ class SACPolicy( # Get predicted Q-values for current observations predicted_grasp_qs = self.grasp_critic_forward(observations=observations, use_target=False) - + # Use gather to select Q-values for taken actions - predicted_grasp_q = torch.gather( - predicted_grasp_qs, - dim=1, - index=actions.unsqueeze(-1) - ).squeeze(-1) + predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions.unsqueeze(-1)).squeeze(-1) # Compute MSE loss between predicted and target Q-values grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q) @@ -642,7 +643,7 @@ class GraspCritic(nn.Module): self, encoder: Optional[nn.Module], network: nn.Module, - output_dim: int = 3, # TODO (azouitine): rename it number of discret acitons smth like that + output_dim: int = 3, # TODO (azouitine): rename it number of discret acitons smth like that init_final: Optional[float] = None, encoder_is_shared: bool = False, ): diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 627a1a17..c57f83fc 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -394,7 +394,7 @@ def add_actor_information_and_train( # Use the forward method for critic loss (includes both main critic and grasp critic) critic_output = policy.forward(forward_batch, model="critic") - + # Main critic optimization loss_critic = critic_output["loss_critic"] optimizers["critic"].zero_grad() @@ -405,7 +405,7 @@ def add_actor_information_and_train( optimizers["critic"].step() # Grasp critic optimization (if available) - if "loss_grasp_critic" in critic_output and hasattr(policy, "grasp_critic"): + if "loss_grasp_critic" in critic_output: loss_grasp_critic = critic_output["loss_grasp_critic"] optimizers["grasp_critic"].zero_grad() loss_grasp_critic.backward() @@ -450,7 +450,7 @@ def add_actor_information_and_train( # Use the forward method for critic loss (includes both main critic and grasp critic) critic_output = policy.forward(forward_batch, model="critic") - + # Main critic optimization loss_critic = critic_output["loss_critic"] optimizers["critic"].zero_grad() @@ -475,7 +475,7 @@ def add_actor_information_and_train( parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value ).item() optimizers["grasp_critic"].step() - + # Add grasp critic info to training info training_infos["loss_grasp_critic"] = loss_grasp_critic.item() training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm @@ -492,7 +492,7 @@ def add_actor_information_and_train( parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value ).item() optimizers["actor"].step() - + # Add actor info to training info training_infos["loss_actor"] = loss_actor.item() training_infos["actor_grad_norm"] = actor_grad_norm @@ -506,7 +506,7 @@ def add_actor_information_and_train( parameters=[policy.log_alpha], max_norm=clip_grad_norm_value ).item() optimizers["temperature"].step() - + # Add temperature info to training info training_infos["loss_temperature"] = loss_temperature.item() training_infos["temperature_grad_norm"] = temp_grad_norm @@ -756,7 +756,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module): lr=cfg.policy.actor_lr, ) optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) - + if cfg.policy.num_discrete_actions is not None: optimizer_grasp_critic = torch.optim.Adam( params=policy.grasp_critic.parameters(), lr=policy.critic_lr From 0ed7ff142cd45992469cf2aee7a4f36f8571bb92 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Wed, 2 Apr 2025 15:50:39 +0000 Subject: [PATCH 12/22] Enhance SACPolicy and learner server for improved grasp critic integration - Updated SACPolicy to conditionally compute grasp critic losses based on the presence of discrete actions. - Refactored the forward method to handle grasp critic model selection and loss computation more clearly. - Adjusted learner server to utilize optimized parameters for grasp critic during training. - Improved action handling in the ManiskillMockGripperWrapper to accommodate both tuple and single action inputs. --- lerobot/common/policies/sac/modeling_sac.py | 95 +++++++++++-------- lerobot/scripts/server/learner_server.py | 16 ++-- .../scripts/server/maniskill_manipulator.py | 11 ++- 3 files changed, 72 insertions(+), 50 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 41ff7d8c..af624592 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -52,8 +52,6 @@ class SACPolicy( self.config = config continuous_action_dim = config.output_features["action"].shape[0] - if config.num_discrete_actions is not None: - continuous_action_dim -= 1 if config.dataset_stats is not None: input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats) @@ -191,7 +189,7 @@ class SACPolicy( if self.config.num_discrete_actions is not None: discrete_action_value = self.grasp_critic(batch) - discrete_action = torch.argmax(discrete_action_value, dim=-1) + discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True) actions = torch.cat([actions, discrete_action], dim=-1) return actions @@ -236,7 +234,7 @@ class SACPolicy( def forward( self, batch: dict[str, Tensor | dict[str, Tensor]], - model: Literal["actor", "critic", "temperature"] = "critic", + model: Literal["actor", "critic", "temperature", "grasp_critic"] = "critic", ) -> dict[str, Tensor]: """Compute the loss for the given model @@ -275,18 +273,25 @@ class SACPolicy( observation_features=observation_features, next_observation_features=next_observation_features, ) - if self.config.num_discrete_actions is not None: - loss_grasp_critic = self.compute_loss_grasp_critic( - observations=observations, - actions=actions, - rewards=rewards, - next_observations=next_observations, - done=done, - ) - return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic} return {"loss_critic": loss_critic} + if model == "grasp_critic" and self.config.num_discrete_actions is not None: + # Extract critic-specific components + rewards: Tensor = batch["reward"] + next_observations: dict[str, Tensor] = batch["next_state"] + done: Tensor = batch["done"] + next_observation_features: Tensor = batch.get("next_observation_feature") + loss_grasp_critic = self.compute_loss_grasp_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + observation_features=observation_features, + next_observation_features=next_observation_features, + ) + return {"loss_grasp_critic": loss_grasp_critic} if model == "actor": return { "loss_actor": self.compute_loss_actor( @@ -373,7 +378,6 @@ class SACPolicy( # In the buffer we have the full action space (continuous + discrete) # We need to split them before concatenating them in the critic forward actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX] - q_preds = self.critic_forward( observations=observations, actions=actions, @@ -407,30 +411,38 @@ class SACPolicy( # NOTE: We only want to keep the discrete action part # In the buffer we have the full action space (continuous + discrete) # We need to split them before concatenating them in the critic forward - actions: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:] - actions = actions.long() + actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone() + actions_discrete = actions_discrete.long() with torch.no_grad(): # For DQN, select actions using online network, evaluate with target network - next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=False) - best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1) + next_grasp_qs = self.grasp_critic_forward( + next_observations, use_target=False, observation_features=next_observation_features + ) + best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1, keepdim=True) # Get target Q-values from target network - target_next_grasp_qs = self.grasp_critic_forward(observations=next_observations, use_target=True) + target_next_grasp_qs = self.grasp_critic_forward( + observations=next_observations, + use_target=True, + observation_features=next_observation_features, + ) # Use gather to select Q-values for best actions target_next_grasp_q = torch.gather( - target_next_grasp_qs, dim=1, index=best_next_grasp_action.unsqueeze(-1) + target_next_grasp_qs, dim=1, index=best_next_grasp_action ).squeeze(-1) # Compute target Q-value with Bellman equation target_grasp_q = rewards + (1 - done) * self.config.discount * target_next_grasp_q # Get predicted Q-values for current observations - predicted_grasp_qs = self.grasp_critic_forward(observations=observations, use_target=False) + predicted_grasp_qs = self.grasp_critic_forward( + observations=observations, use_target=False, observation_features=observation_features + ) # Use gather to select Q-values for taken actions - predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions.unsqueeze(-1)).squeeze(-1) + predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions_discrete).squeeze(-1) # Compute MSE loss between predicted and target Q-values grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q) @@ -642,49 +654,52 @@ class GraspCritic(nn.Module): def __init__( self, encoder: Optional[nn.Module], - network: nn.Module, - output_dim: int = 3, # TODO (azouitine): rename it number of discret acitons smth like that + input_dim: int, + hidden_dims: list[int], + output_dim: int = 3, + activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), + activate_final: bool = False, + dropout_rate: Optional[float] = None, init_final: Optional[float] = None, - encoder_is_shared: bool = False, + final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None, ): super().__init__() self.encoder = encoder - self.network = network self.output_dim = output_dim - # Find the last Linear layer's output dimension - for layer in reversed(network.net): - if isinstance(layer, nn.Linear): - out_features = layer.out_features - break + self.net = MLP( + input_dim=input_dim, + hidden_dims=hidden_dims, + activations=activations, + activate_final=activate_final, + dropout_rate=dropout_rate, + final_activation=final_activation, + ) - self.parameters_to_optimize += list(self.network.parameters()) - - if self.encoder is not None and not encoder_is_shared: - self.parameters_to_optimize += list(self.encoder.parameters()) - - self.output_layer = nn.Linear(in_features=out_features, out_features=self.output_dim) + self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=self.output_dim) if init_final is not None: nn.init.uniform_(self.output_layer.weight, -init_final, init_final) nn.init.uniform_(self.output_layer.bias, -init_final, init_final) else: orthogonal_init()(self.output_layer.weight) + self.parameters_to_optimize = [] + self.parameters_to_optimize += list(self.net.parameters()) self.parameters_to_optimize += list(self.output_layer.parameters()) def forward( self, observations: torch.Tensor, observation_features: torch.Tensor | None = None ) -> torch.Tensor: device = get_device_from_parameters(self) - # Move each tensor in observations to device + # Move each tensor in observations to device by cloning first to avoid inplace operations observations = {k: v.to(device) for k, v in observations.items()} # Encode observations if encoder exists obs_enc = ( - observation_features + observation_features.to(device) if observation_features is not None else (observations if self.encoder is None else self.encoder(observations)) ) - return self.output_layer(self.network(obs_enc)) + return self.output_layer(self.net(obs_enc)) class Policy(nn.Module): diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index c57f83fc..ce9a1b41 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -405,12 +405,13 @@ def add_actor_information_and_train( optimizers["critic"].step() # Grasp critic optimization (if available) - if "loss_grasp_critic" in critic_output: - loss_grasp_critic = critic_output["loss_grasp_critic"] + if policy.config.num_discrete_actions is not None: + discrete_critic_output = policy.forward(forward_batch, model="grasp_critic") + loss_grasp_critic = discrete_critic_output["loss_grasp_critic"] optimizers["grasp_critic"].zero_grad() loss_grasp_critic.backward() grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value + parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value ) optimizers["grasp_critic"].step() @@ -467,12 +468,13 @@ def add_actor_information_and_train( } # Grasp critic optimization (if available) - if "loss_grasp_critic" in critic_output: - loss_grasp_critic = critic_output["loss_grasp_critic"] + if policy.config.num_discrete_actions is not None: + discrete_critic_output = policy.forward(forward_batch, model="grasp_critic") + loss_grasp_critic = discrete_critic_output["loss_grasp_critic"] optimizers["grasp_critic"].zero_grad() loss_grasp_critic.backward() grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value + parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value ).item() optimizers["grasp_critic"].step() @@ -759,7 +761,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module): if cfg.policy.num_discrete_actions is not None: optimizer_grasp_critic = torch.optim.Adam( - params=policy.grasp_critic.parameters(), lr=policy.critic_lr + params=policy.grasp_critic.parameters_to_optimize, lr=cfg.policy.critic_lr ) optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) lr_scheduler = None diff --git a/lerobot/scripts/server/maniskill_manipulator.py b/lerobot/scripts/server/maniskill_manipulator.py index f4a89888..b5c181c1 100644 --- a/lerobot/scripts/server/maniskill_manipulator.py +++ b/lerobot/scripts/server/maniskill_manipulator.py @@ -16,7 +16,6 @@ from lerobot.common.policies.sac.configuration_sac import SACConfig from lerobot.common.policies.sac.modeling_sac import SACPolicy - def preprocess_maniskill_observation( observations: dict[str, np.ndarray], ) -> dict[str, torch.Tensor]: @@ -156,6 +155,7 @@ class TimeLimitWrapper(gym.Wrapper): self.current_step = 0 return super().reset(seed=seed, options=options) + class ManiskillMockGripperWrapper(gym.Wrapper): def __init__(self, env, nb_discrete_actions: int = 3): super().__init__(env) @@ -166,11 +166,16 @@ class ManiskillMockGripperWrapper(gym.Wrapper): self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1])) def step(self, action): - action_agent, telop_action = action + if isinstance(action, tuple): + action_agent, telop_action = action + else: + telop_action = 0 + action_agent = action real_action = action_agent[:-1] final_action = (real_action, telop_action) obs, reward, terminated, truncated, info = self.env.step(final_action) - return obs, reward, terminated, truncated, info + return obs, reward, terminated, truncated, info + def make_maniskill( cfg: ManiskillEnvConfig, From 51f1625c2004cdb8adf50c5c862a1c778d746c97 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Thu, 3 Apr 2025 07:44:46 +0000 Subject: [PATCH 13/22] Enhance SACPolicy to support shared encoder and optimize action selection - Cached encoder output in select_action method to reduce redundant computations. - Updated action selection and grasp critic calls to utilize cached encoder features when available. --- lerobot/common/policies/sac/modeling_sac.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index af624592..2246bf8c 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -81,6 +81,7 @@ class SACPolicy( else: encoder_critic = SACObservationEncoder(config, self.normalize_inputs) encoder_actor = SACObservationEncoder(config, self.normalize_inputs) + self.shared_encoder = config.shared_encoder # Create a list of critic heads critic_heads = [ @@ -184,11 +185,16 @@ class SACPolicy( @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select action for inference/evaluation""" - actions, _, _ = self.actor(batch) + # We cached the encoder output to avoid recomputing it + observations_features = None + if self.shared_encoder and self.actor.encoder is not None: + observations_features = self.actor.encoder(batch) + + actions, _, _ = self.actor(batch, observations_features) actions = self.unnormalize_outputs({"action": actions})["action"] if self.config.num_discrete_actions is not None: - discrete_action_value = self.grasp_critic(batch) + discrete_action_value = self.grasp_critic(batch, observations_features) discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True) actions = torch.cat([actions, discrete_action], dim=-1) From 38a8dbd9c9a0fe5bf9fb0244be068869ea0bc1c1 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Thu, 3 Apr 2025 14:23:50 +0000 Subject: [PATCH 14/22] Enhance SAC configuration and replay buffer with asynchronous prefetching support - Added async_prefetch parameter to SACConfig for improved buffer management. - Implemented get_iterator method in ReplayBuffer to support asynchronous prefetching of batches. - Updated learner_server to utilize the new iterator for online and offline sampling, enhancing training efficiency. --- .../common/policies/sac/configuration_sac.py | 4 +- lerobot/scripts/server/buffer.py | 576 ++++-------------- lerobot/scripts/server/learner_server.py | 34 +- 3 files changed, 132 insertions(+), 482 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index ae38b1c5..3d01f47c 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -42,8 +42,6 @@ class CriticNetworkConfig: final_activation: str | None = None - - @dataclass class ActorNetworkConfig: hidden_dims: list[int] = field(default_factory=lambda: [256, 256]) @@ -94,6 +92,7 @@ class SACConfig(PreTrainedConfig): online_env_seed: Seed for the online environment. online_buffer_capacity: Capacity of the online replay buffer. offline_buffer_capacity: Capacity of the offline replay buffer. + async_prefetch: Whether to use asynchronous prefetching for the buffers. online_step_before_learning: Number of steps before learning starts. policy_update_freq: Frequency of policy updates. discount: Discount factor for the SAC algorithm. @@ -154,6 +153,7 @@ class SACConfig(PreTrainedConfig): online_env_seed: int = 10000 online_buffer_capacity: int = 100000 offline_buffer_capacity: int = 100000 + async_prefetch: bool = False online_step_before_learning: int = 100 policy_update_freq: int = 1 diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 2af3995e..c8f85372 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -345,6 +345,109 @@ class ReplayBuffer: truncated=batch_truncateds, ) + def get_iterator( + self, + batch_size: int, + async_prefetch: bool = True, + queue_size: int = 2, + ): + """ + Creates an infinite iterator that yields batches of transitions. + Will automatically restart when internal iterator is exhausted. + + Args: + batch_size (int): Size of batches to sample + async_prefetch (bool): Whether to use asynchronous prefetching with threads (default: True) + queue_size (int): Number of batches to prefetch (default: 2) + + Yields: + BatchTransition: Batched transitions + """ + while True: # Create an infinite loop + if async_prefetch: + # Get the standard iterator + iterator = self._get_async_iterator(queue_size=queue_size, batch_size=batch_size) + else: + iterator = self._get_naive_iterator(batch_size=batch_size, queue_size=queue_size) + + # Yield all items from the iterator + try: + yield from iterator + except StopIteration: + # Just continue the outer loop to create a new iterator + pass + + def _get_async_iterator(self, batch_size: int, queue_size: int = 2): + """ + Creates an iterator that prefetches batches in a background thread. + + Args: + queue_size (int): Number of batches to prefetch (default: 2) + batch_size (int): Size of batches to sample (default: 128) + + Yields: + BatchTransition: Prefetched batch transitions + """ + import threading + import queue + + # Use thread-safe queue + data_queue = queue.Queue(maxsize=queue_size) + running = [True] # Use list to allow modification in nested function + + def prefetch_worker(): + while running[0]: + try: + # Sample data and add to queue + data = self.sample(batch_size) + data_queue.put(data, block=True, timeout=0.5) + except queue.Full: + continue + except Exception as e: + print(f"Prefetch error: {e}") + break + + # Start prefetching thread + thread = threading.Thread(target=prefetch_worker, daemon=True) + thread.start() + + try: + while running[0]: + try: + yield data_queue.get(block=True, timeout=0.5) + except queue.Empty: + if not thread.is_alive(): + break + finally: + # Clean up + running[0] = False + thread.join(timeout=1.0) + + def _get_naive_iterator(self, batch_size: int, queue_size: int = 2): + """ + Creates a simple non-threaded iterator that yields batches. + + Args: + batch_size (int): Size of batches to sample + queue_size (int): Number of initial batches to prefetch + + Yields: + BatchTransition: Batch transitions + """ + import collections + + queue = collections.deque() + + def enqueue(n): + for _ in range(n): + data = self.sample(batch_size) + queue.append(data) + + enqueue(queue_size) + while queue: + yield queue.popleft() + enqueue(1) + @classmethod def from_lerobot_dataset( cls, @@ -710,475 +813,4 @@ def concatenate_batch_transitions( if __name__ == "__main__": - from tempfile import TemporaryDirectory - - # ===== Test 1: Create and use a synthetic ReplayBuffer ===== - print("Testing synthetic ReplayBuffer...") - - # Create sample data dimensions - batch_size = 32 - state_dims = {"observation.image": (3, 84, 84), "observation.state": (10,)} - action_dim = (6,) - - # Create a buffer - buffer = ReplayBuffer( - capacity=1000, - device="cpu", - state_keys=list(state_dims.keys()), - use_drq=True, - storage_device="cpu", - ) - - # Add some random transitions - for i in range(100): - # Create dummy transition data - state = { - "observation.image": torch.rand(1, 3, 84, 84), - "observation.state": torch.rand(1, 10), - } - action = torch.rand(1, 6) - reward = 0.5 - next_state = { - "observation.image": torch.rand(1, 3, 84, 84), - "observation.state": torch.rand(1, 10), - } - done = False if i < 99 else True - truncated = False - - buffer.add( - state=state, - action=action, - reward=reward, - next_state=next_state, - done=done, - truncated=truncated, - ) - - # Test sampling - batch = buffer.sample(batch_size) - print(f"Buffer size: {len(buffer)}") - print( - f"Sampled batch state shapes: {batch['state']['observation.image'].shape}, {batch['state']['observation.state'].shape}" - ) - print(f"Sampled batch action shape: {batch['action'].shape}") - print(f"Sampled batch reward shape: {batch['reward'].shape}") - print(f"Sampled batch done shape: {batch['done'].shape}") - print(f"Sampled batch truncated shape: {batch['truncated'].shape}") - - # ===== Test for state-action-reward alignment ===== - print("\nTesting state-action-reward alignment...") - - # Create a buffer with controlled transitions where we know the relationships - aligned_buffer = ReplayBuffer( - capacity=100, device="cpu", state_keys=["state_value"], storage_device="cpu" - ) - - # Create transitions with known relationships - # - Each state has a unique signature value - # - Action is 2x the state signature - # - Reward is 3x the state signature - # - Next state is signature + 0.01 (unless at episode end) - for i in range(100): - # Create a state with a signature value that encodes the transition number - signature = float(i) / 100.0 - state = {"state_value": torch.tensor([[signature]]).float()} - - # Action is 2x the signature - action = torch.tensor([[2.0 * signature]]).float() - - # Reward is 3x the signature - reward = 3.0 * signature - - # Next state is signature + 0.01, unless end of episode - # End episode every 10 steps - is_end = (i + 1) % 10 == 0 - - if is_end: - # At episode boundaries, next_state repeats current state (as per your implementation) - next_state = {"state_value": torch.tensor([[signature]]).float()} - done = True - else: - # Within episodes, next_state has signature + 0.01 - next_signature = float(i + 1) / 100.0 - next_state = {"state_value": torch.tensor([[next_signature]]).float()} - done = False - - aligned_buffer.add(state, action, reward, next_state, done, False) - - # Sample from this buffer - aligned_batch = aligned_buffer.sample(50) - - # Verify alignments in sampled batch - correct_relationships = 0 - total_checks = 0 - - # For each transition in the batch - for i in range(50): - # Extract signature from state - state_sig = aligned_batch["state"]["state_value"][i].item() - - # Check action is 2x signature (within reasonable precision) - action_val = aligned_batch["action"][i].item() - action_check = abs(action_val - 2.0 * state_sig) < 1e-4 - - # Check reward is 3x signature (within reasonable precision) - reward_val = aligned_batch["reward"][i].item() - reward_check = abs(reward_val - 3.0 * state_sig) < 1e-4 - - # Check next_state relationship matches our pattern - next_state_sig = aligned_batch["next_state"]["state_value"][i].item() - is_done = aligned_batch["done"][i].item() > 0.5 - - # Calculate expected next_state value based on done flag - if is_done: - # For episodes that end, next_state should equal state - next_state_check = abs(next_state_sig - state_sig) < 1e-4 - else: - # For continuing episodes, check if next_state is approximately state + 0.01 - # We need to be careful because we don't know the original index - # So we check if the increment is roughly 0.01 - next_state_check = ( - abs(next_state_sig - state_sig - 0.01) < 1e-4 or abs(next_state_sig - state_sig) < 1e-4 - ) - - # Count correct relationships - if action_check: - correct_relationships += 1 - if reward_check: - correct_relationships += 1 - if next_state_check: - correct_relationships += 1 - - total_checks += 3 - - alignment_accuracy = 100.0 * correct_relationships / total_checks - print(f"State-action-reward-next_state alignment accuracy: {alignment_accuracy:.2f}%") - if alignment_accuracy > 99.0: - print("✅ All relationships verified! Buffer maintains correct temporal relationships.") - else: - print("⚠️ Some relationships don't match expected patterns. Buffer may have alignment issues.") - - # Print some debug information about failures - print("\nDebug information for failed checks:") - for i in range(5): # Print first 5 transitions for debugging - state_sig = aligned_batch["state"]["state_value"][i].item() - action_val = aligned_batch["action"][i].item() - reward_val = aligned_batch["reward"][i].item() - next_state_sig = aligned_batch["next_state"]["state_value"][i].item() - is_done = aligned_batch["done"][i].item() > 0.5 - - print(f"Transition {i}:") - print(f" State: {state_sig:.6f}") - print(f" Action: {action_val:.6f} (expected: {2.0 * state_sig:.6f})") - print(f" Reward: {reward_val:.6f} (expected: {3.0 * state_sig:.6f})") - print(f" Done: {is_done}") - print(f" Next state: {next_state_sig:.6f}") - - # Calculate expected next state - if is_done: - expected_next = state_sig - else: - # This approximation might not be perfect - state_idx = round(state_sig * 100) - expected_next = (state_idx + 1) / 100.0 - - print(f" Expected next state: {expected_next:.6f}") - print() - - # ===== Test 2: Convert to LeRobotDataset and back ===== - with TemporaryDirectory() as temp_dir: - print("\nTesting conversion to LeRobotDataset and back...") - # Convert buffer to dataset - repo_id = "test/replay_buffer_conversion" - # Create a subdirectory to avoid the "directory exists" error - dataset_dir = os.path.join(temp_dir, "dataset1") - dataset = buffer.to_lerobot_dataset(repo_id=repo_id, root=dataset_dir) - - print(f"Dataset created with {len(dataset)} frames") - print(f"Dataset features: {list(dataset.features.keys())}") - - # Check a random sample from the dataset - sample = dataset[0] - print( - f"Dataset sample types: {[(k, type(v)) for k, v in sample.items() if k.startswith('observation')]}" - ) - - # Convert dataset back to buffer - reconverted_buffer = ReplayBuffer.from_lerobot_dataset( - dataset, state_keys=list(state_dims.keys()), device="cpu" - ) - - print(f"Reconverted buffer size: {len(reconverted_buffer)}") - - # Sample from the reconverted buffer - reconverted_batch = reconverted_buffer.sample(batch_size) - print( - f"Reconverted batch state shapes: {reconverted_batch['state']['observation.image'].shape}, {reconverted_batch['state']['observation.state'].shape}" - ) - - # Verify consistency before and after conversion - original_states = batch["state"]["observation.image"].mean().item() - reconverted_states = reconverted_batch["state"]["observation.image"].mean().item() - print(f"Original buffer state mean: {original_states:.4f}") - print(f"Reconverted buffer state mean: {reconverted_states:.4f}") - - if abs(original_states - reconverted_states) < 1.0: - print("Values are reasonably similar - conversion works as expected") - else: - print("WARNING: Significant difference between original and reconverted values") - - print("\nAll previous tests completed!") - - # ===== Test for memory optimization ===== - print("\n===== Testing Memory Optimization =====") - - # Create two buffers, one with memory optimization and one without - standard_buffer = ReplayBuffer( - capacity=1000, - device="cpu", - state_keys=["observation.image", "observation.state"], - storage_device="cpu", - optimize_memory=False, - use_drq=True, - ) - - optimized_buffer = ReplayBuffer( - capacity=1000, - device="cpu", - state_keys=["observation.image", "observation.state"], - storage_device="cpu", - optimize_memory=True, - use_drq=True, - ) - - # Generate sample data with larger state dimensions for better memory impact - print("Generating test data...") - num_episodes = 10 - steps_per_episode = 50 - total_steps = num_episodes * steps_per_episode - - for episode in range(num_episodes): - for step in range(steps_per_episode): - # Index in the overall sequence - i = episode * steps_per_episode + step - - # Create state with identifiable values - img = torch.ones((3, 84, 84)) * (i / total_steps) - state_vec = torch.ones((10,)) * (i / total_steps) - - state = { - "observation.image": img.unsqueeze(0), - "observation.state": state_vec.unsqueeze(0), - } - - # Create next state (i+1 or same as current if last in episode) - is_last_step = step == steps_per_episode - 1 - - if is_last_step: - # At episode end, next state = current state - next_img = img.clone() - next_state_vec = state_vec.clone() - done = True - truncated = False - else: - # Within episode, next state has incremented value - next_val = (i + 1) / total_steps - next_img = torch.ones((3, 84, 84)) * next_val - next_state_vec = torch.ones((10,)) * next_val - done = False - truncated = False - - next_state = { - "observation.image": next_img.unsqueeze(0), - "observation.state": next_state_vec.unsqueeze(0), - } - - # Action and reward - action = torch.tensor([[i / total_steps]]) - reward = float(i / total_steps) - - # Add to both buffers - standard_buffer.add(state, action, reward, next_state, done, truncated) - optimized_buffer.add(state, action, reward, next_state, done, truncated) - - # Verify episode boundaries with our simplified approach - print("\nVerifying simplified memory optimization...") - - # Test with a new buffer with a small sequence - test_buffer = ReplayBuffer( - capacity=20, - device="cpu", - state_keys=["value"], - storage_device="cpu", - optimize_memory=True, - use_drq=False, - ) - - # Add a simple sequence with known episode boundaries - for i in range(20): - val = float(i) - state = {"value": torch.tensor([[val]]).float()} - next_val = float(i + 1) if i % 5 != 4 else val # Episode ends every 5 steps - next_state = {"value": torch.tensor([[next_val]]).float()} - - # Set done=True at every 5th step - done = (i % 5) == 4 - action = torch.tensor([[0.0]]) - reward = 1.0 - truncated = False - - test_buffer.add(state, action, reward, next_state, done, truncated) - - # Get sequential batch for verification - sequential_batch_size = test_buffer.size - all_indices = torch.arange(sequential_batch_size, device=test_buffer.storage_device) - - # Get state tensors - batch_state = {"value": test_buffer.states["value"][all_indices].to(test_buffer.device)} - - # Get next_state using memory-optimized approach (simply index+1) - next_indices = (all_indices + 1) % test_buffer.capacity - batch_next_state = {"value": test_buffer.states["value"][next_indices].to(test_buffer.device)} - - # Get other tensors - batch_dones = test_buffer.dones[all_indices].to(test_buffer.device) - - # Print sequential values - print("State, Next State, Done (Sequential values with simplified optimization):") - state_values = batch_state["value"].squeeze().tolist() - next_values = batch_next_state["value"].squeeze().tolist() - done_flags = batch_dones.tolist() - - # Print all values - for i in range(len(state_values)): - print(f" {state_values[i]:.1f} → {next_values[i]:.1f}, Done: {done_flags[i]}") - - # Explain the memory optimization tradeoff - print("\nWith simplified memory optimization:") - print("- We always use the next state in the buffer (index+1) as next_state") - print("- For terminal states, this means using the first state of the next episode") - print("- This is a common tradeoff in RL implementations for memory efficiency") - print("- Since we track done flags, the algorithm can handle these transitions correctly") - - # Test random sampling - print("\nVerifying random sampling with simplified memory optimization...") - random_samples = test_buffer.sample(20) # Sample all transitions - - # Extract values - random_state_values = random_samples["state"]["value"].squeeze().tolist() - random_next_values = random_samples["next_state"]["value"].squeeze().tolist() - random_done_flags = random_samples["done"].bool().tolist() - - # Print a few samples - print("Random samples - State, Next State, Done (First 10):") - for i in range(10): - print(f" {random_state_values[i]:.1f} → {random_next_values[i]:.1f}, Done: {random_done_flags[i]}") - - # Calculate memory savings - # Assume optimized_buffer and standard_buffer have already been initialized and filled - std_mem = ( - sum( - standard_buffer.states[key].nelement() * standard_buffer.states[key].element_size() - for key in standard_buffer.states - ) - * 2 - ) - opt_mem = sum( - optimized_buffer.states[key].nelement() * optimized_buffer.states[key].element_size() - for key in optimized_buffer.states - ) - - savings_percent = (std_mem - opt_mem) / std_mem * 100 - - print("\nMemory optimization result:") - print(f"- Standard buffer state memory: {std_mem / (1024 * 1024):.2f} MB") - print(f"- Optimized buffer state memory: {opt_mem / (1024 * 1024):.2f} MB") - print(f"- Memory savings for state tensors: {savings_percent:.1f}%") - - print("\nAll memory optimization tests completed!") - - # # ===== Test real dataset conversion ===== - # print("\n===== Testing Real LeRobotDataset Conversion =====") - # try: - # # Try to use a real dataset if available - # dataset_name = "AdilZtn/Maniskill-Pushcube-demonstration-small" - # dataset = LeRobotDataset(repo_id=dataset_name) - - # # Print available keys to debug - # sample = dataset[0] - # print("Available keys in dataset:", list(sample.keys())) - - # # Check for required keys - # if "action" not in sample or "next.reward" not in sample: - # print("Dataset missing essential keys. Cannot convert.") - # raise ValueError("Missing required keys in dataset") - - # # Auto-detect appropriate state keys - # image_keys = [] - # state_keys = [] - # for k, v in sample.items(): - # # Skip metadata keys and action/reward keys - # if k in { - # "index", - # "episode_index", - # "frame_index", - # "timestamp", - # "task_index", - # "action", - # "next.reward", - # "next.done", - # }: - # continue - - # # Infer key type from tensor shape - # if isinstance(v, torch.Tensor): - # if len(v.shape) == 3 and (v.shape[0] == 3 or v.shape[0] == 1): - # # Likely an image (channels, height, width) - # image_keys.append(k) - # else: - # # Likely state or other vector - # state_keys.append(k) - - # print(f"Detected image keys: {image_keys}") - # print(f"Detected state keys: {state_keys}") - - # if not image_keys and not state_keys: - # print("No usable keys found in dataset, skipping further tests") - # raise ValueError("No usable keys found in dataset") - - # # Test with standard and memory-optimized buffers - # for optimize_memory in [False, True]: - # buffer_type = "Standard" if not optimize_memory else "Memory-optimized" - # print(f"\nTesting {buffer_type} buffer with real dataset...") - - # # Convert to ReplayBuffer with detected keys - # replay_buffer = ReplayBuffer.from_lerobot_dataset( - # lerobot_dataset=dataset, - # state_keys=image_keys + state_keys, - # device="cpu", - # optimize_memory=optimize_memory, - # ) - # print(f"Loaded {len(replay_buffer)} transitions from {dataset_name}") - - # # Test sampling - # real_batch = replay_buffer.sample(32) - # print(f"Sampled batch from real dataset ({buffer_type}), state shapes:") - # for key in real_batch["state"]: - # print(f" {key}: {real_batch['state'][key].shape}") - - # # Convert back to LeRobotDataset - # with TemporaryDirectory() as temp_dir: - # dataset_name = f"test/real_dataset_converted_{buffer_type}" - # replay_buffer_converted = replay_buffer.to_lerobot_dataset( - # repo_id=dataset_name, - # root=os.path.join(temp_dir, f"dataset_{buffer_type}"), - # ) - # print( - # f"Successfully converted back to LeRobotDataset with {len(replay_buffer_converted)} frames" - # ) - - # except Exception as e: - # print(f"Real dataset test failed: {e}") - # print("This is expected if running offline or if the dataset is not available.") - - # print("\nAll tests completed!") + pass # All test code is currently commented out diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index ce9a1b41..08baa6ea 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -269,6 +269,7 @@ def add_actor_information_and_train( policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency saving_checkpoint = cfg.save_checkpoint online_steps = cfg.policy.online_steps + async_prefetch = cfg.policy.async_prefetch # Initialize logging for multiprocessing if not use_threads(cfg): @@ -326,6 +327,9 @@ def add_actor_information_and_train( if cfg.dataset is not None: dataset_repo_id = cfg.dataset.repo_id + # Initialize iterators + online_iterator = None + offline_iterator = None # NOTE: THIS IS THE MAIN LOOP OF THE LEARNER while True: # Exit the training loop if shutdown is requested @@ -359,16 +363,29 @@ def add_actor_information_and_train( if len(replay_buffer) < online_step_before_learning: continue + if online_iterator is None: + logging.debug("[LEARNER] Initializing online replay buffer iterator") + online_iterator = replay_buffer.get_iterator( + batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2 + ) + + if offline_replay_buffer is not None and offline_iterator is None: + logging.debug("[LEARNER] Initializing offline replay buffer iterator") + offline_iterator = offline_replay_buffer.get_iterator( + batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2 + ) + logging.debug("[LEARNER] Starting optimization loop") time_for_one_optimization_step = time.time() for _ in range(utd_ratio - 1): - batch = replay_buffer.sample(batch_size=batch_size) + # Sample from the iterators + batch = next(online_iterator) - if dataset_repo_id is not None: - batch_offline = offline_replay_buffer.sample(batch_size=batch_size) - batch = concatenate_batch_transitions( - left_batch_transitions=batch, right_batch_transition=batch_offline - ) + if dataset_repo_id is not None: + batch_offline = next(offline_iterator) + batch = concatenate_batch_transitions( + left_batch_transitions=batch, right_batch_transition=batch_offline + ) actions = batch["action"] rewards = batch["reward"] @@ -418,10 +435,11 @@ def add_actor_information_and_train( # Update target networks policy.update_target_networks() - batch = replay_buffer.sample(batch_size=batch_size) + # Sample for the last update in the UTD ratio + batch = next(online_iterator) if dataset_repo_id is not None: - batch_offline = offline_replay_buffer.sample(batch_size=batch_size) + batch_offline = next(offline_iterator) batch = concatenate_batch_transitions( left_batch_transitions=batch, right_batch_transition=batch_offline ) From e86fe66dbd392fbf087e0ac74f8e23c1262c272b Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Thu, 3 Apr 2025 16:05:29 +0000 Subject: [PATCH 15/22] fix indentation issue --- lerobot/scripts/server/learner_server.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 08baa6ea..65b1d9b8 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -381,11 +381,11 @@ def add_actor_information_and_train( # Sample from the iterators batch = next(online_iterator) - if dataset_repo_id is not None: - batch_offline = next(offline_iterator) - batch = concatenate_batch_transitions( - left_batch_transitions=batch, right_batch_transition=batch_offline - ) + if dataset_repo_id is not None: + batch_offline = next(offline_iterator) + batch = concatenate_batch_transitions( + left_batch_transitions=batch, right_batch_transition=batch_offline + ) actions = batch["action"] rewards = batch["reward"] @@ -435,8 +435,8 @@ def add_actor_information_and_train( # Update target networks policy.update_target_networks() - # Sample for the last update in the UTD ratio - batch = next(online_iterator) + # Sample for the last update in the UTD ratio + batch = next(online_iterator) if dataset_repo_id is not None: batch_offline = next(offline_iterator) From 037ecae9e0b89d07dbc3da30e177cd5d7c27ed70 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 4 Apr 2025 07:59:22 +0000 Subject: [PATCH 16/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- lerobot/scripts/server/buffer.py | 3 +-- lerobot/scripts/server/maniskill_manipulator.py | 6 ------ 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index c8f85372..8947f6d9 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -15,7 +15,6 @@ # limitations under the License. import functools import io -import os import pickle from typing import Any, Callable, Optional, Sequence, TypedDict @@ -388,8 +387,8 @@ class ReplayBuffer: Yields: BatchTransition: Prefetched batch transitions """ - import threading import queue + import threading # Use thread-safe queue data_queue = queue.Queue(maxsize=queue_size) diff --git a/lerobot/scripts/server/maniskill_manipulator.py b/lerobot/scripts/server/maniskill_manipulator.py index b5c181c1..03a7ec10 100644 --- a/lerobot/scripts/server/maniskill_manipulator.py +++ b/lerobot/scripts/server/maniskill_manipulator.py @@ -1,5 +1,3 @@ -import logging -import time from typing import Any import einops @@ -10,10 +8,6 @@ from mani_skill.utils.wrappers.record import RecordEpisode from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv from lerobot.common.envs.configs import ManiskillEnvConfig -from lerobot.configs import parser -from lerobot.configs.train import TrainPipelineConfig -from lerobot.common.policies.sac.configuration_sac import SACConfig -from lerobot.common.policies.sac.modeling_sac import SACPolicy def preprocess_maniskill_observation( From 7741526ce42f2017db33841895d36e2c95ac875c Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Fri, 4 Apr 2025 14:29:38 +0000 Subject: [PATCH 17/22] fix caching --- lerobot/common/policies/sac/modeling_sac.py | 224 ++++++++++---------- lerobot/scripts/server/learner_server.py | 12 +- 2 files changed, 117 insertions(+), 119 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 2246bf8c..b5bfb36e 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -187,8 +187,8 @@ class SACPolicy( """Select action for inference/evaluation""" # We cached the encoder output to avoid recomputing it observations_features = None - if self.shared_encoder and self.actor.encoder is not None: - observations_features = self.actor.encoder(batch) + if self.shared_encoder: + observations_features = self.actor.encoder.get_image_features(batch) actions, _, _ = self.actor(batch, observations_features) actions = self.unnormalize_outputs({"action": actions})["action"] @@ -484,6 +484,109 @@ class SACPolicy( return actor_loss +class SACObservationEncoder(nn.Module): + """Encode image and/or state vector observations.""" + + def __init__(self, config: SACConfig, input_normalizer: nn.Module): + """ + Creates encoders for pixel and/or state modalities. + """ + super().__init__() + self.config = config + self.input_normalization = input_normalizer + self.has_pretrained_vision_encoder = False + self.parameters_to_optimize = [] + + self.aggregation_size: int = 0 + if any("observation.image" in key for key in config.input_features): + self.camera_number = config.camera_number + + if self.config.vision_encoder_name is not None: + self.image_enc_layers = PretrainedImageEncoder(config) + self.has_pretrained_vision_encoder = True + else: + self.image_enc_layers = DefaultImageEncoder(config) + + self.aggregation_size += config.latent_dim * self.camera_number + + if config.freeze_vision_encoder: + freeze_image_encoder(self.image_enc_layers) + else: + self.parameters_to_optimize += list(self.image_enc_layers.parameters()) + self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")] + + if "observation.state" in config.input_features: + self.state_enc_layers = nn.Sequential( + nn.Linear( + in_features=config.input_features["observation.state"].shape[0], + out_features=config.latent_dim, + ), + nn.LayerNorm(normalized_shape=config.latent_dim), + nn.Tanh(), + ) + self.aggregation_size += config.latent_dim + + self.parameters_to_optimize += list(self.state_enc_layers.parameters()) + + if "observation.environment_state" in config.input_features: + self.env_state_enc_layers = nn.Sequential( + nn.Linear( + in_features=config.input_features["observation.environment_state"].shape[0], + out_features=config.latent_dim, + ), + nn.LayerNorm(normalized_shape=config.latent_dim), + nn.Tanh(), + ) + self.aggregation_size += config.latent_dim + self.parameters_to_optimize += list(self.env_state_enc_layers.parameters()) + + self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim) + self.parameters_to_optimize += list(self.aggregation_layer.parameters()) + + def forward( + self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None + ) -> Tensor: + """Encode the image and/or state vector. + + Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken + over all features. + """ + feat = [] + obs_dict = self.input_normalization(obs_dict) + if len(self.all_image_keys) > 0 and vision_encoder_cache is None: + vision_encoder_cache = self.get_image_features(obs_dict) + feat.append(vision_encoder_cache) + + if vision_encoder_cache is not None: + feat.append(vision_encoder_cache) + + if "observation.environment_state" in self.config.input_features: + feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) + if "observation.state" in self.config.input_features: + feat.append(self.state_enc_layers(obs_dict["observation.state"])) + + features = torch.cat(tensors=feat, dim=-1) + features = self.aggregation_layer(features) + + return features + + def get_image_features(self, batch: dict[str, Tensor]) -> torch.Tensor: + # [N*B, C, H, W] + if len(self.all_image_keys) > 0: + # Batch all images along the batch dimension, then encode them. + images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0) + images_batched = self.image_enc_layers(images_batched) + embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys)) + embeddings_image = torch.cat(embeddings_chunks, dim=-1) + return embeddings_image + return None + + @property + def output_dim(self) -> int: + """Returns the dimension of the encoder output""" + return self.config.latent_dim + + class MLP(nn.Module): def __init__( self, @@ -606,7 +709,7 @@ class CriticEnsemble(nn.Module): def __init__( self, - encoder: Optional[nn.Module], + encoder: SACObservationEncoder, ensemble: List[CriticHead], output_normalization: nn.Module, init_final: Optional[float] = None, @@ -638,11 +741,7 @@ class CriticEnsemble(nn.Module): actions = self.output_normalization(actions)["action"] actions = actions.to(device) - obs_enc = ( - observation_features - if observation_features is not None - else (observations if self.encoder is None else self.encoder(observations)) - ) + obs_enc = self.encoder(observations, observation_features) inputs = torch.cat([obs_enc, actions], dim=-1) @@ -659,7 +758,7 @@ class CriticEnsemble(nn.Module): class GraspCritic(nn.Module): def __init__( self, - encoder: Optional[nn.Module], + encoder: nn.Module, input_dim: int, hidden_dims: list[int], output_dim: int = 3, @@ -699,19 +798,14 @@ class GraspCritic(nn.Module): device = get_device_from_parameters(self) # Move each tensor in observations to device by cloning first to avoid inplace operations observations = {k: v.to(device) for k, v in observations.items()} - # Encode observations if encoder exists - obs_enc = ( - observation_features.to(device) - if observation_features is not None - else (observations if self.encoder is None else self.encoder(observations)) - ) + obs_enc = self.encoder(observations, vision_encoder_cache=observation_features) return self.output_layer(self.net(obs_enc)) class Policy(nn.Module): def __init__( self, - encoder: Optional[nn.Module], + encoder: SACObservationEncoder, network: nn.Module, action_dim: int, log_std_min: float = -5, @@ -722,7 +816,7 @@ class Policy(nn.Module): encoder_is_shared: bool = False, ): super().__init__() - self.encoder = encoder + self.encoder: SACObservationEncoder = encoder self.network = network self.action_dim = action_dim self.log_std_min = log_std_min @@ -765,11 +859,7 @@ class Policy(nn.Module): observation_features: torch.Tensor | None = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Encode observations if encoder exists - obs_enc = ( - observation_features - if observation_features is not None - else (observations if self.encoder is None else self.encoder(observations)) - ) + obs_enc = self.encoder(observations, vision_encoder_cache=observation_features) # Get network outputs outputs = self.network(obs_enc) @@ -813,96 +903,6 @@ class Policy(nn.Module): return observations -class SACObservationEncoder(nn.Module): - """Encode image and/or state vector observations.""" - - def __init__(self, config: SACConfig, input_normalizer: nn.Module): - """ - Creates encoders for pixel and/or state modalities. - """ - super().__init__() - self.config = config - self.input_normalization = input_normalizer - self.has_pretrained_vision_encoder = False - self.parameters_to_optimize = [] - - self.aggregation_size: int = 0 - if any("observation.image" in key for key in config.input_features): - self.camera_number = config.camera_number - - if self.config.vision_encoder_name is not None: - self.image_enc_layers = PretrainedImageEncoder(config) - self.has_pretrained_vision_encoder = True - else: - self.image_enc_layers = DefaultImageEncoder(config) - - self.aggregation_size += config.latent_dim * self.camera_number - - if config.freeze_vision_encoder: - freeze_image_encoder(self.image_enc_layers) - else: - self.parameters_to_optimize += list(self.image_enc_layers.parameters()) - self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")] - - if "observation.state" in config.input_features: - self.state_enc_layers = nn.Sequential( - nn.Linear( - in_features=config.input_features["observation.state"].shape[0], - out_features=config.latent_dim, - ), - nn.LayerNorm(normalized_shape=config.latent_dim), - nn.Tanh(), - ) - self.aggregation_size += config.latent_dim - - self.parameters_to_optimize += list(self.state_enc_layers.parameters()) - - if "observation.environment_state" in config.input_features: - self.env_state_enc_layers = nn.Sequential( - nn.Linear( - in_features=config.input_features["observation.environment_state"].shape[0], - out_features=config.latent_dim, - ), - nn.LayerNorm(normalized_shape=config.latent_dim), - nn.Tanh(), - ) - self.aggregation_size += config.latent_dim - self.parameters_to_optimize += list(self.env_state_enc_layers.parameters()) - - self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim) - self.parameters_to_optimize += list(self.aggregation_layer.parameters()) - - def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: - """Encode the image and/or state vector. - - Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken - over all features. - """ - feat = [] - obs_dict = self.input_normalization(obs_dict) - # Batch all images along the batch dimension, then encode them. - if len(self.all_image_keys) > 0: - images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0) - images_batched = self.image_enc_layers(images_batched) - embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys)) - feat.extend(embeddings_chunks) - - if "observation.environment_state" in self.config.input_features: - feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) - if "observation.state" in self.config.input_features: - feat.append(self.state_enc_layers(obs_dict["observation.state"])) - - features = torch.cat(tensors=feat, dim=-1) - features = self.aggregation_layer(features) - - return features - - @property - def output_dim(self) -> int: - """Returns the dimension of the encoder output""" - return self.config.latent_dim - - class DefaultImageEncoder(nn.Module): def __init__(self, config: SACConfig): super().__init__() diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 65b1d9b8..37586fe9 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -775,7 +775,9 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module): params=policy.actor.parameters_to_optimize, lr=cfg.policy.actor_lr, ) - optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) + optimizer_critic = torch.optim.Adam( + params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr + ) if cfg.policy.num_discrete_actions is not None: optimizer_grasp_critic = torch.optim.Adam( @@ -1024,12 +1026,8 @@ def get_observation_features( return None, None with torch.no_grad(): - observation_features = ( - policy.actor.encoder(observations) if policy.actor.encoder is not None else None - ) - next_observation_features = ( - policy.actor.encoder(next_observations) if policy.actor.encoder is not None else None - ) + observation_features = policy.actor.encoder.get_image_features(observations) + next_observation_features = policy.actor.encoder.get_image_features(next_observations) return observation_features, next_observation_features From 4621f4e4f37faf8ce94b467132a23da7b4a5c839 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Mon, 7 Apr 2025 08:23:49 +0000 Subject: [PATCH 18/22] Handle gripper penalty --- lerobot/common/policies/sac/modeling_sac.py | 11 +- lerobot/scripts/server/buffer.py | 167 ++++++++++++++++---- lerobot/scripts/server/learner_server.py | 2 +- 3 files changed, 147 insertions(+), 33 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index b5bfb36e..e3d3765e 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -288,6 +288,7 @@ class SACPolicy( next_observations: dict[str, Tensor] = batch["next_state"] done: Tensor = batch["done"] next_observation_features: Tensor = batch.get("next_observation_feature") + complementary_info = batch.get("complementary_info") loss_grasp_critic = self.compute_loss_grasp_critic( observations=observations, actions=actions, @@ -296,6 +297,7 @@ class SACPolicy( done=done, observation_features=observation_features, next_observation_features=next_observation_features, + complementary_info=complementary_info, ) return {"loss_grasp_critic": loss_grasp_critic} if model == "actor": @@ -413,6 +415,7 @@ class SACPolicy( done, observation_features=None, next_observation_features=None, + complementary_info=None, ): # NOTE: We only want to keep the discrete action part # In the buffer we have the full action space (continuous + discrete) @@ -420,6 +423,9 @@ class SACPolicy( actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone() actions_discrete = actions_discrete.long() + if complementary_info is not None: + gripper_penalties: Tensor | None = complementary_info.get("gripper_penalty") + with torch.no_grad(): # For DQN, select actions using online network, evaluate with target network next_grasp_qs = self.grasp_critic_forward( @@ -440,7 +446,10 @@ class SACPolicy( ).squeeze(-1) # Compute target Q-value with Bellman equation - target_grasp_q = rewards + (1 - done) * self.config.discount * target_next_grasp_q + rewards_gripper = rewards + if gripper_penalties is not None: + rewards_gripper = rewards - gripper_penalties + target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q # Get predicted Q-values for current observations predicted_grasp_qs = self.grasp_critic_forward( diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 8947f6d9..92ad7dc7 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -32,7 +32,7 @@ class Transition(TypedDict): next_state: dict[str, torch.Tensor] done: bool truncated: bool - complementary_info: dict[str, Any] = None + complementary_info: dict[str, torch.Tensor | float | int] | None = None class BatchTransition(TypedDict): @@ -42,41 +42,43 @@ class BatchTransition(TypedDict): next_state: dict[str, torch.Tensor] done: torch.Tensor truncated: torch.Tensor + complementary_info: dict[str, torch.Tensor | float | int] | None = None def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition: - # Move state tensors to CPU device = torch.device(device) + non_blocking = device.type == "cuda" + + # Move state tensors to device transition["state"] = { - key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items() + key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items() } - # Move action to CPU - transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda") + # Move action to device + transition["action"] = transition["action"].to(device, non_blocking=non_blocking) - # No need to move reward or done, as they are float and bool - - # No need to move reward or done, as they are float and bool + # Move reward and done if they are tensors if isinstance(transition["reward"], torch.Tensor): - transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda") + transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking) if isinstance(transition["done"], torch.Tensor): - transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda") + transition["done"] = transition["done"].to(device, non_blocking=non_blocking) if isinstance(transition["truncated"], torch.Tensor): - transition["truncated"] = transition["truncated"].to(device, non_blocking=device.type == "cuda") + transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking) - # Move next_state tensors to CPU + # Move next_state tensors to device transition["next_state"] = { - key: val.to(device, non_blocking=device.type == "cuda") - for key, val in transition["next_state"].items() + key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items() } - # If complementary_info is present, move its tensors to CPU - # if transition["complementary_info"] is not None: - # transition["complementary_info"] = { - # key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items() - # } + # Move complementary_info tensors if present + if transition.get("complementary_info") is not None: + transition["complementary_info"] = { + key: val.to(device, non_blocking=non_blocking) + for key, val in transition["complementary_info"].items() + } + return transition @@ -216,7 +218,12 @@ class ReplayBuffer: self.image_augmentation_function = torch.compile(base_function) self.use_drq = use_drq - def _initialize_storage(self, state: dict[str, torch.Tensor], action: torch.Tensor): + def _initialize_storage( + self, + state: dict[str, torch.Tensor], + action: torch.Tensor, + complementary_info: Optional[dict[str, torch.Tensor]] = None, + ): """Initialize the storage tensors based on the first transition.""" # Determine shapes from the first transition state_shapes = {key: val.squeeze(0).shape for key, val in state.items()} @@ -244,6 +251,26 @@ class ReplayBuffer: self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) + # Initialize storage for complementary_info + self.has_complementary_info = complementary_info is not None + self.complementary_info_keys = [] + self.complementary_info = {} + + if self.has_complementary_info: + self.complementary_info_keys = list(complementary_info.keys()) + # Pre-allocate tensors for each key in complementary_info + for key, value in complementary_info.items(): + if isinstance(value, torch.Tensor): + value_shape = value.squeeze(0).shape + self.complementary_info[key] = torch.empty( + (self.capacity, *value_shape), device=self.storage_device + ) + elif isinstance(value, (int, float)): + # Handle scalar values similar to reward + self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device) + else: + raise ValueError(f"Unsupported type {type(value)} for complementary_info[{key}]") + self.initialized = True def __len__(self): @@ -262,7 +289,7 @@ class ReplayBuffer: """Saves a transition, ensuring tensors are stored on the designated storage device.""" # Initialize storage if this is the first transition if not self.initialized: - self._initialize_storage(state=state, action=action) + self._initialize_storage(state=state, action=action, complementary_info=complementary_info) # Store the transition in pre-allocated tensors for key in self.states: @@ -277,6 +304,17 @@ class ReplayBuffer: self.dones[self.position] = done self.truncateds[self.position] = truncated + # Handle complementary_info if provided and storage is initialized + if complementary_info is not None and self.has_complementary_info: + # Store the complementary_info + for key in self.complementary_info_keys: + if key in complementary_info: + value = complementary_info[key] + if isinstance(value, torch.Tensor): + self.complementary_info[key][self.position].copy_(value.squeeze(dim=0)) + elif isinstance(value, (int, float)): + self.complementary_info[key][self.position] = value + self.position = (self.position + 1) % self.capacity self.size = min(self.size + 1, self.capacity) @@ -335,6 +373,13 @@ class ReplayBuffer: batch_dones = self.dones[idx].to(self.device).float() batch_truncateds = self.truncateds[idx].to(self.device).float() + # Sample complementary_info if available + batch_complementary_info = None + if self.has_complementary_info: + batch_complementary_info = {} + for key in self.complementary_info_keys: + batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device) + return BatchTransition( state=batch_state, action=batch_actions, @@ -342,6 +387,7 @@ class ReplayBuffer: next_state=batch_next_state, done=batch_dones, truncated=batch_truncateds, + complementary_info=batch_complementary_info, ) def get_iterator( @@ -518,7 +564,19 @@ class ReplayBuffer: if action_delta is not None: first_action = first_action / action_delta - replay_buffer._initialize_storage(state=first_state, action=first_action) + # Get complementary info if available + first_complementary_info = None + if ( + "complementary_info" in first_transition + and first_transition["complementary_info"] is not None + ): + first_complementary_info = { + k: v.to(device) for k, v in first_transition["complementary_info"].items() + } + + replay_buffer._initialize_storage( + state=first_state, action=first_action, complementary_info=first_complementary_info + ) # Fill the buffer with all transitions for data in list_transition: @@ -546,6 +604,7 @@ class ReplayBuffer: next_state=data["next_state"], done=data["done"], truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset + complementary_info=data.get("complementary_info", None), ) return replay_buffer @@ -587,6 +646,13 @@ class ReplayBuffer: f_info = guess_feature_info(t=sample_val, name=key) features[key] = f_info + # Add complementary_info keys if available + if self.has_complementary_info: + for key in self.complementary_info_keys: + sample_val = self.complementary_info[key][0] + f_info = guess_feature_info(t=sample_val, name=f"complementary_info.{key}") + features[f"complementary_info.{key}"] = f_info + # Create an empty LeRobotDataset lerobot_dataset = LeRobotDataset.create( repo_id=repo_id, @@ -620,6 +686,11 @@ class ReplayBuffer: frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu() frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu() + # Add complementary_info if available + if self.has_complementary_info: + for key in self.complementary_info_keys: + frame_dict[f"complementary_info.{key}"] = self.complementary_info[key][actual_idx].cpu() + # Add task field which is required by LeRobotDataset frame_dict["task"] = task_name @@ -686,6 +757,10 @@ class ReplayBuffer: sample = dataset[0] has_done_key = "next.done" in sample + # Check for complementary_info keys + complementary_info_keys = [key for key in sample.keys() if key.startswith("complementary_info.")] + has_complementary_info = len(complementary_info_keys) > 0 + # If not, we need to infer it from episode boundaries if not has_done_key: print("'next.done' key not found in dataset. Inferring from episode boundaries...") @@ -735,6 +810,16 @@ class ReplayBuffer: next_state_data[key] = val.unsqueeze(0) # Add batch dimension next_state = next_state_data + # ----- 5) Complementary info (if available) ----- + complementary_info = None + if has_complementary_info: + complementary_info = {} + for key in complementary_info_keys: + # Strip the "complementary_info." prefix to get the actual key + clean_key = key[len("complementary_info.") :] + val = current_sample[key] + complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension + # ----- Construct the Transition ----- transition = Transition( state=current_state, @@ -743,6 +828,7 @@ class ReplayBuffer: next_state=next_state, done=done, truncated=truncated, + complementary_info=complementary_info, ) transitions.append(transition) @@ -775,32 +861,33 @@ def concatenate_batch_transitions( left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition ) -> BatchTransition: """NOTE: Be careful it change the left_batch_transitions in place""" + # Concatenate state fields left_batch_transitions["state"] = { key: torch.cat( - [ - left_batch_transitions["state"][key], - right_batch_transition["state"][key], - ], + [left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0, ) for key in left_batch_transitions["state"] } + + # Concatenate basic fields left_batch_transitions["action"] = torch.cat( [left_batch_transitions["action"], right_batch_transition["action"]], dim=0 ) left_batch_transitions["reward"] = torch.cat( [left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0 ) + + # Concatenate next_state fields left_batch_transitions["next_state"] = { key: torch.cat( - [ - left_batch_transitions["next_state"][key], - right_batch_transition["next_state"][key], - ], + [left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0, ) for key in left_batch_transitions["next_state"] } + + # Concatenate done and truncated fields left_batch_transitions["done"] = torch.cat( [left_batch_transitions["done"], right_batch_transition["done"]], dim=0 ) @@ -808,8 +895,26 @@ def concatenate_batch_transitions( [left_batch_transitions["truncated"], right_batch_transition["truncated"]], dim=0, ) + + # Handle complementary_info + left_info = left_batch_transitions.get("complementary_info") + right_info = right_batch_transition.get("complementary_info") + + # Only process if right_info exists + if right_info is not None: + # Initialize left complementary_info if needed + if left_info is None: + left_batch_transitions["complementary_info"] = right_info + else: + # Concatenate each field + for key in right_info: + if key in left_info: + left_info[key] = torch.cat([left_info[key], right_info[key]], dim=0) + else: + left_info[key] = right_info[key] + return left_batch_transitions if __name__ == "__main__": - pass # All test code is currently commented out + pass diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 37586fe9..5489d6dc 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +# !/usr/bin/env python # Copyright 2024 The HuggingFace Inc. team. # All rights reserved. From 6c103906532c542c2ba763b5300afd2abc25dda8 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Mon, 7 Apr 2025 14:48:42 +0000 Subject: [PATCH 19/22] Refactor complementary_info handling in ReplayBuffer --- lerobot/scripts/server/buffer.py | 132 ++++++++++++++++++++++++++++--- 1 file changed, 120 insertions(+), 12 deletions(-) diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 92ad7dc7..185412fd 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -74,11 +74,15 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr # Move complementary_info tensors if present if transition.get("complementary_info") is not None: - transition["complementary_info"] = { - key: val.to(device, non_blocking=non_blocking) - for key, val in transition["complementary_info"].items() - } - + for key, val in transition["complementary_info"].items(): + if isinstance(val, torch.Tensor): + transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking) + elif isinstance(val, (int, float, bool)): + transition["complementary_info"][key] = torch.tensor( + val, device=device, non_blocking=non_blocking + ) + else: + raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]") return transition @@ -650,6 +654,8 @@ class ReplayBuffer: if self.has_complementary_info: for key in self.complementary_info_keys: sample_val = self.complementary_info[key][0] + if isinstance(sample_val, torch.Tensor) and sample_val.ndim == 0: + sample_val = sample_val.unsqueeze(0) f_info = guess_feature_info(t=sample_val, name=f"complementary_info.{key}") features[f"complementary_info.{key}"] = f_info @@ -689,7 +695,15 @@ class ReplayBuffer: # Add complementary_info if available if self.has_complementary_info: for key in self.complementary_info_keys: - frame_dict[f"complementary_info.{key}"] = self.complementary_info[key][actual_idx].cpu() + val = self.complementary_info[key][actual_idx] + # Convert tensors to CPU + if isinstance(val, torch.Tensor): + if val.ndim == 0: + val = val.unsqueeze(0) + frame_dict[f"complementary_info.{key}"] = val.cpu() + # Non-tensor values can be used directly + else: + frame_dict[f"complementary_info.{key}"] = val # Add task field which is required by LeRobotDataset frame_dict["task"] = task_name @@ -758,7 +772,7 @@ class ReplayBuffer: has_done_key = "next.done" in sample # Check for complementary_info keys - complementary_info_keys = [key for key in sample.keys() if key.startswith("complementary_info.")] + complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")] has_complementary_info = len(complementary_info_keys) > 0 # If not, we need to infer it from episode boundaries @@ -818,7 +832,13 @@ class ReplayBuffer: # Strip the "complementary_info." prefix to get the actual key clean_key = key[len("complementary_info.") :] val = current_sample[key] - complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension + # Handle tensor and non-tensor values differently + if isinstance(val, torch.Tensor): + complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension + else: + # TODO: (azouitine) Check if it's necessary to convert to tensor + # For non-tensor values, use directly + complementary_info[clean_key] = val # ----- Construct the Transition ----- transition = Transition( @@ -836,12 +856,13 @@ class ReplayBuffer: # Utility function to guess shapes/dtypes from a tensor -def guess_feature_info(t: torch.Tensor, name: str): +def guess_feature_info(t, name: str): """ - Return a dictionary with the 'dtype' and 'shape' for a given tensor or array. + Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value. If it looks like a 3D (C,H,W) shape, we might consider it an 'image'. - Otherwise default to 'float32' for numeric. You can customize as needed. + Otherwise default to appropriate dtype for numeric. """ + shape = tuple(t.shape) # Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image' if len(shape) == 3 and shape[0] in [1, 3]: @@ -917,4 +938,91 @@ def concatenate_batch_transitions( if __name__ == "__main__": - pass + + def test_load_dataset_with_complementary_info(): + """ + Test loading a dataset with complementary_info into a ReplayBuffer. + The dataset 'aractingi/pick_lift_cube_two_cameras_gripper_penalty' contains + gripper_penalty values in complementary_info. + """ + import time + from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + + print("Loading dataset with complementary info...") + # Load a small subset of the dataset (first episode) + dataset = LeRobotDataset( + repo_id="aractingi/pick_lift_cube_two_cameras_gripper_penalty", + ) + + print(f"Dataset loaded with {len(dataset)} frames") + print(f"Dataset features: {list(dataset.features.keys())}") + + # Check if dataset has complementary_info.gripper_penalty + sample = dataset[0] + complementary_info_keys = [key for key in sample if key.startswith("complementary_info")] + print(f"Complementary info keys: {complementary_info_keys}") + + if "complementary_info.gripper_penalty" in sample: + print(f"Found gripper_penalty: {sample['complementary_info.gripper_penalty']}") + + # Extract state keys for the buffer + state_keys = [] + for key in sample: + if key.startswith("observation"): + state_keys.append(key) + + print(f"Using state keys: {state_keys}") + + # Create a replay buffer from the dataset + start_time = time.time() + buffer = ReplayBuffer.from_lerobot_dataset( + lerobot_dataset=dataset, state_keys=state_keys, use_drq=True, optimize_memory=False + ) + load_time = time.time() - start_time + print(f"Loaded dataset into buffer in {load_time:.2f} seconds") + print(f"Buffer size: {len(buffer)}") + + # Check if complementary_info was transferred correctly + print("Sampling from buffer to check complementary_info...") + batch = buffer.sample(batch_size=4) + + if batch["complementary_info"] is not None: + print("Complementary info in batch:") + for key, value in batch["complementary_info"].items(): + print(f" {key}: {type(value)}, shape: {value.shape if hasattr(value, 'shape') else 'N/A'}") + if key == "gripper_penalty": + print(f" Sample gripper_penalty values: {value[:5]}") + else: + print("No complementary_info found in batch") + + # Now convert the buffer back to a LeRobotDataset + print("\nConverting buffer back to LeRobotDataset...") + start_time = time.time() + new_dataset = buffer.to_lerobot_dataset( + repo_id="test_dataset_from_buffer", + fps=dataset.fps, + root="./test_dataset_from_buffer", + task_name="test_conversion", + ) + convert_time = time.time() - start_time + print(f"Converted buffer to dataset in {convert_time:.2f} seconds") + print(f"New dataset size: {len(new_dataset)} frames") + + # Check if complementary_info was preserved + new_sample = new_dataset[0] + new_complementary_info_keys = [key for key in new_sample if key.startswith("complementary_info")] + print(f"New dataset complementary info keys: {new_complementary_info_keys}") + + if "complementary_info.gripper_penalty" in new_sample: + print(f"Found gripper_penalty in new dataset: {new_sample['complementary_info.gripper_penalty']}") + + # Compare original and new datasets + print("\nComparing original and new datasets:") + print(f"Original dataset frames: {len(dataset)}, New dataset frames: {len(new_dataset)}") + print(f"Original features: {list(dataset.features.keys())}") + print(f"New features: {list(new_dataset.features.keys())}") + + return buffer, dataset, new_dataset + + # Run the test + test_load_dataset_with_complementary_info() From 632b2b46c1f7dca5338b87b1d20c3306c1719ff6 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Mon, 7 Apr 2025 15:44:06 +0000 Subject: [PATCH 20/22] fix sign issue --- lerobot/common/policies/sac/modeling_sac.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index e3d3765e..9b909813 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -448,7 +448,7 @@ class SACPolicy( # Compute target Q-value with Bellman equation rewards_gripper = rewards if gripper_penalties is not None: - rewards_gripper = rewards - gripper_penalties + rewards_gripper = rewards + gripper_penalties target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q # Get predicted Q-values for current observations From a7be613ee8329d3f3a2784d314ad599605b72c4a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 7 Apr 2025 15:48:39 +0000 Subject: [PATCH 21/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- lerobot/scripts/server/buffer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 185412fd..8db1a82c 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -946,6 +946,7 @@ if __name__ == "__main__": gripper_penalty values in complementary_info. """ import time + from lerobot.common.datasets.lerobot_dataset import LeRobotDataset print("Loading dataset with complementary info...") From a8135629b4a22141fcae2ea86c050b49e4374d19 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Tue, 8 Apr 2025 08:50:02 +0000 Subject: [PATCH 22/22] Add rounding for safety --- lerobot/common/policies/sac/modeling_sac.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 9b909813..e3d83d36 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -421,6 +421,7 @@ class SACPolicy( # In the buffer we have the full action space (continuous + discrete) # We need to split them before concatenating them in the critic forward actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone() + actions_discrete = torch.round(actions_discrete) actions_discrete = actions_discrete.long() if complementary_info is not None: