From c0682408c5487f36ebc643b1b55f5111c76131a2 Mon Sep 17 00:00:00 2001 From: Yun <2289128964@qq.com> Date: Thu, 20 Jun 2024 20:21:37 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20=E7=AE=80=E5=8D=95?= =?UTF-8?q?=E8=87=AA=E5=8A=A8=E7=94=9F=E6=88=90musetalk=E6=95=B0=E5=AD=97?= =?UTF-8?q?=E4=BA=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 7 + app.py | 182 +++++++------ musetalk/simple_musetalk.py | 331 ++++++++++++++++++++++++ musetalk/utils/face_parsing/__init__.py | 18 +- 4 files changed, 454 insertions(+), 84 deletions(-) create mode 100644 musetalk/simple_musetalk.py diff --git a/README.md b/README.md index 21e6f8a..a67823e 100644 --- a/README.md +++ b/README.md @@ -172,6 +172,13 @@ python -m scripts.realtime_inference --inference_config configs/inference/realti 运行后将results/avatars下文件拷到本项目的data/avatars下 ``` +```bash +也可以试用本地目录下的 simple_musetalk.py +cd musetalk +python simple_musetalk.py --avatar_id 2 --file D:\\ok\\test.mp4 +运行后将直接生成在data/avatars下 +``` + ### 3.10 模型用wav2lip 暂不支持rtmp推送 - 下载模型 diff --git a/app.py b/app.py index bc111ad..a916733 100644 --- a/app.py +++ b/app.py @@ -1,5 +1,5 @@ # server.py -from flask import Flask, render_template,send_from_directory,request, jsonify +from flask import Flask, render_template, send_from_directory, request, jsonify from flask_sockets import Sockets import base64 import time @@ -10,7 +10,7 @@ from geventwebsocket.handler import WebSocketHandler import os import re import numpy as np -from threading import Thread,Event +from threading import Thread, Event import multiprocessing from aiohttp import web @@ -24,16 +24,15 @@ import argparse import shutil import asyncio - app = Flask(__name__) sockets = Sockets(app) global nerfreal - + @sockets.route('/humanecho') def echo_socket(ws): # 获取WebSocket对象 - #ws = request.environ.get('wsgi.websocket') + # ws = request.environ.get('wsgi.websocket') # 如果没有获取到,返回错误信息 if not ws: print('未建立连接!') @@ -42,11 +41,11 @@ def echo_socket(ws): else: print('建立连接!') while True: - message = ws.receive() - - if not message or len(message)==0: + message = ws.receive() + + if not message or len(message) == 0: return '输入信息为空' - else: + else: nerfreal.put_msg_txt(message) @@ -54,15 +53,16 @@ def llm_response(message): from llm.LLM import LLM # llm = LLM().init_model('Gemini', model_path= 'gemini-pro',api_key='Your API Key', proxy_url=None) # llm = LLM().init_model('ChatGPT', model_path= 'gpt-3.5-turbo',api_key='Your API Key') - llm = LLM().init_model('VllmGPT', model_path= 'THUDM/chatglm3-6b') + llm = LLM().init_model('VllmGPT', model_path='THUDM/chatglm3-6b') response = llm.chat(message) print(response) return response + @sockets.route('/humanchat') def chat_socket(ws): # 获取WebSocket对象 - #ws = request.environ.get('wsgi.websocket') + # ws = request.environ.get('wsgi.websocket') # 如果没有获取到,返回错误信息 if not ws: print('未建立连接!') @@ -71,18 +71,20 @@ def chat_socket(ws): else: print('建立连接!') while True: - message = ws.receive() - - if len(message)==0: + message = ws.receive() + + if len(message) == 0: return '输入信息为空' else: - res=llm_response(message) + res = llm_response(message) nerfreal.put_msg_txt(res) + #####webrtc############################### pcs = set() -#@app.route('/offer', methods=['POST']) + +# @app.route('/offer', methods=['POST']) async def offer(request): params = await request.json() offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) @@ -106,7 +108,7 @@ async def offer(request): answer = await pc.createAnswer() await pc.setLocalDescription(answer) - #return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}) + # return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}) return web.Response( content_type="application/json", @@ -115,36 +117,40 @@ async def offer(request): ), ) + async def human(request): params = await request.json() - if params['type']=='echo': + if params['type'] == 'echo': nerfreal.put_msg_txt(params['text']) - elif params['type']=='chat': - res=await asyncio.get_event_loop().run_in_executor(None, llm_response(params['text'])) + elif params['type'] == 'chat': + res = await asyncio.get_event_loop().run_in_executor(None, llm_response(params['text'])) nerfreal.put_msg_txt(res) return web.Response( content_type="application/json", text=json.dumps( - {"code": 0, "data":"ok"} + {"code": 0, "data": "ok"} ), ) + async def on_shutdown(app): # close peer connections coros = [pc.close() for pc in pcs] await asyncio.gather(*coros) pcs.clear() -async def post(url,data): + +async def post(url, data): try: async with aiohttp.ClientSession() as session: - async with session.post(url,data=data) as response: + async with session.post(url, data=data) as response: return await response.text() except aiohttp.ClientError as e: print(f'Error: {e}') + async def run(push_url): pc = RTCPeerConnection() pcs.add(pc) @@ -161,8 +167,10 @@ async def run(push_url): video_sender = pc.addTrack(player.video) await pc.setLocalDescription(await pc.createOffer()) - answer = await post(push_url,pc.localDescription.sdp) - await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer')) + answer = await post(push_url, pc.localDescription.sdp) + await pc.setRemoteDescription(RTCSessionDescription(sdp=answer, type='answer')) + + ########################################## # os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' # os.environ['MULTIPROCESSING_METHOD'] = 'forkserver' @@ -181,14 +189,20 @@ if __name__ == '__main__': ### training options parser.add_argument('--ckpt', type=str, default='data/pretrained/ngp_kf.pth') - - parser.add_argument('--num_rays', type=int, default=4096 * 16, help="num rays sampled per image for each training step") + + parser.add_argument('--num_rays', type=int, default=4096 * 16, + help="num rays sampled per image for each training step") parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") - parser.add_argument('--max_steps', type=int, default=16, help="max num steps sampled per ray (only valid when using --cuda_ray)") - parser.add_argument('--num_steps', type=int, default=16, help="num steps sampled per ray (only valid when NOT using --cuda_ray)") - parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)") - parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)") - parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)") + parser.add_argument('--max_steps', type=int, default=16, + help="max num steps sampled per ray (only valid when using --cuda_ray)") + parser.add_argument('--num_steps', type=int, default=16, + help="num steps sampled per ray (only valid when NOT using --cuda_ray)") + parser.add_argument('--upsample_steps', type=int, default=0, + help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)") + parser.add_argument('--update_extra_interval', type=int, default=16, + help="iter interval to update extra status (only valid when using --cuda_ray)") + parser.add_argument('--max_ray_batch', type=int, default=4096, + help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)") ### loss set parser.add_argument('--warmup_step', type=int, default=10000, help="warm up steps") @@ -199,27 +213,35 @@ if __name__ == '__main__': ### network backbone options parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training") - + parser.add_argument('--bg_img', type=str, default='white', help="background image") parser.add_argument('--fbg', action='store_true', help="frame-wise bg") parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes") - parser.add_argument('--fix_eye', type=float, default=-1, help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye") + parser.add_argument('--fix_eye', type=float, default=-1, + help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye") parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence") - parser.add_argument('--torso_shrink', type=float, default=0.8, help="shrink bg coords to allow more flexibility in deform") + parser.add_argument('--torso_shrink', type=float, default=0.8, + help="shrink bg coords to allow more flexibility in deform") ### dataset options parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)") - parser.add_argument('--preload', type=int, default=0, help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.") + parser.add_argument('--preload', type=int, default=0, + help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.") # (the default value is for the fox dataset) - parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.") + parser.add_argument('--bound', type=float, default=1, + help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.") parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3") parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location") - parser.add_argument('--dt_gamma', type=float, default=1/256, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") + parser.add_argument('--dt_gamma', type=float, default=1 / 256, + help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera") - parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied (sigma)") - parser.add_argument('--density_thresh_torso', type=float, default=0.01, help="threshold for density grid to be occupied (alpha)") - parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable") + parser.add_argument('--density_thresh', type=float, default=10, + help="threshold for density grid to be occupied (sigma)") + parser.add_argument('--density_thresh_torso', type=float, default=0.01, + help="threshold for density grid to be occupied (alpha)") + parser.add_argument('--patch_size', type=int, default=1, + help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable") parser.add_argument('--init_lips', action='store_true', help="init lips region") parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region") @@ -237,12 +259,15 @@ if __name__ == '__main__': parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel") ### else - parser.add_argument('--att', type=int, default=2, help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)") - parser.add_argument('--aud', type=str, default='', help="audio source (empty will load the default, else should be a path to a npy file)") + parser.add_argument('--att', type=int, default=2, + help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)") + parser.add_argument('--aud', type=str, default='', + help="audio source (empty will load the default, else should be a path to a npy file)") parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits") parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off") - parser.add_argument('--ind_num', type=int, default=10000, help="number of individual codes, should be larger than training dataset size") + parser.add_argument('--ind_num', type=int, default=10000, + help="number of individual codes, should be larger than training dataset size") parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off") @@ -251,7 +276,8 @@ if __name__ == '__main__': parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)") parser.add_argument('--train_camera', action='store_true', help="optimize camera pose") - parser.add_argument('--smooth_path', action='store_true', help="brute-force smooth camera pose trajectory with a window size") + parser.add_argument('--smooth_path', action='store_true', + help="brute-force smooth camera pose trajectory with a window size") parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size") # asr @@ -259,8 +285,8 @@ if __name__ == '__main__': parser.add_argument('--asr_wav', type=str, default='', help="load the wav and use as input") parser.add_argument('--asr_play', action='store_true', help="play out the audio") - #parser.add_argument('--asr_model', type=str, default='deepspeech') - parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') # + # parser.add_argument('--asr_model', type=str, default='deepspeech') + parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') # # parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self') # parser.add_argument('--asr_model', type=str, default='facebook/hubert-large-ls960-ft') @@ -279,7 +305,7 @@ if __name__ == '__main__': parser.add_argument('--fullbody_offset_x', type=int, default=0) parser.add_argument('--fullbody_offset_y', type=int, default=0) - #musetalk opt + # musetalk opt parser.add_argument('--avatar_id', type=str, default='avator_1') parser.add_argument('--bbox_shift', type=int, default=5) parser.add_argument('--batch_size', type=int, default=16) @@ -289,33 +315,35 @@ if __name__ == '__main__': parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img') parser.add_argument('--customvideo_imgnum', type=int, default=1) - parser.add_argument('--tts', type=str, default='edgetts') #xtts gpt-sovits + parser.add_argument('--tts', type=str, default='edgetts') # xtts gpt-sovits parser.add_argument('--REF_FILE', type=str, default=None) parser.add_argument('--REF_TEXT', type=str, default=None) - parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000 + parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000 # parser.add_argument('--CHARACTER', type=str, default='test') # parser.add_argument('--EMOTION', type=str, default='default') - parser.add_argument('--model', type=str, default='ernerf') #musetalk wav2lip + parser.add_argument('--model', type=str, default='ernerf') # musetalk wav2lip - parser.add_argument('--transport', type=str, default='rtcpush') #rtmp webrtc rtcpush - parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream + parser.add_argument('--transport', type=str, default='rtcpush') # rtmp webrtc rtcpush + parser.add_argument('--push_url', type=str, + default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') # rtmp://localhost/live/livestream parser.add_argument('--listenport', type=int, default=8010) opt = parser.parse_args() - #app.config.from_object(opt) - #print(app.config) + # app.config.from_object(opt) + # print(app.config) if opt.model == 'ernerf': from ernerf.nerf_triplane.provider import NeRFDataset_Test from ernerf.nerf_triplane.utils import * from ernerf.nerf_triplane.network import NeRFNetwork from nerfreal import NeRFReal + # assert test mode opt.test = True opt.test_train = False - #opt.train_camera =True + # opt.train_camera =True # explicit smoothing opt.smooth_path = True opt.smooth_lips = True @@ -328,7 +356,7 @@ if __name__ == '__main__': opt.exp_eye = True opt.smooth_eye = True - if opt.torso_imgs=='': #no img,use model output + if opt.torso_imgs == '': # no img,use model output opt.torso = True # assert opt.cuda_ray, "Only support CUDA ray mode." @@ -344,9 +372,10 @@ if __name__ == '__main__': model = NeRFNetwork(opt) criterion = torch.nn.MSELoss(reduction='none') - metrics = [] # use no metric in GUI for faster initialization... + metrics = [] # use no metric in GUI for faster initialization... print(model) - trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt) + trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, + metrics=metrics, use_checkpoint=opt.ckpt) test_loader = NeRFDataset_Test(opt, device=device).dataloader() model.aud_features = test_loader._data.auds @@ -356,17 +385,19 @@ if __name__ == '__main__': nerfreal = NeRFReal(opt, trainer, test_loader) elif opt.model == 'musetalk': from musereal import MuseReal + print(opt) nerfreal = MuseReal(opt) elif opt.model == 'wav2lip': from lipreal import LipReal + print(opt) nerfreal = LipReal(opt) - #txt_to_audio('我是中国人,我来自北京') - if opt.transport=='rtmp': + # txt_to_audio('我是中国人,我来自北京') + if opt.transport == 'rtmp': thread_quit = Event() - rendthrd = Thread(target=nerfreal.render,args=(thread_quit,)) + rendthrd = Thread(target=nerfreal.render, args=(thread_quit,)) rendthrd.start() ############################################################################# @@ -374,35 +405,36 @@ if __name__ == '__main__': appasync.on_shutdown.append(on_shutdown) appasync.router.add_post("/offer", offer) appasync.router.add_post("/human", human) - appasync.router.add_static('/',path='web') + appasync.router.add_static('/', path='web') # Configure default CORS settings. cors = aiohttp_cors.setup(appasync, defaults={ - "*": aiohttp_cors.ResourceOptions( - allow_credentials=True, - expose_headers="*", - allow_headers="*", - ) - }) + "*": aiohttp_cors.ResourceOptions( + allow_credentials=True, + expose_headers="*", + allow_headers="*", + ) + }) # Configure CORS on all routes. for route in list(appasync.router.routes()): cors.add(route) + def run_server(runner): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(runner.setup()) site = web.TCPSite(runner, '0.0.0.0', opt.listenport) loop.run_until_complete(site.start()) - if opt.transport=='rtcpush': + if opt.transport == 'rtcpush': loop.run_until_complete(run(opt.push_url)) - loop.run_forever() + loop.run_forever() + + Thread(target=run_server, args=(web.AppRunner(appasync),)).start() print('start websocket server') - #app.on_shutdown.append(on_shutdown) - #app.router.add_post("/offer", offer) + # app.on_shutdown.append(on_shutdown) + # app.router.add_post("/offer", offer) server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler) server.serve_forever() - - \ No newline at end of file diff --git a/musetalk/simple_musetalk.py b/musetalk/simple_musetalk.py new file mode 100644 index 0000000..1a6654e --- /dev/null +++ b/musetalk/simple_musetalk.py @@ -0,0 +1,331 @@ +import argparse +import glob +import json +import os +import pickle +import shutil +import sys + +import cv2 +import numpy as np +import torch +import torchvision.transforms as transforms +from PIL import Image +from diffusers import AutoencoderKL +from face_alignment import NetworkSize +from mmpose.apis import inference_topdown, init_model +from mmpose.structures import merge_data_samples +from tqdm import tqdm + +from utils.face_parsing import FaceParsing + + +def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000): + cap = cv2.VideoCapture(vid_path) + count = 0 + while True: + if count > cut_frame: + break + ret, frame = cap.read() + if ret: + cv2.imwrite(f"{save_path}/{count:08d}.png", frame) + count += 1 + else: + break + + +def read_imgs(img_list): + frames = [] + print('reading images...') + for img_path in tqdm(img_list): + frame = cv2.imread(img_path) + frames.append(frame) + return frames + + +def get_landmark_and_bbox(img_list, upperbondrange=0): + frames = read_imgs(img_list) + batch_size_fa = 1 + batches = [frames[i:i + batch_size_fa] for i in range(0, len(frames), batch_size_fa)] + coords_list = [] + landmarks = [] + if upperbondrange != 0: + print('get key_landmark and face bounding boxes with the bbox_shift:', upperbondrange) + else: + print('get key_landmark and face bounding boxes with the default value') + average_range_minus = [] + average_range_plus = [] + for fb in tqdm(batches): + results = inference_topdown(model, np.asarray(fb)[0]) + results = merge_data_samples(results) + keypoints = results.pred_instances.keypoints + face_land_mark = keypoints[0][23:91] + face_land_mark = face_land_mark.astype(np.int32) + + # get bounding boxes by face detetion + bbox = fa.get_detections_for_batch(np.asarray(fb)) + + # adjust the bounding box refer to landmark + # Add the bounding box to a tuple and append it to the coordinates list + for j, f in enumerate(bbox): + if f is None: # no face in the image + coords_list += [coord_placeholder] + continue + + half_face_coord = face_land_mark[29] # np.mean([face_land_mark[28], face_land_mark[29]], axis=0) + range_minus = (face_land_mark[30] - face_land_mark[29])[1] + range_plus = (face_land_mark[29] - face_land_mark[28])[1] + average_range_minus.append(range_minus) + average_range_plus.append(range_plus) + if upperbondrange != 0: + half_face_coord[1] = upperbondrange + half_face_coord[1] # 手动调整 + 向下(偏29) - 向上(偏28) + half_face_dist = np.max(face_land_mark[:, 1]) - half_face_coord[1] + upper_bond = half_face_coord[1] - half_face_dist + + f_landmark = ( + np.min(face_land_mark[:, 0]), int(upper_bond), np.max(face_land_mark[:, 0]), + np.max(face_land_mark[:, 1])) + x1, y1, x2, y2 = f_landmark + + if y2 - y1 <= 0 or x2 - x1 <= 0 or x1 < 0: # if the landmark bbox is not suitable, reuse the bbox + coords_list += [f] + w, h = f[2] - f[0], f[3] - f[1] + print("error bbox:", f) + else: + coords_list += [f_landmark] + return coords_list, frames + + +class FaceAlignment: + def __init__(self, landmarks_type, network_size=NetworkSize.LARGE, + device='cuda', flip_input=False, face_detector='sfd', verbose=False): + self.device = device + self.flip_input = flip_input + self.landmarks_type = landmarks_type + self.verbose = verbose + + network_size = int(network_size) + if 'cuda' in device: + torch.backends.cudnn.benchmark = True + # torch.backends.cuda.matmul.allow_tf32 = False + # torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = False + # torch.backends.cudnn.allow_tf32 = True + print('cuda start') + + # Get the face detector + face_detector_module = __import__('face_detection.detection.' + face_detector, + globals(), locals(), [face_detector], 0) + + self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose) + + def get_detections_for_batch(self, images): + images = images[..., ::-1] + detected_faces = self.face_detector.detect_from_batch(images.copy()) + results = [] + + for i, d in enumerate(detected_faces): + if len(d) == 0: + results.append(None) + continue + d = d[0] + d = np.clip(d, 0, None) + + x1, y1, x2, y2 = map(int, d[:-1]) + results.append((x1, y1, x2, y2)) + return results + + +def get_mask_tensor(): + """ + Creates a mask tensor for image processing. + :return: A mask tensor. + """ + mask_tensor = torch.zeros((256, 256)) + mask_tensor[:256 // 2, :] = 1 + mask_tensor[mask_tensor < 0.5] = 0 + mask_tensor[mask_tensor >= 0.5] = 1 + return mask_tensor + + +def preprocess_img(img_name, half_mask=False): + window = [] + if isinstance(img_name, str): + window_fnames = [img_name] + for fname in window_fnames: + img = cv2.imread(fname) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = cv2.resize(img, (256, 256), + interpolation=cv2.INTER_LANCZOS4) + window.append(img) + else: + img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB) + window.append(img) + x = np.asarray(window) / 255. + x = np.transpose(x, (3, 0, 1, 2)) + x = torch.squeeze(torch.FloatTensor(x)) + if half_mask: + x = x * (get_mask_tensor() > 0.5) + normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + x = normalize(x) + x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor + x = x.to(device) + return x + + +def encode_latents(image): + with torch.no_grad(): + init_latent_dist = vae.encode(image.to(vae.dtype)).latent_dist + init_latents = vae.config.scaling_factor * init_latent_dist.sample() + return init_latents + + +def get_latents_for_unet(img): + ref_image = preprocess_img(img, half_mask=True) # [1, 3, 256, 256] RGB, torch tensor + masked_latents = encode_latents(ref_image) # [1, 4, 32, 32], torch tensor + ref_image = preprocess_img(img, half_mask=False) # [1, 3, 256, 256] RGB, torch tensor + ref_latents = encode_latents(ref_image) # [1, 4, 32, 32], torch tensor + latent_model_input = torch.cat([masked_latents, ref_latents], dim=1) + return latent_model_input + + +def get_crop_box(box, expand): + x, y, x1, y1 = box + x_c, y_c = (x + x1) // 2, (y + y1) // 2 + w, h = x1 - x, y1 - y + s = int(max(w, h) // 2 * expand) + crop_box = [x_c - s, y_c - s, x_c + s, y_c + s] + return crop_box, s + + +def face_seg(image): + seg_image = fp(image) + if seg_image is None: + print("error, no person_segment") + return None + + seg_image = seg_image.resize(image.size) + return seg_image + + +def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand=1.2): + body = Image.fromarray(image[:, :, ::-1]) + + x, y, x1, y1 = face_box + # print(x1-x,y1-y) + crop_box, s = get_crop_box(face_box, expand) + x_s, y_s, x_e, y_e = crop_box + + face_large = body.crop(crop_box) + ori_shape = face_large.size + + mask_image = face_seg(face_large) + mask_small = mask_image.crop((x - x_s, y - y_s, x1 - x_s, y1 - y_s)) + mask_image = Image.new('L', ori_shape, 0) + mask_image.paste(mask_small, (x - x_s, y - y_s, x1 - x_s, y1 - y_s)) + + # keep upper_boundary_ratio of talking area + width, height = mask_image.size + top_boundary = int(height * upper_boundary_ratio) + modified_mask_image = Image.new('L', ori_shape, 0) + modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary)) + + blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1 + mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0) + return mask_array, crop_box + + +def create_dir(dir_path): + if not os.path.exists(dir_path): + os.makedirs(dir_path) + + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +# initialize the mmpose model +device = "cuda" if torch.cuda.is_available() else "cpu" +fa = FaceAlignment(1, flip_input=False, device=device) +config_file = './utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py' +checkpoint_file = '../models/dwpose/dw-ll_ucoco_384.pth' +model = init_model(config_file, checkpoint_file, device=device) +vae = AutoencoderKL.from_pretrained("../models/sd-vae-ft-mse") +vae.to(device) +fp = FaceParsing() +if __name__ == '__main__': + # 视频文件地址 + parser = argparse.ArgumentParser() + parser.add_argument("--file", + type=str, + default=r'D:\ok\test.mp4', + ) + parser.add_argument("--avatar_id", + type=str, + default='1', + ) + args = parser.parse_args() + file = args.file + # 保存文件设置 可以不动 + save_path = f'../data/avatars/avator_{args.avatar_id}' + save_full_path = f'../data/avatars/avator_{args.avatar_id}/full_imgs' + create_dir(save_path) + create_dir(save_full_path) + mask_out_path = f'../data/avatars/avator_{args.avatar_id}/mask' + create_dir(mask_out_path) + + # 模型 + mask_coords_path = f'{save_path}/mask_coords.pkl' + coords_path = f'{save_path}/coords.pkl' + latents_out_path = f'{save_path}/latents.pt' + + with open(f'{save_path}/avator_info.json', "w") as f: + json.dump({ + "avatar_id": args.avatar_id, + "video_path": file, + "bbox_shift": 5 + }, f) + + if os.path.isfile(file): + video2imgs(file, save_full_path, ext='png') + else: + files = os.listdir(file) + files.sort() + files = [file for file in files if file.split(".")[-1] == "png"] + for filename in files: + shutil.copyfile(f"{file}/{filename}", f"{save_full_path}/{filename}") + input_img_list = sorted(glob.glob(os.path.join(save_full_path, '*.[jpJP][pnPN]*[gG]'))) + print("extracting landmarks...") + coord_list, frame_list = get_landmark_and_bbox(input_img_list, 5) + input_latent_list = [] + idx = -1 + # maker if the bbox is not sufficient + coord_placeholder = (0.0, 0.0, 0.0, 0.0) + for bbox, frame in zip(coord_list, frame_list): + idx = idx + 1 + if bbox == coord_placeholder: + continue + x1, y1, x2, y2 = bbox + crop_frame = frame[y1:y2, x1:x2] + resized_crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4) + latents = get_latents_for_unet(resized_crop_frame) + input_latent_list.append(latents) + + frame_list_cycle = frame_list + frame_list[::-1] + coord_list_cycle = coord_list + coord_list[::-1] + input_latent_list_cycle = input_latent_list + input_latent_list[::-1] + mask_coords_list_cycle = [] + mask_list_cycle = [] + for i, frame in enumerate(tqdm(frame_list_cycle)): + cv2.imwrite(f"{save_full_path}/{str(i).zfill(8)}.png", frame) + + face_box = coord_list_cycle[i] + mask, crop_box = get_image_prepare_material(frame, face_box) + cv2.imwrite(f"{mask_out_path}/{str(i).zfill(8)}.png", mask) + mask_coords_list_cycle += [crop_box] + mask_list_cycle.append(mask) + + with open(mask_coords_path, 'wb') as f: + pickle.dump(mask_coords_list_cycle, f) + + with open(coords_path, 'wb') as f: + pickle.dump(coord_list_cycle, f) + torch.save(input_latent_list_cycle, os.path.join(latents_out_path)) diff --git a/musetalk/utils/face_parsing/__init__.py b/musetalk/utils/face_parsing/__init__.py index fc963a3..003147f 100755 --- a/musetalk/utils/face_parsing/__init__.py +++ b/musetalk/utils/face_parsing/__init__.py @@ -7,18 +7,18 @@ from PIL import Image from .model import BiSeNet import torchvision.transforms as transforms + class FaceParsing(): - def __init__(self): - self.net = self.model_init() + def __init__(self, resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth', + model_pth='./models/face-parse-bisent/79999_iter.pth'): + self.net = self.model_init(resnet_path,model_pth) self.preprocess = self.image_preprocess() - def model_init(self, - resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth', - model_pth='./models/face-parse-bisent/79999_iter.pth'): + def model_init(self,resnet_path, model_pth): net = BiSeNet(resnet_path) if torch.cuda.is_available(): net.cuda() - net.load_state_dict(torch.load(model_pth)) + net.load_state_dict(torch.load(model_pth)) else: net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu'))) net.eval() @@ -44,13 +44,13 @@ class FaceParsing(): img = torch.unsqueeze(img, 0) out = self.net(img)[0] parsing = out.squeeze(0).cpu().numpy().argmax(0) - parsing[np.where(parsing>13)] = 0 - parsing[np.where(parsing>=1)] = 255 + parsing[np.where(parsing > 13)] = 0 + parsing[np.where(parsing >= 1)] = 255 parsing = Image.fromarray(parsing.astype(np.uint8)) return parsing + if __name__ == "__main__": fp = FaceParsing() segmap = fp('154_small.png') segmap.save('res.png') -