git init
This commit is contained in:
parent
6e909475c5
commit
e4b2cab164
|
@ -0,0 +1,17 @@
|
|||
__pycache__/
|
||||
build/
|
||||
*.egg-info/
|
||||
*.so
|
||||
*.mp4
|
||||
|
||||
tmp*
|
||||
trial*/
|
||||
|
||||
data
|
||||
data_utils/face_tracking/3DMM/*
|
||||
data_utils/face_parsing/79999_iter.pth
|
||||
|
||||
pretrained
|
||||
*.mp4
|
||||
.DS_Store
|
||||
workspace/log_ngp.txt
|
|
@ -0,0 +1,121 @@
|
|||
# 虚拟人说话头生成(照片虚拟人实时驱动)
|
||||
![](/img/example.gif)
|
||||
# Get Started
|
||||
|
||||
## Installation
|
||||
|
||||
Tested on Ubuntu 22.04, Pytorch 1.12 and CUDA 11.6,or Pytorch 1.12 and CUDA 11.3
|
||||
|
||||
```python
|
||||
git clone https://github.com/waityousea/xuniren.git
|
||||
cd xuniren
|
||||
```
|
||||
|
||||
### Install dependency
|
||||
|
||||
```python
|
||||
# for ubuntu, portaudio is needed for pyaudio to work.
|
||||
sudo apt install portaudio19-dev
|
||||
|
||||
pip install -r requirements.txt
|
||||
or
|
||||
## environment.yml中的pytorch使用的1.12和cuda 11.3
|
||||
conda env create -f environment.yml
|
||||
## install pytorch3d
|
||||
#ubuntu/mac
|
||||
pip install "git+https://github.com/facebookresearch/pytorch3d.git"
|
||||
```
|
||||
|
||||
**windows安装pytorch3d**
|
||||
|
||||
- gcc & g++ ≥ 4.9
|
||||
|
||||
在windows中,需要安装gcc编译器,可以根据需求自行安装,例如采用MinGW
|
||||
|
||||
以下安装步骤来自于[pytorch3d](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md)官方, 可以根据需求进行选择。
|
||||
|
||||
```python
|
||||
conda create -n pytorch3d python=3.9
|
||||
conda activate pytorch3d
|
||||
conda install pytorch=1.13.0 torchvision pytorch-cuda=11.6 -c pytorch -c nvidia
|
||||
conda install -c fvcore -c iopath -c conda-forge fvcore iopath
|
||||
```
|
||||
|
||||
对于 CUB 构建时间依赖项,仅当您的 CUDA 早于 11.7 时才需要,如果您使用的是 conda,则可以继续
|
||||
|
||||
```
|
||||
conda install -c bottler nvidiacub
|
||||
```
|
||||
|
||||
```
|
||||
# Demos and examples
|
||||
conda install jupyter
|
||||
pip install scikit-image matplotlib imageio plotly opencv-python
|
||||
|
||||
# Tests/Linting
|
||||
pip install black usort flake8 flake8-bugbear flake8-comprehensions
|
||||
```
|
||||
|
||||
任何必要的补丁后,你可以去“x64 Native Tools Command Prompt for VS 2019”编译安装
|
||||
|
||||
```
|
||||
git clone https://github.com/facebookresearch/pytorch3d.git
|
||||
cd pytorch3d
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
### Build extension
|
||||
|
||||
By default, we use [`load`](https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.load) to build the extension at runtime. However, this may be inconvenient sometimes. Therefore, we also provide the `setup.py` to build each extension:
|
||||
|
||||
```
|
||||
# install all extension modules
|
||||
# notice: 该模块必须安装。
|
||||
# 在windows下,建议采用vs2019的x64 Native Tools Command Prompt for VS 2019命令窗口安装
|
||||
bash scripts/install_ext.sh
|
||||
```
|
||||
|
||||
### **start(独立运行)**
|
||||
|
||||
环境配置完成后,启动虚拟人生成器:
|
||||
|
||||
```python
|
||||
python app.py
|
||||
```
|
||||
### **start(对接fay,在ubuntu 20下完成测试)**
|
||||
环境配置完成后,启动fay对接脚本
|
||||
```python
|
||||
python fay_connect.py
|
||||
```
|
||||
![](img/weplay.png)
|
||||
|
||||
扫码支助开源开发工作,凭支付单号入qq交流群
|
||||
|
||||
|
||||
|
||||
接口的输入与输出信息 [Websoket.md](https://github.com/waityousea/xuniren/blob/main/WebSocket.md)
|
||||
|
||||
虚拟人生成的核心文件
|
||||
|
||||
```python
|
||||
## 注意,核心文件需要单独训练
|
||||
.
|
||||
├── data
|
||||
│ ├── kf.json
|
||||
│ ├── pretrained
|
||||
│ └── └── ngp_kg.pth
|
||||
|
||||
```
|
||||
|
||||
### Inference Speed
|
||||
|
||||
在台式机RTX A4000或笔记本RTX 3080ti的显卡(显存16G)上进行视频推理时,1s可以推理35~43帧,假如1s视频25帧,则1s可推理约1.5s视频。
|
||||
|
||||
# Acknowledgement
|
||||
|
||||
- The data pre-processing part is adapted from [AD-NeRF](https://github.com/YudongGuo/AD-NeRF).
|
||||
- The NeRF framework is based on [torch-ngp](https://github.com/ashawkey/torch-ngp).
|
||||
- The algorithm core come from [RAD-NeRF](https://github.com/ashawkey/RAD-NeRF).
|
||||
- Usage example [Fay](https://github.com/TheRamU/Fay).
|
||||
|
||||
学术交流可发邮件到邮箱:waityousea@126.com
|
|
@ -0,0 +1,251 @@
|
|||
# server.py
|
||||
from flask import Flask, request, jsonify
|
||||
from flask_sockets import Sockets
|
||||
import base64
|
||||
import time
|
||||
import json
|
||||
import gevent
|
||||
from gevent import pywsgi
|
||||
from geventwebsocket.handler import WebSocketHandler
|
||||
from tools import audio_pre_process, video_pre_process, generate_video,audio_process
|
||||
import os
|
||||
import re
|
||||
import numpy as np
|
||||
|
||||
import argparse
|
||||
from nerf_triplane.provider import NeRFDataset_Test
|
||||
from nerf_triplane.utils import *
|
||||
from nerf_triplane.network import NeRFNetwork
|
||||
from nerfreal import NeRFReal
|
||||
|
||||
import shutil
|
||||
import asyncio
|
||||
import edge_tts
|
||||
|
||||
app = Flask(__name__)
|
||||
sockets = Sockets(app)
|
||||
video_list = []
|
||||
global nerfreal
|
||||
|
||||
|
||||
async def main(voicename: str, text: str, render):
|
||||
communicate = edge_tts.Communicate(text, voicename)
|
||||
|
||||
#with open(OUTPUT_FILE, "wb") as file:
|
||||
async for chunk in communicate.stream():
|
||||
if chunk["type"] == "audio":
|
||||
render.push_audio(chunk["data"])
|
||||
#file.write(chunk["data"])
|
||||
elif chunk["type"] == "WordBoundary":
|
||||
pass
|
||||
|
||||
|
||||
def send_information(path, ws):
|
||||
|
||||
print('传输信息开始!')
|
||||
#path = video_list[0]
|
||||
''''''
|
||||
with open(path, 'rb') as f:
|
||||
video_data = base64.b64encode(f.read()).decode()
|
||||
|
||||
data = {
|
||||
'video': 'data:video/mp4;base64,%s' % video_data,
|
||||
}
|
||||
json_data = json.dumps(data)
|
||||
|
||||
ws.send(json_data)
|
||||
|
||||
|
||||
|
||||
def txt_to_audio(text_):
|
||||
audio_list = []
|
||||
#audio_path = 'data/audio/aud_0.wav'
|
||||
voicename = "zh-CN-YunxiaNeural"
|
||||
# 让我们一起学习。必应由 AI 提供支持,因此可能出现意外和错误。请确保核对事实,并 共享反馈以便我们可以学习和改进!
|
||||
text = text_
|
||||
asyncio.get_event_loop().run_until_complete(main(voicename,text,nerfreal))
|
||||
#audio_process(audio_path)
|
||||
|
||||
@sockets.route('/dighuman')
|
||||
def echo_socket(ws):
|
||||
# 获取WebSocket对象
|
||||
#ws = request.environ.get('wsgi.websocket')
|
||||
# 如果没有获取到,返回错误信息
|
||||
if not ws:
|
||||
print('未建立连接!')
|
||||
return 'Please use WebSocket'
|
||||
# 否则,循环接收和发送消息
|
||||
else:
|
||||
print('建立连接!')
|
||||
while True:
|
||||
message = ws.receive()
|
||||
|
||||
if len(message)==0:
|
||||
|
||||
return '输入信息为空'
|
||||
else:
|
||||
txt_to_audio(message)
|
||||
audio_path = 'data/audio/aud_0.wav'
|
||||
audio_path_eo = 'data/audio/aud_0_eo.npy'
|
||||
video_path = 'data/video/results/ngp_0.mp4'
|
||||
output_path = 'data/video/results/output_0.mp4'
|
||||
generate_video(audio_path, audio_path_eo, video_path, output_path)
|
||||
video_list.append(output_path)
|
||||
send_information(output_path, ws)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--pose', type=str, default="data/data_kf.json", help="transforms.json, pose source")
|
||||
|
||||
parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --exp_eye")
|
||||
|
||||
parser.add_argument('--data_range', type=int, nargs='*', default=[0, -1], help="data range to use")
|
||||
parser.add_argument('--workspace', type=str, default='data/video')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
|
||||
### training options
|
||||
parser.add_argument('--ckpt', type=str, default='data/pretrained/ngp_kf.pth')
|
||||
|
||||
parser.add_argument('--num_rays', type=int, default=4096 * 16, help="num rays sampled per image for each training step")
|
||||
parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
|
||||
parser.add_argument('--max_steps', type=int, default=16, help="max num steps sampled per ray (only valid when using --cuda_ray)")
|
||||
parser.add_argument('--num_steps', type=int, default=16, help="num steps sampled per ray (only valid when NOT using --cuda_ray)")
|
||||
parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)")
|
||||
parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
|
||||
parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)")
|
||||
|
||||
### loss set
|
||||
parser.add_argument('--warmup_step', type=int, default=10000, help="warm up steps")
|
||||
parser.add_argument('--amb_aud_loss', type=int, default=1, help="use ambient aud loss")
|
||||
parser.add_argument('--amb_eye_loss', type=int, default=1, help="use ambient eye loss")
|
||||
parser.add_argument('--unc_loss', type=int, default=1, help="use uncertainty loss")
|
||||
parser.add_argument('--lambda_amb', type=float, default=1e-4, help="lambda for ambient loss")
|
||||
|
||||
### network backbone options
|
||||
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
|
||||
|
||||
parser.add_argument('--bg_img', type=str, default='white', help="background image")
|
||||
parser.add_argument('--fbg', action='store_true', help="frame-wise bg")
|
||||
parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes")
|
||||
parser.add_argument('--fix_eye', type=float, default=-1, help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye")
|
||||
parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence")
|
||||
|
||||
parser.add_argument('--torso_shrink', type=float, default=0.8, help="shrink bg coords to allow more flexibility in deform")
|
||||
|
||||
### dataset options
|
||||
parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)")
|
||||
parser.add_argument('--preload', type=int, default=0, help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.")
|
||||
# (the default value is for the fox dataset)
|
||||
parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.")
|
||||
parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3")
|
||||
parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location")
|
||||
parser.add_argument('--dt_gamma', type=float, default=1/256, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
|
||||
parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera")
|
||||
parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied (sigma)")
|
||||
parser.add_argument('--density_thresh_torso', type=float, default=0.01, help="threshold for density grid to be occupied (alpha)")
|
||||
parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable")
|
||||
|
||||
parser.add_argument('--init_lips', action='store_true', help="init lips region")
|
||||
parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region")
|
||||
parser.add_argument('--smooth_lips', action='store_true', help="smooth the enc_a in a exponential decay way...")
|
||||
|
||||
parser.add_argument('--torso', action='store_true', help="fix head and train torso")
|
||||
parser.add_argument('--head_ckpt', type=str, default='', help="head model")
|
||||
|
||||
### GUI options
|
||||
parser.add_argument('--gui', action='store_true', help="start a GUI")
|
||||
parser.add_argument('--W', type=int, default=450, help="GUI width")
|
||||
parser.add_argument('--H', type=int, default=450, help="GUI height")
|
||||
parser.add_argument('--radius', type=float, default=3.35, help="default GUI camera radius from center")
|
||||
parser.add_argument('--fovy', type=float, default=21.24, help="default GUI camera fovy")
|
||||
parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel")
|
||||
|
||||
### else
|
||||
parser.add_argument('--att', type=int, default=2, help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)")
|
||||
parser.add_argument('--aud', type=str, default='', help="audio source (empty will load the default, else should be a path to a npy file)")
|
||||
parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits")
|
||||
|
||||
parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off")
|
||||
parser.add_argument('--ind_num', type=int, default=10000, help="number of individual codes, should be larger than training dataset size")
|
||||
|
||||
parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off")
|
||||
|
||||
parser.add_argument('--amb_dim', type=int, default=2, help="ambient dimension")
|
||||
parser.add_argument('--part', action='store_true', help="use partial training data (1/10)")
|
||||
parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)")
|
||||
|
||||
parser.add_argument('--train_camera', action='store_true', help="optimize camera pose")
|
||||
parser.add_argument('--smooth_path', action='store_true', help="brute-force smooth camera pose trajectory with a window size")
|
||||
parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size")
|
||||
|
||||
# asr
|
||||
parser.add_argument('--asr', action='store_true', help="load asr for real-time app")
|
||||
parser.add_argument('--asr_wav', type=str, default='', help="load the wav and use as input")
|
||||
parser.add_argument('--asr_play', action='store_true', help="play out the audio")
|
||||
|
||||
#parser.add_argument('--asr_model', type=str, default='deepspeech')
|
||||
parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto')
|
||||
# parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
|
||||
|
||||
parser.add_argument('--asr_save_feats', action='store_true')
|
||||
# audio FPS
|
||||
parser.add_argument('--fps', type=int, default=50)
|
||||
# sliding window left-middle-right length (unit: 20ms)
|
||||
parser.add_argument('-l', type=int, default=10)
|
||||
parser.add_argument('-m', type=int, default=50)
|
||||
parser.add_argument('-r', type=int, default=10)
|
||||
|
||||
opt = parser.parse_args()
|
||||
|
||||
# assert test mode
|
||||
opt.test = True
|
||||
opt.test_train = False
|
||||
#opt.train_camera =True
|
||||
# explicit smoothing
|
||||
opt.smooth_path = True
|
||||
opt.smooth_eye = True
|
||||
opt.smooth_lips = True
|
||||
|
||||
assert opt.pose != '', 'Must provide a pose source'
|
||||
|
||||
# if opt.O:
|
||||
opt.fp16 = True
|
||||
opt.exp_eye = True
|
||||
|
||||
opt.cuda_ray = True
|
||||
opt.torso = True
|
||||
# assert opt.cuda_ray, "Only support CUDA ray mode."
|
||||
opt.asr = True
|
||||
|
||||
if opt.patch_size > 1:
|
||||
# assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss."
|
||||
assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays."
|
||||
seed_everything(opt.seed)
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
model = NeRFNetwork(opt)
|
||||
|
||||
criterion = torch.nn.MSELoss(reduction='none')
|
||||
metrics = [] # use no metric in GUI for faster initialization...
|
||||
print(model)
|
||||
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt)
|
||||
|
||||
test_loader = NeRFDataset_Test(opt, device=device).dataloader()
|
||||
model.aud_features = test_loader._data.auds
|
||||
model.eye_areas = test_loader._data.eye_area
|
||||
|
||||
# we still need test_loader to provide audio features for testing.
|
||||
nerfreal = NeRFReal(opt, trainer, test_loader)
|
||||
txt_to_audio('我是中国人,我来自北京')
|
||||
nerfreal.render()
|
||||
|
||||
#############################################################################
|
||||
|
||||
server = pywsgi.WSGIServer(('127.0.0.1', 8800), app, handler_class=WebSocketHandler)
|
||||
server.serve_forever()
|
||||
|
||||
|
|
@ -0,0 +1,464 @@
|
|||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoModelForCTC, AutoProcessor
|
||||
|
||||
import pyaudio
|
||||
import soundfile as sf
|
||||
import resampy
|
||||
|
||||
from queue import Queue
|
||||
#from collections import deque
|
||||
from threading import Thread, Event
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
def _read_frame(stream, exit_event, queue, chunk):
|
||||
|
||||
while True:
|
||||
if exit_event.is_set():
|
||||
print(f'[INFO] read frame thread ends')
|
||||
break
|
||||
frame = stream.read(chunk, exception_on_overflow=False)
|
||||
frame = np.frombuffer(frame, dtype=np.int16).astype(np.float32) / 32767 # [chunk]
|
||||
queue.put(frame)
|
||||
|
||||
def _play_frame(stream, exit_event, queue, chunk):
|
||||
|
||||
while True:
|
||||
if exit_event.is_set():
|
||||
print(f'[INFO] play frame thread ends')
|
||||
break
|
||||
frame = queue.get()
|
||||
frame = (frame * 32767).astype(np.int16).tobytes()
|
||||
stream.write(frame, chunk)
|
||||
|
||||
class ASR:
|
||||
def __init__(self, opt):
|
||||
|
||||
self.opt = opt
|
||||
|
||||
self.play = opt.asr_play #false
|
||||
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
self.fps = opt.fps # 20 ms per frame
|
||||
self.sample_rate = 16000
|
||||
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
|
||||
self.mode = 'live' if opt.asr_wav == '' else 'file'
|
||||
|
||||
if 'esperanto' in self.opt.asr_model:
|
||||
self.audio_dim = 44
|
||||
elif 'deepspeech' in self.opt.asr_model:
|
||||
self.audio_dim = 29
|
||||
else:
|
||||
self.audio_dim = 32
|
||||
|
||||
# prepare context cache
|
||||
# each segment is (stride_left + ctx + stride_right) * 20ms, latency should be (ctx + stride_right) * 20ms
|
||||
self.context_size = opt.m
|
||||
self.stride_left_size = opt.l
|
||||
self.stride_right_size = opt.r
|
||||
self.text = '[START]\n'
|
||||
self.terminated = False
|
||||
self.frames = []
|
||||
self.inwarm = False
|
||||
|
||||
# pad left frames
|
||||
if self.stride_left_size > 0:
|
||||
self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size)
|
||||
|
||||
|
||||
self.exit_event = Event()
|
||||
#self.audio_instance = pyaudio.PyAudio() #not need
|
||||
|
||||
# create input stream
|
||||
if self.mode == 'file': #live mode
|
||||
self.file_stream = self.create_file_stream()
|
||||
else:
|
||||
self.queue = Queue()
|
||||
self.input_stream = BytesIO()
|
||||
self.output_queue = Queue()
|
||||
# start a background process to read frames
|
||||
#self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk)
|
||||
#self.queue = Queue()
|
||||
#self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk))
|
||||
|
||||
# play out the audio too...?
|
||||
if self.play:
|
||||
self.output_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=False, output=True, frames_per_buffer=self.chunk)
|
||||
self.output_queue = Queue()
|
||||
self.process_play_frame = Thread(target=_play_frame, args=(self.output_stream, self.exit_event, self.output_queue, self.chunk))
|
||||
|
||||
# current location of audio
|
||||
self.idx = 0
|
||||
|
||||
# create wav2vec model
|
||||
print(f'[INFO] loading ASR model {self.opt.asr_model}...')
|
||||
self.processor = AutoProcessor.from_pretrained(opt.asr_model)
|
||||
self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device)
|
||||
|
||||
# prepare to save logits
|
||||
if self.opt.asr_save_feats:
|
||||
self.all_feats = []
|
||||
|
||||
# the extracted features
|
||||
# use a loop queue to efficiently record endless features: [f--t---][-------][-------]
|
||||
self.feat_buffer_size = 4
|
||||
self.feat_buffer_idx = 0
|
||||
self.feat_queue = torch.zeros(self.feat_buffer_size * self.context_size, self.audio_dim, dtype=torch.float32, device=self.device)
|
||||
|
||||
# TODO: hard coded 16 and 8 window size...
|
||||
self.front = self.feat_buffer_size * self.context_size - 8 # fake padding
|
||||
self.tail = 8
|
||||
# attention window...
|
||||
self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding...
|
||||
|
||||
# warm up steps needed: mid + right + window_size + attention_size
|
||||
self.warm_up_steps = self.context_size + self.stride_right_size + 8 + 2 * 3
|
||||
|
||||
self.listening = False
|
||||
self.playing = False
|
||||
|
||||
def listen(self):
|
||||
# start
|
||||
if self.mode == 'live' and not self.listening:
|
||||
print(f'[INFO] starting read frame thread...')
|
||||
self.process_read_frame.start()
|
||||
self.listening = True
|
||||
|
||||
if self.play and not self.playing:
|
||||
print(f'[INFO] starting play frame thread...')
|
||||
self.process_play_frame.start()
|
||||
self.playing = True
|
||||
|
||||
def stop(self):
|
||||
|
||||
self.exit_event.set()
|
||||
|
||||
if self.play:
|
||||
self.output_stream.stop_stream()
|
||||
self.output_stream.close()
|
||||
if self.playing:
|
||||
self.process_play_frame.join()
|
||||
self.playing = False
|
||||
|
||||
if self.mode == 'live':
|
||||
#self.input_stream.stop_stream() todo
|
||||
self.input_stream.close()
|
||||
if self.listening:
|
||||
self.process_read_frame.join()
|
||||
self.listening = False
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
|
||||
self.stop()
|
||||
|
||||
if self.mode == 'live':
|
||||
# live mode: also print the result text.
|
||||
self.text += '\n[END]'
|
||||
print(self.text)
|
||||
|
||||
def get_next_feat(self):
|
||||
# return a [1/8, 16] window, for the next input to nerf side.
|
||||
|
||||
while len(self.att_feats) < 8:
|
||||
# [------f+++t-----]
|
||||
if self.front < self.tail:
|
||||
feat = self.feat_queue[self.front:self.tail]
|
||||
# [++t-----------f+]
|
||||
else:
|
||||
feat = torch.cat([self.feat_queue[self.front:], self.feat_queue[:self.tail]], dim=0)
|
||||
|
||||
self.front = (self.front + 2) % self.feat_queue.shape[0]
|
||||
self.tail = (self.tail + 2) % self.feat_queue.shape[0]
|
||||
|
||||
# print(self.front, self.tail, feat.shape)
|
||||
|
||||
self.att_feats.append(feat.permute(1, 0))
|
||||
|
||||
att_feat = torch.stack(self.att_feats, dim=0) # [8, 44, 16]
|
||||
|
||||
# discard old
|
||||
self.att_feats = self.att_feats[1:]
|
||||
|
||||
return att_feat
|
||||
|
||||
def run_step(self):
|
||||
|
||||
if self.terminated:
|
||||
return
|
||||
|
||||
# get a frame of audio
|
||||
frame = self.get_audio_frame()
|
||||
|
||||
# the last frame
|
||||
if frame is None:
|
||||
# terminate, but always run the network for the left frames
|
||||
self.terminated = True
|
||||
else:
|
||||
self.frames.append(frame)
|
||||
# put to output
|
||||
#if self.play:
|
||||
self.output_queue.put(frame)
|
||||
# context not enough, do not run network.
|
||||
if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
|
||||
return
|
||||
|
||||
inputs = np.concatenate(self.frames) # [N * chunk]
|
||||
|
||||
# discard the old part to save memory
|
||||
if not self.terminated:
|
||||
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
|
||||
|
||||
print(f'[INFO] frame_to_text... ')
|
||||
logits, labels, text = self.frame_to_text(inputs)
|
||||
feats = logits # better lips-sync than labels
|
||||
|
||||
# save feats
|
||||
if self.opt.asr_save_feats:
|
||||
self.all_feats.append(feats)
|
||||
|
||||
# record the feats efficiently.. (no concat, constant memory)
|
||||
start = self.feat_buffer_idx * self.context_size
|
||||
end = start + feats.shape[0]
|
||||
self.feat_queue[start:end] = feats
|
||||
self.feat_buffer_idx = (self.feat_buffer_idx + 1) % self.feat_buffer_size
|
||||
|
||||
# very naive, just concat the text output.
|
||||
if text != '':
|
||||
self.text = self.text + ' ' + text
|
||||
|
||||
# will only run once at ternimation
|
||||
if self.terminated:
|
||||
self.text += '\n[END]'
|
||||
print(self.text)
|
||||
if self.opt.asr_save_feats:
|
||||
print(f'[INFO] save all feats for training purpose... ')
|
||||
feats = torch.cat(self.all_feats, dim=0) # [N, C]
|
||||
# print('[INFO] before unfold', feats.shape)
|
||||
window_size = 16
|
||||
padding = window_size // 2
|
||||
feats = feats.view(-1, self.audio_dim).permute(1, 0).contiguous() # [C, M]
|
||||
feats = feats.view(1, self.audio_dim, -1, 1) # [1, C, M, 1]
|
||||
unfold_feats = F.unfold(feats, kernel_size=(window_size, 1), padding=(padding, 0), stride=(2, 1)) # [1, C * window_size, M / 2 + 1]
|
||||
unfold_feats = unfold_feats.view(self.audio_dim, window_size, -1).permute(2, 1, 0).contiguous() # [C, window_size, M / 2 + 1] --> [M / 2 + 1, window_size, C]
|
||||
# print('[INFO] after unfold', unfold_feats.shape)
|
||||
# save to a npy file
|
||||
if 'esperanto' in self.opt.asr_model:
|
||||
output_path = self.opt.asr_wav.replace('.wav', '_eo.npy')
|
||||
else:
|
||||
output_path = self.opt.asr_wav.replace('.wav', '.npy')
|
||||
np.save(output_path, unfold_feats.cpu().numpy())
|
||||
print(f"[INFO] saved logits to {output_path}")
|
||||
|
||||
def create_file_stream(self):
|
||||
|
||||
stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64
|
||||
stream = stream.astype(np.float32)
|
||||
|
||||
if stream.ndim > 1:
|
||||
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
|
||||
stream = stream[:, 0]
|
||||
|
||||
if sample_rate != self.sample_rate:
|
||||
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
|
||||
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
|
||||
|
||||
print(f'[INFO] loaded audio stream {self.opt.asr_wav}: {stream.shape}')
|
||||
|
||||
return stream
|
||||
|
||||
|
||||
def create_pyaudio_stream(self):
|
||||
|
||||
import pyaudio
|
||||
|
||||
print(f'[INFO] creating live audio stream ...')
|
||||
|
||||
audio = pyaudio.PyAudio()
|
||||
|
||||
# get devices
|
||||
info = audio.get_host_api_info_by_index(0)
|
||||
n_devices = info.get('deviceCount')
|
||||
|
||||
for i in range(0, n_devices):
|
||||
if (audio.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
|
||||
name = audio.get_device_info_by_host_api_device_index(0, i).get('name')
|
||||
print(f'[INFO] choose audio device {name}, id {i}')
|
||||
break
|
||||
|
||||
# get stream
|
||||
stream = audio.open(input_device_index=i,
|
||||
format=pyaudio.paInt16,
|
||||
channels=1,
|
||||
rate=self.sample_rate,
|
||||
input=True,
|
||||
frames_per_buffer=self.chunk)
|
||||
|
||||
return audio, stream
|
||||
|
||||
|
||||
def get_audio_frame(self):
|
||||
|
||||
if self.inwarm: # warm up
|
||||
return np.zeros(self.chunk, dtype=np.float32)
|
||||
|
||||
if self.mode == 'file':
|
||||
|
||||
if self.idx < self.file_stream.shape[0]:
|
||||
frame = self.file_stream[self.idx: self.idx + self.chunk]
|
||||
self.idx = self.idx + self.chunk
|
||||
return frame
|
||||
else:
|
||||
return None
|
||||
|
||||
else:
|
||||
|
||||
frame = self.queue.get()
|
||||
print(f'[INFO] get frame {frame.shape}')
|
||||
|
||||
self.idx = self.idx + self.chunk
|
||||
|
||||
return frame
|
||||
|
||||
|
||||
def frame_to_text(self, frame):
|
||||
# frame: [N * 320], N = (context_size + 2 * stride_size)
|
||||
|
||||
inputs = self.processor(frame, sampling_rate=self.sample_rate, return_tensors="pt", padding=True)
|
||||
|
||||
with torch.no_grad():
|
||||
result = self.model(inputs.input_values.to(self.device))
|
||||
logits = result.logits # [1, N - 1, 32]
|
||||
|
||||
# cut off stride
|
||||
left = max(0, self.stride_left_size)
|
||||
right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input.
|
||||
|
||||
# do not cut right if terminated.
|
||||
if self.terminated:
|
||||
right = logits.shape[1]
|
||||
|
||||
logits = logits[:, left:right]
|
||||
|
||||
# print(frame.shape, inputs.input_values.shape, logits.shape)
|
||||
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
transcription = self.processor.batch_decode(predicted_ids)[0].lower()
|
||||
|
||||
|
||||
# for esperanto
|
||||
# labels = np.array(['ŭ', '»', 'c', 'ĵ', 'ñ', '”', '„', '“', 'ǔ', 'o', 'ĝ', 'm', 'k', 'd', 'a', 'ŝ', 'z', 'i', '«', '—', '‘', 'ĥ', 'f', 'y', 'h', 'j', '|', 'r', 'u', 'ĉ', 's', '–', 'fi', 'l', 'p', '’', 'g', 'v', 't', 'b', 'n', 'e', '[UNK]', '[PAD]'])
|
||||
|
||||
# labels = np.array([' ', ' ', ' ', '-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z'])
|
||||
# print(''.join(labels[predicted_ids[0].detach().cpu().long().numpy()]))
|
||||
# print(predicted_ids[0])
|
||||
# print(transcription)
|
||||
|
||||
return logits[0], predicted_ids[0], transcription # [N,]
|
||||
|
||||
def create_bytes_stream(self,byte_stream):
|
||||
#byte_stream=BytesIO(buffer)
|
||||
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
|
||||
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
|
||||
stream = stream.astype(np.float32)
|
||||
|
||||
if stream.ndim > 1:
|
||||
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
|
||||
stream = stream[:, 0]
|
||||
|
||||
if sample_rate != self.sample_rate:
|
||||
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
|
||||
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
|
||||
|
||||
return stream
|
||||
|
||||
def push_audio(self,buffer):
|
||||
print(f'[INFO] push_audio {len(buffer)}')
|
||||
self.input_stream.write(buffer)
|
||||
if len(buffer)<=0:
|
||||
self.input_stream.seek(0)
|
||||
stream = self.create_bytes_stream(self.input_stream)
|
||||
streamlen = stream.shape[0]
|
||||
idx=0
|
||||
while streamlen >= self.chunk:
|
||||
self.queue.put(stream[idx:idx+self.chunk])
|
||||
streamlen -= self.chunk
|
||||
idx += self.chunk
|
||||
if streamlen>0:
|
||||
self.queue.put(stream[idx:])
|
||||
|
||||
def get_audio_out(self):
|
||||
return self.output_queue.get()
|
||||
|
||||
def run(self):
|
||||
|
||||
self.listen()
|
||||
|
||||
while not self.terminated:
|
||||
self.run_step()
|
||||
|
||||
def clear_queue(self):
|
||||
# clear the queue, to reduce potential latency...
|
||||
print(f'[INFO] clear queue')
|
||||
if self.mode == 'live':
|
||||
self.queue.queue.clear()
|
||||
if self.play:
|
||||
self.output_queue.queue.clear()
|
||||
|
||||
def warm_up(self):
|
||||
|
||||
#self.listen()
|
||||
|
||||
self.inwarm = True
|
||||
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
|
||||
t = time.time()
|
||||
for _ in range(self.warm_up_steps):
|
||||
self.run_step()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
t = time.time() - t
|
||||
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
|
||||
self.inwarm = False
|
||||
|
||||
#self.clear_queue()
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--wav', type=str, default='')
|
||||
parser.add_argument('--play', action='store_true', help="play out the audio")
|
||||
|
||||
parser.add_argument('--model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto')
|
||||
# parser.add_argument('--model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
|
||||
|
||||
parser.add_argument('--save_feats', action='store_true')
|
||||
# audio FPS
|
||||
parser.add_argument('--fps', type=int, default=50)
|
||||
# sliding window left-middle-right length.
|
||||
parser.add_argument('-l', type=int, default=10)
|
||||
parser.add_argument('-m', type=int, default=50)
|
||||
parser.add_argument('-r', type=int, default=10)
|
||||
|
||||
opt = parser.parse_args()
|
||||
|
||||
# fix
|
||||
opt.asr_wav = opt.wav
|
||||
opt.asr_play = opt.play
|
||||
opt.asr_model = opt.model
|
||||
opt.asr_save_feats = opt.save_feats
|
||||
|
||||
if 'deepspeech' in opt.asr_model:
|
||||
raise ValueError("DeepSpeech features should not use this code to extract...")
|
||||
|
||||
with ASR(opt) as asr:
|
||||
asr.run()
|
Binary file not shown.
After Width: | Height: | Size: 182 KiB |
File diff suppressed because it is too large
Load Diff
Binary file not shown.
|
@ -0,0 +1,20 @@
|
|||
# Routines for DeepSpeech features processing
|
||||
Several routines for [DeepSpeech](https://github.com/mozilla/DeepSpeech) features processing, like speech features generation for [VOCA](https://github.com/TimoBolkart/voca) model.
|
||||
|
||||
## Installation
|
||||
|
||||
```
|
||||
pip3 install -r requirements.txt
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
Generate wav files:
|
||||
```
|
||||
python3 extract_wav.py --in-video=<you_data_dir>
|
||||
```
|
||||
|
||||
Generate files with DeepSpeech features:
|
||||
```
|
||||
python3 extract_ds_features.py --input=<you_data_dir>
|
||||
```
|
|
@ -0,0 +1,275 @@
|
|||
"""
|
||||
DeepSpeech features processing routines.
|
||||
NB: Based on VOCA code. See the corresponding license restrictions.
|
||||
"""
|
||||
|
||||
__all__ = ['conv_audios_to_deepspeech']
|
||||
|
||||
import numpy as np
|
||||
import warnings
|
||||
import resampy
|
||||
from scipy.io import wavfile
|
||||
from python_speech_features import mfcc
|
||||
import tensorflow.compat.v1 as tf
|
||||
tf.disable_v2_behavior()
|
||||
|
||||
def conv_audios_to_deepspeech(audios,
|
||||
out_files,
|
||||
num_frames_info,
|
||||
deepspeech_pb_path,
|
||||
audio_window_size=1,
|
||||
audio_window_stride=1):
|
||||
"""
|
||||
Convert list of audio files into files with DeepSpeech features.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audios : list of str or list of None
|
||||
Paths to input audio files.
|
||||
out_files : list of str
|
||||
Paths to output files with DeepSpeech features.
|
||||
num_frames_info : list of int
|
||||
List of numbers of frames.
|
||||
deepspeech_pb_path : str
|
||||
Path to DeepSpeech 0.1.0 frozen model.
|
||||
audio_window_size : int, default 16
|
||||
Audio window size.
|
||||
audio_window_stride : int, default 1
|
||||
Audio window stride.
|
||||
"""
|
||||
# deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
|
||||
graph, logits_ph, input_node_ph, input_lengths_ph = prepare_deepspeech_net(
|
||||
deepspeech_pb_path)
|
||||
|
||||
with tf.compat.v1.Session(graph=graph) as sess:
|
||||
for audio_file_path, out_file_path, num_frames in zip(audios, out_files, num_frames_info):
|
||||
print(audio_file_path)
|
||||
print(out_file_path)
|
||||
audio_sample_rate, audio = wavfile.read(audio_file_path)
|
||||
if audio.ndim != 1:
|
||||
warnings.warn(
|
||||
"Audio has multiple channels, the first channel is used")
|
||||
audio = audio[:, 0]
|
||||
ds_features = pure_conv_audio_to_deepspeech(
|
||||
audio=audio,
|
||||
audio_sample_rate=audio_sample_rate,
|
||||
audio_window_size=audio_window_size,
|
||||
audio_window_stride=audio_window_stride,
|
||||
num_frames=num_frames,
|
||||
net_fn=lambda x: sess.run(
|
||||
logits_ph,
|
||||
feed_dict={
|
||||
input_node_ph: x[np.newaxis, ...],
|
||||
input_lengths_ph: [x.shape[0]]}))
|
||||
|
||||
net_output = ds_features.reshape(-1, 29)
|
||||
win_size = 16
|
||||
zero_pad = np.zeros((int(win_size / 2), net_output.shape[1]))
|
||||
net_output = np.concatenate(
|
||||
(zero_pad, net_output, zero_pad), axis=0)
|
||||
windows = []
|
||||
for window_index in range(0, net_output.shape[0] - win_size, 2):
|
||||
windows.append(
|
||||
net_output[window_index:window_index + win_size])
|
||||
print(np.array(windows).shape)
|
||||
np.save(out_file_path, np.array(windows))
|
||||
|
||||
|
||||
def prepare_deepspeech_net(deepspeech_pb_path):
|
||||
"""
|
||||
Load and prepare DeepSpeech network.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
deepspeech_pb_path : str
|
||||
Path to DeepSpeech 0.1.0 frozen model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
graph : obj
|
||||
ThensorFlow graph.
|
||||
logits_ph : obj
|
||||
ThensorFlow placeholder for `logits`.
|
||||
input_node_ph : obj
|
||||
ThensorFlow placeholder for `input_node`.
|
||||
input_lengths_ph : obj
|
||||
ThensorFlow placeholder for `input_lengths`.
|
||||
"""
|
||||
# Load graph and place_holders:
|
||||
with tf.io.gfile.GFile(deepspeech_pb_path, "rb") as f:
|
||||
graph_def = tf.compat.v1.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
|
||||
graph = tf.compat.v1.get_default_graph()
|
||||
tf.import_graph_def(graph_def, name="deepspeech")
|
||||
logits_ph = graph.get_tensor_by_name("deepspeech/logits:0")
|
||||
input_node_ph = graph.get_tensor_by_name("deepspeech/input_node:0")
|
||||
input_lengths_ph = graph.get_tensor_by_name("deepspeech/input_lengths:0")
|
||||
|
||||
return graph, logits_ph, input_node_ph, input_lengths_ph
|
||||
|
||||
|
||||
def pure_conv_audio_to_deepspeech(audio,
|
||||
audio_sample_rate,
|
||||
audio_window_size,
|
||||
audio_window_stride,
|
||||
num_frames,
|
||||
net_fn):
|
||||
"""
|
||||
Core routine for converting audion into DeepSpeech features.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : np.array
|
||||
Audio data.
|
||||
audio_sample_rate : int
|
||||
Audio sample rate.
|
||||
audio_window_size : int
|
||||
Audio window size.
|
||||
audio_window_stride : int
|
||||
Audio window stride.
|
||||
num_frames : int or None
|
||||
Numbers of frames.
|
||||
net_fn : func
|
||||
Function for DeepSpeech model call.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.array
|
||||
DeepSpeech features.
|
||||
"""
|
||||
target_sample_rate = 16000
|
||||
if audio_sample_rate != target_sample_rate:
|
||||
resampled_audio = resampy.resample(
|
||||
x=audio.astype(np.float),
|
||||
sr_orig=audio_sample_rate,
|
||||
sr_new=target_sample_rate)
|
||||
else:
|
||||
resampled_audio = audio.astype(np.float)
|
||||
input_vector = conv_audio_to_deepspeech_input_vector(
|
||||
audio=resampled_audio.astype(np.int16),
|
||||
sample_rate=target_sample_rate,
|
||||
num_cepstrum=26,
|
||||
num_context=9)
|
||||
|
||||
network_output = net_fn(input_vector)
|
||||
# print(network_output.shape)
|
||||
|
||||
deepspeech_fps = 50
|
||||
video_fps = 50 # Change this option if video fps is different
|
||||
audio_len_s = float(audio.shape[0]) / audio_sample_rate
|
||||
if num_frames is None:
|
||||
num_frames = int(round(audio_len_s * video_fps))
|
||||
else:
|
||||
video_fps = num_frames / audio_len_s
|
||||
network_output = interpolate_features(
|
||||
features=network_output[:, 0],
|
||||
input_rate=deepspeech_fps,
|
||||
output_rate=video_fps,
|
||||
output_len=num_frames)
|
||||
|
||||
# Make windows:
|
||||
zero_pad = np.zeros((int(audio_window_size / 2), network_output.shape[1]))
|
||||
network_output = np.concatenate(
|
||||
(zero_pad, network_output, zero_pad), axis=0)
|
||||
windows = []
|
||||
for window_index in range(0, network_output.shape[0] - audio_window_size, audio_window_stride):
|
||||
windows.append(
|
||||
network_output[window_index:window_index + audio_window_size])
|
||||
|
||||
return np.array(windows)
|
||||
|
||||
|
||||
def conv_audio_to_deepspeech_input_vector(audio,
|
||||
sample_rate,
|
||||
num_cepstrum,
|
||||
num_context):
|
||||
"""
|
||||
Convert audio raw data into DeepSpeech input vector.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : np.array
|
||||
Audio data.
|
||||
audio_sample_rate : int
|
||||
Audio sample rate.
|
||||
num_cepstrum : int
|
||||
Number of cepstrum.
|
||||
num_context : int
|
||||
Number of context.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.array
|
||||
DeepSpeech input vector.
|
||||
"""
|
||||
# Get mfcc coefficients:
|
||||
features = mfcc(
|
||||
signal=audio,
|
||||
samplerate=sample_rate,
|
||||
numcep=num_cepstrum)
|
||||
|
||||
# We only keep every second feature (BiRNN stride = 2):
|
||||
features = features[::2]
|
||||
|
||||
# One stride per time step in the input:
|
||||
num_strides = len(features)
|
||||
|
||||
# Add empty initial and final contexts:
|
||||
empty_context = np.zeros((num_context, num_cepstrum), dtype=features.dtype)
|
||||
features = np.concatenate((empty_context, features, empty_context))
|
||||
|
||||
# Create a view into the array with overlapping strides of size
|
||||
# numcontext (past) + 1 (present) + numcontext (future):
|
||||
window_size = 2 * num_context + 1
|
||||
train_inputs = np.lib.stride_tricks.as_strided(
|
||||
features,
|
||||
shape=(num_strides, window_size, num_cepstrum),
|
||||
strides=(features.strides[0],
|
||||
features.strides[0], features.strides[1]),
|
||||
writeable=False)
|
||||
|
||||
# Flatten the second and third dimensions:
|
||||
train_inputs = np.reshape(train_inputs, [num_strides, -1])
|
||||
|
||||
train_inputs = np.copy(train_inputs)
|
||||
train_inputs = (train_inputs - np.mean(train_inputs)) / \
|
||||
np.std(train_inputs)
|
||||
|
||||
return train_inputs
|
||||
|
||||
|
||||
def interpolate_features(features,
|
||||
input_rate,
|
||||
output_rate,
|
||||
output_len):
|
||||
"""
|
||||
Interpolate DeepSpeech features.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
features : np.array
|
||||
DeepSpeech features.
|
||||
input_rate : int
|
||||
input rate (FPS).
|
||||
output_rate : int
|
||||
Output rate (FPS).
|
||||
output_len : int
|
||||
Output data length.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.array
|
||||
Interpolated data.
|
||||
"""
|
||||
input_len = features.shape[0]
|
||||
num_features = features.shape[1]
|
||||
input_timestamps = np.arange(input_len) / float(input_rate)
|
||||
output_timestamps = np.arange(output_len) / float(output_rate)
|
||||
output_features = np.zeros((output_len, num_features))
|
||||
for feature_idx in range(num_features):
|
||||
output_features[:, feature_idx] = np.interp(
|
||||
x=output_timestamps,
|
||||
xp=input_timestamps,
|
||||
fp=features[:, feature_idx])
|
||||
return output_features
|
|
@ -0,0 +1,172 @@
|
|||
"""
|
||||
Routines for loading DeepSpeech model.
|
||||
"""
|
||||
|
||||
__all__ = ['get_deepspeech_model_file']
|
||||
|
||||
import os
|
||||
import zipfile
|
||||
import logging
|
||||
import hashlib
|
||||
|
||||
|
||||
deepspeech_features_repo_url = 'https://github.com/osmr/deepspeech_features'
|
||||
|
||||
|
||||
def get_deepspeech_model_file(local_model_store_dir_path=os.path.join("~", ".tensorflow", "models")):
|
||||
"""
|
||||
Return location for the pretrained on local file system. This function will download from online model zoo when
|
||||
model cannot be found or has mismatch. The root directory will be created if it doesn't exist.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
local_model_store_dir_path : str, default $TENSORFLOW_HOME/models
|
||||
Location for keeping the model parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
file_path
|
||||
Path to the requested pretrained model file.
|
||||
"""
|
||||
sha1_hash = "b90017e816572ddce84f5843f1fa21e6a377975e"
|
||||
file_name = "deepspeech-0_1_0-b90017e8.pb"
|
||||
local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path)
|
||||
file_path = os.path.join(local_model_store_dir_path, file_name)
|
||||
if os.path.exists(file_path):
|
||||
if _check_sha1(file_path, sha1_hash):
|
||||
return file_path
|
||||
else:
|
||||
logging.warning("Mismatch in the content of model file detected. Downloading again.")
|
||||
else:
|
||||
logging.info("Model file not found. Downloading to {}.".format(file_path))
|
||||
|
||||
if not os.path.exists(local_model_store_dir_path):
|
||||
os.makedirs(local_model_store_dir_path)
|
||||
|
||||
zip_file_path = file_path + ".zip"
|
||||
_download(
|
||||
url="{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip".format(
|
||||
repo_url=deepspeech_features_repo_url,
|
||||
repo_release_tag="v0.0.1",
|
||||
file_name=file_name),
|
||||
path=zip_file_path,
|
||||
overwrite=True)
|
||||
with zipfile.ZipFile(zip_file_path) as zf:
|
||||
zf.extractall(local_model_store_dir_path)
|
||||
os.remove(zip_file_path)
|
||||
|
||||
if _check_sha1(file_path, sha1_hash):
|
||||
return file_path
|
||||
else:
|
||||
raise ValueError("Downloaded file has different hash. Please try again.")
|
||||
|
||||
|
||||
def _download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
|
||||
"""
|
||||
Download an given URL
|
||||
|
||||
Parameters
|
||||
----------
|
||||
url : str
|
||||
URL to download
|
||||
path : str, optional
|
||||
Destination path to store downloaded file. By default stores to the
|
||||
current directory with same name as in url.
|
||||
overwrite : bool, optional
|
||||
Whether to overwrite destination file if already exists.
|
||||
sha1_hash : str, optional
|
||||
Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
|
||||
but doesn't match.
|
||||
retries : integer, default 5
|
||||
The number of times to attempt the download in case of failure or non 200 return codes
|
||||
verify_ssl : bool, default True
|
||||
Verify SSL certificates.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The file path of the downloaded file.
|
||||
"""
|
||||
import warnings
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
class requests_failed_to_import(object):
|
||||
pass
|
||||
requests = requests_failed_to_import
|
||||
|
||||
if path is None:
|
||||
fname = url.split("/")[-1]
|
||||
# Empty filenames are invalid
|
||||
assert fname, "Can't construct file-name from this URL. Please set the `path` option manually."
|
||||
else:
|
||||
path = os.path.expanduser(path)
|
||||
if os.path.isdir(path):
|
||||
fname = os.path.join(path, url.split("/")[-1])
|
||||
else:
|
||||
fname = path
|
||||
assert retries >= 0, "Number of retries should be at least 0"
|
||||
|
||||
if not verify_ssl:
|
||||
warnings.warn(
|
||||
"Unverified HTTPS request is being made (verify_ssl=False). "
|
||||
"Adding certificate verification is strongly advised.")
|
||||
|
||||
if overwrite or not os.path.exists(fname) or (sha1_hash and not _check_sha1(fname, sha1_hash)):
|
||||
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
while retries + 1 > 0:
|
||||
# Disable pyling too broad Exception
|
||||
# pylint: disable=W0703
|
||||
try:
|
||||
print("Downloading {} from {}...".format(fname, url))
|
||||
r = requests.get(url, stream=True, verify=verify_ssl)
|
||||
if r.status_code != 200:
|
||||
raise RuntimeError("Failed downloading url {}".format(url))
|
||||
with open(fname, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
f.write(chunk)
|
||||
if sha1_hash and not _check_sha1(fname, sha1_hash):
|
||||
raise UserWarning("File {} is downloaded but the content hash does not match."
|
||||
" The repo may be outdated or download may be incomplete. "
|
||||
"If the `repo_url` is overridden, consider switching to "
|
||||
"the default repo.".format(fname))
|
||||
break
|
||||
except Exception as e:
|
||||
retries -= 1
|
||||
if retries <= 0:
|
||||
raise e
|
||||
else:
|
||||
print("download failed, retrying, {} attempt{} left"
|
||||
.format(retries, "s" if retries > 1 else ""))
|
||||
|
||||
return fname
|
||||
|
||||
|
||||
def _check_sha1(filename, sha1_hash):
|
||||
"""
|
||||
Check whether the sha1 hash of the file content matches the expected hash.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filename : str
|
||||
Path to the file.
|
||||
sha1_hash : str
|
||||
Expected sha1 hash in hexadecimal digits.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
Whether the file content matches the expected hash.
|
||||
"""
|
||||
sha1 = hashlib.sha1()
|
||||
with open(filename, "rb") as f:
|
||||
while True:
|
||||
data = f.read(1048576)
|
||||
if not data:
|
||||
break
|
||||
sha1.update(data)
|
||||
|
||||
return sha1.hexdigest() == sha1_hash
|
|
@ -0,0 +1,132 @@
|
|||
"""
|
||||
Script for extracting DeepSpeech features from audio file.
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from deepspeech_store import get_deepspeech_model_file
|
||||
from deepspeech_features import conv_audios_to_deepspeech
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
Create python script parameters.
|
||||
Returns
|
||||
-------
|
||||
ArgumentParser
|
||||
Resulted args.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Extract DeepSpeech features from audio file",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
type=str,
|
||||
required=True,
|
||||
help="path to input audio file or directory")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
help="path to output file with DeepSpeech features")
|
||||
parser.add_argument(
|
||||
"--deepspeech",
|
||||
type=str,
|
||||
help="path to DeepSpeech 0.1.0 frozen model")
|
||||
parser.add_argument(
|
||||
"--metainfo",
|
||||
type=str,
|
||||
help="path to file with meta-information")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def extract_features(in_audios,
|
||||
out_files,
|
||||
deepspeech_pb_path,
|
||||
metainfo_file_path=None):
|
||||
"""
|
||||
Real extract audio from video file.
|
||||
Parameters
|
||||
----------
|
||||
in_audios : list of str
|
||||
Paths to input audio files.
|
||||
out_files : list of str
|
||||
Paths to output files with DeepSpeech features.
|
||||
deepspeech_pb_path : str
|
||||
Path to DeepSpeech 0.1.0 frozen model.
|
||||
metainfo_file_path : str, default None
|
||||
Path to file with meta-information.
|
||||
"""
|
||||
#deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
|
||||
if metainfo_file_path is None:
|
||||
num_frames_info = [None] * len(in_audios)
|
||||
else:
|
||||
train_df = pd.read_csv(
|
||||
metainfo_file_path,
|
||||
sep="\t",
|
||||
index_col=False,
|
||||
dtype={"Id": np.int, "File": np.unicode, "Count": np.int})
|
||||
num_frames_info = train_df["Count"].values
|
||||
assert (len(num_frames_info) == len(in_audios))
|
||||
|
||||
for i, in_audio in enumerate(in_audios):
|
||||
if not out_files[i]:
|
||||
file_stem, _ = os.path.splitext(in_audio)
|
||||
out_files[i] = file_stem + ".npy"
|
||||
#print(out_files[i])
|
||||
conv_audios_to_deepspeech(
|
||||
audios=in_audios,
|
||||
out_files=out_files,
|
||||
num_frames_info=num_frames_info,
|
||||
deepspeech_pb_path=deepspeech_pb_path)
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main body of script.
|
||||
"""
|
||||
args = parse_args()
|
||||
in_audio = os.path.expanduser(args.input)
|
||||
if not os.path.exists(in_audio):
|
||||
raise Exception("Input file/directory doesn't exist: {}".format(in_audio))
|
||||
deepspeech_pb_path = args.deepspeech
|
||||
#add
|
||||
deepspeech_pb_path = True
|
||||
args.deepspeech = '~/.tensorflow/models/deepspeech-0_1_0-b90017e8.pb'
|
||||
#deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
|
||||
if deepspeech_pb_path is None:
|
||||
deepspeech_pb_path = ""
|
||||
if deepspeech_pb_path:
|
||||
deepspeech_pb_path = os.path.expanduser(args.deepspeech)
|
||||
if not os.path.exists(deepspeech_pb_path):
|
||||
deepspeech_pb_path = get_deepspeech_model_file()
|
||||
if os.path.isfile(in_audio):
|
||||
extract_features(
|
||||
in_audios=[in_audio],
|
||||
out_files=[args.output],
|
||||
deepspeech_pb_path=deepspeech_pb_path,
|
||||
metainfo_file_path=args.metainfo)
|
||||
else:
|
||||
audio_file_paths = []
|
||||
for file_name in os.listdir(in_audio):
|
||||
if not os.path.isfile(os.path.join(in_audio, file_name)):
|
||||
continue
|
||||
_, file_ext = os.path.splitext(file_name)
|
||||
if file_ext.lower() == ".wav":
|
||||
audio_file_path = os.path.join(in_audio, file_name)
|
||||
audio_file_paths.append(audio_file_path)
|
||||
audio_file_paths = sorted(audio_file_paths)
|
||||
out_file_paths = [""] * len(audio_file_paths)
|
||||
extract_features(
|
||||
in_audios=audio_file_paths,
|
||||
out_files=out_file_paths,
|
||||
deepspeech_pb_path=deepspeech_pb_path,
|
||||
metainfo_file_path=args.metainfo)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
"""
|
||||
Script for extracting audio (16-bit, mono, 22000 Hz) from video file.
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import subprocess
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
Create python script parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ArgumentParser
|
||||
Resulted args.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Extract audio from video file",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument(
|
||||
"--in-video",
|
||||
type=str,
|
||||
required=True,
|
||||
help="path to input video file or directory")
|
||||
parser.add_argument(
|
||||
"--out-audio",
|
||||
type=str,
|
||||
help="path to output audio file")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def extract_audio(in_video,
|
||||
out_audio):
|
||||
"""
|
||||
Real extract audio from video file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_video : str
|
||||
Path to input video file.
|
||||
out_audio : str
|
||||
Path to output audio file.
|
||||
"""
|
||||
if not out_audio:
|
||||
file_stem, _ = os.path.splitext(in_video)
|
||||
out_audio = file_stem + ".wav"
|
||||
# command1 = "ffmpeg -i {in_video} -vn -acodec copy {aac_audio}"
|
||||
# command2 = "ffmpeg -i {aac_audio} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}"
|
||||
# command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}"
|
||||
command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 16000 {out_audio}"
|
||||
subprocess.call([command.format(in_video=in_video, out_audio=out_audio)], shell=True)
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main body of script.
|
||||
"""
|
||||
args = parse_args()
|
||||
in_video = os.path.expanduser(args.in_video)
|
||||
if not os.path.exists(in_video):
|
||||
raise Exception("Input file/directory doesn't exist: {}".format(in_video))
|
||||
if os.path.isfile(in_video):
|
||||
extract_audio(
|
||||
in_video=in_video,
|
||||
out_audio=args.out_audio)
|
||||
else:
|
||||
video_file_paths = []
|
||||
for file_name in os.listdir(in_video):
|
||||
if not os.path.isfile(os.path.join(in_video, file_name)):
|
||||
continue
|
||||
_, file_ext = os.path.splitext(file_name)
|
||||
if file_ext.lower() in (".mp4", ".mkv", ".avi"):
|
||||
video_file_path = os.path.join(in_video, file_name)
|
||||
video_file_paths.append(video_file_path)
|
||||
video_file_paths = sorted(video_file_paths)
|
||||
for video_file_path in video_file_paths:
|
||||
extract_audio(
|
||||
in_video=video_file_path,
|
||||
out_audio="")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,11 @@
|
|||
import numpy as np
|
||||
|
||||
net_output = np.load('french.ds.npy').reshape(-1, 29)
|
||||
win_size = 16
|
||||
zero_pad = np.zeros((int(win_size / 2), net_output.shape[1]))
|
||||
net_output = np.concatenate((zero_pad, net_output, zero_pad), axis=0)
|
||||
windows = []
|
||||
for window_index in range(0, net_output.shape[0] - win_size, 2):
|
||||
windows.append(net_output[window_index:window_index + win_size])
|
||||
print(np.array(windows).shape)
|
||||
np.save('aud_french.npy', np.array(windows))
|
|
@ -0,0 +1,23 @@
|
|||
#!/usr/bin/python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
|
||||
import os.path as osp
|
||||
import time
|
||||
import sys
|
||||
import logging
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def setup_logger(logpth):
|
||||
logfile = 'BiSeNet-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S'))
|
||||
logfile = osp.join(logpth, logfile)
|
||||
FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s'
|
||||
log_level = logging.INFO
|
||||
if dist.is_initialized() and not dist.get_rank()==0:
|
||||
log_level = logging.ERROR
|
||||
logging.basicConfig(level=log_level, format=FORMAT, filename=logfile)
|
||||
logging.root.addHandler(logging.StreamHandler())
|
||||
|
||||
|
|
@ -0,0 +1,285 @@
|
|||
#!/usr/bin/python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
from resnet import Resnet18
|
||||
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
self.conv = nn.Conv2d(in_chan,
|
||||
out_chan,
|
||||
kernel_size = ks,
|
||||
stride = stride,
|
||||
padding = padding,
|
||||
bias = False)
|
||||
self.bn = nn.BatchNorm2d(out_chan)
|
||||
self.init_weight()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = F.relu(self.bn(x))
|
||||
return x
|
||||
|
||||
def init_weight(self):
|
||||
for ly in self.children():
|
||||
if isinstance(ly, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||
|
||||
class BiSeNetOutput(nn.Module):
|
||||
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
|
||||
super(BiSeNetOutput, self).__init__()
|
||||
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
|
||||
self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
|
||||
self.init_weight()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
def init_weight(self):
|
||||
for ly in self.children():
|
||||
if isinstance(ly, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||
|
||||
def get_params(self):
|
||||
wd_params, nowd_params = [], []
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
||||
wd_params.append(module.weight)
|
||||
if not module.bias is None:
|
||||
nowd_params.append(module.bias)
|
||||
elif isinstance(module, nn.BatchNorm2d):
|
||||
nowd_params += list(module.parameters())
|
||||
return wd_params, nowd_params
|
||||
|
||||
|
||||
class AttentionRefinementModule(nn.Module):
|
||||
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
||||
super(AttentionRefinementModule, self).__init__()
|
||||
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
|
||||
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
|
||||
self.bn_atten = nn.BatchNorm2d(out_chan)
|
||||
self.sigmoid_atten = nn.Sigmoid()
|
||||
self.init_weight()
|
||||
|
||||
def forward(self, x):
|
||||
feat = self.conv(x)
|
||||
atten = F.avg_pool2d(feat, feat.size()[2:])
|
||||
atten = self.conv_atten(atten)
|
||||
atten = self.bn_atten(atten)
|
||||
atten = self.sigmoid_atten(atten)
|
||||
out = torch.mul(feat, atten)
|
||||
return out
|
||||
|
||||
def init_weight(self):
|
||||
for ly in self.children():
|
||||
if isinstance(ly, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||
|
||||
|
||||
class ContextPath(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ContextPath, self).__init__()
|
||||
self.resnet = Resnet18()
|
||||
self.arm16 = AttentionRefinementModule(256, 128)
|
||||
self.arm32 = AttentionRefinementModule(512, 128)
|
||||
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
||||
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
||||
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
|
||||
|
||||
self.init_weight()
|
||||
|
||||
def forward(self, x):
|
||||
H0, W0 = x.size()[2:]
|
||||
feat8, feat16, feat32 = self.resnet(x)
|
||||
H8, W8 = feat8.size()[2:]
|
||||
H16, W16 = feat16.size()[2:]
|
||||
H32, W32 = feat32.size()[2:]
|
||||
|
||||
avg = F.avg_pool2d(feat32, feat32.size()[2:])
|
||||
avg = self.conv_avg(avg)
|
||||
avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
|
||||
|
||||
feat32_arm = self.arm32(feat32)
|
||||
feat32_sum = feat32_arm + avg_up
|
||||
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
|
||||
feat32_up = self.conv_head32(feat32_up)
|
||||
|
||||
feat16_arm = self.arm16(feat16)
|
||||
feat16_sum = feat16_arm + feat32_up
|
||||
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
|
||||
feat16_up = self.conv_head16(feat16_up)
|
||||
|
||||
return feat8, feat16_up, feat32_up # x8, x8, x16
|
||||
|
||||
def init_weight(self):
|
||||
for ly in self.children():
|
||||
if isinstance(ly, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||
|
||||
def get_params(self):
|
||||
wd_params, nowd_params = [], []
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
wd_params.append(module.weight)
|
||||
if not module.bias is None:
|
||||
nowd_params.append(module.bias)
|
||||
elif isinstance(module, nn.BatchNorm2d):
|
||||
nowd_params += list(module.parameters())
|
||||
return wd_params, nowd_params
|
||||
|
||||
|
||||
### This is not used, since I replace this with the resnet feature with the same size
|
||||
class SpatialPath(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SpatialPath, self).__init__()
|
||||
self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
|
||||
self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
||||
self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
||||
self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
|
||||
self.init_weight()
|
||||
|
||||
def forward(self, x):
|
||||
feat = self.conv1(x)
|
||||
feat = self.conv2(feat)
|
||||
feat = self.conv3(feat)
|
||||
feat = self.conv_out(feat)
|
||||
return feat
|
||||
|
||||
def init_weight(self):
|
||||
for ly in self.children():
|
||||
if isinstance(ly, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||
|
||||
def get_params(self):
|
||||
wd_params, nowd_params = [], []
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
||||
wd_params.append(module.weight)
|
||||
if not module.bias is None:
|
||||
nowd_params.append(module.bias)
|
||||
elif isinstance(module, nn.BatchNorm2d):
|
||||
nowd_params += list(module.parameters())
|
||||
return wd_params, nowd_params
|
||||
|
||||
|
||||
class FeatureFusionModule(nn.Module):
|
||||
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
||||
super(FeatureFusionModule, self).__init__()
|
||||
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
|
||||
self.conv1 = nn.Conv2d(out_chan,
|
||||
out_chan//4,
|
||||
kernel_size = 1,
|
||||
stride = 1,
|
||||
padding = 0,
|
||||
bias = False)
|
||||
self.conv2 = nn.Conv2d(out_chan//4,
|
||||
out_chan,
|
||||
kernel_size = 1,
|
||||
stride = 1,
|
||||
padding = 0,
|
||||
bias = False)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.init_weight()
|
||||
|
||||
def forward(self, fsp, fcp):
|
||||
fcat = torch.cat([fsp, fcp], dim=1)
|
||||
feat = self.convblk(fcat)
|
||||
atten = F.avg_pool2d(feat, feat.size()[2:])
|
||||
atten = self.conv1(atten)
|
||||
atten = self.relu(atten)
|
||||
atten = self.conv2(atten)
|
||||
atten = self.sigmoid(atten)
|
||||
feat_atten = torch.mul(feat, atten)
|
||||
feat_out = feat_atten + feat
|
||||
return feat_out
|
||||
|
||||
def init_weight(self):
|
||||
for ly in self.children():
|
||||
if isinstance(ly, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||
|
||||
def get_params(self):
|
||||
wd_params, nowd_params = [], []
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
||||
wd_params.append(module.weight)
|
||||
if not module.bias is None:
|
||||
nowd_params.append(module.bias)
|
||||
elif isinstance(module, nn.BatchNorm2d):
|
||||
nowd_params += list(module.parameters())
|
||||
return wd_params, nowd_params
|
||||
|
||||
|
||||
class BiSeNet(nn.Module):
|
||||
def __init__(self, n_classes, *args, **kwargs):
|
||||
super(BiSeNet, self).__init__()
|
||||
self.cp = ContextPath()
|
||||
## here self.sp is deleted
|
||||
self.ffm = FeatureFusionModule(256, 256)
|
||||
self.conv_out = BiSeNetOutput(256, 256, n_classes)
|
||||
self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
|
||||
self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
|
||||
self.init_weight()
|
||||
|
||||
def forward(self, x):
|
||||
H, W = x.size()[2:]
|
||||
feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
|
||||
feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
|
||||
feat_fuse = self.ffm(feat_sp, feat_cp8)
|
||||
|
||||
feat_out = self.conv_out(feat_fuse)
|
||||
feat_out16 = self.conv_out16(feat_cp8)
|
||||
feat_out32 = self.conv_out32(feat_cp16)
|
||||
|
||||
feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
|
||||
feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
|
||||
feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
|
||||
|
||||
# return feat_out, feat_out16, feat_out32
|
||||
return feat_out
|
||||
|
||||
def init_weight(self):
|
||||
for ly in self.children():
|
||||
if isinstance(ly, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||
|
||||
def get_params(self):
|
||||
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
|
||||
for name, child in self.named_children():
|
||||
child_wd_params, child_nowd_params = child.get_params()
|
||||
if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
|
||||
lr_mul_wd_params += child_wd_params
|
||||
lr_mul_nowd_params += child_nowd_params
|
||||
else:
|
||||
wd_params += child_wd_params
|
||||
nowd_params += child_nowd_params
|
||||
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
net = BiSeNet(19)
|
||||
net.cuda()
|
||||
net.eval()
|
||||
in_ten = torch.randn(16, 3, 640, 480).cuda()
|
||||
out, out16, out32 = net(in_ten)
|
||||
print(out.shape)
|
||||
|
||||
net.get_params()
|
|
@ -0,0 +1,109 @@
|
|||
#!/usr/bin/python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.model_zoo as modelzoo
|
||||
|
||||
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
||||
|
||||
resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
def __init__(self, in_chan, out_chan, stride=1):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(in_chan, out_chan, stride)
|
||||
self.bn1 = nn.BatchNorm2d(out_chan)
|
||||
self.conv2 = conv3x3(out_chan, out_chan)
|
||||
self.bn2 = nn.BatchNorm2d(out_chan)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = None
|
||||
if in_chan != out_chan or stride != 1:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_chan, out_chan,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(out_chan),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
residual = self.conv1(x)
|
||||
residual = F.relu(self.bn1(residual))
|
||||
residual = self.conv2(residual)
|
||||
residual = self.bn2(residual)
|
||||
|
||||
shortcut = x
|
||||
if self.downsample is not None:
|
||||
shortcut = self.downsample(x)
|
||||
|
||||
out = shortcut + residual
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
|
||||
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
|
||||
for i in range(bnum-1):
|
||||
layers.append(BasicBlock(out_chan, out_chan, stride=1))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class Resnet18(nn.Module):
|
||||
def __init__(self):
|
||||
super(Resnet18, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
|
||||
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
|
||||
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
|
||||
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
|
||||
self.init_weight()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = F.relu(self.bn1(x))
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
feat8 = self.layer2(x) # 1/8
|
||||
feat16 = self.layer3(feat8) # 1/16
|
||||
feat32 = self.layer4(feat16) # 1/32
|
||||
return feat8, feat16, feat32
|
||||
|
||||
def init_weight(self):
|
||||
state_dict = modelzoo.load_url(resnet18_url)
|
||||
self_state_dict = self.state_dict()
|
||||
for k, v in state_dict.items():
|
||||
if 'fc' in k: continue
|
||||
self_state_dict.update({k: v})
|
||||
self.load_state_dict(self_state_dict)
|
||||
|
||||
def get_params(self):
|
||||
wd_params, nowd_params = [], []
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
wd_params.append(module.weight)
|
||||
if not module.bias is None:
|
||||
nowd_params.append(module.bias)
|
||||
elif isinstance(module, nn.BatchNorm2d):
|
||||
nowd_params += list(module.parameters())
|
||||
return wd_params, nowd_params
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
net = Resnet18()
|
||||
x = torch.randn(16, 3, 224, 224)
|
||||
out = net(x)
|
||||
print(out[0].size())
|
||||
print(out[1].size())
|
||||
print(out[2].size())
|
||||
net.get_params()
|
|
@ -0,0 +1,98 @@
|
|||
#!/usr/bin/python
|
||||
# -*- encoding: utf-8 -*-
|
||||
import numpy as np
|
||||
from model import BiSeNet
|
||||
|
||||
import torch
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
from PIL import Image
|
||||
import torchvision.transforms as transforms
|
||||
import cv2
|
||||
from pathlib import Path
|
||||
import configargparse
|
||||
import tqdm
|
||||
|
||||
# import ttach as tta
|
||||
|
||||
def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg',
|
||||
img_size=(512, 512)):
|
||||
im = np.array(im)
|
||||
vis_im = im.copy().astype(np.uint8)
|
||||
vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
|
||||
vis_parsing_anno = cv2.resize(
|
||||
vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
|
||||
vis_parsing_anno_color = np.zeros(
|
||||
(vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + np.array([255, 255, 255]) # + 255
|
||||
|
||||
num_of_class = np.max(vis_parsing_anno)
|
||||
# print(num_of_class)
|
||||
for pi in range(1, 14):
|
||||
index = np.where(vis_parsing_anno == pi)
|
||||
vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0])
|
||||
|
||||
for pi in range(14, 16):
|
||||
index = np.where(vis_parsing_anno == pi)
|
||||
vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 255, 0])
|
||||
for pi in range(16, 17):
|
||||
index = np.where(vis_parsing_anno == pi)
|
||||
vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 0, 255])
|
||||
for pi in range(17, num_of_class+1):
|
||||
index = np.where(vis_parsing_anno == pi)
|
||||
vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0])
|
||||
|
||||
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
|
||||
index = np.where(vis_parsing_anno == num_of_class-1)
|
||||
vis_im = cv2.resize(vis_parsing_anno_color, img_size,
|
||||
interpolation=cv2.INTER_NEAREST)
|
||||
if save_im:
|
||||
cv2.imwrite(save_path, vis_im)
|
||||
|
||||
|
||||
def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'):
|
||||
|
||||
Path(respth).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f'[INFO] loading model...')
|
||||
n_classes = 19
|
||||
net = BiSeNet(n_classes=n_classes)
|
||||
net.cuda()
|
||||
net.load_state_dict(torch.load(cp))
|
||||
net.eval()
|
||||
|
||||
to_tensor = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
||||
])
|
||||
|
||||
image_paths = os.listdir(dspth)
|
||||
|
||||
with torch.no_grad():
|
||||
for image_path in tqdm.tqdm(image_paths):
|
||||
if image_path.endswith('.jpg') or image_path.endswith('.png'):
|
||||
img = Image.open(osp.join(dspth, image_path))
|
||||
ori_size = img.size
|
||||
image = img.resize((512, 512), Image.BILINEAR)
|
||||
image = image.convert("RGB")
|
||||
img = to_tensor(image)
|
||||
|
||||
# test-time augmentation.
|
||||
inputs = torch.unsqueeze(img, 0) # [1, 3, 512, 512]
|
||||
outputs = net(inputs.cuda())
|
||||
parsing = outputs.mean(0).cpu().numpy().argmax(0)
|
||||
|
||||
image_path = int(image_path[:-4])
|
||||
image_path = str(image_path) + '.png'
|
||||
|
||||
vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path), img_size=ori_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = configargparse.ArgumentParser()
|
||||
parser.add_argument('--respath', type=str, default='./result/', help='result path for label')
|
||||
parser.add_argument('--imgpath', type=str, default='./imgs/', help='path for input images')
|
||||
parser.add_argument('--modelpath', type=str, default='data_utils/face_parsing/79999_iter.pth')
|
||||
args = parser.parse_args()
|
||||
evaluate(respth=args.respath, dspth=args.imgpath, cp=args.modelpath)
|
|
@ -0,0 +1,39 @@
|
|||
import numpy as np
|
||||
from scipy.io import loadmat
|
||||
|
||||
original_BFM = loadmat("3DMM/01_MorphableModel.mat")
|
||||
sub_inds = np.load("3DMM/topology_info.npy", allow_pickle=True).item()["sub_inds"]
|
||||
|
||||
shapePC = original_BFM["shapePC"]
|
||||
shapeEV = original_BFM["shapeEV"]
|
||||
shapeMU = original_BFM["shapeMU"]
|
||||
texPC = original_BFM["texPC"]
|
||||
texEV = original_BFM["texEV"]
|
||||
texMU = original_BFM["texMU"]
|
||||
|
||||
b_shape = shapePC.reshape(-1, 199).transpose(1, 0).reshape(199, -1, 3)
|
||||
mu_shape = shapeMU.reshape(-1, 3)
|
||||
|
||||
b_tex = texPC.reshape(-1, 199).transpose(1, 0).reshape(199, -1, 3)
|
||||
mu_tex = texMU.reshape(-1, 3)
|
||||
|
||||
b_shape = b_shape[:, sub_inds, :].reshape(199, -1)
|
||||
mu_shape = mu_shape[sub_inds, :].reshape(-1)
|
||||
b_tex = b_tex[:, sub_inds, :].reshape(199, -1)
|
||||
mu_tex = mu_tex[sub_inds, :].reshape(-1)
|
||||
|
||||
exp_info = np.load("3DMM/exp_info.npy", allow_pickle=True).item()
|
||||
np.save(
|
||||
"3DMM/3DMM_info.npy",
|
||||
{
|
||||
"mu_shape": mu_shape,
|
||||
"b_shape": b_shape,
|
||||
"sig_shape": shapeEV.reshape(-1),
|
||||
"mu_exp": exp_info["mu_exp"],
|
||||
"b_exp": exp_info["base_exp"],
|
||||
"sig_exp": exp_info["sig_exp"],
|
||||
"mu_tex": mu_tex,
|
||||
"b_tex": b_tex,
|
||||
"sig_tex": texEV.reshape(-1),
|
||||
},
|
||||
)
|
|
@ -0,0 +1,16 @@
|
|||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_dir(path, start, end):
|
||||
lmss = []
|
||||
imgs_paths = []
|
||||
for i in range(start, end):
|
||||
if os.path.isfile(os.path.join(path, str(i) + ".lms")):
|
||||
lms = np.loadtxt(os.path.join(path, str(i) + ".lms"), dtype=np.float32)
|
||||
lmss.append(lms)
|
||||
imgs_paths.append(os.path.join(path, str(i) + ".jpg"))
|
||||
lmss = np.stack(lmss)
|
||||
lmss = torch.as_tensor(lmss).cuda()
|
||||
return lmss, imgs_paths
|
|
@ -0,0 +1,390 @@
|
|||
import os
|
||||
import sys
|
||||
import cv2
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import numpy as np
|
||||
from data_loader import load_dir
|
||||
from facemodel import Face_3DMM
|
||||
from util import *
|
||||
from render_3dmm import Render_3DMM
|
||||
|
||||
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
|
||||
def set_requires_grad(tensor_list):
|
||||
for tensor in tensor_list:
|
||||
tensor.requires_grad = True
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--path", type=str, default="obama/ori_imgs", help="idname of target person"
|
||||
)
|
||||
parser.add_argument("--img_h", type=int, default=512, help="image height")
|
||||
parser.add_argument("--img_w", type=int, default=512, help="image width")
|
||||
parser.add_argument("--frame_num", type=int, default=11000, help="image number")
|
||||
args = parser.parse_args()
|
||||
|
||||
start_id = 0
|
||||
end_id = args.frame_num
|
||||
|
||||
lms, img_paths = load_dir(args.path, start_id, end_id)
|
||||
num_frames = lms.shape[0]
|
||||
h, w = args.img_h, args.img_w
|
||||
cxy = torch.tensor((w / 2.0, h / 2.0), dtype=torch.float).cuda()
|
||||
id_dim, exp_dim, tex_dim, point_num = 100, 79, 100, 34650
|
||||
model_3dmm = Face_3DMM(
|
||||
os.path.join(dir_path, "3DMM"), id_dim, exp_dim, tex_dim, point_num
|
||||
)
|
||||
|
||||
# only use one image per 40 to do fit the focal length
|
||||
sel_ids = np.arange(0, num_frames, 40)
|
||||
sel_num = sel_ids.shape[0]
|
||||
arg_focal = 1600
|
||||
arg_landis = 1e5
|
||||
|
||||
print(f'[INFO] fitting focal length...')
|
||||
|
||||
# fit the focal length
|
||||
for focal in range(600, 1500, 100):
|
||||
id_para = lms.new_zeros((1, id_dim), requires_grad=True)
|
||||
exp_para = lms.new_zeros((sel_num, exp_dim), requires_grad=True)
|
||||
euler_angle = lms.new_zeros((sel_num, 3), requires_grad=True)
|
||||
trans = lms.new_zeros((sel_num, 3), requires_grad=True)
|
||||
trans.data[:, 2] -= 7
|
||||
focal_length = lms.new_zeros(1, requires_grad=False)
|
||||
focal_length.data += focal
|
||||
set_requires_grad([id_para, exp_para, euler_angle, trans])
|
||||
|
||||
optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=0.1)
|
||||
optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=0.1)
|
||||
|
||||
for iter in range(2000):
|
||||
id_para_batch = id_para.expand(sel_num, -1)
|
||||
geometry = model_3dmm.get_3dlandmarks(
|
||||
id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
|
||||
)
|
||||
proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
|
||||
loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms[sel_ids].detach())
|
||||
loss = loss_lan
|
||||
optimizer_frame.zero_grad()
|
||||
loss.backward()
|
||||
optimizer_frame.step()
|
||||
# if iter % 100 == 0:
|
||||
# print(focal, 'pose', iter, loss.item())
|
||||
|
||||
for iter in range(2500):
|
||||
id_para_batch = id_para.expand(sel_num, -1)
|
||||
geometry = model_3dmm.get_3dlandmarks(
|
||||
id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
|
||||
)
|
||||
proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
|
||||
loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms[sel_ids].detach())
|
||||
loss_regid = torch.mean(id_para * id_para)
|
||||
loss_regexp = torch.mean(exp_para * exp_para)
|
||||
loss = loss_lan + loss_regid * 0.5 + loss_regexp * 0.4
|
||||
optimizer_idexp.zero_grad()
|
||||
optimizer_frame.zero_grad()
|
||||
loss.backward()
|
||||
optimizer_idexp.step()
|
||||
optimizer_frame.step()
|
||||
# if iter % 100 == 0:
|
||||
# print(focal, 'poseidexp', iter, loss_lan.item(), loss_regid.item(), loss_regexp.item())
|
||||
|
||||
if iter % 1500 == 0 and iter >= 1500:
|
||||
for param_group in optimizer_idexp.param_groups:
|
||||
param_group["lr"] *= 0.2
|
||||
for param_group in optimizer_frame.param_groups:
|
||||
param_group["lr"] *= 0.2
|
||||
|
||||
print(focal, loss_lan.item(), torch.mean(trans[:, 2]).item())
|
||||
|
||||
if loss_lan.item() < arg_landis:
|
||||
arg_landis = loss_lan.item()
|
||||
arg_focal = focal
|
||||
|
||||
print("[INFO] find best focal:", arg_focal)
|
||||
|
||||
print(f'[INFO] coarse fitting...')
|
||||
|
||||
# for all frames, do a coarse fitting ???
|
||||
id_para = lms.new_zeros((1, id_dim), requires_grad=True)
|
||||
exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
|
||||
tex_para = lms.new_zeros(
|
||||
(1, tex_dim), requires_grad=True
|
||||
) # not optimized in this block ???
|
||||
euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True)
|
||||
trans = lms.new_zeros((num_frames, 3), requires_grad=True)
|
||||
light_para = lms.new_zeros((num_frames, 27), requires_grad=True)
|
||||
trans.data[:, 2] -= 7 # ???
|
||||
focal_length = lms.new_zeros(1, requires_grad=True)
|
||||
focal_length.data += arg_focal
|
||||
|
||||
set_requires_grad([id_para, exp_para, tex_para, euler_angle, trans, light_para])
|
||||
|
||||
optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=0.1)
|
||||
optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=1)
|
||||
|
||||
for iter in range(1500):
|
||||
id_para_batch = id_para.expand(num_frames, -1)
|
||||
geometry = model_3dmm.get_3dlandmarks(
|
||||
id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
|
||||
)
|
||||
proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
|
||||
loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms.detach())
|
||||
loss = loss_lan
|
||||
optimizer_frame.zero_grad()
|
||||
loss.backward()
|
||||
optimizer_frame.step()
|
||||
if iter == 1000:
|
||||
for param_group in optimizer_frame.param_groups:
|
||||
param_group["lr"] = 0.1
|
||||
# if iter % 100 == 0:
|
||||
# print('pose', iter, loss.item())
|
||||
|
||||
for param_group in optimizer_frame.param_groups:
|
||||
param_group["lr"] = 0.1
|
||||
|
||||
for iter in range(2000):
|
||||
id_para_batch = id_para.expand(num_frames, -1)
|
||||
geometry = model_3dmm.get_3dlandmarks(
|
||||
id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
|
||||
)
|
||||
proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
|
||||
loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms.detach())
|
||||
loss_regid = torch.mean(id_para * id_para)
|
||||
loss_regexp = torch.mean(exp_para * exp_para)
|
||||
loss = loss_lan + loss_regid * 0.5 + loss_regexp * 0.4
|
||||
optimizer_idexp.zero_grad()
|
||||
optimizer_frame.zero_grad()
|
||||
loss.backward()
|
||||
optimizer_idexp.step()
|
||||
optimizer_frame.step()
|
||||
# if iter % 100 == 0:
|
||||
# print('poseidexp', iter, loss_lan.item(), loss_regid.item(), loss_regexp.item())
|
||||
if iter % 1000 == 0 and iter >= 1000:
|
||||
for param_group in optimizer_idexp.param_groups:
|
||||
param_group["lr"] *= 0.2
|
||||
for param_group in optimizer_frame.param_groups:
|
||||
param_group["lr"] *= 0.2
|
||||
|
||||
print(loss_lan.item(), torch.mean(trans[:, 2]).item())
|
||||
|
||||
print(f'[INFO] fitting light...')
|
||||
|
||||
batch_size = 32
|
||||
|
||||
device_default = torch.device("cuda:0")
|
||||
device_render = torch.device("cuda:0")
|
||||
renderer = Render_3DMM(arg_focal, h, w, batch_size, device_render)
|
||||
|
||||
sel_ids = np.arange(0, num_frames, int(num_frames / batch_size))[:batch_size]
|
||||
imgs = []
|
||||
for sel_id in sel_ids:
|
||||
imgs.append(cv2.imread(img_paths[sel_id])[:, :, ::-1])
|
||||
imgs = np.stack(imgs)
|
||||
sel_imgs = torch.as_tensor(imgs).cuda()
|
||||
sel_lms = lms[sel_ids]
|
||||
sel_light = light_para.new_zeros((batch_size, 27), requires_grad=True)
|
||||
set_requires_grad([sel_light])
|
||||
|
||||
optimizer_tl = torch.optim.Adam([tex_para, sel_light], lr=0.1)
|
||||
optimizer_id_frame = torch.optim.Adam([euler_angle, trans, exp_para, id_para], lr=0.01)
|
||||
|
||||
for iter in range(71):
|
||||
sel_exp_para, sel_euler, sel_trans = (
|
||||
exp_para[sel_ids],
|
||||
euler_angle[sel_ids],
|
||||
trans[sel_ids],
|
||||
)
|
||||
sel_id_para = id_para.expand(batch_size, -1)
|
||||
geometry = model_3dmm.get_3dlandmarks(
|
||||
sel_id_para, sel_exp_para, sel_euler, sel_trans, focal_length, cxy
|
||||
)
|
||||
proj_geo = forward_transform(geometry, sel_euler, sel_trans, focal_length, cxy)
|
||||
|
||||
loss_lan = cal_lan_loss(proj_geo[:, :, :2], sel_lms.detach())
|
||||
loss_regid = torch.mean(id_para * id_para)
|
||||
loss_regexp = torch.mean(sel_exp_para * sel_exp_para)
|
||||
|
||||
sel_tex_para = tex_para.expand(batch_size, -1)
|
||||
sel_texture = model_3dmm.forward_tex(sel_tex_para)
|
||||
geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para)
|
||||
rott_geo = forward_rott(geometry, sel_euler, sel_trans)
|
||||
render_imgs = renderer(
|
||||
rott_geo.to(device_render),
|
||||
sel_texture.to(device_render),
|
||||
sel_light.to(device_render),
|
||||
)
|
||||
render_imgs = render_imgs.to(device_default)
|
||||
|
||||
mask = (render_imgs[:, :, :, 3]).detach() > 0.0
|
||||
render_proj = sel_imgs.clone()
|
||||
render_proj[mask] = render_imgs[mask][..., :3].byte()
|
||||
loss_col = cal_col_loss(render_imgs[:, :, :, :3], sel_imgs.float(), mask)
|
||||
|
||||
if iter > 50:
|
||||
loss = loss_col + loss_lan * 0.05 + loss_regid * 1.0 + loss_regexp * 0.8
|
||||
else:
|
||||
loss = loss_col + loss_lan * 3 + loss_regid * 2.0 + loss_regexp * 1.0
|
||||
|
||||
optimizer_tl.zero_grad()
|
||||
optimizer_id_frame.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
optimizer_tl.step()
|
||||
optimizer_id_frame.step()
|
||||
|
||||
if iter % 50 == 0 and iter > 0:
|
||||
for param_group in optimizer_id_frame.param_groups:
|
||||
param_group["lr"] *= 0.2
|
||||
for param_group in optimizer_tl.param_groups:
|
||||
param_group["lr"] *= 0.2
|
||||
# print(iter, loss_col.item(), loss_lan.item(), loss_regid.item(), loss_regexp.item())
|
||||
|
||||
|
||||
light_mean = torch.mean(sel_light, 0).unsqueeze(0).repeat(num_frames, 1)
|
||||
light_para.data = light_mean
|
||||
|
||||
exp_para = exp_para.detach()
|
||||
euler_angle = euler_angle.detach()
|
||||
trans = trans.detach()
|
||||
light_para = light_para.detach()
|
||||
|
||||
print(f'[INFO] fine frame-wise fitting...')
|
||||
|
||||
for i in range(int((num_frames - 1) / batch_size + 1)):
|
||||
|
||||
if (i + 1) * batch_size > num_frames:
|
||||
start_n = num_frames - batch_size
|
||||
sel_ids = np.arange(num_frames - batch_size, num_frames)
|
||||
else:
|
||||
start_n = i * batch_size
|
||||
sel_ids = np.arange(i * batch_size, i * batch_size + batch_size)
|
||||
|
||||
imgs = []
|
||||
for sel_id in sel_ids:
|
||||
imgs.append(cv2.imread(img_paths[sel_id])[:, :, ::-1])
|
||||
imgs = np.stack(imgs)
|
||||
sel_imgs = torch.as_tensor(imgs).cuda()
|
||||
sel_lms = lms[sel_ids]
|
||||
|
||||
sel_exp_para = exp_para.new_zeros((batch_size, exp_dim), requires_grad=True)
|
||||
sel_exp_para.data = exp_para[sel_ids].clone()
|
||||
sel_euler = euler_angle.new_zeros((batch_size, 3), requires_grad=True)
|
||||
sel_euler.data = euler_angle[sel_ids].clone()
|
||||
sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True)
|
||||
sel_trans.data = trans[sel_ids].clone()
|
||||
sel_light = light_para.new_zeros((batch_size, 27), requires_grad=True)
|
||||
sel_light.data = light_para[sel_ids].clone()
|
||||
|
||||
set_requires_grad([sel_exp_para, sel_euler, sel_trans, sel_light])
|
||||
|
||||
optimizer_cur_batch = torch.optim.Adam(
|
||||
[sel_exp_para, sel_euler, sel_trans, sel_light], lr=0.005
|
||||
)
|
||||
|
||||
sel_id_para = id_para.expand(batch_size, -1).detach()
|
||||
sel_tex_para = tex_para.expand(batch_size, -1).detach()
|
||||
|
||||
pre_num = 5
|
||||
|
||||
if i > 0:
|
||||
pre_ids = np.arange(start_n - pre_num, start_n)
|
||||
|
||||
for iter in range(50):
|
||||
|
||||
geometry = model_3dmm.get_3dlandmarks(
|
||||
sel_id_para, sel_exp_para, sel_euler, sel_trans, focal_length, cxy
|
||||
)
|
||||
proj_geo = forward_transform(geometry, sel_euler, sel_trans, focal_length, cxy)
|
||||
loss_lan = cal_lan_loss(proj_geo[:, :, :2], sel_lms.detach())
|
||||
loss_regexp = torch.mean(sel_exp_para * sel_exp_para)
|
||||
|
||||
sel_geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para)
|
||||
sel_texture = model_3dmm.forward_tex(sel_tex_para)
|
||||
geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para)
|
||||
rott_geo = forward_rott(geometry, sel_euler, sel_trans)
|
||||
render_imgs = renderer(
|
||||
rott_geo.to(device_render),
|
||||
sel_texture.to(device_render),
|
||||
sel_light.to(device_render),
|
||||
)
|
||||
render_imgs = render_imgs.to(device_default)
|
||||
|
||||
mask = (render_imgs[:, :, :, 3]).detach() > 0.0
|
||||
|
||||
loss_col = cal_col_loss(render_imgs[:, :, :, :3], sel_imgs.float(), mask)
|
||||
|
||||
if i > 0:
|
||||
geometry_lap = model_3dmm.forward_geo_sub(
|
||||
id_para.expand(batch_size + pre_num, -1).detach(),
|
||||
torch.cat((exp_para[pre_ids].detach(), sel_exp_para)),
|
||||
model_3dmm.rigid_ids,
|
||||
)
|
||||
rott_geo_lap = forward_rott(
|
||||
geometry_lap,
|
||||
torch.cat((euler_angle[pre_ids].detach(), sel_euler)),
|
||||
torch.cat((trans[pre_ids].detach(), sel_trans)),
|
||||
)
|
||||
loss_lap = cal_lap_loss(
|
||||
[rott_geo_lap.reshape(rott_geo_lap.shape[0], -1).permute(1, 0)], [1.0]
|
||||
)
|
||||
else:
|
||||
geometry_lap = model_3dmm.forward_geo_sub(
|
||||
id_para.expand(batch_size, -1).detach(),
|
||||
sel_exp_para,
|
||||
model_3dmm.rigid_ids,
|
||||
)
|
||||
rott_geo_lap = forward_rott(geometry_lap, sel_euler, sel_trans)
|
||||
loss_lap = cal_lap_loss(
|
||||
[rott_geo_lap.reshape(rott_geo_lap.shape[0], -1).permute(1, 0)], [1.0]
|
||||
)
|
||||
|
||||
|
||||
if iter > 30:
|
||||
loss = loss_col * 0.5 + loss_lan * 1.5 + loss_lap * 100000 + loss_regexp * 1.0
|
||||
else:
|
||||
loss = loss_col * 0.5 + loss_lan * 8 + loss_lap * 100000 + loss_regexp * 1.0
|
||||
|
||||
optimizer_cur_batch.zero_grad()
|
||||
loss.backward()
|
||||
optimizer_cur_batch.step()
|
||||
|
||||
# if iter % 10 == 0:
|
||||
# print(
|
||||
# i,
|
||||
# iter,
|
||||
# loss_col.item(),
|
||||
# loss_lan.item(),
|
||||
# loss_lap.item(),
|
||||
# loss_regexp.item(),
|
||||
# )
|
||||
|
||||
print(str(i) + " of " + str(int((num_frames - 1) / batch_size + 1)) + " done")
|
||||
|
||||
render_proj = sel_imgs.clone()
|
||||
render_proj[mask] = render_imgs[mask][..., :3].byte()
|
||||
|
||||
exp_para[sel_ids] = sel_exp_para.clone()
|
||||
euler_angle[sel_ids] = sel_euler.clone()
|
||||
trans[sel_ids] = sel_trans.clone()
|
||||
light_para[sel_ids] = sel_light.clone()
|
||||
|
||||
torch.save(
|
||||
{
|
||||
"id": id_para.detach().cpu(),
|
||||
"exp": exp_para.detach().cpu(),
|
||||
"euler": euler_angle.detach().cpu(),
|
||||
"trans": trans.detach().cpu(),
|
||||
"focal": focal_length.detach().cpu(),
|
||||
},
|
||||
os.path.join(os.path.dirname(args.path), "track_params.pt"),
|
||||
)
|
||||
|
||||
print("params saved")
|
|
@ -0,0 +1,153 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import os
|
||||
from util import *
|
||||
|
||||
|
||||
class Face_3DMM(nn.Module):
|
||||
def __init__(self, modelpath, id_dim, exp_dim, tex_dim, point_num):
|
||||
super(Face_3DMM, self).__init__()
|
||||
# id_dim = 100
|
||||
# exp_dim = 79
|
||||
# tex_dim = 100
|
||||
self.point_num = point_num
|
||||
DMM_info = np.load(
|
||||
os.path.join(modelpath, "3DMM_info.npy"), allow_pickle=True
|
||||
).item()
|
||||
base_id = DMM_info["b_shape"][:id_dim, :]
|
||||
mu_id = DMM_info["mu_shape"]
|
||||
base_exp = DMM_info["b_exp"][:exp_dim, :]
|
||||
mu_exp = DMM_info["mu_exp"]
|
||||
mu = mu_id + mu_exp
|
||||
mu = mu.reshape(-1, 3)
|
||||
for i in range(3):
|
||||
mu[:, i] -= np.mean(mu[:, i])
|
||||
mu = mu.reshape(-1)
|
||||
self.base_id = torch.as_tensor(base_id).cuda() / 100000.0
|
||||
self.base_exp = torch.as_tensor(base_exp).cuda() / 100000.0
|
||||
self.mu = torch.as_tensor(mu).cuda() / 100000.0
|
||||
base_tex = DMM_info["b_tex"][:tex_dim, :]
|
||||
mu_tex = DMM_info["mu_tex"]
|
||||
self.base_tex = torch.as_tensor(base_tex).cuda()
|
||||
self.mu_tex = torch.as_tensor(mu_tex).cuda()
|
||||
sig_id = DMM_info["sig_shape"][:id_dim]
|
||||
sig_tex = DMM_info["sig_tex"][:tex_dim]
|
||||
sig_exp = DMM_info["sig_exp"][:exp_dim]
|
||||
self.sig_id = torch.as_tensor(sig_id).cuda()
|
||||
self.sig_tex = torch.as_tensor(sig_tex).cuda()
|
||||
self.sig_exp = torch.as_tensor(sig_exp).cuda()
|
||||
|
||||
keys_info = np.load(
|
||||
os.path.join(modelpath, "keys_info.npy"), allow_pickle=True
|
||||
).item()
|
||||
self.keyinds = torch.as_tensor(keys_info["keyinds"]).cuda()
|
||||
self.left_contours = torch.as_tensor(keys_info["left_contour"]).cuda()
|
||||
self.right_contours = torch.as_tensor(keys_info["right_contour"]).cuda()
|
||||
self.rigid_ids = torch.as_tensor(keys_info["rigid_ids"]).cuda()
|
||||
|
||||
def get_3dlandmarks(self, id_para, exp_para, euler_angle, trans, focal_length, cxy):
|
||||
id_para = id_para * self.sig_id
|
||||
exp_para = exp_para * self.sig_exp
|
||||
batch_size = id_para.shape[0]
|
||||
num_per_contour = self.left_contours.shape[1]
|
||||
left_contours_flat = self.left_contours.reshape(-1)
|
||||
right_contours_flat = self.right_contours.reshape(-1)
|
||||
sel_index = torch.cat(
|
||||
(
|
||||
3 * left_contours_flat.unsqueeze(1),
|
||||
3 * left_contours_flat.unsqueeze(1) + 1,
|
||||
3 * left_contours_flat.unsqueeze(1) + 2,
|
||||
),
|
||||
dim=1,
|
||||
).reshape(-1)
|
||||
left_geometry = (
|
||||
torch.mm(id_para, self.base_id[:, sel_index])
|
||||
+ torch.mm(exp_para, self.base_exp[:, sel_index])
|
||||
+ self.mu[sel_index]
|
||||
)
|
||||
left_geometry = left_geometry.view(batch_size, -1, 3)
|
||||
proj_x = forward_transform(
|
||||
left_geometry, euler_angle, trans, focal_length, cxy
|
||||
)[:, :, 0]
|
||||
proj_x = proj_x.reshape(batch_size, 8, num_per_contour)
|
||||
arg_min = proj_x.argmin(dim=2)
|
||||
left_geometry = left_geometry.view(batch_size * 8, num_per_contour, 3)
|
||||
left_3dlands = left_geometry[
|
||||
torch.arange(batch_size * 8), arg_min.view(-1), :
|
||||
].view(batch_size, 8, 3)
|
||||
|
||||
sel_index = torch.cat(
|
||||
(
|
||||
3 * right_contours_flat.unsqueeze(1),
|
||||
3 * right_contours_flat.unsqueeze(1) + 1,
|
||||
3 * right_contours_flat.unsqueeze(1) + 2,
|
||||
),
|
||||
dim=1,
|
||||
).reshape(-1)
|
||||
right_geometry = (
|
||||
torch.mm(id_para, self.base_id[:, sel_index])
|
||||
+ torch.mm(exp_para, self.base_exp[:, sel_index])
|
||||
+ self.mu[sel_index]
|
||||
)
|
||||
right_geometry = right_geometry.view(batch_size, -1, 3)
|
||||
proj_x = forward_transform(
|
||||
right_geometry, euler_angle, trans, focal_length, cxy
|
||||
)[:, :, 0]
|
||||
proj_x = proj_x.reshape(batch_size, 8, num_per_contour)
|
||||
arg_max = proj_x.argmax(dim=2)
|
||||
right_geometry = right_geometry.view(batch_size * 8, num_per_contour, 3)
|
||||
right_3dlands = right_geometry[
|
||||
torch.arange(batch_size * 8), arg_max.view(-1), :
|
||||
].view(batch_size, 8, 3)
|
||||
|
||||
sel_index = torch.cat(
|
||||
(
|
||||
3 * self.keyinds.unsqueeze(1),
|
||||
3 * self.keyinds.unsqueeze(1) + 1,
|
||||
3 * self.keyinds.unsqueeze(1) + 2,
|
||||
),
|
||||
dim=1,
|
||||
).reshape(-1)
|
||||
geometry = (
|
||||
torch.mm(id_para, self.base_id[:, sel_index])
|
||||
+ torch.mm(exp_para, self.base_exp[:, sel_index])
|
||||
+ self.mu[sel_index]
|
||||
)
|
||||
lands_3d = geometry.view(-1, self.keyinds.shape[0], 3)
|
||||
lands_3d[:, :8, :] = left_3dlands
|
||||
lands_3d[:, 9:17, :] = right_3dlands
|
||||
return lands_3d
|
||||
|
||||
def forward_geo_sub(self, id_para, exp_para, sub_index):
|
||||
id_para = id_para * self.sig_id
|
||||
exp_para = exp_para * self.sig_exp
|
||||
sel_index = torch.cat(
|
||||
(
|
||||
3 * sub_index.unsqueeze(1),
|
||||
3 * sub_index.unsqueeze(1) + 1,
|
||||
3 * sub_index.unsqueeze(1) + 2,
|
||||
),
|
||||
dim=1,
|
||||
).reshape(-1)
|
||||
geometry = (
|
||||
torch.mm(id_para, self.base_id[:, sel_index])
|
||||
+ torch.mm(exp_para, self.base_exp[:, sel_index])
|
||||
+ self.mu[sel_index]
|
||||
)
|
||||
return geometry.reshape(-1, sub_index.shape[0], 3)
|
||||
|
||||
def forward_geo(self, id_para, exp_para):
|
||||
id_para = id_para * self.sig_id
|
||||
exp_para = exp_para * self.sig_exp
|
||||
geometry = (
|
||||
torch.mm(id_para, self.base_id)
|
||||
+ torch.mm(exp_para, self.base_exp)
|
||||
+ self.mu
|
||||
)
|
||||
return geometry.reshape(-1, self.point_num, 3)
|
||||
|
||||
def forward_tex(self, tex_para):
|
||||
tex_para = tex_para * self.sig_tex
|
||||
texture = torch.mm(tex_para, self.base_tex) + self.mu_tex
|
||||
return texture.reshape(-1, self.point_num, 3)
|
|
@ -0,0 +1,69 @@
|
|||
"""This module contains functions for geometry transform and camera projection"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
def euler2rot(euler_angle):
|
||||
batch_size = euler_angle.shape[0]
|
||||
theta = euler_angle[:, 0].reshape(-1, 1, 1)
|
||||
phi = euler_angle[:, 1].reshape(-1, 1, 1)
|
||||
psi = euler_angle[:, 2].reshape(-1, 1, 1)
|
||||
one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device)
|
||||
zero = torch.zeros(
|
||||
(batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device
|
||||
)
|
||||
rot_x = torch.cat(
|
||||
(
|
||||
torch.cat((one, zero, zero), 1),
|
||||
torch.cat((zero, theta.cos(), theta.sin()), 1),
|
||||
torch.cat((zero, -theta.sin(), theta.cos()), 1),
|
||||
),
|
||||
2,
|
||||
)
|
||||
rot_y = torch.cat(
|
||||
(
|
||||
torch.cat((phi.cos(), zero, -phi.sin()), 1),
|
||||
torch.cat((zero, one, zero), 1),
|
||||
torch.cat((phi.sin(), zero, phi.cos()), 1),
|
||||
),
|
||||
2,
|
||||
)
|
||||
rot_z = torch.cat(
|
||||
(
|
||||
torch.cat((psi.cos(), -psi.sin(), zero), 1),
|
||||
torch.cat((psi.sin(), psi.cos(), zero), 1),
|
||||
torch.cat((zero, zero, one), 1),
|
||||
),
|
||||
2,
|
||||
)
|
||||
return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
|
||||
|
||||
|
||||
def rot_trans_geo(geometry, rot, trans):
|
||||
rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans.view(-1, 3, 1)
|
||||
return rott_geo.permute(0, 2, 1)
|
||||
|
||||
|
||||
def euler_trans_geo(geometry, euler, trans):
|
||||
rot = euler2rot(euler)
|
||||
return rot_trans_geo(geometry, rot, trans)
|
||||
|
||||
|
||||
def proj_geo(rott_geo, camera_para):
|
||||
fx = camera_para[:, 0]
|
||||
fy = camera_para[:, 0]
|
||||
cx = camera_para[:, 1]
|
||||
cy = camera_para[:, 2]
|
||||
|
||||
X = rott_geo[:, :, 0]
|
||||
Y = rott_geo[:, :, 1]
|
||||
Z = rott_geo[:, :, 2]
|
||||
|
||||
fxX = fx[:, None] * X
|
||||
fyY = fy[:, None] * Y
|
||||
|
||||
proj_x = -fxX / Z + cx[:, None]
|
||||
proj_y = fyY / Z + cy[:, None]
|
||||
|
||||
return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2)
|
|
@ -0,0 +1,202 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import os
|
||||
from pytorch3d.structures import Meshes
|
||||
from pytorch3d.renderer import (
|
||||
look_at_view_transform,
|
||||
PerspectiveCameras,
|
||||
FoVPerspectiveCameras,
|
||||
PointLights,
|
||||
DirectionalLights,
|
||||
Materials,
|
||||
RasterizationSettings,
|
||||
MeshRenderer,
|
||||
MeshRasterizer,
|
||||
SoftPhongShader,
|
||||
TexturesUV,
|
||||
TexturesVertex,
|
||||
blending,
|
||||
)
|
||||
|
||||
from pytorch3d.ops import interpolate_face_attributes
|
||||
|
||||
from pytorch3d.renderer.blending import (
|
||||
BlendParams,
|
||||
hard_rgb_blend,
|
||||
sigmoid_alpha_blend,
|
||||
softmax_rgb_blend,
|
||||
)
|
||||
|
||||
|
||||
class SoftSimpleShader(nn.Module):
|
||||
"""
|
||||
Per pixel lighting - the lighting model is applied using the interpolated
|
||||
coordinates and normals for each pixel. The blending function returns the
|
||||
soft aggregated color using all the faces per pixel.
|
||||
|
||||
To use the default values, simply initialize the shader with the desired
|
||||
device e.g.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
|
||||
):
|
||||
super().__init__()
|
||||
self.lights = lights if lights is not None else PointLights(device=device)
|
||||
self.materials = (
|
||||
materials if materials is not None else Materials(device=device)
|
||||
)
|
||||
self.cameras = cameras
|
||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||
|
||||
def to(self, device):
|
||||
# Manually move to device modules which are not subclasses of nn.Module
|
||||
self.cameras = self.cameras.to(device)
|
||||
self.materials = self.materials.to(device)
|
||||
self.lights = self.lights.to(device)
|
||||
return self
|
||||
|
||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||
|
||||
texels = meshes.sample_textures(fragments)
|
||||
blend_params = kwargs.get("blend_params", self.blend_params)
|
||||
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
if cameras is None:
|
||||
msg = "Cameras must be specified either at initialization \
|
||||
or in the forward pass of SoftPhongShader"
|
||||
raise ValueError(msg)
|
||||
znear = kwargs.get("znear", getattr(cameras, "znear", 1.0))
|
||||
zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0))
|
||||
images = softmax_rgb_blend(
|
||||
texels, fragments, blend_params, znear=znear, zfar=zfar
|
||||
)
|
||||
return images
|
||||
|
||||
|
||||
class Render_3DMM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
focal=1015,
|
||||
img_h=500,
|
||||
img_w=500,
|
||||
batch_size=1,
|
||||
device=torch.device("cuda:0"),
|
||||
):
|
||||
super(Render_3DMM, self).__init__()
|
||||
|
||||
self.focal = focal
|
||||
self.img_h = img_h
|
||||
self.img_w = img_w
|
||||
self.device = device
|
||||
self.renderer = self.get_render(batch_size)
|
||||
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
topo_info = np.load(
|
||||
os.path.join(dir_path, "3DMM", "topology_info.npy"), allow_pickle=True
|
||||
).item()
|
||||
self.tris = torch.as_tensor(topo_info["tris"]).to(self.device)
|
||||
self.vert_tris = torch.as_tensor(topo_info["vert_tris"]).to(self.device)
|
||||
|
||||
def compute_normal(self, geometry):
|
||||
vert_1 = torch.index_select(geometry, 1, self.tris[:, 0])
|
||||
vert_2 = torch.index_select(geometry, 1, self.tris[:, 1])
|
||||
vert_3 = torch.index_select(geometry, 1, self.tris[:, 2])
|
||||
nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 2)
|
||||
tri_normal = nn.functional.normalize(nnorm, dim=2)
|
||||
v_norm = tri_normal[:, self.vert_tris, :].sum(2)
|
||||
vert_normal = v_norm / v_norm.norm(dim=2).unsqueeze(2)
|
||||
return vert_normal
|
||||
|
||||
def get_render(self, batch_size=1):
|
||||
half_s = self.img_w * 0.5
|
||||
R, T = look_at_view_transform(10, 0, 0)
|
||||
R = R.repeat(batch_size, 1, 1)
|
||||
T = torch.zeros((batch_size, 3), dtype=torch.float32).to(self.device)
|
||||
|
||||
cameras = FoVPerspectiveCameras(
|
||||
device=self.device,
|
||||
R=R,
|
||||
T=T,
|
||||
znear=0.01,
|
||||
zfar=20,
|
||||
fov=2 * np.arctan(self.img_w // 2 / self.focal) * 180.0 / np.pi,
|
||||
)
|
||||
lights = PointLights(
|
||||
device=self.device,
|
||||
location=[[0.0, 0.0, 1e5]],
|
||||
ambient_color=[[1, 1, 1]],
|
||||
specular_color=[[0.0, 0.0, 0.0]],
|
||||
diffuse_color=[[0.0, 0.0, 0.0]],
|
||||
)
|
||||
sigma = 1e-4
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=(self.img_h, self.img_w),
|
||||
blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma / 18.0,
|
||||
faces_per_pixel=2,
|
||||
perspective_correct=False,
|
||||
)
|
||||
blend_params = blending.BlendParams(background_color=[0, 0, 0])
|
||||
renderer = MeshRenderer(
|
||||
rasterizer=MeshRasterizer(raster_settings=raster_settings, cameras=cameras),
|
||||
shader=SoftSimpleShader(
|
||||
lights=lights, blend_params=blend_params, cameras=cameras
|
||||
),
|
||||
)
|
||||
return renderer.to(self.device)
|
||||
|
||||
@staticmethod
|
||||
def Illumination_layer(face_texture, norm, gamma):
|
||||
|
||||
n_b, num_vertex, _ = face_texture.size()
|
||||
n_v_full = n_b * num_vertex
|
||||
gamma = gamma.view(-1, 3, 9).clone()
|
||||
gamma[:, :, 0] += 0.8
|
||||
|
||||
gamma = gamma.permute(0, 2, 1)
|
||||
|
||||
a0 = np.pi
|
||||
a1 = 2 * np.pi / np.sqrt(3.0)
|
||||
a2 = 2 * np.pi / np.sqrt(8.0)
|
||||
c0 = 1 / np.sqrt(4 * np.pi)
|
||||
c1 = np.sqrt(3.0) / np.sqrt(4 * np.pi)
|
||||
c2 = 3 * np.sqrt(5.0) / np.sqrt(12 * np.pi)
|
||||
d0 = 0.5 / np.sqrt(3.0)
|
||||
|
||||
Y0 = torch.ones(n_v_full).to(gamma.device).float() * a0 * c0
|
||||
norm = norm.view(-1, 3)
|
||||
nx, ny, nz = norm[:, 0], norm[:, 1], norm[:, 2]
|
||||
arrH = []
|
||||
|
||||
arrH.append(Y0)
|
||||
arrH.append(-a1 * c1 * ny)
|
||||
arrH.append(a1 * c1 * nz)
|
||||
arrH.append(-a1 * c1 * nx)
|
||||
arrH.append(a2 * c2 * nx * ny)
|
||||
arrH.append(-a2 * c2 * ny * nz)
|
||||
arrH.append(a2 * c2 * d0 * (3 * nz.pow(2) - 1))
|
||||
arrH.append(-a2 * c2 * nx * nz)
|
||||
arrH.append(a2 * c2 * 0.5 * (nx.pow(2) - ny.pow(2)))
|
||||
|
||||
H = torch.stack(arrH, 1)
|
||||
Y = H.view(n_b, num_vertex, 9)
|
||||
lighting = Y.bmm(gamma)
|
||||
|
||||
face_color = face_texture * lighting
|
||||
return face_color
|
||||
|
||||
def forward(self, rott_geometry, texture, diffuse_sh):
|
||||
face_normal = self.compute_normal(rott_geometry)
|
||||
face_color = self.Illumination_layer(texture, face_normal, diffuse_sh)
|
||||
face_color = TexturesVertex(face_color)
|
||||
mesh = Meshes(
|
||||
rott_geometry,
|
||||
self.tris.float().repeat(rott_geometry.shape[0], 1, 1),
|
||||
face_color,
|
||||
)
|
||||
rendered_img = self.renderer(mesh)
|
||||
rendered_img = torch.clamp(rendered_img, 0, 255)
|
||||
|
||||
return rendered_img
|
|
@ -0,0 +1,192 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import render_util
|
||||
import geo_transform
|
||||
import numpy as np
|
||||
|
||||
|
||||
def compute_tri_normal(geometry, tris):
|
||||
geometry = geometry.permute(0, 2, 1)
|
||||
tri_1 = tris[:, 0]
|
||||
tri_2 = tris[:, 1]
|
||||
tri_3 = tris[:, 2]
|
||||
|
||||
vert_1 = torch.index_select(geometry, 2, tri_1)
|
||||
vert_2 = torch.index_select(geometry, 2, tri_2)
|
||||
vert_3 = torch.index_select(geometry, 2, tri_3)
|
||||
|
||||
nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 1)
|
||||
normal = nn.functional.normalize(nnorm).permute(0, 2, 1)
|
||||
return normal
|
||||
|
||||
|
||||
class Compute_normal_base(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, normal):
|
||||
(normal_b,) = render_util.normal_base_forward(normal)
|
||||
ctx.save_for_backward(normal)
|
||||
return normal_b
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_normal_b):
|
||||
(normal,) = ctx.saved_tensors
|
||||
(grad_normal,) = render_util.normal_base_backward(grad_normal_b, normal)
|
||||
return grad_normal
|
||||
|
||||
|
||||
class Normal_Base(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Normal_Base, self).__init__()
|
||||
|
||||
def forward(self, normal):
|
||||
return Compute_normal_base.apply(normal)
|
||||
|
||||
|
||||
def preprocess_render(geometry, euler, trans, cam, tris, vert_tris, ori_img):
|
||||
point_num = geometry.shape[1]
|
||||
rott_geo = geo_transform.euler_trans_geo(geometry, euler, trans)
|
||||
proj_geo = geo_transform.proj_geo(rott_geo, cam)
|
||||
rot_tri_normal = compute_tri_normal(rott_geo, tris)
|
||||
rot_vert_normal = torch.index_select(rot_tri_normal, 1, vert_tris)
|
||||
is_visible = -torch.bmm(
|
||||
rot_vert_normal.reshape(-1, 1, 3),
|
||||
nn.functional.normalize(rott_geo.reshape(-1, 3, 1)),
|
||||
).reshape(-1, point_num)
|
||||
is_visible[is_visible < 0.01] = -1
|
||||
pixel_valid = torch.zeros(
|
||||
(ori_img.shape[0], ori_img.shape[1] * ori_img.shape[2]),
|
||||
dtype=torch.float32,
|
||||
device=ori_img.device,
|
||||
)
|
||||
return rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid
|
||||
|
||||
|
||||
class Render_Face(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid
|
||||
):
|
||||
batch_size, h, w, _ = ori_img.shape
|
||||
ori_img = ori_img.view(batch_size, -1, 3)
|
||||
ori_size = torch.cat(
|
||||
(
|
||||
torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
|
||||
* h,
|
||||
torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
|
||||
* w,
|
||||
),
|
||||
dim=1,
|
||||
).view(-1)
|
||||
tri_index, tri_coord, render, real = render_util.render_face_forward(
|
||||
proj_geo, ori_img, ori_size, texture, nbl, is_visible, tri_inds, pixel_valid
|
||||
)
|
||||
ctx.save_for_backward(
|
||||
ori_img, ori_size, proj_geo, texture, nbl, tri_inds, tri_index, tri_coord
|
||||
)
|
||||
return render, real
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_render, grad_real):
|
||||
(
|
||||
ori_img,
|
||||
ori_size,
|
||||
proj_geo,
|
||||
texture,
|
||||
nbl,
|
||||
tri_inds,
|
||||
tri_index,
|
||||
tri_coord,
|
||||
) = ctx.saved_tensors
|
||||
grad_proj_geo, grad_texture, grad_nbl = render_util.render_face_backward(
|
||||
grad_render,
|
||||
grad_real,
|
||||
ori_img,
|
||||
ori_size,
|
||||
proj_geo,
|
||||
texture,
|
||||
nbl,
|
||||
tri_inds,
|
||||
tri_index,
|
||||
tri_coord,
|
||||
)
|
||||
return grad_proj_geo, grad_texture, grad_nbl, None, None, None, None
|
||||
|
||||
|
||||
class Render_RGB(nn.Module):
|
||||
def __init__(self):
|
||||
super(Render_RGB, self).__init__()
|
||||
|
||||
def forward(
|
||||
self, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid
|
||||
):
|
||||
return Render_Face.apply(
|
||||
proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid
|
||||
)
|
||||
|
||||
|
||||
def cal_land(proj_geo, is_visible, lands_info, land_num):
|
||||
(land_index,) = render_util.update_contour(lands_info, is_visible, land_num)
|
||||
proj_land = torch.index_select(proj_geo.reshape(-1, 3), 0, land_index)[
|
||||
:, :2
|
||||
].reshape(-1, land_num, 2)
|
||||
return proj_land
|
||||
|
||||
|
||||
class Render_Land(nn.Module):
|
||||
def __init__(self):
|
||||
super(Render_Land, self).__init__()
|
||||
lands_info = np.loadtxt("../data/3DMM/lands_info.txt", dtype=np.int32)
|
||||
self.lands_info = torch.as_tensor(lands_info).cuda()
|
||||
tris = np.loadtxt("../data/3DMM/tris.txt", dtype=np.int64)
|
||||
self.tris = torch.as_tensor(tris).cuda() - 1
|
||||
vert_tris = np.loadtxt("../data/3DMM/vert_tris.txt", dtype=np.int64)
|
||||
self.vert_tris = torch.as_tensor(vert_tris).cuda()
|
||||
self.normal_baser = Normal_Base().cuda()
|
||||
self.renderer = Render_RGB().cuda()
|
||||
|
||||
def render_mesh(self, geometry, euler, trans, cam, ori_img, light):
|
||||
batch_size, h, w, _ = ori_img.shape
|
||||
ori_img = ori_img.view(batch_size, -1, 3)
|
||||
ori_size = torch.cat(
|
||||
(
|
||||
torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
|
||||
* h,
|
||||
torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
|
||||
* w,
|
||||
),
|
||||
dim=1,
|
||||
).view(-1)
|
||||
rott_geo, proj_geo, rot_tri_normal, _, _ = preprocess_render(
|
||||
geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img
|
||||
)
|
||||
tri_nb = self.normal_baser(rot_tri_normal.contiguous())
|
||||
nbl = torch.bmm(
|
||||
tri_nb, (light.reshape(-1, 9, 3))[:, :, 0].unsqueeze(-1).repeat(1, 1, 3)
|
||||
)
|
||||
texture = torch.ones_like(geometry) * 200
|
||||
(render,) = render_util.render_mesh(
|
||||
proj_geo, ori_img, ori_size, texture, nbl, self.tris
|
||||
)
|
||||
return render.view(batch_size, h, w, 3).byte()
|
||||
|
||||
def cal_loss_rgb(self, geometry, euler, trans, cam, ori_img, light, texture, lands):
|
||||
rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid = preprocess_render(
|
||||
geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img
|
||||
)
|
||||
tri_nb = self.normal_baser(rot_tri_normal.contiguous())
|
||||
nbl = torch.bmm(tri_nb, light.reshape(-1, 9, 3))
|
||||
render, real = self.renderer(
|
||||
proj_geo, texture, nbl, ori_img, is_visible, self.tris, pixel_valid
|
||||
)
|
||||
proj_land = cal_land(proj_geo, is_visible, self.lands_info, lands.shape[1])
|
||||
col_minus = torch.norm((render - real).reshape(-1, 3), dim=1).reshape(
|
||||
ori_img.shape[0], -1
|
||||
)
|
||||
col_dis = torch.mean(col_minus * pixel_valid) / (
|
||||
torch.mean(pixel_valid) + 0.00001
|
||||
)
|
||||
land_dists = torch.norm((proj_land - lands).reshape(-1, 2), dim=1).reshape(
|
||||
ori_img.shape[0], -1
|
||||
)
|
||||
lan_dis = torch.mean(land_dists)
|
||||
return col_dis, lan_dis
|
|
@ -0,0 +1,109 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def compute_tri_normal(geometry, tris):
|
||||
tri_1 = tris[:, 0]
|
||||
tri_2 = tris[:, 1]
|
||||
tri_3 = tris[:, 2]
|
||||
vert_1 = torch.index_select(geometry, 1, tri_1)
|
||||
vert_2 = torch.index_select(geometry, 1, tri_2)
|
||||
vert_3 = torch.index_select(geometry, 1, tri_3)
|
||||
nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 2)
|
||||
normal = nn.functional.normalize(nnorm)
|
||||
return normal
|
||||
|
||||
|
||||
def euler2rot(euler_angle):
|
||||
batch_size = euler_angle.shape[0]
|
||||
theta = euler_angle[:, 0].reshape(-1, 1, 1)
|
||||
phi = euler_angle[:, 1].reshape(-1, 1, 1)
|
||||
psi = euler_angle[:, 2].reshape(-1, 1, 1)
|
||||
one = torch.ones(batch_size, 1, 1).to(euler_angle.device)
|
||||
zero = torch.zeros(batch_size, 1, 1).to(euler_angle.device)
|
||||
rot_x = torch.cat(
|
||||
(
|
||||
torch.cat((one, zero, zero), 1),
|
||||
torch.cat((zero, theta.cos(), theta.sin()), 1),
|
||||
torch.cat((zero, -theta.sin(), theta.cos()), 1),
|
||||
),
|
||||
2,
|
||||
)
|
||||
rot_y = torch.cat(
|
||||
(
|
||||
torch.cat((phi.cos(), zero, -phi.sin()), 1),
|
||||
torch.cat((zero, one, zero), 1),
|
||||
torch.cat((phi.sin(), zero, phi.cos()), 1),
|
||||
),
|
||||
2,
|
||||
)
|
||||
rot_z = torch.cat(
|
||||
(
|
||||
torch.cat((psi.cos(), -psi.sin(), zero), 1),
|
||||
torch.cat((psi.sin(), psi.cos(), zero), 1),
|
||||
torch.cat((zero, zero, one), 1),
|
||||
),
|
||||
2,
|
||||
)
|
||||
return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
|
||||
|
||||
|
||||
def rot_trans_pts(geometry, rot, trans):
|
||||
rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans[:, :, None]
|
||||
return rott_geo.permute(0, 2, 1)
|
||||
|
||||
|
||||
def cal_lap_loss(tensor_list, weight_list):
|
||||
lap_kernel = (
|
||||
torch.Tensor((-0.5, 1.0, -0.5))
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(0)
|
||||
.float()
|
||||
.to(tensor_list[0].device)
|
||||
)
|
||||
loss_lap = 0
|
||||
for i in range(len(tensor_list)):
|
||||
in_tensor = tensor_list[i]
|
||||
in_tensor = in_tensor.view(-1, 1, in_tensor.shape[-1])
|
||||
out_tensor = F.conv1d(in_tensor, lap_kernel)
|
||||
loss_lap += torch.mean(out_tensor ** 2) * weight_list[i]
|
||||
return loss_lap
|
||||
|
||||
|
||||
def proj_pts(rott_geo, focal_length, cxy):
|
||||
cx, cy = cxy[0], cxy[1]
|
||||
X = rott_geo[:, :, 0]
|
||||
Y = rott_geo[:, :, 1]
|
||||
Z = rott_geo[:, :, 2]
|
||||
fxX = focal_length * X
|
||||
fyY = focal_length * Y
|
||||
proj_x = -fxX / Z + cx
|
||||
proj_y = fyY / Z + cy
|
||||
return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2)
|
||||
|
||||
|
||||
def forward_rott(geometry, euler_angle, trans):
|
||||
rot = euler2rot(euler_angle)
|
||||
rott_geo = rot_trans_pts(geometry, rot, trans)
|
||||
return rott_geo
|
||||
|
||||
|
||||
def forward_transform(geometry, euler_angle, trans, focal_length, cxy):
|
||||
rot = euler2rot(euler_angle)
|
||||
rott_geo = rot_trans_pts(geometry, rot, trans)
|
||||
proj_geo = proj_pts(rott_geo, focal_length, cxy)
|
||||
return proj_geo
|
||||
|
||||
|
||||
def cal_lan_loss(proj_lan, gt_lan):
|
||||
return torch.mean((proj_lan - gt_lan) ** 2)
|
||||
|
||||
|
||||
def cal_col_loss(pred_img, gt_img, img_mask):
|
||||
pred_img = pred_img.float()
|
||||
# loss = torch.sqrt(torch.sum(torch.square(pred_img - gt_img), 3))*img_mask/255
|
||||
loss = (torch.sum(torch.square(pred_img - gt_img), 3)) * img_mask / 255
|
||||
loss = torch.sum(loss, dim=(1, 2)) / torch.sum(img_mask, dim=(1, 2))
|
||||
loss = torch.mean(loss)
|
||||
return loss
|
|
@ -0,0 +1,402 @@
|
|||
import os
|
||||
import glob
|
||||
import tqdm
|
||||
import json
|
||||
import argparse
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
def extract_audio(path, out_path, sample_rate=16000):
|
||||
|
||||
print(f'[INFO] ===== extract audio from {path} to {out_path} =====')
|
||||
cmd = f'ffmpeg -i {path} -f wav -ar {sample_rate} {out_path}'
|
||||
os.system(cmd)
|
||||
print(f'[INFO] ===== extracted audio =====')
|
||||
|
||||
|
||||
def extract_audio_features(path, mode='wav2vec'):
|
||||
|
||||
print(f'[INFO] ===== extract audio labels for {path} =====')
|
||||
if mode == 'wav2vec':
|
||||
cmd = f'python nerf/asr.py --wav {path} --save_feats'
|
||||
else: # deepspeech
|
||||
cmd = f'python data_utils/deepspeech_features/extract_ds_features.py --input {path}'
|
||||
os.system(cmd)
|
||||
print(f'[INFO] ===== extracted audio labels =====')
|
||||
|
||||
|
||||
|
||||
def extract_images(path, out_path, fps=25):
|
||||
|
||||
print(f'[INFO] ===== extract images from {path} to {out_path} =====')
|
||||
cmd = f'ffmpeg -i {path} -vf fps={fps} -qmin 1 -q:v 1 -start_number 0 {os.path.join(out_path, "%d.jpg")}'
|
||||
os.system(cmd)
|
||||
print(f'[INFO] ===== extracted images =====')
|
||||
|
||||
|
||||
def extract_semantics(ori_imgs_dir, parsing_dir):
|
||||
|
||||
print(f'[INFO] ===== extract semantics from {ori_imgs_dir} to {parsing_dir} =====')
|
||||
cmd = f'python data_utils/face_parsing/test.py --respath={parsing_dir} --imgpath={ori_imgs_dir}'
|
||||
os.system(cmd)
|
||||
print(f'[INFO] ===== extracted semantics =====')
|
||||
|
||||
|
||||
def extract_landmarks(ori_imgs_dir):
|
||||
|
||||
print(f'[INFO] ===== extract face landmarks from {ori_imgs_dir} =====')
|
||||
|
||||
import face_alignment
|
||||
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)
|
||||
image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
|
||||
for image_path in tqdm.tqdm(image_paths):
|
||||
input = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
||||
input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
|
||||
preds = fa.get_landmarks(input)
|
||||
if len(preds) > 0:
|
||||
lands = preds[0].reshape(-1, 2)[:,:2]
|
||||
np.savetxt(image_path.replace('jpg', 'lms'), lands, '%f')
|
||||
del fa
|
||||
print(f'[INFO] ===== extracted face landmarks =====')
|
||||
|
||||
|
||||
def extract_background(base_dir, ori_imgs_dir):
|
||||
|
||||
print(f'[INFO] ===== extract background image from {ori_imgs_dir} =====')
|
||||
|
||||
from sklearn.neighbors import NearestNeighbors
|
||||
|
||||
image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
|
||||
# only use 1/20 image_paths
|
||||
image_paths = image_paths[::20]
|
||||
# read one image to get H/W
|
||||
tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
||||
h, w = tmp_image.shape[:2]
|
||||
|
||||
# nearest neighbors
|
||||
all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose()
|
||||
distss = []
|
||||
for image_path in tqdm.tqdm(image_paths):
|
||||
parse_img = cv2.imread(image_path.replace('ori_imgs', 'parsing').replace('.jpg', '.png'))
|
||||
bg = (parse_img[..., 0] == 255) & (parse_img[..., 1] == 255) & (parse_img[..., 2] == 255)
|
||||
fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0)
|
||||
nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
|
||||
dists, _ = nbrs.kneighbors(all_xys)
|
||||
distss.append(dists)
|
||||
|
||||
distss = np.stack(distss)
|
||||
max_dist = np.max(distss, 0)
|
||||
max_id = np.argmax(distss, 0)
|
||||
|
||||
bc_pixs = max_dist > 5
|
||||
bc_pixs_id = np.nonzero(bc_pixs)
|
||||
bc_ids = max_id[bc_pixs]
|
||||
|
||||
imgs = []
|
||||
num_pixs = distss.shape[1]
|
||||
for image_path in image_paths:
|
||||
img = cv2.imread(image_path)
|
||||
imgs.append(img)
|
||||
imgs = np.stack(imgs).reshape(-1, num_pixs, 3)
|
||||
|
||||
bc_img = np.zeros((h*w, 3), dtype=np.uint8)
|
||||
bc_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :]
|
||||
bc_img = bc_img.reshape(h, w, 3)
|
||||
|
||||
max_dist = max_dist.reshape(h, w)
|
||||
bc_pixs = max_dist > 5
|
||||
bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose()
|
||||
fg_xys = np.stack(np.nonzero(bc_pixs)).transpose()
|
||||
nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
|
||||
distances, indices = nbrs.kneighbors(bg_xys)
|
||||
bg_fg_xys = fg_xys[indices[:, 0]]
|
||||
bc_img[bg_xys[:, 0], bg_xys[:, 1], :] = bc_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
|
||||
|
||||
cv2.imwrite(os.path.join(base_dir, 'bc.jpg'), bc_img)
|
||||
|
||||
print(f'[INFO] ===== extracted background image =====')
|
||||
|
||||
|
||||
def extract_torso_and_gt(base_dir, ori_imgs_dir):
|
||||
|
||||
print(f'[INFO] ===== extract torso and gt images for {base_dir} =====')
|
||||
|
||||
from scipy.ndimage import binary_erosion, binary_dilation
|
||||
|
||||
# load bg
|
||||
bg_image = cv2.imread(os.path.join(base_dir, 'bc.jpg'), cv2.IMREAD_UNCHANGED)
|
||||
|
||||
image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
|
||||
|
||||
for image_path in tqdm.tqdm(image_paths):
|
||||
# read ori image
|
||||
ori_image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
||||
|
||||
# read semantics
|
||||
seg = cv2.imread(image_path.replace('ori_imgs', 'parsing').replace('.jpg', '.png'))
|
||||
head_part = (seg[..., 0] == 255) & (seg[..., 1] == 0) & (seg[..., 2] == 0)
|
||||
neck_part = (seg[..., 0] == 0) & (seg[..., 1] == 255) & (seg[..., 2] == 0)
|
||||
torso_part = (seg[..., 0] == 0) & (seg[..., 1] == 0) & (seg[..., 2] == 255)
|
||||
bg_part = (seg[..., 0] == 255) & (seg[..., 1] == 255) & (seg[..., 2] == 255)
|
||||
|
||||
# get gt image
|
||||
gt_image = ori_image.copy()
|
||||
gt_image[bg_part] = bg_image[bg_part]
|
||||
cv2.imwrite(image_path.replace('ori_imgs', 'gt_imgs'), gt_image)
|
||||
|
||||
# get torso image
|
||||
torso_image = gt_image.copy() # rgb
|
||||
torso_image[head_part] = bg_image[head_part]
|
||||
torso_alpha = 255 * np.ones((gt_image.shape[0], gt_image.shape[1], 1), dtype=np.uint8) # alpha
|
||||
|
||||
# torso part "vertical" in-painting...
|
||||
L = 8 + 1
|
||||
torso_coords = np.stack(np.nonzero(torso_part), axis=-1) # [M, 2]
|
||||
# lexsort: sort 2D coords first by y then by x,
|
||||
# ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
|
||||
inds = np.lexsort((torso_coords[:, 0], torso_coords[:, 1]))
|
||||
torso_coords = torso_coords[inds]
|
||||
# choose the top pixel for each column
|
||||
u, uid, ucnt = np.unique(torso_coords[:, 1], return_index=True, return_counts=True)
|
||||
top_torso_coords = torso_coords[uid] # [m, 2]
|
||||
# only keep top-is-head pixels
|
||||
top_torso_coords_up = top_torso_coords.copy() - np.array([1, 0])
|
||||
mask = head_part[tuple(top_torso_coords_up.T)]
|
||||
if mask.any():
|
||||
top_torso_coords = top_torso_coords[mask]
|
||||
# get the color
|
||||
top_torso_colors = gt_image[tuple(top_torso_coords.T)] # [m, 3]
|
||||
# construct inpaint coords (vertically up, or minus in x)
|
||||
inpaint_torso_coords = top_torso_coords[None].repeat(L, 0) # [L, m, 2]
|
||||
inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
|
||||
inpaint_torso_coords += inpaint_offsets
|
||||
inpaint_torso_coords = inpaint_torso_coords.reshape(-1, 2) # [Lm, 2]
|
||||
inpaint_torso_colors = top_torso_colors[None].repeat(L, 0) # [L, m, 3]
|
||||
darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
|
||||
inpaint_torso_colors = (inpaint_torso_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
|
||||
# set color
|
||||
torso_image[tuple(inpaint_torso_coords.T)] = inpaint_torso_colors
|
||||
|
||||
inpaint_torso_mask = np.zeros_like(torso_image[..., 0]).astype(bool)
|
||||
inpaint_torso_mask[tuple(inpaint_torso_coords.T)] = True
|
||||
else:
|
||||
inpaint_torso_mask = None
|
||||
|
||||
|
||||
# neck part "vertical" in-painting...
|
||||
push_down = 4
|
||||
L = 48 + push_down + 1
|
||||
|
||||
neck_part = binary_dilation(neck_part, structure=np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=bool), iterations=3)
|
||||
|
||||
neck_coords = np.stack(np.nonzero(neck_part), axis=-1) # [M, 2]
|
||||
# lexsort: sort 2D coords first by y then by x,
|
||||
# ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
|
||||
inds = np.lexsort((neck_coords[:, 0], neck_coords[:, 1]))
|
||||
neck_coords = neck_coords[inds]
|
||||
# choose the top pixel for each column
|
||||
u, uid, ucnt = np.unique(neck_coords[:, 1], return_index=True, return_counts=True)
|
||||
top_neck_coords = neck_coords[uid] # [m, 2]
|
||||
# only keep top-is-head pixels
|
||||
top_neck_coords_up = top_neck_coords.copy() - np.array([1, 0])
|
||||
mask = head_part[tuple(top_neck_coords_up.T)]
|
||||
|
||||
top_neck_coords = top_neck_coords[mask]
|
||||
# push these top down for 4 pixels to make the neck inpainting more natural...
|
||||
offset_down = np.minimum(ucnt[mask] - 1, push_down)
|
||||
top_neck_coords += np.stack([offset_down, np.zeros_like(offset_down)], axis=-1)
|
||||
# get the color
|
||||
top_neck_colors = gt_image[tuple(top_neck_coords.T)] # [m, 3]
|
||||
# construct inpaint coords (vertically up, or minus in x)
|
||||
inpaint_neck_coords = top_neck_coords[None].repeat(L, 0) # [L, m, 2]
|
||||
inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
|
||||
inpaint_neck_coords += inpaint_offsets
|
||||
inpaint_neck_coords = inpaint_neck_coords.reshape(-1, 2) # [Lm, 2]
|
||||
inpaint_neck_colors = top_neck_colors[None].repeat(L, 0) # [L, m, 3]
|
||||
darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
|
||||
inpaint_neck_colors = (inpaint_neck_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
|
||||
# set color
|
||||
torso_image[tuple(inpaint_neck_coords.T)] = inpaint_neck_colors
|
||||
|
||||
# apply blurring to the inpaint area to avoid vertical-line artifects...
|
||||
inpaint_mask = np.zeros_like(torso_image[..., 0]).astype(bool)
|
||||
inpaint_mask[tuple(inpaint_neck_coords.T)] = True
|
||||
|
||||
blur_img = torso_image.copy()
|
||||
blur_img = cv2.GaussianBlur(blur_img, (5, 5), cv2.BORDER_DEFAULT)
|
||||
|
||||
torso_image[inpaint_mask] = blur_img[inpaint_mask]
|
||||
|
||||
# set mask
|
||||
mask = (neck_part | torso_part | inpaint_mask)
|
||||
if inpaint_torso_mask is not None:
|
||||
mask = mask | inpaint_torso_mask
|
||||
torso_image[~mask] = 0
|
||||
torso_alpha[~mask] = 0
|
||||
|
||||
cv2.imwrite(image_path.replace('ori_imgs', 'torso_imgs').replace('.jpg', '.png'), np.concatenate([torso_image, torso_alpha], axis=-1))
|
||||
|
||||
print(f'[INFO] ===== extracted torso and gt images =====')
|
||||
|
||||
|
||||
def face_tracking(ori_imgs_dir):
|
||||
|
||||
print(f'[INFO] ===== perform face tracking =====')
|
||||
|
||||
image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
|
||||
|
||||
# read one image to get H/W
|
||||
tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
||||
h, w = tmp_image.shape[:2]
|
||||
|
||||
cmd = f'python data_utils/face_tracking/face_tracker.py --path={ori_imgs_dir} --img_h={h} --img_w={w} --frame_num={len(image_paths)}'
|
||||
|
||||
os.system(cmd)
|
||||
|
||||
print(f'[INFO] ===== finished face tracking =====')
|
||||
|
||||
|
||||
def save_transforms(base_dir, ori_imgs_dir):
|
||||
print(f'[INFO] ===== save transforms =====')
|
||||
|
||||
import torch
|
||||
|
||||
image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
|
||||
|
||||
# read one image to get H/W
|
||||
tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
||||
h, w = tmp_image.shape[:2]
|
||||
|
||||
params_dict = torch.load(os.path.join(base_dir, 'track_params.pt'))
|
||||
focal_len = params_dict['focal']
|
||||
euler_angle = params_dict['euler']
|
||||
trans = params_dict['trans'] / 10.0
|
||||
valid_num = euler_angle.shape[0]
|
||||
|
||||
def euler2rot(euler_angle):
|
||||
batch_size = euler_angle.shape[0]
|
||||
theta = euler_angle[:, 0].reshape(-1, 1, 1)
|
||||
phi = euler_angle[:, 1].reshape(-1, 1, 1)
|
||||
psi = euler_angle[:, 2].reshape(-1, 1, 1)
|
||||
one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device)
|
||||
zero = torch.zeros((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device)
|
||||
rot_x = torch.cat((
|
||||
torch.cat((one, zero, zero), 1),
|
||||
torch.cat((zero, theta.cos(), theta.sin()), 1),
|
||||
torch.cat((zero, -theta.sin(), theta.cos()), 1),
|
||||
), 2)
|
||||
rot_y = torch.cat((
|
||||
torch.cat((phi.cos(), zero, -phi.sin()), 1),
|
||||
torch.cat((zero, one, zero), 1),
|
||||
torch.cat((phi.sin(), zero, phi.cos()), 1),
|
||||
), 2)
|
||||
rot_z = torch.cat((
|
||||
torch.cat((psi.cos(), -psi.sin(), zero), 1),
|
||||
torch.cat((psi.sin(), psi.cos(), zero), 1),
|
||||
torch.cat((zero, zero, one), 1)
|
||||
), 2)
|
||||
return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
|
||||
|
||||
|
||||
# train_val_split = int(valid_num*0.5)
|
||||
# train_val_split = valid_num - 25 * 20 # take the last 20s as valid set.
|
||||
train_val_split = int(valid_num * 10 / 11)
|
||||
|
||||
train_ids = torch.arange(0, train_val_split)
|
||||
val_ids = torch.arange(train_val_split, valid_num)
|
||||
|
||||
rot = euler2rot(euler_angle)
|
||||
rot_inv = rot.permute(0, 2, 1)
|
||||
trans_inv = -torch.bmm(rot_inv, trans.unsqueeze(2))
|
||||
|
||||
pose = torch.eye(4, dtype=torch.float32)
|
||||
save_ids = ['train', 'val']
|
||||
train_val_ids = [train_ids, val_ids]
|
||||
mean_z = -float(torch.mean(trans[:, 2]).item())
|
||||
|
||||
for split in range(2):
|
||||
transform_dict = dict()
|
||||
transform_dict['focal_len'] = float(focal_len[0])
|
||||
transform_dict['cx'] = float(w/2.0)
|
||||
transform_dict['cy'] = float(h/2.0)
|
||||
transform_dict['frames'] = []
|
||||
ids = train_val_ids[split]
|
||||
save_id = save_ids[split]
|
||||
|
||||
for i in ids:
|
||||
i = i.item()
|
||||
frame_dict = dict()
|
||||
frame_dict['img_id'] = i
|
||||
frame_dict['aud_id'] = i
|
||||
|
||||
pose[:3, :3] = rot_inv[i]
|
||||
pose[:3, 3] = trans_inv[i, :, 0]
|
||||
|
||||
frame_dict['transform_matrix'] = pose.numpy().tolist()
|
||||
|
||||
transform_dict['frames'].append(frame_dict)
|
||||
|
||||
with open(os.path.join(base_dir, 'transforms_' + save_id + '.json'), 'w') as fp:
|
||||
json.dump(transform_dict, fp, indent=2, separators=(',', ': '))
|
||||
|
||||
print(f'[INFO] ===== finished saving transforms =====')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('path', type=str, help="path to video file")
|
||||
parser.add_argument('--task', type=int, default=-1, help="-1 means all")
|
||||
parser.add_argument('--asr', type=str, default='wav2vec', help="wav2vec or deepspeech")
|
||||
|
||||
opt = parser.parse_args()
|
||||
|
||||
base_dir = os.path.dirname(opt.path)
|
||||
|
||||
wav_path = os.path.join(base_dir, 'aud.wav')
|
||||
ori_imgs_dir = os.path.join(base_dir, 'ori_imgs')
|
||||
parsing_dir = os.path.join(base_dir, 'parsing')
|
||||
gt_imgs_dir = os.path.join(base_dir, 'gt_imgs')
|
||||
torso_imgs_dir = os.path.join(base_dir, 'torso_imgs')
|
||||
|
||||
os.makedirs(ori_imgs_dir, exist_ok=True)
|
||||
os.makedirs(parsing_dir, exist_ok=True)
|
||||
os.makedirs(gt_imgs_dir, exist_ok=True)
|
||||
os.makedirs(torso_imgs_dir, exist_ok=True)
|
||||
|
||||
|
||||
# extract audio
|
||||
if opt.task == -1 or opt.task == 1:
|
||||
extract_audio(opt.path, wav_path)
|
||||
|
||||
# extract audio features
|
||||
if opt.task == -1 or opt.task == 2:
|
||||
extract_audio_features(wav_path, mode=opt.asr)
|
||||
|
||||
# extract images
|
||||
if opt.task == -1 or opt.task == 3:
|
||||
extract_images(opt.path, ori_imgs_dir)
|
||||
|
||||
# face parsing
|
||||
if opt.task == -1 or opt.task == 4:
|
||||
extract_semantics(ori_imgs_dir, parsing_dir)
|
||||
|
||||
# extract bg
|
||||
if opt.task == -1 or opt.task == 5:
|
||||
extract_background(base_dir, ori_imgs_dir)
|
||||
|
||||
# extract torso images and gt_images
|
||||
if opt.task == -1 or opt.task == 6:
|
||||
extract_torso_and_gt(base_dir, ori_imgs_dir)
|
||||
|
||||
# extract face landmarks
|
||||
if opt.task == -1 or opt.task == 7:
|
||||
extract_landmarks(ori_imgs_dir)
|
||||
|
||||
# face tracking
|
||||
if opt.task == -1 or opt.task == 8:
|
||||
face_tracking(ori_imgs_dir)
|
||||
|
||||
# save transforms.json
|
||||
if opt.task == -1 or opt.task == 9:
|
||||
save_transforms(base_dir, ori_imgs_dir)
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def get_encoder(encoding, input_dim=3,
|
||||
multires=6,
|
||||
degree=4,
|
||||
num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False,
|
||||
**kwargs):
|
||||
|
||||
if encoding == 'None':
|
||||
return lambda x, **kwargs: x, input_dim
|
||||
|
||||
elif encoding == 'frequency':
|
||||
from freqencoder import FreqEncoder
|
||||
encoder = FreqEncoder(input_dim=input_dim, degree=multires)
|
||||
|
||||
elif encoding == 'spherical_harmonics':
|
||||
from shencoder import SHEncoder
|
||||
encoder = SHEncoder(input_dim=input_dim, degree=degree)
|
||||
|
||||
elif encoding == 'hashgrid':
|
||||
from gridencoder import GridEncoder
|
||||
encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners)
|
||||
|
||||
elif encoding == 'tiledgrid':
|
||||
from gridencoder import GridEncoder
|
||||
encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners)
|
||||
|
||||
elif encoding == 'ash':
|
||||
from ashencoder import AshEncoder
|
||||
encoder = AshEncoder(input_dim=input_dim, output_dim=16, log2_hashmap_size=log2_hashmap_size, resolution=desired_resolution)
|
||||
|
||||
else:
|
||||
raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, spherical_harmonics, hashgrid, tiledgrid]')
|
||||
|
||||
return encoder, encoder.output_dim
|
|
@ -0,0 +1 @@
|
|||
from .freq import FreqEncoder
|
|
@ -0,0 +1,41 @@
|
|||
import os
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
nvcc_flags = [
|
||||
'-O3', '-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
||||
'-use_fast_math'
|
||||
]
|
||||
|
||||
if os.name == "posix":
|
||||
c_flags = ['-O3', '-std=c++14']
|
||||
elif os.name == "nt":
|
||||
c_flags = ['/O2', '/std:c++17']
|
||||
|
||||
# find cl.exe
|
||||
def find_cl_path():
|
||||
import glob
|
||||
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
||||
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
||||
if paths:
|
||||
return paths[0]
|
||||
|
||||
# If cl.exe is not on path, try to find it.
|
||||
if os.system("where cl.exe >nul 2>nul") != 0:
|
||||
cl_path = find_cl_path()
|
||||
if cl_path is None:
|
||||
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
||||
os.environ["PATH"] += ";" + cl_path
|
||||
|
||||
_backend = load(name='_freqencoder',
|
||||
extra_cflags=c_flags,
|
||||
extra_cuda_cflags=nvcc_flags,
|
||||
sources=[os.path.join(_src_path, 'src', f) for f in [
|
||||
'freqencoder.cu',
|
||||
'bindings.cpp',
|
||||
]],
|
||||
)
|
||||
|
||||
__all__ = ['_backend']
|
|
@ -0,0 +1,77 @@
|
|||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
try:
|
||||
import _freqencoder as _backend
|
||||
except ImportError:
|
||||
from .backend import _backend
|
||||
|
||||
|
||||
class _freq_encoder(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
|
||||
def forward(ctx, inputs, degree, output_dim):
|
||||
# inputs: [B, input_dim], float
|
||||
# RETURN: [B, F], float
|
||||
|
||||
if not inputs.is_cuda: inputs = inputs.cuda()
|
||||
inputs = inputs.contiguous()
|
||||
|
||||
B, input_dim = inputs.shape # batch size, coord dim
|
||||
|
||||
outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
|
||||
|
||||
_backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
|
||||
|
||||
ctx.save_for_backward(inputs, outputs)
|
||||
ctx.dims = [B, input_dim, degree, output_dim]
|
||||
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
#@once_differentiable
|
||||
@custom_bwd
|
||||
def backward(ctx, grad):
|
||||
# grad: [B, C * C]
|
||||
|
||||
grad = grad.contiguous()
|
||||
inputs, outputs = ctx.saved_tensors
|
||||
B, input_dim, degree, output_dim = ctx.dims
|
||||
|
||||
grad_inputs = torch.zeros_like(inputs)
|
||||
_backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
|
||||
|
||||
return grad_inputs, None, None
|
||||
|
||||
|
||||
freq_encode = _freq_encoder.apply
|
||||
|
||||
|
||||
class FreqEncoder(nn.Module):
|
||||
def __init__(self, input_dim=3, degree=4):
|
||||
super().__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.degree = degree
|
||||
self.output_dim = input_dim + input_dim * 2 * degree
|
||||
|
||||
def __repr__(self):
|
||||
return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}"
|
||||
|
||||
def forward(self, inputs, **kwargs):
|
||||
# inputs: [..., input_dim]
|
||||
# return: [..., ]
|
||||
|
||||
prefix_shape = list(inputs.shape[:-1])
|
||||
inputs = inputs.reshape(-1, self.input_dim)
|
||||
|
||||
outputs = freq_encode(inputs, self.degree, self.output_dim)
|
||||
|
||||
outputs = outputs.reshape(prefix_shape + [self.output_dim])
|
||||
|
||||
return outputs
|
|
@ -0,0 +1,51 @@
|
|||
import os
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
nvcc_flags = [
|
||||
'-O3', '-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
||||
'-use_fast_math'
|
||||
]
|
||||
|
||||
if os.name == "posix":
|
||||
c_flags = ['-O3', '-std=c++14']
|
||||
elif os.name == "nt":
|
||||
c_flags = ['/O2', '/std:c++17']
|
||||
|
||||
# find cl.exe
|
||||
def find_cl_path():
|
||||
import glob
|
||||
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
||||
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
||||
if paths:
|
||||
return paths[0]
|
||||
|
||||
# If cl.exe is not on path, try to find it.
|
||||
if os.system("where cl.exe >nul 2>nul") != 0:
|
||||
cl_path = find_cl_path()
|
||||
if cl_path is None:
|
||||
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
||||
os.environ["PATH"] += ";" + cl_path
|
||||
|
||||
setup(
|
||||
name='freqencoder', # package name, import this to use python API
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
name='_freqencoder', # extension name, import this to use CUDA API
|
||||
sources=[os.path.join(_src_path, 'src', f) for f in [
|
||||
'freqencoder.cu',
|
||||
'bindings.cpp',
|
||||
]],
|
||||
extra_compile_args={
|
||||
'cxx': c_flags,
|
||||
'nvcc': nvcc_flags,
|
||||
}
|
||||
),
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension,
|
||||
}
|
||||
)
|
|
@ -0,0 +1,8 @@
|
|||
#include <torch/extension.h>
|
||||
|
||||
#include "freqencoder.h"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)");
|
||||
m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)");
|
||||
}
|
|
@ -0,0 +1,129 @@
|
|||
#include <stdint.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
|
||||
#include <cstdio>
|
||||
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
|
||||
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
|
||||
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
|
||||
|
||||
inline constexpr __device__ float PI() { return 3.141592653589793f; }
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ T div_round_up(T val, T divisor) {
|
||||
return (val + divisor - 1) / divisor;
|
||||
}
|
||||
|
||||
// inputs: [B, D]
|
||||
// outputs: [B, C], C = D + D * deg * 2
|
||||
__global__ void kernel_freq(
|
||||
const float * __restrict__ inputs,
|
||||
uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
|
||||
float * outputs
|
||||
) {
|
||||
// parallel on per-element
|
||||
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (t >= B * C) return;
|
||||
|
||||
// get index
|
||||
const uint32_t b = t / C;
|
||||
const uint32_t c = t - b * C; // t % C;
|
||||
|
||||
// locate
|
||||
inputs += b * D;
|
||||
outputs += t;
|
||||
|
||||
// write self
|
||||
if (c < D) {
|
||||
outputs[0] = inputs[c];
|
||||
// write freq
|
||||
} else {
|
||||
const uint32_t col = c / D - 1;
|
||||
const uint32_t d = c % D;
|
||||
const uint32_t freq = col / 2;
|
||||
const float phase_shift = (col % 2) * (PI() / 2);
|
||||
outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift);
|
||||
}
|
||||
}
|
||||
|
||||
// grad: [B, C], C = D + D * deg * 2
|
||||
// outputs: [B, C]
|
||||
// grad_inputs: [B, D]
|
||||
__global__ void kernel_freq_backward(
|
||||
const float * __restrict__ grad,
|
||||
const float * __restrict__ outputs,
|
||||
uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
|
||||
float * grad_inputs
|
||||
) {
|
||||
// parallel on per-element
|
||||
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (t >= B * D) return;
|
||||
|
||||
const uint32_t b = t / D;
|
||||
const uint32_t d = t - b * D; // t % D;
|
||||
|
||||
// locate
|
||||
grad += b * C;
|
||||
outputs += b * C;
|
||||
grad_inputs += t;
|
||||
|
||||
// register
|
||||
float result = grad[d];
|
||||
grad += D;
|
||||
outputs += D;
|
||||
|
||||
for (uint32_t f = 0; f < deg; f++) {
|
||||
result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]);
|
||||
grad += 2 * D;
|
||||
outputs += 2 * D;
|
||||
}
|
||||
|
||||
// write
|
||||
grad_inputs[0] = result;
|
||||
}
|
||||
|
||||
|
||||
void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) {
|
||||
CHECK_CUDA(inputs);
|
||||
CHECK_CUDA(outputs);
|
||||
|
||||
CHECK_CONTIGUOUS(inputs);
|
||||
CHECK_CONTIGUOUS(outputs);
|
||||
|
||||
CHECK_IS_FLOATING(inputs);
|
||||
CHECK_IS_FLOATING(outputs);
|
||||
|
||||
static constexpr uint32_t N_THREADS = 128;
|
||||
|
||||
kernel_freq<<<div_round_up(B * C, N_THREADS), N_THREADS>>>(inputs.data_ptr<float>(), B, D, deg, C, outputs.data_ptr<float>());
|
||||
}
|
||||
|
||||
|
||||
void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) {
|
||||
CHECK_CUDA(grad);
|
||||
CHECK_CUDA(outputs);
|
||||
CHECK_CUDA(grad_inputs);
|
||||
|
||||
CHECK_CONTIGUOUS(grad);
|
||||
CHECK_CONTIGUOUS(outputs);
|
||||
CHECK_CONTIGUOUS(grad_inputs);
|
||||
|
||||
CHECK_IS_FLOATING(grad);
|
||||
CHECK_IS_FLOATING(outputs);
|
||||
CHECK_IS_FLOATING(grad_inputs);
|
||||
|
||||
static constexpr uint32_t N_THREADS = 128;
|
||||
|
||||
kernel_freq_backward<<<div_round_up(B * D, N_THREADS), N_THREADS>>>(grad.data_ptr<float>(), outputs.data_ptr<float>(), B, D, deg, C, grad_inputs.data_ptr<float>());
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
# pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
// _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
|
||||
void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs);
|
||||
|
||||
// _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
|
||||
void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs);
|
|
@ -0,0 +1 @@
|
|||
from .grid import GridEncoder
|
|
@ -0,0 +1,40 @@
|
|||
import os
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
nvcc_flags = [
|
||||
'-O3', '-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
||||
]
|
||||
|
||||
if os.name == "posix":
|
||||
c_flags = ['-O3', '-std=c++14', '-finput-charset=UTF-8']
|
||||
elif os.name == "nt":
|
||||
c_flags = ['/O2', '/std:c++17', '/finput-charset=UTF-8']
|
||||
|
||||
# find cl.exe
|
||||
def find_cl_path():
|
||||
import glob
|
||||
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
||||
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
||||
if paths:
|
||||
return paths[0]
|
||||
|
||||
# If cl.exe is not on path, try to find it.
|
||||
if os.system("where cl.exe >nul 2>nul") != 0:
|
||||
cl_path = find_cl_path()
|
||||
if cl_path is None:
|
||||
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
||||
os.environ["PATH"] += ";" + cl_path
|
||||
|
||||
_backend = load(name='_grid_encoder',
|
||||
extra_cflags=c_flags,
|
||||
extra_cuda_cflags=nvcc_flags,
|
||||
sources=[os.path.join(_src_path, 'src', f) for f in [
|
||||
'gridencoder.cu',
|
||||
'bindings.cpp',
|
||||
]],
|
||||
)
|
||||
|
||||
__all__ = ['_backend']
|
|
@ -0,0 +1,155 @@
|
|||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
try:
|
||||
import _gridencoder as _backend
|
||||
except ImportError:
|
||||
from .backend import _backend
|
||||
|
||||
_gridtype_to_id = {
|
||||
'hash': 0,
|
||||
'tiled': 1,
|
||||
}
|
||||
|
||||
class _grid_encode(Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False):
|
||||
# inputs: [B, D], float in [0, 1]
|
||||
# embeddings: [sO, C], float
|
||||
# offsets: [L + 1], int
|
||||
# RETURN: [B, F], float
|
||||
|
||||
inputs = inputs.float().contiguous()
|
||||
|
||||
B, D = inputs.shape # batch size, coord dim
|
||||
L = offsets.shape[0] - 1 # level
|
||||
C = embeddings.shape[1] # embedding dim for each level
|
||||
S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
|
||||
H = base_resolution # base resolution
|
||||
|
||||
# manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)
|
||||
# if C % 2 != 0, force float, since half for atomicAdd is very slow.
|
||||
if torch.is_autocast_enabled() and C % 2 == 0:
|
||||
embeddings = embeddings.to(torch.half)
|
||||
|
||||
# L first, optimize cache for cuda kernel, but needs an extra permute later
|
||||
outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
|
||||
|
||||
if calc_grad_inputs:
|
||||
dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype)
|
||||
else:
|
||||
dy_dx = None
|
||||
|
||||
_backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, dy_dx, gridtype, align_corners)
|
||||
|
||||
# permute back to [B, L * C]
|
||||
outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
|
||||
|
||||
ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
|
||||
ctx.dims = [B, D, C, L, S, H, gridtype]
|
||||
ctx.align_corners = align_corners
|
||||
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
#@once_differentiable
|
||||
@custom_bwd
|
||||
def backward(ctx, grad):
|
||||
|
||||
inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
|
||||
B, D, C, L, S, H, gridtype = ctx.dims
|
||||
align_corners = ctx.align_corners
|
||||
|
||||
# grad: [B, L * C] --> [L, B, C]
|
||||
grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
|
||||
|
||||
grad_embeddings = torch.zeros_like(embeddings)
|
||||
|
||||
if dy_dx is not None:
|
||||
grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
|
||||
else:
|
||||
grad_inputs = None
|
||||
|
||||
_backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners)
|
||||
|
||||
if dy_dx is not None:
|
||||
grad_inputs = grad_inputs.to(inputs.dtype)
|
||||
|
||||
return grad_inputs, grad_embeddings, None, None, None, None, None, None
|
||||
|
||||
|
||||
|
||||
grid_encode = _grid_encode.apply
|
||||
|
||||
|
||||
class GridEncoder(nn.Module):
|
||||
def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False):
|
||||
super().__init__()
|
||||
|
||||
# the finest resolution desired at the last level, if provided, overridee per_level_scale
|
||||
if desired_resolution is not None:
|
||||
per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1))
|
||||
|
||||
self.input_dim = input_dim # coord dims, 2 or 3
|
||||
self.num_levels = num_levels # num levels, each level multiply resolution by 2
|
||||
self.level_dim = level_dim # encode channels per level
|
||||
self.per_level_scale = per_level_scale # multiply resolution by this scale at each level.
|
||||
self.log2_hashmap_size = log2_hashmap_size
|
||||
self.base_resolution = base_resolution
|
||||
self.output_dim = num_levels * level_dim
|
||||
self.gridtype = gridtype
|
||||
self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash"
|
||||
self.align_corners = align_corners
|
||||
|
||||
# allocate parameters
|
||||
offsets = []
|
||||
offset = 0
|
||||
self.max_params = 2 ** log2_hashmap_size
|
||||
for i in range(num_levels):
|
||||
resolution = int(np.ceil(base_resolution * per_level_scale ** i))
|
||||
params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number
|
||||
params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible
|
||||
offsets.append(offset)
|
||||
offset += params_in_level
|
||||
# print(resolution, params_in_level)
|
||||
offsets.append(offset)
|
||||
offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
|
||||
self.register_buffer('offsets', offsets)
|
||||
|
||||
self.n_params = offsets[-1] * level_dim
|
||||
|
||||
# parameters
|
||||
self.embeddings = nn.Parameter(torch.empty(offset, level_dim))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
std = 1e-4
|
||||
self.embeddings.data.uniform_(-std, std)
|
||||
|
||||
def __repr__(self):
|
||||
return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners}"
|
||||
|
||||
def forward(self, inputs, bound=1):
|
||||
# inputs: [..., input_dim], normalized real world positions in [-bound, bound]
|
||||
# return: [..., num_levels * level_dim]
|
||||
|
||||
inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
|
||||
|
||||
#print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())
|
||||
|
||||
prefix_shape = list(inputs.shape[:-1])
|
||||
inputs = inputs.view(-1, self.input_dim)
|
||||
|
||||
outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners)
|
||||
outputs = outputs.view(prefix_shape + [self.output_dim])
|
||||
|
||||
#print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())
|
||||
|
||||
return outputs
|
|
@ -0,0 +1,50 @@
|
|||
import os
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
nvcc_flags = [
|
||||
'-O3', '-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
||||
]
|
||||
|
||||
if os.name == "posix":
|
||||
c_flags = ['-O3', '-std=c++14']
|
||||
elif os.name == "nt":
|
||||
c_flags = ['/O2', '/std:c++17']
|
||||
|
||||
# find cl.exe
|
||||
def find_cl_path():
|
||||
import glob
|
||||
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
||||
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
||||
if paths:
|
||||
return paths[0]
|
||||
|
||||
# If cl.exe is not on path, try to find it.
|
||||
if os.system("where cl.exe >nul 2>nul") != 0:
|
||||
cl_path = find_cl_path()
|
||||
if cl_path is None:
|
||||
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
||||
os.environ["PATH"] += ";" + cl_path
|
||||
|
||||
setup(
|
||||
name='gridencoder', # package name, import this to use python API
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
name='_gridencoder', # extension name, import this to use CUDA API
|
||||
sources=[os.path.join(_src_path, 'src', f) for f in [
|
||||
'gridencoder.cu',
|
||||
'bindings.cpp',
|
||||
]],
|
||||
extra_compile_args={
|
||||
'cxx': c_flags,
|
||||
'nvcc': nvcc_flags,
|
||||
}
|
||||
),
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension,
|
||||
}
|
||||
)
|
|
@ -0,0 +1,8 @@
|
|||
#include <torch/extension.h>
|
||||
|
||||
#include "gridencoder.h"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)");
|
||||
m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)");
|
||||
}
|
|
@ -0,0 +1,479 @@
|
|||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
|
||||
#include <stdint.h>
|
||||
#include <cstdio>
|
||||
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
|
||||
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
|
||||
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
|
||||
|
||||
|
||||
// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF...
|
||||
static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
|
||||
// requires CUDA >= 10 and ARCH >= 70
|
||||
// this is very slow compared to float or __half2, and never used.
|
||||
//return atomicAdd(reinterpret_cast<__half*>(address), val);
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
static inline __host__ __device__ T div_round_up(T val, T divisor) {
|
||||
return (val + divisor - 1) / divisor;
|
||||
}
|
||||
|
||||
|
||||
template <uint32_t D>
|
||||
__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {
|
||||
static_assert(D <= 7, "fast_hash can only hash up to 7 dimensions.");
|
||||
|
||||
// While 1 is technically not a good prime for hashing (or a prime at all), it helps memory coherence
|
||||
// and is sufficient for our use case of obtaining a uniformly colliding index from high-dimensional
|
||||
// coordinates.
|
||||
constexpr uint32_t primes[7] = { 1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737 };
|
||||
|
||||
uint32_t result = 0;
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < D; ++i) {
|
||||
result ^= pos_grid[i] * primes[i];
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
template <uint32_t D, uint32_t C>
|
||||
__device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) {
|
||||
uint32_t stride = 1;
|
||||
uint32_t index = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {
|
||||
index += pos_grid[d] * stride;
|
||||
stride *= align_corners ? resolution: (resolution + 1);
|
||||
}
|
||||
|
||||
// NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97.
|
||||
// gridtype: 0 == hash, 1 == tiled
|
||||
if (gridtype == 0 && stride > hashmap_size) {
|
||||
index = fast_hash<D>(pos_grid);
|
||||
}
|
||||
|
||||
return (index % hashmap_size) * C + ch;
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t, uint32_t D, uint32_t C>
|
||||
__global__ void kernel_grid(
|
||||
const float * __restrict__ inputs,
|
||||
const scalar_t * __restrict__ grid,
|
||||
const int * __restrict__ offsets,
|
||||
scalar_t * __restrict__ outputs,
|
||||
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
|
||||
scalar_t * __restrict__ dy_dx,
|
||||
const uint32_t gridtype,
|
||||
const bool align_corners
|
||||
) {
|
||||
const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (b >= B) return;
|
||||
|
||||
const uint32_t level = blockIdx.y;
|
||||
|
||||
// locate
|
||||
grid += (uint32_t)offsets[level] * C;
|
||||
inputs += b * D;
|
||||
outputs += level * B * C + b * C;
|
||||
|
||||
// check input range (should be in [0, 1])
|
||||
bool flag_oob = false;
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
if (inputs[d] < 0 || inputs[d] > 1) {
|
||||
flag_oob = true;
|
||||
}
|
||||
}
|
||||
// if input out of bound, just set output to 0
|
||||
if (flag_oob) {
|
||||
#pragma unroll
|
||||
for (uint32_t ch = 0; ch < C; ch++) {
|
||||
outputs[ch] = 0;
|
||||
}
|
||||
if (dy_dx) {
|
||||
dy_dx += b * D * L * C + level * D * C; // B L D C
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
#pragma unroll
|
||||
for (uint32_t ch = 0; ch < C; ch++) {
|
||||
dy_dx[d * C + ch] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
|
||||
const float scale = exp2f(level * S) * H - 1.0f;
|
||||
const uint32_t resolution = (uint32_t)ceil(scale) + 1;
|
||||
|
||||
// calculate coordinate
|
||||
float pos[D];
|
||||
uint32_t pos_grid[D];
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
|
||||
pos_grid[d] = floorf(pos[d]);
|
||||
pos[d] -= (float)pos_grid[d];
|
||||
}
|
||||
|
||||
//printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);
|
||||
|
||||
// interpolate
|
||||
scalar_t results[C] = {0}; // temp results in register
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t idx = 0; idx < (1 << D); idx++) {
|
||||
float w = 1;
|
||||
uint32_t pos_grid_local[D];
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
if ((idx & (1 << d)) == 0) {
|
||||
w *= 1 - pos[d];
|
||||
pos_grid_local[d] = pos_grid[d];
|
||||
} else {
|
||||
w *= pos[d];
|
||||
pos_grid_local[d] = pos_grid[d] + 1;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t index = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
|
||||
|
||||
// writing to register (fast)
|
||||
#pragma unroll
|
||||
for (uint32_t ch = 0; ch < C; ch++) {
|
||||
results[ch] += w * grid[index + ch];
|
||||
}
|
||||
|
||||
//printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]);
|
||||
}
|
||||
|
||||
// writing to global memory (slow)
|
||||
#pragma unroll
|
||||
for (uint32_t ch = 0; ch < C; ch++) {
|
||||
outputs[ch] = results[ch];
|
||||
}
|
||||
|
||||
// prepare dy_dx
|
||||
// differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9
|
||||
if (dy_dx) {
|
||||
|
||||
dy_dx += b * D * L * C + level * D * C; // B L D C
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t gd = 0; gd < D; gd++) {
|
||||
|
||||
scalar_t results_grad[C] = {0};
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {
|
||||
float w = scale;
|
||||
uint32_t pos_grid_local[D];
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t nd = 0; nd < D - 1; nd++) {
|
||||
const uint32_t d = (nd >= gd) ? (nd + 1) : nd;
|
||||
|
||||
if ((idx & (1 << nd)) == 0) {
|
||||
w *= 1 - pos[d];
|
||||
pos_grid_local[d] = pos_grid[d];
|
||||
} else {
|
||||
w *= pos[d];
|
||||
pos_grid_local[d] = pos_grid[d] + 1;
|
||||
}
|
||||
}
|
||||
|
||||
pos_grid_local[gd] = pos_grid[gd];
|
||||
uint32_t index_left = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
|
||||
pos_grid_local[gd] = pos_grid[gd] + 1;
|
||||
uint32_t index_right = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t ch = 0; ch < C; ch++) {
|
||||
results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t ch = 0; ch < C; ch++) {
|
||||
dy_dx[gd * C + ch] = results_grad[ch];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t, uint32_t D, uint32_t C, uint32_t N_C>
|
||||
__global__ void kernel_grid_backward(
|
||||
const scalar_t * __restrict__ grad,
|
||||
const float * __restrict__ inputs,
|
||||
const scalar_t * __restrict__ grid,
|
||||
const int * __restrict__ offsets,
|
||||
scalar_t * __restrict__ grad_grid,
|
||||
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
|
||||
const uint32_t gridtype,
|
||||
const bool align_corners
|
||||
) {
|
||||
const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;
|
||||
if (b >= B) return;
|
||||
|
||||
const uint32_t level = blockIdx.y;
|
||||
const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;
|
||||
|
||||
// locate
|
||||
grad_grid += offsets[level] * C;
|
||||
inputs += b * D;
|
||||
grad += level * B * C + b * C + ch; // L, B, C
|
||||
|
||||
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
|
||||
const float scale = exp2f(level * S) * H - 1.0f;
|
||||
const uint32_t resolution = (uint32_t)ceil(scale) + 1;
|
||||
|
||||
// check input range (should be in [0, 1])
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
if (inputs[d] < 0 || inputs[d] > 1) {
|
||||
return; // grad is init as 0, so we simply return.
|
||||
}
|
||||
}
|
||||
|
||||
// calculate coordinate
|
||||
float pos[D];
|
||||
uint32_t pos_grid[D];
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
|
||||
pos_grid[d] = floorf(pos[d]);
|
||||
pos[d] -= (float)pos_grid[d];
|
||||
}
|
||||
|
||||
scalar_t grad_cur[N_C] = {0}; // fetch to register
|
||||
#pragma unroll
|
||||
for (uint32_t c = 0; c < N_C; c++) {
|
||||
grad_cur[c] = grad[c];
|
||||
}
|
||||
|
||||
// interpolate
|
||||
#pragma unroll
|
||||
for (uint32_t idx = 0; idx < (1 << D); idx++) {
|
||||
float w = 1;
|
||||
uint32_t pos_grid_local[D];
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
if ((idx & (1 << d)) == 0) {
|
||||
w *= 1 - pos[d];
|
||||
pos_grid_local[d] = pos_grid[d];
|
||||
} else {
|
||||
w *= pos[d];
|
||||
pos_grid_local[d] = pos_grid[d] + 1;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t index = get_grid_index<D, C>(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local);
|
||||
|
||||
// atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0
|
||||
// TODO: use float which is better than __half, if N_C % 2 != 0
|
||||
if (std::is_same<scalar_t, at::Half>::value && N_C % 2 == 0) {
|
||||
#pragma unroll
|
||||
for (uint32_t c = 0; c < N_C; c += 2) {
|
||||
// process two __half at once (by interpreting as a __half2)
|
||||
__half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])};
|
||||
atomicAdd((__half2*)&grad_grid[index + c], v);
|
||||
}
|
||||
// float, or __half when N_C % 2 != 0 (which means C == 1)
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t c = 0; c < N_C; c++) {
|
||||
atomicAdd(&grad_grid[index + c], w * grad_cur[c]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t, uint32_t D, uint32_t C>
|
||||
__global__ void kernel_input_backward(
|
||||
const scalar_t * __restrict__ grad,
|
||||
const scalar_t * __restrict__ dy_dx,
|
||||
scalar_t * __restrict__ grad_inputs,
|
||||
uint32_t B, uint32_t L
|
||||
) {
|
||||
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (t >= B * D) return;
|
||||
|
||||
const uint32_t b = t / D;
|
||||
const uint32_t d = t - b * D;
|
||||
|
||||
dy_dx += b * L * D * C;
|
||||
|
||||
scalar_t result = 0;
|
||||
|
||||
# pragma unroll
|
||||
for (int l = 0; l < L; l++) {
|
||||
# pragma unroll
|
||||
for (int ch = 0; ch < C; ch++) {
|
||||
result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch];
|
||||
}
|
||||
}
|
||||
|
||||
grad_inputs[t] = result;
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t, uint32_t D>
|
||||
void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {
|
||||
static constexpr uint32_t N_THREAD = 512;
|
||||
const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 };
|
||||
switch (C) {
|
||||
case 1: kernel_grid<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
|
||||
case 2: kernel_grid<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
|
||||
case 4: kernel_grid<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
|
||||
case 8: kernel_grid<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
|
||||
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
|
||||
}
|
||||
}
|
||||
|
||||
// inputs: [B, D], float, in [0, 1]
|
||||
// embeddings: [sO, C], float
|
||||
// offsets: [L + 1], uint32_t
|
||||
// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.)
|
||||
// H: base resolution
|
||||
// dy_dx: [B, L * D * C]
|
||||
template <typename scalar_t>
|
||||
void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {
|
||||
switch (D) {
|
||||
case 1: kernel_grid_wrapper<scalar_t, 1>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
|
||||
case 2: kernel_grid_wrapper<scalar_t, 2>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
|
||||
case 3: kernel_grid_wrapper<scalar_t, 3>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
|
||||
case 4: kernel_grid_wrapper<scalar_t, 4>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
|
||||
case 5: kernel_grid_wrapper<scalar_t, 5>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
|
||||
default: throw std::runtime_error{"GridEncoding: D must be 1, 2, 3, 4, or 5"};
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template <typename scalar_t, uint32_t D>
|
||||
void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
|
||||
static constexpr uint32_t N_THREAD = 256;
|
||||
const uint32_t N_C = std::min(2u, C); // n_features_per_thread
|
||||
const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 };
|
||||
switch (C) {
|
||||
case 1:
|
||||
kernel_grid_backward<scalar_t, D, 1, 1><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
|
||||
if (dy_dx) kernel_input_backward<scalar_t, D, 1><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
||||
break;
|
||||
case 2:
|
||||
kernel_grid_backward<scalar_t, D, 2, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
|
||||
if (dy_dx) kernel_input_backward<scalar_t, D, 2><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
||||
break;
|
||||
case 4:
|
||||
kernel_grid_backward<scalar_t, D, 4, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
|
||||
if (dy_dx) kernel_input_backward<scalar_t, D, 4><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
||||
break;
|
||||
case 8:
|
||||
kernel_grid_backward<scalar_t, D, 8, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
|
||||
if (dy_dx) kernel_input_backward<scalar_t, D, 8><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
||||
break;
|
||||
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// grad: [L, B, C], float
|
||||
// inputs: [B, D], float, in [0, 1]
|
||||
// embeddings: [sO, C], float
|
||||
// offsets: [L + 1], uint32_t
|
||||
// grad_embeddings: [sO, C]
|
||||
// H: base resolution
|
||||
template <typename scalar_t>
|
||||
void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
|
||||
switch (D) {
|
||||
case 1: kernel_grid_backward_wrapper<scalar_t, 1>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
|
||||
case 2: kernel_grid_backward_wrapper<scalar_t, 2>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
|
||||
case 3: kernel_grid_backward_wrapper<scalar_t, 3>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
|
||||
case 4: kernel_grid_backward_wrapper<scalar_t, 4>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
|
||||
case 5: kernel_grid_backward_wrapper<scalar_t, 5>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
|
||||
default: throw std::runtime_error{"GridEncoding: D must be 1, 2, 3, 4, or 5"};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners) {
|
||||
CHECK_CUDA(inputs);
|
||||
CHECK_CUDA(embeddings);
|
||||
CHECK_CUDA(offsets);
|
||||
CHECK_CUDA(outputs);
|
||||
// CHECK_CUDA(dy_dx);
|
||||
|
||||
CHECK_CONTIGUOUS(inputs);
|
||||
CHECK_CONTIGUOUS(embeddings);
|
||||
CHECK_CONTIGUOUS(offsets);
|
||||
CHECK_CONTIGUOUS(outputs);
|
||||
// CHECK_CONTIGUOUS(dy_dx);
|
||||
|
||||
CHECK_IS_FLOATING(inputs);
|
||||
CHECK_IS_FLOATING(embeddings);
|
||||
CHECK_IS_INT(offsets);
|
||||
CHECK_IS_FLOATING(outputs);
|
||||
// CHECK_IS_FLOATING(dy_dx);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
embeddings.scalar_type(), "grid_encode_forward", ([&] {
|
||||
grid_encode_forward_cuda<scalar_t>(inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), outputs.data_ptr<scalar_t>(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners);
|
||||
}));
|
||||
}
|
||||
|
||||
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners) {
|
||||
CHECK_CUDA(grad);
|
||||
CHECK_CUDA(inputs);
|
||||
CHECK_CUDA(embeddings);
|
||||
CHECK_CUDA(offsets);
|
||||
CHECK_CUDA(grad_embeddings);
|
||||
// CHECK_CUDA(dy_dx);
|
||||
// CHECK_CUDA(grad_inputs);
|
||||
|
||||
CHECK_CONTIGUOUS(grad);
|
||||
CHECK_CONTIGUOUS(inputs);
|
||||
CHECK_CONTIGUOUS(embeddings);
|
||||
CHECK_CONTIGUOUS(offsets);
|
||||
CHECK_CONTIGUOUS(grad_embeddings);
|
||||
// CHECK_CONTIGUOUS(dy_dx);
|
||||
// CHECK_CONTIGUOUS(grad_inputs);
|
||||
|
||||
CHECK_IS_FLOATING(grad);
|
||||
CHECK_IS_FLOATING(inputs);
|
||||
CHECK_IS_FLOATING(embeddings);
|
||||
CHECK_IS_INT(offsets);
|
||||
CHECK_IS_FLOATING(grad_embeddings);
|
||||
// CHECK_IS_FLOATING(dy_dx);
|
||||
// CHECK_IS_FLOATING(grad_inputs);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
grad.scalar_type(), "grid_encode_backward", ([&] {
|
||||
grid_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), grad_embeddings.data_ptr<scalar_t>(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners);
|
||||
}));
|
||||
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
#ifndef _HASH_ENCODE_H
|
||||
#define _HASH_ENCODE_H
|
||||
|
||||
#include <stdint.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
// inputs: [B, D], float, in [0, 1]
|
||||
// embeddings: [sO, C], float
|
||||
// offsets: [L + 1], uint32_t
|
||||
// outputs: [B, L * C], float
|
||||
// H: base resolution
|
||||
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners);
|
||||
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners);
|
||||
|
||||
#endif
|
|
@ -0,0 +1,260 @@
|
|||
import torch
|
||||
import argparse
|
||||
|
||||
from nerf_triplane.provider import NeRFDataset
|
||||
from nerf_triplane.utils import *
|
||||
from nerf_triplane.network import NeRFNetwork
|
||||
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
# Close tf32 features. Fix low numerical accuracy on rtx30xx gpu.
|
||||
try:
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
except AttributeError as e:
|
||||
print('Info. This pytorch version is not support with tf32.')
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('path', type=str)
|
||||
parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --exp_eye")
|
||||
parser.add_argument('--test', action='store_true', help="test mode (load model and test dataset)")
|
||||
parser.add_argument('--test_train', action='store_true', help="test mode (load model and train dataset)")
|
||||
parser.add_argument('--data_range', type=int, nargs='*', default=[0, -1], help="data range to use")
|
||||
parser.add_argument('--workspace', type=str, default='workspace')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
|
||||
### training options
|
||||
parser.add_argument('--iters', type=int, default=200000, help="training iters")
|
||||
parser.add_argument('--lr', type=float, default=1e-2, help="initial learning rate")
|
||||
parser.add_argument('--lr_net', type=float, default=1e-3, help="initial learning rate")
|
||||
parser.add_argument('--ckpt', type=str, default='latest')
|
||||
parser.add_argument('--num_rays', type=int, default=4096 * 16, help="num rays sampled per image for each training step")
|
||||
parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
|
||||
parser.add_argument('--max_steps', type=int, default=16, help="max num steps sampled per ray (only valid when using --cuda_ray)")
|
||||
parser.add_argument('--num_steps', type=int, default=16, help="num steps sampled per ray (only valid when NOT using --cuda_ray)")
|
||||
parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)")
|
||||
parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
|
||||
parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)")
|
||||
|
||||
### loss set
|
||||
parser.add_argument('--warmup_step', type=int, default=10000, help="warm up steps")
|
||||
parser.add_argument('--amb_aud_loss', type=int, default=1, help="use ambient aud loss")
|
||||
parser.add_argument('--amb_eye_loss', type=int, default=1, help="use ambient eye loss")
|
||||
parser.add_argument('--unc_loss', type=int, default=1, help="use uncertainty loss")
|
||||
parser.add_argument('--lambda_amb', type=float, default=1e-4, help="lambda for ambient loss")
|
||||
|
||||
### network backbone options
|
||||
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
|
||||
|
||||
parser.add_argument('--bg_img', type=str, default='', help="background image")
|
||||
parser.add_argument('--fbg', action='store_true', help="frame-wise bg")
|
||||
parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes")
|
||||
parser.add_argument('--fix_eye', type=float, default=-1, help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye")
|
||||
parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence")
|
||||
|
||||
parser.add_argument('--torso_shrink', type=float, default=0.8, help="shrink bg coords to allow more flexibility in deform")
|
||||
|
||||
### dataset options
|
||||
parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)")
|
||||
parser.add_argument('--preload', type=int, default=0, help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.")
|
||||
# (the default value is for the fox dataset)
|
||||
parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.")
|
||||
parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3")
|
||||
parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location")
|
||||
parser.add_argument('--dt_gamma', type=float, default=1/256, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
|
||||
parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera")
|
||||
parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied (sigma)")
|
||||
parser.add_argument('--density_thresh_torso', type=float, default=0.01, help="threshold for density grid to be occupied (alpha)")
|
||||
parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable")
|
||||
|
||||
parser.add_argument('--init_lips', action='store_true', help="init lips region")
|
||||
parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region")
|
||||
parser.add_argument('--smooth_lips', action='store_true', help="smooth the enc_a in a exponential decay way...")
|
||||
|
||||
parser.add_argument('--torso', action='store_true', help="fix head and train torso")
|
||||
parser.add_argument('--head_ckpt', type=str, default='', help="head model")
|
||||
|
||||
### GUI options
|
||||
parser.add_argument('--gui', action='store_true', help="start a GUI")
|
||||
parser.add_argument('--W', type=int, default=450, help="GUI width")
|
||||
parser.add_argument('--H', type=int, default=450, help="GUI height")
|
||||
parser.add_argument('--radius', type=float, default=3.35, help="default GUI camera radius from center")
|
||||
parser.add_argument('--fovy', type=float, default=21.24, help="default GUI camera fovy")
|
||||
parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel")
|
||||
|
||||
### else
|
||||
parser.add_argument('--att', type=int, default=2, help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)")
|
||||
parser.add_argument('--aud', type=str, default='', help="audio source (empty will load the default, else should be a path to a npy file)")
|
||||
parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits")
|
||||
|
||||
parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off")
|
||||
parser.add_argument('--ind_num', type=int, default=10000, help="number of individual codes, should be larger than training dataset size")
|
||||
|
||||
parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off")
|
||||
|
||||
parser.add_argument('--amb_dim', type=int, default=2, help="ambient dimension")
|
||||
parser.add_argument('--part', action='store_true', help="use partial training data (1/10)")
|
||||
parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)")
|
||||
|
||||
parser.add_argument('--train_camera', action='store_true', help="optimize camera pose")
|
||||
parser.add_argument('--smooth_path', action='store_true', help="brute-force smooth camera pose trajectory with a window size")
|
||||
parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size")
|
||||
|
||||
# asr
|
||||
parser.add_argument('--asr', action='store_true', help="load asr for real-time app")
|
||||
parser.add_argument('--asr_wav', type=str, default='', help="load the wav and use as input")
|
||||
parser.add_argument('--asr_play', action='store_true', help="play out the audio")
|
||||
|
||||
parser.add_argument('--asr_model', type=str, default='deepspeech')
|
||||
# parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto')
|
||||
# parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
|
||||
|
||||
parser.add_argument('--asr_save_feats', action='store_true')
|
||||
# audio FPS
|
||||
parser.add_argument('--fps', type=int, default=50)
|
||||
# sliding window left-middle-right length (unit: 20ms)
|
||||
parser.add_argument('-l', type=int, default=10)
|
||||
parser.add_argument('-m', type=int, default=50)
|
||||
parser.add_argument('-r', type=int, default=10)
|
||||
|
||||
opt = parser.parse_args()
|
||||
|
||||
if opt.O:
|
||||
opt.fp16 = True
|
||||
opt.exp_eye = True
|
||||
|
||||
if opt.test and False:
|
||||
opt.smooth_path = True
|
||||
opt.smooth_eye = True
|
||||
opt.smooth_lips = True
|
||||
|
||||
opt.cuda_ray = True
|
||||
# assert opt.cuda_ray, "Only support CUDA ray mode."
|
||||
|
||||
if opt.patch_size > 1:
|
||||
# assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss."
|
||||
assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays."
|
||||
|
||||
# if opt.finetune_lips:
|
||||
# # do not update density grid in finetune stage
|
||||
# opt.update_extra_interval = 1e9
|
||||
|
||||
print(opt)
|
||||
|
||||
seed_everything(opt.seed)
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
model = NeRFNetwork(opt)
|
||||
|
||||
# manually load state dict for head
|
||||
if opt.torso and opt.head_ckpt != '':
|
||||
|
||||
model_dict = torch.load(opt.head_ckpt, map_location='cpu')['model']
|
||||
|
||||
missing_keys, unexpected_keys = model.load_state_dict(model_dict, strict=False)
|
||||
|
||||
if len(missing_keys) > 0:
|
||||
print(f"[WARN] missing keys: {missing_keys}")
|
||||
if len(unexpected_keys) > 0:
|
||||
print(f"[WARN] unexpected keys: {unexpected_keys}")
|
||||
|
||||
# freeze these keys
|
||||
for k, v in model.named_parameters():
|
||||
if k in model_dict:
|
||||
# print(f'[INFO] freeze {k}, {v.shape}')
|
||||
v.requires_grad = False
|
||||
|
||||
|
||||
# print(model)
|
||||
|
||||
criterion = torch.nn.MSELoss(reduction='none')
|
||||
|
||||
if opt.test:
|
||||
|
||||
if opt.gui:
|
||||
metrics = [] # use no metric in GUI for faster initialization...
|
||||
else:
|
||||
# metrics = [PSNRMeter(), LPIPSMeter(device=device)]
|
||||
metrics = [PSNRMeter(), LPIPSMeter(device=device), LMDMeter(backend='fan')]
|
||||
|
||||
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt)
|
||||
|
||||
if opt.test_train:
|
||||
test_set = NeRFDataset(opt, device=device, type='train')
|
||||
# a manual fix to test on the training dataset
|
||||
test_set.training = False
|
||||
test_set.num_rays = -1
|
||||
test_loader = test_set.dataloader()
|
||||
else:
|
||||
test_loader = NeRFDataset(opt, device=device, type='test').dataloader()
|
||||
|
||||
|
||||
# temp fix: for update_extra_states
|
||||
model.aud_features = test_loader._data.auds
|
||||
model.eye_areas = test_loader._data.eye_area
|
||||
|
||||
if opt.gui:
|
||||
from nerf_triplane.gui import NeRFGUI
|
||||
# we still need test_loader to provide audio features for testing.
|
||||
with NeRFGUI(opt, trainer, test_loader) as gui:
|
||||
gui.render()
|
||||
|
||||
else:
|
||||
### test and save video (fast)
|
||||
trainer.test(test_loader)
|
||||
|
||||
### evaluate metrics (slow)
|
||||
if test_loader.has_gt:
|
||||
trainer.evaluate(test_loader)
|
||||
|
||||
|
||||
|
||||
else:
|
||||
|
||||
optimizer = lambda model: torch.optim.AdamW(model.get_params(opt.lr, opt.lr_net), betas=(0, 0.99), eps=1e-8)
|
||||
|
||||
train_loader = NeRFDataset(opt, device=device, type='train').dataloader()
|
||||
|
||||
assert len(train_loader) < opt.ind_num, f"[ERROR] dataset too many frames: {len(train_loader)}, please increase --ind_num to this number!"
|
||||
|
||||
# temp fix: for update_extra_states
|
||||
model.aud_features = train_loader._data.auds
|
||||
model.eye_area = train_loader._data.eye_area
|
||||
model.poses = train_loader._data.poses
|
||||
|
||||
# decay to 0.1 * init_lr at last iter step
|
||||
if opt.finetune_lips:
|
||||
scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.05 ** (iter / opt.iters))
|
||||
else:
|
||||
scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.5 ** (iter / opt.iters))
|
||||
|
||||
metrics = [PSNRMeter(), LPIPSMeter(device=device)]
|
||||
|
||||
eval_interval = max(1, int(5000 / len(train_loader)))
|
||||
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, optimizer=optimizer, criterion=criterion, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, scheduler_update_every_step=True, metrics=metrics, use_checkpoint=opt.ckpt, eval_interval=eval_interval)
|
||||
with open(os.path.join(opt.workspace, 'opt.txt'), 'a') as f:
|
||||
f.write(str(opt))
|
||||
if opt.gui:
|
||||
with NeRFGUI(opt, trainer, train_loader) as gui:
|
||||
gui.render()
|
||||
|
||||
else:
|
||||
valid_loader = NeRFDataset(opt, device=device, type='val', downscale=1).dataloader()
|
||||
|
||||
max_epochs = np.ceil(opt.iters / len(train_loader)).astype(np.int32)
|
||||
print(f'[INFO] max_epoch = {max_epochs}')
|
||||
trainer.train(train_loader, valid_loader, max_epochs)
|
||||
|
||||
# free some mem
|
||||
del train_loader, valid_loader
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# also test
|
||||
test_loader = NeRFDataset(opt, device=device, type='test').dataloader()
|
||||
|
||||
if test_loader.has_gt:
|
||||
trainer.evaluate(test_loader) # blender has gt, so evaluate it.
|
||||
|
||||
trainer.test(test_loader)
|
|
@ -0,0 +1,419 @@
|
|||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoModelForCTC, AutoProcessor
|
||||
|
||||
import pyaudio
|
||||
import soundfile as sf
|
||||
import resampy
|
||||
|
||||
from queue import Queue
|
||||
from threading import Thread, Event
|
||||
|
||||
|
||||
def _read_frame(stream, exit_event, queue, chunk):
|
||||
|
||||
while True:
|
||||
if exit_event.is_set():
|
||||
print(f'[INFO] read frame thread ends')
|
||||
break
|
||||
frame = stream.read(chunk, exception_on_overflow=False)
|
||||
frame = np.frombuffer(frame, dtype=np.int16).astype(np.float32) / 32767 # [chunk]
|
||||
queue.put(frame)
|
||||
|
||||
def _play_frame(stream, exit_event, queue, chunk):
|
||||
|
||||
while True:
|
||||
if exit_event.is_set():
|
||||
print(f'[INFO] play frame thread ends')
|
||||
break
|
||||
frame = queue.get()
|
||||
frame = (frame * 32767).astype(np.int16).tobytes()
|
||||
stream.write(frame, chunk)
|
||||
|
||||
class ASR:
|
||||
def __init__(self, opt):
|
||||
|
||||
self.opt = opt
|
||||
|
||||
self.play = opt.asr_play
|
||||
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
self.fps = opt.fps # 20 ms per frame
|
||||
self.sample_rate = 16000
|
||||
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
|
||||
self.mode = 'live' if opt.asr_wav == '' else 'file'
|
||||
|
||||
if 'esperanto' in self.opt.asr_model:
|
||||
self.audio_dim = 44
|
||||
elif 'deepspeech' in self.opt.asr_model:
|
||||
self.audio_dim = 29
|
||||
else:
|
||||
self.audio_dim = 32
|
||||
|
||||
# prepare context cache
|
||||
# each segment is (stride_left + ctx + stride_right) * 20ms, latency should be (ctx + stride_right) * 20ms
|
||||
self.context_size = opt.m
|
||||
self.stride_left_size = opt.l
|
||||
self.stride_right_size = opt.r
|
||||
self.text = '[START]\n'
|
||||
self.terminated = False
|
||||
self.frames = []
|
||||
|
||||
# pad left frames
|
||||
if self.stride_left_size > 0:
|
||||
self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size)
|
||||
|
||||
|
||||
self.exit_event = Event()
|
||||
self.audio_instance = pyaudio.PyAudio()
|
||||
|
||||
# create input stream
|
||||
if self.mode == 'file':
|
||||
self.file_stream = self.create_file_stream()
|
||||
else:
|
||||
# start a background process to read frames
|
||||
self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk)
|
||||
self.queue = Queue()
|
||||
self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk))
|
||||
|
||||
# play out the audio too...?
|
||||
if self.play:
|
||||
self.output_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=False, output=True, frames_per_buffer=self.chunk)
|
||||
self.output_queue = Queue()
|
||||
self.process_play_frame = Thread(target=_play_frame, args=(self.output_stream, self.exit_event, self.output_queue, self.chunk))
|
||||
|
||||
# current location of audio
|
||||
self.idx = 0
|
||||
|
||||
# create wav2vec model
|
||||
print(f'[INFO] loading ASR model {self.opt.asr_model}...')
|
||||
self.processor = AutoProcessor.from_pretrained(opt.asr_model)
|
||||
self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device)
|
||||
|
||||
# prepare to save logits
|
||||
if self.opt.asr_save_feats:
|
||||
self.all_feats = []
|
||||
|
||||
# the extracted features
|
||||
# use a loop queue to efficiently record endless features: [f--t---][-------][-------]
|
||||
self.feat_buffer_size = 4
|
||||
self.feat_buffer_idx = 0
|
||||
self.feat_queue = torch.zeros(self.feat_buffer_size * self.context_size, self.audio_dim, dtype=torch.float32, device=self.device)
|
||||
|
||||
# TODO: hard coded 16 and 8 window size...
|
||||
self.front = self.feat_buffer_size * self.context_size - 8 # fake padding
|
||||
self.tail = 8
|
||||
# attention window...
|
||||
self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding...
|
||||
|
||||
# warm up steps needed: mid + right + window_size + attention_size
|
||||
self.warm_up_steps = self.context_size + self.stride_right_size + 8 + 2 * 3
|
||||
|
||||
self.listening = False
|
||||
self.playing = False
|
||||
|
||||
def listen(self):
|
||||
# start
|
||||
if self.mode == 'live' and not self.listening:
|
||||
print(f'[INFO] starting read frame thread...')
|
||||
self.process_read_frame.start()
|
||||
self.listening = True
|
||||
|
||||
if self.play and not self.playing:
|
||||
print(f'[INFO] starting play frame thread...')
|
||||
self.process_play_frame.start()
|
||||
self.playing = True
|
||||
|
||||
def stop(self):
|
||||
|
||||
self.exit_event.set()
|
||||
|
||||
if self.play:
|
||||
self.output_stream.stop_stream()
|
||||
self.output_stream.close()
|
||||
if self.playing:
|
||||
self.process_play_frame.join()
|
||||
self.playing = False
|
||||
|
||||
if self.mode == 'live':
|
||||
self.input_stream.stop_stream()
|
||||
self.input_stream.close()
|
||||
if self.listening:
|
||||
self.process_read_frame.join()
|
||||
self.listening = False
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
|
||||
self.stop()
|
||||
|
||||
if self.mode == 'live':
|
||||
# live mode: also print the result text.
|
||||
self.text += '\n[END]'
|
||||
print(self.text)
|
||||
|
||||
def get_next_feat(self):
|
||||
# return a [1/8, 16] window, for the next input to nerf side.
|
||||
|
||||
while len(self.att_feats) < 8:
|
||||
# [------f+++t-----]
|
||||
if self.front < self.tail:
|
||||
feat = self.feat_queue[self.front:self.tail]
|
||||
# [++t-----------f+]
|
||||
else:
|
||||
feat = torch.cat([self.feat_queue[self.front:], self.feat_queue[:self.tail]], dim=0)
|
||||
|
||||
self.front = (self.front + 2) % self.feat_queue.shape[0]
|
||||
self.tail = (self.tail + 2) % self.feat_queue.shape[0]
|
||||
|
||||
# print(self.front, self.tail, feat.shape)
|
||||
|
||||
self.att_feats.append(feat.permute(1, 0))
|
||||
|
||||
att_feat = torch.stack(self.att_feats, dim=0) # [8, 44, 16]
|
||||
|
||||
# discard old
|
||||
self.att_feats = self.att_feats[1:]
|
||||
|
||||
return att_feat
|
||||
|
||||
def run_step(self):
|
||||
|
||||
if self.terminated:
|
||||
return
|
||||
|
||||
# get a frame of audio
|
||||
frame = self.get_audio_frame()
|
||||
|
||||
# the last frame
|
||||
if frame is None:
|
||||
# terminate, but always run the network for the left frames
|
||||
self.terminated = True
|
||||
else:
|
||||
self.frames.append(frame)
|
||||
# put to output
|
||||
if self.play:
|
||||
self.output_queue.put(frame)
|
||||
# context not enough, do not run network.
|
||||
if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
|
||||
return
|
||||
|
||||
inputs = np.concatenate(self.frames) # [N * chunk]
|
||||
|
||||
# discard the old part to save memory
|
||||
if not self.terminated:
|
||||
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
|
||||
|
||||
logits, labels, text = self.frame_to_text(inputs)
|
||||
feats = logits # better lips-sync than labels
|
||||
|
||||
# save feats
|
||||
if self.opt.asr_save_feats:
|
||||
self.all_feats.append(feats)
|
||||
|
||||
# record the feats efficiently.. (no concat, constant memory)
|
||||
start = self.feat_buffer_idx * self.context_size
|
||||
end = start + feats.shape[0]
|
||||
self.feat_queue[start:end] = feats
|
||||
self.feat_buffer_idx = (self.feat_buffer_idx + 1) % self.feat_buffer_size
|
||||
|
||||
# very naive, just concat the text output.
|
||||
if text != '':
|
||||
self.text = self.text + ' ' + text
|
||||
|
||||
# will only run once at ternimation
|
||||
if self.terminated:
|
||||
self.text += '\n[END]'
|
||||
print(self.text)
|
||||
if self.opt.asr_save_feats:
|
||||
print(f'[INFO] save all feats for training purpose... ')
|
||||
feats = torch.cat(self.all_feats, dim=0) # [N, C]
|
||||
# print('[INFO] before unfold', feats.shape)
|
||||
window_size = 16
|
||||
padding = window_size // 2
|
||||
feats = feats.view(-1, self.audio_dim).permute(1, 0).contiguous() # [C, M]
|
||||
feats = feats.view(1, self.audio_dim, -1, 1) # [1, C, M, 1]
|
||||
unfold_feats = F.unfold(feats, kernel_size=(window_size, 1), padding=(padding, 0), stride=(2, 1)) # [1, C * window_size, M / 2 + 1]
|
||||
unfold_feats = unfold_feats.view(self.audio_dim, window_size, -1).permute(2, 1, 0).contiguous() # [C, window_size, M / 2 + 1] --> [M / 2 + 1, window_size, C]
|
||||
# print('[INFO] after unfold', unfold_feats.shape)
|
||||
# save to a npy file
|
||||
if 'esperanto' in self.opt.asr_model:
|
||||
output_path = self.opt.asr_wav.replace('.wav', '_eo.npy')
|
||||
else:
|
||||
output_path = self.opt.asr_wav.replace('.wav', '.npy')
|
||||
np.save(output_path, unfold_feats.cpu().numpy())
|
||||
print(f"[INFO] saved logits to {output_path}")
|
||||
|
||||
def create_file_stream(self):
|
||||
|
||||
stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64
|
||||
stream = stream.astype(np.float32)
|
||||
|
||||
if stream.ndim > 1:
|
||||
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
|
||||
stream = stream[:, 0]
|
||||
|
||||
if sample_rate != self.sample_rate:
|
||||
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
|
||||
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
|
||||
|
||||
print(f'[INFO] loaded audio stream {self.opt.asr_wav}: {stream.shape}')
|
||||
|
||||
return stream
|
||||
|
||||
|
||||
def create_pyaudio_stream(self):
|
||||
|
||||
import pyaudio
|
||||
|
||||
print(f'[INFO] creating live audio stream ...')
|
||||
|
||||
audio = pyaudio.PyAudio()
|
||||
|
||||
# get devices
|
||||
info = audio.get_host_api_info_by_index(0)
|
||||
n_devices = info.get('deviceCount')
|
||||
|
||||
for i in range(0, n_devices):
|
||||
if (audio.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
|
||||
name = audio.get_device_info_by_host_api_device_index(0, i).get('name')
|
||||
print(f'[INFO] choose audio device {name}, id {i}')
|
||||
break
|
||||
|
||||
# get stream
|
||||
stream = audio.open(input_device_index=i,
|
||||
format=pyaudio.paInt16,
|
||||
channels=1,
|
||||
rate=self.sample_rate,
|
||||
input=True,
|
||||
frames_per_buffer=self.chunk)
|
||||
|
||||
return audio, stream
|
||||
|
||||
|
||||
def get_audio_frame(self):
|
||||
|
||||
if self.mode == 'file':
|
||||
|
||||
if self.idx < self.file_stream.shape[0]:
|
||||
frame = self.file_stream[self.idx: self.idx + self.chunk]
|
||||
self.idx = self.idx + self.chunk
|
||||
return frame
|
||||
else:
|
||||
return None
|
||||
|
||||
else:
|
||||
|
||||
frame = self.queue.get()
|
||||
# print(f'[INFO] get frame {frame.shape}')
|
||||
|
||||
self.idx = self.idx + self.chunk
|
||||
|
||||
return frame
|
||||
|
||||
|
||||
def frame_to_text(self, frame):
|
||||
# frame: [N * 320], N = (context_size + 2 * stride_size)
|
||||
|
||||
inputs = self.processor(frame, sampling_rate=self.sample_rate, return_tensors="pt", padding=True)
|
||||
|
||||
with torch.no_grad():
|
||||
result = self.model(inputs.input_values.to(self.device))
|
||||
logits = result.logits # [1, N - 1, 32]
|
||||
|
||||
# cut off stride
|
||||
left = max(0, self.stride_left_size)
|
||||
right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input.
|
||||
|
||||
# do not cut right if terminated.
|
||||
if self.terminated:
|
||||
right = logits.shape[1]
|
||||
|
||||
logits = logits[:, left:right]
|
||||
|
||||
# print(frame.shape, inputs.input_values.shape, logits.shape)
|
||||
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
transcription = self.processor.batch_decode(predicted_ids)[0].lower()
|
||||
|
||||
|
||||
# for esperanto
|
||||
# labels = np.array(['ŭ', '»', 'c', 'ĵ', 'ñ', '”', '„', '“', 'ǔ', 'o', 'ĝ', 'm', 'k', 'd', 'a', 'ŝ', 'z', 'i', '«', '—', '‘', 'ĥ', 'f', 'y', 'h', 'j', '|', 'r', 'u', 'ĉ', 's', '–', 'fi', 'l', 'p', '’', 'g', 'v', 't', 'b', 'n', 'e', '[UNK]', '[PAD]'])
|
||||
|
||||
# labels = np.array([' ', ' ', ' ', '-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z'])
|
||||
# print(''.join(labels[predicted_ids[0].detach().cpu().long().numpy()]))
|
||||
# print(predicted_ids[0])
|
||||
# print(transcription)
|
||||
|
||||
return logits[0], predicted_ids[0], transcription # [N,]
|
||||
|
||||
|
||||
def run(self):
|
||||
|
||||
self.listen()
|
||||
|
||||
while not self.terminated:
|
||||
self.run_step()
|
||||
|
||||
def clear_queue(self):
|
||||
# clear the queue, to reduce potential latency...
|
||||
print(f'[INFO] clear queue')
|
||||
if self.mode == 'live':
|
||||
self.queue.queue.clear()
|
||||
if self.play:
|
||||
self.output_queue.queue.clear()
|
||||
|
||||
def warm_up(self):
|
||||
|
||||
self.listen()
|
||||
|
||||
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
|
||||
t = time.time()
|
||||
for _ in range(self.warm_up_steps):
|
||||
self.run_step()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
t = time.time() - t
|
||||
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
|
||||
|
||||
self.clear_queue()
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--wav', type=str, default='')
|
||||
parser.add_argument('--play', action='store_true', help="play out the audio")
|
||||
|
||||
parser.add_argument('--model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto')
|
||||
# parser.add_argument('--model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
|
||||
|
||||
parser.add_argument('--save_feats', action='store_true')
|
||||
# audio FPS
|
||||
parser.add_argument('--fps', type=int, default=50)
|
||||
# sliding window left-middle-right length.
|
||||
parser.add_argument('-l', type=int, default=10)
|
||||
parser.add_argument('-m', type=int, default=50)
|
||||
parser.add_argument('-r', type=int, default=10)
|
||||
|
||||
opt = parser.parse_args()
|
||||
|
||||
# fix
|
||||
opt.asr_wav = opt.wav
|
||||
opt.asr_play = opt.play
|
||||
opt.asr_model = opt.model
|
||||
opt.asr_save_feats = opt.save_feats
|
||||
|
||||
if 'deepspeech' in opt.asr_model:
|
||||
raise ValueError("DeepSpeech features should not use this code to extract...")
|
||||
|
||||
with ASR(opt) as asr:
|
||||
asr.run()
|
|
@ -0,0 +1,565 @@
|
|||
import math
|
||||
import torch
|
||||
import numpy as np
|
||||
import dearpygui.dearpygui as dpg
|
||||
from scipy.spatial.transform import Rotation as R
|
||||
|
||||
from .utils import *
|
||||
|
||||
from .asr import ASR
|
||||
|
||||
|
||||
class OrbitCamera:
|
||||
def __init__(self, W, H, r=2, fovy=60):
|
||||
self.W = W
|
||||
self.H = H
|
||||
self.radius = r # camera distance from center
|
||||
self.fovy = fovy # in degree
|
||||
self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point
|
||||
self.rot = R.from_matrix([[0, -1, 0], [0, 0, -1], [1, 0, 0]]) # init camera matrix: [[1, 0, 0], [0, -1, 0], [0, 0, 1]] (to suit ngp convention)
|
||||
self.up = np.array([1, 0, 0], dtype=np.float32) # need to be normalized!
|
||||
|
||||
# pose
|
||||
@property
|
||||
def pose(self):
|
||||
# first move camera to radius
|
||||
res = np.eye(4, dtype=np.float32)
|
||||
res[2, 3] -= self.radius
|
||||
# rotate
|
||||
rot = np.eye(4, dtype=np.float32)
|
||||
rot[:3, :3] = self.rot.as_matrix()
|
||||
res = rot @ res
|
||||
# translate
|
||||
res[:3, 3] -= self.center
|
||||
return res
|
||||
|
||||
def update_pose(self, pose):
|
||||
# pose: [4, 4] numpy array
|
||||
# assert self.center is 0
|
||||
self.radius = np.linalg.norm(pose[:3, 3])
|
||||
T = np.eye(4)
|
||||
T[2, 3] = -self.radius
|
||||
rot = pose @ np.linalg.inv(T)
|
||||
self.rot = R.from_matrix(rot[:3, :3])
|
||||
|
||||
def update_intrinsics(self, intrinsics):
|
||||
fl_x, fl_y, cx, cy = intrinsics
|
||||
self.W = int(cx * 2)
|
||||
self.H = int(cy * 2)
|
||||
self.fovy = np.rad2deg(2 * np.arctan2(self.H, 2 * fl_y))
|
||||
|
||||
# intrinsics
|
||||
@property
|
||||
def intrinsics(self):
|
||||
focal = self.H / (2 * np.tan(np.deg2rad(self.fovy) / 2))
|
||||
return np.array([focal, focal, self.W // 2, self.H // 2])
|
||||
|
||||
def orbit(self, dx, dy):
|
||||
# rotate along camera up/side axis!
|
||||
side = self.rot.as_matrix()[:3, 0] # why this is side --> ? # already normalized.
|
||||
rotvec_x = self.up * np.radians(-0.01 * dx)
|
||||
rotvec_y = side * np.radians(-0.01 * dy)
|
||||
self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot
|
||||
|
||||
def scale(self, delta):
|
||||
self.radius *= 1.1 ** (-delta)
|
||||
|
||||
def pan(self, dx, dy, dz=0):
|
||||
# pan in camera coordinate system (careful on the sensitivity!)
|
||||
self.center += 0.0001 * self.rot.as_matrix()[:3, :3] @ np.array([dx, dy, dz])
|
||||
|
||||
|
||||
class NeRFGUI:
|
||||
def __init__(self, opt, trainer, data_loader, debug=True):
|
||||
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
|
||||
self.W = opt.W
|
||||
self.H = opt.H
|
||||
self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)
|
||||
self.debug = debug
|
||||
self.training = False
|
||||
self.step = 0 # training step
|
||||
|
||||
self.trainer = trainer
|
||||
self.data_loader = data_loader
|
||||
|
||||
# override with dataloader's intrinsics
|
||||
self.W = data_loader._data.W
|
||||
self.H = data_loader._data.H
|
||||
self.cam.update_intrinsics(data_loader._data.intrinsics)
|
||||
|
||||
# use dataloader's pose
|
||||
pose_init = data_loader._data.poses[0]
|
||||
self.cam.update_pose(pose_init.detach().cpu().numpy())
|
||||
|
||||
# use dataloader's bg
|
||||
bg_img = data_loader._data.bg_img #.view(1, -1, 3)
|
||||
if self.H != bg_img.shape[0] or self.W != bg_img.shape[1]:
|
||||
bg_img = F.interpolate(bg_img.permute(2, 0, 1).unsqueeze(0).contiguous(), (self.H, self.W), mode='bilinear').squeeze(0).permute(1, 2, 0).contiguous()
|
||||
self.bg_color = bg_img.view(1, -1, 3)
|
||||
|
||||
# audio features (from dataloader, only used in non-playing mode)
|
||||
self.audio_features = data_loader._data.auds # [N, 29, 16]
|
||||
self.audio_idx = 0
|
||||
|
||||
# control eye
|
||||
self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item()
|
||||
|
||||
# playing seq from dataloader, or pause.
|
||||
self.playing = False
|
||||
self.loader = iter(data_loader)
|
||||
|
||||
self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
|
||||
self.need_update = True # camera moved, should reset accumulation
|
||||
self.spp = 1 # sample per pixel
|
||||
self.mode = 'image' # choose from ['image', 'depth']
|
||||
|
||||
self.dynamic_resolution = False # assert False!
|
||||
self.downscale = 1
|
||||
self.train_steps = 16
|
||||
|
||||
self.ind_index = 0
|
||||
self.ind_num = trainer.model.individual_codes.shape[0]
|
||||
|
||||
# build asr
|
||||
if self.opt.asr:
|
||||
self.asr = ASR(opt)
|
||||
|
||||
dpg.create_context()
|
||||
self.register_dpg()
|
||||
self.test_step()
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
if self.opt.asr:
|
||||
self.asr.stop()
|
||||
dpg.destroy_context()
|
||||
|
||||
def train_step(self):
|
||||
|
||||
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
||||
starter.record()
|
||||
|
||||
outputs = self.trainer.train_gui(self.data_loader, step=self.train_steps)
|
||||
|
||||
ender.record()
|
||||
torch.cuda.synchronize()
|
||||
t = starter.elapsed_time(ender)
|
||||
|
||||
self.step += self.train_steps
|
||||
self.need_update = True
|
||||
|
||||
dpg.set_value("_log_train_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
|
||||
dpg.set_value("_log_train_log", f'step = {self.step: 5d} (+{self.train_steps: 2d}), loss = {outputs["loss"]:.4f}, lr = {outputs["lr"]:.5f}')
|
||||
|
||||
# dynamic train steps
|
||||
# max allowed train time per-frame is 500 ms
|
||||
full_t = t / self.train_steps * 16
|
||||
train_steps = min(16, max(4, int(16 * 500 / full_t)))
|
||||
if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8:
|
||||
self.train_steps = train_steps
|
||||
|
||||
def prepare_buffer(self, outputs):
|
||||
if self.mode == 'image':
|
||||
return outputs['image']
|
||||
else:
|
||||
return np.expand_dims(outputs['depth'], -1).repeat(3, -1)
|
||||
|
||||
def test_step(self):
|
||||
|
||||
if self.need_update or self.spp < self.opt.max_spp:
|
||||
|
||||
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
||||
starter.record()
|
||||
|
||||
if self.playing:
|
||||
try:
|
||||
data = next(self.loader)
|
||||
except StopIteration:
|
||||
self.loader = iter(self.data_loader)
|
||||
data = next(self.loader)
|
||||
|
||||
if self.opt.asr:
|
||||
# use the live audio stream
|
||||
data['auds'] = self.asr.get_next_feat()
|
||||
|
||||
outputs = self.trainer.test_gui_with_data(data, self.W, self.H)
|
||||
|
||||
# sync local camera pose
|
||||
self.cam.update_pose(data['poses_matrix'][0].detach().cpu().numpy())
|
||||
|
||||
else:
|
||||
if self.audio_features is not None:
|
||||
auds = get_audio_features(self.audio_features, self.opt.att, self.audio_idx)
|
||||
else:
|
||||
auds = None
|
||||
outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.W, self.H, auds, self.eye_area, self.ind_index, self.bg_color, self.spp, self.downscale)
|
||||
|
||||
ender.record()
|
||||
torch.cuda.synchronize()
|
||||
t = starter.elapsed_time(ender)
|
||||
|
||||
# update dynamic resolution
|
||||
if self.dynamic_resolution:
|
||||
# max allowed infer time per-frame is 200 ms
|
||||
full_t = t / (self.downscale ** 2)
|
||||
downscale = min(1, max(1/4, math.sqrt(200 / full_t)))
|
||||
if downscale > self.downscale * 1.2 or downscale < self.downscale * 0.8:
|
||||
self.downscale = downscale
|
||||
|
||||
if self.need_update:
|
||||
self.render_buffer = self.prepare_buffer(outputs)
|
||||
self.spp = 1
|
||||
self.need_update = False
|
||||
else:
|
||||
self.render_buffer = (self.render_buffer * self.spp + self.prepare_buffer(outputs)) / (self.spp + 1)
|
||||
self.spp += 1
|
||||
|
||||
if self.playing:
|
||||
self.need_update = True
|
||||
|
||||
dpg.set_value("_log_infer_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
|
||||
dpg.set_value("_log_resolution", f'{int(self.downscale * self.W)}x{int(self.downscale * self.H)}')
|
||||
dpg.set_value("_log_spp", self.spp)
|
||||
dpg.set_value("_texture", self.render_buffer)
|
||||
|
||||
|
||||
def register_dpg(self):
|
||||
|
||||
### register texture
|
||||
|
||||
with dpg.texture_registry(show=False):
|
||||
dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture")
|
||||
|
||||
### register window
|
||||
|
||||
# the rendered image, as the primary window
|
||||
with dpg.window(tag="_primary_window", width=self.W, height=self.H):
|
||||
|
||||
# add the texture
|
||||
dpg.add_image("_texture")
|
||||
|
||||
# dpg.set_primary_window("_primary_window", True)
|
||||
|
||||
dpg.show_tool(dpg.mvTool_Metrics)
|
||||
|
||||
# control window
|
||||
with dpg.window(label="Control", tag="_control_window", width=400, height=300):
|
||||
|
||||
# button theme
|
||||
with dpg.theme() as theme_button:
|
||||
with dpg.theme_component(dpg.mvButton):
|
||||
dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18))
|
||||
dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47))
|
||||
dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83))
|
||||
dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5)
|
||||
dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3)
|
||||
|
||||
# time
|
||||
if not self.opt.test:
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_text("Train time: ")
|
||||
dpg.add_text("no data", tag="_log_train_time")
|
||||
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_text("Infer time: ")
|
||||
dpg.add_text("no data", tag="_log_infer_time")
|
||||
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_text("SPP: ")
|
||||
dpg.add_text("1", tag="_log_spp")
|
||||
|
||||
# train button
|
||||
if not self.opt.test:
|
||||
with dpg.collapsing_header(label="Train", default_open=True):
|
||||
|
||||
# train / stop
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_text("Train: ")
|
||||
|
||||
def callback_train(sender, app_data):
|
||||
if self.training:
|
||||
self.training = False
|
||||
dpg.configure_item("_button_train", label="start")
|
||||
else:
|
||||
self.training = True
|
||||
dpg.configure_item("_button_train", label="stop")
|
||||
|
||||
dpg.add_button(label="start", tag="_button_train", callback=callback_train)
|
||||
dpg.bind_item_theme("_button_train", theme_button)
|
||||
|
||||
def callback_reset(sender, app_data):
|
||||
@torch.no_grad()
|
||||
def weight_reset(m: nn.Module):
|
||||
reset_parameters = getattr(m, "reset_parameters", None)
|
||||
if callable(reset_parameters):
|
||||
m.reset_parameters()
|
||||
self.trainer.model.apply(fn=weight_reset)
|
||||
self.trainer.model.reset_extra_state() # for cuda_ray density_grid and step_counter
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_button(label="reset", tag="_button_reset", callback=callback_reset)
|
||||
dpg.bind_item_theme("_button_reset", theme_button)
|
||||
|
||||
# save ckpt
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_text("Checkpoint: ")
|
||||
|
||||
def callback_save(sender, app_data):
|
||||
self.trainer.save_checkpoint(full=True, best=False)
|
||||
dpg.set_value("_log_ckpt", "saved " + os.path.basename(self.trainer.stats["checkpoints"][-1]))
|
||||
self.trainer.epoch += 1 # use epoch to indicate different calls.
|
||||
|
||||
dpg.add_button(label="save", tag="_button_save", callback=callback_save)
|
||||
dpg.bind_item_theme("_button_save", theme_button)
|
||||
|
||||
dpg.add_text("", tag="_log_ckpt")
|
||||
|
||||
# save mesh
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_text("Marching Cubes: ")
|
||||
|
||||
def callback_mesh(sender, app_data):
|
||||
self.trainer.save_mesh(resolution=256, threshold=10)
|
||||
dpg.set_value("_log_mesh", "saved " + f'{self.trainer.name}_{self.trainer.epoch}.ply')
|
||||
self.trainer.epoch += 1 # use epoch to indicate different calls.
|
||||
|
||||
dpg.add_button(label="mesh", tag="_button_mesh", callback=callback_mesh)
|
||||
dpg.bind_item_theme("_button_mesh", theme_button)
|
||||
|
||||
dpg.add_text("", tag="_log_mesh")
|
||||
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_text("", tag="_log_train_log")
|
||||
|
||||
|
||||
# rendering options
|
||||
with dpg.collapsing_header(label="Options", default_open=True):
|
||||
|
||||
# playing
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_text("Play: ")
|
||||
|
||||
def callback_play(sender, app_data):
|
||||
|
||||
if self.playing:
|
||||
self.playing = False
|
||||
dpg.configure_item("_button_play", label="start")
|
||||
else:
|
||||
self.playing = True
|
||||
dpg.configure_item("_button_play", label="stop")
|
||||
if self.opt.asr:
|
||||
self.asr.warm_up()
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_button(label="start", tag="_button_play", callback=callback_play)
|
||||
dpg.bind_item_theme("_button_play", theme_button)
|
||||
|
||||
# set asr
|
||||
if self.opt.asr:
|
||||
|
||||
# clear queue button
|
||||
def callback_clear_queue(sender, app_data):
|
||||
|
||||
self.asr.clear_queue()
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_button(label="clear", tag="_button_clear_queue", callback=callback_clear_queue)
|
||||
dpg.bind_item_theme("_button_clear_queue", theme_button)
|
||||
|
||||
# dynamic rendering resolution
|
||||
with dpg.group(horizontal=True):
|
||||
|
||||
def callback_set_dynamic_resolution(sender, app_data):
|
||||
if self.dynamic_resolution:
|
||||
self.dynamic_resolution = False
|
||||
self.downscale = 1
|
||||
else:
|
||||
self.dynamic_resolution = True
|
||||
self.need_update = True
|
||||
|
||||
# Disable dynamic resolution for face.
|
||||
# dpg.add_checkbox(label="dynamic resolution", default_value=self.dynamic_resolution, callback=callback_set_dynamic_resolution)
|
||||
dpg.add_text(f"{self.W}x{self.H}", tag="_log_resolution")
|
||||
|
||||
# mode combo
|
||||
def callback_change_mode(sender, app_data):
|
||||
self.mode = app_data
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_combo(('image', 'depth'), label='mode', default_value=self.mode, callback=callback_change_mode)
|
||||
|
||||
|
||||
# bg_color picker
|
||||
def callback_change_bg(sender, app_data):
|
||||
self.bg_color = torch.tensor(app_data[:3], dtype=torch.float32) # only need RGB in [0, 1]
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_color_edit((255, 255, 255), label="Background Color", width=200, tag="_color_editor", no_alpha=True, callback=callback_change_bg)
|
||||
|
||||
# audio index slider
|
||||
if not self.opt.asr:
|
||||
def callback_set_audio_index(sender, app_data):
|
||||
self.audio_idx = app_data
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_slider_int(label="Audio", min_value=0, max_value=self.audio_features.shape[0] - 1, format="%d", default_value=self.audio_idx, callback=callback_set_audio_index)
|
||||
|
||||
# ind code index slider
|
||||
if self.opt.ind_dim > 0:
|
||||
def callback_set_individual_code(sender, app_data):
|
||||
self.ind_index = app_data
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_slider_int(label="Individual", min_value=0, max_value=self.ind_num - 1, format="%d", default_value=self.ind_index, callback=callback_set_individual_code)
|
||||
|
||||
# eye area slider
|
||||
if self.opt.exp_eye:
|
||||
def callback_set_eye(sender, app_data):
|
||||
self.eye_area = app_data
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_slider_float(label="eye area", min_value=0, max_value=0.5, format="%.2f percent", default_value=self.eye_area, callback=callback_set_eye)
|
||||
|
||||
# fov slider
|
||||
def callback_set_fovy(sender, app_data):
|
||||
self.cam.fovy = app_data
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy)
|
||||
|
||||
# dt_gamma slider
|
||||
def callback_set_dt_gamma(sender, app_data):
|
||||
self.opt.dt_gamma = app_data
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_slider_float(label="dt_gamma", min_value=0, max_value=0.1, format="%.5f", default_value=self.opt.dt_gamma, callback=callback_set_dt_gamma)
|
||||
|
||||
# max_steps slider
|
||||
def callback_set_max_steps(sender, app_data):
|
||||
self.opt.max_steps = app_data
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_slider_int(label="max steps", min_value=1, max_value=1024, format="%d", default_value=self.opt.max_steps, callback=callback_set_max_steps)
|
||||
|
||||
# aabb slider
|
||||
def callback_set_aabb(sender, app_data, user_data):
|
||||
# user_data is the dimension for aabb (xmin, ymin, zmin, xmax, ymax, zmax)
|
||||
self.trainer.model.aabb_infer[user_data] = app_data
|
||||
|
||||
# also change train aabb ? [better not...]
|
||||
#self.trainer.model.aabb_train[user_data] = app_data
|
||||
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_separator()
|
||||
dpg.add_text("Axis-aligned bounding box:")
|
||||
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_slider_float(label="x", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=0)
|
||||
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=3)
|
||||
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_slider_float(label="y", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=1)
|
||||
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=4)
|
||||
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_slider_float(label="z", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=2)
|
||||
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=5)
|
||||
|
||||
|
||||
# debug info
|
||||
if self.debug:
|
||||
with dpg.collapsing_header(label="Debug"):
|
||||
# pose
|
||||
dpg.add_separator()
|
||||
dpg.add_text("Camera Pose:")
|
||||
dpg.add_text(str(self.cam.pose), tag="_log_pose")
|
||||
|
||||
|
||||
### register camera handler
|
||||
|
||||
def callback_camera_drag_rotate(sender, app_data):
|
||||
|
||||
if not dpg.is_item_focused("_primary_window"):
|
||||
return
|
||||
|
||||
dx = app_data[1]
|
||||
dy = app_data[2]
|
||||
|
||||
self.cam.orbit(dx, dy)
|
||||
self.need_update = True
|
||||
|
||||
if self.debug:
|
||||
dpg.set_value("_log_pose", str(self.cam.pose))
|
||||
|
||||
|
||||
def callback_camera_wheel_scale(sender, app_data):
|
||||
|
||||
if not dpg.is_item_focused("_primary_window"):
|
||||
return
|
||||
|
||||
delta = app_data
|
||||
|
||||
self.cam.scale(delta)
|
||||
self.need_update = True
|
||||
|
||||
if self.debug:
|
||||
dpg.set_value("_log_pose", str(self.cam.pose))
|
||||
|
||||
|
||||
def callback_camera_drag_pan(sender, app_data):
|
||||
|
||||
if not dpg.is_item_focused("_primary_window"):
|
||||
return
|
||||
|
||||
dx = app_data[1]
|
||||
dy = app_data[2]
|
||||
|
||||
self.cam.pan(dx, dy)
|
||||
self.need_update = True
|
||||
|
||||
if self.debug:
|
||||
dpg.set_value("_log_pose", str(self.cam.pose))
|
||||
|
||||
|
||||
with dpg.handler_registry():
|
||||
dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate)
|
||||
dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
|
||||
dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan)
|
||||
|
||||
|
||||
dpg.create_viewport(title='RAD-NeRF', width=1080, height=720, resizable=True)
|
||||
|
||||
### global theme
|
||||
with dpg.theme() as theme_no_padding:
|
||||
with dpg.theme_component(dpg.mvAll):
|
||||
# set all padding to 0 to avoid scroll bar
|
||||
dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core)
|
||||
dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core)
|
||||
dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core)
|
||||
|
||||
dpg.bind_item_theme("_primary_window", theme_no_padding)
|
||||
|
||||
dpg.setup_dearpygui()
|
||||
|
||||
#dpg.show_metrics()
|
||||
|
||||
dpg.show_viewport()
|
||||
|
||||
|
||||
def render(self):
|
||||
|
||||
while dpg.is_dearpygui_running():
|
||||
# update texture every frame
|
||||
if self.training:
|
||||
self.train_step()
|
||||
# audio stream thread...
|
||||
if self.opt.asr and self.playing:
|
||||
# run 2 ASR steps (audio is at 50FPS, video is at 25FPS)
|
||||
for _ in range(2):
|
||||
self.asr.run_step()
|
||||
self.test_step()
|
||||
dpg.render_dearpygui_frame()
|
|
@ -0,0 +1,352 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from encoding import get_encoder
|
||||
from .renderer import NeRFRenderer
|
||||
|
||||
# Audio feature extractor
|
||||
class AudioAttNet(nn.Module):
|
||||
def __init__(self, dim_aud=64, seq_len=8):
|
||||
super(AudioAttNet, self).__init__()
|
||||
self.seq_len = seq_len
|
||||
self.dim_aud = dim_aud
|
||||
self.attentionConvNet = nn.Sequential( # b x subspace_dim x seq_len
|
||||
nn.Conv1d(self.dim_aud, 16, kernel_size=3, stride=1, padding=1, bias=True),
|
||||
nn.LeakyReLU(0.02, True),
|
||||
nn.Conv1d(16, 8, kernel_size=3, stride=1, padding=1, bias=True),
|
||||
nn.LeakyReLU(0.02, True),
|
||||
nn.Conv1d(8, 4, kernel_size=3, stride=1, padding=1, bias=True),
|
||||
nn.LeakyReLU(0.02, True),
|
||||
nn.Conv1d(4, 2, kernel_size=3, stride=1, padding=1, bias=True),
|
||||
nn.LeakyReLU(0.02, True),
|
||||
nn.Conv1d(2, 1, kernel_size=3, stride=1, padding=1, bias=True),
|
||||
nn.LeakyReLU(0.02, True)
|
||||
)
|
||||
self.attentionNet = nn.Sequential(
|
||||
nn.Linear(in_features=self.seq_len, out_features=self.seq_len, bias=True),
|
||||
nn.Softmax(dim=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [1, seq_len, dim_aud]
|
||||
y = x.permute(0, 2, 1) # [1, dim_aud, seq_len]
|
||||
y = self.attentionConvNet(y)
|
||||
y = self.attentionNet(y.view(1, self.seq_len)).view(1, self.seq_len, 1)
|
||||
return torch.sum(y * x, dim=1) # [1, dim_aud]
|
||||
|
||||
|
||||
# Audio feature extractor
|
||||
class AudioNet(nn.Module):
|
||||
def __init__(self, dim_in=29, dim_aud=64, win_size=16):
|
||||
super(AudioNet, self).__init__()
|
||||
self.win_size = win_size
|
||||
self.dim_aud = dim_aud
|
||||
self.encoder_conv = nn.Sequential( # n x 29 x 16
|
||||
nn.Conv1d(dim_in, 32, kernel_size=3, stride=2, padding=1, bias=True), # n x 32 x 8
|
||||
nn.LeakyReLU(0.02, True),
|
||||
nn.Conv1d(32, 32, kernel_size=3, stride=2, padding=1, bias=True), # n x 32 x 4
|
||||
nn.LeakyReLU(0.02, True),
|
||||
nn.Conv1d(32, 64, kernel_size=3, stride=2, padding=1, bias=True), # n x 64 x 2
|
||||
nn.LeakyReLU(0.02, True),
|
||||
nn.Conv1d(64, 64, kernel_size=3, stride=2, padding=1, bias=True), # n x 64 x 1
|
||||
nn.LeakyReLU(0.02, True),
|
||||
)
|
||||
self.encoder_fc1 = nn.Sequential(
|
||||
nn.Linear(64, 64),
|
||||
nn.LeakyReLU(0.02, True),
|
||||
nn.Linear(64, dim_aud),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
half_w = int(self.win_size/2)
|
||||
x = x[:, :, 8-half_w:8+half_w]
|
||||
x = self.encoder_conv(x).squeeze(-1)
|
||||
x = self.encoder_fc1(x)
|
||||
return x
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dim_hidden, num_layers):
|
||||
super().__init__()
|
||||
self.dim_in = dim_in
|
||||
self.dim_out = dim_out
|
||||
self.dim_hidden = dim_hidden
|
||||
self.num_layers = num_layers
|
||||
|
||||
net = []
|
||||
for l in range(num_layers):
|
||||
net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=False))
|
||||
|
||||
self.net = nn.ModuleList(net)
|
||||
|
||||
def forward(self, x):
|
||||
for l in range(self.num_layers):
|
||||
x = self.net[l](x)
|
||||
if l != self.num_layers - 1:
|
||||
x = F.relu(x, inplace=True)
|
||||
# x = F.dropout(x, p=0.1, training=self.training)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class NeRFNetwork(NeRFRenderer):
|
||||
def __init__(self,
|
||||
opt,
|
||||
# torso net (hard coded for now)
|
||||
):
|
||||
super().__init__(opt)
|
||||
|
||||
# audio embedding
|
||||
self.emb = self.opt.emb
|
||||
|
||||
if 'esperanto' in self.opt.asr_model:
|
||||
self.audio_in_dim = 44
|
||||
elif 'deepspeech' in self.opt.asr_model:
|
||||
self.audio_in_dim = 29
|
||||
else:
|
||||
self.audio_in_dim = 32
|
||||
|
||||
if self.emb:
|
||||
self.embedding = nn.Embedding(self.audio_in_dim, self.audio_in_dim)
|
||||
|
||||
# audio network
|
||||
audio_dim = 32
|
||||
self.audio_dim = audio_dim
|
||||
self.audio_net = AudioNet(self.audio_in_dim, self.audio_dim)
|
||||
|
||||
self.att = self.opt.att
|
||||
if self.att > 0:
|
||||
self.audio_att_net = AudioAttNet(self.audio_dim)
|
||||
|
||||
# DYNAMIC PART
|
||||
self.num_levels = 12
|
||||
self.level_dim = 1
|
||||
self.encoder_xy, self.in_dim_xy = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=64, log2_hashmap_size=14, desired_resolution=512 * self.bound)
|
||||
self.encoder_yz, self.in_dim_yz = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=64, log2_hashmap_size=14, desired_resolution=512 * self.bound)
|
||||
self.encoder_xz, self.in_dim_xz = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=64, log2_hashmap_size=14, desired_resolution=512 * self.bound)
|
||||
|
||||
self.in_dim = self.in_dim_xy + self.in_dim_yz + self.in_dim_xz
|
||||
|
||||
## sigma network
|
||||
self.num_layers = 3
|
||||
self.hidden_dim = 64
|
||||
self.geo_feat_dim = 64
|
||||
self.eye_att_net = MLP(self.in_dim, 1, 16, 2)
|
||||
self.eye_dim = 1 if self.exp_eye else 0
|
||||
self.sigma_net = MLP(self.in_dim + self.audio_dim + self.eye_dim, 1 + self.geo_feat_dim, self.hidden_dim, self.num_layers)
|
||||
## color network
|
||||
self.num_layers_color = 2
|
||||
self.hidden_dim_color = 64
|
||||
self.encoder_dir, self.in_dim_dir = get_encoder('spherical_harmonics')
|
||||
self.color_net = MLP(self.in_dim_dir + self.geo_feat_dim + self.individual_dim, 3, self.hidden_dim_color, self.num_layers_color)
|
||||
# 处理音频的
|
||||
self.unc_net = MLP(self.in_dim, 1, 32, 2)
|
||||
|
||||
self.aud_ch_att_net = MLP(self.in_dim, self.audio_dim, 64, 2)
|
||||
|
||||
self.testing = False
|
||||
|
||||
if self.torso:
|
||||
# torso deform network
|
||||
self.register_parameter('anchor_points',
|
||||
nn.Parameter(torch.tensor([[0.01, 0.01, 0.1, 1], [-0.1, -0.1, 0.1, 1], [0.1, -0.1, 0.1, 1]])))
|
||||
self.torso_deform_encoder, self.torso_deform_in_dim = get_encoder('frequency', input_dim=2, multires=8)
|
||||
# self.torso_deform_encoder, self.torso_deform_in_dim = get_encoder('tiledgrid', input_dim=2, num_levels=16, level_dim=1, base_resolution=16, log2_hashmap_size=16, desired_resolution=512)
|
||||
self.anchor_encoder, self.anchor_in_dim = get_encoder('frequency', input_dim=6, multires=3)
|
||||
self.torso_deform_net = MLP(self.torso_deform_in_dim + self.anchor_in_dim + self.individual_dim_torso, 2, 32, 3)
|
||||
|
||||
# torso color network
|
||||
self.torso_encoder, self.torso_in_dim = get_encoder('tiledgrid', input_dim=2, num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=16, desired_resolution=2048)
|
||||
self.torso_net = MLP(self.torso_in_dim + self.torso_deform_in_dim + self.anchor_in_dim + self.individual_dim_torso, 4, 32, 3)
|
||||
|
||||
|
||||
def forward_torso(self, x, poses, c=None):
|
||||
# x: [N, 2] in [-1, 1]
|
||||
# head poses: [1, 4, 4]
|
||||
# c: [1, ind_dim], individual code
|
||||
|
||||
# test: shrink x
|
||||
x = x * self.opt.torso_shrink
|
||||
# 对pose进行了调整
|
||||
# deformation-based
|
||||
wrapped_anchor = self.anchor_points[None, ...] @ poses.permute(0, 2, 1).inverse()
|
||||
wrapped_anchor = (wrapped_anchor[:, :, :2] / wrapped_anchor[:, :, 3, None] / wrapped_anchor[:, :, 2, None]).view(1, -1)
|
||||
# print(wrapped_anchor)
|
||||
# enc_pose = self.pose_encoder(poses)
|
||||
enc_anchor = self.anchor_encoder(wrapped_anchor)
|
||||
enc_x = self.torso_deform_encoder(x)
|
||||
|
||||
if c is not None:
|
||||
h = torch.cat([enc_x, enc_anchor.repeat(x.shape[0], 1), c.repeat(x.shape[0], 1)], dim=-1)
|
||||
else:
|
||||
h = torch.cat([enc_x, enc_anchor.repeat(x.shape[0], 1)], dim=-1)
|
||||
|
||||
dx = self.torso_deform_net(h)
|
||||
|
||||
x = (x + dx).clamp(-1, 1)
|
||||
|
||||
x = self.torso_encoder(x, bound=1)
|
||||
|
||||
# h = torch.cat([x, h, enc_a.repeat(x.shape[0], 1)], dim=-1)
|
||||
h = torch.cat([x, h], dim=-1)
|
||||
|
||||
h = self.torso_net(h)
|
||||
|
||||
alpha = torch.sigmoid(h[..., :1])*(1 + 2*0.001) - 0.001
|
||||
color = torch.sigmoid(h[..., 1:])*(1 + 2*0.001) - 0.001
|
||||
|
||||
return alpha, color, dx
|
||||
|
||||
|
||||
@staticmethod
|
||||
@torch.jit.script
|
||||
def split_xyz(x):
|
||||
xy, yz, xz = x[:, :-1], x[:, 1:], torch.cat([x[:,:1], x[:,-1:]], dim=-1)
|
||||
return xy, yz, xz
|
||||
|
||||
|
||||
def encode_x(self, xyz, bound):
|
||||
# x: [N, 3], in [-bound, bound]
|
||||
N, M = xyz.shape
|
||||
xy, yz, xz = self.split_xyz(xyz)
|
||||
feat_xy = self.encoder_xy(xy, bound=bound)
|
||||
feat_yz = self.encoder_yz(yz, bound=bound)
|
||||
feat_xz = self.encoder_xz(xz, bound=bound)
|
||||
|
||||
return torch.cat([feat_xy, feat_yz, feat_xz], dim=-1)
|
||||
|
||||
|
||||
def encode_audio(self, a):
|
||||
# a: [1, 29, 16] or [8, 29, 16], audio features from deepspeech
|
||||
# if emb, a should be: [1, 16] or [8, 16]
|
||||
|
||||
# fix audio traininig
|
||||
if a is None: return None
|
||||
|
||||
if self.emb:
|
||||
a = self.embedding(a).transpose(-1, -2).contiguous() # [1/8, 29, 16]
|
||||
|
||||
enc_a = self.audio_net(a) # [1/8, 64]
|
||||
|
||||
if self.att > 0:
|
||||
enc_a = self.audio_att_net(enc_a.unsqueeze(0)) # [1, 64]
|
||||
|
||||
return enc_a
|
||||
|
||||
|
||||
def predict_uncertainty(self, unc_inp):
|
||||
if self.testing or not self.opt.unc_loss:
|
||||
unc = torch.zeros_like(unc_inp)
|
||||
else:
|
||||
unc = self.unc_net(unc_inp.detach())
|
||||
|
||||
return unc
|
||||
|
||||
|
||||
def forward(self, x, d, enc_a, c, e=None):
|
||||
# x: [N, 3], in [-bound, bound]
|
||||
# d: [N, 3], nomalized in [-1, 1]
|
||||
# enc_a: [1, aud_dim]
|
||||
# c: [1, ind_dim], individual code
|
||||
# e: [1, 1], eye feature
|
||||
enc_x = self.encode_x(x, bound=self.bound)
|
||||
|
||||
sigma_result = self.density(x, enc_a, e, enc_x)
|
||||
sigma = sigma_result['sigma']
|
||||
geo_feat = sigma_result['geo_feat']
|
||||
aud_ch_att = sigma_result['ambient_aud']
|
||||
eye_att = sigma_result['ambient_eye']
|
||||
|
||||
# color
|
||||
enc_d = self.encoder_dir(d)
|
||||
|
||||
if c is not None:
|
||||
h = torch.cat([enc_d, geo_feat, c.repeat(x.shape[0], 1)], dim=-1)
|
||||
else:
|
||||
h = torch.cat([enc_d, geo_feat], dim=-1)
|
||||
|
||||
h_color = self.color_net(h)
|
||||
color = torch.sigmoid(h_color)*(1 + 2*0.001) - 0.001
|
||||
|
||||
uncertainty = self.predict_uncertainty(enc_x)
|
||||
uncertainty = torch.log(1 + torch.exp(uncertainty))
|
||||
|
||||
return sigma, color, aud_ch_att, eye_att, uncertainty[..., None]
|
||||
|
||||
|
||||
def density(self, x, enc_a, e=None, enc_x=None):
|
||||
# x: [N, 3], in [-bound, bound]
|
||||
if enc_x is None:
|
||||
enc_x = self.encode_x(x, bound=self.bound)
|
||||
|
||||
enc_a = enc_a.repeat(enc_x.shape[0], 1)
|
||||
aud_ch_att = self.aud_ch_att_net(enc_x)
|
||||
enc_w = enc_a * aud_ch_att
|
||||
|
||||
if e is not None:
|
||||
# e = self.encoder_eye(e)
|
||||
eye_att = torch.sigmoid(self.eye_att_net(enc_x))
|
||||
e = e * eye_att
|
||||
# e = e.repeat(enc_x.shape[0], 1)
|
||||
h = torch.cat([enc_x, enc_w, e], dim=-1)
|
||||
else:
|
||||
h = torch.cat([enc_x, enc_w], dim=-1)
|
||||
|
||||
h = self.sigma_net(h)
|
||||
|
||||
sigma = torch.exp(h[..., 0])
|
||||
geo_feat = h[..., 1:]
|
||||
|
||||
return {
|
||||
'sigma': sigma,
|
||||
'geo_feat': geo_feat,
|
||||
'ambient_aud' : aud_ch_att.norm(dim=-1, keepdim=True),
|
||||
'ambient_eye' : eye_att,
|
||||
}
|
||||
|
||||
|
||||
# optimizer utils
|
||||
def get_params(self, lr, lr_net, wd=0):
|
||||
|
||||
# ONLY train torso
|
||||
if self.torso:
|
||||
params = [
|
||||
{'params': self.torso_encoder.parameters(), 'lr': lr},
|
||||
{'params': self.torso_deform_encoder.parameters(), 'lr': lr, 'weight_decay': wd},
|
||||
{'params': self.torso_net.parameters(), 'lr': lr_net, 'weight_decay': wd},
|
||||
{'params': self.torso_deform_net.parameters(), 'lr': lr_net, 'weight_decay': wd},
|
||||
{'params': self.anchor_points, 'lr': lr_net, 'weight_decay': wd}
|
||||
]
|
||||
|
||||
if self.individual_dim_torso > 0:
|
||||
params.append({'params': self.individual_codes_torso, 'lr': lr_net, 'weight_decay': wd})
|
||||
|
||||
return params
|
||||
|
||||
params = [
|
||||
{'params': self.audio_net.parameters(), 'lr': lr_net, 'weight_decay': wd},
|
||||
|
||||
{'params': self.encoder_xy.parameters(), 'lr': lr},
|
||||
{'params': self.encoder_yz.parameters(), 'lr': lr},
|
||||
{'params': self.encoder_xz.parameters(), 'lr': lr},
|
||||
# {'params': self.encoder_xyz.parameters(), 'lr': lr},
|
||||
|
||||
{'params': self.sigma_net.parameters(), 'lr': lr_net, 'weight_decay': wd},
|
||||
{'params': self.color_net.parameters(), 'lr': lr_net, 'weight_decay': wd},
|
||||
]
|
||||
if self.att > 0:
|
||||
params.append({'params': self.audio_att_net.parameters(), 'lr': lr_net * 5, 'weight_decay': 0.0001})
|
||||
if self.emb:
|
||||
params.append({'params': self.embedding.parameters(), 'lr': lr})
|
||||
if self.individual_dim > 0:
|
||||
params.append({'params': self.individual_codes, 'lr': lr_net, 'weight_decay': wd})
|
||||
if self.train_camera:
|
||||
params.append({'params': self.camera_dT, 'lr': 1e-5, 'weight_decay': 0})
|
||||
params.append({'params': self.camera_dR, 'lr': 1e-5, 'weight_decay': 0})
|
||||
|
||||
params.append({'params': self.aud_ch_att_net.parameters(), 'lr': lr_net, 'weight_decay': wd})
|
||||
params.append({'params': self.unc_net.parameters(), 'lr': lr_net, 'weight_decay': wd})
|
||||
params.append({'params': self.eye_att_net.parameters(), 'lr': lr_net, 'weight_decay': wd})
|
||||
|
||||
return params
|
|
@ -0,0 +1,732 @@
|
|||
import os
|
||||
import cv2
|
||||
import glob
|
||||
import json
|
||||
import tqdm
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Slerp, Rotation
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import trimesh
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .utils import get_audio_features, get_rays, get_bg_coords, convert_poses
|
||||
|
||||
# ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50
|
||||
def nerf_matrix_to_ngp(pose, scale=0.33, offset=[0, 0, 0]):
|
||||
new_pose = np.array([
|
||||
[pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale + offset[0]],
|
||||
[pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale + offset[1]],
|
||||
[pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale + offset[2]],
|
||||
[0, 0, 0, 1],
|
||||
], dtype=np.float32)
|
||||
return new_pose
|
||||
|
||||
|
||||
def smooth_camera_path(poses, kernel_size=5):
|
||||
# smooth the camera trajectory...
|
||||
# poses: [N, 4, 4], numpy array
|
||||
|
||||
N = poses.shape[0]
|
||||
K = kernel_size // 2
|
||||
|
||||
trans = poses[:, :3, 3].copy() # [N, 3]
|
||||
rots = poses[:, :3, :3].copy() # [N, 3, 3]
|
||||
|
||||
for i in range(N):
|
||||
start = max(0, i - K)
|
||||
end = min(N, i + K + 1)
|
||||
poses[i, :3, 3] = trans[start:end].mean(0)
|
||||
poses[i, :3, :3] = Rotation.from_matrix(rots[start:end]).mean().as_matrix()
|
||||
|
||||
return poses
|
||||
|
||||
def polygon_area(x, y):
|
||||
x_ = x - x.mean()
|
||||
y_ = y - y.mean()
|
||||
correction = x_[-1] * y_[0] - y_[-1]* x_[0]
|
||||
main_area = np.dot(x_[:-1], y_[1:]) - np.dot(y_[:-1], x_[1:])
|
||||
return 0.5 * np.abs(main_area + correction)
|
||||
|
||||
|
||||
def visualize_poses(poses, size=0.1):
|
||||
# poses: [B, 4, 4]
|
||||
|
||||
print(f'[INFO] visualize poses: {poses.shape}')
|
||||
|
||||
axes = trimesh.creation.axis(axis_length=4)
|
||||
box = trimesh.primitives.Box(extents=(2, 2, 2)).as_outline()
|
||||
box.colors = np.array([[128, 128, 128]] * len(box.entities))
|
||||
objects = [axes, box]
|
||||
|
||||
for pose in poses:
|
||||
# a camera is visualized with 8 line segments.
|
||||
pos = pose[:3, 3]
|
||||
a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2]
|
||||
b = pos - size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2]
|
||||
c = pos - size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2]
|
||||
d = pos + size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2]
|
||||
|
||||
dir = (a + b + c + d) / 4 - pos
|
||||
dir = dir / (np.linalg.norm(dir) + 1e-8)
|
||||
o = pos + dir * 3
|
||||
|
||||
segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a], [pos, o]])
|
||||
segs = trimesh.load_path(segs)
|
||||
objects.append(segs)
|
||||
|
||||
trimesh.Scene(objects).show()
|
||||
|
||||
|
||||
class NeRFDataset_Test:
|
||||
def __init__(self, opt, device, downscale=1):
|
||||
super().__init__()
|
||||
|
||||
self.opt = opt
|
||||
self.device = device
|
||||
self.downscale = downscale
|
||||
self.scale = opt.scale # camera radius scale to make sure camera are inside the bounding box.
|
||||
self.offset = opt.offset # camera offset
|
||||
self.bound = opt.bound # bounding box half length, also used as the radius to random sample poses.
|
||||
self.fp16 = opt.fp16
|
||||
|
||||
self.start_index = opt.data_range[0]
|
||||
self.end_index = opt.data_range[1]
|
||||
|
||||
self.training = False
|
||||
self.num_rays = -1
|
||||
|
||||
# load nerf-compatible format data.
|
||||
|
||||
with open(opt.pose, 'r') as f:
|
||||
transform = json.load(f)
|
||||
|
||||
# load image size
|
||||
self.H = int(transform['cy']) * 2 // downscale
|
||||
self.W = int(transform['cx']) * 2 // downscale
|
||||
|
||||
# read images
|
||||
frames = transform["frames"]
|
||||
|
||||
# use a slice of the dataset
|
||||
if self.end_index == -1: # abuse...
|
||||
self.end_index = len(frames)
|
||||
|
||||
frames = frames[self.start_index:self.end_index]
|
||||
|
||||
print(f'[INFO] load {len(frames)} frames.')
|
||||
|
||||
# only load pre-calculated aud features when not live-streaming
|
||||
if not self.opt.asr:
|
||||
|
||||
aud_features = np.load(self.opt.aud)
|
||||
|
||||
aud_features = torch.from_numpy(aud_features)
|
||||
|
||||
# support both [N, 16] labels and [N, 16, K] logits
|
||||
if len(aud_features.shape) == 3:
|
||||
aud_features = aud_features.float().permute(0, 2, 1) # [N, 16, 29] --> [N, 29, 16]
|
||||
|
||||
if self.opt.emb:
|
||||
print(f'[INFO] argmax to aud features {aud_features.shape} for --emb mode')
|
||||
aud_features = aud_features.argmax(1) # [N, 16]
|
||||
|
||||
else:
|
||||
assert self.opt.emb, "aud only provide labels, must use --emb"
|
||||
aud_features = aud_features.long()
|
||||
|
||||
print(f'[INFO] load {self.opt.aud} aud_features: {aud_features.shape}')
|
||||
|
||||
self.poses = []
|
||||
self.auds = []
|
||||
self.eye_area = []
|
||||
|
||||
for f in tqdm.tqdm(frames, desc=f'Loading data'):
|
||||
|
||||
pose = np.array(f['transform_matrix'], dtype=np.float32) # [4, 4]
|
||||
pose = nerf_matrix_to_ngp(pose, scale=self.scale, offset=self.offset)
|
||||
self.poses.append(pose)
|
||||
|
||||
# find the corresponding audio to the image frame
|
||||
if not self.opt.asr and self.opt.aud == '':
|
||||
aud = aud_features[min(f['aud_id'], aud_features.shape[0] - 1)] # careful for the last frame...
|
||||
self.auds.append(aud)
|
||||
|
||||
if self.opt.exp_eye:
|
||||
|
||||
if 'eye_ratio' in f:
|
||||
area = f['eye_ratio']
|
||||
else:
|
||||
area = 0.25 # default value for opened eye
|
||||
|
||||
self.eye_area.append(area)
|
||||
|
||||
# load pre-extracted background image (should be the same size as training image...)
|
||||
|
||||
if self.opt.bg_img == 'white': # special
|
||||
bg_img = np.ones((self.H, self.W, 3), dtype=np.float32)
|
||||
elif self.opt.bg_img == 'black': # special
|
||||
bg_img = np.zeros((self.H, self.W, 3), dtype=np.float32)
|
||||
else: # load from file
|
||||
bg_img = cv2.imread(self.opt.bg_img, cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
||||
if bg_img.shape[0] != self.H or bg_img.shape[1] != self.W:
|
||||
bg_img = cv2.resize(bg_img, (self.W, self.H), interpolation=cv2.INTER_AREA)
|
||||
bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB)
|
||||
bg_img = bg_img.astype(np.float32) / 255 # [H, W, 3/4]
|
||||
|
||||
self.bg_img = bg_img
|
||||
|
||||
self.poses = np.stack(self.poses, axis=0)
|
||||
|
||||
# smooth camera path...
|
||||
if self.opt.smooth_path:
|
||||
self.poses = smooth_camera_path(self.poses, self.opt.smooth_path_window)
|
||||
|
||||
self.poses = torch.from_numpy(self.poses) # [N, 4, 4]
|
||||
|
||||
if self.opt.asr:
|
||||
# live streaming, no pre-calculated auds
|
||||
self.auds = None
|
||||
else:
|
||||
# auds corresponding to images
|
||||
if self.opt.aud == '':
|
||||
self.auds = torch.stack(self.auds, dim=0) # [N, 32, 16]
|
||||
# auds is novel, may have a different length with images
|
||||
else:
|
||||
self.auds = aud_features
|
||||
|
||||
self.bg_img = torch.from_numpy(self.bg_img)
|
||||
|
||||
if self.opt.exp_eye:
|
||||
self.eye_area = np.array(self.eye_area, dtype=np.float32) # [N]
|
||||
print(f'[INFO] eye_area: {self.eye_area.min()} - {self.eye_area.max()}')
|
||||
|
||||
if self.opt.smooth_eye:
|
||||
|
||||
# naive 5 window average
|
||||
ori_eye = self.eye_area.copy()
|
||||
for i in range(ori_eye.shape[0]):
|
||||
start = max(0, i - 1)
|
||||
end = min(ori_eye.shape[0], i + 2)
|
||||
self.eye_area[i] = ori_eye[start:end].mean()
|
||||
|
||||
self.eye_area = torch.from_numpy(self.eye_area).view(-1, 1) # [N, 1]
|
||||
|
||||
# always preload
|
||||
self.poses = self.poses.to(self.device)
|
||||
|
||||
if self.auds is not None:
|
||||
self.auds = self.auds.to(self.device)
|
||||
|
||||
self.bg_img = self.bg_img.to(torch.half).to(self.device)
|
||||
|
||||
if self.opt.exp_eye:
|
||||
self.eye_area = self.eye_area.to(self.device)
|
||||
|
||||
# load intrinsics
|
||||
|
||||
fl_x = fl_y = transform['focal_len']
|
||||
|
||||
cx = (transform['cx'] / downscale)
|
||||
cy = (transform['cy'] / downscale)
|
||||
|
||||
self.intrinsics = np.array([fl_x, fl_y, cx, cy])
|
||||
|
||||
# directly build the coordinate meshgrid in [-1, 1]^2
|
||||
self.bg_coords = get_bg_coords(self.H, self.W, self.device) # [1, H*W, 2] in [-1, 1]
|
||||
|
||||
def mirror_index(self, index):
|
||||
size = self.poses.shape[0]
|
||||
turn = index // size
|
||||
res = index % size
|
||||
if turn % 2 == 0:
|
||||
return res
|
||||
else:
|
||||
return size - res - 1
|
||||
|
||||
def collate(self, index):
|
||||
|
||||
B = len(index) # a list of length 1
|
||||
# assert B == 1
|
||||
|
||||
results = {}
|
||||
|
||||
# audio use the original index
|
||||
if self.auds is not None:
|
||||
auds = get_audio_features(self.auds, self.opt.att, index[0]).to(self.device)
|
||||
results['auds'] = auds
|
||||
|
||||
# head pose and bg image may mirror (replay --> <-- --> <--).
|
||||
index[0] = self.mirror_index(index[0])
|
||||
|
||||
poses = self.poses[index].to(self.device) # [B, 4, 4]
|
||||
|
||||
rays = get_rays(poses, self.intrinsics, self.H, self.W, self.num_rays, self.opt.patch_size)
|
||||
|
||||
results['index'] = index # for ind. code
|
||||
results['H'] = self.H
|
||||
results['W'] = self.W
|
||||
results['rays_o'] = rays['rays_o']
|
||||
results['rays_d'] = rays['rays_d']
|
||||
|
||||
if self.opt.exp_eye:
|
||||
results['eye'] = self.eye_area[index].to(self.device) # [1]
|
||||
else:
|
||||
results['eye'] = None
|
||||
|
||||
bg_img = self.bg_img.view(1, -1, 3).repeat(B, 1, 1).to(self.device)
|
||||
|
||||
results['bg_color'] = bg_img
|
||||
|
||||
bg_coords = self.bg_coords # [1, N, 2]
|
||||
results['bg_coords'] = bg_coords
|
||||
|
||||
# results['poses'] = convert_poses(poses) # [B, 6]
|
||||
# results['poses_matrix'] = poses # [B, 4, 4]
|
||||
results['poses'] = poses # [B, 4, 4]
|
||||
|
||||
return results
|
||||
|
||||
def dataloader(self):
|
||||
|
||||
|
||||
# test with novel auds, then use its length
|
||||
if self.auds is not None:
|
||||
size = self.auds.shape[0]
|
||||
# live stream test, use 2 * len(poses), so it naturally mirrors.
|
||||
else:
|
||||
size = 2 * self.poses.shape[0]
|
||||
|
||||
loader = DataLoader(list(range(size)), batch_size=1, collate_fn=self.collate, shuffle=False, num_workers=0)
|
||||
loader._data = self # an ugly fix... we need poses in trainer.
|
||||
|
||||
# do evaluate if has gt images and use self-driven setting
|
||||
loader.has_gt = False
|
||||
|
||||
return loader
|
||||
|
||||
|
||||
class NeRFDataset:
|
||||
def __init__(self, opt, device, type='train', downscale=1):
|
||||
super().__init__()
|
||||
|
||||
self.opt = opt
|
||||
self.device = device
|
||||
self.type = type # train, val, test
|
||||
self.downscale = downscale
|
||||
self.root_path = opt.path
|
||||
self.preload = opt.preload # 0 = disk, 1 = cpu, 2 = gpu
|
||||
self.scale = opt.scale # camera radius scale to make sure camera are inside the bounding box.
|
||||
self.offset = opt.offset # camera offset
|
||||
self.bound = opt.bound # bounding box half length, also used as the radius to random sample poses.
|
||||
self.fp16 = opt.fp16
|
||||
|
||||
self.start_index = opt.data_range[0]
|
||||
self.end_index = opt.data_range[1]
|
||||
|
||||
self.training = self.type in ['train', 'all', 'trainval']
|
||||
self.num_rays = self.opt.num_rays if self.training else -1
|
||||
|
||||
# load nerf-compatible format data.
|
||||
|
||||
with open(opt.pose, 'r') as f:
|
||||
transform = json.load(f)
|
||||
|
||||
# load image size
|
||||
if 'h' in transform and 'w' in transform:
|
||||
self.H = int(transform['h']) // downscale
|
||||
self.W = int(transform['w']) // downscale
|
||||
else:
|
||||
self.H = int(transform['cy']) * 2 // downscale
|
||||
self.W = int(transform['cx']) * 2 // downscale
|
||||
|
||||
# read images
|
||||
frames = transform["frames"]
|
||||
|
||||
# use a slice of the dataset
|
||||
if self.end_index == -1: # abuse...
|
||||
self.end_index = len(frames)
|
||||
|
||||
frames = frames[self.start_index:self.end_index]
|
||||
print(f'[INFO] load {len(frames)} {type} frames.')
|
||||
|
||||
# only load pre-calculated aud features when not live-streaming
|
||||
if not self.opt.asr:
|
||||
|
||||
# empty means the default self-driven extracted features.
|
||||
if self.opt.aud == '':
|
||||
if 'esperanto' in self.opt.asr_model:
|
||||
aud_features = np.load(os.path.join(self.root_path, 'aud_eo.npy'))
|
||||
elif 'deepspeech' in self.opt.asr_model:
|
||||
aud_features = np.load(os.path.join(self.root_path, 'aud_ds.npy'))
|
||||
else:
|
||||
aud_features = np.load(os.path.join(self.root_path, 'aud.npy'))
|
||||
# cross-driven extracted features.
|
||||
else:
|
||||
aud_features = np.load(self.opt.aud)
|
||||
|
||||
aud_features = torch.from_numpy(aud_features)
|
||||
|
||||
# support both [N, 16] labels and [N, 16, K] logits
|
||||
if len(aud_features.shape) == 3:
|
||||
aud_features = aud_features.float().permute(0, 2, 1) # [N, 16, 29] --> [N, 29, 16]
|
||||
|
||||
if self.opt.emb:
|
||||
print(f'[INFO] argmax to aud features {aud_features.shape} for --emb mode')
|
||||
aud_features = aud_features.argmax(1) # [N, 16]
|
||||
|
||||
else:
|
||||
assert self.opt.emb, "aud only provide labels, must use --emb"
|
||||
aud_features = aud_features.long()
|
||||
|
||||
print(f'[INFO] load {self.opt.aud} aud_features: {aud_features.shape}')
|
||||
|
||||
# load action units
|
||||
import pandas as pd
|
||||
au_blink_info=pd.read_csv(os.path.join(self.root_path, 'au.csv'))
|
||||
au_blink = au_blink_info[' AU45_r'].values
|
||||
|
||||
self.torso_img = []
|
||||
self.images = []
|
||||
|
||||
self.poses = []
|
||||
self.exps = []
|
||||
|
||||
self.auds = []
|
||||
self.face_rect = []
|
||||
self.lhalf_rect = []
|
||||
self.lips_rect = []
|
||||
self.eye_area = []
|
||||
self.eye_rect = []
|
||||
|
||||
for f in tqdm.tqdm(frames, desc=f'Loading {type} data'):
|
||||
|
||||
f_path = os.path.join(self.root_path, 'gt_imgs', str(f['img_id']) + '.jpg')
|
||||
|
||||
if not os.path.exists(f_path):
|
||||
print('[WARN]', f_path, 'NOT FOUND!')
|
||||
continue
|
||||
|
||||
pose = np.array(f['transform_matrix'], dtype=np.float32) # [4, 4]
|
||||
pose = nerf_matrix_to_ngp(pose, scale=self.scale, offset=self.offset)
|
||||
self.poses.append(pose)
|
||||
|
||||
if self.preload > 0:
|
||||
image = cv2.imread(f_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] o [H, W, 4]
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
image = image.astype(np.float32) / 255 # [H, W, 3/4]
|
||||
|
||||
self.images.append(image)
|
||||
else:
|
||||
self.images.append(f_path)
|
||||
|
||||
# load frame-wise bg
|
||||
|
||||
torso_img_path = os.path.join(self.root_path, 'torso_imgs', str(f['img_id']) + '.png')
|
||||
|
||||
if self.preload > 0:
|
||||
torso_img = cv2.imread(torso_img_path, cv2.IMREAD_UNCHANGED) # [H, W, 4]
|
||||
torso_img = cv2.cvtColor(torso_img, cv2.COLOR_BGRA2RGBA)
|
||||
torso_img = torso_img.astype(np.float32) / 255 # [H, W, 3/4]
|
||||
|
||||
self.torso_img.append(torso_img)
|
||||
else:
|
||||
self.torso_img.append(torso_img_path)
|
||||
|
||||
# find the corresponding audio to the image frame
|
||||
if not self.opt.asr and self.opt.aud == '':
|
||||
aud = aud_features[min(f['aud_id'], aud_features.shape[0] - 1)] # careful for the last frame...
|
||||
self.auds.append(aud)
|
||||
|
||||
# load lms and extract face
|
||||
lms = np.loadtxt(os.path.join(self.root_path, 'ori_imgs', str(f['img_id']) + '.lms')) # [68, 2]
|
||||
|
||||
lh_xmin, lh_xmax = int(lms[31:36, 1].min()), int(lms[:, 1].max()) # actually lower half area
|
||||
xmin, xmax = int(lms[:, 1].min()), int(lms[:, 1].max())
|
||||
ymin, ymax = int(lms[:, 0].min()), int(lms[:, 0].max())
|
||||
self.face_rect.append([xmin, xmax, ymin, ymax])
|
||||
self.lhalf_rect.append([lh_xmin, lh_xmax, ymin, ymax])
|
||||
|
||||
if self.opt.exp_eye:
|
||||
# eyes_left = slice(36, 42)
|
||||
# eyes_right = slice(42, 48)
|
||||
|
||||
# area_left = polygon_area(lms[eyes_left, 0], lms[eyes_left, 1])
|
||||
# area_right = polygon_area(lms[eyes_right, 0], lms[eyes_right, 1])
|
||||
|
||||
# # area percentage of two eyes of the whole image...
|
||||
# area = (area_left + area_right) / (self.H * self.W) * 100
|
||||
|
||||
# action units blink AU45
|
||||
area = au_blink[f['img_id']]
|
||||
area = np.clip(area, 0, 2) / 2
|
||||
# area = area + np.random.rand() / 10
|
||||
self.eye_area.append(area)
|
||||
|
||||
xmin, xmax = int(lms[36:48, 1].min()), int(lms[36:48, 1].max())
|
||||
ymin, ymax = int(lms[36:48, 0].min()), int(lms[36:48, 0].max())
|
||||
self.eye_rect.append([xmin, xmax, ymin, ymax])
|
||||
|
||||
if self.opt.finetune_lips:
|
||||
lips = slice(48, 60)
|
||||
xmin, xmax = int(lms[lips, 1].min()), int(lms[lips, 1].max())
|
||||
ymin, ymax = int(lms[lips, 0].min()), int(lms[lips, 0].max())
|
||||
|
||||
# padding to H == W
|
||||
cx = (xmin + xmax) // 2
|
||||
cy = (ymin + ymax) // 2
|
||||
|
||||
l = max(xmax - xmin, ymax - ymin) // 2
|
||||
xmin = max(0, cx - l)
|
||||
xmax = min(self.H, cx + l)
|
||||
ymin = max(0, cy - l)
|
||||
ymax = min(self.W, cy + l)
|
||||
|
||||
self.lips_rect.append([xmin, xmax, ymin, ymax])
|
||||
|
||||
# load pre-extracted background image (should be the same size as training image...)
|
||||
|
||||
if self.opt.bg_img == 'white': # special
|
||||
bg_img = np.ones((self.H, self.W, 3), dtype=np.float32)
|
||||
elif self.opt.bg_img == 'black': # special
|
||||
bg_img = np.zeros((self.H, self.W, 3), dtype=np.float32)
|
||||
else: # load from file
|
||||
# default bg
|
||||
if self.opt.bg_img == '':
|
||||
self.opt.bg_img = os.path.join(self.root_path, 'bc.jpg')
|
||||
bg_img = cv2.imread(self.opt.bg_img, cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
||||
if bg_img.shape[0] != self.H or bg_img.shape[1] != self.W:
|
||||
bg_img = cv2.resize(bg_img, (self.W, self.H), interpolation=cv2.INTER_AREA)
|
||||
bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB)
|
||||
bg_img = bg_img.astype(np.float32) / 255 # [H, W, 3/4]
|
||||
|
||||
self.bg_img = bg_img
|
||||
|
||||
self.poses = np.stack(self.poses, axis=0)
|
||||
|
||||
# smooth camera path...
|
||||
if self.opt.smooth_path:
|
||||
self.poses = smooth_camera_path(self.poses, self.opt.smooth_path_window)
|
||||
|
||||
self.poses = torch.from_numpy(self.poses) # [N, 4, 4]
|
||||
|
||||
if self.preload > 0:
|
||||
self.images = torch.from_numpy(np.stack(self.images, axis=0)) # [N, H, W, C]
|
||||
self.torso_img = torch.from_numpy(np.stack(self.torso_img, axis=0)) # [N, H, W, C]
|
||||
else:
|
||||
self.images = np.array(self.images)
|
||||
self.torso_img = np.array(self.torso_img)
|
||||
|
||||
if self.opt.asr:
|
||||
# live streaming, no pre-calculated auds
|
||||
self.auds = None
|
||||
else:
|
||||
# auds corresponding to images
|
||||
if self.opt.aud == '':
|
||||
self.auds = torch.stack(self.auds, dim=0) # [N, 32, 16]
|
||||
# auds is novel, may have a different length with images
|
||||
else:
|
||||
self.auds = aud_features
|
||||
|
||||
self.bg_img = torch.from_numpy(self.bg_img)
|
||||
|
||||
if self.opt.exp_eye:
|
||||
self.eye_area = np.array(self.eye_area, dtype=np.float32) # [N]
|
||||
print(f'[INFO] eye_area: {self.eye_area.min()} - {self.eye_area.max()}')
|
||||
|
||||
if self.opt.smooth_eye:
|
||||
|
||||
# naive 5 window average
|
||||
ori_eye = self.eye_area.copy()
|
||||
for i in range(ori_eye.shape[0]):
|
||||
start = max(0, i - 1)
|
||||
end = min(ori_eye.shape[0], i + 2)
|
||||
self.eye_area[i] = ori_eye[start:end].mean()
|
||||
|
||||
self.eye_area = torch.from_numpy(self.eye_area).view(-1, 1) # [N, 1]
|
||||
|
||||
|
||||
# calculate mean radius of all camera poses
|
||||
self.radius = self.poses[:, :3, 3].norm(dim=-1).mean(0).item()
|
||||
#print(f'[INFO] dataset camera poses: radius = {self.radius:.4f}, bound = {self.bound}')
|
||||
|
||||
|
||||
# [debug] uncomment to view all training poses.
|
||||
# visualize_poses(self.poses.numpy())
|
||||
|
||||
# [debug] uncomment to view examples of randomly generated poses.
|
||||
# visualize_poses(rand_poses(100, self.device, radius=self.radius).cpu().numpy())
|
||||
|
||||
if self.preload > 1:
|
||||
self.poses = self.poses.to(self.device)
|
||||
|
||||
if self.auds is not None:
|
||||
self.auds = self.auds.to(self.device)
|
||||
|
||||
self.bg_img = self.bg_img.to(torch.half).to(self.device)
|
||||
|
||||
self.torso_img = self.torso_img.to(torch.half).to(self.device)
|
||||
self.images = self.images.to(torch.half).to(self.device)
|
||||
|
||||
if self.opt.exp_eye:
|
||||
self.eye_area = self.eye_area.to(self.device)
|
||||
|
||||
# load intrinsics
|
||||
if 'focal_len' in transform:
|
||||
fl_x = fl_y = transform['focal_len']
|
||||
elif 'fl_x' in transform or 'fl_y' in transform:
|
||||
fl_x = (transform['fl_x'] if 'fl_x' in transform else transform['fl_y']) / downscale
|
||||
fl_y = (transform['fl_y'] if 'fl_y' in transform else transform['fl_x']) / downscale
|
||||
elif 'camera_angle_x' in transform or 'camera_angle_y' in transform:
|
||||
# blender, assert in radians. already downscaled since we use H/W
|
||||
fl_x = self.W / (2 * np.tan(transform['camera_angle_x'] / 2)) if 'camera_angle_x' in transform else None
|
||||
fl_y = self.H / (2 * np.tan(transform['camera_angle_y'] / 2)) if 'camera_angle_y' in transform else None
|
||||
if fl_x is None: fl_x = fl_y
|
||||
if fl_y is None: fl_y = fl_x
|
||||
else:
|
||||
raise RuntimeError('Failed to load focal length, please check the transforms.json!')
|
||||
|
||||
cx = (transform['cx'] / downscale) if 'cx' in transform else (self.W / 2)
|
||||
cy = (transform['cy'] / downscale) if 'cy' in transform else (self.H / 2)
|
||||
|
||||
self.intrinsics = np.array([fl_x, fl_y, cx, cy])
|
||||
|
||||
# directly build the coordinate meshgrid in [-1, 1]^2
|
||||
self.bg_coords = get_bg_coords(self.H, self.W, self.device) # [1, H*W, 2] in [-1, 1]
|
||||
|
||||
|
||||
def mirror_index(self, index):
|
||||
size = self.poses.shape[0]
|
||||
turn = index // size
|
||||
res = index % size
|
||||
if turn % 2 == 0:
|
||||
return res
|
||||
else:
|
||||
return size - res - 1
|
||||
|
||||
|
||||
def collate(self, index):
|
||||
|
||||
B = len(index) # a list of length 1
|
||||
# assert B == 1
|
||||
|
||||
results = {}
|
||||
|
||||
# audio use the original index
|
||||
if self.auds is not None:
|
||||
auds = get_audio_features(self.auds, self.opt.att, index[0]).to(self.device)
|
||||
results['auds'] = auds
|
||||
|
||||
# head pose and bg image may mirror (replay --> <-- --> <--).
|
||||
index[0] = self.mirror_index(index[0])
|
||||
|
||||
poses = self.poses[index].to(self.device) # [B, 4, 4]
|
||||
|
||||
if self.training and self.opt.finetune_lips:
|
||||
rect = self.lips_rect[index[0]]
|
||||
results['rect'] = rect
|
||||
rays = get_rays(poses, self.intrinsics, self.H, self.W, -1, rect=rect)
|
||||
else:
|
||||
rays = get_rays(poses, self.intrinsics, self.H, self.W, self.num_rays, self.opt.patch_size)
|
||||
|
||||
results['index'] = index # for ind. code
|
||||
results['H'] = self.H
|
||||
results['W'] = self.W
|
||||
results['rays_o'] = rays['rays_o']
|
||||
results['rays_d'] = rays['rays_d']
|
||||
|
||||
# get a mask for rays inside rect_face
|
||||
if self.training:
|
||||
xmin, xmax, ymin, ymax = self.face_rect[index[0]]
|
||||
face_mask = (rays['j'] >= xmin) & (rays['j'] < xmax) & (rays['i'] >= ymin) & (rays['i'] < ymax) # [B, N]
|
||||
results['face_mask'] = face_mask
|
||||
|
||||
xmin, xmax, ymin, ymax = self.lhalf_rect[index[0]]
|
||||
lhalf_mask = (rays['j'] >= xmin) & (rays['j'] < xmax) & (rays['i'] >= ymin) & (rays['i'] < ymax) # [B, N]
|
||||
results['lhalf_mask'] = lhalf_mask
|
||||
|
||||
if self.opt.exp_eye:
|
||||
results['eye'] = self.eye_area[index].to(self.device) # [1]
|
||||
if self.training:
|
||||
results['eye'] += (np.random.rand()-0.5) / 10
|
||||
xmin, xmax, ymin, ymax = self.eye_rect[index[0]]
|
||||
eye_mask = (rays['j'] >= xmin) & (rays['j'] < xmax) & (rays['i'] >= ymin) & (rays['i'] < ymax) # [B, N]
|
||||
results['eye_mask'] = eye_mask
|
||||
|
||||
else:
|
||||
results['eye'] = None
|
||||
|
||||
# load bg
|
||||
bg_torso_img = self.torso_img[index]
|
||||
if self.preload == 0: # on the fly loading
|
||||
bg_torso_img = cv2.imread(bg_torso_img[0], cv2.IMREAD_UNCHANGED) # [H, W, 4]
|
||||
bg_torso_img = cv2.cvtColor(bg_torso_img, cv2.COLOR_BGRA2RGBA)
|
||||
bg_torso_img = bg_torso_img.astype(np.float32) / 255 # [H, W, 3/4]
|
||||
bg_torso_img = torch.from_numpy(bg_torso_img).unsqueeze(0)
|
||||
bg_torso_img = bg_torso_img[..., :3] * bg_torso_img[..., 3:] + self.bg_img * (1 - bg_torso_img[..., 3:])
|
||||
bg_torso_img = bg_torso_img.view(B, -1, 3).to(self.device)
|
||||
|
||||
if not self.opt.torso:
|
||||
bg_img = bg_torso_img
|
||||
else:
|
||||
bg_img = self.bg_img.view(1, -1, 3).repeat(B, 1, 1).to(self.device)
|
||||
|
||||
if self.training:
|
||||
bg_img = torch.gather(bg_img, 1, torch.stack(3 * [rays['inds']], -1)) # [B, N, 3]
|
||||
|
||||
results['bg_color'] = bg_img
|
||||
|
||||
if self.opt.torso and self.training:
|
||||
bg_torso_img = torch.gather(bg_torso_img, 1, torch.stack(3 * [rays['inds']], -1)) # [B, N, 3]
|
||||
results['bg_torso_color'] = bg_torso_img
|
||||
|
||||
images = self.images[index] # [B, H, W, 3/4]
|
||||
if self.preload == 0:
|
||||
images = cv2.imread(images[0], cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
||||
images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB)
|
||||
images = images.astype(np.float32) / 255 # [H, W, 3]
|
||||
images = torch.from_numpy(images).unsqueeze(0)
|
||||
images = images.to(self.device)
|
||||
|
||||
if self.training:
|
||||
C = images.shape[-1]
|
||||
images = torch.gather(images.view(B, -1, C), 1, torch.stack(C * [rays['inds']], -1)) # [B, N, 3/4]
|
||||
|
||||
results['images'] = images
|
||||
|
||||
if self.training:
|
||||
bg_coords = torch.gather(self.bg_coords, 1, torch.stack(2 * [rays['inds']], -1)) # [1, N, 2]
|
||||
else:
|
||||
bg_coords = self.bg_coords # [1, N, 2]
|
||||
|
||||
results['bg_coords'] = bg_coords
|
||||
|
||||
# results['poses'] = convert_poses(poses) # [B, 6]
|
||||
# results['poses_matrix'] = poses # [B, 4, 4]
|
||||
results['poses'] = poses # [B, 4, 4]
|
||||
|
||||
return results
|
||||
|
||||
def dataloader(self):
|
||||
|
||||
if self.training:
|
||||
# training len(poses) == len(auds)
|
||||
size = self.poses.shape[0]
|
||||
else:
|
||||
# test with novel auds, then use its length
|
||||
if self.auds is not None:
|
||||
size = self.auds.shape[0]
|
||||
# live stream test, use 2 * len(poses), so it naturally mirrors.
|
||||
else:
|
||||
size = 2 * self.poses.shape[0]
|
||||
|
||||
loader = DataLoader(list(range(size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0)
|
||||
loader._data = self # an ugly fix... we need poses in trainer.
|
||||
|
||||
# do evaluate if has gt images and use self-driven setting
|
||||
loader.has_gt = (self.opt.aud == '')
|
||||
|
||||
return loader
|
|
@ -0,0 +1,700 @@
|
|||
import math
|
||||
import trimesh
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import raymarching
|
||||
from .utils import custom_meshgrid, get_audio_features, euler_angles_to_matrix, convert_poses
|
||||
|
||||
def sample_pdf(bins, weights, n_samples, det=False):
|
||||
# This implementation is from NeRF
|
||||
# bins: [B, T], old_z_vals
|
||||
# weights: [B, T - 1], bin weights.
|
||||
# return: [B, n_samples], new_z_vals
|
||||
|
||||
# Get pdf
|
||||
weights = weights + 1e-5 # prevent nans
|
||||
pdf = weights / torch.sum(weights, -1, keepdim=True)
|
||||
cdf = torch.cumsum(pdf, -1)
|
||||
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
|
||||
# Take uniform samples
|
||||
if det:
|
||||
u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device)
|
||||
u = u.expand(list(cdf.shape[:-1]) + [n_samples])
|
||||
else:
|
||||
u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device)
|
||||
|
||||
# Invert CDF
|
||||
u = u.contiguous()
|
||||
inds = torch.searchsorted(cdf, u, right=True)
|
||||
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
|
||||
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
|
||||
inds_g = torch.stack([below, above], -1) # (B, n_samples, 2)
|
||||
|
||||
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
|
||||
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
||||
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
||||
|
||||
denom = (cdf_g[..., 1] - cdf_g[..., 0])
|
||||
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
|
||||
t = (u - cdf_g[..., 0]) / denom
|
||||
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def plot_pointcloud(pc, color=None):
|
||||
# pc: [N, 3]
|
||||
# color: [N, 3/4]
|
||||
print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0))
|
||||
pc = trimesh.PointCloud(pc, color)
|
||||
# axis
|
||||
axes = trimesh.creation.axis(axis_length=4)
|
||||
# sphere
|
||||
sphere = trimesh.creation.icosphere(radius=1)
|
||||
trimesh.Scene([pc, axes, sphere]).show()
|
||||
|
||||
|
||||
class NeRFRenderer(nn.Module):
|
||||
def __init__(self, opt):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.opt = opt
|
||||
self.bound = opt.bound
|
||||
self.cascade = 1 + math.ceil(math.log2(opt.bound))
|
||||
self.grid_size = 128
|
||||
self.density_scale = 1
|
||||
|
||||
self.min_near = opt.min_near
|
||||
self.density_thresh = opt.density_thresh
|
||||
self.density_thresh_torso = opt.density_thresh_torso
|
||||
|
||||
self.exp_eye = opt.exp_eye
|
||||
self.test_train = opt.test_train
|
||||
self.smooth_lips = opt.smooth_lips
|
||||
|
||||
self.torso = opt.torso
|
||||
self.cuda_ray = opt.cuda_ray
|
||||
|
||||
# prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax)
|
||||
# NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing.
|
||||
aabb_train = torch.FloatTensor([-opt.bound, -opt.bound/2, -opt.bound, opt.bound, opt.bound/2, opt.bound])
|
||||
aabb_infer = aabb_train.clone()
|
||||
self.register_buffer('aabb_train', aabb_train)
|
||||
self.register_buffer('aabb_infer', aabb_infer)
|
||||
|
||||
# individual codes
|
||||
self.individual_num = opt.ind_num
|
||||
|
||||
self.individual_dim = opt.ind_dim
|
||||
if self.individual_dim > 0:
|
||||
self.individual_codes = nn.Parameter(torch.randn(self.individual_num, self.individual_dim) * 0.1)
|
||||
|
||||
if self.torso:
|
||||
self.individual_dim_torso = opt.ind_dim_torso
|
||||
if self.individual_dim_torso > 0:
|
||||
self.individual_codes_torso = nn.Parameter(torch.randn(self.individual_num, self.individual_dim_torso) * 0.1)
|
||||
|
||||
# optimize camera pose
|
||||
self.train_camera = self.opt.train_camera
|
||||
if self.train_camera:
|
||||
self.camera_dR = nn.Parameter(torch.zeros(self.individual_num, 3)) # euler angle
|
||||
self.camera_dT = nn.Parameter(torch.zeros(self.individual_num, 3)) # xyz offset
|
||||
|
||||
# extra state for cuda raymarching
|
||||
|
||||
# 3D head density grid
|
||||
density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H]
|
||||
density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8]
|
||||
self.register_buffer('density_grid', density_grid)
|
||||
self.register_buffer('density_bitfield', density_bitfield)
|
||||
self.mean_density = 0
|
||||
self.iter_density = 0
|
||||
|
||||
# 2D torso density grid
|
||||
if self.torso:
|
||||
density_grid_torso = torch.zeros([self.grid_size ** 2]) # [H * H]
|
||||
self.register_buffer('density_grid_torso', density_grid_torso)
|
||||
self.mean_density_torso = 0
|
||||
|
||||
# step counter
|
||||
step_counter = torch.zeros(16, 2, dtype=torch.int32) # 16 is hardcoded for averaging...
|
||||
self.register_buffer('step_counter', step_counter)
|
||||
self.mean_count = 0
|
||||
self.local_step = 0
|
||||
|
||||
# decay for enc_a
|
||||
if self.smooth_lips:
|
||||
self.enc_a = None
|
||||
|
||||
def forward(self, x, d):
|
||||
raise NotImplementedError()
|
||||
|
||||
# separated density and color query (can accelerate non-cuda-ray mode.)
|
||||
def density(self, x):
|
||||
raise NotImplementedError()
|
||||
|
||||
def color(self, x, d, mask=None, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
def reset_extra_state(self):
|
||||
if not self.cuda_ray:
|
||||
return
|
||||
# density grid
|
||||
self.density_grid.zero_()
|
||||
self.mean_density = 0
|
||||
self.iter_density = 0
|
||||
# step counter
|
||||
self.step_counter.zero_()
|
||||
self.mean_count = 0
|
||||
self.local_step = 0
|
||||
|
||||
|
||||
def run_cuda(self, rays_o, rays_d, auds, bg_coords, poses, eye=None, index=0, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs):
|
||||
# rays_o, rays_d: [B, N, 3], assumes B == 1
|
||||
# auds: [B, 16]
|
||||
# index: [B]
|
||||
# return: image: [B, N, 3], depth: [B, N]
|
||||
|
||||
prefix = rays_o.shape[:-1]
|
||||
rays_o = rays_o.contiguous().view(-1, 3)
|
||||
rays_d = rays_d.contiguous().view(-1, 3)
|
||||
bg_coords = bg_coords.contiguous().view(-1, 2)
|
||||
|
||||
# only add camera offset at training!
|
||||
if self.train_camera and (self.training or self.test_train):
|
||||
dT = self.camera_dT[index] # [1, 3]
|
||||
dR = euler_angles_to_matrix(self.camera_dR[index] / 180 * np.pi + 1e-8).squeeze(0) # [1, 3] --> [3, 3]
|
||||
|
||||
rays_o = rays_o + dT
|
||||
rays_d = rays_d @ dR
|
||||
|
||||
N = rays_o.shape[0] # N = B * N, in fact
|
||||
device = rays_o.device
|
||||
|
||||
results = {}
|
||||
|
||||
# pre-calculate near far
|
||||
nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer, self.min_near)
|
||||
nears = nears.detach()
|
||||
fars = fars.detach()
|
||||
|
||||
# encode audio
|
||||
enc_a = self.encode_audio(auds) # [1, 64]
|
||||
|
||||
if enc_a is not None and self.smooth_lips:
|
||||
if self.enc_a is not None:
|
||||
_lambda = 0.35
|
||||
enc_a = _lambda * self.enc_a + (1 - _lambda) * enc_a
|
||||
self.enc_a = enc_a
|
||||
|
||||
|
||||
if self.individual_dim > 0:
|
||||
if self.training:
|
||||
ind_code = self.individual_codes[index]
|
||||
# use a fixed ind code for the unknown test data.
|
||||
else:
|
||||
ind_code = self.individual_codes[0]
|
||||
else:
|
||||
ind_code = None
|
||||
|
||||
if self.training:
|
||||
# setup counter
|
||||
counter = self.step_counter[self.local_step % 16]
|
||||
counter.zero_() # set to 0
|
||||
self.local_step += 1
|
||||
|
||||
xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps)
|
||||
sigmas, rgbs, amb_aud, amb_eye, uncertainty = self(xyzs, dirs, enc_a, ind_code, eye)
|
||||
sigmas = self.density_scale * sigmas
|
||||
|
||||
#print(f'valid RGB query ratio: {mask.sum().item() / mask.shape[0]} (total = {mask.sum().item()})')
|
||||
|
||||
# weights_sum, ambient_sum, uncertainty_sum, depth, image = raymarching.composite_rays_train_uncertainty(sigmas, rgbs, ambient.abs().sum(-1), uncertainty, deltas, rays)
|
||||
weights_sum, amb_aud_sum, amb_eye_sum, uncertainty_sum, depth, image = raymarching.composite_rays_train_triplane(sigmas, rgbs, amb_aud.abs().sum(-1), amb_eye.abs().sum(-1), uncertainty, deltas, rays)
|
||||
|
||||
# for training only
|
||||
results['weights_sum'] = weights_sum
|
||||
results['ambient_aud'] = amb_aud_sum
|
||||
results['ambient_eye'] = amb_eye_sum
|
||||
results['uncertainty'] = uncertainty_sum
|
||||
|
||||
results['rays'] = xyzs, dirs, enc_a, ind_code, eye
|
||||
|
||||
else:
|
||||
|
||||
dtype = torch.float32
|
||||
|
||||
weights_sum = torch.zeros(N, dtype=dtype, device=device)
|
||||
depth = torch.zeros(N, dtype=dtype, device=device)
|
||||
image = torch.zeros(N, 3, dtype=dtype, device=device)
|
||||
amb_aud_sum = torch.zeros(N, dtype=dtype, device=device)
|
||||
amb_eye_sum = torch.zeros(N, dtype=dtype, device=device)
|
||||
uncertainty_sum = torch.zeros(N, dtype=dtype, device=device)
|
||||
|
||||
n_alive = N
|
||||
rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N]
|
||||
rays_t = nears.clone() # [N]
|
||||
|
||||
step = 0
|
||||
|
||||
while step < max_steps:
|
||||
|
||||
# count alive rays
|
||||
n_alive = rays_alive.shape[0]
|
||||
|
||||
# exit loop
|
||||
if n_alive <= 0:
|
||||
break
|
||||
|
||||
# decide compact_steps
|
||||
n_step = max(min(N // n_alive, 8), 1)
|
||||
|
||||
xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps)
|
||||
|
||||
sigmas, rgbs, ambients_aud, ambients_eye, uncertainties = self(xyzs, dirs, enc_a, ind_code, eye)
|
||||
sigmas = self.density_scale * sigmas
|
||||
|
||||
# raymarching.composite_rays_uncertainty(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, uncertainties, weights_sum, depth, image, ambient_sum, uncertainty_sum, T_thresh)
|
||||
raymarching.composite_rays_triplane(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambients_aud, ambients_eye, uncertainties, weights_sum, depth, image, amb_aud_sum, amb_eye_sum, uncertainty_sum, T_thresh)
|
||||
|
||||
rays_alive = rays_alive[rays_alive >= 0]
|
||||
|
||||
# print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}')
|
||||
|
||||
step += n_step
|
||||
|
||||
torso_results = self.run_torso(rays_o, bg_coords, poses, index, bg_color)
|
||||
bg_color = torso_results['bg_color']
|
||||
|
||||
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
|
||||
image = image.view(*prefix, 3)
|
||||
image = image.clamp(0, 1)
|
||||
|
||||
depth = torch.clamp(depth - nears, min=0) / (fars - nears)
|
||||
depth = depth.view(*prefix)
|
||||
|
||||
amb_aud_sum = amb_aud_sum.view(*prefix)
|
||||
amb_eye_sum = amb_eye_sum.view(*prefix)
|
||||
|
||||
results['depth'] = depth
|
||||
results['image'] = image # head_image if train, else com_image
|
||||
results['ambient_aud'] = amb_aud_sum
|
||||
results['ambient_eye'] = amb_eye_sum
|
||||
results['uncertainty'] = uncertainty_sum
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def run_torso(self, rays_o, bg_coords, poses, index=0, bg_color=None, **kwargs):
|
||||
# rays_o, rays_d: [B, N, 3], assumes B == 1
|
||||
# auds: [B, 16]
|
||||
# index: [B]
|
||||
# return: image: [B, N, 3], depth: [B, N]
|
||||
|
||||
rays_o = rays_o.contiguous().view(-1, 3)
|
||||
bg_coords = bg_coords.contiguous().view(-1, 2)
|
||||
|
||||
N = rays_o.shape[0] # N = B * N, in fact
|
||||
device = rays_o.device
|
||||
|
||||
results = {}
|
||||
|
||||
# background
|
||||
if bg_color is None:
|
||||
bg_color = 1
|
||||
|
||||
# first mix torso with background
|
||||
if self.torso:
|
||||
# torso ind code
|
||||
if self.individual_dim_torso > 0:
|
||||
if self.training:
|
||||
ind_code_torso = self.individual_codes_torso[index]
|
||||
# use a fixed ind code for the unknown test data.
|
||||
else:
|
||||
ind_code_torso = self.individual_codes_torso[0]
|
||||
else:
|
||||
ind_code_torso = None
|
||||
|
||||
# 2D density grid for acceleration...
|
||||
density_thresh_torso = min(self.density_thresh_torso, self.mean_density_torso)
|
||||
occupancy = F.grid_sample(self.density_grid_torso.view(1, 1, self.grid_size, self.grid_size), bg_coords.view(1, -1, 1, 2), align_corners=True).view(-1)
|
||||
mask = occupancy > density_thresh_torso
|
||||
|
||||
# masked query of torso
|
||||
torso_alpha = torch.zeros([N, 1], device=device)
|
||||
torso_color = torch.zeros([N, 3], device=device)
|
||||
|
||||
if mask.any():
|
||||
torso_alpha_mask, torso_color_mask, deform = self.forward_torso(bg_coords[mask], poses, ind_code_torso)
|
||||
|
||||
torso_alpha[mask] = torso_alpha_mask.float()
|
||||
torso_color[mask] = torso_color_mask.float()
|
||||
|
||||
results['deform'] = deform
|
||||
|
||||
# first mix torso with background
|
||||
|
||||
bg_color = torso_color * torso_alpha + bg_color * (1 - torso_alpha)
|
||||
|
||||
results['torso_alpha'] = torso_alpha
|
||||
results['torso_color'] = bg_color
|
||||
|
||||
# print(torso_alpha.shape, torso_alpha.max().item(), torso_alpha.min().item())
|
||||
|
||||
results['bg_color'] = bg_color
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def mark_untrained_grid(self, poses, intrinsic, S=64):
|
||||
# poses: [B, 4, 4]
|
||||
# intrinsic: [3, 3]
|
||||
|
||||
if not self.cuda_ray:
|
||||
return
|
||||
|
||||
if isinstance(poses, np.ndarray):
|
||||
poses = torch.from_numpy(poses)
|
||||
|
||||
B = poses.shape[0]
|
||||
|
||||
fx, fy, cx, cy = intrinsic
|
||||
|
||||
X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
||||
Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
||||
Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
||||
|
||||
count = torch.zeros_like(self.density_grid)
|
||||
poses = poses.to(count.device)
|
||||
|
||||
# 5-level loop, forgive me...
|
||||
|
||||
for xs in X:
|
||||
for ys in Y:
|
||||
for zs in Z:
|
||||
|
||||
# construct points
|
||||
xx, yy, zz = custom_meshgrid(xs, ys, zs)
|
||||
coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
|
||||
indices = raymarching.morton3D(coords).long() # [N]
|
||||
world_xyzs = (2 * coords.float() / (self.grid_size - 1) - 1).unsqueeze(0) # [1, N, 3] in [-1, 1]
|
||||
|
||||
# cascading
|
||||
for cas in range(self.cascade):
|
||||
bound = min(2 ** cas, self.bound)
|
||||
half_grid_size = bound / self.grid_size
|
||||
# scale to current cascade's resolution
|
||||
cas_world_xyzs = world_xyzs * (bound - half_grid_size)
|
||||
|
||||
# split batch to avoid OOM
|
||||
head = 0
|
||||
while head < B:
|
||||
tail = min(head + S, B)
|
||||
|
||||
# world2cam transform (poses is c2w, so we need to transpose it. Another transpose is needed for batched matmul, so the final form is without transpose.)
|
||||
cam_xyzs = cas_world_xyzs - poses[head:tail, :3, 3].unsqueeze(1)
|
||||
cam_xyzs = cam_xyzs @ poses[head:tail, :3, :3] # [S, N, 3]
|
||||
|
||||
# query if point is covered by any camera
|
||||
mask_z = cam_xyzs[:, :, 2] > 0 # [S, N]
|
||||
mask_x = torch.abs(cam_xyzs[:, :, 0]) < cx / fx * cam_xyzs[:, :, 2] + half_grid_size * 2
|
||||
mask_y = torch.abs(cam_xyzs[:, :, 1]) < cy / fy * cam_xyzs[:, :, 2] + half_grid_size * 2
|
||||
mask = (mask_z & mask_x & mask_y).sum(0).reshape(-1) # [N]
|
||||
|
||||
# update count
|
||||
count[cas, indices] += mask
|
||||
head += S
|
||||
|
||||
# mark untrained grid as -1
|
||||
self.density_grid[count == 0] = -1
|
||||
|
||||
#print(f'[mark untrained grid] {(count == 0).sum()} from {resolution ** 3 * self.cascade}')
|
||||
|
||||
@torch.no_grad()
|
||||
def update_extra_state(self, decay=0.95, S=128):
|
||||
# call before each epoch to update extra states.
|
||||
|
||||
if not self.cuda_ray:
|
||||
return
|
||||
|
||||
# use random auds (different expressions should have similar density grid...)
|
||||
rand_idx = random.randint(0, self.aud_features.shape[0] - 1)
|
||||
auds = get_audio_features(self.aud_features, self.att, rand_idx).to(self.density_bitfield.device)
|
||||
|
||||
# encode audio
|
||||
enc_a = self.encode_audio(auds)
|
||||
|
||||
### update density grid
|
||||
if not self.torso: # forbid updating head if is training torso...
|
||||
|
||||
tmp_grid = torch.zeros_like(self.density_grid)
|
||||
|
||||
# use a random eye area based on training dataset's statistics...
|
||||
if self.exp_eye:
|
||||
eye = self.eye_area[[rand_idx]].to(self.density_bitfield.device) # [1, 1]
|
||||
else:
|
||||
eye = None
|
||||
|
||||
# full update
|
||||
X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
||||
Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
||||
Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
||||
|
||||
for xs in X:
|
||||
for ys in Y:
|
||||
for zs in Z:
|
||||
|
||||
# construct points
|
||||
xx, yy, zz = custom_meshgrid(xs, ys, zs)
|
||||
coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
|
||||
indices = raymarching.morton3D(coords).long() # [N]
|
||||
xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
|
||||
|
||||
# cascading
|
||||
for cas in range(self.cascade):
|
||||
bound = min(2 ** cas, self.bound)
|
||||
half_grid_size = bound / self.grid_size
|
||||
# scale to current cascade's resolution
|
||||
cas_xyzs = xyzs * (bound - half_grid_size)
|
||||
# add noise in [-hgs, hgs]
|
||||
cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
|
||||
# query density
|
||||
sigmas = self.density(cas_xyzs, enc_a, eye)['sigma'].reshape(-1).detach().to(tmp_grid.dtype)
|
||||
sigmas *= self.density_scale
|
||||
# assign
|
||||
tmp_grid[cas, indices] = sigmas
|
||||
|
||||
# dilate the density_grid (less aggressive culling)
|
||||
tmp_grid = raymarching.morton3D_dilation(tmp_grid)
|
||||
|
||||
# ema update
|
||||
valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0)
|
||||
self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])
|
||||
self.mean_density = torch.mean(self.density_grid.clamp(min=0)).item() # -1 non-training regions are viewed as 0 density.
|
||||
self.iter_density += 1
|
||||
|
||||
# convert to bitfield
|
||||
density_thresh = min(self.mean_density, self.density_thresh)
|
||||
self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield)
|
||||
|
||||
### update torso density grid
|
||||
if self.torso:
|
||||
tmp_grid_torso = torch.zeros_like(self.density_grid_torso)
|
||||
|
||||
# random pose, random ind_code
|
||||
rand_idx = random.randint(0, self.poses.shape[0] - 1)
|
||||
# pose = convert_poses(self.poses[[rand_idx]]).to(self.density_bitfield.device)
|
||||
pose = self.poses[[rand_idx]].to(self.density_bitfield.device)
|
||||
|
||||
if self.opt.ind_dim_torso > 0:
|
||||
ind_code = self.individual_codes_torso[[rand_idx]]
|
||||
else:
|
||||
ind_code = None
|
||||
|
||||
X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
||||
Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
||||
|
||||
half_grid_size = 1 / self.grid_size
|
||||
|
||||
for xs in X:
|
||||
for ys in Y:
|
||||
xx, yy = custom_meshgrid(xs, ys)
|
||||
coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], dim=-1) # [N, 2], in [0, 128)
|
||||
indices = (coords[:, 1] * self.grid_size + coords[:, 0]).long() # NOTE: xy transposed!
|
||||
xys = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 2] in [-1, 1]
|
||||
xys = xys * (1 - half_grid_size)
|
||||
# add noise in [-hgs, hgs]
|
||||
xys += (torch.rand_like(xys) * 2 - 1) * half_grid_size
|
||||
# query density
|
||||
alphas, _, _ = self.forward_torso(xys, pose, ind_code) # [N, 1]
|
||||
|
||||
# assign
|
||||
tmp_grid_torso[indices] = alphas.squeeze(1).float()
|
||||
|
||||
# dilate
|
||||
tmp_grid_torso = tmp_grid_torso.view(1, 1, self.grid_size, self.grid_size)
|
||||
# tmp_grid_torso = F.max_pool2d(tmp_grid_torso, kernel_size=3, stride=1, padding=1)
|
||||
tmp_grid_torso = F.max_pool2d(tmp_grid_torso, kernel_size=5, stride=1, padding=2)
|
||||
tmp_grid_torso = tmp_grid_torso.view(-1)
|
||||
|
||||
self.density_grid_torso = torch.maximum(self.density_grid_torso * decay, tmp_grid_torso)
|
||||
self.mean_density_torso = torch.mean(self.density_grid_torso).item()
|
||||
|
||||
# density_thresh_torso = min(self.density_thresh_torso, self.mean_density_torso)
|
||||
# print(f'[density grid torso] min={self.density_grid_torso.min().item():.4f}, max={self.density_grid_torso.max().item():.4f}, mean={self.mean_density_torso:.4f}, occ_rate={(self.density_grid_torso > density_thresh_torso).sum() / (128**2):.3f}')
|
||||
|
||||
### update step counter
|
||||
total_step = min(16, self.local_step)
|
||||
if total_step > 0:
|
||||
self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step)
|
||||
self.local_step = 0
|
||||
|
||||
#print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > 0.01).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}')
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_audio_grid(self, S=128):
|
||||
# call before each epoch to update extra states.
|
||||
|
||||
if not self.cuda_ray:
|
||||
return
|
||||
|
||||
# use random auds (different expressions should have similar density grid...)
|
||||
rand_idx = random.randint(0, self.aud_features.shape[0] - 1)
|
||||
auds = get_audio_features(self.aud_features, self.att, rand_idx).to(self.density_bitfield.device)
|
||||
|
||||
# encode audio
|
||||
enc_a = self.encode_audio(auds)
|
||||
tmp_grid = torch.zeros_like(self.density_grid)
|
||||
|
||||
# use a random eye area based on training dataset's statistics...
|
||||
if self.exp_eye:
|
||||
eye = self.eye_area[[rand_idx]].to(self.density_bitfield.device) # [1, 1]
|
||||
else:
|
||||
eye = None
|
||||
|
||||
# full update
|
||||
X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
||||
Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
||||
Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
||||
|
||||
for xs in X:
|
||||
for ys in Y:
|
||||
for zs in Z:
|
||||
|
||||
# construct points
|
||||
xx, yy, zz = custom_meshgrid(xs, ys, zs)
|
||||
coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
|
||||
indices = raymarching.morton3D(coords).long() # [N]
|
||||
xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
|
||||
|
||||
# cascading
|
||||
for cas in range(self.cascade):
|
||||
bound = min(2 ** cas, self.bound)
|
||||
half_grid_size = bound / self.grid_size
|
||||
# scale to current cascade's resolution
|
||||
cas_xyzs = xyzs * (bound - half_grid_size)
|
||||
# add noise in [-hgs, hgs]
|
||||
cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
|
||||
# query density
|
||||
aud_norms = self.density(cas_xyzs.to(tmp_grid.dtype), enc_a, eye)['ambient_aud'].reshape(-1).detach().to(tmp_grid.dtype)
|
||||
# assign
|
||||
tmp_grid[cas, indices] = aud_norms
|
||||
|
||||
# dilate the density_grid (less aggressive culling)
|
||||
tmp_grid = raymarching.morton3D_dilation(tmp_grid)
|
||||
return tmp_grid
|
||||
# # ema update
|
||||
# valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0)
|
||||
# self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_eye_grid(self, S=128):
|
||||
# call before each epoch to update extra states.
|
||||
|
||||
if not self.cuda_ray:
|
||||
return
|
||||
|
||||
# use random auds (different expressions should have similar density grid...)
|
||||
rand_idx = random.randint(0, self.aud_features.shape[0] - 1)
|
||||
auds = get_audio_features(self.aud_features, self.att, rand_idx).to(self.density_bitfield.device)
|
||||
|
||||
# encode audio
|
||||
enc_a = self.encode_audio(auds)
|
||||
tmp_grid = torch.zeros_like(self.density_grid)
|
||||
|
||||
# use a random eye area based on training dataset's statistics...
|
||||
if self.exp_eye:
|
||||
eye = self.eye_area[[rand_idx]].to(self.density_bitfield.device) # [1, 1]
|
||||
else:
|
||||
eye = None
|
||||
|
||||
# full update
|
||||
X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
||||
Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
||||
Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
||||
|
||||
for xs in X:
|
||||
for ys in Y:
|
||||
for zs in Z:
|
||||
|
||||
# construct points
|
||||
xx, yy, zz = custom_meshgrid(xs, ys, zs)
|
||||
coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
|
||||
indices = raymarching.morton3D(coords).long() # [N]
|
||||
xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
|
||||
|
||||
# cascading
|
||||
for cas in range(self.cascade):
|
||||
bound = min(2 ** cas, self.bound)
|
||||
half_grid_size = bound / self.grid_size
|
||||
# scale to current cascade's resolution
|
||||
cas_xyzs = xyzs * (bound - half_grid_size)
|
||||
# add noise in [-hgs, hgs]
|
||||
cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
|
||||
# query density
|
||||
eye_norms = self.density(cas_xyzs.to(tmp_grid.dtype), enc_a, eye)['ambient_eye'].reshape(-1).detach().to(tmp_grid.dtype)
|
||||
# assign
|
||||
tmp_grid[cas, indices] = eye_norms
|
||||
|
||||
# dilate the density_grid (less aggressive culling)
|
||||
tmp_grid = raymarching.morton3D_dilation(tmp_grid)
|
||||
return tmp_grid
|
||||
# # ema update
|
||||
# valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0)
|
||||
# self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])
|
||||
|
||||
|
||||
|
||||
def render(self, rays_o, rays_d, auds, bg_coords, poses, staged=False, max_ray_batch=4096, **kwargs):
|
||||
# rays_o, rays_d: [B, N, 3], assumes B == 1
|
||||
# auds: [B, 29, 16]
|
||||
# eye: [B, 1]
|
||||
# bg_coords: [1, N, 2]
|
||||
# return: pred_rgb: [B, N, 3]
|
||||
|
||||
_run = self.run_cuda
|
||||
|
||||
B, N = rays_o.shape[:2]
|
||||
device = rays_o.device
|
||||
|
||||
# never stage when cuda_ray
|
||||
if staged and not self.cuda_ray:
|
||||
# not used
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
results = _run(rays_o, rays_d, auds, bg_coords, poses, **kwargs)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def render_torso(self, rays_o, rays_d, auds, bg_coords, poses, staged=False, max_ray_batch=4096, **kwargs):
|
||||
# rays_o, rays_d: [B, N, 3], assumes B == 1
|
||||
# auds: [B, 29, 16]
|
||||
# eye: [B, 1]
|
||||
# bg_coords: [1, N, 2]
|
||||
# return: pred_rgb: [B, N, 3]
|
||||
|
||||
_run = self.run_torso
|
||||
|
||||
B, N = rays_o.shape[:2]
|
||||
device = rays_o.device
|
||||
|
||||
# never stage when cuda_ray
|
||||
if staged and not self.cuda_ray:
|
||||
# not used
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
results = _run(rays_o, bg_coords, poses, **kwargs)
|
||||
|
||||
return results
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,158 @@
|
|||
import math
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
#from .utils import *
|
||||
import subprocess
|
||||
import os
|
||||
|
||||
from asrreal import ASR
|
||||
|
||||
class NeRFReal:
|
||||
def __init__(self, opt, trainer, data_loader, debug=True):
|
||||
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
|
||||
self.W = opt.W
|
||||
self.H = opt.H
|
||||
self.debug = debug
|
||||
self.training = False
|
||||
self.step = 0 # training step
|
||||
|
||||
self.trainer = trainer
|
||||
self.data_loader = data_loader
|
||||
|
||||
# use dataloader's bg
|
||||
bg_img = data_loader._data.bg_img #.view(1, -1, 3)
|
||||
if self.H != bg_img.shape[0] or self.W != bg_img.shape[1]:
|
||||
bg_img = F.interpolate(bg_img.permute(2, 0, 1).unsqueeze(0).contiguous(), (self.H, self.W), mode='bilinear').squeeze(0).permute(1, 2, 0).contiguous()
|
||||
self.bg_color = bg_img.view(1, -1, 3)
|
||||
|
||||
# audio features (from dataloader, only used in non-playing mode)
|
||||
self.audio_features = data_loader._data.auds # [N, 29, 16]
|
||||
self.audio_idx = 0
|
||||
|
||||
# control eye
|
||||
self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item()
|
||||
|
||||
# playing seq from dataloader, or pause.
|
||||
self.playing = True #False todo
|
||||
self.loader = iter(data_loader)
|
||||
|
||||
self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
|
||||
self.need_update = True # camera moved, should reset accumulation
|
||||
self.spp = 1 # sample per pixel
|
||||
self.mode = 'image' # choose from ['image', 'depth']
|
||||
|
||||
self.dynamic_resolution = False # assert False!
|
||||
self.downscale = 1
|
||||
self.train_steps = 16
|
||||
|
||||
self.ind_index = 0
|
||||
self.ind_num = trainer.model.individual_codes.shape[0]
|
||||
|
||||
# build asr
|
||||
if self.opt.asr:
|
||||
self.asr = ASR(opt)
|
||||
|
||||
video_path = 'video_stream'
|
||||
if not os.path.exists(video_path):
|
||||
os.mkfifo(video_path, mode=0o777)
|
||||
audio_path = 'audio_stream'
|
||||
if not os.path.exists(audio_path):
|
||||
os.mkfifo(audio_path, mode=0o777)
|
||||
width=450
|
||||
height=450
|
||||
fps=25
|
||||
push_url='rtmp://localhost/live/livestream' #'data/video/output_0.mp4'
|
||||
command = ['ffmpeg',
|
||||
'-y', #'-an',
|
||||
#'-re',
|
||||
'-f', 'rawvideo',
|
||||
'-vcodec','rawvideo',
|
||||
'-pix_fmt', 'rgb24', #像素格式
|
||||
'-s', "{}x{}".format(width, height),
|
||||
'-r', str(fps),
|
||||
'-i', video_path,
|
||||
'-f', 's16le',
|
||||
'-acodec','pcm_s16le',
|
||||
'-ac', '1',
|
||||
'-ar', '16000',
|
||||
'-i', audio_path,
|
||||
#'-fflags', '+genpts',
|
||||
'-map', '0:v',
|
||||
'-map', '1:a',
|
||||
#'-copyts',
|
||||
'-acodec', 'aac',
|
||||
'-pix_fmt', 'yuv420p', #'-vcodec', "h264",
|
||||
#"-rtmp_buffer", "100",
|
||||
'-f' , 'flv',
|
||||
push_url]
|
||||
self.pipe = subprocess.Popen(command, shell=False) #, stdin=subprocess.PIPE)
|
||||
self.fifo_video = open(video_path, 'wb')
|
||||
self.fifo_audio = open(audio_path, 'wb')
|
||||
#self.test_step()
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
if self.opt.asr:
|
||||
self.asr.stop()
|
||||
|
||||
def push_audio(self,chunk):
|
||||
self.asr.push_audio(chunk)
|
||||
|
||||
def prepare_buffer(self, outputs):
|
||||
if self.mode == 'image':
|
||||
return outputs['image']
|
||||
else:
|
||||
return np.expand_dims(outputs['depth'], -1).repeat(3, -1)
|
||||
|
||||
def test_step(self):
|
||||
|
||||
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
||||
starter.record()
|
||||
|
||||
if self.playing:
|
||||
try:
|
||||
data = next(self.loader)
|
||||
except StopIteration:
|
||||
self.loader = iter(self.data_loader)
|
||||
data = next(self.loader)
|
||||
|
||||
if self.opt.asr:
|
||||
# use the live audio stream
|
||||
data['auds'] = self.asr.get_next_feat()
|
||||
|
||||
outputs = self.trainer.test_gui_with_data(data, self.W, self.H)
|
||||
print(f'[INFO] outputs shape ',outputs['image'].shape)
|
||||
image = (outputs['image'] * 255).astype(np.uint8)
|
||||
#self.pipe.stdin.write(image.tostring())
|
||||
for _ in range(2):
|
||||
frame = self.asr.get_audio_out()
|
||||
print(f'[INFO] get_audio_out shape ',frame.shape)
|
||||
frame = (frame * 32767).astype(np.int16).tobytes()
|
||||
self.fifo_audio.write(frame)
|
||||
self.fifo_video.write(image.tostring())
|
||||
else:
|
||||
if self.audio_features is not None:
|
||||
auds = get_audio_features(self.audio_features, self.opt.att, self.audio_idx)
|
||||
else:
|
||||
auds = None
|
||||
outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.W, self.H, auds, self.eye_area, self.ind_index, self.bg_color, self.spp, self.downscale)
|
||||
|
||||
ender.record()
|
||||
torch.cuda.synchronize()
|
||||
t = starter.elapsed_time(ender)
|
||||
|
||||
def render(self):
|
||||
if self.opt.asr:
|
||||
self.asr.warm_up()
|
||||
while True: #todo
|
||||
# update texture every frame
|
||||
# audio stream thread...
|
||||
if self.opt.asr and self.playing:
|
||||
# run 2 ASR steps (audio is at 50FPS, video is at 25FPS)
|
||||
for _ in range(2):
|
||||
self.asr.run_step()
|
||||
self.test_step()
|
|
@ -0,0 +1 @@
|
|||
from .raymarching import *
|
|
@ -0,0 +1,40 @@
|
|||
import os
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
nvcc_flags = [
|
||||
'-O3', '-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
||||
]
|
||||
|
||||
if os.name == "posix":
|
||||
c_flags = ['-O3', '-std=c++14']
|
||||
elif os.name == "nt":
|
||||
c_flags = ['/O2', '/std:c++17']
|
||||
|
||||
# find cl.exe
|
||||
def find_cl_path():
|
||||
import glob
|
||||
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
||||
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
||||
if paths:
|
||||
return paths[0]
|
||||
|
||||
# If cl.exe is not on path, try to find it.
|
||||
if os.system("where cl.exe >nul 2>nul") != 0:
|
||||
cl_path = find_cl_path()
|
||||
if cl_path is None:
|
||||
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
||||
os.environ["PATH"] += ";" + cl_path
|
||||
|
||||
_backend = load(name='_raymarching_face',
|
||||
extra_cflags=c_flags,
|
||||
extra_cuda_cflags=nvcc_flags,
|
||||
sources=[os.path.join(_src_path, 'src', f) for f in [
|
||||
'raymarching.cu',
|
||||
'bindings.cpp',
|
||||
]],
|
||||
)
|
||||
|
||||
__all__ = ['_backend']
|
|
@ -0,0 +1,671 @@
|
|||
import numpy as np
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Function
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
try:
|
||||
import _raymarching_face as _backend
|
||||
except ImportError:
|
||||
from .backend import _backend
|
||||
|
||||
# ----------------------------------------
|
||||
# utils
|
||||
# ----------------------------------------
|
||||
|
||||
class _near_far_from_aabb(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, rays_o, rays_d, aabb, min_near=0.2):
|
||||
''' near_far_from_aabb, CUDA implementation
|
||||
Calculate rays' intersection time (near and far) with aabb
|
||||
Args:
|
||||
rays_o: float, [N, 3]
|
||||
rays_d: float, [N, 3]
|
||||
aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax)
|
||||
min_near: float, scalar
|
||||
Returns:
|
||||
nears: float, [N]
|
||||
fars: float, [N]
|
||||
'''
|
||||
if not rays_o.is_cuda: rays_o = rays_o.cuda()
|
||||
if not rays_d.is_cuda: rays_d = rays_d.cuda()
|
||||
|
||||
rays_o = rays_o.contiguous().view(-1, 3)
|
||||
rays_d = rays_d.contiguous().view(-1, 3)
|
||||
|
||||
N = rays_o.shape[0] # num rays
|
||||
|
||||
nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
|
||||
fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
|
||||
|
||||
_backend.near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars)
|
||||
|
||||
return nears, fars
|
||||
|
||||
near_far_from_aabb = _near_far_from_aabb.apply
|
||||
|
||||
|
||||
class _sph_from_ray(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, rays_o, rays_d, radius):
|
||||
''' sph_from_ray, CUDA implementation
|
||||
get spherical coordinate on the background sphere from rays.
|
||||
Assume rays_o are inside the Sphere(radius).
|
||||
Args:
|
||||
rays_o: [N, 3]
|
||||
rays_d: [N, 3]
|
||||
radius: scalar, float
|
||||
Return:
|
||||
coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface)
|
||||
'''
|
||||
if not rays_o.is_cuda: rays_o = rays_o.cuda()
|
||||
if not rays_d.is_cuda: rays_d = rays_d.cuda()
|
||||
|
||||
rays_o = rays_o.contiguous().view(-1, 3)
|
||||
rays_d = rays_d.contiguous().view(-1, 3)
|
||||
|
||||
N = rays_o.shape[0] # num rays
|
||||
|
||||
coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device)
|
||||
|
||||
_backend.sph_from_ray(rays_o, rays_d, radius, N, coords)
|
||||
|
||||
return coords
|
||||
|
||||
sph_from_ray = _sph_from_ray.apply
|
||||
|
||||
|
||||
class _morton3D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, coords):
|
||||
''' morton3D, CUDA implementation
|
||||
Args:
|
||||
coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...)
|
||||
TODO: check if the coord range is valid! (current 128 is safe)
|
||||
Returns:
|
||||
indices: [N], int32, in [0, 128^3)
|
||||
|
||||
'''
|
||||
if not coords.is_cuda: coords = coords.cuda()
|
||||
|
||||
N = coords.shape[0]
|
||||
|
||||
indices = torch.empty(N, dtype=torch.int32, device=coords.device)
|
||||
|
||||
_backend.morton3D(coords.int(), N, indices)
|
||||
|
||||
return indices
|
||||
|
||||
morton3D = _morton3D.apply
|
||||
|
||||
class _morton3D_invert(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, indices):
|
||||
''' morton3D_invert, CUDA implementation
|
||||
Args:
|
||||
indices: [N], int32, in [0, 128^3)
|
||||
Returns:
|
||||
coords: [N, 3], int32, in [0, 128)
|
||||
|
||||
'''
|
||||
if not indices.is_cuda: indices = indices.cuda()
|
||||
|
||||
N = indices.shape[0]
|
||||
|
||||
coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device)
|
||||
|
||||
_backend.morton3D_invert(indices.int(), N, coords)
|
||||
|
||||
return coords
|
||||
|
||||
morton3D_invert = _morton3D_invert.apply
|
||||
|
||||
|
||||
class _packbits(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, grid, thresh, bitfield=None):
|
||||
''' packbits, CUDA implementation
|
||||
Pack up the density grid into a bit field to accelerate ray marching.
|
||||
Args:
|
||||
grid: float, [C, H * H * H], assume H % 2 == 0
|
||||
thresh: float, threshold
|
||||
Returns:
|
||||
bitfield: uint8, [C, H * H * H / 8]
|
||||
'''
|
||||
if not grid.is_cuda: grid = grid.cuda()
|
||||
grid = grid.contiguous()
|
||||
|
||||
C = grid.shape[0]
|
||||
H3 = grid.shape[1]
|
||||
N = C * H3 // 8
|
||||
|
||||
if bitfield is None:
|
||||
bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device)
|
||||
|
||||
_backend.packbits(grid, N, thresh, bitfield)
|
||||
|
||||
return bitfield
|
||||
|
||||
packbits = _packbits.apply
|
||||
|
||||
|
||||
class _morton3D_dilation(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, grid):
|
||||
''' max pooling with morton coord, CUDA implementation
|
||||
or maybe call it dilation... we don't support adjust kernel size.
|
||||
Args:
|
||||
grid: float, [C, H * H * H], assume H % 2 == 0
|
||||
Returns:
|
||||
grid_dilate: float, [C, H * H * H], assume H % 2 == 0bitfield: uint8, [C, H * H * H / 8]
|
||||
'''
|
||||
if not grid.is_cuda: grid = grid.cuda()
|
||||
grid = grid.contiguous()
|
||||
|
||||
C = grid.shape[0]
|
||||
H3 = grid.shape[1]
|
||||
H = int(np.cbrt(H3))
|
||||
grid_dilation = torch.empty_like(grid)
|
||||
|
||||
_backend.morton3D_dilation(grid, C, H, grid_dilation)
|
||||
|
||||
return grid_dilation
|
||||
|
||||
morton3D_dilation = _morton3D_dilation.apply
|
||||
|
||||
# ----------------------------------------
|
||||
# train functions
|
||||
# ----------------------------------------
|
||||
|
||||
class _march_rays_train(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024):
|
||||
''' march rays to generate points (forward only)
|
||||
Args:
|
||||
rays_o/d: float, [N, 3]
|
||||
bound: float, scalar
|
||||
density_bitfield: uint8: [CHHH // 8]
|
||||
C: int
|
||||
H: int
|
||||
nears/fars: float, [N]
|
||||
step_counter: int32, (2), used to count the actual number of generated points.
|
||||
mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.)
|
||||
perturb: bool
|
||||
align: int, pad output so its size is dividable by align, set to -1 to disable.
|
||||
force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays.
|
||||
dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
|
||||
max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
|
||||
Returns:
|
||||
xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray)
|
||||
dirs: float, [M, 3], all generated points' view dirs.
|
||||
deltas: float, [M, 2], first is delta_t, second is rays_t
|
||||
rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 1] + rays[i, 2]] --> points belonging to rays[i, 0]
|
||||
'''
|
||||
|
||||
if not rays_o.is_cuda: rays_o = rays_o.cuda()
|
||||
if not rays_d.is_cuda: rays_d = rays_d.cuda()
|
||||
if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda()
|
||||
|
||||
rays_o = rays_o.contiguous().view(-1, 3)
|
||||
rays_d = rays_d.contiguous().view(-1, 3)
|
||||
density_bitfield = density_bitfield.contiguous()
|
||||
|
||||
N = rays_o.shape[0] # num rays
|
||||
M = N * max_steps # init max points number in total
|
||||
|
||||
# running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp)
|
||||
# It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated.
|
||||
if not force_all_rays and mean_count > 0:
|
||||
if align > 0:
|
||||
mean_count += align - mean_count % align
|
||||
M = mean_count
|
||||
|
||||
xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
|
||||
dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
|
||||
deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device)
|
||||
rays = torch.empty(N, 3, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps
|
||||
|
||||
if step_counter is None:
|
||||
step_counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device) # point counter, ray counter
|
||||
|
||||
if perturb:
|
||||
noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device)
|
||||
else:
|
||||
noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device)
|
||||
|
||||
_backend.march_rays_train(rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars, xyzs, dirs, deltas, rays, step_counter, noises) # m is the actually used points number
|
||||
|
||||
#print(step_counter, M)
|
||||
|
||||
# only used at the first (few) epochs.
|
||||
if force_all_rays or mean_count <= 0:
|
||||
m = step_counter[0].item() # D2H copy
|
||||
if align > 0:
|
||||
m += align - m % align
|
||||
xyzs = xyzs[:m]
|
||||
dirs = dirs[:m]
|
||||
deltas = deltas[:m]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
ctx.save_for_backward(rays, deltas)
|
||||
|
||||
return xyzs, dirs, deltas, rays
|
||||
|
||||
# to support optimizing camera poses.
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_xyzs, grad_dirs, grad_deltas, grad_rays):
|
||||
# grad_xyzs/dirs: [M, 3]
|
||||
|
||||
rays, deltas = ctx.saved_tensors
|
||||
|
||||
N = rays.shape[0]
|
||||
M = grad_xyzs.shape[0]
|
||||
|
||||
grad_rays_o = torch.zeros(N, 3, device=rays.device)
|
||||
grad_rays_d = torch.zeros(N, 3, device=rays.device)
|
||||
|
||||
_backend.march_rays_train_backward(grad_xyzs, grad_dirs, rays, deltas, N, M, grad_rays_o, grad_rays_d)
|
||||
|
||||
return grad_rays_o, grad_rays_d, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
march_rays_train = _march_rays_train.apply
|
||||
|
||||
|
||||
class _composite_rays_train(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, sigmas, rgbs, ambient, deltas, rays, T_thresh=1e-4):
|
||||
''' composite rays' rgbs, according to the ray marching formula.
|
||||
Args:
|
||||
rgbs: float, [M, 3]
|
||||
sigmas: float, [M,]
|
||||
ambient: float, [M,] (after summing up the last dimension)
|
||||
deltas: float, [M, 2]
|
||||
rays: int32, [N, 3]
|
||||
Returns:
|
||||
weights_sum: float, [N,], the alpha channel
|
||||
depth: float, [N, ], the Depth
|
||||
image: float, [N, 3], the RGB channel (after multiplying alpha!)
|
||||
'''
|
||||
|
||||
sigmas = sigmas.contiguous()
|
||||
rgbs = rgbs.contiguous()
|
||||
ambient = ambient.contiguous()
|
||||
|
||||
M = sigmas.shape[0]
|
||||
N = rays.shape[0]
|
||||
|
||||
weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
|
||||
ambient_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
|
||||
depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
|
||||
image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)
|
||||
|
||||
_backend.composite_rays_train_forward(sigmas, rgbs, ambient, deltas, rays, M, N, T_thresh, weights_sum, ambient_sum, depth, image)
|
||||
|
||||
ctx.save_for_backward(sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, depth, image)
|
||||
ctx.dims = [M, N, T_thresh]
|
||||
|
||||
return weights_sum, ambient_sum, depth, image
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_weights_sum, grad_ambient_sum, grad_depth, grad_image):
|
||||
|
||||
# NOTE: grad_depth is not used now! It won't be propagated to sigmas.
|
||||
|
||||
grad_weights_sum = grad_weights_sum.contiguous()
|
||||
grad_ambient_sum = grad_ambient_sum.contiguous()
|
||||
grad_image = grad_image.contiguous()
|
||||
|
||||
sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, depth, image = ctx.saved_tensors
|
||||
M, N, T_thresh = ctx.dims
|
||||
|
||||
grad_sigmas = torch.zeros_like(sigmas)
|
||||
grad_rgbs = torch.zeros_like(rgbs)
|
||||
grad_ambient = torch.zeros_like(ambient)
|
||||
|
||||
_backend.composite_rays_train_backward(grad_weights_sum, grad_ambient_sum, grad_image, sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs, grad_ambient)
|
||||
|
||||
return grad_sigmas, grad_rgbs, grad_ambient, None, None, None
|
||||
|
||||
|
||||
composite_rays_train = _composite_rays_train.apply
|
||||
|
||||
# ----------------------------------------
|
||||
# infer functions
|
||||
# ----------------------------------------
|
||||
|
||||
class _march_rays(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, align=-1, perturb=False, dt_gamma=0, max_steps=1024):
|
||||
''' march rays to generate points (forward only, for inference)
|
||||
Args:
|
||||
n_alive: int, number of alive rays
|
||||
n_step: int, how many steps we march
|
||||
rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive)
|
||||
rays_t: float, [N], the alive rays' time, we only use the first n_alive.
|
||||
rays_o/d: float, [N, 3]
|
||||
bound: float, scalar
|
||||
density_bitfield: uint8: [CHHH // 8]
|
||||
C: int
|
||||
H: int
|
||||
nears/fars: float, [N]
|
||||
align: int, pad output so its size is dividable by align, set to -1 to disable.
|
||||
perturb: bool/int, int > 0 is used as the random seed.
|
||||
dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
|
||||
max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
|
||||
Returns:
|
||||
xyzs: float, [n_alive * n_step, 3], all generated points' coords
|
||||
dirs: float, [n_alive * n_step, 3], all generated points' view dirs.
|
||||
deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).
|
||||
'''
|
||||
|
||||
if not rays_o.is_cuda: rays_o = rays_o.cuda()
|
||||
if not rays_d.is_cuda: rays_d = rays_d.cuda()
|
||||
|
||||
rays_o = rays_o.contiguous().view(-1, 3)
|
||||
rays_d = rays_d.contiguous().view(-1, 3)
|
||||
|
||||
M = n_alive * n_step
|
||||
|
||||
if align > 0:
|
||||
M += align - (M % align)
|
||||
|
||||
xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
|
||||
dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
|
||||
deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth
|
||||
|
||||
if perturb:
|
||||
# torch.manual_seed(perturb) # test_gui uses spp index as seed
|
||||
noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device)
|
||||
else:
|
||||
noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device)
|
||||
|
||||
_backend.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, deltas, noises)
|
||||
|
||||
return xyzs, dirs, deltas
|
||||
|
||||
march_rays = _march_rays.apply
|
||||
|
||||
|
||||
class _composite_rays(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
|
||||
def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh=1e-2):
|
||||
''' composite rays' rgbs, according to the ray marching formula. (for inference)
|
||||
Args:
|
||||
n_alive: int, number of alive rays
|
||||
n_step: int, how many steps we march
|
||||
rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive)
|
||||
rays_t: float, [N], the alive rays' time
|
||||
sigmas: float, [n_alive * n_step,]
|
||||
rgbs: float, [n_alive * n_step, 3]
|
||||
deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).
|
||||
In-place Outputs:
|
||||
weights_sum: float, [N,], the alpha channel
|
||||
depth: float, [N,], the depth value
|
||||
image: float, [N, 3], the RGB channel (after multiplying alpha!)
|
||||
'''
|
||||
_backend.composite_rays(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image)
|
||||
return tuple()
|
||||
|
||||
|
||||
composite_rays = _composite_rays.apply
|
||||
|
||||
|
||||
class _composite_rays_ambient(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
|
||||
def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, weights_sum, depth, image, ambient_sum, T_thresh=1e-2):
|
||||
_backend.composite_rays_ambient(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, weights_sum, depth, image, ambient_sum)
|
||||
return tuple()
|
||||
|
||||
|
||||
composite_rays_ambient = _composite_rays_ambient.apply
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# custom
|
||||
|
||||
class _composite_rays_train_sigma(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, sigmas, rgbs, ambient, deltas, rays, T_thresh=1e-4):
|
||||
''' composite rays' rgbs, according to the ray marching formula.
|
||||
Args:
|
||||
rgbs: float, [M, 3]
|
||||
sigmas: float, [M,]
|
||||
ambient: float, [M,] (after summing up the last dimension)
|
||||
deltas: float, [M, 2]
|
||||
rays: int32, [N, 3]
|
||||
Returns:
|
||||
weights_sum: float, [N,], the alpha channel
|
||||
depth: float, [N, ], the Depth
|
||||
image: float, [N, 3], the RGB channel (after multiplying alpha!)
|
||||
'''
|
||||
|
||||
sigmas = sigmas.contiguous()
|
||||
rgbs = rgbs.contiguous()
|
||||
ambient = ambient.contiguous()
|
||||
|
||||
M = sigmas.shape[0]
|
||||
N = rays.shape[0]
|
||||
|
||||
weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
|
||||
ambient_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
|
||||
depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
|
||||
image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)
|
||||
|
||||
_backend.composite_rays_train_sigma_forward(sigmas, rgbs, ambient, deltas, rays, M, N, T_thresh, weights_sum, ambient_sum, depth, image)
|
||||
|
||||
ctx.save_for_backward(sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, depth, image)
|
||||
ctx.dims = [M, N, T_thresh]
|
||||
|
||||
return weights_sum, ambient_sum, depth, image
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_weights_sum, grad_ambient_sum, grad_depth, grad_image):
|
||||
|
||||
# NOTE: grad_depth is not used now! It won't be propagated to sigmas.
|
||||
|
||||
grad_weights_sum = grad_weights_sum.contiguous()
|
||||
grad_ambient_sum = grad_ambient_sum.contiguous()
|
||||
grad_image = grad_image.contiguous()
|
||||
|
||||
sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, depth, image = ctx.saved_tensors
|
||||
M, N, T_thresh = ctx.dims
|
||||
|
||||
grad_sigmas = torch.zeros_like(sigmas)
|
||||
grad_rgbs = torch.zeros_like(rgbs)
|
||||
grad_ambient = torch.zeros_like(ambient)
|
||||
|
||||
_backend.composite_rays_train_sigma_backward(grad_weights_sum, grad_ambient_sum, grad_image, sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs, grad_ambient)
|
||||
|
||||
return grad_sigmas, grad_rgbs, grad_ambient, None, None, None
|
||||
|
||||
|
||||
composite_rays_train_sigma = _composite_rays_train_sigma.apply
|
||||
|
||||
|
||||
class _composite_rays_ambient_sigma(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
|
||||
def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, weights_sum, depth, image, ambient_sum, T_thresh=1e-2):
|
||||
_backend.composite_rays_ambient_sigma(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, weights_sum, depth, image, ambient_sum)
|
||||
return tuple()
|
||||
|
||||
|
||||
composite_rays_ambient_sigma = _composite_rays_ambient_sigma.apply
|
||||
|
||||
|
||||
|
||||
# uncertainty
|
||||
class _composite_rays_train_uncertainty(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, sigmas, rgbs, ambient, uncertainty, deltas, rays, T_thresh=1e-4):
|
||||
''' composite rays' rgbs, according to the ray marching formula.
|
||||
Args:
|
||||
rgbs: float, [M, 3]
|
||||
sigmas: float, [M,]
|
||||
ambient: float, [M,] (after summing up the last dimension)
|
||||
deltas: float, [M, 2]
|
||||
rays: int32, [N, 3]
|
||||
Returns:
|
||||
weights_sum: float, [N,], the alpha channel
|
||||
depth: float, [N, ], the Depth
|
||||
image: float, [N, 3], the RGB channel (after multiplying alpha!)
|
||||
'''
|
||||
|
||||
sigmas = sigmas.contiguous()
|
||||
rgbs = rgbs.contiguous()
|
||||
ambient = ambient.contiguous()
|
||||
uncertainty = uncertainty.contiguous()
|
||||
|
||||
M = sigmas.shape[0]
|
||||
N = rays.shape[0]
|
||||
|
||||
weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
|
||||
ambient_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
|
||||
uncertainty_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
|
||||
depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
|
||||
image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)
|
||||
|
||||
_backend.composite_rays_train_uncertainty_forward(sigmas, rgbs, ambient, uncertainty, deltas, rays, M, N, T_thresh, weights_sum, ambient_sum, uncertainty_sum, depth, image)
|
||||
|
||||
ctx.save_for_backward(sigmas, rgbs, ambient, uncertainty, deltas, rays, weights_sum, ambient_sum, uncertainty_sum, depth, image)
|
||||
ctx.dims = [M, N, T_thresh]
|
||||
|
||||
return weights_sum, ambient_sum, uncertainty_sum, depth, image
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_weights_sum, grad_ambient_sum, grad_uncertainty_sum, grad_depth, grad_image):
|
||||
|
||||
# NOTE: grad_depth is not used now! It won't be propagated to sigmas.
|
||||
|
||||
grad_weights_sum = grad_weights_sum.contiguous()
|
||||
grad_ambient_sum = grad_ambient_sum.contiguous()
|
||||
grad_uncertainty_sum = grad_uncertainty_sum.contiguous()
|
||||
grad_image = grad_image.contiguous()
|
||||
|
||||
sigmas, rgbs, ambient, uncertainty, deltas, rays, weights_sum, ambient_sum, uncertainty_sum, depth, image = ctx.saved_tensors
|
||||
M, N, T_thresh = ctx.dims
|
||||
|
||||
grad_sigmas = torch.zeros_like(sigmas)
|
||||
grad_rgbs = torch.zeros_like(rgbs)
|
||||
grad_ambient = torch.zeros_like(ambient)
|
||||
grad_uncertainty = torch.zeros_like(uncertainty)
|
||||
|
||||
_backend.composite_rays_train_uncertainty_backward(grad_weights_sum, grad_ambient_sum, grad_uncertainty_sum, grad_image, sigmas, rgbs, ambient, uncertainty, deltas, rays, weights_sum, ambient_sum, uncertainty_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs, grad_ambient, grad_uncertainty)
|
||||
|
||||
return grad_sigmas, grad_rgbs, grad_ambient, grad_uncertainty, None, None, None
|
||||
|
||||
|
||||
composite_rays_train_uncertainty = _composite_rays_train_uncertainty.apply
|
||||
|
||||
|
||||
class _composite_rays_uncertainty(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
|
||||
def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, uncertainties, weights_sum, depth, image, ambient_sum, uncertainty_sum, T_thresh=1e-2):
|
||||
_backend.composite_rays_uncertainty(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, uncertainties, weights_sum, depth, image, ambient_sum, uncertainty_sum)
|
||||
return tuple()
|
||||
|
||||
|
||||
composite_rays_uncertainty = _composite_rays_uncertainty.apply
|
||||
|
||||
|
||||
|
||||
# triplane(eye)
|
||||
class _composite_rays_train_triplane(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, sigmas, rgbs, amb_aud, amb_eye, uncertainty, deltas, rays, T_thresh=1e-4):
|
||||
''' composite rays' rgbs, according to the ray marching formula.
|
||||
Args:
|
||||
rgbs: float, [M, 3]
|
||||
sigmas: float, [M,]
|
||||
ambient: float, [M,] (after summing up the last dimension)
|
||||
deltas: float, [M, 2]
|
||||
rays: int32, [N, 3]
|
||||
Returns:
|
||||
weights_sum: float, [N,], the alpha channel
|
||||
depth: float, [N, ], the Depth
|
||||
image: float, [N, 3], the RGB channel (after multiplying alpha!)
|
||||
'''
|
||||
|
||||
sigmas = sigmas.contiguous()
|
||||
rgbs = rgbs.contiguous()
|
||||
amb_aud = amb_aud.contiguous()
|
||||
amb_eye = amb_eye.contiguous()
|
||||
uncertainty = uncertainty.contiguous()
|
||||
|
||||
M = sigmas.shape[0]
|
||||
N = rays.shape[0]
|
||||
|
||||
weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
|
||||
amb_aud_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
|
||||
amb_eye_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
|
||||
uncertainty_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
|
||||
depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
|
||||
image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)
|
||||
|
||||
_backend.composite_rays_train_triplane_forward(sigmas, rgbs, amb_aud, amb_eye, uncertainty, deltas, rays, M, N, T_thresh, weights_sum, amb_aud_sum, amb_eye_sum, uncertainty_sum, depth, image)
|
||||
|
||||
ctx.save_for_backward(sigmas, rgbs, amb_aud, amb_eye, uncertainty, deltas, rays, weights_sum, amb_aud_sum, amb_eye_sum, uncertainty_sum, depth, image)
|
||||
ctx.dims = [M, N, T_thresh]
|
||||
|
||||
return weights_sum, amb_aud_sum, amb_eye_sum, uncertainty_sum, depth, image
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_weights_sum, grad_amb_aud_sum, grad_amb_eye_sum, grad_uncertainty_sum, grad_depth, grad_image):
|
||||
|
||||
# NOTE: grad_depth is not used now! It won't be propagated to sigmas.
|
||||
|
||||
grad_weights_sum = grad_weights_sum.contiguous()
|
||||
grad_amb_aud_sum = grad_amb_aud_sum.contiguous()
|
||||
grad_amb_eye_sum = grad_amb_eye_sum.contiguous()
|
||||
grad_uncertainty_sum = grad_uncertainty_sum.contiguous()
|
||||
grad_image = grad_image.contiguous()
|
||||
|
||||
sigmas, rgbs, amb_aud, amb_eye, uncertainty, deltas, rays, weights_sum, amb_aud_sum, amb_eye_sum, uncertainty_sum, depth, image = ctx.saved_tensors
|
||||
M, N, T_thresh = ctx.dims
|
||||
|
||||
grad_sigmas = torch.zeros_like(sigmas)
|
||||
grad_rgbs = torch.zeros_like(rgbs)
|
||||
grad_amb_aud = torch.zeros_like(amb_aud)
|
||||
grad_amb_eye = torch.zeros_like(amb_eye)
|
||||
grad_uncertainty = torch.zeros_like(uncertainty)
|
||||
|
||||
_backend.composite_rays_train_triplane_backward(grad_weights_sum, grad_amb_aud_sum, grad_amb_eye_sum, grad_uncertainty_sum, grad_image, sigmas, rgbs, amb_aud, amb_eye, uncertainty, deltas, rays, weights_sum, amb_aud_sum, amb_eye_sum, uncertainty_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs, grad_amb_aud, grad_amb_eye, grad_uncertainty)
|
||||
|
||||
return grad_sigmas, grad_rgbs, grad_amb_aud, grad_amb_eye, grad_uncertainty, None, None, None
|
||||
|
||||
|
||||
composite_rays_train_triplane = _composite_rays_train_triplane.apply
|
||||
|
||||
|
||||
class _composite_rays_triplane(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
|
||||
def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambs_aud, ambs_eye, uncertainties, weights_sum, depth, image, amb_aud_sum, amb_eye_sum, uncertainty_sum, T_thresh=1e-2):
|
||||
_backend.composite_rays_triplane(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, ambs_aud, ambs_eye, uncertainties, weights_sum, depth, image, amb_aud_sum, amb_eye_sum, uncertainty_sum)
|
||||
return tuple()
|
||||
|
||||
|
||||
composite_rays_triplane = _composite_rays_triplane.apply
|
|
@ -0,0 +1,63 @@
|
|||
import os
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
nvcc_flags = [
|
||||
'-O3', '-std=c++14',
|
||||
# '-lineinfo', # to debug illegal memory access
|
||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
||||
]
|
||||
|
||||
if os.name == "posix":
|
||||
c_flags = ['-O3', '-std=c++14']
|
||||
elif os.name == "nt":
|
||||
c_flags = ['/O2', '/std:c++17']
|
||||
|
||||
# find cl.exe
|
||||
def find_cl_path():
|
||||
import glob
|
||||
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
||||
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
||||
if paths:
|
||||
return paths[0]
|
||||
|
||||
# If cl.exe is not on path, try to find it.
|
||||
if os.system("where cl.exe >nul 2>nul") != 0:
|
||||
cl_path = find_cl_path()
|
||||
if cl_path is None:
|
||||
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
||||
os.environ["PATH"] += ";" + cl_path
|
||||
|
||||
'''
|
||||
Usage:
|
||||
|
||||
python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory)
|
||||
|
||||
python setup.py install # build extensions and install (copy) to PATH.
|
||||
pip install . # ditto but better (e.g., dependency & metadata handling)
|
||||
|
||||
python setup.py develop # build extensions and install (symbolic) to PATH.
|
||||
pip install -e . # ditto but better (e.g., dependency & metadata handling)
|
||||
|
||||
'''
|
||||
setup(
|
||||
name='raymarching_face', # package name, import this to use python API
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
name='_raymarching_face', # extension name, import this to use CUDA API
|
||||
sources=[os.path.join(_src_path, 'src', f) for f in [
|
||||
'raymarching.cu',
|
||||
'bindings.cpp',
|
||||
]],
|
||||
extra_compile_args={
|
||||
'cxx': c_flags,
|
||||
'nvcc': nvcc_flags,
|
||||
}
|
||||
),
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension,
|
||||
}
|
||||
)
|
|
@ -0,0 +1,39 @@
|
|||
#include <torch/extension.h>
|
||||
|
||||
#include "raymarching.h"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
// utils
|
||||
m.def("packbits", &packbits, "packbits (CUDA)");
|
||||
m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)");
|
||||
m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)");
|
||||
m.def("morton3D", &morton3D, "morton3D (CUDA)");
|
||||
m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)");
|
||||
m.def("morton3D_dilation", &morton3D_dilation, "morton3D_dilation (CUDA)");
|
||||
// train
|
||||
m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)");
|
||||
m.def("march_rays_train_backward", &march_rays_train_backward, "march_rays_train_backward (CUDA)");
|
||||
m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)");
|
||||
m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)");
|
||||
// infer
|
||||
m.def("march_rays", &march_rays, "march rays (CUDA)");
|
||||
m.def("composite_rays", &composite_rays, "composite rays (CUDA)");
|
||||
m.def("composite_rays_ambient", &composite_rays_ambient, "composite rays with ambient (CUDA)");
|
||||
|
||||
// train
|
||||
m.def("composite_rays_train_sigma_forward", &composite_rays_train_sigma_forward, "composite_rays_train_forward (CUDA)");
|
||||
m.def("composite_rays_train_sigma_backward", &composite_rays_train_sigma_backward, "composite_rays_train_backward (CUDA)");
|
||||
// infer
|
||||
m.def("composite_rays_ambient_sigma", &composite_rays_ambient_sigma, "composite rays with ambient (CUDA)");
|
||||
|
||||
// uncertainty train
|
||||
m.def("composite_rays_train_uncertainty_forward", &composite_rays_train_uncertainty_forward, "composite_rays_train_forward (CUDA)");
|
||||
m.def("composite_rays_train_uncertainty_backward", &composite_rays_train_uncertainty_backward, "composite_rays_train_backward (CUDA)");
|
||||
m.def("composite_rays_uncertainty", &composite_rays_uncertainty, "composite rays with ambient (CUDA)");
|
||||
|
||||
// triplane
|
||||
m.def("composite_rays_train_triplane_forward", &composite_rays_train_triplane_forward, "composite_rays_train_forward (CUDA)");
|
||||
m.def("composite_rays_train_triplane_backward", &composite_rays_train_triplane_backward, "composite_rays_train_backward (CUDA)");
|
||||
m.def("composite_rays_triplane", &composite_rays_triplane, "composite rays with ambient (CUDA)");
|
||||
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,38 @@
|
|||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
|
||||
void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars);
|
||||
void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords);
|
||||
void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices);
|
||||
void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords);
|
||||
void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield);
|
||||
void morton3D_dilation(const at::Tensor grid, const uint32_t C, const uint32_t H, at::Tensor grid_dilation);
|
||||
|
||||
void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises);
|
||||
void march_rays_train_backward(const at::Tensor grad_xyzs, const at::Tensor grad_dirs, const at::Tensor rays, const at::Tensor deltas, const uint32_t N, const uint32_t M, at::Tensor grad_rays_o, at::Tensor grad_rays_d);
|
||||
void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor ambient_sum, at::Tensor depth, at::Tensor image);
|
||||
void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_ambient_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor ambient_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_ambient);
|
||||
|
||||
void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises);
|
||||
void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
|
||||
void composite_rays_ambient(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor ambients, at::Tensor weights, at::Tensor depth, at::Tensor image, at::Tensor ambient_sum);
|
||||
|
||||
|
||||
void composite_rays_train_sigma_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor ambient_sum, at::Tensor depth, at::Tensor image);
|
||||
void composite_rays_train_sigma_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_ambient_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor ambient_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_ambient);
|
||||
|
||||
void composite_rays_ambient_sigma(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor ambients, at::Tensor weights, at::Tensor depth, at::Tensor image, at::Tensor ambient_sum);
|
||||
|
||||
|
||||
// uncertainty
|
||||
void composite_rays_train_uncertainty_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor uncertainty, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor ambient_sum, at::Tensor uncertainty_sum, at::Tensor depth, at::Tensor image);
|
||||
void composite_rays_train_uncertainty_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_ambient_sum, const at::Tensor grad_uncertainty_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor uncertainty, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor ambient_sum, const at::Tensor uncertainty_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_ambient, at::Tensor grad_uncertainty);
|
||||
void composite_rays_uncertainty(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor ambients, at::Tensor uncertainties, at::Tensor weights, at::Tensor depth, at::Tensor image, at::Tensor ambient_sum, at::Tensor uncertainty_sum);
|
||||
|
||||
// triplane
|
||||
void composite_rays_train_triplane_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor amb_aud, const at::Tensor amb_eye, const at::Tensor uncertainty, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor amb_aud_sum, at::Tensor amb_eye_sum, at::Tensor uncertainty_sum, at::Tensor depth, at::Tensor image);
|
||||
void composite_rays_train_triplane_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_amb_aud_sum, const at::Tensor grad_amb_eye_sum, const at::Tensor grad_uncertainty_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor amb_aud, const at::Tensor amb_eye, const at::Tensor uncertainty, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor amb_aud_sum, const at::Tensor amb_eye_sum, const at::Tensor uncertainty_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_amb_aud, at::Tensor grad_amb_eye, at::Tensor grad_uncertainty);
|
||||
void composite_rays_triplane(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor ambs_aud, at::Tensor ambs_eye, at::Tensor uncertainties, at::Tensor weights, at::Tensor depth, at::Tensor image, at::Tensor amb_aud_sum, at::Tensor amb_eye_sum, at::Tensor uncertainty_sum);
|
|
@ -0,0 +1,26 @@
|
|||
torch-ema
|
||||
ninja
|
||||
trimesh
|
||||
opencv-python
|
||||
tensorboardX
|
||||
numpy
|
||||
pandas
|
||||
tqdm
|
||||
matplotlib
|
||||
PyMCubes
|
||||
rich
|
||||
dearpygui
|
||||
packaging
|
||||
scipy
|
||||
|
||||
face_alignment
|
||||
python_speech_features
|
||||
numba
|
||||
resampy
|
||||
pyaudio
|
||||
soundfile
|
||||
einops
|
||||
configargparse
|
||||
|
||||
lpips
|
||||
imageio-ffmpeg
|
|
@ -0,0 +1,5 @@
|
|||
python main.py data/obama/ --workspace trial_obama_triplane/ -O --iters 100000
|
||||
cp -r trial_obama_triplane/checkpoints trial_obama_triplane/checkpoints_
|
||||
python main.py data/obama/ --workspace trial_obama_triplane/ -O --iters 125000 --finetune_lips --patch_size 32
|
||||
python main.py data/obama/ --workspace trial_obama_triplane/ -O --test
|
||||
# python main.py data/obama/ --workspace trial_obama_triplane_torso/ -O --torso --head_ckpt <head>.pth --iters 200000
|
|
@ -0,0 +1 @@
|
|||
from .sphere_harmonics import SHEncoder
|
|
@ -0,0 +1,40 @@
|
|||
import os
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
nvcc_flags = [
|
||||
'-O3', '-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
||||
]
|
||||
|
||||
if os.name == "posix":
|
||||
c_flags = ['-O3', '-std=c++14', '-finput-charset=utf-8']
|
||||
elif os.name == "nt":
|
||||
c_flags = ['/O2', '/std:c++17', '/source-charset:utf-8']
|
||||
|
||||
# find cl.exe
|
||||
def find_cl_path():
|
||||
import glob
|
||||
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
||||
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
||||
if paths:
|
||||
return paths[0]
|
||||
|
||||
# If cl.exe is not on path, try to find it.
|
||||
if os.system("where cl.exe >nul 2>nul") != 0:
|
||||
cl_path = find_cl_path()
|
||||
if cl_path is None:
|
||||
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
||||
os.environ["PATH"] += ";" + cl_path
|
||||
|
||||
_backend = load(name='_sh_encoder',
|
||||
extra_cflags=c_flags,
|
||||
extra_cuda_cflags=nvcc_flags,
|
||||
sources=[os.path.join(_src_path, 'src', f) for f in [
|
||||
'shencoder.cu',
|
||||
'bindings.cpp',
|
||||
]],
|
||||
)
|
||||
|
||||
__all__ = ['_backend']
|
|
@ -0,0 +1,50 @@
|
|||
import os
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
nvcc_flags = [
|
||||
'-O3', '-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
||||
]
|
||||
|
||||
if os.name == "posix":
|
||||
c_flags = ['-O3', '-std=c++14']
|
||||
elif os.name == "nt":
|
||||
c_flags = ['/O2', '/std:c++17']
|
||||
|
||||
# find cl.exe
|
||||
def find_cl_path():
|
||||
import glob
|
||||
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
||||
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
||||
if paths:
|
||||
return paths[0]
|
||||
|
||||
# If cl.exe is not on path, try to find it.
|
||||
if os.system("where cl.exe >nul 2>nul") != 0:
|
||||
cl_path = find_cl_path()
|
||||
if cl_path is None:
|
||||
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
||||
os.environ["PATH"] += ";" + cl_path
|
||||
|
||||
setup(
|
||||
name='shencoder', # package name, import this to use python API
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
name='_shencoder', # extension name, import this to use CUDA API
|
||||
sources=[os.path.join(_src_path, 'src', f) for f in [
|
||||
'shencoder.cu',
|
||||
'bindings.cpp',
|
||||
]],
|
||||
extra_compile_args={
|
||||
'cxx': c_flags,
|
||||
'nvcc': nvcc_flags,
|
||||
}
|
||||
),
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension,
|
||||
}
|
||||
)
|
|
@ -0,0 +1,87 @@
|
|||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
try:
|
||||
import _shencoder as _backend
|
||||
except ImportError:
|
||||
from .backend import _backend
|
||||
|
||||
class _sh_encoder(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
|
||||
def forward(ctx, inputs, degree, calc_grad_inputs=False):
|
||||
# inputs: [B, input_dim], float in [-1, 1]
|
||||
# RETURN: [B, F], float
|
||||
|
||||
inputs = inputs.contiguous()
|
||||
B, input_dim = inputs.shape # batch size, coord dim
|
||||
output_dim = degree ** 2
|
||||
|
||||
outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
|
||||
|
||||
if calc_grad_inputs:
|
||||
dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device)
|
||||
else:
|
||||
dy_dx = None
|
||||
|
||||
_backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx)
|
||||
|
||||
ctx.save_for_backward(inputs, dy_dx)
|
||||
ctx.dims = [B, input_dim, degree]
|
||||
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
#@once_differentiable
|
||||
@custom_bwd
|
||||
def backward(ctx, grad):
|
||||
# grad: [B, C * C]
|
||||
|
||||
inputs, dy_dx = ctx.saved_tensors
|
||||
|
||||
if dy_dx is not None:
|
||||
grad = grad.contiguous()
|
||||
B, input_dim, degree = ctx.dims
|
||||
grad_inputs = torch.zeros_like(inputs)
|
||||
_backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs)
|
||||
return grad_inputs, None, None
|
||||
else:
|
||||
return None, None, None
|
||||
|
||||
|
||||
|
||||
sh_encode = _sh_encoder.apply
|
||||
|
||||
|
||||
class SHEncoder(nn.Module):
|
||||
def __init__(self, input_dim=3, degree=4):
|
||||
super().__init__()
|
||||
|
||||
self.input_dim = input_dim # coord dims, must be 3
|
||||
self.degree = degree # 0 ~ 4
|
||||
self.output_dim = degree ** 2
|
||||
|
||||
assert self.input_dim == 3, "SH encoder only support input dim == 3"
|
||||
assert self.degree > 0 and self.degree <= 8, "SH encoder only supports degree in [1, 8]"
|
||||
|
||||
def __repr__(self):
|
||||
return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}"
|
||||
|
||||
def forward(self, inputs, size=1):
|
||||
# inputs: [..., input_dim], normalized real world positions in [-size, size]
|
||||
# return: [..., degree^2]
|
||||
|
||||
inputs = inputs / size # [-1, 1]
|
||||
|
||||
prefix_shape = list(inputs.shape[:-1])
|
||||
inputs = inputs.reshape(-1, self.input_dim)
|
||||
|
||||
outputs = sh_encode(inputs, self.degree, inputs.requires_grad)
|
||||
outputs = outputs.reshape(prefix_shape + [self.output_dim])
|
||||
|
||||
return outputs
|
|
@ -0,0 +1,8 @@
|
|||
#include <torch/extension.h>
|
||||
|
||||
#include "shencoder.h"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("sh_encode_forward", &sh_encode_forward, "SH encode forward (CUDA)");
|
||||
m.def("sh_encode_backward", &sh_encode_backward, "SH encode backward (CUDA)");
|
||||
}
|
|
@ -0,0 +1,439 @@
|
|||
#include <stdint.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
|
||||
#include <cstdio>
|
||||
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
|
||||
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
|
||||
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
|
||||
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ T div_round_up(T val, T divisor) {
|
||||
return (val + divisor - 1) / divisor;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void kernel_sh(
|
||||
const scalar_t * __restrict__ inputs,
|
||||
scalar_t * outputs,
|
||||
uint32_t B, uint32_t D, uint32_t C,
|
||||
scalar_t * dy_dx
|
||||
) {
|
||||
const uint32_t b = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (b >= B) return;
|
||||
|
||||
const uint32_t C2 = C * C;
|
||||
|
||||
// locate
|
||||
inputs += b * D;
|
||||
outputs += b * C2;
|
||||
|
||||
scalar_t x = inputs[0], y = inputs[1], z = inputs[2];
|
||||
|
||||
scalar_t xy=x*y, xz=x*z, yz=y*z, x2=x*x, y2=y*y, z2=z*z, xyz=xy*z;
|
||||
scalar_t x4=x2*x2, y4=y2*y2, z4=z2*z2;
|
||||
scalar_t x6=x4*x2, y6=y4*y2, z6=z4*z2;
|
||||
|
||||
auto write_sh = [&]() {
|
||||
outputs[0] = 0.28209479177387814f ; // 1/(2*sqrt(pi))
|
||||
if (C <= 1) { return; }
|
||||
outputs[1] = -0.48860251190291987f*y ; // -sqrt(3)*y/(2*sqrt(pi))
|
||||
outputs[2] = 0.48860251190291987f*z ; // sqrt(3)*z/(2*sqrt(pi))
|
||||
outputs[3] = -0.48860251190291987f*x ; // -sqrt(3)*x/(2*sqrt(pi))
|
||||
if (C <= 2) { return; }
|
||||
outputs[4] = 1.0925484305920792f*xy ; // sqrt(15)*xy/(2*sqrt(pi))
|
||||
outputs[5] = -1.0925484305920792f*yz ; // -sqrt(15)*yz/(2*sqrt(pi))
|
||||
outputs[6] = 0.94617469575755997f*z2 - 0.31539156525251999f ; // sqrt(5)*(3*z2 - 1)/(4*sqrt(pi))
|
||||
outputs[7] = -1.0925484305920792f*xz ; // -sqrt(15)*xz/(2*sqrt(pi))
|
||||
outputs[8] = 0.54627421529603959f*x2 - 0.54627421529603959f*y2 ; // sqrt(15)*(x2 - y2)/(4*sqrt(pi))
|
||||
if (C <= 3) { return; }
|
||||
outputs[9] = 0.59004358992664352f*y*(-3.0f*x2 + y2) ; // sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi))
|
||||
outputs[10] = 2.8906114426405538f*xy*z ; // sqrt(105)*xy*z/(2*sqrt(pi))
|
||||
outputs[11] = 0.45704579946446572f*y*(1.0f - 5.0f*z2) ; // sqrt(42)*y*(1 - 5*z2)/(8*sqrt(pi))
|
||||
outputs[12] = 0.3731763325901154f*z*(5.0f*z2 - 3.0f) ; // sqrt(7)*z*(5*z2 - 3)/(4*sqrt(pi))
|
||||
outputs[13] = 0.45704579946446572f*x*(1.0f - 5.0f*z2) ; // sqrt(42)*x*(1 - 5*z2)/(8*sqrt(pi))
|
||||
outputs[14] = 1.4453057213202769f*z*(x2 - y2) ; // sqrt(105)*z*(x2 - y2)/(4*sqrt(pi))
|
||||
outputs[15] = 0.59004358992664352f*x*(-x2 + 3.0f*y2) ; // sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi))
|
||||
if (C <= 4) { return; }
|
||||
outputs[16] = 2.5033429417967046f*xy*(x2 - y2) ; // 3*sqrt(35)*xy*(x2 - y2)/(4*sqrt(pi))
|
||||
outputs[17] = 1.7701307697799304f*yz*(-3.0f*x2 + y2) ; // 3*sqrt(70)*yz*(-3*x2 + y2)/(8*sqrt(pi))
|
||||
outputs[18] = 0.94617469575756008f*xy*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*xy*(7*z2 - 1)/(4*sqrt(pi))
|
||||
outputs[19] = 0.66904654355728921f*yz*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*yz*(3 - 7*z2)/(8*sqrt(pi))
|
||||
outputs[20] = -3.1735664074561294f*z2 + 3.7024941420321507f*z4 + 0.31735664074561293f ; // 3*(-30*z2 + 35*z4 + 3)/(16*sqrt(pi))
|
||||
outputs[21] = 0.66904654355728921f*xz*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*xz*(3 - 7*z2)/(8*sqrt(pi))
|
||||
outputs[22] = 0.47308734787878004f*(x2 - y2)*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*(x2 - y2)*(7*z2 - 1)/(8*sqrt(pi))
|
||||
outputs[23] = 1.7701307697799304f*xz*(-x2 + 3.0f*y2) ; // 3*sqrt(70)*xz*(-x2 + 3*y2)/(8*sqrt(pi))
|
||||
outputs[24] = -3.7550144126950569f*x2*y2 + 0.62583573544917614f*x4 + 0.62583573544917614f*y4 ; // 3*sqrt(35)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
|
||||
if (C <= 5) { return; }
|
||||
outputs[25] = 0.65638205684017015f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(154)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
|
||||
outputs[26] = 8.3026492595241645f*xy*z*(x2 - y2) ; // 3*sqrt(385)*xy*z*(x2 - y2)/(4*sqrt(pi))
|
||||
outputs[27] = -0.48923829943525038f*y*(3.0f*x2 - y2)*(9.0f*z2 - 1.0f) ; // -sqrt(770)*y*(3*x2 - y2)*(9*z2 - 1)/(32*sqrt(pi))
|
||||
outputs[28] = 4.7935367849733241f*xy*z*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xy*z*(3*z2 - 1)/(4*sqrt(pi))
|
||||
outputs[29] = 0.45294665119569694f*y*(14.0f*z2 - 21.0f*z4 - 1.0f) ; // sqrt(165)*y*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
|
||||
outputs[30] = 0.1169503224534236f*z*(-70.0f*z2 + 63.0f*z4 + 15.0f) ; // sqrt(11)*z*(-70*z2 + 63*z4 + 15)/(16*sqrt(pi))
|
||||
outputs[31] = 0.45294665119569694f*x*(14.0f*z2 - 21.0f*z4 - 1.0f) ; // sqrt(165)*x*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
|
||||
outputs[32] = 2.3967683924866621f*z*(x2 - y2)*(3.0f*z2 - 1.0f) ; // sqrt(1155)*z*(x2 - y2)*(3*z2 - 1)/(8*sqrt(pi))
|
||||
outputs[33] = -0.48923829943525038f*x*(x2 - 3.0f*y2)*(9.0f*z2 - 1.0f) ; // -sqrt(770)*x*(x2 - 3*y2)*(9*z2 - 1)/(32*sqrt(pi))
|
||||
outputs[34] = 2.0756623148810411f*z*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(385)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
|
||||
outputs[35] = 0.65638205684017015f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(154)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
|
||||
if (C <= 6) { return; }
|
||||
outputs[36] = 1.3663682103838286f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // sqrt(6006)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
|
||||
outputs[37] = 2.3666191622317521f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(2002)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
|
||||
outputs[38] = 2.0182596029148963f*xy*(x2 - y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*xy*(x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))
|
||||
outputs[39] = -0.92120525951492349f*yz*(3.0f*x2 - y2)*(11.0f*z2 - 3.0f) ; // -sqrt(2730)*yz*(3*x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))
|
||||
outputs[40] = 0.92120525951492349f*xy*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*xy*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
|
||||
outputs[41] = 0.58262136251873131f*yz*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*yz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
|
||||
outputs[42] = 6.6747662381009842f*z2 - 20.024298714302954f*z4 + 14.684485723822165f*z6 - 0.31784601133814211f ; // sqrt(13)*(105*z2 - 315*z4 + 231*z6 - 5)/(32*sqrt(pi))
|
||||
outputs[43] = 0.58262136251873131f*xz*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*xz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
|
||||
outputs[44] = 0.46060262975746175f*(x2 - y2)*(11.0f*z2*(3.0f*z2 - 1.0f) - 7.0f*z2 + 1.0f) ; // sqrt(2730)*(x2 - y2)*(11*z2*(3*z2 - 1) - 7*z2 + 1)/(64*sqrt(pi))
|
||||
outputs[45] = -0.92120525951492349f*xz*(x2 - 3.0f*y2)*(11.0f*z2 - 3.0f) ; // -sqrt(2730)*xz*(x2 - 3*y2)*(11*z2 - 3)/(32*sqrt(pi))
|
||||
outputs[46] = 0.50456490072872406f*(11.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(91)*(11*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))
|
||||
outputs[47] = 2.3666191622317521f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(2002)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
|
||||
outputs[48] = 10.247761577878714f*x2*y4 - 10.247761577878714f*x4*y2 + 0.6831841051919143f*x6 - 0.6831841051919143f*y6 ; // sqrt(6006)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))
|
||||
if (C <= 7) { return; }
|
||||
outputs[49] = 0.70716273252459627f*y*(-21.0f*x2*y4 + 35.0f*x4*y2 - 7.0f*x6 + y6) ; // 3*sqrt(715)*y*(-21*x2*y4 + 35*x4*y2 - 7*x6 + y6)/(64*sqrt(pi))
|
||||
outputs[50] = 5.2919213236038001f*xy*z*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 3*sqrt(10010)*xy*z*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
|
||||
outputs[51] = -0.51891557872026028f*y*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // -3*sqrt(385)*y*(13*z2 - 1)*(-10*x2*y2 + 5*x4 + y4)/(64*sqrt(pi))
|
||||
outputs[52] = 4.1513246297620823f*xy*z*(x2 - y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xy*z*(x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))
|
||||
outputs[53] = -0.15645893386229404f*y*(3.0f*x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -3*sqrt(35)*y*(3*x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))
|
||||
outputs[54] = 0.44253269244498261f*xy*z*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xy*z*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
|
||||
outputs[55] = 0.090331607582517306f*y*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ; // sqrt(105)*y*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
|
||||
outputs[56] = 0.068284276912004949f*z*(315.0f*z2 - 693.0f*z4 + 429.0f*z6 - 35.0f) ; // sqrt(15)*z*(315*z2 - 693*z4 + 429*z6 - 35)/(32*sqrt(pi))
|
||||
outputs[57] = 0.090331607582517306f*x*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ; // sqrt(105)*x*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
|
||||
outputs[58] = 0.07375544874083044f*z*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) - 187.0f*z2 + 45.0f) ; // sqrt(70)*z*(x2 - y2)*(143*z2*(3*z2 - 1) - 187*z2 + 45)/(64*sqrt(pi))
|
||||
outputs[59] = -0.15645893386229404f*x*(x2 - 3.0f*y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -3*sqrt(35)*x*(x2 - 3*y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))
|
||||
outputs[60] = 1.0378311574405206f*z*(13.0f*z2 - 3.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(385)*z*(13*z2 - 3)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))
|
||||
outputs[61] = -0.51891557872026028f*x*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // -3*sqrt(385)*x*(13*z2 - 1)*(-10*x2*y2 + x4 + 5*y4)/(64*sqrt(pi))
|
||||
outputs[62] = 2.6459606618019f*z*(15.0f*x2*y4 - 15.0f*x4*y2 + x6 - y6) ; // 3*sqrt(10010)*z*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))
|
||||
outputs[63] = 0.70716273252459627f*x*(-35.0f*x2*y4 + 21.0f*x4*y2 - x6 + 7.0f*y6) ; // 3*sqrt(715)*x*(-35*x2*y4 + 21*x4*y2 - x6 + 7*y6)/(64*sqrt(pi))
|
||||
};
|
||||
|
||||
write_sh();
|
||||
|
||||
if (dy_dx) {
|
||||
scalar_t *dx = dy_dx + b * D * C2;
|
||||
scalar_t *dy = dx + C2;
|
||||
scalar_t *dz = dy + C2;
|
||||
|
||||
auto write_sh_dx = [&]() {
|
||||
dx[0] = 0.0f ; // 0
|
||||
if (C <= 1) { return; }
|
||||
dx[1] = 0.0f ; // 0
|
||||
dx[2] = 0.0f ; // 0
|
||||
dx[3] = -0.48860251190291992f ; // -sqrt(3)/(2*sqrt(pi))
|
||||
if (C <= 2) { return; }
|
||||
dx[4] = 1.0925484305920792f*y ; // sqrt(15)*y/(2*sqrt(pi))
|
||||
dx[5] = 0.0f ; // 0
|
||||
dx[6] = 0.0f ; // 0
|
||||
dx[7] = -1.0925484305920792f*z ; // -sqrt(15)*z/(2*sqrt(pi))
|
||||
dx[8] = 1.0925484305920792f*x ; // sqrt(15)*x/(2*sqrt(pi))
|
||||
if (C <= 3) { return; }
|
||||
dx[9] = -3.5402615395598609f*xy ; // -3*sqrt(70)*xy/(4*sqrt(pi))
|
||||
dx[10] = 2.8906114426405538f*yz ; // sqrt(105)*yz/(2*sqrt(pi))
|
||||
dx[11] = 0.0f ; // 0
|
||||
dx[12] = 0.0f ; // 0
|
||||
dx[13] = 0.45704579946446572f - 2.2852289973223288f*z2 ; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi))
|
||||
dx[14] = 2.8906114426405538f*xz ; // sqrt(105)*xz/(2*sqrt(pi))
|
||||
dx[15] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi))
|
||||
if (C <= 4) { return; }
|
||||
dx[16] = 2.5033429417967046f*y*(3.0f*x2 - y2) ; // 3*sqrt(35)*y*(3*x2 - y2)/(4*sqrt(pi))
|
||||
dx[17] = -10.620784618679583f*xy*z ; // -9*sqrt(70)*xy*z/(4*sqrt(pi))
|
||||
dx[18] = 0.94617469575756008f*y*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*y*(7*z2 - 1)/(4*sqrt(pi))
|
||||
dx[19] = 0.0f ; // 0
|
||||
dx[20] = 0.0f ; // 0
|
||||
dx[21] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi))
|
||||
dx[22] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi))
|
||||
dx[23] = 5.3103923093397913f*z*(-x2 + y2) ; // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi))
|
||||
dx[24] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ; // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi))
|
||||
if (C <= 5) { return; }
|
||||
dx[25] = 13.127641136803401f*xy*(-x2 + y2) ; // 15*sqrt(154)*xy*(-x2 + y2)/(8*sqrt(pi))
|
||||
dx[26] = 8.3026492595241645f*yz*(3.0f*x2 - y2) ; // 3*sqrt(385)*yz*(3*x2 - y2)/(4*sqrt(pi))
|
||||
dx[27] = 2.9354297966115022f*xy*(1.0f - 9.0f*z2) ; // 3*sqrt(770)*xy*(1 - 9*z2)/(16*sqrt(pi))
|
||||
dx[28] = 4.7935367849733241f*yz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*yz*(3*z2 - 1)/(4*sqrt(pi))
|
||||
dx[29] = 0.0f ; // 0
|
||||
dx[30] = 0.0f ; // 0
|
||||
dx[31] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ; // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
|
||||
dx[32] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi))
|
||||
dx[33] = -13.209434084751759f*x2*z2 + 1.4677148983057511f*x2 + 13.209434084751759f*y2*z2 - 1.4677148983057511f*y2 ; // 3*sqrt(770)*(-9*x2*z2 + x2 + 9*y2*z2 - y2)/(32*sqrt(pi))
|
||||
dx[34] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ; // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi))
|
||||
dx[35] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ; // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
|
||||
if (C <= 6) { return; }
|
||||
dx[36] = 4.0991046311514854f*y*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // 3*sqrt(6006)*y*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi))
|
||||
dx[37] = 47.332383244635047f*xy*z*(-x2 + y2) ; // 15*sqrt(2002)*xy*z*(-x2 + y2)/(8*sqrt(pi))
|
||||
dx[38] = 2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))
|
||||
dx[39] = 5.5272315570895412f*xy*z*(3.0f - 11.0f*z2) ; // 3*sqrt(2730)*xy*z*(3 - 11*z2)/(16*sqrt(pi))
|
||||
dx[40] = 0.92120525951492349f*y*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*y*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
|
||||
dx[41] = 0.0f ; // 0
|
||||
dx[42] = 0.0f ; // 0
|
||||
dx[43] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
|
||||
dx[44] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
|
||||
dx[45] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))
|
||||
dx[46] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi))
|
||||
dx[47] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ; // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
|
||||
dx[48] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
|
||||
if (C <= 7) { return; }
|
||||
dx[49] = 9.9002782553443485f*xy*(10.0f*x2*y2 - 3.0f*x4 - 3.0f*y4) ; // 21*sqrt(715)*xy*(10*x2*y2 - 3*x4 - 3*y4)/(32*sqrt(pi))
|
||||
dx[50] = 15.875763970811402f*yz*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // 9*sqrt(10010)*yz*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi))
|
||||
dx[51] = -10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // -15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi))
|
||||
dx[52] = 4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))
|
||||
dx[53] = 0.93875360317376422f*xy*(66.0f*z2 - 143.0f*z4 - 3.0f) ; // 9*sqrt(35)*xy*(66*z2 - 143*z4 - 3)/(32*sqrt(pi))
|
||||
dx[54] = 0.44253269244498261f*yz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*yz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
|
||||
dx[55] = 0.0f ; // 0
|
||||
dx[56] = 0.0f ; // 0
|
||||
dx[57] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ; // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
|
||||
dx[58] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
|
||||
dx[59] = 30.97886890473422f*x2*z2 - 67.120882626924143f*x2*z4 - 1.4081304047606462f*x2 - 30.97886890473422f*y2*z2 + 67.120882626924143f*y2*z4 + 1.4081304047606462f*y2 ; // 9*sqrt(35)*(66*x2*z2 - 143*x2*z4 - 3*x2 - 66*y2*z2 + 143*y2*z4 + 3*y2)/(64*sqrt(pi))
|
||||
dx[60] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi))
|
||||
dx[61] = -0.51891557872026028f*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 4.0f*x2*(x2 - 5.0f*y2) + x4 + 5.0f*y4) ; // -3*sqrt(385)*(13*z2 - 1)*(-10*x2*y2 + 4*x2*(x2 - 5*y2) + x4 + 5*y4)/(64*sqrt(pi))
|
||||
dx[62] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
|
||||
dx[63] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi))
|
||||
};
|
||||
|
||||
auto write_sh_dy = [&]() {
|
||||
dy[0] = 0.0f ; // 0
|
||||
if (C <= 1) { return; }
|
||||
dy[1] = -0.48860251190291992f ; // -sqrt(3)/(2*sqrt(pi))
|
||||
dy[2] = 0.0f ; // 0
|
||||
dy[3] = 0.0f ; // 0
|
||||
if (C <= 2) { return; }
|
||||
dy[4] = 1.0925484305920792f*x ; // sqrt(15)*x/(2*sqrt(pi))
|
||||
dy[5] = -1.0925484305920792f*z ; // -sqrt(15)*z/(2*sqrt(pi))
|
||||
dy[6] = 0.0f ; // 0
|
||||
dy[7] = 0.0f ; // 0
|
||||
dy[8] = -1.0925484305920792f*y ; // -sqrt(15)*y/(2*sqrt(pi))
|
||||
if (C <= 3) { return; }
|
||||
dy[9] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi))
|
||||
dy[10] = 2.8906114426405538f*xz ; // sqrt(105)*xz/(2*sqrt(pi))
|
||||
dy[11] = 0.45704579946446572f - 2.2852289973223288f*z2 ; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi))
|
||||
dy[12] = 0.0f ; // 0
|
||||
dy[13] = 0.0f ; // 0
|
||||
dy[14] = -2.8906114426405538f*yz ; // -sqrt(105)*yz/(2*sqrt(pi))
|
||||
dy[15] = 3.5402615395598609f*xy ; // 3*sqrt(70)*xy/(4*sqrt(pi))
|
||||
if (C <= 4) { return; }
|
||||
dy[16] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ; // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi))
|
||||
dy[17] = 5.3103923093397913f*z*(-x2 + y2) ; // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi))
|
||||
dy[18] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi))
|
||||
dy[19] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi))
|
||||
dy[20] = 0.0f ; // 0
|
||||
dy[21] = 0.0f ; // 0
|
||||
dy[22] = 0.94617469575756008f*y*(1.0f - 7.0f*z2) ; // 3*sqrt(5)*y*(1 - 7*z2)/(4*sqrt(pi))
|
||||
dy[23] = 10.620784618679583f*xy*z ; // 9*sqrt(70)*xy*z/(4*sqrt(pi))
|
||||
dy[24] = 2.5033429417967046f*y*(-3.0f*x2 + y2) ; // 3*sqrt(35)*y*(-3*x2 + y2)/(4*sqrt(pi))
|
||||
if (C <= 5) { return; }
|
||||
dy[25] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ; // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
|
||||
dy[26] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ; // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi))
|
||||
dy[27] = -1.4677148983057511f*(x2 - y2)*(9.0f*z2 - 1.0f) ; // -3*sqrt(770)*(x2 - y2)*(9*z2 - 1)/(32*sqrt(pi))
|
||||
dy[28] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi))
|
||||
dy[29] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ; // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
|
||||
dy[30] = 0.0f ; // 0
|
||||
dy[31] = 0.0f ; // 0
|
||||
dy[32] = 4.7935367849733241f*yz*(1.0f - 3.0f*z2) ; // sqrt(1155)*yz*(1 - 3*z2)/(4*sqrt(pi))
|
||||
dy[33] = 2.9354297966115022f*xy*(9.0f*z2 - 1.0f) ; // 3*sqrt(770)*xy*(9*z2 - 1)/(16*sqrt(pi))
|
||||
dy[34] = 8.3026492595241645f*yz*(-3.0f*x2 + y2) ; // 3*sqrt(385)*yz*(-3*x2 + y2)/(4*sqrt(pi))
|
||||
dy[35] = 13.127641136803401f*xy*(x2 - y2) ; // 15*sqrt(154)*xy*(x2 - y2)/(8*sqrt(pi))
|
||||
if (C <= 6) { return; }
|
||||
dy[36] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
|
||||
dy[37] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ; // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
|
||||
dy[38] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi))
|
||||
dy[39] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))
|
||||
dy[40] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
|
||||
dy[41] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
|
||||
dy[42] = 0.0f ; // 0
|
||||
dy[43] = 0.0f ; // 0
|
||||
dy[44] = 0.92120525951492349f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // sqrt(2730)*y*(18*z2 - 33*z4 - 1)/(32*sqrt(pi))
|
||||
dy[45] = 5.5272315570895412f*xy*z*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(16*sqrt(pi))
|
||||
dy[46] = -2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))
|
||||
dy[47] = 47.332383244635047f*xy*z*(x2 - y2) ; // 15*sqrt(2002)*xy*z*(x2 - y2)/(8*sqrt(pi))
|
||||
dy[48] = 4.0991046311514854f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(6006)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
|
||||
if (C <= 7) { return; }
|
||||
dy[49] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi))
|
||||
dy[50] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
|
||||
dy[51] = 0.51891557872026028f*(13.0f*z2 - 1.0f)*(10.0f*x2*y2 - 5.0f*x4 + 4.0f*y2*(5.0f*x2 - y2) - y4) ; // 3*sqrt(385)*(13*z2 - 1)*(10*x2*y2 - 5*x4 + 4*y2*(5*x2 - y2) - y4)/(64*sqrt(pi))
|
||||
dy[52] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi))
|
||||
dy[53] = -0.46937680158688211f*(x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -9*sqrt(35)*(x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))
|
||||
dy[54] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
|
||||
dy[55] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ; // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
|
||||
dy[56] = 0.0f ; // 0
|
||||
dy[57] = 0.0f ; // 0
|
||||
dy[58] = 0.44253269244498261f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 3*sqrt(70)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))
|
||||
dy[59] = 0.93875360317376422f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ; // 9*sqrt(35)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi))
|
||||
dy[60] = -4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // -3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))
|
||||
dy[61] = 10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // 15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi))
|
||||
dy[62] = 15.875763970811402f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 9*sqrt(10010)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
|
||||
dy[63] = 9.9002782553443485f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 21*sqrt(715)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
|
||||
};
|
||||
|
||||
auto write_sh_dz = [&]() {
|
||||
dz[0] = 0.0f ; // 0
|
||||
if (C <= 1) { return; }
|
||||
dz[1] = 0.0f ; // 0
|
||||
dz[2] = 0.48860251190291992f ; // sqrt(3)/(2*sqrt(pi))
|
||||
dz[3] = 0.0f ; // 0
|
||||
if (C <= 2) { return; }
|
||||
dz[4] = 0.0f ; // 0
|
||||
dz[5] = -1.0925484305920792f*y ; // -sqrt(15)*y/(2*sqrt(pi))
|
||||
dz[6] = 1.8923493915151202f*z ; // 3*sqrt(5)*z/(2*sqrt(pi))
|
||||
dz[7] = -1.0925484305920792f*x ; // -sqrt(15)*x/(2*sqrt(pi))
|
||||
dz[8] = 0.0f ; // 0
|
||||
if (C <= 3) { return; }
|
||||
dz[9] = 0.0f ; // 0
|
||||
dz[10] = 2.8906114426405538f*xy ; // sqrt(105)*xy/(2*sqrt(pi))
|
||||
dz[11] = -4.5704579946446566f*yz ; // -5*sqrt(42)*yz/(4*sqrt(pi))
|
||||
dz[12] = 5.597644988851731f*z2 - 1.1195289977703462f ; // 3*sqrt(7)*(5*z2 - 1)/(4*sqrt(pi))
|
||||
dz[13] = -4.5704579946446566f*xz ; // -5*sqrt(42)*xz/(4*sqrt(pi))
|
||||
dz[14] = 1.4453057213202769f*x2 - 1.4453057213202769f*y2 ; // sqrt(105)*(x2 - y2)/(4*sqrt(pi))
|
||||
dz[15] = 0.0f ; // 0
|
||||
if (C <= 4) { return; }
|
||||
dz[16] = 0.0f ; // 0
|
||||
dz[17] = 1.7701307697799304f*y*(-3.0f*x2 + y2) ; // 3*sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi))
|
||||
dz[18] = 13.246445740605839f*xy*z ; // 21*sqrt(5)*xy*z/(2*sqrt(pi))
|
||||
dz[19] = 2.0071396306718676f*y*(1.0f - 7.0f*z2) ; // 9*sqrt(10)*y*(1 - 7*z2)/(8*sqrt(pi))
|
||||
dz[20] = 14.809976568128603f*z*z*z - 6.3471328149122579f*z ; // (105*z**3 - 45*z)/(4*sqrt(pi))
|
||||
dz[21] = 2.0071396306718676f*x*(1.0f - 7.0f*z2) ; // 9*sqrt(10)*x*(1 - 7*z2)/(8*sqrt(pi))
|
||||
dz[22] = 6.6232228703029197f*z*(x2 - y2) ; // 21*sqrt(5)*z*(x2 - y2)/(4*sqrt(pi))
|
||||
dz[23] = 1.7701307697799304f*x*(-x2 + 3.0f*y2) ; // 3*sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi))
|
||||
dz[24] = 0.0f ; // 0
|
||||
if (C <= 5) { return; }
|
||||
dz[25] = 0.0f ; // 0
|
||||
dz[26] = 8.3026492595241645f*xy*(x2 - y2) ; // 3*sqrt(385)*xy*(x2 - y2)/(4*sqrt(pi))
|
||||
dz[27] = 8.8062893898345074f*yz*(-3.0f*x2 + y2) ; // 9*sqrt(770)*yz*(-3*x2 + y2)/(16*sqrt(pi))
|
||||
dz[28] = 4.7935367849733241f*xy*(9.0f*z2 - 1.0f) ; // sqrt(1155)*xy*(9*z2 - 1)/(4*sqrt(pi))
|
||||
dz[29] = 12.682506233479513f*yz*(1.0f - 3.0f*z2) ; // 7*sqrt(165)*yz*(1 - 3*z2)/(4*sqrt(pi))
|
||||
dz[30] = -24.559567715218954f*z2 + 36.839351572828434f*z4 + 1.754254836801354f ; // 15*sqrt(11)*(-14*z2 + 21*z4 + 1)/(16*sqrt(pi))
|
||||
dz[31] = 12.682506233479513f*xz*(1.0f - 3.0f*z2) ; // 7*sqrt(165)*xz*(1 - 3*z2)/(4*sqrt(pi))
|
||||
dz[32] = 2.3967683924866621f*(x2 - y2)*(9.0f*z2 - 1.0f) ; // sqrt(1155)*(x2 - y2)*(9*z2 - 1)/(8*sqrt(pi))
|
||||
dz[33] = 8.8062893898345074f*xz*(-x2 + 3.0f*y2) ; // 9*sqrt(770)*xz*(-x2 + 3*y2)/(16*sqrt(pi))
|
||||
dz[34] = -12.453973889286246f*x2*y2 + 2.0756623148810411f*x4 + 2.0756623148810411f*y4 ; // 3*sqrt(385)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
|
||||
dz[35] = 0.0f ; // 0
|
||||
if (C <= 6) { return; }
|
||||
dz[36] = 0.0f ; // 0
|
||||
dz[37] = 2.3666191622317521f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(2002)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
|
||||
dz[38] = 44.401711264127719f*xy*z*(x2 - y2) ; // 33*sqrt(91)*xy*z*(x2 - y2)/(4*sqrt(pi))
|
||||
dz[39] = -2.7636157785447706f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(2730)*y*(3*x2 - y2)*(11*z2 - 1)/(32*sqrt(pi))
|
||||
dz[40] = 11.054463114179082f*xy*z*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(8*sqrt(pi))
|
||||
dz[41] = 2.9131068125936568f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // 5*sqrt(273)*y*(18*z2 - 33*z4 - 1)/(16*sqrt(pi))
|
||||
dz[42] = 2.6699064952403937f*z*(-30.0f*z2 + 33.0f*z4 + 5.0f) ; // 21*sqrt(13)*z*(-30*z2 + 33*z4 + 5)/(16*sqrt(pi))
|
||||
dz[43] = 2.9131068125936568f*x*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // 5*sqrt(273)*x*(18*z2 - 33*z4 - 1)/(16*sqrt(pi))
|
||||
dz[44] = 5.5272315570895412f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(16*sqrt(pi))
|
||||
dz[45] = -2.7636157785447706f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(2730)*x*(x2 - 3*y2)*(11*z2 - 1)/(32*sqrt(pi))
|
||||
dz[46] = 11.10042781603193f*z*(-6.0f*x2*y2 + x4 + y4) ; // 33*sqrt(91)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
|
||||
dz[47] = 2.3666191622317521f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(2002)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
|
||||
dz[48] = 0.0f ; // 0
|
||||
if (C <= 7) { return; }
|
||||
dz[49] = 0.0f ; // 0
|
||||
dz[50] = 5.2919213236038001f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 3*sqrt(10010)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
|
||||
dz[51] = 13.491805046726766f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 39*sqrt(385)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
|
||||
dz[52] = 12.453973889286248f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // 9*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(8*sqrt(pi))
|
||||
dz[53] = -6.8841930899409371f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // -33*sqrt(35)*yz*(3*x2 - y2)*(13*z2 - 3)/(16*sqrt(pi))
|
||||
dz[54] = 2.2126634622249131f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ; // 15*sqrt(70)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi))
|
||||
dz[55] = 1.6259689364853116f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 9*sqrt(105)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))
|
||||
dz[56] = 64.528641681844675f*z2 - 236.60501950009714f*z4 + 205.05768356675085f*z6 - 2.3899496919201733f ; // 7*sqrt(15)*(135*z2 - 495*z4 + 429*z6 - 5)/(32*sqrt(pi))
|
||||
dz[57] = 1.6259689364853116f*xz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 9*sqrt(105)*xz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))
|
||||
dz[58] = 0.07375544874083044f*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) + 132.0f*z2*(13.0f*z2 - 5.0f) - 187.0f*z2 + 45.0f) ; // sqrt(70)*(x2 - y2)*(143*z2*(3*z2 - 1) + 132*z2*(13*z2 - 5) - 187*z2 + 45)/(64*sqrt(pi))
|
||||
dz[59] = -6.8841930899409371f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // -33*sqrt(35)*xz*(x2 - 3*y2)*(13*z2 - 3)/(16*sqrt(pi))
|
||||
dz[60] = 3.1134934723215619f*(13.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 9*sqrt(385)*(13*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))
|
||||
dz[61] = 13.491805046726766f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 39*sqrt(385)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
|
||||
dz[62] = 39.6894099270285f*x2*y4 - 39.6894099270285f*x4*y2 + 2.6459606618019f*x6 - 2.6459606618019f*y6 ; // 3*sqrt(10010)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))
|
||||
dz[63] = 0.0f ; // 0
|
||||
};
|
||||
write_sh_dx();
|
||||
write_sh_dy();
|
||||
write_sh_dz();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void kernel_sh_backward(
|
||||
const scalar_t * __restrict__ grad,
|
||||
const scalar_t * __restrict__ inputs,
|
||||
uint32_t B, uint32_t D, uint32_t C,
|
||||
const scalar_t * __restrict__ dy_dx,
|
||||
scalar_t * grad_inputs
|
||||
) {
|
||||
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const uint32_t b = t / D;
|
||||
if (b >= B) return;
|
||||
|
||||
const uint32_t d = t - b * D;
|
||||
const uint32_t C2 = C * C;
|
||||
|
||||
// locate
|
||||
grad += b * C2;
|
||||
dy_dx += b * D * C2 + d * C2;
|
||||
|
||||
for (int ch = 0; ch < C2; ch++) {
|
||||
grad_inputs[t] += grad[ch] * dy_dx[ch];
|
||||
//printf("t=%d, b=%d, d=%d, ch=%d, grad=%f (+= %f * %f)\n", t, b, d, ch, grad_inputs[t], grad[ch], dy_dx[ch]);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// inputs: [B, D], float, in [0, 1]
|
||||
// outputs: [B, L * C], float
|
||||
template <typename scalar_t>
|
||||
void sh_encode_forward_cuda(const scalar_t *inputs, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx) {
|
||||
static constexpr uint32_t N_THREADS = 256;
|
||||
kernel_sh<scalar_t><<<div_round_up(B, N_THREADS), N_THREADS>>>(inputs, outputs, B, D, C, dy_dx);
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t>
|
||||
void sh_encode_backward_cuda(const scalar_t *grad, const scalar_t *inputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx, scalar_t *grad_inputs) {
|
||||
static constexpr uint32_t N_THREADS = 256;
|
||||
kernel_sh_backward<scalar_t><<<div_round_up(B * D, N_THREADS), N_THREADS>>>(grad, inputs, B, D, C, dy_dx, grad_inputs);
|
||||
}
|
||||
|
||||
|
||||
void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional<at::Tensor> dy_dx) {
|
||||
CHECK_CUDA(inputs);
|
||||
CHECK_CUDA(outputs);
|
||||
// CHECK_CUDA(dy_dx);
|
||||
|
||||
CHECK_CONTIGUOUS(inputs);
|
||||
CHECK_CONTIGUOUS(outputs);
|
||||
// CHECK_CONTIGUOUS(dy_dx);
|
||||
|
||||
CHECK_IS_FLOATING(inputs);
|
||||
CHECK_IS_FLOATING(outputs);
|
||||
// CHECK_IS_FLOATING(dy_dx);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
inputs.scalar_type(), "sh_encode_forward_cuda", ([&] {
|
||||
sh_encode_forward_cuda<scalar_t>(inputs.data_ptr<scalar_t>(), outputs.data_ptr<scalar_t>(), B, D, C, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr);
|
||||
}));
|
||||
}
|
||||
|
||||
void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs) {
|
||||
CHECK_CUDA(grad);
|
||||
CHECK_CUDA(inputs);
|
||||
CHECK_CUDA(dy_dx);
|
||||
CHECK_CUDA(grad_inputs);
|
||||
|
||||
CHECK_CONTIGUOUS(grad);
|
||||
CHECK_CONTIGUOUS(inputs);
|
||||
CHECK_CONTIGUOUS(dy_dx);
|
||||
CHECK_CONTIGUOUS(grad_inputs);
|
||||
|
||||
CHECK_IS_FLOATING(grad);
|
||||
CHECK_IS_FLOATING(inputs);
|
||||
CHECK_IS_FLOATING(dy_dx);
|
||||
CHECK_IS_FLOATING(grad_inputs);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
grad.scalar_type(), "sh_encode_backward_cuda", ([&] {
|
||||
sh_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<scalar_t>(), B, D, C, dy_dx.data_ptr<scalar_t>(), grad_inputs.data_ptr<scalar_t>());
|
||||
}));
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
# pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
// inputs: [B, D], float, in [-1, 1]
|
||||
// outputs: [B, F], float
|
||||
|
||||
void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional<at::Tensor> dy_dx);
|
||||
void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs);
|
Loading…
Reference in New Issue