feat: add musereal static img

This commit is contained in:
Yun 2024-06-19 14:47:57 +08:00
parent 592312ab8c
commit 5da818b9d9
2 changed files with 113 additions and 106 deletions

1
app.py
View File

@ -285,6 +285,7 @@ if __name__ == '__main__':
parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--customvideo', action='store_true', help="custom video") parser.add_argument('--customvideo', action='store_true', help="custom video")
parser.add_argument('--static_img', action='store_true', help="Use the first photo as a time of rest")
parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img') parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img')
parser.add_argument('--customvideo_imgnum', type=int, default=1) parser.add_argument('--customvideo_imgnum', type=int, default=1)

View File

@ -2,7 +2,7 @@ import math
import torch import torch
import numpy as np import numpy as np
#from .utils import * # from .utils import *
import subprocess import subprocess
import os import os
import time import time
@ -18,17 +18,19 @@ from threading import Thread, Event
from io import BytesIO from io import BytesIO
import multiprocessing as mp import multiprocessing as mp
from musetalk.utils.utils import get_file_type,get_video_fps,datagen from musetalk.utils.utils import get_file_type, get_video_fps, datagen
#from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder # from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending from musetalk.utils.blending import get_image, get_image_prepare_material, get_image_blending
from musetalk.utils.utils import load_all_model,load_diffusion_model,load_audio_model from musetalk.utils.utils import load_all_model, load_diffusion_model, load_audio_model
from ttsreal import EdgeTTS,VoitsTTS,XTTS from ttsreal import EdgeTTS, VoitsTTS, XTTS
from museasr import MuseASR from museasr import MuseASR
import asyncio import asyncio
from av import AudioFrame, VideoFrame from av import AudioFrame, VideoFrame
from tqdm import tqdm from tqdm import tqdm
def read_imgs(img_list): def read_imgs(img_list):
frames = [] frames = []
print('reading images...') print('reading images...')
@ -37,8 +39,9 @@ def read_imgs(img_list):
frames.append(frame) frames.append(frame)
return frames return frames
def __mirror_index(size, index): def __mirror_index(size, index):
#size = len(self.coord_list_cycle) # size = len(self.coord_list_cycle)
turn = index // size turn = index // size
res = index % size res = index % size
if turn % 2 == 0: if turn % 2 == 0:
@ -46,8 +49,9 @@ def __mirror_index(size, index):
else: else:
return size - res - 1 return size - res - 1
def inference(render_event,batch_size,latents_out_path,audio_feat_queue,audio_out_queue,res_frame_queue,
): #vae, unet, pe,timesteps def inference(render_event, batch_size, latents_out_path, audio_feat_queue, audio_out_queue, res_frame_queue,
): # vae, unet, pe,timesteps
vae, unet, pe = load_diffusion_model() vae, unet, pe = load_diffusion_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -59,34 +63,34 @@ def inference(render_event,batch_size,latents_out_path,audio_feat_queue,audio_ou
input_latent_list_cycle = torch.load(latents_out_path) input_latent_list_cycle = torch.load(latents_out_path)
length = len(input_latent_list_cycle) length = len(input_latent_list_cycle)
index = 0 index = 0
count=0 count = 0
counttime=0 counttime = 0
print('start inference') print('start inference')
while True: while True:
if render_event.is_set(): if render_event.is_set():
starttime=time.perf_counter() starttime = time.perf_counter()
try: try:
whisper_chunks = audio_feat_queue.get(block=True, timeout=1) whisper_chunks = audio_feat_queue.get(block=True, timeout=1)
except queue.Empty: except queue.Empty:
continue continue
is_all_silence=True is_all_silence = True
audio_frames = [] audio_frames = []
for _ in range(batch_size*2): for _ in range(batch_size * 2):
frame,type = audio_out_queue.get() frame, type = audio_out_queue.get()
audio_frames.append((frame,type)) audio_frames.append((frame, type))
if type==0: if type == 0:
is_all_silence=False is_all_silence = False
if is_all_silence: if is_all_silence:
for i in range(batch_size): for i in range(batch_size):
res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2])) res_frame_queue.put((None, __mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]))
index = index + 1 index = index + 1
else: else:
# print('infer=======') # print('infer=======')
t=time.perf_counter() t = time.perf_counter()
whisper_batch = np.stack(whisper_chunks) whisper_batch = np.stack(whisper_chunks)
latent_batch = [] latent_batch = []
for i in range(batch_size): for i in range(batch_size):
idx = __mirror_index(length,index+i) idx = __mirror_index(length, index + i)
latent = input_latent_list_cycle[idx] latent = input_latent_list_cycle[idx]
latent_batch.append(latent) latent_batch.append(latent)
latent_batch = torch.cat(latent_batch, dim=0) latent_batch = torch.cat(latent_batch, dim=0)
@ -94,85 +98,87 @@ def inference(render_event,batch_size,latents_out_path,audio_feat_queue,audio_ou
# for i, (whisper_batch,latent_batch) in enumerate(gen): # for i, (whisper_batch,latent_batch) in enumerate(gen):
audio_feature_batch = torch.from_numpy(whisper_batch) audio_feature_batch = torch.from_numpy(whisper_batch)
audio_feature_batch = audio_feature_batch.to(device=unet.device, audio_feature_batch = audio_feature_batch.to(device=unet.device,
dtype=unet.model.dtype) dtype=unet.model.dtype)
audio_feature_batch = pe(audio_feature_batch) audio_feature_batch = pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype) latent_batch = latent_batch.to(dtype=unet.model.dtype)
# print('prepare time:',time.perf_counter()-t) # print('prepare time:',time.perf_counter()-t)
# t=time.perf_counter() # t=time.perf_counter()
pred_latents = unet.model(latent_batch, pred_latents = unet.model(latent_batch,
timesteps, timesteps,
encoder_hidden_states=audio_feature_batch).sample encoder_hidden_states=audio_feature_batch).sample
# print('unet time:',time.perf_counter()-t) # print('unet time:',time.perf_counter()-t)
# t=time.perf_counter() # t=time.perf_counter()
recon = vae.decode_latents(pred_latents) recon = vae.decode_latents(pred_latents)
# print('vae time:',time.perf_counter()-t) # print('vae time:',time.perf_counter()-t)
#print('diffusion len=',len(recon)) # print('diffusion len=',len(recon))
counttime += (time.perf_counter() - t) counttime += (time.perf_counter() - t)
count += batch_size count += batch_size
#_totalframe += 1 # _totalframe += 1
if count>=100: if count >= 100:
print(f"------actual avg infer fps:{count/counttime:.4f}") print(f"------actual avg infer fps:{count / counttime:.4f}")
count=0 count = 0
counttime=0 counttime = 0
for i,res_frame in enumerate(recon): for i, res_frame in enumerate(recon):
#self.__pushmedia(res_frame,loop,audio_track,video_track) # self.__pushmedia(res_frame,loop,audio_track,video_track)
res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2])) res_frame_queue.put((res_frame, __mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]))
index = index + 1 index = index + 1
#print('total batch time:',time.perf_counter()-starttime) # print('total batch time:',time.perf_counter()-starttime)
else: else:
time.sleep(1) time.sleep(1)
print('musereal inference processor stop') print('musereal inference processor stop')
@torch.no_grad() @torch.no_grad()
class MuseReal: class MuseReal:
def __init__(self, opt): def __init__(self, opt):
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.W = opt.W self.W = opt.W
self.H = opt.H self.H = opt.H
self.fps = opt.fps # 20 ms per frame self.fps = opt.fps # 20 ms per frame
#### musetalk #### musetalk
self.avatar_id = opt.avatar_id self.avatar_id = opt.avatar_id
self.video_path = '' #video_path self.static_img = opt.static_img
self.video_path = '' # video_path
self.bbox_shift = opt.bbox_shift self.bbox_shift = opt.bbox_shift
self.avatar_path = f"./data/avatars/{self.avatar_id}" self.avatar_path = f"./data/avatars/{self.avatar_id}"
self.full_imgs_path = f"{self.avatar_path}/full_imgs" self.full_imgs_path = f"{self.avatar_path}/full_imgs"
self.coords_path = f"{self.avatar_path}/coords.pkl" self.coords_path = f"{self.avatar_path}/coords.pkl"
self.latents_out_path= f"{self.avatar_path}/latents.pt" self.latents_out_path = f"{self.avatar_path}/latents.pt"
self.video_out_path = f"{self.avatar_path}/vid_output/" self.video_out_path = f"{self.avatar_path}/vid_output/"
self.mask_out_path =f"{self.avatar_path}/mask" self.mask_out_path = f"{self.avatar_path}/mask"
self.mask_coords_path =f"{self.avatar_path}/mask_coords.pkl" self.mask_coords_path = f"{self.avatar_path}/mask_coords.pkl"
self.avatar_info_path = f"{self.avatar_path}/avator_info.json" self.avatar_info_path = f"{self.avatar_path}/avator_info.json"
self.avatar_info = { self.avatar_info = {
"avatar_id":self.avatar_id, "avatar_id": self.avatar_id,
"video_path":self.video_path, "video_path": self.video_path,
"bbox_shift":self.bbox_shift "bbox_shift": self.bbox_shift
} }
self.batch_size = opt.batch_size self.batch_size = opt.batch_size
self.idx = 0 self.idx = 0
self.res_frame_queue = mp.Queue(self.batch_size*2) self.res_frame_queue = mp.Queue(self.batch_size * 2)
self.__loadmodels() self.__loadmodels()
self.__loadavatar() self.__loadavatar()
self.asr = MuseASR(opt,self.audio_processor) self.asr = MuseASR(opt, self.audio_processor)
if opt.tts == "edgetts": if opt.tts == "edgetts":
self.tts = EdgeTTS(opt,self) self.tts = EdgeTTS(opt, self)
elif opt.tts == "gpt-sovits": elif opt.tts == "gpt-sovits":
self.tts = VoitsTTS(opt,self) self.tts = VoitsTTS(opt, self)
elif opt.tts == "xtts": elif opt.tts == "xtts":
self.tts = XTTS(opt,self) self.tts = XTTS(opt, self)
#self.__warm_up() # self.__warm_up()
self.render_event = mp.Event() self.render_event = mp.Event()
mp.Process(target=inference, args=(self.render_event,self.batch_size,self.latents_out_path, mp.Process(target=inference, args=(self.render_event, self.batch_size, self.latents_out_path,
self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue, self.asr.feat_queue, self.asr.output_queue, self.res_frame_queue,
)).start() #self.vae, self.unet, self.pe,self.timesteps )).start() # self.vae, self.unet, self.pe,self.timesteps
def __loadmodels(self): def __loadmodels(self):
# load model weights # load model weights
self.audio_processor= load_audio_model() self.audio_processor = load_audio_model()
# self.audio_processor, self.vae, self.unet, self.pe = load_all_model() # self.audio_processor, self.vae, self.unet, self.pe = load_all_model()
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# self.timesteps = torch.tensor([0], device=device) # self.timesteps = torch.tensor([0], device=device)
@ -181,7 +187,7 @@ class MuseReal:
# self.unet.model = self.unet.model.half() # self.unet.model = self.unet.model.half()
def __loadavatar(self): def __loadavatar(self):
#self.input_latent_list_cycle = torch.load(self.latents_out_path) # self.input_latent_list_cycle = torch.load(self.latents_out_path)
with open(self.coords_path, 'rb') as f: with open(self.coords_path, 'rb') as f:
self.coord_list_cycle = pickle.load(f) self.coord_list_cycle = pickle.load(f)
input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]')) input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
@ -193,11 +199,10 @@ class MuseReal:
input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
self.mask_list_cycle = read_imgs(input_mask_list) self.mask_list_cycle = read_imgs(input_mask_list)
def put_msg_txt(self, msg):
def put_msg_txt(self,msg):
self.tts.put_msg_txt(msg) self.tts.put_msg_txt(msg)
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm def put_audio_frame(self, audio_chunk): # 16khz 20ms pcm
self.asr.put_audio_frame(audio_chunk) self.asr.put_audio_frame(audio_chunk)
def __mirror_index(self, index): def __mirror_index(self, index):
@ -215,7 +220,7 @@ class MuseReal:
whisper_batch = np.stack(whisper_chunks) whisper_batch = np.stack(whisper_chunks)
latent_batch = [] latent_batch = []
for i in range(self.batch_size): for i in range(self.batch_size):
idx = self.__mirror_index(self.idx+i) idx = self.__mirror_index(self.idx + i)
latent = self.input_latent_list_cycle[idx] latent = self.input_latent_list_cycle[idx]
latent_batch.append(latent) latent_batch.append(latent)
latent_batch = torch.cat(latent_batch, dim=0) latent_batch = torch.cat(latent_batch, dim=0)
@ -223,87 +228,88 @@ class MuseReal:
# for i, (whisper_batch,latent_batch) in enumerate(gen): # for i, (whisper_batch,latent_batch) in enumerate(gen):
audio_feature_batch = torch.from_numpy(whisper_batch) audio_feature_batch = torch.from_numpy(whisper_batch)
audio_feature_batch = audio_feature_batch.to(device=self.unet.device, audio_feature_batch = audio_feature_batch.to(device=self.unet.device,
dtype=self.unet.model.dtype) dtype=self.unet.model.dtype)
audio_feature_batch = self.pe(audio_feature_batch) audio_feature_batch = self.pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=self.unet.model.dtype) latent_batch = latent_batch.to(dtype=self.unet.model.dtype)
pred_latents = self.unet.model(latent_batch, pred_latents = self.unet.model(latent_batch,
self.timesteps, self.timesteps,
encoder_hidden_states=audio_feature_batch).sample encoder_hidden_states=audio_feature_batch).sample
recon = self.vae.decode_latents(pred_latents) recon = self.vae.decode_latents(pred_latents)
def process_frames(self, quit_event, loop=None, audio_track=None, video_track=None):
def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):
while not quit_event.is_set(): while not quit_event.is_set():
try: try:
res_frame,idx,audio_frames = self.res_frame_queue.get(block=True, timeout=1) res_frame, idx, audio_frames = self.res_frame_queue.get(block=True, timeout=1)
except queue.Empty: except queue.Empty:
continue continue
if audio_frames[0][1]==1 and audio_frames[1][1]==1: #全为静音数据只需要取fullimg if audio_frames[0][1] == 1 and audio_frames[1][1] == 1: # 全为静音数据只需要取fullimg
combine_frame = self.frame_list_cycle[idx] if self.static_img:
combine_frame = self.frame_list_cycle[0]
else:
combine_frame = self.frame_list_cycle[idx]
else: else:
bbox = self.coord_list_cycle[idx] bbox = self.coord_list_cycle[idx]
ori_frame = copy.deepcopy(self.frame_list_cycle[idx]) ori_frame = copy.deepcopy(self.frame_list_cycle[idx])
x1, y1, x2, y2 = bbox x1, y1, x2, y2 = bbox
try: try:
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1)) res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
except: except:
continue continue
mask = self.mask_list_cycle[idx] mask = self.mask_list_cycle[idx]
mask_crop_box = self.mask_coords_list_cycle[idx] mask_crop_box = self.mask_coords_list_cycle[idx]
#combine_frame = get_image(ori_frame,res_frame,bbox) # combine_frame = get_image(ori_frame,res_frame,bbox)
#t=time.perf_counter() # t=time.perf_counter()
combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box) combine_frame = get_image_blending(ori_frame, res_frame, bbox, mask, mask_crop_box)
#print('blending time:',time.perf_counter()-t) # print('blending time:',time.perf_counter()-t)
image = combine_frame #(outputs['image'] * 255).astype(np.uint8) image = combine_frame # (outputs['image'] * 255).astype(np.uint8)
new_frame = VideoFrame.from_ndarray(image, format="bgr24") new_frame = VideoFrame.from_ndarray(image, format="bgr24")
asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop) asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop)
for audio_frame in audio_frames: for audio_frame in audio_frames:
frame,type = audio_frame frame, type = audio_frame
frame = (frame * 32767).astype(np.int16) frame = (frame * 32767).astype(np.int16)
new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
new_frame.planes[0].update(frame.tobytes()) new_frame.planes[0].update(frame.tobytes())
new_frame.sample_rate=16000 new_frame.sample_rate = 16000
# if audio_track._queue.qsize()>10: # if audio_track._queue.qsize()>10:
# time.sleep(0.1) # time.sleep(0.1)
asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop) asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop)
print('musereal process_frames thread stop') print('musereal process_frames thread stop')
def render(self,quit_event,loop=None,audio_track=None,video_track=None): def render(self, quit_event, loop=None, audio_track=None, video_track=None):
#if self.opt.asr: # if self.opt.asr:
# self.asr.warm_up() # self.asr.warm_up()
self.tts.render(quit_event) self.tts.render(quit_event)
process_thread = Thread(target=self.process_frames, args=(quit_event,loop,audio_track,video_track)) process_thread = Thread(target=self.process_frames, args=(quit_event, loop, audio_track, video_track))
process_thread.start() process_thread.start()
self.render_event.set() #start infer process render self.render_event.set() # start infer process render
count=0 count = 0
totaltime=0 totaltime = 0
_starttime=time.perf_counter() _starttime = time.perf_counter()
#_totalframe=0 # _totalframe=0
while not quit_event.is_set(): #todo while not quit_event.is_set(): # todo
# update texture every frame # update texture every frame
# audio stream thread... # audio stream thread...
t = time.perf_counter() t = time.perf_counter()
self.asr.run_step() self.asr.run_step()
#self.test_step(loop,audio_track,video_track) # self.test_step(loop,audio_track,video_track)
# totaltime += (time.perf_counter() - t) # totaltime += (time.perf_counter() - t)
# count += self.opt.batch_size # count += self.opt.batch_size
# if count>=100: # if count>=100:
# print(f"------actual avg infer fps:{count/totaltime:.4f}") # print(f"------actual avg infer fps:{count/totaltime:.4f}")
# count=0 # count=0
# totaltime=0 # totaltime=0
if video_track._queue.qsize()>=2*self.opt.batch_size: if video_track._queue.qsize() >= 2 * self.opt.batch_size:
print('sleep qsize=',video_track._queue.qsize()) print('sleep qsize=', video_track._queue.qsize())
time.sleep(0.04*self.opt.batch_size*1.5) time.sleep(0.04 * self.opt.batch_size * 1.5)
# delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms # delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
# if delay > 0: # if delay > 0:
# time.sleep(delay) # time.sleep(delay)
self.render_event.clear() #end infer process render self.render_event.clear() # end infer process render
print('musereal thread stop') print('musereal thread stop')