DAY 40 训练和测试的规范写法

@浙大疏锦行https://blog.csdn.net/weixin_45655710
知识点回顾:

  1. 彩色和灰度图片测试和训练的规范写法:封装在函数中
  2. 展平操作:除第一个维度batchsize外全部展平
  3. dropout操作:训练阶段随机丢弃神经元,测试阶段eval模式关闭dropout

作业:仔细学习下测试和训练代码的逻辑,这是基础,这个代码框架后续会一直沿用,后续的重点慢慢就是转向模型定义阶段了。

今天代码中训练 (train) 和测试 (test) 函数的规范写法。这套代码框架是PyTorch深度学习项目的基石,理解了它,未来接触任何复杂的模型和任务,其核心训练逻辑都是万变不离其宗的。

一个通俗的比喻来理解:把整个过程想象成"学生(模型)备考(训练)和参加期末考(测试)"

核心框架概览

我们的代码主要由两个核心部分组成:

  1. train 函数 :学生在一个学期内(一个epoch)反复做练习题(train_loader中的数据)的过程。
  2. test 函数 :学期结束后,用一套全新的模拟卷(test_loader中的数据)来检验学生的真实水平。

主程序 (if __name__ == "__main__":) 则扮演**"教务处"**的角色,负责安排学期总数 (epochs),并协调"训练"和"测试"的进行。

一、 train 函数解析:学生的学习过程

def train(model, train_loader, ...)

这个函数的目标是让模型 (model) 通过学习训练数据 (train_loader) 来不断更新自己的知识(权重参数)。

它的内部逻辑可以分为两层循环:

外层循环:for epoch in range(epochs): (一个学期)
  • epoch 代表一个完整的学习周期,我们称之为"轮次"。在一轮中,学生(模型)会把所有的练习册(整个train_dataset)从头到尾做一遍。
  • model.train() :在每个学期开始时,学生要告诉自己:"现在是学习时间! " 这会开启一些只在学习时才用的"超能力",比如 Dropout (为了防止死记硬背而故意忘掉一些东西)和 BatchNorm(一种让学习更稳定的技巧)。
内层循环:for batch_idx, (data, target) in enumerate(train_loader): (做一页练习题)
  • train_loader 像是一本很厚的练习册,它被分成了很多页,每一页就是一批 (batch) 数据。
  • 这个循环就是学生一页一页地做练习题的过程。data是这一页的题目(图像),target是标准答案(标签)。
  • data, target = data.to(device), target.to(device): 把这一页练习题和答案都拿到"大脑"(GPU)里去处理,速度更快。

做一页练习题的核心四步曲:

  1. optimizer.zero_grad() (清空草稿纸):在做新一页题前,先把上一页的计算草稿(梯度)擦干净。PyTorch默认会累积梯度,所以每次都必须手动清零。
  2. output = model(data) (做题) :学生(模型)根据自己当前的知识水平,对这页的题目(data)给出自己的答案(output)。
  3. loss = criterion(output, target) (对答案并计算差距)criterion (损失函数) 就像一个评分老师,它会比较学生的答案(output)和标准答案(target),然后计算出一个差距值(loss 。差距越大,loss值也越大。
  4. loss.backward() (反思总结) :这是最神奇的一步,也叫反向传播 。学生根据差距 (loss),反思自己知识体系里的每一个知识点(模型参数)对这次做错题的"责任"有多大。这个"责任"就是梯度
  5. optimizer.step() (修正知识)optimizer (优化器) 像一个学习方法指导老师,它根据每个知识点的"责任"(梯度),告诉学生该如何去调整、更新自己的知识(模型参数),以便下次能做得更好。

二、 test 函数解析:学生的期末考试

def test(model, test_loader, ...)

这个函数的目标是检验模型在从未见过的新数据上的表现,以评估其真实的泛化能力。

它的核心逻辑如下:

  1. model.eval() (进入考试模式) :在考试前,学生要告诉自己:"现在是考试时间! " 这会关闭那些只在学习时才用的"超能力"(如Dropout和BatchNorm),确保每次考试的结果都是稳定、一致的。这是至关重要的一步。
  2. with torch.no_grad(): (收起草稿纸,只答题不学习):这个代码块告诉PyTorch:"接下来只进行计算,不需要记录任何'反思过程'(梯度)"。这能大大加快计算速度,并节省显存,因为考试时不需要再学习了。
  3. 循环与计算
    • 它会遍历测试题库 (test_loader) 中的每一批数据。
    • output = model(data) (做题):学生用自己最终学到的知识来解答这些全新的题目。
    • correct += ... (计分):将学生的答案与标准答案进行比较,统计做对的总题数。
  4. 返回结果 :最终计算出总的平均损失准确率,作为这次期末考的最终成绩。

总结:一条清晰的逻辑线

将整个流程想象成一个高度自动化、目标明确的"智能教育系统":

  1. 数据准备 (DataLoader):系统将海量的练习题和模拟卷整理成册,分门别类。
  2. 模型定义 (nn.Module):我们设计了一个"学生"的大脑结构。
  3. 主流程( if __name__ == "__main__": )
    • "教务处"宣布:"本学期共 epochs 轮学习!"
    • 进入每一轮学习 (for epoch in ...)
      • 首先,命令学生(模型)进入学习状态 (model.train()) ,并开始做一整本练习册 (train函数)
      • 做完练习册后,为了检验本轮学习效果,立刻命令学生进入考试状态 (model.eval()) ,做一套期末模拟卷 (test函数),并公布成绩。
    • 所有学期结束后,整个培养计划完成。

这个"训练一轮,测试一轮"的循环框架,是深度学习项目中最核心、最通用的代码结构。掌握了它,就掌握了驱动所有复杂模型进行学习和评估的"引擎"。

相关推荐
胡耀超5 小时前
DataOceanAI Dolphin(ffmpeg音频转化教程) 多语言(中国方言)语音识别系统部署与应用指南
python·深度学习·ffmpeg·音视频·语音识别·多模态·asr
HUIMU_6 小时前
DAY12&DAY13-新世纪DL(Deeplearning/深度学习)战士:破(改善神经网络)1
人工智能·深度学习
mit6.8247 小时前
[1Prompt1Story] 注意力机制增强 IPCA | 去噪神经网络 UNet | U型架构分步去噪
人工智能·深度学习·神经网络
Coovally AI模型快速验证7 小时前
YOLO、DarkNet和深度学习如何让自动驾驶看得清?
深度学习·算法·yolo·cnn·自动驾驶·transformer·无人机
科大饭桶8 小时前
昇腾AI自学Day2-- 深度学习基础工具与数学
人工智能·pytorch·python·深度学习·numpy
努力还债的学术吗喽8 小时前
2021 IEEE【论文精读】用GAN让音频隐写术骗过AI检测器 - 对抗深度学习的音频信息隐藏
人工智能·深度学习·生成对抗网络·密码学·音频·gan·隐写
weixin_5079299110 小时前
第G7周:Semi-Supervised GAN 理论与实战
人工智能·pytorch·深度学习
AI波克布林12 小时前
发文暴论!线性注意力is all you need!
人工智能·深度学习·神经网络·机器学习·注意力机制·线性注意力
Blossom.11812 小时前
把 AI 推理塞进「 8 位 MCU 」——0.5 KB RAM 跑通关键词唤醒的魔幻之旅
人工智能·笔记·单片机·嵌入式硬件·深度学习·机器学习·搜索引擎
2502_9271612814 小时前
DAY 40 训练和测试的规范写法
人工智能·深度学习·机器学习