图神经网络实战(18)——消息传播神经网络

图神经网络实战(18)------消息传播神经网络

    • [0. 前言](#0. 前言)
    • [1. 消息传播神经网络](#1. 消息传播神经网络)
    • [2. 实现 MPNN 框架](#2. 实现 MPNN 框架)
    • 小结
    • 系列链接

0. 前言

我们已经学习了多种图神经网络 (Graph Neural Networks, GNN)变体,包括图卷积网络 (Graph Convolutional Network, GCN)图注意力网络 (Graph Attention Networks,GAT)GraphSAGE 等。在本节中,我们将对这些变体 GNN 结构进行一般性总结,即 GNN 的通用框架,也是 GNN 架构的通用范式。研究 GNN 通用框架能够帮助我们更加清晰的对比各类 GNN 模型,同时也为 GNN 模型的扩展提供了灵活性。

1. 消息传播神经网络

我们已经学习了用于聚合和组合不同节点特征的不同函数,最简单的图神经网络 (Graph Neural Networks, GNN)层将相邻节点(包括目标节点本身)的特征与权重矩阵的线性组合求和,然后,使用求和输出取代之前的目标节点嵌入。

节点级算子可以表示为如下形式:
h i ′ = ∑ j ∈ N i h j W T h_i'=\sum_{j\in \mathcal N_i}h_jW^T hi′=j∈Ni∑hjWT

其中, N i \mathcal N_i Ni 是节点 i i i (包括节点本身)的邻居节点集, h i h_i hi 是第 i i i 个节点的嵌入, W W W 是权重矩阵。
图卷积网络 (Graph Convolutional Network, GCN)图注意力网络 (Graph Attention Networks,GAT)层分别为节点特征添加了固定权重和动态权重,但于 GNN 具有相同的思想。包括 GraphSAGE 的邻接采样和图同构网络 (Graph Isomorphism Network, GIN)的全局求和池化也没有改变 GNN 层的主要思想。综观这些变体,我们可以将 GNN 层归纳为称为消息传递神经网络 (Message Passing Neural Network, MPNNMP-GNN) 的通用框架。该框架于 2017 年由 Gilmer 等人提出,由三个主要操作组成:

  • 消息 (Message): 每个节点使用一个函数为每个邻居创建一条信息。它可以只包含自己的特征,也可以同时考虑相邻节点的特征和边特征
  • 聚合 (Aggregate): 每个节点都会使用置换不变函数(如求和)聚合来自邻居的信息
  • 更新 (`Update):每个节点使用一个函数来结合当前特征和聚合的消息来更新自己的特征。例如,通过引入了自循环来聚合第 i i i 个节点的当前特征。

上述过程可以用以下等式进行概括:
h i ′ = γ ( h i , ⊕ j ∈ N i ϕ ( h i , h j , e j , i ) ) h_i'=\gamma(h_i,\oplus_{j\in \mathcal N_i}\phi(h_i,h_j,e_{j,i})) hi′=γ(hi,⊕j∈Niϕ(hi,hj,ej,i))

其中, h i h_i hi 是节点 i i i 的节点嵌入, e j , i e_{j,i} ej,i 是 j → i j→i j→i 链接的边嵌入, ϕ \phi ϕ 是信息函数, ⊕ \oplus ⊕ 是聚合函数, γ γ γ 是更新函数。MPNN 框架如下所示:

2. 实现 MPNN 框架

使用 PyTorch Geometric 可以通过 MessagePassing 类直接实现消息传递神经网络 (Message Passing Neural Network, MPNN) 框架。接下来,使用 MessagePassing 类实现 GCN 层。

(1) 首先,导入所需的库:

python 复制代码
import numpy as np

import torch
from torch.nn import Linear
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

(2) 声明继承自 MessagePassingGCN 类:

python 复制代码
class GCNConv(MessagePassing):

GCN 类需要两个参数------输入维度和输出(隐藏)维度进行初始化。MessagePassing 以 "add" 聚合初始化,在 __init__() 方法中定义一个未使用偏置的线性层:

python 复制代码
    def __init__(self, dim_in, dim_h):
        super().__init__(aggr='add')
        self.linear = Linear(dim_in, dim_h, bias=False)

定义 forward() 方法,首先在邻接矩阵中添加自循环,以考虑目标节点:

python 复制代码
    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

然后,使用线性层进行线性变换:

python 复制代码
        x = self.linear(x)

计算归一化因子 − 1 d e g ( i ) ⋅ d e g ( j ) -\frac 1 {\sqrt{deg(i)}\cdot \sqrt{deg(j)}} −deg(i) ⋅deg(j) 1.

python 复制代码
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

使用更新后的 edge_index (包括自循环)和存储在 norm 张量中的归一化因子调用 propagate() 方法,该方法调用 message()aggregate()update()。我们不需要重新定义 update(),因为它已经包含了自循环,在 __init__ 方法中,已经用 aggr='add' 指定了 aggregate() 函数:

python 复制代码
        out = self.propagate(edge_index, x=x, norm=norm)
        return out

重新定义 message() 函数,用 norm 对相邻节点特征 x 进行归一化处理:

python 复制代码
    def message(self, x, norm):
        return norm.view(-1, 1) * x

(3) 创建 GCN 实例:

python 复制代码
conv = GCNConv(16, 32)

以上代码展示了如何在 PyTorch Geometric 中使用 MPNN 创建 GCN 层,我们也可以尝试使用 MPNN 框架实现 GINGAT 层。

小结

消息传递神经网络 (Message Passing Neural Network, MPNN) 通过消息传播机制对多种图神经网络 (Graph Neural Networks, GNN) 模型做出了一般性总结。在本节中,我们介绍了 MPNN 框架,该框架通过信息、聚合和更新三个步骤统一了 GNN 层,同时介绍了如何在 PyTorch Geometric 中使用 MPNN 创建图卷积网络 (Graph Convolutional Network, GCN) 层。

系列链接

图神经网络实战(1)------图神经网络(Graph Neural Networks, GNN)基础
图神经网络实战(2)------图论基础
图神经网络实战(3)------基于DeepWalk创建节点表示
图神经网络实战(4)------基于Node2Vec改进嵌入质量
图神经网络实战(5)------常用图数据集
图神经网络实战(6)------使用PyTorch构建图神经网络
图神经网络实战(7)------图卷积网络(Graph Convolutional Network, GCN)详解与实现
图神经网络实战(8)------图注意力网络(Graph Attention Networks, GAT)
图神经网络实战(9)------GraphSAGE详解与实现
图神经网络实战(10)------归纳学习
图神经网络实战(11)------Weisfeiler-Leman测试
图神经网络实战(12)------图同构网络(Graph Isomorphism Network, GIN)
图神经网络实战(13)------经典链接预测算法
图神经网络实战(14)------基于节点嵌入预测链接
图神经网络实战(15)------SEAL链接预测算法
图神经网络实战(16)------经典图生成算法
图神经网络实战(17)------深度图生成模型

相关推荐
棒棒的皮皮26 分钟前
【深度学习】YOLO-Python基础认知与算法演进
python·深度学习·yolo·计算机视觉
人工智能培训2 小时前
10分钟了解向量数据库(1)
人工智能·深度学习·算法·机器学习·大模型·智能体搭建
不吃香菜的鱼2 小时前
PyTorch-CUDA-v2.9镜像自动混合精度训练配置指南
pytorch·cuda·自动混合精度
新职语2 小时前
打造个人AI实验室:低成本使用PyTorch-CUDA-v2.8云实例
pytorch·cuda·云实例
小程故事多_802 小时前
从零吃透PyTorch,最易懂的入门全指南
人工智能·pytorch·python
大叔and小萝莉2 小时前
PyTorch-v2.8新特性解析:性能提升背后的秘密
pytorch· torch.compile· 性能优化
lifetime‵(+﹏+)′2 小时前
5060显卡Windows配置Anaconda中的CUDA及Pytorch
人工智能·pytorch·windows
老鱼说AI2 小时前
万字长文警告!一次性搞定GAN(生成对抗网络):从浅入深原理级精析 + PyTorch代码逐行讲解实现
人工智能·深度学习·神经网络·生成对抗网络·计算机视觉·ai作画·超分辨率重建
Kingston Chang2 小时前
利用PyTorch-CUDA镜像快速复现顶会论文实验结果
pytorch·镜像·cuda
START_GAME2 小时前
深度学习环境配置:PyTorch、CUDA和Python版本选择
人工智能·pytorch·深度学习