论文《Mixture of Weak & Strong Experts on Graphs》笔记

【Mowst 2024 ICLR】论文提出了一种新的图神经网络架构,称为Mixture of weak and strong experts(Mowst),通过将轻量级的多层感知机(MLP)作为弱专家和现成的GNN作为强专家相结合,以处理图中的节点特征和邻域结构。引入了基于弱专家预测结果离散度的"置信度"机制,以适应不同目标节点的专家协作。论文分析了置信函数对损失的影响,并揭示了训练动态,表明训练算法通过软分割图来鼓励专家的专门化。Mowst易于优化,表达能力强,计算成本与单一GNN相当,在多个节点分类基准测试中显示出显著的准确性提升。

发表在2024年ICLR会议上,作者学校:Meta AI和University of Rochester,引用量:2。

ICLR会议简介:全称International Conference on Learning Representations(国际学习表征会议),深度学习顶会。

查询会议:

原文和开源代码链接:

0、核心内容

现实的图包含(1)丰富的节点自特征和(2)邻域的信息结构,在典型的设置中由GNN共同处理。

我们建议通过弱专家和强专家的混合(Mowst)来解耦这两种模式,其中弱专家是一个轻量级的多层感知机(MLP),而强专家是一个现成的GNN。

为了使专家的协作适用于不同的目标节点,我们提出了一种基于弱专家预测日志的离散度(base on the dispersion of the weak expert's prediction logits)的"置信度"机制("confidence" mechanism)。当节点的分类依赖于邻域信息,或弱专家的模型质量较低时,强专家在低置信度区域被有条件地激活。通过分析置信函数对损失的影响,我们揭示了有趣的训练动力学:我们的训练算法通过有效地生成图的软分裂来鼓励每个专家的专门化。此外,我们的"confidence"设计对强专家施加了一个可取的偏好,以受益于GNN更好的泛化能力。

Mowst易于优化,并实现了强大的表达能力,其计算成本与单一GNN相当。根据经验,Mowst在4个主干GNN架构上,在6个标准节点分类基准上显示了显著的精度提高,包括同配图和异配图。

(引自摘要)

图:Mowst模型的总体设计框架

整个系统由一个弱专家(weak expert)、一个强专家(strong expert)和一个门控模块组成(gating module)。由于门控模块的协调,强弱专家之前出现了不同的协作行为。门控函数可以手动定义,也可以自动学习(通过一个额外的紧凑的MLP),根据弱专家预测日志的离散度计算置信分数(base on the dispersion of only the weak expert's prediction logits)。置信度得分在不同的目标节点上,根据专家在局部图区域上的相对强度而不同。分数还直接控制了每个专家自己的日志如何组合到系统的最终预测中。

1、先验知识
① base on the dispersion of only the weak expert's prediction logits是什么?

base on the dispersion of only the weak expert's prediction logits指的是基于弱专家(在这个模型中是一个轻量级的多层感知器,MLP)的预测结果的分散程度。这里的"分散程度"(dispersion)是一个统计学术语,用于衡量一组数值的分布范围或离散程度。在机器学习模型的上下文中,它可以用来衡量模型预测的不确定性或信心水平。

具体来说,在Mowst模型中,弱专家(MLP)首先对节点进行预测,生成一组预测结果(logits),这些结果可以被视为一个概率分布,表示模型对不同类别的置信度。然后,模型计算这个概率分布的分散度,这通常可以通过方差或熵等统计量来衡量。

  • 如果分散度较高,意味着模型对其预测不是很有信心,因为概率分布比较分散。
  • 如果分散度较低,意味着模型对其预测非常有信心,因为概率分布比较集中。

这个分散度被用来计算置信度(confidence)分数,该分数决定了在最终预测中,弱专家的预测结果应该被赋予多少权重。如果置信度分数较高,表示弱专家的预测被信任,那么它的预测结果将在最终的模型输出中占据较大比重;如果置信度分数较低,则表示强专家(GNN)的预测将被更多地依赖。

这种方法允许模型动态地根据每个节点的特定情况来调整不同专家的协作方式,从而提高整体模型的性能和泛化能力。

② The training should minimize the expected loss incurred in inference(训练节点应该尽量减少推理过程中产生的预期损失),如何理解这句话?

这句话的意思是,在训练机器学习模型时,应该尽量减少在推理(Inference)阶段预期会发生的损失(loss)。这里的"损失"通常指的是模型预测值与实际值之间的差异,这种差异可以通过特定的损失函数来量化。在机器学习中,模型训练的目标就是通过优化算法调整模型的参数,使得这个损失函数的值最小化。

简单来说,这句话强调了模型训练的一个重要目标:确保模型在实际应用中(即在推理阶段)能够尽可能准确地预测或分类,从而减少预测错误或分类错误带来的损失。

③ L M o w s t L_{Mowst} LMowst is fully differentiable( L M o w s t L_{Mowst} LMowst是完全可微分的),如何理解这句话?

在数学和机器学习领域,如果一个函数在某点可微分,那么它在该点的导数存在,这意味着函数在该点的局部可以用切线来近似。当一个函数是"完全可微分"的,这通常意味着它不仅在某个点可微分,而且在整个定义域内都可微分,并且其导数也是连续的。

在深度学习的上下文中,如果一个模型或其组成部分是完全可微分的,这意味着可以通过反向传播算法来计算模型参数的梯度,这是训练过程中优化模型的关键步骤。完全可微分的模型允许使用标准的梯度下降方法来更新参数,从而最小化损失函数。

④ 什么是standard (n-1)-simplex, S n − 1 S_{n-1} Sn−1?

在机器学习和统计学中,标准(n-1)单纯形(standard (n-1)-simplex)经常用来表示概率分布,因为概率分布的和必须为1,且每个概率值必须是非负的。在这种情况下, S n − 1 S_{n-1} Sn−1可以用来表示一个n类分类问题中的概率分布,其中 p i p_i pi表示第i类的预测概率,且所有概率之和为1。

2、引言
① 研究发现:图的不同部分可能出现不同的模式------同配模式和异配模式。

参考论文:

  • 局部同配和局部异配区域可能在一个图中共存:《Graph Neural Networks with Heterophily》
  • 根据局部连通性,图信号可以以不同的方式混合,通过节点级分类进行量化:《Breaking the Limit of Graph Neural Networks by Improving the Assortativity of Graphs with Local Mixing Patterns》
  • 图卷积迭代的次数应该根据每个目标节点周围邻域的拓扑结构进行调整:《Node Dependent Local Smoothing for Scalable Graph Learning》
② 现有GNN的局限性:许多被广泛使用的GNNs都有一个基本的局限性,因为它们是基于图的全局属性而设计的。

例如:

  • GCN和SGC使用全局拉普拉斯算子进行信号平滑;
  • GIN在所有目标节点上模拟具有相同k的k-hop子图同质检验;
  • GraphSAGE和GAT聚集了来自k跳邻居的特征,同样具有全局k。
③ 结论:通过在每个节点的基础上进行多样化对待来提高GNN的能力有很大的潜力。
④ 模型能力可以通过两种方式得到增强:

1)为单个GNN开发更高级的层架构,目的是使模型能够自动适应不同目标节点的独特特征。

  • 《How Attentive are Graph Attention Networks?》
  • 《Finding the missing-half:Graph complementary learning for homophily-prone and heterophily-prone graphs》
  • 《On the expressive power of geometric graph neural networks》

2)将现有的GNN模型纳入专家混合(Mixture-of-Experts,MoE)系统,考虑到MoE有效地改进了许多领域的模型能力。

  • 《Adaptive mixtures of local experts》
  • 《Graphdive:Graph classification by mixture of diverse experts》
  • 《Mixture of experts:a literature survey》
  • 《GLaM:Efficient scaling of language models with mixture-of-experts》
  • 《GShard:Scaling giant models with conditional computation and automatic sharding》
⑤ 本文的研究

在这项研究中,我们遵循MoE的设计理念,但后退一步,混合了一个简单的多层感知机(MLP)和一个现成的GNN------这是传统MoE中看到的故意不平衡的组合。

其主要动机是MLP和GNN模型可以专门解决图中两种最基本的模态:节点本身的特征及其邻域的结构。

MLP虽然比GNN弱得多,但是在各种情况下都可以发挥重要作用。

  • 例如,在节点特征相似的同型区域,利用MLP关注单个节点的丰富特征可能比通过GNN层聚合邻域特征更有效。
  • 相反,在高度异配的区域,信息传递可能会引入噪声,可能造成的危害大于利。(refer to《Beyond homophily in graph neural networks:Current limitations and effective designs》)

MLP专家可以帮助"清理"GNN的数据集,使强大的专家能够专注于更复杂的节点,这些节点的邻域结构为学习任务提供了有用的信息。

3、Mowst模型
① 整体模型

关键的挑战是设计混合模块,考虑到不平衡专家之间的微妙互动。

  • 一方面,弱专家应谨慎激活,以避免准确性下降。
  • 另一方面,对于能够真正被MLP掌握的节点,弱专家应该做出有意义的贡献,而不是被其更强的对应专家所掩盖。

算法1是模型的推理(预测)阶段,算法2是模型的训练阶段。在训练阶段,训练节点应该尽量减少推理过程中产生的预期损失。

损失函数:

训练策略:

  • 交替优化:通过交替固定一个专家的参数并优化另一个专家的参数,可以确保每个专家都能在不影响对方的情况下充分优化自己。
  • 置信度学习:如果置信度函数C是可学习的,那么在训练过程中也会更新它的参数。

总的来说,Mowst的训练过程是一个迭代的、交替优化的过程,旨在最小化整体损失,同时通过置信度机制来平衡两个专家的贡献,从而提高模型在节点分类任务中的性能。

② 协作行为

当优化公式1中的训练损失时,置信度C会在一些节点上积累,而在其他节点上减少。C的不同分布对应于两位专家可以专门化和协作的不同方式。

下面,我们从理论上揭示了控制C值的三个因素:自我特征信息的丰富度、两位专家之间的相对强度,以及置信函数的形状。对专家的相对强度的分析也揭示了为什么基于confidence的门是有偏差的。由于C和MLP的损失都是MLP预测 p p p的函数,我们分析了在给定一个固定GNN专家的情况下,最小化 L M o w s t L_{Mowst} LMowst的最优 p p p。

定理2.4:

定理2.4是论文中提出的一个理论结果,它描述了在特定的优化问题中,如何根据置信度函数C和损失函数L的性质来确定模型参数的最优值。

这个定理是关于Mowst模型中专家(MLP和GNN)的协作行为和训练动态的分析。

这个定理的直观理解是:

  • 当MLP的预测损失大于GNN的平均预测损失时,MLP应该完全不参与预测,这意味着我们完全信任GNN的预测。
  • 当MLP的预测损失等于GNN的平均预测损失时,MLP可以参与预测,但最终的预测可能完全由MLP或GNN决定。
  • 当MLP的预测损失小于GNN的平均预测损失时,MLP应该参与预测,并且其预测应该接近最优预测,或者在置信度函数C的约束下足够接近 L μ α L_μ^α Lμα的水平集。

这个定理为Mowst模型的训练提供了理论支持,说明了如何通过调整MLP和GNN的协作来优化整体模型的性能。它揭示了置信度函数C如何影响模型在不同情况下的预测行为,以及如何通过训练动态来调整这种协作。

③ Mowst的一种变体

4、实验部分

5、参考资料
相关推荐
迅易科技18 分钟前
借助腾讯云质检平台的新范式,做工业制造企业质检的“AI慧眼”
人工智能·视觉检测·制造
古希腊掌管学习的神1 小时前
[机器学习]XGBoost(3)——确定树的结构
人工智能·机器学习
ZHOU_WUYI2 小时前
4.metagpt中的软件公司智能体 (ProjectManager 角色)
人工智能·metagpt
靴子学长2 小时前
基于字节大模型的论文翻译(含免费源码)
人工智能·深度学习·nlp
AI_NEW_COME3 小时前
知识库管理系统可扩展性深度测评
人工智能
海棠AI实验室4 小时前
AI的进阶之路:从机器学习到深度学习的演变(一)
人工智能·深度学习·机器学习
hunteritself4 小时前
AI Weekly『12月16-22日』:OpenAI公布o3,谷歌发布首个推理模型,GitHub Copilot免费版上线!
人工智能·gpt·chatgpt·github·openai·copilot
IT古董4 小时前
【机器学习】机器学习的基本分类-强化学习-策略梯度(Policy Gradient,PG)
人工智能·机器学习·分类
centurysee4 小时前
【最佳实践】Anthropic:Agentic系统实践案例
人工智能
mahuifa4 小时前
混合开发环境---使用编程AI辅助开发Qt
人工智能·vscode·qt·qtcreator·编程ai