From f3630ad91042e9ee323d8522b8371f4c7ab5f8e6 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Tue, 5 Nov 2024 19:09:12 +0100 Subject: [PATCH] Fix tests --- lerobot/common/datasets/image_writer.py | 21 +- lerobot/common/datasets/lerobot_dataset.py | 177 +++++------ lerobot/common/datasets/utils.py | 45 +++ tests/conftest.py | 1 - tests/fixtures/dataset.py | 67 ---- tests/fixtures/dataset_factories.py | 348 +++++++++------------ tests/fixtures/defaults.py | 25 +- tests/fixtures/files.py | 56 ++-- tests/fixtures/hub.py | 46 ++- tests/test_datasets.py | 10 +- tests/test_delta_timestamps.py | 23 +- tests/test_image_writer.py | 94 ++---- tests/test_online_buffer.py | 20 +- 13 files changed, 437 insertions(+), 496 deletions(-) delete mode 100644 tests/fixtures/dataset.py diff --git a/lerobot/common/datasets/image_writer.py b/lerobot/common/datasets/image_writer.py index 180069d7..13df091b 100644 --- a/lerobot/common/datasets/image_writer.py +++ b/lerobot/common/datasets/image_writer.py @@ -22,8 +22,6 @@ import numpy as np import PIL.Image import torch -DEFAULT_IMAGE_PATH = "{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png" - def safe_stop_image_writer(func): def wrapper(*args, **kwargs): @@ -87,7 +85,7 @@ def worker_process(queue: queue.Queue, num_threads: int): t.join() -class ImageWriter: +class AsyncImageWriter: """ This class abstract away the initialisation of processes or/and threads to save images on disk asynchrounously, which is critical to control a robot and record data @@ -102,11 +100,7 @@ class ImageWriter: the number of threads. If it is still not stable, try to use 1 subprocess, or more. """ - def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1): - self.write_dir = write_dir - self.write_dir.mkdir(parents=True, exist_ok=True) - self.image_path = DEFAULT_IMAGE_PATH - + def __init__(self, num_processes: int = 0, num_threads: int = 1): self.num_processes = num_processes self.num_threads = num_threads self.queue = None @@ -134,17 +128,6 @@ class ImageWriter: p.start() self.processes.append(p) - def get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path: - fpath = self.image_path.format( - image_key=image_key, episode_index=episode_index, frame_index=frame_index - ) - return self.write_dir / fpath - - def get_episode_dir(self, episode_index: int, image_key: str) -> Path: - return self.get_image_file_path( - episode_index=episode_index, image_key=image_key, frame_index=0 - ).parent - def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path): if isinstance(image, torch.Tensor): # Convert tensor to numpy array to minimize main process time diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index f03d6826..ac6f7721 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -22,15 +22,18 @@ from pathlib import Path from typing import Callable import datasets +import numpy as np +import PIL.Image import torch import torch.utils from datasets import load_dataset from huggingface_hub import snapshot_download, upload_folder from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats -from lerobot.common.datasets.image_writer import ImageWriter +from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image from lerobot.common.datasets.utils import ( DEFAULT_FEATURES, + DEFAULT_IMAGE_PATH, EPISODES_PATH, INFO_PATH, STATS_PATH, @@ -44,6 +47,7 @@ from lerobot.common.datasets.utils import ( get_delta_indices, get_episode_data_index, get_features_from_robot, + get_hf_features_from_features, get_hub_safe_version, hf_transform_to_torch, load_episodes, @@ -140,14 +144,9 @@ class LeRobotDatasetMetadata: @property def features(self) -> dict[str, dict]: - """""" + """All features contained in the dataset.""" return self.info["features"] - @property - def keys(self) -> list[str]: - """Keys to access non-image data (state, actions etc.).""" - return self.info["keys"] - @property def image_keys(self) -> list[str]: """Keys to access visual modalities stored as images.""" @@ -268,7 +267,7 @@ class LeRobotDatasetMetadata: obj.root = root if root is not None else LEROBOT_HOME / repo_id if robot is not None: - features = get_features_from_robot(robot) + features = get_features_from_robot(robot, use_videos) robot_type = robot.robot_type if not all(cam.fps == fps for cam in robot.cameras.values()): logging.warning( @@ -522,35 +521,16 @@ class LeRobotDataset(torch.utils.data.Dataset): return len(self.episodes) if self.episodes is not None else self.meta.total_episodes @property - def features(self) -> list[str]: - return list(self._features) + self.meta.video_keys + def features(self) -> dict[str, dict]: + return self.meta.features @property - def _features(self) -> datasets.Features: + def hf_features(self) -> datasets.Features: """Features of the hf_dataset.""" if self.hf_dataset is not None: return self.hf_dataset.features - elif self.episode_buffer is None: - raise NotImplementedError( - "Dataset features must be infered from an existing hf_dataset or episode_buffer." - ) - - features = {} - for key in self.episode_buffer: - if key in ["episode_index", "frame_index", "index", "task_index"]: - features[key] = datasets.Value(dtype="int64") - elif key in ["next.done", "next.success"]: - features[key] = datasets.Value(dtype="bool") - elif key in ["timestamp", "next.reward"]: - features[key] = datasets.Value(dtype="float32") - elif key in self.meta.image_keys: - features[key] = datasets.Image() - elif key in self.meta.keys: - features[key] = datasets.Sequence( - length=self.meta.shapes[key], feature=datasets.Value(dtype="float32") - ) - - return datasets.Features(features) + else: + return get_hf_features_from_features(self.features) def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]: ep_start = self.episode_data_index["from"][ep_idx] @@ -650,17 +630,26 @@ class LeRobotDataset(torch.utils.data.Dataset): ) def _create_episode_buffer(self, episode_index: int | None = None) -> dict: - # TODO(aliberts): Handle resume + current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index return { "size": 0, - "episode_index": self.meta.total_episodes if episode_index is None else episode_index, - "task_index": None, - "frame_index": [], - "timestamp": [], - **{key: [] for key in self.meta.features}, - **{key: [] for key in self.meta.image_keys}, + **{key: [] if key != "episode_index" else current_ep_idx for key in self.features}, } + def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path: + fpath = DEFAULT_IMAGE_PATH.format( + image_key=image_key, episode_index=episode_index, frame_index=frame_index + ) + return self.root / fpath + + def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None: + if self.image_writer is None: + if isinstance(image, torch.Tensor): + image = image.cpu().numpy() + write_image(image, fpath) + else: + self.image_writer.save_image(image=image, fpath=fpath) + def add_frame(self, frame: dict) -> None: """ This function only adds the frame to the episode_buffer. Apart from images — which are written in a @@ -668,35 +657,25 @@ class LeRobotDataset(torch.utils.data.Dataset): then needs to be called. """ frame_index = self.episode_buffer["size"] - self.episode_buffer["frame_index"].append(frame_index) - self.episode_buffer["timestamp"].append(frame_index / self.fps) - self.episode_buffer["next.done"].append(False) - - # Save all observed modalities except images - for key in self.meta.keys: - self.episode_buffer[key].append(frame[key]) + for key, ft in self.features.items(): + if key == "frame_index": + self.episode_buffer[key].append(frame_index) + elif key == "timestamp": + self.episode_buffer[key].append(frame_index / self.fps) + elif key in frame and ft["dtype"] not in ["image", "video"]: + self.episode_buffer[key].append(frame[key]) + elif key in frame and ft["dtype"] in ["image", "video"]: + img_path = self._get_image_file_path( + episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index + ) + if frame_index == 0: + img_path.parent.mkdir(parents=True, exist_ok=True) + self._save_image(frame[key], img_path) + if ft["dtype"] == "image": + self.episode_buffer[key].append(str(img_path)) self.episode_buffer["size"] += 1 - if self.image_writer is None: - return - - # Save images - for cam_key in self.meta.camera_keys: - img_path = self.image_writer.get_image_file_path( - episode_index=self.episode_buffer["episode_index"], image_key=cam_key, frame_index=frame_index - ) - if frame_index == 0: - img_path.parent.mkdir(parents=True, exist_ok=True) - - self.image_writer.save_image( - image=frame[cam_key], - fpath=img_path, - ) - - if cam_key in self.meta.image_keys: - self.episode_buffer[cam_key].append(str(img_path)) - def add_episode(self, task: str, encode_videos: bool = False) -> None: """ This will save to disk the current episode in self.episode_buffer. Note that since it affects files on @@ -714,23 +693,28 @@ class LeRobotDataset(torch.utils.data.Dataset): raise NotImplementedError() task_index = self.meta.get_task_index(task) - self.episode_buffer["next.done"][-1] = True - for key in self.episode_buffer: - if key in self.meta.image_keys: - continue - elif key in self.meta.keys: - self.episode_buffer[key] = torch.stack(self.episode_buffer[key]) + if not set(self.episode_buffer.keys()) == set(self.features): + raise ValueError() + + for key, ft in self.features.items(): + if key == "index": + self.episode_buffer[key] = np.arange( + self.meta.total_frames, self.meta.total_frames + episode_length + ) elif key == "episode_index": - self.episode_buffer[key] = torch.full((episode_length,), episode_index) + self.episode_buffer[key] = np.full((episode_length,), episode_index) elif key == "task_index": - self.episode_buffer[key] = torch.full((episode_length,), task_index) - else: + self.episode_buffer[key] = np.full((episode_length,), task_index) + elif ft["dtype"] in ["image", "video"]: + continue + elif ft["shape"][0] == 1: self.episode_buffer[key] = torch.tensor(self.episode_buffer[key]) + elif ft["shape"][0] > 1: + self.episode_buffer[key] = torch.stack(self.episode_buffer[key]) + else: + raise ValueError() - self.episode_buffer["index"] = torch.arange( - self.meta.total_frames, self.meta.total_frames + episode_length - ) self.meta.add_episode(episode_index, episode_length, task, task_index) self._wait_image_writer() @@ -744,7 +728,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.consolidated = False def _save_episode_table(self, episode_index: int) -> None: - ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self._features, split="train") + ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self.hf_features, split="train") ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index) ep_data_path.parent.mkdir(parents=True, exist_ok=True) write_parquet(ep_dataset, ep_data_path) @@ -753,7 +737,9 @@ class LeRobotDataset(torch.utils.data.Dataset): episode_index = self.episode_buffer["episode_index"] if self.image_writer is not None: for cam_key in self.meta.camera_keys: - img_dir = self.image_writer.get_episode_dir(episode_index, cam_key) + img_dir = self._get_image_file_path( + episode_index=episode_index, image_key=cam_key, frame_index=0 + ).parent if img_dir.is_dir(): shutil.rmtree(img_dir) @@ -761,13 +747,12 @@ class LeRobotDataset(torch.utils.data.Dataset): self.episode_buffer = self._create_episode_buffer() def start_image_writer(self, num_processes: int = 0, num_threads: int = 1) -> None: - if isinstance(self.image_writer, ImageWriter): + if isinstance(self.image_writer, AsyncImageWriter): logging.warning( - "You are starting a new ImageWriter that is replacing an already exising one in the dataset." + "You are starting a new AsyncImageWriter that is replacing an already exising one in the dataset." ) - self.image_writer = ImageWriter( - write_dir=self.root / "images", + self.image_writer = AsyncImageWriter( num_processes=num_processes, num_threads=num_threads, ) @@ -787,19 +772,21 @@ class LeRobotDataset(torch.utils.data.Dataset): self.image_writer.wait_until_done() def encode_videos(self) -> None: - # Use ffmpeg to convert frames stored as png into mp4 videos + """ + Use ffmpeg to convert frames stored as png into mp4 videos. + Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, + since video encoding with ffmpeg is already using multithreading. + """ for episode_index in range(self.meta.total_episodes): for key in self.meta.video_keys: - # TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need - # to call self.image_writer here - tmp_imgs_dir = self.image_writer.get_episode_dir(episode_index, key) video_path = self.root / self.meta.get_video_file_path(episode_index, key) if video_path.is_file(): # Skip if video is already encoded. Could be the case when resuming data recording. continue - # note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, - # since video encoding with ffmpeg is already using multithreading. - encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True) + img_dir = self._get_image_file_path( + episode_index=episode_index, image_key=key, frame_index=0 + ).parent + encode_video_frames(img_dir, video_path, self.fps, overwrite=True) def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None: self.hf_dataset = self.load_hf_dataset() @@ -810,8 +797,10 @@ class LeRobotDataset(torch.utils.data.Dataset): self.encode_videos() self.meta.write_video_info() - if not keep_image_files and self.image_writer is not None: - shutil.rmtree(self.image_writer.write_dir) + if not keep_image_files: + img_dir = self.root / "images" + if img_dir.is_dir(): + shutil.rmtree(self.root / "images") video_files = list(self.root.rglob("*.mp4")) assert len(video_files) == self.num_episodes * len(self.meta.video_keys) @@ -989,7 +978,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): def features(self) -> datasets.Features: features = {} for dataset in self._datasets: - features.update({k: v for k, v in dataset._features.items() if k not in self.disabled_data_keys}) + features.update( + {k: v for k, v in dataset.hf_features.items() if k not in self.disabled_data_keys} + ) return features @property diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index eef319d9..8af6dadc 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -22,6 +22,7 @@ from typing import Any import datasets import jsonlines +import pyarrow.compute as pc import torch from datasets.table import embed_table_storage from huggingface_hub import DatasetCard, HfApi @@ -39,6 +40,7 @@ TASKS_PATH = "meta/tasks.jsonl" DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet" +DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png" DATASET_CARD_TEMPLATE = """ --- @@ -222,6 +224,24 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> return version +def get_hf_features_from_features(features: dict) -> datasets.Features: + hf_features = {} + for key, ft in features.items(): + if ft["dtype"] == "video": + continue + elif ft["dtype"] == "image": + hf_features[key] = datasets.Image() + elif ft["shape"] == (1,): + hf_features[key] = datasets.Value(dtype=ft["dtype"]) + else: + assert len(ft["shape"]) == 1 + hf_features[key] = datasets.Sequence( + length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"]) + ) + + return datasets.Features(hf_features) + + def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict: camera_ft = {} if robot.cameras: @@ -270,6 +290,31 @@ def get_episode_data_index( } +def calculate_total_episode( + hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True +) -> dict[str, torch.Tensor]: + episode_indices = sorted(hf_dataset.unique("episode_index")) + total_episodes = len(episode_indices) + if raise_if_not_contiguous and episode_indices != list(range(total_episodes)): + raise ValueError("episode_index values are not sorted and contiguous.") + return total_episodes + + +def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]: + episode_lengths = [] + table = hf_dataset.data.table + total_episodes = calculate_total_episode(hf_dataset) + for ep_idx in range(total_episodes): + ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) + episode_lengths.insert(ep_idx, len(ep_table)) + + cumulative_lenghts = list(accumulate(episode_lengths)) + return { + "from": torch.LongTensor([0] + cumulative_lenghts[:-1]), + "to": torch.LongTensor(cumulative_lenghts), + } + + def check_timestamps_sync( hf_dataset: datasets.Dataset, episode_data_index: dict[str, torch.Tensor], diff --git a/tests/conftest.py b/tests/conftest.py index 8491eeba..2075c2aa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,7 +25,6 @@ from tests.utils import DEVICE, ROBOT_CONFIG_PATH_TEMPLATE, make_camera, make_mo # Import fixture modules as plugins pytest_plugins = [ - "tests.fixtures.dataset", "tests.fixtures.dataset_factories", "tests.fixtures.files", "tests.fixtures.hub", diff --git a/tests/fixtures/dataset.py b/tests/fixtures/dataset.py deleted file mode 100644 index bd2060b6..00000000 --- a/tests/fixtures/dataset.py +++ /dev/null @@ -1,67 +0,0 @@ -import datasets -import pytest - -from lerobot.common.datasets.utils import get_episode_data_index -from tests.fixtures.defaults import DUMMY_CAMERA_KEYS - - -@pytest.fixture(scope="session") -def empty_info(info_factory) -> dict: - return info_factory( - keys=[], - image_keys=[], - video_keys=[], - shapes={}, - names={}, - ) - - -@pytest.fixture(scope="session") -def info(info_factory) -> dict: - return info_factory( - total_episodes=4, - total_frames=420, - total_tasks=3, - total_videos=8, - total_chunks=1, - ) - - -@pytest.fixture(scope="session") -def stats(stats_factory) -> list: - return stats_factory() - - -@pytest.fixture(scope="session") -def tasks() -> list: - return [ - {"task_index": 0, "task": "Pick up the block."}, - {"task_index": 1, "task": "Open the box."}, - {"task_index": 2, "task": "Make paperclips."}, - ] - - -@pytest.fixture(scope="session") -def episodes() -> list: - return [ - {"episode_index": 0, "tasks": ["Pick up the block."], "length": 100}, - {"episode_index": 1, "tasks": ["Open the box."], "length": 80}, - {"episode_index": 2, "tasks": ["Pick up the block."], "length": 90}, - {"episode_index": 3, "tasks": ["Make paperclips."], "length": 150}, - ] - - -@pytest.fixture(scope="session") -def episode_data_index(episodes) -> dict: - return get_episode_data_index(episodes) - - -@pytest.fixture(scope="session") -def hf_dataset(hf_dataset_factory) -> datasets.Dataset: - return hf_dataset_factory() - - -@pytest.fixture(scope="session") -def hf_dataset_image(hf_dataset_factory) -> datasets.Dataset: - image_keys = DUMMY_CAMERA_KEYS - return hf_dataset_factory(image_keys=image_keys) diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index bbd485b7..c773dac8 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -11,16 +11,19 @@ import torch from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata from lerobot.common.datasets.utils import ( DEFAULT_CHUNK_SIZE, + DEFAULT_FEATURES, DEFAULT_PARQUET_PATH, DEFAULT_VIDEO_PATH, + get_hf_features_from_features, hf_transform_to_torch, ) from tests.fixtures.defaults import ( DEFAULT_FPS, - DUMMY_CAMERA_KEYS, - DUMMY_KEYS, + DUMMY_CAMERA_FEATURES, + DUMMY_MOTOR_FEATURES, DUMMY_REPO_ID, DUMMY_ROBOT_TYPE, + DUMMY_VIDEO_INFO, ) @@ -73,16 +76,33 @@ def img_factory(img_array_factory): @pytest.fixture(scope="session") -def info_factory(): +def features_factory(): + def _create_features( + motor_features: dict = DUMMY_MOTOR_FEATURES, + camera_features: dict = DUMMY_CAMERA_FEATURES, + use_videos: bool = True, + ) -> dict: + if use_videos: + camera_ft = { + key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items() + } + else: + camera_ft = {key: {"dtype": "image", **ft} for key, ft in camera_features.items()} + return { + **motor_features, + **camera_ft, + **DEFAULT_FEATURES, + } + + return _create_features + + +@pytest.fixture(scope="session") +def info_factory(features_factory): def _create_info( codebase_version: str = CODEBASE_VERSION, fps: int = DEFAULT_FPS, robot_type: str = DUMMY_ROBOT_TYPE, - keys: list[str] = DUMMY_KEYS, - image_keys: list[str] | None = None, - video_keys: list[str] = DUMMY_CAMERA_KEYS, - shapes: dict | None = None, - names: dict | None = None, total_episodes: int = 0, total_frames: int = 0, total_tasks: int = 0, @@ -90,30 +110,14 @@ def info_factory(): total_chunks: int = 0, chunks_size: int = DEFAULT_CHUNK_SIZE, data_path: str = DEFAULT_PARQUET_PATH, - videos_path: str = DEFAULT_VIDEO_PATH, + video_path: str = DEFAULT_VIDEO_PATH, + motor_features: dict = DUMMY_MOTOR_FEATURES, + camera_features: dict = DUMMY_CAMERA_FEATURES, + use_videos: bool = True, ) -> dict: - if not image_keys: - image_keys = [] - if not shapes: - shapes = make_dummy_shapes(keys=keys, camera_keys=[*image_keys, *video_keys]) - if not names: - names = {key: [f"motor_{i}" for i in range(shapes[key])] for key in keys} - - video_info = {"videos_path": videos_path} - for key in video_keys: - video_info[key] = { - "video.fps": fps, - "video.width": shapes[key]["width"], - "video.height": shapes[key]["height"], - "video.channels": shapes[key]["channels"], - "video.codec": "av1", - "video.pix_fmt": "yuv420p", - "video.is_depth_map": False, - "has_audio": False, - } + features = features_factory(motor_features, camera_features, use_videos) return { "codebase_version": codebase_version, - "data_path": data_path, "robot_type": robot_type, "total_episodes": total_episodes, "total_frames": total_frames, @@ -123,12 +127,9 @@ def info_factory(): "chunks_size": chunks_size, "fps": fps, "splits": {}, - "keys": keys, - "video_keys": video_keys, - "image_keys": image_keys, - "shapes": shapes, - "names": names, - "videos": video_info if len(video_keys) > 0 else None, + "data_path": data_path, + "video_path": video_path if use_videos else None, + "features": features, } return _create_info @@ -137,32 +138,26 @@ def info_factory(): @pytest.fixture(scope="session") def stats_factory(): def _create_stats( - keys: list[str] = DUMMY_KEYS, - image_keys: list[str] | None = None, - video_keys: list[str] = DUMMY_CAMERA_KEYS, - shapes: dict | None = None, + features: dict[str] | None = None, ) -> dict: - if not image_keys: - image_keys = [] - if not shapes: - shapes = make_dummy_shapes(keys=keys, camera_keys=[*image_keys, *video_keys]) stats = {} - for key in keys: - shape = shapes[key] - stats[key] = { - "max": np.full(shape, 1, dtype=np.float32).tolist(), - "mean": np.full(shape, 0.5, dtype=np.float32).tolist(), - "min": np.full(shape, 0, dtype=np.float32).tolist(), - "std": np.full(shape, 0.25, dtype=np.float32).tolist(), - } - for key in [*image_keys, *video_keys]: - shape = (3, 1, 1) - stats[key] = { - "max": np.full(shape, 1, dtype=np.float32).tolist(), - "mean": np.full(shape, 0.5, dtype=np.float32).tolist(), - "min": np.full(shape, 0, dtype=np.float32).tolist(), - "std": np.full(shape, 0.25, dtype=np.float32).tolist(), - } + for key, ft in features.items(): + shape = ft["shape"] + dtype = ft["dtype"] + if dtype in ["image", "video"]: + stats[key] = { + "max": np.full((3, 1, 1), 1, dtype=np.float32).tolist(), + "mean": np.full((3, 1, 1), 0.5, dtype=np.float32).tolist(), + "min": np.full((3, 1, 1), 0, dtype=np.float32).tolist(), + "std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(), + } + else: + stats[key] = { + "max": np.full(shape, 1, dtype=dtype).tolist(), + "mean": np.full(shape, 0.5, dtype=dtype).tolist(), + "min": np.full(shape, 0, dtype=dtype).tolist(), + "std": np.full(shape, 0.25, dtype=dtype).tolist(), + } return stats return _create_stats @@ -185,7 +180,7 @@ def episodes_factory(tasks_factory): def _create_episodes( total_episodes: int = 3, total_frames: int = 400, - task_dicts: dict | None = None, + tasks: dict | None = None, multi_task: bool = False, ): if total_episodes <= 0 or total_frames <= 0: @@ -193,18 +188,18 @@ def episodes_factory(tasks_factory): if total_frames < total_episodes: raise ValueError("total_length must be greater than or equal to num_episodes.") - if not task_dicts: + if not tasks: min_tasks = 2 if multi_task else 1 total_tasks = random.randint(min_tasks, total_episodes) - task_dicts = tasks_factory(total_tasks) + tasks = tasks_factory(total_tasks) - if total_episodes < len(task_dicts) and not multi_task: + if total_episodes < len(tasks) and not multi_task: raise ValueError("The number of tasks should be less than the number of episodes.") # Generate random lengths that sum up to total_length lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist() - tasks_list = [task_dict["task"] for task_dict in task_dicts] + tasks_list = [task_dict["task"] for task_dict in tasks] num_tasks_available = len(tasks_list) episodes_list = [] @@ -231,81 +226,56 @@ def episodes_factory(tasks_factory): @pytest.fixture(scope="session") -def hf_dataset_factory(img_array_factory, episodes, tasks): +def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory): def _create_hf_dataset( - episode_dicts: list[dict] = episodes, - task_dicts: list[dict] = tasks, - keys: list[str] = DUMMY_KEYS, - image_keys: list[str] | None = None, - shapes: dict | None = None, + features: dict | None = None, + tasks: list[dict] | None = None, + episodes: list[dict] | None = None, fps: int = DEFAULT_FPS, ) -> datasets.Dataset: - if not image_keys: - image_keys = [] - if not shapes: - shapes = make_dummy_shapes(keys=keys, camera_keys=image_keys) - key_features = { - key: datasets.Sequence(length=shapes[key], feature=datasets.Value(dtype="float32")) - for key in keys - } - image_features = {key: datasets.Image() for key in image_keys} if image_keys else {} - common_features = { - "episode_index": datasets.Value(dtype="int64"), - "frame_index": datasets.Value(dtype="int64"), - "timestamp": datasets.Value(dtype="float32"), - "next.done": datasets.Value(dtype="bool"), - "index": datasets.Value(dtype="int64"), - "task_index": datasets.Value(dtype="int64"), - } - features = datasets.Features( - { - **key_features, - **image_features, - **common_features, - } - ) + if not tasks: + tasks = tasks_factory() + if not episodes: + episodes = episodes_factory() + if not features: + features = features_factory() - episode_index_col = np.array([], dtype=np.int64) - frame_index_col = np.array([], dtype=np.int64) timestamp_col = np.array([], dtype=np.float32) - next_done_col = np.array([], dtype=bool) + frame_index_col = np.array([], dtype=np.int64) + episode_index_col = np.array([], dtype=np.int64) task_index = np.array([], dtype=np.int64) - - for ep_dict in episode_dicts: + for ep_dict in episodes: + timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps)) + frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int))) episode_index_col = np.concatenate( (episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int)) ) - frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int))) - timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps)) - next_done_ep = np.full(ep_dict["length"], False, dtype=bool) - next_done_ep[-1] = True - next_done_col = np.concatenate((next_done_col, next_done_ep)) - ep_task_index = get_task_index(task_dicts, ep_dict["tasks"][0]) + ep_task_index = get_task_index(tasks, ep_dict["tasks"][0]) task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int))) index_col = np.arange(len(episode_index_col)) - key_cols = {key: np.random.random((len(index_col), shapes[key])).astype(np.float32) for key in keys} - image_cols = {} - if image_keys: - for key in image_keys: - image_cols[key] = [ - img_array_factory(width=shapes[key]["width"], height=shapes[key]["height"]) + robot_cols = {} + for key, ft in features.items(): + if ft["dtype"] == "image": + robot_cols[key] = [ + img_array_factory(width=ft["shapes"][0], height=ft["shapes"][1]) for _ in range(len(index_col)) ] + elif ft["shape"][0] > 1 and ft["dtype"] != "video": + robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"]) + hf_features = get_hf_features_from_features(features) dataset = datasets.Dataset.from_dict( { - **key_cols, - **image_cols, - "episode_index": episode_index_col, - "frame_index": frame_index_col, + **robot_cols, "timestamp": timestamp_col, - "next.done": next_done_col, + "frame_index": frame_index_col, + "episode_index": episode_index_col, "index": index_col, "task_index": task_index, }, - features=features, + features=hf_features, ) dataset.set_transform(hf_transform_to_torch) return dataset @@ -315,26 +285,37 @@ def hf_dataset_factory(img_array_factory, episodes, tasks): @pytest.fixture(scope="session") def lerobot_dataset_metadata_factory( - info, - stats, - tasks, - episodes, + info_factory, + stats_factory, + tasks_factory, + episodes_factory, mock_snapshot_download_factory, ): def _create_lerobot_dataset_metadata( root: Path, repo_id: str = DUMMY_REPO_ID, - info_dict: dict = info, - stats_dict: dict = stats, - task_dicts: list[dict] = tasks, - episode_dicts: list[dict] = episodes, - **kwargs, + info: dict | None = None, + stats: dict | None = None, + tasks: list[dict] | None = None, + episodes: list[dict] | None = None, + local_files_only: bool = False, ) -> LeRobotDatasetMetadata: + if not info: + info = info_factory() + if not stats: + stats = stats_factory(features=info["features"]) + if not tasks: + tasks = tasks_factory(total_tasks=info["total_tasks"]) + if not episodes: + episodes = episodes_factory( + total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks + ) + mock_snapshot_download = mock_snapshot_download_factory( - info_dict=info_dict, - stats_dict=stats_dict, - task_dicts=task_dicts, - episode_dicts=episode_dicts, + info=info, + stats=stats, + tasks=tasks, + episodes=episodes, ) with ( patch( @@ -347,48 +328,68 @@ def lerobot_dataset_metadata_factory( mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version mock_snapshot_download_patch.side_effect = mock_snapshot_download - return LeRobotDatasetMetadata( - repo_id=repo_id, root=root, local_files_only=kwargs.get("local_files_only", False) - ) + return LeRobotDatasetMetadata(repo_id=repo_id, root=root, local_files_only=local_files_only) return _create_lerobot_dataset_metadata @pytest.fixture(scope="session") def lerobot_dataset_factory( - info, - stats, - tasks, - episodes, - hf_dataset, + info_factory, + stats_factory, + tasks_factory, + episodes_factory, + hf_dataset_factory, mock_snapshot_download_factory, lerobot_dataset_metadata_factory, ): def _create_lerobot_dataset( root: Path, repo_id: str = DUMMY_REPO_ID, - info_dict: dict = info, - stats_dict: dict = stats, - task_dicts: list[dict] = tasks, - episode_dicts: list[dict] = episodes, - hf_ds: datasets.Dataset = hf_dataset, + total_episodes: int = 3, + total_frames: int = 150, + total_tasks: int = 1, + multi_task: bool = False, + info: dict | None = None, + stats: dict | None = None, + tasks: list[dict] | None = None, + episode_dicts: list[dict] | None = None, + hf_dataset: datasets.Dataset | None = None, **kwargs, ) -> LeRobotDataset: + if not info: + info = info_factory( + total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks + ) + if not stats: + stats = stats_factory(features=info["features"]) + if not tasks: + tasks = tasks_factory(total_tasks=info["total_tasks"]) + if not episode_dicts: + episode_dicts = episodes_factory( + total_episodes=info["total_episodes"], + total_frames=info["total_frames"], + tasks=tasks, + multi_task=multi_task, + ) + if not hf_dataset: + hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"]) + mock_snapshot_download = mock_snapshot_download_factory( - info_dict=info_dict, - stats_dict=stats_dict, - task_dicts=task_dicts, - episode_dicts=episode_dicts, - hf_ds=hf_ds, + info=info, + stats=stats, + tasks=tasks, + episodes=episode_dicts, + hf_dataset=hf_dataset, ) mock_metadata = lerobot_dataset_metadata_factory( root=root, repo_id=repo_id, - info_dict=info_dict, - stats_dict=stats_dict, - task_dicts=task_dicts, - episode_dicts=episode_dicts, - **kwargs, + info=info, + stats=stats, + tasks=tasks, + episodes=episode_dicts, + local_files_only=kwargs.get("local_files_only", False), ) with ( patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch, @@ -402,44 +403,3 @@ def lerobot_dataset_factory( return LeRobotDataset(repo_id=repo_id, root=root, **kwargs) return _create_lerobot_dataset - - -@pytest.fixture(scope="session") -def lerobot_dataset_from_episodes_factory( - info_factory, - tasks_factory, - episodes_factory, - hf_dataset_factory, - lerobot_dataset_factory, -): - def _create_lerobot_dataset_total_episodes( - root: Path, - total_episodes: int = 3, - total_frames: int = 150, - total_tasks: int = 1, - multi_task: bool = False, - repo_id: str = DUMMY_REPO_ID, - **kwargs, - ): - info_dict = info_factory( - total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks - ) - task_dicts = tasks_factory(total_tasks) - episode_dicts = episodes_factory( - total_episodes=total_episodes, - total_frames=total_frames, - task_dicts=task_dicts, - multi_task=multi_task, - ) - hf_dataset = hf_dataset_factory(episode_dicts=episode_dicts, task_dicts=task_dicts) - return lerobot_dataset_factory( - root=root, - repo_id=repo_id, - info_dict=info_dict, - task_dicts=task_dicts, - episode_dicts=episode_dicts, - hf_ds=hf_dataset, - **kwargs, - ) - - return _create_lerobot_dataset_total_episodes diff --git a/tests/fixtures/defaults.py b/tests/fixtures/defaults.py index 3072e0c7..a430ead8 100644 --- a/tests/fixtures/defaults.py +++ b/tests/fixtures/defaults.py @@ -3,6 +3,27 @@ from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME LEROBOT_TEST_DIR = LEROBOT_HOME / "_testing" DUMMY_REPO_ID = "dummy/repo" DUMMY_ROBOT_TYPE = "dummy_robot" -DUMMY_KEYS = ["state", "action"] -DUMMY_CAMERA_KEYS = ["laptop", "phone"] +DUMMY_MOTOR_FEATURES = { + "action": { + "dtype": "float32", + "shape": (6,), + "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], + }, + "state": { + "dtype": "float32", + "shape": (6,), + "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], + }, +} +DUMMY_CAMERA_FEATURES = { + "laptop": {"shape": (640, 480, 3), "names": ["width", "height", "channels"], "info": None}, + "phone": {"shape": (640, 480, 3), "names": ["width", "height", "channels"], "info": None}, +} DEFAULT_FPS = 30 +DUMMY_VIDEO_INFO = { + "video.fps": DEFAULT_FPS, + "video.codec": "av1", + "video.pix_fmt": "yuv420p", + "video.is_depth_map": False, + "has_audio": False, +} diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index a9ee2c35..5fe8a314 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -11,64 +11,77 @@ from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, @pytest.fixture(scope="session") -def info_path(info): - def _create_info_json_file(dir: Path, info_dict: dict = info) -> Path: +def info_path(info_factory): + def _create_info_json_file(dir: Path, info: dict | None = None) -> Path: + if not info: + info = info_factory() fpath = dir / INFO_PATH fpath.parent.mkdir(parents=True, exist_ok=True) with open(fpath, "w") as f: - json.dump(info_dict, f, indent=4, ensure_ascii=False) + json.dump(info, f, indent=4, ensure_ascii=False) return fpath return _create_info_json_file @pytest.fixture(scope="session") -def stats_path(stats): - def _create_stats_json_file(dir: Path, stats_dict: dict = stats) -> Path: +def stats_path(stats_factory): + def _create_stats_json_file(dir: Path, stats: dict | None = None) -> Path: + if not stats: + stats = stats_factory() fpath = dir / STATS_PATH fpath.parent.mkdir(parents=True, exist_ok=True) with open(fpath, "w") as f: - json.dump(stats_dict, f, indent=4, ensure_ascii=False) + json.dump(stats, f, indent=4, ensure_ascii=False) return fpath return _create_stats_json_file @pytest.fixture(scope="session") -def tasks_path(tasks): - def _create_tasks_jsonl_file(dir: Path, task_dicts: list = tasks) -> Path: +def tasks_path(tasks_factory): + def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path: + if not tasks: + tasks = tasks_factory() fpath = dir / TASKS_PATH fpath.parent.mkdir(parents=True, exist_ok=True) with jsonlines.open(fpath, "w") as writer: - writer.write_all(task_dicts) + writer.write_all(tasks) return fpath return _create_tasks_jsonl_file @pytest.fixture(scope="session") -def episode_path(episodes): - def _create_episodes_jsonl_file(dir: Path, episode_dicts: list = episodes) -> Path: +def episode_path(episodes_factory): + def _create_episodes_jsonl_file(dir: Path, episodes: list | None = None) -> Path: + if not episodes: + episodes = episodes_factory() fpath = dir / EPISODES_PATH fpath.parent.mkdir(parents=True, exist_ok=True) with jsonlines.open(fpath, "w") as writer: - writer.write_all(episode_dicts) + writer.write_all(episodes) return fpath return _create_episodes_jsonl_file @pytest.fixture(scope="session") -def single_episode_parquet_path(hf_dataset, info): +def single_episode_parquet_path(hf_dataset_factory, info_factory): def _create_single_episode_parquet( - dir: Path, hf_ds: datasets.Dataset = hf_dataset, ep_idx: int = 0 + dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None ) -> Path: + if not info: + info = info_factory() + if hf_dataset is None: + hf_dataset = hf_dataset_factory() + data_path = info["data_path"] chunks_size = info["chunks_size"] ep_chunk = ep_idx // chunks_size fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx) fpath.parent.mkdir(parents=True, exist_ok=True) - table = hf_ds.data.table + table = hf_dataset.data.table ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) pq.write_table(ep_table, fpath) return fpath @@ -77,8 +90,15 @@ def single_episode_parquet_path(hf_dataset, info): @pytest.fixture(scope="session") -def multi_episode_parquet_path(hf_dataset, info): - def _create_multi_episode_parquet(dir: Path, hf_ds: datasets.Dataset = hf_dataset) -> Path: +def multi_episode_parquet_path(hf_dataset_factory, info_factory): + def _create_multi_episode_parquet( + dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None + ) -> Path: + if not info: + info = info_factory() + if hf_dataset is None: + hf_dataset = hf_dataset_factory() + data_path = info["data_path"] chunks_size = info["chunks_size"] total_episodes = info["total_episodes"] @@ -86,7 +106,7 @@ def multi_episode_parquet_path(hf_dataset, info): ep_chunk = ep_idx // chunks_size fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx) fpath.parent.mkdir(parents=True, exist_ok=True) - table = hf_ds.data.table + table = hf_dataset.data.table ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) pq.write_table(ep_table, fpath) return dir / "data" diff --git a/tests/fixtures/hub.py b/tests/fixtures/hub.py index 8dd9e966..2300c883 100644 --- a/tests/fixtures/hub.py +++ b/tests/fixtures/hub.py @@ -1,5 +1,6 @@ from pathlib import Path +import datasets import pytest from huggingface_hub.utils import filter_repo_objects @@ -9,16 +10,16 @@ from tests.fixtures.defaults import LEROBOT_TEST_DIR @pytest.fixture(scope="session") def mock_snapshot_download_factory( - info, + info_factory, info_path, - stats, + stats_factory, stats_path, - tasks, + tasks_factory, tasks_path, - episodes, + episodes_factory, episode_path, single_episode_parquet_path, - hf_dataset, + hf_dataset_factory, ): """ This factory allows to patch snapshot_download such that when called, it will create expected files rather @@ -26,8 +27,25 @@ def mock_snapshot_download_factory( """ def _mock_snapshot_download_func( - info_dict=info, stats_dict=stats, task_dicts=tasks, episode_dicts=episodes, hf_ds=hf_dataset + info: dict | None = None, + stats: dict | None = None, + tasks: list[dict] | None = None, + episodes: list[dict] | None = None, + hf_dataset: datasets.Dataset | None = None, ): + if not info: + info = info_factory() + if not stats: + stats = stats_factory(features=info["features"]) + if not tasks: + tasks = tasks_factory(total_tasks=info["total_tasks"]) + if not episodes: + episodes = episodes_factory( + total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks + ) + if not hf_dataset: + hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"]) + def _extract_episode_index_from_path(fpath: str) -> int: path = Path(fpath) if path.suffix == ".parquet" and path.stem.startswith("episode_"): @@ -53,10 +71,10 @@ def mock_snapshot_download_factory( all_files.extend(meta_files) data_files = [] - for episode_dict in episode_dicts: + for episode_dict in episodes: ep_idx = episode_dict["episode_index"] - ep_chunk = ep_idx // info_dict["chunks_size"] - data_path = info_dict["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx) + ep_chunk = ep_idx // info["chunks_size"] + data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx) data_files.append(data_path) all_files.extend(data_files) @@ -69,15 +87,15 @@ def mock_snapshot_download_factory( if rel_path.startswith("data/"): episode_index = _extract_episode_index_from_path(rel_path) if episode_index is not None: - _ = single_episode_parquet_path(local_dir, hf_ds, ep_idx=episode_index) + _ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info) if rel_path == INFO_PATH: - _ = info_path(local_dir, info_dict) + _ = info_path(local_dir, info) elif rel_path == STATS_PATH: - _ = stats_path(local_dir, stats_dict) + _ = stats_path(local_dir, stats) elif rel_path == TASKS_PATH: - _ = tasks_path(local_dir, task_dicts) + _ = tasks_path(local_dir, tasks) elif rel_path == EPISODES_PATH: - _ = episode_path(local_dir, episode_dicts) + _ = episode_path(local_dir, episodes) else: pass return str(local_dir) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index d1d49b31..7c7bb5e4 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -35,7 +35,6 @@ from lerobot.common.datasets.compute_stats import ( from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.lerobot_dataset import ( LeRobotDataset, - LeRobotDatasetMetadata, MultiLeRobotDataset, ) from lerobot.common.datasets.utils import ( @@ -57,10 +56,7 @@ def test_same_attributes_defined(lerobot_dataset_factory, tmp_path): # Instantiate both ways robot = make_robot("koch", mock=True) root_create = tmp_path / "create" - metadata_create = LeRobotDatasetMetadata.create( - repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create - ) - dataset_create = LeRobotDataset.create(metadata_create) + dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create) root_init = tmp_path / "init" dataset_init = lerobot_dataset_factory(root=root_init) @@ -75,14 +71,14 @@ def test_same_attributes_defined(lerobot_dataset_factory, tmp_path): assert init_attr == create_attr -def test_dataset_initialization(lerobot_dataset_from_episodes_factory, tmp_path): +def test_dataset_initialization(lerobot_dataset_factory, tmp_path): kwargs = { "repo_id": DUMMY_REPO_ID, "total_episodes": 10, "total_frames": 400, "episodes": [2, 5, 6], } - dataset = lerobot_dataset_from_episodes_factory(root=tmp_path, **kwargs) + dataset = lerobot_dataset_factory(root=tmp_path, **kwargs) assert dataset.repo_id == kwargs["repo_id"] assert dataset.meta.total_episodes == kwargs["total_episodes"] diff --git a/tests/test_delta_timestamps.py b/tests/test_delta_timestamps.py index 3dea95b8..c862a135 100644 --- a/tests/test_delta_timestamps.py +++ b/tests/test_delta_timestamps.py @@ -3,12 +3,13 @@ import torch from datasets import Dataset from lerobot.common.datasets.utils import ( + calculate_episode_data_index, check_delta_timestamps, check_timestamps_sync, get_delta_indices, hf_transform_to_torch, ) -from tests.fixtures.defaults import DUMMY_KEYS +from tests.fixtures.defaults import DUMMY_MOTOR_FEATURES @pytest.fixture(scope="module") @@ -53,7 +54,7 @@ def slightly_off_hf_dataset_factory(synced_hf_dataset_factory): @pytest.fixture(scope="module") def valid_delta_timestamps_factory(): - def _create_valid_delta_timestamps(fps: int = 30, keys: list = DUMMY_KEYS) -> dict: + def _create_valid_delta_timestamps(fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES) -> dict: delta_timestamps = {key: [i * (1 / fps) for i in range(-10, 10)] for key in keys} return delta_timestamps @@ -63,7 +64,7 @@ def valid_delta_timestamps_factory(): @pytest.fixture(scope="module") def invalid_delta_timestamps_factory(valid_delta_timestamps_factory): def _create_invalid_delta_timestamps( - fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_KEYS + fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES ) -> dict: delta_timestamps = valid_delta_timestamps_factory(fps, keys) # Modify a single timestamp just outside tolerance @@ -77,7 +78,7 @@ def invalid_delta_timestamps_factory(valid_delta_timestamps_factory): @pytest.fixture(scope="module") def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory): def _create_slightly_off_delta_timestamps( - fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_KEYS + fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES ) -> dict: delta_timestamps = valid_delta_timestamps_factory(fps, keys) # Modify a single timestamp just inside tolerance @@ -90,14 +91,15 @@ def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory): @pytest.fixture(scope="module") -def delta_indices(keys: list = DUMMY_KEYS) -> dict: +def delta_indices(keys: list = DUMMY_MOTOR_FEATURES) -> dict: return {key: list(range(-10, 10)) for key in keys} -def test_check_timestamps_sync_synced(synced_hf_dataset_factory, episode_data_index): +def test_check_timestamps_sync_synced(synced_hf_dataset_factory): fps = 30 tolerance_s = 1e-4 synced_hf_dataset = synced_hf_dataset_factory(fps) + episode_data_index = calculate_episode_data_index(synced_hf_dataset) result = check_timestamps_sync( hf_dataset=synced_hf_dataset, episode_data_index=episode_data_index, @@ -107,10 +109,11 @@ def test_check_timestamps_sync_synced(synced_hf_dataset_factory, episode_data_in assert result is True -def test_check_timestamps_sync_unsynced(unsynced_hf_dataset_factory, episode_data_index): +def test_check_timestamps_sync_unsynced(unsynced_hf_dataset_factory): fps = 30 tolerance_s = 1e-4 unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s) + episode_data_index = calculate_episode_data_index(unsynced_hf_dataset) with pytest.raises(ValueError): check_timestamps_sync( hf_dataset=unsynced_hf_dataset, @@ -120,10 +123,11 @@ def test_check_timestamps_sync_unsynced(unsynced_hf_dataset_factory, episode_dat ) -def test_check_timestamps_sync_unsynced_no_exception(unsynced_hf_dataset_factory, episode_data_index): +def test_check_timestamps_sync_unsynced_no_exception(unsynced_hf_dataset_factory): fps = 30 tolerance_s = 1e-4 unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s) + episode_data_index = calculate_episode_data_index(unsynced_hf_dataset) result = check_timestamps_sync( hf_dataset=unsynced_hf_dataset, episode_data_index=episode_data_index, @@ -134,10 +138,11 @@ def test_check_timestamps_sync_unsynced_no_exception(unsynced_hf_dataset_factory assert result is False -def test_check_timestamps_sync_slightly_off(slightly_off_hf_dataset_factory, episode_data_index): +def test_check_timestamps_sync_slightly_off(slightly_off_hf_dataset_factory): fps = 30 tolerance_s = 1e-4 slightly_off_hf_dataset = slightly_off_hf_dataset_factory(fps, tolerance_s) + episode_data_index = calculate_episode_data_index(slightly_off_hf_dataset) result = check_timestamps_sync( hf_dataset=slightly_off_hf_dataset, episode_data_index=episode_data_index, diff --git a/tests/test_image_writer.py b/tests/test_image_writer.py index 3f3045d0..2b0884a1 100644 --- a/tests/test_image_writer.py +++ b/tests/test_image_writer.py @@ -8,7 +8,7 @@ import pytest from PIL import Image from lerobot.common.datasets.image_writer import ( - ImageWriter, + AsyncImageWriter, image_array_to_image, safe_stop_image_writer, write_image, @@ -17,8 +17,8 @@ from lerobot.common.datasets.image_writer import ( DUMMY_IMAGE = "test_image.png" -def test_init_threading(tmp_path): - writer = ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=2) +def test_init_threading(): + writer = AsyncImageWriter(num_processes=0, num_threads=2) try: assert writer.num_processes == 0 assert writer.num_threads == 2 @@ -30,8 +30,8 @@ def test_init_threading(tmp_path): writer.stop() -def test_init_multiprocessing(tmp_path): - writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2) +def test_init_multiprocessing(): + writer = AsyncImageWriter(num_processes=2, num_threads=2) try: assert writer.num_processes == 2 assert writer.num_threads == 2 @@ -43,35 +43,9 @@ def test_init_multiprocessing(tmp_path): writer.stop() -def test_write_dir_created(tmp_path): - write_dir = tmp_path / "non_existent_dir" - assert not write_dir.exists() - writer = ImageWriter(write_dir=write_dir) - try: - assert write_dir.exists() - finally: - writer.stop() - - -def test_get_image_file_path_and_episode_dir(tmp_path): - writer = ImageWriter(write_dir=tmp_path) - try: - episode_index = 1 - image_key = "test_key" - frame_index = 10 - expected_episode_dir = tmp_path / f"{image_key}/episode_{episode_index:06d}" - expected_path = expected_episode_dir / f"frame_{frame_index:06d}.png" - image_file_path = writer.get_image_file_path(episode_index, image_key, frame_index) - assert image_file_path == expected_path - episode_dir = writer.get_episode_dir(episode_index, image_key) - assert episode_dir == expected_episode_dir - finally: - writer.stop() - - -def test_zero_threads(tmp_path): +def test_zero_threads(): with pytest.raises(ValueError): - ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=0) + AsyncImageWriter(num_processes=0, num_threads=0) def test_image_array_to_image_rgb(img_array_factory): @@ -148,7 +122,7 @@ def test_write_image_exception(tmp_path): def test_save_image_numpy(tmp_path, img_array_factory): - writer = ImageWriter(write_dir=tmp_path) + writer = AsyncImageWriter() try: image_array = img_array_factory() fpath = tmp_path / DUMMY_IMAGE @@ -163,7 +137,7 @@ def test_save_image_numpy(tmp_path, img_array_factory): def test_save_image_numpy_multiprocessing(tmp_path, img_array_factory): - writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2) + writer = AsyncImageWriter(num_processes=2, num_threads=2) try: image_array = img_array_factory() fpath = tmp_path / DUMMY_IMAGE @@ -177,7 +151,7 @@ def test_save_image_numpy_multiprocessing(tmp_path, img_array_factory): def test_save_image_torch(tmp_path, img_tensor_factory): - writer = ImageWriter(write_dir=tmp_path) + writer = AsyncImageWriter() try: image_tensor = img_tensor_factory() fpath = tmp_path / DUMMY_IMAGE @@ -193,7 +167,7 @@ def test_save_image_torch(tmp_path, img_tensor_factory): def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory): - writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2) + writer = AsyncImageWriter(num_processes=2, num_threads=2) try: image_tensor = img_tensor_factory() fpath = tmp_path / DUMMY_IMAGE @@ -208,7 +182,7 @@ def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory): def test_save_image_pil(tmp_path, img_factory): - writer = ImageWriter(write_dir=tmp_path) + writer = AsyncImageWriter() try: image_pil = img_factory() fpath = tmp_path / DUMMY_IMAGE @@ -223,7 +197,7 @@ def test_save_image_pil(tmp_path, img_factory): def test_save_image_pil_multiprocessing(tmp_path, img_factory): - writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2) + writer = AsyncImageWriter(num_processes=2, num_threads=2) try: image_pil = img_factory() fpath = tmp_path / DUMMY_IMAGE @@ -237,10 +211,10 @@ def test_save_image_pil_multiprocessing(tmp_path, img_factory): def test_save_image_invalid_data(tmp_path): - writer = ImageWriter(write_dir=tmp_path) + writer = AsyncImageWriter() try: image_array = "invalid data" - fpath = writer.get_image_file_path(0, "test_key", 0) + fpath = tmp_path / DUMMY_IMAGE fpath.parent.mkdir(parents=True, exist_ok=True) with patch("builtins.print") as mock_print: writer.save_image(image_array, fpath) @@ -252,47 +226,47 @@ def test_save_image_invalid_data(tmp_path): def test_save_image_after_stop(tmp_path, img_array_factory): - writer = ImageWriter(write_dir=tmp_path) + writer = AsyncImageWriter() writer.stop() image_array = img_array_factory() - fpath = writer.get_image_file_path(0, "test_key", 0) + fpath = tmp_path / DUMMY_IMAGE writer.save_image(image_array, fpath) time.sleep(1) assert not fpath.exists() -def test_stop(tmp_path): - writer = ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=2) +def test_stop(): + writer = AsyncImageWriter(num_processes=0, num_threads=2) writer.stop() assert not any(t.is_alive() for t in writer.threads) -def test_stop_multiprocessing(tmp_path): - writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2) +def test_stop_multiprocessing(): + writer = AsyncImageWriter(num_processes=2, num_threads=2) writer.stop() assert not any(p.is_alive() for p in writer.processes) -def test_multiple_stops(tmp_path): - writer = ImageWriter(write_dir=tmp_path) +def test_multiple_stops(): + writer = AsyncImageWriter() writer.stop() writer.stop() # Should not raise an exception assert not any(t.is_alive() for t in writer.threads) -def test_multiple_stops_multiprocessing(tmp_path): - writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2) +def test_multiple_stops_multiprocessing(): + writer = AsyncImageWriter(num_processes=2, num_threads=2) writer.stop() writer.stop() # Should not raise an exception assert not any(t.is_alive() for t in writer.threads) def test_wait_until_done(tmp_path, img_array_factory): - writer = ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=4) + writer = AsyncImageWriter(num_processes=0, num_threads=4) try: num_images = 100 image_arrays = [img_array_factory(width=500, height=500) for _ in range(num_images)] - fpaths = [writer.get_image_file_path(0, "test_key", i) for i in range(num_images)] + fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)] for image_array, fpath in zip(image_arrays, fpaths, strict=True): fpath.parent.mkdir(parents=True, exist_ok=True) writer.save_image(image_array, fpath) @@ -306,11 +280,11 @@ def test_wait_until_done(tmp_path, img_array_factory): def test_wait_until_done_multiprocessing(tmp_path, img_array_factory): - writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2) + writer = AsyncImageWriter(num_processes=2, num_threads=2) try: num_images = 100 image_arrays = [img_array_factory() for _ in range(num_images)] - fpaths = [writer.get_image_file_path(0, "test_key", i) for i in range(num_images)] + fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)] for image_array, fpath in zip(image_arrays, fpaths, strict=True): fpath.parent.mkdir(parents=True, exist_ok=True) writer.save_image(image_array, fpath) @@ -324,7 +298,7 @@ def test_wait_until_done_multiprocessing(tmp_path, img_array_factory): def test_exception_handling(tmp_path, img_array_factory): - writer = ImageWriter(write_dir=tmp_path) + writer = AsyncImageWriter() try: image_array = img_array_factory() with ( @@ -338,7 +312,7 @@ def test_exception_handling(tmp_path, img_array_factory): def test_with_different_image_formats(tmp_path, img_array_factory): - writer = ImageWriter(write_dir=tmp_path) + writer = AsyncImageWriter() try: image_array = img_array_factory() formats = ["png", "jpeg", "bmp"] @@ -353,7 +327,7 @@ def test_with_different_image_formats(tmp_path, img_array_factory): def test_safe_stop_image_writer_decorator(): class MockDataset: def __init__(self): - self.image_writer = MagicMock(spec=ImageWriter) + self.image_writer = MagicMock(spec=AsyncImageWriter) @safe_stop_image_writer def function_that_raises_exception(dataset=None): @@ -369,10 +343,10 @@ def test_safe_stop_image_writer_decorator(): def test_main_process_time(tmp_path, img_tensor_factory): - writer = ImageWriter(write_dir=tmp_path) + writer = AsyncImageWriter() try: image_tensor = img_tensor_factory() - fpath = tmp_path / "test_main_process_time.png" + fpath = tmp_path / DUMMY_IMAGE start_time = time.perf_counter() writer.save_image(image_tensor, fpath) end_time = time.perf_counter() diff --git a/tests/test_online_buffer.py b/tests/test_online_buffer.py index 20e26177..092cd3d0 100644 --- a/tests/test_online_buffer.py +++ b/tests/test_online_buffer.py @@ -213,15 +213,13 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range(): @pytest.mark.parametrize("online_dataset_size", [0, 4]) @pytest.mark.parametrize("online_sampling_ratio", [0.0, 1.0]) def test_compute_sampler_weights_trivial( - lerobot_dataset_from_episodes_factory, + lerobot_dataset_factory, tmp_path, offline_dataset_size: int, online_dataset_size: int, online_sampling_ratio: float, ): - offline_dataset = lerobot_dataset_from_episodes_factory( - tmp_path, total_episodes=1, total_frames=offline_dataset_size - ) + offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size) online_dataset, _ = make_new_buffer() if online_dataset_size > 0: online_dataset.add_data( @@ -241,9 +239,9 @@ def test_compute_sampler_weights_trivial( assert torch.allclose(weights, expected_weights) -def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_from_episodes_factory, tmp_path): +def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path): # Arbitrarily set small dataset sizes, making sure to have uneven sizes. - offline_dataset = lerobot_dataset_from_episodes_factory(tmp_path, total_episodes=1, total_frames=4) + offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4) online_dataset, _ = make_new_buffer() online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) online_sampling_ratio = 0.8 @@ -255,11 +253,9 @@ def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_from_episodes_ ) -def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n( - lerobot_dataset_from_episodes_factory, tmp_path -): +def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path): # Arbitrarily set small dataset sizes, making sure to have uneven sizes. - offline_dataset = lerobot_dataset_from_episodes_factory(tmp_path, total_episodes=1, total_frames=4) + offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4) online_dataset, _ = make_new_buffer() online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) weights = compute_sampler_weights( @@ -270,9 +266,9 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n( ) -def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_from_episodes_factory, tmp_path): +def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path): """Note: test copied from test_sampler.""" - offline_dataset = lerobot_dataset_from_episodes_factory(tmp_path, total_episodes=1, total_frames=2) + offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2) online_dataset, _ = make_new_buffer() online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))