图神经网络并在 TensorFlow 中实现

asokraju.medium.com

一、说明

本文将引导您了解图神经网络 (GNN) 并使用 TensorFlow 实现该网络。在后续的 文章中,我们讨论 GNN 的不同变体及其实现。这是一个分步计划:

  1. 图神经网络 (GNN) 的使用:我们首先讨论 GNN 是什么、它们如何工作以及它们的使用地点。
  2. 理解图:在深入研究 GNN 之前,了解图的基础知识非常重要,包括节点、边、邻接矩阵和图表示。
  3. 理解图神经网络:我们还将简要介绍神经网络的基础知识,因为 GNN 是神经网络的一种。
  4. 图神经网络 (GNN) 的变体
  5. 使用 TensorFlow 实现 GNN:最后,我们将介绍使用 TensorFlow 实现简单 GNN 的过程。

二、图神经网络 (GNN) 的使用

图神经网络 (GNN) 是一种神经网络,旨在对图数据结构执行机器学习任务。它们对于数据以图形表示的任务特别有用,例如社交网络、分子结构和推荐系统。

GNN 的工作原理是将信息从节点传播到其邻居。图中的节点根据其邻居的状态进行更新,并且此过程会重复多次迭代。然后可以使用节点的最终状态进行预测。

例如,在社交网络中,GNN 可用于根据用户朋友的兴趣来预测用户的兴趣。 GNN 将从每个用户的一些初始表示开始,然后根据其朋友的表示更新每个用户的表示。经过几次迭代后,每个用户的最终表示不仅会捕获他们自己的兴趣,还会捕获他们的朋友、朋友的朋友等的兴趣。

三、理解图表:

是一种对对象之间的关系进行建模的数学结构。它由节点 (也称为顶点)和组成。节点代表对象,边代表这些对象之间的关系。

例如,在社交网络中,每个人可以由一个节点表示,每个友谊可以由连接两个节点的边表示。

有两种主要类型的图表:

  1. 无向图:在无向图中,边没有方向。也就是说,如果存在从节点 A 到节点 B 的边,则也存在从节点 B 到节点 A 的边。 Facebook 友谊就是这样的一个示例:如果人 A 是人 B 的朋友,那么人 B 也是人与 A 是朋友。
  2. 有向图:在有向图中,边确实有方向。也就是说,如果从节点 A 到节点 B 存在一条边,并不一定意味着从节点 B 到节点 A 也存在一条边。 Twitter 关注就是一个例子:如果 A 关注了 B,那么它就会关注 B。并不意味着B跟随A。

图可以用多种方式表示,但最常见的方式之一是通过邻接矩阵。邻接矩阵是一个方阵,其中第 i 行第 j 列中的条目等于节点 i 和 j 之间的边数。对于无向图,邻接矩阵是对称的。

另一种常见的表示形式是边列表,其中每条边由一对节点表示。

了解图的这些基础知识对于理解图神经网络的工作原理至关重要,因为它们直接在图结构上运行。

四、理解图神经网络

GNN 是一种神经网络,旨在对图数据结构执行机器学习任务。它们对于数据以图形表示的任务特别有用,例如社交网络、分子结构和推荐系统。

GNN 背后的关键思想是捕获图中连接之间的依赖关系。他们通过聚合相邻节点的特征来为每个节点生成嵌入来实现这一点。然后,这些嵌入可用于执行各种任务,例如节点分类、链接预测和图分类。

以下是 GNN 工作原理的更详细的分步过程:

  1. 节点特征初始化:图中的每个节点都使用特征向量进行初始化。这可能是节点标签的单热编码、特定于节点的一些实值向量,甚至是零向量。
  2. 特征聚合:每个节点聚合其邻近节点的特征向量以更新自己的特征向量。这通常是使用一个函数来完成的,该函数接收节点及其邻居的特征向量并输出一个新的特征向量。该函数可以是简单平均值、加权和或更复杂的函数。
  3. 特征变换:然后对聚合的特征向量进行变换,通常使用线性变换,然后使用非线性激活函数。这与传统神经网络层中发生的情况类似。
  4. 重复步骤 2 和 3:重复步骤 2 和 3 一定次数的迭代。在每次迭代中,节点都会聚合并转换来自越来越大邻域的特征。
  5. 读出:最终迭代后,使用读出函数聚合图中所有节点的特征向量以产生图级输出。

GNN 的优点在于它们可以处理不同大小和形状的图,并且可以捕获图的局部和全局结构。

五、使用 TensorFlow 实现 GNN

有几个构建在 TensorFlow 之上的库提供了各种类型的 GNN 的实现,例如 Graph Nets 和 Spektral。我们可以使用这些库之一来简化实现过程。

首先,您需要安装 Spektral 库。您可以使用 pip 执行此操作:

ba 复制代码
pip install spektral

安装 Spektral 后,您可以首先导入必要的库:

ba 复制代码
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Dropout
from spektral.layers import GCNConv, global_sum_pool
from spektral.data import DisjointLoader, Dataset
from spektral.datasets import TUDataset

在此示例中,我们将使用 TUDataset,它是用于图分类的基准数据集的集合。

接下来,让我们加载数据集:

ba 复制代码
dataset = TUDataset('PROTEINS')

这将下载 PROTEINS 数据集,这是蛋白质结构的图形分类数据集。

  1. 读出:在最后一层之后,使用读出函数聚合图中所有节点的特征向量以产生图级输出。

现在,让我们看看如何使用 TensorFlow 中的 Spektral 库实现一个简单的 GraphSAGE 模型:

ba 复制代码
import spektral
from spektral.layers import GraphSageConv
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dropout, Dense

# Define the model
class GraphSageModel(Model):
    def __init__(self, n_hidden, n_labels):
        super().__init__()
        self.sage_conv1 = GraphSageConv(n_hidden)
        self.sage_conv2 = GraphSageConv(n_labels)
        self.dropout = Dropout(0.5)
        self.dense = Dense(n_labels, 'softmax')

    def call(self, inputs, training=False):
        x, a = inputs
        x = self.dropout(x, training=training)
        x = self.sage_conv1([x, a])
        x = self.sage_conv2([x, a])
        return self.dense(x)

# Instantiate the model
model = GraphSageModel(n_hidden=64, n_labels=dataset.n_labels)

该模型将由其节点特征表示的图作为输入x、邻接矩阵a和批次索引i.该模型首先对节点特征应用 dropout,然后应用两个图卷积层,将节点特征池化为图级表示,最后应用密集层来预测每个图的类别。

接下来,让我们编译并训练我们的模型:

ba 复制代码
model = GNN(n_hidden=64, n_labels=dataset.n_labels)
model.compile('adam', 'categorical_crossentropy', ['acc'])
loader = DisjointLoader(dataset, batch_size=32, epochs=10)
model.fit(loader.load(), steps_per_epoch=loader.steps_per_epoch)

什么是global_sum_pool represent?**

在图神经网络(GNN)的背景下,池化是一种用于将整个图的信息聚合成单个向量表示的技术。这对于图级预测任务特别有用,我们想要对整个图(而不是单个节点或边)进行预测。

global_sum_pool是 Spektral 库提供的一种此类池化操作。顾名思义,它只是将图中所有节点的特征向量相加以生成单个向量。此操作对于图中节点的顺序是不变的,这对于许多基于图的任务来说是一个重要属性。

值得注意的是,求和池化是一种非常简单的池化操作,GNN 中还可以使用许多其他更复杂的池化操作,例如均值池化、最大池化以及更复杂的方法,例如图注意力池化和图同构池化。池化操作的选择会对 GNN 的性能产生重大影响,而最佳选择通常取决于具体的任务和数据。

i 表示 x = self.pool(x, i) 是什么?

函数调用中的i表示每个节点的批次索引。global_sum_pool(x, i)

当您在批量设置中处理图形数据(即单个批次中的多个图形)时,您需要一种方法来指示哪些节点属于哪些图形。这是因为与图像或文本数据不同,批次中的图可以具有不同的大小(即不同数量的节点和边),因此不能简单地将它们堆叠在单个张量中。

批次索引i 是一个向量,它将每个节点分配给批次中的特定图。例如,如果批次中有两个图表,第一个有 3 个节点,第二个有 2 个节点,则批次索引 i 将为 [0, 0, 0, 1, 1]。这表明前三个节点属于第一个图,最后两个节点属于第二个图。

在后续文章中,我们讨论 GNN 的不同变体及其实现。

相关推荐
jay神几秒前
基于YOLOv8的传送带异物检测系统
人工智能·python·深度学习·yolo·可视化·计算机毕业设计
强风7942 分钟前
OpenCV基础入门
人工智能·opencv·计算机视觉
小超同学你好3 分钟前
Langgragh 19. Skills 4. SkillToolset 式设计 —— 工具化按需加载的 Skills(含代码示例)
人工智能·语言模型·langchain
人工智能培训4 分钟前
如何衔接知识图谱与图神经网络
人工智能·神经网络·知识图谱
火星资讯7 分钟前
Zenlayer Fabric Port 新加坡首发:城域免费,全球畅连
人工智能·科技
新缸中之脑7 分钟前
20个Nano Banana 2创意工作流
人工智能
智驱力人工智能9 分钟前
馆藏文物预防性保护依赖的图像分析技术 文物损害检测 文物破损检测 文物损害识别误报率优化方案 文物安全巡查AI系统案例 智慧文保AI监测
人工智能·算法·安全·yolo·边缘计算
tobias.b11 分钟前
机器学习 超清晰通俗讲解 + 核心算法全解(深度+易懂版)
人工智能·算法·机器学习
code_pgf11 分钟前
Jetson 上 OpenClaw + Ollama + llama.cpp 的联动配置模板部署大模型
服务器·数据库·人工智能·llama