From 71715c3914cb890ae6beb70de371d27fd513eeb1 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Wed, 23 Apr 2025 11:42:21 +0200 Subject: [PATCH] fix hf_dataset.set_transform(hf_transform_to_torch) --- lerobot/common/datasets/lerobot_dataset.py | 4 ++-- tests/datasets/test_delta_timestamps.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 7c61e486..dddd11ab 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -736,7 +736,7 @@ class LeRobotDataset(torch.utils.data.Dataset): for key in self.meta.video_keys: if query_indices is not None and key in query_indices: timestamps = self.hf_dataset.select(query_indices[key])["timestamp"] - query_timestamps[key] = timestamps.tolist() + query_timestamps[key] = torch.stack(timestamps).tolist() else: query_timestamps[key] = [current_ts] @@ -744,7 +744,7 @@ class LeRobotDataset(torch.utils.data.Dataset): def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict: return { - key: self.hf_dataset.select(q_idx)[key] + key: torch.stack(self.hf_dataset.select(q_idx)[key]) for key, q_idx in query_indices.items() if key not in self.meta.video_keys } diff --git a/tests/datasets/test_delta_timestamps.py b/tests/datasets/test_delta_timestamps.py index d2ca7d13..35014642 100644 --- a/tests/datasets/test_delta_timestamps.py +++ b/tests/datasets/test_delta_timestamps.py @@ -56,8 +56,8 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n def synced_timestamps_factory(hf_dataset_factory): def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]: hf_dataset = hf_dataset_factory(fps=fps) - timestamps = hf_dataset["timestamp"].numpy() - episode_indices = hf_dataset["episode_index"].numpy() + timestamps = torch.stack(hf_dataset["timestamp"]).numpy() + episode_indices = torch.stack(hf_dataset["episode_index"]).numpy() episode_data_index = calculate_episode_data_index(hf_dataset) return timestamps, episode_indices, episode_data_index