lerobot/tests/fixtures/dataset.py

68 lines
1.7 KiB
Python
Raw Normal View History

2024-10-31 20:46:46 +08:00
import datasets
import pytest
from lerobot.common.datasets.utils import get_episode_data_index
from tests.fixtures.defaults import DUMMY_CAMERA_KEYS
2024-10-31 20:46:46 +08:00
@pytest.fixture(scope="session")
def empty_info(info_factory) -> dict:
return info_factory(
keys=[],
image_keys=[],
video_keys=[],
shapes={},
names={},
)
2024-10-31 20:46:46 +08:00
@pytest.fixture(scope="session")
def info(info_factory) -> dict:
return info_factory(
total_episodes=4,
total_frames=420,
total_tasks=3,
total_videos=8,
total_chunks=1,
)
@pytest.fixture(scope="session")
def stats(stats_factory) -> list:
return stats_factory()
2024-10-31 20:46:46 +08:00
@pytest.fixture(scope="session")
def tasks() -> list:
2024-10-31 20:46:46 +08:00
return [
{"task_index": 0, "task": "Pick up the block."},
{"task_index": 1, "task": "Open the box."},
{"task_index": 2, "task": "Make paperclips."},
]
@pytest.fixture(scope="session")
def episodes() -> list:
2024-10-31 20:46:46 +08:00
return [
{"episode_index": 0, "tasks": ["Pick up the block."], "length": 100},
{"episode_index": 1, "tasks": ["Open the box."], "length": 80},
{"episode_index": 2, "tasks": ["Pick up the block."], "length": 90},
{"episode_index": 3, "tasks": ["Make paperclips."], "length": 150},
]
@pytest.fixture(scope="session")
def episode_data_index(episodes) -> dict:
return get_episode_data_index(episodes)
2024-10-31 20:46:46 +08:00
@pytest.fixture(scope="session")
def hf_dataset(hf_dataset_factory) -> datasets.Dataset:
return hf_dataset_factory()
2024-10-31 20:46:46 +08:00
@pytest.fixture(scope="session")
def hf_dataset_image(hf_dataset_factory) -> datasets.Dataset:
image_keys = DUMMY_CAMERA_KEYS
return hf_dataset_factory(image_keys=image_keys)