Pytorch学习笔记(模型训练)

模型训练

在同一个包下创建train.pymodel.py,按照步骤先从数据处理,模型架构搭建,训练测试,统计损失,如下面代码所示

  1. train.py
py 复制代码
import torch.optim
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from model import NNN

# 1. 准备数据集
train_data = torchvision.datasets.CIFAR10("./data", train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10("./data", train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)

train_data_size = len(train_data)
test_data_size = len(test_data)
print(f"训练数据集的长度:{train_data_size}")
print(f"测试数据集的长度:{test_data_size}")

# 2. 利用DataLoader 加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

# 3. 搭建神经网络
# 引入model.py
nnn = NNN()

# 4. 创建损失函数loss
loss_fn = nn.CrossEntropyLoss()  # 交叉熵

# 5. 优化器
learning_rate = 0.01
optimizer = torch.optim.SGD(nnn.parameters(), lr=learning_rate)  # 随机梯度下降

# 6. 设置训练网络的一些参数
total_train_step = 0  # 记录训练次数
total_test_step = 0  # 训练测试次数
epoch = 10  # 训练轮数

# 补充tensorboard
writer = SummaryWriter("../logs")

# 开始训练
for i in range(epoch):
    print(f"--------第{i+1}轮训练开始--------")
    # 训练
    nnn.train()
    for data in train_dataloader:
        imgs, targets = data
        outputs = nnn(imgs)
        loss = loss_fn(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step += 1
        if total_train_step % 100 == 0:
            print(f"训练次数:{total_train_step}---loss:{loss.item()}")
            writer.add_scalar("train_loss", loss.item(), total_train_step)

    # 测试
    nnn.eval()
    total_test_loss = 0  # 总体的误差
    total_accuracy = 0  # 总体的正确率
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            outputs = nnn(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss += loss.item()
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy += accuracy
    print(f"整体测试集上的loss:{total_test_loss}")
    print(f"整体测试集上的准确率:{total_accuracy/test_data_size}")
    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    writer.add_scalar("total_accuracy", total_accuracy/test_data_size, total_test_step)
    total_test_step += 1

    # 保存每一轮训练的模型
    torch.save(nnn, f"nnn_{i+1}.pth")
    print("模式已保存")


writer.close()
  1. model.py
py 复制代码
import torch
from torch import nn


# 搭建神经网络
class NNN(nn.Module):
    def __init__(self):
        super(NNN, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 32, 5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 64, 5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(1024, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.model(x)
        return x


if __name__ == '__main__':
    nnn = NNN()
    input = torch.ones((64, 3, 32, 32))
    output = nnn(input)
    print(output.shape)

运行train.py后可以通过启动tensorboard进行查看我们的loss情况,损失是不断下降的。


补充argmax函数的使用

我们模型预测处理的是概率,我们需要使用argmax函数还得到预测的结果,就是选出概率最大的,上面测试准确率的计算使用到了。

简单代码示例:

py 复制代码
import torch
# 模型输出的概率
outputs = torch.tensor([[0.1, 0.3],
                        [0.7, 0.2]])
# 真实的分类
targets = torch.tensor([[1, 1]])
# 对概率进行预测
preds = outputs.argmax(1)  # 1:横向比较 0:竖向比较

# 预测与真实进行比较
print(preds == targets)
print((preds == targets).sum().item())  # 统计正确的个数

输出:

cpp 复制代码
tensor([[ True, False]])
1
相关推荐
June bug3 分钟前
(#数组/链表操作)合并两个有重复元素的无序数组,返回无重复的有序结果
数据结构·python·算法·leetcode·面试·跳槽
人工智能AI技术10 分钟前
【Agent从入门到实践】33 集成多工具,实现Agent的工具选择与执行
人工智能·python
AIFQuant20 分钟前
如何通过股票数据 API 计算 RSI、MACD 与移动平均线MA
大数据·后端·python·金融·restful
70asunflower27 分钟前
Python with 语句与上下文管理完全教程
linux·服务器·python
deephub35 分钟前
为什么标准化要用均值0和方差1?
人工智能·python·机器学习·标准化
hnxaoli40 分钟前
win10程序(十五)归档文件的xlsx目录自动分卷
python
喵手1 小时前
Python爬虫零基础入门【第九章:实战项目教学·第8节】限速器进阶:令牌桶 + 动态降速(429/5xx)!
爬虫·python·令牌桶·python爬虫工程化实战·python爬虫零基础入门·限速器·动态降速
煤炭里de黑猫1 小时前
使用 PyTorch 实现标准 LSTM 神经网络
人工智能·pytorch·lstm
深度学习lover1 小时前
<项目代码>yolo毛毛虫识别<目标检测>
人工智能·python·yolo·目标检测·计算机视觉·毛毛虫识别
喵手1 小时前
Python爬虫零基础入门【第九章:实战项目教学·第3节】通用清洗工具包:日期/金额/单位/空值(可复用)!
爬虫·python·python爬虫实战·python爬虫工程化实战·python爬虫零基础入门·通用清洗工具包·爬虫实战项目