livetalking/lipasr.py

98 lines
3.3 KiB
Python

import time
import torch
import numpy as np
import soundfile as sf
import resampy
import queue
from queue import Queue
from io import BytesIO
import multiprocessing as mp
from wav2lip import audio
class LipASR:
def __init__(self, opt):
self.opt = opt
self.fps = opt.fps # 20 ms per frame
self.sample_rate = 16000
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.queue = Queue()
# self.input_stream = BytesIO()
self.output_queue = mp.Queue()
#self.audio_processor = audio_processor
self.batch_size = opt.batch_size
self.frames = []
self.stride_left_size = self.stride_right_size = 10
self.context_size = 10
self.audio_feats = []
self.feat_queue = mp.Queue(5)
self.warm_up()
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.queue.put(audio_chunk)
def __get_audio_frame(self):
try:
frame = self.queue.get(block=True,timeout=0.018)
type = 0
#print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32)
type = 1
return frame,type
def get_audio_out(self): #get origin audio pcm to nerf
return self.output_queue.get()
def warm_up(self):
for _ in range(self.stride_left_size + self.stride_right_size):
audio_frame,type=self.__get_audio_frame()
self.frames.append(audio_frame)
self.output_queue.put((audio_frame,type))
for _ in range(self.stride_left_size):
self.output_queue.get()
def run_step(self):
############################################## extract audio feature ##############################################
# get a frame of audio
for _ in range(self.batch_size*2):
frame,type = self.__get_audio_frame()
self.frames.append(frame)
# put to output
self.output_queue.put((frame,type))
# context not enough, do not run network.
if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
return
inputs = np.concatenate(self.frames) # [N * chunk]
mel = audio.melspectrogram(inputs)
#print(mel.shape[0],mel.shape,len(mel[0]),len(self.frames))
# cut off stride
left = max(0, self.stride_left_size*80/50)
right = min(len(mel[0]), len(mel[0]) - self.stride_right_size*80/50)
mel_idx_multiplier = 80.*2/self.fps
mel_step_size = 16
i = 0
mel_chunks = []
while i < (len(self.frames)-self.stride_left_size-self.stride_right_size)/2:
start_idx = int(left + i * mel_idx_multiplier)
#print(start_idx)
if start_idx + mel_step_size > len(mel[0]):
mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
else:
mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
i += 1
self.feat_queue.put(mel_chunks)
# discard the old part to save memory
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
def get_next_feat(self,block,timeout):
return self.feat_queue.get(block,timeout)