Dial back to atol=1e-7
This commit is contained in:
parent
97ded04b07
commit
1acfd61b88
|
@ -255,10 +255,10 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides):
|
||||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
|
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
|
||||||
|
|
||||||
for key in saved_output_dict:
|
for key in saved_output_dict:
|
||||||
assert torch.isclose(output_dict[key], saved_output_dict[key], rtol=0.1, atol=1e-8).all()
|
assert torch.isclose(output_dict[key], saved_output_dict[key], rtol=0.1, atol=1e-7).all()
|
||||||
for key in saved_grad_stats:
|
for key in saved_grad_stats:
|
||||||
assert torch.isclose(grad_stats[key], saved_grad_stats[key], rtol=0.1, atol=1e-8).all()
|
assert torch.isclose(grad_stats[key], saved_grad_stats[key], rtol=0.1, atol=1e-7).all()
|
||||||
for key in saved_param_stats:
|
for key in saved_param_stats:
|
||||||
assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=0.1, atol=1e-8).all()
|
assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=0.1, atol=1e-7).all()
|
||||||
for key in saved_actions:
|
for key in saved_actions:
|
||||||
assert torch.isclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-8).all()
|
assert torch.isclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7).all()
|
||||||
|
|
Loading…
Reference in New Issue