diff --git a/README.md b/README.md index b169183..a8a7c44 100644 --- a/README.md +++ b/README.md @@ -6,11 +6,10 @@ Real time interactive streaming digital human, realize audio video synchronous ## Features 1. 支持多种数字人模型: ernerf、musetalk、wav2lip 2. 支持声音克隆 -3. 支持多种音频特征驱动:wav2vec、hubert +3. 支持数字人说话被打断 4. 支持全身视频拼接 5. 支持rtmp和webrtc 6. 支持视频编排:不说话时播放自定义视频 -7. 支持大模型对话 ## 1. Installation @@ -171,13 +170,11 @@ cd MuseTalk 修改configs/inference/realtime.yaml,将preparation改为True python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml 运行后将results/avatars下文件拷到本项目的data/avatars下 -``` - -```bash -也可以试用本地目录下的 simple_musetalk.py -cd musetalk -python simple_musetalk.py --avatar_id 2 --file D:\\ok\\test.mp4 -运行后将直接生成在data/avatars下 +方法二 +执行 +cd musetalk +python simple_musetalk.py --avatar_id 4 --file D:\\ok\\test.mp4 +支持视频和图片生成 会自动生成到data的avatars目录下 ``` ### 3.10 模型用wav2lip @@ -185,7 +182,7 @@ python simple_musetalk.py --avatar_id 2 --file D:\\ok\\test.mp4 - 下载模型 下载wav2lip运行需要的模型,网盘地址 https://drive.uc.cn/s/551be97d7cfa4 将s3fd.pth拷到本项目wav2lip/face_detection/detection/sfd/s3fd.pth, 将wav2lip.pth拷到本项目的models下 -数字人模型文件 wav2lip_avatar1.tar.gz, 解压后将整个文件夹拷到本项目的data/avatars下 +数字人模型文件 wav2lip_avatar1.tar.gz,网盘地址 https://drive.uc.cn/s/5bd0cde0b0774, 解压后将整个文件夹拷到本项目的data/avatars下 - 运行 python app.py --transport webrtc --model wav2lip --avatar_id wav2lip_avatar1 用浏览器打开http://serverip:8010/webrtcapi.html diff --git a/app.py b/app.py index 1102190..83a60ce 100644 --- a/app.py +++ b/app.py @@ -1,31 +1,39 @@ # server.py -import argparse -import asyncio -import json -import multiprocessing -from threading import Thread, Event - -import aiohttp -import aiohttp_cors -from aiohttp import web -from aiortc import RTCPeerConnection, RTCSessionDescription -from flask import Flask +from flask import Flask, render_template,send_from_directory,request, jsonify from flask_sockets import Sockets +import base64 +import time +import json +import gevent from gevent import pywsgi from geventwebsocket.handler import WebSocketHandler +import os +import re +import numpy as np +from threading import Thread,Event +import multiprocessing -from musetalk.simple_musetalk import create_musetalk_human +from aiohttp import web +import aiohttp +import aiohttp_cors +from aiortc import RTCPeerConnection, RTCSessionDescription from webrtc import HumanPlayer +import argparse + +import shutil +import asyncio + + app = Flask(__name__) sockets = Sockets(app) global nerfreal - + @sockets.route('/humanecho') def echo_socket(ws): # 获取WebSocket对象 - # ws = request.environ.get('wsgi.websocket') + #ws = request.environ.get('wsgi.websocket') # 如果没有获取到,返回错误信息 if not ws: print('未建立连接!') @@ -34,11 +42,11 @@ def echo_socket(ws): else: print('建立连接!') while True: - message = ws.receive() - - if not message or len(message) == 0: + message = ws.receive() + + if not message or len(message)==0: return '输入信息为空' - else: + else: nerfreal.put_msg_txt(message) @@ -46,16 +54,15 @@ def llm_response(message): from llm.LLM import LLM # llm = LLM().init_model('Gemini', model_path= 'gemini-pro',api_key='Your API Key', proxy_url=None) # llm = LLM().init_model('ChatGPT', model_path= 'gpt-3.5-turbo',api_key='Your API Key') - llm = LLM().init_model('VllmGPT', model_path='THUDM/chatglm3-6b') + llm = LLM().init_model('VllmGPT', model_path= 'THUDM/chatglm3-6b') response = llm.chat(message) print(response) return response - @sockets.route('/humanchat') def chat_socket(ws): # 获取WebSocket对象 - # ws = request.environ.get('wsgi.websocket') + #ws = request.environ.get('wsgi.websocket') # 如果没有获取到,返回错误信息 if not ws: print('未建立连接!') @@ -64,20 +71,18 @@ def chat_socket(ws): else: print('建立连接!') while True: - message = ws.receive() - - if len(message) == 0: + message = ws.receive() + + if len(message)==0: return '输入信息为空' else: - res = llm_response(message) + res=llm_response(message) nerfreal.put_msg_txt(res) - #####webrtc############################### pcs = set() - -# @app.route('/offer', methods=['POST']) +#@app.route('/offer', methods=['POST']) async def offer(request): params = await request.json() offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) @@ -101,7 +106,7 @@ async def offer(request): answer = await pc.createAnswer() await pc.setLocalDescription(answer) - # return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}) + #return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}) return web.Response( content_type="application/json", @@ -110,61 +115,39 @@ async def offer(request): ), ) - async def human(request): params = await request.json() - if params['type'] == 'echo': + if params.get('interrupt'): + nerfreal.pause_talk() + + if params['type']=='echo': nerfreal.put_msg_txt(params['text']) - elif params['type'] == 'chat': - res = await asyncio.get_event_loop().run_in_executor(None, llm_response(params['text'])) + elif params['type']=='chat': + res=await asyncio.get_event_loop().run_in_executor(None, llm_response(params['text'])) nerfreal.put_msg_txt(res) return web.Response( content_type="application/json", text=json.dumps( - {"code": 0, "data": "ok"} + {"code": 0, "data":"ok"} ), ) - -async def handle_create_musetalk(request): - reader = await request.multipart() - # 处理文件部分 - file_part = await reader.next() - filename = file_part.filename - file_data = await file_part.read() # 读取文件的内容 - # 注意:确保这个文件路径是可写的 - with open(filename, 'wb') as f: - f.write(file_data) - # 处理整数部分 - part = await reader.next() - avatar_id = int(await part.text()) - create_musetalk_human(filename, avatar_id) - os.remove(filename) - return web.json_response({ - 'status': 'success', - 'filename': filename, - 'int_value': avatar_id, - }) - - async def on_shutdown(app): # close peer connections coros = [pc.close() for pc in pcs] await asyncio.gather(*coros) pcs.clear() - -async def post(url, data): +async def post(url,data): try: async with aiohttp.ClientSession() as session: - async with session.post(url, data=data) as response: + async with session.post(url,data=data) as response: return await response.text() except aiohttp.ClientError as e: print(f'Error: {e}') - async def run(push_url): pc = RTCPeerConnection() pcs.add(pc) @@ -181,10 +164,8 @@ async def run(push_url): video_sender = pc.addTrack(player.video) await pc.setLocalDescription(await pc.createOffer()) - answer = await post(push_url, pc.localDescription.sdp) - await pc.setRemoteDescription(RTCSessionDescription(sdp=answer, type='answer')) - - + answer = await post(push_url,pc.localDescription.sdp) + await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer')) ########################################## # os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' # os.environ['MULTIPROCESSING_METHOD'] = 'forkserver' @@ -203,20 +184,14 @@ if __name__ == '__main__': ### training options parser.add_argument('--ckpt', type=str, default='data/pretrained/ngp_kf.pth') - - parser.add_argument('--num_rays', type=int, default=4096 * 16, - help="num rays sampled per image for each training step") + + parser.add_argument('--num_rays', type=int, default=4096 * 16, help="num rays sampled per image for each training step") parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") - parser.add_argument('--max_steps', type=int, default=16, - help="max num steps sampled per ray (only valid when using --cuda_ray)") - parser.add_argument('--num_steps', type=int, default=16, - help="num steps sampled per ray (only valid when NOT using --cuda_ray)") - parser.add_argument('--upsample_steps', type=int, default=0, - help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)") - parser.add_argument('--update_extra_interval', type=int, default=16, - help="iter interval to update extra status (only valid when using --cuda_ray)") - parser.add_argument('--max_ray_batch', type=int, default=4096, - help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)") + parser.add_argument('--max_steps', type=int, default=16, help="max num steps sampled per ray (only valid when using --cuda_ray)") + parser.add_argument('--num_steps', type=int, default=16, help="num steps sampled per ray (only valid when NOT using --cuda_ray)") + parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)") + parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)") + parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)") ### loss set parser.add_argument('--warmup_step', type=int, default=10000, help="warm up steps") @@ -227,35 +202,27 @@ if __name__ == '__main__': ### network backbone options parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training") - + parser.add_argument('--bg_img', type=str, default='white', help="background image") parser.add_argument('--fbg', action='store_true', help="frame-wise bg") parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes") - parser.add_argument('--fix_eye', type=float, default=-1, - help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye") + parser.add_argument('--fix_eye', type=float, default=-1, help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye") parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence") - parser.add_argument('--torso_shrink', type=float, default=0.8, - help="shrink bg coords to allow more flexibility in deform") + parser.add_argument('--torso_shrink', type=float, default=0.8, help="shrink bg coords to allow more flexibility in deform") ### dataset options parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)") - parser.add_argument('--preload', type=int, default=0, - help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.") + parser.add_argument('--preload', type=int, default=0, help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.") # (the default value is for the fox dataset) - parser.add_argument('--bound', type=float, default=1, - help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.") + parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.") parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3") parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location") - parser.add_argument('--dt_gamma', type=float, default=1 / 256, - help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") + parser.add_argument('--dt_gamma', type=float, default=1/256, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera") - parser.add_argument('--density_thresh', type=float, default=10, - help="threshold for density grid to be occupied (sigma)") - parser.add_argument('--density_thresh_torso', type=float, default=0.01, - help="threshold for density grid to be occupied (alpha)") - parser.add_argument('--patch_size', type=int, default=1, - help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable") + parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied (sigma)") + parser.add_argument('--density_thresh_torso', type=float, default=0.01, help="threshold for density grid to be occupied (alpha)") + parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable") parser.add_argument('--init_lips', action='store_true', help="init lips region") parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region") @@ -273,15 +240,12 @@ if __name__ == '__main__': parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel") ### else - parser.add_argument('--att', type=int, default=2, - help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)") - parser.add_argument('--aud', type=str, default='', - help="audio source (empty will load the default, else should be a path to a npy file)") + parser.add_argument('--att', type=int, default=2, help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)") + parser.add_argument('--aud', type=str, default='', help="audio source (empty will load the default, else should be a path to a npy file)") parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits") parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off") - parser.add_argument('--ind_num', type=int, default=10000, - help="number of individual codes, should be larger than training dataset size") + parser.add_argument('--ind_num', type=int, default=10000, help="number of individual codes, should be larger than training dataset size") parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off") @@ -290,8 +254,7 @@ if __name__ == '__main__': parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)") parser.add_argument('--train_camera', action='store_true', help="optimize camera pose") - parser.add_argument('--smooth_path', action='store_true', - help="brute-force smooth camera pose trajectory with a window size") + parser.add_argument('--smooth_path', action='store_true', help="brute-force smooth camera pose trajectory with a window size") parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size") # asr @@ -299,8 +262,8 @@ if __name__ == '__main__': parser.add_argument('--asr_wav', type=str, default='', help="load the wav and use as input") parser.add_argument('--asr_play', action='store_true', help="play out the audio") - # parser.add_argument('--asr_model', type=str, default='deepspeech') - parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') # + #parser.add_argument('--asr_model', type=str, default='deepspeech') + parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') # # parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self') # parser.add_argument('--asr_model', type=str, default='facebook/hubert-large-ls960-ft') @@ -319,45 +282,42 @@ if __name__ == '__main__': parser.add_argument('--fullbody_offset_x', type=int, default=0) parser.add_argument('--fullbody_offset_y', type=int, default=0) - # musetalk opt + #musetalk opt parser.add_argument('--avatar_id', type=str, default='avator_1') parser.add_argument('--bbox_shift', type=int, default=5) parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--customvideo', action='store_true', help="custom video") - parser.add_argument('--static_img', action='store_true', help="Use the first photo as a time of rest") parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img') parser.add_argument('--customvideo_imgnum', type=int, default=1) - parser.add_argument('--tts', type=str, default='edgetts') # xtts gpt-sovits + parser.add_argument('--tts', type=str, default='edgetts') #xtts gpt-sovits parser.add_argument('--REF_FILE', type=str, default=None) parser.add_argument('--REF_TEXT', type=str, default=None) - parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000 + parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000 # parser.add_argument('--CHARACTER', type=str, default='test') # parser.add_argument('--EMOTION', type=str, default='default') - parser.add_argument('--model', type=str, default='ernerf') # musetalk wav2lip + parser.add_argument('--model', type=str, default='ernerf') #musetalk wav2lip - parser.add_argument('--transport', type=str, default='rtcpush') # rtmp webrtc rtcpush - parser.add_argument('--push_url', type=str, - default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') # rtmp://localhost/live/livestream + parser.add_argument('--transport', type=str, default='rtcpush') #rtmp webrtc rtcpush + parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream parser.add_argument('--listenport', type=int, default=8010) opt = parser.parse_args() - # app.config.from_object(opt) - # print(app.config) + #app.config.from_object(opt) + #print(app.config) if opt.model == 'ernerf': from ernerf.nerf_triplane.provider import NeRFDataset_Test from ernerf.nerf_triplane.utils import * from ernerf.nerf_triplane.network import NeRFNetwork from nerfreal import NeRFReal - # assert test mode opt.test = True opt.test_train = False - # opt.train_camera =True + #opt.train_camera =True # explicit smoothing opt.smooth_path = True opt.smooth_lips = True @@ -370,7 +330,7 @@ if __name__ == '__main__': opt.exp_eye = True opt.smooth_eye = True - if opt.torso_imgs == '': # no img,use model output + if opt.torso_imgs=='': #no img,use model output opt.torso = True # assert opt.cuda_ray, "Only support CUDA ray mode." @@ -386,10 +346,9 @@ if __name__ == '__main__': model = NeRFNetwork(opt) criterion = torch.nn.MSELoss(reduction='none') - metrics = [] # use no metric in GUI for faster initialization... + metrics = [] # use no metric in GUI for faster initialization... print(model) - trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, - metrics=metrics, use_checkpoint=opt.ckpt) + trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt) test_loader = NeRFDataset_Test(opt, device=device).dataloader() model.aud_features = test_loader._data.auds @@ -399,19 +358,17 @@ if __name__ == '__main__': nerfreal = NeRFReal(opt, trainer, test_loader) elif opt.model == 'musetalk': from musereal import MuseReal - print(opt) nerfreal = MuseReal(opt) elif opt.model == 'wav2lip': from lipreal import LipReal - print(opt) nerfreal = LipReal(opt) - # txt_to_audio('我是中国人,我来自北京') - if opt.transport == 'rtmp': + #txt_to_audio('我是中国人,我来自北京') + if opt.transport=='rtmp': thread_quit = Event() - rendthrd = Thread(target=nerfreal.render, args=(thread_quit,)) + rendthrd = Thread(target=nerfreal.render,args=(thread_quit,)) rendthrd.start() ############################################################################# @@ -419,37 +376,35 @@ if __name__ == '__main__': appasync.on_shutdown.append(on_shutdown) appasync.router.add_post("/offer", offer) appasync.router.add_post("/human", human) - appasync.router.add_post("/create_musetalk", handle_create_musetalk) - appasync.router.add_static('/', path='web') + appasync.router.add_static('/',path='web') # Configure default CORS settings. cors = aiohttp_cors.setup(appasync, defaults={ - "*": aiohttp_cors.ResourceOptions( - allow_credentials=True, - expose_headers="*", - allow_headers="*", - ) - }) + "*": aiohttp_cors.ResourceOptions( + allow_credentials=True, + expose_headers="*", + allow_headers="*", + ) + }) # Configure CORS on all routes. for route in list(appasync.router.routes()): cors.add(route) - def run_server(runner): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(runner.setup()) site = web.TCPSite(runner, '0.0.0.0', opt.listenport) loop.run_until_complete(site.start()) - if opt.transport == 'rtcpush': + if opt.transport=='rtcpush': loop.run_until_complete(run(opt.push_url)) - loop.run_forever() - - + loop.run_forever() Thread(target=run_server, args=(web.AppRunner(appasync),)).start() print('start websocket server') - # app.on_shutdown.append(on_shutdown) - # app.router.add_post("/offer", offer) + #app.on_shutdown.append(on_shutdown) + #app.router.add_post("/offer", offer) server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler) server.serve_forever() + + \ No newline at end of file diff --git a/asrreal.py b/asrreal.py index b3e4093..62aa15e 100644 --- a/asrreal.py +++ b/asrreal.py @@ -4,29 +4,19 @@ import torch import torch.nn.functional as F from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel -#import pyaudio -import soundfile as sf -import resampy import queue from queue import Queue #from collections import deque from threading import Thread, Event -from io import BytesIO -class ASR: +from baseasr import BaseASR + +class ASR(BaseASR): def __init__(self, opt): - - self.opt = opt - - self.play = opt.asr_play #false + super().__init__(opt) self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.fps = opt.fps # 20 ms per frame - self.sample_rate = 16000 - self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000) - self.mode = 'live' if opt.asr_wav == '' else 'file' - if 'esperanto' in self.opt.asr_model: self.audio_dim = 44 elif 'deepspeech' in self.opt.asr_model: @@ -41,30 +31,11 @@ class ASR: self.context_size = opt.m self.stride_left_size = opt.l self.stride_right_size = opt.r - self.text = '[START]\n' - self.terminated = False - self.frames = [] - self.inwarm = False # pad left frames if self.stride_left_size > 0: self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size) - - self.exit_event = Event() - #self.audio_instance = pyaudio.PyAudio() #not need - - # create input stream - self.queue = Queue() - self.output_queue = Queue() - # start a background process to read frames - #self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk) - #self.queue = Queue() - #self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk)) - - # current location of audio - self.idx = 0 - # create wav2vec model print(f'[INFO] loading ASR model {self.opt.asr_model}...') if 'hubert' in self.opt.asr_model: @@ -74,10 +45,6 @@ class ASR: self.processor = AutoProcessor.from_pretrained(opt.asr_model) self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device) - # prepare to save logits - if self.opt.asr_save_feats: - self.all_feats = [] - # the extracted features # use a loop queue to efficiently record endless features: [f--t---][-------][-------] self.feat_buffer_size = 4 @@ -93,8 +60,16 @@ class ASR: # warm up steps needed: mid + right + window_size + attention_size self.warm_up_steps = self.context_size + self.stride_left_size + self.stride_right_size #+ self.stride_left_size #+ 8 + 2 * 3 - self.listening = False - self.playing = False + def get_audio_frame(self): + try: + frame = self.queue.get(block=False) + type = 0 + #print(f'[INFO] get frame {frame.shape}') + except queue.Empty: + frame = np.zeros(self.chunk, dtype=np.float32) + type = 1 + + return frame,type def get_next_feat(self): #get audio embedding to nerf # return a [1/8, 16] window, for the next input to nerf side. @@ -136,29 +111,19 @@ class ASR: def run_step(self): - if self.terminated: - return - # get a frame of audio - frame,type = self.__get_audio_frame() - - # the last frame - if frame is None: - # terminate, but always run the network for the left frames - self.terminated = True - else: - self.frames.append(frame) - # put to output - self.output_queue.put((frame,type)) - # context not enough, do not run network. - if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size: - return + frame,type = self.get_audio_frame() + self.frames.append(frame) + # put to output + self.output_queue.put((frame,type)) + # context not enough, do not run network. + if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size: + return inputs = np.concatenate(self.frames) # [N * chunk] # discard the old part to save memory - if not self.terminated: - self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] + self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] #print(f'[INFO] frame_to_text... ') #t = time.time() @@ -166,10 +131,6 @@ class ASR: #print(f'-------wav2vec time:{time.time()-t:.4f}s') feats = logits # better lips-sync than labels - # save feats - if self.opt.asr_save_feats: - self.all_feats.append(feats) - # record the feats efficiently.. (no concat, constant memory) start = self.feat_buffer_idx * self.context_size end = start + feats.shape[0] @@ -203,24 +164,6 @@ class ASR: # np.save(output_path, unfold_feats.cpu().numpy()) # print(f"[INFO] saved logits to {output_path}") - def put_audio_frame(self,audio_chunk): #16khz 20ms pcm - self.queue.put(audio_chunk) - - def __get_audio_frame(self): - if self.inwarm: # warm up - return np.zeros(self.chunk, dtype=np.float32),1 - - try: - frame = self.queue.get(block=False) - type = 0 - print(f'[INFO] get frame {frame.shape}') - except queue.Empty: - frame = np.zeros(self.chunk, dtype=np.float32) - type = 1 - - self.idx = self.idx + self.chunk - - return frame,type def __frame_to_text(self, frame): @@ -241,8 +184,8 @@ class ASR: right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input. # do not cut right if terminated. - if self.terminated: - right = logits.shape[1] + # if self.terminated: + # right = logits.shape[1] logits = logits[:, left:right] @@ -262,10 +205,23 @@ class ASR: return logits[0], None,None #predicted_ids[0], transcription # [N,] - - def get_audio_out(self): #get origin audio pcm to nerf - return self.output_queue.get() - + + def warm_up(self): + print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s') + t = time.time() + #for _ in range(self.stride_left_size): + # self.frames.append(np.zeros(self.chunk, dtype=np.float32)) + for _ in range(self.warm_up_steps): + self.run_step() + #if torch.cuda.is_available(): + # torch.cuda.synchronize() + t = time.time() - t + print(f'[INFO] warm-up done, actual latency = {t:.6f}s') + + #self.clear_queue() + + #####not used function##################################### + ''' def __init_queue(self): self.frames = [] self.queue.queue.clear() @@ -290,26 +246,6 @@ class ASR: if self.play: self.output_queue.queue.clear() - def warm_up(self): - - #self.listen() - - self.inwarm = True - print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s') - t = time.time() - #for _ in range(self.stride_left_size): - # self.frames.append(np.zeros(self.chunk, dtype=np.float32)) - for _ in range(self.warm_up_steps): - self.run_step() - #if torch.cuda.is_available(): - # torch.cuda.synchronize() - t = time.time() - t - print(f'[INFO] warm-up done, actual latency = {t:.6f}s') - self.inwarm = False - - #self.clear_queue() - - #####not used function##################################### def listen(self): # start if self.mode == 'live' and not self.listening: @@ -404,4 +340,5 @@ if __name__ == '__main__': raise ValueError("DeepSpeech features should not use this code to extract...") with ASR(opt) as asr: - asr.run() \ No newline at end of file + asr.run() +''' \ No newline at end of file diff --git a/lipasr.py b/lipasr.py index 5742dd7..29948ac 100644 --- a/lipasr.py +++ b/lipasr.py @@ -6,60 +6,16 @@ import queue from queue import Queue import multiprocessing as mp +from baseasr import BaseASR from wav2lip import audio -class LipASR: - def __init__(self, opt): - self.opt = opt - - self.fps = opt.fps # 20 ms per frame - self.sample_rate = 16000 - self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000) - self.queue = Queue() - # self.input_stream = BytesIO() - self.output_queue = mp.Queue() - - #self.audio_processor = audio_processor - self.batch_size = opt.batch_size - - self.frames = [] - self.stride_left_size = opt.l - self.stride_right_size = opt.r - #self.context_size = 10 - self.feat_queue = mp.Queue(5) - - self.warm_up() - - def put_audio_frame(self,audio_chunk): #16khz 20ms pcm - self.queue.put(audio_chunk) - - def __get_audio_frame(self): - try: - frame = self.queue.get(block=True,timeout=0.01) - type = 0 - #print(f'[INFO] get frame {frame.shape}') - except queue.Empty: - frame = np.zeros(self.chunk, dtype=np.float32) - type = 1 - - return frame,type - - def get_audio_out(self): #get origin audio pcm to nerf - return self.output_queue.get() - - def warm_up(self): - for _ in range(self.stride_left_size + self.stride_right_size): - audio_frame,type=self.__get_audio_frame() - self.frames.append(audio_frame) - self.output_queue.put((audio_frame,type)) - for _ in range(self.stride_left_size): - self.output_queue.get() +class LipASR(BaseASR): def run_step(self): ############################################## extract audio feature ############################################## # get a frame of audio for _ in range(self.batch_size*2): - frame,type = self.__get_audio_frame() + frame,type = self.get_audio_frame() self.frames.append(frame) # put to output self.output_queue.put((frame,type)) @@ -89,7 +45,3 @@ class LipASR: # discard the old part to save memory self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] - - - def get_next_feat(self,block,timeout): - return self.feat_queue.get(block,timeout) \ No newline at end of file diff --git a/lipreal.py b/lipreal.py index d69f3dc..9461e7b 100644 --- a/lipreal.py +++ b/lipreal.py @@ -164,6 +164,7 @@ class LipReal: self.__loadavatar() self.asr = LipASR(opt) + self.asr.warm_up() if opt.tts == "edgetts": self.tts = EdgeTTS(opt,self) elif opt.tts == "gpt-sovits": @@ -199,6 +200,10 @@ class LipReal: def put_audio_frame(self,audio_chunk): #16khz 20ms pcm self.asr.put_audio_frame(audio_chunk) + + def pause_talk(self): + self.tts.pause_talk() + self.asr.pause_talk() def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None): @@ -257,9 +262,12 @@ class LipReal: t = time.perf_counter() self.asr.run_step() - if video_track._queue.qsize()>=2*self.opt.batch_size: + # if video_track._queue.qsize()>=2*self.opt.batch_size: + # print('sleep qsize=',video_track._queue.qsize()) + # time.sleep(0.04*video_track._queue.qsize()*0.8) + if video_track._queue.qsize()>=5: print('sleep qsize=',video_track._queue.qsize()) - time.sleep(0.04*self.opt.batch_size*1.5) + time.sleep(0.04*video_track._queue.qsize()*0.8) # delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms # if delay > 0: diff --git a/museasr.py b/museasr.py index cfcd9ba..cb166a6 100644 --- a/museasr.py +++ b/museasr.py @@ -1,65 +1,22 @@ import time -import torch import numpy as np import queue from queue import Queue import multiprocessing as mp - +from baseasr import BaseASR from musetalk.whisper.audio2feature import Audio2Feature -class MuseASR: +class MuseASR(BaseASR): def __init__(self, opt, audio_processor:Audio2Feature): - self.opt = opt - - self.fps = opt.fps # 20 ms per frame - self.sample_rate = 16000 - self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000) - self.queue = Queue() - # self.input_stream = BytesIO() - self.output_queue = mp.Queue() - + super().__init__(opt) self.audio_processor = audio_processor - self.batch_size = opt.batch_size - - self.frames = [] - self.stride_left_size = opt.l - self.stride_right_size = opt.r - self.feat_queue = mp.Queue(5) - - self.warm_up() - - def put_audio_frame(self,audio_chunk): #16khz 20ms pcm - self.queue.put(audio_chunk) - - def __get_audio_frame(self): - try: - frame = self.queue.get(block=True,timeout=0.01) - type = 0 - #print(f'[INFO] get frame {frame.shape}') - except queue.Empty: - frame = np.zeros(self.chunk, dtype=np.float32) - type = 1 - - return frame,type - - def get_audio_out(self): #get origin audio pcm to nerf - return self.output_queue.get() - - def warm_up(self): - for _ in range(self.stride_left_size + self.stride_right_size): - audio_frame,type=self.__get_audio_frame() - self.frames.append(audio_frame) - self.output_queue.put((audio_frame,type)) - - for _ in range(self.stride_left_size): - self.output_queue.get() def run_step(self): ############################################## extract audio feature ############################################## start_time = time.time() for _ in range(self.batch_size*2): - audio_frame,type=self.__get_audio_frame() + audio_frame,type=self.get_audio_frame() self.frames.append(audio_frame) self.output_queue.put((audio_frame,type)) @@ -77,6 +34,3 @@ class MuseASR: self.feat_queue.put(whisper_chunks) # discard the old part to save memory self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] - - def get_next_feat(self,block,timeout): - return self.feat_queue.get(block,timeout) \ No newline at end of file diff --git a/musereal.py b/musereal.py index 7e58029..0f01dcf 100644 --- a/musereal.py +++ b/musereal.py @@ -2,7 +2,7 @@ import math import torch import numpy as np -# from .utils import * +#from .utils import * import subprocess import os import time @@ -18,19 +18,17 @@ from threading import Thread, Event from io import BytesIO import multiprocessing as mp -from musetalk.utils.utils import get_file_type, get_video_fps, datagen -# from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder -from musetalk.utils.blending import get_image, get_image_prepare_material, get_image_blending -from musetalk.utils.utils import load_all_model, load_diffusion_model, load_audio_model -from ttsreal import EdgeTTS, VoitsTTS, XTTS +from musetalk.utils.utils import get_file_type,get_video_fps,datagen +#from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder +from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending +from musetalk.utils.utils import load_all_model,load_diffusion_model,load_audio_model +from ttsreal import EdgeTTS,VoitsTTS,XTTS from museasr import MuseASR import asyncio from av import AudioFrame, VideoFrame from tqdm import tqdm - - def read_imgs(img_list): frames = [] print('reading images...') @@ -39,146 +37,143 @@ def read_imgs(img_list): frames.append(frame) return frames - def __mirror_index(size, index): - # size = len(self.coord_list_cycle) + #size = len(self.coord_list_cycle) turn = index // size res = index % size if turn % 2 == 0: return res else: - return size - res - 1 - - -def inference(render_event, batch_size, latents_out_path, audio_feat_queue, audio_out_queue, res_frame_queue, - ): # vae, unet, pe,timesteps + return size - res - 1 +def inference(render_event,batch_size,latents_out_path,audio_feat_queue,audio_out_queue,res_frame_queue, + ): #vae, unet, pe,timesteps + vae, unet, pe = load_diffusion_model() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") timesteps = torch.tensor([0], device=device) pe = pe.half() vae.vae = vae.vae.half() unet.model = unet.model.half() - + input_latent_list_cycle = torch.load(latents_out_path) length = len(input_latent_list_cycle) index = 0 - count = 0 - counttime = 0 + count=0 + counttime=0 print('start inference') while True: if render_event.is_set(): - starttime = time.perf_counter() + starttime=time.perf_counter() try: whisper_chunks = audio_feat_queue.get(block=True, timeout=1) except queue.Empty: continue - is_all_silence = True + is_all_silence=True audio_frames = [] - for _ in range(batch_size * 2): - frame, type = audio_out_queue.get() - audio_frames.append((frame, type)) - if type == 0: - is_all_silence = False + for _ in range(batch_size*2): + frame,type = audio_out_queue.get() + audio_frames.append((frame,type)) + if type==0: + is_all_silence=False if is_all_silence: for i in range(batch_size): - res_frame_queue.put((None, __mirror_index(length, index), audio_frames[i * 2:i * 2 + 2])) + res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2])) index = index + 1 else: # print('infer=======') - t = time.perf_counter() + t=time.perf_counter() whisper_batch = np.stack(whisper_chunks) latent_batch = [] for i in range(batch_size): - idx = __mirror_index(length, index + i) + idx = __mirror_index(length,index+i) latent = input_latent_list_cycle[idx] latent_batch.append(latent) latent_batch = torch.cat(latent_batch, dim=0) - + # for i, (whisper_batch,latent_batch) in enumerate(gen): audio_feature_batch = torch.from_numpy(whisper_batch) audio_feature_batch = audio_feature_batch.to(device=unet.device, - dtype=unet.model.dtype) + dtype=unet.model.dtype) audio_feature_batch = pe(audio_feature_batch) latent_batch = latent_batch.to(dtype=unet.model.dtype) # print('prepare time:',time.perf_counter()-t) # t=time.perf_counter() - pred_latents = unet.model(latent_batch, - timesteps, - encoder_hidden_states=audio_feature_batch).sample + pred_latents = unet.model(latent_batch, + timesteps, + encoder_hidden_states=audio_feature_batch).sample # print('unet time:',time.perf_counter()-t) # t=time.perf_counter() recon = vae.decode_latents(pred_latents) # print('vae time:',time.perf_counter()-t) - # print('diffusion len=',len(recon)) + #print('diffusion len=',len(recon)) counttime += (time.perf_counter() - t) count += batch_size - # _totalframe += 1 - if count >= 100: - print(f"------actual avg infer fps:{count / counttime:.4f}") - count = 0 - counttime = 0 - for i, res_frame in enumerate(recon): - # self.__pushmedia(res_frame,loop,audio_track,video_track) - res_frame_queue.put((res_frame, __mirror_index(length, index), audio_frames[i * 2:i * 2 + 2])) + #_totalframe += 1 + if count>=100: + print(f"------actual avg infer fps:{count/counttime:.4f}") + count=0 + counttime=0 + for i,res_frame in enumerate(recon): + #self.__pushmedia(res_frame,loop,audio_track,video_track) + res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2])) index = index + 1 - # print('total batch time:',time.perf_counter()-starttime) + #print('total batch time:',time.perf_counter()-starttime) else: time.sleep(1) print('musereal inference processor stop') - @torch.no_grad() class MuseReal: def __init__(self, opt): - self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. + self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. self.W = opt.W self.H = opt.H - self.fps = opt.fps # 20 ms per frame + self.fps = opt.fps # 20 ms per frame #### musetalk self.avatar_id = opt.avatar_id - self.static_img = opt.static_img - self.video_path = '' # video_path + self.video_path = '' #video_path self.bbox_shift = opt.bbox_shift self.avatar_path = f"./data/avatars/{self.avatar_id}" - self.full_imgs_path = f"{self.avatar_path}/full_imgs" + self.full_imgs_path = f"{self.avatar_path}/full_imgs" self.coords_path = f"{self.avatar_path}/coords.pkl" - self.latents_out_path = f"{self.avatar_path}/latents.pt" + self.latents_out_path= f"{self.avatar_path}/latents.pt" self.video_out_path = f"{self.avatar_path}/vid_output/" - self.mask_out_path = f"{self.avatar_path}/mask" - self.mask_coords_path = f"{self.avatar_path}/mask_coords.pkl" + self.mask_out_path =f"{self.avatar_path}/mask" + self.mask_coords_path =f"{self.avatar_path}/mask_coords.pkl" self.avatar_info_path = f"{self.avatar_path}/avator_info.json" self.avatar_info = { - "avatar_id": self.avatar_id, - "video_path": self.video_path, - "bbox_shift": self.bbox_shift + "avatar_id":self.avatar_id, + "video_path":self.video_path, + "bbox_shift":self.bbox_shift } self.batch_size = opt.batch_size self.idx = 0 - self.res_frame_queue = mp.Queue(self.batch_size * 2) + self.res_frame_queue = mp.Queue(self.batch_size*2) self.__loadmodels() self.__loadavatar() - self.asr = MuseASR(opt, self.audio_processor) + self.asr = MuseASR(opt,self.audio_processor) + self.asr.warm_up() if opt.tts == "edgetts": - self.tts = EdgeTTS(opt, self) + self.tts = EdgeTTS(opt,self) elif opt.tts == "gpt-sovits": - self.tts = VoitsTTS(opt, self) + self.tts = VoitsTTS(opt,self) elif opt.tts == "xtts": - self.tts = XTTS(opt, self) - # self.__warm_up() - + self.tts = XTTS(opt,self) + #self.__warm_up() + self.render_event = mp.Event() - mp.Process(target=inference, args=(self.render_event, self.batch_size, self.latents_out_path, - self.asr.feat_queue, self.asr.output_queue, self.res_frame_queue, - )).start() # self.vae, self.unet, self.pe,self.timesteps + mp.Process(target=inference, args=(self.render_event,self.batch_size,self.latents_out_path, + self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue, + )).start() #self.vae, self.unet, self.pe,self.timesteps def __loadmodels(self): # load model weights - self.audio_processor = load_audio_model() + self.audio_processor= load_audio_model() # self.audio_processor, self.vae, self.unet, self.pe = load_all_model() # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # self.timesteps = torch.tensor([0], device=device) @@ -187,7 +182,7 @@ class MuseReal: # self.unet.model = self.unet.model.half() def __loadavatar(self): - # self.input_latent_list_cycle = torch.load(self.latents_out_path) + #self.input_latent_list_cycle = torch.load(self.latents_out_path) with open(self.coords_path, 'rb') as f: self.coord_list_cycle = pickle.load(f) input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]')) @@ -198,13 +193,19 @@ class MuseReal: input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]')) input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) self.mask_list_cycle = read_imgs(input_mask_list) - - def put_msg_txt(self, msg): + + + def put_msg_txt(self,msg): self.tts.put_msg_txt(msg) - - def put_audio_frame(self, audio_chunk): # 16khz 20ms pcm + + def put_audio_frame(self,audio_chunk): #16khz 20ms pcm self.asr.put_audio_frame(audio_chunk) + def pause_talk(self): + self.tts.pause_talk() + self.asr.pause_talk() + + def __mirror_index(self, index): size = len(self.coord_list_cycle) turn = index // size @@ -212,15 +213,15 @@ class MuseReal: if turn % 2 == 0: return res else: - return size - res - 1 + return size - res - 1 - def __warm_up(self): + def __warm_up(self): self.asr.run_step() whisper_chunks = self.asr.get_next_feat() whisper_batch = np.stack(whisper_chunks) latent_batch = [] for i in range(self.batch_size): - idx = self.__mirror_index(self.idx + i) + idx = self.__mirror_index(self.idx+i) latent = self.input_latent_list_cycle[idx] latent_batch.append(latent) latent_batch = torch.cat(latent_batch, dim=0) @@ -228,88 +229,90 @@ class MuseReal: # for i, (whisper_batch,latent_batch) in enumerate(gen): audio_feature_batch = torch.from_numpy(whisper_batch) audio_feature_batch = audio_feature_batch.to(device=self.unet.device, - dtype=self.unet.model.dtype) + dtype=self.unet.model.dtype) audio_feature_batch = self.pe(audio_feature_batch) latent_batch = latent_batch.to(dtype=self.unet.model.dtype) - pred_latents = self.unet.model(latent_batch, - self.timesteps, - encoder_hidden_states=audio_feature_batch).sample + pred_latents = self.unet.model(latent_batch, + self.timesteps, + encoder_hidden_states=audio_feature_batch).sample recon = self.vae.decode_latents(pred_latents) + - def process_frames(self, quit_event, loop=None, audio_track=None, video_track=None): - + def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None): + while not quit_event.is_set(): try: - res_frame, idx, audio_frames = self.res_frame_queue.get(block=True, timeout=1) + res_frame,idx,audio_frames = self.res_frame_queue.get(block=True, timeout=1) except queue.Empty: continue - if audio_frames[0][1] == 1 and audio_frames[1][1] == 1: # 全为静音数据,只需要取fullimg - if self.static_img: - combine_frame = self.frame_list_cycle[0] - else: - combine_frame = self.frame_list_cycle[idx] + if audio_frames[0][1]==1 and audio_frames[1][1]==1: #全为静音数据,只需要取fullimg + combine_frame = self.frame_list_cycle[idx] else: bbox = self.coord_list_cycle[idx] ori_frame = copy.deepcopy(self.frame_list_cycle[idx]) x1, y1, x2, y2 = bbox try: - res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1)) + res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1)) except: continue mask = self.mask_list_cycle[idx] mask_crop_box = self.mask_coords_list_cycle[idx] - # combine_frame = get_image(ori_frame,res_frame,bbox) - # t=time.perf_counter() - combine_frame = get_image_blending(ori_frame, res_frame, bbox, mask, mask_crop_box) - # print('blending time:',time.perf_counter()-t) + #combine_frame = get_image(ori_frame,res_frame,bbox) + #t=time.perf_counter() + combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box) + #print('blending time:',time.perf_counter()-t) - image = combine_frame # (outputs['image'] * 255).astype(np.uint8) + image = combine_frame #(outputs['image'] * 255).astype(np.uint8) new_frame = VideoFrame.from_ndarray(image, format="bgr24") - asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop) + asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop) for audio_frame in audio_frames: - frame, type = audio_frame + frame,type = audio_frame frame = (frame * 32767).astype(np.int16) new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) new_frame.planes[0].update(frame.tobytes()) - new_frame.sample_rate = 16000 + new_frame.sample_rate=16000 # if audio_track._queue.qsize()>10: # time.sleep(0.1) asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop) - print('musereal process_frames thread stop') - - def render(self, quit_event, loop=None, audio_track=None, video_track=None): - # if self.opt.asr: + print('musereal process_frames thread stop') + + def render(self,quit_event,loop=None,audio_track=None,video_track=None): + #if self.opt.asr: # self.asr.warm_up() self.tts.render(quit_event) - process_thread = Thread(target=self.process_frames, args=(quit_event, loop, audio_track, video_track)) + process_thread = Thread(target=self.process_frames, args=(quit_event,loop,audio_track,video_track)) process_thread.start() - self.render_event.set() # start infer process render - count = 0 - totaltime = 0 - _starttime = time.perf_counter() - # _totalframe=0 - while not quit_event.is_set(): # todo + self.render_event.set() #start infer process render + count=0 + totaltime=0 + _starttime=time.perf_counter() + #_totalframe=0 + while not quit_event.is_set(): #todo # update texture every frame # audio stream thread... t = time.perf_counter() self.asr.run_step() - # self.test_step(loop,audio_track,video_track) + #self.test_step(loop,audio_track,video_track) # totaltime += (time.perf_counter() - t) # count += self.opt.batch_size # if count>=100: # print(f"------actual avg infer fps:{count/totaltime:.4f}") # count=0 # totaltime=0 - if video_track._queue.qsize() >= 2 * self.opt.batch_size: - print('sleep qsize=', video_track._queue.qsize()) - time.sleep(0.04 * self.opt.batch_size * 1.5) - + if video_track._queue.qsize()>=1.5*self.opt.batch_size: + print('sleep qsize=',video_track._queue.qsize()) + time.sleep(0.04*video_track._queue.qsize()*0.8) + # if video_track._queue.qsize()>=5: + # print('sleep qsize=',video_track._queue.qsize()) + # time.sleep(0.04*video_track._queue.qsize()*0.8) + # delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms # if delay > 0: # time.sleep(delay) - self.render_event.clear() # end infer process render + self.render_event.clear() #end infer process render print('musereal thread stop') + \ No newline at end of file diff --git a/musetalk/utils/face_parsing/__init__.py b/musetalk/utils/face_parsing/__init__.py index 003147f..520593e 100755 --- a/musetalk/utils/face_parsing/__init__.py +++ b/musetalk/utils/face_parsing/__init__.py @@ -7,14 +7,15 @@ from PIL import Image from .model import BiSeNet import torchvision.transforms as transforms - class FaceParsing(): - def __init__(self, resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth', - model_pth='./models/face-parse-bisent/79999_iter.pth'): + def __init__(self,resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth', + model_pth='./models/face-parse-bisent/79999_iter.pth'): self.net = self.model_init(resnet_path,model_pth) self.preprocess = self.image_preprocess() - def model_init(self,resnet_path, model_pth): + def model_init(self, + resnet_path, + model_pth): net = BiSeNet(resnet_path) if torch.cuda.is_available(): net.cuda() @@ -44,13 +45,13 @@ class FaceParsing(): img = torch.unsqueeze(img, 0) out = self.net(img)[0] parsing = out.squeeze(0).cpu().numpy().argmax(0) - parsing[np.where(parsing > 13)] = 0 - parsing[np.where(parsing >= 1)] = 255 + parsing[np.where(parsing>13)] = 0 + parsing[np.where(parsing>=1)] = 255 parsing = Image.fromarray(parsing.astype(np.uint8)) return parsing - if __name__ == "__main__": fp = FaceParsing() segmap = fp('154_small.png') segmap.save('res.png') + diff --git a/nerfreal.py b/nerfreal.py index 5ee2365..ef04c3e 100644 --- a/nerfreal.py +++ b/nerfreal.py @@ -20,9 +20,6 @@ class NeRFReal: self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. self.W = opt.W self.H = opt.H - self.debug = debug - self.training = False - self.step = 0 # training step self.trainer = trainer self.data_loader = data_loader @@ -44,7 +41,6 @@ class NeRFReal: #self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item() # playing seq from dataloader, or pause. - self.playing = True #False todo self.loader = iter(data_loader) #self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32) @@ -62,9 +58,8 @@ class NeRFReal: self.customimg_index = 0 # build asr - if self.opt.asr: - self.asr = ASR(opt) - self.asr.warm_up() + self.asr = ASR(opt) + self.asr.warm_up() if opt.tts == "edgetts": self.tts = EdgeTTS(opt,self) elif opt.tts == "gpt-sovits": @@ -122,7 +117,11 @@ class NeRFReal: self.tts.put_msg_txt(msg) def put_audio_frame(self,audio_chunk): #16khz 20ms pcm - self.asr.put_audio_frame(audio_chunk) + self.asr.put_audio_frame(audio_chunk) + + def pause_talk(self): + self.tts.pause_talk() + self.asr.pause_talk() def mirror_index(self, index): @@ -248,10 +247,9 @@ class NeRFReal: # update texture every frame # audio stream thread... t = time.perf_counter() - if self.opt.asr and self.playing: - # run 2 ASR steps (audio is at 50FPS, video is at 25FPS) - for _ in range(2): - self.asr.run_step() + # run 2 ASR steps (audio is at 50FPS, video is at 25FPS) + for _ in range(2): + self.asr.run_step() self.test_step(loop,audio_track,video_track) totaltime += (time.perf_counter() - t) count += 1 @@ -267,7 +265,7 @@ class NeRFReal: else: if video_track._queue.qsize()>=5: #print('sleep qsize=',video_track._queue.qsize()) - time.sleep(0.1) + time.sleep(0.04*video_track._queue.qsize()*0.8) print('nerfreal thread stop') \ No newline at end of file diff --git a/ttsreal.py b/ttsreal.py index 22e2909..e5c2c7f 100644 --- a/ttsreal.py +++ b/ttsreal.py @@ -13,6 +13,11 @@ import queue from queue import Queue from io import BytesIO from threading import Thread, Event +from enum import Enum + +class State(Enum): + RUNNING=0 + PAUSE=1 class BaseTTS: def __init__(self, opt, parent): @@ -25,6 +30,11 @@ class BaseTTS: self.input_stream = BytesIO() self.msgqueue = Queue() + self.state = State.RUNNING + + def pause_talk(self): + self.msgqueue.queue.clear() + self.state = State.PAUSE def put_msg_txt(self,msg): self.msgqueue.put(msg) @@ -37,6 +47,7 @@ class BaseTTS: while not quit_event.is_set(): try: msg = self.msgqueue.get(block=True, timeout=1) + self.state=State.RUNNING except queue.Empty: continue self.txt_to_audio(msg) @@ -59,7 +70,7 @@ class EdgeTTS(BaseTTS): stream = self.__create_bytes_stream(self.input_stream) streamlen = stream.shape[0] idx=0 - while streamlen >= self.chunk: + while streamlen >= self.chunk and self.state==State.RUNNING: self.parent.put_audio_frame(stream[idx:idx+self.chunk]) streamlen -= self.chunk idx += self.chunk @@ -92,7 +103,7 @@ class EdgeTTS(BaseTTS): async for chunk in communicate.stream(): if first: first = False - if chunk["type"] == "audio": + if chunk["type"] == "audio" and self.state==State.RUNNING: #self.push_audio(chunk["data"]) self.input_stream.write(chunk["data"]) #file.write(chunk["data"]) @@ -147,7 +158,7 @@ class VoitsTTS(BaseTTS): end = time.perf_counter() print(f"gpt_sovits Time to first chunk: {end-start}s") first = False - if chunk: + if chunk and self.state==State.RUNNING: yield chunk print("gpt_sovits response.elapsed:", res.elapsed) diff --git a/web/chat.html b/web/chat.html index aa189ac..c554aba 100644 --- a/web/chat.html +++ b/web/chat.html @@ -29,22 +29,22 @@ $(document).ready(function() { var host = window.location.hostname - var ws = new WebSocket("ws://"+host+":8000/humanchat"); - //document.getElementsByTagName("video")[0].setAttribute("src", aa["video"]); - ws.onopen = function() { - console.log('Connected'); - }; - ws.onmessage = function(e) { - console.log('Received: ' + e.data); - data = e - var vid = JSON.parse(data.data); - console.log(typeof(vid),vid) - //document.getElementsByTagName("video")[0].setAttribute("src", vid["video"]); + // var ws = new WebSocket("ws://"+host+":8000/humanecho"); + // //document.getElementsByTagName("video")[0].setAttribute("src", aa["video"]); + // ws.onopen = function() { + // console.log('Connected'); + // }; + // ws.onmessage = function(e) { + // console.log('Received: ' + e.data); + // data = e + // var vid = JSON.parse(data.data); + // console.log(typeof(vid),vid) + // //document.getElementsByTagName("video")[0].setAttribute("src", vid["video"]); - }; - ws.onclose = function(e) { - console.log('Closed'); - }; + // }; + // ws.onclose = function(e) { + // console.log('Closed'); + // }; flvPlayer = mpegts.createPlayer({type: 'flv', url: "http://"+host+":8080/live/livestream.flv", isLive: true, enableStashBuffer: false}); flvPlayer.attachMediaElement(document.getElementById('video_player')); @@ -55,9 +55,19 @@ e.preventDefault(); var message = $('#message').val(); console.log('Sending: ' + message); - ws.send(message); + fetch('/human', { + body: JSON.stringify({ + text: message, + type: 'chat', + }), + headers: { + 'Content-Type': 'application/json' + }, + method: 'POST' + }); + //ws.send(message); $('#message').val(''); - }); + }); }); \ No newline at end of file diff --git a/web/rtcpushchat.html b/web/rtcpushchat.html index 730541d..95e2319 100644 --- a/web/rtcpushchat.html +++ b/web/rtcpushchat.html @@ -51,30 +51,40 @@