203 lines
6.6 KiB
Python
203 lines
6.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
import os
|
|
from pytorch3d.structures import Meshes
|
|
from pytorch3d.renderer import (
|
|
look_at_view_transform,
|
|
PerspectiveCameras,
|
|
FoVPerspectiveCameras,
|
|
PointLights,
|
|
DirectionalLights,
|
|
Materials,
|
|
RasterizationSettings,
|
|
MeshRenderer,
|
|
MeshRasterizer,
|
|
SoftPhongShader,
|
|
TexturesUV,
|
|
TexturesVertex,
|
|
blending,
|
|
)
|
|
|
|
from pytorch3d.ops import interpolate_face_attributes
|
|
|
|
from pytorch3d.renderer.blending import (
|
|
BlendParams,
|
|
hard_rgb_blend,
|
|
sigmoid_alpha_blend,
|
|
softmax_rgb_blend,
|
|
)
|
|
|
|
|
|
class SoftSimpleShader(nn.Module):
|
|
"""
|
|
Per pixel lighting - the lighting model is applied using the interpolated
|
|
coordinates and normals for each pixel. The blending function returns the
|
|
soft aggregated color using all the faces per pixel.
|
|
|
|
To use the default values, simply initialize the shader with the desired
|
|
device e.g.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
|
|
):
|
|
super().__init__()
|
|
self.lights = lights if lights is not None else PointLights(device=device)
|
|
self.materials = (
|
|
materials if materials is not None else Materials(device=device)
|
|
)
|
|
self.cameras = cameras
|
|
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
|
|
|
def to(self, device):
|
|
# Manually move to device modules which are not subclasses of nn.Module
|
|
self.cameras = self.cameras.to(device)
|
|
self.materials = self.materials.to(device)
|
|
self.lights = self.lights.to(device)
|
|
return self
|
|
|
|
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
|
|
|
texels = meshes.sample_textures(fragments)
|
|
blend_params = kwargs.get("blend_params", self.blend_params)
|
|
|
|
cameras = kwargs.get("cameras", self.cameras)
|
|
if cameras is None:
|
|
msg = "Cameras must be specified either at initialization \
|
|
or in the forward pass of SoftPhongShader"
|
|
raise ValueError(msg)
|
|
znear = kwargs.get("znear", getattr(cameras, "znear", 1.0))
|
|
zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0))
|
|
images = softmax_rgb_blend(
|
|
texels, fragments, blend_params, znear=znear, zfar=zfar
|
|
)
|
|
return images
|
|
|
|
|
|
class Render_3DMM(nn.Module):
|
|
def __init__(
|
|
self,
|
|
focal=1015,
|
|
img_h=500,
|
|
img_w=500,
|
|
batch_size=1,
|
|
device=torch.device("cuda:0"),
|
|
):
|
|
super(Render_3DMM, self).__init__()
|
|
|
|
self.focal = focal
|
|
self.img_h = img_h
|
|
self.img_w = img_w
|
|
self.device = device
|
|
self.renderer = self.get_render(batch_size)
|
|
|
|
dir_path = os.path.dirname(os.path.realpath(__file__))
|
|
topo_info = np.load(
|
|
os.path.join(dir_path, "3DMM", "topology_info.npy"), allow_pickle=True
|
|
).item()
|
|
self.tris = torch.as_tensor(topo_info["tris"]).to(self.device)
|
|
self.vert_tris = torch.as_tensor(topo_info["vert_tris"]).to(self.device)
|
|
|
|
def compute_normal(self, geometry):
|
|
vert_1 = torch.index_select(geometry, 1, self.tris[:, 0])
|
|
vert_2 = torch.index_select(geometry, 1, self.tris[:, 1])
|
|
vert_3 = torch.index_select(geometry, 1, self.tris[:, 2])
|
|
nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 2)
|
|
tri_normal = nn.functional.normalize(nnorm, dim=2)
|
|
v_norm = tri_normal[:, self.vert_tris, :].sum(2)
|
|
vert_normal = v_norm / v_norm.norm(dim=2).unsqueeze(2)
|
|
return vert_normal
|
|
|
|
def get_render(self, batch_size=1):
|
|
half_s = self.img_w * 0.5
|
|
R, T = look_at_view_transform(10, 0, 0)
|
|
R = R.repeat(batch_size, 1, 1)
|
|
T = torch.zeros((batch_size, 3), dtype=torch.float32).to(self.device)
|
|
|
|
cameras = FoVPerspectiveCameras(
|
|
device=self.device,
|
|
R=R,
|
|
T=T,
|
|
znear=0.01,
|
|
zfar=20,
|
|
fov=2 * np.arctan(self.img_w // 2 / self.focal) * 180.0 / np.pi,
|
|
)
|
|
lights = PointLights(
|
|
device=self.device,
|
|
location=[[0.0, 0.0, 1e5]],
|
|
ambient_color=[[1, 1, 1]],
|
|
specular_color=[[0.0, 0.0, 0.0]],
|
|
diffuse_color=[[0.0, 0.0, 0.0]],
|
|
)
|
|
sigma = 1e-4
|
|
raster_settings = RasterizationSettings(
|
|
image_size=(self.img_h, self.img_w),
|
|
blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma / 18.0,
|
|
faces_per_pixel=2,
|
|
perspective_correct=False,
|
|
)
|
|
blend_params = blending.BlendParams(background_color=[0, 0, 0])
|
|
renderer = MeshRenderer(
|
|
rasterizer=MeshRasterizer(raster_settings=raster_settings, cameras=cameras),
|
|
shader=SoftSimpleShader(
|
|
lights=lights, blend_params=blend_params, cameras=cameras
|
|
),
|
|
)
|
|
return renderer.to(self.device)
|
|
|
|
@staticmethod
|
|
def Illumination_layer(face_texture, norm, gamma):
|
|
|
|
n_b, num_vertex, _ = face_texture.size()
|
|
n_v_full = n_b * num_vertex
|
|
gamma = gamma.view(-1, 3, 9).clone()
|
|
gamma[:, :, 0] += 0.8
|
|
|
|
gamma = gamma.permute(0, 2, 1)
|
|
|
|
a0 = np.pi
|
|
a1 = 2 * np.pi / np.sqrt(3.0)
|
|
a2 = 2 * np.pi / np.sqrt(8.0)
|
|
c0 = 1 / np.sqrt(4 * np.pi)
|
|
c1 = np.sqrt(3.0) / np.sqrt(4 * np.pi)
|
|
c2 = 3 * np.sqrt(5.0) / np.sqrt(12 * np.pi)
|
|
d0 = 0.5 / np.sqrt(3.0)
|
|
|
|
Y0 = torch.ones(n_v_full).to(gamma.device).float() * a0 * c0
|
|
norm = norm.view(-1, 3)
|
|
nx, ny, nz = norm[:, 0], norm[:, 1], norm[:, 2]
|
|
arrH = []
|
|
|
|
arrH.append(Y0)
|
|
arrH.append(-a1 * c1 * ny)
|
|
arrH.append(a1 * c1 * nz)
|
|
arrH.append(-a1 * c1 * nx)
|
|
arrH.append(a2 * c2 * nx * ny)
|
|
arrH.append(-a2 * c2 * ny * nz)
|
|
arrH.append(a2 * c2 * d0 * (3 * nz.pow(2) - 1))
|
|
arrH.append(-a2 * c2 * nx * nz)
|
|
arrH.append(a2 * c2 * 0.5 * (nx.pow(2) - ny.pow(2)))
|
|
|
|
H = torch.stack(arrH, 1)
|
|
Y = H.view(n_b, num_vertex, 9)
|
|
lighting = Y.bmm(gamma)
|
|
|
|
face_color = face_texture * lighting
|
|
return face_color
|
|
|
|
def forward(self, rott_geometry, texture, diffuse_sh):
|
|
face_normal = self.compute_normal(rott_geometry)
|
|
face_color = self.Illumination_layer(texture, face_normal, diffuse_sh)
|
|
face_color = TexturesVertex(face_color)
|
|
mesh = Meshes(
|
|
rott_geometry,
|
|
self.tris.float().repeat(rott_geometry.shape[0], 1, 1),
|
|
face_color,
|
|
)
|
|
rendered_img = self.renderer(mesh)
|
|
rendered_img = torch.clamp(rendered_img, 0, 255)
|
|
|
|
return rendered_img
|