diff --git a/README.md b/README.md index 69cb1f6..a88dd0b 100644 --- a/README.md +++ b/README.md @@ -47,25 +47,16 @@ python app.py export HF_ENDPOINT=https://hf-mirror.com ``` -运行成功后,用vlc访问rtmp://serverip/live/livestream +运行成功后,用vlc访问rtmp://serverip/live/livestream -### 2.3 网页端数字人播报输入文字 -安装并启动nginx -``` -apt install nginx -nginx -``` -将echo.html和mpegts-1.7.3.min.js拷到/var/www/html下 - -用浏览器打开http://serverip/echo.html, 在文本框输入任意文字,提交。数字人播报该段文字 +用浏览器打开http://serverip:8010/echo.html, 在文本框输入任意文字,提交。数字人播报该段文字 ## 3. More Usage ### 3.1 使用LLM模型进行数字人对话 -目前借鉴数字人对话系统[LinlyTalker](https://github.com/Kedreamix/Linly-Talker)的方式,LLM模型支持Chatgpt,Qwen和GeminiPro。需要在app.py中填入自己的api_key。 -安装并启动nginx,将chat.html和mpegts-1.7.3.min.js拷到/var/www/html下 +目前借鉴数字人对话系统[LinlyTalker](https://github.com/Kedreamix/Linly-Talker)的方式,LLM模型支持Chatgpt,Qwen和GeminiPro。需要在app.py中填入自己的api_key。 -用浏览器打开http://serverip/chat.html +用浏览器打开http://serverip:8010/chat.html ### 3.2 使用本地tts服务,支持声音克隆 运行xtts服务,参照 https://github.com/coqui-ai/xtts-streaming-server @@ -105,13 +96,20 @@ python app.py --fullbody --fullbody_img data/fullbody/img --fullbody_offset_x 10 - --fullbody_width、--fullbody_height 全身视频的宽、高 - --W、--H 训练视频的宽、高 - ernerf训练第三步torso如果训练的不好,在拼接处会有接缝。可以在上面的命令加上--torso_imgs data/xxx/torso_imgs,torso不用模型推理,直接用训练数据集里的torso图片。这种方式可能头颈处会有些人工痕迹。 + +### 3.6 webrtc +``` +python app.py --transport webrtc +``` +用浏览器打开http://serverip:8010/webrtc.html + ## 4. Docker Run 不需要第1步的安装,直接运行。 ``` docker run --gpus all -it --network=host --rm registry.cn-hangzhou.aliyuncs.com/lipku/nerfstream:v1.3 ``` -srs和nginx的运行同2.1和2.3 +srs的运行同2.1 ## 5. Data flow ![](/assets/dataflow.png) diff --git a/app.py b/app.py index f074278..0a42c1c 100644 --- a/app.py +++ b/app.py @@ -1,5 +1,5 @@ # server.py -from flask import Flask, request, jsonify +from flask import Flask, render_template,send_from_directory,request, jsonify from flask_sockets import Sockets import base64 import time @@ -10,9 +10,13 @@ from geventwebsocket.handler import WebSocketHandler import os import re import numpy as np -from threading import Thread +from threading import Thread,Event import multiprocessing +from aiohttp import web +from aiortc import RTCPeerConnection, RTCSessionDescription +from webrtc import HumanPlayer + import argparse from nerf_triplane.provider import NeRFDataset_Test from nerf_triplane.utils import * @@ -153,11 +157,51 @@ def chat_socket(ws): return '输入信息为空' else: res=llm_response(message) - txt_to_audio(res) + txt_to_audio(res) -def render(): - nerfreal.render() - +#####webrtc############################### +pcs = set() + +#@app.route('/offer', methods=['POST']) +async def offer(request): + params = await request.json() + offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) + + pc = RTCPeerConnection() + pcs.add(pc) + + @pc.on("connectionstatechange") + async def on_connectionstatechange(): + print("Connection state is %s" % pc.connectionState) + if pc.connectionState == "failed": + await pc.close() + pcs.discard(pc) + + player = HumanPlayer(nerfreal) + audio_sender = pc.addTrack(player.audio) + video_sender = pc.addTrack(player.video) + + await pc.setRemoteDescription(offer) + + answer = await pc.createAnswer() + await pc.setLocalDescription(answer) + + #return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}) + + return web.Response( + content_type="application/json", + text=json.dumps( + {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} + ), + ) + + +async def on_shutdown(app): + # close peer connections + coros = [pc.close() for pc in pcs] + await asyncio.gather(*coros) + pcs.clear() +########################################## if __name__ == '__main__': @@ -257,6 +301,7 @@ if __name__ == '__main__': # parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self') # parser.add_argument('--asr_model', type=str, default='facebook/hubert-large-ls960-ft') + parser.add_argument('--transport', type=str, default='rtmp') #rtmp webrtc parser.add_argument('--push_url', type=str, default='rtmp://localhost/live/livestream') parser.add_argument('--asr_save_feats', action='store_true') @@ -330,12 +375,29 @@ if __name__ == '__main__': # we still need test_loader to provide audio features for testing. nerfreal = NeRFReal(opt, trainer, test_loader) #txt_to_audio('我是中国人,我来自北京') - rendthrd = Thread(target=render) - rendthrd.start() + if opt.transport=='rtmp': + thread_quit = Event() + rendthrd = Thread(target=nerfreal.render,args=(thread_quit,)) + rendthrd.start() ############################################################################# - print('start websocket server') + appasync = web.Application() + appasync.on_shutdown.append(on_shutdown) + appasync.router.add_post("/offer", offer) + appasync.router.add_static('/',path='web') + def run_server(runner): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(runner.setup()) + site = web.TCPSite(runner, '0.0.0.0', 8010) + loop.run_until_complete(site.start()) + loop.run_forever() + Thread(target=run_server, args=(web.AppRunner(appasync),)).start() + + print('start websocket server') + #app.on_shutdown.append(on_shutdown) + #app.router.add_post("/offer", offer) server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler) server.serve_forever() diff --git a/nerfreal.py b/nerfreal.py index 426f08f..c6fe831 100644 --- a/nerfreal.py +++ b/nerfreal.py @@ -10,7 +10,9 @@ import torch.nn.functional as F import cv2 from asrreal import ASR +import asyncio from rtmp_streaming import StreamerConfig, Streamer +from av import AudioFrame, VideoFrame class NeRFReal: def __init__(self, opt, trainer, data_loader, debug=True): @@ -118,7 +120,7 @@ class NeRFReal: else: return np.expand_dims(outputs['depth'], -1).repeat(3, -1) - def test_step(self): + def test_step(self,loop=None,audio_track=None,video_track=None): #starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) #starter.record() @@ -140,7 +142,11 @@ class NeRFReal: #print(f'[INFO] outputs shape ',outputs['image'].shape) image = (outputs['image'] * 255).astype(np.uint8) if not self.opt.fullbody: - self.streamer.stream_frame(image) + if self.opt.transport=='rtmp': + self.streamer.stream_frame(image) + else: + new_frame = VideoFrame.from_ndarray(image, format="rgb24") + asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop) else: #fullbody human #print("frame index:",data['index']) image_fullbody = cv2.imread(os.path.join(self.opt.fullbody_img, str(data['index'][0])+'.jpg')) @@ -148,12 +154,23 @@ class NeRFReal: start_x = self.opt.fullbody_offset_x # 合并后小图片的起始x坐标 start_y = self.opt.fullbody_offset_y # 合并后小图片的起始y坐标 image_fullbody[start_y:start_y+image.shape[0], start_x:start_x+image.shape[1]] = image - self.streamer.stream_frame(image_fullbody) + if self.opt.transport=='rtmp': + self.streamer.stream_frame(image_fullbody) + else: + new_frame = VideoFrame.from_ndarray(image, format="rgb24") + asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop) #self.pipe.stdin.write(image.tostring()) for _ in range(2): frame = self.asr.get_audio_out() #print(f'[INFO] get_audio_out shape ',frame.shape) - self.streamer.stream_frame_audio(frame) + if self.opt.transport=='rtmp': + self.streamer.stream_frame_audio(frame) + else: + frame = (frame * 32767).astype(np.int16) + new_frame = AudioFrame(format='s16', layout='mono', samples=320) + new_frame.planes[0].update(frame.tobytes()) + new_frame.sample_rate=16000 + asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop) # frame = (frame * 32767).astype(np.int16).tobytes() # self.fifo_audio.write(frame) else: @@ -167,35 +184,36 @@ class NeRFReal: #torch.cuda.synchronize() #t = starter.elapsed_time(ender) - def render(self): + def render(self,quit_event,loop=None,audio_track=None,video_track=None): if self.opt.asr: self.asr.warm_up() count=0 totaltime=0 - fps=25 - #push_url='rtmp://localhost/live/livestream' #'data/video/output_0.mp4' - sc = StreamerConfig() - sc.source_width = self.W - sc.source_height = self.H - sc.stream_width = self.W - sc.stream_height = self.H - if self.opt.fullbody: - sc.source_width = self.opt.fullbody_width - sc.source_height = self.opt.fullbody_height - sc.stream_width = self.opt.fullbody_width - sc.stream_height = self.opt.fullbody_height - sc.stream_fps = fps - sc.stream_bitrate = 1000000 - sc.stream_profile = 'baseline' #'high444' # 'main' - sc.audio_channel = 1 - sc.sample_rate = 16000 - sc.stream_server = self.opt.push_url - self.streamer = Streamer() - self.streamer.init(sc) - #self.streamer.enable_av_debug_log() + if self.opt.transport=='rtmp': + fps=25 + #push_url='rtmp://localhost/live/livestream' #'data/video/output_0.mp4' + sc = StreamerConfig() + sc.source_width = self.W + sc.source_height = self.H + sc.stream_width = self.W + sc.stream_height = self.H + if self.opt.fullbody: + sc.source_width = self.opt.fullbody_width + sc.source_height = self.opt.fullbody_height + sc.stream_width = self.opt.fullbody_width + sc.stream_height = self.opt.fullbody_height + sc.stream_fps = fps + sc.stream_bitrate = 1000000 + sc.stream_profile = 'baseline' #'high444' # 'main' + sc.audio_channel = 1 + sc.sample_rate = 16000 + sc.stream_server = self.opt.push_url + self.streamer = Streamer() + self.streamer.init(sc) + #self.streamer.enable_av_debug_log() - while True: #todo + while not quit_event.is_set(): #todo # update texture every frame # audio stream thread... t = time.time() @@ -203,14 +221,14 @@ class NeRFReal: # run 2 ASR steps (audio is at 50FPS, video is at 25FPS) for _ in range(2): self.asr.run_step() - self.test_step() + self.test_step(loop,audio_track,video_track) totaltime += (time.time() - t) count += 1 if count==100: print(f"------actual avg fps:{count/totaltime:.4f}") count=0 totaltime=0 - # delay = 0.04 - (time.time() - t) #40ms - # if delay > 0: - # time.sleep(delay) + delay = 0.04 - (time.time() - t) #40ms + if delay > 0: + time.sleep(delay) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 12a5079..3fd1921 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,3 +31,4 @@ edge_tts flask flask_sockets opencv-python-headless +aiortc diff --git a/chat.html b/web/chat.html similarity index 100% rename from chat.html rename to web/chat.html diff --git a/web/client.js b/web/client.js new file mode 100644 index 0000000..c48b064 --- /dev/null +++ b/web/client.js @@ -0,0 +1,76 @@ +var pc = null; + +function negotiate() { + pc.addTransceiver('video', { direction: 'recvonly' }); + pc.addTransceiver('audio', { direction: 'recvonly' }); + return pc.createOffer().then((offer) => { + return pc.setLocalDescription(offer); + }).then(() => { + // wait for ICE gathering to complete + return new Promise((resolve) => { + if (pc.iceGatheringState === 'complete') { + resolve(); + } else { + const checkState = () => { + if (pc.iceGatheringState === 'complete') { + pc.removeEventListener('icegatheringstatechange', checkState); + resolve(); + } + }; + pc.addEventListener('icegatheringstatechange', checkState); + } + }); + }).then(() => { + var offer = pc.localDescription; + return fetch('/offer', { + body: JSON.stringify({ + sdp: offer.sdp, + type: offer.type, + }), + headers: { + 'Content-Type': 'application/json' + }, + method: 'POST' + }); + }).then((response) => { + return response.json(); + }).then((answer) => { + return pc.setRemoteDescription(answer); + }).catch((e) => { + alert(e); + }); +} + +function start() { + var config = { + sdpSemantics: 'unified-plan' + }; + + if (document.getElementById('use-stun').checked) { + config.iceServers = [{ urls: ['stun:stun.l.google.com:19302'] }]; + } + + pc = new RTCPeerConnection(config); + + // connect audio / video + pc.addEventListener('track', (evt) => { + if (evt.track.kind == 'video') { + document.getElementById('video').srcObject = evt.streams[0]; + } else { + document.getElementById('audio').srcObject = evt.streams[0]; + } + }); + + document.getElementById('start').style.display = 'none'; + negotiate(); + document.getElementById('stop').style.display = 'inline-block'; +} + +function stop() { + document.getElementById('stop').style.display = 'none'; + + // close peer connection + setTimeout(() => { + pc.close(); + }, 500); +} diff --git a/echo.html b/web/echo.html similarity index 100% rename from echo.html rename to web/echo.html diff --git a/mpegts-1.7.3.min.js b/web/mpegts-1.7.3.min.js similarity index 100% rename from mpegts-1.7.3.min.js rename to web/mpegts-1.7.3.min.js diff --git a/web/webrtc.html b/web/webrtc.html new file mode 100644 index 0000000..0103257 --- /dev/null +++ b/web/webrtc.html @@ -0,0 +1,83 @@ + + +
+ + +