PyTorch实战(14)——图注意力网络(Graph Attention Network,GAT)

PyTorch实战(14)------图注意力网络(Graph Attention Network,GAT)

    • [0. 前言](#0. 前言)
    • [1. 图注意力网络](#1. 图注意力网络)
    • [2. 模型构建](#2. 模型构建)
    • [3. 模型训练](#3. 模型训练)
    • 小结
    • 系列链接

0. 前言

我们已经通过使用图卷积网络 (Graph Convolutional Network, GCN) 模型在节点分类任务上具备了超越了基线多层感知机 (Multilayer Perceptron, MLP) 模型的性能。在本节中,我们将通过将 GCN 模型替换为图注意力网络 (Graph Attention Network, GAT) 模型来进一步提高分类准确率,核心改进在于将邻域节点信息平均聚合机制替换为注意力机制。接下来,将基于 GCN 的解决方案重构为基于 GAT 的解决方案。

1. 图注意力网络

GCN 采用平均值聚合机制来汇聚相邻节点的特征信息,但这种方式存在固有缺陷------它默认所有邻居节点的重要性相同,而实际情况未必如此。例如,假设节点 XY 的初始特征值完全相同,且邻居集合也一致,GCN 模型会将它们归为同一类别或簇。但这不一定是正确的。为了捕捉图中此类细微差异,我们可以用注意力机制替代简单的平均值聚合,这正是图注意力网络 (Graph Attention Network, GAT) 的核心思想。

我们已经在文本数据中学习过注意力机制,在文本数据的上下文中,其本质是为句子中的不同单词分配差异化的重要性权重,从而聚焦关键部分进行后续处理。在 GAT 中,注意力机制允许模型在分类节点类型时,为不同邻居节点分配不同的权重,从而构建更复杂、更强大的模型。通过注意力机制,我们可以为每个邻居节点学习注意力系数,这些系数为模型增加了可训练参数。下图展示了 GATGCN 在聚合邻居节点特征时的对比。

如上图所示,我们引入了一组新的可训练参数------注意力向量。该向量的长度是单个节点特征向量的两倍,因为它需要与"当前节点特征和邻居节点特征拼接结果"进行点积运算。在给定网络层中,这个可学习的注意力向量会在所有<节点,邻居>对之间共享。这组额外的可训练参数使 GAT 能够针对不同特征维度和不同邻居学习差异化权重。

每个<节点,邻居>对的注意力系数计算过程如下:先将节点特征与邻居特征拼接,再与注意力向量做点积,最后通过 Leaky ReLU 激活函数。邻居节点的最终权重通过对注意力系数进行 softmax 归一化得到。这个 softmax 函数不仅增加了非线性,还确保了所有权重之和为 1。通过这种方式,GAT 显著增强了利用可训练注意力参数从邻居节点提取信息的机制。

2. 模型构建

(1) 首先,定义 GAT 模型架构及其前向传播函数:

python 复制代码
class GAT(torch.nn.Module):
    def __init__(self, hidden_channels, heads):
        super().__init__()
        self.conv1 = GATConv(dataset.num_features, hidden_channels, heads)
        self.conv2 = GATConv(hidden_channels * heads, dataset.num_classes, heads=1)

    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return x

model = GAT(hidden_channels=16, heads=8)
print(model)

输出结果如下所示:

shell 复制代码
GAT(
  (conv1): GATConv(3703, 16, heads=8)
  (conv2): GATConv(128, 6, heads=1)
)

该模型的一个关键特性在于每个 GATConv 层的注意力头数量。在第一层中,我们设置了 8 个注意力头,这意味着将通过 8 个并行且可独立训练的注意力系数,进行 8 个并行的邻居节点特征聚合,如下图所示。

经过第一层 GAT 处理后,这 8 个注意力头生成的特征向量(每个维度为 16)将被拼接起来,最终输出 128 维特征。第二层 GATConv 仅包含 1 个注意力头,直接输出对应 6 个节点类别的结果。在该模型的前向传播过程中,我们采用了多重 dropout 策略来应对模型复杂度提升可能导致的过拟合问题(特别是多注意力头机制使得图信息传递过程更为精细)。

(2) 将数据集所有节点输入刚定义好的未训练 GAT 模型,生成 6 个节点类别的概率向量。随后对这些 6 维特征应用 t-SNE 降维至 2 维,实现所有节点在二维平面上的可视化:

python 复制代码
model.eval()

out = model(data.x, data.edge_index)
visualize(out.detach().cpu().numpy(), data.y)

输出如下所示,节点呈现随机分布状态:

3. 模型训练

(1) 在定义优化器、损失函数以及模型训练和评估流程后,训练 GAT 模型 100epoch

python 复制代码
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-1)
criterion = torch.nn.CrossEntropyLoss()

def train():
      model.train()
      optimizer.zero_grad()  # Clear gradients.
      out = model(data.x, data.edge_index)  # Perform a single forward pass.
      loss = criterion(out[data.train_mask], data.y[data.train_mask])  # Compute the loss solely based on the training nodes.
      loss.backward()  # Derive gradients.
      optimizer.step()  # Update parameters based on gradients.
      return loss

def test(mask):
      model.eval()
      out = model(data.x, data.edge_index)
      pred = out.argmax(dim=1)  # Use the class with highest probability.
      correct = pred[mask] == data.y[mask]  # Check against ground-truth labels.
      acc = int(correct.sum()) / int(mask.sum())  # Derive ratio of correct predictions.
      return acc

for epoch in range(1, 101):
    loss = train()
    val_acc = test(data.val_mask)
    test_acc = test(data.test_mask)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_acc:.4f}')

输出结果如下所示:

可以看到,验证集的准确率 (72.20%) 略高于 GCN 模型的验证集准确率 (70.00%)。

(2) 我们使用训练好的 GAT 模型,评估其在测试集上的准确率:

python 复制代码
test_acc = test(data.test_mask)
print(f'Test Accuracy: {test_acc:.4f}')

输出结果如下所示:

shell 复制代码
Test Accuracy: 0.7110

相较于 GCN 模型 69.60% 的准确率,采用 GAT 模型后性能进一步提升 2.5 个百分点。这一提升源自注意力层的强大特性------该机制不仅为模型增加了更多可训练参数,还赋予图模型更高灵活性,从而可以学习不同邻居节点之间的自定义关系。

(3) 最后,我们使用训练好的 GAT 模型对所有图节点进行预测,并将 6 维节点类别概率经 t-SNE 降维至 2 维实现可视化:

python 复制代码
out = model(data.x, data.edge_index)
visualize(out.detach().cpu().numpy(), data.y)

输出结果如下所示:

与随机分布相比,节点已形成明显的聚类分布,这表明模型确实得到了有效训练。更重要的是,该模型学习到的节点表征质量优于 GCN 模型,不同类别节点的分离边界更加清晰可辨。

小结

本节介绍了图注意力网络 (Graph Attention Network, GAT) 在节点分类任务中的应用。相比图卷积网络 (Graph Convolutional Network, GCN)的平均聚合机制,GAT 通过引入注意力机制,能够为不同邻居节点分配差异化权重,从而更精准地捕捉图结构信息。实验表明,GATCiteSeer 数据集上的分类准确率达到 71.1%,较 GCN 提升 2.5 个百分点。可视化结果显示,GAT 学习到的节点表征具有更好的类别区分度。这验证了注意力机制在图数据建模中的有效性,为处理复杂图结构任务提供了更优解决方案。

系列链接

PyTorch实战(1)------深度学习(Deep Learning)
PyTorch实战(2)------使用PyTorch构建神经网络
PyTorch实战(3)------PyTorch vs. TensorFlow详解
PyTorch实战(4)------卷积神经网络(Convolutional Neural Network,CNN)
PyTorch实战(5)------深度卷积神经网络
PyTorch实战(6)------模型微调详解
PyTorch实战(7)------循环神经网络
PyTorch实战(8)------图像描述生成
PyTorch实战(9)------从零开始实现Transformer
PyTorch实战(10)------从零开始实现GPT模型
PyTorch实战(11)------随机连接神经网络(RandWireNN)
PyTorch实战(12)------图神经网络(Graph Neural Network,GNN)
PyTorch实战(13)------图卷积网络(Graph Convolutional Network,GCN)

相关推荐
、、、、南山小雨、、、、2 小时前
云主机GPU pyTorch部署
人工智能·pytorch·python
西猫雷婶3 小时前
CNN计算|原始矩阵扩充后的多维度卷积核计算效果
人工智能·pytorch·深度学习·神经网络·机器学习·矩阵·cnn
盼小辉丶16 小时前
图机器学习(7)——图神经网络 (Graph Neural Network, GNN)
人工智能·神经网络·图神经网络·图机器学习
nix.gnehc19 小时前
PyTorch自动求导
人工智能·pytorch·python
多恩Stone19 小时前
【Pytorch 深入理解(2)】减少训练显存-Gradient Checkpointing
人工智能·pytorch·python
broken_utopia20 小时前
PyTorch中view/transpose/permute的内存可视化解析
人工智能·pytorch·python
LDG_AGI20 小时前
【推荐系统】深度学习训练框架(七):PyTorch DDP(DistributedDataParallel)中,每个rank的batch数必须相同
网络·人工智能·pytorch·深度学习·机器学习·spark·batch
远瞻。21 小时前
【环境部署】安装flash-attention
pip·注意力机制
上天夭21 小时前
PyTorch的Dataloader模块解析
人工智能·pytorch·python