什么是 GNN?用最简单的代码带你入门图神经网络!

💡 GNN 是什么?一篇文章带你入门图神经网络!

🌟 1. 什么是 GNN(图神经网络)?

GNN(Graph Neural Network) 是一种专门处理"图数据"的深度学习方法。

什么是"图"数据? 🤔

  • 社交网络 👥(微信、微博好友推荐)
  • 电商推荐 🛒(淘宝"猜你喜欢")
  • 知识图谱 📚(Google 搜索、ChatGPT 背后的逻辑)
  • 反欺诈检测 🔍(信用卡诈骗、金融风控)
  • 生物信息学 🧬(蛋白质结构预测)

📌 GNN 的目标
传统神经网络只能处理表格、文本、图片,而 GNN 能学习"节点之间的关系"

💡 一句话理解 GNN:

"你是谁,取决于你的邻居"

GNN 通过 "信息传递",让每个节点不仅学习自己的特征,还能学习"邻居节点"的特征,最终实现分类、推荐、预测等任务。


🌟 2. 用代码带你训练一个 GNN(GCN)

💡 我们用 PyTorch Geometric 训练 GCN(图卷积网络),看看 GNN 是怎么学习的! 🚀

✅ (1) 安装 PyTorch Geometric

如果你还没安装 PyG,可以运行:

bash 复制代码
pip install torch torchvision torchaudio
pip install torch-geometric

✅ (2) 载入 Cora 论文引用数据集

python 复制代码
from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='data', name='Cora')
data = dataset[0]  # 只有一个图
print(f"节点数: {data.num_nodes}, 边数: {data.num_edges}")

📌 解释:

  • Cora 数据集 是一个论文引用网络,论文是节点(Nodes) ,论文之间的引用是边(Edges)
  • 目标:根据论文内容,预测论文属于哪个学科(7 类)。

📊 数据统计

makefile 复制代码
节点数: 2708
边数: 10556
特征维度: 1433
类别数: 7

✅ (3) 定义 GNN(GCN)

python 复制代码
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)  # 第一层 GCN
        self.conv2 = GCNConv(16, dataset.num_classes)  # 输出 7 类

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)  # 激活函数
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)  # 输出分类概率

📌 代码解析:

  • 第一层 GCN:将论文的 1433 维特征降维到 16 维。
  • ReLU 激活:让 GCN 学习更复杂的关系。
  • 第二层 GCN:输出 7 维分类得分(对应 7 个研究领域)。

✅ (4) 训练 GCN

python 复制代码
model = GCN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

def train():
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

for epoch in range(200):
    loss = train()
    if epoch % 20 == 0:
        print(f"Epoch {epoch}, Loss: {loss:.4f}")

📌 训练过程

  1. 前向传播:计算论文的分类得分。
  2. 计算损失:交叉熵(NLL Loss)。
  3. 反向传播:更新模型参数。
  4. 循环训练 200 轮(每 20 轮输出一次损失)。

✅ (5) 测试 GCN

python 复制代码
def test():
    model.eval()
    out = model(data)
    pred = out.argmax(dim=1)  # 取最大类别
    correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
    acc = correct / data.test_mask.sum()
    print(f"测试集准确率: {acc:.4f}")

test()
  • model.eval() 👉 进入测试模式
  • argmax(dim=1) 👉 选取最大得分的类别
  • 最终测试集准确率约 80% 🎯

🌟 3. 训练 200 次 vs. 600 次,为什么提升不大?

你可能会好奇,为什么训练 200 轮后,增加到 600 轮,模型准确率提升很小?

📌 原因:

  1. 过早收敛(Early Convergence)

    • GCN 主要依赖 邻居信息传播,通常 100~200 轮就能收敛,继续训练效果不会大幅提升。
  2. 过拟合(Overfitting)

    • 训练 200 轮后,模型已经很好地学习到了数据模式,继续训练可能只是在记忆训练数据,而不能更好地泛化到测试集。

📌 解决方案

  • 观察训练曲线(如果 loss 变化不大,说明已经收敛)
  • 尝试 Dropout(减少过拟合)
  • 改变 GCN 结构(增加层数、更深的图传播)

🌟 4. 完整代码

python 复制代码
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid  # 经典的GNN数据集
from torch_geometric.nn import GCNConv  # GCN(图卷积网络)层

# 载入 Cora 数据集(包含论文引用网络)
dataset = Planetoid(root='data', name='Cora')

# 获取数据
data = dataset[0]  # 只有一个图
print(f"节点数: {data.num_nodes}, 边数: {data.num_edges}")
print(f"特征维度: {data.num_node_features}, 类别数: {dataset.num_classes}")
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)  # 第一层 GCN
        self.conv2 = GCNConv(16, dataset.num_classes)  # 第二层 GCN 输出类别

    def forward(self, data):
        x, edge_index = data.x, data.edge_index  # 获取节点特征 & 边信息
        x = self.conv1(x, edge_index)
        x = F.relu(x)  # ReLU 激活
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)  # 归一化得分

# 创建 GCN 模型
model = GCN()
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

def train():
    model.train()
    optimizer.zero_grad()  # 清空梯度
    out = model(data)  # 前向传播
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])  # 计算损失
    loss.backward()  # 反向传播
    optimizer.step()  # 更新参数
    return loss.item()

for epoch in range(600):
    loss = train()
    if epoch % 20 == 0:
        print(f"Epoch {epoch}, Loss: {loss:.4f}")

def test():
    model.eval()
    out = model(data)
    pred = out.argmax(dim=1)  # 取最大类别
    correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
    acc = correct / data.test_mask.sum()
    print(f"测试集准确率: {acc:.4f}")

test()

🌟 5. 结论

GNN 适用于图数据(社交网络、知识图谱、推荐系统)。

GCN 通过"邻居特征传播"实现节点分类任务,在 Cora 论文网络上达到了 80%+ 的准确率。

训练 200 轮 vs. 600 轮,提升不明显的原因是 GCN 收敛较快,继续训练可能导致过拟合。

🔥 现在你已经掌握了 GNN 的基本原理!你可以尝试:

  • 训练更复杂的 GNN(如 GAT、GraphSAGE)。
  • 用 GNN 处理真实世界的数据(社交推荐、金融风控)。
  • 结合 AI 搜索,让 GNN + LLM 提供更精准的智能问答!
相关推荐
努力学习的小廉1 小时前
我爱学算法之——滑动窗口攻克子数组和子串难题(上)
开发语言·c++·算法
梦想攻城狮1 小时前
深度学习之神经网络
人工智能·算法·机器学习
阿巴~阿巴~1 小时前
素数判定方法详解:从基础试除法到优化策略
c++·算法
Vitalia2 小时前
图论入门【数据结构基础】:什么是树?如何表示树?
数据结构·算法·图论·
埃菲尔铁塔_CV算法2 小时前
WPF 开发从入门到进阶(五)
深度学习·算法·机器学习·计算机视觉·wpf
Cindy辛蒂2 小时前
C语言:能够规定次数,处理非法字符的猜数游戏(三重循环)
c语言·算法·游戏
小卡皮巴拉2 小时前
【力扣刷题实战】无重复的最长字串
开发语言·c++·算法·leetcode·滑动窗口
huangyuchi.2 小时前
map容器练习:使用map容器识别统计单词个数
开发语言·数据结构·c++·笔记·算法
我是初生2 小时前
c++基础知识-图论进阶
算法·图论
啥都鼓捣的小yao2 小时前
Python解决“特定数组的逆序拼接”问题
开发语言·python·算法