From 5da818b9d91081fc9b38e214ba86777abea2b1d6 Mon Sep 17 00:00:00 2001 From: Yun <2289128964@qq.com> Date: Wed, 19 Jun 2024 14:47:57 +0800 Subject: [PATCH 1/5] feat: add musereal static img --- app.py | 1 + musereal.py | 218 +++++++++++++++++++++++++++------------------------- 2 files changed, 113 insertions(+), 106 deletions(-) diff --git a/app.py b/app.py index 396bf6a..bc111ad 100644 --- a/app.py +++ b/app.py @@ -285,6 +285,7 @@ if __name__ == '__main__': parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--customvideo', action='store_true', help="custom video") + parser.add_argument('--static_img', action='store_true', help="Use the first photo as a time of rest") parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img') parser.add_argument('--customvideo_imgnum', type=int, default=1) diff --git a/musereal.py b/musereal.py index d92ee85..7e58029 100644 --- a/musereal.py +++ b/musereal.py @@ -2,7 +2,7 @@ import math import torch import numpy as np -#from .utils import * +# from .utils import * import subprocess import os import time @@ -18,17 +18,19 @@ from threading import Thread, Event from io import BytesIO import multiprocessing as mp -from musetalk.utils.utils import get_file_type,get_video_fps,datagen -#from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder -from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending -from musetalk.utils.utils import load_all_model,load_diffusion_model,load_audio_model -from ttsreal import EdgeTTS,VoitsTTS,XTTS +from musetalk.utils.utils import get_file_type, get_video_fps, datagen +# from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder +from musetalk.utils.blending import get_image, get_image_prepare_material, get_image_blending +from musetalk.utils.utils import load_all_model, load_diffusion_model, load_audio_model +from ttsreal import EdgeTTS, VoitsTTS, XTTS from museasr import MuseASR import asyncio from av import AudioFrame, VideoFrame from tqdm import tqdm + + def read_imgs(img_list): frames = [] print('reading images...') @@ -37,142 +39,146 @@ def read_imgs(img_list): frames.append(frame) return frames + def __mirror_index(size, index): - #size = len(self.coord_list_cycle) + # size = len(self.coord_list_cycle) turn = index // size res = index % size if turn % 2 == 0: return res else: - return size - res - 1 + return size - res - 1 + + +def inference(render_event, batch_size, latents_out_path, audio_feat_queue, audio_out_queue, res_frame_queue, + ): # vae, unet, pe,timesteps -def inference(render_event,batch_size,latents_out_path,audio_feat_queue,audio_out_queue,res_frame_queue, - ): #vae, unet, pe,timesteps - vae, unet, pe = load_diffusion_model() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") timesteps = torch.tensor([0], device=device) pe = pe.half() vae.vae = vae.vae.half() unet.model = unet.model.half() - + input_latent_list_cycle = torch.load(latents_out_path) length = len(input_latent_list_cycle) index = 0 - count=0 - counttime=0 + count = 0 + counttime = 0 print('start inference') while True: if render_event.is_set(): - starttime=time.perf_counter() + starttime = time.perf_counter() try: whisper_chunks = audio_feat_queue.get(block=True, timeout=1) except queue.Empty: continue - is_all_silence=True + is_all_silence = True audio_frames = [] - for _ in range(batch_size*2): - frame,type = audio_out_queue.get() - audio_frames.append((frame,type)) - if type==0: - is_all_silence=False + for _ in range(batch_size * 2): + frame, type = audio_out_queue.get() + audio_frames.append((frame, type)) + if type == 0: + is_all_silence = False if is_all_silence: for i in range(batch_size): - res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2])) + res_frame_queue.put((None, __mirror_index(length, index), audio_frames[i * 2:i * 2 + 2])) index = index + 1 else: # print('infer=======') - t=time.perf_counter() + t = time.perf_counter() whisper_batch = np.stack(whisper_chunks) latent_batch = [] for i in range(batch_size): - idx = __mirror_index(length,index+i) + idx = __mirror_index(length, index + i) latent = input_latent_list_cycle[idx] latent_batch.append(latent) latent_batch = torch.cat(latent_batch, dim=0) - + # for i, (whisper_batch,latent_batch) in enumerate(gen): audio_feature_batch = torch.from_numpy(whisper_batch) audio_feature_batch = audio_feature_batch.to(device=unet.device, - dtype=unet.model.dtype) + dtype=unet.model.dtype) audio_feature_batch = pe(audio_feature_batch) latent_batch = latent_batch.to(dtype=unet.model.dtype) # print('prepare time:',time.perf_counter()-t) # t=time.perf_counter() - pred_latents = unet.model(latent_batch, - timesteps, - encoder_hidden_states=audio_feature_batch).sample + pred_latents = unet.model(latent_batch, + timesteps, + encoder_hidden_states=audio_feature_batch).sample # print('unet time:',time.perf_counter()-t) # t=time.perf_counter() recon = vae.decode_latents(pred_latents) # print('vae time:',time.perf_counter()-t) - #print('diffusion len=',len(recon)) + # print('diffusion len=',len(recon)) counttime += (time.perf_counter() - t) count += batch_size - #_totalframe += 1 - if count>=100: - print(f"------actual avg infer fps:{count/counttime:.4f}") - count=0 - counttime=0 - for i,res_frame in enumerate(recon): - #self.__pushmedia(res_frame,loop,audio_track,video_track) - res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2])) + # _totalframe += 1 + if count >= 100: + print(f"------actual avg infer fps:{count / counttime:.4f}") + count = 0 + counttime = 0 + for i, res_frame in enumerate(recon): + # self.__pushmedia(res_frame,loop,audio_track,video_track) + res_frame_queue.put((res_frame, __mirror_index(length, index), audio_frames[i * 2:i * 2 + 2])) index = index + 1 - #print('total batch time:',time.perf_counter()-starttime) + # print('total batch time:',time.perf_counter()-starttime) else: time.sleep(1) print('musereal inference processor stop') + @torch.no_grad() class MuseReal: def __init__(self, opt): - self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. + self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. self.W = opt.W self.H = opt.H - self.fps = opt.fps # 20 ms per frame + self.fps = opt.fps # 20 ms per frame #### musetalk self.avatar_id = opt.avatar_id - self.video_path = '' #video_path + self.static_img = opt.static_img + self.video_path = '' # video_path self.bbox_shift = opt.bbox_shift self.avatar_path = f"./data/avatars/{self.avatar_id}" - self.full_imgs_path = f"{self.avatar_path}/full_imgs" + self.full_imgs_path = f"{self.avatar_path}/full_imgs" self.coords_path = f"{self.avatar_path}/coords.pkl" - self.latents_out_path= f"{self.avatar_path}/latents.pt" + self.latents_out_path = f"{self.avatar_path}/latents.pt" self.video_out_path = f"{self.avatar_path}/vid_output/" - self.mask_out_path =f"{self.avatar_path}/mask" - self.mask_coords_path =f"{self.avatar_path}/mask_coords.pkl" + self.mask_out_path = f"{self.avatar_path}/mask" + self.mask_coords_path = f"{self.avatar_path}/mask_coords.pkl" self.avatar_info_path = f"{self.avatar_path}/avator_info.json" self.avatar_info = { - "avatar_id":self.avatar_id, - "video_path":self.video_path, - "bbox_shift":self.bbox_shift + "avatar_id": self.avatar_id, + "video_path": self.video_path, + "bbox_shift": self.bbox_shift } self.batch_size = opt.batch_size self.idx = 0 - self.res_frame_queue = mp.Queue(self.batch_size*2) + self.res_frame_queue = mp.Queue(self.batch_size * 2) self.__loadmodels() self.__loadavatar() - self.asr = MuseASR(opt,self.audio_processor) + self.asr = MuseASR(opt, self.audio_processor) if opt.tts == "edgetts": - self.tts = EdgeTTS(opt,self) + self.tts = EdgeTTS(opt, self) elif opt.tts == "gpt-sovits": - self.tts = VoitsTTS(opt,self) + self.tts = VoitsTTS(opt, self) elif opt.tts == "xtts": - self.tts = XTTS(opt,self) - #self.__warm_up() - + self.tts = XTTS(opt, self) + # self.__warm_up() + self.render_event = mp.Event() - mp.Process(target=inference, args=(self.render_event,self.batch_size,self.latents_out_path, - self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue, - )).start() #self.vae, self.unet, self.pe,self.timesteps + mp.Process(target=inference, args=(self.render_event, self.batch_size, self.latents_out_path, + self.asr.feat_queue, self.asr.output_queue, self.res_frame_queue, + )).start() # self.vae, self.unet, self.pe,self.timesteps def __loadmodels(self): # load model weights - self.audio_processor= load_audio_model() + self.audio_processor = load_audio_model() # self.audio_processor, self.vae, self.unet, self.pe = load_all_model() # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # self.timesteps = torch.tensor([0], device=device) @@ -181,7 +187,7 @@ class MuseReal: # self.unet.model = self.unet.model.half() def __loadavatar(self): - #self.input_latent_list_cycle = torch.load(self.latents_out_path) + # self.input_latent_list_cycle = torch.load(self.latents_out_path) with open(self.coords_path, 'rb') as f: self.coord_list_cycle = pickle.load(f) input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]')) @@ -192,12 +198,11 @@ class MuseReal: input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]')) input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) self.mask_list_cycle = read_imgs(input_mask_list) - - - def put_msg_txt(self,msg): + + def put_msg_txt(self, msg): self.tts.put_msg_txt(msg) - - def put_audio_frame(self,audio_chunk): #16khz 20ms pcm + + def put_audio_frame(self, audio_chunk): # 16khz 20ms pcm self.asr.put_audio_frame(audio_chunk) def __mirror_index(self, index): @@ -207,15 +212,15 @@ class MuseReal: if turn % 2 == 0: return res else: - return size - res - 1 + return size - res - 1 - def __warm_up(self): + def __warm_up(self): self.asr.run_step() whisper_chunks = self.asr.get_next_feat() whisper_batch = np.stack(whisper_chunks) latent_batch = [] for i in range(self.batch_size): - idx = self.__mirror_index(self.idx+i) + idx = self.__mirror_index(self.idx + i) latent = self.input_latent_list_cycle[idx] latent_batch.append(latent) latent_batch = torch.cat(latent_batch, dim=0) @@ -223,87 +228,88 @@ class MuseReal: # for i, (whisper_batch,latent_batch) in enumerate(gen): audio_feature_batch = torch.from_numpy(whisper_batch) audio_feature_batch = audio_feature_batch.to(device=self.unet.device, - dtype=self.unet.model.dtype) + dtype=self.unet.model.dtype) audio_feature_batch = self.pe(audio_feature_batch) latent_batch = latent_batch.to(dtype=self.unet.model.dtype) - pred_latents = self.unet.model(latent_batch, - self.timesteps, - encoder_hidden_states=audio_feature_batch).sample + pred_latents = self.unet.model(latent_batch, + self.timesteps, + encoder_hidden_states=audio_feature_batch).sample recon = self.vae.decode_latents(pred_latents) - - def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None): - + def process_frames(self, quit_event, loop=None, audio_track=None, video_track=None): + while not quit_event.is_set(): try: - res_frame,idx,audio_frames = self.res_frame_queue.get(block=True, timeout=1) + res_frame, idx, audio_frames = self.res_frame_queue.get(block=True, timeout=1) except queue.Empty: continue - if audio_frames[0][1]==1 and audio_frames[1][1]==1: #全为静音数据,只需要取fullimg - combine_frame = self.frame_list_cycle[idx] + if audio_frames[0][1] == 1 and audio_frames[1][1] == 1: # 全为静音数据,只需要取fullimg + if self.static_img: + combine_frame = self.frame_list_cycle[0] + else: + combine_frame = self.frame_list_cycle[idx] else: bbox = self.coord_list_cycle[idx] ori_frame = copy.deepcopy(self.frame_list_cycle[idx]) x1, y1, x2, y2 = bbox try: - res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1)) + res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1)) except: continue mask = self.mask_list_cycle[idx] mask_crop_box = self.mask_coords_list_cycle[idx] - #combine_frame = get_image(ori_frame,res_frame,bbox) - #t=time.perf_counter() - combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box) - #print('blending time:',time.perf_counter()-t) + # combine_frame = get_image(ori_frame,res_frame,bbox) + # t=time.perf_counter() + combine_frame = get_image_blending(ori_frame, res_frame, bbox, mask, mask_crop_box) + # print('blending time:',time.perf_counter()-t) - image = combine_frame #(outputs['image'] * 255).astype(np.uint8) + image = combine_frame # (outputs['image'] * 255).astype(np.uint8) new_frame = VideoFrame.from_ndarray(image, format="bgr24") - asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop) + asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop) for audio_frame in audio_frames: - frame,type = audio_frame + frame, type = audio_frame frame = (frame * 32767).astype(np.int16) new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) new_frame.planes[0].update(frame.tobytes()) - new_frame.sample_rate=16000 + new_frame.sample_rate = 16000 # if audio_track._queue.qsize()>10: # time.sleep(0.1) asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop) - print('musereal process_frames thread stop') - - def render(self,quit_event,loop=None,audio_track=None,video_track=None): - #if self.opt.asr: + print('musereal process_frames thread stop') + + def render(self, quit_event, loop=None, audio_track=None, video_track=None): + # if self.opt.asr: # self.asr.warm_up() self.tts.render(quit_event) - process_thread = Thread(target=self.process_frames, args=(quit_event,loop,audio_track,video_track)) + process_thread = Thread(target=self.process_frames, args=(quit_event, loop, audio_track, video_track)) process_thread.start() - self.render_event.set() #start infer process render - count=0 - totaltime=0 - _starttime=time.perf_counter() - #_totalframe=0 - while not quit_event.is_set(): #todo + self.render_event.set() # start infer process render + count = 0 + totaltime = 0 + _starttime = time.perf_counter() + # _totalframe=0 + while not quit_event.is_set(): # todo # update texture every frame # audio stream thread... t = time.perf_counter() self.asr.run_step() - #self.test_step(loop,audio_track,video_track) + # self.test_step(loop,audio_track,video_track) # totaltime += (time.perf_counter() - t) # count += self.opt.batch_size # if count>=100: # print(f"------actual avg infer fps:{count/totaltime:.4f}") # count=0 # totaltime=0 - if video_track._queue.qsize()>=2*self.opt.batch_size: - print('sleep qsize=',video_track._queue.qsize()) - time.sleep(0.04*self.opt.batch_size*1.5) - + if video_track._queue.qsize() >= 2 * self.opt.batch_size: + print('sleep qsize=', video_track._queue.qsize()) + time.sleep(0.04 * self.opt.batch_size * 1.5) + # delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms # if delay > 0: # time.sleep(delay) - self.render_event.clear() #end infer process render + self.render_event.clear() # end infer process render print('musereal thread stop') - \ No newline at end of file From c0682408c5487f36ebc643b1b55f5111c76131a2 Mon Sep 17 00:00:00 2001 From: Yun <2289128964@qq.com> Date: Thu, 20 Jun 2024 20:21:37 +0800 Subject: [PATCH 2/5] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20=E7=AE=80?= =?UTF-8?q?=E5=8D=95=E8=87=AA=E5=8A=A8=E7=94=9F=E6=88=90musetalk=E6=95=B0?= =?UTF-8?q?=E5=AD=97=E4=BA=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 7 + app.py | 182 +++++++------ musetalk/simple_musetalk.py | 331 ++++++++++++++++++++++++ musetalk/utils/face_parsing/__init__.py | 18 +- 4 files changed, 454 insertions(+), 84 deletions(-) create mode 100644 musetalk/simple_musetalk.py diff --git a/README.md b/README.md index 21e6f8a..a67823e 100644 --- a/README.md +++ b/README.md @@ -172,6 +172,13 @@ python -m scripts.realtime_inference --inference_config configs/inference/realti 运行后将results/avatars下文件拷到本项目的data/avatars下 ``` +```bash +也可以试用本地目录下的 simple_musetalk.py +cd musetalk +python simple_musetalk.py --avatar_id 2 --file D:\\ok\\test.mp4 +运行后将直接生成在data/avatars下 +``` + ### 3.10 模型用wav2lip 暂不支持rtmp推送 - 下载模型 diff --git a/app.py b/app.py index bc111ad..a916733 100644 --- a/app.py +++ b/app.py @@ -1,5 +1,5 @@ # server.py -from flask import Flask, render_template,send_from_directory,request, jsonify +from flask import Flask, render_template, send_from_directory, request, jsonify from flask_sockets import Sockets import base64 import time @@ -10,7 +10,7 @@ from geventwebsocket.handler import WebSocketHandler import os import re import numpy as np -from threading import Thread,Event +from threading import Thread, Event import multiprocessing from aiohttp import web @@ -24,16 +24,15 @@ import argparse import shutil import asyncio - app = Flask(__name__) sockets = Sockets(app) global nerfreal - + @sockets.route('/humanecho') def echo_socket(ws): # 获取WebSocket对象 - #ws = request.environ.get('wsgi.websocket') + # ws = request.environ.get('wsgi.websocket') # 如果没有获取到,返回错误信息 if not ws: print('未建立连接!') @@ -42,11 +41,11 @@ def echo_socket(ws): else: print('建立连接!') while True: - message = ws.receive() - - if not message or len(message)==0: + message = ws.receive() + + if not message or len(message) == 0: return '输入信息为空' - else: + else: nerfreal.put_msg_txt(message) @@ -54,15 +53,16 @@ def llm_response(message): from llm.LLM import LLM # llm = LLM().init_model('Gemini', model_path= 'gemini-pro',api_key='Your API Key', proxy_url=None) # llm = LLM().init_model('ChatGPT', model_path= 'gpt-3.5-turbo',api_key='Your API Key') - llm = LLM().init_model('VllmGPT', model_path= 'THUDM/chatglm3-6b') + llm = LLM().init_model('VllmGPT', model_path='THUDM/chatglm3-6b') response = llm.chat(message) print(response) return response + @sockets.route('/humanchat') def chat_socket(ws): # 获取WebSocket对象 - #ws = request.environ.get('wsgi.websocket') + # ws = request.environ.get('wsgi.websocket') # 如果没有获取到,返回错误信息 if not ws: print('未建立连接!') @@ -71,18 +71,20 @@ def chat_socket(ws): else: print('建立连接!') while True: - message = ws.receive() - - if len(message)==0: + message = ws.receive() + + if len(message) == 0: return '输入信息为空' else: - res=llm_response(message) + res = llm_response(message) nerfreal.put_msg_txt(res) + #####webrtc############################### pcs = set() -#@app.route('/offer', methods=['POST']) + +# @app.route('/offer', methods=['POST']) async def offer(request): params = await request.json() offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) @@ -106,7 +108,7 @@ async def offer(request): answer = await pc.createAnswer() await pc.setLocalDescription(answer) - #return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}) + # return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}) return web.Response( content_type="application/json", @@ -115,36 +117,40 @@ async def offer(request): ), ) + async def human(request): params = await request.json() - if params['type']=='echo': + if params['type'] == 'echo': nerfreal.put_msg_txt(params['text']) - elif params['type']=='chat': - res=await asyncio.get_event_loop().run_in_executor(None, llm_response(params['text'])) + elif params['type'] == 'chat': + res = await asyncio.get_event_loop().run_in_executor(None, llm_response(params['text'])) nerfreal.put_msg_txt(res) return web.Response( content_type="application/json", text=json.dumps( - {"code": 0, "data":"ok"} + {"code": 0, "data": "ok"} ), ) + async def on_shutdown(app): # close peer connections coros = [pc.close() for pc in pcs] await asyncio.gather(*coros) pcs.clear() -async def post(url,data): + +async def post(url, data): try: async with aiohttp.ClientSession() as session: - async with session.post(url,data=data) as response: + async with session.post(url, data=data) as response: return await response.text() except aiohttp.ClientError as e: print(f'Error: {e}') + async def run(push_url): pc = RTCPeerConnection() pcs.add(pc) @@ -161,8 +167,10 @@ async def run(push_url): video_sender = pc.addTrack(player.video) await pc.setLocalDescription(await pc.createOffer()) - answer = await post(push_url,pc.localDescription.sdp) - await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer')) + answer = await post(push_url, pc.localDescription.sdp) + await pc.setRemoteDescription(RTCSessionDescription(sdp=answer, type='answer')) + + ########################################## # os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' # os.environ['MULTIPROCESSING_METHOD'] = 'forkserver' @@ -181,14 +189,20 @@ if __name__ == '__main__': ### training options parser.add_argument('--ckpt', type=str, default='data/pretrained/ngp_kf.pth') - - parser.add_argument('--num_rays', type=int, default=4096 * 16, help="num rays sampled per image for each training step") + + parser.add_argument('--num_rays', type=int, default=4096 * 16, + help="num rays sampled per image for each training step") parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") - parser.add_argument('--max_steps', type=int, default=16, help="max num steps sampled per ray (only valid when using --cuda_ray)") - parser.add_argument('--num_steps', type=int, default=16, help="num steps sampled per ray (only valid when NOT using --cuda_ray)") - parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)") - parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)") - parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)") + parser.add_argument('--max_steps', type=int, default=16, + help="max num steps sampled per ray (only valid when using --cuda_ray)") + parser.add_argument('--num_steps', type=int, default=16, + help="num steps sampled per ray (only valid when NOT using --cuda_ray)") + parser.add_argument('--upsample_steps', type=int, default=0, + help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)") + parser.add_argument('--update_extra_interval', type=int, default=16, + help="iter interval to update extra status (only valid when using --cuda_ray)") + parser.add_argument('--max_ray_batch', type=int, default=4096, + help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)") ### loss set parser.add_argument('--warmup_step', type=int, default=10000, help="warm up steps") @@ -199,27 +213,35 @@ if __name__ == '__main__': ### network backbone options parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training") - + parser.add_argument('--bg_img', type=str, default='white', help="background image") parser.add_argument('--fbg', action='store_true', help="frame-wise bg") parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes") - parser.add_argument('--fix_eye', type=float, default=-1, help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye") + parser.add_argument('--fix_eye', type=float, default=-1, + help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye") parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence") - parser.add_argument('--torso_shrink', type=float, default=0.8, help="shrink bg coords to allow more flexibility in deform") + parser.add_argument('--torso_shrink', type=float, default=0.8, + help="shrink bg coords to allow more flexibility in deform") ### dataset options parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)") - parser.add_argument('--preload', type=int, default=0, help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.") + parser.add_argument('--preload', type=int, default=0, + help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.") # (the default value is for the fox dataset) - parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.") + parser.add_argument('--bound', type=float, default=1, + help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.") parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3") parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location") - parser.add_argument('--dt_gamma', type=float, default=1/256, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") + parser.add_argument('--dt_gamma', type=float, default=1 / 256, + help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera") - parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied (sigma)") - parser.add_argument('--density_thresh_torso', type=float, default=0.01, help="threshold for density grid to be occupied (alpha)") - parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable") + parser.add_argument('--density_thresh', type=float, default=10, + help="threshold for density grid to be occupied (sigma)") + parser.add_argument('--density_thresh_torso', type=float, default=0.01, + help="threshold for density grid to be occupied (alpha)") + parser.add_argument('--patch_size', type=int, default=1, + help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable") parser.add_argument('--init_lips', action='store_true', help="init lips region") parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region") @@ -237,12 +259,15 @@ if __name__ == '__main__': parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel") ### else - parser.add_argument('--att', type=int, default=2, help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)") - parser.add_argument('--aud', type=str, default='', help="audio source (empty will load the default, else should be a path to a npy file)") + parser.add_argument('--att', type=int, default=2, + help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)") + parser.add_argument('--aud', type=str, default='', + help="audio source (empty will load the default, else should be a path to a npy file)") parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits") parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off") - parser.add_argument('--ind_num', type=int, default=10000, help="number of individual codes, should be larger than training dataset size") + parser.add_argument('--ind_num', type=int, default=10000, + help="number of individual codes, should be larger than training dataset size") parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off") @@ -251,7 +276,8 @@ if __name__ == '__main__': parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)") parser.add_argument('--train_camera', action='store_true', help="optimize camera pose") - parser.add_argument('--smooth_path', action='store_true', help="brute-force smooth camera pose trajectory with a window size") + parser.add_argument('--smooth_path', action='store_true', + help="brute-force smooth camera pose trajectory with a window size") parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size") # asr @@ -259,8 +285,8 @@ if __name__ == '__main__': parser.add_argument('--asr_wav', type=str, default='', help="load the wav and use as input") parser.add_argument('--asr_play', action='store_true', help="play out the audio") - #parser.add_argument('--asr_model', type=str, default='deepspeech') - parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') # + # parser.add_argument('--asr_model', type=str, default='deepspeech') + parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') # # parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self') # parser.add_argument('--asr_model', type=str, default='facebook/hubert-large-ls960-ft') @@ -279,7 +305,7 @@ if __name__ == '__main__': parser.add_argument('--fullbody_offset_x', type=int, default=0) parser.add_argument('--fullbody_offset_y', type=int, default=0) - #musetalk opt + # musetalk opt parser.add_argument('--avatar_id', type=str, default='avator_1') parser.add_argument('--bbox_shift', type=int, default=5) parser.add_argument('--batch_size', type=int, default=16) @@ -289,33 +315,35 @@ if __name__ == '__main__': parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img') parser.add_argument('--customvideo_imgnum', type=int, default=1) - parser.add_argument('--tts', type=str, default='edgetts') #xtts gpt-sovits + parser.add_argument('--tts', type=str, default='edgetts') # xtts gpt-sovits parser.add_argument('--REF_FILE', type=str, default=None) parser.add_argument('--REF_TEXT', type=str, default=None) - parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000 + parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000 # parser.add_argument('--CHARACTER', type=str, default='test') # parser.add_argument('--EMOTION', type=str, default='default') - parser.add_argument('--model', type=str, default='ernerf') #musetalk wav2lip + parser.add_argument('--model', type=str, default='ernerf') # musetalk wav2lip - parser.add_argument('--transport', type=str, default='rtcpush') #rtmp webrtc rtcpush - parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream + parser.add_argument('--transport', type=str, default='rtcpush') # rtmp webrtc rtcpush + parser.add_argument('--push_url', type=str, + default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') # rtmp://localhost/live/livestream parser.add_argument('--listenport', type=int, default=8010) opt = parser.parse_args() - #app.config.from_object(opt) - #print(app.config) + # app.config.from_object(opt) + # print(app.config) if opt.model == 'ernerf': from ernerf.nerf_triplane.provider import NeRFDataset_Test from ernerf.nerf_triplane.utils import * from ernerf.nerf_triplane.network import NeRFNetwork from nerfreal import NeRFReal + # assert test mode opt.test = True opt.test_train = False - #opt.train_camera =True + # opt.train_camera =True # explicit smoothing opt.smooth_path = True opt.smooth_lips = True @@ -328,7 +356,7 @@ if __name__ == '__main__': opt.exp_eye = True opt.smooth_eye = True - if opt.torso_imgs=='': #no img,use model output + if opt.torso_imgs == '': # no img,use model output opt.torso = True # assert opt.cuda_ray, "Only support CUDA ray mode." @@ -344,9 +372,10 @@ if __name__ == '__main__': model = NeRFNetwork(opt) criterion = torch.nn.MSELoss(reduction='none') - metrics = [] # use no metric in GUI for faster initialization... + metrics = [] # use no metric in GUI for faster initialization... print(model) - trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt) + trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, + metrics=metrics, use_checkpoint=opt.ckpt) test_loader = NeRFDataset_Test(opt, device=device).dataloader() model.aud_features = test_loader._data.auds @@ -356,17 +385,19 @@ if __name__ == '__main__': nerfreal = NeRFReal(opt, trainer, test_loader) elif opt.model == 'musetalk': from musereal import MuseReal + print(opt) nerfreal = MuseReal(opt) elif opt.model == 'wav2lip': from lipreal import LipReal + print(opt) nerfreal = LipReal(opt) - #txt_to_audio('我是中国人,我来自北京') - if opt.transport=='rtmp': + # txt_to_audio('我是中国人,我来自北京') + if opt.transport == 'rtmp': thread_quit = Event() - rendthrd = Thread(target=nerfreal.render,args=(thread_quit,)) + rendthrd = Thread(target=nerfreal.render, args=(thread_quit,)) rendthrd.start() ############################################################################# @@ -374,35 +405,36 @@ if __name__ == '__main__': appasync.on_shutdown.append(on_shutdown) appasync.router.add_post("/offer", offer) appasync.router.add_post("/human", human) - appasync.router.add_static('/',path='web') + appasync.router.add_static('/', path='web') # Configure default CORS settings. cors = aiohttp_cors.setup(appasync, defaults={ - "*": aiohttp_cors.ResourceOptions( - allow_credentials=True, - expose_headers="*", - allow_headers="*", - ) - }) + "*": aiohttp_cors.ResourceOptions( + allow_credentials=True, + expose_headers="*", + allow_headers="*", + ) + }) # Configure CORS on all routes. for route in list(appasync.router.routes()): cors.add(route) + def run_server(runner): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(runner.setup()) site = web.TCPSite(runner, '0.0.0.0', opt.listenport) loop.run_until_complete(site.start()) - if opt.transport=='rtcpush': + if opt.transport == 'rtcpush': loop.run_until_complete(run(opt.push_url)) - loop.run_forever() + loop.run_forever() + + Thread(target=run_server, args=(web.AppRunner(appasync),)).start() print('start websocket server') - #app.on_shutdown.append(on_shutdown) - #app.router.add_post("/offer", offer) + # app.on_shutdown.append(on_shutdown) + # app.router.add_post("/offer", offer) server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler) server.serve_forever() - - \ No newline at end of file diff --git a/musetalk/simple_musetalk.py b/musetalk/simple_musetalk.py new file mode 100644 index 0000000..1a6654e --- /dev/null +++ b/musetalk/simple_musetalk.py @@ -0,0 +1,331 @@ +import argparse +import glob +import json +import os +import pickle +import shutil +import sys + +import cv2 +import numpy as np +import torch +import torchvision.transforms as transforms +from PIL import Image +from diffusers import AutoencoderKL +from face_alignment import NetworkSize +from mmpose.apis import inference_topdown, init_model +from mmpose.structures import merge_data_samples +from tqdm import tqdm + +from utils.face_parsing import FaceParsing + + +def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000): + cap = cv2.VideoCapture(vid_path) + count = 0 + while True: + if count > cut_frame: + break + ret, frame = cap.read() + if ret: + cv2.imwrite(f"{save_path}/{count:08d}.png", frame) + count += 1 + else: + break + + +def read_imgs(img_list): + frames = [] + print('reading images...') + for img_path in tqdm(img_list): + frame = cv2.imread(img_path) + frames.append(frame) + return frames + + +def get_landmark_and_bbox(img_list, upperbondrange=0): + frames = read_imgs(img_list) + batch_size_fa = 1 + batches = [frames[i:i + batch_size_fa] for i in range(0, len(frames), batch_size_fa)] + coords_list = [] + landmarks = [] + if upperbondrange != 0: + print('get key_landmark and face bounding boxes with the bbox_shift:', upperbondrange) + else: + print('get key_landmark and face bounding boxes with the default value') + average_range_minus = [] + average_range_plus = [] + for fb in tqdm(batches): + results = inference_topdown(model, np.asarray(fb)[0]) + results = merge_data_samples(results) + keypoints = results.pred_instances.keypoints + face_land_mark = keypoints[0][23:91] + face_land_mark = face_land_mark.astype(np.int32) + + # get bounding boxes by face detetion + bbox = fa.get_detections_for_batch(np.asarray(fb)) + + # adjust the bounding box refer to landmark + # Add the bounding box to a tuple and append it to the coordinates list + for j, f in enumerate(bbox): + if f is None: # no face in the image + coords_list += [coord_placeholder] + continue + + half_face_coord = face_land_mark[29] # np.mean([face_land_mark[28], face_land_mark[29]], axis=0) + range_minus = (face_land_mark[30] - face_land_mark[29])[1] + range_plus = (face_land_mark[29] - face_land_mark[28])[1] + average_range_minus.append(range_minus) + average_range_plus.append(range_plus) + if upperbondrange != 0: + half_face_coord[1] = upperbondrange + half_face_coord[1] # 手动调整 + 向下(偏29) - 向上(偏28) + half_face_dist = np.max(face_land_mark[:, 1]) - half_face_coord[1] + upper_bond = half_face_coord[1] - half_face_dist + + f_landmark = ( + np.min(face_land_mark[:, 0]), int(upper_bond), np.max(face_land_mark[:, 0]), + np.max(face_land_mark[:, 1])) + x1, y1, x2, y2 = f_landmark + + if y2 - y1 <= 0 or x2 - x1 <= 0 or x1 < 0: # if the landmark bbox is not suitable, reuse the bbox + coords_list += [f] + w, h = f[2] - f[0], f[3] - f[1] + print("error bbox:", f) + else: + coords_list += [f_landmark] + return coords_list, frames + + +class FaceAlignment: + def __init__(self, landmarks_type, network_size=NetworkSize.LARGE, + device='cuda', flip_input=False, face_detector='sfd', verbose=False): + self.device = device + self.flip_input = flip_input + self.landmarks_type = landmarks_type + self.verbose = verbose + + network_size = int(network_size) + if 'cuda' in device: + torch.backends.cudnn.benchmark = True + # torch.backends.cuda.matmul.allow_tf32 = False + # torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = False + # torch.backends.cudnn.allow_tf32 = True + print('cuda start') + + # Get the face detector + face_detector_module = __import__('face_detection.detection.' + face_detector, + globals(), locals(), [face_detector], 0) + + self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose) + + def get_detections_for_batch(self, images): + images = images[..., ::-1] + detected_faces = self.face_detector.detect_from_batch(images.copy()) + results = [] + + for i, d in enumerate(detected_faces): + if len(d) == 0: + results.append(None) + continue + d = d[0] + d = np.clip(d, 0, None) + + x1, y1, x2, y2 = map(int, d[:-1]) + results.append((x1, y1, x2, y2)) + return results + + +def get_mask_tensor(): + """ + Creates a mask tensor for image processing. + :return: A mask tensor. + """ + mask_tensor = torch.zeros((256, 256)) + mask_tensor[:256 // 2, :] = 1 + mask_tensor[mask_tensor < 0.5] = 0 + mask_tensor[mask_tensor >= 0.5] = 1 + return mask_tensor + + +def preprocess_img(img_name, half_mask=False): + window = [] + if isinstance(img_name, str): + window_fnames = [img_name] + for fname in window_fnames: + img = cv2.imread(fname) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = cv2.resize(img, (256, 256), + interpolation=cv2.INTER_LANCZOS4) + window.append(img) + else: + img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB) + window.append(img) + x = np.asarray(window) / 255. + x = np.transpose(x, (3, 0, 1, 2)) + x = torch.squeeze(torch.FloatTensor(x)) + if half_mask: + x = x * (get_mask_tensor() > 0.5) + normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + x = normalize(x) + x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor + x = x.to(device) + return x + + +def encode_latents(image): + with torch.no_grad(): + init_latent_dist = vae.encode(image.to(vae.dtype)).latent_dist + init_latents = vae.config.scaling_factor * init_latent_dist.sample() + return init_latents + + +def get_latents_for_unet(img): + ref_image = preprocess_img(img, half_mask=True) # [1, 3, 256, 256] RGB, torch tensor + masked_latents = encode_latents(ref_image) # [1, 4, 32, 32], torch tensor + ref_image = preprocess_img(img, half_mask=False) # [1, 3, 256, 256] RGB, torch tensor + ref_latents = encode_latents(ref_image) # [1, 4, 32, 32], torch tensor + latent_model_input = torch.cat([masked_latents, ref_latents], dim=1) + return latent_model_input + + +def get_crop_box(box, expand): + x, y, x1, y1 = box + x_c, y_c = (x + x1) // 2, (y + y1) // 2 + w, h = x1 - x, y1 - y + s = int(max(w, h) // 2 * expand) + crop_box = [x_c - s, y_c - s, x_c + s, y_c + s] + return crop_box, s + + +def face_seg(image): + seg_image = fp(image) + if seg_image is None: + print("error, no person_segment") + return None + + seg_image = seg_image.resize(image.size) + return seg_image + + +def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand=1.2): + body = Image.fromarray(image[:, :, ::-1]) + + x, y, x1, y1 = face_box + # print(x1-x,y1-y) + crop_box, s = get_crop_box(face_box, expand) + x_s, y_s, x_e, y_e = crop_box + + face_large = body.crop(crop_box) + ori_shape = face_large.size + + mask_image = face_seg(face_large) + mask_small = mask_image.crop((x - x_s, y - y_s, x1 - x_s, y1 - y_s)) + mask_image = Image.new('L', ori_shape, 0) + mask_image.paste(mask_small, (x - x_s, y - y_s, x1 - x_s, y1 - y_s)) + + # keep upper_boundary_ratio of talking area + width, height = mask_image.size + top_boundary = int(height * upper_boundary_ratio) + modified_mask_image = Image.new('L', ori_shape, 0) + modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary)) + + blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1 + mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0) + return mask_array, crop_box + + +def create_dir(dir_path): + if not os.path.exists(dir_path): + os.makedirs(dir_path) + + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +# initialize the mmpose model +device = "cuda" if torch.cuda.is_available() else "cpu" +fa = FaceAlignment(1, flip_input=False, device=device) +config_file = './utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py' +checkpoint_file = '../models/dwpose/dw-ll_ucoco_384.pth' +model = init_model(config_file, checkpoint_file, device=device) +vae = AutoencoderKL.from_pretrained("../models/sd-vae-ft-mse") +vae.to(device) +fp = FaceParsing() +if __name__ == '__main__': + # 视频文件地址 + parser = argparse.ArgumentParser() + parser.add_argument("--file", + type=str, + default=r'D:\ok\test.mp4', + ) + parser.add_argument("--avatar_id", + type=str, + default='1', + ) + args = parser.parse_args() + file = args.file + # 保存文件设置 可以不动 + save_path = f'../data/avatars/avator_{args.avatar_id}' + save_full_path = f'../data/avatars/avator_{args.avatar_id}/full_imgs' + create_dir(save_path) + create_dir(save_full_path) + mask_out_path = f'../data/avatars/avator_{args.avatar_id}/mask' + create_dir(mask_out_path) + + # 模型 + mask_coords_path = f'{save_path}/mask_coords.pkl' + coords_path = f'{save_path}/coords.pkl' + latents_out_path = f'{save_path}/latents.pt' + + with open(f'{save_path}/avator_info.json', "w") as f: + json.dump({ + "avatar_id": args.avatar_id, + "video_path": file, + "bbox_shift": 5 + }, f) + + if os.path.isfile(file): + video2imgs(file, save_full_path, ext='png') + else: + files = os.listdir(file) + files.sort() + files = [file for file in files if file.split(".")[-1] == "png"] + for filename in files: + shutil.copyfile(f"{file}/{filename}", f"{save_full_path}/{filename}") + input_img_list = sorted(glob.glob(os.path.join(save_full_path, '*.[jpJP][pnPN]*[gG]'))) + print("extracting landmarks...") + coord_list, frame_list = get_landmark_and_bbox(input_img_list, 5) + input_latent_list = [] + idx = -1 + # maker if the bbox is not sufficient + coord_placeholder = (0.0, 0.0, 0.0, 0.0) + for bbox, frame in zip(coord_list, frame_list): + idx = idx + 1 + if bbox == coord_placeholder: + continue + x1, y1, x2, y2 = bbox + crop_frame = frame[y1:y2, x1:x2] + resized_crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4) + latents = get_latents_for_unet(resized_crop_frame) + input_latent_list.append(latents) + + frame_list_cycle = frame_list + frame_list[::-1] + coord_list_cycle = coord_list + coord_list[::-1] + input_latent_list_cycle = input_latent_list + input_latent_list[::-1] + mask_coords_list_cycle = [] + mask_list_cycle = [] + for i, frame in enumerate(tqdm(frame_list_cycle)): + cv2.imwrite(f"{save_full_path}/{str(i).zfill(8)}.png", frame) + + face_box = coord_list_cycle[i] + mask, crop_box = get_image_prepare_material(frame, face_box) + cv2.imwrite(f"{mask_out_path}/{str(i).zfill(8)}.png", mask) + mask_coords_list_cycle += [crop_box] + mask_list_cycle.append(mask) + + with open(mask_coords_path, 'wb') as f: + pickle.dump(mask_coords_list_cycle, f) + + with open(coords_path, 'wb') as f: + pickle.dump(coord_list_cycle, f) + torch.save(input_latent_list_cycle, os.path.join(latents_out_path)) diff --git a/musetalk/utils/face_parsing/__init__.py b/musetalk/utils/face_parsing/__init__.py index fc963a3..003147f 100755 --- a/musetalk/utils/face_parsing/__init__.py +++ b/musetalk/utils/face_parsing/__init__.py @@ -7,18 +7,18 @@ from PIL import Image from .model import BiSeNet import torchvision.transforms as transforms + class FaceParsing(): - def __init__(self): - self.net = self.model_init() + def __init__(self, resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth', + model_pth='./models/face-parse-bisent/79999_iter.pth'): + self.net = self.model_init(resnet_path,model_pth) self.preprocess = self.image_preprocess() - def model_init(self, - resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth', - model_pth='./models/face-parse-bisent/79999_iter.pth'): + def model_init(self,resnet_path, model_pth): net = BiSeNet(resnet_path) if torch.cuda.is_available(): net.cuda() - net.load_state_dict(torch.load(model_pth)) + net.load_state_dict(torch.load(model_pth)) else: net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu'))) net.eval() @@ -44,13 +44,13 @@ class FaceParsing(): img = torch.unsqueeze(img, 0) out = self.net(img)[0] parsing = out.squeeze(0).cpu().numpy().argmax(0) - parsing[np.where(parsing>13)] = 0 - parsing[np.where(parsing>=1)] = 255 + parsing[np.where(parsing > 13)] = 0 + parsing[np.where(parsing >= 1)] = 255 parsing = Image.fromarray(parsing.astype(np.uint8)) return parsing + if __name__ == "__main__": fp = FaceParsing() segmap = fp('154_small.png') segmap.save('res.png') - From 18d7db35a7e1390b3dc376d41770483c81766a86 Mon Sep 17 00:00:00 2001 From: Yun <2289128964@qq.com> Date: Sun, 23 Jun 2024 14:51:58 +0800 Subject: [PATCH 3/5] =?UTF-8?q?feat:=20=E5=AE=8C=E5=96=84=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=E6=88=90=E8=87=AA=E5=8A=A8=E7=BB=9D=E5=AF=B9=E8=B7=AF?= =?UTF-8?q?=E5=BE=84,=E6=B7=BB=E5=8A=A0=E6=8E=A5=E5=8F=A3=E7=94=9F?= =?UTF-8?q?=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app.py | 49 +++++++++++++-------- musetalk/simple_musetalk.py | 87 ++++++++++++++++++++++--------------- 2 files changed, 84 insertions(+), 52 deletions(-) diff --git a/app.py b/app.py index a916733..1102190 100644 --- a/app.py +++ b/app.py @@ -1,29 +1,22 @@ # server.py -from flask import Flask, render_template, send_from_directory, request, jsonify -from flask_sockets import Sockets -import base64 -import time +import argparse +import asyncio import json -import gevent -from gevent import pywsgi -from geventwebsocket.handler import WebSocketHandler -import os -import re -import numpy as np -from threading import Thread, Event import multiprocessing +from threading import Thread, Event -from aiohttp import web import aiohttp import aiohttp_cors +from aiohttp import web from aiortc import RTCPeerConnection, RTCSessionDescription +from flask import Flask +from flask_sockets import Sockets +from gevent import pywsgi +from geventwebsocket.handler import WebSocketHandler + +from musetalk.simple_musetalk import create_musetalk_human from webrtc import HumanPlayer -import argparse - -import shutil -import asyncio - app = Flask(__name__) sockets = Sockets(app) global nerfreal @@ -135,6 +128,27 @@ async def human(request): ) +async def handle_create_musetalk(request): + reader = await request.multipart() + # 处理文件部分 + file_part = await reader.next() + filename = file_part.filename + file_data = await file_part.read() # 读取文件的内容 + # 注意:确保这个文件路径是可写的 + with open(filename, 'wb') as f: + f.write(file_data) + # 处理整数部分 + part = await reader.next() + avatar_id = int(await part.text()) + create_musetalk_human(filename, avatar_id) + os.remove(filename) + return web.json_response({ + 'status': 'success', + 'filename': filename, + 'int_value': avatar_id, + }) + + async def on_shutdown(app): # close peer connections coros = [pc.close() for pc in pcs] @@ -405,6 +419,7 @@ if __name__ == '__main__': appasync.on_shutdown.append(on_shutdown) appasync.router.add_post("/offer", offer) appasync.router.add_post("/human", human) + appasync.router.add_post("/create_musetalk", handle_create_musetalk) appasync.router.add_static('/', path='web') # Configure default CORS settings. diff --git a/musetalk/simple_musetalk.py b/musetalk/simple_musetalk.py index 1a6654e..97b8a36 100644 --- a/musetalk/simple_musetalk.py +++ b/musetalk/simple_musetalk.py @@ -4,7 +4,6 @@ import json import os import pickle import shutil -import sys import cv2 import numpy as np @@ -17,7 +16,10 @@ from mmpose.apis import inference_topdown, init_model from mmpose.structures import merge_data_samples from tqdm import tqdm -from utils.face_parsing import FaceParsing +try: + from utils.face_parsing import FaceParsing +except ModuleNotFoundError: + from musetalk.utils.face_parsing import FaceParsing def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000): @@ -55,6 +57,7 @@ def get_landmark_and_bbox(img_list, upperbondrange=0): print('get key_landmark and face bounding boxes with the default value') average_range_minus = [] average_range_plus = [] + coord_placeholder = (0.0, 0.0, 0.0, 0.0) for fb in tqdm(batches): results = inference_topdown(model, np.asarray(fb)[0]) results = merge_data_samples(results) @@ -235,57 +238,47 @@ def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand return mask_array, crop_box +##todo 简单根据文件后缀判断 要更精确的可以自己修改 使用 magic +def is_video_file(file_path): + video_exts = ['.mp4', '.mkv', '.flv', '.avi', '.mov'] # 这里列出了一些常见的视频文件扩展名,可以根据需要添加更多 + file_ext = os.path.splitext(file_path)[1].lower() # 获取文件扩展名并转换为小写 + return file_ext in video_exts + + def create_dir(dir_path): if not os.path.exists(dir_path): os.makedirs(dir_path) -sys.path.append(os.path.dirname(os.path.abspath(__file__))) +current_dir = os.path.dirname(os.path.abspath(__file__)) -# initialize the mmpose model -device = "cuda" if torch.cuda.is_available() else "cpu" -fa = FaceAlignment(1, flip_input=False, device=device) -config_file = './utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py' -checkpoint_file = '../models/dwpose/dw-ll_ucoco_384.pth' -model = init_model(config_file, checkpoint_file, device=device) -vae = AutoencoderKL.from_pretrained("../models/sd-vae-ft-mse") -vae.to(device) -fp = FaceParsing() -if __name__ == '__main__': - # 视频文件地址 - parser = argparse.ArgumentParser() - parser.add_argument("--file", - type=str, - default=r'D:\ok\test.mp4', - ) - parser.add_argument("--avatar_id", - type=str, - default='1', - ) - args = parser.parse_args() - file = args.file + +def create_musetalk_human(file, avatar_id): # 保存文件设置 可以不动 - save_path = f'../data/avatars/avator_{args.avatar_id}' - save_full_path = f'../data/avatars/avator_{args.avatar_id}/full_imgs' + save_path = os.path.join(current_dir, f'../data/avatars/avator_{avatar_id}') + save_full_path = os.path.join(current_dir, f'../data/avatars/avator_{avatar_id}/full_imgs') create_dir(save_path) create_dir(save_full_path) - mask_out_path = f'../data/avatars/avator_{args.avatar_id}/mask' + mask_out_path = os.path.join(current_dir, f'../data/avatars/avator_{avatar_id}/mask') create_dir(mask_out_path) # 模型 - mask_coords_path = f'{save_path}/mask_coords.pkl' - coords_path = f'{save_path}/coords.pkl' - latents_out_path = f'{save_path}/latents.pt' + mask_coords_path = os.path.join(current_dir, f'{save_path}/mask_coords.pkl') + coords_path = os.path.join(current_dir, f'{save_path}/coords.pkl') + latents_out_path = os.path.join(current_dir, f'{save_path}/latents.pt') - with open(f'{save_path}/avator_info.json', "w") as f: + with open(os.path.join(current_dir, f'{save_path}/avator_info.json'), "w") as f: json.dump({ - "avatar_id": args.avatar_id, + "avatar_id": avatar_id, "video_path": file, "bbox_shift": 5 }, f) if os.path.isfile(file): - video2imgs(file, save_full_path, ext='png') + if is_video_file(file): + video2imgs(file, save_full_path, ext='png') + else: + shutil.copyfile(file, f"{save_full_path}/{os.path.basename(file)}") else: files = os.listdir(file) files.sort() @@ -316,7 +309,6 @@ if __name__ == '__main__': mask_list_cycle = [] for i, frame in enumerate(tqdm(frame_list_cycle)): cv2.imwrite(f"{save_full_path}/{str(i).zfill(8)}.png", frame) - face_box = coord_list_cycle[i] mask, crop_box = get_image_prepare_material(frame, face_box) cv2.imwrite(f"{mask_out_path}/{str(i).zfill(8)}.png", mask) @@ -329,3 +321,28 @@ if __name__ == '__main__': with open(coords_path, 'wb') as f: pickle.dump(coord_list_cycle, f) torch.save(input_latent_list_cycle, os.path.join(latents_out_path)) + + +# initialize the mmpose model +device = "cuda" if torch.cuda.is_available() else "cpu" +fa = FaceAlignment(1, flip_input=False, device=device) +config_file = os.path.join(current_dir, 'utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py') +checkpoint_file = os.path.abspath(os.path.join(current_dir, '../models/dwpose/dw-ll_ucoco_384.pth')) +model = init_model(config_file, checkpoint_file, device=device) +vae = AutoencoderKL.from_pretrained(os.path.abspath(os.path.join(current_dir, '../models/sd-vae-ft-mse'))) +vae.to(device) +fp = FaceParsing(os.path.abspath(os.path.join(current_dir, '../models/face-parse-bisent/resnet18-5c106cde.pth')), + os.path.abspath(os.path.join(current_dir, '../models/face-parse-bisent/79999_iter.pth'))) +if __name__ == '__main__': + # 视频文件地址 + parser = argparse.ArgumentParser() + parser.add_argument("--file", + type=str, + default=r'D:\ok\00000000.png', + ) + parser.add_argument("--avatar_id", + type=str, + default='3', + ) + args = parser.parse_args() + create_musetalk_human(args.file, args.avatar_id) From cd7d5f31b54d298345d2457af6247a0d305adbb4 Mon Sep 17 00:00:00 2001 From: Yun <2289128964@qq.com> Date: Thu, 4 Jul 2024 09:43:56 +0800 Subject: [PATCH 4/5] =?UTF-8?q?feat:=20=E5=AE=8C=E5=96=84=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=E6=88=90=E8=87=AA=E5=8A=A8=E7=BB=9D=E5=AF=B9=E8=B7=AF?= =?UTF-8?q?=E5=BE=84,=E6=B7=BB=E5=8A=A0=E6=8E=A5=E5=8F=A3=E7=94=9F?= =?UTF-8?q?=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 17 +- app.py | 235 ++++++++++-------------- asrreal.py | 149 +++++---------- lipasr.py | 54 +----- lipreal.py | 12 +- museasr.py | 54 +----- musereal.py | 227 ++++++++++++----------- musetalk/utils/face_parsing/__init__.py | 15 +- nerfreal.py | 24 ++- ttsreal.py | 17 +- web/chat.html | 44 +++-- web/rtcpushchat.html | 52 +++--- web/webrtcapi.html | 1 + web/webrtcchat.html | 53 +++--- 14 files changed, 401 insertions(+), 553 deletions(-) diff --git a/README.md b/README.md index b169183..a8a7c44 100644 --- a/README.md +++ b/README.md @@ -6,11 +6,10 @@ Real time interactive streaming digital human, realize audio video synchronous ## Features 1. 支持多种数字人模型: ernerf、musetalk、wav2lip 2. 支持声音克隆 -3. 支持多种音频特征驱动:wav2vec、hubert +3. 支持数字人说话被打断 4. 支持全身视频拼接 5. 支持rtmp和webrtc 6. 支持视频编排:不说话时播放自定义视频 -7. 支持大模型对话 ## 1. Installation @@ -171,13 +170,11 @@ cd MuseTalk 修改configs/inference/realtime.yaml,将preparation改为True python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml 运行后将results/avatars下文件拷到本项目的data/avatars下 -``` - -```bash -也可以试用本地目录下的 simple_musetalk.py -cd musetalk -python simple_musetalk.py --avatar_id 2 --file D:\\ok\\test.mp4 -运行后将直接生成在data/avatars下 +方法二 +执行 +cd musetalk +python simple_musetalk.py --avatar_id 4 --file D:\\ok\\test.mp4 +支持视频和图片生成 会自动生成到data的avatars目录下 ``` ### 3.10 模型用wav2lip @@ -185,7 +182,7 @@ python simple_musetalk.py --avatar_id 2 --file D:\\ok\\test.mp4 - 下载模型 下载wav2lip运行需要的模型,网盘地址 https://drive.uc.cn/s/551be97d7cfa4 将s3fd.pth拷到本项目wav2lip/face_detection/detection/sfd/s3fd.pth, 将wav2lip.pth拷到本项目的models下 -数字人模型文件 wav2lip_avatar1.tar.gz, 解压后将整个文件夹拷到本项目的data/avatars下 +数字人模型文件 wav2lip_avatar1.tar.gz,网盘地址 https://drive.uc.cn/s/5bd0cde0b0774, 解压后将整个文件夹拷到本项目的data/avatars下 - 运行 python app.py --transport webrtc --model wav2lip --avatar_id wav2lip_avatar1 用浏览器打开http://serverip:8010/webrtcapi.html diff --git a/app.py b/app.py index 1102190..83a60ce 100644 --- a/app.py +++ b/app.py @@ -1,31 +1,39 @@ # server.py -import argparse -import asyncio -import json -import multiprocessing -from threading import Thread, Event - -import aiohttp -import aiohttp_cors -from aiohttp import web -from aiortc import RTCPeerConnection, RTCSessionDescription -from flask import Flask +from flask import Flask, render_template,send_from_directory,request, jsonify from flask_sockets import Sockets +import base64 +import time +import json +import gevent from gevent import pywsgi from geventwebsocket.handler import WebSocketHandler +import os +import re +import numpy as np +from threading import Thread,Event +import multiprocessing -from musetalk.simple_musetalk import create_musetalk_human +from aiohttp import web +import aiohttp +import aiohttp_cors +from aiortc import RTCPeerConnection, RTCSessionDescription from webrtc import HumanPlayer +import argparse + +import shutil +import asyncio + + app = Flask(__name__) sockets = Sockets(app) global nerfreal - + @sockets.route('/humanecho') def echo_socket(ws): # 获取WebSocket对象 - # ws = request.environ.get('wsgi.websocket') + #ws = request.environ.get('wsgi.websocket') # 如果没有获取到,返回错误信息 if not ws: print('未建立连接!') @@ -34,11 +42,11 @@ def echo_socket(ws): else: print('建立连接!') while True: - message = ws.receive() - - if not message or len(message) == 0: + message = ws.receive() + + if not message or len(message)==0: return '输入信息为空' - else: + else: nerfreal.put_msg_txt(message) @@ -46,16 +54,15 @@ def llm_response(message): from llm.LLM import LLM # llm = LLM().init_model('Gemini', model_path= 'gemini-pro',api_key='Your API Key', proxy_url=None) # llm = LLM().init_model('ChatGPT', model_path= 'gpt-3.5-turbo',api_key='Your API Key') - llm = LLM().init_model('VllmGPT', model_path='THUDM/chatglm3-6b') + llm = LLM().init_model('VllmGPT', model_path= 'THUDM/chatglm3-6b') response = llm.chat(message) print(response) return response - @sockets.route('/humanchat') def chat_socket(ws): # 获取WebSocket对象 - # ws = request.environ.get('wsgi.websocket') + #ws = request.environ.get('wsgi.websocket') # 如果没有获取到,返回错误信息 if not ws: print('未建立连接!') @@ -64,20 +71,18 @@ def chat_socket(ws): else: print('建立连接!') while True: - message = ws.receive() - - if len(message) == 0: + message = ws.receive() + + if len(message)==0: return '输入信息为空' else: - res = llm_response(message) + res=llm_response(message) nerfreal.put_msg_txt(res) - #####webrtc############################### pcs = set() - -# @app.route('/offer', methods=['POST']) +#@app.route('/offer', methods=['POST']) async def offer(request): params = await request.json() offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) @@ -101,7 +106,7 @@ async def offer(request): answer = await pc.createAnswer() await pc.setLocalDescription(answer) - # return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}) + #return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}) return web.Response( content_type="application/json", @@ -110,61 +115,39 @@ async def offer(request): ), ) - async def human(request): params = await request.json() - if params['type'] == 'echo': + if params.get('interrupt'): + nerfreal.pause_talk() + + if params['type']=='echo': nerfreal.put_msg_txt(params['text']) - elif params['type'] == 'chat': - res = await asyncio.get_event_loop().run_in_executor(None, llm_response(params['text'])) + elif params['type']=='chat': + res=await asyncio.get_event_loop().run_in_executor(None, llm_response(params['text'])) nerfreal.put_msg_txt(res) return web.Response( content_type="application/json", text=json.dumps( - {"code": 0, "data": "ok"} + {"code": 0, "data":"ok"} ), ) - -async def handle_create_musetalk(request): - reader = await request.multipart() - # 处理文件部分 - file_part = await reader.next() - filename = file_part.filename - file_data = await file_part.read() # 读取文件的内容 - # 注意:确保这个文件路径是可写的 - with open(filename, 'wb') as f: - f.write(file_data) - # 处理整数部分 - part = await reader.next() - avatar_id = int(await part.text()) - create_musetalk_human(filename, avatar_id) - os.remove(filename) - return web.json_response({ - 'status': 'success', - 'filename': filename, - 'int_value': avatar_id, - }) - - async def on_shutdown(app): # close peer connections coros = [pc.close() for pc in pcs] await asyncio.gather(*coros) pcs.clear() - -async def post(url, data): +async def post(url,data): try: async with aiohttp.ClientSession() as session: - async with session.post(url, data=data) as response: + async with session.post(url,data=data) as response: return await response.text() except aiohttp.ClientError as e: print(f'Error: {e}') - async def run(push_url): pc = RTCPeerConnection() pcs.add(pc) @@ -181,10 +164,8 @@ async def run(push_url): video_sender = pc.addTrack(player.video) await pc.setLocalDescription(await pc.createOffer()) - answer = await post(push_url, pc.localDescription.sdp) - await pc.setRemoteDescription(RTCSessionDescription(sdp=answer, type='answer')) - - + answer = await post(push_url,pc.localDescription.sdp) + await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer')) ########################################## # os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' # os.environ['MULTIPROCESSING_METHOD'] = 'forkserver' @@ -203,20 +184,14 @@ if __name__ == '__main__': ### training options parser.add_argument('--ckpt', type=str, default='data/pretrained/ngp_kf.pth') - - parser.add_argument('--num_rays', type=int, default=4096 * 16, - help="num rays sampled per image for each training step") + + parser.add_argument('--num_rays', type=int, default=4096 * 16, help="num rays sampled per image for each training step") parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") - parser.add_argument('--max_steps', type=int, default=16, - help="max num steps sampled per ray (only valid when using --cuda_ray)") - parser.add_argument('--num_steps', type=int, default=16, - help="num steps sampled per ray (only valid when NOT using --cuda_ray)") - parser.add_argument('--upsample_steps', type=int, default=0, - help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)") - parser.add_argument('--update_extra_interval', type=int, default=16, - help="iter interval to update extra status (only valid when using --cuda_ray)") - parser.add_argument('--max_ray_batch', type=int, default=4096, - help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)") + parser.add_argument('--max_steps', type=int, default=16, help="max num steps sampled per ray (only valid when using --cuda_ray)") + parser.add_argument('--num_steps', type=int, default=16, help="num steps sampled per ray (only valid when NOT using --cuda_ray)") + parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)") + parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)") + parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)") ### loss set parser.add_argument('--warmup_step', type=int, default=10000, help="warm up steps") @@ -227,35 +202,27 @@ if __name__ == '__main__': ### network backbone options parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training") - + parser.add_argument('--bg_img', type=str, default='white', help="background image") parser.add_argument('--fbg', action='store_true', help="frame-wise bg") parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes") - parser.add_argument('--fix_eye', type=float, default=-1, - help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye") + parser.add_argument('--fix_eye', type=float, default=-1, help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye") parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence") - parser.add_argument('--torso_shrink', type=float, default=0.8, - help="shrink bg coords to allow more flexibility in deform") + parser.add_argument('--torso_shrink', type=float, default=0.8, help="shrink bg coords to allow more flexibility in deform") ### dataset options parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)") - parser.add_argument('--preload', type=int, default=0, - help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.") + parser.add_argument('--preload', type=int, default=0, help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.") # (the default value is for the fox dataset) - parser.add_argument('--bound', type=float, default=1, - help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.") + parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.") parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3") parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location") - parser.add_argument('--dt_gamma', type=float, default=1 / 256, - help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") + parser.add_argument('--dt_gamma', type=float, default=1/256, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera") - parser.add_argument('--density_thresh', type=float, default=10, - help="threshold for density grid to be occupied (sigma)") - parser.add_argument('--density_thresh_torso', type=float, default=0.01, - help="threshold for density grid to be occupied (alpha)") - parser.add_argument('--patch_size', type=int, default=1, - help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable") + parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied (sigma)") + parser.add_argument('--density_thresh_torso', type=float, default=0.01, help="threshold for density grid to be occupied (alpha)") + parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable") parser.add_argument('--init_lips', action='store_true', help="init lips region") parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region") @@ -273,15 +240,12 @@ if __name__ == '__main__': parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel") ### else - parser.add_argument('--att', type=int, default=2, - help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)") - parser.add_argument('--aud', type=str, default='', - help="audio source (empty will load the default, else should be a path to a npy file)") + parser.add_argument('--att', type=int, default=2, help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)") + parser.add_argument('--aud', type=str, default='', help="audio source (empty will load the default, else should be a path to a npy file)") parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits") parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off") - parser.add_argument('--ind_num', type=int, default=10000, - help="number of individual codes, should be larger than training dataset size") + parser.add_argument('--ind_num', type=int, default=10000, help="number of individual codes, should be larger than training dataset size") parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off") @@ -290,8 +254,7 @@ if __name__ == '__main__': parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)") parser.add_argument('--train_camera', action='store_true', help="optimize camera pose") - parser.add_argument('--smooth_path', action='store_true', - help="brute-force smooth camera pose trajectory with a window size") + parser.add_argument('--smooth_path', action='store_true', help="brute-force smooth camera pose trajectory with a window size") parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size") # asr @@ -299,8 +262,8 @@ if __name__ == '__main__': parser.add_argument('--asr_wav', type=str, default='', help="load the wav and use as input") parser.add_argument('--asr_play', action='store_true', help="play out the audio") - # parser.add_argument('--asr_model', type=str, default='deepspeech') - parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') # + #parser.add_argument('--asr_model', type=str, default='deepspeech') + parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') # # parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self') # parser.add_argument('--asr_model', type=str, default='facebook/hubert-large-ls960-ft') @@ -319,45 +282,42 @@ if __name__ == '__main__': parser.add_argument('--fullbody_offset_x', type=int, default=0) parser.add_argument('--fullbody_offset_y', type=int, default=0) - # musetalk opt + #musetalk opt parser.add_argument('--avatar_id', type=str, default='avator_1') parser.add_argument('--bbox_shift', type=int, default=5) parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--customvideo', action='store_true', help="custom video") - parser.add_argument('--static_img', action='store_true', help="Use the first photo as a time of rest") parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img') parser.add_argument('--customvideo_imgnum', type=int, default=1) - parser.add_argument('--tts', type=str, default='edgetts') # xtts gpt-sovits + parser.add_argument('--tts', type=str, default='edgetts') #xtts gpt-sovits parser.add_argument('--REF_FILE', type=str, default=None) parser.add_argument('--REF_TEXT', type=str, default=None) - parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000 + parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000 # parser.add_argument('--CHARACTER', type=str, default='test') # parser.add_argument('--EMOTION', type=str, default='default') - parser.add_argument('--model', type=str, default='ernerf') # musetalk wav2lip + parser.add_argument('--model', type=str, default='ernerf') #musetalk wav2lip - parser.add_argument('--transport', type=str, default='rtcpush') # rtmp webrtc rtcpush - parser.add_argument('--push_url', type=str, - default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') # rtmp://localhost/live/livestream + parser.add_argument('--transport', type=str, default='rtcpush') #rtmp webrtc rtcpush + parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream parser.add_argument('--listenport', type=int, default=8010) opt = parser.parse_args() - # app.config.from_object(opt) - # print(app.config) + #app.config.from_object(opt) + #print(app.config) if opt.model == 'ernerf': from ernerf.nerf_triplane.provider import NeRFDataset_Test from ernerf.nerf_triplane.utils import * from ernerf.nerf_triplane.network import NeRFNetwork from nerfreal import NeRFReal - # assert test mode opt.test = True opt.test_train = False - # opt.train_camera =True + #opt.train_camera =True # explicit smoothing opt.smooth_path = True opt.smooth_lips = True @@ -370,7 +330,7 @@ if __name__ == '__main__': opt.exp_eye = True opt.smooth_eye = True - if opt.torso_imgs == '': # no img,use model output + if opt.torso_imgs=='': #no img,use model output opt.torso = True # assert opt.cuda_ray, "Only support CUDA ray mode." @@ -386,10 +346,9 @@ if __name__ == '__main__': model = NeRFNetwork(opt) criterion = torch.nn.MSELoss(reduction='none') - metrics = [] # use no metric in GUI for faster initialization... + metrics = [] # use no metric in GUI for faster initialization... print(model) - trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, - metrics=metrics, use_checkpoint=opt.ckpt) + trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt) test_loader = NeRFDataset_Test(opt, device=device).dataloader() model.aud_features = test_loader._data.auds @@ -399,19 +358,17 @@ if __name__ == '__main__': nerfreal = NeRFReal(opt, trainer, test_loader) elif opt.model == 'musetalk': from musereal import MuseReal - print(opt) nerfreal = MuseReal(opt) elif opt.model == 'wav2lip': from lipreal import LipReal - print(opt) nerfreal = LipReal(opt) - # txt_to_audio('我是中国人,我来自北京') - if opt.transport == 'rtmp': + #txt_to_audio('我是中国人,我来自北京') + if opt.transport=='rtmp': thread_quit = Event() - rendthrd = Thread(target=nerfreal.render, args=(thread_quit,)) + rendthrd = Thread(target=nerfreal.render,args=(thread_quit,)) rendthrd.start() ############################################################################# @@ -419,37 +376,35 @@ if __name__ == '__main__': appasync.on_shutdown.append(on_shutdown) appasync.router.add_post("/offer", offer) appasync.router.add_post("/human", human) - appasync.router.add_post("/create_musetalk", handle_create_musetalk) - appasync.router.add_static('/', path='web') + appasync.router.add_static('/',path='web') # Configure default CORS settings. cors = aiohttp_cors.setup(appasync, defaults={ - "*": aiohttp_cors.ResourceOptions( - allow_credentials=True, - expose_headers="*", - allow_headers="*", - ) - }) + "*": aiohttp_cors.ResourceOptions( + allow_credentials=True, + expose_headers="*", + allow_headers="*", + ) + }) # Configure CORS on all routes. for route in list(appasync.router.routes()): cors.add(route) - def run_server(runner): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(runner.setup()) site = web.TCPSite(runner, '0.0.0.0', opt.listenport) loop.run_until_complete(site.start()) - if opt.transport == 'rtcpush': + if opt.transport=='rtcpush': loop.run_until_complete(run(opt.push_url)) - loop.run_forever() - - + loop.run_forever() Thread(target=run_server, args=(web.AppRunner(appasync),)).start() print('start websocket server') - # app.on_shutdown.append(on_shutdown) - # app.router.add_post("/offer", offer) + #app.on_shutdown.append(on_shutdown) + #app.router.add_post("/offer", offer) server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler) server.serve_forever() + + \ No newline at end of file diff --git a/asrreal.py b/asrreal.py index b3e4093..62aa15e 100644 --- a/asrreal.py +++ b/asrreal.py @@ -4,29 +4,19 @@ import torch import torch.nn.functional as F from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel -#import pyaudio -import soundfile as sf -import resampy import queue from queue import Queue #from collections import deque from threading import Thread, Event -from io import BytesIO -class ASR: +from baseasr import BaseASR + +class ASR(BaseASR): def __init__(self, opt): - - self.opt = opt - - self.play = opt.asr_play #false + super().__init__(opt) self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.fps = opt.fps # 20 ms per frame - self.sample_rate = 16000 - self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000) - self.mode = 'live' if opt.asr_wav == '' else 'file' - if 'esperanto' in self.opt.asr_model: self.audio_dim = 44 elif 'deepspeech' in self.opt.asr_model: @@ -41,30 +31,11 @@ class ASR: self.context_size = opt.m self.stride_left_size = opt.l self.stride_right_size = opt.r - self.text = '[START]\n' - self.terminated = False - self.frames = [] - self.inwarm = False # pad left frames if self.stride_left_size > 0: self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size) - - self.exit_event = Event() - #self.audio_instance = pyaudio.PyAudio() #not need - - # create input stream - self.queue = Queue() - self.output_queue = Queue() - # start a background process to read frames - #self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk) - #self.queue = Queue() - #self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk)) - - # current location of audio - self.idx = 0 - # create wav2vec model print(f'[INFO] loading ASR model {self.opt.asr_model}...') if 'hubert' in self.opt.asr_model: @@ -74,10 +45,6 @@ class ASR: self.processor = AutoProcessor.from_pretrained(opt.asr_model) self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device) - # prepare to save logits - if self.opt.asr_save_feats: - self.all_feats = [] - # the extracted features # use a loop queue to efficiently record endless features: [f--t---][-------][-------] self.feat_buffer_size = 4 @@ -93,8 +60,16 @@ class ASR: # warm up steps needed: mid + right + window_size + attention_size self.warm_up_steps = self.context_size + self.stride_left_size + self.stride_right_size #+ self.stride_left_size #+ 8 + 2 * 3 - self.listening = False - self.playing = False + def get_audio_frame(self): + try: + frame = self.queue.get(block=False) + type = 0 + #print(f'[INFO] get frame {frame.shape}') + except queue.Empty: + frame = np.zeros(self.chunk, dtype=np.float32) + type = 1 + + return frame,type def get_next_feat(self): #get audio embedding to nerf # return a [1/8, 16] window, for the next input to nerf side. @@ -136,29 +111,19 @@ class ASR: def run_step(self): - if self.terminated: - return - # get a frame of audio - frame,type = self.__get_audio_frame() - - # the last frame - if frame is None: - # terminate, but always run the network for the left frames - self.terminated = True - else: - self.frames.append(frame) - # put to output - self.output_queue.put((frame,type)) - # context not enough, do not run network. - if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size: - return + frame,type = self.get_audio_frame() + self.frames.append(frame) + # put to output + self.output_queue.put((frame,type)) + # context not enough, do not run network. + if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size: + return inputs = np.concatenate(self.frames) # [N * chunk] # discard the old part to save memory - if not self.terminated: - self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] + self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] #print(f'[INFO] frame_to_text... ') #t = time.time() @@ -166,10 +131,6 @@ class ASR: #print(f'-------wav2vec time:{time.time()-t:.4f}s') feats = logits # better lips-sync than labels - # save feats - if self.opt.asr_save_feats: - self.all_feats.append(feats) - # record the feats efficiently.. (no concat, constant memory) start = self.feat_buffer_idx * self.context_size end = start + feats.shape[0] @@ -203,24 +164,6 @@ class ASR: # np.save(output_path, unfold_feats.cpu().numpy()) # print(f"[INFO] saved logits to {output_path}") - def put_audio_frame(self,audio_chunk): #16khz 20ms pcm - self.queue.put(audio_chunk) - - def __get_audio_frame(self): - if self.inwarm: # warm up - return np.zeros(self.chunk, dtype=np.float32),1 - - try: - frame = self.queue.get(block=False) - type = 0 - print(f'[INFO] get frame {frame.shape}') - except queue.Empty: - frame = np.zeros(self.chunk, dtype=np.float32) - type = 1 - - self.idx = self.idx + self.chunk - - return frame,type def __frame_to_text(self, frame): @@ -241,8 +184,8 @@ class ASR: right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input. # do not cut right if terminated. - if self.terminated: - right = logits.shape[1] + # if self.terminated: + # right = logits.shape[1] logits = logits[:, left:right] @@ -262,10 +205,23 @@ class ASR: return logits[0], None,None #predicted_ids[0], transcription # [N,] - - def get_audio_out(self): #get origin audio pcm to nerf - return self.output_queue.get() - + + def warm_up(self): + print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s') + t = time.time() + #for _ in range(self.stride_left_size): + # self.frames.append(np.zeros(self.chunk, dtype=np.float32)) + for _ in range(self.warm_up_steps): + self.run_step() + #if torch.cuda.is_available(): + # torch.cuda.synchronize() + t = time.time() - t + print(f'[INFO] warm-up done, actual latency = {t:.6f}s') + + #self.clear_queue() + + #####not used function##################################### + ''' def __init_queue(self): self.frames = [] self.queue.queue.clear() @@ -290,26 +246,6 @@ class ASR: if self.play: self.output_queue.queue.clear() - def warm_up(self): - - #self.listen() - - self.inwarm = True - print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s') - t = time.time() - #for _ in range(self.stride_left_size): - # self.frames.append(np.zeros(self.chunk, dtype=np.float32)) - for _ in range(self.warm_up_steps): - self.run_step() - #if torch.cuda.is_available(): - # torch.cuda.synchronize() - t = time.time() - t - print(f'[INFO] warm-up done, actual latency = {t:.6f}s') - self.inwarm = False - - #self.clear_queue() - - #####not used function##################################### def listen(self): # start if self.mode == 'live' and not self.listening: @@ -404,4 +340,5 @@ if __name__ == '__main__': raise ValueError("DeepSpeech features should not use this code to extract...") with ASR(opt) as asr: - asr.run() \ No newline at end of file + asr.run() +''' \ No newline at end of file diff --git a/lipasr.py b/lipasr.py index 5742dd7..29948ac 100644 --- a/lipasr.py +++ b/lipasr.py @@ -6,60 +6,16 @@ import queue from queue import Queue import multiprocessing as mp +from baseasr import BaseASR from wav2lip import audio -class LipASR: - def __init__(self, opt): - self.opt = opt - - self.fps = opt.fps # 20 ms per frame - self.sample_rate = 16000 - self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000) - self.queue = Queue() - # self.input_stream = BytesIO() - self.output_queue = mp.Queue() - - #self.audio_processor = audio_processor - self.batch_size = opt.batch_size - - self.frames = [] - self.stride_left_size = opt.l - self.stride_right_size = opt.r - #self.context_size = 10 - self.feat_queue = mp.Queue(5) - - self.warm_up() - - def put_audio_frame(self,audio_chunk): #16khz 20ms pcm - self.queue.put(audio_chunk) - - def __get_audio_frame(self): - try: - frame = self.queue.get(block=True,timeout=0.01) - type = 0 - #print(f'[INFO] get frame {frame.shape}') - except queue.Empty: - frame = np.zeros(self.chunk, dtype=np.float32) - type = 1 - - return frame,type - - def get_audio_out(self): #get origin audio pcm to nerf - return self.output_queue.get() - - def warm_up(self): - for _ in range(self.stride_left_size + self.stride_right_size): - audio_frame,type=self.__get_audio_frame() - self.frames.append(audio_frame) - self.output_queue.put((audio_frame,type)) - for _ in range(self.stride_left_size): - self.output_queue.get() +class LipASR(BaseASR): def run_step(self): ############################################## extract audio feature ############################################## # get a frame of audio for _ in range(self.batch_size*2): - frame,type = self.__get_audio_frame() + frame,type = self.get_audio_frame() self.frames.append(frame) # put to output self.output_queue.put((frame,type)) @@ -89,7 +45,3 @@ class LipASR: # discard the old part to save memory self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] - - - def get_next_feat(self,block,timeout): - return self.feat_queue.get(block,timeout) \ No newline at end of file diff --git a/lipreal.py b/lipreal.py index d69f3dc..9461e7b 100644 --- a/lipreal.py +++ b/lipreal.py @@ -164,6 +164,7 @@ class LipReal: self.__loadavatar() self.asr = LipASR(opt) + self.asr.warm_up() if opt.tts == "edgetts": self.tts = EdgeTTS(opt,self) elif opt.tts == "gpt-sovits": @@ -199,6 +200,10 @@ class LipReal: def put_audio_frame(self,audio_chunk): #16khz 20ms pcm self.asr.put_audio_frame(audio_chunk) + + def pause_talk(self): + self.tts.pause_talk() + self.asr.pause_talk() def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None): @@ -257,9 +262,12 @@ class LipReal: t = time.perf_counter() self.asr.run_step() - if video_track._queue.qsize()>=2*self.opt.batch_size: + # if video_track._queue.qsize()>=2*self.opt.batch_size: + # print('sleep qsize=',video_track._queue.qsize()) + # time.sleep(0.04*video_track._queue.qsize()*0.8) + if video_track._queue.qsize()>=5: print('sleep qsize=',video_track._queue.qsize()) - time.sleep(0.04*self.opt.batch_size*1.5) + time.sleep(0.04*video_track._queue.qsize()*0.8) # delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms # if delay > 0: diff --git a/museasr.py b/museasr.py index cfcd9ba..cb166a6 100644 --- a/museasr.py +++ b/museasr.py @@ -1,65 +1,22 @@ import time -import torch import numpy as np import queue from queue import Queue import multiprocessing as mp - +from baseasr import BaseASR from musetalk.whisper.audio2feature import Audio2Feature -class MuseASR: +class MuseASR(BaseASR): def __init__(self, opt, audio_processor:Audio2Feature): - self.opt = opt - - self.fps = opt.fps # 20 ms per frame - self.sample_rate = 16000 - self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000) - self.queue = Queue() - # self.input_stream = BytesIO() - self.output_queue = mp.Queue() - + super().__init__(opt) self.audio_processor = audio_processor - self.batch_size = opt.batch_size - - self.frames = [] - self.stride_left_size = opt.l - self.stride_right_size = opt.r - self.feat_queue = mp.Queue(5) - - self.warm_up() - - def put_audio_frame(self,audio_chunk): #16khz 20ms pcm - self.queue.put(audio_chunk) - - def __get_audio_frame(self): - try: - frame = self.queue.get(block=True,timeout=0.01) - type = 0 - #print(f'[INFO] get frame {frame.shape}') - except queue.Empty: - frame = np.zeros(self.chunk, dtype=np.float32) - type = 1 - - return frame,type - - def get_audio_out(self): #get origin audio pcm to nerf - return self.output_queue.get() - - def warm_up(self): - for _ in range(self.stride_left_size + self.stride_right_size): - audio_frame,type=self.__get_audio_frame() - self.frames.append(audio_frame) - self.output_queue.put((audio_frame,type)) - - for _ in range(self.stride_left_size): - self.output_queue.get() def run_step(self): ############################################## extract audio feature ############################################## start_time = time.time() for _ in range(self.batch_size*2): - audio_frame,type=self.__get_audio_frame() + audio_frame,type=self.get_audio_frame() self.frames.append(audio_frame) self.output_queue.put((audio_frame,type)) @@ -77,6 +34,3 @@ class MuseASR: self.feat_queue.put(whisper_chunks) # discard the old part to save memory self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] - - def get_next_feat(self,block,timeout): - return self.feat_queue.get(block,timeout) \ No newline at end of file diff --git a/musereal.py b/musereal.py index 7e58029..0f01dcf 100644 --- a/musereal.py +++ b/musereal.py @@ -2,7 +2,7 @@ import math import torch import numpy as np -# from .utils import * +#from .utils import * import subprocess import os import time @@ -18,19 +18,17 @@ from threading import Thread, Event from io import BytesIO import multiprocessing as mp -from musetalk.utils.utils import get_file_type, get_video_fps, datagen -# from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder -from musetalk.utils.blending import get_image, get_image_prepare_material, get_image_blending -from musetalk.utils.utils import load_all_model, load_diffusion_model, load_audio_model -from ttsreal import EdgeTTS, VoitsTTS, XTTS +from musetalk.utils.utils import get_file_type,get_video_fps,datagen +#from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder +from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending +from musetalk.utils.utils import load_all_model,load_diffusion_model,load_audio_model +from ttsreal import EdgeTTS,VoitsTTS,XTTS from museasr import MuseASR import asyncio from av import AudioFrame, VideoFrame from tqdm import tqdm - - def read_imgs(img_list): frames = [] print('reading images...') @@ -39,146 +37,143 @@ def read_imgs(img_list): frames.append(frame) return frames - def __mirror_index(size, index): - # size = len(self.coord_list_cycle) + #size = len(self.coord_list_cycle) turn = index // size res = index % size if turn % 2 == 0: return res else: - return size - res - 1 - - -def inference(render_event, batch_size, latents_out_path, audio_feat_queue, audio_out_queue, res_frame_queue, - ): # vae, unet, pe,timesteps + return size - res - 1 +def inference(render_event,batch_size,latents_out_path,audio_feat_queue,audio_out_queue,res_frame_queue, + ): #vae, unet, pe,timesteps + vae, unet, pe = load_diffusion_model() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") timesteps = torch.tensor([0], device=device) pe = pe.half() vae.vae = vae.vae.half() unet.model = unet.model.half() - + input_latent_list_cycle = torch.load(latents_out_path) length = len(input_latent_list_cycle) index = 0 - count = 0 - counttime = 0 + count=0 + counttime=0 print('start inference') while True: if render_event.is_set(): - starttime = time.perf_counter() + starttime=time.perf_counter() try: whisper_chunks = audio_feat_queue.get(block=True, timeout=1) except queue.Empty: continue - is_all_silence = True + is_all_silence=True audio_frames = [] - for _ in range(batch_size * 2): - frame, type = audio_out_queue.get() - audio_frames.append((frame, type)) - if type == 0: - is_all_silence = False + for _ in range(batch_size*2): + frame,type = audio_out_queue.get() + audio_frames.append((frame,type)) + if type==0: + is_all_silence=False if is_all_silence: for i in range(batch_size): - res_frame_queue.put((None, __mirror_index(length, index), audio_frames[i * 2:i * 2 + 2])) + res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2])) index = index + 1 else: # print('infer=======') - t = time.perf_counter() + t=time.perf_counter() whisper_batch = np.stack(whisper_chunks) latent_batch = [] for i in range(batch_size): - idx = __mirror_index(length, index + i) + idx = __mirror_index(length,index+i) latent = input_latent_list_cycle[idx] latent_batch.append(latent) latent_batch = torch.cat(latent_batch, dim=0) - + # for i, (whisper_batch,latent_batch) in enumerate(gen): audio_feature_batch = torch.from_numpy(whisper_batch) audio_feature_batch = audio_feature_batch.to(device=unet.device, - dtype=unet.model.dtype) + dtype=unet.model.dtype) audio_feature_batch = pe(audio_feature_batch) latent_batch = latent_batch.to(dtype=unet.model.dtype) # print('prepare time:',time.perf_counter()-t) # t=time.perf_counter() - pred_latents = unet.model(latent_batch, - timesteps, - encoder_hidden_states=audio_feature_batch).sample + pred_latents = unet.model(latent_batch, + timesteps, + encoder_hidden_states=audio_feature_batch).sample # print('unet time:',time.perf_counter()-t) # t=time.perf_counter() recon = vae.decode_latents(pred_latents) # print('vae time:',time.perf_counter()-t) - # print('diffusion len=',len(recon)) + #print('diffusion len=',len(recon)) counttime += (time.perf_counter() - t) count += batch_size - # _totalframe += 1 - if count >= 100: - print(f"------actual avg infer fps:{count / counttime:.4f}") - count = 0 - counttime = 0 - for i, res_frame in enumerate(recon): - # self.__pushmedia(res_frame,loop,audio_track,video_track) - res_frame_queue.put((res_frame, __mirror_index(length, index), audio_frames[i * 2:i * 2 + 2])) + #_totalframe += 1 + if count>=100: + print(f"------actual avg infer fps:{count/counttime:.4f}") + count=0 + counttime=0 + for i,res_frame in enumerate(recon): + #self.__pushmedia(res_frame,loop,audio_track,video_track) + res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2])) index = index + 1 - # print('total batch time:',time.perf_counter()-starttime) + #print('total batch time:',time.perf_counter()-starttime) else: time.sleep(1) print('musereal inference processor stop') - @torch.no_grad() class MuseReal: def __init__(self, opt): - self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. + self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. self.W = opt.W self.H = opt.H - self.fps = opt.fps # 20 ms per frame + self.fps = opt.fps # 20 ms per frame #### musetalk self.avatar_id = opt.avatar_id - self.static_img = opt.static_img - self.video_path = '' # video_path + self.video_path = '' #video_path self.bbox_shift = opt.bbox_shift self.avatar_path = f"./data/avatars/{self.avatar_id}" - self.full_imgs_path = f"{self.avatar_path}/full_imgs" + self.full_imgs_path = f"{self.avatar_path}/full_imgs" self.coords_path = f"{self.avatar_path}/coords.pkl" - self.latents_out_path = f"{self.avatar_path}/latents.pt" + self.latents_out_path= f"{self.avatar_path}/latents.pt" self.video_out_path = f"{self.avatar_path}/vid_output/" - self.mask_out_path = f"{self.avatar_path}/mask" - self.mask_coords_path = f"{self.avatar_path}/mask_coords.pkl" + self.mask_out_path =f"{self.avatar_path}/mask" + self.mask_coords_path =f"{self.avatar_path}/mask_coords.pkl" self.avatar_info_path = f"{self.avatar_path}/avator_info.json" self.avatar_info = { - "avatar_id": self.avatar_id, - "video_path": self.video_path, - "bbox_shift": self.bbox_shift + "avatar_id":self.avatar_id, + "video_path":self.video_path, + "bbox_shift":self.bbox_shift } self.batch_size = opt.batch_size self.idx = 0 - self.res_frame_queue = mp.Queue(self.batch_size * 2) + self.res_frame_queue = mp.Queue(self.batch_size*2) self.__loadmodels() self.__loadavatar() - self.asr = MuseASR(opt, self.audio_processor) + self.asr = MuseASR(opt,self.audio_processor) + self.asr.warm_up() if opt.tts == "edgetts": - self.tts = EdgeTTS(opt, self) + self.tts = EdgeTTS(opt,self) elif opt.tts == "gpt-sovits": - self.tts = VoitsTTS(opt, self) + self.tts = VoitsTTS(opt,self) elif opt.tts == "xtts": - self.tts = XTTS(opt, self) - # self.__warm_up() - + self.tts = XTTS(opt,self) + #self.__warm_up() + self.render_event = mp.Event() - mp.Process(target=inference, args=(self.render_event, self.batch_size, self.latents_out_path, - self.asr.feat_queue, self.asr.output_queue, self.res_frame_queue, - )).start() # self.vae, self.unet, self.pe,self.timesteps + mp.Process(target=inference, args=(self.render_event,self.batch_size,self.latents_out_path, + self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue, + )).start() #self.vae, self.unet, self.pe,self.timesteps def __loadmodels(self): # load model weights - self.audio_processor = load_audio_model() + self.audio_processor= load_audio_model() # self.audio_processor, self.vae, self.unet, self.pe = load_all_model() # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # self.timesteps = torch.tensor([0], device=device) @@ -187,7 +182,7 @@ class MuseReal: # self.unet.model = self.unet.model.half() def __loadavatar(self): - # self.input_latent_list_cycle = torch.load(self.latents_out_path) + #self.input_latent_list_cycle = torch.load(self.latents_out_path) with open(self.coords_path, 'rb') as f: self.coord_list_cycle = pickle.load(f) input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]')) @@ -198,13 +193,19 @@ class MuseReal: input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]')) input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) self.mask_list_cycle = read_imgs(input_mask_list) - - def put_msg_txt(self, msg): + + + def put_msg_txt(self,msg): self.tts.put_msg_txt(msg) - - def put_audio_frame(self, audio_chunk): # 16khz 20ms pcm + + def put_audio_frame(self,audio_chunk): #16khz 20ms pcm self.asr.put_audio_frame(audio_chunk) + def pause_talk(self): + self.tts.pause_talk() + self.asr.pause_talk() + + def __mirror_index(self, index): size = len(self.coord_list_cycle) turn = index // size @@ -212,15 +213,15 @@ class MuseReal: if turn % 2 == 0: return res else: - return size - res - 1 + return size - res - 1 - def __warm_up(self): + def __warm_up(self): self.asr.run_step() whisper_chunks = self.asr.get_next_feat() whisper_batch = np.stack(whisper_chunks) latent_batch = [] for i in range(self.batch_size): - idx = self.__mirror_index(self.idx + i) + idx = self.__mirror_index(self.idx+i) latent = self.input_latent_list_cycle[idx] latent_batch.append(latent) latent_batch = torch.cat(latent_batch, dim=0) @@ -228,88 +229,90 @@ class MuseReal: # for i, (whisper_batch,latent_batch) in enumerate(gen): audio_feature_batch = torch.from_numpy(whisper_batch) audio_feature_batch = audio_feature_batch.to(device=self.unet.device, - dtype=self.unet.model.dtype) + dtype=self.unet.model.dtype) audio_feature_batch = self.pe(audio_feature_batch) latent_batch = latent_batch.to(dtype=self.unet.model.dtype) - pred_latents = self.unet.model(latent_batch, - self.timesteps, - encoder_hidden_states=audio_feature_batch).sample + pred_latents = self.unet.model(latent_batch, + self.timesteps, + encoder_hidden_states=audio_feature_batch).sample recon = self.vae.decode_latents(pred_latents) + - def process_frames(self, quit_event, loop=None, audio_track=None, video_track=None): - + def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None): + while not quit_event.is_set(): try: - res_frame, idx, audio_frames = self.res_frame_queue.get(block=True, timeout=1) + res_frame,idx,audio_frames = self.res_frame_queue.get(block=True, timeout=1) except queue.Empty: continue - if audio_frames[0][1] == 1 and audio_frames[1][1] == 1: # 全为静音数据,只需要取fullimg - if self.static_img: - combine_frame = self.frame_list_cycle[0] - else: - combine_frame = self.frame_list_cycle[idx] + if audio_frames[0][1]==1 and audio_frames[1][1]==1: #全为静音数据,只需要取fullimg + combine_frame = self.frame_list_cycle[idx] else: bbox = self.coord_list_cycle[idx] ori_frame = copy.deepcopy(self.frame_list_cycle[idx]) x1, y1, x2, y2 = bbox try: - res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1)) + res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1)) except: continue mask = self.mask_list_cycle[idx] mask_crop_box = self.mask_coords_list_cycle[idx] - # combine_frame = get_image(ori_frame,res_frame,bbox) - # t=time.perf_counter() - combine_frame = get_image_blending(ori_frame, res_frame, bbox, mask, mask_crop_box) - # print('blending time:',time.perf_counter()-t) + #combine_frame = get_image(ori_frame,res_frame,bbox) + #t=time.perf_counter() + combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box) + #print('blending time:',time.perf_counter()-t) - image = combine_frame # (outputs['image'] * 255).astype(np.uint8) + image = combine_frame #(outputs['image'] * 255).astype(np.uint8) new_frame = VideoFrame.from_ndarray(image, format="bgr24") - asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop) + asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop) for audio_frame in audio_frames: - frame, type = audio_frame + frame,type = audio_frame frame = (frame * 32767).astype(np.int16) new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) new_frame.planes[0].update(frame.tobytes()) - new_frame.sample_rate = 16000 + new_frame.sample_rate=16000 # if audio_track._queue.qsize()>10: # time.sleep(0.1) asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop) - print('musereal process_frames thread stop') - - def render(self, quit_event, loop=None, audio_track=None, video_track=None): - # if self.opt.asr: + print('musereal process_frames thread stop') + + def render(self,quit_event,loop=None,audio_track=None,video_track=None): + #if self.opt.asr: # self.asr.warm_up() self.tts.render(quit_event) - process_thread = Thread(target=self.process_frames, args=(quit_event, loop, audio_track, video_track)) + process_thread = Thread(target=self.process_frames, args=(quit_event,loop,audio_track,video_track)) process_thread.start() - self.render_event.set() # start infer process render - count = 0 - totaltime = 0 - _starttime = time.perf_counter() - # _totalframe=0 - while not quit_event.is_set(): # todo + self.render_event.set() #start infer process render + count=0 + totaltime=0 + _starttime=time.perf_counter() + #_totalframe=0 + while not quit_event.is_set(): #todo # update texture every frame # audio stream thread... t = time.perf_counter() self.asr.run_step() - # self.test_step(loop,audio_track,video_track) + #self.test_step(loop,audio_track,video_track) # totaltime += (time.perf_counter() - t) # count += self.opt.batch_size # if count>=100: # print(f"------actual avg infer fps:{count/totaltime:.4f}") # count=0 # totaltime=0 - if video_track._queue.qsize() >= 2 * self.opt.batch_size: - print('sleep qsize=', video_track._queue.qsize()) - time.sleep(0.04 * self.opt.batch_size * 1.5) - + if video_track._queue.qsize()>=1.5*self.opt.batch_size: + print('sleep qsize=',video_track._queue.qsize()) + time.sleep(0.04*video_track._queue.qsize()*0.8) + # if video_track._queue.qsize()>=5: + # print('sleep qsize=',video_track._queue.qsize()) + # time.sleep(0.04*video_track._queue.qsize()*0.8) + # delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms # if delay > 0: # time.sleep(delay) - self.render_event.clear() # end infer process render + self.render_event.clear() #end infer process render print('musereal thread stop') + \ No newline at end of file diff --git a/musetalk/utils/face_parsing/__init__.py b/musetalk/utils/face_parsing/__init__.py index 003147f..520593e 100755 --- a/musetalk/utils/face_parsing/__init__.py +++ b/musetalk/utils/face_parsing/__init__.py @@ -7,14 +7,15 @@ from PIL import Image from .model import BiSeNet import torchvision.transforms as transforms - class FaceParsing(): - def __init__(self, resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth', - model_pth='./models/face-parse-bisent/79999_iter.pth'): + def __init__(self,resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth', + model_pth='./models/face-parse-bisent/79999_iter.pth'): self.net = self.model_init(resnet_path,model_pth) self.preprocess = self.image_preprocess() - def model_init(self,resnet_path, model_pth): + def model_init(self, + resnet_path, + model_pth): net = BiSeNet(resnet_path) if torch.cuda.is_available(): net.cuda() @@ -44,13 +45,13 @@ class FaceParsing(): img = torch.unsqueeze(img, 0) out = self.net(img)[0] parsing = out.squeeze(0).cpu().numpy().argmax(0) - parsing[np.where(parsing > 13)] = 0 - parsing[np.where(parsing >= 1)] = 255 + parsing[np.where(parsing>13)] = 0 + parsing[np.where(parsing>=1)] = 255 parsing = Image.fromarray(parsing.astype(np.uint8)) return parsing - if __name__ == "__main__": fp = FaceParsing() segmap = fp('154_small.png') segmap.save('res.png') + diff --git a/nerfreal.py b/nerfreal.py index 5ee2365..ef04c3e 100644 --- a/nerfreal.py +++ b/nerfreal.py @@ -20,9 +20,6 @@ class NeRFReal: self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. self.W = opt.W self.H = opt.H - self.debug = debug - self.training = False - self.step = 0 # training step self.trainer = trainer self.data_loader = data_loader @@ -44,7 +41,6 @@ class NeRFReal: #self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item() # playing seq from dataloader, or pause. - self.playing = True #False todo self.loader = iter(data_loader) #self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32) @@ -62,9 +58,8 @@ class NeRFReal: self.customimg_index = 0 # build asr - if self.opt.asr: - self.asr = ASR(opt) - self.asr.warm_up() + self.asr = ASR(opt) + self.asr.warm_up() if opt.tts == "edgetts": self.tts = EdgeTTS(opt,self) elif opt.tts == "gpt-sovits": @@ -122,7 +117,11 @@ class NeRFReal: self.tts.put_msg_txt(msg) def put_audio_frame(self,audio_chunk): #16khz 20ms pcm - self.asr.put_audio_frame(audio_chunk) + self.asr.put_audio_frame(audio_chunk) + + def pause_talk(self): + self.tts.pause_talk() + self.asr.pause_talk() def mirror_index(self, index): @@ -248,10 +247,9 @@ class NeRFReal: # update texture every frame # audio stream thread... t = time.perf_counter() - if self.opt.asr and self.playing: - # run 2 ASR steps (audio is at 50FPS, video is at 25FPS) - for _ in range(2): - self.asr.run_step() + # run 2 ASR steps (audio is at 50FPS, video is at 25FPS) + for _ in range(2): + self.asr.run_step() self.test_step(loop,audio_track,video_track) totaltime += (time.perf_counter() - t) count += 1 @@ -267,7 +265,7 @@ class NeRFReal: else: if video_track._queue.qsize()>=5: #print('sleep qsize=',video_track._queue.qsize()) - time.sleep(0.1) + time.sleep(0.04*video_track._queue.qsize()*0.8) print('nerfreal thread stop') \ No newline at end of file diff --git a/ttsreal.py b/ttsreal.py index 22e2909..e5c2c7f 100644 --- a/ttsreal.py +++ b/ttsreal.py @@ -13,6 +13,11 @@ import queue from queue import Queue from io import BytesIO from threading import Thread, Event +from enum import Enum + +class State(Enum): + RUNNING=0 + PAUSE=1 class BaseTTS: def __init__(self, opt, parent): @@ -25,6 +30,11 @@ class BaseTTS: self.input_stream = BytesIO() self.msgqueue = Queue() + self.state = State.RUNNING + + def pause_talk(self): + self.msgqueue.queue.clear() + self.state = State.PAUSE def put_msg_txt(self,msg): self.msgqueue.put(msg) @@ -37,6 +47,7 @@ class BaseTTS: while not quit_event.is_set(): try: msg = self.msgqueue.get(block=True, timeout=1) + self.state=State.RUNNING except queue.Empty: continue self.txt_to_audio(msg) @@ -59,7 +70,7 @@ class EdgeTTS(BaseTTS): stream = self.__create_bytes_stream(self.input_stream) streamlen = stream.shape[0] idx=0 - while streamlen >= self.chunk: + while streamlen >= self.chunk and self.state==State.RUNNING: self.parent.put_audio_frame(stream[idx:idx+self.chunk]) streamlen -= self.chunk idx += self.chunk @@ -92,7 +103,7 @@ class EdgeTTS(BaseTTS): async for chunk in communicate.stream(): if first: first = False - if chunk["type"] == "audio": + if chunk["type"] == "audio" and self.state==State.RUNNING: #self.push_audio(chunk["data"]) self.input_stream.write(chunk["data"]) #file.write(chunk["data"]) @@ -147,7 +158,7 @@ class VoitsTTS(BaseTTS): end = time.perf_counter() print(f"gpt_sovits Time to first chunk: {end-start}s") first = False - if chunk: + if chunk and self.state==State.RUNNING: yield chunk print("gpt_sovits response.elapsed:", res.elapsed) diff --git a/web/chat.html b/web/chat.html index aa189ac..c554aba 100644 --- a/web/chat.html +++ b/web/chat.html @@ -29,22 +29,22 @@ $(document).ready(function() { var host = window.location.hostname - var ws = new WebSocket("ws://"+host+":8000/humanchat"); - //document.getElementsByTagName("video")[0].setAttribute("src", aa["video"]); - ws.onopen = function() { - console.log('Connected'); - }; - ws.onmessage = function(e) { - console.log('Received: ' + e.data); - data = e - var vid = JSON.parse(data.data); - console.log(typeof(vid),vid) - //document.getElementsByTagName("video")[0].setAttribute("src", vid["video"]); + // var ws = new WebSocket("ws://"+host+":8000/humanecho"); + // //document.getElementsByTagName("video")[0].setAttribute("src", aa["video"]); + // ws.onopen = function() { + // console.log('Connected'); + // }; + // ws.onmessage = function(e) { + // console.log('Received: ' + e.data); + // data = e + // var vid = JSON.parse(data.data); + // console.log(typeof(vid),vid) + // //document.getElementsByTagName("video")[0].setAttribute("src", vid["video"]); - }; - ws.onclose = function(e) { - console.log('Closed'); - }; + // }; + // ws.onclose = function(e) { + // console.log('Closed'); + // }; flvPlayer = mpegts.createPlayer({type: 'flv', url: "http://"+host+":8080/live/livestream.flv", isLive: true, enableStashBuffer: false}); flvPlayer.attachMediaElement(document.getElementById('video_player')); @@ -55,9 +55,19 @@ e.preventDefault(); var message = $('#message').val(); console.log('Sending: ' + message); - ws.send(message); + fetch('/human', { + body: JSON.stringify({ + text: message, + type: 'chat', + }), + headers: { + 'Content-Type': 'application/json' + }, + method: 'POST' + }); + //ws.send(message); $('#message').val(''); - }); + }); }); \ No newline at end of file diff --git a/web/rtcpushchat.html b/web/rtcpushchat.html index 730541d..95e2319 100644 --- a/web/rtcpushchat.html +++ b/web/rtcpushchat.html @@ -51,30 +51,40 @@ From 79df82ebea8494007f13bcdfedc9372132143baf Mon Sep 17 00:00:00 2001 From: Yun <2289128964@qq.com> Date: Thu, 4 Jul 2024 09:46:42 +0800 Subject: [PATCH 5/5] =?UTF-8?q?feat:=20=E5=AE=8C=E5=96=84=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=E6=88=90=E8=87=AA=E5=8A=A8=E7=BB=9D=E5=AF=B9=E8=B7=AF?= =?UTF-8?q?=E5=BE=84,=E6=B7=BB=E5=8A=A0=E6=8E=A5=E5=8F=A3=E7=94=9F?= =?UTF-8?q?=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- baseasr.py | 61 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 baseasr.py diff --git a/baseasr.py b/baseasr.py new file mode 100644 index 0000000..df66873 --- /dev/null +++ b/baseasr.py @@ -0,0 +1,61 @@ +import time +import numpy as np + +import queue +from queue import Queue +import multiprocessing as mp + + +class BaseASR: + def __init__(self, opt): + self.opt = opt + + self.fps = opt.fps # 20 ms per frame + self.sample_rate = 16000 + self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000) + self.queue = Queue() + self.output_queue = mp.Queue() + + self.batch_size = opt.batch_size + + self.frames = [] + self.stride_left_size = opt.l + self.stride_right_size = opt.r + #self.context_size = 10 + self.feat_queue = mp.Queue(2) + + #self.warm_up() + + def pause_talk(self): + self.queue.queue.clear() + + def put_audio_frame(self,audio_chunk): #16khz 20ms pcm + self.queue.put(audio_chunk) + + def get_audio_frame(self): + try: + frame = self.queue.get(block=True,timeout=0.01) + type = 0 + #print(f'[INFO] get frame {frame.shape}') + except queue.Empty: + frame = np.zeros(self.chunk, dtype=np.float32) + type = 1 + + return frame,type + + def get_audio_out(self): #get origin audio pcm to nerf + return self.output_queue.get() + + def warm_up(self): + for _ in range(self.stride_left_size + self.stride_right_size): + audio_frame,type=self.get_audio_frame() + self.frames.append(audio_frame) + self.output_queue.put((audio_frame,type)) + for _ in range(self.stride_left_size): + self.output_queue.get() + + def run_step(self): + pass + + def get_next_feat(self,block,timeout): + return self.feat_queue.get(block,timeout) \ No newline at end of file