diff --git a/asrreal.py b/asrreal.py index 73918d9..3068acd 100644 --- a/asrreal.py +++ b/asrreal.py @@ -1,478 +1,480 @@ -import time -import numpy as np -import torch -import torch.nn.functional as F -from transformers import AutoModelForCTC, AutoProcessor - -#import pyaudio -import soundfile as sf -import resampy - -import queue -from queue import Queue -#from collections import deque -from threading import Thread, Event -from io import BytesIO - - -def _read_frame(stream, exit_event, queue, chunk): - - while True: - if exit_event.is_set(): - print(f'[INFO] read frame thread ends') - break - frame = stream.read(chunk, exception_on_overflow=False) - frame = np.frombuffer(frame, dtype=np.int16).astype(np.float32) / 32767 # [chunk] - queue.put(frame) - -def _play_frame(stream, exit_event, queue, chunk): - - while True: - if exit_event.is_set(): - print(f'[INFO] play frame thread ends') - break - frame = queue.get() - frame = (frame * 32767).astype(np.int16).tobytes() - stream.write(frame, chunk) - -class ASR: - def __init__(self, opt): - - self.opt = opt - - self.play = opt.asr_play #false - - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.fps = opt.fps # 20 ms per frame - self.sample_rate = 16000 - self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000) - self.mode = 'live' if opt.asr_wav == '' else 'file' - - if 'esperanto' in self.opt.asr_model: - self.audio_dim = 44 - elif 'deepspeech' in self.opt.asr_model: - self.audio_dim = 29 - else: - self.audio_dim = 32 - - # prepare context cache - # each segment is (stride_left + ctx + stride_right) * 20ms, latency should be (ctx + stride_right) * 20ms - self.context_size = opt.m - self.stride_left_size = opt.l - self.stride_right_size = opt.r - self.text = '[START]\n' - self.terminated = False - self.frames = [] - self.inwarm = False - - # pad left frames - if self.stride_left_size > 0: - self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size) - - - self.exit_event = Event() - #self.audio_instance = pyaudio.PyAudio() #not need - - # create input stream - if self.mode == 'file': #live mode - self.file_stream = self.create_file_stream() - else: - self.queue = Queue() - self.input_stream = BytesIO() - self.output_queue = Queue() - # start a background process to read frames - #self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk) - #self.queue = Queue() - #self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk)) - - # play out the audio too...? - if self.play: - self.output_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=False, output=True, frames_per_buffer=self.chunk) - self.output_queue = Queue() - self.process_play_frame = Thread(target=_play_frame, args=(self.output_stream, self.exit_event, self.output_queue, self.chunk)) - - # current location of audio - self.idx = 0 - - # create wav2vec model - print(f'[INFO] loading ASR model {self.opt.asr_model}...') - self.processor = AutoProcessor.from_pretrained(opt.asr_model) - self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device) - - # prepare to save logits - if self.opt.asr_save_feats: - self.all_feats = [] - - # the extracted features - # use a loop queue to efficiently record endless features: [f--t---][-------][-------] - self.feat_buffer_size = 4 - self.feat_buffer_idx = 0 - self.feat_queue = torch.zeros(self.feat_buffer_size * self.context_size, self.audio_dim, dtype=torch.float32, device=self.device) - - # TODO: hard coded 16 and 8 window size... - self.front = self.feat_buffer_size * self.context_size - 8 # fake padding - self.tail = 8 - # attention window... - self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding... - - # warm up steps needed: mid + right + window_size + attention_size - self.warm_up_steps = self.context_size + self.stride_right_size + 8 + 2 * 3 - - self.listening = False - self.playing = False - - def listen(self): - # start - if self.mode == 'live' and not self.listening: - print(f'[INFO] starting read frame thread...') - self.process_read_frame.start() - self.listening = True - - if self.play and not self.playing: - print(f'[INFO] starting play frame thread...') - self.process_play_frame.start() - self.playing = True - - def stop(self): - - self.exit_event.set() - - if self.play: - self.output_stream.stop_stream() - self.output_stream.close() - if self.playing: - self.process_play_frame.join() - self.playing = False - - if self.mode == 'live': - #self.input_stream.stop_stream() todo - self.input_stream.close() - if self.listening: - self.process_read_frame.join() - self.listening = False - - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - - self.stop() - - if self.mode == 'live': - # live mode: also print the result text. - self.text += '\n[END]' - print(self.text) - - def get_next_feat(self): - # return a [1/8, 16] window, for the next input to nerf side. - - while len(self.att_feats) < 8: - # [------f+++t-----] - if self.front < self.tail: - feat = self.feat_queue[self.front:self.tail] - # [++t-----------f+] - else: - feat = torch.cat([self.feat_queue[self.front:], self.feat_queue[:self.tail]], dim=0) - - self.front = (self.front + 2) % self.feat_queue.shape[0] - self.tail = (self.tail + 2) % self.feat_queue.shape[0] - - # print(self.front, self.tail, feat.shape) - - self.att_feats.append(feat.permute(1, 0)) - - att_feat = torch.stack(self.att_feats, dim=0) # [8, 44, 16] - - # discard old - self.att_feats = self.att_feats[1:] - - return att_feat - - def run_step(self): - - if self.terminated: - return - - # get a frame of audio - frame = self.get_audio_frame() - - # the last frame - if frame is None: - # terminate, but always run the network for the left frames - self.terminated = True - else: - self.frames.append(frame) - # put to output - #if self.play: - self.output_queue.put(frame) - # context not enough, do not run network. - if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size: - return - - inputs = np.concatenate(self.frames) # [N * chunk] - - # discard the old part to save memory - if not self.terminated: - self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] - - print(f'[INFO] frame_to_text... ') - logits, labels, text = self.frame_to_text(inputs) - feats = logits # better lips-sync than labels - - # save feats - if self.opt.asr_save_feats: - self.all_feats.append(feats) - - # record the feats efficiently.. (no concat, constant memory) - start = self.feat_buffer_idx * self.context_size - end = start + feats.shape[0] - self.feat_queue[start:end] = feats - self.feat_buffer_idx = (self.feat_buffer_idx + 1) % self.feat_buffer_size - - # very naive, just concat the text output. - if text != '': - self.text = self.text + ' ' + text - - # will only run once at ternimation - if self.terminated: - self.text += '\n[END]' - print(self.text) - if self.opt.asr_save_feats: - print(f'[INFO] save all feats for training purpose... ') - feats = torch.cat(self.all_feats, dim=0) # [N, C] - # print('[INFO] before unfold', feats.shape) - window_size = 16 - padding = window_size // 2 - feats = feats.view(-1, self.audio_dim).permute(1, 0).contiguous() # [C, M] - feats = feats.view(1, self.audio_dim, -1, 1) # [1, C, M, 1] - unfold_feats = F.unfold(feats, kernel_size=(window_size, 1), padding=(padding, 0), stride=(2, 1)) # [1, C * window_size, M / 2 + 1] - unfold_feats = unfold_feats.view(self.audio_dim, window_size, -1).permute(2, 1, 0).contiguous() # [C, window_size, M / 2 + 1] --> [M / 2 + 1, window_size, C] - # print('[INFO] after unfold', unfold_feats.shape) - # save to a npy file - if 'esperanto' in self.opt.asr_model: - output_path = self.opt.asr_wav.replace('.wav', '_eo.npy') - else: - output_path = self.opt.asr_wav.replace('.wav', '.npy') - np.save(output_path, unfold_feats.cpu().numpy()) - print(f"[INFO] saved logits to {output_path}") - - def create_file_stream(self): - - stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64 - stream = stream.astype(np.float32) - - if stream.ndim > 1: - print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') - stream = stream[:, 0] - - if sample_rate != self.sample_rate: - print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') - stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) - - print(f'[INFO] loaded audio stream {self.opt.asr_wav}: {stream.shape}') - - return stream - - - def create_pyaudio_stream(self): - - import pyaudio - - print(f'[INFO] creating live audio stream ...') - - audio = pyaudio.PyAudio() - - # get devices - info = audio.get_host_api_info_by_index(0) - n_devices = info.get('deviceCount') - - for i in range(0, n_devices): - if (audio.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: - name = audio.get_device_info_by_host_api_device_index(0, i).get('name') - print(f'[INFO] choose audio device {name}, id {i}') - break - - # get stream - stream = audio.open(input_device_index=i, - format=pyaudio.paInt16, - channels=1, - rate=self.sample_rate, - input=True, - frames_per_buffer=self.chunk) - - return audio, stream - - - def get_audio_frame(self): - - if self.inwarm: # warm up - return np.zeros(self.chunk, dtype=np.float32) - - if self.mode == 'file': - - if self.idx < self.file_stream.shape[0]: - frame = self.file_stream[self.idx: self.idx + self.chunk] - self.idx = self.idx + self.chunk - return frame - else: - return None - - else: - try: - frame = self.queue.get(block=False) - print(f'[INFO] get frame {frame.shape}') - except queue.Empty: - frame = np.zeros(self.chunk, dtype=np.float32) - - self.idx = self.idx + self.chunk - - return frame - - - def frame_to_text(self, frame): - # frame: [N * 320], N = (context_size + 2 * stride_size) - - inputs = self.processor(frame, sampling_rate=self.sample_rate, return_tensors="pt", padding=True) - - with torch.no_grad(): - result = self.model(inputs.input_values.to(self.device)) - logits = result.logits # [1, N - 1, 32] - - # cut off stride - left = max(0, self.stride_left_size) - right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input. - - # do not cut right if terminated. - if self.terminated: - right = logits.shape[1] - - logits = logits[:, left:right] - - # print(frame.shape, inputs.input_values.shape, logits.shape) - - predicted_ids = torch.argmax(logits, dim=-1) - transcription = self.processor.batch_decode(predicted_ids)[0].lower() - - - # for esperanto - # labels = np.array(['ŭ', '»', 'c', 'ĵ', 'ñ', '”', '„', '“', 'ǔ', 'o', 'ĝ', 'm', 'k', 'd', 'a', 'ŝ', 'z', 'i', '«', '—', '‘', 'ĥ', 'f', 'y', 'h', 'j', '|', 'r', 'u', 'ĉ', 's', '–', 'fi', 'l', 'p', '’', 'g', 'v', 't', 'b', 'n', 'e', '[UNK]', '[PAD]']) - - # labels = np.array([' ', ' ', ' ', '-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z']) - # print(''.join(labels[predicted_ids[0].detach().cpu().long().numpy()])) - # print(predicted_ids[0]) - # print(transcription) - - return logits[0], predicted_ids[0], transcription # [N,] - - def create_bytes_stream(self,byte_stream): - #byte_stream=BytesIO(buffer) - stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64 - print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}') - stream = stream.astype(np.float32) - - if stream.ndim > 1: - print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') - stream = stream[:, 0] - - if sample_rate != self.sample_rate: - print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') - stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) - - return stream - - def push_audio(self,buffer): - print(f'[INFO] push_audio {len(buffer)}') - # if len(buffer)>0: - # byte_stream=BytesIO(buffer) - # stream = self.create_bytes_stream(byte_stream) - # streamlen = stream.shape[0] - # idx=0 - # while streamlen >= self.chunk: - # self.queue.put(stream[idx:idx+self.chunk]) - # streamlen -= self.chunk - # idx += self.chunk - # if streamlen>0: - # self.queue.put(stream[idx:]) - self.input_stream.write(buffer) - if len(buffer)<=0: - self.input_stream.seek(0) - stream = self.create_bytes_stream(self.input_stream) - streamlen = stream.shape[0] - idx=0 - while streamlen >= self.chunk: - self.queue.put(stream[idx:idx+self.chunk]) - streamlen -= self.chunk - idx += self.chunk - if streamlen>0: - self.queue.put(stream[idx:]) - - def get_audio_out(self): - return self.output_queue.get() - - def run(self): - - self.listen() - - while not self.terminated: - self.run_step() - - def clear_queue(self): - # clear the queue, to reduce potential latency... - print(f'[INFO] clear queue') - if self.mode == 'live': - self.queue.queue.clear() - if self.play: - self.output_queue.queue.clear() - - def warm_up(self): - - #self.listen() - - self.inwarm = True - print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s') - t = time.time() - for _ in range(self.warm_up_steps): - self.run_step() - if torch.cuda.is_available(): - torch.cuda.synchronize() - t = time.time() - t - print(f'[INFO] warm-up done, actual latency = {t:.6f}s') - self.inwarm = False - - #self.clear_queue() - - - - -if __name__ == '__main__': - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument('--wav', type=str, default='') - parser.add_argument('--play', action='store_true', help="play out the audio") - - parser.add_argument('--model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') - # parser.add_argument('--model', type=str, default='facebook/wav2vec2-large-960h-lv60-self') - - parser.add_argument('--save_feats', action='store_true') - # audio FPS - parser.add_argument('--fps', type=int, default=50) - # sliding window left-middle-right length. - parser.add_argument('-l', type=int, default=10) - parser.add_argument('-m', type=int, default=50) - parser.add_argument('-r', type=int, default=10) - - opt = parser.parse_args() - - # fix - opt.asr_wav = opt.wav - opt.asr_play = opt.play - opt.asr_model = opt.model - opt.asr_save_feats = opt.save_feats - - if 'deepspeech' in opt.asr_model: - raise ValueError("DeepSpeech features should not use this code to extract...") - - with ASR(opt) as asr: +import time +import numpy as np +import torch +import torch.nn.functional as F +from transformers import AutoModelForCTC, AutoProcessor + +#import pyaudio +import soundfile as sf +import resampy + +import queue +from queue import Queue +#from collections import deque +from threading import Thread, Event +from io import BytesIO + + +def _read_frame(stream, exit_event, queue, chunk): + + while True: + if exit_event.is_set(): + print(f'[INFO] read frame thread ends') + break + frame = stream.read(chunk, exception_on_overflow=False) + frame = np.frombuffer(frame, dtype=np.int16).astype(np.float32) / 32767 # [chunk] + queue.put(frame) + +def _play_frame(stream, exit_event, queue, chunk): + + while True: + if exit_event.is_set(): + print(f'[INFO] play frame thread ends') + break + frame = queue.get() + frame = (frame * 32767).astype(np.int16).tobytes() + stream.write(frame, chunk) + +class ASR: + def __init__(self, opt): + + self.opt = opt + + self.play = opt.asr_play #false + + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.fps = opt.fps # 20 ms per frame + self.sample_rate = 16000 + self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000) + self.mode = 'live' if opt.asr_wav == '' else 'file' + + if 'esperanto' in self.opt.asr_model: + self.audio_dim = 44 + elif 'deepspeech' in self.opt.asr_model: + self.audio_dim = 29 + else: + self.audio_dim = 32 + + # prepare context cache + # each segment is (stride_left + ctx + stride_right) * 20ms, latency should be (ctx + stride_right) * 20ms + self.context_size = opt.m + self.stride_left_size = opt.l + self.stride_right_size = opt.r + self.text = '[START]\n' + self.terminated = False + self.frames = [] + self.inwarm = False + + # pad left frames + if self.stride_left_size > 0: + self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size) + + + self.exit_event = Event() + #self.audio_instance = pyaudio.PyAudio() #not need + + # create input stream + if self.mode == 'file': #live mode + self.file_stream = self.create_file_stream() + else: + self.queue = Queue() + self.input_stream = BytesIO() + self.output_queue = Queue() + # start a background process to read frames + #self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk) + #self.queue = Queue() + #self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk)) + + # play out the audio too...? + if self.play: + self.output_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=False, output=True, frames_per_buffer=self.chunk) + self.output_queue = Queue() + self.process_play_frame = Thread(target=_play_frame, args=(self.output_stream, self.exit_event, self.output_queue, self.chunk)) + + # current location of audio + self.idx = 0 + + # create wav2vec model + print(f'[INFO] loading ASR model {self.opt.asr_model}...') + self.processor = AutoProcessor.from_pretrained(opt.asr_model) + self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device) + + # prepare to save logits + if self.opt.asr_save_feats: + self.all_feats = [] + + # the extracted features + # use a loop queue to efficiently record endless features: [f--t---][-------][-------] + self.feat_buffer_size = 4 + self.feat_buffer_idx = 0 + self.feat_queue = torch.zeros(self.feat_buffer_size * self.context_size, self.audio_dim, dtype=torch.float32, device=self.device) + + # TODO: hard coded 16 and 8 window size... + self.front = self.feat_buffer_size * self.context_size - 8 # fake padding + self.tail = 8 + # attention window... + self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding... + + # warm up steps needed: mid + right + window_size + attention_size + self.warm_up_steps = self.context_size + self.stride_right_size + 8 + 2 * 3 + + self.listening = False + self.playing = False + + def listen(self): + # start + if self.mode == 'live' and not self.listening: + print(f'[INFO] starting read frame thread...') + self.process_read_frame.start() + self.listening = True + + if self.play and not self.playing: + print(f'[INFO] starting play frame thread...') + self.process_play_frame.start() + self.playing = True + + def stop(self): + + self.exit_event.set() + + if self.play: + self.output_stream.stop_stream() + self.output_stream.close() + if self.playing: + self.process_play_frame.join() + self.playing = False + + if self.mode == 'live': + #self.input_stream.stop_stream() todo + self.input_stream.close() + if self.listening: + self.process_read_frame.join() + self.listening = False + + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + + self.stop() + + if self.mode == 'live': + # live mode: also print the result text. + self.text += '\n[END]' + print(self.text) + + def get_next_feat(self): + # return a [1/8, 16] window, for the next input to nerf side. + + while len(self.att_feats) < 8: + # [------f+++t-----] + if self.front < self.tail: + feat = self.feat_queue[self.front:self.tail] + # [++t-----------f+] + else: + feat = torch.cat([self.feat_queue[self.front:], self.feat_queue[:self.tail]], dim=0) + + self.front = (self.front + 2) % self.feat_queue.shape[0] + self.tail = (self.tail + 2) % self.feat_queue.shape[0] + + # print(self.front, self.tail, feat.shape) + + self.att_feats.append(feat.permute(1, 0)) + + att_feat = torch.stack(self.att_feats, dim=0) # [8, 44, 16] + + # discard old + self.att_feats = self.att_feats[1:] + + return att_feat + + def run_step(self): + + if self.terminated: + return + + # get a frame of audio + frame = self.get_audio_frame() + + # the last frame + if frame is None: + # terminate, but always run the network for the left frames + self.terminated = True + else: + self.frames.append(frame) + # put to output + #if self.play: + self.output_queue.put(frame) + # context not enough, do not run network. + if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size: + return + + inputs = np.concatenate(self.frames) # [N * chunk] + + # discard the old part to save memory + if not self.terminated: + self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] + + print(f'[INFO] frame_to_text... ') + logits, labels, text = self.frame_to_text(inputs) + feats = logits # better lips-sync than labels + + # save feats + if self.opt.asr_save_feats: + self.all_feats.append(feats) + + # record the feats efficiently.. (no concat, constant memory) + start = self.feat_buffer_idx * self.context_size + end = start + feats.shape[0] + self.feat_queue[start:end] = feats + self.feat_buffer_idx = (self.feat_buffer_idx + 1) % self.feat_buffer_size + + # very naive, just concat the text output. + if text != '': + self.text = self.text + ' ' + text + + # will only run once at ternimation + if self.terminated: + self.text += '\n[END]' + print(self.text) + if self.opt.asr_save_feats: + print(f'[INFO] save all feats for training purpose... ') + feats = torch.cat(self.all_feats, dim=0) # [N, C] + # print('[INFO] before unfold', feats.shape) + window_size = 16 + padding = window_size // 2 + feats = feats.view(-1, self.audio_dim).permute(1, 0).contiguous() # [C, M] + feats = feats.view(1, self.audio_dim, -1, 1) # [1, C, M, 1] + unfold_feats = F.unfold(feats, kernel_size=(window_size, 1), padding=(padding, 0), stride=(2, 1)) # [1, C * window_size, M / 2 + 1] + unfold_feats = unfold_feats.view(self.audio_dim, window_size, -1).permute(2, 1, 0).contiguous() # [C, window_size, M / 2 + 1] --> [M / 2 + 1, window_size, C] + # print('[INFO] after unfold', unfold_feats.shape) + # save to a npy file + if 'esperanto' in self.opt.asr_model: + output_path = self.opt.asr_wav.replace('.wav', '_eo.npy') + else: + output_path = self.opt.asr_wav.replace('.wav', '.npy') + np.save(output_path, unfold_feats.cpu().numpy()) + print(f"[INFO] saved logits to {output_path}") + + def create_file_stream(self): + + stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64 + stream = stream.astype(np.float32) + + if stream.ndim > 1: + print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') + stream = stream[:, 0] + + if sample_rate != self.sample_rate: + print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') + stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) + + print(f'[INFO] loaded audio stream {self.opt.asr_wav}: {stream.shape}') + + return stream + + + def create_pyaudio_stream(self): + + import pyaudio + + print(f'[INFO] creating live audio stream ...') + + audio = pyaudio.PyAudio() + + # get devices + info = audio.get_host_api_info_by_index(0) + n_devices = info.get('deviceCount') + + for i in range(0, n_devices): + if (audio.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: + name = audio.get_device_info_by_host_api_device_index(0, i).get('name') + print(f'[INFO] choose audio device {name}, id {i}') + break + + # get stream + stream = audio.open(input_device_index=i, + format=pyaudio.paInt16, + channels=1, + rate=self.sample_rate, + input=True, + frames_per_buffer=self.chunk) + + return audio, stream + + + def get_audio_frame(self): + + if self.inwarm: # warm up + return np.zeros(self.chunk, dtype=np.float32) + + if self.mode == 'file': + + if self.idx < self.file_stream.shape[0]: + frame = self.file_stream[self.idx: self.idx + self.chunk] + self.idx = self.idx + self.chunk + return frame + else: + return None + + else: + try: + frame = self.queue.get(block=False) + print(f'[INFO] get frame {frame.shape}') + except queue.Empty: + frame = np.zeros(self.chunk, dtype=np.float32) + + self.idx = self.idx + self.chunk + + return frame + + + def frame_to_text(self, frame): + # frame: [N * 320], N = (context_size + 2 * stride_size) + + inputs = self.processor(frame, sampling_rate=self.sample_rate, return_tensors="pt", padding=True) + + with torch.no_grad(): + result = self.model(inputs.input_values.to(self.device)) + logits = result.logits # [1, N - 1, 32] + + # cut off stride + left = max(0, self.stride_left_size) + right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input. + + # do not cut right if terminated. + if self.terminated: + right = logits.shape[1] + + logits = logits[:, left:right] + + # print(frame.shape, inputs.input_values.shape, logits.shape) + + predicted_ids = torch.argmax(logits, dim=-1) + transcription = self.processor.batch_decode(predicted_ids)[0].lower() + + + # for esperanto + # labels = np.array(['ŭ', '»', 'c', 'ĵ', 'ñ', '”', '„', '“', 'ǔ', 'o', 'ĝ', 'm', 'k', 'd', 'a', 'ŝ', 'z', 'i', '«', '—', '‘', 'ĥ', 'f', 'y', 'h', 'j', '|', 'r', 'u', 'ĉ', 's', '–', 'fi', 'l', 'p', '’', 'g', 'v', 't', 'b', 'n', 'e', '[UNK]', '[PAD]']) + + # labels = np.array([' ', ' ', ' ', '-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z']) + # print(''.join(labels[predicted_ids[0].detach().cpu().long().numpy()])) + # print(predicted_ids[0]) + # print(transcription) + + return logits[0], predicted_ids[0], transcription # [N,] + + def create_bytes_stream(self,byte_stream): + #byte_stream=BytesIO(buffer) + stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64 + print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}') + stream = stream.astype(np.float32) + + if stream.ndim > 1: + print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') + stream = stream[:, 0] + + if sample_rate != self.sample_rate: + print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') + stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) + + return stream + + def push_audio(self,buffer): + print(f'[INFO] push_audio {len(buffer)}') + # if len(buffer)>0: + # byte_stream=BytesIO(buffer) + # stream = self.create_bytes_stream(byte_stream) + # streamlen = stream.shape[0] + # idx=0 + # while streamlen >= self.chunk: + # self.queue.put(stream[idx:idx+self.chunk]) + # streamlen -= self.chunk + # idx += self.chunk + # if streamlen>0: + # self.queue.put(stream[idx:]) + self.input_stream.write(buffer) + if len(buffer)<=0: + self.input_stream.seek(0) + stream = self.create_bytes_stream(self.input_stream) + streamlen = stream.shape[0] + idx=0 + while streamlen >= self.chunk: + self.queue.put(stream[idx:idx+self.chunk]) + streamlen -= self.chunk + idx += self.chunk + if streamlen>0: + self.queue.put(stream[idx:]) + self.input_stream.seek(0) + self.input_stream.truncate() + + def get_audio_out(self): + return self.output_queue.get() + + def run(self): + + self.listen() + + while not self.terminated: + self.run_step() + + def clear_queue(self): + # clear the queue, to reduce potential latency... + print(f'[INFO] clear queue') + if self.mode == 'live': + self.queue.queue.clear() + if self.play: + self.output_queue.queue.clear() + + def warm_up(self): + + #self.listen() + + self.inwarm = True + print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s') + t = time.time() + for _ in range(self.warm_up_steps): + self.run_step() + if torch.cuda.is_available(): + torch.cuda.synchronize() + t = time.time() - t + print(f'[INFO] warm-up done, actual latency = {t:.6f}s') + self.inwarm = False + + #self.clear_queue() + + + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--wav', type=str, default='') + parser.add_argument('--play', action='store_true', help="play out the audio") + + parser.add_argument('--model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') + # parser.add_argument('--model', type=str, default='facebook/wav2vec2-large-960h-lv60-self') + + parser.add_argument('--save_feats', action='store_true') + # audio FPS + parser.add_argument('--fps', type=int, default=50) + # sliding window left-middle-right length. + parser.add_argument('-l', type=int, default=10) + parser.add_argument('-m', type=int, default=50) + parser.add_argument('-r', type=int, default=10) + + opt = parser.parse_args() + + # fix + opt.asr_wav = opt.wav + opt.asr_play = opt.play + opt.asr_model = opt.model + opt.asr_save_feats = opt.save_feats + + if 'deepspeech' in opt.asr_model: + raise ValueError("DeepSpeech features should not use this code to extract...") + + with ASR(opt) as asr: asr.run() \ No newline at end of file diff --git a/nerfreal.py b/nerfreal.py index 0504d3d..9a400b7 100644 --- a/nerfreal.py +++ b/nerfreal.py @@ -1,178 +1,184 @@ -import math -import torch -import numpy as np - -#from .utils import * -import subprocess -import os - -from asrreal import ASR -from rtmp_streaming import StreamerConfig, Streamer - -class NeRFReal: - def __init__(self, opt, trainer, data_loader, debug=True): - self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. - self.W = opt.W - self.H = opt.H - self.debug = debug - self.training = False - self.step = 0 # training step - - self.trainer = trainer - self.data_loader = data_loader - - # use dataloader's bg - bg_img = data_loader._data.bg_img #.view(1, -1, 3) - if self.H != bg_img.shape[0] or self.W != bg_img.shape[1]: - bg_img = F.interpolate(bg_img.permute(2, 0, 1).unsqueeze(0).contiguous(), (self.H, self.W), mode='bilinear').squeeze(0).permute(1, 2, 0).contiguous() - self.bg_color = bg_img.view(1, -1, 3) - - # audio features (from dataloader, only used in non-playing mode) - self.audio_features = data_loader._data.auds # [N, 29, 16] - self.audio_idx = 0 - - # control eye - self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item() - - # playing seq from dataloader, or pause. - self.playing = True #False todo - self.loader = iter(data_loader) - - self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32) - self.need_update = True # camera moved, should reset accumulation - self.spp = 1 # sample per pixel - self.mode = 'image' # choose from ['image', 'depth'] - - self.dynamic_resolution = False # assert False! - self.downscale = 1 - self.train_steps = 16 - - self.ind_index = 0 - self.ind_num = trainer.model.individual_codes.shape[0] - - # build asr - if self.opt.asr: - self.asr = ASR(opt) - - fps=25 - #push_url='rtmp://localhost/live/livestream' #'data/video/output_0.mp4' - sc = StreamerConfig() - sc.source_width = self.W - sc.source_height = self.H - sc.stream_width = self.W - sc.stream_height = self.H - sc.stream_fps = fps - sc.stream_bitrate = 1000000 - sc.stream_profile = 'main' #'high444' # 'main' - sc.audio_channel = 1 - sc.sample_rate = 16000 - sc.stream_server = opt.push_url - - self.streamer = Streamer() - self.streamer.init(sc) - self.streamer.enable_av_debug_log() - - ''' - video_path = 'video_stream' - if not os.path.exists(video_path): - os.mkfifo(video_path, mode=0o777) - audio_path = 'audio_stream' - if not os.path.exists(audio_path): - os.mkfifo(audio_path, mode=0o777) - width=450 - height=450 - command = ['ffmpeg', - '-y', #'-an', - #'-re', - '-f', 'rawvideo', - '-vcodec','rawvideo', - '-pix_fmt', 'rgb24', #像素格式 - '-s', "{}x{}".format(width, height), - '-r', str(fps), - '-i', video_path, - '-f', 's16le', - '-acodec','pcm_s16le', - '-ac', '1', - '-ar', '16000', - '-i', audio_path, - #'-fflags', '+genpts', - '-map', '0:v', - '-map', '1:a', - #'-copyts', - '-acodec', 'aac', - '-pix_fmt', 'yuv420p', #'-vcodec', "h264", - #"-rtmp_buffer", "100", - '-f' , 'flv', - push_url] - self.pipe = subprocess.Popen(command, shell=False) #, stdin=subprocess.PIPE) - self.fifo_video = open(video_path, 'wb') - self.fifo_audio = open(audio_path, 'wb') - #self.test_step() - ''' - - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - if self.opt.asr: - self.asr.stop() - - def push_audio(self,chunk): - self.asr.push_audio(chunk) - - def prepare_buffer(self, outputs): - if self.mode == 'image': - return outputs['image'] - else: - return np.expand_dims(outputs['depth'], -1).repeat(3, -1) - - def test_step(self): - - starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) - starter.record() - - if self.playing: - try: - data = next(self.loader) - except StopIteration: - self.loader = iter(self.data_loader) - data = next(self.loader) - - if self.opt.asr: - # use the live audio stream - data['auds'] = self.asr.get_next_feat() - - outputs = self.trainer.test_gui_with_data(data, self.W, self.H) - #print(f'[INFO] outputs shape ',outputs['image'].shape) - image = (outputs['image'] * 255).astype(np.uint8) - self.streamer.stream_frame(image) - #self.pipe.stdin.write(image.tostring()) - for _ in range(2): - frame = self.asr.get_audio_out() - #print(f'[INFO] get_audio_out shape ',frame.shape) - self.streamer.stream_frame_audio(frame) - # frame = (frame * 32767).astype(np.int16).tobytes() - # self.fifo_audio.write(frame) - else: - if self.audio_features is not None: - auds = get_audio_features(self.audio_features, self.opt.att, self.audio_idx) - else: - auds = None - outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.W, self.H, auds, self.eye_area, self.ind_index, self.bg_color, self.spp, self.downscale) - - ender.record() - torch.cuda.synchronize() - t = starter.elapsed_time(ender) - - def render(self): - if self.opt.asr: - self.asr.warm_up() - while True: #todo - # update texture every frame - # audio stream thread... - if self.opt.asr and self.playing: - # run 2 ASR steps (audio is at 50FPS, video is at 25FPS) - for _ in range(2): - self.asr.run_step() - self.test_step() \ No newline at end of file +import math +import torch +import numpy as np + +#from .utils import * +import subprocess +import os +import time + +from asrreal import ASR +from rtmp_streaming import StreamerConfig, Streamer + +class NeRFReal: + def __init__(self, opt, trainer, data_loader, debug=True): + self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. + self.W = opt.W + self.H = opt.H + self.debug = debug + self.training = False + self.step = 0 # training step + + self.trainer = trainer + self.data_loader = data_loader + + # use dataloader's bg + bg_img = data_loader._data.bg_img #.view(1, -1, 3) + if self.H != bg_img.shape[0] or self.W != bg_img.shape[1]: + bg_img = F.interpolate(bg_img.permute(2, 0, 1).unsqueeze(0).contiguous(), (self.H, self.W), mode='bilinear').squeeze(0).permute(1, 2, 0).contiguous() + self.bg_color = bg_img.view(1, -1, 3) + + # audio features (from dataloader, only used in non-playing mode) + self.audio_features = data_loader._data.auds # [N, 29, 16] + self.audio_idx = 0 + + # control eye + self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item() + + # playing seq from dataloader, or pause. + self.playing = True #False todo + self.loader = iter(data_loader) + + self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32) + self.need_update = True # camera moved, should reset accumulation + self.spp = 1 # sample per pixel + self.mode = 'image' # choose from ['image', 'depth'] + + self.dynamic_resolution = False # assert False! + self.downscale = 1 + self.train_steps = 16 + + self.ind_index = 0 + self.ind_num = trainer.model.individual_codes.shape[0] + + # build asr + if self.opt.asr: + self.asr = ASR(opt) + + fps=25 + #push_url='rtmp://localhost/live/livestream' #'data/video/output_0.mp4' + sc = StreamerConfig() + sc.source_width = self.W + sc.source_height = self.H + sc.stream_width = self.W + sc.stream_height = self.H + sc.stream_fps = fps + sc.stream_bitrate = 1000000 + sc.stream_profile = 'main' #'high444' # 'main' + sc.audio_channel = 1 + sc.sample_rate = 16000 + sc.stream_server = opt.push_url + + self.streamer = Streamer() + self.streamer.init(sc) + self.streamer.enable_av_debug_log() + + ''' + video_path = 'video_stream' + if not os.path.exists(video_path): + os.mkfifo(video_path, mode=0o777) + audio_path = 'audio_stream' + if not os.path.exists(audio_path): + os.mkfifo(audio_path, mode=0o777) + width=450 + height=450 + command = ['ffmpeg', + '-y', #'-an', + #'-re', + '-f', 'rawvideo', + '-vcodec','rawvideo', + '-pix_fmt', 'rgb24', #像素格式 + '-s', "{}x{}".format(width, height), + '-r', str(fps), + '-i', video_path, + '-f', 's16le', + '-acodec','pcm_s16le', + '-ac', '1', + '-ar', '16000', + '-i', audio_path, + #'-fflags', '+genpts', + '-map', '0:v', + '-map', '1:a', + #'-copyts', + '-acodec', 'aac', + '-pix_fmt', 'yuv420p', #'-vcodec', "h264", + #"-rtmp_buffer", "100", + '-f' , 'flv', + push_url] + self.pipe = subprocess.Popen(command, shell=False) #, stdin=subprocess.PIPE) + self.fifo_video = open(video_path, 'wb') + self.fifo_audio = open(audio_path, 'wb') + #self.test_step() + ''' + + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + if self.opt.asr: + self.asr.stop() + + def push_audio(self,chunk): + self.asr.push_audio(chunk) + + def prepare_buffer(self, outputs): + if self.mode == 'image': + return outputs['image'] + else: + return np.expand_dims(outputs['depth'], -1).repeat(3, -1) + + def test_step(self): + + #starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + #starter.record() + + if self.playing: + try: + data = next(self.loader) + except StopIteration: + self.loader = iter(self.data_loader) + data = next(self.loader) + + if self.opt.asr: + # use the live audio stream + data['auds'] = self.asr.get_next_feat() + + outputs = self.trainer.test_gui_with_data(data, self.W, self.H) + #print(f'[INFO] outputs shape ',outputs['image'].shape) + image = (outputs['image'] * 255).astype(np.uint8) + self.streamer.stream_frame(image) + #self.pipe.stdin.write(image.tostring()) + for _ in range(2): + frame = self.asr.get_audio_out() + #print(f'[INFO] get_audio_out shape ',frame.shape) + self.streamer.stream_frame_audio(frame) + # frame = (frame * 32767).astype(np.int16).tobytes() + # self.fifo_audio.write(frame) + else: + if self.audio_features is not None: + auds = get_audio_features(self.audio_features, self.opt.att, self.audio_idx) + else: + auds = None + outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.W, self.H, auds, self.eye_area, self.ind_index, self.bg_color, self.spp, self.downscale) + + #ender.record() + #torch.cuda.synchronize() + #t = starter.elapsed_time(ender) + + def render(self): + if self.opt.asr: + self.asr.warm_up() + while True: #todo + # update texture every frame + # audio stream thread... + t = time.time() + if self.opt.asr and self.playing: + # run 2 ASR steps (audio is at 50FPS, video is at 25FPS) + for _ in range(2): + self.asr.run_step() + self.test_step() + delay = 0.04 - (time.time() - t) #40ms + if delay > 0: + time.sleep(delay) + \ No newline at end of file