-
使用
torch.no_grad()
:这是一个上下文管理器(context manager),用于暂时禁用在其作用域内的所有计算的梯度计算。这在模型评估阶段非常有用,因为它可以减少内存消耗并提高计算效率,因为验证过程中不需要计算梯度信息。 -
调用
model.eval()
:这将模型设置为评估模式。在这种模式下,模型中的某些层(如批量归一化层 BatchNorm 和 dropout 层)会改变其行为,以适应评估(例如,BatchNorm 层会使用在训练时收集的运行时统计数据,而 dropout 层会关闭)。
py
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
def evaluate_model(model, data_loader, device):
"""
对给定的模型和数据加载器进行验证。
参数:
- model: 要验证的PyTorch模型。
- data_loader: 数据的PyTorch DataLoader。
- device: 用于模型和数据的设备('cuda' 或 'cpu')。
"""
# 将模型设置为评估模式,这会关闭dropout和batch normalization层的训练行为
model.eval()
# 初始化度量指标
total_correct = 0
total_samples = 0
# 使用torch.no_grad()上下文管理器来禁用梯度计算
with torch.no_grad():
# 遍历数据加载器中的所有批次
for inputs, targets in data_loader:
# 将数据移动到指定的设备
inputs, targets = inputs.to(device), targets.to(device)
# 前向传播,获取模型输出
outputs = model(inputs)
# 计算预测结果
_, predicted = torch.max(outputs, 1)
# 计算准确度
correct = (predicted == targets).sum().item()
total_correct += correct
total_samples += targets.size(0)
# 计算总体准确度
accuracy = total_correct / total_samples
return accuracy