网络模型训练完整代码

存个代码

具体看这位博主的网络模型训练完整套路 写的比较清晰

python 复制代码
import torchvision, torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
 
train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)
 
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))
 
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

# 这一块替换为要训练的网络模型
''' 
class Mydata(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            nn.Linear(64, 10)
        )
 
    def forward(self, x):
        x = self.model(x)
        return x
'''
mydata = Mydata()
loss_fn = nn.CrossEntropyLoss()
learning_rate = 1e-2
optimizer = torch.optim.SGD(mydata.parameters(), lr=learning_rate)  
 
total_train_step = 0
total_test_step = 0
epoch = 10
writer = SummaryWriter("logs")
 
for i in range(epoch):
    print("------------第 {} 轮训练开始------------".format(i + 1))
 
    mydata.train()  
    for data in train_dataloader:
        imgs, targets = data
        outputs = mydata(imgs)
        loss = loss_fn(outputs, targets)
 
        optimizer.zero_grad()  
        loss.backward()
        optimizer.step()
 
        total_train_step = total_train_step + 1
        if total_train_step % 100 == 0:
            print("训练次数:{}, Loss: {}".format(total_train_step, loss.item()))  
            writer.add_scalar("train_loss", loss.item(), total_train_step)  
 
    mydata.eval()  
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad():  
        for data in test_dataloader:  
            imgs, targets = data
            outputs = mydata(imgs)  
            loss = loss_fn(outputs, targets)
            total_test_loss = total_test_loss + loss.item()  
            accuracy = (outputs.argmax(1) == targets).sum()  
            total_accuracy = total_accuracy + accuracy
 
    print("整体测试集上的Loss: {}".format(total_test_loss))
    print("整体测试集上的正确率: {}".format(total_accuracy / test_data_size))  
    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    writer.add_scalar("test_accuracy", total_accuracy / test_data_size, total_test_step)
    total_test_step = total_test_step + 1
 
    torch.save(mydata, "mydata_{}.pth".format(i)) 
    print("模型已保存")
 
writer.close()
相关推荐
后端小肥肠几秒前
一人公司如何用 WorkBuddy + Obsidian 搭一套长期记忆系统?
人工智能·aigc·agent
RFID舜识物联网1 分钟前
破局“信息孤岛”:RFID耐高温标签重塑汽车喷漆车间可视化
大数据·人工智能·科技·物联网·安全·汽车
05大叔2 分钟前
预训练模型演化,提示词工程
人工智能·深度学习·自然语言处理
BU摆烂会噶2 分钟前
【LangGraph】House_Agent 实战(一):架构与环境配置
人工智能·vscode·python·架构·langchain·人机交互
小小测试开发3 分钟前
OpenAI 模型攻克离散几何 80 年难题:Erdős 单位距离猜想被 AI 证明
人工智能·算法·机器学习
moonsims4 分钟前
从“传感器融合”升级为“多机器人约束融合系统”-Factor Graph 多约束融合
人工智能·算法
tedcloud1235 分钟前
agent-skills部署教程:打造工程化AI Agent系统
服务器·人工智能·系统架构·powerpoint·dreamweaver
测试员周周5 分钟前
【Appium 系列】第15节-视觉测试 — 截图、对比、视觉回归
人工智能·python·数据挖掘·回归·appium·测试用例·测试覆盖率
Dfreedom.8 分钟前
模型剪枝完全指南:从理论到实践,打造高效深度学习模型
人工智能·算法·机器学习·剪枝·模型加速
开始脱发的自然卷10 分钟前
用 Excel 手算 LSTM:从四个门到梯度下降的完整过程
人工智能