Mock snapshot_download

This commit is contained in:
Simon Alibert 2024-11-01 10:58:09 +01:00
parent 5ea7c78237
commit cd1509d805
4 changed files with 119 additions and 15 deletions

View File

@ -24,7 +24,12 @@ from lerobot.common.utils.utils import init_hydra_config
from tests.utils import DEVICE, ROBOT_CONFIG_PATH_TEMPLATE, make_camera, make_motors_bus
# Import fixture modules as plugins
pytest_plugins = ["tests.fixtures.dataset", "tests.fixtures.dataset_factories", "tests.fixtures.files"]
pytest_plugins = [
"tests.fixtures.dataset",
"tests.fixtures.dataset_factories",
"tests.fixtures.files",
"tests.fixtures.hub",
]
def pytest_collection_finish():

View File

@ -1,4 +1,5 @@
from pathlib import Path
from unittest.mock import patch
import datasets
import numpy as np
@ -223,31 +224,39 @@ def hf_dataset_factory(img_array_factory, episodes, tasks):
@pytest.fixture(scope="session")
def lerobot_dataset_factory(
info,
info_path,
stats,
stats_path,
episodes,
episode_path,
tasks,
tasks_path,
hf_dataset,
multi_episode_parquet_path,
mock_snapshot_download_factory,
):
def _create_lerobot_dataset(
root: Path,
info_dict: dict = info,
stats_dict: dict = stats,
episode_dicts: list[dict] = episodes,
task_dicts: list[dict] = tasks,
episode_dicts: list[dict] = episodes,
hf_ds: datasets.Dataset = hf_dataset,
**kwargs,
) -> LeRobotDataset:
root.mkdir(parents=True, exist_ok=True)
# Create local files
_ = info_path(root, info_dict)
_ = stats_path(root, stats_dict)
_ = tasks_path(root, task_dicts)
_ = episode_path(root, episode_dicts)
_ = multi_episode_parquet_path(root, hf_ds)
return LeRobotDataset(repo_id=DUMMY_REPO_ID, root=root, local_files_only=True)
mock_snapshot_download = mock_snapshot_download_factory(
info_dict=info_dict,
stats_dict=stats_dict,
tasks_dicts=task_dicts,
episodes_dicts=episode_dicts,
hf_ds=hf_ds,
)
with (
patch(
"lerobot.common.datasets.lerobot_dataset.get_hub_safe_version"
) as mock_get_hub_safe_version_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
):
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDataset(repo_id=DUMMY_REPO_ID, root=root, **kwargs)
return _create_lerobot_dataset

View File

@ -1,3 +1,6 @@
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
LEROBOT_TEST_DIR = LEROBOT_HOME / "_testing"
DUMMY_REPO_ID = "dummy/repo"
DUMMY_KEYS = ["state", "action"]
DUMMY_CAMERA_KEYS = ["laptop", "phone"]

87
tests/fixtures/hub.py vendored Normal file
View File

@ -0,0 +1,87 @@
from pathlib import Path
import pytest
from huggingface_hub.utils import filter_repo_objects
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
from tests.fixtures.defaults import LEROBOT_TEST_DIR
@pytest.fixture(scope="session")
def mock_snapshot_download_factory(
info,
info_path,
stats,
stats_path,
tasks,
tasks_path,
episodes,
episode_path,
single_episode_parquet_path,
hf_dataset,
):
"""
This factory allows to patch snapshot_download such that when called, it will create expected files rather
than making calls to the hub api. Its design allows to pass explicitly files which you want to be created.
"""
def _mock_snapshot_download_func(
info_dict=info, stats_dict=stats, tasks_dicts=tasks, episodes_dicts=episodes, hf_ds=hf_dataset
):
def _extract_episode_index_from_path(fpath: str) -> int:
path = Path(fpath)
if path.suffix == ".parquet" and path.stem.startswith("episode_"):
episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0
return episode_index
else:
return None
def _mock_snapshot_download(
repo_id: str,
local_dir: str | Path | None = None,
allow_patterns: str | list[str] | None = None,
ignore_patterns: str | list[str] | None = None,
*args,
**kwargs,
) -> str:
if not local_dir:
local_dir = LEROBOT_TEST_DIR
# List all possible files
all_files = []
meta_files = [INFO_PATH, STATS_PATH, TASKS_PATH, EPISODES_PATH]
all_files.extend(meta_files)
data_files = []
for episode_dict in episodes_dicts:
ep_idx = episode_dict["episode_index"]
ep_chunk = ep_idx // info_dict["chunks_size"]
data_path = info_dict["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
data_files.append(data_path)
all_files.extend(data_files)
allowed_files = filter_repo_objects(
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
)
# Create allowed files
for rel_path in allowed_files:
if rel_path.startswith("data/"):
episode_index = _extract_episode_index_from_path(rel_path)
if episode_index is not None:
_ = single_episode_parquet_path(local_dir, hf_ds, ep_idx=episode_index)
if rel_path == INFO_PATH:
_ = info_path(local_dir, info_dict)
elif rel_path == STATS_PATH:
_ = stats_path(local_dir, stats_dict)
elif rel_path == TASKS_PATH:
_ = tasks_path(local_dir, tasks_dicts)
elif rel_path == EPISODES_PATH:
_ = episode_path(local_dir, episodes_dicts)
else:
pass
return str(local_dir)
return _mock_snapshot_download
return _mock_snapshot_download_func