Loads episode_data_index and stats during dataset __init__ (#85)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
parent
e2168163cd
commit
1030ea0070
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "absl-py"
|
||||
|
@ -522,21 +522,21 @@ toml = ["tomli"]
|
|||
|
||||
[[package]]
|
||||
name = "datasets"
|
||||
version = "2.18.0"
|
||||
version = "2.19.0"
|
||||
description = "HuggingFace community-driven open-source library of datasets"
|
||||
optional = false
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "datasets-2.18.0-py3-none-any.whl", hash = "sha256:f1bbf0e2896917a914de01cbd37075b14deea3837af87ad0d9f697388ccaeb50"},
|
||||
{file = "datasets-2.18.0.tar.gz", hash = "sha256:cdf8b8c6abf7316377ba4f49f9589a4c74556d6b481afd0abd2284f3d69185cb"},
|
||||
{file = "datasets-2.19.0-py3-none-any.whl", hash = "sha256:f57c5316e123d4721b970c68c1cb856505f289cda58f5557ffe745b49c011a8e"},
|
||||
{file = "datasets-2.19.0.tar.gz", hash = "sha256:0b47e08cc7af2c6800a42cadc4657b22a0afc7197786c8986d703c08d90886a6"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
aiohttp = "*"
|
||||
dill = ">=0.3.0,<0.3.9"
|
||||
filelock = "*"
|
||||
fsspec = {version = ">=2023.1.0,<=2024.2.0", extras = ["http"]}
|
||||
huggingface-hub = ">=0.19.4"
|
||||
fsspec = {version = ">=2023.1.0,<=2024.3.1", extras = ["http"]}
|
||||
huggingface-hub = ">=0.21.2"
|
||||
multiprocess = "*"
|
||||
numpy = ">=1.17"
|
||||
packaging = "*"
|
||||
|
@ -552,15 +552,15 @@ xxhash = "*"
|
|||
apache-beam = ["apache-beam (>=2.26.0)"]
|
||||
audio = ["librosa", "soundfile (>=0.12.1)"]
|
||||
benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"]
|
||||
dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
|
||||
docs = ["s3fs", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos", "torch", "transformers"]
|
||||
dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
|
||||
docs = ["s3fs", "tensorflow (>=2.6.0)", "torch", "transformers"]
|
||||
jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"]
|
||||
metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"]
|
||||
quality = ["ruff (>=0.3.0)"]
|
||||
s3 = ["s3fs"]
|
||||
tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos"]
|
||||
tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"]
|
||||
tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
|
||||
tensorflow = ["tensorflow (>=2.6.0)"]
|
||||
tensorflow-gpu = ["tensorflow (>=2.6.0)"]
|
||||
tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
|
||||
torch = ["torch"]
|
||||
vision = ["Pillow (>=6.2.1)"]
|
||||
|
||||
|
@ -1524,7 +1524,6 @@ description = "Powerful and Pythonic XML processing library combining libxml2/li
|
|||
optional = true
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "lxml-5.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:704f5572ff473a5f897745abebc6df40f22d4133c1e0a1f124e4f2bd3330ff7e"},
|
||||
{file = "lxml-5.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9d3c0f8567ffe7502d969c2c1b809892dc793b5d0665f602aad19895f8d508da"},
|
||||
{file = "lxml-5.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5fcfbebdb0c5d8d18b84118842f31965d59ee3e66996ac842e21f957eb76138c"},
|
||||
{file = "lxml-5.1.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2f37c6d7106a9d6f0708d4e164b707037b7380fcd0b04c5bd9cae1fb46a856fb"},
|
||||
|
@ -1534,7 +1533,6 @@ files = [
|
|||
{file = "lxml-5.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:82bddf0e72cb2af3cbba7cec1d2fd11fda0de6be8f4492223d4a268713ef2147"},
|
||||
{file = "lxml-5.1.0-cp310-cp310-win32.whl", hash = "sha256:b66aa6357b265670bb574f050ffceefb98549c721cf28351b748be1ef9577d93"},
|
||||
{file = "lxml-5.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:4946e7f59b7b6a9e27bef34422f645e9a368cb2be11bf1ef3cafc39a1f6ba68d"},
|
||||
{file = "lxml-5.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:14deca1460b4b0f6b01f1ddc9557704e8b365f55c63070463f6c18619ebf964f"},
|
||||
{file = "lxml-5.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ed8c3d2cd329bf779b7ed38db176738f3f8be637bb395ce9629fc76f78afe3d4"},
|
||||
{file = "lxml-5.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:436a943c2900bb98123b06437cdd30580a61340fbdb7b28aaf345a459c19046a"},
|
||||
{file = "lxml-5.1.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:acb6b2f96f60f70e7f34efe0c3ea34ca63f19ca63ce90019c6cbca6b676e81fa"},
|
||||
|
@ -1544,7 +1542,6 @@ files = [
|
|||
{file = "lxml-5.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f4c9bda132ad108b387c33fabfea47866af87f4ea6ffb79418004f0521e63204"},
|
||||
{file = "lxml-5.1.0-cp311-cp311-win32.whl", hash = "sha256:bc64d1b1dab08f679fb89c368f4c05693f58a9faf744c4d390d7ed1d8223869b"},
|
||||
{file = "lxml-5.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:a5ab722ae5a873d8dcee1f5f45ddd93c34210aed44ff2dc643b5025981908cda"},
|
||||
{file = "lxml-5.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9aa543980ab1fbf1720969af1d99095a548ea42e00361e727c58a40832439114"},
|
||||
{file = "lxml-5.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6f11b77ec0979f7e4dc5ae081325a2946f1fe424148d3945f943ceaede98adb8"},
|
||||
{file = "lxml-5.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a36c506e5f8aeb40680491d39ed94670487ce6614b9d27cabe45d94cd5d63e1e"},
|
||||
{file = "lxml-5.1.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f643ffd2669ffd4b5a3e9b41c909b72b2a1d5e4915da90a77e119b8d48ce867a"},
|
||||
|
@ -1570,8 +1567,8 @@ files = [
|
|||
{file = "lxml-5.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8f52fe6859b9db71ee609b0c0a70fea5f1e71c3462ecf144ca800d3f434f0764"},
|
||||
{file = "lxml-5.1.0-cp37-cp37m-win32.whl", hash = "sha256:d42e3a3fc18acc88b838efded0e6ec3edf3e328a58c68fbd36a7263a874906c8"},
|
||||
{file = "lxml-5.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:eac68f96539b32fce2c9b47eb7c25bb2582bdaf1bbb360d25f564ee9e04c542b"},
|
||||
{file = "lxml-5.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ae15347a88cf8af0949a9872b57a320d2605ae069bcdf047677318bc0bba45b1"},
|
||||
{file = "lxml-5.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c26aab6ea9c54d3bed716b8851c8bfc40cb249b8e9880e250d1eddde9f709bf5"},
|
||||
{file = "lxml-5.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cfbac9f6149174f76df7e08c2e28b19d74aed90cad60383ad8671d3af7d0502f"},
|
||||
{file = "lxml-5.1.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:342e95bddec3a698ac24378d61996b3ee5ba9acfeb253986002ac53c9a5f6f84"},
|
||||
{file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:725e171e0b99a66ec8605ac77fa12239dbe061482ac854d25720e2294652eeaa"},
|
||||
{file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d184e0d5c918cff04cdde9dbdf9600e960161d773666958c9d7b565ccc60c45"},
|
||||
|
@ -1579,7 +1576,6 @@ files = [
|
|||
{file = "lxml-5.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6d48fc57e7c1e3df57be5ae8614bab6d4e7b60f65c5457915c26892c41afc59e"},
|
||||
{file = "lxml-5.1.0-cp38-cp38-win32.whl", hash = "sha256:7ec465e6549ed97e9f1e5ed51c657c9ede767bc1c11552f7f4d022c4df4a977a"},
|
||||
{file = "lxml-5.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:b21b4031b53d25b0858d4e124f2f9131ffc1530431c6d1321805c90da78388d1"},
|
||||
{file = "lxml-5.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:52427a7eadc98f9e62cb1368a5079ae826f94f05755d2d567d93ee1bc3ceb354"},
|
||||
{file = "lxml-5.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6a2a2c724d97c1eb8cf966b16ca2915566a4904b9aad2ed9a09c748ffe14f969"},
|
||||
{file = "lxml-5.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:843b9c835580d52828d8f69ea4302537337a21e6b4f1ec711a52241ba4a824f3"},
|
||||
{file = "lxml-5.1.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9b99f564659cfa704a2dd82d0684207b1aadf7d02d33e54845f9fc78e06b7581"},
|
||||
|
@ -2688,7 +2684,6 @@ files = [
|
|||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
|
@ -3919,4 +3914,4 @@ xarm = ["gym-xarm"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "bd9c506d2499d5e1e3b5e8b1a0f65df45c8feef38d89d0daeade56847fdb6a2e"
|
||||
content-hash = "e526416d1282dea2550680b2be7fcf9ff6e1c67ac89d34c684b486d94a6addee"
|
||||
|
|
|
@ -53,7 +53,7 @@ pre-commit = {version = "^3.7.0", optional = true}
|
|||
debugpy = {version = "^1.8.1", optional = true}
|
||||
pytest = {version = "^8.1.0", optional = true}
|
||||
pytest-cov = {version = "^5.0.0", optional = true}
|
||||
datasets = "^2.18.0"
|
||||
datasets = "^2.19.0"
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
|
|
@ -208,7 +208,7 @@ HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATAS
|
|||
|
||||
You will need to set the corresponding version as a default argument in your dataset class:
|
||||
```python
|
||||
version: str | None = "v1.0",
|
||||
version: str | None = "v1.1",
|
||||
```
|
||||
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ useless dependencies when using datasets.
|
|||
"""
|
||||
|
||||
import io
|
||||
import json
|
||||
import pickle
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
@ -14,16 +15,20 @@ import numpy as np
|
|||
import torch
|
||||
import tqdm
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from huggingface_hub import HfApi
|
||||
from PIL import Image as PILImage
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.utils import compute_stats, flatten_dict, hf_transform_to_torch
|
||||
|
||||
|
||||
def download_and_upload(root, root_tests, dataset_id):
|
||||
def download_and_upload(root, revision, dataset_id):
|
||||
if "pusht" in dataset_id:
|
||||
download_and_upload_pusht(root, root_tests, dataset_id)
|
||||
download_and_upload_pusht(root, revision, dataset_id)
|
||||
elif "xarm" in dataset_id:
|
||||
download_and_upload_xarm(root, root_tests, dataset_id)
|
||||
download_and_upload_xarm(root, revision, dataset_id)
|
||||
elif "aloha" in dataset_id:
|
||||
download_and_upload_aloha(root, root_tests, dataset_id)
|
||||
download_and_upload_aloha(root, revision, dataset_id)
|
||||
else:
|
||||
raise ValueError(dataset_id)
|
||||
|
||||
|
@ -56,7 +61,102 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
|
||||
def concatenate_episodes(ep_dicts):
|
||||
data_dict = {}
|
||||
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
if torch.is_tensor(ep_dicts[0][key][0]):
|
||||
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
||||
else:
|
||||
if key not in data_dict:
|
||||
data_dict[key] = []
|
||||
for ep_dict in ep_dicts:
|
||||
for x in ep_dict[key]:
|
||||
data_dict[key].append(x)
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
return data_dict
|
||||
|
||||
|
||||
def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id):
|
||||
# push to main to indicate latest version
|
||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||
|
||||
# push to version branch
|
||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision=revision)
|
||||
|
||||
# create and store meta_data
|
||||
meta_data_dir = root / dataset_id / "meta_data"
|
||||
meta_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
api = HfApi()
|
||||
|
||||
# info
|
||||
info_path = meta_data_dir / "info.json"
|
||||
with open(str(info_path), "w") as f:
|
||||
json.dump(info, f, indent=4)
|
||||
api.upload_file(
|
||||
path_or_fileobj=info_path,
|
||||
path_in_repo=str(info_path).replace(f"{root}/{dataset_id}", ""),
|
||||
repo_id=f"lerobot/{dataset_id}",
|
||||
repo_type="dataset",
|
||||
)
|
||||
api.upload_file(
|
||||
path_or_fileobj=info_path,
|
||||
path_in_repo=str(info_path).replace(f"{root}/{dataset_id}", ""),
|
||||
repo_id=f"lerobot/{dataset_id}",
|
||||
repo_type="dataset",
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
# stats
|
||||
stats_path = meta_data_dir / "stats.safetensors"
|
||||
save_file(flatten_dict(stats), stats_path)
|
||||
api.upload_file(
|
||||
path_or_fileobj=stats_path,
|
||||
path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""),
|
||||
repo_id=f"lerobot/{dataset_id}",
|
||||
repo_type="dataset",
|
||||
)
|
||||
api.upload_file(
|
||||
path_or_fileobj=stats_path,
|
||||
path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""),
|
||||
repo_id=f"lerobot/{dataset_id}",
|
||||
repo_type="dataset",
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
# episode_data_index
|
||||
episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
|
||||
ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
|
||||
save_file(episode_data_index, ep_data_idx_path)
|
||||
api.upload_file(
|
||||
path_or_fileobj=ep_data_idx_path,
|
||||
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{dataset_id}", ""),
|
||||
repo_id=f"lerobot/{dataset_id}",
|
||||
repo_type="dataset",
|
||||
)
|
||||
api.upload_file(
|
||||
path_or_fileobj=ep_data_idx_path,
|
||||
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{dataset_id}", ""),
|
||||
repo_id=f"lerobot/{dataset_id}",
|
||||
repo_type="dataset",
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
# copy in tests folder, the first episode and the meta_data directory
|
||||
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
||||
hf_dataset.select(range(num_items_first_ep)).with_format("torch").save_to_disk(
|
||||
f"tests/data/{dataset_id}/train"
|
||||
)
|
||||
if Path(f"tests/data/{dataset_id}/meta_data").exists():
|
||||
shutil.rmtree(f"tests/data/{dataset_id}/meta_data")
|
||||
shutil.copytree(meta_data_dir, f"tests/data/{dataset_id}/meta_data")
|
||||
|
||||
|
||||
def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
|
||||
try:
|
||||
import pymunk
|
||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||
|
@ -99,6 +199,7 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
|
|||
actions = torch.from_numpy(dataset_dict["action"])
|
||||
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
id_from = 0
|
||||
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||
|
@ -151,8 +252,8 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
|
|||
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
|
||||
"observation.state": agent_pos,
|
||||
"action": actions[id_from:id_to],
|
||||
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_index": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||
# "next.observation.image": image[1:],
|
||||
# "next.observation.state": agent_pos[1:],
|
||||
|
@ -160,28 +261,15 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
|
|||
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
|
||||
"next.done": torch.cat([done[1:], done[[-1]]]),
|
||||
"next.success": torch.cat([success[1:], success[[-1]]]),
|
||||
"episode_data_index_from": torch.tensor([id_from] * num_frames),
|
||||
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
|
||||
}
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
|
||||
id_from += num_frames
|
||||
|
||||
data_dict = {}
|
||||
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
if torch.is_tensor(ep_dicts[0][key][0]):
|
||||
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
||||
else:
|
||||
if key not in data_dict:
|
||||
data_dict[key] = []
|
||||
for ep_dict in ep_dicts:
|
||||
for x in ep_dict[key]:
|
||||
data_dict[key].append(x)
|
||||
|
||||
total_frames = id_from
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
features = {
|
||||
"observation.image": Image(),
|
||||
|
@ -189,35 +277,35 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
|
|||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
),
|
||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||
"episode_id": Value(dtype="int64", id=None),
|
||||
"frame_id": Value(dtype="int64", id=None),
|
||||
"episode_index": Value(dtype="int64", id=None),
|
||||
"frame_index": Value(dtype="int64", id=None),
|
||||
"timestamp": Value(dtype="float32", id=None),
|
||||
"next.reward": Value(dtype="float32", id=None),
|
||||
"next.done": Value(dtype="bool", id=None),
|
||||
"next.success": Value(dtype="bool", id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
"episode_data_index_from": Value(dtype="int64", id=None),
|
||||
"episode_data_index_to": Value(dtype="int64", id=None),
|
||||
}
|
||||
features = Features(features)
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset = hf_dataset.with_format("torch")
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
||||
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
||||
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
|
||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||
info = {
|
||||
"fps": fps,
|
||||
}
|
||||
stats = compute_stats(hf_dataset)
|
||||
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
|
||||
|
||||
|
||||
def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
|
||||
def download_and_upload_xarm(root, revision, dataset_id, fps=15):
|
||||
root = Path(root)
|
||||
raw_dir = root / f"{dataset_id}_raw"
|
||||
raw_dir = root / "xarm_datasets_raw"
|
||||
if not raw_dir.exists():
|
||||
import zipfile
|
||||
|
||||
import gdown
|
||||
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
# from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py
|
||||
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
||||
zip_path = raw_dir / "data.zip"
|
||||
gdown.download(url, str(zip_path), quiet=False)
|
||||
|
@ -234,13 +322,13 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
|
|||
with open(dataset_path, "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
|
||||
total_frames = dataset_dict["actions"].shape[0]
|
||||
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
id_from = 0
|
||||
id_to = 0
|
||||
episode_id = 0
|
||||
total_frames = dataset_dict["actions"].shape[0]
|
||||
for i in tqdm.tqdm(range(total_frames)):
|
||||
id_to += 1
|
||||
|
||||
|
@ -264,35 +352,23 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
|
|||
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
|
||||
"observation.state": state,
|
||||
"action": action,
|
||||
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_index": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||
# "next.observation.image": next_image,
|
||||
# "next.observation.state": next_state,
|
||||
"next.reward": next_reward,
|
||||
"next.done": next_done,
|
||||
"episode_data_index_from": torch.tensor([id_from] * num_frames),
|
||||
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
|
||||
}
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
|
||||
id_from = id_to
|
||||
episode_id += 1
|
||||
|
||||
data_dict = {}
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
if torch.is_tensor(ep_dicts[0][key][0]):
|
||||
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
||||
else:
|
||||
if key not in data_dict:
|
||||
data_dict[key] = []
|
||||
for ep_dict in ep_dicts:
|
||||
for x in ep_dict[key]:
|
||||
data_dict[key].append(x)
|
||||
|
||||
total_frames = id_from
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
features = {
|
||||
"observation.image": Image(),
|
||||
|
@ -300,27 +376,26 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
|
|||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
),
|
||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||
"episode_id": Value(dtype="int64", id=None),
|
||||
"frame_id": Value(dtype="int64", id=None),
|
||||
"episode_index": Value(dtype="int64", id=None),
|
||||
"frame_index": Value(dtype="int64", id=None),
|
||||
"timestamp": Value(dtype="float32", id=None),
|
||||
"next.reward": Value(dtype="float32", id=None),
|
||||
"next.done": Value(dtype="bool", id=None),
|
||||
#'next.success': Value(dtype='bool', id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
"episode_data_index_from": Value(dtype="int64", id=None),
|
||||
"episode_data_index_to": Value(dtype="int64", id=None),
|
||||
}
|
||||
features = Features(features)
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset = hf_dataset.with_format("torch")
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
||||
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
||||
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
|
||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||
info = {
|
||||
"fps": fps,
|
||||
}
|
||||
stats = compute_stats(hf_dataset)
|
||||
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
|
||||
|
||||
|
||||
def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
|
||||
def download_and_upload_aloha(root, revision, dataset_id, fps=50):
|
||||
folder_urls = {
|
||||
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
|
||||
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
|
||||
|
@ -381,6 +456,7 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
|
|||
gdown.download(ep49_urls[dataset_id], output=str(raw_dir / "episode_49.hdf5"), fuzzy=True)
|
||||
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
id_from = 0
|
||||
for ep_id in tqdm.tqdm(range(num_episodes[dataset_id])):
|
||||
|
@ -408,40 +484,26 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
|
|||
{
|
||||
"observation.state": state,
|
||||
"action": action,
|
||||
"episode_id": torch.tensor([ep_id] * num_frames),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"episode_index": torch.tensor([ep_id] * num_frames),
|
||||
"frame_index": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||
# "next.observation.state": state,
|
||||
# TODO(rcadene): compute reward and success
|
||||
# "next.reward": reward,
|
||||
"next.done": done,
|
||||
# "next.success": success,
|
||||
"episode_data_index_from": torch.tensor([id_from] * num_frames),
|
||||
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
|
||||
}
|
||||
)
|
||||
|
||||
assert isinstance(ep_id, int)
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
|
||||
id_from += num_frames
|
||||
|
||||
data_dict = {}
|
||||
|
||||
data_dict = {}
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
if torch.is_tensor(ep_dicts[0][key][0]):
|
||||
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
||||
else:
|
||||
if key not in data_dict:
|
||||
data_dict[key] = []
|
||||
for ep_dict in ep_dicts:
|
||||
for x in ep_dict[key]:
|
||||
data_dict[key].append(x)
|
||||
|
||||
total_frames = id_from
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
features = {
|
||||
"observation.images.top": Image(),
|
||||
|
@ -449,39 +511,39 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
|
|||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
),
|
||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||
"episode_id": Value(dtype="int64", id=None),
|
||||
"frame_id": Value(dtype="int64", id=None),
|
||||
"episode_index": Value(dtype="int64", id=None),
|
||||
"frame_index": Value(dtype="int64", id=None),
|
||||
"timestamp": Value(dtype="float32", id=None),
|
||||
#'next.reward': Value(dtype='float32', id=None),
|
||||
"next.done": Value(dtype="bool", id=None),
|
||||
#'next.success': Value(dtype='bool', id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
"episode_data_index_from": Value(dtype="int64", id=None),
|
||||
"episode_data_index_to": Value(dtype="int64", id=None),
|
||||
}
|
||||
features = Features(features)
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset = hf_dataset.with_format("torch")
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
||||
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
||||
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
|
||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||
info = {
|
||||
"fps": fps,
|
||||
}
|
||||
stats = compute_stats(hf_dataset)
|
||||
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
root = "data"
|
||||
root_tests = "tests/data"
|
||||
revision = "v1.1"
|
||||
|
||||
dataset_ids = [
|
||||
# "pusht",
|
||||
# "xarm_lift_medium",
|
||||
# "aloha_sim_insertion_human",
|
||||
# "aloha_sim_insertion_scripted",
|
||||
# "aloha_sim_transfer_cube_human",
|
||||
"pusht",
|
||||
"xarm_lift_medium",
|
||||
"xarm_lift_medium_replay",
|
||||
"xarm_push_medium",
|
||||
"xarm_push_medium_replay",
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
"aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
]
|
||||
for dataset_id in dataset_ids:
|
||||
download_and_upload(root, root_tests, dataset_id)
|
||||
# assume stats have been precomputed
|
||||
shutil.copy(f"{root}/{dataset_id}/stats.pth", f"{root_tests}/{dataset_id}/stats.pth")
|
||||
download_and_upload(root, revision, dataset_id)
|
||||
|
|
|
@ -10,10 +10,13 @@ As an example, this script saves frames of episode number 5 of the PushT dataset
|
|||
This script supports several Hugging Face datasets, among which:
|
||||
1. [Pusht](https://huggingface.co/datasets/lerobot/pusht)
|
||||
2. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_lift_medium)
|
||||
3. [Aloha Sim Insertion Human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
|
||||
4. [Aloha Sim Insertion Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
|
||||
5. [Aloha Sim Transfer Cube Human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
|
||||
6. [Aloha Sim Transfer Cube Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
|
||||
3. [Xarm Lift Medium Replay](https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay)
|
||||
4. [Xarm Push Medium](https://huggingface.co/datasets/lerobot/xarm_push_medium)
|
||||
5. [Xarm Push Medium Replay](https://huggingface.co/datasets/lerobot/xarm_push_medium_replay)
|
||||
6. [Aloha Sim Insertion Human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
|
||||
7. [Aloha Sim Insertion Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
|
||||
8. [Aloha Sim Transfer Cube Human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
|
||||
9. [Aloha Sim Transfer Cube Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
|
||||
|
||||
To try a different Hugging Face dataset, you can replace this line:
|
||||
```python
|
||||
|
@ -22,12 +25,16 @@ hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
|
|||
by one of these:
|
||||
```python
|
||||
hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium", split="train"), 15
|
||||
hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium_replay", split="train"), 15
|
||||
hf_dataset, fps = load_dataset("lerobot/xarm_push_medium", split="train"), 15
|
||||
hf_dataset, fps = load_dataset("lerobot/xarm_push_medium_replay", split="train"), 15
|
||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_human", split="train"), 50
|
||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_scripted", split="train"), 50
|
||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_human", split="train"), 50
|
||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_scripted", split="train"), 50
|
||||
```
|
||||
"""
|
||||
# TODO(rcadene): remove this example file of using hf_dataset
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -37,19 +44,22 @@ from datasets import load_dataset
|
|||
# TODO(rcadene): list available datasets on lerobot page using `datasets`
|
||||
|
||||
# download/load hugging face dataset in pyarrow format
|
||||
hf_dataset, fps = load_dataset("lerobot/pusht", revision="v1.0", split="train"), 10
|
||||
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
|
||||
|
||||
# display name of dataset and its features
|
||||
# TODO(rcadene): update to make the print pretty
|
||||
print(f"{hf_dataset=}")
|
||||
print(f"{hf_dataset.features=}")
|
||||
|
||||
# display useful statistics about frames and episodes, which are sequences of frames from the same video
|
||||
print(f"number of frames: {len(hf_dataset)=}")
|
||||
print(f"number of episodes: {len(hf_dataset.unique('episode_id'))=}")
|
||||
print(f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_id')):.3f}")
|
||||
print(f"number of episodes: {len(hf_dataset.unique('episode_index'))=}")
|
||||
print(
|
||||
f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_index')):.3f}"
|
||||
)
|
||||
|
||||
# select the frames belonging to episode number 5
|
||||
hf_dataset = hf_dataset.filter(lambda frame: frame["episode_id"] == 5)
|
||||
hf_dataset = hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
|
||||
|
||||
# load all frames of episode 5 in RAM in PIL format
|
||||
frames = hf_dataset["observation.image"]
|
||||
|
|
|
@ -18,7 +18,10 @@ dataset = PushtDataset()
|
|||
```
|
||||
by one of these:
|
||||
```python
|
||||
dataset = XarmDataset()
|
||||
dataset = XarmDataset("xarm_lift_medium")
|
||||
dataset = XarmDataset("xarm_lift_medium_replay")
|
||||
dataset = XarmDataset("xarm_push_medium")
|
||||
dataset = XarmDataset("xarm_push_medium_replay")
|
||||
dataset = AlohaDataset("aloha_sim_insertion_human")
|
||||
dataset = AlohaDataset("aloha_sim_insertion_scripted")
|
||||
dataset = AlohaDataset("aloha_sim_transfer_cube_human")
|
||||
|
@ -44,6 +47,7 @@ from lerobot.common.datasets.pusht import PushtDataset
|
|||
dataset = PushtDataset()
|
||||
|
||||
# All LeRobot datasets are actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information).
|
||||
# TODO(rcadene): update to make the print pretty
|
||||
print(f"{dataset=}")
|
||||
print(f"{dataset.hf_dataset=}")
|
||||
|
||||
|
@ -55,13 +59,16 @@ print(f"frames per second used during data collection: {dataset.fps=}")
|
|||
print(f"keys to access images from cameras: {dataset.image_keys=}")
|
||||
|
||||
# While the LeRobot dataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
|
||||
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_id"] == 5)
|
||||
# TODO(rcadene): remove this example of accessing hf_dataset
|
||||
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
|
||||
|
||||
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grap all the image frames.
|
||||
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grab all the image frames.
|
||||
frames = [sample["observation.image"] for sample in dataset]
|
||||
|
||||
# but frames are now channel first to follow pytorch convention,
|
||||
# to view them, we convert to channel last
|
||||
# but frames are now float32 range [0,1] channel first (c,h,w) to follow pytorch convention,
|
||||
# to view them, we convert to uint8 range [0,255]
|
||||
frames = [(frame * 255).type(torch.uint8) for frame in frames]
|
||||
# and to channel last (h,w,c)
|
||||
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
|
||||
|
||||
# and finally save them to a mp4 video
|
||||
|
|
|
@ -50,7 +50,12 @@ available_datasets = {
|
|||
"aloha_sim_transfer_cube_scripted",
|
||||
],
|
||||
"pusht": ["pusht"],
|
||||
"xarm": ["xarm_lift_medium"],
|
||||
"xarm": [
|
||||
"xarm_lift_medium",
|
||||
"xarm_lift_medium_replay",
|
||||
"xarm_push_medium",
|
||||
"xarm_push_medium_replay",
|
||||
],
|
||||
}
|
||||
|
||||
available_policies = [
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
from lerobot.common.datasets.utils import load_previous_and_future_frames
|
||||
from lerobot.common.datasets.utils import (
|
||||
load_episode_data_index,
|
||||
load_hf_dataset,
|
||||
load_previous_and_future_frames,
|
||||
load_stats,
|
||||
)
|
||||
|
||||
|
||||
class AlohaDataset(torch.utils.data.Dataset):
|
||||
|
@ -27,7 +31,7 @@ class AlohaDataset(torch.utils.data.Dataset):
|
|||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.0",
|
||||
version: str | None = "v1.1",
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
|
@ -40,13 +44,10 @@ class AlohaDataset(torch.utils.data.Dataset):
|
|||
self.split = split
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
if self.root is not None:
|
||||
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
else:
|
||||
self.hf_dataset = load_dataset(
|
||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
||||
)
|
||||
self.hf_dataset = self.hf_dataset.with_format("torch")
|
||||
# load data from hub or locally when root is provided
|
||||
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
|
||||
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
|
||||
self.stats = load_stats(dataset_id, version, root)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
|
@ -54,7 +55,7 @@ class AlohaDataset(torch.utils.data.Dataset):
|
|||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.hf_dataset.unique("episode_id"))
|
||||
return len(self.hf_dataset.unique("episode_index"))
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
@ -66,19 +67,11 @@ class AlohaDataset(torch.utils.data.Dataset):
|
|||
item = load_previous_and_future_frames(
|
||||
item,
|
||||
self.hf_dataset,
|
||||
self.episode_data_index,
|
||||
self.delta_timestamps,
|
||||
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
|
||||
)
|
||||
|
||||
# convert images from channel last (PIL) to channel first (pytorch)
|
||||
for key in self.image_keys:
|
||||
if item[key].ndim == 3:
|
||||
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
|
||||
elif item[key].ndim == 4:
|
||||
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
|
||||
else:
|
||||
raise ValueError(item[key].ndim)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import v2
|
||||
|
||||
from lerobot.common.datasets.utils import compute_stats
|
||||
from lerobot.common.transforms import NormalizeTransform, Prod
|
||||
from lerobot.common.transforms import NormalizeTransform
|
||||
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
|
@ -52,32 +50,18 @@ def make_dataset(
|
|||
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
elif stats_path is None:
|
||||
# load stats if the file exists already or compute stats and save it
|
||||
if DATA_DIR is None:
|
||||
# TODO(rcadene): clean stats
|
||||
precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth"
|
||||
else:
|
||||
precomputed_stats_path = DATA_DIR / cfg.dataset_id / "stats.pth"
|
||||
if precomputed_stats_path.exists():
|
||||
stats = torch.load(precomputed_stats_path)
|
||||
else:
|
||||
logging.info(f"compute_stats and save to {precomputed_stats_path}")
|
||||
# Create a dataset for stats computation.
|
||||
stats_dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
split="train",
|
||||
root=DATA_DIR,
|
||||
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||
)
|
||||
stats = compute_stats(stats_dataset)
|
||||
precomputed_stats_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(stats, precomputed_stats_path)
|
||||
# load a first dataset to access precomputed stats
|
||||
stats_dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
split="train",
|
||||
root=DATA_DIR,
|
||||
)
|
||||
stats = stats_dataset.stats
|
||||
else:
|
||||
stats = torch.load(stats_path)
|
||||
|
||||
transforms = v2.Compose(
|
||||
[
|
||||
Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||
NormalizeTransform(
|
||||
stats,
|
||||
in_keys=[
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
from lerobot.common.datasets.utils import load_previous_and_future_frames
|
||||
from lerobot.common.datasets.utils import (
|
||||
load_episode_data_index,
|
||||
load_hf_dataset,
|
||||
load_previous_and_future_frames,
|
||||
load_stats,
|
||||
)
|
||||
|
||||
|
||||
class PushtDataset(torch.utils.data.Dataset):
|
||||
|
@ -25,7 +29,7 @@ class PushtDataset(torch.utils.data.Dataset):
|
|||
def __init__(
|
||||
self,
|
||||
dataset_id: str = "pusht",
|
||||
version: str | None = "v1.0",
|
||||
version: str | None = "v1.1",
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
|
@ -38,13 +42,10 @@ class PushtDataset(torch.utils.data.Dataset):
|
|||
self.split = split
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
if self.root is not None:
|
||||
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
else:
|
||||
self.hf_dataset = load_dataset(
|
||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
||||
)
|
||||
self.hf_dataset = self.hf_dataset.with_format("torch")
|
||||
# load data from hub or locally when root is provided
|
||||
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
|
||||
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
|
||||
self.stats = load_stats(dataset_id, version, root)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
|
@ -52,7 +53,7 @@ class PushtDataset(torch.utils.data.Dataset):
|
|||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.hf_dataset.unique("episode_id"))
|
||||
return len(self.episode_data_index["from"])
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
@ -64,19 +65,11 @@ class PushtDataset(torch.utils.data.Dataset):
|
|||
item = load_previous_and_future_frames(
|
||||
item,
|
||||
self.hf_dataset,
|
||||
self.episode_data_index,
|
||||
self.delta_timestamps,
|
||||
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
|
||||
)
|
||||
|
||||
# convert images from channel last (PIL) to channel first (pytorch)
|
||||
for key in self.image_keys:
|
||||
if item[key].ndim == 3:
|
||||
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
|
||||
elif item[key].ndim == 4:
|
||||
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
|
||||
else:
|
||||
raise ValueError(item[key].ndim)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
|
|
|
@ -1,15 +1,121 @@
|
|||
from copy import deepcopy
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import einops
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Image, load_dataset, load_from_disk
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image as PILImage
|
||||
from safetensors.torch import load_file
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
def flatten_dict(d, parent_key="", sep="/"):
|
||||
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
||||
|
||||
For example:
|
||||
```
|
||||
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}`
|
||||
>>> print(flatten_dict(dct))
|
||||
{"a/b": 1, "a/c/d": 2, "e": 3}
|
||||
"""
|
||||
items = []
|
||||
for k, v in d.items():
|
||||
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
||||
else:
|
||||
items.append((new_key, v))
|
||||
return dict(items)
|
||||
|
||||
|
||||
def unflatten_dict(d, sep="/"):
|
||||
outdict = {}
|
||||
for key, value in d.items():
|
||||
parts = key.split(sep)
|
||||
d = outdict
|
||||
for part in parts[:-1]:
|
||||
if part not in d:
|
||||
d[part] = {}
|
||||
d = d[part]
|
||||
d[parts[-1]] = value
|
||||
return outdict
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict):
|
||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||
a channel last representation (h w c) of uint8 type, to a torch image representation
|
||||
with channel first (c h w) of float32 type in range [0,1].
|
||||
"""
|
||||
for key in items_dict:
|
||||
first_item = items_dict[key][0]
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
else:
|
||||
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
|
||||
return items_dict
|
||||
|
||||
|
||||
def load_hf_dataset(dataset_id, version, root, split) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
if root is not None:
|
||||
hf_dataset = load_from_disk(str(Path(root) / dataset_id / split))
|
||||
else:
|
||||
# TODO(rcadene): remove dataset_id everywhere and use repo_id instead
|
||||
repo_id = f"lerobot/{dataset_id}"
|
||||
hf_dataset = load_dataset(repo_id, revision=version, split=split)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor]:
|
||||
"""episode_data_index contains the range of indices for each episode
|
||||
|
||||
Example:
|
||||
```python
|
||||
from_id = episode_data_index["from"][episode_id].item()
|
||||
to_id = episode_data_index["to"][episode_id].item()
|
||||
episode_frames = [dataset[i] for i in range(from_id, to_id)]
|
||||
```
|
||||
"""
|
||||
if root is not None:
|
||||
path = Path(root) / dataset_id / "meta_data" / "episode_data_index.safetensors"
|
||||
else:
|
||||
repo_id = f"lerobot/{dataset_id}"
|
||||
path = hf_hub_download(
|
||||
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version
|
||||
)
|
||||
|
||||
return load_file(path)
|
||||
|
||||
|
||||
def load_stats(dataset_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
|
||||
|
||||
Example:
|
||||
```python
|
||||
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
|
||||
```
|
||||
"""
|
||||
if root is not None:
|
||||
path = Path(root) / dataset_id / "meta_data" / "stats.safetensors"
|
||||
else:
|
||||
repo_id = f"lerobot/{dataset_id}"
|
||||
path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version)
|
||||
|
||||
stats = load_file(path)
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
def load_previous_and_future_frames(
|
||||
item: dict[str, torch.Tensor],
|
||||
hf_dataset: datasets.Dataset,
|
||||
episode_data_index: dict[str, torch.Tensor],
|
||||
delta_timestamps: dict[str, list[float]],
|
||||
tol: float,
|
||||
) -> dict[torch.Tensor]:
|
||||
|
@ -31,6 +137,8 @@ def load_previous_and_future_frames(
|
|||
corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
||||
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different
|
||||
modality (e.g., "timestamp", "observation.image", "action").
|
||||
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
|
||||
They indicate the start index and end index of each episode in the dataset.
|
||||
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be
|
||||
retrieved. These deltas are added to the item timestamp to form the query timestamps.
|
||||
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query
|
||||
|
@ -46,12 +154,14 @@ def load_previous_and_future_frames(
|
|||
issues with timestamps during data collection.
|
||||
"""
|
||||
# get indices of the frames associated to the episode, and their timestamps
|
||||
ep_data_id_from = item["episode_data_index_from"].item()
|
||||
ep_data_id_to = item["episode_data_index_to"].item()
|
||||
ep_id = item["episode_index"].item()
|
||||
ep_data_id_from = episode_data_index["from"][ep_id].item()
|
||||
ep_data_id_to = episode_data_index["to"][ep_id].item()
|
||||
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
|
||||
|
||||
# load timestamps
|
||||
ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
|
||||
ep_timestamps = torch.stack(ep_timestamps)
|
||||
|
||||
# we make the assumption that the timestamps are sorted
|
||||
ep_first_ts = ep_timestamps[0]
|
||||
|
@ -82,39 +192,57 @@ def load_previous_and_future_frames(
|
|||
|
||||
# load frames modality
|
||||
item[key] = hf_dataset.select_columns(key)[data_ids][key]
|
||||
item[key] = torch.stack(item[key])
|
||||
item[f"{key}_is_pad"] = is_pad
|
||||
|
||||
return item
|
||||
|
||||
|
||||
def get_stats_einops_patterns(dataset):
|
||||
"""These einops patterns will be used to aggregate batches and compute statistics."""
|
||||
stats_patterns = {
|
||||
"action": "b c -> c",
|
||||
"observation.state": "b c -> c",
|
||||
}
|
||||
for key in dataset.image_keys:
|
||||
stats_patterns[key] = "b c h w -> c 1 1"
|
||||
def get_stats_einops_patterns(hf_dataset):
|
||||
"""These einops patterns will be used to aggregate batches and compute statistics.
|
||||
|
||||
Note: We assume the images of `hf_dataset` are in channel first format
|
||||
"""
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
hf_dataset,
|
||||
num_workers=0,
|
||||
batch_size=2,
|
||||
shuffle=False,
|
||||
)
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
stats_patterns = {}
|
||||
for key, feats_type in hf_dataset.features.items():
|
||||
# sanity check that tensors are not float64
|
||||
assert batch[key].dtype != torch.float64
|
||||
|
||||
if isinstance(feats_type, Image):
|
||||
# sanity check that images are channel first
|
||||
_, c, h, w = batch[key].shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
|
||||
|
||||
# sanity check that images are float32 in range [0,1]
|
||||
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
|
||||
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
|
||||
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
|
||||
|
||||
stats_patterns[key] = "b c h w -> c 1 1"
|
||||
elif batch[key].ndim == 2:
|
||||
stats_patterns[key] = "b c -> c "
|
||||
elif batch[key].ndim == 1:
|
||||
stats_patterns[key] = "b -> 1"
|
||||
else:
|
||||
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
|
||||
|
||||
return stats_patterns
|
||||
|
||||
|
||||
def compute_stats(dataset, batch_size=32, max_num_samples=None):
|
||||
def compute_stats(hf_dataset, batch_size=32, max_num_samples=None):
|
||||
if max_num_samples is None:
|
||||
max_num_samples = len(dataset)
|
||||
else:
|
||||
raise NotImplementedError("We need to set shuffle=True, but this violate an assert for now.")
|
||||
max_num_samples = len(hf_dataset)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
# pin_memory=cfg.device != "cpu",
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
# get einops patterns to aggregate batches and compute statistics
|
||||
stats_patterns = get_stats_einops_patterns(dataset)
|
||||
stats_patterns = get_stats_einops_patterns(hf_dataset)
|
||||
|
||||
# mean and std will be computed incrementally while max and min will track the running value.
|
||||
mean, std, max, min = {}, {}, {}, {}
|
||||
|
@ -124,10 +252,24 @@ def compute_stats(dataset, batch_size=32, max_num_samples=None):
|
|||
max[key] = torch.tensor(-float("inf")).float()
|
||||
min[key] = torch.tensor(float("inf")).float()
|
||||
|
||||
def create_seeded_dataloader(hf_dataset, batch_size, seed):
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
hf_dataset,
|
||||
num_workers=4,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
drop_last=False,
|
||||
generator=generator,
|
||||
)
|
||||
return dataloader
|
||||
|
||||
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
|
||||
# surprises when rerunning the sampler.
|
||||
first_batch = None
|
||||
running_item_count = 0 # for online mean computation
|
||||
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
|
||||
):
|
||||
|
@ -153,6 +295,7 @@ def compute_stats(dataset, batch_size=32, max_num_samples=None):
|
|||
|
||||
first_batch_ = None
|
||||
running_item_count = 0 # for online std computation
|
||||
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
|
||||
):
|
||||
|
|
|
@ -1,25 +1,37 @@
|
|||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
from lerobot.common.datasets.utils import load_previous_and_future_frames
|
||||
from lerobot.common.datasets.utils import (
|
||||
load_episode_data_index,
|
||||
load_hf_dataset,
|
||||
load_previous_and_future_frames,
|
||||
load_stats,
|
||||
)
|
||||
|
||||
|
||||
class XarmDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
https://huggingface.co/datasets/lerobot/xarm_lift_medium
|
||||
https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay
|
||||
https://huggingface.co/datasets/lerobot/xarm_push_medium
|
||||
https://huggingface.co/datasets/lerobot/xarm_push_medium_replay
|
||||
"""
|
||||
|
||||
# Copied from lerobot/__init__.py
|
||||
available_datasets = ["xarm_lift_medium"]
|
||||
available_datasets = [
|
||||
"xarm_lift_medium",
|
||||
"xarm_lift_medium_replay",
|
||||
"xarm_push_medium",
|
||||
"xarm_push_medium_replay",
|
||||
]
|
||||
fps = 15
|
||||
image_keys = ["observation.image"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str = "xarm_lift_medium",
|
||||
version: str | None = "v1.0",
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.1",
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
|
@ -32,13 +44,10 @@ class XarmDataset(torch.utils.data.Dataset):
|
|||
self.split = split
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
if self.root is not None:
|
||||
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
else:
|
||||
self.hf_dataset = load_dataset(
|
||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
||||
)
|
||||
self.hf_dataset = self.hf_dataset.with_format("torch")
|
||||
# load data from hub or locally when root is provided
|
||||
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
|
||||
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
|
||||
self.stats = load_stats(dataset_id, version, root)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
|
@ -46,7 +55,7 @@ class XarmDataset(torch.utils.data.Dataset):
|
|||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.hf_dataset.unique("episode_id"))
|
||||
return len(self.hf_dataset.unique("episode_index"))
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
@ -58,19 +67,11 @@ class XarmDataset(torch.utils.data.Dataset):
|
|||
item = load_previous_and_future_frames(
|
||||
item,
|
||||
self.hf_dataset,
|
||||
self.episode_data_index,
|
||||
self.delta_timestamps,
|
||||
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
|
||||
)
|
||||
|
||||
# convert images from channel last (PIL) to channel first (pytorch)
|
||||
for key in self.image_keys:
|
||||
if item[key].ndim == 3:
|
||||
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
|
||||
elif item[key].ndim == 4:
|
||||
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
|
||||
else:
|
||||
raise ValueError(item[key].ndim)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
|
|
|
@ -39,4 +39,5 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
|
|||
for _ in range(num_parallel_envs)
|
||||
]
|
||||
)
|
||||
|
||||
return env
|
||||
|
|
|
@ -15,8 +15,19 @@ def preprocess_observation(observation, transform=None):
|
|||
|
||||
for imgkey, img in imgs.items():
|
||||
img = torch.from_numpy(img)
|
||||
# convert to (b c h w) torch format
|
||||
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w")
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
obs[imgkey] = img
|
||||
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import torch
|
||||
from torchvision.transforms.v2 import Compose, Transform
|
||||
|
||||
|
||||
|
@ -12,40 +11,6 @@ def apply_inverse_transform(item, transform):
|
|||
return item
|
||||
|
||||
|
||||
class Prod(Transform):
|
||||
invertible = True
|
||||
|
||||
def __init__(self, in_keys: list[str], prod: float):
|
||||
super().__init__()
|
||||
self.in_keys = in_keys
|
||||
self.prod = prod
|
||||
self.original_dtypes = {}
|
||||
|
||||
def forward(self, item):
|
||||
for key in self.in_keys:
|
||||
if key not in item:
|
||||
continue
|
||||
self.original_dtypes[key] = item[key].dtype
|
||||
item[key] = item[key].type(torch.float32) * self.prod
|
||||
return item
|
||||
|
||||
def inverse_transform(self, item):
|
||||
for key in self.in_keys:
|
||||
if key not in item:
|
||||
continue
|
||||
item[key] = (item[key] / self.prod).type(self.original_dtypes[key])
|
||||
return item
|
||||
|
||||
# def transform_observation_spec(self, obs_spec):
|
||||
# for key in self.in_keys:
|
||||
# if obs_spec.get(key, None) is None:
|
||||
# continue
|
||||
# obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod
|
||||
# obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod
|
||||
# obs_spec[key].dtype = torch.float32
|
||||
# return obs_spec
|
||||
|
||||
|
||||
class NormalizeTransform(Transform):
|
||||
invertible = True
|
||||
|
||||
|
|
|
@ -47,6 +47,7 @@ from PIL import Image as PILImage
|
|||
from tqdm import trange
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import hf_transform_to_torch
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||
from lerobot.common.logger import log_output_dir
|
||||
|
@ -208,11 +209,12 @@ def eval_policy(
|
|||
max_rewards.extend(batch_max_reward.tolist())
|
||||
all_successes.extend(batch_success.tolist())
|
||||
|
||||
# similar logic is implemented in dataset preprocessing
|
||||
# similar logic is implemented when datasets are pushed to hub (see: `push_to_hub`)
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
num_episodes = dones.shape[0]
|
||||
total_frames = 0
|
||||
idx_from = 0
|
||||
id_from = 0
|
||||
for ep_id in range(num_episodes):
|
||||
num_frames = done_indices[ep_id].item() + 1
|
||||
total_frames += num_frames
|
||||
|
@ -222,19 +224,20 @@ def eval_policy(
|
|||
if return_episode_data:
|
||||
ep_dict = {
|
||||
"action": actions[ep_id, :num_frames],
|
||||
"episode_id": torch.tensor([ep_id] * num_frames),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"episode_index": torch.tensor([ep_id] * num_frames),
|
||||
"frame_index": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||
"next.done": dones[ep_id, :num_frames],
|
||||
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
|
||||
"episode_data_index_from": torch.tensor([idx_from] * num_frames),
|
||||
"episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames),
|
||||
}
|
||||
for key in observations:
|
||||
ep_dict[key] = observations[key][ep_id][:num_frames]
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
idx_from += num_frames
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
|
||||
id_from += num_frames
|
||||
|
||||
# similar logic is implemented in dataset preprocessing
|
||||
if return_episode_data:
|
||||
|
@ -247,14 +250,29 @@ def eval_policy(
|
|||
if key not in data_dict:
|
||||
data_dict[key] = []
|
||||
for ep_dict in ep_dicts:
|
||||
for x in ep_dict[key]:
|
||||
# c h w -> h w c
|
||||
img = PILImage.fromarray(x.permute(1, 2, 0).numpy())
|
||||
for img in ep_dict[key]:
|
||||
# sanity check that images are channel first
|
||||
c, h, w = img.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
|
||||
|
||||
# sanity check that images are float32 in range [0,1]
|
||||
assert img.dtype == torch.float32, f"expect torch.float32, but instead {img.dtype=}"
|
||||
assert img.max() <= 1, f"expect pixels lower than 1, but instead {img.max()=}"
|
||||
assert img.min() >= 0, f"expect pixels greater than 1, but instead {img.min()=}"
|
||||
|
||||
# from float32 in range [0,1] to uint8 in range [0,255]
|
||||
img *= 255
|
||||
img = img.type(torch.uint8)
|
||||
|
||||
# convert to channel last and numpy as expected by PIL
|
||||
img = PILImage.fromarray(img.permute(1, 2, 0).numpy())
|
||||
|
||||
data_dict[key].append(img)
|
||||
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict).with_format("torch")
|
||||
hf_dataset = Dataset.from_dict(data_dict)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
||||
if max_episodes_rendered > 0:
|
||||
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
||||
|
@ -307,7 +325,10 @@ def eval_policy(
|
|||
},
|
||||
}
|
||||
if return_episode_data:
|
||||
info["episodes"] = hf_dataset
|
||||
info["episodes"] = {
|
||||
"hf_dataset": hf_dataset,
|
||||
"episode_data_index": episode_data_index,
|
||||
}
|
||||
if max_episodes_rendered > 0:
|
||||
info["videos"] = videos
|
||||
return info
|
||||
|
|
|
@ -136,6 +136,7 @@ def add_episodes_inplace(
|
|||
concat_dataset: torch.utils.data.ConcatDataset,
|
||||
sampler: torch.utils.data.WeightedRandomSampler,
|
||||
hf_dataset: datasets.Dataset,
|
||||
episode_data_index: dict[str, torch.Tensor],
|
||||
pc_online_samples: float,
|
||||
):
|
||||
"""
|
||||
|
@ -151,13 +152,15 @@ def add_episodes_inplace(
|
|||
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
|
||||
reflect changes in the dataset sizes and specified sampling weights.
|
||||
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
|
||||
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
|
||||
They indicate the start index and end index of each episode in the dataset.
|
||||
- pc_online_samples (float): The target percentage of samples that should come from
|
||||
the online dataset during sampling operations.
|
||||
|
||||
Raises:
|
||||
- AssertionError: If the first episode_id or index in hf_dataset is not 0
|
||||
"""
|
||||
first_episode_id = hf_dataset.select_columns("episode_id")[0]["episode_id"].item()
|
||||
first_episode_id = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
|
||||
first_index = hf_dataset.select_columns("index")[0]["index"].item()
|
||||
assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}"
|
||||
assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}"
|
||||
|
@ -167,21 +170,22 @@ def add_episodes_inplace(
|
|||
online_dataset.hf_dataset = hf_dataset
|
||||
else:
|
||||
# find episode index and data frame indices according to previous episode in online_dataset
|
||||
start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1
|
||||
start_episode = online_dataset.select_columns("episode_index")[-1]["episode_index"].item() + 1
|
||||
start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1
|
||||
|
||||
def shift_indices(example):
|
||||
# note: we dont shift "frame_id" since it represents the index of the frame in the episode it belongs to
|
||||
example["episode_id"] += start_episode
|
||||
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
|
||||
example["episode_index"] += start_episode
|
||||
example["index"] += start_index
|
||||
example["episode_data_index_from"] += start_index
|
||||
example["episode_data_index_to"] += start_index
|
||||
return example
|
||||
|
||||
disable_progress_bars() # map has a tqdm progress bar
|
||||
hf_dataset = hf_dataset.map(shift_indices)
|
||||
enable_progress_bars()
|
||||
|
||||
episode_data_index["from"] += start_index
|
||||
episode_data_index["to"] += start_index
|
||||
|
||||
# extend online dataset
|
||||
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
|
||||
|
||||
|
@ -334,9 +338,13 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
seed=cfg.seed,
|
||||
)
|
||||
|
||||
online_pc_sampling = cfg.get("demo_schedule", 0.5)
|
||||
add_episodes_inplace(
|
||||
online_dataset, concat_dataset, sampler, eval_info["episodes"], online_pc_sampling
|
||||
online_dataset,
|
||||
concat_dataset,
|
||||
sampler,
|
||||
hf_dataset=eval_info["episodes"]["hf_dataset"],
|
||||
episode_data_index=eval_info["episodes"]["episode_data_index"],
|
||||
pc_online_samples=cfg.get("demo_schedule", 0.5),
|
||||
)
|
||||
|
||||
for _ in range(cfg.policy.utd):
|
||||
|
|
|
@ -22,11 +22,24 @@ def visualize_dataset_cli(cfg: dict):
|
|||
|
||||
|
||||
def cat_and_write_video(video_path, frames, fps):
|
||||
# Expects images in [0, 255].
|
||||
frames = torch.cat(frames)
|
||||
assert frames.dtype == torch.uint8
|
||||
frames = einops.rearrange(frames, "b c h w -> b h w c").numpy()
|
||||
imageio.mimsave(video_path, frames, fps=fps)
|
||||
|
||||
# Expects images in [0, 1].
|
||||
frame = frames[0]
|
||||
if frame.ndim == 4:
|
||||
raise NotImplementedError("We currently dont support multiple timestamps.")
|
||||
c, h, w = frame.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {frame.shape}"
|
||||
|
||||
# sanity check that images are float32 in range [0,1]
|
||||
assert frame.dtype == torch.float32, f"expect torch.float32, but instead {frame.dtype=}"
|
||||
assert frame.max() <= 1, f"expect pixels lower than 1, but instead {frame.max()=}"
|
||||
assert frame.min() >= 0, f"expect pixels greater than 1, but instead {frame.min()=}"
|
||||
|
||||
# convert to channel last uint8 [0, 255]
|
||||
frames = einops.rearrange(frames, "b c h w -> b h w c")
|
||||
frames = (frames * 255).type(torch.uint8)
|
||||
imageio.mimsave(video_path, frames.numpy(), fps=fps)
|
||||
|
||||
|
||||
def visualize_dataset(cfg: dict, out_dir=None):
|
||||
|
@ -44,9 +57,10 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
|||
)
|
||||
|
||||
logging.info("Start rendering episodes from offline buffer")
|
||||
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps)
|
||||
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER)
|
||||
for video_path in video_paths:
|
||||
logging.info(video_path)
|
||||
return video_paths
|
||||
|
||||
|
||||
def render_dataset(dataset, out_dir, max_num_episodes):
|
||||
|
@ -77,7 +91,7 @@ def render_dataset(dataset, out_dir, max_num_episodes):
|
|||
# add current frame to list of frames to render
|
||||
frames[im_key].append(item[im_key])
|
||||
|
||||
end_of_episode = item["index"].item() == item["episode_data_index_to"].item() - 1
|
||||
end_of_episode = item["index"].item() == dataset.episode_data_index["to"][ep_id] - 1
|
||||
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
for im_key in dataset.image_keys:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "absl-py"
|
||||
|
@ -522,21 +522,21 @@ toml = ["tomli"]
|
|||
|
||||
[[package]]
|
||||
name = "datasets"
|
||||
version = "2.18.0"
|
||||
version = "2.19.0"
|
||||
description = "HuggingFace community-driven open-source library of datasets"
|
||||
optional = false
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "datasets-2.18.0-py3-none-any.whl", hash = "sha256:f1bbf0e2896917a914de01cbd37075b14deea3837af87ad0d9f697388ccaeb50"},
|
||||
{file = "datasets-2.18.0.tar.gz", hash = "sha256:cdf8b8c6abf7316377ba4f49f9589a4c74556d6b481afd0abd2284f3d69185cb"},
|
||||
{file = "datasets-2.19.0-py3-none-any.whl", hash = "sha256:f57c5316e123d4721b970c68c1cb856505f289cda58f5557ffe745b49c011a8e"},
|
||||
{file = "datasets-2.19.0.tar.gz", hash = "sha256:0b47e08cc7af2c6800a42cadc4657b22a0afc7197786c8986d703c08d90886a6"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
aiohttp = "*"
|
||||
dill = ">=0.3.0,<0.3.9"
|
||||
filelock = "*"
|
||||
fsspec = {version = ">=2023.1.0,<=2024.2.0", extras = ["http"]}
|
||||
huggingface-hub = ">=0.19.4"
|
||||
fsspec = {version = ">=2023.1.0,<=2024.3.1", extras = ["http"]}
|
||||
huggingface-hub = ">=0.21.2"
|
||||
multiprocess = "*"
|
||||
numpy = ">=1.17"
|
||||
packaging = "*"
|
||||
|
@ -552,15 +552,15 @@ xxhash = "*"
|
|||
apache-beam = ["apache-beam (>=2.26.0)"]
|
||||
audio = ["librosa", "soundfile (>=0.12.1)"]
|
||||
benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"]
|
||||
dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
|
||||
docs = ["s3fs", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos", "torch", "transformers"]
|
||||
dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
|
||||
docs = ["s3fs", "tensorflow (>=2.6.0)", "torch", "transformers"]
|
||||
jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"]
|
||||
metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"]
|
||||
quality = ["ruff (>=0.3.0)"]
|
||||
s3 = ["s3fs"]
|
||||
tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos"]
|
||||
tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"]
|
||||
tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
|
||||
tensorflow = ["tensorflow (>=2.6.0)"]
|
||||
tensorflow-gpu = ["tensorflow (>=2.6.0)"]
|
||||
tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
|
||||
torch = ["torch"]
|
||||
vision = ["Pillow (>=6.2.1)"]
|
||||
|
||||
|
@ -2909,7 +2909,6 @@ files = [
|
|||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
|
@ -4195,4 +4194,4 @@ xarm = ["gym-xarm"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "01ad4eb04061ec9f785d4574bf66d3e5cb4549e2ea11ab175895f94cb62c1f1c"
|
||||
content-hash = "7f5afa48aead953f598e686e767891d3d23f2862b80144f76dc064101ef80b4a"
|
||||
|
|
|
@ -53,7 +53,8 @@ pre-commit = {version = "^3.7.0", optional = true}
|
|||
debugpy = {version = "^1.8.1", optional = true}
|
||||
pytest = {version = "^8.1.0", optional = true}
|
||||
pytest-cov = {version = "^5.0.0", optional = true}
|
||||
datasets = "^2.18.0"
|
||||
datasets = "^2.19.0"
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
pusht = ["gym-pusht"]
|
||||
|
|
Binary file not shown.
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"fps": 50
|
||||
}
|
Binary file not shown.
Binary file not shown.
|
@ -21,11 +21,11 @@
|
|||
"length": 14,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"episode_id": {
|
||||
"episode_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"frame_id": {
|
||||
"frame_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
|
@ -37,14 +37,6 @@
|
|||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_index_from": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_index_to": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "d79cf82ffc86f110",
|
||||
"_fingerprint": "22eeca7a3f4725ee",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
|
|
Binary file not shown.
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"fps": 50
|
||||
}
|
Binary file not shown.
Binary file not shown.
|
@ -21,11 +21,11 @@
|
|||
"length": 14,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"episode_id": {
|
||||
"episode_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"frame_id": {
|
||||
"frame_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
|
@ -37,14 +37,6 @@
|
|||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_index_from": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_index_to": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "d8e4a817b5449498",
|
||||
"_fingerprint": "97c28d4ad1536e4c",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
|
|
Binary file not shown.
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"fps": 50
|
||||
}
|
Binary file not shown.
Binary file not shown.
|
@ -21,11 +21,11 @@
|
|||
"length": 14,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"episode_id": {
|
||||
"episode_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"frame_id": {
|
||||
"frame_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
|
@ -37,14 +37,6 @@
|
|||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_index_from": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_index_to": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "f03482befa767127",
|
||||
"_fingerprint": "cb9349b5c92951e8",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
|
|
Binary file not shown.
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"fps": 50
|
||||
}
|
Binary file not shown.
Binary file not shown.
|
@ -21,11 +21,11 @@
|
|||
"length": 14,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"episode_id": {
|
||||
"episode_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"frame_id": {
|
||||
"frame_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
|
@ -37,14 +37,6 @@
|
|||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_index_from": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_index_to": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "93e03c6320c7d56e",
|
||||
"_fingerprint": "e4d7ad2b360db1af",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
|
|
Binary file not shown.
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"fps": 10
|
||||
}
|
Binary file not shown.
Binary file not shown.
|
@ -21,11 +21,11 @@
|
|||
"length": 2,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"episode_id": {
|
||||
"episode_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"frame_id": {
|
||||
"frame_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
|
@ -45,14 +45,6 @@
|
|||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_index_from": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_index_to": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
|
|
Binary file not shown.
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"fps": 10
|
||||
}
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -4,7 +4,7 @@
|
|||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "21bb9a76ed78a475",
|
||||
"_fingerprint": "a04a9ce660122e23",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"fps": 15
|
||||
}
|
Binary file not shown.
Binary file not shown.
|
@ -21,11 +21,11 @@
|
|||
"length": 4,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"episode_id": {
|
||||
"episode_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"frame_id": {
|
||||
"frame_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
|
@ -41,14 +41,6 @@
|
|||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_index_from": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_index_to": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "a95cbec45e3bb9d6",
|
||||
"_fingerprint": "cc6afdfcdd6f63ab",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
|
|
Binary file not shown.
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"fps": 15
|
||||
}
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,51 @@
|
|||
{
|
||||
"citation": "",
|
||||
"description": "",
|
||||
"features": {
|
||||
"observation.image": {
|
||||
"_type": "Image"
|
||||
},
|
||||
"observation.state": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"length": 4,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"action": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"length": 4,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"episode_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"frame_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"timestamp": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"next.reward": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"next.done": {
|
||||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
}
|
||||
},
|
||||
"homepage": "",
|
||||
"license": ""
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
{
|
||||
"_data_files": [
|
||||
{
|
||||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "9f8e1a8c1845df55",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
"_output_all_columns": false,
|
||||
"_split": null
|
||||
}
|
Binary file not shown.
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"fps": 15
|
||||
}
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,51 @@
|
|||
{
|
||||
"citation": "",
|
||||
"description": "",
|
||||
"features": {
|
||||
"observation.image": {
|
||||
"_type": "Image"
|
||||
},
|
||||
"observation.state": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"length": 4,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"action": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"length": 3,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"episode_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"frame_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"timestamp": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"next.reward": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"next.done": {
|
||||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
}
|
||||
},
|
||||
"homepage": "",
|
||||
"license": ""
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
{
|
||||
"_data_files": [
|
||||
{
|
||||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "c900258061dd0b3f",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
"_output_all_columns": false,
|
||||
"_split": null
|
||||
}
|
Binary file not shown.
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"fps": 15
|
||||
}
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,51 @@
|
|||
{
|
||||
"citation": "",
|
||||
"description": "",
|
||||
"features": {
|
||||
"observation.image": {
|
||||
"_type": "Image"
|
||||
},
|
||||
"observation.state": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"length": 4,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"action": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"length": 3,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"episode_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"frame_index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"timestamp": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"next.reward": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"next.done": {
|
||||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
}
|
||||
},
|
||||
"homepage": "",
|
||||
"license": ""
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
{
|
||||
"_data_files": [
|
||||
{
|
||||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "e51c80a33c7688c0",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
"_output_all_columns": false,
|
||||
"_split": null
|
||||
}
|
|
@ -0,0 +1,71 @@
|
|||
"""
|
||||
This script provides a utility for saving a dataset as safetensors files for the purpose of testing backward compatibility
|
||||
when updating the data format. It uses the `PushtDataset` to create a DataLoader and saves selected frame from the
|
||||
dataset into a corresponding safetensors file in a specified output directory.
|
||||
|
||||
If you know that your change will break backward compatibility, you should write a shortlived test by modifying
|
||||
`tests/test_datasets.py::test_backward_compatibility` accordingly, and make sure this custom test pass. Your custom test
|
||||
doesnt need to be merged into the `main` branch. Then you need to run this script and update the tests artifacts.
|
||||
|
||||
Example usage:
|
||||
`python tests/script/save_dataset_to_safetensors.py`
|
||||
"""
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.pusht import PushtDataset
|
||||
|
||||
|
||||
def save_dataset_to_safetensors(output_dir, dataset_id="pusht"):
|
||||
data_dir = Path(output_dir) / dataset_id
|
||||
|
||||
if data_dir.exists():
|
||||
shutil.rmtree(data_dir)
|
||||
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# TODO(rcadene): make it work for all datasets with LeRobotDataset(repo_id)
|
||||
dataset = PushtDataset(
|
||||
dataset_id=dataset_id,
|
||||
split="train",
|
||||
)
|
||||
|
||||
# save 2 first frames of first episode
|
||||
i = dataset.episode_data_index["from"][0].item()
|
||||
save_file(dataset[i], data_dir / f"frame_{i}.safetensors")
|
||||
save_file(dataset[i + 1], data_dir / f"frame_{i+1}.safetensors")
|
||||
|
||||
# save 2 frames at the middle of first episode
|
||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
||||
save_file(dataset[i], data_dir / f"frame_{i}.safetensors")
|
||||
save_file(dataset[i + 1], data_dir / f"frame_{i+1}.safetensors")
|
||||
|
||||
# save 2 last frames of first episode
|
||||
i = dataset.episode_data_index["to"][0].item()
|
||||
save_file(dataset[i - 2], data_dir / f"frame_{i-2}.safetensors")
|
||||
save_file(dataset[i - 1], data_dir / f"frame_{i-1}.safetensors")
|
||||
|
||||
# TODO(rcadene): Enable testing on second and last episode
|
||||
# We currently cant because our test dataset only contains the first episode
|
||||
|
||||
# # save 2 first frames of second episode
|
||||
# i = dataset.episode_data_index["from"][1].item()
|
||||
# save_file(dataset[i], data_dir / f"frame_{i}.safetensors")
|
||||
# save_file(dataset[i+1], data_dir / f"frame_{i+1}.safetensors")
|
||||
|
||||
# # save 2 last frames of second episode
|
||||
# i = dataset.episode_data_index["to"][1].item()
|
||||
# save_file(dataset[i-2], data_dir / f"frame_{i-2}.safetensors")
|
||||
# save_file(dataset[i-1], data_dir / f"frame_{i-1}.safetensors")
|
||||
|
||||
# # save 2 last frames of last episode
|
||||
# i = dataset.episode_data_index["to"][-1].item()
|
||||
# save_file(dataset[i-2], data_dir / f"frame_{i-2}.safetensors")
|
||||
# save_file(dataset[i-1], data_dir / f"frame_{i-1}.safetensors")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors")
|
|
@ -1,20 +1,26 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from safetensors.torch import load_file
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.pusht import PushtDataset
|
||||
from lerobot.common.datasets.utils import (
|
||||
compute_stats,
|
||||
flatten_dict,
|
||||
get_stats_einops_patterns,
|
||||
hf_transform_to_torch,
|
||||
load_previous_and_future_frames,
|
||||
unflatten_dict,
|
||||
)
|
||||
from lerobot.common.transforms import Prod
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
||||
from .utils import DEFAULT_CONFIG_PATH, DEVICE
|
||||
|
@ -39,8 +45,8 @@ def test_factory(env_name, dataset_id, policy_name):
|
|||
|
||||
keys_ndim_required = [
|
||||
("action", 1, True),
|
||||
("episode_id", 0, True),
|
||||
("frame_id", 0, True),
|
||||
("episode_index", 0, True),
|
||||
("frame_index", 0, True),
|
||||
("timestamp", 0, True),
|
||||
# TODO(rcadene): should we rename it agent_pos?
|
||||
("observation.state", 1, True),
|
||||
|
@ -48,12 +54,6 @@ def test_factory(env_name, dataset_id, policy_name):
|
|||
("next.done", 0, False),
|
||||
]
|
||||
|
||||
for key in image_keys:
|
||||
keys_ndim_required.append(
|
||||
(key, 3, True),
|
||||
)
|
||||
assert dataset.hf_dataset[key].dtype == torch.uint8, f"{key}"
|
||||
|
||||
# test number of dimensions
|
||||
for key, ndim, required in keys_ndim_required:
|
||||
if key not in item:
|
||||
|
@ -94,26 +94,21 @@ def test_compute_stats_on_xarm():
|
|||
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
|
||||
because we are working with a small dataset).
|
||||
"""
|
||||
# TODO(rcadene): Reduce size of dataset sample on which stats compute is tested
|
||||
from lerobot.common.datasets.xarm import XarmDataset
|
||||
|
||||
data_dir = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
# get transform to convert images from uint8 [0,255] to float32 [0,1]
|
||||
transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0)
|
||||
|
||||
dataset = XarmDataset(
|
||||
dataset_id="xarm_lift_medium",
|
||||
root=data_dir,
|
||||
transform=transform,
|
||||
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
|
||||
)
|
||||
|
||||
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
|
||||
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
|
||||
# dataset into even batches.
|
||||
computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25))
|
||||
computed_stats = compute_stats(dataset.hf_dataset, batch_size=int(len(dataset) * 0.25))
|
||||
|
||||
# get einops patterns to aggregate batches and compute statistics
|
||||
stats_patterns = get_stats_einops_patterns(dataset)
|
||||
stats_patterns = get_stats_einops_patterns(dataset.hf_dataset)
|
||||
|
||||
# get all frames from the dataset in the same dtype and range as during compute_stats
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
|
@ -122,18 +117,19 @@ def test_compute_stats_on_xarm():
|
|||
batch_size=len(dataset),
|
||||
shuffle=False,
|
||||
)
|
||||
hf_dataset = next(iter(dataloader))
|
||||
full_batch = next(iter(dataloader))
|
||||
|
||||
# compute stats based on all frames from the dataset without any batching
|
||||
expected_stats = {}
|
||||
for k, pattern in stats_patterns.items():
|
||||
full_batch[k] = full_batch[k].float()
|
||||
expected_stats[k] = {}
|
||||
expected_stats[k]["mean"] = einops.reduce(hf_dataset[k], pattern, "mean")
|
||||
expected_stats[k]["mean"] = einops.reduce(full_batch[k], pattern, "mean")
|
||||
expected_stats[k]["std"] = torch.sqrt(
|
||||
einops.reduce((hf_dataset[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
|
||||
einops.reduce((full_batch[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
|
||||
)
|
||||
expected_stats[k]["min"] = einops.reduce(hf_dataset[k], pattern, "min")
|
||||
expected_stats[k]["max"] = einops.reduce(hf_dataset[k], pattern, "max")
|
||||
expected_stats[k]["min"] = einops.reduce(full_batch[k], pattern, "min")
|
||||
expected_stats[k]["max"] = einops.reduce(full_batch[k], pattern, "max")
|
||||
|
||||
# test computed stats match expected stats
|
||||
for k in stats_patterns:
|
||||
|
@ -142,11 +138,10 @@ def test_compute_stats_on_xarm():
|
|||
assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"])
|
||||
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
|
||||
|
||||
# TODO(rcadene): check that the stats used for training are correct too
|
||||
# # load stats that are expected to match the ones returned by computed_stats
|
||||
# assert (dataset.data_dir / "stats.pth").exists()
|
||||
# loaded_stats = torch.load(dataset.data_dir / "stats.pth")
|
||||
# load stats used during training which are expected to match the ones returned by computed_stats
|
||||
loaded_stats = dataset.stats # noqa: F841
|
||||
|
||||
# TODO(rcadene): we can't test this because expected_stats is computed on a subset
|
||||
# # test loaded stats match expected stats
|
||||
# for k in stats_patterns:
|
||||
# assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"])
|
||||
|
@ -160,15 +155,18 @@ def test_load_previous_and_future_frames_within_tolerance():
|
|||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"index": [0, 1, 2, 3, 4],
|
||||
"episode_data_index_from": [0, 0, 0, 0, 0],
|
||||
"episode_data_index_to": [5, 5, 5, 5, 5],
|
||||
"episode_index": [0, 0, 0, 0, 0],
|
||||
}
|
||||
)
|
||||
hf_dataset = hf_dataset.with_format("torch")
|
||||
item = hf_dataset[2]
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = {
|
||||
"from": torch.tensor([0]),
|
||||
"to": torch.tensor([5]),
|
||||
}
|
||||
delta_timestamps = {"index": [-0.2, 0, 0.139]}
|
||||
tol = 0.04
|
||||
item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
|
||||
item = hf_dataset[2]
|
||||
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
|
||||
data, is_pad = item["index"], item["index_is_pad"]
|
||||
assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
|
||||
assert not is_pad.any(), "Unexpected padding detected"
|
||||
|
@ -179,16 +177,19 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range(
|
|||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"index": [0, 1, 2, 3, 4],
|
||||
"episode_data_index_from": [0, 0, 0, 0, 0],
|
||||
"episode_data_index_to": [5, 5, 5, 5, 5],
|
||||
"episode_index": [0, 0, 0, 0, 0],
|
||||
}
|
||||
)
|
||||
hf_dataset = hf_dataset.with_format("torch")
|
||||
item = hf_dataset[2]
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = {
|
||||
"from": torch.tensor([0]),
|
||||
"to": torch.tensor([5]),
|
||||
}
|
||||
delta_timestamps = {"index": [-0.2, 0, 0.141]}
|
||||
tol = 0.04
|
||||
item = hf_dataset[2]
|
||||
with pytest.raises(AssertionError):
|
||||
load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
|
||||
load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
|
||||
|
||||
|
||||
def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range():
|
||||
|
@ -196,17 +197,102 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range
|
|||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"index": [0, 1, 2, 3, 4],
|
||||
"episode_data_index_from": [0, 0, 0, 0, 0],
|
||||
"episode_data_index_to": [5, 5, 5, 5, 5],
|
||||
"episode_index": [0, 0, 0, 0, 0],
|
||||
}
|
||||
)
|
||||
hf_dataset = hf_dataset.with_format("torch")
|
||||
item = hf_dataset[2]
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = {
|
||||
"from": torch.tensor([0]),
|
||||
"to": torch.tensor([5]),
|
||||
}
|
||||
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
|
||||
tol = 0.04
|
||||
item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
|
||||
item = hf_dataset[2]
|
||||
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
|
||||
data, is_pad = item["index"], item["index_is_pad"]
|
||||
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
|
||||
assert torch.equal(
|
||||
is_pad, torch.tensor([True, False, False, True, True])
|
||||
), "Padding does not match expected values"
|
||||
|
||||
|
||||
def test_flatten_unflatten_dict():
|
||||
d = {
|
||||
"obs": {
|
||||
"min": 0,
|
||||
"max": 1,
|
||||
"mean": 2,
|
||||
"std": 3,
|
||||
},
|
||||
"action": {
|
||||
"min": 4,
|
||||
"max": 5,
|
||||
"mean": 6,
|
||||
"std": 7,
|
||||
},
|
||||
}
|
||||
|
||||
original_d = deepcopy(d)
|
||||
d = unflatten_dict(flatten_dict(d))
|
||||
|
||||
# test equality between nested dicts
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
||||
|
||||
|
||||
def test_backward_compatibility():
|
||||
"""This tests artifacts have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
|
||||
# TODO(rcadene): make it work for all datasets with LeRobotDataset(repo_id)
|
||||
dataset_id = "pusht"
|
||||
data_dir = Path("tests/data/save_dataset_to_safetensors") / dataset_id
|
||||
|
||||
dataset = PushtDataset(
|
||||
dataset_id=dataset_id,
|
||||
split="train",
|
||||
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
|
||||
)
|
||||
|
||||
def load_and_compare(i):
|
||||
new_frame = dataset[i]
|
||||
old_frame = load_file(data_dir / f"frame_{i}.safetensors")
|
||||
|
||||
new_keys = set(new_frame.keys())
|
||||
old_keys = set(old_frame.keys())
|
||||
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
|
||||
|
||||
for key in new_frame:
|
||||
assert (
|
||||
new_frame[key] == old_frame[key]
|
||||
).all(), f"{key=} for index={i} does not contain the same value"
|
||||
|
||||
# test2 first frames of first episode
|
||||
i = dataset.episode_data_index["from"][0].item()
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
|
||||
# test 2 frames at the middle of first episode
|
||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
|
||||
# test 2 last frames of first episode
|
||||
i = dataset.episode_data_index["to"][0].item()
|
||||
load_and_compare(i - 2)
|
||||
load_and_compare(i - 1)
|
||||
|
||||
# TODO(rcadene): Enable testing on second and last episode
|
||||
# We currently cant because our test dataset only contains the first episode
|
||||
|
||||
# # test 2 first frames of second episode
|
||||
# i = dataset.episode_data_index["from"][1].item()
|
||||
# load_and_compare(i)
|
||||
# load_and_compare(i+1)
|
||||
|
||||
# #test 2 last frames of second episode
|
||||
# i = dataset.episode_data_index["to"][1].item()
|
||||
# load_and_compare(i-2)
|
||||
# load_and_compare(i-1)
|
||||
|
||||
# # test 2 last frames of last episode
|
||||
# i = dataset.episode_data_index["to"][-1].item()
|
||||
# load_and_compare(i-2)
|
||||
# load_and_compare(i-1)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
# TODO(aliberts): Mute logging for these tests
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
import pytest
|
||||
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.scripts.visualize_dataset import visualize_dataset
|
||||
|
||||
from .utils import DEFAULT_CONFIG_PATH
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dataset_id",
|
||||
[
|
||||
"aloha_sim_insertion_human",
|
||||
],
|
||||
)
|
||||
def test_visualize_dataset(tmpdir, dataset_id):
|
||||
# TODO(rcadene): this test might fail with other datasets/policies/envs, since visualization_dataset
|
||||
# doesnt support multiple timesteps which requires delta_timestamps to None for images.
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[
|
||||
"policy=act",
|
||||
"env=aloha",
|
||||
f"dataset_id={dataset_id}",
|
||||
],
|
||||
)
|
||||
video_paths = visualize_dataset(cfg, out_dir=tmpdir)
|
||||
|
||||
assert len(video_paths) > 0
|
||||
|
||||
for video_path in video_paths:
|
||||
assert video_path.exists()
|
Loading…
Reference in New Issue