机器学习 - 预测训练模型

接着上篇博客机器学习-训练模型做进一步说明。

There are three things to make predictions (also called performing inference) with a PyTorch model:

  1. Set the model in evaluation mode (model.eval())
  2. Make the predictions using the inference mode context manager (with torch.inference_mode(): ...)
  3. All predictions should be made with objects on the same device (e.g. data and model on GPU only or data and model on CPU only).

The first two items make sure all helpful calculations and settings PyTorch uses behind the scenes during training but aren't necessary for inference are turned off (this results in faster computation). And the third ensures that you won't run into cross-device errors.


下面代码片段是连接之前的博客

python 复制代码
import torch

# 1. Set the model in evaluation mode 
model_0.eval() 

# 2. Setup the inference mode context manager
with torch.inference_mode():
  # 3. Make sure the calculations are done with the model and data on the same device
  y_preds = model_0(X_test)

print(y_preds)

plot_predictions(predictions=y_preds)

# 结果如下
tensor([[0.8685],
        [0.8825],
        [0.8965],
        [0.9105],
        [0.9245],
        [0.9384],
        [0.9524],
        [0.9664],
        [0.9804],
        [0.9944]])

在下图,能看到预测点跟测试点很靠近,这结果挺理想的

这里稍微介绍一下 torch.inference_mode()

torch.inference.mode() 是一个上下文管理器,用于控制推断模式下的模型行为。在深度学习中,模型在训练和推断 (或称为预测) 阶段有不同的行为。在推断阶段,通常不需要计算梯度,也不需要跟踪计算图,这样可以提高推断速度并减少内存占用。torch.inference_mode() 上下文管理器就是为了控制模型在推断阶段的行为。

当进入torch.inference_mode() 上下文环境时,PyTorch会关闭梯度跟踪,并且禁用自动微分机制。这意味着在此环境中,无法调用backward()方法计算梯度,也无法通过梯度进行参数更新。这样可以确保模型在推断阶段不会意外地计算梯度,提高了推断的速度和效率。


都看到这里,点个赞支持一下呗~

相关推荐
程序猿追1 分钟前
深度解码昇腾 AI 算力引擎:CANN Runtime 核心架构与技术演进
人工智能·架构
金融RPA机器人丨实在智能1 分钟前
Android Studio开发App项目进入AI深水区:实在智能Agent引领无代码交互革命
android·人工智能·ai·android studio
lili-felicity4 分钟前
CANN异步推理实战:从Stream管理到流水线优化
大数据·人工智能
做人不要太理性5 分钟前
CANN Runtime 运行时组件深度解析:任务下沉执行、异构内存规划与全栈维测诊断机制
人工智能·神经网络·魔珐星云
不爱学英文的码字机器5 分钟前
破壁者:CANN ops-nn 仓库与昇腾 AI 算子优化的工程哲学
人工智能
晚霞的不甘8 分钟前
CANN 编译器深度解析:TBE 自定义算子开发实战
人工智能·架构·开源·音视频
愚公搬代码9 分钟前
【愚公系列】《AI短视频创作一本通》016-AI短视频的生成(AI短视频运镜方法)
人工智能·音视频
哈__9 分钟前
CANN内存管理与资源优化
人工智能·pytorch
极新10 分钟前
智启新篇,智创未来,“2026智造新IP:AI驱动品牌增长新周期”峰会暨北京电子商务协会第五届第三次会员代表大会成功举办
人工智能·网络协议·tcp/ip
island131411 分钟前
CANN GE(图引擎)深度解析:计算图优化管线、内存静态规划与异构任务的 Stream 调度机制
开发语言·人工智能·深度学习·神经网络