reduce time delay; support audio attention choice

This commit is contained in:
lipku 2024-04-05 08:55:21 +08:00
parent 250cbaa587
commit 9cdd6fcadf
4 changed files with 232 additions and 202 deletions

View File

@ -133,9 +133,9 @@ srs和nginx的运行同2.1和2.3
在Tesla T4显卡上测试整体fps为18左右如果去掉音视频编码推流帧率在20左右。用4090显卡可以达到40多帧/秒。
优化:新开一个线程运行音视频编码推流
2. 延时
整体延时5s多
1tts延时2s左右目前用的edgetts需要将每句话转完后一次性输入可以优化tts改成流式输入
2wav2vec延时1s多需要缓存50帧音频做计算可以通过-m设置context_size来减少延时
整体延时3s左右
1tts延时1.7s左右目前用的edgetts需要将每句话转完后一次性输入可以优化tts改成流式输入
2wav2vec延时0.4s需要缓存18帧音频做计算
3srs转发延时设置srs服务器减少缓冲延时。具体配置可看 https://ossrs.net/lts/zh-cn/docs/v5/doc/low-latency, 配置了一个低延时版本
```python
docker run --rm -it -p 1935:1935 -p 1985:1985 -p 8080:8080 registry.cn-hangzhou.aliyuncs.com/lipku/srs:v1.1

6
app.py
View File

@ -37,7 +37,11 @@ async def main(voicename: str, text: str, render):
communicate = edge_tts.Communicate(text, voicename)
#with open(OUTPUT_FILE, "wb") as file:
first = True
async for chunk in communicate.stream():
if first:
#render.before_push_audio()
first = False
if chunk["type"] == "audio":
render.push_audio(chunk["data"])
#file.write(chunk["data"])
@ -258,7 +262,7 @@ if __name__ == '__main__':
parser.add_argument('--fps', type=int, default=50)
# sliding window left-middle-right length (unit: 20ms)
parser.add_argument('-l', type=int, default=10)
parser.add_argument('-m', type=int, default=50)
parser.add_argument('-m', type=int, default=8)
parser.add_argument('-r', type=int, default=10)
parser.add_argument('--fullbody', action='store_true', help="fullbody human")

View File

@ -122,58 +122,34 @@ class ASR:
self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding...
# warm up steps needed: mid + right + window_size + attention_size
self.warm_up_steps = self.context_size + self.stride_right_size + self.stride_left_size #+ 8 + 2 * 3
self.warm_up_steps = self.context_size + self.stride_right_size #+ self.stride_left_size #+ 8 + 2 * 3
self.listening = False
self.playing = False
def listen(self):
# start
if self.mode == 'live' and not self.listening:
print(f'[INFO] starting read frame thread...')
self.process_read_frame.start()
self.listening = True
if self.play and not self.playing:
print(f'[INFO] starting play frame thread...')
self.process_play_frame.start()
self.playing = True
def stop(self):
self.exit_event.set()
if self.play:
self.output_stream.stop_stream()
self.output_stream.close()
if self.playing:
self.process_play_frame.join()
self.playing = False
if self.mode == 'live':
#self.input_stream.stop_stream() todo
self.input_stream.close()
if self.listening:
self.process_read_frame.join()
self.listening = False
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.stop()
if self.mode == 'live':
# live mode: also print the result text.
self.text += '\n[END]'
print(self.text)
def get_next_feat(self):
def get_next_feat(self): #get audio embedding to nerf
# return a [1/8, 16] window, for the next input to nerf side.
if self.opt.att>0:
while len(self.att_feats) < 8:
# [------f+++t-----]
if self.front < self.tail:
feat = self.feat_queue[self.front:self.tail]
# [++t-----------f+]
else:
feat = torch.cat([self.feat_queue[self.front:], self.feat_queue[:self.tail]], dim=0)
while len(self.att_feats) < 8:
self.front = (self.front + 2) % self.feat_queue.shape[0]
self.tail = (self.tail + 2) % self.feat_queue.shape[0]
# print(self.front, self.tail, feat.shape)
self.att_feats.append(feat.permute(1, 0))
att_feat = torch.stack(self.att_feats, dim=0) # [8, 44, 16]
# discard old
self.att_feats = self.att_feats[1:]
else:
# [------f+++t-----]
if self.front < self.tail:
feat = self.feat_queue[self.front:self.tail]
@ -184,14 +160,8 @@ class ASR:
self.front = (self.front + 2) % self.feat_queue.shape[0]
self.tail = (self.tail + 2) % self.feat_queue.shape[0]
# print(self.front, self.tail, feat.shape)
att_feat = feat.permute(1, 0).unsqueeze(0)
self.att_feats.append(feat.permute(1, 0))
att_feat = torch.stack(self.att_feats, dim=0) # [8, 44, 16]
# discard old
self.att_feats = self.att_feats[1:]
return att_feat
@ -201,7 +171,7 @@ class ASR:
return
# get a frame of audio
frame = self.get_audio_frame()
frame = self.__get_audio_frame()
# the last frame
if frame is None:
@ -223,7 +193,7 @@ class ASR:
print(f'[INFO] frame_to_text... ')
#t = time.time()
logits, labels, text = self.frame_to_text(inputs)
logits, labels, text = self.__frame_to_text(inputs)
#print(f'-------wav2vec time:{time.time()-t:.4f}s')
feats = logits # better lips-sync than labels
@ -264,6 +234,166 @@ class ASR:
np.save(output_path, unfold_feats.cpu().numpy())
print(f"[INFO] saved logits to {output_path}")
def __get_audio_frame(self):
if self.inwarm: # warm up
return np.zeros(self.chunk, dtype=np.float32)
if self.mode == 'file':
if self.idx < self.file_stream.shape[0]:
frame = self.file_stream[self.idx: self.idx + self.chunk]
self.idx = self.idx + self.chunk
return frame
else:
return None
else:
try:
frame = self.queue.get(block=False)
print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32)
self.idx = self.idx + self.chunk
return frame
def __frame_to_text(self, frame):
# frame: [N * 320], N = (context_size + 2 * stride_size)
inputs = self.processor(frame, sampling_rate=self.sample_rate, return_tensors="pt", padding=True)
with torch.no_grad():
result = self.model(inputs.input_values.to(self.device))
if 'hubert' in self.opt.asr_model:
logits = result.last_hidden_state # [B=1, T=pts//320, hid=1024]
else:
logits = result.logits # [1, N - 1, 32]
#print('logits.shape:',logits.shape)
# cut off stride
left = max(0, self.stride_left_size)
right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input.
# do not cut right if terminated.
if self.terminated:
right = logits.shape[1]
logits = logits[:, left:right]
# print(frame.shape, inputs.input_values.shape, logits.shape)
#predicted_ids = torch.argmax(logits, dim=-1)
#transcription = self.processor.batch_decode(predicted_ids)[0].lower()
# for esperanto
# labels = np.array(['ŭ', '»', 'c', 'ĵ', 'ñ', '”', '„', '“', 'ǔ', 'o', 'ĝ', 'm', 'k', 'd', 'a', 'ŝ', 'z', 'i', '«', '—', '', 'ĥ', 'f', 'y', 'h', 'j', '|', 'r', 'u', 'ĉ', 's', '', 'fi', 'l', 'p', '', 'g', 'v', 't', 'b', 'n', 'e', '[UNK]', '[PAD]'])
# labels = np.array([' ', ' ', ' ', '-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z'])
# print(''.join(labels[predicted_ids[0].detach().cpu().long().numpy()]))
# print(predicted_ids[0])
# print(transcription)
return logits[0], None,None #predicted_ids[0], transcription # [N,]
def __create_bytes_stream(self,byte_stream):
#byte_stream=BytesIO(buffer)
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
stream = stream.astype(np.float32)
if stream.ndim > 1:
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0]
if sample_rate != self.sample_rate and stream.shape[0]>0:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
return stream
def push_audio(self,buffer): #push audio pcm from tts
print(f'[INFO] push_audio {len(buffer)}')
if self.opt.tts == "xtts":
if len(buffer)>0:
stream = np.frombuffer(buffer, dtype=np.int16).astype(np.float32) / 32767
stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate)
#byte_stream=BytesIO(buffer)
#stream = self.__create_bytes_stream(byte_stream)
streamlen = stream.shape[0]
idx=0
while streamlen >= self.chunk:
self.queue.put(stream[idx:idx+self.chunk])
streamlen -= self.chunk
idx += self.chunk
# if streamlen>0: #skip last frame(not 20ms)
# self.queue.put(stream[idx:])
else: #edge tts
self.input_stream.write(buffer)
if len(buffer)<=0:
self.input_stream.seek(0)
stream = self.__create_bytes_stream(self.input_stream)
streamlen = stream.shape[0]
idx=0
while streamlen >= self.chunk:
self.queue.put(stream[idx:idx+self.chunk])
streamlen -= self.chunk
idx += self.chunk
#if streamlen>0: #skip last frame(not 20ms)
# self.queue.put(stream[idx:])
self.input_stream.seek(0)
self.input_stream.truncate()
def get_audio_out(self): #get origin audio pcm to nerf
return self.output_queue.get()
def __init_queue(self):
self.frames = []
self.queue.queue.clear()
self.output_queue.queue.clear()
self.front = self.feat_buffer_size * self.context_size - 8 # fake padding
self.tail = 8
# attention window...
self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4
def before_push_audio(self):
self.__init_queue()
self.warm_up()
def run(self):
self.listen()
while not self.terminated:
self.run_step()
def clear_queue(self):
# clear the queue, to reduce potential latency...
print(f'[INFO] clear queue')
if self.mode == 'live':
self.queue.queue.clear()
if self.play:
self.output_queue.queue.clear()
def warm_up(self):
#self.listen()
self.inwarm = True
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
t = time.time()
for _ in range(self.stride_left_size):
self.frames.append(np.zeros(self.chunk, dtype=np.float32))
for _ in range(self.warm_up_steps):
self.run_step()
#if torch.cuda.is_available():
# torch.cuda.synchronize()
t = time.time() - t
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
self.inwarm = False
#self.clear_queue()
'''
def create_file_stream(self):
@ -311,157 +441,50 @@ class ASR:
return audio, stream
'''
#####not used function#####################################
def listen(self):
# start
if self.mode == 'live' and not self.listening:
print(f'[INFO] starting read frame thread...')
self.process_read_frame.start()
self.listening = True
def get_audio_frame(self):
if self.play and not self.playing:
print(f'[INFO] starting play frame thread...')
self.process_play_frame.start()
self.playing = True
if self.inwarm: # warm up
return np.zeros(self.chunk, dtype=np.float32)
def stop(self):
if self.mode == 'file':
self.exit_event.set()
if self.idx < self.file_stream.shape[0]:
frame = self.file_stream[self.idx: self.idx + self.chunk]
self.idx = self.idx + self.chunk
return frame
else:
return None
else:
try:
frame = self.queue.get(block=False)
print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32)
self.idx = self.idx + self.chunk
return frame
def frame_to_text(self, frame):
# frame: [N * 320], N = (context_size + 2 * stride_size)
inputs = self.processor(frame, sampling_rate=self.sample_rate, return_tensors="pt", padding=True)
with torch.no_grad():
result = self.model(inputs.input_values.to(self.device))
if 'hubert' in self.opt.asr_model:
logits = result.last_hidden_state # [B=1, T=pts//320, hid=1024]
else:
logits = result.logits # [1, N - 1, 32]
#print('logits.shape:',logits.shape)
# cut off stride
left = max(0, self.stride_left_size)
right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input.
# do not cut right if terminated.
if self.terminated:
right = logits.shape[1]
logits = logits[:, left:right]
# print(frame.shape, inputs.input_values.shape, logits.shape)
#predicted_ids = torch.argmax(logits, dim=-1)
#transcription = self.processor.batch_decode(predicted_ids)[0].lower()
# for esperanto
# labels = np.array(['ŭ', '»', 'c', 'ĵ', 'ñ', '”', '„', '“', 'ǔ', 'o', 'ĝ', 'm', 'k', 'd', 'a', 'ŝ', 'z', 'i', '«', '—', '', 'ĥ', 'f', 'y', 'h', 'j', '|', 'r', 'u', 'ĉ', 's', '', 'fi', 'l', 'p', '', 'g', 'v', 't', 'b', 'n', 'e', '[UNK]', '[PAD]'])
# labels = np.array([' ', ' ', ' ', '-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z'])
# print(''.join(labels[predicted_ids[0].detach().cpu().long().numpy()]))
# print(predicted_ids[0])
# print(transcription)
return logits[0], None,None #predicted_ids[0], transcription # [N,]
def create_bytes_stream(self,byte_stream):
#byte_stream=BytesIO(buffer)
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
stream = stream.astype(np.float32)
if stream.ndim > 1:
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0]
if sample_rate != self.sample_rate and stream.shape[0]>0:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
return stream
def push_audio(self,buffer):
print(f'[INFO] push_audio {len(buffer)}')
if self.opt.tts == "xtts":
if len(buffer)>0:
stream = np.frombuffer(buffer, dtype=np.int16).astype(np.float32) / 32767
stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate)
#byte_stream=BytesIO(buffer)
#stream = self.create_bytes_stream(byte_stream)
streamlen = stream.shape[0]
idx=0
while streamlen >= self.chunk:
self.queue.put(stream[idx:idx+self.chunk])
streamlen -= self.chunk
idx += self.chunk
# if streamlen>0: #skip last frame(not 20ms)
# self.queue.put(stream[idx:])
else: #edge tts
self.input_stream.write(buffer)
if len(buffer)<=0:
self.input_stream.seek(0)
stream = self.create_bytes_stream(self.input_stream)
streamlen = stream.shape[0]
idx=0
while streamlen >= self.chunk:
self.queue.put(stream[idx:idx+self.chunk])
streamlen -= self.chunk
idx += self.chunk
#if streamlen>0: #skip last frame(not 20ms)
# self.queue.put(stream[idx:])
self.input_stream.seek(0)
self.input_stream.truncate()
def get_audio_out(self):
return self.output_queue.get()
def run(self):
self.listen()
while not self.terminated:
self.run_step()
def clear_queue(self):
# clear the queue, to reduce potential latency...
print(f'[INFO] clear queue')
if self.mode == 'live':
self.queue.queue.clear()
if self.play:
self.output_queue.queue.clear()
self.output_stream.stop_stream()
self.output_stream.close()
if self.playing:
self.process_play_frame.join()
self.playing = False
def warm_up(self):
#self.listen()
self.inwarm = True
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
t = time.time()
for _ in range(self.warm_up_steps):
self.run_step()
if torch.cuda.is_available():
torch.cuda.synchronize()
t = time.time() - t
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
self.inwarm = False
#self.clear_queue()
if self.mode == 'live':
#self.input_stream.stop_stream() todo
self.input_stream.close()
if self.listening:
self.process_read_frame.join()
self.listening = False
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.stop()
if self.mode == 'live':
# live mode: also print the result text.
self.text += '\n[END]'
print(self.text)
#########################################################
if __name__ == '__main__':
import argparse

View File

@ -109,6 +109,9 @@ class NeRFReal:
def push_audio(self,chunk):
self.asr.push_audio(chunk)
def before_push_audio(self):
self.asr.before_push_audio()
def prepare_buffer(self, outputs):
if self.mode == 'image':
return outputs['image']