lerobot/tests/fixtures/files.py

95 lines
3.2 KiB
Python

import json
from pathlib import Path
import datasets
import jsonlines
import pyarrow.compute as pc
import pyarrow.parquet as pq
import pytest
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
@pytest.fixture(scope="session")
def info_path(info):
def _create_info_json_file(dir: Path, info_dict: dict = info) -> Path:
fpath = dir / INFO_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with open(fpath, "w") as f:
json.dump(info_dict, f, indent=4, ensure_ascii=False)
return fpath
return _create_info_json_file
@pytest.fixture(scope="session")
def stats_path(stats):
def _create_stats_json_file(dir: Path, stats_dict: dict = stats) -> Path:
fpath = dir / STATS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with open(fpath, "w") as f:
json.dump(stats_dict, f, indent=4, ensure_ascii=False)
return fpath
return _create_stats_json_file
@pytest.fixture(scope="session")
def tasks_path(tasks):
def _create_tasks_jsonl_file(dir: Path, task_dicts: list = tasks) -> Path:
fpath = dir / TASKS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(task_dicts)
return fpath
return _create_tasks_jsonl_file
@pytest.fixture(scope="session")
def episode_path(episodes):
def _create_episodes_jsonl_file(dir: Path, episode_dicts: list = episodes) -> Path:
fpath = dir / EPISODES_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(episode_dicts)
return fpath
return _create_episodes_jsonl_file
@pytest.fixture(scope="session")
def single_episode_parquet_path(hf_dataset, info):
def _create_single_episode_parquet(
dir: Path, hf_ds: datasets.Dataset = hf_dataset, ep_idx: int = 0
) -> Path:
data_path = info["data_path"]
chunks_size = info["chunks_size"]
ep_chunk = ep_idx // chunks_size
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
fpath.parent.mkdir(parents=True, exist_ok=True)
table = hf_ds.data.table
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
pq.write_table(ep_table, fpath)
return fpath
return _create_single_episode_parquet
@pytest.fixture(scope="session")
def multi_episode_parquet_path(hf_dataset, info):
def _create_multi_episode_parquet(dir: Path, hf_ds: datasets.Dataset = hf_dataset) -> Path:
data_path = info["data_path"]
chunks_size = info["chunks_size"]
total_episodes = info["total_episodes"]
for ep_idx in range(total_episodes):
ep_chunk = ep_idx // chunks_size
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
fpath.parent.mkdir(parents=True, exist_ok=True)
table = hf_ds.data.table
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
pq.write_table(ep_table, fpath)
return dir / "data"
return _create_multi_episode_parquet