Mock snapshot_download
This commit is contained in:
parent
5ea7c78237
commit
cd1509d805
|
@ -24,7 +24,12 @@ from lerobot.common.utils.utils import init_hydra_config
|
|||
from tests.utils import DEVICE, ROBOT_CONFIG_PATH_TEMPLATE, make_camera, make_motors_bus
|
||||
|
||||
# Import fixture modules as plugins
|
||||
pytest_plugins = ["tests.fixtures.dataset", "tests.fixtures.dataset_factories", "tests.fixtures.files"]
|
||||
pytest_plugins = [
|
||||
"tests.fixtures.dataset",
|
||||
"tests.fixtures.dataset_factories",
|
||||
"tests.fixtures.files",
|
||||
"tests.fixtures.hub",
|
||||
]
|
||||
|
||||
|
||||
def pytest_collection_finish():
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
|
@ -223,31 +224,39 @@ def hf_dataset_factory(img_array_factory, episodes, tasks):
|
|||
@pytest.fixture(scope="session")
|
||||
def lerobot_dataset_factory(
|
||||
info,
|
||||
info_path,
|
||||
stats,
|
||||
stats_path,
|
||||
episodes,
|
||||
episode_path,
|
||||
tasks,
|
||||
tasks_path,
|
||||
hf_dataset,
|
||||
multi_episode_parquet_path,
|
||||
mock_snapshot_download_factory,
|
||||
):
|
||||
def _create_lerobot_dataset(
|
||||
root: Path,
|
||||
info_dict: dict = info,
|
||||
stats_dict: dict = stats,
|
||||
episode_dicts: list[dict] = episodes,
|
||||
task_dicts: list[dict] = tasks,
|
||||
episode_dicts: list[dict] = episodes,
|
||||
hf_ds: datasets.Dataset = hf_dataset,
|
||||
**kwargs,
|
||||
) -> LeRobotDataset:
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
# Create local files
|
||||
_ = info_path(root, info_dict)
|
||||
_ = stats_path(root, stats_dict)
|
||||
_ = tasks_path(root, task_dicts)
|
||||
_ = episode_path(root, episode_dicts)
|
||||
_ = multi_episode_parquet_path(root, hf_ds)
|
||||
return LeRobotDataset(repo_id=DUMMY_REPO_ID, root=root, local_files_only=True)
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
info_dict=info_dict,
|
||||
stats_dict=stats_dict,
|
||||
tasks_dicts=task_dicts,
|
||||
episodes_dicts=episode_dicts,
|
||||
hf_ds=hf_ds,
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.get_hub_safe_version"
|
||||
) as mock_get_hub_safe_version_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||
) as mock_snapshot_download_patch,
|
||||
):
|
||||
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
|
||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
||||
|
||||
return LeRobotDataset(repo_id=DUMMY_REPO_ID, root=root, **kwargs)
|
||||
|
||||
return _create_lerobot_dataset
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
|
||||
|
||||
LEROBOT_TEST_DIR = LEROBOT_HOME / "_testing"
|
||||
DUMMY_REPO_ID = "dummy/repo"
|
||||
DUMMY_KEYS = ["state", "action"]
|
||||
DUMMY_CAMERA_KEYS = ["laptop", "phone"]
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from huggingface_hub.utils import filter_repo_objects
|
||||
|
||||
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
|
||||
from tests.fixtures.defaults import LEROBOT_TEST_DIR
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_snapshot_download_factory(
|
||||
info,
|
||||
info_path,
|
||||
stats,
|
||||
stats_path,
|
||||
tasks,
|
||||
tasks_path,
|
||||
episodes,
|
||||
episode_path,
|
||||
single_episode_parquet_path,
|
||||
hf_dataset,
|
||||
):
|
||||
"""
|
||||
This factory allows to patch snapshot_download such that when called, it will create expected files rather
|
||||
than making calls to the hub api. Its design allows to pass explicitly files which you want to be created.
|
||||
"""
|
||||
|
||||
def _mock_snapshot_download_func(
|
||||
info_dict=info, stats_dict=stats, tasks_dicts=tasks, episodes_dicts=episodes, hf_ds=hf_dataset
|
||||
):
|
||||
def _extract_episode_index_from_path(fpath: str) -> int:
|
||||
path = Path(fpath)
|
||||
if path.suffix == ".parquet" and path.stem.startswith("episode_"):
|
||||
episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0
|
||||
return episode_index
|
||||
else:
|
||||
return None
|
||||
|
||||
def _mock_snapshot_download(
|
||||
repo_id: str,
|
||||
local_dir: str | Path | None = None,
|
||||
allow_patterns: str | list[str] | None = None,
|
||||
ignore_patterns: str | list[str] | None = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if not local_dir:
|
||||
local_dir = LEROBOT_TEST_DIR
|
||||
|
||||
# List all possible files
|
||||
all_files = []
|
||||
meta_files = [INFO_PATH, STATS_PATH, TASKS_PATH, EPISODES_PATH]
|
||||
all_files.extend(meta_files)
|
||||
|
||||
data_files = []
|
||||
for episode_dict in episodes_dicts:
|
||||
ep_idx = episode_dict["episode_index"]
|
||||
ep_chunk = ep_idx // info_dict["chunks_size"]
|
||||
data_path = info_dict["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
data_files.append(data_path)
|
||||
all_files.extend(data_files)
|
||||
|
||||
allowed_files = filter_repo_objects(
|
||||
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
|
||||
)
|
||||
|
||||
# Create allowed files
|
||||
for rel_path in allowed_files:
|
||||
if rel_path.startswith("data/"):
|
||||
episode_index = _extract_episode_index_from_path(rel_path)
|
||||
if episode_index is not None:
|
||||
_ = single_episode_parquet_path(local_dir, hf_ds, ep_idx=episode_index)
|
||||
if rel_path == INFO_PATH:
|
||||
_ = info_path(local_dir, info_dict)
|
||||
elif rel_path == STATS_PATH:
|
||||
_ = stats_path(local_dir, stats_dict)
|
||||
elif rel_path == TASKS_PATH:
|
||||
_ = tasks_path(local_dir, tasks_dicts)
|
||||
elif rel_path == EPISODES_PATH:
|
||||
_ = episode_path(local_dir, episodes_dicts)
|
||||
else:
|
||||
pass
|
||||
return str(local_dir)
|
||||
|
||||
return _mock_snapshot_download
|
||||
|
||||
return _mock_snapshot_download_func
|
Loading…
Reference in New Issue