From 39b6fcbe1edee41daa52a910623b8ccd34d9a13e Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 20 May 2024 13:14:00 +0100 Subject: [PATCH] add amp to eval script --- lerobot/scripts/eval.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 9c95633a..7e4690d0 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -46,6 +46,7 @@ import json import logging import threading import time +from contextlib import nullcontext from copy import deepcopy from datetime import datetime as dt from pathlib import Path @@ -520,7 +521,7 @@ def eval( raise NotImplementedError() # Check device is available - get_safe_torch_device(hydra_cfg.device, log=True) + device = get_safe_torch_device(hydra_cfg.device, log=True) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -539,16 +540,17 @@ def eval( policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats) policy.eval() - info = eval_policy( - env, - policy, - hydra_cfg.eval.n_episodes, - max_episodes_rendered=10, - video_dir=Path(out_dir) / "eval", - start_seed=hydra_cfg.seed, - enable_progbar=True, - enable_inner_progbar=True, - ) + with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(): + info = eval_policy( + env, + policy, + hydra_cfg.eval.n_episodes, + max_episodes_rendered=10, + video_dir=Path(out_dir) / "eval", + start_seed=hydra_cfg.seed, + enable_progbar=True, + enable_inner_progbar=True, + ) print(info["aggregated"]) # Save info