Fix for PR #5
This commit is contained in:
parent
b33ec5a630
commit
b859e89936
|
@ -1,5 +1,4 @@
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -119,7 +118,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
if cfg.device == "cuda":
|
if cfg.device == "cuda":
|
||||||
assert torch.cuda.is_available()
|
assert torch.cuda.is_available()
|
||||||
else:
|
else:
|
||||||
warnings.warn("Using CPU, this will be slow.", UserWarning, stacklevel=1)
|
logging.warning("Using CPU, this will be slow.")
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
Loading…
Reference in New Issue