feat: 完善修改成自动绝对路径,添加接口生成
This commit is contained in:
parent
18d7db35a7
commit
cd7d5f31b5
17
README.md
17
README.md
|
@ -6,11 +6,10 @@ Real time interactive streaming digital human, realize audio video synchronous
|
||||||
## Features
|
## Features
|
||||||
1. 支持多种数字人模型: ernerf、musetalk、wav2lip
|
1. 支持多种数字人模型: ernerf、musetalk、wav2lip
|
||||||
2. 支持声音克隆
|
2. 支持声音克隆
|
||||||
3. 支持多种音频特征驱动:wav2vec、hubert
|
3. 支持数字人说话被打断
|
||||||
4. 支持全身视频拼接
|
4. 支持全身视频拼接
|
||||||
5. 支持rtmp和webrtc
|
5. 支持rtmp和webrtc
|
||||||
6. 支持视频编排:不说话时播放自定义视频
|
6. 支持视频编排:不说话时播放自定义视频
|
||||||
7. 支持大模型对话
|
|
||||||
|
|
||||||
## 1. Installation
|
## 1. Installation
|
||||||
|
|
||||||
|
@ -171,13 +170,11 @@ cd MuseTalk
|
||||||
修改configs/inference/realtime.yaml,将preparation改为True
|
修改configs/inference/realtime.yaml,将preparation改为True
|
||||||
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml
|
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml
|
||||||
运行后将results/avatars下文件拷到本项目的data/avatars下
|
运行后将results/avatars下文件拷到本项目的data/avatars下
|
||||||
```
|
方法二
|
||||||
|
执行
|
||||||
```bash
|
cd musetalk
|
||||||
也可以试用本地目录下的 simple_musetalk.py
|
python simple_musetalk.py --avatar_id 4 --file D:\\ok\\test.mp4
|
||||||
cd musetalk
|
支持视频和图片生成 会自动生成到data的avatars目录下
|
||||||
python simple_musetalk.py --avatar_id 2 --file D:\\ok\\test.mp4
|
|
||||||
运行后将直接生成在data/avatars下
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3.10 模型用wav2lip
|
### 3.10 模型用wav2lip
|
||||||
|
@ -185,7 +182,7 @@ python simple_musetalk.py --avatar_id 2 --file D:\\ok\\test.mp4
|
||||||
- 下载模型
|
- 下载模型
|
||||||
下载wav2lip运行需要的模型,网盘地址 https://drive.uc.cn/s/551be97d7cfa4
|
下载wav2lip运行需要的模型,网盘地址 https://drive.uc.cn/s/551be97d7cfa4
|
||||||
将s3fd.pth拷到本项目wav2lip/face_detection/detection/sfd/s3fd.pth, 将wav2lip.pth拷到本项目的models下
|
将s3fd.pth拷到本项目wav2lip/face_detection/detection/sfd/s3fd.pth, 将wav2lip.pth拷到本项目的models下
|
||||||
数字人模型文件 wav2lip_avatar1.tar.gz, 解压后将整个文件夹拷到本项目的data/avatars下
|
数字人模型文件 wav2lip_avatar1.tar.gz,网盘地址 https://drive.uc.cn/s/5bd0cde0b0774, 解压后将整个文件夹拷到本项目的data/avatars下
|
||||||
- 运行
|
- 运行
|
||||||
python app.py --transport webrtc --model wav2lip --avatar_id wav2lip_avatar1
|
python app.py --transport webrtc --model wav2lip --avatar_id wav2lip_avatar1
|
||||||
用浏览器打开http://serverip:8010/webrtcapi.html
|
用浏览器打开http://serverip:8010/webrtcapi.html
|
||||||
|
|
235
app.py
235
app.py
|
@ -1,31 +1,39 @@
|
||||||
# server.py
|
# server.py
|
||||||
import argparse
|
from flask import Flask, render_template,send_from_directory,request, jsonify
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import multiprocessing
|
|
||||||
from threading import Thread, Event
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import aiohttp_cors
|
|
||||||
from aiohttp import web
|
|
||||||
from aiortc import RTCPeerConnection, RTCSessionDescription
|
|
||||||
from flask import Flask
|
|
||||||
from flask_sockets import Sockets
|
from flask_sockets import Sockets
|
||||||
|
import base64
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import gevent
|
||||||
from gevent import pywsgi
|
from gevent import pywsgi
|
||||||
from geventwebsocket.handler import WebSocketHandler
|
from geventwebsocket.handler import WebSocketHandler
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import numpy as np
|
||||||
|
from threading import Thread,Event
|
||||||
|
import multiprocessing
|
||||||
|
|
||||||
from musetalk.simple_musetalk import create_musetalk_human
|
from aiohttp import web
|
||||||
|
import aiohttp
|
||||||
|
import aiohttp_cors
|
||||||
|
from aiortc import RTCPeerConnection, RTCSessionDescription
|
||||||
from webrtc import HumanPlayer
|
from webrtc import HumanPlayer
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
sockets = Sockets(app)
|
sockets = Sockets(app)
|
||||||
global nerfreal
|
global nerfreal
|
||||||
|
|
||||||
|
|
||||||
@sockets.route('/humanecho')
|
@sockets.route('/humanecho')
|
||||||
def echo_socket(ws):
|
def echo_socket(ws):
|
||||||
# 获取WebSocket对象
|
# 获取WebSocket对象
|
||||||
# ws = request.environ.get('wsgi.websocket')
|
#ws = request.environ.get('wsgi.websocket')
|
||||||
# 如果没有获取到,返回错误信息
|
# 如果没有获取到,返回错误信息
|
||||||
if not ws:
|
if not ws:
|
||||||
print('未建立连接!')
|
print('未建立连接!')
|
||||||
|
@ -34,11 +42,11 @@ def echo_socket(ws):
|
||||||
else:
|
else:
|
||||||
print('建立连接!')
|
print('建立连接!')
|
||||||
while True:
|
while True:
|
||||||
message = ws.receive()
|
message = ws.receive()
|
||||||
|
|
||||||
if not message or len(message) == 0:
|
if not message or len(message)==0:
|
||||||
return '输入信息为空'
|
return '输入信息为空'
|
||||||
else:
|
else:
|
||||||
nerfreal.put_msg_txt(message)
|
nerfreal.put_msg_txt(message)
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,16 +54,15 @@ def llm_response(message):
|
||||||
from llm.LLM import LLM
|
from llm.LLM import LLM
|
||||||
# llm = LLM().init_model('Gemini', model_path= 'gemini-pro',api_key='Your API Key', proxy_url=None)
|
# llm = LLM().init_model('Gemini', model_path= 'gemini-pro',api_key='Your API Key', proxy_url=None)
|
||||||
# llm = LLM().init_model('ChatGPT', model_path= 'gpt-3.5-turbo',api_key='Your API Key')
|
# llm = LLM().init_model('ChatGPT', model_path= 'gpt-3.5-turbo',api_key='Your API Key')
|
||||||
llm = LLM().init_model('VllmGPT', model_path='THUDM/chatglm3-6b')
|
llm = LLM().init_model('VllmGPT', model_path= 'THUDM/chatglm3-6b')
|
||||||
response = llm.chat(message)
|
response = llm.chat(message)
|
||||||
print(response)
|
print(response)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@sockets.route('/humanchat')
|
@sockets.route('/humanchat')
|
||||||
def chat_socket(ws):
|
def chat_socket(ws):
|
||||||
# 获取WebSocket对象
|
# 获取WebSocket对象
|
||||||
# ws = request.environ.get('wsgi.websocket')
|
#ws = request.environ.get('wsgi.websocket')
|
||||||
# 如果没有获取到,返回错误信息
|
# 如果没有获取到,返回错误信息
|
||||||
if not ws:
|
if not ws:
|
||||||
print('未建立连接!')
|
print('未建立连接!')
|
||||||
|
@ -64,20 +71,18 @@ def chat_socket(ws):
|
||||||
else:
|
else:
|
||||||
print('建立连接!')
|
print('建立连接!')
|
||||||
while True:
|
while True:
|
||||||
message = ws.receive()
|
message = ws.receive()
|
||||||
|
|
||||||
if len(message) == 0:
|
if len(message)==0:
|
||||||
return '输入信息为空'
|
return '输入信息为空'
|
||||||
else:
|
else:
|
||||||
res = llm_response(message)
|
res=llm_response(message)
|
||||||
nerfreal.put_msg_txt(res)
|
nerfreal.put_msg_txt(res)
|
||||||
|
|
||||||
|
|
||||||
#####webrtc###############################
|
#####webrtc###############################
|
||||||
pcs = set()
|
pcs = set()
|
||||||
|
|
||||||
|
#@app.route('/offer', methods=['POST'])
|
||||||
# @app.route('/offer', methods=['POST'])
|
|
||||||
async def offer(request):
|
async def offer(request):
|
||||||
params = await request.json()
|
params = await request.json()
|
||||||
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
|
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
|
||||||
|
@ -101,7 +106,7 @@ async def offer(request):
|
||||||
answer = await pc.createAnswer()
|
answer = await pc.createAnswer()
|
||||||
await pc.setLocalDescription(answer)
|
await pc.setLocalDescription(answer)
|
||||||
|
|
||||||
# return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type})
|
#return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type})
|
||||||
|
|
||||||
return web.Response(
|
return web.Response(
|
||||||
content_type="application/json",
|
content_type="application/json",
|
||||||
|
@ -110,61 +115,39 @@ async def offer(request):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def human(request):
|
async def human(request):
|
||||||
params = await request.json()
|
params = await request.json()
|
||||||
|
|
||||||
if params['type'] == 'echo':
|
if params.get('interrupt'):
|
||||||
|
nerfreal.pause_talk()
|
||||||
|
|
||||||
|
if params['type']=='echo':
|
||||||
nerfreal.put_msg_txt(params['text'])
|
nerfreal.put_msg_txt(params['text'])
|
||||||
elif params['type'] == 'chat':
|
elif params['type']=='chat':
|
||||||
res = await asyncio.get_event_loop().run_in_executor(None, llm_response(params['text']))
|
res=await asyncio.get_event_loop().run_in_executor(None, llm_response(params['text']))
|
||||||
nerfreal.put_msg_txt(res)
|
nerfreal.put_msg_txt(res)
|
||||||
|
|
||||||
return web.Response(
|
return web.Response(
|
||||||
content_type="application/json",
|
content_type="application/json",
|
||||||
text=json.dumps(
|
text=json.dumps(
|
||||||
{"code": 0, "data": "ok"}
|
{"code": 0, "data":"ok"}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def handle_create_musetalk(request):
|
|
||||||
reader = await request.multipart()
|
|
||||||
# 处理文件部分
|
|
||||||
file_part = await reader.next()
|
|
||||||
filename = file_part.filename
|
|
||||||
file_data = await file_part.read() # 读取文件的内容
|
|
||||||
# 注意:确保这个文件路径是可写的
|
|
||||||
with open(filename, 'wb') as f:
|
|
||||||
f.write(file_data)
|
|
||||||
# 处理整数部分
|
|
||||||
part = await reader.next()
|
|
||||||
avatar_id = int(await part.text())
|
|
||||||
create_musetalk_human(filename, avatar_id)
|
|
||||||
os.remove(filename)
|
|
||||||
return web.json_response({
|
|
||||||
'status': 'success',
|
|
||||||
'filename': filename,
|
|
||||||
'int_value': avatar_id,
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
async def on_shutdown(app):
|
async def on_shutdown(app):
|
||||||
# close peer connections
|
# close peer connections
|
||||||
coros = [pc.close() for pc in pcs]
|
coros = [pc.close() for pc in pcs]
|
||||||
await asyncio.gather(*coros)
|
await asyncio.gather(*coros)
|
||||||
pcs.clear()
|
pcs.clear()
|
||||||
|
|
||||||
|
async def post(url,data):
|
||||||
async def post(url, data):
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(url, data=data) as response:
|
async with session.post(url,data=data) as response:
|
||||||
return await response.text()
|
return await response.text()
|
||||||
except aiohttp.ClientError as e:
|
except aiohttp.ClientError as e:
|
||||||
print(f'Error: {e}')
|
print(f'Error: {e}')
|
||||||
|
|
||||||
|
|
||||||
async def run(push_url):
|
async def run(push_url):
|
||||||
pc = RTCPeerConnection()
|
pc = RTCPeerConnection()
|
||||||
pcs.add(pc)
|
pcs.add(pc)
|
||||||
|
@ -181,10 +164,8 @@ async def run(push_url):
|
||||||
video_sender = pc.addTrack(player.video)
|
video_sender = pc.addTrack(player.video)
|
||||||
|
|
||||||
await pc.setLocalDescription(await pc.createOffer())
|
await pc.setLocalDescription(await pc.createOffer())
|
||||||
answer = await post(push_url, pc.localDescription.sdp)
|
answer = await post(push_url,pc.localDescription.sdp)
|
||||||
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer, type='answer'))
|
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer'))
|
||||||
|
|
||||||
|
|
||||||
##########################################
|
##########################################
|
||||||
# os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
|
# os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
|
||||||
# os.environ['MULTIPROCESSING_METHOD'] = 'forkserver'
|
# os.environ['MULTIPROCESSING_METHOD'] = 'forkserver'
|
||||||
|
@ -203,20 +184,14 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
### training options
|
### training options
|
||||||
parser.add_argument('--ckpt', type=str, default='data/pretrained/ngp_kf.pth')
|
parser.add_argument('--ckpt', type=str, default='data/pretrained/ngp_kf.pth')
|
||||||
|
|
||||||
parser.add_argument('--num_rays', type=int, default=4096 * 16,
|
parser.add_argument('--num_rays', type=int, default=4096 * 16, help="num rays sampled per image for each training step")
|
||||||
help="num rays sampled per image for each training step")
|
|
||||||
parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
|
parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
|
||||||
parser.add_argument('--max_steps', type=int, default=16,
|
parser.add_argument('--max_steps', type=int, default=16, help="max num steps sampled per ray (only valid when using --cuda_ray)")
|
||||||
help="max num steps sampled per ray (only valid when using --cuda_ray)")
|
parser.add_argument('--num_steps', type=int, default=16, help="num steps sampled per ray (only valid when NOT using --cuda_ray)")
|
||||||
parser.add_argument('--num_steps', type=int, default=16,
|
parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)")
|
||||||
help="num steps sampled per ray (only valid when NOT using --cuda_ray)")
|
parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
|
||||||
parser.add_argument('--upsample_steps', type=int, default=0,
|
parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)")
|
||||||
help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)")
|
|
||||||
parser.add_argument('--update_extra_interval', type=int, default=16,
|
|
||||||
help="iter interval to update extra status (only valid when using --cuda_ray)")
|
|
||||||
parser.add_argument('--max_ray_batch', type=int, default=4096,
|
|
||||||
help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)")
|
|
||||||
|
|
||||||
### loss set
|
### loss set
|
||||||
parser.add_argument('--warmup_step', type=int, default=10000, help="warm up steps")
|
parser.add_argument('--warmup_step', type=int, default=10000, help="warm up steps")
|
||||||
|
@ -227,35 +202,27 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
### network backbone options
|
### network backbone options
|
||||||
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
|
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
|
||||||
|
|
||||||
parser.add_argument('--bg_img', type=str, default='white', help="background image")
|
parser.add_argument('--bg_img', type=str, default='white', help="background image")
|
||||||
parser.add_argument('--fbg', action='store_true', help="frame-wise bg")
|
parser.add_argument('--fbg', action='store_true', help="frame-wise bg")
|
||||||
parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes")
|
parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes")
|
||||||
parser.add_argument('--fix_eye', type=float, default=-1,
|
parser.add_argument('--fix_eye', type=float, default=-1, help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye")
|
||||||
help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye")
|
|
||||||
parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence")
|
parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence")
|
||||||
|
|
||||||
parser.add_argument('--torso_shrink', type=float, default=0.8,
|
parser.add_argument('--torso_shrink', type=float, default=0.8, help="shrink bg coords to allow more flexibility in deform")
|
||||||
help="shrink bg coords to allow more flexibility in deform")
|
|
||||||
|
|
||||||
### dataset options
|
### dataset options
|
||||||
parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)")
|
parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)")
|
||||||
parser.add_argument('--preload', type=int, default=0,
|
parser.add_argument('--preload', type=int, default=0, help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.")
|
||||||
help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.")
|
|
||||||
# (the default value is for the fox dataset)
|
# (the default value is for the fox dataset)
|
||||||
parser.add_argument('--bound', type=float, default=1,
|
parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.")
|
||||||
help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.")
|
|
||||||
parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3")
|
parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3")
|
||||||
parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location")
|
parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location")
|
||||||
parser.add_argument('--dt_gamma', type=float, default=1 / 256,
|
parser.add_argument('--dt_gamma', type=float, default=1/256, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
|
||||||
help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
|
|
||||||
parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera")
|
parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera")
|
||||||
parser.add_argument('--density_thresh', type=float, default=10,
|
parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied (sigma)")
|
||||||
help="threshold for density grid to be occupied (sigma)")
|
parser.add_argument('--density_thresh_torso', type=float, default=0.01, help="threshold for density grid to be occupied (alpha)")
|
||||||
parser.add_argument('--density_thresh_torso', type=float, default=0.01,
|
parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable")
|
||||||
help="threshold for density grid to be occupied (alpha)")
|
|
||||||
parser.add_argument('--patch_size', type=int, default=1,
|
|
||||||
help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable")
|
|
||||||
|
|
||||||
parser.add_argument('--init_lips', action='store_true', help="init lips region")
|
parser.add_argument('--init_lips', action='store_true', help="init lips region")
|
||||||
parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region")
|
parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region")
|
||||||
|
@ -273,15 +240,12 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel")
|
parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel")
|
||||||
|
|
||||||
### else
|
### else
|
||||||
parser.add_argument('--att', type=int, default=2,
|
parser.add_argument('--att', type=int, default=2, help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)")
|
||||||
help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)")
|
parser.add_argument('--aud', type=str, default='', help="audio source (empty will load the default, else should be a path to a npy file)")
|
||||||
parser.add_argument('--aud', type=str, default='',
|
|
||||||
help="audio source (empty will load the default, else should be a path to a npy file)")
|
|
||||||
parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits")
|
parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits")
|
||||||
|
|
||||||
parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off")
|
parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off")
|
||||||
parser.add_argument('--ind_num', type=int, default=10000,
|
parser.add_argument('--ind_num', type=int, default=10000, help="number of individual codes, should be larger than training dataset size")
|
||||||
help="number of individual codes, should be larger than training dataset size")
|
|
||||||
|
|
||||||
parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off")
|
parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off")
|
||||||
|
|
||||||
|
@ -290,8 +254,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)")
|
parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)")
|
||||||
|
|
||||||
parser.add_argument('--train_camera', action='store_true', help="optimize camera pose")
|
parser.add_argument('--train_camera', action='store_true', help="optimize camera pose")
|
||||||
parser.add_argument('--smooth_path', action='store_true',
|
parser.add_argument('--smooth_path', action='store_true', help="brute-force smooth camera pose trajectory with a window size")
|
||||||
help="brute-force smooth camera pose trajectory with a window size")
|
|
||||||
parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size")
|
parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size")
|
||||||
|
|
||||||
# asr
|
# asr
|
||||||
|
@ -299,8 +262,8 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--asr_wav', type=str, default='', help="load the wav and use as input")
|
parser.add_argument('--asr_wav', type=str, default='', help="load the wav and use as input")
|
||||||
parser.add_argument('--asr_play', action='store_true', help="play out the audio")
|
parser.add_argument('--asr_play', action='store_true', help="play out the audio")
|
||||||
|
|
||||||
# parser.add_argument('--asr_model', type=str, default='deepspeech')
|
#parser.add_argument('--asr_model', type=str, default='deepspeech')
|
||||||
parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') #
|
parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') #
|
||||||
# parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
|
# parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
|
||||||
# parser.add_argument('--asr_model', type=str, default='facebook/hubert-large-ls960-ft')
|
# parser.add_argument('--asr_model', type=str, default='facebook/hubert-large-ls960-ft')
|
||||||
|
|
||||||
|
@ -319,45 +282,42 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--fullbody_offset_x', type=int, default=0)
|
parser.add_argument('--fullbody_offset_x', type=int, default=0)
|
||||||
parser.add_argument('--fullbody_offset_y', type=int, default=0)
|
parser.add_argument('--fullbody_offset_y', type=int, default=0)
|
||||||
|
|
||||||
# musetalk opt
|
#musetalk opt
|
||||||
parser.add_argument('--avatar_id', type=str, default='avator_1')
|
parser.add_argument('--avatar_id', type=str, default='avator_1')
|
||||||
parser.add_argument('--bbox_shift', type=int, default=5)
|
parser.add_argument('--bbox_shift', type=int, default=5)
|
||||||
parser.add_argument('--batch_size', type=int, default=16)
|
parser.add_argument('--batch_size', type=int, default=16)
|
||||||
|
|
||||||
parser.add_argument('--customvideo', action='store_true', help="custom video")
|
parser.add_argument('--customvideo', action='store_true', help="custom video")
|
||||||
parser.add_argument('--static_img', action='store_true', help="Use the first photo as a time of rest")
|
|
||||||
parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img')
|
parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img')
|
||||||
parser.add_argument('--customvideo_imgnum', type=int, default=1)
|
parser.add_argument('--customvideo_imgnum', type=int, default=1)
|
||||||
|
|
||||||
parser.add_argument('--tts', type=str, default='edgetts') # xtts gpt-sovits
|
parser.add_argument('--tts', type=str, default='edgetts') #xtts gpt-sovits
|
||||||
parser.add_argument('--REF_FILE', type=str, default=None)
|
parser.add_argument('--REF_FILE', type=str, default=None)
|
||||||
parser.add_argument('--REF_TEXT', type=str, default=None)
|
parser.add_argument('--REF_TEXT', type=str, default=None)
|
||||||
parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000
|
parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000
|
||||||
# parser.add_argument('--CHARACTER', type=str, default='test')
|
# parser.add_argument('--CHARACTER', type=str, default='test')
|
||||||
# parser.add_argument('--EMOTION', type=str, default='default')
|
# parser.add_argument('--EMOTION', type=str, default='default')
|
||||||
|
|
||||||
parser.add_argument('--model', type=str, default='ernerf') # musetalk wav2lip
|
parser.add_argument('--model', type=str, default='ernerf') #musetalk wav2lip
|
||||||
|
|
||||||
parser.add_argument('--transport', type=str, default='rtcpush') # rtmp webrtc rtcpush
|
parser.add_argument('--transport', type=str, default='rtcpush') #rtmp webrtc rtcpush
|
||||||
parser.add_argument('--push_url', type=str,
|
parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream
|
||||||
default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') # rtmp://localhost/live/livestream
|
|
||||||
|
|
||||||
parser.add_argument('--listenport', type=int, default=8010)
|
parser.add_argument('--listenport', type=int, default=8010)
|
||||||
|
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
# app.config.from_object(opt)
|
#app.config.from_object(opt)
|
||||||
# print(app.config)
|
#print(app.config)
|
||||||
|
|
||||||
if opt.model == 'ernerf':
|
if opt.model == 'ernerf':
|
||||||
from ernerf.nerf_triplane.provider import NeRFDataset_Test
|
from ernerf.nerf_triplane.provider import NeRFDataset_Test
|
||||||
from ernerf.nerf_triplane.utils import *
|
from ernerf.nerf_triplane.utils import *
|
||||||
from ernerf.nerf_triplane.network import NeRFNetwork
|
from ernerf.nerf_triplane.network import NeRFNetwork
|
||||||
from nerfreal import NeRFReal
|
from nerfreal import NeRFReal
|
||||||
|
|
||||||
# assert test mode
|
# assert test mode
|
||||||
opt.test = True
|
opt.test = True
|
||||||
opt.test_train = False
|
opt.test_train = False
|
||||||
# opt.train_camera =True
|
#opt.train_camera =True
|
||||||
# explicit smoothing
|
# explicit smoothing
|
||||||
opt.smooth_path = True
|
opt.smooth_path = True
|
||||||
opt.smooth_lips = True
|
opt.smooth_lips = True
|
||||||
|
@ -370,7 +330,7 @@ if __name__ == '__main__':
|
||||||
opt.exp_eye = True
|
opt.exp_eye = True
|
||||||
opt.smooth_eye = True
|
opt.smooth_eye = True
|
||||||
|
|
||||||
if opt.torso_imgs == '': # no img,use model output
|
if opt.torso_imgs=='': #no img,use model output
|
||||||
opt.torso = True
|
opt.torso = True
|
||||||
|
|
||||||
# assert opt.cuda_ray, "Only support CUDA ray mode."
|
# assert opt.cuda_ray, "Only support CUDA ray mode."
|
||||||
|
@ -386,10 +346,9 @@ if __name__ == '__main__':
|
||||||
model = NeRFNetwork(opt)
|
model = NeRFNetwork(opt)
|
||||||
|
|
||||||
criterion = torch.nn.MSELoss(reduction='none')
|
criterion = torch.nn.MSELoss(reduction='none')
|
||||||
metrics = [] # use no metric in GUI for faster initialization...
|
metrics = [] # use no metric in GUI for faster initialization...
|
||||||
print(model)
|
print(model)
|
||||||
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16,
|
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt)
|
||||||
metrics=metrics, use_checkpoint=opt.ckpt)
|
|
||||||
|
|
||||||
test_loader = NeRFDataset_Test(opt, device=device).dataloader()
|
test_loader = NeRFDataset_Test(opt, device=device).dataloader()
|
||||||
model.aud_features = test_loader._data.auds
|
model.aud_features = test_loader._data.auds
|
||||||
|
@ -399,19 +358,17 @@ if __name__ == '__main__':
|
||||||
nerfreal = NeRFReal(opt, trainer, test_loader)
|
nerfreal = NeRFReal(opt, trainer, test_loader)
|
||||||
elif opt.model == 'musetalk':
|
elif opt.model == 'musetalk':
|
||||||
from musereal import MuseReal
|
from musereal import MuseReal
|
||||||
|
|
||||||
print(opt)
|
print(opt)
|
||||||
nerfreal = MuseReal(opt)
|
nerfreal = MuseReal(opt)
|
||||||
elif opt.model == 'wav2lip':
|
elif opt.model == 'wav2lip':
|
||||||
from lipreal import LipReal
|
from lipreal import LipReal
|
||||||
|
|
||||||
print(opt)
|
print(opt)
|
||||||
nerfreal = LipReal(opt)
|
nerfreal = LipReal(opt)
|
||||||
|
|
||||||
# txt_to_audio('我是中国人,我来自北京')
|
#txt_to_audio('我是中国人,我来自北京')
|
||||||
if opt.transport == 'rtmp':
|
if opt.transport=='rtmp':
|
||||||
thread_quit = Event()
|
thread_quit = Event()
|
||||||
rendthrd = Thread(target=nerfreal.render, args=(thread_quit,))
|
rendthrd = Thread(target=nerfreal.render,args=(thread_quit,))
|
||||||
rendthrd.start()
|
rendthrd.start()
|
||||||
|
|
||||||
#############################################################################
|
#############################################################################
|
||||||
|
@ -419,37 +376,35 @@ if __name__ == '__main__':
|
||||||
appasync.on_shutdown.append(on_shutdown)
|
appasync.on_shutdown.append(on_shutdown)
|
||||||
appasync.router.add_post("/offer", offer)
|
appasync.router.add_post("/offer", offer)
|
||||||
appasync.router.add_post("/human", human)
|
appasync.router.add_post("/human", human)
|
||||||
appasync.router.add_post("/create_musetalk", handle_create_musetalk)
|
appasync.router.add_static('/',path='web')
|
||||||
appasync.router.add_static('/', path='web')
|
|
||||||
|
|
||||||
# Configure default CORS settings.
|
# Configure default CORS settings.
|
||||||
cors = aiohttp_cors.setup(appasync, defaults={
|
cors = aiohttp_cors.setup(appasync, defaults={
|
||||||
"*": aiohttp_cors.ResourceOptions(
|
"*": aiohttp_cors.ResourceOptions(
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
expose_headers="*",
|
expose_headers="*",
|
||||||
allow_headers="*",
|
allow_headers="*",
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
# Configure CORS on all routes.
|
# Configure CORS on all routes.
|
||||||
for route in list(appasync.router.routes()):
|
for route in list(appasync.router.routes()):
|
||||||
cors.add(route)
|
cors.add(route)
|
||||||
|
|
||||||
|
|
||||||
def run_server(runner):
|
def run_server(runner):
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
loop.run_until_complete(runner.setup())
|
loop.run_until_complete(runner.setup())
|
||||||
site = web.TCPSite(runner, '0.0.0.0', opt.listenport)
|
site = web.TCPSite(runner, '0.0.0.0', opt.listenport)
|
||||||
loop.run_until_complete(site.start())
|
loop.run_until_complete(site.start())
|
||||||
if opt.transport == 'rtcpush':
|
if opt.transport=='rtcpush':
|
||||||
loop.run_until_complete(run(opt.push_url))
|
loop.run_until_complete(run(opt.push_url))
|
||||||
loop.run_forever()
|
loop.run_forever()
|
||||||
|
|
||||||
|
|
||||||
Thread(target=run_server, args=(web.AppRunner(appasync),)).start()
|
Thread(target=run_server, args=(web.AppRunner(appasync),)).start()
|
||||||
|
|
||||||
print('start websocket server')
|
print('start websocket server')
|
||||||
# app.on_shutdown.append(on_shutdown)
|
#app.on_shutdown.append(on_shutdown)
|
||||||
# app.router.add_post("/offer", offer)
|
#app.router.add_post("/offer", offer)
|
||||||
server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler)
|
server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler)
|
||||||
server.serve_forever()
|
server.serve_forever()
|
||||||
|
|
||||||
|
|
149
asrreal.py
149
asrreal.py
|
@ -4,29 +4,19 @@ import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel
|
from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel
|
||||||
|
|
||||||
#import pyaudio
|
|
||||||
import soundfile as sf
|
|
||||||
import resampy
|
|
||||||
|
|
||||||
import queue
|
import queue
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
#from collections import deque
|
#from collections import deque
|
||||||
from threading import Thread, Event
|
from threading import Thread, Event
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
class ASR:
|
from baseasr import BaseASR
|
||||||
|
|
||||||
|
class ASR(BaseASR):
|
||||||
def __init__(self, opt):
|
def __init__(self, opt):
|
||||||
|
super().__init__(opt)
|
||||||
self.opt = opt
|
|
||||||
|
|
||||||
self.play = opt.asr_play #false
|
|
||||||
|
|
||||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
self.fps = opt.fps # 20 ms per frame
|
|
||||||
self.sample_rate = 16000
|
|
||||||
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
|
|
||||||
self.mode = 'live' if opt.asr_wav == '' else 'file'
|
|
||||||
|
|
||||||
if 'esperanto' in self.opt.asr_model:
|
if 'esperanto' in self.opt.asr_model:
|
||||||
self.audio_dim = 44
|
self.audio_dim = 44
|
||||||
elif 'deepspeech' in self.opt.asr_model:
|
elif 'deepspeech' in self.opt.asr_model:
|
||||||
|
@ -41,30 +31,11 @@ class ASR:
|
||||||
self.context_size = opt.m
|
self.context_size = opt.m
|
||||||
self.stride_left_size = opt.l
|
self.stride_left_size = opt.l
|
||||||
self.stride_right_size = opt.r
|
self.stride_right_size = opt.r
|
||||||
self.text = '[START]\n'
|
|
||||||
self.terminated = False
|
|
||||||
self.frames = []
|
|
||||||
self.inwarm = False
|
|
||||||
|
|
||||||
# pad left frames
|
# pad left frames
|
||||||
if self.stride_left_size > 0:
|
if self.stride_left_size > 0:
|
||||||
self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size)
|
self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size)
|
||||||
|
|
||||||
|
|
||||||
self.exit_event = Event()
|
|
||||||
#self.audio_instance = pyaudio.PyAudio() #not need
|
|
||||||
|
|
||||||
# create input stream
|
|
||||||
self.queue = Queue()
|
|
||||||
self.output_queue = Queue()
|
|
||||||
# start a background process to read frames
|
|
||||||
#self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk)
|
|
||||||
#self.queue = Queue()
|
|
||||||
#self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk))
|
|
||||||
|
|
||||||
# current location of audio
|
|
||||||
self.idx = 0
|
|
||||||
|
|
||||||
# create wav2vec model
|
# create wav2vec model
|
||||||
print(f'[INFO] loading ASR model {self.opt.asr_model}...')
|
print(f'[INFO] loading ASR model {self.opt.asr_model}...')
|
||||||
if 'hubert' in self.opt.asr_model:
|
if 'hubert' in self.opt.asr_model:
|
||||||
|
@ -74,10 +45,6 @@ class ASR:
|
||||||
self.processor = AutoProcessor.from_pretrained(opt.asr_model)
|
self.processor = AutoProcessor.from_pretrained(opt.asr_model)
|
||||||
self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device)
|
self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device)
|
||||||
|
|
||||||
# prepare to save logits
|
|
||||||
if self.opt.asr_save_feats:
|
|
||||||
self.all_feats = []
|
|
||||||
|
|
||||||
# the extracted features
|
# the extracted features
|
||||||
# use a loop queue to efficiently record endless features: [f--t---][-------][-------]
|
# use a loop queue to efficiently record endless features: [f--t---][-------][-------]
|
||||||
self.feat_buffer_size = 4
|
self.feat_buffer_size = 4
|
||||||
|
@ -93,8 +60,16 @@ class ASR:
|
||||||
# warm up steps needed: mid + right + window_size + attention_size
|
# warm up steps needed: mid + right + window_size + attention_size
|
||||||
self.warm_up_steps = self.context_size + self.stride_left_size + self.stride_right_size #+ self.stride_left_size #+ 8 + 2 * 3
|
self.warm_up_steps = self.context_size + self.stride_left_size + self.stride_right_size #+ self.stride_left_size #+ 8 + 2 * 3
|
||||||
|
|
||||||
self.listening = False
|
def get_audio_frame(self):
|
||||||
self.playing = False
|
try:
|
||||||
|
frame = self.queue.get(block=False)
|
||||||
|
type = 0
|
||||||
|
#print(f'[INFO] get frame {frame.shape}')
|
||||||
|
except queue.Empty:
|
||||||
|
frame = np.zeros(self.chunk, dtype=np.float32)
|
||||||
|
type = 1
|
||||||
|
|
||||||
|
return frame,type
|
||||||
|
|
||||||
def get_next_feat(self): #get audio embedding to nerf
|
def get_next_feat(self): #get audio embedding to nerf
|
||||||
# return a [1/8, 16] window, for the next input to nerf side.
|
# return a [1/8, 16] window, for the next input to nerf side.
|
||||||
|
@ -136,29 +111,19 @@ class ASR:
|
||||||
|
|
||||||
def run_step(self):
|
def run_step(self):
|
||||||
|
|
||||||
if self.terminated:
|
|
||||||
return
|
|
||||||
|
|
||||||
# get a frame of audio
|
# get a frame of audio
|
||||||
frame,type = self.__get_audio_frame()
|
frame,type = self.get_audio_frame()
|
||||||
|
self.frames.append(frame)
|
||||||
# the last frame
|
# put to output
|
||||||
if frame is None:
|
self.output_queue.put((frame,type))
|
||||||
# terminate, but always run the network for the left frames
|
# context not enough, do not run network.
|
||||||
self.terminated = True
|
if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
|
||||||
else:
|
return
|
||||||
self.frames.append(frame)
|
|
||||||
# put to output
|
|
||||||
self.output_queue.put((frame,type))
|
|
||||||
# context not enough, do not run network.
|
|
||||||
if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
|
|
||||||
return
|
|
||||||
|
|
||||||
inputs = np.concatenate(self.frames) # [N * chunk]
|
inputs = np.concatenate(self.frames) # [N * chunk]
|
||||||
|
|
||||||
# discard the old part to save memory
|
# discard the old part to save memory
|
||||||
if not self.terminated:
|
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
|
||||||
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
|
|
||||||
|
|
||||||
#print(f'[INFO] frame_to_text... ')
|
#print(f'[INFO] frame_to_text... ')
|
||||||
#t = time.time()
|
#t = time.time()
|
||||||
|
@ -166,10 +131,6 @@ class ASR:
|
||||||
#print(f'-------wav2vec time:{time.time()-t:.4f}s')
|
#print(f'-------wav2vec time:{time.time()-t:.4f}s')
|
||||||
feats = logits # better lips-sync than labels
|
feats = logits # better lips-sync than labels
|
||||||
|
|
||||||
# save feats
|
|
||||||
if self.opt.asr_save_feats:
|
|
||||||
self.all_feats.append(feats)
|
|
||||||
|
|
||||||
# record the feats efficiently.. (no concat, constant memory)
|
# record the feats efficiently.. (no concat, constant memory)
|
||||||
start = self.feat_buffer_idx * self.context_size
|
start = self.feat_buffer_idx * self.context_size
|
||||||
end = start + feats.shape[0]
|
end = start + feats.shape[0]
|
||||||
|
@ -203,24 +164,6 @@ class ASR:
|
||||||
# np.save(output_path, unfold_feats.cpu().numpy())
|
# np.save(output_path, unfold_feats.cpu().numpy())
|
||||||
# print(f"[INFO] saved logits to {output_path}")
|
# print(f"[INFO] saved logits to {output_path}")
|
||||||
|
|
||||||
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
|
|
||||||
self.queue.put(audio_chunk)
|
|
||||||
|
|
||||||
def __get_audio_frame(self):
|
|
||||||
if self.inwarm: # warm up
|
|
||||||
return np.zeros(self.chunk, dtype=np.float32),1
|
|
||||||
|
|
||||||
try:
|
|
||||||
frame = self.queue.get(block=False)
|
|
||||||
type = 0
|
|
||||||
print(f'[INFO] get frame {frame.shape}')
|
|
||||||
except queue.Empty:
|
|
||||||
frame = np.zeros(self.chunk, dtype=np.float32)
|
|
||||||
type = 1
|
|
||||||
|
|
||||||
self.idx = self.idx + self.chunk
|
|
||||||
|
|
||||||
return frame,type
|
|
||||||
|
|
||||||
|
|
||||||
def __frame_to_text(self, frame):
|
def __frame_to_text(self, frame):
|
||||||
|
@ -241,8 +184,8 @@ class ASR:
|
||||||
right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input.
|
right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input.
|
||||||
|
|
||||||
# do not cut right if terminated.
|
# do not cut right if terminated.
|
||||||
if self.terminated:
|
# if self.terminated:
|
||||||
right = logits.shape[1]
|
# right = logits.shape[1]
|
||||||
|
|
||||||
logits = logits[:, left:right]
|
logits = logits[:, left:right]
|
||||||
|
|
||||||
|
@ -262,10 +205,23 @@ class ASR:
|
||||||
|
|
||||||
return logits[0], None,None #predicted_ids[0], transcription # [N,]
|
return logits[0], None,None #predicted_ids[0], transcription # [N,]
|
||||||
|
|
||||||
|
|
||||||
def get_audio_out(self): #get origin audio pcm to nerf
|
def warm_up(self):
|
||||||
return self.output_queue.get()
|
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
|
||||||
|
t = time.time()
|
||||||
|
#for _ in range(self.stride_left_size):
|
||||||
|
# self.frames.append(np.zeros(self.chunk, dtype=np.float32))
|
||||||
|
for _ in range(self.warm_up_steps):
|
||||||
|
self.run_step()
|
||||||
|
#if torch.cuda.is_available():
|
||||||
|
# torch.cuda.synchronize()
|
||||||
|
t = time.time() - t
|
||||||
|
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
|
||||||
|
|
||||||
|
#self.clear_queue()
|
||||||
|
|
||||||
|
#####not used function#####################################
|
||||||
|
'''
|
||||||
def __init_queue(self):
|
def __init_queue(self):
|
||||||
self.frames = []
|
self.frames = []
|
||||||
self.queue.queue.clear()
|
self.queue.queue.clear()
|
||||||
|
@ -290,26 +246,6 @@ class ASR:
|
||||||
if self.play:
|
if self.play:
|
||||||
self.output_queue.queue.clear()
|
self.output_queue.queue.clear()
|
||||||
|
|
||||||
def warm_up(self):
|
|
||||||
|
|
||||||
#self.listen()
|
|
||||||
|
|
||||||
self.inwarm = True
|
|
||||||
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
|
|
||||||
t = time.time()
|
|
||||||
#for _ in range(self.stride_left_size):
|
|
||||||
# self.frames.append(np.zeros(self.chunk, dtype=np.float32))
|
|
||||||
for _ in range(self.warm_up_steps):
|
|
||||||
self.run_step()
|
|
||||||
#if torch.cuda.is_available():
|
|
||||||
# torch.cuda.synchronize()
|
|
||||||
t = time.time() - t
|
|
||||||
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
|
|
||||||
self.inwarm = False
|
|
||||||
|
|
||||||
#self.clear_queue()
|
|
||||||
|
|
||||||
#####not used function#####################################
|
|
||||||
def listen(self):
|
def listen(self):
|
||||||
# start
|
# start
|
||||||
if self.mode == 'live' and not self.listening:
|
if self.mode == 'live' and not self.listening:
|
||||||
|
@ -404,4 +340,5 @@ if __name__ == '__main__':
|
||||||
raise ValueError("DeepSpeech features should not use this code to extract...")
|
raise ValueError("DeepSpeech features should not use this code to extract...")
|
||||||
|
|
||||||
with ASR(opt) as asr:
|
with ASR(opt) as asr:
|
||||||
asr.run()
|
asr.run()
|
||||||
|
'''
|
54
lipasr.py
54
lipasr.py
|
@ -6,60 +6,16 @@ import queue
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
|
|
||||||
|
from baseasr import BaseASR
|
||||||
from wav2lip import audio
|
from wav2lip import audio
|
||||||
|
|
||||||
class LipASR:
|
class LipASR(BaseASR):
|
||||||
def __init__(self, opt):
|
|
||||||
self.opt = opt
|
|
||||||
|
|
||||||
self.fps = opt.fps # 20 ms per frame
|
|
||||||
self.sample_rate = 16000
|
|
||||||
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
|
|
||||||
self.queue = Queue()
|
|
||||||
# self.input_stream = BytesIO()
|
|
||||||
self.output_queue = mp.Queue()
|
|
||||||
|
|
||||||
#self.audio_processor = audio_processor
|
|
||||||
self.batch_size = opt.batch_size
|
|
||||||
|
|
||||||
self.frames = []
|
|
||||||
self.stride_left_size = opt.l
|
|
||||||
self.stride_right_size = opt.r
|
|
||||||
#self.context_size = 10
|
|
||||||
self.feat_queue = mp.Queue(5)
|
|
||||||
|
|
||||||
self.warm_up()
|
|
||||||
|
|
||||||
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
|
|
||||||
self.queue.put(audio_chunk)
|
|
||||||
|
|
||||||
def __get_audio_frame(self):
|
|
||||||
try:
|
|
||||||
frame = self.queue.get(block=True,timeout=0.01)
|
|
||||||
type = 0
|
|
||||||
#print(f'[INFO] get frame {frame.shape}')
|
|
||||||
except queue.Empty:
|
|
||||||
frame = np.zeros(self.chunk, dtype=np.float32)
|
|
||||||
type = 1
|
|
||||||
|
|
||||||
return frame,type
|
|
||||||
|
|
||||||
def get_audio_out(self): #get origin audio pcm to nerf
|
|
||||||
return self.output_queue.get()
|
|
||||||
|
|
||||||
def warm_up(self):
|
|
||||||
for _ in range(self.stride_left_size + self.stride_right_size):
|
|
||||||
audio_frame,type=self.__get_audio_frame()
|
|
||||||
self.frames.append(audio_frame)
|
|
||||||
self.output_queue.put((audio_frame,type))
|
|
||||||
for _ in range(self.stride_left_size):
|
|
||||||
self.output_queue.get()
|
|
||||||
|
|
||||||
def run_step(self):
|
def run_step(self):
|
||||||
############################################## extract audio feature ##############################################
|
############################################## extract audio feature ##############################################
|
||||||
# get a frame of audio
|
# get a frame of audio
|
||||||
for _ in range(self.batch_size*2):
|
for _ in range(self.batch_size*2):
|
||||||
frame,type = self.__get_audio_frame()
|
frame,type = self.get_audio_frame()
|
||||||
self.frames.append(frame)
|
self.frames.append(frame)
|
||||||
# put to output
|
# put to output
|
||||||
self.output_queue.put((frame,type))
|
self.output_queue.put((frame,type))
|
||||||
|
@ -89,7 +45,3 @@ class LipASR:
|
||||||
|
|
||||||
# discard the old part to save memory
|
# discard the old part to save memory
|
||||||
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
|
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
|
||||||
|
|
||||||
|
|
||||||
def get_next_feat(self,block,timeout):
|
|
||||||
return self.feat_queue.get(block,timeout)
|
|
12
lipreal.py
12
lipreal.py
|
@ -164,6 +164,7 @@ class LipReal:
|
||||||
self.__loadavatar()
|
self.__loadavatar()
|
||||||
|
|
||||||
self.asr = LipASR(opt)
|
self.asr = LipASR(opt)
|
||||||
|
self.asr.warm_up()
|
||||||
if opt.tts == "edgetts":
|
if opt.tts == "edgetts":
|
||||||
self.tts = EdgeTTS(opt,self)
|
self.tts = EdgeTTS(opt,self)
|
||||||
elif opt.tts == "gpt-sovits":
|
elif opt.tts == "gpt-sovits":
|
||||||
|
@ -199,6 +200,10 @@ class LipReal:
|
||||||
|
|
||||||
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
|
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
|
||||||
self.asr.put_audio_frame(audio_chunk)
|
self.asr.put_audio_frame(audio_chunk)
|
||||||
|
|
||||||
|
def pause_talk(self):
|
||||||
|
self.tts.pause_talk()
|
||||||
|
self.asr.pause_talk()
|
||||||
|
|
||||||
|
|
||||||
def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):
|
def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):
|
||||||
|
@ -257,9 +262,12 @@ class LipReal:
|
||||||
t = time.perf_counter()
|
t = time.perf_counter()
|
||||||
self.asr.run_step()
|
self.asr.run_step()
|
||||||
|
|
||||||
if video_track._queue.qsize()>=2*self.opt.batch_size:
|
# if video_track._queue.qsize()>=2*self.opt.batch_size:
|
||||||
|
# print('sleep qsize=',video_track._queue.qsize())
|
||||||
|
# time.sleep(0.04*video_track._queue.qsize()*0.8)
|
||||||
|
if video_track._queue.qsize()>=5:
|
||||||
print('sleep qsize=',video_track._queue.qsize())
|
print('sleep qsize=',video_track._queue.qsize())
|
||||||
time.sleep(0.04*self.opt.batch_size*1.5)
|
time.sleep(0.04*video_track._queue.qsize()*0.8)
|
||||||
|
|
||||||
# delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
|
# delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
|
||||||
# if delay > 0:
|
# if delay > 0:
|
||||||
|
|
54
museasr.py
54
museasr.py
|
@ -1,65 +1,22 @@
|
||||||
import time
|
import time
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import queue
|
import queue
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
|
from baseasr import BaseASR
|
||||||
from musetalk.whisper.audio2feature import Audio2Feature
|
from musetalk.whisper.audio2feature import Audio2Feature
|
||||||
|
|
||||||
class MuseASR:
|
class MuseASR(BaseASR):
|
||||||
def __init__(self, opt, audio_processor:Audio2Feature):
|
def __init__(self, opt, audio_processor:Audio2Feature):
|
||||||
self.opt = opt
|
super().__init__(opt)
|
||||||
|
|
||||||
self.fps = opt.fps # 20 ms per frame
|
|
||||||
self.sample_rate = 16000
|
|
||||||
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
|
|
||||||
self.queue = Queue()
|
|
||||||
# self.input_stream = BytesIO()
|
|
||||||
self.output_queue = mp.Queue()
|
|
||||||
|
|
||||||
self.audio_processor = audio_processor
|
self.audio_processor = audio_processor
|
||||||
self.batch_size = opt.batch_size
|
|
||||||
|
|
||||||
self.frames = []
|
|
||||||
self.stride_left_size = opt.l
|
|
||||||
self.stride_right_size = opt.r
|
|
||||||
self.feat_queue = mp.Queue(5)
|
|
||||||
|
|
||||||
self.warm_up()
|
|
||||||
|
|
||||||
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
|
|
||||||
self.queue.put(audio_chunk)
|
|
||||||
|
|
||||||
def __get_audio_frame(self):
|
|
||||||
try:
|
|
||||||
frame = self.queue.get(block=True,timeout=0.01)
|
|
||||||
type = 0
|
|
||||||
#print(f'[INFO] get frame {frame.shape}')
|
|
||||||
except queue.Empty:
|
|
||||||
frame = np.zeros(self.chunk, dtype=np.float32)
|
|
||||||
type = 1
|
|
||||||
|
|
||||||
return frame,type
|
|
||||||
|
|
||||||
def get_audio_out(self): #get origin audio pcm to nerf
|
|
||||||
return self.output_queue.get()
|
|
||||||
|
|
||||||
def warm_up(self):
|
|
||||||
for _ in range(self.stride_left_size + self.stride_right_size):
|
|
||||||
audio_frame,type=self.__get_audio_frame()
|
|
||||||
self.frames.append(audio_frame)
|
|
||||||
self.output_queue.put((audio_frame,type))
|
|
||||||
|
|
||||||
for _ in range(self.stride_left_size):
|
|
||||||
self.output_queue.get()
|
|
||||||
|
|
||||||
def run_step(self):
|
def run_step(self):
|
||||||
############################################## extract audio feature ##############################################
|
############################################## extract audio feature ##############################################
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
for _ in range(self.batch_size*2):
|
for _ in range(self.batch_size*2):
|
||||||
audio_frame,type=self.__get_audio_frame()
|
audio_frame,type=self.get_audio_frame()
|
||||||
self.frames.append(audio_frame)
|
self.frames.append(audio_frame)
|
||||||
self.output_queue.put((audio_frame,type))
|
self.output_queue.put((audio_frame,type))
|
||||||
|
|
||||||
|
@ -77,6 +34,3 @@ class MuseASR:
|
||||||
self.feat_queue.put(whisper_chunks)
|
self.feat_queue.put(whisper_chunks)
|
||||||
# discard the old part to save memory
|
# discard the old part to save memory
|
||||||
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
|
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
|
||||||
|
|
||||||
def get_next_feat(self,block,timeout):
|
|
||||||
return self.feat_queue.get(block,timeout)
|
|
227
musereal.py
227
musereal.py
|
@ -2,7 +2,7 @@ import math
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
# from .utils import *
|
#from .utils import *
|
||||||
import subprocess
|
import subprocess
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
@ -18,19 +18,17 @@ from threading import Thread, Event
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
|
|
||||||
from musetalk.utils.utils import get_file_type, get_video_fps, datagen
|
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
|
||||||
# from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
|
#from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
|
||||||
from musetalk.utils.blending import get_image, get_image_prepare_material, get_image_blending
|
from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending
|
||||||
from musetalk.utils.utils import load_all_model, load_diffusion_model, load_audio_model
|
from musetalk.utils.utils import load_all_model,load_diffusion_model,load_audio_model
|
||||||
from ttsreal import EdgeTTS, VoitsTTS, XTTS
|
from ttsreal import EdgeTTS,VoitsTTS,XTTS
|
||||||
|
|
||||||
from museasr import MuseASR
|
from museasr import MuseASR
|
||||||
import asyncio
|
import asyncio
|
||||||
from av import AudioFrame, VideoFrame
|
from av import AudioFrame, VideoFrame
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def read_imgs(img_list):
|
def read_imgs(img_list):
|
||||||
frames = []
|
frames = []
|
||||||
print('reading images...')
|
print('reading images...')
|
||||||
|
@ -39,146 +37,143 @@ def read_imgs(img_list):
|
||||||
frames.append(frame)
|
frames.append(frame)
|
||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
|
||||||
def __mirror_index(size, index):
|
def __mirror_index(size, index):
|
||||||
# size = len(self.coord_list_cycle)
|
#size = len(self.coord_list_cycle)
|
||||||
turn = index // size
|
turn = index // size
|
||||||
res = index % size
|
res = index % size
|
||||||
if turn % 2 == 0:
|
if turn % 2 == 0:
|
||||||
return res
|
return res
|
||||||
else:
|
else:
|
||||||
return size - res - 1
|
return size - res - 1
|
||||||
|
|
||||||
|
|
||||||
def inference(render_event, batch_size, latents_out_path, audio_feat_queue, audio_out_queue, res_frame_queue,
|
|
||||||
): # vae, unet, pe,timesteps
|
|
||||||
|
|
||||||
|
def inference(render_event,batch_size,latents_out_path,audio_feat_queue,audio_out_queue,res_frame_queue,
|
||||||
|
): #vae, unet, pe,timesteps
|
||||||
|
|
||||||
vae, unet, pe = load_diffusion_model()
|
vae, unet, pe = load_diffusion_model()
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
timesteps = torch.tensor([0], device=device)
|
timesteps = torch.tensor([0], device=device)
|
||||||
pe = pe.half()
|
pe = pe.half()
|
||||||
vae.vae = vae.vae.half()
|
vae.vae = vae.vae.half()
|
||||||
unet.model = unet.model.half()
|
unet.model = unet.model.half()
|
||||||
|
|
||||||
input_latent_list_cycle = torch.load(latents_out_path)
|
input_latent_list_cycle = torch.load(latents_out_path)
|
||||||
length = len(input_latent_list_cycle)
|
length = len(input_latent_list_cycle)
|
||||||
index = 0
|
index = 0
|
||||||
count = 0
|
count=0
|
||||||
counttime = 0
|
counttime=0
|
||||||
print('start inference')
|
print('start inference')
|
||||||
while True:
|
while True:
|
||||||
if render_event.is_set():
|
if render_event.is_set():
|
||||||
starttime = time.perf_counter()
|
starttime=time.perf_counter()
|
||||||
try:
|
try:
|
||||||
whisper_chunks = audio_feat_queue.get(block=True, timeout=1)
|
whisper_chunks = audio_feat_queue.get(block=True, timeout=1)
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
is_all_silence = True
|
is_all_silence=True
|
||||||
audio_frames = []
|
audio_frames = []
|
||||||
for _ in range(batch_size * 2):
|
for _ in range(batch_size*2):
|
||||||
frame, type = audio_out_queue.get()
|
frame,type = audio_out_queue.get()
|
||||||
audio_frames.append((frame, type))
|
audio_frames.append((frame,type))
|
||||||
if type == 0:
|
if type==0:
|
||||||
is_all_silence = False
|
is_all_silence=False
|
||||||
if is_all_silence:
|
if is_all_silence:
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
res_frame_queue.put((None, __mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]))
|
res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
|
||||||
index = index + 1
|
index = index + 1
|
||||||
else:
|
else:
|
||||||
# print('infer=======')
|
# print('infer=======')
|
||||||
t = time.perf_counter()
|
t=time.perf_counter()
|
||||||
whisper_batch = np.stack(whisper_chunks)
|
whisper_batch = np.stack(whisper_chunks)
|
||||||
latent_batch = []
|
latent_batch = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
idx = __mirror_index(length, index + i)
|
idx = __mirror_index(length,index+i)
|
||||||
latent = input_latent_list_cycle[idx]
|
latent = input_latent_list_cycle[idx]
|
||||||
latent_batch.append(latent)
|
latent_batch.append(latent)
|
||||||
latent_batch = torch.cat(latent_batch, dim=0)
|
latent_batch = torch.cat(latent_batch, dim=0)
|
||||||
|
|
||||||
# for i, (whisper_batch,latent_batch) in enumerate(gen):
|
# for i, (whisper_batch,latent_batch) in enumerate(gen):
|
||||||
audio_feature_batch = torch.from_numpy(whisper_batch)
|
audio_feature_batch = torch.from_numpy(whisper_batch)
|
||||||
audio_feature_batch = audio_feature_batch.to(device=unet.device,
|
audio_feature_batch = audio_feature_batch.to(device=unet.device,
|
||||||
dtype=unet.model.dtype)
|
dtype=unet.model.dtype)
|
||||||
audio_feature_batch = pe(audio_feature_batch)
|
audio_feature_batch = pe(audio_feature_batch)
|
||||||
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
||||||
# print('prepare time:',time.perf_counter()-t)
|
# print('prepare time:',time.perf_counter()-t)
|
||||||
# t=time.perf_counter()
|
# t=time.perf_counter()
|
||||||
|
|
||||||
pred_latents = unet.model(latent_batch,
|
pred_latents = unet.model(latent_batch,
|
||||||
timesteps,
|
timesteps,
|
||||||
encoder_hidden_states=audio_feature_batch).sample
|
encoder_hidden_states=audio_feature_batch).sample
|
||||||
# print('unet time:',time.perf_counter()-t)
|
# print('unet time:',time.perf_counter()-t)
|
||||||
# t=time.perf_counter()
|
# t=time.perf_counter()
|
||||||
recon = vae.decode_latents(pred_latents)
|
recon = vae.decode_latents(pred_latents)
|
||||||
# print('vae time:',time.perf_counter()-t)
|
# print('vae time:',time.perf_counter()-t)
|
||||||
# print('diffusion len=',len(recon))
|
#print('diffusion len=',len(recon))
|
||||||
counttime += (time.perf_counter() - t)
|
counttime += (time.perf_counter() - t)
|
||||||
count += batch_size
|
count += batch_size
|
||||||
# _totalframe += 1
|
#_totalframe += 1
|
||||||
if count >= 100:
|
if count>=100:
|
||||||
print(f"------actual avg infer fps:{count / counttime:.4f}")
|
print(f"------actual avg infer fps:{count/counttime:.4f}")
|
||||||
count = 0
|
count=0
|
||||||
counttime = 0
|
counttime=0
|
||||||
for i, res_frame in enumerate(recon):
|
for i,res_frame in enumerate(recon):
|
||||||
# self.__pushmedia(res_frame,loop,audio_track,video_track)
|
#self.__pushmedia(res_frame,loop,audio_track,video_track)
|
||||||
res_frame_queue.put((res_frame, __mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]))
|
res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
|
||||||
index = index + 1
|
index = index + 1
|
||||||
# print('total batch time:',time.perf_counter()-starttime)
|
#print('total batch time:',time.perf_counter()-starttime)
|
||||||
else:
|
else:
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
print('musereal inference processor stop')
|
print('musereal inference processor stop')
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
class MuseReal:
|
class MuseReal:
|
||||||
def __init__(self, opt):
|
def __init__(self, opt):
|
||||||
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
|
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
|
||||||
self.W = opt.W
|
self.W = opt.W
|
||||||
self.H = opt.H
|
self.H = opt.H
|
||||||
|
|
||||||
self.fps = opt.fps # 20 ms per frame
|
self.fps = opt.fps # 20 ms per frame
|
||||||
|
|
||||||
#### musetalk
|
#### musetalk
|
||||||
self.avatar_id = opt.avatar_id
|
self.avatar_id = opt.avatar_id
|
||||||
self.static_img = opt.static_img
|
self.video_path = '' #video_path
|
||||||
self.video_path = '' # video_path
|
|
||||||
self.bbox_shift = opt.bbox_shift
|
self.bbox_shift = opt.bbox_shift
|
||||||
self.avatar_path = f"./data/avatars/{self.avatar_id}"
|
self.avatar_path = f"./data/avatars/{self.avatar_id}"
|
||||||
self.full_imgs_path = f"{self.avatar_path}/full_imgs"
|
self.full_imgs_path = f"{self.avatar_path}/full_imgs"
|
||||||
self.coords_path = f"{self.avatar_path}/coords.pkl"
|
self.coords_path = f"{self.avatar_path}/coords.pkl"
|
||||||
self.latents_out_path = f"{self.avatar_path}/latents.pt"
|
self.latents_out_path= f"{self.avatar_path}/latents.pt"
|
||||||
self.video_out_path = f"{self.avatar_path}/vid_output/"
|
self.video_out_path = f"{self.avatar_path}/vid_output/"
|
||||||
self.mask_out_path = f"{self.avatar_path}/mask"
|
self.mask_out_path =f"{self.avatar_path}/mask"
|
||||||
self.mask_coords_path = f"{self.avatar_path}/mask_coords.pkl"
|
self.mask_coords_path =f"{self.avatar_path}/mask_coords.pkl"
|
||||||
self.avatar_info_path = f"{self.avatar_path}/avator_info.json"
|
self.avatar_info_path = f"{self.avatar_path}/avator_info.json"
|
||||||
self.avatar_info = {
|
self.avatar_info = {
|
||||||
"avatar_id": self.avatar_id,
|
"avatar_id":self.avatar_id,
|
||||||
"video_path": self.video_path,
|
"video_path":self.video_path,
|
||||||
"bbox_shift": self.bbox_shift
|
"bbox_shift":self.bbox_shift
|
||||||
}
|
}
|
||||||
self.batch_size = opt.batch_size
|
self.batch_size = opt.batch_size
|
||||||
self.idx = 0
|
self.idx = 0
|
||||||
self.res_frame_queue = mp.Queue(self.batch_size * 2)
|
self.res_frame_queue = mp.Queue(self.batch_size*2)
|
||||||
self.__loadmodels()
|
self.__loadmodels()
|
||||||
self.__loadavatar()
|
self.__loadavatar()
|
||||||
|
|
||||||
self.asr = MuseASR(opt, self.audio_processor)
|
self.asr = MuseASR(opt,self.audio_processor)
|
||||||
|
self.asr.warm_up()
|
||||||
if opt.tts == "edgetts":
|
if opt.tts == "edgetts":
|
||||||
self.tts = EdgeTTS(opt, self)
|
self.tts = EdgeTTS(opt,self)
|
||||||
elif opt.tts == "gpt-sovits":
|
elif opt.tts == "gpt-sovits":
|
||||||
self.tts = VoitsTTS(opt, self)
|
self.tts = VoitsTTS(opt,self)
|
||||||
elif opt.tts == "xtts":
|
elif opt.tts == "xtts":
|
||||||
self.tts = XTTS(opt, self)
|
self.tts = XTTS(opt,self)
|
||||||
# self.__warm_up()
|
#self.__warm_up()
|
||||||
|
|
||||||
self.render_event = mp.Event()
|
self.render_event = mp.Event()
|
||||||
mp.Process(target=inference, args=(self.render_event, self.batch_size, self.latents_out_path,
|
mp.Process(target=inference, args=(self.render_event,self.batch_size,self.latents_out_path,
|
||||||
self.asr.feat_queue, self.asr.output_queue, self.res_frame_queue,
|
self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue,
|
||||||
)).start() # self.vae, self.unet, self.pe,self.timesteps
|
)).start() #self.vae, self.unet, self.pe,self.timesteps
|
||||||
|
|
||||||
def __loadmodels(self):
|
def __loadmodels(self):
|
||||||
# load model weights
|
# load model weights
|
||||||
self.audio_processor = load_audio_model()
|
self.audio_processor= load_audio_model()
|
||||||
# self.audio_processor, self.vae, self.unet, self.pe = load_all_model()
|
# self.audio_processor, self.vae, self.unet, self.pe = load_all_model()
|
||||||
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
# self.timesteps = torch.tensor([0], device=device)
|
# self.timesteps = torch.tensor([0], device=device)
|
||||||
|
@ -187,7 +182,7 @@ class MuseReal:
|
||||||
# self.unet.model = self.unet.model.half()
|
# self.unet.model = self.unet.model.half()
|
||||||
|
|
||||||
def __loadavatar(self):
|
def __loadavatar(self):
|
||||||
# self.input_latent_list_cycle = torch.load(self.latents_out_path)
|
#self.input_latent_list_cycle = torch.load(self.latents_out_path)
|
||||||
with open(self.coords_path, 'rb') as f:
|
with open(self.coords_path, 'rb') as f:
|
||||||
self.coord_list_cycle = pickle.load(f)
|
self.coord_list_cycle = pickle.load(f)
|
||||||
input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
|
input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
|
||||||
|
@ -198,13 +193,19 @@ class MuseReal:
|
||||||
input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]'))
|
input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]'))
|
||||||
input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||||
self.mask_list_cycle = read_imgs(input_mask_list)
|
self.mask_list_cycle = read_imgs(input_mask_list)
|
||||||
|
|
||||||
def put_msg_txt(self, msg):
|
|
||||||
|
def put_msg_txt(self,msg):
|
||||||
self.tts.put_msg_txt(msg)
|
self.tts.put_msg_txt(msg)
|
||||||
|
|
||||||
def put_audio_frame(self, audio_chunk): # 16khz 20ms pcm
|
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
|
||||||
self.asr.put_audio_frame(audio_chunk)
|
self.asr.put_audio_frame(audio_chunk)
|
||||||
|
|
||||||
|
def pause_talk(self):
|
||||||
|
self.tts.pause_talk()
|
||||||
|
self.asr.pause_talk()
|
||||||
|
|
||||||
|
|
||||||
def __mirror_index(self, index):
|
def __mirror_index(self, index):
|
||||||
size = len(self.coord_list_cycle)
|
size = len(self.coord_list_cycle)
|
||||||
turn = index // size
|
turn = index // size
|
||||||
|
@ -212,15 +213,15 @@ class MuseReal:
|
||||||
if turn % 2 == 0:
|
if turn % 2 == 0:
|
||||||
return res
|
return res
|
||||||
else:
|
else:
|
||||||
return size - res - 1
|
return size - res - 1
|
||||||
|
|
||||||
def __warm_up(self):
|
def __warm_up(self):
|
||||||
self.asr.run_step()
|
self.asr.run_step()
|
||||||
whisper_chunks = self.asr.get_next_feat()
|
whisper_chunks = self.asr.get_next_feat()
|
||||||
whisper_batch = np.stack(whisper_chunks)
|
whisper_batch = np.stack(whisper_chunks)
|
||||||
latent_batch = []
|
latent_batch = []
|
||||||
for i in range(self.batch_size):
|
for i in range(self.batch_size):
|
||||||
idx = self.__mirror_index(self.idx + i)
|
idx = self.__mirror_index(self.idx+i)
|
||||||
latent = self.input_latent_list_cycle[idx]
|
latent = self.input_latent_list_cycle[idx]
|
||||||
latent_batch.append(latent)
|
latent_batch.append(latent)
|
||||||
latent_batch = torch.cat(latent_batch, dim=0)
|
latent_batch = torch.cat(latent_batch, dim=0)
|
||||||
|
@ -228,88 +229,90 @@ class MuseReal:
|
||||||
# for i, (whisper_batch,latent_batch) in enumerate(gen):
|
# for i, (whisper_batch,latent_batch) in enumerate(gen):
|
||||||
audio_feature_batch = torch.from_numpy(whisper_batch)
|
audio_feature_batch = torch.from_numpy(whisper_batch)
|
||||||
audio_feature_batch = audio_feature_batch.to(device=self.unet.device,
|
audio_feature_batch = audio_feature_batch.to(device=self.unet.device,
|
||||||
dtype=self.unet.model.dtype)
|
dtype=self.unet.model.dtype)
|
||||||
audio_feature_batch = self.pe(audio_feature_batch)
|
audio_feature_batch = self.pe(audio_feature_batch)
|
||||||
latent_batch = latent_batch.to(dtype=self.unet.model.dtype)
|
latent_batch = latent_batch.to(dtype=self.unet.model.dtype)
|
||||||
|
|
||||||
pred_latents = self.unet.model(latent_batch,
|
pred_latents = self.unet.model(latent_batch,
|
||||||
self.timesteps,
|
self.timesteps,
|
||||||
encoder_hidden_states=audio_feature_batch).sample
|
encoder_hidden_states=audio_feature_batch).sample
|
||||||
recon = self.vae.decode_latents(pred_latents)
|
recon = self.vae.decode_latents(pred_latents)
|
||||||
|
|
||||||
|
|
||||||
def process_frames(self, quit_event, loop=None, audio_track=None, video_track=None):
|
def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):
|
||||||
|
|
||||||
while not quit_event.is_set():
|
while not quit_event.is_set():
|
||||||
try:
|
try:
|
||||||
res_frame, idx, audio_frames = self.res_frame_queue.get(block=True, timeout=1)
|
res_frame,idx,audio_frames = self.res_frame_queue.get(block=True, timeout=1)
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
if audio_frames[0][1] == 1 and audio_frames[1][1] == 1: # 全为静音数据,只需要取fullimg
|
if audio_frames[0][1]==1 and audio_frames[1][1]==1: #全为静音数据,只需要取fullimg
|
||||||
if self.static_img:
|
combine_frame = self.frame_list_cycle[idx]
|
||||||
combine_frame = self.frame_list_cycle[0]
|
|
||||||
else:
|
|
||||||
combine_frame = self.frame_list_cycle[idx]
|
|
||||||
else:
|
else:
|
||||||
bbox = self.coord_list_cycle[idx]
|
bbox = self.coord_list_cycle[idx]
|
||||||
ori_frame = copy.deepcopy(self.frame_list_cycle[idx])
|
ori_frame = copy.deepcopy(self.frame_list_cycle[idx])
|
||||||
x1, y1, x2, y2 = bbox
|
x1, y1, x2, y2 = bbox
|
||||||
try:
|
try:
|
||||||
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
|
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
||||||
except:
|
except:
|
||||||
continue
|
continue
|
||||||
mask = self.mask_list_cycle[idx]
|
mask = self.mask_list_cycle[idx]
|
||||||
mask_crop_box = self.mask_coords_list_cycle[idx]
|
mask_crop_box = self.mask_coords_list_cycle[idx]
|
||||||
# combine_frame = get_image(ori_frame,res_frame,bbox)
|
#combine_frame = get_image(ori_frame,res_frame,bbox)
|
||||||
# t=time.perf_counter()
|
#t=time.perf_counter()
|
||||||
combine_frame = get_image_blending(ori_frame, res_frame, bbox, mask, mask_crop_box)
|
combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
|
||||||
# print('blending time:',time.perf_counter()-t)
|
#print('blending time:',time.perf_counter()-t)
|
||||||
|
|
||||||
image = combine_frame # (outputs['image'] * 255).astype(np.uint8)
|
image = combine_frame #(outputs['image'] * 255).astype(np.uint8)
|
||||||
new_frame = VideoFrame.from_ndarray(image, format="bgr24")
|
new_frame = VideoFrame.from_ndarray(image, format="bgr24")
|
||||||
asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop)
|
asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop)
|
||||||
|
|
||||||
for audio_frame in audio_frames:
|
for audio_frame in audio_frames:
|
||||||
frame, type = audio_frame
|
frame,type = audio_frame
|
||||||
frame = (frame * 32767).astype(np.int16)
|
frame = (frame * 32767).astype(np.int16)
|
||||||
new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
|
new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
|
||||||
new_frame.planes[0].update(frame.tobytes())
|
new_frame.planes[0].update(frame.tobytes())
|
||||||
new_frame.sample_rate = 16000
|
new_frame.sample_rate=16000
|
||||||
# if audio_track._queue.qsize()>10:
|
# if audio_track._queue.qsize()>10:
|
||||||
# time.sleep(0.1)
|
# time.sleep(0.1)
|
||||||
asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop)
|
asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop)
|
||||||
print('musereal process_frames thread stop')
|
print('musereal process_frames thread stop')
|
||||||
|
|
||||||
def render(self, quit_event, loop=None, audio_track=None, video_track=None):
|
def render(self,quit_event,loop=None,audio_track=None,video_track=None):
|
||||||
# if self.opt.asr:
|
#if self.opt.asr:
|
||||||
# self.asr.warm_up()
|
# self.asr.warm_up()
|
||||||
|
|
||||||
self.tts.render(quit_event)
|
self.tts.render(quit_event)
|
||||||
process_thread = Thread(target=self.process_frames, args=(quit_event, loop, audio_track, video_track))
|
process_thread = Thread(target=self.process_frames, args=(quit_event,loop,audio_track,video_track))
|
||||||
process_thread.start()
|
process_thread.start()
|
||||||
|
|
||||||
self.render_event.set() # start infer process render
|
self.render_event.set() #start infer process render
|
||||||
count = 0
|
count=0
|
||||||
totaltime = 0
|
totaltime=0
|
||||||
_starttime = time.perf_counter()
|
_starttime=time.perf_counter()
|
||||||
# _totalframe=0
|
#_totalframe=0
|
||||||
while not quit_event.is_set(): # todo
|
while not quit_event.is_set(): #todo
|
||||||
# update texture every frame
|
# update texture every frame
|
||||||
# audio stream thread...
|
# audio stream thread...
|
||||||
t = time.perf_counter()
|
t = time.perf_counter()
|
||||||
self.asr.run_step()
|
self.asr.run_step()
|
||||||
# self.test_step(loop,audio_track,video_track)
|
#self.test_step(loop,audio_track,video_track)
|
||||||
# totaltime += (time.perf_counter() - t)
|
# totaltime += (time.perf_counter() - t)
|
||||||
# count += self.opt.batch_size
|
# count += self.opt.batch_size
|
||||||
# if count>=100:
|
# if count>=100:
|
||||||
# print(f"------actual avg infer fps:{count/totaltime:.4f}")
|
# print(f"------actual avg infer fps:{count/totaltime:.4f}")
|
||||||
# count=0
|
# count=0
|
||||||
# totaltime=0
|
# totaltime=0
|
||||||
if video_track._queue.qsize() >= 2 * self.opt.batch_size:
|
if video_track._queue.qsize()>=1.5*self.opt.batch_size:
|
||||||
print('sleep qsize=', video_track._queue.qsize())
|
print('sleep qsize=',video_track._queue.qsize())
|
||||||
time.sleep(0.04 * self.opt.batch_size * 1.5)
|
time.sleep(0.04*video_track._queue.qsize()*0.8)
|
||||||
|
# if video_track._queue.qsize()>=5:
|
||||||
|
# print('sleep qsize=',video_track._queue.qsize())
|
||||||
|
# time.sleep(0.04*video_track._queue.qsize()*0.8)
|
||||||
|
|
||||||
# delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
|
# delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
|
||||||
# if delay > 0:
|
# if delay > 0:
|
||||||
# time.sleep(delay)
|
# time.sleep(delay)
|
||||||
self.render_event.clear() # end infer process render
|
self.render_event.clear() #end infer process render
|
||||||
print('musereal thread stop')
|
print('musereal thread stop')
|
||||||
|
|
|
@ -7,14 +7,15 @@ from PIL import Image
|
||||||
from .model import BiSeNet
|
from .model import BiSeNet
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
|
|
||||||
|
|
||||||
class FaceParsing():
|
class FaceParsing():
|
||||||
def __init__(self, resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
|
def __init__(self,resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
|
||||||
model_pth='./models/face-parse-bisent/79999_iter.pth'):
|
model_pth='./models/face-parse-bisent/79999_iter.pth'):
|
||||||
self.net = self.model_init(resnet_path,model_pth)
|
self.net = self.model_init(resnet_path,model_pth)
|
||||||
self.preprocess = self.image_preprocess()
|
self.preprocess = self.image_preprocess()
|
||||||
|
|
||||||
def model_init(self,resnet_path, model_pth):
|
def model_init(self,
|
||||||
|
resnet_path,
|
||||||
|
model_pth):
|
||||||
net = BiSeNet(resnet_path)
|
net = BiSeNet(resnet_path)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
net.cuda()
|
net.cuda()
|
||||||
|
@ -44,13 +45,13 @@ class FaceParsing():
|
||||||
img = torch.unsqueeze(img, 0)
|
img = torch.unsqueeze(img, 0)
|
||||||
out = self.net(img)[0]
|
out = self.net(img)[0]
|
||||||
parsing = out.squeeze(0).cpu().numpy().argmax(0)
|
parsing = out.squeeze(0).cpu().numpy().argmax(0)
|
||||||
parsing[np.where(parsing > 13)] = 0
|
parsing[np.where(parsing>13)] = 0
|
||||||
parsing[np.where(parsing >= 1)] = 255
|
parsing[np.where(parsing>=1)] = 255
|
||||||
parsing = Image.fromarray(parsing.astype(np.uint8))
|
parsing = Image.fromarray(parsing.astype(np.uint8))
|
||||||
return parsing
|
return parsing
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
fp = FaceParsing()
|
fp = FaceParsing()
|
||||||
segmap = fp('154_small.png')
|
segmap = fp('154_small.png')
|
||||||
segmap.save('res.png')
|
segmap.save('res.png')
|
||||||
|
|
||||||
|
|
24
nerfreal.py
24
nerfreal.py
|
@ -20,9 +20,6 @@ class NeRFReal:
|
||||||
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
|
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
|
||||||
self.W = opt.W
|
self.W = opt.W
|
||||||
self.H = opt.H
|
self.H = opt.H
|
||||||
self.debug = debug
|
|
||||||
self.training = False
|
|
||||||
self.step = 0 # training step
|
|
||||||
|
|
||||||
self.trainer = trainer
|
self.trainer = trainer
|
||||||
self.data_loader = data_loader
|
self.data_loader = data_loader
|
||||||
|
@ -44,7 +41,6 @@ class NeRFReal:
|
||||||
#self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item()
|
#self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item()
|
||||||
|
|
||||||
# playing seq from dataloader, or pause.
|
# playing seq from dataloader, or pause.
|
||||||
self.playing = True #False todo
|
|
||||||
self.loader = iter(data_loader)
|
self.loader = iter(data_loader)
|
||||||
|
|
||||||
#self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
|
#self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
|
||||||
|
@ -62,9 +58,8 @@ class NeRFReal:
|
||||||
self.customimg_index = 0
|
self.customimg_index = 0
|
||||||
|
|
||||||
# build asr
|
# build asr
|
||||||
if self.opt.asr:
|
self.asr = ASR(opt)
|
||||||
self.asr = ASR(opt)
|
self.asr.warm_up()
|
||||||
self.asr.warm_up()
|
|
||||||
if opt.tts == "edgetts":
|
if opt.tts == "edgetts":
|
||||||
self.tts = EdgeTTS(opt,self)
|
self.tts = EdgeTTS(opt,self)
|
||||||
elif opt.tts == "gpt-sovits":
|
elif opt.tts == "gpt-sovits":
|
||||||
|
@ -122,7 +117,11 @@ class NeRFReal:
|
||||||
self.tts.put_msg_txt(msg)
|
self.tts.put_msg_txt(msg)
|
||||||
|
|
||||||
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
|
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
|
||||||
self.asr.put_audio_frame(audio_chunk)
|
self.asr.put_audio_frame(audio_chunk)
|
||||||
|
|
||||||
|
def pause_talk(self):
|
||||||
|
self.tts.pause_talk()
|
||||||
|
self.asr.pause_talk()
|
||||||
|
|
||||||
|
|
||||||
def mirror_index(self, index):
|
def mirror_index(self, index):
|
||||||
|
@ -248,10 +247,9 @@ class NeRFReal:
|
||||||
# update texture every frame
|
# update texture every frame
|
||||||
# audio stream thread...
|
# audio stream thread...
|
||||||
t = time.perf_counter()
|
t = time.perf_counter()
|
||||||
if self.opt.asr and self.playing:
|
# run 2 ASR steps (audio is at 50FPS, video is at 25FPS)
|
||||||
# run 2 ASR steps (audio is at 50FPS, video is at 25FPS)
|
for _ in range(2):
|
||||||
for _ in range(2):
|
self.asr.run_step()
|
||||||
self.asr.run_step()
|
|
||||||
self.test_step(loop,audio_track,video_track)
|
self.test_step(loop,audio_track,video_track)
|
||||||
totaltime += (time.perf_counter() - t)
|
totaltime += (time.perf_counter() - t)
|
||||||
count += 1
|
count += 1
|
||||||
|
@ -267,7 +265,7 @@ class NeRFReal:
|
||||||
else:
|
else:
|
||||||
if video_track._queue.qsize()>=5:
|
if video_track._queue.qsize()>=5:
|
||||||
#print('sleep qsize=',video_track._queue.qsize())
|
#print('sleep qsize=',video_track._queue.qsize())
|
||||||
time.sleep(0.1)
|
time.sleep(0.04*video_track._queue.qsize()*0.8)
|
||||||
print('nerfreal thread stop')
|
print('nerfreal thread stop')
|
||||||
|
|
||||||
|
|
17
ttsreal.py
17
ttsreal.py
|
@ -13,6 +13,11 @@ import queue
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from threading import Thread, Event
|
from threading import Thread, Event
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
class State(Enum):
|
||||||
|
RUNNING=0
|
||||||
|
PAUSE=1
|
||||||
|
|
||||||
class BaseTTS:
|
class BaseTTS:
|
||||||
def __init__(self, opt, parent):
|
def __init__(self, opt, parent):
|
||||||
|
@ -25,6 +30,11 @@ class BaseTTS:
|
||||||
self.input_stream = BytesIO()
|
self.input_stream = BytesIO()
|
||||||
|
|
||||||
self.msgqueue = Queue()
|
self.msgqueue = Queue()
|
||||||
|
self.state = State.RUNNING
|
||||||
|
|
||||||
|
def pause_talk(self):
|
||||||
|
self.msgqueue.queue.clear()
|
||||||
|
self.state = State.PAUSE
|
||||||
|
|
||||||
def put_msg_txt(self,msg):
|
def put_msg_txt(self,msg):
|
||||||
self.msgqueue.put(msg)
|
self.msgqueue.put(msg)
|
||||||
|
@ -37,6 +47,7 @@ class BaseTTS:
|
||||||
while not quit_event.is_set():
|
while not quit_event.is_set():
|
||||||
try:
|
try:
|
||||||
msg = self.msgqueue.get(block=True, timeout=1)
|
msg = self.msgqueue.get(block=True, timeout=1)
|
||||||
|
self.state=State.RUNNING
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
self.txt_to_audio(msg)
|
self.txt_to_audio(msg)
|
||||||
|
@ -59,7 +70,7 @@ class EdgeTTS(BaseTTS):
|
||||||
stream = self.__create_bytes_stream(self.input_stream)
|
stream = self.__create_bytes_stream(self.input_stream)
|
||||||
streamlen = stream.shape[0]
|
streamlen = stream.shape[0]
|
||||||
idx=0
|
idx=0
|
||||||
while streamlen >= self.chunk:
|
while streamlen >= self.chunk and self.state==State.RUNNING:
|
||||||
self.parent.put_audio_frame(stream[idx:idx+self.chunk])
|
self.parent.put_audio_frame(stream[idx:idx+self.chunk])
|
||||||
streamlen -= self.chunk
|
streamlen -= self.chunk
|
||||||
idx += self.chunk
|
idx += self.chunk
|
||||||
|
@ -92,7 +103,7 @@ class EdgeTTS(BaseTTS):
|
||||||
async for chunk in communicate.stream():
|
async for chunk in communicate.stream():
|
||||||
if first:
|
if first:
|
||||||
first = False
|
first = False
|
||||||
if chunk["type"] == "audio":
|
if chunk["type"] == "audio" and self.state==State.RUNNING:
|
||||||
#self.push_audio(chunk["data"])
|
#self.push_audio(chunk["data"])
|
||||||
self.input_stream.write(chunk["data"])
|
self.input_stream.write(chunk["data"])
|
||||||
#file.write(chunk["data"])
|
#file.write(chunk["data"])
|
||||||
|
@ -147,7 +158,7 @@ class VoitsTTS(BaseTTS):
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
print(f"gpt_sovits Time to first chunk: {end-start}s")
|
print(f"gpt_sovits Time to first chunk: {end-start}s")
|
||||||
first = False
|
first = False
|
||||||
if chunk:
|
if chunk and self.state==State.RUNNING:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
print("gpt_sovits response.elapsed:", res.elapsed)
|
print("gpt_sovits response.elapsed:", res.elapsed)
|
||||||
|
|
|
@ -29,22 +29,22 @@
|
||||||
|
|
||||||
$(document).ready(function() {
|
$(document).ready(function() {
|
||||||
var host = window.location.hostname
|
var host = window.location.hostname
|
||||||
var ws = new WebSocket("ws://"+host+":8000/humanchat");
|
// var ws = new WebSocket("ws://"+host+":8000/humanecho");
|
||||||
//document.getElementsByTagName("video")[0].setAttribute("src", aa["video"]);
|
// //document.getElementsByTagName("video")[0].setAttribute("src", aa["video"]);
|
||||||
ws.onopen = function() {
|
// ws.onopen = function() {
|
||||||
console.log('Connected');
|
// console.log('Connected');
|
||||||
};
|
// };
|
||||||
ws.onmessage = function(e) {
|
// ws.onmessage = function(e) {
|
||||||
console.log('Received: ' + e.data);
|
// console.log('Received: ' + e.data);
|
||||||
data = e
|
// data = e
|
||||||
var vid = JSON.parse(data.data);
|
// var vid = JSON.parse(data.data);
|
||||||
console.log(typeof(vid),vid)
|
// console.log(typeof(vid),vid)
|
||||||
//document.getElementsByTagName("video")[0].setAttribute("src", vid["video"]);
|
// //document.getElementsByTagName("video")[0].setAttribute("src", vid["video"]);
|
||||||
|
|
||||||
};
|
// };
|
||||||
ws.onclose = function(e) {
|
// ws.onclose = function(e) {
|
||||||
console.log('Closed');
|
// console.log('Closed');
|
||||||
};
|
// };
|
||||||
|
|
||||||
flvPlayer = mpegts.createPlayer({type: 'flv', url: "http://"+host+":8080/live/livestream.flv", isLive: true, enableStashBuffer: false});
|
flvPlayer = mpegts.createPlayer({type: 'flv', url: "http://"+host+":8080/live/livestream.flv", isLive: true, enableStashBuffer: false});
|
||||||
flvPlayer.attachMediaElement(document.getElementById('video_player'));
|
flvPlayer.attachMediaElement(document.getElementById('video_player'));
|
||||||
|
@ -55,9 +55,19 @@
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
var message = $('#message').val();
|
var message = $('#message').val();
|
||||||
console.log('Sending: ' + message);
|
console.log('Sending: ' + message);
|
||||||
ws.send(message);
|
fetch('/human', {
|
||||||
|
body: JSON.stringify({
|
||||||
|
text: message,
|
||||||
|
type: 'chat',
|
||||||
|
}),
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
},
|
||||||
|
method: 'POST'
|
||||||
|
});
|
||||||
|
//ws.send(message);
|
||||||
$('#message').val('');
|
$('#message').val('');
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
</script>
|
</script>
|
||||||
</html>
|
</html>
|
|
@ -51,30 +51,40 @@
|
||||||
<script type="text/javascript" charset="utf-8">
|
<script type="text/javascript" charset="utf-8">
|
||||||
|
|
||||||
$(document).ready(function() {
|
$(document).ready(function() {
|
||||||
var host = window.location.hostname
|
// var host = window.location.hostname
|
||||||
var ws = new WebSocket("ws://"+host+":8000/humanchat");
|
// var ws = new WebSocket("ws://"+host+":8000/humanecho");
|
||||||
//document.getElementsByTagName("video")[0].setAttribute("src", aa["video"]);
|
// //document.getElementsByTagName("video")[0].setAttribute("src", aa["video"]);
|
||||||
ws.onopen = function() {
|
// ws.onopen = function() {
|
||||||
console.log('Connected');
|
// console.log('Connected');
|
||||||
};
|
// };
|
||||||
ws.onmessage = function(e) {
|
// ws.onmessage = function(e) {
|
||||||
console.log('Received: ' + e.data);
|
// console.log('Received: ' + e.data);
|
||||||
data = e
|
// data = e
|
||||||
var vid = JSON.parse(data.data);
|
// var vid = JSON.parse(data.data);
|
||||||
console.log(typeof(vid),vid)
|
// console.log(typeof(vid),vid)
|
||||||
//document.getElementsByTagName("video")[0].setAttribute("src", vid["video"]);
|
// //document.getElementsByTagName("video")[0].setAttribute("src", vid["video"]);
|
||||||
|
|
||||||
};
|
// };
|
||||||
ws.onclose = function(e) {
|
// ws.onclose = function(e) {
|
||||||
console.log('Closed');
|
// console.log('Closed');
|
||||||
};
|
// };
|
||||||
|
|
||||||
$('#echo-form').on('submit', function(e) {
|
$('#echo-form').on('submit', function(e) {
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
var message = $('#message').val();
|
var message = $('#message').val();
|
||||||
console.log('Sending: ' + message);
|
console.log('Sending: ' + message);
|
||||||
ws.send(message);
|
fetch('/human', {
|
||||||
$('#message').val('');
|
body: JSON.stringify({
|
||||||
|
text: message,
|
||||||
|
type: 'chat',
|
||||||
|
}),
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
},
|
||||||
|
method: 'POST'
|
||||||
|
});
|
||||||
|
//ws.send(message);
|
||||||
|
$('#message').val('');
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -79,6 +79,7 @@
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
text: message,
|
text: message,
|
||||||
type: 'echo',
|
type: 'echo',
|
||||||
|
interrupt: true,
|
||||||
}),
|
}),
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
|
|
|
@ -53,30 +53,41 @@
|
||||||
<script type="text/javascript" charset="utf-8">
|
<script type="text/javascript" charset="utf-8">
|
||||||
|
|
||||||
$(document).ready(function() {
|
$(document).ready(function() {
|
||||||
var host = window.location.hostname
|
// var host = window.location.hostname
|
||||||
var ws = new WebSocket("ws://"+host+":8000/humanchat");
|
// var ws = new WebSocket("ws://"+host+":8000/humanecho");
|
||||||
//document.getElementsByTagName("video")[0].setAttribute("src", aa["video"]);
|
// //document.getElementsByTagName("video")[0].setAttribute("src", aa["video"]);
|
||||||
ws.onopen = function() {
|
// ws.onopen = function() {
|
||||||
console.log('Connected');
|
// console.log('Connected');
|
||||||
};
|
// };
|
||||||
ws.onmessage = function(e) {
|
// ws.onmessage = function(e) {
|
||||||
console.log('Received: ' + e.data);
|
// console.log('Received: ' + e.data);
|
||||||
data = e
|
// data = e
|
||||||
var vid = JSON.parse(data.data);
|
// var vid = JSON.parse(data.data);
|
||||||
console.log(typeof(vid),vid)
|
// console.log(typeof(vid),vid)
|
||||||
//document.getElementsByTagName("video")[0].setAttribute("src", vid["video"]);
|
// //document.getElementsByTagName("video")[0].setAttribute("src", vid["video"]);
|
||||||
|
|
||||||
};
|
// };
|
||||||
ws.onclose = function(e) {
|
// ws.onclose = function(e) {
|
||||||
console.log('Closed');
|
// console.log('Closed');
|
||||||
};
|
// };
|
||||||
|
|
||||||
$('#echo-form').on('submit', function(e) {
|
$('#echo-form').on('submit', function(e) {
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
var message = $('#message').val();
|
var message = $('#message').val();
|
||||||
console.log('Sending: ' + message);
|
console.log('Sending: ' + message);
|
||||||
ws.send(message);
|
fetch('/human', {
|
||||||
$('#message').val('');
|
body: JSON.stringify({
|
||||||
|
text: message,
|
||||||
|
type: 'chat',
|
||||||
|
interrupt: true,
|
||||||
|
}),
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
},
|
||||||
|
method: 'POST'
|
||||||
|
});
|
||||||
|
//ws.send(message);
|
||||||
|
$('#message').val('');
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
</script>
|
</script>
|
||||||
|
|
Loading…
Reference in New Issue