Update LeRobotDataset.__get_item__

This commit is contained in:
Simon Alibert 2024-10-10 21:32:14 +02:00
parent 3113038beb
commit b417cebc4e
3 changed files with 232 additions and 128 deletions

View File

@ -15,25 +15,27 @@
# limitations under the License.
import logging
import os
from itertools import accumulate
from pathlib import Path
from typing import Callable
import datasets
import torch
import torch.utils
from huggingface_hub import snapshot_download
from lerobot.common.datasets.compute_stats import aggregate_stats
from lerobot.common.datasets.utils import (
download_episodes,
check_delta_timestamps,
check_timestamps_sync,
get_delta_indices,
get_episode_data_index,
get_hub_safe_version,
load_hf_dataset,
load_info,
load_previous_and_future_frames,
load_stats,
load_tasks,
)
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
CODEBASE_VERSION = "v2.0"
@ -49,6 +51,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
split: str = "train",
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
tolerance_s: float = 1e-4,
video_backend: str | None = None,
):
super().__init__()
@ -58,7 +61,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
self.episodes = episodes
self.tolerance_s = tolerance_s
self.video_backend = video_backend if video_backend is not None else "pyav"
self.delta_indices = None
# Load metadata
self.root.mkdir(exist_ok=True, parents=True)
@ -68,34 +73,60 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.tasks = load_tasks(repo_id, self._version, self.root)
# Load actual data
download_episodes(
repo_id,
self._version,
self.root,
self.data_path,
self.video_keys,
self.num_episodes,
self.episodes,
self.videos_path,
)
self.download_episodes()
self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes)
self.episode_data_index = self.get_episode_data_index()
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
# Check timestamps
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
# Setup delta_indices
if self.delta_timestamps is not None:
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
# TODO(aliberts):
# - [ ] Update __get_item__
# - [X] Move delta_timestamp logic outside __get_item__
# - [X] Update __get_item__
# - [ ] Add self.add_frame()
# - [ ] Add self.consolidate() for:
# - [X] Check timestamps sync
# - [ ] Sanity checks (episodes num, shapes, files, etc.)
# - [ ] Update episode_index (arg update=True)
# - [ ] Update info.json (arg update=True)
# TODO(aliberts): remove (deprecated)
# if split == "train":
# self.episode_data_index = load_episode_data_index(self.episodes, self.episode_list)
# else:
# self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
# self.hf_dataset = reset_episode_index(self.hf_dataset)
# if self.video:
# self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
def download_episodes(self) -> None:
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
in 'local_dir', they won't be downloaded again.
Note: Currently, if you're running this code offline but you already have the files in 'local_dir',
snapshot_download will still fail. This behavior will be fixed in an upcoming update of huggingface_hub.
"""
# TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
files = None
if self.episodes is not None:
files = [
self.data_path.format(episode_index=ep_idx, total_episodes=self.total_episodes)
for ep_idx in self.episodes
]
if len(self.video_keys) > 0:
video_files = [
self.videos_path.format(video_key=vid_key, episode_index=ep_idx)
for vid_key in self.video_keys
for ep_idx in self.episodes
]
files += video_files
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self._version,
local_dir=self.root,
allow_patterns=files,
)
@property
def data_path(self) -> str:
@ -134,17 +165,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property
def camera_keys(self) -> list[str]:
"""Keys to access image and video streams from cameras."""
"""Keys to access image and video streams from cameras (regardless of their storage method)."""
return self.image_keys + self.video_keys
@property
def video_frame_keys(self) -> list[str]:
"""Keys to access video frames that requires to be decoded into images.
"""
DEPRECATED, USE 'video_keys' INSTEAD
Keys to access video frames that requires to be decoded into images.
Note: It is empty if the dataset contains images only,
or equal to `self.cameras` if the dataset contains videos only,
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
"""
# TODO(aliberts): remove
video_frame_keys = []
for key, feats in self.hf_dataset.features.items():
if isinstance(feats, VideoFrame):
@ -166,54 +200,97 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""Total number of episodes available."""
return self.info["total_episodes"]
@property
def tolerance_s(self) -> float:
"""Tolerance in seconds used to discard loaded frames when their timestamps
are not close enough from the requested frames. It is only used when `delta_timestamps`
is provided or when loading video frames from mp4 files.
"""
# 1e-4 to account for possible numerical error
return 1 / self.fps - 1e-4
# @property
# def tolerance_s(self) -> float:
# """Tolerance in seconds used to discard loaded frames when their timestamps
# are not close enough from the requested frames. It is used at the init of the dataset to make sure
# that each timestamps is separated to the next by 1/fps +/- tolerance. It is only used when
# `delta_timestamps` is provided or when loading video frames from mp4 files.
# """
# # 1e-4 to account for possible numerical error
# return 1e-4
@property
def shapes(self) -> dict:
"""Shapes for the different features."""
self.info.get("shapes")
def get_episode_data_index(self) -> dict[str, torch.Tensor]:
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(self.episode_dicts)}
def current_episode_index(self, idx: int) -> int:
episode_index = self.hf_dataset["episode_index"][idx]
if self.episodes is not None:
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in self.episodes}
# get episode_index from selected episodes
episode_index = self.episodes.index(episode_index)
cumulative_lenghts = list(accumulate(episode_lengths.values()))
return episode_index
def episode_length(self, episode_index) -> int:
"""Number of samples/frames for given episode."""
return self.info["episodes"][episode_index]["length"]
def _get_query_indices(self, idx: int, ep_idx: int) -> dict[str, list[int]]:
# Pad values outside of current episode range
ep_start = self.episode_data_index["from"][ep_idx]
ep_end = self.episode_data_index["to"][ep_idx]
return {
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
"to": torch.LongTensor(cumulative_lenghts),
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
for key, delta_idx in self.delta_indices.items()
}
def _get_query_timestamps(
self, query_indices: dict[str, list[int]], current_ts: float
) -> dict[str, list[float]]:
query_timestamps = {}
for key in self.video_keys:
if key in query_indices:
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
query_timestamps[key] = torch.stack(timestamps).tolist()
else:
query_timestamps[key] = [current_ts]
return query_timestamps
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
return {
key: torch.stack(self.hf_dataset.select(q_idx)[key])
for key, q_idx in query_indices.items()
if key not in self.video_keys
}
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
the main process and a subprocess fails to access it.
"""
item = {}
for vid_key, query_ts in query_timestamps.items():
video_path = self.root / self.videos_path.format(video_key=vid_key, episode_index=ep_idx)
frames = decode_video_frames_torchvision(
video_path, query_ts, self.tolerance_s, self.video_backend
)
item[vid_key] = frames
return item
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
def __getitem__(self, idx) -> dict:
item = self.hf_dataset[idx]
ep_idx = item["episode_index"].item()
if self.delta_timestamps is not None:
item = load_previous_and_future_frames(
item,
self.hf_dataset,
self.episode_data_index,
self.delta_timestamps,
self.tolerance_s,
)
if self.delta_indices is not None:
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
query_indices = self._get_query_indices(idx, current_ep_idx)
query_result = self._query_hf_dataset(query_indices)
for key, val in query_result.items():
item[key] = val
if self.video:
item = load_from_videos(
item,
self.video_keys,
self.videos_dir,
self.tolerance_s,
self.video_backend,
)
if len(self.video_keys) > 0:
current_ts = item["timestamp"].item()
query_timestamps = self._get_query_timestamps(query_indices, current_ts)
video_frames = self._query_videos(query_timestamps, ep_idx)
item = {**video_frames, **item}
if self.image_transforms is not None:
for cam in self.camera_keys:

View File

@ -16,13 +16,15 @@
import json
import warnings
from functools import cache
from itertools import accumulate
from pathlib import Path
from pprint import pformat
from typing import Dict
import datasets
import torch
from datasets import load_dataset
from huggingface_hub import DatasetCard, HfApi, hf_hub_download, snapshot_download
from huggingface_hub import DatasetCard, HfApi, hf_hub_download
from PIL import Image as PILImage
from torchvision import transforms
@ -193,40 +195,102 @@ def load_tasks(repo_id: str, version: str, local_dir: Path) -> dict:
return json.load(f)
def download_episodes(
repo_id: str,
version: str,
local_dir: Path,
data_path: str,
video_keys: list,
total_episodes: int,
episodes: list[int] | None = None,
videos_path: str | None = None,
) -> None:
"""Downloads the dataset from the given 'repo_id' at the provided 'version'. If 'episodes' is given, this
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
in 'local_dir', they won't be downloaded again.
Note: Currently, if you're running this code offline but you already have the files in 'local_dir',
snapshot_download will still fail. This behavior will be fixed in an upcoming update of huggingface_hub.
"""
# TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
files = None
def get_episode_data_index(episodes: list, episode_dicts: list[dict]) -> dict[str, torch.Tensor]:
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)}
if episodes is not None:
files = [data_path.format(episode_index=ep_idx, total_episodes=total_episodes) for ep_idx in episodes]
if len(video_keys) > 0:
video_files = [
videos_path.format(video_key=vid_key, episode_index=ep_idx)
for vid_key in video_keys
for ep_idx in episodes
]
files += video_files
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
snapshot_download(
repo_id, repo_type="dataset", revision=version, local_dir=local_dir, allow_patterns=files
)
cumulative_lenghts = list(accumulate(episode_lengths.values()))
return {
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
"to": torch.LongTensor(cumulative_lenghts),
}
def check_timestamps_sync(
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
fps: int,
tolerance_s: float,
raise_value_error: bool = True,
) -> bool:
"""
This check is to make sure that each timestamps is separated to the next by 1/fps +/- tolerance to
account for possible numerical error.
"""
timestamps = torch.stack(hf_dataset["timestamp"])
# timestamps[2] += tolerance_s # TODO delete
# timestamps[-2] += tolerance_s/2 # TODO delete
diffs = torch.diff(timestamps)
within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s
# We mask differences between the timestamp at the end of an episode
# and the one the start of the next episode since these are expected
# to be outside tolerance.
mask = torch.ones(len(diffs), dtype=torch.bool)
ignored_diffs = episode_data_index["to"][:-1] - 1
mask[ignored_diffs] = False
filtered_within_tolerance = within_tolerance[mask]
if not torch.all(filtered_within_tolerance):
# Track original indices before masking
original_indices = torch.arange(len(diffs))
filtered_indices = original_indices[mask]
outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance).squeeze()
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
episode_indices = torch.stack(hf_dataset["episode_index"])
outside_tolerances = []
for idx in outside_tolerance_indices:
entry = {
"timestamps": [timestamps[idx], timestamps[idx + 1]],
"diff": diffs[idx],
"episode_index": episode_indices[idx].item(),
}
outside_tolerances.append(entry)
if raise_value_error:
raise ValueError(
f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
This might be due to synchronization issues with timestamps during data collection.
\n{pformat(outside_tolerances)}"""
)
return False
return True
def check_delta_timestamps(
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
) -> bool:
outside_tolerance = {}
for key, delta_ts in delta_timestamps.items():
abs_delta_ts = torch.abs(torch.tensor(delta_ts))
within_tolerance = (abs_delta_ts % (1 / fps)) <= tolerance_s
if not torch.all(within_tolerance):
outside_tolerance[key] = torch.tensor(delta_ts)[~within_tolerance]
if len(outside_tolerance) > 0:
if raise_value_error:
raise ValueError(
f"""
The following delta_timestamps are found outside of tolerance range.
Please make sure they are multiples of 1/{fps} +/- tolerance and adjust
their values accordingly.
\n{pformat(outside_tolerance)}
"""
)
return False
return True
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
delta_indices = {}
for key, delta_ts in delta_timestamps.items():
delta_indices[key] = (torch.tensor(delta_ts) * fps).long().tolist()
return delta_indices
def load_previous_and_future_frames(

View File

@ -27,45 +27,8 @@ import torchvision
from datasets.features.features import register_feature
def load_from_videos(
item: dict[str, torch.Tensor],
video_frame_keys: list[str],
videos_dir: Path,
tolerance_s: float,
backend: str = "pyav",
):
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault.
This probably happens because a memory reference to the video loader is created in the main process and a
subprocess fails to access it.
"""
# since video path already contains "videos" (e.g. videos_dir="data/videos", path="videos/episode_0.mp4")
data_dir = videos_dir.parent
for key in video_frame_keys:
if isinstance(item[key], list):
# load multiple frames at once (expected when delta_timestamps is not None)
timestamps = [frame["timestamp"] for frame in item[key]]
paths = [frame["path"] for frame in item[key]]
if len(set(paths)) > 1:
raise NotImplementedError("All video paths are expected to be the same for now.")
video_path = data_dir / paths[0]
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
item[key] = frames
else:
# load one frame
timestamps = [item[key]["timestamp"]]
video_path = data_dir / item[key]["path"]
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
item[key] = frames[0]
return item
def decode_video_frames_torchvision(
video_path: str,
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
backend: str = "pyav",