Refactor complementary_info handling in ReplayBuffer
This commit is contained in:
parent
4621f4e4f3
commit
6c10390653
|
@ -74,11 +74,15 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr
|
|||
|
||||
# Move complementary_info tensors if present
|
||||
if transition.get("complementary_info") is not None:
|
||||
transition["complementary_info"] = {
|
||||
key: val.to(device, non_blocking=non_blocking)
|
||||
for key, val in transition["complementary_info"].items()
|
||||
}
|
||||
|
||||
for key, val in transition["complementary_info"].items():
|
||||
if isinstance(val, torch.Tensor):
|
||||
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
|
||||
elif isinstance(val, (int, float, bool)):
|
||||
transition["complementary_info"][key] = torch.tensor(
|
||||
val, device=device, non_blocking=non_blocking
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
|
||||
return transition
|
||||
|
||||
|
||||
|
@ -650,6 +654,8 @@ class ReplayBuffer:
|
|||
if self.has_complementary_info:
|
||||
for key in self.complementary_info_keys:
|
||||
sample_val = self.complementary_info[key][0]
|
||||
if isinstance(sample_val, torch.Tensor) and sample_val.ndim == 0:
|
||||
sample_val = sample_val.unsqueeze(0)
|
||||
f_info = guess_feature_info(t=sample_val, name=f"complementary_info.{key}")
|
||||
features[f"complementary_info.{key}"] = f_info
|
||||
|
||||
|
@ -689,7 +695,15 @@ class ReplayBuffer:
|
|||
# Add complementary_info if available
|
||||
if self.has_complementary_info:
|
||||
for key in self.complementary_info_keys:
|
||||
frame_dict[f"complementary_info.{key}"] = self.complementary_info[key][actual_idx].cpu()
|
||||
val = self.complementary_info[key][actual_idx]
|
||||
# Convert tensors to CPU
|
||||
if isinstance(val, torch.Tensor):
|
||||
if val.ndim == 0:
|
||||
val = val.unsqueeze(0)
|
||||
frame_dict[f"complementary_info.{key}"] = val.cpu()
|
||||
# Non-tensor values can be used directly
|
||||
else:
|
||||
frame_dict[f"complementary_info.{key}"] = val
|
||||
|
||||
# Add task field which is required by LeRobotDataset
|
||||
frame_dict["task"] = task_name
|
||||
|
@ -758,7 +772,7 @@ class ReplayBuffer:
|
|||
has_done_key = "next.done" in sample
|
||||
|
||||
# Check for complementary_info keys
|
||||
complementary_info_keys = [key for key in sample.keys() if key.startswith("complementary_info.")]
|
||||
complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")]
|
||||
has_complementary_info = len(complementary_info_keys) > 0
|
||||
|
||||
# If not, we need to infer it from episode boundaries
|
||||
|
@ -818,7 +832,13 @@ class ReplayBuffer:
|
|||
# Strip the "complementary_info." prefix to get the actual key
|
||||
clean_key = key[len("complementary_info.") :]
|
||||
val = current_sample[key]
|
||||
complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension
|
||||
# Handle tensor and non-tensor values differently
|
||||
if isinstance(val, torch.Tensor):
|
||||
complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension
|
||||
else:
|
||||
# TODO: (azouitine) Check if it's necessary to convert to tensor
|
||||
# For non-tensor values, use directly
|
||||
complementary_info[clean_key] = val
|
||||
|
||||
# ----- Construct the Transition -----
|
||||
transition = Transition(
|
||||
|
@ -836,12 +856,13 @@ class ReplayBuffer:
|
|||
|
||||
|
||||
# Utility function to guess shapes/dtypes from a tensor
|
||||
def guess_feature_info(t: torch.Tensor, name: str):
|
||||
def guess_feature_info(t, name: str):
|
||||
"""
|
||||
Return a dictionary with the 'dtype' and 'shape' for a given tensor or array.
|
||||
Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value.
|
||||
If it looks like a 3D (C,H,W) shape, we might consider it an 'image'.
|
||||
Otherwise default to 'float32' for numeric. You can customize as needed.
|
||||
Otherwise default to appropriate dtype for numeric.
|
||||
"""
|
||||
|
||||
shape = tuple(t.shape)
|
||||
# Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image'
|
||||
if len(shape) == 3 and shape[0] in [1, 3]:
|
||||
|
@ -917,4 +938,91 @@ def concatenate_batch_transitions(
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
|
||||
def test_load_dataset_with_complementary_info():
|
||||
"""
|
||||
Test loading a dataset with complementary_info into a ReplayBuffer.
|
||||
The dataset 'aractingi/pick_lift_cube_two_cameras_gripper_penalty' contains
|
||||
gripper_penalty values in complementary_info.
|
||||
"""
|
||||
import time
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
print("Loading dataset with complementary info...")
|
||||
# Load a small subset of the dataset (first episode)
|
||||
dataset = LeRobotDataset(
|
||||
repo_id="aractingi/pick_lift_cube_two_cameras_gripper_penalty",
|
||||
)
|
||||
|
||||
print(f"Dataset loaded with {len(dataset)} frames")
|
||||
print(f"Dataset features: {list(dataset.features.keys())}")
|
||||
|
||||
# Check if dataset has complementary_info.gripper_penalty
|
||||
sample = dataset[0]
|
||||
complementary_info_keys = [key for key in sample if key.startswith("complementary_info")]
|
||||
print(f"Complementary info keys: {complementary_info_keys}")
|
||||
|
||||
if "complementary_info.gripper_penalty" in sample:
|
||||
print(f"Found gripper_penalty: {sample['complementary_info.gripper_penalty']}")
|
||||
|
||||
# Extract state keys for the buffer
|
||||
state_keys = []
|
||||
for key in sample:
|
||||
if key.startswith("observation"):
|
||||
state_keys.append(key)
|
||||
|
||||
print(f"Using state keys: {state_keys}")
|
||||
|
||||
# Create a replay buffer from the dataset
|
||||
start_time = time.time()
|
||||
buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
lerobot_dataset=dataset, state_keys=state_keys, use_drq=True, optimize_memory=False
|
||||
)
|
||||
load_time = time.time() - start_time
|
||||
print(f"Loaded dataset into buffer in {load_time:.2f} seconds")
|
||||
print(f"Buffer size: {len(buffer)}")
|
||||
|
||||
# Check if complementary_info was transferred correctly
|
||||
print("Sampling from buffer to check complementary_info...")
|
||||
batch = buffer.sample(batch_size=4)
|
||||
|
||||
if batch["complementary_info"] is not None:
|
||||
print("Complementary info in batch:")
|
||||
for key, value in batch["complementary_info"].items():
|
||||
print(f" {key}: {type(value)}, shape: {value.shape if hasattr(value, 'shape') else 'N/A'}")
|
||||
if key == "gripper_penalty":
|
||||
print(f" Sample gripper_penalty values: {value[:5]}")
|
||||
else:
|
||||
print("No complementary_info found in batch")
|
||||
|
||||
# Now convert the buffer back to a LeRobotDataset
|
||||
print("\nConverting buffer back to LeRobotDataset...")
|
||||
start_time = time.time()
|
||||
new_dataset = buffer.to_lerobot_dataset(
|
||||
repo_id="test_dataset_from_buffer",
|
||||
fps=dataset.fps,
|
||||
root="./test_dataset_from_buffer",
|
||||
task_name="test_conversion",
|
||||
)
|
||||
convert_time = time.time() - start_time
|
||||
print(f"Converted buffer to dataset in {convert_time:.2f} seconds")
|
||||
print(f"New dataset size: {len(new_dataset)} frames")
|
||||
|
||||
# Check if complementary_info was preserved
|
||||
new_sample = new_dataset[0]
|
||||
new_complementary_info_keys = [key for key in new_sample if key.startswith("complementary_info")]
|
||||
print(f"New dataset complementary info keys: {new_complementary_info_keys}")
|
||||
|
||||
if "complementary_info.gripper_penalty" in new_sample:
|
||||
print(f"Found gripper_penalty in new dataset: {new_sample['complementary_info.gripper_penalty']}")
|
||||
|
||||
# Compare original and new datasets
|
||||
print("\nComparing original and new datasets:")
|
||||
print(f"Original dataset frames: {len(dataset)}, New dataset frames: {len(new_dataset)}")
|
||||
print(f"Original features: {list(dataset.features.keys())}")
|
||||
print(f"New features: {list(new_dataset.features.keys())}")
|
||||
|
||||
return buffer, dataset, new_dataset
|
||||
|
||||
# Run the test
|
||||
test_load_dataset_with_complementary_info()
|
||||
|
|
Loading…
Reference in New Issue