Split fixtures into factories and files
This commit is contained in:
parent
c70b8d0abc
commit
1267c3e955
|
@ -24,7 +24,7 @@ from lerobot.common.utils.utils import init_hydra_config
|
||||||
from tests.utils import DEVICE, ROBOT_CONFIG_PATH_TEMPLATE, make_camera, make_motors_bus
|
from tests.utils import DEVICE, ROBOT_CONFIG_PATH_TEMPLATE, make_camera, make_motors_bus
|
||||||
|
|
||||||
# Import fixture modules as plugins
|
# Import fixture modules as plugins
|
||||||
pytest_plugins = ["tests.fixtures.dataset"]
|
pytest_plugins = ["tests.fixtures.dataset", "tests.fixtures.dataset_factories", "tests.fixtures.files"]
|
||||||
|
|
||||||
|
|
||||||
def pytest_collection_finish():
|
def pytest_collection_finish():
|
||||||
|
|
|
@ -1,20 +1,40 @@
|
||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import get_episode_data_index, hf_transform_to_torch
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.common.datasets.utils import get_episode_data_index
|
||||||
|
from tests.fixtures.defaults import DUMMY_CAMERA_KEYS
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def img_array_factory():
|
def empty_info(info_factory) -> dict:
|
||||||
def _create_img_array(width=100, height=100) -> np.ndarray:
|
return info_factory(
|
||||||
return np.random.randint(0, 256, size=(height, width, 3), dtype=np.uint8)
|
keys=[],
|
||||||
|
image_keys=[],
|
||||||
return _create_img_array
|
video_keys=[],
|
||||||
|
shapes={},
|
||||||
|
names={},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def tasks():
|
def info(info_factory) -> dict:
|
||||||
|
return info_factory(
|
||||||
|
total_episodes=4,
|
||||||
|
total_frames=420,
|
||||||
|
total_tasks=3,
|
||||||
|
total_videos=8,
|
||||||
|
total_chunks=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def stats(stats_factory) -> list:
|
||||||
|
return stats_factory()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def tasks() -> list:
|
||||||
return [
|
return [
|
||||||
{"task_index": 0, "task": "Pick up the block."},
|
{"task_index": 0, "task": "Pick up the block."},
|
||||||
{"task_index": 1, "task": "Open the box."},
|
{"task_index": 1, "task": "Open the box."},
|
||||||
|
@ -23,7 +43,7 @@ def tasks():
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def episode_dicts():
|
def episodes() -> list:
|
||||||
return [
|
return [
|
||||||
{"episode_index": 0, "tasks": ["Pick up the block."], "length": 100},
|
{"episode_index": 0, "tasks": ["Pick up the block."], "length": 100},
|
||||||
{"episode_index": 1, "tasks": ["Open the box."], "length": 80},
|
{"episode_index": 1, "tasks": ["Open the box."], "length": 80},
|
||||||
|
@ -33,120 +53,22 @@ def episode_dicts():
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def episode_data_index(episode_dicts):
|
def episode_data_index(episodes) -> dict:
|
||||||
return get_episode_data_index(episode_dicts)
|
return get_episode_data_index(episodes)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def hf_dataset(hf_dataset_factory, episode_dicts, tasks):
|
def hf_dataset(hf_dataset_factory) -> datasets.Dataset:
|
||||||
keys = ["state", "action"]
|
return hf_dataset_factory()
|
||||||
shapes = {
|
|
||||||
"state": 10,
|
|
||||||
"action": 10,
|
|
||||||
}
|
|
||||||
return hf_dataset_factory(episode_dicts, tasks, keys, shapes)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def hf_dataset_image(hf_dataset_factory, episode_dicts, tasks):
|
def hf_dataset_image(hf_dataset_factory) -> datasets.Dataset:
|
||||||
keys = ["state", "action"]
|
image_keys = DUMMY_CAMERA_KEYS
|
||||||
image_keys = ["image"]
|
return hf_dataset_factory(image_keys=image_keys)
|
||||||
shapes = {
|
|
||||||
"state": 10,
|
|
||||||
"action": 10,
|
|
||||||
"image": {
|
|
||||||
"width": 100,
|
|
||||||
"height": 70,
|
|
||||||
"channels": 3,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return hf_dataset_factory(episode_dicts, tasks, keys, shapes, image_keys=image_keys)
|
|
||||||
|
|
||||||
|
|
||||||
def get_task_index(tasks_dicts: dict, task: str) -> int:
|
|
||||||
"""
|
|
||||||
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
|
||||||
otherwise creates a new task_index.
|
|
||||||
"""
|
|
||||||
tasks = {d["task_index"]: d["task"] for d in tasks_dicts}
|
|
||||||
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
|
|
||||||
return task_to_task_index[task]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def hf_dataset_factory(img_array_factory):
|
def lerobot_dataset(lerobot_dataset_factory, tmp_path_factory) -> LeRobotDataset:
|
||||||
def _create_hf_dataset(
|
root = tmp_path_factory.getbasetemp()
|
||||||
episode_dicts: list[dict],
|
return lerobot_dataset_factory(root=root)
|
||||||
tasks: list[dict],
|
|
||||||
keys: list[str],
|
|
||||||
shapes: dict,
|
|
||||||
fps: int = 30,
|
|
||||||
image_keys: list[str] | None = None,
|
|
||||||
):
|
|
||||||
key_features = {
|
|
||||||
key: datasets.Sequence(length=shapes[key], feature=datasets.Value(dtype="float32"))
|
|
||||||
for key in keys
|
|
||||||
}
|
|
||||||
image_features = {key: datasets.Image() for key in image_keys} if image_keys else {}
|
|
||||||
common_features = {
|
|
||||||
"episode_index": datasets.Value(dtype="int64"),
|
|
||||||
"frame_index": datasets.Value(dtype="int64"),
|
|
||||||
"timestamp": datasets.Value(dtype="float32"),
|
|
||||||
"next.done": datasets.Value(dtype="bool"),
|
|
||||||
"index": datasets.Value(dtype="int64"),
|
|
||||||
"task_index": datasets.Value(dtype="int64"),
|
|
||||||
}
|
|
||||||
features = datasets.Features(
|
|
||||||
{
|
|
||||||
**key_features,
|
|
||||||
**image_features,
|
|
||||||
**common_features,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
episode_index_col = np.array([], dtype=np.int64)
|
|
||||||
frame_index_col = np.array([], dtype=np.int64)
|
|
||||||
timestamp_col = np.array([], dtype=np.float32)
|
|
||||||
next_done_col = np.array([], dtype=bool)
|
|
||||||
task_index = np.array([], dtype=np.int64)
|
|
||||||
|
|
||||||
for ep_dict in episode_dicts:
|
|
||||||
episode_index_col = np.concatenate(
|
|
||||||
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
|
|
||||||
)
|
|
||||||
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
|
|
||||||
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
|
|
||||||
next_done_ep = np.full(ep_dict["length"], False, dtype=bool)
|
|
||||||
next_done_ep[-1] = True
|
|
||||||
next_done_col = np.concatenate((next_done_col, next_done_ep))
|
|
||||||
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
|
|
||||||
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
|
|
||||||
|
|
||||||
index_col = np.arange(len(episode_index_col))
|
|
||||||
key_cols = {key: np.random.random((len(index_col), shapes[key])).astype(np.float32) for key in keys}
|
|
||||||
|
|
||||||
image_cols = {}
|
|
||||||
if image_keys:
|
|
||||||
for key in image_keys:
|
|
||||||
image_cols[key] = [
|
|
||||||
img_array_factory(width=shapes[key]["width"], height=shapes[key]["height"])
|
|
||||||
for _ in range(len(index_col))
|
|
||||||
]
|
|
||||||
|
|
||||||
dataset = datasets.Dataset.from_dict(
|
|
||||||
{
|
|
||||||
**key_cols,
|
|
||||||
**image_cols,
|
|
||||||
"episode_index": episode_index_col,
|
|
||||||
"frame_index": frame_index_col,
|
|
||||||
"timestamp": timestamp_col,
|
|
||||||
"next.done": next_done_col,
|
|
||||||
"index": index_col,
|
|
||||||
"task_index": task_index,
|
|
||||||
},
|
|
||||||
features=features,
|
|
||||||
)
|
|
||||||
dataset.set_transform(hf_transform_to_torch)
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
return _create_hf_dataset
|
|
||||||
|
|
|
@ -0,0 +1,253 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||||
|
from lerobot.common.datasets.utils import (
|
||||||
|
DEFAULT_CHUNK_SIZE,
|
||||||
|
DEFAULT_PARQUET_PATH,
|
||||||
|
DEFAULT_VIDEO_PATH,
|
||||||
|
hf_transform_to_torch,
|
||||||
|
)
|
||||||
|
from tests.fixtures.defaults import DUMMY_CAMERA_KEYS, DUMMY_KEYS, DUMMY_REPO_ID
|
||||||
|
|
||||||
|
|
||||||
|
def get_dummy_shapes(keys: list[str] | None = None, camera_keys: list[str] | None = None) -> dict:
|
||||||
|
shapes = {}
|
||||||
|
if keys:
|
||||||
|
shapes.update({key: 10 for key in keys})
|
||||||
|
if camera_keys:
|
||||||
|
shapes.update({key: {"width": 100, "height": 70, "channels": 3} for key in camera_keys})
|
||||||
|
return shapes
|
||||||
|
|
||||||
|
|
||||||
|
def get_task_index(tasks_dicts: dict, task: str) -> int:
|
||||||
|
"""
|
||||||
|
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
||||||
|
otherwise creates a new task_index.
|
||||||
|
"""
|
||||||
|
tasks = {d["task_index"]: d["task"] for d in tasks_dicts}
|
||||||
|
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
|
||||||
|
return task_to_task_index[task]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def img_array_factory():
|
||||||
|
def _create_img_array(width=100, height=100) -> np.ndarray:
|
||||||
|
return np.random.randint(0, 256, size=(height, width, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
return _create_img_array
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def info_factory():
|
||||||
|
def _create_info(
|
||||||
|
codebase_version: str = CODEBASE_VERSION,
|
||||||
|
fps: int = 30,
|
||||||
|
robot_type: str = "dummy_robot",
|
||||||
|
keys: list[str] = DUMMY_KEYS,
|
||||||
|
image_keys: list[str] | None = None,
|
||||||
|
video_keys: list[str] = DUMMY_CAMERA_KEYS,
|
||||||
|
shapes: dict | None = None,
|
||||||
|
names: dict | None = None,
|
||||||
|
total_episodes: int = 0,
|
||||||
|
total_frames: int = 0,
|
||||||
|
total_tasks: int = 0,
|
||||||
|
total_videos: int = 0,
|
||||||
|
total_chunks: int = 0,
|
||||||
|
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||||
|
data_path: str = DEFAULT_PARQUET_PATH,
|
||||||
|
videos_path: str = DEFAULT_VIDEO_PATH,
|
||||||
|
) -> dict:
|
||||||
|
if not image_keys:
|
||||||
|
image_keys = []
|
||||||
|
if not shapes:
|
||||||
|
shapes = get_dummy_shapes(keys=keys, camera_keys=[*image_keys, *video_keys])
|
||||||
|
if not names:
|
||||||
|
names = {key: [f"motor_{i}" for i in range(shapes[key])] for key in keys}
|
||||||
|
|
||||||
|
video_info = {"videos_path": videos_path}
|
||||||
|
for key in video_keys:
|
||||||
|
video_info[key] = {
|
||||||
|
"video.fps": fps,
|
||||||
|
"video.width": shapes[key]["width"],
|
||||||
|
"video.height": shapes[key]["height"],
|
||||||
|
"video.channels": shapes[key]["channels"],
|
||||||
|
"video.codec": "av1",
|
||||||
|
"video.pix_fmt": "yuv420p",
|
||||||
|
"video.is_depth_map": False,
|
||||||
|
"has_audio": False,
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"codebase_version": codebase_version,
|
||||||
|
"data_path": data_path,
|
||||||
|
"robot_type": robot_type,
|
||||||
|
"total_episodes": total_episodes,
|
||||||
|
"total_frames": total_frames,
|
||||||
|
"total_tasks": total_tasks,
|
||||||
|
"total_videos": total_videos,
|
||||||
|
"total_chunks": total_chunks,
|
||||||
|
"chunks_size": chunks_size,
|
||||||
|
"fps": fps,
|
||||||
|
"splits": {},
|
||||||
|
"keys": keys,
|
||||||
|
"video_keys": video_keys,
|
||||||
|
"image_keys": image_keys,
|
||||||
|
"shapes": shapes,
|
||||||
|
"names": names,
|
||||||
|
"videos": video_info if len(video_keys) > 0 else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
return _create_info
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def stats_factory():
|
||||||
|
def _create_stats(
|
||||||
|
keys: list[str] = DUMMY_KEYS,
|
||||||
|
image_keys: list[str] | None = None,
|
||||||
|
video_keys: list[str] = DUMMY_CAMERA_KEYS,
|
||||||
|
shapes: dict | None = None,
|
||||||
|
) -> dict:
|
||||||
|
if not image_keys:
|
||||||
|
image_keys = []
|
||||||
|
if not shapes:
|
||||||
|
shapes = get_dummy_shapes(keys=keys, camera_keys=[*image_keys, *video_keys])
|
||||||
|
stats = {}
|
||||||
|
for key in keys:
|
||||||
|
shape = shapes[key]
|
||||||
|
stats[key] = {
|
||||||
|
"max": np.full(shape, 1, dtype=np.float32).tolist(),
|
||||||
|
"mean": np.full(shape, 0.5, dtype=np.float32).tolist(),
|
||||||
|
"min": np.full(shape, 0, dtype=np.float32).tolist(),
|
||||||
|
"std": np.full(shape, 0.25, dtype=np.float32).tolist(),
|
||||||
|
}
|
||||||
|
for key in [*image_keys, *video_keys]:
|
||||||
|
shape = (3, 1, 1)
|
||||||
|
stats[key] = {
|
||||||
|
"max": np.full(shape, 1, dtype=np.float32).tolist(),
|
||||||
|
"mean": np.full(shape, 0.5, dtype=np.float32).tolist(),
|
||||||
|
"min": np.full(shape, 0, dtype=np.float32).tolist(),
|
||||||
|
"std": np.full(shape, 0.25, dtype=np.float32).tolist(),
|
||||||
|
}
|
||||||
|
return stats
|
||||||
|
|
||||||
|
return _create_stats
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def hf_dataset_factory(img_array_factory, episodes, tasks):
|
||||||
|
def _create_hf_dataset(
|
||||||
|
episode_dicts: list[dict] = episodes,
|
||||||
|
task_dicts: list[dict] = tasks,
|
||||||
|
keys: list[str] = DUMMY_KEYS,
|
||||||
|
image_keys: list[str] | None = None,
|
||||||
|
shapes: dict | None = None,
|
||||||
|
fps: int = 30,
|
||||||
|
) -> datasets.Dataset:
|
||||||
|
if not image_keys:
|
||||||
|
image_keys = []
|
||||||
|
if not shapes:
|
||||||
|
shapes = get_dummy_shapes(keys=keys, camera_keys=image_keys)
|
||||||
|
key_features = {
|
||||||
|
key: datasets.Sequence(length=shapes[key], feature=datasets.Value(dtype="float32"))
|
||||||
|
for key in keys
|
||||||
|
}
|
||||||
|
image_features = {key: datasets.Image() for key in image_keys} if image_keys else {}
|
||||||
|
common_features = {
|
||||||
|
"episode_index": datasets.Value(dtype="int64"),
|
||||||
|
"frame_index": datasets.Value(dtype="int64"),
|
||||||
|
"timestamp": datasets.Value(dtype="float32"),
|
||||||
|
"next.done": datasets.Value(dtype="bool"),
|
||||||
|
"index": datasets.Value(dtype="int64"),
|
||||||
|
"task_index": datasets.Value(dtype="int64"),
|
||||||
|
}
|
||||||
|
features = datasets.Features(
|
||||||
|
{
|
||||||
|
**key_features,
|
||||||
|
**image_features,
|
||||||
|
**common_features,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
episode_index_col = np.array([], dtype=np.int64)
|
||||||
|
frame_index_col = np.array([], dtype=np.int64)
|
||||||
|
timestamp_col = np.array([], dtype=np.float32)
|
||||||
|
next_done_col = np.array([], dtype=bool)
|
||||||
|
task_index = np.array([], dtype=np.int64)
|
||||||
|
|
||||||
|
for ep_dict in episode_dicts:
|
||||||
|
episode_index_col = np.concatenate(
|
||||||
|
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
|
||||||
|
)
|
||||||
|
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
|
||||||
|
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
|
||||||
|
next_done_ep = np.full(ep_dict["length"], False, dtype=bool)
|
||||||
|
next_done_ep[-1] = True
|
||||||
|
next_done_col = np.concatenate((next_done_col, next_done_ep))
|
||||||
|
ep_task_index = get_task_index(task_dicts, ep_dict["tasks"][0])
|
||||||
|
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
|
||||||
|
|
||||||
|
index_col = np.arange(len(episode_index_col))
|
||||||
|
key_cols = {key: np.random.random((len(index_col), shapes[key])).astype(np.float32) for key in keys}
|
||||||
|
|
||||||
|
image_cols = {}
|
||||||
|
if image_keys:
|
||||||
|
for key in image_keys:
|
||||||
|
image_cols[key] = [
|
||||||
|
img_array_factory(width=shapes[key]["width"], height=shapes[key]["height"])
|
||||||
|
for _ in range(len(index_col))
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset = datasets.Dataset.from_dict(
|
||||||
|
{
|
||||||
|
**key_cols,
|
||||||
|
**image_cols,
|
||||||
|
"episode_index": episode_index_col,
|
||||||
|
"frame_index": frame_index_col,
|
||||||
|
"timestamp": timestamp_col,
|
||||||
|
"next.done": next_done_col,
|
||||||
|
"index": index_col,
|
||||||
|
"task_index": task_index,
|
||||||
|
},
|
||||||
|
features=features,
|
||||||
|
)
|
||||||
|
dataset.set_transform(hf_transform_to_torch)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
return _create_hf_dataset
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def lerobot_dataset_factory(
|
||||||
|
info,
|
||||||
|
info_path,
|
||||||
|
stats,
|
||||||
|
stats_path,
|
||||||
|
episodes,
|
||||||
|
episode_path,
|
||||||
|
tasks,
|
||||||
|
tasks_path,
|
||||||
|
hf_dataset,
|
||||||
|
multi_episode_parquet_path,
|
||||||
|
):
|
||||||
|
def _create_lerobot_dataset(
|
||||||
|
root: Path,
|
||||||
|
info_dict: dict = info,
|
||||||
|
stats_dict: dict = stats,
|
||||||
|
episode_dicts: list[dict] = episodes,
|
||||||
|
task_dicts: list[dict] = tasks,
|
||||||
|
hf_ds: datasets.Dataset = hf_dataset,
|
||||||
|
) -> LeRobotDataset:
|
||||||
|
root.mkdir(parents=True, exist_ok=True)
|
||||||
|
# Create local files
|
||||||
|
_ = info_path(root, info_dict)
|
||||||
|
_ = stats_path(root, stats_dict)
|
||||||
|
_ = tasks_path(root, task_dicts)
|
||||||
|
_ = episode_path(root, episode_dicts)
|
||||||
|
_ = multi_episode_parquet_path(root, hf_ds)
|
||||||
|
return LeRobotDataset(repo_id=DUMMY_REPO_ID, root=root, local_files_only=True)
|
||||||
|
|
||||||
|
return _create_lerobot_dataset
|
|
@ -0,0 +1,3 @@
|
||||||
|
DUMMY_REPO_ID = "dummy/repo"
|
||||||
|
DUMMY_KEYS = ["state", "action"]
|
||||||
|
DUMMY_CAMERA_KEYS = ["laptop", "phone"]
|
|
@ -0,0 +1,94 @@
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
import jsonlines
|
||||||
|
import pyarrow.compute as pc
|
||||||
|
import pyarrow.parquet as pq
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def info_path(info):
|
||||||
|
def _create_info_json_file(dir: Path, info_dict: dict = info) -> Path:
|
||||||
|
fpath = dir / INFO_PATH
|
||||||
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(fpath, "w") as f:
|
||||||
|
json.dump(info_dict, f, indent=4, ensure_ascii=False)
|
||||||
|
return fpath
|
||||||
|
|
||||||
|
return _create_info_json_file
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def stats_path(stats):
|
||||||
|
def _create_stats_json_file(dir: Path, stats_dict: dict = stats) -> Path:
|
||||||
|
fpath = dir / STATS_PATH
|
||||||
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(fpath, "w") as f:
|
||||||
|
json.dump(stats_dict, f, indent=4, ensure_ascii=False)
|
||||||
|
return fpath
|
||||||
|
|
||||||
|
return _create_stats_json_file
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def tasks_path(tasks):
|
||||||
|
def _create_tasks_jsonl_file(dir: Path, tasks_dicts: list = tasks) -> Path:
|
||||||
|
fpath = dir / TASKS_PATH
|
||||||
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with jsonlines.open(fpath, "w") as writer:
|
||||||
|
writer.write_all(tasks_dicts)
|
||||||
|
return fpath
|
||||||
|
|
||||||
|
return _create_tasks_jsonl_file
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def episode_path(episodes):
|
||||||
|
def _create_episodes_jsonl_file(dir: Path, episode_dicts: list = episodes) -> Path:
|
||||||
|
fpath = dir / EPISODES_PATH
|
||||||
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with jsonlines.open(fpath, "w") as writer:
|
||||||
|
writer.write_all(episode_dicts)
|
||||||
|
return fpath
|
||||||
|
|
||||||
|
return _create_episodes_jsonl_file
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def single_episode_parquet_path(hf_dataset, info):
|
||||||
|
def _create_single_episode_parquet(
|
||||||
|
dir: Path, hf_ds: datasets.Dataset = hf_dataset, ep_idx: int = 0
|
||||||
|
) -> Path:
|
||||||
|
data_path = info["data_path"]
|
||||||
|
chunks_size = info["chunks_size"]
|
||||||
|
ep_chunk = ep_idx // chunks_size
|
||||||
|
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||||
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
table = hf_ds.data.table
|
||||||
|
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||||
|
pq.write_table(ep_table, fpath)
|
||||||
|
return fpath
|
||||||
|
|
||||||
|
return _create_single_episode_parquet
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def multi_episode_parquet_path(hf_dataset, info):
|
||||||
|
def _create_multi_episode_parquet(dir: Path, hf_ds: datasets.Dataset = hf_dataset) -> Path:
|
||||||
|
data_path = info["data_path"]
|
||||||
|
chunks_size = info["chunks_size"]
|
||||||
|
total_episodes = info["total_episodes"]
|
||||||
|
for ep_idx in range(total_episodes):
|
||||||
|
ep_chunk = ep_idx // chunks_size
|
||||||
|
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||||
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
table = hf_ds.data.table
|
||||||
|
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||||
|
pq.write_table(ep_table, fpath)
|
||||||
|
return dir / "data"
|
||||||
|
|
||||||
|
return _create_multi_episode_parquet
|
|
@ -8,25 +8,21 @@ from lerobot.common.datasets.utils import (
|
||||||
get_delta_indices,
|
get_delta_indices,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
|
from tests.fixtures.defaults import DUMMY_KEYS
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def synced_hf_dataset_factory(hf_dataset_factory, episode_dicts, tasks):
|
def synced_hf_dataset_factory(hf_dataset_factory):
|
||||||
def _create_synced_hf_dataset(fps: int = 30, keys: list | None = None) -> Dataset:
|
def _create_synced_hf_dataset(fps: int = 30) -> Dataset:
|
||||||
if not keys:
|
return hf_dataset_factory(fps=fps)
|
||||||
keys = ["state", "action"]
|
|
||||||
shapes = {key: 10 for key in keys}
|
|
||||||
return hf_dataset_factory(episode_dicts, tasks, keys, shapes, fps=fps)
|
|
||||||
|
|
||||||
return _create_synced_hf_dataset
|
return _create_synced_hf_dataset
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def unsynced_hf_dataset_factory(synced_hf_dataset_factory):
|
def unsynced_hf_dataset_factory(synced_hf_dataset_factory):
|
||||||
def _create_unsynced_hf_dataset(
|
def _create_unsynced_hf_dataset(fps: int = 30, tolerance_s: float = 1e-4) -> Dataset:
|
||||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None
|
hf_dataset = synced_hf_dataset_factory(fps=fps)
|
||||||
) -> Dataset:
|
|
||||||
hf_dataset = synced_hf_dataset_factory(fps=fps, keys=keys)
|
|
||||||
features = hf_dataset.features
|
features = hf_dataset.features
|
||||||
df = hf_dataset.to_pandas()
|
df = hf_dataset.to_pandas()
|
||||||
dtype = df["timestamp"].dtype # This is to avoid pandas type warning
|
dtype = df["timestamp"].dtype # This is to avoid pandas type warning
|
||||||
|
@ -41,10 +37,8 @@ def unsynced_hf_dataset_factory(synced_hf_dataset_factory):
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def slightly_off_hf_dataset_factory(synced_hf_dataset_factory):
|
def slightly_off_hf_dataset_factory(synced_hf_dataset_factory):
|
||||||
def _create_slightly_off_hf_dataset(
|
def _create_slightly_off_hf_dataset(fps: int = 30, tolerance_s: float = 1e-4) -> Dataset:
|
||||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None
|
hf_dataset = synced_hf_dataset_factory(fps=fps)
|
||||||
) -> Dataset:
|
|
||||||
hf_dataset = synced_hf_dataset_factory(fps=fps, keys=keys)
|
|
||||||
features = hf_dataset.features
|
features = hf_dataset.features
|
||||||
df = hf_dataset.to_pandas()
|
df = hf_dataset.to_pandas()
|
||||||
dtype = df["timestamp"].dtype # This is to avoid pandas type warning
|
dtype = df["timestamp"].dtype # This is to avoid pandas type warning
|
||||||
|
@ -59,9 +53,7 @@ def slightly_off_hf_dataset_factory(synced_hf_dataset_factory):
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def valid_delta_timestamps_factory():
|
def valid_delta_timestamps_factory():
|
||||||
def _create_valid_delta_timestamps(fps: int = 30, keys: list | None = None) -> dict:
|
def _create_valid_delta_timestamps(fps: int = 30, keys: list = DUMMY_KEYS) -> dict:
|
||||||
if not keys:
|
|
||||||
keys = ["state", "action"]
|
|
||||||
delta_timestamps = {key: [i * (1 / fps) for i in range(-10, 10)] for key in keys}
|
delta_timestamps = {key: [i * (1 / fps) for i in range(-10, 10)] for key in keys}
|
||||||
return delta_timestamps
|
return delta_timestamps
|
||||||
|
|
||||||
|
@ -71,10 +63,8 @@ def valid_delta_timestamps_factory():
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||||
def _create_invalid_delta_timestamps(
|
def _create_invalid_delta_timestamps(
|
||||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None
|
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_KEYS
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if not keys:
|
|
||||||
keys = ["state", "action"]
|
|
||||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
||||||
# Modify a single timestamp just outside tolerance
|
# Modify a single timestamp just outside tolerance
|
||||||
for key in keys:
|
for key in keys:
|
||||||
|
@ -87,10 +77,8 @@ def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||||
def _create_slightly_off_delta_timestamps(
|
def _create_slightly_off_delta_timestamps(
|
||||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list | None = None
|
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_KEYS
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if not keys:
|
|
||||||
keys = ["state", "action"]
|
|
||||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
||||||
# Modify a single timestamp just inside tolerance
|
# Modify a single timestamp just inside tolerance
|
||||||
for key in delta_timestamps:
|
for key in delta_timestamps:
|
||||||
|
@ -102,9 +90,7 @@ def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def delta_indices(keys: list | None = None) -> dict:
|
def delta_indices(keys: list = DUMMY_KEYS) -> dict:
|
||||||
if not keys:
|
|
||||||
keys = ["state", "action"]
|
|
||||||
return {key: list(range(-10, 10)) for key in keys}
|
return {key: list(range(-10, 10)) for key in keys}
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue