Rework LeRobotDataset.__init__

This commit is contained in:
Simon Alibert 2024-10-09 14:33:26 +02:00
parent 2d75b93ba0
commit 096824b5ff
2 changed files with 189 additions and 114 deletions

View File

@ -15,6 +15,7 @@
# limitations under the License.
import logging
import os
from itertools import accumulate
from pathlib import Path
from typing import Callable
@ -24,27 +25,27 @@ import torch.utils
from lerobot.common.datasets.compute_stats import aggregate_stats
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
load_episode_data_index,
download_episodes,
get_hub_safe_version,
load_hf_dataset,
load_info,
load_previous_and_future_frames,
load_stats,
load_videos,
reset_episode_index,
load_tasks,
)
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
CODEBASE_VERSION = "v1.6"
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
CODEBASE_VERSION = "v2.0"
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
class LeRobotDataset(torch.utils.data.Dataset):
def __init__(
self,
repo_id: str,
root: Path | None = DATA_DIR,
root: Path | None = None,
episodes: list[int] | None = None,
split: str = "train",
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
@ -52,49 +53,89 @@ class LeRobotDataset(torch.utils.data.Dataset):
):
super().__init__()
self.repo_id = repo_id
self.root = root
self.root = root if root is not None else LEROBOT_HOME / repo_id
self.split = split
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided
# TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, root, split)
if split == "train":
self.episode_data_index = load_episode_data_index(repo_id, CODEBASE_VERSION, root)
else:
self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
self.hf_dataset = reset_episode_index(self.hf_dataset)
self.stats = load_stats(repo_id, CODEBASE_VERSION, root)
self.info = load_info(repo_id, CODEBASE_VERSION, root)
if self.video:
self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
self.episodes = episodes
self.video_backend = video_backend if video_backend is not None else "pyav"
# Load metadata
self.root.mkdir(exist_ok=True, parents=True)
self._version = get_hub_safe_version(repo_id, CODEBASE_VERSION)
self.info = load_info(repo_id, self._version, self.root)
self.stats = load_stats(repo_id, self._version, self.root)
self.tasks = load_tasks(repo_id, self._version, self.root)
# Load actual data
download_episodes(
repo_id,
self._version,
self.root,
self.data_path,
self.video_keys,
self.num_episodes,
self.episodes,
self.videos_path,
)
self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes)
self.episode_data_index = self.get_episode_data_index()
# TODO(aliberts):
# - [ ] Update __get_item__
# - [ ] Add self.consolidate() for:
# - [ ] Sanity checks (episodes num, shapes, files, etc.)
# - [ ] Update episode_index (arg update=True)
# - [ ] Update info.json (arg update=True)
# TODO(aliberts): remove (deprecated)
# if split == "train":
# self.episode_data_index = load_episode_data_index(self.episodes, self.episode_list)
# else:
# self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
# self.hf_dataset = reset_episode_index(self.hf_dataset)
# if self.video:
# self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
@property
def data_path(self) -> str:
"""Formattable string for the parquet files."""
return self.info["data_path"]
@property
def videos_path(self) -> str | None:
"""Formattable string for the video files."""
return self.info["videos"]["videos_path"] if len(self.video_keys) > 0 else None
@property
def episode_dicts(self) -> list[dict]:
"""List of dictionary containing information for each episode, indexed by episode_index."""
return self.info["episodes"]
@property
def fps(self) -> int:
"""Frames per second used during data collection."""
return self.info["fps"]
@property
def video(self) -> bool:
"""Returns True if this dataset loads video frames from mp4 files.
Returns False if it only loads images from png files.
"""
return self.info.get("video", False)
def keys(self) -> list[str]:
"""Keys to access non-image data (state, actions etc.)."""
return self.info["keys"]
@property
def features(self) -> datasets.Features:
return self.hf_dataset.features
def image_keys(self) -> list[str]:
"""Keys to access visual modalities stored as images."""
return self.info["image_keys"]
@property
def video_keys(self) -> list[str]:
"""Keys to access visual modalities stored as videos."""
return self.info["video_keys"]
@property
def camera_keys(self) -> list[str]:
"""Keys to access image and video stream from cameras."""
keys = []
for key, feats in self.hf_dataset.features.items():
if isinstance(feats, (datasets.Image, VideoFrame)):
keys.append(key)
return keys
"""Keys to access image and video streams from cameras."""
return self.image_keys + self.video_keys
@property
def video_frame_keys(self) -> list[str]:
@ -117,8 +158,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property
def num_episodes(self) -> int:
"""Number of episodes."""
return len(self.hf_dataset.unique("episode_index"))
"""Number of episodes selected."""
return len(self.episodes) if self.episodes is not None else self.total_episodes
@property
def total_episodes(self) -> int:
"""Total number of episodes available."""
return self.info["total_episodes"]
@property
def tolerance_s(self) -> float:
@ -129,6 +175,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
# 1e-4 to account for possible numerical error
return 1 / self.fps - 1e-4
@property
def shapes(self) -> dict:
"""Shapes for the different features."""
self.info.get("shapes")
def get_episode_data_index(self) -> dict[str, torch.Tensor]:
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(self.episode_dicts)}
if self.episodes is not None:
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in self.episodes}
cumulative_lenghts = list(accumulate(episode_lengths.values()))
return {
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
"to": torch.LongTensor(cumulative_lenghts),
}
def __len__(self):
return self.num_samples
@ -147,7 +209,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if self.video:
item = load_from_videos(
item,
self.video_frame_keys,
self.video_keys,
self.videos_dir,
self.tolerance_s,
self.video_backend,
@ -225,7 +287,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def __init__(
self,
repo_ids: list[str],
root: Path | None = DATA_DIR,
root: Path | None = LEROBOT_HOME,
split: str = "train",
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,

View File

@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
import warnings
from functools import cache
from pathlib import Path
@ -22,10 +21,9 @@ from typing import Dict
import datasets
import torch
from datasets import load_dataset, load_from_disk
from datasets import load_dataset
from huggingface_hub import DatasetCard, HfApi, hf_hub_download, snapshot_download
from PIL import Image as PILImage
from safetensors.torch import load_file
from torchvision import transforms
DATASET_CARD_TEMPLATE = """
@ -96,7 +94,14 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
@cache
def get_hf_dataset_safe_version(repo_id: str, version: str) -> str:
def get_hub_safe_version(repo_id: str, version: str) -> str:
num_version = float(version.strip("v"))
if num_version < 2:
raise ValueError(
f"""The dataset you requested ({repo_id}) is in {version} format. We introduced a new
format with v2.0 that is not backward compatible. Please use our conversion script
first (convert_dataset_16_to_20.py) to convert your dataset to this new format."""
)
api = HfApi()
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
branches = [b.name for b in dataset_info.branches]
@ -116,56 +121,27 @@ def get_hf_dataset_safe_version(repo_id: str, version: str) -> str:
return version
def load_hf_dataset(repo_id: str, version: str, root: Path, split: str) -> datasets.Dataset:
def load_hf_dataset(
local_dir: Path,
data_path: str,
total_episodes: int,
episodes: list[int] | None = None,
split: str = "train",
) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if root is not None:
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
# TODO(rcadene): clean this which enables getting a subset of dataset
if split != "train":
if "%" in split:
raise NotImplementedError(f"We dont support splitting based on percentage for now ({split}).")
match_from = re.search(r"train\[(\d+):\]", split)
match_to = re.search(r"train\[:(\d+)\]", split)
if match_from:
from_frame_index = int(match_from.group(1))
hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset)))
elif match_to:
to_frame_index = int(match_to.group(1))
hf_dataset = hf_dataset.select(range(to_frame_index))
if episodes is None:
path = str(local_dir / "data")
hf_dataset = load_dataset("parquet", data_dir=path, split=split)
else:
raise ValueError(
f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"'
)
else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
hf_dataset = load_dataset(repo_id, revision=safe_version, split=split)
files = [data_path.format(episode_index=ep_idx, total_episodes=total_episodes) for ep_idx in episodes]
files = [str(local_dir / fpath) for fpath in files]
hf_dataset = load_dataset("parquet", data_files=files, split=split)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
"""episode_data_index contains the range of indices for each episode
Example:
```python
from_id = episode_data_index["from"][episode_id].item()
to_id = episode_data_index["to"][episode_id].item()
episode_frames = [dataset[i] for i in range(from_id, to_id)]
```
"""
if root is not None:
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=safe_version
)
return load_file(path)
def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
def load_stats(repo_id: str, version: str, local_dir: Path) -> dict[str, dict[str, torch.Tensor]]:
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
Example:
@ -173,47 +149,84 @@ def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
```
"""
if root is not None:
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(
repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=safe_version
fpath = hf_hub_download(
repo_id, filename="meta/stats.json", local_dir=local_dir, repo_type="dataset", revision=version
)
with open(fpath) as f:
stats = json.load(f)
stats = load_file(path)
stats = flatten_dict(stats)
stats = {key: torch.tensor(value) for key, value in stats.items()}
return unflatten_dict(stats)
def load_info(repo_id, version, root) -> dict:
"""info contains useful information regarding the dataset that are not stored elsewhere
def load_info(repo_id: str, version: str, local_dir: Path) -> dict:
"""info contains structural information about the dataset. It should be the reference and
act as the 'source of thruth' for what's inside the dataset.
Example:
```python
print("frame per second used to collect the video", info["fps"])
```
"""
if root is not None:
path = Path(root) / repo_id / "meta_data" / "info.json"
else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=safe_version)
with open(path) as f:
info = json.load(f)
return info
fpath = hf_hub_download(
repo_id, filename="meta/info.json", local_dir=local_dir, repo_type="dataset", revision=version
)
with open(fpath) as f:
return json.load(f)
def load_videos(repo_id, version, root) -> Path:
if root is not None:
path = Path(root) / repo_id / "videos"
else:
# TODO(rcadene): we download the whole repo here. see if we can avoid this
safe_version = get_hf_dataset_safe_version(repo_id, version)
repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=safe_version)
path = Path(repo_dir) / "videos"
def load_tasks(repo_id: str, version: str, local_dir: Path) -> dict:
"""tasks contains all the tasks of the dataset, indexed by their task_index.
return path
Example:
```json
{
"0": "Pick the Lego block and drop it in the box on the right."
}
```
"""
fpath = hf_hub_download(
repo_id, filename="meta/tasks.json", local_dir=local_dir, repo_type="dataset", revision=version
)
with open(fpath) as f:
return json.load(f)
def download_episodes(
repo_id: str,
version: str,
local_dir: Path,
data_path: str,
video_keys: list,
total_episodes: int,
episodes: list[int] | None = None,
videos_path: str | None = None,
) -> None:
"""Downloads the dataset from the given 'repo_id' at the provided 'version'. If 'episodes' is given, this
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
in 'local_dir', they won't be downloaded again.
Note: Currently, if you're running this code offline but you already have the files in 'local_dir',
snapshot_download will still fail. This behavior will be fixed in an upcoming update of huggingface_hub.
"""
# TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
files = None
if episodes is not None:
files = [data_path.format(episode_index=ep_idx, total_episodes=total_episodes) for ep_idx in episodes]
if len(video_keys) > 0:
video_files = [
videos_path.format(video_key=vid_key, episode_index=ep_idx)
for vid_key in video_keys
for ep_idx in episodes
]
files += video_files
snapshot_download(
repo_id, repo_type="dataset", revision=version, local_dir=local_dir, allow_patterns=files
)
def load_previous_and_future_frames(