Loads episode_data_index and stats during dataset __init__ (#85)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi 2024-04-23 14:13:25 +02:00 committed by GitHub
parent e2168163cd
commit 1030ea0070
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
89 changed files with 1008 additions and 432 deletions

31
.github/poetry/cpu/poetry.lock generated vendored
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]]
name = "absl-py"
@ -522,21 +522,21 @@ toml = ["tomli"]
[[package]]
name = "datasets"
version = "2.18.0"
version = "2.19.0"
description = "HuggingFace community-driven open-source library of datasets"
optional = false
python-versions = ">=3.8.0"
files = [
{file = "datasets-2.18.0-py3-none-any.whl", hash = "sha256:f1bbf0e2896917a914de01cbd37075b14deea3837af87ad0d9f697388ccaeb50"},
{file = "datasets-2.18.0.tar.gz", hash = "sha256:cdf8b8c6abf7316377ba4f49f9589a4c74556d6b481afd0abd2284f3d69185cb"},
{file = "datasets-2.19.0-py3-none-any.whl", hash = "sha256:f57c5316e123d4721b970c68c1cb856505f289cda58f5557ffe745b49c011a8e"},
{file = "datasets-2.19.0.tar.gz", hash = "sha256:0b47e08cc7af2c6800a42cadc4657b22a0afc7197786c8986d703c08d90886a6"},
]
[package.dependencies]
aiohttp = "*"
dill = ">=0.3.0,<0.3.9"
filelock = "*"
fsspec = {version = ">=2023.1.0,<=2024.2.0", extras = ["http"]}
huggingface-hub = ">=0.19.4"
fsspec = {version = ">=2023.1.0,<=2024.3.1", extras = ["http"]}
huggingface-hub = ">=0.21.2"
multiprocess = "*"
numpy = ">=1.17"
packaging = "*"
@ -552,15 +552,15 @@ xxhash = "*"
apache-beam = ["apache-beam (>=2.26.0)"]
audio = ["librosa", "soundfile (>=0.12.1)"]
benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"]
dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
docs = ["s3fs", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos", "torch", "transformers"]
dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
docs = ["s3fs", "tensorflow (>=2.6.0)", "torch", "transformers"]
jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"]
metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"]
quality = ["ruff (>=0.3.0)"]
s3 = ["s3fs"]
tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos"]
tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"]
tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
tensorflow = ["tensorflow (>=2.6.0)"]
tensorflow-gpu = ["tensorflow (>=2.6.0)"]
tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
torch = ["torch"]
vision = ["Pillow (>=6.2.1)"]
@ -1524,7 +1524,6 @@ description = "Powerful and Pythonic XML processing library combining libxml2/li
optional = true
python-versions = ">=3.6"
files = [
{file = "lxml-5.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:704f5572ff473a5f897745abebc6df40f22d4133c1e0a1f124e4f2bd3330ff7e"},
{file = "lxml-5.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9d3c0f8567ffe7502d969c2c1b809892dc793b5d0665f602aad19895f8d508da"},
{file = "lxml-5.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5fcfbebdb0c5d8d18b84118842f31965d59ee3e66996ac842e21f957eb76138c"},
{file = "lxml-5.1.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2f37c6d7106a9d6f0708d4e164b707037b7380fcd0b04c5bd9cae1fb46a856fb"},
@ -1534,7 +1533,6 @@ files = [
{file = "lxml-5.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:82bddf0e72cb2af3cbba7cec1d2fd11fda0de6be8f4492223d4a268713ef2147"},
{file = "lxml-5.1.0-cp310-cp310-win32.whl", hash = "sha256:b66aa6357b265670bb574f050ffceefb98549c721cf28351b748be1ef9577d93"},
{file = "lxml-5.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:4946e7f59b7b6a9e27bef34422f645e9a368cb2be11bf1ef3cafc39a1f6ba68d"},
{file = "lxml-5.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:14deca1460b4b0f6b01f1ddc9557704e8b365f55c63070463f6c18619ebf964f"},
{file = "lxml-5.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ed8c3d2cd329bf779b7ed38db176738f3f8be637bb395ce9629fc76f78afe3d4"},
{file = "lxml-5.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:436a943c2900bb98123b06437cdd30580a61340fbdb7b28aaf345a459c19046a"},
{file = "lxml-5.1.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:acb6b2f96f60f70e7f34efe0c3ea34ca63f19ca63ce90019c6cbca6b676e81fa"},
@ -1544,7 +1542,6 @@ files = [
{file = "lxml-5.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f4c9bda132ad108b387c33fabfea47866af87f4ea6ffb79418004f0521e63204"},
{file = "lxml-5.1.0-cp311-cp311-win32.whl", hash = "sha256:bc64d1b1dab08f679fb89c368f4c05693f58a9faf744c4d390d7ed1d8223869b"},
{file = "lxml-5.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:a5ab722ae5a873d8dcee1f5f45ddd93c34210aed44ff2dc643b5025981908cda"},
{file = "lxml-5.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9aa543980ab1fbf1720969af1d99095a548ea42e00361e727c58a40832439114"},
{file = "lxml-5.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6f11b77ec0979f7e4dc5ae081325a2946f1fe424148d3945f943ceaede98adb8"},
{file = "lxml-5.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a36c506e5f8aeb40680491d39ed94670487ce6614b9d27cabe45d94cd5d63e1e"},
{file = "lxml-5.1.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f643ffd2669ffd4b5a3e9b41c909b72b2a1d5e4915da90a77e119b8d48ce867a"},
@ -1570,8 +1567,8 @@ files = [
{file = "lxml-5.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8f52fe6859b9db71ee609b0c0a70fea5f1e71c3462ecf144ca800d3f434f0764"},
{file = "lxml-5.1.0-cp37-cp37m-win32.whl", hash = "sha256:d42e3a3fc18acc88b838efded0e6ec3edf3e328a58c68fbd36a7263a874906c8"},
{file = "lxml-5.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:eac68f96539b32fce2c9b47eb7c25bb2582bdaf1bbb360d25f564ee9e04c542b"},
{file = "lxml-5.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ae15347a88cf8af0949a9872b57a320d2605ae069bcdf047677318bc0bba45b1"},
{file = "lxml-5.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c26aab6ea9c54d3bed716b8851c8bfc40cb249b8e9880e250d1eddde9f709bf5"},
{file = "lxml-5.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cfbac9f6149174f76df7e08c2e28b19d74aed90cad60383ad8671d3af7d0502f"},
{file = "lxml-5.1.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:342e95bddec3a698ac24378d61996b3ee5ba9acfeb253986002ac53c9a5f6f84"},
{file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:725e171e0b99a66ec8605ac77fa12239dbe061482ac854d25720e2294652eeaa"},
{file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d184e0d5c918cff04cdde9dbdf9600e960161d773666958c9d7b565ccc60c45"},
@ -1579,7 +1576,6 @@ files = [
{file = "lxml-5.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6d48fc57e7c1e3df57be5ae8614bab6d4e7b60f65c5457915c26892c41afc59e"},
{file = "lxml-5.1.0-cp38-cp38-win32.whl", hash = "sha256:7ec465e6549ed97e9f1e5ed51c657c9ede767bc1c11552f7f4d022c4df4a977a"},
{file = "lxml-5.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:b21b4031b53d25b0858d4e124f2f9131ffc1530431c6d1321805c90da78388d1"},
{file = "lxml-5.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:52427a7eadc98f9e62cb1368a5079ae826f94f05755d2d567d93ee1bc3ceb354"},
{file = "lxml-5.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6a2a2c724d97c1eb8cf966b16ca2915566a4904b9aad2ed9a09c748ffe14f969"},
{file = "lxml-5.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:843b9c835580d52828d8f69ea4302537337a21e6b4f1ec711a52241ba4a824f3"},
{file = "lxml-5.1.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9b99f564659cfa704a2dd82d0684207b1aadf7d02d33e54845f9fc78e06b7581"},
@ -2688,7 +2684,6 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@ -3919,4 +3914,4 @@ xarm = ["gym-xarm"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "bd9c506d2499d5e1e3b5e8b1a0f65df45c8feef38d89d0daeade56847fdb6a2e"
content-hash = "e526416d1282dea2550680b2be7fcf9ff6e1c67ac89d34c684b486d94a6addee"

View File

@ -53,7 +53,7 @@ pre-commit = {version = "^3.7.0", optional = true}
debugpy = {version = "^1.8.1", optional = true}
pytest = {version = "^8.1.0", optional = true}
pytest-cov = {version = "^5.0.0", optional = true}
datasets = "^2.18.0"
datasets = "^2.19.0"
[tool.poetry.extras]

View File

@ -208,7 +208,7 @@ HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATAS
You will need to set the corresponding version as a default argument in your dataset class:
```python
version: str | None = "v1.0",
version: str | None = "v1.1",
```
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)

View File

@ -4,6 +4,7 @@ useless dependencies when using datasets.
"""
import io
import json
import pickle
import shutil
from pathlib import Path
@ -14,16 +15,20 @@ import numpy as np
import torch
import tqdm
from datasets import Dataset, Features, Image, Sequence, Value
from huggingface_hub import HfApi
from PIL import Image as PILImage
from safetensors.torch import save_file
from lerobot.common.datasets.utils import compute_stats, flatten_dict, hf_transform_to_torch
def download_and_upload(root, root_tests, dataset_id):
def download_and_upload(root, revision, dataset_id):
if "pusht" in dataset_id:
download_and_upload_pusht(root, root_tests, dataset_id)
download_and_upload_pusht(root, revision, dataset_id)
elif "xarm" in dataset_id:
download_and_upload_xarm(root, root_tests, dataset_id)
download_and_upload_xarm(root, revision, dataset_id)
elif "aloha" in dataset_id:
download_and_upload_aloha(root, root_tests, dataset_id)
download_and_upload_aloha(root, revision, dataset_id)
else:
raise ValueError(dataset_id)
@ -56,7 +61,102 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
return False
def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
def concatenate_episodes(ep_dicts):
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
if torch.is_tensor(ep_dicts[0][key][0]):
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
data_dict[key].append(x)
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
return data_dict
def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id):
# push to main to indicate latest version
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
# push to version branch
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision=revision)
# create and store meta_data
meta_data_dir = root / dataset_id / "meta_data"
meta_data_dir.mkdir(parents=True, exist_ok=True)
api = HfApi()
# info
info_path = meta_data_dir / "info.json"
with open(str(info_path), "w") as f:
json.dump(info, f, indent=4)
api.upload_file(
path_or_fileobj=info_path,
path_in_repo=str(info_path).replace(f"{root}/{dataset_id}", ""),
repo_id=f"lerobot/{dataset_id}",
repo_type="dataset",
)
api.upload_file(
path_or_fileobj=info_path,
path_in_repo=str(info_path).replace(f"{root}/{dataset_id}", ""),
repo_id=f"lerobot/{dataset_id}",
repo_type="dataset",
revision=revision,
)
# stats
stats_path = meta_data_dir / "stats.safetensors"
save_file(flatten_dict(stats), stats_path)
api.upload_file(
path_or_fileobj=stats_path,
path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""),
repo_id=f"lerobot/{dataset_id}",
repo_type="dataset",
)
api.upload_file(
path_or_fileobj=stats_path,
path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""),
repo_id=f"lerobot/{dataset_id}",
repo_type="dataset",
revision=revision,
)
# episode_data_index
episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
save_file(episode_data_index, ep_data_idx_path)
api.upload_file(
path_or_fileobj=ep_data_idx_path,
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{dataset_id}", ""),
repo_id=f"lerobot/{dataset_id}",
repo_type="dataset",
)
api.upload_file(
path_or_fileobj=ep_data_idx_path,
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{dataset_id}", ""),
repo_id=f"lerobot/{dataset_id}",
repo_type="dataset",
revision=revision,
)
# copy in tests folder, the first episode and the meta_data directory
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
hf_dataset.select(range(num_items_first_ep)).with_format("torch").save_to_disk(
f"tests/data/{dataset_id}/train"
)
if Path(f"tests/data/{dataset_id}/meta_data").exists():
shutil.rmtree(f"tests/data/{dataset_id}/meta_data")
shutil.copytree(meta_data_dir, f"tests/data/{dataset_id}/meta_data")
def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
try:
import pymunk
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
@ -99,6 +199,7 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
actions = torch.from_numpy(dataset_dict["action"])
ep_dicts = []
episode_data_index = {"from": [], "to": []}
id_from = 0
for episode_id in tqdm.tqdm(range(num_episodes)):
@ -151,8 +252,8 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
"observation.state": agent_pos,
"action": actions[id_from:id_to],
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1),
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.image": image[1:],
# "next.observation.state": agent_pos[1:],
@ -160,28 +261,15 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
"next.done": torch.cat([done[1:], done[[-1]]]),
"next.success": torch.cat([success[1:], success[[-1]]]),
"episode_data_index_from": torch.tensor([id_from] * num_frames),
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
}
ep_dicts.append(ep_dict)
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from += num_frames
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
if torch.is_tensor(ep_dicts[0][key][0]):
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
data_dict[key].append(x)
total_frames = id_from
data_dict["index"] = torch.arange(0, total_frames, 1)
data_dict = concatenate_episodes(ep_dicts)
features = {
"observation.image": Image(),
@ -189,35 +277,35 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_id": Value(dtype="int64", id=None),
"frame_id": Value(dtype="int64", id=None),
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
"next.reward": Value(dtype="float32", id=None),
"next.done": Value(dtype="bool", id=None),
"next.success": Value(dtype="bool", id=None),
"index": Value(dtype="int64", id=None),
"episode_data_index_from": Value(dtype="int64", id=None),
"episode_data_index_to": Value(dtype="int64", id=None),
}
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(hf_transform_to_torch)
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
info = {
"fps": fps,
}
stats = compute_stats(hf_dataset)
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
def download_and_upload_xarm(root, revision, dataset_id, fps=15):
root = Path(root)
raw_dir = root / f"{dataset_id}_raw"
raw_dir = root / "xarm_datasets_raw"
if not raw_dir.exists():
import zipfile
import gdown
raw_dir.mkdir(parents=True, exist_ok=True)
# from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
zip_path = raw_dir / "data.zip"
gdown.download(url, str(zip_path), quiet=False)
@ -234,13 +322,13 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
total_frames = dataset_dict["actions"].shape[0]
ep_dicts = []
episode_data_index = {"from": [], "to": []}
id_from = 0
id_to = 0
episode_id = 0
total_frames = dataset_dict["actions"].shape[0]
for i in tqdm.tqdm(range(total_frames)):
id_to += 1
@ -264,35 +352,23 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
"observation.state": state,
"action": action,
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1),
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.image": next_image,
# "next.observation.state": next_state,
"next.reward": next_reward,
"next.done": next_done,
"episode_data_index_from": torch.tensor([id_from] * num_frames),
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
}
ep_dicts.append(ep_dict)
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from = id_to
episode_id += 1
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
if torch.is_tensor(ep_dicts[0][key][0]):
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
data_dict[key].append(x)
total_frames = id_from
data_dict["index"] = torch.arange(0, total_frames, 1)
data_dict = concatenate_episodes(ep_dicts)
features = {
"observation.image": Image(),
@ -300,27 +376,26 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_id": Value(dtype="int64", id=None),
"frame_id": Value(dtype="int64", id=None),
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
"next.reward": Value(dtype="float32", id=None),
"next.done": Value(dtype="bool", id=None),
#'next.success': Value(dtype='bool', id=None),
"index": Value(dtype="int64", id=None),
"episode_data_index_from": Value(dtype="int64", id=None),
"episode_data_index_to": Value(dtype="int64", id=None),
}
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(hf_transform_to_torch)
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
info = {
"fps": fps,
}
stats = compute_stats(hf_dataset)
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
def download_and_upload_aloha(root, revision, dataset_id, fps=50):
folder_urls = {
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
@ -381,6 +456,7 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
gdown.download(ep49_urls[dataset_id], output=str(raw_dir / "episode_49.hdf5"), fuzzy=True)
ep_dicts = []
episode_data_index = {"from": [], "to": []}
id_from = 0
for ep_id in tqdm.tqdm(range(num_episodes[dataset_id])):
@ -408,40 +484,26 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
{
"observation.state": state,
"action": action,
"episode_id": torch.tensor([ep_id] * num_frames),
"frame_id": torch.arange(0, num_frames, 1),
"episode_index": torch.tensor([ep_id] * num_frames),
"frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.state": state,
# TODO(rcadene): compute reward and success
# "next.reward": reward,
"next.done": done,
# "next.success": success,
"episode_data_index_from": torch.tensor([id_from] * num_frames),
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
}
)
assert isinstance(ep_id, int)
ep_dicts.append(ep_dict)
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from += num_frames
data_dict = {}
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
if torch.is_tensor(ep_dicts[0][key][0]):
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
data_dict[key].append(x)
total_frames = id_from
data_dict["index"] = torch.arange(0, total_frames, 1)
data_dict = concatenate_episodes(ep_dicts)
features = {
"observation.images.top": Image(),
@ -449,39 +511,39 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_id": Value(dtype="int64", id=None),
"frame_id": Value(dtype="int64", id=None),
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
#'next.reward': Value(dtype='float32', id=None),
"next.done": Value(dtype="bool", id=None),
#'next.success': Value(dtype='bool', id=None),
"index": Value(dtype="int64", id=None),
"episode_data_index_from": Value(dtype="int64", id=None),
"episode_data_index_to": Value(dtype="int64", id=None),
}
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(hf_transform_to_torch)
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
info = {
"fps": fps,
}
stats = compute_stats(hf_dataset)
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
if __name__ == "__main__":
root = "data"
root_tests = "tests/data"
revision = "v1.1"
dataset_ids = [
# "pusht",
# "xarm_lift_medium",
# "aloha_sim_insertion_human",
# "aloha_sim_insertion_scripted",
# "aloha_sim_transfer_cube_human",
"pusht",
"xarm_lift_medium",
"xarm_lift_medium_replay",
"xarm_push_medium",
"xarm_push_medium_replay",
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
]
for dataset_id in dataset_ids:
download_and_upload(root, root_tests, dataset_id)
# assume stats have been precomputed
shutil.copy(f"{root}/{dataset_id}/stats.pth", f"{root_tests}/{dataset_id}/stats.pth")
download_and_upload(root, revision, dataset_id)

View File

@ -10,10 +10,13 @@ As an example, this script saves frames of episode number 5 of the PushT dataset
This script supports several Hugging Face datasets, among which:
1. [Pusht](https://huggingface.co/datasets/lerobot/pusht)
2. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_lift_medium)
3. [Aloha Sim Insertion Human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
4. [Aloha Sim Insertion Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
5. [Aloha Sim Transfer Cube Human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
6. [Aloha Sim Transfer Cube Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
3. [Xarm Lift Medium Replay](https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay)
4. [Xarm Push Medium](https://huggingface.co/datasets/lerobot/xarm_push_medium)
5. [Xarm Push Medium Replay](https://huggingface.co/datasets/lerobot/xarm_push_medium_replay)
6. [Aloha Sim Insertion Human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
7. [Aloha Sim Insertion Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
8. [Aloha Sim Transfer Cube Human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
9. [Aloha Sim Transfer Cube Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
To try a different Hugging Face dataset, you can replace this line:
```python
@ -22,12 +25,16 @@ hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
by one of these:
```python
hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium", split="train"), 15
hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium_replay", split="train"), 15
hf_dataset, fps = load_dataset("lerobot/xarm_push_medium", split="train"), 15
hf_dataset, fps = load_dataset("lerobot/xarm_push_medium_replay", split="train"), 15
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_human", split="train"), 50
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_scripted", split="train"), 50
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_human", split="train"), 50
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_scripted", split="train"), 50
```
"""
# TODO(rcadene): remove this example file of using hf_dataset
from pathlib import Path
@ -37,19 +44,22 @@ from datasets import load_dataset
# TODO(rcadene): list available datasets on lerobot page using `datasets`
# download/load hugging face dataset in pyarrow format
hf_dataset, fps = load_dataset("lerobot/pusht", revision="v1.0", split="train"), 10
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
# display name of dataset and its features
# TODO(rcadene): update to make the print pretty
print(f"{hf_dataset=}")
print(f"{hf_dataset.features=}")
# display useful statistics about frames and episodes, which are sequences of frames from the same video
print(f"number of frames: {len(hf_dataset)=}")
print(f"number of episodes: {len(hf_dataset.unique('episode_id'))=}")
print(f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_id')):.3f}")
print(f"number of episodes: {len(hf_dataset.unique('episode_index'))=}")
print(
f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_index')):.3f}"
)
# select the frames belonging to episode number 5
hf_dataset = hf_dataset.filter(lambda frame: frame["episode_id"] == 5)
hf_dataset = hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
# load all frames of episode 5 in RAM in PIL format
frames = hf_dataset["observation.image"]

View File

@ -18,7 +18,10 @@ dataset = PushtDataset()
```
by one of these:
```python
dataset = XarmDataset()
dataset = XarmDataset("xarm_lift_medium")
dataset = XarmDataset("xarm_lift_medium_replay")
dataset = XarmDataset("xarm_push_medium")
dataset = XarmDataset("xarm_push_medium_replay")
dataset = AlohaDataset("aloha_sim_insertion_human")
dataset = AlohaDataset("aloha_sim_insertion_scripted")
dataset = AlohaDataset("aloha_sim_transfer_cube_human")
@ -44,6 +47,7 @@ from lerobot.common.datasets.pusht import PushtDataset
dataset = PushtDataset()
# All LeRobot datasets are actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information).
# TODO(rcadene): update to make the print pretty
print(f"{dataset=}")
print(f"{dataset.hf_dataset=}")
@ -55,13 +59,16 @@ print(f"frames per second used during data collection: {dataset.fps=}")
print(f"keys to access images from cameras: {dataset.image_keys=}")
# While the LeRobot dataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_id"] == 5)
# TODO(rcadene): remove this example of accessing hf_dataset
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grap all the image frames.
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grab all the image frames.
frames = [sample["observation.image"] for sample in dataset]
# but frames are now channel first to follow pytorch convention,
# to view them, we convert to channel last
# but frames are now float32 range [0,1] channel first (c,h,w) to follow pytorch convention,
# to view them, we convert to uint8 range [0,255]
frames = [(frame * 255).type(torch.uint8) for frame in frames]
# and to channel last (h,w,c)
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
# and finally save them to a mp4 video

View File

@ -50,7 +50,12 @@ available_datasets = {
"aloha_sim_transfer_cube_scripted",
],
"pusht": ["pusht"],
"xarm": ["xarm_lift_medium"],
"xarm": [
"xarm_lift_medium",
"xarm_lift_medium_replay",
"xarm_push_medium",
"xarm_push_medium_replay",
],
}
available_policies = [

View File

@ -1,9 +1,13 @@
from pathlib import Path
import torch
from datasets import load_dataset, load_from_disk
from lerobot.common.datasets.utils import load_previous_and_future_frames
from lerobot.common.datasets.utils import (
load_episode_data_index,
load_hf_dataset,
load_previous_and_future_frames,
load_stats,
)
class AlohaDataset(torch.utils.data.Dataset):
@ -27,7 +31,7 @@ class AlohaDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset_id: str,
version: str | None = "v1.0",
version: str | None = "v1.1",
root: Path | None = None,
split: str = "train",
transform: callable = None,
@ -40,13 +44,10 @@ class AlohaDataset(torch.utils.data.Dataset):
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
if self.root is not None:
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
else:
self.hf_dataset = load_dataset(
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
)
self.hf_dataset = self.hf_dataset.with_format("torch")
# load data from hub or locally when root is provided
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
self.stats = load_stats(dataset_id, version, root)
@property
def num_samples(self) -> int:
@ -54,7 +55,7 @@ class AlohaDataset(torch.utils.data.Dataset):
@property
def num_episodes(self) -> int:
return len(self.hf_dataset.unique("episode_id"))
return len(self.hf_dataset.unique("episode_index"))
def __len__(self):
return self.num_samples
@ -66,19 +67,11 @@ class AlohaDataset(torch.utils.data.Dataset):
item = load_previous_and_future_frames(
item,
self.hf_dataset,
self.episode_data_index,
self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
)
# convert images from channel last (PIL) to channel first (pytorch)
for key in self.image_keys:
if item[key].ndim == 3:
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
elif item[key].ndim == 4:
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
else:
raise ValueError(item[key].ndim)
if self.transform is not None:
item = self.transform(item)

View File

@ -1,12 +1,10 @@
import logging
import os
from pathlib import Path
import torch
from torchvision.transforms import v2
from lerobot.common.datasets.utils import compute_stats
from lerobot.common.transforms import NormalizeTransform, Prod
from lerobot.common.transforms import NormalizeTransform
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
@ -52,32 +50,18 @@ def make_dataset(
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
elif stats_path is None:
# load stats if the file exists already or compute stats and save it
if DATA_DIR is None:
# TODO(rcadene): clean stats
precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth"
else:
precomputed_stats_path = DATA_DIR / cfg.dataset_id / "stats.pth"
if precomputed_stats_path.exists():
stats = torch.load(precomputed_stats_path)
else:
logging.info(f"compute_stats and save to {precomputed_stats_path}")
# Create a dataset for stats computation.
# load a first dataset to access precomputed stats
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
split="train",
root=DATA_DIR,
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
)
stats = compute_stats(stats_dataset)
precomputed_stats_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(stats, precomputed_stats_path)
stats = stats_dataset.stats
else:
stats = torch.load(stats_path)
transforms = v2.Compose(
[
Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
NormalizeTransform(
stats,
in_keys=[

View File

@ -1,9 +1,13 @@
from pathlib import Path
import torch
from datasets import load_dataset, load_from_disk
from lerobot.common.datasets.utils import load_previous_and_future_frames
from lerobot.common.datasets.utils import (
load_episode_data_index,
load_hf_dataset,
load_previous_and_future_frames,
load_stats,
)
class PushtDataset(torch.utils.data.Dataset):
@ -25,7 +29,7 @@ class PushtDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset_id: str = "pusht",
version: str | None = "v1.0",
version: str | None = "v1.1",
root: Path | None = None,
split: str = "train",
transform: callable = None,
@ -38,13 +42,10 @@ class PushtDataset(torch.utils.data.Dataset):
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
if self.root is not None:
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
else:
self.hf_dataset = load_dataset(
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
)
self.hf_dataset = self.hf_dataset.with_format("torch")
# load data from hub or locally when root is provided
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
self.stats = load_stats(dataset_id, version, root)
@property
def num_samples(self) -> int:
@ -52,7 +53,7 @@ class PushtDataset(torch.utils.data.Dataset):
@property
def num_episodes(self) -> int:
return len(self.hf_dataset.unique("episode_id"))
return len(self.episode_data_index["from"])
def __len__(self):
return self.num_samples
@ -64,19 +65,11 @@ class PushtDataset(torch.utils.data.Dataset):
item = load_previous_and_future_frames(
item,
self.hf_dataset,
self.episode_data_index,
self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
)
# convert images from channel last (PIL) to channel first (pytorch)
for key in self.image_keys:
if item[key].ndim == 3:
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
elif item[key].ndim == 4:
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
else:
raise ValueError(item[key].ndim)
if self.transform is not None:
item = self.transform(item)

View File

@ -1,15 +1,121 @@
from copy import deepcopy
from math import ceil
from pathlib import Path
import datasets
import einops
import torch
import tqdm
from datasets import Image, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
from PIL import Image as PILImage
from safetensors.torch import load_file
from torchvision import transforms
def flatten_dict(d, parent_key="", sep="/"):
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
For example:
```
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}`
>>> print(flatten_dict(dct))
{"a/b": 1, "a/c/d": 2, "e": 3}
"""
items = []
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
def unflatten_dict(d, sep="/"):
outdict = {}
for key, value in d.items():
parts = key.split(sep)
d = outdict
for part in parts[:-1]:
if part not in d:
d[part] = {}
d = d[part]
d[parts[-1]] = value
return outdict
def hf_transform_to_torch(items_dict):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to
a channel last representation (h w c) of uint8 type, to a torch image representation
with channel first (c h w) of float32 type in range [0,1].
"""
for key in items_dict:
first_item = items_dict[key][0]
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
else:
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
return items_dict
def load_hf_dataset(dataset_id, version, root, split) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if root is not None:
hf_dataset = load_from_disk(str(Path(root) / dataset_id / split))
else:
# TODO(rcadene): remove dataset_id everywhere and use repo_id instead
repo_id = f"lerobot/{dataset_id}"
hf_dataset = load_dataset(repo_id, revision=version, split=split)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor]:
"""episode_data_index contains the range of indices for each episode
Example:
```python
from_id = episode_data_index["from"][episode_id].item()
to_id = episode_data_index["to"][episode_id].item()
episode_frames = [dataset[i] for i in range(from_id, to_id)]
```
"""
if root is not None:
path = Path(root) / dataset_id / "meta_data" / "episode_data_index.safetensors"
else:
repo_id = f"lerobot/{dataset_id}"
path = hf_hub_download(
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version
)
return load_file(path)
def load_stats(dataset_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
Example:
```python
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
```
"""
if root is not None:
path = Path(root) / dataset_id / "meta_data" / "stats.safetensors"
else:
repo_id = f"lerobot/{dataset_id}"
path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version)
stats = load_file(path)
return unflatten_dict(stats)
def load_previous_and_future_frames(
item: dict[str, torch.Tensor],
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
delta_timestamps: dict[str, list[float]],
tol: float,
) -> dict[torch.Tensor]:
@ -31,6 +137,8 @@ def load_previous_and_future_frames(
corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different
modality (e.g., "timestamp", "observation.image", "action").
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
They indicate the start index and end index of each episode in the dataset.
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be
retrieved. These deltas are added to the item timestamp to form the query timestamps.
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query
@ -46,12 +154,14 @@ def load_previous_and_future_frames(
issues with timestamps during data collection.
"""
# get indices of the frames associated to the episode, and their timestamps
ep_data_id_from = item["episode_data_index_from"].item()
ep_data_id_to = item["episode_data_index_to"].item()
ep_id = item["episode_index"].item()
ep_data_id_from = episode_data_index["from"][ep_id].item()
ep_data_id_to = episode_data_index["to"][ep_id].item()
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
# load timestamps
ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
ep_timestamps = torch.stack(ep_timestamps)
# we make the assumption that the timestamps are sorted
ep_first_ts = ep_timestamps[0]
@ -82,39 +192,57 @@ def load_previous_and_future_frames(
# load frames modality
item[key] = hf_dataset.select_columns(key)[data_ids][key]
item[key] = torch.stack(item[key])
item[f"{key}_is_pad"] = is_pad
return item
def get_stats_einops_patterns(dataset):
"""These einops patterns will be used to aggregate batches and compute statistics."""
stats_patterns = {
"action": "b c -> c",
"observation.state": "b c -> c",
}
for key in dataset.image_keys:
def get_stats_einops_patterns(hf_dataset):
"""These einops patterns will be used to aggregate batches and compute statistics.
Note: We assume the images of `hf_dataset` are in channel first format
"""
dataloader = torch.utils.data.DataLoader(
hf_dataset,
num_workers=0,
batch_size=2,
shuffle=False,
)
batch = next(iter(dataloader))
stats_patterns = {}
for key, feats_type in hf_dataset.features.items():
# sanity check that tensors are not float64
assert batch[key].dtype != torch.float64
if isinstance(feats_type, Image):
# sanity check that images are channel first
_, c, h, w = batch[key].shape
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
# sanity check that images are float32 in range [0,1]
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
stats_patterns[key] = "b c h w -> c 1 1"
elif batch[key].ndim == 2:
stats_patterns[key] = "b c -> c "
elif batch[key].ndim == 1:
stats_patterns[key] = "b -> 1"
else:
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
return stats_patterns
def compute_stats(dataset, batch_size=32, max_num_samples=None):
def compute_stats(hf_dataset, batch_size=32, max_num_samples=None):
if max_num_samples is None:
max_num_samples = len(dataset)
else:
raise NotImplementedError("We need to set shuffle=True, but this violate an assert for now.")
max_num_samples = len(hf_dataset)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=batch_size,
shuffle=False,
# pin_memory=cfg.device != "cpu",
drop_last=False,
)
# get einops patterns to aggregate batches and compute statistics
stats_patterns = get_stats_einops_patterns(dataset)
stats_patterns = get_stats_einops_patterns(hf_dataset)
# mean and std will be computed incrementally while max and min will track the running value.
mean, std, max, min = {}, {}, {}, {}
@ -124,10 +252,24 @@ def compute_stats(dataset, batch_size=32, max_num_samples=None):
max[key] = torch.tensor(-float("inf")).float()
min[key] = torch.tensor(float("inf")).float()
def create_seeded_dataloader(hf_dataset, batch_size, seed):
generator = torch.Generator()
generator.manual_seed(seed)
dataloader = torch.utils.data.DataLoader(
hf_dataset,
num_workers=4,
batch_size=batch_size,
shuffle=True,
drop_last=False,
generator=generator,
)
return dataloader
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
# surprises when rerunning the sampler.
first_batch = None
running_item_count = 0 # for online mean computation
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
):
@ -153,6 +295,7 @@ def compute_stats(dataset, batch_size=32, max_num_samples=None):
first_batch_ = None
running_item_count = 0 # for online std computation
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
):

View File

@ -1,25 +1,37 @@
from pathlib import Path
import torch
from datasets import load_dataset, load_from_disk
from lerobot.common.datasets.utils import load_previous_and_future_frames
from lerobot.common.datasets.utils import (
load_episode_data_index,
load_hf_dataset,
load_previous_and_future_frames,
load_stats,
)
class XarmDataset(torch.utils.data.Dataset):
"""
https://huggingface.co/datasets/lerobot/xarm_lift_medium
https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay
https://huggingface.co/datasets/lerobot/xarm_push_medium
https://huggingface.co/datasets/lerobot/xarm_push_medium_replay
"""
# Copied from lerobot/__init__.py
available_datasets = ["xarm_lift_medium"]
available_datasets = [
"xarm_lift_medium",
"xarm_lift_medium_replay",
"xarm_push_medium",
"xarm_push_medium_replay",
]
fps = 15
image_keys = ["observation.image"]
def __init__(
self,
dataset_id: str = "xarm_lift_medium",
version: str | None = "v1.0",
dataset_id: str,
version: str | None = "v1.1",
root: Path | None = None,
split: str = "train",
transform: callable = None,
@ -32,13 +44,10 @@ class XarmDataset(torch.utils.data.Dataset):
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
if self.root is not None:
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
else:
self.hf_dataset = load_dataset(
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
)
self.hf_dataset = self.hf_dataset.with_format("torch")
# load data from hub or locally when root is provided
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
self.stats = load_stats(dataset_id, version, root)
@property
def num_samples(self) -> int:
@ -46,7 +55,7 @@ class XarmDataset(torch.utils.data.Dataset):
@property
def num_episodes(self) -> int:
return len(self.hf_dataset.unique("episode_id"))
return len(self.hf_dataset.unique("episode_index"))
def __len__(self):
return self.num_samples
@ -58,19 +67,11 @@ class XarmDataset(torch.utils.data.Dataset):
item = load_previous_and_future_frames(
item,
self.hf_dataset,
self.episode_data_index,
self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
)
# convert images from channel last (PIL) to channel first (pytorch)
for key in self.image_keys:
if item[key].ndim == 3:
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
elif item[key].ndim == 4:
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
else:
raise ValueError(item[key].ndim)
if self.transform is not None:
item = self.transform(item)

View File

@ -39,4 +39,5 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
for _ in range(num_parallel_envs)
]
)
return env

View File

@ -15,8 +15,19 @@ def preprocess_observation(observation, transform=None):
for imgkey, img in imgs.items():
img = torch.from_numpy(img)
# convert to (b c h w) torch format
# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w")
img = img.type(torch.float32)
img /= 255
obs[imgkey] = img
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"

View File

@ -1,4 +1,3 @@
import torch
from torchvision.transforms.v2 import Compose, Transform
@ -12,40 +11,6 @@ def apply_inverse_transform(item, transform):
return item
class Prod(Transform):
invertible = True
def __init__(self, in_keys: list[str], prod: float):
super().__init__()
self.in_keys = in_keys
self.prod = prod
self.original_dtypes = {}
def forward(self, item):
for key in self.in_keys:
if key not in item:
continue
self.original_dtypes[key] = item[key].dtype
item[key] = item[key].type(torch.float32) * self.prod
return item
def inverse_transform(self, item):
for key in self.in_keys:
if key not in item:
continue
item[key] = (item[key] / self.prod).type(self.original_dtypes[key])
return item
# def transform_observation_spec(self, obs_spec):
# for key in self.in_keys:
# if obs_spec.get(key, None) is None:
# continue
# obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod
# obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod
# obs_spec[key].dtype = torch.float32
# return obs_spec
class NormalizeTransform(Transform):
invertible = True

View File

@ -47,6 +47,7 @@ from PIL import Image as PILImage
from tqdm import trange
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import hf_transform_to_torch
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.logger import log_output_dir
@ -208,11 +209,12 @@ def eval_policy(
max_rewards.extend(batch_max_reward.tolist())
all_successes.extend(batch_success.tolist())
# similar logic is implemented in dataset preprocessing
# similar logic is implemented when datasets are pushed to hub (see: `push_to_hub`)
ep_dicts = []
episode_data_index = {"from": [], "to": []}
num_episodes = dones.shape[0]
total_frames = 0
idx_from = 0
id_from = 0
for ep_id in range(num_episodes):
num_frames = done_indices[ep_id].item() + 1
total_frames += num_frames
@ -222,19 +224,20 @@ def eval_policy(
if return_episode_data:
ep_dict = {
"action": actions[ep_id, :num_frames],
"episode_id": torch.tensor([ep_id] * num_frames),
"frame_id": torch.arange(0, num_frames, 1),
"episode_index": torch.tensor([ep_id] * num_frames),
"frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
"next.done": dones[ep_id, :num_frames],
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
"episode_data_index_from": torch.tensor([idx_from] * num_frames),
"episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames),
}
for key in observations:
ep_dict[key] = observations[key][ep_id][:num_frames]
ep_dicts.append(ep_dict)
idx_from += num_frames
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from += num_frames
# similar logic is implemented in dataset preprocessing
if return_episode_data:
@ -247,14 +250,29 @@ def eval_policy(
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
# c h w -> h w c
img = PILImage.fromarray(x.permute(1, 2, 0).numpy())
for img in ep_dict[key]:
# sanity check that images are channel first
c, h, w = img.shape
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
# sanity check that images are float32 in range [0,1]
assert img.dtype == torch.float32, f"expect torch.float32, but instead {img.dtype=}"
assert img.max() <= 1, f"expect pixels lower than 1, but instead {img.max()=}"
assert img.min() >= 0, f"expect pixels greater than 1, but instead {img.min()=}"
# from float32 in range [0,1] to uint8 in range [0,255]
img *= 255
img = img.type(torch.uint8)
# convert to channel last and numpy as expected by PIL
img = PILImage.fromarray(img.permute(1, 2, 0).numpy())
data_dict[key].append(img)
data_dict["index"] = torch.arange(0, total_frames, 1)
hf_dataset = Dataset.from_dict(data_dict).with_format("torch")
hf_dataset = Dataset.from_dict(data_dict)
hf_dataset.set_transform(hf_transform_to_torch)
if max_episodes_rendered > 0:
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
@ -307,7 +325,10 @@ def eval_policy(
},
}
if return_episode_data:
info["episodes"] = hf_dataset
info["episodes"] = {
"hf_dataset": hf_dataset,
"episode_data_index": episode_data_index,
}
if max_episodes_rendered > 0:
info["videos"] = videos
return info

View File

@ -136,6 +136,7 @@ def add_episodes_inplace(
concat_dataset: torch.utils.data.ConcatDataset,
sampler: torch.utils.data.WeightedRandomSampler,
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
pc_online_samples: float,
):
"""
@ -151,13 +152,15 @@ def add_episodes_inplace(
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
reflect changes in the dataset sizes and specified sampling weights.
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
They indicate the start index and end index of each episode in the dataset.
- pc_online_samples (float): The target percentage of samples that should come from
the online dataset during sampling operations.
Raises:
- AssertionError: If the first episode_id or index in hf_dataset is not 0
"""
first_episode_id = hf_dataset.select_columns("episode_id")[0]["episode_id"].item()
first_episode_id = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
first_index = hf_dataset.select_columns("index")[0]["index"].item()
assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}"
assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}"
@ -167,21 +170,22 @@ def add_episodes_inplace(
online_dataset.hf_dataset = hf_dataset
else:
# find episode index and data frame indices according to previous episode in online_dataset
start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1
start_episode = online_dataset.select_columns("episode_index")[-1]["episode_index"].item() + 1
start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1
def shift_indices(example):
# note: we dont shift "frame_id" since it represents the index of the frame in the episode it belongs to
example["episode_id"] += start_episode
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
example["episode_index"] += start_episode
example["index"] += start_index
example["episode_data_index_from"] += start_index
example["episode_data_index_to"] += start_index
return example
disable_progress_bars() # map has a tqdm progress bar
hf_dataset = hf_dataset.map(shift_indices)
enable_progress_bars()
episode_data_index["from"] += start_index
episode_data_index["to"] += start_index
# extend online dataset
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
@ -334,9 +338,13 @@ def train(cfg: dict, out_dir=None, job_name=None):
seed=cfg.seed,
)
online_pc_sampling = cfg.get("demo_schedule", 0.5)
add_episodes_inplace(
online_dataset, concat_dataset, sampler, eval_info["episodes"], online_pc_sampling
online_dataset,
concat_dataset,
sampler,
hf_dataset=eval_info["episodes"]["hf_dataset"],
episode_data_index=eval_info["episodes"]["episode_data_index"],
pc_online_samples=cfg.get("demo_schedule", 0.5),
)
for _ in range(cfg.policy.utd):

View File

@ -22,11 +22,24 @@ def visualize_dataset_cli(cfg: dict):
def cat_and_write_video(video_path, frames, fps):
# Expects images in [0, 255].
frames = torch.cat(frames)
assert frames.dtype == torch.uint8
frames = einops.rearrange(frames, "b c h w -> b h w c").numpy()
imageio.mimsave(video_path, frames, fps=fps)
# Expects images in [0, 1].
frame = frames[0]
if frame.ndim == 4:
raise NotImplementedError("We currently dont support multiple timestamps.")
c, h, w = frame.shape
assert c < h and c < w, f"expect channel first images, but instead {frame.shape}"
# sanity check that images are float32 in range [0,1]
assert frame.dtype == torch.float32, f"expect torch.float32, but instead {frame.dtype=}"
assert frame.max() <= 1, f"expect pixels lower than 1, but instead {frame.max()=}"
assert frame.min() >= 0, f"expect pixels greater than 1, but instead {frame.min()=}"
# convert to channel last uint8 [0, 255]
frames = einops.rearrange(frames, "b c h w -> b h w c")
frames = (frames * 255).type(torch.uint8)
imageio.mimsave(video_path, frames.numpy(), fps=fps)
def visualize_dataset(cfg: dict, out_dir=None):
@ -44,9 +57,10 @@ def visualize_dataset(cfg: dict, out_dir=None):
)
logging.info("Start rendering episodes from offline buffer")
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps)
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER)
for video_path in video_paths:
logging.info(video_path)
return video_paths
def render_dataset(dataset, out_dir, max_num_episodes):
@ -77,7 +91,7 @@ def render_dataset(dataset, out_dir, max_num_episodes):
# add current frame to list of frames to render
frames[im_key].append(item[im_key])
end_of_episode = item["index"].item() == item["episode_data_index_to"].item() - 1
end_of_episode = item["index"].item() == dataset.episode_data_index["to"][ep_id] - 1
out_dir.mkdir(parents=True, exist_ok=True)
for im_key in dataset.image_keys:

25
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]]
name = "absl-py"
@ -522,21 +522,21 @@ toml = ["tomli"]
[[package]]
name = "datasets"
version = "2.18.0"
version = "2.19.0"
description = "HuggingFace community-driven open-source library of datasets"
optional = false
python-versions = ">=3.8.0"
files = [
{file = "datasets-2.18.0-py3-none-any.whl", hash = "sha256:f1bbf0e2896917a914de01cbd37075b14deea3837af87ad0d9f697388ccaeb50"},
{file = "datasets-2.18.0.tar.gz", hash = "sha256:cdf8b8c6abf7316377ba4f49f9589a4c74556d6b481afd0abd2284f3d69185cb"},
{file = "datasets-2.19.0-py3-none-any.whl", hash = "sha256:f57c5316e123d4721b970c68c1cb856505f289cda58f5557ffe745b49c011a8e"},
{file = "datasets-2.19.0.tar.gz", hash = "sha256:0b47e08cc7af2c6800a42cadc4657b22a0afc7197786c8986d703c08d90886a6"},
]
[package.dependencies]
aiohttp = "*"
dill = ">=0.3.0,<0.3.9"
filelock = "*"
fsspec = {version = ">=2023.1.0,<=2024.2.0", extras = ["http"]}
huggingface-hub = ">=0.19.4"
fsspec = {version = ">=2023.1.0,<=2024.3.1", extras = ["http"]}
huggingface-hub = ">=0.21.2"
multiprocess = "*"
numpy = ">=1.17"
packaging = "*"
@ -552,15 +552,15 @@ xxhash = "*"
apache-beam = ["apache-beam (>=2.26.0)"]
audio = ["librosa", "soundfile (>=0.12.1)"]
benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"]
dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
docs = ["s3fs", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos", "torch", "transformers"]
dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
docs = ["s3fs", "tensorflow (>=2.6.0)", "torch", "transformers"]
jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"]
metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"]
quality = ["ruff (>=0.3.0)"]
s3 = ["s3fs"]
tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos"]
tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"]
tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
tensorflow = ["tensorflow (>=2.6.0)"]
tensorflow-gpu = ["tensorflow (>=2.6.0)"]
tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
torch = ["torch"]
vision = ["Pillow (>=6.2.1)"]
@ -2909,7 +2909,6 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@ -4195,4 +4194,4 @@ xarm = ["gym-xarm"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "01ad4eb04061ec9f785d4574bf66d3e5cb4549e2ea11ab175895f94cb62c1f1c"
content-hash = "7f5afa48aead953f598e686e767891d3d23f2862b80144f76dc064101ef80b4a"

View File

@ -53,7 +53,8 @@ pre-commit = {version = "^3.7.0", optional = true}
debugpy = {version = "^1.8.1", optional = true}
pytest = {version = "^8.1.0", optional = true}
pytest-cov = {version = "^5.0.0", optional = true}
datasets = "^2.18.0"
datasets = "^2.19.0"
[tool.poetry.extras]
pusht = ["gym-pusht"]

View File

@ -0,0 +1,3 @@
{
"fps": 50
}

View File

@ -21,11 +21,11 @@
"length": 14,
"_type": "Sequence"
},
"episode_id": {
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_id": {
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
@ -37,14 +37,6 @@
"dtype": "bool",
"_type": "Value"
},
"episode_data_index_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_index_to": {
"dtype": "int64",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"

View File

@ -4,7 +4,7 @@
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "d79cf82ffc86f110",
"_fingerprint": "22eeca7a3f4725ee",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",

View File

@ -0,0 +1,3 @@
{
"fps": 50
}

View File

@ -21,11 +21,11 @@
"length": 14,
"_type": "Sequence"
},
"episode_id": {
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_id": {
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
@ -37,14 +37,6 @@
"dtype": "bool",
"_type": "Value"
},
"episode_data_index_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_index_to": {
"dtype": "int64",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"

View File

@ -4,7 +4,7 @@
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "d8e4a817b5449498",
"_fingerprint": "97c28d4ad1536e4c",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",

View File

@ -0,0 +1,3 @@
{
"fps": 50
}

View File

@ -21,11 +21,11 @@
"length": 14,
"_type": "Sequence"
},
"episode_id": {
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_id": {
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
@ -37,14 +37,6 @@
"dtype": "bool",
"_type": "Value"
},
"episode_data_index_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_index_to": {
"dtype": "int64",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"

View File

@ -4,7 +4,7 @@
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "f03482befa767127",
"_fingerprint": "cb9349b5c92951e8",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",

View File

@ -0,0 +1,3 @@
{
"fps": 50
}

View File

@ -21,11 +21,11 @@
"length": 14,
"_type": "Sequence"
},
"episode_id": {
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_id": {
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
@ -37,14 +37,6 @@
"dtype": "bool",
"_type": "Value"
},
"episode_data_index_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_index_to": {
"dtype": "int64",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"

View File

@ -4,7 +4,7 @@
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "93e03c6320c7d56e",
"_fingerprint": "e4d7ad2b360db1af",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",

View File

@ -0,0 +1,3 @@
{
"fps": 10
}

Binary file not shown.

View File

@ -21,11 +21,11 @@
"length": 2,
"_type": "Sequence"
},
"episode_id": {
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_id": {
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
@ -45,14 +45,6 @@
"dtype": "bool",
"_type": "Value"
},
"episode_data_index_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_index_to": {
"dtype": "int64",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"

View File

@ -0,0 +1,3 @@
{
"fps": 10
}

View File

@ -4,7 +4,7 @@
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "21bb9a76ed78a475",
"_fingerprint": "a04a9ce660122e23",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",

View File

@ -0,0 +1,3 @@
{
"fps": 15
}

View File

@ -21,11 +21,11 @@
"length": 4,
"_type": "Sequence"
},
"episode_id": {
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_id": {
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
@ -41,14 +41,6 @@
"dtype": "bool",
"_type": "Value"
},
"episode_data_index_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_index_to": {
"dtype": "int64",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"

View File

@ -4,7 +4,7 @@
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "a95cbec45e3bb9d6",
"_fingerprint": "cc6afdfcdd6f63ab",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",

View File

@ -0,0 +1,3 @@
{
"fps": 15
}

View File

@ -0,0 +1,51 @@
{
"citation": "",
"description": "",
"features": {
"observation.image": {
"_type": "Image"
},
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 4,
"_type": "Sequence"
},
"action": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 4,
"_type": "Sequence"
},
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"next.reward": {
"dtype": "float32",
"_type": "Value"
},
"next.done": {
"dtype": "bool",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"
}
},
"homepage": "",
"license": ""
}

View File

@ -0,0 +1,13 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "9f8e1a8c1845df55",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",
"_output_all_columns": false,
"_split": null
}

View File

@ -0,0 +1,3 @@
{
"fps": 15
}

View File

@ -0,0 +1,51 @@
{
"citation": "",
"description": "",
"features": {
"observation.image": {
"_type": "Image"
},
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 4,
"_type": "Sequence"
},
"action": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 3,
"_type": "Sequence"
},
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"next.reward": {
"dtype": "float32",
"_type": "Value"
},
"next.done": {
"dtype": "bool",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"
}
},
"homepage": "",
"license": ""
}

View File

@ -0,0 +1,13 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "c900258061dd0b3f",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",
"_output_all_columns": false,
"_split": null
}

View File

@ -0,0 +1,3 @@
{
"fps": 15
}

View File

@ -0,0 +1,51 @@
{
"citation": "",
"description": "",
"features": {
"observation.image": {
"_type": "Image"
},
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 4,
"_type": "Sequence"
},
"action": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 3,
"_type": "Sequence"
},
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"next.reward": {
"dtype": "float32",
"_type": "Value"
},
"next.done": {
"dtype": "bool",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"
}
},
"homepage": "",
"license": ""
}

View File

@ -0,0 +1,13 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "e51c80a33c7688c0",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",
"_output_all_columns": false,
"_split": null
}

View File

@ -0,0 +1,71 @@
"""
This script provides a utility for saving a dataset as safetensors files for the purpose of testing backward compatibility
when updating the data format. It uses the `PushtDataset` to create a DataLoader and saves selected frame from the
dataset into a corresponding safetensors file in a specified output directory.
If you know that your change will break backward compatibility, you should write a shortlived test by modifying
`tests/test_datasets.py::test_backward_compatibility` accordingly, and make sure this custom test pass. Your custom test
doesnt need to be merged into the `main` branch. Then you need to run this script and update the tests artifacts.
Example usage:
`python tests/script/save_dataset_to_safetensors.py`
"""
import shutil
from pathlib import Path
from safetensors.torch import save_file
from lerobot.common.datasets.pusht import PushtDataset
def save_dataset_to_safetensors(output_dir, dataset_id="pusht"):
data_dir = Path(output_dir) / dataset_id
if data_dir.exists():
shutil.rmtree(data_dir)
data_dir.mkdir(parents=True, exist_ok=True)
# TODO(rcadene): make it work for all datasets with LeRobotDataset(repo_id)
dataset = PushtDataset(
dataset_id=dataset_id,
split="train",
)
# save 2 first frames of first episode
i = dataset.episode_data_index["from"][0].item()
save_file(dataset[i], data_dir / f"frame_{i}.safetensors")
save_file(dataset[i + 1], data_dir / f"frame_{i+1}.safetensors")
# save 2 frames at the middle of first episode
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
save_file(dataset[i], data_dir / f"frame_{i}.safetensors")
save_file(dataset[i + 1], data_dir / f"frame_{i+1}.safetensors")
# save 2 last frames of first episode
i = dataset.episode_data_index["to"][0].item()
save_file(dataset[i - 2], data_dir / f"frame_{i-2}.safetensors")
save_file(dataset[i - 1], data_dir / f"frame_{i-1}.safetensors")
# TODO(rcadene): Enable testing on second and last episode
# We currently cant because our test dataset only contains the first episode
# # save 2 first frames of second episode
# i = dataset.episode_data_index["from"][1].item()
# save_file(dataset[i], data_dir / f"frame_{i}.safetensors")
# save_file(dataset[i+1], data_dir / f"frame_{i+1}.safetensors")
# # save 2 last frames of second episode
# i = dataset.episode_data_index["to"][1].item()
# save_file(dataset[i-2], data_dir / f"frame_{i-2}.safetensors")
# save_file(dataset[i-1], data_dir / f"frame_{i-1}.safetensors")
# # save 2 last frames of last episode
# i = dataset.episode_data_index["to"][-1].item()
# save_file(dataset[i-2], data_dir / f"frame_{i-2}.safetensors")
# save_file(dataset[i-1], data_dir / f"frame_{i-1}.safetensors")
if __name__ == "__main__":
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors")

View File

@ -1,20 +1,26 @@
import json
import logging
import os
from copy import deepcopy
from pathlib import Path
import einops
import pytest
import torch
from datasets import Dataset
from safetensors.torch import load_file
import lerobot
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.pusht import PushtDataset
from lerobot.common.datasets.utils import (
compute_stats,
flatten_dict,
get_stats_einops_patterns,
hf_transform_to_torch,
load_previous_and_future_frames,
unflatten_dict,
)
from lerobot.common.transforms import Prod
from lerobot.common.utils.utils import init_hydra_config
from .utils import DEFAULT_CONFIG_PATH, DEVICE
@ -39,8 +45,8 @@ def test_factory(env_name, dataset_id, policy_name):
keys_ndim_required = [
("action", 1, True),
("episode_id", 0, True),
("frame_id", 0, True),
("episode_index", 0, True),
("frame_index", 0, True),
("timestamp", 0, True),
# TODO(rcadene): should we rename it agent_pos?
("observation.state", 1, True),
@ -48,12 +54,6 @@ def test_factory(env_name, dataset_id, policy_name):
("next.done", 0, False),
]
for key in image_keys:
keys_ndim_required.append(
(key, 3, True),
)
assert dataset.hf_dataset[key].dtype == torch.uint8, f"{key}"
# test number of dimensions
for key, ndim, required in keys_ndim_required:
if key not in item:
@ -94,26 +94,21 @@ def test_compute_stats_on_xarm():
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
because we are working with a small dataset).
"""
# TODO(rcadene): Reduce size of dataset sample on which stats compute is tested
from lerobot.common.datasets.xarm import XarmDataset
data_dir = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
# get transform to convert images from uint8 [0,255] to float32 [0,1]
transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0)
dataset = XarmDataset(
dataset_id="xarm_lift_medium",
root=data_dir,
transform=transform,
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
)
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
# dataset into even batches.
computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25))
computed_stats = compute_stats(dataset.hf_dataset, batch_size=int(len(dataset) * 0.25))
# get einops patterns to aggregate batches and compute statistics
stats_patterns = get_stats_einops_patterns(dataset)
stats_patterns = get_stats_einops_patterns(dataset.hf_dataset)
# get all frames from the dataset in the same dtype and range as during compute_stats
dataloader = torch.utils.data.DataLoader(
@ -122,18 +117,19 @@ def test_compute_stats_on_xarm():
batch_size=len(dataset),
shuffle=False,
)
hf_dataset = next(iter(dataloader))
full_batch = next(iter(dataloader))
# compute stats based on all frames from the dataset without any batching
expected_stats = {}
for k, pattern in stats_patterns.items():
full_batch[k] = full_batch[k].float()
expected_stats[k] = {}
expected_stats[k]["mean"] = einops.reduce(hf_dataset[k], pattern, "mean")
expected_stats[k]["mean"] = einops.reduce(full_batch[k], pattern, "mean")
expected_stats[k]["std"] = torch.sqrt(
einops.reduce((hf_dataset[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
einops.reduce((full_batch[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
)
expected_stats[k]["min"] = einops.reduce(hf_dataset[k], pattern, "min")
expected_stats[k]["max"] = einops.reduce(hf_dataset[k], pattern, "max")
expected_stats[k]["min"] = einops.reduce(full_batch[k], pattern, "min")
expected_stats[k]["max"] = einops.reduce(full_batch[k], pattern, "max")
# test computed stats match expected stats
for k in stats_patterns:
@ -142,11 +138,10 @@ def test_compute_stats_on_xarm():
assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"])
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
# TODO(rcadene): check that the stats used for training are correct too
# # load stats that are expected to match the ones returned by computed_stats
# assert (dataset.data_dir / "stats.pth").exists()
# loaded_stats = torch.load(dataset.data_dir / "stats.pth")
# load stats used during training which are expected to match the ones returned by computed_stats
loaded_stats = dataset.stats # noqa: F841
# TODO(rcadene): we can't test this because expected_stats is computed on a subset
# # test loaded stats match expected stats
# for k in stats_patterns:
# assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"])
@ -160,15 +155,18 @@ def test_load_previous_and_future_frames_within_tolerance():
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2]
hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.2, 0, 0.139]}
tol = 0.04
item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
item = hf_dataset[2]
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
assert not is_pad.any(), "Unexpected padding detected"
@ -179,16 +177,19 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2]
hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.2, 0, 0.141]}
tol = 0.04
item = hf_dataset[2]
with pytest.raises(AssertionError):
load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range():
@ -196,17 +197,102 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2]
hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
tol = 0.04
item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
item = hf_dataset[2]
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
assert torch.equal(
is_pad, torch.tensor([True, False, False, True, True])
), "Padding does not match expected values"
def test_flatten_unflatten_dict():
d = {
"obs": {
"min": 0,
"max": 1,
"mean": 2,
"std": 3,
},
"action": {
"min": 4,
"max": 5,
"mean": 6,
"std": 7,
},
}
original_d = deepcopy(d)
d = unflatten_dict(flatten_dict(d))
# test equality between nested dicts
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
def test_backward_compatibility():
"""This tests artifacts have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
# TODO(rcadene): make it work for all datasets with LeRobotDataset(repo_id)
dataset_id = "pusht"
data_dir = Path("tests/data/save_dataset_to_safetensors") / dataset_id
dataset = PushtDataset(
dataset_id=dataset_id,
split="train",
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
)
def load_and_compare(i):
new_frame = dataset[i]
old_frame = load_file(data_dir / f"frame_{i}.safetensors")
new_keys = set(new_frame.keys())
old_keys = set(old_frame.keys())
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
for key in new_frame:
assert (
new_frame[key] == old_frame[key]
).all(), f"{key=} for index={i} does not contain the same value"
# test2 first frames of first episode
i = dataset.episode_data_index["from"][0].item()
load_and_compare(i)
load_and_compare(i + 1)
# test 2 frames at the middle of first episode
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
load_and_compare(i)
load_and_compare(i + 1)
# test 2 last frames of first episode
i = dataset.episode_data_index["to"][0].item()
load_and_compare(i - 2)
load_and_compare(i - 1)
# TODO(rcadene): Enable testing on second and last episode
# We currently cant because our test dataset only contains the first episode
# # test 2 first frames of second episode
# i = dataset.episode_data_index["from"][1].item()
# load_and_compare(i)
# load_and_compare(i+1)
# #test 2 last frames of second episode
# i = dataset.episode_data_index["to"][1].item()
# load_and_compare(i-2)
# load_and_compare(i-1)
# # test 2 last frames of last episode
# i = dataset.episode_data_index["to"][-1].item()
# load_and_compare(i-2)
# load_and_compare(i-1)

View File

@ -1,3 +1,4 @@
# TODO(aliberts): Mute logging for these tests
import subprocess
from pathlib import Path

View File

@ -0,0 +1,31 @@
import pytest
from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.visualize_dataset import visualize_dataset
from .utils import DEFAULT_CONFIG_PATH
@pytest.mark.parametrize(
"dataset_id",
[
"aloha_sim_insertion_human",
],
)
def test_visualize_dataset(tmpdir, dataset_id):
# TODO(rcadene): this test might fail with other datasets/policies/envs, since visualization_dataset
# doesnt support multiple timesteps which requires delta_timestamps to None for images.
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[
"policy=act",
"env=aloha",
f"dataset_id={dataset_id}",
],
)
video_paths = visualize_dataset(cfg, out_dir=tmpdir)
assert len(video_paths) > 0
for video_path in video_paths:
assert video_path.exists()