66 lines
1.7 KiB
Python
66 lines
1.7 KiB
Python
import logging
|
|
import random
|
|
from datetime import datetime
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
|
|
match cfg_device:
|
|
case "cuda":
|
|
assert torch.cuda.is_available()
|
|
device = torch.device("cuda")
|
|
case "mps":
|
|
assert torch.backends.mps.is_available()
|
|
device = torch.device("mps")
|
|
case "cpu":
|
|
device = torch.device("cpu")
|
|
if log:
|
|
logging.warning("Using CPU, this will be slow.")
|
|
case _:
|
|
device = torch.device(cfg_device)
|
|
if log:
|
|
logging.warning(f"Using custom {cfg_device} device.")
|
|
|
|
return device
|
|
|
|
|
|
def set_seed(seed):
|
|
"""Set seed for reproducibility."""
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
def init_logging():
|
|
def custom_format(record):
|
|
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
fnameline = f"{record.pathname}:{record.lineno}"
|
|
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
|
|
return message
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
for handler in logging.root.handlers[:]:
|
|
logging.root.removeHandler(handler)
|
|
|
|
formatter = logging.Formatter()
|
|
formatter.format = custom_format
|
|
console_handler = logging.StreamHandler()
|
|
console_handler.setFormatter(formatter)
|
|
logging.getLogger().addHandler(console_handler)
|
|
|
|
|
|
def format_big_number(num):
|
|
suffixes = ["", "K", "M", "B", "T", "Q"]
|
|
divisor = 1000.0
|
|
|
|
for suffix in suffixes:
|
|
if abs(num) < divisor:
|
|
return f"{num:.0f}{suffix}"
|
|
num /= divisor
|
|
|
|
return num
|