一、概述
torch.compiler 是 PyTorch 2.0+ 引入的编译器模块,提供了模型编译、优化和部署的核心 API。本文系统整理所有官方 API 及其用途。
二、核心编译 API
1. torch.compiler.compile
功能 :编译 PyTorch 模型/函数
说明 :详见 torch.compile() 的参数文档,这是最常用的编译入口。
python
# 典型用法
model = torch.compile(model, backend="inductor")
2. torch.compiler.reset
功能 :重置编译器状态
说明:清除所有编译缓存,将系统恢复到初始状态。适用于调试内存泄漏或状态异常问题。
3. torch.compiler.disable
功能 :禁用编译的装饰器
说明:为特定函数添加装饰器,强制跳过编译。
python
@torch.compiler.disable
def my_function(x):
# 此函数不会被编译
return x * 2
4. torch.compiler.set_stance
功能 :设置编译器立场
说明:控制编译器的行为模式(如 eager、compile 等模式切换)。
三、图操作与内联控制
5. torch.compiler.allow_in_graph
功能 :允许函数直接进入计算图
说明:告诉 Dynamo 前端跳过符号内省,直接将函数写入图中。适用于 C 扩展函数等无法被追踪的代码。
6. torch.compiler.substitute_in_graph
功能 :注册 polyfill 处理器
说明:为 C 扩展函数提供替代实现,在图内联时使用自定义版本替换原函数。
7. torch.compiler.nested_compile_region
功能 :标记嵌套编译区域
说明:标识一组可重复使用的操作,编译一次后安全复用。适用于模型中重复出现的子结构。
python
with torch.compiler.nested_compile_region():
# 这部分代码会被识别为可复用区域
x = self.block(x)
四、常量与假设优化
8. torch.compiler.assume_constant_result
功能 :标记函数返回常量
说明:告知编译器某函数的输出是常量,可进行常量折叠优化。
五、后端与配置
9. torch.compiler.list_backends
功能 :列出可用后端
说明 :返回所有有效的后端名称字符串,如 "inductor"、"aot_eager" 等。
python
print(torch.compiler.list_backends())
# ['inductor', 'cudagraphs', 'aot_eager', ...]
10. torch.compiler.config
功能 :编译器配置
说明:访问和修改编译器的全局配置参数。
六、状态检测 API
| API | 功能说明 |
|---|---|
is_compiling() |
检测当前是否处于 torch.compile() 或 torch.export() 的编译/追踪过程中 |
is_dynamo_compiling() |
检测是否通过 TorchDynamo 进行图追踪 |
is_exporting() |
检测是否处于模型导出流程中 |
典型用法:
python
if torch.compiler.is_compiling():
# 编译时优化路径
pass
else:
# eager 模式路径
pass
七、分布式与 CUDA Graphs
11. torch.compiler.set_enable_guard_collectives
功能 :启用 guard 评估中的集合通信
说明:允许在 guard 评估时使用 collectives 同步多卡行为。
12. torch.compiler.cudagraph_mark_step_begin
功能 :标记 CUDA Graph 步骤开始
说明:指示新的推理/训练迭代即将开始,用于 CUDA Graph 捕获优化。
八、Guard 控制(⚠️ 不安全 API)
以下 API 均以 _unsafe 结尾,用于高级调优,可能影响正确性:
| API | 功能 |
|---|---|
keep_portable_guards_unsafe() |
仅保留 Python/非 Python 环境通用的 guard |
keep_tensor_guards_unsafe() |
保留所有张量 guard |
skip_guard_on_inbuilt_nn_modules_unsafe() |
跳过内置 nn 模块(如 nn.Linear)的 guard |
skip_guard_on_all_nn_modules_unsafe() |
跳过所有 nn 模块的 guard |
skip_guard_on_globals_unsafe() |
跳过全局变量的 guard |
skip_all_guards_unsafe() |
跳过所有 guard |
⚠️ 警告:这些 API 会放宽编译器的正确性检查,仅在确信代码行为不变时使用。
九、快速参考表
| 场景 | 推荐 API |
|---|---|
| 编译模型 | torch.compile() |
| 调试编译问题 | torch.compiler.reset() |
| 排除特定函数 | @torch.compiler.disable |
| C 扩展函数兼容 | torch.compiler.allow_in_graph() |
| 检测编译状态 | torch.compiler.is_compiling() |
| 查看可用后端 | torch.compiler.list_backends() |
| CUDA Graph 优化 | torch.compiler.cudagraph_mark_step_begin() |
十、相关资源
📌 提示:本文基于 PyTorch 2.11 稳定版文档整理,建议结合官方最新文档使用。