From 8e842ac00ceb88c9e402fa5099158cdae04f4777 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Thu, 2 May 2024 15:27:38 +0200 Subject: [PATCH] rtol=1e-4, atol=1e-7 --- tests/test_policies.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_policies.py b/tests/test_policies.py index 4351c18a..5de95310 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -254,10 +254,10 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides): output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides) for key in saved_output_dict: - assert torch.isclose(output_dict[key], saved_output_dict[key], rtol=1e-5, atol=1e-8).all() + assert torch.isclose(output_dict[key], saved_output_dict[key], rtol=1e-4, atol=1e-7).all() for key in saved_grad_stats: - assert torch.isclose(grad_stats[key], saved_grad_stats[key], rtol=1e-5, atol=1e-8).all() + assert torch.isclose(grad_stats[key], saved_grad_stats[key], rtol=1e-4, atol=1e-7).all() for key in saved_param_stats: - assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=1e-5, atol=1e-8).all() + assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=1e-4, atol=1e-7).all() for key in saved_actions: - assert torch.isclose(actions[key], saved_actions[key], rtol=1e-5, atol=1e-8).all() + assert torch.isclose(actions[key], saved_actions[key], rtol=1e-4, atol=1e-7).all()