livetalking/app.py

483 lines
20 KiB
Python
Raw Normal View History

2023-12-19 09:41:52 +08:00
# server.py
2024-04-14 19:08:25 +08:00
from flask import Flask, render_template,send_from_directory,request, jsonify
2023-12-19 09:41:52 +08:00
from flask_sockets import Sockets
import base64
import time
import json
import gevent
from gevent import pywsgi
from geventwebsocket.handler import WebSocketHandler
import os
import re
import numpy as np
2024-04-14 19:08:25 +08:00
from threading import Thread,Event
2023-12-28 13:11:18 +08:00
import multiprocessing
2023-12-19 09:41:52 +08:00
2024-04-14 19:08:25 +08:00
from aiohttp import web
2024-04-27 18:08:57 +08:00
import aiohttp
2024-04-14 19:08:25 +08:00
from aiortc import RTCPeerConnection, RTCSessionDescription
from webrtc import HumanPlayer
2023-12-19 09:41:52 +08:00
import argparse
2024-05-02 21:05:16 +08:00
from ernerf.nerf_triplane.provider import NeRFDataset_Test
from ernerf.nerf_triplane.utils import *
from ernerf.nerf_triplane.network import NeRFNetwork
2023-12-19 09:41:52 +08:00
from nerfreal import NeRFReal
import shutil
import asyncio
import edge_tts
2024-02-25 18:54:40 +08:00
from typing import Iterator
import requests
2023-12-19 09:41:52 +08:00
app = Flask(__name__)
sockets = Sockets(app)
global nerfreal
2024-02-25 18:54:40 +08:00
global tts_type
global gspeaker
2023-12-19 09:41:52 +08:00
async def main(voicename: str, text: str, render):
communicate = edge_tts.Communicate(text, voicename)
#with open(OUTPUT_FILE, "wb") as file:
first = True
2023-12-19 09:41:52 +08:00
async for chunk in communicate.stream():
if first:
#render.before_push_audio()
first = False
2023-12-19 09:41:52 +08:00
if chunk["type"] == "audio":
render.push_audio(chunk["data"])
#file.write(chunk["data"])
elif chunk["type"] == "WordBoundary":
pass
2024-01-28 07:58:46 +08:00
2024-02-25 18:54:40 +08:00
def get_speaker(ref_audio,server_url):
files = {"wav_file": ("reference.wav", open(ref_audio, "rb"))}
response = requests.post(f"{server_url}/clone_speaker", files=files)
return response.json()
def xtts(text, speaker, language, server_url, stream_chunk_size) -> Iterator[bytes]:
start = time.perf_counter()
speaker["text"] = text
speaker["language"] = language
speaker["stream_chunk_size"] = stream_chunk_size # you can reduce it to get faster response, but degrade quality
res = requests.post(
f"{server_url}/tts_stream",
json=speaker,
stream=True,
)
end = time.perf_counter()
print(f"xtts Time to make POST: {end-start}s")
if res.status_code != 200:
print("Error:", res.text)
return
first = True
2024-04-21 17:09:08 +08:00
for chunk in res.iter_content(chunk_size=960): #24K*20ms*2
2024-02-25 18:54:40 +08:00
if first:
end = time.perf_counter()
print(f"xtts Time to first chunk: {end-start}s")
first = False
if chunk:
yield chunk
print("xtts response.elapsed:", res.elapsed)
2024-04-21 18:19:24 +08:00
def gpt_sovits(text, character, language, server_url, emotion) -> Iterator[bytes]:
2024-04-19 22:29:08 +08:00
start = time.perf_counter()
2024-04-21 17:09:08 +08:00
req={}
req["text"] = text
req["text_language"] = language
req["character"] = character
2024-04-21 18:19:24 +08:00
req["emotion"] = emotion
2024-04-21 17:09:08 +08:00
#req["stream_chunk_size"] = stream_chunk_size # you can reduce it to get faster response, but degrade quality
req["stream"] = True
res = requests.post(
f"{server_url}/tts",
json=req,
stream=True,
)
2024-04-19 22:29:08 +08:00
end = time.perf_counter()
print(f"gpt_sovits Time to make POST: {end-start}s")
if res.status_code != 200:
print("Error:", res.text)
return
first = True
2024-04-21 17:09:08 +08:00
for chunk in res.iter_content(chunk_size=1280): #32K*20ms*2
2024-04-19 22:29:08 +08:00
if first:
end = time.perf_counter()
print(f"gpt_sovits Time to first chunk: {end-start}s")
first = False
if chunk:
yield chunk
print("gpt_sovits response.elapsed:", res.elapsed)
2024-04-21 17:09:08 +08:00
def stream_tts(audio_stream,render):
2024-02-25 18:54:40 +08:00
for chunk in audio_stream:
if chunk is not None:
render.push_audio(chunk)
2023-12-19 09:41:52 +08:00
def txt_to_audio(text_):
2024-02-25 18:54:40 +08:00
if tts_type == "edgetts":
voicename = "zh-CN-YunxiaNeural"
text = text_
t = time.time()
asyncio.get_event_loop().run_until_complete(main(voicename,text,nerfreal))
print(f'-------edge tts time:{time.time()-t:.4f}s')
2024-04-21 17:09:08 +08:00
elif tts_type == "gpt-sovits": #gpt_sovits
stream_tts(
2024-04-19 22:29:08 +08:00
gpt_sovits(
text_,
2024-04-21 18:19:24 +08:00
app.config['CHARACTER'], #"test", #character
2024-04-21 17:09:08 +08:00
"zh", #en args.language,
2024-04-21 18:19:24 +08:00
app.config['TTS_SERVER'], #"http://127.0.0.1:5000", #args.server_url,
app.config['EMOTION'], #emotion
2024-04-19 22:29:08 +08:00
),
nerfreal
)
2024-02-25 18:54:40 +08:00
else: #xtts
2024-04-21 17:09:08 +08:00
stream_tts(
2024-02-25 18:54:40 +08:00
xtts(
text_,
gspeaker,
"zh-cn", #en args.language,
2024-04-21 18:19:24 +08:00
app.config['TTS_SERVER'], #"http://localhost:9000", #args.server_url,
2024-02-25 18:54:40 +08:00
"20" #args.stream_chunk_size
),
nerfreal
)
2023-12-19 09:41:52 +08:00
2023-12-28 13:11:18 +08:00
@sockets.route('/humanecho')
2023-12-19 09:41:52 +08:00
def echo_socket(ws):
# 获取WebSocket对象
#ws = request.environ.get('wsgi.websocket')
# 如果没有获取到,返回错误信息
if not ws:
print('未建立连接!')
return 'Please use WebSocket'
# 否则,循环接收和发送消息
else:
print('建立连接!')
while True:
message = ws.receive()
2024-03-23 18:15:35 +08:00
if not message or len(message)==0:
2023-12-19 09:41:52 +08:00
return '输入信息为空'
else:
2024-01-27 11:13:16 +08:00
txt_to_audio(message)
2024-02-25 18:54:40 +08:00
def llm_response(message):
from llm.LLM import LLM
# llm = LLM().init_model('Gemini', model_path= 'gemini-pro',api_key='Your API Key', proxy_url=None)
2024-04-03 22:02:23 +08:00
# llm = LLM().init_model('ChatGPT', model_path= 'gpt-3.5-turbo',api_key='Your API Key')
llm = LLM().init_model('VllmGPT', model_path= 'THUDM/chatglm3-6b')
2024-02-25 18:54:40 +08:00
response = llm.chat(message)
print(response)
return response
2024-01-27 11:13:16 +08:00
@sockets.route('/humanchat')
def chat_socket(ws):
# 获取WebSocket对象
#ws = request.environ.get('wsgi.websocket')
# 如果没有获取到,返回错误信息
if not ws:
print('未建立连接!')
return 'Please use WebSocket'
# 否则,循环接收和发送消息
else:
print('建立连接!')
while True:
message = ws.receive()
if len(message)==0:
return '输入信息为空'
else:
2024-01-27 19:38:13 +08:00
res=llm_response(message)
2024-04-14 19:08:25 +08:00
txt_to_audio(res)
#####webrtc###############################
pcs = set()
#@app.route('/offer', methods=['POST'])
async def offer(request):
params = await request.json()
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
pc = RTCPeerConnection()
pcs.add(pc)
@pc.on("connectionstatechange")
async def on_connectionstatechange():
print("Connection state is %s" % pc.connectionState)
if pc.connectionState == "failed":
await pc.close()
pcs.discard(pc)
player = HumanPlayer(nerfreal)
audio_sender = pc.addTrack(player.audio)
video_sender = pc.addTrack(player.video)
await pc.setRemoteDescription(offer)
2023-12-19 09:41:52 +08:00
2024-04-14 19:08:25 +08:00
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
#return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type})
return web.Response(
content_type="application/json",
text=json.dumps(
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
),
)
async def on_shutdown(app):
# close peer connections
coros = [pc.close() for pc in pcs]
await asyncio.gather(*coros)
pcs.clear()
2024-04-27 18:08:57 +08:00
async def post(url,data):
try:
async with aiohttp.ClientSession() as session:
async with session.post(url,data=data) as response:
return await response.text()
except aiohttp.ClientError as e:
print(f'Error: {e}')
async def run(push_url):
pc = RTCPeerConnection()
pcs.add(pc)
@pc.on("connectionstatechange")
async def on_connectionstatechange():
print("Connection state is %s" % pc.connectionState)
if pc.connectionState == "failed":
await pc.close()
pcs.discard(pc)
player = HumanPlayer(nerfreal)
audio_sender = pc.addTrack(player.audio)
video_sender = pc.addTrack(player.video)
await pc.setLocalDescription(await pc.createOffer())
answer = await post(push_url,pc.localDescription.sdp)
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer'))
2024-04-14 19:08:25 +08:00
##########################################
2023-12-19 09:41:52 +08:00
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--pose', type=str, default="data/data_kf.json", help="transforms.json, pose source")
2024-02-25 18:54:40 +08:00
parser.add_argument('--au', type=str, default="data/au.csv", help="eye blink area")
2024-03-31 12:01:28 +08:00
parser.add_argument('--torso_imgs', type=str, default="", help="torso images path")
2023-12-19 09:41:52 +08:00
parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --exp_eye")
parser.add_argument('--data_range', type=int, nargs='*', default=[0, -1], help="data range to use")
parser.add_argument('--workspace', type=str, default='data/video')
parser.add_argument('--seed', type=int, default=0)
### training options
parser.add_argument('--ckpt', type=str, default='data/pretrained/ngp_kf.pth')
parser.add_argument('--num_rays', type=int, default=4096 * 16, help="num rays sampled per image for each training step")
parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
parser.add_argument('--max_steps', type=int, default=16, help="max num steps sampled per ray (only valid when using --cuda_ray)")
parser.add_argument('--num_steps', type=int, default=16, help="num steps sampled per ray (only valid when NOT using --cuda_ray)")
parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)")
parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)")
### loss set
parser.add_argument('--warmup_step', type=int, default=10000, help="warm up steps")
parser.add_argument('--amb_aud_loss', type=int, default=1, help="use ambient aud loss")
parser.add_argument('--amb_eye_loss', type=int, default=1, help="use ambient eye loss")
parser.add_argument('--unc_loss', type=int, default=1, help="use uncertainty loss")
parser.add_argument('--lambda_amb', type=float, default=1e-4, help="lambda for ambient loss")
### network backbone options
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
parser.add_argument('--bg_img', type=str, default='white', help="background image")
parser.add_argument('--fbg', action='store_true', help="frame-wise bg")
parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes")
parser.add_argument('--fix_eye', type=float, default=-1, help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye")
parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence")
parser.add_argument('--torso_shrink', type=float, default=0.8, help="shrink bg coords to allow more flexibility in deform")
### dataset options
parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)")
parser.add_argument('--preload', type=int, default=0, help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.")
# (the default value is for the fox dataset)
parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.")
parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3")
parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location")
parser.add_argument('--dt_gamma', type=float, default=1/256, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera")
parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied (sigma)")
parser.add_argument('--density_thresh_torso', type=float, default=0.01, help="threshold for density grid to be occupied (alpha)")
parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable")
parser.add_argument('--init_lips', action='store_true', help="init lips region")
parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region")
parser.add_argument('--smooth_lips', action='store_true', help="smooth the enc_a in a exponential decay way...")
parser.add_argument('--torso', action='store_true', help="fix head and train torso")
parser.add_argument('--head_ckpt', type=str, default='', help="head model")
### GUI options
parser.add_argument('--gui', action='store_true', help="start a GUI")
parser.add_argument('--W', type=int, default=450, help="GUI width")
parser.add_argument('--H', type=int, default=450, help="GUI height")
parser.add_argument('--radius', type=float, default=3.35, help="default GUI camera radius from center")
parser.add_argument('--fovy', type=float, default=21.24, help="default GUI camera fovy")
parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel")
### else
parser.add_argument('--att', type=int, default=2, help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)")
parser.add_argument('--aud', type=str, default='', help="audio source (empty will load the default, else should be a path to a npy file)")
parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits")
parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off")
parser.add_argument('--ind_num', type=int, default=10000, help="number of individual codes, should be larger than training dataset size")
parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off")
parser.add_argument('--amb_dim', type=int, default=2, help="ambient dimension")
parser.add_argument('--part', action='store_true', help="use partial training data (1/10)")
parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)")
parser.add_argument('--train_camera', action='store_true', help="optimize camera pose")
parser.add_argument('--smooth_path', action='store_true', help="brute-force smooth camera pose trajectory with a window size")
parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size")
# asr
parser.add_argument('--asr', action='store_true', help="load asr for real-time app")
parser.add_argument('--asr_wav', type=str, default='', help="load the wav and use as input")
parser.add_argument('--asr_play', action='store_true', help="play out the audio")
#parser.add_argument('--asr_model', type=str, default='deepspeech')
2024-04-05 20:59:55 +08:00
parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') #
2023-12-19 09:41:52 +08:00
# parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
2024-04-05 20:59:55 +08:00
# parser.add_argument('--asr_model', type=str, default='facebook/hubert-large-ls960-ft')
2023-12-19 09:41:52 +08:00
2024-05-02 20:32:28 +08:00
parser.add_argument('--transport', type=str, default='rtcpush') #rtmp webrtc rtcpush
parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream
2023-12-27 12:37:03 +08:00
2023-12-19 09:41:52 +08:00
parser.add_argument('--asr_save_feats', action='store_true')
# audio FPS
parser.add_argument('--fps', type=int, default=50)
# sliding window left-middle-right length (unit: 20ms)
parser.add_argument('-l', type=int, default=10)
parser.add_argument('-m', type=int, default=8)
2023-12-19 09:41:52 +08:00
parser.add_argument('-r', type=int, default=10)
2024-03-23 21:13:21 +08:00
parser.add_argument('--fullbody', action='store_true', help="fullbody human")
parser.add_argument('--fullbody_img', type=str, default='data/fullbody/img')
parser.add_argument('--fullbody_width', type=int, default=580)
parser.add_argument('--fullbody_height', type=int, default=1080)
parser.add_argument('--fullbody_offset_x', type=int, default=0)
parser.add_argument('--fullbody_offset_y', type=int, default=0)
2024-05-04 10:10:41 +08:00
parser.add_argument('--customvideo', action='store_true', help="custom video")
parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img')
parser.add_argument('--customvideo_imgnum', type=int, default=1)
2024-04-21 17:09:08 +08:00
parser.add_argument('--tts', type=str, default='edgetts') #xtts gpt-sovits
2024-04-21 18:19:24 +08:00
parser.add_argument('--REF_FILE', type=str, default=None)
parser.add_argument('--TTS_SERVER', type=str, default='http://localhost:9000') #http://127.0.0.1:5000
parser.add_argument('--CHARACTER', type=str, default='test')
parser.add_argument('--EMOTION', type=str, default='default')
2024-02-25 18:54:40 +08:00
2023-12-19 09:41:52 +08:00
opt = parser.parse_args()
2024-02-25 18:54:40 +08:00
app.config.from_object(opt)
2024-04-21 18:19:24 +08:00
print(app.config)
2024-02-25 18:54:40 +08:00
tts_type = opt.tts
if tts_type == "xtts":
print("Computing the latents for a new reference...")
2024-05-02 20:32:28 +08:00
gspeaker = get_speaker(opt.REF_FILE, opt.TTS_SERVER)
2023-12-19 09:41:52 +08:00
# assert test mode
opt.test = True
opt.test_train = False
#opt.train_camera =True
# explicit smoothing
opt.smooth_path = True
opt.smooth_lips = True
assert opt.pose != '', 'Must provide a pose source'
# if opt.O:
opt.fp16 = True
2024-02-25 18:54:40 +08:00
opt.cuda_ray = True
2023-12-19 09:41:52 +08:00
opt.exp_eye = True
2024-02-25 18:54:40 +08:00
opt.smooth_eye = True
2023-12-19 09:41:52 +08:00
2024-03-31 12:01:28 +08:00
if opt.torso_imgs=='': #no img,use model output
opt.torso = True
2024-02-25 18:54:40 +08:00
2023-12-19 09:41:52 +08:00
# assert opt.cuda_ray, "Only support CUDA ray mode."
opt.asr = True
if opt.patch_size > 1:
# assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss."
assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays."
seed_everything(opt.seed)
2024-03-31 12:01:28 +08:00
print(opt)
2023-12-19 09:41:52 +08:00
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = NeRFNetwork(opt)
criterion = torch.nn.MSELoss(reduction='none')
metrics = [] # use no metric in GUI for faster initialization...
print(model)
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt)
test_loader = NeRFDataset_Test(opt, device=device).dataloader()
model.aud_features = test_loader._data.auds
model.eye_areas = test_loader._data.eye_area
# we still need test_loader to provide audio features for testing.
nerfreal = NeRFReal(opt, trainer, test_loader)
2023-12-28 13:11:18 +08:00
#txt_to_audio('我是中国人,我来自北京')
2024-04-14 19:08:25 +08:00
if opt.transport=='rtmp':
thread_quit = Event()
rendthrd = Thread(target=nerfreal.render,args=(thread_quit,))
rendthrd.start()
2023-12-19 09:41:52 +08:00
#############################################################################
2024-04-14 19:08:25 +08:00
appasync = web.Application()
appasync.on_shutdown.append(on_shutdown)
appasync.router.add_post("/offer", offer)
appasync.router.add_static('/',path='web')
def run_server(runner):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(runner.setup())
site = web.TCPSite(runner, '0.0.0.0', 8010)
loop.run_until_complete(site.start())
2024-04-27 18:08:57 +08:00
if opt.transport=='rtcpush':
loop.run_until_complete(run(opt.push_url))
2024-04-14 19:08:25 +08:00
loop.run_forever()
Thread(target=run_server, args=(web.AppRunner(appasync),)).start()
2024-02-25 18:54:40 +08:00
2024-04-14 19:08:25 +08:00
print('start websocket server')
#app.on_shutdown.append(on_shutdown)
#app.router.add_post("/offer", offer)
2023-12-28 13:11:18 +08:00
server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler)
2023-12-19 09:41:52 +08:00
server.serve_forever()