2024-05-25 06:33:59 +08:00
|
|
|
import torch
|
|
|
|
import time
|
|
|
|
import os
|
|
|
|
import cv2
|
|
|
|
import numpy as np
|
|
|
|
from PIL import Image
|
|
|
|
from .model import BiSeNet
|
|
|
|
import torchvision.transforms as transforms
|
|
|
|
|
2024-06-20 20:21:37 +08:00
|
|
|
|
2024-05-25 06:33:59 +08:00
|
|
|
class FaceParsing():
|
2024-06-20 20:21:37 +08:00
|
|
|
def __init__(self, resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
|
|
|
|
model_pth='./models/face-parse-bisent/79999_iter.pth'):
|
|
|
|
self.net = self.model_init(resnet_path,model_pth)
|
2024-05-25 06:33:59 +08:00
|
|
|
self.preprocess = self.image_preprocess()
|
|
|
|
|
2024-06-20 20:21:37 +08:00
|
|
|
def model_init(self,resnet_path, model_pth):
|
2024-05-25 06:33:59 +08:00
|
|
|
net = BiSeNet(resnet_path)
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
net.cuda()
|
2024-06-20 20:21:37 +08:00
|
|
|
net.load_state_dict(torch.load(model_pth))
|
2024-05-25 06:33:59 +08:00
|
|
|
else:
|
|
|
|
net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu')))
|
|
|
|
net.eval()
|
|
|
|
return net
|
|
|
|
|
|
|
|
def image_preprocess(self):
|
|
|
|
return transforms.Compose([
|
|
|
|
transforms.ToTensor(),
|
|
|
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
|
|
|
])
|
|
|
|
|
|
|
|
def __call__(self, image, size=(512, 512)):
|
|
|
|
if isinstance(image, str):
|
|
|
|
image = Image.open(image)
|
|
|
|
|
|
|
|
width, height = image.size
|
|
|
|
with torch.no_grad():
|
|
|
|
image = image.resize(size, Image.BILINEAR)
|
|
|
|
img = self.preprocess(image)
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
img = torch.unsqueeze(img, 0).cuda()
|
|
|
|
else:
|
|
|
|
img = torch.unsqueeze(img, 0)
|
|
|
|
out = self.net(img)[0]
|
|
|
|
parsing = out.squeeze(0).cpu().numpy().argmax(0)
|
2024-06-20 20:21:37 +08:00
|
|
|
parsing[np.where(parsing > 13)] = 0
|
|
|
|
parsing[np.where(parsing >= 1)] = 255
|
2024-05-25 06:33:59 +08:00
|
|
|
parsing = Image.fromarray(parsing.astype(np.uint8))
|
|
|
|
return parsing
|
|
|
|
|
2024-06-20 20:21:37 +08:00
|
|
|
|
2024-05-25 06:33:59 +08:00
|
|
|
if __name__ == "__main__":
|
|
|
|
fp = FaceParsing()
|
|
|
|
segmap = fp('154_small.png')
|
|
|
|
segmap.save('res.png')
|