Merge remote-tracking branch 'origin/user/aliberts/2024_09_25_reshape_dataset' into 1-rework-support-for-reachy2

This commit is contained in:
Simon Alibert 2024-11-26 10:51:06 +01:00
commit c79d7ed146
67 changed files with 6043 additions and 2033 deletions

View File

@ -103,7 +103,7 @@ jobs:
-W ignore::UserWarning:gymnasium.utils.env_checker:247 \
&& rm -rf tests/outputs outputs
# TODO(aliberts, rcadene): redesign after v2 migration / removing hydra
end-to-end:
name: End-to-end
runs-on: ubuntu-latest

View File

@ -266,7 +266,7 @@ def benchmark_encoding_decoding(
)
ep_num_images = dataset.episode_data_index["to"][0].item()
width, height = tuple(dataset[0][dataset.camera_keys[0]].shape[-2:])
width, height = tuple(dataset[0][dataset.meta.camera_keys[0]].shape[-2:])
num_pixels = width * height
video_size_bytes = video_path.stat().st_size
images_size_bytes = get_directory_size(imgs_dir)

View File

@ -3,78 +3,120 @@ This script demonstrates the use of `LeRobotDataset` class for handling and proc
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
Features included in this script:
- Loading a dataset and accessing its properties.
- Filtering data by episode number.
- Converting tensor data for visualization.
- Saving video files from dataset frames.
- Viewing a dataset's metadata and exploring its properties.
- Loading an existing dataset from the hub or a subset of it.
- Accessing frames by episode number.
- Using advanced dataset features like timestamp-based frame selection.
- Demonstrating compatibility with PyTorch DataLoader for batch processing.
The script ends with examples of how to batch process data using PyTorch's DataLoader.
"""
from pathlib import Path
from pprint import pprint
import imageio
import torch
from huggingface_hub import HfApi
import lerobot
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
# We ported a number of existing datasets ourselves, use this to see the list:
print("List of available datasets:")
pprint(lerobot.available_datasets)
# Let's take one for this example
repo_id = "lerobot/pusht"
# You can also browse through the datasets created/ported by the community on the hub using the hub api:
hub_api = HfApi()
repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])]
pprint(repo_ids)
# You can easily load a dataset from a Hugging Face repository
# Or simply explore them in your web browser directly at:
# https://huggingface.co/datasets?other=LeRobot
# Let's take this one for this example
repo_id = "lerobot/aloha_mobile_cabinet"
# We can have a look and fetch its metadata to know more about it:
ds_meta = LeRobotDatasetMetadata(repo_id)
# By instantiating just this class, you can quickly access useful information about the content and the
# structure of the dataset without downloading the actual data yet (only metadata files — which are
# lightweight).
print(f"Total number of episodes: {ds_meta.total_episodes}")
print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}")
print(f"Frames per second used during data collection: {ds_meta.fps}")
print(f"Robot type: {ds_meta.robot_type}")
print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n")
print("Tasks:")
print(ds_meta.tasks)
print("Features:")
pprint(ds_meta.features)
# You can also get a short summary by simply printing the object:
print(ds_meta)
# You can then load the actual dataset from the hub.
# Either load any subset of episodes:
dataset = LeRobotDataset(repo_id, episodes=[0, 10, 11, 23])
# And see how many frames you have:
print(f"Selected episodes: {dataset.episodes}")
print(f"Number of episodes selected: {dataset.num_episodes}")
print(f"Number of frames selected: {dataset.num_frames}")
# Or simply load the entire dataset:
dataset = LeRobotDataset(repo_id)
print(f"Number of episodes selected: {dataset.num_episodes}")
print(f"Number of frames selected: {dataset.num_frames}")
# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset
# (see https://huggingface.co/docs/datasets/index for more information).
print(dataset)
# The previous metadata class is contained in the 'meta' attribute of the dataset:
print(dataset.meta)
# LeRobotDataset actually wraps an underlying Hugging Face dataset
# (see https://huggingface.co/docs/datasets for more information).
print(dataset.hf_dataset)
# And provides additional utilities for robotics and compatibility with Pytorch
print(f"\naverage number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
print(f"frames per second used during data collection: {dataset.fps=}")
print(f"keys to access images from cameras: {dataset.camera_keys=}\n")
# Access frame indexes associated to first episode
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
# with the latter, like iterating through the dataset.
# The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by
# episodes, you can access the frame indices of any episode using the episode_data_index. Here, we access
# frame indices associated to the first episode:
episode_index = 0
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
# LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working
# with the latter, like iterating through the dataset. Here we grab all the image frames.
frames = [dataset[idx]["observation.image"] for idx in range(from_idx, to_idx)]
# Then we grab all the image frames from the first camera:
camera_key = dataset.meta.camera_keys[0]
frames = [dataset[idx][camera_key] for idx in range(from_idx, to_idx)]
# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention. To visualize
# them, we convert to uint8 in range [0,255]
frames = [(frame * 255).type(torch.uint8) for frame in frames]
# and to channel last (h,w,c).
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
# The objects returned by the dataset are all torch.Tensors
print(type(frames[0]))
print(frames[0].shape)
# Finally, we save the frames to a mp4 video for visualization.
Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_0.mp4", frames, fps=dataset.fps)
# Since we're using pytorch, the shape is in pytorch, channel-first convention (c, h, w).
# We can compare this shape with the information available for that feature
pprint(dataset.features[camera_key])
# In particular:
print(dataset.features[camera_key]["shape"])
# The shape is in (h, w, c) which is a more universal format.
# For many machine learning applications we need to load the history of past observations or trajectories of
# future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
# differences with the current loaded frame. For instance:
delta_timestamps = {
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
"observation.image": [-1, -0.5, -0.20, 0],
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 20 ms, 10 ms, and current frame
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, -0.02, -0.01, 0],
camera_key: [-1, -0.5, -0.20, 0],
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
"action": [t / dataset.fps for t in range(64)],
}
# Note that in any case, these delta_timestamps values need to be multiples of (1/fps) so that added to any
# timestamp, you still get a valid timestamp.
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
print(f"\n{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
print(f"{dataset[0]['observation.state'].shape=}") # (8,c)
print(f"{dataset[0]['action'].shape=}\n") # (64,c)
print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
# PyTorch datasets.
@ -84,8 +126,9 @@ dataloader = torch.utils.data.DataLoader(
batch_size=32,
shuffle=True,
)
for batch in dataloader:
print(f"{batch['observation.image'].shape=}") # (32,4,c,h,w)
print(f"{batch['observation.state'].shape=}") # (32,8,c)
print(f"{batch['action'].shape=}") # (32,64,c)
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
print(f"{batch['observation.state'].shape=}") # (32, 5, c)
print(f"{batch['action'].shape=}") # (32, 64, c)
break

View File

@ -40,7 +40,7 @@ dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)
# For this example, no arguments need to be passed because the defaults are set up for PushT.
# If you're doing something different, you will likely need to change at least some of the defaults.
cfg = DiffusionConfig()
policy = DiffusionPolicy(cfg, dataset_stats=dataset.stats)
policy = DiffusionPolicy(cfg, dataset_stats=dataset.meta.stats)
policy.train()
policy.to(device)

View File

@ -1,7 +1,7 @@
"""
This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data
augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and
transforms are applied to the observation images before they are returned in the dataset's __get_item__.
transforms are applied to the observation images before they are returned in the dataset's __getitem__.
"""
from pathlib import Path
@ -20,7 +20,7 @@ dataset = LeRobotDataset(dataset_repo_id)
first_idx = dataset.episode_data_index["from"][0].item()
# Get the frame corresponding to the first camera
frame = dataset[first_idx][dataset.camera_keys[0]]
frame = dataset[first_idx][dataset.meta.camera_keys[0]]
# Define the transformations
@ -36,7 +36,7 @@ transforms = v2.Compose(
transformed_dataset = LeRobotDataset(dataset_repo_id, image_transforms=transforms)
# Get a frame from the transformed dataset
transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]]
transformed_frame = transformed_dataset[first_idx][transformed_dataset.meta.camera_keys[0]]
# Create a directory to store output images
output_dir = Path("outputs/image_transforms")

View File

@ -14,7 +14,7 @@ from pathlib import Path
import torch
from huggingface_hub import snapshot_download
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
device = torch.device("cuda")
@ -41,26 +41,20 @@ delta_timestamps = {
}
# Load the last 10% of episodes of the dataset as a validation set.
# - Load full dataset
full_dataset = LeRobotDataset("lerobot/pusht", split="train")
# - Calculate train and val subsets
num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100)
num_val_episodes = full_dataset.num_episodes - num_train_episodes
print(f"Number of episodes in full dataset: {full_dataset.num_episodes}")
print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}")
print(f"Number of episodes in validation dataset (10% subset): {num_val_episodes}")
# - Get first frame index of the validation set
first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item()
# - Load frames subset belonging to validation set using the `split` argument.
# It utilizes the `datasets` library's syntax for slicing datasets.
# For more information on the Slice API, please see:
# https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
train_dataset = LeRobotDataset(
"lerobot/pusht", split=f"train[:{first_val_frame_index}]", delta_timestamps=delta_timestamps
)
val_dataset = LeRobotDataset(
"lerobot/pusht", split=f"train[{first_val_frame_index}:]", delta_timestamps=delta_timestamps
)
# - Load dataset metadata
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
# - Calculate train and val episodes
total_episodes = dataset_metadata.total_episodes
episodes = list(range(dataset_metadata.total_episodes))
num_train_episodes = math.floor(total_episodes * 90 / 100)
train_episodes = episodes[:num_train_episodes]
val_episodes = episodes[num_train_episodes:]
print(f"Number of episodes in full dataset: {total_episodes}")
print(f"Number of episodes in training dataset (90% subset): {len(train_episodes)}")
print(f"Number of episodes in validation dataset (10% subset): {len(val_episodes)}")
# - Load train an val datasets
train_dataset = LeRobotDataset("lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps)
val_dataset = LeRobotDataset("lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps)
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")

View File

@ -0,0 +1,222 @@
import shutil
from pathlib import Path
import numpy as np
import torch
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface."
PUSHT_FEATURES = {
"observation.state": {
"dtype": "float32",
"shape": (2,),
"names": {
"axes": ["x", "y"],
},
},
"action": {
"dtype": "float32",
"shape": (2,),
"names": {
"axes": ["x", "y"],
},
},
"next.reward": {
"dtype": "float32",
"shape": (1,),
"names": None,
},
"next.success": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
"observation.environment_state": {
"dtype": "float32",
"shape": (16,),
"names": [
"keypoints",
],
},
"observation.image": {
"dtype": None,
"shape": (3, 96, 96),
"names": [
"channel",
"height",
"width",
],
},
}
def build_features(mode: str) -> dict:
features = PUSHT_FEATURES
if mode == "keypoints":
features.pop("observation.image")
else:
features.pop("observation.environment_state")
features["observation.image"]["dtype"] = mode
return features
def load_raw_dataset(zarr_path: Path, load_images: bool = True):
try:
from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import (
ReplayBuffer as DiffusionPolicyReplayBuffer,
)
except ModuleNotFoundError as e:
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
raise e
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
return zarr_data
def calculate_coverage(zarr_data):
try:
import pymunk
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
except ModuleNotFoundError as e:
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
raise e
block_pos = zarr_data["state"][:, 2:4]
block_angle = zarr_data["state"][:, 4]
num_frames = len(block_pos)
coverage = np.zeros((num_frames,))
# 8 keypoints with 2 coords each
keypoints = np.zeros((num_frames, 16))
# Set x, y, theta (in radians)
goal_pos_angle = np.array([256, 256, np.pi / 4])
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
for i in range(num_frames):
space = pymunk.Space()
space.gravity = 0, 0
space.damping = 0
# Add walls.
walls = [
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
]
space.add(*walls)
block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
intersection_area = goal_geom.intersection(block_geom).area
goal_area = goal_geom.area
coverage[i] = intersection_area / goal_area
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
return coverage, keypoints
def calculate_success(coverage: float, success_threshold: float):
return coverage > success_threshold
def calculate_reward(coverage: float, success_threshold: float):
return np.clip(coverage / success_threshold, 0, 1)
def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = True):
if mode not in ["video", "image", "keypoints"]:
raise ValueError(mode)
if (LEROBOT_HOME / repo_id).exists():
shutil.rmtree(LEROBOT_HOME / repo_id)
if not raw_dir.exists():
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
zarr_data = load_raw_dataset(zarr_path=raw_dir / "pusht_cchi_v7_replay.zarr")
env_state = zarr_data["state"][:]
agent_pos = env_state[:, :2]
action = zarr_data["action"][:]
image = zarr_data["img"] # (b, h, w, c)
episode_data_index = {
"from": np.concatenate(([0], zarr_data.meta["episode_ends"][:-1])),
"to": zarr_data.meta["episode_ends"],
}
# Calculate success and reward based on the overlapping area
# of the T-object and the T-area.
coverage, keypoints = calculate_coverage(zarr_data)
success = calculate_success(coverage, success_threshold=0.95)
reward = calculate_reward(coverage, success_threshold=0.95)
features = build_features(mode)
dataset = LeRobotDataset.create(
repo_id=repo_id,
fps=10,
robot_type="2d pointer",
features=features,
image_writer_threads=4,
)
episodes = range(len(episode_data_index["from"]))
for ep_idx in episodes:
from_idx = episode_data_index["from"][ep_idx]
to_idx = episode_data_index["to"][ep_idx]
num_frames = to_idx - from_idx
for frame_idx in range(num_frames):
i = from_idx + frame_idx
frame = {
"action": torch.from_numpy(action[i]),
# Shift reward and success by +1 until the last item of the episode
"next.reward": reward[i + (frame_idx < num_frames - 1)],
"next.success": success[i + (frame_idx < num_frames - 1)],
}
frame["observation.state"] = torch.from_numpy(agent_pos[i])
if mode == "keypoints":
frame["observation.environment_state"] = torch.from_numpy(keypoints[i])
else:
frame["observation.image"] = torch.from_numpy(image[i])
dataset.add_frame(frame)
dataset.save_episode(task=PUSHT_TASK)
dataset.consolidate()
if push_to_hub:
dataset.push_to_hub()
if __name__ == "__main__":
# To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht)
repo_id = "lerobot/pusht"
modes = ["video", "image", "keypoints"]
# Uncomment if you want to try with a specific mode
# modes = ["video"]
# modes = ["image"]
# modes = ["keypoints"]
raw_dir = Path("data/lerobot-raw/pusht_raw")
for mode in modes:
if mode in ["image", "keypoints"]:
repo_id += f"_{mode}"
# download and load raw dataset, create LeRobotDataset, populate it, push to hub
main(raw_dir, repo_id=repo_id, mode=mode)
# Uncomment if you want to load the local dataset and explore it
# dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True)
# breakpoint()

View File

@ -181,8 +181,8 @@ available_real_world_datasets = [
"lerobot/usc_cloth_sim",
]
available_datasets = list(
itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)
available_datasets = sorted(
set(itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets))
)
# lists all available policies from `lerobot/common/policies`

View File

@ -0,0 +1,27 @@
---
# For reference on dataset card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/datasetcard.md?plain=1
# Doc / guide: https://huggingface.co/docs/hub/datasets-cards
{{ card_data }}
---
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
## Dataset Description
{{ dataset_description | default("", true) }}
- **Homepage:** {{ url | default("[More Information Needed]", true)}}
- **Paper:** {{ paper | default("[More Information Needed]", true)}}
- **License:** {{ license | default("[More Information Needed]", true)}}
## Dataset Structure
{{ dataset_structure | default("[More Information Needed]", true)}}
## Citation
**BibTeX:**
```bibtex
{{ citation_bibtex | default("[More Information Needed]", true)}}
```

View File

@ -19,9 +19,6 @@ from math import ceil
import einops
import torch
import tqdm
from datasets import Image
from lerobot.common.datasets.video_utils import VideoFrame
def get_stats_einops_patterns(dataset, num_workers=0):
@ -39,15 +36,13 @@ def get_stats_einops_patterns(dataset, num_workers=0):
batch = next(iter(dataloader))
stats_patterns = {}
for key, feats_type in dataset.features.items():
# NOTE: skip language_instruction embedding in stats computation
if key == "language_instruction":
continue
for key in dataset.features:
# sanity check that tensors are not float64
assert batch[key].dtype != torch.float64
if isinstance(feats_type, (VideoFrame, Image)):
# if isinstance(feats_type, (VideoFrame, Image)):
if key in dataset.meta.camera_keys:
# sanity check that images are channel first
_, c, h, w = batch[key].shape
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
@ -63,7 +58,7 @@ def get_stats_einops_patterns(dataset, num_workers=0):
elif batch[key].ndim == 1:
stats_patterns[key] = "b -> 1"
else:
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
raise ValueError(f"{key}, {batch[key].shape}")
return stats_patterns
@ -175,39 +170,45 @@ def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
"""
data_keys = set()
for dataset in ls_datasets:
data_keys.update(dataset.stats.keys())
data_keys.update(dataset.meta.stats.keys())
stats = {k: {} for k in data_keys}
for data_key in data_keys:
for stat_key in ["min", "max"]:
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
stats[data_key][stat_key] = einops.reduce(
torch.stack([d.stats[data_key][stat_key] for d in ls_datasets if data_key in d.stats], dim=0),
torch.stack(
[ds.meta.stats[data_key][stat_key] for ds in ls_datasets if data_key in ds.meta.stats],
dim=0,
),
"n ... -> ...",
stat_key,
)
total_samples = sum(d.num_samples for d in ls_datasets if data_key in d.stats)
total_samples = sum(d.num_frames for d in ls_datasets if data_key in d.meta.stats)
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
# dataset, then divide by total_samples to get the overall "mean".
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
# numerical overflow!
stats[data_key]["mean"] = sum(
d.stats[data_key]["mean"] * (d.num_samples / total_samples)
d.meta.stats[data_key]["mean"] * (d.num_frames / total_samples)
for d in ls_datasets
if data_key in d.stats
if data_key in d.meta.stats
)
# The derivation for standard deviation is a little more involved but is much in the same spirit as
# the computation of the mean.
# Given two sets of data where the statistics are known:
# σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
# where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
# numerical overflow!
stats[data_key]["std"] = torch.sqrt(
sum(
(d.stats[data_key]["std"] ** 2 + (d.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2)
* (d.num_samples / total_samples)
(
d.meta.stats[data_key]["std"] ** 2
+ (d.meta.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2
)
* (d.num_frames / total_samples)
for d in ls_datasets
if data_key in d.stats
if data_key in d.meta.stats
)
)
return stats

View File

@ -91,9 +91,9 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
)
if isinstance(cfg.dataset_repo_id, str):
# TODO (aliberts): add 'episodes' arg from config after removing hydra
dataset = LeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
image_transforms=image_transforms,
video_backend=cfg.video_backend,
@ -101,7 +101,6 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
else:
dataset = MultiLeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
image_transforms=image_transforms,
video_backend=cfg.video_backend,
@ -112,6 +111,6 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
for stats_type, listconfig in stats_dict.items():
# example of stats_type: min, max, mean, std
stats = OmegaConf.to_container(listconfig, resolve=True)
dataset.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
return dataset

View File

@ -0,0 +1,160 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
import queue
import threading
from pathlib import Path
import numpy as np
import PIL.Image
import torch
def safe_stop_image_writer(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
dataset = kwargs.get("dataset", None)
image_writer = getattr(dataset, "image_writer", None) if dataset else None
if image_writer is not None:
print("Waiting for image writer to terminate...")
image_writer.stop()
raise e
return wrapper
def image_array_to_image(image_array: np.ndarray) -> PIL.Image.Image:
# TODO(aliberts): handle 1 channel and 4 for depth images
if image_array.ndim == 3 and image_array.shape[0] in [1, 3]:
# Transpose from pytorch convention (C, H, W) to (H, W, C)
image_array = image_array.transpose(1, 2, 0)
if image_array.dtype != np.uint8:
# Assume the image is in [0, 1] range for floating-point data
image_array = np.clip(image_array, 0, 1)
image_array = (image_array * 255).astype(np.uint8)
return PIL.Image.fromarray(image_array)
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
try:
if isinstance(image, np.ndarray):
img = image_array_to_image(image)
elif isinstance(image, PIL.Image.Image):
img = image
else:
raise TypeError(f"Unsupported image type: {type(image)}")
img.save(fpath)
except Exception as e:
print(f"Error writing image {fpath}: {e}")
def worker_thread_loop(queue: queue.Queue):
while True:
item = queue.get()
if item is None:
queue.task_done()
break
image_array, fpath = item
write_image(image_array, fpath)
queue.task_done()
def worker_process(queue: queue.Queue, num_threads: int):
threads = []
for _ in range(num_threads):
t = threading.Thread(target=worker_thread_loop, args=(queue,))
t.daemon = True
t.start()
threads.append(t)
for t in threads:
t.join()
class AsyncImageWriter:
"""
This class abstract away the initialisation of processes or/and threads to
save images on disk asynchrounously, which is critical to control a robot and record data
at a high frame rate.
When `num_processes=0`, it creates a threads pool of size `num_threads`.
When `num_processes>0`, it creates processes pool of size `num_processes`, where each subprocess starts
their own threads pool of size `num_threads`.
The optimal number of processes and threads depends on your computer capabilities.
We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
"""
def __init__(self, num_processes: int = 0, num_threads: int = 1):
self.num_processes = num_processes
self.num_threads = num_threads
self.queue = None
self.threads = []
self.processes = []
self._stopped = False
if num_threads <= 0 and num_processes <= 0:
raise ValueError("Number of threads and processes must be greater than zero.")
if self.num_processes == 0:
# Use threading
self.queue = queue.Queue()
for _ in range(self.num_threads):
t = threading.Thread(target=worker_thread_loop, args=(self.queue,))
t.daemon = True
t.start()
self.threads.append(t)
else:
# Use multiprocessing
self.queue = multiprocessing.JoinableQueue()
for _ in range(self.num_processes):
p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads))
p.daemon = True
p.start()
self.processes.append(p)
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
if isinstance(image, torch.Tensor):
# Convert tensor to numpy array to minimize main process time
image = image.cpu().numpy()
self.queue.put((image, fpath))
def wait_until_done(self):
self.queue.join()
def stop(self):
if self._stopped:
return
if self.num_processes == 0:
for _ in self.threads:
self.queue.put(None)
for t in self.threads:
t.join()
else:
num_nones = self.num_processes * self.num_threads
for _ in range(num_nones):
self.queue.put(None)
for p in self.processes:
p.join()
if p.is_alive():
p.terminate()
self.queue.close()
self.queue.join_thread()
self._stopped = True

File diff suppressed because it is too large Load Diff

View File

@ -187,7 +187,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
assert data[OnlineBuffer.INDEX_KEY][0].item() == 0
# Shift the incoming indices if necessary.
if self.num_samples > 0:
if self.num_frames > 0:
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
@ -227,11 +227,11 @@ class OnlineBuffer(torch.utils.data.Dataset):
)
@property
def num_samples(self) -> int:
def num_frames(self) -> int:
return np.count_nonzero(self._data[OnlineBuffer.OCCUPANCY_MASK_KEY])
def __len__(self):
return self.num_samples
return self.num_frames
def _item_to_tensors(self, item: dict) -> dict:
item_ = {}

View File

@ -1,468 +0,0 @@
"""Functions to create an empty dataset, and populate it with frames."""
# TODO(rcadene, aliberts): to adapt as class methods of next version of LeRobotDataset
import concurrent
import json
import logging
import multiprocessing
import shutil
from pathlib import Path
import torch
import tqdm
from PIL import Image
from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, get_default_encoding
from lerobot.common.datasets.utils import calculate_episode_data_index, create_branch
from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.common.utils.utils import log_say
from lerobot.scripts.push_dataset_to_hub import (
push_dataset_card_to_hub,
push_meta_data_to_hub,
push_videos_to_hub,
save_meta_data,
)
########################################################################################
# Asynchrounous saving of images on disk
########################################################################################
def safe_stop_image_writer(func):
# TODO(aliberts): Allow to pass custom exceptions
# (e.g. ThreadServiceExit, KeyboardInterrupt, SystemExit, UnpluggedError, DynamixelCommError)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
image_writer = kwargs.get("dataset", {}).get("image_writer")
if image_writer is not None:
print("Waiting for image writer to terminate...")
stop_image_writer(image_writer, timeout=20)
raise e
return wrapper
def save_image(img_tensor, key, frame_index, episode_index, videos_dir: str):
img = Image.fromarray(img_tensor.numpy())
path = Path(videos_dir) / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
path.parent.mkdir(parents=True, exist_ok=True)
img.save(str(path), quality=100)
def loop_to_save_images_in_threads(image_queue, num_threads):
if num_threads < 1:
raise NotImplementedError(f"Only `num_threads>=1` is supported for now, but {num_threads=} given.")
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = []
while True:
# Blocks until a frame is available
frame_data = image_queue.get()
# As usually done, exit loop when receiving None to stop the worker
if frame_data is None:
break
image, key, frame_index, episode_index, videos_dir = frame_data
futures.append(executor.submit(save_image, image, key, frame_index, episode_index, videos_dir))
# Before exiting function, wait for all threads to complete
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
concurrent.futures.wait(futures)
progress_bar.update(len(futures))
def start_image_writer_processes(image_queue, num_processes, num_threads_per_process):
if num_processes < 1:
raise ValueError(f"Only `num_processes>=1` is supported, but {num_processes=} given.")
if num_threads_per_process < 1:
raise NotImplementedError(
"Only `num_threads_per_process>=1` is supported for now, but {num_threads_per_process=} given."
)
processes = []
for _ in range(num_processes):
process = multiprocessing.Process(
target=loop_to_save_images_in_threads,
args=(image_queue, num_threads_per_process),
)
process.start()
processes.append(process)
return processes
def stop_processes(processes, queue, timeout):
# Send None to each process to signal them to stop
for _ in processes:
queue.put(None)
# Wait maximum 20 seconds for all processes to terminate
for process in processes:
process.join(timeout=timeout)
# If not terminated after 20 seconds, force termination
if process.is_alive():
process.terminate()
# Close the queue, no more items can be put in the queue
queue.close()
# Ensure all background queue threads have finished
queue.join_thread()
def start_image_writer(num_processes, num_threads):
"""This function abstract away the initialisation of processes or/and threads to
save images on disk asynchrounously, which is critical to control a robot and record data
at a high frame rate.
When `num_processes=0`, it returns a dictionary containing a threads pool of size `num_threads`.
When `num_processes>0`, it returns a dictionary containing a processes pool of size `num_processes`,
where each subprocess starts their own threads pool of size `num_threads`.
The optimal number of processes and threads depends on your computer capabilities.
We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
"""
image_writer = {}
if num_processes == 0:
futures = []
threads_pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_threads)
image_writer["threads_pool"], image_writer["futures"] = threads_pool, futures
else:
# TODO(rcadene): When using num_processes>1, `multiprocessing.Manager().Queue()`
# might be better than `multiprocessing.Queue()`. Source: https://www.geeksforgeeks.org/python-multiprocessing-queue-vs-multiprocessing-manager-queue
image_queue = multiprocessing.Queue()
processes_pool = start_image_writer_processes(
image_queue, num_processes=num_processes, num_threads_per_process=num_threads
)
image_writer["processes_pool"], image_writer["image_queue"] = processes_pool, image_queue
return image_writer
def async_save_image(image_writer, image, key, frame_index, episode_index, videos_dir):
"""This function abstract away the saving of an image on disk asynchrounously. It uses a dictionary
called image writer which contains either a pool of processes or a pool of threads.
"""
if "threads_pool" in image_writer:
threads_pool, futures = image_writer["threads_pool"], image_writer["futures"]
futures.append(threads_pool.submit(save_image, image, key, frame_index, episode_index, videos_dir))
else:
image_queue = image_writer["image_queue"]
image_queue.put((image, key, frame_index, episode_index, videos_dir))
def stop_image_writer(image_writer, timeout):
if "threads_pool" in image_writer:
futures = image_writer["futures"]
# Before exiting function, wait for all threads to complete
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
concurrent.futures.wait(futures, timeout=timeout)
progress_bar.update(len(futures))
else:
processes_pool, image_queue = image_writer["processes_pool"], image_writer["image_queue"]
stop_processes(processes_pool, image_queue, timeout=timeout)
########################################################################################
# Functions to initialize, resume and populate a dataset
########################################################################################
def init_dataset(
repo_id,
root,
force_override,
fps,
video,
write_images,
num_image_writer_processes,
num_image_writer_threads,
):
local_dir = Path(root) / repo_id
if local_dir.exists() and force_override:
shutil.rmtree(local_dir)
episodes_dir = local_dir / "episodes"
episodes_dir.mkdir(parents=True, exist_ok=True)
videos_dir = local_dir / "videos"
videos_dir.mkdir(parents=True, exist_ok=True)
# Logic to resume data recording
rec_info_path = episodes_dir / "data_recording_info.json"
if rec_info_path.exists():
with open(rec_info_path) as f:
rec_info = json.load(f)
num_episodes = rec_info["last_episode_index"] + 1
else:
num_episodes = 0
dataset = {
"repo_id": repo_id,
"local_dir": local_dir,
"videos_dir": videos_dir,
"episodes_dir": episodes_dir,
"fps": fps,
"video": video,
"rec_info_path": rec_info_path,
"num_episodes": num_episodes,
}
if write_images:
# Initialize processes or/and threads dedicated to save images on disk asynchronously,
# which is critical to control a robot and record data at a high frame rate.
image_writer = start_image_writer(
num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads,
)
dataset["image_writer"] = image_writer
return dataset
def add_frame(dataset, observation, action):
if "current_episode" not in dataset:
# initialize episode dictionary
ep_dict = {}
for key in observation:
if key not in ep_dict:
ep_dict[key] = []
for key in action:
if key not in ep_dict:
ep_dict[key] = []
ep_dict["episode_index"] = []
ep_dict["frame_index"] = []
ep_dict["timestamp"] = []
ep_dict["next.done"] = []
dataset["current_episode"] = ep_dict
dataset["current_frame_index"] = 0
ep_dict = dataset["current_episode"]
episode_index = dataset["num_episodes"]
frame_index = dataset["current_frame_index"]
videos_dir = dataset["videos_dir"]
video = dataset["video"]
fps = dataset["fps"]
ep_dict["episode_index"].append(episode_index)
ep_dict["frame_index"].append(frame_index)
ep_dict["timestamp"].append(frame_index / fps)
ep_dict["next.done"].append(False)
img_keys = [key for key in observation if "image" in key]
non_img_keys = [key for key in observation if "image" not in key]
# Save all observed modalities except images
for key in non_img_keys:
ep_dict[key].append(observation[key])
# Save actions
for key in action:
ep_dict[key].append(action[key])
if "image_writer" not in dataset:
dataset["current_frame_index"] += 1
return
# Save images
image_writer = dataset["image_writer"]
for key in img_keys:
imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
async_save_image(
image_writer,
image=observation[key],
key=key,
frame_index=frame_index,
episode_index=episode_index,
videos_dir=str(videos_dir),
)
if video:
fname = f"{key}_episode_{episode_index:06d}.mp4"
frame_info = {"path": f"videos/{fname}", "timestamp": frame_index / fps}
else:
frame_info = str(imgs_dir / f"frame_{frame_index:06d}.png")
ep_dict[key].append(frame_info)
dataset["current_frame_index"] += 1
def delete_current_episode(dataset):
del dataset["current_episode"]
del dataset["current_frame_index"]
# delete temporary images
episode_index = dataset["num_episodes"]
videos_dir = dataset["videos_dir"]
for tmp_imgs_dir in videos_dir.glob(f"*_episode_{episode_index:06d}"):
shutil.rmtree(tmp_imgs_dir)
def save_current_episode(dataset):
episode_index = dataset["num_episodes"]
ep_dict = dataset["current_episode"]
episodes_dir = dataset["episodes_dir"]
rec_info_path = dataset["rec_info_path"]
ep_dict["next.done"][-1] = True
for key in ep_dict:
if "observation" in key and "image" not in key:
ep_dict[key] = torch.stack(ep_dict[key])
ep_dict["action"] = torch.stack(ep_dict["action"])
ep_dict["episode_index"] = torch.tensor(ep_dict["episode_index"])
ep_dict["frame_index"] = torch.tensor(ep_dict["frame_index"])
ep_dict["timestamp"] = torch.tensor(ep_dict["timestamp"])
ep_dict["next.done"] = torch.tensor(ep_dict["next.done"])
ep_path = episodes_dir / f"episode_{episode_index}.pth"
torch.save(ep_dict, ep_path)
rec_info = {
"last_episode_index": episode_index,
}
with open(rec_info_path, "w") as f:
json.dump(rec_info, f)
# force re-initialization of episode dictionnary during add_frame
del dataset["current_episode"]
dataset["num_episodes"] += 1
def encode_videos(dataset, image_keys, play_sounds):
log_say("Encoding videos", play_sounds)
num_episodes = dataset["num_episodes"]
videos_dir = dataset["videos_dir"]
local_dir = dataset["local_dir"]
fps = dataset["fps"]
# Use ffmpeg to convert frames stored as png into mp4 videos
for episode_index in tqdm.tqdm(range(num_episodes)):
for key in image_keys:
# key = f"observation.images.{name}"
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
fname = f"{key}_episode_{episode_index:06d}.mp4"
video_path = local_dir / "videos" / fname
if video_path.exists():
# Skip if video is already encoded. Could be the case when resuming data recording.
continue
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
# since video encoding with ffmpeg is already using multithreading.
encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True)
shutil.rmtree(tmp_imgs_dir)
def from_dataset_to_lerobot_dataset(dataset, play_sounds):
log_say("Consolidate episodes", play_sounds)
num_episodes = dataset["num_episodes"]
episodes_dir = dataset["episodes_dir"]
videos_dir = dataset["videos_dir"]
video = dataset["video"]
fps = dataset["fps"]
repo_id = dataset["repo_id"]
ep_dicts = []
for episode_index in tqdm.tqdm(range(num_episodes)):
ep_path = episodes_dir / f"episode_{episode_index}.pth"
ep_dict = torch.load(ep_path)
ep_dicts.append(ep_dict)
data_dict = concatenate_episodes(ep_dicts)
if video:
image_keys = [key for key in data_dict if "image" in key]
encode_videos(dataset, image_keys, play_sounds)
hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps,
"video": video,
}
if video:
info["encoding"] = get_default_encoding()
lerobot_dataset = LeRobotDataset.from_preloaded(
repo_id=repo_id,
hf_dataset=hf_dataset,
episode_data_index=episode_data_index,
info=info,
videos_dir=videos_dir,
)
return lerobot_dataset
def save_lerobot_dataset_on_disk(lerobot_dataset):
hf_dataset = lerobot_dataset.hf_dataset
info = lerobot_dataset.info
stats = lerobot_dataset.stats
episode_data_index = lerobot_dataset.episode_data_index
local_dir = lerobot_dataset.videos_dir.parent
meta_data_dir = local_dir / "meta_data"
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
hf_dataset.save_to_disk(str(local_dir / "train"))
save_meta_data(info, stats, episode_data_index, meta_data_dir)
def push_lerobot_dataset_to_hub(lerobot_dataset, tags):
hf_dataset = lerobot_dataset.hf_dataset
local_dir = lerobot_dataset.videos_dir.parent
videos_dir = lerobot_dataset.videos_dir
repo_id = lerobot_dataset.repo_id
video = lerobot_dataset.video
meta_data_dir = local_dir / "meta_data"
if not (local_dir / "train").exists():
raise ValueError(
"You need to run `save_lerobot_dataset_on_disk(lerobot_dataset)` before pushing to the hub."
)
hf_dataset.push_to_hub(repo_id, revision="main")
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
push_dataset_card_to_hub(repo_id, revision="main", tags=tags)
if video:
push_videos_to_hub(repo_id, videos_dir, revision="main")
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
def create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds):
if "image_writer" in dataset:
logging.info("Waiting for image writer to terminate...")
image_writer = dataset["image_writer"]
stop_image_writer(image_writer, timeout=20)
lerobot_dataset = from_dataset_to_lerobot_dataset(dataset, play_sounds)
if run_compute_stats:
log_say("Computing dataset statistics", play_sounds)
lerobot_dataset.stats = compute_stats(lerobot_dataset)
else:
logging.info("Skipping computation of the dataset statistics")
lerobot_dataset.stats = {}
save_lerobot_dataset_on_disk(lerobot_dataset)
if push_to_hub:
push_lerobot_dataset_to_hub(lerobot_dataset, tags)
return lerobot_dataset

View File

@ -30,12 +30,12 @@ from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
concatenate_episodes,
get_default_encoding,
save_images_concurrently,
)
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames

View File

@ -24,8 +24,11 @@ from datasets import Dataset, Features, Image, Value
from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
from lerobot.common.datasets.utils import calculate_episode_data_index, hf_transform_to_torch
from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
concatenate_episodes,
)
from lerobot.common.datasets.utils import hf_transform_to_torch
from lerobot.common.datasets.video_utils import VideoFrame

View File

@ -26,8 +26,8 @@ import torch
from datasets import Dataset, Features, Image, Sequence, Value
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame

View File

@ -42,12 +42,12 @@ from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.openx.transforms import OPENX_STANDARDIZATION_TRANSFORMS
from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
concatenate_episodes,
get_default_encoding,
save_images_concurrently,
)
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames

View File

@ -27,12 +27,12 @@ from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
concatenate_episodes,
get_default_encoding,
save_images_concurrently,
)
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames

View File

@ -28,12 +28,12 @@ from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
concatenate_episodes,
get_default_encoding,
save_images_concurrently,
)
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames

View File

@ -16,7 +16,9 @@
import inspect
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Dict
import datasets
import numpy
import PIL
import torch
@ -72,3 +74,58 @@ def check_repo_id(repo_id: str) -> None:
f"""`repo_id` is expected to contain a community or user id `/` the name of the dataset
(e.g. 'lerobot/pusht'), but contains '{repo_id}'."""
)
# TODO(aliberts): remove
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
"""
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
Parameters:
- hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index.
Returns:
- episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys:
- "from": A tensor containing the starting index of each episode.
- "to": A tensor containing the ending index of each episode.
"""
episode_data_index = {"from": [], "to": []}
current_episode = None
"""
The episode_index is a list of integers, each representing the episode index of the corresponding example.
For instance, the following is a valid episode_index:
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
{
"from": [0, 3, 7],
"to": [3, 7, 12]
}
"""
if len(hf_dataset) == 0:
episode_data_index = {
"from": torch.tensor([]),
"to": torch.tensor([]),
}
return episode_data_index
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
if episode_idx != current_episode:
# We encountered a new episode, so we append its starting location to the "from" list
episode_data_index["from"].append(idx)
# If this is not the first episode, we append the ending location of the previous episode to the "to" list
if current_episode is not None:
episode_data_index["to"].append(idx)
# Let's keep track of the current episode index
current_episode = episode_idx
else:
# We are still in the same episode, so there is nothing for us to do here
pass
# We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list
episode_data_index["to"].append(idx + 1)
for k in ["from", "to"]:
episode_data_index[k] = torch.tensor(episode_data_index[k])
return episode_data_index

View File

@ -27,12 +27,12 @@ from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
concatenate_episodes,
get_default_encoding,
save_images_concurrently,
)
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames

View File

@ -14,30 +14,56 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
import warnings
from functools import cache
import logging
import textwrap
from itertools import accumulate
from pathlib import Path
from typing import Dict
from pprint import pformat
from typing import Any
import datasets
import jsonlines
import numpy as np
import pyarrow.compute as pc
import torch
from datasets import load_dataset, load_from_disk
from huggingface_hub import DatasetCard, HfApi, hf_hub_download, snapshot_download
from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
from PIL import Image as PILImage
from safetensors.torch import load_file
from torchvision import transforms
from lerobot.common.robot_devices.robots.utils import Robot
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
INFO_PATH = "meta/info.json"
EPISODES_PATH = "meta/episodes.jsonl"
STATS_PATH = "meta/stats.json"
TASKS_PATH = "meta/tasks.jsonl"
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
DATASET_CARD_TEMPLATE = """
---
# Metadata will go there
---
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
## {}
"""
DEFAULT_FEATURES = {
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
"frame_index": {"dtype": "int64", "shape": (1,), "names": None},
"episode_index": {"dtype": "int64", "shape": (1,), "names": None},
"index": {"dtype": "int64", "shape": (1,), "names": None},
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
}
def flatten_dict(d, parent_key="", sep="/"):
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
For example:
@ -56,7 +82,7 @@ def flatten_dict(d, parent_key="", sep="/"):
return dict(items)
def unflatten_dict(d, sep="/"):
def unflatten_dict(d: dict, sep: str = "/") -> dict:
outdict = {}
for key, value in d.items():
parts = key.split(sep)
@ -69,6 +95,82 @@ def unflatten_dict(d, sep="/"):
return outdict
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
serialized_dict = {key: value.tolist() for key, value in flatten_dict(stats).items()}
return unflatten_dict(serialized_dict)
def write_parquet(dataset: datasets.Dataset, fpath: Path) -> None:
# Embed image bytes into the table before saving to parquet
format = dataset.format
dataset = dataset.with_format("arrow")
dataset = dataset.map(embed_table_storage, batched=False)
dataset = dataset.with_format(**format)
dataset.to_parquet(fpath)
def load_json(fpath: Path) -> Any:
with open(fpath) as f:
return json.load(f)
def write_json(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with open(fpath, "w") as f:
json.dump(data, f, indent=4, ensure_ascii=False)
def load_jsonlines(fpath: Path) -> list[Any]:
with jsonlines.open(fpath, "r") as reader:
return list(reader)
def write_jsonlines(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(data)
def append_jsonlines(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with jsonlines.open(fpath, "a") as writer:
writer.write(data)
def load_info(local_dir: Path) -> dict:
info = load_json(local_dir / INFO_PATH)
for ft in info["features"].values():
ft["shape"] = tuple(ft["shape"])
return info
def load_stats(local_dir: Path) -> dict:
if not (local_dir / STATS_PATH).exists():
return None
stats = load_json(local_dir / STATS_PATH)
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
def load_tasks(local_dir: Path) -> dict:
tasks = load_jsonlines(local_dir / TASKS_PATH)
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
def load_episodes(local_dir: Path) -> dict:
return load_jsonlines(local_dir / EPISODES_PATH)
def load_image_as_numpy(fpath: str | Path, dtype="float32", channel_first: bool = True) -> np.ndarray:
img = PILImage.open(fpath).convert("RGB")
img_array = np.array(img, dtype=dtype)
if channel_first: # (H, W, C) -> (C, H, W)
img_array = np.transpose(img_array, (2, 0, 1))
if "float" in dtype:
img_array /= 255.0
return img_array
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to
@ -80,14 +182,6 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
elif isinstance(first_item, str):
# TODO (michel-aractingi): add str2embedding via language tokenizer
# For now we leave this part up to the user to choose how to address
# language conditioned tasks
pass
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
# video frame will be processed downstream
pass
elif first_item is None:
pass
else:
@ -95,19 +189,67 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
return items_dict
@cache
def get_hf_dataset_safe_version(repo_id: str, version: str) -> str:
def _get_major_minor(version: str) -> tuple[int]:
split = version.strip("v").split(".")
return int(split[0]), int(split[1])
class BackwardCompatibilityError(Exception):
def __init__(self, repo_id, version):
message = textwrap.dedent(f"""
BackwardCompatibilityError: The dataset you requested ({repo_id}) is in {version} format.
We introduced a new format since v2.0 which is not backward compatible with v1.x.
Please, use our conversion script. Modify the following command with your own task description:
```
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
--repo-id {repo_id} \\
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
```
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.",
"Insert the peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.",
"Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.",
"Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ...
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
""")
super().__init__(message)
def check_version_compatibility(
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
) -> None:
current_major, _ = _get_major_minor(current_version)
major_to_check, _ = _get_major_minor(version_to_check)
if major_to_check < current_major and enforce_breaking_major:
raise BackwardCompatibilityError(repo_id, version_to_check)
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
logging.warning(
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
codebase. The current codebase version is {current_version}. You should be fine since
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
)
def get_hub_safe_version(repo_id: str, version: str) -> str:
api = HfApi()
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
branches = [b.name for b in dataset_info.branches]
if version not in branches:
warnings.warn(
num_version = float(version.strip("v"))
hub_num_versions = [float(v.strip("v")) for v in branches if v.startswith("v")]
if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions):
raise BackwardCompatibilityError(repo_id, version)
logging.warning(
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
codebase. The following versions are available: {branches}.
The requested version ('{version}') is not found. You should be fine since
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
stacklevel=1,
)
if "main" not in branches:
raise ValueError(f"Version 'main' not found on {repo_id}")
@ -116,275 +258,184 @@ def get_hf_dataset_safe_version(repo_id: str, version: str) -> str:
return version
def load_hf_dataset(repo_id: str, version: str, root: Path, split: str) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if root is not None:
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
# TODO(rcadene): clean this which enables getting a subset of dataset
if split != "train":
if "%" in split:
raise NotImplementedError(f"We dont support splitting based on percentage for now ({split}).")
match_from = re.search(r"train\[(\d+):\]", split)
match_to = re.search(r"train\[:(\d+)\]", split)
if match_from:
from_frame_index = int(match_from.group(1))
hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset)))
elif match_to:
to_frame_index = int(match_to.group(1))
hf_dataset = hf_dataset.select(range(to_frame_index))
else:
raise ValueError(
f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"'
)
else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
hf_dataset = load_dataset(repo_id, revision=safe_version, split=split)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
"""episode_data_index contains the range of indices for each episode
Example:
```python
from_id = episode_data_index["from"][episode_id].item()
to_id = episode_data_index["to"][episode_id].item()
episode_frames = [dataset[i] for i in range(from_id, to_id)]
```
"""
if root is not None:
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=safe_version
)
return load_file(path)
def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
Example:
```python
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
```
"""
if root is not None:
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(
repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=safe_version
)
stats = load_file(path)
return unflatten_dict(stats)
def load_info(repo_id, version, root) -> dict:
"""info contains useful information regarding the dataset that are not stored elsewhere
Example:
```python
print("frame per second used to collect the video", info["fps"])
```
"""
if root is not None:
path = Path(root) / repo_id / "meta_data" / "info.json"
else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=safe_version)
with open(path) as f:
info = json.load(f)
return info
def load_videos(repo_id, version, root) -> Path:
if root is not None:
path = Path(root) / repo_id / "videos"
else:
# TODO(rcadene): we download the whole repo here. see if we can avoid this
safe_version = get_hf_dataset_safe_version(repo_id, version)
repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=safe_version)
path = Path(repo_dir) / "videos"
return path
def load_previous_and_future_frames(
item: dict[str, torch.Tensor],
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
delta_timestamps: dict[str, list[float]],
tolerance_s: float,
) -> dict[torch.Tensor]:
"""
Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of
some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}), this function computes for each
given modality (e.g. "observation.image") a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest
frames in the dataset.
Importantly, when no frame can be found around a query timestamp within a specified tolerance window, this function
raises an AssertionError. When a timestamp is queried before the first available timestamp of the episode or after
the last available timestamp, the violation of the tolerance doesnt raise an AssertionError, and the function
populates a boolean array indicating which frames are outside of the episode range. For instance, this boolean array
is useful during batched training to not supervise actions associated to timestamps coming after the end of the
episode, or to pad the observations in a specific way. Note that by default the observation frames before the start
of the episode are the same as the first frame of the episode.
Parameters:
- item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key
corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different
modality (e.g., "timestamp", "observation.image", "action").
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
They indicate the start index and end index of each episode in the dataset.
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be
retrieved. These deltas are added to the item timestamp to form the query timestamps.
- tolerance_s (float, optional): The tolerance level (in seconds) used to determine if a data point is close enough to the query
timestamp by asserting `tol > difference`. It is suggested to set `tol` to a smaller value than the
smallest expected inter-frame period, but large enough to account for jitter.
Returns:
- The same item with the queried frames for each modality specified in delta_timestamps, with an additional key for
each modality (e.g. "observation.image_is_pad").
Raises:
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization
issues with timestamps during data collection.
"""
# get indices of the frames associated to the episode, and their timestamps
ep_id = item["episode_index"].item()
ep_data_id_from = episode_data_index["from"][ep_id].item()
ep_data_id_to = episode_data_index["to"][ep_id].item()
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
# load timestamps
ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
ep_timestamps = torch.stack(ep_timestamps)
# we make the assumption that the timestamps are sorted
ep_first_ts = ep_timestamps[0]
ep_last_ts = ep_timestamps[-1]
current_ts = item["timestamp"].item()
for key in delta_timestamps:
# get timestamps used as query to retrieve data of previous/future frames
delta_ts = delta_timestamps[key]
query_ts = current_ts + torch.tensor(delta_ts)
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1)
min_, argmin_ = dist.min(1)
# TODO(rcadene): synchronize timestamps + interpolation if needed
is_pad = min_ > tolerance_s
# check violated query timestamps are all outside the episode range
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tolerance_s=}) inside episode range."
"This might be due to synchronization issues with timestamps during data collection."
)
# get dataset indices corresponding to frames to be loaded
data_ids = ep_data_ids[argmin_]
# load frames modality
item[key] = hf_dataset.select_columns(key)[data_ids][key]
if isinstance(item[key][0], dict) and "path" in item[key][0]:
# video mode where frame are expressed as dict of path and timestamp
item[key] = item[key]
def get_hf_features_from_features(features: dict) -> datasets.Features:
hf_features = {}
for key, ft in features.items():
if ft["dtype"] == "video":
continue
elif ft["dtype"] == "image":
hf_features[key] = datasets.Image()
elif ft["shape"] == (1,):
hf_features[key] = datasets.Value(dtype=ft["dtype"])
else:
item[key] = torch.stack(item[key])
assert len(ft["shape"]) == 1
hf_features[key] = datasets.Sequence(
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
)
item[f"{key}_is_pad"] = is_pad
return item
return datasets.Features(hf_features)
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
"""
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
Parameters:
- hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index.
Returns:
- episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys:
- "from": A tensor containing the starting index of each episode.
- "to": A tensor containing the ending index of each episode.
"""
episode_data_index = {"from": [], "to": []}
current_episode = None
"""
The episode_index is a list of integers, each representing the episode index of the corresponding example.
For instance, the following is a valid episode_index:
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
{
"from": [0, 3, 7],
"to": [3, 7, 12]
def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
camera_ft = {}
if robot.cameras:
camera_ft = {
key: {"dtype": "video" if use_videos else "image", **ft}
for key, ft in robot.camera_features.items()
}
"""
if len(hf_dataset) == 0:
episode_data_index = {
"from": torch.tensor([]),
"to": torch.tensor([]),
}
return episode_data_index
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
if episode_idx != current_episode:
# We encountered a new episode, so we append its starting location to the "from" list
episode_data_index["from"].append(idx)
# If this is not the first episode, we append the ending location of the previous episode to the "to" list
if current_episode is not None:
episode_data_index["to"].append(idx)
# Let's keep track of the current episode index
current_episode = episode_idx
else:
# We are still in the same episode, so there is nothing for us to do here
pass
# We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list
episode_data_index["to"].append(idx + 1)
for k in ["from", "to"]:
episode_data_index[k] = torch.tensor(episode_data_index[k])
return episode_data_index
return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES}
def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
"""Reset the `episode_index` of the provided HuggingFace Dataset.
`episode_data_index` (and related functionality such as `load_previous_and_future_frames`) requires the
`episode_index` to be sorted, continuous (1,1,1 and not 1,2,1) and start at 0.
This brings the `episode_index` to the required format.
"""
if len(hf_dataset) == 0:
return hf_dataset
unique_episode_idxs = torch.stack(hf_dataset["episode_index"]).unique().tolist()
episode_idx_to_reset_idx_mapping = {
ep_id: reset_ep_id for reset_ep_id, ep_id in enumerate(unique_episode_idxs)
def create_empty_dataset_info(
codebase_version: str,
fps: int,
robot_type: str,
features: dict,
use_videos: bool,
) -> dict:
return {
"codebase_version": codebase_version,
"robot_type": robot_type,
"total_episodes": 0,
"total_frames": 0,
"total_tasks": 0,
"total_videos": 0,
"total_chunks": 0,
"chunks_size": DEFAULT_CHUNK_SIZE,
"fps": fps,
"splits": {},
"data_path": DEFAULT_PARQUET_PATH,
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
"features": features,
}
def modify_ep_idx_func(example):
example["episode_index"] = episode_idx_to_reset_idx_mapping[example["episode_index"].item()]
return example
hf_dataset = hf_dataset.map(modify_ep_idx_func)
def get_episode_data_index(
episode_dicts: list[dict], episodes: list[int] | None = None
) -> dict[str, torch.Tensor]:
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)}
if episodes is not None:
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
return hf_dataset
cumulative_lenghts = list(accumulate(episode_lengths.values()))
return {
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
"to": torch.LongTensor(cumulative_lenghts),
}
def calculate_total_episode(
hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True
) -> dict[str, torch.Tensor]:
episode_indices = sorted(hf_dataset.unique("episode_index"))
total_episodes = len(episode_indices)
if raise_if_not_contiguous and episode_indices != list(range(total_episodes)):
raise ValueError("episode_index values are not sorted and contiguous.")
return total_episodes
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]:
episode_lengths = []
table = hf_dataset.data.table
total_episodes = calculate_total_episode(hf_dataset)
for ep_idx in range(total_episodes):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
episode_lengths.insert(ep_idx, len(ep_table))
cumulative_lenghts = list(accumulate(episode_lengths))
return {
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
"to": torch.LongTensor(cumulative_lenghts),
}
def check_timestamps_sync(
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
fps: int,
tolerance_s: float,
raise_value_error: bool = True,
) -> bool:
"""
This check is to make sure that each timestamps is separated to the next by 1/fps +/- tolerance to
account for possible numerical error.
"""
timestamps = torch.stack(hf_dataset["timestamp"])
diffs = torch.diff(timestamps)
within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s
# We mask differences between the timestamp at the end of an episode
# and the one at the start of the next episode since these are expected
# to be outside tolerance.
mask = torch.ones(len(diffs), dtype=torch.bool)
ignored_diffs = episode_data_index["to"][:-1] - 1
mask[ignored_diffs] = False
filtered_within_tolerance = within_tolerance[mask]
if not torch.all(filtered_within_tolerance):
# Track original indices before masking
original_indices = torch.arange(len(diffs))
filtered_indices = original_indices[mask]
outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance) # .squeeze()
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
episode_indices = torch.stack(hf_dataset["episode_index"])
outside_tolerances = []
for idx in outside_tolerance_indices:
entry = {
"timestamps": [timestamps[idx], timestamps[idx + 1]],
"diff": diffs[idx],
"episode_index": episode_indices[idx].item(),
}
outside_tolerances.append(entry)
if raise_value_error:
raise ValueError(
f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
This might be due to synchronization issues with timestamps during data collection.
\n{pformat(outside_tolerances)}"""
)
return False
return True
def check_delta_timestamps(
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
) -> bool:
"""This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be
actual timestamps from the dataset.
"""
outside_tolerance = {}
for key, delta_ts in delta_timestamps.items():
within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts]
if not all(within_tolerance):
outside_tolerance[key] = [
ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within
]
if len(outside_tolerance) > 0:
if raise_value_error:
raise ValueError(
f"""
The following delta_timestamps are found outside of tolerance range.
Please make sure they are multiples of 1/{fps} +/- tolerance and adjust
their values accordingly.
\n{pformat(outside_tolerance)}
"""
)
return False
return True
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
delta_indices = {}
for key, delta_ts in delta_timestamps.items():
delta_indices[key] = (torch.tensor(delta_ts) * fps).long().tolist()
return delta_indices
def cycle(iterable):
@ -400,7 +451,7 @@ def cycle(iterable):
iterator = iter(iterable)
def create_branch(repo_id, *, branch: str, repo_type: str | None = None):
def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None:
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already
exists before creating it.
"""
@ -415,12 +466,35 @@ def create_branch(repo_id, *, branch: str, repo_type: str | None = None):
api.create_branch(repo_id, repo_type=repo_type, branch=branch)
def create_lerobot_dataset_card(tags: list | None = None, text: str | None = None) -> DatasetCard:
card = DatasetCard(DATASET_CARD_TEMPLATE)
card.data.task_categories = ["robotics"]
card.data.tags = ["LeRobot"]
if tags is not None:
card.data.tags += tags
if text is not None:
card.text += text
return card
def create_lerobot_dataset_card(
tags: list | None = None,
dataset_info: dict | None = None,
**kwargs,
) -> DatasetCard:
"""
Keyword arguments will be used to replace values in ./lerobot/common/datasets/card_template.md.
Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses.
"""
card_tags = ["LeRobot"]
if tags:
card_tags += tags
if dataset_info:
dataset_structure = "[meta/info.json](meta/info.json):\n"
dataset_structure += f"```json\n{json.dumps(dataset_info, indent=4)}\n```\n"
kwargs = {**kwargs, "dataset_structure": dataset_structure}
card_data = DatasetCardData(
license=kwargs.get("license"),
tags=card_tags,
task_categories=["robotics"],
configs=[
{
"config_name": "default",
"data_files": "data/*/*.parquet",
}
],
)
return DatasetCard.from_template(
card_data=card_data,
template_path="./lerobot/common/datasets/card_template.md",
**kwargs,
)

View File

@ -0,0 +1,882 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.
Note: Since the original Aloha datasets don't use shadow motors, you need to comment those out in
lerobot/configs/robot/aloha.yaml before running this script.
"""
import traceback
from pathlib import Path
from textwrap import dedent
from lerobot import available_datasets
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset, parse_robot_config
LOCAL_DIR = Path("data/")
ALOHA_CONFIG = Path("lerobot/configs/robot/aloha.yaml")
ALOHA_MOBILE_INFO = {
"robot_config": parse_robot_config(ALOHA_CONFIG),
"license": "mit",
"url": "https://mobile-aloha.github.io/",
"paper": "https://arxiv.org/abs/2401.02117",
"citation_bibtex": dedent(r"""
@inproceedings{fu2024mobile,
author = {Fu, Zipeng and Zhao, Tony Z. and Finn, Chelsea},
title = {Mobile ALOHA: Learning Bimanual Mobile Manipulation with Low-Cost Whole-Body Teleoperation},
booktitle = {arXiv},
year = {2024},
}""").lstrip(),
}
ALOHA_STATIC_INFO = {
"robot_config": parse_robot_config(ALOHA_CONFIG),
"license": "mit",
"url": "https://tonyzhaozh.github.io/aloha/",
"paper": "https://arxiv.org/abs/2304.13705",
"citation_bibtex": dedent(r"""
@article{Zhao2023LearningFB,
title={Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware},
author={Tony Zhao and Vikash Kumar and Sergey Levine and Chelsea Finn},
journal={RSS},
year={2023},
volume={abs/2304.13705},
url={https://arxiv.org/abs/2304.13705}
}""").lstrip(),
}
PUSHT_INFO = {
"license": "mit",
"url": "https://diffusion-policy.cs.columbia.edu/",
"paper": "https://arxiv.org/abs/2303.04137v5",
"citation_bibtex": dedent(r"""
@article{chi2024diffusionpolicy,
author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song},
title ={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
journal = {The International Journal of Robotics Research},
year = {2024},
}""").lstrip(),
}
XARM_INFO = {
"license": "mit",
"url": "https://www.nicklashansen.com/td-mpc/",
"paper": "https://arxiv.org/abs/2203.04955",
"citation_bibtex": dedent(r"""
@inproceedings{Hansen2022tdmpc,
title={Temporal Difference Learning for Model Predictive Control},
author={Nicklas Hansen and Xiaolong Wang and Hao Su},
booktitle={ICML},
year={2022}
}
"""),
}
UNITREEH_INFO = {
"license": "apache-2.0",
}
DATASETS = {
"aloha_mobile_cabinet": {
"single_task": "Open the top cabinet, store the pot inside it then close the cabinet.",
**ALOHA_MOBILE_INFO,
},
"aloha_mobile_chair": {
"single_task": "Push the chairs in front of the desk to place them against it.",
**ALOHA_MOBILE_INFO,
},
"aloha_mobile_elevator": {
"single_task": "Take the elevator to the 1st floor.",
**ALOHA_MOBILE_INFO,
},
"aloha_mobile_shrimp": {
"single_task": "Sauté the raw shrimp on both sides, then serve it in the bowl.",
**ALOHA_MOBILE_INFO,
},
"aloha_mobile_wash_pan": {
"single_task": "Pick up the pan, rinse it in the sink and then place it in the drying rack.",
**ALOHA_MOBILE_INFO,
},
"aloha_mobile_wipe_wine": {
"single_task": "Pick up the wet cloth on the faucet and use it to clean the spilled wine on the table and underneath the glass.",
**ALOHA_MOBILE_INFO,
},
"aloha_static_battery": {
"single_task": "Place the battery into the slot of the remote controller.",
**ALOHA_STATIC_INFO,
},
"aloha_static_candy": {"single_task": "Pick up the candy and unwrap it.", **ALOHA_STATIC_INFO},
"aloha_static_coffee": {
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
**ALOHA_STATIC_INFO,
},
"aloha_static_coffee_new": {
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray.",
**ALOHA_STATIC_INFO,
},
"aloha_static_cups_open": {
"single_task": "Pick up the plastic cup and open its lid.",
**ALOHA_STATIC_INFO,
},
"aloha_static_fork_pick_up": {
"single_task": "Pick up the fork and place it on the plate.",
**ALOHA_STATIC_INFO,
},
"aloha_static_pingpong_test": {
"single_task": "Transfer one of the two balls in the right glass into the left glass, then transfer it back to the right glass.",
**ALOHA_STATIC_INFO,
},
"aloha_static_pro_pencil": {
"single_task": "Pick up the pencil with the right arm, hand it over to the left arm then place it back onto the table.",
**ALOHA_STATIC_INFO,
},
"aloha_static_screw_driver": {
"single_task": "Pick up the screwdriver with the right arm, hand it over to the left arm then place it into the cup.",
**ALOHA_STATIC_INFO,
},
"aloha_static_tape": {
"single_task": "Cut a small piece of tape from the tape dispenser then place it on the cardboard box's edge.",
**ALOHA_STATIC_INFO,
},
"aloha_static_thread_velcro": {
"single_task": "Pick up the velcro cable tie with the left arm, then insert the end of the velcro tie into the other end's loop with the right arm.",
**ALOHA_STATIC_INFO,
},
"aloha_static_towel": {
"single_task": "Pick up a piece of paper towel and place it on the spilled liquid.",
**ALOHA_STATIC_INFO,
},
"aloha_static_vinh_cup": {
"single_task": "Pick up the platic cup with the right arm, then pop its lid open with the left arm.",
**ALOHA_STATIC_INFO,
},
"aloha_static_vinh_cup_left": {
"single_task": "Pick up the platic cup with the left arm, then pop its lid open with the right arm.",
**ALOHA_STATIC_INFO,
},
"aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO},
"aloha_sim_insertion_scripted": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
"aloha_sim_insertion_scripted_image": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_human": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
"aloha_sim_insertion_human_image": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_transfer_cube_scripted": {
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_transfer_cube_scripted_image": {
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_transfer_cube_human": {
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_transfer_cube_human_image": {
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
**ALOHA_STATIC_INFO,
},
"pusht": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
"pusht_image": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
"unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO},
"unitreeh1_rearrange_objects": {"single_task": "Put the object into the bin.", **UNITREEH_INFO},
"unitreeh1_two_robot_greeting": {
"single_task": "Greet the other robot with a high five.",
**UNITREEH_INFO,
},
"unitreeh1_warehouse": {
"single_task": "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.",
**UNITREEH_INFO,
},
"xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_replay": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_replay_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_replay": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_replay_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"umi_cup_in_the_wild": {
"single_task": "Put the cup on the plate.",
"license": "apache-2.0",
},
"asu_table_top": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://link.springer.com/article/10.1007/s10514-023-10129-1",
"citation_bibtex": dedent(r"""
@inproceedings{zhou2023modularity,
title={Modularity through Attention: Efficient Training and Transfer of Language-Conditioned Policies for Robot Manipulation},
author={Zhou, Yifan and Sonawani, Shubham and Phielipp, Mariano and Stepputtis, Simon and Amor, Heni},
booktitle={Conference on Robot Learning},
pages={1684--1695},
year={2023},
organization={PMLR}
}
@article{zhou2023learning,
title={Learning modular language-conditioned robot policies through attention},
author={Zhou, Yifan and Sonawani, Shubham and Phielipp, Mariano and Ben Amor, Heni and Stepputtis, Simon},
journal={Autonomous Robots},
pages={1--21},
year={2023},
publisher={Springer}
}""").lstrip(),
},
"austin_buds_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://ut-austin-rpl.github.io/BUDS-website/",
"paper": "https://arxiv.org/abs/2109.13841",
"citation_bibtex": dedent(r"""
@article{zhu2022bottom,
title={Bottom-Up Skill Discovery From Unsegmented Demonstrations for Long-Horizon Robot Manipulation},
author={Zhu, Yifeng and Stone, Peter and Zhu, Yuke},
journal={IEEE Robotics and Automation Letters},
volume={7},
number={2},
pages={4126--4133},
year={2022},
publisher={IEEE}
}""").lstrip(),
},
"austin_sailor_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://ut-austin-rpl.github.io/sailor/",
"paper": "https://arxiv.org/abs/2210.11435",
"citation_bibtex": dedent(r"""
@inproceedings{nasiriany2022sailor,
title={Learning and Retrieval from Prior Data for Skill-based Imitation Learning},
author={Soroush Nasiriany and Tian Gao and Ajay Mandlekar and Yuke Zhu},
booktitle={Conference on Robot Learning (CoRL)},
year={2022}
}""").lstrip(),
},
"austin_sirius_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://ut-austin-rpl.github.io/sirius/",
"paper": "https://arxiv.org/abs/2211.08416",
"citation_bibtex": dedent(r"""
@inproceedings{liu2022robot,
title = {Robot Learning on the Job: Human-in-the-Loop Autonomy and Learning During Deployment},
author = {Huihan Liu and Soroush Nasiriany and Lance Zhang and Zhiyao Bao and Yuke Zhu},
booktitle = {Robotics: Science and Systems (RSS)},
year = {2023}
}""").lstrip(),
},
"berkeley_autolab_ur5": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"url": "https://sites.google.com/view/berkeley-ur5/home",
"citation_bibtex": dedent(r"""
@misc{BerkeleyUR5Website,
title = {Berkeley {UR5} Demonstration Dataset},
author = {Lawrence Yunliang Chen and Simeon Adebola and Ken Goldberg},
howpublished = {https://sites.google.com/view/berkeley-ur5/home},
}""").lstrip(),
},
"berkeley_cable_routing": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"url": "https://sites.google.com/view/cablerouting/home",
"paper": "https://arxiv.org/abs/2307.08927",
"citation_bibtex": dedent(r"""
@article{luo2023multistage,
author = {Jianlan Luo and Charles Xu and Xinyang Geng and Gilbert Feng and Kuan Fang and Liam Tan and Stefan Schaal and Sergey Levine},
title = {Multi-Stage Cable Routing through Hierarchical Imitation Learning},
journal = {arXiv pre-print},
year = {2023},
url = {https://arxiv.org/abs/2307.08927},
}""").lstrip(),
},
"berkeley_fanuc_manipulation": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://sites.google.com/berkeley.edu/fanuc-manipulation",
"citation_bibtex": dedent(r"""
@article{fanuc_manipulation2023,
title={Fanuc Manipulation: A Dataset for Learning-based Manipulation with FANUC Mate 200iD Robot},
author={Zhu, Xinghao and Tian, Ran and Xu, Chenfeng and Ding, Mingyu and Zhan, Wei and Tomizuka, Masayoshi},
year={2023},
}""").lstrip(),
},
"berkeley_gnm_cory_hall": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://arxiv.org/abs/1709.10489",
"citation_bibtex": dedent(r"""
@inproceedings{kahn2018self,
title={Self-supervised deep reinforcement learning with generalized computation graphs for robot navigation},
author={Kahn, Gregory and Villaflor, Adam and Ding, Bosen and Abbeel, Pieter and Levine, Sergey},
booktitle={2018 IEEE international conference on robotics and automation (ICRA)},
pages={5129--5136},
year={2018},
organization={IEEE}
}""").lstrip(),
},
"berkeley_gnm_recon": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://sites.google.com/view/recon-robot",
"paper": "https://arxiv.org/abs/2104.05859",
"citation_bibtex": dedent(r"""
@inproceedings{shah2021rapid,
title={Rapid Exploration for Open-World Navigation with Latent Goal Models},
author={Dhruv Shah and Benjamin Eysenbach and Nicholas Rhinehart and Sergey Levine},
booktitle={5th Annual Conference on Robot Learning },
year={2021},
url={https://openreview.net/forum?id=d_SWJhyKfVw}
}""").lstrip(),
},
"berkeley_gnm_sac_son": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://sites.google.com/view/SACSoN-review",
"paper": "https://arxiv.org/abs/2306.01874",
"citation_bibtex": dedent(r"""
@article{hirose2023sacson,
title={SACSoN: Scalable Autonomous Data Collection for Social Navigation},
author={Hirose, Noriaki and Shah, Dhruv and Sridhar, Ajay and Levine, Sergey},
journal={arXiv preprint arXiv:2306.01874},
year={2023}
}""").lstrip(),
},
"berkeley_mvp": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://arxiv.org/abs/2203.06173",
"citation_bibtex": dedent(r"""
@InProceedings{Radosavovic2022,
title = {Real-World Robot Learning with Masked Visual Pre-training},
author = {Ilija Radosavovic and Tete Xiao and Stephen James and Pieter Abbeel and Jitendra Malik and Trevor Darrell},
booktitle = {CoRL},
year = {2022}
}""").lstrip(),
},
"berkeley_rpt": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://arxiv.org/abs/2306.10007",
"citation_bibtex": dedent(r"""
@article{Radosavovic2023,
title={Robot Learning with Sensorimotor Pre-training},
author={Ilija Radosavovic and Baifeng Shi and Letian Fu and Ken Goldberg and Trevor Darrell and Jitendra Malik},
year={2023},
journal={arXiv:2306.10007}
}""").lstrip(),
},
"cmu_franka_exploration_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://human-world-model.github.io/",
"paper": "https://arxiv.org/abs/2308.10901",
"citation_bibtex": dedent(r"""
@inproceedings{mendonca2023structured,
title={Structured World Models from Human Videos},
author={Mendonca, Russell and Bahl, Shikhar and Pathak, Deepak},
journal={RSS},
year={2023}
}""").lstrip(),
},
"cmu_play_fusion": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://play-fusion.github.io/",
"paper": "https://arxiv.org/abs/2312.04549",
"citation_bibtex": dedent(r"""
@inproceedings{chen2023playfusion,
title={PlayFusion: Skill Acquisition via Diffusion from Language-Annotated Play},
author={Chen, Lili and Bahl, Shikhar and Pathak, Deepak},
booktitle={CoRL},
year={2023}
}""").lstrip(),
},
"cmu_stretch": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://robo-affordances.github.io/",
"paper": "https://arxiv.org/abs/2304.08488",
"citation_bibtex": dedent(r"""
@inproceedings{bahl2023affordances,
title={Affordances from Human Videos as a Versatile Representation for Robotics},
author={Bahl, Shikhar and Mendonca, Russell and Chen, Lili and Jain, Unnat and Pathak, Deepak},
booktitle={CVPR},
year={2023}
}
@article{mendonca2023structured,
title={Structured World Models from Human Videos},
author={Mendonca, Russell and Bahl, Shikhar and Pathak, Deepak},
journal={CoRL},
year={2023}
}""").lstrip(),
},
"columbia_cairlab_pusht_real": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://diffusion-policy.cs.columbia.edu/",
"paper": "https://arxiv.org/abs/2303.04137v5",
"citation_bibtex": dedent(r"""
@inproceedings{chi2023diffusionpolicy,
title={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
author={Chi, Cheng and Feng, Siyuan and Du, Yilun and Xu, Zhenjia and Cousineau, Eric and Burchfiel, Benjamin and Song, Shuran},
booktitle={Proceedings of Robotics: Science and Systems (RSS)},
year={2023}
}""").lstrip(),
},
"conq_hose_manipulation": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://sites.google.com/view/conq-hose-manipulation-dataset/home",
"citation_bibtex": dedent(r"""
@misc{ConqHoseManipData,
author={Peter Mitrano and Dmitry Berenson},
title={Conq Hose Manipulation Dataset, v1.15.0},
year={2024},
howpublished={https://sites.google.com/view/conq-hose-manipulation-dataset}
}""").lstrip(),
},
"dlr_edan_shared_control": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://ieeexplore.ieee.org/document/9341156",
"citation_bibtex": dedent(r"""
@inproceedings{vogel_edan_2020,
title = {EDAN - an EMG-Controlled Daily Assistant to Help People with Physical Disabilities},
language = {en},
booktitle = {2020 {IEEE}/{RSJ} {International} {Conference} on {Intelligent} {Robots} and {Systems} ({IROS})},
author = {Vogel, Jörn and Hagengruber, Annette and Iskandar, Maged and Quere, Gabriel and Leipscher, Ulrike and Bustamante, Samuel and Dietrich, Alexander and Hoeppner, Hannes and Leidner, Daniel and Albu-Schäffer, Alin},
year = {2020}
}
@inproceedings{quere_shared_2020,
address = {Paris, France},
title = {Shared {Control} {Templates} for {Assistive} {Robotics}},
language = {en},
booktitle = {2020 {IEEE} {International} {Conference} on {Robotics} and {Automation} ({ICRA})},
author = {Quere, Gabriel and Hagengruber, Annette and Iskandar, Maged and Bustamante, Samuel and Leidner, Daniel and Stulp, Freek and Vogel, Joern},
year = {2020},
pages = {7},
}""").lstrip(),
},
"dlr_sara_grid_clamp": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://www.researchsquare.com/article/rs-3289569/v1",
"citation_bibtex": dedent(r"""
@article{padalkar2023guided,
title={A guided reinforcement learning approach using shared control templates for learning manipulation skills in the real world},
author={Padalkar, Abhishek and Quere, Gabriel and Raffin, Antonin and Silv{\'e}rio, Jo{\~a}o and Stulp, Freek},
journal={Research square preprint rs-3289569/v1},
year={2023}
}""").lstrip(),
},
"dlr_sara_pour": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://elib.dlr.de/193739/1/padalkar2023rlsct.pdf",
"citation_bibtex": dedent(r"""
@inproceedings{padalkar2023guiding,
title={Guiding Reinforcement Learning with Shared Control Templates},
author={Padalkar, Abhishek and Quere, Gabriel and Steinmetz, Franz and Raffin, Antonin and Nieuwenhuisen, Matthias and Silv{\'e}rio, Jo{\~a}o and Stulp, Freek},
booktitle={40th IEEE International Conference on Robotics and Automation, ICRA 2023},
year={2023},
organization={IEEE}
}""").lstrip(),
},
"droid_100": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://droid-dataset.github.io/",
"paper": "https://arxiv.org/abs/2403.12945",
"citation_bibtex": dedent(r"""
@article{khazatsky2024droid,
title = {DROID: A Large-Scale In-The-Wild Robot Manipulation Dataset},
author = {Alexander Khazatsky and Karl Pertsch and Suraj Nair and Ashwin Balakrishna and Sudeep Dasari and Siddharth Karamcheti and Soroush Nasiriany and Mohan Kumar Srirama and Lawrence Yunliang Chen and Kirsty Ellis and Peter David Fagan and Joey Hejna and Masha Itkina and Marion Lepert and Yecheng Jason Ma and Patrick Tree Miller and Jimmy Wu and Suneel Belkhale and Shivin Dass and Huy Ha and Arhan Jain and Abraham Lee and Youngwoon Lee and Marius Memmel and Sungjae Park and Ilija Radosavovic and Kaiyuan Wang and Albert Zhan and Kevin Black and Cheng Chi and Kyle Beltran Hatch and Shan Lin and Jingpei Lu and Jean Mercat and Abdul Rehman and Pannag R Sanketi and Archit Sharma and Cody Simpson and Quan Vuong and Homer Rich Walke and Blake Wulfe and Ted Xiao and Jonathan Heewon Yang and Arefeh Yavary and Tony Z. Zhao and Christopher Agia and Rohan Baijal and Mateo Guaman Castro and Daphne Chen and Qiuyu Chen and Trinity Chung and Jaimyn Drake and Ethan Paul Foster and Jensen Gao and David Antonio Herrera and Minho Heo and Kyle Hsu and Jiaheng Hu and Donovon Jackson and Charlotte Le and Yunshuang Li and Kevin Lin and Roy Lin and Zehan Ma and Abhiram Maddukuri and Suvir Mirchandani and Daniel Morton and Tony Nguyen and Abigail O'Neill and Rosario Scalise and Derick Seale and Victor Son and Stephen Tian and Emi Tran and Andrew E. Wang and Yilin Wu and Annie Xie and Jingyun Yang and Patrick Yin and Yunchu Zhang and Osbert Bastani and Glen Berseth and Jeannette Bohg and Ken Goldberg and Abhinav Gupta and Abhishek Gupta and Dinesh Jayaraman and Joseph J Lim and Jitendra Malik and Roberto Martín-Martín and Subramanian Ramamoorthy and Dorsa Sadigh and Shuran Song and Jiajun Wu and Michael C. Yip and Yuke Zhu and Thomas Kollar and Sergey Levine and Chelsea Finn},
year = {2024},
}""").lstrip(),
},
"fmb": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"url": "https://functional-manipulation-benchmark.github.io/",
"paper": "https://arxiv.org/abs/2401.08553",
"citation_bibtex": dedent(r"""
@article{luo2024fmb,
title={FMB: a Functional Manipulation Benchmark for Generalizable Robotic Learning},
author={Luo, Jianlan and Xu, Charles and Liu, Fangchen and Tan, Liam and Lin, Zipeng and Wu, Jeffrey and Abbeel, Pieter and Levine, Sergey},
journal={arXiv preprint arXiv:2401.08553},
year={2024}
}""").lstrip(),
},
"iamlab_cmu_pickup_insert": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://openreview.net/forum?id=WuBv9-IGDUA",
"paper": "https://arxiv.org/abs/2401.14502",
"citation_bibtex": dedent(r"""
@inproceedings{saxena2023multiresolution,
title={Multi-Resolution Sensing for Real-Time Control with Vision-Language Models},
author={Saumya Saxena and Mohit Sharma and Oliver Kroemer},
booktitle={7th Annual Conference on Robot Learning},
year={2023},
url={https://openreview.net/forum?id=WuBv9-IGDUA}
}""").lstrip(),
},
"imperialcollege_sawyer_wrist_cam": {
"tasks_col": "language_instruction",
"license": "mit",
},
"jaco_play": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"url": "https://github.com/clvrai/clvr_jaco_play_dataset",
"citation_bibtex": dedent(r"""
@software{dass2023jacoplay,
author = {Dass, Shivin and Yapeter, Jullian and Zhang, Jesse and Zhang, Jiahui
and Pertsch, Karl and Nikolaidis, Stefanos and Lim, Joseph J.},
title = {CLVR Jaco Play Dataset},
url = {https://github.com/clvrai/clvr_jaco_play_dataset},
version = {1.0.0},
year = {2023}
}""").lstrip(),
},
"kaist_nonprehensile": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"url": "https://github.com/JaeHyung-Kim/rlds_dataset_builder",
"citation_bibtex": dedent(r"""
@article{kimpre,
title={Pre-and post-contact policy decomposition for non-prehensile manipulation with zero-shot sim-to-real transfer},
author={Kim, Minchan and Han, Junhyek and Kim, Jaehyung and Kim, Beomjoon},
booktitle={2023 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
year={2023},
organization={IEEE}
}""").lstrip(),
},
"nyu_door_opening_surprising_effectiveness": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://jyopari.github.io/VINN/",
"paper": "https://arxiv.org/abs/2112.01511",
"citation_bibtex": dedent(r"""
@misc{pari2021surprising,
title={The Surprising Effectiveness of Representation Learning for Visual Imitation},
author={Jyothish Pari and Nur Muhammad Shafiullah and Sridhar Pandian Arunachalam and Lerrel Pinto},
year={2021},
eprint={2112.01511},
archivePrefix={arXiv},
primaryClass={cs.RO}
}""").lstrip(),
},
"nyu_franka_play_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://play-to-policy.github.io/",
"paper": "https://arxiv.org/abs/2210.10047",
"citation_bibtex": dedent(r"""
@article{cui2022play,
title = {From Play to Policy: Conditional Behavior Generation from Uncurated Robot Data},
author = {Cui, Zichen Jeff and Wang, Yibin and Shafiullah, Nur Muhammad Mahi and Pinto, Lerrel},
journal = {arXiv preprint arXiv:2210.10047},
year = {2022}
}""").lstrip(),
},
"nyu_rot_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://rot-robot.github.io/",
"paper": "https://arxiv.org/abs/2206.15469",
"citation_bibtex": dedent(r"""
@inproceedings{haldar2023watch,
title={Watch and match: Supercharging imitation with regularized optimal transport},
author={Haldar, Siddhant and Mathur, Vaibhav and Yarats, Denis and Pinto, Lerrel},
booktitle={Conference on Robot Learning},
pages={32--43},
year={2023},
organization={PMLR}
}""").lstrip(),
},
"roboturk": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://roboturk.stanford.edu/dataset_real.html",
"paper": "PAPER",
"citation_bibtex": dedent(r"""
@inproceedings{mandlekar2019scaling,
title={Scaling robot supervision to hundreds of hours with roboturk: Robotic manipulation dataset through human reasoning and dexterity},
author={Mandlekar, Ajay and Booher, Jonathan and Spero, Max and Tung, Albert and Gupta, Anchit and Zhu, Yuke and Garg, Animesh and Savarese, Silvio and Fei-Fei, Li},
booktitle={2019 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
pages={1048--1055},
year={2019},
organization={IEEE}
}""").lstrip(),
},
"stanford_hydra_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://sites.google.com/view/hydra-il-2023",
"paper": "https://arxiv.org/abs/2306.17237",
"citation_bibtex": dedent(r"""
@article{belkhale2023hydra,
title={HYDRA: Hybrid Robot Actions for Imitation Learning},
author={Belkhale, Suneel and Cui, Yuchen and Sadigh, Dorsa},
journal={arxiv},
year={2023}
}""").lstrip(),
},
"stanford_kuka_multimodal_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://sites.google.com/view/visionandtouch",
"paper": "https://arxiv.org/abs/1810.10191",
"citation_bibtex": dedent(r"""
@inproceedings{lee2019icra,
title={Making sense of vision and touch: Self-supervised learning of multimodal representations for contact-rich tasks},
author={Lee, Michelle A and Zhu, Yuke and Srinivasan, Krishnan and Shah, Parth and Savarese, Silvio and Fei-Fei, Li and Garg, Animesh and Bohg, Jeannette},
booktitle={2019 IEEE International Conference on Robotics and Automation (ICRA)},
year={2019},
url={https://arxiv.org/abs/1810.10191}
}""").lstrip(),
},
"stanford_robocook": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://hshi74.github.io/robocook/",
"paper": "https://arxiv.org/abs/2306.14447",
"citation_bibtex": dedent(r"""
@article{shi2023robocook,
title={RoboCook: Long-Horizon Elasto-Plastic Object Manipulation with Diverse Tools},
author={Shi, Haochen and Xu, Huazhe and Clarke, Samuel and Li, Yunzhu and Wu, Jiajun},
journal={arXiv preprint arXiv:2306.14447},
year={2023}
}""").lstrip(),
},
"taco_play": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"url": "https://www.kaggle.com/datasets/oiermees/taco-robot",
"paper": "https://arxiv.org/abs/2209.08959, https://arxiv.org/abs/2210.01911",
"citation_bibtex": dedent(r"""
@inproceedings{rosete2022tacorl,
author = {Erick Rosete-Beas and Oier Mees and Gabriel Kalweit and Joschka Boedecker and Wolfram Burgard},
title = {Latent Plans for Task Agnostic Offline Reinforcement Learning},
journal = {Proceedings of the 6th Conference on Robot Learning (CoRL)},
year = {2022}
}
@inproceedings{mees23hulc2,
title={Grounding Language with Visual Affordances over Unstructured Data},
author={Oier Mees and Jessica Borja-Diaz and Wolfram Burgard},
booktitle = {Proceedings of the IEEE International Conference on Robotics and Automation (ICRA)},
year={2023},
address = {London, UK}
}""").lstrip(),
},
"tokyo_u_lsmo": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "URL",
"paper": "https://arxiv.org/abs/2107.05842",
"citation_bibtex": dedent(r"""
@Article{Osa22,
author = {Takayuki Osa},
journal = {The International Journal of Robotics Research},
title = {Motion Planning by Learning the Solution Manifold in Trajectory Optimization},
year = {2022},
number = {3},
pages = {291--311},
volume = {41},
}""").lstrip(),
},
"toto": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://toto-benchmark.org/",
"paper": "https://arxiv.org/abs/2306.00942",
"citation_bibtex": dedent(r"""
@inproceedings{zhou2023train,
author={Zhou, Gaoyue and Dean, Victoria and Srirama, Mohan Kumar and Rajeswaran, Aravind and Pari, Jyothish and Hatch, Kyle and Jain, Aryan and Yu, Tianhe and Abbeel, Pieter and Pinto, Lerrel and Finn, Chelsea and Gupta, Abhinav},
booktitle={2023 IEEE International Conference on Robotics and Automation (ICRA)},
title={Train Offline, Test Online: A Real Robot Learning Benchmark},
year={2023},
}""").lstrip(),
},
"ucsd_kitchen_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"citation_bibtex": dedent(r"""
@ARTICLE{ucsd_kitchens,
author = {Ge Yan, Kris Wu, and Xiaolong Wang},
title = {{ucsd kitchens Dataset}},
year = {2023},
month = {August}
}""").lstrip(),
},
"ucsd_pick_and_place_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://owmcorl.github.io/#",
"paper": "https://arxiv.org/abs/2310.16029",
"citation_bibtex": dedent(r"""
@preprint{Feng2023Finetuning,
title={Finetuning Offline World Models in the Real World},
author={Yunhai Feng, Nicklas Hansen, Ziyan Xiong, Chandramouli Rajagopalan, Xiaolong Wang},
year={2023}
}""").lstrip(),
},
"uiuc_d3field": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://robopil.github.io/d3fields/",
"paper": "https://arxiv.org/abs/2309.16118",
"citation_bibtex": dedent(r"""
@article{wang2023d3field,
title={D^3Field: Dynamic 3D Descriptor Fields for Generalizable Robotic Manipulation},
author={Wang, Yixuan and Li, Zhuoran and Zhang, Mingtong and Driggs-Campbell, Katherine and Wu, Jiajun and Fei-Fei, Li and Li, Yunzhu},
journal={arXiv preprint arXiv:},
year={2023},
}""").lstrip(),
},
"usc_cloth_sim": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://uscresl.github.io/dmfd/",
"paper": "https://arxiv.org/abs/2207.10148",
"citation_bibtex": dedent(r"""
@article{salhotra2022dmfd,
author={Salhotra, Gautam and Liu, I-Chun Arthur and Dominguez-Kuhne, Marcus and Sukhatme, Gaurav S.},
journal={IEEE Robotics and Automation Letters},
title={Learning Deformable Object Manipulation From Expert Demonstrations},
year={2022},
volume={7},
number={4},
pages={8775-8782},
doi={10.1109/LRA.2022.3187843}
}""").lstrip(),
},
"utaustin_mutex": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://ut-austin-rpl.github.io/MUTEX/",
"paper": "https://arxiv.org/abs/2309.14320",
"citation_bibtex": dedent(r"""
@inproceedings{shah2023mutex,
title={{MUTEX}: Learning Unified Policies from Multimodal Task Specifications},
author={Rutav Shah and Roberto Mart{\'\i}n-Mart{\'\i}n and Yuke Zhu},
booktitle={7th Annual Conference on Robot Learning},
year={2023},
url={https://openreview.net/forum?id=PwqiqaaEzJ}
}""").lstrip(),
},
"utokyo_pr2_opening_fridge": {
"tasks_col": "language_instruction",
"license": "mit",
"citation_bibtex": dedent(r"""
@misc{oh2023pr2utokyodatasets,
author={Jihoon Oh and Naoaki Kanazawa and Kento Kawaharazuka},
title={X-Embodiment U-Tokyo PR2 Datasets},
year={2023},
url={https://github.com/ojh6404/rlds_dataset_builder},
}""").lstrip(),
},
"utokyo_pr2_tabletop_manipulation": {
"tasks_col": "language_instruction",
"license": "mit",
"citation_bibtex": dedent(r"""
@misc{oh2023pr2utokyodatasets,
author={Jihoon Oh and Naoaki Kanazawa and Kento Kawaharazuka},
title={X-Embodiment U-Tokyo PR2 Datasets},
year={2023},
url={https://github.com/ojh6404/rlds_dataset_builder},
}""").lstrip(),
},
"utokyo_saytap": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://saytap.github.io/",
"paper": "https://arxiv.org/abs/2306.07580",
"citation_bibtex": dedent(r"""
@article{saytap2023,
author = {Yujin Tang and Wenhao Yu and Jie Tan and Heiga Zen and Aleksandra Faust and
Tatsuya Harada},
title = {SayTap: Language to Quadrupedal Locomotion},
eprint = {arXiv:2306.07580},
url = {https://saytap.github.io},
note = {https://saytap.github.io},
year = {2023}
}""").lstrip(),
},
"utokyo_xarm_bimanual": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"citation_bibtex": dedent(r"""
@misc{matsushima2023weblab,
title={Weblab xArm Dataset},
author={Tatsuya Matsushima and Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo},
year={2023},
}""").lstrip(),
},
"utokyo_xarm_pick_and_place": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"citation_bibtex": dedent(r"""
@misc{matsushima2023weblab,
title={Weblab xArm Dataset},
author={Tatsuya Matsushima and Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo},
year={2023},
}""").lstrip(),
},
"viola": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://ut-austin-rpl.github.io/VIOLA/",
"paper": "https://arxiv.org/abs/2210.11339",
"citation_bibtex": dedent(r"""
@article{zhu2022viola,
title={VIOLA: Imitation Learning for Vision-Based Manipulation with Object Proposal Priors},
author={Zhu, Yifeng and Joshi, Abhishek and Stone, Peter and Zhu, Yuke},
journal={6th Annual Conference on Robot Learning (CoRL)},
year={2022}
}""").lstrip(),
},
}
def batch_convert():
status = {}
logfile = LOCAL_DIR / "conversion_log.txt"
assert set(DATASETS) == {id_.split("/")[1] for id_ in available_datasets}
for num, (name, kwargs) in enumerate(DATASETS.items()):
repo_id = f"lerobot/{name}"
print(f"\nConverting {repo_id} ({num}/{len(DATASETS)})")
print("---------------------------------------------------------")
try:
convert_dataset(repo_id, LOCAL_DIR, **kwargs)
status = f"{repo_id}: success."
with open(logfile, "a") as file:
file.write(status + "\n")
except Exception:
status = f"{repo_id}: failed\n {traceback.format_exc()}"
with open(logfile, "a") as file:
file.write(status + "\n")
continue
if __name__ == "__main__":
batch_convert()

View File

@ -0,0 +1,665 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
for each of the task performed in the dataset. This will allow to easily train models with task-conditionning.
We support 3 different scenarios for these tasks (see instructions below):
1. Single task dataset: all episodes of your dataset have the same single task.
2. Single task episodes: the episodes of your dataset each contain a single task but they can differ from
one episode to the next.
3. Multi task episodes: episodes of your dataset may each contain several different tasks.
Can you can also provide a robot config .yaml file (not mandatory) to this script via the option
'--robot-config' so that it writes information about the robot (robot type, motors names) this dataset was
recorded with. For now, only Aloha/Koch type robots are supported with this option.
# 1. Single task dataset
If your dataset contains a single task, you can simply provide it directly via the CLI with the
'--single-task' option.
Examples:
```bash
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
--repo-id lerobot/aloha_sim_insertion_human_image \
--single-task "Insert the peg into the socket." \
--robot-config lerobot/configs/robot/aloha.yaml \
--local-dir data
```
```bash
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
--repo-id aliberts/koch_tutorial \
--single-task "Pick the Lego block and drop it in the box on the right." \
--robot-config lerobot/configs/robot/koch.yaml \
--local-dir data
```
# 2. Single task episodes
If your dataset is a multi-task dataset, you have two options to provide the tasks to this script:
- If your dataset already contains a language instruction column in its parquet file, you can simply provide
this column's name with the '--tasks-col' arg.
Example:
```bash
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
--repo-id lerobot/stanford_kuka_multimodal_dataset \
--tasks-col "language_instruction" \
--local-dir data
```
- If your dataset doesn't contain a language instruction, you should provide the path to a .json file with the
'--tasks-path' arg. This file should have the following structure where keys correspond to each
episode_index in the dataset, and values are the language instruction for that episode.
Example:
```json
{
"0": "Do something",
"1": "Do something else",
"2": "Do something",
"3": "Go there",
...
}
```
# 3. Multi task episodes
If you have multiple tasks per episodes, your dataset should contain a language instruction column in its
parquet file, and you must provide this column's name with the '--tasks-col' arg.
Example:
```bash
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
--repo-id lerobot/stanford_kuka_multimodal_dataset \
--tasks-col "language_instruction" \
--local-dir data
```
"""
import argparse
import contextlib
import filecmp
import json
import logging
import math
import shutil
import subprocess
import tempfile
from pathlib import Path
import datasets
import pyarrow.compute as pc
import pyarrow.parquet as pq
import torch
from datasets import Dataset
from huggingface_hub import HfApi
from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError
from safetensors.torch import load_file
from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_PARQUET_PATH,
DEFAULT_VIDEO_PATH,
EPISODES_PATH,
INFO_PATH,
STATS_PATH,
TASKS_PATH,
create_branch,
create_lerobot_dataset_card,
flatten_dict,
get_hub_safe_version,
load_json,
unflatten_dict,
write_json,
write_jsonlines,
)
from lerobot.common.datasets.video_utils import (
VideoFrame, # noqa: F401
get_image_pixel_channels,
get_video_info,
)
from lerobot.common.utils.utils import init_hydra_config
V16 = "v1.6"
V20 = "v2.0"
GITATTRIBUTES_REF = "aliberts/gitattributes_reference"
V1_VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
V1_INFO_PATH = "meta_data/info.json"
V1_STATS_PATH = "meta_data/stats.safetensors"
def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]:
robot_cfg = init_hydra_config(config_path, config_overrides)
if robot_cfg["robot_type"] in ["aloha", "koch"]:
state_names = [
f"{arm}_{motor}" if len(robot_cfg["follower_arms"]) > 1 else motor
for arm in robot_cfg["follower_arms"]
for motor in robot_cfg["follower_arms"][arm]["motors"]
]
action_names = [
# f"{arm}_{motor}" for arm in ["left", "right"] for motor in robot_cfg["leader_arms"][arm]["motors"]
f"{arm}_{motor}" if len(robot_cfg["leader_arms"]) > 1 else motor
for arm in robot_cfg["leader_arms"]
for motor in robot_cfg["leader_arms"][arm]["motors"]
]
# elif robot_cfg["robot_type"] == "stretch3": TODO
else:
raise NotImplementedError(
"Please provide robot_config={'robot_type': ..., 'names': ...} directly to convert_dataset()."
)
return {
"robot_type": robot_cfg["robot_type"],
"names": {
"observation.state": state_names,
"observation.effort": state_names,
"action": action_names,
},
}
def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
safetensor_path = v1_dir / V1_STATS_PATH
stats = load_file(safetensor_path)
serialized_stats = {key: value.tolist() for key, value in stats.items()}
serialized_stats = unflatten_dict(serialized_stats)
json_path = v2_dir / STATS_PATH
json_path.parent.mkdir(exist_ok=True, parents=True)
with open(json_path, "w") as f:
json.dump(serialized_stats, f, indent=4)
# Sanity check
with open(json_path) as f:
stats_json = json.load(f)
stats_json = flatten_dict(stats_json)
stats_json = {key: torch.tensor(value) for key, value in stats_json.items()}
for key in stats:
torch.testing.assert_close(stats_json[key], stats[key])
def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = None) -> dict[str, list]:
features = {}
for key, ft in dataset.features.items():
if isinstance(ft, datasets.Value):
dtype = ft.dtype
shape = (1,)
names = None
if isinstance(ft, datasets.Sequence):
assert isinstance(ft.feature, datasets.Value)
dtype = ft.feature.dtype
shape = (ft.length,)
motor_names = (
robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
)
assert len(motor_names) == shape[0]
names = {"motors": motor_names}
elif isinstance(ft, datasets.Image):
dtype = "image"
image = dataset[0][key] # Assuming first row
channels = get_image_pixel_channels(image)
shape = (image.height, image.width, channels)
names = ["height", "width", "channel"]
elif ft._type == "VideoFrame":
dtype = "video"
shape = None # Add shape later
names = ["height", "width", "channel"]
features[key] = {
"dtype": dtype,
"shape": shape,
"names": names,
}
return features
def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
df = dataset.to_pandas()
tasks = list(set(tasks_by_episodes.values()))
tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()}
df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
features = dataset.features
features["task_index"] = datasets.Value(dtype="int64")
dataset = Dataset.from_pandas(df, features=features, split="train")
return dataset, tasks
def add_task_index_from_tasks_col(
dataset: Dataset, tasks_col: str
) -> tuple[Dataset, dict[str, list[str]], list[str]]:
df = dataset.to_pandas()
# HACK: This is to clean some of the instructions in our version of Open X datasets
prefix_to_clean = "tf.Tensor(b'"
suffix_to_clean = "', shape=(), dtype=string)"
df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean)
# Create task_index col
tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict()
tasks = df[tasks_col].unique().tolist()
tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
# Build the dataset back from df
features = dataset.features
features["task_index"] = datasets.Value(dtype="int64")
dataset = Dataset.from_pandas(df, features=features, split="train")
dataset = dataset.remove_columns(tasks_col)
return dataset, tasks, tasks_by_episode
def split_parquet_by_episodes(
dataset: Dataset,
total_episodes: int,
total_chunks: int,
output_dir: Path,
) -> list:
table = dataset.data.table
episode_lengths = []
for ep_chunk in range(total_chunks):
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
for ep_idx in range(ep_chunk_start, ep_chunk_end):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
episode_lengths.insert(ep_idx, len(ep_table))
output_file = output_dir / DEFAULT_PARQUET_PATH.format(
episode_chunk=ep_chunk, episode_index=ep_idx
)
pq.write_table(ep_table, output_file)
return episode_lengths
def move_videos(
repo_id: str,
video_keys: list[str],
total_episodes: int,
total_chunks: int,
work_dir: Path,
clean_gittatributes: Path,
branch: str = "main",
) -> None:
"""
HACK: Since HfApi() doesn't provide a way to move files directly in a repo, this function will run git
commands to fetch git lfs video files references to move them into subdirectories without having to
actually download them.
"""
_lfs_clone(repo_id, work_dir, branch)
videos_moved = False
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
if len(video_files) == 0:
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
videos_moved = True # Videos have already been moved
assert len(video_files) == total_episodes * len(video_keys)
lfs_untracked_videos = _get_lfs_untracked_videos(work_dir, video_files)
current_gittatributes = work_dir / ".gitattributes"
if not filecmp.cmp(current_gittatributes, clean_gittatributes, shallow=False):
fix_gitattributes(work_dir, current_gittatributes, clean_gittatributes)
if lfs_untracked_videos:
fix_lfs_video_files_tracking(work_dir, video_files)
if videos_moved:
return
video_dirs = sorted(work_dir.glob("videos*/"))
for ep_chunk in range(total_chunks):
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
for vid_key in video_keys:
chunk_dir = "/".join(DEFAULT_VIDEO_PATH.split("/")[:-1]).format(
episode_chunk=ep_chunk, video_key=vid_key
)
(work_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
for ep_idx in range(ep_chunk_start, ep_chunk_end):
target_path = DEFAULT_VIDEO_PATH.format(
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
)
video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
if len(video_dirs) == 1:
video_path = video_dirs[0] / video_file
else:
for dir in video_dirs:
if (dir / video_file).is_file():
video_path = dir / video_file
break
video_path.rename(work_dir / target_path)
commit_message = "Move video files into chunk subdirectories"
subprocess.run(["git", "add", "."], cwd=work_dir, check=True)
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
subprocess.run(["git", "push"], cwd=work_dir, check=True)
def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None:
"""
HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
there's no other option than to download the actual files and reupload them with lfs tracking.
"""
for i in range(0, len(lfs_untracked_videos), 100):
files = lfs_untracked_videos[i : i + 100]
try:
subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True)
except subprocess.CalledProcessError as e:
print("git rm --cached ERROR:")
print(e.stderr)
subprocess.run(["git", "add", *files], cwd=work_dir, check=True)
commit_message = "Track video files with git lfs"
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
subprocess.run(["git", "push"], cwd=work_dir, check=True)
def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None:
shutil.copyfile(clean_gittatributes, current_gittatributes)
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True)
subprocess.run(["git", "push"], cwd=work_dir, check=True)
def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
subprocess.run(["git", "lfs", "install"], cwd=work_dir, check=True)
repo_url = f"https://huggingface.co/datasets/{repo_id}"
env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
subprocess.run(
["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)],
check=True,
env=env,
)
def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
lfs_tracked_files = subprocess.run(
["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
)
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
return [f for f in video_files if f not in lfs_tracked_files]
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
# Assumes first episode
video_files = [
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
for vid_key in video_keys
]
hub_api = HfApi()
hub_api.snapshot_download(
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
)
videos_info_dict = {}
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
videos_info_dict[vid_key] = get_video_info(local_dir / vid_path)
return videos_info_dict
def convert_dataset(
repo_id: str,
local_dir: Path,
single_task: str | None = None,
tasks_path: Path | None = None,
tasks_col: Path | None = None,
robot_config: dict | None = None,
test_branch: str | None = None,
**card_kwargs,
):
v1 = get_hub_safe_version(repo_id, V16)
v1x_dir = local_dir / V16 / repo_id
v20_dir = local_dir / V20 / repo_id
v1x_dir.mkdir(parents=True, exist_ok=True)
v20_dir.mkdir(parents=True, exist_ok=True)
hub_api = HfApi()
hub_api.snapshot_download(
repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/"
)
branch = "main"
if test_branch:
branch = test_branch
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset")
metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
features = get_features_from_hf_dataset(dataset, robot_config)
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
if single_task and "language_instruction" in dataset.column_names:
logging.warning(
"'single_task' provided but 'language_instruction' tasks_col found. Using 'language_instruction'.",
)
single_task = None
tasks_col = "language_instruction"
# Episodes & chunks
episode_indices = sorted(dataset.unique("episode_index"))
total_episodes = len(episode_indices)
assert episode_indices == list(range(total_episodes))
total_videos = total_episodes * len(video_keys)
total_chunks = total_episodes // DEFAULT_CHUNK_SIZE
if total_episodes % DEFAULT_CHUNK_SIZE != 0:
total_chunks += 1
# Tasks
if single_task:
tasks_by_episodes = {ep_idx: single_task for ep_idx in episode_indices}
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
elif tasks_path:
tasks_by_episodes = load_json(tasks_path)
tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
elif tasks_col:
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col)
else:
raise ValueError
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
write_jsonlines(tasks, v20_dir / TASKS_PATH)
features["task_index"] = {
"dtype": "int64",
"shape": (1,),
"names": None,
}
# Videos
if video_keys:
assert metadata_v1.get("video", False)
dataset = dataset.remove_columns(video_keys)
clean_gitattr = Path(
hub_api.hf_hub_download(
repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
)
).absolute()
with tempfile.TemporaryDirectory() as tmp_video_dir:
move_videos(
repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch
)
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
for key in video_keys:
features[key]["shape"] = (
videos_info[key].pop("video.height"),
videos_info[key].pop("video.width"),
videos_info[key].pop("video.channels"),
)
features[key]["video_info"] = videos_info[key]
assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
if "encoding" in metadata_v1:
assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
else:
assert metadata_v1.get("video", 0) == 0
videos_info = None
# Split data into 1 parquet file by episode
episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
if robot_config is not None:
robot_type = robot_config["robot_type"]
repo_tags = [robot_type]
else:
robot_type = "unknown"
repo_tags = None
# Episodes
episodes = [
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
for ep_idx in episode_indices
]
write_jsonlines(episodes, v20_dir / EPISODES_PATH)
# Assemble metadata v2.0
metadata_v2_0 = {
"codebase_version": V20,
"robot_type": robot_type,
"total_episodes": total_episodes,
"total_frames": len(dataset),
"total_tasks": len(tasks),
"total_videos": total_videos,
"total_chunks": total_chunks,
"chunks_size": DEFAULT_CHUNK_SIZE,
"fps": metadata_v1["fps"],
"splits": {"train": f"0:{total_episodes}"},
"data_path": DEFAULT_PARQUET_PATH,
"video_path": DEFAULT_VIDEO_PATH if video_keys else None,
"features": features,
}
write_json(metadata_v2_0, v20_dir / INFO_PATH)
convert_stats_to_json(v1x_dir, v20_dir)
card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
hub_api.upload_folder(
repo_id=repo_id,
path_in_repo="data",
folder_path=v20_dir / "data",
repo_type="dataset",
revision=branch,
)
hub_api.upload_folder(
repo_id=repo_id,
path_in_repo="meta",
folder_path=v20_dir / "meta",
repo_type="dataset",
revision=branch,
)
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=branch)
if not test_branch:
create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
def main():
parser = argparse.ArgumentParser()
task_args = parser.add_mutually_exclusive_group(required=True)
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
)
task_args.add_argument(
"--single-task",
type=str,
help="A short but accurate description of the single task performed in the dataset.",
)
task_args.add_argument(
"--tasks-col",
type=str,
help="The name of the column containing language instructions",
)
task_args.add_argument(
"--tasks-path",
type=Path,
help="The path to a .json file containing one language instruction for each episode_index",
)
parser.add_argument(
"--robot-config",
type=Path,
default=None,
help="Path to the robot's config yaml the dataset during conversion.",
)
parser.add_argument(
"--robot-overrides",
type=str,
nargs="*",
help="Any key=value arguments to override the robot config values (use dots for.nested=overrides)",
)
parser.add_argument(
"--local-dir",
type=Path,
default=None,
help="Local directory to store the dataset during conversion. Defaults to /tmp/lerobot_dataset_v2",
)
parser.add_argument(
"--license",
type=str,
default="apache-2.0",
help="Repo license. Must be one of https://huggingface.co/docs/hub/repositories-licenses. Defaults to mit.",
)
parser.add_argument(
"--test-branch",
type=str,
default=None,
help="Repo branch to test your conversion first (e.g. 'v2.0.test')",
)
args = parser.parse_args()
if not args.local_dir:
args.local_dir = Path("/tmp/lerobot_dataset_v2")
robot_config = parse_robot_config(args.robot_config, args.robot_overrides) if args.robot_config else None
del args.robot_config, args.robot_overrides
convert_dataset(**vars(args), robot_config=robot_config)
if __name__ == "__main__":
main()

View File

@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import subprocess
import warnings
@ -25,47 +26,11 @@ import pyarrow as pa
import torch
import torchvision
from datasets.features.features import register_feature
def load_from_videos(
item: dict[str, torch.Tensor],
video_frame_keys: list[str],
videos_dir: Path,
tolerance_s: float,
backend: str = "pyav",
):
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault.
This probably happens because a memory reference to the video loader is created in the main process and a
subprocess fails to access it.
"""
# since video path already contains "videos" (e.g. videos_dir="data/videos", path="videos/episode_0.mp4")
data_dir = videos_dir.parent
for key in video_frame_keys:
if isinstance(item[key], list):
# load multiple frames at once (expected when delta_timestamps is not None)
timestamps = [frame["timestamp"] for frame in item[key]]
paths = [frame["path"] for frame in item[key]]
if len(set(paths)) > 1:
raise NotImplementedError("All video paths are expected to be the same for now.")
video_path = data_dir / paths[0]
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
item[key] = frames
else:
# load one frame
timestamps = [item[key]["timestamp"]]
video_path = data_dir / item[key]["path"]
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
item[key] = frames[0]
return item
from PIL import Image
def decode_video_frames_torchvision(
video_path: str,
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
backend: str = "pyav",
@ -163,8 +128,8 @@ def decode_video_frames_torchvision(
def encode_video_frames(
imgs_dir: Path,
video_path: Path,
imgs_dir: Path | str,
video_path: Path | str,
fps: int,
# vcodec: str = "libsvtav1",
vcodec: str = "libx264",
@ -248,3 +213,104 @@ with warnings.catch_warnings():
)
# to make VideoFrame available in HuggingFace `datasets`
register_feature(VideoFrame, "VideoFrame")
def get_audio_info(video_path: Path | str) -> dict:
ffprobe_audio_cmd = [
"ffprobe",
"-v",
"error",
"-select_streams",
"a:0",
"-show_entries",
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
"-of",
"json",
str(video_path),
]
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
info = json.loads(result.stdout)
audio_stream_info = info["streams"][0] if info.get("streams") else None
if audio_stream_info is None:
return {"has_audio": False}
# Return the information, defaulting to None if no audio stream is present
return {
"has_audio": True,
"audio.channels": audio_stream_info.get("channels", None),
"audio.codec": audio_stream_info.get("codec_name", None),
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
"audio.sample_rate": int(audio_stream_info["sample_rate"])
if audio_stream_info.get("sample_rate")
else None,
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
}
def get_video_info(video_path: Path | str) -> dict:
ffprobe_video_cmd = [
"ffprobe",
"-v",
"error",
"-select_streams",
"v:0",
"-show_entries",
"stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt",
"-of",
"json",
str(video_path),
]
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
info = json.loads(result.stdout)
video_stream_info = info["streams"][0]
# Calculate fps from r_frame_rate
r_frame_rate = video_stream_info["r_frame_rate"]
num, denom = map(int, r_frame_rate.split("/"))
fps = num / denom
pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"])
video_info = {
"video.fps": fps,
"video.height": video_stream_info["height"],
"video.width": video_stream_info["width"],
"video.channels": pixel_channels,
"video.codec": video_stream_info["codec_name"],
"video.pix_fmt": video_stream_info["pix_fmt"],
"video.is_depth_map": False,
**get_audio_info(video_path),
}
return video_info
def get_video_pixel_channels(pix_fmt: str) -> int:
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
return 1
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
return 4
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
return 3
else:
raise ValueError("Unknown format")
def get_image_pixel_channels(image: Image):
if image.mode == "L":
return 1 # Grayscale
elif image.mode == "LA":
return 2 # Grayscale + Alpha
elif image.mode == "RGB":
return 3 # RGB
elif image.mode == "RGBA":
return 4 # RGBA
else:
raise ValueError("Unknown format")

View File

@ -168,6 +168,7 @@ class IntelRealSenseCameraConfig:
width: int | None = None
height: int | None = None
color_mode: str = "rgb"
channels: int | None = None
use_depth: bool = False
force_hardware_reset: bool = True
rotation: int | None = None
@ -179,6 +180,8 @@ class IntelRealSenseCameraConfig:
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
self.channels = 3
at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None
at_least_one_is_none = self.fps is None or self.width is None or self.height is None
if at_least_one_is_not_none and at_least_one_is_none:
@ -254,6 +257,7 @@ class IntelRealSenseCamera:
self.fps = config.fps
self.width = config.width
self.height = config.height
self.channels = config.channels
self.color_mode = config.color_mode
self.use_depth = config.use_depth
self.force_hardware_reset = config.force_hardware_reset

View File

@ -192,6 +192,7 @@ class OpenCVCameraConfig:
width: int | None = None
height: int | None = None
color_mode: str = "rgb"
channels: int | None = None
rotation: int | None = None
mock: bool = False
@ -201,6 +202,8 @@ class OpenCVCameraConfig:
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
self.channels = 3
if self.rotation not in [-90, None, 90, 180]:
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
@ -268,6 +271,7 @@ class OpenCVCamera:
self.fps = config.fps
self.width = config.width
self.height = config.height
self.channels = config.channels
self.color_mode = config.color_mode
self.mock = config.mock

View File

@ -13,9 +13,12 @@ from functools import cache
import cv2
import torch
import tqdm
from deepdiff import DeepDiff
from termcolor import colored
from lerobot.common.datasets.populate_dataset import add_frame, safe_stop_image_writer
from lerobot.common.datasets.image_writer import safe_stop_image_writer
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import get_features_from_robot
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait
@ -227,7 +230,7 @@ def control_loop(
control_time_s=None,
teleoperate=False,
display_cameras=False,
dataset=None,
dataset: LeRobotDataset | None = None,
events=None,
policy=None,
device=None,
@ -247,7 +250,7 @@ def control_loop(
if teleoperate and policy is not None:
raise ValueError("When `teleoperate` is True, `policy` should be None.")
if dataset is not None and fps is not None and dataset["fps"] != fps:
if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
timestamp = 0
@ -268,7 +271,8 @@ def control_loop(
action = {"action": action}
if dataset is not None:
add_frame(dataset, observation, action)
frame = {**observation, **action}
dataset.add_frame(frame)
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
@ -324,7 +328,36 @@ def sanity_check_dataset_name(repo_id, policy):
_, dataset_name = repo_id.split("/")
# either repo_id doesnt start with "eval_" and there is no policy
# or repo_id starts with "eval_" and there is a policy
if dataset_name.startswith("eval_") == (policy is None):
# Check if dataset_name starts with "eval_" but policy is missing
if dataset_name.startswith("eval_") and policy is None:
raise ValueError(
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided."
)
# Check if dataset_name does not start with "eval_" but policy is provided
if not dataset_name.startswith("eval_") and policy is not None:
raise ValueError(
f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy})."
)
def sanity_check_dataset_robot_compatibility(
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool
) -> None:
fields = [
("robot_type", dataset.meta.robot_type, robot.robot_type),
("fps", dataset.fps, fps),
("features", dataset.features, get_features_from_robot(robot, use_videos)),
]
mismatches = []
for field, dataset_value, present_value in fields:
diff = DeepDiff(dataset_value, present_value)
if diff:
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
if mismatches:
raise ValueError(
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
)

View File

@ -226,6 +226,42 @@ class ManipulatorRobot:
self.is_connected = False
self.logs = {}
def get_motor_names(self, arm: dict[str, MotorsBus]) -> list:
return [f"{arm}_{motor}" for arm, bus in arm.items() for motor in bus.motors]
@property
def camera_features(self) -> dict:
cam_ft = {}
for cam_key, cam in self.cameras.items():
key = f"observation.images.{cam_key}"
cam_ft[key] = {
"shape": (cam.height, cam.width, cam.channels),
"names": ["height", "width", "channels"],
"info": None,
}
return cam_ft
@property
def motor_features(self) -> dict:
action_names = self.get_motor_names(self.leader_arms)
state_names = self.get_motor_names(self.leader_arms)
return {
"action": {
"dtype": "float32",
"shape": (len(action_names),),
"names": action_names,
},
"observation.state": {
"dtype": "float32",
"shape": (len(state_names),),
"names": state_names,
},
}
@property
def features(self):
return {**self.motor_features, **self.camera_features}
@property
def has_camera(self):
return len(self.cameras) > 0

View File

@ -11,6 +11,7 @@ def get_arm_id(name, arm_type):
class Robot(Protocol):
# TODO(rcadene, aliberts): Add unit test checking the protocol is implemented in the corresponding classes
robot_type: str
features: dict
def connect(self): ...
def run_calibration(self): ...

View File

@ -114,7 +114,7 @@ policy:
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_momentum: null
temporal_ensemble_coeff: null
# Training and loss computation.
dropout: 0.1

View File

@ -95,7 +95,7 @@ policy:
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_momentum: null
temporal_ensemble_coeff: null
# Training and loss computation.
dropout: 0.1

View File

@ -95,7 +95,7 @@ policy:
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_momentum: null
temporal_ensemble_coeff: null
# Training and loss computation.
dropout: 0.1

View File

@ -95,7 +95,7 @@ policy:
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_momentum: null
temporal_ensemble_coeff: null
# Training and loss computation.
dropout: 0.1

View File

@ -106,12 +106,6 @@ from typing import List
# from safetensors.torch import load_file, save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.populate_dataset import (
create_lerobot_dataset,
delete_current_episode,
init_dataset,
save_current_episode,
)
from lerobot.common.robot_devices.control_utils import (
control_loop,
has_method,
@ -121,6 +115,7 @@ from lerobot.common.robot_devices.control_utils import (
record_episode,
reset_environment,
sanity_check_dataset_name,
sanity_check_dataset_robot_compatibility,
stop_recording,
warmup_record,
)
@ -198,23 +193,25 @@ def record(
robot: Robot,
root: str,
repo_id: str,
single_task: str,
pretrained_policy_name_or_path: str | None = None,
policy_overrides: List[str] | None = None,
fps: int | None = None,
warmup_time_s=2,
episode_time_s=10,
reset_time_s=5,
num_episodes=50,
video=True,
run_compute_stats=True,
push_to_hub=True,
tags=None,
num_image_writer_processes=0,
num_image_writer_threads_per_camera=4,
force_override=False,
display_cameras=True,
play_sounds=True,
):
warmup_time_s: int | float = 2,
episode_time_s: int | float = 10,
reset_time_s: int | float = 5,
num_episodes: int = 50,
video: bool = True,
run_compute_stats: bool = True,
push_to_hub: bool = True,
num_image_writer_processes: int = 0,
num_image_writer_threads_per_camera: int = 4,
display_cameras: bool = True,
play_sounds: bool = True,
resume: bool = False,
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
local_files_only: bool = False,
) -> LeRobotDataset:
# TODO(rcadene): Add option to record logs
listener = None
events = None
@ -222,6 +219,11 @@ def record(
device = None
use_amp = None
if single_task:
task = single_task
else:
raise NotImplementedError("Only single-task recording is supported for now")
# Load pretrained policy
if pretrained_policy_name_or_path is not None:
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
@ -234,18 +236,29 @@ def record(
f"There is a mismatch between the provided fps ({fps}) and the one from policy config ({policy_fps})."
)
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy)
dataset = init_dataset(
repo_id,
root,
force_override,
fps,
video,
write_images=robot.has_camera,
num_image_writer_processes=num_image_writer_processes,
num_image_writer_threads=num_image_writer_threads_per_camera * robot.num_cameras,
)
if resume:
dataset = LeRobotDataset(
repo_id,
root=root,
local_files_only=local_files_only,
)
dataset.start_image_writer(
num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads_per_camera * len(robot.cameras),
)
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video)
else:
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy)
dataset = LeRobotDataset.create(
repo_id,
fps,
root=root,
robot=robot,
use_videos=video,
image_writer_processes=num_image_writer_processes,
image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras),
)
if not robot.is_connected:
robot.connect()
@ -263,12 +276,17 @@ def record(
if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop()
recorded_episodes = 0
while True:
if dataset["num_episodes"] >= num_episodes:
if recorded_episodes >= num_episodes:
break
episode_index = dataset["num_episodes"]
log_say(f"Recording episode {episode_index}", play_sounds)
# TODO(aliberts): add task prompt for multitask here. Might need to temporarily disable event if
# input() messes with them.
# if multi_task:
# task = input("Enter your task description: ")
log_say(f"Recording episode {dataset.num_episodes}", play_sounds)
record_episode(
dataset=dataset,
robot=robot,
@ -286,7 +304,7 @@ def record(
# TODO(rcadene): add an option to enable teleoperation during reset
# Skip reset for the last episode to be recorded
if not events["stop_recording"] and (
(episode_index < num_episodes - 1) or events["rerecord_episode"]
(dataset.num_episodes < num_episodes - 1) or events["rerecord_episode"]
):
log_say("Reset the environment", play_sounds)
reset_environment(robot, events, reset_time_s)
@ -295,11 +313,11 @@ def record(
log_say("Re-record episode", play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
delete_current_episode(dataset)
dataset.clear_episode_buffer()
continue
# Increment by one dataset["current_episode_index"]
save_current_episode(dataset)
dataset.save_episode(task)
recorded_episodes += 1
if events["stop_recording"]:
break
@ -307,35 +325,42 @@ def record(
log_say("Stop recording", play_sounds, blocking=True)
stop_recording(robot, listener, display_cameras)
lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds)
if run_compute_stats:
logging.info("Computing dataset statistics")
dataset.consolidate(run_compute_stats)
if push_to_hub:
dataset.push_to_hub()
log_say("Exiting", play_sounds)
return lerobot_dataset
return dataset
@safe_disconnect
def replay(
robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug", play_sounds=True
robot: Robot,
root: Path,
repo_id: str,
episode: int,
fps: int | None = None,
play_sounds: bool = True,
local_files_only: bool = True,
):
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
# TODO(rcadene): Add option to record logs
local_dir = Path(root) / repo_id
if not local_dir.exists():
raise ValueError(local_dir)
dataset = LeRobotDataset(repo_id, root=root)
items = dataset.hf_dataset.select_columns("action")
from_idx = dataset.episode_data_index["from"][episode].item()
to_idx = dataset.episode_data_index["to"][episode].item()
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
actions = dataset.hf_dataset.select_columns("action")
if not robot.is_connected:
robot.connect()
log_say("Replaying episode", play_sounds, blocking=True)
for idx in range(from_idx, to_idx):
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
action = items[idx]["action"]
action = actions[idx]["action"]
robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t
@ -384,9 +409,21 @@ if __name__ == "__main__":
)
parser_record = subparsers.add_parser("record", parents=[base_parser])
task_args = parser_record.add_mutually_exclusive_group(required=True)
parser_record.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
task_args.add_argument(
"--single-task",
type=str,
help="A short but accurate description of the task performed during the recording.",
)
# TODO(aliberts): add multi-task support
# task_args.add_argument(
# "--multi-task",
# type=int,
# help="You will need to enter the task performed at the start of each episode.",
# )
parser_record.add_argument(
"--root",
type=Path,

View File

@ -484,7 +484,7 @@ def main(
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=str(pretrained_policy_path))
else:
# Note: We need the dataset stats to pass to the policy's normalization modules.
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).meta.stats)
assert isinstance(policy, nn.Module)
policy.eval()

View File

@ -1,30 +1,36 @@
import os
import time
from pathlib import Path
from serial.tools import list_ports # Part of pyserial library
def find_available_ports():
ports = []
for path in Path("/dev").glob("tty*"):
ports.append(str(path))
if os.name == "nt": # Windows
# List COM ports using pyserial
ports = [port.device for port in list_ports.comports()]
else: # Linux/macOS
# List /dev/tty* ports for Unix-based systems
ports = [str(path) for path in Path("/dev").glob("tty*")]
return ports
def find_port():
print("Finding all available ports for the MotorsBus.")
ports_before = find_available_ports()
print(ports_before)
print("Ports before disconnecting:", ports_before)
print("Remove the usb cable from your MotorsBus and press Enter when done.")
input()
print("Remove the USB cable from your MotorsBus and press Enter when done.")
input() # Wait for user to disconnect the device
time.sleep(0.5)
time.sleep(0.5) # Allow some time for port to be released
ports_after = find_available_ports()
ports_diff = list(set(ports_before) - set(ports_after))
if len(ports_diff) == 1:
port = ports_diff[0]
print(f"The port of this MotorsBus is '{port}'")
print("Reconnect the usb cable.")
print("Reconnect the USB cable.")
elif len(ports_diff) == 0:
raise OSError(f"Could not detect the port. No difference was found ({ports_diff}).")
else:
@ -32,5 +38,5 @@ def find_port():
if __name__ == "__main__":
# Helper to find the usb port associated to all your MotorsBus.
# Helper to find the USB port associated with your MotorsBus.
find_port()

View File

@ -117,10 +117,14 @@ def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str
def push_dataset_card_to_hub(
repo_id: str, revision: str | None, tags: list | None = None, text: str | None = None
repo_id: str,
revision: str | None,
tags: list | None = None,
license: str = "apache-2.0",
**card_kwargs,
):
"""Creates and pushes a LeRobotDataset Card with appropriate tags to easily find it on the hub."""
card = create_lerobot_dataset_card(tags=tags, text=text)
card = create_lerobot_dataset_card(tags=tags, license=license, **card_kwargs)
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=revision)
@ -260,7 +264,7 @@ def push_dataset_to_hub(
episode_index = 0
tests_videos_dir = tests_data_dir / repo_id / "videos"
tests_videos_dir.mkdir(parents=True, exist_ok=True)
for key in lerobot_dataset.video_frame_keys:
for key in lerobot_dataset.camera_keys:
fname = f"{key}_episode_{episode_index:06d}.mp4"
shutil.copy(videos_dir / fname, tests_videos_dir / fname)

View File

@ -171,9 +171,9 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_online):
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.training.batch_size
avg_samples_per_ep = dataset.num_samples / dataset.num_episodes
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / dataset.num_samples
num_epochs = num_samples / dataset.num_frames
log_items = [
f"step:{format_big_number(step)}",
# number of samples seen during training
@ -208,9 +208,9 @@ def log_eval_info(logger, info, step, cfg, dataset, is_online):
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.training.batch_size
avg_samples_per_ep = dataset.num_samples / dataset.num_episodes
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / dataset.num_samples
num_epochs = num_samples / dataset.num_frames
log_items = [
f"step:{format_big_number(step)}",
# number of samples seen during training
@ -328,7 +328,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("make_policy")
policy = make_policy(
hydra_cfg=cfg,
dataset_stats=offline_dataset.stats if not cfg.resume else None,
dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
)
assert isinstance(policy, nn.Module)
@ -349,7 +349,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
logging.info(f"{cfg.training.online_steps=}")
logging.info(f"{offline_dataset.num_samples=} ({format_big_number(offline_dataset.num_samples)})")
logging.info(f"{offline_dataset.num_frames=} ({format_big_number(offline_dataset.num_frames)})")
logging.info(f"{offline_dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
@ -573,7 +573,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
online_sampling_ratio=cfg.training.online_sampling_ratio,
)
sampler.num_samples = len(concat_dataset)
sampler.num_frames = len(concat_dataset)
update_online_buffer_s = time.perf_counter() - start_update_buffer_time

View File

@ -100,7 +100,7 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
def visualize_dataset(
repo_id: str,
dataset: LeRobotDataset,
episode_index: int,
batch_size: int = 32,
num_workers: int = 0,
@ -108,7 +108,6 @@ def visualize_dataset(
web_port: int = 9090,
ws_port: int = 9087,
save: bool = False,
root: Path | None = None,
output_dir: Path | None = None,
) -> Path | None:
if save:
@ -116,8 +115,7 @@ def visualize_dataset(
output_dir is not None
), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
logging.info("Loading dataset")
dataset = LeRobotDataset(repo_id, root=root)
repo_id = dataset.repo_id
logging.info("Loading dataloader")
episode_sampler = EpisodeSampler(dataset, episode_index)
@ -153,7 +151,7 @@ def visualize_dataset(
rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
# display each camera image
for key in dataset.camera_keys:
for key in dataset.meta.camera_keys:
# TODO(rcadene): add `.compress()`? is it lossless?
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
@ -268,7 +266,14 @@ def main():
)
args = parser.parse_args()
visualize_dataset(**vars(args))
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
root = kwargs.pop("root")
logging.info("Loading dataset")
dataset = LeRobotDataset(repo_id, root=root, local_files_only=True)
visualize_dataset(dataset, **vars(args))
if __name__ == "__main__":

View File

@ -93,18 +93,17 @@ def run_server(
def show_episode(dataset_namespace, dataset_name, episode_id):
dataset_info = {
"repo_id": dataset.repo_id,
"num_samples": dataset.num_samples,
"num_samples": dataset.num_frames,
"num_episodes": dataset.num_episodes,
"fps": dataset.fps,
}
video_paths = get_episode_video_paths(dataset, episode_id)
language_instruction = get_episode_language_instruction(dataset, episode_id)
video_paths = [dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys]
tasks = dataset.meta.episodes[episode_id]["tasks"]
videos_info = [
{"url": url_for("static", filename=video_path), "filename": Path(video_path).name}
{"url": url_for("static", filename=video_path), "filename": video_path.name}
for video_path in video_paths
]
if language_instruction:
videos_info[0]["language_instruction"] = language_instruction
videos_info[0]["language_instruction"] = tasks
ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id))
return render_template(
@ -131,16 +130,16 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
from_idx = dataset.episode_data_index["from"][episode_index]
to_idx = dataset.episode_data_index["to"][episode_index]
has_state = "observation.state" in dataset.hf_dataset.features
has_action = "action" in dataset.hf_dataset.features
has_state = "observation.state" in dataset.features
has_action = "action" in dataset.features
# init header of csv with state and action names
header = ["timestamp"]
if has_state:
dim_state = len(dataset.hf_dataset["observation.state"][0])
dim_state = dataset.meta.shapes["observation.state"][0]
header += [f"state_{i}" for i in range(dim_state)]
if has_action:
dim_action = len(dataset.hf_dataset["action"][0])
dim_action = dataset.meta.shapes["action"][0]
header += [f"action_{i}" for i in range(dim_action)]
columns = ["timestamp"]
@ -172,27 +171,12 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
return [
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
for key in dataset.video_frame_keys
for key in dataset.meta.video_keys
]
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
# check if the dataset has language instructions
if "language_instruction" not in dataset.hf_dataset.features:
return None
# get first frame index
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
# with the tf.tensor appearing in the string
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
def visualize_dataset_html(
repo_id: str,
root: Path | None = None,
dataset: LeRobotDataset,
episodes: list[int] = None,
output_dir: Path | None = None,
serve: bool = True,
@ -202,13 +186,11 @@ def visualize_dataset_html(
) -> Path | None:
init_logging()
dataset = LeRobotDataset(repo_id, root=root)
if not dataset.video:
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
if len(dataset.meta.image_keys) > 0:
raise NotImplementedError(f"Image keys ({dataset.meta.image_keys=}) are currently not supported.")
if output_dir is None:
output_dir = f"outputs/visualize_dataset_html/{repo_id}"
output_dir = f"outputs/visualize_dataset_html/{dataset.repo_id}"
output_dir = Path(output_dir)
if output_dir.exists():
@ -225,7 +207,7 @@ def visualize_dataset_html(
static_dir.mkdir(parents=True, exist_ok=True)
ln_videos_dir = static_dir / "videos"
if not ln_videos_dir.exists():
ln_videos_dir.symlink_to(dataset.videos_dir.resolve())
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
template_dir = Path(__file__).resolve().parent.parent / "templates"
@ -297,7 +279,11 @@ def main():
)
args = parser.parse_args()
visualize_dataset_html(**vars(args))
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
root = kwargs.pop("root")
dataset = LeRobotDataset(repo_id, root=root, local_files_only=True)
visualize_dataset_html(dataset, **kwargs)
if __name__ == "__main__":

View File

@ -157,7 +157,7 @@ def visualize_transforms(cfg, output_dir: Path, n_examples: int = 5):
output_dir.mkdir(parents=True, exist_ok=True)
# Get 1st frame from 1st camera of 1st episode
original_frame = dataset[0][dataset.camera_keys[0]]
original_frame = dataset[0][dataset.meta.camera_keys[0]]
to_pil(original_frame).save(output_dir / "original_frame.png", quality=100)
print("\nOriginal frame saved to:")
print(f" {output_dir / 'original_frame.png'}.")

View File

@ -35,7 +35,7 @@
<ul>
<li>
Number of samples/frames: {{ dataset_info.num_samples }}
Number of samples/frames: {{ dataset_info.num_frames }}
</li>
<li>
Number of episodes: {{ dataset_info.num_episodes }}

1305
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -43,9 +43,8 @@ opencv-python = ">=4.9.0"
diffusers = ">=0.27.2"
torchvision = ">=0.17.1"
h5py = ">=3.10.0"
huggingface-hub = {extras = ["hf-transfer", "cli"], version = ">=0.25.0"}
# TODO(rcadene, aliberts): Make gym 1.0.0 work
gymnasium = "==0.29.1"
huggingface-hub = {extras = ["hf-transfer", "cli"], version = ">=0.25.2"}
gymnasium = "==0.29.1" # TODO(rcadene, aliberts): Make gym 1.0.0 work
cmake = ">=3.29.0.1"
gym-dora = { git = "https://github.com/dora-rs/dora-lerobot.git", subdirectory = "gym_dora", optional = true }
gym-pusht = { version = ">=0.1.5", optional = true}
@ -72,6 +71,7 @@ pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platfo
hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true}
pyserial = {version = ">=3.5", optional = true}
reachy2-sdk = {git = "https://github.com/pollen-robotics/reachy2-sdk", branch="450-opencv-dependency-version", optional = true}
jsonlines = ">=4.0.0"
[tool.poetry.extras]

View File

@ -23,6 +23,13 @@ from lerobot import available_cameras, available_motors, available_robots
from lerobot.common.utils.utils import init_hydra_config
from tests.utils import DEVICE, ROBOT_CONFIG_PATH_TEMPLATE, make_camera, make_motors_bus
# Import fixture modules as plugins
pytest_plugins = [
"tests.fixtures.dataset_factories",
"tests.fixtures.files",
"tests.fixtures.hub",
]
def pytest_collection_finish():
print(f"\nTesting with {DEVICE=}")

29
tests/fixtures/constants.py vendored Normal file
View File

@ -0,0 +1,29 @@
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
LEROBOT_TEST_DIR = LEROBOT_HOME / "_testing"
DUMMY_REPO_ID = "dummy/repo"
DUMMY_ROBOT_TYPE = "dummy_robot"
DUMMY_MOTOR_FEATURES = {
"action": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
},
"state": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
},
}
DUMMY_CAMERA_FEATURES = {
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
}
DEFAULT_FPS = 30
DUMMY_VIDEO_INFO = {
"video.fps": DEFAULT_FPS,
"video.codec": "av1",
"video.pix_fmt": "yuv420p",
"video.is_depth_map": False,
"has_audio": False,
}

396
tests/fixtures/dataset_factories.py vendored Normal file
View File

@ -0,0 +1,396 @@
import random
from pathlib import Path
from unittest.mock import patch
import datasets
import numpy as np
import PIL.Image
import pytest
import torch
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_FEATURES,
DEFAULT_PARQUET_PATH,
DEFAULT_VIDEO_PATH,
get_hf_features_from_features,
hf_transform_to_torch,
)
from tests.fixtures.constants import (
DEFAULT_FPS,
DUMMY_CAMERA_FEATURES,
DUMMY_MOTOR_FEATURES,
DUMMY_REPO_ID,
DUMMY_ROBOT_TYPE,
DUMMY_VIDEO_INFO,
)
def get_task_index(task_dicts: dict, task: str) -> int:
tasks = {d["task_index"]: d["task"] for d in task_dicts}
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
return task_to_task_index[task]
@pytest.fixture(scope="session")
def img_tensor_factory():
def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor:
return torch.rand((channels, height, width), dtype=dtype)
return _create_img_tensor
@pytest.fixture(scope="session")
def img_array_factory():
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
if np.issubdtype(dtype, np.unsignedinteger):
# Int array in [0, 255] range
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
elif np.issubdtype(dtype, np.floating):
# Float array in [0, 1] range
img_array = np.random.rand(height, width, channels).astype(dtype)
else:
raise ValueError(dtype)
return img_array
return _create_img_array
@pytest.fixture(scope="session")
def img_factory(img_array_factory):
def _create_img(height=100, width=100) -> PIL.Image.Image:
img_array = img_array_factory(height=height, width=width)
return PIL.Image.fromarray(img_array)
return _create_img
@pytest.fixture(scope="session")
def features_factory():
def _create_features(
motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES,
use_videos: bool = True,
) -> dict:
if use_videos:
camera_ft = {
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items()
}
else:
camera_ft = {key: {"dtype": "image", **ft} for key, ft in camera_features.items()}
return {
**motor_features,
**camera_ft,
**DEFAULT_FEATURES,
}
return _create_features
@pytest.fixture(scope="session")
def info_factory(features_factory):
def _create_info(
codebase_version: str = CODEBASE_VERSION,
fps: int = DEFAULT_FPS,
robot_type: str = DUMMY_ROBOT_TYPE,
total_episodes: int = 0,
total_frames: int = 0,
total_tasks: int = 0,
total_videos: int = 0,
total_chunks: int = 0,
chunks_size: int = DEFAULT_CHUNK_SIZE,
data_path: str = DEFAULT_PARQUET_PATH,
video_path: str = DEFAULT_VIDEO_PATH,
motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES,
use_videos: bool = True,
) -> dict:
features = features_factory(motor_features, camera_features, use_videos)
return {
"codebase_version": codebase_version,
"robot_type": robot_type,
"total_episodes": total_episodes,
"total_frames": total_frames,
"total_tasks": total_tasks,
"total_videos": total_videos,
"total_chunks": total_chunks,
"chunks_size": chunks_size,
"fps": fps,
"splits": {},
"data_path": data_path,
"video_path": video_path if use_videos else None,
"features": features,
}
return _create_info
@pytest.fixture(scope="session")
def stats_factory():
def _create_stats(
features: dict[str] | None = None,
) -> dict:
stats = {}
for key, ft in features.items():
shape = ft["shape"]
dtype = ft["dtype"]
if dtype in ["image", "video"]:
stats[key] = {
"max": np.full((3, 1, 1), 1, dtype=np.float32).tolist(),
"mean": np.full((3, 1, 1), 0.5, dtype=np.float32).tolist(),
"min": np.full((3, 1, 1), 0, dtype=np.float32).tolist(),
"std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(),
}
else:
stats[key] = {
"max": np.full(shape, 1, dtype=dtype).tolist(),
"mean": np.full(shape, 0.5, dtype=dtype).tolist(),
"min": np.full(shape, 0, dtype=dtype).tolist(),
"std": np.full(shape, 0.25, dtype=dtype).tolist(),
}
return stats
return _create_stats
@pytest.fixture(scope="session")
def tasks_factory():
def _create_tasks(total_tasks: int = 3) -> int:
tasks_list = []
for i in range(total_tasks):
task_dict = {"task_index": i, "task": f"Perform action {i}."}
tasks_list.append(task_dict)
return tasks_list
return _create_tasks
@pytest.fixture(scope="session")
def episodes_factory(tasks_factory):
def _create_episodes(
total_episodes: int = 3,
total_frames: int = 400,
tasks: dict | None = None,
multi_task: bool = False,
):
if total_episodes <= 0 or total_frames <= 0:
raise ValueError("num_episodes and total_length must be positive integers.")
if total_frames < total_episodes:
raise ValueError("total_length must be greater than or equal to num_episodes.")
if not tasks:
min_tasks = 2 if multi_task else 1
total_tasks = random.randint(min_tasks, total_episodes)
tasks = tasks_factory(total_tasks)
if total_episodes < len(tasks) and not multi_task:
raise ValueError("The number of tasks should be less than the number of episodes.")
# Generate random lengths that sum up to total_length
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
tasks_list = [task_dict["task"] for task_dict in tasks]
num_tasks_available = len(tasks_list)
episodes_list = []
remaining_tasks = tasks_list.copy()
for ep_idx in range(total_episodes):
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
if remaining_tasks:
for task in episode_tasks:
remaining_tasks.remove(task)
episodes_list.append(
{
"episode_index": ep_idx,
"tasks": episode_tasks,
"length": lengths[ep_idx],
}
)
return episodes_list
return _create_episodes
@pytest.fixture(scope="session")
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
def _create_hf_dataset(
features: dict | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
fps: int = DEFAULT_FPS,
) -> datasets.Dataset:
if not tasks:
tasks = tasks_factory()
if not episodes:
episodes = episodes_factory()
if not features:
features = features_factory()
timestamp_col = np.array([], dtype=np.float32)
frame_index_col = np.array([], dtype=np.int64)
episode_index_col = np.array([], dtype=np.int64)
task_index = np.array([], dtype=np.int64)
for ep_dict in episodes:
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
episode_index_col = np.concatenate(
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
)
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
index_col = np.arange(len(episode_index_col))
robot_cols = {}
for key, ft in features.items():
if ft["dtype"] == "image":
robot_cols[key] = [
img_array_factory(height=ft["shapes"][1], width=ft["shapes"][0])
for _ in range(len(index_col))
]
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"])
hf_features = get_hf_features_from_features(features)
dataset = datasets.Dataset.from_dict(
{
**robot_cols,
"timestamp": timestamp_col,
"frame_index": frame_index_col,
"episode_index": episode_index_col,
"index": index_col,
"task_index": task_index,
},
features=hf_features,
)
dataset.set_transform(hf_transform_to_torch)
return dataset
return _create_hf_dataset
@pytest.fixture(scope="session")
def lerobot_dataset_metadata_factory(
info_factory,
stats_factory,
tasks_factory,
episodes_factory,
mock_snapshot_download_factory,
):
def _create_lerobot_dataset_metadata(
root: Path,
repo_id: str = DUMMY_REPO_ID,
info: dict | None = None,
stats: dict | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
local_files_only: bool = False,
) -> LeRobotDatasetMetadata:
if not info:
info = info_factory()
if not stats:
stats = stats_factory(features=info["features"])
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
)
mock_snapshot_download = mock_snapshot_download_factory(
info=info,
stats=stats,
tasks=tasks,
episodes=episodes,
)
with (
patch(
"lerobot.common.datasets.lerobot_dataset.get_hub_safe_version"
) as mock_get_hub_safe_version_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
):
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version: version
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDatasetMetadata(repo_id=repo_id, root=root, local_files_only=local_files_only)
return _create_lerobot_dataset_metadata
@pytest.fixture(scope="session")
def lerobot_dataset_factory(
info_factory,
stats_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
mock_snapshot_download_factory,
lerobot_dataset_metadata_factory,
):
def _create_lerobot_dataset(
root: Path,
repo_id: str = DUMMY_REPO_ID,
total_episodes: int = 3,
total_frames: int = 150,
total_tasks: int = 1,
multi_task: bool = False,
info: dict | None = None,
stats: dict | None = None,
tasks: list[dict] | None = None,
episode_dicts: list[dict] | None = None,
hf_dataset: datasets.Dataset | None = None,
**kwargs,
) -> LeRobotDataset:
if not info:
info = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
)
if not stats:
stats = stats_factory(features=info["features"])
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episode_dicts:
episode_dicts = episodes_factory(
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
tasks=tasks,
multi_task=multi_task,
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"])
mock_snapshot_download = mock_snapshot_download_factory(
info=info,
stats=stats,
tasks=tasks,
episodes=episode_dicts,
hf_dataset=hf_dataset,
)
mock_metadata = lerobot_dataset_metadata_factory(
root=root,
repo_id=repo_id,
info=info,
stats=stats,
tasks=tasks,
episodes=episode_dicts,
local_files_only=kwargs.get("local_files_only", False),
)
with (
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
):
mock_metadata_patch.return_value = mock_metadata
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)
return _create_lerobot_dataset

114
tests/fixtures/files.py vendored Normal file
View File

@ -0,0 +1,114 @@
import json
from pathlib import Path
import datasets
import jsonlines
import pyarrow.compute as pc
import pyarrow.parquet as pq
import pytest
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
@pytest.fixture(scope="session")
def info_path(info_factory):
def _create_info_json_file(dir: Path, info: dict | None = None) -> Path:
if not info:
info = info_factory()
fpath = dir / INFO_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with open(fpath, "w") as f:
json.dump(info, f, indent=4, ensure_ascii=False)
return fpath
return _create_info_json_file
@pytest.fixture(scope="session")
def stats_path(stats_factory):
def _create_stats_json_file(dir: Path, stats: dict | None = None) -> Path:
if not stats:
stats = stats_factory()
fpath = dir / STATS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with open(fpath, "w") as f:
json.dump(stats, f, indent=4, ensure_ascii=False)
return fpath
return _create_stats_json_file
@pytest.fixture(scope="session")
def tasks_path(tasks_factory):
def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path:
if not tasks:
tasks = tasks_factory()
fpath = dir / TASKS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(tasks)
return fpath
return _create_tasks_jsonl_file
@pytest.fixture(scope="session")
def episode_path(episodes_factory):
def _create_episodes_jsonl_file(dir: Path, episodes: list | None = None) -> Path:
if not episodes:
episodes = episodes_factory()
fpath = dir / EPISODES_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(episodes)
return fpath
return _create_episodes_jsonl_file
@pytest.fixture(scope="session")
def single_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_single_episode_parquet(
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path:
if not info:
info = info_factory()
if hf_dataset is None:
hf_dataset = hf_dataset_factory()
data_path = info["data_path"]
chunks_size = info["chunks_size"]
ep_chunk = ep_idx // chunks_size
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
fpath.parent.mkdir(parents=True, exist_ok=True)
table = hf_dataset.data.table
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
pq.write_table(ep_table, fpath)
return fpath
return _create_single_episode_parquet
@pytest.fixture(scope="session")
def multi_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_multi_episode_parquet(
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path:
if not info:
info = info_factory()
if hf_dataset is None:
hf_dataset = hf_dataset_factory()
data_path = info["data_path"]
chunks_size = info["chunks_size"]
total_episodes = info["total_episodes"]
for ep_idx in range(total_episodes):
ep_chunk = ep_idx // chunks_size
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
fpath.parent.mkdir(parents=True, exist_ok=True)
table = hf_dataset.data.table
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
pq.write_table(ep_table, fpath)
return dir / "data"
return _create_multi_episode_parquet

105
tests/fixtures/hub.py vendored Normal file
View File

@ -0,0 +1,105 @@
from pathlib import Path
import datasets
import pytest
from huggingface_hub.utils import filter_repo_objects
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
from tests.fixtures.constants import LEROBOT_TEST_DIR
@pytest.fixture(scope="session")
def mock_snapshot_download_factory(
info_factory,
info_path,
stats_factory,
stats_path,
tasks_factory,
tasks_path,
episodes_factory,
episode_path,
single_episode_parquet_path,
hf_dataset_factory,
):
"""
This factory allows to patch snapshot_download such that when called, it will create expected files rather
than making calls to the hub api. Its design allows to pass explicitly files which you want to be created.
"""
def _mock_snapshot_download_func(
info: dict | None = None,
stats: dict | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
hf_dataset: datasets.Dataset | None = None,
):
if not info:
info = info_factory()
if not stats:
stats = stats_factory(features=info["features"])
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
def _extract_episode_index_from_path(fpath: str) -> int:
path = Path(fpath)
if path.suffix == ".parquet" and path.stem.startswith("episode_"):
episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0
return episode_index
else:
return None
def _mock_snapshot_download(
repo_id: str,
local_dir: str | Path | None = None,
allow_patterns: str | list[str] | None = None,
ignore_patterns: str | list[str] | None = None,
*args,
**kwargs,
) -> str:
if not local_dir:
local_dir = LEROBOT_TEST_DIR
# List all possible files
all_files = []
meta_files = [INFO_PATH, STATS_PATH, TASKS_PATH, EPISODES_PATH]
all_files.extend(meta_files)
data_files = []
for episode_dict in episodes:
ep_idx = episode_dict["episode_index"]
ep_chunk = ep_idx // info["chunks_size"]
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
data_files.append(data_path)
all_files.extend(data_files)
allowed_files = filter_repo_objects(
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
)
# Create allowed files
for rel_path in allowed_files:
if rel_path.startswith("data/"):
episode_index = _extract_episode_index_from_path(rel_path)
if episode_index is not None:
_ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info)
if rel_path == INFO_PATH:
_ = info_path(local_dir, info)
elif rel_path == STATS_PATH:
_ = stats_path(local_dir, stats)
elif rel_path == TASKS_PATH:
_ = tasks_path(local_dir, tasks)
elif rel_path == EPISODES_PATH:
_ = episode_path(local_dir, episodes)
else:
pass
return str(local_dir)
return _mock_snapshot_download
return _mock_snapshot_download_func

View File

@ -76,7 +76,7 @@ def main():
dataset = LeRobotDataset(DATASET_REPO_ID, image_transforms=None)
output_dir = Path(ARTIFACT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)
original_frame = dataset[0][dataset.camera_keys[0]]
original_frame = dataset[0][dataset.meta.camera_keys[0]]
save_single_transforms(original_frame, output_dir)
save_default_config_transform(original_frame, output_dir)

View File

@ -38,7 +38,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
)
set_global_seed(1337)
dataset = make_dataset(cfg)
policy = make_policy(cfg, dataset_stats=dataset.stats)
policy = make_policy(cfg, dataset_stats=dataset.meta.stats)
policy.train()
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)

View File

@ -29,7 +29,6 @@ from unittest.mock import patch
import pytest
from lerobot.common.datasets.populate_dataset import add_frame, init_dataset
from lerobot.common.logger import Logger
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import init_hydra_config
@ -93,8 +92,9 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
mock_calibration_dir(calibration_dir)
overrides.append(f"calibration_dir={calibration_dir}")
root = Path(tmpdir) / "data"
repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id
single_task = "Do something."
robot = make_robot(robot_type, overrides=overrides, mock=mock)
record(
@ -102,6 +102,7 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
fps=30,
root=root,
repo_id=repo_id,
single_task=single_task,
warmup_time_s=1,
episode_time_s=1,
num_episodes=2,
@ -132,17 +133,18 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
env_name = "koch_real"
policy_name = "act_koch_real"
root = tmpdir / "data"
repo_id = "lerobot/debug"
eval_repo_id = "lerobot/eval_debug"
root = tmpdir / "data" / repo_id
single_task = "Do something."
robot = make_robot(robot_type, overrides=overrides, mock=mock)
dataset = record(
robot,
root,
repo_id,
single_task,
fps=1,
warmup_time_s=1,
warmup_time_s=0.5,
episode_time_s=1,
reset_time_s=1,
num_episodes=2,
@ -153,7 +155,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
display_cameras=False,
play_sounds=False,
)
assert dataset.num_episodes == 2
assert dataset.meta.total_episodes == 2
assert len(dataset) == 2
replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False)
@ -191,7 +193,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
overrides=overrides,
)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
out_dir = tmpdir / "logger"
logger = Logger(cfg, out_dir, wandb_job_name="debug")
@ -225,10 +227,14 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
else:
num_image_writer_processes = 0
record(
eval_repo_id = "lerobot/eval_debug"
eval_root = tmpdir / "data" / eval_repo_id
dataset = record(
robot,
root,
eval_root,
eval_repo_id,
single_task,
pretrained_policy_name_or_path,
warmup_time_s=1,
episode_time_s=1,
@ -265,51 +271,48 @@ def test_resume_record(tmpdir, request, robot_type, mock):
robot = make_robot(robot_type, overrides=overrides, mock=mock)
root = Path(tmpdir) / "data"
repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id
single_task = "Do something."
dataset = record(
robot,
root,
repo_id,
fps=1,
warmup_time_s=0,
episode_time_s=1,
num_episodes=1,
push_to_hub=False,
video=False,
display_cameras=False,
play_sounds=False,
run_compute_stats=False,
)
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
record_kwargs = {
"robot": robot,
"root": root,
"repo_id": repo_id,
"single_task": single_task,
"fps": 1,
"warmup_time_s": 0,
"episode_time_s": 1,
"push_to_hub": False,
"video": False,
"display_cameras": False,
"play_sounds": False,
"run_compute_stats": False,
"local_files_only": True,
"num_episodes": 1,
}
init_dataset_return_value = {}
dataset = record(**record_kwargs)
assert len(dataset) == 1, f"`dataset` should contain 1 frame, not {len(dataset)}"
def wrapped_init_dataset(*args, **kwargs):
nonlocal init_dataset_return_value
init_dataset_return_value = init_dataset(*args, **kwargs)
return init_dataset_return_value
# init_dataset_return_value = {}
with patch("lerobot.scripts.control_robot.init_dataset", wraps=wrapped_init_dataset):
dataset = record(
robot,
root,
repo_id,
fps=1,
warmup_time_s=0,
episode_time_s=1,
num_episodes=2,
push_to_hub=False,
video=False,
display_cameras=False,
play_sounds=False,
run_compute_stats=False,
)
assert len(dataset) == 2, "`dataset` should contain only 1 frame"
assert (
init_dataset_return_value["num_episodes"] == 2
), "`init_dataset` should load the previous episode"
# def wrapped_init_dataset(*args, **kwargs):
# nonlocal init_dataset_return_value
# init_dataset_return_value = init_dataset(*args, **kwargs)
# return init_dataset_return_value
# with patch("lerobot.scripts.control_robot.init_dataset", wraps=wrapped_init_dataset):
with pytest.raises(FileExistsError):
# Dataset already exists, but resume=False by default
record(**record_kwargs)
dataset = record(**record_kwargs, resume=True)
assert len(dataset) == 2, f"`dataset` should contain 2 frames, not {len(dataset)}"
# assert (
# init_dataset_return_value["num_episodes"] == 2
# ), "`init_dataset` should load the previous episode"
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
@ -328,23 +331,22 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
overrides = []
robot = make_robot(robot_type, overrides=overrides, mock=mock)
with (
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame,
):
with patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener:
mock_events = {}
mock_events["exit_early"] = True
mock_events["rerecord_episode"] = True
mock_events["stop_recording"] = False
mock_listener.return_value = (None, mock_events)
root = Path(tmpdir) / "data"
repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id
single_task = "Do something."
dataset = record(
robot,
root,
repo_id,
single_task,
fps=1,
warmup_time_s=0,
episode_time_s=1,
@ -358,7 +360,6 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False"
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
assert mock_add_frame.call_count == 2, "`add_frame` should have been called 2 times"
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
@ -378,23 +379,22 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
overrides = []
robot = make_robot(robot_type, overrides=overrides, mock=mock)
with (
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame,
):
with patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener:
mock_events = {}
mock_events["exit_early"] = True
mock_events["rerecord_episode"] = False
mock_events["stop_recording"] = False
mock_listener.return_value = (None, mock_events)
root = Path(tmpdir) / "data"
repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id
single_task = "Do something."
dataset = record(
robot,
fps=2,
root=root,
single_task=single_task,
repo_id=repo_id,
warmup_time_s=0,
episode_time_s=1,
@ -407,7 +407,6 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
)
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
assert mock_add_frame.call_count == 1, "`add_frame` should have been called 1 time"
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
@ -429,23 +428,22 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
overrides = []
robot = make_robot(robot_type, overrides=overrides, mock=mock)
with (
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame,
):
with patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener:
mock_events = {}
mock_events["exit_early"] = True
mock_events["rerecord_episode"] = False
mock_events["stop_recording"] = True
mock_listener.return_value = (None, mock_events)
root = Path(tmpdir) / "data"
repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id
single_task = "Do something."
dataset = record(
robot,
root,
repo_id,
single_task=single_task,
fps=1,
warmup_time_s=0,
episode_time_s=1,
@ -459,5 +457,4 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
)
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
assert mock_add_frame.call_count == 1, "`add_frame` should have been called 1 time"
assert len(dataset) == 1, "`dataset` should contain only 1 frame"

View File

@ -33,18 +33,72 @@ from lerobot.common.datasets.compute_stats import (
get_stats_einops_patterns,
)
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
MultiLeRobotDataset,
)
from lerobot.common.datasets.utils import (
create_branch,
flatten_dict,
hf_transform_to_torch,
load_previous_and_future_frames,
unflatten_dict,
)
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE
from tests.fixtures.constants import DUMMY_REPO_ID
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, make_robot
def test_same_attributes_defined(lerobot_dataset_factory, tmp_path):
"""
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
objects have the same sets of attributes defined.
"""
# Instantiate both ways
robot = make_robot("koch", mock=True)
root_create = tmp_path / "create"
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init)
# Access the '_hub_version' cached_property in both instances to force its creation
_ = dataset_init.meta._hub_version
_ = dataset_create.meta._hub_version
init_attr = set(vars(dataset_init).keys())
create_attr = set(vars(dataset_create).keys())
assert init_attr == create_attr
def test_dataset_initialization(lerobot_dataset_factory, tmp_path):
kwargs = {
"repo_id": DUMMY_REPO_ID,
"total_episodes": 10,
"total_frames": 400,
"episodes": [2, 5, 6],
}
dataset = lerobot_dataset_factory(root=tmp_path, **kwargs)
assert dataset.repo_id == kwargs["repo_id"]
assert dataset.meta.total_episodes == kwargs["total_episodes"]
assert dataset.meta.total_frames == kwargs["total_frames"]
assert dataset.episodes == kwargs["episodes"]
assert dataset.num_episodes == len(kwargs["episodes"])
assert dataset.num_frames == len(dataset)
# TODO(aliberts):
# - [ ] test various attributes & state from init and create
# - [ ] test init with episodes and check num_frames
# - [ ] test add_frame
# - [ ] test add_episode
# - [ ] test consolidate
# - [ ] test push_to_hub
# - [ ] test smaller methods
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize(
"env_name, repo_id, policy_name",
lerobot.env_dataset_policy_triplets
@ -67,7 +121,7 @@ def test_factory(env_name, repo_id, policy_name):
)
dataset = make_dataset(cfg)
delta_timestamps = dataset.delta_timestamps
camera_keys = dataset.camera_keys
camera_keys = dataset.meta.camera_keys
item = dataset[0]
@ -117,6 +171,7 @@ def test_factory(env_name, repo_id, policy_name):
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
@pytest.mark.skip("TODO after v2 migration / removing hydra")
def test_multilerobotdataset_frames():
"""Check that all dataset frames are incorporated."""
# Note: use the image variants of the dataset to make the test approx 3x faster.
@ -130,7 +185,7 @@ def test_multilerobotdataset_frames():
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
dataset = MultiLeRobotDataset(repo_ids)
assert len(dataset) == sum(len(d) for d in sub_datasets)
assert dataset.num_samples == sum(d.num_samples for d in sub_datasets)
assert dataset.num_frames == sum(d.num_frames for d in sub_datasets)
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
@ -149,6 +204,8 @@ def test_multilerobotdataset_frames():
assert torch.equal(sub_dataset_item[k], dataset_item[k])
# TODO(aliberts, rcadene): Refactor and move this to a tests/test_compute_stats.py
@pytest.mark.skip("TODO after v2 migration / removing hydra")
def test_compute_stats_on_xarm():
"""Check that the statistics are computed correctly according to the stats_patterns property.
@ -197,7 +254,7 @@ def test_compute_stats_on_xarm():
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
# load stats used during training which are expected to match the ones returned by computed_stats
loaded_stats = dataset.stats # noqa: F841
loaded_stats = dataset.meta.stats # noqa: F841
# TODO(rcadene): we can't test this because expected_stats is computed on a subset
# # test loaded stats match expected stats
@ -208,72 +265,7 @@ def test_compute_stats_on_xarm():
# assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
def test_load_previous_and_future_frames_within_tolerance():
hf_dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.2, 0, 0.139]}
tol = 0.04
item = hf_dataset[2]
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
assert not is_pad.any(), "Unexpected padding detected"
def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range():
hf_dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.2, 0, 0.141]}
tol = 0.04
item = hf_dataset[2]
with pytest.raises(AssertionError):
load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range():
hf_dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
tol = 0.04
item = hf_dataset[2]
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
assert torch.equal(
is_pad, torch.tensor([True, False, False, True, True])
), "Padding does not match expected values"
# TODO(aliberts): Move to more appropriate location
def test_flatten_unflatten_dict():
d = {
"obs": {
@ -297,6 +289,7 @@ def test_flatten_unflatten_dict():
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize(
"repo_id",
[
@ -368,6 +361,7 @@ def test_backward_compatibility(repo_id):
# load_and_compare(i - 1)
@pytest.mark.skip("TODO after v2 migration / removing hydra")
def test_aggregate_stats():
"""Makes 3 basic datasets and checks that aggregate stats are computed correctly."""
with seeded_context(0):

View File

@ -0,0 +1,256 @@
import pytest
import torch
from datasets import Dataset
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
check_delta_timestamps,
check_timestamps_sync,
get_delta_indices,
hf_transform_to_torch,
)
from tests.fixtures.constants import DUMMY_MOTOR_FEATURES
@pytest.fixture(scope="module")
def synced_hf_dataset_factory(hf_dataset_factory):
def _create_synced_hf_dataset(fps: int = 30) -> Dataset:
return hf_dataset_factory(fps=fps)
return _create_synced_hf_dataset
@pytest.fixture(scope="module")
def unsynced_hf_dataset_factory(synced_hf_dataset_factory):
def _create_unsynced_hf_dataset(fps: int = 30, tolerance_s: float = 1e-4) -> Dataset:
hf_dataset = synced_hf_dataset_factory(fps=fps)
features = hf_dataset.features
df = hf_dataset.to_pandas()
dtype = df["timestamp"].dtype # This is to avoid pandas type warning
# Modify a single timestamp just outside tolerance
df.at[30, "timestamp"] = dtype.type(df.at[30, "timestamp"] + (tolerance_s * 1.1))
unsynced_hf_dataset = Dataset.from_pandas(df, features=features)
unsynced_hf_dataset.set_transform(hf_transform_to_torch)
return unsynced_hf_dataset
return _create_unsynced_hf_dataset
@pytest.fixture(scope="module")
def slightly_off_hf_dataset_factory(synced_hf_dataset_factory):
def _create_slightly_off_hf_dataset(fps: int = 30, tolerance_s: float = 1e-4) -> Dataset:
hf_dataset = synced_hf_dataset_factory(fps=fps)
features = hf_dataset.features
df = hf_dataset.to_pandas()
dtype = df["timestamp"].dtype # This is to avoid pandas type warning
# Modify a single timestamp just inside tolerance
df.at[30, "timestamp"] = dtype.type(df.at[30, "timestamp"] + (tolerance_s * 0.9))
unsynced_hf_dataset = Dataset.from_pandas(df, features=features)
unsynced_hf_dataset.set_transform(hf_transform_to_torch)
return unsynced_hf_dataset
return _create_slightly_off_hf_dataset
@pytest.fixture(scope="module")
def valid_delta_timestamps_factory():
def _create_valid_delta_timestamps(fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES) -> dict:
delta_timestamps = {key: [i * (1 / fps) for i in range(-10, 10)] for key in keys}
return delta_timestamps
return _create_valid_delta_timestamps
@pytest.fixture(scope="module")
def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
def _create_invalid_delta_timestamps(
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
) -> dict:
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
# Modify a single timestamp just outside tolerance
for key in keys:
delta_timestamps[key][3] += tolerance_s * 1.1
return delta_timestamps
return _create_invalid_delta_timestamps
@pytest.fixture(scope="module")
def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
def _create_slightly_off_delta_timestamps(
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
) -> dict:
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
# Modify a single timestamp just inside tolerance
for key in delta_timestamps:
delta_timestamps[key][3] += tolerance_s * 0.9
delta_timestamps[key][-3] += tolerance_s * 0.9
return delta_timestamps
return _create_slightly_off_delta_timestamps
@pytest.fixture(scope="module")
def delta_indices(keys: list = DUMMY_MOTOR_FEATURES) -> dict:
return {key: list(range(-10, 10)) for key in keys}
def test_check_timestamps_sync_synced(synced_hf_dataset_factory):
fps = 30
tolerance_s = 1e-4
synced_hf_dataset = synced_hf_dataset_factory(fps)
episode_data_index = calculate_episode_data_index(synced_hf_dataset)
result = check_timestamps_sync(
hf_dataset=synced_hf_dataset,
episode_data_index=episode_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
def test_check_timestamps_sync_unsynced(unsynced_hf_dataset_factory):
fps = 30
tolerance_s = 1e-4
unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s)
episode_data_index = calculate_episode_data_index(unsynced_hf_dataset)
with pytest.raises(ValueError):
check_timestamps_sync(
hf_dataset=unsynced_hf_dataset,
episode_data_index=episode_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
def test_check_timestamps_sync_unsynced_no_exception(unsynced_hf_dataset_factory):
fps = 30
tolerance_s = 1e-4
unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s)
episode_data_index = calculate_episode_data_index(unsynced_hf_dataset)
result = check_timestamps_sync(
hf_dataset=unsynced_hf_dataset,
episode_data_index=episode_data_index,
fps=fps,
tolerance_s=tolerance_s,
raise_value_error=False,
)
assert result is False
def test_check_timestamps_sync_slightly_off(slightly_off_hf_dataset_factory):
fps = 30
tolerance_s = 1e-4
slightly_off_hf_dataset = slightly_off_hf_dataset_factory(fps, tolerance_s)
episode_data_index = calculate_episode_data_index(slightly_off_hf_dataset)
result = check_timestamps_sync(
hf_dataset=slightly_off_hf_dataset,
episode_data_index=episode_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
def test_check_timestamps_sync_single_timestamp():
single_timestamp_hf_dataset = Dataset.from_dict({"timestamp": [0.0], "episode_index": [0]})
single_timestamp_hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {"to": torch.tensor([1]), "from": torch.tensor([0])}
fps = 30
tolerance_s = 1e-4
result = check_timestamps_sync(
hf_dataset=single_timestamp_hf_dataset,
episode_data_index=episode_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
# TODO(aliberts): Change behavior of hf_transform_to_torch so that it can work with empty dataset
@pytest.mark.skip("TODO: fix")
def test_check_timestamps_sync_empty_dataset():
fps = 30
tolerance_s = 1e-4
empty_hf_dataset = Dataset.from_dict({"timestamp": [], "episode_index": []})
empty_hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {
"to": torch.tensor([], dtype=torch.int64),
"from": torch.tensor([], dtype=torch.int64),
}
result = check_timestamps_sync(
hf_dataset=empty_hf_dataset,
episode_data_index=episode_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
def test_check_delta_timestamps_valid(valid_delta_timestamps_factory):
fps = 30
tolerance_s = 1e-4
valid_delta_timestamps = valid_delta_timestamps_factory(fps)
result = check_delta_timestamps(
delta_timestamps=valid_delta_timestamps,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
def test_check_delta_timestamps_slightly_off(slightly_off_delta_timestamps_factory):
fps = 30
tolerance_s = 1e-4
slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory(fps, tolerance_s)
result = check_delta_timestamps(
delta_timestamps=slightly_off_delta_timestamps,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
def test_check_delta_timestamps_invalid(invalid_delta_timestamps_factory):
fps = 30
tolerance_s = 1e-4
invalid_delta_timestamps = invalid_delta_timestamps_factory(fps, tolerance_s)
with pytest.raises(ValueError):
check_delta_timestamps(
delta_timestamps=invalid_delta_timestamps,
fps=fps,
tolerance_s=tolerance_s,
)
def test_check_delta_timestamps_invalid_no_exception(invalid_delta_timestamps_factory):
fps = 30
tolerance_s = 1e-4
invalid_delta_timestamps = invalid_delta_timestamps_factory(fps, tolerance_s)
result = check_delta_timestamps(
delta_timestamps=invalid_delta_timestamps,
fps=fps,
tolerance_s=tolerance_s,
raise_value_error=False,
)
assert result is False
def test_check_delta_timestamps_empty():
delta_timestamps = {}
fps = 30
tolerance_s = 1e-4
result = check_delta_timestamps(
delta_timestamps=delta_timestamps,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
def test_delta_indices(valid_delta_timestamps_factory, delta_indices):
fps = 30
delta_timestamps = valid_delta_timestamps_factory(fps)
expected_delta_indices = delta_indices
actual_delta_indices = get_delta_indices(delta_timestamps, fps)
assert expected_delta_indices == actual_delta_indices

View File

@ -13,12 +13,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(aliberts): Mute logging for these tests
import io
import subprocess
import sys
from pathlib import Path
import pytest
from tests.fixtures.constants import DUMMY_REPO_ID
from tests.utils import require_package
@ -29,6 +32,7 @@ def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> s
return text
# TODO(aliberts): Remove usage of subprocess calls and patch code with fixtures
def _run_script(path):
subprocess.run([sys.executable, path], check=True)
@ -38,12 +42,26 @@ def _read_file(path):
return file.read()
def test_example_1():
@pytest.mark.skip("TODO Fix and remove subprocess / excec calls")
def test_example_1(tmp_path, lerobot_dataset_factory):
_ = lerobot_dataset_factory(root=tmp_path, repo_id=DUMMY_REPO_ID)
path = "examples/1_load_lerobot_dataset.py"
_run_script(path)
file_contents = _read_file(path)
file_contents = _find_and_replace(
file_contents,
[
('repo_id = "lerobot/pusht"', f'repo_id = "{DUMMY_REPO_ID}"'),
(
"LeRobotDataset(repo_id",
f"LeRobotDataset(repo_id, root='{str(tmp_path)}', local_files_only=True",
),
],
)
exec(file_contents, {})
assert Path("outputs/examples/1_load_lerobot_dataset/episode_0.mp4").exists()
@pytest.mark.skip("TODO Fix and remove subprocess / excec calls")
@require_package("gym_pusht")
def test_examples_basic2_basic3_advanced1():
"""
@ -111,7 +129,8 @@ def test_examples_basic2_basic3_advanced1():
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
),
('split=f"train[{first_val_frame_index}:]"', 'split="train[30:]"'),
("train_episodes = episodes[:num_train_episodes]", "train_episodes = [0]"),
("val_episodes = episodes[num_train_episodes:]", "val_episodes = [1]"),
("num_workers=4", "num_workers=0"),
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
("batch_size=64", "batch_size=1"),

View File

@ -15,15 +15,12 @@
# limitations under the License.
from pathlib import Path
import numpy as np
import pytest
import torch
from PIL import Image
from safetensors.torch import load_file
from torchvision.transforms import v2
from torchvision.transforms.v2 import functional as F # noqa: N812
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from lerobot.scripts.visualize_image_transforms import visualize_transforms
@ -33,21 +30,6 @@ ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
def load_png_to_tensor(path: Path):
return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
@pytest.fixture
def img():
dataset = LeRobotDataset(DATASET_REPO_ID)
return dataset[0][dataset.camera_keys[0]]
@pytest.fixture
def img_random():
return torch.rand(3, 480, 640)
@pytest.fixture
def color_jitters():
return [
@ -67,47 +49,54 @@ def default_transforms():
return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
def test_get_image_transforms_no_transform(img):
def test_get_image_transforms_no_transform(img_tensor_factory):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0)
torch.testing.assert_close(tf_actual(img), img)
torch.testing.assert_close(tf_actual(img_tensor), img_tensor)
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_brightness(img, min_max):
def test_get_image_transforms_brightness(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(brightness_weight=1.0, brightness_min_max=min_max)
tf_expected = v2.ColorJitter(brightness=min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_contrast(img, min_max):
def test_get_image_transforms_contrast(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(contrast_weight=1.0, contrast_min_max=min_max)
tf_expected = v2.ColorJitter(contrast=min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_saturation(img, min_max):
def test_get_image_transforms_saturation(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(saturation_weight=1.0, saturation_min_max=min_max)
tf_expected = v2.ColorJitter(saturation=min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)])
def test_get_image_transforms_hue(img, min_max):
def test_get_image_transforms_hue(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(hue_weight=1.0, hue_min_max=min_max)
tf_expected = v2.ColorJitter(hue=min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_sharpness(img, min_max):
def test_get_image_transforms_sharpness(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(sharpness_weight=1.0, sharpness_min_max=min_max)
tf_expected = SharpnessJitter(sharpness=min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
def test_get_image_transforms_max_num_transforms(img):
def test_get_image_transforms_max_num_transforms(img_tensor_factory):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(
brightness_min_max=(0.5, 0.5),
contrast_min_max=(0.5, 0.5),
@ -125,12 +114,13 @@ def test_get_image_transforms_max_num_transforms(img):
SharpnessJitter(sharpness=(0.5, 0.5)),
]
)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@require_x86_64_kernel
def test_get_image_transforms_random_order(img):
def test_get_image_transforms_random_order(img_tensor_factory):
out_imgs = []
img_tensor = img_tensor_factory()
tf = get_image_transforms(
brightness_min_max=(0.5, 0.5),
contrast_min_max=(0.5, 0.5),
@ -141,13 +131,14 @@ def test_get_image_transforms_random_order(img):
)
with seeded_context(1337):
for _ in range(10):
out_imgs.append(tf(img))
out_imgs.append(tf(img_tensor))
for i in range(1, len(out_imgs)):
with pytest.raises(AssertionError):
torch.testing.assert_close(out_imgs[0], out_imgs[i])
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize(
"transform, min_max_values",
[
@ -158,21 +149,24 @@ def test_get_image_transforms_random_order(img):
("sharpness", [(0.5, 0.5), (2.0, 2.0)]),
],
)
def test_backward_compatibility_torchvision(transform, min_max_values, img, single_transforms):
def test_backward_compatibility_torchvision(img_tensor_factory, transform, min_max_values, single_transforms):
img_tensor = img_tensor_factory()
for min_max in min_max_values:
kwargs = {
f"{transform}_weight": 1.0,
f"{transform}_min_max": min_max,
}
tf = get_image_transforms(**kwargs)
actual = tf(img)
actual = tf(img_tensor)
key = f"{transform}_{min_max[0]}_{min_max[1]}"
expected = single_transforms[key]
torch.testing.assert_close(actual, expected)
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@require_x86_64_kernel
def test_backward_compatibility_default_config(img, default_transforms):
def test_backward_compatibility_default_config(img_tensor_factory, default_transforms):
img_tensor = img_tensor_factory()
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
cfg_tf = cfg.training.image_transforms
default_tf = get_image_transforms(
@ -191,7 +185,7 @@ def test_backward_compatibility_default_config(img, default_transforms):
)
with seeded_context(1337):
actual = default_tf(img)
actual = default_tf(img_tensor)
expected = default_transforms["default"]
@ -199,33 +193,36 @@ def test_backward_compatibility_default_config(img, default_transforms):
@pytest.mark.parametrize("p", [[0, 1], [1, 0]])
def test_random_subset_apply_single_choice(p, img):
def test_random_subset_apply_single_choice(img_tensor_factory, p):
img_tensor = img_tensor_factory()
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
random_choice = RandomSubsetApply(flips, p=p, n_subset=1, random_order=False)
actual = random_choice(img)
actual = random_choice(img_tensor)
p_horz, _ = p
if p_horz:
torch.testing.assert_close(actual, F.horizontal_flip(img))
torch.testing.assert_close(actual, F.horizontal_flip(img_tensor))
else:
torch.testing.assert_close(actual, F.vertical_flip(img))
torch.testing.assert_close(actual, F.vertical_flip(img_tensor))
def test_random_subset_apply_random_order(img):
def test_random_subset_apply_random_order(img_tensor_factory):
img_tensor = img_tensor_factory()
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
random_order = RandomSubsetApply(flips, p=[0.5, 0.5], n_subset=2, random_order=True)
# We can't really check whether the transforms are actually applied in random order. However,
# horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform
# applies them in random order, we can use a fixed order to compute the expected value.
actual = random_order(img)
expected = v2.Compose(flips)(img)
actual = random_order(img_tensor)
expected = v2.Compose(flips)(img_tensor)
torch.testing.assert_close(actual, expected)
def test_random_subset_apply_valid_transforms(color_jitters, img):
def test_random_subset_apply_valid_transforms(img_tensor_factory, color_jitters):
img_tensor = img_tensor_factory()
transform = RandomSubsetApply(color_jitters)
output = transform(img)
assert output.shape == img.shape
output = transform(img_tensor)
assert output.shape == img_tensor.shape
def test_random_subset_apply_probability_length_mismatch(color_jitters):
@ -239,16 +236,18 @@ def test_random_subset_apply_invalid_n_subset(color_jitters, n_subset):
RandomSubsetApply(color_jitters, n_subset=n_subset)
def test_sharpness_jitter_valid_range_tuple(img):
def test_sharpness_jitter_valid_range_tuple(img_tensor_factory):
img_tensor = img_tensor_factory()
tf = SharpnessJitter((0.1, 2.0))
output = tf(img)
assert output.shape == img.shape
output = tf(img_tensor)
assert output.shape == img_tensor.shape
def test_sharpness_jitter_valid_range_float(img):
def test_sharpness_jitter_valid_range_float(img_tensor_factory):
img_tensor = img_tensor_factory()
tf = SharpnessJitter(0.5)
output = tf(img)
assert output.shape == img.shape
output = tf(img_tensor)
assert output.shape == img_tensor.shape
def test_sharpness_jitter_invalid_range_min_negative():
@ -261,6 +260,7 @@ def test_sharpness_jitter_invalid_range_max_smaller():
SharpnessJitter((2.0, 0.1))
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize(
"repo_id, n_examples",
[

359
tests/test_image_writer.py Normal file
View File

@ -0,0 +1,359 @@
import queue
import time
from multiprocessing import queues
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from PIL import Image
from lerobot.common.datasets.image_writer import (
AsyncImageWriter,
image_array_to_image,
safe_stop_image_writer,
write_image,
)
DUMMY_IMAGE = "test_image.png"
def test_init_threading():
writer = AsyncImageWriter(num_processes=0, num_threads=2)
try:
assert writer.num_processes == 0
assert writer.num_threads == 2
assert isinstance(writer.queue, queue.Queue)
assert len(writer.threads) == 2
assert len(writer.processes) == 0
assert all(t.is_alive() for t in writer.threads)
finally:
writer.stop()
def test_init_multiprocessing():
writer = AsyncImageWriter(num_processes=2, num_threads=2)
try:
assert writer.num_processes == 2
assert writer.num_threads == 2
assert isinstance(writer.queue, queues.JoinableQueue)
assert len(writer.threads) == 0
assert len(writer.processes) == 2
assert all(p.is_alive() for p in writer.processes)
finally:
writer.stop()
def test_zero_threads():
with pytest.raises(ValueError):
AsyncImageWriter(num_processes=0, num_threads=0)
def test_image_array_to_image_rgb(img_array_factory):
img_array = img_array_factory(100, 100)
result_image = image_array_to_image(img_array)
assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100)
assert result_image.mode == "RGB"
def test_image_array_to_image_pytorch_format(img_array_factory):
img_array = img_array_factory(100, 100).transpose(2, 0, 1)
result_image = image_array_to_image(img_array)
assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100)
assert result_image.mode == "RGB"
@pytest.mark.skip("TODO: implement")
def test_image_array_to_image_single_channel(img_array_factory):
img_array = img_array_factory(channels=1)
result_image = image_array_to_image(img_array)
assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100)
assert result_image.mode == "L"
def test_image_array_to_image_float_array(img_array_factory):
img_array = img_array_factory(dtype=np.float32)
result_image = image_array_to_image(img_array)
assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100)
assert result_image.mode == "RGB"
assert np.array(result_image).dtype == np.uint8
def test_image_array_to_image_out_of_bounds_float():
# Float array with values out of [0, 1]
img_array = np.random.uniform(-1, 2, size=(100, 100, 3)).astype(np.float32)
result_image = image_array_to_image(img_array)
assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100)
assert result_image.mode == "RGB"
assert np.array(result_image).dtype == np.uint8
assert np.array(result_image).min() >= 0 and np.array(result_image).max() <= 255
def test_write_image_numpy(tmp_path, img_array_factory):
image_array = img_array_factory()
fpath = tmp_path / DUMMY_IMAGE
write_image(image_array, fpath)
assert fpath.exists()
saved_image = np.array(Image.open(fpath))
assert np.array_equal(image_array, saved_image)
def test_write_image_image(tmp_path, img_factory):
image_pil = img_factory()
fpath = tmp_path / DUMMY_IMAGE
write_image(image_pil, fpath)
assert fpath.exists()
saved_image = Image.open(fpath)
assert list(saved_image.getdata()) == list(image_pil.getdata())
assert np.array_equal(image_pil, saved_image)
def test_write_image_exception(tmp_path):
image_array = "invalid data"
fpath = tmp_path / DUMMY_IMAGE
with patch("builtins.print") as mock_print:
write_image(image_array, fpath)
mock_print.assert_called()
assert not fpath.exists()
def test_save_image_numpy(tmp_path, img_array_factory):
writer = AsyncImageWriter()
try:
image_array = img_array_factory()
fpath = tmp_path / DUMMY_IMAGE
fpath.parent.mkdir(parents=True, exist_ok=True)
writer.save_image(image_array, fpath)
writer.wait_until_done()
assert fpath.exists()
saved_image = np.array(Image.open(fpath))
assert np.array_equal(image_array, saved_image)
finally:
writer.stop()
def test_save_image_numpy_multiprocessing(tmp_path, img_array_factory):
writer = AsyncImageWriter(num_processes=2, num_threads=2)
try:
image_array = img_array_factory()
fpath = tmp_path / DUMMY_IMAGE
writer.save_image(image_array, fpath)
writer.wait_until_done()
assert fpath.exists()
saved_image = np.array(Image.open(fpath))
assert np.array_equal(image_array, saved_image)
finally:
writer.stop()
def test_save_image_torch(tmp_path, img_tensor_factory):
writer = AsyncImageWriter()
try:
image_tensor = img_tensor_factory()
fpath = tmp_path / DUMMY_IMAGE
fpath.parent.mkdir(parents=True, exist_ok=True)
writer.save_image(image_tensor, fpath)
writer.wait_until_done()
assert fpath.exists()
saved_image = np.array(Image.open(fpath))
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
assert np.array_equal(expected_image, saved_image)
finally:
writer.stop()
def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory):
writer = AsyncImageWriter(num_processes=2, num_threads=2)
try:
image_tensor = img_tensor_factory()
fpath = tmp_path / DUMMY_IMAGE
writer.save_image(image_tensor, fpath)
writer.wait_until_done()
assert fpath.exists()
saved_image = np.array(Image.open(fpath))
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
assert np.array_equal(expected_image, saved_image)
finally:
writer.stop()
def test_save_image_pil(tmp_path, img_factory):
writer = AsyncImageWriter()
try:
image_pil = img_factory()
fpath = tmp_path / DUMMY_IMAGE
fpath.parent.mkdir(parents=True, exist_ok=True)
writer.save_image(image_pil, fpath)
writer.wait_until_done()
assert fpath.exists()
saved_image = Image.open(fpath)
assert list(saved_image.getdata()) == list(image_pil.getdata())
finally:
writer.stop()
def test_save_image_pil_multiprocessing(tmp_path, img_factory):
writer = AsyncImageWriter(num_processes=2, num_threads=2)
try:
image_pil = img_factory()
fpath = tmp_path / DUMMY_IMAGE
writer.save_image(image_pil, fpath)
writer.wait_until_done()
assert fpath.exists()
saved_image = Image.open(fpath)
assert list(saved_image.getdata()) == list(image_pil.getdata())
finally:
writer.stop()
def test_save_image_invalid_data(tmp_path):
writer = AsyncImageWriter()
try:
image_array = "invalid data"
fpath = tmp_path / DUMMY_IMAGE
fpath.parent.mkdir(parents=True, exist_ok=True)
with patch("builtins.print") as mock_print:
writer.save_image(image_array, fpath)
writer.wait_until_done()
mock_print.assert_called()
assert not fpath.exists()
finally:
writer.stop()
def test_save_image_after_stop(tmp_path, img_array_factory):
writer = AsyncImageWriter()
writer.stop()
image_array = img_array_factory()
fpath = tmp_path / DUMMY_IMAGE
writer.save_image(image_array, fpath)
time.sleep(1)
assert not fpath.exists()
def test_stop():
writer = AsyncImageWriter(num_processes=0, num_threads=2)
writer.stop()
assert not any(t.is_alive() for t in writer.threads)
def test_stop_multiprocessing():
writer = AsyncImageWriter(num_processes=2, num_threads=2)
writer.stop()
assert not any(p.is_alive() for p in writer.processes)
def test_multiple_stops():
writer = AsyncImageWriter()
writer.stop()
writer.stop() # Should not raise an exception
assert not any(t.is_alive() for t in writer.threads)
def test_multiple_stops_multiprocessing():
writer = AsyncImageWriter(num_processes=2, num_threads=2)
writer.stop()
writer.stop() # Should not raise an exception
assert not any(t.is_alive() for t in writer.threads)
def test_wait_until_done(tmp_path, img_array_factory):
writer = AsyncImageWriter(num_processes=0, num_threads=4)
try:
num_images = 100
image_arrays = [img_array_factory(height=500, width=500) for _ in range(num_images)]
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
fpath.parent.mkdir(parents=True, exist_ok=True)
writer.save_image(image_array, fpath)
writer.wait_until_done()
for i, fpath in enumerate(fpaths):
assert fpath.exists()
saved_image = np.array(Image.open(fpath))
assert np.array_equal(saved_image, image_arrays[i])
finally:
writer.stop()
def test_wait_until_done_multiprocessing(tmp_path, img_array_factory):
writer = AsyncImageWriter(num_processes=2, num_threads=2)
try:
num_images = 100
image_arrays = [img_array_factory() for _ in range(num_images)]
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
fpath.parent.mkdir(parents=True, exist_ok=True)
writer.save_image(image_array, fpath)
writer.wait_until_done()
for i, fpath in enumerate(fpaths):
assert fpath.exists()
saved_image = np.array(Image.open(fpath))
assert np.array_equal(saved_image, image_arrays[i])
finally:
writer.stop()
def test_exception_handling(tmp_path, img_array_factory):
writer = AsyncImageWriter()
try:
image_array = img_array_factory()
with (
patch.object(writer.queue, "put", side_effect=queue.Full("Queue is full")),
pytest.raises(queue.Full) as exc_info,
):
writer.save_image(image_array, tmp_path / "test.png")
assert str(exc_info.value) == "Queue is full"
finally:
writer.stop()
def test_with_different_image_formats(tmp_path, img_array_factory):
writer = AsyncImageWriter()
try:
image_array = img_array_factory()
formats = ["png", "jpeg", "bmp"]
for fmt in formats:
fpath = tmp_path / f"test_image.{fmt}"
write_image(image_array, fpath)
assert fpath.exists()
finally:
writer.stop()
def test_safe_stop_image_writer_decorator():
class MockDataset:
def __init__(self):
self.image_writer = MagicMock(spec=AsyncImageWriter)
@safe_stop_image_writer
def function_that_raises_exception(dataset=None):
raise Exception("Test exception")
dataset = MockDataset()
with pytest.raises(Exception) as exc_info:
function_that_raises_exception(dataset=dataset)
assert str(exc_info.value) == "Test exception"
dataset.image_writer.stop.assert_called_once()
def test_main_process_time(tmp_path, img_tensor_factory):
writer = AsyncImageWriter()
try:
image_tensor = img_tensor_factory()
fpath = tmp_path / DUMMY_IMAGE
start_time = time.perf_counter()
writer.save_image(image_tensor, fpath)
end_time = time.perf_counter()
time_spent = end_time - start_time
# Might need to adjust this threshold depending on hardware
assert time_spent < 0.01, f"Main process time exceeded threshold: {time_spent}s"
writer.wait_until_done()
assert fpath.exists()
finally:
writer.stop()

View File

@ -19,11 +19,8 @@ from uuid import uuid4
import numpy as np
import pytest
import torch
from datasets import Dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
from lerobot.common.datasets.utils import hf_transform_to_torch
# Some constants for OnlineBuffer tests.
data_key = "data"
@ -212,29 +209,17 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range():
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
@pytest.mark.parametrize("offline_dataset_size", [0, 6])
@pytest.mark.parametrize("offline_dataset_size", [1, 6])
@pytest.mark.parametrize("online_dataset_size", [0, 4])
@pytest.mark.parametrize("online_sampling_ratio", [0.0, 1.0])
def test_compute_sampler_weights_trivial(
offline_dataset_size: int, online_dataset_size: int, online_sampling_ratio: float
lerobot_dataset_factory,
tmp_path,
offline_dataset_size: int,
online_dataset_size: int,
online_sampling_ratio: float,
):
# Pass/skip the test if both datasets sizes are zero.
if offline_dataset_size + online_dataset_size == 0:
return
# Create spoof offline dataset.
offline_dataset = LeRobotDataset.from_preloaded(
hf_dataset=Dataset.from_dict({"data": list(range(offline_dataset_size))})
)
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
if offline_dataset_size == 0:
offline_dataset.episode_data_index = {}
else:
# Set up an episode_data_index with at least two episodes.
offline_dataset.episode_data_index = {
"from": torch.tensor([0, offline_dataset_size // 2]),
"to": torch.tensor([offline_dataset_size // 2, offline_dataset_size]),
}
# Create spoof online datset.
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size)
online_dataset, _ = make_new_buffer()
if online_dataset_size > 0:
online_dataset.add_data(
@ -254,16 +239,9 @@ def test_compute_sampler_weights_trivial(
assert torch.allclose(weights, expected_weights)
def test_compute_sampler_weights_nontrivial_ratio():
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path):
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
# Create spoof offline dataset.
offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict({"data": list(range(4))}))
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
offline_dataset.episode_data_index = {
"from": torch.tensor([0, 2]),
"to": torch.tensor([2, 4]),
}
# Create spoof online datset.
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
online_sampling_ratio = 0.8
@ -275,16 +253,9 @@ def test_compute_sampler_weights_nontrivial_ratio():
)
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n():
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path):
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
# Create spoof offline dataset.
offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict({"data": list(range(4))}))
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
offline_dataset.episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([4]),
}
# Create spoof online datset.
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
weights = compute_sampler_weights(
@ -295,18 +266,9 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n():
)
def test_compute_sampler_weights_drop_n_last_frames():
def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path):
"""Note: test copied from test_sampler."""
data_dict = {
"timestamp": [0, 0.1],
"index": [0, 1],
"episode_index": [0, 0],
"frame_index": [0, 1],
}
offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict(data_dict))
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
offline_dataset.episode_data_index = {"from": torch.tensor([0]), "to": torch.tensor([2])}
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))

View File

@ -50,7 +50,7 @@ def test_get_policy_and_config_classes(policy_name: str):
assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation)
# TODO(aliberts): refactor using lerobot/__init__.py variables
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize(
"env_name,policy_name,extra_overrides",
[
@ -136,7 +136,7 @@ def test_policy(env_name, policy_name, extra_overrides):
# Check that we can make the policy object.
dataset = make_dataset(cfg)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
# Check that the policy follows the required protocol.
assert isinstance(
policy, Policy
@ -195,6 +195,7 @@ def test_policy(env_name, policy_name, extra_overrides):
env.step(action)
@pytest.mark.skip("TODO after v2 migration / removing hydra")
def test_act_backbone_lr():
"""
Test that the ACT policy can be instantiated with a different learning rate for the backbone.
@ -213,7 +214,7 @@ def test_act_backbone_lr():
assert cfg.training.lr_backbone == 0.001
dataset = make_dataset(cfg)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
assert len(optimizer.param_groups) == 2
assert optimizer.param_groups[0]["lr"] == cfg.training.lr
@ -351,6 +352,7 @@ def test_normalize(insert_temporal_dim):
unnormalize(output_batch)
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize(
"env_name, policy_name, extra_overrides, file_name_extra",
[

View File

@ -250,6 +250,7 @@ def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
)
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize(
"required_packages, raw_format, repo_id, make_test_data",
[

View File

@ -15,9 +15,9 @@
# limitations under the License.
from datasets import Dataset
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)

View File

@ -7,10 +7,9 @@ import pytest
import torch
from datasets import Dataset
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
reset_episode_index,
)
from lerobot.common.utils.utils import (
get_global_random_state,
@ -73,20 +72,6 @@ def test_calculate_episode_data_index():
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
def test_reset_episode_index():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [10, 10, 11, 12, 12, 12],
},
)
dataset.set_transform(hf_transform_to_torch)
correct_episode_index = [0, 0, 1, 2, 2, 2]
dataset = reset_episode_index(dataset)
assert dataset["episode_index"] == correct_episode_index
def test_init_hydra_config_empty():
test_file = f"/tmp/test_init_hydra_config_empty_{uuid4().hex}.yaml"
with open(test_file, "w") as f:

View File

@ -13,25 +13,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import pytest
from lerobot.scripts.visualize_dataset import visualize_dataset
@pytest.mark.parametrize(
"repo_id",
["lerobot/pusht"],
)
@pytest.mark.parametrize("root", [Path(__file__).parent / "data"])
def test_visualize_local_dataset(tmpdir, repo_id, root):
@pytest.mark.skip("TODO: add dummy videos")
def test_visualize_local_dataset(tmp_path, lerobot_dataset_factory):
root = tmp_path / "dataset"
output_dir = tmp_path / "outputs"
dataset = lerobot_dataset_factory(root=root)
rrd_path = visualize_dataset(
repo_id,
dataset,
episode_index=0,
batch_size=32,
save=True,
output_dir=tmpdir,
root=root,
output_dir=output_dir,
)
assert rrd_path.exists()

View File

@ -14,23 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import pytest
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
@pytest.mark.parametrize(
"repo_id",
["lerobot/pusht"],
)
def test_visualize_dataset_html(tmpdir, repo_id):
tmpdir = Path(tmpdir)
def test_visualize_dataset_html(tmp_path, lerobot_dataset_factory):
root = tmp_path / "dataset"
output_dir = tmp_path / "outputs"
dataset = lerobot_dataset_factory(root=root)
visualize_dataset_html(
repo_id,
dataset,
episodes=[0],
output_dir=tmpdir,
output_dir=output_dir,
serve=False,
)
assert (tmpdir / "static" / "episode_0.csv").exists()
assert (output_dir / "static" / "episode_0.csv").exists()