99 lines
3.6 KiB
Python
99 lines
3.6 KiB
Python
|
#!/usr/bin/python
|
||
|
# -*- encoding: utf-8 -*-
|
||
|
import numpy as np
|
||
|
from model import BiSeNet
|
||
|
|
||
|
import torch
|
||
|
|
||
|
import os
|
||
|
import os.path as osp
|
||
|
|
||
|
from PIL import Image
|
||
|
import torchvision.transforms as transforms
|
||
|
import cv2
|
||
|
from pathlib import Path
|
||
|
import configargparse
|
||
|
import tqdm
|
||
|
|
||
|
# import ttach as tta
|
||
|
|
||
|
def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg',
|
||
|
img_size=(512, 512)):
|
||
|
im = np.array(im)
|
||
|
vis_im = im.copy().astype(np.uint8)
|
||
|
vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
|
||
|
vis_parsing_anno = cv2.resize(
|
||
|
vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
|
||
|
vis_parsing_anno_color = np.zeros(
|
||
|
(vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + np.array([255, 255, 255]) # + 255
|
||
|
|
||
|
num_of_class = np.max(vis_parsing_anno)
|
||
|
# print(num_of_class)
|
||
|
for pi in range(1, 14):
|
||
|
index = np.where(vis_parsing_anno == pi)
|
||
|
vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0])
|
||
|
|
||
|
for pi in range(14, 16):
|
||
|
index = np.where(vis_parsing_anno == pi)
|
||
|
vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 255, 0])
|
||
|
for pi in range(16, 17):
|
||
|
index = np.where(vis_parsing_anno == pi)
|
||
|
vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 0, 255])
|
||
|
for pi in range(17, num_of_class+1):
|
||
|
index = np.where(vis_parsing_anno == pi)
|
||
|
vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0])
|
||
|
|
||
|
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
|
||
|
index = np.where(vis_parsing_anno == num_of_class-1)
|
||
|
vis_im = cv2.resize(vis_parsing_anno_color, img_size,
|
||
|
interpolation=cv2.INTER_NEAREST)
|
||
|
if save_im:
|
||
|
cv2.imwrite(save_path, vis_im)
|
||
|
|
||
|
|
||
|
def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'):
|
||
|
|
||
|
Path(respth).mkdir(parents=True, exist_ok=True)
|
||
|
|
||
|
print(f'[INFO] loading model...')
|
||
|
n_classes = 19
|
||
|
net = BiSeNet(n_classes=n_classes)
|
||
|
net.cuda()
|
||
|
net.load_state_dict(torch.load(cp))
|
||
|
net.eval()
|
||
|
|
||
|
to_tensor = transforms.Compose([
|
||
|
transforms.ToTensor(),
|
||
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
||
|
])
|
||
|
|
||
|
image_paths = os.listdir(dspth)
|
||
|
|
||
|
with torch.no_grad():
|
||
|
for image_path in tqdm.tqdm(image_paths):
|
||
|
if image_path.endswith('.jpg') or image_path.endswith('.png'):
|
||
|
img = Image.open(osp.join(dspth, image_path))
|
||
|
ori_size = img.size
|
||
|
image = img.resize((512, 512), Image.BILINEAR)
|
||
|
image = image.convert("RGB")
|
||
|
img = to_tensor(image)
|
||
|
|
||
|
# test-time augmentation.
|
||
|
inputs = torch.unsqueeze(img, 0) # [1, 3, 512, 512]
|
||
|
outputs = net(inputs.cuda())
|
||
|
parsing = outputs.mean(0).cpu().numpy().argmax(0)
|
||
|
|
||
|
image_path = int(image_path[:-4])
|
||
|
image_path = str(image_path) + '.png'
|
||
|
|
||
|
vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path), img_size=ori_size)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
parser = configargparse.ArgumentParser()
|
||
|
parser.add_argument('--respath', type=str, default='./result/', help='result path for label')
|
||
|
parser.add_argument('--imgpath', type=str, default='./imgs/', help='path for input images')
|
||
|
parser.add_argument('--modelpath', type=str, default='data_utils/face_parsing/79999_iter.pth')
|
||
|
args = parser.parse_args()
|
||
|
evaluate(respth=args.respath, dspth=args.imgpath, cp=args.modelpath)
|