diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 1e44c5df..7127b24d 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -83,14 +83,16 @@ def eval_policy( ) # (b, t, *) if save_video: - for stacked_frames in batch_stacked_frames: + for stacked_frames, done_index in zip( + batch_stacked_frames, done_indices.flatten().tolist(), strict=False + ): if episode_counter >= num_episodes: continue video_dir.mkdir(parents=True, exist_ok=True) video_path = video_dir / f"eval_episode_{episode_counter}.mp4" thread = threading.Thread( target=write_video, - args=(str(video_path), stacked_frames, fps), + args=(str(video_path), stacked_frames[:done_index], fps), ) thread.start() threads.append(thread)