TorchDynamo 源码深度剖析
为什么需要 TorchDynamo?与 FX 的对比
FX 的局限性
在学习完 FX 后,我们知道 FX 通过 Proxy 拦截机制可以捕获计算图,但它有一个致命缺陷:
python
# FX 无法处理的情况
def forward(self, x):
if x.sum() > 0: # [X] 数据依赖的控制流
return x * 2
else:
return x + 1
traced = fx.symbolic_trace(Model()) # [错误] 会报错或只捕获一个分支
FX 的问题:
- 无法处理数据依赖的控制流(if/while/for)
- 需要实际执行代码才能追踪
- 对 Python 动态特性支持有限
- Graph Break 后无法继续追踪
TorchDynamo 的优势
python
# TorchDynamo 可以处理
@torch.compile
def forward(self, x):
if x.sum() > 0: # [√] 完全支持!
return x * 2
else:
return x + 1
output = forward(torch.randn(10)) # 工作正常
TorchDynamo 的优势:
-
√\] 字节码级追踪:不需要实际执行
-
√\] Guard 机制:动态适配不同情况
技术对比
| 维度 | FX | TorchDynamo |
|---|---|---|
| 追踪方式 | Proxy 拦截(运行时) | 字节码分析(静态+动态) |
| 控制流 | 有限(只能捕获实际执行的分支) | 完整支持(通过 Guard) |
| Python 兼容性 | 部分(需要可追踪的代码) | 几乎完整 |
| 性能 | 追踪开销小 | 首次编译慢,但有缓存 |
| 适用场景 | 图优化、模型分析 | 端到端编译优化 |
核心设计思想
FX: 执行 Python 代码 -> 拦截操作 -> 构建图
[必须运行] [Proxy 魔术方法]
TorchDynamo: 分析字节码 -> 模拟执行 -> 构建图 + Guard
[静态分析] [符号执行] [动态适配]
TorchDynamo 是什么?
一句话总结
TorchDynamo 是 PyTorch 的 JIT 编译器前端,通过字节码级别的追踪和 Guard 机制,在保持完整 Python 语义的同时捕获计算图。
核心组成
TorchDynamo 架构:
├─ Frame Evaluation Hook [入口] PEP 523 钩子,劫持函数执行
├─ Bytecode Analysis [分析] 反汇编字节码指令
├─ Symbolic Execution [执行] 符号化执行字节码
├─ Guard Generation [保护] 生成执行条件
├─ FX Graph Construction [输出] 生成 FX 图
└─ Graph Break Handling [容错] 处理不可追踪的代码
工作流程
Python 函数调用
|
v
[Frame Evaluation Hook]
|
| 拦截函数帧
|
v
[检查缓存]
|
+-- 缓存命中 + Guard 通过 --> 直接执行优化代码 [快]
|
+-- 缓存未命中/Guard 失败
|
v
[Bytecode Analysis]
|
| 反汇编字节码
|
v
[Symbolic Execution]
|
| 符号化执行每条指令
|
v
[构建 FX Graph]
|
| 生成计算图
|
v
[生成 Guard]
|
| 记录执行条件
|
v
[后端编译] (Inductor/其他)
|
| 生成优化代码
|
v
[缓存并执行]
源码文件索引
TorchDynamo 的核心代码位于 torch/_dynamo/ 目录:
torch/_dynamo/
├── eval_frame.c/py # [核心] Frame 劫持与缓存管理
├── symbolic_convert.py # [核心] 字节码 -> FX Graph 转换
├── bytecode_transformation.py # [核心] 字节码分析与变换
├── guards.py # [核心] Guard 生成与检查
├── variables/ # 符号化变量系统
│ ├── base.py # 变量基类
│ ├── tensor.py # Tensor 变量
│ ├── user_defined.py # 用户自定义类
│ └── ...
├── output_graph.py # FX Graph 构建
├── resume_execution.py # Graph Break 后恢复执行
├── side_effects.py # 副作用处理
└── utils.py # 工具函数
基础概念:什么是执行帧(Frame)
在理解 TorchDynamo 之前,我们需要先理解 Python 的执行模型。
Python 代码的执行过程
当你运行一个 Python 函数时,CPython 解释器会经历这些步骤:
python
def add(a, b):
result = a + b
return result
x = add(1, 2)
执行流程:
Python 源代码
|
v
[1] 词法分析 + 语法分析
|
v
抽象语法树 (AST)
|
v
[2] 编译为字节码
|
v
字节码 (Code Object)
|
v
[3] 创建执行帧 (Frame) <-- 这里!
|
v
[4] 执行字节码
|
v
返回结果
什么是执行帧(Frame)?
执行帧(PyFrameObject)就是一个函数调用的"运行环境",它包含了执行这个函数所需的一切信息。
可以把它想象成一个"工作台":
python
class Frame:
"""
这是执行帧的简化概念模型
"""
def __init__(self):
# 1. 要执行的代码
self.f_code = CodeObject(...) # 字节码指令
# 2. 变量存储
self.f_locals = {} # 局部变量: {'a': 1, 'b': 2, 'result': 3}
self.f_globals = {} # 全局变量
self.f_builtins = {} # 内置函数
# 3. 执行状态
self.f_lasti = 0 # 当前执行到第几条指令
self.value_stack = [] # 操作数栈
# 4. 调用链信息
self.f_back = None # 调用者的帧(形成调用栈)
真实示例:查看执行帧
python
import sys
import inspect
def inner_function(x):
# 获取当前帧
frame = sys._getframe()
print("=== 当前帧信息 ===")
print(f"函数名: {frame.f_code.co_name}")
print(f"文件名: {frame.f_code.co_filename}")
print(f"行号: {frame.f_lineno}")
print(f"局部变量: {frame.f_locals}")
print(f"字节码指令数: {len(frame.f_code.co_code)}")
return x * 2
def outer_function():
result = inner_function(10)
return result
# 运行
outer_function()
输出:
=== 当前帧信息 ===
函数名: inner_function
文件名: /path/to/your/file.py
行号: 5
局部变量: {'x': 10, 'frame': <frame object>}
字节码指令数: 42
查看字节码
python
import dis
def add(a, b):
result = a + b
return result
# 查看字节码
dis.dis(add)
输出:
2 0 LOAD_FAST 0 (a)
2 LOAD_FAST 1 (b)
4 BINARY_ADD
6 STORE_FAST 2 (result)
3 8 LOAD_FAST 2 (result)
10 RETURN_VALUE
每一行就是一条字节码指令,执行帧会逐条执行这些指令。
CPython 如何执行帧
在 CPython 内部,有一个核心函数负责执行帧:
c
// CPython 源码:Python/ceval.c
PyObject* _PyEval_EvalFrameDefault(
PyThreadState* tstate,
PyFrameObject* frame,
int throwflag
) {
/*
* 这是 CPython 的"心脏"!
*
* 它的工作:
* 1. 读取 frame.f_code.co_code (字节码)
* 2. 逐条执行指令
* 3. 操作 value_stack (操作数栈)
* 4. 更新 f_locals (局部变量)
* 5. 返回结果
*/
// 简化的执行循环
for (;;) {
// 获取当前指令
int opcode = NEXTOP();
switch (opcode) {
case LOAD_FAST:
// 从局部变量加载到栈
value = GETLOCAL(oparg);
PUSH(value);
break;
case BINARY_ADD:
// 从栈弹出两个值,相加,结果压栈
right = POP();
left = POP();
sum = PyNumber_Add(left, right);
PUSH(sum);
break;
case RETURN_VALUE:
// 返回栈顶值
retval = POP();
return retval;
// ... 100+ 其他指令
}
}
}
关键问题:如何拦截?
正常情况下的执行流程:
Python 调用函数
|
v
CPython 创建 Frame
|
v
调用 _PyEval_EvalFrameDefault(frame) <-- 默认执行函数
|
v
逐条执行字节码
|
v
返回结果
PEP 523 引入的机制:允许替换执行函数
Python 3.6+ 提供了 _PyInterpreterState_SetEvalFrameFunc() 函数,可以替换默认的帧执行函数。
c
// CPython 提供的 API
// 解释器状态
typedef struct {
// ... 其他字段
// 帧执行函数指针(可以替换!)
_PyFrameEvalFunction eval_frame; // 默认指向 _PyEval_EvalFrameDefault
} PyInterpreterState;
// 设置自定义的执行函数
void _PyInterpreterState_SetEvalFrameFunc(
PyInterpreterState* interp,
_PyFrameEvalFunction new_eval_frame // 你的自定义函数!
);
TorchDynamo 的拦截机制
安装拦截器:
c
// torch/csrc/dynamo/eval_frame.c
// 1. 保存原始的执行函数
static _PyFrameEvalFunction original_eval_frame = NULL;
// 2. 定义自己的执行函数
static PyObject* custom_eval_frame(
PyThreadState* tstate,
PyFrameObject* frame,
int throwflag
) {
// 这里可以做任何事情!
// 获取代码对象
PyCodeObject* code = frame->f_code;
printf("拦截到函数调用: %s\n", PyUnicode_AsUTF8(code->co_name));
printf("字节码长度: %ld\n", PyBytes_Size(code->co_code));
// 可以选择:
// A. 调用原始执行(正常运行)
// B. 调用自己的编译逻辑(优化)
// C. 混合使用
if (should_compile(frame)) {
// 编译并执行优化代码
return compile_and_run(frame);
} else {
// 回退到原始执行
return original_eval_frame(tstate, frame, throwflag);
}
}
// 3. 安装拦截器
void install_hook() {
PyInterpreterState* interp = PyThreadState_GET()->interp;
// 保存原始函数
original_eval_frame = _PyInterpreterState_GetEvalFrameFunc(interp);
// 替换为自定义函数
_PyInterpreterState_SetEvalFrameFunc(interp, custom_eval_frame);
printf("拦截器已安装!\n");
}
拦截后的执行流程:
Python 调用函数
|
v
CPython 创建 Frame
|
v
调用 interp->eval_frame(frame)
|
| (已经被替换!)
v
调用 custom_eval_frame(frame) <-- TorchDynamo 的函数
|
v
检查:是否需要编译?
|
+-- 是 --> 分析字节码 --> 构建 FX Graph --> 编译 --> 执行优化代码
|
+-- 否 --> 调用 original_eval_frame(frame) --> 正常执行
|
v
返回结果
完整示例:演示拦截
python
# 演示文件: demo_frame_hook.py
import sys
import types
# 模拟的拦截计数器
call_count = 0
def my_custom_eval_frame(frame, exc):
"""
自定义的帧执行函数(Python 层模拟)
"""
global call_count
call_count += 1
code = frame.f_code
print(f"\n[拦截 #{call_count}]")
print(f" 函数: {code.co_name}")
print(f" 参数: {frame.f_locals}")
# 获取字节码
import dis
print(f" 字节码:")
dis.disassemble(code)
# 调用原始执行(在真实实现中,这里会编译)
# 注意:Python 层无法真正替换 eval_frame,这只是演示概念
return None
# 实际演示:查看帧信息
def demo_function(x, y):
"""被拦截的函数"""
frame = sys._getframe()
print("\n=== 帧详细信息 ===")
print(f"代码对象: {frame.f_code}")
print(f" - 函数名: {frame.f_code.co_name}")
print(f" - 参数名: {frame.f_code.co_varnames}")
print(f" - 常量: {frame.f_code.co_consts}")
print(f"\n执行状态:")
print(f" - 局部变量: {frame.f_locals}")
print(f" - 当前指令: {frame.f_lasti}")
print(f"\n调用栈:")
current = frame
depth = 0
while current is not None:
print(f" [{depth}] {current.f_code.co_name}")
current = current.f_back
depth += 1
result = x + y
return result * 2
# 运行
if __name__ == "__main__":
result = demo_function(10, 20)
print(f"\n最终结果: {result}")
运行输出:
=== 帧详细信息 ===
代码对象: <code object demo_function at 0x...>
- 函数名: demo_function
- 参数名: ('x', 'y', 'frame', 'current', 'depth', 'result')
- 常量: (None, 0, 1, 2, ' [', '] ', '\n最终结果: ')
执行状态:
- 局部变量: {'x': 10, 'y': 20, 'frame': <frame object>}
- 当前指令: 42
调用栈:
[0] demo_function
[1] <module>
最终结果: 60
TorchDynamo 真实的 C 层拦截
c
// torch/csrc/dynamo/eval_frame.c (简化版)
static PyObject* custom_eval_frame(
PyThreadState* tstate,
PyFrameObject* frame,
int throwflag
) {
// 【步骤 1】提取信息
PyCodeObject* code = frame->f_code;
const char* func_name = PyUnicode_AsUTF8(code->co_name);
// 【步骤 2】决策:是否编译
if (should_skip_frame(frame)) {
// 跳过:系统函数、已编译函数等
return original_eval_frame(tstate, frame, throwflag);
}
// 【步骤 3】检查缓存
CacheEntry* cached = lookup_cache(code);
if (cached != NULL) {
// 【步骤 4】检查 Guard
if (check_guards(frame, cached->guards)) {
// Guard 通过,执行缓存的优化代码
cached->hit_count++;
return execute_optimized_code(tstate, frame, cached->code);
}
// Guard 失败
cached->miss_count++;
}
// 【步骤 5】调用 Python 层编译
PyObject* result = call_python_compiler(frame);
if (result != NULL) {
// 编译成功
return result;
}
// 【步骤 6】回退
return original_eval_frame(tstate, frame, throwflag);
}
核心要点总结
-
执行帧(Frame)= 函数的运行环境
- 包含代码(字节码)
- 包含变量(局部/全局)
- 包含执行状态(当前指令位置、操作数栈)
-
CPython 通过
eval_frame函数执行帧- 默认:
_PyEval_EvalFrameDefault - 逐条执行字节码指令
- 默认:
-
PEP 523 允许替换
eval_frame函数- 通过
_PyInterpreterState_SetEvalFrameFunc() - 在每次函数调用时拦截
- 通过
-
TorchDynamo 的拦截流程
函数调用 -> custom_eval_frame (拦截!) -> 检查缓存 -> 检查 Guard -> 编译 (如果需要) -> 执行优化代码 -
为什么这么做有用?
- 无需修改用户代码
- 无需修改 Python 解释器
- 在运行时动态优化
- 保持 Python 的灵活性
直观理解:拦截的本质
把 Python 执行想象成一个流水线:
[正常执行]
用户代码: def add(a, b): return a + b
|
v
Python 解释器
|
v
[解析] -> AST
|
v
[编译] -> 字节码
|
v
[创建帧] -> Frame{code, locals, stack}
|
v
[执行帧] -> _PyEval_EvalFrameDefault(frame)
| |
| +-> 读取指令
| +-> 执行指令
| +-> 更新状态
| +-> 返回结果
v
结果: 3
[TorchDynamo 拦截后]
用户代码: @torch.compile
def add(a, b): return a + b
|
v
Python 解释器 (未修改!)
|
v
[解析] -> AST
|
v
[编译] -> 字节码
|
v
[创建帧] -> Frame{code, locals, stack}
|
v
[执行帧] -> custom_eval_frame(frame) <-- 这里被替换了!
| |
| +-> [检查] 是否需要编译?
| | |
| | +-> 是 -> [提取字节码]
| | [构建 FX Graph]
| | [编译优化]
| | [生成 Guard]
| | [缓存结果]
| | [执行优化代码]
| |
| +-> 否 -> [调用原始执行]
| _PyEval_EvalFrameDefault(frame)
v
结果: 3 (但更快!)
关键时刻:拦截发生在何时?
python
import torch
@torch.compile
def my_func(x, y):
return x + y
# 调用函数
result = my_func(torch.randn(3), torch.randn(3))
时间线:
t0: 用户调用 my_func()
|
v
t1: CPython 准备执行 my_func
|
v
t2: CPython 创建 Frame 对象
|
| Frame 内容:
| - f_code: my_func 的字节码
| - f_locals: {'x': tensor([...]), 'y': tensor([...])}
| - f_lasti: 0
v
t3: CPython 查找 eval_frame 函数
|
| 正常情况: _PyEval_EvalFrameDefault
| 现在: custom_eval_frame (被替换!)
v
t4: 调用 custom_eval_frame(frame) [C 层]
|
v
t5: custom_eval_frame 检查缓存
|
+-> 缓存未命中
|
v
t6: 调用 Python 层的 compile_frame()
|
v
t7: 分析字节码
|
| LOAD_FAST 0 (x)
| LOAD_FAST 1 (y)
| BINARY_ADD
| RETURN_VALUE
v
t8: 构建 FX Graph
|
| def forward(x, y):
| return torch.add(x, y)
v
t9: 后端编译 (Inductor)
|
| 生成优化的 CUDA kernel
v
t10: 生成 Guard
|
| Guard 1: x.shape == torch.Size([3])
| Guard 2: x.dtype == torch.float32
| Guard 3: y.shape == torch.Size([3])
| Guard 4: y.dtype == torch.float32
v
t11: 缓存结果
|
v
t12: 执行优化代码
|
v
t13: 返回结果
---
第二次调用 (相同输入):
t0: 用户调用 my_func()
v
t3: 调用 custom_eval_frame(frame)
v
t5: 检查缓存 -> 命中!
v
t5.1: 检查 Guard -> 通过!
v
t12: 直接执行缓存的优化代码 (极快!)
v
t13: 返回结果
实际运行演示
我已经创建了一个演示程序 Pytorch/examples/demo_frame.py,你可以运行它来亲自查看执行帧的结构:
bash
cd Pytorch/examples
python demo_frame.py
程序会展示:
- 执行帧的详细信息(代码、变量、状态)
- 字节码指令的结构
- 调用栈的追踪
- 字节码执行的模拟过程
- TorchDynamo 拦截的概念
- 实际的 torch.compile 例子
用一句话总结
TorchDynamo 就像在 Python 解释器和你的代码之间插入了一个"智能代理",这个代理在每次函数调用时决定:是直接运行,还是先编译优化再运行。而这个"插入"动作是通过 PEP 523 提供的 eval_frame 替换机制实现的。
torch.compile 装饰器完整调用链
在深入 Frame Evaluation Hook 之前,让我们先看看 @torch.compile 装饰器是如何一步步调用到字节码拦截的。
调用链概览
@torch.compile 装饰器
|
v
torch.compile() 函数 [torch/__init__.py:2422]
|
v
torch._dynamo.optimize() [torch/_dynamo/__init__.py]
|
v
_optimize_catch_errors() [torch/_dynamo/eval_frame.py]
|
v
set_eval_frame() [torch/_dynamo/eval_frame.py]
|
v
_C._dynamo.eval_frame.set_eval_frame() [C 扩展]
|
v
custom_eval_frame() [C 层拦截器]
|
v
Python 函数调用时自动触发
第 1 步:torch.compile() 装饰器
文件位置 :torch/__init__.py 第 2422 行
python
# torch/__init__.py:2422
def compile(
model: Optional[Callable] = None,
*,
fullgraph: bool = False,
dynamic: Optional[bool] = None,
backend: Union[str, Callable] = "inductor",
mode: Union[str, None] = None,
options: Optional[Dict[str, Union[str, int, bool]]] = None,
disable: bool = False,
) -> Union[Callable, Callable]:
"""
主入口函数
"""
# [1] 日志记录
_C._log_api_usage_once("torch.compile")
# [2] 版本检查
if sys.version_info >= (3, 14):
raise RuntimeError("torch.compile is not supported on Python 3.14+")
# [3] 装饰器模式处理
if model is None:
# 返回一个装饰器函数
def fn(model):
return compile(
model,
fullgraph=fullgraph,
dynamic=dynamic,
backend=backend,
mode=mode,
options=options,
disable=disable,
)
return fn
# [4] 创建后端编译器包装器
if backend == "inductor":
# 默认使用 Inductor 后端
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
else:
# 其他后端
backend = _TorchCompileWrapper(backend, mode, options, dynamic)
# [5] 关键调用!调用 TorchDynamo 的 optimize
return torch._dynamo.optimize(
backend=backend,
nopython=fullgraph,
dynamic=dynamic,
disable=disable,
)(model)
第 2 步:torch._dynamo.optimize()
文件位置 :torch/_dynamo/__init__.py
python
# torch/_dynamo/__init__.py
def optimize(
backend="inductor",
*,
nopython=False,
guard_export_fn=None,
guard_fail_fn=None,
disable=False,
dynamic=None,
):
"""
TorchDynamo 的主优化入口
Args:
backend: 编译后端(Inductor, eager, aot_autograd 等)
nopython: 是否要求整个函数可编译(fullgraph)
dynamic: 动态形状支持
Returns:
装饰器函数
"""
# [1] 检查是否禁用
if disable or os.environ.get("TORCH_COMPILE_DISABLE"):
return _disable_decorator
# [2] 查找后端实现
from torch._dynamo.backends.registry import lookup_backend
backend_fn = lookup_backend(backend)
# [3] 返回实际的装饰器
def decorator(fn):
# 这里会包装原始函数
from torch._dynamo.eval_frame import OptimizedModule
if isinstance(fn, torch.nn.Module):
# 处理 nn.Module
return OptimizedModule(
fn,
backend=backend_fn,
nopython=nopython,
dynamic=dynamic,
)
else:
# 处理普通函数
from torch._dynamo.eval_frame import _optimize_catch_errors
return _optimize_catch_errors(
fn,
backend=backend_fn,
nopython=nopython,
dynamic=dynamic,
)
return decorator
第 3 步:_optimize_catch_errors() - 包装并启用钩子
文件位置 :torch/_dynamo/eval_frame.py
python
# torch/_dynamo/eval_frame.py
def _optimize_catch_errors(
fn,
backend,
nopython=False,
dynamic=None,
):
"""
包装函数并启用字节码拦截
这是连接用户代码和底层 Frame Hook 的桥梁
"""
# [1] 创建回调函数
callback = callback_handler(
backend=backend,
nopython=nopython,
dynamic=dynamic,
)
# [2] 注册编译回调
from torch._dynamo.convert_frame import convert_frame
callback.convert_frame_assert = convert_frame_assert
# [3] 包装原始函数
@functools.wraps(fn)
def wrapper(*args, **kwargs):
# [关键] 在函数执行前安装 Frame Hook
prior = set_eval_frame(callback)
try:
# 执行原始函数(此时会被 Hook 拦截)
return fn(*args, **kwargs)
finally:
# 恢复之前的 Hook
set_eval_frame(prior)
# [4] 标记这是一个优化过的函数
wrapper._torchdynamo_orig_callable = fn
wrapper._torchdynamo_inline = fn
return wrapper
第 4 步:set_eval_frame() - 安装 C 层钩子
文件位置 :torch/_dynamo/eval_frame.py
python
# torch/_dynamo/eval_frame.py
def set_eval_frame(callback):
"""
设置 Python 帧执行钩子
这是 Python 层到 C 层的桥梁
Args:
callback: 编译回调函数,接收 (frame, cache_size) 返回优化代码
Returns:
之前的回调函数
"""
# [1] 获取之前的回调
prior = _eval_frame_callback
if callback is None:
# [2] 移除钩子,恢复默认执行
_C._dynamo.eval_frame.set_eval_frame(None)
_eval_frame_callback = None
else:
# [3] 创建 C 层可调用的包装器
def _callback_wrapper(frame, cache_size, frame_state):
"""
C 层调用的实际回调
Args:
frame: PyFrameObject (Python 帧对象)
cache_size: 当前缓存大小
frame_state: 帧状态信息
Returns:
优化后的代码对象,或 None(回退到默认执行)
"""
try:
# [4] 调用 Python 层的编译逻辑
return callback(frame, cache_size, frame_state)
except Exception as e:
# [5] 编译失败,回退
log_compilation_error(e)
return None
# [6] 安装到 C 扩展
# 这是关键调用!将回调注册到 CPython 解释器
_C._dynamo.eval_frame.set_eval_frame(_callback_wrapper)
_eval_frame_callback = callback
return prior
第 5 步:C 扩展层 - custom_eval_frame()
文件位置 :torch/csrc/dynamo/eval_frame.c
c
// torch/csrc/dynamo/eval_frame.c
// 全局变量:保存原始的 eval_frame 函数
static _PyFrameEvalFunction original_eval_frame = NULL;
// 全局变量:Python 层的回调函数
static PyObject* py_callback = NULL;
// TorchDynamo 的自定义 eval_frame 函数
static PyObject* custom_eval_frame(
PyThreadState* tstate,
PyFrameObject* frame,
int throw_flag
) {
/*
* 这个函数会在每次 Python 函数调用时被触发
*
* 执行流程:
* 1. 检查是否需要编译这个函数
* 2. 检查缓存
* 3. 调用 Python 层编译
* 4. 执行优化代码或回退
*/
// [1] 快速路径:检查是否应该跳过这个帧
if (should_skip_frame(frame)) {
// 直接使用原始执行
return original_eval_frame(tstate, frame, throw_flag);
}
// [2] 获取代码对象
PyCodeObject* code = frame->f_code;
// [3] 检查编译缓存
CacheEntry* cached = lookup_cache(code);
if (cached != NULL) {
// [4] 检查 Guard
if (check_guards(frame, cached->guards)) {
// Guard 通过,执行缓存的代码
return execute_cached_code(tstate, frame, cached);
}
// Guard 失败,需要重新编译
}
// [5] 缓存未命中,调用 Python 层编译
if (py_callback != NULL) {
// 构建参数
PyObject* args = Py_BuildValue(
"(OiO)",
frame, // PyFrameObject
get_cache_size(), // cache_size
Py_None // frame_state
);
// [6] 调用 Python 回调
PyObject* result = PyObject_CallObject(py_callback, args);
Py_DECREF(args);
if (result != NULL && result != Py_None) {
// [7] 编译成功,执行优化代码
PyObject* optimized_result = execute_optimized_code(
tstate,
frame,
result
);
Py_DECREF(result);
return optimized_result;
}
Py_XDECREF(result);
}
// [8] 编译失败或未启用,回退到原始执行
return original_eval_frame(tstate, frame, throw_flag);
}
// 安装自定义 eval_frame 函数
PyObject* set_eval_frame_py(PyObject* self, PyObject* callback) {
/*
* Python 调用:_C._dynamo.eval_frame.set_eval_frame(callback)
*/
if (callback == Py_None) {
// [1] 移除钩子
py_callback = NULL;
// [2] 恢复原始 eval_frame
_PyInterpreterState_SetEvalFrameFunc(
PyThreadState_GET()->interp,
original_eval_frame
);
} else {
// [3] 保存回调
Py_INCREF(callback);
py_callback = callback;
// [4] 如果是第一次安装,保存原始函数
if (original_eval_frame == NULL) {
original_eval_frame = _PyInterpreterState_GetEvalFrameFunc(
PyThreadState_GET()->interp
);
}
// [5] 安装自定义 eval_frame(关键!)
_PyInterpreterState_SetEvalFrameFunc(
PyThreadState_GET()->interp,
custom_eval_frame // 我们的自定义函数
);
}
Py_RETURN_NONE;
}
第 6 步:compile_frame() - Python 层编译逻辑
文件位置 :torch/_dynamo/convert_frame.py
python
# torch/_dynamo/convert_frame.py
def convert_frame_assert(
compiler_fn,
frame,
cache_size,
frame_state=None,
):
"""
编译单个函数帧
这是 C 层回调调用的 Python 函数
Args:
compiler_fn: 后端编译器函数
frame: Python 帧对象
cache_size: 缓存大小
frame_state: 帧状态
Returns:
GuardedCode: 包含优化代码和 Guard 的对象
"""
# [1] 创建编译上下文
code = frame.f_code
# [2] 检查是否可以编译
if not is_traceable(code):
return None
# [3] 获取字节码
instructions = list(dis.get_instructions(code))
# [4] 创建符号化转换器
from torch._dynamo.symbolic_convert import InstructionTranslator
tracer = InstructionTranslator(
instructions=instructions,
f_code=code,
f_globals=frame.f_globals,
f_locals=frame.f_locals,
f_builtins=frame.f_builtins,
compiler_fn=compiler_fn,
# ... 其他参数
)
# [5] 执行符号化转换(字节码 -> FX Graph)
try:
tracer.run()
except Exception as e:
# 编译失败
log_graph_break(e)
return None
# [6] 获取输出图
output_graph = tracer.output
# [7] 后端编译
compiled_fn = compiler_fn(
output_graph.graph,
output_graph.example_inputs
)
# [8] 生成 Guard
guards = output_graph.guards
# [9] 创建 GuardedCode 对象
guarded_code = GuardedCode(
code=compiled_fn,
guards=guards,
# ... 其他信息
)
# [10] 缓存
cache_code(code, guarded_code)
return guarded_code
完整执行流程示例
python
import torch
# 定义一个简单的函数
def my_function(x, y):
z = x + y
return z * 2
# 应用 @torch.compile 装饰器
compiled_fn = torch.compile(my_function)
# 第一次调用
x = torch.randn(10)
y = torch.randn(10)
result = compiled_fn(x, y)
# 执行流程:
"""
1. compiled_fn(x, y) 被调用
2. wrapper() 函数开始执行 [eval_frame.py]
-> set_eval_frame(callback) 安装钩子
-> 调用 my_function(x, y)
3. CPython 准备执行 my_function
-> 查找 eval_frame 函数
-> 发现是 custom_eval_frame() [C 层]
4. custom_eval_frame() 被触发 [eval_frame.c]
-> 检查缓存:未命中
-> 调用 py_callback (Python 层)
5. convert_frame_assert() 开始编译 [convert_frame.py]
-> 获取字节码:
LOAD_FAST x
LOAD_FAST y
BINARY_ADD
STORE_FAST z
LOAD_FAST z
LOAD_CONST 2
BINARY_MULTIPLY
RETURN_VALUE
-> 创建 InstructionTranslator
-> 逐条处理字节码指令
6. InstructionTranslator.run() [symbolic_convert.py]
-> LOAD_FAST x: 创建 TensorVariable(x)
-> LOAD_FAST y: 创建 TensorVariable(y)
-> BINARY_ADD: 创建 FX 节点 add(x, y)
-> STORE_FAST z: 保存符号化变量
-> ... 继续处理
-> 构建完整的 FX Graph
7. 生成 Guard
-> x.shape == (10,)
-> x.dtype == torch.float32
-> y.shape == (10,)
-> y.dtype == torch.float32
8. 后端编译 (Inductor)
-> 优化 FX Graph
-> 生成 Triton/C++ 代码
-> 编译为可执行的 kernel
9. 缓存并返回
-> 保存到代码对象的缓存中
-> 返回 GuardedCode
10. custom_eval_frame() 执行优化代码
-> 执行编译后的 kernel
-> 返回结果
11. 第二次调用(相同形状)
-> custom_eval_frame() 检查缓存:命中
-> 检查 Guard:通过
-> 直接执行缓存的代码(极快!)
"""
关键数据结构
GuardedCode
python
# torch/_dynamo/types.py
@dataclasses.dataclass
class GuardedCode:
"""
编译结果:包含优化代码和 Guard
"""
# 优化后的可执行代码
code: Union[types.CodeType, Callable]
# Guard 列表
guards: List[Guard]
# Guard 检查函数(编译后的)
check_fn: Callable[[PyFrameObject], bool]
# 原始代码对象
code_obj: types.CodeType
# 性能统计
hit_count: int = 0
miss_count: int = 0
CacheEntry (C 层)
c
// torch/csrc/dynamo/eval_frame.c
typedef struct {
PyObject* code; // 优化后的代码对象
PyObject* guards; // Guard 列表
PyObject* check_fn; // Guard 检查函数
long hit_count; // 缓存命中次数
long miss_count; // Guard 失败次数
} CacheEntry;
总结
@torch.compile 的完整调用链:
- 装饰器层 (
torch.compile()):用户入口,参数处理 - 优化层 (
torch._dynamo.optimize()):创建装饰器包装 - 钩子安装层 (
set_eval_frame()):注册 C 层回调 - C 拦截层 (
custom_eval_frame()):拦截每次函数调用 - 编译层 (
convert_frame_assert()):字节码 -> FX Graph - 后端层 (
Inductor/其他):FX Graph -> 优化代码
关键技术点:
- PEP 523 的
_PyInterpreterState_SetEvalFrameFunc允许替换帧执行函数 - 字节码在 C 层被拦截,但编译逻辑在 Python 层
- Guard 机制确保优化代码的正确性
- 缓存避免重复编译
第一层:Frame Evaluation Hook - 入口机制
PEP 523:Python 的编译器钩子
Python 3.6 引入了 PEP 523,允许 C 扩展替换 Python 的帧执行函数。
文件位置 :torch/_dynamo/eval_frame.c
核心实现
c
// torch/_dynamo/eval_frame.c
// 原始的 Python 帧执行函数
static PyObject* (*previous_eval_frame)(PyFrameObject*, int) = NULL;
// TorchDynamo 的自定义帧执行函数
static PyObject* custom_eval_frame(PyFrameObject* frame, int throw_flag) {
// [1] 检查是否应该编译这个函数
if (!should_compile_frame(frame)) {
// 不编译,使用原始执行
return previous_eval_frame(frame, throw_flag);
}
// [2] 检查缓存
PyObject* cached_code = lookup_cache(frame->f_code);
if (cached_code != NULL) {
// [3] 检查 Guard
if (check_guards(frame, cached_code)) {
// Guard 通过,执行缓存的优化代码
return execute_cached_code(frame, cached_code);
}
}
// [4] 缓存未命中或 Guard 失败,调用 Python 层编译
PyObject* result = py_compile_function(frame);
if (result != NULL) {
// 编译成功,执行优化代码
return result;
}
// [5] 编译失败,回退到原始执行
return previous_eval_frame(frame, throw_flag);
}
// 安装钩子
PyAPI_FUNC(void) _PyInterpreterState_SetEvalFrameFunc(
PyInterpreterState* interp,
_PyFrameEvalFunction eval_frame
);
Python 层的接口
文件位置 :torch/_dynamo/eval_frame.py
python
# torch/_dynamo/eval_frame.py
import sys
from typing import Callable, Optional
# 全局编译缓存
_compile_cache = {}
def set_eval_frame(callback: Optional[Callable] = None):
"""
设置帧执行钩子
Args:
callback: 编译函数,接收 (frame, cache_size) 返回优化后的代码
"""
if callback is None:
# 移除钩子,恢复默认执行
_C._dynamo.eval_frame.set_eval_frame(None)
else:
# 安装钩子
def wrapper(frame, cache_size):
try:
# 调用编译器
return callback(frame, cache_size)
except Exception as e:
# 编译失败,回退
return None
_C._dynamo.eval_frame.set_eval_frame(wrapper)
def compile_frame(frame, cache_size):
"""
编译单个函数帧
Args:
frame: Python 帧对象
cache_size: 当前缓存大小
Returns:
优化后的代码对象或 None
"""
code = frame.f_code
# [1] 检查缓存
cache_key = make_cache_key(frame)
if cache_key in _compile_cache:
cached_entry = _compile_cache[cache_key]
# 检查 Guard
if check_guards(frame, cached_entry.guards):
return cached_entry.optimized_code
# [2] 开始编译
try:
# 字节码分析与转换
from .symbolic_convert import InstructionTranslator
translator = InstructionTranslator(
instructions=dis.get_instructions(code),
f_globals=frame.f_globals,
f_locals=frame.f_locals,
f_builtins=frame.f_builtins,
code_options={},
symbolic_locals={},
symbolic_globals={},
)
# 执行符号化转换
translator.run()
# [3] 获取生成的 FX 图和 Guard
output_graph = translator.output
guards = translator.output.guards
# [4] 编译后端(例如 Inductor)
from .backends import compile_fx
optimized_fn = compile_fx(output_graph.graph, output_graph.example_inputs)
# [5] 缓存结果
_compile_cache[cache_key] = CacheEntry(
optimized_code=optimized_fn,
guards=guards
)
return optimized_fn
except Exception as e:
# 编译失败
log_compilation_error(e)
return None
关键概念:缓存与 Guard
python
class CacheEntry:
"""编译缓存条目"""
def __init__(self, optimized_code, guards):
self.optimized_code = optimized_code # 优化后的代码
self.guards = guards # Guard 列表
class Guard:
"""执行条件守卫"""
def __init__(self, check_fn, verbose_code):
self.check_fn = check_fn # 检查函数
self.verbose_code = verbose_code # 可读的条件描述
def check(self, frame):
"""检查 Guard 是否满足"""
return self.check_fn(frame)
# Guard 示例
def make_tensor_guard(name, expected_shape, expected_dtype):
"""生成 Tensor 的 Guard"""
def check(frame):
tensor = frame.f_locals[name]
return (
tensor.shape == expected_shape and
tensor.dtype == expected_dtype
)
verbose = f"{name}.shape == {expected_shape} and {name}.dtype == {expected_dtype}"
return Guard(check, verbose)
字节码到 FX Graph 的转换过程
这是 TorchDynamo 最核心的部分!让我们详细看看字节码是如何一步步转换为 FX Graph 的。
核心思想:符号化执行
正常执行:用真实的值运行代码
python
x = 3
y = 5
z = x + y # z = 8 (真实计算)
符号化执行:用"符号"代替真实值,记录操作
python
x = Symbol("x") # 不是真的 3,是一个符号
y = Symbol("y") # 不是真的 5,是一个符号
z = x + y # 不真的计算,而是记录"z = add(x, y)"
完整转换流程示例
让我们跟踪一个实际的例子:
python
@torch.compile
def my_function(x, y):
z = x + y
return z * 2
步骤 1:获取字节码
python
import dis
dis.dis(my_function)
输出:
2 0 LOAD_FAST 0 (x)
2 LOAD_FAST 1 (y)
4 BINARY_ADD
6 STORE_FAST 2 (z)
3 8 LOAD_FAST 2 (z)
10 LOAD_CONST 1 (2)
12 BINARY_MULTIPLY
14 RETURN_VALUE
步骤 2:创建 InstructionTranslator
文件位置 :torch/_dynamo/symbolic_convert.py
python
# 简化的初始化过程
translator = InstructionTranslator(
instructions=instructions, # 字节码指令列表
f_code=code,
f_globals=frame.f_globals,
f_locals=frame.f_locals,
)
# 初始状态
translator.stack = [] # 操作数栈(空)
translator.symbolic_locals = {} # 符号化局部变量(空)
translator.output = OutputGraph() # FX Graph 构建器
步骤 3:符号化执行 - 逐条处理字节码
指令 1: LOAD_FAST 0 (x)
python
def INSTR_LOAD_FAST(self, inst):
"""从局部变量加载到栈"""
var_name = inst.argval # "x"
# 第一次访问 x,需要创建符号化变量
if var_name not in self.symbolic_locals:
# 从真实的 frame.f_locals 获取运行时的值
real_value = self.f_locals[var_name] # Tensor([...])
# 创建符号化 Tensor 变量
from torch._dynamo.variables.tensor import TensorVariable
# 同时创建 FX Graph 中的 placeholder 节点
proxy = self.output.create_proxy(
"placeholder", # 节点类型
var_name, # 节点名称
(), # args
{} # kwargs
)
# 创建符号化变量(保存 proxy 和示例值)
symbolic_var = TensorVariable(
proxy=proxy,
example_value=real_value, # 用于形状推导
)
self.symbolic_locals[var_name] = symbolic_var
# 压入栈
var = self.symbolic_locals[var_name]
self.stack.append(var)
# 状态更新:
# stack: [TensorVariable(x)]
# FX Graph: placeholder x
指令 2: LOAD_FAST 1 (y)
python
# 同理,创建 y 的符号化变量
var_name = "y"
real_value = self.f_locals["y"]
proxy_y = self.output.create_proxy("placeholder", "y", (), {})
symbolic_y = TensorVariable(proxy=proxy_y, example_value=real_value)
self.symbolic_locals["y"] = symbolic_y
self.stack.append(symbolic_y)
# 状态更新:
# stack: [TensorVariable(x), TensorVariable(y)]
# FX Graph:
# placeholder x
# placeholder y
指令 3: BINARY_ADD
python
def INSTR_BINARY_ADD(self, inst):
"""二元加法"""
# 从栈弹出两个操作数
right = self.stack.pop() # TensorVariable(y)
left = self.stack.pop() # TensorVariable(x)
# 调用左操作数的 add 方法(符号化执行)
result = left.call_method(
self, # translator
"__add__", # 方法名
[right], # 参数
{} # 关键字参数
)
# 压入结果
self.stack.append(result)
# TensorVariable.call_method 内部实现:
class TensorVariable:
def call_method(self, tx, name, args, kwargs):
if name == "__add__":
other = args[0]
# 在 FX Graph 中创建 add 节点
result_proxy = self.proxy + other.proxy
# 等价于:
# result_proxy = tx.output.create_proxy(
# "call_function",
# torch.add,
# (self.proxy, other.proxy),
# {}
# )
# 推导结果的形状
result_shape = torch.broadcast_shapes(
self.example_value.shape,
other.example_value.shape
)
result_example = torch.empty(
result_shape,
dtype=self.example_value.dtype
)
# 返回新的符号化变量
return TensorVariable(
proxy=result_proxy,
example_value=result_example
)
# 状态更新:
# stack: [TensorVariable(z=x+y)]
# FX Graph:
# placeholder x
# placeholder y
# call_function torch.add(x, y) -> z_temp
指令 4: STORE_FAST 2 (z)
python
def INSTR_STORE_FAST(self, inst):
"""存储到局部变量"""
var_name = inst.argval # "z"
# 从栈弹出值
value = self.stack.pop() # TensorVariable(x+y)
# 保存到符号化局部变量
self.symbolic_locals[var_name] = value
# 状态更新:
# stack: []
# symbolic_locals: {'x': TensorVariable(x), 'y': TensorVariable(y), 'z': TensorVariable(x+y)}
# FX Graph: (不变)
指令 5: LOAD_FAST 2 (z)
python
# 加载 z
var = self.symbolic_locals["z"]
self.stack.append(var)
# 状态更新:
# stack: [TensorVariable(z)]
指令 6: LOAD_CONST 1 (2)
python
def INSTR_LOAD_CONST(self, inst):
"""加载常量"""
const_value = inst.argval # 2
from torch._dynamo.variables.constant import ConstantVariable
symbolic_const = ConstantVariable(const_value)
self.stack.append(symbolic_const)
# 状态更新:
# stack: [TensorVariable(z), ConstantVariable(2)]
指令 7: BINARY_MULTIPLY
python
# 类似 BINARY_ADD
right = self.stack.pop() # ConstantVariable(2)
left = self.stack.pop() # TensorVariable(z)
result = left.call_method(self, "__mul__", [right], {})
# TensorVariable.__mul__ 内部:
result_proxy = self.proxy * 2 # 在 FX Graph 中创建 mul 节点
result = TensorVariable(proxy=result_proxy, ...)
self.stack.append(result)
# 状态更新:
# stack: [TensorVariable(z*2)]
# FX Graph:
# placeholder x
# placeholder y
# call_function torch.add(x, y) -> z
# call_function torch.mul(z, 2) -> result
指令 8: RETURN_VALUE
python
def INSTR_RETURN_VALUE(self, inst):
"""返回值"""
ret_val = self.stack.pop() # TensorVariable(z*2)
# 添加到输出图
self.output.add_output(ret_val)
# FX Graph 完成:
# placeholder x
# placeholder y
# call_function torch.add(x, y) -> z
# call_function torch.mul(z, 2) -> result
# output result
完整的转换过程可视化
[字节码] [符号化执行] [FX Graph]
LOAD_FAST x -> 创建 TensorVariable -> placeholder x
proxy_x = FX.placeholder("x")
stack: [x]
LOAD_FAST y -> 创建 TensorVariable -> placeholder y
proxy_y = FX.placeholder("y")
stack: [x, y]
BINARY_ADD -> 弹出 y, x -> call_function add
调用 x + y
proxy_z = proxy_x + proxy_y
stack: [z] - target: torch.add
- args: (x, y)
STORE_FAST z -> locals['z'] = z
stack: []
LOAD_FAST z -> stack: [z]
LOAD_CONST 2 -> 创建 ConstantVariable
stack: [z, 2]
BINARY_MULTIPLY -> 弹出 2, z -> call_function mul
调用 z * 2
proxy_r = proxy_z * 2 - target: torch.mul
stack: [result] - args: (z, 2)
RETURN_VALUE -> 弹出 result -> output result
output_graph.add_output(result)
最终生成的 FX Graph
python
# 生成的 FX Graph 代码
def forward(x, y):
# placeholder nodes
x_1 = x
y_1 = y
# call_function nodes
add_result = torch.add(x_1, y_1)
mul_result = torch.mul(add_result, 2)
# output
return mul_result
关键组件详解
1. OutputGraph - FX Graph 构建器
python
# torch/_dynamo/output_graph.py
class OutputGraph:
"""构建 FX Graph 的核心类"""
def __init__(self):
self.graph = torch.fx.Graph() # FX Graph
self.graphargs = [] # 输入参数
self.guards = [] # Guard 列表
def create_proxy(self, kind, target, args, kwargs):
"""
在 FX Graph 中创建节点
Args:
kind: 节点类型 (placeholder, call_function, call_method, etc.)
target: 目标 (函数名、方法名等)
args: 参数
kwargs: 关键字参数
Returns:
Proxy 对象(包装了 FX Node)
"""
# 创建 FX Node
node = self.graph.create_node(
kind,
target,
args,
kwargs
)
# 返回 Proxy(用于后续操作)
return torch.fx.Proxy(node, self.graph)
def add_output(self, var):
"""添加输出节点"""
self.graph.output(var.proxy.node)
2. 符号化变量系统
python
# torch/_dynamo/variables/base.py
class VariableTracker:
"""所有符号化变量的基类"""
def call_method(self, tx, name, args, kwargs):
"""调用方法"""
raise NotImplementedError
# torch/_dynamo/variables/tensor.py
class TensorVariable(VariableTracker):
"""Tensor 的符号化表示"""
def __init__(self, proxy, example_value):
self.proxy = proxy # FX Proxy 对象
self.example_value = example_value # 示例值(用于形状推导)
def call_method(self, tx, name, args, kwargs):
"""
调用 Tensor 方法
这里不会真的计算,而是在 FX Graph 中记录操作
"""
if name == "__add__":
other = args[0]
# 在 FX Graph 中创建加法节点
result_proxy = self.proxy + other.proxy
# 推导结果形状
result_example = self.example_value + other.example_value
return TensorVariable(result_proxy, result_example)
elif name == "__mul__":
other = args[0]
if isinstance(other, ConstantVariable):
# Tensor * 常量
result_proxy = self.proxy * other.value
else:
# Tensor * Tensor
result_proxy = self.proxy * other.proxy
result_example = self.example_value * other.as_python_constant()
return TensorVariable(result_proxy, result_example)
# ... 其他方法
# torch/_dynamo/variables/constant.py
class ConstantVariable(VariableTracker):
"""常量的符号化表示"""
def __init__(self, value):
self.value = value
def as_python_constant(self):
return self.value
完整实际代码示例
让我创建一个演示程序:
python
# 演示:模拟 TorchDynamo 的符号化执行
import torch
import torch.fx as fx
import dis
class SimpleOutputGraph:
"""简化的 OutputGraph"""
def __init__(self):
self.graph = fx.Graph()
def create_placeholder(self, name):
return self.graph.placeholder(name)
def create_call_function(self, target, args):
return self.graph.call_function(target, args)
def set_output(self, node):
self.graph.output(node)
class SimpleTensorVariable:
"""简化的 TensorVariable"""
def __init__(self, node):
self.node = node
def __add__(self, other):
# 创建加法节点
result_node = graph.create_call_function(
torch.add,
(self.node, other.node)
)
return SimpleTensorVariable(result_node)
def __mul__(self, other):
# 创建乘法节点
if isinstance(other, int):
# Tensor * 常量
result_node = graph.create_call_function(
torch.mul,
(self.node, other)
)
else:
result_node = graph.create_call_function(
torch.mul,
(self.node, other.node)
)
return SimpleTensorVariable(result_node)
# 模拟执行
def symbolic_trace_manual(fn):
"""手动符号化追踪"""
global graph
graph = SimpleOutputGraph()
print("[字节码]")
dis.dis(fn)
print("\n[符号化执行过程]")
# 模拟执行
print("1. 创建 placeholder x")
node_x = graph.create_placeholder("x")
x = SimpleTensorVariable(node_x)
print("2. 创建 placeholder y")
node_y = graph.create_placeholder("y")
y = SimpleTensorVariable(node_y)
print("3. 执行 z = x + y (符号化)")
z = x + y # 触发 __add__,创建 add 节点
print("4. 执行 result = z * 2 (符号化)")
result = z * 2 # 触发 __mul__,创建 mul 节点
print("5. 设置输出")
graph.set_output(result.node)
print("\n[生成的 FX Graph]")
print(graph.graph)
return graph.graph
# 测试
def my_function(x, y):
z = x + y
return z * 2
symbolic_trace_manual(my_function)
真实的 TorchDynamo 内部流程
python
# torch/_dynamo/symbolic_convert.py (简化版)
class InstructionTranslator:
"""字节码 -> FX Graph 的核心翻译器"""
def run(self):
"""主执行循环"""
while self.instruction_pointer < len(self.instructions):
inst = self.instructions[self.instruction_pointer]
# 获取指令处理函数
handler = getattr(self, f"INSTR_{inst.opname}", None)
if handler is None:
# 不支持的指令 -> Graph Break
self.graph_break(f"unsupported: {inst.opname}")
return
# 执行指令处理
try:
handler(inst)
except Exception as e:
self.graph_break(str(e))
return
self.instruction_pointer += 1
# 所有指令处理完毕
self.output.finalize()
def INSTR_LOAD_FAST(self, inst):
"""处理 LOAD_FAST 指令"""
var_name = inst.argval
if var_name not in self.symbolic_locals:
# 第一次访问,创建 placeholder
real_value = self.f_locals[var_name]
proxy = self.output.create_proxy(
"placeholder",
var_name,
(),
{}
)
# 根据真实值的类型创建对应的符号化变量
if isinstance(real_value, torch.Tensor):
var = TensorVariable(proxy, real_value)
elif isinstance(real_value, int):
var = ConstantVariable(real_value)
# ... 其他类型
self.symbolic_locals[var_name] = var
self.stack.append(self.symbolic_locals[var_name])
def INSTR_BINARY_ADD(self, inst):
"""处理 BINARY_ADD 指令"""
right = self.stack.pop()
left = self.stack.pop()
# 符号化执行加法
result = left.call_method(self, "__add__", [right], {})
self.stack.append(result)
# ... 其他指令处理函数
核心要点总结
-
符号化执行 = 记录操作而不是真的计算
- 用符号变量代替真实值
- 操作被记录到 FX Graph 中
-
每条字节码指令 -> FX Graph 节点
LOAD_FAST x->placeholder xBINARY_ADD->call_function torch.addRETURN_VALUE->output
-
符号化变量携带两种信息
proxy: FX Graph 中的节点引用example_value: 运行时的示例值(用于形状推导)
-
关键组件
InstructionTranslator: 字节码翻译器OutputGraph: FX Graph 构建器TensorVariable: Tensor 的符号化表示Proxy: FX Node 的包装
-
转换是保守的
- 遇到不支持的操作 -> Graph Break
- 保证正确性优先于性能
完整流程总结图
[Python 源码]
def my_function(x, y):
z = x + y
return z * 2
|
| [Python 编译器]
v
[字节码]
LOAD_FAST 0 (x) <-- 指令 1
LOAD_FAST 1 (y) <-- 指令 2
BINARY_ADD <-- 指令 3
STORE_FAST 2 (z) <-- 指令 4
LOAD_FAST 2 (z) <-- 指令 5
LOAD_CONST 1 (2) <-- 指令 6
BINARY_MULTIPLY <-- 指令 7
RETURN_VALUE <-- 指令 8
|
| [TorchDynamo 符号化执行]
v
[符号化执行过程]
指令 1: LOAD_FAST x
-> 创建 TensorVariable(x)
-> proxy_x = graph.placeholder("x")
-> stack: [x]
指令 2: LOAD_FAST y
-> 创建 TensorVariable(y)
-> proxy_y = graph.placeholder("y")
-> stack: [x, y]
指令 3: BINARY_ADD
-> 弹出 y, x
-> 调用 x.__add__(y)
-> proxy_z = graph.call_function(torch.add, (x, y))
-> stack: [z]
指令 4: STORE_FAST z
-> locals["z"] = stack.pop()
-> stack: []
指令 5: LOAD_FAST z
-> stack: [z]
指令 6: LOAD_CONST 2
-> stack: [z, 2]
指令 7: BINARY_MULTIPLY
-> 弹出 2, z
-> 调用 z.__mul__(2)
-> proxy_r = graph.call_function(torch.mul, (z, 2))
-> stack: [result]
指令 8: RETURN_VALUE
-> graph.output(stack.pop())
|
v
[FX Graph]
graph():
%x : placeholder[target=x]
%y : placeholder[target=y]
%add : call_function[target=torch.add](args=(%x, %y))
%mul : call_function[target=torch.mul](args=(%add, 2))
return mul
|
| [代码生成]
v
[生成的 Python 代码]
def forward(x, y):
add = torch.add(x, y)
mul = torch.mul(add, 2)
return mul
|
| [后端编译 (Inductor)]
v
[优化的代码]
- 算子融合: add + mul -> fused_add_mul
- CUDA Kernel 生成
- 内存优化
关键理解:三层抽象
[第一层] Python 源码
def my_function(x, y):
z = x + y
return z * 2
特点: 人类可读,灵活
问题: 难以优化,执行慢
↓ [Python 编译器]
[第二层] 字节码
LOAD_FAST x
LOAD_FAST y
BINARY_ADD
...
特点: CPython 内部表示
问题: 仍然是解释执行,慢
↓ [TorchDynamo 符号化执行]
[第三层] FX Graph
graph():
%x = placeholder[target=x]
%y = placeholder[target=y]
%add = call_function[torch.add](%x, %y)
%mul = call_function[torch.mul](%add, 2)
return mul
特点: 显式的计算图,可分析可变换
优势: 可以做各种优化!
↓ [后端编译器]
[第四层] 机器码
融合的 CUDA Kernel
特点: 直接在 GPU 上执行
优势: 极快!
为什么需要 FX Graph?
没有 FX Graph(Eager 执行):
python
z = x + y # 立即执行 add,结果写入显存
result = z * 2 # 立即执行 mul,需要从显存读 z
问题:
- 两个独立的 Kernel 调用
- 中间结果 z 需要写回显存(慢!)
- 无法优化
有了 FX Graph:
python
# FX Graph:
graph():
%add = call_function[torch.add](%x, %y)
%mul = call_function[torch.mul](%add, 2)
后端编译器看到整个计算流程:
"哦,add 的结果立即被 mul 使用,我可以融合!"
优化后的 Kernel:
result = fused_add_mul(x, y, 2) # 一个 Kernel,z 不用写回显存
优势:
- 一个 Kernel 调用(减少 Kernel 启动开销)
- 中间结果 z 保持在寄存器(极快!)
- 节省显存带宽
演示程序
我创建了 Pytorch/examples/demo_bytecode_to_graph.py,运行它可以看到:
- 字节码指令详情
- 逐步的符号化执行过程
- 生成的 FX Graph
- 对比真实的 TorchDynamo 行为
bash
cd Pytorch/examples
python demo_bytecode_to_graph.py
最终总结:一句话理解
字节码到 FX Graph 的转换 = 把 Python 的底层指令"翻译"成一个显式的计算图,这样编译器就能看到整个计算流程并进行优化。
核心技巧:
- 不真的执行计算(符号化执行)
- 用符号变量代替真实值
- 每个操作都记录到图中
- 最终得到一个可优化的中间表示(IR)
第二层:字节码分析 - Bytecode Analysis
Python 字节码基础
python
import dis
def example(x, y):
z = x + y
if z > 0:
return z * 2
else:
return z + 1
dis.dis(example)
输出:
2 0 LOAD_FAST 0 (x)
2 LOAD_FAST 1 (y)
4 BINARY_ADD
6 STORE_FAST 2 (z)
3 8 LOAD_FAST 2 (z)
10 LOAD_CONST 1 (0)
12 COMPARE_OP 4 (>)
14 POP_JUMP_IF_FALSE 24
4 16 LOAD_FAST 2 (z)
18 LOAD_CONST 2 (2)
20 BINARY_MULTIPLY
22 RETURN_VALUE
6 >> 24 LOAD_FAST 2 (z)
26 LOAD_CONST 3 (1)
28 BINARY_ADD
30 RETURN_VALUE
TorchDynamo 的字节码翻译器
文件位置 :torch/_dynamo/symbolic_convert.py
python
# torch/_dynamo/symbolic_convert.py
class InstructionTranslator:
"""
字节码指令翻译器
职责:
1. 逐条解析字节码指令
2. 维护符号化的变量栈
3. 构建 FX Graph
4. 生成 Guard
"""
def __init__(
self,
instructions,
f_globals,
f_locals,
f_builtins,
code_options,
symbolic_locals,
symbolic_globals,
):
# 字节码指令列表
self.instructions = list(instructions)
self.instruction_pointer = 0
# 命名空间
self.f_globals = f_globals
self.f_locals = f_locals
self.f_builtins = f_builtins
# 符号化变量栈(模拟 Python 执行栈)
self.stack = []
# 局部变量(符号化)
self.symbolic_locals = {}
# 输出图
self.output = OutputGraph(
f_globals=f_globals,
f_locals=f_locals,
f_builtins=f_builtins,
)
# Guard 列表
self.guards = []
def run(self):
"""执行符号化转换"""
while self.instruction_pointer < len(self.instructions):
inst = self.instructions[self.instruction_pointer]
# 分发到具体的指令处理函数
handler_name = f"INSTR_{inst.opname}"
handler = getattr(self, handler_name, None)
if handler is None:
# 不支持的指令 -> Graph Break
self.graph_break(f"unsupported instruction: {inst.opname}")
return
# 执行指令
handler(inst)
# 移动指令指针
self.instruction_pointer += 1
# ========== 指令处理函数 ==========
def INSTR_LOAD_FAST(self, inst):
"""
LOAD_FAST: 加载局部变量到栈
"""
var_name = inst.argval
if var_name not in self.symbolic_locals:
# 变量未定义 -> Graph Break
self.graph_break(f"undefined variable: {var_name}")
return
# 获取符号化变量
symbolic_var = self.symbolic_locals[var_name]
# 压入栈
self.stack.append(symbolic_var)
def INSTR_LOAD_CONST(self, inst):
"""
LOAD_CONST: 加载常量到栈
"""
const_value = inst.argval
# 创建常量变量
from .variables import ConstantVariable
symbolic_var = ConstantVariable(const_value)
# 压入栈
self.stack.append(symbolic_var)
def INSTR_BINARY_ADD(self, inst):
"""
BINARY_ADD: 二元加法
"""
# 弹出两个操作数
right = self.stack.pop()
left = self.stack.pop()
# 执行符号化加法
result = left.call_method(self, "add", [right], {})
# 结果压入栈
self.stack.append(result)
def INSTR_COMPARE_OP(self, inst):
"""
COMPARE_OP: 比较操作
"""
comp_op = inst.argval # 例如 '>'
# 弹出两个操作数
right = self.stack.pop()
left = self.stack.pop()
# 符号化比较
result = left.call_method(self, f"__{comp_op}__", [right], {})
# 压入栈
self.stack.append(result)
def INSTR_POP_JUMP_IF_FALSE(self, inst):
"""
POP_JUMP_IF_FALSE: 条件跳转(if 语句)
"""
# 弹出条件
condition = self.stack.pop()
# [关键] 这里需要决定走哪个分支
# 如果条件依赖于 Tensor 的值,需要生成 Guard
if condition.is_python_constant():
# 常量条件,直接判断
if not condition.as_python_constant():
# 跳转
self.instruction_pointer = inst.argval - 1
else:
# 动态条件,需要 Guard
# 尝试获取当前运行时的值
runtime_value = condition.get_runtime_value()
# 生成 Guard
guard = condition.make_guard(runtime_value)
self.guards.append(guard)
# 根据运行时值决定分支
if not runtime_value:
self.instruction_pointer = inst.argval - 1
def INSTR_RETURN_VALUE(self, inst):
"""
RETURN_VALUE: 返回值
"""
# 弹出返回值
ret_val = self.stack.pop()
# 添加到输出图
self.output.add_output(ret_val)
def INSTR_CALL_FUNCTION(self, inst):
"""
CALL_FUNCTION: 函数调用
"""
n_args = inst.argval
# 弹出参数
args = [self.stack.pop() for _ in range(n_args)]
args.reverse()
# 弹出函数
fn = self.stack.pop()
# 符号化调用
result = fn.call_function(self, args, {})
# 压入结果
self.stack.append(result)
# ========== Graph Break 处理 ==========
def graph_break(self, reason):
"""
触发 Graph Break
Args:
reason: 断开原因
"""
log_graph_break(reason)
# 结束当前图
self.output.finalize()
# 生成恢复执行的代码
# (这部分代码会在后续章节详细解释)
self.output.create_resume_at(self.instruction_pointer)
第三层:符号化变量系统 - Symbolic Variables
TorchDynamo 为每种 Python 类型定义了对应的"符号化变量"。
文件位置 :torch/_dynamo/variables/
变量基类
python
# torch/_dynamo/variables/base.py
class VariableTracker:
"""
符号化变量的基类
每个变量代表 Python 执行过程中的一个值
"""
def __init__(self, source=None):
self.source = source # 变量来源(用于生成 Guard)
def call_method(self, tx, name, args, kwargs):
"""
调用方法
Args:
tx: InstructionTranslator 实例
name: 方法名
args: 参数列表
kwargs: 关键字参数
Returns:
新的 VariableTracker
"""
raise NotImplementedError
def call_function(self, tx, args, kwargs):
"""
作为函数调用
"""
raise NotImplementedError
def is_python_constant(self):
"""是否是 Python 常量"""
return False
def as_python_constant(self):
"""转换为 Python 常量"""
raise NotImplementedError
def make_guard(self, expected_value):
"""
生成 Guard
Args:
expected_value: 期望的运行时值
Returns:
Guard 对象
"""
raise NotImplementedError
Tensor 变量
python
# torch/_dynamo/variables/tensor.py
class TensorVariable(VariableTracker):
"""
表示 torch.Tensor 的符号化变量
"""
def __init__(self, proxy, example_value, source=None):
super().__init__(source)
self.proxy = proxy # FX Proxy 对象
self.example_value = example_value # 示例值(用于形状推导)
def call_method(self, tx, name, args, kwargs):
"""
调用 Tensor 方法
例如:tensor.add(other)
"""
if name == "add":
# 处理 add 方法
other = args[0]
# 在 FX 图中创建节点
result_proxy = self.proxy.add(other.proxy)
# 推导结果形状
result_shape = torch.broadcast_shapes(
self.example_value.shape,
other.example_value.shape
)
result_example = torch.empty(result_shape, dtype=self.example_value.dtype)
# 返回新的 TensorVariable
return TensorVariable(result_proxy, result_example, source=None)
# 其他方法...
return super().call_method(tx, name, args, kwargs)
def make_guard(self, expected_value):
"""
生成 Tensor 的 Guard
检查:形状、dtype、设备
"""
guards = []
# 形状 Guard
shape_guard = Guard(
lambda frame: frame.f_locals[self.source.name()].shape == expected_value.shape,
f"{self.source.name()}.shape == {expected_value.shape}"
)
guards.append(shape_guard)
# dtype Guard
dtype_guard = Guard(
lambda frame: frame.f_locals[self.source.name()].dtype == expected_value.dtype,
f"{self.source.name()}.dtype == {expected_value.dtype}"
)
guards.append(dtype_guard)
return guards
常量变量
python
# torch/_dynamo/variables/constant.py
class ConstantVariable(VariableTracker):
"""
表示 Python 常量(int, float, str, None 等)
"""
def __init__(self, value):
super().__init__(source=None)
self.value = value
def is_python_constant(self):
return True
def as_python_constant(self):
return self.value
def call_method(self, tx, name, args, kwargs):
"""
常量方法调用
例如:(1).__add__(2) -> 3
"""
# 直接执行 Python 操作
result = getattr(self.value, name)(*[a.as_python_constant() for a in args])
return ConstantVariable(result)
第四层:Guard 机制 - 动态适配的关键
文件位置 :torch/_dynamo/guards.py
Guard 的生成
python
# torch/_dynamo/guards.py
class GuardBuilder:
"""
Guard 构建器
为每个符号化变量生成运行时检查条件
"""
def __init__(self):
self.guards = []
def TYPE_MATCH(self, obj_source, expected_type):
"""
类型匹配 Guard
例如:type(x) == torch.Tensor
"""
def check(frame):
obj = eval_source(frame, obj_source)
return type(obj) == expected_type
guard = Guard(
check_fn=check,
verbose_code=f"type({obj_source}) == {expected_type.__name__}"
)
self.guards.append(guard)
def TENSOR_MATCH(self, tensor_source, expected_tensor):
"""
Tensor 匹配 Guard
检查:shape, dtype, device, requires_grad
"""
# 形状
self.SHAPE_MATCH(tensor_source, expected_tensor.shape)
# dtype
self.DTYPE_MATCH(tensor_source, expected_tensor.dtype)
# device
self.DEVICE_MATCH(tensor_source, expected_tensor.device)
# requires_grad
self.REQUIRES_GRAD_MATCH(tensor_source, expected_tensor.requires_grad)
def SHAPE_MATCH(self, tensor_source, expected_shape):
"""形状匹配"""
def check(frame):
tensor = eval_source(frame, tensor_source)
return tensor.shape == expected_shape
guard = Guard(
check_fn=check,
verbose_code=f"{tensor_source}.shape == {expected_shape}"
)
self.guards.append(guard)
def DYNAMIC_DIM(self, tensor_source, dim_idx):
"""
动态维度
不固定某个维度的大小(用于 batch size)
"""
# 只检查其他维度
def check(frame):
tensor = eval_source(frame, tensor_source)
# 检查除了 dim_idx 之外的维度
return True # 简化示例
guard = Guard(
check_fn=check,
verbose_code=f"{tensor_source}.shape[{dim_idx}] is dynamic"
)
self.guards.append(guard)
# Guard 检查
def check_guards(frame, guards):
"""
检查所有 Guard 是否通过
Args:
frame: Python 帧对象
guards: Guard 列表
Returns:
bool: 所有 Guard 都通过返回 True
"""
for guard in guards:
if not guard.check(frame):
log_guard_failure(guard)
return False
return True
Guard 优化
python
# Guard 数量会影响性能,需要优化
class GuardOptimizer:
"""Guard 优化器"""
def optimize(self, guards):
"""
优化 Guard 列表
策略:
1. 去重
2. 合并相似的 Guard
3. 提前失败(快速检查)
"""
# 1. 去重
unique_guards = self.deduplicate(guards)
# 2. 排序(让快速检查的 Guard 在前面)
sorted_guards = self.sort_by_cost(unique_guards)
return sorted_guards
def deduplicate(self, guards):
"""去除重复的 Guard"""
seen = set()
result = []
for guard in guards:
key = guard.verbose_code
if key not in seen:
seen.add(key)
result.append(guard)
return result
def sort_by_cost(self, guards):
"""按检查成本排序(低成本优先)"""
def cost(guard):
# 类型检查最快
if "type(" in guard.verbose_code:
return 1
# 属性检查次快
elif ".shape" in guard.verbose_code or ".dtype" in guard.verbose_code:
return 2
# 其他
else:
return 3
return sorted(guards, key=cost)
第五层:Graph Break 处理 - 容错机制
Graph Break 的原因
python
# 常见的 Graph Break 场景
def example1(x):
x = x + 1
print(x) # [Graph Break] 副作用(print)
return x * 2
def example2(x):
x = x + 1
if x.item() > 0: # [Graph Break] 需要立即读取 Tensor 值
return x * 2
return x + 1
def example3(x):
x = x + 1
import numpy as np
y = np.array([1, 2, 3]) # [Graph Break] 不支持的库
return x + len(y)
Graph Break 的处理
文件位置 :torch/_dynamo/resume_execution.py
python
# torch/_dynamo/resume_execution.py
class GraphBreakHandler:
"""
Graph Break 处理器
当遇到不可追踪的代码时:
1. 结束当前图
2. 生成恢复执行的代码
3. 继续追踪后续代码
"""
def handle_graph_break(self, translator, reason, instruction_pointer):
"""
处理 Graph Break
Args:
translator: InstructionTranslator 实例
reason: 断开原因
instruction_pointer: 当前指令位置
"""
# [1] 记录 Graph Break
log_graph_break(reason, instruction_pointer)
# [2] 结束当前图
current_graph = translator.output.graph
# [3] 生成代码:将当前状态传递给 Python 解释器
resume_code = self.generate_resume_code(
translator,
instruction_pointer
)
# [4] 创建新的翻译器,继续追踪后续代码
next_translator = InstructionTranslator(
instructions=translator.instructions[instruction_pointer:],
f_globals=translator.f_globals,
f_locals=translator.f_locals,
f_builtins=translator.f_builtins,
code_options=translator.code_options,
symbolic_locals=translator.symbolic_locals,
symbolic_globals=translator.symbolic_globals,
)
return resume_code, next_translator
def generate_resume_code(self, translator, instruction_pointer):
"""
生成恢复执行的代码
作用:将优化的图执行结果传递给后续的 Python 代码
"""
# 示例生成的代码:
code = f"""
def resume_function(*args, **kwargs):
# 执行优化的图
graph_output = compiled_graph(*args, **kwargs)
# 恢复局部变量
locals_dict = {{}}
locals_dict['x'] = graph_output
# 继续执行原始 Python 代码(从 Graph Break 位置开始)
# ...
return result
"""
return code
Graph Break 的可视化
python
import torch._dynamo as dynamo
def visualize_graph_breaks(fn):
"""可视化函数中的 Graph Break"""
explanation = dynamo.explain(fn)
print(f"函数: {fn.__name__}")
print(f"Graph 数量: {explanation.graph_count}")
print(f"Graph Break 数量: {explanation.graph_break_count}")
print("\nGraph Break 详情:")
for i, break_reason in enumerate(explanation.break_reasons):
print(f" [{i+1}] {break_reason}")
# 使用示例
def my_function(x):
x = x + 1
print(x) # Break 1
x = x * 2
if x.item() > 0: # Break 2
return x + 1
return x - 1
visualize_graph_breaks(my_function)
# 输出:
# 函数: my_function
# Graph 数量: 3
# Graph Break 数量: 2
#
# Graph Break 详情:
# [1] call to print (side effect)
# [2] call to tensor.item() (data-dependent control flow)
完整示例:从字节码到优化
让我们通过一个完整的例子,看看 TorchDynamo 如何工作:
python
import torch
import torch._dynamo as dynamo
def example(x, y):
z = x + y
if z.sum() > 0:
return z * 2
else:
return z + 1
# 编译
compiled_fn = torch.compile(example)
# 第一次运行
x1 = torch.randn(10)
y1 = torch.randn(10)
output1 = compiled_fn(x1, y1)
# 内部流程:
"""
1. Frame Evaluation Hook 拦截
-> custom_eval_frame() 被调用
2. 检查缓存
-> 缓存为空,开始编译
3. 字节码分析
指令序列:
LOAD_FAST x
LOAD_FAST y
BINARY_ADD
STORE_FAST z
LOAD_FAST z
LOAD_METHOD sum
CALL_METHOD 0
LOAD_CONST 0
COMPARE_OP >
POP_JUMP_IF_FALSE [else分支]
...
4. 符号化执行
-> x: TensorVariable(shape=(10,), dtype=float32)
-> y: TensorVariable(shape=(10,), dtype=float32)
-> z = x + y: TensorVariable(shape=(10,), dtype=float32)
-> z.sum(): TensorVariable(shape=(), dtype=float32)
-> z.sum() > 0: 需要运行时值
运行时值 = True (假设)
5. 生成 Guard
Guard 1: x.shape == (10,)
Guard 2: x.dtype == torch.float32
Guard 3: y.shape == (10,)
Guard 4: y.dtype == torch.float32
Guard 5: (x + y).sum() > 0 # 控制流 Guard
6. 构建 FX Graph(if 分支)
def forward(x, y):
z = x + y
return z * 2
7. 后端编译(Inductor)
-> 生成融合的 CUDA kernel
8. 缓存并执行
"""
# 第二次运行(相同条件)
x2 = torch.randn(10)
y2 = torch.randn(10) # 假设 sum() > 0
output2 = compiled_fn(x2, y2)
# 内部流程:
"""
1. 检查缓存 -> 命中
2. 检查 Guard -> 全部通过
3. 直接执行缓存的优化代码 [快!]
"""
# 第三次运行(条件改变)
x3 = torch.randn(10) * -10 # 让 sum() < 0
y3 = torch.randn(10) * -10
output3 = compiled_fn(x3, y3)
# 内部流程:
"""
1. 检查缓存 -> 命中
2. 检查 Guard -> Guard 5 失败!(sum() > 0 不满足)
3. 重新编译 else 分支
4. 生成新的 Guard 和优化代码
5. 缓存两个版本(if 和 else)
"""
总结
TorchDynamo 的核心原理
- Frame Evaluation Hook:PEP 523 钩子劫持函数执行
- 字节码分析:反汇编并逐条解析字节码
- 符号化执行:不实际执行,而是构建符号化的计算图
- Guard 机制:记录执行条件,动态适配不同情况
- FX Graph 构建:生成 FX 图,传递给后端编译器
- Graph Break 处理:遇到不可追踪代码时优雅降级
与 FX 的关系
TorchDynamo = 前端(捕获图)
|
| 输出 FX Graph
|
v
FX = 中间表示(图优化)
|
| 输出优化后的 Graph
|
v
后端编译器(Inductor/AOTAutograd)= 代码生成
关键优势
- 完整的 Python 支持:控制流、异常处理、动态特性
- 零代码修改 :只需
@torch.compile装饰器 - 渐进式优化:Graph Break 后继续追踪
- 动态适配:Guard 机制支持动态形状和控制流
性能特征
首次编译:慢(需要分析、编译)
|
v
缓存命中 + Guard 通过:极快(直接执行优化代码)
|
v
Guard 失败:中等(重新编译,但有部分缓存)
推荐阅读顺序
如果你想深入阅读 TorchDynamo 源码,建议按以下顺序:
torch/_dynamo/eval_frame.py- 理解 Frame Hook 机制 (500 行)torch/_dynamo/symbolic_convert.py- 理解字节码翻译 (2000+ 行,核心)torch/_dynamo/variables/- 理解符号化变量系统 (多个文件)torch/_dynamo/guards.py- 理解 Guard 生成与检查 (800 行)torch/_dynamo/output_graph.py- 理解 FX Graph 构建 (600 行)
总计:~5000+ 行核心代码(比 FX 复杂 2 倍)
本文档基于 PyTorch 2.0+ 源码编写,部分实现细节可能因版本而异