628 lines
26 KiB
Python
628 lines
26 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Callable
|
|
|
|
import datasets
|
|
import torch
|
|
import torch.utils
|
|
from huggingface_hub import snapshot_download
|
|
|
|
from lerobot.common.datasets.compute_stats import aggregate_stats
|
|
from lerobot.common.datasets.utils import (
|
|
check_delta_timestamps,
|
|
check_timestamps_sync,
|
|
get_delta_indices,
|
|
get_episode_data_index,
|
|
get_hub_safe_version,
|
|
load_hf_dataset,
|
|
load_metadata,
|
|
)
|
|
from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision
|
|
|
|
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
|
|
CODEBASE_VERSION = "v2.0"
|
|
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
|
|
|
|
DEFAULT_CHUNK_SIZE = 1000
|
|
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
|
DEFAULT_PARQUET_PATH = (
|
|
"data/chunk-{episode_chunk:03d}/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
|
|
)
|
|
|
|
|
|
class LeRobotDataset(torch.utils.data.Dataset):
|
|
def __init__(
|
|
self,
|
|
repo_id: str,
|
|
root: Path | None = None,
|
|
episodes: list[int] | None = None,
|
|
image_transforms: Callable | None = None,
|
|
delta_timestamps: dict[list[float]] | None = None,
|
|
tolerance_s: float = 1e-4,
|
|
download_videos: bool = True,
|
|
video_backend: str | None = None,
|
|
):
|
|
"""LeRobotDataset encapsulates 3 main things:
|
|
- metadata:
|
|
- info contains various information about the dataset like shapes, keys, fps etc.
|
|
- stats stores the dataset statistics of the different modalities for normalization
|
|
- tasks contains the prompts for each task of the dataset, which can be used for
|
|
task-conditionned training.
|
|
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
|
|
- (optional) videos from which frames are loaded to be synchronous with data from parquet files.
|
|
|
|
3 modes are available for this class, depending on 3 different use cases:
|
|
|
|
1. Your dataset already exists on the Hugging Face Hub at the address
|
|
https://huggingface.co/datasets/{repo_id} and is not on your local disk in the 'root' folder:
|
|
Instantiating this class with this 'repo_id' will download the dataset from that address and load
|
|
it, pending your dataset is compliant with codebase_version v2.0. If your dataset has been created
|
|
before this new format, you will be prompted to convert it using our conversion script from v1.6
|
|
to v2.0, which you can find at lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py.
|
|
|
|
2. Your dataset already exists on your local disk in the 'root' folder:
|
|
This is typically the case when you recorded your dataset locally and you may or may not have
|
|
pushed it to the hub yet. Instantiating this class with 'root' will load your dataset directly
|
|
from disk. This can happen while you're offline (no internet connection).
|
|
|
|
3. Your dataset doesn't already exists (either on local disk or on the Hub):
|
|
[TODO(aliberts): add classmethod for this case?]
|
|
|
|
|
|
In terms of files, a typical LeRobotDataset looks like this from its root path:
|
|
.
|
|
├── data
|
|
│ ├── chunk-000
|
|
│ │ ├── train-00000-of-03603.parquet
|
|
│ │ ├── train-00001-of-03603.parquet
|
|
│ │ ├── train-00002-of-03603.parquet
|
|
│ │ └── ...
|
|
│ ├── chunk-001
|
|
│ │ ├── train-01000-of-03603.parquet
|
|
│ │ ├── train-01001-of-03603.parquet
|
|
│ │ ├── train-01002-of-03603.parquet
|
|
│ │ └── ...
|
|
│ └── ...
|
|
├── meta
|
|
│ ├── episodes.jsonl
|
|
│ ├── info.json
|
|
│ ├── stats.json
|
|
│ └── tasks.json
|
|
└── videos (optional)
|
|
├── chunk-000
|
|
│ ├── observation.images.laptop
|
|
│ │ ├── episode_000000.mp4
|
|
│ │ ├── episode_000001.mp4
|
|
│ │ ├── episode_000002.mp4
|
|
│ │ └── ...
|
|
│ ├── observation.images.phone
|
|
│ │ ├── episode_000000.mp4
|
|
│ │ ├── episode_000001.mp4
|
|
│ │ ├── episode_000002.mp4
|
|
│ │ └── ...
|
|
├── chunk-001
|
|
└── ...
|
|
|
|
Note that this file-based structure is designed to be as versatile as possible. The files are split by
|
|
episodes which allows a more granular control over which episodes one wants to use and download. The
|
|
structure of the dataset is entirely described in the info.json file, which can be easily downloaded
|
|
or viewed directly on the hub before downloading any actual data. The type of files used are very
|
|
simple and do not need complex tools to be read, it only uses .parquet, .json and .mp4 files (and .md
|
|
for the README).
|
|
|
|
Args:
|
|
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
|
|
will be stored under root/repo_id.
|
|
root (Path | None, optional): Local directory to use for downloading/writing files. You can also
|
|
set the LEROBOT_HOME environment variable to point to a different location. Defaults to
|
|
'~/.cache/huggingface/lerobot'.
|
|
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
|
their episode_index in this list. Defaults to None.
|
|
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
|
|
torchvision.transforms.v2 here which will be applied to visual modalities (whether they come
|
|
from videos or images). Defaults to None.
|
|
delta_timestamps (dict[list[float]] | None, optional): _description_. Defaults to None.
|
|
tolerance_s (float, optional): Tolerance in seconds used to ensure data timestamps are actually in
|
|
sync with the fps value. It is used at the init of the dataset to make sure that each
|
|
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
|
|
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
|
|
multiples of 1/fps. Defaults to 1e-4.
|
|
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
|
video files are already present on local disk, they won't be downloaded again. Defaults to
|
|
True.
|
|
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
|
|
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
|
|
"""
|
|
super().__init__()
|
|
self.repo_id = repo_id
|
|
self.root = root if root is not None else LEROBOT_HOME / repo_id
|
|
self.image_transforms = image_transforms
|
|
self.delta_timestamps = delta_timestamps
|
|
self.episodes = episodes
|
|
self.tolerance_s = tolerance_s
|
|
self.download_videos = download_videos
|
|
self.video_backend = video_backend if video_backend is not None else "pyav"
|
|
self.delta_indices = None
|
|
|
|
# Load metadata
|
|
self.root.mkdir(exist_ok=True, parents=True)
|
|
self._version = get_hub_safe_version(repo_id, CODEBASE_VERSION)
|
|
self.download_metadata()
|
|
self.info, self.episode_dicts, self.stats, self.tasks = load_metadata(self.root)
|
|
|
|
# Load actual data
|
|
self.download_episodes()
|
|
self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes)
|
|
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
|
|
|
# Check timestamps
|
|
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
|
|
|
# Setup delta_indices
|
|
if self.delta_timestamps is not None:
|
|
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
|
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
|
|
|
# TODO(aliberts):
|
|
# - [X] Move delta_timestamp logic outside __get_item__
|
|
# - [X] Update __get_item__
|
|
# - [/] Add doc
|
|
# - [ ] Add self.add_frame()
|
|
# - [ ] Add self.consolidate() for:
|
|
# - [X] Check timestamps sync
|
|
# - [ ] Sanity checks (episodes num, shapes, files, etc.)
|
|
# - [ ] Update episode_index (arg update=True)
|
|
# - [ ] Update info.json (arg update=True)
|
|
|
|
def download_metadata(self) -> None:
|
|
snapshot_download(
|
|
self.repo_id,
|
|
repo_type="dataset",
|
|
revision=self._version,
|
|
local_dir=self.root,
|
|
allow_patterns="meta/",
|
|
)
|
|
|
|
def download_episodes(self) -> None:
|
|
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
|
|
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
|
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
|
in 'local_dir', they won't be downloaded again.
|
|
"""
|
|
# TODO(rcadene, aliberts): implement faster transfer
|
|
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
|
files = None
|
|
ignore_patterns = None if self.download_videos else "videos/"
|
|
if self.episodes is not None:
|
|
files = [
|
|
self.data_path.format(episode_index=ep_idx, total_episodes=self.total_episodes)
|
|
for ep_idx in self.episodes
|
|
]
|
|
if len(self.video_keys) > 0 and self.download_videos:
|
|
video_files = [
|
|
self.videos_path.format(video_key=vid_key, episode_index=ep_idx)
|
|
for vid_key in self.video_keys
|
|
for ep_idx in self.episodes
|
|
]
|
|
files += video_files
|
|
|
|
snapshot_download(
|
|
self.repo_id,
|
|
repo_type="dataset",
|
|
revision=self._version,
|
|
local_dir=self.root,
|
|
allow_patterns=files,
|
|
ignore_patterns=ignore_patterns,
|
|
)
|
|
|
|
@property
|
|
def data_path(self) -> str:
|
|
"""Formattable string for the parquet files."""
|
|
return self.info["data_path"]
|
|
|
|
@property
|
|
def videos_path(self) -> str | None:
|
|
"""Formattable string for the video files."""
|
|
return self.info["videos"]["videos_path"] if len(self.video_keys) > 0 else None
|
|
|
|
@property
|
|
def fps(self) -> int:
|
|
"""Frames per second used during data collection."""
|
|
return self.info["fps"]
|
|
|
|
@property
|
|
def keys(self) -> list[str]:
|
|
"""Keys to access non-image data (state, actions etc.)."""
|
|
return self.info["keys"]
|
|
|
|
@property
|
|
def image_keys(self) -> list[str]:
|
|
"""Keys to access visual modalities stored as images."""
|
|
return self.info["image_keys"]
|
|
|
|
@property
|
|
def video_keys(self) -> list[str]:
|
|
"""Keys to access visual modalities stored as videos."""
|
|
return self.info["video_keys"]
|
|
|
|
@property
|
|
def camera_keys(self) -> list[str]:
|
|
"""Keys to access visual modalities (regardless of their storage method)."""
|
|
return self.image_keys + self.video_keys
|
|
|
|
@property
|
|
def names(self) -> dict[list[str]]:
|
|
"""Names of the various dimensions of vector modalities."""
|
|
return self.info["names"]
|
|
|
|
@property
|
|
def num_samples(self) -> int:
|
|
"""Number of samples/frames."""
|
|
return len(self.hf_dataset)
|
|
|
|
@property
|
|
def num_episodes(self) -> int:
|
|
"""Number of episodes selected."""
|
|
return len(self.episodes) if self.episodes is not None else self.total_episodes
|
|
|
|
@property
|
|
def total_episodes(self) -> int:
|
|
"""Total number of episodes available."""
|
|
return self.info["total_episodes"]
|
|
|
|
@property
|
|
def total_chunks(self) -> int:
|
|
"""Total number of chunks (groups of episodes)."""
|
|
return self.info["total_chunks"]
|
|
|
|
@property
|
|
def chunks_size(self) -> int:
|
|
"""Max number of episodes per chunk."""
|
|
return self.info["chunks_size"]
|
|
|
|
@property
|
|
def shapes(self) -> dict:
|
|
"""Shapes for the different features."""
|
|
self.info.get("shapes")
|
|
|
|
def current_episode_index(self, idx: int) -> int:
|
|
episode_index = self.hf_dataset["episode_index"][idx]
|
|
if self.episodes is not None:
|
|
# get episode_index from selected episodes
|
|
episode_index = self.episodes.index(episode_index)
|
|
|
|
return episode_index
|
|
|
|
def episode_length(self, episode_index) -> int:
|
|
"""Number of samples/frames for given episode."""
|
|
return self.info["episodes"][episode_index]["length"]
|
|
|
|
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
|
ep_start = self.episode_data_index["from"][ep_idx]
|
|
ep_end = self.episode_data_index["to"][ep_idx]
|
|
query_indices = {
|
|
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
|
|
for key, delta_idx in self.delta_indices.items()
|
|
}
|
|
padding = { # Pad values outside of current episode range
|
|
f"{key}_is_pad": torch.BoolTensor(
|
|
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
|
|
)
|
|
for key, delta_idx in self.delta_indices.items()
|
|
}
|
|
return query_indices, padding
|
|
|
|
def _get_query_timestamps(
|
|
self,
|
|
current_ts: float,
|
|
query_indices: dict[str, list[int]] | None = None,
|
|
) -> dict[str, list[float]]:
|
|
query_timestamps = {}
|
|
for key in self.video_keys:
|
|
if query_indices is not None and key in query_indices:
|
|
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
|
query_timestamps[key] = torch.stack(timestamps).tolist()
|
|
else:
|
|
query_timestamps[key] = [current_ts]
|
|
|
|
return query_timestamps
|
|
|
|
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
|
return {
|
|
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
|
for key, q_idx in query_indices.items()
|
|
if key not in self.video_keys
|
|
}
|
|
|
|
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
|
|
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
|
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
|
|
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
|
the main process and a subprocess fails to access it.
|
|
"""
|
|
item = {}
|
|
for vid_key, query_ts in query_timestamps.items():
|
|
video_path = self.root / self.videos_path.format(video_key=vid_key, episode_index=ep_idx)
|
|
frames = decode_video_frames_torchvision(
|
|
video_path, query_ts, self.tolerance_s, self.video_backend
|
|
)
|
|
item[vid_key] = frames
|
|
|
|
return item
|
|
|
|
def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict:
|
|
for key, val in padding.items():
|
|
item[key] = torch.BoolTensor(val)
|
|
return item
|
|
|
|
def __len__(self):
|
|
return self.num_samples
|
|
|
|
def __getitem__(self, idx) -> dict:
|
|
item = self.hf_dataset[idx]
|
|
ep_idx = item["episode_index"].item()
|
|
|
|
query_indices = None
|
|
if self.delta_indices is not None:
|
|
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
|
|
query_indices, padding = self._get_query_indices(idx, current_ep_idx)
|
|
query_result = self._query_hf_dataset(query_indices)
|
|
item = {**item, **padding}
|
|
for key, val in query_result.items():
|
|
item[key] = val
|
|
|
|
if len(self.video_keys) > 0:
|
|
current_ts = item["timestamp"].item()
|
|
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
|
video_frames = self._query_videos(query_timestamps, ep_idx)
|
|
item = {**video_frames, **item}
|
|
|
|
if self.image_transforms is not None:
|
|
image_keys = self.camera_keys if self.download_videos else self.image_keys
|
|
for cam in image_keys:
|
|
item[cam] = self.image_transforms(item[cam])
|
|
|
|
return item
|
|
|
|
def __repr__(self):
|
|
return (
|
|
f"{self.__class__.__name__}(\n"
|
|
f" Repository ID: '{self.repo_id}',\n"
|
|
f" Number of Samples: {self.num_samples},\n"
|
|
f" Number of Episodes: {self.num_episodes},\n"
|
|
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
|
|
f" Recorded Frames per Second: {self.fps},\n"
|
|
f" Camera Keys: {self.camera_keys},\n"
|
|
f" Video Frame Keys: {self.camera_keys if self.video else 'N/A'},\n"
|
|
f" Transformations: {self.image_transforms},\n"
|
|
f" Codebase Version: {self.info.get('codebase_version', '< v1.6')},\n"
|
|
f")"
|
|
)
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
repo_id: str,
|
|
root: Path | None = None,
|
|
image_transforms: Callable | None = None,
|
|
delta_timestamps: dict[list[float]] | None = None,
|
|
tolerance_s: float = 1e-4,
|
|
video_backend: str | None = None,
|
|
) -> "LeRobotDataset":
|
|
"""Create a LeRobot Dataset from scratch in order to record data."""
|
|
# create an empty object of type LeRobotDataset
|
|
obj = cls.__new__(cls)
|
|
obj.repo_id = repo_id
|
|
obj.root = root if root is not None else LEROBOT_HOME / repo_id
|
|
# obj.episodes = None
|
|
# obj.image_transforms = None
|
|
# obj.delta_timestamps = None
|
|
# obj.episode_data_index = episode_data_index
|
|
# obj.stats = stats
|
|
# obj.info = info if info is not None else {}
|
|
# obj.videos_dir = videos_dir
|
|
# obj.video_backend = video_backend if video_backend is not None else "pyav"
|
|
return obj
|
|
|
|
|
|
class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
|
|
|
|
The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
|
|
structure of `LeRobotDataset`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
repo_ids: list[str],
|
|
root: Path | None = LEROBOT_HOME,
|
|
split: str = "train",
|
|
image_transforms: Callable | None = None,
|
|
delta_timestamps: dict[list[float]] | None = None,
|
|
video_backend: str | None = None,
|
|
):
|
|
super().__init__()
|
|
self.repo_ids = repo_ids
|
|
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
|
# are handled by this class.
|
|
self._datasets = [
|
|
LeRobotDataset(
|
|
repo_id,
|
|
root=root,
|
|
split=split,
|
|
delta_timestamps=delta_timestamps,
|
|
image_transforms=image_transforms,
|
|
video_backend=video_backend,
|
|
)
|
|
for repo_id in repo_ids
|
|
]
|
|
# Check that some properties are consistent across datasets. Note: We may relax some of these
|
|
# consistency requirements in future iterations of this class.
|
|
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
|
|
if dataset.info != self._datasets[0].info:
|
|
raise ValueError(
|
|
f"Detected a mismatch in dataset info between {self.repo_ids[0]} and {repo_id}. This is "
|
|
"not yet supported."
|
|
)
|
|
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
|
|
# restriction in future iterations of this class. For now, this is necessary at least for being able
|
|
# to use PyTorch's default DataLoader collate function.
|
|
self.disabled_data_keys = set()
|
|
intersection_data_keys = set(self._datasets[0].hf_dataset.features)
|
|
for dataset in self._datasets:
|
|
intersection_data_keys.intersection_update(dataset.hf_dataset.features)
|
|
if len(intersection_data_keys) == 0:
|
|
raise RuntimeError(
|
|
"Multiple datasets were provided but they had no keys common to all of them. The "
|
|
"multi-dataset functionality currently only keeps common keys."
|
|
)
|
|
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
|
|
extra_keys = set(dataset.hf_dataset.features).difference(intersection_data_keys)
|
|
logging.warning(
|
|
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
|
"other datasets."
|
|
)
|
|
self.disabled_data_keys.update(extra_keys)
|
|
|
|
self.root = root
|
|
self.split = split
|
|
self.image_transforms = image_transforms
|
|
self.delta_timestamps = delta_timestamps
|
|
self.stats = aggregate_stats(self._datasets)
|
|
|
|
@property
|
|
def repo_id_to_index(self):
|
|
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
|
|
|
|
This index is incorporated as a data key in the dictionary returned by `__getitem__`.
|
|
"""
|
|
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
|
|
|
|
@property
|
|
def repo_index_to_id(self):
|
|
"""Return the inverse mapping if repo_id_to_index."""
|
|
return {v: k for k, v in self.repo_id_to_index}
|
|
|
|
@property
|
|
def fps(self) -> int:
|
|
"""Frames per second used during data collection.
|
|
|
|
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
|
"""
|
|
return self._datasets[0].info["fps"]
|
|
|
|
@property
|
|
def video(self) -> bool:
|
|
"""Returns True if this dataset loads video frames from mp4 files.
|
|
|
|
Returns False if it only loads images from png files.
|
|
|
|
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
|
"""
|
|
return self._datasets[0].info.get("video", False)
|
|
|
|
@property
|
|
def features(self) -> datasets.Features:
|
|
features = {}
|
|
for dataset in self._datasets:
|
|
features.update({k: v for k, v in dataset.features.items() if k not in self.disabled_data_keys})
|
|
return features
|
|
|
|
@property
|
|
def camera_keys(self) -> list[str]:
|
|
"""Keys to access image and video stream from cameras."""
|
|
keys = []
|
|
for key, feats in self.features.items():
|
|
if isinstance(feats, (datasets.Image, VideoFrame)):
|
|
keys.append(key)
|
|
return keys
|
|
|
|
@property
|
|
def video_frame_keys(self) -> list[str]:
|
|
"""Keys to access video frames that requires to be decoded into images.
|
|
|
|
Note: It is empty if the dataset contains images only,
|
|
or equal to `self.cameras` if the dataset contains videos only,
|
|
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
|
|
"""
|
|
video_frame_keys = []
|
|
for key, feats in self.features.items():
|
|
if isinstance(feats, VideoFrame):
|
|
video_frame_keys.append(key)
|
|
return video_frame_keys
|
|
|
|
@property
|
|
def num_samples(self) -> int:
|
|
"""Number of samples/frames."""
|
|
return sum(d.num_samples for d in self._datasets)
|
|
|
|
@property
|
|
def num_episodes(self) -> int:
|
|
"""Number of episodes."""
|
|
return sum(d.num_episodes for d in self._datasets)
|
|
|
|
@property
|
|
def tolerance_s(self) -> float:
|
|
"""Tolerance in seconds used to discard loaded frames when their timestamps
|
|
are not close enough from the requested frames. It is only used when `delta_timestamps`
|
|
is provided or when loading video frames from mp4 files.
|
|
"""
|
|
# 1e-4 to account for possible numerical error
|
|
return 1 / self.fps - 1e-4
|
|
|
|
def __len__(self):
|
|
return self.num_samples
|
|
|
|
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
|
if idx >= len(self):
|
|
raise IndexError(f"Index {idx} out of bounds.")
|
|
# Determine which dataset to get an item from based on the index.
|
|
start_idx = 0
|
|
dataset_idx = 0
|
|
for dataset in self._datasets:
|
|
if idx >= start_idx + dataset.num_samples:
|
|
start_idx += dataset.num_samples
|
|
dataset_idx += 1
|
|
continue
|
|
break
|
|
else:
|
|
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
|
|
item = self._datasets[dataset_idx][idx - start_idx]
|
|
item["dataset_index"] = torch.tensor(dataset_idx)
|
|
for data_key in self.disabled_data_keys:
|
|
if data_key in item:
|
|
del item[data_key]
|
|
|
|
return item
|
|
|
|
def __repr__(self):
|
|
return (
|
|
f"{self.__class__.__name__}(\n"
|
|
f" Repository IDs: '{self.repo_ids}',\n"
|
|
f" Split: '{self.split}',\n"
|
|
f" Number of Samples: {self.num_samples},\n"
|
|
f" Number of Episodes: {self.num_episodes},\n"
|
|
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
|
|
f" Recorded Frames per Second: {self.fps},\n"
|
|
f" Camera Keys: {self.camera_keys},\n"
|
|
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
|
f" Transformations: {self.image_transforms},\n"
|
|
f")"
|
|
)
|