wrapper class baseasr; add talk interrupt

This commit is contained in:
lipku 2024-06-30 09:41:31 +08:00
parent 98eeeb17af
commit 9fe4c7fccf
10 changed files with 161 additions and 227 deletions

3
app.py
View File

@ -118,6 +118,9 @@ async def offer(request):
async def human(request): async def human(request):
params = await request.json() params = await request.json()
if params.get('interrupt'):
nerfreal.pause_talk()
if params['type']=='echo': if params['type']=='echo':
nerfreal.put_msg_txt(params['text']) nerfreal.put_msg_txt(params['text'])
elif params['type']=='chat': elif params['type']=='chat':

View File

@ -4,29 +4,19 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel
#import pyaudio
import soundfile as sf
import resampy
import queue import queue
from queue import Queue from queue import Queue
#from collections import deque #from collections import deque
from threading import Thread, Event from threading import Thread, Event
from io import BytesIO
class ASR: from baseasr import BaseASR
class ASR(BaseASR):
def __init__(self, opt): def __init__(self, opt):
super().__init__(opt)
self.opt = opt
self.play = opt.asr_play #false
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.fps = opt.fps # 20 ms per frame
self.sample_rate = 16000
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.mode = 'live' if opt.asr_wav == '' else 'file'
if 'esperanto' in self.opt.asr_model: if 'esperanto' in self.opt.asr_model:
self.audio_dim = 44 self.audio_dim = 44
elif 'deepspeech' in self.opt.asr_model: elif 'deepspeech' in self.opt.asr_model:
@ -41,30 +31,11 @@ class ASR:
self.context_size = opt.m self.context_size = opt.m
self.stride_left_size = opt.l self.stride_left_size = opt.l
self.stride_right_size = opt.r self.stride_right_size = opt.r
self.text = '[START]\n'
self.terminated = False
self.frames = []
self.inwarm = False
# pad left frames # pad left frames
if self.stride_left_size > 0: if self.stride_left_size > 0:
self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size) self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size)
self.exit_event = Event()
#self.audio_instance = pyaudio.PyAudio() #not need
# create input stream
self.queue = Queue()
self.output_queue = Queue()
# start a background process to read frames
#self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk)
#self.queue = Queue()
#self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk))
# current location of audio
self.idx = 0
# create wav2vec model # create wav2vec model
print(f'[INFO] loading ASR model {self.opt.asr_model}...') print(f'[INFO] loading ASR model {self.opt.asr_model}...')
if 'hubert' in self.opt.asr_model: if 'hubert' in self.opt.asr_model:
@ -74,10 +45,6 @@ class ASR:
self.processor = AutoProcessor.from_pretrained(opt.asr_model) self.processor = AutoProcessor.from_pretrained(opt.asr_model)
self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device) self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device)
# prepare to save logits
if self.opt.asr_save_feats:
self.all_feats = []
# the extracted features # the extracted features
# use a loop queue to efficiently record endless features: [f--t---][-------][-------] # use a loop queue to efficiently record endless features: [f--t---][-------][-------]
self.feat_buffer_size = 4 self.feat_buffer_size = 4
@ -93,8 +60,16 @@ class ASR:
# warm up steps needed: mid + right + window_size + attention_size # warm up steps needed: mid + right + window_size + attention_size
self.warm_up_steps = self.context_size + self.stride_left_size + self.stride_right_size #+ self.stride_left_size #+ 8 + 2 * 3 self.warm_up_steps = self.context_size + self.stride_left_size + self.stride_right_size #+ self.stride_left_size #+ 8 + 2 * 3
self.listening = False def get_audio_frame(self):
self.playing = False try:
frame = self.queue.get(block=False)
type = 0
#print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32)
type = 1
return frame,type
def get_next_feat(self): #get audio embedding to nerf def get_next_feat(self): #get audio embedding to nerf
# return a [1/8, 16] window, for the next input to nerf side. # return a [1/8, 16] window, for the next input to nerf side.
@ -136,29 +111,19 @@ class ASR:
def run_step(self): def run_step(self):
if self.terminated:
return
# get a frame of audio # get a frame of audio
frame,type = self.__get_audio_frame() frame,type = self.get_audio_frame()
self.frames.append(frame)
# the last frame # put to output
if frame is None: self.output_queue.put((frame,type))
# terminate, but always run the network for the left frames # context not enough, do not run network.
self.terminated = True if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
else: return
self.frames.append(frame)
# put to output
self.output_queue.put((frame,type))
# context not enough, do not run network.
if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
return
inputs = np.concatenate(self.frames) # [N * chunk] inputs = np.concatenate(self.frames) # [N * chunk]
# discard the old part to save memory # discard the old part to save memory
if not self.terminated: self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
#print(f'[INFO] frame_to_text... ') #print(f'[INFO] frame_to_text... ')
#t = time.time() #t = time.time()
@ -166,10 +131,6 @@ class ASR:
#print(f'-------wav2vec time:{time.time()-t:.4f}s') #print(f'-------wav2vec time:{time.time()-t:.4f}s')
feats = logits # better lips-sync than labels feats = logits # better lips-sync than labels
# save feats
if self.opt.asr_save_feats:
self.all_feats.append(feats)
# record the feats efficiently.. (no concat, constant memory) # record the feats efficiently.. (no concat, constant memory)
start = self.feat_buffer_idx * self.context_size start = self.feat_buffer_idx * self.context_size
end = start + feats.shape[0] end = start + feats.shape[0]
@ -203,24 +164,6 @@ class ASR:
# np.save(output_path, unfold_feats.cpu().numpy()) # np.save(output_path, unfold_feats.cpu().numpy())
# print(f"[INFO] saved logits to {output_path}") # print(f"[INFO] saved logits to {output_path}")
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.queue.put(audio_chunk)
def __get_audio_frame(self):
if self.inwarm: # warm up
return np.zeros(self.chunk, dtype=np.float32),1
try:
frame = self.queue.get(block=False)
type = 0
print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32)
type = 1
self.idx = self.idx + self.chunk
return frame,type
def __frame_to_text(self, frame): def __frame_to_text(self, frame):
@ -241,8 +184,8 @@ class ASR:
right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input. right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input.
# do not cut right if terminated. # do not cut right if terminated.
if self.terminated: # if self.terminated:
right = logits.shape[1] # right = logits.shape[1]
logits = logits[:, left:right] logits = logits[:, left:right]
@ -263,9 +206,22 @@ class ASR:
return logits[0], None,None #predicted_ids[0], transcription # [N,] return logits[0], None,None #predicted_ids[0], transcription # [N,]
def get_audio_out(self): #get origin audio pcm to nerf def warm_up(self):
return self.output_queue.get() print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
t = time.time()
#for _ in range(self.stride_left_size):
# self.frames.append(np.zeros(self.chunk, dtype=np.float32))
for _ in range(self.warm_up_steps):
self.run_step()
#if torch.cuda.is_available():
# torch.cuda.synchronize()
t = time.time() - t
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
#self.clear_queue()
#####not used function#####################################
'''
def __init_queue(self): def __init_queue(self):
self.frames = [] self.frames = []
self.queue.queue.clear() self.queue.queue.clear()
@ -290,26 +246,6 @@ class ASR:
if self.play: if self.play:
self.output_queue.queue.clear() self.output_queue.queue.clear()
def warm_up(self):
#self.listen()
self.inwarm = True
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
t = time.time()
#for _ in range(self.stride_left_size):
# self.frames.append(np.zeros(self.chunk, dtype=np.float32))
for _ in range(self.warm_up_steps):
self.run_step()
#if torch.cuda.is_available():
# torch.cuda.synchronize()
t = time.time() - t
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
self.inwarm = False
#self.clear_queue()
#####not used function#####################################
def listen(self): def listen(self):
# start # start
if self.mode == 'live' and not self.listening: if self.mode == 'live' and not self.listening:
@ -405,3 +341,4 @@ if __name__ == '__main__':
with ASR(opt) as asr: with ASR(opt) as asr:
asr.run() asr.run()
'''

61
baseasr.py Normal file
View File

@ -0,0 +1,61 @@
import time
import numpy as np
import queue
from queue import Queue
import multiprocessing as mp
class BaseASR:
def __init__(self, opt):
self.opt = opt
self.fps = opt.fps # 20 ms per frame
self.sample_rate = 16000
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.queue = Queue()
self.output_queue = mp.Queue()
self.batch_size = opt.batch_size
self.frames = []
self.stride_left_size = opt.l
self.stride_right_size = opt.r
#self.context_size = 10
self.feat_queue = mp.Queue(2)
#self.warm_up()
def pause_talk(self):
self.queue.queue.clear()
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.queue.put(audio_chunk)
def get_audio_frame(self):
try:
frame = self.queue.get(block=True,timeout=0.01)
type = 0
#print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32)
type = 1
return frame,type
def get_audio_out(self): #get origin audio pcm to nerf
return self.output_queue.get()
def warm_up(self):
for _ in range(self.stride_left_size + self.stride_right_size):
audio_frame,type=self.get_audio_frame()
self.frames.append(audio_frame)
self.output_queue.put((audio_frame,type))
for _ in range(self.stride_left_size):
self.output_queue.get()
def run_step(self):
pass
def get_next_feat(self,block,timeout):
return self.feat_queue.get(block,timeout)

View File

@ -6,60 +6,16 @@ import queue
from queue import Queue from queue import Queue
import multiprocessing as mp import multiprocessing as mp
from baseasr import BaseASR
from wav2lip import audio from wav2lip import audio
class LipASR: class LipASR(BaseASR):
def __init__(self, opt):
self.opt = opt
self.fps = opt.fps # 20 ms per frame
self.sample_rate = 16000
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.queue = Queue()
# self.input_stream = BytesIO()
self.output_queue = mp.Queue()
#self.audio_processor = audio_processor
self.batch_size = opt.batch_size
self.frames = []
self.stride_left_size = opt.l
self.stride_right_size = opt.r
#self.context_size = 10
self.feat_queue = mp.Queue(5)
self.warm_up()
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.queue.put(audio_chunk)
def __get_audio_frame(self):
try:
frame = self.queue.get(block=True,timeout=0.01)
type = 0
#print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32)
type = 1
return frame,type
def get_audio_out(self): #get origin audio pcm to nerf
return self.output_queue.get()
def warm_up(self):
for _ in range(self.stride_left_size + self.stride_right_size):
audio_frame,type=self.__get_audio_frame()
self.frames.append(audio_frame)
self.output_queue.put((audio_frame,type))
for _ in range(self.stride_left_size):
self.output_queue.get()
def run_step(self): def run_step(self):
############################################## extract audio feature ############################################## ############################################## extract audio feature ##############################################
# get a frame of audio # get a frame of audio
for _ in range(self.batch_size*2): for _ in range(self.batch_size*2):
frame,type = self.__get_audio_frame() frame,type = self.get_audio_frame()
self.frames.append(frame) self.frames.append(frame)
# put to output # put to output
self.output_queue.put((frame,type)) self.output_queue.put((frame,type))
@ -89,7 +45,3 @@ class LipASR:
# discard the old part to save memory # discard the old part to save memory
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
def get_next_feat(self,block,timeout):
return self.feat_queue.get(block,timeout)

View File

@ -164,6 +164,7 @@ class LipReal:
self.__loadavatar() self.__loadavatar()
self.asr = LipASR(opt) self.asr = LipASR(opt)
self.asr.warm_up()
if opt.tts == "edgetts": if opt.tts == "edgetts":
self.tts = EdgeTTS(opt,self) self.tts = EdgeTTS(opt,self)
elif opt.tts == "gpt-sovits": elif opt.tts == "gpt-sovits":
@ -200,6 +201,10 @@ class LipReal:
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.asr.put_audio_frame(audio_chunk) self.asr.put_audio_frame(audio_chunk)
def pause_talk(self):
self.tts.pause_talk()
self.asr.pause_talk()
def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None): def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):
@ -257,9 +262,12 @@ class LipReal:
t = time.perf_counter() t = time.perf_counter()
self.asr.run_step() self.asr.run_step()
if video_track._queue.qsize()>=2*self.opt.batch_size: # if video_track._queue.qsize()>=2*self.opt.batch_size:
# print('sleep qsize=',video_track._queue.qsize())
# time.sleep(0.04*video_track._queue.qsize()*0.8)
if video_track._queue.qsize()>=5:
print('sleep qsize=',video_track._queue.qsize()) print('sleep qsize=',video_track._queue.qsize())
time.sleep(0.04*self.opt.batch_size*1.5) time.sleep(0.04*video_track._queue.qsize()*0.8)
# delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms # delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
# if delay > 0: # if delay > 0:

View File

@ -1,65 +1,22 @@
import time import time
import torch
import numpy as np import numpy as np
import queue import queue
from queue import Queue from queue import Queue
import multiprocessing as mp import multiprocessing as mp
from baseasr import BaseASR
from musetalk.whisper.audio2feature import Audio2Feature from musetalk.whisper.audio2feature import Audio2Feature
class MuseASR: class MuseASR(BaseASR):
def __init__(self, opt, audio_processor:Audio2Feature): def __init__(self, opt, audio_processor:Audio2Feature):
self.opt = opt super().__init__(opt)
self.fps = opt.fps # 20 ms per frame
self.sample_rate = 16000
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.queue = Queue()
# self.input_stream = BytesIO()
self.output_queue = mp.Queue()
self.audio_processor = audio_processor self.audio_processor = audio_processor
self.batch_size = opt.batch_size
self.frames = []
self.stride_left_size = opt.l
self.stride_right_size = opt.r
self.feat_queue = mp.Queue(5)
self.warm_up()
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.queue.put(audio_chunk)
def __get_audio_frame(self):
try:
frame = self.queue.get(block=True,timeout=0.01)
type = 0
#print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32)
type = 1
return frame,type
def get_audio_out(self): #get origin audio pcm to nerf
return self.output_queue.get()
def warm_up(self):
for _ in range(self.stride_left_size + self.stride_right_size):
audio_frame,type=self.__get_audio_frame()
self.frames.append(audio_frame)
self.output_queue.put((audio_frame,type))
for _ in range(self.stride_left_size):
self.output_queue.get()
def run_step(self): def run_step(self):
############################################## extract audio feature ############################################## ############################################## extract audio feature ##############################################
start_time = time.time() start_time = time.time()
for _ in range(self.batch_size*2): for _ in range(self.batch_size*2):
audio_frame,type=self.__get_audio_frame() audio_frame,type=self.get_audio_frame()
self.frames.append(audio_frame) self.frames.append(audio_frame)
self.output_queue.put((audio_frame,type)) self.output_queue.put((audio_frame,type))
@ -77,6 +34,3 @@ class MuseASR:
self.feat_queue.put(whisper_chunks) self.feat_queue.put(whisper_chunks)
# discard the old part to save memory # discard the old part to save memory
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
def get_next_feat(self,block,timeout):
return self.feat_queue.get(block,timeout)

View File

@ -157,6 +157,7 @@ class MuseReal:
self.__loadavatar() self.__loadavatar()
self.asr = MuseASR(opt,self.audio_processor) self.asr = MuseASR(opt,self.audio_processor)
self.asr.warm_up()
if opt.tts == "edgetts": if opt.tts == "edgetts":
self.tts = EdgeTTS(opt,self) self.tts = EdgeTTS(opt,self)
elif opt.tts == "gpt-sovits": elif opt.tts == "gpt-sovits":
@ -200,6 +201,11 @@ class MuseReal:
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.asr.put_audio_frame(audio_chunk) self.asr.put_audio_frame(audio_chunk)
def pause_talk(self):
self.tts.pause_talk()
self.asr.pause_talk()
def __mirror_index(self, index): def __mirror_index(self, index):
size = len(self.coord_list_cycle) size = len(self.coord_list_cycle)
turn = index // size turn = index // size
@ -297,9 +303,12 @@ class MuseReal:
# print(f"------actual avg infer fps:{count/totaltime:.4f}") # print(f"------actual avg infer fps:{count/totaltime:.4f}")
# count=0 # count=0
# totaltime=0 # totaltime=0
if video_track._queue.qsize()>=2*self.opt.batch_size: if video_track._queue.qsize()>=1.5*self.opt.batch_size:
print('sleep qsize=',video_track._queue.qsize()) print('sleep qsize=',video_track._queue.qsize())
time.sleep(0.04*self.opt.batch_size*1.5) time.sleep(0.04*video_track._queue.qsize()*0.8)
# if video_track._queue.qsize()>=5:
# print('sleep qsize=',video_track._queue.qsize())
# time.sleep(0.04*video_track._queue.qsize()*0.8)
# delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms # delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
# if delay > 0: # if delay > 0:

View File

@ -20,9 +20,6 @@ class NeRFReal:
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.W = opt.W self.W = opt.W
self.H = opt.H self.H = opt.H
self.debug = debug
self.training = False
self.step = 0 # training step
self.trainer = trainer self.trainer = trainer
self.data_loader = data_loader self.data_loader = data_loader
@ -44,7 +41,6 @@ class NeRFReal:
#self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item() #self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item()
# playing seq from dataloader, or pause. # playing seq from dataloader, or pause.
self.playing = True #False todo
self.loader = iter(data_loader) self.loader = iter(data_loader)
#self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32) #self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
@ -62,9 +58,8 @@ class NeRFReal:
self.customimg_index = 0 self.customimg_index = 0
# build asr # build asr
if self.opt.asr: self.asr = ASR(opt)
self.asr = ASR(opt) self.asr.warm_up()
self.asr.warm_up()
if opt.tts == "edgetts": if opt.tts == "edgetts":
self.tts = EdgeTTS(opt,self) self.tts = EdgeTTS(opt,self)
elif opt.tts == "gpt-sovits": elif opt.tts == "gpt-sovits":
@ -124,6 +119,10 @@ class NeRFReal:
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
self.asr.put_audio_frame(audio_chunk) self.asr.put_audio_frame(audio_chunk)
def pause_talk(self):
self.tts.pause_talk()
self.asr.pause_talk()
def mirror_index(self, index): def mirror_index(self, index):
size = self.opt.customvideo_imgnum size = self.opt.customvideo_imgnum
@ -248,10 +247,9 @@ class NeRFReal:
# update texture every frame # update texture every frame
# audio stream thread... # audio stream thread...
t = time.perf_counter() t = time.perf_counter()
if self.opt.asr and self.playing: # run 2 ASR steps (audio is at 50FPS, video is at 25FPS)
# run 2 ASR steps (audio is at 50FPS, video is at 25FPS) for _ in range(2):
for _ in range(2): self.asr.run_step()
self.asr.run_step()
self.test_step(loop,audio_track,video_track) self.test_step(loop,audio_track,video_track)
totaltime += (time.perf_counter() - t) totaltime += (time.perf_counter() - t)
count += 1 count += 1
@ -267,7 +265,7 @@ class NeRFReal:
else: else:
if video_track._queue.qsize()>=5: if video_track._queue.qsize()>=5:
#print('sleep qsize=',video_track._queue.qsize()) #print('sleep qsize=',video_track._queue.qsize())
time.sleep(0.1) time.sleep(0.04*video_track._queue.qsize()*0.8)
print('nerfreal thread stop') print('nerfreal thread stop')

View File

@ -13,6 +13,11 @@ import queue
from queue import Queue from queue import Queue
from io import BytesIO from io import BytesIO
from threading import Thread, Event from threading import Thread, Event
from enum import Enum
class State(Enum):
RUNNING=0
PAUSE=1
class BaseTTS: class BaseTTS:
def __init__(self, opt, parent): def __init__(self, opt, parent):
@ -25,6 +30,11 @@ class BaseTTS:
self.input_stream = BytesIO() self.input_stream = BytesIO()
self.msgqueue = Queue() self.msgqueue = Queue()
self.state = State.RUNNING
def pause_talk(self):
self.msgqueue.queue.clear()
self.state = State.PAUSE
def put_msg_txt(self,msg): def put_msg_txt(self,msg):
self.msgqueue.put(msg) self.msgqueue.put(msg)
@ -37,6 +47,7 @@ class BaseTTS:
while not quit_event.is_set(): while not quit_event.is_set():
try: try:
msg = self.msgqueue.get(block=True, timeout=1) msg = self.msgqueue.get(block=True, timeout=1)
self.state=State.RUNNING
except queue.Empty: except queue.Empty:
continue continue
self.txt_to_audio(msg) self.txt_to_audio(msg)
@ -59,7 +70,7 @@ class EdgeTTS(BaseTTS):
stream = self.__create_bytes_stream(self.input_stream) stream = self.__create_bytes_stream(self.input_stream)
streamlen = stream.shape[0] streamlen = stream.shape[0]
idx=0 idx=0
while streamlen >= self.chunk: while streamlen >= self.chunk and self.state==State.RUNNING:
self.parent.put_audio_frame(stream[idx:idx+self.chunk]) self.parent.put_audio_frame(stream[idx:idx+self.chunk])
streamlen -= self.chunk streamlen -= self.chunk
idx += self.chunk idx += self.chunk
@ -92,7 +103,7 @@ class EdgeTTS(BaseTTS):
async for chunk in communicate.stream(): async for chunk in communicate.stream():
if first: if first:
first = False first = False
if chunk["type"] == "audio": if chunk["type"] == "audio" and self.state==State.RUNNING:
#self.push_audio(chunk["data"]) #self.push_audio(chunk["data"])
self.input_stream.write(chunk["data"]) self.input_stream.write(chunk["data"])
#file.write(chunk["data"]) #file.write(chunk["data"])
@ -147,7 +158,7 @@ class VoitsTTS(BaseTTS):
end = time.perf_counter() end = time.perf_counter()
print(f"gpt_sovits Time to first chunk: {end-start}s") print(f"gpt_sovits Time to first chunk: {end-start}s")
first = False first = False
if chunk: if chunk and self.state==State.RUNNING:
yield chunk yield chunk
print("gpt_sovits response.elapsed:", res.elapsed) print("gpt_sovits response.elapsed:", res.elapsed)

View File

@ -79,6 +79,7 @@
body: JSON.stringify({ body: JSON.stringify({
text: message, text: message,
type: 'echo', type: 'echo',
interrupt: true,
}), }),
headers: { headers: {
'Content-Type': 'application/json' 'Content-Type': 'application/json'