AlphaFold3 rigid_utils 模块的
identity_trans
函数的功能是生成带有批次维度的全零平移向量张量。
源代码:
@lru_cache(maxsize=None)
def identity_trans(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
trans = torch.zeros(
(*batch_dims, 3),
dtype=dtype,
device=device,
requires_grad=requires_grad
)
return trans
源码解读:
1.
函数定义
def identity_trans(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
✅ 参数解析:
-
batch_dims
: 表示输入的批次维度,比如(16,)
就创建 16 个零向量(16, 3)
。 -
dtype
: 数据类型,float32
,float64
之类的。 -
device
: 张量在哪个设备上,cpu
或cuda
。 -
requires_grad
: 是否需要梯度,默认为True
,适合训练场景。
👉 目标 :创建一个形状 [*, 3]
的零向量,代表初始平移向量 (0, 0, 0)。
2. 创建张量
trans = torch.zeros(
(*batch_dims, 3),
dtype=dtype,
device=device,
requires_grad=requires_grad
)
✅ 解析逐项看:
-
(*batch_dims, 3)
:-
batch_dims
展开成多个批次维度,比如(16,)
展开后就是16
。 -
3
是每个向量的长度,表示(x, y, z)
三个方向的位移。 -
最终形状类似
torch.Size([16, 3])
。
-
-
dtype
: 控制数据类型,比如torch.float32
。 -
device
: 控制张量生成在哪个设备,比如cuda:0
。 -
requires_grad
: 如果True
,这个张量就会参与梯度计算(适合训练用)。
3. 返回张量
return trans
最终返回 形状 [*, 3]
的全零平移向量张量。
4. 总结
identity_trans()
的核心功能:
✅ 生成初始位移向量 ------ 形状 [*, 3]
的零向量,代表 "不移动"。
✅ 支持多批次输入 ------ batch_dims
灵活扩展支持多维数据,比如 [(8, 4, 3)]
。
✅ 缓存加速 ------ 重复调用相同参数时,不重复创建张量,直接用缓存结果。
✅ 支持梯度训练 ------ 默认开启 requires_grad=True
,可以在训练时更新平移向量。