From 659ec4434d4aa9cf633477be3b1cbf1dd6f4c2bf Mon Sep 17 00:00:00 2001 From: Simon Alibert <75076266+aliberts@users.noreply.github.com> Date: Wed, 26 Feb 2025 16:36:03 +0100 Subject: [PATCH] Fix nightly (#775) --- .github/workflows/test-docker-build.yml | 2 +- docker/lerobot-cpu/Dockerfile | 38 +++++------ docker/lerobot-gpu/Dockerfile | 7 +- .../aloha_act/actions.safetensors | 3 - .../aloha_act/grad_stats.safetensors | 3 - .../aloha_act/output_dict.safetensors | 3 - .../aloha_act/param_stats.safetensors | 3 - .../aloha_act_1000_steps/actions.safetensors | 3 - .../grad_stats.safetensors | 3 - .../output_dict.safetensors | 3 - .../param_stats.safetensors | 3 - .../actions.safetensors | 3 + .../grad_stats.safetensors | 3 + .../output_dict.safetensors | 3 + .../param_stats.safetensors | 3 + .../actions.safetensors | 3 + .../grad_stats.safetensors | 3 + .../output_dict.safetensors | 3 + .../param_stats.safetensors | 3 + .../actions.safetensors | 3 - .../grad_stats.safetensors | 3 - .../output_dict.safetensors | 3 - .../param_stats.safetensors | 3 - .../pusht_diffusion/actions.safetensors | 3 - .../pusht_diffusion/grad_stats.safetensors | 3 - .../pusht_diffusion/output_dict.safetensors | 3 - .../pusht_diffusion/param_stats.safetensors | 3 - .../pusht_diffusion_/actions.safetensors | 3 + .../pusht_diffusion_/grad_stats.safetensors | 3 + .../pusht_diffusion_/output_dict.safetensors | 3 + .../pusht_diffusion_/param_stats.safetensors | 3 + .../actions.safetensors | 3 + .../grad_stats.safetensors | 3 + .../output_dict.safetensors | 3 + .../param_stats.safetensors | 3 + .../actions.safetensors | 3 + .../grad_stats.safetensors | 3 + .../output_dict.safetensors | 3 + .../param_stats.safetensors | 3 + .../xarm_tdmpcuse_mpc/actions.safetensors | 3 - .../xarm_tdmpcuse_mpc/grad_stats.safetensors | 3 - .../xarm_tdmpcuse_mpc/output_dict.safetensors | 3 - .../xarm_tdmpcuse_mpc/param_stats.safetensors | 3 - .../xarm_tdmpcuse_policy/actions.safetensors | 3 - .../grad_stats.safetensors | 3 - .../output_dict.safetensors | 3 - .../param_stats.safetensors | 3 - tests/scripts/save_policy_to_safetensors.py | 60 ++++++++--------- tests/test_online_buffer.py | 10 +-- tests/test_policies.py | 66 +++++++++---------- tests/test_robots.py | 2 +- 51 files changed, 145 insertions(+), 172 deletions(-) delete mode 100644 tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors delete mode 100644 tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors delete mode 100644 tests/data/save_policy_to_safetensors/aloha_act/output_dict.safetensors delete mode 100644 tests/data/save_policy_to_safetensors/aloha_act/param_stats.safetensors delete mode 100644 tests/data/save_policy_to_safetensors/aloha_act_1000_steps/actions.safetensors delete mode 100644 tests/data/save_policy_to_safetensors/aloha_act_1000_steps/grad_stats.safetensors delete mode 100644 tests/data/save_policy_to_safetensors/aloha_act_1000_steps/output_dict.safetensors delete mode 100644 tests/data/save_policy_to_safetensors/aloha_act_1000_steps/param_stats.safetensors create mode 100644 tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_/actions.safetensors create mode 100644 tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_/grad_stats.safetensors create mode 100644 tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_/output_dict.safetensors create mode 100644 tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_/param_stats.safetensors create mode 100644 tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_1000_steps/actions.safetensors create mode 100644 tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_1000_steps/grad_stats.safetensors create mode 100644 tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_1000_steps/output_dict.safetensors create mode 100644 tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors delete mode 100644 tests/data/save_policy_to_safetensors/dora_aloha_real_act_aloha_real/actions.safetensors delete mode 100644 tests/data/save_policy_to_safetensors/dora_aloha_real_act_aloha_real/grad_stats.safetensors delete mode 100644 tests/data/save_policy_to_safetensors/dora_aloha_real_act_aloha_real/output_dict.safetensors delete mode 100644 tests/data/save_policy_to_safetensors/dora_aloha_real_act_aloha_real/param_stats.safetensors delete mode 100644 tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors delete mode 100644 tests/data/save_policy_to_safetensors/pusht_diffusion/grad_stats.safetensors delete mode 100644 tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors delete mode 100644 tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors create mode 100644 tests/data/save_policy_to_safetensors/pusht_diffusion_/actions.safetensors create mode 100644 tests/data/save_policy_to_safetensors/pusht_diffusion_/grad_stats.safetensors create mode 100644 tests/data/save_policy_to_safetensors/pusht_diffusion_/output_dict.safetensors create mode 100644 tests/data/save_policy_to_safetensors/pusht_diffusion_/param_stats.safetensors create mode 100644 tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_mpc/actions.safetensors create mode 100644 tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_mpc/grad_stats.safetensors create mode 100644 tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_mpc/output_dict.safetensors create mode 100644 tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_mpc/param_stats.safetensors create mode 100644 tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_policy/actions.safetensors create mode 100644 tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_policy/grad_stats.safetensors create mode 100644 tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_policy/output_dict.safetensors create mode 100644 tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_policy/param_stats.safetensors delete mode 100644 tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/actions.safetensors delete mode 100644 tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/grad_stats.safetensors delete mode 100644 tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/output_dict.safetensors delete mode 100644 tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/param_stats.safetensors delete mode 100644 tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/actions.safetensors delete mode 100644 tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/grad_stats.safetensors delete mode 100644 tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/output_dict.safetensors delete mode 100644 tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/param_stats.safetensors diff --git a/.github/workflows/test-docker-build.yml b/.github/workflows/test-docker-build.yml index 4d6e9ce5..3ee84a27 100644 --- a/.github/workflows/test-docker-build.yml +++ b/.github/workflows/test-docker-build.yml @@ -43,7 +43,7 @@ jobs: needs: get_changed_files runs-on: group: aws-general-8-plus - if: ${{ needs.get_changed_files.outputs.matrix }} != '' + if: needs.get_changed_files.outputs.matrix != '' strategy: fail-fast: false matrix: diff --git a/docker/lerobot-cpu/Dockerfile b/docker/lerobot-cpu/Dockerfile index 06673092..13a45d24 100644 --- a/docker/lerobot-cpu/Dockerfile +++ b/docker/lerobot-cpu/Dockerfile @@ -1,33 +1,29 @@ # Configure image ARG PYTHON_VERSION=3.10 - FROM python:${PYTHON_VERSION}-slim -ARG PYTHON_VERSION -ARG DEBIAN_FRONTEND=noninteractive -# Install apt dependencies +# Configure environment variables +ARG PYTHON_VERSION +ENV DEBIAN_FRONTEND=noninteractive +ENV MUJOCO_GL="egl" +ENV PATH="/opt/venv/bin:$PATH" + +# Install dependencies and set up Python in a single layer RUN apt-get update && apt-get install -y --no-install-recommends \ - build-essential cmake git git-lfs \ + build-essential cmake git \ libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \ speech-dispatcher libgeos-dev \ - && apt-get clean && rm -rf /var/lib/apt/lists/* + && ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python \ + && python -m venv /opt/venv \ + && apt-get clean && rm -rf /var/lib/apt/lists/* \ + && echo "source /opt/venv/bin/activate" >> /root/.bashrc -# Create virtual environment -RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python -RUN python -m venv /opt/venv -ENV PATH="/opt/venv/bin:$PATH" -RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc - -# Install LeRobot -RUN git lfs install -RUN git clone https://github.com/huggingface/lerobot.git /lerobot +# Clone repository and install LeRobot in a single layer +COPY . /lerobot WORKDIR /lerobot -RUN pip install --upgrade --no-cache-dir pip -RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" \ - --extra-index-url https://download.pytorch.org/whl/cpu - -# Set EGL as the rendering backend for MuJoCo -ENV MUJOCO_GL="egl" +RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \ + && /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" \ + --extra-index-url https://download.pytorch.org/whl/cpu # Execute in bash shell rather than python CMD ["/bin/bash"] diff --git a/docker/lerobot-gpu/Dockerfile b/docker/lerobot-gpu/Dockerfile index b2898b97..642a8ded 100644 --- a/docker/lerobot-gpu/Dockerfile +++ b/docker/lerobot-gpu/Dockerfile @@ -8,7 +8,7 @@ ENV PATH="/opt/venv/bin:$PATH" # Install dependencies and set up Python in a single layer RUN apt-get update && apt-get install -y --no-install-recommends \ - build-essential cmake git git-lfs \ + build-essential cmake git \ libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \ speech-dispatcher libgeos-dev \ python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \ @@ -18,8 +18,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ && echo "source /opt/venv/bin/activate" >> /root/.bashrc # Clone repository and install LeRobot in a single layer +COPY . /lerobot WORKDIR /lerobot -RUN git lfs install \ - && git clone https://github.com/huggingface/lerobot.git . \ - && /opt/venv/bin/pip install --upgrade --no-cache-dir pip \ +RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \ && /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" diff --git a/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors deleted file mode 100644 index 2dd4dda3..00000000 --- a/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:eb7b74f919adf8d4478585f65c54997e6f3bccab67eadb4048300108586a4163 -size 5104 diff --git a/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors deleted file mode 100644 index cd966518..00000000 --- a/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:dfbc3b1ad5e3b94311edda0f04db002b26117b0719b73dfdb56dd483dc9c409d -size 31672 diff --git a/tests/data/save_policy_to_safetensors/aloha_act/output_dict.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/output_dict.safetensors deleted file mode 100644 index e957acb8..00000000 --- a/tests/data/save_policy_to_safetensors/aloha_act/output_dict.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e39afdf1f3db8a72a1095a5a0ffdb7e67f478a28bd73e59cda197687da8d236c -size 68 diff --git a/tests/data/save_policy_to_safetensors/aloha_act/param_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/param_stats.safetensors deleted file mode 100644 index 35ba61bd..00000000 --- a/tests/data/save_policy_to_safetensors/aloha_act/param_stats.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:5dd39a554c9c3db537e98c9ceade024d172c46c4fa7ce9e27601b94116445417 -size 33400 diff --git a/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/actions.safetensors b/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/actions.safetensors deleted file mode 100644 index ababdedf..00000000 --- a/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/actions.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:a5ec46abc5a3c85675a5ee4a1bb362eecb3ff4c546082ff309c89fc7821f38bd -size 515400 diff --git a/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/grad_stats.safetensors deleted file mode 100644 index e0b2f54a..00000000 --- a/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/grad_stats.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:50303d05caea725c4a240f1389424d6c2361961f2cee729a0010e909ebffed81 -size 31672 diff --git a/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/output_dict.safetensors b/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/output_dict.safetensors deleted file mode 100644 index 3c5d3b93..00000000 --- a/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/output_dict.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:9bb9b195d32e05550af0edd5df88fcc761c829ab8c4b129ba970a723f39b46ee -size 68 diff --git a/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/param_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/param_stats.safetensors deleted file mode 100644 index 88d3106e..00000000 --- a/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/param_stats.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:683a2038185f3d070e7d7c0c31e4aa75067c11bf798daa41c9fab336f4183fda -size 33400 diff --git a/tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_/actions.safetensors b/tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_/actions.safetensors new file mode 100644 index 00000000..6fec6b22 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_/actions.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc67af1d60f95d84c98d6c9ebd648990e0f0705368bd6b72d2b39533950b0179 +size 5104 diff --git a/tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_/grad_stats.safetensors new file mode 100644 index 00000000..7136a69f --- /dev/null +++ b/tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_/grad_stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:64518cf652105d15f5fd2cfc13d0681f66a4ec4797dc5d5dc2f7b0d91fe5dfd6 +size 31672 diff --git a/tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_/output_dict.safetensors b/tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_/output_dict.safetensors new file mode 100644 index 00000000..864feebe --- /dev/null +++ b/tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_/output_dict.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:32b6d14fab4244b5140adb345e47f662b6739c04974e04b21c3127caa988abbb +size 68 diff --git a/tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_/param_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_/param_stats.safetensors new file mode 100644 index 00000000..bbabade6 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_/param_stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e1904ef0338f7b6efdec70ec235ee931b5751008bf4eb433edb0b3fa0838a4f1 +size 33400 diff --git a/tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_1000_steps/actions.safetensors b/tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_1000_steps/actions.safetensors new file mode 100644 index 00000000..1093b45d --- /dev/null +++ b/tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_1000_steps/actions.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa544a97f00bf46393a09b006b44c2499bbf7d177782360a8c21cacbf200c07a +size 515400 diff --git a/tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_1000_steps/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_1000_steps/grad_stats.safetensors new file mode 100644 index 00000000..092e0040 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_1000_steps/grad_stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:83c7a8ae912300b5cedba31904f7ba22542059fd60dd86548a95e415713f719e +size 31672 diff --git a/tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_1000_steps/output_dict.safetensors b/tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_1000_steps/output_dict.safetensors new file mode 100644 index 00000000..6561116c --- /dev/null +++ b/tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_1000_steps/output_dict.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5a010633237b3a1141603c65174c551daa9e7b4c474af5a1376d73e5425bfb5d +size 68 diff --git a/tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors new file mode 100644 index 00000000..09772ea3 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec8b5c440e9fcec190c9be48b28ebb79f82ae63626afe7c811e4bb0c3dd08842 +size 33400 diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_aloha_real/actions.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_aloha_real/actions.safetensors deleted file mode 100644 index 40434950..00000000 --- a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_aloha_real/actions.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e56a5d30778395534a06ad1742843700424614168fc26d1098558012a5df90c6 -size 5104 diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_aloha_real/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_aloha_real/grad_stats.safetensors deleted file mode 100644 index a8c15716..00000000 --- a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_aloha_real/grad_stats.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c9007dd51c748db4ecd6d75e70bdcabf8c312454ac97bf6710895a12e7288557 -size 31672 diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_aloha_real/output_dict.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_aloha_real/output_dict.safetensors deleted file mode 100644 index 95c598c7..00000000 --- a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_aloha_real/output_dict.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:170bd8365dfd1e36e8f56814bf8bc2057aa0d035c41212b7ddd7e4b9feee1633 -size 68 diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_aloha_real/param_stats.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_aloha_real/param_stats.safetensors deleted file mode 100644 index 09a11d73..00000000 --- a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_aloha_real/param_stats.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:11884346b41ca102c672bb0f361ea9699d2f8b33bb503038b53cc7e7fafd281b -size 34920 diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors deleted file mode 100644 index b021f63c..00000000 --- a/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:0c259ea9c40aab3841ca35b2a2e708d8829b0a9163b2f9e5efd28f1c65848293 -size 4600 diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/grad_stats.safetensors deleted file mode 100644 index ad0300ca..00000000 --- a/tests/data/save_policy_to_safetensors/pusht_diffusion/grad_stats.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:77cd4127a45ded2f75d85ca9c17537808517614ef16fb3035cebb1b45547acbf -size 47424 diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors deleted file mode 100644 index 9c7143e5..00000000 --- a/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:fcff4b736e95d685d56830b501f4542b081f4334f72d28a7415809f4d9d15d0f -size 68 diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors deleted file mode 100644 index 1efb0765..00000000 --- a/tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:60775e91ed550aae66cb0547ee4b0e38917f29172e942671e9361b3812364df6 -size 49120 diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion_/actions.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion_/actions.safetensors new file mode 100644 index 00000000..84e14b97 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/pusht_diffusion_/actions.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a32376dde65a1562403afd1db3e56c7e6b987ebaf6c3c601336e77155b9e608c +size 992 diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion_/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion_/grad_stats.safetensors new file mode 100644 index 00000000..54229791 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/pusht_diffusion_/grad_stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:12ee532c53173d0361ebb979f087b229cc045aa3d9e6b94cfd4290af54fd1201 +size 47424 diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion_/output_dict.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion_/output_dict.safetensors new file mode 100644 index 00000000..f2930399 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/pusht_diffusion_/output_dict.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:271b00cb2f0cd5fd26b1d53463638e3d1a6e92692ec625fcffb420ca190869e5 +size 68 diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion_/param_stats.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion_/param_stats.safetensors new file mode 100644 index 00000000..e91cd08b --- /dev/null +++ b/tests/data/save_policy_to_safetensors/pusht_diffusion_/param_stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:010c01181b95625051276d69cb4209423c21f2e30a3fa9464ae67064a2ba4c22 +size 49120 diff --git a/tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_mpc/actions.safetensors b/tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_mpc/actions.safetensors new file mode 100644 index 00000000..fa9bf06a --- /dev/null +++ b/tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_mpc/actions.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c5edc5600d7206f027cb696a597bc99fcdd9073a15fa130b8031c52c0a7c134b +size 200 diff --git a/tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_mpc/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_mpc/grad_stats.safetensors new file mode 100644 index 00000000..8d90a671 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_mpc/grad_stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10 +size 16904 diff --git a/tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_mpc/output_dict.safetensors b/tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_mpc/output_dict.safetensors new file mode 100644 index 00000000..cde6c6dc --- /dev/null +++ b/tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_mpc/output_dict.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b +size 164 diff --git a/tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_mpc/param_stats.safetensors b/tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_mpc/param_stats.safetensors new file mode 100644 index 00000000..692377d1 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_mpc/param_stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170 +size 36312 diff --git a/tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_policy/actions.safetensors b/tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_policy/actions.safetensors new file mode 100644 index 00000000..7a0b165e --- /dev/null +++ b/tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_policy/actions.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a9c08753ddc43b6c02a176418b81eb784146e59f4fc914591cbd3582ade392bb +size 200 diff --git a/tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_policy/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_policy/grad_stats.safetensors new file mode 100644 index 00000000..8d90a671 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_policy/grad_stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10 +size 16904 diff --git a/tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_policy/output_dict.safetensors b/tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_policy/output_dict.safetensors new file mode 100644 index 00000000..cde6c6dc --- /dev/null +++ b/tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_policy/output_dict.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b +size 164 diff --git a/tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_policy/param_stats.safetensors b/tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_policy/param_stats.safetensors new file mode 100644 index 00000000..692377d1 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/xarm_lift_medium_tdmpc_use_policy/param_stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170 +size 36312 diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/actions.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/actions.safetensors deleted file mode 100644 index e2fb68ac..00000000 --- a/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/actions.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:81457cfd193d9d46b6871071a3971c2901fefa544ab225576132772087b4cf3a -size 472 diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/grad_stats.safetensors deleted file mode 100644 index cf756229..00000000 --- a/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/grad_stats.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d796577863740e8fd643a056e9eff891e51a858ff66019eba11f0a982cb9e9c0 -size 16904 diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/output_dict.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/output_dict.safetensors deleted file mode 100644 index f8863cfb..00000000 --- a/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/output_dict.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:4636751d82103a268ac7cf36f1e69f6356f356b9c40561a9fe8557bb9255e2ee -size 240 diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/param_stats.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/param_stats.safetensors deleted file mode 100644 index 8ce3c4f3..00000000 --- a/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/param_stats.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b7d08c9518f1f15226e4efc6f2a8542d0f3e620c91421c7cacea07d9bd9025d6 -size 36312 diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/actions.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/actions.safetensors deleted file mode 100644 index 1b3912ed..00000000 --- a/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/actions.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:6cdb181ba6acc4aa1209a9ea5dd783f077ff87760257de1026c33f8e2fb2b2b1 -size 472 diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/grad_stats.safetensors deleted file mode 100644 index cf756229..00000000 --- a/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/grad_stats.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d796577863740e8fd643a056e9eff891e51a858ff66019eba11f0a982cb9e9c0 -size 16904 diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/output_dict.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/output_dict.safetensors deleted file mode 100644 index f8863cfb..00000000 --- a/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/output_dict.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:4636751d82103a268ac7cf36f1e69f6356f356b9c40561a9fe8557bb9255e2ee -size 240 diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/param_stats.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/param_stats.safetensors deleted file mode 100644 index 8ce3c4f3..00000000 --- a/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/param_stats.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b7d08c9518f1f15226e4efc6f2a8542d0f3e620c91421c7cacea07d9bd9025d6 -size 36312 diff --git a/tests/scripts/save_policy_to_safetensors.py b/tests/scripts/save_policy_to_safetensors.py index de784db3..03726163 100644 --- a/tests/scripts/save_policy_to_safetensors.py +++ b/tests/scripts/save_policy_to_safetensors.py @@ -27,16 +27,13 @@ from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig -def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs): - # TODO(rcadene, aliberts): env_name? +def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): set_seed(1337) - train_cfg = TrainPipelineConfig( # TODO(rcadene, aliberts): remove dataset download dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]), policy=make_policy_config(policy_name, **policy_kwargs), device="cpu", - **train_kwargs, ) train_cfg.validate() # Needed for auto-setting some parameters @@ -54,8 +51,11 @@ def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwa batch = next(iter(dataloader)) loss, output_dict = policy.forward(batch) - output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)} - output_dict["loss"] = loss + if output_dict is not None: + output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)} + output_dict["loss"] = loss + else: + output_dict = {"loss": loss} loss.backward() grad_stats = {} @@ -101,30 +101,27 @@ def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwa return output_dict, grad_stats, param_stats, actions -def save_policy_to_safetensors(output_dir, env_name, policy_name, policy_kwargs, file_name_extra): - env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}{file_name_extra}" +def save_policy_to_safetensors(output_dir: Path, ds_repo_id: str, policy_name: str, policy_kwargs: dict): + if output_dir.exists(): + print(f"Overwrite existing safetensors in '{output_dir}':") + print(f" - Validate with: `git add {output_dir}`") + print(f" - Revert with: `git checkout -- {output_dir}`") + shutil.rmtree(output_dir) - if env_policy_dir.exists(): - print(f"Overwrite existing safetensors in '{env_policy_dir}':") - print(f" - Validate with: `git add {env_policy_dir}`") - print(f" - Revert with: `git checkout -- {env_policy_dir}`") - shutil.rmtree(env_policy_dir) - - env_policy_dir.mkdir(parents=True, exist_ok=True) - output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, policy_kwargs) - save_file(output_dict, env_policy_dir / "output_dict.safetensors") - save_file(grad_stats, env_policy_dir / "grad_stats.safetensors") - save_file(param_stats, env_policy_dir / "param_stats.safetensors") - save_file(actions, env_policy_dir / "actions.safetensors") + output_dir.mkdir(parents=True, exist_ok=True) + output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs) + save_file(output_dict, output_dir / "output_dict.safetensors") + save_file(grad_stats, output_dir / "grad_stats.safetensors") + save_file(param_stats, output_dir / "param_stats.safetensors") + save_file(actions, output_dir / "actions.safetensors") if __name__ == "__main__": - env_policies = [ - ("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": False}, "use_policy"), - ("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": True}, "use_mpc"), + artifacts_cfg = [ + ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"), + ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"), ( "lerobot/pusht", - "pusht", "diffusion", { "n_action_steps": 8, @@ -133,18 +130,17 @@ if __name__ == "__main__": }, "", ), - ("lerobot/aloha_sim_insertion_human", "aloha", "act", {"n_action_steps": 10}, ""), + ("lerobot/aloha_sim_insertion_human", "act", {"n_action_steps": 10}, ""), ( "lerobot/aloha_sim_insertion_human", - "aloha", "act", {"n_action_steps": 1000, "chunk_size": 1000}, - "_1000_steps", + "1000_steps", ), ] - if len(env_policies) == 0: + if len(artifacts_cfg) == 0: raise RuntimeError("No policies were provided!") - for ds_repo_id, env, policy, policy_kwargs, file_name_extra in env_policies: - save_policy_to_safetensors( - "tests/data/save_policy_to_safetensors", ds_repo_id, env, policy, policy_kwargs, file_name_extra - ) + for ds_repo_id, policy, policy_kwargs, file_name_extra in artifacts_cfg: + ds_name = ds_repo_id.split("/")[-1] + output_dir = Path("tests/data/save_policy_to_safetensors") / f"{ds_name}_{policy}_{file_name_extra}" + save_policy_to_safetensors(output_dir, ds_repo_id, policy, policy_kwargs) diff --git a/tests/test_online_buffer.py b/tests/test_online_buffer.py index db53808d..339f6848 100644 --- a/tests/test_online_buffer.py +++ b/tests/test_online_buffer.py @@ -166,7 +166,7 @@ def test_delta_timestamps_within_tolerance(): buffer.tolerance_s = 0.04 item = buffer[2] data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"] - assert torch.allclose(data, torch.tensor([0, 2, 3])), "Data does not match expected values" + torch.testing.assert_close(data, torch.tensor([0, 2, 3]), msg="Data does not match expected values") assert not is_pad.any(), "Unexpected padding detected" @@ -236,7 +236,7 @@ def test_compute_sampler_weights_trivial( elif online_sampling_ratio == 1: expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)]) expected_weights /= expected_weights.sum() - assert torch.allclose(weights, expected_weights) + torch.testing.assert_close(weights, expected_weights) def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path): @@ -248,7 +248,7 @@ def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_p weights = compute_sampler_weights( offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio ) - assert torch.allclose( + torch.testing.assert_close( weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]) ) @@ -261,7 +261,7 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_datase weights = compute_sampler_weights( offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1 ) - assert torch.allclose( + torch.testing.assert_close( weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0]) ) @@ -279,4 +279,4 @@ def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp online_sampling_ratio=0.5, online_drop_n_last_frames=1, ) - assert torch.allclose(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0])) + torch.testing.assert_close(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0])) diff --git a/tests/test_policies.py b/tests/test_policies.py index 27cf49f8..9dab6176 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -363,37 +363,33 @@ def test_normalize(insert_temporal_dim): @pytest.mark.parametrize( - "ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs, file_name_extra", + "ds_repo_id, policy_name, policy_kwargs, file_name_extra", [ # TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it # was changed to true. For some reason, tests would pass locally, but not in CI. So here we override # to test with `policy.use_mpc=false`. - ("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": False}, {"batch_size": 25}, "use_policy"), - # ("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": True}, {}, "use_mpc"), + ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"), + ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"), # TODO(rcadene): the diffusion model was normalizing the image in mean=0.5 std=0.5 which is a hack supposed to # to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference # that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass. # Thus, we deactivate this test for now. - # ( - # "lerobot/pusht", - # "pusht", - # "diffusion", - # { - # "n_action_steps": 8, - # "num_inference_steps": 10, - # "down_dims": [128, 256, 512], - # }, - # {"batch_size": 64}, - # "", - # ), - ("lerobot/aloha_sim_insertion_human", "aloha", "act", {"n_action_steps": 10}, {}, ""), + ( + "lerobot/pusht", + "diffusion", + { + "n_action_steps": 8, + "num_inference_steps": 10, + "down_dims": [128, 256, 512], + }, + "", + ), + ("lerobot/aloha_sim_insertion_human", "act", {"n_action_steps": 10}, ""), ( "lerobot/aloha_sim_insertion_human", - "aloha", "act", {"n_action_steps": 1000, "chunk_size": 1000}, - {}, - "_1000_steps", + "1000_steps", ), ], ) @@ -401,9 +397,7 @@ def test_normalize(insert_temporal_dim): # pass if it's run on another platform due to floating point errors @require_x86_64_kernel @require_cpu -def test_backward_compatibility( - ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs, file_name_extra -): +def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs: dict, file_name_extra: str): """ NOTE: If this test does not pass, and you have intentionally changed something in the policy: 1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should @@ -416,26 +410,26 @@ def test_backward_compatibility( 5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state. 6. Remember to stage and commit the resulting changes to `tests/data`. """ - env_policy_dir = ( - Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}{file_name_extra}" + ds_name = ds_repo_id.split("/")[-1] + artifact_dir = ( + Path("tests/data/save_policy_to_safetensors") / f"{ds_name}_{policy_name}_{file_name_extra}" ) - saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors") - saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors") - saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors") - saved_actions = load_file(env_policy_dir / "actions.safetensors") + saved_output_dict = load_file(artifact_dir / "output_dict.safetensors") + saved_grad_stats = load_file(artifact_dir / "grad_stats.safetensors") + saved_param_stats = load_file(artifact_dir / "param_stats.safetensors") + saved_actions = load_file(artifact_dir / "actions.safetensors") - output_dict, grad_stats, param_stats, actions = get_policy_stats( - ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs - ) + output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs) for key in saved_output_dict: - assert torch.allclose(output_dict[key], saved_output_dict[key], rtol=0.1, atol=1e-7) + torch.testing.assert_close(output_dict[key], saved_output_dict[key]) for key in saved_grad_stats: - assert torch.allclose(grad_stats[key], saved_grad_stats[key], rtol=0.1, atol=1e-7) + torch.testing.assert_close(grad_stats[key], saved_grad_stats[key]) for key in saved_param_stats: - assert torch.allclose(param_stats[key], saved_param_stats[key], rtol=0.1, atol=1e-7) + torch.testing.assert_close(param_stats[key], saved_param_stats[key]) for key in saved_actions: - assert torch.allclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7) + rtol, atol = (2e-3, 5e-6) if policy_name == "diffusion" else (None, None) # HACK + torch.testing.assert_close(actions[key], saved_actions[key], rtol=rtol, atol=atol) def test_act_temporal_ensembler(): @@ -490,4 +484,4 @@ def test_act_temporal_ensembler(): assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg) assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max")) # Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error. - assert torch.allclose(online_avg, offline_avg, atol=1e-4) + torch.testing.assert_close(online_avg, offline_avg, rtol=1e-4, atol=1e-4) diff --git a/tests/test_robots.py b/tests/test_robots.py index fe440da8..c5734a4c 100644 --- a/tests/test_robots.py +++ b/tests/test_robots.py @@ -114,7 +114,7 @@ def test_robot(tmp_path, request, robot_type, mock): if "image" in name: # TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames continue - assert torch.allclose(captured_observation[name], observation[name], atol=1) + torch.testing.assert_close(captured_observation[name], observation[name], rtol=1e-4, atol=1) assert captured_observation[name].shape == observation[name].shape # Test send_action can run