Fix tests

This commit is contained in:
Simon Alibert 2024-11-05 19:09:12 +01:00
parent aed9f4036a
commit f3630ad910
13 changed files with 437 additions and 496 deletions

View File

@ -22,8 +22,6 @@ import numpy as np
import PIL.Image import PIL.Image
import torch import torch
DEFAULT_IMAGE_PATH = "{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
def safe_stop_image_writer(func): def safe_stop_image_writer(func):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
@ -87,7 +85,7 @@ def worker_process(queue: queue.Queue, num_threads: int):
t.join() t.join()
class ImageWriter: class AsyncImageWriter:
""" """
This class abstract away the initialisation of processes or/and threads to This class abstract away the initialisation of processes or/and threads to
save images on disk asynchrounously, which is critical to control a robot and record data save images on disk asynchrounously, which is critical to control a robot and record data
@ -102,11 +100,7 @@ class ImageWriter:
the number of threads. If it is still not stable, try to use 1 subprocess, or more. the number of threads. If it is still not stable, try to use 1 subprocess, or more.
""" """
def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1): def __init__(self, num_processes: int = 0, num_threads: int = 1):
self.write_dir = write_dir
self.write_dir.mkdir(parents=True, exist_ok=True)
self.image_path = DEFAULT_IMAGE_PATH
self.num_processes = num_processes self.num_processes = num_processes
self.num_threads = num_threads self.num_threads = num_threads
self.queue = None self.queue = None
@ -134,17 +128,6 @@ class ImageWriter:
p.start() p.start()
self.processes.append(p) self.processes.append(p)
def get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
fpath = self.image_path.format(
image_key=image_key, episode_index=episode_index, frame_index=frame_index
)
return self.write_dir / fpath
def get_episode_dir(self, episode_index: int, image_key: str) -> Path:
return self.get_image_file_path(
episode_index=episode_index, image_key=image_key, frame_index=0
).parent
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path): def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
# Convert tensor to numpy array to minimize main process time # Convert tensor to numpy array to minimize main process time

View File

@ -22,15 +22,18 @@ from pathlib import Path
from typing import Callable from typing import Callable
import datasets import datasets
import numpy as np
import PIL.Image
import torch import torch
import torch.utils import torch.utils
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import snapshot_download, upload_folder from huggingface_hub import snapshot_download, upload_folder
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
from lerobot.common.datasets.image_writer import ImageWriter from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
DEFAULT_FEATURES, DEFAULT_FEATURES,
DEFAULT_IMAGE_PATH,
EPISODES_PATH, EPISODES_PATH,
INFO_PATH, INFO_PATH,
STATS_PATH, STATS_PATH,
@ -44,6 +47,7 @@ from lerobot.common.datasets.utils import (
get_delta_indices, get_delta_indices,
get_episode_data_index, get_episode_data_index,
get_features_from_robot, get_features_from_robot,
get_hf_features_from_features,
get_hub_safe_version, get_hub_safe_version,
hf_transform_to_torch, hf_transform_to_torch,
load_episodes, load_episodes,
@ -140,14 +144,9 @@ class LeRobotDatasetMetadata:
@property @property
def features(self) -> dict[str, dict]: def features(self) -> dict[str, dict]:
"""""" """All features contained in the dataset."""
return self.info["features"] return self.info["features"]
@property
def keys(self) -> list[str]:
"""Keys to access non-image data (state, actions etc.)."""
return self.info["keys"]
@property @property
def image_keys(self) -> list[str]: def image_keys(self) -> list[str]:
"""Keys to access visual modalities stored as images.""" """Keys to access visual modalities stored as images."""
@ -268,7 +267,7 @@ class LeRobotDatasetMetadata:
obj.root = root if root is not None else LEROBOT_HOME / repo_id obj.root = root if root is not None else LEROBOT_HOME / repo_id
if robot is not None: if robot is not None:
features = get_features_from_robot(robot) features = get_features_from_robot(robot, use_videos)
robot_type = robot.robot_type robot_type = robot.robot_type
if not all(cam.fps == fps for cam in robot.cameras.values()): if not all(cam.fps == fps for cam in robot.cameras.values()):
logging.warning( logging.warning(
@ -522,35 +521,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
return len(self.episodes) if self.episodes is not None else self.meta.total_episodes return len(self.episodes) if self.episodes is not None else self.meta.total_episodes
@property @property
def features(self) -> list[str]: def features(self) -> dict[str, dict]:
return list(self._features) + self.meta.video_keys return self.meta.features
@property @property
def _features(self) -> datasets.Features: def hf_features(self) -> datasets.Features:
"""Features of the hf_dataset.""" """Features of the hf_dataset."""
if self.hf_dataset is not None: if self.hf_dataset is not None:
return self.hf_dataset.features return self.hf_dataset.features
elif self.episode_buffer is None: else:
raise NotImplementedError( return get_hf_features_from_features(self.features)
"Dataset features must be infered from an existing hf_dataset or episode_buffer."
)
features = {}
for key in self.episode_buffer:
if key in ["episode_index", "frame_index", "index", "task_index"]:
features[key] = datasets.Value(dtype="int64")
elif key in ["next.done", "next.success"]:
features[key] = datasets.Value(dtype="bool")
elif key in ["timestamp", "next.reward"]:
features[key] = datasets.Value(dtype="float32")
elif key in self.meta.image_keys:
features[key] = datasets.Image()
elif key in self.meta.keys:
features[key] = datasets.Sequence(
length=self.meta.shapes[key], feature=datasets.Value(dtype="float32")
)
return datasets.Features(features)
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]: def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
ep_start = self.episode_data_index["from"][ep_idx] ep_start = self.episode_data_index["from"][ep_idx]
@ -650,17 +630,26 @@ class LeRobotDataset(torch.utils.data.Dataset):
) )
def _create_episode_buffer(self, episode_index: int | None = None) -> dict: def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
# TODO(aliberts): Handle resume current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
return { return {
"size": 0, "size": 0,
"episode_index": self.meta.total_episodes if episode_index is None else episode_index, **{key: [] if key != "episode_index" else current_ep_idx for key in self.features},
"task_index": None,
"frame_index": [],
"timestamp": [],
**{key: [] for key in self.meta.features},
**{key: [] for key in self.meta.image_keys},
} }
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
fpath = DEFAULT_IMAGE_PATH.format(
image_key=image_key, episode_index=episode_index, frame_index=frame_index
)
return self.root / fpath
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
if self.image_writer is None:
if isinstance(image, torch.Tensor):
image = image.cpu().numpy()
write_image(image, fpath)
else:
self.image_writer.save_image(image=image, fpath=fpath)
def add_frame(self, frame: dict) -> None: def add_frame(self, frame: dict) -> None:
""" """
This function only adds the frame to the episode_buffer. Apart from images which are written in a This function only adds the frame to the episode_buffer. Apart from images which are written in a
@ -668,35 +657,25 @@ class LeRobotDataset(torch.utils.data.Dataset):
then needs to be called. then needs to be called.
""" """
frame_index = self.episode_buffer["size"] frame_index = self.episode_buffer["size"]
self.episode_buffer["frame_index"].append(frame_index) for key, ft in self.features.items():
self.episode_buffer["timestamp"].append(frame_index / self.fps) if key == "frame_index":
self.episode_buffer["next.done"].append(False) self.episode_buffer[key].append(frame_index)
elif key == "timestamp":
# Save all observed modalities except images self.episode_buffer[key].append(frame_index / self.fps)
for key in self.meta.keys: elif key in frame and ft["dtype"] not in ["image", "video"]:
self.episode_buffer[key].append(frame[key]) self.episode_buffer[key].append(frame[key])
elif key in frame and ft["dtype"] in ["image", "video"]:
img_path = self._get_image_file_path(
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
)
if frame_index == 0:
img_path.parent.mkdir(parents=True, exist_ok=True)
self._save_image(frame[key], img_path)
if ft["dtype"] == "image":
self.episode_buffer[key].append(str(img_path))
self.episode_buffer["size"] += 1 self.episode_buffer["size"] += 1
if self.image_writer is None:
return
# Save images
for cam_key in self.meta.camera_keys:
img_path = self.image_writer.get_image_file_path(
episode_index=self.episode_buffer["episode_index"], image_key=cam_key, frame_index=frame_index
)
if frame_index == 0:
img_path.parent.mkdir(parents=True, exist_ok=True)
self.image_writer.save_image(
image=frame[cam_key],
fpath=img_path,
)
if cam_key in self.meta.image_keys:
self.episode_buffer[cam_key].append(str(img_path))
def add_episode(self, task: str, encode_videos: bool = False) -> None: def add_episode(self, task: str, encode_videos: bool = False) -> None:
""" """
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
@ -714,23 +693,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
raise NotImplementedError() raise NotImplementedError()
task_index = self.meta.get_task_index(task) task_index = self.meta.get_task_index(task)
self.episode_buffer["next.done"][-1] = True
for key in self.episode_buffer: if not set(self.episode_buffer.keys()) == set(self.features):
if key in self.meta.image_keys: raise ValueError()
continue
elif key in self.meta.keys: for key, ft in self.features.items():
self.episode_buffer[key] = torch.stack(self.episode_buffer[key]) if key == "index":
self.episode_buffer[key] = np.arange(
self.meta.total_frames, self.meta.total_frames + episode_length
)
elif key == "episode_index": elif key == "episode_index":
self.episode_buffer[key] = torch.full((episode_length,), episode_index) self.episode_buffer[key] = np.full((episode_length,), episode_index)
elif key == "task_index": elif key == "task_index":
self.episode_buffer[key] = torch.full((episode_length,), task_index) self.episode_buffer[key] = np.full((episode_length,), task_index)
else: elif ft["dtype"] in ["image", "video"]:
continue
elif ft["shape"][0] == 1:
self.episode_buffer[key] = torch.tensor(self.episode_buffer[key]) self.episode_buffer[key] = torch.tensor(self.episode_buffer[key])
elif ft["shape"][0] > 1:
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
else:
raise ValueError()
self.episode_buffer["index"] = torch.arange(
self.meta.total_frames, self.meta.total_frames + episode_length
)
self.meta.add_episode(episode_index, episode_length, task, task_index) self.meta.add_episode(episode_index, episode_length, task, task_index)
self._wait_image_writer() self._wait_image_writer()
@ -744,7 +728,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.consolidated = False self.consolidated = False
def _save_episode_table(self, episode_index: int) -> None: def _save_episode_table(self, episode_index: int) -> None:
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self._features, split="train") ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self.hf_features, split="train")
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index) ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
ep_data_path.parent.mkdir(parents=True, exist_ok=True) ep_data_path.parent.mkdir(parents=True, exist_ok=True)
write_parquet(ep_dataset, ep_data_path) write_parquet(ep_dataset, ep_data_path)
@ -753,7 +737,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_index = self.episode_buffer["episode_index"] episode_index = self.episode_buffer["episode_index"]
if self.image_writer is not None: if self.image_writer is not None:
for cam_key in self.meta.camera_keys: for cam_key in self.meta.camera_keys:
img_dir = self.image_writer.get_episode_dir(episode_index, cam_key) img_dir = self._get_image_file_path(
episode_index=episode_index, image_key=cam_key, frame_index=0
).parent
if img_dir.is_dir(): if img_dir.is_dir():
shutil.rmtree(img_dir) shutil.rmtree(img_dir)
@ -761,13 +747,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_buffer = self._create_episode_buffer() self.episode_buffer = self._create_episode_buffer()
def start_image_writer(self, num_processes: int = 0, num_threads: int = 1) -> None: def start_image_writer(self, num_processes: int = 0, num_threads: int = 1) -> None:
if isinstance(self.image_writer, ImageWriter): if isinstance(self.image_writer, AsyncImageWriter):
logging.warning( logging.warning(
"You are starting a new ImageWriter that is replacing an already exising one in the dataset." "You are starting a new AsyncImageWriter that is replacing an already exising one in the dataset."
) )
self.image_writer = ImageWriter( self.image_writer = AsyncImageWriter(
write_dir=self.root / "images",
num_processes=num_processes, num_processes=num_processes,
num_threads=num_threads, num_threads=num_threads,
) )
@ -787,19 +772,21 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.image_writer.wait_until_done() self.image_writer.wait_until_done()
def encode_videos(self) -> None: def encode_videos(self) -> None:
# Use ffmpeg to convert frames stored as png into mp4 videos """
Use ffmpeg to convert frames stored as png into mp4 videos.
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
since video encoding with ffmpeg is already using multithreading.
"""
for episode_index in range(self.meta.total_episodes): for episode_index in range(self.meta.total_episodes):
for key in self.meta.video_keys: for key in self.meta.video_keys:
# TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need
# to call self.image_writer here
tmp_imgs_dir = self.image_writer.get_episode_dir(episode_index, key)
video_path = self.root / self.meta.get_video_file_path(episode_index, key) video_path = self.root / self.meta.get_video_file_path(episode_index, key)
if video_path.is_file(): if video_path.is_file():
# Skip if video is already encoded. Could be the case when resuming data recording. # Skip if video is already encoded. Could be the case when resuming data recording.
continue continue
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, img_dir = self._get_image_file_path(
# since video encoding with ffmpeg is already using multithreading. episode_index=episode_index, image_key=key, frame_index=0
encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True) ).parent
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None: def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
self.hf_dataset = self.load_hf_dataset() self.hf_dataset = self.load_hf_dataset()
@ -810,8 +797,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.encode_videos() self.encode_videos()
self.meta.write_video_info() self.meta.write_video_info()
if not keep_image_files and self.image_writer is not None: if not keep_image_files:
shutil.rmtree(self.image_writer.write_dir) img_dir = self.root / "images"
if img_dir.is_dir():
shutil.rmtree(self.root / "images")
video_files = list(self.root.rglob("*.mp4")) video_files = list(self.root.rglob("*.mp4"))
assert len(video_files) == self.num_episodes * len(self.meta.video_keys) assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
@ -989,7 +978,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def features(self) -> datasets.Features: def features(self) -> datasets.Features:
features = {} features = {}
for dataset in self._datasets: for dataset in self._datasets:
features.update({k: v for k, v in dataset._features.items() if k not in self.disabled_data_keys}) features.update(
{k: v for k, v in dataset.hf_features.items() if k not in self.disabled_data_keys}
)
return features return features
@property @property

View File

@ -22,6 +22,7 @@ from typing import Any
import datasets import datasets
import jsonlines import jsonlines
import pyarrow.compute as pc
import torch import torch
from datasets.table import embed_table_storage from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, HfApi from huggingface_hub import DatasetCard, HfApi
@ -39,6 +40,7 @@ TASKS_PATH = "meta/tasks.jsonl"
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet" DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
DATASET_CARD_TEMPLATE = """ DATASET_CARD_TEMPLATE = """
--- ---
@ -222,6 +224,24 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) ->
return version return version
def get_hf_features_from_features(features: dict) -> datasets.Features:
hf_features = {}
for key, ft in features.items():
if ft["dtype"] == "video":
continue
elif ft["dtype"] == "image":
hf_features[key] = datasets.Image()
elif ft["shape"] == (1,):
hf_features[key] = datasets.Value(dtype=ft["dtype"])
else:
assert len(ft["shape"]) == 1
hf_features[key] = datasets.Sequence(
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
)
return datasets.Features(hf_features)
def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict: def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
camera_ft = {} camera_ft = {}
if robot.cameras: if robot.cameras:
@ -270,6 +290,31 @@ def get_episode_data_index(
} }
def calculate_total_episode(
hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True
) -> dict[str, torch.Tensor]:
episode_indices = sorted(hf_dataset.unique("episode_index"))
total_episodes = len(episode_indices)
if raise_if_not_contiguous and episode_indices != list(range(total_episodes)):
raise ValueError("episode_index values are not sorted and contiguous.")
return total_episodes
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]:
episode_lengths = []
table = hf_dataset.data.table
total_episodes = calculate_total_episode(hf_dataset)
for ep_idx in range(total_episodes):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
episode_lengths.insert(ep_idx, len(ep_table))
cumulative_lenghts = list(accumulate(episode_lengths))
return {
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
"to": torch.LongTensor(cumulative_lenghts),
}
def check_timestamps_sync( def check_timestamps_sync(
hf_dataset: datasets.Dataset, hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor], episode_data_index: dict[str, torch.Tensor],

View File

@ -25,7 +25,6 @@ from tests.utils import DEVICE, ROBOT_CONFIG_PATH_TEMPLATE, make_camera, make_mo
# Import fixture modules as plugins # Import fixture modules as plugins
pytest_plugins = [ pytest_plugins = [
"tests.fixtures.dataset",
"tests.fixtures.dataset_factories", "tests.fixtures.dataset_factories",
"tests.fixtures.files", "tests.fixtures.files",
"tests.fixtures.hub", "tests.fixtures.hub",

View File

@ -1,67 +0,0 @@
import datasets
import pytest
from lerobot.common.datasets.utils import get_episode_data_index
from tests.fixtures.defaults import DUMMY_CAMERA_KEYS
@pytest.fixture(scope="session")
def empty_info(info_factory) -> dict:
return info_factory(
keys=[],
image_keys=[],
video_keys=[],
shapes={},
names={},
)
@pytest.fixture(scope="session")
def info(info_factory) -> dict:
return info_factory(
total_episodes=4,
total_frames=420,
total_tasks=3,
total_videos=8,
total_chunks=1,
)
@pytest.fixture(scope="session")
def stats(stats_factory) -> list:
return stats_factory()
@pytest.fixture(scope="session")
def tasks() -> list:
return [
{"task_index": 0, "task": "Pick up the block."},
{"task_index": 1, "task": "Open the box."},
{"task_index": 2, "task": "Make paperclips."},
]
@pytest.fixture(scope="session")
def episodes() -> list:
return [
{"episode_index": 0, "tasks": ["Pick up the block."], "length": 100},
{"episode_index": 1, "tasks": ["Open the box."], "length": 80},
{"episode_index": 2, "tasks": ["Pick up the block."], "length": 90},
{"episode_index": 3, "tasks": ["Make paperclips."], "length": 150},
]
@pytest.fixture(scope="session")
def episode_data_index(episodes) -> dict:
return get_episode_data_index(episodes)
@pytest.fixture(scope="session")
def hf_dataset(hf_dataset_factory) -> datasets.Dataset:
return hf_dataset_factory()
@pytest.fixture(scope="session")
def hf_dataset_image(hf_dataset_factory) -> datasets.Dataset:
image_keys = DUMMY_CAMERA_KEYS
return hf_dataset_factory(image_keys=image_keys)

View File

@ -11,16 +11,19 @@ import torch
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_SIZE,
DEFAULT_FEATURES,
DEFAULT_PARQUET_PATH, DEFAULT_PARQUET_PATH,
DEFAULT_VIDEO_PATH, DEFAULT_VIDEO_PATH,
get_hf_features_from_features,
hf_transform_to_torch, hf_transform_to_torch,
) )
from tests.fixtures.defaults import ( from tests.fixtures.defaults import (
DEFAULT_FPS, DEFAULT_FPS,
DUMMY_CAMERA_KEYS, DUMMY_CAMERA_FEATURES,
DUMMY_KEYS, DUMMY_MOTOR_FEATURES,
DUMMY_REPO_ID, DUMMY_REPO_ID,
DUMMY_ROBOT_TYPE, DUMMY_ROBOT_TYPE,
DUMMY_VIDEO_INFO,
) )
@ -73,16 +76,33 @@ def img_factory(img_array_factory):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def info_factory(): def features_factory():
def _create_features(
motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES,
use_videos: bool = True,
) -> dict:
if use_videos:
camera_ft = {
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items()
}
else:
camera_ft = {key: {"dtype": "image", **ft} for key, ft in camera_features.items()}
return {
**motor_features,
**camera_ft,
**DEFAULT_FEATURES,
}
return _create_features
@pytest.fixture(scope="session")
def info_factory(features_factory):
def _create_info( def _create_info(
codebase_version: str = CODEBASE_VERSION, codebase_version: str = CODEBASE_VERSION,
fps: int = DEFAULT_FPS, fps: int = DEFAULT_FPS,
robot_type: str = DUMMY_ROBOT_TYPE, robot_type: str = DUMMY_ROBOT_TYPE,
keys: list[str] = DUMMY_KEYS,
image_keys: list[str] | None = None,
video_keys: list[str] = DUMMY_CAMERA_KEYS,
shapes: dict | None = None,
names: dict | None = None,
total_episodes: int = 0, total_episodes: int = 0,
total_frames: int = 0, total_frames: int = 0,
total_tasks: int = 0, total_tasks: int = 0,
@ -90,30 +110,14 @@ def info_factory():
total_chunks: int = 0, total_chunks: int = 0,
chunks_size: int = DEFAULT_CHUNK_SIZE, chunks_size: int = DEFAULT_CHUNK_SIZE,
data_path: str = DEFAULT_PARQUET_PATH, data_path: str = DEFAULT_PARQUET_PATH,
videos_path: str = DEFAULT_VIDEO_PATH, video_path: str = DEFAULT_VIDEO_PATH,
motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES,
use_videos: bool = True,
) -> dict: ) -> dict:
if not image_keys: features = features_factory(motor_features, camera_features, use_videos)
image_keys = []
if not shapes:
shapes = make_dummy_shapes(keys=keys, camera_keys=[*image_keys, *video_keys])
if not names:
names = {key: [f"motor_{i}" for i in range(shapes[key])] for key in keys}
video_info = {"videos_path": videos_path}
for key in video_keys:
video_info[key] = {
"video.fps": fps,
"video.width": shapes[key]["width"],
"video.height": shapes[key]["height"],
"video.channels": shapes[key]["channels"],
"video.codec": "av1",
"video.pix_fmt": "yuv420p",
"video.is_depth_map": False,
"has_audio": False,
}
return { return {
"codebase_version": codebase_version, "codebase_version": codebase_version,
"data_path": data_path,
"robot_type": robot_type, "robot_type": robot_type,
"total_episodes": total_episodes, "total_episodes": total_episodes,
"total_frames": total_frames, "total_frames": total_frames,
@ -123,12 +127,9 @@ def info_factory():
"chunks_size": chunks_size, "chunks_size": chunks_size,
"fps": fps, "fps": fps,
"splits": {}, "splits": {},
"keys": keys, "data_path": data_path,
"video_keys": video_keys, "video_path": video_path if use_videos else None,
"image_keys": image_keys, "features": features,
"shapes": shapes,
"names": names,
"videos": video_info if len(video_keys) > 0 else None,
} }
return _create_info return _create_info
@ -137,32 +138,26 @@ def info_factory():
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def stats_factory(): def stats_factory():
def _create_stats( def _create_stats(
keys: list[str] = DUMMY_KEYS, features: dict[str] | None = None,
image_keys: list[str] | None = None,
video_keys: list[str] = DUMMY_CAMERA_KEYS,
shapes: dict | None = None,
) -> dict: ) -> dict:
if not image_keys:
image_keys = []
if not shapes:
shapes = make_dummy_shapes(keys=keys, camera_keys=[*image_keys, *video_keys])
stats = {} stats = {}
for key in keys: for key, ft in features.items():
shape = shapes[key] shape = ft["shape"]
stats[key] = { dtype = ft["dtype"]
"max": np.full(shape, 1, dtype=np.float32).tolist(), if dtype in ["image", "video"]:
"mean": np.full(shape, 0.5, dtype=np.float32).tolist(), stats[key] = {
"min": np.full(shape, 0, dtype=np.float32).tolist(), "max": np.full((3, 1, 1), 1, dtype=np.float32).tolist(),
"std": np.full(shape, 0.25, dtype=np.float32).tolist(), "mean": np.full((3, 1, 1), 0.5, dtype=np.float32).tolist(),
} "min": np.full((3, 1, 1), 0, dtype=np.float32).tolist(),
for key in [*image_keys, *video_keys]: "std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(),
shape = (3, 1, 1) }
stats[key] = { else:
"max": np.full(shape, 1, dtype=np.float32).tolist(), stats[key] = {
"mean": np.full(shape, 0.5, dtype=np.float32).tolist(), "max": np.full(shape, 1, dtype=dtype).tolist(),
"min": np.full(shape, 0, dtype=np.float32).tolist(), "mean": np.full(shape, 0.5, dtype=dtype).tolist(),
"std": np.full(shape, 0.25, dtype=np.float32).tolist(), "min": np.full(shape, 0, dtype=dtype).tolist(),
} "std": np.full(shape, 0.25, dtype=dtype).tolist(),
}
return stats return stats
return _create_stats return _create_stats
@ -185,7 +180,7 @@ def episodes_factory(tasks_factory):
def _create_episodes( def _create_episodes(
total_episodes: int = 3, total_episodes: int = 3,
total_frames: int = 400, total_frames: int = 400,
task_dicts: dict | None = None, tasks: dict | None = None,
multi_task: bool = False, multi_task: bool = False,
): ):
if total_episodes <= 0 or total_frames <= 0: if total_episodes <= 0 or total_frames <= 0:
@ -193,18 +188,18 @@ def episodes_factory(tasks_factory):
if total_frames < total_episodes: if total_frames < total_episodes:
raise ValueError("total_length must be greater than or equal to num_episodes.") raise ValueError("total_length must be greater than or equal to num_episodes.")
if not task_dicts: if not tasks:
min_tasks = 2 if multi_task else 1 min_tasks = 2 if multi_task else 1
total_tasks = random.randint(min_tasks, total_episodes) total_tasks = random.randint(min_tasks, total_episodes)
task_dicts = tasks_factory(total_tasks) tasks = tasks_factory(total_tasks)
if total_episodes < len(task_dicts) and not multi_task: if total_episodes < len(tasks) and not multi_task:
raise ValueError("The number of tasks should be less than the number of episodes.") raise ValueError("The number of tasks should be less than the number of episodes.")
# Generate random lengths that sum up to total_length # Generate random lengths that sum up to total_length
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist() lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
tasks_list = [task_dict["task"] for task_dict in task_dicts] tasks_list = [task_dict["task"] for task_dict in tasks]
num_tasks_available = len(tasks_list) num_tasks_available = len(tasks_list)
episodes_list = [] episodes_list = []
@ -231,81 +226,56 @@ def episodes_factory(tasks_factory):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def hf_dataset_factory(img_array_factory, episodes, tasks): def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
def _create_hf_dataset( def _create_hf_dataset(
episode_dicts: list[dict] = episodes, features: dict | None = None,
task_dicts: list[dict] = tasks, tasks: list[dict] | None = None,
keys: list[str] = DUMMY_KEYS, episodes: list[dict] | None = None,
image_keys: list[str] | None = None,
shapes: dict | None = None,
fps: int = DEFAULT_FPS, fps: int = DEFAULT_FPS,
) -> datasets.Dataset: ) -> datasets.Dataset:
if not image_keys: if not tasks:
image_keys = [] tasks = tasks_factory()
if not shapes: if not episodes:
shapes = make_dummy_shapes(keys=keys, camera_keys=image_keys) episodes = episodes_factory()
key_features = { if not features:
key: datasets.Sequence(length=shapes[key], feature=datasets.Value(dtype="float32")) features = features_factory()
for key in keys
}
image_features = {key: datasets.Image() for key in image_keys} if image_keys else {}
common_features = {
"episode_index": datasets.Value(dtype="int64"),
"frame_index": datasets.Value(dtype="int64"),
"timestamp": datasets.Value(dtype="float32"),
"next.done": datasets.Value(dtype="bool"),
"index": datasets.Value(dtype="int64"),
"task_index": datasets.Value(dtype="int64"),
}
features = datasets.Features(
{
**key_features,
**image_features,
**common_features,
}
)
episode_index_col = np.array([], dtype=np.int64)
frame_index_col = np.array([], dtype=np.int64)
timestamp_col = np.array([], dtype=np.float32) timestamp_col = np.array([], dtype=np.float32)
next_done_col = np.array([], dtype=bool) frame_index_col = np.array([], dtype=np.int64)
episode_index_col = np.array([], dtype=np.int64)
task_index = np.array([], dtype=np.int64) task_index = np.array([], dtype=np.int64)
for ep_dict in episodes:
for ep_dict in episode_dicts: timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
episode_index_col = np.concatenate( episode_index_col = np.concatenate(
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int)) (episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
) )
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int))) ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
next_done_ep = np.full(ep_dict["length"], False, dtype=bool)
next_done_ep[-1] = True
next_done_col = np.concatenate((next_done_col, next_done_ep))
ep_task_index = get_task_index(task_dicts, ep_dict["tasks"][0])
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int))) task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
index_col = np.arange(len(episode_index_col)) index_col = np.arange(len(episode_index_col))
key_cols = {key: np.random.random((len(index_col), shapes[key])).astype(np.float32) for key in keys}
image_cols = {} robot_cols = {}
if image_keys: for key, ft in features.items():
for key in image_keys: if ft["dtype"] == "image":
image_cols[key] = [ robot_cols[key] = [
img_array_factory(width=shapes[key]["width"], height=shapes[key]["height"]) img_array_factory(width=ft["shapes"][0], height=ft["shapes"][1])
for _ in range(len(index_col)) for _ in range(len(index_col))
] ]
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"])
hf_features = get_hf_features_from_features(features)
dataset = datasets.Dataset.from_dict( dataset = datasets.Dataset.from_dict(
{ {
**key_cols, **robot_cols,
**image_cols,
"episode_index": episode_index_col,
"frame_index": frame_index_col,
"timestamp": timestamp_col, "timestamp": timestamp_col,
"next.done": next_done_col, "frame_index": frame_index_col,
"episode_index": episode_index_col,
"index": index_col, "index": index_col,
"task_index": task_index, "task_index": task_index,
}, },
features=features, features=hf_features,
) )
dataset.set_transform(hf_transform_to_torch) dataset.set_transform(hf_transform_to_torch)
return dataset return dataset
@ -315,26 +285,37 @@ def hf_dataset_factory(img_array_factory, episodes, tasks):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def lerobot_dataset_metadata_factory( def lerobot_dataset_metadata_factory(
info, info_factory,
stats, stats_factory,
tasks, tasks_factory,
episodes, episodes_factory,
mock_snapshot_download_factory, mock_snapshot_download_factory,
): ):
def _create_lerobot_dataset_metadata( def _create_lerobot_dataset_metadata(
root: Path, root: Path,
repo_id: str = DUMMY_REPO_ID, repo_id: str = DUMMY_REPO_ID,
info_dict: dict = info, info: dict | None = None,
stats_dict: dict = stats, stats: dict | None = None,
task_dicts: list[dict] = tasks, tasks: list[dict] | None = None,
episode_dicts: list[dict] = episodes, episodes: list[dict] | None = None,
**kwargs, local_files_only: bool = False,
) -> LeRobotDatasetMetadata: ) -> LeRobotDatasetMetadata:
if not info:
info = info_factory()
if not stats:
stats = stats_factory(features=info["features"])
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
)
mock_snapshot_download = mock_snapshot_download_factory( mock_snapshot_download = mock_snapshot_download_factory(
info_dict=info_dict, info=info,
stats_dict=stats_dict, stats=stats,
task_dicts=task_dicts, tasks=tasks,
episode_dicts=episode_dicts, episodes=episodes,
) )
with ( with (
patch( patch(
@ -347,48 +328,68 @@ def lerobot_dataset_metadata_factory(
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
mock_snapshot_download_patch.side_effect = mock_snapshot_download mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDatasetMetadata( return LeRobotDatasetMetadata(repo_id=repo_id, root=root, local_files_only=local_files_only)
repo_id=repo_id, root=root, local_files_only=kwargs.get("local_files_only", False)
)
return _create_lerobot_dataset_metadata return _create_lerobot_dataset_metadata
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def lerobot_dataset_factory( def lerobot_dataset_factory(
info, info_factory,
stats, stats_factory,
tasks, tasks_factory,
episodes, episodes_factory,
hf_dataset, hf_dataset_factory,
mock_snapshot_download_factory, mock_snapshot_download_factory,
lerobot_dataset_metadata_factory, lerobot_dataset_metadata_factory,
): ):
def _create_lerobot_dataset( def _create_lerobot_dataset(
root: Path, root: Path,
repo_id: str = DUMMY_REPO_ID, repo_id: str = DUMMY_REPO_ID,
info_dict: dict = info, total_episodes: int = 3,
stats_dict: dict = stats, total_frames: int = 150,
task_dicts: list[dict] = tasks, total_tasks: int = 1,
episode_dicts: list[dict] = episodes, multi_task: bool = False,
hf_ds: datasets.Dataset = hf_dataset, info: dict | None = None,
stats: dict | None = None,
tasks: list[dict] | None = None,
episode_dicts: list[dict] | None = None,
hf_dataset: datasets.Dataset | None = None,
**kwargs, **kwargs,
) -> LeRobotDataset: ) -> LeRobotDataset:
if not info:
info = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
)
if not stats:
stats = stats_factory(features=info["features"])
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episode_dicts:
episode_dicts = episodes_factory(
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
tasks=tasks,
multi_task=multi_task,
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"])
mock_snapshot_download = mock_snapshot_download_factory( mock_snapshot_download = mock_snapshot_download_factory(
info_dict=info_dict, info=info,
stats_dict=stats_dict, stats=stats,
task_dicts=task_dicts, tasks=tasks,
episode_dicts=episode_dicts, episodes=episode_dicts,
hf_ds=hf_ds, hf_dataset=hf_dataset,
) )
mock_metadata = lerobot_dataset_metadata_factory( mock_metadata = lerobot_dataset_metadata_factory(
root=root, root=root,
repo_id=repo_id, repo_id=repo_id,
info_dict=info_dict, info=info,
stats_dict=stats_dict, stats=stats,
task_dicts=task_dicts, tasks=tasks,
episode_dicts=episode_dicts, episodes=episode_dicts,
**kwargs, local_files_only=kwargs.get("local_files_only", False),
) )
with ( with (
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch, patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
@ -402,44 +403,3 @@ def lerobot_dataset_factory(
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs) return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)
return _create_lerobot_dataset return _create_lerobot_dataset
@pytest.fixture(scope="session")
def lerobot_dataset_from_episodes_factory(
info_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
lerobot_dataset_factory,
):
def _create_lerobot_dataset_total_episodes(
root: Path,
total_episodes: int = 3,
total_frames: int = 150,
total_tasks: int = 1,
multi_task: bool = False,
repo_id: str = DUMMY_REPO_ID,
**kwargs,
):
info_dict = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
)
task_dicts = tasks_factory(total_tasks)
episode_dicts = episodes_factory(
total_episodes=total_episodes,
total_frames=total_frames,
task_dicts=task_dicts,
multi_task=multi_task,
)
hf_dataset = hf_dataset_factory(episode_dicts=episode_dicts, task_dicts=task_dicts)
return lerobot_dataset_factory(
root=root,
repo_id=repo_id,
info_dict=info_dict,
task_dicts=task_dicts,
episode_dicts=episode_dicts,
hf_ds=hf_dataset,
**kwargs,
)
return _create_lerobot_dataset_total_episodes

View File

@ -3,6 +3,27 @@ from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
LEROBOT_TEST_DIR = LEROBOT_HOME / "_testing" LEROBOT_TEST_DIR = LEROBOT_HOME / "_testing"
DUMMY_REPO_ID = "dummy/repo" DUMMY_REPO_ID = "dummy/repo"
DUMMY_ROBOT_TYPE = "dummy_robot" DUMMY_ROBOT_TYPE = "dummy_robot"
DUMMY_KEYS = ["state", "action"] DUMMY_MOTOR_FEATURES = {
DUMMY_CAMERA_KEYS = ["laptop", "phone"] "action": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
},
"state": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
},
}
DUMMY_CAMERA_FEATURES = {
"laptop": {"shape": (640, 480, 3), "names": ["width", "height", "channels"], "info": None},
"phone": {"shape": (640, 480, 3), "names": ["width", "height", "channels"], "info": None},
}
DEFAULT_FPS = 30 DEFAULT_FPS = 30
DUMMY_VIDEO_INFO = {
"video.fps": DEFAULT_FPS,
"video.codec": "av1",
"video.pix_fmt": "yuv420p",
"video.is_depth_map": False,
"has_audio": False,
}

View File

@ -11,64 +11,77 @@ from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH,
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def info_path(info): def info_path(info_factory):
def _create_info_json_file(dir: Path, info_dict: dict = info) -> Path: def _create_info_json_file(dir: Path, info: dict | None = None) -> Path:
if not info:
info = info_factory()
fpath = dir / INFO_PATH fpath = dir / INFO_PATH
fpath.parent.mkdir(parents=True, exist_ok=True) fpath.parent.mkdir(parents=True, exist_ok=True)
with open(fpath, "w") as f: with open(fpath, "w") as f:
json.dump(info_dict, f, indent=4, ensure_ascii=False) json.dump(info, f, indent=4, ensure_ascii=False)
return fpath return fpath
return _create_info_json_file return _create_info_json_file
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def stats_path(stats): def stats_path(stats_factory):
def _create_stats_json_file(dir: Path, stats_dict: dict = stats) -> Path: def _create_stats_json_file(dir: Path, stats: dict | None = None) -> Path:
if not stats:
stats = stats_factory()
fpath = dir / STATS_PATH fpath = dir / STATS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True) fpath.parent.mkdir(parents=True, exist_ok=True)
with open(fpath, "w") as f: with open(fpath, "w") as f:
json.dump(stats_dict, f, indent=4, ensure_ascii=False) json.dump(stats, f, indent=4, ensure_ascii=False)
return fpath return fpath
return _create_stats_json_file return _create_stats_json_file
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def tasks_path(tasks): def tasks_path(tasks_factory):
def _create_tasks_jsonl_file(dir: Path, task_dicts: list = tasks) -> Path: def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path:
if not tasks:
tasks = tasks_factory()
fpath = dir / TASKS_PATH fpath = dir / TASKS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True) fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer: with jsonlines.open(fpath, "w") as writer:
writer.write_all(task_dicts) writer.write_all(tasks)
return fpath return fpath
return _create_tasks_jsonl_file return _create_tasks_jsonl_file
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def episode_path(episodes): def episode_path(episodes_factory):
def _create_episodes_jsonl_file(dir: Path, episode_dicts: list = episodes) -> Path: def _create_episodes_jsonl_file(dir: Path, episodes: list | None = None) -> Path:
if not episodes:
episodes = episodes_factory()
fpath = dir / EPISODES_PATH fpath = dir / EPISODES_PATH
fpath.parent.mkdir(parents=True, exist_ok=True) fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer: with jsonlines.open(fpath, "w") as writer:
writer.write_all(episode_dicts) writer.write_all(episodes)
return fpath return fpath
return _create_episodes_jsonl_file return _create_episodes_jsonl_file
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def single_episode_parquet_path(hf_dataset, info): def single_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_single_episode_parquet( def _create_single_episode_parquet(
dir: Path, hf_ds: datasets.Dataset = hf_dataset, ep_idx: int = 0 dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path: ) -> Path:
if not info:
info = info_factory()
if hf_dataset is None:
hf_dataset = hf_dataset_factory()
data_path = info["data_path"] data_path = info["data_path"]
chunks_size = info["chunks_size"] chunks_size = info["chunks_size"]
ep_chunk = ep_idx // chunks_size ep_chunk = ep_idx // chunks_size
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx) fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
fpath.parent.mkdir(parents=True, exist_ok=True) fpath.parent.mkdir(parents=True, exist_ok=True)
table = hf_ds.data.table table = hf_dataset.data.table
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
pq.write_table(ep_table, fpath) pq.write_table(ep_table, fpath)
return fpath return fpath
@ -77,8 +90,15 @@ def single_episode_parquet_path(hf_dataset, info):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def multi_episode_parquet_path(hf_dataset, info): def multi_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_multi_episode_parquet(dir: Path, hf_ds: datasets.Dataset = hf_dataset) -> Path: def _create_multi_episode_parquet(
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path:
if not info:
info = info_factory()
if hf_dataset is None:
hf_dataset = hf_dataset_factory()
data_path = info["data_path"] data_path = info["data_path"]
chunks_size = info["chunks_size"] chunks_size = info["chunks_size"]
total_episodes = info["total_episodes"] total_episodes = info["total_episodes"]
@ -86,7 +106,7 @@ def multi_episode_parquet_path(hf_dataset, info):
ep_chunk = ep_idx // chunks_size ep_chunk = ep_idx // chunks_size
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx) fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
fpath.parent.mkdir(parents=True, exist_ok=True) fpath.parent.mkdir(parents=True, exist_ok=True)
table = hf_ds.data.table table = hf_dataset.data.table
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
pq.write_table(ep_table, fpath) pq.write_table(ep_table, fpath)
return dir / "data" return dir / "data"

46
tests/fixtures/hub.py vendored
View File

@ -1,5 +1,6 @@
from pathlib import Path from pathlib import Path
import datasets
import pytest import pytest
from huggingface_hub.utils import filter_repo_objects from huggingface_hub.utils import filter_repo_objects
@ -9,16 +10,16 @@ from tests.fixtures.defaults import LEROBOT_TEST_DIR
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def mock_snapshot_download_factory( def mock_snapshot_download_factory(
info, info_factory,
info_path, info_path,
stats, stats_factory,
stats_path, stats_path,
tasks, tasks_factory,
tasks_path, tasks_path,
episodes, episodes_factory,
episode_path, episode_path,
single_episode_parquet_path, single_episode_parquet_path,
hf_dataset, hf_dataset_factory,
): ):
""" """
This factory allows to patch snapshot_download such that when called, it will create expected files rather This factory allows to patch snapshot_download such that when called, it will create expected files rather
@ -26,8 +27,25 @@ def mock_snapshot_download_factory(
""" """
def _mock_snapshot_download_func( def _mock_snapshot_download_func(
info_dict=info, stats_dict=stats, task_dicts=tasks, episode_dicts=episodes, hf_ds=hf_dataset info: dict | None = None,
stats: dict | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
hf_dataset: datasets.Dataset | None = None,
): ):
if not info:
info = info_factory()
if not stats:
stats = stats_factory(features=info["features"])
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
def _extract_episode_index_from_path(fpath: str) -> int: def _extract_episode_index_from_path(fpath: str) -> int:
path = Path(fpath) path = Path(fpath)
if path.suffix == ".parquet" and path.stem.startswith("episode_"): if path.suffix == ".parquet" and path.stem.startswith("episode_"):
@ -53,10 +71,10 @@ def mock_snapshot_download_factory(
all_files.extend(meta_files) all_files.extend(meta_files)
data_files = [] data_files = []
for episode_dict in episode_dicts: for episode_dict in episodes:
ep_idx = episode_dict["episode_index"] ep_idx = episode_dict["episode_index"]
ep_chunk = ep_idx // info_dict["chunks_size"] ep_chunk = ep_idx // info["chunks_size"]
data_path = info_dict["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx) data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
data_files.append(data_path) data_files.append(data_path)
all_files.extend(data_files) all_files.extend(data_files)
@ -69,15 +87,15 @@ def mock_snapshot_download_factory(
if rel_path.startswith("data/"): if rel_path.startswith("data/"):
episode_index = _extract_episode_index_from_path(rel_path) episode_index = _extract_episode_index_from_path(rel_path)
if episode_index is not None: if episode_index is not None:
_ = single_episode_parquet_path(local_dir, hf_ds, ep_idx=episode_index) _ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info)
if rel_path == INFO_PATH: if rel_path == INFO_PATH:
_ = info_path(local_dir, info_dict) _ = info_path(local_dir, info)
elif rel_path == STATS_PATH: elif rel_path == STATS_PATH:
_ = stats_path(local_dir, stats_dict) _ = stats_path(local_dir, stats)
elif rel_path == TASKS_PATH: elif rel_path == TASKS_PATH:
_ = tasks_path(local_dir, task_dicts) _ = tasks_path(local_dir, tasks)
elif rel_path == EPISODES_PATH: elif rel_path == EPISODES_PATH:
_ = episode_path(local_dir, episode_dicts) _ = episode_path(local_dir, episodes)
else: else:
pass pass
return str(local_dir) return str(local_dir)

View File

@ -35,7 +35,6 @@ from lerobot.common.datasets.compute_stats import (
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import ( from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset, LeRobotDataset,
LeRobotDatasetMetadata,
MultiLeRobotDataset, MultiLeRobotDataset,
) )
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
@ -57,10 +56,7 @@ def test_same_attributes_defined(lerobot_dataset_factory, tmp_path):
# Instantiate both ways # Instantiate both ways
robot = make_robot("koch", mock=True) robot = make_robot("koch", mock=True)
root_create = tmp_path / "create" root_create = tmp_path / "create"
metadata_create = LeRobotDatasetMetadata.create( dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create
)
dataset_create = LeRobotDataset.create(metadata_create)
root_init = tmp_path / "init" root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init) dataset_init = lerobot_dataset_factory(root=root_init)
@ -75,14 +71,14 @@ def test_same_attributes_defined(lerobot_dataset_factory, tmp_path):
assert init_attr == create_attr assert init_attr == create_attr
def test_dataset_initialization(lerobot_dataset_from_episodes_factory, tmp_path): def test_dataset_initialization(lerobot_dataset_factory, tmp_path):
kwargs = { kwargs = {
"repo_id": DUMMY_REPO_ID, "repo_id": DUMMY_REPO_ID,
"total_episodes": 10, "total_episodes": 10,
"total_frames": 400, "total_frames": 400,
"episodes": [2, 5, 6], "episodes": [2, 5, 6],
} }
dataset = lerobot_dataset_from_episodes_factory(root=tmp_path, **kwargs) dataset = lerobot_dataset_factory(root=tmp_path, **kwargs)
assert dataset.repo_id == kwargs["repo_id"] assert dataset.repo_id == kwargs["repo_id"]
assert dataset.meta.total_episodes == kwargs["total_episodes"] assert dataset.meta.total_episodes == kwargs["total_episodes"]

View File

@ -3,12 +3,13 @@ import torch
from datasets import Dataset from datasets import Dataset
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
calculate_episode_data_index,
check_delta_timestamps, check_delta_timestamps,
check_timestamps_sync, check_timestamps_sync,
get_delta_indices, get_delta_indices,
hf_transform_to_torch, hf_transform_to_torch,
) )
from tests.fixtures.defaults import DUMMY_KEYS from tests.fixtures.defaults import DUMMY_MOTOR_FEATURES
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -53,7 +54,7 @@ def slightly_off_hf_dataset_factory(synced_hf_dataset_factory):
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def valid_delta_timestamps_factory(): def valid_delta_timestamps_factory():
def _create_valid_delta_timestamps(fps: int = 30, keys: list = DUMMY_KEYS) -> dict: def _create_valid_delta_timestamps(fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES) -> dict:
delta_timestamps = {key: [i * (1 / fps) for i in range(-10, 10)] for key in keys} delta_timestamps = {key: [i * (1 / fps) for i in range(-10, 10)] for key in keys}
return delta_timestamps return delta_timestamps
@ -63,7 +64,7 @@ def valid_delta_timestamps_factory():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def invalid_delta_timestamps_factory(valid_delta_timestamps_factory): def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
def _create_invalid_delta_timestamps( def _create_invalid_delta_timestamps(
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_KEYS fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
) -> dict: ) -> dict:
delta_timestamps = valid_delta_timestamps_factory(fps, keys) delta_timestamps = valid_delta_timestamps_factory(fps, keys)
# Modify a single timestamp just outside tolerance # Modify a single timestamp just outside tolerance
@ -77,7 +78,7 @@ def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory): def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
def _create_slightly_off_delta_timestamps( def _create_slightly_off_delta_timestamps(
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_KEYS fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
) -> dict: ) -> dict:
delta_timestamps = valid_delta_timestamps_factory(fps, keys) delta_timestamps = valid_delta_timestamps_factory(fps, keys)
# Modify a single timestamp just inside tolerance # Modify a single timestamp just inside tolerance
@ -90,14 +91,15 @@ def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def delta_indices(keys: list = DUMMY_KEYS) -> dict: def delta_indices(keys: list = DUMMY_MOTOR_FEATURES) -> dict:
return {key: list(range(-10, 10)) for key in keys} return {key: list(range(-10, 10)) for key in keys}
def test_check_timestamps_sync_synced(synced_hf_dataset_factory, episode_data_index): def test_check_timestamps_sync_synced(synced_hf_dataset_factory):
fps = 30 fps = 30
tolerance_s = 1e-4 tolerance_s = 1e-4
synced_hf_dataset = synced_hf_dataset_factory(fps) synced_hf_dataset = synced_hf_dataset_factory(fps)
episode_data_index = calculate_episode_data_index(synced_hf_dataset)
result = check_timestamps_sync( result = check_timestamps_sync(
hf_dataset=synced_hf_dataset, hf_dataset=synced_hf_dataset,
episode_data_index=episode_data_index, episode_data_index=episode_data_index,
@ -107,10 +109,11 @@ def test_check_timestamps_sync_synced(synced_hf_dataset_factory, episode_data_in
assert result is True assert result is True
def test_check_timestamps_sync_unsynced(unsynced_hf_dataset_factory, episode_data_index): def test_check_timestamps_sync_unsynced(unsynced_hf_dataset_factory):
fps = 30 fps = 30
tolerance_s = 1e-4 tolerance_s = 1e-4
unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s) unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s)
episode_data_index = calculate_episode_data_index(unsynced_hf_dataset)
with pytest.raises(ValueError): with pytest.raises(ValueError):
check_timestamps_sync( check_timestamps_sync(
hf_dataset=unsynced_hf_dataset, hf_dataset=unsynced_hf_dataset,
@ -120,10 +123,11 @@ def test_check_timestamps_sync_unsynced(unsynced_hf_dataset_factory, episode_dat
) )
def test_check_timestamps_sync_unsynced_no_exception(unsynced_hf_dataset_factory, episode_data_index): def test_check_timestamps_sync_unsynced_no_exception(unsynced_hf_dataset_factory):
fps = 30 fps = 30
tolerance_s = 1e-4 tolerance_s = 1e-4
unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s) unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s)
episode_data_index = calculate_episode_data_index(unsynced_hf_dataset)
result = check_timestamps_sync( result = check_timestamps_sync(
hf_dataset=unsynced_hf_dataset, hf_dataset=unsynced_hf_dataset,
episode_data_index=episode_data_index, episode_data_index=episode_data_index,
@ -134,10 +138,11 @@ def test_check_timestamps_sync_unsynced_no_exception(unsynced_hf_dataset_factory
assert result is False assert result is False
def test_check_timestamps_sync_slightly_off(slightly_off_hf_dataset_factory, episode_data_index): def test_check_timestamps_sync_slightly_off(slightly_off_hf_dataset_factory):
fps = 30 fps = 30
tolerance_s = 1e-4 tolerance_s = 1e-4
slightly_off_hf_dataset = slightly_off_hf_dataset_factory(fps, tolerance_s) slightly_off_hf_dataset = slightly_off_hf_dataset_factory(fps, tolerance_s)
episode_data_index = calculate_episode_data_index(slightly_off_hf_dataset)
result = check_timestamps_sync( result = check_timestamps_sync(
hf_dataset=slightly_off_hf_dataset, hf_dataset=slightly_off_hf_dataset,
episode_data_index=episode_data_index, episode_data_index=episode_data_index,

View File

@ -8,7 +8,7 @@ import pytest
from PIL import Image from PIL import Image
from lerobot.common.datasets.image_writer import ( from lerobot.common.datasets.image_writer import (
ImageWriter, AsyncImageWriter,
image_array_to_image, image_array_to_image,
safe_stop_image_writer, safe_stop_image_writer,
write_image, write_image,
@ -17,8 +17,8 @@ from lerobot.common.datasets.image_writer import (
DUMMY_IMAGE = "test_image.png" DUMMY_IMAGE = "test_image.png"
def test_init_threading(tmp_path): def test_init_threading():
writer = ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=2) writer = AsyncImageWriter(num_processes=0, num_threads=2)
try: try:
assert writer.num_processes == 0 assert writer.num_processes == 0
assert writer.num_threads == 2 assert writer.num_threads == 2
@ -30,8 +30,8 @@ def test_init_threading(tmp_path):
writer.stop() writer.stop()
def test_init_multiprocessing(tmp_path): def test_init_multiprocessing():
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2) writer = AsyncImageWriter(num_processes=2, num_threads=2)
try: try:
assert writer.num_processes == 2 assert writer.num_processes == 2
assert writer.num_threads == 2 assert writer.num_threads == 2
@ -43,35 +43,9 @@ def test_init_multiprocessing(tmp_path):
writer.stop() writer.stop()
def test_write_dir_created(tmp_path): def test_zero_threads():
write_dir = tmp_path / "non_existent_dir"
assert not write_dir.exists()
writer = ImageWriter(write_dir=write_dir)
try:
assert write_dir.exists()
finally:
writer.stop()
def test_get_image_file_path_and_episode_dir(tmp_path):
writer = ImageWriter(write_dir=tmp_path)
try:
episode_index = 1
image_key = "test_key"
frame_index = 10
expected_episode_dir = tmp_path / f"{image_key}/episode_{episode_index:06d}"
expected_path = expected_episode_dir / f"frame_{frame_index:06d}.png"
image_file_path = writer.get_image_file_path(episode_index, image_key, frame_index)
assert image_file_path == expected_path
episode_dir = writer.get_episode_dir(episode_index, image_key)
assert episode_dir == expected_episode_dir
finally:
writer.stop()
def test_zero_threads(tmp_path):
with pytest.raises(ValueError): with pytest.raises(ValueError):
ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=0) AsyncImageWriter(num_processes=0, num_threads=0)
def test_image_array_to_image_rgb(img_array_factory): def test_image_array_to_image_rgb(img_array_factory):
@ -148,7 +122,7 @@ def test_write_image_exception(tmp_path):
def test_save_image_numpy(tmp_path, img_array_factory): def test_save_image_numpy(tmp_path, img_array_factory):
writer = ImageWriter(write_dir=tmp_path) writer = AsyncImageWriter()
try: try:
image_array = img_array_factory() image_array = img_array_factory()
fpath = tmp_path / DUMMY_IMAGE fpath = tmp_path / DUMMY_IMAGE
@ -163,7 +137,7 @@ def test_save_image_numpy(tmp_path, img_array_factory):
def test_save_image_numpy_multiprocessing(tmp_path, img_array_factory): def test_save_image_numpy_multiprocessing(tmp_path, img_array_factory):
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2) writer = AsyncImageWriter(num_processes=2, num_threads=2)
try: try:
image_array = img_array_factory() image_array = img_array_factory()
fpath = tmp_path / DUMMY_IMAGE fpath = tmp_path / DUMMY_IMAGE
@ -177,7 +151,7 @@ def test_save_image_numpy_multiprocessing(tmp_path, img_array_factory):
def test_save_image_torch(tmp_path, img_tensor_factory): def test_save_image_torch(tmp_path, img_tensor_factory):
writer = ImageWriter(write_dir=tmp_path) writer = AsyncImageWriter()
try: try:
image_tensor = img_tensor_factory() image_tensor = img_tensor_factory()
fpath = tmp_path / DUMMY_IMAGE fpath = tmp_path / DUMMY_IMAGE
@ -193,7 +167,7 @@ def test_save_image_torch(tmp_path, img_tensor_factory):
def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory): def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory):
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2) writer = AsyncImageWriter(num_processes=2, num_threads=2)
try: try:
image_tensor = img_tensor_factory() image_tensor = img_tensor_factory()
fpath = tmp_path / DUMMY_IMAGE fpath = tmp_path / DUMMY_IMAGE
@ -208,7 +182,7 @@ def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory):
def test_save_image_pil(tmp_path, img_factory): def test_save_image_pil(tmp_path, img_factory):
writer = ImageWriter(write_dir=tmp_path) writer = AsyncImageWriter()
try: try:
image_pil = img_factory() image_pil = img_factory()
fpath = tmp_path / DUMMY_IMAGE fpath = tmp_path / DUMMY_IMAGE
@ -223,7 +197,7 @@ def test_save_image_pil(tmp_path, img_factory):
def test_save_image_pil_multiprocessing(tmp_path, img_factory): def test_save_image_pil_multiprocessing(tmp_path, img_factory):
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2) writer = AsyncImageWriter(num_processes=2, num_threads=2)
try: try:
image_pil = img_factory() image_pil = img_factory()
fpath = tmp_path / DUMMY_IMAGE fpath = tmp_path / DUMMY_IMAGE
@ -237,10 +211,10 @@ def test_save_image_pil_multiprocessing(tmp_path, img_factory):
def test_save_image_invalid_data(tmp_path): def test_save_image_invalid_data(tmp_path):
writer = ImageWriter(write_dir=tmp_path) writer = AsyncImageWriter()
try: try:
image_array = "invalid data" image_array = "invalid data"
fpath = writer.get_image_file_path(0, "test_key", 0) fpath = tmp_path / DUMMY_IMAGE
fpath.parent.mkdir(parents=True, exist_ok=True) fpath.parent.mkdir(parents=True, exist_ok=True)
with patch("builtins.print") as mock_print: with patch("builtins.print") as mock_print:
writer.save_image(image_array, fpath) writer.save_image(image_array, fpath)
@ -252,47 +226,47 @@ def test_save_image_invalid_data(tmp_path):
def test_save_image_after_stop(tmp_path, img_array_factory): def test_save_image_after_stop(tmp_path, img_array_factory):
writer = ImageWriter(write_dir=tmp_path) writer = AsyncImageWriter()
writer.stop() writer.stop()
image_array = img_array_factory() image_array = img_array_factory()
fpath = writer.get_image_file_path(0, "test_key", 0) fpath = tmp_path / DUMMY_IMAGE
writer.save_image(image_array, fpath) writer.save_image(image_array, fpath)
time.sleep(1) time.sleep(1)
assert not fpath.exists() assert not fpath.exists()
def test_stop(tmp_path): def test_stop():
writer = ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=2) writer = AsyncImageWriter(num_processes=0, num_threads=2)
writer.stop() writer.stop()
assert not any(t.is_alive() for t in writer.threads) assert not any(t.is_alive() for t in writer.threads)
def test_stop_multiprocessing(tmp_path): def test_stop_multiprocessing():
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2) writer = AsyncImageWriter(num_processes=2, num_threads=2)
writer.stop() writer.stop()
assert not any(p.is_alive() for p in writer.processes) assert not any(p.is_alive() for p in writer.processes)
def test_multiple_stops(tmp_path): def test_multiple_stops():
writer = ImageWriter(write_dir=tmp_path) writer = AsyncImageWriter()
writer.stop() writer.stop()
writer.stop() # Should not raise an exception writer.stop() # Should not raise an exception
assert not any(t.is_alive() for t in writer.threads) assert not any(t.is_alive() for t in writer.threads)
def test_multiple_stops_multiprocessing(tmp_path): def test_multiple_stops_multiprocessing():
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2) writer = AsyncImageWriter(num_processes=2, num_threads=2)
writer.stop() writer.stop()
writer.stop() # Should not raise an exception writer.stop() # Should not raise an exception
assert not any(t.is_alive() for t in writer.threads) assert not any(t.is_alive() for t in writer.threads)
def test_wait_until_done(tmp_path, img_array_factory): def test_wait_until_done(tmp_path, img_array_factory):
writer = ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=4) writer = AsyncImageWriter(num_processes=0, num_threads=4)
try: try:
num_images = 100 num_images = 100
image_arrays = [img_array_factory(width=500, height=500) for _ in range(num_images)] image_arrays = [img_array_factory(width=500, height=500) for _ in range(num_images)]
fpaths = [writer.get_image_file_path(0, "test_key", i) for i in range(num_images)] fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
for image_array, fpath in zip(image_arrays, fpaths, strict=True): for image_array, fpath in zip(image_arrays, fpaths, strict=True):
fpath.parent.mkdir(parents=True, exist_ok=True) fpath.parent.mkdir(parents=True, exist_ok=True)
writer.save_image(image_array, fpath) writer.save_image(image_array, fpath)
@ -306,11 +280,11 @@ def test_wait_until_done(tmp_path, img_array_factory):
def test_wait_until_done_multiprocessing(tmp_path, img_array_factory): def test_wait_until_done_multiprocessing(tmp_path, img_array_factory):
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2) writer = AsyncImageWriter(num_processes=2, num_threads=2)
try: try:
num_images = 100 num_images = 100
image_arrays = [img_array_factory() for _ in range(num_images)] image_arrays = [img_array_factory() for _ in range(num_images)]
fpaths = [writer.get_image_file_path(0, "test_key", i) for i in range(num_images)] fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
for image_array, fpath in zip(image_arrays, fpaths, strict=True): for image_array, fpath in zip(image_arrays, fpaths, strict=True):
fpath.parent.mkdir(parents=True, exist_ok=True) fpath.parent.mkdir(parents=True, exist_ok=True)
writer.save_image(image_array, fpath) writer.save_image(image_array, fpath)
@ -324,7 +298,7 @@ def test_wait_until_done_multiprocessing(tmp_path, img_array_factory):
def test_exception_handling(tmp_path, img_array_factory): def test_exception_handling(tmp_path, img_array_factory):
writer = ImageWriter(write_dir=tmp_path) writer = AsyncImageWriter()
try: try:
image_array = img_array_factory() image_array = img_array_factory()
with ( with (
@ -338,7 +312,7 @@ def test_exception_handling(tmp_path, img_array_factory):
def test_with_different_image_formats(tmp_path, img_array_factory): def test_with_different_image_formats(tmp_path, img_array_factory):
writer = ImageWriter(write_dir=tmp_path) writer = AsyncImageWriter()
try: try:
image_array = img_array_factory() image_array = img_array_factory()
formats = ["png", "jpeg", "bmp"] formats = ["png", "jpeg", "bmp"]
@ -353,7 +327,7 @@ def test_with_different_image_formats(tmp_path, img_array_factory):
def test_safe_stop_image_writer_decorator(): def test_safe_stop_image_writer_decorator():
class MockDataset: class MockDataset:
def __init__(self): def __init__(self):
self.image_writer = MagicMock(spec=ImageWriter) self.image_writer = MagicMock(spec=AsyncImageWriter)
@safe_stop_image_writer @safe_stop_image_writer
def function_that_raises_exception(dataset=None): def function_that_raises_exception(dataset=None):
@ -369,10 +343,10 @@ def test_safe_stop_image_writer_decorator():
def test_main_process_time(tmp_path, img_tensor_factory): def test_main_process_time(tmp_path, img_tensor_factory):
writer = ImageWriter(write_dir=tmp_path) writer = AsyncImageWriter()
try: try:
image_tensor = img_tensor_factory() image_tensor = img_tensor_factory()
fpath = tmp_path / "test_main_process_time.png" fpath = tmp_path / DUMMY_IMAGE
start_time = time.perf_counter() start_time = time.perf_counter()
writer.save_image(image_tensor, fpath) writer.save_image(image_tensor, fpath)
end_time = time.perf_counter() end_time = time.perf_counter()

View File

@ -213,15 +213,13 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range():
@pytest.mark.parametrize("online_dataset_size", [0, 4]) @pytest.mark.parametrize("online_dataset_size", [0, 4])
@pytest.mark.parametrize("online_sampling_ratio", [0.0, 1.0]) @pytest.mark.parametrize("online_sampling_ratio", [0.0, 1.0])
def test_compute_sampler_weights_trivial( def test_compute_sampler_weights_trivial(
lerobot_dataset_from_episodes_factory, lerobot_dataset_factory,
tmp_path, tmp_path,
offline_dataset_size: int, offline_dataset_size: int,
online_dataset_size: int, online_dataset_size: int,
online_sampling_ratio: float, online_sampling_ratio: float,
): ):
offline_dataset = lerobot_dataset_from_episodes_factory( offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size)
tmp_path, total_episodes=1, total_frames=offline_dataset_size
)
online_dataset, _ = make_new_buffer() online_dataset, _ = make_new_buffer()
if online_dataset_size > 0: if online_dataset_size > 0:
online_dataset.add_data( online_dataset.add_data(
@ -241,9 +239,9 @@ def test_compute_sampler_weights_trivial(
assert torch.allclose(weights, expected_weights) assert torch.allclose(weights, expected_weights)
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_from_episodes_factory, tmp_path): def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path):
# Arbitrarily set small dataset sizes, making sure to have uneven sizes. # Arbitrarily set small dataset sizes, making sure to have uneven sizes.
offline_dataset = lerobot_dataset_from_episodes_factory(tmp_path, total_episodes=1, total_frames=4) offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
online_dataset, _ = make_new_buffer() online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
online_sampling_ratio = 0.8 online_sampling_ratio = 0.8
@ -255,11 +253,9 @@ def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_from_episodes_
) )
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n( def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path):
lerobot_dataset_from_episodes_factory, tmp_path
):
# Arbitrarily set small dataset sizes, making sure to have uneven sizes. # Arbitrarily set small dataset sizes, making sure to have uneven sizes.
offline_dataset = lerobot_dataset_from_episodes_factory(tmp_path, total_episodes=1, total_frames=4) offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
online_dataset, _ = make_new_buffer() online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
weights = compute_sampler_weights( weights = compute_sampler_weights(
@ -270,9 +266,9 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(
) )
def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_from_episodes_factory, tmp_path): def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path):
"""Note: test copied from test_sampler.""" """Note: test copied from test_sampler."""
offline_dataset = lerobot_dataset_from_episodes_factory(tmp_path, total_episodes=1, total_frames=2) offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2)
online_dataset, _ = make_new_buffer() online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))