fix audiostream buffer
This commit is contained in:
parent
6a7af5d006
commit
46b0f5abab
956
asrreal.py
956
asrreal.py
|
@ -1,478 +1,480 @@
|
||||||
import time
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import AutoModelForCTC, AutoProcessor
|
from transformers import AutoModelForCTC, AutoProcessor
|
||||||
|
|
||||||
#import pyaudio
|
#import pyaudio
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import resampy
|
import resampy
|
||||||
|
|
||||||
import queue
|
import queue
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
#from collections import deque
|
#from collections import deque
|
||||||
from threading import Thread, Event
|
from threading import Thread, Event
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
|
|
||||||
def _read_frame(stream, exit_event, queue, chunk):
|
def _read_frame(stream, exit_event, queue, chunk):
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if exit_event.is_set():
|
if exit_event.is_set():
|
||||||
print(f'[INFO] read frame thread ends')
|
print(f'[INFO] read frame thread ends')
|
||||||
break
|
break
|
||||||
frame = stream.read(chunk, exception_on_overflow=False)
|
frame = stream.read(chunk, exception_on_overflow=False)
|
||||||
frame = np.frombuffer(frame, dtype=np.int16).astype(np.float32) / 32767 # [chunk]
|
frame = np.frombuffer(frame, dtype=np.int16).astype(np.float32) / 32767 # [chunk]
|
||||||
queue.put(frame)
|
queue.put(frame)
|
||||||
|
|
||||||
def _play_frame(stream, exit_event, queue, chunk):
|
def _play_frame(stream, exit_event, queue, chunk):
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if exit_event.is_set():
|
if exit_event.is_set():
|
||||||
print(f'[INFO] play frame thread ends')
|
print(f'[INFO] play frame thread ends')
|
||||||
break
|
break
|
||||||
frame = queue.get()
|
frame = queue.get()
|
||||||
frame = (frame * 32767).astype(np.int16).tobytes()
|
frame = (frame * 32767).astype(np.int16).tobytes()
|
||||||
stream.write(frame, chunk)
|
stream.write(frame, chunk)
|
||||||
|
|
||||||
class ASR:
|
class ASR:
|
||||||
def __init__(self, opt):
|
def __init__(self, opt):
|
||||||
|
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
|
|
||||||
self.play = opt.asr_play #false
|
self.play = opt.asr_play #false
|
||||||
|
|
||||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
self.fps = opt.fps # 20 ms per frame
|
self.fps = opt.fps # 20 ms per frame
|
||||||
self.sample_rate = 16000
|
self.sample_rate = 16000
|
||||||
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
|
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
|
||||||
self.mode = 'live' if opt.asr_wav == '' else 'file'
|
self.mode = 'live' if opt.asr_wav == '' else 'file'
|
||||||
|
|
||||||
if 'esperanto' in self.opt.asr_model:
|
if 'esperanto' in self.opt.asr_model:
|
||||||
self.audio_dim = 44
|
self.audio_dim = 44
|
||||||
elif 'deepspeech' in self.opt.asr_model:
|
elif 'deepspeech' in self.opt.asr_model:
|
||||||
self.audio_dim = 29
|
self.audio_dim = 29
|
||||||
else:
|
else:
|
||||||
self.audio_dim = 32
|
self.audio_dim = 32
|
||||||
|
|
||||||
# prepare context cache
|
# prepare context cache
|
||||||
# each segment is (stride_left + ctx + stride_right) * 20ms, latency should be (ctx + stride_right) * 20ms
|
# each segment is (stride_left + ctx + stride_right) * 20ms, latency should be (ctx + stride_right) * 20ms
|
||||||
self.context_size = opt.m
|
self.context_size = opt.m
|
||||||
self.stride_left_size = opt.l
|
self.stride_left_size = opt.l
|
||||||
self.stride_right_size = opt.r
|
self.stride_right_size = opt.r
|
||||||
self.text = '[START]\n'
|
self.text = '[START]\n'
|
||||||
self.terminated = False
|
self.terminated = False
|
||||||
self.frames = []
|
self.frames = []
|
||||||
self.inwarm = False
|
self.inwarm = False
|
||||||
|
|
||||||
# pad left frames
|
# pad left frames
|
||||||
if self.stride_left_size > 0:
|
if self.stride_left_size > 0:
|
||||||
self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size)
|
self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size)
|
||||||
|
|
||||||
|
|
||||||
self.exit_event = Event()
|
self.exit_event = Event()
|
||||||
#self.audio_instance = pyaudio.PyAudio() #not need
|
#self.audio_instance = pyaudio.PyAudio() #not need
|
||||||
|
|
||||||
# create input stream
|
# create input stream
|
||||||
if self.mode == 'file': #live mode
|
if self.mode == 'file': #live mode
|
||||||
self.file_stream = self.create_file_stream()
|
self.file_stream = self.create_file_stream()
|
||||||
else:
|
else:
|
||||||
self.queue = Queue()
|
self.queue = Queue()
|
||||||
self.input_stream = BytesIO()
|
self.input_stream = BytesIO()
|
||||||
self.output_queue = Queue()
|
self.output_queue = Queue()
|
||||||
# start a background process to read frames
|
# start a background process to read frames
|
||||||
#self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk)
|
#self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk)
|
||||||
#self.queue = Queue()
|
#self.queue = Queue()
|
||||||
#self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk))
|
#self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk))
|
||||||
|
|
||||||
# play out the audio too...?
|
# play out the audio too...?
|
||||||
if self.play:
|
if self.play:
|
||||||
self.output_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=False, output=True, frames_per_buffer=self.chunk)
|
self.output_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=False, output=True, frames_per_buffer=self.chunk)
|
||||||
self.output_queue = Queue()
|
self.output_queue = Queue()
|
||||||
self.process_play_frame = Thread(target=_play_frame, args=(self.output_stream, self.exit_event, self.output_queue, self.chunk))
|
self.process_play_frame = Thread(target=_play_frame, args=(self.output_stream, self.exit_event, self.output_queue, self.chunk))
|
||||||
|
|
||||||
# current location of audio
|
# current location of audio
|
||||||
self.idx = 0
|
self.idx = 0
|
||||||
|
|
||||||
# create wav2vec model
|
# create wav2vec model
|
||||||
print(f'[INFO] loading ASR model {self.opt.asr_model}...')
|
print(f'[INFO] loading ASR model {self.opt.asr_model}...')
|
||||||
self.processor = AutoProcessor.from_pretrained(opt.asr_model)
|
self.processor = AutoProcessor.from_pretrained(opt.asr_model)
|
||||||
self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device)
|
self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device)
|
||||||
|
|
||||||
# prepare to save logits
|
# prepare to save logits
|
||||||
if self.opt.asr_save_feats:
|
if self.opt.asr_save_feats:
|
||||||
self.all_feats = []
|
self.all_feats = []
|
||||||
|
|
||||||
# the extracted features
|
# the extracted features
|
||||||
# use a loop queue to efficiently record endless features: [f--t---][-------][-------]
|
# use a loop queue to efficiently record endless features: [f--t---][-------][-------]
|
||||||
self.feat_buffer_size = 4
|
self.feat_buffer_size = 4
|
||||||
self.feat_buffer_idx = 0
|
self.feat_buffer_idx = 0
|
||||||
self.feat_queue = torch.zeros(self.feat_buffer_size * self.context_size, self.audio_dim, dtype=torch.float32, device=self.device)
|
self.feat_queue = torch.zeros(self.feat_buffer_size * self.context_size, self.audio_dim, dtype=torch.float32, device=self.device)
|
||||||
|
|
||||||
# TODO: hard coded 16 and 8 window size...
|
# TODO: hard coded 16 and 8 window size...
|
||||||
self.front = self.feat_buffer_size * self.context_size - 8 # fake padding
|
self.front = self.feat_buffer_size * self.context_size - 8 # fake padding
|
||||||
self.tail = 8
|
self.tail = 8
|
||||||
# attention window...
|
# attention window...
|
||||||
self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding...
|
self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding...
|
||||||
|
|
||||||
# warm up steps needed: mid + right + window_size + attention_size
|
# warm up steps needed: mid + right + window_size + attention_size
|
||||||
self.warm_up_steps = self.context_size + self.stride_right_size + 8 + 2 * 3
|
self.warm_up_steps = self.context_size + self.stride_right_size + 8 + 2 * 3
|
||||||
|
|
||||||
self.listening = False
|
self.listening = False
|
||||||
self.playing = False
|
self.playing = False
|
||||||
|
|
||||||
def listen(self):
|
def listen(self):
|
||||||
# start
|
# start
|
||||||
if self.mode == 'live' and not self.listening:
|
if self.mode == 'live' and not self.listening:
|
||||||
print(f'[INFO] starting read frame thread...')
|
print(f'[INFO] starting read frame thread...')
|
||||||
self.process_read_frame.start()
|
self.process_read_frame.start()
|
||||||
self.listening = True
|
self.listening = True
|
||||||
|
|
||||||
if self.play and not self.playing:
|
if self.play and not self.playing:
|
||||||
print(f'[INFO] starting play frame thread...')
|
print(f'[INFO] starting play frame thread...')
|
||||||
self.process_play_frame.start()
|
self.process_play_frame.start()
|
||||||
self.playing = True
|
self.playing = True
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
|
|
||||||
self.exit_event.set()
|
self.exit_event.set()
|
||||||
|
|
||||||
if self.play:
|
if self.play:
|
||||||
self.output_stream.stop_stream()
|
self.output_stream.stop_stream()
|
||||||
self.output_stream.close()
|
self.output_stream.close()
|
||||||
if self.playing:
|
if self.playing:
|
||||||
self.process_play_frame.join()
|
self.process_play_frame.join()
|
||||||
self.playing = False
|
self.playing = False
|
||||||
|
|
||||||
if self.mode == 'live':
|
if self.mode == 'live':
|
||||||
#self.input_stream.stop_stream() todo
|
#self.input_stream.stop_stream() todo
|
||||||
self.input_stream.close()
|
self.input_stream.close()
|
||||||
if self.listening:
|
if self.listening:
|
||||||
self.process_read_frame.join()
|
self.process_read_frame.join()
|
||||||
self.listening = False
|
self.listening = False
|
||||||
|
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
|
||||||
self.stop()
|
self.stop()
|
||||||
|
|
||||||
if self.mode == 'live':
|
if self.mode == 'live':
|
||||||
# live mode: also print the result text.
|
# live mode: also print the result text.
|
||||||
self.text += '\n[END]'
|
self.text += '\n[END]'
|
||||||
print(self.text)
|
print(self.text)
|
||||||
|
|
||||||
def get_next_feat(self):
|
def get_next_feat(self):
|
||||||
# return a [1/8, 16] window, for the next input to nerf side.
|
# return a [1/8, 16] window, for the next input to nerf side.
|
||||||
|
|
||||||
while len(self.att_feats) < 8:
|
while len(self.att_feats) < 8:
|
||||||
# [------f+++t-----]
|
# [------f+++t-----]
|
||||||
if self.front < self.tail:
|
if self.front < self.tail:
|
||||||
feat = self.feat_queue[self.front:self.tail]
|
feat = self.feat_queue[self.front:self.tail]
|
||||||
# [++t-----------f+]
|
# [++t-----------f+]
|
||||||
else:
|
else:
|
||||||
feat = torch.cat([self.feat_queue[self.front:], self.feat_queue[:self.tail]], dim=0)
|
feat = torch.cat([self.feat_queue[self.front:], self.feat_queue[:self.tail]], dim=0)
|
||||||
|
|
||||||
self.front = (self.front + 2) % self.feat_queue.shape[0]
|
self.front = (self.front + 2) % self.feat_queue.shape[0]
|
||||||
self.tail = (self.tail + 2) % self.feat_queue.shape[0]
|
self.tail = (self.tail + 2) % self.feat_queue.shape[0]
|
||||||
|
|
||||||
# print(self.front, self.tail, feat.shape)
|
# print(self.front, self.tail, feat.shape)
|
||||||
|
|
||||||
self.att_feats.append(feat.permute(1, 0))
|
self.att_feats.append(feat.permute(1, 0))
|
||||||
|
|
||||||
att_feat = torch.stack(self.att_feats, dim=0) # [8, 44, 16]
|
att_feat = torch.stack(self.att_feats, dim=0) # [8, 44, 16]
|
||||||
|
|
||||||
# discard old
|
# discard old
|
||||||
self.att_feats = self.att_feats[1:]
|
self.att_feats = self.att_feats[1:]
|
||||||
|
|
||||||
return att_feat
|
return att_feat
|
||||||
|
|
||||||
def run_step(self):
|
def run_step(self):
|
||||||
|
|
||||||
if self.terminated:
|
if self.terminated:
|
||||||
return
|
return
|
||||||
|
|
||||||
# get a frame of audio
|
# get a frame of audio
|
||||||
frame = self.get_audio_frame()
|
frame = self.get_audio_frame()
|
||||||
|
|
||||||
# the last frame
|
# the last frame
|
||||||
if frame is None:
|
if frame is None:
|
||||||
# terminate, but always run the network for the left frames
|
# terminate, but always run the network for the left frames
|
||||||
self.terminated = True
|
self.terminated = True
|
||||||
else:
|
else:
|
||||||
self.frames.append(frame)
|
self.frames.append(frame)
|
||||||
# put to output
|
# put to output
|
||||||
#if self.play:
|
#if self.play:
|
||||||
self.output_queue.put(frame)
|
self.output_queue.put(frame)
|
||||||
# context not enough, do not run network.
|
# context not enough, do not run network.
|
||||||
if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
|
if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
|
||||||
return
|
return
|
||||||
|
|
||||||
inputs = np.concatenate(self.frames) # [N * chunk]
|
inputs = np.concatenate(self.frames) # [N * chunk]
|
||||||
|
|
||||||
# discard the old part to save memory
|
# discard the old part to save memory
|
||||||
if not self.terminated:
|
if not self.terminated:
|
||||||
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
|
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
|
||||||
|
|
||||||
print(f'[INFO] frame_to_text... ')
|
print(f'[INFO] frame_to_text... ')
|
||||||
logits, labels, text = self.frame_to_text(inputs)
|
logits, labels, text = self.frame_to_text(inputs)
|
||||||
feats = logits # better lips-sync than labels
|
feats = logits # better lips-sync than labels
|
||||||
|
|
||||||
# save feats
|
# save feats
|
||||||
if self.opt.asr_save_feats:
|
if self.opt.asr_save_feats:
|
||||||
self.all_feats.append(feats)
|
self.all_feats.append(feats)
|
||||||
|
|
||||||
# record the feats efficiently.. (no concat, constant memory)
|
# record the feats efficiently.. (no concat, constant memory)
|
||||||
start = self.feat_buffer_idx * self.context_size
|
start = self.feat_buffer_idx * self.context_size
|
||||||
end = start + feats.shape[0]
|
end = start + feats.shape[0]
|
||||||
self.feat_queue[start:end] = feats
|
self.feat_queue[start:end] = feats
|
||||||
self.feat_buffer_idx = (self.feat_buffer_idx + 1) % self.feat_buffer_size
|
self.feat_buffer_idx = (self.feat_buffer_idx + 1) % self.feat_buffer_size
|
||||||
|
|
||||||
# very naive, just concat the text output.
|
# very naive, just concat the text output.
|
||||||
if text != '':
|
if text != '':
|
||||||
self.text = self.text + ' ' + text
|
self.text = self.text + ' ' + text
|
||||||
|
|
||||||
# will only run once at ternimation
|
# will only run once at ternimation
|
||||||
if self.terminated:
|
if self.terminated:
|
||||||
self.text += '\n[END]'
|
self.text += '\n[END]'
|
||||||
print(self.text)
|
print(self.text)
|
||||||
if self.opt.asr_save_feats:
|
if self.opt.asr_save_feats:
|
||||||
print(f'[INFO] save all feats for training purpose... ')
|
print(f'[INFO] save all feats for training purpose... ')
|
||||||
feats = torch.cat(self.all_feats, dim=0) # [N, C]
|
feats = torch.cat(self.all_feats, dim=0) # [N, C]
|
||||||
# print('[INFO] before unfold', feats.shape)
|
# print('[INFO] before unfold', feats.shape)
|
||||||
window_size = 16
|
window_size = 16
|
||||||
padding = window_size // 2
|
padding = window_size // 2
|
||||||
feats = feats.view(-1, self.audio_dim).permute(1, 0).contiguous() # [C, M]
|
feats = feats.view(-1, self.audio_dim).permute(1, 0).contiguous() # [C, M]
|
||||||
feats = feats.view(1, self.audio_dim, -1, 1) # [1, C, M, 1]
|
feats = feats.view(1, self.audio_dim, -1, 1) # [1, C, M, 1]
|
||||||
unfold_feats = F.unfold(feats, kernel_size=(window_size, 1), padding=(padding, 0), stride=(2, 1)) # [1, C * window_size, M / 2 + 1]
|
unfold_feats = F.unfold(feats, kernel_size=(window_size, 1), padding=(padding, 0), stride=(2, 1)) # [1, C * window_size, M / 2 + 1]
|
||||||
unfold_feats = unfold_feats.view(self.audio_dim, window_size, -1).permute(2, 1, 0).contiguous() # [C, window_size, M / 2 + 1] --> [M / 2 + 1, window_size, C]
|
unfold_feats = unfold_feats.view(self.audio_dim, window_size, -1).permute(2, 1, 0).contiguous() # [C, window_size, M / 2 + 1] --> [M / 2 + 1, window_size, C]
|
||||||
# print('[INFO] after unfold', unfold_feats.shape)
|
# print('[INFO] after unfold', unfold_feats.shape)
|
||||||
# save to a npy file
|
# save to a npy file
|
||||||
if 'esperanto' in self.opt.asr_model:
|
if 'esperanto' in self.opt.asr_model:
|
||||||
output_path = self.opt.asr_wav.replace('.wav', '_eo.npy')
|
output_path = self.opt.asr_wav.replace('.wav', '_eo.npy')
|
||||||
else:
|
else:
|
||||||
output_path = self.opt.asr_wav.replace('.wav', '.npy')
|
output_path = self.opt.asr_wav.replace('.wav', '.npy')
|
||||||
np.save(output_path, unfold_feats.cpu().numpy())
|
np.save(output_path, unfold_feats.cpu().numpy())
|
||||||
print(f"[INFO] saved logits to {output_path}")
|
print(f"[INFO] saved logits to {output_path}")
|
||||||
|
|
||||||
def create_file_stream(self):
|
def create_file_stream(self):
|
||||||
|
|
||||||
stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64
|
stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64
|
||||||
stream = stream.astype(np.float32)
|
stream = stream.astype(np.float32)
|
||||||
|
|
||||||
if stream.ndim > 1:
|
if stream.ndim > 1:
|
||||||
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
|
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
|
||||||
stream = stream[:, 0]
|
stream = stream[:, 0]
|
||||||
|
|
||||||
if sample_rate != self.sample_rate:
|
if sample_rate != self.sample_rate:
|
||||||
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
|
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
|
||||||
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
|
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
|
||||||
|
|
||||||
print(f'[INFO] loaded audio stream {self.opt.asr_wav}: {stream.shape}')
|
print(f'[INFO] loaded audio stream {self.opt.asr_wav}: {stream.shape}')
|
||||||
|
|
||||||
return stream
|
return stream
|
||||||
|
|
||||||
|
|
||||||
def create_pyaudio_stream(self):
|
def create_pyaudio_stream(self):
|
||||||
|
|
||||||
import pyaudio
|
import pyaudio
|
||||||
|
|
||||||
print(f'[INFO] creating live audio stream ...')
|
print(f'[INFO] creating live audio stream ...')
|
||||||
|
|
||||||
audio = pyaudio.PyAudio()
|
audio = pyaudio.PyAudio()
|
||||||
|
|
||||||
# get devices
|
# get devices
|
||||||
info = audio.get_host_api_info_by_index(0)
|
info = audio.get_host_api_info_by_index(0)
|
||||||
n_devices = info.get('deviceCount')
|
n_devices = info.get('deviceCount')
|
||||||
|
|
||||||
for i in range(0, n_devices):
|
for i in range(0, n_devices):
|
||||||
if (audio.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
|
if (audio.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
|
||||||
name = audio.get_device_info_by_host_api_device_index(0, i).get('name')
|
name = audio.get_device_info_by_host_api_device_index(0, i).get('name')
|
||||||
print(f'[INFO] choose audio device {name}, id {i}')
|
print(f'[INFO] choose audio device {name}, id {i}')
|
||||||
break
|
break
|
||||||
|
|
||||||
# get stream
|
# get stream
|
||||||
stream = audio.open(input_device_index=i,
|
stream = audio.open(input_device_index=i,
|
||||||
format=pyaudio.paInt16,
|
format=pyaudio.paInt16,
|
||||||
channels=1,
|
channels=1,
|
||||||
rate=self.sample_rate,
|
rate=self.sample_rate,
|
||||||
input=True,
|
input=True,
|
||||||
frames_per_buffer=self.chunk)
|
frames_per_buffer=self.chunk)
|
||||||
|
|
||||||
return audio, stream
|
return audio, stream
|
||||||
|
|
||||||
|
|
||||||
def get_audio_frame(self):
|
def get_audio_frame(self):
|
||||||
|
|
||||||
if self.inwarm: # warm up
|
if self.inwarm: # warm up
|
||||||
return np.zeros(self.chunk, dtype=np.float32)
|
return np.zeros(self.chunk, dtype=np.float32)
|
||||||
|
|
||||||
if self.mode == 'file':
|
if self.mode == 'file':
|
||||||
|
|
||||||
if self.idx < self.file_stream.shape[0]:
|
if self.idx < self.file_stream.shape[0]:
|
||||||
frame = self.file_stream[self.idx: self.idx + self.chunk]
|
frame = self.file_stream[self.idx: self.idx + self.chunk]
|
||||||
self.idx = self.idx + self.chunk
|
self.idx = self.idx + self.chunk
|
||||||
return frame
|
return frame
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
frame = self.queue.get(block=False)
|
frame = self.queue.get(block=False)
|
||||||
print(f'[INFO] get frame {frame.shape}')
|
print(f'[INFO] get frame {frame.shape}')
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
frame = np.zeros(self.chunk, dtype=np.float32)
|
frame = np.zeros(self.chunk, dtype=np.float32)
|
||||||
|
|
||||||
self.idx = self.idx + self.chunk
|
self.idx = self.idx + self.chunk
|
||||||
|
|
||||||
return frame
|
return frame
|
||||||
|
|
||||||
|
|
||||||
def frame_to_text(self, frame):
|
def frame_to_text(self, frame):
|
||||||
# frame: [N * 320], N = (context_size + 2 * stride_size)
|
# frame: [N * 320], N = (context_size + 2 * stride_size)
|
||||||
|
|
||||||
inputs = self.processor(frame, sampling_rate=self.sample_rate, return_tensors="pt", padding=True)
|
inputs = self.processor(frame, sampling_rate=self.sample_rate, return_tensors="pt", padding=True)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
result = self.model(inputs.input_values.to(self.device))
|
result = self.model(inputs.input_values.to(self.device))
|
||||||
logits = result.logits # [1, N - 1, 32]
|
logits = result.logits # [1, N - 1, 32]
|
||||||
|
|
||||||
# cut off stride
|
# cut off stride
|
||||||
left = max(0, self.stride_left_size)
|
left = max(0, self.stride_left_size)
|
||||||
right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input.
|
right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input.
|
||||||
|
|
||||||
# do not cut right if terminated.
|
# do not cut right if terminated.
|
||||||
if self.terminated:
|
if self.terminated:
|
||||||
right = logits.shape[1]
|
right = logits.shape[1]
|
||||||
|
|
||||||
logits = logits[:, left:right]
|
logits = logits[:, left:right]
|
||||||
|
|
||||||
# print(frame.shape, inputs.input_values.shape, logits.shape)
|
# print(frame.shape, inputs.input_values.shape, logits.shape)
|
||||||
|
|
||||||
predicted_ids = torch.argmax(logits, dim=-1)
|
predicted_ids = torch.argmax(logits, dim=-1)
|
||||||
transcription = self.processor.batch_decode(predicted_ids)[0].lower()
|
transcription = self.processor.batch_decode(predicted_ids)[0].lower()
|
||||||
|
|
||||||
|
|
||||||
# for esperanto
|
# for esperanto
|
||||||
# labels = np.array(['ŭ', '»', 'c', 'ĵ', 'ñ', '”', '„', '“', 'ǔ', 'o', 'ĝ', 'm', 'k', 'd', 'a', 'ŝ', 'z', 'i', '«', '—', '‘', 'ĥ', 'f', 'y', 'h', 'j', '|', 'r', 'u', 'ĉ', 's', '–', 'fi', 'l', 'p', '’', 'g', 'v', 't', 'b', 'n', 'e', '[UNK]', '[PAD]'])
|
# labels = np.array(['ŭ', '»', 'c', 'ĵ', 'ñ', '”', '„', '“', 'ǔ', 'o', 'ĝ', 'm', 'k', 'd', 'a', 'ŝ', 'z', 'i', '«', '—', '‘', 'ĥ', 'f', 'y', 'h', 'j', '|', 'r', 'u', 'ĉ', 's', '–', 'fi', 'l', 'p', '’', 'g', 'v', 't', 'b', 'n', 'e', '[UNK]', '[PAD]'])
|
||||||
|
|
||||||
# labels = np.array([' ', ' ', ' ', '-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z'])
|
# labels = np.array([' ', ' ', ' ', '-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z'])
|
||||||
# print(''.join(labels[predicted_ids[0].detach().cpu().long().numpy()]))
|
# print(''.join(labels[predicted_ids[0].detach().cpu().long().numpy()]))
|
||||||
# print(predicted_ids[0])
|
# print(predicted_ids[0])
|
||||||
# print(transcription)
|
# print(transcription)
|
||||||
|
|
||||||
return logits[0], predicted_ids[0], transcription # [N,]
|
return logits[0], predicted_ids[0], transcription # [N,]
|
||||||
|
|
||||||
def create_bytes_stream(self,byte_stream):
|
def create_bytes_stream(self,byte_stream):
|
||||||
#byte_stream=BytesIO(buffer)
|
#byte_stream=BytesIO(buffer)
|
||||||
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
|
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
|
||||||
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
|
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
|
||||||
stream = stream.astype(np.float32)
|
stream = stream.astype(np.float32)
|
||||||
|
|
||||||
if stream.ndim > 1:
|
if stream.ndim > 1:
|
||||||
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
|
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
|
||||||
stream = stream[:, 0]
|
stream = stream[:, 0]
|
||||||
|
|
||||||
if sample_rate != self.sample_rate:
|
if sample_rate != self.sample_rate:
|
||||||
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
|
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
|
||||||
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
|
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
|
||||||
|
|
||||||
return stream
|
return stream
|
||||||
|
|
||||||
def push_audio(self,buffer):
|
def push_audio(self,buffer):
|
||||||
print(f'[INFO] push_audio {len(buffer)}')
|
print(f'[INFO] push_audio {len(buffer)}')
|
||||||
# if len(buffer)>0:
|
# if len(buffer)>0:
|
||||||
# byte_stream=BytesIO(buffer)
|
# byte_stream=BytesIO(buffer)
|
||||||
# stream = self.create_bytes_stream(byte_stream)
|
# stream = self.create_bytes_stream(byte_stream)
|
||||||
# streamlen = stream.shape[0]
|
# streamlen = stream.shape[0]
|
||||||
# idx=0
|
# idx=0
|
||||||
# while streamlen >= self.chunk:
|
# while streamlen >= self.chunk:
|
||||||
# self.queue.put(stream[idx:idx+self.chunk])
|
# self.queue.put(stream[idx:idx+self.chunk])
|
||||||
# streamlen -= self.chunk
|
# streamlen -= self.chunk
|
||||||
# idx += self.chunk
|
# idx += self.chunk
|
||||||
# if streamlen>0:
|
# if streamlen>0:
|
||||||
# self.queue.put(stream[idx:])
|
# self.queue.put(stream[idx:])
|
||||||
self.input_stream.write(buffer)
|
self.input_stream.write(buffer)
|
||||||
if len(buffer)<=0:
|
if len(buffer)<=0:
|
||||||
self.input_stream.seek(0)
|
self.input_stream.seek(0)
|
||||||
stream = self.create_bytes_stream(self.input_stream)
|
stream = self.create_bytes_stream(self.input_stream)
|
||||||
streamlen = stream.shape[0]
|
streamlen = stream.shape[0]
|
||||||
idx=0
|
idx=0
|
||||||
while streamlen >= self.chunk:
|
while streamlen >= self.chunk:
|
||||||
self.queue.put(stream[idx:idx+self.chunk])
|
self.queue.put(stream[idx:idx+self.chunk])
|
||||||
streamlen -= self.chunk
|
streamlen -= self.chunk
|
||||||
idx += self.chunk
|
idx += self.chunk
|
||||||
if streamlen>0:
|
if streamlen>0:
|
||||||
self.queue.put(stream[idx:])
|
self.queue.put(stream[idx:])
|
||||||
|
self.input_stream.seek(0)
|
||||||
def get_audio_out(self):
|
self.input_stream.truncate()
|
||||||
return self.output_queue.get()
|
|
||||||
|
def get_audio_out(self):
|
||||||
def run(self):
|
return self.output_queue.get()
|
||||||
|
|
||||||
self.listen()
|
def run(self):
|
||||||
|
|
||||||
while not self.terminated:
|
self.listen()
|
||||||
self.run_step()
|
|
||||||
|
while not self.terminated:
|
||||||
def clear_queue(self):
|
self.run_step()
|
||||||
# clear the queue, to reduce potential latency...
|
|
||||||
print(f'[INFO] clear queue')
|
def clear_queue(self):
|
||||||
if self.mode == 'live':
|
# clear the queue, to reduce potential latency...
|
||||||
self.queue.queue.clear()
|
print(f'[INFO] clear queue')
|
||||||
if self.play:
|
if self.mode == 'live':
|
||||||
self.output_queue.queue.clear()
|
self.queue.queue.clear()
|
||||||
|
if self.play:
|
||||||
def warm_up(self):
|
self.output_queue.queue.clear()
|
||||||
|
|
||||||
#self.listen()
|
def warm_up(self):
|
||||||
|
|
||||||
self.inwarm = True
|
#self.listen()
|
||||||
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
|
|
||||||
t = time.time()
|
self.inwarm = True
|
||||||
for _ in range(self.warm_up_steps):
|
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
|
||||||
self.run_step()
|
t = time.time()
|
||||||
if torch.cuda.is_available():
|
for _ in range(self.warm_up_steps):
|
||||||
torch.cuda.synchronize()
|
self.run_step()
|
||||||
t = time.time() - t
|
if torch.cuda.is_available():
|
||||||
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
|
torch.cuda.synchronize()
|
||||||
self.inwarm = False
|
t = time.time() - t
|
||||||
|
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
|
||||||
#self.clear_queue()
|
self.inwarm = False
|
||||||
|
|
||||||
|
#self.clear_queue()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
import argparse
|
|
||||||
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
import argparse
|
||||||
parser.add_argument('--wav', type=str, default='')
|
|
||||||
parser.add_argument('--play', action='store_true', help="play out the audio")
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--wav', type=str, default='')
|
||||||
parser.add_argument('--model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto')
|
parser.add_argument('--play', action='store_true', help="play out the audio")
|
||||||
# parser.add_argument('--model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
|
|
||||||
|
parser.add_argument('--model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto')
|
||||||
parser.add_argument('--save_feats', action='store_true')
|
# parser.add_argument('--model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
|
||||||
# audio FPS
|
|
||||||
parser.add_argument('--fps', type=int, default=50)
|
parser.add_argument('--save_feats', action='store_true')
|
||||||
# sliding window left-middle-right length.
|
# audio FPS
|
||||||
parser.add_argument('-l', type=int, default=10)
|
parser.add_argument('--fps', type=int, default=50)
|
||||||
parser.add_argument('-m', type=int, default=50)
|
# sliding window left-middle-right length.
|
||||||
parser.add_argument('-r', type=int, default=10)
|
parser.add_argument('-l', type=int, default=10)
|
||||||
|
parser.add_argument('-m', type=int, default=50)
|
||||||
opt = parser.parse_args()
|
parser.add_argument('-r', type=int, default=10)
|
||||||
|
|
||||||
# fix
|
opt = parser.parse_args()
|
||||||
opt.asr_wav = opt.wav
|
|
||||||
opt.asr_play = opt.play
|
# fix
|
||||||
opt.asr_model = opt.model
|
opt.asr_wav = opt.wav
|
||||||
opt.asr_save_feats = opt.save_feats
|
opt.asr_play = opt.play
|
||||||
|
opt.asr_model = opt.model
|
||||||
if 'deepspeech' in opt.asr_model:
|
opt.asr_save_feats = opt.save_feats
|
||||||
raise ValueError("DeepSpeech features should not use this code to extract...")
|
|
||||||
|
if 'deepspeech' in opt.asr_model:
|
||||||
with ASR(opt) as asr:
|
raise ValueError("DeepSpeech features should not use this code to extract...")
|
||||||
|
|
||||||
|
with ASR(opt) as asr:
|
||||||
asr.run()
|
asr.run()
|
362
nerfreal.py
362
nerfreal.py
|
@ -1,178 +1,184 @@
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
#from .utils import *
|
#from .utils import *
|
||||||
import subprocess
|
import subprocess
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from asrreal import ASR
|
|
||||||
from rtmp_streaming import StreamerConfig, Streamer
|
from asrreal import ASR
|
||||||
|
from rtmp_streaming import StreamerConfig, Streamer
|
||||||
class NeRFReal:
|
|
||||||
def __init__(self, opt, trainer, data_loader, debug=True):
|
class NeRFReal:
|
||||||
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
|
def __init__(self, opt, trainer, data_loader, debug=True):
|
||||||
self.W = opt.W
|
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
|
||||||
self.H = opt.H
|
self.W = opt.W
|
||||||
self.debug = debug
|
self.H = opt.H
|
||||||
self.training = False
|
self.debug = debug
|
||||||
self.step = 0 # training step
|
self.training = False
|
||||||
|
self.step = 0 # training step
|
||||||
self.trainer = trainer
|
|
||||||
self.data_loader = data_loader
|
self.trainer = trainer
|
||||||
|
self.data_loader = data_loader
|
||||||
# use dataloader's bg
|
|
||||||
bg_img = data_loader._data.bg_img #.view(1, -1, 3)
|
# use dataloader's bg
|
||||||
if self.H != bg_img.shape[0] or self.W != bg_img.shape[1]:
|
bg_img = data_loader._data.bg_img #.view(1, -1, 3)
|
||||||
bg_img = F.interpolate(bg_img.permute(2, 0, 1).unsqueeze(0).contiguous(), (self.H, self.W), mode='bilinear').squeeze(0).permute(1, 2, 0).contiguous()
|
if self.H != bg_img.shape[0] or self.W != bg_img.shape[1]:
|
||||||
self.bg_color = bg_img.view(1, -1, 3)
|
bg_img = F.interpolate(bg_img.permute(2, 0, 1).unsqueeze(0).contiguous(), (self.H, self.W), mode='bilinear').squeeze(0).permute(1, 2, 0).contiguous()
|
||||||
|
self.bg_color = bg_img.view(1, -1, 3)
|
||||||
# audio features (from dataloader, only used in non-playing mode)
|
|
||||||
self.audio_features = data_loader._data.auds # [N, 29, 16]
|
# audio features (from dataloader, only used in non-playing mode)
|
||||||
self.audio_idx = 0
|
self.audio_features = data_loader._data.auds # [N, 29, 16]
|
||||||
|
self.audio_idx = 0
|
||||||
# control eye
|
|
||||||
self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item()
|
# control eye
|
||||||
|
self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item()
|
||||||
# playing seq from dataloader, or pause.
|
|
||||||
self.playing = True #False todo
|
# playing seq from dataloader, or pause.
|
||||||
self.loader = iter(data_loader)
|
self.playing = True #False todo
|
||||||
|
self.loader = iter(data_loader)
|
||||||
self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
|
|
||||||
self.need_update = True # camera moved, should reset accumulation
|
self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
|
||||||
self.spp = 1 # sample per pixel
|
self.need_update = True # camera moved, should reset accumulation
|
||||||
self.mode = 'image' # choose from ['image', 'depth']
|
self.spp = 1 # sample per pixel
|
||||||
|
self.mode = 'image' # choose from ['image', 'depth']
|
||||||
self.dynamic_resolution = False # assert False!
|
|
||||||
self.downscale = 1
|
self.dynamic_resolution = False # assert False!
|
||||||
self.train_steps = 16
|
self.downscale = 1
|
||||||
|
self.train_steps = 16
|
||||||
self.ind_index = 0
|
|
||||||
self.ind_num = trainer.model.individual_codes.shape[0]
|
self.ind_index = 0
|
||||||
|
self.ind_num = trainer.model.individual_codes.shape[0]
|
||||||
# build asr
|
|
||||||
if self.opt.asr:
|
# build asr
|
||||||
self.asr = ASR(opt)
|
if self.opt.asr:
|
||||||
|
self.asr = ASR(opt)
|
||||||
fps=25
|
|
||||||
#push_url='rtmp://localhost/live/livestream' #'data/video/output_0.mp4'
|
fps=25
|
||||||
sc = StreamerConfig()
|
#push_url='rtmp://localhost/live/livestream' #'data/video/output_0.mp4'
|
||||||
sc.source_width = self.W
|
sc = StreamerConfig()
|
||||||
sc.source_height = self.H
|
sc.source_width = self.W
|
||||||
sc.stream_width = self.W
|
sc.source_height = self.H
|
||||||
sc.stream_height = self.H
|
sc.stream_width = self.W
|
||||||
sc.stream_fps = fps
|
sc.stream_height = self.H
|
||||||
sc.stream_bitrate = 1000000
|
sc.stream_fps = fps
|
||||||
sc.stream_profile = 'main' #'high444' # 'main'
|
sc.stream_bitrate = 1000000
|
||||||
sc.audio_channel = 1
|
sc.stream_profile = 'main' #'high444' # 'main'
|
||||||
sc.sample_rate = 16000
|
sc.audio_channel = 1
|
||||||
sc.stream_server = opt.push_url
|
sc.sample_rate = 16000
|
||||||
|
sc.stream_server = opt.push_url
|
||||||
self.streamer = Streamer()
|
|
||||||
self.streamer.init(sc)
|
self.streamer = Streamer()
|
||||||
self.streamer.enable_av_debug_log()
|
self.streamer.init(sc)
|
||||||
|
self.streamer.enable_av_debug_log()
|
||||||
'''
|
|
||||||
video_path = 'video_stream'
|
'''
|
||||||
if not os.path.exists(video_path):
|
video_path = 'video_stream'
|
||||||
os.mkfifo(video_path, mode=0o777)
|
if not os.path.exists(video_path):
|
||||||
audio_path = 'audio_stream'
|
os.mkfifo(video_path, mode=0o777)
|
||||||
if not os.path.exists(audio_path):
|
audio_path = 'audio_stream'
|
||||||
os.mkfifo(audio_path, mode=0o777)
|
if not os.path.exists(audio_path):
|
||||||
width=450
|
os.mkfifo(audio_path, mode=0o777)
|
||||||
height=450
|
width=450
|
||||||
command = ['ffmpeg',
|
height=450
|
||||||
'-y', #'-an',
|
command = ['ffmpeg',
|
||||||
#'-re',
|
'-y', #'-an',
|
||||||
'-f', 'rawvideo',
|
#'-re',
|
||||||
'-vcodec','rawvideo',
|
'-f', 'rawvideo',
|
||||||
'-pix_fmt', 'rgb24', #像素格式
|
'-vcodec','rawvideo',
|
||||||
'-s', "{}x{}".format(width, height),
|
'-pix_fmt', 'rgb24', #像素格式
|
||||||
'-r', str(fps),
|
'-s', "{}x{}".format(width, height),
|
||||||
'-i', video_path,
|
'-r', str(fps),
|
||||||
'-f', 's16le',
|
'-i', video_path,
|
||||||
'-acodec','pcm_s16le',
|
'-f', 's16le',
|
||||||
'-ac', '1',
|
'-acodec','pcm_s16le',
|
||||||
'-ar', '16000',
|
'-ac', '1',
|
||||||
'-i', audio_path,
|
'-ar', '16000',
|
||||||
#'-fflags', '+genpts',
|
'-i', audio_path,
|
||||||
'-map', '0:v',
|
#'-fflags', '+genpts',
|
||||||
'-map', '1:a',
|
'-map', '0:v',
|
||||||
#'-copyts',
|
'-map', '1:a',
|
||||||
'-acodec', 'aac',
|
#'-copyts',
|
||||||
'-pix_fmt', 'yuv420p', #'-vcodec', "h264",
|
'-acodec', 'aac',
|
||||||
#"-rtmp_buffer", "100",
|
'-pix_fmt', 'yuv420p', #'-vcodec', "h264",
|
||||||
'-f' , 'flv',
|
#"-rtmp_buffer", "100",
|
||||||
push_url]
|
'-f' , 'flv',
|
||||||
self.pipe = subprocess.Popen(command, shell=False) #, stdin=subprocess.PIPE)
|
push_url]
|
||||||
self.fifo_video = open(video_path, 'wb')
|
self.pipe = subprocess.Popen(command, shell=False) #, stdin=subprocess.PIPE)
|
||||||
self.fifo_audio = open(audio_path, 'wb')
|
self.fifo_video = open(video_path, 'wb')
|
||||||
#self.test_step()
|
self.fifo_audio = open(audio_path, 'wb')
|
||||||
'''
|
#self.test_step()
|
||||||
|
'''
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
def __enter__(self):
|
||||||
|
return self
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
|
||||||
if self.opt.asr:
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
self.asr.stop()
|
if self.opt.asr:
|
||||||
|
self.asr.stop()
|
||||||
def push_audio(self,chunk):
|
|
||||||
self.asr.push_audio(chunk)
|
def push_audio(self,chunk):
|
||||||
|
self.asr.push_audio(chunk)
|
||||||
def prepare_buffer(self, outputs):
|
|
||||||
if self.mode == 'image':
|
def prepare_buffer(self, outputs):
|
||||||
return outputs['image']
|
if self.mode == 'image':
|
||||||
else:
|
return outputs['image']
|
||||||
return np.expand_dims(outputs['depth'], -1).repeat(3, -1)
|
else:
|
||||||
|
return np.expand_dims(outputs['depth'], -1).repeat(3, -1)
|
||||||
def test_step(self):
|
|
||||||
|
def test_step(self):
|
||||||
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
|
||||||
starter.record()
|
#starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
||||||
|
#starter.record()
|
||||||
if self.playing:
|
|
||||||
try:
|
if self.playing:
|
||||||
data = next(self.loader)
|
try:
|
||||||
except StopIteration:
|
data = next(self.loader)
|
||||||
self.loader = iter(self.data_loader)
|
except StopIteration:
|
||||||
data = next(self.loader)
|
self.loader = iter(self.data_loader)
|
||||||
|
data = next(self.loader)
|
||||||
if self.opt.asr:
|
|
||||||
# use the live audio stream
|
if self.opt.asr:
|
||||||
data['auds'] = self.asr.get_next_feat()
|
# use the live audio stream
|
||||||
|
data['auds'] = self.asr.get_next_feat()
|
||||||
outputs = self.trainer.test_gui_with_data(data, self.W, self.H)
|
|
||||||
#print(f'[INFO] outputs shape ',outputs['image'].shape)
|
outputs = self.trainer.test_gui_with_data(data, self.W, self.H)
|
||||||
image = (outputs['image'] * 255).astype(np.uint8)
|
#print(f'[INFO] outputs shape ',outputs['image'].shape)
|
||||||
self.streamer.stream_frame(image)
|
image = (outputs['image'] * 255).astype(np.uint8)
|
||||||
#self.pipe.stdin.write(image.tostring())
|
self.streamer.stream_frame(image)
|
||||||
for _ in range(2):
|
#self.pipe.stdin.write(image.tostring())
|
||||||
frame = self.asr.get_audio_out()
|
for _ in range(2):
|
||||||
#print(f'[INFO] get_audio_out shape ',frame.shape)
|
frame = self.asr.get_audio_out()
|
||||||
self.streamer.stream_frame_audio(frame)
|
#print(f'[INFO] get_audio_out shape ',frame.shape)
|
||||||
# frame = (frame * 32767).astype(np.int16).tobytes()
|
self.streamer.stream_frame_audio(frame)
|
||||||
# self.fifo_audio.write(frame)
|
# frame = (frame * 32767).astype(np.int16).tobytes()
|
||||||
else:
|
# self.fifo_audio.write(frame)
|
||||||
if self.audio_features is not None:
|
else:
|
||||||
auds = get_audio_features(self.audio_features, self.opt.att, self.audio_idx)
|
if self.audio_features is not None:
|
||||||
else:
|
auds = get_audio_features(self.audio_features, self.opt.att, self.audio_idx)
|
||||||
auds = None
|
else:
|
||||||
outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.W, self.H, auds, self.eye_area, self.ind_index, self.bg_color, self.spp, self.downscale)
|
auds = None
|
||||||
|
outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.W, self.H, auds, self.eye_area, self.ind_index, self.bg_color, self.spp, self.downscale)
|
||||||
ender.record()
|
|
||||||
torch.cuda.synchronize()
|
#ender.record()
|
||||||
t = starter.elapsed_time(ender)
|
#torch.cuda.synchronize()
|
||||||
|
#t = starter.elapsed_time(ender)
|
||||||
def render(self):
|
|
||||||
if self.opt.asr:
|
def render(self):
|
||||||
self.asr.warm_up()
|
if self.opt.asr:
|
||||||
while True: #todo
|
self.asr.warm_up()
|
||||||
# update texture every frame
|
while True: #todo
|
||||||
# audio stream thread...
|
# update texture every frame
|
||||||
if self.opt.asr and self.playing:
|
# audio stream thread...
|
||||||
# run 2 ASR steps (audio is at 50FPS, video is at 25FPS)
|
t = time.time()
|
||||||
for _ in range(2):
|
if self.opt.asr and self.playing:
|
||||||
self.asr.run_step()
|
# run 2 ASR steps (audio is at 50FPS, video is at 25FPS)
|
||||||
self.test_step()
|
for _ in range(2):
|
||||||
|
self.asr.run_step()
|
||||||
|
self.test_step()
|
||||||
|
delay = 0.04 - (time.time() - t) #40ms
|
||||||
|
if delay > 0:
|
||||||
|
time.sleep(delay)
|
||||||
|
|
Loading…
Reference in New Issue