From cd1509d8059d190c0395b8ba1c02a3b4a87b8698 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Fri, 1 Nov 2024 10:58:09 +0100 Subject: [PATCH] Mock snapshot_download --- tests/conftest.py | 7 ++- tests/fixtures/dataset_factories.py | 37 +++++++----- tests/fixtures/defaults.py | 3 + tests/fixtures/hub.py | 87 +++++++++++++++++++++++++++++ 4 files changed, 119 insertions(+), 15 deletions(-) create mode 100644 tests/fixtures/hub.py diff --git a/tests/conftest.py b/tests/conftest.py index caf20148..8491eeba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,7 +24,12 @@ from lerobot.common.utils.utils import init_hydra_config from tests.utils import DEVICE, ROBOT_CONFIG_PATH_TEMPLATE, make_camera, make_motors_bus # Import fixture modules as plugins -pytest_plugins = ["tests.fixtures.dataset", "tests.fixtures.dataset_factories", "tests.fixtures.files"] +pytest_plugins = [ + "tests.fixtures.dataset", + "tests.fixtures.dataset_factories", + "tests.fixtures.files", + "tests.fixtures.hub", +] def pytest_collection_finish(): diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index 92195ee7..6ee077d5 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -1,4 +1,5 @@ from pathlib import Path +from unittest.mock import patch import datasets import numpy as np @@ -223,31 +224,39 @@ def hf_dataset_factory(img_array_factory, episodes, tasks): @pytest.fixture(scope="session") def lerobot_dataset_factory( info, - info_path, stats, - stats_path, episodes, - episode_path, tasks, - tasks_path, hf_dataset, - multi_episode_parquet_path, + mock_snapshot_download_factory, ): def _create_lerobot_dataset( root: Path, info_dict: dict = info, stats_dict: dict = stats, - episode_dicts: list[dict] = episodes, task_dicts: list[dict] = tasks, + episode_dicts: list[dict] = episodes, hf_ds: datasets.Dataset = hf_dataset, + **kwargs, ) -> LeRobotDataset: - root.mkdir(parents=True, exist_ok=True) - # Create local files - _ = info_path(root, info_dict) - _ = stats_path(root, stats_dict) - _ = tasks_path(root, task_dicts) - _ = episode_path(root, episode_dicts) - _ = multi_episode_parquet_path(root, hf_ds) - return LeRobotDataset(repo_id=DUMMY_REPO_ID, root=root, local_files_only=True) + mock_snapshot_download = mock_snapshot_download_factory( + info_dict=info_dict, + stats_dict=stats_dict, + tasks_dicts=task_dicts, + episodes_dicts=episode_dicts, + hf_ds=hf_ds, + ) + with ( + patch( + "lerobot.common.datasets.lerobot_dataset.get_hub_safe_version" + ) as mock_get_hub_safe_version_patch, + patch( + "lerobot.common.datasets.lerobot_dataset.snapshot_download" + ) as mock_snapshot_download_patch, + ): + mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version + mock_snapshot_download_patch.side_effect = mock_snapshot_download + + return LeRobotDataset(repo_id=DUMMY_REPO_ID, root=root, **kwargs) return _create_lerobot_dataset diff --git a/tests/fixtures/defaults.py b/tests/fixtures/defaults.py index 1edb5132..27722e83 100644 --- a/tests/fixtures/defaults.py +++ b/tests/fixtures/defaults.py @@ -1,3 +1,6 @@ +from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME + +LEROBOT_TEST_DIR = LEROBOT_HOME / "_testing" DUMMY_REPO_ID = "dummy/repo" DUMMY_KEYS = ["state", "action"] DUMMY_CAMERA_KEYS = ["laptop", "phone"] diff --git a/tests/fixtures/hub.py b/tests/fixtures/hub.py new file mode 100644 index 00000000..3422936c --- /dev/null +++ b/tests/fixtures/hub.py @@ -0,0 +1,87 @@ +from pathlib import Path + +import pytest +from huggingface_hub.utils import filter_repo_objects + +from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH +from tests.fixtures.defaults import LEROBOT_TEST_DIR + + +@pytest.fixture(scope="session") +def mock_snapshot_download_factory( + info, + info_path, + stats, + stats_path, + tasks, + tasks_path, + episodes, + episode_path, + single_episode_parquet_path, + hf_dataset, +): + """ + This factory allows to patch snapshot_download such that when called, it will create expected files rather + than making calls to the hub api. Its design allows to pass explicitly files which you want to be created. + """ + + def _mock_snapshot_download_func( + info_dict=info, stats_dict=stats, tasks_dicts=tasks, episodes_dicts=episodes, hf_ds=hf_dataset + ): + def _extract_episode_index_from_path(fpath: str) -> int: + path = Path(fpath) + if path.suffix == ".parquet" and path.stem.startswith("episode_"): + episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0 + return episode_index + else: + return None + + def _mock_snapshot_download( + repo_id: str, + local_dir: str | Path | None = None, + allow_patterns: str | list[str] | None = None, + ignore_patterns: str | list[str] | None = None, + *args, + **kwargs, + ) -> str: + if not local_dir: + local_dir = LEROBOT_TEST_DIR + + # List all possible files + all_files = [] + meta_files = [INFO_PATH, STATS_PATH, TASKS_PATH, EPISODES_PATH] + all_files.extend(meta_files) + + data_files = [] + for episode_dict in episodes_dicts: + ep_idx = episode_dict["episode_index"] + ep_chunk = ep_idx // info_dict["chunks_size"] + data_path = info_dict["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx) + data_files.append(data_path) + all_files.extend(data_files) + + allowed_files = filter_repo_objects( + all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns + ) + + # Create allowed files + for rel_path in allowed_files: + if rel_path.startswith("data/"): + episode_index = _extract_episode_index_from_path(rel_path) + if episode_index is not None: + _ = single_episode_parquet_path(local_dir, hf_ds, ep_idx=episode_index) + if rel_path == INFO_PATH: + _ = info_path(local_dir, info_dict) + elif rel_path == STATS_PATH: + _ = stats_path(local_dir, stats_dict) + elif rel_path == TASKS_PATH: + _ = tasks_path(local_dir, tasks_dicts) + elif rel_path == EPISODES_PATH: + _ = episode_path(local_dir, episodes_dicts) + else: + pass + return str(local_dir) + + return _mock_snapshot_download + + return _mock_snapshot_download_func