improve musetalk infer speed

This commit is contained in:
lipku 2024-06-09 09:04:04 +08:00
parent 016442272e
commit d01860176e
5 changed files with 143 additions and 89 deletions

View File

@ -177,12 +177,12 @@ docker run --gpus all -it --network=host --rm registry.cn-hangzhou.aliyuncs.com
``` ```
docker版本已经不是最新代码可以作为一个空环境把最新代码拷进去运行。 docker版本已经不是最新代码可以作为一个空环境把最新代码拷进去运行。
另外提供autodl教程 另外提供autodl镜像
https://www.codewithgpu.com/i/lipku/metahuman-stream/base
[autodl教程](autodl/README.md) [autodl教程](autodl/README.md)
## 5. Data flow
![](/assets/dataflow.png)
## 6. 数字人模型文件
## 5. 数字人模型文件
可以替换成自己训练的模型(https://github.com/Fictionarry/ER-NeRF) 可以替换成自己训练的模型(https://github.com/Fictionarry/ER-NeRF)
```python ```python
. .
@ -194,7 +194,7 @@ docker版本已经不是最新代码可以作为一个空环境把最新
``` ```
## 7. 性能分析 ## 6. 性能分析
1. 帧率 1. 帧率
在Tesla T4显卡上测试整体fps为18左右如果去掉音视频编码推流帧率在20左右。用4090显卡可以达到40多帧/秒。 在Tesla T4显卡上测试整体fps为18左右如果去掉音视频编码推流帧率在20左右。用4090显卡可以达到40多帧/秒。
优化:新开一个线程运行音视频编码推流 优化:新开一个线程运行音视频编码推流
@ -204,7 +204,7 @@ docker版本已经不是最新代码可以作为一个空环境把最新
2wav2vec延时0.4s需要缓存18帧音频做计算 2wav2vec延时0.4s需要缓存18帧音频做计算
3srs转发延时设置srs服务器减少缓冲延时。具体配置可看 https://ossrs.net/lts/zh-cn/docs/v5/doc/low-latency 3srs转发延时设置srs服务器减少缓冲延时。具体配置可看 https://ossrs.net/lts/zh-cn/docs/v5/doc/low-latency
## 8. TODO ## 7. TODO
- [x] 添加chatgpt实现数字人对话 - [x] 添加chatgpt实现数字人对话
- [x] 声音克隆 - [x] 声音克隆
- [x] 数字人静音时用一段视频代替 - [x] 数字人静音时用一段视频代替
@ -215,5 +215,4 @@ docker版本已经不是最新代码可以作为一个空环境把最新
知识星球: https://t.zsxq.com/7NMyO 知识星球: https://t.zsxq.com/7NMyO
微信公众号:数字人技术 微信公众号:数字人技术
![](https://mmbiz.qpic.cn/sz_mmbiz_jpg/l3ZibgueFiaeyfaiaLZGuMGQXnhLWxibpJUS2gfs8Dje6JuMY8zu2tVyU9n8Zx1yaNncvKHBMibX0ocehoITy5qQEZg/640?wxfrom=12&tp=wxpic&usePicPrefetch=1&wx_fmt=jpeg&from=appmsg) ![](https://mmbiz.qpic.cn/sz_mmbiz_jpg/l3ZibgueFiaeyfaiaLZGuMGQXnhLWxibpJUS2gfs8Dje6JuMY8zu2tVyU9n8Zx1yaNncvKHBMibX0ocehoITy5qQEZg/640?wxfrom=12&tp=wxpic&usePicPrefetch=1&wx_fmt=jpeg&from=appmsg)
Buy me a coffee
![](https://mmbiz.qpic.cn/sz_mmbiz_jpg/l3ZibgueFiaeyEO2TDmroXibUSeFRCB3ftThHyTgVmVYyVVyvqDxronGvoU7xzkztnwQpnM5lBgx4MSaUUrnRZwCw/640?wx_fmt=jpeg&from=appmsg)

5
app.py
View File

@ -164,9 +164,10 @@ async def run(push_url):
answer = await post(push_url,pc.localDescription.sdp) answer = await post(push_url,pc.localDescription.sdp)
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer')) await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer'))
########################################## ##########################################
# os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
# os.environ['MULTIPROCESSING_METHOD'] = 'forkserver'
if __name__ == '__main__': if __name__ == '__main__':
multiprocessing.set_start_method('spawn')
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--pose', type=str, default="data/data_kf.json", help="transforms.json, pose source") parser.add_argument('--pose', type=str, default="data/data_kf.json", help="transforms.json, pose source")
parser.add_argument('--au', type=str, default="data/au.csv", help="eye blink area") parser.add_argument('--au', type=str, default="data/au.csv", help="eye blink area")

View File

@ -52,7 +52,6 @@ var url = "http://公网ip:1985/rtc/v1/whep/?app=live&stream=livestream"
## 注意事项 ## 注意事项
1. autodl 如果是个人用户需要使用官方的ssh代理工具进行端口代理才可以访问6006 1. autodl 如果是个人用户需要使用官方的ssh代理工具进行端口代理才可以访问6006
2.基础环境镜像中如果想使用musetalk环境还需要自己操作 2. 声音延迟需要后台优化srs的功能
3.声音延迟需要后台优化srs的功能 3. musetalk 暂不支持rtmp推流 但是支持rtcpush
4.musetalk 暂不支持rtmp推流 但是支持rtcpush 4. musetalk 教程即将更新
5.musetalk 教程即将更新

View File

@ -7,6 +7,7 @@ import resampy
import queue import queue
from queue import Queue from queue import Queue
from io import BytesIO from io import BytesIO
import multiprocessing as mp
from musetalk.whisper.audio2feature import Audio2Feature from musetalk.whisper.audio2feature import Audio2Feature
@ -19,13 +20,14 @@ class MuseASR:
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000) self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.queue = Queue() self.queue = Queue()
# self.input_stream = BytesIO() # self.input_stream = BytesIO()
self.output_queue = Queue() self.output_queue = mp.Queue()
self.audio_processor = audio_processor self.audio_processor = audio_processor
self.batch_size = opt.batch_size self.batch_size = opt.batch_size
self.stride_left_size = self.stride_right_size = 6 self.stride_left_size = self.stride_right_size = 6
self.audio_feats = [] self.audio_feats = []
self.feat_queue = mp.Queue(5)
self.warm_up() self.warm_up()
@ -34,7 +36,7 @@ class MuseASR:
def __get_audio_frame(self): def __get_audio_frame(self):
try: try:
frame = self.queue.get(block=True,timeout=0.02) frame = self.queue.get(block=True,timeout=0.018)
type = 0 type = 0
#print(f'[INFO] get frame {frame.shape}') #print(f'[INFO] get frame {frame.shape}')
except queue.Empty: except queue.Empty:
@ -72,11 +74,11 @@ class MuseASR:
whisper_feature = self.audio_processor.audio2feat(inputs) whisper_feature = self.audio_processor.audio2feat(inputs)
for feature in whisper_feature: for feature in whisper_feature:
self.audio_feats.append(feature) self.audio_feats.append(feature)
#print(f"processing audio costs {(time.time() - start_time) * 1000}ms, inputs shape:{inputs.shape} whisper_feature len:{len(whisper_feature)}") #print(f"processing audio costs {(time.time() - start_time) * 1000}ms, inputs shape:{inputs.shape} whisper_feature len:{len(whisper_feature)}")
def get_next_feat(self):
whisper_chunks = self.audio_processor.feature2chunks(feature_array=self.audio_feats,fps=self.fps/2,batch_size=self.batch_size,start=self.stride_left_size/2 ) whisper_chunks = self.audio_processor.feature2chunks(feature_array=self.audio_feats,fps=self.fps/2,batch_size=self.batch_size,start=self.stride_left_size/2 )
#print(f"whisper_chunks len:{len(whisper_chunks)},self.audio_feats len:{len(self.audio_feats)},self.output_queue len:{self.output_queue.qsize()}") #print(f"whisper_chunks len:{len(whisper_chunks)},self.audio_feats len:{len(self.audio_feats)},self.output_queue len:{self.output_queue.qsize()}")
self.audio_feats = self.audio_feats[-(self.stride_left_size + self.stride_right_size):] self.audio_feats = self.audio_feats[-(self.stride_left_size + self.stride_right_size):]
return whisper_chunks self.feat_queue.put(whisper_chunks)
def get_next_feat(self,block,timeout):
return self.feat_queue.get(block,timeout)

View File

@ -16,9 +16,10 @@ import queue
from queue import Queue from queue import Queue
from threading import Thread, Event from threading import Thread, Event
from io import BytesIO from io import BytesIO
import multiprocessing as mp
from musetalk.utils.utils import get_file_type,get_video_fps,datagen from musetalk.utils.utils import get_file_type,get_video_fps,datagen
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder #from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending
from musetalk.utils.utils import load_all_model from musetalk.utils.utils import load_all_model
from ttsreal import EdgeTTS,VoitsTTS,XTTS from ttsreal import EdgeTTS,VoitsTTS,XTTS
@ -27,6 +28,102 @@ from museasr import MuseASR
import asyncio import asyncio
from av import AudioFrame, VideoFrame from av import AudioFrame, VideoFrame
from tqdm import tqdm
def read_imgs(img_list):
frames = []
print('reading images...')
for img_path in tqdm(img_list):
frame = cv2.imread(img_path)
frames.append(frame)
return frames
def __mirror_index(size, index):
#size = len(self.coord_list_cycle)
turn = index // size
res = index % size
if turn % 2 == 0:
return res
else:
return size - res - 1
def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,audio_out_queue,res_frame_queue,
vae, unet, pe,timesteps):
# _, vae, unet, pe = load_all_model()
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# timesteps = torch.tensor([0], device=device)
# pe = pe.half()
# vae.vae = vae.vae.half()
# unet.model = unet.model.half()
#input_latent_list_cycle = torch.load(latents_out_path)
length = len(input_latent_list_cycle)
index = 0
count=0
counttime=0
print('start inference')
while True:
if render_event.is_set():
starttime=time.perf_counter()
try:
whisper_chunks = audio_feat_queue.get(block=True, timeout=1)
except queue.Empty:
continue
is_all_silence=True
audio_frames = []
for _ in range(batch_size*2):
frame,type = audio_out_queue.get()
audio_frames.append((frame,type))
if type==0:
is_all_silence=False
if is_all_silence:
for i in range(batch_size):
res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
index = index + 1
else:
# print('infer=======')
t=time.perf_counter()
whisper_batch = np.stack(whisper_chunks)
latent_batch = []
for i in range(batch_size):
idx = __mirror_index(length,index+i)
latent = input_latent_list_cycle[idx]
latent_batch.append(latent)
latent_batch = torch.cat(latent_batch, dim=0)
# for i, (whisper_batch,latent_batch) in enumerate(gen):
audio_feature_batch = torch.from_numpy(whisper_batch)
audio_feature_batch = audio_feature_batch.to(device=unet.device,
dtype=unet.model.dtype)
audio_feature_batch = pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
# print('prepare time:',time.perf_counter()-t)
# t=time.perf_counter()
pred_latents = unet.model(latent_batch,
timesteps,
encoder_hidden_states=audio_feature_batch).sample
# print('unet time:',time.perf_counter()-t)
# t=time.perf_counter()
recon = vae.decode_latents(pred_latents)
# print('vae time:',time.perf_counter()-t)
#print('diffusion len=',len(recon))
counttime += (time.perf_counter() - t)
count += batch_size
#_totalframe += 1
if count>=100:
print(f"------actual avg infer fps:{count/counttime:.4f}")
count=0
counttime=0
for i,res_frame in enumerate(recon):
#self.__pushmedia(res_frame,loop,audio_track,video_track)
res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
index = index + 1
print('total batch time:',time.perf_counter()-starttime)
else:
time.sleep(1)
print('musereal inference processor stop')
@torch.no_grad() @torch.no_grad()
class MuseReal: class MuseReal:
def __init__(self, opt): def __init__(self, opt):
@ -55,7 +152,7 @@ class MuseReal:
} }
self.batch_size = opt.batch_size self.batch_size = opt.batch_size
self.idx = 0 self.idx = 0
self.res_frame_queue = Queue() self.res_frame_queue = mp.Queue(self.batch_size*2)
self.__loadmodels() self.__loadmodels()
self.__loadavatar() self.__loadavatar()
@ -68,6 +165,11 @@ class MuseReal:
self.tts = XTTS(opt,self) self.tts = XTTS(opt,self)
#self.__warm_up() #self.__warm_up()
self.render_event = mp.Event()
mp.Process(target=inference, args=(self.render_event,self.batch_size,self.input_latent_list_cycle,
self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue,
self.vae, self.unet, self.pe,self.timesteps)).start()
def __loadmodels(self): def __loadmodels(self):
# load model weights # load model weights
self.audio_processor, self.vae, self.unet, self.pe = load_all_model() self.audio_processor, self.vae, self.unet, self.pe = load_all_model()
@ -129,59 +231,6 @@ class MuseReal:
encoder_hidden_states=audio_feature_batch).sample encoder_hidden_states=audio_feature_batch).sample
recon = self.vae.decode_latents(pred_latents) recon = self.vae.decode_latents(pred_latents)
def test_step(self,loop=None,audio_track=None,video_track=None):
# gen = datagen(whisper_chunks,
# self.input_latent_list_cycle,
# self.batch_size)
starttime=time.perf_counter()
self.asr.run_step()
whisper_chunks = self.asr.get_next_feat()
is_all_silence=True
audio_frames = []
for _ in range(self.batch_size*2):
frame,type = self.asr.get_audio_out()
audio_frames.append((frame,type))
if type==0:
is_all_silence=False
if is_all_silence:
for i in range(self.batch_size):
self.res_frame_queue.put((None,self.__mirror_index(self.idx),audio_frames[i*2:i*2+2]))
self.idx = self.idx + 1
else:
# print('infer=======')
t=time.perf_counter()
whisper_batch = np.stack(whisper_chunks)
latent_batch = []
for i in range(self.batch_size):
idx = self.__mirror_index(self.idx+i)
latent = self.input_latent_list_cycle[idx]
latent_batch.append(latent)
latent_batch = torch.cat(latent_batch, dim=0)
# for i, (whisper_batch,latent_batch) in enumerate(gen):
audio_feature_batch = torch.from_numpy(whisper_batch)
audio_feature_batch = audio_feature_batch.to(device=self.unet.device,
dtype=self.unet.model.dtype)
audio_feature_batch = self.pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=self.unet.model.dtype)
# print('prepare time:',time.perf_counter()-t)
# t=time.perf_counter()
pred_latents = self.unet.model(latent_batch,
self.timesteps,
encoder_hidden_states=audio_feature_batch).sample
# print('unet time:',time.perf_counter()-t)
# t=time.perf_counter()
recon = self.vae.decode_latents(pred_latents)
# print('vae time:',time.perf_counter()-t)
#print('diffusion len=',len(recon))
for i,res_frame in enumerate(recon):
#self.__pushmedia(res_frame,loop,audio_track,video_track)
self.res_frame_queue.put((res_frame,self.__mirror_index(self.idx),audio_frames[i*2:i*2+2]))
self.idx = self.idx + 1
print('total batch time:',time.perf_counter()-starttime)
def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None): def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):
@ -203,7 +252,9 @@ class MuseReal:
mask = self.mask_list_cycle[idx] mask = self.mask_list_cycle[idx]
mask_crop_box = self.mask_coords_list_cycle[idx] mask_crop_box = self.mask_coords_list_cycle[idx]
#combine_frame = get_image(ori_frame,res_frame,bbox) #combine_frame = get_image(ori_frame,res_frame,bbox)
#t=time.perf_counter()
combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box) combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
#print('blending time:',time.perf_counter()-t)
image = combine_frame #(outputs['image'] * 255).astype(np.uint8) image = combine_frame #(outputs['image'] * 255).astype(np.uint8)
new_frame = VideoFrame.from_ndarray(image, format="bgr24") new_frame = VideoFrame.from_ndarray(image, format="bgr24")
@ -228,6 +279,7 @@ class MuseReal:
process_thread = Thread(target=self.process_frames, args=(quit_event,loop,audio_track,video_track)) process_thread = Thread(target=self.process_frames, args=(quit_event,loop,audio_track,video_track))
process_thread.start() process_thread.start()
self.render_event.set() #start infer process render
count=0 count=0
totaltime=0 totaltime=0
_starttime=time.perf_counter() _starttime=time.perf_counter()
@ -236,20 +288,21 @@ class MuseReal:
# update texture every frame # update texture every frame
# audio stream thread... # audio stream thread...
t = time.perf_counter() t = time.perf_counter()
self.test_step(loop,audio_track,video_track) self.asr.run_step()
totaltime += (time.perf_counter() - t) #self.test_step(loop,audio_track,video_track)
count += self.opt.batch_size # totaltime += (time.perf_counter() - t)
#_totalframe += 1 # count += self.opt.batch_size
if count>=100: # if count>=100:
print(f"------actual avg infer fps:{count/totaltime:.4f}") # print(f"------actual avg infer fps:{count/totaltime:.4f}")
count=0 # count=0
totaltime=0 # totaltime=0
if video_track._queue.qsize()>=2*self.opt.batch_size: if video_track._queue.qsize()>=2*self.opt.batch_size:
#print('sleep qsize=',video_track._queue.qsize()) print('sleep qsize=',video_track._queue.qsize())
time.sleep(0.04*self.opt.batch_size*1.5) time.sleep(0.04*self.opt.batch_size*1.5)
# delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms # delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
# if delay > 0: # if delay > 0:
# time.sleep(delay) # time.sleep(delay)
self.render_event.clear() #end infer process render
print('musereal thread stop') print('musereal thread stop')