Fix nightly (#775)
This commit is contained in:
parent
da265ca920
commit
659ec4434d
|
@ -43,7 +43,7 @@ jobs:
|
|||
needs: get_changed_files
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
if: ${{ needs.get_changed_files.outputs.matrix }} != ''
|
||||
if: needs.get_changed_files.outputs.matrix != ''
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
|
|
|
@ -1,33 +1,29 @@
|
|||
# Configure image
|
||||
ARG PYTHON_VERSION=3.10
|
||||
|
||||
FROM python:${PYTHON_VERSION}-slim
|
||||
ARG PYTHON_VERSION
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install apt dependencies
|
||||
# Configure environment variables
|
||||
ARG PYTHON_VERSION
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV MUJOCO_GL="egl"
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# Install dependencies and set up Python in a single layer
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake git git-lfs \
|
||||
build-essential cmake git \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
speech-dispatcher libgeos-dev \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
&& ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python \
|
||||
&& python -m venv /opt/venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
|
||||
&& echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
|
||||
# Create virtual environment
|
||||
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
|
||||
RUN python -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
|
||||
# Install LeRobot
|
||||
RUN git lfs install
|
||||
RUN git clone https://github.com/huggingface/lerobot.git /lerobot
|
||||
# Clone repository and install LeRobot in a single layer
|
||||
COPY . /lerobot
|
||||
WORKDIR /lerobot
|
||||
RUN pip install --upgrade --no-cache-dir pip
|
||||
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
# Set EGL as the rendering backend for MuJoCo
|
||||
ENV MUJOCO_GL="egl"
|
||||
RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
|
||||
&& /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
# Execute in bash shell rather than python
|
||||
CMD ["/bin/bash"]
|
||||
|
|
|
@ -8,7 +8,7 @@ ENV PATH="/opt/venv/bin:$PATH"
|
|||
|
||||
# Install dependencies and set up Python in a single layer
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake git git-lfs \
|
||||
build-essential cmake git \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
speech-dispatcher libgeos-dev \
|
||||
python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
|
@ -18,8 +18,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||
&& echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
|
||||
# Clone repository and install LeRobot in a single layer
|
||||
COPY . /lerobot
|
||||
WORKDIR /lerobot
|
||||
RUN git lfs install \
|
||||
&& git clone https://github.com/huggingface/lerobot.git . \
|
||||
&& /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
|
||||
RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
|
||||
&& /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]"
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:eb7b74f919adf8d4478585f65c54997e6f3bccab67eadb4048300108586a4163
|
||||
size 5104
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:dfbc3b1ad5e3b94311edda0f04db002b26117b0719b73dfdb56dd483dc9c409d
|
||||
size 31672
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e39afdf1f3db8a72a1095a5a0ffdb7e67f478a28bd73e59cda197687da8d236c
|
||||
size 68
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5dd39a554c9c3db537e98c9ceade024d172c46c4fa7ce9e27601b94116445417
|
||||
size 33400
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a5ec46abc5a3c85675a5ee4a1bb362eecb3ff4c546082ff309c89fc7821f38bd
|
||||
size 515400
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:50303d05caea725c4a240f1389424d6c2361961f2cee729a0010e909ebffed81
|
||||
size 31672
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9bb9b195d32e05550af0edd5df88fcc761c829ab8c4b129ba970a723f39b46ee
|
||||
size 68
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:683a2038185f3d070e7d7c0c31e4aa75067c11bf798daa41c9fab336f4183fda
|
||||
size 33400
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cc67af1d60f95d84c98d6c9ebd648990e0f0705368bd6b72d2b39533950b0179
|
||||
size 5104
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:64518cf652105d15f5fd2cfc13d0681f66a4ec4797dc5d5dc2f7b0d91fe5dfd6
|
||||
size 31672
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:32b6d14fab4244b5140adb345e47f662b6739c04974e04b21c3127caa988abbb
|
||||
size 68
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e1904ef0338f7b6efdec70ec235ee931b5751008bf4eb433edb0b3fa0838a4f1
|
||||
size 33400
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fa544a97f00bf46393a09b006b44c2499bbf7d177782360a8c21cacbf200c07a
|
||||
size 515400
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:83c7a8ae912300b5cedba31904f7ba22542059fd60dd86548a95e415713f719e
|
||||
size 31672
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5a010633237b3a1141603c65174c551daa9e7b4c474af5a1376d73e5425bfb5d
|
||||
size 68
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ec8b5c440e9fcec190c9be48b28ebb79f82ae63626afe7c811e4bb0c3dd08842
|
||||
size 33400
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e56a5d30778395534a06ad1742843700424614168fc26d1098558012a5df90c6
|
||||
size 5104
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c9007dd51c748db4ecd6d75e70bdcabf8c312454ac97bf6710895a12e7288557
|
||||
size 31672
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:170bd8365dfd1e36e8f56814bf8bc2057aa0d035c41212b7ddd7e4b9feee1633
|
||||
size 68
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:11884346b41ca102c672bb0f361ea9699d2f8b33bb503038b53cc7e7fafd281b
|
||||
size 34920
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0c259ea9c40aab3841ca35b2a2e708d8829b0a9163b2f9e5efd28f1c65848293
|
||||
size 4600
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:77cd4127a45ded2f75d85ca9c17537808517614ef16fb3035cebb1b45547acbf
|
||||
size 47424
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fcff4b736e95d685d56830b501f4542b081f4334f72d28a7415809f4d9d15d0f
|
||||
size 68
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:60775e91ed550aae66cb0547ee4b0e38917f29172e942671e9361b3812364df6
|
||||
size 49120
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a32376dde65a1562403afd1db3e56c7e6b987ebaf6c3c601336e77155b9e608c
|
||||
size 992
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:12ee532c53173d0361ebb979f087b229cc045aa3d9e6b94cfd4290af54fd1201
|
||||
size 47424
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:271b00cb2f0cd5fd26b1d53463638e3d1a6e92692ec625fcffb420ca190869e5
|
||||
size 68
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:010c01181b95625051276d69cb4209423c21f2e30a3fa9464ae67064a2ba4c22
|
||||
size 49120
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c5edc5600d7206f027cb696a597bc99fcdd9073a15fa130b8031c52c0a7c134b
|
||||
size 200
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10
|
||||
size 16904
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b
|
||||
size 164
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170
|
||||
size 36312
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a9c08753ddc43b6c02a176418b81eb784146e59f4fc914591cbd3582ade392bb
|
||||
size 200
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10
|
||||
size 16904
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b
|
||||
size 164
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170
|
||||
size 36312
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:81457cfd193d9d46b6871071a3971c2901fefa544ab225576132772087b4cf3a
|
||||
size 472
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d796577863740e8fd643a056e9eff891e51a858ff66019eba11f0a982cb9e9c0
|
||||
size 16904
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4636751d82103a268ac7cf36f1e69f6356f356b9c40561a9fe8557bb9255e2ee
|
||||
size 240
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b7d08c9518f1f15226e4efc6f2a8542d0f3e620c91421c7cacea07d9bd9025d6
|
||||
size 36312
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6cdb181ba6acc4aa1209a9ea5dd783f077ff87760257de1026c33f8e2fb2b2b1
|
||||
size 472
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d796577863740e8fd643a056e9eff891e51a858ff66019eba11f0a982cb9e9c0
|
||||
size 16904
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4636751d82103a268ac7cf36f1e69f6356f356b9c40561a9fe8557bb9255e2ee
|
||||
size 240
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b7d08c9518f1f15226e4efc6f2a8542d0f3e620c91421c7cacea07d9bd9025d6
|
||||
size 36312
|
|
@ -27,16 +27,13 @@ from lerobot.configs.default import DatasetConfig
|
|||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs):
|
||||
# TODO(rcadene, aliberts): env_name?
|
||||
def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
set_seed(1337)
|
||||
|
||||
train_cfg = TrainPipelineConfig(
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
|
||||
policy=make_policy_config(policy_name, **policy_kwargs),
|
||||
device="cpu",
|
||||
**train_kwargs,
|
||||
)
|
||||
train_cfg.validate() # Needed for auto-setting some parameters
|
||||
|
||||
|
@ -54,8 +51,11 @@ def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwa
|
|||
|
||||
batch = next(iter(dataloader))
|
||||
loss, output_dict = policy.forward(batch)
|
||||
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
|
||||
output_dict["loss"] = loss
|
||||
if output_dict is not None:
|
||||
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
|
||||
output_dict["loss"] = loss
|
||||
else:
|
||||
output_dict = {"loss": loss}
|
||||
|
||||
loss.backward()
|
||||
grad_stats = {}
|
||||
|
@ -101,30 +101,27 @@ def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwa
|
|||
return output_dict, grad_stats, param_stats, actions
|
||||
|
||||
|
||||
def save_policy_to_safetensors(output_dir, env_name, policy_name, policy_kwargs, file_name_extra):
|
||||
env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}{file_name_extra}"
|
||||
def save_policy_to_safetensors(output_dir: Path, ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
if output_dir.exists():
|
||||
print(f"Overwrite existing safetensors in '{output_dir}':")
|
||||
print(f" - Validate with: `git add {output_dir}`")
|
||||
print(f" - Revert with: `git checkout -- {output_dir}`")
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
if env_policy_dir.exists():
|
||||
print(f"Overwrite existing safetensors in '{env_policy_dir}':")
|
||||
print(f" - Validate with: `git add {env_policy_dir}`")
|
||||
print(f" - Revert with: `git checkout -- {env_policy_dir}`")
|
||||
shutil.rmtree(env_policy_dir)
|
||||
|
||||
env_policy_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, policy_kwargs)
|
||||
save_file(output_dict, env_policy_dir / "output_dict.safetensors")
|
||||
save_file(grad_stats, env_policy_dir / "grad_stats.safetensors")
|
||||
save_file(param_stats, env_policy_dir / "param_stats.safetensors")
|
||||
save_file(actions, env_policy_dir / "actions.safetensors")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs)
|
||||
save_file(output_dict, output_dir / "output_dict.safetensors")
|
||||
save_file(grad_stats, output_dir / "grad_stats.safetensors")
|
||||
save_file(param_stats, output_dir / "param_stats.safetensors")
|
||||
save_file(actions, output_dir / "actions.safetensors")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
env_policies = [
|
||||
("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": False}, "use_policy"),
|
||||
("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": True}, "use_mpc"),
|
||||
artifacts_cfg = [
|
||||
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"),
|
||||
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
|
||||
(
|
||||
"lerobot/pusht",
|
||||
"pusht",
|
||||
"diffusion",
|
||||
{
|
||||
"n_action_steps": 8,
|
||||
|
@ -133,18 +130,17 @@ if __name__ == "__main__":
|
|||
},
|
||||
"",
|
||||
),
|
||||
("lerobot/aloha_sim_insertion_human", "aloha", "act", {"n_action_steps": 10}, ""),
|
||||
("lerobot/aloha_sim_insertion_human", "act", {"n_action_steps": 10}, ""),
|
||||
(
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
"aloha",
|
||||
"act",
|
||||
{"n_action_steps": 1000, "chunk_size": 1000},
|
||||
"_1000_steps",
|
||||
"1000_steps",
|
||||
),
|
||||
]
|
||||
if len(env_policies) == 0:
|
||||
if len(artifacts_cfg) == 0:
|
||||
raise RuntimeError("No policies were provided!")
|
||||
for ds_repo_id, env, policy, policy_kwargs, file_name_extra in env_policies:
|
||||
save_policy_to_safetensors(
|
||||
"tests/data/save_policy_to_safetensors", ds_repo_id, env, policy, policy_kwargs, file_name_extra
|
||||
)
|
||||
for ds_repo_id, policy, policy_kwargs, file_name_extra in artifacts_cfg:
|
||||
ds_name = ds_repo_id.split("/")[-1]
|
||||
output_dir = Path("tests/data/save_policy_to_safetensors") / f"{ds_name}_{policy}_{file_name_extra}"
|
||||
save_policy_to_safetensors(output_dir, ds_repo_id, policy, policy_kwargs)
|
||||
|
|
|
@ -166,7 +166,7 @@ def test_delta_timestamps_within_tolerance():
|
|||
buffer.tolerance_s = 0.04
|
||||
item = buffer[2]
|
||||
data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"]
|
||||
assert torch.allclose(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
|
||||
torch.testing.assert_close(data, torch.tensor([0, 2, 3]), msg="Data does not match expected values")
|
||||
assert not is_pad.any(), "Unexpected padding detected"
|
||||
|
||||
|
||||
|
@ -236,7 +236,7 @@ def test_compute_sampler_weights_trivial(
|
|||
elif online_sampling_ratio == 1:
|
||||
expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)])
|
||||
expected_weights /= expected_weights.sum()
|
||||
assert torch.allclose(weights, expected_weights)
|
||||
torch.testing.assert_close(weights, expected_weights)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path):
|
||||
|
@ -248,7 +248,7 @@ def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_p
|
|||
weights = compute_sampler_weights(
|
||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
|
||||
)
|
||||
assert torch.allclose(
|
||||
torch.testing.assert_close(
|
||||
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
|
||||
)
|
||||
|
||||
|
@ -261,7 +261,7 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_datase
|
|||
weights = compute_sampler_weights(
|
||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1
|
||||
)
|
||||
assert torch.allclose(
|
||||
torch.testing.assert_close(
|
||||
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0])
|
||||
)
|
||||
|
||||
|
@ -279,4 +279,4 @@ def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp
|
|||
online_sampling_ratio=0.5,
|
||||
online_drop_n_last_frames=1,
|
||||
)
|
||||
assert torch.allclose(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0]))
|
||||
torch.testing.assert_close(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0]))
|
||||
|
|
|
@ -363,37 +363,33 @@ def test_normalize(insert_temporal_dim):
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs, file_name_extra",
|
||||
"ds_repo_id, policy_name, policy_kwargs, file_name_extra",
|
||||
[
|
||||
# TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it
|
||||
# was changed to true. For some reason, tests would pass locally, but not in CI. So here we override
|
||||
# to test with `policy.use_mpc=false`.
|
||||
("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": False}, {"batch_size": 25}, "use_policy"),
|
||||
# ("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": True}, {}, "use_mpc"),
|
||||
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"),
|
||||
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
|
||||
# TODO(rcadene): the diffusion model was normalizing the image in mean=0.5 std=0.5 which is a hack supposed to
|
||||
# to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference
|
||||
# that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass.
|
||||
# Thus, we deactivate this test for now.
|
||||
# (
|
||||
# "lerobot/pusht",
|
||||
# "pusht",
|
||||
# "diffusion",
|
||||
# {
|
||||
# "n_action_steps": 8,
|
||||
# "num_inference_steps": 10,
|
||||
# "down_dims": [128, 256, 512],
|
||||
# },
|
||||
# {"batch_size": 64},
|
||||
# "",
|
||||
# ),
|
||||
("lerobot/aloha_sim_insertion_human", "aloha", "act", {"n_action_steps": 10}, {}, ""),
|
||||
(
|
||||
"lerobot/pusht",
|
||||
"diffusion",
|
||||
{
|
||||
"n_action_steps": 8,
|
||||
"num_inference_steps": 10,
|
||||
"down_dims": [128, 256, 512],
|
||||
},
|
||||
"",
|
||||
),
|
||||
("lerobot/aloha_sim_insertion_human", "act", {"n_action_steps": 10}, ""),
|
||||
(
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
"aloha",
|
||||
"act",
|
||||
{"n_action_steps": 1000, "chunk_size": 1000},
|
||||
{},
|
||||
"_1000_steps",
|
||||
"1000_steps",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -401,9 +397,7 @@ def test_normalize(insert_temporal_dim):
|
|||
# pass if it's run on another platform due to floating point errors
|
||||
@require_x86_64_kernel
|
||||
@require_cpu
|
||||
def test_backward_compatibility(
|
||||
ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs, file_name_extra
|
||||
):
|
||||
def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs: dict, file_name_extra: str):
|
||||
"""
|
||||
NOTE: If this test does not pass, and you have intentionally changed something in the policy:
|
||||
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
|
||||
|
@ -416,26 +410,26 @@ def test_backward_compatibility(
|
|||
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
|
||||
6. Remember to stage and commit the resulting changes to `tests/data`.
|
||||
"""
|
||||
env_policy_dir = (
|
||||
Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}{file_name_extra}"
|
||||
ds_name = ds_repo_id.split("/")[-1]
|
||||
artifact_dir = (
|
||||
Path("tests/data/save_policy_to_safetensors") / f"{ds_name}_{policy_name}_{file_name_extra}"
|
||||
)
|
||||
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors")
|
||||
saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors")
|
||||
saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors")
|
||||
saved_actions = load_file(env_policy_dir / "actions.safetensors")
|
||||
saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")
|
||||
saved_grad_stats = load_file(artifact_dir / "grad_stats.safetensors")
|
||||
saved_param_stats = load_file(artifact_dir / "param_stats.safetensors")
|
||||
saved_actions = load_file(artifact_dir / "actions.safetensors")
|
||||
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(
|
||||
ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs
|
||||
)
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs)
|
||||
|
||||
for key in saved_output_dict:
|
||||
assert torch.allclose(output_dict[key], saved_output_dict[key], rtol=0.1, atol=1e-7)
|
||||
torch.testing.assert_close(output_dict[key], saved_output_dict[key])
|
||||
for key in saved_grad_stats:
|
||||
assert torch.allclose(grad_stats[key], saved_grad_stats[key], rtol=0.1, atol=1e-7)
|
||||
torch.testing.assert_close(grad_stats[key], saved_grad_stats[key])
|
||||
for key in saved_param_stats:
|
||||
assert torch.allclose(param_stats[key], saved_param_stats[key], rtol=0.1, atol=1e-7)
|
||||
torch.testing.assert_close(param_stats[key], saved_param_stats[key])
|
||||
for key in saved_actions:
|
||||
assert torch.allclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7)
|
||||
rtol, atol = (2e-3, 5e-6) if policy_name == "diffusion" else (None, None) # HACK
|
||||
torch.testing.assert_close(actions[key], saved_actions[key], rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
def test_act_temporal_ensembler():
|
||||
|
@ -490,4 +484,4 @@ def test_act_temporal_ensembler():
|
|||
assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg)
|
||||
assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max"))
|
||||
# Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error.
|
||||
assert torch.allclose(online_avg, offline_avg, atol=1e-4)
|
||||
torch.testing.assert_close(online_avg, offline_avg, rtol=1e-4, atol=1e-4)
|
||||
|
|
|
@ -114,7 +114,7 @@ def test_robot(tmp_path, request, robot_type, mock):
|
|||
if "image" in name:
|
||||
# TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames
|
||||
continue
|
||||
assert torch.allclose(captured_observation[name], observation[name], atol=1)
|
||||
torch.testing.assert_close(captured_observation[name], observation[name], rtol=1e-4, atol=1)
|
||||
assert captured_observation[name].shape == observation[name].shape
|
||||
|
||||
# Test send_action can run
|
||||
|
|
Loading…
Reference in New Issue