💡 GNN 是什么?一篇文章带你入门图神经网络!
🌟 1. 什么是 GNN(图神经网络)?
GNN(Graph Neural Network) 是一种专门处理"图数据"的深度学习方法。
什么是"图"数据? 🤔
- 社交网络 👥(微信、微博好友推荐)
- 电商推荐 🛒(淘宝"猜你喜欢")
- 知识图谱 📚(Google 搜索、ChatGPT 背后的逻辑)
- 反欺诈检测 🔍(信用卡诈骗、金融风控)
- 生物信息学 🧬(蛋白质结构预测)
📌 GNN 的目标 :
传统神经网络只能处理表格、文本、图片,而 GNN 能学习"节点之间的关系"。
💡 一句话理解 GNN:
"你是谁,取决于你的邻居"
GNN 通过 "信息传递",让每个节点不仅学习自己的特征,还能学习"邻居节点"的特征,最终实现分类、推荐、预测等任务。
🌟 2. 用代码带你训练一个 GNN(GCN)
💡 我们用 PyTorch Geometric 训练 GCN(图卷积网络),看看 GNN 是怎么学习的! 🚀
✅ (1) 安装 PyTorch Geometric
如果你还没安装 PyG,可以运行:
bash
pip install torch torchvision torchaudio
pip install torch-geometric
✅ (2) 载入 Cora 论文引用数据集
python
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='data', name='Cora')
data = dataset[0] # 只有一个图
print(f"节点数: {data.num_nodes}, 边数: {data.num_edges}")
📌 解释:
- Cora 数据集 是一个论文引用网络,论文是节点(Nodes) ,论文之间的引用是边(Edges)。
- 目标:根据论文内容,预测论文属于哪个学科(7 类)。
📊 数据统计
makefile
节点数: 2708
边数: 10556
特征维度: 1433
类别数: 7
✅ (3) 定义 GNN(GCN)
python
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
def __init__(self):
super(GCN, self).__init__()
self.conv1 = GCNConv(dataset.num_node_features, 16) # 第一层 GCN
self.conv2 = GCNConv(16, dataset.num_classes) # 输出 7 类
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x) # 激活函数
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1) # 输出分类概率
📌 代码解析:
- 第一层 GCN:将论文的 1433 维特征降维到 16 维。
- ReLU 激活:让 GCN 学习更复杂的关系。
- 第二层 GCN:输出 7 维分类得分(对应 7 个研究领域)。
✅ (4) 训练 GCN
python
model = GCN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
def train():
model.train()
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss.item()
for epoch in range(200):
loss = train()
if epoch % 20 == 0:
print(f"Epoch {epoch}, Loss: {loss:.4f}")
📌 训练过程:
- 前向传播:计算论文的分类得分。
- 计算损失:交叉熵(NLL Loss)。
- 反向传播:更新模型参数。
- 循环训练 200 轮(每 20 轮输出一次损失)。
✅ (5) 测试 GCN
python
def test():
model.eval()
out = model(data)
pred = out.argmax(dim=1) # 取最大类别
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = correct / data.test_mask.sum()
print(f"测试集准确率: {acc:.4f}")
test()
model.eval()
👉 进入测试模式argmax(dim=1)
👉 选取最大得分的类别- 最终测试集准确率约 80% 🎯
🌟 3. 训练 200 次 vs. 600 次,为什么提升不大?
你可能会好奇,为什么训练 200 轮后,增加到 600 轮,模型准确率提升很小?
📌 原因:
-
过早收敛(Early Convergence)
- GCN 主要依赖 邻居信息传播,通常 100~200 轮就能收敛,继续训练效果不会大幅提升。
-
过拟合(Overfitting)
- 训练 200 轮后,模型已经很好地学习到了数据模式,继续训练可能只是在记忆训练数据,而不能更好地泛化到测试集。
📌 解决方案
- 观察训练曲线(如果 loss 变化不大,说明已经收敛)
- 尝试 Dropout(减少过拟合)
- 改变 GCN 结构(增加层数、更深的图传播)
🌟 4. 完整代码
python
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid # 经典的GNN数据集
from torch_geometric.nn import GCNConv # GCN(图卷积网络)层
# 载入 Cora 数据集(包含论文引用网络)
dataset = Planetoid(root='data', name='Cora')
# 获取数据
data = dataset[0] # 只有一个图
print(f"节点数: {data.num_nodes}, 边数: {data.num_edges}")
print(f"特征维度: {data.num_node_features}, 类别数: {dataset.num_classes}")
class GCN(torch.nn.Module):
def __init__(self):
super(GCN, self).__init__()
self.conv1 = GCNConv(dataset.num_node_features, 16) # 第一层 GCN
self.conv2 = GCNConv(16, dataset.num_classes) # 第二层 GCN 输出类别
def forward(self, data):
x, edge_index = data.x, data.edge_index # 获取节点特征 & 边信息
x = self.conv1(x, edge_index)
x = F.relu(x) # ReLU 激活
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1) # 归一化得分
# 创建 GCN 模型
model = GCN()
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
def train():
model.train()
optimizer.zero_grad() # 清空梯度
out = model(data) # 前向传播
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 更新参数
return loss.item()
for epoch in range(600):
loss = train()
if epoch % 20 == 0:
print(f"Epoch {epoch}, Loss: {loss:.4f}")
def test():
model.eval()
out = model(data)
pred = out.argmax(dim=1) # 取最大类别
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = correct / data.test_mask.sum()
print(f"测试集准确率: {acc:.4f}")
test()
🌟 5. 结论
✅ GNN 适用于图数据(社交网络、知识图谱、推荐系统)。
✅ GCN 通过"邻居特征传播"实现节点分类任务,在 Cora 论文网络上达到了 80%+ 的准确率。
✅ 训练 200 轮 vs. 600 轮,提升不明显的原因是 GCN 收敛较快,继续训练可能导致过拟合。
🔥 现在你已经掌握了 GNN 的基本原理!你可以尝试:
- 训练更复杂的 GNN(如 GAT、GraphSAGE)。
- 用 GNN 处理真实世界的数据(社交推荐、金融风控)。
- 结合 AI 搜索,让 GNN + LLM 提供更精准的智能问答!