391 lines
14 KiB
Python
391 lines
14 KiB
Python
|
import os
|
||
|
import sys
|
||
|
import cv2
|
||
|
import argparse
|
||
|
from pathlib import Path
|
||
|
import torch
|
||
|
import numpy as np
|
||
|
from data_loader import load_dir
|
||
|
from facemodel import Face_3DMM
|
||
|
from util import *
|
||
|
from render_3dmm import Render_3DMM
|
||
|
|
||
|
|
||
|
# torch.autograd.set_detect_anomaly(True)
|
||
|
|
||
|
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||
|
|
||
|
|
||
|
def set_requires_grad(tensor_list):
|
||
|
for tensor in tensor_list:
|
||
|
tensor.requires_grad = True
|
||
|
|
||
|
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument(
|
||
|
"--path", type=str, default="obama/ori_imgs", help="idname of target person"
|
||
|
)
|
||
|
parser.add_argument("--img_h", type=int, default=512, help="image height")
|
||
|
parser.add_argument("--img_w", type=int, default=512, help="image width")
|
||
|
parser.add_argument("--frame_num", type=int, default=11000, help="image number")
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
start_id = 0
|
||
|
end_id = args.frame_num
|
||
|
|
||
|
lms, img_paths = load_dir(args.path, start_id, end_id)
|
||
|
num_frames = lms.shape[0]
|
||
|
h, w = args.img_h, args.img_w
|
||
|
cxy = torch.tensor((w / 2.0, h / 2.0), dtype=torch.float).cuda()
|
||
|
id_dim, exp_dim, tex_dim, point_num = 100, 79, 100, 34650
|
||
|
model_3dmm = Face_3DMM(
|
||
|
os.path.join(dir_path, "3DMM"), id_dim, exp_dim, tex_dim, point_num
|
||
|
)
|
||
|
|
||
|
# only use one image per 40 to do fit the focal length
|
||
|
sel_ids = np.arange(0, num_frames, 40)
|
||
|
sel_num = sel_ids.shape[0]
|
||
|
arg_focal = 1600
|
||
|
arg_landis = 1e5
|
||
|
|
||
|
print(f'[INFO] fitting focal length...')
|
||
|
|
||
|
# fit the focal length
|
||
|
for focal in range(600, 1500, 100):
|
||
|
id_para = lms.new_zeros((1, id_dim), requires_grad=True)
|
||
|
exp_para = lms.new_zeros((sel_num, exp_dim), requires_grad=True)
|
||
|
euler_angle = lms.new_zeros((sel_num, 3), requires_grad=True)
|
||
|
trans = lms.new_zeros((sel_num, 3), requires_grad=True)
|
||
|
trans.data[:, 2] -= 7
|
||
|
focal_length = lms.new_zeros(1, requires_grad=False)
|
||
|
focal_length.data += focal
|
||
|
set_requires_grad([id_para, exp_para, euler_angle, trans])
|
||
|
|
||
|
optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=0.1)
|
||
|
optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=0.1)
|
||
|
|
||
|
for iter in range(2000):
|
||
|
id_para_batch = id_para.expand(sel_num, -1)
|
||
|
geometry = model_3dmm.get_3dlandmarks(
|
||
|
id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
|
||
|
)
|
||
|
proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
|
||
|
loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms[sel_ids].detach())
|
||
|
loss = loss_lan
|
||
|
optimizer_frame.zero_grad()
|
||
|
loss.backward()
|
||
|
optimizer_frame.step()
|
||
|
# if iter % 100 == 0:
|
||
|
# print(focal, 'pose', iter, loss.item())
|
||
|
|
||
|
for iter in range(2500):
|
||
|
id_para_batch = id_para.expand(sel_num, -1)
|
||
|
geometry = model_3dmm.get_3dlandmarks(
|
||
|
id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
|
||
|
)
|
||
|
proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
|
||
|
loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms[sel_ids].detach())
|
||
|
loss_regid = torch.mean(id_para * id_para)
|
||
|
loss_regexp = torch.mean(exp_para * exp_para)
|
||
|
loss = loss_lan + loss_regid * 0.5 + loss_regexp * 0.4
|
||
|
optimizer_idexp.zero_grad()
|
||
|
optimizer_frame.zero_grad()
|
||
|
loss.backward()
|
||
|
optimizer_idexp.step()
|
||
|
optimizer_frame.step()
|
||
|
# if iter % 100 == 0:
|
||
|
# print(focal, 'poseidexp', iter, loss_lan.item(), loss_regid.item(), loss_regexp.item())
|
||
|
|
||
|
if iter % 1500 == 0 and iter >= 1500:
|
||
|
for param_group in optimizer_idexp.param_groups:
|
||
|
param_group["lr"] *= 0.2
|
||
|
for param_group in optimizer_frame.param_groups:
|
||
|
param_group["lr"] *= 0.2
|
||
|
|
||
|
print(focal, loss_lan.item(), torch.mean(trans[:, 2]).item())
|
||
|
|
||
|
if loss_lan.item() < arg_landis:
|
||
|
arg_landis = loss_lan.item()
|
||
|
arg_focal = focal
|
||
|
|
||
|
print("[INFO] find best focal:", arg_focal)
|
||
|
|
||
|
print(f'[INFO] coarse fitting...')
|
||
|
|
||
|
# for all frames, do a coarse fitting ???
|
||
|
id_para = lms.new_zeros((1, id_dim), requires_grad=True)
|
||
|
exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
|
||
|
tex_para = lms.new_zeros(
|
||
|
(1, tex_dim), requires_grad=True
|
||
|
) # not optimized in this block ???
|
||
|
euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True)
|
||
|
trans = lms.new_zeros((num_frames, 3), requires_grad=True)
|
||
|
light_para = lms.new_zeros((num_frames, 27), requires_grad=True)
|
||
|
trans.data[:, 2] -= 7 # ???
|
||
|
focal_length = lms.new_zeros(1, requires_grad=True)
|
||
|
focal_length.data += arg_focal
|
||
|
|
||
|
set_requires_grad([id_para, exp_para, tex_para, euler_angle, trans, light_para])
|
||
|
|
||
|
optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=0.1)
|
||
|
optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=1)
|
||
|
|
||
|
for iter in range(1500):
|
||
|
id_para_batch = id_para.expand(num_frames, -1)
|
||
|
geometry = model_3dmm.get_3dlandmarks(
|
||
|
id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
|
||
|
)
|
||
|
proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
|
||
|
loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms.detach())
|
||
|
loss = loss_lan
|
||
|
optimizer_frame.zero_grad()
|
||
|
loss.backward()
|
||
|
optimizer_frame.step()
|
||
|
if iter == 1000:
|
||
|
for param_group in optimizer_frame.param_groups:
|
||
|
param_group["lr"] = 0.1
|
||
|
# if iter % 100 == 0:
|
||
|
# print('pose', iter, loss.item())
|
||
|
|
||
|
for param_group in optimizer_frame.param_groups:
|
||
|
param_group["lr"] = 0.1
|
||
|
|
||
|
for iter in range(2000):
|
||
|
id_para_batch = id_para.expand(num_frames, -1)
|
||
|
geometry = model_3dmm.get_3dlandmarks(
|
||
|
id_para_batch, exp_para, euler_angle, trans, focal_length, cxy
|
||
|
)
|
||
|
proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy)
|
||
|
loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms.detach())
|
||
|
loss_regid = torch.mean(id_para * id_para)
|
||
|
loss_regexp = torch.mean(exp_para * exp_para)
|
||
|
loss = loss_lan + loss_regid * 0.5 + loss_regexp * 0.4
|
||
|
optimizer_idexp.zero_grad()
|
||
|
optimizer_frame.zero_grad()
|
||
|
loss.backward()
|
||
|
optimizer_idexp.step()
|
||
|
optimizer_frame.step()
|
||
|
# if iter % 100 == 0:
|
||
|
# print('poseidexp', iter, loss_lan.item(), loss_regid.item(), loss_regexp.item())
|
||
|
if iter % 1000 == 0 and iter >= 1000:
|
||
|
for param_group in optimizer_idexp.param_groups:
|
||
|
param_group["lr"] *= 0.2
|
||
|
for param_group in optimizer_frame.param_groups:
|
||
|
param_group["lr"] *= 0.2
|
||
|
|
||
|
print(loss_lan.item(), torch.mean(trans[:, 2]).item())
|
||
|
|
||
|
print(f'[INFO] fitting light...')
|
||
|
|
||
|
batch_size = 32
|
||
|
|
||
|
device_default = torch.device("cuda:0")
|
||
|
device_render = torch.device("cuda:0")
|
||
|
renderer = Render_3DMM(arg_focal, h, w, batch_size, device_render)
|
||
|
|
||
|
sel_ids = np.arange(0, num_frames, int(num_frames / batch_size))[:batch_size]
|
||
|
imgs = []
|
||
|
for sel_id in sel_ids:
|
||
|
imgs.append(cv2.imread(img_paths[sel_id])[:, :, ::-1])
|
||
|
imgs = np.stack(imgs)
|
||
|
sel_imgs = torch.as_tensor(imgs).cuda()
|
||
|
sel_lms = lms[sel_ids]
|
||
|
sel_light = light_para.new_zeros((batch_size, 27), requires_grad=True)
|
||
|
set_requires_grad([sel_light])
|
||
|
|
||
|
optimizer_tl = torch.optim.Adam([tex_para, sel_light], lr=0.1)
|
||
|
optimizer_id_frame = torch.optim.Adam([euler_angle, trans, exp_para, id_para], lr=0.01)
|
||
|
|
||
|
for iter in range(71):
|
||
|
sel_exp_para, sel_euler, sel_trans = (
|
||
|
exp_para[sel_ids],
|
||
|
euler_angle[sel_ids],
|
||
|
trans[sel_ids],
|
||
|
)
|
||
|
sel_id_para = id_para.expand(batch_size, -1)
|
||
|
geometry = model_3dmm.get_3dlandmarks(
|
||
|
sel_id_para, sel_exp_para, sel_euler, sel_trans, focal_length, cxy
|
||
|
)
|
||
|
proj_geo = forward_transform(geometry, sel_euler, sel_trans, focal_length, cxy)
|
||
|
|
||
|
loss_lan = cal_lan_loss(proj_geo[:, :, :2], sel_lms.detach())
|
||
|
loss_regid = torch.mean(id_para * id_para)
|
||
|
loss_regexp = torch.mean(sel_exp_para * sel_exp_para)
|
||
|
|
||
|
sel_tex_para = tex_para.expand(batch_size, -1)
|
||
|
sel_texture = model_3dmm.forward_tex(sel_tex_para)
|
||
|
geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para)
|
||
|
rott_geo = forward_rott(geometry, sel_euler, sel_trans)
|
||
|
render_imgs = renderer(
|
||
|
rott_geo.to(device_render),
|
||
|
sel_texture.to(device_render),
|
||
|
sel_light.to(device_render),
|
||
|
)
|
||
|
render_imgs = render_imgs.to(device_default)
|
||
|
|
||
|
mask = (render_imgs[:, :, :, 3]).detach() > 0.0
|
||
|
render_proj = sel_imgs.clone()
|
||
|
render_proj[mask] = render_imgs[mask][..., :3].byte()
|
||
|
loss_col = cal_col_loss(render_imgs[:, :, :, :3], sel_imgs.float(), mask)
|
||
|
|
||
|
if iter > 50:
|
||
|
loss = loss_col + loss_lan * 0.05 + loss_regid * 1.0 + loss_regexp * 0.8
|
||
|
else:
|
||
|
loss = loss_col + loss_lan * 3 + loss_regid * 2.0 + loss_regexp * 1.0
|
||
|
|
||
|
optimizer_tl.zero_grad()
|
||
|
optimizer_id_frame.zero_grad()
|
||
|
loss.backward()
|
||
|
|
||
|
optimizer_tl.step()
|
||
|
optimizer_id_frame.step()
|
||
|
|
||
|
if iter % 50 == 0 and iter > 0:
|
||
|
for param_group in optimizer_id_frame.param_groups:
|
||
|
param_group["lr"] *= 0.2
|
||
|
for param_group in optimizer_tl.param_groups:
|
||
|
param_group["lr"] *= 0.2
|
||
|
# print(iter, loss_col.item(), loss_lan.item(), loss_regid.item(), loss_regexp.item())
|
||
|
|
||
|
|
||
|
light_mean = torch.mean(sel_light, 0).unsqueeze(0).repeat(num_frames, 1)
|
||
|
light_para.data = light_mean
|
||
|
|
||
|
exp_para = exp_para.detach()
|
||
|
euler_angle = euler_angle.detach()
|
||
|
trans = trans.detach()
|
||
|
light_para = light_para.detach()
|
||
|
|
||
|
print(f'[INFO] fine frame-wise fitting...')
|
||
|
|
||
|
for i in range(int((num_frames - 1) / batch_size + 1)):
|
||
|
|
||
|
if (i + 1) * batch_size > num_frames:
|
||
|
start_n = num_frames - batch_size
|
||
|
sel_ids = np.arange(num_frames - batch_size, num_frames)
|
||
|
else:
|
||
|
start_n = i * batch_size
|
||
|
sel_ids = np.arange(i * batch_size, i * batch_size + batch_size)
|
||
|
|
||
|
imgs = []
|
||
|
for sel_id in sel_ids:
|
||
|
imgs.append(cv2.imread(img_paths[sel_id])[:, :, ::-1])
|
||
|
imgs = np.stack(imgs)
|
||
|
sel_imgs = torch.as_tensor(imgs).cuda()
|
||
|
sel_lms = lms[sel_ids]
|
||
|
|
||
|
sel_exp_para = exp_para.new_zeros((batch_size, exp_dim), requires_grad=True)
|
||
|
sel_exp_para.data = exp_para[sel_ids].clone()
|
||
|
sel_euler = euler_angle.new_zeros((batch_size, 3), requires_grad=True)
|
||
|
sel_euler.data = euler_angle[sel_ids].clone()
|
||
|
sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True)
|
||
|
sel_trans.data = trans[sel_ids].clone()
|
||
|
sel_light = light_para.new_zeros((batch_size, 27), requires_grad=True)
|
||
|
sel_light.data = light_para[sel_ids].clone()
|
||
|
|
||
|
set_requires_grad([sel_exp_para, sel_euler, sel_trans, sel_light])
|
||
|
|
||
|
optimizer_cur_batch = torch.optim.Adam(
|
||
|
[sel_exp_para, sel_euler, sel_trans, sel_light], lr=0.005
|
||
|
)
|
||
|
|
||
|
sel_id_para = id_para.expand(batch_size, -1).detach()
|
||
|
sel_tex_para = tex_para.expand(batch_size, -1).detach()
|
||
|
|
||
|
pre_num = 5
|
||
|
|
||
|
if i > 0:
|
||
|
pre_ids = np.arange(start_n - pre_num, start_n)
|
||
|
|
||
|
for iter in range(50):
|
||
|
|
||
|
geometry = model_3dmm.get_3dlandmarks(
|
||
|
sel_id_para, sel_exp_para, sel_euler, sel_trans, focal_length, cxy
|
||
|
)
|
||
|
proj_geo = forward_transform(geometry, sel_euler, sel_trans, focal_length, cxy)
|
||
|
loss_lan = cal_lan_loss(proj_geo[:, :, :2], sel_lms.detach())
|
||
|
loss_regexp = torch.mean(sel_exp_para * sel_exp_para)
|
||
|
|
||
|
sel_geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para)
|
||
|
sel_texture = model_3dmm.forward_tex(sel_tex_para)
|
||
|
geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para)
|
||
|
rott_geo = forward_rott(geometry, sel_euler, sel_trans)
|
||
|
render_imgs = renderer(
|
||
|
rott_geo.to(device_render),
|
||
|
sel_texture.to(device_render),
|
||
|
sel_light.to(device_render),
|
||
|
)
|
||
|
render_imgs = render_imgs.to(device_default)
|
||
|
|
||
|
mask = (render_imgs[:, :, :, 3]).detach() > 0.0
|
||
|
|
||
|
loss_col = cal_col_loss(render_imgs[:, :, :, :3], sel_imgs.float(), mask)
|
||
|
|
||
|
if i > 0:
|
||
|
geometry_lap = model_3dmm.forward_geo_sub(
|
||
|
id_para.expand(batch_size + pre_num, -1).detach(),
|
||
|
torch.cat((exp_para[pre_ids].detach(), sel_exp_para)),
|
||
|
model_3dmm.rigid_ids,
|
||
|
)
|
||
|
rott_geo_lap = forward_rott(
|
||
|
geometry_lap,
|
||
|
torch.cat((euler_angle[pre_ids].detach(), sel_euler)),
|
||
|
torch.cat((trans[pre_ids].detach(), sel_trans)),
|
||
|
)
|
||
|
loss_lap = cal_lap_loss(
|
||
|
[rott_geo_lap.reshape(rott_geo_lap.shape[0], -1).permute(1, 0)], [1.0]
|
||
|
)
|
||
|
else:
|
||
|
geometry_lap = model_3dmm.forward_geo_sub(
|
||
|
id_para.expand(batch_size, -1).detach(),
|
||
|
sel_exp_para,
|
||
|
model_3dmm.rigid_ids,
|
||
|
)
|
||
|
rott_geo_lap = forward_rott(geometry_lap, sel_euler, sel_trans)
|
||
|
loss_lap = cal_lap_loss(
|
||
|
[rott_geo_lap.reshape(rott_geo_lap.shape[0], -1).permute(1, 0)], [1.0]
|
||
|
)
|
||
|
|
||
|
|
||
|
if iter > 30:
|
||
|
loss = loss_col * 0.5 + loss_lan * 1.5 + loss_lap * 100000 + loss_regexp * 1.0
|
||
|
else:
|
||
|
loss = loss_col * 0.5 + loss_lan * 8 + loss_lap * 100000 + loss_regexp * 1.0
|
||
|
|
||
|
optimizer_cur_batch.zero_grad()
|
||
|
loss.backward()
|
||
|
optimizer_cur_batch.step()
|
||
|
|
||
|
# if iter % 10 == 0:
|
||
|
# print(
|
||
|
# i,
|
||
|
# iter,
|
||
|
# loss_col.item(),
|
||
|
# loss_lan.item(),
|
||
|
# loss_lap.item(),
|
||
|
# loss_regexp.item(),
|
||
|
# )
|
||
|
|
||
|
print(str(i) + " of " + str(int((num_frames - 1) / batch_size + 1)) + " done")
|
||
|
|
||
|
render_proj = sel_imgs.clone()
|
||
|
render_proj[mask] = render_imgs[mask][..., :3].byte()
|
||
|
|
||
|
exp_para[sel_ids] = sel_exp_para.clone()
|
||
|
euler_angle[sel_ids] = sel_euler.clone()
|
||
|
trans[sel_ids] = sel_trans.clone()
|
||
|
light_para[sel_ids] = sel_light.clone()
|
||
|
|
||
|
torch.save(
|
||
|
{
|
||
|
"id": id_para.detach().cpu(),
|
||
|
"exp": exp_para.detach().cpu(),
|
||
|
"euler": euler_angle.detach().cpu(),
|
||
|
"trans": trans.detach().cpu(),
|
||
|
"focal": focal_length.detach().cpu(),
|
||
|
},
|
||
|
os.path.join(os.path.dirname(args.path), "track_params.pt"),
|
||
|
)
|
||
|
|
||
|
print("params saved")
|