This commit is contained in:
lihengzhong 2023-12-19 09:41:52 +08:00
parent 6e909475c5
commit e4b2cab164
67 changed files with 231453 additions and 0 deletions

17
.gitignore vendored Normal file
View File

@ -0,0 +1,17 @@
__pycache__/
build/
*.egg-info/
*.so
*.mp4
tmp*
trial*/
data
data_utils/face_tracking/3DMM/*
data_utils/face_parsing/79999_iter.pth
pretrained
*.mp4
.DS_Store
workspace/log_ngp.txt

121
README.md Normal file
View File

@ -0,0 +1,121 @@
# 虚拟人说话头生成(照片虚拟人实时驱动)
![](/img/example.gif)
# Get Started
## Installation
Tested on Ubuntu 22.04, Pytorch 1.12 and CUDA 11.6or Pytorch 1.12 and CUDA 11.3
```python
git clone https://github.com/waityousea/xuniren.git
cd xuniren
```
### Install dependency
```python
# for ubuntu, portaudio is needed for pyaudio to work.
sudo apt install portaudio19-dev
pip install -r requirements.txt
or
## environment.yml中的pytorch使用的1.12和cuda 11.3
conda env create -f environment.yml
## install pytorch3d
#ubuntu/mac
pip install "git+https://github.com/facebookresearch/pytorch3d.git"
```
**windows安装pytorch3d**
- gcc & g++ ≥ 4.9
在windows中需要安装gcc编译器可以根据需求自行安装例如采用MinGW
以下安装步骤来自于[pytorch3d](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md)官方, 可以根据需求进行选择。
```python
conda create -n pytorch3d python=3.9
conda activate pytorch3d
conda install pytorch=1.13.0 torchvision pytorch-cuda=11.6 -c pytorch -c nvidia
conda install -c fvcore -c iopath -c conda-forge fvcore iopath
```
对于 CUB 构建时间依赖项,仅当您的 CUDA 早于 11.7 时才需要,如果您使用的是 conda则可以继续
```
conda install -c bottler nvidiacub
```
```
# Demos and examples
conda install jupyter
pip install scikit-image matplotlib imageio plotly opencv-python
# Tests/Linting
pip install black usort flake8 flake8-bugbear flake8-comprehensions
```
任何必要的补丁后你可以去“x64 Native Tools Command Prompt for VS 2019”编译安装
```
git clone https://github.com/facebookresearch/pytorch3d.git
cd pytorch3d
python setup.py install
```
### Build extension
By default, we use [`load`](https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.load) to build the extension at runtime. However, this may be inconvenient sometimes. Therefore, we also provide the `setup.py` to build each extension:
```
# install all extension modules
# notice: 该模块必须安装。
# 在windows下建议采用vs2019的x64 Native Tools Command Prompt for VS 2019命令窗口安装
bash scripts/install_ext.sh
```
### **start(独立运行)**
环境配置完成后,启动虚拟人生成器:
```python
python app.py
```
### **start对接fay在ubuntu 20下完成测试**
环境配置完成后启动fay对接脚本
```python
python fay_connect.py
```
![](img/weplay.png)
扫码支助开源开发工作凭支付单号入qq交流群
接口的输入与输出信息 [Websoket.md](https://github.com/waityousea/xuniren/blob/main/WebSocket.md)
虚拟人生成的核心文件
```python
## 注意,核心文件需要单独训练
.
├── data
│ ├── kf.json
│ ├── pretrained
│ └── └── ngp_kg.pth
```
### Inference Speed
在台式机RTX A4000或笔记本RTX 3080ti的显卡显存16G上进行视频推理时1s可以推理35~43帧假如1s视频25帧则1s可推理约1.5s视频。
# Acknowledgement
- The data pre-processing part is adapted from [AD-NeRF](https://github.com/YudongGuo/AD-NeRF).
- The NeRF framework is based on [torch-ngp](https://github.com/ashawkey/torch-ngp).
- The algorithm core come from [RAD-NeRF](https://github.com/ashawkey/RAD-NeRF).
- Usage example [Fay](https://github.com/TheRamU/Fay).
学术交流可发邮件到邮箱waityousea@126.com

251
app.py Normal file
View File

@ -0,0 +1,251 @@
# server.py
from flask import Flask, request, jsonify
from flask_sockets import Sockets
import base64
import time
import json
import gevent
from gevent import pywsgi
from geventwebsocket.handler import WebSocketHandler
from tools import audio_pre_process, video_pre_process, generate_video,audio_process
import os
import re
import numpy as np
import argparse
from nerf_triplane.provider import NeRFDataset_Test
from nerf_triplane.utils import *
from nerf_triplane.network import NeRFNetwork
from nerfreal import NeRFReal
import shutil
import asyncio
import edge_tts
app = Flask(__name__)
sockets = Sockets(app)
video_list = []
global nerfreal
async def main(voicename: str, text: str, render):
communicate = edge_tts.Communicate(text, voicename)
#with open(OUTPUT_FILE, "wb") as file:
async for chunk in communicate.stream():
if chunk["type"] == "audio":
render.push_audio(chunk["data"])
#file.write(chunk["data"])
elif chunk["type"] == "WordBoundary":
pass
def send_information(path, ws):
print('传输信息开始!')
#path = video_list[0]
''''''
with open(path, 'rb') as f:
video_data = base64.b64encode(f.read()).decode()
data = {
'video': 'data:video/mp4;base64,%s' % video_data,
}
json_data = json.dumps(data)
ws.send(json_data)
def txt_to_audio(text_):
audio_list = []
#audio_path = 'data/audio/aud_0.wav'
voicename = "zh-CN-YunxiaNeural"
# 让我们一起学习。必应由 AI 提供支持,因此可能出现意外和错误。请确保核对事实,并 共享反馈以便我们可以学习和改进!
text = text_
asyncio.get_event_loop().run_until_complete(main(voicename,text,nerfreal))
#audio_process(audio_path)
@sockets.route('/dighuman')
def echo_socket(ws):
# 获取WebSocket对象
#ws = request.environ.get('wsgi.websocket')
# 如果没有获取到,返回错误信息
if not ws:
print('未建立连接!')
return 'Please use WebSocket'
# 否则,循环接收和发送消息
else:
print('建立连接!')
while True:
message = ws.receive()
if len(message)==0:
return '输入信息为空'
else:
txt_to_audio(message)
audio_path = 'data/audio/aud_0.wav'
audio_path_eo = 'data/audio/aud_0_eo.npy'
video_path = 'data/video/results/ngp_0.mp4'
output_path = 'data/video/results/output_0.mp4'
generate_video(audio_path, audio_path_eo, video_path, output_path)
video_list.append(output_path)
send_information(output_path, ws)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--pose', type=str, default="data/data_kf.json", help="transforms.json, pose source")
parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --exp_eye")
parser.add_argument('--data_range', type=int, nargs='*', default=[0, -1], help="data range to use")
parser.add_argument('--workspace', type=str, default='data/video')
parser.add_argument('--seed', type=int, default=0)
### training options
parser.add_argument('--ckpt', type=str, default='data/pretrained/ngp_kf.pth')
parser.add_argument('--num_rays', type=int, default=4096 * 16, help="num rays sampled per image for each training step")
parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
parser.add_argument('--max_steps', type=int, default=16, help="max num steps sampled per ray (only valid when using --cuda_ray)")
parser.add_argument('--num_steps', type=int, default=16, help="num steps sampled per ray (only valid when NOT using --cuda_ray)")
parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)")
parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)")
### loss set
parser.add_argument('--warmup_step', type=int, default=10000, help="warm up steps")
parser.add_argument('--amb_aud_loss', type=int, default=1, help="use ambient aud loss")
parser.add_argument('--amb_eye_loss', type=int, default=1, help="use ambient eye loss")
parser.add_argument('--unc_loss', type=int, default=1, help="use uncertainty loss")
parser.add_argument('--lambda_amb', type=float, default=1e-4, help="lambda for ambient loss")
### network backbone options
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
parser.add_argument('--bg_img', type=str, default='white', help="background image")
parser.add_argument('--fbg', action='store_true', help="frame-wise bg")
parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes")
parser.add_argument('--fix_eye', type=float, default=-1, help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye")
parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence")
parser.add_argument('--torso_shrink', type=float, default=0.8, help="shrink bg coords to allow more flexibility in deform")
### dataset options
parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)")
parser.add_argument('--preload', type=int, default=0, help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.")
# (the default value is for the fox dataset)
parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.")
parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3")
parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location")
parser.add_argument('--dt_gamma', type=float, default=1/256, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera")
parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied (sigma)")
parser.add_argument('--density_thresh_torso', type=float, default=0.01, help="threshold for density grid to be occupied (alpha)")
parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable")
parser.add_argument('--init_lips', action='store_true', help="init lips region")
parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region")
parser.add_argument('--smooth_lips', action='store_true', help="smooth the enc_a in a exponential decay way...")
parser.add_argument('--torso', action='store_true', help="fix head and train torso")
parser.add_argument('--head_ckpt', type=str, default='', help="head model")
### GUI options
parser.add_argument('--gui', action='store_true', help="start a GUI")
parser.add_argument('--W', type=int, default=450, help="GUI width")
parser.add_argument('--H', type=int, default=450, help="GUI height")
parser.add_argument('--radius', type=float, default=3.35, help="default GUI camera radius from center")
parser.add_argument('--fovy', type=float, default=21.24, help="default GUI camera fovy")
parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel")
### else
parser.add_argument('--att', type=int, default=2, help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)")
parser.add_argument('--aud', type=str, default='', help="audio source (empty will load the default, else should be a path to a npy file)")
parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits")
parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off")
parser.add_argument('--ind_num', type=int, default=10000, help="number of individual codes, should be larger than training dataset size")
parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off")
parser.add_argument('--amb_dim', type=int, default=2, help="ambient dimension")
parser.add_argument('--part', action='store_true', help="use partial training data (1/10)")
parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)")
parser.add_argument('--train_camera', action='store_true', help="optimize camera pose")
parser.add_argument('--smooth_path', action='store_true', help="brute-force smooth camera pose trajectory with a window size")
parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size")
# asr
parser.add_argument('--asr', action='store_true', help="load asr for real-time app")
parser.add_argument('--asr_wav', type=str, default='', help="load the wav and use as input")
parser.add_argument('--asr_play', action='store_true', help="play out the audio")
#parser.add_argument('--asr_model', type=str, default='deepspeech')
parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto')
# parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
parser.add_argument('--asr_save_feats', action='store_true')
# audio FPS
parser.add_argument('--fps', type=int, default=50)
# sliding window left-middle-right length (unit: 20ms)
parser.add_argument('-l', type=int, default=10)
parser.add_argument('-m', type=int, default=50)
parser.add_argument('-r', type=int, default=10)
opt = parser.parse_args()
# assert test mode
opt.test = True
opt.test_train = False
#opt.train_camera =True
# explicit smoothing
opt.smooth_path = True
opt.smooth_eye = True
opt.smooth_lips = True
assert opt.pose != '', 'Must provide a pose source'
# if opt.O:
opt.fp16 = True
opt.exp_eye = True
opt.cuda_ray = True
opt.torso = True
# assert opt.cuda_ray, "Only support CUDA ray mode."
opt.asr = True
if opt.patch_size > 1:
# assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss."
assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays."
seed_everything(opt.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = NeRFNetwork(opt)
criterion = torch.nn.MSELoss(reduction='none')
metrics = [] # use no metric in GUI for faster initialization...
print(model)
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt)
test_loader = NeRFDataset_Test(opt, device=device).dataloader()
model.aud_features = test_loader._data.auds
model.eye_areas = test_loader._data.eye_area
# we still need test_loader to provide audio features for testing.
nerfreal = NeRFReal(opt, trainer, test_loader)
txt_to_audio('我是中国人,我来自北京')
nerfreal.render()
#############################################################################
server = pywsgi.WSGIServer(('127.0.0.1', 8800), app, handler_class=WebSocketHandler)
server.serve_forever()

464
asrreal.py Normal file
View File

@ -0,0 +1,464 @@
import time
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoModelForCTC, AutoProcessor
import pyaudio
import soundfile as sf
import resampy
from queue import Queue
#from collections import deque
from threading import Thread, Event
from io import BytesIO
def _read_frame(stream, exit_event, queue, chunk):
while True:
if exit_event.is_set():
print(f'[INFO] read frame thread ends')
break
frame = stream.read(chunk, exception_on_overflow=False)
frame = np.frombuffer(frame, dtype=np.int16).astype(np.float32) / 32767 # [chunk]
queue.put(frame)
def _play_frame(stream, exit_event, queue, chunk):
while True:
if exit_event.is_set():
print(f'[INFO] play frame thread ends')
break
frame = queue.get()
frame = (frame * 32767).astype(np.int16).tobytes()
stream.write(frame, chunk)
class ASR:
def __init__(self, opt):
self.opt = opt
self.play = opt.asr_play #false
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.fps = opt.fps # 20 ms per frame
self.sample_rate = 16000
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.mode = 'live' if opt.asr_wav == '' else 'file'
if 'esperanto' in self.opt.asr_model:
self.audio_dim = 44
elif 'deepspeech' in self.opt.asr_model:
self.audio_dim = 29
else:
self.audio_dim = 32
# prepare context cache
# each segment is (stride_left + ctx + stride_right) * 20ms, latency should be (ctx + stride_right) * 20ms
self.context_size = opt.m
self.stride_left_size = opt.l
self.stride_right_size = opt.r
self.text = '[START]\n'
self.terminated = False
self.frames = []
self.inwarm = False
# pad left frames
if self.stride_left_size > 0:
self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size)
self.exit_event = Event()
#self.audio_instance = pyaudio.PyAudio() #not need
# create input stream
if self.mode == 'file': #live mode
self.file_stream = self.create_file_stream()
else:
self.queue = Queue()
self.input_stream = BytesIO()
self.output_queue = Queue()
# start a background process to read frames
#self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk)
#self.queue = Queue()
#self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk))
# play out the audio too...?
if self.play:
self.output_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=False, output=True, frames_per_buffer=self.chunk)
self.output_queue = Queue()
self.process_play_frame = Thread(target=_play_frame, args=(self.output_stream, self.exit_event, self.output_queue, self.chunk))
# current location of audio
self.idx = 0
# create wav2vec model
print(f'[INFO] loading ASR model {self.opt.asr_model}...')
self.processor = AutoProcessor.from_pretrained(opt.asr_model)
self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device)
# prepare to save logits
if self.opt.asr_save_feats:
self.all_feats = []
# the extracted features
# use a loop queue to efficiently record endless features: [f--t---][-------][-------]
self.feat_buffer_size = 4
self.feat_buffer_idx = 0
self.feat_queue = torch.zeros(self.feat_buffer_size * self.context_size, self.audio_dim, dtype=torch.float32, device=self.device)
# TODO: hard coded 16 and 8 window size...
self.front = self.feat_buffer_size * self.context_size - 8 # fake padding
self.tail = 8
# attention window...
self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding...
# warm up steps needed: mid + right + window_size + attention_size
self.warm_up_steps = self.context_size + self.stride_right_size + 8 + 2 * 3
self.listening = False
self.playing = False
def listen(self):
# start
if self.mode == 'live' and not self.listening:
print(f'[INFO] starting read frame thread...')
self.process_read_frame.start()
self.listening = True
if self.play and not self.playing:
print(f'[INFO] starting play frame thread...')
self.process_play_frame.start()
self.playing = True
def stop(self):
self.exit_event.set()
if self.play:
self.output_stream.stop_stream()
self.output_stream.close()
if self.playing:
self.process_play_frame.join()
self.playing = False
if self.mode == 'live':
#self.input_stream.stop_stream() todo
self.input_stream.close()
if self.listening:
self.process_read_frame.join()
self.listening = False
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.stop()
if self.mode == 'live':
# live mode: also print the result text.
self.text += '\n[END]'
print(self.text)
def get_next_feat(self):
# return a [1/8, 16] window, for the next input to nerf side.
while len(self.att_feats) < 8:
# [------f+++t-----]
if self.front < self.tail:
feat = self.feat_queue[self.front:self.tail]
# [++t-----------f+]
else:
feat = torch.cat([self.feat_queue[self.front:], self.feat_queue[:self.tail]], dim=0)
self.front = (self.front + 2) % self.feat_queue.shape[0]
self.tail = (self.tail + 2) % self.feat_queue.shape[0]
# print(self.front, self.tail, feat.shape)
self.att_feats.append(feat.permute(1, 0))
att_feat = torch.stack(self.att_feats, dim=0) # [8, 44, 16]
# discard old
self.att_feats = self.att_feats[1:]
return att_feat
def run_step(self):
if self.terminated:
return
# get a frame of audio
frame = self.get_audio_frame()
# the last frame
if frame is None:
# terminate, but always run the network for the left frames
self.terminated = True
else:
self.frames.append(frame)
# put to output
#if self.play:
self.output_queue.put(frame)
# context not enough, do not run network.
if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
return
inputs = np.concatenate(self.frames) # [N * chunk]
# discard the old part to save memory
if not self.terminated:
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
print(f'[INFO] frame_to_text... ')
logits, labels, text = self.frame_to_text(inputs)
feats = logits # better lips-sync than labels
# save feats
if self.opt.asr_save_feats:
self.all_feats.append(feats)
# record the feats efficiently.. (no concat, constant memory)
start = self.feat_buffer_idx * self.context_size
end = start + feats.shape[0]
self.feat_queue[start:end] = feats
self.feat_buffer_idx = (self.feat_buffer_idx + 1) % self.feat_buffer_size
# very naive, just concat the text output.
if text != '':
self.text = self.text + ' ' + text
# will only run once at ternimation
if self.terminated:
self.text += '\n[END]'
print(self.text)
if self.opt.asr_save_feats:
print(f'[INFO] save all feats for training purpose... ')
feats = torch.cat(self.all_feats, dim=0) # [N, C]
# print('[INFO] before unfold', feats.shape)
window_size = 16
padding = window_size // 2
feats = feats.view(-1, self.audio_dim).permute(1, 0).contiguous() # [C, M]
feats = feats.view(1, self.audio_dim, -1, 1) # [1, C, M, 1]
unfold_feats = F.unfold(feats, kernel_size=(window_size, 1), padding=(padding, 0), stride=(2, 1)) # [1, C * window_size, M / 2 + 1]
unfold_feats = unfold_feats.view(self.audio_dim, window_size, -1).permute(2, 1, 0).contiguous() # [C, window_size, M / 2 + 1] --> [M / 2 + 1, window_size, C]
# print('[INFO] after unfold', unfold_feats.shape)
# save to a npy file
if 'esperanto' in self.opt.asr_model:
output_path = self.opt.asr_wav.replace('.wav', '_eo.npy')
else:
output_path = self.opt.asr_wav.replace('.wav', '.npy')
np.save(output_path, unfold_feats.cpu().numpy())
print(f"[INFO] saved logits to {output_path}")
def create_file_stream(self):
stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64
stream = stream.astype(np.float32)
if stream.ndim > 1:
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0]
if sample_rate != self.sample_rate:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
print(f'[INFO] loaded audio stream {self.opt.asr_wav}: {stream.shape}')
return stream
def create_pyaudio_stream(self):
import pyaudio
print(f'[INFO] creating live audio stream ...')
audio = pyaudio.PyAudio()
# get devices
info = audio.get_host_api_info_by_index(0)
n_devices = info.get('deviceCount')
for i in range(0, n_devices):
if (audio.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
name = audio.get_device_info_by_host_api_device_index(0, i).get('name')
print(f'[INFO] choose audio device {name}, id {i}')
break
# get stream
stream = audio.open(input_device_index=i,
format=pyaudio.paInt16,
channels=1,
rate=self.sample_rate,
input=True,
frames_per_buffer=self.chunk)
return audio, stream
def get_audio_frame(self):
if self.inwarm: # warm up
return np.zeros(self.chunk, dtype=np.float32)
if self.mode == 'file':
if self.idx < self.file_stream.shape[0]:
frame = self.file_stream[self.idx: self.idx + self.chunk]
self.idx = self.idx + self.chunk
return frame
else:
return None
else:
frame = self.queue.get()
print(f'[INFO] get frame {frame.shape}')
self.idx = self.idx + self.chunk
return frame
def frame_to_text(self, frame):
# frame: [N * 320], N = (context_size + 2 * stride_size)
inputs = self.processor(frame, sampling_rate=self.sample_rate, return_tensors="pt", padding=True)
with torch.no_grad():
result = self.model(inputs.input_values.to(self.device))
logits = result.logits # [1, N - 1, 32]
# cut off stride
left = max(0, self.stride_left_size)
right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input.
# do not cut right if terminated.
if self.terminated:
right = logits.shape[1]
logits = logits[:, left:right]
# print(frame.shape, inputs.input_values.shape, logits.shape)
predicted_ids = torch.argmax(logits, dim=-1)
transcription = self.processor.batch_decode(predicted_ids)[0].lower()
# for esperanto
# labels = np.array(['ŭ', '»', 'c', 'ĵ', 'ñ', '”', '„', '“', 'ǔ', 'o', 'ĝ', 'm', 'k', 'd', 'a', 'ŝ', 'z', 'i', '«', '—', '', 'ĥ', 'f', 'y', 'h', 'j', '|', 'r', 'u', 'ĉ', 's', '', 'fi', 'l', 'p', '', 'g', 'v', 't', 'b', 'n', 'e', '[UNK]', '[PAD]'])
# labels = np.array([' ', ' ', ' ', '-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z'])
# print(''.join(labels[predicted_ids[0].detach().cpu().long().numpy()]))
# print(predicted_ids[0])
# print(transcription)
return logits[0], predicted_ids[0], transcription # [N,]
def create_bytes_stream(self,byte_stream):
#byte_stream=BytesIO(buffer)
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
stream = stream.astype(np.float32)
if stream.ndim > 1:
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0]
if sample_rate != self.sample_rate:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
return stream
def push_audio(self,buffer):
print(f'[INFO] push_audio {len(buffer)}')
self.input_stream.write(buffer)
if len(buffer)<=0:
self.input_stream.seek(0)
stream = self.create_bytes_stream(self.input_stream)
streamlen = stream.shape[0]
idx=0
while streamlen >= self.chunk:
self.queue.put(stream[idx:idx+self.chunk])
streamlen -= self.chunk
idx += self.chunk
if streamlen>0:
self.queue.put(stream[idx:])
def get_audio_out(self):
return self.output_queue.get()
def run(self):
self.listen()
while not self.terminated:
self.run_step()
def clear_queue(self):
# clear the queue, to reduce potential latency...
print(f'[INFO] clear queue')
if self.mode == 'live':
self.queue.queue.clear()
if self.play:
self.output_queue.queue.clear()
def warm_up(self):
#self.listen()
self.inwarm = True
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
t = time.time()
for _ in range(self.warm_up_steps):
self.run_step()
if torch.cuda.is_available():
torch.cuda.synchronize()
t = time.time() - t
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
self.inwarm = False
#self.clear_queue()
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--wav', type=str, default='')
parser.add_argument('--play', action='store_true', help="play out the audio")
parser.add_argument('--model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto')
# parser.add_argument('--model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
parser.add_argument('--save_feats', action='store_true')
# audio FPS
parser.add_argument('--fps', type=int, default=50)
# sliding window left-middle-right length.
parser.add_argument('-l', type=int, default=10)
parser.add_argument('-m', type=int, default=50)
parser.add_argument('-r', type=int, default=10)
opt = parser.parse_args()
# fix
opt.asr_wav = opt.wav
opt.asr_play = opt.play
opt.asr_model = opt.model
opt.asr_save_feats = opt.save_feats
if 'deepspeech' in opt.asr_model:
raise ValueError("DeepSpeech features should not use this code to extract...")
with ASR(opt) as asr:
asr.run()

BIN
assets/main.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 182 KiB

218167
data/data_kf.json Normal file

File diff suppressed because it is too large Load Diff

BIN
data/pretrained/ngp_kf.pth Normal file

Binary file not shown.

View File

@ -0,0 +1,20 @@
# Routines for DeepSpeech features processing
Several routines for [DeepSpeech](https://github.com/mozilla/DeepSpeech) features processing, like speech features generation for [VOCA](https://github.com/TimoBolkart/voca) model.
## Installation
```
pip3 install -r requirements.txt
```
## Usage
Generate wav files:
```
python3 extract_wav.py --in-video=<you_data_dir>
```
Generate files with DeepSpeech features:
```
python3 extract_ds_features.py --input=<you_data_dir>
```

View File

@ -0,0 +1,275 @@
"""
DeepSpeech features processing routines.
NB: Based on VOCA code. See the corresponding license restrictions.
"""
__all__ = ['conv_audios_to_deepspeech']
import numpy as np
import warnings
import resampy
from scipy.io import wavfile
from python_speech_features import mfcc
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
def conv_audios_to_deepspeech(audios,
out_files,
num_frames_info,
deepspeech_pb_path,
audio_window_size=1,
audio_window_stride=1):
"""
Convert list of audio files into files with DeepSpeech features.
Parameters
----------
audios : list of str or list of None
Paths to input audio files.
out_files : list of str
Paths to output files with DeepSpeech features.
num_frames_info : list of int
List of numbers of frames.
deepspeech_pb_path : str
Path to DeepSpeech 0.1.0 frozen model.
audio_window_size : int, default 16
Audio window size.
audio_window_stride : int, default 1
Audio window stride.
"""
# deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
graph, logits_ph, input_node_ph, input_lengths_ph = prepare_deepspeech_net(
deepspeech_pb_path)
with tf.compat.v1.Session(graph=graph) as sess:
for audio_file_path, out_file_path, num_frames in zip(audios, out_files, num_frames_info):
print(audio_file_path)
print(out_file_path)
audio_sample_rate, audio = wavfile.read(audio_file_path)
if audio.ndim != 1:
warnings.warn(
"Audio has multiple channels, the first channel is used")
audio = audio[:, 0]
ds_features = pure_conv_audio_to_deepspeech(
audio=audio,
audio_sample_rate=audio_sample_rate,
audio_window_size=audio_window_size,
audio_window_stride=audio_window_stride,
num_frames=num_frames,
net_fn=lambda x: sess.run(
logits_ph,
feed_dict={
input_node_ph: x[np.newaxis, ...],
input_lengths_ph: [x.shape[0]]}))
net_output = ds_features.reshape(-1, 29)
win_size = 16
zero_pad = np.zeros((int(win_size / 2), net_output.shape[1]))
net_output = np.concatenate(
(zero_pad, net_output, zero_pad), axis=0)
windows = []
for window_index in range(0, net_output.shape[0] - win_size, 2):
windows.append(
net_output[window_index:window_index + win_size])
print(np.array(windows).shape)
np.save(out_file_path, np.array(windows))
def prepare_deepspeech_net(deepspeech_pb_path):
"""
Load and prepare DeepSpeech network.
Parameters
----------
deepspeech_pb_path : str
Path to DeepSpeech 0.1.0 frozen model.
Returns
-------
graph : obj
ThensorFlow graph.
logits_ph : obj
ThensorFlow placeholder for `logits`.
input_node_ph : obj
ThensorFlow placeholder for `input_node`.
input_lengths_ph : obj
ThensorFlow placeholder for `input_lengths`.
"""
# Load graph and place_holders:
with tf.io.gfile.GFile(deepspeech_pb_path, "rb") as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.compat.v1.get_default_graph()
tf.import_graph_def(graph_def, name="deepspeech")
logits_ph = graph.get_tensor_by_name("deepspeech/logits:0")
input_node_ph = graph.get_tensor_by_name("deepspeech/input_node:0")
input_lengths_ph = graph.get_tensor_by_name("deepspeech/input_lengths:0")
return graph, logits_ph, input_node_ph, input_lengths_ph
def pure_conv_audio_to_deepspeech(audio,
audio_sample_rate,
audio_window_size,
audio_window_stride,
num_frames,
net_fn):
"""
Core routine for converting audion into DeepSpeech features.
Parameters
----------
audio : np.array
Audio data.
audio_sample_rate : int
Audio sample rate.
audio_window_size : int
Audio window size.
audio_window_stride : int
Audio window stride.
num_frames : int or None
Numbers of frames.
net_fn : func
Function for DeepSpeech model call.
Returns
-------
np.array
DeepSpeech features.
"""
target_sample_rate = 16000
if audio_sample_rate != target_sample_rate:
resampled_audio = resampy.resample(
x=audio.astype(np.float),
sr_orig=audio_sample_rate,
sr_new=target_sample_rate)
else:
resampled_audio = audio.astype(np.float)
input_vector = conv_audio_to_deepspeech_input_vector(
audio=resampled_audio.astype(np.int16),
sample_rate=target_sample_rate,
num_cepstrum=26,
num_context=9)
network_output = net_fn(input_vector)
# print(network_output.shape)
deepspeech_fps = 50
video_fps = 50 # Change this option if video fps is different
audio_len_s = float(audio.shape[0]) / audio_sample_rate
if num_frames is None:
num_frames = int(round(audio_len_s * video_fps))
else:
video_fps = num_frames / audio_len_s
network_output = interpolate_features(
features=network_output[:, 0],
input_rate=deepspeech_fps,
output_rate=video_fps,
output_len=num_frames)
# Make windows:
zero_pad = np.zeros((int(audio_window_size / 2), network_output.shape[1]))
network_output = np.concatenate(
(zero_pad, network_output, zero_pad), axis=0)
windows = []
for window_index in range(0, network_output.shape[0] - audio_window_size, audio_window_stride):
windows.append(
network_output[window_index:window_index + audio_window_size])
return np.array(windows)
def conv_audio_to_deepspeech_input_vector(audio,
sample_rate,
num_cepstrum,
num_context):
"""
Convert audio raw data into DeepSpeech input vector.
Parameters
----------
audio : np.array
Audio data.
audio_sample_rate : int
Audio sample rate.
num_cepstrum : int
Number of cepstrum.
num_context : int
Number of context.
Returns
-------
np.array
DeepSpeech input vector.
"""
# Get mfcc coefficients:
features = mfcc(
signal=audio,
samplerate=sample_rate,
numcep=num_cepstrum)
# We only keep every second feature (BiRNN stride = 2):
features = features[::2]
# One stride per time step in the input:
num_strides = len(features)
# Add empty initial and final contexts:
empty_context = np.zeros((num_context, num_cepstrum), dtype=features.dtype)
features = np.concatenate((empty_context, features, empty_context))
# Create a view into the array with overlapping strides of size
# numcontext (past) + 1 (present) + numcontext (future):
window_size = 2 * num_context + 1
train_inputs = np.lib.stride_tricks.as_strided(
features,
shape=(num_strides, window_size, num_cepstrum),
strides=(features.strides[0],
features.strides[0], features.strides[1]),
writeable=False)
# Flatten the second and third dimensions:
train_inputs = np.reshape(train_inputs, [num_strides, -1])
train_inputs = np.copy(train_inputs)
train_inputs = (train_inputs - np.mean(train_inputs)) / \
np.std(train_inputs)
return train_inputs
def interpolate_features(features,
input_rate,
output_rate,
output_len):
"""
Interpolate DeepSpeech features.
Parameters
----------
features : np.array
DeepSpeech features.
input_rate : int
input rate (FPS).
output_rate : int
Output rate (FPS).
output_len : int
Output data length.
Returns
-------
np.array
Interpolated data.
"""
input_len = features.shape[0]
num_features = features.shape[1]
input_timestamps = np.arange(input_len) / float(input_rate)
output_timestamps = np.arange(output_len) / float(output_rate)
output_features = np.zeros((output_len, num_features))
for feature_idx in range(num_features):
output_features[:, feature_idx] = np.interp(
x=output_timestamps,
xp=input_timestamps,
fp=features[:, feature_idx])
return output_features

View File

@ -0,0 +1,172 @@
"""
Routines for loading DeepSpeech model.
"""
__all__ = ['get_deepspeech_model_file']
import os
import zipfile
import logging
import hashlib
deepspeech_features_repo_url = 'https://github.com/osmr/deepspeech_features'
def get_deepspeech_model_file(local_model_store_dir_path=os.path.join("~", ".tensorflow", "models")):
"""
Return location for the pretrained on local file system. This function will download from online model zoo when
model cannot be found or has mismatch. The root directory will be created if it doesn't exist.
Parameters
----------
local_model_store_dir_path : str, default $TENSORFLOW_HOME/models
Location for keeping the model parameters.
Returns
-------
file_path
Path to the requested pretrained model file.
"""
sha1_hash = "b90017e816572ddce84f5843f1fa21e6a377975e"
file_name = "deepspeech-0_1_0-b90017e8.pb"
local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path)
file_path = os.path.join(local_model_store_dir_path, file_name)
if os.path.exists(file_path):
if _check_sha1(file_path, sha1_hash):
return file_path
else:
logging.warning("Mismatch in the content of model file detected. Downloading again.")
else:
logging.info("Model file not found. Downloading to {}.".format(file_path))
if not os.path.exists(local_model_store_dir_path):
os.makedirs(local_model_store_dir_path)
zip_file_path = file_path + ".zip"
_download(
url="{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip".format(
repo_url=deepspeech_features_repo_url,
repo_release_tag="v0.0.1",
file_name=file_name),
path=zip_file_path,
overwrite=True)
with zipfile.ZipFile(zip_file_path) as zf:
zf.extractall(local_model_store_dir_path)
os.remove(zip_file_path)
if _check_sha1(file_path, sha1_hash):
return file_path
else:
raise ValueError("Downloaded file has different hash. Please try again.")
def _download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
"""
Download an given URL
Parameters
----------
url : str
URL to download
path : str, optional
Destination path to store downloaded file. By default stores to the
current directory with same name as in url.
overwrite : bool, optional
Whether to overwrite destination file if already exists.
sha1_hash : str, optional
Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
but doesn't match.
retries : integer, default 5
The number of times to attempt the download in case of failure or non 200 return codes
verify_ssl : bool, default True
Verify SSL certificates.
Returns
-------
str
The file path of the downloaded file.
"""
import warnings
try:
import requests
except ImportError:
class requests_failed_to_import(object):
pass
requests = requests_failed_to_import
if path is None:
fname = url.split("/")[-1]
# Empty filenames are invalid
assert fname, "Can't construct file-name from this URL. Please set the `path` option manually."
else:
path = os.path.expanduser(path)
if os.path.isdir(path):
fname = os.path.join(path, url.split("/")[-1])
else:
fname = path
assert retries >= 0, "Number of retries should be at least 0"
if not verify_ssl:
warnings.warn(
"Unverified HTTPS request is being made (verify_ssl=False). "
"Adding certificate verification is strongly advised.")
if overwrite or not os.path.exists(fname) or (sha1_hash and not _check_sha1(fname, sha1_hash)):
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
if not os.path.exists(dirname):
os.makedirs(dirname)
while retries + 1 > 0:
# Disable pyling too broad Exception
# pylint: disable=W0703
try:
print("Downloading {} from {}...".format(fname, url))
r = requests.get(url, stream=True, verify=verify_ssl)
if r.status_code != 200:
raise RuntimeError("Failed downloading url {}".format(url))
with open(fname, "wb") as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
if sha1_hash and not _check_sha1(fname, sha1_hash):
raise UserWarning("File {} is downloaded but the content hash does not match."
" The repo may be outdated or download may be incomplete. "
"If the `repo_url` is overridden, consider switching to "
"the default repo.".format(fname))
break
except Exception as e:
retries -= 1
if retries <= 0:
raise e
else:
print("download failed, retrying, {} attempt{} left"
.format(retries, "s" if retries > 1 else ""))
return fname
def _check_sha1(filename, sha1_hash):
"""
Check whether the sha1 hash of the file content matches the expected hash.
Parameters
----------
filename : str
Path to the file.
sha1_hash : str
Expected sha1 hash in hexadecimal digits.
Returns
-------
bool
Whether the file content matches the expected hash.
"""
sha1 = hashlib.sha1()
with open(filename, "rb") as f:
while True:
data = f.read(1048576)
if not data:
break
sha1.update(data)
return sha1.hexdigest() == sha1_hash

View File

@ -0,0 +1,132 @@
"""
Script for extracting DeepSpeech features from audio file.
"""
import os
import argparse
import numpy as np
import pandas as pd
from deepspeech_store import get_deepspeech_model_file
from deepspeech_features import conv_audios_to_deepspeech
def parse_args():
"""
Create python script parameters.
Returns
-------
ArgumentParser
Resulted args.
"""
parser = argparse.ArgumentParser(
description="Extract DeepSpeech features from audio file",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--input",
type=str,
required=True,
help="path to input audio file or directory")
parser.add_argument(
"--output",
type=str,
help="path to output file with DeepSpeech features")
parser.add_argument(
"--deepspeech",
type=str,
help="path to DeepSpeech 0.1.0 frozen model")
parser.add_argument(
"--metainfo",
type=str,
help="path to file with meta-information")
args = parser.parse_args()
return args
def extract_features(in_audios,
out_files,
deepspeech_pb_path,
metainfo_file_path=None):
"""
Real extract audio from video file.
Parameters
----------
in_audios : list of str
Paths to input audio files.
out_files : list of str
Paths to output files with DeepSpeech features.
deepspeech_pb_path : str
Path to DeepSpeech 0.1.0 frozen model.
metainfo_file_path : str, default None
Path to file with meta-information.
"""
#deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
if metainfo_file_path is None:
num_frames_info = [None] * len(in_audios)
else:
train_df = pd.read_csv(
metainfo_file_path,
sep="\t",
index_col=False,
dtype={"Id": np.int, "File": np.unicode, "Count": np.int})
num_frames_info = train_df["Count"].values
assert (len(num_frames_info) == len(in_audios))
for i, in_audio in enumerate(in_audios):
if not out_files[i]:
file_stem, _ = os.path.splitext(in_audio)
out_files[i] = file_stem + ".npy"
#print(out_files[i])
conv_audios_to_deepspeech(
audios=in_audios,
out_files=out_files,
num_frames_info=num_frames_info,
deepspeech_pb_path=deepspeech_pb_path)
def main():
"""
Main body of script.
"""
args = parse_args()
in_audio = os.path.expanduser(args.input)
if not os.path.exists(in_audio):
raise Exception("Input file/directory doesn't exist: {}".format(in_audio))
deepspeech_pb_path = args.deepspeech
#add
deepspeech_pb_path = True
args.deepspeech = '~/.tensorflow/models/deepspeech-0_1_0-b90017e8.pb'
#deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
if deepspeech_pb_path is None:
deepspeech_pb_path = ""
if deepspeech_pb_path:
deepspeech_pb_path = os.path.expanduser(args.deepspeech)
if not os.path.exists(deepspeech_pb_path):
deepspeech_pb_path = get_deepspeech_model_file()
if os.path.isfile(in_audio):
extract_features(
in_audios=[in_audio],
out_files=[args.output],
deepspeech_pb_path=deepspeech_pb_path,
metainfo_file_path=args.metainfo)
else:
audio_file_paths = []
for file_name in os.listdir(in_audio):
if not os.path.isfile(os.path.join(in_audio, file_name)):
continue
_, file_ext = os.path.splitext(file_name)
if file_ext.lower() == ".wav":
audio_file_path = os.path.join(in_audio, file_name)
audio_file_paths.append(audio_file_path)
audio_file_paths = sorted(audio_file_paths)
out_file_paths = [""] * len(audio_file_paths)
extract_features(
in_audios=audio_file_paths,
out_files=out_file_paths,
deepspeech_pb_path=deepspeech_pb_path,
metainfo_file_path=args.metainfo)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,87 @@
"""
Script for extracting audio (16-bit, mono, 22000 Hz) from video file.
"""
import os
import argparse
import subprocess
def parse_args():
"""
Create python script parameters.
Returns
-------
ArgumentParser
Resulted args.
"""
parser = argparse.ArgumentParser(
description="Extract audio from video file",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--in-video",
type=str,
required=True,
help="path to input video file or directory")
parser.add_argument(
"--out-audio",
type=str,
help="path to output audio file")
args = parser.parse_args()
return args
def extract_audio(in_video,
out_audio):
"""
Real extract audio from video file.
Parameters
----------
in_video : str
Path to input video file.
out_audio : str
Path to output audio file.
"""
if not out_audio:
file_stem, _ = os.path.splitext(in_video)
out_audio = file_stem + ".wav"
# command1 = "ffmpeg -i {in_video} -vn -acodec copy {aac_audio}"
# command2 = "ffmpeg -i {aac_audio} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}"
# command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}"
command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 16000 {out_audio}"
subprocess.call([command.format(in_video=in_video, out_audio=out_audio)], shell=True)
def main():
"""
Main body of script.
"""
args = parse_args()
in_video = os.path.expanduser(args.in_video)
if not os.path.exists(in_video):
raise Exception("Input file/directory doesn't exist: {}".format(in_video))
if os.path.isfile(in_video):
extract_audio(
in_video=in_video,
out_audio=args.out_audio)
else:
video_file_paths = []
for file_name in os.listdir(in_video):
if not os.path.isfile(os.path.join(in_video, file_name)):
continue
_, file_ext = os.path.splitext(file_name)
if file_ext.lower() in (".mp4", ".mkv", ".avi"):
video_file_path = os.path.join(in_video, file_name)
video_file_paths.append(video_file_path)
video_file_paths = sorted(video_file_paths)
for video_file_path in video_file_paths:
extract_audio(
in_video=video_file_path,
out_audio="")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,11 @@
import numpy as np
net_output = np.load('french.ds.npy').reshape(-1, 29)
win_size = 16
zero_pad = np.zeros((int(win_size / 2), net_output.shape[1]))
net_output = np.concatenate((zero_pad, net_output, zero_pad), axis=0)
windows = []
for window_index in range(0, net_output.shape[0] - win_size, 2):
windows.append(net_output[window_index:window_index + win_size])
print(np.array(windows).shape)
np.save('aud_french.npy', np.array(windows))

View File

@ -0,0 +1,23 @@
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import os.path as osp
import time
import sys
import logging
import torch.distributed as dist
def setup_logger(logpth):
logfile = 'BiSeNet-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S'))
logfile = osp.join(logpth, logfile)
FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s'
log_level = logging.INFO
if dist.is_initialized() and not dist.get_rank()==0:
log_level = logging.ERROR
logging.basicConfig(level=log_level, format=FORMAT, filename=logfile)
logging.root.addHandler(logging.StreamHandler())

View File

@ -0,0 +1,285 @@
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from resnet import Resnet18
# from modules.bn import InPlaceABNSync as BatchNorm2d
class ConvBNReLU(nn.Module):
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(in_chan,
out_chan,
kernel_size = ks,
stride = stride,
padding = padding,
bias = False)
self.bn = nn.BatchNorm2d(out_chan)
self.init_weight()
def forward(self, x):
x = self.conv(x)
x = F.relu(self.bn(x))
return x
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
class BiSeNetOutput(nn.Module):
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
super(BiSeNetOutput, self).__init__()
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
self.init_weight()
def forward(self, x):
x = self.conv(x)
x = self.conv_out(x)
return x
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.BatchNorm2d):
nowd_params += list(module.parameters())
return wd_params, nowd_params
class AttentionRefinementModule(nn.Module):
def __init__(self, in_chan, out_chan, *args, **kwargs):
super(AttentionRefinementModule, self).__init__()
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
self.bn_atten = nn.BatchNorm2d(out_chan)
self.sigmoid_atten = nn.Sigmoid()
self.init_weight()
def forward(self, x):
feat = self.conv(x)
atten = F.avg_pool2d(feat, feat.size()[2:])
atten = self.conv_atten(atten)
atten = self.bn_atten(atten)
atten = self.sigmoid_atten(atten)
out = torch.mul(feat, atten)
return out
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
class ContextPath(nn.Module):
def __init__(self, *args, **kwargs):
super(ContextPath, self).__init__()
self.resnet = Resnet18()
self.arm16 = AttentionRefinementModule(256, 128)
self.arm32 = AttentionRefinementModule(512, 128)
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
self.init_weight()
def forward(self, x):
H0, W0 = x.size()[2:]
feat8, feat16, feat32 = self.resnet(x)
H8, W8 = feat8.size()[2:]
H16, W16 = feat16.size()[2:]
H32, W32 = feat32.size()[2:]
avg = F.avg_pool2d(feat32, feat32.size()[2:])
avg = self.conv_avg(avg)
avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
feat32_arm = self.arm32(feat32)
feat32_sum = feat32_arm + avg_up
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
feat32_up = self.conv_head32(feat32_up)
feat16_arm = self.arm16(feat16)
feat16_sum = feat16_arm + feat32_up
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
feat16_up = self.conv_head16(feat16_up)
return feat8, feat16_up, feat32_up # x8, x8, x16
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.BatchNorm2d):
nowd_params += list(module.parameters())
return wd_params, nowd_params
### This is not used, since I replace this with the resnet feature with the same size
class SpatialPath(nn.Module):
def __init__(self, *args, **kwargs):
super(SpatialPath, self).__init__()
self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
self.init_weight()
def forward(self, x):
feat = self.conv1(x)
feat = self.conv2(feat)
feat = self.conv3(feat)
feat = self.conv_out(feat)
return feat
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.BatchNorm2d):
nowd_params += list(module.parameters())
return wd_params, nowd_params
class FeatureFusionModule(nn.Module):
def __init__(self, in_chan, out_chan, *args, **kwargs):
super(FeatureFusionModule, self).__init__()
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
self.conv1 = nn.Conv2d(out_chan,
out_chan//4,
kernel_size = 1,
stride = 1,
padding = 0,
bias = False)
self.conv2 = nn.Conv2d(out_chan//4,
out_chan,
kernel_size = 1,
stride = 1,
padding = 0,
bias = False)
self.relu = nn.ReLU(inplace=True)
self.sigmoid = nn.Sigmoid()
self.init_weight()
def forward(self, fsp, fcp):
fcat = torch.cat([fsp, fcp], dim=1)
feat = self.convblk(fcat)
atten = F.avg_pool2d(feat, feat.size()[2:])
atten = self.conv1(atten)
atten = self.relu(atten)
atten = self.conv2(atten)
atten = self.sigmoid(atten)
feat_atten = torch.mul(feat, atten)
feat_out = feat_atten + feat
return feat_out
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.BatchNorm2d):
nowd_params += list(module.parameters())
return wd_params, nowd_params
class BiSeNet(nn.Module):
def __init__(self, n_classes, *args, **kwargs):
super(BiSeNet, self).__init__()
self.cp = ContextPath()
## here self.sp is deleted
self.ffm = FeatureFusionModule(256, 256)
self.conv_out = BiSeNetOutput(256, 256, n_classes)
self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
self.init_weight()
def forward(self, x):
H, W = x.size()[2:]
feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
feat_fuse = self.ffm(feat_sp, feat_cp8)
feat_out = self.conv_out(feat_fuse)
feat_out16 = self.conv_out16(feat_cp8)
feat_out32 = self.conv_out32(feat_cp16)
feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
# return feat_out, feat_out16, feat_out32
return feat_out
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
for name, child in self.named_children():
child_wd_params, child_nowd_params = child.get_params()
if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
lr_mul_wd_params += child_wd_params
lr_mul_nowd_params += child_nowd_params
else:
wd_params += child_wd_params
nowd_params += child_nowd_params
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
if __name__ == "__main__":
net = BiSeNet(19)
net.cuda()
net.eval()
in_ten = torch.randn(16, 3, 640, 480).cuda()
out, out16, out32 = net(in_ten)
print(out.shape)
net.get_params()

View File

@ -0,0 +1,109 @@
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as modelzoo
# from modules.bn import InPlaceABNSync as BatchNorm2d
resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(nn.Module):
def __init__(self, in_chan, out_chan, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(in_chan, out_chan, stride)
self.bn1 = nn.BatchNorm2d(out_chan)
self.conv2 = conv3x3(out_chan, out_chan)
self.bn2 = nn.BatchNorm2d(out_chan)
self.relu = nn.ReLU(inplace=True)
self.downsample = None
if in_chan != out_chan or stride != 1:
self.downsample = nn.Sequential(
nn.Conv2d(in_chan, out_chan,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_chan),
)
def forward(self, x):
residual = self.conv1(x)
residual = F.relu(self.bn1(residual))
residual = self.conv2(residual)
residual = self.bn2(residual)
shortcut = x
if self.downsample is not None:
shortcut = self.downsample(x)
out = shortcut + residual
out = self.relu(out)
return out
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
for i in range(bnum-1):
layers.append(BasicBlock(out_chan, out_chan, stride=1))
return nn.Sequential(*layers)
class Resnet18(nn.Module):
def __init__(self):
super(Resnet18, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
self.init_weight()
def forward(self, x):
x = self.conv1(x)
x = F.relu(self.bn1(x))
x = self.maxpool(x)
x = self.layer1(x)
feat8 = self.layer2(x) # 1/8
feat16 = self.layer3(feat8) # 1/16
feat32 = self.layer4(feat16) # 1/32
return feat8, feat16, feat32
def init_weight(self):
state_dict = modelzoo.load_url(resnet18_url)
self_state_dict = self.state_dict()
for k, v in state_dict.items():
if 'fc' in k: continue
self_state_dict.update({k: v})
self.load_state_dict(self_state_dict)
def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.BatchNorm2d):
nowd_params += list(module.parameters())
return wd_params, nowd_params
if __name__ == "__main__":
net = Resnet18()
x = torch.randn(16, 3, 224, 224)
out = net(x)
print(out[0].size())
print(out[1].size())
print(out[2].size())
net.get_params()

View File

@ -0,0 +1,98 @@
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import numpy as np
from model import BiSeNet
import torch
import os
import os.path as osp
from PIL import Image
import torchvision.transforms as transforms
import cv2
from pathlib import Path
import configargparse
import tqdm
# import ttach as tta
def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg',
img_size=(512, 512)):
im = np.array(im)
vis_im = im.copy().astype(np.uint8)
vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
vis_parsing_anno = cv2.resize(
vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
vis_parsing_anno_color = np.zeros(
(vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + np.array([255, 255, 255]) # + 255
num_of_class = np.max(vis_parsing_anno)
# print(num_of_class)
for pi in range(1, 14):
index = np.where(vis_parsing_anno == pi)
vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0])
for pi in range(14, 16):
index = np.where(vis_parsing_anno == pi)
vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 255, 0])
for pi in range(16, 17):
index = np.where(vis_parsing_anno == pi)
vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 0, 255])
for pi in range(17, num_of_class+1):
index = np.where(vis_parsing_anno == pi)
vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0])
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
index = np.where(vis_parsing_anno == num_of_class-1)
vis_im = cv2.resize(vis_parsing_anno_color, img_size,
interpolation=cv2.INTER_NEAREST)
if save_im:
cv2.imwrite(save_path, vis_im)
def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'):
Path(respth).mkdir(parents=True, exist_ok=True)
print(f'[INFO] loading model...')
n_classes = 19
net = BiSeNet(n_classes=n_classes)
net.cuda()
net.load_state_dict(torch.load(cp))
net.eval()
to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
image_paths = os.listdir(dspth)
with torch.no_grad():
for image_path in tqdm.tqdm(image_paths):
if image_path.endswith('.jpg') or image_path.endswith('.png'):
img = Image.open(osp.join(dspth, image_path))
ori_size = img.size
image = img.resize((512, 512), Image.BILINEAR)
image = image.convert("RGB")
img = to_tensor(image)
# test-time augmentation.
inputs = torch.unsqueeze(img, 0) # [1, 3, 512, 512]
outputs = net(inputs.cuda())
parsing = outputs.mean(0).cpu().numpy().argmax(0)
image_path = int(image_path[:-4])
image_path = str(image_path) + '.png'
vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path), img_size=ori_size)
if __name__ == "__main__":
parser = configargparse.ArgumentParser()
parser.add_argument('--respath', type=str, default='./result/', help='result path for label')
parser.add_argument('--imgpath', type=str, default='./imgs/', help='path for input images')
parser.add_argument('--modelpath', type=str, default='data_utils/face_parsing/79999_iter.pth')
args = parser.parse_args()
evaluate(respth=args.respath, dspth=args.imgpath, cp=args.modelpath)

View File

View File

@ -0,0 +1,39 @@
import numpy as np
from scipy.io import loadmat
original_BFM = loadmat("3DMM/01_MorphableModel.mat")
sub_inds = np.load("3DMM/topology_info.npy", allow_pickle=True).item()["sub_inds"]
shapePC = original_BFM["shapePC"]
shapeEV = original_BFM["shapeEV"]
shapeMU = original_BFM["shapeMU"]
texPC = original_BFM["texPC"]
texEV = original_BFM["texEV"]
texMU = original_BFM["texMU"]
b_shape = shapePC.reshape(-1, 199).transpose(1, 0).reshape(199, -1, 3)
mu_shape = shapeMU.reshape(-1, 3)
b_tex = texPC.reshape(-1, 199).transpose(1, 0).reshape(199, -1, 3)
mu_tex = texMU.reshape(-1, 3)
b_shape = b_shape[:, sub_inds, :].reshape(199, -1)
mu_shape = mu_shape[sub_inds, :].reshape(-1)
b_tex = b_tex[:, sub_inds, :].reshape(199, -1)
mu_tex = mu_tex[sub_inds, :].reshape(-1)
exp_info = np.load("3DMM/exp_info.npy", allow_pickle=True).item()
np.save(
"3DMM/3DMM_info.npy",
{
"mu_shape": mu_shape,
"b_shape": b_shape,
"sig_shape": shapeEV.reshape(-1),
"mu_exp": exp_info["mu_exp"],
"b_exp": exp_info["base_exp"],
"sig_exp": exp_info["sig_exp"],
"mu_tex": mu_tex,
"b_tex": b_tex,
"sig_tex": texEV.reshape(-1),
},
)

View File

@ -0,0 +1,16 @@
import os
import torch
import numpy as np
def load_dir(path, start, end):
lmss = []
imgs_paths = []
for i in range(start, end):
if os.path.isfile(os.path.join(path, str(i) + ".lms")):
lms = np.loadtxt(os.path.join(path, str(i) + ".lms"), dtype=np.float32)
lmss.append(lms)
imgs_paths.append(os.path.join(path, str(i) + ".jpg"))
lmss = np.stack(lmss)
lmss = torch.as_tensor(lmss).cuda()
return lmss, imgs_paths

View File

@ -0,0 +1,390 @@
import os
import sys
import cv2
import argparse
from pathlib import Path
import torch
import numpy as np
from data_loader import load_dir
from facemodel import Face_3DMM
from util import *
from render_3dmm import Render_3DMM
# torch.autograd.set_detect_anomaly(True)
dir_path = os.path.dirname(os.path.realpath(__file__))
def set_requires_grad(tensor_list):
for tensor in tensor_list:
tensor.requires_grad = True
parser = argparse.ArgumentParser()
parser.add_argument(
"--path", type=str, default="obama/ori_imgs", help="idname of target person"
)
parser.add_argument("--img_h", type=int, default=512, help="image height")
parser.add_argument("--img_w", type=int, default=512, help="image width")
parser.add_argument("--frame_num", type=int, default=11000, help="image number")
args = parser.parse_args()
start_id = 0
end_id = args.frame_num
lms, img_paths = load_dir(args.path, start_id, end_id)
num_frames = lms.shape[0]
h, w = args.img_h, args.img_w
cxy = torch.tensor((w / 2.0, h / 2.0), dtype=torch.float).cuda()
id_dim, exp_dim, tex_dim, point_num = 100, 79, 100, 34650
model_3dmm = Face_3DMM(
os.path.join(dir_path, "3DMM"), id_dim, exp_dim, tex_dim, point_num
)
# only use one image per 40 to do fit the focal length
sel_ids = np.arange(0, num_frames, 40)
sel_num = sel_ids.shape[0]
arg_focal = 1600
arg_landis = 1e5
print(f'[INFO] fitting focal length...')
# fit the focal length
for focal in range(600, 1500, 100):
id_para = lms.new_zeros((1, id_dim), requires_grad=True)
exp_para = lms.new_zeros((sel_num, exp_dim), requires_grad=True)
euler_angle = lms.new_zeros((sel_num, 3), requires_grad=True)
trans = lms.new_zeros((sel_num, 3), requires_grad=True)
trans.data[:, 2] -= 7
focal_length = lms.new_zeros(1, requires_grad=False)
focal_length.data += focal
set_requires_grad([id_para, exp_para, euler_angle, trans])
optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=0.1)
optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=0.1)
for iter in range(2000):
id_para_batch = id_para.expand(sel_num, -1)
geometry = model_3dmm.get_3dlandmarks(
id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
)
proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms[sel_ids].detach())
loss = loss_lan
optimizer_frame.zero_grad()
loss.backward()
optimizer_frame.step()
# if iter % 100 == 0:
# print(focal, 'pose', iter, loss.item())
for iter in range(2500):
id_para_batch = id_para.expand(sel_num, -1)
geometry = model_3dmm.get_3dlandmarks(
id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
)
proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms[sel_ids].detach())
loss_regid = torch.mean(id_para * id_para)
loss_regexp = torch.mean(exp_para * exp_para)
loss = loss_lan + loss_regid * 0.5 + loss_regexp * 0.4
optimizer_idexp.zero_grad()
optimizer_frame.zero_grad()
loss.backward()
optimizer_idexp.step()
optimizer_frame.step()
# if iter % 100 == 0:
# print(focal, 'poseidexp', iter, loss_lan.item(), loss_regid.item(), loss_regexp.item())
if iter % 1500 == 0 and iter >= 1500:
for param_group in optimizer_idexp.param_groups:
param_group["lr"] *= 0.2
for param_group in optimizer_frame.param_groups:
param_group["lr"] *= 0.2
print(focal, loss_lan.item(), torch.mean(trans[:, 2]).item())
if loss_lan.item() < arg_landis:
arg_landis = loss_lan.item()
arg_focal = focal
print("[INFO] find best focal:", arg_focal)
print(f'[INFO] coarse fitting...')
# for all frames, do a coarse fitting ???
id_para = lms.new_zeros((1, id_dim), requires_grad=True)
exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
tex_para = lms.new_zeros(
(1, tex_dim), requires_grad=True
) # not optimized in this block ???
euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True)
trans = lms.new_zeros((num_frames, 3), requires_grad=True)
light_para = lms.new_zeros((num_frames, 27), requires_grad=True)
trans.data[:, 2] -= 7 # ???
focal_length = lms.new_zeros(1, requires_grad=True)
focal_length.data += arg_focal
set_requires_grad([id_para, exp_para, tex_para, euler_angle, trans, light_para])
optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=0.1)
optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=1)
for iter in range(1500):
id_para_batch = id_para.expand(num_frames, -1)
geometry = model_3dmm.get_3dlandmarks(
id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
)
proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms.detach())
loss = loss_lan
optimizer_frame.zero_grad()
loss.backward()
optimizer_frame.step()
if iter == 1000:
for param_group in optimizer_frame.param_groups:
param_group["lr"] = 0.1
# if iter % 100 == 0:
# print('pose', iter, loss.item())
for param_group in optimizer_frame.param_groups:
param_group["lr"] = 0.1
for iter in range(2000):
id_para_batch = id_para.expand(num_frames, -1)
geometry = model_3dmm.get_3dlandmarks(
id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
)
proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms.detach())
loss_regid = torch.mean(id_para * id_para)
loss_regexp = torch.mean(exp_para * exp_para)
loss = loss_lan + loss_regid * 0.5 + loss_regexp * 0.4
optimizer_idexp.zero_grad()
optimizer_frame.zero_grad()
loss.backward()
optimizer_idexp.step()
optimizer_frame.step()
# if iter % 100 == 0:
# print('poseidexp', iter, loss_lan.item(), loss_regid.item(), loss_regexp.item())
if iter % 1000 == 0 and iter >= 1000:
for param_group in optimizer_idexp.param_groups:
param_group["lr"] *= 0.2
for param_group in optimizer_frame.param_groups:
param_group["lr"] *= 0.2
print(loss_lan.item(), torch.mean(trans[:, 2]).item())
print(f'[INFO] fitting light...')
batch_size = 32
device_default = torch.device("cuda:0")
device_render = torch.device("cuda:0")
renderer = Render_3DMM(arg_focal, h, w, batch_size, device_render)
sel_ids = np.arange(0, num_frames, int(num_frames / batch_size))[:batch_size]
imgs = []
for sel_id in sel_ids:
imgs.append(cv2.imread(img_paths[sel_id])[:, :, ::-1])
imgs = np.stack(imgs)
sel_imgs = torch.as_tensor(imgs).cuda()
sel_lms = lms[sel_ids]
sel_light = light_para.new_zeros((batch_size, 27), requires_grad=True)
set_requires_grad([sel_light])
optimizer_tl = torch.optim.Adam([tex_para, sel_light], lr=0.1)
optimizer_id_frame = torch.optim.Adam([euler_angle, trans, exp_para, id_para], lr=0.01)
for iter in range(71):
sel_exp_para, sel_euler, sel_trans = (
exp_para[sel_ids],
euler_angle[sel_ids],
trans[sel_ids],
)
sel_id_para = id_para.expand(batch_size, -1)
geometry = model_3dmm.get_3dlandmarks(
sel_id_para, sel_exp_para, sel_euler, sel_trans, focal_length, cxy
)
proj_geo = forward_transform(geometry, sel_euler, sel_trans, focal_length, cxy)
loss_lan = cal_lan_loss(proj_geo[:, :, :2], sel_lms.detach())
loss_regid = torch.mean(id_para * id_para)
loss_regexp = torch.mean(sel_exp_para * sel_exp_para)
sel_tex_para = tex_para.expand(batch_size, -1)
sel_texture = model_3dmm.forward_tex(sel_tex_para)
geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para)
rott_geo = forward_rott(geometry, sel_euler, sel_trans)
render_imgs = renderer(
rott_geo.to(device_render),
sel_texture.to(device_render),
sel_light.to(device_render),
)
render_imgs = render_imgs.to(device_default)
mask = (render_imgs[:, :, :, 3]).detach() > 0.0
render_proj = sel_imgs.clone()
render_proj[mask] = render_imgs[mask][..., :3].byte()
loss_col = cal_col_loss(render_imgs[:, :, :, :3], sel_imgs.float(), mask)
if iter > 50:
loss = loss_col + loss_lan * 0.05 + loss_regid * 1.0 + loss_regexp * 0.8
else:
loss = loss_col + loss_lan * 3 + loss_regid * 2.0 + loss_regexp * 1.0
optimizer_tl.zero_grad()
optimizer_id_frame.zero_grad()
loss.backward()
optimizer_tl.step()
optimizer_id_frame.step()
if iter % 50 == 0 and iter > 0:
for param_group in optimizer_id_frame.param_groups:
param_group["lr"] *= 0.2
for param_group in optimizer_tl.param_groups:
param_group["lr"] *= 0.2
# print(iter, loss_col.item(), loss_lan.item(), loss_regid.item(), loss_regexp.item())
light_mean = torch.mean(sel_light, 0).unsqueeze(0).repeat(num_frames, 1)
light_para.data = light_mean
exp_para = exp_para.detach()
euler_angle = euler_angle.detach()
trans = trans.detach()
light_para = light_para.detach()
print(f'[INFO] fine frame-wise fitting...')
for i in range(int((num_frames - 1) / batch_size + 1)):
if (i + 1) * batch_size > num_frames:
start_n = num_frames - batch_size
sel_ids = np.arange(num_frames - batch_size, num_frames)
else:
start_n = i * batch_size
sel_ids = np.arange(i * batch_size, i * batch_size + batch_size)
imgs = []
for sel_id in sel_ids:
imgs.append(cv2.imread(img_paths[sel_id])[:, :, ::-1])
imgs = np.stack(imgs)
sel_imgs = torch.as_tensor(imgs).cuda()
sel_lms = lms[sel_ids]
sel_exp_para = exp_para.new_zeros((batch_size, exp_dim), requires_grad=True)
sel_exp_para.data = exp_para[sel_ids].clone()
sel_euler = euler_angle.new_zeros((batch_size, 3), requires_grad=True)
sel_euler.data = euler_angle[sel_ids].clone()
sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True)
sel_trans.data = trans[sel_ids].clone()
sel_light = light_para.new_zeros((batch_size, 27), requires_grad=True)
sel_light.data = light_para[sel_ids].clone()
set_requires_grad([sel_exp_para, sel_euler, sel_trans, sel_light])
optimizer_cur_batch = torch.optim.Adam(
[sel_exp_para, sel_euler, sel_trans, sel_light], lr=0.005
)
sel_id_para = id_para.expand(batch_size, -1).detach()
sel_tex_para = tex_para.expand(batch_size, -1).detach()
pre_num = 5
if i > 0:
pre_ids = np.arange(start_n - pre_num, start_n)
for iter in range(50):
geometry = model_3dmm.get_3dlandmarks(
sel_id_para, sel_exp_para, sel_euler, sel_trans, focal_length, cxy
)
proj_geo = forward_transform(geometry, sel_euler, sel_trans, focal_length, cxy)
loss_lan = cal_lan_loss(proj_geo[:, :, :2], sel_lms.detach())
loss_regexp = torch.mean(sel_exp_para * sel_exp_para)
sel_geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para)
sel_texture = model_3dmm.forward_tex(sel_tex_para)
geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para)
rott_geo = forward_rott(geometry, sel_euler, sel_trans)
render_imgs = renderer(
rott_geo.to(device_render),
sel_texture.to(device_render),
sel_light.to(device_render),
)
render_imgs = render_imgs.to(device_default)
mask = (render_imgs[:, :, :, 3]).detach() > 0.0
loss_col = cal_col_loss(render_imgs[:, :, :, :3], sel_imgs.float(), mask)
if i > 0:
geometry_lap = model_3dmm.forward_geo_sub(
id_para.expand(batch_size + pre_num, -1).detach(),
torch.cat((exp_para[pre_ids].detach(), sel_exp_para)),
model_3dmm.rigid_ids,
)
rott_geo_lap = forward_rott(
geometry_lap,
torch.cat((euler_angle[pre_ids].detach(), sel_euler)),
torch.cat((trans[pre_ids].detach(), sel_trans)),
)
loss_lap = cal_lap_loss(
[rott_geo_lap.reshape(rott_geo_lap.shape[0], -1).permute(1, 0)], [1.0]
)
else:
geometry_lap = model_3dmm.forward_geo_sub(
id_para.expand(batch_size, -1).detach(),
sel_exp_para,
model_3dmm.rigid_ids,
)
rott_geo_lap = forward_rott(geometry_lap, sel_euler, sel_trans)
loss_lap = cal_lap_loss(
[rott_geo_lap.reshape(rott_geo_lap.shape[0], -1).permute(1, 0)], [1.0]
)
if iter > 30:
loss = loss_col * 0.5 + loss_lan * 1.5 + loss_lap * 100000 + loss_regexp * 1.0
else:
loss = loss_col * 0.5 + loss_lan * 8 + loss_lap * 100000 + loss_regexp * 1.0
optimizer_cur_batch.zero_grad()
loss.backward()
optimizer_cur_batch.step()
# if iter % 10 == 0:
# print(
# i,
# iter,
# loss_col.item(),
# loss_lan.item(),
# loss_lap.item(),
# loss_regexp.item(),
# )
print(str(i) + " of " + str(int((num_frames - 1) / batch_size + 1)) + " done")
render_proj = sel_imgs.clone()
render_proj[mask] = render_imgs[mask][..., :3].byte()
exp_para[sel_ids] = sel_exp_para.clone()
euler_angle[sel_ids] = sel_euler.clone()
trans[sel_ids] = sel_trans.clone()
light_para[sel_ids] = sel_light.clone()
torch.save(
{
"id": id_para.detach().cpu(),
"exp": exp_para.detach().cpu(),
"euler": euler_angle.detach().cpu(),
"trans": trans.detach().cpu(),
"focal": focal_length.detach().cpu(),
},
os.path.join(os.path.dirname(args.path), "track_params.pt"),
)
print("params saved")

View File

@ -0,0 +1,153 @@
import torch
import torch.nn as nn
import numpy as np
import os
from util import *
class Face_3DMM(nn.Module):
def __init__(self, modelpath, id_dim, exp_dim, tex_dim, point_num):
super(Face_3DMM, self).__init__()
# id_dim = 100
# exp_dim = 79
# tex_dim = 100
self.point_num = point_num
DMM_info = np.load(
os.path.join(modelpath, "3DMM_info.npy"), allow_pickle=True
).item()
base_id = DMM_info["b_shape"][:id_dim, :]
mu_id = DMM_info["mu_shape"]
base_exp = DMM_info["b_exp"][:exp_dim, :]
mu_exp = DMM_info["mu_exp"]
mu = mu_id + mu_exp
mu = mu.reshape(-1, 3)
for i in range(3):
mu[:, i] -= np.mean(mu[:, i])
mu = mu.reshape(-1)
self.base_id = torch.as_tensor(base_id).cuda() / 100000.0
self.base_exp = torch.as_tensor(base_exp).cuda() / 100000.0
self.mu = torch.as_tensor(mu).cuda() / 100000.0
base_tex = DMM_info["b_tex"][:tex_dim, :]
mu_tex = DMM_info["mu_tex"]
self.base_tex = torch.as_tensor(base_tex).cuda()
self.mu_tex = torch.as_tensor(mu_tex).cuda()
sig_id = DMM_info["sig_shape"][:id_dim]
sig_tex = DMM_info["sig_tex"][:tex_dim]
sig_exp = DMM_info["sig_exp"][:exp_dim]
self.sig_id = torch.as_tensor(sig_id).cuda()
self.sig_tex = torch.as_tensor(sig_tex).cuda()
self.sig_exp = torch.as_tensor(sig_exp).cuda()
keys_info = np.load(
os.path.join(modelpath, "keys_info.npy"), allow_pickle=True
).item()
self.keyinds = torch.as_tensor(keys_info["keyinds"]).cuda()
self.left_contours = torch.as_tensor(keys_info["left_contour"]).cuda()
self.right_contours = torch.as_tensor(keys_info["right_contour"]).cuda()
self.rigid_ids = torch.as_tensor(keys_info["rigid_ids"]).cuda()
def get_3dlandmarks(self, id_para, exp_para, euler_angle, trans, focal_length, cxy):
id_para = id_para * self.sig_id
exp_para = exp_para * self.sig_exp
batch_size = id_para.shape[0]
num_per_contour = self.left_contours.shape[1]
left_contours_flat = self.left_contours.reshape(-1)
right_contours_flat = self.right_contours.reshape(-1)
sel_index = torch.cat(
(
3 * left_contours_flat.unsqueeze(1),
3 * left_contours_flat.unsqueeze(1) + 1,
3 * left_contours_flat.unsqueeze(1) + 2,
),
dim=1,
).reshape(-1)
left_geometry = (
torch.mm(id_para, self.base_id[:, sel_index])
+ torch.mm(exp_para, self.base_exp[:, sel_index])
+ self.mu[sel_index]
)
left_geometry = left_geometry.view(batch_size, -1, 3)
proj_x = forward_transform(
left_geometry, euler_angle, trans, focal_length, cxy
)[:, :, 0]
proj_x = proj_x.reshape(batch_size, 8, num_per_contour)
arg_min = proj_x.argmin(dim=2)
left_geometry = left_geometry.view(batch_size * 8, num_per_contour, 3)
left_3dlands = left_geometry[
torch.arange(batch_size * 8), arg_min.view(-1), :
].view(batch_size, 8, 3)
sel_index = torch.cat(
(
3 * right_contours_flat.unsqueeze(1),
3 * right_contours_flat.unsqueeze(1) + 1,
3 * right_contours_flat.unsqueeze(1) + 2,
),
dim=1,
).reshape(-1)
right_geometry = (
torch.mm(id_para, self.base_id[:, sel_index])
+ torch.mm(exp_para, self.base_exp[:, sel_index])
+ self.mu[sel_index]
)
right_geometry = right_geometry.view(batch_size, -1, 3)
proj_x = forward_transform(
right_geometry, euler_angle, trans, focal_length, cxy
)[:, :, 0]
proj_x = proj_x.reshape(batch_size, 8, num_per_contour)
arg_max = proj_x.argmax(dim=2)
right_geometry = right_geometry.view(batch_size * 8, num_per_contour, 3)
right_3dlands = right_geometry[
torch.arange(batch_size * 8), arg_max.view(-1), :
].view(batch_size, 8, 3)
sel_index = torch.cat(
(
3 * self.keyinds.unsqueeze(1),
3 * self.keyinds.unsqueeze(1) + 1,
3 * self.keyinds.unsqueeze(1) + 2,
),
dim=1,
).reshape(-1)
geometry = (
torch.mm(id_para, self.base_id[:, sel_index])
+ torch.mm(exp_para, self.base_exp[:, sel_index])
+ self.mu[sel_index]
)
lands_3d = geometry.view(-1, self.keyinds.shape[0], 3)
lands_3d[:, :8, :] = left_3dlands
lands_3d[:, 9:17, :] = right_3dlands
return lands_3d
def forward_geo_sub(self, id_para, exp_para, sub_index):
id_para = id_para * self.sig_id
exp_para = exp_para * self.sig_exp
sel_index = torch.cat(
(
3 * sub_index.unsqueeze(1),
3 * sub_index.unsqueeze(1) + 1,
3 * sub_index.unsqueeze(1) + 2,
),
dim=1,
).reshape(-1)
geometry = (
torch.mm(id_para, self.base_id[:, sel_index])
+ torch.mm(exp_para, self.base_exp[:, sel_index])
+ self.mu[sel_index]
)
return geometry.reshape(-1, sub_index.shape[0], 3)
def forward_geo(self, id_para, exp_para):
id_para = id_para * self.sig_id
exp_para = exp_para * self.sig_exp
geometry = (
torch.mm(id_para, self.base_id)
+ torch.mm(exp_para, self.base_exp)
+ self.mu
)
return geometry.reshape(-1, self.point_num, 3)
def forward_tex(self, tex_para):
tex_para = tex_para * self.sig_tex
texture = torch.mm(tex_para, self.base_tex) + self.mu_tex
return texture.reshape(-1, self.point_num, 3)

View File

@ -0,0 +1,69 @@
"""This module contains functions for geometry transform and camera projection"""
import torch
import torch.nn as nn
import numpy as np
def euler2rot(euler_angle):
batch_size = euler_angle.shape[0]
theta = euler_angle[:, 0].reshape(-1, 1, 1)
phi = euler_angle[:, 1].reshape(-1, 1, 1)
psi = euler_angle[:, 2].reshape(-1, 1, 1)
one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device)
zero = torch.zeros(
(batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device
)
rot_x = torch.cat(
(
torch.cat((one, zero, zero), 1),
torch.cat((zero, theta.cos(), theta.sin()), 1),
torch.cat((zero, -theta.sin(), theta.cos()), 1),
),
2,
)
rot_y = torch.cat(
(
torch.cat((phi.cos(), zero, -phi.sin()), 1),
torch.cat((zero, one, zero), 1),
torch.cat((phi.sin(), zero, phi.cos()), 1),
),
2,
)
rot_z = torch.cat(
(
torch.cat((psi.cos(), -psi.sin(), zero), 1),
torch.cat((psi.sin(), psi.cos(), zero), 1),
torch.cat((zero, zero, one), 1),
),
2,
)
return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
def rot_trans_geo(geometry, rot, trans):
rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans.view(-1, 3, 1)
return rott_geo.permute(0, 2, 1)
def euler_trans_geo(geometry, euler, trans):
rot = euler2rot(euler)
return rot_trans_geo(geometry, rot, trans)
def proj_geo(rott_geo, camera_para):
fx = camera_para[:, 0]
fy = camera_para[:, 0]
cx = camera_para[:, 1]
cy = camera_para[:, 2]
X = rott_geo[:, :, 0]
Y = rott_geo[:, :, 1]
Z = rott_geo[:, :, 2]
fxX = fx[:, None] * X
fyY = fy[:, None] * Y
proj_x = -fxX / Z + cx[:, None]
proj_y = fyY / Z + cy[:, None]
return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2)

View File

@ -0,0 +1,202 @@
import torch
import torch.nn as nn
import numpy as np
import os
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
look_at_view_transform,
PerspectiveCameras,
FoVPerspectiveCameras,
PointLights,
DirectionalLights,
Materials,
RasterizationSettings,
MeshRenderer,
MeshRasterizer,
SoftPhongShader,
TexturesUV,
TexturesVertex,
blending,
)
from pytorch3d.ops import interpolate_face_attributes
from pytorch3d.renderer.blending import (
BlendParams,
hard_rgb_blend,
sigmoid_alpha_blend,
softmax_rgb_blend,
)
class SoftSimpleShader(nn.Module):
"""
Per pixel lighting - the lighting model is applied using the interpolated
coordinates and normals for each pixel. The blending function returns the
soft aggregated color using all the faces per pixel.
To use the default values, simply initialize the shader with the desired
device e.g.
"""
def __init__(
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
):
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()
def to(self, device):
# Manually move to device modules which are not subclasses of nn.Module
self.cameras = self.cameras.to(device)
self.materials = self.materials.to(device)
self.lights = self.lights.to(device)
return self
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
texels = meshes.sample_textures(fragments)
blend_params = kwargs.get("blend_params", self.blend_params)
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of SoftPhongShader"
raise ValueError(msg)
znear = kwargs.get("znear", getattr(cameras, "znear", 1.0))
zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0))
images = softmax_rgb_blend(
texels, fragments, blend_params, znear=znear, zfar=zfar
)
return images
class Render_3DMM(nn.Module):
def __init__(
self,
focal=1015,
img_h=500,
img_w=500,
batch_size=1,
device=torch.device("cuda:0"),
):
super(Render_3DMM, self).__init__()
self.focal = focal
self.img_h = img_h
self.img_w = img_w
self.device = device
self.renderer = self.get_render(batch_size)
dir_path = os.path.dirname(os.path.realpath(__file__))
topo_info = np.load(
os.path.join(dir_path, "3DMM", "topology_info.npy"), allow_pickle=True
).item()
self.tris = torch.as_tensor(topo_info["tris"]).to(self.device)
self.vert_tris = torch.as_tensor(topo_info["vert_tris"]).to(self.device)
def compute_normal(self, geometry):
vert_1 = torch.index_select(geometry, 1, self.tris[:, 0])
vert_2 = torch.index_select(geometry, 1, self.tris[:, 1])
vert_3 = torch.index_select(geometry, 1, self.tris[:, 2])
nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 2)
tri_normal = nn.functional.normalize(nnorm, dim=2)
v_norm = tri_normal[:, self.vert_tris, :].sum(2)
vert_normal = v_norm / v_norm.norm(dim=2).unsqueeze(2)
return vert_normal
def get_render(self, batch_size=1):
half_s = self.img_w * 0.5
R, T = look_at_view_transform(10, 0, 0)
R = R.repeat(batch_size, 1, 1)
T = torch.zeros((batch_size, 3), dtype=torch.float32).to(self.device)
cameras = FoVPerspectiveCameras(
device=self.device,
R=R,
T=T,
znear=0.01,
zfar=20,
fov=2 * np.arctan(self.img_w // 2 / self.focal) * 180.0 / np.pi,
)
lights = PointLights(
device=self.device,
location=[[0.0, 0.0, 1e5]],
ambient_color=[[1, 1, 1]],
specular_color=[[0.0, 0.0, 0.0]],
diffuse_color=[[0.0, 0.0, 0.0]],
)
sigma = 1e-4
raster_settings = RasterizationSettings(
image_size=(self.img_h, self.img_w),
blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma / 18.0,
faces_per_pixel=2,
perspective_correct=False,
)
blend_params = blending.BlendParams(background_color=[0, 0, 0])
renderer = MeshRenderer(
rasterizer=MeshRasterizer(raster_settings=raster_settings, cameras=cameras),
shader=SoftSimpleShader(
lights=lights, blend_params=blend_params, cameras=cameras
),
)
return renderer.to(self.device)
@staticmethod
def Illumination_layer(face_texture, norm, gamma):
n_b, num_vertex, _ = face_texture.size()
n_v_full = n_b * num_vertex
gamma = gamma.view(-1, 3, 9).clone()
gamma[:, :, 0] += 0.8
gamma = gamma.permute(0, 2, 1)
a0 = np.pi
a1 = 2 * np.pi / np.sqrt(3.0)
a2 = 2 * np.pi / np.sqrt(8.0)
c0 = 1 / np.sqrt(4 * np.pi)
c1 = np.sqrt(3.0) / np.sqrt(4 * np.pi)
c2 = 3 * np.sqrt(5.0) / np.sqrt(12 * np.pi)
d0 = 0.5 / np.sqrt(3.0)
Y0 = torch.ones(n_v_full).to(gamma.device).float() * a0 * c0
norm = norm.view(-1, 3)
nx, ny, nz = norm[:, 0], norm[:, 1], norm[:, 2]
arrH = []
arrH.append(Y0)
arrH.append(-a1 * c1 * ny)
arrH.append(a1 * c1 * nz)
arrH.append(-a1 * c1 * nx)
arrH.append(a2 * c2 * nx * ny)
arrH.append(-a2 * c2 * ny * nz)
arrH.append(a2 * c2 * d0 * (3 * nz.pow(2) - 1))
arrH.append(-a2 * c2 * nx * nz)
arrH.append(a2 * c2 * 0.5 * (nx.pow(2) - ny.pow(2)))
H = torch.stack(arrH, 1)
Y = H.view(n_b, num_vertex, 9)
lighting = Y.bmm(gamma)
face_color = face_texture * lighting
return face_color
def forward(self, rott_geometry, texture, diffuse_sh):
face_normal = self.compute_normal(rott_geometry)
face_color = self.Illumination_layer(texture, face_normal, diffuse_sh)
face_color = TexturesVertex(face_color)
mesh = Meshes(
rott_geometry,
self.tris.float().repeat(rott_geometry.shape[0], 1, 1),
face_color,
)
rendered_img = self.renderer(mesh)
rendered_img = torch.clamp(rendered_img, 0, 255)
return rendered_img

View File

@ -0,0 +1,192 @@
import torch
import torch.nn as nn
import render_util
import geo_transform
import numpy as np
def compute_tri_normal(geometry, tris):
geometry = geometry.permute(0, 2, 1)
tri_1 = tris[:, 0]
tri_2 = tris[:, 1]
tri_3 = tris[:, 2]
vert_1 = torch.index_select(geometry, 2, tri_1)
vert_2 = torch.index_select(geometry, 2, tri_2)
vert_3 = torch.index_select(geometry, 2, tri_3)
nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 1)
normal = nn.functional.normalize(nnorm).permute(0, 2, 1)
return normal
class Compute_normal_base(torch.autograd.Function):
@staticmethod
def forward(ctx, normal):
(normal_b,) = render_util.normal_base_forward(normal)
ctx.save_for_backward(normal)
return normal_b
@staticmethod
def backward(ctx, grad_normal_b):
(normal,) = ctx.saved_tensors
(grad_normal,) = render_util.normal_base_backward(grad_normal_b, normal)
return grad_normal
class Normal_Base(torch.nn.Module):
def __init__(self):
super(Normal_Base, self).__init__()
def forward(self, normal):
return Compute_normal_base.apply(normal)
def preprocess_render(geometry, euler, trans, cam, tris, vert_tris, ori_img):
point_num = geometry.shape[1]
rott_geo = geo_transform.euler_trans_geo(geometry, euler, trans)
proj_geo = geo_transform.proj_geo(rott_geo, cam)
rot_tri_normal = compute_tri_normal(rott_geo, tris)
rot_vert_normal = torch.index_select(rot_tri_normal, 1, vert_tris)
is_visible = -torch.bmm(
rot_vert_normal.reshape(-1, 1, 3),
nn.functional.normalize(rott_geo.reshape(-1, 3, 1)),
).reshape(-1, point_num)
is_visible[is_visible < 0.01] = -1
pixel_valid = torch.zeros(
(ori_img.shape[0], ori_img.shape[1] * ori_img.shape[2]),
dtype=torch.float32,
device=ori_img.device,
)
return rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid
class Render_Face(torch.autograd.Function):
@staticmethod
def forward(
ctx, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid
):
batch_size, h, w, _ = ori_img.shape
ori_img = ori_img.view(batch_size, -1, 3)
ori_size = torch.cat(
(
torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
* h,
torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
* w,
),
dim=1,
).view(-1)
tri_index, tri_coord, render, real = render_util.render_face_forward(
proj_geo, ori_img, ori_size, texture, nbl, is_visible, tri_inds, pixel_valid
)
ctx.save_for_backward(
ori_img, ori_size, proj_geo, texture, nbl, tri_inds, tri_index, tri_coord
)
return render, real
@staticmethod
def backward(ctx, grad_render, grad_real):
(
ori_img,
ori_size,
proj_geo,
texture,
nbl,
tri_inds,
tri_index,
tri_coord,
) = ctx.saved_tensors
grad_proj_geo, grad_texture, grad_nbl = render_util.render_face_backward(
grad_render,
grad_real,
ori_img,
ori_size,
proj_geo,
texture,
nbl,
tri_inds,
tri_index,
tri_coord,
)
return grad_proj_geo, grad_texture, grad_nbl, None, None, None, None
class Render_RGB(nn.Module):
def __init__(self):
super(Render_RGB, self).__init__()
def forward(
self, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid
):
return Render_Face.apply(
proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid
)
def cal_land(proj_geo, is_visible, lands_info, land_num):
(land_index,) = render_util.update_contour(lands_info, is_visible, land_num)
proj_land = torch.index_select(proj_geo.reshape(-1, 3), 0, land_index)[
:, :2
].reshape(-1, land_num, 2)
return proj_land
class Render_Land(nn.Module):
def __init__(self):
super(Render_Land, self).__init__()
lands_info = np.loadtxt("../data/3DMM/lands_info.txt", dtype=np.int32)
self.lands_info = torch.as_tensor(lands_info).cuda()
tris = np.loadtxt("../data/3DMM/tris.txt", dtype=np.int64)
self.tris = torch.as_tensor(tris).cuda() - 1
vert_tris = np.loadtxt("../data/3DMM/vert_tris.txt", dtype=np.int64)
self.vert_tris = torch.as_tensor(vert_tris).cuda()
self.normal_baser = Normal_Base().cuda()
self.renderer = Render_RGB().cuda()
def render_mesh(self, geometry, euler, trans, cam, ori_img, light):
batch_size, h, w, _ = ori_img.shape
ori_img = ori_img.view(batch_size, -1, 3)
ori_size = torch.cat(
(
torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
* h,
torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
* w,
),
dim=1,
).view(-1)
rott_geo, proj_geo, rot_tri_normal, _, _ = preprocess_render(
geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img
)
tri_nb = self.normal_baser(rot_tri_normal.contiguous())
nbl = torch.bmm(
tri_nb, (light.reshape(-1, 9, 3))[:, :, 0].unsqueeze(-1).repeat(1, 1, 3)
)
texture = torch.ones_like(geometry) * 200
(render,) = render_util.render_mesh(
proj_geo, ori_img, ori_size, texture, nbl, self.tris
)
return render.view(batch_size, h, w, 3).byte()
def cal_loss_rgb(self, geometry, euler, trans, cam, ori_img, light, texture, lands):
rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid = preprocess_render(
geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img
)
tri_nb = self.normal_baser(rot_tri_normal.contiguous())
nbl = torch.bmm(tri_nb, light.reshape(-1, 9, 3))
render, real = self.renderer(
proj_geo, texture, nbl, ori_img, is_visible, self.tris, pixel_valid
)
proj_land = cal_land(proj_geo, is_visible, self.lands_info, lands.shape[1])
col_minus = torch.norm((render - real).reshape(-1, 3), dim=1).reshape(
ori_img.shape[0], -1
)
col_dis = torch.mean(col_minus * pixel_valid) / (
torch.mean(pixel_valid) + 0.00001
)
land_dists = torch.norm((proj_land - lands).reshape(-1, 2), dim=1).reshape(
ori_img.shape[0], -1
)
lan_dis = torch.mean(land_dists)
return col_dis, lan_dis

View File

@ -0,0 +1,109 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
def compute_tri_normal(geometry, tris):
tri_1 = tris[:, 0]
tri_2 = tris[:, 1]
tri_3 = tris[:, 2]
vert_1 = torch.index_select(geometry, 1, tri_1)
vert_2 = torch.index_select(geometry, 1, tri_2)
vert_3 = torch.index_select(geometry, 1, tri_3)
nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 2)
normal = nn.functional.normalize(nnorm)
return normal
def euler2rot(euler_angle):
batch_size = euler_angle.shape[0]
theta = euler_angle[:, 0].reshape(-1, 1, 1)
phi = euler_angle[:, 1].reshape(-1, 1, 1)
psi = euler_angle[:, 2].reshape(-1, 1, 1)
one = torch.ones(batch_size, 1, 1).to(euler_angle.device)
zero = torch.zeros(batch_size, 1, 1).to(euler_angle.device)
rot_x = torch.cat(
(
torch.cat((one, zero, zero), 1),
torch.cat((zero, theta.cos(), theta.sin()), 1),
torch.cat((zero, -theta.sin(), theta.cos()), 1),
),
2,
)
rot_y = torch.cat(
(
torch.cat((phi.cos(), zero, -phi.sin()), 1),
torch.cat((zero, one, zero), 1),
torch.cat((phi.sin(), zero, phi.cos()), 1),
),
2,
)
rot_z = torch.cat(
(
torch.cat((psi.cos(), -psi.sin(), zero), 1),
torch.cat((psi.sin(), psi.cos(), zero), 1),
torch.cat((zero, zero, one), 1),
),
2,
)
return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
def rot_trans_pts(geometry, rot, trans):
rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans[:, :, None]
return rott_geo.permute(0, 2, 1)
def cal_lap_loss(tensor_list, weight_list):
lap_kernel = (
torch.Tensor((-0.5, 1.0, -0.5))
.unsqueeze(0)
.unsqueeze(0)
.float()
.to(tensor_list[0].device)
)
loss_lap = 0
for i in range(len(tensor_list)):
in_tensor = tensor_list[i]
in_tensor = in_tensor.view(-1, 1, in_tensor.shape[-1])
out_tensor = F.conv1d(in_tensor, lap_kernel)
loss_lap += torch.mean(out_tensor ** 2) * weight_list[i]
return loss_lap
def proj_pts(rott_geo, focal_length, cxy):
cx, cy = cxy[0], cxy[1]
X = rott_geo[:, :, 0]
Y = rott_geo[:, :, 1]
Z = rott_geo[:, :, 2]
fxX = focal_length * X
fyY = focal_length * Y
proj_x = -fxX / Z + cx
proj_y = fyY / Z + cy
return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2)
def forward_rott(geometry, euler_angle, trans):
rot = euler2rot(euler_angle)
rott_geo = rot_trans_pts(geometry, rot, trans)
return rott_geo
def forward_transform(geometry, euler_angle, trans, focal_length, cxy):
rot = euler2rot(euler_angle)
rott_geo = rot_trans_pts(geometry, rot, trans)
proj_geo = proj_pts(rott_geo, focal_length, cxy)
return proj_geo
def cal_lan_loss(proj_lan, gt_lan):
return torch.mean((proj_lan - gt_lan) ** 2)
def cal_col_loss(pred_img, gt_img, img_mask):
pred_img = pred_img.float()
# loss = torch.sqrt(torch.sum(torch.square(pred_img - gt_img), 3))*img_mask/255
loss = (torch.sum(torch.square(pred_img - gt_img), 3)) * img_mask / 255
loss = torch.sum(loss, dim=(1, 2)) / torch.sum(img_mask, dim=(1, 2))
loss = torch.mean(loss)
return loss

402
data_utils/process.py Normal file
View File

@ -0,0 +1,402 @@
import os
import glob
import tqdm
import json
import argparse
import cv2
import numpy as np
def extract_audio(path, out_path, sample_rate=16000):
print(f'[INFO] ===== extract audio from {path} to {out_path} =====')
cmd = f'ffmpeg -i {path} -f wav -ar {sample_rate} {out_path}'
os.system(cmd)
print(f'[INFO] ===== extracted audio =====')
def extract_audio_features(path, mode='wav2vec'):
print(f'[INFO] ===== extract audio labels for {path} =====')
if mode == 'wav2vec':
cmd = f'python nerf/asr.py --wav {path} --save_feats'
else: # deepspeech
cmd = f'python data_utils/deepspeech_features/extract_ds_features.py --input {path}'
os.system(cmd)
print(f'[INFO] ===== extracted audio labels =====')
def extract_images(path, out_path, fps=25):
print(f'[INFO] ===== extract images from {path} to {out_path} =====')
cmd = f'ffmpeg -i {path} -vf fps={fps} -qmin 1 -q:v 1 -start_number 0 {os.path.join(out_path, "%d.jpg")}'
os.system(cmd)
print(f'[INFO] ===== extracted images =====')
def extract_semantics(ori_imgs_dir, parsing_dir):
print(f'[INFO] ===== extract semantics from {ori_imgs_dir} to {parsing_dir} =====')
cmd = f'python data_utils/face_parsing/test.py --respath={parsing_dir} --imgpath={ori_imgs_dir}'
os.system(cmd)
print(f'[INFO] ===== extracted semantics =====')
def extract_landmarks(ori_imgs_dir):
print(f'[INFO] ===== extract face landmarks from {ori_imgs_dir} =====')
import face_alignment
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)
image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
for image_path in tqdm.tqdm(image_paths):
input = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
preds = fa.get_landmarks(input)
if len(preds) > 0:
lands = preds[0].reshape(-1, 2)[:,:2]
np.savetxt(image_path.replace('jpg', 'lms'), lands, '%f')
del fa
print(f'[INFO] ===== extracted face landmarks =====')
def extract_background(base_dir, ori_imgs_dir):
print(f'[INFO] ===== extract background image from {ori_imgs_dir} =====')
from sklearn.neighbors import NearestNeighbors
image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
# only use 1/20 image_paths
image_paths = image_paths[::20]
# read one image to get H/W
tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3]
h, w = tmp_image.shape[:2]
# nearest neighbors
all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose()
distss = []
for image_path in tqdm.tqdm(image_paths):
parse_img = cv2.imread(image_path.replace('ori_imgs', 'parsing').replace('.jpg', '.png'))
bg = (parse_img[..., 0] == 255) & (parse_img[..., 1] == 255) & (parse_img[..., 2] == 255)
fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0)
nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
dists, _ = nbrs.kneighbors(all_xys)
distss.append(dists)
distss = np.stack(distss)
max_dist = np.max(distss, 0)
max_id = np.argmax(distss, 0)
bc_pixs = max_dist > 5
bc_pixs_id = np.nonzero(bc_pixs)
bc_ids = max_id[bc_pixs]
imgs = []
num_pixs = distss.shape[1]
for image_path in image_paths:
img = cv2.imread(image_path)
imgs.append(img)
imgs = np.stack(imgs).reshape(-1, num_pixs, 3)
bc_img = np.zeros((h*w, 3), dtype=np.uint8)
bc_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :]
bc_img = bc_img.reshape(h, w, 3)
max_dist = max_dist.reshape(h, w)
bc_pixs = max_dist > 5
bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose()
fg_xys = np.stack(np.nonzero(bc_pixs)).transpose()
nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
distances, indices = nbrs.kneighbors(bg_xys)
bg_fg_xys = fg_xys[indices[:, 0]]
bc_img[bg_xys[:, 0], bg_xys[:, 1], :] = bc_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
cv2.imwrite(os.path.join(base_dir, 'bc.jpg'), bc_img)
print(f'[INFO] ===== extracted background image =====')
def extract_torso_and_gt(base_dir, ori_imgs_dir):
print(f'[INFO] ===== extract torso and gt images for {base_dir} =====')
from scipy.ndimage import binary_erosion, binary_dilation
# load bg
bg_image = cv2.imread(os.path.join(base_dir, 'bc.jpg'), cv2.IMREAD_UNCHANGED)
image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
for image_path in tqdm.tqdm(image_paths):
# read ori image
ori_image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
# read semantics
seg = cv2.imread(image_path.replace('ori_imgs', 'parsing').replace('.jpg', '.png'))
head_part = (seg[..., 0] == 255) & (seg[..., 1] == 0) & (seg[..., 2] == 0)
neck_part = (seg[..., 0] == 0) & (seg[..., 1] == 255) & (seg[..., 2] == 0)
torso_part = (seg[..., 0] == 0) & (seg[..., 1] == 0) & (seg[..., 2] == 255)
bg_part = (seg[..., 0] == 255) & (seg[..., 1] == 255) & (seg[..., 2] == 255)
# get gt image
gt_image = ori_image.copy()
gt_image[bg_part] = bg_image[bg_part]
cv2.imwrite(image_path.replace('ori_imgs', 'gt_imgs'), gt_image)
# get torso image
torso_image = gt_image.copy() # rgb
torso_image[head_part] = bg_image[head_part]
torso_alpha = 255 * np.ones((gt_image.shape[0], gt_image.shape[1], 1), dtype=np.uint8) # alpha
# torso part "vertical" in-painting...
L = 8 + 1
torso_coords = np.stack(np.nonzero(torso_part), axis=-1) # [M, 2]
# lexsort: sort 2D coords first by y then by x,
# ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
inds = np.lexsort((torso_coords[:, 0], torso_coords[:, 1]))
torso_coords = torso_coords[inds]
# choose the top pixel for each column
u, uid, ucnt = np.unique(torso_coords[:, 1], return_index=True, return_counts=True)
top_torso_coords = torso_coords[uid] # [m, 2]
# only keep top-is-head pixels
top_torso_coords_up = top_torso_coords.copy() - np.array([1, 0])
mask = head_part[tuple(top_torso_coords_up.T)]
if mask.any():
top_torso_coords = top_torso_coords[mask]
# get the color
top_torso_colors = gt_image[tuple(top_torso_coords.T)] # [m, 3]
# construct inpaint coords (vertically up, or minus in x)
inpaint_torso_coords = top_torso_coords[None].repeat(L, 0) # [L, m, 2]
inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
inpaint_torso_coords += inpaint_offsets
inpaint_torso_coords = inpaint_torso_coords.reshape(-1, 2) # [Lm, 2]
inpaint_torso_colors = top_torso_colors[None].repeat(L, 0) # [L, m, 3]
darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
inpaint_torso_colors = (inpaint_torso_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
# set color
torso_image[tuple(inpaint_torso_coords.T)] = inpaint_torso_colors
inpaint_torso_mask = np.zeros_like(torso_image[..., 0]).astype(bool)
inpaint_torso_mask[tuple(inpaint_torso_coords.T)] = True
else:
inpaint_torso_mask = None
# neck part "vertical" in-painting...
push_down = 4
L = 48 + push_down + 1
neck_part = binary_dilation(neck_part, structure=np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=bool), iterations=3)
neck_coords = np.stack(np.nonzero(neck_part), axis=-1) # [M, 2]
# lexsort: sort 2D coords first by y then by x,
# ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
inds = np.lexsort((neck_coords[:, 0], neck_coords[:, 1]))
neck_coords = neck_coords[inds]
# choose the top pixel for each column
u, uid, ucnt = np.unique(neck_coords[:, 1], return_index=True, return_counts=True)
top_neck_coords = neck_coords[uid] # [m, 2]
# only keep top-is-head pixels
top_neck_coords_up = top_neck_coords.copy() - np.array([1, 0])
mask = head_part[tuple(top_neck_coords_up.T)]
top_neck_coords = top_neck_coords[mask]
# push these top down for 4 pixels to make the neck inpainting more natural...
offset_down = np.minimum(ucnt[mask] - 1, push_down)
top_neck_coords += np.stack([offset_down, np.zeros_like(offset_down)], axis=-1)
# get the color
top_neck_colors = gt_image[tuple(top_neck_coords.T)] # [m, 3]
# construct inpaint coords (vertically up, or minus in x)
inpaint_neck_coords = top_neck_coords[None].repeat(L, 0) # [L, m, 2]
inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
inpaint_neck_coords += inpaint_offsets
inpaint_neck_coords = inpaint_neck_coords.reshape(-1, 2) # [Lm, 2]
inpaint_neck_colors = top_neck_colors[None].repeat(L, 0) # [L, m, 3]
darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
inpaint_neck_colors = (inpaint_neck_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
# set color
torso_image[tuple(inpaint_neck_coords.T)] = inpaint_neck_colors
# apply blurring to the inpaint area to avoid vertical-line artifects...
inpaint_mask = np.zeros_like(torso_image[..., 0]).astype(bool)
inpaint_mask[tuple(inpaint_neck_coords.T)] = True
blur_img = torso_image.copy()
blur_img = cv2.GaussianBlur(blur_img, (5, 5), cv2.BORDER_DEFAULT)
torso_image[inpaint_mask] = blur_img[inpaint_mask]
# set mask
mask = (neck_part | torso_part | inpaint_mask)
if inpaint_torso_mask is not None:
mask = mask | inpaint_torso_mask
torso_image[~mask] = 0
torso_alpha[~mask] = 0
cv2.imwrite(image_path.replace('ori_imgs', 'torso_imgs').replace('.jpg', '.png'), np.concatenate([torso_image, torso_alpha], axis=-1))
print(f'[INFO] ===== extracted torso and gt images =====')
def face_tracking(ori_imgs_dir):
print(f'[INFO] ===== perform face tracking =====')
image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
# read one image to get H/W
tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3]
h, w = tmp_image.shape[:2]
cmd = f'python data_utils/face_tracking/face_tracker.py --path={ori_imgs_dir} --img_h={h} --img_w={w} --frame_num={len(image_paths)}'
os.system(cmd)
print(f'[INFO] ===== finished face tracking =====')
def save_transforms(base_dir, ori_imgs_dir):
print(f'[INFO] ===== save transforms =====')
import torch
image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
# read one image to get H/W
tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3]
h, w = tmp_image.shape[:2]
params_dict = torch.load(os.path.join(base_dir, 'track_params.pt'))
focal_len = params_dict['focal']
euler_angle = params_dict['euler']
trans = params_dict['trans'] / 10.0
valid_num = euler_angle.shape[0]
def euler2rot(euler_angle):
batch_size = euler_angle.shape[0]
theta = euler_angle[:, 0].reshape(-1, 1, 1)
phi = euler_angle[:, 1].reshape(-1, 1, 1)
psi = euler_angle[:, 2].reshape(-1, 1, 1)
one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device)
zero = torch.zeros((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device)
rot_x = torch.cat((
torch.cat((one, zero, zero), 1),
torch.cat((zero, theta.cos(), theta.sin()), 1),
torch.cat((zero, -theta.sin(), theta.cos()), 1),
), 2)
rot_y = torch.cat((
torch.cat((phi.cos(), zero, -phi.sin()), 1),
torch.cat((zero, one, zero), 1),
torch.cat((phi.sin(), zero, phi.cos()), 1),
), 2)
rot_z = torch.cat((
torch.cat((psi.cos(), -psi.sin(), zero), 1),
torch.cat((psi.sin(), psi.cos(), zero), 1),
torch.cat((zero, zero, one), 1)
), 2)
return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
# train_val_split = int(valid_num*0.5)
# train_val_split = valid_num - 25 * 20 # take the last 20s as valid set.
train_val_split = int(valid_num * 10 / 11)
train_ids = torch.arange(0, train_val_split)
val_ids = torch.arange(train_val_split, valid_num)
rot = euler2rot(euler_angle)
rot_inv = rot.permute(0, 2, 1)
trans_inv = -torch.bmm(rot_inv, trans.unsqueeze(2))
pose = torch.eye(4, dtype=torch.float32)
save_ids = ['train', 'val']
train_val_ids = [train_ids, val_ids]
mean_z = -float(torch.mean(trans[:, 2]).item())
for split in range(2):
transform_dict = dict()
transform_dict['focal_len'] = float(focal_len[0])
transform_dict['cx'] = float(w/2.0)
transform_dict['cy'] = float(h/2.0)
transform_dict['frames'] = []
ids = train_val_ids[split]
save_id = save_ids[split]
for i in ids:
i = i.item()
frame_dict = dict()
frame_dict['img_id'] = i
frame_dict['aud_id'] = i
pose[:3, :3] = rot_inv[i]
pose[:3, 3] = trans_inv[i, :, 0]
frame_dict['transform_matrix'] = pose.numpy().tolist()
transform_dict['frames'].append(frame_dict)
with open(os.path.join(base_dir, 'transforms_' + save_id + '.json'), 'w') as fp:
json.dump(transform_dict, fp, indent=2, separators=(',', ': '))
print(f'[INFO] ===== finished saving transforms =====')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('path', type=str, help="path to video file")
parser.add_argument('--task', type=int, default=-1, help="-1 means all")
parser.add_argument('--asr', type=str, default='wav2vec', help="wav2vec or deepspeech")
opt = parser.parse_args()
base_dir = os.path.dirname(opt.path)
wav_path = os.path.join(base_dir, 'aud.wav')
ori_imgs_dir = os.path.join(base_dir, 'ori_imgs')
parsing_dir = os.path.join(base_dir, 'parsing')
gt_imgs_dir = os.path.join(base_dir, 'gt_imgs')
torso_imgs_dir = os.path.join(base_dir, 'torso_imgs')
os.makedirs(ori_imgs_dir, exist_ok=True)
os.makedirs(parsing_dir, exist_ok=True)
os.makedirs(gt_imgs_dir, exist_ok=True)
os.makedirs(torso_imgs_dir, exist_ok=True)
# extract audio
if opt.task == -1 or opt.task == 1:
extract_audio(opt.path, wav_path)
# extract audio features
if opt.task == -1 or opt.task == 2:
extract_audio_features(wav_path, mode=opt.asr)
# extract images
if opt.task == -1 or opt.task == 3:
extract_images(opt.path, ori_imgs_dir)
# face parsing
if opt.task == -1 or opt.task == 4:
extract_semantics(ori_imgs_dir, parsing_dir)
# extract bg
if opt.task == -1 or opt.task == 5:
extract_background(base_dir, ori_imgs_dir)
# extract torso images and gt_images
if opt.task == -1 or opt.task == 6:
extract_torso_and_gt(base_dir, ori_imgs_dir)
# extract face landmarks
if opt.task == -1 or opt.task == 7:
extract_landmarks(ori_imgs_dir)
# face tracking
if opt.task == -1 or opt.task == 8:
face_tracking(ori_imgs_dir)
# save transforms.json
if opt.task == -1 or opt.task == 9:
save_transforms(base_dir, ori_imgs_dir)

38
encoding.py Normal file
View File

@ -0,0 +1,38 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
def get_encoder(encoding, input_dim=3,
multires=6,
degree=4,
num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False,
**kwargs):
if encoding == 'None':
return lambda x, **kwargs: x, input_dim
elif encoding == 'frequency':
from freqencoder import FreqEncoder
encoder = FreqEncoder(input_dim=input_dim, degree=multires)
elif encoding == 'spherical_harmonics':
from shencoder import SHEncoder
encoder = SHEncoder(input_dim=input_dim, degree=degree)
elif encoding == 'hashgrid':
from gridencoder import GridEncoder
encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners)
elif encoding == 'tiledgrid':
from gridencoder import GridEncoder
encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners)
elif encoding == 'ash':
from ashencoder import AshEncoder
encoder = AshEncoder(input_dim=input_dim, output_dim=16, log2_hashmap_size=log2_hashmap_size, resolution=desired_resolution)
else:
raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, spherical_harmonics, hashgrid, tiledgrid]')
return encoder, encoder.output_dim

1
freqencoder/__init__.py Normal file
View File

@ -0,0 +1 @@
from .freq import FreqEncoder

41
freqencoder/backend.py Normal file
View File

@ -0,0 +1,41 @@
import os
from torch.utils.cpp_extension import load
_src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [
'-O3', '-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
'-use_fast_math'
]
if os.name == "posix":
c_flags = ['-O3', '-std=c++14']
elif os.name == "nt":
c_flags = ['/O2', '/std:c++17']
# find cl.exe
def find_cl_path():
import glob
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
if paths:
return paths[0]
# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
os.environ["PATH"] += ";" + cl_path
_backend = load(name='_freqencoder',
extra_cflags=c_flags,
extra_cuda_cflags=nvcc_flags,
sources=[os.path.join(_src_path, 'src', f) for f in [
'freqencoder.cu',
'bindings.cpp',
]],
)
__all__ = ['_backend']

77
freqencoder/freq.py Normal file
View File

@ -0,0 +1,77 @@
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd
try:
import _freqencoder as _backend
except ImportError:
from .backend import _backend
class _freq_encoder(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
def forward(ctx, inputs, degree, output_dim):
# inputs: [B, input_dim], float
# RETURN: [B, F], float
if not inputs.is_cuda: inputs = inputs.cuda()
inputs = inputs.contiguous()
B, input_dim = inputs.shape # batch size, coord dim
outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
_backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
ctx.save_for_backward(inputs, outputs)
ctx.dims = [B, input_dim, degree, output_dim]
return outputs
@staticmethod
#@once_differentiable
@custom_bwd
def backward(ctx, grad):
# grad: [B, C * C]
grad = grad.contiguous()
inputs, outputs = ctx.saved_tensors
B, input_dim, degree, output_dim = ctx.dims
grad_inputs = torch.zeros_like(inputs)
_backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
return grad_inputs, None, None
freq_encode = _freq_encoder.apply
class FreqEncoder(nn.Module):
def __init__(self, input_dim=3, degree=4):
super().__init__()
self.input_dim = input_dim
self.degree = degree
self.output_dim = input_dim + input_dim * 2 * degree
def __repr__(self):
return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}"
def forward(self, inputs, **kwargs):
# inputs: [..., input_dim]
# return: [..., ]
prefix_shape = list(inputs.shape[:-1])
inputs = inputs.reshape(-1, self.input_dim)
outputs = freq_encode(inputs, self.degree, self.output_dim)
outputs = outputs.reshape(prefix_shape + [self.output_dim])
return outputs

51
freqencoder/setup.py Normal file
View File

@ -0,0 +1,51 @@
import os
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
_src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [
'-O3', '-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
'-use_fast_math'
]
if os.name == "posix":
c_flags = ['-O3', '-std=c++14']
elif os.name == "nt":
c_flags = ['/O2', '/std:c++17']
# find cl.exe
def find_cl_path():
import glob
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
if paths:
return paths[0]
# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
os.environ["PATH"] += ";" + cl_path
setup(
name='freqencoder', # package name, import this to use python API
ext_modules=[
CUDAExtension(
name='_freqencoder', # extension name, import this to use CUDA API
sources=[os.path.join(_src_path, 'src', f) for f in [
'freqencoder.cu',
'bindings.cpp',
]],
extra_compile_args={
'cxx': c_flags,
'nvcc': nvcc_flags,
}
),
],
cmdclass={
'build_ext': BuildExtension,
}
)

View File

@ -0,0 +1,8 @@
#include <torch/extension.h>
#include "freqencoder.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)");
m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)");
}

View File

@ -0,0 +1,129 @@
#include <stdint.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/torch.h>
#include <algorithm>
#include <stdexcept>
#include <cstdio>
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
inline constexpr __device__ float PI() { return 3.141592653589793f; }
template <typename T>
__host__ __device__ T div_round_up(T val, T divisor) {
return (val + divisor - 1) / divisor;
}
// inputs: [B, D]
// outputs: [B, C], C = D + D * deg * 2
__global__ void kernel_freq(
const float * __restrict__ inputs,
uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
float * outputs
) {
// parallel on per-element
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
if (t >= B * C) return;
// get index
const uint32_t b = t / C;
const uint32_t c = t - b * C; // t % C;
// locate
inputs += b * D;
outputs += t;
// write self
if (c < D) {
outputs[0] = inputs[c];
// write freq
} else {
const uint32_t col = c / D - 1;
const uint32_t d = c % D;
const uint32_t freq = col / 2;
const float phase_shift = (col % 2) * (PI() / 2);
outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift);
}
}
// grad: [B, C], C = D + D * deg * 2
// outputs: [B, C]
// grad_inputs: [B, D]
__global__ void kernel_freq_backward(
const float * __restrict__ grad,
const float * __restrict__ outputs,
uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
float * grad_inputs
) {
// parallel on per-element
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
if (t >= B * D) return;
const uint32_t b = t / D;
const uint32_t d = t - b * D; // t % D;
// locate
grad += b * C;
outputs += b * C;
grad_inputs += t;
// register
float result = grad[d];
grad += D;
outputs += D;
for (uint32_t f = 0; f < deg; f++) {
result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]);
grad += 2 * D;
outputs += 2 * D;
}
// write
grad_inputs[0] = result;
}
void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) {
CHECK_CUDA(inputs);
CHECK_CUDA(outputs);
CHECK_CONTIGUOUS(inputs);
CHECK_CONTIGUOUS(outputs);
CHECK_IS_FLOATING(inputs);
CHECK_IS_FLOATING(outputs);
static constexpr uint32_t N_THREADS = 128;
kernel_freq<<<div_round_up(B * C, N_THREADS), N_THREADS>>>(inputs.data_ptr<float>(), B, D, deg, C, outputs.data_ptr<float>());
}
void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) {
CHECK_CUDA(grad);
CHECK_CUDA(outputs);
CHECK_CUDA(grad_inputs);
CHECK_CONTIGUOUS(grad);
CHECK_CONTIGUOUS(outputs);
CHECK_CONTIGUOUS(grad_inputs);
CHECK_IS_FLOATING(grad);
CHECK_IS_FLOATING(outputs);
CHECK_IS_FLOATING(grad_inputs);
static constexpr uint32_t N_THREADS = 128;
kernel_freq_backward<<<div_round_up(B * D, N_THREADS), N_THREADS>>>(grad.data_ptr<float>(), outputs.data_ptr<float>(), B, D, deg, C, grad_inputs.data_ptr<float>());
}

View File

@ -0,0 +1,10 @@
# pragma once
#include <stdint.h>
#include <torch/torch.h>
// _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs);
// _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs);

1
gridencoder/__init__.py Normal file
View File

@ -0,0 +1 @@
from .grid import GridEncoder

40
gridencoder/backend.py Normal file
View File

@ -0,0 +1,40 @@
import os
from torch.utils.cpp_extension import load
_src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [
'-O3', '-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
]
if os.name == "posix":
c_flags = ['-O3', '-std=c++14', '-finput-charset=UTF-8']
elif os.name == "nt":
c_flags = ['/O2', '/std:c++17', '/finput-charset=UTF-8']
# find cl.exe
def find_cl_path():
import glob
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
if paths:
return paths[0]
# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
os.environ["PATH"] += ";" + cl_path
_backend = load(name='_grid_encoder',
extra_cflags=c_flags,
extra_cuda_cflags=nvcc_flags,
sources=[os.path.join(_src_path, 'src', f) for f in [
'gridencoder.cu',
'bindings.cpp',
]],
)
__all__ = ['_backend']

155
gridencoder/grid.py Normal file
View File

@ -0,0 +1,155 @@
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd
try:
import _gridencoder as _backend
except ImportError:
from .backend import _backend
_gridtype_to_id = {
'hash': 0,
'tiled': 1,
}
class _grid_encode(Function):
@staticmethod
@custom_fwd
def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False):
# inputs: [B, D], float in [0, 1]
# embeddings: [sO, C], float
# offsets: [L + 1], int
# RETURN: [B, F], float
inputs = inputs.float().contiguous()
B, D = inputs.shape # batch size, coord dim
L = offsets.shape[0] - 1 # level
C = embeddings.shape[1] # embedding dim for each level
S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
H = base_resolution # base resolution
# manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)
# if C % 2 != 0, force float, since half for atomicAdd is very slow.
if torch.is_autocast_enabled() and C % 2 == 0:
embeddings = embeddings.to(torch.half)
# L first, optimize cache for cuda kernel, but needs an extra permute later
outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
if calc_grad_inputs:
dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype)
else:
dy_dx = None
_backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, dy_dx, gridtype, align_corners)
# permute back to [B, L * C]
outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
ctx.dims = [B, D, C, L, S, H, gridtype]
ctx.align_corners = align_corners
return outputs
@staticmethod
#@once_differentiable
@custom_bwd
def backward(ctx, grad):
inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
B, D, C, L, S, H, gridtype = ctx.dims
align_corners = ctx.align_corners
# grad: [B, L * C] --> [L, B, C]
grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
grad_embeddings = torch.zeros_like(embeddings)
if dy_dx is not None:
grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
else:
grad_inputs = None
_backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners)
if dy_dx is not None:
grad_inputs = grad_inputs.to(inputs.dtype)
return grad_inputs, grad_embeddings, None, None, None, None, None, None
grid_encode = _grid_encode.apply
class GridEncoder(nn.Module):
def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False):
super().__init__()
# the finest resolution desired at the last level, if provided, overridee per_level_scale
if desired_resolution is not None:
per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1))
self.input_dim = input_dim # coord dims, 2 or 3
self.num_levels = num_levels # num levels, each level multiply resolution by 2
self.level_dim = level_dim # encode channels per level
self.per_level_scale = per_level_scale # multiply resolution by this scale at each level.
self.log2_hashmap_size = log2_hashmap_size
self.base_resolution = base_resolution
self.output_dim = num_levels * level_dim
self.gridtype = gridtype
self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash"
self.align_corners = align_corners
# allocate parameters
offsets = []
offset = 0
self.max_params = 2 ** log2_hashmap_size
for i in range(num_levels):
resolution = int(np.ceil(base_resolution * per_level_scale ** i))
params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number
params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible
offsets.append(offset)
offset += params_in_level
# print(resolution, params_in_level)
offsets.append(offset)
offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
self.register_buffer('offsets', offsets)
self.n_params = offsets[-1] * level_dim
# parameters
self.embeddings = nn.Parameter(torch.empty(offset, level_dim))
self.reset_parameters()
def reset_parameters(self):
std = 1e-4
self.embeddings.data.uniform_(-std, std)
def __repr__(self):
return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners}"
def forward(self, inputs, bound=1):
# inputs: [..., input_dim], normalized real world positions in [-bound, bound]
# return: [..., num_levels * level_dim]
inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
#print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())
prefix_shape = list(inputs.shape[:-1])
inputs = inputs.view(-1, self.input_dim)
outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners)
outputs = outputs.view(prefix_shape + [self.output_dim])
#print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())
return outputs

50
gridencoder/setup.py Normal file
View File

@ -0,0 +1,50 @@
import os
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
_src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [
'-O3', '-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
]
if os.name == "posix":
c_flags = ['-O3', '-std=c++14']
elif os.name == "nt":
c_flags = ['/O2', '/std:c++17']
# find cl.exe
def find_cl_path():
import glob
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
if paths:
return paths[0]
# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
os.environ["PATH"] += ";" + cl_path
setup(
name='gridencoder', # package name, import this to use python API
ext_modules=[
CUDAExtension(
name='_gridencoder', # extension name, import this to use CUDA API
sources=[os.path.join(_src_path, 'src', f) for f in [
'gridencoder.cu',
'bindings.cpp',
]],
extra_compile_args={
'cxx': c_flags,
'nvcc': nvcc_flags,
}
),
],
cmdclass={
'build_ext': BuildExtension,
}
)

View File

@ -0,0 +1,8 @@
#include <torch/extension.h>
#include "gridencoder.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)");
m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)");
}

View File

@ -0,0 +1,479 @@
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/torch.h>
#include <algorithm>
#include <stdexcept>
#include <stdint.h>
#include <cstdio>
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF...
static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
// requires CUDA >= 10 and ARCH >= 70
// this is very slow compared to float or __half2, and never used.
//return atomicAdd(reinterpret_cast<__half*>(address), val);
}
template <typename T>
static inline __host__ __device__ T div_round_up(T val, T divisor) {
return (val + divisor - 1) / divisor;
}
template <uint32_t D>
__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {
static_assert(D <= 7, "fast_hash can only hash up to 7 dimensions.");
// While 1 is technically not a good prime for hashing (or a prime at all), it helps memory coherence
// and is sufficient for our use case of obtaining a uniformly colliding index from high-dimensional
// coordinates.
constexpr uint32_t primes[7] = { 1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737 };
uint32_t result = 0;
#pragma unroll
for (uint32_t i = 0; i < D; ++i) {
result ^= pos_grid[i] * primes[i];
}
return result;
}
template <uint32_t D, uint32_t C>
__device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) {
uint32_t stride = 1;
uint32_t index = 0;
#pragma unroll
for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {
index += pos_grid[d] * stride;
stride *= align_corners ? resolution: (resolution + 1);
}
// NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97.
// gridtype: 0 == hash, 1 == tiled
if (gridtype == 0 && stride > hashmap_size) {
index = fast_hash<D>(pos_grid);
}
return (index % hashmap_size) * C + ch;
}
template <typename scalar_t, uint32_t D, uint32_t C>
__global__ void kernel_grid(
const float * __restrict__ inputs,
const scalar_t * __restrict__ grid,
const int * __restrict__ offsets,
scalar_t * __restrict__ outputs,
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
scalar_t * __restrict__ dy_dx,
const uint32_t gridtype,
const bool align_corners
) {
const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
if (b >= B) return;
const uint32_t level = blockIdx.y;
// locate
grid += (uint32_t)offsets[level] * C;
inputs += b * D;
outputs += level * B * C + b * C;
// check input range (should be in [0, 1])
bool flag_oob = false;
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
if (inputs[d] < 0 || inputs[d] > 1) {
flag_oob = true;
}
}
// if input out of bound, just set output to 0
if (flag_oob) {
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
outputs[ch] = 0;
}
if (dy_dx) {
dy_dx += b * D * L * C + level * D * C; // B L D C
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
dy_dx[d * C + ch] = 0;
}
}
}
return;
}
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
const float scale = exp2f(level * S) * H - 1.0f;
const uint32_t resolution = (uint32_t)ceil(scale) + 1;
// calculate coordinate
float pos[D];
uint32_t pos_grid[D];
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
pos_grid[d] = floorf(pos[d]);
pos[d] -= (float)pos_grid[d];
}
//printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);
// interpolate
scalar_t results[C] = {0}; // temp results in register
#pragma unroll
for (uint32_t idx = 0; idx < (1 << D); idx++) {
float w = 1;
uint32_t pos_grid_local[D];
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
if ((idx & (1 << d)) == 0) {
w *= 1 - pos[d];
pos_grid_local[d] = pos_grid[d];
} else {
w *= pos[d];
pos_grid_local[d] = pos_grid[d] + 1;
}
}
uint32_t index = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
// writing to register (fast)
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
results[ch] += w * grid[index + ch];
}
//printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]);
}
// writing to global memory (slow)
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
outputs[ch] = results[ch];
}
// prepare dy_dx
// differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9
if (dy_dx) {
dy_dx += b * D * L * C + level * D * C; // B L D C
#pragma unroll
for (uint32_t gd = 0; gd < D; gd++) {
scalar_t results_grad[C] = {0};
#pragma unroll
for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {
float w = scale;
uint32_t pos_grid_local[D];
#pragma unroll
for (uint32_t nd = 0; nd < D - 1; nd++) {
const uint32_t d = (nd >= gd) ? (nd + 1) : nd;
if ((idx & (1 << nd)) == 0) {
w *= 1 - pos[d];
pos_grid_local[d] = pos_grid[d];
} else {
w *= pos[d];
pos_grid_local[d] = pos_grid[d] + 1;
}
}
pos_grid_local[gd] = pos_grid[gd];
uint32_t index_left = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
pos_grid_local[gd] = pos_grid[gd] + 1;
uint32_t index_right = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]);
}
}
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
dy_dx[gd * C + ch] = results_grad[ch];
}
}
}
}
template <typename scalar_t, uint32_t D, uint32_t C, uint32_t N_C>
__global__ void kernel_grid_backward(
const scalar_t * __restrict__ grad,
const float * __restrict__ inputs,
const scalar_t * __restrict__ grid,
const int * __restrict__ offsets,
scalar_t * __restrict__ grad_grid,
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
const uint32_t gridtype,
const bool align_corners
) {
const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;
if (b >= B) return;
const uint32_t level = blockIdx.y;
const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;
// locate
grad_grid += offsets[level] * C;
inputs += b * D;
grad += level * B * C + b * C + ch; // L, B, C
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
const float scale = exp2f(level * S) * H - 1.0f;
const uint32_t resolution = (uint32_t)ceil(scale) + 1;
// check input range (should be in [0, 1])
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
if (inputs[d] < 0 || inputs[d] > 1) {
return; // grad is init as 0, so we simply return.
}
}
// calculate coordinate
float pos[D];
uint32_t pos_grid[D];
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
pos_grid[d] = floorf(pos[d]);
pos[d] -= (float)pos_grid[d];
}
scalar_t grad_cur[N_C] = {0}; // fetch to register
#pragma unroll
for (uint32_t c = 0; c < N_C; c++) {
grad_cur[c] = grad[c];
}
// interpolate
#pragma unroll
for (uint32_t idx = 0; idx < (1 << D); idx++) {
float w = 1;
uint32_t pos_grid_local[D];
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
if ((idx & (1 << d)) == 0) {
w *= 1 - pos[d];
pos_grid_local[d] = pos_grid[d];
} else {
w *= pos[d];
pos_grid_local[d] = pos_grid[d] + 1;
}
}
uint32_t index = get_grid_index<D, C>(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local);
// atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0
// TODO: use float which is better than __half, if N_C % 2 != 0
if (std::is_same<scalar_t, at::Half>::value && N_C % 2 == 0) {
#pragma unroll
for (uint32_t c = 0; c < N_C; c += 2) {
// process two __half at once (by interpreting as a __half2)
__half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])};
atomicAdd((__half2*)&grad_grid[index + c], v);
}
// float, or __half when N_C % 2 != 0 (which means C == 1)
} else {
#pragma unroll
for (uint32_t c = 0; c < N_C; c++) {
atomicAdd(&grad_grid[index + c], w * grad_cur[c]);
}
}
}
}
template <typename scalar_t, uint32_t D, uint32_t C>
__global__ void kernel_input_backward(
const scalar_t * __restrict__ grad,
const scalar_t * __restrict__ dy_dx,
scalar_t * __restrict__ grad_inputs,
uint32_t B, uint32_t L
) {
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
if (t >= B * D) return;
const uint32_t b = t / D;
const uint32_t d = t - b * D;
dy_dx += b * L * D * C;
scalar_t result = 0;
# pragma unroll
for (int l = 0; l < L; l++) {
# pragma unroll
for (int ch = 0; ch < C; ch++) {
result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch];
}
}
grad_inputs[t] = result;
}
template <typename scalar_t, uint32_t D>
void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {
static constexpr uint32_t N_THREAD = 512;
const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 };
switch (C) {
case 1: kernel_grid<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
case 2: kernel_grid<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
case 4: kernel_grid<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
case 8: kernel_grid<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
}
}
// inputs: [B, D], float, in [0, 1]
// embeddings: [sO, C], float
// offsets: [L + 1], uint32_t
// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.)
// H: base resolution
// dy_dx: [B, L * D * C]
template <typename scalar_t>
void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {
switch (D) {
case 1: kernel_grid_wrapper<scalar_t, 1>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
case 2: kernel_grid_wrapper<scalar_t, 2>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
case 3: kernel_grid_wrapper<scalar_t, 3>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
case 4: kernel_grid_wrapper<scalar_t, 4>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
case 5: kernel_grid_wrapper<scalar_t, 5>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
default: throw std::runtime_error{"GridEncoding: D must be 1, 2, 3, 4, or 5"};
}
}
template <typename scalar_t, uint32_t D>
void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
static constexpr uint32_t N_THREAD = 256;
const uint32_t N_C = std::min(2u, C); // n_features_per_thread
const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 };
switch (C) {
case 1:
kernel_grid_backward<scalar_t, D, 1, 1><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
if (dy_dx) kernel_input_backward<scalar_t, D, 1><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
break;
case 2:
kernel_grid_backward<scalar_t, D, 2, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
if (dy_dx) kernel_input_backward<scalar_t, D, 2><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
break;
case 4:
kernel_grid_backward<scalar_t, D, 4, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
if (dy_dx) kernel_input_backward<scalar_t, D, 4><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
break;
case 8:
kernel_grid_backward<scalar_t, D, 8, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
if (dy_dx) kernel_input_backward<scalar_t, D, 8><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
break;
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
}
}
// grad: [L, B, C], float
// inputs: [B, D], float, in [0, 1]
// embeddings: [sO, C], float
// offsets: [L + 1], uint32_t
// grad_embeddings: [sO, C]
// H: base resolution
template <typename scalar_t>
void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
switch (D) {
case 1: kernel_grid_backward_wrapper<scalar_t, 1>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
case 2: kernel_grid_backward_wrapper<scalar_t, 2>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
case 3: kernel_grid_backward_wrapper<scalar_t, 3>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
case 4: kernel_grid_backward_wrapper<scalar_t, 4>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
case 5: kernel_grid_backward_wrapper<scalar_t, 5>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
default: throw std::runtime_error{"GridEncoding: D must be 1, 2, 3, 4, or 5"};
}
}
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners) {
CHECK_CUDA(inputs);
CHECK_CUDA(embeddings);
CHECK_CUDA(offsets);
CHECK_CUDA(outputs);
// CHECK_CUDA(dy_dx);
CHECK_CONTIGUOUS(inputs);
CHECK_CONTIGUOUS(embeddings);
CHECK_CONTIGUOUS(offsets);
CHECK_CONTIGUOUS(outputs);
// CHECK_CONTIGUOUS(dy_dx);
CHECK_IS_FLOATING(inputs);
CHECK_IS_FLOATING(embeddings);
CHECK_IS_INT(offsets);
CHECK_IS_FLOATING(outputs);
// CHECK_IS_FLOATING(dy_dx);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
embeddings.scalar_type(), "grid_encode_forward", ([&] {
grid_encode_forward_cuda<scalar_t>(inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), outputs.data_ptr<scalar_t>(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners);
}));
}
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners) {
CHECK_CUDA(grad);
CHECK_CUDA(inputs);
CHECK_CUDA(embeddings);
CHECK_CUDA(offsets);
CHECK_CUDA(grad_embeddings);
// CHECK_CUDA(dy_dx);
// CHECK_CUDA(grad_inputs);
CHECK_CONTIGUOUS(grad);
CHECK_CONTIGUOUS(inputs);
CHECK_CONTIGUOUS(embeddings);
CHECK_CONTIGUOUS(offsets);
CHECK_CONTIGUOUS(grad_embeddings);
// CHECK_CONTIGUOUS(dy_dx);
// CHECK_CONTIGUOUS(grad_inputs);
CHECK_IS_FLOATING(grad);
CHECK_IS_FLOATING(inputs);
CHECK_IS_FLOATING(embeddings);
CHECK_IS_INT(offsets);
CHECK_IS_FLOATING(grad_embeddings);
// CHECK_IS_FLOATING(dy_dx);
// CHECK_IS_FLOATING(grad_inputs);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "grid_encode_backward", ([&] {
grid_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), grad_embeddings.data_ptr<scalar_t>(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners);
}));
}

View File

@ -0,0 +1,15 @@
#ifndef _HASH_ENCODE_H
#define _HASH_ENCODE_H
#include <stdint.h>
#include <torch/torch.h>
// inputs: [B, D], float, in [0, 1]
// embeddings: [sO, C], float
// offsets: [L + 1], uint32_t
// outputs: [B, L * C], float
// H: base resolution
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners);
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners);
#endif

260
main.py Normal file
View File

@ -0,0 +1,260 @@
import torch
import argparse
from nerf_triplane.provider import NeRFDataset
from nerf_triplane.utils import *
from nerf_triplane.network import NeRFNetwork
# torch.autograd.set_detect_anomaly(True)
# Close tf32 features. Fix low numerical accuracy on rtx30xx gpu.
try:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
except AttributeError as e:
print('Info. This pytorch version is not support with tf32.')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('path', type=str)
parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --exp_eye")
parser.add_argument('--test', action='store_true', help="test mode (load model and test dataset)")
parser.add_argument('--test_train', action='store_true', help="test mode (load model and train dataset)")
parser.add_argument('--data_range', type=int, nargs='*', default=[0, -1], help="data range to use")
parser.add_argument('--workspace', type=str, default='workspace')
parser.add_argument('--seed', type=int, default=0)
### training options
parser.add_argument('--iters', type=int, default=200000, help="training iters")
parser.add_argument('--lr', type=float, default=1e-2, help="initial learning rate")
parser.add_argument('--lr_net', type=float, default=1e-3, help="initial learning rate")
parser.add_argument('--ckpt', type=str, default='latest')
parser.add_argument('--num_rays', type=int, default=4096 * 16, help="num rays sampled per image for each training step")
parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
parser.add_argument('--max_steps', type=int, default=16, help="max num steps sampled per ray (only valid when using --cuda_ray)")
parser.add_argument('--num_steps', type=int, default=16, help="num steps sampled per ray (only valid when NOT using --cuda_ray)")
parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)")
parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)")
### loss set
parser.add_argument('--warmup_step', type=int, default=10000, help="warm up steps")
parser.add_argument('--amb_aud_loss', type=int, default=1, help="use ambient aud loss")
parser.add_argument('--amb_eye_loss', type=int, default=1, help="use ambient eye loss")
parser.add_argument('--unc_loss', type=int, default=1, help="use uncertainty loss")
parser.add_argument('--lambda_amb', type=float, default=1e-4, help="lambda for ambient loss")
### network backbone options
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
parser.add_argument('--bg_img', type=str, default='', help="background image")
parser.add_argument('--fbg', action='store_true', help="frame-wise bg")
parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes")
parser.add_argument('--fix_eye', type=float, default=-1, help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye")
parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence")
parser.add_argument('--torso_shrink', type=float, default=0.8, help="shrink bg coords to allow more flexibility in deform")
### dataset options
parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)")
parser.add_argument('--preload', type=int, default=0, help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.")
# (the default value is for the fox dataset)
parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.")
parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3")
parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location")
parser.add_argument('--dt_gamma', type=float, default=1/256, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera")
parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied (sigma)")
parser.add_argument('--density_thresh_torso', type=float, default=0.01, help="threshold for density grid to be occupied (alpha)")
parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable")
parser.add_argument('--init_lips', action='store_true', help="init lips region")
parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region")
parser.add_argument('--smooth_lips', action='store_true', help="smooth the enc_a in a exponential decay way...")
parser.add_argument('--torso', action='store_true', help="fix head and train torso")
parser.add_argument('--head_ckpt', type=str, default='', help="head model")
### GUI options
parser.add_argument('--gui', action='store_true', help="start a GUI")
parser.add_argument('--W', type=int, default=450, help="GUI width")
parser.add_argument('--H', type=int, default=450, help="GUI height")
parser.add_argument('--radius', type=float, default=3.35, help="default GUI camera radius from center")
parser.add_argument('--fovy', type=float, default=21.24, help="default GUI camera fovy")
parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel")
### else
parser.add_argument('--att', type=int, default=2, help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)")
parser.add_argument('--aud', type=str, default='', help="audio source (empty will load the default, else should be a path to a npy file)")
parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits")
parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off")
parser.add_argument('--ind_num', type=int, default=10000, help="number of individual codes, should be larger than training dataset size")
parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off")
parser.add_argument('--amb_dim', type=int, default=2, help="ambient dimension")
parser.add_argument('--part', action='store_true', help="use partial training data (1/10)")
parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)")
parser.add_argument('--train_camera', action='store_true', help="optimize camera pose")
parser.add_argument('--smooth_path', action='store_true', help="brute-force smooth camera pose trajectory with a window size")
parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size")
# asr
parser.add_argument('--asr', action='store_true', help="load asr for real-time app")
parser.add_argument('--asr_wav', type=str, default='', help="load the wav and use as input")
parser.add_argument('--asr_play', action='store_true', help="play out the audio")
parser.add_argument('--asr_model', type=str, default='deepspeech')
# parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto')
# parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
parser.add_argument('--asr_save_feats', action='store_true')
# audio FPS
parser.add_argument('--fps', type=int, default=50)
# sliding window left-middle-right length (unit: 20ms)
parser.add_argument('-l', type=int, default=10)
parser.add_argument('-m', type=int, default=50)
parser.add_argument('-r', type=int, default=10)
opt = parser.parse_args()
if opt.O:
opt.fp16 = True
opt.exp_eye = True
if opt.test and False:
opt.smooth_path = True
opt.smooth_eye = True
opt.smooth_lips = True
opt.cuda_ray = True
# assert opt.cuda_ray, "Only support CUDA ray mode."
if opt.patch_size > 1:
# assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss."
assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays."
# if opt.finetune_lips:
# # do not update density grid in finetune stage
# opt.update_extra_interval = 1e9
print(opt)
seed_everything(opt.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = NeRFNetwork(opt)
# manually load state dict for head
if opt.torso and opt.head_ckpt != '':
model_dict = torch.load(opt.head_ckpt, map_location='cpu')['model']
missing_keys, unexpected_keys = model.load_state_dict(model_dict, strict=False)
if len(missing_keys) > 0:
print(f"[WARN] missing keys: {missing_keys}")
if len(unexpected_keys) > 0:
print(f"[WARN] unexpected keys: {unexpected_keys}")
# freeze these keys
for k, v in model.named_parameters():
if k in model_dict:
# print(f'[INFO] freeze {k}, {v.shape}')
v.requires_grad = False
# print(model)
criterion = torch.nn.MSELoss(reduction='none')
if opt.test:
if opt.gui:
metrics = [] # use no metric in GUI for faster initialization...
else:
# metrics = [PSNRMeter(), LPIPSMeter(device=device)]
metrics = [PSNRMeter(), LPIPSMeter(device=device), LMDMeter(backend='fan')]
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt)
if opt.test_train:
test_set = NeRFDataset(opt, device=device, type='train')
# a manual fix to test on the training dataset
test_set.training = False
test_set.num_rays = -1
test_loader = test_set.dataloader()
else:
test_loader = NeRFDataset(opt, device=device, type='test').dataloader()
# temp fix: for update_extra_states
model.aud_features = test_loader._data.auds
model.eye_areas = test_loader._data.eye_area
if opt.gui:
from nerf_triplane.gui import NeRFGUI
# we still need test_loader to provide audio features for testing.
with NeRFGUI(opt, trainer, test_loader) as gui:
gui.render()
else:
### test and save video (fast)
trainer.test(test_loader)
### evaluate metrics (slow)
if test_loader.has_gt:
trainer.evaluate(test_loader)
else:
optimizer = lambda model: torch.optim.AdamW(model.get_params(opt.lr, opt.lr_net), betas=(0, 0.99), eps=1e-8)
train_loader = NeRFDataset(opt, device=device, type='train').dataloader()
assert len(train_loader) < opt.ind_num, f"[ERROR] dataset too many frames: {len(train_loader)}, please increase --ind_num to this number!"
# temp fix: for update_extra_states
model.aud_features = train_loader._data.auds
model.eye_area = train_loader._data.eye_area
model.poses = train_loader._data.poses
# decay to 0.1 * init_lr at last iter step
if opt.finetune_lips:
scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.05 ** (iter / opt.iters))
else:
scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.5 ** (iter / opt.iters))
metrics = [PSNRMeter(), LPIPSMeter(device=device)]
eval_interval = max(1, int(5000 / len(train_loader)))
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, optimizer=optimizer, criterion=criterion, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, scheduler_update_every_step=True, metrics=metrics, use_checkpoint=opt.ckpt, eval_interval=eval_interval)
with open(os.path.join(opt.workspace, 'opt.txt'), 'a') as f:
f.write(str(opt))
if opt.gui:
with NeRFGUI(opt, trainer, train_loader) as gui:
gui.render()
else:
valid_loader = NeRFDataset(opt, device=device, type='val', downscale=1).dataloader()
max_epochs = np.ceil(opt.iters / len(train_loader)).astype(np.int32)
print(f'[INFO] max_epoch = {max_epochs}')
trainer.train(train_loader, valid_loader, max_epochs)
# free some mem
del train_loader, valid_loader
torch.cuda.empty_cache()
# also test
test_loader = NeRFDataset(opt, device=device, type='test').dataloader()
if test_loader.has_gt:
trainer.evaluate(test_loader) # blender has gt, so evaluate it.
trainer.test(test_loader)

419
nerf_triplane/asr.py Normal file
View File

@ -0,0 +1,419 @@
import time
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoModelForCTC, AutoProcessor
import pyaudio
import soundfile as sf
import resampy
from queue import Queue
from threading import Thread, Event
def _read_frame(stream, exit_event, queue, chunk):
while True:
if exit_event.is_set():
print(f'[INFO] read frame thread ends')
break
frame = stream.read(chunk, exception_on_overflow=False)
frame = np.frombuffer(frame, dtype=np.int16).astype(np.float32) / 32767 # [chunk]
queue.put(frame)
def _play_frame(stream, exit_event, queue, chunk):
while True:
if exit_event.is_set():
print(f'[INFO] play frame thread ends')
break
frame = queue.get()
frame = (frame * 32767).astype(np.int16).tobytes()
stream.write(frame, chunk)
class ASR:
def __init__(self, opt):
self.opt = opt
self.play = opt.asr_play
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.fps = opt.fps # 20 ms per frame
self.sample_rate = 16000
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.mode = 'live' if opt.asr_wav == '' else 'file'
if 'esperanto' in self.opt.asr_model:
self.audio_dim = 44
elif 'deepspeech' in self.opt.asr_model:
self.audio_dim = 29
else:
self.audio_dim = 32
# prepare context cache
# each segment is (stride_left + ctx + stride_right) * 20ms, latency should be (ctx + stride_right) * 20ms
self.context_size = opt.m
self.stride_left_size = opt.l
self.stride_right_size = opt.r
self.text = '[START]\n'
self.terminated = False
self.frames = []
# pad left frames
if self.stride_left_size > 0:
self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size)
self.exit_event = Event()
self.audio_instance = pyaudio.PyAudio()
# create input stream
if self.mode == 'file':
self.file_stream = self.create_file_stream()
else:
# start a background process to read frames
self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk)
self.queue = Queue()
self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk))
# play out the audio too...?
if self.play:
self.output_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=False, output=True, frames_per_buffer=self.chunk)
self.output_queue = Queue()
self.process_play_frame = Thread(target=_play_frame, args=(self.output_stream, self.exit_event, self.output_queue, self.chunk))
# current location of audio
self.idx = 0
# create wav2vec model
print(f'[INFO] loading ASR model {self.opt.asr_model}...')
self.processor = AutoProcessor.from_pretrained(opt.asr_model)
self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device)
# prepare to save logits
if self.opt.asr_save_feats:
self.all_feats = []
# the extracted features
# use a loop queue to efficiently record endless features: [f--t---][-------][-------]
self.feat_buffer_size = 4
self.feat_buffer_idx = 0
self.feat_queue = torch.zeros(self.feat_buffer_size * self.context_size, self.audio_dim, dtype=torch.float32, device=self.device)
# TODO: hard coded 16 and 8 window size...
self.front = self.feat_buffer_size * self.context_size - 8 # fake padding
self.tail = 8
# attention window...
self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding...
# warm up steps needed: mid + right + window_size + attention_size
self.warm_up_steps = self.context_size + self.stride_right_size + 8 + 2 * 3
self.listening = False
self.playing = False
def listen(self):
# start
if self.mode == 'live' and not self.listening:
print(f'[INFO] starting read frame thread...')
self.process_read_frame.start()
self.listening = True
if self.play and not self.playing:
print(f'[INFO] starting play frame thread...')
self.process_play_frame.start()
self.playing = True
def stop(self):
self.exit_event.set()
if self.play:
self.output_stream.stop_stream()
self.output_stream.close()
if self.playing:
self.process_play_frame.join()
self.playing = False
if self.mode == 'live':
self.input_stream.stop_stream()
self.input_stream.close()
if self.listening:
self.process_read_frame.join()
self.listening = False
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.stop()
if self.mode == 'live':
# live mode: also print the result text.
self.text += '\n[END]'
print(self.text)
def get_next_feat(self):
# return a [1/8, 16] window, for the next input to nerf side.
while len(self.att_feats) < 8:
# [------f+++t-----]
if self.front < self.tail:
feat = self.feat_queue[self.front:self.tail]
# [++t-----------f+]
else:
feat = torch.cat([self.feat_queue[self.front:], self.feat_queue[:self.tail]], dim=0)
self.front = (self.front + 2) % self.feat_queue.shape[0]
self.tail = (self.tail + 2) % self.feat_queue.shape[0]
# print(self.front, self.tail, feat.shape)
self.att_feats.append(feat.permute(1, 0))
att_feat = torch.stack(self.att_feats, dim=0) # [8, 44, 16]
# discard old
self.att_feats = self.att_feats[1:]
return att_feat
def run_step(self):
if self.terminated:
return
# get a frame of audio
frame = self.get_audio_frame()
# the last frame
if frame is None:
# terminate, but always run the network for the left frames
self.terminated = True
else:
self.frames.append(frame)
# put to output
if self.play:
self.output_queue.put(frame)
# context not enough, do not run network.
if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
return
inputs = np.concatenate(self.frames) # [N * chunk]
# discard the old part to save memory
if not self.terminated:
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
logits, labels, text = self.frame_to_text(inputs)
feats = logits # better lips-sync than labels
# save feats
if self.opt.asr_save_feats:
self.all_feats.append(feats)
# record the feats efficiently.. (no concat, constant memory)
start = self.feat_buffer_idx * self.context_size
end = start + feats.shape[0]
self.feat_queue[start:end] = feats
self.feat_buffer_idx = (self.feat_buffer_idx + 1) % self.feat_buffer_size
# very naive, just concat the text output.
if text != '':
self.text = self.text + ' ' + text
# will only run once at ternimation
if self.terminated:
self.text += '\n[END]'
print(self.text)
if self.opt.asr_save_feats:
print(f'[INFO] save all feats for training purpose... ')
feats = torch.cat(self.all_feats, dim=0) # [N, C]
# print('[INFO] before unfold', feats.shape)
window_size = 16
padding = window_size // 2
feats = feats.view(-1, self.audio_dim).permute(1, 0).contiguous() # [C, M]
feats = feats.view(1, self.audio_dim, -1, 1) # [1, C, M, 1]
unfold_feats = F.unfold(feats, kernel_size=(window_size, 1), padding=(padding, 0), stride=(2, 1)) # [1, C * window_size, M / 2 + 1]
unfold_feats = unfold_feats.view(self.audio_dim, window_size, -1).permute(2, 1, 0).contiguous() # [C, window_size, M / 2 + 1] --> [M / 2 + 1, window_size, C]
# print('[INFO] after unfold', unfold_feats.shape)
# save to a npy file
if 'esperanto' in self.opt.asr_model:
output_path = self.opt.asr_wav.replace('.wav', '_eo.npy')
else:
output_path = self.opt.asr_wav.replace('.wav', '.npy')
np.save(output_path, unfold_feats.cpu().numpy())
print(f"[INFO] saved logits to {output_path}")
def create_file_stream(self):
stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64
stream = stream.astype(np.float32)
if stream.ndim > 1:
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0]
if sample_rate != self.sample_rate:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
print(f'[INFO] loaded audio stream {self.opt.asr_wav}: {stream.shape}')
return stream
def create_pyaudio_stream(self):
import pyaudio
print(f'[INFO] creating live audio stream ...')
audio = pyaudio.PyAudio()
# get devices
info = audio.get_host_api_info_by_index(0)
n_devices = info.get('deviceCount')
for i in range(0, n_devices):
if (audio.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
name = audio.get_device_info_by_host_api_device_index(0, i).get('name')
print(f'[INFO] choose audio device {name}, id {i}')
break
# get stream
stream = audio.open(input_device_index=i,
format=pyaudio.paInt16,
channels=1,
rate=self.sample_rate,
input=True,
frames_per_buffer=self.chunk)
return audio, stream
def get_audio_frame(self):
if self.mode == 'file':
if self.idx < self.file_stream.shape[0]:
frame = self.file_stream[self.idx: self.idx + self.chunk]
self.idx = self.idx + self.chunk
return frame
else:
return None
else:
frame = self.queue.get()
# print(f'[INFO] get frame {frame.shape}')
self.idx = self.idx + self.chunk
return frame
def frame_to_text(self, frame):
# frame: [N * 320], N = (context_size + 2 * stride_size)
inputs = self.processor(frame, sampling_rate=self.sample_rate, return_tensors="pt", padding=True)
with torch.no_grad():
result = self.model(inputs.input_values.to(self.device))
logits = result.logits # [1, N - 1, 32]
# cut off stride
left = max(0, self.stride_left_size)
right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input.
# do not cut right if terminated.
if self.terminated:
right = logits.shape[1]
logits = logits[:, left:right]
# print(frame.shape, inputs.input_values.shape, logits.shape)
predicted_ids = torch.argmax(logits, dim=-1)
transcription = self.processor.batch_decode(predicted_ids)[0].lower()
# for esperanto
# labels = np.array(['ŭ', '»', 'c', 'ĵ', 'ñ', '”', '„', '“', 'ǔ', 'o', 'ĝ', 'm', 'k', 'd', 'a', 'ŝ', 'z', 'i', '«', '—', '', 'ĥ', 'f', 'y', 'h', 'j', '|', 'r', 'u', 'ĉ', 's', '', 'fi', 'l', 'p', '', 'g', 'v', 't', 'b', 'n', 'e', '[UNK]', '[PAD]'])
# labels = np.array([' ', ' ', ' ', '-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z'])
# print(''.join(labels[predicted_ids[0].detach().cpu().long().numpy()]))
# print(predicted_ids[0])
# print(transcription)
return logits[0], predicted_ids[0], transcription # [N,]
def run(self):
self.listen()
while not self.terminated:
self.run_step()
def clear_queue(self):
# clear the queue, to reduce potential latency...
print(f'[INFO] clear queue')
if self.mode == 'live':
self.queue.queue.clear()
if self.play:
self.output_queue.queue.clear()
def warm_up(self):
self.listen()
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
t = time.time()
for _ in range(self.warm_up_steps):
self.run_step()
if torch.cuda.is_available():
torch.cuda.synchronize()
t = time.time() - t
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
self.clear_queue()
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--wav', type=str, default='')
parser.add_argument('--play', action='store_true', help="play out the audio")
parser.add_argument('--model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto')
# parser.add_argument('--model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
parser.add_argument('--save_feats', action='store_true')
# audio FPS
parser.add_argument('--fps', type=int, default=50)
# sliding window left-middle-right length.
parser.add_argument('-l', type=int, default=10)
parser.add_argument('-m', type=int, default=50)
parser.add_argument('-r', type=int, default=10)
opt = parser.parse_args()
# fix
opt.asr_wav = opt.wav
opt.asr_play = opt.play
opt.asr_model = opt.model
opt.asr_save_feats = opt.save_feats
if 'deepspeech' in opt.asr_model:
raise ValueError("DeepSpeech features should not use this code to extract...")
with ASR(opt) as asr:
asr.run()

565
nerf_triplane/gui.py Normal file
View File

@ -0,0 +1,565 @@
import math
import torch
import numpy as np
import dearpygui.dearpygui as dpg
from scipy.spatial.transform import Rotation as R
from .utils import *
from .asr import ASR
class OrbitCamera:
def __init__(self, W, H, r=2, fovy=60):
self.W = W
self.H = H
self.radius = r # camera distance from center
self.fovy = fovy # in degree
self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point
self.rot = R.from_matrix([[0, -1, 0], [0, 0, -1], [1, 0, 0]]) # init camera matrix: [[1, 0, 0], [0, -1, 0], [0, 0, 1]] (to suit ngp convention)
self.up = np.array([1, 0, 0], dtype=np.float32) # need to be normalized!
# pose
@property
def pose(self):
# first move camera to radius
res = np.eye(4, dtype=np.float32)
res[2, 3] -= self.radius
# rotate
rot = np.eye(4, dtype=np.float32)
rot[:3, :3] = self.rot.as_matrix()
res = rot @ res
# translate
res[:3, 3] -= self.center
return res
def update_pose(self, pose):
# pose: [4, 4] numpy array
# assert self.center is 0
self.radius = np.linalg.norm(pose[:3, 3])
T = np.eye(4)
T[2, 3] = -self.radius
rot = pose @ np.linalg.inv(T)
self.rot = R.from_matrix(rot[:3, :3])
def update_intrinsics(self, intrinsics):
fl_x, fl_y, cx, cy = intrinsics
self.W = int(cx * 2)
self.H = int(cy * 2)
self.fovy = np.rad2deg(2 * np.arctan2(self.H, 2 * fl_y))
# intrinsics
@property
def intrinsics(self):
focal = self.H / (2 * np.tan(np.deg2rad(self.fovy) / 2))
return np.array([focal, focal, self.W // 2, self.H // 2])
def orbit(self, dx, dy):
# rotate along camera up/side axis!
side = self.rot.as_matrix()[:3, 0] # why this is side --> ? # already normalized.
rotvec_x = self.up * np.radians(-0.01 * dx)
rotvec_y = side * np.radians(-0.01 * dy)
self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot
def scale(self, delta):
self.radius *= 1.1 ** (-delta)
def pan(self, dx, dy, dz=0):
# pan in camera coordinate system (careful on the sensitivity!)
self.center += 0.0001 * self.rot.as_matrix()[:3, :3] @ np.array([dx, dy, dz])
class NeRFGUI:
def __init__(self, opt, trainer, data_loader, debug=True):
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.W = opt.W
self.H = opt.H
self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)
self.debug = debug
self.training = False
self.step = 0 # training step
self.trainer = trainer
self.data_loader = data_loader
# override with dataloader's intrinsics
self.W = data_loader._data.W
self.H = data_loader._data.H
self.cam.update_intrinsics(data_loader._data.intrinsics)
# use dataloader's pose
pose_init = data_loader._data.poses[0]
self.cam.update_pose(pose_init.detach().cpu().numpy())
# use dataloader's bg
bg_img = data_loader._data.bg_img #.view(1, -1, 3)
if self.H != bg_img.shape[0] or self.W != bg_img.shape[1]:
bg_img = F.interpolate(bg_img.permute(2, 0, 1).unsqueeze(0).contiguous(), (self.H, self.W), mode='bilinear').squeeze(0).permute(1, 2, 0).contiguous()
self.bg_color = bg_img.view(1, -1, 3)
# audio features (from dataloader, only used in non-playing mode)
self.audio_features = data_loader._data.auds # [N, 29, 16]
self.audio_idx = 0
# control eye
self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item()
# playing seq from dataloader, or pause.
self.playing = False
self.loader = iter(data_loader)
self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
self.need_update = True # camera moved, should reset accumulation
self.spp = 1 # sample per pixel
self.mode = 'image' # choose from ['image', 'depth']
self.dynamic_resolution = False # assert False!
self.downscale = 1
self.train_steps = 16
self.ind_index = 0
self.ind_num = trainer.model.individual_codes.shape[0]
# build asr
if self.opt.asr:
self.asr = ASR(opt)
dpg.create_context()
self.register_dpg()
self.test_step()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
if self.opt.asr:
self.asr.stop()
dpg.destroy_context()
def train_step(self):
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
starter.record()
outputs = self.trainer.train_gui(self.data_loader, step=self.train_steps)
ender.record()
torch.cuda.synchronize()
t = starter.elapsed_time(ender)
self.step += self.train_steps
self.need_update = True
dpg.set_value("_log_train_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
dpg.set_value("_log_train_log", f'step = {self.step: 5d} (+{self.train_steps: 2d}), loss = {outputs["loss"]:.4f}, lr = {outputs["lr"]:.5f}')
# dynamic train steps
# max allowed train time per-frame is 500 ms
full_t = t / self.train_steps * 16
train_steps = min(16, max(4, int(16 * 500 / full_t)))
if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8:
self.train_steps = train_steps
def prepare_buffer(self, outputs):
if self.mode == 'image':
return outputs['image']
else:
return np.expand_dims(outputs['depth'], -1).repeat(3, -1)
def test_step(self):
if self.need_update or self.spp < self.opt.max_spp:
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
starter.record()
if self.playing:
try:
data = next(self.loader)
except StopIteration:
self.loader = iter(self.data_loader)
data = next(self.loader)
if self.opt.asr:
# use the live audio stream
data['auds'] = self.asr.get_next_feat()
outputs = self.trainer.test_gui_with_data(data, self.W, self.H)
# sync local camera pose
self.cam.update_pose(data['poses_matrix'][0].detach().cpu().numpy())
else:
if self.audio_features is not None:
auds = get_audio_features(self.audio_features, self.opt.att, self.audio_idx)
else:
auds = None
outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.W, self.H, auds, self.eye_area, self.ind_index, self.bg_color, self.spp, self.downscale)
ender.record()
torch.cuda.synchronize()
t = starter.elapsed_time(ender)
# update dynamic resolution
if self.dynamic_resolution:
# max allowed infer time per-frame is 200 ms
full_t = t / (self.downscale ** 2)
downscale = min(1, max(1/4, math.sqrt(200 / full_t)))
if downscale > self.downscale * 1.2 or downscale < self.downscale * 0.8:
self.downscale = downscale
if self.need_update:
self.render_buffer = self.prepare_buffer(outputs)
self.spp = 1
self.need_update = False
else:
self.render_buffer = (self.render_buffer * self.spp + self.prepare_buffer(outputs)) / (self.spp + 1)
self.spp += 1
if self.playing:
self.need_update = True
dpg.set_value("_log_infer_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
dpg.set_value("_log_resolution", f'{int(self.downscale * self.W)}x{int(self.downscale * self.H)}')
dpg.set_value("_log_spp", self.spp)
dpg.set_value("_texture", self.render_buffer)
def register_dpg(self):
### register texture
with dpg.texture_registry(show=False):
dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture")
### register window
# the rendered image, as the primary window
with dpg.window(tag="_primary_window", width=self.W, height=self.H):
# add the texture
dpg.add_image("_texture")
# dpg.set_primary_window("_primary_window", True)
dpg.show_tool(dpg.mvTool_Metrics)
# control window
with dpg.window(label="Control", tag="_control_window", width=400, height=300):
# button theme
with dpg.theme() as theme_button:
with dpg.theme_component(dpg.mvButton):
dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18))
dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47))
dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83))
dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5)
dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3)
# time
if not self.opt.test:
with dpg.group(horizontal=True):
dpg.add_text("Train time: ")
dpg.add_text("no data", tag="_log_train_time")
with dpg.group(horizontal=True):
dpg.add_text("Infer time: ")
dpg.add_text("no data", tag="_log_infer_time")
with dpg.group(horizontal=True):
dpg.add_text("SPP: ")
dpg.add_text("1", tag="_log_spp")
# train button
if not self.opt.test:
with dpg.collapsing_header(label="Train", default_open=True):
# train / stop
with dpg.group(horizontal=True):
dpg.add_text("Train: ")
def callback_train(sender, app_data):
if self.training:
self.training = False
dpg.configure_item("_button_train", label="start")
else:
self.training = True
dpg.configure_item("_button_train", label="stop")
dpg.add_button(label="start", tag="_button_train", callback=callback_train)
dpg.bind_item_theme("_button_train", theme_button)
def callback_reset(sender, app_data):
@torch.no_grad()
def weight_reset(m: nn.Module):
reset_parameters = getattr(m, "reset_parameters", None)
if callable(reset_parameters):
m.reset_parameters()
self.trainer.model.apply(fn=weight_reset)
self.trainer.model.reset_extra_state() # for cuda_ray density_grid and step_counter
self.need_update = True
dpg.add_button(label="reset", tag="_button_reset", callback=callback_reset)
dpg.bind_item_theme("_button_reset", theme_button)
# save ckpt
with dpg.group(horizontal=True):
dpg.add_text("Checkpoint: ")
def callback_save(sender, app_data):
self.trainer.save_checkpoint(full=True, best=False)
dpg.set_value("_log_ckpt", "saved " + os.path.basename(self.trainer.stats["checkpoints"][-1]))
self.trainer.epoch += 1 # use epoch to indicate different calls.
dpg.add_button(label="save", tag="_button_save", callback=callback_save)
dpg.bind_item_theme("_button_save", theme_button)
dpg.add_text("", tag="_log_ckpt")
# save mesh
with dpg.group(horizontal=True):
dpg.add_text("Marching Cubes: ")
def callback_mesh(sender, app_data):
self.trainer.save_mesh(resolution=256, threshold=10)
dpg.set_value("_log_mesh", "saved " + f'{self.trainer.name}_{self.trainer.epoch}.ply')
self.trainer.epoch += 1 # use epoch to indicate different calls.
dpg.add_button(label="mesh", tag="_button_mesh", callback=callback_mesh)
dpg.bind_item_theme("_button_mesh", theme_button)
dpg.add_text("", tag="_log_mesh")
with dpg.group(horizontal=True):
dpg.add_text("", tag="_log_train_log")
# rendering options
with dpg.collapsing_header(label="Options", default_open=True):
# playing
with dpg.group(horizontal=True):
dpg.add_text("Play: ")
def callback_play(sender, app_data):
if self.playing:
self.playing = False
dpg.configure_item("_button_play", label="start")
else:
self.playing = True
dpg.configure_item("_button_play", label="stop")
if self.opt.asr:
self.asr.warm_up()
self.need_update = True
dpg.add_button(label="start", tag="_button_play", callback=callback_play)
dpg.bind_item_theme("_button_play", theme_button)
# set asr
if self.opt.asr:
# clear queue button
def callback_clear_queue(sender, app_data):
self.asr.clear_queue()
self.need_update = True
dpg.add_button(label="clear", tag="_button_clear_queue", callback=callback_clear_queue)
dpg.bind_item_theme("_button_clear_queue", theme_button)
# dynamic rendering resolution
with dpg.group(horizontal=True):
def callback_set_dynamic_resolution(sender, app_data):
if self.dynamic_resolution:
self.dynamic_resolution = False
self.downscale = 1
else:
self.dynamic_resolution = True
self.need_update = True
# Disable dynamic resolution for face.
# dpg.add_checkbox(label="dynamic resolution", default_value=self.dynamic_resolution, callback=callback_set_dynamic_resolution)
dpg.add_text(f"{self.W}x{self.H}", tag="_log_resolution")
# mode combo
def callback_change_mode(sender, app_data):
self.mode = app_data
self.need_update = True
dpg.add_combo(('image', 'depth'), label='mode', default_value=self.mode, callback=callback_change_mode)
# bg_color picker
def callback_change_bg(sender, app_data):
self.bg_color = torch.tensor(app_data[:3], dtype=torch.float32) # only need RGB in [0, 1]
self.need_update = True
dpg.add_color_edit((255, 255, 255), label="Background Color", width=200, tag="_color_editor", no_alpha=True, callback=callback_change_bg)
# audio index slider
if not self.opt.asr:
def callback_set_audio_index(sender, app_data):
self.audio_idx = app_data
self.need_update = True
dpg.add_slider_int(label="Audio", min_value=0, max_value=self.audio_features.shape[0] - 1, format="%d", default_value=self.audio_idx, callback=callback_set_audio_index)
# ind code index slider
if self.opt.ind_dim > 0:
def callback_set_individual_code(sender, app_data):
self.ind_index = app_data
self.need_update = True
dpg.add_slider_int(label="Individual", min_value=0, max_value=self.ind_num - 1, format="%d", default_value=self.ind_index, callback=callback_set_individual_code)
# eye area slider
if self.opt.exp_eye:
def callback_set_eye(sender, app_data):
self.eye_area = app_data
self.need_update = True
dpg.add_slider_float(label="eye area", min_value=0, max_value=0.5, format="%.2f percent", default_value=self.eye_area, callback=callback_set_eye)
# fov slider
def callback_set_fovy(sender, app_data):
self.cam.fovy = app_data
self.need_update = True
dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy)
# dt_gamma slider
def callback_set_dt_gamma(sender, app_data):
self.opt.dt_gamma = app_data
self.need_update = True
dpg.add_slider_float(label="dt_gamma", min_value=0, max_value=0.1, format="%.5f", default_value=self.opt.dt_gamma, callback=callback_set_dt_gamma)
# max_steps slider
def callback_set_max_steps(sender, app_data):
self.opt.max_steps = app_data
self.need_update = True
dpg.add_slider_int(label="max steps", min_value=1, max_value=1024, format="%d", default_value=self.opt.max_steps, callback=callback_set_max_steps)
# aabb slider
def callback_set_aabb(sender, app_data, user_data):
# user_data is the dimension for aabb (xmin, ymin, zmin, xmax, ymax, zmax)
self.trainer.model.aabb_infer[user_data] = app_data
# also change train aabb ? [better not...]
#self.trainer.model.aabb_train[user_data] = app_data
self.need_update = True
dpg.add_separator()
dpg.add_text("Axis-aligned bounding box:")
with dpg.group(horizontal=True):
dpg.add_slider_float(label="x", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=0)
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=3)
with dpg.group(horizontal=True):
dpg.add_slider_float(label="y", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=1)
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=4)
with dpg.group(horizontal=True):
dpg.add_slider_float(label="z", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=2)
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=5)
# debug info
if self.debug:
with dpg.collapsing_header(label="Debug"):
# pose
dpg.add_separator()
dpg.add_text("Camera Pose:")
dpg.add_text(str(self.cam.pose), tag="_log_pose")
### register camera handler
def callback_camera_drag_rotate(sender, app_data):
if not dpg.is_item_focused("_primary_window"):
return
dx = app_data[1]
dy = app_data[2]
self.cam.orbit(dx, dy)
self.need_update = True
if self.debug:
dpg.set_value("_log_pose", str(self.cam.pose))
def callback_camera_wheel_scale(sender, app_data):
if not dpg.is_item_focused("_primary_window"):
return
delta = app_data
self.cam.scale(delta)
self.need_update = True
if self.debug:
dpg.set_value("_log_pose", str(self.cam.pose))
def callback_camera_drag_pan(sender, app_data):
if not dpg.is_item_focused("_primary_window"):
return
dx = app_data[1]
dy = app_data[2]
self.cam.pan(dx, dy)
self.need_update = True
if self.debug:
dpg.set_value("_log_pose", str(self.cam.pose))
with dpg.handler_registry():
dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate)
dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan)
dpg.create_viewport(title='RAD-NeRF', width=1080, height=720, resizable=True)
### global theme
with dpg.theme() as theme_no_padding:
with dpg.theme_component(dpg.mvAll):
# set all padding to 0 to avoid scroll bar
dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core)
dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core)
dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core)
dpg.bind_item_theme("_primary_window", theme_no_padding)
dpg.setup_dearpygui()
#dpg.show_metrics()
dpg.show_viewport()
def render(self):
while dpg.is_dearpygui_running():
# update texture every frame
if self.training:
self.train_step()
# audio stream thread...
if self.opt.asr and self.playing:
# run 2 ASR steps (audio is at 50FPS, video is at 25FPS)
for _ in range(2):
self.asr.run_step()
self.test_step()
dpg.render_dearpygui_frame()

352
nerf_triplane/network.py Normal file
View File

@ -0,0 +1,352 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from encoding import get_encoder
from .renderer import NeRFRenderer
# Audio feature extractor
class AudioAttNet(nn.Module):
def __init__(self, dim_aud=64, seq_len=8):
super(AudioAttNet, self).__init__()
self.seq_len = seq_len
self.dim_aud = dim_aud
self.attentionConvNet = nn.Sequential( # b x subspace_dim x seq_len
nn.Conv1d(self.dim_aud, 16, kernel_size=3, stride=1, padding=1, bias=True),
nn.LeakyReLU(0.02, True),
nn.Conv1d(16, 8, kernel_size=3, stride=1, padding=1, bias=True),
nn.LeakyReLU(0.02, True),
nn.Conv1d(8, 4, kernel_size=3, stride=1, padding=1, bias=True),
nn.LeakyReLU(0.02, True),
nn.Conv1d(4, 2, kernel_size=3, stride=1, padding=1, bias=True),
nn.LeakyReLU(0.02, True),
nn.Conv1d(2, 1, kernel_size=3, stride=1, padding=1, bias=True),
nn.LeakyReLU(0.02, True)
)
self.attentionNet = nn.Sequential(
nn.Linear(in_features=self.seq_len, out_features=self.seq_len, bias=True),
nn.Softmax(dim=1)
)
def forward(self, x):
# x: [1, seq_len, dim_aud]
y = x.permute(0, 2, 1) # [1, dim_aud, seq_len]
y = self.attentionConvNet(y)
y = self.attentionNet(y.view(1, self.seq_len)).view(1, self.seq_len, 1)
return torch.sum(y * x, dim=1) # [1, dim_aud]
# Audio feature extractor
class AudioNet(nn.Module):
def __init__(self, dim_in=29, dim_aud=64, win_size=16):
super(AudioNet, self).__init__()
self.win_size = win_size
self.dim_aud = dim_aud
self.encoder_conv = nn.Sequential( # n x 29 x 16
nn.Conv1d(dim_in, 32, kernel_size=3, stride=2, padding=1, bias=True), # n x 32 x 8
nn.LeakyReLU(0.02, True),
nn.Conv1d(32, 32, kernel_size=3, stride=2, padding=1, bias=True), # n x 32 x 4
nn.LeakyReLU(0.02, True),
nn.Conv1d(32, 64, kernel_size=3, stride=2, padding=1, bias=True), # n x 64 x 2
nn.LeakyReLU(0.02, True),
nn.Conv1d(64, 64, kernel_size=3, stride=2, padding=1, bias=True), # n x 64 x 1
nn.LeakyReLU(0.02, True),
)
self.encoder_fc1 = nn.Sequential(
nn.Linear(64, 64),
nn.LeakyReLU(0.02, True),
nn.Linear(64, dim_aud),
)
def forward(self, x):
half_w = int(self.win_size/2)
x = x[:, :, 8-half_w:8+half_w]
x = self.encoder_conv(x).squeeze(-1)
x = self.encoder_fc1(x)
return x
class MLP(nn.Module):
def __init__(self, dim_in, dim_out, dim_hidden, num_layers):
super().__init__()
self.dim_in = dim_in
self.dim_out = dim_out
self.dim_hidden = dim_hidden
self.num_layers = num_layers
net = []
for l in range(num_layers):
net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=False))
self.net = nn.ModuleList(net)
def forward(self, x):
for l in range(self.num_layers):
x = self.net[l](x)
if l != self.num_layers - 1:
x = F.relu(x, inplace=True)
# x = F.dropout(x, p=0.1, training=self.training)
return x
class NeRFNetwork(NeRFRenderer):
def __init__(self,
opt,
# torso net (hard coded for now)
):
super().__init__(opt)
# audio embedding
self.emb = self.opt.emb
if 'esperanto' in self.opt.asr_model:
self.audio_in_dim = 44
elif 'deepspeech' in self.opt.asr_model:
self.audio_in_dim = 29
else:
self.audio_in_dim = 32
if self.emb:
self.embedding = nn.Embedding(self.audio_in_dim, self.audio_in_dim)
# audio network
audio_dim = 32
self.audio_dim = audio_dim
self.audio_net = AudioNet(self.audio_in_dim, self.audio_dim)
self.att = self.opt.att
if self.att > 0:
self.audio_att_net = AudioAttNet(self.audio_dim)
# DYNAMIC PART
self.num_levels = 12
self.level_dim = 1
self.encoder_xy, self.in_dim_xy = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=64, log2_hashmap_size=14, desired_resolution=512 * self.bound)
self.encoder_yz, self.in_dim_yz = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=64, log2_hashmap_size=14, desired_resolution=512 * self.bound)
self.encoder_xz, self.in_dim_xz = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=64, log2_hashmap_size=14, desired_resolution=512 * self.bound)
self.in_dim = self.in_dim_xy + self.in_dim_yz + self.in_dim_xz
## sigma network
self.num_layers = 3
self.hidden_dim = 64
self.geo_feat_dim = 64
self.eye_att_net = MLP(self.in_dim, 1, 16, 2)
self.eye_dim = 1 if self.exp_eye else 0
self.sigma_net = MLP(self.in_dim + self.audio_dim + self.eye_dim, 1 + self.geo_feat_dim, self.hidden_dim, self.num_layers)
## color network
self.num_layers_color = 2
self.hidden_dim_color = 64
self.encoder_dir, self.in_dim_dir = get_encoder('spherical_harmonics')
self.color_net = MLP(self.in_dim_dir + self.geo_feat_dim + self.individual_dim, 3, self.hidden_dim_color, self.num_layers_color)
# 处理音频的
self.unc_net = MLP(self.in_dim, 1, 32, 2)
self.aud_ch_att_net = MLP(self.in_dim, self.audio_dim, 64, 2)
self.testing = False
if self.torso:
# torso deform network
self.register_parameter('anchor_points',
nn.Parameter(torch.tensor([[0.01, 0.01, 0.1, 1], [-0.1, -0.1, 0.1, 1], [0.1, -0.1, 0.1, 1]])))
self.torso_deform_encoder, self.torso_deform_in_dim = get_encoder('frequency', input_dim=2, multires=8)
# self.torso_deform_encoder, self.torso_deform_in_dim = get_encoder('tiledgrid', input_dim=2, num_levels=16, level_dim=1, base_resolution=16, log2_hashmap_size=16, desired_resolution=512)
self.anchor_encoder, self.anchor_in_dim = get_encoder('frequency', input_dim=6, multires=3)
self.torso_deform_net = MLP(self.torso_deform_in_dim + self.anchor_in_dim + self.individual_dim_torso, 2, 32, 3)
# torso color network
self.torso_encoder, self.torso_in_dim = get_encoder('tiledgrid', input_dim=2, num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=16, desired_resolution=2048)
self.torso_net = MLP(self.torso_in_dim + self.torso_deform_in_dim + self.anchor_in_dim + self.individual_dim_torso, 4, 32, 3)
def forward_torso(self, x, poses, c=None):
# x: [N, 2] in [-1, 1]
# head poses: [1, 4, 4]
# c: [1, ind_dim], individual code
# test: shrink x
x = x * self.opt.torso_shrink
# 对pose进行了调整
# deformation-based
wrapped_anchor = self.anchor_points[None, ...] @ poses.permute(0, 2, 1).inverse()
wrapped_anchor = (wrapped_anchor[:, :, :2] / wrapped_anchor[:, :, 3, None] / wrapped_anchor[:, :, 2, None]).view(1, -1)
# print(wrapped_anchor)
# enc_pose = self.pose_encoder(poses)
enc_anchor = self.anchor_encoder(wrapped_anchor)
enc_x = self.torso_deform_encoder(x)
if c is not None:
h = torch.cat([enc_x, enc_anchor.repeat(x.shape[0], 1), c.repeat(x.shape[0], 1)], dim=-1)
else:
h = torch.cat([enc_x, enc_anchor.repeat(x.shape[0], 1)], dim=-1)
dx = self.torso_deform_net(h)
x = (x + dx).clamp(-1, 1)
x = self.torso_encoder(x, bound=1)
# h = torch.cat([x, h, enc_a.repeat(x.shape[0], 1)], dim=-1)
h = torch.cat([x, h], dim=-1)
h = self.torso_net(h)
alpha = torch.sigmoid(h[..., :1])*(1 + 2*0.001) - 0.001
color = torch.sigmoid(h[..., 1:])*(1 + 2*0.001) - 0.001
return alpha, color, dx
@staticmethod
@torch.jit.script
def split_xyz(x):
xy, yz, xz = x[:, :-1], x[:, 1:], torch.cat([x[:,:1], x[:,-1:]], dim=-1)
return xy, yz, xz
def encode_x(self, xyz, bound):
# x: [N, 3], in [-bound, bound]
N, M = xyz.shape
xy, yz, xz = self.split_xyz(xyz)
feat_xy = self.encoder_xy(xy, bound=bound)
feat_yz = self.encoder_yz(yz, bound=bound)
feat_xz = self.encoder_xz(xz, bound=bound)
return torch.cat([feat_xy, feat_yz, feat_xz], dim=-1)
def encode_audio(self, a):
# a: [1, 29, 16] or [8, 29, 16], audio features from deepspeech
# if emb, a should be: [1, 16] or [8, 16]
# fix audio traininig
if a is None: return None
if self.emb:
a = self.embedding(a).transpose(-1, -2).contiguous() # [1/8, 29, 16]
enc_a = self.audio_net(a) # [1/8, 64]
if self.att > 0:
enc_a = self.audio_att_net(enc_a.unsqueeze(0)) # [1, 64]
return enc_a
def predict_uncertainty(self, unc_inp):
if self.testing or not self.opt.unc_loss:
unc = torch.zeros_like(unc_inp)
else:
unc = self.unc_net(unc_inp.detach())
return unc
def forward(self, x, d, enc_a, c, e=None):
# x: [N, 3], in [-bound, bound]
# d: [N, 3], nomalized in [-1, 1]
# enc_a: [1, aud_dim]
# c: [1, ind_dim], individual code
# e: [1, 1], eye feature
enc_x = self.encode_x(x, bound=self.bound)
sigma_result = self.density(x, enc_a, e, enc_x)
sigma = sigma_result['sigma']
geo_feat = sigma_result['geo_feat']
aud_ch_att = sigma_result['ambient_aud']
eye_att = sigma_result['ambient_eye']
# color
enc_d = self.encoder_dir(d)
if c is not None:
h = torch.cat([enc_d, geo_feat, c.repeat(x.shape[0], 1)], dim=-1)
else:
h = torch.cat([enc_d, geo_feat], dim=-1)
h_color = self.color_net(h)
color = torch.sigmoid(h_color)*(1 + 2*0.001) - 0.001
uncertainty = self.predict_uncertainty(enc_x)
uncertainty = torch.log(1 + torch.exp(uncertainty))
return sigma, color, aud_ch_att, eye_att, uncertainty[..., None]
def density(self, x, enc_a, e=None, enc_x=None):
# x: [N, 3], in [-bound, bound]
if enc_x is None:
enc_x = self.encode_x(x, bound=self.bound)
enc_a = enc_a.repeat(enc_x.shape[0], 1)
aud_ch_att = self.aud_ch_att_net(enc_x)
enc_w = enc_a * aud_ch_att
if e is not None:
# e = self.encoder_eye(e)
eye_att = torch.sigmoid(self.eye_att_net(enc_x))
e = e * eye_att
# e = e.repeat(enc_x.shape[0], 1)
h = torch.cat([enc_x, enc_w, e], dim=-1)
else:
h = torch.cat([enc_x, enc_w], dim=-1)
h = self.sigma_net(h)
sigma = torch.exp(h[..., 0])
geo_feat = h[..., 1:]
return {
'sigma': sigma,
'geo_feat': geo_feat,
'ambient_aud' : aud_ch_att.norm(dim=-1, keepdim=True),
'ambient_eye' : eye_att,
}
# optimizer utils
def get_params(self, lr, lr_net, wd=0):
# ONLY train torso
if self.torso:
params = [
{'params': self.torso_encoder.parameters(), 'lr': lr},
{'params': self.torso_deform_encoder.parameters(), 'lr': lr, 'weight_decay': wd},
{'params': self.torso_net.parameters(), 'lr': lr_net, 'weight_decay': wd},
{'params': self.torso_deform_net.parameters(), 'lr': lr_net, 'weight_decay': wd},
{'params': self.anchor_points, 'lr': lr_net, 'weight_decay': wd}
]
if self.individual_dim_torso > 0:
params.append({'params': self.individual_codes_torso, 'lr': lr_net, 'weight_decay': wd})
return params
params = [
{'params': self.audio_net.parameters(), 'lr': lr_net, 'weight_decay': wd},
{'params': self.encoder_xy.parameters(), 'lr': lr},
{'params': self.encoder_yz.parameters(), 'lr': lr},
{'params': self.encoder_xz.parameters(), 'lr': lr},
# {'params': self.encoder_xyz.parameters(), 'lr': lr},
{'params': self.sigma_net.parameters(), 'lr': lr_net, 'weight_decay': wd},
{'params': self.color_net.parameters(), 'lr': lr_net, 'weight_decay': wd},
]
if self.att > 0:
params.append({'params': self.audio_att_net.parameters(), 'lr': lr_net * 5, 'weight_decay': 0.0001})
if self.emb:
params.append({'params': self.embedding.parameters(), 'lr': lr})
if self.individual_dim > 0:
params.append({'params': self.individual_codes, 'lr': lr_net, 'weight_decay': wd})
if self.train_camera:
params.append({'params': self.camera_dT, 'lr': 1e-5, 'weight_decay': 0})
params.append({'params': self.camera_dR, 'lr': 1e-5, 'weight_decay': 0})
params.append({'params': self.aud_ch_att_net.parameters(), 'lr': lr_net, 'weight_decay': wd})
params.append({'params': self.unc_net.parameters(), 'lr': lr_net, 'weight_decay': wd})
params.append({'params': self.eye_att_net.parameters(), 'lr': lr_net, 'weight_decay': wd})
return params

732
nerf_triplane/provider.py Normal file
View File

@ -0,0 +1,732 @@
import os
import cv2
import glob
import json
import tqdm
import numpy as np
from scipy.spatial.transform import Slerp, Rotation
import matplotlib.pyplot as plt
import trimesh
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from .utils import get_audio_features, get_rays, get_bg_coords, convert_poses
# ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50
def nerf_matrix_to_ngp(pose, scale=0.33, offset=[0, 0, 0]):
new_pose = np.array([
[pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale + offset[0]],
[pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale + offset[1]],
[pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale + offset[2]],
[0, 0, 0, 1],
], dtype=np.float32)
return new_pose
def smooth_camera_path(poses, kernel_size=5):
# smooth the camera trajectory...
# poses: [N, 4, 4], numpy array
N = poses.shape[0]
K = kernel_size // 2
trans = poses[:, :3, 3].copy() # [N, 3]
rots = poses[:, :3, :3].copy() # [N, 3, 3]
for i in range(N):
start = max(0, i - K)
end = min(N, i + K + 1)
poses[i, :3, 3] = trans[start:end].mean(0)
poses[i, :3, :3] = Rotation.from_matrix(rots[start:end]).mean().as_matrix()
return poses
def polygon_area(x, y):
x_ = x - x.mean()
y_ = y - y.mean()
correction = x_[-1] * y_[0] - y_[-1]* x_[0]
main_area = np.dot(x_[:-1], y_[1:]) - np.dot(y_[:-1], x_[1:])
return 0.5 * np.abs(main_area + correction)
def visualize_poses(poses, size=0.1):
# poses: [B, 4, 4]
print(f'[INFO] visualize poses: {poses.shape}')
axes = trimesh.creation.axis(axis_length=4)
box = trimesh.primitives.Box(extents=(2, 2, 2)).as_outline()
box.colors = np.array([[128, 128, 128]] * len(box.entities))
objects = [axes, box]
for pose in poses:
# a camera is visualized with 8 line segments.
pos = pose[:3, 3]
a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2]
b = pos - size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2]
c = pos - size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2]
d = pos + size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2]
dir = (a + b + c + d) / 4 - pos
dir = dir / (np.linalg.norm(dir) + 1e-8)
o = pos + dir * 3
segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a], [pos, o]])
segs = trimesh.load_path(segs)
objects.append(segs)
trimesh.Scene(objects).show()
class NeRFDataset_Test:
def __init__(self, opt, device, downscale=1):
super().__init__()
self.opt = opt
self.device = device
self.downscale = downscale
self.scale = opt.scale # camera radius scale to make sure camera are inside the bounding box.
self.offset = opt.offset # camera offset
self.bound = opt.bound # bounding box half length, also used as the radius to random sample poses.
self.fp16 = opt.fp16
self.start_index = opt.data_range[0]
self.end_index = opt.data_range[1]
self.training = False
self.num_rays = -1
# load nerf-compatible format data.
with open(opt.pose, 'r') as f:
transform = json.load(f)
# load image size
self.H = int(transform['cy']) * 2 // downscale
self.W = int(transform['cx']) * 2 // downscale
# read images
frames = transform["frames"]
# use a slice of the dataset
if self.end_index == -1: # abuse...
self.end_index = len(frames)
frames = frames[self.start_index:self.end_index]
print(f'[INFO] load {len(frames)} frames.')
# only load pre-calculated aud features when not live-streaming
if not self.opt.asr:
aud_features = np.load(self.opt.aud)
aud_features = torch.from_numpy(aud_features)
# support both [N, 16] labels and [N, 16, K] logits
if len(aud_features.shape) == 3:
aud_features = aud_features.float().permute(0, 2, 1) # [N, 16, 29] --> [N, 29, 16]
if self.opt.emb:
print(f'[INFO] argmax to aud features {aud_features.shape} for --emb mode')
aud_features = aud_features.argmax(1) # [N, 16]
else:
assert self.opt.emb, "aud only provide labels, must use --emb"
aud_features = aud_features.long()
print(f'[INFO] load {self.opt.aud} aud_features: {aud_features.shape}')
self.poses = []
self.auds = []
self.eye_area = []
for f in tqdm.tqdm(frames, desc=f'Loading data'):
pose = np.array(f['transform_matrix'], dtype=np.float32) # [4, 4]
pose = nerf_matrix_to_ngp(pose, scale=self.scale, offset=self.offset)
self.poses.append(pose)
# find the corresponding audio to the image frame
if not self.opt.asr and self.opt.aud == '':
aud = aud_features[min(f['aud_id'], aud_features.shape[0] - 1)] # careful for the last frame...
self.auds.append(aud)
if self.opt.exp_eye:
if 'eye_ratio' in f:
area = f['eye_ratio']
else:
area = 0.25 # default value for opened eye
self.eye_area.append(area)
# load pre-extracted background image (should be the same size as training image...)
if self.opt.bg_img == 'white': # special
bg_img = np.ones((self.H, self.W, 3), dtype=np.float32)
elif self.opt.bg_img == 'black': # special
bg_img = np.zeros((self.H, self.W, 3), dtype=np.float32)
else: # load from file
bg_img = cv2.imread(self.opt.bg_img, cv2.IMREAD_UNCHANGED) # [H, W, 3]
if bg_img.shape[0] != self.H or bg_img.shape[1] != self.W:
bg_img = cv2.resize(bg_img, (self.W, self.H), interpolation=cv2.INTER_AREA)
bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB)
bg_img = bg_img.astype(np.float32) / 255 # [H, W, 3/4]
self.bg_img = bg_img
self.poses = np.stack(self.poses, axis=0)
# smooth camera path...
if self.opt.smooth_path:
self.poses = smooth_camera_path(self.poses, self.opt.smooth_path_window)
self.poses = torch.from_numpy(self.poses) # [N, 4, 4]
if self.opt.asr:
# live streaming, no pre-calculated auds
self.auds = None
else:
# auds corresponding to images
if self.opt.aud == '':
self.auds = torch.stack(self.auds, dim=0) # [N, 32, 16]
# auds is novel, may have a different length with images
else:
self.auds = aud_features
self.bg_img = torch.from_numpy(self.bg_img)
if self.opt.exp_eye:
self.eye_area = np.array(self.eye_area, dtype=np.float32) # [N]
print(f'[INFO] eye_area: {self.eye_area.min()} - {self.eye_area.max()}')
if self.opt.smooth_eye:
# naive 5 window average
ori_eye = self.eye_area.copy()
for i in range(ori_eye.shape[0]):
start = max(0, i - 1)
end = min(ori_eye.shape[0], i + 2)
self.eye_area[i] = ori_eye[start:end].mean()
self.eye_area = torch.from_numpy(self.eye_area).view(-1, 1) # [N, 1]
# always preload
self.poses = self.poses.to(self.device)
if self.auds is not None:
self.auds = self.auds.to(self.device)
self.bg_img = self.bg_img.to(torch.half).to(self.device)
if self.opt.exp_eye:
self.eye_area = self.eye_area.to(self.device)
# load intrinsics
fl_x = fl_y = transform['focal_len']
cx = (transform['cx'] / downscale)
cy = (transform['cy'] / downscale)
self.intrinsics = np.array([fl_x, fl_y, cx, cy])
# directly build the coordinate meshgrid in [-1, 1]^2
self.bg_coords = get_bg_coords(self.H, self.W, self.device) # [1, H*W, 2] in [-1, 1]
def mirror_index(self, index):
size = self.poses.shape[0]
turn = index // size
res = index % size
if turn % 2 == 0:
return res
else:
return size - res - 1
def collate(self, index):
B = len(index) # a list of length 1
# assert B == 1
results = {}
# audio use the original index
if self.auds is not None:
auds = get_audio_features(self.auds, self.opt.att, index[0]).to(self.device)
results['auds'] = auds
# head pose and bg image may mirror (replay --> <-- --> <--).
index[0] = self.mirror_index(index[0])
poses = self.poses[index].to(self.device) # [B, 4, 4]
rays = get_rays(poses, self.intrinsics, self.H, self.W, self.num_rays, self.opt.patch_size)
results['index'] = index # for ind. code
results['H'] = self.H
results['W'] = self.W
results['rays_o'] = rays['rays_o']
results['rays_d'] = rays['rays_d']
if self.opt.exp_eye:
results['eye'] = self.eye_area[index].to(self.device) # [1]
else:
results['eye'] = None
bg_img = self.bg_img.view(1, -1, 3).repeat(B, 1, 1).to(self.device)
results['bg_color'] = bg_img
bg_coords = self.bg_coords # [1, N, 2]
results['bg_coords'] = bg_coords
# results['poses'] = convert_poses(poses) # [B, 6]
# results['poses_matrix'] = poses # [B, 4, 4]
results['poses'] = poses # [B, 4, 4]
return results
def dataloader(self):
# test with novel auds, then use its length
if self.auds is not None:
size = self.auds.shape[0]
# live stream test, use 2 * len(poses), so it naturally mirrors.
else:
size = 2 * self.poses.shape[0]
loader = DataLoader(list(range(size)), batch_size=1, collate_fn=self.collate, shuffle=False, num_workers=0)
loader._data = self # an ugly fix... we need poses in trainer.
# do evaluate if has gt images and use self-driven setting
loader.has_gt = False
return loader
class NeRFDataset:
def __init__(self, opt, device, type='train', downscale=1):
super().__init__()
self.opt = opt
self.device = device
self.type = type # train, val, test
self.downscale = downscale
self.root_path = opt.path
self.preload = opt.preload # 0 = disk, 1 = cpu, 2 = gpu
self.scale = opt.scale # camera radius scale to make sure camera are inside the bounding box.
self.offset = opt.offset # camera offset
self.bound = opt.bound # bounding box half length, also used as the radius to random sample poses.
self.fp16 = opt.fp16
self.start_index = opt.data_range[0]
self.end_index = opt.data_range[1]
self.training = self.type in ['train', 'all', 'trainval']
self.num_rays = self.opt.num_rays if self.training else -1
# load nerf-compatible format data.
with open(opt.pose, 'r') as f:
transform = json.load(f)
# load image size
if 'h' in transform and 'w' in transform:
self.H = int(transform['h']) // downscale
self.W = int(transform['w']) // downscale
else:
self.H = int(transform['cy']) * 2 // downscale
self.W = int(transform['cx']) * 2 // downscale
# read images
frames = transform["frames"]
# use a slice of the dataset
if self.end_index == -1: # abuse...
self.end_index = len(frames)
frames = frames[self.start_index:self.end_index]
print(f'[INFO] load {len(frames)} {type} frames.')
# only load pre-calculated aud features when not live-streaming
if not self.opt.asr:
# empty means the default self-driven extracted features.
if self.opt.aud == '':
if 'esperanto' in self.opt.asr_model:
aud_features = np.load(os.path.join(self.root_path, 'aud_eo.npy'))
elif 'deepspeech' in self.opt.asr_model:
aud_features = np.load(os.path.join(self.root_path, 'aud_ds.npy'))
else:
aud_features = np.load(os.path.join(self.root_path, 'aud.npy'))
# cross-driven extracted features.
else:
aud_features = np.load(self.opt.aud)
aud_features = torch.from_numpy(aud_features)
# support both [N, 16] labels and [N, 16, K] logits
if len(aud_features.shape) == 3:
aud_features = aud_features.float().permute(0, 2, 1) # [N, 16, 29] --> [N, 29, 16]
if self.opt.emb:
print(f'[INFO] argmax to aud features {aud_features.shape} for --emb mode')
aud_features = aud_features.argmax(1) # [N, 16]
else:
assert self.opt.emb, "aud only provide labels, must use --emb"
aud_features = aud_features.long()
print(f'[INFO] load {self.opt.aud} aud_features: {aud_features.shape}')
# load action units
import pandas as pd
au_blink_info=pd.read_csv(os.path.join(self.root_path, 'au.csv'))
au_blink = au_blink_info[' AU45_r'].values
self.torso_img = []
self.images = []
self.poses = []
self.exps = []
self.auds = []
self.face_rect = []
self.lhalf_rect = []
self.lips_rect = []
self.eye_area = []
self.eye_rect = []
for f in tqdm.tqdm(frames, desc=f'Loading {type} data'):
f_path = os.path.join(self.root_path, 'gt_imgs', str(f['img_id']) + '.jpg')
if not os.path.exists(f_path):
print('[WARN]', f_path, 'NOT FOUND!')
continue
pose = np.array(f['transform_matrix'], dtype=np.float32) # [4, 4]
pose = nerf_matrix_to_ngp(pose, scale=self.scale, offset=self.offset)
self.poses.append(pose)
if self.preload > 0:
image = cv2.imread(f_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] o [H, W, 4]
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = image.astype(np.float32) / 255 # [H, W, 3/4]
self.images.append(image)
else:
self.images.append(f_path)
# load frame-wise bg
torso_img_path = os.path.join(self.root_path, 'torso_imgs', str(f['img_id']) + '.png')
if self.preload > 0:
torso_img = cv2.imread(torso_img_path, cv2.IMREAD_UNCHANGED) # [H, W, 4]
torso_img = cv2.cvtColor(torso_img, cv2.COLOR_BGRA2RGBA)
torso_img = torso_img.astype(np.float32) / 255 # [H, W, 3/4]
self.torso_img.append(torso_img)
else:
self.torso_img.append(torso_img_path)
# find the corresponding audio to the image frame
if not self.opt.asr and self.opt.aud == '':
aud = aud_features[min(f['aud_id'], aud_features.shape[0] - 1)] # careful for the last frame...
self.auds.append(aud)
# load lms and extract face
lms = np.loadtxt(os.path.join(self.root_path, 'ori_imgs', str(f['img_id']) + '.lms')) # [68, 2]
lh_xmin, lh_xmax = int(lms[31:36, 1].min()), int(lms[:, 1].max()) # actually lower half area
xmin, xmax = int(lms[:, 1].min()), int(lms[:, 1].max())
ymin, ymax = int(lms[:, 0].min()), int(lms[:, 0].max())
self.face_rect.append([xmin, xmax, ymin, ymax])
self.lhalf_rect.append([lh_xmin, lh_xmax, ymin, ymax])
if self.opt.exp_eye:
# eyes_left = slice(36, 42)
# eyes_right = slice(42, 48)
# area_left = polygon_area(lms[eyes_left, 0], lms[eyes_left, 1])
# area_right = polygon_area(lms[eyes_right, 0], lms[eyes_right, 1])
# # area percentage of two eyes of the whole image...
# area = (area_left + area_right) / (self.H * self.W) * 100
# action units blink AU45
area = au_blink[f['img_id']]
area = np.clip(area, 0, 2) / 2
# area = area + np.random.rand() / 10
self.eye_area.append(area)
xmin, xmax = int(lms[36:48, 1].min()), int(lms[36:48, 1].max())
ymin, ymax = int(lms[36:48, 0].min()), int(lms[36:48, 0].max())
self.eye_rect.append([xmin, xmax, ymin, ymax])
if self.opt.finetune_lips:
lips = slice(48, 60)
xmin, xmax = int(lms[lips, 1].min()), int(lms[lips, 1].max())
ymin, ymax = int(lms[lips, 0].min()), int(lms[lips, 0].max())
# padding to H == W
cx = (xmin + xmax) // 2
cy = (ymin + ymax) // 2
l = max(xmax - xmin, ymax - ymin) // 2
xmin = max(0, cx - l)
xmax = min(self.H, cx + l)
ymin = max(0, cy - l)
ymax = min(self.W, cy + l)
self.lips_rect.append([xmin, xmax, ymin, ymax])
# load pre-extracted background image (should be the same size as training image...)
if self.opt.bg_img == 'white': # special
bg_img = np.ones((self.H, self.W, 3), dtype=np.float32)
elif self.opt.bg_img == 'black': # special
bg_img = np.zeros((self.H, self.W, 3), dtype=np.float32)
else: # load from file
# default bg
if self.opt.bg_img == '':
self.opt.bg_img = os.path.join(self.root_path, 'bc.jpg')
bg_img = cv2.imread(self.opt.bg_img, cv2.IMREAD_UNCHANGED) # [H, W, 3]
if bg_img.shape[0] != self.H or bg_img.shape[1] != self.W:
bg_img = cv2.resize(bg_img, (self.W, self.H), interpolation=cv2.INTER_AREA)
bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB)
bg_img = bg_img.astype(np.float32) / 255 # [H, W, 3/4]
self.bg_img = bg_img
self.poses = np.stack(self.poses, axis=0)
# smooth camera path...
if self.opt.smooth_path:
self.poses = smooth_camera_path(self.poses, self.opt.smooth_path_window)
self.poses = torch.from_numpy(self.poses) # [N, 4, 4]
if self.preload > 0:
self.images = torch.from_numpy(np.stack(self.images, axis=0)) # [N, H, W, C]
self.torso_img = torch.from_numpy(np.stack(self.torso_img, axis=0)) # [N, H, W, C]
else:
self.images = np.array(self.images)
self.torso_img = np.array(self.torso_img)
if self.opt.asr:
# live streaming, no pre-calculated auds
self.auds = None
else:
# auds corresponding to images
if self.opt.aud == '':
self.auds = torch.stack(self.auds, dim=0) # [N, 32, 16]
# auds is novel, may have a different length with images
else:
self.auds = aud_features
self.bg_img = torch.from_numpy(self.bg_img)
if self.opt.exp_eye:
self.eye_area = np.array(self.eye_area, dtype=np.float32) # [N]
print(f'[INFO] eye_area: {self.eye_area.min()} - {self.eye_area.max()}')
if self.opt.smooth_eye:
# naive 5 window average
ori_eye = self.eye_area.copy()
for i in range(ori_eye.shape[0]):
start = max(0, i - 1)
end = min(ori_eye.shape[0], i + 2)
self.eye_area[i] = ori_eye[start:end].mean()
self.eye_area = torch.from_numpy(self.eye_area).view(-1, 1) # [N, 1]
# calculate mean radius of all camera poses
self.radius = self.poses[:, :3, 3].norm(dim=-1).mean(0).item()
#print(f'[INFO] dataset camera poses: radius = {self.radius:.4f}, bound = {self.bound}')
# [debug] uncomment to view all training poses.
# visualize_poses(self.poses.numpy())
# [debug] uncomment to view examples of randomly generated poses.
# visualize_poses(rand_poses(100, self.device, radius=self.radius).cpu().numpy())
if self.preload > 1:
self.poses = self.poses.to(self.device)
if self.auds is not None:
self.auds = self.auds.to(self.device)
self.bg_img = self.bg_img.to(torch.half).to(self.device)
self.torso_img = self.torso_img.to(torch.half).to(self.device)
self.images = self.images.to(torch.half).to(self.device)
if self.opt.exp_eye:
self.eye_area = self.eye_area.to(self.device)
# load intrinsics
if 'focal_len' in transform:
fl_x = fl_y = transform['focal_len']
elif 'fl_x' in transform or 'fl_y' in transform:
fl_x = (transform['fl_x'] if 'fl_x' in transform else transform['fl_y']) / downscale
fl_y = (transform['fl_y'] if 'fl_y' in transform else transform['fl_x']) / downscale
elif 'camera_angle_x' in transform or 'camera_angle_y' in transform:
# blender, assert in radians. already downscaled since we use H/W
fl_x = self.W / (2 * np.tan(transform['camera_angle_x'] / 2)) if 'camera_angle_x' in transform else None
fl_y = self.H / (2 * np.tan(transform['camera_angle_y'] / 2)) if 'camera_angle_y' in transform else None
if fl_x is None: fl_x = fl_y
if fl_y is None: fl_y = fl_x
else:
raise RuntimeError('Failed to load focal length, please check the transforms.json!')
cx = (transform['cx'] / downscale) if 'cx' in transform else (self.W / 2)
cy = (transform['cy'] / downscale) if 'cy' in transform else (self.H / 2)
self.intrinsics = np.array([fl_x, fl_y, cx, cy])
# directly build the coordinate meshgrid in [-1, 1]^2
self.bg_coords = get_bg_coords(self.H, self.W, self.device) # [1, H*W, 2] in [-1, 1]
def mirror_index(self, index):
size = self.poses.shape[0]
turn = index // size
res = index % size
if turn % 2 == 0:
return res
else:
return size - res - 1
def collate(self, index):
B = len(index) # a list of length 1
# assert B == 1
results = {}
# audio use the original index
if self.auds is not None:
auds = get_audio_features(self.auds, self.opt.att, index[0]).to(self.device)
results['auds'] = auds
# head pose and bg image may mirror (replay --> <-- --> <--).
index[0] = self.mirror_index(index[0])
poses = self.poses[index].to(self.device) # [B, 4, 4]
if self.training and self.opt.finetune_lips:
rect = self.lips_rect[index[0]]
results['rect'] = rect
rays = get_rays(poses, self.intrinsics, self.H, self.W, -1, rect=rect)
else:
rays = get_rays(poses, self.intrinsics, self.H, self.W, self.num_rays, self.opt.patch_size)
results['index'] = index # for ind. code
results['H'] = self.H
results['W'] = self.W
results['rays_o'] = rays['rays_o']
results['rays_d'] = rays['rays_d']
# get a mask for rays inside rect_face
if self.training:
xmin, xmax, ymin, ymax = self.face_rect[index[0]]
face_mask = (rays['j'] >= xmin) & (rays['j'] < xmax) & (rays['i'] >= ymin) & (rays['i'] < ymax) # [B, N]
results['face_mask'] = face_mask
xmin, xmax, ymin, ymax = self.lhalf_rect[index[0]]
lhalf_mask = (rays['j'] >= xmin) & (rays['j'] < xmax) & (rays['i'] >= ymin) & (rays['i'] < ymax) # [B, N]
results['lhalf_mask'] = lhalf_mask
if self.opt.exp_eye:
results['eye'] = self.eye_area[index].to(self.device) # [1]
if self.training:
results['eye'] += (np.random.rand()-0.5) / 10
xmin, xmax, ymin, ymax = self.eye_rect[index[0]]
eye_mask = (rays['j'] >= xmin) & (rays['j'] < xmax) & (rays['i'] >= ymin) & (rays['i'] < ymax) # [B, N]
results['eye_mask'] = eye_mask
else:
results['eye'] = None
# load bg
bg_torso_img = self.torso_img[index]
if self.preload == 0: # on the fly loading
bg_torso_img = cv2.imread(bg_torso_img[0], cv2.IMREAD_UNCHANGED) # [H, W, 4]
bg_torso_img = cv2.cvtColor(bg_torso_img, cv2.COLOR_BGRA2RGBA)
bg_torso_img = bg_torso_img.astype(np.float32) / 255 # [H, W, 3/4]
bg_torso_img = torch.from_numpy(bg_torso_img).unsqueeze(0)
bg_torso_img = bg_torso_img[..., :3] * bg_torso_img[..., 3:] + self.bg_img * (1 - bg_torso_img[..., 3:])
bg_torso_img = bg_torso_img.view(B, -1, 3).to(self.device)
if not self.opt.torso:
bg_img = bg_torso_img
else:
bg_img = self.bg_img.view(1, -1, 3).repeat(B, 1, 1).to(self.device)
if self.training:
bg_img = torch.gather(bg_img, 1, torch.stack(3 * [rays['inds']], -1)) # [B, N, 3]
results['bg_color'] = bg_img
if self.opt.torso and self.training:
bg_torso_img = torch.gather(bg_torso_img, 1, torch.stack(3 * [rays['inds']], -1)) # [B, N, 3]
results['bg_torso_color'] = bg_torso_img
images = self.images[index] # [B, H, W, 3/4]
if self.preload == 0:
images = cv2.imread(images[0], cv2.IMREAD_UNCHANGED) # [H, W, 3]
images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB)
images = images.astype(np.float32) / 255 # [H, W, 3]
images = torch.from_numpy(images).unsqueeze(0)
images = images.to(self.device)
if self.training:
C = images.shape[-1]
images = torch.gather(images.view(B, -1, C), 1, torch.stack(C * [rays['inds']], -1)) # [B, N, 3/4]
results['images'] = images
if self.training:
bg_coords = torch.gather(self.bg_coords, 1, torch.stack(2 * [rays['inds']], -1)) # [1, N, 2]
else:
bg_coords = self.bg_coords # [1, N, 2]
results['bg_coords'] = bg_coords
# results['poses'] = convert_poses(poses) # [B, 6]
# results['poses_matrix'] = poses # [B, 4, 4]
results['poses'] = poses # [B, 4, 4]
return results
def dataloader(self):
if self.training:
# training len(poses) == len(auds)
size = self.poses.shape[0]
else:
# test with novel auds, then use its length
if self.auds is not None:
size = self.auds.shape[0]
# live stream test, use 2 * len(poses), so it naturally mirrors.
else:
size = 2 * self.poses.shape[0]
loader = DataLoader(list(range(size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0)
loader._data = self # an ugly fix... we need poses in trainer.
# do evaluate if has gt images and use self-driven setting
loader.has_gt = (self.opt.aud == '')
return loader

700
nerf_triplane/renderer.py Normal file
View File

@ -0,0 +1,700 @@
import math
import trimesh
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import raymarching
from .utils import custom_meshgrid, get_audio_features, euler_angles_to_matrix, convert_poses
def sample_pdf(bins, weights, n_samples, det=False):
# This implementation is from NeRF
# bins: [B, T], old_z_vals
# weights: [B, T - 1], bin weights.
# return: [B, n_samples], new_z_vals
# Get pdf
weights = weights + 1e-5 # prevent nans
pdf = weights / torch.sum(weights, -1, keepdim=True)
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
# Take uniform samples
if det:
u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device)
u = u.expand(list(cdf.shape[:-1]) + [n_samples])
else:
u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device)
# Invert CDF
u = u.contiguous()
inds = torch.searchsorted(cdf, u, right=True)
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
inds_g = torch.stack([below, above], -1) # (B, n_samples, 2)
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
denom = (cdf_g[..., 1] - cdf_g[..., 0])
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
t = (u - cdf_g[..., 0]) / denom
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
return samples
def plot_pointcloud(pc, color=None):
# pc: [N, 3]
# color: [N, 3/4]
print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0))
pc = trimesh.PointCloud(pc, color)
# axis
axes = trimesh.creation.axis(axis_length=4)
# sphere
sphere = trimesh.creation.icosphere(radius=1)
trimesh.Scene([pc, axes, sphere]).show()
class NeRFRenderer(nn.Module):
def __init__(self, opt):
super().__init__()
self.opt = opt
self.bound = opt.bound
self.cascade = 1 + math.ceil(math.log2(opt.bound))
self.grid_size = 128
self.density_scale = 1
self.min_near = opt.min_near
self.density_thresh = opt.density_thresh
self.density_thresh_torso = opt.density_thresh_torso
self.exp_eye = opt.exp_eye
self.test_train = opt.test_train
self.smooth_lips = opt.smooth_lips
self.torso = opt.torso
self.cuda_ray = opt.cuda_ray
# prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax)
# NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing.
aabb_train = torch.FloatTensor([-opt.bound, -opt.bound/2, -opt.bound, opt.bound, opt.bound/2, opt.bound])
aabb_infer = aabb_train.clone()
self.register_buffer('aabb_train', aabb_train)
self.register_buffer('aabb_infer', aabb_infer)
# individual codes
self.individual_num = opt.ind_num
self.individual_dim = opt.ind_dim
if self.individual_dim > 0:
self.individual_codes = nn.Parameter(torch.randn(self.individual_num, self.individual_dim) * 0.1)
if self.torso:
self.individual_dim_torso = opt.ind_dim_torso
if self.individual_dim_torso > 0:
self.individual_codes_torso = nn.Parameter(torch.randn(self.individual_num, self.individual_dim_torso) * 0.1)
# optimize camera pose
self.train_camera = self.opt.train_camera
if self.train_camera:
self.camera_dR = nn.Parameter(torch.zeros(self.individual_num, 3)) # euler angle
self.camera_dT = nn.Parameter(torch.zeros(self.individual_num, 3)) # xyz offset
# extra state for cuda raymarching
# 3D head density grid
density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H]
density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8]
self.register_buffer('density_grid', density_grid)
self.register_buffer('density_bitfield', density_bitfield)
self.mean_density = 0
self.iter_density = 0
# 2D torso density grid
if self.torso:
density_grid_torso = torch.zeros([self.grid_size ** 2]) # [H * H]
self.register_buffer('density_grid_torso', density_grid_torso)
self.mean_density_torso = 0
# step counter
step_counter = torch.zeros(16, 2, dtype=torch.int32) # 16 is hardcoded for averaging...
self.register_buffer('step_counter', step_counter)
self.mean_count = 0
self.local_step = 0
# decay for enc_a
if self.smooth_lips:
self.enc_a = None
def forward(self, x, d):
raise NotImplementedError()
# separated density and color query (can accelerate non-cuda-ray mode.)
def density(self, x):
raise NotImplementedError()
def color(self, x, d, mask=None, **kwargs):
raise NotImplementedError()
def reset_extra_state(self):
if not self.cuda_ray:
return
# density grid
self.density_grid.zero_()
self.mean_density = 0
self.iter_density = 0
# step counter
self.step_counter.zero_()
self.mean_count = 0
self.local_step = 0
def run_cuda(self, rays_o, rays_d, auds, bg_coords, poses, eye=None, index=0, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs):
# rays_o, rays_d: [B, N, 3], assumes B == 1
# auds: [B, 16]
# index: [B]
# return: image: [B, N, 3], depth: [B, N]
prefix = rays_o.shape[:-1]
rays_o = rays_o.contiguous().view(-1, 3)
rays_d = rays_d.contiguous().view(-1, 3)
bg_coords = bg_coords.contiguous().view(-1, 2)
# only add camera offset at training!
if self.train_camera and (self.training or self.test_train):
dT = self.camera_dT[index] # [1, 3]
dR = euler_angles_to_matrix(self.camera_dR[index] / 180 * np.pi + 1e-8).squeeze(0) # [1, 3] --> [3, 3]
rays_o = rays_o + dT
rays_d = rays_d @ dR
N = rays_o.shape[0] # N = B * N, in fact
device = rays_o.device
results = {}
# pre-calculate near far
nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer, self.min_near)
nears = nears.detach()
fars = fars.detach()
# encode audio
enc_a = self.encode_audio(auds) # [1, 64]
if enc_a is not None and self.smooth_lips:
if self.enc_a is not None:
_lambda = 0.35
enc_a = _lambda * self.enc_a + (1 - _lambda) * enc_a
self.enc_a = enc_a
if self.individual_dim > 0:
if self.training:
ind_code = self.individual_codes[index]
# use a fixed ind code for the unknown test data.
else:
ind_code = self.individual_codes[0]
else:
ind_code = None
if self.training:
# setup counter
counter = self.step_counter[self.local_step % 16]
counter.zero_() # set to 0
self.local_step += 1
xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps)
sigmas, rgbs, amb_aud, amb_eye, uncertainty = self(xyzs, dirs, enc_a, ind_code, eye)
sigmas = self.density_scale * sigmas
#print(f'valid RGB query ratio: {mask.sum().item() / mask.shape[0]} (total = {mask.sum().item()})')
# weights_sum, ambient_sum, uncertainty_sum, depth, image = raymarching.composite_rays_train_uncertainty(sigmas, rgbs, ambient.abs().sum(-1), uncertainty, deltas, rays)
weights_sum, amb_aud_sum, amb_eye_sum, uncertainty_sum, depth, image = raymarching.composite_rays_train_triplane(sigmas, rgbs, amb_aud.abs().sum(-1), amb_eye.abs().sum(-1), uncertainty, deltas, rays)
# for training only
results['weights_sum'] = weights_sum
results['ambient_aud'] = amb_aud_sum
results['ambient_eye'] = amb_eye_sum
results['uncertainty'] = uncertainty_sum
results['rays'] = xyzs, dirs, enc_a, ind_code, eye
else:
dtype = torch.float32
weights_sum = torch.zeros(N, dtype=dtype, device=device)
depth = torch.zeros(N, dtype=dtype, device=device)
image = torch.zeros(N, 3, dtype=dtype, device=device)
amb_aud_sum = torch.zeros(N, dtype=dtype, device=device)
amb_eye_sum = torch.zeros(N, dtype=dtype, device=device)
uncertainty_sum = torch.zeros(N, dtype=dtype, device=device)
n_alive = N
rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N]
rays_t = nears.clone() # [N]
step = 0
while step < max_steps:
# count alive rays
n_alive = rays_alive.shape[0]
# exit loop
if n_alive <= 0:
break
# decide compact_steps
n_step = max(min(N // n_alive, 8), 1)
xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps)
sigmas, rgbs, ambients_aud, ambients_eye, uncertainties = self(xyzs, dirs, enc_a, ind_code, eye)
sigmas = self.density_scale * sigmas
# raymarching.composite_rays_uncertainty(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, uncertainties, weights_sum, depth, image, ambient_sum, uncertainty_sum, T_thresh)
raymarching.composite_rays_triplane(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambients_aud, ambients_eye, uncertainties, weights_sum, depth, image, amb_aud_sum, amb_eye_sum, uncertainty_sum, T_thresh)
rays_alive = rays_alive[rays_alive >= 0]
# print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}')
step += n_step
torso_results = self.run_torso(rays_o, bg_coords, poses, index, bg_color)
bg_color = torso_results['bg_color']
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
image = image.view(*prefix, 3)
image = image.clamp(0, 1)
depth = torch.clamp(depth - nears, min=0) / (fars - nears)
depth = depth.view(*prefix)
amb_aud_sum = amb_aud_sum.view(*prefix)
amb_eye_sum = amb_eye_sum.view(*prefix)
results['depth'] = depth
results['image'] = image # head_image if train, else com_image
results['ambient_aud'] = amb_aud_sum
results['ambient_eye'] = amb_eye_sum
results['uncertainty'] = uncertainty_sum
return results
def run_torso(self, rays_o, bg_coords, poses, index=0, bg_color=None, **kwargs):
# rays_o, rays_d: [B, N, 3], assumes B == 1
# auds: [B, 16]
# index: [B]
# return: image: [B, N, 3], depth: [B, N]
rays_o = rays_o.contiguous().view(-1, 3)
bg_coords = bg_coords.contiguous().view(-1, 2)
N = rays_o.shape[0] # N = B * N, in fact
device = rays_o.device
results = {}
# background
if bg_color is None:
bg_color = 1
# first mix torso with background
if self.torso:
# torso ind code
if self.individual_dim_torso > 0:
if self.training:
ind_code_torso = self.individual_codes_torso[index]
# use a fixed ind code for the unknown test data.
else:
ind_code_torso = self.individual_codes_torso[0]
else:
ind_code_torso = None
# 2D density grid for acceleration...
density_thresh_torso = min(self.density_thresh_torso, self.mean_density_torso)
occupancy = F.grid_sample(self.density_grid_torso.view(1, 1, self.grid_size, self.grid_size), bg_coords.view(1, -1, 1, 2), align_corners=True).view(-1)
mask = occupancy > density_thresh_torso
# masked query of torso
torso_alpha = torch.zeros([N, 1], device=device)
torso_color = torch.zeros([N, 3], device=device)
if mask.any():
torso_alpha_mask, torso_color_mask, deform = self.forward_torso(bg_coords[mask], poses, ind_code_torso)
torso_alpha[mask] = torso_alpha_mask.float()
torso_color[mask] = torso_color_mask.float()
results['deform'] = deform
# first mix torso with background
bg_color = torso_color * torso_alpha + bg_color * (1 - torso_alpha)
results['torso_alpha'] = torso_alpha
results['torso_color'] = bg_color
# print(torso_alpha.shape, torso_alpha.max().item(), torso_alpha.min().item())
results['bg_color'] = bg_color
return results
@torch.no_grad()
def mark_untrained_grid(self, poses, intrinsic, S=64):
# poses: [B, 4, 4]
# intrinsic: [3, 3]
if not self.cuda_ray:
return
if isinstance(poses, np.ndarray):
poses = torch.from_numpy(poses)
B = poses.shape[0]
fx, fy, cx, cy = intrinsic
X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
count = torch.zeros_like(self.density_grid)
poses = poses.to(count.device)
# 5-level loop, forgive me...
for xs in X:
for ys in Y:
for zs in Z:
# construct points
xx, yy, zz = custom_meshgrid(xs, ys, zs)
coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
indices = raymarching.morton3D(coords).long() # [N]
world_xyzs = (2 * coords.float() / (self.grid_size - 1) - 1).unsqueeze(0) # [1, N, 3] in [-1, 1]
# cascading
for cas in range(self.cascade):
bound = min(2 ** cas, self.bound)
half_grid_size = bound / self.grid_size
# scale to current cascade's resolution
cas_world_xyzs = world_xyzs * (bound - half_grid_size)
# split batch to avoid OOM
head = 0
while head < B:
tail = min(head + S, B)
# world2cam transform (poses is c2w, so we need to transpose it. Another transpose is needed for batched matmul, so the final form is without transpose.)
cam_xyzs = cas_world_xyzs - poses[head:tail, :3, 3].unsqueeze(1)
cam_xyzs = cam_xyzs @ poses[head:tail, :3, :3] # [S, N, 3]
# query if point is covered by any camera
mask_z = cam_xyzs[:, :, 2] > 0 # [S, N]
mask_x = torch.abs(cam_xyzs[:, :, 0]) < cx / fx * cam_xyzs[:, :, 2] + half_grid_size * 2
mask_y = torch.abs(cam_xyzs[:, :, 1]) < cy / fy * cam_xyzs[:, :, 2] + half_grid_size * 2
mask = (mask_z & mask_x & mask_y).sum(0).reshape(-1) # [N]
# update count
count[cas, indices] += mask
head += S
# mark untrained grid as -1
self.density_grid[count == 0] = -1
#print(f'[mark untrained grid] {(count == 0).sum()} from {resolution ** 3 * self.cascade}')
@torch.no_grad()
def update_extra_state(self, decay=0.95, S=128):
# call before each epoch to update extra states.
if not self.cuda_ray:
return
# use random auds (different expressions should have similar density grid...)
rand_idx = random.randint(0, self.aud_features.shape[0] - 1)
auds = get_audio_features(self.aud_features, self.att, rand_idx).to(self.density_bitfield.device)
# encode audio
enc_a = self.encode_audio(auds)
### update density grid
if not self.torso: # forbid updating head if is training torso...
tmp_grid = torch.zeros_like(self.density_grid)
# use a random eye area based on training dataset's statistics...
if self.exp_eye:
eye = self.eye_area[[rand_idx]].to(self.density_bitfield.device) # [1, 1]
else:
eye = None
# full update
X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
for xs in X:
for ys in Y:
for zs in Z:
# construct points
xx, yy, zz = custom_meshgrid(xs, ys, zs)
coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
indices = raymarching.morton3D(coords).long() # [N]
xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
# cascading
for cas in range(self.cascade):
bound = min(2 ** cas, self.bound)
half_grid_size = bound / self.grid_size
# scale to current cascade's resolution
cas_xyzs = xyzs * (bound - half_grid_size)
# add noise in [-hgs, hgs]
cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
# query density
sigmas = self.density(cas_xyzs, enc_a, eye)['sigma'].reshape(-1).detach().to(tmp_grid.dtype)
sigmas *= self.density_scale
# assign
tmp_grid[cas, indices] = sigmas
# dilate the density_grid (less aggressive culling)
tmp_grid = raymarching.morton3D_dilation(tmp_grid)
# ema update
valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0)
self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])
self.mean_density = torch.mean(self.density_grid.clamp(min=0)).item() # -1 non-training regions are viewed as 0 density.
self.iter_density += 1
# convert to bitfield
density_thresh = min(self.mean_density, self.density_thresh)
self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield)
### update torso density grid
if self.torso:
tmp_grid_torso = torch.zeros_like(self.density_grid_torso)
# random pose, random ind_code
rand_idx = random.randint(0, self.poses.shape[0] - 1)
# pose = convert_poses(self.poses[[rand_idx]]).to(self.density_bitfield.device)
pose = self.poses[[rand_idx]].to(self.density_bitfield.device)
if self.opt.ind_dim_torso > 0:
ind_code = self.individual_codes_torso[[rand_idx]]
else:
ind_code = None
X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
half_grid_size = 1 / self.grid_size
for xs in X:
for ys in Y:
xx, yy = custom_meshgrid(xs, ys)
coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], dim=-1) # [N, 2], in [0, 128)
indices = (coords[:, 1] * self.grid_size + coords[:, 0]).long() # NOTE: xy transposed!
xys = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 2] in [-1, 1]
xys = xys * (1 - half_grid_size)
# add noise in [-hgs, hgs]
xys += (torch.rand_like(xys) * 2 - 1) * half_grid_size
# query density
alphas, _, _ = self.forward_torso(xys, pose, ind_code) # [N, 1]
# assign
tmp_grid_torso[indices] = alphas.squeeze(1).float()
# dilate
tmp_grid_torso = tmp_grid_torso.view(1, 1, self.grid_size, self.grid_size)
# tmp_grid_torso = F.max_pool2d(tmp_grid_torso, kernel_size=3, stride=1, padding=1)
tmp_grid_torso = F.max_pool2d(tmp_grid_torso, kernel_size=5, stride=1, padding=2)
tmp_grid_torso = tmp_grid_torso.view(-1)
self.density_grid_torso = torch.maximum(self.density_grid_torso * decay, tmp_grid_torso)
self.mean_density_torso = torch.mean(self.density_grid_torso).item()
# density_thresh_torso = min(self.density_thresh_torso, self.mean_density_torso)
# print(f'[density grid torso] min={self.density_grid_torso.min().item():.4f}, max={self.density_grid_torso.max().item():.4f}, mean={self.mean_density_torso:.4f}, occ_rate={(self.density_grid_torso > density_thresh_torso).sum() / (128**2):.3f}')
### update step counter
total_step = min(16, self.local_step)
if total_step > 0:
self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step)
self.local_step = 0
#print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > 0.01).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}')
@torch.no_grad()
def get_audio_grid(self, S=128):
# call before each epoch to update extra states.
if not self.cuda_ray:
return
# use random auds (different expressions should have similar density grid...)
rand_idx = random.randint(0, self.aud_features.shape[0] - 1)
auds = get_audio_features(self.aud_features, self.att, rand_idx).to(self.density_bitfield.device)
# encode audio
enc_a = self.encode_audio(auds)
tmp_grid = torch.zeros_like(self.density_grid)
# use a random eye area based on training dataset's statistics...
if self.exp_eye:
eye = self.eye_area[[rand_idx]].to(self.density_bitfield.device) # [1, 1]
else:
eye = None
# full update
X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
for xs in X:
for ys in Y:
for zs in Z:
# construct points
xx, yy, zz = custom_meshgrid(xs, ys, zs)
coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
indices = raymarching.morton3D(coords).long() # [N]
xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
# cascading
for cas in range(self.cascade):
bound = min(2 ** cas, self.bound)
half_grid_size = bound / self.grid_size
# scale to current cascade's resolution
cas_xyzs = xyzs * (bound - half_grid_size)
# add noise in [-hgs, hgs]
cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
# query density
aud_norms = self.density(cas_xyzs.to(tmp_grid.dtype), enc_a, eye)['ambient_aud'].reshape(-1).detach().to(tmp_grid.dtype)
# assign
tmp_grid[cas, indices] = aud_norms
# dilate the density_grid (less aggressive culling)
tmp_grid = raymarching.morton3D_dilation(tmp_grid)
return tmp_grid
# # ema update
# valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0)
# self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])
@torch.no_grad()
def get_eye_grid(self, S=128):
# call before each epoch to update extra states.
if not self.cuda_ray:
return
# use random auds (different expressions should have similar density grid...)
rand_idx = random.randint(0, self.aud_features.shape[0] - 1)
auds = get_audio_features(self.aud_features, self.att, rand_idx).to(self.density_bitfield.device)
# encode audio
enc_a = self.encode_audio(auds)
tmp_grid = torch.zeros_like(self.density_grid)
# use a random eye area based on training dataset's statistics...
if self.exp_eye:
eye = self.eye_area[[rand_idx]].to(self.density_bitfield.device) # [1, 1]
else:
eye = None
# full update
X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
for xs in X:
for ys in Y:
for zs in Z:
# construct points
xx, yy, zz = custom_meshgrid(xs, ys, zs)
coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
indices = raymarching.morton3D(coords).long() # [N]
xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
# cascading
for cas in range(self.cascade):
bound = min(2 ** cas, self.bound)
half_grid_size = bound / self.grid_size
# scale to current cascade's resolution
cas_xyzs = xyzs * (bound - half_grid_size)
# add noise in [-hgs, hgs]
cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
# query density
eye_norms = self.density(cas_xyzs.to(tmp_grid.dtype), enc_a, eye)['ambient_eye'].reshape(-1).detach().to(tmp_grid.dtype)
# assign
tmp_grid[cas, indices] = eye_norms
# dilate the density_grid (less aggressive culling)
tmp_grid = raymarching.morton3D_dilation(tmp_grid)
return tmp_grid
# # ema update
# valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0)
# self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])
def render(self, rays_o, rays_d, auds, bg_coords, poses, staged=False, max_ray_batch=4096, **kwargs):
# rays_o, rays_d: [B, N, 3], assumes B == 1
# auds: [B, 29, 16]
# eye: [B, 1]
# bg_coords: [1, N, 2]
# return: pred_rgb: [B, N, 3]
_run = self.run_cuda
B, N = rays_o.shape[:2]
device = rays_o.device
# never stage when cuda_ray
if staged and not self.cuda_ray:
# not used
raise NotImplementedError
else:
results = _run(rays_o, rays_d, auds, bg_coords, poses, **kwargs)
return results
def render_torso(self, rays_o, rays_d, auds, bg_coords, poses, staged=False, max_ray_batch=4096, **kwargs):
# rays_o, rays_d: [B, N, 3], assumes B == 1
# auds: [B, 29, 16]
# eye: [B, 1]
# bg_coords: [1, N, 2]
# return: pred_rgb: [B, N, 3]
_run = self.run_torso
B, N = rays_o.shape[:2]
device = rays_o.device
# never stage when cuda_ray
if staged and not self.cuda_ray:
# not used
raise NotImplementedError
else:
results = _run(rays_o, bg_coords, poses, **kwargs)
return results

1516
nerf_triplane/utils.py Normal file

File diff suppressed because it is too large Load Diff

158
nerfreal.py Normal file
View File

@ -0,0 +1,158 @@
import math
import torch
import numpy as np
#from .utils import *
import subprocess
import os
from asrreal import ASR
class NeRFReal:
def __init__(self, opt, trainer, data_loader, debug=True):
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.W = opt.W
self.H = opt.H
self.debug = debug
self.training = False
self.step = 0 # training step
self.trainer = trainer
self.data_loader = data_loader
# use dataloader's bg
bg_img = data_loader._data.bg_img #.view(1, -1, 3)
if self.H != bg_img.shape[0] or self.W != bg_img.shape[1]:
bg_img = F.interpolate(bg_img.permute(2, 0, 1).unsqueeze(0).contiguous(), (self.H, self.W), mode='bilinear').squeeze(0).permute(1, 2, 0).contiguous()
self.bg_color = bg_img.view(1, -1, 3)
# audio features (from dataloader, only used in non-playing mode)
self.audio_features = data_loader._data.auds # [N, 29, 16]
self.audio_idx = 0
# control eye
self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item()
# playing seq from dataloader, or pause.
self.playing = True #False todo
self.loader = iter(data_loader)
self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
self.need_update = True # camera moved, should reset accumulation
self.spp = 1 # sample per pixel
self.mode = 'image' # choose from ['image', 'depth']
self.dynamic_resolution = False # assert False!
self.downscale = 1
self.train_steps = 16
self.ind_index = 0
self.ind_num = trainer.model.individual_codes.shape[0]
# build asr
if self.opt.asr:
self.asr = ASR(opt)
video_path = 'video_stream'
if not os.path.exists(video_path):
os.mkfifo(video_path, mode=0o777)
audio_path = 'audio_stream'
if not os.path.exists(audio_path):
os.mkfifo(audio_path, mode=0o777)
width=450
height=450
fps=25
push_url='rtmp://localhost/live/livestream' #'data/video/output_0.mp4'
command = ['ffmpeg',
'-y', #'-an',
#'-re',
'-f', 'rawvideo',
'-vcodec','rawvideo',
'-pix_fmt', 'rgb24', #像素格式
'-s', "{}x{}".format(width, height),
'-r', str(fps),
'-i', video_path,
'-f', 's16le',
'-acodec','pcm_s16le',
'-ac', '1',
'-ar', '16000',
'-i', audio_path,
#'-fflags', '+genpts',
'-map', '0:v',
'-map', '1:a',
#'-copyts',
'-acodec', 'aac',
'-pix_fmt', 'yuv420p', #'-vcodec', "h264",
#"-rtmp_buffer", "100",
'-f' , 'flv',
push_url]
self.pipe = subprocess.Popen(command, shell=False) #, stdin=subprocess.PIPE)
self.fifo_video = open(video_path, 'wb')
self.fifo_audio = open(audio_path, 'wb')
#self.test_step()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
if self.opt.asr:
self.asr.stop()
def push_audio(self,chunk):
self.asr.push_audio(chunk)
def prepare_buffer(self, outputs):
if self.mode == 'image':
return outputs['image']
else:
return np.expand_dims(outputs['depth'], -1).repeat(3, -1)
def test_step(self):
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
starter.record()
if self.playing:
try:
data = next(self.loader)
except StopIteration:
self.loader = iter(self.data_loader)
data = next(self.loader)
if self.opt.asr:
# use the live audio stream
data['auds'] = self.asr.get_next_feat()
outputs = self.trainer.test_gui_with_data(data, self.W, self.H)
print(f'[INFO] outputs shape ',outputs['image'].shape)
image = (outputs['image'] * 255).astype(np.uint8)
#self.pipe.stdin.write(image.tostring())
for _ in range(2):
frame = self.asr.get_audio_out()
print(f'[INFO] get_audio_out shape ',frame.shape)
frame = (frame * 32767).astype(np.int16).tobytes()
self.fifo_audio.write(frame)
self.fifo_video.write(image.tostring())
else:
if self.audio_features is not None:
auds = get_audio_features(self.audio_features, self.opt.att, self.audio_idx)
else:
auds = None
outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.W, self.H, auds, self.eye_area, self.ind_index, self.bg_color, self.spp, self.downscale)
ender.record()
torch.cuda.synchronize()
t = starter.elapsed_time(ender)
def render(self):
if self.opt.asr:
self.asr.warm_up()
while True: #todo
# update texture every frame
# audio stream thread...
if self.opt.asr and self.playing:
# run 2 ASR steps (audio is at 50FPS, video is at 25FPS)
for _ in range(2):
self.asr.run_step()
self.test_step()

1
raymarching/__init__.py Normal file
View File

@ -0,0 +1 @@
from .raymarching import *

40
raymarching/backend.py Normal file
View File

@ -0,0 +1,40 @@
import os
from torch.utils.cpp_extension import load
_src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [
'-O3', '-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
]
if os.name == "posix":
c_flags = ['-O3', '-std=c++14']
elif os.name == "nt":
c_flags = ['/O2', '/std:c++17']
# find cl.exe
def find_cl_path():
import glob
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
if paths:
return paths[0]
# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
os.environ["PATH"] += ";" + cl_path
_backend = load(name='_raymarching_face',
extra_cflags=c_flags,
extra_cuda_cflags=nvcc_flags,
sources=[os.path.join(_src_path, 'src', f) for f in [
'raymarching.cu',
'bindings.cpp',
]],
)
__all__ = ['_backend']

671
raymarching/raymarching.py Normal file
View File

@ -0,0 +1,671 @@
import numpy as np
import time
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd
try:
import _raymarching_face as _backend
except ImportError:
from .backend import _backend
# ----------------------------------------
# utils
# ----------------------------------------
class _near_far_from_aabb(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, rays_o, rays_d, aabb, min_near=0.2):
''' near_far_from_aabb, CUDA implementation
Calculate rays' intersection time (near and far) with aabb
Args:
rays_o: float, [N, 3]
rays_d: float, [N, 3]
aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax)
min_near: float, scalar
Returns:
nears: float, [N]
fars: float, [N]
'''
if not rays_o.is_cuda: rays_o = rays_o.cuda()
if not rays_d.is_cuda: rays_d = rays_d.cuda()
rays_o = rays_o.contiguous().view(-1, 3)
rays_d = rays_d.contiguous().view(-1, 3)
N = rays_o.shape[0] # num rays
nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
_backend.near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars)
return nears, fars
near_far_from_aabb = _near_far_from_aabb.apply
class _sph_from_ray(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, rays_o, rays_d, radius):
''' sph_from_ray, CUDA implementation
get spherical coordinate on the background sphere from rays.
Assume rays_o are inside the Sphere(radius).
Args:
rays_o: [N, 3]
rays_d: [N, 3]
radius: scalar, float
Return:
coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface)
'''
if not rays_o.is_cuda: rays_o = rays_o.cuda()
if not rays_d.is_cuda: rays_d = rays_d.cuda()
rays_o = rays_o.contiguous().view(-1, 3)
rays_d = rays_d.contiguous().view(-1, 3)
N = rays_o.shape[0] # num rays
coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device)
_backend.sph_from_ray(rays_o, rays_d, radius, N, coords)
return coords
sph_from_ray = _sph_from_ray.apply
class _morton3D(Function):
@staticmethod
def forward(ctx, coords):
''' morton3D, CUDA implementation
Args:
coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...)
TODO: check if the coord range is valid! (current 128 is safe)
Returns:
indices: [N], int32, in [0, 128^3)
'''
if not coords.is_cuda: coords = coords.cuda()
N = coords.shape[0]
indices = torch.empty(N, dtype=torch.int32, device=coords.device)
_backend.morton3D(coords.int(), N, indices)
return indices
morton3D = _morton3D.apply
class _morton3D_invert(Function):
@staticmethod
def forward(ctx, indices):
''' morton3D_invert, CUDA implementation
Args:
indices: [N], int32, in [0, 128^3)
Returns:
coords: [N, 3], int32, in [0, 128)
'''
if not indices.is_cuda: indices = indices.cuda()
N = indices.shape[0]
coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device)
_backend.morton3D_invert(indices.int(), N, coords)
return coords
morton3D_invert = _morton3D_invert.apply
class _packbits(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, grid, thresh, bitfield=None):
''' packbits, CUDA implementation
Pack up the density grid into a bit field to accelerate ray marching.
Args:
grid: float, [C, H * H * H], assume H % 2 == 0
thresh: float, threshold
Returns:
bitfield: uint8, [C, H * H * H / 8]
'''
if not grid.is_cuda: grid = grid.cuda()
grid = grid.contiguous()
C = grid.shape[0]
H3 = grid.shape[1]
N = C * H3 // 8
if bitfield is None:
bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device)
_backend.packbits(grid, N, thresh, bitfield)
return bitfield
packbits = _packbits.apply
class _morton3D_dilation(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, grid):
''' max pooling with morton coord, CUDA implementation
or maybe call it dilation... we don't support adjust kernel size.
Args:
grid: float, [C, H * H * H], assume H % 2 == 0
Returns:
grid_dilate: float, [C, H * H * H], assume H % 2 == 0bitfield: uint8, [C, H * H * H / 8]
'''
if not grid.is_cuda: grid = grid.cuda()
grid = grid.contiguous()
C = grid.shape[0]
H3 = grid.shape[1]
H = int(np.cbrt(H3))
grid_dilation = torch.empty_like(grid)
_backend.morton3D_dilation(grid, C, H, grid_dilation)
return grid_dilation
morton3D_dilation = _morton3D_dilation.apply
# ----------------------------------------
# train functions
# ----------------------------------------
class _march_rays_train(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024):
''' march rays to generate points (forward only)
Args:
rays_o/d: float, [N, 3]
bound: float, scalar
density_bitfield: uint8: [CHHH // 8]
C: int
H: int
nears/fars: float, [N]
step_counter: int32, (2), used to count the actual number of generated points.
mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.)
perturb: bool
align: int, pad output so its size is dividable by align, set to -1 to disable.
force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays.
dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
Returns:
xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray)
dirs: float, [M, 3], all generated points' view dirs.
deltas: float, [M, 2], first is delta_t, second is rays_t
rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 1] + rays[i, 2]] --> points belonging to rays[i, 0]
'''
if not rays_o.is_cuda: rays_o = rays_o.cuda()
if not rays_d.is_cuda: rays_d = rays_d.cuda()
if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda()
rays_o = rays_o.contiguous().view(-1, 3)
rays_d = rays_d.contiguous().view(-1, 3)
density_bitfield = density_bitfield.contiguous()
N = rays_o.shape[0] # num rays
M = N * max_steps # init max points number in total
# running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp)
# It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated.
if not force_all_rays and mean_count > 0:
if align > 0:
mean_count += align - mean_count % align
M = mean_count
xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device)
rays = torch.empty(N, 3, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps
if step_counter is None:
step_counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device) # point counter, ray counter
if perturb:
noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device)
else:
noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device)
_backend.march_rays_train(rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars, xyzs, dirs, deltas, rays, step_counter, noises) # m is the actually used points number
#print(step_counter, M)
# only used at the first (few) epochs.
if force_all_rays or mean_count <= 0:
m = step_counter[0].item() # D2H copy
if align > 0:
m += align - m % align
xyzs = xyzs[:m]
dirs = dirs[:m]
deltas = deltas[:m]
torch.cuda.empty_cache()
ctx.save_for_backward(rays, deltas)
return xyzs, dirs, deltas, rays
# to support optimizing camera poses.
@staticmethod
@custom_bwd
def backward(ctx, grad_xyzs, grad_dirs, grad_deltas, grad_rays):
# grad_xyzs/dirs: [M, 3]
rays, deltas = ctx.saved_tensors
N = rays.shape[0]
M = grad_xyzs.shape[0]
grad_rays_o = torch.zeros(N, 3, device=rays.device)
grad_rays_d = torch.zeros(N, 3, device=rays.device)
_backend.march_rays_train_backward(grad_xyzs, grad_dirs, rays, deltas, N, M, grad_rays_o, grad_rays_d)
return grad_rays_o, grad_rays_d, None, None, None, None, None, None, None, None, None, None, None, None, None
march_rays_train = _march_rays_train.apply
class _composite_rays_train(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, sigmas, rgbs, ambient, deltas, rays, T_thresh=1e-4):
''' composite rays' rgbs, according to the ray marching formula.
Args:
rgbs: float, [M, 3]
sigmas: float, [M,]
ambient: float, [M,] (after summing up the last dimension)
deltas: float, [M, 2]
rays: int32, [N, 3]
Returns:
weights_sum: float, [N,], the alpha channel
depth: float, [N, ], the Depth
image: float, [N, 3], the RGB channel (after multiplying alpha!)
'''
sigmas = sigmas.contiguous()
rgbs = rgbs.contiguous()
ambient = ambient.contiguous()
M = sigmas.shape[0]
N = rays.shape[0]
weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
ambient_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)
_backend.composite_rays_train_forward(sigmas, rgbs, ambient, deltas, rays, M, N, T_thresh, weights_sum, ambient_sum, depth, image)
ctx.save_for_backward(sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, depth, image)
ctx.dims = [M, N, T_thresh]
return weights_sum, ambient_sum, depth, image
@staticmethod
@custom_bwd
def backward(ctx, grad_weights_sum, grad_ambient_sum, grad_depth, grad_image):
# NOTE: grad_depth is not used now! It won't be propagated to sigmas.
grad_weights_sum = grad_weights_sum.contiguous()
grad_ambient_sum = grad_ambient_sum.contiguous()
grad_image = grad_image.contiguous()
sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, depth, image = ctx.saved_tensors
M, N, T_thresh = ctx.dims
grad_sigmas = torch.zeros_like(sigmas)
grad_rgbs = torch.zeros_like(rgbs)
grad_ambient = torch.zeros_like(ambient)
_backend.composite_rays_train_backward(grad_weights_sum, grad_ambient_sum, grad_image, sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs, grad_ambient)
return grad_sigmas, grad_rgbs, grad_ambient, None, None, None
composite_rays_train = _composite_rays_train.apply
# ----------------------------------------
# infer functions
# ----------------------------------------
class _march_rays(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, align=-1, perturb=False, dt_gamma=0, max_steps=1024):
''' march rays to generate points (forward only, for inference)
Args:
n_alive: int, number of alive rays
n_step: int, how many steps we march
rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive)
rays_t: float, [N], the alive rays' time, we only use the first n_alive.
rays_o/d: float, [N, 3]
bound: float, scalar
density_bitfield: uint8: [CHHH // 8]
C: int
H: int
nears/fars: float, [N]
align: int, pad output so its size is dividable by align, set to -1 to disable.
perturb: bool/int, int > 0 is used as the random seed.
dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
Returns:
xyzs: float, [n_alive * n_step, 3], all generated points' coords
dirs: float, [n_alive * n_step, 3], all generated points' view dirs.
deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).
'''
if not rays_o.is_cuda: rays_o = rays_o.cuda()
if not rays_d.is_cuda: rays_d = rays_d.cuda()
rays_o = rays_o.contiguous().view(-1, 3)
rays_d = rays_d.contiguous().view(-1, 3)
M = n_alive * n_step
if align > 0:
M += align - (M % align)
xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth
if perturb:
# torch.manual_seed(perturb) # test_gui uses spp index as seed
noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device)
else:
noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device)
_backend.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, deltas, noises)
return xyzs, dirs, deltas
march_rays = _march_rays.apply
class _composite_rays(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh=1e-2):
''' composite rays' rgbs, according to the ray marching formula. (for inference)
Args:
n_alive: int, number of alive rays
n_step: int, how many steps we march
rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive)
rays_t: float, [N], the alive rays' time
sigmas: float, [n_alive * n_step,]
rgbs: float, [n_alive * n_step, 3]
deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).
In-place Outputs:
weights_sum: float, [N,], the alpha channel
depth: float, [N,], the depth value
image: float, [N, 3], the RGB channel (after multiplying alpha!)
'''
_backend.composite_rays(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image)
return tuple()
composite_rays = _composite_rays.apply
class _composite_rays_ambient(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, weights_sum, depth, image, ambient_sum, T_thresh=1e-2):
_backend.composite_rays_ambient(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, weights_sum, depth, image, ambient_sum)
return tuple()
composite_rays_ambient = _composite_rays_ambient.apply
# custom
class _composite_rays_train_sigma(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, sigmas, rgbs, ambient, deltas, rays, T_thresh=1e-4):
''' composite rays' rgbs, according to the ray marching formula.
Args:
rgbs: float, [M, 3]
sigmas: float, [M,]
ambient: float, [M,] (after summing up the last dimension)
deltas: float, [M, 2]
rays: int32, [N, 3]
Returns:
weights_sum: float, [N,], the alpha channel
depth: float, [N, ], the Depth
image: float, [N, 3], the RGB channel (after multiplying alpha!)
'''
sigmas = sigmas.contiguous()
rgbs = rgbs.contiguous()
ambient = ambient.contiguous()
M = sigmas.shape[0]
N = rays.shape[0]
weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
ambient_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)
_backend.composite_rays_train_sigma_forward(sigmas, rgbs, ambient, deltas, rays, M, N, T_thresh, weights_sum, ambient_sum, depth, image)
ctx.save_for_backward(sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, depth, image)
ctx.dims = [M, N, T_thresh]
return weights_sum, ambient_sum, depth, image
@staticmethod
@custom_bwd
def backward(ctx, grad_weights_sum, grad_ambient_sum, grad_depth, grad_image):
# NOTE: grad_depth is not used now! It won't be propagated to sigmas.
grad_weights_sum = grad_weights_sum.contiguous()
grad_ambient_sum = grad_ambient_sum.contiguous()
grad_image = grad_image.contiguous()
sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, depth, image = ctx.saved_tensors
M, N, T_thresh = ctx.dims
grad_sigmas = torch.zeros_like(sigmas)
grad_rgbs = torch.zeros_like(rgbs)
grad_ambient = torch.zeros_like(ambient)
_backend.composite_rays_train_sigma_backward(grad_weights_sum, grad_ambient_sum, grad_image, sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs, grad_ambient)
return grad_sigmas, grad_rgbs, grad_ambient, None, None, None
composite_rays_train_sigma = _composite_rays_train_sigma.apply
class _composite_rays_ambient_sigma(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, weights_sum, depth, image, ambient_sum, T_thresh=1e-2):
_backend.composite_rays_ambient_sigma(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, weights_sum, depth, image, ambient_sum)
return tuple()
composite_rays_ambient_sigma = _composite_rays_ambient_sigma.apply
# uncertainty
class _composite_rays_train_uncertainty(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, sigmas, rgbs, ambient, uncertainty, deltas, rays, T_thresh=1e-4):
''' composite rays' rgbs, according to the ray marching formula.
Args:
rgbs: float, [M, 3]
sigmas: float, [M,]
ambient: float, [M,] (after summing up the last dimension)
deltas: float, [M, 2]
rays: int32, [N, 3]
Returns:
weights_sum: float, [N,], the alpha channel
depth: float, [N, ], the Depth
image: float, [N, 3], the RGB channel (after multiplying alpha!)
'''
sigmas = sigmas.contiguous()
rgbs = rgbs.contiguous()
ambient = ambient.contiguous()
uncertainty = uncertainty.contiguous()
M = sigmas.shape[0]
N = rays.shape[0]
weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
ambient_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
uncertainty_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)
_backend.composite_rays_train_uncertainty_forward(sigmas, rgbs, ambient, uncertainty, deltas, rays, M, N, T_thresh, weights_sum, ambient_sum, uncertainty_sum, depth, image)
ctx.save_for_backward(sigmas, rgbs, ambient, uncertainty, deltas, rays, weights_sum, ambient_sum, uncertainty_sum, depth, image)
ctx.dims = [M, N, T_thresh]
return weights_sum, ambient_sum, uncertainty_sum, depth, image
@staticmethod
@custom_bwd
def backward(ctx, grad_weights_sum, grad_ambient_sum, grad_uncertainty_sum, grad_depth, grad_image):
# NOTE: grad_depth is not used now! It won't be propagated to sigmas.
grad_weights_sum = grad_weights_sum.contiguous()
grad_ambient_sum = grad_ambient_sum.contiguous()
grad_uncertainty_sum = grad_uncertainty_sum.contiguous()
grad_image = grad_image.contiguous()
sigmas, rgbs, ambient, uncertainty, deltas, rays, weights_sum, ambient_sum, uncertainty_sum, depth, image = ctx.saved_tensors
M, N, T_thresh = ctx.dims
grad_sigmas = torch.zeros_like(sigmas)
grad_rgbs = torch.zeros_like(rgbs)
grad_ambient = torch.zeros_like(ambient)
grad_uncertainty = torch.zeros_like(uncertainty)
_backend.composite_rays_train_uncertainty_backward(grad_weights_sum, grad_ambient_sum, grad_uncertainty_sum, grad_image, sigmas, rgbs, ambient, uncertainty, deltas, rays, weights_sum, ambient_sum, uncertainty_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs, grad_ambient, grad_uncertainty)
return grad_sigmas, grad_rgbs, grad_ambient, grad_uncertainty, None, None, None
composite_rays_train_uncertainty = _composite_rays_train_uncertainty.apply
class _composite_rays_uncertainty(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, uncertainties, weights_sum, depth, image, ambient_sum, uncertainty_sum, T_thresh=1e-2):
_backend.composite_rays_uncertainty(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, uncertainties, weights_sum, depth, image, ambient_sum, uncertainty_sum)
return tuple()
composite_rays_uncertainty = _composite_rays_uncertainty.apply
# triplane(eye)
class _composite_rays_train_triplane(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, sigmas, rgbs, amb_aud, amb_eye, uncertainty, deltas, rays, T_thresh=1e-4):
''' composite rays' rgbs, according to the ray marching formula.
Args:
rgbs: float, [M, 3]
sigmas: float, [M,]
ambient: float, [M,] (after summing up the last dimension)
deltas: float, [M, 2]
rays: int32, [N, 3]
Returns:
weights_sum: float, [N,], the alpha channel
depth: float, [N, ], the Depth
image: float, [N, 3], the RGB channel (after multiplying alpha!)
'''
sigmas = sigmas.contiguous()
rgbs = rgbs.contiguous()
amb_aud = amb_aud.contiguous()
amb_eye = amb_eye.contiguous()
uncertainty = uncertainty.contiguous()
M = sigmas.shape[0]
N = rays.shape[0]
weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
amb_aud_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
amb_eye_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
uncertainty_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)
_backend.composite_rays_train_triplane_forward(sigmas, rgbs, amb_aud, amb_eye, uncertainty, deltas, rays, M, N, T_thresh, weights_sum, amb_aud_sum, amb_eye_sum, uncertainty_sum, depth, image)
ctx.save_for_backward(sigmas, rgbs, amb_aud, amb_eye, uncertainty, deltas, rays, weights_sum, amb_aud_sum, amb_eye_sum, uncertainty_sum, depth, image)
ctx.dims = [M, N, T_thresh]
return weights_sum, amb_aud_sum, amb_eye_sum, uncertainty_sum, depth, image
@staticmethod
@custom_bwd
def backward(ctx, grad_weights_sum, grad_amb_aud_sum, grad_amb_eye_sum, grad_uncertainty_sum, grad_depth, grad_image):
# NOTE: grad_depth is not used now! It won't be propagated to sigmas.
grad_weights_sum = grad_weights_sum.contiguous()
grad_amb_aud_sum = grad_amb_aud_sum.contiguous()
grad_amb_eye_sum = grad_amb_eye_sum.contiguous()
grad_uncertainty_sum = grad_uncertainty_sum.contiguous()
grad_image = grad_image.contiguous()
sigmas, rgbs, amb_aud, amb_eye, uncertainty, deltas, rays, weights_sum, amb_aud_sum, amb_eye_sum, uncertainty_sum, depth, image = ctx.saved_tensors
M, N, T_thresh = ctx.dims
grad_sigmas = torch.zeros_like(sigmas)
grad_rgbs = torch.zeros_like(rgbs)
grad_amb_aud = torch.zeros_like(amb_aud)
grad_amb_eye = torch.zeros_like(amb_eye)
grad_uncertainty = torch.zeros_like(uncertainty)
_backend.composite_rays_train_triplane_backward(grad_weights_sum, grad_amb_aud_sum, grad_amb_eye_sum, grad_uncertainty_sum, grad_image, sigmas, rgbs, amb_aud, amb_eye, uncertainty, deltas, rays, weights_sum, amb_aud_sum, amb_eye_sum, uncertainty_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs, grad_amb_aud, grad_amb_eye, grad_uncertainty)
return grad_sigmas, grad_rgbs, grad_amb_aud, grad_amb_eye, grad_uncertainty, None, None, None
composite_rays_train_triplane = _composite_rays_train_triplane.apply
class _composite_rays_triplane(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambs_aud, ambs_eye, uncertainties, weights_sum, depth, image, amb_aud_sum, amb_eye_sum, uncertainty_sum, T_thresh=1e-2):
_backend.composite_rays_triplane(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, ambs_aud, ambs_eye, uncertainties, weights_sum, depth, image, amb_aud_sum, amb_eye_sum, uncertainty_sum)
return tuple()
composite_rays_triplane = _composite_rays_triplane.apply

63
raymarching/setup.py Normal file
View File

@ -0,0 +1,63 @@
import os
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
_src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [
'-O3', '-std=c++14',
# '-lineinfo', # to debug illegal memory access
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
]
if os.name == "posix":
c_flags = ['-O3', '-std=c++14']
elif os.name == "nt":
c_flags = ['/O2', '/std:c++17']
# find cl.exe
def find_cl_path():
import glob
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
if paths:
return paths[0]
# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
os.environ["PATH"] += ";" + cl_path
'''
Usage:
python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory)
python setup.py install # build extensions and install (copy) to PATH.
pip install . # ditto but better (e.g., dependency & metadata handling)
python setup.py develop # build extensions and install (symbolic) to PATH.
pip install -e . # ditto but better (e.g., dependency & metadata handling)
'''
setup(
name='raymarching_face', # package name, import this to use python API
ext_modules=[
CUDAExtension(
name='_raymarching_face', # extension name, import this to use CUDA API
sources=[os.path.join(_src_path, 'src', f) for f in [
'raymarching.cu',
'bindings.cpp',
]],
extra_compile_args={
'cxx': c_flags,
'nvcc': nvcc_flags,
}
),
],
cmdclass={
'build_ext': BuildExtension,
}
)

View File

@ -0,0 +1,39 @@
#include <torch/extension.h>
#include "raymarching.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// utils
m.def("packbits", &packbits, "packbits (CUDA)");
m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)");
m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)");
m.def("morton3D", &morton3D, "morton3D (CUDA)");
m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)");
m.def("morton3D_dilation", &morton3D_dilation, "morton3D_dilation (CUDA)");
// train
m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)");
m.def("march_rays_train_backward", &march_rays_train_backward, "march_rays_train_backward (CUDA)");
m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)");
m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)");
// infer
m.def("march_rays", &march_rays, "march rays (CUDA)");
m.def("composite_rays", &composite_rays, "composite rays (CUDA)");
m.def("composite_rays_ambient", &composite_rays_ambient, "composite rays with ambient (CUDA)");
// train
m.def("composite_rays_train_sigma_forward", &composite_rays_train_sigma_forward, "composite_rays_train_forward (CUDA)");
m.def("composite_rays_train_sigma_backward", &composite_rays_train_sigma_backward, "composite_rays_train_backward (CUDA)");
// infer
m.def("composite_rays_ambient_sigma", &composite_rays_ambient_sigma, "composite rays with ambient (CUDA)");
// uncertainty train
m.def("composite_rays_train_uncertainty_forward", &composite_rays_train_uncertainty_forward, "composite_rays_train_forward (CUDA)");
m.def("composite_rays_train_uncertainty_backward", &composite_rays_train_uncertainty_backward, "composite_rays_train_backward (CUDA)");
m.def("composite_rays_uncertainty", &composite_rays_uncertainty, "composite rays with ambient (CUDA)");
// triplane
m.def("composite_rays_train_triplane_forward", &composite_rays_train_triplane_forward, "composite_rays_train_forward (CUDA)");
m.def("composite_rays_train_triplane_backward", &composite_rays_train_triplane_backward, "composite_rays_train_backward (CUDA)");
m.def("composite_rays_triplane", &composite_rays_triplane, "composite rays with ambient (CUDA)");
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,38 @@
#pragma once
#include <stdint.h>
#include <torch/torch.h>
void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars);
void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords);
void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices);
void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords);
void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield);
void morton3D_dilation(const at::Tensor grid, const uint32_t C, const uint32_t H, at::Tensor grid_dilation);
void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises);
void march_rays_train_backward(const at::Tensor grad_xyzs, const at::Tensor grad_dirs, const at::Tensor rays, const at::Tensor deltas, const uint32_t N, const uint32_t M, at::Tensor grad_rays_o, at::Tensor grad_rays_d);
void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor ambient_sum, at::Tensor depth, at::Tensor image);
void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_ambient_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor ambient_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_ambient);
void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises);
void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
void composite_rays_ambient(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor ambients, at::Tensor weights, at::Tensor depth, at::Tensor image, at::Tensor ambient_sum);
void composite_rays_train_sigma_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor ambient_sum, at::Tensor depth, at::Tensor image);
void composite_rays_train_sigma_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_ambient_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor ambient_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_ambient);
void composite_rays_ambient_sigma(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor ambients, at::Tensor weights, at::Tensor depth, at::Tensor image, at::Tensor ambient_sum);
// uncertainty
void composite_rays_train_uncertainty_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor uncertainty, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor ambient_sum, at::Tensor uncertainty_sum, at::Tensor depth, at::Tensor image);
void composite_rays_train_uncertainty_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_ambient_sum, const at::Tensor grad_uncertainty_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor uncertainty, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor ambient_sum, const at::Tensor uncertainty_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_ambient, at::Tensor grad_uncertainty);
void composite_rays_uncertainty(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor ambients, at::Tensor uncertainties, at::Tensor weights, at::Tensor depth, at::Tensor image, at::Tensor ambient_sum, at::Tensor uncertainty_sum);
// triplane
void composite_rays_train_triplane_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor amb_aud, const at::Tensor amb_eye, const at::Tensor uncertainty, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor amb_aud_sum, at::Tensor amb_eye_sum, at::Tensor uncertainty_sum, at::Tensor depth, at::Tensor image);
void composite_rays_train_triplane_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_amb_aud_sum, const at::Tensor grad_amb_eye_sum, const at::Tensor grad_uncertainty_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor amb_aud, const at::Tensor amb_eye, const at::Tensor uncertainty, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor amb_aud_sum, const at::Tensor amb_eye_sum, const at::Tensor uncertainty_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_amb_aud, at::Tensor grad_amb_eye, at::Tensor grad_uncertainty);
void composite_rays_triplane(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor ambs_aud, at::Tensor ambs_eye, at::Tensor uncertainties, at::Tensor weights, at::Tensor depth, at::Tensor image, at::Tensor amb_aud_sum, at::Tensor amb_eye_sum, at::Tensor uncertainty_sum);

26
requirements.txt Normal file
View File

@ -0,0 +1,26 @@
torch-ema
ninja
trimesh
opencv-python
tensorboardX
numpy
pandas
tqdm
matplotlib
PyMCubes
rich
dearpygui
packaging
scipy
face_alignment
python_speech_features
numba
resampy
pyaudio
soundfile
einops
configargparse
lpips
imageio-ffmpeg

5
scripts/train_obama.sh Normal file
View File

@ -0,0 +1,5 @@
python main.py data/obama/ --workspace trial_obama_triplane/ -O --iters 100000
cp -r trial_obama_triplane/checkpoints trial_obama_triplane/checkpoints_
python main.py data/obama/ --workspace trial_obama_triplane/ -O --iters 125000 --finetune_lips --patch_size 32
python main.py data/obama/ --workspace trial_obama_triplane/ -O --test
# python main.py data/obama/ --workspace trial_obama_triplane_torso/ -O --torso --head_ckpt <head>.pth --iters 200000

1
shencoder/__init__.py Normal file
View File

@ -0,0 +1 @@
from .sphere_harmonics import SHEncoder

40
shencoder/backend.py Normal file
View File

@ -0,0 +1,40 @@
import os
from torch.utils.cpp_extension import load
_src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [
'-O3', '-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
]
if os.name == "posix":
c_flags = ['-O3', '-std=c++14', '-finput-charset=utf-8']
elif os.name == "nt":
c_flags = ['/O2', '/std:c++17', '/source-charset:utf-8']
# find cl.exe
def find_cl_path():
import glob
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
if paths:
return paths[0]
# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
os.environ["PATH"] += ";" + cl_path
_backend = load(name='_sh_encoder',
extra_cflags=c_flags,
extra_cuda_cflags=nvcc_flags,
sources=[os.path.join(_src_path, 'src', f) for f in [
'shencoder.cu',
'bindings.cpp',
]],
)
__all__ = ['_backend']

50
shencoder/setup.py Normal file
View File

@ -0,0 +1,50 @@
import os
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
_src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [
'-O3', '-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
]
if os.name == "posix":
c_flags = ['-O3', '-std=c++14']
elif os.name == "nt":
c_flags = ['/O2', '/std:c++17']
# find cl.exe
def find_cl_path():
import glob
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
if paths:
return paths[0]
# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
os.environ["PATH"] += ";" + cl_path
setup(
name='shencoder', # package name, import this to use python API
ext_modules=[
CUDAExtension(
name='_shencoder', # extension name, import this to use CUDA API
sources=[os.path.join(_src_path, 'src', f) for f in [
'shencoder.cu',
'bindings.cpp',
]],
extra_compile_args={
'cxx': c_flags,
'nvcc': nvcc_flags,
}
),
],
cmdclass={
'build_ext': BuildExtension,
}
)

View File

@ -0,0 +1,87 @@
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd
try:
import _shencoder as _backend
except ImportError:
from .backend import _backend
class _sh_encoder(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
def forward(ctx, inputs, degree, calc_grad_inputs=False):
# inputs: [B, input_dim], float in [-1, 1]
# RETURN: [B, F], float
inputs = inputs.contiguous()
B, input_dim = inputs.shape # batch size, coord dim
output_dim = degree ** 2
outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
if calc_grad_inputs:
dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device)
else:
dy_dx = None
_backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx)
ctx.save_for_backward(inputs, dy_dx)
ctx.dims = [B, input_dim, degree]
return outputs
@staticmethod
#@once_differentiable
@custom_bwd
def backward(ctx, grad):
# grad: [B, C * C]
inputs, dy_dx = ctx.saved_tensors
if dy_dx is not None:
grad = grad.contiguous()
B, input_dim, degree = ctx.dims
grad_inputs = torch.zeros_like(inputs)
_backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs)
return grad_inputs, None, None
else:
return None, None, None
sh_encode = _sh_encoder.apply
class SHEncoder(nn.Module):
def __init__(self, input_dim=3, degree=4):
super().__init__()
self.input_dim = input_dim # coord dims, must be 3
self.degree = degree # 0 ~ 4
self.output_dim = degree ** 2
assert self.input_dim == 3, "SH encoder only support input dim == 3"
assert self.degree > 0 and self.degree <= 8, "SH encoder only supports degree in [1, 8]"
def __repr__(self):
return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}"
def forward(self, inputs, size=1):
# inputs: [..., input_dim], normalized real world positions in [-size, size]
# return: [..., degree^2]
inputs = inputs / size # [-1, 1]
prefix_shape = list(inputs.shape[:-1])
inputs = inputs.reshape(-1, self.input_dim)
outputs = sh_encode(inputs, self.degree, inputs.requires_grad)
outputs = outputs.reshape(prefix_shape + [self.output_dim])
return outputs

View File

@ -0,0 +1,8 @@
#include <torch/extension.h>
#include "shencoder.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("sh_encode_forward", &sh_encode_forward, "SH encode forward (CUDA)");
m.def("sh_encode_backward", &sh_encode_backward, "SH encode backward (CUDA)");
}

439
shencoder/src/shencoder.cu Normal file
View File

@ -0,0 +1,439 @@
#include <stdint.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/torch.h>
#include <algorithm>
#include <stdexcept>
#include <cstdio>
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
template <typename T>
__host__ __device__ T div_round_up(T val, T divisor) {
return (val + divisor - 1) / divisor;
}
template <typename scalar_t>
__global__ void kernel_sh(
const scalar_t * __restrict__ inputs,
scalar_t * outputs,
uint32_t B, uint32_t D, uint32_t C,
scalar_t * dy_dx
) {
const uint32_t b = threadIdx.x + blockIdx.x * blockDim.x;
if (b >= B) return;
const uint32_t C2 = C * C;
// locate
inputs += b * D;
outputs += b * C2;
scalar_t x = inputs[0], y = inputs[1], z = inputs[2];
scalar_t xy=x*y, xz=x*z, yz=y*z, x2=x*x, y2=y*y, z2=z*z, xyz=xy*z;
scalar_t x4=x2*x2, y4=y2*y2, z4=z2*z2;
scalar_t x6=x4*x2, y6=y4*y2, z6=z4*z2;
auto write_sh = [&]() {
outputs[0] = 0.28209479177387814f ; // 1/(2*sqrt(pi))
if (C <= 1) { return; }
outputs[1] = -0.48860251190291987f*y ; // -sqrt(3)*y/(2*sqrt(pi))
outputs[2] = 0.48860251190291987f*z ; // sqrt(3)*z/(2*sqrt(pi))
outputs[3] = -0.48860251190291987f*x ; // -sqrt(3)*x/(2*sqrt(pi))
if (C <= 2) { return; }
outputs[4] = 1.0925484305920792f*xy ; // sqrt(15)*xy/(2*sqrt(pi))
outputs[5] = -1.0925484305920792f*yz ; // -sqrt(15)*yz/(2*sqrt(pi))
outputs[6] = 0.94617469575755997f*z2 - 0.31539156525251999f ; // sqrt(5)*(3*z2 - 1)/(4*sqrt(pi))
outputs[7] = -1.0925484305920792f*xz ; // -sqrt(15)*xz/(2*sqrt(pi))
outputs[8] = 0.54627421529603959f*x2 - 0.54627421529603959f*y2 ; // sqrt(15)*(x2 - y2)/(4*sqrt(pi))
if (C <= 3) { return; }
outputs[9] = 0.59004358992664352f*y*(-3.0f*x2 + y2) ; // sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi))
outputs[10] = 2.8906114426405538f*xy*z ; // sqrt(105)*xy*z/(2*sqrt(pi))
outputs[11] = 0.45704579946446572f*y*(1.0f - 5.0f*z2) ; // sqrt(42)*y*(1 - 5*z2)/(8*sqrt(pi))
outputs[12] = 0.3731763325901154f*z*(5.0f*z2 - 3.0f) ; // sqrt(7)*z*(5*z2 - 3)/(4*sqrt(pi))
outputs[13] = 0.45704579946446572f*x*(1.0f - 5.0f*z2) ; // sqrt(42)*x*(1 - 5*z2)/(8*sqrt(pi))
outputs[14] = 1.4453057213202769f*z*(x2 - y2) ; // sqrt(105)*z*(x2 - y2)/(4*sqrt(pi))
outputs[15] = 0.59004358992664352f*x*(-x2 + 3.0f*y2) ; // sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi))
if (C <= 4) { return; }
outputs[16] = 2.5033429417967046f*xy*(x2 - y2) ; // 3*sqrt(35)*xy*(x2 - y2)/(4*sqrt(pi))
outputs[17] = 1.7701307697799304f*yz*(-3.0f*x2 + y2) ; // 3*sqrt(70)*yz*(-3*x2 + y2)/(8*sqrt(pi))
outputs[18] = 0.94617469575756008f*xy*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*xy*(7*z2 - 1)/(4*sqrt(pi))
outputs[19] = 0.66904654355728921f*yz*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*yz*(3 - 7*z2)/(8*sqrt(pi))
outputs[20] = -3.1735664074561294f*z2 + 3.7024941420321507f*z4 + 0.31735664074561293f ; // 3*(-30*z2 + 35*z4 + 3)/(16*sqrt(pi))
outputs[21] = 0.66904654355728921f*xz*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*xz*(3 - 7*z2)/(8*sqrt(pi))
outputs[22] = 0.47308734787878004f*(x2 - y2)*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*(x2 - y2)*(7*z2 - 1)/(8*sqrt(pi))
outputs[23] = 1.7701307697799304f*xz*(-x2 + 3.0f*y2) ; // 3*sqrt(70)*xz*(-x2 + 3*y2)/(8*sqrt(pi))
outputs[24] = -3.7550144126950569f*x2*y2 + 0.62583573544917614f*x4 + 0.62583573544917614f*y4 ; // 3*sqrt(35)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
if (C <= 5) { return; }
outputs[25] = 0.65638205684017015f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(154)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
outputs[26] = 8.3026492595241645f*xy*z*(x2 - y2) ; // 3*sqrt(385)*xy*z*(x2 - y2)/(4*sqrt(pi))
outputs[27] = -0.48923829943525038f*y*(3.0f*x2 - y2)*(9.0f*z2 - 1.0f) ; // -sqrt(770)*y*(3*x2 - y2)*(9*z2 - 1)/(32*sqrt(pi))
outputs[28] = 4.7935367849733241f*xy*z*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xy*z*(3*z2 - 1)/(4*sqrt(pi))
outputs[29] = 0.45294665119569694f*y*(14.0f*z2 - 21.0f*z4 - 1.0f) ; // sqrt(165)*y*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
outputs[30] = 0.1169503224534236f*z*(-70.0f*z2 + 63.0f*z4 + 15.0f) ; // sqrt(11)*z*(-70*z2 + 63*z4 + 15)/(16*sqrt(pi))
outputs[31] = 0.45294665119569694f*x*(14.0f*z2 - 21.0f*z4 - 1.0f) ; // sqrt(165)*x*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
outputs[32] = 2.3967683924866621f*z*(x2 - y2)*(3.0f*z2 - 1.0f) ; // sqrt(1155)*z*(x2 - y2)*(3*z2 - 1)/(8*sqrt(pi))
outputs[33] = -0.48923829943525038f*x*(x2 - 3.0f*y2)*(9.0f*z2 - 1.0f) ; // -sqrt(770)*x*(x2 - 3*y2)*(9*z2 - 1)/(32*sqrt(pi))
outputs[34] = 2.0756623148810411f*z*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(385)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
outputs[35] = 0.65638205684017015f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(154)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
if (C <= 6) { return; }
outputs[36] = 1.3663682103838286f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // sqrt(6006)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
outputs[37] = 2.3666191622317521f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(2002)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
outputs[38] = 2.0182596029148963f*xy*(x2 - y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*xy*(x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))
outputs[39] = -0.92120525951492349f*yz*(3.0f*x2 - y2)*(11.0f*z2 - 3.0f) ; // -sqrt(2730)*yz*(3*x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))
outputs[40] = 0.92120525951492349f*xy*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*xy*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
outputs[41] = 0.58262136251873131f*yz*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*yz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
outputs[42] = 6.6747662381009842f*z2 - 20.024298714302954f*z4 + 14.684485723822165f*z6 - 0.31784601133814211f ; // sqrt(13)*(105*z2 - 315*z4 + 231*z6 - 5)/(32*sqrt(pi))
outputs[43] = 0.58262136251873131f*xz*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*xz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
outputs[44] = 0.46060262975746175f*(x2 - y2)*(11.0f*z2*(3.0f*z2 - 1.0f) - 7.0f*z2 + 1.0f) ; // sqrt(2730)*(x2 - y2)*(11*z2*(3*z2 - 1) - 7*z2 + 1)/(64*sqrt(pi))
outputs[45] = -0.92120525951492349f*xz*(x2 - 3.0f*y2)*(11.0f*z2 - 3.0f) ; // -sqrt(2730)*xz*(x2 - 3*y2)*(11*z2 - 3)/(32*sqrt(pi))
outputs[46] = 0.50456490072872406f*(11.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(91)*(11*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))
outputs[47] = 2.3666191622317521f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(2002)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
outputs[48] = 10.247761577878714f*x2*y4 - 10.247761577878714f*x4*y2 + 0.6831841051919143f*x6 - 0.6831841051919143f*y6 ; // sqrt(6006)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))
if (C <= 7) { return; }
outputs[49] = 0.70716273252459627f*y*(-21.0f*x2*y4 + 35.0f*x4*y2 - 7.0f*x6 + y6) ; // 3*sqrt(715)*y*(-21*x2*y4 + 35*x4*y2 - 7*x6 + y6)/(64*sqrt(pi))
outputs[50] = 5.2919213236038001f*xy*z*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 3*sqrt(10010)*xy*z*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
outputs[51] = -0.51891557872026028f*y*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // -3*sqrt(385)*y*(13*z2 - 1)*(-10*x2*y2 + 5*x4 + y4)/(64*sqrt(pi))
outputs[52] = 4.1513246297620823f*xy*z*(x2 - y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xy*z*(x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))
outputs[53] = -0.15645893386229404f*y*(3.0f*x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -3*sqrt(35)*y*(3*x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))
outputs[54] = 0.44253269244498261f*xy*z*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xy*z*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
outputs[55] = 0.090331607582517306f*y*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ; // sqrt(105)*y*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
outputs[56] = 0.068284276912004949f*z*(315.0f*z2 - 693.0f*z4 + 429.0f*z6 - 35.0f) ; // sqrt(15)*z*(315*z2 - 693*z4 + 429*z6 - 35)/(32*sqrt(pi))
outputs[57] = 0.090331607582517306f*x*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ; // sqrt(105)*x*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
outputs[58] = 0.07375544874083044f*z*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) - 187.0f*z2 + 45.0f) ; // sqrt(70)*z*(x2 - y2)*(143*z2*(3*z2 - 1) - 187*z2 + 45)/(64*sqrt(pi))
outputs[59] = -0.15645893386229404f*x*(x2 - 3.0f*y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -3*sqrt(35)*x*(x2 - 3*y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))
outputs[60] = 1.0378311574405206f*z*(13.0f*z2 - 3.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(385)*z*(13*z2 - 3)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))
outputs[61] = -0.51891557872026028f*x*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // -3*sqrt(385)*x*(13*z2 - 1)*(-10*x2*y2 + x4 + 5*y4)/(64*sqrt(pi))
outputs[62] = 2.6459606618019f*z*(15.0f*x2*y4 - 15.0f*x4*y2 + x6 - y6) ; // 3*sqrt(10010)*z*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))
outputs[63] = 0.70716273252459627f*x*(-35.0f*x2*y4 + 21.0f*x4*y2 - x6 + 7.0f*y6) ; // 3*sqrt(715)*x*(-35*x2*y4 + 21*x4*y2 - x6 + 7*y6)/(64*sqrt(pi))
};
write_sh();
if (dy_dx) {
scalar_t *dx = dy_dx + b * D * C2;
scalar_t *dy = dx + C2;
scalar_t *dz = dy + C2;
auto write_sh_dx = [&]() {
dx[0] = 0.0f ; // 0
if (C <= 1) { return; }
dx[1] = 0.0f ; // 0
dx[2] = 0.0f ; // 0
dx[3] = -0.48860251190291992f ; // -sqrt(3)/(2*sqrt(pi))
if (C <= 2) { return; }
dx[4] = 1.0925484305920792f*y ; // sqrt(15)*y/(2*sqrt(pi))
dx[5] = 0.0f ; // 0
dx[6] = 0.0f ; // 0
dx[7] = -1.0925484305920792f*z ; // -sqrt(15)*z/(2*sqrt(pi))
dx[8] = 1.0925484305920792f*x ; // sqrt(15)*x/(2*sqrt(pi))
if (C <= 3) { return; }
dx[9] = -3.5402615395598609f*xy ; // -3*sqrt(70)*xy/(4*sqrt(pi))
dx[10] = 2.8906114426405538f*yz ; // sqrt(105)*yz/(2*sqrt(pi))
dx[11] = 0.0f ; // 0
dx[12] = 0.0f ; // 0
dx[13] = 0.45704579946446572f - 2.2852289973223288f*z2 ; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi))
dx[14] = 2.8906114426405538f*xz ; // sqrt(105)*xz/(2*sqrt(pi))
dx[15] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi))
if (C <= 4) { return; }
dx[16] = 2.5033429417967046f*y*(3.0f*x2 - y2) ; // 3*sqrt(35)*y*(3*x2 - y2)/(4*sqrt(pi))
dx[17] = -10.620784618679583f*xy*z ; // -9*sqrt(70)*xy*z/(4*sqrt(pi))
dx[18] = 0.94617469575756008f*y*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*y*(7*z2 - 1)/(4*sqrt(pi))
dx[19] = 0.0f ; // 0
dx[20] = 0.0f ; // 0
dx[21] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi))
dx[22] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi))
dx[23] = 5.3103923093397913f*z*(-x2 + y2) ; // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi))
dx[24] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ; // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi))
if (C <= 5) { return; }
dx[25] = 13.127641136803401f*xy*(-x2 + y2) ; // 15*sqrt(154)*xy*(-x2 + y2)/(8*sqrt(pi))
dx[26] = 8.3026492595241645f*yz*(3.0f*x2 - y2) ; // 3*sqrt(385)*yz*(3*x2 - y2)/(4*sqrt(pi))
dx[27] = 2.9354297966115022f*xy*(1.0f - 9.0f*z2) ; // 3*sqrt(770)*xy*(1 - 9*z2)/(16*sqrt(pi))
dx[28] = 4.7935367849733241f*yz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*yz*(3*z2 - 1)/(4*sqrt(pi))
dx[29] = 0.0f ; // 0
dx[30] = 0.0f ; // 0
dx[31] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ; // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
dx[32] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi))
dx[33] = -13.209434084751759f*x2*z2 + 1.4677148983057511f*x2 + 13.209434084751759f*y2*z2 - 1.4677148983057511f*y2 ; // 3*sqrt(770)*(-9*x2*z2 + x2 + 9*y2*z2 - y2)/(32*sqrt(pi))
dx[34] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ; // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi))
dx[35] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ; // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
if (C <= 6) { return; }
dx[36] = 4.0991046311514854f*y*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // 3*sqrt(6006)*y*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi))
dx[37] = 47.332383244635047f*xy*z*(-x2 + y2) ; // 15*sqrt(2002)*xy*z*(-x2 + y2)/(8*sqrt(pi))
dx[38] = 2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))
dx[39] = 5.5272315570895412f*xy*z*(3.0f - 11.0f*z2) ; // 3*sqrt(2730)*xy*z*(3 - 11*z2)/(16*sqrt(pi))
dx[40] = 0.92120525951492349f*y*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*y*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
dx[41] = 0.0f ; // 0
dx[42] = 0.0f ; // 0
dx[43] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
dx[44] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
dx[45] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))
dx[46] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi))
dx[47] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ; // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
dx[48] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
if (C <= 7) { return; }
dx[49] = 9.9002782553443485f*xy*(10.0f*x2*y2 - 3.0f*x4 - 3.0f*y4) ; // 21*sqrt(715)*xy*(10*x2*y2 - 3*x4 - 3*y4)/(32*sqrt(pi))
dx[50] = 15.875763970811402f*yz*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // 9*sqrt(10010)*yz*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi))
dx[51] = -10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // -15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi))
dx[52] = 4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))
dx[53] = 0.93875360317376422f*xy*(66.0f*z2 - 143.0f*z4 - 3.0f) ; // 9*sqrt(35)*xy*(66*z2 - 143*z4 - 3)/(32*sqrt(pi))
dx[54] = 0.44253269244498261f*yz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*yz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
dx[55] = 0.0f ; // 0
dx[56] = 0.0f ; // 0
dx[57] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ; // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
dx[58] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
dx[59] = 30.97886890473422f*x2*z2 - 67.120882626924143f*x2*z4 - 1.4081304047606462f*x2 - 30.97886890473422f*y2*z2 + 67.120882626924143f*y2*z4 + 1.4081304047606462f*y2 ; // 9*sqrt(35)*(66*x2*z2 - 143*x2*z4 - 3*x2 - 66*y2*z2 + 143*y2*z4 + 3*y2)/(64*sqrt(pi))
dx[60] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi))
dx[61] = -0.51891557872026028f*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 4.0f*x2*(x2 - 5.0f*y2) + x4 + 5.0f*y4) ; // -3*sqrt(385)*(13*z2 - 1)*(-10*x2*y2 + 4*x2*(x2 - 5*y2) + x4 + 5*y4)/(64*sqrt(pi))
dx[62] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
dx[63] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi))
};
auto write_sh_dy = [&]() {
dy[0] = 0.0f ; // 0
if (C <= 1) { return; }
dy[1] = -0.48860251190291992f ; // -sqrt(3)/(2*sqrt(pi))
dy[2] = 0.0f ; // 0
dy[3] = 0.0f ; // 0
if (C <= 2) { return; }
dy[4] = 1.0925484305920792f*x ; // sqrt(15)*x/(2*sqrt(pi))
dy[5] = -1.0925484305920792f*z ; // -sqrt(15)*z/(2*sqrt(pi))
dy[6] = 0.0f ; // 0
dy[7] = 0.0f ; // 0
dy[8] = -1.0925484305920792f*y ; // -sqrt(15)*y/(2*sqrt(pi))
if (C <= 3) { return; }
dy[9] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi))
dy[10] = 2.8906114426405538f*xz ; // sqrt(105)*xz/(2*sqrt(pi))
dy[11] = 0.45704579946446572f - 2.2852289973223288f*z2 ; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi))
dy[12] = 0.0f ; // 0
dy[13] = 0.0f ; // 0
dy[14] = -2.8906114426405538f*yz ; // -sqrt(105)*yz/(2*sqrt(pi))
dy[15] = 3.5402615395598609f*xy ; // 3*sqrt(70)*xy/(4*sqrt(pi))
if (C <= 4) { return; }
dy[16] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ; // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi))
dy[17] = 5.3103923093397913f*z*(-x2 + y2) ; // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi))
dy[18] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi))
dy[19] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi))
dy[20] = 0.0f ; // 0
dy[21] = 0.0f ; // 0
dy[22] = 0.94617469575756008f*y*(1.0f - 7.0f*z2) ; // 3*sqrt(5)*y*(1 - 7*z2)/(4*sqrt(pi))
dy[23] = 10.620784618679583f*xy*z ; // 9*sqrt(70)*xy*z/(4*sqrt(pi))
dy[24] = 2.5033429417967046f*y*(-3.0f*x2 + y2) ; // 3*sqrt(35)*y*(-3*x2 + y2)/(4*sqrt(pi))
if (C <= 5) { return; }
dy[25] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ; // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
dy[26] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ; // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi))
dy[27] = -1.4677148983057511f*(x2 - y2)*(9.0f*z2 - 1.0f) ; // -3*sqrt(770)*(x2 - y2)*(9*z2 - 1)/(32*sqrt(pi))
dy[28] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi))
dy[29] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ; // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
dy[30] = 0.0f ; // 0
dy[31] = 0.0f ; // 0
dy[32] = 4.7935367849733241f*yz*(1.0f - 3.0f*z2) ; // sqrt(1155)*yz*(1 - 3*z2)/(4*sqrt(pi))
dy[33] = 2.9354297966115022f*xy*(9.0f*z2 - 1.0f) ; // 3*sqrt(770)*xy*(9*z2 - 1)/(16*sqrt(pi))
dy[34] = 8.3026492595241645f*yz*(-3.0f*x2 + y2) ; // 3*sqrt(385)*yz*(-3*x2 + y2)/(4*sqrt(pi))
dy[35] = 13.127641136803401f*xy*(x2 - y2) ; // 15*sqrt(154)*xy*(x2 - y2)/(8*sqrt(pi))
if (C <= 6) { return; }
dy[36] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
dy[37] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ; // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
dy[38] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi))
dy[39] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))
dy[40] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
dy[41] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
dy[42] = 0.0f ; // 0
dy[43] = 0.0f ; // 0
dy[44] = 0.92120525951492349f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // sqrt(2730)*y*(18*z2 - 33*z4 - 1)/(32*sqrt(pi))
dy[45] = 5.5272315570895412f*xy*z*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(16*sqrt(pi))
dy[46] = -2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))
dy[47] = 47.332383244635047f*xy*z*(x2 - y2) ; // 15*sqrt(2002)*xy*z*(x2 - y2)/(8*sqrt(pi))
dy[48] = 4.0991046311514854f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(6006)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
if (C <= 7) { return; }
dy[49] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi))
dy[50] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
dy[51] = 0.51891557872026028f*(13.0f*z2 - 1.0f)*(10.0f*x2*y2 - 5.0f*x4 + 4.0f*y2*(5.0f*x2 - y2) - y4) ; // 3*sqrt(385)*(13*z2 - 1)*(10*x2*y2 - 5*x4 + 4*y2*(5*x2 - y2) - y4)/(64*sqrt(pi))
dy[52] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi))
dy[53] = -0.46937680158688211f*(x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -9*sqrt(35)*(x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))
dy[54] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
dy[55] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ; // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
dy[56] = 0.0f ; // 0
dy[57] = 0.0f ; // 0
dy[58] = 0.44253269244498261f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 3*sqrt(70)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))
dy[59] = 0.93875360317376422f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ; // 9*sqrt(35)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi))
dy[60] = -4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // -3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))
dy[61] = 10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // 15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi))
dy[62] = 15.875763970811402f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 9*sqrt(10010)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
dy[63] = 9.9002782553443485f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 21*sqrt(715)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
};
auto write_sh_dz = [&]() {
dz[0] = 0.0f ; // 0
if (C <= 1) { return; }
dz[1] = 0.0f ; // 0
dz[2] = 0.48860251190291992f ; // sqrt(3)/(2*sqrt(pi))
dz[3] = 0.0f ; // 0
if (C <= 2) { return; }
dz[4] = 0.0f ; // 0
dz[5] = -1.0925484305920792f*y ; // -sqrt(15)*y/(2*sqrt(pi))
dz[6] = 1.8923493915151202f*z ; // 3*sqrt(5)*z/(2*sqrt(pi))
dz[7] = -1.0925484305920792f*x ; // -sqrt(15)*x/(2*sqrt(pi))
dz[8] = 0.0f ; // 0
if (C <= 3) { return; }
dz[9] = 0.0f ; // 0
dz[10] = 2.8906114426405538f*xy ; // sqrt(105)*xy/(2*sqrt(pi))
dz[11] = -4.5704579946446566f*yz ; // -5*sqrt(42)*yz/(4*sqrt(pi))
dz[12] = 5.597644988851731f*z2 - 1.1195289977703462f ; // 3*sqrt(7)*(5*z2 - 1)/(4*sqrt(pi))
dz[13] = -4.5704579946446566f*xz ; // -5*sqrt(42)*xz/(4*sqrt(pi))
dz[14] = 1.4453057213202769f*x2 - 1.4453057213202769f*y2 ; // sqrt(105)*(x2 - y2)/(4*sqrt(pi))
dz[15] = 0.0f ; // 0
if (C <= 4) { return; }
dz[16] = 0.0f ; // 0
dz[17] = 1.7701307697799304f*y*(-3.0f*x2 + y2) ; // 3*sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi))
dz[18] = 13.246445740605839f*xy*z ; // 21*sqrt(5)*xy*z/(2*sqrt(pi))
dz[19] = 2.0071396306718676f*y*(1.0f - 7.0f*z2) ; // 9*sqrt(10)*y*(1 - 7*z2)/(8*sqrt(pi))
dz[20] = 14.809976568128603f*z*z*z - 6.3471328149122579f*z ; // (105*z**3 - 45*z)/(4*sqrt(pi))
dz[21] = 2.0071396306718676f*x*(1.0f - 7.0f*z2) ; // 9*sqrt(10)*x*(1 - 7*z2)/(8*sqrt(pi))
dz[22] = 6.6232228703029197f*z*(x2 - y2) ; // 21*sqrt(5)*z*(x2 - y2)/(4*sqrt(pi))
dz[23] = 1.7701307697799304f*x*(-x2 + 3.0f*y2) ; // 3*sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi))
dz[24] = 0.0f ; // 0
if (C <= 5) { return; }
dz[25] = 0.0f ; // 0
dz[26] = 8.3026492595241645f*xy*(x2 - y2) ; // 3*sqrt(385)*xy*(x2 - y2)/(4*sqrt(pi))
dz[27] = 8.8062893898345074f*yz*(-3.0f*x2 + y2) ; // 9*sqrt(770)*yz*(-3*x2 + y2)/(16*sqrt(pi))
dz[28] = 4.7935367849733241f*xy*(9.0f*z2 - 1.0f) ; // sqrt(1155)*xy*(9*z2 - 1)/(4*sqrt(pi))
dz[29] = 12.682506233479513f*yz*(1.0f - 3.0f*z2) ; // 7*sqrt(165)*yz*(1 - 3*z2)/(4*sqrt(pi))
dz[30] = -24.559567715218954f*z2 + 36.839351572828434f*z4 + 1.754254836801354f ; // 15*sqrt(11)*(-14*z2 + 21*z4 + 1)/(16*sqrt(pi))
dz[31] = 12.682506233479513f*xz*(1.0f - 3.0f*z2) ; // 7*sqrt(165)*xz*(1 - 3*z2)/(4*sqrt(pi))
dz[32] = 2.3967683924866621f*(x2 - y2)*(9.0f*z2 - 1.0f) ; // sqrt(1155)*(x2 - y2)*(9*z2 - 1)/(8*sqrt(pi))
dz[33] = 8.8062893898345074f*xz*(-x2 + 3.0f*y2) ; // 9*sqrt(770)*xz*(-x2 + 3*y2)/(16*sqrt(pi))
dz[34] = -12.453973889286246f*x2*y2 + 2.0756623148810411f*x4 + 2.0756623148810411f*y4 ; // 3*sqrt(385)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
dz[35] = 0.0f ; // 0
if (C <= 6) { return; }
dz[36] = 0.0f ; // 0
dz[37] = 2.3666191622317521f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(2002)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
dz[38] = 44.401711264127719f*xy*z*(x2 - y2) ; // 33*sqrt(91)*xy*z*(x2 - y2)/(4*sqrt(pi))
dz[39] = -2.7636157785447706f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(2730)*y*(3*x2 - y2)*(11*z2 - 1)/(32*sqrt(pi))
dz[40] = 11.054463114179082f*xy*z*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(8*sqrt(pi))
dz[41] = 2.9131068125936568f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // 5*sqrt(273)*y*(18*z2 - 33*z4 - 1)/(16*sqrt(pi))
dz[42] = 2.6699064952403937f*z*(-30.0f*z2 + 33.0f*z4 + 5.0f) ; // 21*sqrt(13)*z*(-30*z2 + 33*z4 + 5)/(16*sqrt(pi))
dz[43] = 2.9131068125936568f*x*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // 5*sqrt(273)*x*(18*z2 - 33*z4 - 1)/(16*sqrt(pi))
dz[44] = 5.5272315570895412f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(16*sqrt(pi))
dz[45] = -2.7636157785447706f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(2730)*x*(x2 - 3*y2)*(11*z2 - 1)/(32*sqrt(pi))
dz[46] = 11.10042781603193f*z*(-6.0f*x2*y2 + x4 + y4) ; // 33*sqrt(91)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
dz[47] = 2.3666191622317521f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(2002)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
dz[48] = 0.0f ; // 0
if (C <= 7) { return; }
dz[49] = 0.0f ; // 0
dz[50] = 5.2919213236038001f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 3*sqrt(10010)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
dz[51] = 13.491805046726766f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 39*sqrt(385)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
dz[52] = 12.453973889286248f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // 9*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(8*sqrt(pi))
dz[53] = -6.8841930899409371f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // -33*sqrt(35)*yz*(3*x2 - y2)*(13*z2 - 3)/(16*sqrt(pi))
dz[54] = 2.2126634622249131f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ; // 15*sqrt(70)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi))
dz[55] = 1.6259689364853116f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 9*sqrt(105)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))
dz[56] = 64.528641681844675f*z2 - 236.60501950009714f*z4 + 205.05768356675085f*z6 - 2.3899496919201733f ; // 7*sqrt(15)*(135*z2 - 495*z4 + 429*z6 - 5)/(32*sqrt(pi))
dz[57] = 1.6259689364853116f*xz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 9*sqrt(105)*xz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))
dz[58] = 0.07375544874083044f*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) + 132.0f*z2*(13.0f*z2 - 5.0f) - 187.0f*z2 + 45.0f) ; // sqrt(70)*(x2 - y2)*(143*z2*(3*z2 - 1) + 132*z2*(13*z2 - 5) - 187*z2 + 45)/(64*sqrt(pi))
dz[59] = -6.8841930899409371f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // -33*sqrt(35)*xz*(x2 - 3*y2)*(13*z2 - 3)/(16*sqrt(pi))
dz[60] = 3.1134934723215619f*(13.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 9*sqrt(385)*(13*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))
dz[61] = 13.491805046726766f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 39*sqrt(385)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
dz[62] = 39.6894099270285f*x2*y4 - 39.6894099270285f*x4*y2 + 2.6459606618019f*x6 - 2.6459606618019f*y6 ; // 3*sqrt(10010)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))
dz[63] = 0.0f ; // 0
};
write_sh_dx();
write_sh_dy();
write_sh_dz();
}
}
template <typename scalar_t>
__global__ void kernel_sh_backward(
const scalar_t * __restrict__ grad,
const scalar_t * __restrict__ inputs,
uint32_t B, uint32_t D, uint32_t C,
const scalar_t * __restrict__ dy_dx,
scalar_t * grad_inputs
) {
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
const uint32_t b = t / D;
if (b >= B) return;
const uint32_t d = t - b * D;
const uint32_t C2 = C * C;
// locate
grad += b * C2;
dy_dx += b * D * C2 + d * C2;
for (int ch = 0; ch < C2; ch++) {
grad_inputs[t] += grad[ch] * dy_dx[ch];
//printf("t=%d, b=%d, d=%d, ch=%d, grad=%f (+= %f * %f)\n", t, b, d, ch, grad_inputs[t], grad[ch], dy_dx[ch]);
}
}
// inputs: [B, D], float, in [0, 1]
// outputs: [B, L * C], float
template <typename scalar_t>
void sh_encode_forward_cuda(const scalar_t *inputs, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx) {
static constexpr uint32_t N_THREADS = 256;
kernel_sh<scalar_t><<<div_round_up(B, N_THREADS), N_THREADS>>>(inputs, outputs, B, D, C, dy_dx);
}
template <typename scalar_t>
void sh_encode_backward_cuda(const scalar_t *grad, const scalar_t *inputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx, scalar_t *grad_inputs) {
static constexpr uint32_t N_THREADS = 256;
kernel_sh_backward<scalar_t><<<div_round_up(B * D, N_THREADS), N_THREADS>>>(grad, inputs, B, D, C, dy_dx, grad_inputs);
}
void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional<at::Tensor> dy_dx) {
CHECK_CUDA(inputs);
CHECK_CUDA(outputs);
// CHECK_CUDA(dy_dx);
CHECK_CONTIGUOUS(inputs);
CHECK_CONTIGUOUS(outputs);
// CHECK_CONTIGUOUS(dy_dx);
CHECK_IS_FLOATING(inputs);
CHECK_IS_FLOATING(outputs);
// CHECK_IS_FLOATING(dy_dx);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
inputs.scalar_type(), "sh_encode_forward_cuda", ([&] {
sh_encode_forward_cuda<scalar_t>(inputs.data_ptr<scalar_t>(), outputs.data_ptr<scalar_t>(), B, D, C, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr);
}));
}
void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs) {
CHECK_CUDA(grad);
CHECK_CUDA(inputs);
CHECK_CUDA(dy_dx);
CHECK_CUDA(grad_inputs);
CHECK_CONTIGUOUS(grad);
CHECK_CONTIGUOUS(inputs);
CHECK_CONTIGUOUS(dy_dx);
CHECK_CONTIGUOUS(grad_inputs);
CHECK_IS_FLOATING(grad);
CHECK_IS_FLOATING(inputs);
CHECK_IS_FLOATING(dy_dx);
CHECK_IS_FLOATING(grad_inputs);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "sh_encode_backward_cuda", ([&] {
sh_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<scalar_t>(), B, D, C, dy_dx.data_ptr<scalar_t>(), grad_inputs.data_ptr<scalar_t>());
}));
}

10
shencoder/src/shencoder.h Normal file
View File

@ -0,0 +1,10 @@
# pragma once
#include <stdint.h>
#include <torch/torch.h>
// inputs: [B, D], float, in [-1, 1]
// outputs: [B, F], float
void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional<at::Tensor> dy_dx);
void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs);

68
test.html Normal file

File diff suppressed because one or more lines are too long