feat: add musereal static img
This commit is contained in:
parent
592312ab8c
commit
5da818b9d9
1
app.py
1
app.py
|
@ -285,6 +285,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--batch_size', type=int, default=16)
|
||||
|
||||
parser.add_argument('--customvideo', action='store_true', help="custom video")
|
||||
parser.add_argument('--static_img', action='store_true', help="Use the first photo as a time of rest")
|
||||
parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img')
|
||||
parser.add_argument('--customvideo_imgnum', type=int, default=1)
|
||||
|
||||
|
|
12
musereal.py
12
musereal.py
|
@ -29,6 +29,8 @@ import asyncio
|
|||
from av import AudioFrame, VideoFrame
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def read_imgs(img_list):
|
||||
frames = []
|
||||
print('reading images...')
|
||||
|
@ -37,6 +39,7 @@ def read_imgs(img_list):
|
|||
frames.append(frame)
|
||||
return frames
|
||||
|
||||
|
||||
def __mirror_index(size, index):
|
||||
# size = len(self.coord_list_cycle)
|
||||
turn = index // size
|
||||
|
@ -46,6 +49,7 @@ def __mirror_index(size, index):
|
|||
else:
|
||||
return size - res - 1
|
||||
|
||||
|
||||
def inference(render_event, batch_size, latents_out_path, audio_feat_queue, audio_out_queue, res_frame_queue,
|
||||
): # vae, unet, pe,timesteps
|
||||
|
||||
|
@ -124,6 +128,7 @@ def inference(render_event,batch_size,latents_out_path,audio_feat_queue,audio_ou
|
|||
time.sleep(1)
|
||||
print('musereal inference processor stop')
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
class MuseReal:
|
||||
def __init__(self, opt):
|
||||
|
@ -135,6 +140,7 @@ class MuseReal:
|
|||
|
||||
#### musetalk
|
||||
self.avatar_id = opt.avatar_id
|
||||
self.static_img = opt.static_img
|
||||
self.video_path = '' # video_path
|
||||
self.bbox_shift = opt.bbox_shift
|
||||
self.avatar_path = f"./data/avatars/{self.avatar_id}"
|
||||
|
@ -193,7 +199,6 @@ class MuseReal:
|
|||
input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||
self.mask_list_cycle = read_imgs(input_mask_list)
|
||||
|
||||
|
||||
def put_msg_txt(self, msg):
|
||||
self.tts.put_msg_txt(msg)
|
||||
|
||||
|
@ -232,7 +237,6 @@ class MuseReal:
|
|||
encoder_hidden_states=audio_feature_batch).sample
|
||||
recon = self.vae.decode_latents(pred_latents)
|
||||
|
||||
|
||||
def process_frames(self, quit_event, loop=None, audio_track=None, video_track=None):
|
||||
|
||||
while not quit_event.is_set():
|
||||
|
@ -241,6 +245,9 @@ class MuseReal:
|
|||
except queue.Empty:
|
||||
continue
|
||||
if audio_frames[0][1] == 1 and audio_frames[1][1] == 1: # 全为静音数据,只需要取fullimg
|
||||
if self.static_img:
|
||||
combine_frame = self.frame_list_cycle[0]
|
||||
else:
|
||||
combine_frame = self.frame_list_cycle[idx]
|
||||
else:
|
||||
bbox = self.coord_list_cycle[idx]
|
||||
|
@ -306,4 +313,3 @@ class MuseReal:
|
|||
# time.sleep(delay)
|
||||
self.render_event.clear() # end infer process render
|
||||
print('musereal thread stop')
|
||||
|
Loading…
Reference in New Issue