Merge branch 'lipku:main' into main

This commit is contained in:
yanyuxiyangzk 2024-04-14 19:16:12 +08:00 committed by GitHub
commit ec7f7b5041
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 450 additions and 54 deletions

View File

@ -49,23 +49,14 @@ export HF_ENDPOINT=https://hf-mirror.com
运行成功后用vlc访问rtmp://serverip/live/livestream
### 2.3 网页端数字人播报输入文字
安装并启动nginx
```
apt install nginx
nginx
```
将echo.html和mpegts-1.7.3.min.js拷到/var/www/html下
用浏览器打开http://serverip/echo.html, 在文本框输入任意文字,提交。数字人播报该段文字
用浏览器打开http://serverip:8010/echo.html, 在文本框输入任意文字,提交。数字人播报该段文字
## 3. More Usage
### 3.1 使用LLM模型进行数字人对话
目前借鉴数字人对话系统[LinlyTalker](https://github.com/Kedreamix/Linly-Talker)的方式LLM模型支持Chatgpt,Qwen和GeminiPro。需要在app.py中填入自己的api_key。
安装并启动nginx将chat.html和mpegts-1.7.3.min.js拷到/var/www/html下
用浏览器打开http://serverip/chat.html
用浏览器打开http://serverip:8010/chat.html
### 3.2 使用本地tts服务,支持声音克隆
运行xtts服务参照 https://github.com/coqui-ai/xtts-streaming-server
@ -106,12 +97,19 @@ python app.py --fullbody --fullbody_img data/fullbody/img --fullbody_offset_x 10
- --W、--H 训练视频的宽、高
- ernerf训练第三步torso如果训练的不好在拼接处会有接缝。可以在上面的命令加上--torso_imgs data/xxx/torso_imgstorso不用模型推理直接用训练数据集里的torso图片。这种方式可能头颈处会有些人工痕迹。
### 3.6 webrtc
```
python app.py --transport webrtc
```
用浏览器打开http://serverip:8010/webrtc.html
## 4. Docker Run
不需要第1步的安装直接运行。
```
docker run --gpus all -it --network=host --rm registry.cn-hangzhou.aliyuncs.com/lipku/nerfstream:v1.3
```
srs和nginx的运行同2.1和2.3
srs的运行同2.1
## 5. Data flow
![](/assets/dataflow.png)

76
app.py
View File

@ -1,5 +1,5 @@
# server.py
from flask import Flask, request, jsonify
from flask import Flask, render_template,send_from_directory,request, jsonify
from flask_sockets import Sockets
import base64
import time
@ -10,9 +10,13 @@ from geventwebsocket.handler import WebSocketHandler
import os
import re
import numpy as np
from threading import Thread
from threading import Thread,Event
import multiprocessing
from aiohttp import web
from aiortc import RTCPeerConnection, RTCSessionDescription
from webrtc import HumanPlayer
import argparse
from nerf_triplane.provider import NeRFDataset_Test
from nerf_triplane.utils import *
@ -155,9 +159,49 @@ def chat_socket(ws):
res=llm_response(message)
txt_to_audio(res)
def render():
nerfreal.render()
#####webrtc###############################
pcs = set()
#@app.route('/offer', methods=['POST'])
async def offer(request):
params = await request.json()
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
pc = RTCPeerConnection()
pcs.add(pc)
@pc.on("connectionstatechange")
async def on_connectionstatechange():
print("Connection state is %s" % pc.connectionState)
if pc.connectionState == "failed":
await pc.close()
pcs.discard(pc)
player = HumanPlayer(nerfreal)
audio_sender = pc.addTrack(player.audio)
video_sender = pc.addTrack(player.video)
await pc.setRemoteDescription(offer)
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
#return jsonify({"sdp": pc.localDescription.sdp, "type": pc.localDescription.type})
return web.Response(
content_type="application/json",
text=json.dumps(
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
),
)
async def on_shutdown(app):
# close peer connections
coros = [pc.close() for pc in pcs]
await asyncio.gather(*coros)
pcs.clear()
##########################################
if __name__ == '__main__':
@ -257,6 +301,7 @@ if __name__ == '__main__':
# parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
# parser.add_argument('--asr_model', type=str, default='facebook/hubert-large-ls960-ft')
parser.add_argument('--transport', type=str, default='rtmp') #rtmp webrtc
parser.add_argument('--push_url', type=str, default='rtmp://localhost/live/livestream')
parser.add_argument('--asr_save_feats', action='store_true')
@ -330,12 +375,29 @@ if __name__ == '__main__':
# we still need test_loader to provide audio features for testing.
nerfreal = NeRFReal(opt, trainer, test_loader)
#txt_to_audio('我是中国人,我来自北京')
rendthrd = Thread(target=render)
rendthrd.start()
if opt.transport=='rtmp':
thread_quit = Event()
rendthrd = Thread(target=nerfreal.render,args=(thread_quit,))
rendthrd.start()
#############################################################################
print('start websocket server')
appasync = web.Application()
appasync.on_shutdown.append(on_shutdown)
appasync.router.add_post("/offer", offer)
appasync.router.add_static('/',path='web')
def run_server(runner):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(runner.setup())
site = web.TCPSite(runner, '0.0.0.0', 8010)
loop.run_until_complete(site.start())
loop.run_forever()
Thread(target=run_server, args=(web.AppRunner(appasync),)).start()
print('start websocket server')
#app.on_shutdown.append(on_shutdown)
#app.router.add_post("/offer", offer)
server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler)
server.serve_forever()

View File

@ -10,7 +10,9 @@ import torch.nn.functional as F
import cv2
from asrreal import ASR
import asyncio
from rtmp_streaming import StreamerConfig, Streamer
from av import AudioFrame, VideoFrame
class NeRFReal:
def __init__(self, opt, trainer, data_loader, debug=True):
@ -118,7 +120,7 @@ class NeRFReal:
else:
return np.expand_dims(outputs['depth'], -1).repeat(3, -1)
def test_step(self):
def test_step(self,loop=None,audio_track=None,video_track=None):
#starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
#starter.record()
@ -140,7 +142,11 @@ class NeRFReal:
#print(f'[INFO] outputs shape ',outputs['image'].shape)
image = (outputs['image'] * 255).astype(np.uint8)
if not self.opt.fullbody:
self.streamer.stream_frame(image)
if self.opt.transport=='rtmp':
self.streamer.stream_frame(image)
else:
new_frame = VideoFrame.from_ndarray(image, format="rgb24")
asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop)
else: #fullbody human
#print("frame index:",data['index'])
image_fullbody = cv2.imread(os.path.join(self.opt.fullbody_img, str(data['index'][0])+'.jpg'))
@ -148,12 +154,23 @@ class NeRFReal:
start_x = self.opt.fullbody_offset_x # 合并后小图片的起始x坐标
start_y = self.opt.fullbody_offset_y # 合并后小图片的起始y坐标
image_fullbody[start_y:start_y+image.shape[0], start_x:start_x+image.shape[1]] = image
self.streamer.stream_frame(image_fullbody)
if self.opt.transport=='rtmp':
self.streamer.stream_frame(image_fullbody)
else:
new_frame = VideoFrame.from_ndarray(image, format="rgb24")
asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop)
#self.pipe.stdin.write(image.tostring())
for _ in range(2):
frame = self.asr.get_audio_out()
#print(f'[INFO] get_audio_out shape ',frame.shape)
self.streamer.stream_frame_audio(frame)
if self.opt.transport=='rtmp':
self.streamer.stream_frame_audio(frame)
else:
frame = (frame * 32767).astype(np.int16)
new_frame = AudioFrame(format='s16', layout='mono', samples=320)
new_frame.planes[0].update(frame.tobytes())
new_frame.sample_rate=16000
asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop)
# frame = (frame * 32767).astype(np.int16).tobytes()
# self.fifo_audio.write(frame)
else:
@ -167,35 +184,36 @@ class NeRFReal:
#torch.cuda.synchronize()
#t = starter.elapsed_time(ender)
def render(self):
def render(self,quit_event,loop=None,audio_track=None,video_track=None):
if self.opt.asr:
self.asr.warm_up()
count=0
totaltime=0
fps=25
#push_url='rtmp://localhost/live/livestream' #'data/video/output_0.mp4'
sc = StreamerConfig()
sc.source_width = self.W
sc.source_height = self.H
sc.stream_width = self.W
sc.stream_height = self.H
if self.opt.fullbody:
sc.source_width = self.opt.fullbody_width
sc.source_height = self.opt.fullbody_height
sc.stream_width = self.opt.fullbody_width
sc.stream_height = self.opt.fullbody_height
sc.stream_fps = fps
sc.stream_bitrate = 1000000
sc.stream_profile = 'baseline' #'high444' # 'main'
sc.audio_channel = 1
sc.sample_rate = 16000
sc.stream_server = self.opt.push_url
self.streamer = Streamer()
self.streamer.init(sc)
#self.streamer.enable_av_debug_log()
if self.opt.transport=='rtmp':
fps=25
#push_url='rtmp://localhost/live/livestream' #'data/video/output_0.mp4'
sc = StreamerConfig()
sc.source_width = self.W
sc.source_height = self.H
sc.stream_width = self.W
sc.stream_height = self.H
if self.opt.fullbody:
sc.source_width = self.opt.fullbody_width
sc.source_height = self.opt.fullbody_height
sc.stream_width = self.opt.fullbody_width
sc.stream_height = self.opt.fullbody_height
sc.stream_fps = fps
sc.stream_bitrate = 1000000
sc.stream_profile = 'baseline' #'high444' # 'main'
sc.audio_channel = 1
sc.sample_rate = 16000
sc.stream_server = self.opt.push_url
self.streamer = Streamer()
self.streamer.init(sc)
#self.streamer.enable_av_debug_log()
while True: #todo
while not quit_event.is_set(): #todo
# update texture every frame
# audio stream thread...
t = time.time()
@ -203,14 +221,14 @@ class NeRFReal:
# run 2 ASR steps (audio is at 50FPS, video is at 25FPS)
for _ in range(2):
self.asr.run_step()
self.test_step()
self.test_step(loop,audio_track,video_track)
totaltime += (time.time() - t)
count += 1
if count==100:
print(f"------actual avg fps:{count/totaltime:.4f}")
count=0
totaltime=0
# delay = 0.04 - (time.time() - t) #40ms
# if delay > 0:
# time.sleep(delay)
delay = 0.04 - (time.time() - t) #40ms
if delay > 0:
time.sleep(delay)

View File

@ -31,3 +31,4 @@ edge_tts
flask
flask_sockets
opencv-python-headless
aiortc

76
web/client.js Normal file
View File

@ -0,0 +1,76 @@
var pc = null;
function negotiate() {
pc.addTransceiver('video', { direction: 'recvonly' });
pc.addTransceiver('audio', { direction: 'recvonly' });
return pc.createOffer().then((offer) => {
return pc.setLocalDescription(offer);
}).then(() => {
// wait for ICE gathering to complete
return new Promise((resolve) => {
if (pc.iceGatheringState === 'complete') {
resolve();
} else {
const checkState = () => {
if (pc.iceGatheringState === 'complete') {
pc.removeEventListener('icegatheringstatechange', checkState);
resolve();
}
};
pc.addEventListener('icegatheringstatechange', checkState);
}
});
}).then(() => {
var offer = pc.localDescription;
return fetch('/offer', {
body: JSON.stringify({
sdp: offer.sdp,
type: offer.type,
}),
headers: {
'Content-Type': 'application/json'
},
method: 'POST'
});
}).then((response) => {
return response.json();
}).then((answer) => {
return pc.setRemoteDescription(answer);
}).catch((e) => {
alert(e);
});
}
function start() {
var config = {
sdpSemantics: 'unified-plan'
};
if (document.getElementById('use-stun').checked) {
config.iceServers = [{ urls: ['stun:stun.l.google.com:19302'] }];
}
pc = new RTCPeerConnection(config);
// connect audio / video
pc.addEventListener('track', (evt) => {
if (evt.track.kind == 'video') {
document.getElementById('video').srcObject = evt.streams[0];
} else {
document.getElementById('audio').srcObject = evt.streams[0];
}
});
document.getElementById('start').style.display = 'none';
negotiate();
document.getElementById('stop').style.display = 'inline-block';
}
function stop() {
document.getElementById('stop').style.display = 'none';
// close peer connection
setTimeout(() => {
pc.close();
}, 500);
}

83
web/webrtc.html Normal file
View File

@ -0,0 +1,83 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>WebRTC webcam</title>
<style>
button {
padding: 8px 16px;
}
video {
width: 100%;
}
.option {
margin-bottom: 8px;
}
#media {
max-width: 1280px;
}
</style>
</head>
<body>
<div class="option">
<input id="use-stun" type="checkbox"/>
<label for="use-stun">Use STUN server</label>
</div>
<button id="start" onclick="start()">Start</button>
<button id="stop" style="display: none" onclick="stop()">Stop</button>
<form class="form-inline" id="echo-form">
<div class="form-group">
<p>input text</p>
<textarea cols="2" rows="3" style="width:600px;height:50px;" class="form-control" id="message">test</textarea>
</div>
<button type="submit" class="btn btn-default">Send</button>
</form>
<div id="media">
<h2>Media</h2>
<audio id="audio" autoplay="true"></audio>
<video id="video" autoplay="true" playsinline="true"></video>
</div>
<script src="client.js"></script>
<script type="text/javascript" src="http://cdn.sockjs.org/sockjs-0.3.4.js"></script>
<script src="http://code.jquery.com/jquery-2.1.1.min.js"></script>
</body>
<script type="text/javascript" charset="utf-8">
$(document).ready(function() {
var host = window.location.hostname
var ws = new WebSocket("ws://"+host+":8000/humanecho");
//document.getElementsByTagName("video")[0].setAttribute("src", aa["video"]);
ws.onopen = function() {
console.log('Connected');
};
ws.onmessage = function(e) {
console.log('Received: ' + e.data);
data = e
var vid = JSON.parse(data.data);
console.log(typeof(vid),vid)
//document.getElementsByTagName("video")[0].setAttribute("src", vid["video"]);
};
ws.onclose = function(e) {
console.log('Closed');
};
$('#echo-form').on('submit', function(e) {
e.preventDefault();
var message = $('#message').val();
console.log('Sending: ' + message);
ws.send(message);
$('#message').val('');
});
});
</script>
</html>

158
webrtc.py Normal file
View File

@ -0,0 +1,158 @@
import asyncio
import json
import logging
import threading
import time
from typing import Tuple, Dict, Optional, Set, Union
from av.frame import Frame
from av.packet import Packet
import fractions
AUDIO_PTIME = 0.020 # 20ms audio packetization
VIDEO_CLOCK_RATE = 90000
VIDEO_PTIME = 1 / 25 # 30fps
VIDEO_TIME_BASE = fractions.Fraction(1, VIDEO_CLOCK_RATE)
SAMPLE_RATE = 16000
AUDIO_TIME_BASE = fractions.Fraction(1, SAMPLE_RATE)
#from aiortc.contrib.media import MediaPlayer, MediaRelay
#from aiortc.rtcrtpsender import RTCRtpSender
from aiortc import (
MediaStreamTrack,
)
logging.basicConfig()
logger = logging.getLogger(__name__)
class PlayerStreamTrack(MediaStreamTrack):
"""
A video track that returns an animated flag.
"""
def __init__(self, player, kind):
super().__init__() # don't forget this!
self.kind = kind
self._player = player
self._queue = asyncio.Queue()
_start: float
_timestamp: int
async def next_timestamp(self) -> Tuple[int, fractions.Fraction]:
if self.readyState != "live":
raise Exception
if self.kind == 'video':
if hasattr(self, "_timestamp"):
self._timestamp += int(VIDEO_PTIME * VIDEO_CLOCK_RATE)
wait = self._start + (self._timestamp / VIDEO_CLOCK_RATE) - time.time()
await asyncio.sleep(wait)
else:
self._start = time.time()
self._timestamp = 0
return self._timestamp, VIDEO_TIME_BASE
else: #audio
if hasattr(self, "_timestamp"):
self._timestamp += int(AUDIO_PTIME * SAMPLE_RATE)
wait = self._start + (self._timestamp / SAMPLE_RATE) - time.time()
await asyncio.sleep(wait)
else:
self._start = time.time()
self._timestamp = 0
return self._timestamp, AUDIO_TIME_BASE
async def recv(self) -> Union[Frame, Packet]:
# frame = self.frames[self.counter % 30]
self._player._start(self)
frame = await self._queue.get()
pts, time_base = await self.next_timestamp()
frame.pts = pts
frame.time_base = time_base
if frame is None:
self.stop()
raise Exception
return frame
def stop(self):
super().stop()
if self._player is not None:
self._player._stop(self)
self._player = None
def player_worker_thread(
quit_event,
loop,
container,
audio_track,
video_track
):
container.render(quit_event,loop,audio_track,video_track)
class HumanPlayer:
def __init__(
self, nerfreal, format=None, options=None, timeout=None, loop=False, decode=True
):
self.__thread: Optional[threading.Thread] = None
self.__thread_quit: Optional[threading.Event] = None
# examine streams
self.__started: Set[PlayerStreamTrack] = set()
self.__audio: Optional[PlayerStreamTrack] = None
self.__video: Optional[PlayerStreamTrack] = None
self.__audio = PlayerStreamTrack(self, kind="audio")
self.__video = PlayerStreamTrack(self, kind="video")
self.__container = nerfreal
@property
def audio(self) -> MediaStreamTrack:
"""
A :class:`aiortc.MediaStreamTrack` instance if the file contains audio.
"""
return self.__audio
@property
def video(self) -> MediaStreamTrack:
"""
A :class:`aiortc.MediaStreamTrack` instance if the file contains video.
"""
return self.__video
def _start(self, track: PlayerStreamTrack) -> None:
self.__started.add(track)
if self.__thread is None:
self.__log_debug("Starting worker thread")
self.__thread_quit = threading.Event()
self.__thread = threading.Thread(
name="media-player",
target=player_worker_thread,
args=(
self.__thread_quit,
asyncio.get_event_loop(),
self.__container,
self.__audio,
self.__video
),
)
self.__thread.start()
def _stop(self, track: PlayerStreamTrack) -> None:
self.__started.discard(track)
if not self.__started and self.__thread is not None:
self.__log_debug("Stopping worker thread")
self.__thread_quit.set()
self.__thread.join()
self.__thread = None
if not self.__started and self.__container is not None:
#self.__container.close()
self.__container = None
def __log_debug(self, msg: str, *args) -> None:
logger.debug(f"HumanPlayer {msg}", *args)