Add doc, scrap video_frame_keys attribute

This commit is contained in:
Simon Alibert 2024-10-11 10:59:38 +02:00
parent b417cebc4e
commit 6d2bc11365
4 changed files with 85 additions and 31 deletions

View File

@ -54,6 +54,83 @@ class LeRobotDataset(torch.utils.data.Dataset):
tolerance_s: float = 1e-4,
video_backend: str | None = None,
):
"""LeRobotDataset encapsulates 3 main things:
- metadata:
- info contains various information about the dataset like shapes, keys, fps etc.
- stats stores the dataset statistics of the different modalities for normalization
- tasks contains the prompts for each task of the dataset, which can be used for
task-conditionned training.
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
- (optional) videos from which frames are loaded to be synchronous with data from parquet files.
3 use modes are available for this class, depending on 3 different use cases:
1. Your dataset already exists on the Hugging Face Hub at the address
https://huggingface.co/datasets/{repo_id} and is not on your local disk in the 'root' folder:
Instantiating this class with this 'repo_id' will download the dataset from that address and load
it, pending your dataset is compliant with codebase_version v2.0. If your dataset has been created
before this new format, you will be prompted to convert it using our conversion script from v1.6
to v2.0, which you can find at [TODO(aliberts): move conversion script & add location here].
2. Your dataset already exists on your local disk in the 'root' folder:
This is typically the case when you recorded your dataset locally and you may or may not have
pushed it to the hub yet. Instantiating this class with 'root' will load your dataset directly
from disk. This can happen while you're offline (no internet connection).
3. Your dataset doesn't already exists (either on local disk or on the Hub):
[TODO(aliberts): add classmethod for this case?]
In terms of files, a typical LeRobotDataset looks like this from its root path:
.
README.md
data
train-00000-of-00050.parquet
train-00001-of-00050.parquet
train-00002-of-00050.parquet
...
meta
info.json
stats.json
tasks.json
videos (optional)
observation.images.laptop_episode_000000.mp4
observation.images.laptop_episode_000001.mp4
observation.images.laptop_episode_000002.mp4
...
observation.images.phone_episode_000000.mp4
observation.images.phone_episode_000001.mp4
observation.images.phone_episode_000002.mp4
...
Note that this file-based structure is designed to be as versatile as possible. The files are split by
episodes which allows a more granular control over which episodes one wants to use and download. The
structure of the dataset is entirely described in the info.json file, which can be easily downloaded
or viewed directly on the hub before downloading any actual data. The type of files used are very
simple and do not need complex tools to be read, it only uses .parquet, .json and .mp4 files (and .md
for the README).
Args:
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
will be stored under root/repo_id.
root (Path | None, optional): Local directory to use for downloading/writing files. You can also
set the LEROBOT_HOME environment variable to point to a different location. Defaults to
'~/.cache/huggingface/lerobot'.
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list. Defaults to None.
split (str, optional): _description_. Defaults to "train".
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
torchvision.transforms.v2 here which will be applied to visual modalities (whether they come
from videos or images). Defaults to None.
delta_timestamps (dict[list[float]] | None, optional): _description_. Defaults to None.
tolerance_s (float, optional): Tolerance in seconds used to ensure data timestamps are actually in
sync with the fps value. It is used at the init of the dataset to make sure that each
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
multiples of 1/fps. Defaults to 1e-4.
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
"""
super().__init__()
self.repo_id = repo_id
self.root = root if root is not None else LEROBOT_HOME / repo_id
@ -88,6 +165,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# TODO(aliberts):
# - [X] Move delta_timestamp logic outside __get_item__
# - [X] Update __get_item__
# - [/] Add doc
# - [ ] Add self.add_frame()
# - [ ] Add self.consolidate() for:
# - [X] Check timestamps sync
@ -168,23 +246,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""Keys to access image and video streams from cameras (regardless of their storage method)."""
return self.image_keys + self.video_keys
@property
def video_frame_keys(self) -> list[str]:
"""
DEPRECATED, USE 'video_keys' INSTEAD
Keys to access video frames that requires to be decoded into images.
Note: It is empty if the dataset contains images only,
or equal to `self.cameras` if the dataset contains videos only,
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
"""
# TODO(aliberts): remove
video_frame_keys = []
for key, feats in self.hf_dataset.features.items():
if isinstance(feats, VideoFrame):
video_frame_keys.append(key)
return video_frame_keys
@property
def num_samples(self) -> int:
"""Number of samples/frames."""
@ -200,16 +261,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""Total number of episodes available."""
return self.info["total_episodes"]
# @property
# def tolerance_s(self) -> float:
# """Tolerance in seconds used to discard loaded frames when their timestamps
# are not close enough from the requested frames. It is used at the init of the dataset to make sure
# that each timestamps is separated to the next by 1/fps +/- tolerance. It is only used when
# `delta_timestamps` is provided or when loading video frames from mp4 files.
# """
# # 1e-4 to account for possible numerical error
# return 1e-4
@property
def shapes(self) -> dict:
"""Shapes for the different features."""
@ -308,7 +359,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
f" Recorded Frames per Second: {self.fps},\n"
f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
f" Video Frame Keys: {self.camera_keys if self.video else 'N/A'},\n"
f" Transformations: {self.image_transforms},\n"
f" Codebase Version: {self.info.get('codebase_version', '< v1.6')},\n"
f")"

View File

@ -263,6 +263,10 @@ def check_timestamps_sync(
def check_delta_timestamps(
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
) -> bool:
"""This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be
actual timestamps from the dataset.
"""
outside_tolerance = {}
for key, delta_ts in delta_timestamps.items():
abs_delta_ts = torch.abs(torch.tensor(delta_ts))

View File

@ -260,7 +260,7 @@ def push_dataset_to_hub(
episode_index = 0
tests_videos_dir = tests_data_dir / repo_id / "videos"
tests_videos_dir.mkdir(parents=True, exist_ok=True)
for key in lerobot_dataset.video_frame_keys:
for key in lerobot_dataset.camera_keys:
fname = f"{key}_episode_{episode_index:06d}.mp4"
shutil.copy(videos_dir / fname, tests_videos_dir / fname)

View File

@ -171,8 +171,7 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
# get first frame of episode (hack to get video_path of the episode)
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
return [
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
for key in dataset.video_frame_keys
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] for key in dataset.camera_keys
]