图神经网络实战(18)——消息传播神经网络

图神经网络实战(18)------消息传播神经网络

    • [0. 前言](#0. 前言)
    • [1. 消息传播神经网络](#1. 消息传播神经网络)
    • [2. 实现 MPNN 框架](#2. 实现 MPNN 框架)
    • 小结
    • 系列链接

0. 前言

我们已经学习了多种图神经网络 (Graph Neural Networks, GNN)变体,包括图卷积网络 (Graph Convolutional Network, GCN)图注意力网络 (Graph Attention Networks,GAT)GraphSAGE 等。在本节中,我们将对这些变体 GNN 结构进行一般性总结,即 GNN 的通用框架,也是 GNN 架构的通用范式。研究 GNN 通用框架能够帮助我们更加清晰的对比各类 GNN 模型,同时也为 GNN 模型的扩展提供了灵活性。

1. 消息传播神经网络

我们已经学习了用于聚合和组合不同节点特征的不同函数,最简单的图神经网络 (Graph Neural Networks, GNN)层将相邻节点(包括目标节点本身)的特征与权重矩阵的线性组合求和,然后,使用求和输出取代之前的目标节点嵌入。

节点级算子可以表示为如下形式:
h i ′ = ∑ j ∈ N i h j W T h_i'=\sum_{j\in \mathcal N_i}h_jW^T hi′=j∈Ni∑hjWT

其中, N i \mathcal N_i Ni 是节点 i i i (包括节点本身)的邻居节点集, h i h_i hi 是第 i i i 个节点的嵌入, W W W 是权重矩阵。
图卷积网络 (Graph Convolutional Network, GCN)图注意力网络 (Graph Attention Networks,GAT)层分别为节点特征添加了固定权重和动态权重,但于 GNN 具有相同的思想。包括 GraphSAGE 的邻接采样和图同构网络 (Graph Isomorphism Network, GIN)的全局求和池化也没有改变 GNN 层的主要思想。综观这些变体,我们可以将 GNN 层归纳为称为消息传递神经网络 (Message Passing Neural Network, MPNNMP-GNN) 的通用框架。该框架于 2017 年由 Gilmer 等人提出,由三个主要操作组成:

  • 消息 (Message): 每个节点使用一个函数为每个邻居创建一条信息。它可以只包含自己的特征,也可以同时考虑相邻节点的特征和边特征
  • 聚合 (Aggregate): 每个节点都会使用置换不变函数(如求和)聚合来自邻居的信息
  • 更新 (`Update):每个节点使用一个函数来结合当前特征和聚合的消息来更新自己的特征。例如,通过引入了自循环来聚合第 i i i 个节点的当前特征。

上述过程可以用以下等式进行概括:
h i ′ = γ ( h i , ⊕ j ∈ N i ϕ ( h i , h j , e j , i ) ) h_i'=\gamma(h_i,\oplus_{j\in \mathcal N_i}\phi(h_i,h_j,e_{j,i})) hi′=γ(hi,⊕j∈Niϕ(hi,hj,ej,i))

其中, h i h_i hi 是节点 i i i 的节点嵌入, e j , i e_{j,i} ej,i 是 j → i j→i j→i 链接的边嵌入, ϕ \phi ϕ 是信息函数, ⊕ \oplus ⊕ 是聚合函数, γ γ γ 是更新函数。MPNN 框架如下所示:

2. 实现 MPNN 框架

使用 PyTorch Geometric 可以通过 MessagePassing 类直接实现消息传递神经网络 (Message Passing Neural Network, MPNN) 框架。接下来,使用 MessagePassing 类实现 GCN 层。

(1) 首先,导入所需的库:

python 复制代码
import numpy as np

import torch
from torch.nn import Linear
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

(2) 声明继承自 MessagePassingGCN 类:

python 复制代码
class GCNConv(MessagePassing):

GCN 类需要两个参数------输入维度和输出(隐藏)维度进行初始化。MessagePassing 以 "add" 聚合初始化,在 __init__() 方法中定义一个未使用偏置的线性层:

python 复制代码
    def __init__(self, dim_in, dim_h):
        super().__init__(aggr='add')
        self.linear = Linear(dim_in, dim_h, bias=False)

定义 forward() 方法,首先在邻接矩阵中添加自循环,以考虑目标节点:

python 复制代码
    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

然后,使用线性层进行线性变换:

python 复制代码
        x = self.linear(x)

计算归一化因子 − 1 d e g ( i ) ⋅ d e g ( j ) -\frac 1 {\sqrt{deg(i)}\cdot \sqrt{deg(j)}} −deg(i) ⋅deg(j) 1.

python 复制代码
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

使用更新后的 edge_index (包括自循环)和存储在 norm 张量中的归一化因子调用 propagate() 方法,该方法调用 message()aggregate()update()。我们不需要重新定义 update(),因为它已经包含了自循环,在 __init__ 方法中,已经用 aggr='add' 指定了 aggregate() 函数:

python 复制代码
        out = self.propagate(edge_index, x=x, norm=norm)
        return out

重新定义 message() 函数,用 norm 对相邻节点特征 x 进行归一化处理:

python 复制代码
    def message(self, x, norm):
        return norm.view(-1, 1) * x

(3) 创建 GCN 实例:

python 复制代码
conv = GCNConv(16, 32)

以上代码展示了如何在 PyTorch Geometric 中使用 MPNN 创建 GCN 层,我们也可以尝试使用 MPNN 框架实现 GINGAT 层。

小结

消息传递神经网络 (Message Passing Neural Network, MPNN) 通过消息传播机制对多种图神经网络 (Graph Neural Networks, GNN) 模型做出了一般性总结。在本节中,我们介绍了 MPNN 框架,该框架通过信息、聚合和更新三个步骤统一了 GNN 层,同时介绍了如何在 PyTorch Geometric 中使用 MPNN 创建图卷积网络 (Graph Convolutional Network, GCN) 层。

系列链接

图神经网络实战(1)------图神经网络(Graph Neural Networks, GNN)基础
图神经网络实战(2)------图论基础
图神经网络实战(3)------基于DeepWalk创建节点表示
图神经网络实战(4)------基于Node2Vec改进嵌入质量
图神经网络实战(5)------常用图数据集
图神经网络实战(6)------使用PyTorch构建图神经网络
图神经网络实战(7)------图卷积网络(Graph Convolutional Network, GCN)详解与实现
图神经网络实战(8)------图注意力网络(Graph Attention Networks, GAT)
图神经网络实战(9)------GraphSAGE详解与实现
图神经网络实战(10)------归纳学习
图神经网络实战(11)------Weisfeiler-Leman测试
图神经网络实战(12)------图同构网络(Graph Isomorphism Network, GIN)
图神经网络实战(13)------经典链接预测算法
图神经网络实战(14)------基于节点嵌入预测链接
图神经网络实战(15)------SEAL链接预测算法
图神经网络实战(16)------经典图生成算法
图神经网络实战(17)------深度图生成模型

相关推荐
AI大模型知识分享2 小时前
Prompt最佳实践|如何用参考文本让ChatGPT答案更精准?
人工智能·深度学习·机器学习·chatgpt·prompt·gpt-3
小言从不摸鱼4 小时前
【AI大模型】ChatGPT模型原理介绍(下)
人工智能·python·深度学习·机器学习·自然语言处理·chatgpt
artificiali8 小时前
Anaconda配置pytorch的基本操作
人工智能·pytorch·python
酱香编程,风雨兼程8 小时前
深度学习——基础知识
人工智能·深度学习
#include<菜鸡>9 小时前
动手学深度学习(pytorch土堆)-04torchvision中数据集的使用
人工智能·pytorch·深度学习
拓端研究室TRL9 小时前
TensorFlow深度学习框架改进K-means聚类、SOM自组织映射算法及上海招生政策影响分析研究...
深度学习·算法·tensorflow·kmeans·聚类
chnyi6_ya10 小时前
深度学习的笔记
服务器·人工智能·pytorch
i嗑盐の小F11 小时前
【IEEE出版,高录用 | EI快检索】第二届人工智能与自动化控制国际学术会议(AIAC 2024,10月25-27)
图像处理·人工智能·深度学习·算法·自然语言处理·自动化
卡卡大怪兽11 小时前
深度学习:数据集处理简单记录
人工智能·深度学习
菜就多练_082811 小时前
《深度学习》深度学习 框架、流程解析、动态展示及推导
人工智能·深度学习