网络模型训练完整代码

存个代码

具体看这位博主的网络模型训练完整套路 写的比较清晰

python 复制代码
import torchvision, torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
 
train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)
 
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))
 
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

# 这一块替换为要训练的网络模型
''' 
class Mydata(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            nn.Linear(64, 10)
        )
 
    def forward(self, x):
        x = self.model(x)
        return x
'''
mydata = Mydata()
loss_fn = nn.CrossEntropyLoss()
learning_rate = 1e-2
optimizer = torch.optim.SGD(mydata.parameters(), lr=learning_rate)  
 
total_train_step = 0
total_test_step = 0
epoch = 10
writer = SummaryWriter("logs")
 
for i in range(epoch):
    print("------------第 {} 轮训练开始------------".format(i + 1))
 
    mydata.train()  
    for data in train_dataloader:
        imgs, targets = data
        outputs = mydata(imgs)
        loss = loss_fn(outputs, targets)
 
        optimizer.zero_grad()  
        loss.backward()
        optimizer.step()
 
        total_train_step = total_train_step + 1
        if total_train_step % 100 == 0:
            print("训练次数:{}, Loss: {}".format(total_train_step, loss.item()))  
            writer.add_scalar("train_loss", loss.item(), total_train_step)  
 
    mydata.eval()  
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad():  
        for data in test_dataloader:  
            imgs, targets = data
            outputs = mydata(imgs)  
            loss = loss_fn(outputs, targets)
            total_test_loss = total_test_loss + loss.item()  
            accuracy = (outputs.argmax(1) == targets).sum()  
            total_accuracy = total_accuracy + accuracy
 
    print("整体测试集上的Loss: {}".format(total_test_loss))
    print("整体测试集上的正确率: {}".format(total_accuracy / test_data_size))  
    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    writer.add_scalar("test_accuracy", total_accuracy / test_data_size, total_test_step)
    total_test_step = total_test_step + 1
 
    torch.save(mydata, "mydata_{}.pth".format(i)) 
    print("模型已保存")
 
writer.close()
相关推荐
luoganttcc3 小时前
在 orin 上 安装了 miniconda 如何使用 orin 内置的 opencv
人工智能·opencv·计算机视觉
JinchuanMaster3 小时前
cv_bridge和openCV不兼容问题
人工智能·opencv·计算机视觉
心勤则明3 小时前
Spring AI 文档ETL实战:集成text-embedding-v4 与 Milvus
人工智能·spring·etl
啦啦啦在冲冲冲3 小时前
mse和交叉熵loss,为什么分类问题不用 mse
人工智能·分类·数据挖掘
SaaS_Product3 小时前
有安全好用且稳定的共享网盘吗?
人工智能·云计算·saas·onedrive
~~李木子~~3 小时前
图像分类项目:Fashion-MNIST 分类(SimpleCNN )
人工智能·分类·数据挖掘
轻赚时代3 小时前
新手做国风视频难?AI + 敦煌美学高效出片教程
人工智能·经验分享·笔记·创业创新·课程设计·学习方法
Xxtaoaooo3 小时前
原生多模态AI架构:统一训练与跨模态推理的系统实现与性能优化
人工智能·架构·分布式训练·多模态·模型优化
霖003 小时前
ZYNQ裸机开发指南笔记
人工智能·经验分享·笔记·matlab·fpga开发·信号处理