机器学习 - 预测训练模型

接着上篇博客机器学习-训练模型做进一步说明。

There are three things to make predictions (also called performing inference) with a PyTorch model:

  1. Set the model in evaluation mode (model.eval())
  2. Make the predictions using the inference mode context manager (with torch.inference_mode(): ...)
  3. All predictions should be made with objects on the same device (e.g. data and model on GPU only or data and model on CPU only).

The first two items make sure all helpful calculations and settings PyTorch uses behind the scenes during training but aren't necessary for inference are turned off (this results in faster computation). And the third ensures that you won't run into cross-device errors.


下面代码片段是连接之前的博客

python 复制代码
import torch

# 1. Set the model in evaluation mode 
model_0.eval() 

# 2. Setup the inference mode context manager
with torch.inference_mode():
  # 3. Make sure the calculations are done with the model and data on the same device
  y_preds = model_0(X_test)

print(y_preds)

plot_predictions(predictions=y_preds)

# 结果如下
tensor([[0.8685],
        [0.8825],
        [0.8965],
        [0.9105],
        [0.9245],
        [0.9384],
        [0.9524],
        [0.9664],
        [0.9804],
        [0.9944]])

在下图,能看到预测点跟测试点很靠近,这结果挺理想的

这里稍微介绍一下 torch.inference_mode()

torch.inference.mode() 是一个上下文管理器,用于控制推断模式下的模型行为。在深度学习中,模型在训练和推断 (或称为预测) 阶段有不同的行为。在推断阶段,通常不需要计算梯度,也不需要跟踪计算图,这样可以提高推断速度并减少内存占用。torch.inference_mode() 上下文管理器就是为了控制模型在推断阶段的行为。

当进入torch.inference_mode() 上下文环境时,PyTorch会关闭梯度跟踪,并且禁用自动微分机制。这意味着在此环境中,无法调用backward()方法计算梯度,也无法通过梯度进行参数更新。这样可以确保模型在推断阶段不会意外地计算梯度,提高了推断的速度和效率。


都看到这里,点个赞支持一下呗~

相关推荐
勾股导航几秒前
灰狼优化算法GWO
人工智能·深度学习·机器学习
盼小辉丶5 分钟前
Transformer实战——Transformer跨语言零样本学习
深度学习·transformer·零样本学习
sheyuDemo5 分钟前
关于深度学习的d2l库的安装
人工智能·python·深度学习·机器学习·numpy
政安晨6 分钟前
政安晨【人工智能项目随笔】OpenClaw网关与子节点完整配对指南——从零构建分布式AI助手网络
人工智能·ai网关·openclaw·分布式ai助手网络·openclaw分布式子节点·分布式ai节点·主节点-子节点
shenxianasi6 分钟前
【论文精读】Language Is Not All You Need: Aligning Perceptionwith Language Models
人工智能·机器学习·计算机视觉·语言模型·自然语言处理·vllm·audiolm
这是个栗子9 分钟前
AI辅助编程工具(八) - Baidu Comate
人工智能·ai·baidu comate
Caesar Zou10 分钟前
深度学习14: Adversarial attacks
人工智能·深度学习
SmartBrain12 分钟前
FastAPI 进阶(第二部分):SQLAlchemy ORM(含考题)
数据库·人工智能·aigc·fastapi
向哆哆16 分钟前
道路表面多类型缺陷的图像识别数据集分享(适用于目标检测任务)
人工智能·目标检测·计算机视觉
格林威22 分钟前
Baumer相机药瓶铝盖压合完整性检测:防止密封失效的 7 个关键技术,附 OpenCV+Halcon 实战代码!
人工智能·opencv·计算机视觉·视觉检测·工业相机·智能相机·堡盟相机