Merge remote-tracking branch 'upstream/main'

This commit is contained in:
Alexander Soare 2024-05-09 17:01:28 +01:00
commit 001d74961e
17 changed files with 114 additions and 115 deletions

View File

@ -22,9 +22,8 @@ test-end-to-end:
${MAKE} test-act-ete-eval ${MAKE} test-act-ete-eval
${MAKE} test-diffusion-ete-train ${MAKE} test-diffusion-ete-train
${MAKE} test-diffusion-ete-eval ${MAKE} test-diffusion-ete-eval
# TODO(rcadene, alexander-soare): enable end-to-end tests for tdmpc ${MAKE} test-tdmpc-ete-train
# ${MAKE} test-tdmpc-ete-train ${MAKE} test-tdmpc-ete-eval
# ${MAKE} test-tdmpc-ete-eval
${MAKE} test-default-ete-eval ${MAKE} test-default-ete-eval
test-act-ete-train: test-act-ete-train:
@ -80,7 +79,7 @@ test-tdmpc-ete-train:
policy=tdmpc \ policy=tdmpc \
env=xarm \ env=xarm \
env.task=XarmLift-v0 \ env.task=XarmLift-v0 \
dataset_repo_id=lerobot/xarm_lift_medium_replay \ dataset_repo_id=lerobot/xarm_lift_medium \
wandb.enable=False \ wandb.enable=False \
training.offline_steps=2 \ training.offline_steps=2 \
training.online_steps=2 \ training.online_steps=2 \

View File

@ -7,6 +7,11 @@ ARG DEBIAN_FRONTEND=noninteractive
# Install apt dependencies # Install apt dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \ RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake \ build-essential cmake \
git git-lfs openssh-client \
nano vim \
htop atop nvtop \
sed gawk grep curl wget \
tcpdump sysstat screen \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \ libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \ python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
&& apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get clean && rm -rf /var/lib/apt/lists/*
@ -18,7 +23,8 @@ ENV PATH="/opt/venv/bin:$PATH"
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
# Install LeRobot # Install LeRobot
COPY . /lerobot RUN git lfs install
RUN git clone https://github.com/huggingface/lerobot.git
WORKDIR /lerobot WORKDIR /lerobot
RUN pip install --upgrade --no-cache-dir pip RUN pip install --upgrade --no-cache-dir pip
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht]" RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht]"

View File

@ -47,7 +47,7 @@ class TDMPCConfig:
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
elites, when updating the gaussian parameters for CEM. elites, when updating the gaussian parameters for CEM.
gaussian_mean_momentum: Momentum (α) used for EMA updates of the mean parameter μ of the gaussian gaussian_mean_momentum: Momentum (α) used for EMA updates of the mean parameter μ of the gaussian
paramters optimized in CEM. Updates are calculated as μ αμ + (1-α)μ. parameters optimized in CEM. Updates are calculated as μ αμ + (1-α)μ.
max_random_shift_ratio: Maximum random shift (as a proportion of the image size) to apply to the max_random_shift_ratio: Maximum random shift (as a proportion of the image size) to apply to the
image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation
is applied. Note that the input images are assumed to be square for this augmentation. is applied. Note that the input images are assumed to be square for this augmentation.

View File

@ -3,6 +3,12 @@
seed: 1000 seed: 1000
dataset_repo_id: lerobot/aloha_sim_insertion_human dataset_repo_id: lerobot/aloha_sim_insertion_human
override_dataset_stats:
observation.images.top:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
training: training:
offline_steps: 80000 offline_steps: 80000
online_steps: 0 online_steps: 0
@ -18,12 +24,6 @@ training:
grad_clip_norm: 10 grad_clip_norm: 10
online_steps_between_rollouts: 1 online_steps_between_rollouts: 1
override_dataset_stats:
observation.images.top:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
delta_timestamps: delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]" action: "[i / ${fps} for i in range(${policy.chunk_size})]"

View File

@ -7,6 +7,20 @@
seed: 100000 seed: 100000
dataset_repo_id: lerobot/pusht dataset_repo_id: lerobot/pusht
override_dataset_stats:
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
observation.image:
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
# from the original codebase, but we should remove these and train our own pretrained model
observation.state:
min: [13.456424, 32.938293]
max: [496.14618, 510.9579]
action:
min: [12.0, 25.0]
max: [511.0, 511.0]
training: training:
offline_steps: 200000 offline_steps: 200000
online_steps: 0 online_steps: 0
@ -34,20 +48,6 @@ eval:
n_episodes: 50 n_episodes: 50
batch_size: 50 batch_size: 50
override_dataset_stats:
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
observation.image:
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
# from the original codebase, but we should remove these and train our own pretrained model
observation.state:
min: [13.456424, 32.938293]
max: [496.14618, 510.9579]
action:
min: [12.0, 25.0]
max: [511.0, 511.0]
policy: policy:
name: diffusion name: diffusion

View File

@ -1,7 +1,7 @@
# @package _global_ # @package _global_
seed: 1 seed: 1
dataset_repo_id: lerobot/xarm_lift_medium_replay dataset_repo_id: lerobot/xarm_lift_medium
training: training:
offline_steps: 25000 offline_steps: 25000

110
poetry.lock generated
View File

@ -131,17 +131,6 @@ files = [
{file = "antlr4-python3-runtime-4.9.3.tar.gz", hash = "sha256:f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b"}, {file = "antlr4-python3-runtime-4.9.3.tar.gz", hash = "sha256:f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b"},
] ]
[[package]]
name = "appdirs"
version = "1.4.4"
description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
optional = false
python-versions = "*"
files = [
{file = "appdirs-1.4.4-py2.py3-none-any.whl", hash = "sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128"},
{file = "appdirs-1.4.4.tar.gz", hash = "sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41"},
]
[[package]] [[package]]
name = "asciitree" name = "asciitree"
version = "0.3.3" version = "0.3.3"
@ -1108,67 +1097,67 @@ protobuf = ["grpcio-tools (>=1.63.0)"]
[[package]] [[package]]
name = "gym-aloha" name = "gym-aloha"
version = "0.1.0" version = "0.1.1"
description = "A gym environment for ALOHA" description = "A gym environment for ALOHA"
optional = true optional = true
python-versions = "<4.0,>=3.10" python-versions = "<4.0,>=3.10"
files = [ files = [
{file = "gym_aloha-0.1.0-py3-none-any.whl", hash = "sha256:62e36eeb09284422cbb7baca0292c6f65e38ec8774bf9b0bf7159ad5990cf29a"}, {file = "gym_aloha-0.1.1-py3-none-any.whl", hash = "sha256:2698037246dbb106828f0bc229b61007b0a21d5967c72cc373f7bc1083203584"},
{file = "gym_aloha-0.1.0.tar.gz", hash = "sha256:bab332f469ba5ffe655fc3e9647aead05d2cb3b950dfb1f299b9539b3857ad7e"}, {file = "gym_aloha-0.1.1.tar.gz", hash = "sha256:614ae1cf116323e7b5ae2f0e9bd282c4f052aee15e839e5587ddce45995359bc"},
] ]
[package.dependencies] [package.dependencies]
dm-control = "1.0.14" dm-control = ">=1.0.14"
gymnasium = ">=0.29.1,<0.30.0" gymnasium = ">=0.29.1"
imageio = {version = ">=2.34.0,<3.0.0", extras = ["ffmpeg"]} imageio = {version = ">=2.34.0", extras = ["ffmpeg"]}
mujoco = ">=2.3.7,<3.0.0" mujoco = ">=2.3.7,<3.0.0"
[package.extras] [package.extras]
dev = ["debugpy (>=1.8.1,<2.0.0)", "pre-commit (>=3.7.0,<4.0.0)"] dev = ["debugpy (>=1.8.1)", "pre-commit (>=3.7.0)"]
test = ["pytest (>=8.1.0,<9.0.0)", "pytest-cov (>=5.0.0,<6.0.0)"] test = ["pytest (>=8.1.0)", "pytest-cov (>=5.0.0)"]
[[package]] [[package]]
name = "gym-pusht" name = "gym-pusht"
version = "0.1.1" version = "0.1.3"
description = "A gymnasium environment for PushT." description = "A gymnasium environment for PushT."
optional = true optional = true
python-versions = "<4.0,>=3.10" python-versions = "<4.0,>=3.10"
files = [ files = [
{file = "gym_pusht-0.1.1-py3-none-any.whl", hash = "sha256:dcf8644713db48286e907aabb11e005b0592632e323baa40d1a4f2dfbbc76c3d"}, {file = "gym_pusht-0.1.3-py3-none-any.whl", hash = "sha256:feeb02493a03d1aacc45d43d6397962c50ed779ab7e4019d73af11d2f0b3831b"},
{file = "gym_pusht-0.1.1.tar.gz", hash = "sha256:0d1c9ffd4ad0e2411efcc724003a365a853f20b6d596980c113e7ec181ac021f"}, {file = "gym_pusht-0.1.3.tar.gz", hash = "sha256:c8e9a5256035ba49841ebbc7c32a06c4fa2daa52f5fad80da941b607c4553e28"},
] ]
[package.dependencies] [package.dependencies]
gymnasium = ">=0.29.1,<0.30.0" gymnasium = ">=0.29.1"
opencv-python = ">=4.9.0.80,<5.0.0.0" opencv-python = ">=4.9.0"
pygame = ">=2.5.2,<3.0.0" pygame = ">=2.5.2"
pymunk = ">=6.6.0,<7.0.0" pymunk = ">=6.6.0"
scikit-image = ">=0.22.0" scikit-image = ">=0.22.0"
shapely = ">=2.0.3,<3.0.0" shapely = ">=2.0.3"
[package.extras] [package.extras]
dev = ["debugpy (>=1.8.1,<2.0.0)", "pre-commit (>=3.7.0,<4.0.0)"] dev = ["debugpy (>=1.8.1)", "pre-commit (>=3.7.0)"]
test = ["pytest (>=8.1.0,<9.0.0)", "pytest-cov (>=5.0.0,<6.0.0)"] test = ["pytest (>=8.1.0)", "pytest-cov (>=5.0.0)"]
[[package]] [[package]]
name = "gym-xarm" name = "gym-xarm"
version = "0.1.0" version = "0.1.1"
description = "A gym environment for xArm" description = "A gym environment for xArm"
optional = true optional = true
python-versions = "<4.0,>=3.10" python-versions = "<4.0,>=3.10"
files = [ files = [
{file = "gym_xarm-0.1.0-py3-none-any.whl", hash = "sha256:d10ac19a59d302201a9b8bd913530211b1058467b787ad91a657907e40cdbc13"}, {file = "gym_xarm-0.1.1-py3-none-any.whl", hash = "sha256:3bd7e3c1c5521ba80a56536f01a5e11321580704d72160355ce47a828a8808ad"},
{file = "gym_xarm-0.1.0.tar.gz", hash = "sha256:fc05f9d02af1f0205275311669dc191ce431be484e221a96401eb544764eb986"}, {file = "gym_xarm-0.1.1.tar.gz", hash = "sha256:e455524561b02d06b92a4f7d524f448d84a7484d9a2dbc78600e3c66240e0fb7"},
] ]
[package.dependencies] [package.dependencies]
gymnasium = ">=0.29.1,<0.30.0" gymnasium = ">=0.29.1"
gymnasium-robotics = ">=1.2.4,<2.0.0" gymnasium-robotics = ">=1.2.4"
mujoco = ">=2.3.7,<3.0.0" mujoco = ">=2.3.7,<3.0.0"
[package.extras] [package.extras]
dev = ["debugpy (>=1.8.1,<2.0.0)", "pre-commit (>=3.7.0,<4.0.0)"] dev = ["debugpy (>=1.8.1)", "pre-commit (>=3.7.0)"]
test = ["pytest (>=8.1.0,<9.0.0)", "pytest-cov (>=5.0.0,<6.0.0)"] test = ["pytest (>=8.1.0)", "pytest-cov (>=5.0.0)"]
[[package]] [[package]]
name = "gymnasium" name = "gymnasium"
@ -1258,13 +1247,13 @@ numpy = ">=1.17.3"
[[package]] [[package]]
name = "huggingface-hub" name = "huggingface-hub"
version = "0.21.4" version = "0.23.0"
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
optional = false optional = false
python-versions = ">=3.8.0" python-versions = ">=3.8.0"
files = [ files = [
{file = "huggingface_hub-0.21.4-py3-none-any.whl", hash = "sha256:df37c2c37fc6c82163cdd8a67ede261687d80d1e262526d6c0ce73b6b3630a7b"}, {file = "huggingface_hub-0.23.0-py3-none-any.whl", hash = "sha256:075c30d48ee7db2bba779190dc526d2c11d422aed6f9044c5e2fdc2c432fdb91"},
{file = "huggingface_hub-0.21.4.tar.gz", hash = "sha256:e1f4968c93726565a80edf6dc309763c7b546d0cfe79aa221206034d50155531"}, {file = "huggingface_hub-0.23.0.tar.gz", hash = "sha256:7126dedd10a4c6fac796ced4d87a8cf004efc722a5125c2c09299017fa366fa9"},
] ]
[package.dependencies] [package.dependencies]
@ -1277,15 +1266,16 @@ tqdm = ">=4.42.1"
typing-extensions = ">=3.7.4.3" typing-extensions = ">=3.7.4.3"
[package.extras] [package.extras]
all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
cli = ["InquirerPy (==0.3.4)"] cli = ["InquirerPy (==0.3.4)"]
dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"]
hf-transfer = ["hf-transfer (>=0.1.4)"] hf-transfer = ["hf-transfer (>=0.1.4)"]
inference = ["aiohttp", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)"] inference = ["aiohttp", "minijinja (>=1.0)"]
quality = ["mypy (==1.5.1)", "ruff (>=0.1.3)"] quality = ["mypy (==1.5.1)", "ruff (>=0.3.0)"]
tensorflow = ["graphviz", "pydot", "tensorflow"] tensorflow = ["graphviz", "pydot", "tensorflow"]
testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] tensorflow-testing = ["keras (<3.0)", "tensorflow"]
testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
torch = ["safetensors", "torch"] torch = ["safetensors", "torch"]
typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"]
@ -2587,7 +2577,7 @@ xmp = ["defusedxml"]
name = "platformdirs" name = "platformdirs"
version = "4.2.1" version = "4.2.1"
description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`."
optional = true optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "platformdirs-4.2.1-py3-none-any.whl", hash = "sha256:17d5a1161b3fd67b390023cb2d3b026bbd40abde6fdb052dfbd3a29c3ba22ee1"}, {file = "platformdirs-4.2.1-py3-none-any.whl", hash = "sha256:17d5a1161b3fd67b390023cb2d3b026bbd40abde6fdb052dfbd3a29c3ba22ee1"},
@ -4034,36 +4024,40 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess
[[package]] [[package]]
name = "wandb" name = "wandb"
version = "0.16.6" version = "0.17.0"
description = "A CLI and library for interacting with the Weights & Biases API." description = "A CLI and library for interacting with the Weights & Biases API."
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "wandb-0.16.6-py3-none-any.whl", hash = "sha256:5810019a3b981c796e98ea58557a7c380f18834e0c6bdaed15df115522e5616e"}, {file = "wandb-0.17.0-py3-none-any.whl", hash = "sha256:b1b056b4cad83b00436cb76049fd29ecedc6045999dcaa5eba40db6680960ac2"},
{file = "wandb-0.16.6.tar.gz", hash = "sha256:86f491e3012d715e0d7d7421a4d6de41abef643b7403046261f962f3e512fe1c"}, {file = "wandb-0.17.0-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:e1e6f04e093a6a027dcb100618ca23b122d032204b2ed4c62e4e991a48041a6b"},
{file = "wandb-0.17.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:feeb60d4ff506d2a6bc67f953b310d70b004faa789479c03ccd1559c6f1a9633"},
{file = "wandb-0.17.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7bed8a3dd404a639e6bf5fea38c6efe2fb98d416ff1db4fb51be741278ed328"},
{file = "wandb-0.17.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56a1dd6e0e635cba3f6ed30b52c71739bdc2a3e57df155619d2d80ee952b4201"},
{file = "wandb-0.17.0-py3-none-win32.whl", hash = "sha256:1f692d3063a0d50474022cfe6668e1828260436d1cd40827d1e136b7f730c74c"},
{file = "wandb-0.17.0-py3-none-win_amd64.whl", hash = "sha256:ab582ca0d54d52ef5b991de0717350b835400d9ac2d3adab210022b68338d694"},
] ]
[package.dependencies] [package.dependencies]
appdirs = ">=1.4.3" click = ">=7.1,<8.0.0 || >8.0.0"
Click = ">=7.1,<8.0.0 || >8.0.0"
docker-pycreds = ">=0.4.0" docker-pycreds = ">=0.4.0"
GitPython = ">=1.0.0,<3.1.29 || >3.1.29" gitpython = ">=1.0.0,<3.1.29 || >3.1.29"
platformdirs = "*"
protobuf = {version = ">=3.19.0,<4.21.0 || >4.21.0,<5", markers = "python_version > \"3.9\" or sys_platform != \"linux\""} protobuf = {version = ">=3.19.0,<4.21.0 || >4.21.0,<5", markers = "python_version > \"3.9\" or sys_platform != \"linux\""}
psutil = ">=5.0.0" psutil = ">=5.0.0"
PyYAML = "*" pyyaml = "*"
requests = ">=2.0.0,<3" requests = ">=2.0.0,<3"
sentry-sdk = ">=1.0.0" sentry-sdk = ">=1.0.0"
setproctitle = "*" setproctitle = "*"
setuptools = "*" setuptools = "*"
[package.extras] [package.extras]
async = ["httpx (>=0.23.0)"]
aws = ["boto3"] aws = ["boto3"]
azure = ["azure-identity", "azure-storage-blob"] azure = ["azure-identity", "azure-storage-blob"]
gcp = ["google-cloud-storage"] gcp = ["google-cloud-storage"]
importers = ["filelock", "mlflow", "polars", "rich", "tenacity"] importers = ["filelock", "mlflow", "polars", "rich", "tenacity"]
kubeflow = ["google-cloud-storage", "kubernetes", "minio", "sh"] kubeflow = ["google-cloud-storage", "kubernetes", "minio", "sh"]
launch = ["PyYAML (>=6.0.0)", "awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "pydantic", "tomli", "typing-extensions"] launch = ["awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "pydantic", "pyyaml (>=6.0.0)", "tomli", "typing-extensions"]
media = ["bokeh", "moviepy", "numpy", "pillow", "plotly (>=5.18.0)", "rdkit-pypi", "soundfile"] media = ["bokeh", "moviepy", "numpy", "pillow", "plotly (>=5.18.0)", "rdkit-pypi", "soundfile"]
models = ["cloudpickle"] models = ["cloudpickle"]
perf = ["orjson"] perf = ["orjson"]
@ -4309,13 +4303,13 @@ multidict = ">=4.0"
[[package]] [[package]]
name = "zarr" name = "zarr"
version = "2.17.2" version = "2.18.0"
description = "An implementation of chunked, compressed, N-dimensional arrays for Python" description = "An implementation of chunked, compressed, N-dimensional arrays for Python"
optional = false optional = false
python-versions = ">=3.9" python-versions = ">=3.9"
files = [ files = [
{file = "zarr-2.17.2-py3-none-any.whl", hash = "sha256:70d7cc07c24280c380ef80644151d136b7503b0d83c9f214e8000ddc0f57f69b"}, {file = "zarr-2.18.0-py3-none-any.whl", hash = "sha256:7f8532b6a3f50f22e809e130e09353637ec8b5bb5e95a5a0bfaae91f63978b5d"},
{file = "zarr-2.17.2.tar.gz", hash = "sha256:2cbaa6cb4e342d45152d4a7a4b2013c337fcd3a8e7bc98253560180de60552ce"}, {file = "zarr-2.18.0.tar.gz", hash = "sha256:c3b7d2c85b8a42b0ad0ad268a36fb6886ca852098358c125c6b126a417e0a598"},
] ]
[package.dependencies] [package.dependencies]
@ -4354,4 +4348,4 @@ xarm = ["gym-xarm"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.10,<3.13" python-versions = ">=3.10,<3.13"
content-hash = "008a6af5ad9d9eafbd933c922c2c5d84fddae85aff8a9eefc0538b1319966f6e" content-hash = "21dd1d7404ac774bd1139e8cda44ea8e3ed97c30e524f2ed862de431d3d5fa87"

View File

@ -28,37 +28,37 @@ packages = [{include = "lerobot"}]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.10,<3.13" python = ">=3.10,<3.13"
termcolor = "^2.4.0" termcolor = ">=2.4.0"
omegaconf = "^2.3.0" omegaconf = ">=2.3.0"
wandb = "^0.16.3" wandb = ">=0.16.3"
imageio = {extras = ["ffmpeg"], version = "^2.34.0"} imageio = {extras = ["ffmpeg"], version = ">=2.34.0"}
gdown = "^5.1.0" gdown = ">=5.1.0"
hydra-core = "^1.3.2" hydra-core = ">=1.3.2"
einops = "^0.8.0" einops = ">=0.8.0"
pymunk = "^6.6.0" pymunk = ">=6.6.0"
zarr = "^2.17.0" zarr = ">=2.17.0"
numba = "^0.59.0" numba = ">=0.59.0"
torch = "^2.2.1" torch = "^2.2.1"
opencv-python = "^4.9.0.80" opencv-python = ">=4.9.0"
diffusers = "^0.27.2" diffusers = "^0.27.2"
torchvision = "^0.18.0" torchvision = ">=0.18.0"
h5py = "^3.10.0" h5py = ">=3.10.0"
huggingface-hub = "^0.21.4" huggingface-hub = ">=0.21.4"
robomimic = "0.2.0" robomimic = "0.2.0"
gymnasium = "^0.29.1" gymnasium = ">=0.29.1"
cmake = "^3.29.0.1" cmake = ">=3.29.0.1"
gym-pusht = { version = "^0.1.1", optional = true} gym-pusht = { version = ">=0.1.3", optional = true}
gym-xarm = { version = "^0.1.0", optional = true} gym-xarm = { version = ">=0.1.1", optional = true}
gym-aloha = { version = "^0.1.0", optional = true} gym-aloha = { version = ">=0.1.1", optional = true}
pre-commit = {version = "^3.7.0", optional = true} pre-commit = {version = ">=3.7.0", optional = true}
debugpy = {version = "^1.8.1", optional = true} debugpy = {version = ">=1.8.1", optional = true}
pytest = {version = "^8.1.0", optional = true} pytest = {version = ">=8.1.0", optional = true}
pytest-cov = {version = "^5.0.0", optional = true} pytest-cov = {version = ">=5.0.0", optional = true}
datasets = "^2.19.0" datasets = ">=2.19.0"
imagecodecs = { version = "^2024.1.1", optional = true } imagecodecs = { version = ">=2024.1.1", optional = true }
pyav = "^12.0.5" pyav = ">=12.0.5"
moviepy = "^1.0.3" moviepy = ">=1.0.3"
rerun-sdk = "^0.15.1" rerun-sdk = ">=0.15.1"
[tool.poetry.extras] [tool.poetry.extras]
@ -104,5 +104,5 @@ ignore-init-module-imports = true
[build-system] [build-system]
requires = ["poetry-core>=1.5.0"] requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"

View File

@ -236,7 +236,7 @@ def test_normalize(insert_temporal_dim):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"env_name, policy_name, extra_overrides", "env_name, policy_name, extra_overrides",
[ [
# ("xarm", "tdmpc", ["policy.n_action_repeats=2"]), ("xarm", "tdmpc", []),
( (
"pusht", "pusht",
"diffusion", "diffusion",