add amp to eval script
This commit is contained in:
parent
b059759824
commit
39b6fcbe1e
|
@ -46,6 +46,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from contextlib import nullcontext
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from datetime import datetime as dt
|
from datetime import datetime as dt
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -520,7 +521,7 @@ def eval(
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
# Check device is available
|
# Check device is available
|
||||||
get_safe_torch_device(hydra_cfg.device, log=True)
|
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
@ -539,16 +540,17 @@ def eval(
|
||||||
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
|
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
|
||||||
info = eval_policy(
|
with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
|
||||||
env,
|
info = eval_policy(
|
||||||
policy,
|
env,
|
||||||
hydra_cfg.eval.n_episodes,
|
policy,
|
||||||
max_episodes_rendered=10,
|
hydra_cfg.eval.n_episodes,
|
||||||
video_dir=Path(out_dir) / "eval",
|
max_episodes_rendered=10,
|
||||||
start_seed=hydra_cfg.seed,
|
video_dir=Path(out_dir) / "eval",
|
||||||
enable_progbar=True,
|
start_seed=hydra_cfg.seed,
|
||||||
enable_inner_progbar=True,
|
enable_progbar=True,
|
||||||
)
|
enable_inner_progbar=True,
|
||||||
|
)
|
||||||
print(info["aggregated"])
|
print(info["aggregated"])
|
||||||
|
|
||||||
# Save info
|
# Save info
|
||||||
|
|
Loading…
Reference in New Issue