diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index b3475167..fa7e1096 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -249,7 +249,7 @@ def eval_policy( threads.append(thread) episode_counter += 1 - videos = batch_stacked_frames.transpose(0, 3, 1, 2) + videos = einops.rearrange(batch_stacked_frames, "b t h w c -> b t c h w") for thread in threads: thread.join() @@ -328,6 +328,9 @@ def eval(cfg: dict, out_dir=None, stats_path=None): # Save info with open(Path(out_dir) / "eval_info.json", "w") as f: + # remove pytorch tensors which are not serializable to save the evaluation results only + del info["episodes"] + del info["videos"] json.dump(info, f, indent=2) logging.info("End of eval")