1144 lines
47 KiB
Python
1144 lines
47 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import logging
|
|
import os
|
|
import shutil
|
|
from functools import cached_property
|
|
from pathlib import Path
|
|
from typing import Callable
|
|
|
|
import datasets
|
|
import numpy as np
|
|
import PIL.Image
|
|
import torch
|
|
import torch.utils
|
|
from datasets import load_dataset
|
|
from huggingface_hub import create_repo, snapshot_download, upload_folder
|
|
|
|
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
|
|
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
|
from lerobot.common.datasets.utils import (
|
|
DEFAULT_FEATURES,
|
|
DEFAULT_IMAGE_PATH,
|
|
EPISODES_PATH,
|
|
INFO_PATH,
|
|
STATS_PATH,
|
|
TASKS_PATH,
|
|
append_jsonlines,
|
|
check_delta_timestamps,
|
|
check_timestamps_sync,
|
|
check_version_compatibility,
|
|
create_branch,
|
|
create_empty_dataset_info,
|
|
create_lerobot_dataset_card,
|
|
get_delta_indices,
|
|
get_episode_data_index,
|
|
get_features_from_robot,
|
|
get_hf_features_from_features,
|
|
get_hub_safe_version,
|
|
hf_transform_to_torch,
|
|
load_episodes,
|
|
load_info,
|
|
load_stats,
|
|
load_tasks,
|
|
serialize_dict,
|
|
write_json,
|
|
write_parquet,
|
|
)
|
|
from lerobot.common.datasets.video_utils import (
|
|
VideoFrame,
|
|
decode_video_frames_torchvision,
|
|
encode_video_frames,
|
|
get_video_info,
|
|
)
|
|
from lerobot.common.robot_devices.robots.utils import Robot
|
|
|
|
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
|
|
CODEBASE_VERSION = "v2.0"
|
|
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
|
|
|
|
|
|
class LeRobotDatasetMetadata:
|
|
def __init__(
|
|
self,
|
|
repo_id: str,
|
|
root: str | Path | None = None,
|
|
local_files_only: bool = False,
|
|
):
|
|
self.repo_id = repo_id
|
|
self.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
|
|
self.local_files_only = local_files_only
|
|
|
|
# Load metadata
|
|
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
|
self.pull_from_repo(allow_patterns="meta/")
|
|
self.info = load_info(self.root)
|
|
self.stats = load_stats(self.root)
|
|
self.tasks = load_tasks(self.root)
|
|
self.episodes = load_episodes(self.root)
|
|
|
|
def pull_from_repo(
|
|
self,
|
|
allow_patterns: list[str] | str | None = None,
|
|
ignore_patterns: list[str] | str | None = None,
|
|
) -> None:
|
|
snapshot_download(
|
|
self.repo_id,
|
|
repo_type="dataset",
|
|
revision=self._hub_version,
|
|
local_dir=self.root,
|
|
allow_patterns=allow_patterns,
|
|
ignore_patterns=ignore_patterns,
|
|
local_files_only=self.local_files_only,
|
|
)
|
|
|
|
@cached_property
|
|
def _hub_version(self) -> str | None:
|
|
return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
|
|
|
|
@property
|
|
def _version(self) -> str:
|
|
"""Codebase version used to create this dataset."""
|
|
return self.info["codebase_version"]
|
|
|
|
def get_data_file_path(self, ep_index: int) -> Path:
|
|
ep_chunk = self.get_episode_chunk(ep_index)
|
|
fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index)
|
|
return Path(fpath)
|
|
|
|
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
|
ep_chunk = self.get_episode_chunk(ep_index)
|
|
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
|
return Path(fpath)
|
|
|
|
def get_episode_chunk(self, ep_index: int) -> int:
|
|
return ep_index // self.chunks_size
|
|
|
|
@property
|
|
def data_path(self) -> str:
|
|
"""Formattable string for the parquet files."""
|
|
return self.info["data_path"]
|
|
|
|
@property
|
|
def video_path(self) -> str | None:
|
|
"""Formattable string for the video files."""
|
|
return self.info["video_path"]
|
|
|
|
@property
|
|
def robot_type(self) -> str | None:
|
|
"""Robot type used in recording this dataset."""
|
|
return self.info["robot_type"]
|
|
|
|
@property
|
|
def fps(self) -> int:
|
|
"""Frames per second used during data collection."""
|
|
return self.info["fps"]
|
|
|
|
@property
|
|
def features(self) -> dict[str, dict]:
|
|
"""All features contained in the dataset."""
|
|
return self.info["features"]
|
|
|
|
@property
|
|
def image_keys(self) -> list[str]:
|
|
"""Keys to access visual modalities stored as images."""
|
|
return [key for key, ft in self.features.items() if ft["dtype"] == "image"]
|
|
|
|
@property
|
|
def video_keys(self) -> list[str]:
|
|
"""Keys to access visual modalities stored as videos."""
|
|
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
|
|
|
|
@property
|
|
def camera_keys(self) -> list[str]:
|
|
"""Keys to access visual modalities (regardless of their storage method)."""
|
|
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
|
|
|
@property
|
|
def names(self) -> dict[str, list | dict]:
|
|
"""Names of the various dimensions of vector modalities."""
|
|
return {key: ft["names"] for key, ft in self.features.items()}
|
|
|
|
@property
|
|
def shapes(self) -> dict:
|
|
"""Shapes for the different features."""
|
|
return {key: tuple(ft["shape"]) for key, ft in self.features.items()}
|
|
|
|
@property
|
|
def total_episodes(self) -> int:
|
|
"""Total number of episodes available."""
|
|
return self.info["total_episodes"]
|
|
|
|
@property
|
|
def total_frames(self) -> int:
|
|
"""Total number of frames saved in this dataset."""
|
|
return self.info["total_frames"]
|
|
|
|
@property
|
|
def total_tasks(self) -> int:
|
|
"""Total number of different tasks performed in this dataset."""
|
|
return self.info["total_tasks"]
|
|
|
|
@property
|
|
def total_chunks(self) -> int:
|
|
"""Total number of chunks (groups of episodes)."""
|
|
return self.info["total_chunks"]
|
|
|
|
@property
|
|
def chunks_size(self) -> int:
|
|
"""Max number of episodes per chunk."""
|
|
return self.info["chunks_size"]
|
|
|
|
@property
|
|
def task_to_task_index(self) -> dict:
|
|
return {task: task_idx for task_idx, task in self.tasks.items()}
|
|
|
|
def get_task_index(self, task: str) -> int:
|
|
"""
|
|
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
|
otherwise creates a new task_index.
|
|
"""
|
|
task_index = self.task_to_task_index.get(task, None)
|
|
return task_index if task_index is not None else self.total_tasks
|
|
|
|
def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None:
|
|
self.info["total_episodes"] += 1
|
|
self.info["total_frames"] += episode_length
|
|
|
|
if task_index not in self.tasks:
|
|
self.info["total_tasks"] += 1
|
|
self.tasks[task_index] = task
|
|
task_dict = {
|
|
"task_index": task_index,
|
|
"task": task,
|
|
}
|
|
append_jsonlines(task_dict, self.root / TASKS_PATH)
|
|
|
|
chunk = self.get_episode_chunk(episode_index)
|
|
if chunk >= self.total_chunks:
|
|
self.info["total_chunks"] += 1
|
|
|
|
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
|
self.info["total_videos"] += len(self.video_keys)
|
|
write_json(self.info, self.root / INFO_PATH)
|
|
|
|
episode_dict = {
|
|
"episode_index": episode_index,
|
|
"tasks": [task],
|
|
"length": episode_length,
|
|
}
|
|
self.episodes.append(episode_dict)
|
|
append_jsonlines(episode_dict, self.root / EPISODES_PATH)
|
|
|
|
# TODO(aliberts): refactor stats in save_episodes
|
|
# image_sampling = int(self.fps / 2) # sample 2 img/s for the stats
|
|
# ep_stats = compute_episode_stats(episode_buffer, self.features, episode_length, image_sampling=image_sampling)
|
|
# ep_stats = serialize_dict(ep_stats)
|
|
# append_jsonlines(ep_stats, self.root / STATS_PATH)
|
|
|
|
def write_video_info(self) -> None:
|
|
"""
|
|
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
|
been encoded the same way. Also, this means it assumes the first episode exists.
|
|
"""
|
|
for key in self.video_keys:
|
|
if not self.features[key].get("info", None):
|
|
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
|
self.info["features"][key]["info"] = get_video_info(video_path)
|
|
|
|
write_json(self.info, self.root / INFO_PATH)
|
|
|
|
def __repr__(self):
|
|
feature_keys = list(self.features)
|
|
return (
|
|
f"{self.__class__.__name__}({{\n"
|
|
f" Repository ID: '{self.repo_id}',\n"
|
|
f" Total episodes: '{self.total_episodes}',\n"
|
|
f" Total frames: '{self.total_frames}',\n"
|
|
f" Features: '{feature_keys}',\n"
|
|
"})',\n"
|
|
)
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
repo_id: str,
|
|
fps: int,
|
|
root: str | Path | None = None,
|
|
robot: Robot | None = None,
|
|
robot_type: str | None = None,
|
|
features: dict | None = None,
|
|
use_videos: bool = True,
|
|
) -> "LeRobotDatasetMetadata":
|
|
"""Creates metadata for a LeRobotDataset."""
|
|
obj = cls.__new__(cls)
|
|
obj.repo_id = repo_id
|
|
obj.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
|
|
|
|
obj.root.mkdir(parents=True, exist_ok=False)
|
|
|
|
if robot is not None:
|
|
features = get_features_from_robot(robot, use_videos)
|
|
robot_type = robot.robot_type
|
|
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
|
logging.warning(
|
|
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
|
"In this case, frames from lower fps cameras will be repeated to fill in the blanks."
|
|
)
|
|
elif features is None:
|
|
raise ValueError(
|
|
"Dataset features must either come from a Robot or explicitly passed upon creation."
|
|
)
|
|
else:
|
|
# TODO(aliberts, rcadene): implement sanity check for features
|
|
features = {**features, **DEFAULT_FEATURES}
|
|
|
|
obj.tasks, obj.stats, obj.episodes = {}, {}, []
|
|
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
|
|
if len(obj.video_keys) > 0 and not use_videos:
|
|
raise ValueError()
|
|
write_json(obj.info, obj.root / INFO_PATH)
|
|
obj.local_files_only = True
|
|
return obj
|
|
|
|
|
|
class LeRobotDataset(torch.utils.data.Dataset):
|
|
def __init__(
|
|
self,
|
|
repo_id: str,
|
|
root: str | Path | None = None,
|
|
episodes: list[int] | None = None,
|
|
image_transforms: Callable | None = None,
|
|
delta_timestamps: dict[list[float]] | None = None,
|
|
tolerance_s: float = 1e-4,
|
|
download_videos: bool = True,
|
|
local_files_only: bool = False,
|
|
video_backend: str | None = None,
|
|
):
|
|
"""
|
|
2 modes are available for instantiating this class, depending on 2 different use cases:
|
|
|
|
1. Your dataset already exists:
|
|
- On your local disk in the 'root' folder. This is typically the case when you recorded your
|
|
dataset locally and you may or may not have pushed it to the hub yet. Instantiating this class
|
|
with 'root' will load your dataset directly from disk. This can happen while you're offline (no
|
|
internet connection), in that case, use local_files_only=True.
|
|
|
|
- On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on
|
|
your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download
|
|
the dataset from that address and load it, pending your dataset is compliant with
|
|
codebase_version v2.0. If your dataset has been created before this new format, you will be
|
|
prompted to convert it using our conversion script from v1.6 to v2.0, which you can find at
|
|
lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py.
|
|
|
|
|
|
2. Your dataset doesn't already exists (either on local disk or on the Hub): you can create an empty
|
|
LeRobotDataset with the 'create' classmethod. This can be used for recording a dataset or port an
|
|
existing dataset to the LeRobotDataset format.
|
|
|
|
|
|
In terms of files, LeRobotDataset encapsulates 3 main things:
|
|
- metadata:
|
|
- info contains various information about the dataset like shapes, keys, fps etc.
|
|
- stats stores the dataset statistics of the different modalities for normalization
|
|
- tasks contains the prompts for each task of the dataset, which can be used for
|
|
task-conditionned training.
|
|
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
|
|
- videos (optional) from which frames are loaded to be synchronous with data from parquet files.
|
|
|
|
A typical LeRobotDataset looks like this from its root path:
|
|
.
|
|
├── data
|
|
│ ├── chunk-000
|
|
│ │ ├── episode_000000.parquet
|
|
│ │ ├── episode_000001.parquet
|
|
│ │ ├── episode_000002.parquet
|
|
│ │ └── ...
|
|
│ ├── chunk-001
|
|
│ │ ├── episode_001000.parquet
|
|
│ │ ├── episode_001001.parquet
|
|
│ │ ├── episode_001002.parquet
|
|
│ │ └── ...
|
|
│ └── ...
|
|
├── meta
|
|
│ ├── episodes.jsonl
|
|
│ ├── info.json
|
|
│ ├── stats.json
|
|
│ └── tasks.jsonl
|
|
└── videos
|
|
├── chunk-000
|
|
│ ├── observation.images.laptop
|
|
│ │ ├── episode_000000.mp4
|
|
│ │ ├── episode_000001.mp4
|
|
│ │ ├── episode_000002.mp4
|
|
│ │ └── ...
|
|
│ ├── observation.images.phone
|
|
│ │ ├── episode_000000.mp4
|
|
│ │ ├── episode_000001.mp4
|
|
│ │ ├── episode_000002.mp4
|
|
│ │ └── ...
|
|
├── chunk-001
|
|
└── ...
|
|
|
|
Note that this file-based structure is designed to be as versatile as possible. The files are split by
|
|
episodes which allows a more granular control over which episodes one wants to use and download. The
|
|
structure of the dataset is entirely described in the info.json file, which can be easily downloaded
|
|
or viewed directly on the hub before downloading any actual data. The type of files used are very
|
|
simple and do not need complex tools to be read, it only uses .parquet, .json and .mp4 files (and .md
|
|
for the README).
|
|
|
|
Args:
|
|
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
|
|
will be stored under root/repo_id.
|
|
root (Path | None, optional): Local directory to use for downloading/writing files. You can also
|
|
set the LEROBOT_HOME environment variable to point to a different location. Defaults to
|
|
'~/.cache/huggingface/lerobot'.
|
|
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
|
their episode_index in this list. Defaults to None.
|
|
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
|
|
torchvision.transforms.v2 here which will be applied to visual modalities (whether they come
|
|
from videos or images). Defaults to None.
|
|
delta_timestamps (dict[list[float]] | None, optional): _description_. Defaults to None.
|
|
tolerance_s (float, optional): Tolerance in seconds used to ensure data timestamps are actually in
|
|
sync with the fps value. It is used at the init of the dataset to make sure that each
|
|
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
|
|
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
|
|
multiples of 1/fps. Defaults to 1e-4.
|
|
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
|
video files are already present on local disk, they won't be downloaded again. Defaults to
|
|
True.
|
|
local_files_only (bool, optional): Flag to use local files only. If True, no requests to the hub
|
|
will be made. Defaults to False.
|
|
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
|
|
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
|
|
"""
|
|
super().__init__()
|
|
self.repo_id = repo_id
|
|
self.root = Path(root) if root else LEROBOT_HOME / repo_id
|
|
self.image_transforms = image_transforms
|
|
self.delta_timestamps = delta_timestamps
|
|
self.episodes = episodes
|
|
self.tolerance_s = tolerance_s
|
|
self.video_backend = video_backend if video_backend else "pyav"
|
|
self.delta_indices = None
|
|
self.local_files_only = local_files_only
|
|
|
|
# Unused attributes
|
|
self.image_writer = None
|
|
self.episode_buffer = None
|
|
|
|
self.root.mkdir(exist_ok=True, parents=True)
|
|
|
|
# Load metadata
|
|
self.meta = LeRobotDatasetMetadata(self.repo_id, self.root, self.local_files_only)
|
|
|
|
# Check version
|
|
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
|
|
|
# Load actual data
|
|
self.download_episodes(download_videos)
|
|
self.hf_dataset = self.load_hf_dataset()
|
|
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
|
|
|
# Check timestamps
|
|
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
|
|
|
# Setup delta_indices
|
|
if self.delta_timestamps is not None:
|
|
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
|
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
|
|
|
# Available stats implies all videos have been encoded and dataset is iterable
|
|
self.consolidated = self.meta.stats is not None
|
|
|
|
def push_to_hub(
|
|
self,
|
|
tags: list | None = None,
|
|
license: str | None = "apache-2.0",
|
|
push_videos: bool = True,
|
|
private: bool = False,
|
|
**card_kwargs,
|
|
) -> None:
|
|
if not self.consolidated:
|
|
logging.warning(
|
|
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet. "
|
|
"Consolidating first."
|
|
)
|
|
self.consolidate()
|
|
|
|
ignore_patterns = ["images/"]
|
|
if not push_videos:
|
|
ignore_patterns.append("videos/")
|
|
|
|
create_repo(
|
|
repo_id=self.repo_id,
|
|
private=private,
|
|
repo_type="dataset",
|
|
exist_ok=True,
|
|
)
|
|
|
|
upload_folder(
|
|
repo_id=self.repo_id,
|
|
folder_path=self.root,
|
|
repo_type="dataset",
|
|
ignore_patterns=ignore_patterns,
|
|
)
|
|
card = create_lerobot_dataset_card(
|
|
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
|
)
|
|
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset")
|
|
create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset")
|
|
|
|
def pull_from_repo(
|
|
self,
|
|
allow_patterns: list[str] | str | None = None,
|
|
ignore_patterns: list[str] | str | None = None,
|
|
) -> None:
|
|
snapshot_download(
|
|
self.repo_id,
|
|
repo_type="dataset",
|
|
revision=self.meta._hub_version,
|
|
local_dir=self.root,
|
|
allow_patterns=allow_patterns,
|
|
ignore_patterns=ignore_patterns,
|
|
local_files_only=self.local_files_only,
|
|
)
|
|
|
|
def download_episodes(self, download_videos: bool = True) -> None:
|
|
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
|
|
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
|
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
|
in 'local_dir', they won't be downloaded again.
|
|
"""
|
|
# TODO(rcadene, aliberts): implement faster transfer
|
|
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
|
files = None
|
|
ignore_patterns = None if download_videos else "videos/"
|
|
if self.episodes is not None:
|
|
files = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
|
if len(self.meta.video_keys) > 0 and download_videos:
|
|
video_files = [
|
|
str(self.meta.get_video_file_path(ep_idx, vid_key))
|
|
for vid_key in self.meta.video_keys
|
|
for ep_idx in self.episodes
|
|
]
|
|
files += video_files
|
|
|
|
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
|
|
|
def load_hf_dataset(self) -> datasets.Dataset:
|
|
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
|
if self.episodes is None:
|
|
path = str(self.root / "data")
|
|
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
|
|
else:
|
|
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
|
hf_dataset = load_dataset("parquet", data_files=files, split="train")
|
|
|
|
# TODO(aliberts): hf_dataset.set_format("torch")
|
|
hf_dataset.set_transform(hf_transform_to_torch)
|
|
|
|
return hf_dataset
|
|
|
|
@property
|
|
def fps(self) -> int:
|
|
"""Frames per second used during data collection."""
|
|
return self.meta.fps
|
|
|
|
@property
|
|
def num_frames(self) -> int:
|
|
"""Number of frames in selected episodes."""
|
|
return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames
|
|
|
|
@property
|
|
def num_episodes(self) -> int:
|
|
"""Number of episodes selected."""
|
|
return len(self.episodes) if self.episodes is not None else self.meta.total_episodes
|
|
|
|
@property
|
|
def features(self) -> dict[str, dict]:
|
|
return self.meta.features
|
|
|
|
@property
|
|
def hf_features(self) -> datasets.Features:
|
|
"""Features of the hf_dataset."""
|
|
if self.hf_dataset is not None:
|
|
return self.hf_dataset.features
|
|
else:
|
|
return get_hf_features_from_features(self.features)
|
|
|
|
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
|
ep_start = self.episode_data_index["from"][ep_idx]
|
|
ep_end = self.episode_data_index["to"][ep_idx]
|
|
query_indices = {
|
|
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
|
|
for key, delta_idx in self.delta_indices.items()
|
|
}
|
|
padding = { # Pad values outside of current episode range
|
|
f"{key}_is_pad": torch.BoolTensor(
|
|
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
|
|
)
|
|
for key, delta_idx in self.delta_indices.items()
|
|
}
|
|
return query_indices, padding
|
|
|
|
def _get_query_timestamps(
|
|
self,
|
|
current_ts: float,
|
|
query_indices: dict[str, list[int]] | None = None,
|
|
) -> dict[str, list[float]]:
|
|
query_timestamps = {}
|
|
for key in self.meta.video_keys:
|
|
if query_indices is not None and key in query_indices:
|
|
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
|
query_timestamps[key] = torch.stack(timestamps).tolist()
|
|
else:
|
|
query_timestamps[key] = [current_ts]
|
|
|
|
return query_timestamps
|
|
|
|
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
|
return {
|
|
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
|
for key, q_idx in query_indices.items()
|
|
if key not in self.meta.video_keys
|
|
}
|
|
|
|
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
|
|
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
|
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
|
|
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
|
the main process and a subprocess fails to access it.
|
|
"""
|
|
item = {}
|
|
for vid_key, query_ts in query_timestamps.items():
|
|
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
|
frames = decode_video_frames_torchvision(
|
|
video_path, query_ts, self.tolerance_s, self.video_backend
|
|
)
|
|
item[vid_key] = frames.squeeze(0)
|
|
|
|
return item
|
|
|
|
def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict:
|
|
for key, val in padding.items():
|
|
item[key] = torch.BoolTensor(val)
|
|
return item
|
|
|
|
def __len__(self):
|
|
return self.num_frames
|
|
|
|
def __getitem__(self, idx) -> dict:
|
|
item = self.hf_dataset[idx]
|
|
ep_idx = item["episode_index"].item()
|
|
|
|
query_indices = None
|
|
if self.delta_indices is not None:
|
|
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
|
|
query_indices, padding = self._get_query_indices(idx, current_ep_idx)
|
|
query_result = self._query_hf_dataset(query_indices)
|
|
item = {**item, **padding}
|
|
for key, val in query_result.items():
|
|
item[key] = val
|
|
|
|
if len(self.meta.video_keys) > 0:
|
|
current_ts = item["timestamp"].item()
|
|
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
|
video_frames = self._query_videos(query_timestamps, ep_idx)
|
|
item = {**video_frames, **item}
|
|
|
|
if self.image_transforms is not None:
|
|
image_keys = self.meta.camera_keys
|
|
for cam in image_keys:
|
|
item[cam] = self.image_transforms(item[cam])
|
|
|
|
return item
|
|
|
|
def __repr__(self):
|
|
feature_keys = list(self.features)
|
|
return (
|
|
f"{self.__class__.__name__}({{\n"
|
|
f" Repository ID: '{self.repo_id}',\n"
|
|
f" Number of selected episodes: '{self.num_episodes}',\n"
|
|
f" Number of selected samples: '{self.num_frames}',\n"
|
|
f" Features: '{feature_keys}',\n"
|
|
"})',\n"
|
|
)
|
|
|
|
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
|
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
|
|
return {
|
|
"size": 0,
|
|
**{key: current_ep_idx if key == "episode_index" else [] for key in self.features},
|
|
}
|
|
|
|
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
|
fpath = DEFAULT_IMAGE_PATH.format(
|
|
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
|
)
|
|
return self.root / fpath
|
|
|
|
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
|
|
if self.image_writer is None:
|
|
if isinstance(image, torch.Tensor):
|
|
image = image.cpu().numpy()
|
|
write_image(image, fpath)
|
|
else:
|
|
self.image_writer.save_image(image=image, fpath=fpath)
|
|
|
|
def add_frame(self, frame: dict) -> None:
|
|
"""
|
|
This function only adds the frame to the episode_buffer. Apart from images — which are written in a
|
|
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
|
|
then needs to be called.
|
|
"""
|
|
# TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch,
|
|
# check the dtype and shape matches, etc.
|
|
|
|
if self.episode_buffer is None:
|
|
self.episode_buffer = self.create_episode_buffer()
|
|
|
|
frame_index = self.episode_buffer["size"]
|
|
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
|
self.episode_buffer["frame_index"].append(frame_index)
|
|
self.episode_buffer["timestamp"].append(timestamp)
|
|
|
|
for key in frame:
|
|
if key not in self.features:
|
|
raise ValueError(key)
|
|
|
|
if self.features[key]["dtype"] not in ["image", "video"]:
|
|
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key]
|
|
self.episode_buffer[key].append(item)
|
|
elif self.features[key]["dtype"] in ["image", "video"]:
|
|
img_path = self._get_image_file_path(
|
|
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
|
|
)
|
|
if frame_index == 0:
|
|
img_path.parent.mkdir(parents=True, exist_ok=True)
|
|
self._save_image(frame[key], img_path)
|
|
self.episode_buffer[key].append(str(img_path))
|
|
|
|
self.episode_buffer["size"] += 1
|
|
|
|
def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict | None = None) -> None:
|
|
"""
|
|
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
|
|
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
|
|
the hub.
|
|
|
|
Use 'encode_videos' if you want to encode videos during the saving of this episode. Otherwise,
|
|
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
|
|
time for video encoding.
|
|
"""
|
|
if not episode_data:
|
|
episode_buffer = self.episode_buffer
|
|
|
|
episode_length = episode_buffer.pop("size")
|
|
episode_index = episode_buffer["episode_index"]
|
|
if episode_index != self.meta.total_episodes:
|
|
# TODO(aliberts): Add option to use existing episode_index
|
|
raise NotImplementedError(
|
|
"You might have manually provided the episode_buffer with an episode_index that doesn't "
|
|
"match the total number of episodes in the dataset. This is not supported for now."
|
|
)
|
|
|
|
if episode_length == 0:
|
|
raise ValueError(
|
|
"You must add one or several frames with `add_frame` before calling `add_episode`."
|
|
)
|
|
|
|
task_index = self.meta.get_task_index(task)
|
|
|
|
if not set(episode_buffer.keys()) == set(self.features):
|
|
raise ValueError()
|
|
|
|
for key, ft in self.features.items():
|
|
if key == "index":
|
|
episode_buffer[key] = np.arange(
|
|
self.meta.total_frames, self.meta.total_frames + episode_length
|
|
)
|
|
elif key == "episode_index":
|
|
episode_buffer[key] = np.full((episode_length,), episode_index)
|
|
elif key == "task_index":
|
|
episode_buffer[key] = np.full((episode_length,), task_index)
|
|
elif ft["dtype"] in ["image", "video"]:
|
|
continue
|
|
elif len(ft["shape"]) == 1 and ft["shape"][0] == 1:
|
|
episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"])
|
|
elif len(ft["shape"]) == 1 and ft["shape"][0] > 1:
|
|
episode_buffer[key] = np.stack(episode_buffer[key])
|
|
else:
|
|
raise ValueError(key)
|
|
|
|
self._wait_image_writer()
|
|
self._save_episode_table(episode_buffer, episode_index)
|
|
|
|
self.meta.save_episode(episode_index, episode_length, task, task_index)
|
|
|
|
if encode_videos and len(self.meta.video_keys) > 0:
|
|
video_paths = self.encode_episode_videos(episode_index)
|
|
for key in self.meta.video_keys:
|
|
episode_buffer[key] = video_paths[key]
|
|
|
|
if not episode_data: # Reset the buffer
|
|
self.episode_buffer = self.create_episode_buffer()
|
|
|
|
self.consolidated = False
|
|
|
|
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
|
|
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
|
|
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
|
|
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
|
|
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
|
write_parquet(ep_dataset, ep_data_path)
|
|
|
|
def clear_episode_buffer(self) -> None:
|
|
episode_index = self.episode_buffer["episode_index"]
|
|
if self.image_writer is not None:
|
|
for cam_key in self.meta.camera_keys:
|
|
img_dir = self._get_image_file_path(
|
|
episode_index=episode_index, image_key=cam_key, frame_index=0
|
|
).parent
|
|
if img_dir.is_dir():
|
|
shutil.rmtree(img_dir)
|
|
|
|
# Reset the buffer
|
|
self.episode_buffer = self.create_episode_buffer()
|
|
|
|
def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
|
|
if isinstance(self.image_writer, AsyncImageWriter):
|
|
logging.warning(
|
|
"You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset."
|
|
)
|
|
|
|
self.image_writer = AsyncImageWriter(
|
|
num_processes=num_processes,
|
|
num_threads=num_threads,
|
|
)
|
|
|
|
def stop_image_writer(self) -> None:
|
|
"""
|
|
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
|
|
remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized.
|
|
"""
|
|
if self.image_writer is not None:
|
|
self.image_writer.stop()
|
|
self.image_writer = None
|
|
|
|
def _wait_image_writer(self) -> None:
|
|
"""Wait for asynchronous image writer to finish."""
|
|
if self.image_writer is not None:
|
|
self.image_writer.wait_until_done()
|
|
|
|
def encode_videos(self) -> None:
|
|
"""
|
|
Use ffmpeg to convert frames stored as png into mp4 videos.
|
|
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
|
since video encoding with ffmpeg is already using multithreading.
|
|
"""
|
|
for ep_idx in range(self.meta.total_episodes):
|
|
self.encode_episode_videos(ep_idx)
|
|
|
|
def encode_episode_videos(self, episode_index: int) -> dict:
|
|
"""
|
|
Use ffmpeg to convert frames stored as png into mp4 videos.
|
|
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
|
since video encoding with ffmpeg is already using multithreading.
|
|
"""
|
|
video_paths = {}
|
|
for key in self.meta.video_keys:
|
|
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
|
|
video_paths[key] = str(video_path)
|
|
if video_path.is_file():
|
|
# Skip if video is already encoded. Could be the case when resuming data recording.
|
|
continue
|
|
img_dir = self._get_image_file_path(
|
|
episode_index=episode_index, image_key=key, frame_index=0
|
|
).parent
|
|
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
|
|
|
|
return video_paths
|
|
|
|
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
|
|
self.hf_dataset = self.load_hf_dataset()
|
|
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
|
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
|
|
|
if len(self.meta.video_keys) > 0:
|
|
self.encode_videos()
|
|
self.meta.write_video_info()
|
|
|
|
if not keep_image_files:
|
|
img_dir = self.root / "images"
|
|
if img_dir.is_dir():
|
|
shutil.rmtree(self.root / "images")
|
|
|
|
video_files = list(self.root.rglob("*.mp4"))
|
|
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
|
|
|
|
parquet_files = list(self.root.rglob("*.parquet"))
|
|
assert len(parquet_files) == self.num_episodes
|
|
|
|
if run_compute_stats:
|
|
self.stop_image_writer()
|
|
# TODO(aliberts): refactor stats in save_episodes
|
|
self.meta.stats = compute_stats(self)
|
|
serialized_stats = serialize_dict(self.meta.stats)
|
|
write_json(serialized_stats, self.root / STATS_PATH)
|
|
self.consolidated = True
|
|
else:
|
|
logging.warning(
|
|
"Skipping computation of the dataset statistics, dataset is not fully consolidated."
|
|
)
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
repo_id: str,
|
|
fps: int,
|
|
root: str | Path | None = None,
|
|
robot: Robot | None = None,
|
|
robot_type: str | None = None,
|
|
features: dict | None = None,
|
|
use_videos: bool = True,
|
|
tolerance_s: float = 1e-4,
|
|
image_writer_processes: int = 0,
|
|
image_writer_threads: int = 0,
|
|
video_backend: str | None = None,
|
|
) -> "LeRobotDataset":
|
|
"""Create a LeRobot Dataset from scratch in order to record data."""
|
|
obj = cls.__new__(cls)
|
|
obj.meta = LeRobotDatasetMetadata.create(
|
|
repo_id=repo_id,
|
|
fps=fps,
|
|
root=root,
|
|
robot=robot,
|
|
robot_type=robot_type,
|
|
features=features,
|
|
use_videos=use_videos,
|
|
)
|
|
obj.repo_id = obj.meta.repo_id
|
|
obj.root = obj.meta.root
|
|
obj.local_files_only = obj.meta.local_files_only
|
|
obj.tolerance_s = tolerance_s
|
|
obj.image_writer = None
|
|
|
|
if image_writer_processes or image_writer_threads:
|
|
obj.start_image_writer(image_writer_processes, image_writer_threads)
|
|
|
|
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
|
obj.episode_buffer = obj.create_episode_buffer()
|
|
|
|
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
|
|
# is used to know when certain operations are need (for instance, computing dataset statistics). In
|
|
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling
|
|
# self.consolidate().
|
|
obj.consolidated = True
|
|
|
|
obj.episodes = None
|
|
obj.hf_dataset = None
|
|
obj.image_transforms = None
|
|
obj.delta_timestamps = None
|
|
obj.delta_indices = None
|
|
obj.episode_data_index = None
|
|
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
|
return obj
|
|
|
|
|
|
class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
|
|
|
|
The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
|
|
structure of `LeRobotDataset`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
repo_ids: list[str],
|
|
root: str | Path | None = None,
|
|
episodes: dict | None = None,
|
|
image_transforms: Callable | None = None,
|
|
delta_timestamps: dict[list[float]] | None = None,
|
|
tolerances_s: dict | None = None,
|
|
download_videos: bool = True,
|
|
local_files_only: bool = False,
|
|
video_backend: str | None = None,
|
|
):
|
|
super().__init__()
|
|
self.repo_ids = repo_ids
|
|
self.root = Path(root) if root else LEROBOT_HOME
|
|
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
|
|
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
|
# are handled by this class.
|
|
self._datasets = [
|
|
LeRobotDataset(
|
|
repo_id,
|
|
root=self.root / repo_id,
|
|
episodes=episodes[repo_id] if episodes else None,
|
|
image_transforms=image_transforms,
|
|
delta_timestamps=delta_timestamps,
|
|
tolerance_s=self.tolerances_s[repo_id],
|
|
download_videos=download_videos,
|
|
local_files_only=local_files_only,
|
|
video_backend=video_backend,
|
|
)
|
|
for repo_id in repo_ids
|
|
]
|
|
|
|
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
|
|
# restriction in future iterations of this class. For now, this is necessary at least for being able
|
|
# to use PyTorch's default DataLoader collate function.
|
|
self.disabled_features = set()
|
|
intersection_features = set(self._datasets[0].features)
|
|
for ds in self._datasets:
|
|
intersection_features.intersection_update(ds.features)
|
|
if len(intersection_features) == 0:
|
|
raise RuntimeError(
|
|
"Multiple datasets were provided but they had no keys common to all of them. "
|
|
"The multi-dataset functionality currently only keeps common keys."
|
|
)
|
|
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
|
extra_keys = set(ds.features).difference(intersection_features)
|
|
logging.warning(
|
|
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
|
"other datasets."
|
|
)
|
|
self.disabled_features.update(extra_keys)
|
|
|
|
self.image_transforms = image_transforms
|
|
self.delta_timestamps = delta_timestamps
|
|
self.stats = aggregate_stats(self._datasets)
|
|
|
|
@property
|
|
def repo_id_to_index(self):
|
|
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
|
|
|
|
This index is incorporated as a data key in the dictionary returned by `__getitem__`.
|
|
"""
|
|
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
|
|
|
|
@property
|
|
def repo_index_to_id(self):
|
|
"""Return the inverse mapping if repo_id_to_index."""
|
|
return {v: k for k, v in self.repo_id_to_index}
|
|
|
|
@property
|
|
def fps(self) -> int:
|
|
"""Frames per second used during data collection.
|
|
|
|
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
|
"""
|
|
return self._datasets[0].meta.info["fps"]
|
|
|
|
@property
|
|
def video(self) -> bool:
|
|
"""Returns True if this dataset loads video frames from mp4 files.
|
|
|
|
Returns False if it only loads images from png files.
|
|
|
|
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
|
"""
|
|
return self._datasets[0].meta.info.get("video", False)
|
|
|
|
@property
|
|
def features(self) -> datasets.Features:
|
|
features = {}
|
|
for dataset in self._datasets:
|
|
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
|
|
return features
|
|
|
|
@property
|
|
def camera_keys(self) -> list[str]:
|
|
"""Keys to access image and video stream from cameras."""
|
|
keys = []
|
|
for key, feats in self.features.items():
|
|
if isinstance(feats, (datasets.Image, VideoFrame)):
|
|
keys.append(key)
|
|
return keys
|
|
|
|
@property
|
|
def video_frame_keys(self) -> list[str]:
|
|
"""Keys to access video frames that requires to be decoded into images.
|
|
|
|
Note: It is empty if the dataset contains images only,
|
|
or equal to `self.cameras` if the dataset contains videos only,
|
|
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
|
|
"""
|
|
video_frame_keys = []
|
|
for key, feats in self.features.items():
|
|
if isinstance(feats, VideoFrame):
|
|
video_frame_keys.append(key)
|
|
return video_frame_keys
|
|
|
|
@property
|
|
def num_frames(self) -> int:
|
|
"""Number of samples/frames."""
|
|
return sum(d.num_frames for d in self._datasets)
|
|
|
|
@property
|
|
def num_episodes(self) -> int:
|
|
"""Number of episodes."""
|
|
return sum(d.num_episodes for d in self._datasets)
|
|
|
|
@property
|
|
def tolerance_s(self) -> float:
|
|
"""Tolerance in seconds used to discard loaded frames when their timestamps
|
|
are not close enough from the requested frames. It is only used when `delta_timestamps`
|
|
is provided or when loading video frames from mp4 files.
|
|
"""
|
|
# 1e-4 to account for possible numerical error
|
|
return 1 / self.fps - 1e-4
|
|
|
|
def __len__(self):
|
|
return self.num_frames
|
|
|
|
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
|
if idx >= len(self):
|
|
raise IndexError(f"Index {idx} out of bounds.")
|
|
# Determine which dataset to get an item from based on the index.
|
|
start_idx = 0
|
|
dataset_idx = 0
|
|
for dataset in self._datasets:
|
|
if idx >= start_idx + dataset.num_frames:
|
|
start_idx += dataset.num_frames
|
|
dataset_idx += 1
|
|
continue
|
|
break
|
|
else:
|
|
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
|
|
item = self._datasets[dataset_idx][idx - start_idx]
|
|
item["dataset_index"] = torch.tensor(dataset_idx)
|
|
for data_key in self.disabled_features:
|
|
if data_key in item:
|
|
del item[data_key]
|
|
|
|
return item
|
|
|
|
def __repr__(self):
|
|
return (
|
|
f"{self.__class__.__name__}(\n"
|
|
f" Repository IDs: '{self.repo_ids}',\n"
|
|
f" Number of Samples: {self.num_frames},\n"
|
|
f" Number of Episodes: {self.num_episodes},\n"
|
|
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
|
|
f" Recorded Frames per Second: {self.fps},\n"
|
|
f" Camera Keys: {self.camera_keys},\n"
|
|
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
|
f" Transformations: {self.image_transforms},\n"
|
|
f")"
|
|
)
|