From a805458c7eb57b93522610c7f3fa79e204567725 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Tue, 22 Oct 2024 19:57:52 +0200 Subject: [PATCH] Add local_files_only, encode_videos, fix bugs to pass tests (WIP) --- lerobot/common/datasets/lerobot_dataset.py | 127 ++++++++++++++++----- lerobot/common/datasets/utils.py | 66 ++++++++--- lerobot/common/datasets/video_utils.py | 4 +- lerobot/scripts/control_robot.py | 66 ++++++----- 4 files changed, 183 insertions(+), 80 deletions(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index ffbcf0fb..ad5a37cf 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -17,6 +17,7 @@ import json import logging import os import shutil +from functools import cached_property from pathlib import Path from typing import Callable @@ -30,20 +31,32 @@ from huggingface_hub import snapshot_download, upload_folder from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats from lerobot.common.datasets.image_writer import ImageWriter from lerobot.common.datasets.utils import ( + EPISODES_PATH, + INFO_PATH, + TASKS_PATH, append_jsonl, check_delta_timestamps, check_timestamps_sync, + check_version_compatibility, create_branch, create_empty_dataset_info, + flatten_dict, get_delta_indices, get_episode_data_index, get_hub_safe_version, hf_transform_to_torch, - load_metadata, + load_episode_dicts, + load_info, + load_stats, + load_tasks, unflatten_dict, write_json, ) -from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision +from lerobot.common.datasets.video_utils import ( + VideoFrame, + decode_video_frames_torchvision, + encode_video_frames, +) from lerobot.common.robot_devices.robots.utils import Robot # For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md @@ -61,6 +74,7 @@ class LeRobotDataset(torch.utils.data.Dataset): delta_timestamps: dict[list[float]] | None = None, tolerance_s: float = 1e-4, download_videos: bool = True, + local_files_only: bool = False, video_backend: str | None = None, image_writer: ImageWriter | None = None, ): @@ -162,21 +176,26 @@ class LeRobotDataset(torch.utils.data.Dataset): self.delta_timestamps = delta_timestamps self.episodes = episodes self.tolerance_s = tolerance_s - self.download_videos = download_videos self.video_backend = video_backend if video_backend is not None else "pyav" self.image_writer = image_writer self.delta_indices = None self.consolidated = True self.episode_buffer = {} + self.local_files_only = local_files_only # Load metadata self.root.mkdir(exist_ok=True, parents=True) - self._version = get_hub_safe_version(repo_id, CODEBASE_VERSION) self.pull_from_repo(allow_patterns="meta/") - self.info, self.episode_dicts, self.stats, self.tasks = load_metadata(self.root) + self.info = load_info(self.root) + self.stats = load_stats(self.root) + self.tasks = load_tasks(self.root) + self.episode_dicts = load_episode_dicts(self.root) + + # Check version + check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) # Load actual data - self.download_episodes() + self.download_episodes(download_videos) self.hf_dataset = self.load_hf_dataset() self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts) @@ -199,6 +218,15 @@ class LeRobotDataset(torch.utils.data.Dataset): # - [ ] Update episode_index (arg update=True) # - [ ] Update info.json (arg update=True) + @cached_property + def _hub_version(self) -> str | None: + return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION) + + @property + def _version(self) -> str: + """Codebase version used to create this dataset.""" + return self.info["codebase_version"] + def push_to_repo(self, push_videos: bool = True) -> None: if not self.consolidated: raise RuntimeError( @@ -225,13 +253,14 @@ class LeRobotDataset(torch.utils.data.Dataset): snapshot_download( self.repo_id, repo_type="dataset", - revision=self._version, + revision=self._hub_version, local_dir=self.root, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, + local_files_only=self.local_files_only, ) - def download_episodes(self) -> None: + def download_episodes(self, download_videos: bool = True) -> None: """Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present @@ -240,10 +269,10 @@ class LeRobotDataset(torch.utils.data.Dataset): # TODO(rcadene, aliberts): implement faster transfer # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads files = None - ignore_patterns = None if self.download_videos else "videos/" + ignore_patterns = None if download_videos else "videos/" if self.episodes is not None: files = [self.get_data_file_path(ep_idx) for ep_idx in self.episodes] - if len(self.video_keys) > 0 and self.download_videos: + if len(self.video_keys) > 0 and download_videos: video_files = [ self.get_video_file_path(ep_idx, vid_key) for vid_key in self.video_keys @@ -495,7 +524,7 @@ class LeRobotDataset(torch.utils.data.Dataset): item = {**video_frames, **item} if self.image_transforms is not None: - image_keys = self.camera_keys if self.download_videos else self.image_keys + image_keys = self.camera_keys for cam in image_keys: item[cam] = self.image_transforms(item[cam]) @@ -521,6 +550,7 @@ class LeRobotDataset(torch.utils.data.Dataset): "timestamp": [], "next.done": [], **{key: [] for key in self.keys}, + **{key: [] for key in self.image_keys}, } def add_frame(self, frame: dict) -> None: @@ -553,6 +583,8 @@ class LeRobotDataset(torch.utils.data.Dataset): image=frame[cam_key], file_path=img_path, ) + if cam_key in self.image_keys: + self.episode_buffer[cam_key].append(str(img_path)) def add_episode(self, task: str, encode_videos: bool = False) -> None: """ @@ -574,6 +606,8 @@ class LeRobotDataset(torch.utils.data.Dataset): self.episode_buffer["next.done"][-1] = True for key in self.episode_buffer: + if key in self.image_keys: + continue if key in self.keys: self.episode_buffer[key] = torch.stack(self.episode_buffer[key]) elif key == "episode_index": @@ -583,11 +617,12 @@ class LeRobotDataset(torch.utils.data.Dataset): else: self.episode_buffer[key] = torch.tensor(self.episode_buffer[key]) + self.episode_buffer["index"] = torch.arange(self.total_frames, self.total_frames + episode_length) self._save_episode_to_metadata(episode_index, episode_length, task, task_index) self._save_episode_table(episode_index) - if encode_videos: - pass # TODO + if encode_videos and len(self.video_keys) > 0: + self.encode_videos() # Reset the buffer self.episode_buffer = self._create_episode_buffer() @@ -614,7 +649,7 @@ class LeRobotDataset(torch.utils.data.Dataset): "task_index": task_index, "task": task, } - append_jsonl(task_dict, self.root / "meta/tasks.jsonl") + append_jsonl(task_dict, self.root / TASKS_PATH) chunk = self.get_episode_chunk(episode_index) if chunk >= self.total_chunks: @@ -622,22 +657,23 @@ class LeRobotDataset(torch.utils.data.Dataset): self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} self.info["total_videos"] += len(self.video_keys) - write_json(self.info, self.root / "meta/info.json") + write_json(self.info, self.root / INFO_PATH) episode_dict = { "episode_index": episode_index, "tasks": [task], "length": episode_length, } - append_jsonl(episode_dict, self.root / "meta/episodes.jsonl") + self.episode_dicts.append(episode_dict) + append_jsonl(episode_dict, self.root / EPISODES_PATH) def delete_episode(self) -> None: episode_index = self.episode_buffer["episode_index"] if self.image_writer is not None: for cam_key in self.camera_keys: - cam_dir = self.image_writer.get_episode_dir(episode_index, cam_key) - if cam_dir.is_dir(): - shutil.rmtree(cam_dir) + img_dir = self.image_writer.get_episode_dir(episode_index, cam_key, return_str=False) + if img_dir.is_dir(): + shutil.rmtree(img_dir) # Reset the buffer self.episode_buffer = self._create_episode_buffer() @@ -653,27 +689,54 @@ class LeRobotDataset(torch.utils.data.Dataset): updated_file_name = self.get_data_file_path(ep_idx) current_file_name.rename(updated_file_name) + def _remove_image_writer(self) -> None: + if self.image_writer is not None: + self.image_writer = None + + def encode_videos(self) -> None: + # Use ffmpeg to convert frames stored as png into mp4 videos + for episode_index in range(self.num_episodes): + for key in self.video_keys: + # TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need + # to call self.image_writer here + tmp_imgs_dir = self.image_writer.get_episode_dir(episode_index, key) + video_path = self.get_video_file_path(episode_index, key, return_str=False) + if video_path.is_file(): + # Skip if video is already encoded. Could be the case when resuming data recording. + continue + # note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, + # since video encoding with ffmpeg is already using multithreading. + encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True) + shutil.rmtree(tmp_imgs_dir) + def consolidate(self, run_compute_stats: bool = True) -> None: self._update_data_file_names() + self.hf_dataset = self.load_hf_dataset() + self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts) + check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s) + + if len(self.video_keys) > 0: + self.encode_videos() + if run_compute_stats: logging.info("Computing dataset statistics") - self.hf_dataset = self.load_hf_dataset() + self._remove_image_writer() self.stats = compute_stats(self) - serialized_stats = {key: value.tolist() for key, value in self.stats.items()} + serialized_stats = flatten_dict(self.stats) + serialized_stats = {key: value.tolist() for key, value in serialized_stats.items()} serialized_stats = unflatten_dict(serialized_stats) write_json(serialized_stats, self.root / "meta/stats.json") + self.consolidated = True else: logging.warning("Skipping computation of the dataset statistics.") - self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts) - pass # TODO + # TODO(aliberts) # Sanity checks: # - [ ] shapes # - [ ] ep_lenghts # - [ ] number of files # - [ ] names of files (e.g. parquet 00000-of-00001 and 00001-of-00002) # - [ ] no remaining self.image_writer.dir - self.consolidated = True @classmethod def create( @@ -691,7 +754,6 @@ class LeRobotDataset(torch.utils.data.Dataset): obj = cls.__new__(cls) obj.repo_id = repo_id obj.root = root if root is not None else LEROBOT_HOME / repo_id - obj._version = CODEBASE_VERSION obj.tolerance_s = tolerance_s obj.image_writer = image_writer @@ -702,21 +764,26 @@ class LeRobotDataset(torch.utils.data.Dataset): ) obj.tasks, obj.stats, obj.episode_dicts = {}, {}, [] - obj.info = create_empty_dataset_info(obj._version, fps, robot, use_videos) - write_json(obj.info, obj.root / "meta/info.json") + obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot, use_videos) + write_json(obj.info, obj.root / INFO_PATH) # TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer obj.episode_buffer = obj._create_episode_buffer() - # This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. - # It is used to know when certain operations are need (for instance, computing dataset statistics). - # In order to be able to push the dataset to the hub, it needs to be consolidation first. + # This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It + # is used to know when certain operations are need (for instance, computing dataset statistics). In + # order to be able to push the dataset to the hub, it needs to be consolidated first by calling + # self.consolidate(). obj.consolidated = True + obj.local_files_only = True + obj.download_videos = False + obj.episodes = None obj.hf_dataset = None obj.image_transforms = None obj.delta_timestamps = None + obj.delta_indices = None obj.episode_data_index = None obj.video_backend = video_backend if video_backend is not None else "pyav" return obj diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 8985e449..8625808e 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -30,6 +30,12 @@ from torchvision import transforms from lerobot.common.robot_devices.robots.utils import Robot DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk + +INFO_PATH = "meta/info.json" +EPISODES_PATH = "meta/episodes.jsonl" +STATS_PATH = "meta/stats.json" +TASKS_PATH = "meta/tasks.jsonl" + DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" DEFAULT_PARQUET_PATH = ( "data/chunk-{episode_chunk:03d}/train-{episode_index:05d}-of-{total_episodes:05d}.parquet" @@ -104,6 +110,32 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): return items_dict +def _get_major_minor(version: str) -> tuple[int]: + split = version.strip("v").split(".") + return int(split[0]), int(split[1]) + + +def check_version_compatibility( + repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True +) -> None: + current_major, _ = _get_major_minor(current_version) + major_to_check, _ = _get_major_minor(version_to_check) + if major_to_check < current_major and enforce_breaking_major: + raise ValueError( + f"""The dataset you requested ({repo_id}) is in {version_to_check} format. We introduced a new + format with v2.0 that is not backward compatible. Please use our conversion script + first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format.""" + ) + elif float(version_to_check.strip("v")) < float(current_version.strip("v")): + warnings.warn( + f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the + codebase. The current codebase version is {current_version}. You should be fine since + backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on + Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""", + stacklevel=1, + ) + + def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str: num_version = float(version.strip("v")) if num_version < 2 and enforce_v2: @@ -131,30 +163,28 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> return version -def load_metadata(local_dir: Path) -> tuple[dict | list]: - """Loads metadata files from a dataset.""" - info_path = local_dir / "meta/info.json" - episodes_path = local_dir / "meta/episodes.jsonl" - stats_path = local_dir / "meta/stats.json" - tasks_path = local_dir / "meta/tasks.jsonl" +def load_info(local_dir: Path) -> dict: + with open(local_dir / INFO_PATH) as f: + return json.load(f) - with open(info_path) as f: - info = json.load(f) - with jsonlines.open(episodes_path, "r") as reader: - episode_dicts = list(reader) - - with open(stats_path) as f: +def load_stats(local_dir: Path) -> dict: + with open(local_dir / STATS_PATH) as f: stats = json.load(f) + stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()} + return unflatten_dict(stats) - with jsonlines.open(tasks_path, "r") as reader: + +def load_tasks(local_dir: Path) -> dict: + with jsonlines.open(local_dir / TASKS_PATH, "r") as reader: tasks = list(reader) - stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()} - stats = unflatten_dict(stats) - tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} + return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} - return info, episode_dicts, stats, tasks + +def load_episode_dicts(local_dir: Path) -> dict: + with jsonlines.open(local_dir / EPISODES_PATH, "r") as reader: + return list(reader) def create_empty_dataset_info(codebase_version: str, fps: int, robot: Robot, use_videos: bool = True) -> dict: @@ -229,7 +259,7 @@ def check_timestamps_sync( # Track original indices before masking original_indices = torch.arange(len(diffs)) filtered_indices = original_indices[mask] - outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance).squeeze() + outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance) # .squeeze() outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices] episode_indices = torch.stack(hf_dataset["episode_index"]) diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 6a606415..b5d634ba 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -126,8 +126,8 @@ def decode_video_frames_torchvision( def encode_video_frames( - imgs_dir: Path, - video_path: Path, + imgs_dir: Path | str, + video_path: Path | str, fps: int, vcodec: str = "libsvtav1", pix_fmt: str = "yuv420p", diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 62d6760b..5bf427f4 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -194,19 +194,17 @@ def record( pretrained_policy_name_or_path: str | None = None, policy_overrides: List[str] | None = None, fps: int | None = None, - warmup_time_s=2, - episode_time_s=10, - reset_time_s=5, - num_episodes=50, - video=True, - run_compute_stats=True, - push_to_hub=True, - tags=None, - num_image_writer_processes=0, - num_image_writer_threads_per_camera=4, - force_override=False, - display_cameras=True, - play_sounds=True, + warmup_time_s: int | float = 2, + episode_time_s: int | float = 10, + reset_time_s: int | float = 5, + num_episodes: int = 50, + video: bool = True, + run_compute_stats: bool = True, + push_to_hub: bool = True, + num_image_writer_processes: int = 0, + num_image_writer_threads_per_camera: int = 4, + display_cameras: bool = True, + play_sounds: bool = True, ) -> LeRobotDataset: # TODO(rcadene): Add option to record logs listener = None @@ -234,12 +232,18 @@ def record( # Create empty dataset or load existing saved episodes sanity_check_dataset_name(repo_id, policy) - image_writer = ImageWriter( - write_dir=root, - num_processes=num_image_writer_processes, - num_threads=num_image_writer_threads_per_camera * robot.num_cameras, + if len(robot.cameras) > 0: + image_writer = ImageWriter( + write_dir=root, + num_processes=num_image_writer_processes, + num_threads=num_image_writer_threads_per_camera * robot.num_cameras, + ) + else: + image_writer = None + + dataset = LeRobotDataset.create( + repo_id, fps, robot, root=root, image_writer=image_writer, use_videos=video ) - dataset = LeRobotDataset.create(repo_id, fps, robot, root=root, image_writer=image_writer) if not robot.is_connected: robot.connect() @@ -307,8 +311,9 @@ def record( log_say("Stop recording", play_sounds, blocking=True) stop_recording(robot, listener, display_cameras) - logging.info("Waiting for image writer to terminate...") - dataset.image_writer.stop() + if dataset.image_writer is not None: + logging.info("Waiting for image writer to terminate...") + dataset.image_writer.stop() dataset.consolidate(run_compute_stats) @@ -322,27 +327,28 @@ def record( @safe_disconnect def replay( - robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug", play_sounds=True + robot: Robot, + root: Path, + repo_id: str, + episode: int, + fps: int | None = None, + play_sounds: bool = True, + local_files_only: bool = True, ): # TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset # TODO(rcadene): Add option to record logs - local_dir = Path(root) / repo_id - if not local_dir.exists(): - raise ValueError(local_dir) - dataset = LeRobotDataset(repo_id, root=root) - items = dataset.hf_dataset.select_columns("action") - from_idx = dataset.episode_data_index["from"][episode].item() - to_idx = dataset.episode_data_index["to"][episode].item() + dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only) + actions = dataset.hf_dataset.select_columns("action") if not robot.is_connected: robot.connect() log_say("Replaying episode", play_sounds, blocking=True) - for idx in range(from_idx, to_idx): + for idx in range(dataset.num_samples): start_episode_t = time.perf_counter() - action = items[idx]["action"] + action = actions[idx]["action"] robot.send_action(action) dt_s = time.perf_counter() - start_episode_t