网络模型训练完整代码

存个代码

具体看这位博主的网络模型训练完整套路 写的比较清晰

python 复制代码
import torchvision, torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
 
train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)
 
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))
 
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

# 这一块替换为要训练的网络模型
''' 
class Mydata(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            nn.Linear(64, 10)
        )
 
    def forward(self, x):
        x = self.model(x)
        return x
'''
mydata = Mydata()
loss_fn = nn.CrossEntropyLoss()
learning_rate = 1e-2
optimizer = torch.optim.SGD(mydata.parameters(), lr=learning_rate)  
 
total_train_step = 0
total_test_step = 0
epoch = 10
writer = SummaryWriter("logs")
 
for i in range(epoch):
    print("------------第 {} 轮训练开始------------".format(i + 1))
 
    mydata.train()  
    for data in train_dataloader:
        imgs, targets = data
        outputs = mydata(imgs)
        loss = loss_fn(outputs, targets)
 
        optimizer.zero_grad()  
        loss.backward()
        optimizer.step()
 
        total_train_step = total_train_step + 1
        if total_train_step % 100 == 0:
            print("训练次数:{}, Loss: {}".format(total_train_step, loss.item()))  
            writer.add_scalar("train_loss", loss.item(), total_train_step)  
 
    mydata.eval()  
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad():  
        for data in test_dataloader:  
            imgs, targets = data
            outputs = mydata(imgs)  
            loss = loss_fn(outputs, targets)
            total_test_loss = total_test_loss + loss.item()  
            accuracy = (outputs.argmax(1) == targets).sum()  
            total_accuracy = total_accuracy + accuracy
 
    print("整体测试集上的Loss: {}".format(total_test_loss))
    print("整体测试集上的正确率: {}".format(total_accuracy / test_data_size))  
    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    writer.add_scalar("test_accuracy", total_accuracy / test_data_size, total_test_step)
    total_test_step = total_test_step + 1
 
    torch.save(mydata, "mydata_{}.pth".format(i)) 
    print("模型已保存")
 
writer.close()
相关推荐
人工智能训练1 小时前
【极速部署】Ubuntu24.04+CUDA13.0 玩转 VLLM 0.15.0:预编译 Wheel 包 GPU 版安装全攻略
运维·前端·人工智能·python·ai编程·cuda·vllm
源于花海2 小时前
迁移学习相关的期刊和会议
人工智能·机器学习·迁移学习·期刊会议
DisonTangor4 小时前
DeepSeek-OCR 2: 视觉因果流
人工智能·开源·aigc·ocr·deepseek
薛定谔的猫19824 小时前
二十一、基于 Hugging Face Transformers 实现中文情感分析情感分析
人工智能·自然语言处理·大模型 训练 调优
发哥来了4 小时前
《AI视频生成技术原理剖析及金管道·图生视频的应用实践》
人工智能
数智联AI团队4 小时前
AI搜索引领开源大模型新浪潮,技术创新重塑信息检索未来格局
人工智能·开源
不懒不懒4 小时前
【线性 VS 逻辑回归:一篇讲透两种核心回归模型】
人工智能·机器学习
冰西瓜6004 小时前
从项目入手机器学习——(四)特征工程(简单特征探索)
人工智能·机器学习
Ryan老房5 小时前
未来已来-AI标注工具的下一个10年
人工智能·yolo·目标检测·ai
丝斯20115 小时前
AI学习笔记整理(66)——多模态大模型MOE-LLAVA
人工智能·笔记·学习