diff --git a/README.md b/README.md index 783fc06..69cb1f6 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,12 @@ A streaming digital human based on the Ernerf model, realize audio video synch [![Watch the video]](/assets/demo.mp4) +## Features +1. 支持声音克隆 +2. 支持大模型对话 +3. 支持多种音频特征驱动:wav2vec、hubert +4. 支持全身视频拼接 + ## 1. Installation Tested on Ubuntu 20.04, Python3.10, Pytorch 1.12 and CUDA 11.3 @@ -53,7 +59,7 @@ nginx 用浏览器打开http://serverip/echo.html, 在文本框输入任意文字,提交。数字人播报该段文字 -## 3. 更多使用 +## 3. More Usage ### 3.1 使用LLM模型进行数字人对话 目前借鉴数字人对话系统[LinlyTalker](https://github.com/Kedreamix/Linly-Talker)的方式,LLM模型支持Chatgpt,Qwen和GeminiPro。需要在app.py中填入自己的api_key。 @@ -97,7 +103,8 @@ ffmpeg -i fullbody.mp4 -vf fps=25 -qmin 1 -q:v 1 -start_number 0 data/fullbody/i python app.py --fullbody --fullbody_img data/fullbody/img --fullbody_offset_x 100 --fullbody_offset_y 5 --fullbody_width 580 --fullbody_height 1080 --W 400 --H 400 ``` - --fullbody_width、--fullbody_height 全身视频的宽、高 -- --W、--H 训练视频的宽、高 +- --W、--H 训练视频的宽、高 +- ernerf训练第三步torso如果训练的不好,在拼接处会有接缝。可以在上面的命令加上--torso_imgs data/xxx/torso_imgs,torso不用模型推理,直接用训练数据集里的torso图片。这种方式可能头颈处会有些人工痕迹。 ## 4. Docker Run 不需要第1步的安装,直接运行。 @@ -126,9 +133,9 @@ srs和nginx的运行同2.1和2.3 在Tesla T4显卡上测试整体fps为18左右,如果去掉音视频编码推流,帧率在20左右。用4090显卡可以达到40多帧/秒。 优化:新开一个线程运行音视频编码推流 2. 延时 -整体延时5s多 -(1)tts延时2s左右,目前用的edgetts,需要将每句话转完后一次性输入,可以优化tts改成流式输入 -(2)wav2vec延时1s多,需要缓存50帧音频做计算,可以通过-m设置context_size来减少延时 +整体延时3s左右 +(1)tts延时1.7s左右,目前用的edgetts,需要将每句话转完后一次性输入,可以优化tts改成流式输入 +(2)wav2vec延时0.4s,需要缓存18帧音频做计算 (3)srs转发延时,设置srs服务器减少缓冲延时。具体配置可看 https://ossrs.net/lts/zh-cn/docs/v5/doc/low-latency, 配置了一个低延时版本 ```python docker run --rm -it -p 1935:1935 -p 1985:1985 -p 8080:8080 registry.cn-hangzhou.aliyuncs.com/lipku/srs:v1.1 diff --git a/app.py b/app.py index a811bbc..c039db1 100644 --- a/app.py +++ b/app.py @@ -37,7 +37,11 @@ async def main(voicename: str, text: str, render): communicate = edge_tts.Communicate(text, voicename) #with open(OUTPUT_FILE, "wb") as file: + first = True async for chunk in communicate.stream(): + if first: + #render.before_push_audio() + first = False if chunk["type"] == "audio": render.push_audio(chunk["data"]) #file.write(chunk["data"]) @@ -160,6 +164,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--pose', type=str, default="data/data_kf.json", help="transforms.json, pose source") parser.add_argument('--au', type=str, default="data/au.csv", help="eye blink area") + parser.add_argument('--torso_imgs', type=str, default="", help="torso images path") parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --exp_eye") @@ -259,7 +264,7 @@ if __name__ == '__main__': parser.add_argument('--fps', type=int, default=50) # sliding window left-middle-right length (unit: 20ms) parser.add_argument('-l', type=int, default=10) - parser.add_argument('-m', type=int, default=50) + parser.add_argument('-m', type=int, default=8) parser.add_argument('-r', type=int, default=10) parser.add_argument('--fullbody', action='store_true', help="fullbody human") @@ -298,7 +303,8 @@ if __name__ == '__main__': opt.exp_eye = True opt.smooth_eye = True - opt.torso = True + if opt.torso_imgs=='': #no img,use model output + opt.torso = True # assert opt.cuda_ray, "Only support CUDA ray mode." opt.asr = True @@ -307,6 +313,7 @@ if __name__ == '__main__': # assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss." assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays." seed_everything(opt.seed) + print(opt) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = NeRFNetwork(opt) diff --git a/asrreal.py b/asrreal.py index 595eda0..5d8fea8 100644 --- a/asrreal.py +++ b/asrreal.py @@ -122,58 +122,34 @@ class ASR: self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding... # warm up steps needed: mid + right + window_size + attention_size - self.warm_up_steps = self.context_size + self.stride_right_size + self.stride_left_size #+ 8 + 2 * 3 + self.warm_up_steps = self.context_size + self.stride_right_size #+ self.stride_left_size #+ 8 + 2 * 3 self.listening = False self.playing = False - def listen(self): - # start - if self.mode == 'live' and not self.listening: - print(f'[INFO] starting read frame thread...') - self.process_read_frame.start() - self.listening = True - - if self.play and not self.playing: - print(f'[INFO] starting play frame thread...') - self.process_play_frame.start() - self.playing = True - - def stop(self): - - self.exit_event.set() - - if self.play: - self.output_stream.stop_stream() - self.output_stream.close() - if self.playing: - self.process_play_frame.join() - self.playing = False - - if self.mode == 'live': - #self.input_stream.stop_stream() todo - self.input_stream.close() - if self.listening: - self.process_read_frame.join() - self.listening = False - - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - - self.stop() - - if self.mode == 'live': - # live mode: also print the result text. - self.text += '\n[END]' - print(self.text) - - def get_next_feat(self): + def get_next_feat(self): #get audio embedding to nerf # return a [1/8, 16] window, for the next input to nerf side. - - while len(self.att_feats) < 8: + if self.opt.att>0: + while len(self.att_feats) < 8: + # [------f+++t-----] + if self.front < self.tail: + feat = self.feat_queue[self.front:self.tail] + # [++t-----------f+] + else: + feat = torch.cat([self.feat_queue[self.front:], self.feat_queue[:self.tail]], dim=0) + + self.front = (self.front + 2) % self.feat_queue.shape[0] + self.tail = (self.tail + 2) % self.feat_queue.shape[0] + + # print(self.front, self.tail, feat.shape) + + self.att_feats.append(feat.permute(1, 0)) + + att_feat = torch.stack(self.att_feats, dim=0) # [8, 44, 16] + + # discard old + self.att_feats = self.att_feats[1:] + else: # [------f+++t-----] if self.front < self.tail: feat = self.feat_queue[self.front:self.tail] @@ -184,14 +160,8 @@ class ASR: self.front = (self.front + 2) % self.feat_queue.shape[0] self.tail = (self.tail + 2) % self.feat_queue.shape[0] - # print(self.front, self.tail, feat.shape) + att_feat = feat.permute(1, 0).unsqueeze(0) - self.att_feats.append(feat.permute(1, 0)) - - att_feat = torch.stack(self.att_feats, dim=0) # [8, 44, 16] - - # discard old - self.att_feats = self.att_feats[1:] return att_feat @@ -201,7 +171,7 @@ class ASR: return # get a frame of audio - frame = self.get_audio_frame() + frame = self.__get_audio_frame() # the last frame if frame is None: @@ -223,7 +193,7 @@ class ASR: print(f'[INFO] frame_to_text... ') #t = time.time() - logits, labels, text = self.frame_to_text(inputs) + logits, labels, text = self.__frame_to_text(inputs) #print(f'-------wav2vec time:{time.time()-t:.4f}s') feats = logits # better lips-sync than labels @@ -264,6 +234,166 @@ class ASR: np.save(output_path, unfold_feats.cpu().numpy()) print(f"[INFO] saved logits to {output_path}") + def __get_audio_frame(self): + if self.inwarm: # warm up + return np.zeros(self.chunk, dtype=np.float32) + + if self.mode == 'file': + if self.idx < self.file_stream.shape[0]: + frame = self.file_stream[self.idx: self.idx + self.chunk] + self.idx = self.idx + self.chunk + return frame + else: + return None + else: + try: + frame = self.queue.get(block=False) + print(f'[INFO] get frame {frame.shape}') + except queue.Empty: + frame = np.zeros(self.chunk, dtype=np.float32) + + self.idx = self.idx + self.chunk + + return frame + + + def __frame_to_text(self, frame): + # frame: [N * 320], N = (context_size + 2 * stride_size) + + inputs = self.processor(frame, sampling_rate=self.sample_rate, return_tensors="pt", padding=True) + + with torch.no_grad(): + result = self.model(inputs.input_values.to(self.device)) + if 'hubert' in self.opt.asr_model: + logits = result.last_hidden_state # [B=1, T=pts//320, hid=1024] + else: + logits = result.logits # [1, N - 1, 32] + #print('logits.shape:',logits.shape) + + # cut off stride + left = max(0, self.stride_left_size) + right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input. + + # do not cut right if terminated. + if self.terminated: + right = logits.shape[1] + + logits = logits[:, left:right] + + # print(frame.shape, inputs.input_values.shape, logits.shape) + + #predicted_ids = torch.argmax(logits, dim=-1) + #transcription = self.processor.batch_decode(predicted_ids)[0].lower() + + + # for esperanto + # labels = np.array(['ŭ', '»', 'c', 'ĵ', 'ñ', '”', '„', '“', 'ǔ', 'o', 'ĝ', 'm', 'k', 'd', 'a', 'ŝ', 'z', 'i', '«', '—', '‘', 'ĥ', 'f', 'y', 'h', 'j', '|', 'r', 'u', 'ĉ', 's', '–', 'fi', 'l', 'p', '’', 'g', 'v', 't', 'b', 'n', 'e', '[UNK]', '[PAD]']) + + # labels = np.array([' ', ' ', ' ', '-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z']) + # print(''.join(labels[predicted_ids[0].detach().cpu().long().numpy()])) + # print(predicted_ids[0]) + # print(transcription) + + return logits[0], None,None #predicted_ids[0], transcription # [N,] + + def __create_bytes_stream(self,byte_stream): + #byte_stream=BytesIO(buffer) + stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64 + print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}') + stream = stream.astype(np.float32) + + if stream.ndim > 1: + print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') + stream = stream[:, 0] + + if sample_rate != self.sample_rate and stream.shape[0]>0: + print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') + stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) + + return stream + + def push_audio(self,buffer): #push audio pcm from tts + print(f'[INFO] push_audio {len(buffer)}') + if self.opt.tts == "xtts": + if len(buffer)>0: + stream = np.frombuffer(buffer, dtype=np.int16).astype(np.float32) / 32767 + stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate) + #byte_stream=BytesIO(buffer) + #stream = self.__create_bytes_stream(byte_stream) + streamlen = stream.shape[0] + idx=0 + while streamlen >= self.chunk: + self.queue.put(stream[idx:idx+self.chunk]) + streamlen -= self.chunk + idx += self.chunk + # if streamlen>0: #skip last frame(not 20ms) + # self.queue.put(stream[idx:]) + else: #edge tts + self.input_stream.write(buffer) + if len(buffer)<=0: + self.input_stream.seek(0) + stream = self.__create_bytes_stream(self.input_stream) + streamlen = stream.shape[0] + idx=0 + while streamlen >= self.chunk: + self.queue.put(stream[idx:idx+self.chunk]) + streamlen -= self.chunk + idx += self.chunk + #if streamlen>0: #skip last frame(not 20ms) + # self.queue.put(stream[idx:]) + self.input_stream.seek(0) + self.input_stream.truncate() + + def get_audio_out(self): #get origin audio pcm to nerf + return self.output_queue.get() + + def __init_queue(self): + self.frames = [] + self.queue.queue.clear() + self.output_queue.queue.clear() + self.front = self.feat_buffer_size * self.context_size - 8 # fake padding + self.tail = 8 + # attention window... + self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 + + def before_push_audio(self): + self.__init_queue() + self.warm_up() + + def run(self): + + self.listen() + + while not self.terminated: + self.run_step() + + def clear_queue(self): + # clear the queue, to reduce potential latency... + print(f'[INFO] clear queue') + if self.mode == 'live': + self.queue.queue.clear() + if self.play: + self.output_queue.queue.clear() + + def warm_up(self): + + #self.listen() + + self.inwarm = True + print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s') + t = time.time() + for _ in range(self.stride_left_size): + self.frames.append(np.zeros(self.chunk, dtype=np.float32)) + for _ in range(self.warm_up_steps): + self.run_step() + #if torch.cuda.is_available(): + # torch.cuda.synchronize() + t = time.time() - t + print(f'[INFO] warm-up done, actual latency = {t:.6f}s') + self.inwarm = False + + #self.clear_queue() + ''' def create_file_stream(self): @@ -311,157 +441,50 @@ class ASR: return audio, stream ''' - - def get_audio_frame(self): - - if self.inwarm: # warm up - return np.zeros(self.chunk, dtype=np.float32) + #####not used function##################################### + def listen(self): + # start + if self.mode == 'live' and not self.listening: + print(f'[INFO] starting read frame thread...') + self.process_read_frame.start() + self.listening = True - if self.mode == 'file': + if self.play and not self.playing: + print(f'[INFO] starting play frame thread...') + self.process_play_frame.start() + self.playing = True - if self.idx < self.file_stream.shape[0]: - frame = self.file_stream[self.idx: self.idx + self.chunk] - self.idx = self.idx + self.chunk - return frame - else: - return None - - else: - try: - frame = self.queue.get(block=False) - print(f'[INFO] get frame {frame.shape}') - except queue.Empty: - frame = np.zeros(self.chunk, dtype=np.float32) + def stop(self): - self.idx = self.idx + self.chunk + self.exit_event.set() - return frame - - - def frame_to_text(self, frame): - # frame: [N * 320], N = (context_size + 2 * stride_size) - - inputs = self.processor(frame, sampling_rate=self.sample_rate, return_tensors="pt", padding=True) - - with torch.no_grad(): - result = self.model(inputs.input_values.to(self.device)) - if 'hubert' in self.opt.asr_model: - logits = result.last_hidden_state # [B=1, T=pts//320, hid=1024] - else: - logits = result.logits # [1, N - 1, 32] - #print('logits.shape:',logits.shape) - - # cut off stride - left = max(0, self.stride_left_size) - right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input. - - # do not cut right if terminated. - if self.terminated: - right = logits.shape[1] - - logits = logits[:, left:right] - - # print(frame.shape, inputs.input_values.shape, logits.shape) - - #predicted_ids = torch.argmax(logits, dim=-1) - #transcription = self.processor.batch_decode(predicted_ids)[0].lower() - - - # for esperanto - # labels = np.array(['ŭ', '»', 'c', 'ĵ', 'ñ', '”', '„', '“', 'ǔ', 'o', 'ĝ', 'm', 'k', 'd', 'a', 'ŝ', 'z', 'i', '«', '—', '‘', 'ĥ', 'f', 'y', 'h', 'j', '|', 'r', 'u', 'ĉ', 's', '–', 'fi', 'l', 'p', '’', 'g', 'v', 't', 'b', 'n', 'e', '[UNK]', '[PAD]']) - - # labels = np.array([' ', ' ', ' ', '-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z']) - # print(''.join(labels[predicted_ids[0].detach().cpu().long().numpy()])) - # print(predicted_ids[0]) - # print(transcription) - - return logits[0], None,None #predicted_ids[0], transcription # [N,] - - def create_bytes_stream(self,byte_stream): - #byte_stream=BytesIO(buffer) - stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64 - print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}') - stream = stream.astype(np.float32) - - if stream.ndim > 1: - print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') - stream = stream[:, 0] - - if sample_rate != self.sample_rate and stream.shape[0]>0: - print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') - stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) - - return stream - - def push_audio(self,buffer): - print(f'[INFO] push_audio {len(buffer)}') - if self.opt.tts == "xtts": - if len(buffer)>0: - stream = np.frombuffer(buffer, dtype=np.int16).astype(np.float32) / 32767 - stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate) - #byte_stream=BytesIO(buffer) - #stream = self.create_bytes_stream(byte_stream) - streamlen = stream.shape[0] - idx=0 - while streamlen >= self.chunk: - self.queue.put(stream[idx:idx+self.chunk]) - streamlen -= self.chunk - idx += self.chunk - # if streamlen>0: #skip last frame(not 20ms) - # self.queue.put(stream[idx:]) - else: #edge tts - self.input_stream.write(buffer) - if len(buffer)<=0: - self.input_stream.seek(0) - stream = self.create_bytes_stream(self.input_stream) - streamlen = stream.shape[0] - idx=0 - while streamlen >= self.chunk: - self.queue.put(stream[idx:idx+self.chunk]) - streamlen -= self.chunk - idx += self.chunk - #if streamlen>0: #skip last frame(not 20ms) - # self.queue.put(stream[idx:]) - self.input_stream.seek(0) - self.input_stream.truncate() - - def get_audio_out(self): - return self.output_queue.get() - - def run(self): - - self.listen() - - while not self.terminated: - self.run_step() - - def clear_queue(self): - # clear the queue, to reduce potential latency... - print(f'[INFO] clear queue') - if self.mode == 'live': - self.queue.queue.clear() if self.play: - self.output_queue.queue.clear() + self.output_stream.stop_stream() + self.output_stream.close() + if self.playing: + self.process_play_frame.join() + self.playing = False - def warm_up(self): + if self.mode == 'live': + #self.input_stream.stop_stream() todo + self.input_stream.close() + if self.listening: + self.process_read_frame.join() + self.listening = False - #self.listen() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): - self.inwarm = True - print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s') - t = time.time() - for _ in range(self.warm_up_steps): - self.run_step() - if torch.cuda.is_available(): - torch.cuda.synchronize() - t = time.time() - t - print(f'[INFO] warm-up done, actual latency = {t:.6f}s') - self.inwarm = False - - #self.clear_queue() - - + self.stop() + if self.mode == 'live': + # live mode: also print the result text. + self.text += '\n[END]' + print(self.text) + ######################################################### if __name__ == '__main__': import argparse diff --git a/main.py b/main.py index 6e432ea..cc8a0b5 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,7 @@ import torch import argparse -from nerf_triplane.provider import NeRFDataset +from nerf_triplane.provider import NeRFDataset,NeRFDataset_Test from nerf_triplane.utils import * from nerf_triplane.network import NeRFNetwork @@ -24,6 +24,9 @@ if __name__ == '__main__': parser.add_argument('--workspace', type=str, default='workspace') parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--pose', type=str, default="data/data_kf.json", help="transforms.json, pose source") + parser.add_argument('--au', type=str, default="data/au.csv", help="eye blink area") + ### training options parser.add_argument('--iters', type=int, default=200000, help="training iters") parser.add_argument('--lr', type=float, default=1e-2, help="initial learning rate") @@ -47,7 +50,7 @@ if __name__ == '__main__': ### network backbone options parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training") - parser.add_argument('--bg_img', type=str, default='', help="background image") + parser.add_argument('--bg_img', type=str, default='white', help="background image") parser.add_argument('--fbg', action='store_true', help="frame-wise bg") parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes") parser.add_argument('--fix_eye', type=float, default=-1, help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye") @@ -182,7 +185,7 @@ if __name__ == '__main__': trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt) if opt.test_train: - test_set = NeRFDataset(opt, device=device, type='train') + test_set = NeRFDataset(opt, device=device, type='train') # a manual fix to test on the training dataset test_set.training = False test_set.num_rays = -1 diff --git a/nerf_triplane/provider.py b/nerf_triplane/provider.py index 727836c..4f0a704 100644 --- a/nerf_triplane/provider.py +++ b/nerf_triplane/provider.py @@ -98,6 +98,7 @@ class NeRFDataset_Test: self.training = False self.num_rays = -1 + self.preload = opt.preload # 0 = disk, 1 = cpu, 2 = gpu # load nerf-compatible format data. @@ -148,6 +149,7 @@ class NeRFDataset_Test: self.poses = [] self.auds = [] self.eye_area = [] + self.torso_img = [] for f in tqdm.tqdm(frames, desc=f'Loading data'): @@ -172,6 +174,29 @@ class NeRFDataset_Test: # area = area + np.random.rand() / 10 self.eye_area.append(area) + + # load frame-wise bg + + if self.opt.torso_imgs!='': + torso_img_path = os.path.join(self.opt.torso_imgs, str(f['img_id']) + '.png') + + if self.preload > 0: + torso_img = cv2.imread(torso_img_path, cv2.IMREAD_UNCHANGED) # [H, W, 4] + torso_img = cv2.cvtColor(torso_img, cv2.COLOR_BGRA2RGBA) + torso_img = torso_img.astype(np.float32) / 255 # [H, W, 3/4] + + self.torso_img.append(torso_img) + else: + self.torso_img.append(torso_img_path) + + if self.opt.torso_imgs!='': + if self.preload > 0: + self.torso_img = torch.from_numpy(np.stack(self.torso_img, axis=0)) # [N, H, W, C] + else: + self.torso_img = np.array(self.torso_img) + if self.preload > 1: #gpu + self.torso_img = self.torso_img.to(torch.half).to(self.device) + # load pre-extracted background image (should be the same size as training image...) @@ -209,6 +234,9 @@ class NeRFDataset_Test: self.bg_img = torch.from_numpy(self.bg_img) + if self.preload > 1 or self.opt.torso_imgs=='': #gpu + self.bg_img = self.bg_img.to(torch.half).to(self.device) + if self.opt.exp_eye: self.eye_area = np.array(self.eye_area, dtype=np.float32) # [N] print(f'[INFO] eye_area: {self.eye_area.min()} - {self.eye_area.max()}') @@ -229,8 +257,6 @@ class NeRFDataset_Test: if self.auds is not None: self.auds = self.auds.to(self.device) - - self.bg_img = self.bg_img.to(torch.half).to(self.device) if self.opt.exp_eye: self.eye_area = self.eye_area.to(self.device) @@ -285,8 +311,23 @@ class NeRFDataset_Test: results['eye'] = self.eye_area[index].to(self.device) # [1] else: results['eye'] = None - - bg_img = self.bg_img.view(1, -1, 3).repeat(B, 1, 1).to(self.device) + + # load bg + if self.opt.torso_imgs!='': + bg_torso_img = self.torso_img[index] + if self.preload == 0: # on the fly loading + bg_torso_img = cv2.imread(bg_torso_img[0], cv2.IMREAD_UNCHANGED) # [H, W, 4] + bg_torso_img = cv2.cvtColor(bg_torso_img, cv2.COLOR_BGRA2RGBA) + bg_torso_img = bg_torso_img.astype(np.float32) / 255 # [H, W, 3/4] + bg_torso_img = torch.from_numpy(bg_torso_img).unsqueeze(0) + bg_torso_img = bg_torso_img[..., :3] * bg_torso_img[..., 3:] + self.bg_img * (1 - bg_torso_img[..., 3:]) + bg_torso_img = bg_torso_img.view(B, -1, 3).to(self.device) + if not self.opt.torso: + bg_img = bg_torso_img + else: + bg_img = self.bg_img.view(1, -1, 3).repeat(B, 1, 1).to(self.device) + else: + bg_img = self.bg_img.view(1, -1, 3).repeat(B, 1, 1).to(self.device) results['bg_color'] = bg_img @@ -341,8 +382,30 @@ class NeRFDataset: # load nerf-compatible format data. - with open(opt.pose, 'r') as f: - transform = json.load(f) + # load all splits (train/valid/test) + if type == 'all': + transform_paths = glob.glob(os.path.join(self.root_path, '*.json')) + transform = None + for transform_path in transform_paths: + with open(transform_path, 'r') as f: + tmp_transform = json.load(f) + if transform is None: + transform = tmp_transform + else: + transform['frames'].extend(tmp_transform['frames']) + # load train and val split + elif type == 'trainval': + with open(os.path.join(self.root_path, f'transforms_train.json'), 'r') as f: + transform = json.load(f) + with open(os.path.join(self.root_path, f'transforms_val.json'), 'r') as f: + transform_val = json.load(f) + transform['frames'].extend(transform_val['frames']) + # only load one specified split + else: + # no test, use val as test + _split = 'val' if type == 'test' else type + with open(os.path.join(self.root_path, f'transforms_{_split}.json'), 'r') as f: + transform = json.load(f) # load image size if 'h' in transform and 'w' in transform: @@ -371,6 +434,10 @@ class NeRFDataset: aud_features = np.load(os.path.join(self.root_path, 'aud_eo.npy')) elif 'deepspeech' in self.opt.asr_model: aud_features = np.load(os.path.join(self.root_path, 'aud_ds.npy')) + # elif 'hubert_cn' in self.opt.asr_model: + # aud_features = np.load(os.path.join(self.root_path, 'aud_hu_cn.npy')) + elif 'hubert' in self.opt.asr_model: + aud_features = np.load(os.path.join(self.root_path, 'aud_hu.npy')) else: aud_features = np.load(os.path.join(self.root_path, 'aud.npy')) # cross-driven extracted features. diff --git a/nerfreal.py b/nerfreal.py index 08c3d9b..426f08f 100644 --- a/nerfreal.py +++ b/nerfreal.py @@ -34,9 +34,8 @@ class NeRFReal: self.audio_features = data_loader._data.auds # [N, 29, 16] self.audio_idx = 0 - self.frame_total_num = data_loader._data.end_index - print("frame_total_num:",self.frame_total_num) - self.frame_index=0 + #self.frame_total_num = data_loader._data.end_index + #print("frame_total_num:",self.frame_total_num) # control eye self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item() @@ -109,6 +108,9 @@ class NeRFReal: def push_audio(self,chunk): self.asr.push_audio(chunk) + + def before_push_audio(self): + self.asr.before_push_audio() def prepare_buffer(self, outputs): if self.mode == 'image': @@ -140,7 +142,8 @@ class NeRFReal: if not self.opt.fullbody: self.streamer.stream_frame(image) else: #fullbody human - image_fullbody = cv2.imread(os.path.join(self.opt.fullbody_img, str(self.frame_index%self.frame_total_num)+'.jpg')) + #print("frame index:",data['index']) + image_fullbody = cv2.imread(os.path.join(self.opt.fullbody_img, str(data['index'][0])+'.jpg')) image_fullbody = cv2.cvtColor(image_fullbody, cv2.COLOR_BGR2RGB) start_x = self.opt.fullbody_offset_x # 合并后小图片的起始x坐标 start_y = self.opt.fullbody_offset_y # 合并后小图片的起始y坐标 @@ -201,7 +204,6 @@ class NeRFReal: for _ in range(2): self.asr.run_step() self.test_step() - self.frame_index = (self.frame_index+1)%self.frame_total_num totaltime += (time.time() - t) count += 1 if count==100: