目录
[gsplat 四元数转旋转矩阵等同代码实现](#gsplat 四元数转旋转矩阵等同代码实现)
[scipy 四元数转旋转矩阵替换代码](#scipy 四元数转旋转矩阵替换代码)
gsplat 四元数转旋转矩阵等同代码实现
python
import torch
import torch.nn.functional as F
def quat_act(x: torch.Tensor) -> torch.Tensor:
return x / x.norm(dim=-1, keepdim=True)
def normalized_quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor:
# 源码来自: from gsplat.utils import normalized_quat_to_rotmat
"""Convert normalized quaternion to rotation matrix.
Args:
quat: Normalized quaternion in wxyz convension. (..., 4)
Returns:
Rotation matrix (..., 3, 3)
"""
assert quat.shape[-1] == 4, quat.shape
w, x, y, z = torch.unbind(quat, dim=-1)
mat = torch.stack(
[
1 - 2 * (y**2 + z**2),
2 * (x * y - w * z),
2 * (x * z + w * y),
2 * (x * y + w * z),
1 - 2 * (x**2 + z**2),
2 * (y * z - w * x),
2 * (x * z - w * y),
2 * (y * z + w * x),
1 - 2 * (x**2 + y**2),
],
dim=-1,
)
return mat.reshape(quat.shape[:-1] + (3, 3))
def quat2mat(quat):
qw, qx, qy, qz = torch.unbind(quat, dim=-1) # 原为wxyz
quat_xyzw = torch.stack([qx, qy, qz, qw], dim=-1) # 转为xyzw顺序
# 后续代码保持原逻辑
qx, qy, qz, qw = torch.unbind(quat_xyzw, dim=-1)
# 计算旋转矩阵
R00 = 1 - 2 * (qy ** 2 + qz ** 2)
R01 = 2 * (qx * qy - qw * qz)
R02 = 2 * (qx * qz + qw * qy)
R10 = 2 * (qx * qy + qw * qz)
R11 = 1 - 2 * (qx ** 2 + qz ** 2)
R12 = 2 * (qy * qz - qw * qx)
R20 = 2 * (qx * qz - qw * qy)
R21 = 2 * (qy * qz + qw * qx)
R22 = 1 - 2 * (qx ** 2 + qy ** 2)
# 将旋转矩阵堆叠在一起
matrix = torch.stack([R00, R01, R02, R10, R11, R12, R20, R21, R22], dim=-1)
# 变换为 3x3 的矩阵
return matrix.view(-1, 3, 3)
x=torch.range(0,3*4-1)
x=x.reshape(-1,4)
print(x)
# 调用 quat_act 函数进行归一化
normalized_x = quat_act(x)
aa=F.normalize(x, dim=-1)
print('diff',(normalized_x-aa).sum(dim=-1))
print("\nNormalized x:")
print(aa) # 应该返回一个全为 1 的张量
if 1:
mat= normalized_quat_to_rotmat(aa)
print(mat)
mat2=quat2mat(aa)
print('diff2', (mat2 - mat).sum(dim=-1))
scipy 四元数转旋转矩阵替换代码
python
import torch
from scipy.spatial.transform import Rotation as R
import torch.nn.functional as F
def quat2mat_scipy(quat):
# 从四元数中提取 qx, qy, qz, qw
qx, qy, qz, qw = torch.unbind(quat, dim=-1)
# 计算旋转矩阵
R00 = 1 - 2 * (qy ** 2 + qz ** 2)
R01 = 2 * (qx * qy - qw * qz)
R02 = 2 * (qx * qz + qw * qy)
R10 = 2 * (qx * qy + qw * qz)
R11 = 1 - 2 * (qx ** 2 + qz ** 2)
R12 = 2 * (qy * qz - qw * qx)
R20 = 2 * (qx * qz - qw * qy)
R21 = 2 * (qy * qz + qw * qx)
R22 = 1 - 2 * (qx ** 2 + qy ** 2)
# 将旋转矩阵堆叠在一起
matrix = torch.stack([R00, R01, R02, R10, R11, R12, R20, R21, R22], dim=-1)
# 变换为 3x3 的矩阵
return matrix.view(-1, 3, 3)
if 1:
x = torch.range(0, 3 * 4 - 1)
x = x.reshape(-1, 4)
aa = F.normalize(x, dim=-1)
r = R.from_quat(aa.numpy())
mat3= r.as_matrix()
mat4=quat2mat_scipy(aa)
print('diff3', (mat4.numpy() - mat3).sum(axis=-1))