Refactoring tts code

This commit is contained in:
lipku 2024-06-02 22:25:19 +08:00
parent 4e355e9ab9
commit 632409da1e
7 changed files with 311 additions and 281 deletions

View File

@ -1,5 +1,5 @@
A streaming digital human based on the Ernerf model realize audio video synchronous dialogue. It can basically achieve commercial effects. Real time interactive streaming digital human realize audio video synchronous dialogue. It can basically achieve commercial effects.
基于ernerf模型的流式数字人,实现音视频同步对话。基本可以达到商用效果 实时交互流式数字人,实现音视频同步对话。基本可以达到商用效果
[ernerf效果](https://www.bilibili.com/video/BV1PM4m1y7Q2/) [musetalk效果](https://www.bilibili.com/video/BV1gm421N7vQ/) [ernerf效果](https://www.bilibili.com/video/BV1PM4m1y7Q2/) [musetalk效果](https://www.bilibili.com/video/BV1gm421N7vQ/)
@ -23,17 +23,17 @@ conda create -n nerfstream python=3.10
conda activate nerfstream conda activate nerfstream
conda install pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=11.3 -c pytorch conda install pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=11.3 -c pytorch
pip install -r requirements.txt pip install -r requirements.txt
#如果只用musetalk模型,不需要安装下面的库
pip install "git+https://github.com/facebookresearch/pytorch3d.git" pip install "git+https://github.com/facebookresearch/pytorch3d.git"
pip install tensorflow-gpu==2.8.0 pip install tensorflow-gpu==2.8.0
pip install --upgrade "protobuf<=3.20.1" pip install --upgrade "protobuf<=3.20.1"
pip install --upgrade "edge-tts<=6.1.11"
``` ```
安装常见问题[FAQ](/assets/faq.md) 安装常见问题[FAQ](/assets/faq.md)
linux cuda环境搭建可以参考这篇文章 https://zhuanlan.zhihu.com/p/674972886 linux cuda环境搭建可以参考这篇文章 https://zhuanlan.zhihu.com/p/674972886
## 2. Quick Start ## 2. Quick Start
默认采用webrtc推流到srs 默认采用ernerf模型webrtc推流到srs
### 2.1 运行rtmpserver (srs) ### 2.1 运行rtmpserver (srs)
``` ```
export CANDIDATE='<服务器外网ip>' export CANDIDATE='<服务器外网ip>'
@ -212,7 +212,8 @@ docker版本已经不是最新代码可以作为一个空环境把最新
- [ ] SyncTalk - [ ] SyncTalk
如果本项目对你有帮助帮忙点个star。也欢迎感兴趣的朋友一起来完善该项目。 如果本项目对你有帮助帮忙点个star。也欢迎感兴趣的朋友一起来完善该项目。
Email: lipku@foxmail.com
知识星球: https://t.zsxq.com/7NMyO 知识星球: https://t.zsxq.com/7NMyO
微信公众号:数字人技术 微信公众号:数字人技术
![](https://mmbiz.qpic.cn/sz_mmbiz_jpg/l3ZibgueFiaeyfaiaLZGuMGQXnhLWxibpJUS2gfs8Dje6JuMY8zu2tVyU9n8Zx1yaNncvKHBMibX0ocehoITy5qQEZg/640?wxfrom=12&tp=wxpic&usePicPrefetch=1&wx_fmt=jpeg&amp;from=appmsg) ![](https://mmbiz.qpic.cn/sz_mmbiz_jpg/l3ZibgueFiaeyfaiaLZGuMGQXnhLWxibpJUS2gfs8Dje6JuMY8zu2tVyU9n8Zx1yaNncvKHBMibX0ocehoITy5qQEZg/640?wxfrom=12&tp=wxpic&usePicPrefetch=1&wx_fmt=jpeg&amp;from=appmsg)
Buy me a coffee
![](https://mmbiz.qpic.cn/sz_mmbiz_jpg/l3ZibgueFiaeyEO2TDmroXibUSeFRCB3ftThHyTgVmVYyVVyvqDxronGvoU7xzkztnwQpnM5lBgx4MSaUUrnRZwCw/640?wx_fmt=jpeg&amp;from=appmsg)

169
app.py
View File

@ -23,132 +23,11 @@ import argparse
import shutil import shutil
import asyncio import asyncio
import edge_tts
from typing import Iterator
import requests
app = Flask(__name__) app = Flask(__name__)
sockets = Sockets(app) sockets = Sockets(app)
global nerfreal global nerfreal
global tts_type
global gspeaker
async def main(voicename: str, text: str, render):
communicate = edge_tts.Communicate(text, voicename)
#with open(OUTPUT_FILE, "wb") as file:
first = True
async for chunk in communicate.stream():
if first:
#render.before_push_audio()
first = False
if chunk["type"] == "audio":
render.push_audio(chunk["data"])
#file.write(chunk["data"])
elif chunk["type"] == "WordBoundary":
pass
def get_speaker(ref_audio,server_url):
files = {"wav_file": ("reference.wav", open(ref_audio, "rb"))}
response = requests.post(f"{server_url}/clone_speaker", files=files)
return response.json()
def xtts(text, speaker, language, server_url, stream_chunk_size) -> Iterator[bytes]:
start = time.perf_counter()
speaker["text"] = text
speaker["language"] = language
speaker["stream_chunk_size"] = stream_chunk_size # you can reduce it to get faster response, but degrade quality
res = requests.post(
f"{server_url}/tts_stream",
json=speaker,
stream=True,
)
end = time.perf_counter()
print(f"xtts Time to make POST: {end-start}s")
if res.status_code != 200:
print("Error:", res.text)
return
first = True
for chunk in res.iter_content(chunk_size=960): #24K*20ms*2
if first:
end = time.perf_counter()
print(f"xtts Time to first chunk: {end-start}s")
first = False
if chunk:
yield chunk
print("xtts response.elapsed:", res.elapsed)
def gpt_sovits(text, character, language, server_url, emotion) -> Iterator[bytes]:
start = time.perf_counter()
req={}
req["text"] = text
req["text_language"] = language
req["character"] = character
req["emotion"] = emotion
#req["stream_chunk_size"] = stream_chunk_size # you can reduce it to get faster response, but degrade quality
req["stream"] = True
res = requests.post(
f"{server_url}/tts",
json=req,
stream=True,
)
end = time.perf_counter()
print(f"gpt_sovits Time to make POST: {end-start}s")
if res.status_code != 200:
print("Error:", res.text)
return
first = True
for chunk in res.iter_content(chunk_size=32000): # 1280 32K*20ms*2
if first:
end = time.perf_counter()
print(f"gpt_sovits Time to first chunk: {end-start}s")
first = False
if chunk:
yield chunk
print("gpt_sovits response.elapsed:", res.elapsed)
def stream_tts(audio_stream,render):
for chunk in audio_stream:
if chunk is not None:
render.push_audio(chunk)
def txt_to_audio(text_):
if tts_type == "edgetts":
voicename = "zh-CN-YunxiaNeural"
text = text_
t = time.time()
asyncio.get_event_loop().run_until_complete(main(voicename,text,nerfreal))
print(f'-------edge tts time:{time.time()-t:.4f}s')
elif tts_type == "gpt-sovits": #gpt_sovits
stream_tts(
gpt_sovits(
text_,
app.config['CHARACTER'], #"test", #character
"zh", #en args.language,
app.config['TTS_SERVER'], #"http://127.0.0.1:5000", #args.server_url,
app.config['EMOTION'], #emotion
),
nerfreal
)
else: #xtts
stream_tts(
xtts(
text_,
gspeaker,
"zh-cn", #en args.language,
app.config['TTS_SERVER'], #"http://localhost:9000", #args.server_url,
"20" #args.stream_chunk_size
),
nerfreal
)
@sockets.route('/humanecho') @sockets.route('/humanecho')
@ -168,7 +47,7 @@ def echo_socket(ws):
if not message or len(message)==0: if not message or len(message)==0:
return '输入信息为空' return '输入信息为空'
else: else:
txt_to_audio(message) nerfreal.put_msg_txt(message)
def llm_response(message): def llm_response(message):
@ -198,42 +77,11 @@ def chat_socket(ws):
return '输入信息为空' return '输入信息为空'
else: else:
res=llm_response(message) res=llm_response(message)
txt_to_audio(res) nerfreal.put_msg_txt(res)
#####webrtc############################### #####webrtc###############################
pcs = set() pcs = set()
async def txt_to_audio_async(text_):
if tts_type == "edgetts":
voicename = "zh-CN-YunxiaNeural"
text = text_
t = time.time()
#asyncio.get_event_loop().run_until_complete(main(voicename,text,nerfreal))
await main(voicename,text,nerfreal)
print(f'-------edge tts time:{time.time()-t:.4f}s')
elif tts_type == "gpt-sovits": #gpt_sovits
stream_tts(
gpt_sovits(
text_,
app.config['CHARACTER'], #"test", #character
"zh", #en args.language,
app.config['TTS_SERVER'], #"http://127.0.0.1:5000", #args.server_url,
app.config['EMOTION'], #emotion
),
nerfreal
)
else: #xtts
stream_tts(
xtts(
text_,
gspeaker,
"zh-cn", #en args.language,
app.config['TTS_SERVER'], #"http://localhost:9000", #args.server_url,
"20" #args.stream_chunk_size
),
nerfreal
)
#@app.route('/offer', methods=['POST']) #@app.route('/offer', methods=['POST'])
async def offer(request): async def offer(request):
params = await request.json() params = await request.json()
@ -271,10 +119,10 @@ async def human(request):
params = await request.json() params = await request.json()
if params['type']=='echo': if params['type']=='echo':
await txt_to_audio_async(params['text']) nerfreal.put_msg_txt(params['text'])
elif params['type']=='chat': elif params['type']=='chat':
res=llm_response(params['text']) res=await asyncio.get_event_loop().run_in_executor(None, llm_response(params['text']))
await txt_to_audio_async(res) nerfreal.put_msg_txt(res)
return web.Response( return web.Response(
content_type="application/json", content_type="application/json",
@ -453,14 +301,9 @@ if __name__ == '__main__':
parser.add_argument('--listenport', type=int, default=8010) parser.add_argument('--listenport', type=int, default=8010)
opt = parser.parse_args() opt = parser.parse_args()
app.config.from_object(opt) #app.config.from_object(opt)
#print(app.config) #print(app.config)
tts_type = opt.tts
if tts_type == "xtts":
print("Computing the latents for a new reference...")
gspeaker = get_speaker(opt.REF_FILE, opt.TTS_SERVER)
if opt.model == 'ernerf': if opt.model == 'ernerf':
from ernerf.nerf_triplane.provider import NeRFDataset_Test from ernerf.nerf_triplane.provider import NeRFDataset_Test
from ernerf.nerf_triplane.utils import * from ernerf.nerf_triplane.utils import *

View File

@ -56,7 +56,6 @@ class ASR:
# create input stream # create input stream
self.queue = Queue() self.queue = Queue()
self.input_stream = BytesIO()
self.output_queue = Queue() self.output_queue = Queue()
# start a background process to read frames # start a background process to read frames
#self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk) #self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk)
@ -204,6 +203,9 @@ class ASR:
# np.save(output_path, unfold_feats.cpu().numpy()) # np.save(output_path, unfold_feats.cpu().numpy())
# print(f"[INFO] saved logits to {output_path}") # print(f"[INFO] saved logits to {output_path}")
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.queue.put(audio_chunk)
def __get_audio_frame(self): def __get_audio_frame(self):
if self.inwarm: # warm up if self.inwarm: # warm up
return np.zeros(self.chunk, dtype=np.float32),1 return np.zeros(self.chunk, dtype=np.float32),1
@ -260,56 +262,6 @@ class ASR:
return logits[0], None,None #predicted_ids[0], transcription # [N,] return logits[0], None,None #predicted_ids[0], transcription # [N,]
def __create_bytes_stream(self,byte_stream):
#byte_stream=BytesIO(buffer)
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
stream = stream.astype(np.float32)
if stream.ndim > 1:
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0]
if sample_rate != self.sample_rate and stream.shape[0]>0:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
return stream
def push_audio(self,buffer): #push audio pcm from tts
print(f'[INFO] push_audio {len(buffer)}')
if self.opt.tts == "xtts" or self.opt.tts == "gpt-sovits":
if len(buffer)>0:
stream = np.frombuffer(buffer, dtype=np.int16).astype(np.float32) / 32767
if self.opt.tts == "xtts":
stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate)
else:
stream = resampy.resample(x=stream, sr_orig=32000, sr_new=self.sample_rate)
#byte_stream=BytesIO(buffer)
#stream = self.__create_bytes_stream(byte_stream)
streamlen = stream.shape[0]
idx=0
while streamlen >= self.chunk:
self.queue.put(stream[idx:idx+self.chunk])
streamlen -= self.chunk
idx += self.chunk
# if streamlen>0: #skip last frame(not 20ms)
# self.queue.put(stream[idx:])
else: #edge tts
self.input_stream.write(buffer)
if len(buffer)<=0:
self.input_stream.seek(0)
stream = self.__create_bytes_stream(self.input_stream)
streamlen = stream.shape[0]
idx=0
while streamlen >= self.chunk:
self.queue.put(stream[idx:idx+self.chunk])
streamlen -= self.chunk
idx += self.chunk
#if streamlen>0: #skip last frame(not 20ms)
# self.queue.put(stream[idx:])
self.input_stream.seek(0)
self.input_stream.truncate()
def get_audio_out(self): #get origin audio pcm to nerf def get_audio_out(self): #get origin audio pcm to nerf
return self.output_queue.get() return self.output_queue.get()

View File

@ -18,7 +18,7 @@ class MuseASR:
self.sample_rate = 16000 self.sample_rate = 16000
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000) self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.queue = Queue() self.queue = Queue()
self.input_stream = BytesIO() # self.input_stream = BytesIO()
self.output_queue = Queue() self.output_queue = Queue()
self.audio_processor = audio_processor self.audio_processor = audio_processor
@ -29,62 +29,14 @@ class MuseASR:
self.warm_up() self.warm_up()
def __create_bytes_stream(self,byte_stream): def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
#byte_stream=BytesIO(buffer) self.queue.put(audio_chunk)
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
stream = stream.astype(np.float32)
if stream.ndim > 1:
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0]
if sample_rate != self.sample_rate and stream.shape[0]>0:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
return stream
def push_audio(self,buffer):
print(f'[INFO] push_audio {len(buffer)}')
if self.opt.tts == "xtts" or self.opt.tts == "gpt-sovits":
if len(buffer)>0:
stream = np.frombuffer(buffer, dtype=np.int16).astype(np.float32) / 32767
if self.opt.tts == "xtts":
stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate)
else:
stream = resampy.resample(x=stream, sr_orig=32000, sr_new=self.sample_rate)
#byte_stream=BytesIO(buffer)
#stream = self.__create_bytes_stream(byte_stream)
streamlen = stream.shape[0]
idx=0
while streamlen >= self.chunk:
self.queue.put(stream[idx:idx+self.chunk])
streamlen -= self.chunk
idx += self.chunk
# if streamlen>0: #skip last frame(not 20ms)
# self.queue.put(stream[idx:])
else: #edge tts
self.input_stream.write(buffer)
if len(buffer)<=0:
self.input_stream.seek(0)
stream = self.__create_bytes_stream(self.input_stream)
streamlen = stream.shape[0]
idx=0
while streamlen >= self.chunk:
self.queue.put(stream[idx:idx+self.chunk])
streamlen -= self.chunk
idx += self.chunk
#if streamlen>0: #skip last frame(not 20ms)
# self.queue.put(stream[idx:])
self.input_stream.seek(0)
self.input_stream.truncate()
def __get_audio_frame(self): def __get_audio_frame(self):
try: try:
frame = self.queue.get(block=False) frame = self.queue.get(block=True,timeout=0.02)
type = 0 type = 0
print(f'[INFO] get frame {frame.shape}') #print(f'[INFO] get frame {frame.shape}')
except queue.Empty: except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32) frame = np.zeros(self.chunk, dtype=np.float32)
type = 1 type = 1

View File

@ -21,6 +21,7 @@ from musetalk.utils.utils import get_file_type,get_video_fps,datagen
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending
from musetalk.utils.utils import load_all_model from musetalk.utils.utils import load_all_model
from ttsreal import EdgeTTS,VoitsTTS,XTTS
from museasr import MuseASR from museasr import MuseASR
import asyncio import asyncio
@ -59,6 +60,13 @@ class MuseReal:
self.__loadavatar() self.__loadavatar()
self.asr = MuseASR(opt,self.audio_processor) self.asr = MuseASR(opt,self.audio_processor)
if opt.tts == "edgetts":
self.tts = EdgeTTS(opt,self)
elif opt.tts == "gpt-sovits":
self.tts = VoitsTTS(opt,self)
elif opt.tts == "xtts":
self.tts = XTTS(opt,self)
#self.__warm_up()
def __loadmodels(self): def __loadmodels(self):
# load model weights # load model weights
@ -83,8 +91,11 @@ class MuseReal:
self.mask_list_cycle = read_imgs(input_mask_list) self.mask_list_cycle = read_imgs(input_mask_list)
def push_audio(self,buffer): def put_msg_txt(self,msg):
self.asr.push_audio(buffer) self.tts.put_msg_txt(msg)
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.asr.put_audio_frame(audio_chunk)
def __mirror_index(self, index): def __mirror_index(self, index):
size = len(self.coord_list_cycle) size = len(self.coord_list_cycle)
@ -95,11 +106,35 @@ class MuseReal:
else: else:
return size - res - 1 return size - res - 1
def __warm_up(self):
self.asr.run_step()
whisper_chunks = self.asr.get_next_feat()
whisper_batch = np.stack(whisper_chunks)
latent_batch = []
for i in range(self.batch_size):
idx = self.__mirror_index(self.idx+i)
latent = self.input_latent_list_cycle[idx]
latent_batch.append(latent)
latent_batch = torch.cat(latent_batch, dim=0)
print('infer=======')
# for i, (whisper_batch,latent_batch) in enumerate(gen):
audio_feature_batch = torch.from_numpy(whisper_batch)
audio_feature_batch = audio_feature_batch.to(device=self.unet.device,
dtype=self.unet.model.dtype)
audio_feature_batch = self.pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=self.unet.model.dtype)
pred_latents = self.unet.model(latent_batch,
self.timesteps,
encoder_hidden_states=audio_feature_batch).sample
recon = self.vae.decode_latents(pred_latents)
def test_step(self,loop=None,audio_track=None,video_track=None): def test_step(self,loop=None,audio_track=None,video_track=None):
# gen = datagen(whisper_chunks, # gen = datagen(whisper_chunks,
# self.input_latent_list_cycle, # self.input_latent_list_cycle,
# self.batch_size) # self.batch_size)
starttime=time.perf_counter()
self.asr.run_step() self.asr.run_step()
whisper_chunks = self.asr.get_next_feat() whisper_chunks = self.asr.get_next_feat()
is_all_silence=True is_all_silence=True
@ -114,7 +149,8 @@ class MuseReal:
self.res_frame_queue.put((None,self.__mirror_index(self.idx),audio_frames[i*2:i*2+2])) self.res_frame_queue.put((None,self.__mirror_index(self.idx),audio_frames[i*2:i*2+2]))
self.idx = self.idx + 1 self.idx = self.idx + 1
else: else:
print('infer=======') # print('infer=======')
t=time.perf_counter()
whisper_batch = np.stack(whisper_chunks) whisper_batch = np.stack(whisper_chunks)
latent_batch = [] latent_batch = []
for i in range(self.batch_size): for i in range(self.batch_size):
@ -129,16 +165,22 @@ class MuseReal:
dtype=self.unet.model.dtype) dtype=self.unet.model.dtype)
audio_feature_batch = self.pe(audio_feature_batch) audio_feature_batch = self.pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=self.unet.model.dtype) latent_batch = latent_batch.to(dtype=self.unet.model.dtype)
# print('prepare time:',time.perf_counter()-t)
# t=time.perf_counter()
pred_latents = self.unet.model(latent_batch, pred_latents = self.unet.model(latent_batch,
self.timesteps, self.timesteps,
encoder_hidden_states=audio_feature_batch).sample encoder_hidden_states=audio_feature_batch).sample
# print('unet time:',time.perf_counter()-t)
# t=time.perf_counter()
recon = self.vae.decode_latents(pred_latents) recon = self.vae.decode_latents(pred_latents)
# print('vae time:',time.perf_counter()-t)
#print('diffusion len=',len(recon)) #print('diffusion len=',len(recon))
for i,res_frame in enumerate(recon): for i,res_frame in enumerate(recon):
#self.__pushmedia(res_frame,loop,audio_track,video_track) #self.__pushmedia(res_frame,loop,audio_track,video_track)
self.res_frame_queue.put((res_frame,self.__mirror_index(self.idx),audio_frames[i*2:i*2+2])) self.res_frame_queue.put((res_frame,self.__mirror_index(self.idx),audio_frames[i*2:i*2+2]))
self.idx = self.idx + 1 self.idx = self.idx + 1
print('total batch time:',time.perf_counter()-starttime)
def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None): def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):
@ -176,11 +218,13 @@ class MuseReal:
# if audio_track._queue.qsize()>10: # if audio_track._queue.qsize()>10:
# time.sleep(0.1) # time.sleep(0.1)
asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop) asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop)
print('musereal process_frames thread stop')
def render(self,quit_event,loop=None,audio_track=None,video_track=None): def render(self,quit_event,loop=None,audio_track=None,video_track=None):
#if self.opt.asr: #if self.opt.asr:
# self.asr.warm_up() # self.asr.warm_up()
self.tts.render(quit_event)
process_thread = Thread(target=self.process_frames, args=(quit_event,loop,audio_track,video_track)) process_thread = Thread(target=self.process_frames, args=(quit_event,loop,audio_track,video_track))
process_thread.start() process_thread.start()
@ -207,4 +251,5 @@ class MuseReal:
# delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms # delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
# if delay > 0: # if delay > 0:
# time.sleep(delay) # time.sleep(delay)
print('musereal thread stop')

View File

@ -10,6 +10,8 @@ import torch.nn.functional as F
import cv2 import cv2
from asrreal import ASR from asrreal import ASR
from ttsreal import EdgeTTS,VoitsTTS,XTTS
import asyncio import asyncio
from av import AudioFrame, VideoFrame from av import AudioFrame, VideoFrame
@ -63,6 +65,12 @@ class NeRFReal:
if self.opt.asr: if self.opt.asr:
self.asr = ASR(opt) self.asr = ASR(opt)
self.asr.warm_up() self.asr.warm_up()
if opt.tts == "edgetts":
self.tts = EdgeTTS(opt,self)
elif opt.tts == "gpt-sovits":
self.tts = VoitsTTS(opt,self)
elif opt.tts == "xtts":
self.tts = XTTS(opt,self)
''' '''
video_path = 'video_stream' video_path = 'video_stream'
@ -110,8 +118,11 @@ class NeRFReal:
if self.opt.asr: if self.opt.asr:
self.asr.stop() self.asr.stop()
def push_audio(self,chunk): def put_msg_txt(self,msg):
self.asr.push_audio(chunk) self.tts.put_msg_txt(msg)
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.asr.put_audio_frame(audio_chunk)
def mirror_index(self, index): def mirror_index(self, index):
@ -231,6 +242,8 @@ class NeRFReal:
totaltime=0 totaltime=0
_starttime=time.perf_counter() _starttime=time.perf_counter()
_totalframe=0 _totalframe=0
self.tts.render(quit_event)
while not quit_event.is_set(): #todo while not quit_event.is_set(): #todo
# update texture every frame # update texture every frame
# audio stream thread... # audio stream thread...
@ -255,5 +268,6 @@ class NeRFReal:
if video_track._queue.qsize()>=5: if video_track._queue.qsize()>=5:
#print('sleep qsize=',video_track._queue.qsize()) #print('sleep qsize=',video_track._queue.qsize())
time.sleep(0.1) time.sleep(0.1)
print('nerfreal thread stop')

223
ttsreal.py Normal file
View File

@ -0,0 +1,223 @@
import time
import numpy as np
import soundfile as sf
import resampy
import asyncio
import edge_tts
from typing import Iterator
import requests
import queue
from queue import Queue
from io import BytesIO
from threading import Thread, Event
class BaseTTS:
def __init__(self, opt, parent):
self.opt=opt
self.parent = parent
self.fps = opt.fps # 20 ms per frame
self.sample_rate = 16000
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.input_stream = BytesIO()
self.msgqueue = Queue()
def put_msg_txt(self,msg):
self.msgqueue.put(msg)
def render(self,quit_event):
process_thread = Thread(target=self.process_tts, args=(quit_event,))
process_thread.start()
def process_tts(self,quit_event):
while not quit_event.is_set():
try:
msg = self.msgqueue.get(block=True, timeout=1)
except queue.Empty:
continue
self.txt_to_audio(msg)
print('ttsreal thread stop')
def txt_to_audio(self,msg):
pass
###########################################################################################
class EdgeTTS(BaseTTS):
def txt_to_audio(self,msg):
voicename = "zh-CN-YunxiaNeural"
text = msg
t = time.time()
asyncio.new_event_loop().run_until_complete(self.__main(voicename,text))
print(f'-------edge tts time:{time.time()-t:.4f}s')
self.input_stream.seek(0)
stream = self.__create_bytes_stream(self.input_stream)
streamlen = stream.shape[0]
idx=0
while streamlen >= self.chunk:
self.parent.put_audio_frame(stream[idx:idx+self.chunk])
streamlen -= self.chunk
idx += self.chunk
#if streamlen>0: #skip last frame(not 20ms)
# self.queue.put(stream[idx:])
self.input_stream.seek(0)
self.input_stream.truncate()
def __create_bytes_stream(self,byte_stream):
#byte_stream=BytesIO(buffer)
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
stream = stream.astype(np.float32)
if stream.ndim > 1:
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0]
if sample_rate != self.sample_rate and stream.shape[0]>0:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
return stream
async def __main(self,voicename: str, text: str):
communicate = edge_tts.Communicate(text, voicename)
#with open(OUTPUT_FILE, "wb") as file:
first = True
async for chunk in communicate.stream():
if first:
first = False
if chunk["type"] == "audio":
#self.push_audio(chunk["data"])
self.input_stream.write(chunk["data"])
#file.write(chunk["data"])
elif chunk["type"] == "WordBoundary":
pass
###########################################################################################
class VoitsTTS(BaseTTS):
def txt_to_audio(self,msg):
self.stream_tts(
self.gpt_sovits(
msg,
self.opt.CHARACTER, #"test", #character
"zh", #en args.language,
self.opt.TTS_SERVER, #"http://127.0.0.1:5000", #args.server_url,
self.opt.EMOTION, #emotion
)
)
def gpt_sovits(text, character, language, server_url, emotion) -> Iterator[bytes]:
start = time.perf_counter()
req={}
req["text"] = text
req["text_language"] = language
req["character"] = character
req["emotion"] = emotion
#req["stream_chunk_size"] = stream_chunk_size # you can reduce it to get faster response, but degrade quality
req["stream"] = True
res = requests.post(
f"{server_url}/tts",
json=req,
stream=True,
)
end = time.perf_counter()
print(f"gpt_sovits Time to make POST: {end-start}s")
if res.status_code != 200:
print("Error:", res.text)
return
first = True
for chunk in res.iter_content(chunk_size=32000): # 1280 32K*20ms*2
if first:
end = time.perf_counter()
print(f"gpt_sovits Time to first chunk: {end-start}s")
first = False
if chunk:
yield chunk
print("gpt_sovits response.elapsed:", res.elapsed)
def stream_tts(self,audio_stream):
for chunk in audio_stream:
if chunk is not None and len(chunk)>0:
stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767
stream = resampy.resample(x=stream, sr_orig=32000, sr_new=self.sample_rate)
#byte_stream=BytesIO(buffer)
#stream = self.__create_bytes_stream(byte_stream)
streamlen = stream.shape[0]
idx=0
while streamlen >= self.chunk:
self.parent.put_audio_frame(stream[idx:idx+self.chunk])
streamlen -= self.chunk
idx += self.chunk
###########################################################################################
class XTTS(BaseTTS):
def __init__(self, opt, parent):
super().__init__(opt,parent)
self.speaker = self.get_speaker(opt.REF_FILE, opt.TTS_SERVER)
def txt_to_audio(self,msg):
self.stream_tts(
self.xtts(
msg,
self.speaker,
"zh-cn", #en args.language,
self.opt.TTS_SERVER, #"http://localhost:9000", #args.server_url,
"20" #args.stream_chunk_size
)
)
def get_speaker(self,ref_audio,server_url):
files = {"wav_file": ("reference.wav", open(ref_audio, "rb"))}
response = requests.post(f"{server_url}/clone_speaker", files=files)
return response.json()
def xtts(self,text, speaker, language, server_url, stream_chunk_size) -> Iterator[bytes]:
start = time.perf_counter()
speaker["text"] = text
speaker["language"] = language
speaker["stream_chunk_size"] = stream_chunk_size # you can reduce it to get faster response, but degrade quality
res = requests.post(
f"{server_url}/tts_stream",
json=speaker,
stream=True,
)
end = time.perf_counter()
print(f"xtts Time to make POST: {end-start}s")
if res.status_code != 200:
print("Error:", res.text)
return
first = True
for chunk in res.iter_content(chunk_size=960): #24K*20ms*2
if first:
end = time.perf_counter()
print(f"xtts Time to first chunk: {end-start}s")
first = False
if chunk:
yield chunk
print("xtts response.elapsed:", res.elapsed)
def stream_tts(self,audio_stream):
for chunk in audio_stream:
if chunk is not None and len(chunk)>0:
stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767
stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate)
#byte_stream=BytesIO(buffer)
#stream = self.__create_bytes_stream(byte_stream)
streamlen = stream.shape[0]
idx=0
while streamlen >= self.chunk:
self.parent.put_audio_frame(stream[idx:idx+self.chunk])
streamlen -= self.chunk
idx += self.chunk