improve musetalk lipsync and speed

This commit is contained in:
lipku 2024-06-22 09:02:01 +08:00
parent 592312ab8c
commit da9ffa9521
4 changed files with 158 additions and 137 deletions

View File

@ -1,12 +1,9 @@
import time import time
import torch import torch
import numpy as np import numpy as np
import soundfile as sf
import resampy
import queue import queue
from queue import Queue from queue import Queue
from io import BytesIO
import multiprocessing as mp import multiprocessing as mp
from wav2lip import audio from wav2lip import audio
@ -26,9 +23,9 @@ class LipASR:
self.batch_size = opt.batch_size self.batch_size = opt.batch_size
self.frames = [] self.frames = []
self.stride_left_size = self.stride_right_size = 10 self.stride_left_size = opt.l
self.context_size = 10 self.stride_right_size = opt.r
self.audio_feats = [] #self.context_size = 10
self.feat_queue = mp.Queue(5) self.feat_queue = mp.Queue(5)
self.warm_up() self.warm_up()
@ -38,7 +35,7 @@ class LipASR:
def __get_audio_frame(self): def __get_audio_frame(self):
try: try:
frame = self.queue.get(block=True,timeout=0.018) frame = self.queue.get(block=True,timeout=0.01)
type = 0 type = 0
#print(f'[INFO] get frame {frame.shape}') #print(f'[INFO] get frame {frame.shape}')
except queue.Empty: except queue.Empty:
@ -67,7 +64,7 @@ class LipASR:
# put to output # put to output
self.output_queue.put((frame,type)) self.output_queue.put((frame,type))
# context not enough, do not run network. # context not enough, do not run network.
if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size: if len(self.frames) <= self.stride_left_size + self.stride_right_size:
return return
inputs = np.concatenate(self.frames) # [N * chunk] inputs = np.concatenate(self.frames) # [N * chunk]

View File

@ -1,12 +1,9 @@
import time import time
import torch import torch
import numpy as np import numpy as np
import soundfile as sf
import resampy
import queue import queue
from queue import Queue from queue import Queue
from io import BytesIO
import multiprocessing as mp import multiprocessing as mp
from musetalk.whisper.audio2feature import Audio2Feature from musetalk.whisper.audio2feature import Audio2Feature
@ -25,8 +22,9 @@ class MuseASR:
self.audio_processor = audio_processor self.audio_processor = audio_processor
self.batch_size = opt.batch_size self.batch_size = opt.batch_size
self.stride_left_size = self.stride_right_size = 6 self.frames = []
self.audio_feats = [] self.stride_left_size = opt.l
self.stride_right_size = opt.r
self.feat_queue = mp.Queue(5) self.feat_queue = mp.Queue(5)
self.warm_up() self.warm_up()
@ -36,7 +34,7 @@ class MuseASR:
def __get_audio_frame(self): def __get_audio_frame(self):
try: try:
frame = self.queue.get(block=True,timeout=0.018) frame = self.queue.get(block=True,timeout=0.01)
type = 0 type = 0
#print(f'[INFO] get frame {frame.shape}') #print(f'[INFO] get frame {frame.shape}')
except queue.Empty: except queue.Empty:
@ -49,15 +47,10 @@ class MuseASR:
return self.output_queue.get() return self.output_queue.get()
def warm_up(self): def warm_up(self):
frames = []
for _ in range(self.stride_left_size + self.stride_right_size): for _ in range(self.stride_left_size + self.stride_right_size):
audio_frame,type=self.__get_audio_frame() audio_frame,type=self.__get_audio_frame()
frames.append(audio_frame) self.frames.append(audio_frame)
self.output_queue.put((audio_frame,type)) self.output_queue.put((audio_frame,type))
inputs = np.concatenate(frames) # [N * chunk]
whisper_feature = self.audio_processor.audio2feat(inputs)
for feature in whisper_feature:
self.audio_feats.append(feature)
for _ in range(self.stride_left_size): for _ in range(self.stride_left_size):
self.output_queue.get() self.output_queue.get()
@ -65,20 +58,25 @@ class MuseASR:
def run_step(self): def run_step(self):
############################################## extract audio feature ############################################## ############################################## extract audio feature ##############################################
start_time = time.time() start_time = time.time()
frames = []
for _ in range(self.batch_size*2): for _ in range(self.batch_size*2):
audio_frame,type=self.__get_audio_frame() audio_frame,type=self.__get_audio_frame()
frames.append(audio_frame) self.frames.append(audio_frame)
self.output_queue.put((audio_frame,type)) self.output_queue.put((audio_frame,type))
inputs = np.concatenate(frames) # [N * chunk]
if len(self.frames) <= self.stride_left_size + self.stride_right_size:
return
inputs = np.concatenate(self.frames) # [N * chunk]
whisper_feature = self.audio_processor.audio2feat(inputs) whisper_feature = self.audio_processor.audio2feat(inputs)
for feature in whisper_feature: # for feature in whisper_feature:
self.audio_feats.append(feature) # self.audio_feats.append(feature)
#print(f"processing audio costs {(time.time() - start_time) * 1000}ms, inputs shape:{inputs.shape} whisper_feature len:{len(whisper_feature)}") #print(f"processing audio costs {(time.time() - start_time) * 1000}ms, inputs shape:{inputs.shape} whisper_feature len:{len(whisper_feature)}")
whisper_chunks = self.audio_processor.feature2chunks(feature_array=self.audio_feats,fps=self.fps/2,batch_size=self.batch_size,start=self.stride_left_size/2 ) whisper_chunks = self.audio_processor.feature2chunks(feature_array=whisper_feature,fps=self.fps/2,batch_size=self.batch_size,start=self.stride_left_size/2 )
#print(f"whisper_chunks len:{len(whisper_chunks)},self.audio_feats len:{len(self.audio_feats)},self.output_queue len:{self.output_queue.qsize()}") #print(f"whisper_chunks len:{len(whisper_chunks)},self.audio_feats len:{len(self.audio_feats)},self.output_queue len:{self.output_queue.qsize()}")
self.audio_feats = self.audio_feats[-(self.stride_left_size + self.stride_right_size):] #self.audio_feats = self.audio_feats[-(self.stride_left_size + self.stride_right_size):]
self.feat_queue.put(whisper_chunks) self.feat_queue.put(whisper_chunks)
# discard the old part to save memory
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
def get_next_feat(self,block,timeout): def get_next_feat(self,block,timeout):
return self.feat_queue.get(block,timeout) return self.feat_queue.get(block,timeout)

View File

@ -2,6 +2,7 @@ from PIL import Image
import numpy as np import numpy as np
import cv2 import cv2
from face_parsing import FaceParsing from face_parsing import FaceParsing
import copy
fp = FaceParsing() fp = FaceParsing()
@ -84,17 +85,41 @@ def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand=
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0) mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
return mask_array,crop_box return mask_array,crop_box
def get_image_blending(image,face,face_box,mask_array,crop_box): # def get_image_blending(image,face,face_box,mask_array,crop_box):
body = Image.fromarray(image[:,:,::-1]) # body = Image.fromarray(image[:,:,::-1])
face = Image.fromarray(face[:,:,::-1]) # face = Image.fromarray(face[:,:,::-1])
# x, y, x1, y1 = face_box
# x_s, y_s, x_e, y_e = crop_box
# face_large = body.crop(crop_box)
# mask_image = Image.fromarray(mask_array)
# mask_image = mask_image.convert("L")
# face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s))
# body.paste(face_large, crop_box[:2], mask_image)
# body = np.array(body)
# return body[:,:,::-1]
def get_image_blending(image,face,face_box,mask_array,crop_box):
body = image
x, y, x1, y1 = face_box x, y, x1, y1 = face_box
x_s, y_s, x_e, y_e = crop_box x_s, y_s, x_e, y_e = crop_box
face_large = body.crop(crop_box) face_large = copy.deepcopy(body[y_s:y_e, x_s:x_e])
face_large[y-y_s:y1-y_s, x-x_s:x1-x_s]=face
mask_image = Image.fromarray(mask_array) mask_image = cv2.cvtColor(mask_array,cv2.COLOR_BGR2GRAY)
mask_image = mask_image.convert("L") mask_image = (mask_image/255).astype(np.float32)
face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s))
body.paste(face_large, crop_box[:2], mask_image) # mask_not = cv2.bitwise_not(mask_array)
body = np.array(body) # prospect_tmp = cv2.bitwise_and(face_large, face_large, mask=mask_array)
return body[:,:,::-1] # background_img = body[y_s:y_e, x_s:x_e]
# background_img = cv2.bitwise_and(background_img, background_img, mask=mask_not)
# body[y_s:y_e, x_s:x_e] = prospect_tmp + background_img
#print(mask_image.shape)
#print(cv2.minMaxLoc(mask_image))
body[y_s:y_e, x_s:x_e] = cv2.blendLinear(face_large,body[y_s:y_e, x_s:x_e],mask_image,1-mask_image)
#body.paste(face_large, crop_box[:2], mask_image)
return body

View File

@ -55,13 +55,13 @@ class PlayerStreamTrack(MediaStreamTrack):
if hasattr(self, "_timestamp"): if hasattr(self, "_timestamp"):
#self._timestamp = (time.time()-self._start) * VIDEO_CLOCK_RATE #self._timestamp = (time.time()-self._start) * VIDEO_CLOCK_RATE
self._timestamp += int(VIDEO_PTIME * VIDEO_CLOCK_RATE) self._timestamp += int(VIDEO_PTIME * VIDEO_CLOCK_RATE)
# wait = self._start + (self._timestamp / VIDEO_CLOCK_RATE) - time.time() wait = self._start + (self._timestamp / VIDEO_CLOCK_RATE) - time.time()
wait = self.timelist[0] + len(self.timelist)*VIDEO_PTIME - time.time() # wait = self.timelist[0] + len(self.timelist)*VIDEO_PTIME - time.time()
if wait>0: if wait>0:
await asyncio.sleep(wait) await asyncio.sleep(wait)
self.timelist.append(time.time()) # if len(self.timelist)>=100:
if len(self.timelist)>100: # self.timelist.pop(0)
self.timelist.pop(0) # self.timelist.append(time.time())
else: else:
self._start = time.time() self._start = time.time()
self._timestamp = 0 self._timestamp = 0
@ -72,13 +72,14 @@ class PlayerStreamTrack(MediaStreamTrack):
if hasattr(self, "_timestamp"): if hasattr(self, "_timestamp"):
#self._timestamp = (time.time()-self._start) * SAMPLE_RATE #self._timestamp = (time.time()-self._start) * SAMPLE_RATE
self._timestamp += int(AUDIO_PTIME * SAMPLE_RATE) self._timestamp += int(AUDIO_PTIME * SAMPLE_RATE)
# wait = self._start + (self._timestamp / SAMPLE_RATE) - time.time() wait = self._start + (self._timestamp / SAMPLE_RATE) - time.time()
wait = self.timelist[0] + len(self.timelist)*AUDIO_PTIME - time.time() # wait = self.timelist[0] + len(self.timelist)*AUDIO_PTIME - time.time()
if wait>0: if wait>0:
await asyncio.sleep(wait) await asyncio.sleep(wait)
self.timelist.append(time.time()) # if len(self.timelist)>=200:
if len(self.timelist)>200: # self.timelist.pop(0)
self.timelist.pop(0) # self.timelist.pop(0)
# self.timelist.append(time.time())
else: else:
self._start = time.time() self._start = time.time()
self._timestamp = 0 self._timestamp = 0