332 lines
12 KiB
Python
332 lines
12 KiB
Python
|
import argparse
|
|||
|
import glob
|
|||
|
import json
|
|||
|
import os
|
|||
|
import pickle
|
|||
|
import shutil
|
|||
|
import sys
|
|||
|
|
|||
|
import cv2
|
|||
|
import numpy as np
|
|||
|
import torch
|
|||
|
import torchvision.transforms as transforms
|
|||
|
from PIL import Image
|
|||
|
from diffusers import AutoencoderKL
|
|||
|
from face_alignment import NetworkSize
|
|||
|
from mmpose.apis import inference_topdown, init_model
|
|||
|
from mmpose.structures import merge_data_samples
|
|||
|
from tqdm import tqdm
|
|||
|
|
|||
|
from utils.face_parsing import FaceParsing
|
|||
|
|
|||
|
|
|||
|
def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
|
|||
|
cap = cv2.VideoCapture(vid_path)
|
|||
|
count = 0
|
|||
|
while True:
|
|||
|
if count > cut_frame:
|
|||
|
break
|
|||
|
ret, frame = cap.read()
|
|||
|
if ret:
|
|||
|
cv2.imwrite(f"{save_path}/{count:08d}.png", frame)
|
|||
|
count += 1
|
|||
|
else:
|
|||
|
break
|
|||
|
|
|||
|
|
|||
|
def read_imgs(img_list):
|
|||
|
frames = []
|
|||
|
print('reading images...')
|
|||
|
for img_path in tqdm(img_list):
|
|||
|
frame = cv2.imread(img_path)
|
|||
|
frames.append(frame)
|
|||
|
return frames
|
|||
|
|
|||
|
|
|||
|
def get_landmark_and_bbox(img_list, upperbondrange=0):
|
|||
|
frames = read_imgs(img_list)
|
|||
|
batch_size_fa = 1
|
|||
|
batches = [frames[i:i + batch_size_fa] for i in range(0, len(frames), batch_size_fa)]
|
|||
|
coords_list = []
|
|||
|
landmarks = []
|
|||
|
if upperbondrange != 0:
|
|||
|
print('get key_landmark and face bounding boxes with the bbox_shift:', upperbondrange)
|
|||
|
else:
|
|||
|
print('get key_landmark and face bounding boxes with the default value')
|
|||
|
average_range_minus = []
|
|||
|
average_range_plus = []
|
|||
|
for fb in tqdm(batches):
|
|||
|
results = inference_topdown(model, np.asarray(fb)[0])
|
|||
|
results = merge_data_samples(results)
|
|||
|
keypoints = results.pred_instances.keypoints
|
|||
|
face_land_mark = keypoints[0][23:91]
|
|||
|
face_land_mark = face_land_mark.astype(np.int32)
|
|||
|
|
|||
|
# get bounding boxes by face detetion
|
|||
|
bbox = fa.get_detections_for_batch(np.asarray(fb))
|
|||
|
|
|||
|
# adjust the bounding box refer to landmark
|
|||
|
# Add the bounding box to a tuple and append it to the coordinates list
|
|||
|
for j, f in enumerate(bbox):
|
|||
|
if f is None: # no face in the image
|
|||
|
coords_list += [coord_placeholder]
|
|||
|
continue
|
|||
|
|
|||
|
half_face_coord = face_land_mark[29] # np.mean([face_land_mark[28], face_land_mark[29]], axis=0)
|
|||
|
range_minus = (face_land_mark[30] - face_land_mark[29])[1]
|
|||
|
range_plus = (face_land_mark[29] - face_land_mark[28])[1]
|
|||
|
average_range_minus.append(range_minus)
|
|||
|
average_range_plus.append(range_plus)
|
|||
|
if upperbondrange != 0:
|
|||
|
half_face_coord[1] = upperbondrange + half_face_coord[1] # 手动调整 + 向下(偏29) - 向上(偏28)
|
|||
|
half_face_dist = np.max(face_land_mark[:, 1]) - half_face_coord[1]
|
|||
|
upper_bond = half_face_coord[1] - half_face_dist
|
|||
|
|
|||
|
f_landmark = (
|
|||
|
np.min(face_land_mark[:, 0]), int(upper_bond), np.max(face_land_mark[:, 0]),
|
|||
|
np.max(face_land_mark[:, 1]))
|
|||
|
x1, y1, x2, y2 = f_landmark
|
|||
|
|
|||
|
if y2 - y1 <= 0 or x2 - x1 <= 0 or x1 < 0: # if the landmark bbox is not suitable, reuse the bbox
|
|||
|
coords_list += [f]
|
|||
|
w, h = f[2] - f[0], f[3] - f[1]
|
|||
|
print("error bbox:", f)
|
|||
|
else:
|
|||
|
coords_list += [f_landmark]
|
|||
|
return coords_list, frames
|
|||
|
|
|||
|
|
|||
|
class FaceAlignment:
|
|||
|
def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
|
|||
|
device='cuda', flip_input=False, face_detector='sfd', verbose=False):
|
|||
|
self.device = device
|
|||
|
self.flip_input = flip_input
|
|||
|
self.landmarks_type = landmarks_type
|
|||
|
self.verbose = verbose
|
|||
|
|
|||
|
network_size = int(network_size)
|
|||
|
if 'cuda' in device:
|
|||
|
torch.backends.cudnn.benchmark = True
|
|||
|
# torch.backends.cuda.matmul.allow_tf32 = False
|
|||
|
# torch.backends.cudnn.benchmark = True
|
|||
|
# torch.backends.cudnn.deterministic = False
|
|||
|
# torch.backends.cudnn.allow_tf32 = True
|
|||
|
print('cuda start')
|
|||
|
|
|||
|
# Get the face detector
|
|||
|
face_detector_module = __import__('face_detection.detection.' + face_detector,
|
|||
|
globals(), locals(), [face_detector], 0)
|
|||
|
|
|||
|
self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
|
|||
|
|
|||
|
def get_detections_for_batch(self, images):
|
|||
|
images = images[..., ::-1]
|
|||
|
detected_faces = self.face_detector.detect_from_batch(images.copy())
|
|||
|
results = []
|
|||
|
|
|||
|
for i, d in enumerate(detected_faces):
|
|||
|
if len(d) == 0:
|
|||
|
results.append(None)
|
|||
|
continue
|
|||
|
d = d[0]
|
|||
|
d = np.clip(d, 0, None)
|
|||
|
|
|||
|
x1, y1, x2, y2 = map(int, d[:-1])
|
|||
|
results.append((x1, y1, x2, y2))
|
|||
|
return results
|
|||
|
|
|||
|
|
|||
|
def get_mask_tensor():
|
|||
|
"""
|
|||
|
Creates a mask tensor for image processing.
|
|||
|
:return: A mask tensor.
|
|||
|
"""
|
|||
|
mask_tensor = torch.zeros((256, 256))
|
|||
|
mask_tensor[:256 // 2, :] = 1
|
|||
|
mask_tensor[mask_tensor < 0.5] = 0
|
|||
|
mask_tensor[mask_tensor >= 0.5] = 1
|
|||
|
return mask_tensor
|
|||
|
|
|||
|
|
|||
|
def preprocess_img(img_name, half_mask=False):
|
|||
|
window = []
|
|||
|
if isinstance(img_name, str):
|
|||
|
window_fnames = [img_name]
|
|||
|
for fname in window_fnames:
|
|||
|
img = cv2.imread(fname)
|
|||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|||
|
img = cv2.resize(img, (256, 256),
|
|||
|
interpolation=cv2.INTER_LANCZOS4)
|
|||
|
window.append(img)
|
|||
|
else:
|
|||
|
img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB)
|
|||
|
window.append(img)
|
|||
|
x = np.asarray(window) / 255.
|
|||
|
x = np.transpose(x, (3, 0, 1, 2))
|
|||
|
x = torch.squeeze(torch.FloatTensor(x))
|
|||
|
if half_mask:
|
|||
|
x = x * (get_mask_tensor() > 0.5)
|
|||
|
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
|||
|
x = normalize(x)
|
|||
|
x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor
|
|||
|
x = x.to(device)
|
|||
|
return x
|
|||
|
|
|||
|
|
|||
|
def encode_latents(image):
|
|||
|
with torch.no_grad():
|
|||
|
init_latent_dist = vae.encode(image.to(vae.dtype)).latent_dist
|
|||
|
init_latents = vae.config.scaling_factor * init_latent_dist.sample()
|
|||
|
return init_latents
|
|||
|
|
|||
|
|
|||
|
def get_latents_for_unet(img):
|
|||
|
ref_image = preprocess_img(img, half_mask=True) # [1, 3, 256, 256] RGB, torch tensor
|
|||
|
masked_latents = encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
|
|||
|
ref_image = preprocess_img(img, half_mask=False) # [1, 3, 256, 256] RGB, torch tensor
|
|||
|
ref_latents = encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
|
|||
|
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
|
|||
|
return latent_model_input
|
|||
|
|
|||
|
|
|||
|
def get_crop_box(box, expand):
|
|||
|
x, y, x1, y1 = box
|
|||
|
x_c, y_c = (x + x1) // 2, (y + y1) // 2
|
|||
|
w, h = x1 - x, y1 - y
|
|||
|
s = int(max(w, h) // 2 * expand)
|
|||
|
crop_box = [x_c - s, y_c - s, x_c + s, y_c + s]
|
|||
|
return crop_box, s
|
|||
|
|
|||
|
|
|||
|
def face_seg(image):
|
|||
|
seg_image = fp(image)
|
|||
|
if seg_image is None:
|
|||
|
print("error, no person_segment")
|
|||
|
return None
|
|||
|
|
|||
|
seg_image = seg_image.resize(image.size)
|
|||
|
return seg_image
|
|||
|
|
|||
|
|
|||
|
def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand=1.2):
|
|||
|
body = Image.fromarray(image[:, :, ::-1])
|
|||
|
|
|||
|
x, y, x1, y1 = face_box
|
|||
|
# print(x1-x,y1-y)
|
|||
|
crop_box, s = get_crop_box(face_box, expand)
|
|||
|
x_s, y_s, x_e, y_e = crop_box
|
|||
|
|
|||
|
face_large = body.crop(crop_box)
|
|||
|
ori_shape = face_large.size
|
|||
|
|
|||
|
mask_image = face_seg(face_large)
|
|||
|
mask_small = mask_image.crop((x - x_s, y - y_s, x1 - x_s, y1 - y_s))
|
|||
|
mask_image = Image.new('L', ori_shape, 0)
|
|||
|
mask_image.paste(mask_small, (x - x_s, y - y_s, x1 - x_s, y1 - y_s))
|
|||
|
|
|||
|
# keep upper_boundary_ratio of talking area
|
|||
|
width, height = mask_image.size
|
|||
|
top_boundary = int(height * upper_boundary_ratio)
|
|||
|
modified_mask_image = Image.new('L', ori_shape, 0)
|
|||
|
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
|
|||
|
|
|||
|
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
|
|||
|
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
|
|||
|
return mask_array, crop_box
|
|||
|
|
|||
|
|
|||
|
def create_dir(dir_path):
|
|||
|
if not os.path.exists(dir_path):
|
|||
|
os.makedirs(dir_path)
|
|||
|
|
|||
|
|
|||
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|||
|
|
|||
|
# initialize the mmpose model
|
|||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|||
|
fa = FaceAlignment(1, flip_input=False, device=device)
|
|||
|
config_file = './utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py'
|
|||
|
checkpoint_file = '../models/dwpose/dw-ll_ucoco_384.pth'
|
|||
|
model = init_model(config_file, checkpoint_file, device=device)
|
|||
|
vae = AutoencoderKL.from_pretrained("../models/sd-vae-ft-mse")
|
|||
|
vae.to(device)
|
|||
|
fp = FaceParsing()
|
|||
|
if __name__ == '__main__':
|
|||
|
# 视频文件地址
|
|||
|
parser = argparse.ArgumentParser()
|
|||
|
parser.add_argument("--file",
|
|||
|
type=str,
|
|||
|
default=r'D:\ok\test.mp4',
|
|||
|
)
|
|||
|
parser.add_argument("--avatar_id",
|
|||
|
type=str,
|
|||
|
default='1',
|
|||
|
)
|
|||
|
args = parser.parse_args()
|
|||
|
file = args.file
|
|||
|
# 保存文件设置 可以不动
|
|||
|
save_path = f'../data/avatars/avator_{args.avatar_id}'
|
|||
|
save_full_path = f'../data/avatars/avator_{args.avatar_id}/full_imgs'
|
|||
|
create_dir(save_path)
|
|||
|
create_dir(save_full_path)
|
|||
|
mask_out_path = f'../data/avatars/avator_{args.avatar_id}/mask'
|
|||
|
create_dir(mask_out_path)
|
|||
|
|
|||
|
# 模型
|
|||
|
mask_coords_path = f'{save_path}/mask_coords.pkl'
|
|||
|
coords_path = f'{save_path}/coords.pkl'
|
|||
|
latents_out_path = f'{save_path}/latents.pt'
|
|||
|
|
|||
|
with open(f'{save_path}/avator_info.json', "w") as f:
|
|||
|
json.dump({
|
|||
|
"avatar_id": args.avatar_id,
|
|||
|
"video_path": file,
|
|||
|
"bbox_shift": 5
|
|||
|
}, f)
|
|||
|
|
|||
|
if os.path.isfile(file):
|
|||
|
video2imgs(file, save_full_path, ext='png')
|
|||
|
else:
|
|||
|
files = os.listdir(file)
|
|||
|
files.sort()
|
|||
|
files = [file for file in files if file.split(".")[-1] == "png"]
|
|||
|
for filename in files:
|
|||
|
shutil.copyfile(f"{file}/{filename}", f"{save_full_path}/{filename}")
|
|||
|
input_img_list = sorted(glob.glob(os.path.join(save_full_path, '*.[jpJP][pnPN]*[gG]')))
|
|||
|
print("extracting landmarks...")
|
|||
|
coord_list, frame_list = get_landmark_and_bbox(input_img_list, 5)
|
|||
|
input_latent_list = []
|
|||
|
idx = -1
|
|||
|
# maker if the bbox is not sufficient
|
|||
|
coord_placeholder = (0.0, 0.0, 0.0, 0.0)
|
|||
|
for bbox, frame in zip(coord_list, frame_list):
|
|||
|
idx = idx + 1
|
|||
|
if bbox == coord_placeholder:
|
|||
|
continue
|
|||
|
x1, y1, x2, y2 = bbox
|
|||
|
crop_frame = frame[y1:y2, x1:x2]
|
|||
|
resized_crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
|
|||
|
latents = get_latents_for_unet(resized_crop_frame)
|
|||
|
input_latent_list.append(latents)
|
|||
|
|
|||
|
frame_list_cycle = frame_list + frame_list[::-1]
|
|||
|
coord_list_cycle = coord_list + coord_list[::-1]
|
|||
|
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
|||
|
mask_coords_list_cycle = []
|
|||
|
mask_list_cycle = []
|
|||
|
for i, frame in enumerate(tqdm(frame_list_cycle)):
|
|||
|
cv2.imwrite(f"{save_full_path}/{str(i).zfill(8)}.png", frame)
|
|||
|
|
|||
|
face_box = coord_list_cycle[i]
|
|||
|
mask, crop_box = get_image_prepare_material(frame, face_box)
|
|||
|
cv2.imwrite(f"{mask_out_path}/{str(i).zfill(8)}.png", mask)
|
|||
|
mask_coords_list_cycle += [crop_box]
|
|||
|
mask_list_cycle.append(mask)
|
|||
|
|
|||
|
with open(mask_coords_path, 'wb') as f:
|
|||
|
pickle.dump(mask_coords_list_cycle, f)
|
|||
|
|
|||
|
with open(coords_path, 'wb') as f:
|
|||
|
pickle.dump(coord_list_cycle, f)
|
|||
|
torch.save(input_latent_list_cycle, os.path.join(latents_out_path))
|