5、Pytorch 实现简单图卷积GCN,数据集Cora分类任务

cora数据集- 下载地址

https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz

1、Cora数据集是什么?

Cora 数据集由 2708 篇科学出版物组成,分为七类之一。

引文网络由 5429 个链接组成。

数据集中的每个出版物都由一个 0/1 值的单词向量描述,表示字典中不存在/存在相应的单词。该词典由 1433 个独特的单词组成。

数据集下有两个文件

cora.cites

cora.cites共5429行, 每一行有两个论文编号,表示第一个编号的论文先写,第二个编号的论文引用第一个编号的论文。

cora.content

cora.content共有2708行,每一行代表一个样本点,即一篇论文。如下所示,每一行由三部分组成,分别是论文的编号,如31336;论文的词向量,一个有1433位的二进制;论文的类别,如Neural_Networks。

下面分别截图给大家看一下

2、python代码查看Cora形式

首先我们来看看数据集的形式,输出一下

复制代码
from torch_geometric.datasets import Planetoid

# 加载Cora数据集
dataset = Planetoid(root='data/Cora', name='Cora')
data = dataset[0]  # 获取图数据

print("数据集信息:")
print(f"节点数: {data.num_nodes}")
print(f"边数: {data.num_edges}")
print(f"类别数: {dataset.num_classes}")
print(f"特征维度: {dataset.num_node_features}")

那么执行的输出情况如下:

3、实现GCN

不说废话,直接放代码

我这里直接写成了一万伦次,大家根据设备可以调整。

复制代码
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid

# 检查 CUDA 是否可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 加载Cora数据集
dataset = Planetoid(root='data/Cora', name='Cora')
data = dataset[0].to(device)  # 将图数据移动到 CUDA 设备

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        # 第一层卷积 + ReLU
        x = F.relu(self.conv1(x, edge_index))
        # 第二层卷积 + LogSoftmax
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

# 初始化模型并将其移动到 CUDA 设备
model = GCN(dataset.num_node_features, 16, dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)  # 前向传播
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])  # 计算损失
    loss.backward()
    optimizer.step()
    return loss.item()

def test():
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)  # 预测类别
    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        correct = (pred[mask] == data.y[mask]).sum()
        acc = int(correct) / int(mask.sum())
        accs.append(acc)
    return accs

# 训练过程
# 根据自己情况,修改训练epoch
for epoch in range(10000):
    loss = train()
    train_acc, val_acc, test_acc = test()
    if epoch % 10 == 0:
        print(f"Epoch: {epoch}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}")

4、结果如何

训练结果如下:

相关推荐
江瀚视野42 分钟前
多家银行向甲骨文断贷,巨头甲骨文这是怎么了?
大数据·人工智能
码界筑梦坊42 分钟前
325-基于Python的校园卡消费行为数据可视化分析系统
开发语言·python·信息可视化·django·毕业设计
ccLianLian44 分钟前
计算机基础·cs336·损失函数,优化器,调度器,数据处理和模型加载保存
人工智能·深度学习·计算机视觉·transformer
asheuojj1 小时前
2026年GEO优化获客效果评估指南:如何精准衡量TOP5关
大数据·人工智能·python
多恩Stone1 小时前
【RoPE】Flux 中的 Image Tokenization
开发语言·人工智能·python
callJJ1 小时前
Spring AI ImageModel 完全指南:用 OpenAI DALL-E 生成图像
大数据·人工智能·spring·openai·springai·图像模型
铁蛋AI编程实战1 小时前
2026 大模型推理框架测评:vLLM 0.5/TGI 2.0/TensorRT-LLM 1.8/DeepSpeed-MII 0.9 性能与成本防线对比
人工智能·机器学习·vllm
23遇见1 小时前
CANN ops-nn 仓库高效开发指南:从入门到精通
人工智能
SAP工博科技1 小时前
SAP 公有云 ERP 多工厂多生产线数据统一管理技术实现解析
大数据·运维·人工智能
芷栀夏1 小时前
CANN ops-math:异构计算场景下基础数学算子的深度优化与硬件亲和设计解析
人工智能·cann