WIP add load functions + episode_data_index

This commit is contained in:
Cadene 2024-04-18 23:54:52 +00:00
parent 0bd2ca8d82
commit 64b09ea7a7
4 changed files with 159 additions and 40 deletions

View File

@ -17,9 +17,9 @@ import tqdm
from datasets import Dataset, Features, Image, Sequence, Value from datasets import Dataset, Features, Image, Sequence, Value
from huggingface_hub import HfApi from huggingface_hub import HfApi
from PIL import Image as PILImage from PIL import Image as PILImage
from safetensors.numpy import save_file from safetensors.torch import save_file
from lerobot.common.datasets.utils import compute_stats from lerobot.common.datasets.utils import compute_stats, flatten_dict
def download_and_upload(root, revision, dataset_id): def download_and_upload(root, revision, dataset_id):
@ -98,7 +98,7 @@ def push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id
torch.save(stats, stats_pth_path) torch.save(stats, stats_pth_path)
# create and store meta_data # create and store meta_data
meta_data_dir = root / dataset_id / "train" / "meta_data" meta_data_dir = root / dataset_id / "meta_data"
meta_data_dir.mkdir(parents=True, exist_ok=True) meta_data_dir.mkdir(parents=True, exist_ok=True)
api = HfApi() api = HfApi()
@ -115,9 +115,8 @@ def push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id
) )
# stats # stats
for key in stats: stats_path = meta_data_dir / "stats.safetensors"
stats_path = meta_data_dir / f"stats_{key}.safetensors" save_file(flatten_dict(stats), stats_path)
save_file(episode_data_index, stats_path)
api.upload_file( api.upload_file(
path_or_fileobj=stats_path, path_or_fileobj=stats_path,
path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""), path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""),
@ -126,7 +125,7 @@ def push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id
) )
# episode_data_index # episode_data_index
episode_data_index = {key: np.array(episode_data_index[key]) for key in episode_data_index} episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors" ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
save_file(episode_data_index, ep_data_idx_path) save_file(episode_data_index, ep_data_idx_path)
api.upload_file( api.upload_file(
@ -139,7 +138,7 @@ def push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id
# copy in tests folder, the first episode and the meta_data directory # copy in tests folder, the first episode and the meta_data directory
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0] num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train") hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train")
shutil.copytree(meta_data_dir, f"tests/{meta_data_dir}") shutil.copytree(meta_data_dir, f"tests/data/{dataset_id}/meta_data")
def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10): def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
@ -516,12 +515,12 @@ if __name__ == "__main__":
revision = "v1.1" revision = "v1.1"
dataset_ids = [ dataset_ids = [
# "pusht", "pusht",
# "xarm_lift_medium", # "xarm_lift_medium",
# "aloha_sim_insertion_human", # "aloha_sim_insertion_human",
# "aloha_sim_insertion_scripted", # "aloha_sim_insertion_scripted",
# "aloha_sim_transfer_cube_human", # "aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted", # "aloha_sim_transfer_cube_scripted",
] ]
for dataset_id in dataset_ids: for dataset_id in dataset_ids:
download_and_upload(root, revision, dataset_id) download_and_upload(root, revision, dataset_id)

View File

@ -1,9 +1,13 @@
from pathlib import Path from pathlib import Path
import torch import torch
from datasets import load_dataset, load_from_disk
from lerobot.common.datasets.utils import load_previous_and_future_frames from lerobot.common.datasets.utils import (
load_episode_data_index,
load_hf_dataset,
load_previous_and_future_frames,
load_stats,
)
class PushtDataset(torch.utils.data.Dataset): class PushtDataset(torch.utils.data.Dataset):
@ -38,13 +42,10 @@ class PushtDataset(torch.utils.data.Dataset):
self.split = split self.split = split
self.transform = transform self.transform = transform
self.delta_timestamps = delta_timestamps self.delta_timestamps = delta_timestamps
if self.root is not None: # load data from hub or locally when root is provided
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split) self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
else: self.episode_data_index = load_episode_data_index(dataset_id, version, root)
self.hf_dataset = load_dataset( self.stats = load_stats(dataset_id, version, root)
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
)
self.hf_dataset = self.hf_dataset.with_format("torch")
@property @property
def num_samples(self) -> int: def num_samples(self) -> int:
@ -52,7 +53,7 @@ class PushtDataset(torch.utils.data.Dataset):
@property @property
def num_episodes(self) -> int: def num_episodes(self) -> int:
return len(self.hf_dataset.unique("episode_id")) return len(self.episode_data_index["from"])
def __len__(self): def __len__(self):
return self.num_samples return self.num_samples
@ -64,6 +65,7 @@ class PushtDataset(torch.utils.data.Dataset):
item = load_previous_and_future_frames( item = load_previous_and_future_frames(
item, item,
self.hf_dataset, self.hf_dataset,
self.episode_data_index,
self.delta_timestamps, self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
) )

View File

@ -1,15 +1,93 @@
from copy import deepcopy from copy import deepcopy
from math import ceil from math import ceil
from pathlib import Path
import datasets import datasets
import einops import einops
import torch import torch
import tqdm import tqdm
from datasets import load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
def flatten_dict(d, parent_key="", sep="/"):
items = []
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
def unflatten_dict(d, sep="/"):
outdict = {}
for key, value in d.items():
parts = key.split(sep)
d = outdict
for part in parts[:-1]:
if part not in d:
d[part] = {}
d = d[part]
d[parts[-1]] = value
return outdict
def load_hf_dataset(dataset_id, version, root, split) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if root is not None:
hf_dataset = load_from_disk(Path(root) / dataset_id / split)
else:
repo_id = f"lerobot/{dataset_id}"
hf_dataset = load_dataset(repo_id, revision=version, split=split)
return hf_dataset.with_format("torch")
def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor]:
"""episode_data_index contains the range of indices for each episode
Example:
```python
from_id = episode_data_index["from"][episode_id].item()
to_id = episode_data_index["to"][episode_id].item()
episode_frames = [dataset[i] for i in range(from_id, to_id)]
```
"""
if root is not None:
path = Path(root) / dataset_id / "meta_data" / "episode_data_index.safetensors"
else:
repo_id = f"lerobot/{dataset_id}"
path = hf_hub_download(
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version
)
return load_file(path)
def load_stats(dataset_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
Example:
```python
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
```
"""
if root is not None:
path = Path(root) / dataset_id / "meta_data" / "stats.safetensors"
else:
repo_id = f"lerobot/{dataset_id}"
path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version)
stats = load_file(path)
return unflatten_dict(stats)
def load_previous_and_future_frames( def load_previous_and_future_frames(
item: dict[str, torch.Tensor], item: dict[str, torch.Tensor],
hf_dataset: datasets.Dataset, hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
delta_timestamps: dict[str, list[float]], delta_timestamps: dict[str, list[float]],
tol: float, tol: float,
) -> dict[torch.Tensor]: ) -> dict[torch.Tensor]:
@ -31,6 +109,8 @@ def load_previous_and_future_frames(
corresponds to a different modality (e.g., "timestamp", "observation.image", "action"). corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different - hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different
modality (e.g., "timestamp", "observation.image", "action"). modality (e.g., "timestamp", "observation.image", "action").
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
They indicate the start index and end index of each episode in the dataset.
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be - delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be
retrieved. These deltas are added to the item timestamp to form the query timestamps. retrieved. These deltas are added to the item timestamp to form the query timestamps.
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query - tol (float, optional): The tolerance level used to determine if a data point is close enough to the query
@ -46,8 +126,9 @@ def load_previous_and_future_frames(
issues with timestamps during data collection. issues with timestamps during data collection.
""" """
# get indices of the frames associated to the episode, and their timestamps # get indices of the frames associated to the episode, and their timestamps
ep_data_id_from = item["episode_data_index_from"].item() ep_id = item["episode_id"].item()
ep_data_id_to = item["episode_data_index_to"].item() ep_data_id_from = episode_data_index["from"][ep_id].item()
ep_data_id_to = episode_data_index["to"][ep_id].item()
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1) ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
# load timestamps # load timestamps

View File

@ -1,10 +1,13 @@
import logging import logging
from copy import deepcopy
import json
import os import os
from pathlib import Path from pathlib import Path
import einops import einops
import pytest import pytest
import torch import torch
from datasets import Dataset from datasets import Dataset
import lerobot import lerobot
@ -13,6 +16,8 @@ from lerobot.common.datasets.utils import (
compute_stats, compute_stats,
get_stats_einops_patterns, get_stats_einops_patterns,
load_previous_and_future_frames, load_previous_and_future_frames,
flatten_dict,
unflatten_dict,
) )
from lerobot.common.transforms import Prod from lerobot.common.transforms import Prod
from lerobot.common.utils.utils import init_hydra_config from lerobot.common.utils.utils import init_hydra_config
@ -160,15 +165,18 @@ def test_load_previous_and_future_frames_within_tolerance():
{ {
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4], "index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0], "episode_id": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
} }
) )
hf_dataset = hf_dataset.with_format("torch") hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2] episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.2, 0, 0.139]} delta_timestamps = {"index": [-0.2, 0, 0.139]}
tol = 0.04 tol = 0.04
item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol) item = hf_dataset[2]
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
data, is_pad = item["index"], item["index_is_pad"] data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values" assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
assert not is_pad.any(), "Unexpected padding detected" assert not is_pad.any(), "Unexpected padding detected"
@ -179,16 +187,19 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range(
{ {
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4], "index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0], "episode_id": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
} }
) )
hf_dataset = hf_dataset.with_format("torch") hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2] episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.2, 0, 0.141]} delta_timestamps = {"index": [-0.2, 0, 0.141]}
tol = 0.04 tol = 0.04
item = hf_dataset[2]
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol) load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range(): def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range():
@ -196,17 +207,43 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range
{ {
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4], "index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0], "episode_id": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
} }
) )
hf_dataset = hf_dataset.with_format("torch") hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2] episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]} delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
tol = 0.04 tol = 0.04
item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol) item = hf_dataset[2]
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
data, is_pad = item["index"], item["index_is_pad"] data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values" assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
assert torch.equal( assert torch.equal(
is_pad, torch.tensor([True, False, False, True, True]) is_pad, torch.tensor([True, False, False, True, True])
), "Padding does not match expected values" ), "Padding does not match expected values"
def test_flatten_unflatten_dict():
d = {
"obs": {
"min": 0,
"max": 1,
"mean": 2,
"std": 3,
},
"action": {
"min": 4,
"max": 5,
"mean": 6,
"std": 7,
},
}
original_d = deepcopy(d)
d = unflatten_dict(flatten_dict(d))
# test equality between nested dicts
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"