Pytorch 学习笔记(9): PyTorch.Compile

一、概述

torch.compiler 是 PyTorch 2.0+ 引入的编译器模块,提供了模型编译、优化和部署的核心 API。本文系统整理所有官方 API 及其用途。


二、核心编译 API

1. torch.compiler.compile

功能 :编译 PyTorch 模型/函数
说明 :详见 torch.compile() 的参数文档,这是最常用的编译入口。

python 复制代码
# 典型用法
model = torch.compile(model, backend="inductor")

2. torch.compiler.reset

功能 :重置编译器状态
说明:清除所有编译缓存,将系统恢复到初始状态。适用于调试内存泄漏或状态异常问题。


3. torch.compiler.disable

功能 :禁用编译的装饰器
说明:为特定函数添加装饰器,强制跳过编译。

python 复制代码
@torch.compiler.disable
def my_function(x):
    # 此函数不会被编译
    return x * 2

4. torch.compiler.set_stance

功能 :设置编译器立场
说明:控制编译器的行为模式(如 eager、compile 等模式切换)。


三、图操作与内联控制

5. torch.compiler.allow_in_graph

功能 :允许函数直接进入计算图
说明:告诉 Dynamo 前端跳过符号内省,直接将函数写入图中。适用于 C 扩展函数等无法被追踪的代码。


6. torch.compiler.substitute_in_graph

功能 :注册 polyfill 处理器
说明:为 C 扩展函数提供替代实现,在图内联时使用自定义版本替换原函数。


7. torch.compiler.nested_compile_region

功能 :标记嵌套编译区域
说明:标识一组可重复使用的操作,编译一次后安全复用。适用于模型中重复出现的子结构。

python 复制代码
with torch.compiler.nested_compile_region():
    # 这部分代码会被识别为可复用区域
    x = self.block(x)

四、常量与假设优化

8. torch.compiler.assume_constant_result

功能 :标记函数返回常量
说明:告知编译器某函数的输出是常量,可进行常量折叠优化。


五、后端与配置

9. torch.compiler.list_backends

功能 :列出可用后端
说明 :返回所有有效的后端名称字符串,如 "inductor""aot_eager" 等。

python 复制代码
print(torch.compiler.list_backends())
# ['inductor', 'cudagraphs', 'aot_eager', ...]

10. torch.compiler.config

功能 :编译器配置
说明:访问和修改编译器的全局配置参数。


六、状态检测 API

API 功能说明
is_compiling() 检测当前是否处于 torch.compile()torch.export() 的编译/追踪过程中
is_dynamo_compiling() 检测是否通过 TorchDynamo 进行图追踪
is_exporting() 检测是否处于模型导出流程中

典型用法

python 复制代码
if torch.compiler.is_compiling():
    # 编译时优化路径
    pass
else:
    # eager 模式路径
    pass

七、分布式与 CUDA Graphs

11. torch.compiler.set_enable_guard_collectives

功能 :启用 guard 评估中的集合通信
说明:允许在 guard 评估时使用 collectives 同步多卡行为。


12. torch.compiler.cudagraph_mark_step_begin

功能 :标记 CUDA Graph 步骤开始
说明:指示新的推理/训练迭代即将开始,用于 CUDA Graph 捕获优化。


八、Guard 控制(⚠️ 不安全 API)

以下 API 均以 _unsafe 结尾,用于高级调优,可能影响正确性

API 功能
keep_portable_guards_unsafe() 仅保留 Python/非 Python 环境通用的 guard
keep_tensor_guards_unsafe() 保留所有张量 guard
skip_guard_on_inbuilt_nn_modules_unsafe() 跳过内置 nn 模块(如 nn.Linear)的 guard
skip_guard_on_all_nn_modules_unsafe() 跳过所有 nn 模块的 guard
skip_guard_on_globals_unsafe() 跳过全局变量的 guard
skip_all_guards_unsafe() 跳过所有 guard

⚠️ 警告:这些 API 会放宽编译器的正确性检查,仅在确信代码行为不变时使用。


九、快速参考表

场景 推荐 API
编译模型 torch.compile()
调试编译问题 torch.compiler.reset()
排除特定函数 @torch.compiler.disable
C 扩展函数兼容 torch.compiler.allow_in_graph()
检测编译状态 torch.compiler.is_compiling()
查看可用后端 torch.compiler.list_backends()
CUDA Graph 优化 torch.compiler.cudagraph_mark_step_begin()

十、相关资源


📌 提示:本文基于 PyTorch 2.11 稳定版文档整理,建议结合官方最新文档使用。

相关推荐
pzx_0012 小时前
【Pytorch】nn.Embedding函数详解
人工智能·pytorch·embedding
Xudde.2 小时前
班级作业笔记报告0x09
笔记·学习·安全·web安全·php
charlie1145141912 小时前
嵌入式Linux驱动开发——模块参数与内核调试:让模块“活“起来的魔法
linux·驱动开发·学习·c
ZzYH222 小时前
文献阅读 260407-Leveraging edge artificial intelligence for sustainable agriculture
笔记
青桔柠薯片2 小时前
I²C 总线协议学习总结:从开漏逻辑到读写事务的工程视角
c语言·开发语言·学习
龙文浩_2 小时前
AI的jieba分词原理与多模式应用解析
人工智能·pytorch·深度学习·神经网络
AI_零食2 小时前
开源鸿蒙跨平台Flutter开发:生物力学与力量周期-臂力训练矩阵架构
学习·flutter·ui·华为·矩阵·开源·harmonyos
sinat_255487812 小时前
泛型:类·学习笔记
java·jvm·笔记·学习
被考核重击2 小时前
计算机网络核心知识点笔记
网络·笔记·计算机网络