feat: 添加 简单自动生成musetalk数字人

This commit is contained in:
Yun 2024-06-20 20:21:37 +08:00
parent 5da818b9d9
commit c0682408c5
4 changed files with 454 additions and 84 deletions

View File

@ -172,6 +172,13 @@ python -m scripts.realtime_inference --inference_config configs/inference/realti
运行后将results/avatars下文件拷到本项目的data/avatars下 运行后将results/avatars下文件拷到本项目的data/avatars下
``` ```
```bash
也可以试用本地目录下的 simple_musetalk.py
cd musetalk
python simple_musetalk.py --avatar_id 2 --file D:\\ok\\test.mp4
运行后将直接生成在data/avatars下
```
### 3.10 模型用wav2lip ### 3.10 模型用wav2lip
暂不支持rtmp推送 暂不支持rtmp推送
- 下载模型 - 下载模型

78
app.py
View File

@ -24,7 +24,6 @@ import argparse
import shutil import shutil
import asyncio import asyncio
app = Flask(__name__) app = Flask(__name__)
sockets = Sockets(app) sockets = Sockets(app)
global nerfreal global nerfreal
@ -59,6 +58,7 @@ def llm_response(message):
print(response) print(response)
return response return response
@sockets.route('/humanchat') @sockets.route('/humanchat')
def chat_socket(ws): def chat_socket(ws):
# 获取WebSocket对象 # 获取WebSocket对象
@ -79,9 +79,11 @@ def chat_socket(ws):
res = llm_response(message) res = llm_response(message)
nerfreal.put_msg_txt(res) nerfreal.put_msg_txt(res)
#####webrtc############################### #####webrtc###############################
pcs = set() pcs = set()
# @app.route('/offer', methods=['POST']) # @app.route('/offer', methods=['POST'])
async def offer(request): async def offer(request):
params = await request.json() params = await request.json()
@ -115,6 +117,7 @@ async def offer(request):
), ),
) )
async def human(request): async def human(request):
params = await request.json() params = await request.json()
@ -131,12 +134,14 @@ async def human(request):
), ),
) )
async def on_shutdown(app): async def on_shutdown(app):
# close peer connections # close peer connections
coros = [pc.close() for pc in pcs] coros = [pc.close() for pc in pcs]
await asyncio.gather(*coros) await asyncio.gather(*coros)
pcs.clear() pcs.clear()
async def post(url, data): async def post(url, data):
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
@ -145,6 +150,7 @@ async def post(url,data):
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
print(f'Error: {e}') print(f'Error: {e}')
async def run(push_url): async def run(push_url):
pc = RTCPeerConnection() pc = RTCPeerConnection()
pcs.add(pc) pcs.add(pc)
@ -163,6 +169,8 @@ async def run(push_url):
await pc.setLocalDescription(await pc.createOffer()) await pc.setLocalDescription(await pc.createOffer())
answer = await post(push_url, pc.localDescription.sdp) answer = await post(push_url, pc.localDescription.sdp)
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer, type='answer')) await pc.setRemoteDescription(RTCSessionDescription(sdp=answer, type='answer'))
########################################## ##########################################
# os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' # os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
# os.environ['MULTIPROCESSING_METHOD'] = 'forkserver' # os.environ['MULTIPROCESSING_METHOD'] = 'forkserver'
@ -182,13 +190,19 @@ if __name__ == '__main__':
### training options ### training options
parser.add_argument('--ckpt', type=str, default='data/pretrained/ngp_kf.pth') parser.add_argument('--ckpt', type=str, default='data/pretrained/ngp_kf.pth')
parser.add_argument('--num_rays', type=int, default=4096 * 16, help="num rays sampled per image for each training step") parser.add_argument('--num_rays', type=int, default=4096 * 16,
help="num rays sampled per image for each training step")
parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
parser.add_argument('--max_steps', type=int, default=16, help="max num steps sampled per ray (only valid when using --cuda_ray)") parser.add_argument('--max_steps', type=int, default=16,
parser.add_argument('--num_steps', type=int, default=16, help="num steps sampled per ray (only valid when NOT using --cuda_ray)") help="max num steps sampled per ray (only valid when using --cuda_ray)")
parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)") parser.add_argument('--num_steps', type=int, default=16,
parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)") help="num steps sampled per ray (only valid when NOT using --cuda_ray)")
parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)") parser.add_argument('--upsample_steps', type=int, default=0,
help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)")
parser.add_argument('--update_extra_interval', type=int, default=16,
help="iter interval to update extra status (only valid when using --cuda_ray)")
parser.add_argument('--max_ray_batch', type=int, default=4096,
help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)")
### loss set ### loss set
parser.add_argument('--warmup_step', type=int, default=10000, help="warm up steps") parser.add_argument('--warmup_step', type=int, default=10000, help="warm up steps")
@ -203,23 +217,31 @@ if __name__ == '__main__':
parser.add_argument('--bg_img', type=str, default='white', help="background image") parser.add_argument('--bg_img', type=str, default='white', help="background image")
parser.add_argument('--fbg', action='store_true', help="frame-wise bg") parser.add_argument('--fbg', action='store_true', help="frame-wise bg")
parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes") parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes")
parser.add_argument('--fix_eye', type=float, default=-1, help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye") parser.add_argument('--fix_eye', type=float, default=-1,
help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye")
parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence") parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence")
parser.add_argument('--torso_shrink', type=float, default=0.8, help="shrink bg coords to allow more flexibility in deform") parser.add_argument('--torso_shrink', type=float, default=0.8,
help="shrink bg coords to allow more flexibility in deform")
### dataset options ### dataset options
parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)") parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)")
parser.add_argument('--preload', type=int, default=0, help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.") parser.add_argument('--preload', type=int, default=0,
help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.")
# (the default value is for the fox dataset) # (the default value is for the fox dataset)
parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.") parser.add_argument('--bound', type=float, default=1,
help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.")
parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3") parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3")
parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location") parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location")
parser.add_argument('--dt_gamma', type=float, default=1/256, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") parser.add_argument('--dt_gamma', type=float, default=1 / 256,
help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera") parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera")
parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied (sigma)") parser.add_argument('--density_thresh', type=float, default=10,
parser.add_argument('--density_thresh_torso', type=float, default=0.01, help="threshold for density grid to be occupied (alpha)") help="threshold for density grid to be occupied (sigma)")
parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable") parser.add_argument('--density_thresh_torso', type=float, default=0.01,
help="threshold for density grid to be occupied (alpha)")
parser.add_argument('--patch_size', type=int, default=1,
help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable")
parser.add_argument('--init_lips', action='store_true', help="init lips region") parser.add_argument('--init_lips', action='store_true', help="init lips region")
parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region") parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region")
@ -237,12 +259,15 @@ if __name__ == '__main__':
parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel") parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel")
### else ### else
parser.add_argument('--att', type=int, default=2, help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)") parser.add_argument('--att', type=int, default=2,
parser.add_argument('--aud', type=str, default='', help="audio source (empty will load the default, else should be a path to a npy file)") help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)")
parser.add_argument('--aud', type=str, default='',
help="audio source (empty will load the default, else should be a path to a npy file)")
parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits") parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits")
parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off") parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off")
parser.add_argument('--ind_num', type=int, default=10000, help="number of individual codes, should be larger than training dataset size") parser.add_argument('--ind_num', type=int, default=10000,
help="number of individual codes, should be larger than training dataset size")
parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off") parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off")
@ -251,7 +276,8 @@ if __name__ == '__main__':
parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)") parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)")
parser.add_argument('--train_camera', action='store_true', help="optimize camera pose") parser.add_argument('--train_camera', action='store_true', help="optimize camera pose")
parser.add_argument('--smooth_path', action='store_true', help="brute-force smooth camera pose trajectory with a window size") parser.add_argument('--smooth_path', action='store_true',
help="brute-force smooth camera pose trajectory with a window size")
parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size") parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size")
# asr # asr
@ -299,7 +325,8 @@ if __name__ == '__main__':
parser.add_argument('--model', type=str, default='ernerf') # musetalk wav2lip parser.add_argument('--model', type=str, default='ernerf') # musetalk wav2lip
parser.add_argument('--transport', type=str, default='rtcpush') # rtmp webrtc rtcpush parser.add_argument('--transport', type=str, default='rtcpush') # rtmp webrtc rtcpush
parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream parser.add_argument('--push_url', type=str,
default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') # rtmp://localhost/live/livestream
parser.add_argument('--listenport', type=int, default=8010) parser.add_argument('--listenport', type=int, default=8010)
@ -312,6 +339,7 @@ if __name__ == '__main__':
from ernerf.nerf_triplane.utils import * from ernerf.nerf_triplane.utils import *
from ernerf.nerf_triplane.network import NeRFNetwork from ernerf.nerf_triplane.network import NeRFNetwork
from nerfreal import NeRFReal from nerfreal import NeRFReal
# assert test mode # assert test mode
opt.test = True opt.test = True
opt.test_train = False opt.test_train = False
@ -346,7 +374,8 @@ if __name__ == '__main__':
criterion = torch.nn.MSELoss(reduction='none') criterion = torch.nn.MSELoss(reduction='none')
metrics = [] # use no metric in GUI for faster initialization... metrics = [] # use no metric in GUI for faster initialization...
print(model) print(model)
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt) trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16,
metrics=metrics, use_checkpoint=opt.ckpt)
test_loader = NeRFDataset_Test(opt, device=device).dataloader() test_loader = NeRFDataset_Test(opt, device=device).dataloader()
model.aud_features = test_loader._data.auds model.aud_features = test_loader._data.auds
@ -356,10 +385,12 @@ if __name__ == '__main__':
nerfreal = NeRFReal(opt, trainer, test_loader) nerfreal = NeRFReal(opt, trainer, test_loader)
elif opt.model == 'musetalk': elif opt.model == 'musetalk':
from musereal import MuseReal from musereal import MuseReal
print(opt) print(opt)
nerfreal = MuseReal(opt) nerfreal = MuseReal(opt)
elif opt.model == 'wav2lip': elif opt.model == 'wav2lip':
from lipreal import LipReal from lipreal import LipReal
print(opt) print(opt)
nerfreal = LipReal(opt) nerfreal = LipReal(opt)
@ -388,6 +419,7 @@ if __name__ == '__main__':
for route in list(appasync.router.routes()): for route in list(appasync.router.routes()):
cors.add(route) cors.add(route)
def run_server(runner): def run_server(runner):
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
@ -397,6 +429,8 @@ if __name__ == '__main__':
if opt.transport == 'rtcpush': if opt.transport == 'rtcpush':
loop.run_until_complete(run(opt.push_url)) loop.run_until_complete(run(opt.push_url))
loop.run_forever() loop.run_forever()
Thread(target=run_server, args=(web.AppRunner(appasync),)).start() Thread(target=run_server, args=(web.AppRunner(appasync),)).start()
print('start websocket server') print('start websocket server')
@ -404,5 +438,3 @@ if __name__ == '__main__':
# app.router.add_post("/offer", offer) # app.router.add_post("/offer", offer)
server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler) server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler)
server.serve_forever() server.serve_forever()

331
musetalk/simple_musetalk.py Normal file
View File

@ -0,0 +1,331 @@
import argparse
import glob
import json
import os
import pickle
import shutil
import sys
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from diffusers import AutoencoderKL
from face_alignment import NetworkSize
from mmpose.apis import inference_topdown, init_model
from mmpose.structures import merge_data_samples
from tqdm import tqdm
from utils.face_parsing import FaceParsing
def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
cap = cv2.VideoCapture(vid_path)
count = 0
while True:
if count > cut_frame:
break
ret, frame = cap.read()
if ret:
cv2.imwrite(f"{save_path}/{count:08d}.png", frame)
count += 1
else:
break
def read_imgs(img_list):
frames = []
print('reading images...')
for img_path in tqdm(img_list):
frame = cv2.imread(img_path)
frames.append(frame)
return frames
def get_landmark_and_bbox(img_list, upperbondrange=0):
frames = read_imgs(img_list)
batch_size_fa = 1
batches = [frames[i:i + batch_size_fa] for i in range(0, len(frames), batch_size_fa)]
coords_list = []
landmarks = []
if upperbondrange != 0:
print('get key_landmark and face bounding boxes with the bbox_shift:', upperbondrange)
else:
print('get key_landmark and face bounding boxes with the default value')
average_range_minus = []
average_range_plus = []
for fb in tqdm(batches):
results = inference_topdown(model, np.asarray(fb)[0])
results = merge_data_samples(results)
keypoints = results.pred_instances.keypoints
face_land_mark = keypoints[0][23:91]
face_land_mark = face_land_mark.astype(np.int32)
# get bounding boxes by face detetion
bbox = fa.get_detections_for_batch(np.asarray(fb))
# adjust the bounding box refer to landmark
# Add the bounding box to a tuple and append it to the coordinates list
for j, f in enumerate(bbox):
if f is None: # no face in the image
coords_list += [coord_placeholder]
continue
half_face_coord = face_land_mark[29] # np.mean([face_land_mark[28], face_land_mark[29]], axis=0)
range_minus = (face_land_mark[30] - face_land_mark[29])[1]
range_plus = (face_land_mark[29] - face_land_mark[28])[1]
average_range_minus.append(range_minus)
average_range_plus.append(range_plus)
if upperbondrange != 0:
half_face_coord[1] = upperbondrange + half_face_coord[1] # 手动调整 + 向下偏29 - 向上偏28
half_face_dist = np.max(face_land_mark[:, 1]) - half_face_coord[1]
upper_bond = half_face_coord[1] - half_face_dist
f_landmark = (
np.min(face_land_mark[:, 0]), int(upper_bond), np.max(face_land_mark[:, 0]),
np.max(face_land_mark[:, 1]))
x1, y1, x2, y2 = f_landmark
if y2 - y1 <= 0 or x2 - x1 <= 0 or x1 < 0: # if the landmark bbox is not suitable, reuse the bbox
coords_list += [f]
w, h = f[2] - f[0], f[3] - f[1]
print("error bbox:", f)
else:
coords_list += [f_landmark]
return coords_list, frames
class FaceAlignment:
def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
device='cuda', flip_input=False, face_detector='sfd', verbose=False):
self.device = device
self.flip_input = flip_input
self.landmarks_type = landmarks_type
self.verbose = verbose
network_size = int(network_size)
if 'cuda' in device:
torch.backends.cudnn.benchmark = True
# torch.backends.cuda.matmul.allow_tf32 = False
# torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = False
# torch.backends.cudnn.allow_tf32 = True
print('cuda start')
# Get the face detector
face_detector_module = __import__('face_detection.detection.' + face_detector,
globals(), locals(), [face_detector], 0)
self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
def get_detections_for_batch(self, images):
images = images[..., ::-1]
detected_faces = self.face_detector.detect_from_batch(images.copy())
results = []
for i, d in enumerate(detected_faces):
if len(d) == 0:
results.append(None)
continue
d = d[0]
d = np.clip(d, 0, None)
x1, y1, x2, y2 = map(int, d[:-1])
results.append((x1, y1, x2, y2))
return results
def get_mask_tensor():
"""
Creates a mask tensor for image processing.
:return: A mask tensor.
"""
mask_tensor = torch.zeros((256, 256))
mask_tensor[:256 // 2, :] = 1
mask_tensor[mask_tensor < 0.5] = 0
mask_tensor[mask_tensor >= 0.5] = 1
return mask_tensor
def preprocess_img(img_name, half_mask=False):
window = []
if isinstance(img_name, str):
window_fnames = [img_name]
for fname in window_fnames:
img = cv2.imread(fname)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (256, 256),
interpolation=cv2.INTER_LANCZOS4)
window.append(img)
else:
img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB)
window.append(img)
x = np.asarray(window) / 255.
x = np.transpose(x, (3, 0, 1, 2))
x = torch.squeeze(torch.FloatTensor(x))
if half_mask:
x = x * (get_mask_tensor() > 0.5)
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
x = normalize(x)
x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor
x = x.to(device)
return x
def encode_latents(image):
with torch.no_grad():
init_latent_dist = vae.encode(image.to(vae.dtype)).latent_dist
init_latents = vae.config.scaling_factor * init_latent_dist.sample()
return init_latents
def get_latents_for_unet(img):
ref_image = preprocess_img(img, half_mask=True) # [1, 3, 256, 256] RGB, torch tensor
masked_latents = encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
ref_image = preprocess_img(img, half_mask=False) # [1, 3, 256, 256] RGB, torch tensor
ref_latents = encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
return latent_model_input
def get_crop_box(box, expand):
x, y, x1, y1 = box
x_c, y_c = (x + x1) // 2, (y + y1) // 2
w, h = x1 - x, y1 - y
s = int(max(w, h) // 2 * expand)
crop_box = [x_c - s, y_c - s, x_c + s, y_c + s]
return crop_box, s
def face_seg(image):
seg_image = fp(image)
if seg_image is None:
print("error, no person_segment")
return None
seg_image = seg_image.resize(image.size)
return seg_image
def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand=1.2):
body = Image.fromarray(image[:, :, ::-1])
x, y, x1, y1 = face_box
# print(x1-x,y1-y)
crop_box, s = get_crop_box(face_box, expand)
x_s, y_s, x_e, y_e = crop_box
face_large = body.crop(crop_box)
ori_shape = face_large.size
mask_image = face_seg(face_large)
mask_small = mask_image.crop((x - x_s, y - y_s, x1 - x_s, y1 - y_s))
mask_image = Image.new('L', ori_shape, 0)
mask_image.paste(mask_small, (x - x_s, y - y_s, x1 - x_s, y1 - y_s))
# keep upper_boundary_ratio of talking area
width, height = mask_image.size
top_boundary = int(height * upper_boundary_ratio)
modified_mask_image = Image.new('L', ori_shape, 0)
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
return mask_array, crop_box
def create_dir(dir_path):
if not os.path.exists(dir_path):
os.makedirs(dir_path)
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# initialize the mmpose model
device = "cuda" if torch.cuda.is_available() else "cpu"
fa = FaceAlignment(1, flip_input=False, device=device)
config_file = './utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py'
checkpoint_file = '../models/dwpose/dw-ll_ucoco_384.pth'
model = init_model(config_file, checkpoint_file, device=device)
vae = AutoencoderKL.from_pretrained("../models/sd-vae-ft-mse")
vae.to(device)
fp = FaceParsing()
if __name__ == '__main__':
# 视频文件地址
parser = argparse.ArgumentParser()
parser.add_argument("--file",
type=str,
default=r'D:\ok\test.mp4',
)
parser.add_argument("--avatar_id",
type=str,
default='1',
)
args = parser.parse_args()
file = args.file
# 保存文件设置 可以不动
save_path = f'../data/avatars/avator_{args.avatar_id}'
save_full_path = f'../data/avatars/avator_{args.avatar_id}/full_imgs'
create_dir(save_path)
create_dir(save_full_path)
mask_out_path = f'../data/avatars/avator_{args.avatar_id}/mask'
create_dir(mask_out_path)
# 模型
mask_coords_path = f'{save_path}/mask_coords.pkl'
coords_path = f'{save_path}/coords.pkl'
latents_out_path = f'{save_path}/latents.pt'
with open(f'{save_path}/avator_info.json', "w") as f:
json.dump({
"avatar_id": args.avatar_id,
"video_path": file,
"bbox_shift": 5
}, f)
if os.path.isfile(file):
video2imgs(file, save_full_path, ext='png')
else:
files = os.listdir(file)
files.sort()
files = [file for file in files if file.split(".")[-1] == "png"]
for filename in files:
shutil.copyfile(f"{file}/{filename}", f"{save_full_path}/{filename}")
input_img_list = sorted(glob.glob(os.path.join(save_full_path, '*.[jpJP][pnPN]*[gG]')))
print("extracting landmarks...")
coord_list, frame_list = get_landmark_and_bbox(input_img_list, 5)
input_latent_list = []
idx = -1
# maker if the bbox is not sufficient
coord_placeholder = (0.0, 0.0, 0.0, 0.0)
for bbox, frame in zip(coord_list, frame_list):
idx = idx + 1
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
crop_frame = frame[y1:y2, x1:x2]
resized_crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
latents = get_latents_for_unet(resized_crop_frame)
input_latent_list.append(latents)
frame_list_cycle = frame_list + frame_list[::-1]
coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
mask_coords_list_cycle = []
mask_list_cycle = []
for i, frame in enumerate(tqdm(frame_list_cycle)):
cv2.imwrite(f"{save_full_path}/{str(i).zfill(8)}.png", frame)
face_box = coord_list_cycle[i]
mask, crop_box = get_image_prepare_material(frame, face_box)
cv2.imwrite(f"{mask_out_path}/{str(i).zfill(8)}.png", mask)
mask_coords_list_cycle += [crop_box]
mask_list_cycle.append(mask)
with open(mask_coords_path, 'wb') as f:
pickle.dump(mask_coords_list_cycle, f)
with open(coords_path, 'wb') as f:
pickle.dump(coord_list_cycle, f)
torch.save(input_latent_list_cycle, os.path.join(latents_out_path))

View File

@ -7,14 +7,14 @@ from PIL import Image
from .model import BiSeNet from .model import BiSeNet
import torchvision.transforms as transforms import torchvision.transforms as transforms
class FaceParsing(): class FaceParsing():
def __init__(self): def __init__(self, resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
self.net = self.model_init() model_pth='./models/face-parse-bisent/79999_iter.pth'):
self.net = self.model_init(resnet_path,model_pth)
self.preprocess = self.image_preprocess() self.preprocess = self.image_preprocess()
def model_init(self, def model_init(self,resnet_path, model_pth):
resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
model_pth='./models/face-parse-bisent/79999_iter.pth'):
net = BiSeNet(resnet_path) net = BiSeNet(resnet_path)
if torch.cuda.is_available(): if torch.cuda.is_available():
net.cuda() net.cuda()
@ -49,8 +49,8 @@ class FaceParsing():
parsing = Image.fromarray(parsing.astype(np.uint8)) parsing = Image.fromarray(parsing.astype(np.uint8))
return parsing return parsing
if __name__ == "__main__": if __name__ == "__main__":
fp = FaceParsing() fp = FaceParsing()
segmap = fp('154_small.png') segmap = fp('154_small.png')
segmap.save('res.png') segmap.save('res.png')