wrapper class baseasr; add talk interrupt

This commit is contained in:
lipku 2024-06-30 09:41:31 +08:00
parent 98eeeb17af
commit 9fe4c7fccf
10 changed files with 161 additions and 227 deletions

3
app.py
View File

@ -118,6 +118,9 @@ async def offer(request):
async def human(request):
params = await request.json()
if params.get('interrupt'):
nerfreal.pause_talk()
if params['type']=='echo':
nerfreal.put_msg_txt(params['text'])
elif params['type']=='chat':

View File

@ -4,29 +4,19 @@ import torch
import torch.nn.functional as F
from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel
#import pyaudio
import soundfile as sf
import resampy
import queue
from queue import Queue
#from collections import deque
from threading import Thread, Event
from io import BytesIO
class ASR:
from baseasr import BaseASR
class ASR(BaseASR):
def __init__(self, opt):
self.opt = opt
self.play = opt.asr_play #false
super().__init__(opt)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.fps = opt.fps # 20 ms per frame
self.sample_rate = 16000
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.mode = 'live' if opt.asr_wav == '' else 'file'
if 'esperanto' in self.opt.asr_model:
self.audio_dim = 44
elif 'deepspeech' in self.opt.asr_model:
@ -41,30 +31,11 @@ class ASR:
self.context_size = opt.m
self.stride_left_size = opt.l
self.stride_right_size = opt.r
self.text = '[START]\n'
self.terminated = False
self.frames = []
self.inwarm = False
# pad left frames
if self.stride_left_size > 0:
self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size)
self.exit_event = Event()
#self.audio_instance = pyaudio.PyAudio() #not need
# create input stream
self.queue = Queue()
self.output_queue = Queue()
# start a background process to read frames
#self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk)
#self.queue = Queue()
#self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk))
# current location of audio
self.idx = 0
# create wav2vec model
print(f'[INFO] loading ASR model {self.opt.asr_model}...')
if 'hubert' in self.opt.asr_model:
@ -74,10 +45,6 @@ class ASR:
self.processor = AutoProcessor.from_pretrained(opt.asr_model)
self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device)
# prepare to save logits
if self.opt.asr_save_feats:
self.all_feats = []
# the extracted features
# use a loop queue to efficiently record endless features: [f--t---][-------][-------]
self.feat_buffer_size = 4
@ -93,8 +60,16 @@ class ASR:
# warm up steps needed: mid + right + window_size + attention_size
self.warm_up_steps = self.context_size + self.stride_left_size + self.stride_right_size #+ self.stride_left_size #+ 8 + 2 * 3
self.listening = False
self.playing = False
def get_audio_frame(self):
try:
frame = self.queue.get(block=False)
type = 0
#print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32)
type = 1
return frame,type
def get_next_feat(self): #get audio embedding to nerf
# return a [1/8, 16] window, for the next input to nerf side.
@ -136,17 +111,8 @@ class ASR:
def run_step(self):
if self.terminated:
return
# get a frame of audio
frame,type = self.__get_audio_frame()
# the last frame
if frame is None:
# terminate, but always run the network for the left frames
self.terminated = True
else:
frame,type = self.get_audio_frame()
self.frames.append(frame)
# put to output
self.output_queue.put((frame,type))
@ -157,7 +123,6 @@ class ASR:
inputs = np.concatenate(self.frames) # [N * chunk]
# discard the old part to save memory
if not self.terminated:
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
#print(f'[INFO] frame_to_text... ')
@ -166,10 +131,6 @@ class ASR:
#print(f'-------wav2vec time:{time.time()-t:.4f}s')
feats = logits # better lips-sync than labels
# save feats
if self.opt.asr_save_feats:
self.all_feats.append(feats)
# record the feats efficiently.. (no concat, constant memory)
start = self.feat_buffer_idx * self.context_size
end = start + feats.shape[0]
@ -203,24 +164,6 @@ class ASR:
# np.save(output_path, unfold_feats.cpu().numpy())
# print(f"[INFO] saved logits to {output_path}")
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.queue.put(audio_chunk)
def __get_audio_frame(self):
if self.inwarm: # warm up
return np.zeros(self.chunk, dtype=np.float32),1
try:
frame = self.queue.get(block=False)
type = 0
print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32)
type = 1
self.idx = self.idx + self.chunk
return frame,type
def __frame_to_text(self, frame):
@ -241,8 +184,8 @@ class ASR:
right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input.
# do not cut right if terminated.
if self.terminated:
right = logits.shape[1]
# if self.terminated:
# right = logits.shape[1]
logits = logits[:, left:right]
@ -263,9 +206,22 @@ class ASR:
return logits[0], None,None #predicted_ids[0], transcription # [N,]
def get_audio_out(self): #get origin audio pcm to nerf
return self.output_queue.get()
def warm_up(self):
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
t = time.time()
#for _ in range(self.stride_left_size):
# self.frames.append(np.zeros(self.chunk, dtype=np.float32))
for _ in range(self.warm_up_steps):
self.run_step()
#if torch.cuda.is_available():
# torch.cuda.synchronize()
t = time.time() - t
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
#self.clear_queue()
#####not used function#####################################
'''
def __init_queue(self):
self.frames = []
self.queue.queue.clear()
@ -290,26 +246,6 @@ class ASR:
if self.play:
self.output_queue.queue.clear()
def warm_up(self):
#self.listen()
self.inwarm = True
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
t = time.time()
#for _ in range(self.stride_left_size):
# self.frames.append(np.zeros(self.chunk, dtype=np.float32))
for _ in range(self.warm_up_steps):
self.run_step()
#if torch.cuda.is_available():
# torch.cuda.synchronize()
t = time.time() - t
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
self.inwarm = False
#self.clear_queue()
#####not used function#####################################
def listen(self):
# start
if self.mode == 'live' and not self.listening:
@ -405,3 +341,4 @@ if __name__ == '__main__':
with ASR(opt) as asr:
asr.run()
'''

61
baseasr.py Normal file
View File

@ -0,0 +1,61 @@
import time
import numpy as np
import queue
from queue import Queue
import multiprocessing as mp
class BaseASR:
def __init__(self, opt):
self.opt = opt
self.fps = opt.fps # 20 ms per frame
self.sample_rate = 16000
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.queue = Queue()
self.output_queue = mp.Queue()
self.batch_size = opt.batch_size
self.frames = []
self.stride_left_size = opt.l
self.stride_right_size = opt.r
#self.context_size = 10
self.feat_queue = mp.Queue(2)
#self.warm_up()
def pause_talk(self):
self.queue.queue.clear()
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.queue.put(audio_chunk)
def get_audio_frame(self):
try:
frame = self.queue.get(block=True,timeout=0.01)
type = 0
#print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32)
type = 1
return frame,type
def get_audio_out(self): #get origin audio pcm to nerf
return self.output_queue.get()
def warm_up(self):
for _ in range(self.stride_left_size + self.stride_right_size):
audio_frame,type=self.get_audio_frame()
self.frames.append(audio_frame)
self.output_queue.put((audio_frame,type))
for _ in range(self.stride_left_size):
self.output_queue.get()
def run_step(self):
pass
def get_next_feat(self,block,timeout):
return self.feat_queue.get(block,timeout)

View File

@ -6,60 +6,16 @@ import queue
from queue import Queue
import multiprocessing as mp
from baseasr import BaseASR
from wav2lip import audio
class LipASR:
def __init__(self, opt):
self.opt = opt
self.fps = opt.fps # 20 ms per frame
self.sample_rate = 16000
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.queue = Queue()
# self.input_stream = BytesIO()
self.output_queue = mp.Queue()
#self.audio_processor = audio_processor
self.batch_size = opt.batch_size
self.frames = []
self.stride_left_size = opt.l
self.stride_right_size = opt.r
#self.context_size = 10
self.feat_queue = mp.Queue(5)
self.warm_up()
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.queue.put(audio_chunk)
def __get_audio_frame(self):
try:
frame = self.queue.get(block=True,timeout=0.01)
type = 0
#print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32)
type = 1
return frame,type
def get_audio_out(self): #get origin audio pcm to nerf
return self.output_queue.get()
def warm_up(self):
for _ in range(self.stride_left_size + self.stride_right_size):
audio_frame,type=self.__get_audio_frame()
self.frames.append(audio_frame)
self.output_queue.put((audio_frame,type))
for _ in range(self.stride_left_size):
self.output_queue.get()
class LipASR(BaseASR):
def run_step(self):
############################################## extract audio feature ##############################################
# get a frame of audio
for _ in range(self.batch_size*2):
frame,type = self.__get_audio_frame()
frame,type = self.get_audio_frame()
self.frames.append(frame)
# put to output
self.output_queue.put((frame,type))
@ -89,7 +45,3 @@ class LipASR:
# discard the old part to save memory
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
def get_next_feat(self,block,timeout):
return self.feat_queue.get(block,timeout)

View File

@ -164,6 +164,7 @@ class LipReal:
self.__loadavatar()
self.asr = LipASR(opt)
self.asr.warm_up()
if opt.tts == "edgetts":
self.tts = EdgeTTS(opt,self)
elif opt.tts == "gpt-sovits":
@ -200,6 +201,10 @@ class LipReal:
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.asr.put_audio_frame(audio_chunk)
def pause_talk(self):
self.tts.pause_talk()
self.asr.pause_talk()
def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):
@ -257,9 +262,12 @@ class LipReal:
t = time.perf_counter()
self.asr.run_step()
if video_track._queue.qsize()>=2*self.opt.batch_size:
# if video_track._queue.qsize()>=2*self.opt.batch_size:
# print('sleep qsize=',video_track._queue.qsize())
# time.sleep(0.04*video_track._queue.qsize()*0.8)
if video_track._queue.qsize()>=5:
print('sleep qsize=',video_track._queue.qsize())
time.sleep(0.04*self.opt.batch_size*1.5)
time.sleep(0.04*video_track._queue.qsize()*0.8)
# delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
# if delay > 0:

View File

@ -1,65 +1,22 @@
import time
import torch
import numpy as np
import queue
from queue import Queue
import multiprocessing as mp
from baseasr import BaseASR
from musetalk.whisper.audio2feature import Audio2Feature
class MuseASR:
class MuseASR(BaseASR):
def __init__(self, opt, audio_processor:Audio2Feature):
self.opt = opt
self.fps = opt.fps # 20 ms per frame
self.sample_rate = 16000
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.queue = Queue()
# self.input_stream = BytesIO()
self.output_queue = mp.Queue()
super().__init__(opt)
self.audio_processor = audio_processor
self.batch_size = opt.batch_size
self.frames = []
self.stride_left_size = opt.l
self.stride_right_size = opt.r
self.feat_queue = mp.Queue(5)
self.warm_up()
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.queue.put(audio_chunk)
def __get_audio_frame(self):
try:
frame = self.queue.get(block=True,timeout=0.01)
type = 0
#print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32)
type = 1
return frame,type
def get_audio_out(self): #get origin audio pcm to nerf
return self.output_queue.get()
def warm_up(self):
for _ in range(self.stride_left_size + self.stride_right_size):
audio_frame,type=self.__get_audio_frame()
self.frames.append(audio_frame)
self.output_queue.put((audio_frame,type))
for _ in range(self.stride_left_size):
self.output_queue.get()
def run_step(self):
############################################## extract audio feature ##############################################
start_time = time.time()
for _ in range(self.batch_size*2):
audio_frame,type=self.__get_audio_frame()
audio_frame,type=self.get_audio_frame()
self.frames.append(audio_frame)
self.output_queue.put((audio_frame,type))
@ -77,6 +34,3 @@ class MuseASR:
self.feat_queue.put(whisper_chunks)
# discard the old part to save memory
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
def get_next_feat(self,block,timeout):
return self.feat_queue.get(block,timeout)

View File

@ -157,6 +157,7 @@ class MuseReal:
self.__loadavatar()
self.asr = MuseASR(opt,self.audio_processor)
self.asr.warm_up()
if opt.tts == "edgetts":
self.tts = EdgeTTS(opt,self)
elif opt.tts == "gpt-sovits":
@ -200,6 +201,11 @@ class MuseReal:
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.asr.put_audio_frame(audio_chunk)
def pause_talk(self):
self.tts.pause_talk()
self.asr.pause_talk()
def __mirror_index(self, index):
size = len(self.coord_list_cycle)
turn = index // size
@ -297,9 +303,12 @@ class MuseReal:
# print(f"------actual avg infer fps:{count/totaltime:.4f}")
# count=0
# totaltime=0
if video_track._queue.qsize()>=2*self.opt.batch_size:
if video_track._queue.qsize()>=1.5*self.opt.batch_size:
print('sleep qsize=',video_track._queue.qsize())
time.sleep(0.04*self.opt.batch_size*1.5)
time.sleep(0.04*video_track._queue.qsize()*0.8)
# if video_track._queue.qsize()>=5:
# print('sleep qsize=',video_track._queue.qsize())
# time.sleep(0.04*video_track._queue.qsize()*0.8)
# delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
# if delay > 0:

View File

@ -20,9 +20,6 @@ class NeRFReal:
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.W = opt.W
self.H = opt.H
self.debug = debug
self.training = False
self.step = 0 # training step
self.trainer = trainer
self.data_loader = data_loader
@ -44,7 +41,6 @@ class NeRFReal:
#self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item()
# playing seq from dataloader, or pause.
self.playing = True #False todo
self.loader = iter(data_loader)
#self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
@ -62,7 +58,6 @@ class NeRFReal:
self.customimg_index = 0
# build asr
if self.opt.asr:
self.asr = ASR(opt)
self.asr.warm_up()
if opt.tts == "edgetts":
@ -124,6 +119,10 @@ class NeRFReal:
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.asr.put_audio_frame(audio_chunk)
def pause_talk(self):
self.tts.pause_talk()
self.asr.pause_talk()
def mirror_index(self, index):
size = self.opt.customvideo_imgnum
@ -248,7 +247,6 @@ class NeRFReal:
# update texture every frame
# audio stream thread...
t = time.perf_counter()
if self.opt.asr and self.playing:
# run 2 ASR steps (audio is at 50FPS, video is at 25FPS)
for _ in range(2):
self.asr.run_step()
@ -267,7 +265,7 @@ class NeRFReal:
else:
if video_track._queue.qsize()>=5:
#print('sleep qsize=',video_track._queue.qsize())
time.sleep(0.1)
time.sleep(0.04*video_track._queue.qsize()*0.8)
print('nerfreal thread stop')

View File

@ -13,6 +13,11 @@ import queue
from queue import Queue
from io import BytesIO
from threading import Thread, Event
from enum import Enum
class State(Enum):
RUNNING=0
PAUSE=1
class BaseTTS:
def __init__(self, opt, parent):
@ -25,6 +30,11 @@ class BaseTTS:
self.input_stream = BytesIO()
self.msgqueue = Queue()
self.state = State.RUNNING
def pause_talk(self):
self.msgqueue.queue.clear()
self.state = State.PAUSE
def put_msg_txt(self,msg):
self.msgqueue.put(msg)
@ -37,6 +47,7 @@ class BaseTTS:
while not quit_event.is_set():
try:
msg = self.msgqueue.get(block=True, timeout=1)
self.state=State.RUNNING
except queue.Empty:
continue
self.txt_to_audio(msg)
@ -59,7 +70,7 @@ class EdgeTTS(BaseTTS):
stream = self.__create_bytes_stream(self.input_stream)
streamlen = stream.shape[0]
idx=0
while streamlen >= self.chunk:
while streamlen >= self.chunk and self.state==State.RUNNING:
self.parent.put_audio_frame(stream[idx:idx+self.chunk])
streamlen -= self.chunk
idx += self.chunk
@ -92,7 +103,7 @@ class EdgeTTS(BaseTTS):
async for chunk in communicate.stream():
if first:
first = False
if chunk["type"] == "audio":
if chunk["type"] == "audio" and self.state==State.RUNNING:
#self.push_audio(chunk["data"])
self.input_stream.write(chunk["data"])
#file.write(chunk["data"])
@ -147,7 +158,7 @@ class VoitsTTS(BaseTTS):
end = time.perf_counter()
print(f"gpt_sovits Time to first chunk: {end-start}s")
first = False
if chunk:
if chunk and self.state==State.RUNNING:
yield chunk
print("gpt_sovits response.elapsed:", res.elapsed)

View File

@ -79,6 +79,7 @@
body: JSON.stringify({
text: message,
type: 'echo',
interrupt: true,
}),
headers: {
'Content-Type': 'application/json'