From 299451af81e268eae963134e7ae9b6c8213c3ed8 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Mon, 21 Oct 2024 19:30:20 +0200 Subject: [PATCH] Add add_episode & task logic --- lerobot/common/datasets/lerobot_dataset.py | 179 +++++++++++++++++- lerobot/common/datasets/utils.py | 5 + lerobot/common/robot_devices/control_utils.py | 2 +- lerobot/scripts/control_robot.py | 35 +++- 4 files changed, 203 insertions(+), 18 deletions(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 61331c5a..53b3c4af 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -19,16 +19,19 @@ from pathlib import Path from typing import Callable import datasets +import pyarrow.parquet as pq import torch import torch.utils from datasets import load_dataset -from huggingface_hub import snapshot_download +from huggingface_hub import snapshot_download, upload_folder from lerobot.common.datasets.compute_stats import aggregate_stats from lerobot.common.datasets.image_writer import ImageWriter from lerobot.common.datasets.utils import ( + append_jsonl, check_delta_timestamps, check_timestamps_sync, + create_branch, create_empty_dataset_info, get_delta_indices, get_episode_data_index, @@ -160,6 +163,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.video_backend = video_backend if video_backend is not None else "pyav" self.image_writer = image_writer self.episode_buffer = {} + self.consolidated = True self.delta_indices = None # Load metadata @@ -192,6 +196,24 @@ class LeRobotDataset(torch.utils.data.Dataset): # - [ ] Update episode_index (arg update=True) # - [ ] Update info.json (arg update=True) + def push_to_repo(self, push_videos: bool = True) -> None: + if not self.consolidated: + raise RuntimeError( + "You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet." + "Please use the '.consolidate()' method first." + ) + ignore_patterns = ["images/"] + if not push_videos: + ignore_patterns.append("videos/") + + upload_folder( + repo_id=self.repo_id, + folder_path=self.root, + repo_type="dataset", + ignore_patterns=ignore_patterns, + ) + create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset") + def pull_from_repo( self, allow_patterns: list[str] | str | None = None, @@ -303,11 +325,6 @@ class LeRobotDataset(torch.utils.data.Dataset): """Number of samples/frames in selected episodes.""" return len(self.hf_dataset) - @property - def total_frames(self) -> int: - """Total number of frames saved in this dataset.""" - return self.info["total_frames"] - @property def num_episodes(self) -> int: """Number of episodes selected.""" @@ -318,6 +335,16 @@ class LeRobotDataset(torch.utils.data.Dataset): """Total number of episodes available.""" return self.info["total_episodes"] + @property + def total_frames(self) -> int: + """Total number of frames saved in this dataset.""" + return self.info["total_frames"] + + @property + def total_tasks(self) -> int: + """Total number of different tasks performed in this dataset.""" + return self.info["total_tasks"] + @property def total_chunks(self) -> int: """Total number of chunks (groups of episodes).""" @@ -331,7 +358,46 @@ class LeRobotDataset(torch.utils.data.Dataset): @property def shapes(self) -> dict: """Shapes for the different features.""" - self.info.get("shapes") + return self.info["shapes"] + + @property + def features(self) -> datasets.Features: + """Shapes for the different features.""" + if self.hf_dataset is not None: + return self.hf_dataset.features + elif self.episode_buffer is None: + raise NotImplementedError( + "Dataset features must be infered from an existing hf_dataset or episode_buffer." + ) + + features = {} + for key in self.episode_buffer: + if key in ["episode_index", "frame_index", "index", "task_index"]: + features[key] = datasets.Value(dtype="int64") + elif key in ["next.done", "next.success"]: + features[key] = datasets.Value(dtype="bool") + elif key in ["timestamp", "next.reward"]: + features[key] = datasets.Value(dtype="float32") + elif key in self.image_keys: + features[key] = datasets.Image() + elif key in self.keys: + features[key] = datasets.Sequence( + length=self.shapes[key], feature=datasets.Value(dtype="float32") + ) + + return datasets.Features(features) + + @property + def task_to_task_index(self) -> dict: + return {task: task_idx for task_idx, task in self.tasks.items()} + + def get_task_index(self, task: str) -> int: + """ + Given a task in natural language, returns its task_index if the task already exists in the dataset, + otherwise creates a new task_index. + """ + task_index = self.task_to_task_index.get(task, None) + return task_index if task_index is not None else self.total_tasks def current_episode_index(self, idx: int) -> int: episode_index = self.hf_dataset["episode_index"][idx] @@ -447,12 +513,12 @@ class LeRobotDataset(torch.utils.data.Dataset): f")" ) - def _create_episode_buffer(self) -> dict: + def _create_episode_buffer(self, episode_index: int | None = None) -> dict: # TODO(aliberts): Handle resume return { - "chunk": self.total_chunks, - "episode_index": self.total_episodes, "size": 0, + "episode_index": self.total_episodes if episode_index is None else episode_index, + "task_index": None, "frame_index": [], "timestamp": [], "next.done": [], @@ -490,6 +556,92 @@ class LeRobotDataset(torch.utils.data.Dataset): file_path=img_path, ) + def add_episode(self, task: str, encode_videos: bool = False) -> None: + """ + This will save to disk the current episode in self.episode_buffer. Note that since it affects files on + disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to + the hub. + + Use encode_videos if you want to encode videos during the saving of each episode. Otherwise, + you can do it later during dataset.consolidate(). This is to give more flexibility on when to spend + time for video encoding. + """ + episode_length = self.episode_buffer.pop("size") + episode_index = self.episode_buffer["episode_index"] + task_index = self.get_task_index(task) + self.episode_buffer["next.done"][-1] = True + + for key in self.episode_buffer: + if key in self.keys: + self.episode_buffer[key] = torch.stack(self.episode_buffer[key]) + elif key == "episode_index": + self.episode_buffer[key] = torch.full((episode_length,), episode_index) + elif key == "task_index": + self.episode_buffer[key] = torch.full((episode_length,), task_index) + else: + self.episode_buffer[key] = torch.tensor(self.episode_buffer[key]) + + self._save_episode_to_metadata(episode_index, episode_length, task, task_index) + self._save_episode_table(episode_index) + + if encode_videos: + pass # TODO + + # Reset the buffer + self.episode_buffer = self._create_episode_buffer() + self.consolidated = False + + def _save_episode_table(self, episode_index: int) -> None: + features = self.features + ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=features, split="train") + ep_table = ep_dataset._data.table + ep_data_path = self.get_data_file_path(ep_index=episode_index, return_str=False) + ep_data_path.parent.mkdir(parents=True, exist_ok=True) + pq.write_table(ep_table, ep_data_path) + + def _save_episode_to_metadata( + self, episode_index: int, episode_length: int, task: str, task_index: int + ) -> None: + self.info["total_episodes"] += 1 + self.info["total_frames"] += episode_length + + if task_index not in self.tasks: + self.info["total_tasks"] += 1 + self.tasks[task_index] = task + task_dict = { + "task_index": task_index, + "task": task, + } + append_jsonl(task_dict, self.root / "meta/tasks.jsonl") + + chunk = self.get_episode_chunk(episode_index) + if chunk >= self.total_chunks: + self.info["total_chunks"] += 1 + + self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} + self.info["total_videos"] += len(self.video_keys) + write_json(self.info, self.root / "meta/info.json") + + episode_dict = { + "episode_index": episode_index, + "tasks": [task], + "length": episode_length, + } + append_jsonl(episode_dict, self.root / "meta/episodes.jsonl") + + def delete_episode(self) -> None: + pass # TODO + + def consolidate(self) -> None: + pass # TODO + # Sanity checks: + # - [ ] shapes + # - [ ] ep_lenghts + # - [ ] number of files + # - [ ] names of files (e.g. parquet 00000-of-00001 and 00001-of-00002) + # - [ ] no remaining self.image_writer.dir + self.consolidated = True + @classmethod def create( cls, @@ -508,6 +660,7 @@ class LeRobotDataset(torch.utils.data.Dataset): obj._version = CODEBASE_VERSION obj.tolerance_s = tolerance_s obj.image_writer = image_writer + obj.hf_dataset = None if not all(cam.fps == fps for cam in robot.cameras.values()): logging.warn( @@ -515,12 +668,18 @@ class LeRobotDataset(torch.utils.data.Dataset): "In this case, frames from lower fps cameras will be repeated to fill in the blanks" ) + obj.tasks = {} obj.info = create_empty_dataset_info(obj._version, fps, robot, use_videos) write_json(obj.info, obj.root / "meta/info.json") # TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer obj.episode_buffer = obj._create_episode_buffer() + # This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. + # It is used to know when certain operations are need (for instance, computing dataset statistics). + # In order to be able to push the dataset to the hub, it needs to be consolidation first. + obj.consolidated = True + # obj.episodes = None # obj.image_transforms = None # obj.delta_timestamps = None diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 79459882..8985e449 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -81,6 +81,11 @@ def write_json(data: dict, fpath: Path) -> None: json.dump(data, f, indent=4, ensure_ascii=False) +def append_jsonl(data: dict, fpath: Path) -> None: + with jsonlines.open(fpath, "a") as writer: + writer.write(data) + + def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): """Get a transform function that convert items from Hugging Face dataset (pyarrow) to torch tensors. Importantly, images are converted from PIL, which corresponds to diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 6a8805dc..9bcdaea3 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -248,7 +248,7 @@ def control_loop( if teleoperate and policy is not None: raise ValueError("When `teleoperate` is True, `policy` should be None.") - if dataset is not None and fps is not None and dataset["fps"] != fps: + if dataset is not None and fps is not None and dataset.fps != fps: raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).") timestamp = 0 diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 3d9073b0..86233251 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -109,8 +109,6 @@ from lerobot.common.datasets.image_writer import ImageWriter from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.populate_dataset import ( create_lerobot_dataset, - delete_current_episode, - save_current_episode, ) from lerobot.common.robot_devices.control_utils import ( control_loop, @@ -195,6 +193,7 @@ def record( robot: Robot, root: str, repo_id: str, + single_task: str, pretrained_policy_name_or_path: str | None = None, policy_overrides: List[str] | None = None, fps: int | None = None, @@ -219,6 +218,11 @@ def record( device = None use_amp = None + if single_task: + task = single_task + else: + raise NotImplementedError("Only single-task recording is supported for now") + # Load pretrained policy if pretrained_policy_name_or_path is not None: policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides) @@ -235,8 +239,8 @@ def record( sanity_check_dataset_name(repo_id, policy) image_writer = ImageWriter( write_dir=root, - num_image_writer_processes=num_image_writer_processes, - num_image_writer_threads=num_image_writer_threads_per_camera * robot.num_cameras, + num_processes=num_image_writer_processes, + num_threads=num_image_writer_threads_per_camera * robot.num_cameras, ) dataset = LeRobotDataset.create(repo_id, fps, robot, image_writer=image_writer) @@ -261,7 +265,12 @@ def record( if recorded_episodes >= num_episodes: break - episode_index = dataset["num_episodes"] + # TODO(aliberts): add task prompt for multitask here. Might need to temporarily disable event if + # input() messes with them. + # if multi_task: + # task = input("Enter your task description: ") + + episode_index = dataset.episode_buffer["episode_index"] log_say(f"Recording episode {episode_index}", play_sounds) record_episode( dataset=dataset, @@ -289,11 +298,11 @@ def record( log_say("Re-record episode", play_sounds) events["rerecord_episode"] = False events["exit_early"] = False - delete_current_episode(dataset) + dataset.delete_episode() continue # Increment by one dataset["current_episode_index"] - save_current_episode(dataset) + dataset.add_episode(task) if events["stop_recording"]: break @@ -378,9 +387,21 @@ if __name__ == "__main__": ) parser_record = subparsers.add_parser("record", parents=[base_parser]) + task_args = parser_record.add_mutually_exclusive_group(required=True) parser_record.add_argument( "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" ) + task_args.add_argument( + "--single-task", + type=str, + help="A short but accurate description of the task performed during the recording.", + ) + # TODO(aliberts): add multi-task support + # task_args.add_argument( + # "--multi-task", + # type=int, + # help="You will need to enter the task performed at the start of each episode.", + # ) parser_record.add_argument( "--root", type=Path,