Fix gpu nightly (#829)
This commit is contained in:
parent
25c63ccf63
commit
074f0ac8fe
|
@ -252,10 +252,11 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
|
||||||
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
|
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
|
||||||
}
|
}
|
||||||
policy = policy_cls(policy_cfg)
|
policy = policy_cls(policy_cfg)
|
||||||
|
policy.to(policy_cfg.device)
|
||||||
save_dir = tmp_path / f"test_save_and_load_pretrained_{policy_cls.__name__}"
|
save_dir = tmp_path / f"test_save_and_load_pretrained_{policy_cls.__name__}"
|
||||||
policy.save_pretrained(save_dir)
|
policy.save_pretrained(save_dir)
|
||||||
policy_ = policy_cls.from_pretrained(save_dir, config=policy_cfg)
|
loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg)
|
||||||
assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True))
|
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("insert_temporal_dim", [False, True])
|
@pytest.mark.parametrize("insert_temporal_dim", [False, True])
|
||||||
|
|
Loading…
Reference in New Issue