temoprarily copy noramlize()

This commit is contained in:
yjinzero 2025-02-15 11:18:04 +09:00
parent 2e7355a346
commit 016fd440f1
1 changed files with 12 additions and 0 deletions

View File

@ -37,6 +37,18 @@ def axis_angle_from_quat(quat: np.ndarray, eps: float = 1.0e-6) -> np.ndarray:
) )
return quat[..., 1:4] / sin_half_angles_over_angles[..., None] return quat[..., 1:4] / sin_half_angles_over_angles[..., None]
def normalize(x: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
"""Normalizes a given input tensor to unit length.
Args:
x: Input tensor of shape (N, dims).
eps: A small value to avoid division by zero. Defaults to 1e-9.
Returns:
Normalized tensor of shape (N, dims).
"""
return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1)
def quat_from_angle_axis( def quat_from_angle_axis(
angle: torch.Tensor, angle: torch.Tensor,