This commit is contained in:
Simon Alibert 2024-05-02 16:01:05 +02:00
parent 1acfd61b88
commit 6f11b0afaf
2 changed files with 2 additions and 6 deletions

View File

@ -25,7 +25,6 @@ def get_policy_stats(env_name, policy_name, extra_overrides=None):
dataset = make_dataset(cfg) dataset = make_dataset(cfg)
policy = make_policy(cfg, dataset_stats=dataset.stats) policy = make_policy(cfg, dataset_stats=dataset.stats)
policy.train() policy.train()
# policy.to(DEVICE)
optimizer, _ = make_optimizer(cfg, policy) optimizer, _ = make_optimizer(cfg, policy)
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
@ -36,9 +35,6 @@ def get_policy_stats(env_name, policy_name, extra_overrides=None):
) )
batch = next(iter(dataloader)) batch = next(iter(dataloader))
# for key in batch:
# batch[key] = batch[key].to(DEVICE)
output_dict = policy.forward(batch) output_dict = policy.forward(batch)
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)} output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
loss = output_dict["loss"] loss = output_dict["loss"]
@ -64,8 +60,6 @@ def get_policy_stats(env_name, policy_name, extra_overrides=None):
# HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension # HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension
dataset.delta_timestamps = None dataset.delta_timestamps = None
batch = next(iter(dataloader)) batch = next(iter(dataloader))
# for key in batch:
# batch[key] = batch[key].to(DEVICE)
obs = { obs = {
k: batch[k] k: batch[k]
for k in batch for k in batch

View File

@ -244,6 +244,8 @@ def test_normalize(insert_temporal_dim):
("aloha", "act", []), ("aloha", "act", []),
], ],
) )
# As artifacts have been generated on an x86_64 kernel, this test won't
# pass if it's run on another platform due to floating point errors
@require_x86_64_kernel @require_x86_64_kernel
def test_backward_compatibility(env_name, policy_name, extra_overrides): def test_backward_compatibility(env_name, policy_name, extra_overrides):
env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}" env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"