机器学习 - 预测训练模型

接着上篇博客机器学习-训练模型做进一步说明。

There are three things to make predictions (also called performing inference) with a PyTorch model:

  1. Set the model in evaluation mode (model.eval())
  2. Make the predictions using the inference mode context manager (with torch.inference_mode(): ...)
  3. All predictions should be made with objects on the same device (e.g. data and model on GPU only or data and model on CPU only).

The first two items make sure all helpful calculations and settings PyTorch uses behind the scenes during training but aren't necessary for inference are turned off (this results in faster computation). And the third ensures that you won't run into cross-device errors.


下面代码片段是连接之前的博客

python 复制代码
import torch

# 1. Set the model in evaluation mode 
model_0.eval() 

# 2. Setup the inference mode context manager
with torch.inference_mode():
  # 3. Make sure the calculations are done with the model and data on the same device
  y_preds = model_0(X_test)

print(y_preds)

plot_predictions(predictions=y_preds)

# 结果如下
tensor([[0.8685],
        [0.8825],
        [0.8965],
        [0.9105],
        [0.9245],
        [0.9384],
        [0.9524],
        [0.9664],
        [0.9804],
        [0.9944]])

在下图,能看到预测点跟测试点很靠近,这结果挺理想的

这里稍微介绍一下 torch.inference_mode()

torch.inference.mode() 是一个上下文管理器,用于控制推断模式下的模型行为。在深度学习中,模型在训练和推断 (或称为预测) 阶段有不同的行为。在推断阶段,通常不需要计算梯度,也不需要跟踪计算图,这样可以提高推断速度并减少内存占用。torch.inference_mode() 上下文管理器就是为了控制模型在推断阶段的行为。

当进入torch.inference_mode() 上下文环境时,PyTorch会关闭梯度跟踪,并且禁用自动微分机制。这意味着在此环境中,无法调用backward()方法计算梯度,也无法通过梯度进行参数更新。这样可以确保模型在推断阶段不会意外地计算梯度,提高了推断的速度和效率。


都看到这里,点个赞支持一下呗~

相关推荐
云安全助手1 分钟前
Anthropic年度报告解读:AI重塑网络攻击形态,传统防御体系亟待升级
人工智能·安全·网络安全·ai大模型
pythonpioneer2 分钟前
PyTorch3D:基于 PyTorch 的高效 3D 深度学习工具库
pytorch·深度学习·其他·3d
谁似人间西林客10 分钟前
汽车智能制造解决方案:如何通过智能仓储物流降本提效?
人工智能·汽车·制造
jiushiapwojdap23 分钟前
Antigravity Awesome Skills:1527+ AI 编程助手的可安装技能库
人工智能·其他
顾北顾24 分钟前
多头注意力机制
人工智能·深度学习·算法
hujinyuan2016039 分钟前
2025年12月中国电子学会青少年机器人技术等级考试试卷(二级) 真题+答案
人工智能·算法·机器人
码农小白AI1 小时前
采购合同与来料证书对标校验,IACheck联动AI报告审核通审Agent版自动识别指标不符单据
人工智能
大江东去浪淘尽千古风流人物1 小时前
【PromptStereo】零样本立体匹配新范式:用结构与运动Prompt驱动迭代优化(CVPR 2026)
深度学习·3d·slam·视觉定位·dust3r·3d重建·mast3r
元岳数字人小元1 小时前
AI 数字人开发公司浅谈 虚拟数字人打造景区新服务
人工智能·人机交互·交互
哦哦~9211 小时前
AI赋能生物医学:从临床数据到药物分子性质预测实战培
人工智能·生物医学·药物分子