Fix device
This commit is contained in:
parent
28d4122929
commit
344c1653f2
|
@ -36,6 +36,9 @@ def get_policy_stats(env_name, policy_name, extra_overrides=None):
|
|||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(DEVICE)
|
||||
|
||||
output_dict = policy.forward(batch)
|
||||
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
|
||||
loss = output_dict["loss"]
|
||||
|
@ -60,6 +63,8 @@ def get_policy_stats(env_name, policy_name, extra_overrides=None):
|
|||
|
||||
dataset.delta_timestamps = None
|
||||
batch = next(iter(dataloader))
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(DEVICE)
|
||||
obs = {
|
||||
k: batch[k]
|
||||
for k in batch
|
||||
|
|
|
@ -246,10 +246,10 @@ def test_normalize(insert_temporal_dim):
|
|||
)
|
||||
def test_backward_compatibility(env_name, policy_name, extra_overrides):
|
||||
env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"
|
||||
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors")
|
||||
saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors")
|
||||
saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors")
|
||||
saved_actions = load_file(env_policy_dir / "actions.safetensors")
|
||||
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors", device=DEVICE)
|
||||
saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors", device=DEVICE)
|
||||
saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors", device=DEVICE)
|
||||
saved_actions = load_file(env_policy_dir / "actions.safetensors", device=DEVICE)
|
||||
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
|
||||
|
||||
|
@ -260,4 +260,4 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides):
|
|||
for key in saved_param_stats:
|
||||
assert torch.isclose(param_stats[key], saved_param_stats[key]).all()
|
||||
for key in saved_actions:
|
||||
assert torch.isclose(actions[key], saved_actions[key]).all()
|
||||
assert torch.isclose(actions[key], saved_actions[key], rtol=1e-4, atol=1e-7).all()
|
||||
|
|
Loading…
Reference in New Issue