Merge branch 'lipku:main' into main

This commit is contained in:
ShelikeSnow 2024-06-22 12:49:03 +08:00 committed by GitHub
commit 994535fe3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 158 additions and 137 deletions

View File

@ -1,12 +1,9 @@
import time import time
import torch import torch
import numpy as np import numpy as np
import soundfile as sf
import resampy
import queue import queue
from queue import Queue from queue import Queue
from io import BytesIO
import multiprocessing as mp import multiprocessing as mp
from wav2lip import audio from wav2lip import audio
@ -26,9 +23,9 @@ class LipASR:
self.batch_size = opt.batch_size self.batch_size = opt.batch_size
self.frames = [] self.frames = []
self.stride_left_size = self.stride_right_size = 10 self.stride_left_size = opt.l
self.context_size = 10 self.stride_right_size = opt.r
self.audio_feats = [] #self.context_size = 10
self.feat_queue = mp.Queue(5) self.feat_queue = mp.Queue(5)
self.warm_up() self.warm_up()
@ -38,7 +35,7 @@ class LipASR:
def __get_audio_frame(self): def __get_audio_frame(self):
try: try:
frame = self.queue.get(block=True,timeout=0.018) frame = self.queue.get(block=True,timeout=0.01)
type = 0 type = 0
#print(f'[INFO] get frame {frame.shape}') #print(f'[INFO] get frame {frame.shape}')
except queue.Empty: except queue.Empty:
@ -67,7 +64,7 @@ class LipASR:
# put to output # put to output
self.output_queue.put((frame,type)) self.output_queue.put((frame,type))
# context not enough, do not run network. # context not enough, do not run network.
if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size: if len(self.frames) <= self.stride_left_size + self.stride_right_size:
return return
inputs = np.concatenate(self.frames) # [N * chunk] inputs = np.concatenate(self.frames) # [N * chunk]

View File

@ -1,12 +1,9 @@
import time import time
import torch import torch
import numpy as np import numpy as np
import soundfile as sf
import resampy
import queue import queue
from queue import Queue from queue import Queue
from io import BytesIO
import multiprocessing as mp import multiprocessing as mp
from musetalk.whisper.audio2feature import Audio2Feature from musetalk.whisper.audio2feature import Audio2Feature
@ -25,8 +22,9 @@ class MuseASR:
self.audio_processor = audio_processor self.audio_processor = audio_processor
self.batch_size = opt.batch_size self.batch_size = opt.batch_size
self.stride_left_size = self.stride_right_size = 6 self.frames = []
self.audio_feats = [] self.stride_left_size = opt.l
self.stride_right_size = opt.r
self.feat_queue = mp.Queue(5) self.feat_queue = mp.Queue(5)
self.warm_up() self.warm_up()
@ -36,7 +34,7 @@ class MuseASR:
def __get_audio_frame(self): def __get_audio_frame(self):
try: try:
frame = self.queue.get(block=True,timeout=0.018) frame = self.queue.get(block=True,timeout=0.01)
type = 0 type = 0
#print(f'[INFO] get frame {frame.shape}') #print(f'[INFO] get frame {frame.shape}')
except queue.Empty: except queue.Empty:
@ -49,15 +47,10 @@ class MuseASR:
return self.output_queue.get() return self.output_queue.get()
def warm_up(self): def warm_up(self):
frames = []
for _ in range(self.stride_left_size + self.stride_right_size): for _ in range(self.stride_left_size + self.stride_right_size):
audio_frame,type=self.__get_audio_frame() audio_frame,type=self.__get_audio_frame()
frames.append(audio_frame) self.frames.append(audio_frame)
self.output_queue.put((audio_frame,type)) self.output_queue.put((audio_frame,type))
inputs = np.concatenate(frames) # [N * chunk]
whisper_feature = self.audio_processor.audio2feat(inputs)
for feature in whisper_feature:
self.audio_feats.append(feature)
for _ in range(self.stride_left_size): for _ in range(self.stride_left_size):
self.output_queue.get() self.output_queue.get()
@ -65,20 +58,25 @@ class MuseASR:
def run_step(self): def run_step(self):
############################################## extract audio feature ############################################## ############################################## extract audio feature ##############################################
start_time = time.time() start_time = time.time()
frames = []
for _ in range(self.batch_size*2): for _ in range(self.batch_size*2):
audio_frame,type=self.__get_audio_frame() audio_frame,type=self.__get_audio_frame()
frames.append(audio_frame) self.frames.append(audio_frame)
self.output_queue.put((audio_frame,type)) self.output_queue.put((audio_frame,type))
inputs = np.concatenate(frames) # [N * chunk]
if len(self.frames) <= self.stride_left_size + self.stride_right_size:
return
inputs = np.concatenate(self.frames) # [N * chunk]
whisper_feature = self.audio_processor.audio2feat(inputs) whisper_feature = self.audio_processor.audio2feat(inputs)
for feature in whisper_feature: # for feature in whisper_feature:
self.audio_feats.append(feature) # self.audio_feats.append(feature)
#print(f"processing audio costs {(time.time() - start_time) * 1000}ms, inputs shape:{inputs.shape} whisper_feature len:{len(whisper_feature)}") #print(f"processing audio costs {(time.time() - start_time) * 1000}ms, inputs shape:{inputs.shape} whisper_feature len:{len(whisper_feature)}")
whisper_chunks = self.audio_processor.feature2chunks(feature_array=self.audio_feats,fps=self.fps/2,batch_size=self.batch_size,start=self.stride_left_size/2 ) whisper_chunks = self.audio_processor.feature2chunks(feature_array=whisper_feature,fps=self.fps/2,batch_size=self.batch_size,start=self.stride_left_size/2 )
#print(f"whisper_chunks len:{len(whisper_chunks)},self.audio_feats len:{len(self.audio_feats)},self.output_queue len:{self.output_queue.qsize()}") #print(f"whisper_chunks len:{len(whisper_chunks)},self.audio_feats len:{len(self.audio_feats)},self.output_queue len:{self.output_queue.qsize()}")
self.audio_feats = self.audio_feats[-(self.stride_left_size + self.stride_right_size):] #self.audio_feats = self.audio_feats[-(self.stride_left_size + self.stride_right_size):]
self.feat_queue.put(whisper_chunks) self.feat_queue.put(whisper_chunks)
# discard the old part to save memory
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
def get_next_feat(self,block,timeout): def get_next_feat(self,block,timeout):
return self.feat_queue.get(block,timeout) return self.feat_queue.get(block,timeout)

View File

@ -1,100 +1,125 @@
from PIL import Image from PIL import Image
import numpy as np import numpy as np
import cv2 import cv2
from face_parsing import FaceParsing from face_parsing import FaceParsing
import copy
fp = FaceParsing()
fp = FaceParsing()
def get_crop_box(box, expand):
x, y, x1, y1 = box def get_crop_box(box, expand):
x_c, y_c = (x+x1)//2, (y+y1)//2 x, y, x1, y1 = box
w, h = x1-x, y1-y x_c, y_c = (x+x1)//2, (y+y1)//2
s = int(max(w, h)//2*expand) w, h = x1-x, y1-y
crop_box = [x_c-s, y_c-s, x_c+s, y_c+s] s = int(max(w, h)//2*expand)
return crop_box, s crop_box = [x_c-s, y_c-s, x_c+s, y_c+s]
return crop_box, s
def face_seg(image):
seg_image = fp(image) def face_seg(image):
if seg_image is None: seg_image = fp(image)
print("error, no person_segment") if seg_image is None:
return None print("error, no person_segment")
return None
seg_image = seg_image.resize(image.size)
return seg_image seg_image = seg_image.resize(image.size)
return seg_image
def get_image(image,face,face_box,upper_boundary_ratio = 0.5,expand=1.2):
#print(image.shape) def get_image(image,face,face_box,upper_boundary_ratio = 0.5,expand=1.2):
#print(face.shape) #print(image.shape)
#print(face.shape)
body = Image.fromarray(image[:,:,::-1])
face = Image.fromarray(face[:,:,::-1]) body = Image.fromarray(image[:,:,::-1])
face = Image.fromarray(face[:,:,::-1])
x, y, x1, y1 = face_box
#print(x1-x,y1-y) x, y, x1, y1 = face_box
crop_box, s = get_crop_box(face_box, expand) #print(x1-x,y1-y)
x_s, y_s, x_e, y_e = crop_box crop_box, s = get_crop_box(face_box, expand)
face_position = (x, y) x_s, y_s, x_e, y_e = crop_box
face_position = (x, y)
face_large = body.crop(crop_box)
ori_shape = face_large.size face_large = body.crop(crop_box)
ori_shape = face_large.size
mask_image = face_seg(face_large)
mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s)) mask_image = face_seg(face_large)
mask_image = Image.new('L', ori_shape, 0) mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s))
mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s)) mask_image = Image.new('L', ori_shape, 0)
mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s))
# keep upper_boundary_ratio of talking area
width, height = mask_image.size # keep upper_boundary_ratio of talking area
top_boundary = int(height * upper_boundary_ratio) width, height = mask_image.size
modified_mask_image = Image.new('L', ori_shape, 0) top_boundary = int(height * upper_boundary_ratio)
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary)) modified_mask_image = Image.new('L', ori_shape, 0)
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0) blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
mask_image = Image.fromarray(mask_array) mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
mask_image = Image.fromarray(mask_array)
face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s))
body.paste(face_large, crop_box[:2], mask_image) face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s))
body = np.array(body) body.paste(face_large, crop_box[:2], mask_image)
return body[:,:,::-1] body = np.array(body)
return body[:,:,::-1]
def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand=1.2):
body = Image.fromarray(image[:,:,::-1]) def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand=1.2):
body = Image.fromarray(image[:,:,::-1])
x, y, x1, y1 = face_box
#print(x1-x,y1-y) x, y, x1, y1 = face_box
crop_box, s = get_crop_box(face_box, expand) #print(x1-x,y1-y)
x_s, y_s, x_e, y_e = crop_box crop_box, s = get_crop_box(face_box, expand)
x_s, y_s, x_e, y_e = crop_box
face_large = body.crop(crop_box)
ori_shape = face_large.size face_large = body.crop(crop_box)
ori_shape = face_large.size
mask_image = face_seg(face_large)
mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s)) mask_image = face_seg(face_large)
mask_image = Image.new('L', ori_shape, 0) mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s))
mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s)) mask_image = Image.new('L', ori_shape, 0)
mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s))
# keep upper_boundary_ratio of talking area
width, height = mask_image.size # keep upper_boundary_ratio of talking area
top_boundary = int(height * upper_boundary_ratio) width, height = mask_image.size
modified_mask_image = Image.new('L', ori_shape, 0) top_boundary = int(height * upper_boundary_ratio)
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary)) modified_mask_image = Image.new('L', ori_shape, 0)
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0) blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
return mask_array,crop_box mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
return mask_array,crop_box
def get_image_blending(image,face,face_box,mask_array,crop_box):
body = Image.fromarray(image[:,:,::-1]) # def get_image_blending(image,face,face_box,mask_array,crop_box):
face = Image.fromarray(face[:,:,::-1]) # body = Image.fromarray(image[:,:,::-1])
# face = Image.fromarray(face[:,:,::-1])
x, y, x1, y1 = face_box
x_s, y_s, x_e, y_e = crop_box # x, y, x1, y1 = face_box
face_large = body.crop(crop_box) # x_s, y_s, x_e, y_e = crop_box
# face_large = body.crop(crop_box)
mask_image = Image.fromarray(mask_array)
mask_image = mask_image.convert("L") # mask_image = Image.fromarray(mask_array)
face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s)) # mask_image = mask_image.convert("L")
body.paste(face_large, crop_box[:2], mask_image) # face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s))
body = np.array(body) # body.paste(face_large, crop_box[:2], mask_image)
return body[:,:,::-1] # body = np.array(body)
# return body[:,:,::-1]
def get_image_blending(image,face,face_box,mask_array,crop_box):
body = image
x, y, x1, y1 = face_box
x_s, y_s, x_e, y_e = crop_box
face_large = copy.deepcopy(body[y_s:y_e, x_s:x_e])
face_large[y-y_s:y1-y_s, x-x_s:x1-x_s]=face
mask_image = cv2.cvtColor(mask_array,cv2.COLOR_BGR2GRAY)
mask_image = (mask_image/255).astype(np.float32)
# mask_not = cv2.bitwise_not(mask_array)
# prospect_tmp = cv2.bitwise_and(face_large, face_large, mask=mask_array)
# background_img = body[y_s:y_e, x_s:x_e]
# background_img = cv2.bitwise_and(background_img, background_img, mask=mask_not)
# body[y_s:y_e, x_s:x_e] = prospect_tmp + background_img
#print(mask_image.shape)
#print(cv2.minMaxLoc(mask_image))
body[y_s:y_e, x_s:x_e] = cv2.blendLinear(face_large,body[y_s:y_e, x_s:x_e],mask_image,1-mask_image)
#body.paste(face_large, crop_box[:2], mask_image)
return body

View File

@ -55,13 +55,13 @@ class PlayerStreamTrack(MediaStreamTrack):
if hasattr(self, "_timestamp"): if hasattr(self, "_timestamp"):
#self._timestamp = (time.time()-self._start) * VIDEO_CLOCK_RATE #self._timestamp = (time.time()-self._start) * VIDEO_CLOCK_RATE
self._timestamp += int(VIDEO_PTIME * VIDEO_CLOCK_RATE) self._timestamp += int(VIDEO_PTIME * VIDEO_CLOCK_RATE)
# wait = self._start + (self._timestamp / VIDEO_CLOCK_RATE) - time.time() wait = self._start + (self._timestamp / VIDEO_CLOCK_RATE) - time.time()
wait = self.timelist[0] + len(self.timelist)*VIDEO_PTIME - time.time() # wait = self.timelist[0] + len(self.timelist)*VIDEO_PTIME - time.time()
if wait>0: if wait>0:
await asyncio.sleep(wait) await asyncio.sleep(wait)
self.timelist.append(time.time()) # if len(self.timelist)>=100:
if len(self.timelist)>100: # self.timelist.pop(0)
self.timelist.pop(0) # self.timelist.append(time.time())
else: else:
self._start = time.time() self._start = time.time()
self._timestamp = 0 self._timestamp = 0
@ -72,13 +72,14 @@ class PlayerStreamTrack(MediaStreamTrack):
if hasattr(self, "_timestamp"): if hasattr(self, "_timestamp"):
#self._timestamp = (time.time()-self._start) * SAMPLE_RATE #self._timestamp = (time.time()-self._start) * SAMPLE_RATE
self._timestamp += int(AUDIO_PTIME * SAMPLE_RATE) self._timestamp += int(AUDIO_PTIME * SAMPLE_RATE)
# wait = self._start + (self._timestamp / SAMPLE_RATE) - time.time() wait = self._start + (self._timestamp / SAMPLE_RATE) - time.time()
wait = self.timelist[0] + len(self.timelist)*AUDIO_PTIME - time.time() # wait = self.timelist[0] + len(self.timelist)*AUDIO_PTIME - time.time()
if wait>0: if wait>0:
await asyncio.sleep(wait) await asyncio.sleep(wait)
self.timelist.append(time.time()) # if len(self.timelist)>=200:
if len(self.timelist)>200: # self.timelist.pop(0)
self.timelist.pop(0) # self.timelist.pop(0)
# self.timelist.append(time.time())
else: else:
self._start = time.time() self._start = time.time()
self._timestamp = 0 self._timestamp = 0