This commit is contained in:
Simon Alibert 2024-05-02 16:01:05 +02:00
parent 1acfd61b88
commit 6f11b0afaf
2 changed files with 2 additions and 6 deletions

View File

@ -25,7 +25,6 @@ def get_policy_stats(env_name, policy_name, extra_overrides=None):
dataset = make_dataset(cfg)
policy = make_policy(cfg, dataset_stats=dataset.stats)
policy.train()
# policy.to(DEVICE)
optimizer, _ = make_optimizer(cfg, policy)
dataloader = torch.utils.data.DataLoader(
@ -36,9 +35,6 @@ def get_policy_stats(env_name, policy_name, extra_overrides=None):
)
batch = next(iter(dataloader))
# for key in batch:
# batch[key] = batch[key].to(DEVICE)
output_dict = policy.forward(batch)
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
loss = output_dict["loss"]
@ -64,8 +60,6 @@ def get_policy_stats(env_name, policy_name, extra_overrides=None):
# HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension
dataset.delta_timestamps = None
batch = next(iter(dataloader))
# for key in batch:
# batch[key] = batch[key].to(DEVICE)
obs = {
k: batch[k]
for k in batch

View File

@ -244,6 +244,8 @@ def test_normalize(insert_temporal_dim):
("aloha", "act", []),
],
)
# As artifacts have been generated on an x86_64 kernel, this test won't
# pass if it's run on another platform due to floating point errors
@require_x86_64_kernel
def test_backward_compatibility(env_name, policy_name, extra_overrides):
env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"