pytorch evaluate model(torch.no_grad() and model.eval())

  • 使用 torch.no_grad():这是一个上下文管理器(context manager),用于暂时禁用在其作用域内的所有计算的梯度计算。这在模型评估阶段非常有用,因为它可以减少内存消耗并提高计算效率,因为验证过程中不需要计算梯度信息。

  • 调用 model.eval():这将模型设置为评估模式。在这种模式下,模型中的某些层(如批量归一化层 BatchNorm 和 dropout 层)会改变其行为,以适应评估(例如,BatchNorm 层会使用在训练时收集的运行时统计数据,而 dropout 层会关闭)。

py 复制代码
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score

def evaluate_model(model, data_loader, device):
    """
    对给定的模型和数据加载器进行验证。
    
    参数:
    - model: 要验证的PyTorch模型。
    - data_loader: 数据的PyTorch DataLoader。
    - device: 用于模型和数据的设备('cuda' 或 'cpu')。
    """
    # 将模型设置为评估模式,这会关闭dropout和batch normalization层的训练行为
    model.eval()
    
    # 初始化度量指标
    total_correct = 0
    total_samples = 0
    
    # 使用torch.no_grad()上下文管理器来禁用梯度计算
    with torch.no_grad():
       # 遍历数据加载器中的所有批次
        for inputs, targets in data_loader:
            # 将数据移动到指定的设备
            inputs, targets = inputs.to(device), targets.to(device)
            
            # 前向传播,获取模型输出
            outputs = model(inputs)
            
            # 计算预测结果
            _, predicted = torch.max(outputs, 1)
            
            # 计算准确度
            correct = (predicted == targets).sum().item()
            
            total_correct += correct
            total_samples += targets.size(0)
    
    # 计算总体准确度
    accuracy = total_correct / total_samples
    
    return accuracy
相关推荐
天云数据2 分钟前
我把小某薯运营做成了一个Agent系统
人工智能
会飞的老朱5 分钟前
活动 | AI重构协同办公 九思软件以技术创新赋能企业高质量发展
人工智能·oa协同办公·智能办公平台
2402_854808376 分钟前
c++怎么利用std--span在不拷贝的情况下解析大规模文件映射【进阶】
jvm·数据库·python
2501_948114248 分钟前
2026旗舰模型四强争霸:GPT-5.4、Claude Opus 4.6、Gemini 3.1 Pro与Grok 4.20深度横评
人工智能·gpt·ai·谷歌
大模型备案@虎虎8 分钟前
海珠区第四批大模型备案奖励启动:以合规技术激励,夯实 AI 产业安全底座
人工智能·大模型备案·大模型备案奖励政策·大模型备案流程·生成式人工智能服务备案·大模型备案模板·大语言模型备案
IDZSY04309 分钟前
2026 年 AI 社交发展趋势:Agent 社交将成主流
人工智能
2301_777599379 分钟前
Redis怎样管理16384个哈希槽_利用cluster-config-file持久化保存节点与槽位的映射关系
jvm·数据库·python
qq_3422958211 分钟前
Go语言怎么用GitHub Actions_Go语言GitHub Actions教程【基础】.txt
jvm·数据库·python
慧一居士12 分钟前
AI 领域MaaS平台介绍
人工智能
Wyz2012102413 分钟前
如何利用虚拟 DOM 实现无痕刷新?基于 VNode 对比的状态保持技巧
jvm·数据库·python