PyG-GCN-Cora(在Cora数据集上应用GCN做节点分类)

文章目录

model.py

py 复制代码
import torch.nn as nn
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
class gcn_cls(nn.Module):
    def __init__(self,in_dim,hid_dim,out_dim,dropout_size=0.5):
        super(gcn_cls,self).__init__()
        self.conv1 = GCNConv(in_dim,hid_dim)
        self.conv2 = GCNConv(hid_dim,hid_dim)
        self.fc = nn.Linear(hid_dim,out_dim)
        self.relu  = nn.ReLU()
        self.dropout_size = dropout_size
    def forward(self,x,edge_index):
        x = self.conv1(x,edge_index)
        x = F.dropout(x,p=self.dropout_size,training=self.training)
        x = self.relu(x)
        x = self.conv2(x,edge_index)
        x = self.relu(x)
        x = self.fc(x)
        return x

main.py

py 复制代码
import torch
import torch.nn as nn
from torch_geometric.datasets import Planetoid
from model import gcn_cls
import torch.optim as optim
dataset = Planetoid(root='./data/Cora', name='Cora')
print(dataset[0])
cora_data = dataset[0]

epochs = 50
lr = 1e-3
weight_decay = 5e-3
momentum = 0.5
hidden_dim = 128
output_dim = 7

net = gcn_cls(cora_data.x.shape[1],hidden_dim,output_dim)
optimizer = optim.AdamW(net.parameters(),lr=lr,weight_decay=weight_decay)
#optimizer = optim.SGD(net.parameters(),lr = lr,momentum=momentum)
criterion = nn.CrossEntropyLoss()
print("****************Begin Training****************")
net.train()
for epoch in range(epochs):
    out = net(cora_data.x,cora_data.edge_index)
    optimizer.zero_grad()
    loss_train = criterion(out[cora_data.train_mask],cora_data.y[cora_data.train_mask])
    loss_val   = criterion(out[cora_data.val_mask],cora_data.y[cora_data.val_mask])
    loss_train.backward()
    print('epoch',epoch+1,'loss-train {:.2f}'.format(loss_train),'loss-val {:.2f}'.format(loss_val))
    optimizer.step()

net.eval()
out = net(cora_data.x,cora_data.edge_index)
loss_test = criterion(out[cora_data.test_mask],cora_data.y[cora_data.test_mask])
_,pred = torch.max(out,dim=1)
pred_label = pred[cora_data.test_mask]
true_label = cora_data.y[cora_data.test_mask]
acc = sum(pred_label==true_label)/len(pred_label)
print("****************Begin Testing****************")
print('loss-test {:.2f}'.format(loss_test),'acc {:.2f}'.format(acc))

参数设置

bash 复制代码
epochs = 50
lr = 1e-3
weight_decay = 5e-3
momentum = 0.5
hidden_dim = 128
output_dim = 7

output_dim是输出维度,也就是有多少可能的类别。

注意事项

1.发现loss不下降:

建议改一改lr(学习率),我做的时候开始用的SGD,学习率设的0.01发现loss不下降,改成0.1后好了很多。如果用AdamW,0.001(1e-3)基本就够用了

运行图

相关推荐
华新嘉华DTC创新营销1 天前
华新嘉华:AI搜索优化重塑本地生活行业:智能推荐正取代“关键词匹配”
人工智能·百度·生活
SmartBrain1 天前
DeerFlow 实践:华为IPD流程的评审智能体设计
人工智能·语言模型·架构
l1t1 天前
利用DeepSeek实现服务器客户端模式的DuckDB原型
服务器·c语言·数据库·人工智能·postgresql·协议·duckdb
寒月霜华1 天前
机器学习-数据标注
人工智能·机器学习
九章云极AladdinEdu1 天前
超参数自动化调优指南:Optuna vs. Ray Tune 对比评测
运维·人工智能·深度学习·ai·自动化·gpu算力
人工智能训练师1 天前
Ubuntu22.04如何安装新版本的Node.js和npm
linux·运维·前端·人工智能·ubuntu·npm·node.js
cxr8281 天前
SPARC方法论在Claude Code基于规则驱动开发中的应用
人工智能·驱动开发·claude·智能体
研梦非凡1 天前
ICCV 2025|从粗到细:用于高效3D高斯溅射的可学习离散小波变换
人工智能·深度学习·学习·3d
幂简集成1 天前
Realtime API 语音代理端到端接入全流程教程(含 Demo,延迟 280ms)
人工智能·个人开发
龙腾-虎跃1 天前
FreeSWITCH FunASR语音识别模块
人工智能·语音识别·xcode