Add streaming

This commit is contained in:
Remi Cadene 2025-02-17 17:23:07 +01:00
parent fe483b1d0d
commit c1b4dae6d0
2 changed files with 45 additions and 10 deletions

View File

@ -50,6 +50,7 @@ from lerobot.common.datasets.utils import (
get_hf_features_from_features, get_hf_features_from_features,
get_hub_safe_version, get_hub_safe_version,
hf_transform_to_torch, hf_transform_to_torch,
item_to_torch,
load_episodes, load_episodes,
load_info, load_info,
load_stats, load_stats,
@ -214,6 +215,9 @@ class LeRobotDatasetMetadata:
task_index = self.task_to_task_index.get(task, None) task_index = self.task_to_task_index.get(task, None)
return task_index if task_index is not None else self.total_tasks return task_index if task_index is not None else self.total_tasks
def html_root(self) -> str:
return f"https://huggingface.co/datasets/{self.repo_id}/resolve/main"
def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None: def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None:
self.info["total_episodes"] += 1 self.info["total_episodes"] += 1
self.info["total_frames"] += episode_length self.info["total_frames"] += episode_length
@ -334,6 +338,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
download_videos: bool = True, download_videos: bool = True,
local_files_only: bool = False, local_files_only: bool = False,
video_backend: str | None = None, video_backend: str | None = None,
streaming: bool = False,
): ):
""" """
2 modes are available for instantiating this class, depending on 2 different use cases: 2 modes are available for instantiating this class, depending on 2 different use cases:
@ -431,6 +436,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
will be made. Defaults to False. will be made. Defaults to False.
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
a single option which is the pyav decoder used by Torchvision. Defaults to pyav. a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
streaming (bool, optional): If set to True, don't download the data files. Instead, it streams the data
progressively while iterating on the dataset. Default to False.
""" """
super().__init__() super().__init__()
self.repo_id = repo_id self.repo_id = repo_id
@ -440,10 +447,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episodes = episodes self.episodes = episodes
self.tolerance_s = tolerance_s self.tolerance_s = tolerance_s
self.video_backend = video_backend if video_backend else "pyav" self.video_backend = video_backend if video_backend else "pyav"
self.delta_indices = None
self.local_files_only = local_files_only self.local_files_only = local_files_only
self.streaming = streaming
# Unused attributes # Unused attributes
self.delta_indices = None
self.image_writer = None self.image_writer = None
self.episode_buffer = None self.episode_buffer = None
@ -456,16 +464,21 @@ class LeRobotDataset(torch.utils.data.Dataset):
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION) check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
# Load actual data # Load actual data
self.download_episodes(download_videos) if not self.streaming:
self.download_episodes(download_videos)
self.hf_dataset = self.load_hf_dataset() self.hf_dataset = self.load_hf_dataset()
if self.streaming:
self.hf_dataset_iter = iter(self.hf_dataset.shuffle(buffer_size=1000))
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
# Check timestamps # Check timestamps
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s) if not self.streaming:
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
# Setup delta_indices # Setup delta_indices
if self.delta_timestamps is not None: if self.delta_timestamps is not None:
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) if not self.streaming:
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps) self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
# Available stats implies all videos have been encoded and dataset is iterable # Available stats implies all videos have been encoded and dataset is iterable
@ -550,13 +563,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""hf_dataset contains all the observations, states, actions, rewards, etc.""" """hf_dataset contains all the observations, states, actions, rewards, etc."""
if self.episodes is None: if self.episodes is None:
path = str(self.root / "data") path = str(self.root / "data")
hf_dataset = load_dataset("parquet", data_dir=path, split="train") hf_dataset = load_dataset("parquet", data_dir=path, split="train", streaming=self.streaming)
else: else:
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes] files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
hf_dataset = load_dataset("parquet", data_files=files, split="train") hf_dataset = load_dataset("parquet", data_files=files, split="train", streaming=self.streaming)
# TODO(aliberts): hf_dataset.set_format("torch") if not self.streaming:
hf_dataset.set_transform(hf_transform_to_torch) # TODO(aliberts): hf_dataset.set_format("torch")
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset return hf_dataset
@ -632,7 +646,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
""" """
item = {} item = {}
for vid_key, query_ts in query_timestamps.items(): for vid_key, query_ts in query_timestamps.items():
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) root = self.meta.html_root if self.streaming else self.root
video_path = Path(root) / self.meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames_torchvision( frames = decode_video_frames_torchvision(
video_path, query_ts, self.tolerance_s, self.video_backend video_path, query_ts, self.tolerance_s, self.video_backend
) )
@ -649,7 +664,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
return self.num_frames return self.num_frames
def __getitem__(self, idx) -> dict: def __getitem__(self, idx) -> dict:
item = self.hf_dataset[idx] if self.streaming:
try:
item = next(self.hf_dataset_iter)
except StopIteration:
self.hf_dataset_iter = iter(self.hf_dataset.shuffle(buffer_size=1000))
item = next(self.hf_dataset_iter)
item = item_to_torch(item)
else:
item = self.hf_dataset[idx]
ep_idx = item["episode_index"].item() ep_idx = item["episode_index"].item()
query_indices = None query_indices = None

View File

@ -205,6 +205,18 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
return items_dict return items_dict
def item_to_torch(item: dict):
for key, value in item.items():
if isinstance(value, PILImage.Image):
to_tensor = transforms.ToTensor()
item[key] = to_tensor(value)
elif value is None or isinstance(value, str):
pass
else:
item[key] = torch.tensor(value)
return item
def _get_major_minor(version: str) -> tuple[int]: def _get_major_minor(version: str) -> tuple[int]:
split = version.strip("v").split(".") split = version.strip("v").split(".")
return int(split[0]), int(split[1]) return int(split[0]), int(split[1])