All in One: Multi-Task Prompting for Graph Neural Networks学习笔记

简介

主要研究了图神经网络(GNN)中多任务提示(multi-task prompting)的方法。文中讨论了传统的GNN"预训练与微调"方法和下游任务割裂,特别是在节点级(node-level)、边级(edge-level)和图级(graph-level)任务之间(不同任务之间不通用)。为了克服这些问题,文中引入了"提示学习"(灵感来自自然语言处理中的方法)并将其应用到图任务中。核心思想是构建一个框架,利用图级提示来提升预训练图模型在多种任务中的可迁移性和泛化能力。

主要贡献

  • 提出了一个新的prompt结构,将NLP中的提示概念适应到图中,其中主要涉及三部分:

    • prompt tokens:每个token都是和节点特征向量大小一样的向量;token的数量一般远远小于节点数量和预训练时的隐藏层数量;
    • token structure:类似于节点的邻接矩阵,用来表示token之间的关系(通常是隐式的),提出了三种方法来得到:
      • 基于可调参数学出来,表示为
        A = ∪ i = 1 , j = i + 1 ∣ P ∣ − 1 a i j A=\cup_{i=1,j=i+1}^{|P|-1}{a_{ij}} A=∪i=1,j=i+1∣P∣−1aij
        其中 ∣ P ∣ |P| ∣P∣表示为token集的大小, a i j a_{ij} aij是一个可学习的参数反映了 p i p_i pi和 p j p_j pj两个token之间的联系
      • 通过计算token点积来确定联系,比如小于某个阈值(提前确定)时才确定存在关系
      • 把token视为独立的存在,两两之间不存在关系
    • inserting pattern:解决怎么将prompt图和输入子图结合起来的问题。比如可以定义成token p k p_k pk和节点特征向量 x i x_i xi的点积,这样结合后的节点向量就可以表示成 x ^ i = x i + ∑ k = 1 ∣ P ∣ w i k p k \hat x_i=x_i+\sum_{k=1}^{|P|}w_{ik}p_k x^i=xi+∑k=1∣P∣wikpk。 w i k w_{ik} wik是一个权重值,可以用来剔除不必要的结合(插入)
      w i k = { σ ( p k ⋅ x i T ) , σ ( p k ⋅ x i T ) > δ 0 , o t h e r w i s e w_{ik}=\begin{cases}\sigma(p_k·x_i^T),\quad &\sigma(p_k·x_i^T)>\delta\\0,\quad &otherwise\end{cases} wik={σ(pk⋅xiT),0,σ(pk⋅xiT)>δotherwise
      当然更直接的情况就是将 w i k w_{ik} wik视为1。
  • 提出了一些技术(诱导子图),将节点级和边级任务转化为图级任务(更通用),以便与预训练策略对齐。

  • 使用元学习技术来有效地学习多任务提示的初始化,从而提高任务的适应性和性能。元学习的目标是找到一个好的提示图初始化,使得图神经网络能够快速适应新的任务,尤其是在少样本的情况下。通过内循环和外循环的训练过程,模型在多个任务上进行训练,并利用查询集进行评估,最终优化出能够跨任务迁移的提示参数。这个方法不仅提高了任务之间的迁移能力,还能显著提升模型在少样本场景下的学习效率和泛化能力。

模型结构

红色即传统的GNN用于迁移时所使用的框架:先针对某一特定任务进行预训练,用于另外的下游任务时在不断的进行微调,费时费力,且效果不好。

蓝色即作者提出的结构,在传统的GNN的基础上,或者甚至可以去除微调的过程,先对每个任务做预训练(内循环)调参,而后进行多任务进行外循环调参。

相关推荐
北顾南栀倾寒4 分钟前
[杂学笔记]HTTP1.0和HTTP1.1区别、socket系列接口与TCP协议、传输长数据的时候考虑网络问题、慢查询如何优化、C++的垃圾回收机制
网络·c++·笔记·tcp/ip·mysql·http
啥也不会的菜鸟·20 分钟前
Redis7——进阶篇(三)
redis·学习·缓存·redis经典面试题
whennl22 分钟前
IO学习day3
学习
codexu_4612291871 小时前
Tauri跨端笔记实战(4) - 如何实现系统级截图
前端·笔记·rust·app·tauri
Q一件事1 小时前
生态安全相关文献推荐
学习
汇能感知1 小时前
不同类型光谱相机的技术差异比较
经验分享·笔记·科技
朝九晚五ฺ2 小时前
【Linux探索学习】第三十二弹——生产消费模型:基于阻塞队列和基于环形队列的两种主要的实现方法
linux·运维·学习
老哥不老2 小时前
深入 Vue.js 组件开发:从基础到实践
vue.js·笔记
小馒头学python2 小时前
【AIGC实战】蓝耘元生代部署通义万相2.1文生图,结尾附上提示词合集
python·学习·算法·aigc
Suckerbin2 小时前
Raven: 2靶场渗透测试
数据库·学习·安全·网络安全