什么是 GNN?用最简单的代码带你入门图神经网络!

💡 GNN 是什么?一篇文章带你入门图神经网络!

🌟 1. 什么是 GNN(图神经网络)?

GNN(Graph Neural Network) 是一种专门处理"图数据"的深度学习方法。

什么是"图"数据? 🤔

  • 社交网络 👥(微信、微博好友推荐)
  • 电商推荐 🛒(淘宝"猜你喜欢")
  • 知识图谱 📚(Google 搜索、ChatGPT 背后的逻辑)
  • 反欺诈检测 🔍(信用卡诈骗、金融风控)
  • 生物信息学 🧬(蛋白质结构预测)

📌 GNN 的目标
传统神经网络只能处理表格、文本、图片,而 GNN 能学习"节点之间的关系"

💡 一句话理解 GNN:

"你是谁,取决于你的邻居"

GNN 通过 "信息传递",让每个节点不仅学习自己的特征,还能学习"邻居节点"的特征,最终实现分类、推荐、预测等任务。


🌟 2. 用代码带你训练一个 GNN(GCN)

💡 我们用 PyTorch Geometric 训练 GCN(图卷积网络),看看 GNN 是怎么学习的! 🚀

✅ (1) 安装 PyTorch Geometric

如果你还没安装 PyG,可以运行:

bash 复制代码
pip install torch torchvision torchaudio
pip install torch-geometric

✅ (2) 载入 Cora 论文引用数据集

python 复制代码
from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='data', name='Cora')
data = dataset[0]  # 只有一个图
print(f"节点数: {data.num_nodes}, 边数: {data.num_edges}")

📌 解释:

  • Cora 数据集 是一个论文引用网络,论文是节点(Nodes) ,论文之间的引用是边(Edges)
  • 目标:根据论文内容,预测论文属于哪个学科(7 类)。

📊 数据统计

makefile 复制代码
节点数: 2708
边数: 10556
特征维度: 1433
类别数: 7

✅ (3) 定义 GNN(GCN)

python 复制代码
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)  # 第一层 GCN
        self.conv2 = GCNConv(16, dataset.num_classes)  # 输出 7 类

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)  # 激活函数
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)  # 输出分类概率

📌 代码解析:

  • 第一层 GCN:将论文的 1433 维特征降维到 16 维。
  • ReLU 激活:让 GCN 学习更复杂的关系。
  • 第二层 GCN:输出 7 维分类得分(对应 7 个研究领域)。

✅ (4) 训练 GCN

python 复制代码
model = GCN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

def train():
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

for epoch in range(200):
    loss = train()
    if epoch % 20 == 0:
        print(f"Epoch {epoch}, Loss: {loss:.4f}")

📌 训练过程

  1. 前向传播:计算论文的分类得分。
  2. 计算损失:交叉熵(NLL Loss)。
  3. 反向传播:更新模型参数。
  4. 循环训练 200 轮(每 20 轮输出一次损失)。

✅ (5) 测试 GCN

python 复制代码
def test():
    model.eval()
    out = model(data)
    pred = out.argmax(dim=1)  # 取最大类别
    correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
    acc = correct / data.test_mask.sum()
    print(f"测试集准确率: {acc:.4f}")

test()
  • model.eval() 👉 进入测试模式
  • argmax(dim=1) 👉 选取最大得分的类别
  • 最终测试集准确率约 80% 🎯

🌟 3. 训练 200 次 vs. 600 次,为什么提升不大?

你可能会好奇,为什么训练 200 轮后,增加到 600 轮,模型准确率提升很小?

📌 原因:

  1. 过早收敛(Early Convergence)

    • GCN 主要依赖 邻居信息传播,通常 100~200 轮就能收敛,继续训练效果不会大幅提升。
  2. 过拟合(Overfitting)

    • 训练 200 轮后,模型已经很好地学习到了数据模式,继续训练可能只是在记忆训练数据,而不能更好地泛化到测试集。

📌 解决方案

  • 观察训练曲线(如果 loss 变化不大,说明已经收敛)
  • 尝试 Dropout(减少过拟合)
  • 改变 GCN 结构(增加层数、更深的图传播)

🌟 4. 完整代码

python 复制代码
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid  # 经典的GNN数据集
from torch_geometric.nn import GCNConv  # GCN(图卷积网络)层

# 载入 Cora 数据集(包含论文引用网络)
dataset = Planetoid(root='data', name='Cora')

# 获取数据
data = dataset[0]  # 只有一个图
print(f"节点数: {data.num_nodes}, 边数: {data.num_edges}")
print(f"特征维度: {data.num_node_features}, 类别数: {dataset.num_classes}")
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)  # 第一层 GCN
        self.conv2 = GCNConv(16, dataset.num_classes)  # 第二层 GCN 输出类别

    def forward(self, data):
        x, edge_index = data.x, data.edge_index  # 获取节点特征 & 边信息
        x = self.conv1(x, edge_index)
        x = F.relu(x)  # ReLU 激活
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)  # 归一化得分

# 创建 GCN 模型
model = GCN()
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

def train():
    model.train()
    optimizer.zero_grad()  # 清空梯度
    out = model(data)  # 前向传播
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])  # 计算损失
    loss.backward()  # 反向传播
    optimizer.step()  # 更新参数
    return loss.item()

for epoch in range(600):
    loss = train()
    if epoch % 20 == 0:
        print(f"Epoch {epoch}, Loss: {loss:.4f}")

def test():
    model.eval()
    out = model(data)
    pred = out.argmax(dim=1)  # 取最大类别
    correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
    acc = correct / data.test_mask.sum()
    print(f"测试集准确率: {acc:.4f}")

test()

🌟 5. 结论

GNN 适用于图数据(社交网络、知识图谱、推荐系统)。

GCN 通过"邻居特征传播"实现节点分类任务,在 Cora 论文网络上达到了 80%+ 的准确率。

训练 200 轮 vs. 600 轮,提升不明显的原因是 GCN 收敛较快,继续训练可能导致过拟合。

🔥 现在你已经掌握了 GNN 的基本原理!你可以尝试:

  • 训练更复杂的 GNN(如 GAT、GraphSAGE)。
  • 用 GNN 处理真实世界的数据(社交推荐、金融风控)。
  • 结合 AI 搜索,让 GNN + LLM 提供更精准的智能问答!
相关推荐
Haohao+++29 分钟前
Stable Diffusion原理解析
人工智能·深度学习·算法
ideaout技术团队3 小时前
leetcode学习笔记2:多数元素(摩尔投票算法)
学习·算法·leetcode
代码充电宝3 小时前
LeetCode 算法题【简单】283. 移动零
java·算法·leetcode·职场和发展
不枯石6 小时前
Matlab通过GUI实现点云的均值滤波(附最简版)
开发语言·图像处理·算法·计算机视觉·matlab·均值算法
不枯石6 小时前
Matlab通过GUI实现点云的双边(Bilateral)滤波(附最简版)
开发语言·图像处理·算法·计算机视觉·matlab
白水先森8 小时前
C语言作用域与数组详解
java·数据结构·算法
想唱rap8 小时前
直接选择排序、堆排序、冒泡排序
c语言·数据结构·笔记·算法·新浪微博
老葱头蒸鸡9 小时前
(27)APS.NET Core8.0 堆栈原理通俗理解
算法
视睿9 小时前
【C++练习】06.输出100以内的所有素数
开发语言·c++·算法·机器人·无人机
柠檬071110 小时前
matlab cell 数据转换及记录
算法