From a2a8538ac97407d3fa65da4842a33ddbd7f78e82 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Wed, 23 Oct 2024 11:38:07 +0200 Subject: [PATCH] add write_stats, changes names, add some typing --- lerobot/common/datasets/factory.py | 3 +- lerobot/common/datasets/image_writer.py | 5 +-- lerobot/common/datasets/lerobot_dataset.py | 36 ++++++++++------------ lerobot/common/datasets/utils.py | 10 ++++-- lerobot/scripts/control_robot.py | 5 ++- 5 files changed, 33 insertions(+), 26 deletions(-) diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 96a353fb..04b6e57b 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -91,9 +91,9 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData ) if isinstance(cfg.dataset_repo_id, str): + # TODO (aliberts): add 'episodes' arg from config after removing hydra dataset = LeRobotDataset( cfg.dataset_repo_id, - split=split, delta_timestamps=cfg.training.get("delta_timestamps"), image_transforms=image_transforms, video_backend=cfg.video_backend, @@ -101,7 +101,6 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData else: dataset = MultiLeRobotDataset( cfg.dataset_repo_id, - split=split, delta_timestamps=cfg.training.get("delta_timestamps"), image_transforms=image_transforms, video_backend=cfg.video_backend, diff --git a/lerobot/common/datasets/image_writer.py b/lerobot/common/datasets/image_writer.py index 09f803e2..0900d910 100644 --- a/lerobot/common/datasets/image_writer.py +++ b/lerobot/common/datasets/image_writer.py @@ -21,7 +21,7 @@ import torch import tqdm from PIL import Image -DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png" +DEFAULT_IMAGE_PATH = "{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png" def safe_stop_image_writer(func): @@ -54,7 +54,8 @@ class ImageWriter: """ def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1): - self.dir = write_dir + self.dir = write_dir / "images" + self.dir.mkdir(parents=True, exist_ok=True) self.image_path = DEFAULT_IMAGE_PATH self.num_processes = num_processes self.num_threads = self.num_threads_per_process = num_threads diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 9721cd62..0c62756e 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -33,6 +33,7 @@ from lerobot.common.datasets.image_writer import ImageWriter from lerobot.common.datasets.utils import ( EPISODES_PATH, INFO_PATH, + STATS_PATH, TASKS_PATH, append_jsonl, check_delta_timestamps, @@ -40,7 +41,6 @@ from lerobot.common.datasets.utils import ( check_version_compatibility, create_branch, create_empty_dataset_info, - flatten_dict, get_delta_indices, get_episode_data_index, get_hub_safe_version, @@ -49,8 +49,8 @@ from lerobot.common.datasets.utils import ( load_info, load_stats, load_tasks, - unflatten_dict, write_json, + write_stats, ) from lerobot.common.datasets.video_utils import ( VideoFrame, @@ -227,11 +227,11 @@ class LeRobotDataset(torch.utils.data.Dataset): """Codebase version used to create this dataset.""" return self.info["codebase_version"] - def push_to_repo(self, push_videos: bool = True) -> None: + def push_to_hub(self, push_videos: bool = True) -> None: if not self.consolidated: raise RuntimeError( "You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet." - "Please use the '.consolidate()' method first." + "Please call the dataset 'consolidate()' method first." ) ignore_patterns = ["images/"] if not push_videos: @@ -675,7 +675,9 @@ class LeRobotDataset(torch.utils.data.Dataset): # Reset the buffer self.episode_buffer = self._create_episode_buffer() - def _remove_image_writer(self) -> None: + def read_mode(self) -> None: + """Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first.""" + # TODO(aliberts, rcadene): find better api/interface for this. if self.image_writer is not None: self.image_writer = None @@ -693,9 +695,8 @@ class LeRobotDataset(torch.utils.data.Dataset): # note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, # since video encoding with ffmpeg is already using multithreading. encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True) - shutil.rmtree(tmp_imgs_dir) - def consolidate(self, run_compute_stats: bool = True) -> None: + def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None: self.hf_dataset = self.load_hf_dataset() self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts) check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s) @@ -703,14 +704,13 @@ class LeRobotDataset(torch.utils.data.Dataset): if len(self.video_keys) > 0: self.encode_videos() + if not keep_image_files: + shutil.rmtree(self.image_writer.dir) + if run_compute_stats: - logging.info("Computing dataset statistics") - self._remove_image_writer() + self.read_mode() self.stats = compute_stats(self) - serialized_stats = flatten_dict(self.stats) - serialized_stats = {key: value.tolist() for key, value in serialized_stats.items()} - serialized_stats = unflatten_dict(serialized_stats) - write_json(serialized_stats, self.root / "meta/stats.json") + write_stats(self.stats, self.root / STATS_PATH) self.consolidated = True else: logging.warning("Skipping computation of the dataset statistics.") @@ -784,8 +784,8 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): def __init__( self, repo_ids: list[str], - root: Path | None = LEROBOT_HOME, - split: str = "train", + root: Path | None = None, + episodes: dict | None = None, image_transforms: Callable | None = None, delta_timestamps: dict[list[float]] | None = None, video_backend: str | None = None, @@ -797,8 +797,8 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): self._datasets = [ LeRobotDataset( repo_id, - root=root, - split=split, + root=root / repo_id if root is not None else None, + episodes=episodes[repo_id] if episodes is not None else None, delta_timestamps=delta_timestamps, image_transforms=image_transforms, video_backend=video_backend, @@ -834,7 +834,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): self.disabled_data_keys.update(extra_keys) self.root = root - self.split = split self.image_transforms = image_transforms self.delta_timestamps = delta_timestamps self.stats = aggregate_stats(self._datasets) @@ -948,7 +947,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): return ( f"{self.__class__.__name__}(\n" f" Repository IDs: '{self.repo_ids}',\n" - f" Split: '{self.split}',\n" f" Number of Samples: {self.num_samples},\n" f" Number of Episodes: {self.num_episodes},\n" f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n" diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index aa9c0c04..394723c0 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -48,7 +48,7 @@ This dataset was created using [LeRobot](https://github.com/huggingface/lerobot) """ -def flatten_dict(d, parent_key="", sep="/"): +def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict: """Flatten a nested dictionary structure by collapsing nested keys into one key with a separator. For example: @@ -67,7 +67,7 @@ def flatten_dict(d, parent_key="", sep="/"): return dict(items) -def unflatten_dict(d, sep="/"): +def unflatten_dict(d: dict, sep: str = "/") -> dict: outdict = {} for key, value in d.items(): parts = key.split(sep) @@ -92,6 +92,12 @@ def append_jsonl(data: dict, fpath: Path) -> None: writer.write(data) +def write_stats(stats: dict[str, torch.Tensor | dict], fpath: Path) -> None: + serialized_stats = {key: value.tolist() for key, value in flatten_dict(stats).items()} + serialized_stats = unflatten_dict(serialized_stats) + write_json(serialized_stats, fpath) + + def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): """Get a transform function that convert items from Hugging Face dataset (pyarrow) to torch tensors. Importantly, images are converted from PIL, which corresponds to diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 5bf427f4..9ef50ced 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -315,11 +315,14 @@ def record( logging.info("Waiting for image writer to terminate...") dataset.image_writer.stop() + if run_compute_stats: + logging.info("Computing dataset statistics") + dataset.consolidate(run_compute_stats) # lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds) if push_to_hub: - dataset.push_to_repo() + dataset.push_to_hub() log_say("Exiting", play_sounds) return dataset