147 lines
5.1 KiB
Python
147 lines
5.1 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
def axis_angle_from_quat(quat: np.ndarray, eps: float = 1.0e-6) -> np.ndarray:
|
|
"""Convert rotations given as quaternions to axis/angle.
|
|
|
|
Args:
|
|
quat: The quaternion orientation in (w, x, y, z). Shape is (..., 4).
|
|
eps: The tolerance for Taylor approximation. Defaults to 1.0e-6.
|
|
|
|
Returns:
|
|
Rotations given as a vector in axis angle form. Shape is (..., 3).
|
|
The vector's magnitude is the angle turned anti-clockwise in radians around the vector's direction.
|
|
|
|
Reference:
|
|
https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L526-L554
|
|
"""
|
|
# Modified to take in quat as [q_w, q_x, q_y, q_z]
|
|
# Quaternion is [q_w, q_x, q_y, q_z] = [cos(theta/2), n_x * sin(theta/2), n_y * sin(theta/2), n_z * sin(theta/2)]
|
|
# Axis-angle is [a_x, a_y, a_z] = [theta * n_x, theta * n_y, theta * n_z]
|
|
# Thus, axis-angle is [q_x, q_y, q_z] / (sin(theta/2) / theta)
|
|
# When theta = 0, (sin(theta/2) / theta) is undefined
|
|
# However, as theta --> 0, we can use the Taylor approximation 1/2 -
|
|
# theta^2 / 48
|
|
quat = quat * (1.0 - 2.0 * (quat[..., 0:1] < 0.0))
|
|
mag = np.linalg.norm(quat[..., 1:], axis=-1)
|
|
half_angle = np.arctan2(mag, quat[..., 0])
|
|
angle = 2.0 * half_angle
|
|
# check whether to apply Taylor approximation
|
|
sin_half_angles_over_angles = np.where(
|
|
np.abs(angle) > eps,
|
|
np.sin(half_angle) / angle,
|
|
0.5 - angle * angle / 48
|
|
)
|
|
return quat[..., 1:4] / sin_half_angles_over_angles[..., None]
|
|
|
|
|
|
def quat_from_angle_axis(
|
|
angle: torch.Tensor,
|
|
axis: torch.Tensor = None) -> torch.Tensor:
|
|
"""Convert rotations given as angle-axis to quaternions.
|
|
|
|
Args:
|
|
angle: The angle turned anti-clockwise in radians around the vector's direction. Shape is (N,).
|
|
axis: The axis of rotation. Shape is (N, 3).
|
|
|
|
Returns:
|
|
The quaternion in (w, x, y, z). Shape is (N, 4).
|
|
"""
|
|
if axis is None:
|
|
axa = angle
|
|
angle = torch.linalg.norm(axa, dim=-1)
|
|
axis = axa / angle[..., None].clamp_min(1e-6)
|
|
theta = (angle / 2).unsqueeze(-1)
|
|
xyz = normalize(axis) * theta.sin()
|
|
w = theta.cos()
|
|
return normalize(torch.cat([w, xyz], dim=-1))
|
|
|
|
|
|
def quat_mul(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
|
|
"""Multiply two quaternions together.
|
|
|
|
Args:
|
|
q1: The first quaternion in (w, x, y, z). Shape is (..., 4).
|
|
q2: The second quaternion in (w, x, y, z). Shape is (..., 4).
|
|
|
|
Returns:
|
|
The product of the two quaternions in (w, x, y, z). Shape is (..., 4).
|
|
|
|
Raises:
|
|
ValueError: Input shapes of ``q1`` and ``q2`` are not matching.
|
|
"""
|
|
# check input is correct
|
|
if q1.shape != q2.shape:
|
|
msg = f"Expected input quaternion shape mismatch: {q1.shape} != {q2.shape}."
|
|
raise ValueError(msg)
|
|
# reshape to (N, 4) for multiplication
|
|
shape = q1.shape
|
|
q1 = q1.reshape(-1, 4)
|
|
q2 = q2.reshape(-1, 4)
|
|
# extract components from quaternions
|
|
w1, x1, y1, z1 = q1[:, 0], q1[:, 1], q1[:, 2], q1[:, 3]
|
|
w2, x2, y2, z2 = q2[:, 0], q2[:, 1], q2[:, 2], q2[:, 3]
|
|
# perform multiplication
|
|
ww = (z1 + x1) * (x2 + y2)
|
|
yy = (w1 - y1) * (w2 + z2)
|
|
zz = (w1 + y1) * (w2 - z2)
|
|
xx = ww + yy + zz
|
|
qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
|
|
w = qq - ww + (z1 - y1) * (y2 - z2)
|
|
x = qq - xx + (x1 + w1) * (x2 + w2)
|
|
y = qq - yy + (w1 - x1) * (y2 + z2)
|
|
z = qq - zz + (z1 + y1) * (w2 - x2)
|
|
|
|
return torch.stack([w, x, y, z], dim=-1).view(shape)
|
|
|
|
|
|
def quat_rotate(q: np.ndarray, v: np.ndarray) -> np.ndarray:
|
|
"""Rotate a vector by a quaternion along the last dimension of q and v.
|
|
|
|
Args:
|
|
q: The quaternion in (w, x, y, z). Shape is (..., 4).
|
|
v: The vector in (x, y, z). Shape is (..., 3).
|
|
|
|
Returns:
|
|
The rotated vector in (x, y, z). Shape is (..., 3).
|
|
"""
|
|
q_w = q[..., 0]
|
|
q_vec = q[..., 1:]
|
|
a = v * (2.0 * q_w**2 - 1.0)[..., None]
|
|
b = np.cross(q_vec, v, axis=-1) * q_w[..., None] * 2.0
|
|
c = q_vec * np.einsum("...i,...i->...", q_vec, v)[..., None] * 2.0
|
|
return a + b + c
|
|
|
|
|
|
def quat_rotate_inverse(q: np.ndarray, v: np.ndarray) -> np.ndarray:
|
|
"""Rotate a vector by the inverse of a quaternion along the last dimension of q and v.
|
|
|
|
Args:
|
|
q: The quaternion in (w, x, y, z). Shape is (..., 4).
|
|
v: The vector in (x, y, z). Shape is (..., 3).
|
|
|
|
Returns:
|
|
The rotated vector in (x, y, z). Shape is (..., 3).
|
|
"""
|
|
q_w = q[..., 0]
|
|
q_vec = q[..., 1:]
|
|
a = v * (2.0 * q_w**2 - 1.0)[..., None]
|
|
b = np.cross(q_vec, v, axis=-1) * q_w[..., None] * 2.0
|
|
c = q_vec * np.einsum("...i,...i->...", q_vec, v)[..., None] * 2.0
|
|
return a - b + c
|
|
|
|
|
|
def index_map(k_to, k_from):
|
|
"""
|
|
Returns an index mapping from k_from to k_to.
|
|
|
|
Given k_to=a, k_from=b,
|
|
returns an index map "a_from_b" such that
|
|
array_a[a_from_b] = array_b
|
|
|
|
Missing values are set to -1.
|
|
"""
|
|
index_dict = {k: i for i, k in enumerate(k_to)} # O(len(k_from))
|
|
return [index_dict.get(k, -1) for k in k_from] # O(len(k_to)) |