Only save video frames in first rollout

This commit is contained in:
Alexander Soare 2024-03-20 08:32:11 +00:00
parent 4f1955edfd
commit 52e149fbfd
1 changed files with 4 additions and 2 deletions

View File

@ -83,14 +83,16 @@ def eval_policy(
) # (b, t, *)
if save_video:
for stacked_frames in batch_stacked_frames:
for stacked_frames, done_index in zip(
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
):
if episode_counter >= num_episodes:
continue
video_dir.mkdir(parents=True, exist_ok=True)
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
thread = threading.Thread(
target=write_video,
args=(str(video_path), stacked_frames, fps),
args=(str(video_path), stacked_frames[:done_index], fps),
)
thread.start()
threads.append(thread)