From 4621f4e4f37faf8ce94b467132a23da7b4a5c839 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Mon, 7 Apr 2025 08:23:49 +0000 Subject: [PATCH] Handle gripper penalty --- lerobot/common/policies/sac/modeling_sac.py | 11 +- lerobot/scripts/server/buffer.py | 167 ++++++++++++++++---- lerobot/scripts/server/learner_server.py | 2 +- 3 files changed, 147 insertions(+), 33 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index b5bfb36e..e3d3765e 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -288,6 +288,7 @@ class SACPolicy( next_observations: dict[str, Tensor] = batch["next_state"] done: Tensor = batch["done"] next_observation_features: Tensor = batch.get("next_observation_feature") + complementary_info = batch.get("complementary_info") loss_grasp_critic = self.compute_loss_grasp_critic( observations=observations, actions=actions, @@ -296,6 +297,7 @@ class SACPolicy( done=done, observation_features=observation_features, next_observation_features=next_observation_features, + complementary_info=complementary_info, ) return {"loss_grasp_critic": loss_grasp_critic} if model == "actor": @@ -413,6 +415,7 @@ class SACPolicy( done, observation_features=None, next_observation_features=None, + complementary_info=None, ): # NOTE: We only want to keep the discrete action part # In the buffer we have the full action space (continuous + discrete) @@ -420,6 +423,9 @@ class SACPolicy( actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone() actions_discrete = actions_discrete.long() + if complementary_info is not None: + gripper_penalties: Tensor | None = complementary_info.get("gripper_penalty") + with torch.no_grad(): # For DQN, select actions using online network, evaluate with target network next_grasp_qs = self.grasp_critic_forward( @@ -440,7 +446,10 @@ class SACPolicy( ).squeeze(-1) # Compute target Q-value with Bellman equation - target_grasp_q = rewards + (1 - done) * self.config.discount * target_next_grasp_q + rewards_gripper = rewards + if gripper_penalties is not None: + rewards_gripper = rewards - gripper_penalties + target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q # Get predicted Q-values for current observations predicted_grasp_qs = self.grasp_critic_forward( diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 8947f6d9..92ad7dc7 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -32,7 +32,7 @@ class Transition(TypedDict): next_state: dict[str, torch.Tensor] done: bool truncated: bool - complementary_info: dict[str, Any] = None + complementary_info: dict[str, torch.Tensor | float | int] | None = None class BatchTransition(TypedDict): @@ -42,41 +42,43 @@ class BatchTransition(TypedDict): next_state: dict[str, torch.Tensor] done: torch.Tensor truncated: torch.Tensor + complementary_info: dict[str, torch.Tensor | float | int] | None = None def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition: - # Move state tensors to CPU device = torch.device(device) + non_blocking = device.type == "cuda" + + # Move state tensors to device transition["state"] = { - key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items() + key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items() } - # Move action to CPU - transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda") + # Move action to device + transition["action"] = transition["action"].to(device, non_blocking=non_blocking) - # No need to move reward or done, as they are float and bool - - # No need to move reward or done, as they are float and bool + # Move reward and done if they are tensors if isinstance(transition["reward"], torch.Tensor): - transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda") + transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking) if isinstance(transition["done"], torch.Tensor): - transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda") + transition["done"] = transition["done"].to(device, non_blocking=non_blocking) if isinstance(transition["truncated"], torch.Tensor): - transition["truncated"] = transition["truncated"].to(device, non_blocking=device.type == "cuda") + transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking) - # Move next_state tensors to CPU + # Move next_state tensors to device transition["next_state"] = { - key: val.to(device, non_blocking=device.type == "cuda") - for key, val in transition["next_state"].items() + key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items() } - # If complementary_info is present, move its tensors to CPU - # if transition["complementary_info"] is not None: - # transition["complementary_info"] = { - # key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items() - # } + # Move complementary_info tensors if present + if transition.get("complementary_info") is not None: + transition["complementary_info"] = { + key: val.to(device, non_blocking=non_blocking) + for key, val in transition["complementary_info"].items() + } + return transition @@ -216,7 +218,12 @@ class ReplayBuffer: self.image_augmentation_function = torch.compile(base_function) self.use_drq = use_drq - def _initialize_storage(self, state: dict[str, torch.Tensor], action: torch.Tensor): + def _initialize_storage( + self, + state: dict[str, torch.Tensor], + action: torch.Tensor, + complementary_info: Optional[dict[str, torch.Tensor]] = None, + ): """Initialize the storage tensors based on the first transition.""" # Determine shapes from the first transition state_shapes = {key: val.squeeze(0).shape for key, val in state.items()} @@ -244,6 +251,26 @@ class ReplayBuffer: self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) + # Initialize storage for complementary_info + self.has_complementary_info = complementary_info is not None + self.complementary_info_keys = [] + self.complementary_info = {} + + if self.has_complementary_info: + self.complementary_info_keys = list(complementary_info.keys()) + # Pre-allocate tensors for each key in complementary_info + for key, value in complementary_info.items(): + if isinstance(value, torch.Tensor): + value_shape = value.squeeze(0).shape + self.complementary_info[key] = torch.empty( + (self.capacity, *value_shape), device=self.storage_device + ) + elif isinstance(value, (int, float)): + # Handle scalar values similar to reward + self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device) + else: + raise ValueError(f"Unsupported type {type(value)} for complementary_info[{key}]") + self.initialized = True def __len__(self): @@ -262,7 +289,7 @@ class ReplayBuffer: """Saves a transition, ensuring tensors are stored on the designated storage device.""" # Initialize storage if this is the first transition if not self.initialized: - self._initialize_storage(state=state, action=action) + self._initialize_storage(state=state, action=action, complementary_info=complementary_info) # Store the transition in pre-allocated tensors for key in self.states: @@ -277,6 +304,17 @@ class ReplayBuffer: self.dones[self.position] = done self.truncateds[self.position] = truncated + # Handle complementary_info if provided and storage is initialized + if complementary_info is not None and self.has_complementary_info: + # Store the complementary_info + for key in self.complementary_info_keys: + if key in complementary_info: + value = complementary_info[key] + if isinstance(value, torch.Tensor): + self.complementary_info[key][self.position].copy_(value.squeeze(dim=0)) + elif isinstance(value, (int, float)): + self.complementary_info[key][self.position] = value + self.position = (self.position + 1) % self.capacity self.size = min(self.size + 1, self.capacity) @@ -335,6 +373,13 @@ class ReplayBuffer: batch_dones = self.dones[idx].to(self.device).float() batch_truncateds = self.truncateds[idx].to(self.device).float() + # Sample complementary_info if available + batch_complementary_info = None + if self.has_complementary_info: + batch_complementary_info = {} + for key in self.complementary_info_keys: + batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device) + return BatchTransition( state=batch_state, action=batch_actions, @@ -342,6 +387,7 @@ class ReplayBuffer: next_state=batch_next_state, done=batch_dones, truncated=batch_truncateds, + complementary_info=batch_complementary_info, ) def get_iterator( @@ -518,7 +564,19 @@ class ReplayBuffer: if action_delta is not None: first_action = first_action / action_delta - replay_buffer._initialize_storage(state=first_state, action=first_action) + # Get complementary info if available + first_complementary_info = None + if ( + "complementary_info" in first_transition + and first_transition["complementary_info"] is not None + ): + first_complementary_info = { + k: v.to(device) for k, v in first_transition["complementary_info"].items() + } + + replay_buffer._initialize_storage( + state=first_state, action=first_action, complementary_info=first_complementary_info + ) # Fill the buffer with all transitions for data in list_transition: @@ -546,6 +604,7 @@ class ReplayBuffer: next_state=data["next_state"], done=data["done"], truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset + complementary_info=data.get("complementary_info", None), ) return replay_buffer @@ -587,6 +646,13 @@ class ReplayBuffer: f_info = guess_feature_info(t=sample_val, name=key) features[key] = f_info + # Add complementary_info keys if available + if self.has_complementary_info: + for key in self.complementary_info_keys: + sample_val = self.complementary_info[key][0] + f_info = guess_feature_info(t=sample_val, name=f"complementary_info.{key}") + features[f"complementary_info.{key}"] = f_info + # Create an empty LeRobotDataset lerobot_dataset = LeRobotDataset.create( repo_id=repo_id, @@ -620,6 +686,11 @@ class ReplayBuffer: frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu() frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu() + # Add complementary_info if available + if self.has_complementary_info: + for key in self.complementary_info_keys: + frame_dict[f"complementary_info.{key}"] = self.complementary_info[key][actual_idx].cpu() + # Add task field which is required by LeRobotDataset frame_dict["task"] = task_name @@ -686,6 +757,10 @@ class ReplayBuffer: sample = dataset[0] has_done_key = "next.done" in sample + # Check for complementary_info keys + complementary_info_keys = [key for key in sample.keys() if key.startswith("complementary_info.")] + has_complementary_info = len(complementary_info_keys) > 0 + # If not, we need to infer it from episode boundaries if not has_done_key: print("'next.done' key not found in dataset. Inferring from episode boundaries...") @@ -735,6 +810,16 @@ class ReplayBuffer: next_state_data[key] = val.unsqueeze(0) # Add batch dimension next_state = next_state_data + # ----- 5) Complementary info (if available) ----- + complementary_info = None + if has_complementary_info: + complementary_info = {} + for key in complementary_info_keys: + # Strip the "complementary_info." prefix to get the actual key + clean_key = key[len("complementary_info.") :] + val = current_sample[key] + complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension + # ----- Construct the Transition ----- transition = Transition( state=current_state, @@ -743,6 +828,7 @@ class ReplayBuffer: next_state=next_state, done=done, truncated=truncated, + complementary_info=complementary_info, ) transitions.append(transition) @@ -775,32 +861,33 @@ def concatenate_batch_transitions( left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition ) -> BatchTransition: """NOTE: Be careful it change the left_batch_transitions in place""" + # Concatenate state fields left_batch_transitions["state"] = { key: torch.cat( - [ - left_batch_transitions["state"][key], - right_batch_transition["state"][key], - ], + [left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0, ) for key in left_batch_transitions["state"] } + + # Concatenate basic fields left_batch_transitions["action"] = torch.cat( [left_batch_transitions["action"], right_batch_transition["action"]], dim=0 ) left_batch_transitions["reward"] = torch.cat( [left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0 ) + + # Concatenate next_state fields left_batch_transitions["next_state"] = { key: torch.cat( - [ - left_batch_transitions["next_state"][key], - right_batch_transition["next_state"][key], - ], + [left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0, ) for key in left_batch_transitions["next_state"] } + + # Concatenate done and truncated fields left_batch_transitions["done"] = torch.cat( [left_batch_transitions["done"], right_batch_transition["done"]], dim=0 ) @@ -808,8 +895,26 @@ def concatenate_batch_transitions( [left_batch_transitions["truncated"], right_batch_transition["truncated"]], dim=0, ) + + # Handle complementary_info + left_info = left_batch_transitions.get("complementary_info") + right_info = right_batch_transition.get("complementary_info") + + # Only process if right_info exists + if right_info is not None: + # Initialize left complementary_info if needed + if left_info is None: + left_batch_transitions["complementary_info"] = right_info + else: + # Concatenate each field + for key in right_info: + if key in left_info: + left_info[key] = torch.cat([left_info[key], right_info[key]], dim=0) + else: + left_info[key] = right_info[key] + return left_batch_transitions if __name__ == "__main__": - pass # All test code is currently commented out + pass diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 37586fe9..5489d6dc 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +# !/usr/bin/env python # Copyright 2024 The HuggingFace Inc. team. # All rights reserved.