From 9d6886dd086a1ed469e7b49ed86a17999aad7fab Mon Sep 17 00:00:00 2001 From: Remi Date: Fri, 14 Feb 2025 14:22:22 +0100 Subject: [PATCH] Add frame level task (#693) Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> --- examples/port_datasets/pusht_zarr.py | 3 +- lerobot/common/datasets/lerobot_dataset.py | 117 +++++++++++------- lerobot/common/datasets/utils.py | 4 +- lerobot/common/robot_devices/control_utils.py | 8 +- lerobot/scripts/control_robot.py | 5 +- tests/test_datasets.py | 18 +++ 6 files changed, 105 insertions(+), 50 deletions(-) diff --git a/examples/port_datasets/pusht_zarr.py b/examples/port_datasets/pusht_zarr.py index 1506f427..e9015d2c 100644 --- a/examples/port_datasets/pusht_zarr.py +++ b/examples/port_datasets/pusht_zarr.py @@ -180,6 +180,7 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T # Shift reward and success by +1 until the last item of the episode "next.reward": reward[i + (frame_idx < num_frames - 1)], "next.success": success[i + (frame_idx < num_frames - 1)], + "task": PUSHT_TASK, } frame["observation.state"] = torch.from_numpy(agent_pos[i]) @@ -191,7 +192,7 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T dataset.add_frame(frame) - dataset.save_episode(task=PUSHT_TASK) + dataset.save_episode() dataset.consolidate() diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index c1a4a6d5..877703d7 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -87,7 +87,7 @@ class LeRobotDatasetMetadata: self.pull_from_repo(allow_patterns="meta/") self.info = load_info(self.root) self.stats = load_stats(self.root) - self.tasks = load_tasks(self.root) + self.tasks, self.task_to_task_index = load_tasks(self.root) self.episodes = load_episodes(self.root) def pull_from_repo( @@ -202,31 +202,35 @@ class LeRobotDatasetMetadata: """Max number of episodes per chunk.""" return self.info["chunks_size"] - @property - def task_to_task_index(self) -> dict: - return {task: task_idx for task_idx, task in self.tasks.items()} - - def get_task_index(self, task: str) -> int: + def get_task_index(self, task: str) -> int | None: """ Given a task in natural language, returns its task_index if the task already exists in the dataset, - otherwise creates a new task_index. + otherwise return None. """ - task_index = self.task_to_task_index.get(task, None) - return task_index if task_index is not None else self.total_tasks + return self.task_to_task_index.get(task, None) - def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None: + def add_task(self, task: str): + """ + Given a task in natural language, add it to the dictionnary of tasks. + """ + if task in self.task_to_task_index: + raise ValueError(f"The task '{task}' already exists and can't be added twice.") + + task_index = self.info["total_tasks"] + self.task_to_task_index[task] = task_index + self.tasks[task_index] = task + self.info["total_tasks"] += 1 + + task_dict = { + "task_index": task_index, + "task": task, + } + append_jsonlines(task_dict, self.root / TASKS_PATH) + + def save_episode(self, episode_index: int, episode_length: int, episode_tasks: list[str]) -> None: self.info["total_episodes"] += 1 self.info["total_frames"] += episode_length - if task_index not in self.tasks: - self.info["total_tasks"] += 1 - self.tasks[task_index] = task - task_dict = { - "task_index": task_index, - "task": task, - } - append_jsonlines(task_dict, self.root / TASKS_PATH) - chunk = self.get_episode_chunk(episode_index) if chunk >= self.total_chunks: self.info["total_chunks"] += 1 @@ -237,7 +241,7 @@ class LeRobotDatasetMetadata: episode_dict = { "episode_index": episode_index, - "tasks": [task], + "tasks": episode_tasks, "length": episode_length, } self.episodes.append(episode_dict) @@ -313,7 +317,8 @@ class LeRobotDatasetMetadata: features = {**features, **DEFAULT_FEATURES} - obj.tasks, obj.stats, obj.episodes = {}, {}, [] + obj.tasks, obj.task_to_task_index = {}, {} + obj.stats, obj.episodes = {}, [] obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos) if len(obj.video_keys) > 0 and not use_videos: raise ValueError() @@ -691,10 +696,13 @@ class LeRobotDataset(torch.utils.data.Dataset): def create_episode_buffer(self, episode_index: int | None = None) -> dict: current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index - return { - "size": 0, - **{key: current_ep_idx if key == "episode_index" else [] for key in self.features}, - } + ep_buffer = {} + # size and task are special cases that are not in self.features + ep_buffer["size"] = 0 + ep_buffer["task"] = [] + for key in self.features: + ep_buffer[key] = current_ep_idx if key == "episode_index" else [] + return ep_buffer def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path: fpath = DEFAULT_IMAGE_PATH.format( @@ -718,6 +726,8 @@ class LeRobotDataset(torch.utils.data.Dataset): """ # TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch, # check the dtype and shape matches, etc. + if "task" not in frame: + raise ValueError("The mandatory feature 'task' wasn't found in `frame` dictionnary.") if self.episode_buffer is None: self.episode_buffer = self.create_episode_buffer() @@ -728,13 +738,17 @@ class LeRobotDataset(torch.utils.data.Dataset): self.episode_buffer["timestamp"].append(timestamp) for key in frame: - if key not in self.features: - raise ValueError(key) + if key == "task": + # Note: we associate the task in natural language to its task index during `save_episode` + self.episode_buffer["task"].append(frame["task"]) + continue - if self.features[key]["dtype"] not in ["image", "video"]: - item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key] - self.episode_buffer[key].append(item) - elif self.features[key]["dtype"] in ["image", "video"]: + if key not in self.features: + raise ValueError( + f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'." + ) + + if self.features[key]["dtype"] in ["image", "video"]: img_path = self._get_image_file_path( episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index ) @@ -742,10 +756,13 @@ class LeRobotDataset(torch.utils.data.Dataset): img_path.parent.mkdir(parents=True, exist_ok=True) self._save_image(frame[key], img_path) self.episode_buffer[key].append(str(img_path)) + else: + item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key] + self.episode_buffer[key].append(item) self.episode_buffer["size"] += 1 - def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict | None = None) -> None: + def save_episode(self, encode_videos: bool = True, episode_data: dict | None = None) -> None: """ This will save to disk the current episode in self.episode_buffer. Note that since it affects files on disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to @@ -758,7 +775,11 @@ class LeRobotDataset(torch.utils.data.Dataset): if not episode_data: episode_buffer = self.episode_buffer + # size and task are special cases that won't be added to hf_dataset episode_length = episode_buffer.pop("size") + tasks = episode_buffer.pop("task") + episode_tasks = list(set(tasks)) + episode_index = episode_buffer["episode_index"] if episode_index != self.meta.total_episodes: # TODO(aliberts): Add option to use existing episode_index @@ -772,21 +793,27 @@ class LeRobotDataset(torch.utils.data.Dataset): "You must add one or several frames with `add_frame` before calling `add_episode`." ) - task_index = self.meta.get_task_index(task) - if not set(episode_buffer.keys()) == set(self.features): - raise ValueError() + raise ValueError( + f"Features from `episode_buffer` don't match the ones in `self.features`: '{set(episode_buffer.keys())}' vs '{set(self.features)}'" + ) + + episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length) + episode_buffer["episode_index"] = np.full((episode_length,), episode_index) + + # Add new tasks to the tasks dictionnary + for task in episode_tasks: + task_index = self.meta.get_task_index(task) + if task_index is None: + self.meta.add_task(task) + + # Given tasks in natural language, find their corresponding task indices + episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks]) for key, ft in self.features.items(): - if key == "index": - episode_buffer[key] = np.arange( - self.meta.total_frames, self.meta.total_frames + episode_length - ) - elif key == "episode_index": - episode_buffer[key] = np.full((episode_length,), episode_index) - elif key == "task_index": - episode_buffer[key] = np.full((episode_length,), task_index) - elif ft["dtype"] in ["image", "video"]: + # index, episode_index, task_index are already processed above, and image and video + # are processed separately by storing image path and frame info as meta data + if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: continue elif len(ft["shape"]) == 1 and ft["shape"][0] == 1: episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"]) @@ -798,7 +825,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self._wait_image_writer() self._save_episode_table(episode_buffer, episode_index) - self.meta.save_episode(episode_index, episode_length, task, task_index) + self.meta.save_episode(episode_index, episode_length, episode_tasks) if encode_videos and len(self.meta.video_keys) > 0: video_paths = self.encode_episode_videos(episode_index) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 612bac39..e6ec169e 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -170,7 +170,9 @@ def load_stats(local_dir: Path) -> dict: def load_tasks(local_dir: Path) -> dict: tasks = load_jsonlines(local_dir / TASKS_PATH) - return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} + tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} + task_to_task_index = {task: task_index for task_index, task in tasks.items()} + return tasks, task_to_task_index def load_episodes(local_dir: Path) -> dict: diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 9368b89d..5dcafa69 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -183,6 +183,7 @@ def record_episode( device, use_amp, fps, + single_task, ): control_loop( robot=robot, @@ -195,6 +196,7 @@ def record_episode( use_amp=use_amp, fps=fps, teleoperate=policy is None, + single_task=single_task, ) @@ -210,6 +212,7 @@ def control_loop( device: torch.device | str | None = None, use_amp: bool | None = None, fps: int | None = None, + single_task: str | None = None, ): # TODO(rcadene): Add option to record logs if not robot.is_connected: @@ -224,6 +227,9 @@ def control_loop( if teleoperate and policy is not None: raise ValueError("When `teleoperate` is True, `policy` should be None.") + if dataset is not None and single_task is None: + raise ValueError("You need to provide a task as argument in `single_task`.") + if dataset is not None and fps is not None and dataset.fps != fps: raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).") @@ -248,7 +254,7 @@ def control_loop( action = {"action": action} if dataset is not None: - frame = {**observation, **action} + frame = {**observation, **action, "task": single_task} dataset.add_frame(frame) if display_cameras and not is_headless(): diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 3fdb0acc..de67e331 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -263,8 +263,8 @@ def record( log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds) record_episode( - dataset=dataset, robot=robot, + dataset=dataset, events=events, episode_time_s=cfg.episode_time_s, display_cameras=cfg.display_cameras, @@ -272,6 +272,7 @@ def record( device=cfg.device, use_amp=cfg.use_amp, fps=cfg.fps, + single_task=cfg.single_task, ) # Execute a few seconds without recording to give time to manually reset the environment @@ -291,7 +292,7 @@ def record( dataset.clear_episode_buffer() continue - dataset.save_episode(cfg.single_task) + dataset.save_episode() recorded_episodes += 1 if events["stop_recording"]: diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 2945df41..460954c1 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -93,6 +93,24 @@ def test_dataset_initialization(lerobot_dataset_factory, tmp_path): assert dataset.num_frames == len(dataset) +def test_add_frame_no_task(tmp_path): + features = {"1d": {"dtype": "float32", "shape": (1,), "names": None}} + dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, root=tmp_path / "test", features=features) + with pytest.raises(ValueError, match="The mandatory feature 'task' wasn't found in `frame` dictionnary."): + dataset.add_frame({"1d": torch.randn(1)}) + + +def test_add_frame(tmp_path): + features = {"1d": {"dtype": "float32", "shape": (1,), "names": None}} + dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, root=tmp_path / "test", features=features) + dataset.add_frame({"1d": torch.randn(1), "task": "dummy"}) + dataset.save_episode(encode_videos=False) + dataset.consolidate(run_compute_stats=False) + assert len(dataset) == 1 + assert dataset[0]["task"] == "dummy" + assert dataset[0]["task_index"] == 0 + + # TODO(aliberts): # - [ ] test various attributes & state from init and create # - [ ] test init with episodes and check num_frames