图卷积网络(GCN)简单示例

代码功能

这段代码的功能是使用图卷积网络(GCN)对图数据中的节点进行分类,并通过可视化展示节点的真实标签和预测结果。具体步骤如下:

  1. 加载数据集:使用 Cora 引用网络数据集,每个节点表示论文,边表示引用关系,节点标签为论文类别。
  2. 定义 GCN 模型:构建一个两层的 GCN 模型,第一层提取特征,第二层输出类别。
  3. 可视化原始图:使用真实标签颜色绘制图结构,以便对比分类效果。
  4. 训练模型:通过200轮迭代优化模型参数,使其学习节点类别特征。
  5. 可视化预测结果:用模型预测的标签颜色绘制图结构,直观展示分类效果。
  6. 评估准确率:计算并输出模型在测试集上的准确率。

    整体上,这段代码实现了图数据的节点分类及结果的可视化。

代码

python 复制代码
# 导入必要的库
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
import matplotlib.pyplot as plt
import networkx as nx
from torch_geometric.utils import to_networkx

# 1. 加载数据集(使用Cora数据集,这是一个引用网络数据集)
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]  # 获取图数据

# 2. 定义GCN模型
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        # 定义两层GCN卷积层
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        # 第一层卷积+ReLU激活函数
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        # 第二层卷积+Softmax输出
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

# 3. 初始化模型和优化器
model = GCN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

# 4. 定义绘制节点图函数
def plot_graph(data, color_map=None, title="Graph"):
    # 将数据转换为NetworkX图
    G = to_networkx(data, to_undirected=True)
    plt.figure(figsize=(8, 8))
    # 绘制图,并为节点上色
    nx.draw(G, pos=nx.spring_layout(G), with_labels=False, node_color=color_map, 
            node_size=50, cmap="coolwarm")
    plt.title(title)
    plt.show()

# 使用真实标签颜色绘制原始图
color_map = data.y.numpy()
plot_graph(data, color_map, title="Original Graph with True Labels")

# 5. 训练模型
model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    # 计算交叉熵损失
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    if epoch % 20 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item()}')

# 6. 评估模型,并可视化预测结果
model.eval()
_, pred = model(data).max(dim=1)

# 使用预测标签颜色绘制图
pred_color_map = pred.numpy()  # 使用预测标签作为颜色映射
plot_graph(data, pred_color_map, title="Graph with Predicted Labels")

# 计算并输出准确率
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
accuracy = int(correct) / int(data.test_mask.sum())
print(f'Accuracy: {accuracy:.4f}')
相关推荐
我送炭你添花11 分钟前
Pelco KBD300A 模拟器:06+2.Pelco KBD300A 模拟器项目重构指南
python·重构·自动化·运维开发
Swizard13 分钟前
别再只会算直线距离了!用“马氏距离”揪出那个伪装的数据“卧底”
python·算法·ai
站大爷IP14 分钟前
Python函数与模块化编程:局部变量与全局变量的深度解析
python
我命由我1234522 分钟前
Python Flask 开发问题:ImportError: cannot import name ‘Markup‘ from ‘flask‘
开发语言·后端·python·学习·flask·学习方法·python3.11
databook31 分钟前
掌握相关性分析:读懂数据间的“悄悄话”
python·数据挖掘·数据分析
全栈陈序员1 小时前
【Python】基础语法入门(二十)——项目实战:从零构建命令行 To-Do List 应用
开发语言·人工智能·python·学习
jcsx1 小时前
如何将django项目发布为https
python·https·django
岁月宁静1 小时前
LangGraph 技术详解:基于图结构的 AI 工作流与多智能体编排框架
前端·python·langchain
百锦再1 小时前
京东云鼎入驻方案解读——通往协同的“高架桥”与“快速路”
android·java·python·rust·django·restful·京东云
岁月宁静1 小时前
LangChain 技术栈全解析:从模型编排到 RAG 实战
前端·python·langchain