一个简单的图像分类项目(六)编写脚本:初步训练

训练的脚本 ,用于训练和测试。lib.train.py:

python 复制代码
import time

from load_imags import train_loader, train_num
from nets import *


def main():
    # 定义网络
    print('Please choose a network:')
    print('1. ResNet18')
    print('2. VGG')

    # 选择网络
    while True:
        net_choose = input('')
        if net_choose == '1':
            net = resnet18_model().to(device)
            print('You choose ResNet18,now start training')
            break
        elif net_choose == '2':
            net = vgg_model().to(device)
            print('You choose VGG,now start training')
            break
        else:
            print('Please input a correct number!')

    # 定义损失函数和优化器
    loss_func = nn.CrossEntropyLoss()  # 交叉熵损失函数
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)  # 优化器使用Adam
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=5,
                                                gamma=0.9)  # 学习率衰减, 每5个epoch,学习率乘以0.9

    # 训练模型
    for epoch in range(num_epoches):
        trained_num = 0  # 记录训练过的图片数量
        total_correct = 0  # 记录正确数量
        print('-' * 100)
        print('Epoch {}/{}'.format(epoch + 1, num_epoches))
        begin_time = time.time()  # 记录开始时间
        net.train()  # 训练模式
        for i, (images, labels) in enumerate(train_loader):
            images = images.to(device)  # 每batch_size个图像的数据
            labels = labels.to(device)  # 每batch_size个图像的标签
            trained_num += images.size(0)  # 记录训练过的图片数量
            outputs = net(images)  # 前向传播
            loss = loss_func(outputs, labels)  # 计算损失
            optimizer.zero_grad()  # 梯度清零
            loss.backward()  # 反向传播
            optimizer.step()  # 优化器更新参数

            _, predicted = torch.max(outputs.data, 1)  # 预测结果
            correct = predicted.eq(labels).cpu().sum()  # 计算本batch_size的正确数量
            total_correct += correct  # 记录正确数量
            if (i + 1) % 50 == 0:  # 每50个batch_size打印一次
                print('trained: {}/{}'.format(trained_num, train_num))
                print('Loss: {:.4f}, Accuracy: {:.2f}%'.format(loss.item(), 100 * correct / images.size(0)))

        # 每5个epoch,学习率衰减
        scheduler.step()
        end_time = time.time()  # 记录结束时间
        print('Each train_epoch take time: {} s'.format(end_time - begin_time))
        print('This train_epoch accuracy: {:.2f}%'.format(100 * total_correct / train_num))


if __name__ == '__main__':
    main()
python 复制代码
C:\Users\DY\.conda\envs\torch\python.exe E:\AI_test\image_classification\lib\train.py 
Please choose a network:
1. ResNet18
2. VGG
2
You choose VGG,now start training
----------------------------------------------------------------------------------------------------
Epoch 1/100
trained: 6400/50000
Loss: 2.3902, Accuracy: 10.16%
trained: 12800/50000
Loss: 2.3063, Accuracy: 11.72%
trained: 19200/50000
Loss: 2.1875, Accuracy: 18.75%
trained: 25600/50000
Loss: 2.1349, Accuracy: 19.53%
trained: 32000/50000
Loss: 1.9848, Accuracy: 26.56%
trained: 38400/50000
Loss: 2.0000, Accuracy: 16.41%
trained: 44800/50000
Loss: 2.0151, Accuracy: 25.78%
Each train_epoch take time: 71.04850149154663 s
This train_epoch accuracy: 19.34%
----------------------------------------------------------------------------------------------------
Epoch 2/100
trained: 6400/50000
Loss: 1.8815, Accuracy: 28.12%
trained: 12800/50000
Loss: 1.8677, Accuracy: 34.38%
trained: 19200/50000
Loss: 1.7808, Accuracy: 39.06%
trained: 25600/50000
Loss: 1.9118, Accuracy: 29.69%
trained: 32000/50000
Loss: 1.6296, Accuracy: 39.84%
trained: 38400/50000
Loss: 1.6648, Accuracy: 35.94%
trained: 44800/50000
Loss: 1.7854, Accuracy: 33.59%
Each train_epoch take time: 66.71016025543213 s
This train_epoch accuracy: 33.65%
----------------------------------------------------------------------------------------------------
Epoch 3/100
trained: 6400/50000
Loss: 1.4987, Accuracy: 44.53%
trained: 12800/50000
Loss: 1.6677, Accuracy: 41.41%
trained: 19200/50000
Loss: 1.6952, Accuracy: 43.75%
trained: 25600/50000
Loss: 1.6941, Accuracy: 38.28%
trained: 32000/50000
Loss: 1.4057, Accuracy: 49.22%
trained: 38400/50000
Loss: 1.5183, Accuracy: 44.53%
trained: 44800/50000
Loss: 1.6591, Accuracy: 37.50%
Each train_epoch take time: 68.37232995033264 s
This train_epoch accuracy: 41.65%
----------------------------------------------------------------------------------------------------
Epoch 4/100
trained: 6400/50000
Loss: 1.6636, Accuracy: 43.75%
trained: 12800/50000
Loss: 1.5985, Accuracy: 42.19%
trained: 19200/50000
Loss: 1.4054, Accuracy: 52.34%
trained: 25600/50000
Loss: 1.4520, Accuracy: 40.62%
trained: 32000/50000
Loss: 1.4574, Accuracy: 46.09%
trained: 38400/50000
Loss: 1.4711, Accuracy: 42.19%
trained: 44800/50000
Loss: 1.4806, Accuracy: 43.75%
Each train_epoch take time: 68.32443571090698 s
This train_epoch accuracy: 46.48%
----------------------------------------------------------------------------------------------------
Epoch 5/100
trained: 6400/50000
Loss: 1.2265, Accuracy: 57.03%
trained: 12800/50000
Loss: 1.3454, Accuracy: 52.34%
trained: 19200/50000
Loss: 1.3527, Accuracy: 49.22%
trained: 25600/50000
Loss: 1.2874, Accuracy: 53.12%
trained: 32000/50000
Loss: 1.3666, Accuracy: 55.47%
trained: 38400/50000
Loss: 1.4465, Accuracy: 50.00%
trained: 44800/50000
Loss: 1.2802, Accuracy: 52.34%
Each train_epoch take time: 68.22098922729492 s
This train_epoch accuracy: 50.72%
----------------------------------------------------------------------------------------------------
Epoch 6/100
trained: 6400/50000
Loss: 1.3402, Accuracy: 51.56%
trained: 12800/50000
Loss: 1.2873, Accuracy: 53.91%
trained: 19200/50000
Loss: 1.3183, Accuracy: 52.34%
trained: 25600/50000
Loss: 1.3688, Accuracy: 48.44%
trained: 32000/50000
Loss: 1.2143, Accuracy: 55.47%
trained: 38400/50000
Loss: 1.2132, Accuracy: 56.25%
trained: 44800/50000
Loss: 1.3172, Accuracy: 53.12%
Each train_epoch take time: 68.76534986495972 s
This train_epoch accuracy: 54.53%
----------------------------------------------------------------------------------------------------
Epoch 7/100
trained: 6400/50000
Loss: 1.3156, Accuracy: 53.12%
trained: 12800/50000
Loss: 1.1412, Accuracy: 60.16%
trained: 19200/50000
Loss: 1.1978, Accuracy: 57.03%
trained: 25600/50000
Loss: 1.0312, Accuracy: 55.47%
trained: 32000/50000
Loss: 1.3486, Accuracy: 50.00%
trained: 38400/50000
Loss: 1.1591, Accuracy: 60.16%
trained: 44800/50000
Loss: 1.0707, Accuracy: 63.28%
Each train_epoch take time: 68.1180489063263 s
This train_epoch accuracy: 56.99%
----------------------------------------------------------------------------------------------------
Epoch 8/100

看得出,模型是在逐步收敛的。下一步,完善训练脚本,加入测试的代码。

相关推荐
古希腊掌管学习的神6 分钟前
[搜广推]王树森推荐系统——矩阵补充&最近邻查找
python·算法·机器学习·矩阵
martian66512 分钟前
【人工智能数学基础篇】——深入详解多变量微积分:在机器学习模型中优化损失函数时应用
人工智能·机器学习·微积分·数学基础
人机与认知实验室1 小时前
人、机、环境中各有其神经网络系统
人工智能·深度学习·神经网络·机器学习
LucianaiB1 小时前
探索CSDN博客数据:使用Python爬虫技术
开发语言·爬虫·python
黑色叉腰丶大魔王1 小时前
基于 MATLAB 的图像增强技术分享
图像处理·人工智能·计算机视觉
PieroPc3 小时前
Python 写的 智慧记 进销存 辅助 程序 导入导出 excel 可打印
开发语言·python·excel
迅易科技4 小时前
借助腾讯云质检平台的新范式,做工业制造企业质检的“AI慧眼”
人工智能·视觉检测·制造
古希腊掌管学习的神5 小时前
[机器学习]XGBoost(3)——确定树的结构
人工智能·机器学习
ZHOU_WUYI6 小时前
4.metagpt中的软件公司智能体 (ProjectManager 角色)
人工智能·metagpt
靴子学长6 小时前
基于字节大模型的论文翻译(含免费源码)
人工智能·深度学习·nlp