Add LeRobotDatasetMetadata
This commit is contained in:
parent
ac79e8cb36
commit
e4ba084e25
|
@ -266,7 +266,7 @@ def benchmark_encoding_decoding(
|
|||
)
|
||||
|
||||
ep_num_images = dataset.episode_data_index["to"][0].item()
|
||||
width, height = tuple(dataset[0][dataset.camera_keys[0]].shape[-2:])
|
||||
width, height = tuple(dataset[0][dataset.meta.camera_keys[0]].shape[-2:])
|
||||
num_pixels = width * height
|
||||
video_size_bytes = video_path.stat().st_size
|
||||
images_size_bytes = get_directory_size(imgs_dir)
|
||||
|
|
|
@ -13,6 +13,7 @@ Features included in this script:
|
|||
The script ends with examples of how to batch process data using PyTorch's DataLoader.
|
||||
"""
|
||||
|
||||
# TODO(aliberts, rcadene): Update this script with the new v2 api
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
|
@ -31,7 +32,7 @@ repo_id = "lerobot/pusht"
|
|||
# You can easily load a dataset from a Hugging Face repository
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
|
||||
# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset
|
||||
# LeRobotDataset actually wraps an underlying Hugging Face dataset
|
||||
# (see https://huggingface.co/docs/datasets/index for more information).
|
||||
print(dataset)
|
||||
print(dataset.hf_dataset)
|
||||
|
@ -39,7 +40,7 @@ print(dataset.hf_dataset)
|
|||
# And provides additional utilities for robotics and compatibility with Pytorch
|
||||
print(f"\naverage number of frames per episode: {dataset.num_frames / dataset.num_episodes:.3f}")
|
||||
print(f"frames per second used during data collection: {dataset.fps=}")
|
||||
print(f"keys to access images from cameras: {dataset.camera_keys=}\n")
|
||||
print(f"keys to access images from cameras: {dataset.meta.camera_keys=}\n")
|
||||
|
||||
# Access frame indexes associated to first episode
|
||||
episode_index = 0
|
||||
|
@ -60,14 +61,15 @@ frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
|
|||
Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
|
||||
imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_0.mp4", frames, fps=dataset.fps)
|
||||
|
||||
|
||||
# For many machine learning applications we need to load the history of past observations or trajectories of
|
||||
# future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
|
||||
# differences with the current loaded frame. For instance:
|
||||
delta_timestamps = {
|
||||
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
||||
"observation.image": [-1, -0.5, -0.20, 0],
|
||||
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 20 ms, 10 ms, and current frame
|
||||
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, -0.02, -0.01, 0],
|
||||
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
|
||||
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
|
||||
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
||||
"action": [t / dataset.fps for t in range(64)],
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)
|
|||
# For this example, no arguments need to be passed because the defaults are set up for PushT.
|
||||
# If you're doing something different, you will likely need to change at least some of the defaults.
|
||||
cfg = DiffusionConfig()
|
||||
policy = DiffusionPolicy(cfg, dataset_stats=dataset.stats)
|
||||
policy = DiffusionPolicy(cfg, dataset_stats=dataset.meta.stats)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ dataset = LeRobotDataset(dataset_repo_id)
|
|||
first_idx = dataset.episode_data_index["from"][0].item()
|
||||
|
||||
# Get the frame corresponding to the first camera
|
||||
frame = dataset[first_idx][dataset.camera_keys[0]]
|
||||
frame = dataset[first_idx][dataset.meta.camera_keys[0]]
|
||||
|
||||
|
||||
# Define the transformations
|
||||
|
@ -36,7 +36,7 @@ transforms = v2.Compose(
|
|||
transformed_dataset = LeRobotDataset(dataset_repo_id, image_transforms=transforms)
|
||||
|
||||
# Get a frame from the transformed dataset
|
||||
transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]]
|
||||
transformed_frame = transformed_dataset[first_idx][transformed_dataset.meta.camera_keys[0]]
|
||||
|
||||
# Create a directory to store output images
|
||||
output_dir = Path("outputs/image_transforms")
|
||||
|
|
|
@ -8,6 +8,7 @@ especially in the context of imitation learning. The most reliable approach is t
|
|||
on the target environment, whether that be in simulation or the real world.
|
||||
"""
|
||||
|
||||
# TODO(aliberts, rcadene): Update this script with the new v2 api
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ def get_stats_einops_patterns(dataset, num_workers=0):
|
|||
assert batch[key].dtype != torch.float64
|
||||
|
||||
# if isinstance(feats_type, (VideoFrame, Image)):
|
||||
if key in dataset.camera_keys:
|
||||
if key in dataset.meta.camera_keys:
|
||||
# sanity check that images are channel first
|
||||
_, c, h, w = batch[key].shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
|
||||
|
|
|
@ -111,6 +111,6 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
|||
for stats_type, listconfig in stats_dict.items():
|
||||
# example of stats_type: min, max, mean, std
|
||||
stats = OmegaConf.to_container(listconfig, resolve=True)
|
||||
dataset.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
|
||||
return dataset
|
||||
|
|
|
@ -45,7 +45,7 @@ from lerobot.common.datasets.utils import (
|
|||
get_episode_data_index,
|
||||
get_hub_safe_version,
|
||||
hf_transform_to_torch,
|
||||
load_episode_dicts,
|
||||
load_episodes,
|
||||
load_info,
|
||||
load_stats,
|
||||
load_tasks,
|
||||
|
@ -66,6 +66,237 @@ CODEBASE_VERSION = "v2.0"
|
|||
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
|
||||
|
||||
|
||||
class LeRobotDatasetMetadata:
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
root: Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
):
|
||||
self.repo_id = repo_id
|
||||
self.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||
self.local_files_only = local_files_only
|
||||
|
||||
# Load metadata
|
||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.info = load_info(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
ignore_patterns: list[str] | str | None = None,
|
||||
) -> None:
|
||||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self._hub_version,
|
||||
local_dir=self.root,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
local_files_only=self.local_files_only,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def _hub_version(self) -> str | None:
|
||||
return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
|
||||
|
||||
@property
|
||||
def _version(self) -> str:
|
||||
"""Codebase version used to create this dataset."""
|
||||
return self.info["codebase_version"]
|
||||
|
||||
def get_data_file_path(self, ep_index: int) -> Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index)
|
||||
return Path(fpath)
|
||||
|
||||
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
fpath = self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
||||
return Path(fpath)
|
||||
|
||||
def get_episode_chunk(self, ep_index: int) -> int:
|
||||
return ep_index // self.chunks_size
|
||||
|
||||
@property
|
||||
def data_path(self) -> str:
|
||||
"""Formattable string for the parquet files."""
|
||||
return self.info["data_path"]
|
||||
|
||||
@property
|
||||
def videos_path(self) -> str | None:
|
||||
"""Formattable string for the video files."""
|
||||
return self.info["videos"]["videos_path"] if len(self.video_keys) > 0 else None
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
"""Frames per second used during data collection."""
|
||||
return self.info["fps"]
|
||||
|
||||
@property
|
||||
def keys(self) -> list[str]:
|
||||
"""Keys to access non-image data (state, actions etc.)."""
|
||||
return self.info["keys"]
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities stored as images."""
|
||||
return self.info["image_keys"]
|
||||
|
||||
@property
|
||||
def video_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities stored as videos."""
|
||||
return self.info["video_keys"]
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return self.image_keys + self.video_keys
|
||||
|
||||
@property
|
||||
def names(self) -> dict[list[str]]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
return self.info["names"]
|
||||
|
||||
@property
|
||||
def total_episodes(self) -> int:
|
||||
"""Total number of episodes available."""
|
||||
return self.info["total_episodes"]
|
||||
|
||||
@property
|
||||
def total_frames(self) -> int:
|
||||
"""Total number of frames saved in this dataset."""
|
||||
return self.info["total_frames"]
|
||||
|
||||
@property
|
||||
def total_tasks(self) -> int:
|
||||
"""Total number of different tasks performed in this dataset."""
|
||||
return self.info["total_tasks"]
|
||||
|
||||
@property
|
||||
def total_chunks(self) -> int:
|
||||
"""Total number of chunks (groups of episodes)."""
|
||||
return self.info["total_chunks"]
|
||||
|
||||
@property
|
||||
def chunks_size(self) -> int:
|
||||
"""Max number of episodes per chunk."""
|
||||
return self.info["chunks_size"]
|
||||
|
||||
@property
|
||||
def shapes(self) -> dict:
|
||||
"""Shapes for the different features."""
|
||||
return self.info["shapes"]
|
||||
|
||||
@property
|
||||
def task_to_task_index(self) -> dict:
|
||||
return {task: task_idx for task_idx, task in self.tasks.items()}
|
||||
|
||||
def get_task_index(self, task: str) -> int:
|
||||
"""
|
||||
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
||||
otherwise creates a new task_index.
|
||||
"""
|
||||
task_index = self.task_to_task_index.get(task, None)
|
||||
return task_index if task_index is not None else self.total_tasks
|
||||
|
||||
def add_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None:
|
||||
self.info["total_episodes"] += 1
|
||||
self.info["total_frames"] += episode_length
|
||||
|
||||
if task_index not in self.tasks:
|
||||
self.info["total_tasks"] += 1
|
||||
self.tasks[task_index] = task
|
||||
task_dict = {
|
||||
"task_index": task_index,
|
||||
"task": task,
|
||||
}
|
||||
append_jsonlines(task_dict, self.root / TASKS_PATH)
|
||||
|
||||
chunk = self.get_episode_chunk(episode_index)
|
||||
if chunk >= self.total_chunks:
|
||||
self.info["total_chunks"] += 1
|
||||
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
self.info["total_videos"] += len(self.video_keys)
|
||||
write_json(self.info, self.root / INFO_PATH)
|
||||
|
||||
episode_dict = {
|
||||
"episode_index": episode_index,
|
||||
"tasks": [task],
|
||||
"length": episode_length,
|
||||
}
|
||||
self.episodes.append(episode_dict)
|
||||
append_jsonlines(episode_dict, self.root / EPISODES_PATH)
|
||||
|
||||
def write_video_info(self) -> None:
|
||||
"""
|
||||
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
"""
|
||||
for key in self.video_keys:
|
||||
if key not in self.info["videos"]:
|
||||
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
||||
self.info["videos"][key] = get_video_info(video_path)
|
||||
|
||||
write_json(self.info, self.root / INFO_PATH)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
repo_id: str,
|
||||
fps: int,
|
||||
root: Path | None = None,
|
||||
robot: Robot | None = None,
|
||||
robot_type: str | None = None,
|
||||
keys: list[str] | None = None,
|
||||
image_keys: list[str] | None = None,
|
||||
video_keys: list[str] = None,
|
||||
shapes: dict | None = None,
|
||||
names: dict | None = None,
|
||||
use_videos: bool = True,
|
||||
) -> "LeRobotDatasetMetadata":
|
||||
"""Creates metadata for a LeRobotDataset."""
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||
obj.image_writer = None
|
||||
|
||||
if robot is not None:
|
||||
robot_type, keys, image_keys, video_keys, shapes, names = _get_info_from_robot(robot, use_videos)
|
||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||
logging.warning(
|
||||
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
|
||||
)
|
||||
elif (
|
||||
robot_type is None
|
||||
or keys is None
|
||||
or image_keys is None
|
||||
or video_keys is None
|
||||
or shapes is None
|
||||
or names is None
|
||||
):
|
||||
raise ValueError(
|
||||
"Dataset info (robot_type, keys, shapes...) must either come from a Robot or explicitly passed upon creation."
|
||||
)
|
||||
|
||||
if len(video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
|
||||
obj.tasks, obj.stats, obj.episodes = {}, {}, []
|
||||
obj.info = create_empty_dataset_info(
|
||||
CODEBASE_VERSION, fps, robot_type, keys, image_keys, video_keys, shapes, names
|
||||
)
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
obj.local_files_only = True
|
||||
return obj
|
||||
|
||||
|
||||
class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -86,9 +317,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
- On your local disk in the 'root' folder. This is typically the case when you recorded your
|
||||
dataset locally and you may or may not have pushed it to the hub yet. Instantiating this class
|
||||
with 'root' will load your dataset directly from disk. This can happen while you're offline (no
|
||||
internet connection).
|
||||
internet connection), in that case, use local_files_only=True.
|
||||
|
||||
- On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and is not on
|
||||
- On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on
|
||||
your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download
|
||||
the dataset from that address and load it, pending your dataset is compliant with
|
||||
codebase_version v2.0. If your dataset has been created before this new format, you will be
|
||||
|
@ -96,9 +327,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py.
|
||||
|
||||
|
||||
2. Your dataset doesn't already exists (either on local disk or on the Hub):
|
||||
You can create an empty LeRobotDataset with the 'create' classmethod. This can be used for
|
||||
recording a dataset or port an existing dataset to the LeRobotDataset format.
|
||||
2. Your dataset doesn't already exists (either on local disk or on the Hub): you can create an empty
|
||||
LeRobotDataset with the 'create' classmethod. This can be used for recording a dataset or port an
|
||||
existing dataset to the LeRobotDataset format.
|
||||
|
||||
|
||||
In terms of files, LeRobotDataset encapsulates 3 main things:
|
||||
|
@ -192,21 +423,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.image_writer = None
|
||||
self.episode_buffer = {}
|
||||
|
||||
# Load metadata
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.info = load_info(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.episode_dicts = load_episode_dicts(self.root)
|
||||
|
||||
# Load metadata
|
||||
self.meta = LeRobotDatasetMetadata(self.repo_id, self.root, self.local_files_only)
|
||||
|
||||
# Check version
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
||||
|
||||
# Load actual data
|
||||
self.download_episodes(download_videos)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.episode_dicts, self.episodes)
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
|
||||
# Check timestamps
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
|
@ -216,26 +444,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
# TODO(aliberts):
|
||||
# - [X] Move delta_timestamp logic outside __get_item__
|
||||
# - [X] Update __get_item__
|
||||
# - [/] Add doc
|
||||
# - [ ] Add self.add_frame()
|
||||
# - [ ] Add self.consolidate() for:
|
||||
# - [X] Check timestamps sync
|
||||
# - [ ] Sanity checks (episodes num, shapes, files, etc.)
|
||||
# - [ ] Update episode_index (arg update=True)
|
||||
# - [ ] Update info.json (arg update=True)
|
||||
|
||||
@cached_property
|
||||
def _hub_version(self) -> str | None:
|
||||
return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
|
||||
|
||||
@property
|
||||
def _version(self) -> str:
|
||||
"""Codebase version used to create this dataset."""
|
||||
return self.info["codebase_version"]
|
||||
|
||||
def push_to_hub(self, push_videos: bool = True) -> None:
|
||||
if not self.consolidated:
|
||||
raise RuntimeError(
|
||||
|
@ -262,7 +470,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self._hub_version,
|
||||
revision=self.meta._hub_version,
|
||||
local_dir=self.root,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
|
@ -280,11 +488,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
files = None
|
||||
ignore_patterns = None if download_videos else "videos/"
|
||||
if self.episodes is not None:
|
||||
files = [str(self.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
||||
if len(self.video_keys) > 0 and download_videos:
|
||||
files = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
||||
if len(self.meta.video_keys) > 0 and download_videos:
|
||||
video_files = [
|
||||
str(self.get_video_file_path(ep_idx, vid_key))
|
||||
for vid_key in self.video_keys
|
||||
str(self.meta.get_video_file_path(ep_idx, vid_key))
|
||||
for vid_key in self.meta.video_keys
|
||||
for ep_idx in self.episodes
|
||||
]
|
||||
files += video_files
|
||||
|
@ -297,108 +505,30 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
path = str(self.root / "data")
|
||||
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
|
||||
else:
|
||||
files = [str(self.root / self.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
||||
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
||||
hf_dataset = load_dataset("parquet", data_files=files, split="train")
|
||||
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
def get_data_file_path(self, ep_index: int) -> Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index)
|
||||
return Path(fpath)
|
||||
|
||||
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
fpath = self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
||||
return Path(fpath)
|
||||
|
||||
def get_episode_chunk(self, ep_index: int) -> int:
|
||||
return ep_index // self.chunks_size
|
||||
|
||||
@property
|
||||
def data_path(self) -> str:
|
||||
"""Formattable string for the parquet files."""
|
||||
return self.info["data_path"]
|
||||
|
||||
@property
|
||||
def videos_path(self) -> str | None:
|
||||
"""Formattable string for the video files."""
|
||||
return self.info["videos"]["videos_path"] if len(self.video_keys) > 0 else None
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
"""Frames per second used during data collection."""
|
||||
return self.info["fps"]
|
||||
|
||||
@property
|
||||
def keys(self) -> list[str]:
|
||||
"""Keys to access non-image data (state, actions etc.)."""
|
||||
return self.info["keys"]
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities stored as images."""
|
||||
return self.info["image_keys"]
|
||||
|
||||
@property
|
||||
def video_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities stored as videos."""
|
||||
return self.info["video_keys"]
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return self.image_keys + self.video_keys
|
||||
|
||||
@property
|
||||
def names(self) -> dict[list[str]]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
return self.info["names"]
|
||||
return self.meta.fps
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
"""Number of frames in selected episodes."""
|
||||
return len(self.hf_dataset) if self.hf_dataset is not None else self.total_frames
|
||||
return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
"""Number of episodes selected."""
|
||||
return len(self.episodes) if self.episodes is not None else self.total_episodes
|
||||
|
||||
@property
|
||||
def total_episodes(self) -> int:
|
||||
"""Total number of episodes available."""
|
||||
return self.info["total_episodes"]
|
||||
|
||||
@property
|
||||
def total_frames(self) -> int:
|
||||
"""Total number of frames saved in this dataset."""
|
||||
return self.info["total_frames"]
|
||||
|
||||
@property
|
||||
def total_tasks(self) -> int:
|
||||
"""Total number of different tasks performed in this dataset."""
|
||||
return self.info["total_tasks"]
|
||||
|
||||
@property
|
||||
def total_chunks(self) -> int:
|
||||
"""Total number of chunks (groups of episodes)."""
|
||||
return self.info["total_chunks"]
|
||||
|
||||
@property
|
||||
def chunks_size(self) -> int:
|
||||
"""Max number of episodes per chunk."""
|
||||
return self.info["chunks_size"]
|
||||
|
||||
@property
|
||||
def shapes(self) -> dict:
|
||||
"""Shapes for the different features."""
|
||||
return self.info["shapes"]
|
||||
return len(self.episodes) if self.episodes is not None else self.meta.total_episodes
|
||||
|
||||
@property
|
||||
def features(self) -> list[str]:
|
||||
return list(self._features) + self.video_keys
|
||||
return list(self._features) + self.meta.video_keys
|
||||
|
||||
@property
|
||||
def _features(self) -> datasets.Features:
|
||||
|
@ -418,39 +548,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
features[key] = datasets.Value(dtype="bool")
|
||||
elif key in ["timestamp", "next.reward"]:
|
||||
features[key] = datasets.Value(dtype="float32")
|
||||
elif key in self.image_keys:
|
||||
elif key in self.meta.image_keys:
|
||||
features[key] = datasets.Image()
|
||||
elif key in self.keys:
|
||||
elif key in self.meta.keys:
|
||||
features[key] = datasets.Sequence(
|
||||
length=self.shapes[key], feature=datasets.Value(dtype="float32")
|
||||
length=self.meta.shapes[key], feature=datasets.Value(dtype="float32")
|
||||
)
|
||||
|
||||
return datasets.Features(features)
|
||||
|
||||
@property
|
||||
def task_to_task_index(self) -> dict:
|
||||
return {task: task_idx for task_idx, task in self.tasks.items()}
|
||||
|
||||
def get_task_index(self, task: str) -> int:
|
||||
"""
|
||||
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
||||
otherwise creates a new task_index.
|
||||
"""
|
||||
task_index = self.task_to_task_index.get(task, None)
|
||||
return task_index if task_index is not None else self.total_tasks
|
||||
|
||||
def current_episode_index(self, idx: int) -> int:
|
||||
episode_index = self.hf_dataset["episode_index"][idx]
|
||||
if self.episodes is not None:
|
||||
# get episode_index from selected episodes
|
||||
episode_index = self.episodes.index(episode_index)
|
||||
|
||||
return episode_index
|
||||
|
||||
def episode_length(self, episode_index) -> int:
|
||||
"""Number of samples/frames for given episode."""
|
||||
return self.info["episodes"][episode_index]["length"]
|
||||
|
||||
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
||||
ep_start = self.episode_data_index["from"][ep_idx]
|
||||
ep_end = self.episode_data_index["to"][ep_idx]
|
||||
|
@ -472,7 +578,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
query_indices: dict[str, list[int]] | None = None,
|
||||
) -> dict[str, list[float]]:
|
||||
query_timestamps = {}
|
||||
for key in self.video_keys:
|
||||
for key in self.meta.video_keys:
|
||||
if query_indices is not None and key in query_indices:
|
||||
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
||||
query_timestamps[key] = torch.stack(timestamps).tolist()
|
||||
|
@ -485,7 +591,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
return {
|
||||
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
||||
for key, q_idx in query_indices.items()
|
||||
if key not in self.video_keys
|
||||
if key not in self.meta.video_keys
|
||||
}
|
||||
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
|
||||
|
@ -496,7 +602,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
"""
|
||||
item = {}
|
||||
for vid_key, query_ts in query_timestamps.items():
|
||||
video_path = self.root / self.get_video_file_path(ep_idx, vid_key)
|
||||
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||
frames = decode_video_frames_torchvision(
|
||||
video_path, query_ts, self.tolerance_s, self.video_backend
|
||||
)
|
||||
|
@ -525,14 +631,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
for key, val in query_result.items():
|
||||
item[key] = val
|
||||
|
||||
if len(self.video_keys) > 0:
|
||||
if len(self.meta.video_keys) > 0:
|
||||
current_ts = item["timestamp"].item()
|
||||
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
item = {**video_frames, **item}
|
||||
|
||||
if self.image_transforms is not None:
|
||||
image_keys = self.camera_keys
|
||||
image_keys = self.meta.camera_keys
|
||||
for cam in image_keys:
|
||||
item[cam] = self.image_transforms(item[cam])
|
||||
|
||||
|
@ -545,20 +651,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
f" Selected episodes: {self.episodes},\n"
|
||||
f" Number of selected episodes: {self.num_episodes},\n"
|
||||
f" Number of selected samples: {self.num_frames},\n"
|
||||
f"\n{json.dumps(self.info, indent=4)}\n"
|
||||
f"\n{json.dumps(self.meta.info, indent=4)}\n"
|
||||
)
|
||||
|
||||
def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||
# TODO(aliberts): Handle resume
|
||||
return {
|
||||
"size": 0,
|
||||
"episode_index": self.total_episodes if episode_index is None else episode_index,
|
||||
"episode_index": self.meta.total_episodes if episode_index is None else episode_index,
|
||||
"task_index": None,
|
||||
"frame_index": [],
|
||||
"timestamp": [],
|
||||
"next.done": [],
|
||||
**{key: [] for key in self.keys},
|
||||
**{key: [] for key in self.image_keys},
|
||||
**{key: [] for key in self.meta.keys},
|
||||
**{key: [] for key in self.meta.image_keys},
|
||||
}
|
||||
|
||||
def add_frame(self, frame: dict) -> None:
|
||||
|
@ -573,7 +679,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.episode_buffer["next.done"].append(False)
|
||||
|
||||
# Save all observed modalities except images
|
||||
for key in self.keys:
|
||||
for key in self.meta.keys:
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
|
||||
self.episode_buffer["size"] += 1
|
||||
|
@ -582,7 +688,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
return
|
||||
|
||||
# Save images
|
||||
for cam_key in self.camera_keys:
|
||||
for cam_key in self.meta.camera_keys:
|
||||
img_path = self.image_writer.get_image_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], image_key=cam_key, frame_index=frame_index
|
||||
)
|
||||
|
@ -594,7 +700,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
fpath=img_path,
|
||||
)
|
||||
|
||||
if cam_key in self.image_keys:
|
||||
if cam_key in self.meta.image_keys:
|
||||
self.episode_buffer[cam_key].append(str(img_path))
|
||||
|
||||
def add_episode(self, task: str, encode_videos: bool = False) -> None:
|
||||
|
@ -609,17 +715,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
"""
|
||||
episode_length = self.episode_buffer.pop("size")
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
if episode_index != self.total_episodes:
|
||||
if episode_index != self.meta.total_episodes:
|
||||
# TODO(aliberts): Add option to use existing episode_index
|
||||
raise NotImplementedError()
|
||||
|
||||
task_index = self.get_task_index(task)
|
||||
task_index = self.meta.get_task_index(task)
|
||||
self.episode_buffer["next.done"][-1] = True
|
||||
|
||||
for key in self.episode_buffer:
|
||||
if key in self.image_keys:
|
||||
if key in self.meta.image_keys:
|
||||
continue
|
||||
elif key in self.keys:
|
||||
elif key in self.meta.keys:
|
||||
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
|
||||
elif key == "episode_index":
|
||||
self.episode_buffer[key] = torch.full((episode_length,), episode_index)
|
||||
|
@ -628,13 +734,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
else:
|
||||
self.episode_buffer[key] = torch.tensor(self.episode_buffer[key])
|
||||
|
||||
self.episode_buffer["index"] = torch.arange(self.total_frames, self.total_frames + episode_length)
|
||||
self._save_episode_to_metadata(episode_index, episode_length, task, task_index)
|
||||
self.episode_buffer["index"] = torch.arange(
|
||||
self.meta.total_frames, self.meta.total_frames + episode_length
|
||||
)
|
||||
self.meta.add_episode(episode_index, episode_length, task, task_index)
|
||||
|
||||
self._wait_image_writer()
|
||||
self._save_episode_table(episode_index)
|
||||
|
||||
if encode_videos and len(self.video_keys) > 0:
|
||||
if encode_videos and len(self.meta.video_keys) > 0:
|
||||
self.encode_videos()
|
||||
|
||||
# Reset the buffer
|
||||
|
@ -643,45 +751,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
def _save_episode_table(self, episode_index: int) -> None:
|
||||
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self._features, split="train")
|
||||
ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index)
|
||||
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
|
||||
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
write_parquet(ep_dataset, ep_data_path)
|
||||
|
||||
def _save_episode_to_metadata(
|
||||
self, episode_index: int, episode_length: int, task: str, task_index: int
|
||||
) -> None:
|
||||
self.info["total_episodes"] += 1
|
||||
self.info["total_frames"] += episode_length
|
||||
|
||||
if task_index not in self.tasks:
|
||||
self.info["total_tasks"] += 1
|
||||
self.tasks[task_index] = task
|
||||
task_dict = {
|
||||
"task_index": task_index,
|
||||
"task": task,
|
||||
}
|
||||
append_jsonlines(task_dict, self.root / TASKS_PATH)
|
||||
|
||||
chunk = self.get_episode_chunk(episode_index)
|
||||
if chunk >= self.total_chunks:
|
||||
self.info["total_chunks"] += 1
|
||||
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
self.info["total_videos"] += len(self.video_keys)
|
||||
write_json(self.info, self.root / INFO_PATH)
|
||||
|
||||
episode_dict = {
|
||||
"episode_index": episode_index,
|
||||
"tasks": [task],
|
||||
"length": episode_length,
|
||||
}
|
||||
self.episode_dicts.append(episode_dict)
|
||||
append_jsonlines(episode_dict, self.root / EPISODES_PATH)
|
||||
|
||||
def clear_episode_buffer(self) -> None:
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
if self.image_writer is not None:
|
||||
for cam_key in self.camera_keys:
|
||||
for cam_key in self.meta.camera_keys:
|
||||
img_dir = self.image_writer.get_episode_dir(episode_index, cam_key)
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(img_dir)
|
||||
|
@ -717,12 +794,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
def encode_videos(self) -> None:
|
||||
# Use ffmpeg to convert frames stored as png into mp4 videos
|
||||
for episode_index in range(self.total_episodes):
|
||||
for key in self.video_keys:
|
||||
for episode_index in range(self.meta.total_episodes):
|
||||
for key in self.meta.video_keys:
|
||||
# TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need
|
||||
# to call self.image_writer here
|
||||
tmp_imgs_dir = self.image_writer.get_episode_dir(episode_index, key)
|
||||
video_path = self.root / self.get_video_file_path(episode_index, key)
|
||||
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
|
||||
if video_path.is_file():
|
||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||
continue
|
||||
|
@ -730,40 +807,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# since video encoding with ffmpeg is already using multithreading.
|
||||
encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True)
|
||||
|
||||
def _write_video_info(self) -> None:
|
||||
"""
|
||||
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
"""
|
||||
for key in self.video_keys:
|
||||
if key not in self.info["videos"]:
|
||||
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
||||
self.info["videos"][key] = get_video_info(video_path)
|
||||
|
||||
write_json(self.info, self.root / INFO_PATH)
|
||||
|
||||
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.episode_dicts, self.episodes)
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
|
||||
if len(self.video_keys) > 0:
|
||||
if len(self.meta.video_keys) > 0:
|
||||
self.encode_videos()
|
||||
self._write_video_info()
|
||||
self.meta.write_video_info()
|
||||
|
||||
if not keep_image_files and self.image_writer is not None:
|
||||
shutil.rmtree(self.image_writer.write_dir)
|
||||
|
||||
video_files = list(self.root.rglob("*.mp4"))
|
||||
assert len(video_files) == self.num_episodes * len(self.video_keys)
|
||||
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
|
||||
|
||||
parquet_files = list(self.root.rglob("*.parquet"))
|
||||
assert len(parquet_files) == self.num_episodes
|
||||
|
||||
if run_compute_stats:
|
||||
self.stop_image_writer()
|
||||
self.stats = compute_stats(self)
|
||||
write_stats(self.stats, self.root / STATS_PATH)
|
||||
self.meta.stats = compute_stats(self)
|
||||
write_stats(self.meta.stats, self.root / STATS_PATH)
|
||||
self.consolidated = True
|
||||
else:
|
||||
logging.warning(
|
||||
|
@ -780,60 +845,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
repo_id: str,
|
||||
fps: int,
|
||||
root: Path | None = None,
|
||||
robot: Robot | None = None,
|
||||
robot_type: str | None = None,
|
||||
keys: list[str] | None = None,
|
||||
image_keys: list[str] | None = None,
|
||||
video_keys: list[str] = None,
|
||||
shapes: dict | None = None,
|
||||
names: dict | None = None,
|
||||
metadata: LeRobotDatasetMetadata,
|
||||
tolerance_s: float = 1e-4,
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads_per_camera: int = 0,
|
||||
use_videos: bool = True,
|
||||
image_writer_threads: int = 0,
|
||||
video_backend: str | None = None,
|
||||
) -> "LeRobotDataset":
|
||||
"""Create a LeRobot Dataset from scratch in order to record data."""
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||
obj.meta = metadata
|
||||
obj.repo_id = obj.meta.repo_id
|
||||
obj.root = obj.meta.root
|
||||
obj.local_files_only = obj.meta.local_files_only
|
||||
obj.tolerance_s = tolerance_s
|
||||
obj.image_writer = None
|
||||
|
||||
if robot is not None:
|
||||
robot_type, keys, image_keys, video_keys, shapes, names = _get_info_from_robot(robot, use_videos)
|
||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||
logging.warning(
|
||||
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
|
||||
)
|
||||
if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera):
|
||||
obj.start_image_writer(
|
||||
image_writer_processes, image_writer_threads_per_camera * robot.num_cameras
|
||||
)
|
||||
elif (
|
||||
robot_type is None
|
||||
or keys is None
|
||||
or image_keys is None
|
||||
or video_keys is None
|
||||
or shapes is None
|
||||
or names is None
|
||||
):
|
||||
raise ValueError(
|
||||
"Dataset info (robot_type, keys, shapes...) must either come from a Robot or explicitly passed upon creation."
|
||||
)
|
||||
|
||||
if len(video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
|
||||
obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
|
||||
obj.info = create_empty_dataset_info(
|
||||
CODEBASE_VERSION, fps, robot_type, keys, image_keys, video_keys, shapes, names
|
||||
)
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
if image_writer_processes or image_writer_threads:
|
||||
obj.start_image_writer(image_writer_processes, image_writer_threads)
|
||||
|
||||
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
||||
obj.episode_buffer = obj._create_episode_buffer()
|
||||
|
@ -849,7 +877,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.delta_indices = None
|
||||
obj.local_files_only = True
|
||||
obj.episode_data_index = None
|
||||
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
return obj
|
||||
|
@ -889,7 +916,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
# Check that some properties are consistent across datasets. Note: We may relax some of these
|
||||
# consistency requirements in future iterations of this class.
|
||||
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
|
||||
if dataset.info != self._datasets[0].info:
|
||||
if dataset.meta.info != self._datasets[0].meta.info:
|
||||
raise ValueError(
|
||||
f"Detected a mismatch in dataset info between {self.repo_ids[0]} and {repo_id}. This is "
|
||||
"not yet supported."
|
||||
|
@ -938,7 +965,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||
"""
|
||||
return self._datasets[0].info["fps"]
|
||||
return self._datasets[0].meta.info["fps"]
|
||||
|
||||
@property
|
||||
def video(self) -> bool:
|
||||
|
@ -948,7 +975,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||
"""
|
||||
return self._datasets[0].info.get("video", False)
|
||||
return self._datasets[0].meta.info.get("video", False)
|
||||
|
||||
@property
|
||||
def features(self) -> datasets.Features:
|
||||
|
|
|
@ -139,7 +139,7 @@ def load_tasks(local_dir: Path) -> dict:
|
|||
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
|
||||
|
||||
def load_episode_dicts(local_dir: Path) -> dict:
|
||||
def load_episodes(local_dir: Path) -> dict:
|
||||
return load_jsonlines(local_dir / EPISODES_PATH)
|
||||
|
||||
|
||||
|
|
|
@ -105,7 +105,7 @@ from pathlib import Path
|
|||
from typing import List
|
||||
|
||||
# from safetensors.torch import load_file, save_file
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.robot_devices.control_utils import (
|
||||
control_loop,
|
||||
has_method,
|
||||
|
@ -234,15 +234,18 @@ def record(
|
|||
|
||||
# Create empty dataset or load existing saved episodes
|
||||
sanity_check_dataset_name(repo_id, policy)
|
||||
dataset = LeRobotDataset.create(
|
||||
dataset_metadata = LeRobotDatasetMetadata.create(
|
||||
repo_id,
|
||||
fps,
|
||||
root=root,
|
||||
robot=robot,
|
||||
image_writer_processes=num_image_writer_processes,
|
||||
image_writer_threads_per_camera=num_image_writer_threads_per_camera,
|
||||
use_videos=video,
|
||||
)
|
||||
dataset = LeRobotDataset.create(
|
||||
dataset_metadata,
|
||||
image_writer_processes=num_image_writer_processes,
|
||||
image_writer_threads=num_image_writer_threads_per_camera,
|
||||
)
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
@ -315,7 +318,6 @@ def record(
|
|||
|
||||
dataset.consolidate(run_compute_stats)
|
||||
|
||||
# lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds)
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
|
|
@ -484,7 +484,7 @@ def main(
|
|||
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=str(pretrained_policy_path))
|
||||
else:
|
||||
# Note: We need the dataset stats to pass to the policy's normalization modules.
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).meta.stats)
|
||||
|
||||
assert isinstance(policy, nn.Module)
|
||||
policy.eval()
|
||||
|
|
|
@ -328,7 +328,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
logging.info("make_policy")
|
||||
policy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
dataset_stats=offline_dataset.stats if not cfg.resume else None,
|
||||
dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
|
|
@ -153,7 +153,7 @@ def visualize_dataset(
|
|||
rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
|
||||
|
||||
# display each camera image
|
||||
for key in dataset.camera_keys:
|
||||
for key in dataset.meta.camera_keys:
|
||||
# TODO(rcadene): add `.compress()`? is it lossless?
|
||||
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
|
||||
|
||||
|
|
|
@ -97,8 +97,8 @@ def run_server(
|
|||
"num_episodes": dataset.num_episodes,
|
||||
"fps": dataset.fps,
|
||||
}
|
||||
video_paths = [dataset.get_video_file_path(episode_id, key) for key in dataset.video_keys]
|
||||
tasks = dataset.episode_dicts[episode_id]["tasks"]
|
||||
video_paths = [dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys]
|
||||
tasks = dataset.meta.episodes[episode_id]["tasks"]
|
||||
videos_info = [
|
||||
{"url": url_for("static", filename=video_path), "filename": video_path.name}
|
||||
for video_path in video_paths
|
||||
|
@ -170,7 +170,8 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
|
|||
# get first frame of episode (hack to get video_path of the episode)
|
||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
||||
return [
|
||||
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] for key in dataset.video_keys
|
||||
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
|
||||
for key in dataset.meta.video_keys
|
||||
]
|
||||
|
||||
|
||||
|
@ -202,8 +203,8 @@ def visualize_dataset_html(
|
|||
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
|
||||
if len(dataset.image_keys) > 0:
|
||||
raise NotImplementedError(f"Image keys ({dataset.image_keys=}) are currently not supported.")
|
||||
if len(dataset.meta.image_keys) > 0:
|
||||
raise NotImplementedError(f"Image keys ({dataset.meta.image_keys=}) are currently not supported.")
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = f"outputs/visualize_dataset_html/{repo_id}"
|
||||
|
|
|
@ -157,7 +157,7 @@ def visualize_transforms(cfg, output_dir: Path, n_examples: int = 5):
|
|||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get 1st frame from 1st camera of 1st episode
|
||||
original_frame = dataset[0][dataset.camera_keys[0]]
|
||||
original_frame = dataset[0][dataset.meta.camera_keys[0]]
|
||||
to_pil(original_frame).save(output_dir / "original_frame.png", quality=100)
|
||||
print("\nOriginal frame saved to:")
|
||||
print(f" {output_dir / 'original_frame.png'}.")
|
||||
|
|
|
@ -8,7 +8,7 @@ import PIL.Image
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_PARQUET_PATH,
|
||||
|
@ -33,8 +33,8 @@ def make_dummy_shapes(keys: list[str] | None = None, camera_keys: list[str] | No
|
|||
return shapes
|
||||
|
||||
|
||||
def get_task_index(tasks_dicts: dict, task: str) -> int:
|
||||
tasks = {d["task_index"]: d["task"] for d in tasks_dicts}
|
||||
def get_task_index(task_dicts: dict, task: str) -> int:
|
||||
tasks = {d["task_index"]: d["task"] for d in task_dicts}
|
||||
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
|
||||
return task_to_task_index[task]
|
||||
|
||||
|
@ -313,6 +313,47 @@ def hf_dataset_factory(img_array_factory, episodes, tasks):
|
|||
return _create_hf_dataset
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def lerobot_dataset_metadata_factory(
|
||||
info,
|
||||
stats,
|
||||
tasks,
|
||||
episodes,
|
||||
mock_snapshot_download_factory,
|
||||
):
|
||||
def _create_lerobot_dataset_metadata(
|
||||
root: Path,
|
||||
repo_id: str = DUMMY_REPO_ID,
|
||||
info_dict: dict = info,
|
||||
stats_dict: dict = stats,
|
||||
task_dicts: list[dict] = tasks,
|
||||
episode_dicts: list[dict] = episodes,
|
||||
**kwargs,
|
||||
) -> LeRobotDatasetMetadata:
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
info_dict=info_dict,
|
||||
stats_dict=stats_dict,
|
||||
task_dicts=task_dicts,
|
||||
episode_dicts=episode_dicts,
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.get_hub_safe_version"
|
||||
) as mock_get_hub_safe_version_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||
) as mock_snapshot_download_patch,
|
||||
):
|
||||
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
|
||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
||||
|
||||
return LeRobotDatasetMetadata(
|
||||
repo_id=repo_id, root=root, local_files_only=kwargs.get("local_files_only", False)
|
||||
)
|
||||
|
||||
return _create_lerobot_dataset_metadata
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def lerobot_dataset_factory(
|
||||
info,
|
||||
|
@ -321,6 +362,7 @@ def lerobot_dataset_factory(
|
|||
episodes,
|
||||
hf_dataset,
|
||||
mock_snapshot_download_factory,
|
||||
lerobot_dataset_metadata_factory,
|
||||
):
|
||||
def _create_lerobot_dataset(
|
||||
root: Path,
|
||||
|
@ -335,19 +377,26 @@ def lerobot_dataset_factory(
|
|||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
info_dict=info_dict,
|
||||
stats_dict=stats_dict,
|
||||
tasks_dicts=task_dicts,
|
||||
episodes_dicts=episode_dicts,
|
||||
task_dicts=task_dicts,
|
||||
episode_dicts=episode_dicts,
|
||||
hf_ds=hf_ds,
|
||||
)
|
||||
mock_metadata = lerobot_dataset_metadata_factory(
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
info_dict=info_dict,
|
||||
stats_dict=stats_dict,
|
||||
task_dicts=task_dicts,
|
||||
episode_dicts=episode_dicts,
|
||||
**kwargs,
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.get_hub_safe_version"
|
||||
) as mock_get_hub_safe_version_patch,
|
||||
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||
) as mock_snapshot_download_patch,
|
||||
):
|
||||
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
|
||||
mock_metadata_patch.return_value = mock_metadata
|
||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
||||
|
||||
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)
|
||||
|
|
|
@ -36,11 +36,11 @@ def stats_path(stats):
|
|||
|
||||
@pytest.fixture(scope="session")
|
||||
def tasks_path(tasks):
|
||||
def _create_tasks_jsonl_file(dir: Path, tasks_dicts: list = tasks) -> Path:
|
||||
def _create_tasks_jsonl_file(dir: Path, task_dicts: list = tasks) -> Path:
|
||||
fpath = dir / TASKS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(tasks_dicts)
|
||||
writer.write_all(task_dicts)
|
||||
return fpath
|
||||
|
||||
return _create_tasks_jsonl_file
|
||||
|
|
|
@ -26,7 +26,7 @@ def mock_snapshot_download_factory(
|
|||
"""
|
||||
|
||||
def _mock_snapshot_download_func(
|
||||
info_dict=info, stats_dict=stats, tasks_dicts=tasks, episodes_dicts=episodes, hf_ds=hf_dataset
|
||||
info_dict=info, stats_dict=stats, task_dicts=tasks, episode_dicts=episodes, hf_ds=hf_dataset
|
||||
):
|
||||
def _extract_episode_index_from_path(fpath: str) -> int:
|
||||
path = Path(fpath)
|
||||
|
@ -53,7 +53,7 @@ def mock_snapshot_download_factory(
|
|||
all_files.extend(meta_files)
|
||||
|
||||
data_files = []
|
||||
for episode_dict in episodes_dicts:
|
||||
for episode_dict in episode_dicts:
|
||||
ep_idx = episode_dict["episode_index"]
|
||||
ep_chunk = ep_idx // info_dict["chunks_size"]
|
||||
data_path = info_dict["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
|
@ -75,9 +75,9 @@ def mock_snapshot_download_factory(
|
|||
elif rel_path == STATS_PATH:
|
||||
_ = stats_path(local_dir, stats_dict)
|
||||
elif rel_path == TASKS_PATH:
|
||||
_ = tasks_path(local_dir, tasks_dicts)
|
||||
_ = tasks_path(local_dir, task_dicts)
|
||||
elif rel_path == EPISODES_PATH:
|
||||
_ = episode_path(local_dir, episodes_dicts)
|
||||
_ = episode_path(local_dir, episode_dicts)
|
||||
else:
|
||||
pass
|
||||
return str(local_dir)
|
||||
|
|
|
@ -76,7 +76,7 @@ def main():
|
|||
dataset = LeRobotDataset(DATASET_REPO_ID, image_transforms=None)
|
||||
output_dir = Path(ARTIFACT_DIR)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
original_frame = dataset[0][dataset.camera_keys[0]]
|
||||
original_frame = dataset[0][dataset.meta.camera_keys[0]]
|
||||
|
||||
save_single_transforms(original_frame, output_dir)
|
||||
save_default_config_transform(original_frame, output_dir)
|
||||
|
|
|
@ -38,7 +38,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
|||
)
|
||||
set_global_seed(1337)
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||
policy = make_policy(cfg, dataset_stats=dataset.meta.stats)
|
||||
policy.train()
|
||||
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
||||
|
||||
|
|
|
@ -155,7 +155,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
|||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
)
|
||||
assert dataset.total_episodes == 2
|
||||
assert dataset.meta.total_episodes == 2
|
||||
assert len(dataset) == 2
|
||||
|
||||
replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False)
|
||||
|
@ -193,7 +193,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
|||
overrides=overrides,
|
||||
)
|
||||
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
out_dir = tmpdir / "logger"
|
||||
logger = Logger(cfg, out_dir, wandb_job_name="debug")
|
||||
|
|
|
@ -33,7 +33,11 @@ from lerobot.common.datasets.compute_stats import (
|
|||
get_stats_einops_patterns,
|
||||
)
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
LeRobotDatasetMetadata,
|
||||
MultiLeRobotDataset,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
create_branch,
|
||||
flatten_dict,
|
||||
|
@ -53,14 +57,17 @@ def test_same_attributes_defined(lerobot_dataset_factory, tmp_path):
|
|||
# Instantiate both ways
|
||||
robot = make_robot("koch", mock=True)
|
||||
root_create = tmp_path / "create"
|
||||
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
|
||||
metadata_create = LeRobotDatasetMetadata.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create
|
||||
)
|
||||
dataset_create = LeRobotDataset.create(metadata_create)
|
||||
|
||||
root_init = tmp_path / "init"
|
||||
dataset_init = lerobot_dataset_factory(root=root_init)
|
||||
|
||||
# Access the '_hub_version' cached_property in both instances to force its creation
|
||||
_ = dataset_init._hub_version
|
||||
_ = dataset_create._hub_version
|
||||
_ = dataset_init.meta._hub_version
|
||||
_ = dataset_create.meta._hub_version
|
||||
|
||||
init_attr = set(vars(dataset_init).keys())
|
||||
create_attr = set(vars(dataset_create).keys())
|
||||
|
@ -78,8 +85,8 @@ def test_dataset_initialization(lerobot_dataset_from_episodes_factory, tmp_path)
|
|||
dataset = lerobot_dataset_from_episodes_factory(root=tmp_path, **kwargs)
|
||||
|
||||
assert dataset.repo_id == kwargs["repo_id"]
|
||||
assert dataset.total_episodes == kwargs["total_episodes"]
|
||||
assert dataset.total_frames == kwargs["total_frames"]
|
||||
assert dataset.meta.total_episodes == kwargs["total_episodes"]
|
||||
assert dataset.meta.total_frames == kwargs["total_frames"]
|
||||
assert dataset.episodes == kwargs["episodes"]
|
||||
assert dataset.num_episodes == len(kwargs["episodes"])
|
||||
assert dataset.num_frames == len(dataset)
|
||||
|
@ -118,7 +125,7 @@ def test_factory(env_name, repo_id, policy_name):
|
|||
)
|
||||
dataset = make_dataset(cfg)
|
||||
delta_timestamps = dataset.delta_timestamps
|
||||
camera_keys = dataset.camera_keys
|
||||
camera_keys = dataset.meta.camera_keys
|
||||
|
||||
item = dataset[0]
|
||||
|
||||
|
@ -251,7 +258,7 @@ def test_compute_stats_on_xarm():
|
|||
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
|
||||
|
||||
# load stats used during training which are expected to match the ones returned by computed_stats
|
||||
loaded_stats = dataset.stats # noqa: F841
|
||||
loaded_stats = dataset.meta.stats # noqa: F841
|
||||
|
||||
# TODO(rcadene): we can't test this because expected_stats is computed on a subset
|
||||
# # test loaded stats match expected stats
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# TODO(aliberts): Mute logging for these tests
|
||||
|
||||
import io
|
||||
import subprocess
|
||||
import sys
|
||||
|
@ -29,6 +29,7 @@ def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> s
|
|||
return text
|
||||
|
||||
|
||||
# TODO(aliberts): Remove usage of subprocess calls and patch code with fixtures
|
||||
def _run_script(path):
|
||||
subprocess.run([sys.executable, path], check=True)
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ def test_get_policy_and_config_classes(policy_name: str):
|
|||
assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation)
|
||||
|
||||
|
||||
# TODO(aliberts): refactor using lerobot/__init__.py variables
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"env_name,policy_name,extra_overrides",
|
||||
[
|
||||
|
@ -136,7 +136,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||
|
||||
# Check that we can make the policy object.
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
|
||||
# Check that the policy follows the required protocol.
|
||||
assert isinstance(
|
||||
policy, Policy
|
||||
|
@ -195,6 +195,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||
env.step(action)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
def test_act_backbone_lr():
|
||||
"""
|
||||
Test that the ACT policy can be instantiated with a different learning rate for the backbone.
|
||||
|
@ -213,7 +214,7 @@ def test_act_backbone_lr():
|
|||
assert cfg.training.lr_backbone == 0.001
|
||||
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
|
||||
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
||||
assert len(optimizer.param_groups) == 2
|
||||
assert optimizer.param_groups[0]["lr"] == cfg.training.lr
|
||||
|
|
|
@ -250,6 +250,7 @@ def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"required_packages, raw_format, repo_id, make_test_data",
|
||||
[
|
||||
|
|
Loading…
Reference in New Issue