Add streaming
This commit is contained in:
parent
fe483b1d0d
commit
c1b4dae6d0
|
@ -50,6 +50,7 @@ from lerobot.common.datasets.utils import (
|
||||||
get_hf_features_from_features,
|
get_hf_features_from_features,
|
||||||
get_hub_safe_version,
|
get_hub_safe_version,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
|
item_to_torch,
|
||||||
load_episodes,
|
load_episodes,
|
||||||
load_info,
|
load_info,
|
||||||
load_stats,
|
load_stats,
|
||||||
|
@ -214,6 +215,9 @@ class LeRobotDatasetMetadata:
|
||||||
task_index = self.task_to_task_index.get(task, None)
|
task_index = self.task_to_task_index.get(task, None)
|
||||||
return task_index if task_index is not None else self.total_tasks
|
return task_index if task_index is not None else self.total_tasks
|
||||||
|
|
||||||
|
def html_root(self) -> str:
|
||||||
|
return f"https://huggingface.co/datasets/{self.repo_id}/resolve/main"
|
||||||
|
|
||||||
def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None:
|
def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None:
|
||||||
self.info["total_episodes"] += 1
|
self.info["total_episodes"] += 1
|
||||||
self.info["total_frames"] += episode_length
|
self.info["total_frames"] += episode_length
|
||||||
|
@ -334,6 +338,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
download_videos: bool = True,
|
download_videos: bool = True,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
|
streaming: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
2 modes are available for instantiating this class, depending on 2 different use cases:
|
2 modes are available for instantiating this class, depending on 2 different use cases:
|
||||||
|
@ -431,6 +436,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
will be made. Defaults to False.
|
will be made. Defaults to False.
|
||||||
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
|
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
|
||||||
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
|
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
|
||||||
|
streaming (bool, optional): If set to True, don't download the data files. Instead, it streams the data
|
||||||
|
progressively while iterating on the dataset. Default to False.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
|
@ -440,10 +447,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.episodes = episodes
|
self.episodes = episodes
|
||||||
self.tolerance_s = tolerance_s
|
self.tolerance_s = tolerance_s
|
||||||
self.video_backend = video_backend if video_backend else "pyav"
|
self.video_backend = video_backend if video_backend else "pyav"
|
||||||
self.delta_indices = None
|
|
||||||
self.local_files_only = local_files_only
|
self.local_files_only = local_files_only
|
||||||
|
self.streaming = streaming
|
||||||
|
|
||||||
# Unused attributes
|
# Unused attributes
|
||||||
|
self.delta_indices = None
|
||||||
self.image_writer = None
|
self.image_writer = None
|
||||||
self.episode_buffer = None
|
self.episode_buffer = None
|
||||||
|
|
||||||
|
@ -456,16 +464,21 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
||||||
|
|
||||||
# Load actual data
|
# Load actual data
|
||||||
self.download_episodes(download_videos)
|
if not self.streaming:
|
||||||
|
self.download_episodes(download_videos)
|
||||||
self.hf_dataset = self.load_hf_dataset()
|
self.hf_dataset = self.load_hf_dataset()
|
||||||
|
if self.streaming:
|
||||||
|
self.hf_dataset_iter = iter(self.hf_dataset.shuffle(buffer_size=1000))
|
||||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||||
|
|
||||||
# Check timestamps
|
# Check timestamps
|
||||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
if not self.streaming:
|
||||||
|
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||||
|
|
||||||
# Setup delta_indices
|
# Setup delta_indices
|
||||||
if self.delta_timestamps is not None:
|
if self.delta_timestamps is not None:
|
||||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
if not self.streaming:
|
||||||
|
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||||
|
|
||||||
# Available stats implies all videos have been encoded and dataset is iterable
|
# Available stats implies all videos have been encoded and dataset is iterable
|
||||||
|
@ -550,13 +563,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||||
if self.episodes is None:
|
if self.episodes is None:
|
||||||
path = str(self.root / "data")
|
path = str(self.root / "data")
|
||||||
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
|
hf_dataset = load_dataset("parquet", data_dir=path, split="train", streaming=self.streaming)
|
||||||
else:
|
else:
|
||||||
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
||||||
hf_dataset = load_dataset("parquet", data_files=files, split="train")
|
hf_dataset = load_dataset("parquet", data_files=files, split="train", streaming=self.streaming)
|
||||||
|
|
||||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
if not self.streaming:
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
# TODO(aliberts): hf_dataset.set_format("torch")
|
||||||
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
|
@ -632,7 +646,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
"""
|
"""
|
||||||
item = {}
|
item = {}
|
||||||
for vid_key, query_ts in query_timestamps.items():
|
for vid_key, query_ts in query_timestamps.items():
|
||||||
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
root = self.meta.html_root if self.streaming else self.root
|
||||||
|
video_path = Path(root) / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||||
frames = decode_video_frames_torchvision(
|
frames = decode_video_frames_torchvision(
|
||||||
video_path, query_ts, self.tolerance_s, self.video_backend
|
video_path, query_ts, self.tolerance_s, self.video_backend
|
||||||
)
|
)
|
||||||
|
@ -649,7 +664,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
return self.num_frames
|
return self.num_frames
|
||||||
|
|
||||||
def __getitem__(self, idx) -> dict:
|
def __getitem__(self, idx) -> dict:
|
||||||
item = self.hf_dataset[idx]
|
if self.streaming:
|
||||||
|
try:
|
||||||
|
item = next(self.hf_dataset_iter)
|
||||||
|
except StopIteration:
|
||||||
|
self.hf_dataset_iter = iter(self.hf_dataset.shuffle(buffer_size=1000))
|
||||||
|
item = next(self.hf_dataset_iter)
|
||||||
|
item = item_to_torch(item)
|
||||||
|
else:
|
||||||
|
item = self.hf_dataset[idx]
|
||||||
ep_idx = item["episode_index"].item()
|
ep_idx = item["episode_index"].item()
|
||||||
|
|
||||||
query_indices = None
|
query_indices = None
|
||||||
|
|
|
@ -205,6 +205,18 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||||
return items_dict
|
return items_dict
|
||||||
|
|
||||||
|
|
||||||
|
def item_to_torch(item: dict):
|
||||||
|
for key, value in item.items():
|
||||||
|
if isinstance(value, PILImage.Image):
|
||||||
|
to_tensor = transforms.ToTensor()
|
||||||
|
item[key] = to_tensor(value)
|
||||||
|
elif value is None or isinstance(value, str):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
item[key] = torch.tensor(value)
|
||||||
|
return item
|
||||||
|
|
||||||
|
|
||||||
def _get_major_minor(version: str) -> tuple[int]:
|
def _get_major_minor(version: str) -> tuple[int]:
|
||||||
split = version.strip("v").split(".")
|
split = version.strip("v").split(".")
|
||||||
return int(split[0]), int(split[1])
|
return int(split[0]), int(split[1])
|
||||||
|
|
Loading…
Reference in New Issue