fix audiostream buffer

This commit is contained in:
lihengzhong 2024-01-09 10:01:50 +08:00
parent 6a7af5d006
commit 46b0f5abab
2 changed files with 663 additions and 655 deletions

View File

@ -1,478 +1,480 @@
import time import time
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import AutoModelForCTC, AutoProcessor from transformers import AutoModelForCTC, AutoProcessor
#import pyaudio #import pyaudio
import soundfile as sf import soundfile as sf
import resampy import resampy
import queue import queue
from queue import Queue from queue import Queue
#from collections import deque #from collections import deque
from threading import Thread, Event from threading import Thread, Event
from io import BytesIO from io import BytesIO
def _read_frame(stream, exit_event, queue, chunk): def _read_frame(stream, exit_event, queue, chunk):
while True: while True:
if exit_event.is_set(): if exit_event.is_set():
print(f'[INFO] read frame thread ends') print(f'[INFO] read frame thread ends')
break break
frame = stream.read(chunk, exception_on_overflow=False) frame = stream.read(chunk, exception_on_overflow=False)
frame = np.frombuffer(frame, dtype=np.int16).astype(np.float32) / 32767 # [chunk] frame = np.frombuffer(frame, dtype=np.int16).astype(np.float32) / 32767 # [chunk]
queue.put(frame) queue.put(frame)
def _play_frame(stream, exit_event, queue, chunk): def _play_frame(stream, exit_event, queue, chunk):
while True: while True:
if exit_event.is_set(): if exit_event.is_set():
print(f'[INFO] play frame thread ends') print(f'[INFO] play frame thread ends')
break break
frame = queue.get() frame = queue.get()
frame = (frame * 32767).astype(np.int16).tobytes() frame = (frame * 32767).astype(np.int16).tobytes()
stream.write(frame, chunk) stream.write(frame, chunk)
class ASR: class ASR:
def __init__(self, opt): def __init__(self, opt):
self.opt = opt self.opt = opt
self.play = opt.asr_play #false self.play = opt.asr_play #false
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.fps = opt.fps # 20 ms per frame self.fps = opt.fps # 20 ms per frame
self.sample_rate = 16000 self.sample_rate = 16000
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000) self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.mode = 'live' if opt.asr_wav == '' else 'file' self.mode = 'live' if opt.asr_wav == '' else 'file'
if 'esperanto' in self.opt.asr_model: if 'esperanto' in self.opt.asr_model:
self.audio_dim = 44 self.audio_dim = 44
elif 'deepspeech' in self.opt.asr_model: elif 'deepspeech' in self.opt.asr_model:
self.audio_dim = 29 self.audio_dim = 29
else: else:
self.audio_dim = 32 self.audio_dim = 32
# prepare context cache # prepare context cache
# each segment is (stride_left + ctx + stride_right) * 20ms, latency should be (ctx + stride_right) * 20ms # each segment is (stride_left + ctx + stride_right) * 20ms, latency should be (ctx + stride_right) * 20ms
self.context_size = opt.m self.context_size = opt.m
self.stride_left_size = opt.l self.stride_left_size = opt.l
self.stride_right_size = opt.r self.stride_right_size = opt.r
self.text = '[START]\n' self.text = '[START]\n'
self.terminated = False self.terminated = False
self.frames = [] self.frames = []
self.inwarm = False self.inwarm = False
# pad left frames # pad left frames
if self.stride_left_size > 0: if self.stride_left_size > 0:
self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size) self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size)
self.exit_event = Event() self.exit_event = Event()
#self.audio_instance = pyaudio.PyAudio() #not need #self.audio_instance = pyaudio.PyAudio() #not need
# create input stream # create input stream
if self.mode == 'file': #live mode if self.mode == 'file': #live mode
self.file_stream = self.create_file_stream() self.file_stream = self.create_file_stream()
else: else:
self.queue = Queue() self.queue = Queue()
self.input_stream = BytesIO() self.input_stream = BytesIO()
self.output_queue = Queue() self.output_queue = Queue()
# start a background process to read frames # start a background process to read frames
#self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk) #self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk)
#self.queue = Queue() #self.queue = Queue()
#self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk)) #self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk))
# play out the audio too...? # play out the audio too...?
if self.play: if self.play:
self.output_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=False, output=True, frames_per_buffer=self.chunk) self.output_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=False, output=True, frames_per_buffer=self.chunk)
self.output_queue = Queue() self.output_queue = Queue()
self.process_play_frame = Thread(target=_play_frame, args=(self.output_stream, self.exit_event, self.output_queue, self.chunk)) self.process_play_frame = Thread(target=_play_frame, args=(self.output_stream, self.exit_event, self.output_queue, self.chunk))
# current location of audio # current location of audio
self.idx = 0 self.idx = 0
# create wav2vec model # create wav2vec model
print(f'[INFO] loading ASR model {self.opt.asr_model}...') print(f'[INFO] loading ASR model {self.opt.asr_model}...')
self.processor = AutoProcessor.from_pretrained(opt.asr_model) self.processor = AutoProcessor.from_pretrained(opt.asr_model)
self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device) self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device)
# prepare to save logits # prepare to save logits
if self.opt.asr_save_feats: if self.opt.asr_save_feats:
self.all_feats = [] self.all_feats = []
# the extracted features # the extracted features
# use a loop queue to efficiently record endless features: [f--t---][-------][-------] # use a loop queue to efficiently record endless features: [f--t---][-------][-------]
self.feat_buffer_size = 4 self.feat_buffer_size = 4
self.feat_buffer_idx = 0 self.feat_buffer_idx = 0
self.feat_queue = torch.zeros(self.feat_buffer_size * self.context_size, self.audio_dim, dtype=torch.float32, device=self.device) self.feat_queue = torch.zeros(self.feat_buffer_size * self.context_size, self.audio_dim, dtype=torch.float32, device=self.device)
# TODO: hard coded 16 and 8 window size... # TODO: hard coded 16 and 8 window size...
self.front = self.feat_buffer_size * self.context_size - 8 # fake padding self.front = self.feat_buffer_size * self.context_size - 8 # fake padding
self.tail = 8 self.tail = 8
# attention window... # attention window...
self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding... self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding...
# warm up steps needed: mid + right + window_size + attention_size # warm up steps needed: mid + right + window_size + attention_size
self.warm_up_steps = self.context_size + self.stride_right_size + 8 + 2 * 3 self.warm_up_steps = self.context_size + self.stride_right_size + 8 + 2 * 3
self.listening = False self.listening = False
self.playing = False self.playing = False
def listen(self): def listen(self):
# start # start
if self.mode == 'live' and not self.listening: if self.mode == 'live' and not self.listening:
print(f'[INFO] starting read frame thread...') print(f'[INFO] starting read frame thread...')
self.process_read_frame.start() self.process_read_frame.start()
self.listening = True self.listening = True
if self.play and not self.playing: if self.play and not self.playing:
print(f'[INFO] starting play frame thread...') print(f'[INFO] starting play frame thread...')
self.process_play_frame.start() self.process_play_frame.start()
self.playing = True self.playing = True
def stop(self): def stop(self):
self.exit_event.set() self.exit_event.set()
if self.play: if self.play:
self.output_stream.stop_stream() self.output_stream.stop_stream()
self.output_stream.close() self.output_stream.close()
if self.playing: if self.playing:
self.process_play_frame.join() self.process_play_frame.join()
self.playing = False self.playing = False
if self.mode == 'live': if self.mode == 'live':
#self.input_stream.stop_stream() todo #self.input_stream.stop_stream() todo
self.input_stream.close() self.input_stream.close()
if self.listening: if self.listening:
self.process_read_frame.join() self.process_read_frame.join()
self.listening = False self.listening = False
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
self.stop() self.stop()
if self.mode == 'live': if self.mode == 'live':
# live mode: also print the result text. # live mode: also print the result text.
self.text += '\n[END]' self.text += '\n[END]'
print(self.text) print(self.text)
def get_next_feat(self): def get_next_feat(self):
# return a [1/8, 16] window, for the next input to nerf side. # return a [1/8, 16] window, for the next input to nerf side.
while len(self.att_feats) < 8: while len(self.att_feats) < 8:
# [------f+++t-----] # [------f+++t-----]
if self.front < self.tail: if self.front < self.tail:
feat = self.feat_queue[self.front:self.tail] feat = self.feat_queue[self.front:self.tail]
# [++t-----------f+] # [++t-----------f+]
else: else:
feat = torch.cat([self.feat_queue[self.front:], self.feat_queue[:self.tail]], dim=0) feat = torch.cat([self.feat_queue[self.front:], self.feat_queue[:self.tail]], dim=0)
self.front = (self.front + 2) % self.feat_queue.shape[0] self.front = (self.front + 2) % self.feat_queue.shape[0]
self.tail = (self.tail + 2) % self.feat_queue.shape[0] self.tail = (self.tail + 2) % self.feat_queue.shape[0]
# print(self.front, self.tail, feat.shape) # print(self.front, self.tail, feat.shape)
self.att_feats.append(feat.permute(1, 0)) self.att_feats.append(feat.permute(1, 0))
att_feat = torch.stack(self.att_feats, dim=0) # [8, 44, 16] att_feat = torch.stack(self.att_feats, dim=0) # [8, 44, 16]
# discard old # discard old
self.att_feats = self.att_feats[1:] self.att_feats = self.att_feats[1:]
return att_feat return att_feat
def run_step(self): def run_step(self):
if self.terminated: if self.terminated:
return return
# get a frame of audio # get a frame of audio
frame = self.get_audio_frame() frame = self.get_audio_frame()
# the last frame # the last frame
if frame is None: if frame is None:
# terminate, but always run the network for the left frames # terminate, but always run the network for the left frames
self.terminated = True self.terminated = True
else: else:
self.frames.append(frame) self.frames.append(frame)
# put to output # put to output
#if self.play: #if self.play:
self.output_queue.put(frame) self.output_queue.put(frame)
# context not enough, do not run network. # context not enough, do not run network.
if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size: if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
return return
inputs = np.concatenate(self.frames) # [N * chunk] inputs = np.concatenate(self.frames) # [N * chunk]
# discard the old part to save memory # discard the old part to save memory
if not self.terminated: if not self.terminated:
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
print(f'[INFO] frame_to_text... ') print(f'[INFO] frame_to_text... ')
logits, labels, text = self.frame_to_text(inputs) logits, labels, text = self.frame_to_text(inputs)
feats = logits # better lips-sync than labels feats = logits # better lips-sync than labels
# save feats # save feats
if self.opt.asr_save_feats: if self.opt.asr_save_feats:
self.all_feats.append(feats) self.all_feats.append(feats)
# record the feats efficiently.. (no concat, constant memory) # record the feats efficiently.. (no concat, constant memory)
start = self.feat_buffer_idx * self.context_size start = self.feat_buffer_idx * self.context_size
end = start + feats.shape[0] end = start + feats.shape[0]
self.feat_queue[start:end] = feats self.feat_queue[start:end] = feats
self.feat_buffer_idx = (self.feat_buffer_idx + 1) % self.feat_buffer_size self.feat_buffer_idx = (self.feat_buffer_idx + 1) % self.feat_buffer_size
# very naive, just concat the text output. # very naive, just concat the text output.
if text != '': if text != '':
self.text = self.text + ' ' + text self.text = self.text + ' ' + text
# will only run once at ternimation # will only run once at ternimation
if self.terminated: if self.terminated:
self.text += '\n[END]' self.text += '\n[END]'
print(self.text) print(self.text)
if self.opt.asr_save_feats: if self.opt.asr_save_feats:
print(f'[INFO] save all feats for training purpose... ') print(f'[INFO] save all feats for training purpose... ')
feats = torch.cat(self.all_feats, dim=0) # [N, C] feats = torch.cat(self.all_feats, dim=0) # [N, C]
# print('[INFO] before unfold', feats.shape) # print('[INFO] before unfold', feats.shape)
window_size = 16 window_size = 16
padding = window_size // 2 padding = window_size // 2
feats = feats.view(-1, self.audio_dim).permute(1, 0).contiguous() # [C, M] feats = feats.view(-1, self.audio_dim).permute(1, 0).contiguous() # [C, M]
feats = feats.view(1, self.audio_dim, -1, 1) # [1, C, M, 1] feats = feats.view(1, self.audio_dim, -1, 1) # [1, C, M, 1]
unfold_feats = F.unfold(feats, kernel_size=(window_size, 1), padding=(padding, 0), stride=(2, 1)) # [1, C * window_size, M / 2 + 1] unfold_feats = F.unfold(feats, kernel_size=(window_size, 1), padding=(padding, 0), stride=(2, 1)) # [1, C * window_size, M / 2 + 1]
unfold_feats = unfold_feats.view(self.audio_dim, window_size, -1).permute(2, 1, 0).contiguous() # [C, window_size, M / 2 + 1] --> [M / 2 + 1, window_size, C] unfold_feats = unfold_feats.view(self.audio_dim, window_size, -1).permute(2, 1, 0).contiguous() # [C, window_size, M / 2 + 1] --> [M / 2 + 1, window_size, C]
# print('[INFO] after unfold', unfold_feats.shape) # print('[INFO] after unfold', unfold_feats.shape)
# save to a npy file # save to a npy file
if 'esperanto' in self.opt.asr_model: if 'esperanto' in self.opt.asr_model:
output_path = self.opt.asr_wav.replace('.wav', '_eo.npy') output_path = self.opt.asr_wav.replace('.wav', '_eo.npy')
else: else:
output_path = self.opt.asr_wav.replace('.wav', '.npy') output_path = self.opt.asr_wav.replace('.wav', '.npy')
np.save(output_path, unfold_feats.cpu().numpy()) np.save(output_path, unfold_feats.cpu().numpy())
print(f"[INFO] saved logits to {output_path}") print(f"[INFO] saved logits to {output_path}")
def create_file_stream(self): def create_file_stream(self):
stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64 stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64
stream = stream.astype(np.float32) stream = stream.astype(np.float32)
if stream.ndim > 1: if stream.ndim > 1:
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0] stream = stream[:, 0]
if sample_rate != self.sample_rate: if sample_rate != self.sample_rate:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
print(f'[INFO] loaded audio stream {self.opt.asr_wav}: {stream.shape}') print(f'[INFO] loaded audio stream {self.opt.asr_wav}: {stream.shape}')
return stream return stream
def create_pyaudio_stream(self): def create_pyaudio_stream(self):
import pyaudio import pyaudio
print(f'[INFO] creating live audio stream ...') print(f'[INFO] creating live audio stream ...')
audio = pyaudio.PyAudio() audio = pyaudio.PyAudio()
# get devices # get devices
info = audio.get_host_api_info_by_index(0) info = audio.get_host_api_info_by_index(0)
n_devices = info.get('deviceCount') n_devices = info.get('deviceCount')
for i in range(0, n_devices): for i in range(0, n_devices):
if (audio.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: if (audio.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
name = audio.get_device_info_by_host_api_device_index(0, i).get('name') name = audio.get_device_info_by_host_api_device_index(0, i).get('name')
print(f'[INFO] choose audio device {name}, id {i}') print(f'[INFO] choose audio device {name}, id {i}')
break break
# get stream # get stream
stream = audio.open(input_device_index=i, stream = audio.open(input_device_index=i,
format=pyaudio.paInt16, format=pyaudio.paInt16,
channels=1, channels=1,
rate=self.sample_rate, rate=self.sample_rate,
input=True, input=True,
frames_per_buffer=self.chunk) frames_per_buffer=self.chunk)
return audio, stream return audio, stream
def get_audio_frame(self): def get_audio_frame(self):
if self.inwarm: # warm up if self.inwarm: # warm up
return np.zeros(self.chunk, dtype=np.float32) return np.zeros(self.chunk, dtype=np.float32)
if self.mode == 'file': if self.mode == 'file':
if self.idx < self.file_stream.shape[0]: if self.idx < self.file_stream.shape[0]:
frame = self.file_stream[self.idx: self.idx + self.chunk] frame = self.file_stream[self.idx: self.idx + self.chunk]
self.idx = self.idx + self.chunk self.idx = self.idx + self.chunk
return frame return frame
else: else:
return None return None
else: else:
try: try:
frame = self.queue.get(block=False) frame = self.queue.get(block=False)
print(f'[INFO] get frame {frame.shape}') print(f'[INFO] get frame {frame.shape}')
except queue.Empty: except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32) frame = np.zeros(self.chunk, dtype=np.float32)
self.idx = self.idx + self.chunk self.idx = self.idx + self.chunk
return frame return frame
def frame_to_text(self, frame): def frame_to_text(self, frame):
# frame: [N * 320], N = (context_size + 2 * stride_size) # frame: [N * 320], N = (context_size + 2 * stride_size)
inputs = self.processor(frame, sampling_rate=self.sample_rate, return_tensors="pt", padding=True) inputs = self.processor(frame, sampling_rate=self.sample_rate, return_tensors="pt", padding=True)
with torch.no_grad(): with torch.no_grad():
result = self.model(inputs.input_values.to(self.device)) result = self.model(inputs.input_values.to(self.device))
logits = result.logits # [1, N - 1, 32] logits = result.logits # [1, N - 1, 32]
# cut off stride # cut off stride
left = max(0, self.stride_left_size) left = max(0, self.stride_left_size)
right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input. right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input.
# do not cut right if terminated. # do not cut right if terminated.
if self.terminated: if self.terminated:
right = logits.shape[1] right = logits.shape[1]
logits = logits[:, left:right] logits = logits[:, left:right]
# print(frame.shape, inputs.input_values.shape, logits.shape) # print(frame.shape, inputs.input_values.shape, logits.shape)
predicted_ids = torch.argmax(logits, dim=-1) predicted_ids = torch.argmax(logits, dim=-1)
transcription = self.processor.batch_decode(predicted_ids)[0].lower() transcription = self.processor.batch_decode(predicted_ids)[0].lower()
# for esperanto # for esperanto
# labels = np.array(['ŭ', '»', 'c', 'ĵ', 'ñ', '”', '„', '“', 'ǔ', 'o', 'ĝ', 'm', 'k', 'd', 'a', 'ŝ', 'z', 'i', '«', '—', '', 'ĥ', 'f', 'y', 'h', 'j', '|', 'r', 'u', 'ĉ', 's', '', 'fi', 'l', 'p', '', 'g', 'v', 't', 'b', 'n', 'e', '[UNK]', '[PAD]']) # labels = np.array(['ŭ', '»', 'c', 'ĵ', 'ñ', '”', '„', '“', 'ǔ', 'o', 'ĝ', 'm', 'k', 'd', 'a', 'ŝ', 'z', 'i', '«', '—', '', 'ĥ', 'f', 'y', 'h', 'j', '|', 'r', 'u', 'ĉ', 's', '', 'fi', 'l', 'p', '', 'g', 'v', 't', 'b', 'n', 'e', '[UNK]', '[PAD]'])
# labels = np.array([' ', ' ', ' ', '-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z']) # labels = np.array([' ', ' ', ' ', '-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z'])
# print(''.join(labels[predicted_ids[0].detach().cpu().long().numpy()])) # print(''.join(labels[predicted_ids[0].detach().cpu().long().numpy()]))
# print(predicted_ids[0]) # print(predicted_ids[0])
# print(transcription) # print(transcription)
return logits[0], predicted_ids[0], transcription # [N,] return logits[0], predicted_ids[0], transcription # [N,]
def create_bytes_stream(self,byte_stream): def create_bytes_stream(self,byte_stream):
#byte_stream=BytesIO(buffer) #byte_stream=BytesIO(buffer)
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64 stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}') print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
stream = stream.astype(np.float32) stream = stream.astype(np.float32)
if stream.ndim > 1: if stream.ndim > 1:
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0] stream = stream[:, 0]
if sample_rate != self.sample_rate: if sample_rate != self.sample_rate:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
return stream return stream
def push_audio(self,buffer): def push_audio(self,buffer):
print(f'[INFO] push_audio {len(buffer)}') print(f'[INFO] push_audio {len(buffer)}')
# if len(buffer)>0: # if len(buffer)>0:
# byte_stream=BytesIO(buffer) # byte_stream=BytesIO(buffer)
# stream = self.create_bytes_stream(byte_stream) # stream = self.create_bytes_stream(byte_stream)
# streamlen = stream.shape[0] # streamlen = stream.shape[0]
# idx=0 # idx=0
# while streamlen >= self.chunk: # while streamlen >= self.chunk:
# self.queue.put(stream[idx:idx+self.chunk]) # self.queue.put(stream[idx:idx+self.chunk])
# streamlen -= self.chunk # streamlen -= self.chunk
# idx += self.chunk # idx += self.chunk
# if streamlen>0: # if streamlen>0:
# self.queue.put(stream[idx:]) # self.queue.put(stream[idx:])
self.input_stream.write(buffer) self.input_stream.write(buffer)
if len(buffer)<=0: if len(buffer)<=0:
self.input_stream.seek(0) self.input_stream.seek(0)
stream = self.create_bytes_stream(self.input_stream) stream = self.create_bytes_stream(self.input_stream)
streamlen = stream.shape[0] streamlen = stream.shape[0]
idx=0 idx=0
while streamlen >= self.chunk: while streamlen >= self.chunk:
self.queue.put(stream[idx:idx+self.chunk]) self.queue.put(stream[idx:idx+self.chunk])
streamlen -= self.chunk streamlen -= self.chunk
idx += self.chunk idx += self.chunk
if streamlen>0: if streamlen>0:
self.queue.put(stream[idx:]) self.queue.put(stream[idx:])
self.input_stream.seek(0)
def get_audio_out(self): self.input_stream.truncate()
return self.output_queue.get()
def get_audio_out(self):
def run(self): return self.output_queue.get()
self.listen() def run(self):
while not self.terminated: self.listen()
self.run_step()
while not self.terminated:
def clear_queue(self): self.run_step()
# clear the queue, to reduce potential latency...
print(f'[INFO] clear queue') def clear_queue(self):
if self.mode == 'live': # clear the queue, to reduce potential latency...
self.queue.queue.clear() print(f'[INFO] clear queue')
if self.play: if self.mode == 'live':
self.output_queue.queue.clear() self.queue.queue.clear()
if self.play:
def warm_up(self): self.output_queue.queue.clear()
#self.listen() def warm_up(self):
self.inwarm = True #self.listen()
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
t = time.time() self.inwarm = True
for _ in range(self.warm_up_steps): print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
self.run_step() t = time.time()
if torch.cuda.is_available(): for _ in range(self.warm_up_steps):
torch.cuda.synchronize() self.run_step()
t = time.time() - t if torch.cuda.is_available():
print(f'[INFO] warm-up done, actual latency = {t:.6f}s') torch.cuda.synchronize()
self.inwarm = False t = time.time() - t
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
#self.clear_queue() self.inwarm = False
#self.clear_queue()
if __name__ == '__main__':
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser() import argparse
parser.add_argument('--wav', type=str, default='')
parser.add_argument('--play', action='store_true', help="play out the audio") parser = argparse.ArgumentParser()
parser.add_argument('--wav', type=str, default='')
parser.add_argument('--model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') parser.add_argument('--play', action='store_true', help="play out the audio")
# parser.add_argument('--model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
parser.add_argument('--model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto')
parser.add_argument('--save_feats', action='store_true') # parser.add_argument('--model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
# audio FPS
parser.add_argument('--fps', type=int, default=50) parser.add_argument('--save_feats', action='store_true')
# sliding window left-middle-right length. # audio FPS
parser.add_argument('-l', type=int, default=10) parser.add_argument('--fps', type=int, default=50)
parser.add_argument('-m', type=int, default=50) # sliding window left-middle-right length.
parser.add_argument('-r', type=int, default=10) parser.add_argument('-l', type=int, default=10)
parser.add_argument('-m', type=int, default=50)
opt = parser.parse_args() parser.add_argument('-r', type=int, default=10)
# fix opt = parser.parse_args()
opt.asr_wav = opt.wav
opt.asr_play = opt.play # fix
opt.asr_model = opt.model opt.asr_wav = opt.wav
opt.asr_save_feats = opt.save_feats opt.asr_play = opt.play
opt.asr_model = opt.model
if 'deepspeech' in opt.asr_model: opt.asr_save_feats = opt.save_feats
raise ValueError("DeepSpeech features should not use this code to extract...")
if 'deepspeech' in opt.asr_model:
with ASR(opt) as asr: raise ValueError("DeepSpeech features should not use this code to extract...")
with ASR(opt) as asr:
asr.run() asr.run()

View File

@ -1,178 +1,184 @@
import math import math
import torch import torch
import numpy as np import numpy as np
#from .utils import * #from .utils import *
import subprocess import subprocess
import os import os
import time
from asrreal import ASR
from rtmp_streaming import StreamerConfig, Streamer from asrreal import ASR
from rtmp_streaming import StreamerConfig, Streamer
class NeRFReal:
def __init__(self, opt, trainer, data_loader, debug=True): class NeRFReal:
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. def __init__(self, opt, trainer, data_loader, debug=True):
self.W = opt.W self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.H = opt.H self.W = opt.W
self.debug = debug self.H = opt.H
self.training = False self.debug = debug
self.step = 0 # training step self.training = False
self.step = 0 # training step
self.trainer = trainer
self.data_loader = data_loader self.trainer = trainer
self.data_loader = data_loader
# use dataloader's bg
bg_img = data_loader._data.bg_img #.view(1, -1, 3) # use dataloader's bg
if self.H != bg_img.shape[0] or self.W != bg_img.shape[1]: bg_img = data_loader._data.bg_img #.view(1, -1, 3)
bg_img = F.interpolate(bg_img.permute(2, 0, 1).unsqueeze(0).contiguous(), (self.H, self.W), mode='bilinear').squeeze(0).permute(1, 2, 0).contiguous() if self.H != bg_img.shape[0] or self.W != bg_img.shape[1]:
self.bg_color = bg_img.view(1, -1, 3) bg_img = F.interpolate(bg_img.permute(2, 0, 1).unsqueeze(0).contiguous(), (self.H, self.W), mode='bilinear').squeeze(0).permute(1, 2, 0).contiguous()
self.bg_color = bg_img.view(1, -1, 3)
# audio features (from dataloader, only used in non-playing mode)
self.audio_features = data_loader._data.auds # [N, 29, 16] # audio features (from dataloader, only used in non-playing mode)
self.audio_idx = 0 self.audio_features = data_loader._data.auds # [N, 29, 16]
self.audio_idx = 0
# control eye
self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item() # control eye
self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item()
# playing seq from dataloader, or pause.
self.playing = True #False todo # playing seq from dataloader, or pause.
self.loader = iter(data_loader) self.playing = True #False todo
self.loader = iter(data_loader)
self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
self.need_update = True # camera moved, should reset accumulation self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
self.spp = 1 # sample per pixel self.need_update = True # camera moved, should reset accumulation
self.mode = 'image' # choose from ['image', 'depth'] self.spp = 1 # sample per pixel
self.mode = 'image' # choose from ['image', 'depth']
self.dynamic_resolution = False # assert False!
self.downscale = 1 self.dynamic_resolution = False # assert False!
self.train_steps = 16 self.downscale = 1
self.train_steps = 16
self.ind_index = 0
self.ind_num = trainer.model.individual_codes.shape[0] self.ind_index = 0
self.ind_num = trainer.model.individual_codes.shape[0]
# build asr
if self.opt.asr: # build asr
self.asr = ASR(opt) if self.opt.asr:
self.asr = ASR(opt)
fps=25
#push_url='rtmp://localhost/live/livestream' #'data/video/output_0.mp4' fps=25
sc = StreamerConfig() #push_url='rtmp://localhost/live/livestream' #'data/video/output_0.mp4'
sc.source_width = self.W sc = StreamerConfig()
sc.source_height = self.H sc.source_width = self.W
sc.stream_width = self.W sc.source_height = self.H
sc.stream_height = self.H sc.stream_width = self.W
sc.stream_fps = fps sc.stream_height = self.H
sc.stream_bitrate = 1000000 sc.stream_fps = fps
sc.stream_profile = 'main' #'high444' # 'main' sc.stream_bitrate = 1000000
sc.audio_channel = 1 sc.stream_profile = 'main' #'high444' # 'main'
sc.sample_rate = 16000 sc.audio_channel = 1
sc.stream_server = opt.push_url sc.sample_rate = 16000
sc.stream_server = opt.push_url
self.streamer = Streamer()
self.streamer.init(sc) self.streamer = Streamer()
self.streamer.enable_av_debug_log() self.streamer.init(sc)
self.streamer.enable_av_debug_log()
'''
video_path = 'video_stream' '''
if not os.path.exists(video_path): video_path = 'video_stream'
os.mkfifo(video_path, mode=0o777) if not os.path.exists(video_path):
audio_path = 'audio_stream' os.mkfifo(video_path, mode=0o777)
if not os.path.exists(audio_path): audio_path = 'audio_stream'
os.mkfifo(audio_path, mode=0o777) if not os.path.exists(audio_path):
width=450 os.mkfifo(audio_path, mode=0o777)
height=450 width=450
command = ['ffmpeg', height=450
'-y', #'-an', command = ['ffmpeg',
#'-re', '-y', #'-an',
'-f', 'rawvideo', #'-re',
'-vcodec','rawvideo', '-f', 'rawvideo',
'-pix_fmt', 'rgb24', #像素格式 '-vcodec','rawvideo',
'-s', "{}x{}".format(width, height), '-pix_fmt', 'rgb24', #像素格式
'-r', str(fps), '-s', "{}x{}".format(width, height),
'-i', video_path, '-r', str(fps),
'-f', 's16le', '-i', video_path,
'-acodec','pcm_s16le', '-f', 's16le',
'-ac', '1', '-acodec','pcm_s16le',
'-ar', '16000', '-ac', '1',
'-i', audio_path, '-ar', '16000',
#'-fflags', '+genpts', '-i', audio_path,
'-map', '0:v', #'-fflags', '+genpts',
'-map', '1:a', '-map', '0:v',
#'-copyts', '-map', '1:a',
'-acodec', 'aac', #'-copyts',
'-pix_fmt', 'yuv420p', #'-vcodec', "h264", '-acodec', 'aac',
#"-rtmp_buffer", "100", '-pix_fmt', 'yuv420p', #'-vcodec', "h264",
'-f' , 'flv', #"-rtmp_buffer", "100",
push_url] '-f' , 'flv',
self.pipe = subprocess.Popen(command, shell=False) #, stdin=subprocess.PIPE) push_url]
self.fifo_video = open(video_path, 'wb') self.pipe = subprocess.Popen(command, shell=False) #, stdin=subprocess.PIPE)
self.fifo_audio = open(audio_path, 'wb') self.fifo_video = open(video_path, 'wb')
#self.test_step() self.fifo_audio = open(audio_path, 'wb')
''' #self.test_step()
'''
def __enter__(self):
return self def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
if self.opt.asr: def __exit__(self, exc_type, exc_value, traceback):
self.asr.stop() if self.opt.asr:
self.asr.stop()
def push_audio(self,chunk):
self.asr.push_audio(chunk) def push_audio(self,chunk):
self.asr.push_audio(chunk)
def prepare_buffer(self, outputs):
if self.mode == 'image': def prepare_buffer(self, outputs):
return outputs['image'] if self.mode == 'image':
else: return outputs['image']
return np.expand_dims(outputs['depth'], -1).repeat(3, -1) else:
return np.expand_dims(outputs['depth'], -1).repeat(3, -1)
def test_step(self):
def test_step(self):
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
starter.record() #starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
#starter.record()
if self.playing:
try: if self.playing:
data = next(self.loader) try:
except StopIteration: data = next(self.loader)
self.loader = iter(self.data_loader) except StopIteration:
data = next(self.loader) self.loader = iter(self.data_loader)
data = next(self.loader)
if self.opt.asr:
# use the live audio stream if self.opt.asr:
data['auds'] = self.asr.get_next_feat() # use the live audio stream
data['auds'] = self.asr.get_next_feat()
outputs = self.trainer.test_gui_with_data(data, self.W, self.H)
#print(f'[INFO] outputs shape ',outputs['image'].shape) outputs = self.trainer.test_gui_with_data(data, self.W, self.H)
image = (outputs['image'] * 255).astype(np.uint8) #print(f'[INFO] outputs shape ',outputs['image'].shape)
self.streamer.stream_frame(image) image = (outputs['image'] * 255).astype(np.uint8)
#self.pipe.stdin.write(image.tostring()) self.streamer.stream_frame(image)
for _ in range(2): #self.pipe.stdin.write(image.tostring())
frame = self.asr.get_audio_out() for _ in range(2):
#print(f'[INFO] get_audio_out shape ',frame.shape) frame = self.asr.get_audio_out()
self.streamer.stream_frame_audio(frame) #print(f'[INFO] get_audio_out shape ',frame.shape)
# frame = (frame * 32767).astype(np.int16).tobytes() self.streamer.stream_frame_audio(frame)
# self.fifo_audio.write(frame) # frame = (frame * 32767).astype(np.int16).tobytes()
else: # self.fifo_audio.write(frame)
if self.audio_features is not None: else:
auds = get_audio_features(self.audio_features, self.opt.att, self.audio_idx) if self.audio_features is not None:
else: auds = get_audio_features(self.audio_features, self.opt.att, self.audio_idx)
auds = None else:
outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.W, self.H, auds, self.eye_area, self.ind_index, self.bg_color, self.spp, self.downscale) auds = None
outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.W, self.H, auds, self.eye_area, self.ind_index, self.bg_color, self.spp, self.downscale)
ender.record()
torch.cuda.synchronize() #ender.record()
t = starter.elapsed_time(ender) #torch.cuda.synchronize()
#t = starter.elapsed_time(ender)
def render(self):
if self.opt.asr: def render(self):
self.asr.warm_up() if self.opt.asr:
while True: #todo self.asr.warm_up()
# update texture every frame while True: #todo
# audio stream thread... # update texture every frame
if self.opt.asr and self.playing: # audio stream thread...
# run 2 ASR steps (audio is at 50FPS, video is at 25FPS) t = time.time()
for _ in range(2): if self.opt.asr and self.playing:
self.asr.run_step() # run 2 ASR steps (audio is at 50FPS, video is at 25FPS)
self.test_step() for _ in range(2):
self.asr.run_step()
self.test_step()
delay = 0.04 - (time.time() - t) #40ms
if delay > 0:
time.sleep(delay)