【Pytorch】model.eval()与model.train()

model.train():

作用是启用Batch Normalization 和 Dropout

如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train()。model.train()是保证BN层能够用到每一批数据的均值和方差。对于Dropout,model.train()是随机取一部分网络连接来训练更新参数。

model.eval():

如果模型中有BN层(Batch Normalization)和Dropout,在测试时添加model.eval()。保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。对于Dropout,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。

训练完train样本后,生成的模型model要用来测试样本。在测试集上进行测试之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有BN层和Dropout所带来的的性质。

相关推荐
伊织code4 小时前
PyTorch API 5 - 全分片数据并行、流水线并行、概率分布
pytorch·python·ai·api·-·5
CM莫问5 小时前
<论文>(微软)避免推荐域外物品:基于LLM的受限生成式推荐
人工智能·算法·大模型·推荐算法·受限生成
康谋自动驾驶6 小时前
康谋分享 | 自动驾驶仿真进入“标准时代”:aiSim全面对接ASAM OpenX
人工智能·科技·算法·机器学习·自动驾驶·汽车
深蓝学院7 小时前
密西根大学新作——LightEMMA:自动驾驶中轻量级端到端多模态模型
人工智能·机器学习·自动驾驶
归去_来兮7 小时前
人工神经网络(ANN)模型
人工智能·机器学习·人工神经网络
2201_754918417 小时前
深入理解卷积神经网络:从基础原理到实战应用
人工智能·神经网络·cnn
强盛小灵通专卖员7 小时前
DL00219-基于深度学习的水稻病害检测系统含源码
人工智能·深度学习·水稻病害
Luke Ewin8 小时前
CentOS7.9部署FunASR实时语音识别接口 | 部署商用级别实时语音识别接口FunASR
人工智能·语音识别·实时语音识别·商用级别实时语音识别
Joern-Lee8 小时前
初探机器学习与深度学习
人工智能·深度学习·机器学习
云卓SKYDROID8 小时前
无人机数据处理与特征提取技术分析!
人工智能·科技·无人机·科普·云卓科技