Handle gripper penalty
This commit is contained in:
parent
7741526ce4
commit
4621f4e4f3
|
@ -288,6 +288,7 @@ class SACPolicy(
|
|||
next_observations: dict[str, Tensor] = batch["next_state"]
|
||||
done: Tensor = batch["done"]
|
||||
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||
complementary_info = batch.get("complementary_info")
|
||||
loss_grasp_critic = self.compute_loss_grasp_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
|
@ -296,6 +297,7 @@ class SACPolicy(
|
|||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
return {"loss_grasp_critic": loss_grasp_critic}
|
||||
if model == "actor":
|
||||
|
@ -413,6 +415,7 @@ class SACPolicy(
|
|||
done,
|
||||
observation_features=None,
|
||||
next_observation_features=None,
|
||||
complementary_info=None,
|
||||
):
|
||||
# NOTE: We only want to keep the discrete action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
|
@ -420,6 +423,9 @@ class SACPolicy(
|
|||
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
|
||||
actions_discrete = actions_discrete.long()
|
||||
|
||||
if complementary_info is not None:
|
||||
gripper_penalties: Tensor | None = complementary_info.get("gripper_penalty")
|
||||
|
||||
with torch.no_grad():
|
||||
# For DQN, select actions using online network, evaluate with target network
|
||||
next_grasp_qs = self.grasp_critic_forward(
|
||||
|
@ -440,7 +446,10 @@ class SACPolicy(
|
|||
).squeeze(-1)
|
||||
|
||||
# Compute target Q-value with Bellman equation
|
||||
target_grasp_q = rewards + (1 - done) * self.config.discount * target_next_grasp_q
|
||||
rewards_gripper = rewards
|
||||
if gripper_penalties is not None:
|
||||
rewards_gripper = rewards - gripper_penalties
|
||||
target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
|
||||
|
||||
# Get predicted Q-values for current observations
|
||||
predicted_grasp_qs = self.grasp_critic_forward(
|
||||
|
|
|
@ -32,7 +32,7 @@ class Transition(TypedDict):
|
|||
next_state: dict[str, torch.Tensor]
|
||||
done: bool
|
||||
truncated: bool
|
||||
complementary_info: dict[str, Any] = None
|
||||
complementary_info: dict[str, torch.Tensor | float | int] | None = None
|
||||
|
||||
|
||||
class BatchTransition(TypedDict):
|
||||
|
@ -42,41 +42,43 @@ class BatchTransition(TypedDict):
|
|||
next_state: dict[str, torch.Tensor]
|
||||
done: torch.Tensor
|
||||
truncated: torch.Tensor
|
||||
complementary_info: dict[str, torch.Tensor | float | int] | None = None
|
||||
|
||||
|
||||
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
|
||||
# Move state tensors to CPU
|
||||
device = torch.device(device)
|
||||
non_blocking = device.type == "cuda"
|
||||
|
||||
# Move state tensors to device
|
||||
transition["state"] = {
|
||||
key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items()
|
||||
key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items()
|
||||
}
|
||||
|
||||
# Move action to CPU
|
||||
transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda")
|
||||
# Move action to device
|
||||
transition["action"] = transition["action"].to(device, non_blocking=non_blocking)
|
||||
|
||||
# No need to move reward or done, as they are float and bool
|
||||
|
||||
# No need to move reward or done, as they are float and bool
|
||||
# Move reward and done if they are tensors
|
||||
if isinstance(transition["reward"], torch.Tensor):
|
||||
transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda")
|
||||
transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking)
|
||||
|
||||
if isinstance(transition["done"], torch.Tensor):
|
||||
transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda")
|
||||
transition["done"] = transition["done"].to(device, non_blocking=non_blocking)
|
||||
|
||||
if isinstance(transition["truncated"], torch.Tensor):
|
||||
transition["truncated"] = transition["truncated"].to(device, non_blocking=device.type == "cuda")
|
||||
transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking)
|
||||
|
||||
# Move next_state tensors to CPU
|
||||
# Move next_state tensors to device
|
||||
transition["next_state"] = {
|
||||
key: val.to(device, non_blocking=device.type == "cuda")
|
||||
for key, val in transition["next_state"].items()
|
||||
key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items()
|
||||
}
|
||||
|
||||
# If complementary_info is present, move its tensors to CPU
|
||||
# if transition["complementary_info"] is not None:
|
||||
# transition["complementary_info"] = {
|
||||
# key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items()
|
||||
# }
|
||||
# Move complementary_info tensors if present
|
||||
if transition.get("complementary_info") is not None:
|
||||
transition["complementary_info"] = {
|
||||
key: val.to(device, non_blocking=non_blocking)
|
||||
for key, val in transition["complementary_info"].items()
|
||||
}
|
||||
|
||||
return transition
|
||||
|
||||
|
||||
|
@ -216,7 +218,12 @@ class ReplayBuffer:
|
|||
self.image_augmentation_function = torch.compile(base_function)
|
||||
self.use_drq = use_drq
|
||||
|
||||
def _initialize_storage(self, state: dict[str, torch.Tensor], action: torch.Tensor):
|
||||
def _initialize_storage(
|
||||
self,
|
||||
state: dict[str, torch.Tensor],
|
||||
action: torch.Tensor,
|
||||
complementary_info: Optional[dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Initialize the storage tensors based on the first transition."""
|
||||
# Determine shapes from the first transition
|
||||
state_shapes = {key: val.squeeze(0).shape for key, val in state.items()}
|
||||
|
@ -244,6 +251,26 @@ class ReplayBuffer:
|
|||
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||
|
||||
# Initialize storage for complementary_info
|
||||
self.has_complementary_info = complementary_info is not None
|
||||
self.complementary_info_keys = []
|
||||
self.complementary_info = {}
|
||||
|
||||
if self.has_complementary_info:
|
||||
self.complementary_info_keys = list(complementary_info.keys())
|
||||
# Pre-allocate tensors for each key in complementary_info
|
||||
for key, value in complementary_info.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
value_shape = value.squeeze(0).shape
|
||||
self.complementary_info[key] = torch.empty(
|
||||
(self.capacity, *value_shape), device=self.storage_device
|
||||
)
|
||||
elif isinstance(value, (int, float)):
|
||||
# Handle scalar values similar to reward
|
||||
self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type {type(value)} for complementary_info[{key}]")
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def __len__(self):
|
||||
|
@ -262,7 +289,7 @@ class ReplayBuffer:
|
|||
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
||||
# Initialize storage if this is the first transition
|
||||
if not self.initialized:
|
||||
self._initialize_storage(state=state, action=action)
|
||||
self._initialize_storage(state=state, action=action, complementary_info=complementary_info)
|
||||
|
||||
# Store the transition in pre-allocated tensors
|
||||
for key in self.states:
|
||||
|
@ -277,6 +304,17 @@ class ReplayBuffer:
|
|||
self.dones[self.position] = done
|
||||
self.truncateds[self.position] = truncated
|
||||
|
||||
# Handle complementary_info if provided and storage is initialized
|
||||
if complementary_info is not None and self.has_complementary_info:
|
||||
# Store the complementary_info
|
||||
for key in self.complementary_info_keys:
|
||||
if key in complementary_info:
|
||||
value = complementary_info[key]
|
||||
if isinstance(value, torch.Tensor):
|
||||
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
|
||||
elif isinstance(value, (int, float)):
|
||||
self.complementary_info[key][self.position] = value
|
||||
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
self.size = min(self.size + 1, self.capacity)
|
||||
|
||||
|
@ -335,6 +373,13 @@ class ReplayBuffer:
|
|||
batch_dones = self.dones[idx].to(self.device).float()
|
||||
batch_truncateds = self.truncateds[idx].to(self.device).float()
|
||||
|
||||
# Sample complementary_info if available
|
||||
batch_complementary_info = None
|
||||
if self.has_complementary_info:
|
||||
batch_complementary_info = {}
|
||||
for key in self.complementary_info_keys:
|
||||
batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
|
||||
|
||||
return BatchTransition(
|
||||
state=batch_state,
|
||||
action=batch_actions,
|
||||
|
@ -342,6 +387,7 @@ class ReplayBuffer:
|
|||
next_state=batch_next_state,
|
||||
done=batch_dones,
|
||||
truncated=batch_truncateds,
|
||||
complementary_info=batch_complementary_info,
|
||||
)
|
||||
|
||||
def get_iterator(
|
||||
|
@ -518,7 +564,19 @@ class ReplayBuffer:
|
|||
if action_delta is not None:
|
||||
first_action = first_action / action_delta
|
||||
|
||||
replay_buffer._initialize_storage(state=first_state, action=first_action)
|
||||
# Get complementary info if available
|
||||
first_complementary_info = None
|
||||
if (
|
||||
"complementary_info" in first_transition
|
||||
and first_transition["complementary_info"] is not None
|
||||
):
|
||||
first_complementary_info = {
|
||||
k: v.to(device) for k, v in first_transition["complementary_info"].items()
|
||||
}
|
||||
|
||||
replay_buffer._initialize_storage(
|
||||
state=first_state, action=first_action, complementary_info=first_complementary_info
|
||||
)
|
||||
|
||||
# Fill the buffer with all transitions
|
||||
for data in list_transition:
|
||||
|
@ -546,6 +604,7 @@ class ReplayBuffer:
|
|||
next_state=data["next_state"],
|
||||
done=data["done"],
|
||||
truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset
|
||||
complementary_info=data.get("complementary_info", None),
|
||||
)
|
||||
|
||||
return replay_buffer
|
||||
|
@ -587,6 +646,13 @@ class ReplayBuffer:
|
|||
f_info = guess_feature_info(t=sample_val, name=key)
|
||||
features[key] = f_info
|
||||
|
||||
# Add complementary_info keys if available
|
||||
if self.has_complementary_info:
|
||||
for key in self.complementary_info_keys:
|
||||
sample_val = self.complementary_info[key][0]
|
||||
f_info = guess_feature_info(t=sample_val, name=f"complementary_info.{key}")
|
||||
features[f"complementary_info.{key}"] = f_info
|
||||
|
||||
# Create an empty LeRobotDataset
|
||||
lerobot_dataset = LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
|
@ -620,6 +686,11 @@ class ReplayBuffer:
|
|||
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
||||
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
||||
|
||||
# Add complementary_info if available
|
||||
if self.has_complementary_info:
|
||||
for key in self.complementary_info_keys:
|
||||
frame_dict[f"complementary_info.{key}"] = self.complementary_info[key][actual_idx].cpu()
|
||||
|
||||
# Add task field which is required by LeRobotDataset
|
||||
frame_dict["task"] = task_name
|
||||
|
||||
|
@ -686,6 +757,10 @@ class ReplayBuffer:
|
|||
sample = dataset[0]
|
||||
has_done_key = "next.done" in sample
|
||||
|
||||
# Check for complementary_info keys
|
||||
complementary_info_keys = [key for key in sample.keys() if key.startswith("complementary_info.")]
|
||||
has_complementary_info = len(complementary_info_keys) > 0
|
||||
|
||||
# If not, we need to infer it from episode boundaries
|
||||
if not has_done_key:
|
||||
print("'next.done' key not found in dataset. Inferring from episode boundaries...")
|
||||
|
@ -735,6 +810,16 @@ class ReplayBuffer:
|
|||
next_state_data[key] = val.unsqueeze(0) # Add batch dimension
|
||||
next_state = next_state_data
|
||||
|
||||
# ----- 5) Complementary info (if available) -----
|
||||
complementary_info = None
|
||||
if has_complementary_info:
|
||||
complementary_info = {}
|
||||
for key in complementary_info_keys:
|
||||
# Strip the "complementary_info." prefix to get the actual key
|
||||
clean_key = key[len("complementary_info.") :]
|
||||
val = current_sample[key]
|
||||
complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- Construct the Transition -----
|
||||
transition = Transition(
|
||||
state=current_state,
|
||||
|
@ -743,6 +828,7 @@ class ReplayBuffer:
|
|||
next_state=next_state,
|
||||
done=done,
|
||||
truncated=truncated,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
transitions.append(transition)
|
||||
|
||||
|
@ -775,32 +861,33 @@ def concatenate_batch_transitions(
|
|||
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
|
||||
) -> BatchTransition:
|
||||
"""NOTE: Be careful it change the left_batch_transitions in place"""
|
||||
# Concatenate state fields
|
||||
left_batch_transitions["state"] = {
|
||||
key: torch.cat(
|
||||
[
|
||||
left_batch_transitions["state"][key],
|
||||
right_batch_transition["state"][key],
|
||||
],
|
||||
[left_batch_transitions["state"][key], right_batch_transition["state"][key]],
|
||||
dim=0,
|
||||
)
|
||||
for key in left_batch_transitions["state"]
|
||||
}
|
||||
|
||||
# Concatenate basic fields
|
||||
left_batch_transitions["action"] = torch.cat(
|
||||
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
|
||||
)
|
||||
left_batch_transitions["reward"] = torch.cat(
|
||||
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
|
||||
)
|
||||
|
||||
# Concatenate next_state fields
|
||||
left_batch_transitions["next_state"] = {
|
||||
key: torch.cat(
|
||||
[
|
||||
left_batch_transitions["next_state"][key],
|
||||
right_batch_transition["next_state"][key],
|
||||
],
|
||||
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]],
|
||||
dim=0,
|
||||
)
|
||||
for key in left_batch_transitions["next_state"]
|
||||
}
|
||||
|
||||
# Concatenate done and truncated fields
|
||||
left_batch_transitions["done"] = torch.cat(
|
||||
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
|
||||
)
|
||||
|
@ -808,8 +895,26 @@ def concatenate_batch_transitions(
|
|||
[left_batch_transitions["truncated"], right_batch_transition["truncated"]],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Handle complementary_info
|
||||
left_info = left_batch_transitions.get("complementary_info")
|
||||
right_info = right_batch_transition.get("complementary_info")
|
||||
|
||||
# Only process if right_info exists
|
||||
if right_info is not None:
|
||||
# Initialize left complementary_info if needed
|
||||
if left_info is None:
|
||||
left_batch_transitions["complementary_info"] = right_info
|
||||
else:
|
||||
# Concatenate each field
|
||||
for key in right_info:
|
||||
if key in left_info:
|
||||
left_info[key] = torch.cat([left_info[key], right_info[key]], dim=0)
|
||||
else:
|
||||
left_info[key] = right_info[key]
|
||||
|
||||
return left_batch_transitions
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass # All test code is currently commented out
|
||||
pass
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
#!/usr/bin/env python
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
|
|
Loading…
Reference in New Issue