70 lines
1.9 KiB
Python
70 lines
1.9 KiB
Python
|
"""This module contains functions for geometry transform and camera projection"""
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import numpy as np
|
||
|
|
||
|
|
||
|
def euler2rot(euler_angle):
|
||
|
batch_size = euler_angle.shape[0]
|
||
|
theta = euler_angle[:, 0].reshape(-1, 1, 1)
|
||
|
phi = euler_angle[:, 1].reshape(-1, 1, 1)
|
||
|
psi = euler_angle[:, 2].reshape(-1, 1, 1)
|
||
|
one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device)
|
||
|
zero = torch.zeros(
|
||
|
(batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device
|
||
|
)
|
||
|
rot_x = torch.cat(
|
||
|
(
|
||
|
torch.cat((one, zero, zero), 1),
|
||
|
torch.cat((zero, theta.cos(), theta.sin()), 1),
|
||
|
torch.cat((zero, -theta.sin(), theta.cos()), 1),
|
||
|
),
|
||
|
2,
|
||
|
)
|
||
|
rot_y = torch.cat(
|
||
|
(
|
||
|
torch.cat((phi.cos(), zero, -phi.sin()), 1),
|
||
|
torch.cat((zero, one, zero), 1),
|
||
|
torch.cat((phi.sin(), zero, phi.cos()), 1),
|
||
|
),
|
||
|
2,
|
||
|
)
|
||
|
rot_z = torch.cat(
|
||
|
(
|
||
|
torch.cat((psi.cos(), -psi.sin(), zero), 1),
|
||
|
torch.cat((psi.sin(), psi.cos(), zero), 1),
|
||
|
torch.cat((zero, zero, one), 1),
|
||
|
),
|
||
|
2,
|
||
|
)
|
||
|
return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
|
||
|
|
||
|
|
||
|
def rot_trans_geo(geometry, rot, trans):
|
||
|
rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans.view(-1, 3, 1)
|
||
|
return rott_geo.permute(0, 2, 1)
|
||
|
|
||
|
|
||
|
def euler_trans_geo(geometry, euler, trans):
|
||
|
rot = euler2rot(euler)
|
||
|
return rot_trans_geo(geometry, rot, trans)
|
||
|
|
||
|
|
||
|
def proj_geo(rott_geo, camera_para):
|
||
|
fx = camera_para[:, 0]
|
||
|
fy = camera_para[:, 0]
|
||
|
cx = camera_para[:, 1]
|
||
|
cy = camera_para[:, 2]
|
||
|
|
||
|
X = rott_geo[:, :, 0]
|
||
|
Y = rott_geo[:, :, 1]
|
||
|
Z = rott_geo[:, :, 2]
|
||
|
|
||
|
fxX = fx[:, None] * X
|
||
|
fyY = fy[:, None] * Y
|
||
|
|
||
|
proj_x = -fxX / Z + cx[:, None]
|
||
|
proj_y = fyY / Z + cy[:, None]
|
||
|
|
||
|
return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2)
|