Loads episode_data_index and stats during dataset __init__ (#85)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
parent
e2168163cd
commit
1030ea0070
|
@ -1,4 +1,4 @@
|
||||||
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "absl-py"
|
name = "absl-py"
|
||||||
|
@ -522,21 +522,21 @@ toml = ["tomli"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "datasets"
|
name = "datasets"
|
||||||
version = "2.18.0"
|
version = "2.19.0"
|
||||||
description = "HuggingFace community-driven open-source library of datasets"
|
description = "HuggingFace community-driven open-source library of datasets"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8.0"
|
python-versions = ">=3.8.0"
|
||||||
files = [
|
files = [
|
||||||
{file = "datasets-2.18.0-py3-none-any.whl", hash = "sha256:f1bbf0e2896917a914de01cbd37075b14deea3837af87ad0d9f697388ccaeb50"},
|
{file = "datasets-2.19.0-py3-none-any.whl", hash = "sha256:f57c5316e123d4721b970c68c1cb856505f289cda58f5557ffe745b49c011a8e"},
|
||||||
{file = "datasets-2.18.0.tar.gz", hash = "sha256:cdf8b8c6abf7316377ba4f49f9589a4c74556d6b481afd0abd2284f3d69185cb"},
|
{file = "datasets-2.19.0.tar.gz", hash = "sha256:0b47e08cc7af2c6800a42cadc4657b22a0afc7197786c8986d703c08d90886a6"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
aiohttp = "*"
|
aiohttp = "*"
|
||||||
dill = ">=0.3.0,<0.3.9"
|
dill = ">=0.3.0,<0.3.9"
|
||||||
filelock = "*"
|
filelock = "*"
|
||||||
fsspec = {version = ">=2023.1.0,<=2024.2.0", extras = ["http"]}
|
fsspec = {version = ">=2023.1.0,<=2024.3.1", extras = ["http"]}
|
||||||
huggingface-hub = ">=0.19.4"
|
huggingface-hub = ">=0.21.2"
|
||||||
multiprocess = "*"
|
multiprocess = "*"
|
||||||
numpy = ">=1.17"
|
numpy = ">=1.17"
|
||||||
packaging = "*"
|
packaging = "*"
|
||||||
|
@ -552,15 +552,15 @@ xxhash = "*"
|
||||||
apache-beam = ["apache-beam (>=2.26.0)"]
|
apache-beam = ["apache-beam (>=2.26.0)"]
|
||||||
audio = ["librosa", "soundfile (>=0.12.1)"]
|
audio = ["librosa", "soundfile (>=0.12.1)"]
|
||||||
benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"]
|
benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"]
|
||||||
dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
|
dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
|
||||||
docs = ["s3fs", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos", "torch", "transformers"]
|
docs = ["s3fs", "tensorflow (>=2.6.0)", "torch", "transformers"]
|
||||||
jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"]
|
jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"]
|
||||||
metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"]
|
metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"]
|
||||||
quality = ["ruff (>=0.3.0)"]
|
quality = ["ruff (>=0.3.0)"]
|
||||||
s3 = ["s3fs"]
|
s3 = ["s3fs"]
|
||||||
tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos"]
|
tensorflow = ["tensorflow (>=2.6.0)"]
|
||||||
tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"]
|
tensorflow-gpu = ["tensorflow (>=2.6.0)"]
|
||||||
tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
|
tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
|
||||||
torch = ["torch"]
|
torch = ["torch"]
|
||||||
vision = ["Pillow (>=6.2.1)"]
|
vision = ["Pillow (>=6.2.1)"]
|
||||||
|
|
||||||
|
@ -1524,7 +1524,6 @@ description = "Powerful and Pythonic XML processing library combining libxml2/li
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.6"
|
python-versions = ">=3.6"
|
||||||
files = [
|
files = [
|
||||||
{file = "lxml-5.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:704f5572ff473a5f897745abebc6df40f22d4133c1e0a1f124e4f2bd3330ff7e"},
|
|
||||||
{file = "lxml-5.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9d3c0f8567ffe7502d969c2c1b809892dc793b5d0665f602aad19895f8d508da"},
|
{file = "lxml-5.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9d3c0f8567ffe7502d969c2c1b809892dc793b5d0665f602aad19895f8d508da"},
|
||||||
{file = "lxml-5.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5fcfbebdb0c5d8d18b84118842f31965d59ee3e66996ac842e21f957eb76138c"},
|
{file = "lxml-5.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5fcfbebdb0c5d8d18b84118842f31965d59ee3e66996ac842e21f957eb76138c"},
|
||||||
{file = "lxml-5.1.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2f37c6d7106a9d6f0708d4e164b707037b7380fcd0b04c5bd9cae1fb46a856fb"},
|
{file = "lxml-5.1.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2f37c6d7106a9d6f0708d4e164b707037b7380fcd0b04c5bd9cae1fb46a856fb"},
|
||||||
|
@ -1534,7 +1533,6 @@ files = [
|
||||||
{file = "lxml-5.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:82bddf0e72cb2af3cbba7cec1d2fd11fda0de6be8f4492223d4a268713ef2147"},
|
{file = "lxml-5.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:82bddf0e72cb2af3cbba7cec1d2fd11fda0de6be8f4492223d4a268713ef2147"},
|
||||||
{file = "lxml-5.1.0-cp310-cp310-win32.whl", hash = "sha256:b66aa6357b265670bb574f050ffceefb98549c721cf28351b748be1ef9577d93"},
|
{file = "lxml-5.1.0-cp310-cp310-win32.whl", hash = "sha256:b66aa6357b265670bb574f050ffceefb98549c721cf28351b748be1ef9577d93"},
|
||||||
{file = "lxml-5.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:4946e7f59b7b6a9e27bef34422f645e9a368cb2be11bf1ef3cafc39a1f6ba68d"},
|
{file = "lxml-5.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:4946e7f59b7b6a9e27bef34422f645e9a368cb2be11bf1ef3cafc39a1f6ba68d"},
|
||||||
{file = "lxml-5.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:14deca1460b4b0f6b01f1ddc9557704e8b365f55c63070463f6c18619ebf964f"},
|
|
||||||
{file = "lxml-5.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ed8c3d2cd329bf779b7ed38db176738f3f8be637bb395ce9629fc76f78afe3d4"},
|
{file = "lxml-5.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ed8c3d2cd329bf779b7ed38db176738f3f8be637bb395ce9629fc76f78afe3d4"},
|
||||||
{file = "lxml-5.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:436a943c2900bb98123b06437cdd30580a61340fbdb7b28aaf345a459c19046a"},
|
{file = "lxml-5.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:436a943c2900bb98123b06437cdd30580a61340fbdb7b28aaf345a459c19046a"},
|
||||||
{file = "lxml-5.1.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:acb6b2f96f60f70e7f34efe0c3ea34ca63f19ca63ce90019c6cbca6b676e81fa"},
|
{file = "lxml-5.1.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:acb6b2f96f60f70e7f34efe0c3ea34ca63f19ca63ce90019c6cbca6b676e81fa"},
|
||||||
|
@ -1544,7 +1542,6 @@ files = [
|
||||||
{file = "lxml-5.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f4c9bda132ad108b387c33fabfea47866af87f4ea6ffb79418004f0521e63204"},
|
{file = "lxml-5.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f4c9bda132ad108b387c33fabfea47866af87f4ea6ffb79418004f0521e63204"},
|
||||||
{file = "lxml-5.1.0-cp311-cp311-win32.whl", hash = "sha256:bc64d1b1dab08f679fb89c368f4c05693f58a9faf744c4d390d7ed1d8223869b"},
|
{file = "lxml-5.1.0-cp311-cp311-win32.whl", hash = "sha256:bc64d1b1dab08f679fb89c368f4c05693f58a9faf744c4d390d7ed1d8223869b"},
|
||||||
{file = "lxml-5.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:a5ab722ae5a873d8dcee1f5f45ddd93c34210aed44ff2dc643b5025981908cda"},
|
{file = "lxml-5.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:a5ab722ae5a873d8dcee1f5f45ddd93c34210aed44ff2dc643b5025981908cda"},
|
||||||
{file = "lxml-5.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9aa543980ab1fbf1720969af1d99095a548ea42e00361e727c58a40832439114"},
|
|
||||||
{file = "lxml-5.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6f11b77ec0979f7e4dc5ae081325a2946f1fe424148d3945f943ceaede98adb8"},
|
{file = "lxml-5.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6f11b77ec0979f7e4dc5ae081325a2946f1fe424148d3945f943ceaede98adb8"},
|
||||||
{file = "lxml-5.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a36c506e5f8aeb40680491d39ed94670487ce6614b9d27cabe45d94cd5d63e1e"},
|
{file = "lxml-5.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a36c506e5f8aeb40680491d39ed94670487ce6614b9d27cabe45d94cd5d63e1e"},
|
||||||
{file = "lxml-5.1.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f643ffd2669ffd4b5a3e9b41c909b72b2a1d5e4915da90a77e119b8d48ce867a"},
|
{file = "lxml-5.1.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f643ffd2669ffd4b5a3e9b41c909b72b2a1d5e4915da90a77e119b8d48ce867a"},
|
||||||
|
@ -1570,8 +1567,8 @@ files = [
|
||||||
{file = "lxml-5.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8f52fe6859b9db71ee609b0c0a70fea5f1e71c3462ecf144ca800d3f434f0764"},
|
{file = "lxml-5.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8f52fe6859b9db71ee609b0c0a70fea5f1e71c3462ecf144ca800d3f434f0764"},
|
||||||
{file = "lxml-5.1.0-cp37-cp37m-win32.whl", hash = "sha256:d42e3a3fc18acc88b838efded0e6ec3edf3e328a58c68fbd36a7263a874906c8"},
|
{file = "lxml-5.1.0-cp37-cp37m-win32.whl", hash = "sha256:d42e3a3fc18acc88b838efded0e6ec3edf3e328a58c68fbd36a7263a874906c8"},
|
||||||
{file = "lxml-5.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:eac68f96539b32fce2c9b47eb7c25bb2582bdaf1bbb360d25f564ee9e04c542b"},
|
{file = "lxml-5.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:eac68f96539b32fce2c9b47eb7c25bb2582bdaf1bbb360d25f564ee9e04c542b"},
|
||||||
{file = "lxml-5.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ae15347a88cf8af0949a9872b57a320d2605ae069bcdf047677318bc0bba45b1"},
|
|
||||||
{file = "lxml-5.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c26aab6ea9c54d3bed716b8851c8bfc40cb249b8e9880e250d1eddde9f709bf5"},
|
{file = "lxml-5.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c26aab6ea9c54d3bed716b8851c8bfc40cb249b8e9880e250d1eddde9f709bf5"},
|
||||||
|
{file = "lxml-5.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cfbac9f6149174f76df7e08c2e28b19d74aed90cad60383ad8671d3af7d0502f"},
|
||||||
{file = "lxml-5.1.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:342e95bddec3a698ac24378d61996b3ee5ba9acfeb253986002ac53c9a5f6f84"},
|
{file = "lxml-5.1.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:342e95bddec3a698ac24378d61996b3ee5ba9acfeb253986002ac53c9a5f6f84"},
|
||||||
{file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:725e171e0b99a66ec8605ac77fa12239dbe061482ac854d25720e2294652eeaa"},
|
{file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:725e171e0b99a66ec8605ac77fa12239dbe061482ac854d25720e2294652eeaa"},
|
||||||
{file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d184e0d5c918cff04cdde9dbdf9600e960161d773666958c9d7b565ccc60c45"},
|
{file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d184e0d5c918cff04cdde9dbdf9600e960161d773666958c9d7b565ccc60c45"},
|
||||||
|
@ -1579,7 +1576,6 @@ files = [
|
||||||
{file = "lxml-5.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6d48fc57e7c1e3df57be5ae8614bab6d4e7b60f65c5457915c26892c41afc59e"},
|
{file = "lxml-5.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6d48fc57e7c1e3df57be5ae8614bab6d4e7b60f65c5457915c26892c41afc59e"},
|
||||||
{file = "lxml-5.1.0-cp38-cp38-win32.whl", hash = "sha256:7ec465e6549ed97e9f1e5ed51c657c9ede767bc1c11552f7f4d022c4df4a977a"},
|
{file = "lxml-5.1.0-cp38-cp38-win32.whl", hash = "sha256:7ec465e6549ed97e9f1e5ed51c657c9ede767bc1c11552f7f4d022c4df4a977a"},
|
||||||
{file = "lxml-5.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:b21b4031b53d25b0858d4e124f2f9131ffc1530431c6d1321805c90da78388d1"},
|
{file = "lxml-5.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:b21b4031b53d25b0858d4e124f2f9131ffc1530431c6d1321805c90da78388d1"},
|
||||||
{file = "lxml-5.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:52427a7eadc98f9e62cb1368a5079ae826f94f05755d2d567d93ee1bc3ceb354"},
|
|
||||||
{file = "lxml-5.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6a2a2c724d97c1eb8cf966b16ca2915566a4904b9aad2ed9a09c748ffe14f969"},
|
{file = "lxml-5.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6a2a2c724d97c1eb8cf966b16ca2915566a4904b9aad2ed9a09c748ffe14f969"},
|
||||||
{file = "lxml-5.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:843b9c835580d52828d8f69ea4302537337a21e6b4f1ec711a52241ba4a824f3"},
|
{file = "lxml-5.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:843b9c835580d52828d8f69ea4302537337a21e6b4f1ec711a52241ba4a824f3"},
|
||||||
{file = "lxml-5.1.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9b99f564659cfa704a2dd82d0684207b1aadf7d02d33e54845f9fc78e06b7581"},
|
{file = "lxml-5.1.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9b99f564659cfa704a2dd82d0684207b1aadf7d02d33e54845f9fc78e06b7581"},
|
||||||
|
@ -2688,7 +2684,6 @@ files = [
|
||||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||||
|
@ -3919,4 +3914,4 @@ xarm = ["gym-xarm"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "bd9c506d2499d5e1e3b5e8b1a0f65df45c8feef38d89d0daeade56847fdb6a2e"
|
content-hash = "e526416d1282dea2550680b2be7fcf9ff6e1c67ac89d34c684b486d94a6addee"
|
||||||
|
|
|
@ -53,7 +53,7 @@ pre-commit = {version = "^3.7.0", optional = true}
|
||||||
debugpy = {version = "^1.8.1", optional = true}
|
debugpy = {version = "^1.8.1", optional = true}
|
||||||
pytest = {version = "^8.1.0", optional = true}
|
pytest = {version = "^8.1.0", optional = true}
|
||||||
pytest-cov = {version = "^5.0.0", optional = true}
|
pytest-cov = {version = "^5.0.0", optional = true}
|
||||||
datasets = "^2.18.0"
|
datasets = "^2.19.0"
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
|
|
|
@ -208,7 +208,7 @@ HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATAS
|
||||||
|
|
||||||
You will need to set the corresponding version as a default argument in your dataset class:
|
You will need to set the corresponding version as a default argument in your dataset class:
|
||||||
```python
|
```python
|
||||||
version: str | None = "v1.0",
|
version: str | None = "v1.1",
|
||||||
```
|
```
|
||||||
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
|
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ useless dependencies when using datasets.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import io
|
import io
|
||||||
|
import json
|
||||||
import pickle
|
import pickle
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -14,16 +15,20 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from datasets import Dataset, Features, Image, Sequence, Value
|
from datasets import Dataset, Features, Image, Sequence, Value
|
||||||
|
from huggingface_hub import HfApi
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
from lerobot.common.datasets.utils import compute_stats, flatten_dict, hf_transform_to_torch
|
||||||
|
|
||||||
|
|
||||||
def download_and_upload(root, root_tests, dataset_id):
|
def download_and_upload(root, revision, dataset_id):
|
||||||
if "pusht" in dataset_id:
|
if "pusht" in dataset_id:
|
||||||
download_and_upload_pusht(root, root_tests, dataset_id)
|
download_and_upload_pusht(root, revision, dataset_id)
|
||||||
elif "xarm" in dataset_id:
|
elif "xarm" in dataset_id:
|
||||||
download_and_upload_xarm(root, root_tests, dataset_id)
|
download_and_upload_xarm(root, revision, dataset_id)
|
||||||
elif "aloha" in dataset_id:
|
elif "aloha" in dataset_id:
|
||||||
download_and_upload_aloha(root, root_tests, dataset_id)
|
download_and_upload_aloha(root, revision, dataset_id)
|
||||||
else:
|
else:
|
||||||
raise ValueError(dataset_id)
|
raise ValueError(dataset_id)
|
||||||
|
|
||||||
|
@ -56,7 +61,102 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
|
def concatenate_episodes(ep_dicts):
|
||||||
|
data_dict = {}
|
||||||
|
|
||||||
|
keys = ep_dicts[0].keys()
|
||||||
|
for key in keys:
|
||||||
|
if torch.is_tensor(ep_dicts[0][key][0]):
|
||||||
|
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
||||||
|
else:
|
||||||
|
if key not in data_dict:
|
||||||
|
data_dict[key] = []
|
||||||
|
for ep_dict in ep_dicts:
|
||||||
|
for x in ep_dict[key]:
|
||||||
|
data_dict[key].append(x)
|
||||||
|
|
||||||
|
total_frames = data_dict["frame_index"].shape[0]
|
||||||
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
return data_dict
|
||||||
|
|
||||||
|
|
||||||
|
def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id):
|
||||||
|
# push to main to indicate latest version
|
||||||
|
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||||
|
|
||||||
|
# push to version branch
|
||||||
|
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision=revision)
|
||||||
|
|
||||||
|
# create and store meta_data
|
||||||
|
meta_data_dir = root / dataset_id / "meta_data"
|
||||||
|
meta_data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
api = HfApi()
|
||||||
|
|
||||||
|
# info
|
||||||
|
info_path = meta_data_dir / "info.json"
|
||||||
|
with open(str(info_path), "w") as f:
|
||||||
|
json.dump(info, f, indent=4)
|
||||||
|
api.upload_file(
|
||||||
|
path_or_fileobj=info_path,
|
||||||
|
path_in_repo=str(info_path).replace(f"{root}/{dataset_id}", ""),
|
||||||
|
repo_id=f"lerobot/{dataset_id}",
|
||||||
|
repo_type="dataset",
|
||||||
|
)
|
||||||
|
api.upload_file(
|
||||||
|
path_or_fileobj=info_path,
|
||||||
|
path_in_repo=str(info_path).replace(f"{root}/{dataset_id}", ""),
|
||||||
|
repo_id=f"lerobot/{dataset_id}",
|
||||||
|
repo_type="dataset",
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
|
||||||
|
# stats
|
||||||
|
stats_path = meta_data_dir / "stats.safetensors"
|
||||||
|
save_file(flatten_dict(stats), stats_path)
|
||||||
|
api.upload_file(
|
||||||
|
path_or_fileobj=stats_path,
|
||||||
|
path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""),
|
||||||
|
repo_id=f"lerobot/{dataset_id}",
|
||||||
|
repo_type="dataset",
|
||||||
|
)
|
||||||
|
api.upload_file(
|
||||||
|
path_or_fileobj=stats_path,
|
||||||
|
path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""),
|
||||||
|
repo_id=f"lerobot/{dataset_id}",
|
||||||
|
repo_type="dataset",
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
|
||||||
|
# episode_data_index
|
||||||
|
episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
|
||||||
|
ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
|
||||||
|
save_file(episode_data_index, ep_data_idx_path)
|
||||||
|
api.upload_file(
|
||||||
|
path_or_fileobj=ep_data_idx_path,
|
||||||
|
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{dataset_id}", ""),
|
||||||
|
repo_id=f"lerobot/{dataset_id}",
|
||||||
|
repo_type="dataset",
|
||||||
|
)
|
||||||
|
api.upload_file(
|
||||||
|
path_or_fileobj=ep_data_idx_path,
|
||||||
|
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{dataset_id}", ""),
|
||||||
|
repo_id=f"lerobot/{dataset_id}",
|
||||||
|
repo_type="dataset",
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
|
||||||
|
# copy in tests folder, the first episode and the meta_data directory
|
||||||
|
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
||||||
|
hf_dataset.select(range(num_items_first_ep)).with_format("torch").save_to_disk(
|
||||||
|
f"tests/data/{dataset_id}/train"
|
||||||
|
)
|
||||||
|
if Path(f"tests/data/{dataset_id}/meta_data").exists():
|
||||||
|
shutil.rmtree(f"tests/data/{dataset_id}/meta_data")
|
||||||
|
shutil.copytree(meta_data_dir, f"tests/data/{dataset_id}/meta_data")
|
||||||
|
|
||||||
|
|
||||||
|
def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
|
||||||
try:
|
try:
|
||||||
import pymunk
|
import pymunk
|
||||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||||
|
@ -99,6 +199,7 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
|
||||||
actions = torch.from_numpy(dataset_dict["action"])
|
actions = torch.from_numpy(dataset_dict["action"])
|
||||||
|
|
||||||
ep_dicts = []
|
ep_dicts = []
|
||||||
|
episode_data_index = {"from": [], "to": []}
|
||||||
|
|
||||||
id_from = 0
|
id_from = 0
|
||||||
for episode_id in tqdm.tqdm(range(num_episodes)):
|
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||||
|
@ -151,8 +252,8 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
|
||||||
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
|
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
|
||||||
"observation.state": agent_pos,
|
"observation.state": agent_pos,
|
||||||
"action": actions[id_from:id_to],
|
"action": actions[id_from:id_to],
|
||||||
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||||
"frame_id": torch.arange(0, num_frames, 1),
|
"frame_index": torch.arange(0, num_frames, 1),
|
||||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||||
# "next.observation.image": image[1:],
|
# "next.observation.image": image[1:],
|
||||||
# "next.observation.state": agent_pos[1:],
|
# "next.observation.state": agent_pos[1:],
|
||||||
|
@ -160,28 +261,15 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
|
||||||
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
|
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
|
||||||
"next.done": torch.cat([done[1:], done[[-1]]]),
|
"next.done": torch.cat([done[1:], done[[-1]]]),
|
||||||
"next.success": torch.cat([success[1:], success[[-1]]]),
|
"next.success": torch.cat([success[1:], success[[-1]]]),
|
||||||
"episode_data_index_from": torch.tensor([id_from] * num_frames),
|
|
||||||
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
|
|
||||||
}
|
}
|
||||||
ep_dicts.append(ep_dict)
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
|
episode_data_index["from"].append(id_from)
|
||||||
|
episode_data_index["to"].append(id_from + num_frames)
|
||||||
|
|
||||||
id_from += num_frames
|
id_from += num_frames
|
||||||
|
|
||||||
data_dict = {}
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
|
|
||||||
keys = ep_dicts[0].keys()
|
|
||||||
for key in keys:
|
|
||||||
if torch.is_tensor(ep_dicts[0][key][0]):
|
|
||||||
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
|
||||||
else:
|
|
||||||
if key not in data_dict:
|
|
||||||
data_dict[key] = []
|
|
||||||
for ep_dict in ep_dicts:
|
|
||||||
for x in ep_dict[key]:
|
|
||||||
data_dict[key].append(x)
|
|
||||||
|
|
||||||
total_frames = id_from
|
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
|
||||||
|
|
||||||
features = {
|
features = {
|
||||||
"observation.image": Image(),
|
"observation.image": Image(),
|
||||||
|
@ -189,35 +277,35 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
|
||||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
),
|
),
|
||||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||||
"episode_id": Value(dtype="int64", id=None),
|
"episode_index": Value(dtype="int64", id=None),
|
||||||
"frame_id": Value(dtype="int64", id=None),
|
"frame_index": Value(dtype="int64", id=None),
|
||||||
"timestamp": Value(dtype="float32", id=None),
|
"timestamp": Value(dtype="float32", id=None),
|
||||||
"next.reward": Value(dtype="float32", id=None),
|
"next.reward": Value(dtype="float32", id=None),
|
||||||
"next.done": Value(dtype="bool", id=None),
|
"next.done": Value(dtype="bool", id=None),
|
||||||
"next.success": Value(dtype="bool", id=None),
|
"next.success": Value(dtype="bool", id=None),
|
||||||
"index": Value(dtype="int64", id=None),
|
"index": Value(dtype="int64", id=None),
|
||||||
"episode_data_index_from": Value(dtype="int64", id=None),
|
|
||||||
"episode_data_index_to": Value(dtype="int64", id=None),
|
|
||||||
}
|
}
|
||||||
features = Features(features)
|
features = Features(features)
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||||
hf_dataset = hf_dataset.with_format("torch")
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
|
||||||
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
info = {
|
||||||
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
|
"fps": fps,
|
||||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
}
|
||||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
stats = compute_stats(hf_dataset)
|
||||||
|
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
|
||||||
|
|
||||||
|
|
||||||
def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
|
def download_and_upload_xarm(root, revision, dataset_id, fps=15):
|
||||||
root = Path(root)
|
root = Path(root)
|
||||||
raw_dir = root / f"{dataset_id}_raw"
|
raw_dir = root / "xarm_datasets_raw"
|
||||||
if not raw_dir.exists():
|
if not raw_dir.exists():
|
||||||
import zipfile
|
import zipfile
|
||||||
|
|
||||||
import gdown
|
import gdown
|
||||||
|
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
# from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py
|
||||||
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
||||||
zip_path = raw_dir / "data.zip"
|
zip_path = raw_dir / "data.zip"
|
||||||
gdown.download(url, str(zip_path), quiet=False)
|
gdown.download(url, str(zip_path), quiet=False)
|
||||||
|
@ -234,13 +322,13 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
|
||||||
with open(dataset_path, "rb") as f:
|
with open(dataset_path, "rb") as f:
|
||||||
dataset_dict = pickle.load(f)
|
dataset_dict = pickle.load(f)
|
||||||
|
|
||||||
total_frames = dataset_dict["actions"].shape[0]
|
|
||||||
|
|
||||||
ep_dicts = []
|
ep_dicts = []
|
||||||
|
episode_data_index = {"from": [], "to": []}
|
||||||
|
|
||||||
id_from = 0
|
id_from = 0
|
||||||
id_to = 0
|
id_to = 0
|
||||||
episode_id = 0
|
episode_id = 0
|
||||||
|
total_frames = dataset_dict["actions"].shape[0]
|
||||||
for i in tqdm.tqdm(range(total_frames)):
|
for i in tqdm.tqdm(range(total_frames)):
|
||||||
id_to += 1
|
id_to += 1
|
||||||
|
|
||||||
|
@ -264,35 +352,23 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
|
||||||
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
|
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
|
||||||
"observation.state": state,
|
"observation.state": state,
|
||||||
"action": action,
|
"action": action,
|
||||||
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||||
"frame_id": torch.arange(0, num_frames, 1),
|
"frame_index": torch.arange(0, num_frames, 1),
|
||||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||||
# "next.observation.image": next_image,
|
# "next.observation.image": next_image,
|
||||||
# "next.observation.state": next_state,
|
# "next.observation.state": next_state,
|
||||||
"next.reward": next_reward,
|
"next.reward": next_reward,
|
||||||
"next.done": next_done,
|
"next.done": next_done,
|
||||||
"episode_data_index_from": torch.tensor([id_from] * num_frames),
|
|
||||||
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
|
|
||||||
}
|
}
|
||||||
ep_dicts.append(ep_dict)
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
|
episode_data_index["from"].append(id_from)
|
||||||
|
episode_data_index["to"].append(id_from + num_frames)
|
||||||
|
|
||||||
id_from = id_to
|
id_from = id_to
|
||||||
episode_id += 1
|
episode_id += 1
|
||||||
|
|
||||||
data_dict = {}
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
keys = ep_dicts[0].keys()
|
|
||||||
for key in keys:
|
|
||||||
if torch.is_tensor(ep_dicts[0][key][0]):
|
|
||||||
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
|
||||||
else:
|
|
||||||
if key not in data_dict:
|
|
||||||
data_dict[key] = []
|
|
||||||
for ep_dict in ep_dicts:
|
|
||||||
for x in ep_dict[key]:
|
|
||||||
data_dict[key].append(x)
|
|
||||||
|
|
||||||
total_frames = id_from
|
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
|
||||||
|
|
||||||
features = {
|
features = {
|
||||||
"observation.image": Image(),
|
"observation.image": Image(),
|
||||||
|
@ -300,27 +376,26 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
|
||||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
),
|
),
|
||||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||||
"episode_id": Value(dtype="int64", id=None),
|
"episode_index": Value(dtype="int64", id=None),
|
||||||
"frame_id": Value(dtype="int64", id=None),
|
"frame_index": Value(dtype="int64", id=None),
|
||||||
"timestamp": Value(dtype="float32", id=None),
|
"timestamp": Value(dtype="float32", id=None),
|
||||||
"next.reward": Value(dtype="float32", id=None),
|
"next.reward": Value(dtype="float32", id=None),
|
||||||
"next.done": Value(dtype="bool", id=None),
|
"next.done": Value(dtype="bool", id=None),
|
||||||
#'next.success': Value(dtype='bool', id=None),
|
#'next.success': Value(dtype='bool', id=None),
|
||||||
"index": Value(dtype="int64", id=None),
|
"index": Value(dtype="int64", id=None),
|
||||||
"episode_data_index_from": Value(dtype="int64", id=None),
|
|
||||||
"episode_data_index_to": Value(dtype="int64", id=None),
|
|
||||||
}
|
}
|
||||||
features = Features(features)
|
features = Features(features)
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||||
hf_dataset = hf_dataset.with_format("torch")
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
|
||||||
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
info = {
|
||||||
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
|
"fps": fps,
|
||||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
}
|
||||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
stats = compute_stats(hf_dataset)
|
||||||
|
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
|
||||||
|
|
||||||
|
|
||||||
def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
|
def download_and_upload_aloha(root, revision, dataset_id, fps=50):
|
||||||
folder_urls = {
|
folder_urls = {
|
||||||
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
|
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
|
||||||
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
|
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
|
||||||
|
@ -381,6 +456,7 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
|
||||||
gdown.download(ep49_urls[dataset_id], output=str(raw_dir / "episode_49.hdf5"), fuzzy=True)
|
gdown.download(ep49_urls[dataset_id], output=str(raw_dir / "episode_49.hdf5"), fuzzy=True)
|
||||||
|
|
||||||
ep_dicts = []
|
ep_dicts = []
|
||||||
|
episode_data_index = {"from": [], "to": []}
|
||||||
|
|
||||||
id_from = 0
|
id_from = 0
|
||||||
for ep_id in tqdm.tqdm(range(num_episodes[dataset_id])):
|
for ep_id in tqdm.tqdm(range(num_episodes[dataset_id])):
|
||||||
|
@ -408,40 +484,26 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
|
||||||
{
|
{
|
||||||
"observation.state": state,
|
"observation.state": state,
|
||||||
"action": action,
|
"action": action,
|
||||||
"episode_id": torch.tensor([ep_id] * num_frames),
|
"episode_index": torch.tensor([ep_id] * num_frames),
|
||||||
"frame_id": torch.arange(0, num_frames, 1),
|
"frame_index": torch.arange(0, num_frames, 1),
|
||||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||||
# "next.observation.state": state,
|
# "next.observation.state": state,
|
||||||
# TODO(rcadene): compute reward and success
|
# TODO(rcadene): compute reward and success
|
||||||
# "next.reward": reward,
|
# "next.reward": reward,
|
||||||
"next.done": done,
|
"next.done": done,
|
||||||
# "next.success": success,
|
# "next.success": success,
|
||||||
"episode_data_index_from": torch.tensor([id_from] * num_frames),
|
|
||||||
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(ep_id, int)
|
assert isinstance(ep_id, int)
|
||||||
ep_dicts.append(ep_dict)
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
|
episode_data_index["from"].append(id_from)
|
||||||
|
episode_data_index["to"].append(id_from + num_frames)
|
||||||
|
|
||||||
id_from += num_frames
|
id_from += num_frames
|
||||||
|
|
||||||
data_dict = {}
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
|
|
||||||
data_dict = {}
|
|
||||||
keys = ep_dicts[0].keys()
|
|
||||||
for key in keys:
|
|
||||||
if torch.is_tensor(ep_dicts[0][key][0]):
|
|
||||||
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
|
||||||
else:
|
|
||||||
if key not in data_dict:
|
|
||||||
data_dict[key] = []
|
|
||||||
for ep_dict in ep_dicts:
|
|
||||||
for x in ep_dict[key]:
|
|
||||||
data_dict[key].append(x)
|
|
||||||
|
|
||||||
total_frames = id_from
|
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
|
||||||
|
|
||||||
features = {
|
features = {
|
||||||
"observation.images.top": Image(),
|
"observation.images.top": Image(),
|
||||||
|
@ -449,39 +511,39 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
|
||||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
),
|
),
|
||||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||||
"episode_id": Value(dtype="int64", id=None),
|
"episode_index": Value(dtype="int64", id=None),
|
||||||
"frame_id": Value(dtype="int64", id=None),
|
"frame_index": Value(dtype="int64", id=None),
|
||||||
"timestamp": Value(dtype="float32", id=None),
|
"timestamp": Value(dtype="float32", id=None),
|
||||||
#'next.reward': Value(dtype='float32', id=None),
|
#'next.reward': Value(dtype='float32', id=None),
|
||||||
"next.done": Value(dtype="bool", id=None),
|
"next.done": Value(dtype="bool", id=None),
|
||||||
#'next.success': Value(dtype='bool', id=None),
|
#'next.success': Value(dtype='bool', id=None),
|
||||||
"index": Value(dtype="int64", id=None),
|
"index": Value(dtype="int64", id=None),
|
||||||
"episode_data_index_from": Value(dtype="int64", id=None),
|
|
||||||
"episode_data_index_to": Value(dtype="int64", id=None),
|
|
||||||
}
|
}
|
||||||
features = Features(features)
|
features = Features(features)
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||||
hf_dataset = hf_dataset.with_format("torch")
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
|
||||||
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
info = {
|
||||||
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
|
"fps": fps,
|
||||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
}
|
||||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
stats = compute_stats(hf_dataset)
|
||||||
|
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
root = "data"
|
root = "data"
|
||||||
root_tests = "tests/data"
|
revision = "v1.1"
|
||||||
|
|
||||||
dataset_ids = [
|
dataset_ids = [
|
||||||
# "pusht",
|
"pusht",
|
||||||
# "xarm_lift_medium",
|
"xarm_lift_medium",
|
||||||
# "aloha_sim_insertion_human",
|
"xarm_lift_medium_replay",
|
||||||
# "aloha_sim_insertion_scripted",
|
"xarm_push_medium",
|
||||||
# "aloha_sim_transfer_cube_human",
|
"xarm_push_medium_replay",
|
||||||
|
"aloha_sim_insertion_human",
|
||||||
|
"aloha_sim_insertion_scripted",
|
||||||
|
"aloha_sim_transfer_cube_human",
|
||||||
"aloha_sim_transfer_cube_scripted",
|
"aloha_sim_transfer_cube_scripted",
|
||||||
]
|
]
|
||||||
for dataset_id in dataset_ids:
|
for dataset_id in dataset_ids:
|
||||||
download_and_upload(root, root_tests, dataset_id)
|
download_and_upload(root, revision, dataset_id)
|
||||||
# assume stats have been precomputed
|
|
||||||
shutil.copy(f"{root}/{dataset_id}/stats.pth", f"{root_tests}/{dataset_id}/stats.pth")
|
|
||||||
|
|
|
@ -10,10 +10,13 @@ As an example, this script saves frames of episode number 5 of the PushT dataset
|
||||||
This script supports several Hugging Face datasets, among which:
|
This script supports several Hugging Face datasets, among which:
|
||||||
1. [Pusht](https://huggingface.co/datasets/lerobot/pusht)
|
1. [Pusht](https://huggingface.co/datasets/lerobot/pusht)
|
||||||
2. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_lift_medium)
|
2. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_lift_medium)
|
||||||
3. [Aloha Sim Insertion Human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
|
3. [Xarm Lift Medium Replay](https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay)
|
||||||
4. [Aloha Sim Insertion Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
|
4. [Xarm Push Medium](https://huggingface.co/datasets/lerobot/xarm_push_medium)
|
||||||
5. [Aloha Sim Transfer Cube Human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
|
5. [Xarm Push Medium Replay](https://huggingface.co/datasets/lerobot/xarm_push_medium_replay)
|
||||||
6. [Aloha Sim Transfer Cube Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
|
6. [Aloha Sim Insertion Human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
|
||||||
|
7. [Aloha Sim Insertion Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
|
||||||
|
8. [Aloha Sim Transfer Cube Human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
|
||||||
|
9. [Aloha Sim Transfer Cube Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
|
||||||
|
|
||||||
To try a different Hugging Face dataset, you can replace this line:
|
To try a different Hugging Face dataset, you can replace this line:
|
||||||
```python
|
```python
|
||||||
|
@ -22,12 +25,16 @@ hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
|
||||||
by one of these:
|
by one of these:
|
||||||
```python
|
```python
|
||||||
hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium", split="train"), 15
|
hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium", split="train"), 15
|
||||||
|
hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium_replay", split="train"), 15
|
||||||
|
hf_dataset, fps = load_dataset("lerobot/xarm_push_medium", split="train"), 15
|
||||||
|
hf_dataset, fps = load_dataset("lerobot/xarm_push_medium_replay", split="train"), 15
|
||||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_human", split="train"), 50
|
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_human", split="train"), 50
|
||||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_scripted", split="train"), 50
|
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_scripted", split="train"), 50
|
||||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_human", split="train"), 50
|
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_human", split="train"), 50
|
||||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_scripted", split="train"), 50
|
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_scripted", split="train"), 50
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
# TODO(rcadene): remove this example file of using hf_dataset
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -37,19 +44,22 @@ from datasets import load_dataset
|
||||||
# TODO(rcadene): list available datasets on lerobot page using `datasets`
|
# TODO(rcadene): list available datasets on lerobot page using `datasets`
|
||||||
|
|
||||||
# download/load hugging face dataset in pyarrow format
|
# download/load hugging face dataset in pyarrow format
|
||||||
hf_dataset, fps = load_dataset("lerobot/pusht", revision="v1.0", split="train"), 10
|
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
|
||||||
|
|
||||||
# display name of dataset and its features
|
# display name of dataset and its features
|
||||||
|
# TODO(rcadene): update to make the print pretty
|
||||||
print(f"{hf_dataset=}")
|
print(f"{hf_dataset=}")
|
||||||
print(f"{hf_dataset.features=}")
|
print(f"{hf_dataset.features=}")
|
||||||
|
|
||||||
# display useful statistics about frames and episodes, which are sequences of frames from the same video
|
# display useful statistics about frames and episodes, which are sequences of frames from the same video
|
||||||
print(f"number of frames: {len(hf_dataset)=}")
|
print(f"number of frames: {len(hf_dataset)=}")
|
||||||
print(f"number of episodes: {len(hf_dataset.unique('episode_id'))=}")
|
print(f"number of episodes: {len(hf_dataset.unique('episode_index'))=}")
|
||||||
print(f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_id')):.3f}")
|
print(
|
||||||
|
f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_index')):.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
# select the frames belonging to episode number 5
|
# select the frames belonging to episode number 5
|
||||||
hf_dataset = hf_dataset.filter(lambda frame: frame["episode_id"] == 5)
|
hf_dataset = hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
|
||||||
|
|
||||||
# load all frames of episode 5 in RAM in PIL format
|
# load all frames of episode 5 in RAM in PIL format
|
||||||
frames = hf_dataset["observation.image"]
|
frames = hf_dataset["observation.image"]
|
||||||
|
|
|
@ -18,7 +18,10 @@ dataset = PushtDataset()
|
||||||
```
|
```
|
||||||
by one of these:
|
by one of these:
|
||||||
```python
|
```python
|
||||||
dataset = XarmDataset()
|
dataset = XarmDataset("xarm_lift_medium")
|
||||||
|
dataset = XarmDataset("xarm_lift_medium_replay")
|
||||||
|
dataset = XarmDataset("xarm_push_medium")
|
||||||
|
dataset = XarmDataset("xarm_push_medium_replay")
|
||||||
dataset = AlohaDataset("aloha_sim_insertion_human")
|
dataset = AlohaDataset("aloha_sim_insertion_human")
|
||||||
dataset = AlohaDataset("aloha_sim_insertion_scripted")
|
dataset = AlohaDataset("aloha_sim_insertion_scripted")
|
||||||
dataset = AlohaDataset("aloha_sim_transfer_cube_human")
|
dataset = AlohaDataset("aloha_sim_transfer_cube_human")
|
||||||
|
@ -44,6 +47,7 @@ from lerobot.common.datasets.pusht import PushtDataset
|
||||||
dataset = PushtDataset()
|
dataset = PushtDataset()
|
||||||
|
|
||||||
# All LeRobot datasets are actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information).
|
# All LeRobot datasets are actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information).
|
||||||
|
# TODO(rcadene): update to make the print pretty
|
||||||
print(f"{dataset=}")
|
print(f"{dataset=}")
|
||||||
print(f"{dataset.hf_dataset=}")
|
print(f"{dataset.hf_dataset=}")
|
||||||
|
|
||||||
|
@ -55,13 +59,16 @@ print(f"frames per second used during data collection: {dataset.fps=}")
|
||||||
print(f"keys to access images from cameras: {dataset.image_keys=}")
|
print(f"keys to access images from cameras: {dataset.image_keys=}")
|
||||||
|
|
||||||
# While the LeRobot dataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
|
# While the LeRobot dataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
|
||||||
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_id"] == 5)
|
# TODO(rcadene): remove this example of accessing hf_dataset
|
||||||
|
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
|
||||||
|
|
||||||
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grap all the image frames.
|
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grab all the image frames.
|
||||||
frames = [sample["observation.image"] for sample in dataset]
|
frames = [sample["observation.image"] for sample in dataset]
|
||||||
|
|
||||||
# but frames are now channel first to follow pytorch convention,
|
# but frames are now float32 range [0,1] channel first (c,h,w) to follow pytorch convention,
|
||||||
# to view them, we convert to channel last
|
# to view them, we convert to uint8 range [0,255]
|
||||||
|
frames = [(frame * 255).type(torch.uint8) for frame in frames]
|
||||||
|
# and to channel last (h,w,c)
|
||||||
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
|
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
|
||||||
|
|
||||||
# and finally save them to a mp4 video
|
# and finally save them to a mp4 video
|
||||||
|
|
|
@ -50,7 +50,12 @@ available_datasets = {
|
||||||
"aloha_sim_transfer_cube_scripted",
|
"aloha_sim_transfer_cube_scripted",
|
||||||
],
|
],
|
||||||
"pusht": ["pusht"],
|
"pusht": ["pusht"],
|
||||||
"xarm": ["xarm_lift_medium"],
|
"xarm": [
|
||||||
|
"xarm_lift_medium",
|
||||||
|
"xarm_lift_medium_replay",
|
||||||
|
"xarm_push_medium",
|
||||||
|
"xarm_push_medium_replay",
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
available_policies = [
|
available_policies = [
|
||||||
|
|
|
@ -1,9 +1,13 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset, load_from_disk
|
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import load_previous_and_future_frames
|
from lerobot.common.datasets.utils import (
|
||||||
|
load_episode_data_index,
|
||||||
|
load_hf_dataset,
|
||||||
|
load_previous_and_future_frames,
|
||||||
|
load_stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AlohaDataset(torch.utils.data.Dataset):
|
class AlohaDataset(torch.utils.data.Dataset):
|
||||||
|
@ -27,7 +31,7 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
version: str | None = "v1.0",
|
version: str | None = "v1.1",
|
||||||
root: Path | None = None,
|
root: Path | None = None,
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
transform: callable = None,
|
transform: callable = None,
|
||||||
|
@ -40,13 +44,10 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||||
self.split = split
|
self.split = split
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
if self.root is not None:
|
# load data from hub or locally when root is provided
|
||||||
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
|
||||||
else:
|
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
|
||||||
self.hf_dataset = load_dataset(
|
self.stats = load_stats(dataset_id, version, root)
|
||||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
|
||||||
)
|
|
||||||
self.hf_dataset = self.hf_dataset.with_format("torch")
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_samples(self) -> int:
|
def num_samples(self) -> int:
|
||||||
|
@ -54,7 +55,7 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_episodes(self) -> int:
|
def num_episodes(self) -> int:
|
||||||
return len(self.hf_dataset.unique("episode_id"))
|
return len(self.hf_dataset.unique("episode_index"))
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.num_samples
|
return self.num_samples
|
||||||
|
@ -66,19 +67,11 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||||
item = load_previous_and_future_frames(
|
item = load_previous_and_future_frames(
|
||||||
item,
|
item,
|
||||||
self.hf_dataset,
|
self.hf_dataset,
|
||||||
|
self.episode_data_index,
|
||||||
self.delta_timestamps,
|
self.delta_timestamps,
|
||||||
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
|
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
|
||||||
)
|
)
|
||||||
|
|
||||||
# convert images from channel last (PIL) to channel first (pytorch)
|
|
||||||
for key in self.image_keys:
|
|
||||||
if item[key].ndim == 3:
|
|
||||||
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
|
|
||||||
elif item[key].ndim == 4:
|
|
||||||
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
|
|
||||||
else:
|
|
||||||
raise ValueError(item[key].ndim)
|
|
||||||
|
|
||||||
if self.transform is not None:
|
if self.transform is not None:
|
||||||
item = self.transform(item)
|
item = self.transform(item)
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,10 @@
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torchvision.transforms import v2
|
from torchvision.transforms import v2
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import compute_stats
|
from lerobot.common.transforms import NormalizeTransform
|
||||||
from lerobot.common.transforms import NormalizeTransform, Prod
|
|
||||||
|
|
||||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||||
|
|
||||||
|
@ -52,32 +50,18 @@ def make_dataset(
|
||||||
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||||
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||||
elif stats_path is None:
|
elif stats_path is None:
|
||||||
# load stats if the file exists already or compute stats and save it
|
# load a first dataset to access precomputed stats
|
||||||
if DATA_DIR is None:
|
stats_dataset = clsfunc(
|
||||||
# TODO(rcadene): clean stats
|
dataset_id=cfg.dataset_id,
|
||||||
precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth"
|
split="train",
|
||||||
else:
|
root=DATA_DIR,
|
||||||
precomputed_stats_path = DATA_DIR / cfg.dataset_id / "stats.pth"
|
)
|
||||||
if precomputed_stats_path.exists():
|
stats = stats_dataset.stats
|
||||||
stats = torch.load(precomputed_stats_path)
|
|
||||||
else:
|
|
||||||
logging.info(f"compute_stats and save to {precomputed_stats_path}")
|
|
||||||
# Create a dataset for stats computation.
|
|
||||||
stats_dataset = clsfunc(
|
|
||||||
dataset_id=cfg.dataset_id,
|
|
||||||
split="train",
|
|
||||||
root=DATA_DIR,
|
|
||||||
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
|
||||||
)
|
|
||||||
stats = compute_stats(stats_dataset)
|
|
||||||
precomputed_stats_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
torch.save(stats, precomputed_stats_path)
|
|
||||||
else:
|
else:
|
||||||
stats = torch.load(stats_path)
|
stats = torch.load(stats_path)
|
||||||
|
|
||||||
transforms = v2.Compose(
|
transforms = v2.Compose(
|
||||||
[
|
[
|
||||||
Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
|
||||||
NormalizeTransform(
|
NormalizeTransform(
|
||||||
stats,
|
stats,
|
||||||
in_keys=[
|
in_keys=[
|
||||||
|
|
|
@ -1,9 +1,13 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset, load_from_disk
|
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import load_previous_and_future_frames
|
from lerobot.common.datasets.utils import (
|
||||||
|
load_episode_data_index,
|
||||||
|
load_hf_dataset,
|
||||||
|
load_previous_and_future_frames,
|
||||||
|
load_stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PushtDataset(torch.utils.data.Dataset):
|
class PushtDataset(torch.utils.data.Dataset):
|
||||||
|
@ -25,7 +29,7 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset_id: str = "pusht",
|
dataset_id: str = "pusht",
|
||||||
version: str | None = "v1.0",
|
version: str | None = "v1.1",
|
||||||
root: Path | None = None,
|
root: Path | None = None,
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
transform: callable = None,
|
transform: callable = None,
|
||||||
|
@ -38,13 +42,10 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||||
self.split = split
|
self.split = split
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
if self.root is not None:
|
# load data from hub or locally when root is provided
|
||||||
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
|
||||||
else:
|
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
|
||||||
self.hf_dataset = load_dataset(
|
self.stats = load_stats(dataset_id, version, root)
|
||||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
|
||||||
)
|
|
||||||
self.hf_dataset = self.hf_dataset.with_format("torch")
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_samples(self) -> int:
|
def num_samples(self) -> int:
|
||||||
|
@ -52,7 +53,7 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_episodes(self) -> int:
|
def num_episodes(self) -> int:
|
||||||
return len(self.hf_dataset.unique("episode_id"))
|
return len(self.episode_data_index["from"])
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.num_samples
|
return self.num_samples
|
||||||
|
@ -64,19 +65,11 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||||
item = load_previous_and_future_frames(
|
item = load_previous_and_future_frames(
|
||||||
item,
|
item,
|
||||||
self.hf_dataset,
|
self.hf_dataset,
|
||||||
|
self.episode_data_index,
|
||||||
self.delta_timestamps,
|
self.delta_timestamps,
|
||||||
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
|
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
|
||||||
)
|
)
|
||||||
|
|
||||||
# convert images from channel last (PIL) to channel first (pytorch)
|
|
||||||
for key in self.image_keys:
|
|
||||||
if item[key].ndim == 3:
|
|
||||||
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
|
|
||||||
elif item[key].ndim == 4:
|
|
||||||
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
|
|
||||||
else:
|
|
||||||
raise ValueError(item[key].ndim)
|
|
||||||
|
|
||||||
if self.transform is not None:
|
if self.transform is not None:
|
||||||
item = self.transform(item)
|
item = self.transform(item)
|
||||||
|
|
||||||
|
|
|
@ -1,15 +1,121 @@
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from math import ceil
|
from math import ceil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import einops
|
import einops
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
from datasets import Image, load_dataset, load_from_disk
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_dict(d, parent_key="", sep="/"):
|
||||||
|
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
```
|
||||||
|
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}`
|
||||||
|
>>> print(flatten_dict(dct))
|
||||||
|
{"a/b": 1, "a/c/d": 2, "e": 3}
|
||||||
|
"""
|
||||||
|
items = []
|
||||||
|
for k, v in d.items():
|
||||||
|
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
||||||
|
if isinstance(v, dict):
|
||||||
|
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
||||||
|
else:
|
||||||
|
items.append((new_key, v))
|
||||||
|
return dict(items)
|
||||||
|
|
||||||
|
|
||||||
|
def unflatten_dict(d, sep="/"):
|
||||||
|
outdict = {}
|
||||||
|
for key, value in d.items():
|
||||||
|
parts = key.split(sep)
|
||||||
|
d = outdict
|
||||||
|
for part in parts[:-1]:
|
||||||
|
if part not in d:
|
||||||
|
d[part] = {}
|
||||||
|
d = d[part]
|
||||||
|
d[parts[-1]] = value
|
||||||
|
return outdict
|
||||||
|
|
||||||
|
|
||||||
|
def hf_transform_to_torch(items_dict):
|
||||||
|
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||||
|
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||||
|
a channel last representation (h w c) of uint8 type, to a torch image representation
|
||||||
|
with channel first (c h w) of float32 type in range [0,1].
|
||||||
|
"""
|
||||||
|
for key in items_dict:
|
||||||
|
first_item = items_dict[key][0]
|
||||||
|
if isinstance(first_item, PILImage.Image):
|
||||||
|
to_tensor = transforms.ToTensor()
|
||||||
|
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||||
|
else:
|
||||||
|
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
|
||||||
|
return items_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load_hf_dataset(dataset_id, version, root, split) -> datasets.Dataset:
|
||||||
|
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||||
|
if root is not None:
|
||||||
|
hf_dataset = load_from_disk(str(Path(root) / dataset_id / split))
|
||||||
|
else:
|
||||||
|
# TODO(rcadene): remove dataset_id everywhere and use repo_id instead
|
||||||
|
repo_id = f"lerobot/{dataset_id}"
|
||||||
|
hf_dataset = load_dataset(repo_id, revision=version, split=split)
|
||||||
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
return hf_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor]:
|
||||||
|
"""episode_data_index contains the range of indices for each episode
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from_id = episode_data_index["from"][episode_id].item()
|
||||||
|
to_id = episode_data_index["to"][episode_id].item()
|
||||||
|
episode_frames = [dataset[i] for i in range(from_id, to_id)]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if root is not None:
|
||||||
|
path = Path(root) / dataset_id / "meta_data" / "episode_data_index.safetensors"
|
||||||
|
else:
|
||||||
|
repo_id = f"lerobot/{dataset_id}"
|
||||||
|
path = hf_hub_download(
|
||||||
|
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version
|
||||||
|
)
|
||||||
|
|
||||||
|
return load_file(path)
|
||||||
|
|
||||||
|
|
||||||
|
def load_stats(dataset_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
|
||||||
|
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if root is not None:
|
||||||
|
path = Path(root) / dataset_id / "meta_data" / "stats.safetensors"
|
||||||
|
else:
|
||||||
|
repo_id = f"lerobot/{dataset_id}"
|
||||||
|
path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version)
|
||||||
|
|
||||||
|
stats = load_file(path)
|
||||||
|
return unflatten_dict(stats)
|
||||||
|
|
||||||
|
|
||||||
def load_previous_and_future_frames(
|
def load_previous_and_future_frames(
|
||||||
item: dict[str, torch.Tensor],
|
item: dict[str, torch.Tensor],
|
||||||
hf_dataset: datasets.Dataset,
|
hf_dataset: datasets.Dataset,
|
||||||
|
episode_data_index: dict[str, torch.Tensor],
|
||||||
delta_timestamps: dict[str, list[float]],
|
delta_timestamps: dict[str, list[float]],
|
||||||
tol: float,
|
tol: float,
|
||||||
) -> dict[torch.Tensor]:
|
) -> dict[torch.Tensor]:
|
||||||
|
@ -31,6 +137,8 @@ def load_previous_and_future_frames(
|
||||||
corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
||||||
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different
|
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different
|
||||||
modality (e.g., "timestamp", "observation.image", "action").
|
modality (e.g., "timestamp", "observation.image", "action").
|
||||||
|
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
|
||||||
|
They indicate the start index and end index of each episode in the dataset.
|
||||||
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be
|
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be
|
||||||
retrieved. These deltas are added to the item timestamp to form the query timestamps.
|
retrieved. These deltas are added to the item timestamp to form the query timestamps.
|
||||||
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query
|
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query
|
||||||
|
@ -46,12 +154,14 @@ def load_previous_and_future_frames(
|
||||||
issues with timestamps during data collection.
|
issues with timestamps during data collection.
|
||||||
"""
|
"""
|
||||||
# get indices of the frames associated to the episode, and their timestamps
|
# get indices of the frames associated to the episode, and their timestamps
|
||||||
ep_data_id_from = item["episode_data_index_from"].item()
|
ep_id = item["episode_index"].item()
|
||||||
ep_data_id_to = item["episode_data_index_to"].item()
|
ep_data_id_from = episode_data_index["from"][ep_id].item()
|
||||||
|
ep_data_id_to = episode_data_index["to"][ep_id].item()
|
||||||
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
|
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
|
||||||
|
|
||||||
# load timestamps
|
# load timestamps
|
||||||
ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
|
ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
|
||||||
|
ep_timestamps = torch.stack(ep_timestamps)
|
||||||
|
|
||||||
# we make the assumption that the timestamps are sorted
|
# we make the assumption that the timestamps are sorted
|
||||||
ep_first_ts = ep_timestamps[0]
|
ep_first_ts = ep_timestamps[0]
|
||||||
|
@ -82,39 +192,57 @@ def load_previous_and_future_frames(
|
||||||
|
|
||||||
# load frames modality
|
# load frames modality
|
||||||
item[key] = hf_dataset.select_columns(key)[data_ids][key]
|
item[key] = hf_dataset.select_columns(key)[data_ids][key]
|
||||||
|
item[key] = torch.stack(item[key])
|
||||||
item[f"{key}_is_pad"] = is_pad
|
item[f"{key}_is_pad"] = is_pad
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
|
||||||
def get_stats_einops_patterns(dataset):
|
def get_stats_einops_patterns(hf_dataset):
|
||||||
"""These einops patterns will be used to aggregate batches and compute statistics."""
|
"""These einops patterns will be used to aggregate batches and compute statistics.
|
||||||
stats_patterns = {
|
|
||||||
"action": "b c -> c",
|
Note: We assume the images of `hf_dataset` are in channel first format
|
||||||
"observation.state": "b c -> c",
|
"""
|
||||||
}
|
|
||||||
for key in dataset.image_keys:
|
dataloader = torch.utils.data.DataLoader(
|
||||||
stats_patterns[key] = "b c h w -> c 1 1"
|
hf_dataset,
|
||||||
|
num_workers=0,
|
||||||
|
batch_size=2,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
batch = next(iter(dataloader))
|
||||||
|
|
||||||
|
stats_patterns = {}
|
||||||
|
for key, feats_type in hf_dataset.features.items():
|
||||||
|
# sanity check that tensors are not float64
|
||||||
|
assert batch[key].dtype != torch.float64
|
||||||
|
|
||||||
|
if isinstance(feats_type, Image):
|
||||||
|
# sanity check that images are channel first
|
||||||
|
_, c, h, w = batch[key].shape
|
||||||
|
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
|
||||||
|
|
||||||
|
# sanity check that images are float32 in range [0,1]
|
||||||
|
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
|
||||||
|
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
|
||||||
|
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
|
||||||
|
|
||||||
|
stats_patterns[key] = "b c h w -> c 1 1"
|
||||||
|
elif batch[key].ndim == 2:
|
||||||
|
stats_patterns[key] = "b c -> c "
|
||||||
|
elif batch[key].ndim == 1:
|
||||||
|
stats_patterns[key] = "b -> 1"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
|
||||||
|
|
||||||
return stats_patterns
|
return stats_patterns
|
||||||
|
|
||||||
|
|
||||||
def compute_stats(dataset, batch_size=32, max_num_samples=None):
|
def compute_stats(hf_dataset, batch_size=32, max_num_samples=None):
|
||||||
if max_num_samples is None:
|
if max_num_samples is None:
|
||||||
max_num_samples = len(dataset)
|
max_num_samples = len(hf_dataset)
|
||||||
else:
|
|
||||||
raise NotImplementedError("We need to set shuffle=True, but this violate an assert for now.")
|
|
||||||
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
stats_patterns = get_stats_einops_patterns(hf_dataset)
|
||||||
dataset,
|
|
||||||
num_workers=4,
|
|
||||||
batch_size=batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
# pin_memory=cfg.device != "cpu",
|
|
||||||
drop_last=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# get einops patterns to aggregate batches and compute statistics
|
|
||||||
stats_patterns = get_stats_einops_patterns(dataset)
|
|
||||||
|
|
||||||
# mean and std will be computed incrementally while max and min will track the running value.
|
# mean and std will be computed incrementally while max and min will track the running value.
|
||||||
mean, std, max, min = {}, {}, {}, {}
|
mean, std, max, min = {}, {}, {}, {}
|
||||||
|
@ -124,10 +252,24 @@ def compute_stats(dataset, batch_size=32, max_num_samples=None):
|
||||||
max[key] = torch.tensor(-float("inf")).float()
|
max[key] = torch.tensor(-float("inf")).float()
|
||||||
min[key] = torch.tensor(float("inf")).float()
|
min[key] = torch.tensor(float("inf")).float()
|
||||||
|
|
||||||
|
def create_seeded_dataloader(hf_dataset, batch_size, seed):
|
||||||
|
generator = torch.Generator()
|
||||||
|
generator.manual_seed(seed)
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
hf_dataset,
|
||||||
|
num_workers=4,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=False,
|
||||||
|
generator=generator,
|
||||||
|
)
|
||||||
|
return dataloader
|
||||||
|
|
||||||
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
|
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
|
||||||
# surprises when rerunning the sampler.
|
# surprises when rerunning the sampler.
|
||||||
first_batch = None
|
first_batch = None
|
||||||
running_item_count = 0 # for online mean computation
|
running_item_count = 0 # for online mean computation
|
||||||
|
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
|
||||||
for i, batch in enumerate(
|
for i, batch in enumerate(
|
||||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
|
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
|
||||||
):
|
):
|
||||||
|
@ -153,6 +295,7 @@ def compute_stats(dataset, batch_size=32, max_num_samples=None):
|
||||||
|
|
||||||
first_batch_ = None
|
first_batch_ = None
|
||||||
running_item_count = 0 # for online std computation
|
running_item_count = 0 # for online std computation
|
||||||
|
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
|
||||||
for i, batch in enumerate(
|
for i, batch in enumerate(
|
||||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
|
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
|
||||||
):
|
):
|
||||||
|
|
|
@ -1,25 +1,37 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset, load_from_disk
|
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import load_previous_and_future_frames
|
from lerobot.common.datasets.utils import (
|
||||||
|
load_episode_data_index,
|
||||||
|
load_hf_dataset,
|
||||||
|
load_previous_and_future_frames,
|
||||||
|
load_stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class XarmDataset(torch.utils.data.Dataset):
|
class XarmDataset(torch.utils.data.Dataset):
|
||||||
"""
|
"""
|
||||||
https://huggingface.co/datasets/lerobot/xarm_lift_medium
|
https://huggingface.co/datasets/lerobot/xarm_lift_medium
|
||||||
|
https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay
|
||||||
|
https://huggingface.co/datasets/lerobot/xarm_push_medium
|
||||||
|
https://huggingface.co/datasets/lerobot/xarm_push_medium_replay
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Copied from lerobot/__init__.py
|
# Copied from lerobot/__init__.py
|
||||||
available_datasets = ["xarm_lift_medium"]
|
available_datasets = [
|
||||||
|
"xarm_lift_medium",
|
||||||
|
"xarm_lift_medium_replay",
|
||||||
|
"xarm_push_medium",
|
||||||
|
"xarm_push_medium_replay",
|
||||||
|
]
|
||||||
fps = 15
|
fps = 15
|
||||||
image_keys = ["observation.image"]
|
image_keys = ["observation.image"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset_id: str = "xarm_lift_medium",
|
dataset_id: str,
|
||||||
version: str | None = "v1.0",
|
version: str | None = "v1.1",
|
||||||
root: Path | None = None,
|
root: Path | None = None,
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
transform: callable = None,
|
transform: callable = None,
|
||||||
|
@ -32,13 +44,10 @@ class XarmDataset(torch.utils.data.Dataset):
|
||||||
self.split = split
|
self.split = split
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
if self.root is not None:
|
# load data from hub or locally when root is provided
|
||||||
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
|
||||||
else:
|
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
|
||||||
self.hf_dataset = load_dataset(
|
self.stats = load_stats(dataset_id, version, root)
|
||||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
|
||||||
)
|
|
||||||
self.hf_dataset = self.hf_dataset.with_format("torch")
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_samples(self) -> int:
|
def num_samples(self) -> int:
|
||||||
|
@ -46,7 +55,7 @@ class XarmDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_episodes(self) -> int:
|
def num_episodes(self) -> int:
|
||||||
return len(self.hf_dataset.unique("episode_id"))
|
return len(self.hf_dataset.unique("episode_index"))
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.num_samples
|
return self.num_samples
|
||||||
|
@ -58,19 +67,11 @@ class XarmDataset(torch.utils.data.Dataset):
|
||||||
item = load_previous_and_future_frames(
|
item = load_previous_and_future_frames(
|
||||||
item,
|
item,
|
||||||
self.hf_dataset,
|
self.hf_dataset,
|
||||||
|
self.episode_data_index,
|
||||||
self.delta_timestamps,
|
self.delta_timestamps,
|
||||||
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
|
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
|
||||||
)
|
)
|
||||||
|
|
||||||
# convert images from channel last (PIL) to channel first (pytorch)
|
|
||||||
for key in self.image_keys:
|
|
||||||
if item[key].ndim == 3:
|
|
||||||
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
|
|
||||||
elif item[key].ndim == 4:
|
|
||||||
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
|
|
||||||
else:
|
|
||||||
raise ValueError(item[key].ndim)
|
|
||||||
|
|
||||||
if self.transform is not None:
|
if self.transform is not None:
|
||||||
item = self.transform(item)
|
item = self.transform(item)
|
||||||
|
|
||||||
|
|
|
@ -39,4 +39,5 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
|
||||||
for _ in range(num_parallel_envs)
|
for _ in range(num_parallel_envs)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
|
@ -15,8 +15,19 @@ def preprocess_observation(observation, transform=None):
|
||||||
|
|
||||||
for imgkey, img in imgs.items():
|
for imgkey, img in imgs.items():
|
||||||
img = torch.from_numpy(img)
|
img = torch.from_numpy(img)
|
||||||
# convert to (b c h w) torch format
|
|
||||||
|
# sanity check that images are channel last
|
||||||
|
_, h, w, c = img.shape
|
||||||
|
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
|
||||||
|
|
||||||
|
# sanity check that images are uint8
|
||||||
|
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||||
|
|
||||||
|
# convert to channel first of type float32 in range [0,1]
|
||||||
img = einops.rearrange(img, "b h w c -> b c h w")
|
img = einops.rearrange(img, "b h w c -> b c h w")
|
||||||
|
img = img.type(torch.float32)
|
||||||
|
img /= 255
|
||||||
|
|
||||||
obs[imgkey] = img
|
obs[imgkey] = img
|
||||||
|
|
||||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"
|
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import torch
|
|
||||||
from torchvision.transforms.v2 import Compose, Transform
|
from torchvision.transforms.v2 import Compose, Transform
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,40 +11,6 @@ def apply_inverse_transform(item, transform):
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
|
||||||
class Prod(Transform):
|
|
||||||
invertible = True
|
|
||||||
|
|
||||||
def __init__(self, in_keys: list[str], prod: float):
|
|
||||||
super().__init__()
|
|
||||||
self.in_keys = in_keys
|
|
||||||
self.prod = prod
|
|
||||||
self.original_dtypes = {}
|
|
||||||
|
|
||||||
def forward(self, item):
|
|
||||||
for key in self.in_keys:
|
|
||||||
if key not in item:
|
|
||||||
continue
|
|
||||||
self.original_dtypes[key] = item[key].dtype
|
|
||||||
item[key] = item[key].type(torch.float32) * self.prod
|
|
||||||
return item
|
|
||||||
|
|
||||||
def inverse_transform(self, item):
|
|
||||||
for key in self.in_keys:
|
|
||||||
if key not in item:
|
|
||||||
continue
|
|
||||||
item[key] = (item[key] / self.prod).type(self.original_dtypes[key])
|
|
||||||
return item
|
|
||||||
|
|
||||||
# def transform_observation_spec(self, obs_spec):
|
|
||||||
# for key in self.in_keys:
|
|
||||||
# if obs_spec.get(key, None) is None:
|
|
||||||
# continue
|
|
||||||
# obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod
|
|
||||||
# obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod
|
|
||||||
# obs_spec[key].dtype = torch.float32
|
|
||||||
# return obs_spec
|
|
||||||
|
|
||||||
|
|
||||||
class NormalizeTransform(Transform):
|
class NormalizeTransform(Transform):
|
||||||
invertible = True
|
invertible = True
|
||||||
|
|
||||||
|
|
|
@ -47,6 +47,7 @@ from PIL import Image as PILImage
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
from lerobot.common.datasets.utils import hf_transform_to_torch
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||||
from lerobot.common.logger import log_output_dir
|
from lerobot.common.logger import log_output_dir
|
||||||
|
@ -208,11 +209,12 @@ def eval_policy(
|
||||||
max_rewards.extend(batch_max_reward.tolist())
|
max_rewards.extend(batch_max_reward.tolist())
|
||||||
all_successes.extend(batch_success.tolist())
|
all_successes.extend(batch_success.tolist())
|
||||||
|
|
||||||
# similar logic is implemented in dataset preprocessing
|
# similar logic is implemented when datasets are pushed to hub (see: `push_to_hub`)
|
||||||
ep_dicts = []
|
ep_dicts = []
|
||||||
|
episode_data_index = {"from": [], "to": []}
|
||||||
num_episodes = dones.shape[0]
|
num_episodes = dones.shape[0]
|
||||||
total_frames = 0
|
total_frames = 0
|
||||||
idx_from = 0
|
id_from = 0
|
||||||
for ep_id in range(num_episodes):
|
for ep_id in range(num_episodes):
|
||||||
num_frames = done_indices[ep_id].item() + 1
|
num_frames = done_indices[ep_id].item() + 1
|
||||||
total_frames += num_frames
|
total_frames += num_frames
|
||||||
|
@ -222,19 +224,20 @@ def eval_policy(
|
||||||
if return_episode_data:
|
if return_episode_data:
|
||||||
ep_dict = {
|
ep_dict = {
|
||||||
"action": actions[ep_id, :num_frames],
|
"action": actions[ep_id, :num_frames],
|
||||||
"episode_id": torch.tensor([ep_id] * num_frames),
|
"episode_index": torch.tensor([ep_id] * num_frames),
|
||||||
"frame_id": torch.arange(0, num_frames, 1),
|
"frame_index": torch.arange(0, num_frames, 1),
|
||||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||||
"next.done": dones[ep_id, :num_frames],
|
"next.done": dones[ep_id, :num_frames],
|
||||||
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
|
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
|
||||||
"episode_data_index_from": torch.tensor([idx_from] * num_frames),
|
|
||||||
"episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames),
|
|
||||||
}
|
}
|
||||||
for key in observations:
|
for key in observations:
|
||||||
ep_dict[key] = observations[key][ep_id][:num_frames]
|
ep_dict[key] = observations[key][ep_id][:num_frames]
|
||||||
ep_dicts.append(ep_dict)
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
idx_from += num_frames
|
episode_data_index["from"].append(id_from)
|
||||||
|
episode_data_index["to"].append(id_from + num_frames)
|
||||||
|
|
||||||
|
id_from += num_frames
|
||||||
|
|
||||||
# similar logic is implemented in dataset preprocessing
|
# similar logic is implemented in dataset preprocessing
|
||||||
if return_episode_data:
|
if return_episode_data:
|
||||||
|
@ -247,14 +250,29 @@ def eval_policy(
|
||||||
if key not in data_dict:
|
if key not in data_dict:
|
||||||
data_dict[key] = []
|
data_dict[key] = []
|
||||||
for ep_dict in ep_dicts:
|
for ep_dict in ep_dicts:
|
||||||
for x in ep_dict[key]:
|
for img in ep_dict[key]:
|
||||||
# c h w -> h w c
|
# sanity check that images are channel first
|
||||||
img = PILImage.fromarray(x.permute(1, 2, 0).numpy())
|
c, h, w = img.shape
|
||||||
|
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
|
||||||
|
|
||||||
|
# sanity check that images are float32 in range [0,1]
|
||||||
|
assert img.dtype == torch.float32, f"expect torch.float32, but instead {img.dtype=}"
|
||||||
|
assert img.max() <= 1, f"expect pixels lower than 1, but instead {img.max()=}"
|
||||||
|
assert img.min() >= 0, f"expect pixels greater than 1, but instead {img.min()=}"
|
||||||
|
|
||||||
|
# from float32 in range [0,1] to uint8 in range [0,255]
|
||||||
|
img *= 255
|
||||||
|
img = img.type(torch.uint8)
|
||||||
|
|
||||||
|
# convert to channel last and numpy as expected by PIL
|
||||||
|
img = PILImage.fromarray(img.permute(1, 2, 0).numpy())
|
||||||
|
|
||||||
data_dict[key].append(img)
|
data_dict[key].append(img)
|
||||||
|
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
|
||||||
hf_dataset = Dataset.from_dict(data_dict).with_format("torch")
|
hf_dataset = Dataset.from_dict(data_dict)
|
||||||
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
|
||||||
if max_episodes_rendered > 0:
|
if max_episodes_rendered > 0:
|
||||||
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
||||||
|
@ -307,7 +325,10 @@ def eval_policy(
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if return_episode_data:
|
if return_episode_data:
|
||||||
info["episodes"] = hf_dataset
|
info["episodes"] = {
|
||||||
|
"hf_dataset": hf_dataset,
|
||||||
|
"episode_data_index": episode_data_index,
|
||||||
|
}
|
||||||
if max_episodes_rendered > 0:
|
if max_episodes_rendered > 0:
|
||||||
info["videos"] = videos
|
info["videos"] = videos
|
||||||
return info
|
return info
|
||||||
|
|
|
@ -136,6 +136,7 @@ def add_episodes_inplace(
|
||||||
concat_dataset: torch.utils.data.ConcatDataset,
|
concat_dataset: torch.utils.data.ConcatDataset,
|
||||||
sampler: torch.utils.data.WeightedRandomSampler,
|
sampler: torch.utils.data.WeightedRandomSampler,
|
||||||
hf_dataset: datasets.Dataset,
|
hf_dataset: datasets.Dataset,
|
||||||
|
episode_data_index: dict[str, torch.Tensor],
|
||||||
pc_online_samples: float,
|
pc_online_samples: float,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -151,13 +152,15 @@ def add_episodes_inplace(
|
||||||
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
|
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
|
||||||
reflect changes in the dataset sizes and specified sampling weights.
|
reflect changes in the dataset sizes and specified sampling weights.
|
||||||
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
|
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
|
||||||
|
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
|
||||||
|
They indicate the start index and end index of each episode in the dataset.
|
||||||
- pc_online_samples (float): The target percentage of samples that should come from
|
- pc_online_samples (float): The target percentage of samples that should come from
|
||||||
the online dataset during sampling operations.
|
the online dataset during sampling operations.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
- AssertionError: If the first episode_id or index in hf_dataset is not 0
|
- AssertionError: If the first episode_id or index in hf_dataset is not 0
|
||||||
"""
|
"""
|
||||||
first_episode_id = hf_dataset.select_columns("episode_id")[0]["episode_id"].item()
|
first_episode_id = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
|
||||||
first_index = hf_dataset.select_columns("index")[0]["index"].item()
|
first_index = hf_dataset.select_columns("index")[0]["index"].item()
|
||||||
assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}"
|
assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}"
|
||||||
assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}"
|
assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}"
|
||||||
|
@ -167,21 +170,22 @@ def add_episodes_inplace(
|
||||||
online_dataset.hf_dataset = hf_dataset
|
online_dataset.hf_dataset = hf_dataset
|
||||||
else:
|
else:
|
||||||
# find episode index and data frame indices according to previous episode in online_dataset
|
# find episode index and data frame indices according to previous episode in online_dataset
|
||||||
start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1
|
start_episode = online_dataset.select_columns("episode_index")[-1]["episode_index"].item() + 1
|
||||||
start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1
|
start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1
|
||||||
|
|
||||||
def shift_indices(example):
|
def shift_indices(example):
|
||||||
# note: we dont shift "frame_id" since it represents the index of the frame in the episode it belongs to
|
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
|
||||||
example["episode_id"] += start_episode
|
example["episode_index"] += start_episode
|
||||||
example["index"] += start_index
|
example["index"] += start_index
|
||||||
example["episode_data_index_from"] += start_index
|
|
||||||
example["episode_data_index_to"] += start_index
|
|
||||||
return example
|
return example
|
||||||
|
|
||||||
disable_progress_bars() # map has a tqdm progress bar
|
disable_progress_bars() # map has a tqdm progress bar
|
||||||
hf_dataset = hf_dataset.map(shift_indices)
|
hf_dataset = hf_dataset.map(shift_indices)
|
||||||
enable_progress_bars()
|
enable_progress_bars()
|
||||||
|
|
||||||
|
episode_data_index["from"] += start_index
|
||||||
|
episode_data_index["to"] += start_index
|
||||||
|
|
||||||
# extend online dataset
|
# extend online dataset
|
||||||
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
|
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
|
||||||
|
|
||||||
|
@ -334,9 +338,13 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
seed=cfg.seed,
|
seed=cfg.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
online_pc_sampling = cfg.get("demo_schedule", 0.5)
|
|
||||||
add_episodes_inplace(
|
add_episodes_inplace(
|
||||||
online_dataset, concat_dataset, sampler, eval_info["episodes"], online_pc_sampling
|
online_dataset,
|
||||||
|
concat_dataset,
|
||||||
|
sampler,
|
||||||
|
hf_dataset=eval_info["episodes"]["hf_dataset"],
|
||||||
|
episode_data_index=eval_info["episodes"]["episode_data_index"],
|
||||||
|
pc_online_samples=cfg.get("demo_schedule", 0.5),
|
||||||
)
|
)
|
||||||
|
|
||||||
for _ in range(cfg.policy.utd):
|
for _ in range(cfg.policy.utd):
|
||||||
|
|
|
@ -22,11 +22,24 @@ def visualize_dataset_cli(cfg: dict):
|
||||||
|
|
||||||
|
|
||||||
def cat_and_write_video(video_path, frames, fps):
|
def cat_and_write_video(video_path, frames, fps):
|
||||||
# Expects images in [0, 255].
|
|
||||||
frames = torch.cat(frames)
|
frames = torch.cat(frames)
|
||||||
assert frames.dtype == torch.uint8
|
|
||||||
frames = einops.rearrange(frames, "b c h w -> b h w c").numpy()
|
# Expects images in [0, 1].
|
||||||
imageio.mimsave(video_path, frames, fps=fps)
|
frame = frames[0]
|
||||||
|
if frame.ndim == 4:
|
||||||
|
raise NotImplementedError("We currently dont support multiple timestamps.")
|
||||||
|
c, h, w = frame.shape
|
||||||
|
assert c < h and c < w, f"expect channel first images, but instead {frame.shape}"
|
||||||
|
|
||||||
|
# sanity check that images are float32 in range [0,1]
|
||||||
|
assert frame.dtype == torch.float32, f"expect torch.float32, but instead {frame.dtype=}"
|
||||||
|
assert frame.max() <= 1, f"expect pixels lower than 1, but instead {frame.max()=}"
|
||||||
|
assert frame.min() >= 0, f"expect pixels greater than 1, but instead {frame.min()=}"
|
||||||
|
|
||||||
|
# convert to channel last uint8 [0, 255]
|
||||||
|
frames = einops.rearrange(frames, "b c h w -> b h w c")
|
||||||
|
frames = (frames * 255).type(torch.uint8)
|
||||||
|
imageio.mimsave(video_path, frames.numpy(), fps=fps)
|
||||||
|
|
||||||
|
|
||||||
def visualize_dataset(cfg: dict, out_dir=None):
|
def visualize_dataset(cfg: dict, out_dir=None):
|
||||||
|
@ -44,9 +57,10 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Start rendering episodes from offline buffer")
|
logging.info("Start rendering episodes from offline buffer")
|
||||||
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps)
|
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER)
|
||||||
for video_path in video_paths:
|
for video_path in video_paths:
|
||||||
logging.info(video_path)
|
logging.info(video_path)
|
||||||
|
return video_paths
|
||||||
|
|
||||||
|
|
||||||
def render_dataset(dataset, out_dir, max_num_episodes):
|
def render_dataset(dataset, out_dir, max_num_episodes):
|
||||||
|
@ -77,7 +91,7 @@ def render_dataset(dataset, out_dir, max_num_episodes):
|
||||||
# add current frame to list of frames to render
|
# add current frame to list of frames to render
|
||||||
frames[im_key].append(item[im_key])
|
frames[im_key].append(item[im_key])
|
||||||
|
|
||||||
end_of_episode = item["index"].item() == item["episode_data_index_to"].item() - 1
|
end_of_episode = item["index"].item() == dataset.episode_data_index["to"][ep_id] - 1
|
||||||
|
|
||||||
out_dir.mkdir(parents=True, exist_ok=True)
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
for im_key in dataset.image_keys:
|
for im_key in dataset.image_keys:
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "absl-py"
|
name = "absl-py"
|
||||||
|
@ -522,21 +522,21 @@ toml = ["tomli"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "datasets"
|
name = "datasets"
|
||||||
version = "2.18.0"
|
version = "2.19.0"
|
||||||
description = "HuggingFace community-driven open-source library of datasets"
|
description = "HuggingFace community-driven open-source library of datasets"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8.0"
|
python-versions = ">=3.8.0"
|
||||||
files = [
|
files = [
|
||||||
{file = "datasets-2.18.0-py3-none-any.whl", hash = "sha256:f1bbf0e2896917a914de01cbd37075b14deea3837af87ad0d9f697388ccaeb50"},
|
{file = "datasets-2.19.0-py3-none-any.whl", hash = "sha256:f57c5316e123d4721b970c68c1cb856505f289cda58f5557ffe745b49c011a8e"},
|
||||||
{file = "datasets-2.18.0.tar.gz", hash = "sha256:cdf8b8c6abf7316377ba4f49f9589a4c74556d6b481afd0abd2284f3d69185cb"},
|
{file = "datasets-2.19.0.tar.gz", hash = "sha256:0b47e08cc7af2c6800a42cadc4657b22a0afc7197786c8986d703c08d90886a6"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
aiohttp = "*"
|
aiohttp = "*"
|
||||||
dill = ">=0.3.0,<0.3.9"
|
dill = ">=0.3.0,<0.3.9"
|
||||||
filelock = "*"
|
filelock = "*"
|
||||||
fsspec = {version = ">=2023.1.0,<=2024.2.0", extras = ["http"]}
|
fsspec = {version = ">=2023.1.0,<=2024.3.1", extras = ["http"]}
|
||||||
huggingface-hub = ">=0.19.4"
|
huggingface-hub = ">=0.21.2"
|
||||||
multiprocess = "*"
|
multiprocess = "*"
|
||||||
numpy = ">=1.17"
|
numpy = ">=1.17"
|
||||||
packaging = "*"
|
packaging = "*"
|
||||||
|
@ -552,15 +552,15 @@ xxhash = "*"
|
||||||
apache-beam = ["apache-beam (>=2.26.0)"]
|
apache-beam = ["apache-beam (>=2.26.0)"]
|
||||||
audio = ["librosa", "soundfile (>=0.12.1)"]
|
audio = ["librosa", "soundfile (>=0.12.1)"]
|
||||||
benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"]
|
benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"]
|
||||||
dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
|
dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
|
||||||
docs = ["s3fs", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos", "torch", "transformers"]
|
docs = ["s3fs", "tensorflow (>=2.6.0)", "torch", "transformers"]
|
||||||
jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"]
|
jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"]
|
||||||
metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"]
|
metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"]
|
||||||
quality = ["ruff (>=0.3.0)"]
|
quality = ["ruff (>=0.3.0)"]
|
||||||
s3 = ["s3fs"]
|
s3 = ["s3fs"]
|
||||||
tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos"]
|
tensorflow = ["tensorflow (>=2.6.0)"]
|
||||||
tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"]
|
tensorflow-gpu = ["tensorflow (>=2.6.0)"]
|
||||||
tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
|
tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
|
||||||
torch = ["torch"]
|
torch = ["torch"]
|
||||||
vision = ["Pillow (>=6.2.1)"]
|
vision = ["Pillow (>=6.2.1)"]
|
||||||
|
|
||||||
|
@ -2909,7 +2909,6 @@ files = [
|
||||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||||
|
@ -4195,4 +4194,4 @@ xarm = ["gym-xarm"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "01ad4eb04061ec9f785d4574bf66d3e5cb4549e2ea11ab175895f94cb62c1f1c"
|
content-hash = "7f5afa48aead953f598e686e767891d3d23f2862b80144f76dc064101ef80b4a"
|
||||||
|
|
|
@ -53,7 +53,8 @@ pre-commit = {version = "^3.7.0", optional = true}
|
||||||
debugpy = {version = "^1.8.1", optional = true}
|
debugpy = {version = "^1.8.1", optional = true}
|
||||||
pytest = {version = "^8.1.0", optional = true}
|
pytest = {version = "^8.1.0", optional = true}
|
||||||
pytest-cov = {version = "^5.0.0", optional = true}
|
pytest-cov = {version = "^5.0.0", optional = true}
|
||||||
datasets = "^2.18.0"
|
datasets = "^2.19.0"
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
pusht = ["gym-pusht"]
|
pusht = ["gym-pusht"]
|
||||||
|
|
Binary file not shown.
|
@ -0,0 +1,3 @@
|
||||||
|
{
|
||||||
|
"fps": 50
|
||||||
|
}
|
Binary file not shown.
Binary file not shown.
|
@ -21,11 +21,11 @@
|
||||||
"length": 14,
|
"length": 14,
|
||||||
"_type": "Sequence"
|
"_type": "Sequence"
|
||||||
},
|
},
|
||||||
"episode_id": {
|
"episode_index": {
|
||||||
"dtype": "int64",
|
"dtype": "int64",
|
||||||
"_type": "Value"
|
"_type": "Value"
|
||||||
},
|
},
|
||||||
"frame_id": {
|
"frame_index": {
|
||||||
"dtype": "int64",
|
"dtype": "int64",
|
||||||
"_type": "Value"
|
"_type": "Value"
|
||||||
},
|
},
|
||||||
|
@ -37,14 +37,6 @@
|
||||||
"dtype": "bool",
|
"dtype": "bool",
|
||||||
"_type": "Value"
|
"_type": "Value"
|
||||||
},
|
},
|
||||||
"episode_data_index_from": {
|
|
||||||
"dtype": "int64",
|
|
||||||
"_type": "Value"
|
|
||||||
},
|
|
||||||
"episode_data_index_to": {
|
|
||||||
"dtype": "int64",
|
|
||||||
"_type": "Value"
|
|
||||||
},
|
|
||||||
"index": {
|
"index": {
|
||||||
"dtype": "int64",
|
"dtype": "int64",
|
||||||
"_type": "Value"
|
"_type": "Value"
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
"filename": "data-00000-of-00001.arrow"
|
"filename": "data-00000-of-00001.arrow"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"_fingerprint": "d79cf82ffc86f110",
|
"_fingerprint": "22eeca7a3f4725ee",
|
||||||
"_format_columns": null,
|
"_format_columns": null,
|
||||||
"_format_kwargs": {},
|
"_format_kwargs": {},
|
||||||
"_format_type": "torch",
|
"_format_type": "torch",
|
||||||
|
|
Binary file not shown.
|
@ -0,0 +1,3 @@
|
||||||
|
{
|
||||||
|
"fps": 50
|
||||||
|
}
|
Binary file not shown.
Binary file not shown.
|
@ -21,11 +21,11 @@
|
||||||
"length": 14,
|
"length": 14,
|
||||||
"_type": "Sequence"
|
"_type": "Sequence"
|
||||||
},
|
},
|
||||||
"episode_id": {
|
"episode_index": {
|
||||||
"dtype": "int64",
|
"dtype": "int64",
|
||||||
"_type": "Value"
|
"_type": "Value"
|
||||||
},
|
},
|
||||||
"frame_id": {
|
"frame_index": {
|
||||||
"dtype": "int64",
|
"dtype": "int64",
|
||||||
"_type": "Value"
|
"_type": "Value"
|
||||||
},
|
},
|
||||||
|
@ -37,14 +37,6 @@
|
||||||
"dtype": "bool",
|
"dtype": "bool",
|
||||||
"_type": "Value"
|
"_type": "Value"
|
||||||
},
|
},
|
||||||
"episode_data_index_from": {
|
|
||||||
"dtype": "int64",
|
|
||||||
"_type": "Value"
|
|
||||||
},
|
|
||||||
"episode_data_index_to": {
|
|
||||||
"dtype": "int64",
|
|
||||||
"_type": "Value"
|
|
||||||
},
|
|
||||||
"index": {
|
"index": {
|
||||||
"dtype": "int64",
|
"dtype": "int64",
|
||||||
"_type": "Value"
|
"_type": "Value"
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
"filename": "data-00000-of-00001.arrow"
|
"filename": "data-00000-of-00001.arrow"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"_fingerprint": "d8e4a817b5449498",
|
"_fingerprint": "97c28d4ad1536e4c",
|
||||||
"_format_columns": null,
|
"_format_columns": null,
|
||||||
"_format_kwargs": {},
|
"_format_kwargs": {},
|
||||||
"_format_type": "torch",
|
"_format_type": "torch",
|
||||||
|
|
Binary file not shown.
|
@ -0,0 +1,3 @@
|
||||||
|
{
|
||||||
|
"fps": 50
|
||||||
|
}
|
Binary file not shown.
Binary file not shown.
|
@ -21,11 +21,11 @@
|
||||||
"length": 14,
|
"length": 14,
|
||||||
"_type": "Sequence"
|
"_type": "Sequence"
|
||||||
},
|
},
|
||||||
"episode_id": {
|
"episode_index": {
|
||||||
"dtype": "int64",
|
"dtype": "int64",
|
||||||
"_type": "Value"
|
"_type": "Value"
|
||||||
},
|
},
|
||||||
"frame_id": {
|
"frame_index": {
|
||||||
"dtype": "int64",
|
"dtype": "int64",
|
||||||
"_type": "Value"
|
"_type": "Value"
|
||||||
},
|
},
|
||||||
|
@ -37,14 +37,6 @@
|
||||||
"dtype": "bool",
|
"dtype": "bool",
|
||||||
"_type": "Value"
|
"_type": "Value"
|
||||||
},
|
},
|
||||||
"episode_data_index_from": {
|
|
||||||
"dtype": "int64",
|
|
||||||
"_type": "Value"
|
|
||||||
},
|
|
||||||
"episode_data_index_to": {
|
|
||||||
"dtype": "int64",
|
|
||||||
"_type": "Value"
|
|
||||||
},
|
|
||||||
"index": {
|
"index": {
|
||||||
"dtype": "int64",
|
"dtype": "int64",
|
||||||
"_type": "Value"
|
"_type": "Value"
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
"filename": "data-00000-of-00001.arrow"
|
"filename": "data-00000-of-00001.arrow"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"_fingerprint": "f03482befa767127",
|
"_fingerprint": "cb9349b5c92951e8",
|
||||||
"_format_columns": null,
|
"_format_columns": null,
|
||||||
"_format_kwargs": {},
|
"_format_kwargs": {},
|
||||||
"_format_type": "torch",
|
"_format_type": "torch",
|
||||||
|
|
Binary file not shown.
|
@ -0,0 +1,3 @@
|
||||||
|
{
|
||||||
|
"fps": 50
|
||||||
|
}
|
Binary file not shown.
Binary file not shown.
|
@ -21,11 +21,11 @@
|
||||||
"length": 14,
|
"length": 14,
|
||||||
"_type": "Sequence"
|
"_type": "Sequence"
|
||||||
},
|
},
|
||||||
"episode_id": {
|
"episode_index": {
|
||||||
"dtype": "int64",
|
"dtype": "int64",
|
||||||
"_type": "Value"
|
"_type": "Value"
|
||||||
},
|
},
|
||||||
"frame_id": {
|
"frame_index": {
|
||||||
"dtype": "int64",
|
"dtype": "int64",
|
||||||
"_type": "Value"
|
"_type": "Value"
|
||||||
},
|
},
|
||||||
|
@ -37,14 +37,6 @@
|
||||||
"dtype": "bool",
|
"dtype": "bool",
|
||||||
"_type": "Value"
|
"_type": "Value"
|
||||||
},
|
},
|
||||||
"episode_data_index_from": {
|
|
||||||
"dtype": "int64",
|
|
||||||
"_type": "Value"
|
|
||||||
},
|
|
||||||
"episode_data_index_to": {
|
|
||||||
"dtype": "int64",
|
|
||||||
"_type": "Value"
|
|
||||||
},
|
|
||||||
"index": {
|
"index": {
|
||||||
"dtype": "int64",
|
"dtype": "int64",
|
||||||
"_type": "Value"
|
"_type": "Value"
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
"filename": "data-00000-of-00001.arrow"
|
"filename": "data-00000-of-00001.arrow"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"_fingerprint": "93e03c6320c7d56e",
|
"_fingerprint": "e4d7ad2b360db1af",
|
||||||
"_format_columns": null,
|
"_format_columns": null,
|
||||||
"_format_kwargs": {},
|
"_format_kwargs": {},
|
||||||
"_format_type": "torch",
|
"_format_type": "torch",
|
||||||
|
|
Binary file not shown.
|
@ -0,0 +1,3 @@
|
||||||
|
{
|
||||||
|
"fps": 10
|
||||||
|
}
|
Binary file not shown.
Binary file not shown.
|
@ -21,11 +21,11 @@
|
||||||
"length": 2,
|
"length": 2,
|
||||||
"_type": "Sequence"
|
"_type": "Sequence"
|
||||||
},
|
},
|
||||||
"episode_id": {
|
"episode_index": {
|
||||||
"dtype": "int64",
|
"dtype": "int64",
|
||||||
"_type": "Value"
|
"_type": "Value"
|
||||||
},
|
},
|
||||||
"frame_id": {
|
"frame_index": {
|
||||||
"dtype": "int64",
|
"dtype": "int64",
|
||||||
"_type": "Value"
|
"_type": "Value"
|
||||||
},
|
},
|
||||||
|
@ -45,14 +45,6 @@
|
||||||
"dtype": "bool",
|
"dtype": "bool",
|
||||||
"_type": "Value"
|
"_type": "Value"
|
||||||
},
|
},
|
||||||
"episode_data_index_from": {
|
|
||||||
"dtype": "int64",
|
|
||||||
"_type": "Value"
|
|
||||||
},
|
|
||||||
"episode_data_index_to": {
|
|
||||||
"dtype": "int64",
|
|
||||||
"_type": "Value"
|
|
||||||
},
|
|
||||||
"index": {
|
"index": {
|
||||||
"dtype": "int64",
|
"dtype": "int64",
|
||||||
"_type": "Value"
|
"_type": "Value"
|
||||||
|
|
Binary file not shown.
|
@ -0,0 +1,3 @@
|
||||||
|
{
|
||||||
|
"fps": 10
|
||||||
|
}
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -4,7 +4,7 @@
|
||||||
"filename": "data-00000-of-00001.arrow"
|
"filename": "data-00000-of-00001.arrow"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"_fingerprint": "21bb9a76ed78a475",
|
"_fingerprint": "a04a9ce660122e23",
|
||||||
"_format_columns": null,
|
"_format_columns": null,
|
||||||
"_format_kwargs": {},
|
"_format_kwargs": {},
|
||||||
"_format_type": "torch",
|
"_format_type": "torch",
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,3 @@
|
||||||
|
{
|
||||||
|
"fps": 15
|
||||||
|
}
|
Binary file not shown.
Binary file not shown.
|
@ -21,11 +21,11 @@
|
||||||
"length": 4,
|
"length": 4,
|
||||||
"_type": "Sequence"
|
"_type": "Sequence"
|
||||||
},
|
},
|
||||||
"episode_id": {
|
"episode_index": {
|
||||||
"dtype": "int64",
|
"dtype": "int64",
|
||||||
"_type": "Value"
|
"_type": "Value"
|
||||||
},
|
},
|
||||||
"frame_id": {
|
"frame_index": {
|
||||||
"dtype": "int64",
|
"dtype": "int64",
|
||||||
"_type": "Value"
|
"_type": "Value"
|
||||||
},
|
},
|
||||||
|
@ -41,14 +41,6 @@
|
||||||
"dtype": "bool",
|
"dtype": "bool",
|
||||||
"_type": "Value"
|
"_type": "Value"
|
||||||
},
|
},
|
||||||
"episode_data_index_from": {
|
|
||||||
"dtype": "int64",
|
|
||||||
"_type": "Value"
|
|
||||||
},
|
|
||||||
"episode_data_index_to": {
|
|
||||||
"dtype": "int64",
|
|
||||||
"_type": "Value"
|
|
||||||
},
|
|
||||||
"index": {
|
"index": {
|
||||||
"dtype": "int64",
|
"dtype": "int64",
|
||||||
"_type": "Value"
|
"_type": "Value"
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
"filename": "data-00000-of-00001.arrow"
|
"filename": "data-00000-of-00001.arrow"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"_fingerprint": "a95cbec45e3bb9d6",
|
"_fingerprint": "cc6afdfcdd6f63ab",
|
||||||
"_format_columns": null,
|
"_format_columns": null,
|
||||||
"_format_kwargs": {},
|
"_format_kwargs": {},
|
||||||
"_format_type": "torch",
|
"_format_type": "torch",
|
||||||
|
|
Binary file not shown.
|
@ -0,0 +1,3 @@
|
||||||
|
{
|
||||||
|
"fps": 15
|
||||||
|
}
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,51 @@
|
||||||
|
{
|
||||||
|
"citation": "",
|
||||||
|
"description": "",
|
||||||
|
"features": {
|
||||||
|
"observation.image": {
|
||||||
|
"_type": "Image"
|
||||||
|
},
|
||||||
|
"observation.state": {
|
||||||
|
"feature": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"length": 4,
|
||||||
|
"_type": "Sequence"
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"feature": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"length": 4,
|
||||||
|
"_type": "Sequence"
|
||||||
|
},
|
||||||
|
"episode_index": {
|
||||||
|
"dtype": "int64",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"frame_index": {
|
||||||
|
"dtype": "int64",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"timestamp": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"next.reward": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"next.done": {
|
||||||
|
"dtype": "bool",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"index": {
|
||||||
|
"dtype": "int64",
|
||||||
|
"_type": "Value"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"homepage": "",
|
||||||
|
"license": ""
|
||||||
|
}
|
|
@ -0,0 +1,13 @@
|
||||||
|
{
|
||||||
|
"_data_files": [
|
||||||
|
{
|
||||||
|
"filename": "data-00000-of-00001.arrow"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"_fingerprint": "9f8e1a8c1845df55",
|
||||||
|
"_format_columns": null,
|
||||||
|
"_format_kwargs": {},
|
||||||
|
"_format_type": "torch",
|
||||||
|
"_output_all_columns": false,
|
||||||
|
"_split": null
|
||||||
|
}
|
Binary file not shown.
|
@ -0,0 +1,3 @@
|
||||||
|
{
|
||||||
|
"fps": 15
|
||||||
|
}
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,51 @@
|
||||||
|
{
|
||||||
|
"citation": "",
|
||||||
|
"description": "",
|
||||||
|
"features": {
|
||||||
|
"observation.image": {
|
||||||
|
"_type": "Image"
|
||||||
|
},
|
||||||
|
"observation.state": {
|
||||||
|
"feature": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"length": 4,
|
||||||
|
"_type": "Sequence"
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"feature": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"length": 3,
|
||||||
|
"_type": "Sequence"
|
||||||
|
},
|
||||||
|
"episode_index": {
|
||||||
|
"dtype": "int64",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"frame_index": {
|
||||||
|
"dtype": "int64",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"timestamp": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"next.reward": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"next.done": {
|
||||||
|
"dtype": "bool",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"index": {
|
||||||
|
"dtype": "int64",
|
||||||
|
"_type": "Value"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"homepage": "",
|
||||||
|
"license": ""
|
||||||
|
}
|
|
@ -0,0 +1,13 @@
|
||||||
|
{
|
||||||
|
"_data_files": [
|
||||||
|
{
|
||||||
|
"filename": "data-00000-of-00001.arrow"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"_fingerprint": "c900258061dd0b3f",
|
||||||
|
"_format_columns": null,
|
||||||
|
"_format_kwargs": {},
|
||||||
|
"_format_type": "torch",
|
||||||
|
"_output_all_columns": false,
|
||||||
|
"_split": null
|
||||||
|
}
|
Binary file not shown.
|
@ -0,0 +1,3 @@
|
||||||
|
{
|
||||||
|
"fps": 15
|
||||||
|
}
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,51 @@
|
||||||
|
{
|
||||||
|
"citation": "",
|
||||||
|
"description": "",
|
||||||
|
"features": {
|
||||||
|
"observation.image": {
|
||||||
|
"_type": "Image"
|
||||||
|
},
|
||||||
|
"observation.state": {
|
||||||
|
"feature": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"length": 4,
|
||||||
|
"_type": "Sequence"
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"feature": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"length": 3,
|
||||||
|
"_type": "Sequence"
|
||||||
|
},
|
||||||
|
"episode_index": {
|
||||||
|
"dtype": "int64",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"frame_index": {
|
||||||
|
"dtype": "int64",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"timestamp": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"next.reward": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"next.done": {
|
||||||
|
"dtype": "bool",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"index": {
|
||||||
|
"dtype": "int64",
|
||||||
|
"_type": "Value"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"homepage": "",
|
||||||
|
"license": ""
|
||||||
|
}
|
|
@ -0,0 +1,13 @@
|
||||||
|
{
|
||||||
|
"_data_files": [
|
||||||
|
{
|
||||||
|
"filename": "data-00000-of-00001.arrow"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"_fingerprint": "e51c80a33c7688c0",
|
||||||
|
"_format_columns": null,
|
||||||
|
"_format_kwargs": {},
|
||||||
|
"_format_type": "torch",
|
||||||
|
"_output_all_columns": false,
|
||||||
|
"_split": null
|
||||||
|
}
|
|
@ -0,0 +1,71 @@
|
||||||
|
"""
|
||||||
|
This script provides a utility for saving a dataset as safetensors files for the purpose of testing backward compatibility
|
||||||
|
when updating the data format. It uses the `PushtDataset` to create a DataLoader and saves selected frame from the
|
||||||
|
dataset into a corresponding safetensors file in a specified output directory.
|
||||||
|
|
||||||
|
If you know that your change will break backward compatibility, you should write a shortlived test by modifying
|
||||||
|
`tests/test_datasets.py::test_backward_compatibility` accordingly, and make sure this custom test pass. Your custom test
|
||||||
|
doesnt need to be merged into the `main` branch. Then you need to run this script and update the tests artifacts.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
`python tests/script/save_dataset_to_safetensors.py`
|
||||||
|
"""
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
from lerobot.common.datasets.pusht import PushtDataset
|
||||||
|
|
||||||
|
|
||||||
|
def save_dataset_to_safetensors(output_dir, dataset_id="pusht"):
|
||||||
|
data_dir = Path(output_dir) / dataset_id
|
||||||
|
|
||||||
|
if data_dir.exists():
|
||||||
|
shutil.rmtree(data_dir)
|
||||||
|
|
||||||
|
data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# TODO(rcadene): make it work for all datasets with LeRobotDataset(repo_id)
|
||||||
|
dataset = PushtDataset(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
split="train",
|
||||||
|
)
|
||||||
|
|
||||||
|
# save 2 first frames of first episode
|
||||||
|
i = dataset.episode_data_index["from"][0].item()
|
||||||
|
save_file(dataset[i], data_dir / f"frame_{i}.safetensors")
|
||||||
|
save_file(dataset[i + 1], data_dir / f"frame_{i+1}.safetensors")
|
||||||
|
|
||||||
|
# save 2 frames at the middle of first episode
|
||||||
|
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
||||||
|
save_file(dataset[i], data_dir / f"frame_{i}.safetensors")
|
||||||
|
save_file(dataset[i + 1], data_dir / f"frame_{i+1}.safetensors")
|
||||||
|
|
||||||
|
# save 2 last frames of first episode
|
||||||
|
i = dataset.episode_data_index["to"][0].item()
|
||||||
|
save_file(dataset[i - 2], data_dir / f"frame_{i-2}.safetensors")
|
||||||
|
save_file(dataset[i - 1], data_dir / f"frame_{i-1}.safetensors")
|
||||||
|
|
||||||
|
# TODO(rcadene): Enable testing on second and last episode
|
||||||
|
# We currently cant because our test dataset only contains the first episode
|
||||||
|
|
||||||
|
# # save 2 first frames of second episode
|
||||||
|
# i = dataset.episode_data_index["from"][1].item()
|
||||||
|
# save_file(dataset[i], data_dir / f"frame_{i}.safetensors")
|
||||||
|
# save_file(dataset[i+1], data_dir / f"frame_{i+1}.safetensors")
|
||||||
|
|
||||||
|
# # save 2 last frames of second episode
|
||||||
|
# i = dataset.episode_data_index["to"][1].item()
|
||||||
|
# save_file(dataset[i-2], data_dir / f"frame_{i-2}.safetensors")
|
||||||
|
# save_file(dataset[i-1], data_dir / f"frame_{i-1}.safetensors")
|
||||||
|
|
||||||
|
# # save 2 last frames of last episode
|
||||||
|
# i = dataset.episode_data_index["to"][-1].item()
|
||||||
|
# save_file(dataset[i-2], data_dir / f"frame_{i-2}.safetensors")
|
||||||
|
# save_file(dataset[i-1], data_dir / f"frame_{i-1}.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors")
|
|
@ -1,20 +1,26 @@
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
import lerobot
|
import lerobot
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
from lerobot.common.datasets.pusht import PushtDataset
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
compute_stats,
|
compute_stats,
|
||||||
|
flatten_dict,
|
||||||
get_stats_einops_patterns,
|
get_stats_einops_patterns,
|
||||||
|
hf_transform_to_torch,
|
||||||
load_previous_and_future_frames,
|
load_previous_and_future_frames,
|
||||||
|
unflatten_dict,
|
||||||
)
|
)
|
||||||
from lerobot.common.transforms import Prod
|
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
|
|
||||||
from .utils import DEFAULT_CONFIG_PATH, DEVICE
|
from .utils import DEFAULT_CONFIG_PATH, DEVICE
|
||||||
|
@ -39,8 +45,8 @@ def test_factory(env_name, dataset_id, policy_name):
|
||||||
|
|
||||||
keys_ndim_required = [
|
keys_ndim_required = [
|
||||||
("action", 1, True),
|
("action", 1, True),
|
||||||
("episode_id", 0, True),
|
("episode_index", 0, True),
|
||||||
("frame_id", 0, True),
|
("frame_index", 0, True),
|
||||||
("timestamp", 0, True),
|
("timestamp", 0, True),
|
||||||
# TODO(rcadene): should we rename it agent_pos?
|
# TODO(rcadene): should we rename it agent_pos?
|
||||||
("observation.state", 1, True),
|
("observation.state", 1, True),
|
||||||
|
@ -48,12 +54,6 @@ def test_factory(env_name, dataset_id, policy_name):
|
||||||
("next.done", 0, False),
|
("next.done", 0, False),
|
||||||
]
|
]
|
||||||
|
|
||||||
for key in image_keys:
|
|
||||||
keys_ndim_required.append(
|
|
||||||
(key, 3, True),
|
|
||||||
)
|
|
||||||
assert dataset.hf_dataset[key].dtype == torch.uint8, f"{key}"
|
|
||||||
|
|
||||||
# test number of dimensions
|
# test number of dimensions
|
||||||
for key, ndim, required in keys_ndim_required:
|
for key, ndim, required in keys_ndim_required:
|
||||||
if key not in item:
|
if key not in item:
|
||||||
|
@ -94,26 +94,21 @@ def test_compute_stats_on_xarm():
|
||||||
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
|
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
|
||||||
because we are working with a small dataset).
|
because we are working with a small dataset).
|
||||||
"""
|
"""
|
||||||
|
# TODO(rcadene): Reduce size of dataset sample on which stats compute is tested
|
||||||
from lerobot.common.datasets.xarm import XarmDataset
|
from lerobot.common.datasets.xarm import XarmDataset
|
||||||
|
|
||||||
data_dir = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
|
||||||
|
|
||||||
# get transform to convert images from uint8 [0,255] to float32 [0,1]
|
|
||||||
transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0)
|
|
||||||
|
|
||||||
dataset = XarmDataset(
|
dataset = XarmDataset(
|
||||||
dataset_id="xarm_lift_medium",
|
dataset_id="xarm_lift_medium",
|
||||||
root=data_dir,
|
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
|
||||||
transform=transform,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
|
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
|
||||||
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
|
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
|
||||||
# dataset into even batches.
|
# dataset into even batches.
|
||||||
computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25))
|
computed_stats = compute_stats(dataset.hf_dataset, batch_size=int(len(dataset) * 0.25))
|
||||||
|
|
||||||
# get einops patterns to aggregate batches and compute statistics
|
# get einops patterns to aggregate batches and compute statistics
|
||||||
stats_patterns = get_stats_einops_patterns(dataset)
|
stats_patterns = get_stats_einops_patterns(dataset.hf_dataset)
|
||||||
|
|
||||||
# get all frames from the dataset in the same dtype and range as during compute_stats
|
# get all frames from the dataset in the same dtype and range as during compute_stats
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
@ -122,18 +117,19 @@ def test_compute_stats_on_xarm():
|
||||||
batch_size=len(dataset),
|
batch_size=len(dataset),
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
hf_dataset = next(iter(dataloader))
|
full_batch = next(iter(dataloader))
|
||||||
|
|
||||||
# compute stats based on all frames from the dataset without any batching
|
# compute stats based on all frames from the dataset without any batching
|
||||||
expected_stats = {}
|
expected_stats = {}
|
||||||
for k, pattern in stats_patterns.items():
|
for k, pattern in stats_patterns.items():
|
||||||
|
full_batch[k] = full_batch[k].float()
|
||||||
expected_stats[k] = {}
|
expected_stats[k] = {}
|
||||||
expected_stats[k]["mean"] = einops.reduce(hf_dataset[k], pattern, "mean")
|
expected_stats[k]["mean"] = einops.reduce(full_batch[k], pattern, "mean")
|
||||||
expected_stats[k]["std"] = torch.sqrt(
|
expected_stats[k]["std"] = torch.sqrt(
|
||||||
einops.reduce((hf_dataset[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
|
einops.reduce((full_batch[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
|
||||||
)
|
)
|
||||||
expected_stats[k]["min"] = einops.reduce(hf_dataset[k], pattern, "min")
|
expected_stats[k]["min"] = einops.reduce(full_batch[k], pattern, "min")
|
||||||
expected_stats[k]["max"] = einops.reduce(hf_dataset[k], pattern, "max")
|
expected_stats[k]["max"] = einops.reduce(full_batch[k], pattern, "max")
|
||||||
|
|
||||||
# test computed stats match expected stats
|
# test computed stats match expected stats
|
||||||
for k in stats_patterns:
|
for k in stats_patterns:
|
||||||
|
@ -142,11 +138,10 @@ def test_compute_stats_on_xarm():
|
||||||
assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"])
|
assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"])
|
||||||
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
|
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
|
||||||
|
|
||||||
# TODO(rcadene): check that the stats used for training are correct too
|
# load stats used during training which are expected to match the ones returned by computed_stats
|
||||||
# # load stats that are expected to match the ones returned by computed_stats
|
loaded_stats = dataset.stats # noqa: F841
|
||||||
# assert (dataset.data_dir / "stats.pth").exists()
|
|
||||||
# loaded_stats = torch.load(dataset.data_dir / "stats.pth")
|
|
||||||
|
|
||||||
|
# TODO(rcadene): we can't test this because expected_stats is computed on a subset
|
||||||
# # test loaded stats match expected stats
|
# # test loaded stats match expected stats
|
||||||
# for k in stats_patterns:
|
# for k in stats_patterns:
|
||||||
# assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"])
|
# assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"])
|
||||||
|
@ -160,15 +155,18 @@ def test_load_previous_and_future_frames_within_tolerance():
|
||||||
{
|
{
|
||||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||||
"index": [0, 1, 2, 3, 4],
|
"index": [0, 1, 2, 3, 4],
|
||||||
"episode_data_index_from": [0, 0, 0, 0, 0],
|
"episode_index": [0, 0, 0, 0, 0],
|
||||||
"episode_data_index_to": [5, 5, 5, 5, 5],
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
hf_dataset = hf_dataset.with_format("torch")
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
item = hf_dataset[2]
|
episode_data_index = {
|
||||||
|
"from": torch.tensor([0]),
|
||||||
|
"to": torch.tensor([5]),
|
||||||
|
}
|
||||||
delta_timestamps = {"index": [-0.2, 0, 0.139]}
|
delta_timestamps = {"index": [-0.2, 0, 0.139]}
|
||||||
tol = 0.04
|
tol = 0.04
|
||||||
item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
|
item = hf_dataset[2]
|
||||||
|
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
|
||||||
data, is_pad = item["index"], item["index_is_pad"]
|
data, is_pad = item["index"], item["index_is_pad"]
|
||||||
assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
|
assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
|
||||||
assert not is_pad.any(), "Unexpected padding detected"
|
assert not is_pad.any(), "Unexpected padding detected"
|
||||||
|
@ -179,16 +177,19 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range(
|
||||||
{
|
{
|
||||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||||
"index": [0, 1, 2, 3, 4],
|
"index": [0, 1, 2, 3, 4],
|
||||||
"episode_data_index_from": [0, 0, 0, 0, 0],
|
"episode_index": [0, 0, 0, 0, 0],
|
||||||
"episode_data_index_to": [5, 5, 5, 5, 5],
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
hf_dataset = hf_dataset.with_format("torch")
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
item = hf_dataset[2]
|
episode_data_index = {
|
||||||
|
"from": torch.tensor([0]),
|
||||||
|
"to": torch.tensor([5]),
|
||||||
|
}
|
||||||
delta_timestamps = {"index": [-0.2, 0, 0.141]}
|
delta_timestamps = {"index": [-0.2, 0, 0.141]}
|
||||||
tol = 0.04
|
tol = 0.04
|
||||||
|
item = hf_dataset[2]
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
|
load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
|
||||||
|
|
||||||
|
|
||||||
def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range():
|
def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range():
|
||||||
|
@ -196,17 +197,102 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range
|
||||||
{
|
{
|
||||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||||
"index": [0, 1, 2, 3, 4],
|
"index": [0, 1, 2, 3, 4],
|
||||||
"episode_data_index_from": [0, 0, 0, 0, 0],
|
"episode_index": [0, 0, 0, 0, 0],
|
||||||
"episode_data_index_to": [5, 5, 5, 5, 5],
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
hf_dataset = hf_dataset.with_format("torch")
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
item = hf_dataset[2]
|
episode_data_index = {
|
||||||
|
"from": torch.tensor([0]),
|
||||||
|
"to": torch.tensor([5]),
|
||||||
|
}
|
||||||
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
|
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
|
||||||
tol = 0.04
|
tol = 0.04
|
||||||
item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
|
item = hf_dataset[2]
|
||||||
|
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
|
||||||
data, is_pad = item["index"], item["index_is_pad"]
|
data, is_pad = item["index"], item["index_is_pad"]
|
||||||
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
|
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
|
||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
is_pad, torch.tensor([True, False, False, True, True])
|
is_pad, torch.tensor([True, False, False, True, True])
|
||||||
), "Padding does not match expected values"
|
), "Padding does not match expected values"
|
||||||
|
|
||||||
|
|
||||||
|
def test_flatten_unflatten_dict():
|
||||||
|
d = {
|
||||||
|
"obs": {
|
||||||
|
"min": 0,
|
||||||
|
"max": 1,
|
||||||
|
"mean": 2,
|
||||||
|
"std": 3,
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"min": 4,
|
||||||
|
"max": 5,
|
||||||
|
"mean": 6,
|
||||||
|
"std": 7,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
original_d = deepcopy(d)
|
||||||
|
d = unflatten_dict(flatten_dict(d))
|
||||||
|
|
||||||
|
# test equality between nested dicts
|
||||||
|
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_backward_compatibility():
|
||||||
|
"""This tests artifacts have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
|
||||||
|
# TODO(rcadene): make it work for all datasets with LeRobotDataset(repo_id)
|
||||||
|
dataset_id = "pusht"
|
||||||
|
data_dir = Path("tests/data/save_dataset_to_safetensors") / dataset_id
|
||||||
|
|
||||||
|
dataset = PushtDataset(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
split="train",
|
||||||
|
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_and_compare(i):
|
||||||
|
new_frame = dataset[i]
|
||||||
|
old_frame = load_file(data_dir / f"frame_{i}.safetensors")
|
||||||
|
|
||||||
|
new_keys = set(new_frame.keys())
|
||||||
|
old_keys = set(old_frame.keys())
|
||||||
|
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
|
||||||
|
|
||||||
|
for key in new_frame:
|
||||||
|
assert (
|
||||||
|
new_frame[key] == old_frame[key]
|
||||||
|
).all(), f"{key=} for index={i} does not contain the same value"
|
||||||
|
|
||||||
|
# test2 first frames of first episode
|
||||||
|
i = dataset.episode_data_index["from"][0].item()
|
||||||
|
load_and_compare(i)
|
||||||
|
load_and_compare(i + 1)
|
||||||
|
|
||||||
|
# test 2 frames at the middle of first episode
|
||||||
|
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
||||||
|
load_and_compare(i)
|
||||||
|
load_and_compare(i + 1)
|
||||||
|
|
||||||
|
# test 2 last frames of first episode
|
||||||
|
i = dataset.episode_data_index["to"][0].item()
|
||||||
|
load_and_compare(i - 2)
|
||||||
|
load_and_compare(i - 1)
|
||||||
|
|
||||||
|
# TODO(rcadene): Enable testing on second and last episode
|
||||||
|
# We currently cant because our test dataset only contains the first episode
|
||||||
|
|
||||||
|
# # test 2 first frames of second episode
|
||||||
|
# i = dataset.episode_data_index["from"][1].item()
|
||||||
|
# load_and_compare(i)
|
||||||
|
# load_and_compare(i+1)
|
||||||
|
|
||||||
|
# #test 2 last frames of second episode
|
||||||
|
# i = dataset.episode_data_index["to"][1].item()
|
||||||
|
# load_and_compare(i-2)
|
||||||
|
# load_and_compare(i-1)
|
||||||
|
|
||||||
|
# # test 2 last frames of last episode
|
||||||
|
# i = dataset.episode_data_index["to"][-1].item()
|
||||||
|
# load_and_compare(i-2)
|
||||||
|
# load_and_compare(i-1)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
# TODO(aliberts): Mute logging for these tests
|
||||||
import subprocess
|
import subprocess
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
|
from lerobot.scripts.visualize_dataset import visualize_dataset
|
||||||
|
|
||||||
|
from .utils import DEFAULT_CONFIG_PATH
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"dataset_id",
|
||||||
|
[
|
||||||
|
"aloha_sim_insertion_human",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_visualize_dataset(tmpdir, dataset_id):
|
||||||
|
# TODO(rcadene): this test might fail with other datasets/policies/envs, since visualization_dataset
|
||||||
|
# doesnt support multiple timesteps which requires delta_timestamps to None for images.
|
||||||
|
cfg = init_hydra_config(
|
||||||
|
DEFAULT_CONFIG_PATH,
|
||||||
|
overrides=[
|
||||||
|
"policy=act",
|
||||||
|
"env=aloha",
|
||||||
|
f"dataset_id={dataset_id}",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
video_paths = visualize_dataset(cfg, out_dir=tmpdir)
|
||||||
|
|
||||||
|
assert len(video_paths) > 0
|
||||||
|
|
||||||
|
for video_path in video_paths:
|
||||||
|
assert video_path.exists()
|
Loading…
Reference in New Issue