现代卷积网络实战系列2:训练函数、PyTorch构建LeNet网络

4、训练函数

4.1 调用训练函数

python 复制代码
train(epochs, net, train_loader, device, optimizer, test_loader, true_value)

因为每一个epoch训练结束后,我们需要测试一下这个网络的性能,所有会在训练函数中频繁调用测试函数,所有测试函数中所有需要的参数,训练函数都需要

这七个参数,是训练一个神经网络所需要的最少参数

4.2 训练函数

训练函数中,所有训练集进行多次迭代,而每次迭代又会将数据分成多个批次进行迭代

python 复制代码
def train(epochs, net, train_loader, device, optimizer, test_loader, true_value):
    for epoch in range(1, epochs + 1):
        net.train()
        all_train_loss = []
        for batch_idx, (data, target) in enumerate(train_loader):
            data = data.to(device)
            target = target.to(device)
            optimizer.zero_grad()
            output = net(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            cur_train_loss = loss.item()
            all_train_loss.append(cur_train_loss)
        train_loss = np.round(np.mean(all_train_loss) * 1000, 2)
        print('\nepoch step:', epoch)
        print('training loss: ', train_loss)
        test(net, test_loader, device, true_value, epoch)
    print("\nTraining finished")
  1. 定义训练函数
  2. 安装epochs迭代数据
  3. 进入pytorch的训练模式
  4. all_train_loss 存放训练集5万张图片的损失值
  5. 按照batch取数据
  6. 数据进入GPU
  7. 标签进入GPU
  8. 梯度清零
  9. 当前batch进入网络后得到输出
  10. 根据输出得到当前损失
  11. 反向传播
  12. 梯度下降
  13. 获取损失的损失值(PyTorch框架中的数据)
  14. 把当前batch的损失加入all_train_loss数组中,结束batch的迭代
  15. 将5张图片的损失计算出来并且进行求平均,这里乘以1000是因为我觉得计算出的损失太小了,所以乘以1000,方便看损失的变化,保留两位有效数字
  16. 打印当前epoch
  17. 打印损失
  18. 调用测试函数,测试当前训练的网络的性能,结束epoch的迭代
  19. 打印训练完成

5、LeNet

5.1 网络结构

LeNet可以说是首次提出卷积神经网络的模型

主要包含下面的网络层:

  1. 5*5的二维卷积
  2. sigmoid激活函数(这里使用了relu)
  3. 5*5的二维卷积
  4. sigmoid激活函数
  5. 数据一维化
  6. 全连接层
  7. 全连接层
  8. softmax分类器

将网络结构打印出来:

LeNet(

-------(conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))

-------(conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))

-------(conv2_drop): Dropout2d(p=0.5, inplace=False)

-------(fc1): Linear(in_features=320, out_features=50, bias=True)

-------(fc2): Linear(in_features=50, out_features=10, bias=True)

)

5.2 PyTorch构建LeNet

python 复制代码
class LeNet(nn.Module):
    def __init__(self, num_classes):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, num_classes)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

这个时候已经是一个完整的项目了,看看10个epoch训练过程的打印:

D:\conda\envs\pytorch\python.exe A:\0_MNIST\train.py
Reading data...

train_data: (60000, 28, 28) train_label (60000,)

test_data: (10000, 28, 28) test_label (10000,)
Initialize neural network

test loss: 2301.68

test accuracy: 11.3 %
epoch step: 1

training loss: 634.74

test loss: 158.03

test accuracy: 95.29 %
epoch step: 2

training loss: 324.04

test loss: 107.62

test accuracy: 96.55 %
epoch step: 3

training loss: 271.25

test loss: 88.43

test accuracy: 97.04 %
epoch step: 4

training loss: 236.69

test loss: 70.94

test accuracy: 97.61 %
epoch step: 5

training loss: 211.05

test loss: 69.69

test accuracy: 97.72 %
epoch step: 6

training loss: 199.28

test loss: 62.04

test accuracy: 97.98 %
epoch step: 7

training loss: 187.11

test loss: 59.65

test accuracy: 97.98 %
epoch step: 8

training loss: 178.79

test loss: 53.89

test accuracy: 98.2 %
epoch step: 9

training loss: 168.75

test loss: 51.83

test accuracy: 98.43 %
epoch step: 10

training loss: 160.83

test loss: 50.35

test accuracy: 98.4 %
Training finished

进程已结束,退出代码为 0

可以看出基本上只要一个epoch就可以得到很好的训练效果了,后续的epoch中的提升比较小

相关推荐
永霖光电_UVLED28 分钟前
IVWorks率先将8英寸GaN纳米线片商业化
人工智能·神经网络·生成对抗网络
如何原谅奋力过但无声1 小时前
TensorFlow 2.x常用函数总结(持续更新)
人工智能·python·tensorflow
chao1898441 小时前
多光谱图像融合:IHS、PCA与小波变换的MATLAB实现
图像处理·计算机视觉·matlab
qyresearch_1 小时前
大语言模型训推一体机:AI算力革命的“新引擎”,2031年市场规模突破123亿的黄金赛道
人工智能·语言模型·自然语言处理
计算机小手2 小时前
使用 llama.cpp 在本地高效运行大语言模型,支持 Docker 一键启动,兼容CPU与GPU
人工智能·经验分享·docker·语言模型·开源软件
短视频矩阵源码定制2 小时前
矩阵系统哪个好?2025年全方位选型指南与品牌深度解析
java·人工智能·矩阵·架构·aigc
java1234_小锋2 小时前
[免费]基于Python的Flask酒店客房管理系统【论文+源码+SQL脚本】
开发语言·人工智能·python·flask·酒店客房
hakuii2 小时前
SVD分解后的各个矩阵的深层理解
人工智能·机器学习·矩阵
这张生成的图像能检测吗2 小时前
(论文速读)基于图像堆栈的低频超宽带SAR叶簇隐蔽目标变化检测
图像处理·人工智能·深度学习·机器学习·信号处理·雷达·变化检测
leijiwen2 小时前
城市本地生活实体零售可信数据空间 RWA 平台方案
人工智能·生活·零售