Fix device

This commit is contained in:
Simon Alibert 2024-05-02 14:17:20 +02:00
parent 28d4122929
commit 344c1653f2
2 changed files with 10 additions and 5 deletions

View File

@ -36,6 +36,9 @@ def get_policy_stats(env_name, policy_name, extra_overrides=None):
)
batch = next(iter(dataloader))
for key in batch:
batch[key] = batch[key].to(DEVICE)
output_dict = policy.forward(batch)
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
loss = output_dict["loss"]
@ -60,6 +63,8 @@ def get_policy_stats(env_name, policy_name, extra_overrides=None):
dataset.delta_timestamps = None
batch = next(iter(dataloader))
for key in batch:
batch[key] = batch[key].to(DEVICE)
obs = {
k: batch[k]
for k in batch

View File

@ -246,10 +246,10 @@ def test_normalize(insert_temporal_dim):
)
def test_backward_compatibility(env_name, policy_name, extra_overrides):
env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors")
saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors")
saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors")
saved_actions = load_file(env_policy_dir / "actions.safetensors")
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors", device=DEVICE)
saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors", device=DEVICE)
saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors", device=DEVICE)
saved_actions = load_file(env_policy_dir / "actions.safetensors", device=DEVICE)
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
@ -260,4 +260,4 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides):
for key in saved_param_stats:
assert torch.isclose(param_stats[key], saved_param_stats[key]).all()
for key in saved_actions:
assert torch.isclose(actions[key], saved_actions[key]).all()
assert torch.isclose(actions[key], saved_actions[key], rtol=1e-4, atol=1e-7).all()