from pathlib import Path import datasets import pandas as pd import pytest from huggingface_hub.utils import filter_repo_objects from lerobot.common.datasets.utils import ( DEFAULT_DATA_PATH, DEFAULT_EPISODES_PATH, DEFAULT_TASKS_PATH, INFO_PATH, STATS_PATH, ) from tests.fixtures.constants import LEROBOT_TEST_DIR @pytest.fixture(scope="session") def mock_snapshot_download_factory( info_factory, create_info, stats_factory, create_stats, tasks_factory, create_tasks, episodes_factory, create_episodes, hf_dataset_factory, create_hf_dataset, ): """ This factory allows to patch snapshot_download such that when called, it will create expected files rather than making calls to the hub api. Its design allows to pass explicitly files which you want to be created. """ def _mock_snapshot_download_func( info: dict | None = None, stats: dict | None = None, tasks: pd.DataFrame | None = None, episodes: datasets.Dataset | None = None, hf_dataset: datasets.Dataset | None = None, ): if info is None: info = info_factory() if stats is None: stats = stats_factory(features=info["features"]) if tasks is None: tasks = tasks_factory(total_tasks=info["total_tasks"]) if episodes is None: episodes = episodes_factory( features=info["features"], total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks, ) if hf_dataset is None: hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"]) def _mock_snapshot_download( repo_id: str, # TODO(rcadene): repo_id should be used no? local_dir: str | Path | None = None, allow_patterns: str | list[str] | None = None, ignore_patterns: str | list[str] | None = None, *args, **kwargs, ) -> str: if local_dir is None: local_dir = LEROBOT_TEST_DIR # List all possible files all_files = [ INFO_PATH, STATS_PATH, # TODO(rcadene): remove naive chunk 0 file 0 ? DEFAULT_TASKS_PATH.format(chunk_index=0, file_index=0), DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0), DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0), ] allowed_files = filter_repo_objects( all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns ) has_info = False has_tasks = False has_episodes = False has_stats = False has_data = False for rel_path in allowed_files: if rel_path.startswith("meta/info.json"): has_info = True elif rel_path.startswith("meta/stats"): has_stats = True elif rel_path.startswith("meta/tasks"): has_tasks = True elif rel_path.startswith("meta/episodes"): has_episodes = True elif rel_path.startswith("data/"): has_data = True else: raise ValueError(f"{rel_path} not supported.") if has_info: create_info(local_dir, info) if has_stats: create_stats(local_dir, stats) if has_tasks: create_tasks(local_dir, tasks) if has_episodes: create_episodes(local_dir, episodes) if has_data: create_hf_dataset(local_dir, hf_dataset) return str(local_dir) return _mock_snapshot_download return _mock_snapshot_download_func