图神经网络实战(18)——消息传播神经网络

图神经网络实战(18)------消息传播神经网络

    • [0. 前言](#0. 前言)
    • [1. 消息传播神经网络](#1. 消息传播神经网络)
    • [2. 实现 MPNN 框架](#2. 实现 MPNN 框架)
    • 小结
    • 系列链接

0. 前言

我们已经学习了多种图神经网络 (Graph Neural Networks, GNN)变体,包括图卷积网络 (Graph Convolutional Network, GCN)图注意力网络 (Graph Attention Networks,GAT)GraphSAGE 等。在本节中,我们将对这些变体 GNN 结构进行一般性总结,即 GNN 的通用框架,也是 GNN 架构的通用范式。研究 GNN 通用框架能够帮助我们更加清晰的对比各类 GNN 模型,同时也为 GNN 模型的扩展提供了灵活性。

1. 消息传播神经网络

我们已经学习了用于聚合和组合不同节点特征的不同函数,最简单的图神经网络 (Graph Neural Networks, GNN)层将相邻节点(包括目标节点本身)的特征与权重矩阵的线性组合求和,然后,使用求和输出取代之前的目标节点嵌入。

节点级算子可以表示为如下形式:
h i ′ = ∑ j ∈ N i h j W T h_i'=\sum_{j\in \mathcal N_i}h_jW^T hi′=j∈Ni∑hjWT

其中, N i \mathcal N_i Ni 是节点 i i i (包括节点本身)的邻居节点集, h i h_i hi 是第 i i i 个节点的嵌入, W W W 是权重矩阵。
图卷积网络 (Graph Convolutional Network, GCN)图注意力网络 (Graph Attention Networks,GAT)层分别为节点特征添加了固定权重和动态权重,但于 GNN 具有相同的思想。包括 GraphSAGE 的邻接采样和图同构网络 (Graph Isomorphism Network, GIN)的全局求和池化也没有改变 GNN 层的主要思想。综观这些变体,我们可以将 GNN 层归纳为称为消息传递神经网络 (Message Passing Neural Network, MPNNMP-GNN) 的通用框架。该框架于 2017 年由 Gilmer 等人提出,由三个主要操作组成:

  • 消息 (Message): 每个节点使用一个函数为每个邻居创建一条信息。它可以只包含自己的特征,也可以同时考虑相邻节点的特征和边特征
  • 聚合 (Aggregate): 每个节点都会使用置换不变函数(如求和)聚合来自邻居的信息
  • 更新 (`Update):每个节点使用一个函数来结合当前特征和聚合的消息来更新自己的特征。例如,通过引入了自循环来聚合第 i i i 个节点的当前特征。

上述过程可以用以下等式进行概括:
h i ′ = γ ( h i , ⊕ j ∈ N i ϕ ( h i , h j , e j , i ) ) h_i'=\gamma(h_i,\oplus_{j\in \mathcal N_i}\phi(h_i,h_j,e_{j,i})) hi′=γ(hi,⊕j∈Niϕ(hi,hj,ej,i))

其中, h i h_i hi 是节点 i i i 的节点嵌入, e j , i e_{j,i} ej,i 是 j → i j→i j→i 链接的边嵌入, ϕ \phi ϕ 是信息函数, ⊕ \oplus ⊕ 是聚合函数, γ γ γ 是更新函数。MPNN 框架如下所示:

2. 实现 MPNN 框架

使用 PyTorch Geometric 可以通过 MessagePassing 类直接实现消息传递神经网络 (Message Passing Neural Network, MPNN) 框架。接下来,使用 MessagePassing 类实现 GCN 层。

(1) 首先,导入所需的库:

python 复制代码
import numpy as np

import torch
from torch.nn import Linear
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

(2) 声明继承自 MessagePassingGCN 类:

python 复制代码
class GCNConv(MessagePassing):

GCN 类需要两个参数------输入维度和输出(隐藏)维度进行初始化。MessagePassing 以 "add" 聚合初始化,在 __init__() 方法中定义一个未使用偏置的线性层:

python 复制代码
    def __init__(self, dim_in, dim_h):
        super().__init__(aggr='add')
        self.linear = Linear(dim_in, dim_h, bias=False)

定义 forward() 方法,首先在邻接矩阵中添加自循环,以考虑目标节点:

python 复制代码
    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

然后,使用线性层进行线性变换:

python 复制代码
        x = self.linear(x)

计算归一化因子 − 1 d e g ( i ) ⋅ d e g ( j ) -\frac 1 {\sqrt{deg(i)}\cdot \sqrt{deg(j)}} −deg(i) ⋅deg(j) 1.

python 复制代码
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

使用更新后的 edge_index (包括自循环)和存储在 norm 张量中的归一化因子调用 propagate() 方法,该方法调用 message()aggregate()update()。我们不需要重新定义 update(),因为它已经包含了自循环,在 __init__ 方法中,已经用 aggr='add' 指定了 aggregate() 函数:

python 复制代码
        out = self.propagate(edge_index, x=x, norm=norm)
        return out

重新定义 message() 函数,用 norm 对相邻节点特征 x 进行归一化处理:

python 复制代码
    def message(self, x, norm):
        return norm.view(-1, 1) * x

(3) 创建 GCN 实例:

python 复制代码
conv = GCNConv(16, 32)

以上代码展示了如何在 PyTorch Geometric 中使用 MPNN 创建 GCN 层,我们也可以尝试使用 MPNN 框架实现 GINGAT 层。

小结

消息传递神经网络 (Message Passing Neural Network, MPNN) 通过消息传播机制对多种图神经网络 (Graph Neural Networks, GNN) 模型做出了一般性总结。在本节中,我们介绍了 MPNN 框架,该框架通过信息、聚合和更新三个步骤统一了 GNN 层,同时介绍了如何在 PyTorch Geometric 中使用 MPNN 创建图卷积网络 (Graph Convolutional Network, GCN) 层。

系列链接

图神经网络实战(1)------图神经网络(Graph Neural Networks, GNN)基础
图神经网络实战(2)------图论基础
图神经网络实战(3)------基于DeepWalk创建节点表示
图神经网络实战(4)------基于Node2Vec改进嵌入质量
图神经网络实战(5)------常用图数据集
图神经网络实战(6)------使用PyTorch构建图神经网络
图神经网络实战(7)------图卷积网络(Graph Convolutional Network, GCN)详解与实现
图神经网络实战(8)------图注意力网络(Graph Attention Networks, GAT)
图神经网络实战(9)------GraphSAGE详解与实现
图神经网络实战(10)------归纳学习
图神经网络实战(11)------Weisfeiler-Leman测试
图神经网络实战(12)------图同构网络(Graph Isomorphism Network, GIN)
图神经网络实战(13)------经典链接预测算法
图神经网络实战(14)------基于节点嵌入预测链接
图神经网络实战(15)------SEAL链接预测算法
图神经网络实战(16)------经典图生成算法
图神经网络实战(17)------深度图生成模型

相关推荐
摸鱼仙人~1 小时前
Attention Free Transformer (AFT)-2020论文笔记
论文阅读·深度学习·transformer
python算法(魔法师版)2 小时前
深度学习深度解析:从基础到前沿
人工智能·深度学习
kakaZhui2 小时前
【llm对话系统】大模型源码分析之 LLaMA 位置编码 RoPE
人工智能·深度学习·chatgpt·aigc·llama
struggle20253 小时前
一个开源 GenBI AI 本地代理(确保本地数据安全),使数据驱动型团队能够与其数据进行互动,生成文本到 SQL、图表、电子表格、报告和 BI
人工智能·深度学习·目标检测·语言模型·自然语言处理·数据挖掘·集成学习
davenian6 小时前
DeepSeek-R1 论文. Reinforcement Learning 通过强化学习激励大型语言模型的推理能力
人工智能·深度学习·语言模型·deepseek
CM莫问6 小时前
什么是门控循环单元?
人工智能·pytorch·python·rnn·深度学习·算法·gru
饮马长城窟6 小时前
Paddle和pytorch不可以同时引用
人工智能·pytorch·paddle
纠结哥_Shrek7 小时前
pytorch生成对抗网络
人工智能·pytorch·生成对抗网络
纠结哥_Shrek8 小时前
pytorch实现文本摘要
人工智能·pytorch·python
sirius123451239 小时前
自定义数据集 使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测,对预测结果计算精确度和召回率及F1分数
人工智能·pytorch·逻辑回归