机器学习 - 预测训练模型

接着上篇博客机器学习-训练模型做进一步说明。

There are three things to make predictions (also called performing inference) with a PyTorch model:

  1. Set the model in evaluation mode (model.eval())
  2. Make the predictions using the inference mode context manager (with torch.inference_mode(): ...)
  3. All predictions should be made with objects on the same device (e.g. data and model on GPU only or data and model on CPU only).

The first two items make sure all helpful calculations and settings PyTorch uses behind the scenes during training but aren't necessary for inference are turned off (this results in faster computation). And the third ensures that you won't run into cross-device errors.


下面代码片段是连接之前的博客

python 复制代码
import torch

# 1. Set the model in evaluation mode 
model_0.eval() 

# 2. Setup the inference mode context manager
with torch.inference_mode():
  # 3. Make sure the calculations are done with the model and data on the same device
  y_preds = model_0(X_test)

print(y_preds)

plot_predictions(predictions=y_preds)

# 结果如下
tensor([[0.8685],
        [0.8825],
        [0.8965],
        [0.9105],
        [0.9245],
        [0.9384],
        [0.9524],
        [0.9664],
        [0.9804],
        [0.9944]])

在下图,能看到预测点跟测试点很靠近,这结果挺理想的

这里稍微介绍一下 torch.inference_mode()

torch.inference.mode() 是一个上下文管理器,用于控制推断模式下的模型行为。在深度学习中,模型在训练和推断 (或称为预测) 阶段有不同的行为。在推断阶段,通常不需要计算梯度,也不需要跟踪计算图,这样可以提高推断速度并减少内存占用。torch.inference_mode() 上下文管理器就是为了控制模型在推断阶段的行为。

当进入torch.inference_mode() 上下文环境时,PyTorch会关闭梯度跟踪,并且禁用自动微分机制。这意味着在此环境中,无法调用backward()方法计算梯度,也无法通过梯度进行参数更新。这样可以确保模型在推断阶段不会意外地计算梯度,提高了推断的速度和效率。


都看到这里,点个赞支持一下呗~

相关推荐
冷色系里的一抹暖调13 小时前
OpenClaw Docker 部署避坑指南:服务启动成功但网页打不开?
人工智能·windows·docker·ai·容器·opencode
沪漂阿龙13 小时前
卷积神经网络(CNN)零基础通关指南:原理、图解与PyTorch实战
人工智能·pytorch·cnn
Data-Miner13 小时前
54页可编辑PPT | 数据中台建设方案汇报
大数据·人工智能
语戚13 小时前
深度解析:Stable Diffusion 底层原理 + U-Net Denoise 去噪机制全拆解
人工智能·ai·stable diffusion·aigc·模型
卡梅德生物科技小能手13 小时前
整合素家族核心靶点解析:CD51(Integrin αv)的分子机制与药物研发技术前瞻
经验分享·深度学习·生活
舒一笑13 小时前
AI 时代最火的新岗位,不是提示词工程师,而是 Harness 工程师
人工智能·程序员·设计
明月醉窗台13 小时前
[jetson] AGX Xavier 安装Ubuntu18.04及jetpack4.5
人工智能·算法·nvidia·cuda·jetson
青稞社区.13 小时前
从最基础的模型出发,深度剖析高性能 VLA 的设计空间
人工智能·agi
weixin_5134499613 小时前
walk_these_ways项目学习记录第九篇(通过行为多样性 (MoB) 实现地形泛化)--学习算法
学习·算法·机器学习
夜猫逐梦13 小时前
【AI】 Claude Code 源码泄露:一场关于安全与学习的风波
人工智能·安全·claude code·源码泄漏