元学习结合图神经网络用于药物发现

寻找具有良好药理活性、低毒性和适当药代动力学性质的候选分子是药物发现的重要任务。深度神经网络在加速和改进药物发现方面取得了令人印象深刻的进展。然而,这些技术依赖于大量的标记数据来形成分子性质的准确预测。在药物发现管道的每个阶段,通常只有少量候选分子和衍生物的生物学数据可用,这表明深度神经网络在低数据药物发现中的应用仍然是一个巨大的挑战。在这里,作者提出了一个带有GAT的元学习架构,Meta-GAT,用于预测低数据药物发现中的分子特性。GAT通过注意力机制在原子水平上捕捉原子群的局部效应,并隐含地在分子水平上捕捉不同原子群之间的相互作用。Meta-GAT进一步发展了基于双层优化的元学习策略,将来自其他属性预测任务的元知识转移到低数据目标任务中。总之,Meta-GAT展示了元学习如何减少在低数据情景下进行分子有意义预测所需的数据量。元学习很可能成为低数据药物发现中的新学习范式。

来自:Meta Learning With Graph Attention Networks for Low-Data Drug Discovery

工程地址:https://github.com/lol88/Meta-GAT

目录

背景概述

药物发现的关键问题是候选分子的筛选和优化,候选分子必须满足一系列标准:化合物需要具有合适的生物靶点潜力,并具有良好的物理化学性质,比如吸收、分布、代谢、排泄和毒性等,这些性质简称ADMET。然而,通常只有少数经过验证的数据可用。因此,如何以低数据准确预测候选分子的物理化学性质越来越受到研究人员的重视。

近年来,深度学习技术在分子性质预测、药物相互作用预测、虚拟筛选等方面取得了一些关键进展。特别是图神经网络(GNNs),它可以直接从化学图结构中学习节点和边所包含的信息,引起了生物信息学家的强烈兴趣。深度学习的性能很大程度上取决于训练数据的大小,更大的样本量通常会产生更准确的模型。在有大量标注数据的情况下,深度神经网络有足够的能力学习输入的复杂表示。然而,这显然与新疾病下的药物发现相矛盾。由于标记数据的稀缺性,对低数据药物发现取得令人满意的结果仍然是一个挑战。

人脑对客观事物的理解并不一定需要大样本训练,很多情况下可以通过简单的类比来学习。DeepMind探索大脑如何在很少的经验下学习,即"元学习"或"学会学习"。与当下火热的大模型相反,对元学习模式的理解是实现通用智能的重要途径之一。

元学习利用元知识来降低对样本复杂度的要求。然而,分子结构通常是由原子之间的相互作用和复杂的电子构型组成的。即使分子结构的微小变化也可能导致完全相反的分子性质。一个模型学习了分子结构的复杂性,这就要求该模型能够完美地提取出邻近原子对中心原子的局部环境影响,以及拓扑上相距较远的原子对之间包含的丰富的非局部信息。因此,用于低数据药物发现的元学习高度依赖于图结构,并且需要针对广泛不同的任务进行重新设计。

元学习在预测分子性质方面做了一些有代表性的尝试。Altae-Tran等人[43]引入了一种迭代改进的长短期记忆(IterRefLSTM)架构,该架构使用IterRefLSTM为one-shot学习生成嵌入。Adler等[44]提出了跨域Hebbian集成小样本学习(CHEF),通过Hebbian学习器的集成作用于深度神经网络的不同层来实现表示融合。元分子图神经网络(MGNN)利用预训练的GNN,并引入额外的自监督任务,如键重建和原子型预测,与分子性质预测任务共同优化[45]。Meta-MGNN,CHEF通过对大规模分子语料库和附加的自监督模型参数进行预训练获得元知识。IterRefLSTM训练内存要求高,这限制了模型结构,只能在特定的领域场景中使用。如何有效地表示分子特征以及如何捕获不同任务之间的共同知识是元学习中存在的巨大挑战。

在这项工作中,作者提出了一种基于图注意网络的元学习架构Meta-GAT,用于预测低数据药物发现中分子的生化特性。图注意网络通过三重注意机制捕获原子水平上原子群的局部效应,从而使GAT能够了解原子群对化合物性质的影响。在分子水平上,GAT将整个分子视为连接分子中每个原子的虚拟节点,隐式地捕获不同原子群之间的相互作用。门控递归单元(GRU)分层模型主要致力于将有限的分子信息抽象或转化为更高层次的特征向量或元知识,提高门控递归单元感知化学环境和分子连通性的能力,从而有效降低样本复杂度。这对于低数据药物发现非常重要。Meta-GAT受益于元知识,并进一步发展了一种基于双层优化的元学习策略,该策略将元知识从其他属性预测任务转移到低数据目标任务,使模型能够快速适应少样本的分子属性预测。

贡献包括:

  • 作者创造了一种化学工具来预测模型看不见的新分子的多种生理特性。这个工具可以推动低数据药物发现的分子表示的边界。
  • 所提出的Meta-GAT通过三重注意机制在原子水平上捕捉原子群的局部效应,也可以在分子水平上模拟分子的全局效应。
  • 作者提出了一种元学习策略,通过双层优化(bilevel optimization)有选择地更新每个任务中的参数,这对捕获不同任务之间共享的通用知识有帮助。
  • Meta-GAT展示了元学习如何减少在低数据药物发现中对分子进行有意义预测所需的数据量。

方法

问题的公式化

考虑几个常见的药物发现任务 T T T,例如预测新分子的毒性和副作用, x x x是要测量的化合物分子,标记 y y y是分子性质的二元实验标记(正/负)。假设考虑了所有可能的规则 H H H(假设空间)。 h h h是从 x x x到 y y y的最优假设。期望风险 R ( h ) R(h) R(h)表示决策模型对所有样本的预测能力。经验风险 R ( h I ) R(h_I) R(hI)通过计算损失函数的平均值表示模型对训练集中样本的预测能力, I I I表示训练集中样本的个数。使用经验风险 R ( h I ) R(h_I) R(hI)来估计期望风险 R ( h ) R(h) R(h)。在实际应用中,对于新分子的性质预测任务,只有几个例子可用,即 I → I→ I→few。根据经验风险最小化理论,如果只提供少量的训练样本,使得经验风险 R ( h I ) R(h_I) R(hI)与期望风险 R ( h ) R(h) R(h)的近似值相去甚远,则得到的经验风险最小化器是不可靠的。学习的挑战是从几个例子中获得可靠的经验风险最小化: E [ R ( h I → f e w ) − R ( h ) ] = 0 \mathbb{E}[R(h_{I→few})-R(h)]=0 E[R(hI→few)−R(h)]=0经验风险最小化与样本复杂度密切相关。样本复杂度是指最小化经验风险 R ( h I ) R(h_I) R(hI)所需的训练样本数量。我们使用元知识 w w w来降低学习样本的复杂性,从而解决最小化不可靠经验风险的核心问题。

元学习

元学习,也称为learning to learn,是指通过系统地观察模型在广泛的学习任务中的表现来学习如何学习的经验。这种学习经验被称为元知识 w w w。元学习的目标是找到不同任务之间共享的 w w w,这样模型就可以快速泛化到只包含少数有监督示例的新任务。

元学习和迁移学习的区别在于,迁移学习通常是拟合一个数据的分布,而元学习是拟合多个相似任务的分布。因此,元学习的训练样本是一系列的任务。

MAML被用作Meta-GAT框架的基本元学习算法。Meta-GAT通过双层优化有选择地更新每个任务中的参数,并将元知识转移到标签样本较少的新任务中,如图1所示。双层优化意味着一个优化包含另一个优化作为约束。在内层优化中,我们希望从训练任务的支持集中学习一个通用的元知识 w w w,使不同任务的损失尽可能小。内层优化阶段可以形式化,如下所示: θ ∗ ( i ) ( w ) = a r g m i n θ L f θ t a s k ( θ , w , D t r a i n s ( i ) ) \theta^{*(i)}(w)=argmin_{\theta}L_{f_{\theta}}^{task}(\theta,w,D_{train}^{s(i)}) θ∗(i)(w)=argminθLfθtask(θ,w,Dtrains(i))外层优化阶段,Meta-GAT计算每个任务query集中相对于最优参数的梯度,计算所有训练任务的最小总损失值,对 w w w参数进行优化,从而降低训练任务的预期损失,如下所示: w ∗ = a r g m i n w ∑ i = 1 M L f θ m e t a ( θ ∗ ( i ) ( w ) , D t r a i n q ) w^{*}=argmin_{w}\sum_{i=1}^{M}L_{f_{\theta}}^{meta}(\theta^{*(i)}(w),D^{q}_{train}) w∗=argminwi=1∑MLfθmeta(θ∗(i)(w),Dtrainq)其中, L m e t a L^{meta} Lmeta和 L t a s k L^{task} Ltask分别是外层和内层优化目标。 i i i为第 i i i个训练任务。

  • 图1:元学习框架的few-shot分子性质预测。蓝色框和橙色框分别表示训练阶段和测试阶段的数据流。Eye Disorders:眼部疾病,Cardiac Disorders:心脏疾病,Vascular Disorders:血管疾病。

具体来说:

首先 ,训练任务 T t r a i n T_{train} Ttrain和测试任务 T t e s t T_{test} Ttest是从一组用于药物发现的任务集合 T T T划分的,每个任务都有一个support set D s D^{s} Ds和query set D q D^{q} Dq。Meta-GAT使用大量的训练任务来拟合多个相似任务 T T T的分布。

其次 ,Meta-GAT依次迭代一批训练任务,学习特定于任务的参数。对应的优化参数 θ \theta θ从每个任务的support set中获得: θ i ′ = θ − α ∇ θ L T t r a i n ( f θ ) \theta_{i}'=\theta - \alpha\nabla_{\theta}L_{T_{train}}(f_{\theta}) θi′=θ−α∇θLTtrain(fθ)然后,优化参数不会赋给 θ \theta θ,而是被缓存。然后,外部优化学习 w w w,再产生模型 f θ f_{\theta} fθ: θ = θ − β ∇ θ ∑ T t r a i n ∼ p ( T ) L T t r a i n ′ ( f θ i ′ ) \theta=\theta-\beta\nabla_{\theta}\sum_{T_{train}\sim p(T)}L'{T{train}}(f_{\theta_{i}'}) θ=θ−β∇θTtrain∼p(T)∑LTtrain′(fθi′)每个任务的query set用于获得每个任务特定参数 θ θ θ的梯度值。从上述批处理任务query set获得的梯度值的向量之和用于更新元学习器的参数。模型继续迭代到预设次数,并根据query set选择最佳元模型。

最后 ,在测试阶段,学习了元知识 w w w的Meta-GAT通过对新任务的support set进行一些内部优化来学习新测试任务的特异性,如下所示。注意,模型参数 θ θ θ单独存在或存在于元知识 w w w中。 θ ∗ = a r g m i n θ E T t e s t ∈ T L ( θ , w , D t e s t s ) \theta^{*}=argmin_{\theta}\mathbb{E}{T{test}\in T}L(\theta,w,D_{test}^{s}) θ∗=argminθETtest∈TL(θ,w,Dtests)模型的性能通过对新任务query set的精度来衡量。在学习新任务的过程中,模型利用元知识降低了对样本复杂度的要求,从而实现在假设空间 H H H中更快地搜索参数优化策略。

Meta-GAT本质上是寻找一种更适合预测药物分子性质的假设。因此,在更新参数时,它结合query set上所有任务的损失来指定梯度更新。用这种方法得到的参数已经是对新任务的近似最优假设,并且只需很少的内部迭代即可达到最优假设。

分子表示

分子被编码成具有节点特征、边特征和邻接矩阵的图,以便输入到GNNs中。总共使用了9个原子特征和4个键特征来表征分子图的结构(表1)。分子结构通常涉及原子相互作用和复杂的电子结构,键的特征包含了分子支架和构象异构体的丰富信息。

  • 表1:特征编码初始化。

图注意力

GNN在化学学领域取得了重大进展。它具有学习结构和性质之间复杂关系的能力。注意力机制在预测分子性质方面具有突出的作用。

分子结构涉及原子的空间位置和化学键的类型。分子中拓扑上相邻的节点有更大的机会相互作用。在某些情况下,它们还可以形成决定分子化学性质的官能团。此外,拓扑结构上相距较远的原子对也可能具有重要的相互作用,例如氢键。图注意力网络从局部和全局角度提取分子结构和特征,如图2所示。GAT通过注意力机制在原子水平上捕捉原子群的局部效应,也可以在分子水平上模拟分子的全局效应。

  • 图2:Meta-GAT中的GAT结构。

分子 G = ( v , e ) G=(v,e) G=(v,e)被定义为图,我们将包含9个原子特征和4个键特征的化学信息编码到分子图中,作为图注意网络的输入。对于分子内部的局部环境,以往的图网络只聚集了相邻节点的信息,可能导致边(键)信息提取不足

具体来说,GAT首先对相邻节点的向量 v i v_i vi、 v j v_j vj及其边状态 e i j e_{ij} eij进行线性变换和非线性激活,使这些向量对齐到同一维,并将它们连接成三重嵌入向量。然后, h i j h_{ij} hij被softmax进行归一化得到注意力 a i j a_{ij} aij。最后,将节点隐藏状态和边隐藏状态元素级乘以邻居节点表示,根据注意力对邻居(包括邻居节点和边)的信息进行聚合,得到原子 i i i的上下文状态 c i c_i ci: h i j = L e a k y R e L U ( W ⋅ [ v i , e i j , v j ] ) a i j = s o f t m a x ( h i j ) = e x p ( h i j ) ∑ j ∈ N ( i ) e x p ( h i j ) c i = ∑ j ∈ N ( i ) a i j ⋅ W ⋅ [ e i j , v j ] h_{ij}=LeakyReLU(W\cdot[v_{i},e_{ij},v_{j}])\\ a_{ij}=softmax(h_{ij})=\frac{exp(h_{ij})}{\sum_{j\in N(i)}exp(h_{ij})}\\ c_{i}=\sum_{j\in N(i)}a_{ij}\cdot W\cdot [e_{ij},v_{j}] hij=LeakyReLU(W⋅[vi,eij,vj])aij=softmax(hij)=∑j∈N(i)exp(hij)exp(hij)ci=j∈N(i)∑aij⋅W⋅[eij,vj]其中, W W W是可学习矩阵,然后,使用GRU作为消息传递函数,将半径更远的消息进行融合,生成新的上下文状态,如图2(左下)所示。 h i t = G R U ( h i t − 1 , c i t − 1 ) h_{i}^{t}=GRU(h_{i}^{t-1},c_{i}^{t-1}) hit=GRU(hit−1,cit−1)为了包含来自分子的更多全局信息,GAT通过readout函数聚合原子级表示,该函数将整个分子视为连接分子中原子的超节点。使用BiGRU和注意力拼接节点特征,得到分子表示 M M M。最终的向量表示是分子结构信息的高质量描述符,降低了元学习模型学习分子图中无监督信息的难度。

数据集

作者报告了Meta-GAT模型在多个公共基准数据集上的实验结果。表2显示了基准数据集的详细信息,包括类别、任务和分子数量。所有数据集都可以在公共项目MoleculNet上下载。

  • 表2:数据集统计信息。

模型实现与评估

Meta-GAT使用pytorch框架实现,并使用Adam优化器,其学习率为0.001,用于梯度下降优化。内部迭代的学习率是0.1。生成半径为2的原子周围的信息。全连接层的输出单位为200。GRU和BiGRU都有200个隐藏单元。梯度下降在训练和测试阶段的每个迭代中执行五次, α , β = 5 \alpha,\beta=5 α,β=5。在训练过程中,在N-way K-shot中,生成1万episodes, K = N p o s + N n e g   ( N p o s , N n e g ∈ [ 1 , 5 , 10 ] ) K = N_{pos} + N_{neg}\thinspace(N_{pos}, N_{neg}∈[1,5,10]) K=Npos+Nneg(Npos,Nneg∈[1,5,10])。 N p o s N_{pos} Npos和 N n e g N_{neg} Nneg分别表示support set中正反例的个数。使用CrossEntropyLoss作为分类任务的损失函数。在预测阶段,从任务的数据集中随机抽取一批大小为 N p o s + N n e g N_{pos} + N_{neg} Npos+Nneg的支持集和一批大小为 K = 128 K = 128 K=128的查询集。对于每个测试任务,基于不同的随机种子进行20次独立运行。

此外,作者还分析了总训练时间、元训练时间、元测试时间,MACs和模型大小,以评估所提出方法的计算复杂度。Meta-GAT包括两个步骤,元训练阶段和元测试阶段。总训练时间是指在新任务上稳定Meta-GAT性能的成本。元训练时间是元训练阶段一次迭代的花费。元测试时间是指Meta-GAT在元测试阶段学习少样本的分子新性质预测任务的成本。在迭代中,support set和query set都参与模型前向计算,并执行一次或多次梯度下降迭代。元训练阶段的一次迭代开销即元训练时间为模型前向计算时间的 2 N ∗ α 2N∗α 2N∗α倍,而元测试时间为模型前向计算时间的 2 ∗ β 2∗β 2∗β倍。GeForce RTX 2060上, N N N为8, α α α和 β β β为5。

Meta-GAT在Tox21和Side Effect Resource (SIDER)数据集上的平均前向计算时间分别为14.84和23.08 ms。因此,元训练时间约为1187.2 ms和1846.4 ms,元测试时间约为148.4 ms和230.8 ms,总训练时间约为7.3 h和6 h。Meta-GAT的MACs为3.17E9,模型大小为4.8 M。

Reference

[43]Low data drug discovery with one-shot learning

[44]Cross-domain few-shot learning by representation fusion

[45]Few-shot graph learning for molecular property prediction

相关推荐
青椒大仙KI118 分钟前
24/11/14 算法笔记<强化学习> 马尔可夫
人工智能·笔记·机器学习
GOTXX17 分钟前
NAT、代理服务与内网穿透技术全解析
linux·网络·人工智能·计算机网络·智能路由器
进击的小小学生26 分钟前
2024年第45周ETF周报
大数据·人工智能
VertexGeek30 分钟前
Rust学习(四):作用域、所有权和生命周期:
java·学习·rust
TaoYuan__1 小时前
机器学习【激活函数】
人工智能·机器学习
TaoYuan__1 小时前
机器学习的常用算法
人工智能·算法·机器学习
正义的彬彬侠1 小时前
协方差矩阵及其计算方法
人工智能·机器学习·协方差·协方差矩阵
致Great1 小时前
Invar-RAG:基于不变性对齐的LLM检索方法提升生成质量
人工智能·大模型·rag
抱走江江1 小时前
SpringCloud框架学习(第二部分:Consul、LoadBalancer和openFeign)
学习·spring·spring cloud
华奥系科技1 小时前
智慧安防丨以科技之力,筑起防范人贩的铜墙铁壁
人工智能·科技·安全·生活