feat: 完善修改成自动绝对路径,添加接口生成

This commit is contained in:
Yun 2024-07-04 09:43:56 +08:00
parent 18d7db35a7
commit cd7d5f31b5
14 changed files with 401 additions and 553 deletions

View File

@ -6,11 +6,10 @@ Real time interactive streaming digital human realize audio video synchronous
## Features
1. 支持多种数字人模型: ernerf、musetalk、wav2lip
2. 支持声音克隆
3. 支持多种音频特征驱动wav2vec、hubert
3. 支持数字人说话被打断
4. 支持全身视频拼接
5. 支持rtmp和webrtc
6. 支持视频编排:不说话时播放自定义视频
7. 支持大模型对话
## 1. Installation
@ -171,13 +170,11 @@ cd MuseTalk
修改configs/inference/realtime.yaml将preparation改为True
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml
运行后将results/avatars下文件拷到本项目的data/avatars下
```
```bash
也可以试用本地目录下的 simple_musetalk.py
cd musetalk
python simple_musetalk.py --avatar_id 2 --file D:\\ok\\test.mp4
运行后将直接生成在data/avatars下
方法二
执行
cd musetalk
python simple_musetalk.py --avatar_id 4 --file D:\\ok\\test.mp4
支持视频和图片生成 会自动生成到data的avatars目录下
```
### 3.10 模型用wav2lip
@ -185,7 +182,7 @@ python simple_musetalk.py --avatar_id 2 --file D:\\ok\\test.mp4
- 下载模型
下载wav2lip运行需要的模型网盘地址 https://drive.uc.cn/s/551be97d7cfa4
将s3fd.pth拷到本项目wav2lip/face_detection/detection/sfd/s3fd.pth, 将wav2lip.pth拷到本项目的models下
数字人模型文件 wav2lip_avatar1.tar.gz, 解压后将整个文件夹拷到本项目的data/avatars下
数字人模型文件 wav2lip_avatar1.tar.gz,网盘地址 https://drive.uc.cn/s/5bd0cde0b0774, 解压后将整个文件夹拷到本项目的data/avatars下
- 运行
python app.py --transport webrtc --model wav2lip --avatar_id wav2lip_avatar1
用浏览器打开http://serverip:8010/webrtcapi.html

235
app.py
View File

@ -1,31 +1,39 @@
# server.py
import argparse
import asyncio
import json
import multiprocessing
from threading import Thread, Event
import aiohttp
import aiohttp_cors
from aiohttp import web
from aiortc import RTCPeerConnection, RTCSessionDescription
from flask import Flask
from flask import Flask, render_template,send_from_directory,request, jsonify
from flask_sockets import Sockets
import base64
import time
import json
import gevent
from gevent import pywsgi
from geventwebsocket.handler import WebSocketHandler
import os
import re
import numpy as np
from threading import Thread,Event
import multiprocessing
from musetalk.simple_musetalk import create_musetalk_human
from aiohttp import web
import aiohttp
import aiohttp_cors
from aiortc import RTCPeerConnection, RTCSessionDescription
from webrtc import HumanPlayer
import argparse
import shutil
import asyncio
app = Flask(__name__)
sockets = Sockets(app)
global nerfreal
@sockets.route('/humanecho')
def echo_socket(ws):
# 获取WebSocket对象
# ws = request.environ.get('wsgi.websocket')
#ws = request.environ.get('wsgi.websocket')
# 如果没有获取到,返回错误信息
if not ws:
print('未建立连接!')
@ -34,11 +42,11 @@ def echo_socket(ws):
else:
print('建立连接!')
while True:
message = ws.receive()
if not message or len(message) == 0:
message = ws.receive()
if not message or len(message)==0:
return '输入信息为空'
else:
else:
nerfreal.put_msg_txt(message)
@ -46,16 +54,15 @@ def llm_response(message):
from llm.LLM import LLM
# llm = LLM().init_model('Gemini', model_path= 'gemini-pro',api_key='Your API Key', proxy_url=None)
# llm = LLM().init_model('ChatGPT', model_path= 'gpt-3.5-turbo',api_key='Your API Key')
llm = LLM().init_model('VllmGPT', model_path='THUDM/chatglm3-6b')
llm = LLM().init_model('VllmGPT', model_path= 'THUDM/chatglm3-6b')
response = llm.chat(message)
print(response)
return response
@sockets.route('/humanchat')
def chat_socket(ws):
# 获取WebSocket对象
# ws = request.environ.get('wsgi.websocket')
#ws = request.environ.get('wsgi.websocket')
# 如果没有获取到,返回错误信息
if not ws:
print('未建立连接!')
@ -64,20 +71,18 @@ def chat_socket(ws):
else:
print('建立连接!')
while True:
message = ws.receive()
if len(message) == 0:
message = ws.receive()
if len(message)==0:
return '输入信息为空'
else:
res = llm_response(message)
res=llm_response(message)
nerfreal.put_msg_txt(res)
#####webrtc###############################
pcs = set()
# @app.route('/offer', methods=['POST'])
#@app.route('/offer', methods=['POST'])
async def offer(request):
params = await request.json()
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
@ -101,7 +106,7 @@ async def offer(request):
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
# return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type})
#return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type})
return web.Response(
content_type="application/json",
@ -110,61 +115,39 @@ async def offer(request):
),
)
async def human(request):
params = await request.json()
if params['type'] == 'echo':
if params.get('interrupt'):
nerfreal.pause_talk()
if params['type']=='echo':
nerfreal.put_msg_txt(params['text'])
elif params['type'] == 'chat':
res = await asyncio.get_event_loop().run_in_executor(None, llm_response(params['text']))
elif params['type']=='chat':
res=await asyncio.get_event_loop().run_in_executor(None, llm_response(params['text']))
nerfreal.put_msg_txt(res)
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "data": "ok"}
{"code": 0, "data":"ok"}
),
)
async def handle_create_musetalk(request):
reader = await request.multipart()
# 处理文件部分
file_part = await reader.next()
filename = file_part.filename
file_data = await file_part.read() # 读取文件的内容
# 注意:确保这个文件路径是可写的
with open(filename, 'wb') as f:
f.write(file_data)
# 处理整数部分
part = await reader.next()
avatar_id = int(await part.text())
create_musetalk_human(filename, avatar_id)
os.remove(filename)
return web.json_response({
'status': 'success',
'filename': filename,
'int_value': avatar_id,
})
async def on_shutdown(app):
# close peer connections
coros = [pc.close() for pc in pcs]
await asyncio.gather(*coros)
pcs.clear()
async def post(url, data):
async def post(url,data):
try:
async with aiohttp.ClientSession() as session:
async with session.post(url, data=data) as response:
async with session.post(url,data=data) as response:
return await response.text()
except aiohttp.ClientError as e:
print(f'Error: {e}')
async def run(push_url):
pc = RTCPeerConnection()
pcs.add(pc)
@ -181,10 +164,8 @@ async def run(push_url):
video_sender = pc.addTrack(player.video)
await pc.setLocalDescription(await pc.createOffer())
answer = await post(push_url, pc.localDescription.sdp)
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer, type='answer'))
answer = await post(push_url,pc.localDescription.sdp)
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer'))
##########################################
# os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
# os.environ['MULTIPROCESSING_METHOD'] = 'forkserver'
@ -203,20 +184,14 @@ if __name__ == '__main__':
### training options
parser.add_argument('--ckpt', type=str, default='data/pretrained/ngp_kf.pth')
parser.add_argument('--num_rays', type=int, default=4096 * 16,
help="num rays sampled per image for each training step")
parser.add_argument('--num_rays', type=int, default=4096 * 16, help="num rays sampled per image for each training step")
parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
parser.add_argument('--max_steps', type=int, default=16,
help="max num steps sampled per ray (only valid when using --cuda_ray)")
parser.add_argument('--num_steps', type=int, default=16,
help="num steps sampled per ray (only valid when NOT using --cuda_ray)")
parser.add_argument('--upsample_steps', type=int, default=0,
help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)")
parser.add_argument('--update_extra_interval', type=int, default=16,
help="iter interval to update extra status (only valid when using --cuda_ray)")
parser.add_argument('--max_ray_batch', type=int, default=4096,
help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)")
parser.add_argument('--max_steps', type=int, default=16, help="max num steps sampled per ray (only valid when using --cuda_ray)")
parser.add_argument('--num_steps', type=int, default=16, help="num steps sampled per ray (only valid when NOT using --cuda_ray)")
parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)")
parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)")
### loss set
parser.add_argument('--warmup_step', type=int, default=10000, help="warm up steps")
@ -227,35 +202,27 @@ if __name__ == '__main__':
### network backbone options
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
parser.add_argument('--bg_img', type=str, default='white', help="background image")
parser.add_argument('--fbg', action='store_true', help="frame-wise bg")
parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes")
parser.add_argument('--fix_eye', type=float, default=-1,
help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye")
parser.add_argument('--fix_eye', type=float, default=-1, help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye")
parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence")
parser.add_argument('--torso_shrink', type=float, default=0.8,
help="shrink bg coords to allow more flexibility in deform")
parser.add_argument('--torso_shrink', type=float, default=0.8, help="shrink bg coords to allow more flexibility in deform")
### dataset options
parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)")
parser.add_argument('--preload', type=int, default=0,
help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.")
parser.add_argument('--preload', type=int, default=0, help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.")
# (the default value is for the fox dataset)
parser.add_argument('--bound', type=float, default=1,
help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.")
parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.")
parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3")
parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location")
parser.add_argument('--dt_gamma', type=float, default=1 / 256,
help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
parser.add_argument('--dt_gamma', type=float, default=1/256, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera")
parser.add_argument('--density_thresh', type=float, default=10,
help="threshold for density grid to be occupied (sigma)")
parser.add_argument('--density_thresh_torso', type=float, default=0.01,
help="threshold for density grid to be occupied (alpha)")
parser.add_argument('--patch_size', type=int, default=1,
help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable")
parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied (sigma)")
parser.add_argument('--density_thresh_torso', type=float, default=0.01, help="threshold for density grid to be occupied (alpha)")
parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable")
parser.add_argument('--init_lips', action='store_true', help="init lips region")
parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region")
@ -273,15 +240,12 @@ if __name__ == '__main__':
parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel")
### else
parser.add_argument('--att', type=int, default=2,
help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)")
parser.add_argument('--aud', type=str, default='',
help="audio source (empty will load the default, else should be a path to a npy file)")
parser.add_argument('--att', type=int, default=2, help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)")
parser.add_argument('--aud', type=str, default='', help="audio source (empty will load the default, else should be a path to a npy file)")
parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits")
parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off")
parser.add_argument('--ind_num', type=int, default=10000,
help="number of individual codes, should be larger than training dataset size")
parser.add_argument('--ind_num', type=int, default=10000, help="number of individual codes, should be larger than training dataset size")
parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off")
@ -290,8 +254,7 @@ if __name__ == '__main__':
parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)")
parser.add_argument('--train_camera', action='store_true', help="optimize camera pose")
parser.add_argument('--smooth_path', action='store_true',
help="brute-force smooth camera pose trajectory with a window size")
parser.add_argument('--smooth_path', action='store_true', help="brute-force smooth camera pose trajectory with a window size")
parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size")
# asr
@ -299,8 +262,8 @@ if __name__ == '__main__':
parser.add_argument('--asr_wav', type=str, default='', help="load the wav and use as input")
parser.add_argument('--asr_play', action='store_true', help="play out the audio")
# parser.add_argument('--asr_model', type=str, default='deepspeech')
parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') #
#parser.add_argument('--asr_model', type=str, default='deepspeech')
parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') #
# parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
# parser.add_argument('--asr_model', type=str, default='facebook/hubert-large-ls960-ft')
@ -319,45 +282,42 @@ if __name__ == '__main__':
parser.add_argument('--fullbody_offset_x', type=int, default=0)
parser.add_argument('--fullbody_offset_y', type=int, default=0)
# musetalk opt
#musetalk opt
parser.add_argument('--avatar_id', type=str, default='avator_1')
parser.add_argument('--bbox_shift', type=int, default=5)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--customvideo', action='store_true', help="custom video")
parser.add_argument('--static_img', action='store_true', help="Use the first photo as a time of rest")
parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img')
parser.add_argument('--customvideo_imgnum', type=int, default=1)
parser.add_argument('--tts', type=str, default='edgetts') # xtts gpt-sovits
parser.add_argument('--tts', type=str, default='edgetts') #xtts gpt-sovits
parser.add_argument('--REF_FILE', type=str, default=None)
parser.add_argument('--REF_TEXT', type=str, default=None)
parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000
parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000
# parser.add_argument('--CHARACTER', type=str, default='test')
# parser.add_argument('--EMOTION', type=str, default='default')
parser.add_argument('--model', type=str, default='ernerf') # musetalk wav2lip
parser.add_argument('--model', type=str, default='ernerf') #musetalk wav2lip
parser.add_argument('--transport', type=str, default='rtcpush') # rtmp webrtc rtcpush
parser.add_argument('--push_url', type=str,
default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') # rtmp://localhost/live/livestream
parser.add_argument('--transport', type=str, default='rtcpush') #rtmp webrtc rtcpush
parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream
parser.add_argument('--listenport', type=int, default=8010)
opt = parser.parse_args()
# app.config.from_object(opt)
# print(app.config)
#app.config.from_object(opt)
#print(app.config)
if opt.model == 'ernerf':
from ernerf.nerf_triplane.provider import NeRFDataset_Test
from ernerf.nerf_triplane.utils import *
from ernerf.nerf_triplane.network import NeRFNetwork
from nerfreal import NeRFReal
# assert test mode
opt.test = True
opt.test_train = False
# opt.train_camera =True
#opt.train_camera =True
# explicit smoothing
opt.smooth_path = True
opt.smooth_lips = True
@ -370,7 +330,7 @@ if __name__ == '__main__':
opt.exp_eye = True
opt.smooth_eye = True
if opt.torso_imgs == '': # no img,use model output
if opt.torso_imgs=='': #no img,use model output
opt.torso = True
# assert opt.cuda_ray, "Only support CUDA ray mode."
@ -386,10 +346,9 @@ if __name__ == '__main__':
model = NeRFNetwork(opt)
criterion = torch.nn.MSELoss(reduction='none')
metrics = [] # use no metric in GUI for faster initialization...
metrics = [] # use no metric in GUI for faster initialization...
print(model)
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16,
metrics=metrics, use_checkpoint=opt.ckpt)
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt)
test_loader = NeRFDataset_Test(opt, device=device).dataloader()
model.aud_features = test_loader._data.auds
@ -399,19 +358,17 @@ if __name__ == '__main__':
nerfreal = NeRFReal(opt, trainer, test_loader)
elif opt.model == 'musetalk':
from musereal import MuseReal
print(opt)
nerfreal = MuseReal(opt)
elif opt.model == 'wav2lip':
from lipreal import LipReal
print(opt)
nerfreal = LipReal(opt)
# txt_to_audio('我是中国人,我来自北京')
if opt.transport == 'rtmp':
#txt_to_audio('我是中国人,我来自北京')
if opt.transport=='rtmp':
thread_quit = Event()
rendthrd = Thread(target=nerfreal.render, args=(thread_quit,))
rendthrd = Thread(target=nerfreal.render,args=(thread_quit,))
rendthrd.start()
#############################################################################
@ -419,37 +376,35 @@ if __name__ == '__main__':
appasync.on_shutdown.append(on_shutdown)
appasync.router.add_post("/offer", offer)
appasync.router.add_post("/human", human)
appasync.router.add_post("/create_musetalk", handle_create_musetalk)
appasync.router.add_static('/', path='web')
appasync.router.add_static('/',path='web')
# Configure default CORS settings.
cors = aiohttp_cors.setup(appasync, defaults={
"*": aiohttp_cors.ResourceOptions(
allow_credentials=True,
expose_headers="*",
allow_headers="*",
)
})
"*": aiohttp_cors.ResourceOptions(
allow_credentials=True,
expose_headers="*",
allow_headers="*",
)
})
# Configure CORS on all routes.
for route in list(appasync.router.routes()):
cors.add(route)
def run_server(runner):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(runner.setup())
site = web.TCPSite(runner, '0.0.0.0', opt.listenport)
loop.run_until_complete(site.start())
if opt.transport == 'rtcpush':
if opt.transport=='rtcpush':
loop.run_until_complete(run(opt.push_url))
loop.run_forever()
loop.run_forever()
Thread(target=run_server, args=(web.AppRunner(appasync),)).start()
print('start websocket server')
# app.on_shutdown.append(on_shutdown)
# app.router.add_post("/offer", offer)
#app.on_shutdown.append(on_shutdown)
#app.router.add_post("/offer", offer)
server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler)
server.serve_forever()

View File

@ -4,29 +4,19 @@ import torch
import torch.nn.functional as F
from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel
#import pyaudio
import soundfile as sf
import resampy
import queue
from queue import Queue
#from collections import deque
from threading import Thread, Event
from io import BytesIO
class ASR:
from baseasr import BaseASR
class ASR(BaseASR):
def __init__(self, opt):
self.opt = opt
self.play = opt.asr_play #false
super().__init__(opt)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.fps = opt.fps # 20 ms per frame
self.sample_rate = 16000
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.mode = 'live' if opt.asr_wav == '' else 'file'
if 'esperanto' in self.opt.asr_model:
self.audio_dim = 44
elif 'deepspeech' in self.opt.asr_model:
@ -41,30 +31,11 @@ class ASR:
self.context_size = opt.m
self.stride_left_size = opt.l
self.stride_right_size = opt.r
self.text = '[START]\n'
self.terminated = False
self.frames = []
self.inwarm = False
# pad left frames
if self.stride_left_size > 0:
self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size)
self.exit_event = Event()
#self.audio_instance = pyaudio.PyAudio() #not need
# create input stream
self.queue = Queue()
self.output_queue = Queue()
# start a background process to read frames
#self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk)
#self.queue = Queue()
#self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk))
# current location of audio
self.idx = 0
# create wav2vec model
print(f'[INFO] loading ASR model {self.opt.asr_model}...')
if 'hubert' in self.opt.asr_model:
@ -74,10 +45,6 @@ class ASR:
self.processor = AutoProcessor.from_pretrained(opt.asr_model)
self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device)
# prepare to save logits
if self.opt.asr_save_feats:
self.all_feats = []
# the extracted features
# use a loop queue to efficiently record endless features: [f--t---][-------][-------]
self.feat_buffer_size = 4
@ -93,8 +60,16 @@ class ASR:
# warm up steps needed: mid + right + window_size + attention_size
self.warm_up_steps = self.context_size + self.stride_left_size + self.stride_right_size #+ self.stride_left_size #+ 8 + 2 * 3
self.listening = False
self.playing = False
def get_audio_frame(self):
try:
frame = self.queue.get(block=False)
type = 0
#print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32)
type = 1
return frame,type
def get_next_feat(self): #get audio embedding to nerf
# return a [1/8, 16] window, for the next input to nerf side.
@ -136,29 +111,19 @@ class ASR:
def run_step(self):
if self.terminated:
return
# get a frame of audio
frame,type = self.__get_audio_frame()
# the last frame
if frame is None:
# terminate, but always run the network for the left frames
self.terminated = True
else:
self.frames.append(frame)
# put to output
self.output_queue.put((frame,type))
# context not enough, do not run network.
if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
return
frame,type = self.get_audio_frame()
self.frames.append(frame)
# put to output
self.output_queue.put((frame,type))
# context not enough, do not run network.
if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
return
inputs = np.concatenate(self.frames) # [N * chunk]
# discard the old part to save memory
if not self.terminated:
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
#print(f'[INFO] frame_to_text... ')
#t = time.time()
@ -166,10 +131,6 @@ class ASR:
#print(f'-------wav2vec time:{time.time()-t:.4f}s')
feats = logits # better lips-sync than labels
# save feats
if self.opt.asr_save_feats:
self.all_feats.append(feats)
# record the feats efficiently.. (no concat, constant memory)
start = self.feat_buffer_idx * self.context_size
end = start + feats.shape[0]
@ -203,24 +164,6 @@ class ASR:
# np.save(output_path, unfold_feats.cpu().numpy())
# print(f"[INFO] saved logits to {output_path}")
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.queue.put(audio_chunk)
def __get_audio_frame(self):
if self.inwarm: # warm up
return np.zeros(self.chunk, dtype=np.float32),1
try:
frame = self.queue.get(block=False)
type = 0
print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32)
type = 1
self.idx = self.idx + self.chunk
return frame,type
def __frame_to_text(self, frame):
@ -241,8 +184,8 @@ class ASR:
right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input.
# do not cut right if terminated.
if self.terminated:
right = logits.shape[1]
# if self.terminated:
# right = logits.shape[1]
logits = logits[:, left:right]
@ -262,10 +205,23 @@ class ASR:
return logits[0], None,None #predicted_ids[0], transcription # [N,]
def get_audio_out(self): #get origin audio pcm to nerf
return self.output_queue.get()
def warm_up(self):
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
t = time.time()
#for _ in range(self.stride_left_size):
# self.frames.append(np.zeros(self.chunk, dtype=np.float32))
for _ in range(self.warm_up_steps):
self.run_step()
#if torch.cuda.is_available():
# torch.cuda.synchronize()
t = time.time() - t
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
#self.clear_queue()
#####not used function#####################################
'''
def __init_queue(self):
self.frames = []
self.queue.queue.clear()
@ -290,26 +246,6 @@ class ASR:
if self.play:
self.output_queue.queue.clear()
def warm_up(self):
#self.listen()
self.inwarm = True
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
t = time.time()
#for _ in range(self.stride_left_size):
# self.frames.append(np.zeros(self.chunk, dtype=np.float32))
for _ in range(self.warm_up_steps):
self.run_step()
#if torch.cuda.is_available():
# torch.cuda.synchronize()
t = time.time() - t
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
self.inwarm = False
#self.clear_queue()
#####not used function#####################################
def listen(self):
# start
if self.mode == 'live' and not self.listening:
@ -404,4 +340,5 @@ if __name__ == '__main__':
raise ValueError("DeepSpeech features should not use this code to extract...")
with ASR(opt) as asr:
asr.run()
asr.run()
'''

View File

@ -6,60 +6,16 @@ import queue
from queue import Queue
import multiprocessing as mp
from baseasr import BaseASR
from wav2lip import audio
class LipASR:
def __init__(self, opt):
self.opt = opt
self.fps = opt.fps # 20 ms per frame
self.sample_rate = 16000
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.queue = Queue()
# self.input_stream = BytesIO()
self.output_queue = mp.Queue()
#self.audio_processor = audio_processor
self.batch_size = opt.batch_size
self.frames = []
self.stride_left_size = opt.l
self.stride_right_size = opt.r
#self.context_size = 10
self.feat_queue = mp.Queue(5)
self.warm_up()
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.queue.put(audio_chunk)
def __get_audio_frame(self):
try:
frame = self.queue.get(block=True,timeout=0.01)
type = 0
#print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32)
type = 1
return frame,type
def get_audio_out(self): #get origin audio pcm to nerf
return self.output_queue.get()
def warm_up(self):
for _ in range(self.stride_left_size + self.stride_right_size):
audio_frame,type=self.__get_audio_frame()
self.frames.append(audio_frame)
self.output_queue.put((audio_frame,type))
for _ in range(self.stride_left_size):
self.output_queue.get()
class LipASR(BaseASR):
def run_step(self):
############################################## extract audio feature ##############################################
# get a frame of audio
for _ in range(self.batch_size*2):
frame,type = self.__get_audio_frame()
frame,type = self.get_audio_frame()
self.frames.append(frame)
# put to output
self.output_queue.put((frame,type))
@ -89,7 +45,3 @@ class LipASR:
# discard the old part to save memory
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
def get_next_feat(self,block,timeout):
return self.feat_queue.get(block,timeout)

View File

@ -164,6 +164,7 @@ class LipReal:
self.__loadavatar()
self.asr = LipASR(opt)
self.asr.warm_up()
if opt.tts == "edgetts":
self.tts = EdgeTTS(opt,self)
elif opt.tts == "gpt-sovits":
@ -199,6 +200,10 @@ class LipReal:
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.asr.put_audio_frame(audio_chunk)
def pause_talk(self):
self.tts.pause_talk()
self.asr.pause_talk()
def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):
@ -257,9 +262,12 @@ class LipReal:
t = time.perf_counter()
self.asr.run_step()
if video_track._queue.qsize()>=2*self.opt.batch_size:
# if video_track._queue.qsize()>=2*self.opt.batch_size:
# print('sleep qsize=',video_track._queue.qsize())
# time.sleep(0.04*video_track._queue.qsize()*0.8)
if video_track._queue.qsize()>=5:
print('sleep qsize=',video_track._queue.qsize())
time.sleep(0.04*self.opt.batch_size*1.5)
time.sleep(0.04*video_track._queue.qsize()*0.8)
# delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
# if delay > 0:

View File

@ -1,65 +1,22 @@
import time
import torch
import numpy as np
import queue
from queue import Queue
import multiprocessing as mp
from baseasr import BaseASR
from musetalk.whisper.audio2feature import Audio2Feature
class MuseASR:
class MuseASR(BaseASR):
def __init__(self, opt, audio_processor:Audio2Feature):
self.opt = opt
self.fps = opt.fps # 20 ms per frame
self.sample_rate = 16000
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.queue = Queue()
# self.input_stream = BytesIO()
self.output_queue = mp.Queue()
super().__init__(opt)
self.audio_processor = audio_processor
self.batch_size = opt.batch_size
self.frames = []
self.stride_left_size = opt.l
self.stride_right_size = opt.r
self.feat_queue = mp.Queue(5)
self.warm_up()
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.queue.put(audio_chunk)
def __get_audio_frame(self):
try:
frame = self.queue.get(block=True,timeout=0.01)
type = 0
#print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32)
type = 1
return frame,type
def get_audio_out(self): #get origin audio pcm to nerf
return self.output_queue.get()
def warm_up(self):
for _ in range(self.stride_left_size + self.stride_right_size):
audio_frame,type=self.__get_audio_frame()
self.frames.append(audio_frame)
self.output_queue.put((audio_frame,type))
for _ in range(self.stride_left_size):
self.output_queue.get()
def run_step(self):
############################################## extract audio feature ##############################################
start_time = time.time()
for _ in range(self.batch_size*2):
audio_frame,type=self.__get_audio_frame()
audio_frame,type=self.get_audio_frame()
self.frames.append(audio_frame)
self.output_queue.put((audio_frame,type))
@ -77,6 +34,3 @@ class MuseASR:
self.feat_queue.put(whisper_chunks)
# discard the old part to save memory
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
def get_next_feat(self,block,timeout):
return self.feat_queue.get(block,timeout)

View File

@ -2,7 +2,7 @@ import math
import torch
import numpy as np
# from .utils import *
#from .utils import *
import subprocess
import os
import time
@ -18,19 +18,17 @@ from threading import Thread, Event
from io import BytesIO
import multiprocessing as mp
from musetalk.utils.utils import get_file_type, get_video_fps, datagen
# from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
from musetalk.utils.blending import get_image, get_image_prepare_material, get_image_blending
from musetalk.utils.utils import load_all_model, load_diffusion_model, load_audio_model
from ttsreal import EdgeTTS, VoitsTTS, XTTS
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
#from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending
from musetalk.utils.utils import load_all_model,load_diffusion_model,load_audio_model
from ttsreal import EdgeTTS,VoitsTTS,XTTS
from museasr import MuseASR
import asyncio
from av import AudioFrame, VideoFrame
from tqdm import tqdm
def read_imgs(img_list):
frames = []
print('reading images...')
@ -39,146 +37,143 @@ def read_imgs(img_list):
frames.append(frame)
return frames
def __mirror_index(size, index):
# size = len(self.coord_list_cycle)
#size = len(self.coord_list_cycle)
turn = index // size
res = index % size
if turn % 2 == 0:
return res
else:
return size - res - 1
def inference(render_event, batch_size, latents_out_path, audio_feat_queue, audio_out_queue, res_frame_queue,
): # vae, unet, pe,timesteps
return size - res - 1
def inference(render_event,batch_size,latents_out_path,audio_feat_queue,audio_out_queue,res_frame_queue,
): #vae, unet, pe,timesteps
vae, unet, pe = load_diffusion_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = torch.tensor([0], device=device)
pe = pe.half()
vae.vae = vae.vae.half()
unet.model = unet.model.half()
input_latent_list_cycle = torch.load(latents_out_path)
length = len(input_latent_list_cycle)
index = 0
count = 0
counttime = 0
count=0
counttime=0
print('start inference')
while True:
if render_event.is_set():
starttime = time.perf_counter()
starttime=time.perf_counter()
try:
whisper_chunks = audio_feat_queue.get(block=True, timeout=1)
except queue.Empty:
continue
is_all_silence = True
is_all_silence=True
audio_frames = []
for _ in range(batch_size * 2):
frame, type = audio_out_queue.get()
audio_frames.append((frame, type))
if type == 0:
is_all_silence = False
for _ in range(batch_size*2):
frame,type = audio_out_queue.get()
audio_frames.append((frame,type))
if type==0:
is_all_silence=False
if is_all_silence:
for i in range(batch_size):
res_frame_queue.put((None, __mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]))
res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
index = index + 1
else:
# print('infer=======')
t = time.perf_counter()
t=time.perf_counter()
whisper_batch = np.stack(whisper_chunks)
latent_batch = []
for i in range(batch_size):
idx = __mirror_index(length, index + i)
idx = __mirror_index(length,index+i)
latent = input_latent_list_cycle[idx]
latent_batch.append(latent)
latent_batch = torch.cat(latent_batch, dim=0)
# for i, (whisper_batch,latent_batch) in enumerate(gen):
audio_feature_batch = torch.from_numpy(whisper_batch)
audio_feature_batch = audio_feature_batch.to(device=unet.device,
dtype=unet.model.dtype)
dtype=unet.model.dtype)
audio_feature_batch = pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
# print('prepare time:',time.perf_counter()-t)
# t=time.perf_counter()
pred_latents = unet.model(latent_batch,
timesteps,
encoder_hidden_states=audio_feature_batch).sample
pred_latents = unet.model(latent_batch,
timesteps,
encoder_hidden_states=audio_feature_batch).sample
# print('unet time:',time.perf_counter()-t)
# t=time.perf_counter()
recon = vae.decode_latents(pred_latents)
# print('vae time:',time.perf_counter()-t)
# print('diffusion len=',len(recon))
#print('diffusion len=',len(recon))
counttime += (time.perf_counter() - t)
count += batch_size
# _totalframe += 1
if count >= 100:
print(f"------actual avg infer fps:{count / counttime:.4f}")
count = 0
counttime = 0
for i, res_frame in enumerate(recon):
# self.__pushmedia(res_frame,loop,audio_track,video_track)
res_frame_queue.put((res_frame, __mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]))
#_totalframe += 1
if count>=100:
print(f"------actual avg infer fps:{count/counttime:.4f}")
count=0
counttime=0
for i,res_frame in enumerate(recon):
#self.__pushmedia(res_frame,loop,audio_track,video_track)
res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
index = index + 1
# print('total batch time:',time.perf_counter()-starttime)
#print('total batch time:',time.perf_counter()-starttime)
else:
time.sleep(1)
print('musereal inference processor stop')
@torch.no_grad()
class MuseReal:
def __init__(self, opt):
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.W = opt.W
self.H = opt.H
self.fps = opt.fps # 20 ms per frame
self.fps = opt.fps # 20 ms per frame
#### musetalk
self.avatar_id = opt.avatar_id
self.static_img = opt.static_img
self.video_path = '' # video_path
self.video_path = '' #video_path
self.bbox_shift = opt.bbox_shift
self.avatar_path = f"./data/avatars/{self.avatar_id}"
self.full_imgs_path = f"{self.avatar_path}/full_imgs"
self.full_imgs_path = f"{self.avatar_path}/full_imgs"
self.coords_path = f"{self.avatar_path}/coords.pkl"
self.latents_out_path = f"{self.avatar_path}/latents.pt"
self.latents_out_path= f"{self.avatar_path}/latents.pt"
self.video_out_path = f"{self.avatar_path}/vid_output/"
self.mask_out_path = f"{self.avatar_path}/mask"
self.mask_coords_path = f"{self.avatar_path}/mask_coords.pkl"
self.mask_out_path =f"{self.avatar_path}/mask"
self.mask_coords_path =f"{self.avatar_path}/mask_coords.pkl"
self.avatar_info_path = f"{self.avatar_path}/avator_info.json"
self.avatar_info = {
"avatar_id": self.avatar_id,
"video_path": self.video_path,
"bbox_shift": self.bbox_shift
"avatar_id":self.avatar_id,
"video_path":self.video_path,
"bbox_shift":self.bbox_shift
}
self.batch_size = opt.batch_size
self.idx = 0
self.res_frame_queue = mp.Queue(self.batch_size * 2)
self.res_frame_queue = mp.Queue(self.batch_size*2)
self.__loadmodels()
self.__loadavatar()
self.asr = MuseASR(opt, self.audio_processor)
self.asr = MuseASR(opt,self.audio_processor)
self.asr.warm_up()
if opt.tts == "edgetts":
self.tts = EdgeTTS(opt, self)
self.tts = EdgeTTS(opt,self)
elif opt.tts == "gpt-sovits":
self.tts = VoitsTTS(opt, self)
self.tts = VoitsTTS(opt,self)
elif opt.tts == "xtts":
self.tts = XTTS(opt, self)
# self.__warm_up()
self.tts = XTTS(opt,self)
#self.__warm_up()
self.render_event = mp.Event()
mp.Process(target=inference, args=(self.render_event, self.batch_size, self.latents_out_path,
self.asr.feat_queue, self.asr.output_queue, self.res_frame_queue,
)).start() # self.vae, self.unet, self.pe,self.timesteps
mp.Process(target=inference, args=(self.render_event,self.batch_size,self.latents_out_path,
self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue,
)).start() #self.vae, self.unet, self.pe,self.timesteps
def __loadmodels(self):
# load model weights
self.audio_processor = load_audio_model()
self.audio_processor= load_audio_model()
# self.audio_processor, self.vae, self.unet, self.pe = load_all_model()
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# self.timesteps = torch.tensor([0], device=device)
@ -187,7 +182,7 @@ class MuseReal:
# self.unet.model = self.unet.model.half()
def __loadavatar(self):
# self.input_latent_list_cycle = torch.load(self.latents_out_path)
#self.input_latent_list_cycle = torch.load(self.latents_out_path)
with open(self.coords_path, 'rb') as f:
self.coord_list_cycle = pickle.load(f)
input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
@ -198,13 +193,19 @@ class MuseReal:
input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]'))
input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
self.mask_list_cycle = read_imgs(input_mask_list)
def put_msg_txt(self, msg):
def put_msg_txt(self,msg):
self.tts.put_msg_txt(msg)
def put_audio_frame(self, audio_chunk): # 16khz 20ms pcm
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.asr.put_audio_frame(audio_chunk)
def pause_talk(self):
self.tts.pause_talk()
self.asr.pause_talk()
def __mirror_index(self, index):
size = len(self.coord_list_cycle)
turn = index // size
@ -212,15 +213,15 @@ class MuseReal:
if turn % 2 == 0:
return res
else:
return size - res - 1
return size - res - 1
def __warm_up(self):
def __warm_up(self):
self.asr.run_step()
whisper_chunks = self.asr.get_next_feat()
whisper_batch = np.stack(whisper_chunks)
latent_batch = []
for i in range(self.batch_size):
idx = self.__mirror_index(self.idx + i)
idx = self.__mirror_index(self.idx+i)
latent = self.input_latent_list_cycle[idx]
latent_batch.append(latent)
latent_batch = torch.cat(latent_batch, dim=0)
@ -228,88 +229,90 @@ class MuseReal:
# for i, (whisper_batch,latent_batch) in enumerate(gen):
audio_feature_batch = torch.from_numpy(whisper_batch)
audio_feature_batch = audio_feature_batch.to(device=self.unet.device,
dtype=self.unet.model.dtype)
dtype=self.unet.model.dtype)
audio_feature_batch = self.pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=self.unet.model.dtype)
pred_latents = self.unet.model(latent_batch,
self.timesteps,
encoder_hidden_states=audio_feature_batch).sample
pred_latents = self.unet.model(latent_batch,
self.timesteps,
encoder_hidden_states=audio_feature_batch).sample
recon = self.vae.decode_latents(pred_latents)
def process_frames(self, quit_event, loop=None, audio_track=None, video_track=None):
def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):
while not quit_event.is_set():
try:
res_frame, idx, audio_frames = self.res_frame_queue.get(block=True, timeout=1)
res_frame,idx,audio_frames = self.res_frame_queue.get(block=True, timeout=1)
except queue.Empty:
continue
if audio_frames[0][1] == 1 and audio_frames[1][1] == 1: # 全为静音数据只需要取fullimg
if self.static_img:
combine_frame = self.frame_list_cycle[0]
else:
combine_frame = self.frame_list_cycle[idx]
if audio_frames[0][1]==1 and audio_frames[1][1]==1: #全为静音数据只需要取fullimg
combine_frame = self.frame_list_cycle[idx]
else:
bbox = self.coord_list_cycle[idx]
ori_frame = copy.deepcopy(self.frame_list_cycle[idx])
x1, y1, x2, y2 = bbox
try:
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
except:
continue
mask = self.mask_list_cycle[idx]
mask_crop_box = self.mask_coords_list_cycle[idx]
# combine_frame = get_image(ori_frame,res_frame,bbox)
# t=time.perf_counter()
combine_frame = get_image_blending(ori_frame, res_frame, bbox, mask, mask_crop_box)
# print('blending time:',time.perf_counter()-t)
#combine_frame = get_image(ori_frame,res_frame,bbox)
#t=time.perf_counter()
combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
#print('blending time:',time.perf_counter()-t)
image = combine_frame # (outputs['image'] * 255).astype(np.uint8)
image = combine_frame #(outputs['image'] * 255).astype(np.uint8)
new_frame = VideoFrame.from_ndarray(image, format="bgr24")
asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop)
asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop)
for audio_frame in audio_frames:
frame, type = audio_frame
frame,type = audio_frame
frame = (frame * 32767).astype(np.int16)
new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
new_frame.planes[0].update(frame.tobytes())
new_frame.sample_rate = 16000
new_frame.sample_rate=16000
# if audio_track._queue.qsize()>10:
# time.sleep(0.1)
asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop)
print('musereal process_frames thread stop')
def render(self, quit_event, loop=None, audio_track=None, video_track=None):
# if self.opt.asr:
print('musereal process_frames thread stop')
def render(self,quit_event,loop=None,audio_track=None,video_track=None):
#if self.opt.asr:
# self.asr.warm_up()
self.tts.render(quit_event)
process_thread = Thread(target=self.process_frames, args=(quit_event, loop, audio_track, video_track))
process_thread = Thread(target=self.process_frames, args=(quit_event,loop,audio_track,video_track))
process_thread.start()
self.render_event.set() # start infer process render
count = 0
totaltime = 0
_starttime = time.perf_counter()
# _totalframe=0
while not quit_event.is_set(): # todo
self.render_event.set() #start infer process render
count=0
totaltime=0
_starttime=time.perf_counter()
#_totalframe=0
while not quit_event.is_set(): #todo
# update texture every frame
# audio stream thread...
t = time.perf_counter()
self.asr.run_step()
# self.test_step(loop,audio_track,video_track)
#self.test_step(loop,audio_track,video_track)
# totaltime += (time.perf_counter() - t)
# count += self.opt.batch_size
# if count>=100:
# print(f"------actual avg infer fps:{count/totaltime:.4f}")
# count=0
# totaltime=0
if video_track._queue.qsize() >= 2 * self.opt.batch_size:
print('sleep qsize=', video_track._queue.qsize())
time.sleep(0.04 * self.opt.batch_size * 1.5)
if video_track._queue.qsize()>=1.5*self.opt.batch_size:
print('sleep qsize=',video_track._queue.qsize())
time.sleep(0.04*video_track._queue.qsize()*0.8)
# if video_track._queue.qsize()>=5:
# print('sleep qsize=',video_track._queue.qsize())
# time.sleep(0.04*video_track._queue.qsize()*0.8)
# delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
# if delay > 0:
# time.sleep(delay)
self.render_event.clear() # end infer process render
self.render_event.clear() #end infer process render
print('musereal thread stop')

View File

@ -7,14 +7,15 @@ from PIL import Image
from .model import BiSeNet
import torchvision.transforms as transforms
class FaceParsing():
def __init__(self, resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
model_pth='./models/face-parse-bisent/79999_iter.pth'):
def __init__(self,resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
model_pth='./models/face-parse-bisent/79999_iter.pth'):
self.net = self.model_init(resnet_path,model_pth)
self.preprocess = self.image_preprocess()
def model_init(self,resnet_path, model_pth):
def model_init(self,
resnet_path,
model_pth):
net = BiSeNet(resnet_path)
if torch.cuda.is_available():
net.cuda()
@ -44,13 +45,13 @@ class FaceParsing():
img = torch.unsqueeze(img, 0)
out = self.net(img)[0]
parsing = out.squeeze(0).cpu().numpy().argmax(0)
parsing[np.where(parsing > 13)] = 0
parsing[np.where(parsing >= 1)] = 255
parsing[np.where(parsing>13)] = 0
parsing[np.where(parsing>=1)] = 255
parsing = Image.fromarray(parsing.astype(np.uint8))
return parsing
if __name__ == "__main__":
fp = FaceParsing()
segmap = fp('154_small.png')
segmap.save('res.png')

View File

@ -20,9 +20,6 @@ class NeRFReal:
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.W = opt.W
self.H = opt.H
self.debug = debug
self.training = False
self.step = 0 # training step
self.trainer = trainer
self.data_loader = data_loader
@ -44,7 +41,6 @@ class NeRFReal:
#self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item()
# playing seq from dataloader, or pause.
self.playing = True #False todo
self.loader = iter(data_loader)
#self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
@ -62,9 +58,8 @@ class NeRFReal:
self.customimg_index = 0
# build asr
if self.opt.asr:
self.asr = ASR(opt)
self.asr.warm_up()
self.asr = ASR(opt)
self.asr.warm_up()
if opt.tts == "edgetts":
self.tts = EdgeTTS(opt,self)
elif opt.tts == "gpt-sovits":
@ -122,7 +117,11 @@ class NeRFReal:
self.tts.put_msg_txt(msg)
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.asr.put_audio_frame(audio_chunk)
self.asr.put_audio_frame(audio_chunk)
def pause_talk(self):
self.tts.pause_talk()
self.asr.pause_talk()
def mirror_index(self, index):
@ -248,10 +247,9 @@ class NeRFReal:
# update texture every frame
# audio stream thread...
t = time.perf_counter()
if self.opt.asr and self.playing:
# run 2 ASR steps (audio is at 50FPS, video is at 25FPS)
for _ in range(2):
self.asr.run_step()
# run 2 ASR steps (audio is at 50FPS, video is at 25FPS)
for _ in range(2):
self.asr.run_step()
self.test_step(loop,audio_track,video_track)
totaltime += (time.perf_counter() - t)
count += 1
@ -267,7 +265,7 @@ class NeRFReal:
else:
if video_track._queue.qsize()>=5:
#print('sleep qsize=',video_track._queue.qsize())
time.sleep(0.1)
time.sleep(0.04*video_track._queue.qsize()*0.8)
print('nerfreal thread stop')

View File

@ -13,6 +13,11 @@ import queue
from queue import Queue
from io import BytesIO
from threading import Thread, Event
from enum import Enum
class State(Enum):
RUNNING=0
PAUSE=1
class BaseTTS:
def __init__(self, opt, parent):
@ -25,6 +30,11 @@ class BaseTTS:
self.input_stream = BytesIO()
self.msgqueue = Queue()
self.state = State.RUNNING
def pause_talk(self):
self.msgqueue.queue.clear()
self.state = State.PAUSE
def put_msg_txt(self,msg):
self.msgqueue.put(msg)
@ -37,6 +47,7 @@ class BaseTTS:
while not quit_event.is_set():
try:
msg = self.msgqueue.get(block=True, timeout=1)
self.state=State.RUNNING
except queue.Empty:
continue
self.txt_to_audio(msg)
@ -59,7 +70,7 @@ class EdgeTTS(BaseTTS):
stream = self.__create_bytes_stream(self.input_stream)
streamlen = stream.shape[0]
idx=0
while streamlen >= self.chunk:
while streamlen >= self.chunk and self.state==State.RUNNING:
self.parent.put_audio_frame(stream[idx:idx+self.chunk])
streamlen -= self.chunk
idx += self.chunk
@ -92,7 +103,7 @@ class EdgeTTS(BaseTTS):
async for chunk in communicate.stream():
if first:
first = False
if chunk["type"] == "audio":
if chunk["type"] == "audio" and self.state==State.RUNNING:
#self.push_audio(chunk["data"])
self.input_stream.write(chunk["data"])
#file.write(chunk["data"])
@ -147,7 +158,7 @@ class VoitsTTS(BaseTTS):
end = time.perf_counter()
print(f"gpt_sovits Time to first chunk: {end-start}s")
first = False
if chunk:
if chunk and self.state==State.RUNNING:
yield chunk
print("gpt_sovits response.elapsed:", res.elapsed)

View File

@ -29,22 +29,22 @@
$(document).ready(function() {
var host = window.location.hostname
var ws = new WebSocket("ws://"+host+":8000/humanchat");
//document.getElementsByTagName("video")[0].setAttribute("src", aa["video"]);
ws.onopen = function() {
console.log('Connected');
};
ws.onmessage = function(e) {
console.log('Received: ' + e.data);
data = e
var vid = JSON.parse(data.data);
console.log(typeof(vid),vid)
//document.getElementsByTagName("video")[0].setAttribute("src", vid["video"]);
// var ws = new WebSocket("ws://"+host+":8000/humanecho");
// //document.getElementsByTagName("video")[0].setAttribute("src", aa["video"]);
// ws.onopen = function() {
// console.log('Connected');
// };
// ws.onmessage = function(e) {
// console.log('Received: ' + e.data);
// data = e
// var vid = JSON.parse(data.data);
// console.log(typeof(vid),vid)
// //document.getElementsByTagName("video")[0].setAttribute("src", vid["video"]);
};
ws.onclose = function(e) {
console.log('Closed');
};
// };
// ws.onclose = function(e) {
// console.log('Closed');
// };
flvPlayer = mpegts.createPlayer({type: 'flv', url: "http://"+host+":8080/live/livestream.flv", isLive: true, enableStashBuffer: false});
flvPlayer.attachMediaElement(document.getElementById('video_player'));
@ -55,9 +55,19 @@
e.preventDefault();
var message = $('#message').val();
console.log('Sending: ' + message);
ws.send(message);
fetch('/human', {
body: JSON.stringify({
text: message,
type: 'chat',
}),
headers: {
'Content-Type': 'application/json'
},
method: 'POST'
});
//ws.send(message);
$('#message').val('');
});
});
});
</script>
</html>

View File

@ -51,30 +51,40 @@
<script type="text/javascript" charset="utf-8">
$(document).ready(function() {
var host = window.location.hostname
var ws = new WebSocket("ws://"+host+":8000/humanchat");
//document.getElementsByTagName("video")[0].setAttribute("src", aa["video"]);
ws.onopen = function() {
console.log('Connected');
};
ws.onmessage = function(e) {
console.log('Received: ' + e.data);
data = e
var vid = JSON.parse(data.data);
console.log(typeof(vid),vid)
//document.getElementsByTagName("video")[0].setAttribute("src", vid["video"]);
// var host = window.location.hostname
// var ws = new WebSocket("ws://"+host+":8000/humanecho");
// //document.getElementsByTagName("video")[0].setAttribute("src", aa["video"]);
// ws.onopen = function() {
// console.log('Connected');
// };
// ws.onmessage = function(e) {
// console.log('Received: ' + e.data);
// data = e
// var vid = JSON.parse(data.data);
// console.log(typeof(vid),vid)
// //document.getElementsByTagName("video")[0].setAttribute("src", vid["video"]);
};
ws.onclose = function(e) {
console.log('Closed');
};
// };
// ws.onclose = function(e) {
// console.log('Closed');
// };
$('#echo-form').on('submit', function(e) {
e.preventDefault();
var message = $('#message').val();
console.log('Sending: ' + message);
ws.send(message);
$('#message').val('');
e.preventDefault();
var message = $('#message').val();
console.log('Sending: ' + message);
fetch('/human', {
body: JSON.stringify({
text: message,
type: 'chat',
}),
headers: {
'Content-Type': 'application/json'
},
method: 'POST'
});
//ws.send(message);
$('#message').val('');
});
});

View File

@ -79,6 +79,7 @@
body: JSON.stringify({
text: message,
type: 'echo',
interrupt: true,
}),
headers: {
'Content-Type': 'application/json'

View File

@ -53,30 +53,41 @@
<script type="text/javascript" charset="utf-8">
$(document).ready(function() {
var host = window.location.hostname
var ws = new WebSocket("ws://"+host+":8000/humanchat");
//document.getElementsByTagName("video")[0].setAttribute("src", aa["video"]);
ws.onopen = function() {
console.log('Connected');
};
ws.onmessage = function(e) {
console.log('Received: ' + e.data);
data = e
var vid = JSON.parse(data.data);
console.log(typeof(vid),vid)
//document.getElementsByTagName("video")[0].setAttribute("src", vid["video"]);
// var host = window.location.hostname
// var ws = new WebSocket("ws://"+host+":8000/humanecho");
// //document.getElementsByTagName("video")[0].setAttribute("src", aa["video"]);
// ws.onopen = function() {
// console.log('Connected');
// };
// ws.onmessage = function(e) {
// console.log('Received: ' + e.data);
// data = e
// var vid = JSON.parse(data.data);
// console.log(typeof(vid),vid)
// //document.getElementsByTagName("video")[0].setAttribute("src", vid["video"]);
};
ws.onclose = function(e) {
console.log('Closed');
};
// };
// ws.onclose = function(e) {
// console.log('Closed');
// };
$('#echo-form').on('submit', function(e) {
e.preventDefault();
var message = $('#message').val();
console.log('Sending: ' + message);
ws.send(message);
$('#message').val('');
e.preventDefault();
var message = $('#message').val();
console.log('Sending: ' + message);
fetch('/human', {
body: JSON.stringify({
text: message,
type: 'chat',
interrupt: true,
}),
headers: {
'Content-Type': 'application/json'
},
method: 'POST'
});
//ws.send(message);
$('#message').val('');
});
});
</script>