optimize code, add some log

This commit is contained in:
lihengzhong 2024-01-13 17:12:25 +08:00
parent 233d46f45e
commit 49d7446933
5 changed files with 23 additions and 11 deletions

2
app.py
View File

@ -47,7 +47,7 @@ def txt_to_audio(text_):
text = text_ text = text_
t = time.time() t = time.time()
asyncio.get_event_loop().run_until_complete(main(voicename,text,nerfreal)) asyncio.get_event_loop().run_until_complete(main(voicename,text,nerfreal))
print('-------tts time: ',time.time()-t) print(f'-------tts time:{time.time()-t:.4f}s')
@sockets.route('/humanecho') @sockets.route('/humanecho')
def echo_socket(ws): def echo_socket(ws):

View File

@ -116,7 +116,7 @@ class ASR:
self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding... self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding...
# warm up steps needed: mid + right + window_size + attention_size # warm up steps needed: mid + right + window_size + attention_size
self.warm_up_steps = self.context_size + self.stride_right_size + 8 + 2 * 3 self.warm_up_steps = self.context_size + self.stride_right_size + self.stride_left_size #+ 8 + 2 * 3
self.listening = False self.listening = False
self.playing = False self.playing = False
@ -204,7 +204,6 @@ class ASR:
else: else:
self.frames.append(frame) self.frames.append(frame)
# put to output # put to output
#if self.play:
self.output_queue.put(frame) self.output_queue.put(frame)
# context not enough, do not run network. # context not enough, do not run network.
if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size: if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
@ -217,7 +216,9 @@ class ASR:
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
print(f'[INFO] frame_to_text... ') print(f'[INFO] frame_to_text... ')
#t = time.time()
logits, labels, text = self.frame_to_text(inputs) logits, labels, text = self.frame_to_text(inputs)
#print(f'-------wav2vec time:{time.time()-t:.4f}s')
feats = logits # better lips-sync than labels feats = logits # better lips-sync than labels
# save feats # save feats
@ -257,6 +258,7 @@ class ASR:
np.save(output_path, unfold_feats.cpu().numpy()) np.save(output_path, unfold_feats.cpu().numpy())
print(f"[INFO] saved logits to {output_path}") print(f"[INFO] saved logits to {output_path}")
'''
def create_file_stream(self): def create_file_stream(self):
stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64 stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64
@ -302,7 +304,7 @@ class ASR:
frames_per_buffer=self.chunk) frames_per_buffer=self.chunk)
return audio, stream return audio, stream
'''
def get_audio_frame(self): def get_audio_frame(self):
@ -351,8 +353,8 @@ class ASR:
# print(frame.shape, inputs.input_values.shape, logits.shape) # print(frame.shape, inputs.input_values.shape, logits.shape)
predicted_ids = torch.argmax(logits, dim=-1) #predicted_ids = torch.argmax(logits, dim=-1)
transcription = self.processor.batch_decode(predicted_ids)[0].lower() #transcription = self.processor.batch_decode(predicted_ids)[0].lower()
# for esperanto # for esperanto
@ -363,7 +365,7 @@ class ASR:
# print(predicted_ids[0]) # print(predicted_ids[0])
# print(transcription) # print(transcription)
return logits[0], predicted_ids[0], transcription # [N,] return logits[0], None,None #predicted_ids[0], transcription # [N,]
def create_bytes_stream(self,byte_stream): def create_bytes_stream(self,byte_stream):
#byte_stream=BytesIO(buffer) #byte_stream=BytesIO(buffer)
@ -404,8 +406,8 @@ class ASR:
self.queue.put(stream[idx:idx+self.chunk]) self.queue.put(stream[idx:idx+self.chunk])
streamlen -= self.chunk streamlen -= self.chunk
idx += self.chunk idx += self.chunk
if streamlen>0: #if streamlen>0: #skip last frame(not 20ms)
self.queue.put(stream[idx:]) # self.queue.put(stream[idx:])
self.input_stream.seek(0) self.input_stream.seek(0)
self.input_stream.truncate() self.input_stream.truncate()

Binary file not shown.

View File

@ -22,7 +22,7 @@
<div id="log"> <div id="log">
</div> </div>
<video id="video_player" width="40%" autoplay controls></video> <video id="video_player" width="40%" controls autoplay muted></video>
</div> </div>
</body> </body>
<script type="text/javascript" charset="utf-8"> <script type="text/javascript" charset="utf-8">

View File

@ -144,7 +144,9 @@ class NeRFReal:
# use the live audio stream # use the live audio stream
data['auds'] = self.asr.get_next_feat() data['auds'] = self.asr.get_next_feat()
#t = time.time()
outputs = self.trainer.test_gui_with_data(data, self.W, self.H) outputs = self.trainer.test_gui_with_data(data, self.W, self.H)
#print('-------ernerf time: ',time.time()-t)
#print(f'[INFO] outputs shape ',outputs['image'].shape) #print(f'[INFO] outputs shape ',outputs['image'].shape)
image = (outputs['image'] * 255).astype(np.uint8) image = (outputs['image'] * 255).astype(np.uint8)
self.streamer.stream_frame(image) self.streamer.stream_frame(image)
@ -168,7 +170,9 @@ class NeRFReal:
def render(self): def render(self):
if self.opt.asr: if self.opt.asr:
self.asr.warm_up() self.asr.warm_up()
count=0
totaltime=0
while True: #todo while True: #todo
# update texture every frame # update texture every frame
# audio stream thread... # audio stream thread...
@ -178,6 +182,12 @@ class NeRFReal:
for _ in range(2): for _ in range(2):
self.asr.run_step() self.asr.run_step()
self.test_step() self.test_step()
totaltime += (time.time() - t)
count += 1
if count==100:
print(f"------actual avg fps:{count/totaltime:.4f}")
count=0
totaltime=0
# delay = 0.04 - (time.time() - t) #40ms # delay = 0.04 - (time.time() - t) #40ms
# if delay > 0: # if delay > 0:
# time.sleep(delay) # time.sleep(delay)