图神经网络和分子表征:3. 不变网络最后的辉煌

写这篇文章的时候已经是2023年的8月份,GNN for molecule property prediction 这个小领域正在变得火热起来,各大榜单被不断刷新,颇有当年 CNN 刷榜 imagenet 的势头。

由于对力、维里等性质有着内禀优势,当下高居榜首的模型毫无疑问是NequIP, MACE这些等变模型。与此同时,传统不变模型的江山正在一点点被蚕食,逐渐走向没落。话虽如此,学习不变模型依然是有必要的,因为不变的3D-based模型成功打败了传统的2D模型,其对角度、二面角等几何信息的利用,消息传递机制的设计都非常值得我们学习。

本篇博客,我们将依次介绍首次纳入角度信息的DimeNet(2020 ICLR),受DimeNet启发的GemNet (NeurIPS 2021),PAINN(2021 ICML)和SphereNet(2022 ICLR)以及做到局域完备性的ComENet(NeurIPS 2022)。

(放GNN Expressive的slide镇楼)

DimeNet

DimeNet 是划时代的一个网络架构。

之所以给这么高的评价,是因为在 DimeNet 之前,所有的网络结构最多只利用了几何信息中的距离信息,而 DimeNet 是第一个将角度信息引入 GNN 的模型,虽然引入的方式现在看十分的笨重。

其消息传递机制如下所示:

节点 i 有 Ni 个邻居,节点 j 是其中一个。节点 j 又有 Nj 个邻居,k1, k2, k3是其中三个。

节点 i 将收到来自 Ni 个邻居的 Ni 条信息,每一条信息的制作过程由 各个邻居负责。以邻居 j 为例,它将除了 i 意外的所有自己邻居的信息进行一个汇总,后再传给 i,这期间就包含了以 j 为核心的,角度 k1 j i 的信息。

如果你觉得上面这段话抽象,你可以设想节点 i 是一个国王,他将向子民收税。首先,国王向大臣(最近的一圈邻居)收税。大臣们收到命令后,向老百姓收税,大臣们汇集之后,再向国王缴税。当然这是一个简化的模型。

不过,可以很清晰的看到,这种消息传递的机制涉及 2-hop 的消息传递,这使得 DimeNet 的算力消耗大幅增加。

为了充分利用引入的角度信息,DimeNet 从 DFT 领域中引入了球极坐标系,以便适配 spherical harmonics 基组。

然而,DimeNet 部分模块设计非常不合理,其改进版本 DimeNet++ 很快发布,并在2020年底登上了 NeurIPS。现在 DIG 仓库中已经没有 DimeNet 了,默认的是 DimeNet++ 版本。

GemNet,PAINN和SphereNet

DimeNet 的发布引起了不小的轰动,多位研究小组加入战场。其中比较有代表性的分别是 GemNet 和 PAINN。SphereNet 入局较晚,但也收获了一篇ICLR。

GemNet 的思路比较好理解,既然 2-hop 的消息传递能够提升精度,3-hop 势必能更进一步。如果说, 2-hop 可以引入 三体 间的角度,那么 3-hop 则可以引入 四体 间的二面角。

我们顺延上一小节的例子就是,国王向大臣征税,大臣向乡绅征税,乡绅向老百姓征税,过三层。国王、大臣、乡绅和百姓构成的四体关系能够有效提炼出二面角信息。

然而,这样做的代价就是,虽然精度提高了,但计算量猛增。从 DimeNet 的 O ( n k 2 ) O(nk^2) O(nk2) 增到 O ( n k 3 ) O(nk^3) O(nk3)

GemNet 在 qm9 上维持了相当长时间的 SOTA 记录,但从技术角度看,这并不是一个 smart 的技术路线。

PAINN 则指出,DimeNet 的角度引入方式可以进一步精简。

PAINN 创新性的引入了向量的概念,如上图所示。PAINN 对比了三个几何信息。距离信息仅需要 O(N) 的计算量,但无法反映键角的变化。键角信息的计算需要 O ( n 2 ) O(n^2) O(n2) 的成本,然而方向信息可以进一步缩减至 O ( n ) O(n) O(n) 。(两向量求和本身就包含了键角信息)

因此,PAINN 将消息传递分成了由标量几何信息(距离)承载的路线和由向量几何信息(方向)承载的路线。(并行的两条)

这种方式进一步可以将整个网络架构改造成等变网络,对力,导热性质等的预测精度大幅提升。我想,这大概是等变网络兴起的开端。(至少也是起了推动作用)

最后,发布 DimeNet 的 ShuiwangJi 小组在2021年也发布了能够嵌入二面角信息的网络 SphereNet,然而由于 GemNet 表现过于亮眼(虽然大部分是计算量换来的)SphereNet 未能在 2021 年见刊,最终拖延到了2022年年初发表在了 ICLR 上。

SphereNet 的思路和 GemNet 是类似的,也是在 距离、角度 后面缀上二面角。既然是二面角信息,就不可避免要设计到 四体。

在 GemNet 中,四体是国王,大臣,乡绅和老百姓。等级分明。所以是 3-hop 的消息传递机制。

在 SphereNet 中,四体设计为,国王,大臣,乡绅和乡绅。具体来说,国王和大臣构成了球极坐标的z轴,众乡绅向大臣进贡的同时需要遵循一个逆时针的顺序,即需要拿自己紧邻的,逆时针方向前,的乡绅当参考点,求出四体的二面角。

显然,这是一个 2-hop 的消息传递机制,计算量比 GemNet O ( n k 3 ) O(nk^3) O(nk3) 小,与 DimeNet O ( n k 2 ) O(nk^2) O(nk2) 相当。

但是,SphereNet 并不是一个 local 完备的模型。在 SphereNet 原文中,作者指出,SphereNet 可以有效区分 手性分子,但对于一些特别刁钻的情形,SphereNet 依然无法有效区分。

ComENet

把 ComENet 放在最后并单独成章,是因为 ComENet 的发布代表了不变网络最后的辉煌。

彼时已进入2022年下半年,各大数据集早已被等变网络霸榜。不变网络似乎迎来了英雄落幕。

ComENet 从数学角度严格证明了其 local 完备性,并大幅降低计算量至 1-hop 水平。这使得 ComENet 计算速度大幅提升,同时理论上拥有不输于 SphereNet 的 local 鉴别能力。但这并没有在部分数据集上取得精度的提升(QM9),部分原因可能是(笔者个人猜测) 2-hop 的模型可能比 1-hop 抓取信息的效率更高。

那 ComENet 是怎么进行消息传递的呢?

如下图所示:

ComENet 也定义了一个 四体(为了拿到二面角信息),只不过这个四体是以边为中心的。

我们还用国王收税举例:

  1. 国王 i 最终是要从邻居(大臣 j )那里一一收税的,这其实对应了消息传递的最后阶段。
  2. 那么大臣 j 怎么收税呢?
    2.1 大臣 j 首先收角度的税。他找来了同样是大臣的,距离 i 最近的点 f i j f_{i\\j} fij 作为参考点,收一个角度税 f i j f_{i\\j} fij, i, j
    2.2 大臣 j 收一个二面角的税。他找来了同样是大臣的,距离 i 最近的点 f i j f_{i\\j} fij 作为一个参考点,又找来了同样是大臣的,距离 i 第二近的点 s i s_i si 作为第二个参考点。平面 f i j f_{i\\j} fij, i, j 和 平面 f i j f_{i\\j} fij, i, s i s_i si 构成一个二面角。注意, f i j f_{i\\j} fij, i, s i s_i si 平面与 j 无关,因此是一个参考平面。
    2.3 大臣 j 向自己最近的乡绅收一个二面角的税。为什么要收这个税?是因为,前面这些税,实际上已经可以将 i 的 local 固定住了。但 local i 的 completeness 并不能随着消息传递拓展至外边。对此,作者提出将 节点 j 的最近一个邻居作为参考点,只要固定住了 f i j f_{i\\j} fij, i, j, f j i f_{j\\i} fji 四者构成的二面角,我们就可以在经过消息传递后收获全局的完备性。

ComENet 的伪代码如下:

其中的 Eq.1 如下:

大眼一看,ComENet 确实是一个 1-hop 的模型。但上面国王的例子中,我们可以看到,ComENet 又确实涉及到了乡绅阶层。这是一个相对模糊的事情:

  1. ComENet 在进行点节点向量的迭代时,仅涉及了邻居节点的特征向量,从这个角度看是 1-hop 的。( Eq.1 )
  2. ComENet 在计算所谓的 rotable bond 的时候,用到了邻居的邻居,(但也紧紧是计算了一个二面角作为输入),这里面确实有一些 2-hop 的几何信息泄露。
  3. 从计算量角度看,ComENet 是一个 O ( n k ) O(nk) O(nk) ,k 指 i 的邻居,n 指 i 的个数。这个是毋庸置疑的。

总体来说,这种 local -> global 的思路与 ClofNet 不谋而和,然而, ComENet 是基于球极坐标系的,这使其难以使用等变网络的算子。虽然 ComENet 在部分实验中,和不变网络相比取得了不错的表现(尤其是速度),但这些很难和等变网络进行对比。

与之相对的, ClofNet 继承了 local completeness 的设计优点,并进一步将其与等变算子结合形成了 leftnet。leftnet 在多个数据集上爆杀特杀,预计在很长一段时间内将保持 SOTA 优势。

小结

不变网络的发展历程总体来看是一个"不断纳入更多几何信息"的过程,发展至 ComENet ,我想几何信息的收集已经可以画上完美的句号。从更庞大的视角看,不变网络第一次完美击败了 2DGNN,让3DGNN登上舞台。同时,不变网络所衍生出的技术路线------以 PAINN 为代表的等变网络,则以更加猛烈的攻势抢占分子表征疆土。

谁会是这片领土最终的统治者,我们不得而知。但至少属于不变网络的辉煌已经落幕,本文间记之。

相关推荐
IT古董16 分钟前
【机器学习】机器学习的基本分类-强化学习-策略梯度(Policy Gradient,PG)
人工智能·机器学习·分类
centurysee17 分钟前
【最佳实践】Anthropic:Agentic系统实践案例
人工智能
mahuifa17 分钟前
混合开发环境---使用编程AI辅助开发Qt
人工智能·vscode·qt·qtcreator·编程ai
四口鲸鱼爱吃盐19 分钟前
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
人工智能·pytorch·分类
落魄君子23 分钟前
ELM分类-单隐藏层前馈神经网络(Single Hidden Layer Feedforward Neural Network, SLFN)
神经网络·分类·数据挖掘
蓝天星空32 分钟前
Python调用open ai接口
人工智能·python
睡觉狂魔er33 分钟前
自动驾驶控制与规划——Project 3: LQR车辆横向控制
人工智能·机器学习·自动驾驶
scan7241 小时前
LILAC采样算法
人工智能·算法·机器学习
leaf_leaves_leaf1 小时前
win11用一条命令给anaconda环境安装GPU版本pytorch,并检查是否为GPU版本
人工智能·pytorch·python
夜雨飘零11 小时前
基于Pytorch实现的说话人日志(说话人分离)
人工智能·pytorch·python·声纹识别·说话人分离·说话人日志