diff --git a/tests/scripts/save_policy_to_safetensor.py b/tests/scripts/save_policy_to_safetensor.py index f1f780d3..8f365644 100644 --- a/tests/scripts/save_policy_to_safetensor.py +++ b/tests/scripts/save_policy_to_safetensor.py @@ -36,6 +36,9 @@ def get_policy_stats(env_name, policy_name, extra_overrides=None): ) batch = next(iter(dataloader)) + for key in batch: + batch[key] = batch[key].to(DEVICE) + output_dict = policy.forward(batch) output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)} loss = output_dict["loss"] @@ -60,6 +63,8 @@ def get_policy_stats(env_name, policy_name, extra_overrides=None): dataset.delta_timestamps = None batch = next(iter(dataloader)) + for key in batch: + batch[key] = batch[key].to(DEVICE) obs = { k: batch[k] for k in batch diff --git a/tests/test_policies.py b/tests/test_policies.py index 95a51c8e..21e69175 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -246,10 +246,10 @@ def test_normalize(insert_temporal_dim): ) def test_backward_compatibility(env_name, policy_name, extra_overrides): env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}" - saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors") - saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors") - saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors") - saved_actions = load_file(env_policy_dir / "actions.safetensors") + saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors", device=DEVICE) + saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors", device=DEVICE) + saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors", device=DEVICE) + saved_actions = load_file(env_policy_dir / "actions.safetensors", device=DEVICE) output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides) @@ -260,4 +260,4 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides): for key in saved_param_stats: assert torch.isclose(param_stats[key], saved_param_stats[key]).all() for key in saved_actions: - assert torch.isclose(actions[key], saved_actions[key]).all() + assert torch.isclose(actions[key], saved_actions[key], rtol=1e-4, atol=1e-7).all()