diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index 6387882c..2e65f468 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -1,6 +1,7 @@ # @package _global_ seed: 1 +dataset_repo_id: lerobot/xarm_lift_medium_replay training: offline_steps: 25000 diff --git a/tests/scripts/save_policy_to_safetensor.py b/tests/scripts/save_policy_to_safetensor.py index 68657350..a8e176c2 100644 --- a/tests/scripts/save_policy_to_safetensor.py +++ b/tests/scripts/save_policy_to_safetensor.py @@ -82,9 +82,9 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name): if __name__ == "__main__": env_policies = [ - # ("xarm", "tdmpc"), - ("pusht", "diffusion"), - ("aloha", "act"), + ("xarm", "tdmpc"), + # ("pusht", "diffusion"), + # ("aloha", "act"), ] for env, policy in env_policies: save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy) diff --git a/tests/test_policies.py b/tests/test_policies.py index 80991097..242ea46b 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -239,7 +239,7 @@ def test_normalize(insert_temporal_dim): @pytest.mark.parametrize( "env_name,policy_name", [ - # ("xarm", "tdmpc"), + ("xarm", "tdmpc"), ("pusht", "diffusion"), ("aloha", "act"), ],