break_when_any_done==True for batch_size==1

This commit is contained in:
Alexander Soare 2024-03-19 19:08:25 +00:00
parent 46ac87d2a6
commit b54cdc9a0f
2 changed files with 3 additions and 3 deletions

View File

@ -51,14 +51,14 @@ def eval_policy(
ep_frames.append(env.render()) # noqa: B023 ep_frames.append(env.render()) # noqa: B023
with torch.inference_mode(): with torch.inference_mode():
# TODO(alexander-soare): Due the `break_when_any_done == False` this rolls out for max_steps even when all # TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all
# envs are done the first time. But we only use the first rollout. This is a waste of compute. # envs are done the first time. But we only use the first rollout. This is a waste of compute.
rollout = env.rollout( rollout = env.rollout(
max_steps=max_steps, max_steps=max_steps,
policy=policy, policy=policy,
auto_cast_to_device=True, auto_cast_to_device=True,
callback=maybe_render_frame, callback=maybe_render_frame,
break_when_any_done=False, break_when_any_done=env.batch_size[0] == 1,
) )
# Figure out where in each rollout sequence the first done condition was encountered (results after this won't # Figure out where in each rollout sequence the first done condition was encountered (results after this won't
# be included). # be included).

View File

@ -52,7 +52,7 @@ def test_concrete_policy(env_name, policy_name, extra_overrides):
offline_buffer = make_offline_buffer(cfg) offline_buffer = make_offline_buffer(cfg)
env = make_env(cfg, transform=offline_buffer.transform) env = make_env(cfg, transform=offline_buffer.transform)
if policy_name != "aloha": if env_name != "aloha":
# TODO(alexander-soare): Fix this part of the test. PrioritizedSliceSampler raises NotImplementedError: # TODO(alexander-soare): Fix this part of the test. PrioritizedSliceSampler raises NotImplementedError:
# seq_length as a list is not supported for now. # seq_length as a list is not supported for now.
policy.update(offline_buffer, torch.tensor(0, device=DEVICE)) policy.update(offline_buffer, torch.tensor(0, device=DEVICE))