diff --git a/app.py b/app.py index 396bf6a..83a60ce 100644 --- a/app.py +++ b/app.py @@ -118,6 +118,9 @@ async def offer(request): async def human(request): params = await request.json() + if params.get('interrupt'): + nerfreal.pause_talk() + if params['type']=='echo': nerfreal.put_msg_txt(params['text']) elif params['type']=='chat': diff --git a/asrreal.py b/asrreal.py index b3e4093..62aa15e 100644 --- a/asrreal.py +++ b/asrreal.py @@ -4,29 +4,19 @@ import torch import torch.nn.functional as F from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel -#import pyaudio -import soundfile as sf -import resampy import queue from queue import Queue #from collections import deque from threading import Thread, Event -from io import BytesIO -class ASR: +from baseasr import BaseASR + +class ASR(BaseASR): def __init__(self, opt): - - self.opt = opt - - self.play = opt.asr_play #false + super().__init__(opt) self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.fps = opt.fps # 20 ms per frame - self.sample_rate = 16000 - self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000) - self.mode = 'live' if opt.asr_wav == '' else 'file' - if 'esperanto' in self.opt.asr_model: self.audio_dim = 44 elif 'deepspeech' in self.opt.asr_model: @@ -41,30 +31,11 @@ class ASR: self.context_size = opt.m self.stride_left_size = opt.l self.stride_right_size = opt.r - self.text = '[START]\n' - self.terminated = False - self.frames = [] - self.inwarm = False # pad left frames if self.stride_left_size > 0: self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size) - - self.exit_event = Event() - #self.audio_instance = pyaudio.PyAudio() #not need - - # create input stream - self.queue = Queue() - self.output_queue = Queue() - # start a background process to read frames - #self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk) - #self.queue = Queue() - #self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk)) - - # current location of audio - self.idx = 0 - # create wav2vec model print(f'[INFO] loading ASR model {self.opt.asr_model}...') if 'hubert' in self.opt.asr_model: @@ -74,10 +45,6 @@ class ASR: self.processor = AutoProcessor.from_pretrained(opt.asr_model) self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device) - # prepare to save logits - if self.opt.asr_save_feats: - self.all_feats = [] - # the extracted features # use a loop queue to efficiently record endless features: [f--t---][-------][-------] self.feat_buffer_size = 4 @@ -93,8 +60,16 @@ class ASR: # warm up steps needed: mid + right + window_size + attention_size self.warm_up_steps = self.context_size + self.stride_left_size + self.stride_right_size #+ self.stride_left_size #+ 8 + 2 * 3 - self.listening = False - self.playing = False + def get_audio_frame(self): + try: + frame = self.queue.get(block=False) + type = 0 + #print(f'[INFO] get frame {frame.shape}') + except queue.Empty: + frame = np.zeros(self.chunk, dtype=np.float32) + type = 1 + + return frame,type def get_next_feat(self): #get audio embedding to nerf # return a [1/8, 16] window, for the next input to nerf side. @@ -136,29 +111,19 @@ class ASR: def run_step(self): - if self.terminated: - return - # get a frame of audio - frame,type = self.__get_audio_frame() - - # the last frame - if frame is None: - # terminate, but always run the network for the left frames - self.terminated = True - else: - self.frames.append(frame) - # put to output - self.output_queue.put((frame,type)) - # context not enough, do not run network. - if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size: - return + frame,type = self.get_audio_frame() + self.frames.append(frame) + # put to output + self.output_queue.put((frame,type)) + # context not enough, do not run network. + if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size: + return inputs = np.concatenate(self.frames) # [N * chunk] # discard the old part to save memory - if not self.terminated: - self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] + self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] #print(f'[INFO] frame_to_text... ') #t = time.time() @@ -166,10 +131,6 @@ class ASR: #print(f'-------wav2vec time:{time.time()-t:.4f}s') feats = logits # better lips-sync than labels - # save feats - if self.opt.asr_save_feats: - self.all_feats.append(feats) - # record the feats efficiently.. (no concat, constant memory) start = self.feat_buffer_idx * self.context_size end = start + feats.shape[0] @@ -203,24 +164,6 @@ class ASR: # np.save(output_path, unfold_feats.cpu().numpy()) # print(f"[INFO] saved logits to {output_path}") - def put_audio_frame(self,audio_chunk): #16khz 20ms pcm - self.queue.put(audio_chunk) - - def __get_audio_frame(self): - if self.inwarm: # warm up - return np.zeros(self.chunk, dtype=np.float32),1 - - try: - frame = self.queue.get(block=False) - type = 0 - print(f'[INFO] get frame {frame.shape}') - except queue.Empty: - frame = np.zeros(self.chunk, dtype=np.float32) - type = 1 - - self.idx = self.idx + self.chunk - - return frame,type def __frame_to_text(self, frame): @@ -241,8 +184,8 @@ class ASR: right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input. # do not cut right if terminated. - if self.terminated: - right = logits.shape[1] + # if self.terminated: + # right = logits.shape[1] logits = logits[:, left:right] @@ -262,10 +205,23 @@ class ASR: return logits[0], None,None #predicted_ids[0], transcription # [N,] - - def get_audio_out(self): #get origin audio pcm to nerf - return self.output_queue.get() - + + def warm_up(self): + print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s') + t = time.time() + #for _ in range(self.stride_left_size): + # self.frames.append(np.zeros(self.chunk, dtype=np.float32)) + for _ in range(self.warm_up_steps): + self.run_step() + #if torch.cuda.is_available(): + # torch.cuda.synchronize() + t = time.time() - t + print(f'[INFO] warm-up done, actual latency = {t:.6f}s') + + #self.clear_queue() + + #####not used function##################################### + ''' def __init_queue(self): self.frames = [] self.queue.queue.clear() @@ -290,26 +246,6 @@ class ASR: if self.play: self.output_queue.queue.clear() - def warm_up(self): - - #self.listen() - - self.inwarm = True - print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s') - t = time.time() - #for _ in range(self.stride_left_size): - # self.frames.append(np.zeros(self.chunk, dtype=np.float32)) - for _ in range(self.warm_up_steps): - self.run_step() - #if torch.cuda.is_available(): - # torch.cuda.synchronize() - t = time.time() - t - print(f'[INFO] warm-up done, actual latency = {t:.6f}s') - self.inwarm = False - - #self.clear_queue() - - #####not used function##################################### def listen(self): # start if self.mode == 'live' and not self.listening: @@ -404,4 +340,5 @@ if __name__ == '__main__': raise ValueError("DeepSpeech features should not use this code to extract...") with ASR(opt) as asr: - asr.run() \ No newline at end of file + asr.run() +''' \ No newline at end of file diff --git a/baseasr.py b/baseasr.py new file mode 100644 index 0000000..df66873 --- /dev/null +++ b/baseasr.py @@ -0,0 +1,61 @@ +import time +import numpy as np + +import queue +from queue import Queue +import multiprocessing as mp + + +class BaseASR: + def __init__(self, opt): + self.opt = opt + + self.fps = opt.fps # 20 ms per frame + self.sample_rate = 16000 + self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000) + self.queue = Queue() + self.output_queue = mp.Queue() + + self.batch_size = opt.batch_size + + self.frames = [] + self.stride_left_size = opt.l + self.stride_right_size = opt.r + #self.context_size = 10 + self.feat_queue = mp.Queue(2) + + #self.warm_up() + + def pause_talk(self): + self.queue.queue.clear() + + def put_audio_frame(self,audio_chunk): #16khz 20ms pcm + self.queue.put(audio_chunk) + + def get_audio_frame(self): + try: + frame = self.queue.get(block=True,timeout=0.01) + type = 0 + #print(f'[INFO] get frame {frame.shape}') + except queue.Empty: + frame = np.zeros(self.chunk, dtype=np.float32) + type = 1 + + return frame,type + + def get_audio_out(self): #get origin audio pcm to nerf + return self.output_queue.get() + + def warm_up(self): + for _ in range(self.stride_left_size + self.stride_right_size): + audio_frame,type=self.get_audio_frame() + self.frames.append(audio_frame) + self.output_queue.put((audio_frame,type)) + for _ in range(self.stride_left_size): + self.output_queue.get() + + def run_step(self): + pass + + def get_next_feat(self,block,timeout): + return self.feat_queue.get(block,timeout) \ No newline at end of file diff --git a/lipasr.py b/lipasr.py index 5742dd7..29948ac 100644 --- a/lipasr.py +++ b/lipasr.py @@ -6,60 +6,16 @@ import queue from queue import Queue import multiprocessing as mp +from baseasr import BaseASR from wav2lip import audio -class LipASR: - def __init__(self, opt): - self.opt = opt - - self.fps = opt.fps # 20 ms per frame - self.sample_rate = 16000 - self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000) - self.queue = Queue() - # self.input_stream = BytesIO() - self.output_queue = mp.Queue() - - #self.audio_processor = audio_processor - self.batch_size = opt.batch_size - - self.frames = [] - self.stride_left_size = opt.l - self.stride_right_size = opt.r - #self.context_size = 10 - self.feat_queue = mp.Queue(5) - - self.warm_up() - - def put_audio_frame(self,audio_chunk): #16khz 20ms pcm - self.queue.put(audio_chunk) - - def __get_audio_frame(self): - try: - frame = self.queue.get(block=True,timeout=0.01) - type = 0 - #print(f'[INFO] get frame {frame.shape}') - except queue.Empty: - frame = np.zeros(self.chunk, dtype=np.float32) - type = 1 - - return frame,type - - def get_audio_out(self): #get origin audio pcm to nerf - return self.output_queue.get() - - def warm_up(self): - for _ in range(self.stride_left_size + self.stride_right_size): - audio_frame,type=self.__get_audio_frame() - self.frames.append(audio_frame) - self.output_queue.put((audio_frame,type)) - for _ in range(self.stride_left_size): - self.output_queue.get() +class LipASR(BaseASR): def run_step(self): ############################################## extract audio feature ############################################## # get a frame of audio for _ in range(self.batch_size*2): - frame,type = self.__get_audio_frame() + frame,type = self.get_audio_frame() self.frames.append(frame) # put to output self.output_queue.put((frame,type)) @@ -89,7 +45,3 @@ class LipASR: # discard the old part to save memory self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] - - - def get_next_feat(self,block,timeout): - return self.feat_queue.get(block,timeout) \ No newline at end of file diff --git a/lipreal.py b/lipreal.py index d69f3dc..9461e7b 100644 --- a/lipreal.py +++ b/lipreal.py @@ -164,6 +164,7 @@ class LipReal: self.__loadavatar() self.asr = LipASR(opt) + self.asr.warm_up() if opt.tts == "edgetts": self.tts = EdgeTTS(opt,self) elif opt.tts == "gpt-sovits": @@ -199,6 +200,10 @@ class LipReal: def put_audio_frame(self,audio_chunk): #16khz 20ms pcm self.asr.put_audio_frame(audio_chunk) + + def pause_talk(self): + self.tts.pause_talk() + self.asr.pause_talk() def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None): @@ -257,9 +262,12 @@ class LipReal: t = time.perf_counter() self.asr.run_step() - if video_track._queue.qsize()>=2*self.opt.batch_size: + # if video_track._queue.qsize()>=2*self.opt.batch_size: + # print('sleep qsize=',video_track._queue.qsize()) + # time.sleep(0.04*video_track._queue.qsize()*0.8) + if video_track._queue.qsize()>=5: print('sleep qsize=',video_track._queue.qsize()) - time.sleep(0.04*self.opt.batch_size*1.5) + time.sleep(0.04*video_track._queue.qsize()*0.8) # delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms # if delay > 0: diff --git a/museasr.py b/museasr.py index cfcd9ba..cb166a6 100644 --- a/museasr.py +++ b/museasr.py @@ -1,65 +1,22 @@ import time -import torch import numpy as np import queue from queue import Queue import multiprocessing as mp - +from baseasr import BaseASR from musetalk.whisper.audio2feature import Audio2Feature -class MuseASR: +class MuseASR(BaseASR): def __init__(self, opt, audio_processor:Audio2Feature): - self.opt = opt - - self.fps = opt.fps # 20 ms per frame - self.sample_rate = 16000 - self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000) - self.queue = Queue() - # self.input_stream = BytesIO() - self.output_queue = mp.Queue() - + super().__init__(opt) self.audio_processor = audio_processor - self.batch_size = opt.batch_size - - self.frames = [] - self.stride_left_size = opt.l - self.stride_right_size = opt.r - self.feat_queue = mp.Queue(5) - - self.warm_up() - - def put_audio_frame(self,audio_chunk): #16khz 20ms pcm - self.queue.put(audio_chunk) - - def __get_audio_frame(self): - try: - frame = self.queue.get(block=True,timeout=0.01) - type = 0 - #print(f'[INFO] get frame {frame.shape}') - except queue.Empty: - frame = np.zeros(self.chunk, dtype=np.float32) - type = 1 - - return frame,type - - def get_audio_out(self): #get origin audio pcm to nerf - return self.output_queue.get() - - def warm_up(self): - for _ in range(self.stride_left_size + self.stride_right_size): - audio_frame,type=self.__get_audio_frame() - self.frames.append(audio_frame) - self.output_queue.put((audio_frame,type)) - - for _ in range(self.stride_left_size): - self.output_queue.get() def run_step(self): ############################################## extract audio feature ############################################## start_time = time.time() for _ in range(self.batch_size*2): - audio_frame,type=self.__get_audio_frame() + audio_frame,type=self.get_audio_frame() self.frames.append(audio_frame) self.output_queue.put((audio_frame,type)) @@ -77,6 +34,3 @@ class MuseASR: self.feat_queue.put(whisper_chunks) # discard the old part to save memory self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] - - def get_next_feat(self,block,timeout): - return self.feat_queue.get(block,timeout) \ No newline at end of file diff --git a/musereal.py b/musereal.py index d92ee85..0f01dcf 100644 --- a/musereal.py +++ b/musereal.py @@ -157,6 +157,7 @@ class MuseReal: self.__loadavatar() self.asr = MuseASR(opt,self.audio_processor) + self.asr.warm_up() if opt.tts == "edgetts": self.tts = EdgeTTS(opt,self) elif opt.tts == "gpt-sovits": @@ -200,6 +201,11 @@ class MuseReal: def put_audio_frame(self,audio_chunk): #16khz 20ms pcm self.asr.put_audio_frame(audio_chunk) + def pause_talk(self): + self.tts.pause_talk() + self.asr.pause_talk() + + def __mirror_index(self, index): size = len(self.coord_list_cycle) turn = index // size @@ -297,9 +303,12 @@ class MuseReal: # print(f"------actual avg infer fps:{count/totaltime:.4f}") # count=0 # totaltime=0 - if video_track._queue.qsize()>=2*self.opt.batch_size: + if video_track._queue.qsize()>=1.5*self.opt.batch_size: print('sleep qsize=',video_track._queue.qsize()) - time.sleep(0.04*self.opt.batch_size*1.5) + time.sleep(0.04*video_track._queue.qsize()*0.8) + # if video_track._queue.qsize()>=5: + # print('sleep qsize=',video_track._queue.qsize()) + # time.sleep(0.04*video_track._queue.qsize()*0.8) # delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms # if delay > 0: diff --git a/nerfreal.py b/nerfreal.py index 5ee2365..ef04c3e 100644 --- a/nerfreal.py +++ b/nerfreal.py @@ -20,9 +20,6 @@ class NeRFReal: self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. self.W = opt.W self.H = opt.H - self.debug = debug - self.training = False - self.step = 0 # training step self.trainer = trainer self.data_loader = data_loader @@ -44,7 +41,6 @@ class NeRFReal: #self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item() # playing seq from dataloader, or pause. - self.playing = True #False todo self.loader = iter(data_loader) #self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32) @@ -62,9 +58,8 @@ class NeRFReal: self.customimg_index = 0 # build asr - if self.opt.asr: - self.asr = ASR(opt) - self.asr.warm_up() + self.asr = ASR(opt) + self.asr.warm_up() if opt.tts == "edgetts": self.tts = EdgeTTS(opt,self) elif opt.tts == "gpt-sovits": @@ -122,7 +117,11 @@ class NeRFReal: self.tts.put_msg_txt(msg) def put_audio_frame(self,audio_chunk): #16khz 20ms pcm - self.asr.put_audio_frame(audio_chunk) + self.asr.put_audio_frame(audio_chunk) + + def pause_talk(self): + self.tts.pause_talk() + self.asr.pause_talk() def mirror_index(self, index): @@ -248,10 +247,9 @@ class NeRFReal: # update texture every frame # audio stream thread... t = time.perf_counter() - if self.opt.asr and self.playing: - # run 2 ASR steps (audio is at 50FPS, video is at 25FPS) - for _ in range(2): - self.asr.run_step() + # run 2 ASR steps (audio is at 50FPS, video is at 25FPS) + for _ in range(2): + self.asr.run_step() self.test_step(loop,audio_track,video_track) totaltime += (time.perf_counter() - t) count += 1 @@ -267,7 +265,7 @@ class NeRFReal: else: if video_track._queue.qsize()>=5: #print('sleep qsize=',video_track._queue.qsize()) - time.sleep(0.1) + time.sleep(0.04*video_track._queue.qsize()*0.8) print('nerfreal thread stop') \ No newline at end of file diff --git a/ttsreal.py b/ttsreal.py index 22e2909..e5c2c7f 100644 --- a/ttsreal.py +++ b/ttsreal.py @@ -13,6 +13,11 @@ import queue from queue import Queue from io import BytesIO from threading import Thread, Event +from enum import Enum + +class State(Enum): + RUNNING=0 + PAUSE=1 class BaseTTS: def __init__(self, opt, parent): @@ -25,6 +30,11 @@ class BaseTTS: self.input_stream = BytesIO() self.msgqueue = Queue() + self.state = State.RUNNING + + def pause_talk(self): + self.msgqueue.queue.clear() + self.state = State.PAUSE def put_msg_txt(self,msg): self.msgqueue.put(msg) @@ -37,6 +47,7 @@ class BaseTTS: while not quit_event.is_set(): try: msg = self.msgqueue.get(block=True, timeout=1) + self.state=State.RUNNING except queue.Empty: continue self.txt_to_audio(msg) @@ -59,7 +70,7 @@ class EdgeTTS(BaseTTS): stream = self.__create_bytes_stream(self.input_stream) streamlen = stream.shape[0] idx=0 - while streamlen >= self.chunk: + while streamlen >= self.chunk and self.state==State.RUNNING: self.parent.put_audio_frame(stream[idx:idx+self.chunk]) streamlen -= self.chunk idx += self.chunk @@ -92,7 +103,7 @@ class EdgeTTS(BaseTTS): async for chunk in communicate.stream(): if first: first = False - if chunk["type"] == "audio": + if chunk["type"] == "audio" and self.state==State.RUNNING: #self.push_audio(chunk["data"]) self.input_stream.write(chunk["data"]) #file.write(chunk["data"]) @@ -147,7 +158,7 @@ class VoitsTTS(BaseTTS): end = time.perf_counter() print(f"gpt_sovits Time to first chunk: {end-start}s") first = False - if chunk: + if chunk and self.state==State.RUNNING: yield chunk print("gpt_sovits response.elapsed:", res.elapsed) diff --git a/web/webrtcapi.html b/web/webrtcapi.html index 16e0860..eff0287 100644 --- a/web/webrtcapi.html +++ b/web/webrtcapi.html @@ -79,6 +79,7 @@ body: JSON.stringify({ text: message, type: 'echo', + interrupt: true, }), headers: { 'Content-Type': 'application/json'