feat: 完善修改成自动绝对路径,添加接口生成

This commit is contained in:
Yun 2024-07-04 09:43:56 +08:00
parent 18d7db35a7
commit cd7d5f31b5
14 changed files with 401 additions and 553 deletions

View File

@ -6,11 +6,10 @@ Real time interactive streaming digital human realize audio video synchronous
## Features ## Features
1. 支持多种数字人模型: ernerf、musetalk、wav2lip 1. 支持多种数字人模型: ernerf、musetalk、wav2lip
2. 支持声音克隆 2. 支持声音克隆
3. 支持多种音频特征驱动wav2vec、hubert 3. 支持数字人说话被打断
4. 支持全身视频拼接 4. 支持全身视频拼接
5. 支持rtmp和webrtc 5. 支持rtmp和webrtc
6. 支持视频编排:不说话时播放自定义视频 6. 支持视频编排:不说话时播放自定义视频
7. 支持大模型对话
## 1. Installation ## 1. Installation
@ -171,13 +170,11 @@ cd MuseTalk
修改configs/inference/realtime.yaml将preparation改为True 修改configs/inference/realtime.yaml将preparation改为True
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml
运行后将results/avatars下文件拷到本项目的data/avatars下 运行后将results/avatars下文件拷到本项目的data/avatars下
``` 方法二
执行
```bash cd musetalk
也可以试用本地目录下的 simple_musetalk.py python simple_musetalk.py --avatar_id 4 --file D:\\ok\\test.mp4
cd musetalk 支持视频和图片生成 会自动生成到data的avatars目录下
python simple_musetalk.py --avatar_id 2 --file D:\\ok\\test.mp4
运行后将直接生成在data/avatars下
``` ```
### 3.10 模型用wav2lip ### 3.10 模型用wav2lip
@ -185,7 +182,7 @@ python simple_musetalk.py --avatar_id 2 --file D:\\ok\\test.mp4
- 下载模型 - 下载模型
下载wav2lip运行需要的模型网盘地址 https://drive.uc.cn/s/551be97d7cfa4 下载wav2lip运行需要的模型网盘地址 https://drive.uc.cn/s/551be97d7cfa4
将s3fd.pth拷到本项目wav2lip/face_detection/detection/sfd/s3fd.pth, 将wav2lip.pth拷到本项目的models下 将s3fd.pth拷到本项目wav2lip/face_detection/detection/sfd/s3fd.pth, 将wav2lip.pth拷到本项目的models下
数字人模型文件 wav2lip_avatar1.tar.gz, 解压后将整个文件夹拷到本项目的data/avatars下 数字人模型文件 wav2lip_avatar1.tar.gz,网盘地址 https://drive.uc.cn/s/5bd0cde0b0774, 解压后将整个文件夹拷到本项目的data/avatars下
- 运行 - 运行
python app.py --transport webrtc --model wav2lip --avatar_id wav2lip_avatar1 python app.py --transport webrtc --model wav2lip --avatar_id wav2lip_avatar1
用浏览器打开http://serverip:8010/webrtcapi.html 用浏览器打开http://serverip:8010/webrtcapi.html

235
app.py
View File

@ -1,31 +1,39 @@
# server.py # server.py
import argparse from flask import Flask, render_template,send_from_directory,request, jsonify
import asyncio
import json
import multiprocessing
from threading import Thread, Event
import aiohttp
import aiohttp_cors
from aiohttp import web
from aiortc import RTCPeerConnection, RTCSessionDescription
from flask import Flask
from flask_sockets import Sockets from flask_sockets import Sockets
import base64
import time
import json
import gevent
from gevent import pywsgi from gevent import pywsgi
from geventwebsocket.handler import WebSocketHandler from geventwebsocket.handler import WebSocketHandler
import os
import re
import numpy as np
from threading import Thread,Event
import multiprocessing
from musetalk.simple_musetalk import create_musetalk_human from aiohttp import web
import aiohttp
import aiohttp_cors
from aiortc import RTCPeerConnection, RTCSessionDescription
from webrtc import HumanPlayer from webrtc import HumanPlayer
import argparse
import shutil
import asyncio
app = Flask(__name__) app = Flask(__name__)
sockets = Sockets(app) sockets = Sockets(app)
global nerfreal global nerfreal
@sockets.route('/humanecho') @sockets.route('/humanecho')
def echo_socket(ws): def echo_socket(ws):
# 获取WebSocket对象 # 获取WebSocket对象
# ws = request.environ.get('wsgi.websocket') #ws = request.environ.get('wsgi.websocket')
# 如果没有获取到,返回错误信息 # 如果没有获取到,返回错误信息
if not ws: if not ws:
print('未建立连接!') print('未建立连接!')
@ -34,11 +42,11 @@ def echo_socket(ws):
else: else:
print('建立连接!') print('建立连接!')
while True: while True:
message = ws.receive() message = ws.receive()
if not message or len(message) == 0: if not message or len(message)==0:
return '输入信息为空' return '输入信息为空'
else: else:
nerfreal.put_msg_txt(message) nerfreal.put_msg_txt(message)
@ -46,16 +54,15 @@ def llm_response(message):
from llm.LLM import LLM from llm.LLM import LLM
# llm = LLM().init_model('Gemini', model_path= 'gemini-pro',api_key='Your API Key', proxy_url=None) # llm = LLM().init_model('Gemini', model_path= 'gemini-pro',api_key='Your API Key', proxy_url=None)
# llm = LLM().init_model('ChatGPT', model_path= 'gpt-3.5-turbo',api_key='Your API Key') # llm = LLM().init_model('ChatGPT', model_path= 'gpt-3.5-turbo',api_key='Your API Key')
llm = LLM().init_model('VllmGPT', model_path='THUDM/chatglm3-6b') llm = LLM().init_model('VllmGPT', model_path= 'THUDM/chatglm3-6b')
response = llm.chat(message) response = llm.chat(message)
print(response) print(response)
return response return response
@sockets.route('/humanchat') @sockets.route('/humanchat')
def chat_socket(ws): def chat_socket(ws):
# 获取WebSocket对象 # 获取WebSocket对象
# ws = request.environ.get('wsgi.websocket') #ws = request.environ.get('wsgi.websocket')
# 如果没有获取到,返回错误信息 # 如果没有获取到,返回错误信息
if not ws: if not ws:
print('未建立连接!') print('未建立连接!')
@ -64,20 +71,18 @@ def chat_socket(ws):
else: else:
print('建立连接!') print('建立连接!')
while True: while True:
message = ws.receive() message = ws.receive()
if len(message) == 0: if len(message)==0:
return '输入信息为空' return '输入信息为空'
else: else:
res = llm_response(message) res=llm_response(message)
nerfreal.put_msg_txt(res) nerfreal.put_msg_txt(res)
#####webrtc############################### #####webrtc###############################
pcs = set() pcs = set()
#@app.route('/offer', methods=['POST'])
# @app.route('/offer', methods=['POST'])
async def offer(request): async def offer(request):
params = await request.json() params = await request.json()
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
@ -101,7 +106,7 @@ async def offer(request):
answer = await pc.createAnswer() answer = await pc.createAnswer()
await pc.setLocalDescription(answer) await pc.setLocalDescription(answer)
# return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}) #return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type})
return web.Response( return web.Response(
content_type="application/json", content_type="application/json",
@ -110,61 +115,39 @@ async def offer(request):
), ),
) )
async def human(request): async def human(request):
params = await request.json() params = await request.json()
if params['type'] == 'echo': if params.get('interrupt'):
nerfreal.pause_talk()
if params['type']=='echo':
nerfreal.put_msg_txt(params['text']) nerfreal.put_msg_txt(params['text'])
elif params['type'] == 'chat': elif params['type']=='chat':
res = await asyncio.get_event_loop().run_in_executor(None, llm_response(params['text'])) res=await asyncio.get_event_loop().run_in_executor(None, llm_response(params['text']))
nerfreal.put_msg_txt(res) nerfreal.put_msg_txt(res)
return web.Response( return web.Response(
content_type="application/json", content_type="application/json",
text=json.dumps( text=json.dumps(
{"code": 0, "data": "ok"} {"code": 0, "data":"ok"}
), ),
) )
async def handle_create_musetalk(request):
reader = await request.multipart()
# 处理文件部分
file_part = await reader.next()
filename = file_part.filename
file_data = await file_part.read() # 读取文件的内容
# 注意:确保这个文件路径是可写的
with open(filename, 'wb') as f:
f.write(file_data)
# 处理整数部分
part = await reader.next()
avatar_id = int(await part.text())
create_musetalk_human(filename, avatar_id)
os.remove(filename)
return web.json_response({
'status': 'success',
'filename': filename,
'int_value': avatar_id,
})
async def on_shutdown(app): async def on_shutdown(app):
# close peer connections # close peer connections
coros = [pc.close() for pc in pcs] coros = [pc.close() for pc in pcs]
await asyncio.gather(*coros) await asyncio.gather(*coros)
pcs.clear() pcs.clear()
async def post(url,data):
async def post(url, data):
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post(url, data=data) as response: async with session.post(url,data=data) as response:
return await response.text() return await response.text()
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
print(f'Error: {e}') print(f'Error: {e}')
async def run(push_url): async def run(push_url):
pc = RTCPeerConnection() pc = RTCPeerConnection()
pcs.add(pc) pcs.add(pc)
@ -181,10 +164,8 @@ async def run(push_url):
video_sender = pc.addTrack(player.video) video_sender = pc.addTrack(player.video)
await pc.setLocalDescription(await pc.createOffer()) await pc.setLocalDescription(await pc.createOffer())
answer = await post(push_url, pc.localDescription.sdp) answer = await post(push_url,pc.localDescription.sdp)
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer, type='answer')) await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer'))
########################################## ##########################################
# os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' # os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
# os.environ['MULTIPROCESSING_METHOD'] = 'forkserver' # os.environ['MULTIPROCESSING_METHOD'] = 'forkserver'
@ -203,20 +184,14 @@ if __name__ == '__main__':
### training options ### training options
parser.add_argument('--ckpt', type=str, default='data/pretrained/ngp_kf.pth') parser.add_argument('--ckpt', type=str, default='data/pretrained/ngp_kf.pth')
parser.add_argument('--num_rays', type=int, default=4096 * 16, parser.add_argument('--num_rays', type=int, default=4096 * 16, help="num rays sampled per image for each training step")
help="num rays sampled per image for each training step")
parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
parser.add_argument('--max_steps', type=int, default=16, parser.add_argument('--max_steps', type=int, default=16, help="max num steps sampled per ray (only valid when using --cuda_ray)")
help="max num steps sampled per ray (only valid when using --cuda_ray)") parser.add_argument('--num_steps', type=int, default=16, help="num steps sampled per ray (only valid when NOT using --cuda_ray)")
parser.add_argument('--num_steps', type=int, default=16, parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)")
help="num steps sampled per ray (only valid when NOT using --cuda_ray)") parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
parser.add_argument('--upsample_steps', type=int, default=0, parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)")
help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)")
parser.add_argument('--update_extra_interval', type=int, default=16,
help="iter interval to update extra status (only valid when using --cuda_ray)")
parser.add_argument('--max_ray_batch', type=int, default=4096,
help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)")
### loss set ### loss set
parser.add_argument('--warmup_step', type=int, default=10000, help="warm up steps") parser.add_argument('--warmup_step', type=int, default=10000, help="warm up steps")
@ -227,35 +202,27 @@ if __name__ == '__main__':
### network backbone options ### network backbone options
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training") parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
parser.add_argument('--bg_img', type=str, default='white', help="background image") parser.add_argument('--bg_img', type=str, default='white', help="background image")
parser.add_argument('--fbg', action='store_true', help="frame-wise bg") parser.add_argument('--fbg', action='store_true', help="frame-wise bg")
parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes") parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes")
parser.add_argument('--fix_eye', type=float, default=-1, parser.add_argument('--fix_eye', type=float, default=-1, help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye")
help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye")
parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence") parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence")
parser.add_argument('--torso_shrink', type=float, default=0.8, parser.add_argument('--torso_shrink', type=float, default=0.8, help="shrink bg coords to allow more flexibility in deform")
help="shrink bg coords to allow more flexibility in deform")
### dataset options ### dataset options
parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)") parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)")
parser.add_argument('--preload', type=int, default=0, parser.add_argument('--preload', type=int, default=0, help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.")
help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.")
# (the default value is for the fox dataset) # (the default value is for the fox dataset)
parser.add_argument('--bound', type=float, default=1, parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.")
help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.")
parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3") parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3")
parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location") parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location")
parser.add_argument('--dt_gamma', type=float, default=1 / 256, parser.add_argument('--dt_gamma', type=float, default=1/256, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera") parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera")
parser.add_argument('--density_thresh', type=float, default=10, parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied (sigma)")
help="threshold for density grid to be occupied (sigma)") parser.add_argument('--density_thresh_torso', type=float, default=0.01, help="threshold for density grid to be occupied (alpha)")
parser.add_argument('--density_thresh_torso', type=float, default=0.01, parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable")
help="threshold for density grid to be occupied (alpha)")
parser.add_argument('--patch_size', type=int, default=1,
help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable")
parser.add_argument('--init_lips', action='store_true', help="init lips region") parser.add_argument('--init_lips', action='store_true', help="init lips region")
parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region") parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region")
@ -273,15 +240,12 @@ if __name__ == '__main__':
parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel") parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel")
### else ### else
parser.add_argument('--att', type=int, default=2, parser.add_argument('--att', type=int, default=2, help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)")
help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)") parser.add_argument('--aud', type=str, default='', help="audio source (empty will load the default, else should be a path to a npy file)")
parser.add_argument('--aud', type=str, default='',
help="audio source (empty will load the default, else should be a path to a npy file)")
parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits") parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits")
parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off") parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off")
parser.add_argument('--ind_num', type=int, default=10000, parser.add_argument('--ind_num', type=int, default=10000, help="number of individual codes, should be larger than training dataset size")
help="number of individual codes, should be larger than training dataset size")
parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off") parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off")
@ -290,8 +254,7 @@ if __name__ == '__main__':
parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)") parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)")
parser.add_argument('--train_camera', action='store_true', help="optimize camera pose") parser.add_argument('--train_camera', action='store_true', help="optimize camera pose")
parser.add_argument('--smooth_path', action='store_true', parser.add_argument('--smooth_path', action='store_true', help="brute-force smooth camera pose trajectory with a window size")
help="brute-force smooth camera pose trajectory with a window size")
parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size") parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size")
# asr # asr
@ -299,8 +262,8 @@ if __name__ == '__main__':
parser.add_argument('--asr_wav', type=str, default='', help="load the wav and use as input") parser.add_argument('--asr_wav', type=str, default='', help="load the wav and use as input")
parser.add_argument('--asr_play', action='store_true', help="play out the audio") parser.add_argument('--asr_play', action='store_true', help="play out the audio")
# parser.add_argument('--asr_model', type=str, default='deepspeech') #parser.add_argument('--asr_model', type=str, default='deepspeech')
parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') # parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') #
# parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self') # parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
# parser.add_argument('--asr_model', type=str, default='facebook/hubert-large-ls960-ft') # parser.add_argument('--asr_model', type=str, default='facebook/hubert-large-ls960-ft')
@ -319,45 +282,42 @@ if __name__ == '__main__':
parser.add_argument('--fullbody_offset_x', type=int, default=0) parser.add_argument('--fullbody_offset_x', type=int, default=0)
parser.add_argument('--fullbody_offset_y', type=int, default=0) parser.add_argument('--fullbody_offset_y', type=int, default=0)
# musetalk opt #musetalk opt
parser.add_argument('--avatar_id', type=str, default='avator_1') parser.add_argument('--avatar_id', type=str, default='avator_1')
parser.add_argument('--bbox_shift', type=int, default=5) parser.add_argument('--bbox_shift', type=int, default=5)
parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--customvideo', action='store_true', help="custom video") parser.add_argument('--customvideo', action='store_true', help="custom video")
parser.add_argument('--static_img', action='store_true', help="Use the first photo as a time of rest")
parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img') parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img')
parser.add_argument('--customvideo_imgnum', type=int, default=1) parser.add_argument('--customvideo_imgnum', type=int, default=1)
parser.add_argument('--tts', type=str, default='edgetts') # xtts gpt-sovits parser.add_argument('--tts', type=str, default='edgetts') #xtts gpt-sovits
parser.add_argument('--REF_FILE', type=str, default=None) parser.add_argument('--REF_FILE', type=str, default=None)
parser.add_argument('--REF_TEXT', type=str, default=None) parser.add_argument('--REF_TEXT', type=str, default=None)
parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000 parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000
# parser.add_argument('--CHARACTER', type=str, default='test') # parser.add_argument('--CHARACTER', type=str, default='test')
# parser.add_argument('--EMOTION', type=str, default='default') # parser.add_argument('--EMOTION', type=str, default='default')
parser.add_argument('--model', type=str, default='ernerf') # musetalk wav2lip parser.add_argument('--model', type=str, default='ernerf') #musetalk wav2lip
parser.add_argument('--transport', type=str, default='rtcpush') # rtmp webrtc rtcpush parser.add_argument('--transport', type=str, default='rtcpush') #rtmp webrtc rtcpush
parser.add_argument('--push_url', type=str, parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream
default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') # rtmp://localhost/live/livestream
parser.add_argument('--listenport', type=int, default=8010) parser.add_argument('--listenport', type=int, default=8010)
opt = parser.parse_args() opt = parser.parse_args()
# app.config.from_object(opt) #app.config.from_object(opt)
# print(app.config) #print(app.config)
if opt.model == 'ernerf': if opt.model == 'ernerf':
from ernerf.nerf_triplane.provider import NeRFDataset_Test from ernerf.nerf_triplane.provider import NeRFDataset_Test
from ernerf.nerf_triplane.utils import * from ernerf.nerf_triplane.utils import *
from ernerf.nerf_triplane.network import NeRFNetwork from ernerf.nerf_triplane.network import NeRFNetwork
from nerfreal import NeRFReal from nerfreal import NeRFReal
# assert test mode # assert test mode
opt.test = True opt.test = True
opt.test_train = False opt.test_train = False
# opt.train_camera =True #opt.train_camera =True
# explicit smoothing # explicit smoothing
opt.smooth_path = True opt.smooth_path = True
opt.smooth_lips = True opt.smooth_lips = True
@ -370,7 +330,7 @@ if __name__ == '__main__':
opt.exp_eye = True opt.exp_eye = True
opt.smooth_eye = True opt.smooth_eye = True
if opt.torso_imgs == '': # no img,use model output if opt.torso_imgs=='': #no img,use model output
opt.torso = True opt.torso = True
# assert opt.cuda_ray, "Only support CUDA ray mode." # assert opt.cuda_ray, "Only support CUDA ray mode."
@ -386,10 +346,9 @@ if __name__ == '__main__':
model = NeRFNetwork(opt) model = NeRFNetwork(opt)
criterion = torch.nn.MSELoss(reduction='none') criterion = torch.nn.MSELoss(reduction='none')
metrics = [] # use no metric in GUI for faster initialization... metrics = [] # use no metric in GUI for faster initialization...
print(model) print(model)
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt)
metrics=metrics, use_checkpoint=opt.ckpt)
test_loader = NeRFDataset_Test(opt, device=device).dataloader() test_loader = NeRFDataset_Test(opt, device=device).dataloader()
model.aud_features = test_loader._data.auds model.aud_features = test_loader._data.auds
@ -399,19 +358,17 @@ if __name__ == '__main__':
nerfreal = NeRFReal(opt, trainer, test_loader) nerfreal = NeRFReal(opt, trainer, test_loader)
elif opt.model == 'musetalk': elif opt.model == 'musetalk':
from musereal import MuseReal from musereal import MuseReal
print(opt) print(opt)
nerfreal = MuseReal(opt) nerfreal = MuseReal(opt)
elif opt.model == 'wav2lip': elif opt.model == 'wav2lip':
from lipreal import LipReal from lipreal import LipReal
print(opt) print(opt)
nerfreal = LipReal(opt) nerfreal = LipReal(opt)
# txt_to_audio('我是中国人,我来自北京') #txt_to_audio('我是中国人,我来自北京')
if opt.transport == 'rtmp': if opt.transport=='rtmp':
thread_quit = Event() thread_quit = Event()
rendthrd = Thread(target=nerfreal.render, args=(thread_quit,)) rendthrd = Thread(target=nerfreal.render,args=(thread_quit,))
rendthrd.start() rendthrd.start()
############################################################################# #############################################################################
@ -419,37 +376,35 @@ if __name__ == '__main__':
appasync.on_shutdown.append(on_shutdown) appasync.on_shutdown.append(on_shutdown)
appasync.router.add_post("/offer", offer) appasync.router.add_post("/offer", offer)
appasync.router.add_post("/human", human) appasync.router.add_post("/human", human)
appasync.router.add_post("/create_musetalk", handle_create_musetalk) appasync.router.add_static('/',path='web')
appasync.router.add_static('/', path='web')
# Configure default CORS settings. # Configure default CORS settings.
cors = aiohttp_cors.setup(appasync, defaults={ cors = aiohttp_cors.setup(appasync, defaults={
"*": aiohttp_cors.ResourceOptions( "*": aiohttp_cors.ResourceOptions(
allow_credentials=True, allow_credentials=True,
expose_headers="*", expose_headers="*",
allow_headers="*", allow_headers="*",
) )
}) })
# Configure CORS on all routes. # Configure CORS on all routes.
for route in list(appasync.router.routes()): for route in list(appasync.router.routes()):
cors.add(route) cors.add(route)
def run_server(runner): def run_server(runner):
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
loop.run_until_complete(runner.setup()) loop.run_until_complete(runner.setup())
site = web.TCPSite(runner, '0.0.0.0', opt.listenport) site = web.TCPSite(runner, '0.0.0.0', opt.listenport)
loop.run_until_complete(site.start()) loop.run_until_complete(site.start())
if opt.transport == 'rtcpush': if opt.transport=='rtcpush':
loop.run_until_complete(run(opt.push_url)) loop.run_until_complete(run(opt.push_url))
loop.run_forever() loop.run_forever()
Thread(target=run_server, args=(web.AppRunner(appasync),)).start() Thread(target=run_server, args=(web.AppRunner(appasync),)).start()
print('start websocket server') print('start websocket server')
# app.on_shutdown.append(on_shutdown) #app.on_shutdown.append(on_shutdown)
# app.router.add_post("/offer", offer) #app.router.add_post("/offer", offer)
server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler) server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler)
server.serve_forever() server.serve_forever()

View File

@ -4,29 +4,19 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel
#import pyaudio
import soundfile as sf
import resampy
import queue import queue
from queue import Queue from queue import Queue
#from collections import deque #from collections import deque
from threading import Thread, Event from threading import Thread, Event
from io import BytesIO
class ASR: from baseasr import BaseASR
class ASR(BaseASR):
def __init__(self, opt): def __init__(self, opt):
super().__init__(opt)
self.opt = opt
self.play = opt.asr_play #false
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.fps = opt.fps # 20 ms per frame
self.sample_rate = 16000
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.mode = 'live' if opt.asr_wav == '' else 'file'
if 'esperanto' in self.opt.asr_model: if 'esperanto' in self.opt.asr_model:
self.audio_dim = 44 self.audio_dim = 44
elif 'deepspeech' in self.opt.asr_model: elif 'deepspeech' in self.opt.asr_model:
@ -41,30 +31,11 @@ class ASR:
self.context_size = opt.m self.context_size = opt.m
self.stride_left_size = opt.l self.stride_left_size = opt.l
self.stride_right_size = opt.r self.stride_right_size = opt.r
self.text = '[START]\n'
self.terminated = False
self.frames = []
self.inwarm = False
# pad left frames # pad left frames
if self.stride_left_size > 0: if self.stride_left_size > 0:
self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size) self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size)
self.exit_event = Event()
#self.audio_instance = pyaudio.PyAudio() #not need
# create input stream
self.queue = Queue()
self.output_queue = Queue()
# start a background process to read frames
#self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk)
#self.queue = Queue()
#self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk))
# current location of audio
self.idx = 0
# create wav2vec model # create wav2vec model
print(f'[INFO] loading ASR model {self.opt.asr_model}...') print(f'[INFO] loading ASR model {self.opt.asr_model}...')
if 'hubert' in self.opt.asr_model: if 'hubert' in self.opt.asr_model:
@ -74,10 +45,6 @@ class ASR:
self.processor = AutoProcessor.from_pretrained(opt.asr_model) self.processor = AutoProcessor.from_pretrained(opt.asr_model)
self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device) self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device)
# prepare to save logits
if self.opt.asr_save_feats:
self.all_feats = []
# the extracted features # the extracted features
# use a loop queue to efficiently record endless features: [f--t---][-------][-------] # use a loop queue to efficiently record endless features: [f--t---][-------][-------]
self.feat_buffer_size = 4 self.feat_buffer_size = 4
@ -93,8 +60,16 @@ class ASR:
# warm up steps needed: mid + right + window_size + attention_size # warm up steps needed: mid + right + window_size + attention_size
self.warm_up_steps = self.context_size + self.stride_left_size + self.stride_right_size #+ self.stride_left_size #+ 8 + 2 * 3 self.warm_up_steps = self.context_size + self.stride_left_size + self.stride_right_size #+ self.stride_left_size #+ 8 + 2 * 3
self.listening = False def get_audio_frame(self):
self.playing = False try:
frame = self.queue.get(block=False)
type = 0
#print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32)
type = 1
return frame,type
def get_next_feat(self): #get audio embedding to nerf def get_next_feat(self): #get audio embedding to nerf
# return a [1/8, 16] window, for the next input to nerf side. # return a [1/8, 16] window, for the next input to nerf side.
@ -136,29 +111,19 @@ class ASR:
def run_step(self): def run_step(self):
if self.terminated:
return
# get a frame of audio # get a frame of audio
frame,type = self.__get_audio_frame() frame,type = self.get_audio_frame()
self.frames.append(frame)
# the last frame # put to output
if frame is None: self.output_queue.put((frame,type))
# terminate, but always run the network for the left frames # context not enough, do not run network.
self.terminated = True if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
else: return
self.frames.append(frame)
# put to output
self.output_queue.put((frame,type))
# context not enough, do not run network.
if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
return
inputs = np.concatenate(self.frames) # [N * chunk] inputs = np.concatenate(self.frames) # [N * chunk]
# discard the old part to save memory # discard the old part to save memory
if not self.terminated: self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
#print(f'[INFO] frame_to_text... ') #print(f'[INFO] frame_to_text... ')
#t = time.time() #t = time.time()
@ -166,10 +131,6 @@ class ASR:
#print(f'-------wav2vec time:{time.time()-t:.4f}s') #print(f'-------wav2vec time:{time.time()-t:.4f}s')
feats = logits # better lips-sync than labels feats = logits # better lips-sync than labels
# save feats
if self.opt.asr_save_feats:
self.all_feats.append(feats)
# record the feats efficiently.. (no concat, constant memory) # record the feats efficiently.. (no concat, constant memory)
start = self.feat_buffer_idx * self.context_size start = self.feat_buffer_idx * self.context_size
end = start + feats.shape[0] end = start + feats.shape[0]
@ -203,24 +164,6 @@ class ASR:
# np.save(output_path, unfold_feats.cpu().numpy()) # np.save(output_path, unfold_feats.cpu().numpy())
# print(f"[INFO] saved logits to {output_path}") # print(f"[INFO] saved logits to {output_path}")
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.queue.put(audio_chunk)
def __get_audio_frame(self):
if self.inwarm: # warm up
return np.zeros(self.chunk, dtype=np.float32),1
try:
frame = self.queue.get(block=False)
type = 0
print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32)
type = 1
self.idx = self.idx + self.chunk
return frame,type
def __frame_to_text(self, frame): def __frame_to_text(self, frame):
@ -241,8 +184,8 @@ class ASR:
right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input. right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input.
# do not cut right if terminated. # do not cut right if terminated.
if self.terminated: # if self.terminated:
right = logits.shape[1] # right = logits.shape[1]
logits = logits[:, left:right] logits = logits[:, left:right]
@ -262,10 +205,23 @@ class ASR:
return logits[0], None,None #predicted_ids[0], transcription # [N,] return logits[0], None,None #predicted_ids[0], transcription # [N,]
def get_audio_out(self): #get origin audio pcm to nerf def warm_up(self):
return self.output_queue.get() print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
t = time.time()
#for _ in range(self.stride_left_size):
# self.frames.append(np.zeros(self.chunk, dtype=np.float32))
for _ in range(self.warm_up_steps):
self.run_step()
#if torch.cuda.is_available():
# torch.cuda.synchronize()
t = time.time() - t
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
#self.clear_queue()
#####not used function#####################################
'''
def __init_queue(self): def __init_queue(self):
self.frames = [] self.frames = []
self.queue.queue.clear() self.queue.queue.clear()
@ -290,26 +246,6 @@ class ASR:
if self.play: if self.play:
self.output_queue.queue.clear() self.output_queue.queue.clear()
def warm_up(self):
#self.listen()
self.inwarm = True
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
t = time.time()
#for _ in range(self.stride_left_size):
# self.frames.append(np.zeros(self.chunk, dtype=np.float32))
for _ in range(self.warm_up_steps):
self.run_step()
#if torch.cuda.is_available():
# torch.cuda.synchronize()
t = time.time() - t
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
self.inwarm = False
#self.clear_queue()
#####not used function#####################################
def listen(self): def listen(self):
# start # start
if self.mode == 'live' and not self.listening: if self.mode == 'live' and not self.listening:
@ -404,4 +340,5 @@ if __name__ == '__main__':
raise ValueError("DeepSpeech features should not use this code to extract...") raise ValueError("DeepSpeech features should not use this code to extract...")
with ASR(opt) as asr: with ASR(opt) as asr:
asr.run() asr.run()
'''

View File

@ -6,60 +6,16 @@ import queue
from queue import Queue from queue import Queue
import multiprocessing as mp import multiprocessing as mp
from baseasr import BaseASR
from wav2lip import audio from wav2lip import audio
class LipASR: class LipASR(BaseASR):
def __init__(self, opt):
self.opt = opt
self.fps = opt.fps # 20 ms per frame
self.sample_rate = 16000
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.queue = Queue()
# self.input_stream = BytesIO()
self.output_queue = mp.Queue()
#self.audio_processor = audio_processor
self.batch_size = opt.batch_size
self.frames = []
self.stride_left_size = opt.l
self.stride_right_size = opt.r
#self.context_size = 10
self.feat_queue = mp.Queue(5)
self.warm_up()
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.queue.put(audio_chunk)
def __get_audio_frame(self):
try:
frame = self.queue.get(block=True,timeout=0.01)
type = 0
#print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32)
type = 1
return frame,type
def get_audio_out(self): #get origin audio pcm to nerf
return self.output_queue.get()
def warm_up(self):
for _ in range(self.stride_left_size + self.stride_right_size):
audio_frame,type=self.__get_audio_frame()
self.frames.append(audio_frame)
self.output_queue.put((audio_frame,type))
for _ in range(self.stride_left_size):
self.output_queue.get()
def run_step(self): def run_step(self):
############################################## extract audio feature ############################################## ############################################## extract audio feature ##############################################
# get a frame of audio # get a frame of audio
for _ in range(self.batch_size*2): for _ in range(self.batch_size*2):
frame,type = self.__get_audio_frame() frame,type = self.get_audio_frame()
self.frames.append(frame) self.frames.append(frame)
# put to output # put to output
self.output_queue.put((frame,type)) self.output_queue.put((frame,type))
@ -89,7 +45,3 @@ class LipASR:
# discard the old part to save memory # discard the old part to save memory
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
def get_next_feat(self,block,timeout):
return self.feat_queue.get(block,timeout)

View File

@ -164,6 +164,7 @@ class LipReal:
self.__loadavatar() self.__loadavatar()
self.asr = LipASR(opt) self.asr = LipASR(opt)
self.asr.warm_up()
if opt.tts == "edgetts": if opt.tts == "edgetts":
self.tts = EdgeTTS(opt,self) self.tts = EdgeTTS(opt,self)
elif opt.tts == "gpt-sovits": elif opt.tts == "gpt-sovits":
@ -199,6 +200,10 @@ class LipReal:
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.asr.put_audio_frame(audio_chunk) self.asr.put_audio_frame(audio_chunk)
def pause_talk(self):
self.tts.pause_talk()
self.asr.pause_talk()
def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None): def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):
@ -257,9 +262,12 @@ class LipReal:
t = time.perf_counter() t = time.perf_counter()
self.asr.run_step() self.asr.run_step()
if video_track._queue.qsize()>=2*self.opt.batch_size: # if video_track._queue.qsize()>=2*self.opt.batch_size:
# print('sleep qsize=',video_track._queue.qsize())
# time.sleep(0.04*video_track._queue.qsize()*0.8)
if video_track._queue.qsize()>=5:
print('sleep qsize=',video_track._queue.qsize()) print('sleep qsize=',video_track._queue.qsize())
time.sleep(0.04*self.opt.batch_size*1.5) time.sleep(0.04*video_track._queue.qsize()*0.8)
# delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms # delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
# if delay > 0: # if delay > 0:

View File

@ -1,65 +1,22 @@
import time import time
import torch
import numpy as np import numpy as np
import queue import queue
from queue import Queue from queue import Queue
import multiprocessing as mp import multiprocessing as mp
from baseasr import BaseASR
from musetalk.whisper.audio2feature import Audio2Feature from musetalk.whisper.audio2feature import Audio2Feature
class MuseASR: class MuseASR(BaseASR):
def __init__(self, opt, audio_processor:Audio2Feature): def __init__(self, opt, audio_processor:Audio2Feature):
self.opt = opt super().__init__(opt)
self.fps = opt.fps # 20 ms per frame
self.sample_rate = 16000
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.queue = Queue()
# self.input_stream = BytesIO()
self.output_queue = mp.Queue()
self.audio_processor = audio_processor self.audio_processor = audio_processor
self.batch_size = opt.batch_size
self.frames = []
self.stride_left_size = opt.l
self.stride_right_size = opt.r
self.feat_queue = mp.Queue(5)
self.warm_up()
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.queue.put(audio_chunk)
def __get_audio_frame(self):
try:
frame = self.queue.get(block=True,timeout=0.01)
type = 0
#print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32)
type = 1
return frame,type
def get_audio_out(self): #get origin audio pcm to nerf
return self.output_queue.get()
def warm_up(self):
for _ in range(self.stride_left_size + self.stride_right_size):
audio_frame,type=self.__get_audio_frame()
self.frames.append(audio_frame)
self.output_queue.put((audio_frame,type))
for _ in range(self.stride_left_size):
self.output_queue.get()
def run_step(self): def run_step(self):
############################################## extract audio feature ############################################## ############################################## extract audio feature ##############################################
start_time = time.time() start_time = time.time()
for _ in range(self.batch_size*2): for _ in range(self.batch_size*2):
audio_frame,type=self.__get_audio_frame() audio_frame,type=self.get_audio_frame()
self.frames.append(audio_frame) self.frames.append(audio_frame)
self.output_queue.put((audio_frame,type)) self.output_queue.put((audio_frame,type))
@ -77,6 +34,3 @@ class MuseASR:
self.feat_queue.put(whisper_chunks) self.feat_queue.put(whisper_chunks)
# discard the old part to save memory # discard the old part to save memory
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
def get_next_feat(self,block,timeout):
return self.feat_queue.get(block,timeout)

View File

@ -2,7 +2,7 @@ import math
import torch import torch
import numpy as np import numpy as np
# from .utils import * #from .utils import *
import subprocess import subprocess
import os import os
import time import time
@ -18,19 +18,17 @@ from threading import Thread, Event
from io import BytesIO from io import BytesIO
import multiprocessing as mp import multiprocessing as mp
from musetalk.utils.utils import get_file_type, get_video_fps, datagen from musetalk.utils.utils import get_file_type,get_video_fps,datagen
# from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder #from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
from musetalk.utils.blending import get_image, get_image_prepare_material, get_image_blending from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending
from musetalk.utils.utils import load_all_model, load_diffusion_model, load_audio_model from musetalk.utils.utils import load_all_model,load_diffusion_model,load_audio_model
from ttsreal import EdgeTTS, VoitsTTS, XTTS from ttsreal import EdgeTTS,VoitsTTS,XTTS
from museasr import MuseASR from museasr import MuseASR
import asyncio import asyncio
from av import AudioFrame, VideoFrame from av import AudioFrame, VideoFrame
from tqdm import tqdm from tqdm import tqdm
def read_imgs(img_list): def read_imgs(img_list):
frames = [] frames = []
print('reading images...') print('reading images...')
@ -39,146 +37,143 @@ def read_imgs(img_list):
frames.append(frame) frames.append(frame)
return frames return frames
def __mirror_index(size, index): def __mirror_index(size, index):
# size = len(self.coord_list_cycle) #size = len(self.coord_list_cycle)
turn = index // size turn = index // size
res = index % size res = index % size
if turn % 2 == 0: if turn % 2 == 0:
return res return res
else: else:
return size - res - 1 return size - res - 1
def inference(render_event, batch_size, latents_out_path, audio_feat_queue, audio_out_queue, res_frame_queue,
): # vae, unet, pe,timesteps
def inference(render_event,batch_size,latents_out_path,audio_feat_queue,audio_out_queue,res_frame_queue,
): #vae, unet, pe,timesteps
vae, unet, pe = load_diffusion_model() vae, unet, pe = load_diffusion_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = torch.tensor([0], device=device) timesteps = torch.tensor([0], device=device)
pe = pe.half() pe = pe.half()
vae.vae = vae.vae.half() vae.vae = vae.vae.half()
unet.model = unet.model.half() unet.model = unet.model.half()
input_latent_list_cycle = torch.load(latents_out_path) input_latent_list_cycle = torch.load(latents_out_path)
length = len(input_latent_list_cycle) length = len(input_latent_list_cycle)
index = 0 index = 0
count = 0 count=0
counttime = 0 counttime=0
print('start inference') print('start inference')
while True: while True:
if render_event.is_set(): if render_event.is_set():
starttime = time.perf_counter() starttime=time.perf_counter()
try: try:
whisper_chunks = audio_feat_queue.get(block=True, timeout=1) whisper_chunks = audio_feat_queue.get(block=True, timeout=1)
except queue.Empty: except queue.Empty:
continue continue
is_all_silence = True is_all_silence=True
audio_frames = [] audio_frames = []
for _ in range(batch_size * 2): for _ in range(batch_size*2):
frame, type = audio_out_queue.get() frame,type = audio_out_queue.get()
audio_frames.append((frame, type)) audio_frames.append((frame,type))
if type == 0: if type==0:
is_all_silence = False is_all_silence=False
if is_all_silence: if is_all_silence:
for i in range(batch_size): for i in range(batch_size):
res_frame_queue.put((None, __mirror_index(length, index), audio_frames[i * 2:i * 2 + 2])) res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
index = index + 1 index = index + 1
else: else:
# print('infer=======') # print('infer=======')
t = time.perf_counter() t=time.perf_counter()
whisper_batch = np.stack(whisper_chunks) whisper_batch = np.stack(whisper_chunks)
latent_batch = [] latent_batch = []
for i in range(batch_size): for i in range(batch_size):
idx = __mirror_index(length, index + i) idx = __mirror_index(length,index+i)
latent = input_latent_list_cycle[idx] latent = input_latent_list_cycle[idx]
latent_batch.append(latent) latent_batch.append(latent)
latent_batch = torch.cat(latent_batch, dim=0) latent_batch = torch.cat(latent_batch, dim=0)
# for i, (whisper_batch,latent_batch) in enumerate(gen): # for i, (whisper_batch,latent_batch) in enumerate(gen):
audio_feature_batch = torch.from_numpy(whisper_batch) audio_feature_batch = torch.from_numpy(whisper_batch)
audio_feature_batch = audio_feature_batch.to(device=unet.device, audio_feature_batch = audio_feature_batch.to(device=unet.device,
dtype=unet.model.dtype) dtype=unet.model.dtype)
audio_feature_batch = pe(audio_feature_batch) audio_feature_batch = pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype) latent_batch = latent_batch.to(dtype=unet.model.dtype)
# print('prepare time:',time.perf_counter()-t) # print('prepare time:',time.perf_counter()-t)
# t=time.perf_counter() # t=time.perf_counter()
pred_latents = unet.model(latent_batch, pred_latents = unet.model(latent_batch,
timesteps, timesteps,
encoder_hidden_states=audio_feature_batch).sample encoder_hidden_states=audio_feature_batch).sample
# print('unet time:',time.perf_counter()-t) # print('unet time:',time.perf_counter()-t)
# t=time.perf_counter() # t=time.perf_counter()
recon = vae.decode_latents(pred_latents) recon = vae.decode_latents(pred_latents)
# print('vae time:',time.perf_counter()-t) # print('vae time:',time.perf_counter()-t)
# print('diffusion len=',len(recon)) #print('diffusion len=',len(recon))
counttime += (time.perf_counter() - t) counttime += (time.perf_counter() - t)
count += batch_size count += batch_size
# _totalframe += 1 #_totalframe += 1
if count >= 100: if count>=100:
print(f"------actual avg infer fps:{count / counttime:.4f}") print(f"------actual avg infer fps:{count/counttime:.4f}")
count = 0 count=0
counttime = 0 counttime=0
for i, res_frame in enumerate(recon): for i,res_frame in enumerate(recon):
# self.__pushmedia(res_frame,loop,audio_track,video_track) #self.__pushmedia(res_frame,loop,audio_track,video_track)
res_frame_queue.put((res_frame, __mirror_index(length, index), audio_frames[i * 2:i * 2 + 2])) res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
index = index + 1 index = index + 1
# print('total batch time:',time.perf_counter()-starttime) #print('total batch time:',time.perf_counter()-starttime)
else: else:
time.sleep(1) time.sleep(1)
print('musereal inference processor stop') print('musereal inference processor stop')
@torch.no_grad() @torch.no_grad()
class MuseReal: class MuseReal:
def __init__(self, opt): def __init__(self, opt):
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.W = opt.W self.W = opt.W
self.H = opt.H self.H = opt.H
self.fps = opt.fps # 20 ms per frame self.fps = opt.fps # 20 ms per frame
#### musetalk #### musetalk
self.avatar_id = opt.avatar_id self.avatar_id = opt.avatar_id
self.static_img = opt.static_img self.video_path = '' #video_path
self.video_path = '' # video_path
self.bbox_shift = opt.bbox_shift self.bbox_shift = opt.bbox_shift
self.avatar_path = f"./data/avatars/{self.avatar_id}" self.avatar_path = f"./data/avatars/{self.avatar_id}"
self.full_imgs_path = f"{self.avatar_path}/full_imgs" self.full_imgs_path = f"{self.avatar_path}/full_imgs"
self.coords_path = f"{self.avatar_path}/coords.pkl" self.coords_path = f"{self.avatar_path}/coords.pkl"
self.latents_out_path = f"{self.avatar_path}/latents.pt" self.latents_out_path= f"{self.avatar_path}/latents.pt"
self.video_out_path = f"{self.avatar_path}/vid_output/" self.video_out_path = f"{self.avatar_path}/vid_output/"
self.mask_out_path = f"{self.avatar_path}/mask" self.mask_out_path =f"{self.avatar_path}/mask"
self.mask_coords_path = f"{self.avatar_path}/mask_coords.pkl" self.mask_coords_path =f"{self.avatar_path}/mask_coords.pkl"
self.avatar_info_path = f"{self.avatar_path}/avator_info.json" self.avatar_info_path = f"{self.avatar_path}/avator_info.json"
self.avatar_info = { self.avatar_info = {
"avatar_id": self.avatar_id, "avatar_id":self.avatar_id,
"video_path": self.video_path, "video_path":self.video_path,
"bbox_shift": self.bbox_shift "bbox_shift":self.bbox_shift
} }
self.batch_size = opt.batch_size self.batch_size = opt.batch_size
self.idx = 0 self.idx = 0
self.res_frame_queue = mp.Queue(self.batch_size * 2) self.res_frame_queue = mp.Queue(self.batch_size*2)
self.__loadmodels() self.__loadmodels()
self.__loadavatar() self.__loadavatar()
self.asr = MuseASR(opt, self.audio_processor) self.asr = MuseASR(opt,self.audio_processor)
self.asr.warm_up()
if opt.tts == "edgetts": if opt.tts == "edgetts":
self.tts = EdgeTTS(opt, self) self.tts = EdgeTTS(opt,self)
elif opt.tts == "gpt-sovits": elif opt.tts == "gpt-sovits":
self.tts = VoitsTTS(opt, self) self.tts = VoitsTTS(opt,self)
elif opt.tts == "xtts": elif opt.tts == "xtts":
self.tts = XTTS(opt, self) self.tts = XTTS(opt,self)
# self.__warm_up() #self.__warm_up()
self.render_event = mp.Event() self.render_event = mp.Event()
mp.Process(target=inference, args=(self.render_event, self.batch_size, self.latents_out_path, mp.Process(target=inference, args=(self.render_event,self.batch_size,self.latents_out_path,
self.asr.feat_queue, self.asr.output_queue, self.res_frame_queue, self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue,
)).start() # self.vae, self.unet, self.pe,self.timesteps )).start() #self.vae, self.unet, self.pe,self.timesteps
def __loadmodels(self): def __loadmodels(self):
# load model weights # load model weights
self.audio_processor = load_audio_model() self.audio_processor= load_audio_model()
# self.audio_processor, self.vae, self.unet, self.pe = load_all_model() # self.audio_processor, self.vae, self.unet, self.pe = load_all_model()
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# self.timesteps = torch.tensor([0], device=device) # self.timesteps = torch.tensor([0], device=device)
@ -187,7 +182,7 @@ class MuseReal:
# self.unet.model = self.unet.model.half() # self.unet.model = self.unet.model.half()
def __loadavatar(self): def __loadavatar(self):
# self.input_latent_list_cycle = torch.load(self.latents_out_path) #self.input_latent_list_cycle = torch.load(self.latents_out_path)
with open(self.coords_path, 'rb') as f: with open(self.coords_path, 'rb') as f:
self.coord_list_cycle = pickle.load(f) self.coord_list_cycle = pickle.load(f)
input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]')) input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
@ -198,13 +193,19 @@ class MuseReal:
input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]')) input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]'))
input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
self.mask_list_cycle = read_imgs(input_mask_list) self.mask_list_cycle = read_imgs(input_mask_list)
def put_msg_txt(self, msg):
def put_msg_txt(self,msg):
self.tts.put_msg_txt(msg) self.tts.put_msg_txt(msg)
def put_audio_frame(self, audio_chunk): # 16khz 20ms pcm def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.asr.put_audio_frame(audio_chunk) self.asr.put_audio_frame(audio_chunk)
def pause_talk(self):
self.tts.pause_talk()
self.asr.pause_talk()
def __mirror_index(self, index): def __mirror_index(self, index):
size = len(self.coord_list_cycle) size = len(self.coord_list_cycle)
turn = index // size turn = index // size
@ -212,15 +213,15 @@ class MuseReal:
if turn % 2 == 0: if turn % 2 == 0:
return res return res
else: else:
return size - res - 1 return size - res - 1
def __warm_up(self): def __warm_up(self):
self.asr.run_step() self.asr.run_step()
whisper_chunks = self.asr.get_next_feat() whisper_chunks = self.asr.get_next_feat()
whisper_batch = np.stack(whisper_chunks) whisper_batch = np.stack(whisper_chunks)
latent_batch = [] latent_batch = []
for i in range(self.batch_size): for i in range(self.batch_size):
idx = self.__mirror_index(self.idx + i) idx = self.__mirror_index(self.idx+i)
latent = self.input_latent_list_cycle[idx] latent = self.input_latent_list_cycle[idx]
latent_batch.append(latent) latent_batch.append(latent)
latent_batch = torch.cat(latent_batch, dim=0) latent_batch = torch.cat(latent_batch, dim=0)
@ -228,88 +229,90 @@ class MuseReal:
# for i, (whisper_batch,latent_batch) in enumerate(gen): # for i, (whisper_batch,latent_batch) in enumerate(gen):
audio_feature_batch = torch.from_numpy(whisper_batch) audio_feature_batch = torch.from_numpy(whisper_batch)
audio_feature_batch = audio_feature_batch.to(device=self.unet.device, audio_feature_batch = audio_feature_batch.to(device=self.unet.device,
dtype=self.unet.model.dtype) dtype=self.unet.model.dtype)
audio_feature_batch = self.pe(audio_feature_batch) audio_feature_batch = self.pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=self.unet.model.dtype) latent_batch = latent_batch.to(dtype=self.unet.model.dtype)
pred_latents = self.unet.model(latent_batch, pred_latents = self.unet.model(latent_batch,
self.timesteps, self.timesteps,
encoder_hidden_states=audio_feature_batch).sample encoder_hidden_states=audio_feature_batch).sample
recon = self.vae.decode_latents(pred_latents) recon = self.vae.decode_latents(pred_latents)
def process_frames(self, quit_event, loop=None, audio_track=None, video_track=None): def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):
while not quit_event.is_set(): while not quit_event.is_set():
try: try:
res_frame, idx, audio_frames = self.res_frame_queue.get(block=True, timeout=1) res_frame,idx,audio_frames = self.res_frame_queue.get(block=True, timeout=1)
except queue.Empty: except queue.Empty:
continue continue
if audio_frames[0][1] == 1 and audio_frames[1][1] == 1: # 全为静音数据只需要取fullimg if audio_frames[0][1]==1 and audio_frames[1][1]==1: #全为静音数据只需要取fullimg
if self.static_img: combine_frame = self.frame_list_cycle[idx]
combine_frame = self.frame_list_cycle[0]
else:
combine_frame = self.frame_list_cycle[idx]
else: else:
bbox = self.coord_list_cycle[idx] bbox = self.coord_list_cycle[idx]
ori_frame = copy.deepcopy(self.frame_list_cycle[idx]) ori_frame = copy.deepcopy(self.frame_list_cycle[idx])
x1, y1, x2, y2 = bbox x1, y1, x2, y2 = bbox
try: try:
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1)) res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
except: except:
continue continue
mask = self.mask_list_cycle[idx] mask = self.mask_list_cycle[idx]
mask_crop_box = self.mask_coords_list_cycle[idx] mask_crop_box = self.mask_coords_list_cycle[idx]
# combine_frame = get_image(ori_frame,res_frame,bbox) #combine_frame = get_image(ori_frame,res_frame,bbox)
# t=time.perf_counter() #t=time.perf_counter()
combine_frame = get_image_blending(ori_frame, res_frame, bbox, mask, mask_crop_box) combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
# print('blending time:',time.perf_counter()-t) #print('blending time:',time.perf_counter()-t)
image = combine_frame # (outputs['image'] * 255).astype(np.uint8) image = combine_frame #(outputs['image'] * 255).astype(np.uint8)
new_frame = VideoFrame.from_ndarray(image, format="bgr24") new_frame = VideoFrame.from_ndarray(image, format="bgr24")
asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop) asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop)
for audio_frame in audio_frames: for audio_frame in audio_frames:
frame, type = audio_frame frame,type = audio_frame
frame = (frame * 32767).astype(np.int16) frame = (frame * 32767).astype(np.int16)
new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
new_frame.planes[0].update(frame.tobytes()) new_frame.planes[0].update(frame.tobytes())
new_frame.sample_rate = 16000 new_frame.sample_rate=16000
# if audio_track._queue.qsize()>10: # if audio_track._queue.qsize()>10:
# time.sleep(0.1) # time.sleep(0.1)
asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop) asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop)
print('musereal process_frames thread stop') print('musereal process_frames thread stop')
def render(self, quit_event, loop=None, audio_track=None, video_track=None): def render(self,quit_event,loop=None,audio_track=None,video_track=None):
# if self.opt.asr: #if self.opt.asr:
# self.asr.warm_up() # self.asr.warm_up()
self.tts.render(quit_event) self.tts.render(quit_event)
process_thread = Thread(target=self.process_frames, args=(quit_event, loop, audio_track, video_track)) process_thread = Thread(target=self.process_frames, args=(quit_event,loop,audio_track,video_track))
process_thread.start() process_thread.start()
self.render_event.set() # start infer process render self.render_event.set() #start infer process render
count = 0 count=0
totaltime = 0 totaltime=0
_starttime = time.perf_counter() _starttime=time.perf_counter()
# _totalframe=0 #_totalframe=0
while not quit_event.is_set(): # todo while not quit_event.is_set(): #todo
# update texture every frame # update texture every frame
# audio stream thread... # audio stream thread...
t = time.perf_counter() t = time.perf_counter()
self.asr.run_step() self.asr.run_step()
# self.test_step(loop,audio_track,video_track) #self.test_step(loop,audio_track,video_track)
# totaltime += (time.perf_counter() - t) # totaltime += (time.perf_counter() - t)
# count += self.opt.batch_size # count += self.opt.batch_size
# if count>=100: # if count>=100:
# print(f"------actual avg infer fps:{count/totaltime:.4f}") # print(f"------actual avg infer fps:{count/totaltime:.4f}")
# count=0 # count=0
# totaltime=0 # totaltime=0
if video_track._queue.qsize() >= 2 * self.opt.batch_size: if video_track._queue.qsize()>=1.5*self.opt.batch_size:
print('sleep qsize=', video_track._queue.qsize()) print('sleep qsize=',video_track._queue.qsize())
time.sleep(0.04 * self.opt.batch_size * 1.5) time.sleep(0.04*video_track._queue.qsize()*0.8)
# if video_track._queue.qsize()>=5:
# print('sleep qsize=',video_track._queue.qsize())
# time.sleep(0.04*video_track._queue.qsize()*0.8)
# delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms # delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
# if delay > 0: # if delay > 0:
# time.sleep(delay) # time.sleep(delay)
self.render_event.clear() # end infer process render self.render_event.clear() #end infer process render
print('musereal thread stop') print('musereal thread stop')

View File

@ -7,14 +7,15 @@ from PIL import Image
from .model import BiSeNet from .model import BiSeNet
import torchvision.transforms as transforms import torchvision.transforms as transforms
class FaceParsing(): class FaceParsing():
def __init__(self, resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth', def __init__(self,resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
model_pth='./models/face-parse-bisent/79999_iter.pth'): model_pth='./models/face-parse-bisent/79999_iter.pth'):
self.net = self.model_init(resnet_path,model_pth) self.net = self.model_init(resnet_path,model_pth)
self.preprocess = self.image_preprocess() self.preprocess = self.image_preprocess()
def model_init(self,resnet_path, model_pth): def model_init(self,
resnet_path,
model_pth):
net = BiSeNet(resnet_path) net = BiSeNet(resnet_path)
if torch.cuda.is_available(): if torch.cuda.is_available():
net.cuda() net.cuda()
@ -44,13 +45,13 @@ class FaceParsing():
img = torch.unsqueeze(img, 0) img = torch.unsqueeze(img, 0)
out = self.net(img)[0] out = self.net(img)[0]
parsing = out.squeeze(0).cpu().numpy().argmax(0) parsing = out.squeeze(0).cpu().numpy().argmax(0)
parsing[np.where(parsing > 13)] = 0 parsing[np.where(parsing>13)] = 0
parsing[np.where(parsing >= 1)] = 255 parsing[np.where(parsing>=1)] = 255
parsing = Image.fromarray(parsing.astype(np.uint8)) parsing = Image.fromarray(parsing.astype(np.uint8))
return parsing return parsing
if __name__ == "__main__": if __name__ == "__main__":
fp = FaceParsing() fp = FaceParsing()
segmap = fp('154_small.png') segmap = fp('154_small.png')
segmap.save('res.png') segmap.save('res.png')

View File

@ -20,9 +20,6 @@ class NeRFReal:
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.W = opt.W self.W = opt.W
self.H = opt.H self.H = opt.H
self.debug = debug
self.training = False
self.step = 0 # training step
self.trainer = trainer self.trainer = trainer
self.data_loader = data_loader self.data_loader = data_loader
@ -44,7 +41,6 @@ class NeRFReal:
#self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item() #self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item()
# playing seq from dataloader, or pause. # playing seq from dataloader, or pause.
self.playing = True #False todo
self.loader = iter(data_loader) self.loader = iter(data_loader)
#self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32) #self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
@ -62,9 +58,8 @@ class NeRFReal:
self.customimg_index = 0 self.customimg_index = 0
# build asr # build asr
if self.opt.asr: self.asr = ASR(opt)
self.asr = ASR(opt) self.asr.warm_up()
self.asr.warm_up()
if opt.tts == "edgetts": if opt.tts == "edgetts":
self.tts = EdgeTTS(opt,self) self.tts = EdgeTTS(opt,self)
elif opt.tts == "gpt-sovits": elif opt.tts == "gpt-sovits":
@ -122,7 +117,11 @@ class NeRFReal:
self.tts.put_msg_txt(msg) self.tts.put_msg_txt(msg)
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.asr.put_audio_frame(audio_chunk) self.asr.put_audio_frame(audio_chunk)
def pause_talk(self):
self.tts.pause_talk()
self.asr.pause_talk()
def mirror_index(self, index): def mirror_index(self, index):
@ -248,10 +247,9 @@ class NeRFReal:
# update texture every frame # update texture every frame
# audio stream thread... # audio stream thread...
t = time.perf_counter() t = time.perf_counter()
if self.opt.asr and self.playing: # run 2 ASR steps (audio is at 50FPS, video is at 25FPS)
# run 2 ASR steps (audio is at 50FPS, video is at 25FPS) for _ in range(2):
for _ in range(2): self.asr.run_step()
self.asr.run_step()
self.test_step(loop,audio_track,video_track) self.test_step(loop,audio_track,video_track)
totaltime += (time.perf_counter() - t) totaltime += (time.perf_counter() - t)
count += 1 count += 1
@ -267,7 +265,7 @@ class NeRFReal:
else: else:
if video_track._queue.qsize()>=5: if video_track._queue.qsize()>=5:
#print('sleep qsize=',video_track._queue.qsize()) #print('sleep qsize=',video_track._queue.qsize())
time.sleep(0.1) time.sleep(0.04*video_track._queue.qsize()*0.8)
print('nerfreal thread stop') print('nerfreal thread stop')

View File

@ -13,6 +13,11 @@ import queue
from queue import Queue from queue import Queue
from io import BytesIO from io import BytesIO
from threading import Thread, Event from threading import Thread, Event
from enum import Enum
class State(Enum):
RUNNING=0
PAUSE=1
class BaseTTS: class BaseTTS:
def __init__(self, opt, parent): def __init__(self, opt, parent):
@ -25,6 +30,11 @@ class BaseTTS:
self.input_stream = BytesIO() self.input_stream = BytesIO()
self.msgqueue = Queue() self.msgqueue = Queue()
self.state = State.RUNNING
def pause_talk(self):
self.msgqueue.queue.clear()
self.state = State.PAUSE
def put_msg_txt(self,msg): def put_msg_txt(self,msg):
self.msgqueue.put(msg) self.msgqueue.put(msg)
@ -37,6 +47,7 @@ class BaseTTS:
while not quit_event.is_set(): while not quit_event.is_set():
try: try:
msg = self.msgqueue.get(block=True, timeout=1) msg = self.msgqueue.get(block=True, timeout=1)
self.state=State.RUNNING
except queue.Empty: except queue.Empty:
continue continue
self.txt_to_audio(msg) self.txt_to_audio(msg)
@ -59,7 +70,7 @@ class EdgeTTS(BaseTTS):
stream = self.__create_bytes_stream(self.input_stream) stream = self.__create_bytes_stream(self.input_stream)
streamlen = stream.shape[0] streamlen = stream.shape[0]
idx=0 idx=0
while streamlen >= self.chunk: while streamlen >= self.chunk and self.state==State.RUNNING:
self.parent.put_audio_frame(stream[idx:idx+self.chunk]) self.parent.put_audio_frame(stream[idx:idx+self.chunk])
streamlen -= self.chunk streamlen -= self.chunk
idx += self.chunk idx += self.chunk
@ -92,7 +103,7 @@ class EdgeTTS(BaseTTS):
async for chunk in communicate.stream(): async for chunk in communicate.stream():
if first: if first:
first = False first = False
if chunk["type"] == "audio": if chunk["type"] == "audio" and self.state==State.RUNNING:
#self.push_audio(chunk["data"]) #self.push_audio(chunk["data"])
self.input_stream.write(chunk["data"]) self.input_stream.write(chunk["data"])
#file.write(chunk["data"]) #file.write(chunk["data"])
@ -147,7 +158,7 @@ class VoitsTTS(BaseTTS):
end = time.perf_counter() end = time.perf_counter()
print(f"gpt_sovits Time to first chunk: {end-start}s") print(f"gpt_sovits Time to first chunk: {end-start}s")
first = False first = False
if chunk: if chunk and self.state==State.RUNNING:
yield chunk yield chunk
print("gpt_sovits response.elapsed:", res.elapsed) print("gpt_sovits response.elapsed:", res.elapsed)

View File

@ -29,22 +29,22 @@
$(document).ready(function() { $(document).ready(function() {
var host = window.location.hostname var host = window.location.hostname
var ws = new WebSocket("ws://"+host+":8000/humanchat"); // var ws = new WebSocket("ws://"+host+":8000/humanecho");
//document.getElementsByTagName("video")[0].setAttribute("src", aa["video"]); // //document.getElementsByTagName("video")[0].setAttribute("src", aa["video"]);
ws.onopen = function() { // ws.onopen = function() {
console.log('Connected'); // console.log('Connected');
}; // };
ws.onmessage = function(e) { // ws.onmessage = function(e) {
console.log('Received: ' + e.data); // console.log('Received: ' + e.data);
data = e // data = e
var vid = JSON.parse(data.data); // var vid = JSON.parse(data.data);
console.log(typeof(vid),vid) // console.log(typeof(vid),vid)
//document.getElementsByTagName("video")[0].setAttribute("src", vid["video"]); // //document.getElementsByTagName("video")[0].setAttribute("src", vid["video"]);
}; // };
ws.onclose = function(e) { // ws.onclose = function(e) {
console.log('Closed'); // console.log('Closed');
}; // };
flvPlayer = mpegts.createPlayer({type: 'flv', url: "http://"+host+":8080/live/livestream.flv", isLive: true, enableStashBuffer: false}); flvPlayer = mpegts.createPlayer({type: 'flv', url: "http://"+host+":8080/live/livestream.flv", isLive: true, enableStashBuffer: false});
flvPlayer.attachMediaElement(document.getElementById('video_player')); flvPlayer.attachMediaElement(document.getElementById('video_player'));
@ -55,9 +55,19 @@
e.preventDefault(); e.preventDefault();
var message = $('#message').val(); var message = $('#message').val();
console.log('Sending: ' + message); console.log('Sending: ' + message);
ws.send(message); fetch('/human', {
body: JSON.stringify({
text: message,
type: 'chat',
}),
headers: {
'Content-Type': 'application/json'
},
method: 'POST'
});
//ws.send(message);
$('#message').val(''); $('#message').val('');
}); });
}); });
</script> </script>
</html> </html>

View File

@ -51,30 +51,40 @@
<script type="text/javascript" charset="utf-8"> <script type="text/javascript" charset="utf-8">
$(document).ready(function() { $(document).ready(function() {
var host = window.location.hostname // var host = window.location.hostname
var ws = new WebSocket("ws://"+host+":8000/humanchat"); // var ws = new WebSocket("ws://"+host+":8000/humanecho");
//document.getElementsByTagName("video")[0].setAttribute("src", aa["video"]); // //document.getElementsByTagName("video")[0].setAttribute("src", aa["video"]);
ws.onopen = function() { // ws.onopen = function() {
console.log('Connected'); // console.log('Connected');
}; // };
ws.onmessage = function(e) { // ws.onmessage = function(e) {
console.log('Received: ' + e.data); // console.log('Received: ' + e.data);
data = e // data = e
var vid = JSON.parse(data.data); // var vid = JSON.parse(data.data);
console.log(typeof(vid),vid) // console.log(typeof(vid),vid)
//document.getElementsByTagName("video")[0].setAttribute("src", vid["video"]); // //document.getElementsByTagName("video")[0].setAttribute("src", vid["video"]);
}; // };
ws.onclose = function(e) { // ws.onclose = function(e) {
console.log('Closed'); // console.log('Closed');
}; // };
$('#echo-form').on('submit', function(e) { $('#echo-form').on('submit', function(e) {
e.preventDefault(); e.preventDefault();
var message = $('#message').val(); var message = $('#message').val();
console.log('Sending: ' + message); console.log('Sending: ' + message);
ws.send(message); fetch('/human', {
$('#message').val(''); body: JSON.stringify({
text: message,
type: 'chat',
}),
headers: {
'Content-Type': 'application/json'
},
method: 'POST'
});
//ws.send(message);
$('#message').val('');
}); });
}); });

View File

@ -79,6 +79,7 @@
body: JSON.stringify({ body: JSON.stringify({
text: message, text: message,
type: 'echo', type: 'echo',
interrupt: true,
}), }),
headers: { headers: {
'Content-Type': 'application/json' 'Content-Type': 'application/json'

View File

@ -53,30 +53,41 @@
<script type="text/javascript" charset="utf-8"> <script type="text/javascript" charset="utf-8">
$(document).ready(function() { $(document).ready(function() {
var host = window.location.hostname // var host = window.location.hostname
var ws = new WebSocket("ws://"+host+":8000/humanchat"); // var ws = new WebSocket("ws://"+host+":8000/humanecho");
//document.getElementsByTagName("video")[0].setAttribute("src", aa["video"]); // //document.getElementsByTagName("video")[0].setAttribute("src", aa["video"]);
ws.onopen = function() { // ws.onopen = function() {
console.log('Connected'); // console.log('Connected');
}; // };
ws.onmessage = function(e) { // ws.onmessage = function(e) {
console.log('Received: ' + e.data); // console.log('Received: ' + e.data);
data = e // data = e
var vid = JSON.parse(data.data); // var vid = JSON.parse(data.data);
console.log(typeof(vid),vid) // console.log(typeof(vid),vid)
//document.getElementsByTagName("video")[0].setAttribute("src", vid["video"]); // //document.getElementsByTagName("video")[0].setAttribute("src", vid["video"]);
}; // };
ws.onclose = function(e) { // ws.onclose = function(e) {
console.log('Closed'); // console.log('Closed');
}; // };
$('#echo-form').on('submit', function(e) { $('#echo-form').on('submit', function(e) {
e.preventDefault(); e.preventDefault();
var message = $('#message').val(); var message = $('#message').val();
console.log('Sending: ' + message); console.log('Sending: ' + message);
ws.send(message); fetch('/human', {
$('#message').val(''); body: JSON.stringify({
text: message,
type: 'chat',
interrupt: true,
}),
headers: {
'Content-Type': 'application/json'
},
method: 'POST'
});
//ws.send(message);
$('#message').val('');
}); });
}); });
</script> </script>