什么是 GNN?用最简单的代码带你入门图神经网络!

💡 GNN 是什么?一篇文章带你入门图神经网络!

🌟 1. 什么是 GNN(图神经网络)?

GNN(Graph Neural Network) 是一种专门处理"图数据"的深度学习方法。

什么是"图"数据? 🤔

  • 社交网络 👥(微信、微博好友推荐)
  • 电商推荐 🛒(淘宝"猜你喜欢")
  • 知识图谱 📚(Google 搜索、ChatGPT 背后的逻辑)
  • 反欺诈检测 🔍(信用卡诈骗、金融风控)
  • 生物信息学 🧬(蛋白质结构预测)

📌 GNN 的目标
传统神经网络只能处理表格、文本、图片,而 GNN 能学习"节点之间的关系"

💡 一句话理解 GNN:

"你是谁,取决于你的邻居"

GNN 通过 "信息传递",让每个节点不仅学习自己的特征,还能学习"邻居节点"的特征,最终实现分类、推荐、预测等任务。


🌟 2. 用代码带你训练一个 GNN(GCN)

💡 我们用 PyTorch Geometric 训练 GCN(图卷积网络),看看 GNN 是怎么学习的! 🚀

✅ (1) 安装 PyTorch Geometric

如果你还没安装 PyG,可以运行:

bash 复制代码
pip install torch torchvision torchaudio
pip install torch-geometric

✅ (2) 载入 Cora 论文引用数据集

python 复制代码
from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='data', name='Cora')
data = dataset[0]  # 只有一个图
print(f"节点数: {data.num_nodes}, 边数: {data.num_edges}")

📌 解释:

  • Cora 数据集 是一个论文引用网络,论文是节点(Nodes) ,论文之间的引用是边(Edges)
  • 目标:根据论文内容,预测论文属于哪个学科(7 类)。

📊 数据统计

makefile 复制代码
节点数: 2708
边数: 10556
特征维度: 1433
类别数: 7

✅ (3) 定义 GNN(GCN)

python 复制代码
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)  # 第一层 GCN
        self.conv2 = GCNConv(16, dataset.num_classes)  # 输出 7 类

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)  # 激活函数
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)  # 输出分类概率

📌 代码解析:

  • 第一层 GCN:将论文的 1433 维特征降维到 16 维。
  • ReLU 激活:让 GCN 学习更复杂的关系。
  • 第二层 GCN:输出 7 维分类得分(对应 7 个研究领域)。

✅ (4) 训练 GCN

python 复制代码
model = GCN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

def train():
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

for epoch in range(200):
    loss = train()
    if epoch % 20 == 0:
        print(f"Epoch {epoch}, Loss: {loss:.4f}")

📌 训练过程

  1. 前向传播:计算论文的分类得分。
  2. 计算损失:交叉熵(NLL Loss)。
  3. 反向传播:更新模型参数。
  4. 循环训练 200 轮(每 20 轮输出一次损失)。

✅ (5) 测试 GCN

python 复制代码
def test():
    model.eval()
    out = model(data)
    pred = out.argmax(dim=1)  # 取最大类别
    correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
    acc = correct / data.test_mask.sum()
    print(f"测试集准确率: {acc:.4f}")

test()
  • model.eval() 👉 进入测试模式
  • argmax(dim=1) 👉 选取最大得分的类别
  • 最终测试集准确率约 80% 🎯

🌟 3. 训练 200 次 vs. 600 次,为什么提升不大?

你可能会好奇,为什么训练 200 轮后,增加到 600 轮,模型准确率提升很小?

📌 原因:

  1. 过早收敛(Early Convergence)

    • GCN 主要依赖 邻居信息传播,通常 100~200 轮就能收敛,继续训练效果不会大幅提升。
  2. 过拟合(Overfitting)

    • 训练 200 轮后,模型已经很好地学习到了数据模式,继续训练可能只是在记忆训练数据,而不能更好地泛化到测试集。

📌 解决方案

  • 观察训练曲线(如果 loss 变化不大,说明已经收敛)
  • 尝试 Dropout(减少过拟合)
  • 改变 GCN 结构(增加层数、更深的图传播)

🌟 4. 完整代码

python 复制代码
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid  # 经典的GNN数据集
from torch_geometric.nn import GCNConv  # GCN(图卷积网络)层

# 载入 Cora 数据集(包含论文引用网络)
dataset = Planetoid(root='data', name='Cora')

# 获取数据
data = dataset[0]  # 只有一个图
print(f"节点数: {data.num_nodes}, 边数: {data.num_edges}")
print(f"特征维度: {data.num_node_features}, 类别数: {dataset.num_classes}")
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)  # 第一层 GCN
        self.conv2 = GCNConv(16, dataset.num_classes)  # 第二层 GCN 输出类别

    def forward(self, data):
        x, edge_index = data.x, data.edge_index  # 获取节点特征 & 边信息
        x = self.conv1(x, edge_index)
        x = F.relu(x)  # ReLU 激活
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)  # 归一化得分

# 创建 GCN 模型
model = GCN()
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

def train():
    model.train()
    optimizer.zero_grad()  # 清空梯度
    out = model(data)  # 前向传播
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])  # 计算损失
    loss.backward()  # 反向传播
    optimizer.step()  # 更新参数
    return loss.item()

for epoch in range(600):
    loss = train()
    if epoch % 20 == 0:
        print(f"Epoch {epoch}, Loss: {loss:.4f}")

def test():
    model.eval()
    out = model(data)
    pred = out.argmax(dim=1)  # 取最大类别
    correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
    acc = correct / data.test_mask.sum()
    print(f"测试集准确率: {acc:.4f}")

test()

🌟 5. 结论

GNN 适用于图数据(社交网络、知识图谱、推荐系统)。

GCN 通过"邻居特征传播"实现节点分类任务,在 Cora 论文网络上达到了 80%+ 的准确率。

训练 200 轮 vs. 600 轮,提升不明显的原因是 GCN 收敛较快,继续训练可能导致过拟合。

🔥 现在你已经掌握了 GNN 的基本原理!你可以尝试:

  • 训练更复杂的 GNN(如 GAT、GraphSAGE)。
  • 用 GNN 处理真实世界的数据(社交推荐、金融风控)。
  • 结合 AI 搜索,让 GNN + LLM 提供更精准的智能问答!
相关推荐
S01d13r1 小时前
LeetCode 解题思路 47(最长回文子串、最长公共子序列)
算法·leetcode·职场和发展
摄殓永恒2 小时前
【入门】数字走向II
算法
饮啦冰美式3 小时前
PPO近端策略优化算法
人工智能·深度学习·算法
void_sk3 小时前
C/C++复习--C语言中的函数详细
c语言·c++·算法
evolution_language3 小时前
LintCode第485题-生成给定大小的数组,第220题-冰雹猜想,第235题-分解质因数
数据结构·算法·新手必刷编程50题
钢铁男儿4 小时前
C# 方法(参数数组)
java·算法·c#
KuaCpp4 小时前
5.8线性动态规划2
算法·动态规划
How_doyou_do5 小时前
备战菊厂笔试2-BFS记忆化MLE?用Set去重-Set会TLE?用SortedSet剪枝
算法·深度优先
晴空闲雲6 小时前
线性表-顺序表(Sequential List)
数据结构·算法
Javis2116 小时前
代码随想录算法训练营第九天 |【字符串】151.翻转字符串里的单词、卡码网55.右旋转字符串、28.实现strStr、459.重复的子字符串
数据结构·c++·算法