pytorch evaluate model(torch.no_grad() and model.eval())

  • 使用 torch.no_grad():这是一个上下文管理器(context manager),用于暂时禁用在其作用域内的所有计算的梯度计算。这在模型评估阶段非常有用,因为它可以减少内存消耗并提高计算效率,因为验证过程中不需要计算梯度信息。

  • 调用 model.eval():这将模型设置为评估模式。在这种模式下,模型中的某些层(如批量归一化层 BatchNorm 和 dropout 层)会改变其行为,以适应评估(例如,BatchNorm 层会使用在训练时收集的运行时统计数据,而 dropout 层会关闭)。

py 复制代码
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score

def evaluate_model(model, data_loader, device):
    """
    对给定的模型和数据加载器进行验证。
    
    参数:
    - model: 要验证的PyTorch模型。
    - data_loader: 数据的PyTorch DataLoader。
    - device: 用于模型和数据的设备('cuda' 或 'cpu')。
    """
    # 将模型设置为评估模式,这会关闭dropout和batch normalization层的训练行为
    model.eval()
    
    # 初始化度量指标
    total_correct = 0
    total_samples = 0
    
    # 使用torch.no_grad()上下文管理器来禁用梯度计算
    with torch.no_grad():
       # 遍历数据加载器中的所有批次
        for inputs, targets in data_loader:
            # 将数据移动到指定的设备
            inputs, targets = inputs.to(device), targets.to(device)
            
            # 前向传播,获取模型输出
            outputs = model(inputs)
            
            # 计算预测结果
            _, predicted = torch.max(outputs, 1)
            
            # 计算准确度
            correct = (predicted == targets).sum().item()
            
            total_correct += correct
            total_samples += targets.size(0)
    
    # 计算总体准确度
    accuracy = total_correct / total_samples
    
    return accuracy
相关推荐
科技圈快讯1 分钟前
破解企业低碳转型难题,港华商会携手碳启元出击
大数据·人工智能
hhzz13 分钟前
【Vision人工智能设计 】ComfyUI 基础文生图设计
人工智能·comfyui·视觉大模型·wan
有Li20 分钟前
用于CBCT到CT合成的纹理保留扩散模型/文献速递-基于人工智能的医学影像技术
论文阅读·人工智能·深度学习·计算机视觉·文献
大模型最新论文速读24 分钟前
NAtS-L: 自适应融合多种注意力架构,推理能力提高 36%
人工智能·深度学习·机器学习·语言模型·自然语言处理
TYFHVB1224 分钟前
11款CRM数字化方案横评:获客-履约-复购全链路能力对决
大数据·人工智能·架构·自动化·流程图
zcbk016836 分钟前
不踩坑!手把手教你在 Mac 上安装 Windows(含分区/虚拟机/驱动解决方案)
python
Dev7z41 分钟前
滚压表面强化过程中变形诱导位错演化与梯度晶粒细化机理的数值模拟研究
人工智能·python·算法
魔乐社区1 小时前
来魔乐,一键获取OpenClaw (原Moltbolt/Clawdbot),告别部署烦恼!
人工智能·开源·agent·clawdbot·openclaw
吴秋霖1 小时前
apple游客下单逆向分析
python·算法·逆向分析
feasibility.2 小时前
yolo11-seg在ISIC2016医疗数据集训练预测流程(含AOP调loss函数方法)
人工智能·python·yolo·计算机视觉·健康医疗·实例分割·isic2016