livetalking/nerfasr.py

348 lines
14 KiB
Python
Raw Normal View History

2024-01-09 10:01:50 +08:00
import time
import numpy as np
import torch
import torch.nn.functional as F
2024-03-23 18:15:35 +08:00
from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel
2024-01-09 10:01:50 +08:00
import queue
from queue import Queue
#from collections import deque
from threading import Thread, Event
from baseasr import BaseASR
2024-01-09 10:01:50 +08:00
2024-08-03 12:58:49 +08:00
class NerfASR(BaseASR):
def __init__(self, opt, parent):
super().__init__(opt,parent)
2024-01-09 10:01:50 +08:00
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
if 'esperanto' in self.opt.asr_model:
self.audio_dim = 44
elif 'deepspeech' in self.opt.asr_model:
self.audio_dim = 29
2024-03-23 18:15:35 +08:00
elif 'hubert' in self.opt.asr_model:
self.audio_dim = 1024
2024-01-09 10:01:50 +08:00
else:
self.audio_dim = 32
# prepare context cache
# each segment is (stride_left + ctx + stride_right) * 20ms, latency should be (ctx + stride_right) * 20ms
self.context_size = opt.m
self.stride_left_size = opt.l
self.stride_right_size = opt.r
# pad left frames
if self.stride_left_size > 0:
self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size)
# create wav2vec model
print(f'[INFO] loading ASR model {self.opt.asr_model}...')
2024-03-23 18:15:35 +08:00
if 'hubert' in self.opt.asr_model:
self.processor = Wav2Vec2Processor.from_pretrained(opt.asr_model)
self.model = HubertModel.from_pretrained(opt.asr_model).to(self.device)
else:
self.processor = AutoProcessor.from_pretrained(opt.asr_model)
self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device)
2024-01-09 10:01:50 +08:00
# the extracted features
# use a loop queue to efficiently record endless features: [f--t---][-------][-------]
self.feat_buffer_size = 4
self.feat_buffer_idx = 0
self.feat_queue = torch.zeros(self.feat_buffer_size * self.context_size, self.audio_dim, dtype=torch.float32, device=self.device)
# TODO: hard coded 16 and 8 window size...
self.front = self.feat_buffer_size * self.context_size - 8 # fake padding
self.tail = 8
# attention window...
self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding...
# warm up steps needed: mid + right + window_size + attention_size
2024-04-26 23:30:07 +08:00
self.warm_up_steps = self.context_size + self.stride_left_size + self.stride_right_size #+ self.stride_left_size #+ 8 + 2 * 3
2024-01-09 10:01:50 +08:00
def get_audio_frame(self):
try:
frame = self.queue.get(block=False)
type = 0
#print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
2024-08-03 12:58:49 +08:00
if self.parent and self.parent.curr_state>1: #播放自定义音频
frame = self.parent.get_audio_stream(self.parent.curr_state)
type = self.parent.curr_state
else:
frame = np.zeros(self.chunk, dtype=np.float32)
type = 1
return frame,type
2024-01-09 10:01:50 +08:00
def get_next_feat(self): #get audio embedding to nerf
# return a [1/8, 16] window, for the next input to nerf side.
if self.opt.att>0:
while len(self.att_feats) < 8:
# [------f+++t-----]
if self.front < self.tail:
feat = self.feat_queue[self.front:self.tail]
# [++t-----------f+]
else:
feat = torch.cat([self.feat_queue[self.front:], self.feat_queue[:self.tail]], dim=0)
2024-01-09 10:01:50 +08:00
self.front = (self.front + 2) % self.feat_queue.shape[0]
self.tail = (self.tail + 2) % self.feat_queue.shape[0]
2024-01-09 10:01:50 +08:00
# print(self.front, self.tail, feat.shape)
2024-01-09 10:01:50 +08:00
self.att_feats.append(feat.permute(1, 0))
att_feat = torch.stack(self.att_feats, dim=0) # [8, 44, 16]
2024-01-09 10:01:50 +08:00
# discard old
self.att_feats = self.att_feats[1:]
else:
2024-01-09 10:01:50 +08:00
# [------f+++t-----]
if self.front < self.tail:
feat = self.feat_queue[self.front:self.tail]
# [++t-----------f+]
else:
feat = torch.cat([self.feat_queue[self.front:], self.feat_queue[:self.tail]], dim=0)
self.front = (self.front + 2) % self.feat_queue.shape[0]
self.tail = (self.tail + 2) % self.feat_queue.shape[0]
att_feat = feat.permute(1, 0).unsqueeze(0)
2024-01-09 10:01:50 +08:00
return att_feat
def run_step(self):
# get a frame of audio
frame,type = self.get_audio_frame()
self.frames.append(frame)
# put to output
self.output_queue.put((frame,type))
# context not enough, do not run network.
if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
return
2024-01-09 10:01:50 +08:00
inputs = np.concatenate(self.frames) # [N * chunk]
# discard the old part to save memory
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
2024-01-09 10:01:50 +08:00
2024-04-27 18:08:57 +08:00
#print(f'[INFO] frame_to_text... ')
2024-01-13 17:12:25 +08:00
#t = time.time()
logits, labels, text = self.__frame_to_text(inputs)
2024-01-13 17:12:25 +08:00
#print(f'-------wav2vec time:{time.time()-t:.4f}s')
2024-01-09 10:01:50 +08:00
feats = logits # better lips-sync than labels
# record the feats efficiently.. (no concat, constant memory)
start = self.feat_buffer_idx * self.context_size
end = start + feats.shape[0]
self.feat_queue[start:end] = feats
self.feat_buffer_idx = (self.feat_buffer_idx + 1) % self.feat_buffer_size
# very naive, just concat the text output.
2024-01-11 11:08:04 +08:00
#if text != '':
# self.text = self.text + ' ' + text
2024-01-09 10:01:50 +08:00
# will only run once at ternimation
2024-06-01 06:58:02 +08:00
# if self.terminated:
# self.text += '\n[END]'
# print(self.text)
# if self.opt.asr_save_feats:
# print(f'[INFO] save all feats for training purpose... ')
# feats = torch.cat(self.all_feats, dim=0) # [N, C]
# # print('[INFO] before unfold', feats.shape)
# window_size = 16
# padding = window_size // 2
# feats = feats.view(-1, self.audio_dim).permute(1, 0).contiguous() # [C, M]
# feats = feats.view(1, self.audio_dim, -1, 1) # [1, C, M, 1]
# unfold_feats = F.unfold(feats, kernel_size=(window_size, 1), padding=(padding, 0), stride=(2, 1)) # [1, C * window_size, M / 2 + 1]
# unfold_feats = unfold_feats.view(self.audio_dim, window_size, -1).permute(2, 1, 0).contiguous() # [C, window_size, M / 2 + 1] --> [M / 2 + 1, window_size, C]
# # print('[INFO] after unfold', unfold_feats.shape)
# # save to a npy file
# if 'esperanto' in self.opt.asr_model:
# output_path = self.opt.asr_wav.replace('.wav', '_eo.npy')
# else:
# output_path = self.opt.asr_wav.replace('.wav', '.npy')
# np.save(output_path, unfold_feats.cpu().numpy())
# print(f"[INFO] saved logits to {output_path}")
2024-01-09 10:01:50 +08:00
def __frame_to_text(self, frame):
2024-01-09 10:01:50 +08:00
# frame: [N * 320], N = (context_size + 2 * stride_size)
inputs = self.processor(frame, sampling_rate=self.sample_rate, return_tensors="pt", padding=True)
with torch.no_grad():
result = self.model(inputs.input_values.to(self.device))
2024-03-23 18:15:35 +08:00
if 'hubert' in self.opt.asr_model:
logits = result.last_hidden_state # [B=1, T=pts//320, hid=1024]
else:
logits = result.logits # [1, N - 1, 32]
#print('logits.shape:',logits.shape)
2024-01-09 10:01:50 +08:00
# cut off stride
left = max(0, self.stride_left_size)
right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input.
# do not cut right if terminated.
# if self.terminated:
# right = logits.shape[1]
2024-01-09 10:01:50 +08:00
logits = logits[:, left:right]
# print(frame.shape, inputs.input_values.shape, logits.shape)
2024-01-13 17:12:25 +08:00
#predicted_ids = torch.argmax(logits, dim=-1)
#transcription = self.processor.batch_decode(predicted_ids)[0].lower()
2024-01-09 10:01:50 +08:00
# for esperanto
# labels = np.array(['ŭ', '»', 'c', 'ĵ', 'ñ', '”', '„', '“', 'ǔ', 'o', 'ĝ', 'm', 'k', 'd', 'a', 'ŝ', 'z', 'i', '«', '—', '', 'ĥ', 'f', 'y', 'h', 'j', '|', 'r', 'u', 'ĉ', 's', '', 'fi', 'l', 'p', '', 'g', 'v', 't', 'b', 'n', 'e', '[UNK]', '[PAD]'])
# labels = np.array([' ', ' ', ' ', '-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z'])
# print(''.join(labels[predicted_ids[0].detach().cpu().long().numpy()]))
# print(predicted_ids[0])
# print(transcription)
2024-01-13 17:12:25 +08:00
return logits[0], None,None #predicted_ids[0], transcription # [N,]
2024-01-09 10:01:50 +08:00
def warm_up(self):
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
t = time.time()
#for _ in range(self.stride_left_size):
# self.frames.append(np.zeros(self.chunk, dtype=np.float32))
for _ in range(self.warm_up_steps):
self.run_step()
#if torch.cuda.is_available():
# torch.cuda.synchronize()
t = time.time() - t
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
#self.clear_queue()
#####not used function#####################################
'''
def __init_queue(self):
self.frames = []
self.queue.queue.clear()
self.output_queue.queue.clear()
self.front = self.feat_buffer_size * self.context_size - 8 # fake padding
self.tail = 8
# attention window...
self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4
2024-01-09 10:01:50 +08:00
def run(self):
self.listen()
while not self.terminated:
self.run_step()
def clear_queue(self):
# clear the queue, to reduce potential latency...
print(f'[INFO] clear queue')
if self.mode == 'live':
self.queue.queue.clear()
if self.play:
self.output_queue.queue.clear()
def listen(self):
# start
if self.mode == 'live' and not self.listening:
print(f'[INFO] starting read frame thread...')
self.process_read_frame.start()
self.listening = True
if self.play and not self.playing:
print(f'[INFO] starting play frame thread...')
self.process_play_frame.start()
self.playing = True
def stop(self):
self.exit_event.set()
if self.play:
self.output_stream.stop_stream()
self.output_stream.close()
if self.playing:
self.process_play_frame.join()
self.playing = False
if self.mode == 'live':
#self.input_stream.stop_stream() todo
self.input_stream.close()
if self.listening:
self.process_read_frame.join()
self.listening = False
2024-01-09 10:01:50 +08:00
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.stop()
if self.mode == 'live':
# live mode: also print the result text.
self.text += '\n[END]'
print(self.text)
2024-06-01 06:58:02 +08:00
def _read_frame(stream, exit_event, queue, chunk):
while True:
if exit_event.is_set():
print(f'[INFO] read frame thread ends')
break
frame = stream.read(chunk, exception_on_overflow=False)
frame = np.frombuffer(frame, dtype=np.int16).astype(np.float32) / 32767 # [chunk]
queue.put(frame)
def _play_frame(stream, exit_event, queue, chunk):
while True:
if exit_event.is_set():
print(f'[INFO] play frame thread ends')
break
frame = queue.get()
frame = (frame * 32767).astype(np.int16).tobytes()
stream.write(frame, chunk)
#########################################################
2024-01-09 10:01:50 +08:00
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--wav', type=str, default='')
parser.add_argument('--play', action='store_true', help="play out the audio")
2024-04-03 15:08:38 +08:00
# parser.add_argument('--model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto')
2024-01-09 10:01:50 +08:00
# parser.add_argument('--model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
2024-04-03 15:08:38 +08:00
parser.add_argument('--model', type=str, default='facebook/hubert-large-ls960-ft')
2024-01-09 10:01:50 +08:00
parser.add_argument('--save_feats', action='store_true')
# audio FPS
parser.add_argument('--fps', type=int, default=50)
# sliding window left-middle-right length.
parser.add_argument('-l', type=int, default=10)
parser.add_argument('-m', type=int, default=50)
parser.add_argument('-r', type=int, default=10)
opt = parser.parse_args()
# fix
opt.asr_wav = opt.wav
opt.asr_play = opt.play
opt.asr_model = opt.model
opt.asr_save_feats = opt.save_feats
if 'deepspeech' in opt.asr_model:
raise ValueError("DeepSpeech features should not use this code to extract...")
with ASR(opt) as asr:
asr.run()
'''