Cleanup
This commit is contained in:
parent
1acfd61b88
commit
6f11b0afaf
|
@ -25,7 +25,6 @@ def get_policy_stats(env_name, policy_name, extra_overrides=None):
|
|||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||
policy.train()
|
||||
# policy.to(DEVICE)
|
||||
optimizer, _ = make_optimizer(cfg, policy)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
|
@ -36,9 +35,6 @@ def get_policy_stats(env_name, policy_name, extra_overrides=None):
|
|||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
# for key in batch:
|
||||
# batch[key] = batch[key].to(DEVICE)
|
||||
|
||||
output_dict = policy.forward(batch)
|
||||
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
|
||||
loss = output_dict["loss"]
|
||||
|
@ -64,8 +60,6 @@ def get_policy_stats(env_name, policy_name, extra_overrides=None):
|
|||
# HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension
|
||||
dataset.delta_timestamps = None
|
||||
batch = next(iter(dataloader))
|
||||
# for key in batch:
|
||||
# batch[key] = batch[key].to(DEVICE)
|
||||
obs = {
|
||||
k: batch[k]
|
||||
for k in batch
|
||||
|
|
|
@ -244,6 +244,8 @@ def test_normalize(insert_temporal_dim):
|
|||
("aloha", "act", []),
|
||||
],
|
||||
)
|
||||
# As artifacts have been generated on an x86_64 kernel, this test won't
|
||||
# pass if it's run on another platform due to floating point errors
|
||||
@require_x86_64_kernel
|
||||
def test_backward_compatibility(env_name, policy_name, extra_overrides):
|
||||
env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"
|
||||
|
|
Loading…
Reference in New Issue