diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 92ad7dc7..185412fd 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -74,11 +74,15 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr # Move complementary_info tensors if present if transition.get("complementary_info") is not None: - transition["complementary_info"] = { - key: val.to(device, non_blocking=non_blocking) - for key, val in transition["complementary_info"].items() - } - + for key, val in transition["complementary_info"].items(): + if isinstance(val, torch.Tensor): + transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking) + elif isinstance(val, (int, float, bool)): + transition["complementary_info"][key] = torch.tensor( + val, device=device, non_blocking=non_blocking + ) + else: + raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]") return transition @@ -650,6 +654,8 @@ class ReplayBuffer: if self.has_complementary_info: for key in self.complementary_info_keys: sample_val = self.complementary_info[key][0] + if isinstance(sample_val, torch.Tensor) and sample_val.ndim == 0: + sample_val = sample_val.unsqueeze(0) f_info = guess_feature_info(t=sample_val, name=f"complementary_info.{key}") features[f"complementary_info.{key}"] = f_info @@ -689,7 +695,15 @@ class ReplayBuffer: # Add complementary_info if available if self.has_complementary_info: for key in self.complementary_info_keys: - frame_dict[f"complementary_info.{key}"] = self.complementary_info[key][actual_idx].cpu() + val = self.complementary_info[key][actual_idx] + # Convert tensors to CPU + if isinstance(val, torch.Tensor): + if val.ndim == 0: + val = val.unsqueeze(0) + frame_dict[f"complementary_info.{key}"] = val.cpu() + # Non-tensor values can be used directly + else: + frame_dict[f"complementary_info.{key}"] = val # Add task field which is required by LeRobotDataset frame_dict["task"] = task_name @@ -758,7 +772,7 @@ class ReplayBuffer: has_done_key = "next.done" in sample # Check for complementary_info keys - complementary_info_keys = [key for key in sample.keys() if key.startswith("complementary_info.")] + complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")] has_complementary_info = len(complementary_info_keys) > 0 # If not, we need to infer it from episode boundaries @@ -818,7 +832,13 @@ class ReplayBuffer: # Strip the "complementary_info." prefix to get the actual key clean_key = key[len("complementary_info.") :] val = current_sample[key] - complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension + # Handle tensor and non-tensor values differently + if isinstance(val, torch.Tensor): + complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension + else: + # TODO: (azouitine) Check if it's necessary to convert to tensor + # For non-tensor values, use directly + complementary_info[clean_key] = val # ----- Construct the Transition ----- transition = Transition( @@ -836,12 +856,13 @@ class ReplayBuffer: # Utility function to guess shapes/dtypes from a tensor -def guess_feature_info(t: torch.Tensor, name: str): +def guess_feature_info(t, name: str): """ - Return a dictionary with the 'dtype' and 'shape' for a given tensor or array. + Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value. If it looks like a 3D (C,H,W) shape, we might consider it an 'image'. - Otherwise default to 'float32' for numeric. You can customize as needed. + Otherwise default to appropriate dtype for numeric. """ + shape = tuple(t.shape) # Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image' if len(shape) == 3 and shape[0] in [1, 3]: @@ -917,4 +938,91 @@ def concatenate_batch_transitions( if __name__ == "__main__": - pass + + def test_load_dataset_with_complementary_info(): + """ + Test loading a dataset with complementary_info into a ReplayBuffer. + The dataset 'aractingi/pick_lift_cube_two_cameras_gripper_penalty' contains + gripper_penalty values in complementary_info. + """ + import time + from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + + print("Loading dataset with complementary info...") + # Load a small subset of the dataset (first episode) + dataset = LeRobotDataset( + repo_id="aractingi/pick_lift_cube_two_cameras_gripper_penalty", + ) + + print(f"Dataset loaded with {len(dataset)} frames") + print(f"Dataset features: {list(dataset.features.keys())}") + + # Check if dataset has complementary_info.gripper_penalty + sample = dataset[0] + complementary_info_keys = [key for key in sample if key.startswith("complementary_info")] + print(f"Complementary info keys: {complementary_info_keys}") + + if "complementary_info.gripper_penalty" in sample: + print(f"Found gripper_penalty: {sample['complementary_info.gripper_penalty']}") + + # Extract state keys for the buffer + state_keys = [] + for key in sample: + if key.startswith("observation"): + state_keys.append(key) + + print(f"Using state keys: {state_keys}") + + # Create a replay buffer from the dataset + start_time = time.time() + buffer = ReplayBuffer.from_lerobot_dataset( + lerobot_dataset=dataset, state_keys=state_keys, use_drq=True, optimize_memory=False + ) + load_time = time.time() - start_time + print(f"Loaded dataset into buffer in {load_time:.2f} seconds") + print(f"Buffer size: {len(buffer)}") + + # Check if complementary_info was transferred correctly + print("Sampling from buffer to check complementary_info...") + batch = buffer.sample(batch_size=4) + + if batch["complementary_info"] is not None: + print("Complementary info in batch:") + for key, value in batch["complementary_info"].items(): + print(f" {key}: {type(value)}, shape: {value.shape if hasattr(value, 'shape') else 'N/A'}") + if key == "gripper_penalty": + print(f" Sample gripper_penalty values: {value[:5]}") + else: + print("No complementary_info found in batch") + + # Now convert the buffer back to a LeRobotDataset + print("\nConverting buffer back to LeRobotDataset...") + start_time = time.time() + new_dataset = buffer.to_lerobot_dataset( + repo_id="test_dataset_from_buffer", + fps=dataset.fps, + root="./test_dataset_from_buffer", + task_name="test_conversion", + ) + convert_time = time.time() - start_time + print(f"Converted buffer to dataset in {convert_time:.2f} seconds") + print(f"New dataset size: {len(new_dataset)} frames") + + # Check if complementary_info was preserved + new_sample = new_dataset[0] + new_complementary_info_keys = [key for key in new_sample if key.startswith("complementary_info")] + print(f"New dataset complementary info keys: {new_complementary_info_keys}") + + if "complementary_info.gripper_penalty" in new_sample: + print(f"Found gripper_penalty in new dataset: {new_sample['complementary_info.gripper_penalty']}") + + # Compare original and new datasets + print("\nComparing original and new datasets:") + print(f"Original dataset frames: {len(dataset)}, New dataset frames: {len(new_dataset)}") + print(f"Original features: {list(dataset.features.keys())}") + print(f"New features: {list(new_dataset.features.keys())}") + + return buffer, dataset, new_dataset + + # Run the test + test_load_dataset_with_complementary_info()