Rework LeRobotDataset.__init__

This commit is contained in:
Simon Alibert 2024-10-09 14:33:26 +02:00
parent 2d75b93ba0
commit 096824b5ff
2 changed files with 189 additions and 114 deletions

View File

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import os import os
from itertools import accumulate
from pathlib import Path from pathlib import Path
from typing import Callable from typing import Callable
@ -24,27 +25,27 @@ import torch.utils
from lerobot.common.datasets.compute_stats import aggregate_stats from lerobot.common.datasets.compute_stats import aggregate_stats
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
calculate_episode_data_index, download_episodes,
load_episode_data_index, get_hub_safe_version,
load_hf_dataset, load_hf_dataset,
load_info, load_info,
load_previous_and_future_frames, load_previous_and_future_frames,
load_stats, load_stats,
load_videos, load_tasks,
reset_episode_index,
) )
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md # For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
CODEBASE_VERSION = "v1.6" CODEBASE_VERSION = "v2.0"
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
class LeRobotDataset(torch.utils.data.Dataset): class LeRobotDataset(torch.utils.data.Dataset):
def __init__( def __init__(
self, self,
repo_id: str, repo_id: str,
root: Path | None = DATA_DIR, root: Path | None = None,
episodes: list[int] | None = None,
split: str = "train", split: str = "train",
image_transforms: Callable | None = None, image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None, delta_timestamps: dict[list[float]] | None = None,
@ -52,49 +53,89 @@ class LeRobotDataset(torch.utils.data.Dataset):
): ):
super().__init__() super().__init__()
self.repo_id = repo_id self.repo_id = repo_id
self.root = root self.root = root if root is not None else LEROBOT_HOME / repo_id
self.split = split self.split = split
self.image_transforms = image_transforms self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided self.episodes = episodes
# TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, root, split)
if split == "train":
self.episode_data_index = load_episode_data_index(repo_id, CODEBASE_VERSION, root)
else:
self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
self.hf_dataset = reset_episode_index(self.hf_dataset)
self.stats = load_stats(repo_id, CODEBASE_VERSION, root)
self.info = load_info(repo_id, CODEBASE_VERSION, root)
if self.video:
self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
self.video_backend = video_backend if video_backend is not None else "pyav" self.video_backend = video_backend if video_backend is not None else "pyav"
# Load metadata
self.root.mkdir(exist_ok=True, parents=True)
self._version = get_hub_safe_version(repo_id, CODEBASE_VERSION)
self.info = load_info(repo_id, self._version, self.root)
self.stats = load_stats(repo_id, self._version, self.root)
self.tasks = load_tasks(repo_id, self._version, self.root)
# Load actual data
download_episodes(
repo_id,
self._version,
self.root,
self.data_path,
self.video_keys,
self.num_episodes,
self.episodes,
self.videos_path,
)
self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes)
self.episode_data_index = self.get_episode_data_index()
# TODO(aliberts):
# - [ ] Update __get_item__
# - [ ] Add self.consolidate() for:
# - [ ] Sanity checks (episodes num, shapes, files, etc.)
# - [ ] Update episode_index (arg update=True)
# - [ ] Update info.json (arg update=True)
# TODO(aliberts): remove (deprecated)
# if split == "train":
# self.episode_data_index = load_episode_data_index(self.episodes, self.episode_list)
# else:
# self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
# self.hf_dataset = reset_episode_index(self.hf_dataset)
# if self.video:
# self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
@property
def data_path(self) -> str:
"""Formattable string for the parquet files."""
return self.info["data_path"]
@property
def videos_path(self) -> str | None:
"""Formattable string for the video files."""
return self.info["videos"]["videos_path"] if len(self.video_keys) > 0 else None
@property
def episode_dicts(self) -> list[dict]:
"""List of dictionary containing information for each episode, indexed by episode_index."""
return self.info["episodes"]
@property @property
def fps(self) -> int: def fps(self) -> int:
"""Frames per second used during data collection.""" """Frames per second used during data collection."""
return self.info["fps"] return self.info["fps"]
@property @property
def video(self) -> bool: def keys(self) -> list[str]:
"""Returns True if this dataset loads video frames from mp4 files. """Keys to access non-image data (state, actions etc.)."""
Returns False if it only loads images from png files. return self.info["keys"]
"""
return self.info.get("video", False)
@property @property
def features(self) -> datasets.Features: def image_keys(self) -> list[str]:
return self.hf_dataset.features """Keys to access visual modalities stored as images."""
return self.info["image_keys"]
@property
def video_keys(self) -> list[str]:
"""Keys to access visual modalities stored as videos."""
return self.info["video_keys"]
@property @property
def camera_keys(self) -> list[str]: def camera_keys(self) -> list[str]:
"""Keys to access image and video stream from cameras.""" """Keys to access image and video streams from cameras."""
keys = [] return self.image_keys + self.video_keys
for key, feats in self.hf_dataset.features.items():
if isinstance(feats, (datasets.Image, VideoFrame)):
keys.append(key)
return keys
@property @property
def video_frame_keys(self) -> list[str]: def video_frame_keys(self) -> list[str]:
@ -117,8 +158,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property @property
def num_episodes(self) -> int: def num_episodes(self) -> int:
"""Number of episodes.""" """Number of episodes selected."""
return len(self.hf_dataset.unique("episode_index")) return len(self.episodes) if self.episodes is not None else self.total_episodes
@property
def total_episodes(self) -> int:
"""Total number of episodes available."""
return self.info["total_episodes"]
@property @property
def tolerance_s(self) -> float: def tolerance_s(self) -> float:
@ -129,6 +175,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
# 1e-4 to account for possible numerical error # 1e-4 to account for possible numerical error
return 1 / self.fps - 1e-4 return 1 / self.fps - 1e-4
@property
def shapes(self) -> dict:
"""Shapes for the different features."""
self.info.get("shapes")
def get_episode_data_index(self) -> dict[str, torch.Tensor]:
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(self.episode_dicts)}
if self.episodes is not None:
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in self.episodes}
cumulative_lenghts = list(accumulate(episode_lengths.values()))
return {
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
"to": torch.LongTensor(cumulative_lenghts),
}
def __len__(self): def __len__(self):
return self.num_samples return self.num_samples
@ -147,7 +209,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if self.video: if self.video:
item = load_from_videos( item = load_from_videos(
item, item,
self.video_frame_keys, self.video_keys,
self.videos_dir, self.videos_dir,
self.tolerance_s, self.tolerance_s,
self.video_backend, self.video_backend,
@ -225,7 +287,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def __init__( def __init__(
self, self,
repo_ids: list[str], repo_ids: list[str],
root: Path | None = DATA_DIR, root: Path | None = LEROBOT_HOME,
split: str = "train", split: str = "train",
image_transforms: Callable | None = None, image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None, delta_timestamps: dict[list[float]] | None = None,

View File

@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import json
import re
import warnings import warnings
from functools import cache from functools import cache
from pathlib import Path from pathlib import Path
@ -22,10 +21,9 @@ from typing import Dict
import datasets import datasets
import torch import torch
from datasets import load_dataset, load_from_disk from datasets import load_dataset
from huggingface_hub import DatasetCard, HfApi, hf_hub_download, snapshot_download from huggingface_hub import DatasetCard, HfApi, hf_hub_download, snapshot_download
from PIL import Image as PILImage from PIL import Image as PILImage
from safetensors.torch import load_file
from torchvision import transforms from torchvision import transforms
DATASET_CARD_TEMPLATE = """ DATASET_CARD_TEMPLATE = """
@ -96,7 +94,14 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
@cache @cache
def get_hf_dataset_safe_version(repo_id: str, version: str) -> str: def get_hub_safe_version(repo_id: str, version: str) -> str:
num_version = float(version.strip("v"))
if num_version < 2:
raise ValueError(
f"""The dataset you requested ({repo_id}) is in {version} format. We introduced a new
format with v2.0 that is not backward compatible. Please use our conversion script
first (convert_dataset_16_to_20.py) to convert your dataset to this new format."""
)
api = HfApi() api = HfApi()
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset") dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
branches = [b.name for b in dataset_info.branches] branches = [b.name for b in dataset_info.branches]
@ -116,56 +121,27 @@ def get_hf_dataset_safe_version(repo_id: str, version: str) -> str:
return version return version
def load_hf_dataset(repo_id: str, version: str, root: Path, split: str) -> datasets.Dataset: def load_hf_dataset(
local_dir: Path,
data_path: str,
total_episodes: int,
episodes: list[int] | None = None,
split: str = "train",
) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc.""" """hf_dataset contains all the observations, states, actions, rewards, etc."""
if root is not None: if episodes is None:
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train")) path = str(local_dir / "data")
# TODO(rcadene): clean this which enables getting a subset of dataset hf_dataset = load_dataset("parquet", data_dir=path, split=split)
if split != "train":
if "%" in split:
raise NotImplementedError(f"We dont support splitting based on percentage for now ({split}).")
match_from = re.search(r"train\[(\d+):\]", split)
match_to = re.search(r"train\[:(\d+)\]", split)
if match_from:
from_frame_index = int(match_from.group(1))
hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset)))
elif match_to:
to_frame_index = int(match_to.group(1))
hf_dataset = hf_dataset.select(range(to_frame_index))
else: else:
raise ValueError( files = [data_path.format(episode_index=ep_idx, total_episodes=total_episodes) for ep_idx in episodes]
f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"' files = [str(local_dir / fpath) for fpath in files]
) hf_dataset = load_dataset("parquet", data_files=files, split=split)
else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
hf_dataset = load_dataset(repo_id, revision=safe_version, split=split)
hf_dataset.set_transform(hf_transform_to_torch) hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset return hf_dataset
def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]: def load_stats(repo_id: str, version: str, local_dir: Path) -> dict[str, dict[str, torch.Tensor]]:
"""episode_data_index contains the range of indices for each episode
Example:
```python
from_id = episode_data_index["from"][episode_id].item()
to_id = episode_data_index["to"][episode_id].item()
episode_frames = [dataset[i] for i in range(from_id, to_id)]
```
"""
if root is not None:
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=safe_version
)
return load_file(path)
def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std """stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
Example: Example:
@ -173,47 +149,84 @@ def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"] normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
``` ```
""" """
if root is not None: fpath = hf_hub_download(
path = Path(root) / repo_id / "meta_data" / "stats.safetensors" repo_id, filename="meta/stats.json", local_dir=local_dir, repo_type="dataset", revision=version
else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(
repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=safe_version
) )
with open(fpath) as f:
stats = json.load(f)
stats = load_file(path) stats = flatten_dict(stats)
stats = {key: torch.tensor(value) for key, value in stats.items()}
return unflatten_dict(stats) return unflatten_dict(stats)
def load_info(repo_id, version, root) -> dict: def load_info(repo_id: str, version: str, local_dir: Path) -> dict:
"""info contains useful information regarding the dataset that are not stored elsewhere """info contains structural information about the dataset. It should be the reference and
act as the 'source of thruth' for what's inside the dataset.
Example: Example:
```python ```python
print("frame per second used to collect the video", info["fps"]) print("frame per second used to collect the video", info["fps"])
``` ```
""" """
if root is not None: fpath = hf_hub_download(
path = Path(root) / repo_id / "meta_data" / "info.json" repo_id, filename="meta/info.json", local_dir=local_dir, repo_type="dataset", revision=version
else: )
safe_version = get_hf_dataset_safe_version(repo_id, version) with open(fpath) as f:
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=safe_version) return json.load(f)
with open(path) as f:
info = json.load(f)
return info
def load_videos(repo_id, version, root) -> Path: def load_tasks(repo_id: str, version: str, local_dir: Path) -> dict:
if root is not None: """tasks contains all the tasks of the dataset, indexed by their task_index.
path = Path(root) / repo_id / "videos"
else:
# TODO(rcadene): we download the whole repo here. see if we can avoid this
safe_version = get_hf_dataset_safe_version(repo_id, version)
repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=safe_version)
path = Path(repo_dir) / "videos"
return path Example:
```json
{
"0": "Pick the Lego block and drop it in the box on the right."
}
```
"""
fpath = hf_hub_download(
repo_id, filename="meta/tasks.json", local_dir=local_dir, repo_type="dataset", revision=version
)
with open(fpath) as f:
return json.load(f)
def download_episodes(
repo_id: str,
version: str,
local_dir: Path,
data_path: str,
video_keys: list,
total_episodes: int,
episodes: list[int] | None = None,
videos_path: str | None = None,
) -> None:
"""Downloads the dataset from the given 'repo_id' at the provided 'version'. If 'episodes' is given, this
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
in 'local_dir', they won't be downloaded again.
Note: Currently, if you're running this code offline but you already have the files in 'local_dir',
snapshot_download will still fail. This behavior will be fixed in an upcoming update of huggingface_hub.
"""
# TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
files = None
if episodes is not None:
files = [data_path.format(episode_index=ep_idx, total_episodes=total_episodes) for ep_idx in episodes]
if len(video_keys) > 0:
video_files = [
videos_path.format(video_key=vid_key, episode_index=ep_idx)
for vid_key in video_keys
for ep_idx in episodes
]
files += video_files
snapshot_download(
repo_id, repo_type="dataset", revision=version, local_dir=local_dir, allow_patterns=files
)
def load_previous_and_future_frames( def load_previous_and_future_frames(