Merge remote-tracking branch 'upstream/main' into finish_examples

This commit is contained in:
Alexander Soare 2024-04-01 11:31:31 +01:00
commit f1148b8c2d
8 changed files with 406 additions and 43 deletions

116
.github/poetry/cpu/poetry.lock generated vendored
View File

@ -327,6 +327,35 @@ files = [
{file = "cloudpickle-3.0.0.tar.gz", hash = "sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882"}, {file = "cloudpickle-3.0.0.tar.gz", hash = "sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882"},
] ]
[[package]]
name = "cmake"
version = "3.29.0.1"
description = "CMake is an open-source, cross-platform family of tools designed to build, test and package software"
optional = false
python-versions = ">=3.7"
files = [
{file = "cmake-3.29.0.1-py3-none-macosx_10_10_universal2.macosx_10_10_x86_64.macosx_11_0_arm64.macosx_11_0_universal2.whl", hash = "sha256:ec8b39fdeb75c48fd5a2894658a1ca75f94fb49b421c1f753d86d3e5d5e9f196"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:ead7dc5176a6c6347b3fc19532c25ec328f9279b6213902ac930242334e7b621"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f7c7dabf3dd40cb830d2eded43d51c5a3737625bbc5ab6916041c04e352b74f8"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de8f55198f4a820daf2c57645a4bb8cd1064dc92d950ad95be14c5ffabc15bd4"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7bb8aaa3419eafd466931e4dcc161d3e5e6a82730ab508c75946ff4fc883b3f6"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e956e4b6c2d8d6bbee399bf3c77e5b901d916fe8f35d6b2f58444d5892c4602b"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a8f7c8b07e6ab0dd444c5b74e658d5013ca0da456041029f734d751090bb7ec"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22493049b6383ea2baa7237a326c2914ab4a7b3e1642f4233245e3a34aae39f6"},
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:d09573411901fc8a3ede7433713c000ac7b81cda0d771874e9182770acf29eb4"},
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_i686.whl", hash = "sha256:c85e35ec572e54152154637f24d6bde316fbcf94dfea644bd8f22b1855a09abe"},
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:a36f3b19a5caa6c63aa70bfa0b262bae8d296e68c6e6d8e918edc5c51d952bf8"},
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_s390x.whl", hash = "sha256:85a0e28eaaee311d50fcee60f730e5a44a65b3cefc556a1163bbabd7328acd60"},
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:58879ef15dd8344e1583a36cead794fb0fee13f78c590a56749283ac2e27d30a"},
{file = "cmake-3.29.0.1-py3-none-win32.whl", hash = "sha256:068a3e7461dd9e487f5f3f720ae8072c2eeb37239ebe4642c4ae29058d83347f"},
{file = "cmake-3.29.0.1-py3-none-win_amd64.whl", hash = "sha256:c097892b3653e2d2d41d055c80a9af029fdd24f9eff2ecb66576e4da8b85b1c7"},
{file = "cmake-3.29.0.1-py3-none-win_arm64.whl", hash = "sha256:1808071047cb49ed0fe2359e4c310b49880c2805cad4ea9f03c959f51b881ac7"},
{file = "cmake-3.29.0.1.tar.gz", hash = "sha256:ec49a7a4480959c229d9d2aa77f7885859c17a45fc66981aaf4551ceffb4d030"},
]
[package.extras]
test = ["coverage (>=4.2)", "pytest (>=3.0.3)", "pytest-cov (>=2.4.0)"]
[[package]] [[package]]
name = "colorama" name = "colorama"
version = "0.4.6" version = "0.4.6"
@ -338,6 +367,73 @@ files = [
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
] ]
[[package]]
name = "coverage"
version = "7.4.4"
description = "Code coverage measurement for Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "coverage-7.4.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0be5efd5127542ef31f165de269f77560d6cdef525fffa446de6f7e9186cfb2"},
{file = "coverage-7.4.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ccd341521be3d1b3daeb41960ae94a5e87abe2f46f17224ba5d6f2b8398016cf"},
{file = "coverage-7.4.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09fa497a8ab37784fbb20ab699c246053ac294d13fc7eb40ec007a5043ec91f8"},
{file = "coverage-7.4.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b1a93009cb80730c9bca5d6d4665494b725b6e8e157c1cb7f2db5b4b122ea562"},
{file = "coverage-7.4.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:690db6517f09336559dc0b5f55342df62370a48f5469fabf502db2c6d1cffcd2"},
{file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:09c3255458533cb76ef55da8cc49ffab9e33f083739c8bd4f58e79fecfe288f7"},
{file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8ce1415194b4a6bd0cdcc3a1dfbf58b63f910dcb7330fe15bdff542c56949f87"},
{file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b91cbc4b195444e7e258ba27ac33769c41b94967919f10037e6355e998af255c"},
{file = "coverage-7.4.4-cp310-cp310-win32.whl", hash = "sha256:598825b51b81c808cb6f078dcb972f96af96b078faa47af7dfcdf282835baa8d"},
{file = "coverage-7.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:09ef9199ed6653989ebbcaacc9b62b514bb63ea2f90256e71fea3ed74bd8ff6f"},
{file = "coverage-7.4.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0f9f50e7ef2a71e2fae92774c99170eb8304e3fdf9c8c3c7ae9bab3e7229c5cf"},
{file = "coverage-7.4.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:623512f8ba53c422fcfb2ce68362c97945095b864cda94a92edbaf5994201083"},
{file = "coverage-7.4.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0513b9508b93da4e1716744ef6ebc507aff016ba115ffe8ecff744d1322a7b63"},
{file = "coverage-7.4.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40209e141059b9370a2657c9b15607815359ab3ef9918f0196b6fccce8d3230f"},
{file = "coverage-7.4.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a2b2b78c78293782fd3767d53e6474582f62443d0504b1554370bde86cc8227"},
{file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:73bfb9c09951125d06ee473bed216e2c3742f530fc5acc1383883125de76d9cd"},
{file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:1f384c3cc76aeedce208643697fb3e8437604b512255de6d18dae3f27655a384"},
{file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:54eb8d1bf7cacfbf2a3186019bcf01d11c666bd495ed18717162f7eb1e9dd00b"},
{file = "coverage-7.4.4-cp311-cp311-win32.whl", hash = "sha256:cac99918c7bba15302a2d81f0312c08054a3359eaa1929c7e4b26ebe41e9b286"},
{file = "coverage-7.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:b14706df8b2de49869ae03a5ccbc211f4041750cd4a66f698df89d44f4bd30ec"},
{file = "coverage-7.4.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:201bef2eea65e0e9c56343115ba3814e896afe6d36ffd37bab783261db430f76"},
{file = "coverage-7.4.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:41c9c5f3de16b903b610d09650e5e27adbfa7f500302718c9ffd1c12cf9d6818"},
{file = "coverage-7.4.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d898fe162d26929b5960e4e138651f7427048e72c853607f2b200909794ed978"},
{file = "coverage-7.4.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3ea79bb50e805cd6ac058dfa3b5c8f6c040cb87fe83de10845857f5535d1db70"},
{file = "coverage-7.4.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce4b94265ca988c3f8e479e741693d143026632672e3ff924f25fab50518dd51"},
{file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:00838a35b882694afda09f85e469c96367daa3f3f2b097d846a7216993d37f4c"},
{file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fdfafb32984684eb03c2d83e1e51f64f0906b11e64482df3c5db936ce3839d48"},
{file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:69eb372f7e2ece89f14751fbcbe470295d73ed41ecd37ca36ed2eb47512a6ab9"},
{file = "coverage-7.4.4-cp312-cp312-win32.whl", hash = "sha256:137eb07173141545e07403cca94ab625cc1cc6bc4c1e97b6e3846270e7e1fea0"},
{file = "coverage-7.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:d71eec7d83298f1af3326ce0ff1d0ea83c7cb98f72b577097f9083b20bdaf05e"},
{file = "coverage-7.4.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d5ae728ff3b5401cc320d792866987e7e7e880e6ebd24433b70a33b643bb0384"},
{file = "coverage-7.4.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cc4f1358cb0c78edef3ed237ef2c86056206bb8d9140e73b6b89fbcfcbdd40e1"},
{file = "coverage-7.4.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8130a2aa2acb8788e0b56938786c33c7c98562697bf9f4c7d6e8e5e3a0501e4a"},
{file = "coverage-7.4.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf271892d13e43bc2b51e6908ec9a6a5094a4df1d8af0bfc360088ee6c684409"},
{file = "coverage-7.4.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4cdc86d54b5da0df6d3d3a2f0b710949286094c3a6700c21e9015932b81447e"},
{file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ae71e7ddb7a413dd60052e90528f2f65270aad4b509563af6d03d53e979feafd"},
{file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:38dd60d7bf242c4ed5b38e094baf6401faa114fc09e9e6632374388a404f98e7"},
{file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aa5b1c1bfc28384f1f53b69a023d789f72b2e0ab1b3787aae16992a7ca21056c"},
{file = "coverage-7.4.4-cp38-cp38-win32.whl", hash = "sha256:dfa8fe35a0bb90382837b238fff375de15f0dcdb9ae68ff85f7a63649c98527e"},
{file = "coverage-7.4.4-cp38-cp38-win_amd64.whl", hash = "sha256:b2991665420a803495e0b90a79233c1433d6ed77ef282e8e152a324bbbc5e0c8"},
{file = "coverage-7.4.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3b799445b9f7ee8bf299cfaed6f5b226c0037b74886a4e11515e569b36fe310d"},
{file = "coverage-7.4.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b4d33f418f46362995f1e9d4f3a35a1b6322cb959c31d88ae56b0298e1c22357"},
{file = "coverage-7.4.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aadacf9a2f407a4688d700e4ebab33a7e2e408f2ca04dbf4aef17585389eff3e"},
{file = "coverage-7.4.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7c95949560050d04d46b919301826525597f07b33beba6187d04fa64d47ac82e"},
{file = "coverage-7.4.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff7687ca3d7028d8a5f0ebae95a6e4827c5616b31a4ee1192bdfde697db110d4"},
{file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5fc1de20b2d4a061b3df27ab9b7c7111e9a710f10dc2b84d33a4ab25065994ec"},
{file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:c74880fc64d4958159fbd537a091d2a585448a8f8508bf248d72112723974cbd"},
{file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:742a76a12aa45b44d236815d282b03cfb1de3b4323f3e4ec933acfae08e54ade"},
{file = "coverage-7.4.4-cp39-cp39-win32.whl", hash = "sha256:d89d7b2974cae412400e88f35d86af72208e1ede1a541954af5d944a8ba46c57"},
{file = "coverage-7.4.4-cp39-cp39-win_amd64.whl", hash = "sha256:9ca28a302acb19b6af89e90f33ee3e1906961f94b54ea37de6737b7ca9d8827c"},
{file = "coverage-7.4.4-pp38.pp39.pp310-none-any.whl", hash = "sha256:b2c5edc4ac10a7ef6605a966c58929ec6c1bd0917fb8c15cb3363f65aa40e677"},
{file = "coverage-7.4.4.tar.gz", hash = "sha256:c901df83d097649e257e803be22592aedfd5182f07b3cc87d640bbb9afd50f49"},
]
[package.dependencies]
tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""}
[package.extras]
toml = ["tomli"]
[[package]] [[package]]
name = "debugpy" name = "debugpy"
version = "1.8.1" version = "1.8.1"
@ -2103,6 +2199,24 @@ tomli = {version = ">=1", markers = "python_version < \"3.11\""}
[package.extras] [package.extras]
testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
[[package]]
name = "pytest-cov"
version = "5.0.0"
description = "Pytest plugin for measuring coverage."
optional = false
python-versions = ">=3.8"
files = [
{file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"},
{file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"},
]
[package.dependencies]
coverage = {version = ">=5.2.1", extras = ["toml"]}
pytest = ">=4.6"
[package.extras]
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"]
[[package]] [[package]]
name = "python-dateutil" name = "python-dateutil"
version = "2.9.0.post0" version = "2.9.0.post0"
@ -3216,4 +3330,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "93c406139c456780b3d309d7ed3d68ea60cc0e8893c1ee717692984e573d3404" content-hash = "8800bb8b24312d17b765cd2ce2799f49436171dd5fbf1bec3b07f853cfa9befd"

View File

@ -52,12 +52,14 @@ robomimic = "0.2.0"
huggingface-hub = "^0.21.4" huggingface-hub = "^0.21.4"
gymnasium-robotics = "^1.2.4" gymnasium-robotics = "^1.2.4"
gymnasium = "^0.29.1" gymnasium = "^0.29.1"
cmake = "^3.29.0.1"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
pre-commit = "^3.6.2" pre-commit = "^3.6.2"
debugpy = "^1.8.1" debugpy = "^1.8.1"
pytest = "^8.1.0" pytest = "^8.1.0"
pytest-cov = "^5.0.0"
[[tool.poetry.source]] [[tool.poetry.source]]

View File

@ -1,4 +1,4 @@
name: Test name: Tests
on: on:
pull_request: pull_request:
@ -10,7 +10,7 @@ on:
- main - main
jobs: jobs:
test: tests:
if: | if: |
${{ github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'CI') }} || ${{ github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'CI') }} ||
${{ github.event_name == 'push' }} ${{ github.event_name == 'push' }}
@ -19,7 +19,6 @@ jobs:
POETRY_VERSION: 1.8.2 POETRY_VERSION: 1.8.2
DATA_DIR: tests/data DATA_DIR: tests/data
MUJOCO_GL: egl MUJOCO_GL: egl
LEROBOT_TESTS_DEVICE: cpu
steps: steps:
#---------------------------------------------- #----------------------------------------------
# check-out repo and set-up python # check-out repo and set-up python
@ -110,34 +109,126 @@ jobs:
run: poetry install --no-interaction run: poetry install --no-interaction
#---------------------------------------------- #----------------------------------------------
# run tests # run tests & coverage
#---------------------------------------------- #----------------------------------------------
- name: Run tests - name: Run tests
env:
LEROBOT_TESTS_DEVICE: cpu
run: | run: |
source .venv/bin/activate source .venv/bin/activate
pytest tests pytest --cov=./lerobot --cov-report=xml tests
- name: Test train pusht end-to-end # TODO(aliberts): Link with HF Codecov account
# - name: Upload coverage reports to Codecov with GitHub Action
# uses: codecov/codecov-action@v4
# with:
# files: ./coverage.xml
# verbose: true
#----------------------------------------------
# run end-to-end tests
#----------------------------------------------
- name: Test train ACT on ALOHA end-to-end
run: | run: |
source .venv/bin/activate source .venv/bin/activate
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
hydra.job.name=pusht \ policy=act \
env=aloha \
wandb.enable=False \
offline_steps=2 \
online_steps=0 \
device=cpu \
save_model=true \
save_freq=2 \
horizon=20 \
policy.batch_size=2 \
hydra.run.dir=tests/outputs/act/
- name: Test eval ACT on ALOHA end-to-end
run: |
source .venv/bin/activate
python lerobot/scripts/eval.py \
--config tests/outputs/act/.hydra/config.yaml \
eval_episodes=1 \
env.episode_length=8 \
device=cpu \
policy.pretrained_model_path=tests/outputs/act/models/2.pt
# TODO(aliberts): This takes ~2mn to run, needs to be improved
# - name: Test eval ACT on ALOHA end-to-end (policy is None)
# run: |
# source .venv/bin/activate
# python lerobot/scripts/eval.py \
# --config lerobot/configs/default.yaml \
# policy=act \
# env=aloha \
# eval_episodes=1 \
# device=cpu
- name: Test train Diffusion on PushT end-to-end
run: |
source .venv/bin/activate
python lerobot/scripts/train.py \
policy=diffusion \
env=pusht \ env=pusht \
wandb.enable=False \ wandb.enable=False \
offline_steps=2 \ offline_steps=2 \
online_steps=0 \ online_steps=0 \
device=cpu \ device=cpu \
save_model=true \ save_model=true \
save_freq=1 \ save_freq=2 \
hydra.run.dir=tests/outputs/ hydra.run.dir=tests/outputs/diffusion/
- name: Test eval pusht end-to-end - name: Test eval Diffusion on PushT end-to-end
run: | run: |
source .venv/bin/activate source .venv/bin/activate
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
--config tests/outputs/.hydra/config.yaml \ --config tests/outputs/diffusion/.hydra/config.yaml \
wandb.enable=False \
eval_episodes=1 \ eval_episodes=1 \
env.episode_length=8 \ env.episode_length=8 \
device=cpu \ device=cpu \
policy.pretrained_model_path=tests/outputs/models/1.pt policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt
- name: Test eval Diffusion on PushT end-to-end (policy is None)
run: |
source .venv/bin/activate
python lerobot/scripts/eval.py \
--config lerobot/configs/default.yaml \
policy=diffusion \
env=pusht \
eval_episodes=1 \
device=cpu
- name: Test train TDMPC on Simxarm end-to-end
run: |
source .venv/bin/activate
python lerobot/scripts/train.py \
policy=tdmpc \
env=simxarm \
wandb.enable=False \
offline_steps=1 \
online_steps=1 \
device=cpu \
save_model=true \
save_freq=2 \
hydra.run.dir=tests/outputs/tdmpc/
- name: Test eval TDMPC on Simxarm end-to-end
run: |
source .venv/bin/activate
python lerobot/scripts/eval.py \
--config tests/outputs/tdmpc/.hydra/config.yaml \
eval_episodes=1 \
env.episode_length=8 \
device=cpu \
policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt
- name: Test eval TDPMC on Simxarm end-to-end (policy is None)
run: |
source .venv/bin/activate
python lerobot/scripts/eval.py \
--config lerobot/configs/default.yaml \
policy=tdmpc \
env=simxarm \
eval_episodes=1 \
device=cpu

View File

@ -8,9 +8,25 @@
<br/> <br/>
</p> </p>
# LeRobot <div align="center">
[![Tests](https://github.com/huggingface/lerobot/actions/workflows/test.yml/badge.svg?branch=main)](https://github.com/huggingface/lerobot/actions/workflows/test.yml?query=branch%3Amain)
[![Coverage](https://codecov.io/gh/huggingface/lerobot/branch/main/graph/badge.svg?token=TODO)](https://codecov.io/gh/huggingface/lerobot)
[![Python versions](https://img.shields.io/pypi/pyversions/lerobot)](https://www.python.org/downloads/)
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/huggingface/lerobot/blob/main/LICENSE)
[![Status](https://img.shields.io/pypi/status/lerobot)](https://pypi.org/project/lerobot/)
[![Version](https://img.shields.io/pypi/v/lerobot)](https://pypi.org/project/lerobot/)
[![Examples](https://img.shields.io/badge/Examples-green.svg)](https://github.com/huggingface/lerobot/tree/main/examples)
[![Discord](https://dcbadge.vercel.app/api/server/C5P34WJ68S?style=flat)](https://discord.gg/s3KuuzsPFb)
</div>
<h3 align="center">
<p>State-of-the-art Machine Learning for real-world robotics</p>
</h3>
---
**State-of-the-art machine learning for real-world robotics**
🤗 LeRobot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier for entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models. 🤗 LeRobot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier for entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models.
@ -44,26 +60,21 @@
## Installation ## Installation
Create a virtual environment with Python 3.10, e.g. using `conda`: Download our source code:
```bash
git clone https://github.com/huggingface/lerobot.git
cd lerobot
```
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
```bash ```bash
conda create -y -n lerobot python=3.10 conda create -y -n lerobot python=3.10
conda activate lerobot conda activate lerobot
``` ```
[Install `poetry`](https://python-poetry.org/docs/#installation) (if you don't have it already) Then, install 🤗 LeRobot:
```bash ```bash
curl -sSL https://install.python-poetry.org | python - python -m pip install .
```
Install dependencies
```bash
poetry install
```
If you encounter a disk space error, try to change your tmp dir to a location where you have enough disk space, e.g.
```bash
mkdir ~/tmp
export TMPDIR='~/tmp'
``` ```
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiments tracking, log in with To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiments tracking, log in with
@ -174,11 +185,11 @@ hydra.run.dir=outputs/train/aloha_act
Feel free to open issues and PRs, and to coordinate your efforts with the community on our [Discord Channel](https://discord.gg/VjFz58wn3R). For specific inquiries, reach out to [Remi Cadene](remi.cadene@huggingface.co). Feel free to open issues and PRs, and to coordinate your efforts with the community on our [Discord Channel](https://discord.gg/VjFz58wn3R). For specific inquiries, reach out to [Remi Cadene](remi.cadene@huggingface.co).
**TODO** ### TODO
If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/orgs/huggingface/projects/46) If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/orgs/huggingface/projects/46)
**Follow our style** ### Follow our style
```bash ```bash
# install if needed # install if needed
@ -187,25 +198,33 @@ pre-commit install
pre-commit pre-commit
``` ```
**Add dependencies** ### Add dependencies
Instead of `pip install some-package`, we use `poetry` to track the versions of our dependencies: Instead of using `pip` directly, we use `poetry` for development purposes to easily track our dependencies.
If you don't have it already, follow the [instructions](https://python-poetry.org/docs/#installation) to install it.
Install the project with:
```bash
poetry install
```
Then, the equivalent of `pip install some-package`, would just be:
```bash ```bash
poetry add some-package poetry add some-package
``` ```
**NOTE:** Currently, to ensure the CI works properly, any new package must also be added in the CPU-only environment dedicated CI. To do this, you should create a separate environment and add the new package there as well. For example: **NOTE:** Currently, to ensure the CI works properly, any new package must also be added in the CPU-only environment dedicated to the CI. To do this, you should create a separate environment and add the new package there as well. For example:
```bash ```bash
# add the new package to your main poetry env # Add the new package to your main poetry env
poetry add some-package poetry add some-package
# add the same package to the CPU-only env dedicated to CI # Add the same package to the CPU-only env dedicated to CI
conda create -y -n lerobot-ci python=3.10 conda create -y -n lerobot-ci python=3.10
conda activate lerobot-ci conda activate lerobot-ci
cd .github/poetry/cpu cd .github/poetry/cpu
poetry add some-package poetry add some-package
``` ```
**Run tests locally** ### Run tests locally
Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already). Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already).
@ -236,7 +255,7 @@ Run tests
DATA_DIR="tests/data" pytest -sx tests DATA_DIR="tests/data" pytest -sx tests
``` ```
**Add a new dataset** ### Add a new dataset
To add a dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access: To add a dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access:
```bash ```bash
@ -297,7 +316,7 @@ Finally, you might want to mock the dataset if you need to update the unit tests
python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET
``` ```
**Add a pretrained policy** ### Add a pretrained policy
Once you have trained a policy you may upload it to the HuggingFace hub. Once you have trained a policy you may upload it to the HuggingFace hub.
@ -333,7 +352,7 @@ huggingface-cli upload $HUB_ID to_upload
See `eval.py` for an example of how a user may use your policy. See `eval.py` for an example of how a user may use your policy.
**Improve your code with profiling** ### Improve your code with profiling
An example of a code snippet to profile the evaluation of a policy: An example of a code snippet to profile the evaluation of a policy:
```python ```python

View File

@ -206,7 +206,7 @@ class AlohaEnv(AbstractEnv):
if self.from_pixels: if self.from_pixels:
if isinstance(self.image_size, int): if isinstance(self.image_size, int):
image_shape = (3, self.image_size, self.image_size) image_shape = (3, self.image_size, self.image_size)
elif OmegaConf.is_list(self.image_size): elif OmegaConf.is_list(self.image_size) or isinstance(self.image_size, list):
assert len(self.image_size) == 3 # c h w assert len(self.image_size) == 3 # c h w
assert self.image_size[0] == 3 # c is RGB assert self.image_size[0] == 3 # c is RGB
image_shape = tuple(self.image_size) image_shape = tuple(self.image_size)

116
poetry.lock generated
View File

@ -327,6 +327,35 @@ files = [
{file = "cloudpickle-3.0.0.tar.gz", hash = "sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882"}, {file = "cloudpickle-3.0.0.tar.gz", hash = "sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882"},
] ]
[[package]]
name = "cmake"
version = "3.29.0.1"
description = "CMake is an open-source, cross-platform family of tools designed to build, test and package software"
optional = false
python-versions = ">=3.7"
files = [
{file = "cmake-3.29.0.1-py3-none-macosx_10_10_universal2.macosx_10_10_x86_64.macosx_11_0_arm64.macosx_11_0_universal2.whl", hash = "sha256:ec8b39fdeb75c48fd5a2894658a1ca75f94fb49b421c1f753d86d3e5d5e9f196"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:ead7dc5176a6c6347b3fc19532c25ec328f9279b6213902ac930242334e7b621"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f7c7dabf3dd40cb830d2eded43d51c5a3737625bbc5ab6916041c04e352b74f8"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de8f55198f4a820daf2c57645a4bb8cd1064dc92d950ad95be14c5ffabc15bd4"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7bb8aaa3419eafd466931e4dcc161d3e5e6a82730ab508c75946ff4fc883b3f6"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e956e4b6c2d8d6bbee399bf3c77e5b901d916fe8f35d6b2f58444d5892c4602b"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a8f7c8b07e6ab0dd444c5b74e658d5013ca0da456041029f734d751090bb7ec"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22493049b6383ea2baa7237a326c2914ab4a7b3e1642f4233245e3a34aae39f6"},
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:d09573411901fc8a3ede7433713c000ac7b81cda0d771874e9182770acf29eb4"},
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_i686.whl", hash = "sha256:c85e35ec572e54152154637f24d6bde316fbcf94dfea644bd8f22b1855a09abe"},
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:a36f3b19a5caa6c63aa70bfa0b262bae8d296e68c6e6d8e918edc5c51d952bf8"},
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_s390x.whl", hash = "sha256:85a0e28eaaee311d50fcee60f730e5a44a65b3cefc556a1163bbabd7328acd60"},
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:58879ef15dd8344e1583a36cead794fb0fee13f78c590a56749283ac2e27d30a"},
{file = "cmake-3.29.0.1-py3-none-win32.whl", hash = "sha256:068a3e7461dd9e487f5f3f720ae8072c2eeb37239ebe4642c4ae29058d83347f"},
{file = "cmake-3.29.0.1-py3-none-win_amd64.whl", hash = "sha256:c097892b3653e2d2d41d055c80a9af029fdd24f9eff2ecb66576e4da8b85b1c7"},
{file = "cmake-3.29.0.1-py3-none-win_arm64.whl", hash = "sha256:1808071047cb49ed0fe2359e4c310b49880c2805cad4ea9f03c959f51b881ac7"},
{file = "cmake-3.29.0.1.tar.gz", hash = "sha256:ec49a7a4480959c229d9d2aa77f7885859c17a45fc66981aaf4551ceffb4d030"},
]
[package.extras]
test = ["coverage (>=4.2)", "pytest (>=3.0.3)", "pytest-cov (>=2.4.0)"]
[[package]] [[package]]
name = "colorama" name = "colorama"
version = "0.4.6" version = "0.4.6"
@ -338,6 +367,73 @@ files = [
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
] ]
[[package]]
name = "coverage"
version = "7.4.4"
description = "Code coverage measurement for Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "coverage-7.4.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0be5efd5127542ef31f165de269f77560d6cdef525fffa446de6f7e9186cfb2"},
{file = "coverage-7.4.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ccd341521be3d1b3daeb41960ae94a5e87abe2f46f17224ba5d6f2b8398016cf"},
{file = "coverage-7.4.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09fa497a8ab37784fbb20ab699c246053ac294d13fc7eb40ec007a5043ec91f8"},
{file = "coverage-7.4.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b1a93009cb80730c9bca5d6d4665494b725b6e8e157c1cb7f2db5b4b122ea562"},
{file = "coverage-7.4.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:690db6517f09336559dc0b5f55342df62370a48f5469fabf502db2c6d1cffcd2"},
{file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:09c3255458533cb76ef55da8cc49ffab9e33f083739c8bd4f58e79fecfe288f7"},
{file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8ce1415194b4a6bd0cdcc3a1dfbf58b63f910dcb7330fe15bdff542c56949f87"},
{file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b91cbc4b195444e7e258ba27ac33769c41b94967919f10037e6355e998af255c"},
{file = "coverage-7.4.4-cp310-cp310-win32.whl", hash = "sha256:598825b51b81c808cb6f078dcb972f96af96b078faa47af7dfcdf282835baa8d"},
{file = "coverage-7.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:09ef9199ed6653989ebbcaacc9b62b514bb63ea2f90256e71fea3ed74bd8ff6f"},
{file = "coverage-7.4.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0f9f50e7ef2a71e2fae92774c99170eb8304e3fdf9c8c3c7ae9bab3e7229c5cf"},
{file = "coverage-7.4.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:623512f8ba53c422fcfb2ce68362c97945095b864cda94a92edbaf5994201083"},
{file = "coverage-7.4.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0513b9508b93da4e1716744ef6ebc507aff016ba115ffe8ecff744d1322a7b63"},
{file = "coverage-7.4.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40209e141059b9370a2657c9b15607815359ab3ef9918f0196b6fccce8d3230f"},
{file = "coverage-7.4.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a2b2b78c78293782fd3767d53e6474582f62443d0504b1554370bde86cc8227"},
{file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:73bfb9c09951125d06ee473bed216e2c3742f530fc5acc1383883125de76d9cd"},
{file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:1f384c3cc76aeedce208643697fb3e8437604b512255de6d18dae3f27655a384"},
{file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:54eb8d1bf7cacfbf2a3186019bcf01d11c666bd495ed18717162f7eb1e9dd00b"},
{file = "coverage-7.4.4-cp311-cp311-win32.whl", hash = "sha256:cac99918c7bba15302a2d81f0312c08054a3359eaa1929c7e4b26ebe41e9b286"},
{file = "coverage-7.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:b14706df8b2de49869ae03a5ccbc211f4041750cd4a66f698df89d44f4bd30ec"},
{file = "coverage-7.4.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:201bef2eea65e0e9c56343115ba3814e896afe6d36ffd37bab783261db430f76"},
{file = "coverage-7.4.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:41c9c5f3de16b903b610d09650e5e27adbfa7f500302718c9ffd1c12cf9d6818"},
{file = "coverage-7.4.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d898fe162d26929b5960e4e138651f7427048e72c853607f2b200909794ed978"},
{file = "coverage-7.4.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3ea79bb50e805cd6ac058dfa3b5c8f6c040cb87fe83de10845857f5535d1db70"},
{file = "coverage-7.4.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce4b94265ca988c3f8e479e741693d143026632672e3ff924f25fab50518dd51"},
{file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:00838a35b882694afda09f85e469c96367daa3f3f2b097d846a7216993d37f4c"},
{file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fdfafb32984684eb03c2d83e1e51f64f0906b11e64482df3c5db936ce3839d48"},
{file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:69eb372f7e2ece89f14751fbcbe470295d73ed41ecd37ca36ed2eb47512a6ab9"},
{file = "coverage-7.4.4-cp312-cp312-win32.whl", hash = "sha256:137eb07173141545e07403cca94ab625cc1cc6bc4c1e97b6e3846270e7e1fea0"},
{file = "coverage-7.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:d71eec7d83298f1af3326ce0ff1d0ea83c7cb98f72b577097f9083b20bdaf05e"},
{file = "coverage-7.4.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d5ae728ff3b5401cc320d792866987e7e7e880e6ebd24433b70a33b643bb0384"},
{file = "coverage-7.4.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cc4f1358cb0c78edef3ed237ef2c86056206bb8d9140e73b6b89fbcfcbdd40e1"},
{file = "coverage-7.4.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8130a2aa2acb8788e0b56938786c33c7c98562697bf9f4c7d6e8e5e3a0501e4a"},
{file = "coverage-7.4.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf271892d13e43bc2b51e6908ec9a6a5094a4df1d8af0bfc360088ee6c684409"},
{file = "coverage-7.4.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4cdc86d54b5da0df6d3d3a2f0b710949286094c3a6700c21e9015932b81447e"},
{file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ae71e7ddb7a413dd60052e90528f2f65270aad4b509563af6d03d53e979feafd"},
{file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:38dd60d7bf242c4ed5b38e094baf6401faa114fc09e9e6632374388a404f98e7"},
{file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aa5b1c1bfc28384f1f53b69a023d789f72b2e0ab1b3787aae16992a7ca21056c"},
{file = "coverage-7.4.4-cp38-cp38-win32.whl", hash = "sha256:dfa8fe35a0bb90382837b238fff375de15f0dcdb9ae68ff85f7a63649c98527e"},
{file = "coverage-7.4.4-cp38-cp38-win_amd64.whl", hash = "sha256:b2991665420a803495e0b90a79233c1433d6ed77ef282e8e152a324bbbc5e0c8"},
{file = "coverage-7.4.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3b799445b9f7ee8bf299cfaed6f5b226c0037b74886a4e11515e569b36fe310d"},
{file = "coverage-7.4.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b4d33f418f46362995f1e9d4f3a35a1b6322cb959c31d88ae56b0298e1c22357"},
{file = "coverage-7.4.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aadacf9a2f407a4688d700e4ebab33a7e2e408f2ca04dbf4aef17585389eff3e"},
{file = "coverage-7.4.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7c95949560050d04d46b919301826525597f07b33beba6187d04fa64d47ac82e"},
{file = "coverage-7.4.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff7687ca3d7028d8a5f0ebae95a6e4827c5616b31a4ee1192bdfde697db110d4"},
{file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5fc1de20b2d4a061b3df27ab9b7c7111e9a710f10dc2b84d33a4ab25065994ec"},
{file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:c74880fc64d4958159fbd537a091d2a585448a8f8508bf248d72112723974cbd"},
{file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:742a76a12aa45b44d236815d282b03cfb1de3b4323f3e4ec933acfae08e54ade"},
{file = "coverage-7.4.4-cp39-cp39-win32.whl", hash = "sha256:d89d7b2974cae412400e88f35d86af72208e1ede1a541954af5d944a8ba46c57"},
{file = "coverage-7.4.4-cp39-cp39-win_amd64.whl", hash = "sha256:9ca28a302acb19b6af89e90f33ee3e1906961f94b54ea37de6737b7ca9d8827c"},
{file = "coverage-7.4.4-pp38.pp39.pp310-none-any.whl", hash = "sha256:b2c5edc4ac10a7ef6605a966c58929ec6c1bd0917fb8c15cb3363f65aa40e677"},
{file = "coverage-7.4.4.tar.gz", hash = "sha256:c901df83d097649e257e803be22592aedfd5182f07b3cc87d640bbb9afd50f49"},
]
[package.dependencies]
tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""}
[package.extras]
toml = ["tomli"]
[[package]] [[package]]
name = "debugpy" name = "debugpy"
version = "1.8.1" version = "1.8.1"
@ -2307,6 +2403,24 @@ tomli = {version = ">=1", markers = "python_version < \"3.11\""}
[package.extras] [package.extras]
testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
[[package]]
name = "pytest-cov"
version = "5.0.0"
description = "Pytest plugin for measuring coverage."
optional = false
python-versions = ">=3.8"
files = [
{file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"},
{file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"},
]
[package.dependencies]
coverage = {version = ">=5.2.1", extras = ["toml"]}
pytest = ">=4.6"
[package.extras]
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"]
[[package]] [[package]]
name = "python-dateutil" name = "python-dateutil"
version = "2.9.0.post0" version = "2.9.0.post0"
@ -3475,4 +3589,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "99addbfc02bcd35a308f4ecc5b4285c9c5054118f4aadea27650d8bf355d9616" content-hash = "174c7d42f8039eedd2c447a4e6cae5169782cbd94346b5606572a0010194ca05"

View File

@ -51,12 +51,14 @@ huggingface-hub = {extras = ["hf-transfer"], version = "^0.21.4"}
robomimic = "0.2.0" robomimic = "0.2.0"
gymnasium-robotics = "^1.2.4" gymnasium-robotics = "^1.2.4"
gymnasium = "^0.29.1" gymnasium = "^0.29.1"
cmake = "^3.29.0.1"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
pre-commit = "^3.6.2" pre-commit = "^3.6.2"
debugpy = "^1.8.1" debugpy = "^1.8.1"
pytest = "^8.1.0" pytest = "^8.1.0"
pytest-cov = "^5.0.0"
[tool.ruff] [tool.ruff]

View File

@ -4,6 +4,7 @@ import torch
from torchrl.envs.utils import check_env_specs, step_mdp from torchrl.envs.utils import check_env_specs, step_mdp
from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.aloha.env import AlohaEnv
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.envs.pusht.env import PushtEnv from lerobot.common.envs.pusht.env import PushtEnv
from lerobot.common.envs.simxarm.env import SimxarmEnv from lerobot.common.envs.simxarm.env import SimxarmEnv
@ -39,6 +40,26 @@ def print_spec_rollout(env):
print("data from rollout:", simple_rollout(100)) print("data from rollout:", simple_rollout(100))
@pytest.mark.parametrize(
"task,from_pixels,pixels_only",
[
("sim_insertion", True, False),
("sim_insertion", True, True),
("sim_transfer_cube", True, False),
("sim_transfer_cube", True, True),
],
)
def test_aloha(task, from_pixels, pixels_only):
env = AlohaEnv(
task,
from_pixels=from_pixels,
pixels_only=pixels_only,
image_size=[3, 480, 640] if from_pixels else None,
)
# print_spec_rollout(env)
check_env_specs(env)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"task,from_pixels,pixels_only", "task,from_pixels,pixels_only",
[ [