Only save video frames in first rollout
This commit is contained in:
parent
4f1955edfd
commit
52e149fbfd
|
@ -83,14 +83,16 @@ def eval_policy(
|
||||||
) # (b, t, *)
|
) # (b, t, *)
|
||||||
|
|
||||||
if save_video:
|
if save_video:
|
||||||
for stacked_frames in batch_stacked_frames:
|
for stacked_frames, done_index in zip(
|
||||||
|
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
|
||||||
|
):
|
||||||
if episode_counter >= num_episodes:
|
if episode_counter >= num_episodes:
|
||||||
continue
|
continue
|
||||||
video_dir.mkdir(parents=True, exist_ok=True)
|
video_dir.mkdir(parents=True, exist_ok=True)
|
||||||
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
|
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
|
||||||
thread = threading.Thread(
|
thread = threading.Thread(
|
||||||
target=write_video,
|
target=write_video,
|
||||||
args=(str(video_path), stacked_frames, fps),
|
args=(str(video_path), stacked_frames[:done_index], fps),
|
||||||
)
|
)
|
||||||
thread.start()
|
thread.start()
|
||||||
threads.append(thread)
|
threads.append(thread)
|
||||||
|
|
Loading…
Reference in New Issue