add amp to eval script
This commit is contained in:
parent
b059759824
commit
39b6fcbe1e
|
@ -46,6 +46,7 @@ import json
|
|||
import logging
|
||||
import threading
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from datetime import datetime as dt
|
||||
from pathlib import Path
|
||||
|
@ -520,7 +521,7 @@ def eval(
|
|||
raise NotImplementedError()
|
||||
|
||||
# Check device is available
|
||||
get_safe_torch_device(hydra_cfg.device, log=True)
|
||||
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
@ -539,6 +540,7 @@ def eval(
|
|||
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
|
||||
policy.eval()
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
|
||||
info = eval_policy(
|
||||
env,
|
||||
policy,
|
||||
|
|
Loading…
Reference in New Issue