diff --git a/Makefile b/Makefile index 07aa4e97..a0163f94 100644 --- a/Makefile +++ b/Makefile @@ -22,9 +22,8 @@ test-end-to-end: ${MAKE} test-act-ete-eval ${MAKE} test-diffusion-ete-train ${MAKE} test-diffusion-ete-eval - # TODO(rcadene, alexander-soare): enable end-to-end tests for tdmpc - # ${MAKE} test-tdmpc-ete-train - # ${MAKE} test-tdmpc-ete-eval + ${MAKE} test-tdmpc-ete-train + ${MAKE} test-tdmpc-ete-eval ${MAKE} test-default-ete-eval test-act-ete-train: @@ -80,7 +79,7 @@ test-tdmpc-ete-train: policy=tdmpc \ env=xarm \ env.task=XarmLift-v0 \ - dataset_repo_id=lerobot/xarm_lift_medium_replay \ + dataset_repo_id=lerobot/xarm_lift_medium \ wandb.enable=False \ training.offline_steps=2 \ training.online_steps=2 \ diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index 43e841eb..eefbb303 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -1,7 +1,7 @@ # @package _global_ seed: 1 -dataset_repo_id: lerobot/xarm_lift_medium_replay +dataset_repo_id: lerobot/xarm_lift_medium training: offline_steps: 50000 diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpc/actions.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpc/actions.safetensors new file mode 100644 index 00000000..0339ca0e Binary files /dev/null and b/tests/data/save_policy_to_safetensors/xarm_tdmpc/actions.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpc/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpc/grad_stats.safetensors new file mode 100644 index 00000000..5520c643 Binary files /dev/null and b/tests/data/save_policy_to_safetensors/xarm_tdmpc/grad_stats.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpc/output_dict.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpc/output_dict.safetensors new file mode 100644 index 00000000..2321f31c Binary files /dev/null and b/tests/data/save_policy_to_safetensors/xarm_tdmpc/output_dict.safetensors differ diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpc/param_stats.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpc/param_stats.safetensors new file mode 100644 index 00000000..5e8a6947 Binary files /dev/null and b/tests/data/save_policy_to_safetensors/xarm_tdmpc/param_stats.safetensors differ diff --git a/tests/test_policies.py b/tests/test_policies.py index 12beec92..f0fa7c56 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -236,7 +236,7 @@ def test_normalize(insert_temporal_dim): @pytest.mark.parametrize( "env_name, policy_name, extra_overrides", [ - # ("xarm", "tdmpc", ["policy.n_action_repeats=2"]), + ("xarm", "tdmpc", []), ( "pusht", "diffusion",