import torch import time import os import cv2 import numpy as np from PIL import Image from .model import BiSeNet import torchvision.transforms as transforms class FaceParsing(): def __init__(self,resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth', model_pth='./models/face-parse-bisent/79999_iter.pth'): self.net = self.model_init(resnet_path,model_pth) self.preprocess = self.image_preprocess() def model_init(self, resnet_path, model_pth): net = BiSeNet(resnet_path) if torch.cuda.is_available(): net.cuda() net.load_state_dict(torch.load(model_pth)) else: net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu'))) net.eval() return net def image_preprocess(self): return transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) def __call__(self, image, size=(512, 512)): if isinstance(image, str): image = Image.open(image) width, height = image.size with torch.no_grad(): image = image.resize(size, Image.BILINEAR) img = self.preprocess(image) if torch.cuda.is_available(): img = torch.unsqueeze(img, 0).cuda() else: img = torch.unsqueeze(img, 0) out = self.net(img)[0] parsing = out.squeeze(0).cpu().numpy().argmax(0) parsing[np.where(parsing>13)] = 0 parsing[np.where(parsing>=1)] = 255 parsing = Image.fromarray(parsing.astype(np.uint8)) return parsing if __name__ == "__main__": fp = FaceParsing() segmap = fp('154_small.png') segmap.save('res.png')