move video dir assertion

This commit is contained in:
Wael Karkoub 2024-06-10 16:55:14 +01:00
parent 18b8e95067
commit 1853288ef7
1 changed files with 3 additions and 1 deletions

View File

@ -233,6 +233,9 @@ def eval_policy(
Returns:
Dictionary with metrics and data regarding the rollouts.
"""
if max_episodes_rendered > 0 and not videos_dir:
raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")
assert isinstance(policy, Policy)
start = time.time()
policy.eval()
@ -355,7 +358,6 @@ def eval_policy(
# Maybe render video for visualization.
if max_episodes_rendered > 0 and len(ep_frames) > 0:
batch_stacked_frames = np.stack(ep_frames, axis=1) # (b, t, *)
assert isinstance(videos_dir, Path)
for stacked_frames, done_index in zip(
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
):