diff --git a/tests/conftest.py b/tests/conftest.py index d267f911..caf20148 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,7 +24,7 @@ from lerobot.common.utils.utils import init_hydra_config from tests.utils import DEVICE, ROBOT_CONFIG_PATH_TEMPLATE, make_camera, make_motors_bus # Import fixture modules as plugins -pytest_plugins = ["tests.fixtures.dataset"] +pytest_plugins = ["tests.fixtures.dataset", "tests.fixtures.dataset_factories", "tests.fixtures.files"] def pytest_collection_finish(): diff --git a/tests/fixtures/dataset.py b/tests/fixtures/dataset.py index ad70ff66..576486bb 100644 --- a/tests/fixtures/dataset.py +++ b/tests/fixtures/dataset.py @@ -1,20 +1,40 @@ import datasets -import numpy as np import pytest -from lerobot.common.datasets.utils import get_episode_data_index, hf_transform_to_torch +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.utils import get_episode_data_index +from tests.fixtures.defaults import DUMMY_CAMERA_KEYS @pytest.fixture(scope="session") -def img_array_factory(): - def _create_img_array(width=100, height=100) -> np.ndarray: - return np.random.randint(0, 256, size=(height, width, 3), dtype=np.uint8) - - return _create_img_array +def empty_info(info_factory) -> dict: + return info_factory( + keys=[], + image_keys=[], + video_keys=[], + shapes={}, + names={}, + ) @pytest.fixture(scope="session") -def tasks(): +def info(info_factory) -> dict: + return info_factory( + total_episodes=4, + total_frames=420, + total_tasks=3, + total_videos=8, + total_chunks=1, + ) + + +@pytest.fixture(scope="session") +def stats(stats_factory) -> list: + return stats_factory() + + +@pytest.fixture(scope="session") +def tasks() -> list: return [ {"task_index": 0, "task": "Pick up the block."}, {"task_index": 1, "task": "Open the box."}, @@ -23,7 +43,7 @@ def tasks(): @pytest.fixture(scope="session") -def episode_dicts(): +def episodes() -> list: return [ {"episode_index": 0, "tasks": ["Pick up the block."], "length": 100}, {"episode_index": 1, "tasks": ["Open the box."], "length": 80}, @@ -33,120 +53,22 @@ def episode_dicts(): @pytest.fixture(scope="session") -def episode_data_index(episode_dicts): - return get_episode_data_index(episode_dicts) +def episode_data_index(episodes) -> dict: + return get_episode_data_index(episodes) @pytest.fixture(scope="session") -def hf_dataset(hf_dataset_factory, episode_dicts, tasks): - keys = ["state", "action"] - shapes = { - "state": 10, - "action": 10, - } - return hf_dataset_factory(episode_dicts, tasks, keys, shapes) +def hf_dataset(hf_dataset_factory) -> datasets.Dataset: + return hf_dataset_factory() @pytest.fixture(scope="session") -def hf_dataset_image(hf_dataset_factory, episode_dicts, tasks): - keys = ["state", "action"] - image_keys = ["image"] - shapes = { - "state": 10, - "action": 10, - "image": { - "width": 100, - "height": 70, - "channels": 3, - }, - } - return hf_dataset_factory(episode_dicts, tasks, keys, shapes, image_keys=image_keys) - - -def get_task_index(tasks_dicts: dict, task: str) -> int: - """ - Given a task in natural language, returns its task_index if the task already exists in the dataset, - otherwise creates a new task_index. - """ - tasks = {d["task_index"]: d["task"] for d in tasks_dicts} - task_to_task_index = {task: task_idx for task_idx, task in tasks.items()} - return task_to_task_index[task] +def hf_dataset_image(hf_dataset_factory) -> datasets.Dataset: + image_keys = DUMMY_CAMERA_KEYS + return hf_dataset_factory(image_keys=image_keys) @pytest.fixture(scope="session") -def hf_dataset_factory(img_array_factory): - def _create_hf_dataset( - episode_dicts: list[dict], - tasks: list[dict], - keys: list[str], - shapes: dict, - fps: int = 30, - image_keys: list[str] | None = None, - ): - key_features = { - key: datasets.Sequence(length=shapes[key], feature=datasets.Value(dtype="float32")) - for key in keys - } - image_features = {key: datasets.Image() for key in image_keys} if image_keys else {} - common_features = { - "episode_index": datasets.Value(dtype="int64"), - "frame_index": datasets.Value(dtype="int64"), - "timestamp": datasets.Value(dtype="float32"), - "next.done": datasets.Value(dtype="bool"), - "index": datasets.Value(dtype="int64"), - "task_index": datasets.Value(dtype="int64"), - } - features = datasets.Features( - { - **key_features, - **image_features, - **common_features, - } - ) - - episode_index_col = np.array([], dtype=np.int64) - frame_index_col = np.array([], dtype=np.int64) - timestamp_col = np.array([], dtype=np.float32) - next_done_col = np.array([], dtype=bool) - task_index = np.array([], dtype=np.int64) - - for ep_dict in episode_dicts: - episode_index_col = np.concatenate( - (episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int)) - ) - frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int))) - timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps)) - next_done_ep = np.full(ep_dict["length"], False, dtype=bool) - next_done_ep[-1] = True - next_done_col = np.concatenate((next_done_col, next_done_ep)) - ep_task_index = get_task_index(tasks, ep_dict["tasks"][0]) - task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int))) - - index_col = np.arange(len(episode_index_col)) - key_cols = {key: np.random.random((len(index_col), shapes[key])).astype(np.float32) for key in keys} - - image_cols = {} - if image_keys: - for key in image_keys: - image_cols[key] = [ - img_array_factory(width=shapes[key]["width"], height=shapes[key]["height"]) - for _ in range(len(index_col)) - ] - - dataset = datasets.Dataset.from_dict( - { - **key_cols, - **image_cols, - "episode_index": episode_index_col, - "frame_index": frame_index_col, - "timestamp": timestamp_col, - "next.done": next_done_col, - "index": index_col, - "task_index": task_index, - }, - features=features, - ) - dataset.set_transform(hf_transform_to_torch) - return dataset - - return _create_hf_dataset +def lerobot_dataset(lerobot_dataset_factory, tmp_path_factory) -> LeRobotDataset: + root = tmp_path_factory.getbasetemp() + return lerobot_dataset_factory(root=root) diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py new file mode 100644 index 00000000..92195ee7 --- /dev/null +++ b/tests/fixtures/dataset_factories.py @@ -0,0 +1,253 @@ +from pathlib import Path + +import datasets +import numpy as np +import pytest + +from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset +from lerobot.common.datasets.utils import ( + DEFAULT_CHUNK_SIZE, + DEFAULT_PARQUET_PATH, + DEFAULT_VIDEO_PATH, + hf_transform_to_torch, +) +from tests.fixtures.defaults import DUMMY_CAMERA_KEYS, DUMMY_KEYS, DUMMY_REPO_ID + + +def get_dummy_shapes(keys: list[str] | None = None, camera_keys: list[str] | None = None) -> dict: + shapes = {} + if keys: + shapes.update({key: 10 for key in keys}) + if camera_keys: + shapes.update({key: {"width": 100, "height": 70, "channels": 3} for key in camera_keys}) + return shapes + + +def get_task_index(tasks_dicts: dict, task: str) -> int: + """ + Given a task in natural language, returns its task_index if the task already exists in the dataset, + otherwise creates a new task_index. + """ + tasks = {d["task_index"]: d["task"] for d in tasks_dicts} + task_to_task_index = {task: task_idx for task_idx, task in tasks.items()} + return task_to_task_index[task] + + +@pytest.fixture(scope="session") +def img_array_factory(): + def _create_img_array(width=100, height=100) -> np.ndarray: + return np.random.randint(0, 256, size=(height, width, 3), dtype=np.uint8) + + return _create_img_array + + +@pytest.fixture(scope="session") +def info_factory(): + def _create_info( + codebase_version: str = CODEBASE_VERSION, + fps: int = 30, + robot_type: str = "dummy_robot", + keys: list[str] = DUMMY_KEYS, + image_keys: list[str] | None = None, + video_keys: list[str] = DUMMY_CAMERA_KEYS, + shapes: dict | None = None, + names: dict | None = None, + total_episodes: int = 0, + total_frames: int = 0, + total_tasks: int = 0, + total_videos: int = 0, + total_chunks: int = 0, + chunks_size: int = DEFAULT_CHUNK_SIZE, + data_path: str = DEFAULT_PARQUET_PATH, + videos_path: str = DEFAULT_VIDEO_PATH, + ) -> dict: + if not image_keys: + image_keys = [] + if not shapes: + shapes = get_dummy_shapes(keys=keys, camera_keys=[*image_keys, *video_keys]) + if not names: + names = {key: [f"motor_{i}" for i in range(shapes[key])] for key in keys} + + video_info = {"videos_path": videos_path} + for key in video_keys: + video_info[key] = { + "video.fps": fps, + "video.width": shapes[key]["width"], + "video.height": shapes[key]["height"], + "video.channels": shapes[key]["channels"], + "video.codec": "av1", + "video.pix_fmt": "yuv420p", + "video.is_depth_map": False, + "has_audio": False, + } + return { + "codebase_version": codebase_version, + "data_path": data_path, + "robot_type": robot_type, + "total_episodes": total_episodes, + "total_frames": total_frames, + "total_tasks": total_tasks, + "total_videos": total_videos, + "total_chunks": total_chunks, + "chunks_size": chunks_size, + "fps": fps, + "splits": {}, + "keys": keys, + "video_keys": video_keys, + "image_keys": image_keys, + "shapes": shapes, + "names": names, + "videos": video_info if len(video_keys) > 0 else None, + } + + return _create_info + + +@pytest.fixture(scope="session") +def stats_factory(): + def _create_stats( + keys: list[str] = DUMMY_KEYS, + image_keys: list[str] | None = None, + video_keys: list[str] = DUMMY_CAMERA_KEYS, + shapes: dict | None = None, + ) -> dict: + if not image_keys: + image_keys = [] + if not shapes: + shapes = get_dummy_shapes(keys=keys, camera_keys=[*image_keys, *video_keys]) + stats = {} + for key in keys: + shape = shapes[key] + stats[key] = { + "max": np.full(shape, 1, dtype=np.float32).tolist(), + "mean": np.full(shape, 0.5, dtype=np.float32).tolist(), + "min": np.full(shape, 0, dtype=np.float32).tolist(), + "std": np.full(shape, 0.25, dtype=np.float32).tolist(), + } + for key in [*image_keys, *video_keys]: + shape = (3, 1, 1) + stats[key] = { + "max": np.full(shape, 1, dtype=np.float32).tolist(), + "mean": np.full(shape, 0.5, dtype=np.float32).tolist(), + "min": np.full(shape, 0, dtype=np.float32).tolist(), + "std": np.full(shape, 0.25, dtype=np.float32).tolist(), + } + return stats + + return _create_stats + + +@pytest.fixture(scope="session") +def hf_dataset_factory(img_array_factory, episodes, tasks): + def _create_hf_dataset( + episode_dicts: list[dict] = episodes, + task_dicts: list[dict] = tasks, + keys: list[str] = DUMMY_KEYS, + image_keys: list[str] | None = None, + shapes: dict | None = None, + fps: int = 30, + ) -> datasets.Dataset: + if not image_keys: + image_keys = [] + if not shapes: + shapes = get_dummy_shapes(keys=keys, camera_keys=image_keys) + key_features = { + key: datasets.Sequence(length=shapes[key], feature=datasets.Value(dtype="float32")) + for key in keys + } + image_features = {key: datasets.Image() for key in image_keys} if image_keys else {} + common_features = { + "episode_index": datasets.Value(dtype="int64"), + "frame_index": datasets.Value(dtype="int64"), + "timestamp": datasets.Value(dtype="float32"), + "next.done": datasets.Value(dtype="bool"), + "index": datasets.Value(dtype="int64"), + "task_index": datasets.Value(dtype="int64"), + } + features = datasets.Features( + { + **key_features, + **image_features, + **common_features, + } + ) + + episode_index_col = np.array([], dtype=np.int64) + frame_index_col = np.array([], dtype=np.int64) + timestamp_col = np.array([], dtype=np.float32) + next_done_col = np.array([], dtype=bool) + task_index = np.array([], dtype=np.int64) + + for ep_dict in episode_dicts: + episode_index_col = np.concatenate( + (episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int)) + ) + frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int))) + timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps)) + next_done_ep = np.full(ep_dict["length"], False, dtype=bool) + next_done_ep[-1] = True + next_done_col = np.concatenate((next_done_col, next_done_ep)) + ep_task_index = get_task_index(task_dicts, ep_dict["tasks"][0]) + task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int))) + + index_col = np.arange(len(episode_index_col)) + key_cols = {key: np.random.random((len(index_col), shapes[key])).astype(np.float32) for key in keys} + + image_cols = {} + if image_keys: + for key in image_keys: + image_cols[key] = [ + img_array_factory(width=shapes[key]["width"], height=shapes[key]["height"]) + for _ in range(len(index_col)) + ] + + dataset = datasets.Dataset.from_dict( + { + **key_cols, + **image_cols, + "episode_index": episode_index_col, + "frame_index": frame_index_col, + "timestamp": timestamp_col, + "next.done": next_done_col, + "index": index_col, + "task_index": task_index, + }, + features=features, + ) + dataset.set_transform(hf_transform_to_torch) + return dataset + + return _create_hf_dataset + + +@pytest.fixture(scope="session") +def lerobot_dataset_factory( + info, + info_path, + stats, + stats_path, + episodes, + episode_path, + tasks, + tasks_path, + hf_dataset, + multi_episode_parquet_path, +): + def _create_lerobot_dataset( + root: Path, + info_dict: dict = info, + stats_dict: dict = stats, + episode_dicts: list[dict] = episodes, + task_dicts: list[dict] = tasks, + hf_ds: datasets.Dataset = hf_dataset, + ) -> LeRobotDataset: + root.mkdir(parents=True, exist_ok=True) + # Create local files + _ = info_path(root, info_dict) + _ = stats_path(root, stats_dict) + _ = tasks_path(root, task_dicts) + _ = episode_path(root, episode_dicts) + _ = multi_episode_parquet_path(root, hf_ds) + return LeRobotDataset(repo_id=DUMMY_REPO_ID, root=root, local_files_only=True) + + return _create_lerobot_dataset diff --git a/tests/fixtures/defaults.py b/tests/fixtures/defaults.py new file mode 100644 index 00000000..1edb5132 --- /dev/null +++ b/tests/fixtures/defaults.py @@ -0,0 +1,3 @@ +DUMMY_REPO_ID = "dummy/repo" +DUMMY_KEYS = ["state", "action"] +DUMMY_CAMERA_KEYS = ["laptop", "phone"] diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py new file mode 100644 index 00000000..714824f9 --- /dev/null +++ b/tests/fixtures/files.py @@ -0,0 +1,94 @@ +import json +from pathlib import Path + +import datasets +import jsonlines +import pyarrow.compute as pc +import pyarrow.parquet as pq +import pytest + +from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH + + +@pytest.fixture(scope="session") +def info_path(info): + def _create_info_json_file(dir: Path, info_dict: dict = info) -> Path: + fpath = dir / INFO_PATH + fpath.parent.mkdir(parents=True, exist_ok=True) + with open(fpath, "w") as f: + json.dump(info_dict, f, indent=4, ensure_ascii=False) + return fpath + + return _create_info_json_file + + +@pytest.fixture(scope="session") +def stats_path(stats): + def _create_stats_json_file(dir: Path, stats_dict: dict = stats) -> Path: + fpath = dir / STATS_PATH + fpath.parent.mkdir(parents=True, exist_ok=True) + with open(fpath, "w") as f: + json.dump(stats_dict, f, indent=4, ensure_ascii=False) + return fpath + + return _create_stats_json_file + + +@pytest.fixture(scope="session") +def tasks_path(tasks): + def _create_tasks_jsonl_file(dir: Path, tasks_dicts: list = tasks) -> Path: + fpath = dir / TASKS_PATH + fpath.parent.mkdir(parents=True, exist_ok=True) + with jsonlines.open(fpath, "w") as writer: + writer.write_all(tasks_dicts) + return fpath + + return _create_tasks_jsonl_file + + +@pytest.fixture(scope="session") +def episode_path(episodes): + def _create_episodes_jsonl_file(dir: Path, episode_dicts: list = episodes) -> Path: + fpath = dir / EPISODES_PATH + fpath.parent.mkdir(parents=True, exist_ok=True) + with jsonlines.open(fpath, "w") as writer: + writer.write_all(episode_dicts) + return fpath + + return _create_episodes_jsonl_file + + +@pytest.fixture(scope="session") +def single_episode_parquet_path(hf_dataset, info): + def _create_single_episode_parquet( + dir: Path, hf_ds: datasets.Dataset = hf_dataset, ep_idx: int = 0 + ) -> Path: + data_path = info["data_path"] + chunks_size = info["chunks_size"] + ep_chunk = ep_idx // chunks_size + fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx) + fpath.parent.mkdir(parents=True, exist_ok=True) + table = hf_ds.data.table + ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) + pq.write_table(ep_table, fpath) + return fpath + + return _create_single_episode_parquet + + +@pytest.fixture(scope="session") +def multi_episode_parquet_path(hf_dataset, info): + def _create_multi_episode_parquet(dir: Path, hf_ds: datasets.Dataset = hf_dataset) -> Path: + data_path = info["data_path"] + chunks_size = info["chunks_size"] + total_episodes = info["total_episodes"] + for ep_idx in range(total_episodes): + ep_chunk = ep_idx // chunks_size + fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx) + fpath.parent.mkdir(parents=True, exist_ok=True) + table = hf_ds.data.table + ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) + pq.write_table(ep_table, fpath) + return dir / "data" + + return _create_multi_episode_parquet diff --git a/tests/test_delta_timestamps.py b/tests/test_delta_timestamps.py index 29935fe4..ae5ba0aa 100644 --- a/tests/test_delta_timestamps.py +++ b/tests/test_delta_timestamps.py @@ -8,25 +8,21 @@ from lerobot.common.datasets.utils import ( get_delta_indices, hf_transform_to_torch, ) +from tests.fixtures.defaults import DUMMY_KEYS @pytest.fixture(scope="module") -def synced_hf_dataset_factory(hf_dataset_factory, episode_dicts, tasks): - def _create_synced_hf_dataset(fps: int = 30, keys: list | None = None) -> Dataset: - if not keys: - keys = ["state", "action"] - shapes = {key: 10 for key in keys} - return hf_dataset_factory(episode_dicts, tasks, keys, shapes, fps=fps) +def synced_hf_dataset_factory(hf_dataset_factory): + def _create_synced_hf_dataset(fps: int = 30) -> Dataset: + return hf_dataset_factory(fps=fps) return _create_synced_hf_dataset @pytest.fixture(scope="module") def unsynced_hf_dataset_factory(synced_hf_dataset_factory): - def _create_unsynced_hf_dataset( - fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None - ) -> Dataset: - hf_dataset = synced_hf_dataset_factory(fps=fps, keys=keys) + def _create_unsynced_hf_dataset(fps: int = 30, tolerance_s: float = 1e-4) -> Dataset: + hf_dataset = synced_hf_dataset_factory(fps=fps) features = hf_dataset.features df = hf_dataset.to_pandas() dtype = df["timestamp"].dtype # This is to avoid pandas type warning @@ -41,10 +37,8 @@ def unsynced_hf_dataset_factory(synced_hf_dataset_factory): @pytest.fixture(scope="module") def slightly_off_hf_dataset_factory(synced_hf_dataset_factory): - def _create_slightly_off_hf_dataset( - fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None - ) -> Dataset: - hf_dataset = synced_hf_dataset_factory(fps=fps, keys=keys) + def _create_slightly_off_hf_dataset(fps: int = 30, tolerance_s: float = 1e-4) -> Dataset: + hf_dataset = synced_hf_dataset_factory(fps=fps) features = hf_dataset.features df = hf_dataset.to_pandas() dtype = df["timestamp"].dtype # This is to avoid pandas type warning @@ -59,9 +53,7 @@ def slightly_off_hf_dataset_factory(synced_hf_dataset_factory): @pytest.fixture(scope="module") def valid_delta_timestamps_factory(): - def _create_valid_delta_timestamps(fps: int = 30, keys: list | None = None) -> dict: - if not keys: - keys = ["state", "action"] + def _create_valid_delta_timestamps(fps: int = 30, keys: list = DUMMY_KEYS) -> dict: delta_timestamps = {key: [i * (1 / fps) for i in range(-10, 10)] for key in keys} return delta_timestamps @@ -71,10 +63,8 @@ def valid_delta_timestamps_factory(): @pytest.fixture(scope="module") def invalid_delta_timestamps_factory(valid_delta_timestamps_factory): def _create_invalid_delta_timestamps( - fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None + fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_KEYS ) -> dict: - if not keys: - keys = ["state", "action"] delta_timestamps = valid_delta_timestamps_factory(fps, keys) # Modify a single timestamp just outside tolerance for key in keys: @@ -87,10 +77,8 @@ def invalid_delta_timestamps_factory(valid_delta_timestamps_factory): @pytest.fixture(scope="module") def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory): def _create_slightly_off_delta_timestamps( - fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None + fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_KEYS ) -> dict: - if not keys: - keys = ["state", "action"] delta_timestamps = valid_delta_timestamps_factory(fps, keys) # Modify a single timestamp just inside tolerance for key in delta_timestamps: @@ -102,9 +90,7 @@ def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory): @pytest.fixture(scope="module") -def delta_indices(keys: list | None = None) -> dict: - if not keys: - keys = ["state", "action"] +def delta_indices(keys: list = DUMMY_KEYS) -> dict: return {key: list(range(-10, 10)) for key in keys}