Per-episode stats (#521)

Co-authored-by: Remi Cadene <re.cadene@gmail.com>
Co-authored-by: Remi <remi.cadene@huggingface.co>
This commit is contained in:
Simon Alibert 2025-02-15 15:47:16 +01:00 committed by GitHub
parent 7c2bbee613
commit 8426c64f42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 906 additions and 798 deletions

View File

@ -148,6 +148,10 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
action = zarr_data["action"][:] action = zarr_data["action"][:]
image = zarr_data["img"] # (b, h, w, c) image = zarr_data["img"] # (b, h, w, c)
if image.dtype == np.float32 and image.max() == np.float32(255):
# HACK: images are loaded as float32 but they actually encode uint8 data
image = image.astype(np.uint8)
episode_data_index = { episode_data_index = {
"from": np.concatenate(([0], zarr_data.meta["episode_ends"][:-1])), "from": np.concatenate(([0], zarr_data.meta["episode_ends"][:-1])),
"to": zarr_data.meta["episode_ends"], "to": zarr_data.meta["episode_ends"],

View File

@ -13,202 +13,148 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from copy import deepcopy import numpy as np
from math import ceil
import einops from lerobot.common.datasets.utils import load_image_as_numpy
import torch
import tqdm
def get_stats_einops_patterns(dataset, num_workers=0): def estimate_num_samples(
"""These einops patterns will be used to aggregate batches and compute statistics. dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
) -> int:
"""Heuristic to estimate the number of samples based on dataset size.
The power controls the sample growth relative to dataset size.
Lower the power for less number of samples.
Note: We assume the images are in channel first format For default arguments, we have:
- from 1 to ~500, num_samples=100
- at 1000, num_samples=177
- at 2000, num_samples=299
- at 5000, num_samples=594
- at 10000, num_samples=1000
- at 20000, num_samples=1681
""" """
if dataset_len < min_num_samples:
min_num_samples = dataset_len
return max(min_num_samples, min(int(dataset_len**power), max_num_samples))
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_size=2,
shuffle=False,
)
batch = next(iter(dataloader))
stats_patterns = {} def sample_indices(data_len: int) -> list[int]:
num_samples = estimate_num_samples(data_len)
return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
for key in dataset.features:
# sanity check that tensors are not float64
assert batch[key].dtype != torch.float64
# if isinstance(feats_type, (VideoFrame, Image)): def sample_images(image_paths: list[str]) -> np.ndarray:
if key in dataset.meta.camera_keys: sampled_indices = sample_indices(len(image_paths))
# sanity check that images are channel first images = []
_, c, h, w = batch[key].shape for idx in sampled_indices:
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}" path = image_paths[idx]
# we load as uint8 to reduce memory usage
img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True)
images.append(img)
# sanity check that images are float32 in range [0,1] images = np.stack(images)
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}" return images
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
stats_patterns[key] = "b c h w -> c 1 1"
elif batch[key].ndim == 2: def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
stats_patterns[key] = "b c -> c " return {
elif batch[key].ndim == 1: "min": np.min(array, axis=axis, keepdims=keepdims),
stats_patterns[key] = "b -> 1" "max": np.max(array, axis=axis, keepdims=keepdims),
"mean": np.mean(array, axis=axis, keepdims=keepdims),
"std": np.std(array, axis=axis, keepdims=keepdims),
"count": np.array([len(array)]),
}
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
ep_stats = {}
for key, data in episode_data.items():
if features[key]["dtype"] == "string":
continue # HACK: we should receive np.arrays of strings
elif features[key]["dtype"] in ["image", "video"]:
ep_ft_array = sample_images(data) # data is a list of image paths
axes_to_reduce = (0, 2, 3) # keep channel dim
keepdims = True
else: else:
raise ValueError(f"{key}, {batch[key].shape}") ep_ft_array = data # data is alreay a np.ndarray
axes_to_reduce = 0 # compute stats over the first axis
keepdims = data.ndim == 1 # keep as np.array
return stats_patterns ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
# finally, we normalize and remove batch dim for images
if features[key]["dtype"] in ["image", "video"]:
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
}
return ep_stats
def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None): def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
"""Compute mean/std and min/max statistics of all data keys in a LeRobotDataset.""" for i in range(len(stats_list)):
if max_num_samples is None: for fkey in stats_list[i]:
max_num_samples = len(dataset) for k, v in stats_list[i][fkey].items():
if not isinstance(v, np.ndarray):
# for more info on why we need to set the same number of workers, see `load_from_videos` raise ValueError(
stats_patterns = get_stats_einops_patterns(dataset, num_workers) f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
)
# mean and std will be computed incrementally while max and min will track the running value. if v.ndim == 0:
mean, std, max, min = {}, {}, {}, {} raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
for key in stats_patterns: if k == "count" and v.shape != (1,):
mean[key] = torch.tensor(0.0).float() raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
std[key] = torch.tensor(0.0).float() if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
max[key] = torch.tensor(-float("inf")).float() raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
min[key] = torch.tensor(float("inf")).float()
def create_seeded_dataloader(dataset, batch_size, seed):
generator = torch.Generator()
generator.manual_seed(seed)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_size=batch_size,
shuffle=True,
drop_last=False,
generator=generator,
)
return dataloader
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
# surprises when rerunning the sampler.
first_batch = None
running_item_count = 0 # for online mean computation
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
if first_batch is None:
first_batch = deepcopy(batch)
for key, pattern in stats_patterns.items():
batch[key] = batch[key].float()
# Numerically stable update step for mean computation.
batch_mean = einops.reduce(batch[key], pattern, "mean")
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
# and x is the current batch mean. Some rearrangement is then required to avoid risking
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
if i == ceil(max_num_samples / batch_size) - 1:
break
first_batch_ = None
running_item_count = 0 # for online std computation
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
# Sanity check to make sure the batches are still in the same order as before.
if first_batch_ is None:
first_batch_ = deepcopy(batch)
for key in stats_patterns:
assert torch.equal(first_batch_[key], first_batch[key])
for key, pattern in stats_patterns.items():
batch[key] = batch[key].float()
# Numerically stable update step for mean computation (where the mean is over squared
# residuals).See notes in the mean computation loop above.
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
if i == ceil(max_num_samples / batch_size) - 1:
break
for key in stats_patterns:
std[key] = torch.sqrt(std[key])
stats = {}
for key in stats_patterns:
stats[key] = {
"mean": mean[key],
"std": std[key],
"max": max[key],
"min": min[key],
}
return stats
def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]: def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
"""Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch. """Aggregates stats for a single feature."""
means = np.stack([s["mean"] for s in stats_ft_list])
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
counts = np.stack([s["count"] for s in stats_ft_list])
total_count = counts.sum(axis=0)
The final stats will have the union of all data keys from each of the datasets. # Prepare weighted mean by matching number of dimensions
while counts.ndim < means.ndim:
counts = np.expand_dims(counts, axis=-1)
The final stats will have the union of all data keys from each of the datasets. For instance: # Compute the weighted mean
- new_max = max(max_dataset_0, max_dataset_1, ...) weighted_means = means * counts
total_mean = weighted_means.sum(axis=0) / total_count
# Compute the variance using the parallel algorithm
delta_means = means - total_mean
weighted_variances = (variances + delta_means**2) * counts
total_variance = weighted_variances.sum(axis=0) / total_count
return {
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
"mean": total_mean,
"std": np.sqrt(total_variance),
"count": total_count,
}
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
The final stats will have the union of all data keys from each of the stats dicts.
For instance:
- new_min = min(min_dataset_0, min_dataset_1, ...) - new_min = min(min_dataset_0, min_dataset_1, ...)
- new_mean = (mean of all data) - new_max = max(max_dataset_0, max_dataset_1, ...)
- new_mean = (mean of all data, weighted by counts)
- new_std = (std of all data) - new_std = (std of all data)
""" """
data_keys = set()
for dataset in ls_datasets: _assert_type_and_shape(stats_list)
data_keys.update(dataset.meta.stats.keys())
stats = {k: {} for k in data_keys} data_keys = {key for stats in stats_list for key in stats}
for data_key in data_keys: aggregated_stats = {key: {} for key in data_keys}
for stat_key in ["min", "max"]:
# compute `max(dataset_0["max"], dataset_1["max"], ...)` for key in data_keys:
stats[data_key][stat_key] = einops.reduce( stats_with_key = [stats[key] for stats in stats_list if key in stats]
torch.stack( aggregated_stats[key] = aggregate_feature_stats(stats_with_key)
[ds.meta.stats[data_key][stat_key] for ds in ls_datasets if data_key in ds.meta.stats],
dim=0, return aggregated_stats
),
"n ... -> ...",
stat_key,
)
total_samples = sum(d.num_frames for d in ls_datasets if data_key in d.meta.stats)
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
# dataset, then divide by total_samples to get the overall "mean".
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
# numerical overflow!
stats[data_key]["mean"] = sum(
d.meta.stats[data_key]["mean"] * (d.num_frames / total_samples)
for d in ls_datasets
if data_key in d.meta.stats
)
# The derivation for standard deviation is a little more involved but is much in the same spirit as
# the computation of the mean.
# Given two sets of data where the statistics are known:
# σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
# where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
# numerical overflow!
stats[data_key]["std"] = torch.sqrt(
sum(
(
d.meta.stats[data_key]["std"] ** 2
+ (d.meta.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2
)
* (d.num_frames / total_samples)
for d in ls_datasets
if data_key in d.meta.stats
)
)
return stats

View File

@ -26,18 +26,17 @@ import PIL.Image
import torch import torch
import torch.utils import torch.utils
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import create_repo, snapshot_download, upload_folder from huggingface_hub import HfApi, snapshot_download
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
DEFAULT_FEATURES, DEFAULT_FEATURES,
DEFAULT_IMAGE_PATH, DEFAULT_IMAGE_PATH,
EPISODES_PATH,
INFO_PATH, INFO_PATH,
STATS_PATH,
TASKS_PATH, TASKS_PATH,
append_jsonlines, append_jsonlines,
backward_compatible_episodes_stats,
check_delta_timestamps, check_delta_timestamps,
check_frame_features, check_frame_features,
check_timestamps_sync, check_timestamps_sync,
@ -52,10 +51,13 @@ from lerobot.common.datasets.utils import (
get_hub_safe_version, get_hub_safe_version,
hf_transform_to_torch, hf_transform_to_torch,
load_episodes, load_episodes,
load_episodes_stats,
load_info, load_info,
load_stats, load_stats,
load_tasks, load_tasks,
serialize_dict, write_episode,
write_episode_stats,
write_info,
write_json, write_json,
write_parquet, write_parquet,
) )
@ -90,6 +92,17 @@ class LeRobotDatasetMetadata:
self.stats = load_stats(self.root) self.stats = load_stats(self.root)
self.tasks, self.task_to_task_index = load_tasks(self.root) self.tasks, self.task_to_task_index = load_tasks(self.root)
self.episodes = load_episodes(self.root) self.episodes = load_episodes(self.root)
try:
self.episodes_stats = load_episodes_stats(self.root)
self.stats = aggregate_stats(list(self.episodes_stats.values()))
except FileNotFoundError:
logging.warning(
f"""'episodes_stats.jsonl' not found. Using global dataset stats for each episode instead.
Convert your dataset stats to the new format using this command:
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={self.repo_id} """
)
self.stats = load_stats(self.root)
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
def pull_from_repo( def pull_from_repo(
self, self,
@ -228,7 +241,13 @@ class LeRobotDatasetMetadata:
} }
append_jsonlines(task_dict, self.root / TASKS_PATH) append_jsonlines(task_dict, self.root / TASKS_PATH)
def save_episode(self, episode_index: int, episode_length: int, episode_tasks: list[str]) -> None: def save_episode(
self,
episode_index: int,
episode_length: int,
episode_tasks: list[str],
episode_stats: dict[str, dict],
) -> None:
self.info["total_episodes"] += 1 self.info["total_episodes"] += 1
self.info["total_frames"] += episode_length self.info["total_frames"] += episode_length
@ -238,21 +257,19 @@ class LeRobotDatasetMetadata:
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
self.info["total_videos"] += len(self.video_keys) self.info["total_videos"] += len(self.video_keys)
write_json(self.info, self.root / INFO_PATH) write_info(self.info, self.root)
episode_dict = { episode_dict = {
"episode_index": episode_index, "episode_index": episode_index,
"tasks": episode_tasks, "tasks": episode_tasks,
"length": episode_length, "length": episode_length,
} }
self.episodes.append(episode_dict) self.episodes[episode_index] = episode_dict
append_jsonlines(episode_dict, self.root / EPISODES_PATH) write_episode(episode_dict, self.root)
# TODO(aliberts): refactor stats in save_episodes self.episodes_stats[episode_index] = episode_stats
# image_sampling = int(self.fps / 2) # sample 2 img/s for the stats self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
# ep_stats = compute_episode_stats(episode_buffer, self.features, episode_length, image_sampling=image_sampling) write_episode_stats(episode_index, episode_stats, self.root)
# ep_stats = serialize_dict(ep_stats)
# append_jsonlines(ep_stats, self.root / STATS_PATH)
def write_video_info(self) -> None: def write_video_info(self) -> None:
""" """
@ -309,6 +326,7 @@ class LeRobotDatasetMetadata:
) )
else: else:
# TODO(aliberts, rcadene): implement sanity check for features # TODO(aliberts, rcadene): implement sanity check for features
features = {**features, **DEFAULT_FEATURES}
# check if none of the features contains a "/" in their names, # check if none of the features contains a "/" in their names,
# as this would break the dict flattening in the stats computation, which uses '/' as separator # as this would break the dict flattening in the stats computation, which uses '/' as separator
@ -319,7 +337,7 @@ class LeRobotDatasetMetadata:
features = {**features, **DEFAULT_FEATURES} features = {**features, **DEFAULT_FEATURES}
obj.tasks, obj.task_to_task_index = {}, {} obj.tasks, obj.task_to_task_index = {}, {}
obj.stats, obj.episodes = {}, [] obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos) obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
if len(obj.video_keys) > 0 and not use_videos: if len(obj.video_keys) > 0 and not use_videos:
raise ValueError() raise ValueError()
@ -457,6 +475,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Load metadata # Load metadata
self.meta = LeRobotDatasetMetadata(self.repo_id, self.root, self.local_files_only) self.meta = LeRobotDatasetMetadata(self.repo_id, self.root, self.local_files_only)
if self.episodes is not None and self.meta._version == CODEBASE_VERSION:
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
self.stats = aggregate_stats(episodes_stats)
# Check version # Check version
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION) check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
@ -479,10 +500,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
def push_to_hub( def push_to_hub(
self, self,
branch: str | None = None,
create_card: bool = True,
tags: list | None = None, tags: list | None = None,
license: str | None = "apache-2.0", license: str | None = "apache-2.0",
push_videos: bool = True, push_videos: bool = True,
private: bool = False, private: bool = False,
allow_patterns: list[str] | str | None = None,
**card_kwargs, **card_kwargs,
) -> None: ) -> None:
if not self.consolidated: if not self.consolidated:
@ -496,24 +520,32 @@ class LeRobotDataset(torch.utils.data.Dataset):
if not push_videos: if not push_videos:
ignore_patterns.append("videos/") ignore_patterns.append("videos/")
create_repo( hub_api = HfApi()
hub_api.create_repo(
repo_id=self.repo_id, repo_id=self.repo_id,
private=private, private=private,
repo_type="dataset", repo_type="dataset",
exist_ok=True, exist_ok=True,
) )
if branch:
create_branch(repo_id=self.repo_id, branch=branch, repo_type="dataset")
upload_folder( hub_api.upload_folder(
repo_id=self.repo_id, repo_id=self.repo_id,
folder_path=self.root, folder_path=self.root,
repo_type="dataset", repo_type="dataset",
revision=branch,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns, ignore_patterns=ignore_patterns,
) )
card = create_lerobot_dataset_card( if create_card:
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs card = create_lerobot_dataset_card(
) tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset") )
create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset") card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
if not branch:
create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset")
def pull_from_repo( def pull_from_repo(
self, self,
@ -630,7 +662,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if key not in self.meta.video_keys if key not in self.meta.video_keys
} }
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict: def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
Segmentation Fault. This probably happens because a memory reference to the video loader is created in Segmentation Fault. This probably happens because a memory reference to the video loader is created in
@ -660,8 +692,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
query_indices = None query_indices = None
if self.delta_indices is not None: if self.delta_indices is not None:
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx query_indices, padding = self._get_query_indices(idx, ep_idx)
query_indices, padding = self._get_query_indices(idx, current_ep_idx)
query_result = self._query_hf_dataset(query_indices) query_result = self._query_hf_dataset(query_indices)
item = {**item, **padding} item = {**item, **padding}
for key, val in query_result.items(): for key, val in query_result.items():
@ -735,11 +766,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
if self.episode_buffer is None: if self.episode_buffer is None:
self.episode_buffer = self.create_episode_buffer() self.episode_buffer = self.create_episode_buffer()
# Automatically add frame_index and timestamp to episode buffer
frame_index = self.episode_buffer["size"] frame_index = self.episode_buffer["size"]
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
self.episode_buffer["frame_index"].append(frame_index) self.episode_buffer["frame_index"].append(frame_index)
self.episode_buffer["timestamp"].append(timestamp) self.episode_buffer["timestamp"].append(timestamp)
# Add frame features to episode_buffer
for key in frame: for key in frame:
if key == "task": if key == "task":
# Note: we associate the task in natural language to its task index during `save_episode` # Note: we associate the task in natural language to its task index during `save_episode`
@ -787,7 +820,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# TODO(aliberts): Add option to use existing episode_index # TODO(aliberts): Add option to use existing episode_index
raise NotImplementedError( raise NotImplementedError(
"You might have manually provided the episode_buffer with an episode_index that doesn't " "You might have manually provided the episode_buffer with an episode_index that doesn't "
"match the total number of episodes in the dataset. This is not supported for now." "match the total number of episodes already in the dataset. This is not supported for now."
) )
if episode_length == 0: if episode_length == 0:
@ -821,8 +854,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
self._wait_image_writer() self._wait_image_writer()
self._save_episode_table(episode_buffer, episode_index) self._save_episode_table(episode_buffer, episode_index)
ep_stats = compute_episode_stats(episode_buffer, self.features)
self.meta.save_episode(episode_index, episode_length, episode_tasks) self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
if encode_videos and len(self.meta.video_keys) > 0: if encode_videos and len(self.meta.video_keys) > 0:
video_paths = self.encode_episode_videos(episode_index) video_paths = self.encode_episode_videos(episode_index)
@ -908,7 +941,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
return video_paths return video_paths
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None: def consolidate(self, keep_image_files: bool = False) -> None:
self.hf_dataset = self.load_hf_dataset() self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s) check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
@ -928,17 +961,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
parquet_files = list(self.root.rglob("*.parquet")) parquet_files = list(self.root.rglob("*.parquet"))
assert len(parquet_files) == self.num_episodes assert len(parquet_files) == self.num_episodes
if run_compute_stats: self.consolidated = True
self.stop_image_writer()
# TODO(aliberts): refactor stats in save_episodes
self.meta.stats = compute_stats(self)
serialized_stats = serialize_dict(self.meta.stats)
write_json(serialized_stats, self.root / STATS_PATH)
self.consolidated = True
else:
logging.warning(
"Skipping computation of the dataset statistics, dataset is not fully consolidated."
)
@classmethod @classmethod
def create( def create(
@ -1056,7 +1079,10 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
self.image_transforms = image_transforms self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps self.delta_timestamps = delta_timestamps
self.stats = aggregate_stats(self._datasets) # TODO(rcadene, aliberts): We should not perform this aggregation for datasets
# with multiple robots of different ranges. Instead we should have one normalization
# per robot.
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
@property @property
def repo_id_to_index(self): def repo_id_to_index(self):

View File

@ -43,6 +43,7 @@ DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
INFO_PATH = "meta/info.json" INFO_PATH = "meta/info.json"
EPISODES_PATH = "meta/episodes.jsonl" EPISODES_PATH = "meta/episodes.jsonl"
STATS_PATH = "meta/stats.json" STATS_PATH = "meta/stats.json"
EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
TASKS_PATH = "meta/tasks.jsonl" TASKS_PATH = "meta/tasks.jsonl"
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
@ -113,7 +114,16 @@ def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict: def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
serialized_dict = {key: value.tolist() for key, value in flatten_dict(stats).items()} serialized_dict = {}
for key, value in flatten_dict(stats).items():
if isinstance(value, (torch.Tensor, np.ndarray)):
serialized_dict[key] = value.tolist()
elif isinstance(value, np.generic):
serialized_dict[key] = value.item()
elif isinstance(value, (int, float)):
serialized_dict[key] = value
else:
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
return unflatten_dict(serialized_dict) return unflatten_dict(serialized_dict)
@ -154,6 +164,10 @@ def append_jsonlines(data: dict, fpath: Path) -> None:
writer.write(data) writer.write(data)
def write_info(info: dict, local_dir: Path):
write_json(info, local_dir / INFO_PATH)
def load_info(local_dir: Path) -> dict: def load_info(local_dir: Path) -> dict:
info = load_json(local_dir / INFO_PATH) info = load_json(local_dir / INFO_PATH)
for ft in info["features"].values(): for ft in info["features"].values():
@ -161,12 +175,29 @@ def load_info(local_dir: Path) -> dict:
return info return info
def load_stats(local_dir: Path) -> dict: def write_stats(stats: dict, local_dir: Path):
serialized_stats = serialize_dict(stats)
write_json(serialized_stats, local_dir / STATS_PATH)
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
if not (local_dir / STATS_PATH).exists(): if not (local_dir / STATS_PATH).exists():
return None return None
stats = load_json(local_dir / STATS_PATH) stats = load_json(local_dir / STATS_PATH)
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()} return cast_stats_to_numpy(stats)
return unflatten_dict(stats)
def write_task(task_index: int, task: dict, local_dir: Path):
task_dict = {
"task_index": task_index,
"task": task,
}
append_jsonlines(task_dict, local_dir / TASKS_PATH)
def load_tasks(local_dir: Path) -> dict: def load_tasks(local_dir: Path) -> dict:
@ -176,16 +207,42 @@ def load_tasks(local_dir: Path) -> dict:
return tasks, task_to_task_index return tasks, task_to_task_index
def write_episode(episode: dict, local_dir: Path):
append_jsonlines(episode, local_dir / EPISODES_PATH)
def load_episodes(local_dir: Path) -> dict: def load_episodes(local_dir: Path) -> dict:
return load_jsonlines(local_dir / EPISODES_PATH) episodes = load_jsonlines(local_dir / EPISODES_PATH)
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
def load_image_as_numpy(fpath: str | Path, dtype="float32", channel_first: bool = True) -> np.ndarray: def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
# We wrap episode_stats in a dictionnary since `episode_stats["episode_index"]`
# is a dictionary of stats and not an integer.
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
def load_episodes_stats(local_dir: Path) -> dict:
episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH)
return {
item["episode_index"]: cast_stats_to_numpy(item["stats"])
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
}
def backward_compatible_episodes_stats(stats, episodes: list[int]) -> dict[str, dict[str, np.ndarray]]:
return {ep_idx: stats for ep_idx in episodes}
def load_image_as_numpy(
fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
) -> np.ndarray:
img = PILImage.open(fpath).convert("RGB") img = PILImage.open(fpath).convert("RGB")
img_array = np.array(img, dtype=dtype) img_array = np.array(img, dtype=dtype)
if channel_first: # (H, W, C) -> (C, H, W) if channel_first: # (H, W, C) -> (C, H, W)
img_array = np.transpose(img_array, (2, 0, 1)) img_array = np.transpose(img_array, (2, 0, 1))
if "float" in dtype: if np.issubdtype(dtype, np.floating):
img_array /= 255.0 img_array /= 255.0
return img_array return img_array
@ -370,9 +427,9 @@ def create_empty_dataset_info(
def get_episode_data_index( def get_episode_data_index(
episode_dicts: list[dict], episodes: list[int] | None = None episode_dicts: dict[dict], episodes: list[int] | None = None
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)} episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
if episodes is not None: if episodes is not None:
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes} episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}

View File

@ -0,0 +1,87 @@
"""
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
2.1. It performs the following:
- Generates per-episodes stats and writes them in `episodes_stats.jsonl`
- Removes the deprecated `stats.json` (by default)
- Updates codebase_version in `info.json`
Usage:
```bash
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py \
--repo-id=aliberts/koch_tutorial
```
"""
# TODO(rcadene, aliberts): ensure this script works for any other changes for the final v2.1
import argparse
from huggingface_hub import HfApi
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
def main(
repo_id: str,
test_branch: str | None = None,
delete_old_stats: bool = False,
num_workers: int = 4,
):
dataset = LeRobotDataset(repo_id)
if (dataset.root / EPISODES_STATS_PATH).is_file():
raise FileExistsError("episodes_stats.jsonl already exists.")
convert_stats(dataset, num_workers=num_workers)
ref_stats = load_stats(dataset.root)
check_aggregate_stats(dataset, ref_stats)
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
write_info(dataset.meta.info, dataset.root)
dataset.push_to_hub(branch=test_branch, create_card=False, allow_patterns="meta/")
if delete_old_stats:
if (dataset.root / STATS_PATH).is_file:
(dataset.root / STATS_PATH).unlink()
hub_api = HfApi()
if hub_api.file_exists(
STATS_PATH, repo_id=dataset.repo_id, revision=test_branch, repo_type="dataset"
):
hub_api.delete_file(
STATS_PATH, repo_id=dataset.repo_id, revision=test_branch, repo_type="dataset"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
)
parser.add_argument(
"--test-branch",
type=str,
default=None,
help="Repo branch to test your conversion first (e.g. 'v2.0.test')",
)
parser.add_argument(
"--delete-old-stats",
type=bool,
default=False,
help="Delete the deprecated `stats.json`",
)
parser.add_argument(
"--num-workers",
type=int,
default=4,
help="Number of workers for parallelizing compute",
)
args = parser.parse_args()
main(**vars(args))

View File

@ -0,0 +1,85 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
from tqdm import tqdm
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import write_episode_stats
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
ep_len = dataset.meta.episodes[episode_index]["length"]
sampled_indices = sample_indices(ep_len)
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
video_frames = dataset._query_videos(query_timestamps, episode_index)
return video_frames[ft_key].numpy()
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
ep_stats = {}
for key, ft in dataset.features.items():
if ft["dtype"] == "video":
# We sample only for videos
ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key)
else:
ep_ft_data = np.array(ep_data[key])
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
if ft["dtype"] in ["image", "video"]: # remove batch dim
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
}
dataset.meta.episodes_stats[ep_idx] = ep_stats
def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
assert dataset.episodes is None
print("Computing episodes stats")
total_episodes = dataset.meta.total_episodes
if num_workers > 0:
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = {
executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx
for ep_idx in range(total_episodes)
}
for future in tqdm(as_completed(futures), total=total_episodes):
future.result()
else:
for ep_idx in tqdm(range(total_episodes)):
convert_episode_stats(dataset, ep_idx)
for ep_idx in tqdm(range(total_episodes)):
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
def check_aggregate_stats(
dataset: LeRobotDataset,
reference_stats: dict[str, dict[str, np.ndarray]],
video_rtol_atol: tuple[float] = (1e-2, 1e-2),
default_rtol_atol: tuple[float] = (5e-6, 0.0),
):
"""Verifies that the aggregated stats from episodes_stats are close to reference stats."""
agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values()))
for key, ft in dataset.features.items():
# These values might need some fine-tuning
if ft["dtype"] == "video":
# to account for image sub-sampling
rtol, atol = video_rtol_atol
else:
rtol, atol = default_rtol_atol
for stat, val in agg_stats[key].items():
if key in reference_stats and stat in reference_stats[key]:
err_msg = f"feature='{key}' stats='{stat}'"
np.testing.assert_allclose(
val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
)

View File

@ -69,8 +69,8 @@ def decode_video_frames_torchvision(
# set the first and last requested timestamps # set the first and last requested timestamps
# Note: previous timestamps are usually loaded, since we need to access the previous key frame # Note: previous timestamps are usually loaded, since we need to access the previous key frame
first_ts = timestamps[0] first_ts = min(timestamps)
last_ts = timestamps[-1] last_ts = max(timestamps)
# access closest key frame of the first requested frame # access closest key frame of the first requested frame
# Note: closest key frame timestamp is usally smaller than `first_ts` (e.g. key frame can be the first frame of the video) # Note: closest key frame timestamp is usally smaller than `first_ts` (e.g. key frame can be the first frame of the video)

View File

@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@ -77,17 +78,29 @@ def create_stats_buffers(
} }
) )
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
if stats: if stats:
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated if isinstance(stats[key]["mean"], np.ndarray):
# tensors anywhere (for example, when we use the same stats for normalization and if norm_mode is NormalizationMode.MEAN_STD:
# unnormalization). See the logic here buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
if norm_mode is NormalizationMode.MEAN_STD: elif norm_mode is NormalizationMode.MIN_MAX:
buffer["mean"].data = stats[key]["mean"].clone() buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
buffer["std"].data = stats[key]["std"].clone() buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
elif norm_mode is NormalizationMode.MIN_MAX: elif isinstance(stats[key]["mean"], torch.Tensor):
buffer["min"].data = stats[key]["min"].clone() # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
buffer["max"].data = stats[key]["max"].clone() # tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
else:
type_ = type(stats[key]["mean"])
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
stats_buffers[key] = buffer stats_buffers[key] = buffer
return stats_buffers return stats_buffers
@ -141,6 +154,7 @@ class Normalize(nn.Module):
batch = dict(batch) # shallow copy avoids mutating the input batch batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items(): for key, ft in self.features.items():
if key not in batch: if key not in batch:
# FIXME(aliberts, rcadene): This might lead to silent fail!
continue continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)

View File

@ -60,8 +60,6 @@ class RecordControlConfig(ControlConfig):
num_episodes: int = 50 num_episodes: int = 50
# Encode frames in the dataset into video # Encode frames in the dataset into video
video: bool = True video: bool = True
# By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.
run_compute_stats: bool = True
# Upload dataset to Hugging Face hub. # Upload dataset to Hugging Face hub.
push_to_hub: bool = True push_to_hub: bool = True
# Upload on private repository on the Hugging Face hub. # Upload on private repository on the Hugging Face hub.

View File

@ -301,10 +301,7 @@ def record(
log_say("Stop recording", cfg.play_sounds, blocking=True) log_say("Stop recording", cfg.play_sounds, blocking=True)
stop_recording(robot, listener, cfg.display_cameras) stop_recording(robot, listener, cfg.display_cameras)
if cfg.run_compute_stats: dataset.consolidate()
logging.info("Computing dataset statistics")
dataset.consolidate(cfg.run_compute_stats)
if cfg.push_to_hub: if cfg.push_to_hub:
dataset.push_to_hub(tags=cfg.tags, private=cfg.private) dataset.push_to_hub(tags=cfg.tags, private=cfg.private)

View File

@ -29,7 +29,7 @@ from tests.fixtures.constants import (
def get_task_index(task_dicts: dict, task: str) -> int: def get_task_index(task_dicts: dict, task: str) -> int:
tasks = {d["task_index"]: d["task"] for d in task_dicts} tasks = {d["task_index"]: d["task"] for d in task_dicts.values()}
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()} task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
return task_to_task_index[task] return task_to_task_index[task]
@ -142,6 +142,7 @@ def stats_factory():
"mean": np.full((3, 1, 1), 0.5, dtype=np.float32).tolist(), "mean": np.full((3, 1, 1), 0.5, dtype=np.float32).tolist(),
"min": np.full((3, 1, 1), 0, dtype=np.float32).tolist(), "min": np.full((3, 1, 1), 0, dtype=np.float32).tolist(),
"std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(), "std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(),
"count": [10],
} }
else: else:
stats[key] = { stats[key] = {
@ -149,20 +150,38 @@ def stats_factory():
"mean": np.full(shape, 0.5, dtype=dtype).tolist(), "mean": np.full(shape, 0.5, dtype=dtype).tolist(),
"min": np.full(shape, 0, dtype=dtype).tolist(), "min": np.full(shape, 0, dtype=dtype).tolist(),
"std": np.full(shape, 0.25, dtype=dtype).tolist(), "std": np.full(shape, 0.25, dtype=dtype).tolist(),
"count": [10],
} }
return stats return stats
return _create_stats return _create_stats
@pytest.fixture(scope="session")
def episodes_stats_factory(stats_factory):
def _create_episodes_stats(
features: dict[str],
total_episodes: int = 3,
) -> dict:
episodes_stats = {}
for episode_index in range(total_episodes):
episodes_stats[episode_index] = {
"episode_index": episode_index,
"stats": stats_factory(features),
}
return episodes_stats
return _create_episodes_stats
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def tasks_factory(): def tasks_factory():
def _create_tasks(total_tasks: int = 3) -> int: def _create_tasks(total_tasks: int = 3) -> int:
tasks_list = [] tasks = {}
for i in range(total_tasks): for task_index in range(total_tasks):
task_dict = {"task_index": i, "task": f"Perform action {i}."} task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."}
tasks_list.append(task_dict) tasks[task_index] = task_dict
return tasks_list return tasks
return _create_tasks return _create_tasks
@ -191,10 +210,10 @@ def episodes_factory(tasks_factory):
# Generate random lengths that sum up to total_length # Generate random lengths that sum up to total_length
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist() lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
tasks_list = [task_dict["task"] for task_dict in tasks] tasks_list = [task_dict["task"] for task_dict in tasks.values()]
num_tasks_available = len(tasks_list) num_tasks_available = len(tasks_list)
episodes_list = [] episodes = {}
remaining_tasks = tasks_list.copy() remaining_tasks = tasks_list.copy()
for ep_idx in range(total_episodes): for ep_idx in range(total_episodes):
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1 num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
@ -204,15 +223,13 @@ def episodes_factory(tasks_factory):
for task in episode_tasks: for task in episode_tasks:
remaining_tasks.remove(task) remaining_tasks.remove(task)
episodes_list.append( episodes[ep_idx] = {
{ "episode_index": ep_idx,
"episode_index": ep_idx, "tasks": episode_tasks,
"tasks": episode_tasks, "length": lengths[ep_idx],
"length": lengths[ep_idx], }
}
)
return episodes_list return episodes
return _create_episodes return _create_episodes
@ -236,7 +253,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
frame_index_col = np.array([], dtype=np.int64) frame_index_col = np.array([], dtype=np.int64)
episode_index_col = np.array([], dtype=np.int64) episode_index_col = np.array([], dtype=np.int64)
task_index = np.array([], dtype=np.int64) task_index = np.array([], dtype=np.int64)
for ep_dict in episodes: for ep_dict in episodes.values():
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps)) timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int))) frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
episode_index_col = np.concatenate( episode_index_col = np.concatenate(
@ -279,6 +296,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
def lerobot_dataset_metadata_factory( def lerobot_dataset_metadata_factory(
info_factory, info_factory,
stats_factory, stats_factory,
episodes_stats_factory,
tasks_factory, tasks_factory,
episodes_factory, episodes_factory,
mock_snapshot_download_factory, mock_snapshot_download_factory,
@ -288,6 +306,7 @@ def lerobot_dataset_metadata_factory(
repo_id: str = DUMMY_REPO_ID, repo_id: str = DUMMY_REPO_ID,
info: dict | None = None, info: dict | None = None,
stats: dict | None = None, stats: dict | None = None,
episodes_stats: list[dict] | None = None,
tasks: list[dict] | None = None, tasks: list[dict] | None = None,
episodes: list[dict] | None = None, episodes: list[dict] | None = None,
local_files_only: bool = False, local_files_only: bool = False,
@ -296,6 +315,10 @@ def lerobot_dataset_metadata_factory(
info = info_factory() info = info_factory()
if not stats: if not stats:
stats = stats_factory(features=info["features"]) stats = stats_factory(features=info["features"])
if not episodes_stats:
episodes_stats = episodes_stats_factory(
features=info["features"], total_episodes=info["total_episodes"]
)
if not tasks: if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"]) tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes: if not episodes:
@ -306,6 +329,7 @@ def lerobot_dataset_metadata_factory(
mock_snapshot_download = mock_snapshot_download_factory( mock_snapshot_download = mock_snapshot_download_factory(
info=info, info=info,
stats=stats, stats=stats,
episodes_stats=episodes_stats,
tasks=tasks, tasks=tasks,
episodes=episodes, episodes=episodes,
) )
@ -329,6 +353,7 @@ def lerobot_dataset_metadata_factory(
def lerobot_dataset_factory( def lerobot_dataset_factory(
info_factory, info_factory,
stats_factory, stats_factory,
episodes_stats_factory,
tasks_factory, tasks_factory,
episodes_factory, episodes_factory,
hf_dataset_factory, hf_dataset_factory,
@ -344,6 +369,7 @@ def lerobot_dataset_factory(
multi_task: bool = False, multi_task: bool = False,
info: dict | None = None, info: dict | None = None,
stats: dict | None = None, stats: dict | None = None,
episodes_stats: list[dict] | None = None,
tasks: list[dict] | None = None, tasks: list[dict] | None = None,
episode_dicts: list[dict] | None = None, episode_dicts: list[dict] | None = None,
hf_dataset: datasets.Dataset | None = None, hf_dataset: datasets.Dataset | None = None,
@ -355,6 +381,8 @@ def lerobot_dataset_factory(
) )
if not stats: if not stats:
stats = stats_factory(features=info["features"]) stats = stats_factory(features=info["features"])
if not episodes_stats:
episodes_stats = episodes_stats_factory(features=info["features"], total_episodes=total_episodes)
if not tasks: if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"]) tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episode_dicts: if not episode_dicts:
@ -370,6 +398,7 @@ def lerobot_dataset_factory(
mock_snapshot_download = mock_snapshot_download_factory( mock_snapshot_download = mock_snapshot_download_factory(
info=info, info=info,
stats=stats, stats=stats,
episodes_stats=episodes_stats,
tasks=tasks, tasks=tasks,
episodes=episode_dicts, episodes=episode_dicts,
hf_dataset=hf_dataset, hf_dataset=hf_dataset,
@ -379,6 +408,7 @@ def lerobot_dataset_factory(
repo_id=repo_id, repo_id=repo_id,
info=info, info=info,
stats=stats, stats=stats,
episodes_stats=episodes_stats,
tasks=tasks, tasks=tasks,
episodes=episode_dicts, episodes=episode_dicts,
local_files_only=kwargs.get("local_files_only", False), local_files_only=kwargs.get("local_files_only", False),
@ -406,7 +436,7 @@ def empty_lerobot_dataset_factory():
robot: Robot | None = None, robot: Robot | None = None,
robot_type: str | None = None, robot_type: str | None = None,
features: dict | None = None, features: dict | None = None,
): ) -> LeRobotDataset:
return LeRobotDataset.create( return LeRobotDataset.create(
repo_id=repo_id, fps=fps, root=root, robot=robot, robot_type=robot_type, features=features repo_id=repo_id, fps=fps, root=root, robot=robot, robot_type=robot_type, features=features
) )

View File

@ -7,7 +7,13 @@ import pyarrow.compute as pc
import pyarrow.parquet as pq import pyarrow.parquet as pq
import pytest import pytest
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH from lerobot.common.datasets.utils import (
EPISODES_PATH,
EPISODES_STATS_PATH,
INFO_PATH,
STATS_PATH,
TASKS_PATH,
)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -38,6 +44,20 @@ def stats_path(stats_factory):
return _create_stats_json_file return _create_stats_json_file
@pytest.fixture(scope="session")
def episodes_stats_path(episodes_stats_factory):
def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path:
if not episodes_stats:
episodes_stats = episodes_stats_factory()
fpath = dir / EPISODES_STATS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(episodes_stats.values())
return fpath
return _create_episodes_stats_jsonl_file
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def tasks_path(tasks_factory): def tasks_path(tasks_factory):
def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path: def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path:
@ -46,7 +66,7 @@ def tasks_path(tasks_factory):
fpath = dir / TASKS_PATH fpath = dir / TASKS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True) fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer: with jsonlines.open(fpath, "w") as writer:
writer.write_all(tasks) writer.write_all(tasks.values())
return fpath return fpath
return _create_tasks_jsonl_file return _create_tasks_jsonl_file
@ -60,7 +80,7 @@ def episode_path(episodes_factory):
fpath = dir / EPISODES_PATH fpath = dir / EPISODES_PATH
fpath.parent.mkdir(parents=True, exist_ok=True) fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer: with jsonlines.open(fpath, "w") as writer:
writer.write_all(episodes) writer.write_all(episodes.values())
return fpath return fpath
return _create_episodes_jsonl_file return _create_episodes_jsonl_file

21
tests/fixtures/hub.py vendored
View File

@ -4,7 +4,13 @@ import datasets
import pytest import pytest
from huggingface_hub.utils import filter_repo_objects from huggingface_hub.utils import filter_repo_objects
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH from lerobot.common.datasets.utils import (
EPISODES_PATH,
EPISODES_STATS_PATH,
INFO_PATH,
STATS_PATH,
TASKS_PATH,
)
from tests.fixtures.constants import LEROBOT_TEST_DIR from tests.fixtures.constants import LEROBOT_TEST_DIR
@ -14,6 +20,8 @@ def mock_snapshot_download_factory(
info_path, info_path,
stats_factory, stats_factory,
stats_path, stats_path,
episodes_stats_factory,
episodes_stats_path,
tasks_factory, tasks_factory,
tasks_path, tasks_path,
episodes_factory, episodes_factory,
@ -29,6 +37,7 @@ def mock_snapshot_download_factory(
def _mock_snapshot_download_func( def _mock_snapshot_download_func(
info: dict | None = None, info: dict | None = None,
stats: dict | None = None, stats: dict | None = None,
episodes_stats: list[dict] | None = None,
tasks: list[dict] | None = None, tasks: list[dict] | None = None,
episodes: list[dict] | None = None, episodes: list[dict] | None = None,
hf_dataset: datasets.Dataset | None = None, hf_dataset: datasets.Dataset | None = None,
@ -37,6 +46,10 @@ def mock_snapshot_download_factory(
info = info_factory() info = info_factory()
if not stats: if not stats:
stats = stats_factory(features=info["features"]) stats = stats_factory(features=info["features"])
if not episodes_stats:
episodes_stats = episodes_stats_factory(
features=info["features"], total_episodes=info["total_episodes"]
)
if not tasks: if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"]) tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes: if not episodes:
@ -67,11 +80,11 @@ def mock_snapshot_download_factory(
# List all possible files # List all possible files
all_files = [] all_files = []
meta_files = [INFO_PATH, STATS_PATH, TASKS_PATH, EPISODES_PATH] meta_files = [INFO_PATH, STATS_PATH, EPISODES_STATS_PATH, TASKS_PATH, EPISODES_PATH]
all_files.extend(meta_files) all_files.extend(meta_files)
data_files = [] data_files = []
for episode_dict in episodes: for episode_dict in episodes.values():
ep_idx = episode_dict["episode_index"] ep_idx = episode_dict["episode_index"]
ep_chunk = ep_idx // info["chunks_size"] ep_chunk = ep_idx // info["chunks_size"]
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx) data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
@ -92,6 +105,8 @@ def mock_snapshot_download_factory(
_ = info_path(local_dir, info) _ = info_path(local_dir, info)
elif rel_path == STATS_PATH: elif rel_path == STATS_PATH:
_ = stats_path(local_dir, stats) _ = stats_path(local_dir, stats)
elif rel_path == EPISODES_STATS_PATH:
_ = episodes_stats_path(local_dir, episodes_stats)
elif rel_path == TASKS_PATH: elif rel_path == TASKS_PATH:
_ = tasks_path(local_dir, tasks) _ = tasks_path(local_dir, tasks)
elif rel_path == EPISODES_PATH: elif rel_path == EPISODES_PATH:

View File

@ -182,7 +182,7 @@ def test_camera(request, camera_type, mock):
@pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES) @pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES)
@require_camera @require_camera
def test_save_images_from_cameras(tmpdir, request, camera_type, mock): def test_save_images_from_cameras(tmp_path, request, camera_type, mock):
# TODO(rcadene): refactor # TODO(rcadene): refactor
if camera_type == "opencv": if camera_type == "opencv":
from lerobot.common.robot_devices.cameras.opencv import save_images_from_cameras from lerobot.common.robot_devices.cameras.opencv import save_images_from_cameras
@ -190,4 +190,4 @@ def test_save_images_from_cameras(tmpdir, request, camera_type, mock):
from lerobot.common.robot_devices.cameras.intelrealsense import save_images_from_cameras from lerobot.common.robot_devices.cameras.intelrealsense import save_images_from_cameras
# Small `record_time_s` to speedup unit tests # Small `record_time_s` to speedup unit tests
save_images_from_cameras(tmpdir, record_time_s=0.02, mock=mock) save_images_from_cameras(tmp_path, record_time_s=0.02, mock=mock)

311
tests/test_compute_stats.py Normal file
View File

@ -0,0 +1,311 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import patch
import numpy as np
import pytest
from lerobot.common.datasets.compute_stats import (
_assert_type_and_shape,
aggregate_feature_stats,
aggregate_stats,
compute_episode_stats,
estimate_num_samples,
get_feature_stats,
sample_images,
sample_indices,
)
def mock_load_image_as_numpy(path, dtype, channel_first):
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
@pytest.fixture
def sample_array():
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
def test_estimate_num_samples():
assert estimate_num_samples(1) == 1
assert estimate_num_samples(10) == 10
assert estimate_num_samples(100) == 100
assert estimate_num_samples(200) == 100
assert estimate_num_samples(1000) == 177
assert estimate_num_samples(2000) == 299
assert estimate_num_samples(5000) == 594
assert estimate_num_samples(10_000) == 1000
assert estimate_num_samples(20_000) == 1681
assert estimate_num_samples(50_000) == 3343
assert estimate_num_samples(500_000) == 10_000
def test_sample_indices():
indices = sample_indices(10)
assert len(indices) > 0
assert indices[0] == 0
assert indices[-1] == 9
assert len(indices) == estimate_num_samples(10)
@patch("lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy)
def test_sample_images(mock_load):
image_paths = [f"image_{i}.jpg" for i in range(100)]
images = sample_images(image_paths)
assert isinstance(images, np.ndarray)
assert images.shape[1:] == (3, 32, 32)
assert images.dtype == np.uint8
assert len(images) == estimate_num_samples(100)
def test_get_feature_stats_images():
data = np.random.rand(100, 3, 32, 32)
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
np.testing.assert_equal(stats["count"], np.array([100]))
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
def test_get_feature_stats_axis_0_keepdims(sample_array):
expected = {
"min": np.array([[1, 2, 3]]),
"max": np.array([[7, 8, 9]]),
"mean": np.array([[4.0, 5.0, 6.0]]),
"std": np.array([[2.44948974, 2.44948974, 2.44948974]]),
"count": np.array([3]),
}
result = get_feature_stats(sample_array, axis=(0,), keepdims=True)
for key in expected:
np.testing.assert_allclose(result[key], expected[key])
def test_get_feature_stats_axis_1(sample_array):
expected = {
"min": np.array([1, 4, 7]),
"max": np.array([3, 6, 9]),
"mean": np.array([2.0, 5.0, 8.0]),
"std": np.array([0.81649658, 0.81649658, 0.81649658]),
"count": np.array([3]),
}
result = get_feature_stats(sample_array, axis=(1,), keepdims=False)
for key in expected:
np.testing.assert_allclose(result[key], expected[key])
def test_get_feature_stats_no_axis(sample_array):
expected = {
"min": np.array(1),
"max": np.array(9),
"mean": np.array(5.0),
"std": np.array(2.5819889),
"count": np.array([3]),
}
result = get_feature_stats(sample_array, axis=None, keepdims=False)
for key in expected:
np.testing.assert_allclose(result[key], expected[key])
def test_get_feature_stats_empty_array():
array = np.array([])
with pytest.raises(ValueError):
get_feature_stats(array, axis=(0,), keepdims=True)
def test_get_feature_stats_single_value():
array = np.array([[1337]])
result = get_feature_stats(array, axis=None, keepdims=True)
np.testing.assert_equal(result["min"], np.array(1337))
np.testing.assert_equal(result["max"], np.array(1337))
np.testing.assert_equal(result["mean"], np.array(1337.0))
np.testing.assert_equal(result["std"], np.array(0.0))
np.testing.assert_equal(result["count"], np.array([1]))
def test_compute_episode_stats():
episode_data = {
"observation.image": [f"image_{i}.jpg" for i in range(100)],
"observation.state": np.random.rand(100, 10),
}
features = {
"observation.image": {"dtype": "image"},
"observation.state": {"dtype": "numeric"},
}
with patch(
"lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy
):
stats = compute_episode_stats(episode_data, features)
assert "observation.image" in stats and "observation.state" in stats
assert stats["observation.image"]["count"].item() == 100
assert stats["observation.state"]["count"].item() == 100
assert stats["observation.image"]["mean"].shape == (3, 1, 1)
def test_assert_type_and_shape_valid():
valid_stats = [
{
"feature1": {
"min": np.array([1.0]),
"max": np.array([10.0]),
"mean": np.array([5.0]),
"std": np.array([2.0]),
"count": np.array([1]),
}
}
]
_assert_type_and_shape(valid_stats)
def test_assert_type_and_shape_invalid_type():
invalid_stats = [
{
"feature1": {
"min": [1.0], # Not a numpy array
"max": np.array([10.0]),
"mean": np.array([5.0]),
"std": np.array([2.0]),
"count": np.array([1]),
}
}
]
with pytest.raises(ValueError, match="Stats must be composed of numpy array"):
_assert_type_and_shape(invalid_stats)
def test_assert_type_and_shape_invalid_shape():
invalid_stats = [
{
"feature1": {
"count": np.array([1, 2]), # Wrong shape
}
}
]
with pytest.raises(ValueError, match=r"Shape of 'count' must be \(1\)"):
_assert_type_and_shape(invalid_stats)
def test_aggregate_feature_stats():
stats_ft_list = [
{
"min": np.array([1.0]),
"max": np.array([10.0]),
"mean": np.array([5.0]),
"std": np.array([2.0]),
"count": np.array([1]),
},
{
"min": np.array([2.0]),
"max": np.array([12.0]),
"mean": np.array([6.0]),
"std": np.array([2.5]),
"count": np.array([1]),
},
]
result = aggregate_feature_stats(stats_ft_list)
np.testing.assert_allclose(result["min"], np.array([1.0]))
np.testing.assert_allclose(result["max"], np.array([12.0]))
np.testing.assert_allclose(result["mean"], np.array([5.5]))
np.testing.assert_allclose(result["std"], np.array([2.318405]), atol=1e-6)
np.testing.assert_allclose(result["count"], np.array([2]))
def test_aggregate_stats():
all_stats = [
{
"observation.image": {
"min": [1, 2, 3],
"max": [10, 20, 30],
"mean": [5.5, 10.5, 15.5],
"std": [2.87, 5.87, 8.87],
"count": 10,
},
"observation.state": {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10},
"extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6},
},
{
"observation.image": {
"min": [2, 1, 0],
"max": [15, 10, 5],
"mean": [8.5, 5.5, 2.5],
"std": [3.42, 2.42, 1.42],
"count": 15,
},
"observation.state": {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15},
"extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5},
},
]
expected_agg_stats = {
"observation.image": {
"min": [1, 1, 0],
"max": [15, 20, 30],
"mean": [7.3, 7.5, 7.7],
"std": [3.5317, 4.8267, 8.5581],
"count": 25,
},
"observation.state": {
"min": 1,
"max": 15,
"mean": 7.3,
"std": 3.5317,
"count": 25,
},
"extra_key_0": {
"min": 5,
"max": 25,
"mean": 15.0,
"std": 6.0,
"count": 6,
},
"extra_key_1": {
"min": 0,
"max": 20,
"mean": 10.0,
"std": 5.0,
"count": 5,
},
}
# cast to numpy
for ep_stats in all_stats:
for fkey, stats in ep_stats.items():
for k in stats:
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
if fkey == "observation.image" and k != "count":
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
else:
stats[k] = stats[k].reshape(1)
# cast to numpy
for fkey, stats in expected_agg_stats.items():
for k in stats:
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
if fkey == "observation.image" and k != "count":
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
else:
stats[k] = stats[k].reshape(1)
results = aggregate_stats(all_stats)
for fkey in expected_agg_stats:
np.testing.assert_allclose(results[fkey]["min"], expected_agg_stats[fkey]["min"])
np.testing.assert_allclose(results[fkey]["max"], expected_agg_stats[fkey]["max"])
np.testing.assert_allclose(results[fkey]["mean"], expected_agg_stats[fkey]["mean"])
np.testing.assert_allclose(
results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04
)
np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"])

View File

@ -24,7 +24,6 @@ pytest -sx 'tests/test_control_robot.py::test_teleoperate[aloha-True]'
""" """
import multiprocessing import multiprocessing
from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -45,7 +44,7 @@ from tests.utils import DEVICE, TEST_ROBOT_TYPES, mock_calibration_dir, require_
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES) @pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@require_robot @require_robot
def test_teleoperate(tmpdir, request, robot_type, mock): def test_teleoperate(tmp_path, request, robot_type, mock):
robot_kwargs = {"robot_type": robot_type, "mock": mock} robot_kwargs = {"robot_type": robot_type, "mock": mock}
if mock and robot_type != "aloha": if mock and robot_type != "aloha":
@ -53,8 +52,7 @@ def test_teleoperate(tmpdir, request, robot_type, mock):
# Create an empty calibration directory to trigger manual calibration # Create an empty calibration directory to trigger manual calibration
# and avoid writing calibration files in user .cache/calibration folder # and avoid writing calibration files in user .cache/calibration folder
tmpdir = Path(tmpdir) calibration_dir = tmp_path / robot_type
calibration_dir = tmpdir / robot_type
mock_calibration_dir(calibration_dir) mock_calibration_dir(calibration_dir)
robot_kwargs["calibration_dir"] = calibration_dir robot_kwargs["calibration_dir"] = calibration_dir
else: else:
@ -70,15 +68,14 @@ def test_teleoperate(tmpdir, request, robot_type, mock):
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES) @pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@require_robot @require_robot
def test_calibrate(tmpdir, request, robot_type, mock): def test_calibrate(tmp_path, request, robot_type, mock):
robot_kwargs = {"robot_type": robot_type, "mock": mock} robot_kwargs = {"robot_type": robot_type, "mock": mock}
if mock: if mock:
request.getfixturevalue("patch_builtins_input") request.getfixturevalue("patch_builtins_input")
# Create an empty calibration directory to trigger manual calibration # Create an empty calibration directory to trigger manual calibration
tmpdir = Path(tmpdir) calibration_dir = tmp_path / robot_type
calibration_dir = tmpdir / robot_type
robot_kwargs["calibration_dir"] = calibration_dir robot_kwargs["calibration_dir"] = calibration_dir
robot = make_robot(**robot_kwargs) robot = make_robot(**robot_kwargs)
@ -89,7 +86,7 @@ def test_calibrate(tmpdir, request, robot_type, mock):
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES) @pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@require_robot @require_robot
def test_record_without_cameras(tmpdir, request, robot_type, mock): def test_record_without_cameras(tmp_path, request, robot_type, mock):
robot_kwargs = {"robot_type": robot_type, "mock": mock} robot_kwargs = {"robot_type": robot_type, "mock": mock}
# Avoid using cameras # Avoid using cameras
@ -100,7 +97,7 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
# Create an empty calibration directory to trigger manual calibration # Create an empty calibration directory to trigger manual calibration
# and avoid writing calibration files in user .cache/calibration folder # and avoid writing calibration files in user .cache/calibration folder
calibration_dir = Path(tmpdir) / robot_type calibration_dir = tmp_path / robot_type
mock_calibration_dir(calibration_dir) mock_calibration_dir(calibration_dir)
robot_kwargs["calibration_dir"] = calibration_dir robot_kwargs["calibration_dir"] = calibration_dir
else: else:
@ -108,7 +105,7 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
pass pass
repo_id = "lerobot/debug" repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id root = tmp_path / "data" / repo_id
single_task = "Do something." single_task = "Do something."
robot = make_robot(**robot_kwargs) robot = make_robot(**robot_kwargs)
@ -121,7 +118,6 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
episode_time_s=1, episode_time_s=1,
reset_time_s=0.1, reset_time_s=0.1,
num_episodes=2, num_episodes=2,
run_compute_stats=False,
push_to_hub=False, push_to_hub=False,
video=False, video=False,
play_sounds=False, play_sounds=False,
@ -131,8 +127,7 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES) @pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@require_robot @require_robot
def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
tmpdir = Path(tmpdir)
robot_kwargs = {"robot_type": robot_type, "mock": mock} robot_kwargs = {"robot_type": robot_type, "mock": mock}
if mock and robot_type != "aloha": if mock and robot_type != "aloha":
@ -140,7 +135,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
# Create an empty calibration directory to trigger manual calibration # Create an empty calibration directory to trigger manual calibration
# and avoid writing calibration files in user .cache/calibration folder # and avoid writing calibration files in user .cache/calibration folder
calibration_dir = tmpdir / robot_type calibration_dir = tmp_path / robot_type
mock_calibration_dir(calibration_dir) mock_calibration_dir(calibration_dir)
robot_kwargs["calibration_dir"] = calibration_dir robot_kwargs["calibration_dir"] = calibration_dir
else: else:
@ -148,7 +143,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
pass pass
repo_id = "lerobot_test/debug" repo_id = "lerobot_test/debug"
root = tmpdir / "data" / repo_id root = tmp_path / "data" / repo_id
single_task = "Do something." single_task = "Do something."
robot = make_robot(**robot_kwargs) robot = make_robot(**robot_kwargs)
@ -180,7 +175,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
policy_cfg = ACTConfig() policy_cfg = ACTConfig()
policy = make_policy(policy_cfg, ds_meta=dataset.meta, device=DEVICE) policy = make_policy(policy_cfg, ds_meta=dataset.meta, device=DEVICE)
out_dir = tmpdir / "logger" out_dir = tmp_path / "logger"
pretrained_policy_path = out_dir / "checkpoints/last/pretrained_model" pretrained_policy_path = out_dir / "checkpoints/last/pretrained_model"
policy.save_pretrained(pretrained_policy_path) policy.save_pretrained(pretrained_policy_path)
@ -207,7 +202,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
num_image_writer_processes = 0 num_image_writer_processes = 0
eval_repo_id = "lerobot/eval_debug" eval_repo_id = "lerobot/eval_debug"
eval_root = tmpdir / "data" / eval_repo_id eval_root = tmp_path / "data" / eval_repo_id
rec_eval_cfg = RecordControlConfig( rec_eval_cfg = RecordControlConfig(
repo_id=eval_repo_id, repo_id=eval_repo_id,
@ -218,7 +213,6 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
episode_time_s=1, episode_time_s=1,
reset_time_s=0.1, reset_time_s=0.1,
num_episodes=2, num_episodes=2,
run_compute_stats=False,
push_to_hub=False, push_to_hub=False,
video=False, video=False,
display_cameras=False, display_cameras=False,
@ -240,7 +234,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
@pytest.mark.parametrize("robot_type, mock", [("koch", True)]) @pytest.mark.parametrize("robot_type, mock", [("koch", True)])
@require_robot @require_robot
def test_resume_record(tmpdir, request, robot_type, mock): def test_resume_record(tmp_path, request, robot_type, mock):
robot_kwargs = {"robot_type": robot_type, "mock": mock} robot_kwargs = {"robot_type": robot_type, "mock": mock}
if mock and robot_type != "aloha": if mock and robot_type != "aloha":
@ -248,7 +242,7 @@ def test_resume_record(tmpdir, request, robot_type, mock):
# Create an empty calibration directory to trigger manual calibration # Create an empty calibration directory to trigger manual calibration
# and avoid writing calibration files in user .cache/calibration folder # and avoid writing calibration files in user .cache/calibration folder
calibration_dir = tmpdir / robot_type calibration_dir = tmp_path / robot_type
mock_calibration_dir(calibration_dir) mock_calibration_dir(calibration_dir)
robot_kwargs["calibration_dir"] = calibration_dir robot_kwargs["calibration_dir"] = calibration_dir
else: else:
@ -258,7 +252,7 @@ def test_resume_record(tmpdir, request, robot_type, mock):
robot = make_robot(**robot_kwargs) robot = make_robot(**robot_kwargs)
repo_id = "lerobot/debug" repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id root = tmp_path / "data" / repo_id
single_task = "Do something." single_task = "Do something."
rec_cfg = RecordControlConfig( rec_cfg = RecordControlConfig(
@ -272,7 +266,6 @@ def test_resume_record(tmpdir, request, robot_type, mock):
video=False, video=False,
display_cameras=False, display_cameras=False,
play_sounds=False, play_sounds=False,
run_compute_stats=False,
local_files_only=True, local_files_only=True,
num_episodes=1, num_episodes=1,
) )
@ -291,7 +284,7 @@ def test_resume_record(tmpdir, request, robot_type, mock):
@pytest.mark.parametrize("robot_type, mock", [("koch", True)]) @pytest.mark.parametrize("robot_type, mock", [("koch", True)])
@require_robot @require_robot
def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock): def test_record_with_event_rerecord_episode(tmp_path, request, robot_type, mock):
robot_kwargs = {"robot_type": robot_type, "mock": mock} robot_kwargs = {"robot_type": robot_type, "mock": mock}
if mock and robot_type != "aloha": if mock and robot_type != "aloha":
@ -299,7 +292,7 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
# Create an empty calibration directory to trigger manual calibration # Create an empty calibration directory to trigger manual calibration
# and avoid writing calibration files in user .cache/calibration folder # and avoid writing calibration files in user .cache/calibration folder
calibration_dir = tmpdir / robot_type calibration_dir = tmp_path / robot_type
mock_calibration_dir(calibration_dir) mock_calibration_dir(calibration_dir)
robot_kwargs["calibration_dir"] = calibration_dir robot_kwargs["calibration_dir"] = calibration_dir
else: else:
@ -316,7 +309,7 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
mock_listener.return_value = (None, mock_events) mock_listener.return_value = (None, mock_events)
repo_id = "lerobot/debug" repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id root = tmp_path / "data" / repo_id
single_task = "Do something." single_task = "Do something."
rec_cfg = RecordControlConfig( rec_cfg = RecordControlConfig(
@ -331,7 +324,6 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
video=False, video=False,
display_cameras=False, display_cameras=False,
play_sounds=False, play_sounds=False,
run_compute_stats=False,
) )
dataset = record(robot, rec_cfg) dataset = record(robot, rec_cfg)
@ -342,7 +334,7 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
@pytest.mark.parametrize("robot_type, mock", [("koch", True)]) @pytest.mark.parametrize("robot_type, mock", [("koch", True)])
@require_robot @require_robot
def test_record_with_event_exit_early(tmpdir, request, robot_type, mock): def test_record_with_event_exit_early(tmp_path, request, robot_type, mock):
robot_kwargs = {"robot_type": robot_type, "mock": mock} robot_kwargs = {"robot_type": robot_type, "mock": mock}
if mock: if mock:
@ -350,7 +342,7 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
# Create an empty calibration directory to trigger manual calibration # Create an empty calibration directory to trigger manual calibration
# and avoid writing calibration files in user .cache/calibration folder # and avoid writing calibration files in user .cache/calibration folder
calibration_dir = tmpdir / robot_type calibration_dir = tmp_path / robot_type
mock_calibration_dir(calibration_dir) mock_calibration_dir(calibration_dir)
robot_kwargs["calibration_dir"] = calibration_dir robot_kwargs["calibration_dir"] = calibration_dir
else: else:
@ -367,7 +359,7 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
mock_listener.return_value = (None, mock_events) mock_listener.return_value = (None, mock_events)
repo_id = "lerobot/debug" repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id root = tmp_path / "data" / repo_id
single_task = "Do something." single_task = "Do something."
rec_cfg = RecordControlConfig( rec_cfg = RecordControlConfig(
@ -382,7 +374,6 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
video=False, video=False,
display_cameras=False, display_cameras=False,
play_sounds=False, play_sounds=False,
run_compute_stats=False,
) )
dataset = record(robot, rec_cfg) dataset = record(robot, rec_cfg)
@ -395,7 +386,7 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
"robot_type, mock, num_image_writer_processes", [("koch", True, 0), ("koch", True, 1)] "robot_type, mock, num_image_writer_processes", [("koch", True, 0), ("koch", True, 1)]
) )
@require_robot @require_robot
def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num_image_writer_processes): def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, num_image_writer_processes):
robot_kwargs = {"robot_type": robot_type, "mock": mock} robot_kwargs = {"robot_type": robot_type, "mock": mock}
if mock: if mock:
@ -403,7 +394,7 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
# Create an empty calibration directory to trigger manual calibration # Create an empty calibration directory to trigger manual calibration
# and avoid writing calibration files in user .cache/calibration folder # and avoid writing calibration files in user .cache/calibration folder
calibration_dir = tmpdir / robot_type calibration_dir = tmp_path / robot_type
mock_calibration_dir(calibration_dir) mock_calibration_dir(calibration_dir)
robot_kwargs["calibration_dir"] = calibration_dir robot_kwargs["calibration_dir"] = calibration_dir
else: else:
@ -420,7 +411,7 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
mock_listener.return_value = (None, mock_events) mock_listener.return_value = (None, mock_events)
repo_id = "lerobot/debug" repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id root = tmp_path / "data" / repo_id
single_task = "Do something." single_task = "Do something."
rec_cfg = RecordControlConfig( rec_cfg = RecordControlConfig(
@ -436,7 +427,6 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
video=False, video=False,
display_cameras=False, display_cameras=False,
play_sounds=False, play_sounds=False,
run_compute_stats=False,
num_image_writer_processes=num_image_writer_processes, num_image_writer_processes=num_image_writer_processes,
) )

View File

@ -20,21 +20,14 @@ from copy import deepcopy
from itertools import chain from itertools import chain
from pathlib import Path from pathlib import Path
import einops
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
from datasets import Dataset
from huggingface_hub import HfApi from huggingface_hub import HfApi
from PIL import Image from PIL import Image
from safetensors.torch import load_file from safetensors.torch import load_file
import lerobot import lerobot
from lerobot.common.datasets.compute_stats import (
aggregate_stats,
compute_stats,
get_stats_einops_patterns,
)
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.image_writer import image_array_to_pil_image from lerobot.common.datasets.image_writer import image_array_to_pil_image
from lerobot.common.datasets.lerobot_dataset import ( from lerobot.common.datasets.lerobot_dataset import (
@ -44,13 +37,11 @@ from lerobot.common.datasets.lerobot_dataset import (
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
create_branch, create_branch,
flatten_dict, flatten_dict,
hf_transform_to_torch,
unflatten_dict, unflatten_dict,
) )
from lerobot.common.envs.factory import make_env_config from lerobot.common.envs.factory import make_env_config
from lerobot.common.policies.factory import make_policy_config from lerobot.common.policies.factory import make_policy_config
from lerobot.common.robot_devices.robots.utils import make_robot from lerobot.common.robot_devices.robots.utils import make_robot
from lerobot.common.utils.random_utils import seeded_context
from lerobot.configs.default import DatasetConfig from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.train import TrainPipelineConfig
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
@ -196,12 +187,12 @@ def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_fact
def test_add_frame(tmp_path, empty_lerobot_dataset_factory): def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(1), "task": "dummy"}) dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
dataset.save_episode(encode_videos=False) dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False) dataset.consolidate()
assert len(dataset) == 1 assert len(dataset) == 1
assert dataset[0]["task"] == "dummy" assert dataset[0]["task"] == "Dummy task"
assert dataset[0]["task_index"] == 0 assert dataset[0]["task_index"] == 0
assert dataset[0]["state"].ndim == 0 assert dataset[0]["state"].ndim == 0
@ -209,9 +200,9 @@ def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}} features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2), "task": "dummy"}) dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"})
dataset.save_episode(encode_videos=False) dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False) dataset.consolidate()
assert dataset[0]["state"].shape == torch.Size([2]) assert dataset[0]["state"].shape == torch.Size([2])
@ -219,9 +210,9 @@ def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}} features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4), "task": "dummy"}) dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"})
dataset.save_episode(encode_videos=False) dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False) dataset.consolidate()
assert dataset[0]["state"].shape == torch.Size([2, 4]) assert dataset[0]["state"].shape == torch.Size([2, 4])
@ -229,9 +220,9 @@ def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}} features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "dummy"}) dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"})
dataset.save_episode(encode_videos=False) dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False) dataset.consolidate()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3]) assert dataset[0]["state"].shape == torch.Size([2, 4, 3])
@ -239,9 +230,9 @@ def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}} features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "dummy"}) dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"})
dataset.save_episode(encode_videos=False) dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False) dataset.consolidate()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5]) assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5])
@ -249,9 +240,9 @@ def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}} features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "dummy"}) dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"})
dataset.save_episode(encode_videos=False) dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False) dataset.consolidate()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1]) assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1])
@ -261,7 +252,7 @@ def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"}) dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"})
dataset.save_episode(encode_videos=False) dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False) dataset.consolidate()
assert dataset[0]["state"].ndim == 0 assert dataset[0]["state"].ndim == 0
@ -271,7 +262,7 @@ def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory):
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"}) dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"})
dataset.save_episode(encode_videos=False) dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False) dataset.consolidate()
assert dataset[0]["caption"] == "Dummy caption" assert dataset[0]["caption"] == "Dummy caption"
@ -307,7 +298,7 @@ def test_add_frame_image(image_dataset):
dataset = image_dataset dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
dataset.save_episode(encode_videos=False) dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False) dataset.consolidate()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@ -316,7 +307,7 @@ def test_add_frame_image_h_w_c(image_dataset):
dataset = image_dataset dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"}) dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"})
dataset.save_episode(encode_videos=False) dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False) dataset.consolidate()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@ -326,7 +317,7 @@ def test_add_frame_image_uint8(image_dataset):
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
dataset.add_frame({"image": image, "task": "Dummy task"}) dataset.add_frame({"image": image, "task": "Dummy task"})
dataset.save_episode(encode_videos=False) dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False) dataset.consolidate()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@ -336,7 +327,7 @@ def test_add_frame_image_pil(image_dataset):
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"}) dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"})
dataset.save_episode(encode_videos=False) dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False) dataset.consolidate()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@ -463,67 +454,6 @@ def test_multidataset_frames():
assert torch.equal(sub_dataset_item[k], dataset_item[k]) assert torch.equal(sub_dataset_item[k], dataset_item[k])
# TODO(aliberts, rcadene): Refactor and move this to a tests/test_compute_stats.py
def test_compute_stats_on_xarm():
"""Check that the statistics are computed correctly according to the stats_patterns property.
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
because we are working with a small dataset).
"""
# TODO(rcadene, aliberts): remove dataset download
dataset = LeRobotDataset("lerobot/xarm_lift_medium", episodes=[0])
# reduce size of dataset sample on which stats compute is tested to 10 frames
dataset.hf_dataset = dataset.hf_dataset.select(range(10))
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
# dataset into even batches.
computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25), num_workers=0)
# get einops patterns to aggregate batches and compute statistics
stats_patterns = get_stats_einops_patterns(dataset)
# get all frames from the dataset in the same dtype and range as during compute_stats
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=len(dataset),
shuffle=False,
)
full_batch = next(iter(dataloader))
# compute stats based on all frames from the dataset without any batching
expected_stats = {}
for k, pattern in stats_patterns.items():
full_batch[k] = full_batch[k].float()
expected_stats[k] = {}
expected_stats[k]["mean"] = einops.reduce(full_batch[k], pattern, "mean")
expected_stats[k]["std"] = torch.sqrt(
einops.reduce((full_batch[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
)
expected_stats[k]["min"] = einops.reduce(full_batch[k], pattern, "min")
expected_stats[k]["max"] = einops.reduce(full_batch[k], pattern, "max")
# test computed stats match expected stats
for k in stats_patterns:
assert torch.allclose(computed_stats[k]["mean"], expected_stats[k]["mean"])
assert torch.allclose(computed_stats[k]["std"], expected_stats[k]["std"])
assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"])
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
# load stats used during training which are expected to match the ones returned by computed_stats
loaded_stats = dataset.meta.stats # noqa: F841
# TODO(rcadene): we can't test this because expected_stats is computed on a subset
# # test loaded stats match expected stats
# for k in stats_patterns:
# assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"])
# assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"])
# assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"])
# assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
# TODO(aliberts): Move to more appropriate location # TODO(aliberts): Move to more appropriate location
def test_flatten_unflatten_dict(): def test_flatten_unflatten_dict():
d = { d = {
@ -627,35 +557,6 @@ def test_backward_compatibility(repo_id):
# load_and_compare(i - 1) # load_and_compare(i - 1)
@pytest.mark.skip("TODO after fix multidataset")
def test_multidataset_aggregate_stats():
"""Makes 3 basic datasets and checks that aggregate stats are computed correctly."""
with seeded_context(0):
data_a = torch.rand(30, dtype=torch.float32)
data_b = torch.rand(20, dtype=torch.float32)
data_c = torch.rand(20, dtype=torch.float32)
hf_dataset_1 = Dataset.from_dict(
{"a": data_a[:10], "b": data_b[:10], "c": data_c[:10], "index": torch.arange(10)}
)
hf_dataset_1.set_transform(hf_transform_to_torch)
hf_dataset_2 = Dataset.from_dict({"a": data_a[10:20], "b": data_b[10:], "index": torch.arange(10)})
hf_dataset_2.set_transform(hf_transform_to_torch)
hf_dataset_3 = Dataset.from_dict({"a": data_a[20:], "c": data_c[10:], "index": torch.arange(10)})
hf_dataset_3.set_transform(hf_transform_to_torch)
dataset_1 = LeRobotDataset.from_preloaded("d1", hf_dataset=hf_dataset_1)
dataset_1.stats = compute_stats(dataset_1, batch_size=len(hf_dataset_1), num_workers=0)
dataset_2 = LeRobotDataset.from_preloaded("d2", hf_dataset=hf_dataset_2)
dataset_2.stats = compute_stats(dataset_2, batch_size=len(hf_dataset_2), num_workers=0)
dataset_3 = LeRobotDataset.from_preloaded("d3", hf_dataset=hf_dataset_3)
dataset_3.stats = compute_stats(dataset_3, batch_size=len(hf_dataset_3), num_workers=0)
stats = aggregate_stats([dataset_1, dataset_2, dataset_3])
for data_key, data in zip(["a", "b", "c"], [data_a, data_b, data_c], strict=True):
for agg_fn in ["mean", "min", "max"]:
assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn))
assert torch.allclose(stats[data_key]["std"], torch.std(data, correction=0))
@pytest.mark.skip("Requires internet access") @pytest.mark.skip("Requires internet access")
def test_create_branch(): def test_create_branch():
api = HfApi() api = HfApi()

View File

@ -1,370 +0,0 @@
"""
This file contains generic tests to ensure that nothing breaks if we modify the push_dataset_to_hub API.
Also, this file contains backward compatibility tests. Because they are slow and require to download the raw datasets,
we skip them for now in our CI.
Example to run backward compatiblity tests locally:
```
python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility
```
"""
from pathlib import Path
import numpy as np
import pytest
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub.utils import save_images_concurrently
from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub
from tests.utils import require_package_arg
def _mock_download_raw_pusht(raw_dir, num_frames=4, num_episodes=3):
import zarr
raw_dir.mkdir(parents=True, exist_ok=True)
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
store = zarr.DirectoryStore(zarr_path)
zarr_data = zarr.group(store=store)
zarr_data.create_dataset(
"data/action", shape=(num_frames, 1), chunks=(num_frames, 1), dtype=np.float32, overwrite=True
)
zarr_data.create_dataset(
"data/img",
shape=(num_frames, 96, 96, 3),
chunks=(num_frames, 96, 96, 3),
dtype=np.uint8,
overwrite=True,
)
zarr_data.create_dataset(
"data/n_contacts", shape=(num_frames, 2), chunks=(num_frames, 2), dtype=np.float32, overwrite=True
)
zarr_data.create_dataset(
"data/state", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
)
zarr_data.create_dataset(
"data/keypoint", shape=(num_frames, 9, 2), chunks=(num_frames, 9, 2), dtype=np.float32, overwrite=True
)
zarr_data.create_dataset(
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
)
zarr_data["data/action"][:] = np.random.randn(num_frames, 1)
zarr_data["data/img"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
zarr_data["data/n_contacts"][:] = np.random.randn(num_frames, 2)
zarr_data["data/state"][:] = np.random.randn(num_frames, 5)
zarr_data["data/keypoint"][:] = np.random.randn(num_frames, 9, 2)
zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4])
store.close()
def _mock_download_raw_umi(raw_dir, num_frames=4, num_episodes=3):
import zarr
raw_dir.mkdir(parents=True, exist_ok=True)
zarr_path = raw_dir / "cup_in_the_wild.zarr"
store = zarr.DirectoryStore(zarr_path)
zarr_data = zarr.group(store=store)
zarr_data.create_dataset(
"data/camera0_rgb",
shape=(num_frames, 96, 96, 3),
chunks=(num_frames, 96, 96, 3),
dtype=np.uint8,
overwrite=True,
)
zarr_data.create_dataset(
"data/robot0_demo_end_pose",
shape=(num_frames, 5),
chunks=(num_frames, 5),
dtype=np.float32,
overwrite=True,
)
zarr_data.create_dataset(
"data/robot0_demo_start_pose",
shape=(num_frames, 5),
chunks=(num_frames, 5),
dtype=np.float32,
overwrite=True,
)
zarr_data.create_dataset(
"data/robot0_eef_pos", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
)
zarr_data.create_dataset(
"data/robot0_eef_rot_axis_angle",
shape=(num_frames, 5),
chunks=(num_frames, 5),
dtype=np.float32,
overwrite=True,
)
zarr_data.create_dataset(
"data/robot0_gripper_width",
shape=(num_frames, 5),
chunks=(num_frames, 5),
dtype=np.float32,
overwrite=True,
)
zarr_data.create_dataset(
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
)
zarr_data["data/camera0_rgb"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
zarr_data["data/robot0_demo_end_pose"][:] = np.random.randn(num_frames, 5)
zarr_data["data/robot0_demo_start_pose"][:] = np.random.randn(num_frames, 5)
zarr_data["data/robot0_eef_pos"][:] = np.random.randn(num_frames, 5)
zarr_data["data/robot0_eef_rot_axis_angle"][:] = np.random.randn(num_frames, 5)
zarr_data["data/robot0_gripper_width"][:] = np.random.randn(num_frames, 5)
zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4])
store.close()
def _mock_download_raw_xarm(raw_dir, num_frames=4):
import pickle
dataset_dict = {
"observations": {
"rgb": np.random.randint(0, 255, size=(num_frames, 3, 84, 84), dtype=np.uint8),
"state": np.random.randn(num_frames, 4),
},
"actions": np.random.randn(num_frames, 3),
"rewards": np.random.randn(num_frames),
"masks": np.random.randn(num_frames),
"dones": np.array([False, True, True, True]),
}
raw_dir.mkdir(parents=True, exist_ok=True)
pkl_path = raw_dir / "buffer.pkl"
with open(pkl_path, "wb") as f:
pickle.dump(dataset_dict, f)
def _mock_download_raw_aloha(raw_dir, num_frames=6, num_episodes=3):
import h5py
for ep_idx in range(num_episodes):
raw_dir.mkdir(parents=True, exist_ok=True)
path_h5 = raw_dir / f"episode_{ep_idx}.hdf5"
with h5py.File(str(path_h5), "w") as f:
f.create_dataset("action", data=np.random.randn(num_frames // num_episodes, 14))
f.create_dataset("observations/qpos", data=np.random.randn(num_frames // num_episodes, 14))
f.create_dataset("observations/qvel", data=np.random.randn(num_frames // num_episodes, 14))
f.create_dataset(
"observations/images/top",
data=np.random.randint(
0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8
),
)
def _mock_download_raw_dora(raw_dir, num_frames=6, num_episodes=3, fps=30):
from datetime import datetime, timedelta, timezone
import pandas
def write_parquet(key, timestamps, values):
data = {
"timestamp_utc": timestamps,
key: values,
}
df = pandas.DataFrame(data)
raw_dir.mkdir(parents=True, exist_ok=True)
df.to_parquet(raw_dir / f"{key}.parquet", engine="pyarrow")
episode_indices = [None, None, -1, None, None, -1, None, None, -1]
episode_indices_mapping = [0, 0, 0, 1, 1, 1, 2, 2, 2]
frame_indices = [0, 1, -1, 0, 1, -1, 0, 1, -1]
cam_key = "observation.images.cam_high"
timestamps = []
actions = []
states = []
frames = []
# `+ num_episodes`` for buffer frames associated to episode_index=-1
for i, frame_idx in enumerate(frame_indices):
t_utc = datetime.now(timezone.utc) + timedelta(seconds=i / fps)
action = np.random.randn(21).tolist()
state = np.random.randn(21).tolist()
ep_idx = episode_indices_mapping[i]
frame = [{"path": f"videos/{cam_key}_episode_{ep_idx:06d}.mp4", "timestamp": frame_idx / fps}]
timestamps.append(t_utc)
actions.append(action)
states.append(state)
frames.append(frame)
write_parquet(cam_key, timestamps, frames)
write_parquet("observation.state", timestamps, states)
write_parquet("action", timestamps, actions)
write_parquet("episode_index", timestamps, episode_indices)
# write fake mp4 file for each episode
for ep_idx in range(num_episodes):
imgs_array = np.random.randint(0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8)
tmp_imgs_dir = raw_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir)
fname = f"{cam_key}_episode_{ep_idx:06d}.mp4"
video_path = raw_dir / "videos" / fname
encode_video_frames(tmp_imgs_dir, video_path, fps, vcodec="libx264")
def _mock_download_raw(raw_dir, repo_id):
if "wrist_gripper" in repo_id:
_mock_download_raw_dora(raw_dir)
elif "aloha" in repo_id:
_mock_download_raw_aloha(raw_dir)
elif "pusht" in repo_id:
_mock_download_raw_pusht(raw_dir)
elif "xarm" in repo_id:
_mock_download_raw_xarm(raw_dir)
elif "umi" in repo_id:
_mock_download_raw_umi(raw_dir)
else:
raise ValueError(repo_id)
@pytest.mark.skip("push_dataset_to_hub is deprecated")
def test_push_dataset_to_hub_invalid_repo_id(tmpdir):
with pytest.raises(ValueError):
push_dataset_to_hub(Path(tmpdir), "raw_format", "invalid_repo_id")
@pytest.mark.skip("push_dataset_to_hub is deprecated")
def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
tmpdir = Path(tmpdir)
out_dir = tmpdir / "out"
raw_dir = tmpdir / "raw"
# mkdir to skip download
raw_dir.mkdir(parents=True, exist_ok=True)
with pytest.raises(ValueError):
push_dataset_to_hub(
raw_dir=raw_dir,
raw_format="some_format",
repo_id="user/dataset",
local_dir=out_dir,
force_override=False,
)
@pytest.mark.skip("push_dataset_to_hub is deprecated")
@pytest.mark.parametrize(
"required_packages, raw_format, repo_id, make_test_data",
[
(["gym_pusht"], "pusht_zarr", "lerobot/pusht", False),
(["gym_pusht"], "pusht_zarr", "lerobot/pusht", True),
(None, "xarm_pkl", "lerobot/xarm_lift_medium", False),
(None, "aloha_hdf5", "lerobot/aloha_sim_insertion_scripted", False),
(["imagecodecs"], "umi_zarr", "lerobot/umi_cup_in_the_wild", False),
(None, "dora_parquet", "cadene/wrist_gripper", False),
],
)
@require_package_arg
def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id, make_test_data):
num_episodes = 3
tmpdir = Path(tmpdir)
raw_dir = tmpdir / f"{repo_id}_raw"
_mock_download_raw(raw_dir, repo_id)
local_dir = tmpdir / repo_id
lerobot_dataset = push_dataset_to_hub(
raw_dir=raw_dir,
raw_format=raw_format,
repo_id=repo_id,
push_to_hub=False,
local_dir=local_dir,
force_override=False,
cache_dir=tmpdir / "cache",
tests_data_dir=tmpdir / "tests/data" if make_test_data else None,
encoding={"vcodec": "libx264"},
)
# minimal generic tests on the local directory containing LeRobotDataset
assert (local_dir / "meta_data" / "info.json").exists()
assert (local_dir / "meta_data" / "stats.safetensors").exists()
assert (local_dir / "meta_data" / "episode_data_index.safetensors").exists()
for i in range(num_episodes):
for cam_key in lerobot_dataset.camera_keys:
assert (local_dir / "videos" / f"{cam_key}_episode_{i:06d}.mp4").exists()
assert (local_dir / "train" / "dataset_info.json").exists()
assert (local_dir / "train" / "state.json").exists()
assert len(list((local_dir / "train").glob("*.arrow"))) > 0
# minimal generic tests on the item
item = lerobot_dataset[0]
assert "index" in item
assert "episode_index" in item
assert "timestamp" in item
for cam_key in lerobot_dataset.camera_keys:
assert cam_key in item
if make_test_data:
# Check that only the first episode is selected.
test_dataset = LeRobotDataset(repo_id=repo_id, root=tmpdir / "tests/data")
num_frames = sum(
i == lerobot_dataset.hf_dataset["episode_index"][0]
for i in lerobot_dataset.hf_dataset["episode_index"]
).item()
assert (
test_dataset.hf_dataset["episode_index"]
== lerobot_dataset.hf_dataset["episode_index"][:num_frames]
)
for k in ["from", "to"]:
assert torch.equal(test_dataset.episode_data_index[k], lerobot_dataset.episode_data_index[k][:1])
@pytest.mark.skip("push_dataset_to_hub is deprecated")
@pytest.mark.parametrize(
"raw_format, repo_id",
[
# TODO(rcadene): add raw dataset test artifacts
("pusht_zarr", "lerobot/pusht"),
("xarm_pkl", "lerobot/xarm_lift_medium"),
("aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"),
("umi_zarr", "lerobot/umi_cup_in_the_wild"),
("dora_parquet", "cadene/wrist_gripper"),
],
)
def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, repo_id):
_, dataset_id = repo_id.split("/")
tmpdir = Path(tmpdir)
raw_dir = tmpdir / f"{dataset_id}_raw"
local_dir = tmpdir / repo_id
push_dataset_to_hub(
raw_dir=raw_dir,
raw_format=raw_format,
repo_id=repo_id,
push_to_hub=False,
local_dir=local_dir,
force_override=False,
cache_dir=tmpdir / "cache",
episodes=[0],
)
ds_actual = LeRobotDataset(repo_id, root=tmpdir)
ds_reference = LeRobotDataset(repo_id)
assert len(ds_reference.hf_dataset) == len(ds_actual.hf_dataset)
def check_same_items(item1, item2):
assert item1.keys() == item2.keys(), "Keys mismatch"
for key in item1:
if isinstance(item1[key], torch.Tensor) and isinstance(item2[key], torch.Tensor):
assert torch.equal(item1[key], item2[key]), f"Mismatch found in key: {key}"
else:
assert item1[key] == item2[key], f"Mismatch found in key: {key}"
for i in range(len(ds_reference.hf_dataset)):
item_reference = ds_reference.hf_dataset[i]
item_actual = ds_actual.hf_dataset[i]
check_same_items(item_reference, item_actual)

View File

@ -23,8 +23,6 @@ pytest -sx 'tests/test_robots.py::test_robot[aloha-True]'
``` ```
""" """
from pathlib import Path
import pytest import pytest
import torch import torch
@ -35,7 +33,7 @@ from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES) @pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@require_robot @require_robot
def test_robot(tmpdir, request, robot_type, mock): def test_robot(tmp_path, request, robot_type, mock):
# TODO(rcadene): measure fps in nightly? # TODO(rcadene): measure fps in nightly?
# TODO(rcadene): test logs # TODO(rcadene): test logs
# TODO(rcadene): add compatibility with other robots # TODO(rcadene): add compatibility with other robots
@ -50,8 +48,7 @@ def test_robot(tmpdir, request, robot_type, mock):
request.getfixturevalue("patch_builtins_input") request.getfixturevalue("patch_builtins_input")
# Create an empty calibration directory to trigger manual calibration # Create an empty calibration directory to trigger manual calibration
tmpdir = Path(tmpdir) calibration_dir = tmp_path / robot_type
calibration_dir = tmpdir / robot_type
mock_calibration_dir(calibration_dir) mock_calibration_dir(calibration_dir)
robot_kwargs["calibration_dir"] = calibration_dir robot_kwargs["calibration_dir"] = calibration_dir