model.train()
和 model.eval()
是 PyTorch 中常用的两个方法,用于切换模型的模式(training/evaluation)。它们的主要目的是在训练和评估过程中设置模型的行为,使其根据不同阶段进行合适的计算,特别是涉及一些特定层的行为差异(如 Dropout
和 BatchNorm
层)。以下是它们的详细介绍:
1. model.train()
model.train()
将模型设置为"训练模式"(training mode)。在调用此方法后,模型内部的各个层会自动调整到训练所需的状态。
-
关键影响层:
Dropout
:在训练模式下,Dropout
会随机丢弃一些神经元,以增加模型的泛化能力,减少过拟合。BatchNorm
:BatchNorm
会根据当前批次数据计算均值和方差,并更新内部的运行均值和方差,以逐步累积整体数据的统计信息。
-
使用场景 :训练模型时调用。每次开始训练循环之前,调用
model.train()
以确保模型处于正确的训练状态。 -
代码示例:
2. model.eval()
model.eval()
将模型设置为"评估模式"(evaluation mode)。在此模式下,模型会调整为适合推理或验证的状态。
-
关键影响层:
Dropout
:在评估模式下,Dropout
层会停用,不再随机丢弃神经元,确保每次前向传播都得到相同的结果。BatchNorm
:BatchNorm
层会使用训练期间累积的均值和方差,而不是当前批次的统计信息,以确保推理结果的稳定性。
-
使用场景:在验证或测试阶段,或者进行模型推理时调用。评估模式能确保模型在这些阶段的行为一致,并且减少不必要的计算负担。
-
代码示例:
model.eval() # 切换到评估模式
with torch.no_grad(): # 禁用梯度计算,节省内存
for data, target in test_loader:
output = model(data)
test_loss += loss_fn(output, target).item()
3. 注意事项
- 作用范围 :
model.train()
和model.eval()
对模型及其所有子模块有效,所有层都会递归切换模式。 - 与
torch.no_grad()
配合使用 :在评估模式下通常会使用with torch.no_grad()
禁用梯度计算,以减少内存占用和加速计算。model.eval()
本身并不会禁用梯度计算,二者需要配合使用。
总结
model.train()
:在训练时调用,适用于调整模型以适应训练的行为,如随机Dropout
和动态BatchNorm
。model.eval()
:在评估或推理时调用,确保推理的稳定性,Dropout
停用,BatchNorm
使用训练时的统计数据。