lerobot/tests/fixtures/files.py

126 lines
3.8 KiB
Python

from pathlib import Path
import datasets
import pandas as pd
import pyarrow.compute as pc
import pyarrow.parquet as pq
import pytest
from lerobot.common.datasets.utils import (
write_episodes,
write_hf_dataset,
write_info,
write_stats,
write_tasks,
)
@pytest.fixture(scope="session")
def create_info(info_factory):
def _create_info(dir: Path, info: dict | None = None):
if info is None:
info = info_factory()
write_info(info, dir)
return _create_info
@pytest.fixture(scope="session")
def create_stats(stats_factory):
def _create_stats(dir: Path, stats: dict | None = None):
if stats is None:
stats = stats_factory()
write_stats(stats, dir)
return _create_stats
# @pytest.fixture(scope="session")
# def create_episodes_stats(episodes_stats_factory):
# def _create_episodes_stats(dir: Path, episodes_stats: Dataset | None = None):
# if episodes_stats is None:
# episodes_stats = episodes_stats_factory()
# write_episodes_stats(episodes_stats, dir)
# return _create_episodes_stats
@pytest.fixture(scope="session")
def create_tasks(tasks_factory):
def _create_tasks(dir: Path, tasks: pd.DataFrame | None = None):
if tasks is None:
tasks = tasks_factory()
write_tasks(tasks, dir)
return _create_tasks
@pytest.fixture(scope="session")
def create_episodes(episodes_factory):
def _create_episodes(dir: Path, episodes: datasets.Dataset | None = None):
if episodes is None:
episodes = episodes_factory()
write_episodes(episodes, dir)
return _create_episodes
@pytest.fixture(scope="session")
def create_hf_dataset(hf_dataset_factory):
def _create_hf_dataset(dir: Path, hf_dataset: datasets.Dataset | None = None):
if hf_dataset is None:
hf_dataset = hf_dataset_factory()
write_hf_dataset(hf_dataset, dir)
return _create_hf_dataset
@pytest.fixture(scope="session")
def single_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_single_episode_parquet(
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path:
raise NotImplementedError()
if info is None:
info = info_factory()
if hf_dataset is None:
hf_dataset = hf_dataset_factory()
data_path = info["data_path"]
chunks_size = info["chunks_size"]
ep_chunk = ep_idx // chunks_size
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
fpath.parent.mkdir(parents=True, exist_ok=True)
table = hf_dataset.data.table
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
pq.write_table(ep_table, fpath)
return fpath
return _create_single_episode_parquet
@pytest.fixture(scope="session")
def multi_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_multi_episode_parquet(
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path:
raise NotImplementedError()
if info is None:
info = info_factory()
if hf_dataset is None:
hf_dataset = hf_dataset_factory()
data_path = info["data_path"]
chunks_size = info["chunks_size"]
total_episodes = info["total_episodes"]
for ep_idx in range(total_episodes):
ep_chunk = ep_idx // chunks_size
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
fpath.parent.mkdir(parents=True, exist_ok=True)
table = hf_dataset.data.table
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
pq.write_table(ep_table, fpath)
return dir / "data"
return _create_multi_episode_parquet