Merge branch 'lipku:main' into main
This commit is contained in:
commit
313d57dfa4
15
README.md
15
README.md
|
@ -3,6 +3,12 @@ A streaming digital human based on the Ernerf model, realize audio video synch
|
||||||
|
|
||||||
[![Watch the video]](/assets/demo.mp4)
|
[![Watch the video]](/assets/demo.mp4)
|
||||||
|
|
||||||
|
## Features
|
||||||
|
1. 支持声音克隆
|
||||||
|
2. 支持大模型对话
|
||||||
|
3. 支持多种音频特征驱动:wav2vec、hubert
|
||||||
|
4. 支持全身视频拼接
|
||||||
|
|
||||||
## 1. Installation
|
## 1. Installation
|
||||||
|
|
||||||
Tested on Ubuntu 20.04, Python3.10, Pytorch 1.12 and CUDA 11.3
|
Tested on Ubuntu 20.04, Python3.10, Pytorch 1.12 and CUDA 11.3
|
||||||
|
@ -53,7 +59,7 @@ nginx
|
||||||
|
|
||||||
用浏览器打开http://serverip/echo.html, 在文本框输入任意文字,提交。数字人播报该段文字
|
用浏览器打开http://serverip/echo.html, 在文本框输入任意文字,提交。数字人播报该段文字
|
||||||
|
|
||||||
## 3. 更多使用
|
## 3. More Usage
|
||||||
### 3.1 使用LLM模型进行数字人对话
|
### 3.1 使用LLM模型进行数字人对话
|
||||||
|
|
||||||
目前借鉴数字人对话系统[LinlyTalker](https://github.com/Kedreamix/Linly-Talker)的方式,LLM模型支持Chatgpt,Qwen和GeminiPro。需要在app.py中填入自己的api_key。
|
目前借鉴数字人对话系统[LinlyTalker](https://github.com/Kedreamix/Linly-Talker)的方式,LLM模型支持Chatgpt,Qwen和GeminiPro。需要在app.py中填入自己的api_key。
|
||||||
|
@ -98,6 +104,7 @@ python app.py --fullbody --fullbody_img data/fullbody/img --fullbody_offset_x 10
|
||||||
```
|
```
|
||||||
- --fullbody_width、--fullbody_height 全身视频的宽、高
|
- --fullbody_width、--fullbody_height 全身视频的宽、高
|
||||||
- --W、--H 训练视频的宽、高
|
- --W、--H 训练视频的宽、高
|
||||||
|
- ernerf训练第三步torso如果训练的不好,在拼接处会有接缝。可以在上面的命令加上--torso_imgs data/xxx/torso_imgs,torso不用模型推理,直接用训练数据集里的torso图片。这种方式可能头颈处会有些人工痕迹。
|
||||||
|
|
||||||
## 4. Docker Run
|
## 4. Docker Run
|
||||||
不需要第1步的安装,直接运行。
|
不需要第1步的安装,直接运行。
|
||||||
|
@ -126,9 +133,9 @@ srs和nginx的运行同2.1和2.3
|
||||||
在Tesla T4显卡上测试整体fps为18左右,如果去掉音视频编码推流,帧率在20左右。用4090显卡可以达到40多帧/秒。
|
在Tesla T4显卡上测试整体fps为18左右,如果去掉音视频编码推流,帧率在20左右。用4090显卡可以达到40多帧/秒。
|
||||||
优化:新开一个线程运行音视频编码推流
|
优化:新开一个线程运行音视频编码推流
|
||||||
2. 延时
|
2. 延时
|
||||||
整体延时5s多
|
整体延时3s左右
|
||||||
(1)tts延时2s左右,目前用的edgetts,需要将每句话转完后一次性输入,可以优化tts改成流式输入
|
(1)tts延时1.7s左右,目前用的edgetts,需要将每句话转完后一次性输入,可以优化tts改成流式输入
|
||||||
(2)wav2vec延时1s多,需要缓存50帧音频做计算,可以通过-m设置context_size来减少延时
|
(2)wav2vec延时0.4s,需要缓存18帧音频做计算
|
||||||
(3)srs转发延时,设置srs服务器减少缓冲延时。具体配置可看 https://ossrs.net/lts/zh-cn/docs/v5/doc/low-latency, 配置了一个低延时版本
|
(3)srs转发延时,设置srs服务器减少缓冲延时。具体配置可看 https://ossrs.net/lts/zh-cn/docs/v5/doc/low-latency, 配置了一个低延时版本
|
||||||
```python
|
```python
|
||||||
docker run --rm -it -p 1935:1935 -p 1985:1985 -p 8080:8080 registry.cn-hangzhou.aliyuncs.com/lipku/srs:v1.1
|
docker run --rm -it -p 1935:1935 -p 1985:1985 -p 8080:8080 registry.cn-hangzhou.aliyuncs.com/lipku/srs:v1.1
|
||||||
|
|
9
app.py
9
app.py
|
@ -37,7 +37,11 @@ async def main(voicename: str, text: str, render):
|
||||||
communicate = edge_tts.Communicate(text, voicename)
|
communicate = edge_tts.Communicate(text, voicename)
|
||||||
|
|
||||||
#with open(OUTPUT_FILE, "wb") as file:
|
#with open(OUTPUT_FILE, "wb") as file:
|
||||||
|
first = True
|
||||||
async for chunk in communicate.stream():
|
async for chunk in communicate.stream():
|
||||||
|
if first:
|
||||||
|
#render.before_push_audio()
|
||||||
|
first = False
|
||||||
if chunk["type"] == "audio":
|
if chunk["type"] == "audio":
|
||||||
render.push_audio(chunk["data"])
|
render.push_audio(chunk["data"])
|
||||||
#file.write(chunk["data"])
|
#file.write(chunk["data"])
|
||||||
|
@ -160,6 +164,7 @@ if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--pose', type=str, default="data/data_kf.json", help="transforms.json, pose source")
|
parser.add_argument('--pose', type=str, default="data/data_kf.json", help="transforms.json, pose source")
|
||||||
parser.add_argument('--au', type=str, default="data/au.csv", help="eye blink area")
|
parser.add_argument('--au', type=str, default="data/au.csv", help="eye blink area")
|
||||||
|
parser.add_argument('--torso_imgs', type=str, default="", help="torso images path")
|
||||||
|
|
||||||
parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --exp_eye")
|
parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --exp_eye")
|
||||||
|
|
||||||
|
@ -259,7 +264,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--fps', type=int, default=50)
|
parser.add_argument('--fps', type=int, default=50)
|
||||||
# sliding window left-middle-right length (unit: 20ms)
|
# sliding window left-middle-right length (unit: 20ms)
|
||||||
parser.add_argument('-l', type=int, default=10)
|
parser.add_argument('-l', type=int, default=10)
|
||||||
parser.add_argument('-m', type=int, default=50)
|
parser.add_argument('-m', type=int, default=8)
|
||||||
parser.add_argument('-r', type=int, default=10)
|
parser.add_argument('-r', type=int, default=10)
|
||||||
|
|
||||||
parser.add_argument('--fullbody', action='store_true', help="fullbody human")
|
parser.add_argument('--fullbody', action='store_true', help="fullbody human")
|
||||||
|
@ -298,6 +303,7 @@ if __name__ == '__main__':
|
||||||
opt.exp_eye = True
|
opt.exp_eye = True
|
||||||
opt.smooth_eye = True
|
opt.smooth_eye = True
|
||||||
|
|
||||||
|
if opt.torso_imgs=='': #no img,use model output
|
||||||
opt.torso = True
|
opt.torso = True
|
||||||
|
|
||||||
# assert opt.cuda_ray, "Only support CUDA ray mode."
|
# assert opt.cuda_ray, "Only support CUDA ray mode."
|
||||||
|
@ -307,6 +313,7 @@ if __name__ == '__main__':
|
||||||
# assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss."
|
# assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss."
|
||||||
assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays."
|
assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays."
|
||||||
seed_everything(opt.seed)
|
seed_everything(opt.seed)
|
||||||
|
print(opt)
|
||||||
|
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
model = NeRFNetwork(opt)
|
model = NeRFNetwork(opt)
|
||||||
|
|
403
asrreal.py
403
asrreal.py
|
@ -122,57 +122,14 @@ class ASR:
|
||||||
self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding...
|
self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding...
|
||||||
|
|
||||||
# warm up steps needed: mid + right + window_size + attention_size
|
# warm up steps needed: mid + right + window_size + attention_size
|
||||||
self.warm_up_steps = self.context_size + self.stride_right_size + self.stride_left_size #+ 8 + 2 * 3
|
self.warm_up_steps = self.context_size + self.stride_right_size #+ self.stride_left_size #+ 8 + 2 * 3
|
||||||
|
|
||||||
self.listening = False
|
self.listening = False
|
||||||
self.playing = False
|
self.playing = False
|
||||||
|
|
||||||
def listen(self):
|
def get_next_feat(self): #get audio embedding to nerf
|
||||||
# start
|
|
||||||
if self.mode == 'live' and not self.listening:
|
|
||||||
print(f'[INFO] starting read frame thread...')
|
|
||||||
self.process_read_frame.start()
|
|
||||||
self.listening = True
|
|
||||||
|
|
||||||
if self.play and not self.playing:
|
|
||||||
print(f'[INFO] starting play frame thread...')
|
|
||||||
self.process_play_frame.start()
|
|
||||||
self.playing = True
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
|
|
||||||
self.exit_event.set()
|
|
||||||
|
|
||||||
if self.play:
|
|
||||||
self.output_stream.stop_stream()
|
|
||||||
self.output_stream.close()
|
|
||||||
if self.playing:
|
|
||||||
self.process_play_frame.join()
|
|
||||||
self.playing = False
|
|
||||||
|
|
||||||
if self.mode == 'live':
|
|
||||||
#self.input_stream.stop_stream() todo
|
|
||||||
self.input_stream.close()
|
|
||||||
if self.listening:
|
|
||||||
self.process_read_frame.join()
|
|
||||||
self.listening = False
|
|
||||||
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
|
||||||
|
|
||||||
self.stop()
|
|
||||||
|
|
||||||
if self.mode == 'live':
|
|
||||||
# live mode: also print the result text.
|
|
||||||
self.text += '\n[END]'
|
|
||||||
print(self.text)
|
|
||||||
|
|
||||||
def get_next_feat(self):
|
|
||||||
# return a [1/8, 16] window, for the next input to nerf side.
|
# return a [1/8, 16] window, for the next input to nerf side.
|
||||||
|
if self.opt.att>0:
|
||||||
while len(self.att_feats) < 8:
|
while len(self.att_feats) < 8:
|
||||||
# [------f+++t-----]
|
# [------f+++t-----]
|
||||||
if self.front < self.tail:
|
if self.front < self.tail:
|
||||||
|
@ -192,6 +149,19 @@ class ASR:
|
||||||
|
|
||||||
# discard old
|
# discard old
|
||||||
self.att_feats = self.att_feats[1:]
|
self.att_feats = self.att_feats[1:]
|
||||||
|
else:
|
||||||
|
# [------f+++t-----]
|
||||||
|
if self.front < self.tail:
|
||||||
|
feat = self.feat_queue[self.front:self.tail]
|
||||||
|
# [++t-----------f+]
|
||||||
|
else:
|
||||||
|
feat = torch.cat([self.feat_queue[self.front:], self.feat_queue[:self.tail]], dim=0)
|
||||||
|
|
||||||
|
self.front = (self.front + 2) % self.feat_queue.shape[0]
|
||||||
|
self.tail = (self.tail + 2) % self.feat_queue.shape[0]
|
||||||
|
|
||||||
|
att_feat = feat.permute(1, 0).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
return att_feat
|
return att_feat
|
||||||
|
|
||||||
|
@ -201,7 +171,7 @@ class ASR:
|
||||||
return
|
return
|
||||||
|
|
||||||
# get a frame of audio
|
# get a frame of audio
|
||||||
frame = self.get_audio_frame()
|
frame = self.__get_audio_frame()
|
||||||
|
|
||||||
# the last frame
|
# the last frame
|
||||||
if frame is None:
|
if frame is None:
|
||||||
|
@ -223,7 +193,7 @@ class ASR:
|
||||||
|
|
||||||
print(f'[INFO] frame_to_text... ')
|
print(f'[INFO] frame_to_text... ')
|
||||||
#t = time.time()
|
#t = time.time()
|
||||||
logits, labels, text = self.frame_to_text(inputs)
|
logits, labels, text = self.__frame_to_text(inputs)
|
||||||
#print(f'-------wav2vec time:{time.time()-t:.4f}s')
|
#print(f'-------wav2vec time:{time.time()-t:.4f}s')
|
||||||
feats = logits # better lips-sync than labels
|
feats = logits # better lips-sync than labels
|
||||||
|
|
||||||
|
@ -264,6 +234,166 @@ class ASR:
|
||||||
np.save(output_path, unfold_feats.cpu().numpy())
|
np.save(output_path, unfold_feats.cpu().numpy())
|
||||||
print(f"[INFO] saved logits to {output_path}")
|
print(f"[INFO] saved logits to {output_path}")
|
||||||
|
|
||||||
|
def __get_audio_frame(self):
|
||||||
|
if self.inwarm: # warm up
|
||||||
|
return np.zeros(self.chunk, dtype=np.float32)
|
||||||
|
|
||||||
|
if self.mode == 'file':
|
||||||
|
if self.idx < self.file_stream.shape[0]:
|
||||||
|
frame = self.file_stream[self.idx: self.idx + self.chunk]
|
||||||
|
self.idx = self.idx + self.chunk
|
||||||
|
return frame
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
frame = self.queue.get(block=False)
|
||||||
|
print(f'[INFO] get frame {frame.shape}')
|
||||||
|
except queue.Empty:
|
||||||
|
frame = np.zeros(self.chunk, dtype=np.float32)
|
||||||
|
|
||||||
|
self.idx = self.idx + self.chunk
|
||||||
|
|
||||||
|
return frame
|
||||||
|
|
||||||
|
|
||||||
|
def __frame_to_text(self, frame):
|
||||||
|
# frame: [N * 320], N = (context_size + 2 * stride_size)
|
||||||
|
|
||||||
|
inputs = self.processor(frame, sampling_rate=self.sample_rate, return_tensors="pt", padding=True)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
result = self.model(inputs.input_values.to(self.device))
|
||||||
|
if 'hubert' in self.opt.asr_model:
|
||||||
|
logits = result.last_hidden_state # [B=1, T=pts//320, hid=1024]
|
||||||
|
else:
|
||||||
|
logits = result.logits # [1, N - 1, 32]
|
||||||
|
#print('logits.shape:',logits.shape)
|
||||||
|
|
||||||
|
# cut off stride
|
||||||
|
left = max(0, self.stride_left_size)
|
||||||
|
right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input.
|
||||||
|
|
||||||
|
# do not cut right if terminated.
|
||||||
|
if self.terminated:
|
||||||
|
right = logits.shape[1]
|
||||||
|
|
||||||
|
logits = logits[:, left:right]
|
||||||
|
|
||||||
|
# print(frame.shape, inputs.input_values.shape, logits.shape)
|
||||||
|
|
||||||
|
#predicted_ids = torch.argmax(logits, dim=-1)
|
||||||
|
#transcription = self.processor.batch_decode(predicted_ids)[0].lower()
|
||||||
|
|
||||||
|
|
||||||
|
# for esperanto
|
||||||
|
# labels = np.array(['ŭ', '»', 'c', 'ĵ', 'ñ', '”', '„', '“', 'ǔ', 'o', 'ĝ', 'm', 'k', 'd', 'a', 'ŝ', 'z', 'i', '«', '—', '‘', 'ĥ', 'f', 'y', 'h', 'j', '|', 'r', 'u', 'ĉ', 's', '–', 'fi', 'l', 'p', '’', 'g', 'v', 't', 'b', 'n', 'e', '[UNK]', '[PAD]'])
|
||||||
|
|
||||||
|
# labels = np.array([' ', ' ', ' ', '-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z'])
|
||||||
|
# print(''.join(labels[predicted_ids[0].detach().cpu().long().numpy()]))
|
||||||
|
# print(predicted_ids[0])
|
||||||
|
# print(transcription)
|
||||||
|
|
||||||
|
return logits[0], None,None #predicted_ids[0], transcription # [N,]
|
||||||
|
|
||||||
|
def __create_bytes_stream(self,byte_stream):
|
||||||
|
#byte_stream=BytesIO(buffer)
|
||||||
|
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
|
||||||
|
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
|
||||||
|
stream = stream.astype(np.float32)
|
||||||
|
|
||||||
|
if stream.ndim > 1:
|
||||||
|
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
|
||||||
|
stream = stream[:, 0]
|
||||||
|
|
||||||
|
if sample_rate != self.sample_rate and stream.shape[0]>0:
|
||||||
|
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
|
||||||
|
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
|
||||||
|
|
||||||
|
return stream
|
||||||
|
|
||||||
|
def push_audio(self,buffer): #push audio pcm from tts
|
||||||
|
print(f'[INFO] push_audio {len(buffer)}')
|
||||||
|
if self.opt.tts == "xtts":
|
||||||
|
if len(buffer)>0:
|
||||||
|
stream = np.frombuffer(buffer, dtype=np.int16).astype(np.float32) / 32767
|
||||||
|
stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate)
|
||||||
|
#byte_stream=BytesIO(buffer)
|
||||||
|
#stream = self.__create_bytes_stream(byte_stream)
|
||||||
|
streamlen = stream.shape[0]
|
||||||
|
idx=0
|
||||||
|
while streamlen >= self.chunk:
|
||||||
|
self.queue.put(stream[idx:idx+self.chunk])
|
||||||
|
streamlen -= self.chunk
|
||||||
|
idx += self.chunk
|
||||||
|
# if streamlen>0: #skip last frame(not 20ms)
|
||||||
|
# self.queue.put(stream[idx:])
|
||||||
|
else: #edge tts
|
||||||
|
self.input_stream.write(buffer)
|
||||||
|
if len(buffer)<=0:
|
||||||
|
self.input_stream.seek(0)
|
||||||
|
stream = self.__create_bytes_stream(self.input_stream)
|
||||||
|
streamlen = stream.shape[0]
|
||||||
|
idx=0
|
||||||
|
while streamlen >= self.chunk:
|
||||||
|
self.queue.put(stream[idx:idx+self.chunk])
|
||||||
|
streamlen -= self.chunk
|
||||||
|
idx += self.chunk
|
||||||
|
#if streamlen>0: #skip last frame(not 20ms)
|
||||||
|
# self.queue.put(stream[idx:])
|
||||||
|
self.input_stream.seek(0)
|
||||||
|
self.input_stream.truncate()
|
||||||
|
|
||||||
|
def get_audio_out(self): #get origin audio pcm to nerf
|
||||||
|
return self.output_queue.get()
|
||||||
|
|
||||||
|
def __init_queue(self):
|
||||||
|
self.frames = []
|
||||||
|
self.queue.queue.clear()
|
||||||
|
self.output_queue.queue.clear()
|
||||||
|
self.front = self.feat_buffer_size * self.context_size - 8 # fake padding
|
||||||
|
self.tail = 8
|
||||||
|
# attention window...
|
||||||
|
self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4
|
||||||
|
|
||||||
|
def before_push_audio(self):
|
||||||
|
self.__init_queue()
|
||||||
|
self.warm_up()
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
|
||||||
|
self.listen()
|
||||||
|
|
||||||
|
while not self.terminated:
|
||||||
|
self.run_step()
|
||||||
|
|
||||||
|
def clear_queue(self):
|
||||||
|
# clear the queue, to reduce potential latency...
|
||||||
|
print(f'[INFO] clear queue')
|
||||||
|
if self.mode == 'live':
|
||||||
|
self.queue.queue.clear()
|
||||||
|
if self.play:
|
||||||
|
self.output_queue.queue.clear()
|
||||||
|
|
||||||
|
def warm_up(self):
|
||||||
|
|
||||||
|
#self.listen()
|
||||||
|
|
||||||
|
self.inwarm = True
|
||||||
|
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
|
||||||
|
t = time.time()
|
||||||
|
for _ in range(self.stride_left_size):
|
||||||
|
self.frames.append(np.zeros(self.chunk, dtype=np.float32))
|
||||||
|
for _ in range(self.warm_up_steps):
|
||||||
|
self.run_step()
|
||||||
|
#if torch.cuda.is_available():
|
||||||
|
# torch.cuda.synchronize()
|
||||||
|
t = time.time() - t
|
||||||
|
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
|
||||||
|
self.inwarm = False
|
||||||
|
|
||||||
|
#self.clear_queue()
|
||||||
|
|
||||||
'''
|
'''
|
||||||
def create_file_stream(self):
|
def create_file_stream(self):
|
||||||
|
|
||||||
|
@ -311,157 +441,50 @@ class ASR:
|
||||||
|
|
||||||
return audio, stream
|
return audio, stream
|
||||||
'''
|
'''
|
||||||
|
#####not used function#####################################
|
||||||
|
def listen(self):
|
||||||
|
# start
|
||||||
|
if self.mode == 'live' and not self.listening:
|
||||||
|
print(f'[INFO] starting read frame thread...')
|
||||||
|
self.process_read_frame.start()
|
||||||
|
self.listening = True
|
||||||
|
|
||||||
def get_audio_frame(self):
|
if self.play and not self.playing:
|
||||||
|
print(f'[INFO] starting play frame thread...')
|
||||||
|
self.process_play_frame.start()
|
||||||
|
self.playing = True
|
||||||
|
|
||||||
if self.inwarm: # warm up
|
def stop(self):
|
||||||
return np.zeros(self.chunk, dtype=np.float32)
|
|
||||||
|
|
||||||
if self.mode == 'file':
|
self.exit_event.set()
|
||||||
|
|
||||||
if self.idx < self.file_stream.shape[0]:
|
|
||||||
frame = self.file_stream[self.idx: self.idx + self.chunk]
|
|
||||||
self.idx = self.idx + self.chunk
|
|
||||||
return frame
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
frame = self.queue.get(block=False)
|
|
||||||
print(f'[INFO] get frame {frame.shape}')
|
|
||||||
except queue.Empty:
|
|
||||||
frame = np.zeros(self.chunk, dtype=np.float32)
|
|
||||||
|
|
||||||
self.idx = self.idx + self.chunk
|
|
||||||
|
|
||||||
return frame
|
|
||||||
|
|
||||||
|
|
||||||
def frame_to_text(self, frame):
|
|
||||||
# frame: [N * 320], N = (context_size + 2 * stride_size)
|
|
||||||
|
|
||||||
inputs = self.processor(frame, sampling_rate=self.sample_rate, return_tensors="pt", padding=True)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
result = self.model(inputs.input_values.to(self.device))
|
|
||||||
if 'hubert' in self.opt.asr_model:
|
|
||||||
logits = result.last_hidden_state # [B=1, T=pts//320, hid=1024]
|
|
||||||
else:
|
|
||||||
logits = result.logits # [1, N - 1, 32]
|
|
||||||
#print('logits.shape:',logits.shape)
|
|
||||||
|
|
||||||
# cut off stride
|
|
||||||
left = max(0, self.stride_left_size)
|
|
||||||
right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input.
|
|
||||||
|
|
||||||
# do not cut right if terminated.
|
|
||||||
if self.terminated:
|
|
||||||
right = logits.shape[1]
|
|
||||||
|
|
||||||
logits = logits[:, left:right]
|
|
||||||
|
|
||||||
# print(frame.shape, inputs.input_values.shape, logits.shape)
|
|
||||||
|
|
||||||
#predicted_ids = torch.argmax(logits, dim=-1)
|
|
||||||
#transcription = self.processor.batch_decode(predicted_ids)[0].lower()
|
|
||||||
|
|
||||||
|
|
||||||
# for esperanto
|
|
||||||
# labels = np.array(['ŭ', '»', 'c', 'ĵ', 'ñ', '”', '„', '“', 'ǔ', 'o', 'ĝ', 'm', 'k', 'd', 'a', 'ŝ', 'z', 'i', '«', '—', '‘', 'ĥ', 'f', 'y', 'h', 'j', '|', 'r', 'u', 'ĉ', 's', '–', 'fi', 'l', 'p', '’', 'g', 'v', 't', 'b', 'n', 'e', '[UNK]', '[PAD]'])
|
|
||||||
|
|
||||||
# labels = np.array([' ', ' ', ' ', '-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z'])
|
|
||||||
# print(''.join(labels[predicted_ids[0].detach().cpu().long().numpy()]))
|
|
||||||
# print(predicted_ids[0])
|
|
||||||
# print(transcription)
|
|
||||||
|
|
||||||
return logits[0], None,None #predicted_ids[0], transcription # [N,]
|
|
||||||
|
|
||||||
def create_bytes_stream(self,byte_stream):
|
|
||||||
#byte_stream=BytesIO(buffer)
|
|
||||||
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
|
|
||||||
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
|
|
||||||
stream = stream.astype(np.float32)
|
|
||||||
|
|
||||||
if stream.ndim > 1:
|
|
||||||
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
|
|
||||||
stream = stream[:, 0]
|
|
||||||
|
|
||||||
if sample_rate != self.sample_rate and stream.shape[0]>0:
|
|
||||||
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
|
|
||||||
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
|
|
||||||
|
|
||||||
return stream
|
|
||||||
|
|
||||||
def push_audio(self,buffer):
|
|
||||||
print(f'[INFO] push_audio {len(buffer)}')
|
|
||||||
if self.opt.tts == "xtts":
|
|
||||||
if len(buffer)>0:
|
|
||||||
stream = np.frombuffer(buffer, dtype=np.int16).astype(np.float32) / 32767
|
|
||||||
stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate)
|
|
||||||
#byte_stream=BytesIO(buffer)
|
|
||||||
#stream = self.create_bytes_stream(byte_stream)
|
|
||||||
streamlen = stream.shape[0]
|
|
||||||
idx=0
|
|
||||||
while streamlen >= self.chunk:
|
|
||||||
self.queue.put(stream[idx:idx+self.chunk])
|
|
||||||
streamlen -= self.chunk
|
|
||||||
idx += self.chunk
|
|
||||||
# if streamlen>0: #skip last frame(not 20ms)
|
|
||||||
# self.queue.put(stream[idx:])
|
|
||||||
else: #edge tts
|
|
||||||
self.input_stream.write(buffer)
|
|
||||||
if len(buffer)<=0:
|
|
||||||
self.input_stream.seek(0)
|
|
||||||
stream = self.create_bytes_stream(self.input_stream)
|
|
||||||
streamlen = stream.shape[0]
|
|
||||||
idx=0
|
|
||||||
while streamlen >= self.chunk:
|
|
||||||
self.queue.put(stream[idx:idx+self.chunk])
|
|
||||||
streamlen -= self.chunk
|
|
||||||
idx += self.chunk
|
|
||||||
#if streamlen>0: #skip last frame(not 20ms)
|
|
||||||
# self.queue.put(stream[idx:])
|
|
||||||
self.input_stream.seek(0)
|
|
||||||
self.input_stream.truncate()
|
|
||||||
|
|
||||||
def get_audio_out(self):
|
|
||||||
return self.output_queue.get()
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
|
|
||||||
self.listen()
|
|
||||||
|
|
||||||
while not self.terminated:
|
|
||||||
self.run_step()
|
|
||||||
|
|
||||||
def clear_queue(self):
|
|
||||||
# clear the queue, to reduce potential latency...
|
|
||||||
print(f'[INFO] clear queue')
|
|
||||||
if self.mode == 'live':
|
|
||||||
self.queue.queue.clear()
|
|
||||||
if self.play:
|
if self.play:
|
||||||
self.output_queue.queue.clear()
|
self.output_stream.stop_stream()
|
||||||
|
self.output_stream.close()
|
||||||
|
if self.playing:
|
||||||
|
self.process_play_frame.join()
|
||||||
|
self.playing = False
|
||||||
|
|
||||||
def warm_up(self):
|
if self.mode == 'live':
|
||||||
|
#self.input_stream.stop_stream() todo
|
||||||
#self.listen()
|
self.input_stream.close()
|
||||||
|
if self.listening:
|
||||||
self.inwarm = True
|
self.process_read_frame.join()
|
||||||
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
|
self.listening = False
|
||||||
t = time.time()
|
|
||||||
for _ in range(self.warm_up_steps):
|
|
||||||
self.run_step()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
t = time.time() - t
|
|
||||||
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
|
|
||||||
self.inwarm = False
|
|
||||||
|
|
||||||
#self.clear_queue()
|
|
||||||
|
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
|
||||||
|
self.stop()
|
||||||
|
|
||||||
|
if self.mode == 'live':
|
||||||
|
# live mode: also print the result text.
|
||||||
|
self.text += '\n[END]'
|
||||||
|
print(self.text)
|
||||||
|
#########################################################
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
import argparse
|
import argparse
|
||||||
|
|
7
main.py
7
main.py
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from nerf_triplane.provider import NeRFDataset
|
from nerf_triplane.provider import NeRFDataset,NeRFDataset_Test
|
||||||
from nerf_triplane.utils import *
|
from nerf_triplane.utils import *
|
||||||
from nerf_triplane.network import NeRFNetwork
|
from nerf_triplane.network import NeRFNetwork
|
||||||
|
|
||||||
|
@ -24,6 +24,9 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--workspace', type=str, default='workspace')
|
parser.add_argument('--workspace', type=str, default='workspace')
|
||||||
parser.add_argument('--seed', type=int, default=0)
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
|
|
||||||
|
parser.add_argument('--pose', type=str, default="data/data_kf.json", help="transforms.json, pose source")
|
||||||
|
parser.add_argument('--au', type=str, default="data/au.csv", help="eye blink area")
|
||||||
|
|
||||||
### training options
|
### training options
|
||||||
parser.add_argument('--iters', type=int, default=200000, help="training iters")
|
parser.add_argument('--iters', type=int, default=200000, help="training iters")
|
||||||
parser.add_argument('--lr', type=float, default=1e-2, help="initial learning rate")
|
parser.add_argument('--lr', type=float, default=1e-2, help="initial learning rate")
|
||||||
|
@ -47,7 +50,7 @@ if __name__ == '__main__':
|
||||||
### network backbone options
|
### network backbone options
|
||||||
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
|
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
|
||||||
|
|
||||||
parser.add_argument('--bg_img', type=str, default='', help="background image")
|
parser.add_argument('--bg_img', type=str, default='white', help="background image")
|
||||||
parser.add_argument('--fbg', action='store_true', help="frame-wise bg")
|
parser.add_argument('--fbg', action='store_true', help="frame-wise bg")
|
||||||
parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes")
|
parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes")
|
||||||
parser.add_argument('--fix_eye', type=float, default=-1, help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye")
|
parser.add_argument('--fix_eye', type=float, default=-1, help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye")
|
||||||
|
|
|
@ -98,6 +98,7 @@ class NeRFDataset_Test:
|
||||||
|
|
||||||
self.training = False
|
self.training = False
|
||||||
self.num_rays = -1
|
self.num_rays = -1
|
||||||
|
self.preload = opt.preload # 0 = disk, 1 = cpu, 2 = gpu
|
||||||
|
|
||||||
# load nerf-compatible format data.
|
# load nerf-compatible format data.
|
||||||
|
|
||||||
|
@ -148,6 +149,7 @@ class NeRFDataset_Test:
|
||||||
self.poses = []
|
self.poses = []
|
||||||
self.auds = []
|
self.auds = []
|
||||||
self.eye_area = []
|
self.eye_area = []
|
||||||
|
self.torso_img = []
|
||||||
|
|
||||||
for f in tqdm.tqdm(frames, desc=f'Loading data'):
|
for f in tqdm.tqdm(frames, desc=f'Loading data'):
|
||||||
|
|
||||||
|
@ -173,6 +175,29 @@ class NeRFDataset_Test:
|
||||||
|
|
||||||
self.eye_area.append(area)
|
self.eye_area.append(area)
|
||||||
|
|
||||||
|
# load frame-wise bg
|
||||||
|
|
||||||
|
if self.opt.torso_imgs!='':
|
||||||
|
torso_img_path = os.path.join(self.opt.torso_imgs, str(f['img_id']) + '.png')
|
||||||
|
|
||||||
|
if self.preload > 0:
|
||||||
|
torso_img = cv2.imread(torso_img_path, cv2.IMREAD_UNCHANGED) # [H, W, 4]
|
||||||
|
torso_img = cv2.cvtColor(torso_img, cv2.COLOR_BGRA2RGBA)
|
||||||
|
torso_img = torso_img.astype(np.float32) / 255 # [H, W, 3/4]
|
||||||
|
|
||||||
|
self.torso_img.append(torso_img)
|
||||||
|
else:
|
||||||
|
self.torso_img.append(torso_img_path)
|
||||||
|
|
||||||
|
if self.opt.torso_imgs!='':
|
||||||
|
if self.preload > 0:
|
||||||
|
self.torso_img = torch.from_numpy(np.stack(self.torso_img, axis=0)) # [N, H, W, C]
|
||||||
|
else:
|
||||||
|
self.torso_img = np.array(self.torso_img)
|
||||||
|
if self.preload > 1: #gpu
|
||||||
|
self.torso_img = self.torso_img.to(torch.half).to(self.device)
|
||||||
|
|
||||||
|
|
||||||
# load pre-extracted background image (should be the same size as training image...)
|
# load pre-extracted background image (should be the same size as training image...)
|
||||||
|
|
||||||
if self.opt.bg_img == 'white': # special
|
if self.opt.bg_img == 'white': # special
|
||||||
|
@ -209,6 +234,9 @@ class NeRFDataset_Test:
|
||||||
|
|
||||||
self.bg_img = torch.from_numpy(self.bg_img)
|
self.bg_img = torch.from_numpy(self.bg_img)
|
||||||
|
|
||||||
|
if self.preload > 1 or self.opt.torso_imgs=='': #gpu
|
||||||
|
self.bg_img = self.bg_img.to(torch.half).to(self.device)
|
||||||
|
|
||||||
if self.opt.exp_eye:
|
if self.opt.exp_eye:
|
||||||
self.eye_area = np.array(self.eye_area, dtype=np.float32) # [N]
|
self.eye_area = np.array(self.eye_area, dtype=np.float32) # [N]
|
||||||
print(f'[INFO] eye_area: {self.eye_area.min()} - {self.eye_area.max()}')
|
print(f'[INFO] eye_area: {self.eye_area.min()} - {self.eye_area.max()}')
|
||||||
|
@ -230,8 +258,6 @@ class NeRFDataset_Test:
|
||||||
if self.auds is not None:
|
if self.auds is not None:
|
||||||
self.auds = self.auds.to(self.device)
|
self.auds = self.auds.to(self.device)
|
||||||
|
|
||||||
self.bg_img = self.bg_img.to(torch.half).to(self.device)
|
|
||||||
|
|
||||||
if self.opt.exp_eye:
|
if self.opt.exp_eye:
|
||||||
self.eye_area = self.eye_area.to(self.device)
|
self.eye_area = self.eye_area.to(self.device)
|
||||||
|
|
||||||
|
@ -286,6 +312,21 @@ class NeRFDataset_Test:
|
||||||
else:
|
else:
|
||||||
results['eye'] = None
|
results['eye'] = None
|
||||||
|
|
||||||
|
# load bg
|
||||||
|
if self.opt.torso_imgs!='':
|
||||||
|
bg_torso_img = self.torso_img[index]
|
||||||
|
if self.preload == 0: # on the fly loading
|
||||||
|
bg_torso_img = cv2.imread(bg_torso_img[0], cv2.IMREAD_UNCHANGED) # [H, W, 4]
|
||||||
|
bg_torso_img = cv2.cvtColor(bg_torso_img, cv2.COLOR_BGRA2RGBA)
|
||||||
|
bg_torso_img = bg_torso_img.astype(np.float32) / 255 # [H, W, 3/4]
|
||||||
|
bg_torso_img = torch.from_numpy(bg_torso_img).unsqueeze(0)
|
||||||
|
bg_torso_img = bg_torso_img[..., :3] * bg_torso_img[..., 3:] + self.bg_img * (1 - bg_torso_img[..., 3:])
|
||||||
|
bg_torso_img = bg_torso_img.view(B, -1, 3).to(self.device)
|
||||||
|
if not self.opt.torso:
|
||||||
|
bg_img = bg_torso_img
|
||||||
|
else:
|
||||||
|
bg_img = self.bg_img.view(1, -1, 3).repeat(B, 1, 1).to(self.device)
|
||||||
|
else:
|
||||||
bg_img = self.bg_img.view(1, -1, 3).repeat(B, 1, 1).to(self.device)
|
bg_img = self.bg_img.view(1, -1, 3).repeat(B, 1, 1).to(self.device)
|
||||||
|
|
||||||
results['bg_color'] = bg_img
|
results['bg_color'] = bg_img
|
||||||
|
@ -341,7 +382,29 @@ class NeRFDataset:
|
||||||
|
|
||||||
# load nerf-compatible format data.
|
# load nerf-compatible format data.
|
||||||
|
|
||||||
with open(opt.pose, 'r') as f:
|
# load all splits (train/valid/test)
|
||||||
|
if type == 'all':
|
||||||
|
transform_paths = glob.glob(os.path.join(self.root_path, '*.json'))
|
||||||
|
transform = None
|
||||||
|
for transform_path in transform_paths:
|
||||||
|
with open(transform_path, 'r') as f:
|
||||||
|
tmp_transform = json.load(f)
|
||||||
|
if transform is None:
|
||||||
|
transform = tmp_transform
|
||||||
|
else:
|
||||||
|
transform['frames'].extend(tmp_transform['frames'])
|
||||||
|
# load train and val split
|
||||||
|
elif type == 'trainval':
|
||||||
|
with open(os.path.join(self.root_path, f'transforms_train.json'), 'r') as f:
|
||||||
|
transform = json.load(f)
|
||||||
|
with open(os.path.join(self.root_path, f'transforms_val.json'), 'r') as f:
|
||||||
|
transform_val = json.load(f)
|
||||||
|
transform['frames'].extend(transform_val['frames'])
|
||||||
|
# only load one specified split
|
||||||
|
else:
|
||||||
|
# no test, use val as test
|
||||||
|
_split = 'val' if type == 'test' else type
|
||||||
|
with open(os.path.join(self.root_path, f'transforms_{_split}.json'), 'r') as f:
|
||||||
transform = json.load(f)
|
transform = json.load(f)
|
||||||
|
|
||||||
# load image size
|
# load image size
|
||||||
|
@ -371,6 +434,10 @@ class NeRFDataset:
|
||||||
aud_features = np.load(os.path.join(self.root_path, 'aud_eo.npy'))
|
aud_features = np.load(os.path.join(self.root_path, 'aud_eo.npy'))
|
||||||
elif 'deepspeech' in self.opt.asr_model:
|
elif 'deepspeech' in self.opt.asr_model:
|
||||||
aud_features = np.load(os.path.join(self.root_path, 'aud_ds.npy'))
|
aud_features = np.load(os.path.join(self.root_path, 'aud_ds.npy'))
|
||||||
|
# elif 'hubert_cn' in self.opt.asr_model:
|
||||||
|
# aud_features = np.load(os.path.join(self.root_path, 'aud_hu_cn.npy'))
|
||||||
|
elif 'hubert' in self.opt.asr_model:
|
||||||
|
aud_features = np.load(os.path.join(self.root_path, 'aud_hu.npy'))
|
||||||
else:
|
else:
|
||||||
aud_features = np.load(os.path.join(self.root_path, 'aud.npy'))
|
aud_features = np.load(os.path.join(self.root_path, 'aud.npy'))
|
||||||
# cross-driven extracted features.
|
# cross-driven extracted features.
|
||||||
|
|
12
nerfreal.py
12
nerfreal.py
|
@ -34,9 +34,8 @@ class NeRFReal:
|
||||||
self.audio_features = data_loader._data.auds # [N, 29, 16]
|
self.audio_features = data_loader._data.auds # [N, 29, 16]
|
||||||
self.audio_idx = 0
|
self.audio_idx = 0
|
||||||
|
|
||||||
self.frame_total_num = data_loader._data.end_index
|
#self.frame_total_num = data_loader._data.end_index
|
||||||
print("frame_total_num:",self.frame_total_num)
|
#print("frame_total_num:",self.frame_total_num)
|
||||||
self.frame_index=0
|
|
||||||
|
|
||||||
# control eye
|
# control eye
|
||||||
self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item()
|
self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item()
|
||||||
|
@ -110,6 +109,9 @@ class NeRFReal:
|
||||||
def push_audio(self,chunk):
|
def push_audio(self,chunk):
|
||||||
self.asr.push_audio(chunk)
|
self.asr.push_audio(chunk)
|
||||||
|
|
||||||
|
def before_push_audio(self):
|
||||||
|
self.asr.before_push_audio()
|
||||||
|
|
||||||
def prepare_buffer(self, outputs):
|
def prepare_buffer(self, outputs):
|
||||||
if self.mode == 'image':
|
if self.mode == 'image':
|
||||||
return outputs['image']
|
return outputs['image']
|
||||||
|
@ -140,7 +142,8 @@ class NeRFReal:
|
||||||
if not self.opt.fullbody:
|
if not self.opt.fullbody:
|
||||||
self.streamer.stream_frame(image)
|
self.streamer.stream_frame(image)
|
||||||
else: #fullbody human
|
else: #fullbody human
|
||||||
image_fullbody = cv2.imread(os.path.join(self.opt.fullbody_img, str(self.frame_index%self.frame_total_num)+'.jpg'))
|
#print("frame index:",data['index'])
|
||||||
|
image_fullbody = cv2.imread(os.path.join(self.opt.fullbody_img, str(data['index'][0])+'.jpg'))
|
||||||
image_fullbody = cv2.cvtColor(image_fullbody, cv2.COLOR_BGR2RGB)
|
image_fullbody = cv2.cvtColor(image_fullbody, cv2.COLOR_BGR2RGB)
|
||||||
start_x = self.opt.fullbody_offset_x # 合并后小图片的起始x坐标
|
start_x = self.opt.fullbody_offset_x # 合并后小图片的起始x坐标
|
||||||
start_y = self.opt.fullbody_offset_y # 合并后小图片的起始y坐标
|
start_y = self.opt.fullbody_offset_y # 合并后小图片的起始y坐标
|
||||||
|
@ -201,7 +204,6 @@ class NeRFReal:
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
self.asr.run_step()
|
self.asr.run_step()
|
||||||
self.test_step()
|
self.test_step()
|
||||||
self.frame_index = (self.frame_index+1)%self.frame_total_num
|
|
||||||
totaltime += (time.time() - t)
|
totaltime += (time.time() - t)
|
||||||
count += 1
|
count += 1
|
||||||
if count==100:
|
if count==100:
|
||||||
|
|
Loading…
Reference in New Issue