livetalking/ernerf/data_utils/face_tracking/render_land.py

193 lines
6.6 KiB
Python
Raw Permalink Normal View History

2023-12-19 09:41:52 +08:00
import torch
import torch.nn as nn
import render_util
import geo_transform
import numpy as np
def compute_tri_normal(geometry, tris):
geometry = geometry.permute(0, 2, 1)
tri_1 = tris[:, 0]
tri_2 = tris[:, 1]
tri_3 = tris[:, 2]
vert_1 = torch.index_select(geometry, 2, tri_1)
vert_2 = torch.index_select(geometry, 2, tri_2)
vert_3 = torch.index_select(geometry, 2, tri_3)
nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 1)
normal = nn.functional.normalize(nnorm).permute(0, 2, 1)
return normal
class Compute_normal_base(torch.autograd.Function):
@staticmethod
def forward(ctx, normal):
(normal_b,) = render_util.normal_base_forward(normal)
ctx.save_for_backward(normal)
return normal_b
@staticmethod
def backward(ctx, grad_normal_b):
(normal,) = ctx.saved_tensors
(grad_normal,) = render_util.normal_base_backward(grad_normal_b, normal)
return grad_normal
class Normal_Base(torch.nn.Module):
def __init__(self):
super(Normal_Base, self).__init__()
def forward(self, normal):
return Compute_normal_base.apply(normal)
def preprocess_render(geometry, euler, trans, cam, tris, vert_tris, ori_img):
point_num = geometry.shape[1]
rott_geo = geo_transform.euler_trans_geo(geometry, euler, trans)
proj_geo = geo_transform.proj_geo(rott_geo, cam)
rot_tri_normal = compute_tri_normal(rott_geo, tris)
rot_vert_normal = torch.index_select(rot_tri_normal, 1, vert_tris)
is_visible = -torch.bmm(
rot_vert_normal.reshape(-1, 1, 3),
nn.functional.normalize(rott_geo.reshape(-1, 3, 1)),
).reshape(-1, point_num)
is_visible[is_visible < 0.01] = -1
pixel_valid = torch.zeros(
(ori_img.shape[0], ori_img.shape[1] * ori_img.shape[2]),
dtype=torch.float32,
device=ori_img.device,
)
return rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid
class Render_Face(torch.autograd.Function):
@staticmethod
def forward(
ctx, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid
):
batch_size, h, w, _ = ori_img.shape
ori_img = ori_img.view(batch_size, -1, 3)
ori_size = torch.cat(
(
torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
* h,
torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
* w,
),
dim=1,
).view(-1)
tri_index, tri_coord, render, real = render_util.render_face_forward(
proj_geo, ori_img, ori_size, texture, nbl, is_visible, tri_inds, pixel_valid
)
ctx.save_for_backward(
ori_img, ori_size, proj_geo, texture, nbl, tri_inds, tri_index, tri_coord
)
return render, real
@staticmethod
def backward(ctx, grad_render, grad_real):
(
ori_img,
ori_size,
proj_geo,
texture,
nbl,
tri_inds,
tri_index,
tri_coord,
) = ctx.saved_tensors
grad_proj_geo, grad_texture, grad_nbl = render_util.render_face_backward(
grad_render,
grad_real,
ori_img,
ori_size,
proj_geo,
texture,
nbl,
tri_inds,
tri_index,
tri_coord,
)
return grad_proj_geo, grad_texture, grad_nbl, None, None, None, None
class Render_RGB(nn.Module):
def __init__(self):
super(Render_RGB, self).__init__()
def forward(
self, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid
):
return Render_Face.apply(
proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid
)
def cal_land(proj_geo, is_visible, lands_info, land_num):
(land_index,) = render_util.update_contour(lands_info, is_visible, land_num)
proj_land = torch.index_select(proj_geo.reshape(-1, 3), 0, land_index)[
:, :2
].reshape(-1, land_num, 2)
return proj_land
class Render_Land(nn.Module):
def __init__(self):
super(Render_Land, self).__init__()
lands_info = np.loadtxt("../data/3DMM/lands_info.txt", dtype=np.int32)
self.lands_info = torch.as_tensor(lands_info).cuda()
tris = np.loadtxt("../data/3DMM/tris.txt", dtype=np.int64)
self.tris = torch.as_tensor(tris).cuda() - 1
vert_tris = np.loadtxt("../data/3DMM/vert_tris.txt", dtype=np.int64)
self.vert_tris = torch.as_tensor(vert_tris).cuda()
self.normal_baser = Normal_Base().cuda()
self.renderer = Render_RGB().cuda()
def render_mesh(self, geometry, euler, trans, cam, ori_img, light):
batch_size, h, w, _ = ori_img.shape
ori_img = ori_img.view(batch_size, -1, 3)
ori_size = torch.cat(
(
torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
* h,
torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)
* w,
),
dim=1,
).view(-1)
rott_geo, proj_geo, rot_tri_normal, _, _ = preprocess_render(
geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img
)
tri_nb = self.normal_baser(rot_tri_normal.contiguous())
nbl = torch.bmm(
tri_nb, (light.reshape(-1, 9, 3))[:, :, 0].unsqueeze(-1).repeat(1, 1, 3)
)
texture = torch.ones_like(geometry) * 200
(render,) = render_util.render_mesh(
proj_geo, ori_img, ori_size, texture, nbl, self.tris
)
return render.view(batch_size, h, w, 3).byte()
def cal_loss_rgb(self, geometry, euler, trans, cam, ori_img, light, texture, lands):
rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid = preprocess_render(
geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img
)
tri_nb = self.normal_baser(rot_tri_normal.contiguous())
nbl = torch.bmm(tri_nb, light.reshape(-1, 9, 3))
render, real = self.renderer(
proj_geo, texture, nbl, ori_img, is_visible, self.tris, pixel_valid
)
proj_land = cal_land(proj_geo, is_visible, self.lands_info, lands.shape[1])
col_minus = torch.norm((render - real).reshape(-1, 3), dim=1).reshape(
ori_img.shape[0], -1
)
col_dis = torch.mean(col_minus * pixel_valid) / (
torch.mean(pixel_valid) + 0.00001
)
land_dists = torch.norm((proj_land - lands).reshape(-1, 2), dim=1).reshape(
ori_img.shape[0], -1
)
lan_dis = torch.mean(land_dists)
return col_dis, lan_dis