TensorFlow深度学习实战——节点分类

TensorFlow深度学习实战------节点分类

    • [0. 前言](#0. 前言)
    • [1. 数据分析](#1. 数据分析)
    • [2. 构建节点分类模型](#2. 构建节点分类模型)
    • [3. 模型训练与评估](#3. 模型训练与评估)
    • 相关链接

0. 前言

节点分类是图数据领域的一个常见任务。在这一任务中,模型的训练目标是预测节点的类别。非图分类方法仅使用节点特征向量实现节点分类,早期的图神经网络 (Graph Neural Network, GNN)方法(如 DeepWalknode2vec )仅使用邻接矩阵(连接信息)实现节点分类,而 GNN 能够同时利用节点特征向量和连接信息进行节点分类。

1. 数据分析

本质上,节点分类的思路是对图中的所有节点应用一个或多个图卷积,将节点的特征向量投影到相应的输出类别向量中,以预测节点的类别。本节,将使用 CORA 数据集训练节点分类模型,CORA 数据集是一个包含 2,708 篇科学论文的集合,每篇论文可以分类为七个类别之一。这些论文以及它们之间的引用关系构成了一个包含 5,429 个链接的引文网络,每篇论文由一个大小为 1,433 的词向量描述。

(1) 首先,导入所需库:

python 复制代码
import dgl
import dgl.data
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
from dgl.nn.tensorflow import GraphConv

(2) 加载 CORA 数据集:

python 复制代码
dataset = dgl.data.CoraGraphDataset()

(3) 第一次调用时,它会记录下载和提取到本地文件的过程。完成后,它会输出一些有关 CORA 数据集的统计信息。可以看到,图中有 2,708 个节点和 10,566 条边。每个节点都有一个大小为 1,433 的特征向量,节点被分类为七个类别之一,此外,有 140 个训练样本、500 个验证样本和 1,000 个测试样本:

shell 复制代码
  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done saving data into cached files.

CORA 数据集是一个单一的引文图,可以通过 len(dataset) 来验证,将返回 1。这意味着模型将处理 dataset[0] 提供的图,节点特征作为键值对包含在字典 dataset[0].ndata 中,边特征则在 dataset[0].edata 中。ndata 包含键 train_maskval_masktest_mask,这些是布尔掩码,表示哪些节点属于训练、验证和测试集,还有一个 feat 键,包含图中每个节点的特征向量。

2. 构建节点分类模型

构建一个包含两个 GraphConv 层的 NodeClassifier 网络。每一层将通过聚合邻居信息计算新的节点表示。GraphConv 层是 tf.keras.layers.Layer 对象,因此可以进行堆叠。第一个 GraphConv 层将输入特征(大小为 1,433 )投影到大小为 16 的隐藏特征向量上,第二个 GraphConv 层将隐藏特征向量投影到大小为 2 的输出类别向量,从中获取类别:

python 复制代码
"""Defining a Graph Convolutional Network (GCN)"""
class NodeClassifier(tf.keras.Model):
    def __init__(self, g, in_feats, h_feats, num_classes):
        super(NodeClassifier, self).__init__()
        self.g = g
        self.conv1 = GraphConv(in_feats, h_feats, activation=tf.nn.relu)
        self.conv2 = GraphConv(h_feats, num_classes)

    def call(self, in_feat):
        h = self.conv1(self.g, in_feat)
        h = self.conv2(self.g, h)
        return h

g = dataset[0]
model = NodeClassifier(g, g.ndata["feat"].shape[1], 16, dataset.num_classes)

需要注意的是,GraphConv 只是构建 NodeClassifier 模型的一种图神经网络层,DGL 提供了多种图卷积层,可以用来替换 GraphConv

3. 模型训练与评估

(1)CORA 数据集上训练模型。使用 AdamW 优化器,AdamW 优化器是 Adam 优化器的变体,能够得到更好的模型泛化能力,学习率为 1e-2,权重衰减为 5e-4,训练 200epoch。同时检测是否有可用的 GPU,如果有,将图数据转移到 GPU 上。如果检测到 GPUTensorFlow 会自动将模型转移到 GPU 上:

python 复制代码
"""Training the GCN"""
device = "/cpu:0"
gpus = tf.config.list_physical_devices("GPU")
if len(gpus) > 0:
    device = gpus[0]
g = g.to(device)

(2) 定义 do_eval() 方法,根据特征计算模型在(由布尔掩码拆分的)测试数据集上的准确率:

python 复制代码
def do_eval(model, features, labels, mask):
    logits = model(features, training=False)
    logits = logits[mask]
    labels = labels[mask]
    preds = tf.math.argmax(logits, axis=1)
    acc = tf.reduce_mean(tf.cast(preds == labels, dtype=tf.float32))
    return acc.numpy().item()

(3) 最后,定义训练循环:

python 复制代码
NUM_HIDDEN = 16
LEARNING_RATE = 1e-2
WEIGHT_DECAY = 5e-4
NUM_EPOCHS = 200

with tf.device(device):
    feats = g.ndata["feat"]
    labels = g.ndata["label"]
    train_mask = g.ndata["train_mask"]
    val_mask = g.ndata["val_mask"]
    test_mask = g.ndata["test_mask"]
    in_feats = feats.shape[1]
    n_classes = dataset.num_classes
    n_edges = dataset[0].number_of_edges()

    model = NodeClassifier(g, in_feats, NUM_HIDDEN, n_classes)
    loss_fcn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    optimizer = tf.keras.optimizers.AdamW(
        learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

    best_val_acc, best_test_acc = 0, 0
    history = []
    for epoch in range(NUM_EPOCHS):
        with tf.GradientTape() as tape:
            logits = model(feats)
            loss = loss_fcn(labels[train_mask], logits[train_mask])
            grads = tape.gradient(loss, model.trainable_weights)
            optimizer.apply_gradients(zip(grads, model.trainable_weights))
        
        val_acc = do_eval(model, feats, labels, val_mask)
        history.append((epoch + 1, loss.numpy().item(), val_acc))

        if epoch % 10 == 0:
            print("Epoch {:3d} | train loss: {:.3f} | val acc: {:.3f}".format(epoch, loss.numpy().item(), val_acc))

epochs = [epoch for epoch, _, _ in history]
losses = [loss for _, loss, _ in history]
val_accs = [val_acc for _, _, val_acc in history]

plt.subplot(2, 1, 1)
plt.plot(epochs, losses)
plt.xlabel("epochs")
plt.ylabel("train loss")

plt.subplot(2, 1, 2)
plt.plot(epochs, val_accs)
plt.xlabel("epochs")
plt.ylabel("val acc")

plt.tight_layout()
plt.show()

运行代码,训练运行过程输出如下,可以看到训练损失从 1.9 降低到 0.02,验证准确率从 0.13 提高到 0.78

(4) 评估训练好的节点分类器在测试数据集上的表现:

python 复制代码
test_acc = do_eval(model, feats, labels, test_mask)
print("Test acc: {:.3f}".format(test_acc))

打印出模型在测试数据集上的准确率如下:

shell 复制代码
Test acc: 0.779

相关链接

TensorFlow深度学习实战(1)------神经网络与模型训练过程详解
TensorFlow深度学习实战(2)------使用TensorFlow构建神经网络
TensorFlow深度学习实战(3)------深度学习中常用激活函数详解
TensorFlow深度学习实战(4)------正则化技术详解
TensorFlow深度学习实战(5)------神经网络性能优化技术详解
TensorFlow深度学习实战(6)------回归分析详解
TensorFlow深度学习实战(7)------分类任务详解
TensorFlow深度学习实战(8)------卷积神经网络
TensorFlow深度学习实战(9)------构建VGG模型实现图像分类
TensorFlow深度学习实战(10)------迁移学习详解
TensorFlow深度学习实战(11)------风格迁移详解
TensorFlow深度学习实战(12)------词嵌入技术详解
TensorFlow深度学习实战(13)------神经嵌入详解
TensorFlow深度学习实战(14)------循环神经网络详解
TensorFlow深度学习实战(15)------编码器-解码器架构
TensorFlow深度学习实战(16)------注意力机制详解
TensorFlow深度学习实战(17)------主成分分析详解
TensorFlow深度学习实战(18)------K-means 聚类详解
TensorFlow深度学习实战(19)------受限玻尔兹曼机
TensorFlow深度学习实战(20)------自组织映射详解
TensorFlow深度学习实战(21)------Transformer架构详解与实现
TensorFlow深度学习实战(22)------从零开始实现Transformer机器翻译
TensorFlow深度学习实战(23)------自编码器详解与实现
TensorFlow深度学习实战(24)------卷积自编码器详解与实现
TensorFlow深度学习实战(25)------变分自编码器详解与实现
TensorFlow深度学习实战(26)------生成对抗网络详解与实现
TensorFlow深度学习实战(27)------CycleGAN详解与实现
TensorFlow深度学习实战(28)------扩散模型(Diffusion Model)
TensorFlow深度学习实战(29)------自监督学习(Self-Supervised Learning)
TensorFlow深度学习实战(30)------强化学习(Reinforcement learning,RL)
TensorFlow深度学习实战(31)------强化学习仿真库Gymnasium
TensorFlow深度学习实战(32)------深度Q网络(Deep Q-Network,DQN)
TensorFlow深度学习实战(33)------深度确定性策略梯度
TensorFlow深度学习实战(34)------TensorFlow Probability
TensorFlow深度学习实战(35)------概率神经网络
TensorFlow深度学习实战(36)------自动机器学习(AutoML)
TensorFlow深度学习实战(37)------深度学习的数学原理
TensorFlow深度学习实战(38)------常用深度学习库
TensorFlow深度学习实战(39)------机器学习实践指南
TensorFlow深度学习实战(40)------图神经网络(GNN)

相关推荐
Aspect of twilight3 小时前
3D Gaussian Splatting论文简要解读与可视化复现(基于gsplat)
人工智能·深度学习·gsplat
成为深度学习高手4 小时前
DGCN+informer分类预测模型
人工智能·分类·数据挖掘
Sunhen_Qiletian4 小时前
卷积神经网络搭建实战(二)——基于PyTorch框架和本地自定义图像数据集的食物分类案例(附输入图片预测功能)
pytorch·分类·cnn
ouliten4 小时前
cuda编程笔记(29)-- CUDA Graph
笔记·深度学习·cuda
hzp6664 小时前
Magnus:面向大规模机器学习工作负载的综合数据管理方法
人工智能·深度学习·机器学习·大模型·llm·数据湖·大数据存储
m0_678693335 小时前
深度学习笔记39-CGAN|生成手势图像 | 可控制生成(Pytorch)
深度学习·学习·生成对抗网络
还是大剑师兰特5 小时前
Transformer 面试题及详细答案120道(91-100)-- 理论与扩展
人工智能·深度学习·transformer·大剑师
小白狮ww5 小时前
小米开源端到端语音模型 MiMo-Audio-7B-Instruct 语音智能与音频理解达 SOTA
人工智能·深度学习·机器学习
Blossom.1185 小时前
把AI“绣”进丝绸:生成式刺绣神经网络让古装自带摄像头
人工智能·pytorch·python·深度学习·神经网络·机器学习·fpga开发