图神经网络实战(18)——消息传播神经网络

图神经网络实战(18)------消息传播神经网络

    • [0. 前言](#0. 前言)
    • [1. 消息传播神经网络](#1. 消息传播神经网络)
    • [2. 实现 MPNN 框架](#2. 实现 MPNN 框架)
    • 小结
    • 系列链接

0. 前言

我们已经学习了多种图神经网络 (Graph Neural Networks, GNN)变体,包括图卷积网络 (Graph Convolutional Network, GCN)图注意力网络 (Graph Attention Networks,GAT)GraphSAGE 等。在本节中,我们将对这些变体 GNN 结构进行一般性总结,即 GNN 的通用框架,也是 GNN 架构的通用范式。研究 GNN 通用框架能够帮助我们更加清晰的对比各类 GNN 模型,同时也为 GNN 模型的扩展提供了灵活性。

1. 消息传播神经网络

我们已经学习了用于聚合和组合不同节点特征的不同函数,最简单的图神经网络 (Graph Neural Networks, GNN)层将相邻节点(包括目标节点本身)的特征与权重矩阵的线性组合求和,然后,使用求和输出取代之前的目标节点嵌入。

节点级算子可以表示为如下形式:
h i ′ = ∑ j ∈ N i h j W T h_i'=\sum_{j\in \mathcal N_i}h_jW^T hi′=j∈Ni∑hjWT

其中, N i \mathcal N_i Ni 是节点 i i i (包括节点本身)的邻居节点集, h i h_i hi 是第 i i i 个节点的嵌入, W W W 是权重矩阵。
图卷积网络 (Graph Convolutional Network, GCN)图注意力网络 (Graph Attention Networks,GAT)层分别为节点特征添加了固定权重和动态权重,但于 GNN 具有相同的思想。包括 GraphSAGE 的邻接采样和图同构网络 (Graph Isomorphism Network, GIN)的全局求和池化也没有改变 GNN 层的主要思想。综观这些变体,我们可以将 GNN 层归纳为称为消息传递神经网络 (Message Passing Neural Network, MPNNMP-GNN) 的通用框架。该框架于 2017 年由 Gilmer 等人提出,由三个主要操作组成:

  • 消息 (Message): 每个节点使用一个函数为每个邻居创建一条信息。它可以只包含自己的特征,也可以同时考虑相邻节点的特征和边特征
  • 聚合 (Aggregate): 每个节点都会使用置换不变函数(如求和)聚合来自邻居的信息
  • 更新 (`Update):每个节点使用一个函数来结合当前特征和聚合的消息来更新自己的特征。例如,通过引入了自循环来聚合第 i i i 个节点的当前特征。

上述过程可以用以下等式进行概括:
h i ′ = γ ( h i , ⊕ j ∈ N i ϕ ( h i , h j , e j , i ) ) h_i'=\gamma(h_i,\oplus_{j\in \mathcal N_i}\phi(h_i,h_j,e_{j,i})) hi′=γ(hi,⊕j∈Niϕ(hi,hj,ej,i))

其中, h i h_i hi 是节点 i i i 的节点嵌入, e j , i e_{j,i} ej,i 是 j → i j→i j→i 链接的边嵌入, ϕ \phi ϕ 是信息函数, ⊕ \oplus ⊕ 是聚合函数, γ γ γ 是更新函数。MPNN 框架如下所示:

2. 实现 MPNN 框架

使用 PyTorch Geometric 可以通过 MessagePassing 类直接实现消息传递神经网络 (Message Passing Neural Network, MPNN) 框架。接下来,使用 MessagePassing 类实现 GCN 层。

(1) 首先,导入所需的库:

python 复制代码
import numpy as np

import torch
from torch.nn import Linear
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

(2) 声明继承自 MessagePassingGCN 类:

python 复制代码
class GCNConv(MessagePassing):

GCN 类需要两个参数------输入维度和输出(隐藏)维度进行初始化。MessagePassing 以 "add" 聚合初始化,在 __init__() 方法中定义一个未使用偏置的线性层:

python 复制代码
    def __init__(self, dim_in, dim_h):
        super().__init__(aggr='add')
        self.linear = Linear(dim_in, dim_h, bias=False)

定义 forward() 方法,首先在邻接矩阵中添加自循环,以考虑目标节点:

python 复制代码
    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

然后,使用线性层进行线性变换:

python 复制代码
        x = self.linear(x)

计算归一化因子 − 1 d e g ( i ) ⋅ d e g ( j ) -\frac 1 {\sqrt{deg(i)}\cdot \sqrt{deg(j)}} −deg(i) ⋅deg(j) 1.

python 复制代码
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

使用更新后的 edge_index (包括自循环)和存储在 norm 张量中的归一化因子调用 propagate() 方法,该方法调用 message()aggregate()update()。我们不需要重新定义 update(),因为它已经包含了自循环,在 __init__ 方法中,已经用 aggr='add' 指定了 aggregate() 函数:

python 复制代码
        out = self.propagate(edge_index, x=x, norm=norm)
        return out

重新定义 message() 函数,用 norm 对相邻节点特征 x 进行归一化处理:

python 复制代码
    def message(self, x, norm):
        return norm.view(-1, 1) * x

(3) 创建 GCN 实例:

python 复制代码
conv = GCNConv(16, 32)

以上代码展示了如何在 PyTorch Geometric 中使用 MPNN 创建 GCN 层,我们也可以尝试使用 MPNN 框架实现 GINGAT 层。

小结

消息传递神经网络 (Message Passing Neural Network, MPNN) 通过消息传播机制对多种图神经网络 (Graph Neural Networks, GNN) 模型做出了一般性总结。在本节中,我们介绍了 MPNN 框架,该框架通过信息、聚合和更新三个步骤统一了 GNN 层,同时介绍了如何在 PyTorch Geometric 中使用 MPNN 创建图卷积网络 (Graph Convolutional Network, GCN) 层。

系列链接

图神经网络实战(1)------图神经网络(Graph Neural Networks, GNN)基础
图神经网络实战(2)------图论基础
图神经网络实战(3)------基于DeepWalk创建节点表示
图神经网络实战(4)------基于Node2Vec改进嵌入质量
图神经网络实战(5)------常用图数据集
图神经网络实战(6)------使用PyTorch构建图神经网络
图神经网络实战(7)------图卷积网络(Graph Convolutional Network, GCN)详解与实现
图神经网络实战(8)------图注意力网络(Graph Attention Networks, GAT)
图神经网络实战(9)------GraphSAGE详解与实现
图神经网络实战(10)------归纳学习
图神经网络实战(11)------Weisfeiler-Leman测试
图神经网络实战(12)------图同构网络(Graph Isomorphism Network, GIN)
图神经网络实战(13)------经典链接预测算法
图神经网络实战(14)------基于节点嵌入预测链接
图神经网络实战(15)------SEAL链接预测算法
图神经网络实战(16)------经典图生成算法
图神经网络实战(17)------深度图生成模型

相关推荐
靴子学长44 分钟前
基于字节大模型的论文翻译(含免费源码)
人工智能·深度学习·nlp
海棠AI实验室2 小时前
AI的进阶之路:从机器学习到深度学习的演变(一)
人工智能·深度学习·机器学习
四口鲸鱼爱吃盐3 小时前
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
人工智能·pytorch·分类
leaf_leaves_leaf3 小时前
win11用一条命令给anaconda环境安装GPU版本pytorch,并检查是否为GPU版本
人工智能·pytorch·python
夜雨飘零14 小时前
基于Pytorch实现的说话人日志(说话人分离)
人工智能·pytorch·python·声纹识别·说话人分离·说话人日志
四口鲸鱼爱吃盐4 小时前
Pytorch | 从零构建MobileNet对CIFAR10进行分类
人工智能·pytorch·分类
苏言の狗4 小时前
Pytorch中关于Tensor的操作
人工智能·pytorch·python·深度学习·机器学习
paixiaoxin7 小时前
CV-OCR经典论文解读|An Empirical Study of Scaling Law for OCR/OCR 缩放定律的实证研究
人工智能·深度学习·机器学习·生成对抗网络·计算机视觉·ocr·.net
weixin_515202497 小时前
第R3周:RNN-心脏病预测
人工智能·rnn·深度学习
吕小明么9 小时前
OpenAI o3 “震撼” 发布后回归技术本身的审视与进一步思考
人工智能·深度学习·算法·aigc·agi