231 lines
7.1 KiB
Python
231 lines
7.1 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import logging
|
|
import os
|
|
import os.path as osp
|
|
import platform
|
|
import subprocess
|
|
from copy import copy
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
def none_or_int(value):
|
|
if value == "None":
|
|
return None
|
|
return int(value)
|
|
|
|
|
|
def inside_slurm():
|
|
"""Check whether the python process was launched through slurm"""
|
|
# TODO(rcadene): return False for interactive mode `--pty bash`
|
|
return "SLURM_JOB_ID" in os.environ
|
|
|
|
|
|
def auto_select_torch_device() -> torch.device:
|
|
"""Tries to select automatically a torch device."""
|
|
if torch.cuda.is_available():
|
|
logging.info("Cuda backend detected, using cuda.")
|
|
return torch.device("cuda")
|
|
elif torch.backends.mps.is_available():
|
|
logging.info("Metal backend detected, using cuda.")
|
|
return torch.device("mps")
|
|
else:
|
|
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
|
|
return torch.device("cpu")
|
|
|
|
|
|
# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
|
|
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
|
"""Given a string, return a torch.device with checks on whether the device is available."""
|
|
try_device = str(try_device)
|
|
match try_device:
|
|
case "cuda":
|
|
assert torch.cuda.is_available()
|
|
device = torch.device("cuda")
|
|
case "mps":
|
|
assert torch.backends.mps.is_available()
|
|
device = torch.device("mps")
|
|
case "cpu":
|
|
device = torch.device("cpu")
|
|
if log:
|
|
logging.warning("Using CPU, this will be slow.")
|
|
case _:
|
|
device = torch.device(try_device)
|
|
if log:
|
|
logging.warning(f"Using custom {try_device} device.")
|
|
|
|
return device
|
|
|
|
|
|
def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
|
|
"""
|
|
mps is currently not compatible with float64
|
|
"""
|
|
if isinstance(device, torch.device):
|
|
device = device.type
|
|
if device == "mps" and dtype == torch.float64:
|
|
return torch.float32
|
|
else:
|
|
return dtype
|
|
|
|
|
|
def is_torch_device_available(try_device: str) -> bool:
|
|
try_device = str(try_device) # Ensure try_device is a string
|
|
if try_device == "cuda":
|
|
return torch.cuda.is_available()
|
|
elif try_device == "mps":
|
|
return torch.backends.mps.is_available()
|
|
elif try_device == "cpu":
|
|
return True
|
|
else:
|
|
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
|
|
|
|
|
|
def is_amp_available(device: str):
|
|
if device in ["cuda", "cpu"]:
|
|
return True
|
|
elif device == "mps":
|
|
return False
|
|
else:
|
|
raise ValueError(f"Unknown device '{device}.")
|
|
|
|
|
|
def init_logging():
|
|
def custom_format(record):
|
|
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
fnameline = f"{record.pathname}:{record.lineno}"
|
|
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
|
|
return message
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
for handler in logging.root.handlers[:]:
|
|
logging.root.removeHandler(handler)
|
|
|
|
formatter = logging.Formatter()
|
|
formatter.format = custom_format
|
|
console_handler = logging.StreamHandler()
|
|
console_handler.setFormatter(formatter)
|
|
logging.getLogger().addHandler(console_handler)
|
|
|
|
|
|
def format_big_number(num, precision=0):
|
|
suffixes = ["", "K", "M", "B", "T", "Q"]
|
|
divisor = 1000.0
|
|
|
|
for suffix in suffixes:
|
|
if abs(num) < divisor:
|
|
return f"{num:.{precision}f}{suffix}"
|
|
num /= divisor
|
|
|
|
return num
|
|
|
|
|
|
def _relative_path_between(path1: Path, path2: Path) -> Path:
|
|
"""Returns path1 relative to path2."""
|
|
path1 = path1.absolute()
|
|
path2 = path2.absolute()
|
|
try:
|
|
return path1.relative_to(path2)
|
|
except ValueError: # most likely because path1 is not a subpath of path2
|
|
common_parts = Path(osp.commonpath([path1, path2])).parts
|
|
return Path(
|
|
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
|
|
)
|
|
|
|
|
|
def print_cuda_memory_usage():
|
|
"""Use this function to locate and debug memory leak."""
|
|
import gc
|
|
|
|
gc.collect()
|
|
# Also clear the cache if you want to fully release the memory
|
|
torch.cuda.empty_cache()
|
|
print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2))
|
|
print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2))
|
|
print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2))
|
|
print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))
|
|
|
|
|
|
def capture_timestamp_utc():
|
|
return datetime.now(timezone.utc)
|
|
|
|
|
|
def say(text, blocking=False):
|
|
system = platform.system()
|
|
|
|
if system == "Darwin":
|
|
cmd = ["say", text]
|
|
|
|
elif system == "Linux":
|
|
cmd = ["spd-say", text]
|
|
if blocking:
|
|
cmd.append("--wait")
|
|
|
|
elif system == "Windows":
|
|
cmd = [
|
|
"PowerShell",
|
|
"-Command",
|
|
"Add-Type -AssemblyName System.Speech; "
|
|
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')",
|
|
]
|
|
|
|
else:
|
|
raise RuntimeError("Unsupported operating system for text-to-speech.")
|
|
|
|
if blocking:
|
|
subprocess.run(cmd, check=True)
|
|
else:
|
|
subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0)
|
|
|
|
|
|
def log_say(text, play_sounds, blocking=False):
|
|
logging.info(text)
|
|
|
|
if play_sounds:
|
|
say(text, blocking)
|
|
|
|
|
|
def get_channel_first_image_shape(image_shape: tuple) -> tuple:
|
|
shape = copy(image_shape)
|
|
if shape[2] < shape[0] and shape[2] < shape[1]: # (h, w, c) -> (c, h, w)
|
|
shape = (shape[2], shape[0], shape[1])
|
|
elif not (shape[0] < shape[1] and shape[0] < shape[2]):
|
|
raise ValueError(image_shape)
|
|
|
|
return shape
|
|
|
|
|
|
def has_method(cls: object, method_name: str) -> bool:
|
|
return hasattr(cls, method_name) and callable(getattr(cls, method_name))
|
|
|
|
|
|
def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
|
|
"""
|
|
Return True if a given string can be converted to a numpy dtype.
|
|
"""
|
|
try:
|
|
# Attempt to convert the string to a numpy dtype
|
|
np.dtype(dtype_str)
|
|
return True
|
|
except TypeError:
|
|
# If a TypeError is raised, the string is not a valid dtype
|
|
return False
|