add amp to eval script
This commit is contained in:
parent
b059759824
commit
39b6fcbe1e
|
@ -46,6 +46,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from contextlib import nullcontext
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from datetime import datetime as dt
|
from datetime import datetime as dt
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -520,7 +521,7 @@ def eval(
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
# Check device is available
|
# Check device is available
|
||||||
get_safe_torch_device(hydra_cfg.device, log=True)
|
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
@ -539,6 +540,7 @@ def eval(
|
||||||
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
|
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
|
||||||
|
with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
|
||||||
info = eval_policy(
|
info = eval_policy(
|
||||||
env,
|
env,
|
||||||
policy,
|
policy,
|
||||||
|
|
Loading…
Reference in New Issue