From 6f11b0afaf3c70e9742e3441d62e6534f6113c19 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Thu, 2 May 2024 16:01:05 +0200 Subject: [PATCH] Cleanup --- tests/scripts/save_policy_to_safetensor.py | 6 ------ tests/test_policies.py | 2 ++ 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/scripts/save_policy_to_safetensor.py b/tests/scripts/save_policy_to_safetensor.py index d23f2896..f1708c31 100644 --- a/tests/scripts/save_policy_to_safetensor.py +++ b/tests/scripts/save_policy_to_safetensor.py @@ -25,7 +25,6 @@ def get_policy_stats(env_name, policy_name, extra_overrides=None): dataset = make_dataset(cfg) policy = make_policy(cfg, dataset_stats=dataset.stats) policy.train() - # policy.to(DEVICE) optimizer, _ = make_optimizer(cfg, policy) dataloader = torch.utils.data.DataLoader( @@ -36,9 +35,6 @@ def get_policy_stats(env_name, policy_name, extra_overrides=None): ) batch = next(iter(dataloader)) - # for key in batch: - # batch[key] = batch[key].to(DEVICE) - output_dict = policy.forward(batch) output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)} loss = output_dict["loss"] @@ -64,8 +60,6 @@ def get_policy_stats(env_name, policy_name, extra_overrides=None): # HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension dataset.delta_timestamps = None batch = next(iter(dataloader)) - # for key in batch: - # batch[key] = batch[key].to(DEVICE) obs = { k: batch[k] for k in batch diff --git a/tests/test_policies.py b/tests/test_policies.py index 24125aba..c61a8f26 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -244,6 +244,8 @@ def test_normalize(insert_temporal_dim): ("aloha", "act", []), ], ) +# As artifacts have been generated on an x86_64 kernel, this test won't +# pass if it's run on another platform due to floating point errors @require_x86_64_kernel def test_backward_compatibility(env_name, policy_name, extra_overrides): env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"