del nouse code

This commit is contained in:
lipku 2024-06-01 06:58:02 +08:00
parent af1ad0aed8
commit 4e355e9ab9
3 changed files with 132 additions and 215 deletions

View File

@ -14,27 +14,6 @@ from queue import Queue
from threading import Thread, Event from threading import Thread, Event
from io import BytesIO from io import BytesIO
def _read_frame(stream, exit_event, queue, chunk):
while True:
if exit_event.is_set():
print(f'[INFO] read frame thread ends')
break
frame = stream.read(chunk, exception_on_overflow=False)
frame = np.frombuffer(frame, dtype=np.int16).astype(np.float32) / 32767 # [chunk]
queue.put(frame)
def _play_frame(stream, exit_event, queue, chunk):
while True:
if exit_event.is_set():
print(f'[INFO] play frame thread ends')
break
frame = queue.get()
frame = (frame * 32767).astype(np.int16).tobytes()
stream.write(frame, chunk)
class ASR: class ASR:
def __init__(self, opt): def __init__(self, opt):
@ -76,9 +55,6 @@ class ASR:
#self.audio_instance = pyaudio.PyAudio() #not need #self.audio_instance = pyaudio.PyAudio() #not need
# create input stream # create input stream
if self.mode == 'file': #live mode
self.file_stream = self.create_file_stream()
else:
self.queue = Queue() self.queue = Queue()
self.input_stream = BytesIO() self.input_stream = BytesIO()
self.output_queue = Queue() self.output_queue = Queue()
@ -87,12 +63,6 @@ class ASR:
#self.queue = Queue() #self.queue = Queue()
#self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk)) #self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk))
# play out the audio too...?
if self.play:
self.output_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=False, output=True, frames_per_buffer=self.chunk)
self.output_queue = Queue()
self.process_play_frame = Thread(target=_play_frame, args=(self.output_stream, self.exit_event, self.output_queue, self.chunk))
# current location of audio # current location of audio
self.idx = 0 self.idx = 0
@ -212,40 +182,32 @@ class ASR:
# self.text = self.text + ' ' + text # self.text = self.text + ' ' + text
# will only run once at ternimation # will only run once at ternimation
if self.terminated: # if self.terminated:
self.text += '\n[END]' # self.text += '\n[END]'
print(self.text) # print(self.text)
if self.opt.asr_save_feats: # if self.opt.asr_save_feats:
print(f'[INFO] save all feats for training purpose... ') # print(f'[INFO] save all feats for training purpose... ')
feats = torch.cat(self.all_feats, dim=0) # [N, C] # feats = torch.cat(self.all_feats, dim=0) # [N, C]
# print('[INFO] before unfold', feats.shape) # # print('[INFO] before unfold', feats.shape)
window_size = 16 # window_size = 16
padding = window_size // 2 # padding = window_size // 2
feats = feats.view(-1, self.audio_dim).permute(1, 0).contiguous() # [C, M] # feats = feats.view(-1, self.audio_dim).permute(1, 0).contiguous() # [C, M]
feats = feats.view(1, self.audio_dim, -1, 1) # [1, C, M, 1] # feats = feats.view(1, self.audio_dim, -1, 1) # [1, C, M, 1]
unfold_feats = F.unfold(feats, kernel_size=(window_size, 1), padding=(padding, 0), stride=(2, 1)) # [1, C * window_size, M / 2 + 1] # unfold_feats = F.unfold(feats, kernel_size=(window_size, 1), padding=(padding, 0), stride=(2, 1)) # [1, C * window_size, M / 2 + 1]
unfold_feats = unfold_feats.view(self.audio_dim, window_size, -1).permute(2, 1, 0).contiguous() # [C, window_size, M / 2 + 1] --> [M / 2 + 1, window_size, C] # unfold_feats = unfold_feats.view(self.audio_dim, window_size, -1).permute(2, 1, 0).contiguous() # [C, window_size, M / 2 + 1] --> [M / 2 + 1, window_size, C]
# print('[INFO] after unfold', unfold_feats.shape) # # print('[INFO] after unfold', unfold_feats.shape)
# save to a npy file # # save to a npy file
if 'esperanto' in self.opt.asr_model: # if 'esperanto' in self.opt.asr_model:
output_path = self.opt.asr_wav.replace('.wav', '_eo.npy') # output_path = self.opt.asr_wav.replace('.wav', '_eo.npy')
else: # else:
output_path = self.opt.asr_wav.replace('.wav', '.npy') # output_path = self.opt.asr_wav.replace('.wav', '.npy')
np.save(output_path, unfold_feats.cpu().numpy()) # np.save(output_path, unfold_feats.cpu().numpy())
print(f"[INFO] saved logits to {output_path}") # print(f"[INFO] saved logits to {output_path}")
def __get_audio_frame(self): def __get_audio_frame(self):
if self.inwarm: # warm up if self.inwarm: # warm up
return np.zeros(self.chunk, dtype=np.float32),1 return np.zeros(self.chunk, dtype=np.float32),1
if self.mode == 'file':
if self.idx < self.file_stream.shape[0]:
frame = self.file_stream[self.idx: self.idx + self.chunk]
self.idx = self.idx + self.chunk
return frame,0
else:
return None,0
else:
try: try:
frame = self.queue.get(block=False) frame = self.queue.get(block=False)
type = 0 type = 0
@ -361,10 +323,6 @@ class ASR:
# attention window... # attention window...
self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4
def before_push_audio(self):
self.__init_queue()
self.warm_up()
def run(self): def run(self):
self.listen() self.listen()
@ -399,53 +357,6 @@ class ASR:
#self.clear_queue() #self.clear_queue()
'''
def create_file_stream(self):
stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64
stream = stream.astype(np.float32)
if stream.ndim > 1:
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0]
if sample_rate != self.sample_rate:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
print(f'[INFO] loaded audio stream {self.opt.asr_wav}: {stream.shape}')
return stream
def create_pyaudio_stream(self):
import pyaudio
print(f'[INFO] creating live audio stream ...')
audio = pyaudio.PyAudio()
# get devices
info = audio.get_host_api_info_by_index(0)
n_devices = info.get('deviceCount')
for i in range(0, n_devices):
if (audio.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
name = audio.get_device_info_by_host_api_device_index(0, i).get('name')
print(f'[INFO] choose audio device {name}, id {i}')
break
# get stream
stream = audio.open(input_device_index=i,
format=pyaudio.paInt16,
channels=1,
rate=self.sample_rate,
input=True,
frames_per_buffer=self.chunk)
return audio, stream
'''
#####not used function##################################### #####not used function#####################################
def listen(self): def listen(self):
# start # start
@ -489,6 +400,25 @@ class ASR:
# live mode: also print the result text. # live mode: also print the result text.
self.text += '\n[END]' self.text += '\n[END]'
print(self.text) print(self.text)
def _read_frame(stream, exit_event, queue, chunk):
while True:
if exit_event.is_set():
print(f'[INFO] read frame thread ends')
break
frame = stream.read(chunk, exception_on_overflow=False)
frame = np.frombuffer(frame, dtype=np.int16).astype(np.float32) / 32767 # [chunk]
queue.put(frame)
def _play_frame(stream, exit_event, queue, chunk):
while True:
if exit_event.is_set():
print(f'[INFO] play frame thread ends')
break
frame = queue.get()
frame = (frame * 32767).astype(np.int16).tobytes()
stream.write(frame, chunk)
######################################################### #########################################################
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -26,36 +26,36 @@ class NeRFReal:
self.data_loader = data_loader self.data_loader = data_loader
# use dataloader's bg # use dataloader's bg
bg_img = data_loader._data.bg_img #.view(1, -1, 3) #bg_img = data_loader._data.bg_img #.view(1, -1, 3)
if self.H != bg_img.shape[0] or self.W != bg_img.shape[1]: #if self.H != bg_img.shape[0] or self.W != bg_img.shape[1]:
bg_img = F.interpolate(bg_img.permute(2, 0, 1).unsqueeze(0).contiguous(), (self.H, self.W), mode='bilinear').squeeze(0).permute(1, 2, 0).contiguous() # bg_img = F.interpolate(bg_img.permute(2, 0, 1).unsqueeze(0).contiguous(), (self.H, self.W), mode='bilinear').squeeze(0).permute(1, 2, 0).contiguous()
self.bg_color = bg_img.view(1, -1, 3) #self.bg_color = bg_img.view(1, -1, 3)
# audio features (from dataloader, only used in non-playing mode) # audio features (from dataloader, only used in non-playing mode)
self.audio_features = data_loader._data.auds # [N, 29, 16] #self.audio_features = data_loader._data.auds # [N, 29, 16]
self.audio_idx = 0 #self.audio_idx = 0
#self.frame_total_num = data_loader._data.end_index #self.frame_total_num = data_loader._data.end_index
#print("frame_total_num:",self.frame_total_num) #print("frame_total_num:",self.frame_total_num)
# control eye # control eye
self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item() #self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item()
# playing seq from dataloader, or pause. # playing seq from dataloader, or pause.
self.playing = True #False todo self.playing = True #False todo
self.loader = iter(data_loader) self.loader = iter(data_loader)
self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32) #self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
self.need_update = True # camera moved, should reset accumulation #self.need_update = True # camera moved, should reset accumulation
self.spp = 1 # sample per pixel #self.spp = 1 # sample per pixel
self.mode = 'image' # choose from ['image', 'depth'] #self.mode = 'image' # choose from ['image', 'depth']
self.dynamic_resolution = False # assert False! #self.dynamic_resolution = False # assert False!
self.downscale = 1 #self.downscale = 1
self.train_steps = 16 #self.train_steps = 16
self.ind_index = 0 #self.ind_index = 0
self.ind_num = trainer.model.individual_codes.shape[0] #self.ind_num = trainer.model.individual_codes.shape[0]
self.customimg_index = 0 self.customimg_index = 0
@ -113,8 +113,6 @@ class NeRFReal:
def push_audio(self,chunk): def push_audio(self,chunk):
self.asr.push_audio(chunk) self.asr.push_audio(chunk)
def before_push_audio(self):
self.asr.before_push_audio()
def mirror_index(self, index): def mirror_index(self, index):
size = self.opt.customvideo_imgnum size = self.opt.customvideo_imgnum
@ -125,18 +123,11 @@ class NeRFReal:
else: else:
return size - res - 1 return size - res - 1
def prepare_buffer(self, outputs):
if self.mode == 'image':
return outputs['image']
else:
return np.expand_dims(outputs['depth'], -1).repeat(3, -1)
def test_step(self,loop=None,audio_track=None,video_track=None): def test_step(self,loop=None,audio_track=None,video_track=None):
#starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) #starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
#starter.record() #starter.record()
if self.playing:
try: try:
data = next(self.loader) data = next(self.loader)
except StopIteration: except StopIteration:
@ -203,12 +194,6 @@ class NeRFReal:
new_frame = VideoFrame.from_ndarray(image_fullbody, format="rgb24") new_frame = VideoFrame.from_ndarray(image_fullbody, format="rgb24")
asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop) asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop)
#self.pipe.stdin.write(image.tostring()) #self.pipe.stdin.write(image.tostring())
else:
if self.audio_features is not None:
auds = get_audio_features(self.audio_features, self.opt.att, self.audio_idx)
else:
auds = None
outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.W, self.H, auds, self.eye_area, self.ind_index, self.bg_color, self.spp, self.downscale)
#ender.record() #ender.record()
#torch.cuda.synchronize() #torch.cuda.synchronize()

View File

@ -60,6 +60,7 @@ class PlayerStreamTrack(MediaStreamTrack):
else: else:
self._start = time.time() self._start = time.time()
self._timestamp = 0 self._timestamp = 0
print('video start:',self._start)
return self._timestamp, VIDEO_TIME_BASE return self._timestamp, VIDEO_TIME_BASE
else: #audio else: #audio
if hasattr(self, "_timestamp"): if hasattr(self, "_timestamp"):
@ -71,6 +72,7 @@ class PlayerStreamTrack(MediaStreamTrack):
else: else:
self._start = time.time() self._start = time.time()
self._timestamp = 0 self._timestamp = 0
print('audio start:',self._start)
return self._timestamp, AUDIO_TIME_BASE return self._timestamp, AUDIO_TIME_BASE
async def recv(self) -> Union[Frame, Packet]: async def recv(self) -> Union[Frame, Packet]: