Fix tests
This commit is contained in:
parent
aed9f4036a
commit
f3630ad910
|
@ -22,8 +22,6 @@ import numpy as np
|
|||
import PIL.Image
|
||||
import torch
|
||||
|
||||
DEFAULT_IMAGE_PATH = "{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
|
||||
|
||||
def safe_stop_image_writer(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
|
@ -87,7 +85,7 @@ def worker_process(queue: queue.Queue, num_threads: int):
|
|||
t.join()
|
||||
|
||||
|
||||
class ImageWriter:
|
||||
class AsyncImageWriter:
|
||||
"""
|
||||
This class abstract away the initialisation of processes or/and threads to
|
||||
save images on disk asynchrounously, which is critical to control a robot and record data
|
||||
|
@ -102,11 +100,7 @@ class ImageWriter:
|
|||
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
|
||||
"""
|
||||
|
||||
def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1):
|
||||
self.write_dir = write_dir
|
||||
self.write_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.image_path = DEFAULT_IMAGE_PATH
|
||||
|
||||
def __init__(self, num_processes: int = 0, num_threads: int = 1):
|
||||
self.num_processes = num_processes
|
||||
self.num_threads = num_threads
|
||||
self.queue = None
|
||||
|
@ -134,17 +128,6 @@ class ImageWriter:
|
|||
p.start()
|
||||
self.processes.append(p)
|
||||
|
||||
def get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
fpath = self.image_path.format(
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
return self.write_dir / fpath
|
||||
|
||||
def get_episode_dir(self, episode_index: int, image_key: str) -> Path:
|
||||
return self.get_image_file_path(
|
||||
episode_index=episode_index, image_key=image_key, frame_index=0
|
||||
).parent
|
||||
|
||||
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
if isinstance(image, torch.Tensor):
|
||||
# Convert tensor to numpy array to minimize main process time
|
||||
|
|
|
@ -22,15 +22,18 @@ from pathlib import Path
|
|||
from typing import Callable
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torch.utils
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import snapshot_download, upload_folder
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
|
||||
from lerobot.common.datasets.image_writer import ImageWriter
|
||||
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
EPISODES_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
|
@ -44,6 +47,7 @@ from lerobot.common.datasets.utils import (
|
|||
get_delta_indices,
|
||||
get_episode_data_index,
|
||||
get_features_from_robot,
|
||||
get_hf_features_from_features,
|
||||
get_hub_safe_version,
|
||||
hf_transform_to_torch,
|
||||
load_episodes,
|
||||
|
@ -140,14 +144,9 @@ class LeRobotDatasetMetadata:
|
|||
|
||||
@property
|
||||
def features(self) -> dict[str, dict]:
|
||||
""""""
|
||||
"""All features contained in the dataset."""
|
||||
return self.info["features"]
|
||||
|
||||
@property
|
||||
def keys(self) -> list[str]:
|
||||
"""Keys to access non-image data (state, actions etc.)."""
|
||||
return self.info["keys"]
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities stored as images."""
|
||||
|
@ -268,7 +267,7 @@ class LeRobotDatasetMetadata:
|
|||
obj.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||
|
||||
if robot is not None:
|
||||
features = get_features_from_robot(robot)
|
||||
features = get_features_from_robot(robot, use_videos)
|
||||
robot_type = robot.robot_type
|
||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||
logging.warning(
|
||||
|
@ -522,35 +521,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
return len(self.episodes) if self.episodes is not None else self.meta.total_episodes
|
||||
|
||||
@property
|
||||
def features(self) -> list[str]:
|
||||
return list(self._features) + self.meta.video_keys
|
||||
def features(self) -> dict[str, dict]:
|
||||
return self.meta.features
|
||||
|
||||
@property
|
||||
def _features(self) -> datasets.Features:
|
||||
def hf_features(self) -> datasets.Features:
|
||||
"""Features of the hf_dataset."""
|
||||
if self.hf_dataset is not None:
|
||||
return self.hf_dataset.features
|
||||
elif self.episode_buffer is None:
|
||||
raise NotImplementedError(
|
||||
"Dataset features must be infered from an existing hf_dataset or episode_buffer."
|
||||
)
|
||||
|
||||
features = {}
|
||||
for key in self.episode_buffer:
|
||||
if key in ["episode_index", "frame_index", "index", "task_index"]:
|
||||
features[key] = datasets.Value(dtype="int64")
|
||||
elif key in ["next.done", "next.success"]:
|
||||
features[key] = datasets.Value(dtype="bool")
|
||||
elif key in ["timestamp", "next.reward"]:
|
||||
features[key] = datasets.Value(dtype="float32")
|
||||
elif key in self.meta.image_keys:
|
||||
features[key] = datasets.Image()
|
||||
elif key in self.meta.keys:
|
||||
features[key] = datasets.Sequence(
|
||||
length=self.meta.shapes[key], feature=datasets.Value(dtype="float32")
|
||||
)
|
||||
|
||||
return datasets.Features(features)
|
||||
else:
|
||||
return get_hf_features_from_features(self.features)
|
||||
|
||||
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
||||
ep_start = self.episode_data_index["from"][ep_idx]
|
||||
|
@ -650,17 +630,26 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
)
|
||||
|
||||
def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||
# TODO(aliberts): Handle resume
|
||||
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
|
||||
return {
|
||||
"size": 0,
|
||||
"episode_index": self.meta.total_episodes if episode_index is None else episode_index,
|
||||
"task_index": None,
|
||||
"frame_index": [],
|
||||
"timestamp": [],
|
||||
**{key: [] for key in self.meta.features},
|
||||
**{key: [] for key in self.meta.image_keys},
|
||||
**{key: [] if key != "episode_index" else current_ep_idx for key in self.features},
|
||||
}
|
||||
|
||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
return self.root / fpath
|
||||
|
||||
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
|
||||
if self.image_writer is None:
|
||||
if isinstance(image, torch.Tensor):
|
||||
image = image.cpu().numpy()
|
||||
write_image(image, fpath)
|
||||
else:
|
||||
self.image_writer.save_image(image=image, fpath=fpath)
|
||||
|
||||
def add_frame(self, frame: dict) -> None:
|
||||
"""
|
||||
This function only adds the frame to the episode_buffer. Apart from images — which are written in a
|
||||
|
@ -668,35 +657,25 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
then needs to be called.
|
||||
"""
|
||||
frame_index = self.episode_buffer["size"]
|
||||
self.episode_buffer["frame_index"].append(frame_index)
|
||||
self.episode_buffer["timestamp"].append(frame_index / self.fps)
|
||||
self.episode_buffer["next.done"].append(False)
|
||||
|
||||
# Save all observed modalities except images
|
||||
for key in self.meta.keys:
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
for key, ft in self.features.items():
|
||||
if key == "frame_index":
|
||||
self.episode_buffer[key].append(frame_index)
|
||||
elif key == "timestamp":
|
||||
self.episode_buffer[key].append(frame_index / self.fps)
|
||||
elif key in frame and ft["dtype"] not in ["image", "video"]:
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
elif key in frame and ft["dtype"] in ["image", "video"]:
|
||||
img_path = self._get_image_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
|
||||
)
|
||||
if frame_index == 0:
|
||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._save_image(frame[key], img_path)
|
||||
if ft["dtype"] == "image":
|
||||
self.episode_buffer[key].append(str(img_path))
|
||||
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
if self.image_writer is None:
|
||||
return
|
||||
|
||||
# Save images
|
||||
for cam_key in self.meta.camera_keys:
|
||||
img_path = self.image_writer.get_image_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], image_key=cam_key, frame_index=frame_index
|
||||
)
|
||||
if frame_index == 0:
|
||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.image_writer.save_image(
|
||||
image=frame[cam_key],
|
||||
fpath=img_path,
|
||||
)
|
||||
|
||||
if cam_key in self.meta.image_keys:
|
||||
self.episode_buffer[cam_key].append(str(img_path))
|
||||
|
||||
def add_episode(self, task: str, encode_videos: bool = False) -> None:
|
||||
"""
|
||||
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
|
||||
|
@ -714,23 +693,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
raise NotImplementedError()
|
||||
|
||||
task_index = self.meta.get_task_index(task)
|
||||
self.episode_buffer["next.done"][-1] = True
|
||||
|
||||
for key in self.episode_buffer:
|
||||
if key in self.meta.image_keys:
|
||||
continue
|
||||
elif key in self.meta.keys:
|
||||
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
|
||||
if not set(self.episode_buffer.keys()) == set(self.features):
|
||||
raise ValueError()
|
||||
|
||||
for key, ft in self.features.items():
|
||||
if key == "index":
|
||||
self.episode_buffer[key] = np.arange(
|
||||
self.meta.total_frames, self.meta.total_frames + episode_length
|
||||
)
|
||||
elif key == "episode_index":
|
||||
self.episode_buffer[key] = torch.full((episode_length,), episode_index)
|
||||
self.episode_buffer[key] = np.full((episode_length,), episode_index)
|
||||
elif key == "task_index":
|
||||
self.episode_buffer[key] = torch.full((episode_length,), task_index)
|
||||
else:
|
||||
self.episode_buffer[key] = np.full((episode_length,), task_index)
|
||||
elif ft["dtype"] in ["image", "video"]:
|
||||
continue
|
||||
elif ft["shape"][0] == 1:
|
||||
self.episode_buffer[key] = torch.tensor(self.episode_buffer[key])
|
||||
elif ft["shape"][0] > 1:
|
||||
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
self.episode_buffer["index"] = torch.arange(
|
||||
self.meta.total_frames, self.meta.total_frames + episode_length
|
||||
)
|
||||
self.meta.add_episode(episode_index, episode_length, task, task_index)
|
||||
|
||||
self._wait_image_writer()
|
||||
|
@ -744,7 +728,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.consolidated = False
|
||||
|
||||
def _save_episode_table(self, episode_index: int) -> None:
|
||||
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self._features, split="train")
|
||||
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self.hf_features, split="train")
|
||||
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
|
||||
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
write_parquet(ep_dataset, ep_data_path)
|
||||
|
@ -753,7 +737,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
episode_index = self.episode_buffer["episode_index"]
|
||||
if self.image_writer is not None:
|
||||
for cam_key in self.meta.camera_keys:
|
||||
img_dir = self.image_writer.get_episode_dir(episode_index, cam_key)
|
||||
img_dir = self._get_image_file_path(
|
||||
episode_index=episode_index, image_key=cam_key, frame_index=0
|
||||
).parent
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(img_dir)
|
||||
|
||||
|
@ -761,13 +747,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.episode_buffer = self._create_episode_buffer()
|
||||
|
||||
def start_image_writer(self, num_processes: int = 0, num_threads: int = 1) -> None:
|
||||
if isinstance(self.image_writer, ImageWriter):
|
||||
if isinstance(self.image_writer, AsyncImageWriter):
|
||||
logging.warning(
|
||||
"You are starting a new ImageWriter that is replacing an already exising one in the dataset."
|
||||
"You are starting a new AsyncImageWriter that is replacing an already exising one in the dataset."
|
||||
)
|
||||
|
||||
self.image_writer = ImageWriter(
|
||||
write_dir=self.root / "images",
|
||||
self.image_writer = AsyncImageWriter(
|
||||
num_processes=num_processes,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
@ -787,19 +772,21 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.image_writer.wait_until_done()
|
||||
|
||||
def encode_videos(self) -> None:
|
||||
# Use ffmpeg to convert frames stored as png into mp4 videos
|
||||
"""
|
||||
Use ffmpeg to convert frames stored as png into mp4 videos.
|
||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
for episode_index in range(self.meta.total_episodes):
|
||||
for key in self.meta.video_keys:
|
||||
# TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need
|
||||
# to call self.image_writer here
|
||||
tmp_imgs_dir = self.image_writer.get_episode_dir(episode_index, key)
|
||||
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
|
||||
if video_path.is_file():
|
||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||
continue
|
||||
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
# since video encoding with ffmpeg is already using multithreading.
|
||||
encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True)
|
||||
img_dir = self._get_image_file_path(
|
||||
episode_index=episode_index, image_key=key, frame_index=0
|
||||
).parent
|
||||
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
|
||||
|
||||
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
|
@ -810,8 +797,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.encode_videos()
|
||||
self.meta.write_video_info()
|
||||
|
||||
if not keep_image_files and self.image_writer is not None:
|
||||
shutil.rmtree(self.image_writer.write_dir)
|
||||
if not keep_image_files:
|
||||
img_dir = self.root / "images"
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(self.root / "images")
|
||||
|
||||
video_files = list(self.root.rglob("*.mp4"))
|
||||
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
|
||||
|
@ -989,7 +978,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
def features(self) -> datasets.Features:
|
||||
features = {}
|
||||
for dataset in self._datasets:
|
||||
features.update({k: v for k, v in dataset._features.items() if k not in self.disabled_data_keys})
|
||||
features.update(
|
||||
{k: v for k, v in dataset.hf_features.items() if k not in self.disabled_data_keys}
|
||||
)
|
||||
return features
|
||||
|
||||
@property
|
||||
|
|
|
@ -22,6 +22,7 @@ from typing import Any
|
|||
|
||||
import datasets
|
||||
import jsonlines
|
||||
import pyarrow.compute as pc
|
||||
import torch
|
||||
from datasets.table import embed_table_storage
|
||||
from huggingface_hub import DatasetCard, HfApi
|
||||
|
@ -39,6 +40,7 @@ TASKS_PATH = "meta/tasks.jsonl"
|
|||
|
||||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
|
||||
DATASET_CARD_TEMPLATE = """
|
||||
---
|
||||
|
@ -222,6 +224,24 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) ->
|
|||
return version
|
||||
|
||||
|
||||
def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
hf_features = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "video":
|
||||
continue
|
||||
elif ft["dtype"] == "image":
|
||||
hf_features[key] = datasets.Image()
|
||||
elif ft["shape"] == (1,):
|
||||
hf_features[key] = datasets.Value(dtype=ft["dtype"])
|
||||
else:
|
||||
assert len(ft["shape"]) == 1
|
||||
hf_features[key] = datasets.Sequence(
|
||||
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
|
||||
)
|
||||
|
||||
return datasets.Features(hf_features)
|
||||
|
||||
|
||||
def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
|
||||
camera_ft = {}
|
||||
if robot.cameras:
|
||||
|
@ -270,6 +290,31 @@ def get_episode_data_index(
|
|||
}
|
||||
|
||||
|
||||
def calculate_total_episode(
|
||||
hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_indices = sorted(hf_dataset.unique("episode_index"))
|
||||
total_episodes = len(episode_indices)
|
||||
if raise_if_not_contiguous and episode_indices != list(range(total_episodes)):
|
||||
raise ValueError("episode_index values are not sorted and contiguous.")
|
||||
return total_episodes
|
||||
|
||||
|
||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = []
|
||||
table = hf_dataset.data.table
|
||||
total_episodes = calculate_total_episode(hf_dataset)
|
||||
for ep_idx in range(total_episodes):
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
episode_lengths.insert(ep_idx, len(ep_table))
|
||||
|
||||
cumulative_lenghts = list(accumulate(episode_lengths))
|
||||
return {
|
||||
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
|
||||
"to": torch.LongTensor(cumulative_lenghts),
|
||||
}
|
||||
|
||||
|
||||
def check_timestamps_sync(
|
||||
hf_dataset: datasets.Dataset,
|
||||
episode_data_index: dict[str, torch.Tensor],
|
||||
|
|
|
@ -25,7 +25,6 @@ from tests.utils import DEVICE, ROBOT_CONFIG_PATH_TEMPLATE, make_camera, make_mo
|
|||
|
||||
# Import fixture modules as plugins
|
||||
pytest_plugins = [
|
||||
"tests.fixtures.dataset",
|
||||
"tests.fixtures.dataset_factories",
|
||||
"tests.fixtures.files",
|
||||
"tests.fixtures.hub",
|
||||
|
|
|
@ -1,67 +0,0 @@
|
|||
import datasets
|
||||
import pytest
|
||||
|
||||
from lerobot.common.datasets.utils import get_episode_data_index
|
||||
from tests.fixtures.defaults import DUMMY_CAMERA_KEYS
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def empty_info(info_factory) -> dict:
|
||||
return info_factory(
|
||||
keys=[],
|
||||
image_keys=[],
|
||||
video_keys=[],
|
||||
shapes={},
|
||||
names={},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def info(info_factory) -> dict:
|
||||
return info_factory(
|
||||
total_episodes=4,
|
||||
total_frames=420,
|
||||
total_tasks=3,
|
||||
total_videos=8,
|
||||
total_chunks=1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def stats(stats_factory) -> list:
|
||||
return stats_factory()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tasks() -> list:
|
||||
return [
|
||||
{"task_index": 0, "task": "Pick up the block."},
|
||||
{"task_index": 1, "task": "Open the box."},
|
||||
{"task_index": 2, "task": "Make paperclips."},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episodes() -> list:
|
||||
return [
|
||||
{"episode_index": 0, "tasks": ["Pick up the block."], "length": 100},
|
||||
{"episode_index": 1, "tasks": ["Open the box."], "length": 80},
|
||||
{"episode_index": 2, "tasks": ["Pick up the block."], "length": 90},
|
||||
{"episode_index": 3, "tasks": ["Make paperclips."], "length": 150},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episode_data_index(episodes) -> dict:
|
||||
return get_episode_data_index(episodes)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_dataset(hf_dataset_factory) -> datasets.Dataset:
|
||||
return hf_dataset_factory()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_dataset_image(hf_dataset_factory) -> datasets.Dataset:
|
||||
image_keys = DUMMY_CAMERA_KEYS
|
||||
return hf_dataset_factory(image_keys=image_keys)
|
|
@ -11,16 +11,19 @@ import torch
|
|||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_PARQUET_PATH,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
get_hf_features_from_features,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from tests.fixtures.defaults import (
|
||||
DEFAULT_FPS,
|
||||
DUMMY_CAMERA_KEYS,
|
||||
DUMMY_KEYS,
|
||||
DUMMY_CAMERA_FEATURES,
|
||||
DUMMY_MOTOR_FEATURES,
|
||||
DUMMY_REPO_ID,
|
||||
DUMMY_ROBOT_TYPE,
|
||||
DUMMY_VIDEO_INFO,
|
||||
)
|
||||
|
||||
|
||||
|
@ -73,16 +76,33 @@ def img_factory(img_array_factory):
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def info_factory():
|
||||
def features_factory():
|
||||
def _create_features(
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
use_videos: bool = True,
|
||||
) -> dict:
|
||||
if use_videos:
|
||||
camera_ft = {
|
||||
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items()
|
||||
}
|
||||
else:
|
||||
camera_ft = {key: {"dtype": "image", **ft} for key, ft in camera_features.items()}
|
||||
return {
|
||||
**motor_features,
|
||||
**camera_ft,
|
||||
**DEFAULT_FEATURES,
|
||||
}
|
||||
|
||||
return _create_features
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def info_factory(features_factory):
|
||||
def _create_info(
|
||||
codebase_version: str = CODEBASE_VERSION,
|
||||
fps: int = DEFAULT_FPS,
|
||||
robot_type: str = DUMMY_ROBOT_TYPE,
|
||||
keys: list[str] = DUMMY_KEYS,
|
||||
image_keys: list[str] | None = None,
|
||||
video_keys: list[str] = DUMMY_CAMERA_KEYS,
|
||||
shapes: dict | None = None,
|
||||
names: dict | None = None,
|
||||
total_episodes: int = 0,
|
||||
total_frames: int = 0,
|
||||
total_tasks: int = 0,
|
||||
|
@ -90,30 +110,14 @@ def info_factory():
|
|||
total_chunks: int = 0,
|
||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||
data_path: str = DEFAULT_PARQUET_PATH,
|
||||
videos_path: str = DEFAULT_VIDEO_PATH,
|
||||
video_path: str = DEFAULT_VIDEO_PATH,
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
use_videos: bool = True,
|
||||
) -> dict:
|
||||
if not image_keys:
|
||||
image_keys = []
|
||||
if not shapes:
|
||||
shapes = make_dummy_shapes(keys=keys, camera_keys=[*image_keys, *video_keys])
|
||||
if not names:
|
||||
names = {key: [f"motor_{i}" for i in range(shapes[key])] for key in keys}
|
||||
|
||||
video_info = {"videos_path": videos_path}
|
||||
for key in video_keys:
|
||||
video_info[key] = {
|
||||
"video.fps": fps,
|
||||
"video.width": shapes[key]["width"],
|
||||
"video.height": shapes[key]["height"],
|
||||
"video.channels": shapes[key]["channels"],
|
||||
"video.codec": "av1",
|
||||
"video.pix_fmt": "yuv420p",
|
||||
"video.is_depth_map": False,
|
||||
"has_audio": False,
|
||||
}
|
||||
features = features_factory(motor_features, camera_features, use_videos)
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
"data_path": data_path,
|
||||
"robot_type": robot_type,
|
||||
"total_episodes": total_episodes,
|
||||
"total_frames": total_frames,
|
||||
|
@ -123,12 +127,9 @@ def info_factory():
|
|||
"chunks_size": chunks_size,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"keys": keys,
|
||||
"video_keys": video_keys,
|
||||
"image_keys": image_keys,
|
||||
"shapes": shapes,
|
||||
"names": names,
|
||||
"videos": video_info if len(video_keys) > 0 else None,
|
||||
"data_path": data_path,
|
||||
"video_path": video_path if use_videos else None,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
return _create_info
|
||||
|
@ -137,32 +138,26 @@ def info_factory():
|
|||
@pytest.fixture(scope="session")
|
||||
def stats_factory():
|
||||
def _create_stats(
|
||||
keys: list[str] = DUMMY_KEYS,
|
||||
image_keys: list[str] | None = None,
|
||||
video_keys: list[str] = DUMMY_CAMERA_KEYS,
|
||||
shapes: dict | None = None,
|
||||
features: dict[str] | None = None,
|
||||
) -> dict:
|
||||
if not image_keys:
|
||||
image_keys = []
|
||||
if not shapes:
|
||||
shapes = make_dummy_shapes(keys=keys, camera_keys=[*image_keys, *video_keys])
|
||||
stats = {}
|
||||
for key in keys:
|
||||
shape = shapes[key]
|
||||
stats[key] = {
|
||||
"max": np.full(shape, 1, dtype=np.float32).tolist(),
|
||||
"mean": np.full(shape, 0.5, dtype=np.float32).tolist(),
|
||||
"min": np.full(shape, 0, dtype=np.float32).tolist(),
|
||||
"std": np.full(shape, 0.25, dtype=np.float32).tolist(),
|
||||
}
|
||||
for key in [*image_keys, *video_keys]:
|
||||
shape = (3, 1, 1)
|
||||
stats[key] = {
|
||||
"max": np.full(shape, 1, dtype=np.float32).tolist(),
|
||||
"mean": np.full(shape, 0.5, dtype=np.float32).tolist(),
|
||||
"min": np.full(shape, 0, dtype=np.float32).tolist(),
|
||||
"std": np.full(shape, 0.25, dtype=np.float32).tolist(),
|
||||
}
|
||||
for key, ft in features.items():
|
||||
shape = ft["shape"]
|
||||
dtype = ft["dtype"]
|
||||
if dtype in ["image", "video"]:
|
||||
stats[key] = {
|
||||
"max": np.full((3, 1, 1), 1, dtype=np.float32).tolist(),
|
||||
"mean": np.full((3, 1, 1), 0.5, dtype=np.float32).tolist(),
|
||||
"min": np.full((3, 1, 1), 0, dtype=np.float32).tolist(),
|
||||
"std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(),
|
||||
}
|
||||
else:
|
||||
stats[key] = {
|
||||
"max": np.full(shape, 1, dtype=dtype).tolist(),
|
||||
"mean": np.full(shape, 0.5, dtype=dtype).tolist(),
|
||||
"min": np.full(shape, 0, dtype=dtype).tolist(),
|
||||
"std": np.full(shape, 0.25, dtype=dtype).tolist(),
|
||||
}
|
||||
return stats
|
||||
|
||||
return _create_stats
|
||||
|
@ -185,7 +180,7 @@ def episodes_factory(tasks_factory):
|
|||
def _create_episodes(
|
||||
total_episodes: int = 3,
|
||||
total_frames: int = 400,
|
||||
task_dicts: dict | None = None,
|
||||
tasks: dict | None = None,
|
||||
multi_task: bool = False,
|
||||
):
|
||||
if total_episodes <= 0 or total_frames <= 0:
|
||||
|
@ -193,18 +188,18 @@ def episodes_factory(tasks_factory):
|
|||
if total_frames < total_episodes:
|
||||
raise ValueError("total_length must be greater than or equal to num_episodes.")
|
||||
|
||||
if not task_dicts:
|
||||
if not tasks:
|
||||
min_tasks = 2 if multi_task else 1
|
||||
total_tasks = random.randint(min_tasks, total_episodes)
|
||||
task_dicts = tasks_factory(total_tasks)
|
||||
tasks = tasks_factory(total_tasks)
|
||||
|
||||
if total_episodes < len(task_dicts) and not multi_task:
|
||||
if total_episodes < len(tasks) and not multi_task:
|
||||
raise ValueError("The number of tasks should be less than the number of episodes.")
|
||||
|
||||
# Generate random lengths that sum up to total_length
|
||||
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
|
||||
|
||||
tasks_list = [task_dict["task"] for task_dict in task_dicts]
|
||||
tasks_list = [task_dict["task"] for task_dict in tasks]
|
||||
num_tasks_available = len(tasks_list)
|
||||
|
||||
episodes_list = []
|
||||
|
@ -231,81 +226,56 @@ def episodes_factory(tasks_factory):
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_dataset_factory(img_array_factory, episodes, tasks):
|
||||
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
|
||||
def _create_hf_dataset(
|
||||
episode_dicts: list[dict] = episodes,
|
||||
task_dicts: list[dict] = tasks,
|
||||
keys: list[str] = DUMMY_KEYS,
|
||||
image_keys: list[str] | None = None,
|
||||
shapes: dict | None = None,
|
||||
features: dict | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episodes: list[dict] | None = None,
|
||||
fps: int = DEFAULT_FPS,
|
||||
) -> datasets.Dataset:
|
||||
if not image_keys:
|
||||
image_keys = []
|
||||
if not shapes:
|
||||
shapes = make_dummy_shapes(keys=keys, camera_keys=image_keys)
|
||||
key_features = {
|
||||
key: datasets.Sequence(length=shapes[key], feature=datasets.Value(dtype="float32"))
|
||||
for key in keys
|
||||
}
|
||||
image_features = {key: datasets.Image() for key in image_keys} if image_keys else {}
|
||||
common_features = {
|
||||
"episode_index": datasets.Value(dtype="int64"),
|
||||
"frame_index": datasets.Value(dtype="int64"),
|
||||
"timestamp": datasets.Value(dtype="float32"),
|
||||
"next.done": datasets.Value(dtype="bool"),
|
||||
"index": datasets.Value(dtype="int64"),
|
||||
"task_index": datasets.Value(dtype="int64"),
|
||||
}
|
||||
features = datasets.Features(
|
||||
{
|
||||
**key_features,
|
||||
**image_features,
|
||||
**common_features,
|
||||
}
|
||||
)
|
||||
if not tasks:
|
||||
tasks = tasks_factory()
|
||||
if not episodes:
|
||||
episodes = episodes_factory()
|
||||
if not features:
|
||||
features = features_factory()
|
||||
|
||||
episode_index_col = np.array([], dtype=np.int64)
|
||||
frame_index_col = np.array([], dtype=np.int64)
|
||||
timestamp_col = np.array([], dtype=np.float32)
|
||||
next_done_col = np.array([], dtype=bool)
|
||||
frame_index_col = np.array([], dtype=np.int64)
|
||||
episode_index_col = np.array([], dtype=np.int64)
|
||||
task_index = np.array([], dtype=np.int64)
|
||||
|
||||
for ep_dict in episode_dicts:
|
||||
for ep_dict in episodes:
|
||||
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
|
||||
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
|
||||
episode_index_col = np.concatenate(
|
||||
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
|
||||
)
|
||||
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
|
||||
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
|
||||
next_done_ep = np.full(ep_dict["length"], False, dtype=bool)
|
||||
next_done_ep[-1] = True
|
||||
next_done_col = np.concatenate((next_done_col, next_done_ep))
|
||||
ep_task_index = get_task_index(task_dicts, ep_dict["tasks"][0])
|
||||
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
|
||||
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
|
||||
|
||||
index_col = np.arange(len(episode_index_col))
|
||||
key_cols = {key: np.random.random((len(index_col), shapes[key])).astype(np.float32) for key in keys}
|
||||
|
||||
image_cols = {}
|
||||
if image_keys:
|
||||
for key in image_keys:
|
||||
image_cols[key] = [
|
||||
img_array_factory(width=shapes[key]["width"], height=shapes[key]["height"])
|
||||
robot_cols = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "image":
|
||||
robot_cols[key] = [
|
||||
img_array_factory(width=ft["shapes"][0], height=ft["shapes"][1])
|
||||
for _ in range(len(index_col))
|
||||
]
|
||||
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
|
||||
robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"])
|
||||
|
||||
hf_features = get_hf_features_from_features(features)
|
||||
dataset = datasets.Dataset.from_dict(
|
||||
{
|
||||
**key_cols,
|
||||
**image_cols,
|
||||
"episode_index": episode_index_col,
|
||||
"frame_index": frame_index_col,
|
||||
**robot_cols,
|
||||
"timestamp": timestamp_col,
|
||||
"next.done": next_done_col,
|
||||
"frame_index": frame_index_col,
|
||||
"episode_index": episode_index_col,
|
||||
"index": index_col,
|
||||
"task_index": task_index,
|
||||
},
|
||||
features=features,
|
||||
features=hf_features,
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
return dataset
|
||||
|
@ -315,26 +285,37 @@ def hf_dataset_factory(img_array_factory, episodes, tasks):
|
|||
|
||||
@pytest.fixture(scope="session")
|
||||
def lerobot_dataset_metadata_factory(
|
||||
info,
|
||||
stats,
|
||||
tasks,
|
||||
episodes,
|
||||
info_factory,
|
||||
stats_factory,
|
||||
tasks_factory,
|
||||
episodes_factory,
|
||||
mock_snapshot_download_factory,
|
||||
):
|
||||
def _create_lerobot_dataset_metadata(
|
||||
root: Path,
|
||||
repo_id: str = DUMMY_REPO_ID,
|
||||
info_dict: dict = info,
|
||||
stats_dict: dict = stats,
|
||||
task_dicts: list[dict] = tasks,
|
||||
episode_dicts: list[dict] = episodes,
|
||||
**kwargs,
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episodes: list[dict] | None = None,
|
||||
local_files_only: bool = False,
|
||||
) -> LeRobotDatasetMetadata:
|
||||
if not info:
|
||||
info = info_factory()
|
||||
if not stats:
|
||||
stats = stats_factory(features=info["features"])
|
||||
if not tasks:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episodes:
|
||||
episodes = episodes_factory(
|
||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
|
||||
)
|
||||
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
info_dict=info_dict,
|
||||
stats_dict=stats_dict,
|
||||
task_dicts=task_dicts,
|
||||
episode_dicts=episode_dicts,
|
||||
info=info,
|
||||
stats=stats,
|
||||
tasks=tasks,
|
||||
episodes=episodes,
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
|
@ -347,48 +328,68 @@ def lerobot_dataset_metadata_factory(
|
|||
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
|
||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
||||
|
||||
return LeRobotDatasetMetadata(
|
||||
repo_id=repo_id, root=root, local_files_only=kwargs.get("local_files_only", False)
|
||||
)
|
||||
return LeRobotDatasetMetadata(repo_id=repo_id, root=root, local_files_only=local_files_only)
|
||||
|
||||
return _create_lerobot_dataset_metadata
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def lerobot_dataset_factory(
|
||||
info,
|
||||
stats,
|
||||
tasks,
|
||||
episodes,
|
||||
hf_dataset,
|
||||
info_factory,
|
||||
stats_factory,
|
||||
tasks_factory,
|
||||
episodes_factory,
|
||||
hf_dataset_factory,
|
||||
mock_snapshot_download_factory,
|
||||
lerobot_dataset_metadata_factory,
|
||||
):
|
||||
def _create_lerobot_dataset(
|
||||
root: Path,
|
||||
repo_id: str = DUMMY_REPO_ID,
|
||||
info_dict: dict = info,
|
||||
stats_dict: dict = stats,
|
||||
task_dicts: list[dict] = tasks,
|
||||
episode_dicts: list[dict] = episodes,
|
||||
hf_ds: datasets.Dataset = hf_dataset,
|
||||
total_episodes: int = 3,
|
||||
total_frames: int = 150,
|
||||
total_tasks: int = 1,
|
||||
multi_task: bool = False,
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episode_dicts: list[dict] | None = None,
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
**kwargs,
|
||||
) -> LeRobotDataset:
|
||||
if not info:
|
||||
info = info_factory(
|
||||
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
|
||||
)
|
||||
if not stats:
|
||||
stats = stats_factory(features=info["features"])
|
||||
if not tasks:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episode_dicts:
|
||||
episode_dicts = episodes_factory(
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
tasks=tasks,
|
||||
multi_task=multi_task,
|
||||
)
|
||||
if not hf_dataset:
|
||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"])
|
||||
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
info_dict=info_dict,
|
||||
stats_dict=stats_dict,
|
||||
task_dicts=task_dicts,
|
||||
episode_dicts=episode_dicts,
|
||||
hf_ds=hf_ds,
|
||||
info=info,
|
||||
stats=stats,
|
||||
tasks=tasks,
|
||||
episodes=episode_dicts,
|
||||
hf_dataset=hf_dataset,
|
||||
)
|
||||
mock_metadata = lerobot_dataset_metadata_factory(
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
info_dict=info_dict,
|
||||
stats_dict=stats_dict,
|
||||
task_dicts=task_dicts,
|
||||
episode_dicts=episode_dicts,
|
||||
**kwargs,
|
||||
info=info,
|
||||
stats=stats,
|
||||
tasks=tasks,
|
||||
episodes=episode_dicts,
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
)
|
||||
with (
|
||||
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
|
||||
|
@ -402,44 +403,3 @@ def lerobot_dataset_factory(
|
|||
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)
|
||||
|
||||
return _create_lerobot_dataset
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def lerobot_dataset_from_episodes_factory(
|
||||
info_factory,
|
||||
tasks_factory,
|
||||
episodes_factory,
|
||||
hf_dataset_factory,
|
||||
lerobot_dataset_factory,
|
||||
):
|
||||
def _create_lerobot_dataset_total_episodes(
|
||||
root: Path,
|
||||
total_episodes: int = 3,
|
||||
total_frames: int = 150,
|
||||
total_tasks: int = 1,
|
||||
multi_task: bool = False,
|
||||
repo_id: str = DUMMY_REPO_ID,
|
||||
**kwargs,
|
||||
):
|
||||
info_dict = info_factory(
|
||||
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
|
||||
)
|
||||
task_dicts = tasks_factory(total_tasks)
|
||||
episode_dicts = episodes_factory(
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
task_dicts=task_dicts,
|
||||
multi_task=multi_task,
|
||||
)
|
||||
hf_dataset = hf_dataset_factory(episode_dicts=episode_dicts, task_dicts=task_dicts)
|
||||
return lerobot_dataset_factory(
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
info_dict=info_dict,
|
||||
task_dicts=task_dicts,
|
||||
episode_dicts=episode_dicts,
|
||||
hf_ds=hf_dataset,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return _create_lerobot_dataset_total_episodes
|
||||
|
|
|
@ -3,6 +3,27 @@ from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
|
|||
LEROBOT_TEST_DIR = LEROBOT_HOME / "_testing"
|
||||
DUMMY_REPO_ID = "dummy/repo"
|
||||
DUMMY_ROBOT_TYPE = "dummy_robot"
|
||||
DUMMY_KEYS = ["state", "action"]
|
||||
DUMMY_CAMERA_KEYS = ["laptop", "phone"]
|
||||
DUMMY_MOTOR_FEATURES = {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||
},
|
||||
"state": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||
},
|
||||
}
|
||||
DUMMY_CAMERA_FEATURES = {
|
||||
"laptop": {"shape": (640, 480, 3), "names": ["width", "height", "channels"], "info": None},
|
||||
"phone": {"shape": (640, 480, 3), "names": ["width", "height", "channels"], "info": None},
|
||||
}
|
||||
DEFAULT_FPS = 30
|
||||
DUMMY_VIDEO_INFO = {
|
||||
"video.fps": DEFAULT_FPS,
|
||||
"video.codec": "av1",
|
||||
"video.pix_fmt": "yuv420p",
|
||||
"video.is_depth_map": False,
|
||||
"has_audio": False,
|
||||
}
|
||||
|
|
|
@ -11,64 +11,77 @@ from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH,
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def info_path(info):
|
||||
def _create_info_json_file(dir: Path, info_dict: dict = info) -> Path:
|
||||
def info_path(info_factory):
|
||||
def _create_info_json_file(dir: Path, info: dict | None = None) -> Path:
|
||||
if not info:
|
||||
info = info_factory()
|
||||
fpath = dir / INFO_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(info_dict, f, indent=4, ensure_ascii=False)
|
||||
json.dump(info, f, indent=4, ensure_ascii=False)
|
||||
return fpath
|
||||
|
||||
return _create_info_json_file
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def stats_path(stats):
|
||||
def _create_stats_json_file(dir: Path, stats_dict: dict = stats) -> Path:
|
||||
def stats_path(stats_factory):
|
||||
def _create_stats_json_file(dir: Path, stats: dict | None = None) -> Path:
|
||||
if not stats:
|
||||
stats = stats_factory()
|
||||
fpath = dir / STATS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(stats_dict, f, indent=4, ensure_ascii=False)
|
||||
json.dump(stats, f, indent=4, ensure_ascii=False)
|
||||
return fpath
|
||||
|
||||
return _create_stats_json_file
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tasks_path(tasks):
|
||||
def _create_tasks_jsonl_file(dir: Path, task_dicts: list = tasks) -> Path:
|
||||
def tasks_path(tasks_factory):
|
||||
def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path:
|
||||
if not tasks:
|
||||
tasks = tasks_factory()
|
||||
fpath = dir / TASKS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(task_dicts)
|
||||
writer.write_all(tasks)
|
||||
return fpath
|
||||
|
||||
return _create_tasks_jsonl_file
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episode_path(episodes):
|
||||
def _create_episodes_jsonl_file(dir: Path, episode_dicts: list = episodes) -> Path:
|
||||
def episode_path(episodes_factory):
|
||||
def _create_episodes_jsonl_file(dir: Path, episodes: list | None = None) -> Path:
|
||||
if not episodes:
|
||||
episodes = episodes_factory()
|
||||
fpath = dir / EPISODES_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(episode_dicts)
|
||||
writer.write_all(episodes)
|
||||
return fpath
|
||||
|
||||
return _create_episodes_jsonl_file
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def single_episode_parquet_path(hf_dataset, info):
|
||||
def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||
def _create_single_episode_parquet(
|
||||
dir: Path, hf_ds: datasets.Dataset = hf_dataset, ep_idx: int = 0
|
||||
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||
) -> Path:
|
||||
if not info:
|
||||
info = info_factory()
|
||||
if hf_dataset is None:
|
||||
hf_dataset = hf_dataset_factory()
|
||||
|
||||
data_path = info["data_path"]
|
||||
chunks_size = info["chunks_size"]
|
||||
ep_chunk = ep_idx // chunks_size
|
||||
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
table = hf_ds.data.table
|
||||
table = hf_dataset.data.table
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
pq.write_table(ep_table, fpath)
|
||||
return fpath
|
||||
|
@ -77,8 +90,15 @@ def single_episode_parquet_path(hf_dataset, info):
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def multi_episode_parquet_path(hf_dataset, info):
|
||||
def _create_multi_episode_parquet(dir: Path, hf_ds: datasets.Dataset = hf_dataset) -> Path:
|
||||
def multi_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||
def _create_multi_episode_parquet(
|
||||
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||
) -> Path:
|
||||
if not info:
|
||||
info = info_factory()
|
||||
if hf_dataset is None:
|
||||
hf_dataset = hf_dataset_factory()
|
||||
|
||||
data_path = info["data_path"]
|
||||
chunks_size = info["chunks_size"]
|
||||
total_episodes = info["total_episodes"]
|
||||
|
@ -86,7 +106,7 @@ def multi_episode_parquet_path(hf_dataset, info):
|
|||
ep_chunk = ep_idx // chunks_size
|
||||
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
table = hf_ds.data.table
|
||||
table = hf_dataset.data.table
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
pq.write_table(ep_table, fpath)
|
||||
return dir / "data"
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import pytest
|
||||
from huggingface_hub.utils import filter_repo_objects
|
||||
|
||||
|
@ -9,16 +10,16 @@ from tests.fixtures.defaults import LEROBOT_TEST_DIR
|
|||
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_snapshot_download_factory(
|
||||
info,
|
||||
info_factory,
|
||||
info_path,
|
||||
stats,
|
||||
stats_factory,
|
||||
stats_path,
|
||||
tasks,
|
||||
tasks_factory,
|
||||
tasks_path,
|
||||
episodes,
|
||||
episodes_factory,
|
||||
episode_path,
|
||||
single_episode_parquet_path,
|
||||
hf_dataset,
|
||||
hf_dataset_factory,
|
||||
):
|
||||
"""
|
||||
This factory allows to patch snapshot_download such that when called, it will create expected files rather
|
||||
|
@ -26,8 +27,25 @@ def mock_snapshot_download_factory(
|
|||
"""
|
||||
|
||||
def _mock_snapshot_download_func(
|
||||
info_dict=info, stats_dict=stats, task_dicts=tasks, episode_dicts=episodes, hf_ds=hf_dataset
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episodes: list[dict] | None = None,
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
):
|
||||
if not info:
|
||||
info = info_factory()
|
||||
if not stats:
|
||||
stats = stats_factory(features=info["features"])
|
||||
if not tasks:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episodes:
|
||||
episodes = episodes_factory(
|
||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
|
||||
)
|
||||
if not hf_dataset:
|
||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
|
||||
|
||||
def _extract_episode_index_from_path(fpath: str) -> int:
|
||||
path = Path(fpath)
|
||||
if path.suffix == ".parquet" and path.stem.startswith("episode_"):
|
||||
|
@ -53,10 +71,10 @@ def mock_snapshot_download_factory(
|
|||
all_files.extend(meta_files)
|
||||
|
||||
data_files = []
|
||||
for episode_dict in episode_dicts:
|
||||
for episode_dict in episodes:
|
||||
ep_idx = episode_dict["episode_index"]
|
||||
ep_chunk = ep_idx // info_dict["chunks_size"]
|
||||
data_path = info_dict["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
ep_chunk = ep_idx // info["chunks_size"]
|
||||
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
data_files.append(data_path)
|
||||
all_files.extend(data_files)
|
||||
|
||||
|
@ -69,15 +87,15 @@ def mock_snapshot_download_factory(
|
|||
if rel_path.startswith("data/"):
|
||||
episode_index = _extract_episode_index_from_path(rel_path)
|
||||
if episode_index is not None:
|
||||
_ = single_episode_parquet_path(local_dir, hf_ds, ep_idx=episode_index)
|
||||
_ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info)
|
||||
if rel_path == INFO_PATH:
|
||||
_ = info_path(local_dir, info_dict)
|
||||
_ = info_path(local_dir, info)
|
||||
elif rel_path == STATS_PATH:
|
||||
_ = stats_path(local_dir, stats_dict)
|
||||
_ = stats_path(local_dir, stats)
|
||||
elif rel_path == TASKS_PATH:
|
||||
_ = tasks_path(local_dir, task_dicts)
|
||||
_ = tasks_path(local_dir, tasks)
|
||||
elif rel_path == EPISODES_PATH:
|
||||
_ = episode_path(local_dir, episode_dicts)
|
||||
_ = episode_path(local_dir, episodes)
|
||||
else:
|
||||
pass
|
||||
return str(local_dir)
|
||||
|
|
|
@ -35,7 +35,6 @@ from lerobot.common.datasets.compute_stats import (
|
|||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
LeRobotDatasetMetadata,
|
||||
MultiLeRobotDataset,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
|
@ -57,10 +56,7 @@ def test_same_attributes_defined(lerobot_dataset_factory, tmp_path):
|
|||
# Instantiate both ways
|
||||
robot = make_robot("koch", mock=True)
|
||||
root_create = tmp_path / "create"
|
||||
metadata_create = LeRobotDatasetMetadata.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create
|
||||
)
|
||||
dataset_create = LeRobotDataset.create(metadata_create)
|
||||
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
|
||||
|
||||
root_init = tmp_path / "init"
|
||||
dataset_init = lerobot_dataset_factory(root=root_init)
|
||||
|
@ -75,14 +71,14 @@ def test_same_attributes_defined(lerobot_dataset_factory, tmp_path):
|
|||
assert init_attr == create_attr
|
||||
|
||||
|
||||
def test_dataset_initialization(lerobot_dataset_from_episodes_factory, tmp_path):
|
||||
def test_dataset_initialization(lerobot_dataset_factory, tmp_path):
|
||||
kwargs = {
|
||||
"repo_id": DUMMY_REPO_ID,
|
||||
"total_episodes": 10,
|
||||
"total_frames": 400,
|
||||
"episodes": [2, 5, 6],
|
||||
}
|
||||
dataset = lerobot_dataset_from_episodes_factory(root=tmp_path, **kwargs)
|
||||
dataset = lerobot_dataset_factory(root=tmp_path, **kwargs)
|
||||
|
||||
assert dataset.repo_id == kwargs["repo_id"]
|
||||
assert dataset.meta.total_episodes == kwargs["total_episodes"]
|
||||
|
|
|
@ -3,12 +3,13 @@ import torch
|
|||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
check_delta_timestamps,
|
||||
check_timestamps_sync,
|
||||
get_delta_indices,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from tests.fixtures.defaults import DUMMY_KEYS
|
||||
from tests.fixtures.defaults import DUMMY_MOTOR_FEATURES
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
@ -53,7 +54,7 @@ def slightly_off_hf_dataset_factory(synced_hf_dataset_factory):
|
|||
|
||||
@pytest.fixture(scope="module")
|
||||
def valid_delta_timestamps_factory():
|
||||
def _create_valid_delta_timestamps(fps: int = 30, keys: list = DUMMY_KEYS) -> dict:
|
||||
def _create_valid_delta_timestamps(fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES) -> dict:
|
||||
delta_timestamps = {key: [i * (1 / fps) for i in range(-10, 10)] for key in keys}
|
||||
return delta_timestamps
|
||||
|
||||
|
@ -63,7 +64,7 @@ def valid_delta_timestamps_factory():
|
|||
@pytest.fixture(scope="module")
|
||||
def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
def _create_invalid_delta_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_KEYS
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
|
||||
) -> dict:
|
||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
||||
# Modify a single timestamp just outside tolerance
|
||||
|
@ -77,7 +78,7 @@ def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
|||
@pytest.fixture(scope="module")
|
||||
def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
def _create_slightly_off_delta_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_KEYS
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
|
||||
) -> dict:
|
||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
||||
# Modify a single timestamp just inside tolerance
|
||||
|
@ -90,14 +91,15 @@ def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
|||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def delta_indices(keys: list = DUMMY_KEYS) -> dict:
|
||||
def delta_indices(keys: list = DUMMY_MOTOR_FEATURES) -> dict:
|
||||
return {key: list(range(-10, 10)) for key in keys}
|
||||
|
||||
|
||||
def test_check_timestamps_sync_synced(synced_hf_dataset_factory, episode_data_index):
|
||||
def test_check_timestamps_sync_synced(synced_hf_dataset_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
synced_hf_dataset = synced_hf_dataset_factory(fps)
|
||||
episode_data_index = calculate_episode_data_index(synced_hf_dataset)
|
||||
result = check_timestamps_sync(
|
||||
hf_dataset=synced_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
|
@ -107,10 +109,11 @@ def test_check_timestamps_sync_synced(synced_hf_dataset_factory, episode_data_in
|
|||
assert result is True
|
||||
|
||||
|
||||
def test_check_timestamps_sync_unsynced(unsynced_hf_dataset_factory, episode_data_index):
|
||||
def test_check_timestamps_sync_unsynced(unsynced_hf_dataset_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s)
|
||||
episode_data_index = calculate_episode_data_index(unsynced_hf_dataset)
|
||||
with pytest.raises(ValueError):
|
||||
check_timestamps_sync(
|
||||
hf_dataset=unsynced_hf_dataset,
|
||||
|
@ -120,10 +123,11 @@ def test_check_timestamps_sync_unsynced(unsynced_hf_dataset_factory, episode_dat
|
|||
)
|
||||
|
||||
|
||||
def test_check_timestamps_sync_unsynced_no_exception(unsynced_hf_dataset_factory, episode_data_index):
|
||||
def test_check_timestamps_sync_unsynced_no_exception(unsynced_hf_dataset_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s)
|
||||
episode_data_index = calculate_episode_data_index(unsynced_hf_dataset)
|
||||
result = check_timestamps_sync(
|
||||
hf_dataset=unsynced_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
|
@ -134,10 +138,11 @@ def test_check_timestamps_sync_unsynced_no_exception(unsynced_hf_dataset_factory
|
|||
assert result is False
|
||||
|
||||
|
||||
def test_check_timestamps_sync_slightly_off(slightly_off_hf_dataset_factory, episode_data_index):
|
||||
def test_check_timestamps_sync_slightly_off(slightly_off_hf_dataset_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
slightly_off_hf_dataset = slightly_off_hf_dataset_factory(fps, tolerance_s)
|
||||
episode_data_index = calculate_episode_data_index(slightly_off_hf_dataset)
|
||||
result = check_timestamps_sync(
|
||||
hf_dataset=slightly_off_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
|
|
|
@ -8,7 +8,7 @@ import pytest
|
|||
from PIL import Image
|
||||
|
||||
from lerobot.common.datasets.image_writer import (
|
||||
ImageWriter,
|
||||
AsyncImageWriter,
|
||||
image_array_to_image,
|
||||
safe_stop_image_writer,
|
||||
write_image,
|
||||
|
@ -17,8 +17,8 @@ from lerobot.common.datasets.image_writer import (
|
|||
DUMMY_IMAGE = "test_image.png"
|
||||
|
||||
|
||||
def test_init_threading(tmp_path):
|
||||
writer = ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=2)
|
||||
def test_init_threading():
|
||||
writer = AsyncImageWriter(num_processes=0, num_threads=2)
|
||||
try:
|
||||
assert writer.num_processes == 0
|
||||
assert writer.num_threads == 2
|
||||
|
@ -30,8 +30,8 @@ def test_init_threading(tmp_path):
|
|||
writer.stop()
|
||||
|
||||
|
||||
def test_init_multiprocessing(tmp_path):
|
||||
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
|
||||
def test_init_multiprocessing():
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
try:
|
||||
assert writer.num_processes == 2
|
||||
assert writer.num_threads == 2
|
||||
|
@ -43,35 +43,9 @@ def test_init_multiprocessing(tmp_path):
|
|||
writer.stop()
|
||||
|
||||
|
||||
def test_write_dir_created(tmp_path):
|
||||
write_dir = tmp_path / "non_existent_dir"
|
||||
assert not write_dir.exists()
|
||||
writer = ImageWriter(write_dir=write_dir)
|
||||
try:
|
||||
assert write_dir.exists()
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_get_image_file_path_and_episode_dir(tmp_path):
|
||||
writer = ImageWriter(write_dir=tmp_path)
|
||||
try:
|
||||
episode_index = 1
|
||||
image_key = "test_key"
|
||||
frame_index = 10
|
||||
expected_episode_dir = tmp_path / f"{image_key}/episode_{episode_index:06d}"
|
||||
expected_path = expected_episode_dir / f"frame_{frame_index:06d}.png"
|
||||
image_file_path = writer.get_image_file_path(episode_index, image_key, frame_index)
|
||||
assert image_file_path == expected_path
|
||||
episode_dir = writer.get_episode_dir(episode_index, image_key)
|
||||
assert episode_dir == expected_episode_dir
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_zero_threads(tmp_path):
|
||||
def test_zero_threads():
|
||||
with pytest.raises(ValueError):
|
||||
ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=0)
|
||||
AsyncImageWriter(num_processes=0, num_threads=0)
|
||||
|
||||
|
||||
def test_image_array_to_image_rgb(img_array_factory):
|
||||
|
@ -148,7 +122,7 @@ def test_write_image_exception(tmp_path):
|
|||
|
||||
|
||||
def test_save_image_numpy(tmp_path, img_array_factory):
|
||||
writer = ImageWriter(write_dir=tmp_path)
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_array = img_array_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
|
@ -163,7 +137,7 @@ def test_save_image_numpy(tmp_path, img_array_factory):
|
|||
|
||||
|
||||
def test_save_image_numpy_multiprocessing(tmp_path, img_array_factory):
|
||||
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
try:
|
||||
image_array = img_array_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
|
@ -177,7 +151,7 @@ def test_save_image_numpy_multiprocessing(tmp_path, img_array_factory):
|
|||
|
||||
|
||||
def test_save_image_torch(tmp_path, img_tensor_factory):
|
||||
writer = ImageWriter(write_dir=tmp_path)
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_tensor = img_tensor_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
|
@ -193,7 +167,7 @@ def test_save_image_torch(tmp_path, img_tensor_factory):
|
|||
|
||||
|
||||
def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory):
|
||||
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
try:
|
||||
image_tensor = img_tensor_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
|
@ -208,7 +182,7 @@ def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory):
|
|||
|
||||
|
||||
def test_save_image_pil(tmp_path, img_factory):
|
||||
writer = ImageWriter(write_dir=tmp_path)
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_pil = img_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
|
@ -223,7 +197,7 @@ def test_save_image_pil(tmp_path, img_factory):
|
|||
|
||||
|
||||
def test_save_image_pil_multiprocessing(tmp_path, img_factory):
|
||||
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
try:
|
||||
image_pil = img_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
|
@ -237,10 +211,10 @@ def test_save_image_pil_multiprocessing(tmp_path, img_factory):
|
|||
|
||||
|
||||
def test_save_image_invalid_data(tmp_path):
|
||||
writer = ImageWriter(write_dir=tmp_path)
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_array = "invalid data"
|
||||
fpath = writer.get_image_file_path(0, "test_key", 0)
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with patch("builtins.print") as mock_print:
|
||||
writer.save_image(image_array, fpath)
|
||||
|
@ -252,47 +226,47 @@ def test_save_image_invalid_data(tmp_path):
|
|||
|
||||
|
||||
def test_save_image_after_stop(tmp_path, img_array_factory):
|
||||
writer = ImageWriter(write_dir=tmp_path)
|
||||
writer = AsyncImageWriter()
|
||||
writer.stop()
|
||||
image_array = img_array_factory()
|
||||
fpath = writer.get_image_file_path(0, "test_key", 0)
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
writer.save_image(image_array, fpath)
|
||||
time.sleep(1)
|
||||
assert not fpath.exists()
|
||||
|
||||
|
||||
def test_stop(tmp_path):
|
||||
writer = ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=2)
|
||||
def test_stop():
|
||||
writer = AsyncImageWriter(num_processes=0, num_threads=2)
|
||||
writer.stop()
|
||||
assert not any(t.is_alive() for t in writer.threads)
|
||||
|
||||
|
||||
def test_stop_multiprocessing(tmp_path):
|
||||
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
|
||||
def test_stop_multiprocessing():
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
writer.stop()
|
||||
assert not any(p.is_alive() for p in writer.processes)
|
||||
|
||||
|
||||
def test_multiple_stops(tmp_path):
|
||||
writer = ImageWriter(write_dir=tmp_path)
|
||||
def test_multiple_stops():
|
||||
writer = AsyncImageWriter()
|
||||
writer.stop()
|
||||
writer.stop() # Should not raise an exception
|
||||
assert not any(t.is_alive() for t in writer.threads)
|
||||
|
||||
|
||||
def test_multiple_stops_multiprocessing(tmp_path):
|
||||
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
|
||||
def test_multiple_stops_multiprocessing():
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
writer.stop()
|
||||
writer.stop() # Should not raise an exception
|
||||
assert not any(t.is_alive() for t in writer.threads)
|
||||
|
||||
|
||||
def test_wait_until_done(tmp_path, img_array_factory):
|
||||
writer = ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=4)
|
||||
writer = AsyncImageWriter(num_processes=0, num_threads=4)
|
||||
try:
|
||||
num_images = 100
|
||||
image_arrays = [img_array_factory(width=500, height=500) for _ in range(num_images)]
|
||||
fpaths = [writer.get_image_file_path(0, "test_key", i) for i in range(num_images)]
|
||||
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
|
||||
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
writer.save_image(image_array, fpath)
|
||||
|
@ -306,11 +280,11 @@ def test_wait_until_done(tmp_path, img_array_factory):
|
|||
|
||||
|
||||
def test_wait_until_done_multiprocessing(tmp_path, img_array_factory):
|
||||
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
try:
|
||||
num_images = 100
|
||||
image_arrays = [img_array_factory() for _ in range(num_images)]
|
||||
fpaths = [writer.get_image_file_path(0, "test_key", i) for i in range(num_images)]
|
||||
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
|
||||
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
writer.save_image(image_array, fpath)
|
||||
|
@ -324,7 +298,7 @@ def test_wait_until_done_multiprocessing(tmp_path, img_array_factory):
|
|||
|
||||
|
||||
def test_exception_handling(tmp_path, img_array_factory):
|
||||
writer = ImageWriter(write_dir=tmp_path)
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_array = img_array_factory()
|
||||
with (
|
||||
|
@ -338,7 +312,7 @@ def test_exception_handling(tmp_path, img_array_factory):
|
|||
|
||||
|
||||
def test_with_different_image_formats(tmp_path, img_array_factory):
|
||||
writer = ImageWriter(write_dir=tmp_path)
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_array = img_array_factory()
|
||||
formats = ["png", "jpeg", "bmp"]
|
||||
|
@ -353,7 +327,7 @@ def test_with_different_image_formats(tmp_path, img_array_factory):
|
|||
def test_safe_stop_image_writer_decorator():
|
||||
class MockDataset:
|
||||
def __init__(self):
|
||||
self.image_writer = MagicMock(spec=ImageWriter)
|
||||
self.image_writer = MagicMock(spec=AsyncImageWriter)
|
||||
|
||||
@safe_stop_image_writer
|
||||
def function_that_raises_exception(dataset=None):
|
||||
|
@ -369,10 +343,10 @@ def test_safe_stop_image_writer_decorator():
|
|||
|
||||
|
||||
def test_main_process_time(tmp_path, img_tensor_factory):
|
||||
writer = ImageWriter(write_dir=tmp_path)
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_tensor = img_tensor_factory()
|
||||
fpath = tmp_path / "test_main_process_time.png"
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
start_time = time.perf_counter()
|
||||
writer.save_image(image_tensor, fpath)
|
||||
end_time = time.perf_counter()
|
||||
|
|
|
@ -213,15 +213,13 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range():
|
|||
@pytest.mark.parametrize("online_dataset_size", [0, 4])
|
||||
@pytest.mark.parametrize("online_sampling_ratio", [0.0, 1.0])
|
||||
def test_compute_sampler_weights_trivial(
|
||||
lerobot_dataset_from_episodes_factory,
|
||||
lerobot_dataset_factory,
|
||||
tmp_path,
|
||||
offline_dataset_size: int,
|
||||
online_dataset_size: int,
|
||||
online_sampling_ratio: float,
|
||||
):
|
||||
offline_dataset = lerobot_dataset_from_episodes_factory(
|
||||
tmp_path, total_episodes=1, total_frames=offline_dataset_size
|
||||
)
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
if online_dataset_size > 0:
|
||||
online_dataset.add_data(
|
||||
|
@ -241,9 +239,9 @@ def test_compute_sampler_weights_trivial(
|
|||
assert torch.allclose(weights, expected_weights)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_from_episodes_factory, tmp_path):
|
||||
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path):
|
||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
||||
offline_dataset = lerobot_dataset_from_episodes_factory(tmp_path, total_episodes=1, total_frames=4)
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
online_sampling_ratio = 0.8
|
||||
|
@ -255,11 +253,9 @@ def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_from_episodes_
|
|||
)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(
|
||||
lerobot_dataset_from_episodes_factory, tmp_path
|
||||
):
|
||||
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path):
|
||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
||||
offline_dataset = lerobot_dataset_from_episodes_factory(tmp_path, total_episodes=1, total_frames=4)
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
weights = compute_sampler_weights(
|
||||
|
@ -270,9 +266,9 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(
|
|||
)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_from_episodes_factory, tmp_path):
|
||||
def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path):
|
||||
"""Note: test copied from test_sampler."""
|
||||
offline_dataset = lerobot_dataset_from_episodes_factory(tmp_path, total_episodes=1, total_frames=2)
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
|
||||
|
|
Loading…
Reference in New Issue