402 lines
15 KiB
Python
402 lines
15 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Callable
|
|
|
|
import datasets
|
|
import torch
|
|
import torch.utils
|
|
|
|
from lerobot.common.datasets.compute_stats import aggregate_stats
|
|
from lerobot.common.datasets.utils import (
|
|
calculate_episode_data_index,
|
|
load_episode_data_index,
|
|
load_hf_dataset,
|
|
load_info,
|
|
load_previous_and_future_frames,
|
|
load_stats,
|
|
load_videos,
|
|
reset_episode_index,
|
|
)
|
|
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
|
|
|
|
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
|
|
CODEBASE_VERSION = "v1.6"
|
|
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
|
|
|
|
|
class LeRobotDataset(torch.utils.data.Dataset):
|
|
def __init__(
|
|
self,
|
|
repo_id: str,
|
|
root: Path | None = DATA_DIR,
|
|
split: str = "train",
|
|
image_transforms: Callable | None = None,
|
|
delta_timestamps: dict[list[float]] | None = None,
|
|
video_backend: str | None = None,
|
|
):
|
|
super().__init__()
|
|
self.repo_id = repo_id
|
|
self.root = root
|
|
self.split = split
|
|
self.image_transforms = image_transforms
|
|
self.delta_timestamps = delta_timestamps
|
|
# load data from hub or locally when root is provided
|
|
# TODO(rcadene, aliberts): implement faster transfer
|
|
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
|
self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, root, split)
|
|
if split == "train":
|
|
self.episode_data_index = load_episode_data_index(repo_id, CODEBASE_VERSION, root)
|
|
else:
|
|
self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
|
|
self.hf_dataset = reset_episode_index(self.hf_dataset)
|
|
self.stats = load_stats(repo_id, CODEBASE_VERSION, root)
|
|
self.info = load_info(repo_id, CODEBASE_VERSION, root)
|
|
if self.video:
|
|
self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
|
|
self.video_backend = video_backend if video_backend is not None else "pyav"
|
|
|
|
@property
|
|
def fps(self) -> int:
|
|
"""Frames per second used during data collection."""
|
|
return self.info["fps"]
|
|
|
|
@property
|
|
def video(self) -> bool:
|
|
"""Returns True if this dataset loads video frames from mp4 files.
|
|
Returns False if it only loads images from png files.
|
|
"""
|
|
return self.info.get("video", False)
|
|
|
|
@property
|
|
def features(self) -> datasets.Features:
|
|
return self.hf_dataset.features
|
|
|
|
@property
|
|
def camera_keys(self) -> list[str]:
|
|
"""Keys to access image and video stream from cameras."""
|
|
keys = []
|
|
for key, feats in self.hf_dataset.features.items():
|
|
if isinstance(feats, (datasets.Image, VideoFrame)):
|
|
keys.append(key)
|
|
return keys
|
|
|
|
@property
|
|
def video_frame_keys(self) -> list[str]:
|
|
"""Keys to access video frames that requires to be decoded into images.
|
|
|
|
Note: It is empty if the dataset contains images only,
|
|
or equal to `self.cameras` if the dataset contains videos only,
|
|
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
|
|
"""
|
|
video_frame_keys = []
|
|
for key, feats in self.hf_dataset.features.items():
|
|
if isinstance(feats, VideoFrame):
|
|
video_frame_keys.append(key)
|
|
return video_frame_keys
|
|
|
|
@property
|
|
def num_samples(self) -> int:
|
|
"""Number of samples/frames."""
|
|
return len(self.hf_dataset)
|
|
|
|
@property
|
|
def num_episodes(self) -> int:
|
|
"""Number of episodes."""
|
|
return len(self.hf_dataset.unique("episode_index"))
|
|
|
|
@property
|
|
def tolerance_s(self) -> float:
|
|
"""Tolerance in seconds used to discard loaded frames when their timestamps
|
|
are not close enough from the requested frames. It is only used when `delta_timestamps`
|
|
is provided or when loading video frames from mp4 files.
|
|
"""
|
|
# 1e-4 to account for possible numerical error
|
|
return 1 / self.fps - 1e-4
|
|
|
|
def __len__(self):
|
|
return self.num_samples
|
|
|
|
def __getitem__(self, idx):
|
|
item = self.hf_dataset[idx]
|
|
|
|
if self.delta_timestamps is not None:
|
|
item = load_previous_and_future_frames(
|
|
item,
|
|
self.hf_dataset,
|
|
self.episode_data_index,
|
|
self.delta_timestamps,
|
|
self.tolerance_s,
|
|
)
|
|
|
|
if self.video:
|
|
item = load_from_videos(
|
|
item,
|
|
self.video_frame_keys,
|
|
self.videos_dir,
|
|
self.tolerance_s,
|
|
self.video_backend,
|
|
)
|
|
|
|
if self.image_transforms is not None:
|
|
for cam in self.camera_keys:
|
|
item[cam] = self.image_transforms(item[cam])
|
|
|
|
return item
|
|
|
|
def __repr__(self):
|
|
return (
|
|
f"{self.__class__.__name__}(\n"
|
|
f" Repository ID: '{self.repo_id}',\n"
|
|
f" Split: '{self.split}',\n"
|
|
f" Number of Samples: {self.num_samples},\n"
|
|
f" Number of Episodes: {self.num_episodes},\n"
|
|
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
|
|
f" Recorded Frames per Second: {self.fps},\n"
|
|
f" Camera Keys: {self.camera_keys},\n"
|
|
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
|
f" Transformations: {self.image_transforms},\n"
|
|
f" Codebase Version: {self.info.get('codebase_version', '< v1.6')},\n"
|
|
f")"
|
|
)
|
|
|
|
@classmethod
|
|
def from_preloaded(
|
|
cls,
|
|
repo_id: str = "from_preloaded",
|
|
root: Path | None = None,
|
|
split: str = "train",
|
|
transform: callable = None,
|
|
delta_timestamps: dict[list[float]] | None = None,
|
|
# additional preloaded attributes
|
|
hf_dataset=None,
|
|
episode_data_index=None,
|
|
stats=None,
|
|
info=None,
|
|
videos_dir=None,
|
|
video_backend=None,
|
|
) -> "LeRobotDataset":
|
|
"""Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem.
|
|
|
|
It is especially useful when converting raw data into LeRobotDataset before saving the dataset
|
|
on the filesystem or uploading to the hub.
|
|
|
|
Note: Meta-data attributes like `repo_id`, `version`, `root`, etc are optional and potentially
|
|
meaningless depending on the downstream usage of the return dataset.
|
|
"""
|
|
# create an empty object of type LeRobotDataset
|
|
obj = cls.__new__(cls)
|
|
obj.repo_id = repo_id
|
|
obj.root = root
|
|
obj.split = split
|
|
obj.image_transforms = transform
|
|
obj.delta_timestamps = delta_timestamps
|
|
obj.hf_dataset = hf_dataset
|
|
obj.episode_data_index = episode_data_index
|
|
obj.stats = stats
|
|
obj.info = info if info is not None else {}
|
|
obj.videos_dir = videos_dir
|
|
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
|
return obj
|
|
|
|
|
|
class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
|
|
|
|
The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
|
|
structure of `LeRobotDataset`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
repo_ids: list[str],
|
|
root: Path | None = DATA_DIR,
|
|
split: str = "train",
|
|
image_transforms: Callable | None = None,
|
|
delta_timestamps: dict[list[float]] | None = None,
|
|
video_backend: str | None = None,
|
|
):
|
|
super().__init__()
|
|
self.repo_ids = repo_ids
|
|
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
|
# are handled by this class.
|
|
self._datasets = [
|
|
LeRobotDataset(
|
|
repo_id,
|
|
root=root,
|
|
split=split,
|
|
delta_timestamps=delta_timestamps,
|
|
image_transforms=image_transforms,
|
|
video_backend=video_backend,
|
|
)
|
|
for repo_id in repo_ids
|
|
]
|
|
# Check that some properties are consistent across datasets. Note: We may relax some of these
|
|
# consistency requirements in future iterations of this class.
|
|
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
|
|
if dataset.info != self._datasets[0].info:
|
|
raise ValueError(
|
|
f"Detected a mismatch in dataset info between {self.repo_ids[0]} and {repo_id}. This is "
|
|
"not yet supported."
|
|
)
|
|
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
|
|
# restriction in future iterations of this class. For now, this is necessary at least for being able
|
|
# to use PyTorch's default DataLoader collate function.
|
|
self.disabled_data_keys = set()
|
|
intersection_data_keys = set(self._datasets[0].hf_dataset.features)
|
|
for dataset in self._datasets:
|
|
intersection_data_keys.intersection_update(dataset.hf_dataset.features)
|
|
if len(intersection_data_keys) == 0:
|
|
raise RuntimeError(
|
|
"Multiple datasets were provided but they had no keys common to all of them. The "
|
|
"multi-dataset functionality currently only keeps common keys."
|
|
)
|
|
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
|
|
extra_keys = set(dataset.hf_dataset.features).difference(intersection_data_keys)
|
|
logging.warning(
|
|
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
|
"other datasets."
|
|
)
|
|
self.disabled_data_keys.update(extra_keys)
|
|
|
|
self.root = root
|
|
self.split = split
|
|
self.image_transforms = image_transforms
|
|
self.delta_timestamps = delta_timestamps
|
|
self.stats = aggregate_stats(self._datasets)
|
|
|
|
@property
|
|
def repo_id_to_index(self):
|
|
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
|
|
|
|
This index is incorporated as a data key in the dictionary returned by `__getitem__`.
|
|
"""
|
|
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
|
|
|
|
@property
|
|
def repo_index_to_id(self):
|
|
"""Return the inverse mapping if repo_id_to_index."""
|
|
return {v: k for k, v in self.repo_id_to_index}
|
|
|
|
@property
|
|
def fps(self) -> int:
|
|
"""Frames per second used during data collection.
|
|
|
|
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
|
"""
|
|
return self._datasets[0].info["fps"]
|
|
|
|
@property
|
|
def video(self) -> bool:
|
|
"""Returns True if this dataset loads video frames from mp4 files.
|
|
|
|
Returns False if it only loads images from png files.
|
|
|
|
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
|
"""
|
|
return self._datasets[0].info.get("video", False)
|
|
|
|
@property
|
|
def features(self) -> datasets.Features:
|
|
features = {}
|
|
for dataset in self._datasets:
|
|
features.update({k: v for k, v in dataset.features.items() if k not in self.disabled_data_keys})
|
|
return features
|
|
|
|
@property
|
|
def camera_keys(self) -> list[str]:
|
|
"""Keys to access image and video stream from cameras."""
|
|
keys = []
|
|
for key, feats in self.features.items():
|
|
if isinstance(feats, (datasets.Image, VideoFrame)):
|
|
keys.append(key)
|
|
return keys
|
|
|
|
@property
|
|
def video_frame_keys(self) -> list[str]:
|
|
"""Keys to access video frames that requires to be decoded into images.
|
|
|
|
Note: It is empty if the dataset contains images only,
|
|
or equal to `self.cameras` if the dataset contains videos only,
|
|
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
|
|
"""
|
|
video_frame_keys = []
|
|
for key, feats in self.features.items():
|
|
if isinstance(feats, VideoFrame):
|
|
video_frame_keys.append(key)
|
|
return video_frame_keys
|
|
|
|
@property
|
|
def num_samples(self) -> int:
|
|
"""Number of samples/frames."""
|
|
return sum(d.num_samples for d in self._datasets)
|
|
|
|
@property
|
|
def num_episodes(self) -> int:
|
|
"""Number of episodes."""
|
|
return sum(d.num_episodes for d in self._datasets)
|
|
|
|
@property
|
|
def tolerance_s(self) -> float:
|
|
"""Tolerance in seconds used to discard loaded frames when their timestamps
|
|
are not close enough from the requested frames. It is only used when `delta_timestamps`
|
|
is provided or when loading video frames from mp4 files.
|
|
"""
|
|
# 1e-4 to account for possible numerical error
|
|
return 1 / self.fps - 1e-4
|
|
|
|
def __len__(self):
|
|
return self.num_samples
|
|
|
|
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
|
if idx >= len(self):
|
|
raise IndexError(f"Index {idx} out of bounds.")
|
|
# Determine which dataset to get an item from based on the index.
|
|
start_idx = 0
|
|
dataset_idx = 0
|
|
for dataset in self._datasets:
|
|
if idx >= start_idx + dataset.num_samples:
|
|
start_idx += dataset.num_samples
|
|
dataset_idx += 1
|
|
continue
|
|
break
|
|
else:
|
|
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
|
|
item = self._datasets[dataset_idx][idx - start_idx]
|
|
item["dataset_index"] = torch.tensor(dataset_idx)
|
|
for data_key in self.disabled_data_keys:
|
|
if data_key in item:
|
|
del item[data_key]
|
|
|
|
return item
|
|
|
|
def __repr__(self):
|
|
return (
|
|
f"{self.__class__.__name__}(\n"
|
|
f" Repository IDs: '{self.repo_ids}',\n"
|
|
f" Split: '{self.split}',\n"
|
|
f" Number of Samples: {self.num_samples},\n"
|
|
f" Number of Episodes: {self.num_episodes},\n"
|
|
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
|
|
f" Recorded Frames per Second: {self.fps},\n"
|
|
f" Camera Keys: {self.camera_keys},\n"
|
|
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
|
f" Transformations: {self.image_transforms},\n"
|
|
f")"
|
|
)
|