网络模型训练完整代码

存个代码

具体看这位博主的网络模型训练完整套路 写的比较清晰

python 复制代码
import torchvision, torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
 
train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)
 
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))
 
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

# 这一块替换为要训练的网络模型
''' 
class Mydata(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            nn.Linear(64, 10)
        )
 
    def forward(self, x):
        x = self.model(x)
        return x
'''
mydata = Mydata()
loss_fn = nn.CrossEntropyLoss()
learning_rate = 1e-2
optimizer = torch.optim.SGD(mydata.parameters(), lr=learning_rate)  
 
total_train_step = 0
total_test_step = 0
epoch = 10
writer = SummaryWriter("logs")
 
for i in range(epoch):
    print("------------第 {} 轮训练开始------------".format(i + 1))
 
    mydata.train()  
    for data in train_dataloader:
        imgs, targets = data
        outputs = mydata(imgs)
        loss = loss_fn(outputs, targets)
 
        optimizer.zero_grad()  
        loss.backward()
        optimizer.step()
 
        total_train_step = total_train_step + 1
        if total_train_step % 100 == 0:
            print("训练次数:{}, Loss: {}".format(total_train_step, loss.item()))  
            writer.add_scalar("train_loss", loss.item(), total_train_step)  
 
    mydata.eval()  
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad():  
        for data in test_dataloader:  
            imgs, targets = data
            outputs = mydata(imgs)  
            loss = loss_fn(outputs, targets)
            total_test_loss = total_test_loss + loss.item()  
            accuracy = (outputs.argmax(1) == targets).sum()  
            total_accuracy = total_accuracy + accuracy
 
    print("整体测试集上的Loss: {}".format(total_test_loss))
    print("整体测试集上的正确率: {}".format(total_accuracy / test_data_size))  
    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    writer.add_scalar("test_accuracy", total_accuracy / test_data_size, total_test_step)
    total_test_step = total_test_step + 1
 
    torch.save(mydata, "mydata_{}.pth".format(i)) 
    print("模型已保存")
 
writer.close()
相关推荐
needn2 小时前
TRAE为什么要发布SOLO版本?
人工智能·ai编程
毅航2 小时前
自然语言处理发展史:从规则、统计到深度学习
人工智能·后端
前端付豪3 小时前
LangChain链 写一篇完美推文?用SequencialChain链接不同的组件
人工智能·python·langchain
ursazoo3 小时前
写了一份 7000字指南,让 AI 帮我消化每天的信息流
人工智能·开源·github
_志哥_6 小时前
Superpowers 技术指南:让 AI 编程助手拥有超能力
人工智能·ai编程·测试
YongGit7 小时前
OpenClaw 本地 AI 助手完全指南:飞书接入 + 远程部署实战
人工智能
程序员鱼皮8 小时前
斯坦福大学竟然开了个 AI 编程课?!我已经学上了
人工智能·ai编程
星浩AI9 小时前
Skill 的核心要素与渐进式加载架构——如何设计一个生产可用的 Skill?
人工智能·agent
树獭非懒9 小时前
告别繁琐多端开发:DivKit 带你玩转 Server-Driven UI!
android·前端·人工智能
阿尔的代码屋9 小时前
[大模型实战 07] 基于 LlamaIndex ReAct 框架手搓全自动博客监控 Agent
人工智能·python