Cleanup
This commit is contained in:
parent
1acfd61b88
commit
6f11b0afaf
|
@ -25,7 +25,6 @@ def get_policy_stats(env_name, policy_name, extra_overrides=None):
|
||||||
dataset = make_dataset(cfg)
|
dataset = make_dataset(cfg)
|
||||||
policy = make_policy(cfg, dataset_stats=dataset.stats)
|
policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||||
policy.train()
|
policy.train()
|
||||||
# policy.to(DEVICE)
|
|
||||||
optimizer, _ = make_optimizer(cfg, policy)
|
optimizer, _ = make_optimizer(cfg, policy)
|
||||||
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
@ -36,9 +35,6 @@ def get_policy_stats(env_name, policy_name, extra_overrides=None):
|
||||||
)
|
)
|
||||||
|
|
||||||
batch = next(iter(dataloader))
|
batch = next(iter(dataloader))
|
||||||
# for key in batch:
|
|
||||||
# batch[key] = batch[key].to(DEVICE)
|
|
||||||
|
|
||||||
output_dict = policy.forward(batch)
|
output_dict = policy.forward(batch)
|
||||||
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
|
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
|
||||||
loss = output_dict["loss"]
|
loss = output_dict["loss"]
|
||||||
|
@ -64,8 +60,6 @@ def get_policy_stats(env_name, policy_name, extra_overrides=None):
|
||||||
# HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension
|
# HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension
|
||||||
dataset.delta_timestamps = None
|
dataset.delta_timestamps = None
|
||||||
batch = next(iter(dataloader))
|
batch = next(iter(dataloader))
|
||||||
# for key in batch:
|
|
||||||
# batch[key] = batch[key].to(DEVICE)
|
|
||||||
obs = {
|
obs = {
|
||||||
k: batch[k]
|
k: batch[k]
|
||||||
for k in batch
|
for k in batch
|
||||||
|
|
|
@ -244,6 +244,8 @@ def test_normalize(insert_temporal_dim):
|
||||||
("aloha", "act", []),
|
("aloha", "act", []),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
# As artifacts have been generated on an x86_64 kernel, this test won't
|
||||||
|
# pass if it's run on another platform due to floating point errors
|
||||||
@require_x86_64_kernel
|
@require_x86_64_kernel
|
||||||
def test_backward_compatibility(env_name, policy_name, extra_overrides):
|
def test_backward_compatibility(env_name, policy_name, extra_overrides):
|
||||||
env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"
|
env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"
|
||||||
|
|
Loading…
Reference in New Issue