Fix nightly (#775)

This commit is contained in:
Simon Alibert 2025-02-26 16:36:03 +01:00 committed by GitHub
parent da265ca920
commit 659ec4434d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
51 changed files with 145 additions and 172 deletions

View File

@ -43,7 +43,7 @@ jobs:
needs: get_changed_files needs: get_changed_files
runs-on: runs-on:
group: aws-general-8-plus group: aws-general-8-plus
if: ${{ needs.get_changed_files.outputs.matrix }} != '' if: needs.get_changed_files.outputs.matrix != ''
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:

View File

@ -1,33 +1,29 @@
# Configure image # Configure image
ARG PYTHON_VERSION=3.10 ARG PYTHON_VERSION=3.10
FROM python:${PYTHON_VERSION}-slim FROM python:${PYTHON_VERSION}-slim
ARG PYTHON_VERSION
ARG DEBIAN_FRONTEND=noninteractive
# Install apt dependencies # Configure environment variables
ARG PYTHON_VERSION
ENV DEBIAN_FRONTEND=noninteractive
ENV MUJOCO_GL="egl"
ENV PATH="/opt/venv/bin:$PATH"
# Install dependencies and set up Python in a single layer
RUN apt-get update && apt-get install -y --no-install-recommends \ RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake git git-lfs \ build-essential cmake git \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \ libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
speech-dispatcher libgeos-dev \ speech-dispatcher libgeos-dev \
&& apt-get clean && rm -rf /var/lib/apt/lists/* && ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python \
&& python -m venv /opt/venv \
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
&& echo "source /opt/venv/bin/activate" >> /root/.bashrc
# Create virtual environment # Clone repository and install LeRobot in a single layer
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python COPY . /lerobot
RUN python -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
# Install LeRobot
RUN git lfs install
RUN git clone https://github.com/huggingface/lerobot.git /lerobot
WORKDIR /lerobot WORKDIR /lerobot
RUN pip install --upgrade --no-cache-dir pip RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" \ && /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" \
--extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://download.pytorch.org/whl/cpu
# Set EGL as the rendering backend for MuJoCo
ENV MUJOCO_GL="egl"
# Execute in bash shell rather than python # Execute in bash shell rather than python
CMD ["/bin/bash"] CMD ["/bin/bash"]

View File

@ -8,7 +8,7 @@ ENV PATH="/opt/venv/bin:$PATH"
# Install dependencies and set up Python in a single layer # Install dependencies and set up Python in a single layer
RUN apt-get update && apt-get install -y --no-install-recommends \ RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake git git-lfs \ build-essential cmake git \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \ libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
speech-dispatcher libgeos-dev \ speech-dispatcher libgeos-dev \
python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \ python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
@ -18,8 +18,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
&& echo "source /opt/venv/bin/activate" >> /root/.bashrc && echo "source /opt/venv/bin/activate" >> /root/.bashrc
# Clone repository and install LeRobot in a single layer # Clone repository and install LeRobot in a single layer
COPY . /lerobot
WORKDIR /lerobot WORKDIR /lerobot
RUN git lfs install \ RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
&& git clone https://github.com/huggingface/lerobot.git . \
&& /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
&& /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" && /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]"

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:eb7b74f919adf8d4478585f65c54997e6f3bccab67eadb4048300108586a4163
size 5104

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:dfbc3b1ad5e3b94311edda0f04db002b26117b0719b73dfdb56dd483dc9c409d
size 31672

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e39afdf1f3db8a72a1095a5a0ffdb7e67f478a28bd73e59cda197687da8d236c
size 68

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5dd39a554c9c3db537e98c9ceade024d172c46c4fa7ce9e27601b94116445417
size 33400

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a5ec46abc5a3c85675a5ee4a1bb362eecb3ff4c546082ff309c89fc7821f38bd
size 515400

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:50303d05caea725c4a240f1389424d6c2361961f2cee729a0010e909ebffed81
size 31672

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9bb9b195d32e05550af0edd5df88fcc761c829ab8c4b129ba970a723f39b46ee
size 68

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:683a2038185f3d070e7d7c0c31e4aa75067c11bf798daa41c9fab336f4183fda
size 33400

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cc67af1d60f95d84c98d6c9ebd648990e0f0705368bd6b72d2b39533950b0179
size 5104

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:64518cf652105d15f5fd2cfc13d0681f66a4ec4797dc5d5dc2f7b0d91fe5dfd6
size 31672

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:32b6d14fab4244b5140adb345e47f662b6739c04974e04b21c3127caa988abbb
size 68

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e1904ef0338f7b6efdec70ec235ee931b5751008bf4eb433edb0b3fa0838a4f1
size 33400

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fa544a97f00bf46393a09b006b44c2499bbf7d177782360a8c21cacbf200c07a
size 515400

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:83c7a8ae912300b5cedba31904f7ba22542059fd60dd86548a95e415713f719e
size 31672

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5a010633237b3a1141603c65174c551daa9e7b4c474af5a1376d73e5425bfb5d
size 68

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ec8b5c440e9fcec190c9be48b28ebb79f82ae63626afe7c811e4bb0c3dd08842
size 33400

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e56a5d30778395534a06ad1742843700424614168fc26d1098558012a5df90c6
size 5104

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c9007dd51c748db4ecd6d75e70bdcabf8c312454ac97bf6710895a12e7288557
size 31672

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:170bd8365dfd1e36e8f56814bf8bc2057aa0d035c41212b7ddd7e4b9feee1633
size 68

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:11884346b41ca102c672bb0f361ea9699d2f8b33bb503038b53cc7e7fafd281b
size 34920

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0c259ea9c40aab3841ca35b2a2e708d8829b0a9163b2f9e5efd28f1c65848293
size 4600

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:77cd4127a45ded2f75d85ca9c17537808517614ef16fb3035cebb1b45547acbf
size 47424

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fcff4b736e95d685d56830b501f4542b081f4334f72d28a7415809f4d9d15d0f
size 68

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:60775e91ed550aae66cb0547ee4b0e38917f29172e942671e9361b3812364df6
size 49120

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a32376dde65a1562403afd1db3e56c7e6b987ebaf6c3c601336e77155b9e608c
size 992

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:12ee532c53173d0361ebb979f087b229cc045aa3d9e6b94cfd4290af54fd1201
size 47424

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:271b00cb2f0cd5fd26b1d53463638e3d1a6e92692ec625fcffb420ca190869e5
size 68

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:010c01181b95625051276d69cb4209423c21f2e30a3fa9464ae67064a2ba4c22
size 49120

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c5edc5600d7206f027cb696a597bc99fcdd9073a15fa130b8031c52c0a7c134b
size 200

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10
size 16904

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b
size 164

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170
size 36312

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a9c08753ddc43b6c02a176418b81eb784146e59f4fc914591cbd3582ade392bb
size 200

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10
size 16904

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b
size 164

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170
size 36312

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:81457cfd193d9d46b6871071a3971c2901fefa544ab225576132772087b4cf3a
size 472

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d796577863740e8fd643a056e9eff891e51a858ff66019eba11f0a982cb9e9c0
size 16904

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4636751d82103a268ac7cf36f1e69f6356f356b9c40561a9fe8557bb9255e2ee
size 240

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b7d08c9518f1f15226e4efc6f2a8542d0f3e620c91421c7cacea07d9bd9025d6
size 36312

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6cdb181ba6acc4aa1209a9ea5dd783f077ff87760257de1026c33f8e2fb2b2b1
size 472

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d796577863740e8fd643a056e9eff891e51a858ff66019eba11f0a982cb9e9c0
size 16904

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4636751d82103a268ac7cf36f1e69f6356f356b9c40561a9fe8557bb9255e2ee
size 240

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b7d08c9518f1f15226e4efc6f2a8542d0f3e620c91421c7cacea07d9bd9025d6
size 36312

View File

@ -27,16 +27,13 @@ from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.train import TrainPipelineConfig
def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs): def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
# TODO(rcadene, aliberts): env_name?
set_seed(1337) set_seed(1337)
train_cfg = TrainPipelineConfig( train_cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download # TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]), dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
policy=make_policy_config(policy_name, **policy_kwargs), policy=make_policy_config(policy_name, **policy_kwargs),
device="cpu", device="cpu",
**train_kwargs,
) )
train_cfg.validate() # Needed for auto-setting some parameters train_cfg.validate() # Needed for auto-setting some parameters
@ -54,8 +51,11 @@ def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwa
batch = next(iter(dataloader)) batch = next(iter(dataloader))
loss, output_dict = policy.forward(batch) loss, output_dict = policy.forward(batch)
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)} if output_dict is not None:
output_dict["loss"] = loss output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
output_dict["loss"] = loss
else:
output_dict = {"loss": loss}
loss.backward() loss.backward()
grad_stats = {} grad_stats = {}
@ -101,30 +101,27 @@ def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwa
return output_dict, grad_stats, param_stats, actions return output_dict, grad_stats, param_stats, actions
def save_policy_to_safetensors(output_dir, env_name, policy_name, policy_kwargs, file_name_extra): def save_policy_to_safetensors(output_dir: Path, ds_repo_id: str, policy_name: str, policy_kwargs: dict):
env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}{file_name_extra}" if output_dir.exists():
print(f"Overwrite existing safetensors in '{output_dir}':")
print(f" - Validate with: `git add {output_dir}`")
print(f" - Revert with: `git checkout -- {output_dir}`")
shutil.rmtree(output_dir)
if env_policy_dir.exists(): output_dir.mkdir(parents=True, exist_ok=True)
print(f"Overwrite existing safetensors in '{env_policy_dir}':") output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs)
print(f" - Validate with: `git add {env_policy_dir}`") save_file(output_dict, output_dir / "output_dict.safetensors")
print(f" - Revert with: `git checkout -- {env_policy_dir}`") save_file(grad_stats, output_dir / "grad_stats.safetensors")
shutil.rmtree(env_policy_dir) save_file(param_stats, output_dir / "param_stats.safetensors")
save_file(actions, output_dir / "actions.safetensors")
env_policy_dir.mkdir(parents=True, exist_ok=True)
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, policy_kwargs)
save_file(output_dict, env_policy_dir / "output_dict.safetensors")
save_file(grad_stats, env_policy_dir / "grad_stats.safetensors")
save_file(param_stats, env_policy_dir / "param_stats.safetensors")
save_file(actions, env_policy_dir / "actions.safetensors")
if __name__ == "__main__": if __name__ == "__main__":
env_policies = [ artifacts_cfg = [
("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": False}, "use_policy"), ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"),
("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": True}, "use_mpc"), ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
( (
"lerobot/pusht", "lerobot/pusht",
"pusht",
"diffusion", "diffusion",
{ {
"n_action_steps": 8, "n_action_steps": 8,
@ -133,18 +130,17 @@ if __name__ == "__main__":
}, },
"", "",
), ),
("lerobot/aloha_sim_insertion_human", "aloha", "act", {"n_action_steps": 10}, ""), ("lerobot/aloha_sim_insertion_human", "act", {"n_action_steps": 10}, ""),
( (
"lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_insertion_human",
"aloha",
"act", "act",
{"n_action_steps": 1000, "chunk_size": 1000}, {"n_action_steps": 1000, "chunk_size": 1000},
"_1000_steps", "1000_steps",
), ),
] ]
if len(env_policies) == 0: if len(artifacts_cfg) == 0:
raise RuntimeError("No policies were provided!") raise RuntimeError("No policies were provided!")
for ds_repo_id, env, policy, policy_kwargs, file_name_extra in env_policies: for ds_repo_id, policy, policy_kwargs, file_name_extra in artifacts_cfg:
save_policy_to_safetensors( ds_name = ds_repo_id.split("/")[-1]
"tests/data/save_policy_to_safetensors", ds_repo_id, env, policy, policy_kwargs, file_name_extra output_dir = Path("tests/data/save_policy_to_safetensors") / f"{ds_name}_{policy}_{file_name_extra}"
) save_policy_to_safetensors(output_dir, ds_repo_id, policy, policy_kwargs)

View File

@ -166,7 +166,7 @@ def test_delta_timestamps_within_tolerance():
buffer.tolerance_s = 0.04 buffer.tolerance_s = 0.04
item = buffer[2] item = buffer[2]
data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"] data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"]
assert torch.allclose(data, torch.tensor([0, 2, 3])), "Data does not match expected values" torch.testing.assert_close(data, torch.tensor([0, 2, 3]), msg="Data does not match expected values")
assert not is_pad.any(), "Unexpected padding detected" assert not is_pad.any(), "Unexpected padding detected"
@ -236,7 +236,7 @@ def test_compute_sampler_weights_trivial(
elif online_sampling_ratio == 1: elif online_sampling_ratio == 1:
expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)]) expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)])
expected_weights /= expected_weights.sum() expected_weights /= expected_weights.sum()
assert torch.allclose(weights, expected_weights) torch.testing.assert_close(weights, expected_weights)
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path): def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path):
@ -248,7 +248,7 @@ def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_p
weights = compute_sampler_weights( weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
) )
assert torch.allclose( torch.testing.assert_close(
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]) weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
) )
@ -261,7 +261,7 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_datase
weights = compute_sampler_weights( weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1 offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1
) )
assert torch.allclose( torch.testing.assert_close(
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0]) weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0])
) )
@ -279,4 +279,4 @@ def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp
online_sampling_ratio=0.5, online_sampling_ratio=0.5,
online_drop_n_last_frames=1, online_drop_n_last_frames=1,
) )
assert torch.allclose(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0])) torch.testing.assert_close(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0]))

View File

@ -363,37 +363,33 @@ def test_normalize(insert_temporal_dim):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs, file_name_extra", "ds_repo_id, policy_name, policy_kwargs, file_name_extra",
[ [
# TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it # TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it
# was changed to true. For some reason, tests would pass locally, but not in CI. So here we override # was changed to true. For some reason, tests would pass locally, but not in CI. So here we override
# to test with `policy.use_mpc=false`. # to test with `policy.use_mpc=false`.
("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": False}, {"batch_size": 25}, "use_policy"), ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"),
# ("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": True}, {}, "use_mpc"), ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
# TODO(rcadene): the diffusion model was normalizing the image in mean=0.5 std=0.5 which is a hack supposed to # TODO(rcadene): the diffusion model was normalizing the image in mean=0.5 std=0.5 which is a hack supposed to
# to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference # to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference
# that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass. # that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass.
# Thus, we deactivate this test for now. # Thus, we deactivate this test for now.
# ( (
# "lerobot/pusht", "lerobot/pusht",
# "pusht", "diffusion",
# "diffusion", {
# { "n_action_steps": 8,
# "n_action_steps": 8, "num_inference_steps": 10,
# "num_inference_steps": 10, "down_dims": [128, 256, 512],
# "down_dims": [128, 256, 512], },
# }, "",
# {"batch_size": 64}, ),
# "", ("lerobot/aloha_sim_insertion_human", "act", {"n_action_steps": 10}, ""),
# ),
("lerobot/aloha_sim_insertion_human", "aloha", "act", {"n_action_steps": 10}, {}, ""),
( (
"lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_insertion_human",
"aloha",
"act", "act",
{"n_action_steps": 1000, "chunk_size": 1000}, {"n_action_steps": 1000, "chunk_size": 1000},
{}, "1000_steps",
"_1000_steps",
), ),
], ],
) )
@ -401,9 +397,7 @@ def test_normalize(insert_temporal_dim):
# pass if it's run on another platform due to floating point errors # pass if it's run on another platform due to floating point errors
@require_x86_64_kernel @require_x86_64_kernel
@require_cpu @require_cpu
def test_backward_compatibility( def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs: dict, file_name_extra: str):
ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs, file_name_extra
):
""" """
NOTE: If this test does not pass, and you have intentionally changed something in the policy: NOTE: If this test does not pass, and you have intentionally changed something in the policy:
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should 1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
@ -416,26 +410,26 @@ def test_backward_compatibility(
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state. 5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
6. Remember to stage and commit the resulting changes to `tests/data`. 6. Remember to stage and commit the resulting changes to `tests/data`.
""" """
env_policy_dir = ( ds_name = ds_repo_id.split("/")[-1]
Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}{file_name_extra}" artifact_dir = (
Path("tests/data/save_policy_to_safetensors") / f"{ds_name}_{policy_name}_{file_name_extra}"
) )
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors") saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")
saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors") saved_grad_stats = load_file(artifact_dir / "grad_stats.safetensors")
saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors") saved_param_stats = load_file(artifact_dir / "param_stats.safetensors")
saved_actions = load_file(env_policy_dir / "actions.safetensors") saved_actions = load_file(artifact_dir / "actions.safetensors")
output_dict, grad_stats, param_stats, actions = get_policy_stats( output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs)
ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs
)
for key in saved_output_dict: for key in saved_output_dict:
assert torch.allclose(output_dict[key], saved_output_dict[key], rtol=0.1, atol=1e-7) torch.testing.assert_close(output_dict[key], saved_output_dict[key])
for key in saved_grad_stats: for key in saved_grad_stats:
assert torch.allclose(grad_stats[key], saved_grad_stats[key], rtol=0.1, atol=1e-7) torch.testing.assert_close(grad_stats[key], saved_grad_stats[key])
for key in saved_param_stats: for key in saved_param_stats:
assert torch.allclose(param_stats[key], saved_param_stats[key], rtol=0.1, atol=1e-7) torch.testing.assert_close(param_stats[key], saved_param_stats[key])
for key in saved_actions: for key in saved_actions:
assert torch.allclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7) rtol, atol = (2e-3, 5e-6) if policy_name == "diffusion" else (None, None) # HACK
torch.testing.assert_close(actions[key], saved_actions[key], rtol=rtol, atol=atol)
def test_act_temporal_ensembler(): def test_act_temporal_ensembler():
@ -490,4 +484,4 @@ def test_act_temporal_ensembler():
assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg) assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg)
assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max")) assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max"))
# Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error. # Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error.
assert torch.allclose(online_avg, offline_avg, atol=1e-4) torch.testing.assert_close(online_avg, offline_avg, rtol=1e-4, atol=1e-4)

View File

@ -114,7 +114,7 @@ def test_robot(tmp_path, request, robot_type, mock):
if "image" in name: if "image" in name:
# TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames # TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames
continue continue
assert torch.allclose(captured_observation[name], observation[name], atol=1) torch.testing.assert_close(captured_observation[name], observation[name], rtol=1e-4, atol=1)
assert captured_observation[name].shape == observation[name].shape assert captured_observation[name].shape == observation[name].shape
# Test send_action can run # Test send_action can run