Refactor complementary_info handling in ReplayBuffer

This commit is contained in:
AdilZouitine 2025-04-07 14:48:42 +00:00
parent 4621f4e4f3
commit 6c10390653
1 changed files with 120 additions and 12 deletions

View File

@ -74,11 +74,15 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr
# Move complementary_info tensors if present
if transition.get("complementary_info") is not None:
transition["complementary_info"] = {
key: val.to(device, non_blocking=non_blocking)
for key, val in transition["complementary_info"].items()
}
for key, val in transition["complementary_info"].items():
if isinstance(val, torch.Tensor):
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
elif isinstance(val, (int, float, bool)):
transition["complementary_info"][key] = torch.tensor(
val, device=device, non_blocking=non_blocking
)
else:
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
return transition
@ -650,6 +654,8 @@ class ReplayBuffer:
if self.has_complementary_info:
for key in self.complementary_info_keys:
sample_val = self.complementary_info[key][0]
if isinstance(sample_val, torch.Tensor) and sample_val.ndim == 0:
sample_val = sample_val.unsqueeze(0)
f_info = guess_feature_info(t=sample_val, name=f"complementary_info.{key}")
features[f"complementary_info.{key}"] = f_info
@ -689,7 +695,15 @@ class ReplayBuffer:
# Add complementary_info if available
if self.has_complementary_info:
for key in self.complementary_info_keys:
frame_dict[f"complementary_info.{key}"] = self.complementary_info[key][actual_idx].cpu()
val = self.complementary_info[key][actual_idx]
# Convert tensors to CPU
if isinstance(val, torch.Tensor):
if val.ndim == 0:
val = val.unsqueeze(0)
frame_dict[f"complementary_info.{key}"] = val.cpu()
# Non-tensor values can be used directly
else:
frame_dict[f"complementary_info.{key}"] = val
# Add task field which is required by LeRobotDataset
frame_dict["task"] = task_name
@ -758,7 +772,7 @@ class ReplayBuffer:
has_done_key = "next.done" in sample
# Check for complementary_info keys
complementary_info_keys = [key for key in sample.keys() if key.startswith("complementary_info.")]
complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")]
has_complementary_info = len(complementary_info_keys) > 0
# If not, we need to infer it from episode boundaries
@ -818,7 +832,13 @@ class ReplayBuffer:
# Strip the "complementary_info." prefix to get the actual key
clean_key = key[len("complementary_info.") :]
val = current_sample[key]
complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension
# Handle tensor and non-tensor values differently
if isinstance(val, torch.Tensor):
complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension
else:
# TODO: (azouitine) Check if it's necessary to convert to tensor
# For non-tensor values, use directly
complementary_info[clean_key] = val
# ----- Construct the Transition -----
transition = Transition(
@ -836,12 +856,13 @@ class ReplayBuffer:
# Utility function to guess shapes/dtypes from a tensor
def guess_feature_info(t: torch.Tensor, name: str):
def guess_feature_info(t, name: str):
"""
Return a dictionary with the 'dtype' and 'shape' for a given tensor or array.
Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value.
If it looks like a 3D (C,H,W) shape, we might consider it an 'image'.
Otherwise default to 'float32' for numeric. You can customize as needed.
Otherwise default to appropriate dtype for numeric.
"""
shape = tuple(t.shape)
# Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image'
if len(shape) == 3 and shape[0] in [1, 3]:
@ -917,4 +938,91 @@ def concatenate_batch_transitions(
if __name__ == "__main__":
pass
def test_load_dataset_with_complementary_info():
"""
Test loading a dataset with complementary_info into a ReplayBuffer.
The dataset 'aractingi/pick_lift_cube_two_cameras_gripper_penalty' contains
gripper_penalty values in complementary_info.
"""
import time
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
print("Loading dataset with complementary info...")
# Load a small subset of the dataset (first episode)
dataset = LeRobotDataset(
repo_id="aractingi/pick_lift_cube_two_cameras_gripper_penalty",
)
print(f"Dataset loaded with {len(dataset)} frames")
print(f"Dataset features: {list(dataset.features.keys())}")
# Check if dataset has complementary_info.gripper_penalty
sample = dataset[0]
complementary_info_keys = [key for key in sample if key.startswith("complementary_info")]
print(f"Complementary info keys: {complementary_info_keys}")
if "complementary_info.gripper_penalty" in sample:
print(f"Found gripper_penalty: {sample['complementary_info.gripper_penalty']}")
# Extract state keys for the buffer
state_keys = []
for key in sample:
if key.startswith("observation"):
state_keys.append(key)
print(f"Using state keys: {state_keys}")
# Create a replay buffer from the dataset
start_time = time.time()
buffer = ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=dataset, state_keys=state_keys, use_drq=True, optimize_memory=False
)
load_time = time.time() - start_time
print(f"Loaded dataset into buffer in {load_time:.2f} seconds")
print(f"Buffer size: {len(buffer)}")
# Check if complementary_info was transferred correctly
print("Sampling from buffer to check complementary_info...")
batch = buffer.sample(batch_size=4)
if batch["complementary_info"] is not None:
print("Complementary info in batch:")
for key, value in batch["complementary_info"].items():
print(f" {key}: {type(value)}, shape: {value.shape if hasattr(value, 'shape') else 'N/A'}")
if key == "gripper_penalty":
print(f" Sample gripper_penalty values: {value[:5]}")
else:
print("No complementary_info found in batch")
# Now convert the buffer back to a LeRobotDataset
print("\nConverting buffer back to LeRobotDataset...")
start_time = time.time()
new_dataset = buffer.to_lerobot_dataset(
repo_id="test_dataset_from_buffer",
fps=dataset.fps,
root="./test_dataset_from_buffer",
task_name="test_conversion",
)
convert_time = time.time() - start_time
print(f"Converted buffer to dataset in {convert_time:.2f} seconds")
print(f"New dataset size: {len(new_dataset)} frames")
# Check if complementary_info was preserved
new_sample = new_dataset[0]
new_complementary_info_keys = [key for key in new_sample if key.startswith("complementary_info")]
print(f"New dataset complementary info keys: {new_complementary_info_keys}")
if "complementary_info.gripper_penalty" in new_sample:
print(f"Found gripper_penalty in new dataset: {new_sample['complementary_info.gripper_penalty']}")
# Compare original and new datasets
print("\nComparing original and new datasets:")
print(f"Original dataset frames: {len(dataset)}, New dataset frames: {len(new_dataset)}")
print(f"Original features: {list(dataset.features.keys())}")
print(f"New features: {list(new_dataset.features.keys())}")
return buffer, dataset, new_dataset
# Run the test
test_load_dataset_with_complementary_info()