WIP add load functions + episode_data_index

This commit is contained in:
Cadene 2024-04-18 23:54:52 +00:00
parent 0bd2ca8d82
commit 64b09ea7a7
4 changed files with 159 additions and 40 deletions

View File

@ -17,9 +17,9 @@ import tqdm
from datasets import Dataset, Features, Image, Sequence, Value
from huggingface_hub import HfApi
from PIL import Image as PILImage
from safetensors.numpy import save_file
from safetensors.torch import save_file
from lerobot.common.datasets.utils import compute_stats
from lerobot.common.datasets.utils import compute_stats, flatten_dict
def download_and_upload(root, revision, dataset_id):
@ -98,7 +98,7 @@ def push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id
torch.save(stats, stats_pth_path)
# create and store meta_data
meta_data_dir = root / dataset_id / "train" / "meta_data"
meta_data_dir = root / dataset_id / "meta_data"
meta_data_dir.mkdir(parents=True, exist_ok=True)
api = HfApi()
@ -115,9 +115,8 @@ def push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id
)
# stats
for key in stats:
stats_path = meta_data_dir / f"stats_{key}.safetensors"
save_file(episode_data_index, stats_path)
stats_path = meta_data_dir / "stats.safetensors"
save_file(flatten_dict(stats), stats_path)
api.upload_file(
path_or_fileobj=stats_path,
path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""),
@ -126,7 +125,7 @@ def push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id
)
# episode_data_index
episode_data_index = {key: np.array(episode_data_index[key]) for key in episode_data_index}
episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
save_file(episode_data_index, ep_data_idx_path)
api.upload_file(
@ -139,7 +138,7 @@ def push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id
# copy in tests folder, the first episode and the meta_data directory
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train")
shutil.copytree(meta_data_dir, f"tests/{meta_data_dir}")
shutil.copytree(meta_data_dir, f"tests/data/{dataset_id}/meta_data")
def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
@ -516,12 +515,12 @@ if __name__ == "__main__":
revision = "v1.1"
dataset_ids = [
# "pusht",
"pusht",
# "xarm_lift_medium",
# "aloha_sim_insertion_human",
# "aloha_sim_insertion_scripted",
# "aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
# "aloha_sim_transfer_cube_scripted",
]
for dataset_id in dataset_ids:
download_and_upload(root, revision, dataset_id)

View File

@ -1,9 +1,13 @@
from pathlib import Path
import torch
from datasets import load_dataset, load_from_disk
from lerobot.common.datasets.utils import load_previous_and_future_frames
from lerobot.common.datasets.utils import (
load_episode_data_index,
load_hf_dataset,
load_previous_and_future_frames,
load_stats,
)
class PushtDataset(torch.utils.data.Dataset):
@ -38,13 +42,10 @@ class PushtDataset(torch.utils.data.Dataset):
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
if self.root is not None:
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
else:
self.hf_dataset = load_dataset(
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
)
self.hf_dataset = self.hf_dataset.with_format("torch")
# load data from hub or locally when root is provided
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
self.stats = load_stats(dataset_id, version, root)
@property
def num_samples(self) -> int:
@ -52,7 +53,7 @@ class PushtDataset(torch.utils.data.Dataset):
@property
def num_episodes(self) -> int:
return len(self.hf_dataset.unique("episode_id"))
return len(self.episode_data_index["from"])
def __len__(self):
return self.num_samples
@ -64,6 +65,7 @@ class PushtDataset(torch.utils.data.Dataset):
item = load_previous_and_future_frames(
item,
self.hf_dataset,
self.episode_data_index,
self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
)

View File

@ -1,15 +1,93 @@
from copy import deepcopy
from math import ceil
from pathlib import Path
import datasets
import einops
import torch
import tqdm
from datasets import load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
def flatten_dict(d, parent_key="", sep="/"):
items = []
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
def unflatten_dict(d, sep="/"):
outdict = {}
for key, value in d.items():
parts = key.split(sep)
d = outdict
for part in parts[:-1]:
if part not in d:
d[part] = {}
d = d[part]
d[parts[-1]] = value
return outdict
def load_hf_dataset(dataset_id, version, root, split) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if root is not None:
hf_dataset = load_from_disk(Path(root) / dataset_id / split)
else:
repo_id = f"lerobot/{dataset_id}"
hf_dataset = load_dataset(repo_id, revision=version, split=split)
return hf_dataset.with_format("torch")
def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor]:
"""episode_data_index contains the range of indices for each episode
Example:
```python
from_id = episode_data_index["from"][episode_id].item()
to_id = episode_data_index["to"][episode_id].item()
episode_frames = [dataset[i] for i in range(from_id, to_id)]
```
"""
if root is not None:
path = Path(root) / dataset_id / "meta_data" / "episode_data_index.safetensors"
else:
repo_id = f"lerobot/{dataset_id}"
path = hf_hub_download(
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version
)
return load_file(path)
def load_stats(dataset_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
Example:
```python
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
```
"""
if root is not None:
path = Path(root) / dataset_id / "meta_data" / "stats.safetensors"
else:
repo_id = f"lerobot/{dataset_id}"
path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version)
stats = load_file(path)
return unflatten_dict(stats)
def load_previous_and_future_frames(
item: dict[str, torch.Tensor],
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
delta_timestamps: dict[str, list[float]],
tol: float,
) -> dict[torch.Tensor]:
@ -31,6 +109,8 @@ def load_previous_and_future_frames(
corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different
modality (e.g., "timestamp", "observation.image", "action").
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
They indicate the start index and end index of each episode in the dataset.
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be
retrieved. These deltas are added to the item timestamp to form the query timestamps.
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query
@ -46,8 +126,9 @@ def load_previous_and_future_frames(
issues with timestamps during data collection.
"""
# get indices of the frames associated to the episode, and their timestamps
ep_data_id_from = item["episode_data_index_from"].item()
ep_data_id_to = item["episode_data_index_to"].item()
ep_id = item["episode_id"].item()
ep_data_id_from = episode_data_index["from"][ep_id].item()
ep_data_id_to = episode_data_index["to"][ep_id].item()
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
# load timestamps

View File

@ -1,10 +1,13 @@
import logging
from copy import deepcopy
import json
import os
from pathlib import Path
import einops
import pytest
import torch
from datasets import Dataset
import lerobot
@ -13,6 +16,8 @@ from lerobot.common.datasets.utils import (
compute_stats,
get_stats_einops_patterns,
load_previous_and_future_frames,
flatten_dict,
unflatten_dict,
)
from lerobot.common.transforms import Prod
from lerobot.common.utils.utils import init_hydra_config
@ -160,15 +165,18 @@ def test_load_previous_and_future_frames_within_tolerance():
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
"episode_id": [0, 0, 0, 0, 0],
}
)
hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2]
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.2, 0, 0.139]}
tol = 0.04
item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
item = hf_dataset[2]
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
assert not is_pad.any(), "Unexpected padding detected"
@ -179,16 +187,19 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
"episode_id": [0, 0, 0, 0, 0],
}
)
hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2]
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.2, 0, 0.141]}
tol = 0.04
item = hf_dataset[2]
with pytest.raises(AssertionError):
load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range():
@ -196,17 +207,43 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
"episode_id": [0, 0, 0, 0, 0],
}
)
hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2]
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
tol = 0.04
item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
item = hf_dataset[2]
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
assert torch.equal(
is_pad, torch.tensor([True, False, False, True, True])
), "Padding does not match expected values"
def test_flatten_unflatten_dict():
d = {
"obs": {
"min": 0,
"max": 1,
"mean": 2,
"std": 3,
},
"action": {
"min": 4,
"max": 5,
"mean": 6,
"std": 7,
},
}
original_d = deepcopy(d)
d = unflatten_dict(flatten_dict(d))
# test equality between nested dicts
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"