Fix nightly (#775)

This commit is contained in:
Simon Alibert 2025-02-26 16:36:03 +01:00 committed by GitHub
parent da265ca920
commit 659ec4434d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
51 changed files with 145 additions and 172 deletions

View File

@ -43,7 +43,7 @@ jobs:
needs: get_changed_files
runs-on:
group: aws-general-8-plus
if: ${{ needs.get_changed_files.outputs.matrix }} != ''
if: needs.get_changed_files.outputs.matrix != ''
strategy:
fail-fast: false
matrix:

View File

@ -1,33 +1,29 @@
# Configure image
ARG PYTHON_VERSION=3.10
FROM python:${PYTHON_VERSION}-slim
ARG PYTHON_VERSION
ARG DEBIAN_FRONTEND=noninteractive
# Install apt dependencies
# Configure environment variables
ARG PYTHON_VERSION
ENV DEBIAN_FRONTEND=noninteractive
ENV MUJOCO_GL="egl"
ENV PATH="/opt/venv/bin:$PATH"
# Install dependencies and set up Python in a single layer
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake git git-lfs \
build-essential cmake git \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
speech-dispatcher libgeos-dev \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
&& ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python \
&& python -m venv /opt/venv \
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
&& echo "source /opt/venv/bin/activate" >> /root/.bashrc
# Create virtual environment
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
RUN python -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
# Install LeRobot
RUN git lfs install
RUN git clone https://github.com/huggingface/lerobot.git /lerobot
# Clone repository and install LeRobot in a single layer
COPY . /lerobot
WORKDIR /lerobot
RUN pip install --upgrade --no-cache-dir pip
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" \
--extra-index-url https://download.pytorch.org/whl/cpu
# Set EGL as the rendering backend for MuJoCo
ENV MUJOCO_GL="egl"
RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
&& /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" \
--extra-index-url https://download.pytorch.org/whl/cpu
# Execute in bash shell rather than python
CMD ["/bin/bash"]

View File

@ -8,7 +8,7 @@ ENV PATH="/opt/venv/bin:$PATH"
# Install dependencies and set up Python in a single layer
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake git git-lfs \
build-essential cmake git \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
speech-dispatcher libgeos-dev \
python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
@ -18,8 +18,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
&& echo "source /opt/venv/bin/activate" >> /root/.bashrc
# Clone repository and install LeRobot in a single layer
COPY . /lerobot
WORKDIR /lerobot
RUN git lfs install \
&& git clone https://github.com/huggingface/lerobot.git . \
&& /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
&& /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]"

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:eb7b74f919adf8d4478585f65c54997e6f3bccab67eadb4048300108586a4163
size 5104

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:dfbc3b1ad5e3b94311edda0f04db002b26117b0719b73dfdb56dd483dc9c409d
size 31672

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e39afdf1f3db8a72a1095a5a0ffdb7e67f478a28bd73e59cda197687da8d236c
size 68

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5dd39a554c9c3db537e98c9ceade024d172c46c4fa7ce9e27601b94116445417
size 33400

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a5ec46abc5a3c85675a5ee4a1bb362eecb3ff4c546082ff309c89fc7821f38bd
size 515400

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:50303d05caea725c4a240f1389424d6c2361961f2cee729a0010e909ebffed81
size 31672

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9bb9b195d32e05550af0edd5df88fcc761c829ab8c4b129ba970a723f39b46ee
size 68

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:683a2038185f3d070e7d7c0c31e4aa75067c11bf798daa41c9fab336f4183fda
size 33400

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cc67af1d60f95d84c98d6c9ebd648990e0f0705368bd6b72d2b39533950b0179
size 5104

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:64518cf652105d15f5fd2cfc13d0681f66a4ec4797dc5d5dc2f7b0d91fe5dfd6
size 31672

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:32b6d14fab4244b5140adb345e47f662b6739c04974e04b21c3127caa988abbb
size 68

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e1904ef0338f7b6efdec70ec235ee931b5751008bf4eb433edb0b3fa0838a4f1
size 33400

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fa544a97f00bf46393a09b006b44c2499bbf7d177782360a8c21cacbf200c07a
size 515400

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:83c7a8ae912300b5cedba31904f7ba22542059fd60dd86548a95e415713f719e
size 31672

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5a010633237b3a1141603c65174c551daa9e7b4c474af5a1376d73e5425bfb5d
size 68

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ec8b5c440e9fcec190c9be48b28ebb79f82ae63626afe7c811e4bb0c3dd08842
size 33400

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e56a5d30778395534a06ad1742843700424614168fc26d1098558012a5df90c6
size 5104

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c9007dd51c748db4ecd6d75e70bdcabf8c312454ac97bf6710895a12e7288557
size 31672

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:170bd8365dfd1e36e8f56814bf8bc2057aa0d035c41212b7ddd7e4b9feee1633
size 68

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:11884346b41ca102c672bb0f361ea9699d2f8b33bb503038b53cc7e7fafd281b
size 34920

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0c259ea9c40aab3841ca35b2a2e708d8829b0a9163b2f9e5efd28f1c65848293
size 4600

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:77cd4127a45ded2f75d85ca9c17537808517614ef16fb3035cebb1b45547acbf
size 47424

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fcff4b736e95d685d56830b501f4542b081f4334f72d28a7415809f4d9d15d0f
size 68

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:60775e91ed550aae66cb0547ee4b0e38917f29172e942671e9361b3812364df6
size 49120

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a32376dde65a1562403afd1db3e56c7e6b987ebaf6c3c601336e77155b9e608c
size 992

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:12ee532c53173d0361ebb979f087b229cc045aa3d9e6b94cfd4290af54fd1201
size 47424

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:271b00cb2f0cd5fd26b1d53463638e3d1a6e92692ec625fcffb420ca190869e5
size 68

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:010c01181b95625051276d69cb4209423c21f2e30a3fa9464ae67064a2ba4c22
size 49120

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c5edc5600d7206f027cb696a597bc99fcdd9073a15fa130b8031c52c0a7c134b
size 200

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10
size 16904

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b
size 164

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170
size 36312

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a9c08753ddc43b6c02a176418b81eb784146e59f4fc914591cbd3582ade392bb
size 200

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10
size 16904

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b
size 164

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170
size 36312

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:81457cfd193d9d46b6871071a3971c2901fefa544ab225576132772087b4cf3a
size 472

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d796577863740e8fd643a056e9eff891e51a858ff66019eba11f0a982cb9e9c0
size 16904

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4636751d82103a268ac7cf36f1e69f6356f356b9c40561a9fe8557bb9255e2ee
size 240

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b7d08c9518f1f15226e4efc6f2a8542d0f3e620c91421c7cacea07d9bd9025d6
size 36312

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6cdb181ba6acc4aa1209a9ea5dd783f077ff87760257de1026c33f8e2fb2b2b1
size 472

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d796577863740e8fd643a056e9eff891e51a858ff66019eba11f0a982cb9e9c0
size 16904

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4636751d82103a268ac7cf36f1e69f6356f356b9c40561a9fe8557bb9255e2ee
size 240

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b7d08c9518f1f15226e4efc6f2a8542d0f3e620c91421c7cacea07d9bd9025d6
size 36312

View File

@ -27,16 +27,13 @@ from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs):
# TODO(rcadene, aliberts): env_name?
def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
set_seed(1337)
train_cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
policy=make_policy_config(policy_name, **policy_kwargs),
device="cpu",
**train_kwargs,
)
train_cfg.validate() # Needed for auto-setting some parameters
@ -54,8 +51,11 @@ def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwa
batch = next(iter(dataloader))
loss, output_dict = policy.forward(batch)
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
output_dict["loss"] = loss
if output_dict is not None:
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
output_dict["loss"] = loss
else:
output_dict = {"loss": loss}
loss.backward()
grad_stats = {}
@ -101,30 +101,27 @@ def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwa
return output_dict, grad_stats, param_stats, actions
def save_policy_to_safetensors(output_dir, env_name, policy_name, policy_kwargs, file_name_extra):
env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}{file_name_extra}"
def save_policy_to_safetensors(output_dir: Path, ds_repo_id: str, policy_name: str, policy_kwargs: dict):
if output_dir.exists():
print(f"Overwrite existing safetensors in '{output_dir}':")
print(f" - Validate with: `git add {output_dir}`")
print(f" - Revert with: `git checkout -- {output_dir}`")
shutil.rmtree(output_dir)
if env_policy_dir.exists():
print(f"Overwrite existing safetensors in '{env_policy_dir}':")
print(f" - Validate with: `git add {env_policy_dir}`")
print(f" - Revert with: `git checkout -- {env_policy_dir}`")
shutil.rmtree(env_policy_dir)
env_policy_dir.mkdir(parents=True, exist_ok=True)
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, policy_kwargs)
save_file(output_dict, env_policy_dir / "output_dict.safetensors")
save_file(grad_stats, env_policy_dir / "grad_stats.safetensors")
save_file(param_stats, env_policy_dir / "param_stats.safetensors")
save_file(actions, env_policy_dir / "actions.safetensors")
output_dir.mkdir(parents=True, exist_ok=True)
output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs)
save_file(output_dict, output_dir / "output_dict.safetensors")
save_file(grad_stats, output_dir / "grad_stats.safetensors")
save_file(param_stats, output_dir / "param_stats.safetensors")
save_file(actions, output_dir / "actions.safetensors")
if __name__ == "__main__":
env_policies = [
("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": False}, "use_policy"),
("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": True}, "use_mpc"),
artifacts_cfg = [
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"),
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
(
"lerobot/pusht",
"pusht",
"diffusion",
{
"n_action_steps": 8,
@ -133,18 +130,17 @@ if __name__ == "__main__":
},
"",
),
("lerobot/aloha_sim_insertion_human", "aloha", "act", {"n_action_steps": 10}, ""),
("lerobot/aloha_sim_insertion_human", "act", {"n_action_steps": 10}, ""),
(
"lerobot/aloha_sim_insertion_human",
"aloha",
"act",
{"n_action_steps": 1000, "chunk_size": 1000},
"_1000_steps",
"1000_steps",
),
]
if len(env_policies) == 0:
if len(artifacts_cfg) == 0:
raise RuntimeError("No policies were provided!")
for ds_repo_id, env, policy, policy_kwargs, file_name_extra in env_policies:
save_policy_to_safetensors(
"tests/data/save_policy_to_safetensors", ds_repo_id, env, policy, policy_kwargs, file_name_extra
)
for ds_repo_id, policy, policy_kwargs, file_name_extra in artifacts_cfg:
ds_name = ds_repo_id.split("/")[-1]
output_dir = Path("tests/data/save_policy_to_safetensors") / f"{ds_name}_{policy}_{file_name_extra}"
save_policy_to_safetensors(output_dir, ds_repo_id, policy, policy_kwargs)

View File

@ -166,7 +166,7 @@ def test_delta_timestamps_within_tolerance():
buffer.tolerance_s = 0.04
item = buffer[2]
data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"]
assert torch.allclose(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
torch.testing.assert_close(data, torch.tensor([0, 2, 3]), msg="Data does not match expected values")
assert not is_pad.any(), "Unexpected padding detected"
@ -236,7 +236,7 @@ def test_compute_sampler_weights_trivial(
elif online_sampling_ratio == 1:
expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)])
expected_weights /= expected_weights.sum()
assert torch.allclose(weights, expected_weights)
torch.testing.assert_close(weights, expected_weights)
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path):
@ -248,7 +248,7 @@ def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_p
weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
)
assert torch.allclose(
torch.testing.assert_close(
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
)
@ -261,7 +261,7 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_datase
weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1
)
assert torch.allclose(
torch.testing.assert_close(
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0])
)
@ -279,4 +279,4 @@ def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp
online_sampling_ratio=0.5,
online_drop_n_last_frames=1,
)
assert torch.allclose(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0]))
torch.testing.assert_close(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0]))

View File

@ -363,37 +363,33 @@ def test_normalize(insert_temporal_dim):
@pytest.mark.parametrize(
"ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs, file_name_extra",
"ds_repo_id, policy_name, policy_kwargs, file_name_extra",
[
# TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it
# was changed to true. For some reason, tests would pass locally, but not in CI. So here we override
# to test with `policy.use_mpc=false`.
("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": False}, {"batch_size": 25}, "use_policy"),
# ("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": True}, {}, "use_mpc"),
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"),
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
# TODO(rcadene): the diffusion model was normalizing the image in mean=0.5 std=0.5 which is a hack supposed to
# to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference
# that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass.
# Thus, we deactivate this test for now.
# (
# "lerobot/pusht",
# "pusht",
# "diffusion",
# {
# "n_action_steps": 8,
# "num_inference_steps": 10,
# "down_dims": [128, 256, 512],
# },
# {"batch_size": 64},
# "",
# ),
("lerobot/aloha_sim_insertion_human", "aloha", "act", {"n_action_steps": 10}, {}, ""),
(
"lerobot/pusht",
"diffusion",
{
"n_action_steps": 8,
"num_inference_steps": 10,
"down_dims": [128, 256, 512],
},
"",
),
("lerobot/aloha_sim_insertion_human", "act", {"n_action_steps": 10}, ""),
(
"lerobot/aloha_sim_insertion_human",
"aloha",
"act",
{"n_action_steps": 1000, "chunk_size": 1000},
{},
"_1000_steps",
"1000_steps",
),
],
)
@ -401,9 +397,7 @@ def test_normalize(insert_temporal_dim):
# pass if it's run on another platform due to floating point errors
@require_x86_64_kernel
@require_cpu
def test_backward_compatibility(
ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs, file_name_extra
):
def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs: dict, file_name_extra: str):
"""
NOTE: If this test does not pass, and you have intentionally changed something in the policy:
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
@ -416,26 +410,26 @@ def test_backward_compatibility(
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
6. Remember to stage and commit the resulting changes to `tests/data`.
"""
env_policy_dir = (
Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}{file_name_extra}"
ds_name = ds_repo_id.split("/")[-1]
artifact_dir = (
Path("tests/data/save_policy_to_safetensors") / f"{ds_name}_{policy_name}_{file_name_extra}"
)
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors")
saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors")
saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors")
saved_actions = load_file(env_policy_dir / "actions.safetensors")
saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")
saved_grad_stats = load_file(artifact_dir / "grad_stats.safetensors")
saved_param_stats = load_file(artifact_dir / "param_stats.safetensors")
saved_actions = load_file(artifact_dir / "actions.safetensors")
output_dict, grad_stats, param_stats, actions = get_policy_stats(
ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs
)
output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs)
for key in saved_output_dict:
assert torch.allclose(output_dict[key], saved_output_dict[key], rtol=0.1, atol=1e-7)
torch.testing.assert_close(output_dict[key], saved_output_dict[key])
for key in saved_grad_stats:
assert torch.allclose(grad_stats[key], saved_grad_stats[key], rtol=0.1, atol=1e-7)
torch.testing.assert_close(grad_stats[key], saved_grad_stats[key])
for key in saved_param_stats:
assert torch.allclose(param_stats[key], saved_param_stats[key], rtol=0.1, atol=1e-7)
torch.testing.assert_close(param_stats[key], saved_param_stats[key])
for key in saved_actions:
assert torch.allclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7)
rtol, atol = (2e-3, 5e-6) if policy_name == "diffusion" else (None, None) # HACK
torch.testing.assert_close(actions[key], saved_actions[key], rtol=rtol, atol=atol)
def test_act_temporal_ensembler():
@ -490,4 +484,4 @@ def test_act_temporal_ensembler():
assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg)
assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max"))
# Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error.
assert torch.allclose(online_avg, offline_avg, atol=1e-4)
torch.testing.assert_close(online_avg, offline_avg, rtol=1e-4, atol=1e-4)

View File

@ -114,7 +114,7 @@ def test_robot(tmp_path, request, robot_type, mock):
if "image" in name:
# TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames
continue
assert torch.allclose(captured_observation[name], observation[name], atol=1)
torch.testing.assert_close(captured_observation[name], observation[name], rtol=1e-4, atol=1)
assert captured_observation[name].shape == observation[name].shape
# Test send_action can run