Update LeRobotDataset.__get_item__
This commit is contained in:
parent
3113038beb
commit
b417cebc4e
|
@ -15,25 +15,27 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from itertools import accumulate
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
import torch.utils
|
import torch.utils
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
download_episodes,
|
check_delta_timestamps,
|
||||||
|
check_timestamps_sync,
|
||||||
|
get_delta_indices,
|
||||||
|
get_episode_data_index,
|
||||||
get_hub_safe_version,
|
get_hub_safe_version,
|
||||||
load_hf_dataset,
|
load_hf_dataset,
|
||||||
load_info,
|
load_info,
|
||||||
load_previous_and_future_frames,
|
|
||||||
load_stats,
|
load_stats,
|
||||||
load_tasks,
|
load_tasks,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
|
from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision
|
||||||
|
|
||||||
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
|
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
|
||||||
CODEBASE_VERSION = "v2.0"
|
CODEBASE_VERSION = "v2.0"
|
||||||
|
@ -49,6 +51,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
image_transforms: Callable | None = None,
|
image_transforms: Callable | None = None,
|
||||||
delta_timestamps: dict[list[float]] | None = None,
|
delta_timestamps: dict[list[float]] | None = None,
|
||||||
|
tolerance_s: float = 1e-4,
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -58,7 +61,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.image_transforms = image_transforms
|
self.image_transforms = image_transforms
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
self.episodes = episodes
|
self.episodes = episodes
|
||||||
|
self.tolerance_s = tolerance_s
|
||||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||||
|
self.delta_indices = None
|
||||||
|
|
||||||
# Load metadata
|
# Load metadata
|
||||||
self.root.mkdir(exist_ok=True, parents=True)
|
self.root.mkdir(exist_ok=True, parents=True)
|
||||||
|
@ -68,34 +73,60 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.tasks = load_tasks(repo_id, self._version, self.root)
|
self.tasks = load_tasks(repo_id, self._version, self.root)
|
||||||
|
|
||||||
# Load actual data
|
# Load actual data
|
||||||
download_episodes(
|
self.download_episodes()
|
||||||
repo_id,
|
|
||||||
self._version,
|
|
||||||
self.root,
|
|
||||||
self.data_path,
|
|
||||||
self.video_keys,
|
|
||||||
self.num_episodes,
|
|
||||||
self.episodes,
|
|
||||||
self.videos_path,
|
|
||||||
)
|
|
||||||
self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes)
|
self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes)
|
||||||
self.episode_data_index = self.get_episode_data_index()
|
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
||||||
|
|
||||||
|
# Check timestamps
|
||||||
|
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||||
|
|
||||||
|
# Setup delta_indices
|
||||||
|
if self.delta_timestamps is not None:
|
||||||
|
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||||
|
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||||
|
|
||||||
# TODO(aliberts):
|
# TODO(aliberts):
|
||||||
# - [ ] Update __get_item__
|
# - [X] Move delta_timestamp logic outside __get_item__
|
||||||
|
# - [X] Update __get_item__
|
||||||
|
# - [ ] Add self.add_frame()
|
||||||
# - [ ] Add self.consolidate() for:
|
# - [ ] Add self.consolidate() for:
|
||||||
|
# - [X] Check timestamps sync
|
||||||
# - [ ] Sanity checks (episodes num, shapes, files, etc.)
|
# - [ ] Sanity checks (episodes num, shapes, files, etc.)
|
||||||
# - [ ] Update episode_index (arg update=True)
|
# - [ ] Update episode_index (arg update=True)
|
||||||
# - [ ] Update info.json (arg update=True)
|
# - [ ] Update info.json (arg update=True)
|
||||||
|
|
||||||
# TODO(aliberts): remove (deprecated)
|
def download_episodes(self) -> None:
|
||||||
# if split == "train":
|
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
|
||||||
# self.episode_data_index = load_episode_data_index(self.episodes, self.episode_list)
|
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
||||||
# else:
|
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
||||||
# self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
|
in 'local_dir', they won't be downloaded again.
|
||||||
# self.hf_dataset = reset_episode_index(self.hf_dataset)
|
|
||||||
# if self.video:
|
Note: Currently, if you're running this code offline but you already have the files in 'local_dir',
|
||||||
# self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
|
snapshot_download will still fail. This behavior will be fixed in an upcoming update of huggingface_hub.
|
||||||
|
"""
|
||||||
|
# TODO(rcadene, aliberts): implement faster transfer
|
||||||
|
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||||
|
files = None
|
||||||
|
if self.episodes is not None:
|
||||||
|
files = [
|
||||||
|
self.data_path.format(episode_index=ep_idx, total_episodes=self.total_episodes)
|
||||||
|
for ep_idx in self.episodes
|
||||||
|
]
|
||||||
|
if len(self.video_keys) > 0:
|
||||||
|
video_files = [
|
||||||
|
self.videos_path.format(video_key=vid_key, episode_index=ep_idx)
|
||||||
|
for vid_key in self.video_keys
|
||||||
|
for ep_idx in self.episodes
|
||||||
|
]
|
||||||
|
files += video_files
|
||||||
|
|
||||||
|
snapshot_download(
|
||||||
|
self.repo_id,
|
||||||
|
repo_type="dataset",
|
||||||
|
revision=self._version,
|
||||||
|
local_dir=self.root,
|
||||||
|
allow_patterns=files,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data_path(self) -> str:
|
def data_path(self) -> str:
|
||||||
|
@ -134,17 +165,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def camera_keys(self) -> list[str]:
|
def camera_keys(self) -> list[str]:
|
||||||
"""Keys to access image and video streams from cameras."""
|
"""Keys to access image and video streams from cameras (regardless of their storage method)."""
|
||||||
return self.image_keys + self.video_keys
|
return self.image_keys + self.video_keys
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def video_frame_keys(self) -> list[str]:
|
def video_frame_keys(self) -> list[str]:
|
||||||
"""Keys to access video frames that requires to be decoded into images.
|
"""
|
||||||
|
DEPRECATED, USE 'video_keys' INSTEAD
|
||||||
|
Keys to access video frames that requires to be decoded into images.
|
||||||
|
|
||||||
Note: It is empty if the dataset contains images only,
|
Note: It is empty if the dataset contains images only,
|
||||||
or equal to `self.cameras` if the dataset contains videos only,
|
or equal to `self.cameras` if the dataset contains videos only,
|
||||||
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
|
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
|
||||||
"""
|
"""
|
||||||
|
# TODO(aliberts): remove
|
||||||
video_frame_keys = []
|
video_frame_keys = []
|
||||||
for key, feats in self.hf_dataset.features.items():
|
for key, feats in self.hf_dataset.features.items():
|
||||||
if isinstance(feats, VideoFrame):
|
if isinstance(feats, VideoFrame):
|
||||||
|
@ -166,54 +200,97 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
"""Total number of episodes available."""
|
"""Total number of episodes available."""
|
||||||
return self.info["total_episodes"]
|
return self.info["total_episodes"]
|
||||||
|
|
||||||
@property
|
# @property
|
||||||
def tolerance_s(self) -> float:
|
# def tolerance_s(self) -> float:
|
||||||
"""Tolerance in seconds used to discard loaded frames when their timestamps
|
# """Tolerance in seconds used to discard loaded frames when their timestamps
|
||||||
are not close enough from the requested frames. It is only used when `delta_timestamps`
|
# are not close enough from the requested frames. It is used at the init of the dataset to make sure
|
||||||
is provided or when loading video frames from mp4 files.
|
# that each timestamps is separated to the next by 1/fps +/- tolerance. It is only used when
|
||||||
"""
|
# `delta_timestamps` is provided or when loading video frames from mp4 files.
|
||||||
# 1e-4 to account for possible numerical error
|
# """
|
||||||
return 1 / self.fps - 1e-4
|
# # 1e-4 to account for possible numerical error
|
||||||
|
# return 1e-4
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shapes(self) -> dict:
|
def shapes(self) -> dict:
|
||||||
"""Shapes for the different features."""
|
"""Shapes for the different features."""
|
||||||
self.info.get("shapes")
|
self.info.get("shapes")
|
||||||
|
|
||||||
def get_episode_data_index(self) -> dict[str, torch.Tensor]:
|
def current_episode_index(self, idx: int) -> int:
|
||||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(self.episode_dicts)}
|
episode_index = self.hf_dataset["episode_index"][idx]
|
||||||
if self.episodes is not None:
|
if self.episodes is not None:
|
||||||
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in self.episodes}
|
# get episode_index from selected episodes
|
||||||
|
episode_index = self.episodes.index(episode_index)
|
||||||
|
|
||||||
cumulative_lenghts = list(accumulate(episode_lengths.values()))
|
return episode_index
|
||||||
|
|
||||||
|
def episode_length(self, episode_index) -> int:
|
||||||
|
"""Number of samples/frames for given episode."""
|
||||||
|
return self.info["episodes"][episode_index]["length"]
|
||||||
|
|
||||||
|
def _get_query_indices(self, idx: int, ep_idx: int) -> dict[str, list[int]]:
|
||||||
|
# Pad values outside of current episode range
|
||||||
|
ep_start = self.episode_data_index["from"][ep_idx]
|
||||||
|
ep_end = self.episode_data_index["to"][ep_idx]
|
||||||
return {
|
return {
|
||||||
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
|
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
|
||||||
"to": torch.LongTensor(cumulative_lenghts),
|
for key, delta_idx in self.delta_indices.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _get_query_timestamps(
|
||||||
|
self, query_indices: dict[str, list[int]], current_ts: float
|
||||||
|
) -> dict[str, list[float]]:
|
||||||
|
query_timestamps = {}
|
||||||
|
for key in self.video_keys:
|
||||||
|
if key in query_indices:
|
||||||
|
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
||||||
|
query_timestamps[key] = torch.stack(timestamps).tolist()
|
||||||
|
else:
|
||||||
|
query_timestamps[key] = [current_ts]
|
||||||
|
|
||||||
|
return query_timestamps
|
||||||
|
|
||||||
|
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||||
|
return {
|
||||||
|
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
||||||
|
for key, q_idx in query_indices.items()
|
||||||
|
if key not in self.video_keys
|
||||||
|
}
|
||||||
|
|
||||||
|
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
|
||||||
|
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||||
|
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
|
||||||
|
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
||||||
|
the main process and a subprocess fails to access it.
|
||||||
|
"""
|
||||||
|
item = {}
|
||||||
|
for vid_key, query_ts in query_timestamps.items():
|
||||||
|
video_path = self.root / self.videos_path.format(video_key=vid_key, episode_index=ep_idx)
|
||||||
|
frames = decode_video_frames_torchvision(
|
||||||
|
video_path, query_ts, self.tolerance_s, self.video_backend
|
||||||
|
)
|
||||||
|
item[vid_key] = frames
|
||||||
|
|
||||||
|
return item
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.num_samples
|
return self.num_samples
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx) -> dict:
|
||||||
item = self.hf_dataset[idx]
|
item = self.hf_dataset[idx]
|
||||||
|
ep_idx = item["episode_index"].item()
|
||||||
|
|
||||||
if self.delta_timestamps is not None:
|
if self.delta_indices is not None:
|
||||||
item = load_previous_and_future_frames(
|
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
|
||||||
item,
|
query_indices = self._get_query_indices(idx, current_ep_idx)
|
||||||
self.hf_dataset,
|
query_result = self._query_hf_dataset(query_indices)
|
||||||
self.episode_data_index,
|
for key, val in query_result.items():
|
||||||
self.delta_timestamps,
|
item[key] = val
|
||||||
self.tolerance_s,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.video:
|
if len(self.video_keys) > 0:
|
||||||
item = load_from_videos(
|
current_ts = item["timestamp"].item()
|
||||||
item,
|
query_timestamps = self._get_query_timestamps(query_indices, current_ts)
|
||||||
self.video_keys,
|
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||||
self.videos_dir,
|
item = {**video_frames, **item}
|
||||||
self.tolerance_s,
|
|
||||||
self.video_backend,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.image_transforms is not None:
|
if self.image_transforms is not None:
|
||||||
for cam in self.camera_keys:
|
for cam in self.camera_keys:
|
||||||
|
|
|
@ -16,13 +16,15 @@
|
||||||
import json
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
from functools import cache
|
from functools import cache
|
||||||
|
from itertools import accumulate
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from pprint import pformat
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from huggingface_hub import DatasetCard, HfApi, hf_hub_download, snapshot_download
|
from huggingface_hub import DatasetCard, HfApi, hf_hub_download
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
|
@ -193,40 +195,102 @@ def load_tasks(repo_id: str, version: str, local_dir: Path) -> dict:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
def download_episodes(
|
def get_episode_data_index(episodes: list, episode_dicts: list[dict]) -> dict[str, torch.Tensor]:
|
||||||
repo_id: str,
|
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)}
|
||||||
version: str,
|
|
||||||
local_dir: Path,
|
|
||||||
data_path: str,
|
|
||||||
video_keys: list,
|
|
||||||
total_episodes: int,
|
|
||||||
episodes: list[int] | None = None,
|
|
||||||
videos_path: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Downloads the dataset from the given 'repo_id' at the provided 'version'. If 'episodes' is given, this
|
|
||||||
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
|
||||||
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
|
||||||
in 'local_dir', they won't be downloaded again.
|
|
||||||
|
|
||||||
Note: Currently, if you're running this code offline but you already have the files in 'local_dir',
|
|
||||||
snapshot_download will still fail. This behavior will be fixed in an upcoming update of huggingface_hub.
|
|
||||||
"""
|
|
||||||
# TODO(rcadene, aliberts): implement faster transfer
|
|
||||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
|
||||||
files = None
|
|
||||||
if episodes is not None:
|
if episodes is not None:
|
||||||
files = [data_path.format(episode_index=ep_idx, total_episodes=total_episodes) for ep_idx in episodes]
|
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
|
||||||
if len(video_keys) > 0:
|
|
||||||
video_files = [
|
|
||||||
videos_path.format(video_key=vid_key, episode_index=ep_idx)
|
|
||||||
for vid_key in video_keys
|
|
||||||
for ep_idx in episodes
|
|
||||||
]
|
|
||||||
files += video_files
|
|
||||||
|
|
||||||
snapshot_download(
|
cumulative_lenghts = list(accumulate(episode_lengths.values()))
|
||||||
repo_id, repo_type="dataset", revision=version, local_dir=local_dir, allow_patterns=files
|
return {
|
||||||
|
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
|
||||||
|
"to": torch.LongTensor(cumulative_lenghts),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def check_timestamps_sync(
|
||||||
|
hf_dataset: datasets.Dataset,
|
||||||
|
episode_data_index: dict[str, torch.Tensor],
|
||||||
|
fps: int,
|
||||||
|
tolerance_s: float,
|
||||||
|
raise_value_error: bool = True,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
This check is to make sure that each timestamps is separated to the next by 1/fps +/- tolerance to
|
||||||
|
account for possible numerical error.
|
||||||
|
"""
|
||||||
|
timestamps = torch.stack(hf_dataset["timestamp"])
|
||||||
|
# timestamps[2] += tolerance_s # TODO delete
|
||||||
|
# timestamps[-2] += tolerance_s/2 # TODO delete
|
||||||
|
diffs = torch.diff(timestamps)
|
||||||
|
within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s
|
||||||
|
|
||||||
|
# We mask differences between the timestamp at the end of an episode
|
||||||
|
# and the one the start of the next episode since these are expected
|
||||||
|
# to be outside tolerance.
|
||||||
|
mask = torch.ones(len(diffs), dtype=torch.bool)
|
||||||
|
ignored_diffs = episode_data_index["to"][:-1] - 1
|
||||||
|
mask[ignored_diffs] = False
|
||||||
|
filtered_within_tolerance = within_tolerance[mask]
|
||||||
|
|
||||||
|
if not torch.all(filtered_within_tolerance):
|
||||||
|
# Track original indices before masking
|
||||||
|
original_indices = torch.arange(len(diffs))
|
||||||
|
filtered_indices = original_indices[mask]
|
||||||
|
outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance).squeeze()
|
||||||
|
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
|
||||||
|
episode_indices = torch.stack(hf_dataset["episode_index"])
|
||||||
|
|
||||||
|
outside_tolerances = []
|
||||||
|
for idx in outside_tolerance_indices:
|
||||||
|
entry = {
|
||||||
|
"timestamps": [timestamps[idx], timestamps[idx + 1]],
|
||||||
|
"diff": diffs[idx],
|
||||||
|
"episode_index": episode_indices[idx].item(),
|
||||||
|
}
|
||||||
|
outside_tolerances.append(entry)
|
||||||
|
|
||||||
|
if raise_value_error:
|
||||||
|
raise ValueError(
|
||||||
|
f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
|
||||||
|
This might be due to synchronization issues with timestamps during data collection.
|
||||||
|
\n{pformat(outside_tolerances)}"""
|
||||||
)
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def check_delta_timestamps(
|
||||||
|
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
|
||||||
|
) -> bool:
|
||||||
|
outside_tolerance = {}
|
||||||
|
for key, delta_ts in delta_timestamps.items():
|
||||||
|
abs_delta_ts = torch.abs(torch.tensor(delta_ts))
|
||||||
|
within_tolerance = (abs_delta_ts % (1 / fps)) <= tolerance_s
|
||||||
|
if not torch.all(within_tolerance):
|
||||||
|
outside_tolerance[key] = torch.tensor(delta_ts)[~within_tolerance]
|
||||||
|
|
||||||
|
if len(outside_tolerance) > 0:
|
||||||
|
if raise_value_error:
|
||||||
|
raise ValueError(
|
||||||
|
f"""
|
||||||
|
The following delta_timestamps are found outside of tolerance range.
|
||||||
|
Please make sure they are multiples of 1/{fps} +/- tolerance and adjust
|
||||||
|
their values accordingly.
|
||||||
|
\n{pformat(outside_tolerance)}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
|
||||||
|
delta_indices = {}
|
||||||
|
for key, delta_ts in delta_timestamps.items():
|
||||||
|
delta_indices[key] = (torch.tensor(delta_ts) * fps).long().tolist()
|
||||||
|
|
||||||
|
return delta_indices
|
||||||
|
|
||||||
|
|
||||||
def load_previous_and_future_frames(
|
def load_previous_and_future_frames(
|
||||||
|
|
|
@ -27,45 +27,8 @@ import torchvision
|
||||||
from datasets.features.features import register_feature
|
from datasets.features.features import register_feature
|
||||||
|
|
||||||
|
|
||||||
def load_from_videos(
|
|
||||||
item: dict[str, torch.Tensor],
|
|
||||||
video_frame_keys: list[str],
|
|
||||||
videos_dir: Path,
|
|
||||||
tolerance_s: float,
|
|
||||||
backend: str = "pyav",
|
|
||||||
):
|
|
||||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
|
||||||
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault.
|
|
||||||
This probably happens because a memory reference to the video loader is created in the main process and a
|
|
||||||
subprocess fails to access it.
|
|
||||||
"""
|
|
||||||
# since video path already contains "videos" (e.g. videos_dir="data/videos", path="videos/episode_0.mp4")
|
|
||||||
data_dir = videos_dir.parent
|
|
||||||
|
|
||||||
for key in video_frame_keys:
|
|
||||||
if isinstance(item[key], list):
|
|
||||||
# load multiple frames at once (expected when delta_timestamps is not None)
|
|
||||||
timestamps = [frame["timestamp"] for frame in item[key]]
|
|
||||||
paths = [frame["path"] for frame in item[key]]
|
|
||||||
if len(set(paths)) > 1:
|
|
||||||
raise NotImplementedError("All video paths are expected to be the same for now.")
|
|
||||||
video_path = data_dir / paths[0]
|
|
||||||
|
|
||||||
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
|
||||||
item[key] = frames
|
|
||||||
else:
|
|
||||||
# load one frame
|
|
||||||
timestamps = [item[key]["timestamp"]]
|
|
||||||
video_path = data_dir / item[key]["path"]
|
|
||||||
|
|
||||||
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
|
||||||
item[key] = frames[0]
|
|
||||||
|
|
||||||
return item
|
|
||||||
|
|
||||||
|
|
||||||
def decode_video_frames_torchvision(
|
def decode_video_frames_torchvision(
|
||||||
video_path: str,
|
video_path: Path | str,
|
||||||
timestamps: list[float],
|
timestamps: list[float],
|
||||||
tolerance_s: float,
|
tolerance_s: float,
|
||||||
backend: str = "pyav",
|
backend: str = "pyav",
|
||||||
|
|
Loading…
Reference in New Issue