DAY 40 训练和测试的规范写法

@浙大疏锦行https://blog.csdn.net/weixin_45655710
知识点回顾:

  1. 彩色和灰度图片测试和训练的规范写法:封装在函数中
  2. 展平操作:除第一个维度batchsize外全部展平
  3. dropout操作:训练阶段随机丢弃神经元,测试阶段eval模式关闭dropout

作业:仔细学习下测试和训练代码的逻辑,这是基础,这个代码框架后续会一直沿用,后续的重点慢慢就是转向模型定义阶段了。

今天代码中训练 (train) 和测试 (test) 函数的规范写法。这套代码框架是PyTorch深度学习项目的基石,理解了它,未来接触任何复杂的模型和任务,其核心训练逻辑都是万变不离其宗的。

一个通俗的比喻来理解:把整个过程想象成"学生(模型)备考(训练)和参加期末考(测试)"

核心框架概览

我们的代码主要由两个核心部分组成:

  1. train 函数 :学生在一个学期内(一个epoch)反复做练习题(train_loader中的数据)的过程。
  2. test 函数 :学期结束后,用一套全新的模拟卷(test_loader中的数据)来检验学生的真实水平。

主程序 (if __name__ == "__main__":) 则扮演**"教务处"**的角色,负责安排学期总数 (epochs),并协调"训练"和"测试"的进行。

一、 train 函数解析:学生的学习过程

def train(model, train_loader, ...)

这个函数的目标是让模型 (model) 通过学习训练数据 (train_loader) 来不断更新自己的知识(权重参数)。

它的内部逻辑可以分为两层循环:

外层循环:for epoch in range(epochs): (一个学期)
  • epoch 代表一个完整的学习周期,我们称之为"轮次"。在一轮中,学生(模型)会把所有的练习册(整个train_dataset)从头到尾做一遍。
  • model.train() :在每个学期开始时,学生要告诉自己:"现在是学习时间! " 这会开启一些只在学习时才用的"超能力",比如 Dropout (为了防止死记硬背而故意忘掉一些东西)和 BatchNorm(一种让学习更稳定的技巧)。
内层循环:for batch_idx, (data, target) in enumerate(train_loader): (做一页练习题)
  • train_loader 像是一本很厚的练习册,它被分成了很多页,每一页就是一批 (batch) 数据。
  • 这个循环就是学生一页一页地做练习题的过程。data是这一页的题目(图像),target是标准答案(标签)。
  • data, target = data.to(device), target.to(device): 把这一页练习题和答案都拿到"大脑"(GPU)里去处理,速度更快。

做一页练习题的核心四步曲:

  1. optimizer.zero_grad() (清空草稿纸):在做新一页题前,先把上一页的计算草稿(梯度)擦干净。PyTorch默认会累积梯度,所以每次都必须手动清零。
  2. output = model(data) (做题) :学生(模型)根据自己当前的知识水平,对这页的题目(data)给出自己的答案(output)。
  3. loss = criterion(output, target) (对答案并计算差距)criterion (损失函数) 就像一个评分老师,它会比较学生的答案(output)和标准答案(target),然后计算出一个差距值(loss 。差距越大,loss值也越大。
  4. loss.backward() (反思总结) :这是最神奇的一步,也叫反向传播 。学生根据差距 (loss),反思自己知识体系里的每一个知识点(模型参数)对这次做错题的"责任"有多大。这个"责任"就是梯度
  5. optimizer.step() (修正知识)optimizer (优化器) 像一个学习方法指导老师,它根据每个知识点的"责任"(梯度),告诉学生该如何去调整、更新自己的知识(模型参数),以便下次能做得更好。

二、 test 函数解析:学生的期末考试

def test(model, test_loader, ...)

这个函数的目标是检验模型在从未见过的新数据上的表现,以评估其真实的泛化能力。

它的核心逻辑如下:

  1. model.eval() (进入考试模式) :在考试前,学生要告诉自己:"现在是考试时间! " 这会关闭那些只在学习时才用的"超能力"(如Dropout和BatchNorm),确保每次考试的结果都是稳定、一致的。这是至关重要的一步。
  2. with torch.no_grad(): (收起草稿纸,只答题不学习):这个代码块告诉PyTorch:"接下来只进行计算,不需要记录任何'反思过程'(梯度)"。这能大大加快计算速度,并节省显存,因为考试时不需要再学习了。
  3. 循环与计算
    • 它会遍历测试题库 (test_loader) 中的每一批数据。
    • output = model(data) (做题):学生用自己最终学到的知识来解答这些全新的题目。
    • correct += ... (计分):将学生的答案与标准答案进行比较,统计做对的总题数。
  4. 返回结果 :最终计算出总的平均损失准确率,作为这次期末考的最终成绩。

总结:一条清晰的逻辑线

将整个流程想象成一个高度自动化、目标明确的"智能教育系统":

  1. 数据准备 (DataLoader):系统将海量的练习题和模拟卷整理成册,分门别类。
  2. 模型定义 (nn.Module):我们设计了一个"学生"的大脑结构。
  3. 主流程( if __name__ == "__main__": )
    • "教务处"宣布:"本学期共 epochs 轮学习!"
    • 进入每一轮学习 (for epoch in ...)
      • 首先,命令学生(模型)进入学习状态 (model.train()) ,并开始做一整本练习册 (train函数)
      • 做完练习册后,为了检验本轮学习效果,立刻命令学生进入考试状态 (model.eval()) ,做一套期末模拟卷 (test函数),并公布成绩。
    • 所有学期结束后,整个培养计划完成。

这个"训练一轮,测试一轮"的循环框架,是深度学习项目中最核心、最通用的代码结构。掌握了它,就掌握了驱动所有复杂模型进行学习和评估的"引擎"。

相关推荐
studytosky2 小时前
深度学习理论与实战:MNIST 手写数字分类实战
人工智能·pytorch·python·深度学习·机器学习·分类·matplotlib
哥布林学者3 小时前
吴恩达深度学习课程三: 结构化机器学习项目 第一周:机器学习策略(二)数据集设置
深度学习·ai
【建模先锋】4 小时前
精品数据分享 | 锂电池数据集(四)PINN+锂离子电池退化稳定性建模和预测
深度学习·预测模型·pinn·锂电池剩余寿命预测·锂电池数据集·剩余寿命
九年义务漏网鲨鱼4 小时前
【大模型学习】现代大模型架构(二):旋转位置编码和SwiGLU
深度学习·学习·大模型·智能体
CoovallyAIHub4 小时前
破局红外小目标检测:异常感知Anomaly-Aware YOLO以“俭”驭“繁”
深度学习·算法·计算机视觉
云雾J视界5 小时前
AI芯片设计实战:用Verilog高级综合技术优化神经网络加速器功耗与性能
深度学习·神经网络·verilog·nvidia·ai芯片·卷积加速器
噜~噜~噜~14 小时前
最大熵原理(Principle of Maximum Entropy,MaxEnt)的个人理解
深度学习·最大熵原理
小女孩真可爱15 小时前
大模型学习记录(五)-------调用大模型API接口
pytorch·深度学习·学习
水月wwww19 小时前
深度学习——神经网络
人工智能·深度学习·神经网络
青瓷程序设计19 小时前
花朵识别系统【最新版】Python+TensorFlow+Vue3+Django+人工智能+深度学习+卷积神经网络算法
人工智能·python·深度学习