From b54cdc9a0fe9584faa27780b1bb112539f5e435c Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 19 Mar 2024 19:08:25 +0000 Subject: [PATCH] break_when_any_done==True for batch_size==1 --- lerobot/scripts/eval.py | 4 ++-- tests/test_policies.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 2c564da0..86d4158e 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -51,14 +51,14 @@ def eval_policy( ep_frames.append(env.render()) # noqa: B023 with torch.inference_mode(): - # TODO(alexander-soare): Due the `break_when_any_done == False` this rolls out for max_steps even when all + # TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all # envs are done the first time. But we only use the first rollout. This is a waste of compute. rollout = env.rollout( max_steps=max_steps, policy=policy, auto_cast_to_device=True, callback=maybe_render_frame, - break_when_any_done=False, + break_when_any_done=env.batch_size[0] == 1, ) # Figure out where in each rollout sequence the first done condition was encountered (results after this won't # be included). diff --git a/tests/test_policies.py b/tests/test_policies.py index f2ebcfcc..e6cfdfbc 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -52,7 +52,7 @@ def test_concrete_policy(env_name, policy_name, extra_overrides): offline_buffer = make_offline_buffer(cfg) env = make_env(cfg, transform=offline_buffer.transform) - if policy_name != "aloha": + if env_name != "aloha": # TODO(alexander-soare): Fix this part of the test. PrioritizedSliceSampler raises NotImplementedError: # seq_length as a list is not supported for now. policy.update(offline_buffer, torch.tensor(0, device=DEVICE))