AF3 superimpose函数解读

AlphaFold3 superimpose函数通过使用SVD最小化RMSD,将坐标叠加到参考上,在蛋白质结构预测中用于比较预测结构与真实结构的相似性。

源代码:

复制代码
from src.utils.geometry.alignment import weighted_rigid_align
from src.utils.geometry.vector import Vec3Array
import torch


def compute_rmsd(tensor1, tensor2, mask, eps=1e-6):
    """Compute the RMSD between two tensors."""
    diff = tensor1 - tensor2
    squared_diff = diff ** 2
    sum_squared_diff = squared_diff.sum(dim=-1)

    # Mask out invalid positions
    sum_squared_diff = sum_squared_diff * mask

    # Average over valid positions
    denom = mask.sum(dim=-1) + eps
    mean_squared_diff = torch.sum(sum_squared_diff, dim=-1) / denom

    # Square root to get RMSD
    rmsd = torch.sqrt(mean_squared_diff + eps)
    return rmsd


def superimpose(reference, coords, mask):
    """
        Superimposes coordinates onto a reference by minimizing RMSD using SVD.

        Args:
            reference:
                [*, N, 3] reference tensor
            coords:
                [*, N, 3] tensor
            mask:
                [*, N] tensor
        Returns:
            A tuple of [*, N, 3] superimposed coords and [*] final RMSDs.
    """
    # To Vec3Array for alignment
    reference = Vec3Array.from_array(reference)
    coords = Vec3Array.from_array(coords)

    # Align the coordinates to the reference
    aligned_coords = weighted_rigid_align(coords, reference, weights=mask, mask=mask)
    aligned_coords = aligned_coords.to_tensor()
    reference = reference.to_tensor()
    # Compute RMSD
    rmsds = compute_rmsd(reference, aligned_coords)
    return aligned_coords, rmsds

代码解析

1. compute_rmsd 函数

功能: 计算两组坐标 tensor1tensor2 之间的 RMSD,常用于比较蛋白质结构的相似性。

复制代码
def compute_rmsd(tensor1, tensor2, mask, eps=1e-6):
  • tensor1: 真实坐标 (ground truth) [*, N, 3]
  • tensor2: 预测坐标 [*, N, 3]
  • mask: 指定哪些坐标是有效的 [*, N]
  • eps: 避免除零问题的小常数

步骤解析

复制代码
diff = tensor1 - tensor2  # 计算两者的差值
squared_diff = diff ** 2  # 差值平方
sum_squared_diff = squared_diff.sum(dim=-1)  # 在最后一个维度 (3D) 上求和,即每个点的平方误差和

计算的是欧几里得距离的平方,形状 [*, N]

复制代码
sum_squared_diff = sum_squared_diff * mask  # 只保留mask中有效的坐标

屏蔽无效坐标(未对齐的或者不感兴趣的原子)。

复制代码
denom = mask.sum(dim=-1) + eps  # 计算有效坐标数
mean_squared_diff = torch.sum(sum_squared_diff, dim=-1) / denom  # 计算均方误差

有效原子位置的均方误差(MSE)。

复制代码
rmsd = torch.sqrt(mean_squared_diff + eps)  # 开方得到 RMSD

最终 RMSD 计算完成,返回的是 [ * ] 形状的张量,即每个样本的 RMSD 值。


2. superimpose 函数

功能:

使用加权刚性对齐 (Weighted Rigid Alignment) 方法,通过奇异值分解 (SVD) 使预测坐标 coords 尽可能匹配 reference,然后计算 RMSD。

复制代码
def superimpose(reference, coords, mask):
  • reference: 真实的结构坐标 [*, N, 3]
  • coords: 需要对齐的预测坐标 [*, N, 3]
  • mask: 有效位置的掩码 [*, N]

步骤解析

复制代码
reference = Vec3Array.from_array(reference)
coords = Vec3Array.from_array(coords)

这里 Vec3Array 是一个三维向量的封装类,使得操作更符合几何向量计算需求。

复制代码
aligned_coords = weighted_rigid_align(coords, reference, weights=mask, mask=mask)

核心部分

  • weighted_rigid_align 是一个加权刚性对齐函数,使用SVD(奇异值分解)寻找最优旋转和平移 ,将 coordsreference 对齐。

  • weights=mask 代表对对齐过程加权,仅使用有效坐标对齐。

    aligned_coords = aligned_coords.to_tensor()
    reference = reference.to_tensor()

Vec3Array 转回 PyTorch 张量,方便计算 RMSD。

复制代码
rmsds = compute_rmsd(reference, aligned_coords)
  • 计算对齐后的 RMSD,衡量对齐质量。

最终返回:

  • aligned_coords: 对齐后的坐标
  • rmsds: 对齐后的 RMSD 值

代码核心逻辑

  1. compute_rmsd
    • 计算两组坐标的 RMSD(用于衡量预测误差)。
  2. superimpose
    • 通过 SVD 进行刚性对齐,使预测结构尽量匹配真实结构。
    • 计算对齐后的 RMSD,评估对齐质量。

这两个函数主要用于蛋白质结构比对,常用于 AlphaFold3 这种结构预测模型的评估。

相关推荐
Blossom.1188 分钟前
把AI“绣”进丝绸:生成式刺绣神经网络让古装自带摄像头
人工智能·pytorch·python·深度学习·神经网络·机器学习·fpga开发
大力财经12 分钟前
百度搜索开启公测AI短剧平台,将投入亿元基金、百亿流量扶持创作者
人工智能
RPA中国23 分钟前
谷雨互动赵乾坤 | AI答案时代生存法则:从流量变迁到GEO实践
人工智能
paopaokaka_luck28 分钟前
基于SpringBoot+Vue的数码交流管理系统(AI问答、协同过滤算法、websocket实时聊天、Echarts图形化分析)
vue.js·人工智能·spring boot·websocket·echarts
BB_CC_DD1 小时前
在NVIDIA Jetson Orin NX (Ubuntu 22.04, JetPack 5.1, CUDA 11 cuDnn8) 上安装PyTorch 2
pytorch·深度学习·ubuntu
youngfengying1 小时前
身体活动(physical activity)---深度学习
人工智能·深度学习
START_GAME1 小时前
语音合成系统---IndexTTS2:环境配置与实战
人工智能·语音识别
2501_930799241 小时前
访答知识库#Pdf转word#人工智能#Al编辑器#访答RAG#企业知识库,个人知识库,本地知识库,访答编辑器,访答浏览器……
人工智能
max5006001 小时前
多GPU数据并行训练中GPU利用率不均衡问题深度分析与解决方案
人工智能·机器学习·分类·数据挖掘