livetalking/shencoder/sphere_harmonics.py

87 lines
2.6 KiB
Python
Raw Normal View History

2023-12-19 09:41:52 +08:00
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd
try:
import _shencoder as _backend
except ImportError:
from .backend import _backend
class _sh_encoder(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
def forward(ctx, inputs, degree, calc_grad_inputs=False):
# inputs: [B, input_dim], float in [-1, 1]
# RETURN: [B, F], float
inputs = inputs.contiguous()
B, input_dim = inputs.shape # batch size, coord dim
output_dim = degree ** 2
outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
if calc_grad_inputs:
dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device)
else:
dy_dx = None
_backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx)
ctx.save_for_backward(inputs, dy_dx)
ctx.dims = [B, input_dim, degree]
return outputs
@staticmethod
#@once_differentiable
@custom_bwd
def backward(ctx, grad):
# grad: [B, C * C]
inputs, dy_dx = ctx.saved_tensors
if dy_dx is not None:
grad = grad.contiguous()
B, input_dim, degree = ctx.dims
grad_inputs = torch.zeros_like(inputs)
_backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs)
return grad_inputs, None, None
else:
return None, None, None
sh_encode = _sh_encoder.apply
class SHEncoder(nn.Module):
def __init__(self, input_dim=3, degree=4):
super().__init__()
self.input_dim = input_dim # coord dims, must be 3
self.degree = degree # 0 ~ 4
self.output_dim = degree ** 2
assert self.input_dim == 3, "SH encoder only support input dim == 3"
assert self.degree > 0 and self.degree <= 8, "SH encoder only supports degree in [1, 8]"
def __repr__(self):
return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}"
def forward(self, inputs, size=1):
# inputs: [..., input_dim], normalized real world positions in [-size, size]
# return: [..., degree^2]
inputs = inputs / size # [-1, 1]
prefix_shape = list(inputs.shape[:-1])
inputs = inputs.reshape(-1, self.input_dim)
outputs = sh_encode(inputs, self.degree, inputs.requires_grad)
outputs = outputs.reshape(prefix_shape + [self.output_dim])
return outputs