DAY 40 训练和测试的规范写法

@浙大疏锦行https://blog.csdn.net/weixin_45655710
知识点回顾:

  1. 彩色和灰度图片测试和训练的规范写法:封装在函数中
  2. 展平操作:除第一个维度batchsize外全部展平
  3. dropout操作:训练阶段随机丢弃神经元,测试阶段eval模式关闭dropout

作业:仔细学习下测试和训练代码的逻辑,这是基础,这个代码框架后续会一直沿用,后续的重点慢慢就是转向模型定义阶段了。

今天代码中训练 (train) 和测试 (test) 函数的规范写法。这套代码框架是PyTorch深度学习项目的基石,理解了它,未来接触任何复杂的模型和任务,其核心训练逻辑都是万变不离其宗的。

一个通俗的比喻来理解:把整个过程想象成"学生(模型)备考(训练)和参加期末考(测试)"

核心框架概览

我们的代码主要由两个核心部分组成:

  1. train 函数 :学生在一个学期内(一个epoch)反复做练习题(train_loader中的数据)的过程。
  2. test 函数 :学期结束后,用一套全新的模拟卷(test_loader中的数据)来检验学生的真实水平。

主程序 (if __name__ == "__main__":) 则扮演**"教务处"**的角色,负责安排学期总数 (epochs),并协调"训练"和"测试"的进行。

一、 train 函数解析:学生的学习过程

def train(model, train_loader, ...)

这个函数的目标是让模型 (model) 通过学习训练数据 (train_loader) 来不断更新自己的知识(权重参数)。

它的内部逻辑可以分为两层循环:

外层循环:for epoch in range(epochs): (一个学期)
  • epoch 代表一个完整的学习周期,我们称之为"轮次"。在一轮中,学生(模型)会把所有的练习册(整个train_dataset)从头到尾做一遍。
  • model.train() :在每个学期开始时,学生要告诉自己:"现在是学习时间! " 这会开启一些只在学习时才用的"超能力",比如 Dropout (为了防止死记硬背而故意忘掉一些东西)和 BatchNorm(一种让学习更稳定的技巧)。
内层循环:for batch_idx, (data, target) in enumerate(train_loader): (做一页练习题)
  • train_loader 像是一本很厚的练习册,它被分成了很多页,每一页就是一批 (batch) 数据。
  • 这个循环就是学生一页一页地做练习题的过程。data是这一页的题目(图像),target是标准答案(标签)。
  • data, target = data.to(device), target.to(device): 把这一页练习题和答案都拿到"大脑"(GPU)里去处理,速度更快。

做一页练习题的核心四步曲:

  1. optimizer.zero_grad() (清空草稿纸):在做新一页题前,先把上一页的计算草稿(梯度)擦干净。PyTorch默认会累积梯度,所以每次都必须手动清零。
  2. output = model(data) (做题) :学生(模型)根据自己当前的知识水平,对这页的题目(data)给出自己的答案(output)。
  3. loss = criterion(output, target) (对答案并计算差距)criterion (损失函数) 就像一个评分老师,它会比较学生的答案(output)和标准答案(target),然后计算出一个差距值(loss 。差距越大,loss值也越大。
  4. loss.backward() (反思总结) :这是最神奇的一步,也叫反向传播 。学生根据差距 (loss),反思自己知识体系里的每一个知识点(模型参数)对这次做错题的"责任"有多大。这个"责任"就是梯度
  5. optimizer.step() (修正知识)optimizer (优化器) 像一个学习方法指导老师,它根据每个知识点的"责任"(梯度),告诉学生该如何去调整、更新自己的知识(模型参数),以便下次能做得更好。

二、 test 函数解析:学生的期末考试

def test(model, test_loader, ...)

这个函数的目标是检验模型在从未见过的新数据上的表现,以评估其真实的泛化能力。

它的核心逻辑如下:

  1. model.eval() (进入考试模式) :在考试前,学生要告诉自己:"现在是考试时间! " 这会关闭那些只在学习时才用的"超能力"(如Dropout和BatchNorm),确保每次考试的结果都是稳定、一致的。这是至关重要的一步。
  2. with torch.no_grad(): (收起草稿纸,只答题不学习):这个代码块告诉PyTorch:"接下来只进行计算,不需要记录任何'反思过程'(梯度)"。这能大大加快计算速度,并节省显存,因为考试时不需要再学习了。
  3. 循环与计算
    • 它会遍历测试题库 (test_loader) 中的每一批数据。
    • output = model(data) (做题):学生用自己最终学到的知识来解答这些全新的题目。
    • correct += ... (计分):将学生的答案与标准答案进行比较,统计做对的总题数。
  4. 返回结果 :最终计算出总的平均损失准确率,作为这次期末考的最终成绩。

总结:一条清晰的逻辑线

将整个流程想象成一个高度自动化、目标明确的"智能教育系统":

  1. 数据准备 (DataLoader):系统将海量的练习题和模拟卷整理成册,分门别类。
  2. 模型定义 (nn.Module):我们设计了一个"学生"的大脑结构。
  3. 主流程( if __name__ == "__main__": )
    • "教务处"宣布:"本学期共 epochs 轮学习!"
    • 进入每一轮学习 (for epoch in ...)
      • 首先,命令学生(模型)进入学习状态 (model.train()) ,并开始做一整本练习册 (train函数)
      • 做完练习册后,为了检验本轮学习效果,立刻命令学生进入考试状态 (model.eval()) ,做一套期末模拟卷 (test函数),并公布成绩。
    • 所有学期结束后,整个培养计划完成。

这个"训练一轮,测试一轮"的循环框架,是深度学习项目中最核心、最通用的代码结构。掌握了它,就掌握了驱动所有复杂模型进行学习和评估的"引擎"。

相关推荐
边缘常驻民21 分钟前
PyTorch深度学习入门记录3
人工智能·pytorch·深度学习
a1504631 小时前
人工智能——图像梯度处理、边缘检测、绘制图像轮廓、凸包特征检测
人工智能·深度学习·计算机视觉
格林威5 小时前
Baumer工业相机堡盟工业相机如何通过YoloV8深度学习模型实现卫星图像识别(C#代码,UI界面版)
人工智能·深度学习·数码相机·yolo·计算机视觉
码字的字节18 小时前
深度学习损失函数的设计哲学:从交叉熵到Huber损失的深入探索
深度学习·交叉熵·huber
凪卄121318 小时前
图像预处理 二
人工智能·python·深度学习·计算机视觉·pycharm
碳酸的唐18 小时前
Inception网络架构:深度学习视觉模型的里程碑
网络·深度学习·架构
AI赋能18 小时前
自动驾驶训练-tub详解
人工智能·深度学习·自动驾驶
seasonsyy18 小时前
1.安装anaconda详细步骤(含安装截图)
python·深度学习·环境配置
deephub18 小时前
AI代理性能提升实战:LangChain+LangGraph内存管理与上下文优化完整指南
人工智能·深度学习·神经网络·langchain·大语言模型·rag
go546315846519 小时前
基于深度学习的食管癌右喉返神经旁淋巴结预测系统研究
图像处理·人工智能·深度学习·神经网络·算法