livetalking/musetalk/simple_musetalk.py

332 lines
12 KiB
Python
Raw Normal View History

import argparse
import glob
import json
import os
import pickle
import shutil
import sys
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from diffusers import AutoencoderKL
from face_alignment import NetworkSize
from mmpose.apis import inference_topdown, init_model
from mmpose.structures import merge_data_samples
from tqdm import tqdm
from utils.face_parsing import FaceParsing
def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
cap = cv2.VideoCapture(vid_path)
count = 0
while True:
if count > cut_frame:
break
ret, frame = cap.read()
if ret:
cv2.imwrite(f"{save_path}/{count:08d}.png", frame)
count += 1
else:
break
def read_imgs(img_list):
frames = []
print('reading images...')
for img_path in tqdm(img_list):
frame = cv2.imread(img_path)
frames.append(frame)
return frames
def get_landmark_and_bbox(img_list, upperbondrange=0):
frames = read_imgs(img_list)
batch_size_fa = 1
batches = [frames[i:i + batch_size_fa] for i in range(0, len(frames), batch_size_fa)]
coords_list = []
landmarks = []
if upperbondrange != 0:
print('get key_landmark and face bounding boxes with the bbox_shift:', upperbondrange)
else:
print('get key_landmark and face bounding boxes with the default value')
average_range_minus = []
average_range_plus = []
for fb in tqdm(batches):
results = inference_topdown(model, np.asarray(fb)[0])
results = merge_data_samples(results)
keypoints = results.pred_instances.keypoints
face_land_mark = keypoints[0][23:91]
face_land_mark = face_land_mark.astype(np.int32)
# get bounding boxes by face detetion
bbox = fa.get_detections_for_batch(np.asarray(fb))
# adjust the bounding box refer to landmark
# Add the bounding box to a tuple and append it to the coordinates list
for j, f in enumerate(bbox):
if f is None: # no face in the image
coords_list += [coord_placeholder]
continue
half_face_coord = face_land_mark[29] # np.mean([face_land_mark[28], face_land_mark[29]], axis=0)
range_minus = (face_land_mark[30] - face_land_mark[29])[1]
range_plus = (face_land_mark[29] - face_land_mark[28])[1]
average_range_minus.append(range_minus)
average_range_plus.append(range_plus)
if upperbondrange != 0:
half_face_coord[1] = upperbondrange + half_face_coord[1] # 手动调整 + 向下偏29 - 向上偏28
half_face_dist = np.max(face_land_mark[:, 1]) - half_face_coord[1]
upper_bond = half_face_coord[1] - half_face_dist
f_landmark = (
np.min(face_land_mark[:, 0]), int(upper_bond), np.max(face_land_mark[:, 0]),
np.max(face_land_mark[:, 1]))
x1, y1, x2, y2 = f_landmark
if y2 - y1 <= 0 or x2 - x1 <= 0 or x1 < 0: # if the landmark bbox is not suitable, reuse the bbox
coords_list += [f]
w, h = f[2] - f[0], f[3] - f[1]
print("error bbox:", f)
else:
coords_list += [f_landmark]
return coords_list, frames
class FaceAlignment:
def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
device='cuda', flip_input=False, face_detector='sfd', verbose=False):
self.device = device
self.flip_input = flip_input
self.landmarks_type = landmarks_type
self.verbose = verbose
network_size = int(network_size)
if 'cuda' in device:
torch.backends.cudnn.benchmark = True
# torch.backends.cuda.matmul.allow_tf32 = False
# torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = False
# torch.backends.cudnn.allow_tf32 = True
print('cuda start')
# Get the face detector
face_detector_module = __import__('face_detection.detection.' + face_detector,
globals(), locals(), [face_detector], 0)
self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
def get_detections_for_batch(self, images):
images = images[..., ::-1]
detected_faces = self.face_detector.detect_from_batch(images.copy())
results = []
for i, d in enumerate(detected_faces):
if len(d) == 0:
results.append(None)
continue
d = d[0]
d = np.clip(d, 0, None)
x1, y1, x2, y2 = map(int, d[:-1])
results.append((x1, y1, x2, y2))
return results
def get_mask_tensor():
"""
Creates a mask tensor for image processing.
:return: A mask tensor.
"""
mask_tensor = torch.zeros((256, 256))
mask_tensor[:256 // 2, :] = 1
mask_tensor[mask_tensor < 0.5] = 0
mask_tensor[mask_tensor >= 0.5] = 1
return mask_tensor
def preprocess_img(img_name, half_mask=False):
window = []
if isinstance(img_name, str):
window_fnames = [img_name]
for fname in window_fnames:
img = cv2.imread(fname)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (256, 256),
interpolation=cv2.INTER_LANCZOS4)
window.append(img)
else:
img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB)
window.append(img)
x = np.asarray(window) / 255.
x = np.transpose(x, (3, 0, 1, 2))
x = torch.squeeze(torch.FloatTensor(x))
if half_mask:
x = x * (get_mask_tensor() > 0.5)
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
x = normalize(x)
x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor
x = x.to(device)
return x
def encode_latents(image):
with torch.no_grad():
init_latent_dist = vae.encode(image.to(vae.dtype)).latent_dist
init_latents = vae.config.scaling_factor * init_latent_dist.sample()
return init_latents
def get_latents_for_unet(img):
ref_image = preprocess_img(img, half_mask=True) # [1, 3, 256, 256] RGB, torch tensor
masked_latents = encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
ref_image = preprocess_img(img, half_mask=False) # [1, 3, 256, 256] RGB, torch tensor
ref_latents = encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
return latent_model_input
def get_crop_box(box, expand):
x, y, x1, y1 = box
x_c, y_c = (x + x1) // 2, (y + y1) // 2
w, h = x1 - x, y1 - y
s = int(max(w, h) // 2 * expand)
crop_box = [x_c - s, y_c - s, x_c + s, y_c + s]
return crop_box, s
def face_seg(image):
seg_image = fp(image)
if seg_image is None:
print("error, no person_segment")
return None
seg_image = seg_image.resize(image.size)
return seg_image
def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand=1.2):
body = Image.fromarray(image[:, :, ::-1])
x, y, x1, y1 = face_box
# print(x1-x,y1-y)
crop_box, s = get_crop_box(face_box, expand)
x_s, y_s, x_e, y_e = crop_box
face_large = body.crop(crop_box)
ori_shape = face_large.size
mask_image = face_seg(face_large)
mask_small = mask_image.crop((x - x_s, y - y_s, x1 - x_s, y1 - y_s))
mask_image = Image.new('L', ori_shape, 0)
mask_image.paste(mask_small, (x - x_s, y - y_s, x1 - x_s, y1 - y_s))
# keep upper_boundary_ratio of talking area
width, height = mask_image.size
top_boundary = int(height * upper_boundary_ratio)
modified_mask_image = Image.new('L', ori_shape, 0)
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
return mask_array, crop_box
def create_dir(dir_path):
if not os.path.exists(dir_path):
os.makedirs(dir_path)
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# initialize the mmpose model
device = "cuda" if torch.cuda.is_available() else "cpu"
fa = FaceAlignment(1, flip_input=False, device=device)
config_file = './utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py'
checkpoint_file = '../models/dwpose/dw-ll_ucoco_384.pth'
model = init_model(config_file, checkpoint_file, device=device)
vae = AutoencoderKL.from_pretrained("../models/sd-vae-ft-mse")
vae.to(device)
fp = FaceParsing()
if __name__ == '__main__':
# 视频文件地址
parser = argparse.ArgumentParser()
parser.add_argument("--file",
type=str,
default=r'D:\ok\test.mp4',
)
parser.add_argument("--avatar_id",
type=str,
default='1',
)
args = parser.parse_args()
file = args.file
# 保存文件设置 可以不动
save_path = f'../data/avatars/avator_{args.avatar_id}'
save_full_path = f'../data/avatars/avator_{args.avatar_id}/full_imgs'
create_dir(save_path)
create_dir(save_full_path)
mask_out_path = f'../data/avatars/avator_{args.avatar_id}/mask'
create_dir(mask_out_path)
# 模型
mask_coords_path = f'{save_path}/mask_coords.pkl'
coords_path = f'{save_path}/coords.pkl'
latents_out_path = f'{save_path}/latents.pt'
with open(f'{save_path}/avator_info.json', "w") as f:
json.dump({
"avatar_id": args.avatar_id,
"video_path": file,
"bbox_shift": 5
}, f)
if os.path.isfile(file):
video2imgs(file, save_full_path, ext='png')
else:
files = os.listdir(file)
files.sort()
files = [file for file in files if file.split(".")[-1] == "png"]
for filename in files:
shutil.copyfile(f"{file}/{filename}", f"{save_full_path}/{filename}")
input_img_list = sorted(glob.glob(os.path.join(save_full_path, '*.[jpJP][pnPN]*[gG]')))
print("extracting landmarks...")
coord_list, frame_list = get_landmark_and_bbox(input_img_list, 5)
input_latent_list = []
idx = -1
# maker if the bbox is not sufficient
coord_placeholder = (0.0, 0.0, 0.0, 0.0)
for bbox, frame in zip(coord_list, frame_list):
idx = idx + 1
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
crop_frame = frame[y1:y2, x1:x2]
resized_crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
latents = get_latents_for_unet(resized_crop_frame)
input_latent_list.append(latents)
frame_list_cycle = frame_list + frame_list[::-1]
coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
mask_coords_list_cycle = []
mask_list_cycle = []
for i, frame in enumerate(tqdm(frame_list_cycle)):
cv2.imwrite(f"{save_full_path}/{str(i).zfill(8)}.png", frame)
face_box = coord_list_cycle[i]
mask, crop_box = get_image_prepare_material(frame, face_box)
cv2.imwrite(f"{mask_out_path}/{str(i).zfill(8)}.png", mask)
mask_coords_list_cycle += [crop_box]
mask_list_cycle.append(mask)
with open(mask_coords_path, 'wb') as f:
pickle.dump(mask_coords_list_cycle, f)
with open(coords_path, 'wb') as f:
pickle.dump(coord_list_cycle, f)
torch.save(input_latent_list_cycle, os.path.join(latents_out_path))