Update LeRobotDataset.__get_item__
This commit is contained in:
parent
3113038beb
commit
b417cebc4e
|
@ -15,25 +15,27 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
from itertools import accumulate
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import torch.utils
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.common.datasets.utils import (
|
||||
download_episodes,
|
||||
check_delta_timestamps,
|
||||
check_timestamps_sync,
|
||||
get_delta_indices,
|
||||
get_episode_data_index,
|
||||
get_hub_safe_version,
|
||||
load_hf_dataset,
|
||||
load_info,
|
||||
load_previous_and_future_frames,
|
||||
load_stats,
|
||||
load_tasks,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision
|
||||
|
||||
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
|
||||
CODEBASE_VERSION = "v2.0"
|
||||
|
@ -49,6 +51,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
split: str = "train",
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
video_backend: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -58,7 +61,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
self.delta_indices = None
|
||||
|
||||
# Load metadata
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
@ -68,34 +73,60 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.tasks = load_tasks(repo_id, self._version, self.root)
|
||||
|
||||
# Load actual data
|
||||
download_episodes(
|
||||
repo_id,
|
||||
self._version,
|
||||
self.root,
|
||||
self.data_path,
|
||||
self.video_keys,
|
||||
self.num_episodes,
|
||||
self.episodes,
|
||||
self.videos_path,
|
||||
)
|
||||
self.download_episodes()
|
||||
self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes)
|
||||
self.episode_data_index = self.get_episode_data_index()
|
||||
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
||||
|
||||
# Check timestamps
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
|
||||
# Setup delta_indices
|
||||
if self.delta_timestamps is not None:
|
||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
# TODO(aliberts):
|
||||
# - [ ] Update __get_item__
|
||||
# - [X] Move delta_timestamp logic outside __get_item__
|
||||
# - [X] Update __get_item__
|
||||
# - [ ] Add self.add_frame()
|
||||
# - [ ] Add self.consolidate() for:
|
||||
# - [X] Check timestamps sync
|
||||
# - [ ] Sanity checks (episodes num, shapes, files, etc.)
|
||||
# - [ ] Update episode_index (arg update=True)
|
||||
# - [ ] Update info.json (arg update=True)
|
||||
|
||||
# TODO(aliberts): remove (deprecated)
|
||||
# if split == "train":
|
||||
# self.episode_data_index = load_episode_data_index(self.episodes, self.episode_list)
|
||||
# else:
|
||||
# self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
|
||||
# self.hf_dataset = reset_episode_index(self.hf_dataset)
|
||||
# if self.video:
|
||||
# self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
|
||||
def download_episodes(self) -> None:
|
||||
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
|
||||
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
||||
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
||||
in 'local_dir', they won't be downloaded again.
|
||||
|
||||
Note: Currently, if you're running this code offline but you already have the files in 'local_dir',
|
||||
snapshot_download will still fail. This behavior will be fixed in an upcoming update of huggingface_hub.
|
||||
"""
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||
files = None
|
||||
if self.episodes is not None:
|
||||
files = [
|
||||
self.data_path.format(episode_index=ep_idx, total_episodes=self.total_episodes)
|
||||
for ep_idx in self.episodes
|
||||
]
|
||||
if len(self.video_keys) > 0:
|
||||
video_files = [
|
||||
self.videos_path.format(video_key=vid_key, episode_index=ep_idx)
|
||||
for vid_key in self.video_keys
|
||||
for ep_idx in self.episodes
|
||||
]
|
||||
files += video_files
|
||||
|
||||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self._version,
|
||||
local_dir=self.root,
|
||||
allow_patterns=files,
|
||||
)
|
||||
|
||||
@property
|
||||
def data_path(self) -> str:
|
||||
|
@ -134,17 +165,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access image and video streams from cameras."""
|
||||
"""Keys to access image and video streams from cameras (regardless of their storage method)."""
|
||||
return self.image_keys + self.video_keys
|
||||
|
||||
@property
|
||||
def video_frame_keys(self) -> list[str]:
|
||||
"""Keys to access video frames that requires to be decoded into images.
|
||||
"""
|
||||
DEPRECATED, USE 'video_keys' INSTEAD
|
||||
Keys to access video frames that requires to be decoded into images.
|
||||
|
||||
Note: It is empty if the dataset contains images only,
|
||||
or equal to `self.cameras` if the dataset contains videos only,
|
||||
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
|
||||
"""
|
||||
# TODO(aliberts): remove
|
||||
video_frame_keys = []
|
||||
for key, feats in self.hf_dataset.features.items():
|
||||
if isinstance(feats, VideoFrame):
|
||||
|
@ -166,54 +200,97 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
"""Total number of episodes available."""
|
||||
return self.info["total_episodes"]
|
||||
|
||||
@property
|
||||
def tolerance_s(self) -> float:
|
||||
"""Tolerance in seconds used to discard loaded frames when their timestamps
|
||||
are not close enough from the requested frames. It is only used when `delta_timestamps`
|
||||
is provided or when loading video frames from mp4 files.
|
||||
"""
|
||||
# 1e-4 to account for possible numerical error
|
||||
return 1 / self.fps - 1e-4
|
||||
# @property
|
||||
# def tolerance_s(self) -> float:
|
||||
# """Tolerance in seconds used to discard loaded frames when their timestamps
|
||||
# are not close enough from the requested frames. It is used at the init of the dataset to make sure
|
||||
# that each timestamps is separated to the next by 1/fps +/- tolerance. It is only used when
|
||||
# `delta_timestamps` is provided or when loading video frames from mp4 files.
|
||||
# """
|
||||
# # 1e-4 to account for possible numerical error
|
||||
# return 1e-4
|
||||
|
||||
@property
|
||||
def shapes(self) -> dict:
|
||||
"""Shapes for the different features."""
|
||||
self.info.get("shapes")
|
||||
|
||||
def get_episode_data_index(self) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(self.episode_dicts)}
|
||||
def current_episode_index(self, idx: int) -> int:
|
||||
episode_index = self.hf_dataset["episode_index"][idx]
|
||||
if self.episodes is not None:
|
||||
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in self.episodes}
|
||||
# get episode_index from selected episodes
|
||||
episode_index = self.episodes.index(episode_index)
|
||||
|
||||
cumulative_lenghts = list(accumulate(episode_lengths.values()))
|
||||
return episode_index
|
||||
|
||||
def episode_length(self, episode_index) -> int:
|
||||
"""Number of samples/frames for given episode."""
|
||||
return self.info["episodes"][episode_index]["length"]
|
||||
|
||||
def _get_query_indices(self, idx: int, ep_idx: int) -> dict[str, list[int]]:
|
||||
# Pad values outside of current episode range
|
||||
ep_start = self.episode_data_index["from"][ep_idx]
|
||||
ep_end = self.episode_data_index["to"][ep_idx]
|
||||
return {
|
||||
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
|
||||
"to": torch.LongTensor(cumulative_lenghts),
|
||||
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
|
||||
def _get_query_timestamps(
|
||||
self, query_indices: dict[str, list[int]], current_ts: float
|
||||
) -> dict[str, list[float]]:
|
||||
query_timestamps = {}
|
||||
for key in self.video_keys:
|
||||
if key in query_indices:
|
||||
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
||||
query_timestamps[key] = torch.stack(timestamps).tolist()
|
||||
else:
|
||||
query_timestamps[key] = [current_ts]
|
||||
|
||||
return query_timestamps
|
||||
|
||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||
return {
|
||||
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
||||
for key, q_idx in query_indices.items()
|
||||
if key not in self.video_keys
|
||||
}
|
||||
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
|
||||
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
||||
the main process and a subprocess fails to access it.
|
||||
"""
|
||||
item = {}
|
||||
for vid_key, query_ts in query_timestamps.items():
|
||||
video_path = self.root / self.videos_path.format(video_key=vid_key, episode_index=ep_idx)
|
||||
frames = decode_video_frames_torchvision(
|
||||
video_path, query_ts, self.tolerance_s, self.video_backend
|
||||
)
|
||||
item[vid_key] = frames
|
||||
|
||||
return item
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
def __getitem__(self, idx) -> dict:
|
||||
item = self.hf_dataset[idx]
|
||||
ep_idx = item["episode_index"].item()
|
||||
|
||||
if self.delta_timestamps is not None:
|
||||
item = load_previous_and_future_frames(
|
||||
item,
|
||||
self.hf_dataset,
|
||||
self.episode_data_index,
|
||||
self.delta_timestamps,
|
||||
self.tolerance_s,
|
||||
)
|
||||
if self.delta_indices is not None:
|
||||
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
|
||||
query_indices = self._get_query_indices(idx, current_ep_idx)
|
||||
query_result = self._query_hf_dataset(query_indices)
|
||||
for key, val in query_result.items():
|
||||
item[key] = val
|
||||
|
||||
if self.video:
|
||||
item = load_from_videos(
|
||||
item,
|
||||
self.video_keys,
|
||||
self.videos_dir,
|
||||
self.tolerance_s,
|
||||
self.video_backend,
|
||||
)
|
||||
if len(self.video_keys) > 0:
|
||||
current_ts = item["timestamp"].item()
|
||||
query_timestamps = self._get_query_timestamps(query_indices, current_ts)
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
item = {**video_frames, **item}
|
||||
|
||||
if self.image_transforms is not None:
|
||||
for cam in self.camera_keys:
|
||||
|
|
|
@ -16,13 +16,15 @@
|
|||
import json
|
||||
import warnings
|
||||
from functools import cache
|
||||
from itertools import accumulate
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Dict
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import DatasetCard, HfApi, hf_hub_download, snapshot_download
|
||||
from huggingface_hub import DatasetCard, HfApi, hf_hub_download
|
||||
from PIL import Image as PILImage
|
||||
from torchvision import transforms
|
||||
|
||||
|
@ -193,40 +195,102 @@ def load_tasks(repo_id: str, version: str, local_dir: Path) -> dict:
|
|||
return json.load(f)
|
||||
|
||||
|
||||
def download_episodes(
|
||||
repo_id: str,
|
||||
version: str,
|
||||
local_dir: Path,
|
||||
data_path: str,
|
||||
video_keys: list,
|
||||
total_episodes: int,
|
||||
episodes: list[int] | None = None,
|
||||
videos_path: str | None = None,
|
||||
) -> None:
|
||||
"""Downloads the dataset from the given 'repo_id' at the provided 'version'. If 'episodes' is given, this
|
||||
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
||||
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
||||
in 'local_dir', they won't be downloaded again.
|
||||
|
||||
Note: Currently, if you're running this code offline but you already have the files in 'local_dir',
|
||||
snapshot_download will still fail. This behavior will be fixed in an upcoming update of huggingface_hub.
|
||||
"""
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||
files = None
|
||||
def get_episode_data_index(episodes: list, episode_dicts: list[dict]) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)}
|
||||
if episodes is not None:
|
||||
files = [data_path.format(episode_index=ep_idx, total_episodes=total_episodes) for ep_idx in episodes]
|
||||
if len(video_keys) > 0:
|
||||
video_files = [
|
||||
videos_path.format(video_key=vid_key, episode_index=ep_idx)
|
||||
for vid_key in video_keys
|
||||
for ep_idx in episodes
|
||||
]
|
||||
files += video_files
|
||||
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
|
||||
|
||||
snapshot_download(
|
||||
repo_id, repo_type="dataset", revision=version, local_dir=local_dir, allow_patterns=files
|
||||
)
|
||||
cumulative_lenghts = list(accumulate(episode_lengths.values()))
|
||||
return {
|
||||
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
|
||||
"to": torch.LongTensor(cumulative_lenghts),
|
||||
}
|
||||
|
||||
|
||||
def check_timestamps_sync(
|
||||
hf_dataset: datasets.Dataset,
|
||||
episode_data_index: dict[str, torch.Tensor],
|
||||
fps: int,
|
||||
tolerance_s: float,
|
||||
raise_value_error: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
This check is to make sure that each timestamps is separated to the next by 1/fps +/- tolerance to
|
||||
account for possible numerical error.
|
||||
"""
|
||||
timestamps = torch.stack(hf_dataset["timestamp"])
|
||||
# timestamps[2] += tolerance_s # TODO delete
|
||||
# timestamps[-2] += tolerance_s/2 # TODO delete
|
||||
diffs = torch.diff(timestamps)
|
||||
within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s
|
||||
|
||||
# We mask differences between the timestamp at the end of an episode
|
||||
# and the one the start of the next episode since these are expected
|
||||
# to be outside tolerance.
|
||||
mask = torch.ones(len(diffs), dtype=torch.bool)
|
||||
ignored_diffs = episode_data_index["to"][:-1] - 1
|
||||
mask[ignored_diffs] = False
|
||||
filtered_within_tolerance = within_tolerance[mask]
|
||||
|
||||
if not torch.all(filtered_within_tolerance):
|
||||
# Track original indices before masking
|
||||
original_indices = torch.arange(len(diffs))
|
||||
filtered_indices = original_indices[mask]
|
||||
outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance).squeeze()
|
||||
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
|
||||
episode_indices = torch.stack(hf_dataset["episode_index"])
|
||||
|
||||
outside_tolerances = []
|
||||
for idx in outside_tolerance_indices:
|
||||
entry = {
|
||||
"timestamps": [timestamps[idx], timestamps[idx + 1]],
|
||||
"diff": diffs[idx],
|
||||
"episode_index": episode_indices[idx].item(),
|
||||
}
|
||||
outside_tolerances.append(entry)
|
||||
|
||||
if raise_value_error:
|
||||
raise ValueError(
|
||||
f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
|
||||
This might be due to synchronization issues with timestamps during data collection.
|
||||
\n{pformat(outside_tolerances)}"""
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def check_delta_timestamps(
|
||||
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
|
||||
) -> bool:
|
||||
outside_tolerance = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
abs_delta_ts = torch.abs(torch.tensor(delta_ts))
|
||||
within_tolerance = (abs_delta_ts % (1 / fps)) <= tolerance_s
|
||||
if not torch.all(within_tolerance):
|
||||
outside_tolerance[key] = torch.tensor(delta_ts)[~within_tolerance]
|
||||
|
||||
if len(outside_tolerance) > 0:
|
||||
if raise_value_error:
|
||||
raise ValueError(
|
||||
f"""
|
||||
The following delta_timestamps are found outside of tolerance range.
|
||||
Please make sure they are multiples of 1/{fps} +/- tolerance and adjust
|
||||
their values accordingly.
|
||||
\n{pformat(outside_tolerance)}
|
||||
"""
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
|
||||
delta_indices = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
delta_indices[key] = (torch.tensor(delta_ts) * fps).long().tolist()
|
||||
|
||||
return delta_indices
|
||||
|
||||
|
||||
def load_previous_and_future_frames(
|
||||
|
|
|
@ -27,45 +27,8 @@ import torchvision
|
|||
from datasets.features.features import register_feature
|
||||
|
||||
|
||||
def load_from_videos(
|
||||
item: dict[str, torch.Tensor],
|
||||
video_frame_keys: list[str],
|
||||
videos_dir: Path,
|
||||
tolerance_s: float,
|
||||
backend: str = "pyav",
|
||||
):
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault.
|
||||
This probably happens because a memory reference to the video loader is created in the main process and a
|
||||
subprocess fails to access it.
|
||||
"""
|
||||
# since video path already contains "videos" (e.g. videos_dir="data/videos", path="videos/episode_0.mp4")
|
||||
data_dir = videos_dir.parent
|
||||
|
||||
for key in video_frame_keys:
|
||||
if isinstance(item[key], list):
|
||||
# load multiple frames at once (expected when delta_timestamps is not None)
|
||||
timestamps = [frame["timestamp"] for frame in item[key]]
|
||||
paths = [frame["path"] for frame in item[key]]
|
||||
if len(set(paths)) > 1:
|
||||
raise NotImplementedError("All video paths are expected to be the same for now.")
|
||||
video_path = data_dir / paths[0]
|
||||
|
||||
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
item[key] = frames
|
||||
else:
|
||||
# load one frame
|
||||
timestamps = [item[key]["timestamp"]]
|
||||
video_path = data_dir / item[key]["path"]
|
||||
|
||||
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
item[key] = frames[0]
|
||||
|
||||
return item
|
||||
|
||||
|
||||
def decode_video_frames_torchvision(
|
||||
video_path: str,
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str = "pyav",
|
||||
|
|
Loading…
Reference in New Issue