From 25c63ccf638740f778dff0b5562f81edb074fc20 Mon Sep 17 00:00:00 2001 From: Mathias Wulfman <101942083+mwulfman@users.noreply.github.com> Date: Fri, 7 Mar 2025 13:21:11 +0100 Subject: [PATCH 1/2] :bug: Remove `map_location=device` that no longer exists when loading DiffusionPolicy from_pretained after commit 5e94738 (#830) Co-authored-by: Mathias Wulfman --- examples/2_evaluate_pretrained_policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/2_evaluate_pretrained_policy.py b/examples/2_evaluate_pretrained_policy.py index bf3c442a..edbbad38 100644 --- a/examples/2_evaluate_pretrained_policy.py +++ b/examples/2_evaluate_pretrained_policy.py @@ -44,7 +44,7 @@ pretrained_policy_path = "lerobot/diffusion_pusht" # OR a path to a local outputs/train folder. # pretrained_policy_path = Path("outputs/train/example_pusht_diffusion") -policy = DiffusionPolicy.from_pretrained(pretrained_policy_path, map_location=device) +policy = DiffusionPolicy.from_pretrained(pretrained_policy_path) # Initialize evaluation environment to render two observation types: # an image of the scene and state/position of the agent. The environment From 074f0ac8fec8483b83a04654b09e685274ee0c80 Mon Sep 17 00:00:00 2001 From: Simon Alibert <75076266+aliberts@users.noreply.github.com> Date: Fri, 7 Mar 2025 13:21:58 +0100 Subject: [PATCH 2/2] Fix gpu nightly (#829) --- tests/test_policies.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_policies.py b/tests/test_policies.py index f8e7359c..7df79c1d 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -252,10 +252,11 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name: key: ft for key, ft in features.items() if key not in policy_cfg.output_features } policy = policy_cls(policy_cfg) + policy.to(policy_cfg.device) save_dir = tmp_path / f"test_save_and_load_pretrained_{policy_cls.__name__}" policy.save_pretrained(save_dir) - policy_ = policy_cls.from_pretrained(save_dir, config=policy_cfg) - assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True)) + loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg) + torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0) @pytest.mark.parametrize("insert_temporal_dim", [False, True])