Handle gripper penalty

This commit is contained in:
AdilZouitine 2025-04-07 08:23:49 +00:00
parent 7741526ce4
commit 4621f4e4f3
3 changed files with 147 additions and 33 deletions

View File

@ -288,6 +288,7 @@ class SACPolicy(
next_observations: dict[str, Tensor] = batch["next_state"]
done: Tensor = batch["done"]
next_observation_features: Tensor = batch.get("next_observation_feature")
complementary_info = batch.get("complementary_info")
loss_grasp_critic = self.compute_loss_grasp_critic(
observations=observations,
actions=actions,
@ -296,6 +297,7 @@ class SACPolicy(
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
complementary_info=complementary_info,
)
return {"loss_grasp_critic": loss_grasp_critic}
if model == "actor":
@ -413,6 +415,7 @@ class SACPolicy(
done,
observation_features=None,
next_observation_features=None,
complementary_info=None,
):
# NOTE: We only want to keep the discrete action part
# In the buffer we have the full action space (continuous + discrete)
@ -420,6 +423,9 @@ class SACPolicy(
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
actions_discrete = actions_discrete.long()
if complementary_info is not None:
gripper_penalties: Tensor | None = complementary_info.get("gripper_penalty")
with torch.no_grad():
# For DQN, select actions using online network, evaluate with target network
next_grasp_qs = self.grasp_critic_forward(
@ -440,7 +446,10 @@ class SACPolicy(
).squeeze(-1)
# Compute target Q-value with Bellman equation
target_grasp_q = rewards + (1 - done) * self.config.discount * target_next_grasp_q
rewards_gripper = rewards
if gripper_penalties is not None:
rewards_gripper = rewards - gripper_penalties
target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
# Get predicted Q-values for current observations
predicted_grasp_qs = self.grasp_critic_forward(

View File

@ -32,7 +32,7 @@ class Transition(TypedDict):
next_state: dict[str, torch.Tensor]
done: bool
truncated: bool
complementary_info: dict[str, Any] = None
complementary_info: dict[str, torch.Tensor | float | int] | None = None
class BatchTransition(TypedDict):
@ -42,41 +42,43 @@ class BatchTransition(TypedDict):
next_state: dict[str, torch.Tensor]
done: torch.Tensor
truncated: torch.Tensor
complementary_info: dict[str, torch.Tensor | float | int] | None = None
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
# Move state tensors to CPU
device = torch.device(device)
non_blocking = device.type == "cuda"
# Move state tensors to device
transition["state"] = {
key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items()
key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items()
}
# Move action to CPU
transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda")
# Move action to device
transition["action"] = transition["action"].to(device, non_blocking=non_blocking)
# No need to move reward or done, as they are float and bool
# No need to move reward or done, as they are float and bool
# Move reward and done if they are tensors
if isinstance(transition["reward"], torch.Tensor):
transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda")
transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking)
if isinstance(transition["done"], torch.Tensor):
transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda")
transition["done"] = transition["done"].to(device, non_blocking=non_blocking)
if isinstance(transition["truncated"], torch.Tensor):
transition["truncated"] = transition["truncated"].to(device, non_blocking=device.type == "cuda")
transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking)
# Move next_state tensors to CPU
# Move next_state tensors to device
transition["next_state"] = {
key: val.to(device, non_blocking=device.type == "cuda")
for key, val in transition["next_state"].items()
key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items()
}
# If complementary_info is present, move its tensors to CPU
# if transition["complementary_info"] is not None:
# transition["complementary_info"] = {
# key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items()
# }
# Move complementary_info tensors if present
if transition.get("complementary_info") is not None:
transition["complementary_info"] = {
key: val.to(device, non_blocking=non_blocking)
for key, val in transition["complementary_info"].items()
}
return transition
@ -216,7 +218,12 @@ class ReplayBuffer:
self.image_augmentation_function = torch.compile(base_function)
self.use_drq = use_drq
def _initialize_storage(self, state: dict[str, torch.Tensor], action: torch.Tensor):
def _initialize_storage(
self,
state: dict[str, torch.Tensor],
action: torch.Tensor,
complementary_info: Optional[dict[str, torch.Tensor]] = None,
):
"""Initialize the storage tensors based on the first transition."""
# Determine shapes from the first transition
state_shapes = {key: val.squeeze(0).shape for key, val in state.items()}
@ -244,6 +251,26 @@ class ReplayBuffer:
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
# Initialize storage for complementary_info
self.has_complementary_info = complementary_info is not None
self.complementary_info_keys = []
self.complementary_info = {}
if self.has_complementary_info:
self.complementary_info_keys = list(complementary_info.keys())
# Pre-allocate tensors for each key in complementary_info
for key, value in complementary_info.items():
if isinstance(value, torch.Tensor):
value_shape = value.squeeze(0).shape
self.complementary_info[key] = torch.empty(
(self.capacity, *value_shape), device=self.storage_device
)
elif isinstance(value, (int, float)):
# Handle scalar values similar to reward
self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
else:
raise ValueError(f"Unsupported type {type(value)} for complementary_info[{key}]")
self.initialized = True
def __len__(self):
@ -262,7 +289,7 @@ class ReplayBuffer:
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
# Initialize storage if this is the first transition
if not self.initialized:
self._initialize_storage(state=state, action=action)
self._initialize_storage(state=state, action=action, complementary_info=complementary_info)
# Store the transition in pre-allocated tensors
for key in self.states:
@ -277,6 +304,17 @@ class ReplayBuffer:
self.dones[self.position] = done
self.truncateds[self.position] = truncated
# Handle complementary_info if provided and storage is initialized
if complementary_info is not None and self.has_complementary_info:
# Store the complementary_info
for key in self.complementary_info_keys:
if key in complementary_info:
value = complementary_info[key]
if isinstance(value, torch.Tensor):
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
elif isinstance(value, (int, float)):
self.complementary_info[key][self.position] = value
self.position = (self.position + 1) % self.capacity
self.size = min(self.size + 1, self.capacity)
@ -335,6 +373,13 @@ class ReplayBuffer:
batch_dones = self.dones[idx].to(self.device).float()
batch_truncateds = self.truncateds[idx].to(self.device).float()
# Sample complementary_info if available
batch_complementary_info = None
if self.has_complementary_info:
batch_complementary_info = {}
for key in self.complementary_info_keys:
batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
return BatchTransition(
state=batch_state,
action=batch_actions,
@ -342,6 +387,7 @@ class ReplayBuffer:
next_state=batch_next_state,
done=batch_dones,
truncated=batch_truncateds,
complementary_info=batch_complementary_info,
)
def get_iterator(
@ -518,7 +564,19 @@ class ReplayBuffer:
if action_delta is not None:
first_action = first_action / action_delta
replay_buffer._initialize_storage(state=first_state, action=first_action)
# Get complementary info if available
first_complementary_info = None
if (
"complementary_info" in first_transition
and first_transition["complementary_info"] is not None
):
first_complementary_info = {
k: v.to(device) for k, v in first_transition["complementary_info"].items()
}
replay_buffer._initialize_storage(
state=first_state, action=first_action, complementary_info=first_complementary_info
)
# Fill the buffer with all transitions
for data in list_transition:
@ -546,6 +604,7 @@ class ReplayBuffer:
next_state=data["next_state"],
done=data["done"],
truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset
complementary_info=data.get("complementary_info", None),
)
return replay_buffer
@ -587,6 +646,13 @@ class ReplayBuffer:
f_info = guess_feature_info(t=sample_val, name=key)
features[key] = f_info
# Add complementary_info keys if available
if self.has_complementary_info:
for key in self.complementary_info_keys:
sample_val = self.complementary_info[key][0]
f_info = guess_feature_info(t=sample_val, name=f"complementary_info.{key}")
features[f"complementary_info.{key}"] = f_info
# Create an empty LeRobotDataset
lerobot_dataset = LeRobotDataset.create(
repo_id=repo_id,
@ -620,6 +686,11 @@ class ReplayBuffer:
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
# Add complementary_info if available
if self.has_complementary_info:
for key in self.complementary_info_keys:
frame_dict[f"complementary_info.{key}"] = self.complementary_info[key][actual_idx].cpu()
# Add task field which is required by LeRobotDataset
frame_dict["task"] = task_name
@ -686,6 +757,10 @@ class ReplayBuffer:
sample = dataset[0]
has_done_key = "next.done" in sample
# Check for complementary_info keys
complementary_info_keys = [key for key in sample.keys() if key.startswith("complementary_info.")]
has_complementary_info = len(complementary_info_keys) > 0
# If not, we need to infer it from episode boundaries
if not has_done_key:
print("'next.done' key not found in dataset. Inferring from episode boundaries...")
@ -735,6 +810,16 @@ class ReplayBuffer:
next_state_data[key] = val.unsqueeze(0) # Add batch dimension
next_state = next_state_data
# ----- 5) Complementary info (if available) -----
complementary_info = None
if has_complementary_info:
complementary_info = {}
for key in complementary_info_keys:
# Strip the "complementary_info." prefix to get the actual key
clean_key = key[len("complementary_info.") :]
val = current_sample[key]
complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension
# ----- Construct the Transition -----
transition = Transition(
state=current_state,
@ -743,6 +828,7 @@ class ReplayBuffer:
next_state=next_state,
done=done,
truncated=truncated,
complementary_info=complementary_info,
)
transitions.append(transition)
@ -775,32 +861,33 @@ def concatenate_batch_transitions(
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
) -> BatchTransition:
"""NOTE: Be careful it change the left_batch_transitions in place"""
# Concatenate state fields
left_batch_transitions["state"] = {
key: torch.cat(
[
left_batch_transitions["state"][key],
right_batch_transition["state"][key],
],
[left_batch_transitions["state"][key], right_batch_transition["state"][key]],
dim=0,
)
for key in left_batch_transitions["state"]
}
# Concatenate basic fields
left_batch_transitions["action"] = torch.cat(
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
)
left_batch_transitions["reward"] = torch.cat(
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
)
# Concatenate next_state fields
left_batch_transitions["next_state"] = {
key: torch.cat(
[
left_batch_transitions["next_state"][key],
right_batch_transition["next_state"][key],
],
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]],
dim=0,
)
for key in left_batch_transitions["next_state"]
}
# Concatenate done and truncated fields
left_batch_transitions["done"] = torch.cat(
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
)
@ -808,8 +895,26 @@ def concatenate_batch_transitions(
[left_batch_transitions["truncated"], right_batch_transition["truncated"]],
dim=0,
)
# Handle complementary_info
left_info = left_batch_transitions.get("complementary_info")
right_info = right_batch_transition.get("complementary_info")
# Only process if right_info exists
if right_info is not None:
# Initialize left complementary_info if needed
if left_info is None:
left_batch_transitions["complementary_info"] = right_info
else:
# Concatenate each field
for key in right_info:
if key in left_info:
left_info[key] = torch.cat([left_info[key], right_info[key]], dim=0)
else:
left_info[key] = right_info[key]
return left_batch_transitions
if __name__ == "__main__":
pass # All test code is currently commented out
pass

View File

@ -1,4 +1,4 @@
#!/usr/bin/env python
# !/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.