diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 2c564da0..86d4158e 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -51,14 +51,14 @@ def eval_policy( ep_frames.append(env.render()) # noqa: B023 with torch.inference_mode(): - # TODO(alexander-soare): Due the `break_when_any_done == False` this rolls out for max_steps even when all + # TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all # envs are done the first time. But we only use the first rollout. This is a waste of compute. rollout = env.rollout( max_steps=max_steps, policy=policy, auto_cast_to_device=True, callback=maybe_render_frame, - break_when_any_done=False, + break_when_any_done=env.batch_size[0] == 1, ) # Figure out where in each rollout sequence the first done condition was encountered (results after this won't # be included). diff --git a/tests/test_policies.py b/tests/test_policies.py index f2ebcfcc..e6cfdfbc 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -52,7 +52,7 @@ def test_concrete_policy(env_name, policy_name, extra_overrides): offline_buffer = make_offline_buffer(cfg) env = make_env(cfg, transform=offline_buffer.transform) - if policy_name != "aloha": + if env_name != "aloha": # TODO(alexander-soare): Fix this part of the test. PrioritizedSliceSampler raises NotImplementedError: # seq_length as a list is not supported for now. policy.update(offline_buffer, torch.tensor(0, device=DEVICE))