Add streaming
This commit is contained in:
parent
fe483b1d0d
commit
c1b4dae6d0
|
@ -50,6 +50,7 @@ from lerobot.common.datasets.utils import (
|
|||
get_hf_features_from_features,
|
||||
get_hub_safe_version,
|
||||
hf_transform_to_torch,
|
||||
item_to_torch,
|
||||
load_episodes,
|
||||
load_info,
|
||||
load_stats,
|
||||
|
@ -214,6 +215,9 @@ class LeRobotDatasetMetadata:
|
|||
task_index = self.task_to_task_index.get(task, None)
|
||||
return task_index if task_index is not None else self.total_tasks
|
||||
|
||||
def html_root(self) -> str:
|
||||
return f"https://huggingface.co/datasets/{self.repo_id}/resolve/main"
|
||||
|
||||
def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None:
|
||||
self.info["total_episodes"] += 1
|
||||
self.info["total_frames"] += episode_length
|
||||
|
@ -334,6 +338,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
download_videos: bool = True,
|
||||
local_files_only: bool = False,
|
||||
video_backend: str | None = None,
|
||||
streaming: bool = False,
|
||||
):
|
||||
"""
|
||||
2 modes are available for instantiating this class, depending on 2 different use cases:
|
||||
|
@ -431,6 +436,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
will be made. Defaults to False.
|
||||
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
|
||||
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
|
||||
streaming (bool, optional): If set to True, don't download the data files. Instead, it streams the data
|
||||
progressively while iterating on the dataset. Default to False.
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
|
@ -440,10 +447,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.video_backend = video_backend if video_backend else "pyav"
|
||||
self.delta_indices = None
|
||||
self.local_files_only = local_files_only
|
||||
self.streaming = streaming
|
||||
|
||||
# Unused attributes
|
||||
self.delta_indices = None
|
||||
self.image_writer = None
|
||||
self.episode_buffer = None
|
||||
|
||||
|
@ -456,16 +464,21 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
||||
|
||||
# Load actual data
|
||||
self.download_episodes(download_videos)
|
||||
if not self.streaming:
|
||||
self.download_episodes(download_videos)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
if self.streaming:
|
||||
self.hf_dataset_iter = iter(self.hf_dataset.shuffle(buffer_size=1000))
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
|
||||
# Check timestamps
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
if not self.streaming:
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
|
||||
# Setup delta_indices
|
||||
if self.delta_timestamps is not None:
|
||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
if not self.streaming:
|
||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
# Available stats implies all videos have been encoded and dataset is iterable
|
||||
|
@ -550,13 +563,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
if self.episodes is None:
|
||||
path = str(self.root / "data")
|
||||
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
|
||||
hf_dataset = load_dataset("parquet", data_dir=path, split="train", streaming=self.streaming)
|
||||
else:
|
||||
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
||||
hf_dataset = load_dataset("parquet", data_files=files, split="train")
|
||||
hf_dataset = load_dataset("parquet", data_files=files, split="train", streaming=self.streaming)
|
||||
|
||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
if not self.streaming:
|
||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
||||
return hf_dataset
|
||||
|
||||
|
@ -632,7 +646,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
"""
|
||||
item = {}
|
||||
for vid_key, query_ts in query_timestamps.items():
|
||||
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||
root = self.meta.html_root if self.streaming else self.root
|
||||
video_path = Path(root) / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||
frames = decode_video_frames_torchvision(
|
||||
video_path, query_ts, self.tolerance_s, self.video_backend
|
||||
)
|
||||
|
@ -649,7 +664,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
return self.num_frames
|
||||
|
||||
def __getitem__(self, idx) -> dict:
|
||||
item = self.hf_dataset[idx]
|
||||
if self.streaming:
|
||||
try:
|
||||
item = next(self.hf_dataset_iter)
|
||||
except StopIteration:
|
||||
self.hf_dataset_iter = iter(self.hf_dataset.shuffle(buffer_size=1000))
|
||||
item = next(self.hf_dataset_iter)
|
||||
item = item_to_torch(item)
|
||||
else:
|
||||
item = self.hf_dataset[idx]
|
||||
ep_idx = item["episode_index"].item()
|
||||
|
||||
query_indices = None
|
||||
|
|
|
@ -205,6 +205,18 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
|||
return items_dict
|
||||
|
||||
|
||||
def item_to_torch(item: dict):
|
||||
for key, value in item.items():
|
||||
if isinstance(value, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
item[key] = to_tensor(value)
|
||||
elif value is None or isinstance(value, str):
|
||||
pass
|
||||
else:
|
||||
item[key] = torch.tensor(value)
|
||||
return item
|
||||
|
||||
|
||||
def _get_major_minor(version: str) -> tuple[int]:
|
||||
split = version.strip("v").split(".")
|
||||
return int(split[0]), int(split[1])
|
||||
|
|
Loading…
Reference in New Issue