2023-12-19 09:41:52 +08:00
import torch
import argparse
2024-05-02 21:05:16 +08:00
from . nerf_triplane . provider import NeRFDataset , NeRFDataset_Test
from . nerf_triplane . utils import *
from . nerf_triplane . network import NeRFNetwork
2023-12-19 09:41:52 +08:00
# torch.autograd.set_detect_anomaly(True)
# Close tf32 features. Fix low numerical accuracy on rtx30xx gpu.
try :
torch . backends . cuda . matmul . allow_tf32 = False
torch . backends . cudnn . allow_tf32 = False
except AttributeError as e :
print ( ' Info. This pytorch version is not support with tf32. ' )
if __name__ == ' __main__ ' :
parser = argparse . ArgumentParser ( )
parser . add_argument ( ' path ' , type = str )
parser . add_argument ( ' -O ' , action = ' store_true ' , help = " equals --fp16 --cuda_ray --exp_eye " )
parser . add_argument ( ' --test ' , action = ' store_true ' , help = " test mode (load model and test dataset) " )
parser . add_argument ( ' --test_train ' , action = ' store_true ' , help = " test mode (load model and train dataset) " )
parser . add_argument ( ' --data_range ' , type = int , nargs = ' * ' , default = [ 0 , - 1 ] , help = " data range to use " )
parser . add_argument ( ' --workspace ' , type = str , default = ' workspace ' )
parser . add_argument ( ' --seed ' , type = int , default = 0 )
2024-03-31 12:01:28 +08:00
parser . add_argument ( ' --pose ' , type = str , default = " data/data_kf.json " , help = " transforms.json, pose source " )
parser . add_argument ( ' --au ' , type = str , default = " data/au.csv " , help = " eye blink area " )
2023-12-19 09:41:52 +08:00
### training options
parser . add_argument ( ' --iters ' , type = int , default = 200000 , help = " training iters " )
parser . add_argument ( ' --lr ' , type = float , default = 1e-2 , help = " initial learning rate " )
parser . add_argument ( ' --lr_net ' , type = float , default = 1e-3 , help = " initial learning rate " )
parser . add_argument ( ' --ckpt ' , type = str , default = ' latest ' )
parser . add_argument ( ' --num_rays ' , type = int , default = 4096 * 16 , help = " num rays sampled per image for each training step " )
parser . add_argument ( ' --cuda_ray ' , action = ' store_true ' , help = " use CUDA raymarching instead of pytorch " )
parser . add_argument ( ' --max_steps ' , type = int , default = 16 , help = " max num steps sampled per ray (only valid when using --cuda_ray) " )
parser . add_argument ( ' --num_steps ' , type = int , default = 16 , help = " num steps sampled per ray (only valid when NOT using --cuda_ray) " )
parser . add_argument ( ' --upsample_steps ' , type = int , default = 0 , help = " num steps up-sampled per ray (only valid when NOT using --cuda_ray) " )
parser . add_argument ( ' --update_extra_interval ' , type = int , default = 16 , help = " iter interval to update extra status (only valid when using --cuda_ray) " )
parser . add_argument ( ' --max_ray_batch ' , type = int , default = 4096 , help = " batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray) " )
### loss set
parser . add_argument ( ' --warmup_step ' , type = int , default = 10000 , help = " warm up steps " )
parser . add_argument ( ' --amb_aud_loss ' , type = int , default = 1 , help = " use ambient aud loss " )
parser . add_argument ( ' --amb_eye_loss ' , type = int , default = 1 , help = " use ambient eye loss " )
parser . add_argument ( ' --unc_loss ' , type = int , default = 1 , help = " use uncertainty loss " )
parser . add_argument ( ' --lambda_amb ' , type = float , default = 1e-4 , help = " lambda for ambient loss " )
### network backbone options
parser . add_argument ( ' --fp16 ' , action = ' store_true ' , help = " use amp mixed precision training " )
2024-03-31 12:01:28 +08:00
parser . add_argument ( ' --bg_img ' , type = str , default = ' white ' , help = " background image " )
2023-12-19 09:41:52 +08:00
parser . add_argument ( ' --fbg ' , action = ' store_true ' , help = " frame-wise bg " )
parser . add_argument ( ' --exp_eye ' , action = ' store_true ' , help = " explicitly control the eyes " )
parser . add_argument ( ' --fix_eye ' , type = float , default = - 1 , help = " fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye " )
parser . add_argument ( ' --smooth_eye ' , action = ' store_true ' , help = " smooth the eye area sequence " )
parser . add_argument ( ' --torso_shrink ' , type = float , default = 0.8 , help = " shrink bg coords to allow more flexibility in deform " )
### dataset options
parser . add_argument ( ' --color_space ' , type = str , default = ' srgb ' , help = " Color space, supports (linear, srgb) " )
parser . add_argument ( ' --preload ' , type = int , default = 0 , help = " 0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU. " )
# (the default value is for the fox dataset)
parser . add_argument ( ' --bound ' , type = float , default = 1 , help = " assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching. " )
parser . add_argument ( ' --scale ' , type = float , default = 4 , help = " scale camera location into box[-bound, bound]^3 " )
parser . add_argument ( ' --offset ' , type = float , nargs = ' * ' , default = [ 0 , 0 , 0 ] , help = " offset of camera location " )
parser . add_argument ( ' --dt_gamma ' , type = float , default = 1 / 256 , help = " dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality) " )
parser . add_argument ( ' --min_near ' , type = float , default = 0.05 , help = " minimum near distance for camera " )
parser . add_argument ( ' --density_thresh ' , type = float , default = 10 , help = " threshold for density grid to be occupied (sigma) " )
parser . add_argument ( ' --density_thresh_torso ' , type = float , default = 0.01 , help = " threshold for density grid to be occupied (alpha) " )
parser . add_argument ( ' --patch_size ' , type = int , default = 1 , help = " [experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable " )
parser . add_argument ( ' --init_lips ' , action = ' store_true ' , help = " init lips region " )
parser . add_argument ( ' --finetune_lips ' , action = ' store_true ' , help = " use LPIPS and landmarks to fine tune lips region " )
parser . add_argument ( ' --smooth_lips ' , action = ' store_true ' , help = " smooth the enc_a in a exponential decay way... " )
parser . add_argument ( ' --torso ' , action = ' store_true ' , help = " fix head and train torso " )
parser . add_argument ( ' --head_ckpt ' , type = str , default = ' ' , help = " head model " )
### GUI options
parser . add_argument ( ' --gui ' , action = ' store_true ' , help = " start a GUI " )
parser . add_argument ( ' --W ' , type = int , default = 450 , help = " GUI width " )
parser . add_argument ( ' --H ' , type = int , default = 450 , help = " GUI height " )
parser . add_argument ( ' --radius ' , type = float , default = 3.35 , help = " default GUI camera radius from center " )
parser . add_argument ( ' --fovy ' , type = float , default = 21.24 , help = " default GUI camera fovy " )
parser . add_argument ( ' --max_spp ' , type = int , default = 1 , help = " GUI rendering max sample per pixel " )
### else
parser . add_argument ( ' --att ' , type = int , default = 2 , help = " audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction) " )
parser . add_argument ( ' --aud ' , type = str , default = ' ' , help = " audio source (empty will load the default, else should be a path to a npy file) " )
parser . add_argument ( ' --emb ' , action = ' store_true ' , help = " use audio class + embedding instead of logits " )
parser . add_argument ( ' --ind_dim ' , type = int , default = 4 , help = " individual code dim, 0 to turn off " )
parser . add_argument ( ' --ind_num ' , type = int , default = 10000 , help = " number of individual codes, should be larger than training dataset size " )
parser . add_argument ( ' --ind_dim_torso ' , type = int , default = 8 , help = " individual code dim, 0 to turn off " )
parser . add_argument ( ' --amb_dim ' , type = int , default = 2 , help = " ambient dimension " )
parser . add_argument ( ' --part ' , action = ' store_true ' , help = " use partial training data (1/10) " )
parser . add_argument ( ' --part2 ' , action = ' store_true ' , help = " use partial training data (first 15s) " )
parser . add_argument ( ' --train_camera ' , action = ' store_true ' , help = " optimize camera pose " )
parser . add_argument ( ' --smooth_path ' , action = ' store_true ' , help = " brute-force smooth camera pose trajectory with a window size " )
parser . add_argument ( ' --smooth_path_window ' , type = int , default = 7 , help = " smoothing window size " )
# asr
parser . add_argument ( ' --asr ' , action = ' store_true ' , help = " load asr for real-time app " )
parser . add_argument ( ' --asr_wav ' , type = str , default = ' ' , help = " load the wav and use as input " )
parser . add_argument ( ' --asr_play ' , action = ' store_true ' , help = " play out the audio " )
parser . add_argument ( ' --asr_model ' , type = str , default = ' deepspeech ' )
# parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto')
# parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
parser . add_argument ( ' --asr_save_feats ' , action = ' store_true ' )
# audio FPS
parser . add_argument ( ' --fps ' , type = int , default = 50 )
# sliding window left-middle-right length (unit: 20ms)
parser . add_argument ( ' -l ' , type = int , default = 10 )
parser . add_argument ( ' -m ' , type = int , default = 50 )
parser . add_argument ( ' -r ' , type = int , default = 10 )
opt = parser . parse_args ( )
if opt . O :
opt . fp16 = True
opt . exp_eye = True
if opt . test and False :
opt . smooth_path = True
opt . smooth_eye = True
opt . smooth_lips = True
opt . cuda_ray = True
# assert opt.cuda_ray, "Only support CUDA ray mode."
if opt . patch_size > 1 :
# assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss."
assert opt . num_rays % ( opt . patch_size * * 2 ) == 0 , " patch_size ** 2 should be dividable by num_rays. "
# if opt.finetune_lips:
# # do not update density grid in finetune stage
# opt.update_extra_interval = 1e9
print ( opt )
seed_everything ( opt . seed )
device = torch . device ( ' cuda ' if torch . cuda . is_available ( ) else ' cpu ' )
model = NeRFNetwork ( opt )
# manually load state dict for head
if opt . torso and opt . head_ckpt != ' ' :
model_dict = torch . load ( opt . head_ckpt , map_location = ' cpu ' ) [ ' model ' ]
missing_keys , unexpected_keys = model . load_state_dict ( model_dict , strict = False )
if len ( missing_keys ) > 0 :
print ( f " [WARN] missing keys: { missing_keys } " )
if len ( unexpected_keys ) > 0 :
print ( f " [WARN] unexpected keys: { unexpected_keys } " )
# freeze these keys
for k , v in model . named_parameters ( ) :
if k in model_dict :
# print(f'[INFO] freeze {k}, {v.shape}')
v . requires_grad = False
# print(model)
criterion = torch . nn . MSELoss ( reduction = ' none ' )
if opt . test :
if opt . gui :
metrics = [ ] # use no metric in GUI for faster initialization...
else :
# metrics = [PSNRMeter(), LPIPSMeter(device=device)]
metrics = [ PSNRMeter ( ) , LPIPSMeter ( device = device ) , LMDMeter ( backend = ' fan ' ) ]
trainer = Trainer ( ' ngp ' , opt , model , device = device , workspace = opt . workspace , criterion = criterion , fp16 = opt . fp16 , metrics = metrics , use_checkpoint = opt . ckpt )
if opt . test_train :
2024-03-31 12:01:28 +08:00
test_set = NeRFDataset ( opt , device = device , type = ' train ' )
2023-12-19 09:41:52 +08:00
# a manual fix to test on the training dataset
test_set . training = False
test_set . num_rays = - 1
test_loader = test_set . dataloader ( )
else :
test_loader = NeRFDataset ( opt , device = device , type = ' test ' ) . dataloader ( )
# temp fix: for update_extra_states
model . aud_features = test_loader . _data . auds
model . eye_areas = test_loader . _data . eye_area
if opt . gui :
from nerf_triplane . gui import NeRFGUI
# we still need test_loader to provide audio features for testing.
with NeRFGUI ( opt , trainer , test_loader ) as gui :
gui . render ( )
else :
### test and save video (fast)
trainer . test ( test_loader )
### evaluate metrics (slow)
if test_loader . has_gt :
trainer . evaluate ( test_loader )
else :
optimizer = lambda model : torch . optim . AdamW ( model . get_params ( opt . lr , opt . lr_net ) , betas = ( 0 , 0.99 ) , eps = 1e-8 )
train_loader = NeRFDataset ( opt , device = device , type = ' train ' ) . dataloader ( )
assert len ( train_loader ) < opt . ind_num , f " [ERROR] dataset too many frames: { len ( train_loader ) } , please increase --ind_num to this number! "
# temp fix: for update_extra_states
model . aud_features = train_loader . _data . auds
model . eye_area = train_loader . _data . eye_area
model . poses = train_loader . _data . poses
# decay to 0.1 * init_lr at last iter step
if opt . finetune_lips :
scheduler = lambda optimizer : optim . lr_scheduler . LambdaLR ( optimizer , lambda iter : 0.05 * * ( iter / opt . iters ) )
else :
scheduler = lambda optimizer : optim . lr_scheduler . LambdaLR ( optimizer , lambda iter : 0.5 * * ( iter / opt . iters ) )
metrics = [ PSNRMeter ( ) , LPIPSMeter ( device = device ) ]
eval_interval = max ( 1 , int ( 5000 / len ( train_loader ) ) )
trainer = Trainer ( ' ngp ' , opt , model , device = device , workspace = opt . workspace , optimizer = optimizer , criterion = criterion , ema_decay = 0.95 , fp16 = opt . fp16 , lr_scheduler = scheduler , scheduler_update_every_step = True , metrics = metrics , use_checkpoint = opt . ckpt , eval_interval = eval_interval )
with open ( os . path . join ( opt . workspace , ' opt.txt ' ) , ' a ' ) as f :
f . write ( str ( opt ) )
if opt . gui :
with NeRFGUI ( opt , trainer , train_loader ) as gui :
gui . render ( )
else :
valid_loader = NeRFDataset ( opt , device = device , type = ' val ' , downscale = 1 ) . dataloader ( )
max_epochs = np . ceil ( opt . iters / len ( train_loader ) ) . astype ( np . int32 )
print ( f ' [INFO] max_epoch = { max_epochs } ' )
trainer . train ( train_loader , valid_loader , max_epochs )
# free some mem
del train_loader , valid_loader
torch . cuda . empty_cache ( )
# also test
test_loader = NeRFDataset ( opt , device = device , type = ' test ' ) . dataloader ( )
if test_loader . has_gt :
trainer . evaluate ( test_loader ) # blender has gt, so evaluate it.
trainer . test ( test_loader )