Add LeRobotDatasetMetadata

This commit is contained in:
Simon Alibert 2024-11-03 18:07:37 +01:00
parent ac79e8cb36
commit e4ba084e25
25 changed files with 419 additions and 327 deletions

View File

@ -266,7 +266,7 @@ def benchmark_encoding_decoding(
)
ep_num_images = dataset.episode_data_index["to"][0].item()
width, height = tuple(dataset[0][dataset.camera_keys[0]].shape[-2:])
width, height = tuple(dataset[0][dataset.meta.camera_keys[0]].shape[-2:])
num_pixels = width * height
video_size_bytes = video_path.stat().st_size
images_size_bytes = get_directory_size(imgs_dir)

View File

@ -13,6 +13,7 @@ Features included in this script:
The script ends with examples of how to batch process data using PyTorch's DataLoader.
"""
# TODO(aliberts, rcadene): Update this script with the new v2 api
from pathlib import Path
from pprint import pprint
@ -31,7 +32,7 @@ repo_id = "lerobot/pusht"
# You can easily load a dataset from a Hugging Face repository
dataset = LeRobotDataset(repo_id)
# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset
# LeRobotDataset actually wraps an underlying Hugging Face dataset
# (see https://huggingface.co/docs/datasets/index for more information).
print(dataset)
print(dataset.hf_dataset)
@ -39,7 +40,7 @@ print(dataset.hf_dataset)
# And provides additional utilities for robotics and compatibility with Pytorch
print(f"\naverage number of frames per episode: {dataset.num_frames / dataset.num_episodes:.3f}")
print(f"frames per second used during data collection: {dataset.fps=}")
print(f"keys to access images from cameras: {dataset.camera_keys=}\n")
print(f"keys to access images from cameras: {dataset.meta.camera_keys=}\n")
# Access frame indexes associated to first episode
episode_index = 0
@ -60,14 +61,15 @@ frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_0.mp4", frames, fps=dataset.fps)
# For many machine learning applications we need to load the history of past observations or trajectories of
# future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
# differences with the current loaded frame. For instance:
delta_timestamps = {
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
"observation.image": [-1, -0.5, -0.20, 0],
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 20 ms, 10 ms, and current frame
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, -0.02, -0.01, 0],
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
"action": [t / dataset.fps for t in range(64)],
}

View File

@ -40,7 +40,7 @@ dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)
# For this example, no arguments need to be passed because the defaults are set up for PushT.
# If you're doing something different, you will likely need to change at least some of the defaults.
cfg = DiffusionConfig()
policy = DiffusionPolicy(cfg, dataset_stats=dataset.stats)
policy = DiffusionPolicy(cfg, dataset_stats=dataset.meta.stats)
policy.train()
policy.to(device)

View File

@ -20,7 +20,7 @@ dataset = LeRobotDataset(dataset_repo_id)
first_idx = dataset.episode_data_index["from"][0].item()
# Get the frame corresponding to the first camera
frame = dataset[first_idx][dataset.camera_keys[0]]
frame = dataset[first_idx][dataset.meta.camera_keys[0]]
# Define the transformations
@ -36,7 +36,7 @@ transforms = v2.Compose(
transformed_dataset = LeRobotDataset(dataset_repo_id, image_transforms=transforms)
# Get a frame from the transformed dataset
transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]]
transformed_frame = transformed_dataset[first_idx][transformed_dataset.meta.camera_keys[0]]
# Create a directory to store output images
output_dir = Path("outputs/image_transforms")

View File

@ -8,6 +8,7 @@ especially in the context of imitation learning. The most reliable approach is t
on the target environment, whether that be in simulation or the real world.
"""
# TODO(aliberts, rcadene): Update this script with the new v2 api
import math
from pathlib import Path

View File

@ -42,7 +42,7 @@ def get_stats_einops_patterns(dataset, num_workers=0):
assert batch[key].dtype != torch.float64
# if isinstance(feats_type, (VideoFrame, Image)):
if key in dataset.camera_keys:
if key in dataset.meta.camera_keys:
# sanity check that images are channel first
_, c, h, w = batch[key].shape
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"

View File

@ -111,6 +111,6 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
for stats_type, listconfig in stats_dict.items():
# example of stats_type: min, max, mean, std
stats = OmegaConf.to_container(listconfig, resolve=True)
dataset.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
return dataset

View File

@ -45,7 +45,7 @@ from lerobot.common.datasets.utils import (
get_episode_data_index,
get_hub_safe_version,
hf_transform_to_torch,
load_episode_dicts,
load_episodes,
load_info,
load_stats,
load_tasks,
@ -66,6 +66,237 @@ CODEBASE_VERSION = "v2.0"
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
class LeRobotDatasetMetadata:
def __init__(
self,
repo_id: str,
root: Path | None = None,
local_files_only: bool = False,
):
self.repo_id = repo_id
self.root = root if root is not None else LEROBOT_HOME / repo_id
self.local_files_only = local_files_only
# Load metadata
(self.root / "meta").mkdir(exist_ok=True, parents=True)
self.pull_from_repo(allow_patterns="meta/")
self.info = load_info(self.root)
self.stats = load_stats(self.root)
self.tasks = load_tasks(self.root)
self.episodes = load_episodes(self.root)
def pull_from_repo(
self,
allow_patterns: list[str] | str | None = None,
ignore_patterns: list[str] | str | None = None,
) -> None:
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self._hub_version,
local_dir=self.root,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
local_files_only=self.local_files_only,
)
@cached_property
def _hub_version(self) -> str | None:
return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
@property
def _version(self) -> str:
"""Codebase version used to create this dataset."""
return self.info["codebase_version"]
def get_data_file_path(self, ep_index: int) -> Path:
ep_chunk = self.get_episode_chunk(ep_index)
fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index)
return Path(fpath)
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
ep_chunk = self.get_episode_chunk(ep_index)
fpath = self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
return Path(fpath)
def get_episode_chunk(self, ep_index: int) -> int:
return ep_index // self.chunks_size
@property
def data_path(self) -> str:
"""Formattable string for the parquet files."""
return self.info["data_path"]
@property
def videos_path(self) -> str | None:
"""Formattable string for the video files."""
return self.info["videos"]["videos_path"] if len(self.video_keys) > 0 else None
@property
def fps(self) -> int:
"""Frames per second used during data collection."""
return self.info["fps"]
@property
def keys(self) -> list[str]:
"""Keys to access non-image data (state, actions etc.)."""
return self.info["keys"]
@property
def image_keys(self) -> list[str]:
"""Keys to access visual modalities stored as images."""
return self.info["image_keys"]
@property
def video_keys(self) -> list[str]:
"""Keys to access visual modalities stored as videos."""
return self.info["video_keys"]
@property
def camera_keys(self) -> list[str]:
"""Keys to access visual modalities (regardless of their storage method)."""
return self.image_keys + self.video_keys
@property
def names(self) -> dict[list[str]]:
"""Names of the various dimensions of vector modalities."""
return self.info["names"]
@property
def total_episodes(self) -> int:
"""Total number of episodes available."""
return self.info["total_episodes"]
@property
def total_frames(self) -> int:
"""Total number of frames saved in this dataset."""
return self.info["total_frames"]
@property
def total_tasks(self) -> int:
"""Total number of different tasks performed in this dataset."""
return self.info["total_tasks"]
@property
def total_chunks(self) -> int:
"""Total number of chunks (groups of episodes)."""
return self.info["total_chunks"]
@property
def chunks_size(self) -> int:
"""Max number of episodes per chunk."""
return self.info["chunks_size"]
@property
def shapes(self) -> dict:
"""Shapes for the different features."""
return self.info["shapes"]
@property
def task_to_task_index(self) -> dict:
return {task: task_idx for task_idx, task in self.tasks.items()}
def get_task_index(self, task: str) -> int:
"""
Given a task in natural language, returns its task_index if the task already exists in the dataset,
otherwise creates a new task_index.
"""
task_index = self.task_to_task_index.get(task, None)
return task_index if task_index is not None else self.total_tasks
def add_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None:
self.info["total_episodes"] += 1
self.info["total_frames"] += episode_length
if task_index not in self.tasks:
self.info["total_tasks"] += 1
self.tasks[task_index] = task
task_dict = {
"task_index": task_index,
"task": task,
}
append_jsonlines(task_dict, self.root / TASKS_PATH)
chunk = self.get_episode_chunk(episode_index)
if chunk >= self.total_chunks:
self.info["total_chunks"] += 1
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
self.info["total_videos"] += len(self.video_keys)
write_json(self.info, self.root / INFO_PATH)
episode_dict = {
"episode_index": episode_index,
"tasks": [task],
"length": episode_length,
}
self.episodes.append(episode_dict)
append_jsonlines(episode_dict, self.root / EPISODES_PATH)
def write_video_info(self) -> None:
"""
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
been encoded the same way. Also, this means it assumes the first episode exists.
"""
for key in self.video_keys:
if key not in self.info["videos"]:
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
self.info["videos"][key] = get_video_info(video_path)
write_json(self.info, self.root / INFO_PATH)
@classmethod
def create(
cls,
repo_id: str,
fps: int,
root: Path | None = None,
robot: Robot | None = None,
robot_type: str | None = None,
keys: list[str] | None = None,
image_keys: list[str] | None = None,
video_keys: list[str] = None,
shapes: dict | None = None,
names: dict | None = None,
use_videos: bool = True,
) -> "LeRobotDatasetMetadata":
"""Creates metadata for a LeRobotDataset."""
obj = cls.__new__(cls)
obj.repo_id = repo_id
obj.root = root if root is not None else LEROBOT_HOME / repo_id
obj.image_writer = None
if robot is not None:
robot_type, keys, image_keys, video_keys, shapes, names = _get_info_from_robot(robot, use_videos)
if not all(cam.fps == fps for cam in robot.cameras.values()):
logging.warning(
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
)
elif (
robot_type is None
or keys is None
or image_keys is None
or video_keys is None
or shapes is None
or names is None
):
raise ValueError(
"Dataset info (robot_type, keys, shapes...) must either come from a Robot or explicitly passed upon creation."
)
if len(video_keys) > 0 and not use_videos:
raise ValueError()
obj.tasks, obj.stats, obj.episodes = {}, {}, []
obj.info = create_empty_dataset_info(
CODEBASE_VERSION, fps, robot_type, keys, image_keys, video_keys, shapes, names
)
write_json(obj.info, obj.root / INFO_PATH)
obj.local_files_only = True
return obj
class LeRobotDataset(torch.utils.data.Dataset):
def __init__(
self,
@ -86,9 +317,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
- On your local disk in the 'root' folder. This is typically the case when you recorded your
dataset locally and you may or may not have pushed it to the hub yet. Instantiating this class
with 'root' will load your dataset directly from disk. This can happen while you're offline (no
internet connection).
internet connection), in that case, use local_files_only=True.
- On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and is not on
- On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on
your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download
the dataset from that address and load it, pending your dataset is compliant with
codebase_version v2.0. If your dataset has been created before this new format, you will be
@ -96,9 +327,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py.
2. Your dataset doesn't already exists (either on local disk or on the Hub):
You can create an empty LeRobotDataset with the 'create' classmethod. This can be used for
recording a dataset or port an existing dataset to the LeRobotDataset format.
2. Your dataset doesn't already exists (either on local disk or on the Hub): you can create an empty
LeRobotDataset with the 'create' classmethod. This can be used for recording a dataset or port an
existing dataset to the LeRobotDataset format.
In terms of files, LeRobotDataset encapsulates 3 main things:
@ -192,21 +423,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.image_writer = None
self.episode_buffer = {}
# Load metadata
self.root.mkdir(exist_ok=True, parents=True)
self.pull_from_repo(allow_patterns="meta/")
self.info = load_info(self.root)
self.stats = load_stats(self.root)
self.tasks = load_tasks(self.root)
self.episode_dicts = load_episode_dicts(self.root)
# Load metadata
self.meta = LeRobotDatasetMetadata(self.repo_id, self.root, self.local_files_only)
# Check version
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
# Load actual data
self.download_episodes(download_videos)
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.episode_dicts, self.episodes)
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
# Check timestamps
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
@ -216,26 +444,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
# TODO(aliberts):
# - [X] Move delta_timestamp logic outside __get_item__
# - [X] Update __get_item__
# - [/] Add doc
# - [ ] Add self.add_frame()
# - [ ] Add self.consolidate() for:
# - [X] Check timestamps sync
# - [ ] Sanity checks (episodes num, shapes, files, etc.)
# - [ ] Update episode_index (arg update=True)
# - [ ] Update info.json (arg update=True)
@cached_property
def _hub_version(self) -> str | None:
return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
@property
def _version(self) -> str:
"""Codebase version used to create this dataset."""
return self.info["codebase_version"]
def push_to_hub(self, push_videos: bool = True) -> None:
if not self.consolidated:
raise RuntimeError(
@ -262,7 +470,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self._hub_version,
revision=self.meta._hub_version,
local_dir=self.root,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
@ -280,11 +488,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
files = None
ignore_patterns = None if download_videos else "videos/"
if self.episodes is not None:
files = [str(self.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
if len(self.video_keys) > 0 and download_videos:
files = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
if len(self.meta.video_keys) > 0 and download_videos:
video_files = [
str(self.get_video_file_path(ep_idx, vid_key))
for vid_key in self.video_keys
str(self.meta.get_video_file_path(ep_idx, vid_key))
for vid_key in self.meta.video_keys
for ep_idx in self.episodes
]
files += video_files
@ -297,108 +505,30 @@ class LeRobotDataset(torch.utils.data.Dataset):
path = str(self.root / "data")
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
else:
files = [str(self.root / self.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
hf_dataset = load_dataset("parquet", data_files=files, split="train")
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def get_data_file_path(self, ep_index: int) -> Path:
ep_chunk = self.get_episode_chunk(ep_index)
fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index)
return Path(fpath)
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
ep_chunk = self.get_episode_chunk(ep_index)
fpath = self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
return Path(fpath)
def get_episode_chunk(self, ep_index: int) -> int:
return ep_index // self.chunks_size
@property
def data_path(self) -> str:
"""Formattable string for the parquet files."""
return self.info["data_path"]
@property
def videos_path(self) -> str | None:
"""Formattable string for the video files."""
return self.info["videos"]["videos_path"] if len(self.video_keys) > 0 else None
@property
def fps(self) -> int:
"""Frames per second used during data collection."""
return self.info["fps"]
@property
def keys(self) -> list[str]:
"""Keys to access non-image data (state, actions etc.)."""
return self.info["keys"]
@property
def image_keys(self) -> list[str]:
"""Keys to access visual modalities stored as images."""
return self.info["image_keys"]
@property
def video_keys(self) -> list[str]:
"""Keys to access visual modalities stored as videos."""
return self.info["video_keys"]
@property
def camera_keys(self) -> list[str]:
"""Keys to access visual modalities (regardless of their storage method)."""
return self.image_keys + self.video_keys
@property
def names(self) -> dict[list[str]]:
"""Names of the various dimensions of vector modalities."""
return self.info["names"]
return self.meta.fps
@property
def num_frames(self) -> int:
"""Number of frames in selected episodes."""
return len(self.hf_dataset) if self.hf_dataset is not None else self.total_frames
return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames
@property
def num_episodes(self) -> int:
"""Number of episodes selected."""
return len(self.episodes) if self.episodes is not None else self.total_episodes
@property
def total_episodes(self) -> int:
"""Total number of episodes available."""
return self.info["total_episodes"]
@property
def total_frames(self) -> int:
"""Total number of frames saved in this dataset."""
return self.info["total_frames"]
@property
def total_tasks(self) -> int:
"""Total number of different tasks performed in this dataset."""
return self.info["total_tasks"]
@property
def total_chunks(self) -> int:
"""Total number of chunks (groups of episodes)."""
return self.info["total_chunks"]
@property
def chunks_size(self) -> int:
"""Max number of episodes per chunk."""
return self.info["chunks_size"]
@property
def shapes(self) -> dict:
"""Shapes for the different features."""
return self.info["shapes"]
return len(self.episodes) if self.episodes is not None else self.meta.total_episodes
@property
def features(self) -> list[str]:
return list(self._features) + self.video_keys
return list(self._features) + self.meta.video_keys
@property
def _features(self) -> datasets.Features:
@ -418,39 +548,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
features[key] = datasets.Value(dtype="bool")
elif key in ["timestamp", "next.reward"]:
features[key] = datasets.Value(dtype="float32")
elif key in self.image_keys:
elif key in self.meta.image_keys:
features[key] = datasets.Image()
elif key in self.keys:
elif key in self.meta.keys:
features[key] = datasets.Sequence(
length=self.shapes[key], feature=datasets.Value(dtype="float32")
length=self.meta.shapes[key], feature=datasets.Value(dtype="float32")
)
return datasets.Features(features)
@property
def task_to_task_index(self) -> dict:
return {task: task_idx for task_idx, task in self.tasks.items()}
def get_task_index(self, task: str) -> int:
"""
Given a task in natural language, returns its task_index if the task already exists in the dataset,
otherwise creates a new task_index.
"""
task_index = self.task_to_task_index.get(task, None)
return task_index if task_index is not None else self.total_tasks
def current_episode_index(self, idx: int) -> int:
episode_index = self.hf_dataset["episode_index"][idx]
if self.episodes is not None:
# get episode_index from selected episodes
episode_index = self.episodes.index(episode_index)
return episode_index
def episode_length(self, episode_index) -> int:
"""Number of samples/frames for given episode."""
return self.info["episodes"][episode_index]["length"]
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
ep_start = self.episode_data_index["from"][ep_idx]
ep_end = self.episode_data_index["to"][ep_idx]
@ -472,7 +578,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
query_indices: dict[str, list[int]] | None = None,
) -> dict[str, list[float]]:
query_timestamps = {}
for key in self.video_keys:
for key in self.meta.video_keys:
if query_indices is not None and key in query_indices:
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
query_timestamps[key] = torch.stack(timestamps).tolist()
@ -485,7 +591,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
return {
key: torch.stack(self.hf_dataset.select(q_idx)[key])
for key, q_idx in query_indices.items()
if key not in self.video_keys
if key not in self.meta.video_keys
}
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
@ -496,7 +602,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""
item = {}
for vid_key, query_ts in query_timestamps.items():
video_path = self.root / self.get_video_file_path(ep_idx, vid_key)
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames_torchvision(
video_path, query_ts, self.tolerance_s, self.video_backend
)
@ -525,14 +631,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
for key, val in query_result.items():
item[key] = val
if len(self.video_keys) > 0:
if len(self.meta.video_keys) > 0:
current_ts = item["timestamp"].item()
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
video_frames = self._query_videos(query_timestamps, ep_idx)
item = {**video_frames, **item}
if self.image_transforms is not None:
image_keys = self.camera_keys
image_keys = self.meta.camera_keys
for cam in image_keys:
item[cam] = self.image_transforms(item[cam])
@ -545,20 +651,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
f" Selected episodes: {self.episodes},\n"
f" Number of selected episodes: {self.num_episodes},\n"
f" Number of selected samples: {self.num_frames},\n"
f"\n{json.dumps(self.info, indent=4)}\n"
f"\n{json.dumps(self.meta.info, indent=4)}\n"
)
def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
# TODO(aliberts): Handle resume
return {
"size": 0,
"episode_index": self.total_episodes if episode_index is None else episode_index,
"episode_index": self.meta.total_episodes if episode_index is None else episode_index,
"task_index": None,
"frame_index": [],
"timestamp": [],
"next.done": [],
**{key: [] for key in self.keys},
**{key: [] for key in self.image_keys},
**{key: [] for key in self.meta.keys},
**{key: [] for key in self.meta.image_keys},
}
def add_frame(self, frame: dict) -> None:
@ -573,7 +679,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_buffer["next.done"].append(False)
# Save all observed modalities except images
for key in self.keys:
for key in self.meta.keys:
self.episode_buffer[key].append(frame[key])
self.episode_buffer["size"] += 1
@ -582,7 +688,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
return
# Save images
for cam_key in self.camera_keys:
for cam_key in self.meta.camera_keys:
img_path = self.image_writer.get_image_file_path(
episode_index=self.episode_buffer["episode_index"], image_key=cam_key, frame_index=frame_index
)
@ -594,7 +700,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
fpath=img_path,
)
if cam_key in self.image_keys:
if cam_key in self.meta.image_keys:
self.episode_buffer[cam_key].append(str(img_path))
def add_episode(self, task: str, encode_videos: bool = False) -> None:
@ -609,17 +715,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""
episode_length = self.episode_buffer.pop("size")
episode_index = self.episode_buffer["episode_index"]
if episode_index != self.total_episodes:
if episode_index != self.meta.total_episodes:
# TODO(aliberts): Add option to use existing episode_index
raise NotImplementedError()
task_index = self.get_task_index(task)
task_index = self.meta.get_task_index(task)
self.episode_buffer["next.done"][-1] = True
for key in self.episode_buffer:
if key in self.image_keys:
if key in self.meta.image_keys:
continue
elif key in self.keys:
elif key in self.meta.keys:
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
elif key == "episode_index":
self.episode_buffer[key] = torch.full((episode_length,), episode_index)
@ -628,13 +734,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
else:
self.episode_buffer[key] = torch.tensor(self.episode_buffer[key])
self.episode_buffer["index"] = torch.arange(self.total_frames, self.total_frames + episode_length)
self._save_episode_to_metadata(episode_index, episode_length, task, task_index)
self.episode_buffer["index"] = torch.arange(
self.meta.total_frames, self.meta.total_frames + episode_length
)
self.meta.add_episode(episode_index, episode_length, task, task_index)
self._wait_image_writer()
self._save_episode_table(episode_index)
if encode_videos and len(self.video_keys) > 0:
if encode_videos and len(self.meta.video_keys) > 0:
self.encode_videos()
# Reset the buffer
@ -643,45 +751,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
def _save_episode_table(self, episode_index: int) -> None:
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self._features, split="train")
ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index)
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
write_parquet(ep_dataset, ep_data_path)
def _save_episode_to_metadata(
self, episode_index: int, episode_length: int, task: str, task_index: int
) -> None:
self.info["total_episodes"] += 1
self.info["total_frames"] += episode_length
if task_index not in self.tasks:
self.info["total_tasks"] += 1
self.tasks[task_index] = task
task_dict = {
"task_index": task_index,
"task": task,
}
append_jsonlines(task_dict, self.root / TASKS_PATH)
chunk = self.get_episode_chunk(episode_index)
if chunk >= self.total_chunks:
self.info["total_chunks"] += 1
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
self.info["total_videos"] += len(self.video_keys)
write_json(self.info, self.root / INFO_PATH)
episode_dict = {
"episode_index": episode_index,
"tasks": [task],
"length": episode_length,
}
self.episode_dicts.append(episode_dict)
append_jsonlines(episode_dict, self.root / EPISODES_PATH)
def clear_episode_buffer(self) -> None:
episode_index = self.episode_buffer["episode_index"]
if self.image_writer is not None:
for cam_key in self.camera_keys:
for cam_key in self.meta.camera_keys:
img_dir = self.image_writer.get_episode_dir(episode_index, cam_key)
if img_dir.is_dir():
shutil.rmtree(img_dir)
@ -717,12 +794,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
def encode_videos(self) -> None:
# Use ffmpeg to convert frames stored as png into mp4 videos
for episode_index in range(self.total_episodes):
for key in self.video_keys:
for episode_index in range(self.meta.total_episodes):
for key in self.meta.video_keys:
# TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need
# to call self.image_writer here
tmp_imgs_dir = self.image_writer.get_episode_dir(episode_index, key)
video_path = self.root / self.get_video_file_path(episode_index, key)
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
if video_path.is_file():
# Skip if video is already encoded. Could be the case when resuming data recording.
continue
@ -730,40 +807,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
# since video encoding with ffmpeg is already using multithreading.
encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True)
def _write_video_info(self) -> None:
"""
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
been encoded the same way. Also, this means it assumes the first episode exists.
"""
for key in self.video_keys:
if key not in self.info["videos"]:
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
self.info["videos"][key] = get_video_info(video_path)
write_json(self.info, self.root / INFO_PATH)
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.episode_dicts, self.episodes)
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
if len(self.video_keys) > 0:
if len(self.meta.video_keys) > 0:
self.encode_videos()
self._write_video_info()
self.meta.write_video_info()
if not keep_image_files and self.image_writer is not None:
shutil.rmtree(self.image_writer.write_dir)
video_files = list(self.root.rglob("*.mp4"))
assert len(video_files) == self.num_episodes * len(self.video_keys)
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
parquet_files = list(self.root.rglob("*.parquet"))
assert len(parquet_files) == self.num_episodes
if run_compute_stats:
self.stop_image_writer()
self.stats = compute_stats(self)
write_stats(self.stats, self.root / STATS_PATH)
self.meta.stats = compute_stats(self)
write_stats(self.meta.stats, self.root / STATS_PATH)
self.consolidated = True
else:
logging.warning(
@ -780,60 +845,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
@classmethod
def create(
cls,
repo_id: str,
fps: int,
root: Path | None = None,
robot: Robot | None = None,
robot_type: str | None = None,
keys: list[str] | None = None,
image_keys: list[str] | None = None,
video_keys: list[str] = None,
shapes: dict | None = None,
names: dict | None = None,
metadata: LeRobotDatasetMetadata,
tolerance_s: float = 1e-4,
image_writer_processes: int = 0,
image_writer_threads_per_camera: int = 0,
use_videos: bool = True,
image_writer_threads: int = 0,
video_backend: str | None = None,
) -> "LeRobotDataset":
"""Create a LeRobot Dataset from scratch in order to record data."""
obj = cls.__new__(cls)
obj.repo_id = repo_id
obj.root = root if root is not None else LEROBOT_HOME / repo_id
obj.meta = metadata
obj.repo_id = obj.meta.repo_id
obj.root = obj.meta.root
obj.local_files_only = obj.meta.local_files_only
obj.tolerance_s = tolerance_s
obj.image_writer = None
if robot is not None:
robot_type, keys, image_keys, video_keys, shapes, names = _get_info_from_robot(robot, use_videos)
if not all(cam.fps == fps for cam in robot.cameras.values()):
logging.warning(
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
)
if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera):
obj.start_image_writer(
image_writer_processes, image_writer_threads_per_camera * robot.num_cameras
)
elif (
robot_type is None
or keys is None
or image_keys is None
or video_keys is None
or shapes is None
or names is None
):
raise ValueError(
"Dataset info (robot_type, keys, shapes...) must either come from a Robot or explicitly passed upon creation."
)
if len(video_keys) > 0 and not use_videos:
raise ValueError()
obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
obj.info = create_empty_dataset_info(
CODEBASE_VERSION, fps, robot_type, keys, image_keys, video_keys, shapes, names
)
write_json(obj.info, obj.root / INFO_PATH)
if image_writer_processes or image_writer_threads:
obj.start_image_writer(image_writer_processes, image_writer_threads)
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
obj.episode_buffer = obj._create_episode_buffer()
@ -849,7 +877,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.image_transforms = None
obj.delta_timestamps = None
obj.delta_indices = None
obj.local_files_only = True
obj.episode_data_index = None
obj.video_backend = video_backend if video_backend is not None else "pyav"
return obj
@ -889,7 +916,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
# Check that some properties are consistent across datasets. Note: We may relax some of these
# consistency requirements in future iterations of this class.
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
if dataset.info != self._datasets[0].info:
if dataset.meta.info != self._datasets[0].meta.info:
raise ValueError(
f"Detected a mismatch in dataset info between {self.repo_ids[0]} and {repo_id}. This is "
"not yet supported."
@ -938,7 +965,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
"""
return self._datasets[0].info["fps"]
return self._datasets[0].meta.info["fps"]
@property
def video(self) -> bool:
@ -948,7 +975,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
"""
return self._datasets[0].info.get("video", False)
return self._datasets[0].meta.info.get("video", False)
@property
def features(self) -> datasets.Features:

View File

@ -139,7 +139,7 @@ def load_tasks(local_dir: Path) -> dict:
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
def load_episode_dicts(local_dir: Path) -> dict:
def load_episodes(local_dir: Path) -> dict:
return load_jsonlines(local_dir / EPISODES_PATH)

View File

@ -105,7 +105,7 @@ from pathlib import Path
from typing import List
# from safetensors.torch import load_file, save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.robot_devices.control_utils import (
control_loop,
has_method,
@ -234,15 +234,18 @@ def record(
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy)
dataset = LeRobotDataset.create(
dataset_metadata = LeRobotDatasetMetadata.create(
repo_id,
fps,
root=root,
robot=robot,
image_writer_processes=num_image_writer_processes,
image_writer_threads_per_camera=num_image_writer_threads_per_camera,
use_videos=video,
)
dataset = LeRobotDataset.create(
dataset_metadata,
image_writer_processes=num_image_writer_processes,
image_writer_threads=num_image_writer_threads_per_camera,
)
if not robot.is_connected:
robot.connect()
@ -315,7 +318,6 @@ def record(
dataset.consolidate(run_compute_stats)
# lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds)
if push_to_hub:
dataset.push_to_hub()

View File

@ -484,7 +484,7 @@ def main(
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=str(pretrained_policy_path))
else:
# Note: We need the dataset stats to pass to the policy's normalization modules.
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).meta.stats)
assert isinstance(policy, nn.Module)
policy.eval()

View File

@ -328,7 +328,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("make_policy")
policy = make_policy(
hydra_cfg=cfg,
dataset_stats=offline_dataset.stats if not cfg.resume else None,
dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
)
assert isinstance(policy, nn.Module)

View File

@ -153,7 +153,7 @@ def visualize_dataset(
rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
# display each camera image
for key in dataset.camera_keys:
for key in dataset.meta.camera_keys:
# TODO(rcadene): add `.compress()`? is it lossless?
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))

View File

@ -97,8 +97,8 @@ def run_server(
"num_episodes": dataset.num_episodes,
"fps": dataset.fps,
}
video_paths = [dataset.get_video_file_path(episode_id, key) for key in dataset.video_keys]
tasks = dataset.episode_dicts[episode_id]["tasks"]
video_paths = [dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys]
tasks = dataset.meta.episodes[episode_id]["tasks"]
videos_info = [
{"url": url_for("static", filename=video_path), "filename": video_path.name}
for video_path in video_paths
@ -170,7 +170,8 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
# get first frame of episode (hack to get video_path of the episode)
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
return [
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] for key in dataset.video_keys
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
for key in dataset.meta.video_keys
]
@ -202,8 +203,8 @@ def visualize_dataset_html(
dataset = LeRobotDataset(repo_id, root=root)
if len(dataset.image_keys) > 0:
raise NotImplementedError(f"Image keys ({dataset.image_keys=}) are currently not supported.")
if len(dataset.meta.image_keys) > 0:
raise NotImplementedError(f"Image keys ({dataset.meta.image_keys=}) are currently not supported.")
if output_dir is None:
output_dir = f"outputs/visualize_dataset_html/{repo_id}"

View File

@ -157,7 +157,7 @@ def visualize_transforms(cfg, output_dir: Path, n_examples: int = 5):
output_dir.mkdir(parents=True, exist_ok=True)
# Get 1st frame from 1st camera of 1st episode
original_frame = dataset[0][dataset.camera_keys[0]]
original_frame = dataset[0][dataset.meta.camera_keys[0]]
to_pil(original_frame).save(output_dir / "original_frame.png", quality=100)
print("\nOriginal frame saved to:")
print(f" {output_dir / 'original_frame.png'}.")

View File

@ -8,7 +8,7 @@ import PIL.Image
import pytest
import torch
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_PARQUET_PATH,
@ -33,8 +33,8 @@ def make_dummy_shapes(keys: list[str] | None = None, camera_keys: list[str] | No
return shapes
def get_task_index(tasks_dicts: dict, task: str) -> int:
tasks = {d["task_index"]: d["task"] for d in tasks_dicts}
def get_task_index(task_dicts: dict, task: str) -> int:
tasks = {d["task_index"]: d["task"] for d in task_dicts}
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
return task_to_task_index[task]
@ -313,6 +313,47 @@ def hf_dataset_factory(img_array_factory, episodes, tasks):
return _create_hf_dataset
@pytest.fixture(scope="session")
def lerobot_dataset_metadata_factory(
info,
stats,
tasks,
episodes,
mock_snapshot_download_factory,
):
def _create_lerobot_dataset_metadata(
root: Path,
repo_id: str = DUMMY_REPO_ID,
info_dict: dict = info,
stats_dict: dict = stats,
task_dicts: list[dict] = tasks,
episode_dicts: list[dict] = episodes,
**kwargs,
) -> LeRobotDatasetMetadata:
mock_snapshot_download = mock_snapshot_download_factory(
info_dict=info_dict,
stats_dict=stats_dict,
task_dicts=task_dicts,
episode_dicts=episode_dicts,
)
with (
patch(
"lerobot.common.datasets.lerobot_dataset.get_hub_safe_version"
) as mock_get_hub_safe_version_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
):
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDatasetMetadata(
repo_id=repo_id, root=root, local_files_only=kwargs.get("local_files_only", False)
)
return _create_lerobot_dataset_metadata
@pytest.fixture(scope="session")
def lerobot_dataset_factory(
info,
@ -321,6 +362,7 @@ def lerobot_dataset_factory(
episodes,
hf_dataset,
mock_snapshot_download_factory,
lerobot_dataset_metadata_factory,
):
def _create_lerobot_dataset(
root: Path,
@ -335,19 +377,26 @@ def lerobot_dataset_factory(
mock_snapshot_download = mock_snapshot_download_factory(
info_dict=info_dict,
stats_dict=stats_dict,
tasks_dicts=task_dicts,
episodes_dicts=episode_dicts,
task_dicts=task_dicts,
episode_dicts=episode_dicts,
hf_ds=hf_ds,
)
mock_metadata = lerobot_dataset_metadata_factory(
root=root,
repo_id=repo_id,
info_dict=info_dict,
stats_dict=stats_dict,
task_dicts=task_dicts,
episode_dicts=episode_dicts,
**kwargs,
)
with (
patch(
"lerobot.common.datasets.lerobot_dataset.get_hub_safe_version"
) as mock_get_hub_safe_version_patch,
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
):
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
mock_metadata_patch.return_value = mock_metadata
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)

View File

@ -36,11 +36,11 @@ def stats_path(stats):
@pytest.fixture(scope="session")
def tasks_path(tasks):
def _create_tasks_jsonl_file(dir: Path, tasks_dicts: list = tasks) -> Path:
def _create_tasks_jsonl_file(dir: Path, task_dicts: list = tasks) -> Path:
fpath = dir / TASKS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(tasks_dicts)
writer.write_all(task_dicts)
return fpath
return _create_tasks_jsonl_file

View File

@ -26,7 +26,7 @@ def mock_snapshot_download_factory(
"""
def _mock_snapshot_download_func(
info_dict=info, stats_dict=stats, tasks_dicts=tasks, episodes_dicts=episodes, hf_ds=hf_dataset
info_dict=info, stats_dict=stats, task_dicts=tasks, episode_dicts=episodes, hf_ds=hf_dataset
):
def _extract_episode_index_from_path(fpath: str) -> int:
path = Path(fpath)
@ -53,7 +53,7 @@ def mock_snapshot_download_factory(
all_files.extend(meta_files)
data_files = []
for episode_dict in episodes_dicts:
for episode_dict in episode_dicts:
ep_idx = episode_dict["episode_index"]
ep_chunk = ep_idx // info_dict["chunks_size"]
data_path = info_dict["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
@ -75,9 +75,9 @@ def mock_snapshot_download_factory(
elif rel_path == STATS_PATH:
_ = stats_path(local_dir, stats_dict)
elif rel_path == TASKS_PATH:
_ = tasks_path(local_dir, tasks_dicts)
_ = tasks_path(local_dir, task_dicts)
elif rel_path == EPISODES_PATH:
_ = episode_path(local_dir, episodes_dicts)
_ = episode_path(local_dir, episode_dicts)
else:
pass
return str(local_dir)

View File

@ -76,7 +76,7 @@ def main():
dataset = LeRobotDataset(DATASET_REPO_ID, image_transforms=None)
output_dir = Path(ARTIFACT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)
original_frame = dataset[0][dataset.camera_keys[0]]
original_frame = dataset[0][dataset.meta.camera_keys[0]]
save_single_transforms(original_frame, output_dir)
save_default_config_transform(original_frame, output_dir)

View File

@ -38,7 +38,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
)
set_global_seed(1337)
dataset = make_dataset(cfg)
policy = make_policy(cfg, dataset_stats=dataset.stats)
policy = make_policy(cfg, dataset_stats=dataset.meta.stats)
policy.train()
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)

View File

@ -155,7 +155,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
display_cameras=False,
play_sounds=False,
)
assert dataset.total_episodes == 2
assert dataset.meta.total_episodes == 2
assert len(dataset) == 2
replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False)
@ -193,7 +193,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
overrides=overrides,
)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
out_dir = tmpdir / "logger"
logger = Logger(cfg, out_dir, wandb_job_name="debug")

View File

@ -33,7 +33,11 @@ from lerobot.common.datasets.compute_stats import (
get_stats_einops_patterns,
)
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
MultiLeRobotDataset,
)
from lerobot.common.datasets.utils import (
create_branch,
flatten_dict,
@ -53,14 +57,17 @@ def test_same_attributes_defined(lerobot_dataset_factory, tmp_path):
# Instantiate both ways
robot = make_robot("koch", mock=True)
root_create = tmp_path / "create"
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
metadata_create = LeRobotDatasetMetadata.create(
repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create
)
dataset_create = LeRobotDataset.create(metadata_create)
root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init)
# Access the '_hub_version' cached_property in both instances to force its creation
_ = dataset_init._hub_version
_ = dataset_create._hub_version
_ = dataset_init.meta._hub_version
_ = dataset_create.meta._hub_version
init_attr = set(vars(dataset_init).keys())
create_attr = set(vars(dataset_create).keys())
@ -78,8 +85,8 @@ def test_dataset_initialization(lerobot_dataset_from_episodes_factory, tmp_path)
dataset = lerobot_dataset_from_episodes_factory(root=tmp_path, **kwargs)
assert dataset.repo_id == kwargs["repo_id"]
assert dataset.total_episodes == kwargs["total_episodes"]
assert dataset.total_frames == kwargs["total_frames"]
assert dataset.meta.total_episodes == kwargs["total_episodes"]
assert dataset.meta.total_frames == kwargs["total_frames"]
assert dataset.episodes == kwargs["episodes"]
assert dataset.num_episodes == len(kwargs["episodes"])
assert dataset.num_frames == len(dataset)
@ -118,7 +125,7 @@ def test_factory(env_name, repo_id, policy_name):
)
dataset = make_dataset(cfg)
delta_timestamps = dataset.delta_timestamps
camera_keys = dataset.camera_keys
camera_keys = dataset.meta.camera_keys
item = dataset[0]
@ -251,7 +258,7 @@ def test_compute_stats_on_xarm():
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
# load stats used during training which are expected to match the ones returned by computed_stats
loaded_stats = dataset.stats # noqa: F841
loaded_stats = dataset.meta.stats # noqa: F841
# TODO(rcadene): we can't test this because expected_stats is computed on a subset
# # test loaded stats match expected stats

View File

@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(aliberts): Mute logging for these tests
import io
import subprocess
import sys
@ -29,6 +29,7 @@ def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> s
return text
# TODO(aliberts): Remove usage of subprocess calls and patch code with fixtures
def _run_script(path):
subprocess.run([sys.executable, path], check=True)

View File

@ -50,7 +50,7 @@ def test_get_policy_and_config_classes(policy_name: str):
assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation)
# TODO(aliberts): refactor using lerobot/__init__.py variables
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize(
"env_name,policy_name,extra_overrides",
[
@ -136,7 +136,7 @@ def test_policy(env_name, policy_name, extra_overrides):
# Check that we can make the policy object.
dataset = make_dataset(cfg)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
# Check that the policy follows the required protocol.
assert isinstance(
policy, Policy
@ -195,6 +195,7 @@ def test_policy(env_name, policy_name, extra_overrides):
env.step(action)
@pytest.mark.skip("TODO after v2 migration / removing hydra")
def test_act_backbone_lr():
"""
Test that the ACT policy can be instantiated with a different learning rate for the backbone.
@ -213,7 +214,7 @@ def test_act_backbone_lr():
assert cfg.training.lr_backbone == 0.001
dataset = make_dataset(cfg)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
assert len(optimizer.param_groups) == 2
assert optimizer.param_groups[0]["lr"] == cfg.training.lr

View File

@ -250,6 +250,7 @@ def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
)
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize(
"required_packages, raw_format, repo_id, make_test_data",
[