This commit is contained in:
Simon Alibert 2024-05-02 15:43:39 +02:00
parent 2d11199320
commit 55ff23c252
1 changed files with 4 additions and 4 deletions

View File

@ -254,10 +254,10 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides):
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides) output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
for key in saved_output_dict: for key in saved_output_dict:
assert torch.isclose(output_dict[key], saved_output_dict[key], rtol=1, atol=1e-6).all() assert torch.isclose(output_dict[key], saved_output_dict[key], rtol=0.5, atol=1e-6).all()
for key in saved_grad_stats: for key in saved_grad_stats:
assert torch.isclose(grad_stats[key], saved_grad_stats[key], rtol=1, atol=1e-6).all() assert torch.isclose(grad_stats[key], saved_grad_stats[key], rtol=0.5, atol=1e-6).all()
for key in saved_param_stats: for key in saved_param_stats:
assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=1, atol=1e-6).all() assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=0.5, atol=1e-6).all()
for key in saved_actions: for key in saved_actions:
assert torch.isclose(actions[key], saved_actions[key], rtol=1, atol=1e-6).all() assert torch.isclose(actions[key], saved_actions[key], rtol=0.5, atol=1e-6).all()