112 lines
3.7 KiB
Python
112 lines
3.7 KiB
Python
import logging
|
|
import os.path as osp
|
|
import random
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
import hydra
|
|
import numpy as np
|
|
import torch
|
|
from omegaconf import DictConfig
|
|
|
|
|
|
def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
|
|
"""Given a string, return a torch.device with checks on whether the device is available."""
|
|
match cfg_device:
|
|
case "cuda":
|
|
assert torch.cuda.is_available()
|
|
device = torch.device("cuda")
|
|
case "mps":
|
|
assert torch.backends.mps.is_available()
|
|
device = torch.device("mps")
|
|
case "cpu":
|
|
device = torch.device("cpu")
|
|
if log:
|
|
logging.warning("Using CPU, this will be slow.")
|
|
case _:
|
|
device = torch.device(cfg_device)
|
|
if log:
|
|
logging.warning(f"Using custom {cfg_device} device.")
|
|
|
|
return device
|
|
|
|
|
|
def set_global_seed(seed):
|
|
"""Set seed for reproducibility."""
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
def init_logging():
|
|
def custom_format(record):
|
|
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
fnameline = f"{record.pathname}:{record.lineno}"
|
|
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
|
|
return message
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
for handler in logging.root.handlers[:]:
|
|
logging.root.removeHandler(handler)
|
|
|
|
formatter = logging.Formatter()
|
|
formatter.format = custom_format
|
|
console_handler = logging.StreamHandler()
|
|
console_handler.setFormatter(formatter)
|
|
logging.getLogger().addHandler(console_handler)
|
|
|
|
|
|
def format_big_number(num):
|
|
suffixes = ["", "K", "M", "B", "T", "Q"]
|
|
divisor = 1000.0
|
|
|
|
for suffix in suffixes:
|
|
if abs(num) < divisor:
|
|
return f"{num:.0f}{suffix}"
|
|
num /= divisor
|
|
|
|
return num
|
|
|
|
|
|
def _relative_path_between(path1: Path, path2: Path) -> Path:
|
|
"""Returns path1 relative to path2."""
|
|
path1 = path1.absolute()
|
|
path2 = path2.absolute()
|
|
try:
|
|
return path1.relative_to(path2)
|
|
except ValueError: # most likely because path1 is not a subpath of path2
|
|
common_parts = Path(osp.commonpath([path1, path2])).parts
|
|
return Path(
|
|
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
|
|
)
|
|
|
|
|
|
def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> DictConfig:
|
|
"""Initialize a Hydra config given only the path to the relevant config file.
|
|
|
|
For config resolution, it is assumed that the config file's parent is the Hydra config dir.
|
|
"""
|
|
# TODO(alexander-soare): Resolve configs without Hydra initialization.
|
|
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
|
# Hydra needs a path relative to this file.
|
|
hydra.initialize(
|
|
str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent))
|
|
)
|
|
cfg = hydra.compose(Path(config_path).stem, overrides)
|
|
return cfg
|
|
|
|
|
|
def print_cuda_memory_usage():
|
|
"""Use this function to locate and debug memory leak."""
|
|
import gc
|
|
|
|
gc.collect()
|
|
# Also clear the cache if you want to fully release the memory
|
|
torch.cuda.empty_cache()
|
|
print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2))
|
|
print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2))
|
|
print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2))
|
|
print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))
|