From da9ffa9521b984888cce276ba4ad71fb68ce0822 Mon Sep 17 00:00:00 2001 From: lipku Date: Sat, 22 Jun 2024 09:02:01 +0800 Subject: [PATCH] improve musetalk lipsync and speed --- lipasr.py | 13 +-- museasr.py | 36 +++--- musetalk/utils/blending.py | 225 ++++++++++++++++++++----------------- webrtc.py | 21 ++-- 4 files changed, 158 insertions(+), 137 deletions(-) diff --git a/lipasr.py b/lipasr.py index e1ba811..5742dd7 100644 --- a/lipasr.py +++ b/lipasr.py @@ -1,12 +1,9 @@ import time import torch import numpy as np -import soundfile as sf -import resampy import queue from queue import Queue -from io import BytesIO import multiprocessing as mp from wav2lip import audio @@ -26,9 +23,9 @@ class LipASR: self.batch_size = opt.batch_size self.frames = [] - self.stride_left_size = self.stride_right_size = 10 - self.context_size = 10 - self.audio_feats = [] + self.stride_left_size = opt.l + self.stride_right_size = opt.r + #self.context_size = 10 self.feat_queue = mp.Queue(5) self.warm_up() @@ -38,7 +35,7 @@ class LipASR: def __get_audio_frame(self): try: - frame = self.queue.get(block=True,timeout=0.018) + frame = self.queue.get(block=True,timeout=0.01) type = 0 #print(f'[INFO] get frame {frame.shape}') except queue.Empty: @@ -67,7 +64,7 @@ class LipASR: # put to output self.output_queue.put((frame,type)) # context not enough, do not run network. - if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size: + if len(self.frames) <= self.stride_left_size + self.stride_right_size: return inputs = np.concatenate(self.frames) # [N * chunk] diff --git a/museasr.py b/museasr.py index 251225f..cfcd9ba 100644 --- a/museasr.py +++ b/museasr.py @@ -1,12 +1,9 @@ import time import torch import numpy as np -import soundfile as sf -import resampy import queue from queue import Queue -from io import BytesIO import multiprocessing as mp from musetalk.whisper.audio2feature import Audio2Feature @@ -25,8 +22,9 @@ class MuseASR: self.audio_processor = audio_processor self.batch_size = opt.batch_size - self.stride_left_size = self.stride_right_size = 6 - self.audio_feats = [] + self.frames = [] + self.stride_left_size = opt.l + self.stride_right_size = opt.r self.feat_queue = mp.Queue(5) self.warm_up() @@ -36,7 +34,7 @@ class MuseASR: def __get_audio_frame(self): try: - frame = self.queue.get(block=True,timeout=0.018) + frame = self.queue.get(block=True,timeout=0.01) type = 0 #print(f'[INFO] get frame {frame.shape}') except queue.Empty: @@ -49,15 +47,10 @@ class MuseASR: return self.output_queue.get() def warm_up(self): - frames = [] for _ in range(self.stride_left_size + self.stride_right_size): audio_frame,type=self.__get_audio_frame() - frames.append(audio_frame) + self.frames.append(audio_frame) self.output_queue.put((audio_frame,type)) - inputs = np.concatenate(frames) # [N * chunk] - whisper_feature = self.audio_processor.audio2feat(inputs) - for feature in whisper_feature: - self.audio_feats.append(feature) for _ in range(self.stride_left_size): self.output_queue.get() @@ -65,20 +58,25 @@ class MuseASR: def run_step(self): ############################################## extract audio feature ############################################## start_time = time.time() - frames = [] for _ in range(self.batch_size*2): audio_frame,type=self.__get_audio_frame() - frames.append(audio_frame) + self.frames.append(audio_frame) self.output_queue.put((audio_frame,type)) - inputs = np.concatenate(frames) # [N * chunk] + + if len(self.frames) <= self.stride_left_size + self.stride_right_size: + return + + inputs = np.concatenate(self.frames) # [N * chunk] whisper_feature = self.audio_processor.audio2feat(inputs) - for feature in whisper_feature: - self.audio_feats.append(feature) + # for feature in whisper_feature: + # self.audio_feats.append(feature) #print(f"processing audio costs {(time.time() - start_time) * 1000}ms, inputs shape:{inputs.shape} whisper_feature len:{len(whisper_feature)}") - whisper_chunks = self.audio_processor.feature2chunks(feature_array=self.audio_feats,fps=self.fps/2,batch_size=self.batch_size,start=self.stride_left_size/2 ) + whisper_chunks = self.audio_processor.feature2chunks(feature_array=whisper_feature,fps=self.fps/2,batch_size=self.batch_size,start=self.stride_left_size/2 ) #print(f"whisper_chunks len:{len(whisper_chunks)},self.audio_feats len:{len(self.audio_feats)},self.output_queue len:{self.output_queue.qsize()}") - self.audio_feats = self.audio_feats[-(self.stride_left_size + self.stride_right_size):] + #self.audio_feats = self.audio_feats[-(self.stride_left_size + self.stride_right_size):] self.feat_queue.put(whisper_chunks) + # discard the old part to save memory + self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] def get_next_feat(self,block,timeout): return self.feat_queue.get(block,timeout) \ No newline at end of file diff --git a/musetalk/utils/blending.py b/musetalk/utils/blending.py index d69e435..108c7c2 100644 --- a/musetalk/utils/blending.py +++ b/musetalk/utils/blending.py @@ -1,100 +1,125 @@ -from PIL import Image -import numpy as np -import cv2 -from face_parsing import FaceParsing - -fp = FaceParsing() - -def get_crop_box(box, expand): - x, y, x1, y1 = box - x_c, y_c = (x+x1)//2, (y+y1)//2 - w, h = x1-x, y1-y - s = int(max(w, h)//2*expand) - crop_box = [x_c-s, y_c-s, x_c+s, y_c+s] - return crop_box, s - -def face_seg(image): - seg_image = fp(image) - if seg_image is None: - print("error, no person_segment") - return None - - seg_image = seg_image.resize(image.size) - return seg_image - -def get_image(image,face,face_box,upper_boundary_ratio = 0.5,expand=1.2): - #print(image.shape) - #print(face.shape) - - body = Image.fromarray(image[:,:,::-1]) - face = Image.fromarray(face[:,:,::-1]) - - x, y, x1, y1 = face_box - #print(x1-x,y1-y) - crop_box, s = get_crop_box(face_box, expand) - x_s, y_s, x_e, y_e = crop_box - face_position = (x, y) - - face_large = body.crop(crop_box) - ori_shape = face_large.size - - mask_image = face_seg(face_large) - mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s)) - mask_image = Image.new('L', ori_shape, 0) - mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s)) - - # keep upper_boundary_ratio of talking area - width, height = mask_image.size - top_boundary = int(height * upper_boundary_ratio) - modified_mask_image = Image.new('L', ori_shape, 0) - modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary)) - - blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1 - mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0) - mask_image = Image.fromarray(mask_array) - - face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s)) - body.paste(face_large, crop_box[:2], mask_image) - body = np.array(body) - return body[:,:,::-1] - -def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand=1.2): - body = Image.fromarray(image[:,:,::-1]) - - x, y, x1, y1 = face_box - #print(x1-x,y1-y) - crop_box, s = get_crop_box(face_box, expand) - x_s, y_s, x_e, y_e = crop_box - - face_large = body.crop(crop_box) - ori_shape = face_large.size - - mask_image = face_seg(face_large) - mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s)) - mask_image = Image.new('L', ori_shape, 0) - mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s)) - - # keep upper_boundary_ratio of talking area - width, height = mask_image.size - top_boundary = int(height * upper_boundary_ratio) - modified_mask_image = Image.new('L', ori_shape, 0) - modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary)) - - blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1 - mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0) - return mask_array,crop_box - -def get_image_blending(image,face,face_box,mask_array,crop_box): - body = Image.fromarray(image[:,:,::-1]) - face = Image.fromarray(face[:,:,::-1]) - - x, y, x1, y1 = face_box - x_s, y_s, x_e, y_e = crop_box - face_large = body.crop(crop_box) - - mask_image = Image.fromarray(mask_array) - mask_image = mask_image.convert("L") - face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s)) - body.paste(face_large, crop_box[:2], mask_image) - body = np.array(body) - return body[:,:,::-1] \ No newline at end of file +from PIL import Image +import numpy as np +import cv2 +from face_parsing import FaceParsing +import copy + +fp = FaceParsing() + +def get_crop_box(box, expand): + x, y, x1, y1 = box + x_c, y_c = (x+x1)//2, (y+y1)//2 + w, h = x1-x, y1-y + s = int(max(w, h)//2*expand) + crop_box = [x_c-s, y_c-s, x_c+s, y_c+s] + return crop_box, s + +def face_seg(image): + seg_image = fp(image) + if seg_image is None: + print("error, no person_segment") + return None + + seg_image = seg_image.resize(image.size) + return seg_image + +def get_image(image,face,face_box,upper_boundary_ratio = 0.5,expand=1.2): + #print(image.shape) + #print(face.shape) + + body = Image.fromarray(image[:,:,::-1]) + face = Image.fromarray(face[:,:,::-1]) + + x, y, x1, y1 = face_box + #print(x1-x,y1-y) + crop_box, s = get_crop_box(face_box, expand) + x_s, y_s, x_e, y_e = crop_box + face_position = (x, y) + + face_large = body.crop(crop_box) + ori_shape = face_large.size + + mask_image = face_seg(face_large) + mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s)) + mask_image = Image.new('L', ori_shape, 0) + mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s)) + + # keep upper_boundary_ratio of talking area + width, height = mask_image.size + top_boundary = int(height * upper_boundary_ratio) + modified_mask_image = Image.new('L', ori_shape, 0) + modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary)) + + blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1 + mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0) + mask_image = Image.fromarray(mask_array) + + face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s)) + body.paste(face_large, crop_box[:2], mask_image) + body = np.array(body) + return body[:,:,::-1] + +def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand=1.2): + body = Image.fromarray(image[:,:,::-1]) + + x, y, x1, y1 = face_box + #print(x1-x,y1-y) + crop_box, s = get_crop_box(face_box, expand) + x_s, y_s, x_e, y_e = crop_box + + face_large = body.crop(crop_box) + ori_shape = face_large.size + + mask_image = face_seg(face_large) + mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s)) + mask_image = Image.new('L', ori_shape, 0) + mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s)) + + # keep upper_boundary_ratio of talking area + width, height = mask_image.size + top_boundary = int(height * upper_boundary_ratio) + modified_mask_image = Image.new('L', ori_shape, 0) + modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary)) + + blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1 + mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0) + return mask_array,crop_box + +# def get_image_blending(image,face,face_box,mask_array,crop_box): +# body = Image.fromarray(image[:,:,::-1]) +# face = Image.fromarray(face[:,:,::-1]) + +# x, y, x1, y1 = face_box +# x_s, y_s, x_e, y_e = crop_box +# face_large = body.crop(crop_box) + +# mask_image = Image.fromarray(mask_array) +# mask_image = mask_image.convert("L") +# face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s)) +# body.paste(face_large, crop_box[:2], mask_image) +# body = np.array(body) +# return body[:,:,::-1] + +def get_image_blending(image,face,face_box,mask_array,crop_box): + body = image + x, y, x1, y1 = face_box + x_s, y_s, x_e, y_e = crop_box + face_large = copy.deepcopy(body[y_s:y_e, x_s:x_e]) + face_large[y-y_s:y1-y_s, x-x_s:x1-x_s]=face + + mask_image = cv2.cvtColor(mask_array,cv2.COLOR_BGR2GRAY) + mask_image = (mask_image/255).astype(np.float32) + + # mask_not = cv2.bitwise_not(mask_array) + # prospect_tmp = cv2.bitwise_and(face_large, face_large, mask=mask_array) + # background_img = body[y_s:y_e, x_s:x_e] + # background_img = cv2.bitwise_and(background_img, background_img, mask=mask_not) + # body[y_s:y_e, x_s:x_e] = prospect_tmp + background_img + + #print(mask_image.shape) + #print(cv2.minMaxLoc(mask_image)) + + body[y_s:y_e, x_s:x_e] = cv2.blendLinear(face_large,body[y_s:y_e, x_s:x_e],mask_image,1-mask_image) + + #body.paste(face_large, crop_box[:2], mask_image) + return body \ No newline at end of file diff --git a/webrtc.py b/webrtc.py index 5c0048d..d1a37ca 100644 --- a/webrtc.py +++ b/webrtc.py @@ -55,13 +55,13 @@ class PlayerStreamTrack(MediaStreamTrack): if hasattr(self, "_timestamp"): #self._timestamp = (time.time()-self._start) * VIDEO_CLOCK_RATE self._timestamp += int(VIDEO_PTIME * VIDEO_CLOCK_RATE) - # wait = self._start + (self._timestamp / VIDEO_CLOCK_RATE) - time.time() - wait = self.timelist[0] + len(self.timelist)*VIDEO_PTIME - time.time() + wait = self._start + (self._timestamp / VIDEO_CLOCK_RATE) - time.time() + # wait = self.timelist[0] + len(self.timelist)*VIDEO_PTIME - time.time() if wait>0: await asyncio.sleep(wait) - self.timelist.append(time.time()) - if len(self.timelist)>100: - self.timelist.pop(0) + # if len(self.timelist)>=100: + # self.timelist.pop(0) + # self.timelist.append(time.time()) else: self._start = time.time() self._timestamp = 0 @@ -72,13 +72,14 @@ class PlayerStreamTrack(MediaStreamTrack): if hasattr(self, "_timestamp"): #self._timestamp = (time.time()-self._start) * SAMPLE_RATE self._timestamp += int(AUDIO_PTIME * SAMPLE_RATE) - # wait = self._start + (self._timestamp / SAMPLE_RATE) - time.time() - wait = self.timelist[0] + len(self.timelist)*AUDIO_PTIME - time.time() + wait = self._start + (self._timestamp / SAMPLE_RATE) - time.time() + # wait = self.timelist[0] + len(self.timelist)*AUDIO_PTIME - time.time() if wait>0: await asyncio.sleep(wait) - self.timelist.append(time.time()) - if len(self.timelist)>200: - self.timelist.pop(0) + # if len(self.timelist)>=200: + # self.timelist.pop(0) + # self.timelist.pop(0) + # self.timelist.append(time.time()) else: self._start = time.time() self._timestamp = 0