From 64b09ea7a7253a32aa22716b2bc3b0924c16b620 Mon Sep 17 00:00:00 2001 From: Cadene Date: Thu, 18 Apr 2024 23:54:52 +0000 Subject: [PATCH] WIP add load functions + episode_data_index --- download_and_upload_dataset.py | 31 ++++++------ lerobot/common/datasets/pusht.py | 22 +++++---- lerobot/common/datasets/utils.py | 85 +++++++++++++++++++++++++++++++- tests/test_datasets.py | 61 ++++++++++++++++++----- 4 files changed, 159 insertions(+), 40 deletions(-) diff --git a/download_and_upload_dataset.py b/download_and_upload_dataset.py index 6061a450..6a54833e 100644 --- a/download_and_upload_dataset.py +++ b/download_and_upload_dataset.py @@ -17,9 +17,9 @@ import tqdm from datasets import Dataset, Features, Image, Sequence, Value from huggingface_hub import HfApi from PIL import Image as PILImage -from safetensors.numpy import save_file +from safetensors.torch import save_file -from lerobot.common.datasets.utils import compute_stats +from lerobot.common.datasets.utils import compute_stats, flatten_dict def download_and_upload(root, revision, dataset_id): @@ -98,7 +98,7 @@ def push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id torch.save(stats, stats_pth_path) # create and store meta_data - meta_data_dir = root / dataset_id / "train" / "meta_data" + meta_data_dir = root / dataset_id / "meta_data" meta_data_dir.mkdir(parents=True, exist_ok=True) api = HfApi() @@ -115,18 +115,17 @@ def push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id ) # stats - for key in stats: - stats_path = meta_data_dir / f"stats_{key}.safetensors" - save_file(episode_data_index, stats_path) - api.upload_file( - path_or_fileobj=stats_path, - path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""), - repo_id=f"lerobot/{dataset_id}", - repo_type="dataset", - ) + stats_path = meta_data_dir / "stats.safetensors" + save_file(flatten_dict(stats), stats_path) + api.upload_file( + path_or_fileobj=stats_path, + path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""), + repo_id=f"lerobot/{dataset_id}", + repo_type="dataset", + ) # episode_data_index - episode_data_index = {key: np.array(episode_data_index[key]) for key in episode_data_index} + episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index} ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors" save_file(episode_data_index, ep_data_idx_path) api.upload_file( @@ -139,7 +138,7 @@ def push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id # copy in tests folder, the first episode and the meta_data directory num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0] hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train") - shutil.copytree(meta_data_dir, f"tests/{meta_data_dir}") + shutil.copytree(meta_data_dir, f"tests/data/{dataset_id}/meta_data") def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10): @@ -516,12 +515,12 @@ if __name__ == "__main__": revision = "v1.1" dataset_ids = [ - # "pusht", + "pusht", # "xarm_lift_medium", # "aloha_sim_insertion_human", # "aloha_sim_insertion_scripted", # "aloha_sim_transfer_cube_human", - "aloha_sim_transfer_cube_scripted", + # "aloha_sim_transfer_cube_scripted", ] for dataset_id in dataset_ids: download_and_upload(root, revision, dataset_id) diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 2879c177..7fdd88e0 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -1,9 +1,13 @@ from pathlib import Path import torch -from datasets import load_dataset, load_from_disk -from lerobot.common.datasets.utils import load_previous_and_future_frames +from lerobot.common.datasets.utils import ( + load_episode_data_index, + load_hf_dataset, + load_previous_and_future_frames, + load_stats, +) class PushtDataset(torch.utils.data.Dataset): @@ -38,13 +42,10 @@ class PushtDataset(torch.utils.data.Dataset): self.split = split self.transform = transform self.delta_timestamps = delta_timestamps - if self.root is not None: - self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split) - else: - self.hf_dataset = load_dataset( - f"lerobot/{self.dataset_id}", revision=self.version, split=self.split - ) - self.hf_dataset = self.hf_dataset.with_format("torch") + # load data from hub or locally when root is provided + self.hf_dataset = load_hf_dataset(dataset_id, version, root, split) + self.episode_data_index = load_episode_data_index(dataset_id, version, root) + self.stats = load_stats(dataset_id, version, root) @property def num_samples(self) -> int: @@ -52,7 +53,7 @@ class PushtDataset(torch.utils.data.Dataset): @property def num_episodes(self) -> int: - return len(self.hf_dataset.unique("episode_id")) + return len(self.episode_data_index["from"]) def __len__(self): return self.num_samples @@ -64,6 +65,7 @@ class PushtDataset(torch.utils.data.Dataset): item = load_previous_and_future_frames( item, self.hf_dataset, + self.episode_data_index, self.delta_timestamps, tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error ) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 50c50856..92799c2a 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -1,15 +1,93 @@ from copy import deepcopy from math import ceil +from pathlib import Path import datasets import einops import torch import tqdm +from datasets import load_dataset, load_from_disk +from huggingface_hub import hf_hub_download +from safetensors.torch import load_file + + +def flatten_dict(d, parent_key="", sep="/"): + items = [] + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def unflatten_dict(d, sep="/"): + outdict = {} + for key, value in d.items(): + parts = key.split(sep) + d = outdict + for part in parts[:-1]: + if part not in d: + d[part] = {} + d = d[part] + d[parts[-1]] = value + return outdict + + +def load_hf_dataset(dataset_id, version, root, split) -> datasets.Dataset: + """hf_dataset contains all the observations, states, actions, rewards, etc.""" + if root is not None: + hf_dataset = load_from_disk(Path(root) / dataset_id / split) + else: + repo_id = f"lerobot/{dataset_id}" + hf_dataset = load_dataset(repo_id, revision=version, split=split) + return hf_dataset.with_format("torch") + + +def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor]: + """episode_data_index contains the range of indices for each episode + + Example: + ```python + from_id = episode_data_index["from"][episode_id].item() + to_id = episode_data_index["to"][episode_id].item() + episode_frames = [dataset[i] for i in range(from_id, to_id)] + ``` + """ + if root is not None: + path = Path(root) / dataset_id / "meta_data" / "episode_data_index.safetensors" + else: + repo_id = f"lerobot/{dataset_id}" + path = hf_hub_download( + repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version + ) + + return load_file(path) + + +def load_stats(dataset_id, version, root) -> dict[str, dict[str, torch.Tensor]]: + """stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std + + Example: + ```python + normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"] + ``` + """ + if root is not None: + path = Path(root) / dataset_id / "meta_data" / "stats.safetensors" + else: + repo_id = f"lerobot/{dataset_id}" + path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version) + + stats = load_file(path) + return unflatten_dict(stats) def load_previous_and_future_frames( item: dict[str, torch.Tensor], hf_dataset: datasets.Dataset, + episode_data_index: dict[str, torch.Tensor], delta_timestamps: dict[str, list[float]], tol: float, ) -> dict[torch.Tensor]: @@ -31,6 +109,8 @@ def load_previous_and_future_frames( corresponds to a different modality (e.g., "timestamp", "observation.image", "action"). - hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action"). + - episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices. + They indicate the start index and end index of each episode in the dataset. - delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be retrieved. These deltas are added to the item timestamp to form the query timestamps. - tol (float, optional): The tolerance level used to determine if a data point is close enough to the query @@ -46,8 +126,9 @@ def load_previous_and_future_frames( issues with timestamps during data collection. """ # get indices of the frames associated to the episode, and their timestamps - ep_data_id_from = item["episode_data_index_from"].item() - ep_data_id_to = item["episode_data_index_to"].item() + ep_id = item["episode_id"].item() + ep_data_id_from = episode_data_index["from"][ep_id].item() + ep_data_id_to = episode_data_index["to"][ep_id].item() ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1) # load timestamps diff --git a/tests/test_datasets.py b/tests/test_datasets.py index e488c30b..3dee5fba 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,10 +1,13 @@ import logging +from copy import deepcopy +import json import os from pathlib import Path import einops import pytest import torch + from datasets import Dataset import lerobot @@ -13,6 +16,8 @@ from lerobot.common.datasets.utils import ( compute_stats, get_stats_einops_patterns, load_previous_and_future_frames, + flatten_dict, + unflatten_dict, ) from lerobot.common.transforms import Prod from lerobot.common.utils.utils import init_hydra_config @@ -160,15 +165,18 @@ def test_load_previous_and_future_frames_within_tolerance(): { "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], "index": [0, 1, 2, 3, 4], - "episode_data_index_from": [0, 0, 0, 0, 0], - "episode_data_index_to": [5, 5, 5, 5, 5], + "episode_id": [0, 0, 0, 0, 0], } ) hf_dataset = hf_dataset.with_format("torch") - item = hf_dataset[2] + episode_data_index = { + "from": torch.tensor([0]), + "to": torch.tensor([5]), + } delta_timestamps = {"index": [-0.2, 0, 0.139]} tol = 0.04 - item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol) + item = hf_dataset[2] + item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol) data, is_pad = item["index"], item["index_is_pad"] assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values" assert not is_pad.any(), "Unexpected padding detected" @@ -179,16 +187,19 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range( { "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], "index": [0, 1, 2, 3, 4], - "episode_data_index_from": [0, 0, 0, 0, 0], - "episode_data_index_to": [5, 5, 5, 5, 5], + "episode_id": [0, 0, 0, 0, 0], } ) hf_dataset = hf_dataset.with_format("torch") - item = hf_dataset[2] + episode_data_index = { + "from": torch.tensor([0]), + "to": torch.tensor([5]), + } delta_timestamps = {"index": [-0.2, 0, 0.141]} tol = 0.04 + item = hf_dataset[2] with pytest.raises(AssertionError): - load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol) + load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol) def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range(): @@ -196,17 +207,43 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range { "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], "index": [0, 1, 2, 3, 4], - "episode_data_index_from": [0, 0, 0, 0, 0], - "episode_data_index_to": [5, 5, 5, 5, 5], + "episode_id": [0, 0, 0, 0, 0], } ) hf_dataset = hf_dataset.with_format("torch") - item = hf_dataset[2] + episode_data_index = { + "from": torch.tensor([0]), + "to": torch.tensor([5]), + } delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]} tol = 0.04 - item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol) + item = hf_dataset[2] + item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol) data, is_pad = item["index"], item["index_is_pad"] assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values" assert torch.equal( is_pad, torch.tensor([True, False, False, True, True]) ), "Padding does not match expected values" + + +def test_flatten_unflatten_dict(): + d = { + "obs": { + "min": 0, + "max": 1, + "mean": 2, + "std": 3, + }, + "action": { + "min": 4, + "max": 5, + "mean": 6, + "std": 7, + }, + } + + original_d = deepcopy(d) + d = unflatten_dict(flatten_dict(d)) + + # test equality between nested dicts + assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"