diff --git a/README.md b/README.md index 69cb1f6..a88dd0b 100644 --- a/README.md +++ b/README.md @@ -47,25 +47,16 @@ python app.py export HF_ENDPOINT=https://hf-mirror.com ``` -运行成功后,用vlc访问rtmp://serverip/live/livestream +运行成功后,用vlc访问rtmp://serverip/live/livestream -### 2.3 网页端数字人播报输入文字 -安装并启动nginx -``` -apt install nginx -nginx -``` -将echo.html和mpegts-1.7.3.min.js拷到/var/www/html下 - -用浏览器打开http://serverip/echo.html, 在文本框输入任意文字,提交。数字人播报该段文字 +用浏览器打开http://serverip:8010/echo.html, 在文本框输入任意文字,提交。数字人播报该段文字 ## 3. More Usage ### 3.1 使用LLM模型进行数字人对话 -目前借鉴数字人对话系统[LinlyTalker](https://github.com/Kedreamix/Linly-Talker)的方式,LLM模型支持Chatgpt,Qwen和GeminiPro。需要在app.py中填入自己的api_key。 -安装并启动nginx,将chat.html和mpegts-1.7.3.min.js拷到/var/www/html下 +目前借鉴数字人对话系统[LinlyTalker](https://github.com/Kedreamix/Linly-Talker)的方式,LLM模型支持Chatgpt,Qwen和GeminiPro。需要在app.py中填入自己的api_key。 -用浏览器打开http://serverip/chat.html +用浏览器打开http://serverip:8010/chat.html ### 3.2 使用本地tts服务,支持声音克隆 运行xtts服务,参照 https://github.com/coqui-ai/xtts-streaming-server @@ -105,13 +96,20 @@ python app.py --fullbody --fullbody_img data/fullbody/img --fullbody_offset_x 10 - --fullbody_width、--fullbody_height 全身视频的宽、高 - --W、--H 训练视频的宽、高 - ernerf训练第三步torso如果训练的不好,在拼接处会有接缝。可以在上面的命令加上--torso_imgs data/xxx/torso_imgs,torso不用模型推理,直接用训练数据集里的torso图片。这种方式可能头颈处会有些人工痕迹。 + +### 3.6 webrtc +``` +python app.py --transport webrtc +``` +用浏览器打开http://serverip:8010/webrtc.html + ## 4. Docker Run 不需要第1步的安装,直接运行。 ``` docker run --gpus all -it --network=host --rm registry.cn-hangzhou.aliyuncs.com/lipku/nerfstream:v1.3 ``` -srs和nginx的运行同2.1和2.3 +srs的运行同2.1 ## 5. Data flow ![](/assets/dataflow.png) diff --git a/app.py b/app.py index f074278..0a42c1c 100644 --- a/app.py +++ b/app.py @@ -1,5 +1,5 @@ # server.py -from flask import Flask, request, jsonify +from flask import Flask, render_template,send_from_directory,request, jsonify from flask_sockets import Sockets import base64 import time @@ -10,9 +10,13 @@ from geventwebsocket.handler import WebSocketHandler import os import re import numpy as np -from threading import Thread +from threading import Thread,Event import multiprocessing +from aiohttp import web +from aiortc import RTCPeerConnection, RTCSessionDescription +from webrtc import HumanPlayer + import argparse from nerf_triplane.provider import NeRFDataset_Test from nerf_triplane.utils import * @@ -153,11 +157,51 @@ def chat_socket(ws): return '输入信息为空' else: res=llm_response(message) - txt_to_audio(res) + txt_to_audio(res) -def render(): - nerfreal.render() - +#####webrtc############################### +pcs = set() + +#@app.route('/offer', methods=['POST']) +async def offer(request): + params = await request.json() + offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) + + pc = RTCPeerConnection() + pcs.add(pc) + + @pc.on("connectionstatechange") + async def on_connectionstatechange(): + print("Connection state is %s" % pc.connectionState) + if pc.connectionState == "failed": + await pc.close() + pcs.discard(pc) + + player = HumanPlayer(nerfreal) + audio_sender = pc.addTrack(player.audio) + video_sender = pc.addTrack(player.video) + + await pc.setRemoteDescription(offer) + + answer = await pc.createAnswer() + await pc.setLocalDescription(answer) + + #return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}) + + return web.Response( + content_type="application/json", + text=json.dumps( + {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} + ), + ) + + +async def on_shutdown(app): + # close peer connections + coros = [pc.close() for pc in pcs] + await asyncio.gather(*coros) + pcs.clear() +########################################## if __name__ == '__main__': @@ -257,6 +301,7 @@ if __name__ == '__main__': # parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self') # parser.add_argument('--asr_model', type=str, default='facebook/hubert-large-ls960-ft') + parser.add_argument('--transport', type=str, default='rtmp') #rtmp webrtc parser.add_argument('--push_url', type=str, default='rtmp://localhost/live/livestream') parser.add_argument('--asr_save_feats', action='store_true') @@ -330,12 +375,29 @@ if __name__ == '__main__': # we still need test_loader to provide audio features for testing. nerfreal = NeRFReal(opt, trainer, test_loader) #txt_to_audio('我是中国人,我来自北京') - rendthrd = Thread(target=render) - rendthrd.start() + if opt.transport=='rtmp': + thread_quit = Event() + rendthrd = Thread(target=nerfreal.render,args=(thread_quit,)) + rendthrd.start() ############################################################################# - print('start websocket server') + appasync = web.Application() + appasync.on_shutdown.append(on_shutdown) + appasync.router.add_post("/offer", offer) + appasync.router.add_static('/',path='web') + def run_server(runner): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(runner.setup()) + site = web.TCPSite(runner, '0.0.0.0', 8010) + loop.run_until_complete(site.start()) + loop.run_forever() + Thread(target=run_server, args=(web.AppRunner(appasync),)).start() + + print('start websocket server') + #app.on_shutdown.append(on_shutdown) + #app.router.add_post("/offer", offer) server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler) server.serve_forever() diff --git a/nerfreal.py b/nerfreal.py index 426f08f..c6fe831 100644 --- a/nerfreal.py +++ b/nerfreal.py @@ -10,7 +10,9 @@ import torch.nn.functional as F import cv2 from asrreal import ASR +import asyncio from rtmp_streaming import StreamerConfig, Streamer +from av import AudioFrame, VideoFrame class NeRFReal: def __init__(self, opt, trainer, data_loader, debug=True): @@ -118,7 +120,7 @@ class NeRFReal: else: return np.expand_dims(outputs['depth'], -1).repeat(3, -1) - def test_step(self): + def test_step(self,loop=None,audio_track=None,video_track=None): #starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) #starter.record() @@ -140,7 +142,11 @@ class NeRFReal: #print(f'[INFO] outputs shape ',outputs['image'].shape) image = (outputs['image'] * 255).astype(np.uint8) if not self.opt.fullbody: - self.streamer.stream_frame(image) + if self.opt.transport=='rtmp': + self.streamer.stream_frame(image) + else: + new_frame = VideoFrame.from_ndarray(image, format="rgb24") + asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop) else: #fullbody human #print("frame index:",data['index']) image_fullbody = cv2.imread(os.path.join(self.opt.fullbody_img, str(data['index'][0])+'.jpg')) @@ -148,12 +154,23 @@ class NeRFReal: start_x = self.opt.fullbody_offset_x # 合并后小图片的起始x坐标 start_y = self.opt.fullbody_offset_y # 合并后小图片的起始y坐标 image_fullbody[start_y:start_y+image.shape[0], start_x:start_x+image.shape[1]] = image - self.streamer.stream_frame(image_fullbody) + if self.opt.transport=='rtmp': + self.streamer.stream_frame(image_fullbody) + else: + new_frame = VideoFrame.from_ndarray(image, format="rgb24") + asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop) #self.pipe.stdin.write(image.tostring()) for _ in range(2): frame = self.asr.get_audio_out() #print(f'[INFO] get_audio_out shape ',frame.shape) - self.streamer.stream_frame_audio(frame) + if self.opt.transport=='rtmp': + self.streamer.stream_frame_audio(frame) + else: + frame = (frame * 32767).astype(np.int16) + new_frame = AudioFrame(format='s16', layout='mono', samples=320) + new_frame.planes[0].update(frame.tobytes()) + new_frame.sample_rate=16000 + asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop) # frame = (frame * 32767).astype(np.int16).tobytes() # self.fifo_audio.write(frame) else: @@ -167,35 +184,36 @@ class NeRFReal: #torch.cuda.synchronize() #t = starter.elapsed_time(ender) - def render(self): + def render(self,quit_event,loop=None,audio_track=None,video_track=None): if self.opt.asr: self.asr.warm_up() count=0 totaltime=0 - fps=25 - #push_url='rtmp://localhost/live/livestream' #'data/video/output_0.mp4' - sc = StreamerConfig() - sc.source_width = self.W - sc.source_height = self.H - sc.stream_width = self.W - sc.stream_height = self.H - if self.opt.fullbody: - sc.source_width = self.opt.fullbody_width - sc.source_height = self.opt.fullbody_height - sc.stream_width = self.opt.fullbody_width - sc.stream_height = self.opt.fullbody_height - sc.stream_fps = fps - sc.stream_bitrate = 1000000 - sc.stream_profile = 'baseline' #'high444' # 'main' - sc.audio_channel = 1 - sc.sample_rate = 16000 - sc.stream_server = self.opt.push_url - self.streamer = Streamer() - self.streamer.init(sc) - #self.streamer.enable_av_debug_log() + if self.opt.transport=='rtmp': + fps=25 + #push_url='rtmp://localhost/live/livestream' #'data/video/output_0.mp4' + sc = StreamerConfig() + sc.source_width = self.W + sc.source_height = self.H + sc.stream_width = self.W + sc.stream_height = self.H + if self.opt.fullbody: + sc.source_width = self.opt.fullbody_width + sc.source_height = self.opt.fullbody_height + sc.stream_width = self.opt.fullbody_width + sc.stream_height = self.opt.fullbody_height + sc.stream_fps = fps + sc.stream_bitrate = 1000000 + sc.stream_profile = 'baseline' #'high444' # 'main' + sc.audio_channel = 1 + sc.sample_rate = 16000 + sc.stream_server = self.opt.push_url + self.streamer = Streamer() + self.streamer.init(sc) + #self.streamer.enable_av_debug_log() - while True: #todo + while not quit_event.is_set(): #todo # update texture every frame # audio stream thread... t = time.time() @@ -203,14 +221,14 @@ class NeRFReal: # run 2 ASR steps (audio is at 50FPS, video is at 25FPS) for _ in range(2): self.asr.run_step() - self.test_step() + self.test_step(loop,audio_track,video_track) totaltime += (time.time() - t) count += 1 if count==100: print(f"------actual avg fps:{count/totaltime:.4f}") count=0 totaltime=0 - # delay = 0.04 - (time.time() - t) #40ms - # if delay > 0: - # time.sleep(delay) + delay = 0.04 - (time.time() - t) #40ms + if delay > 0: + time.sleep(delay) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 12a5079..3fd1921 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,3 +31,4 @@ edge_tts flask flask_sockets opencv-python-headless +aiortc diff --git a/chat.html b/web/chat.html similarity index 100% rename from chat.html rename to web/chat.html diff --git a/web/client.js b/web/client.js new file mode 100644 index 0000000..c48b064 --- /dev/null +++ b/web/client.js @@ -0,0 +1,76 @@ +var pc = null; + +function negotiate() { + pc.addTransceiver('video', { direction: 'recvonly' }); + pc.addTransceiver('audio', { direction: 'recvonly' }); + return pc.createOffer().then((offer) => { + return pc.setLocalDescription(offer); + }).then(() => { + // wait for ICE gathering to complete + return new Promise((resolve) => { + if (pc.iceGatheringState === 'complete') { + resolve(); + } else { + const checkState = () => { + if (pc.iceGatheringState === 'complete') { + pc.removeEventListener('icegatheringstatechange', checkState); + resolve(); + } + }; + pc.addEventListener('icegatheringstatechange', checkState); + } + }); + }).then(() => { + var offer = pc.localDescription; + return fetch('/offer', { + body: JSON.stringify({ + sdp: offer.sdp, + type: offer.type, + }), + headers: { + 'Content-Type': 'application/json' + }, + method: 'POST' + }); + }).then((response) => { + return response.json(); + }).then((answer) => { + return pc.setRemoteDescription(answer); + }).catch((e) => { + alert(e); + }); +} + +function start() { + var config = { + sdpSemantics: 'unified-plan' + }; + + if (document.getElementById('use-stun').checked) { + config.iceServers = [{ urls: ['stun:stun.l.google.com:19302'] }]; + } + + pc = new RTCPeerConnection(config); + + // connect audio / video + pc.addEventListener('track', (evt) => { + if (evt.track.kind == 'video') { + document.getElementById('video').srcObject = evt.streams[0]; + } else { + document.getElementById('audio').srcObject = evt.streams[0]; + } + }); + + document.getElementById('start').style.display = 'none'; + negotiate(); + document.getElementById('stop').style.display = 'inline-block'; +} + +function stop() { + document.getElementById('stop').style.display = 'none'; + + // close peer connection + setTimeout(() => { + pc.close(); + }, 500); +} diff --git a/echo.html b/web/echo.html similarity index 100% rename from echo.html rename to web/echo.html diff --git a/mpegts-1.7.3.min.js b/web/mpegts-1.7.3.min.js similarity index 100% rename from mpegts-1.7.3.min.js rename to web/mpegts-1.7.3.min.js diff --git a/web/webrtc.html b/web/webrtc.html new file mode 100644 index 0000000..0103257 --- /dev/null +++ b/web/webrtc.html @@ -0,0 +1,83 @@ + + + + + + WebRTC webcam + + + + +
+ + +
+ + +
+
+

input text

+ + +
+ +
+ +
+

Media

+ + + +
+ + + + + + + diff --git a/webrtc.py b/webrtc.py new file mode 100644 index 0000000..e6b0a8a --- /dev/null +++ b/webrtc.py @@ -0,0 +1,158 @@ + +import asyncio +import json +import logging +import threading +import time +from typing import Tuple, Dict, Optional, Set, Union +from av.frame import Frame +from av.packet import Packet +import fractions + +AUDIO_PTIME = 0.020 # 20ms audio packetization +VIDEO_CLOCK_RATE = 90000 +VIDEO_PTIME = 1 / 25 # 30fps +VIDEO_TIME_BASE = fractions.Fraction(1, VIDEO_CLOCK_RATE) +SAMPLE_RATE = 16000 +AUDIO_TIME_BASE = fractions.Fraction(1, SAMPLE_RATE) + +#from aiortc.contrib.media import MediaPlayer, MediaRelay +#from aiortc.rtcrtpsender import RTCRtpSender +from aiortc import ( + MediaStreamTrack, +) + +logging.basicConfig() +logger = logging.getLogger(__name__) + + +class PlayerStreamTrack(MediaStreamTrack): + """ + A video track that returns an animated flag. + """ + + def __init__(self, player, kind): + super().__init__() # don't forget this! + self.kind = kind + self._player = player + self._queue = asyncio.Queue() + + _start: float + _timestamp: int + + async def next_timestamp(self) -> Tuple[int, fractions.Fraction]: + if self.readyState != "live": + raise Exception + + if self.kind == 'video': + if hasattr(self, "_timestamp"): + self._timestamp += int(VIDEO_PTIME * VIDEO_CLOCK_RATE) + wait = self._start + (self._timestamp / VIDEO_CLOCK_RATE) - time.time() + await asyncio.sleep(wait) + else: + self._start = time.time() + self._timestamp = 0 + return self._timestamp, VIDEO_TIME_BASE + else: #audio + if hasattr(self, "_timestamp"): + self._timestamp += int(AUDIO_PTIME * SAMPLE_RATE) + wait = self._start + (self._timestamp / SAMPLE_RATE) - time.time() + await asyncio.sleep(wait) + else: + self._start = time.time() + self._timestamp = 0 + return self._timestamp, AUDIO_TIME_BASE + + async def recv(self) -> Union[Frame, Packet]: + # frame = self.frames[self.counter % 30] + self._player._start(self) + frame = await self._queue.get() + pts, time_base = await self.next_timestamp() + frame.pts = pts + frame.time_base = time_base + if frame is None: + self.stop() + raise Exception + return frame + + def stop(self): + super().stop() + if self._player is not None: + self._player._stop(self) + self._player = None + +def player_worker_thread( + quit_event, + loop, + container, + audio_track, + video_track +): + container.render(quit_event,loop,audio_track,video_track) + +class HumanPlayer: + + def __init__( + self, nerfreal, format=None, options=None, timeout=None, loop=False, decode=True + ): + self.__thread: Optional[threading.Thread] = None + self.__thread_quit: Optional[threading.Event] = None + + # examine streams + self.__started: Set[PlayerStreamTrack] = set() + self.__audio: Optional[PlayerStreamTrack] = None + self.__video: Optional[PlayerStreamTrack] = None + + self.__audio = PlayerStreamTrack(self, kind="audio") + self.__video = PlayerStreamTrack(self, kind="video") + + self.__container = nerfreal + + + @property + def audio(self) -> MediaStreamTrack: + """ + A :class:`aiortc.MediaStreamTrack` instance if the file contains audio. + """ + return self.__audio + + @property + def video(self) -> MediaStreamTrack: + """ + A :class:`aiortc.MediaStreamTrack` instance if the file contains video. + """ + return self.__video + + def _start(self, track: PlayerStreamTrack) -> None: + self.__started.add(track) + if self.__thread is None: + self.__log_debug("Starting worker thread") + self.__thread_quit = threading.Event() + self.__thread = threading.Thread( + name="media-player", + target=player_worker_thread, + args=( + self.__thread_quit, + asyncio.get_event_loop(), + self.__container, + self.__audio, + self.__video + ), + ) + self.__thread.start() + + def _stop(self, track: PlayerStreamTrack) -> None: + self.__started.discard(track) + + if not self.__started and self.__thread is not None: + self.__log_debug("Stopping worker thread") + self.__thread_quit.set() + self.__thread.join() + self.__thread = None + + if not self.__started and self.__container is not None: + #self.__container.close() + self.__container = None + + def __log_debug(self, msg: str, *args) -> None: + logger.debug(f"HumanPlayer {msg}", *args)