From df6b9d3c972b11699c7ecca3bf4d461021ed7e56 Mon Sep 17 00:00:00 2001 From: lipku Date: Sat, 23 Mar 2024 18:15:35 +0800 Subject: [PATCH] support hubert model --- README.md | 21 ++++++++++++++------- app.py | 4 ++-- asrreal.py | 18 ++++++++++++++---- nerf_triplane/network.py | 2 ++ 4 files changed, 32 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 9580706..8616e7b 100644 --- a/README.md +++ b/README.md @@ -53,14 +53,15 @@ nginx 用浏览器打开http://serverip/echo.html, 在文本框输入任意文字,提交。数字人播报该段文字 -### 2.4 使用LLM模型进行数字人对话 +## 3. 更多使用 +### 3.1 使用LLM模型进行数字人对话 目前借鉴数字人对话系统[LinlyTalker](https://github.com/Kedreamix/Linly-Talker)的方式,LLM模型支持Chatgpt,Qwen和GeminiPro。需要在app.py中填入自己的api_key。 安装并启动nginx,将chat.html和mpegts-1.7.3.min.js拷到/var/www/html下 用浏览器打开http://serverip/chat.html -### 2.5 使用本地tts服务,支持声音克隆 +### 3.2 使用本地tts服务,支持声音克隆 运行xtts服务,参照 https://github.com/coqui-ai/xtts-streaming-server ``` docker run --gpus=all -e COQUI_TOS_AGREED=1 --rm -p 9000:80 ghcr.io/coqui-ai/xtts-streaming-server:latest @@ -69,18 +70,24 @@ docker run --gpus=all -e COQUI_TOS_AGREED=1 --rm -p 9000:80 ghcr.io/coqui-ai/xtt ``` python app.py --tts xtts --ref_file data/ref.wav ``` + +### 3.3 音频特征用hubert +如果训练模型时用的hubert提取音频特征,用如下命令启动数字人 +``` +python app.py --asr_model facebook/hubert-large-ls960-ft +``` -## 3. Docker Run +## 4. Docker Run 不需要第1步的安装,直接运行。 ``` docker run --gpus all -it --network=host --rm registry.cn-hangzhou.aliyuncs.com/lipku/nerfstream:v1.3 ``` srs和nginx的运行同2.1和2.3 -## 4. Data flow +## 5. Data flow ![](/assets/dataflow.png) -## 5. 数字人模型文件 +## 6. 数字人模型文件 可以替换成自己训练的模型(https://github.com/Fictionarry/ER-NeRF) ```python . @@ -92,7 +99,7 @@ srs和nginx的运行同2.1和2.3 ``` -## 6. 性能分析 +## 7. 性能分析 1. 帧率 在Tesla T4显卡上测试整体fps为18左右,如果去掉音视频编码推流,帧率在20左右。用4090显卡可以达到40多帧/秒。 优化:新开一个线程运行音视频编码推流 @@ -105,7 +112,7 @@ srs和nginx的运行同2.1和2.3 docker run --rm -it -p 1935:1935 -p 1985:1985 -p 8080:8080 registry.cn-hangzhou.aliyuncs.com/lipku/srs:v1.1 ``` -## 7. TODO +## 8. TODO - [x] 添加chatgpt实现数字人对话 - [x] 声音克隆 - [ ] 数字人静音时用一段视频代替 diff --git a/app.py b/app.py index b9cb8e0..2e5c5b2 100644 --- a/app.py +++ b/app.py @@ -116,7 +116,7 @@ def echo_socket(ws): while True: message = ws.receive() - if len(message)==0: + if not message or len(message)==0: return '输入信息为空' else: txt_to_audio(message) @@ -247,7 +247,7 @@ if __name__ == '__main__': parser.add_argument('--asr_play', action='store_true', help="play out the audio") #parser.add_argument('--asr_model', type=str, default='deepspeech') - parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') + parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') #facebook/hubert-large-ls960-ft # parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self') parser.add_argument('--push_url', type=str, default='rtmp://localhost/live/livestream') diff --git a/asrreal.py b/asrreal.py index d9fb3e5..3020021 100644 --- a/asrreal.py +++ b/asrreal.py @@ -2,7 +2,7 @@ import time import numpy as np import torch import torch.nn.functional as F -from transformers import AutoModelForCTC, AutoProcessor +from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel #import pyaudio import soundfile as sf @@ -52,6 +52,8 @@ class ASR: self.audio_dim = 44 elif 'deepspeech' in self.opt.asr_model: self.audio_dim = 29 + elif 'hubert' in self.opt.asr_model: + self.audio_dim = 1024 else: self.audio_dim = 32 @@ -96,8 +98,12 @@ class ASR: # create wav2vec model print(f'[INFO] loading ASR model {self.opt.asr_model}...') - self.processor = AutoProcessor.from_pretrained(opt.asr_model) - self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device) + if 'hubert' in self.opt.asr_model: + self.processor = Wav2Vec2Processor.from_pretrained(opt.asr_model) + self.model = HubertModel.from_pretrained(opt.asr_model).to(self.device) + else: + self.processor = AutoProcessor.from_pretrained(opt.asr_model) + self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device) # prepare to save logits if self.opt.asr_save_feats: @@ -339,7 +345,11 @@ class ASR: with torch.no_grad(): result = self.model(inputs.input_values.to(self.device)) - logits = result.logits # [1, N - 1, 32] + if 'hubert' in self.opt.asr_model: + logits = result.last_hidden_state # [B=1, T=pts//320, hid=1024] + else: + logits = result.logits # [1, N - 1, 32] + #print('logits.shape:',logits.shape) # cut off stride left = max(0, self.stride_left_size) diff --git a/nerf_triplane/network.py b/nerf_triplane/network.py index 7db0cf2..0eb3f24 100644 --- a/nerf_triplane/network.py +++ b/nerf_triplane/network.py @@ -104,6 +104,8 @@ class NeRFNetwork(NeRFRenderer): self.audio_in_dim = 44 elif 'deepspeech' in self.opt.asr_model: self.audio_in_dim = 29 + elif 'hubert' in self.opt.asr_model: + self.audio_in_dim = 1024 else: self.audio_in_dim = 32