Pytorch 学习笔记(9): PyTorch.Compile

一、概述

torch.compiler 是 PyTorch 2.0+ 引入的编译器模块,提供了模型编译、优化和部署的核心 API。本文系统整理所有官方 API 及其用途。


二、核心编译 API

1. torch.compiler.compile

功能 :编译 PyTorch 模型/函数
说明 :详见 torch.compile() 的参数文档,这是最常用的编译入口。

python 复制代码
# 典型用法
model = torch.compile(model, backend="inductor")

2. torch.compiler.reset

功能 :重置编译器状态
说明:清除所有编译缓存,将系统恢复到初始状态。适用于调试内存泄漏或状态异常问题。


3. torch.compiler.disable

功能 :禁用编译的装饰器
说明:为特定函数添加装饰器,强制跳过编译。

python 复制代码
@torch.compiler.disable
def my_function(x):
    # 此函数不会被编译
    return x * 2

4. torch.compiler.set_stance

功能 :设置编译器立场
说明:控制编译器的行为模式(如 eager、compile 等模式切换)。


三、图操作与内联控制

5. torch.compiler.allow_in_graph

功能 :允许函数直接进入计算图
说明:告诉 Dynamo 前端跳过符号内省,直接将函数写入图中。适用于 C 扩展函数等无法被追踪的代码。


6. torch.compiler.substitute_in_graph

功能 :注册 polyfill 处理器
说明:为 C 扩展函数提供替代实现,在图内联时使用自定义版本替换原函数。


7. torch.compiler.nested_compile_region

功能 :标记嵌套编译区域
说明:标识一组可重复使用的操作,编译一次后安全复用。适用于模型中重复出现的子结构。

python 复制代码
with torch.compiler.nested_compile_region():
    # 这部分代码会被识别为可复用区域
    x = self.block(x)

四、常量与假设优化

8. torch.compiler.assume_constant_result

功能 :标记函数返回常量
说明:告知编译器某函数的输出是常量,可进行常量折叠优化。


五、后端与配置

9. torch.compiler.list_backends

功能 :列出可用后端
说明 :返回所有有效的后端名称字符串,如 "inductor""aot_eager" 等。

python 复制代码
print(torch.compiler.list_backends())
# ['inductor', 'cudagraphs', 'aot_eager', ...]

10. torch.compiler.config

功能 :编译器配置
说明:访问和修改编译器的全局配置参数。


六、状态检测 API

API 功能说明
is_compiling() 检测当前是否处于 torch.compile()torch.export() 的编译/追踪过程中
is_dynamo_compiling() 检测是否通过 TorchDynamo 进行图追踪
is_exporting() 检测是否处于模型导出流程中

典型用法

python 复制代码
if torch.compiler.is_compiling():
    # 编译时优化路径
    pass
else:
    # eager 模式路径
    pass

七、分布式与 CUDA Graphs

11. torch.compiler.set_enable_guard_collectives

功能 :启用 guard 评估中的集合通信
说明:允许在 guard 评估时使用 collectives 同步多卡行为。


12. torch.compiler.cudagraph_mark_step_begin

功能 :标记 CUDA Graph 步骤开始
说明:指示新的推理/训练迭代即将开始,用于 CUDA Graph 捕获优化。


八、Guard 控制(⚠️ 不安全 API)

以下 API 均以 _unsafe 结尾,用于高级调优,可能影响正确性

API 功能
keep_portable_guards_unsafe() 仅保留 Python/非 Python 环境通用的 guard
keep_tensor_guards_unsafe() 保留所有张量 guard
skip_guard_on_inbuilt_nn_modules_unsafe() 跳过内置 nn 模块(如 nn.Linear)的 guard
skip_guard_on_all_nn_modules_unsafe() 跳过所有 nn 模块的 guard
skip_guard_on_globals_unsafe() 跳过全局变量的 guard
skip_all_guards_unsafe() 跳过所有 guard

⚠️ 警告:这些 API 会放宽编译器的正确性检查,仅在确信代码行为不变时使用。


九、快速参考表

场景 推荐 API
编译模型 torch.compile()
调试编译问题 torch.compiler.reset()
排除特定函数 @torch.compiler.disable
C 扩展函数兼容 torch.compiler.allow_in_graph()
检测编译状态 torch.compiler.is_compiling()
查看可用后端 torch.compiler.list_backends()
CUDA Graph 优化 torch.compiler.cudagraph_mark_step_begin()

十、相关资源


📌 提示:本文基于 PyTorch 2.11 稳定版文档整理,建议结合官方最新文档使用。

相关推荐
AI技术增长10 分钟前
Pytorch图像去噪实战(二):用UNet解决DnCNN细节丢失问题(结构解析+完整代码+踩坑总结)
人工智能·pytorch·python
摇滚侠1 小时前
Java 零基础全套视频教程,面向对象(高级),笔记 105-120
java·开发语言·笔记
tq10861 小时前
程序行为的效应构成:约束、规则与延迟固化的统一视角
笔记
Alice-YUE1 小时前
前端图片优化完全指南:从格式到加载的全面提速方案
前端·笔记·学习
沉默-_-2 小时前
备战蓝桥杯-哈希
c++·学习·算法·蓝桥杯·哈希算法
AI技术增长2 小时前
Pytorch图像去噪实战(五):FFDNet可控图像去噪实战,用噪声强度图解决不同噪声等级问题
pytorch·python·深度学习
我想我不够好。2 小时前
监控学习 4.28 1.5 hour
学习
Stella Blog2 小时前
狂神Java基础学习笔记Day05
java·笔记·学习
枷锁—sha2 小时前
【CTFshow-pwn系列】03_栈溢出【pwn 073】详解:静态编译下的自动化 ROP 链构建
网络·汇编·笔记·安全·网络安全·自动化
Alice-YUE2 小时前
前端性能优化完全指南:从指标到实战
前端·学习·性能优化