GPPT: Graph Pre-training and Prompt Tuning to Generalize Graph Neural Networks
KDD22
推荐指数:#paper/⭐⭐#
动机
本文探讨了图神经网络(GNN)在迁移学习中"预训练-微调"框架的局限性及改进方向。现有方法通过预训练(如边预测、对比学习)学习可迁移的图结构知识,在微调时将其应用于下游任务(如节点分类)。然而,预训练目标与下游任务之间的差异(如二元边预测与多类节点分类)导致知识传递低效甚至负迁移------微调效果可能逊于从头训练。传统改进方案依赖为每个下游任务定制预训练目标(目标工程),但需大量领域知识与试错成本。
受自然语言处理(NLP)中提示(Prompt)技术 的启发,作者提出"预训练-提示-微调"新范式,旨在通过任务重表述缩小预训练与下游任务差异。例如,NLP通过添加语义模板将分类任务转化为与预训练一致的填空任务(如情感分类转为预测掩码词)。然而,图数据面临两大挑战:
- 符号化图数据适配难题:节点为抽象符号,无法直接套用基于文本模板的语义改写。
- 提示设计的有效性:需结合图结构(如节点邻域信息)设计高效的提示函数,以提升分类等任务精度。
因此,本文核心研究问题聚焦于如何设计图感知提示函数,以桥接预训练与下游任务,从而高效激发预训练模型的知识。该方向有望通过任务形式统一化提升预训练模型的泛用性,减少对定制化目标工程的依赖,推动少样本图分析的进一步发展。
图提示框架
Pre-train, Prompt, Fine-tune
Graph prompting function(图提示函数)
v i ′ = f p r o m p t ( v i ) v_{i}^{\prime}=f_{\mathrm{prompt}}(v_{i}) vi′=fprompt(vi), v i ′ v_i' vi′和映射头有相似的输入形状
Pairwise prompting function(成对提示函数)
v i ′ = f p r o m p t ( v i ) = [ T t a s k ( y ) , T s r t ( v i ) ] v_{i}^{\prime}=f_{\mathrm{prompt}}(v_{i})=[T_{\mathbf{task}}(y),T_{\mathbf{srt}}( v_{i})] vi′=fprompt(vi)=[Ttask(y),Tsrt(vi)]
T t a s k T_{task} Ttask是下有任务的token, T s r c T_{src} Tsrc是目标节点结构的token。前者由待分类节点的标签得到,后者由目标节点周围子图表示,以提供更多的结构信息。很自然,可以利用函数来捕获他们两个的联系
Prompt addition
y 1 , ⋯ , y C \] \[y_1,\\cdots,y_C\] \[y1,⋯,yC\]为C个类的prompt。自然可以构造token对: \[ T t a s k ( y c ) , T s r t ( v i ) \] , f o r c = 1 , ⋯ , C \[T_{\\mathrm{task}}(y_{c}),T_{\\mathrm{srt}}(v_{i})\],\\mathrm{for\~}c=1,\\cdots,C \[Ttask(yc),Tsrt(vi)\],for c=1,⋯,C ##### Prompt answer 对于每个token对,我们可以拼接,并将其放入预训练的映射头,如果目标节点 v i v_i vi 与某类得到最高的链接概率,我们就将其归为一类。 ##### prompt tuning: min θ , ϕ ∑ ( v i , y c ) L p r e ( p ϕ p r e ( T t a s k ( y c ) , T s r t ( v i ) ) ; g ( y c , v i ) ) . \\min_{\\theta,\\phi}\\sum_{(v_i,y_c)}\\mathcal{L}\^{\\mathrm{pre}}(p_\\phi\^{\\mathrm{pre}}(T_{\\mathrm{task}}(y_c),T_{\\mathrm{srt}}(v_i));g(y_c,v_i)). minθ,ϕ∑(vi,yc)Lpre(pϕpre(Ttask(yc),Tsrt(vi));g(yc,vi)).其中,g为真实的标签函数 #### 图形提示功能设计 ##### 任务token的生成: e c = T t a s k ( y c ) ∈ R d e_c=T_\\mathrm{task}(y_c)\\in\\mathbb{R}\^d ec=Ttask(yc)∈Rd E = \[ e 1 , ⋯ , e C \] ⊤ ∈ R C × d E=\[e_{1},\\cdots,e_{C}\]\^{\\top}\\in\\mathbb{R}\^{C\\times d} E=\[e1,⋯,eC\]⊤∈RC×d,C是类别数。 很自然,每个节点的token可以通过查询如上的任务token得到自己的类别。很自然的是, T t a s k ( y c ) T_{\\mathbf{task}}(y_c) Ttask(yc)最优应该是类 y c y_c yc的中心。因此,我们通过聚类,来获得初始的tasktoken: 1. 利用可扩展聚类(比如metis)获得M个类: { G 1 , ⋯ , G M } \\{\\mathcal{G}_1,\\cdots,\\mathcal{G}_M\\} {G1,⋯,GM},M是类别超参。 2. 对于每个类,我们得到相应的task token: E m = \[ e 1 m , ⋯ , e C m \] ⊤ ∈ R C × d E\^m=\[e_1\^m,\\cdots,e_C\^m\]\^\\top\\in\\mathbb{R}\^{C\\times d} Em=\[e1m,⋯,eCm\]⊤∈RC×d(怎么感觉有问题这一行表述) 3. 给定集群 处节点 v i v_i vi 的任务令牌 T t a s k ( y c ) T_{task}(y_c) Ttask(yc) ,它使用向量嵌入 e c m e_c\^m ecm 表示。 ##### Structure Token Generation.(结构token的升成) 如果直接用节点v用于下游分类,会失去结构信息。因此我们使用 T s t r ( v i ) T_{\\mathrm{str}}(v_i) Tstr(vi)来表示子图结构,来涵盖结构信息。在本文中,作者使用一阶子图来表示。 e v i = a i ∗ h i + ∑ v j ∈ N ( v i ) a j ∗ h j . e_{v_i}=a_i\*h_i+\\sum_{v_j\\in\\mathcal{N}(v_i)}a_j\*h_j. evi=ai∗hi+∑vj∈N(vi)aj∗hj. a通过注意力机制得到 #### Prompt 初始化以及正交约束: 直接使用随机初始化肯定不太好,因此我们使用预训练的GNN来初始化 E m = \[ e 1 m , ⋯ , e C m \] ⊤ E\^{m}=\[e_{1}\^{m},\\cdots,e_{C}\^{m}\]\^{\\top} Em=\[e1m,⋯,eCm\]⊤。 因此,我们通过节点表示来初始化标记嵌入 e c m e\^m_c ecm,节点表示由集群 m 处 y c y_c yc类的训练节点给出。 不同类的中心的距离应该尽可能的打,因此有: L o = ∑ m ∥ E m ( E m ) ⊤ − I ∥ F 2 . \\mathcal{L}_o=\\sum_m\\\|E\^m(E\^m)\^\\top-I\\\|_F\^2. Lo=∑m∥Em(Em)⊤−I∥F2. #### 损失: min θ , ϕ , E 1 , ⋯ , E M ∑ ( v i , y c ) L p r e ( p ϕ p r e ( e c m , e v i ) ; g ( y c , v i ) ) + λ L o , s . t . θ i n i t = θ p r e , ϕ i n i t = ϕ p r e . \\begin{aligned}\\min_{\\theta,\\phi,E\^{1},\\cdots,E\^{M}}\&\\sum_{(v_{i},y_{c})}\\mathcal{L}\^{\\mathrm{pre}}(p_{\\phi}\^{\\mathrm{pre}}(e_{c}\^{m},e_{v_{i}});g(y_{c},v_{i}))+\\lambda\\mathcal{L}_{o},\\\\\\mathrm{s.t.}\&\\theta\^{\\mathrm{init}}=\\theta\^{\\mathrm{pre}},\\phi\^{\\mathrm{init}}=\\phi\^{\\mathrm{pre}}.\\end{aligned} θ,ϕ,E1,⋯,EMmins.t.(vi,yc)∑Lpre(pϕpre(ecm,evi);g(yc,vi))+λLo,θinit=θpre,ϕinit=ϕpre. ### 结果:  