feat: add musereal static img

This commit is contained in:
Yun 2024-06-19 14:47:57 +08:00
parent 592312ab8c
commit 5da818b9d9
2 changed files with 113 additions and 106 deletions

1
app.py
View File

@ -285,6 +285,7 @@ if __name__ == '__main__':
parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--customvideo', action='store_true', help="custom video") parser.add_argument('--customvideo', action='store_true', help="custom video")
parser.add_argument('--static_img', action='store_true', help="Use the first photo as a time of rest")
parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img') parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img')
parser.add_argument('--customvideo_imgnum', type=int, default=1) parser.add_argument('--customvideo_imgnum', type=int, default=1)

View File

@ -29,6 +29,8 @@ import asyncio
from av import AudioFrame, VideoFrame from av import AudioFrame, VideoFrame
from tqdm import tqdm from tqdm import tqdm
def read_imgs(img_list): def read_imgs(img_list):
frames = [] frames = []
print('reading images...') print('reading images...')
@ -37,6 +39,7 @@ def read_imgs(img_list):
frames.append(frame) frames.append(frame)
return frames return frames
def __mirror_index(size, index): def __mirror_index(size, index):
# size = len(self.coord_list_cycle) # size = len(self.coord_list_cycle)
turn = index // size turn = index // size
@ -46,6 +49,7 @@ def __mirror_index(size, index):
else: else:
return size - res - 1 return size - res - 1
def inference(render_event, batch_size, latents_out_path, audio_feat_queue, audio_out_queue, res_frame_queue, def inference(render_event, batch_size, latents_out_path, audio_feat_queue, audio_out_queue, res_frame_queue,
): # vae, unet, pe,timesteps ): # vae, unet, pe,timesteps
@ -124,6 +128,7 @@ def inference(render_event,batch_size,latents_out_path,audio_feat_queue,audio_ou
time.sleep(1) time.sleep(1)
print('musereal inference processor stop') print('musereal inference processor stop')
@torch.no_grad() @torch.no_grad()
class MuseReal: class MuseReal:
def __init__(self, opt): def __init__(self, opt):
@ -135,6 +140,7 @@ class MuseReal:
#### musetalk #### musetalk
self.avatar_id = opt.avatar_id self.avatar_id = opt.avatar_id
self.static_img = opt.static_img
self.video_path = '' # video_path self.video_path = '' # video_path
self.bbox_shift = opt.bbox_shift self.bbox_shift = opt.bbox_shift
self.avatar_path = f"./data/avatars/{self.avatar_id}" self.avatar_path = f"./data/avatars/{self.avatar_id}"
@ -193,7 +199,6 @@ class MuseReal:
input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
self.mask_list_cycle = read_imgs(input_mask_list) self.mask_list_cycle = read_imgs(input_mask_list)
def put_msg_txt(self, msg): def put_msg_txt(self, msg):
self.tts.put_msg_txt(msg) self.tts.put_msg_txt(msg)
@ -232,7 +237,6 @@ class MuseReal:
encoder_hidden_states=audio_feature_batch).sample encoder_hidden_states=audio_feature_batch).sample
recon = self.vae.decode_latents(pred_latents) recon = self.vae.decode_latents(pred_latents)
def process_frames(self, quit_event, loop=None, audio_track=None, video_track=None): def process_frames(self, quit_event, loop=None, audio_track=None, video_track=None):
while not quit_event.is_set(): while not quit_event.is_set():
@ -241,6 +245,9 @@ class MuseReal:
except queue.Empty: except queue.Empty:
continue continue
if audio_frames[0][1] == 1 and audio_frames[1][1] == 1: # 全为静音数据只需要取fullimg if audio_frames[0][1] == 1 and audio_frames[1][1] == 1: # 全为静音数据只需要取fullimg
if self.static_img:
combine_frame = self.frame_list_cycle[0]
else:
combine_frame = self.frame_list_cycle[idx] combine_frame = self.frame_list_cycle[idx]
else: else:
bbox = self.coord_list_cycle[idx] bbox = self.coord_list_cycle[idx]
@ -306,4 +313,3 @@ class MuseReal:
# time.sleep(delay) # time.sleep(delay)
self.render_event.clear() # end infer process render self.render_event.clear() # end infer process render
print('musereal thread stop') print('musereal thread stop')