AlphaFold3 rigid_utils 模块的 identity_trans 函数的功能是生成带有批次维度的全零平移向量张量。
源代码:
@lru_cache(maxsize=None)
def identity_trans(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
trans = torch.zeros(
(*batch_dims, 3),
dtype=dtype,
device=device,
requires_grad=requires_grad
)
return trans
源码解读:
1. 函数定义
def identity_trans(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
✅ 参数解析:
-
batch_dims: 表示输入的批次维度,比如(16,)就创建 16 个零向量(16, 3)。 -
dtype: 数据类型,float32,float64之类的。 -
device: 张量在哪个设备上,cpu或cuda。 -
requires_grad: 是否需要梯度,默认为True,适合训练场景。
👉 目标 :创建一个形状 [*, 3] 的零向量,代表初始平移向量 (0, 0, 0)。
2. 创建张量
trans = torch.zeros(
(*batch_dims, 3),
dtype=dtype,
device=device,
requires_grad=requires_grad
)
✅ 解析逐项看:
-
(*batch_dims, 3):-
batch_dims展开成多个批次维度,比如(16,)展开后就是16。 -
3是每个向量的长度,表示(x, y, z)三个方向的位移。 -
最终形状类似
torch.Size([16, 3])。
-
-
dtype: 控制数据类型,比如torch.float32。 -
device: 控制张量生成在哪个设备,比如cuda:0。 -
requires_grad: 如果True,这个张量就会参与梯度计算(适合训练用)。
3. 返回张量
return trans
最终返回 形状 [*, 3] 的全零平移向量张量。
4. 总结
identity_trans() 的核心功能:
✅ 生成初始位移向量 ------ 形状 [*, 3] 的零向量,代表 "不移动"。
✅ 支持多批次输入 ------ batch_dims 灵活扩展支持多维数据,比如 [(8, 4, 3)]。
✅ 缓存加速 ------ 重复调用相同参数时,不重复创建张量,直接用缓存结果。
✅ 支持梯度训练 ------ 默认开启 requires_grad=True,可以在训练时更新平移向量。