optimize code, add some log

This commit is contained in:
lihengzhong 2024-01-13 17:12:25 +08:00
parent 233d46f45e
commit 49d7446933
5 changed files with 23 additions and 11 deletions

2
app.py
View File

@ -47,7 +47,7 @@ def txt_to_audio(text_):
text = text_
t = time.time()
asyncio.get_event_loop().run_until_complete(main(voicename,text,nerfreal))
print('-------tts time: ',time.time()-t)
print(f'-------tts time:{time.time()-t:.4f}s')
@sockets.route('/humanecho')
def echo_socket(ws):

View File

@ -116,7 +116,7 @@ class ASR:
self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding...
# warm up steps needed: mid + right + window_size + attention_size
self.warm_up_steps = self.context_size + self.stride_right_size + 8 + 2 * 3
self.warm_up_steps = self.context_size + self.stride_right_size + self.stride_left_size #+ 8 + 2 * 3
self.listening = False
self.playing = False
@ -204,7 +204,6 @@ class ASR:
else:
self.frames.append(frame)
# put to output
#if self.play:
self.output_queue.put(frame)
# context not enough, do not run network.
if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
@ -217,7 +216,9 @@ class ASR:
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
print(f'[INFO] frame_to_text... ')
#t = time.time()
logits, labels, text = self.frame_to_text(inputs)
#print(f'-------wav2vec time:{time.time()-t:.4f}s')
feats = logits # better lips-sync than labels
# save feats
@ -257,6 +258,7 @@ class ASR:
np.save(output_path, unfold_feats.cpu().numpy())
print(f"[INFO] saved logits to {output_path}")
'''
def create_file_stream(self):
stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64
@ -302,7 +304,7 @@ class ASR:
frames_per_buffer=self.chunk)
return audio, stream
'''
def get_audio_frame(self):
@ -351,8 +353,8 @@ class ASR:
# print(frame.shape, inputs.input_values.shape, logits.shape)
predicted_ids = torch.argmax(logits, dim=-1)
transcription = self.processor.batch_decode(predicted_ids)[0].lower()
#predicted_ids = torch.argmax(logits, dim=-1)
#transcription = self.processor.batch_decode(predicted_ids)[0].lower()
# for esperanto
@ -363,7 +365,7 @@ class ASR:
# print(predicted_ids[0])
# print(transcription)
return logits[0], predicted_ids[0], transcription # [N,]
return logits[0], None,None #predicted_ids[0], transcription # [N,]
def create_bytes_stream(self,byte_stream):
#byte_stream=BytesIO(buffer)
@ -404,8 +406,8 @@ class ASR:
self.queue.put(stream[idx:idx+self.chunk])
streamlen -= self.chunk
idx += self.chunk
if streamlen>0:
self.queue.put(stream[idx:])
#if streamlen>0: #skip last frame(not 20ms)
# self.queue.put(stream[idx:])
self.input_stream.seek(0)
self.input_stream.truncate()

Binary file not shown.

View File

@ -22,7 +22,7 @@
<div id="log">
</div>
<video id="video_player" width="40%" autoplay controls></video>
<video id="video_player" width="40%" controls autoplay muted></video>
</div>
</body>
<script type="text/javascript" charset="utf-8">

View File

@ -144,7 +144,9 @@ class NeRFReal:
# use the live audio stream
data['auds'] = self.asr.get_next_feat()
#t = time.time()
outputs = self.trainer.test_gui_with_data(data, self.W, self.H)
#print('-------ernerf time: ',time.time()-t)
#print(f'[INFO] outputs shape ',outputs['image'].shape)
image = (outputs['image'] * 255).astype(np.uint8)
self.streamer.stream_frame(image)
@ -169,6 +171,8 @@ class NeRFReal:
def render(self):
if self.opt.asr:
self.asr.warm_up()
count=0
totaltime=0
while True: #todo
# update texture every frame
# audio stream thread...
@ -178,6 +182,12 @@ class NeRFReal:
for _ in range(2):
self.asr.run_step()
self.test_step()
totaltime += (time.time() - t)
count += 1
if count==100:
print(f"------actual avg fps:{count/totaltime:.4f}")
count=0
totaltime=0
# delay = 0.04 - (time.time() - t) #40ms
# if delay > 0:
# time.sleep(delay)