【torch.compile】TorchDynamo 源码深度剖析

TorchDynamo 源码深度剖析

为什么需要 TorchDynamo?与 FX 的对比

FX 的局限性

在学习完 FX 后,我们知道 FX 通过 Proxy 拦截机制可以捕获计算图,但它有一个致命缺陷:

python 复制代码
# FX 无法处理的情况
def forward(self, x):
    if x.sum() > 0:  # [X] 数据依赖的控制流
        return x * 2
    else:
        return x + 1

traced = fx.symbolic_trace(Model())  # [错误] 会报错或只捕获一个分支

FX 的问题

  • 无法处理数据依赖的控制流(if/while/for)
  • 需要实际执行代码才能追踪
  • 对 Python 动态特性支持有限
  • Graph Break 后无法继续追踪

TorchDynamo 的优势

python 复制代码
# TorchDynamo 可以处理
@torch.compile
def forward(self, x):
    if x.sum() > 0:  # [√] 完全支持!
        return x * 2
    else:
        return x + 1

output = forward(torch.randn(10))  # 工作正常

TorchDynamo 的优势

  • √\] 字节码级追踪:不需要实际执行

  • √\] Guard 机制:动态适配不同情况

技术对比

维度 FX TorchDynamo
追踪方式 Proxy 拦截(运行时) 字节码分析(静态+动态)
控制流 有限(只能捕获实际执行的分支) 完整支持(通过 Guard)
Python 兼容性 部分(需要可追踪的代码) 几乎完整
性能 追踪开销小 首次编译慢,但有缓存
适用场景 图优化、模型分析 端到端编译优化

核心设计思想

复制代码
FX:  执行 Python 代码 -> 拦截操作 -> 构建图
     [必须运行]      [Proxy 魔术方法]

TorchDynamo: 分析字节码 -> 模拟执行 -> 构建图 + Guard
              [静态分析]   [符号执行]   [动态适配]

TorchDynamo 是什么?

一句话总结

TorchDynamo 是 PyTorch 的 JIT 编译器前端,通过字节码级别的追踪和 Guard 机制,在保持完整 Python 语义的同时捕获计算图。

核心组成

复制代码
TorchDynamo 架构:
├─ Frame Evaluation Hook  [入口] PEP 523 钩子,劫持函数执行
├─ Bytecode Analysis      [分析] 反汇编字节码指令
├─ Symbolic Execution     [执行] 符号化执行字节码
├─ Guard Generation       [保护] 生成执行条件
├─ FX Graph Construction  [输出] 生成 FX 图
└─ Graph Break Handling   [容错] 处理不可追踪的代码

工作流程

复制代码
Python 函数调用
    |
    v
[Frame Evaluation Hook]
    |
    | 拦截函数帧
    |
    v
[检查缓存]
    |
    +-- 缓存命中 + Guard 通过 --> 直接执行优化代码 [快]
    |
    +-- 缓存未命中/Guard 失败
        |
        v
    [Bytecode Analysis]
        |
        | 反汇编字节码
        |
        v
    [Symbolic Execution]
        |
        | 符号化执行每条指令
        |
        v
    [构建 FX Graph]
        |
        | 生成计算图
        |
        v
    [生成 Guard]
        |
        | 记录执行条件
        |
        v
    [后端编译] (Inductor/其他)
        |
        | 生成优化代码
        |
        v
    [缓存并执行]

源码文件索引

TorchDynamo 的核心代码位于 torch/_dynamo/ 目录:

复制代码
torch/_dynamo/
├── eval_frame.c/py          # [核心] Frame 劫持与缓存管理
├── symbolic_convert.py      # [核心] 字节码 -> FX Graph 转换
├── bytecode_transformation.py  # [核心] 字节码分析与变换
├── guards.py                # [核心] Guard 生成与检查
├── variables/               # 符号化变量系统
│   ├── base.py             # 变量基类
│   ├── tensor.py           # Tensor 变量
│   ├── user_defined.py     # 用户自定义类
│   └── ...
├── output_graph.py          # FX Graph 构建
├── resume_execution.py      # Graph Break 后恢复执行
├── side_effects.py          # 副作用处理
└── utils.py                 # 工具函数

基础概念:什么是执行帧(Frame)

在理解 TorchDynamo 之前,我们需要先理解 Python 的执行模型。

Python 代码的执行过程

当你运行一个 Python 函数时,CPython 解释器会经历这些步骤:

python 复制代码
def add(a, b):
    result = a + b
    return result

x = add(1, 2)

执行流程:

复制代码
Python 源代码
    |
    v
[1] 词法分析 + 语法分析
    |
    v
抽象语法树 (AST)
    |
    v
[2] 编译为字节码
    |
    v
字节码 (Code Object)
    |
    v
[3] 创建执行帧 (Frame)  <-- 这里!
    |
    v
[4] 执行字节码
    |
    v
返回结果

什么是执行帧(Frame)?

执行帧(PyFrameObject)就是一个函数调用的"运行环境",它包含了执行这个函数所需的一切信息。

可以把它想象成一个"工作台":

python 复制代码
class Frame:
    """
    这是执行帧的简化概念模型
    """
    def __init__(self):
        # 1. 要执行的代码
        self.f_code = CodeObject(...)  # 字节码指令
        
        # 2. 变量存储
        self.f_locals = {}   # 局部变量: {'a': 1, 'b': 2, 'result': 3}
        self.f_globals = {}  # 全局变量
        self.f_builtins = {} # 内置函数
        
        # 3. 执行状态
        self.f_lasti = 0     # 当前执行到第几条指令
        self.value_stack = [] # 操作数栈
        
        # 4. 调用链信息
        self.f_back = None   # 调用者的帧(形成调用栈)

真实示例:查看执行帧

python 复制代码
import sys
import inspect

def inner_function(x):
    # 获取当前帧
    frame = sys._getframe()
    
    print("=== 当前帧信息 ===")
    print(f"函数名: {frame.f_code.co_name}")
    print(f"文件名: {frame.f_code.co_filename}")
    print(f"行号: {frame.f_lineno}")
    print(f"局部变量: {frame.f_locals}")
    print(f"字节码指令数: {len(frame.f_code.co_code)}")
    
    return x * 2

def outer_function():
    result = inner_function(10)
    return result

# 运行
outer_function()

输出:

复制代码
=== 当前帧信息 ===
函数名: inner_function
文件名: /path/to/your/file.py
行号: 5
局部变量: {'x': 10, 'frame': <frame object>}
字节码指令数: 42

查看字节码

python 复制代码
import dis

def add(a, b):
    result = a + b
    return result

# 查看字节码
dis.dis(add)

输出:

复制代码
  2           0 LOAD_FAST                0 (a)
              2 LOAD_FAST                1 (b)
              4 BINARY_ADD
              6 STORE_FAST               2 (result)

  3           8 LOAD_FAST                2 (result)
             10 RETURN_VALUE

每一行就是一条字节码指令,执行帧会逐条执行这些指令。

CPython 如何执行帧

在 CPython 内部,有一个核心函数负责执行帧:

c 复制代码
// CPython 源码:Python/ceval.c

PyObject* _PyEval_EvalFrameDefault(
    PyThreadState* tstate,
    PyFrameObject* frame,
    int throwflag
) {
    /*
     * 这是 CPython 的"心脏"!
     * 
     * 它的工作:
     * 1. 读取 frame.f_code.co_code (字节码)
     * 2. 逐条执行指令
     * 3. 操作 value_stack (操作数栈)
     * 4. 更新 f_locals (局部变量)
     * 5. 返回结果
     */
    
    // 简化的执行循环
    for (;;) {
        // 获取当前指令
        int opcode = NEXTOP();
        
        switch (opcode) {
            case LOAD_FAST:
                // 从局部变量加载到栈
                value = GETLOCAL(oparg);
                PUSH(value);
                break;
            
            case BINARY_ADD:
                // 从栈弹出两个值,相加,结果压栈
                right = POP();
                left = POP();
                sum = PyNumber_Add(left, right);
                PUSH(sum);
                break;
            
            case RETURN_VALUE:
                // 返回栈顶值
                retval = POP();
                return retval;
            
            // ... 100+ 其他指令
        }
    }
}

关键问题:如何拦截?

正常情况下的执行流程:

复制代码
Python 调用函数
    |
    v
CPython 创建 Frame
    |
    v
调用 _PyEval_EvalFrameDefault(frame)  <-- 默认执行函数
    |
    v
逐条执行字节码
    |
    v
返回结果

PEP 523 引入的机制:允许替换执行函数

Python 3.6+ 提供了 _PyInterpreterState_SetEvalFrameFunc() 函数,可以替换默认的帧执行函数

c 复制代码
// CPython 提供的 API

// 解释器状态
typedef struct {
    // ... 其他字段
    
    // 帧执行函数指针(可以替换!)
    _PyFrameEvalFunction eval_frame;  // 默认指向 _PyEval_EvalFrameDefault
    
} PyInterpreterState;

// 设置自定义的执行函数
void _PyInterpreterState_SetEvalFrameFunc(
    PyInterpreterState* interp,
    _PyFrameEvalFunction new_eval_frame  // 你的自定义函数!
);

TorchDynamo 的拦截机制

安装拦截器:

c 复制代码
// torch/csrc/dynamo/eval_frame.c

// 1. 保存原始的执行函数
static _PyFrameEvalFunction original_eval_frame = NULL;

// 2. 定义自己的执行函数
static PyObject* custom_eval_frame(
    PyThreadState* tstate,
    PyFrameObject* frame,
    int throwflag
) {
    // 这里可以做任何事情!
    
    // 获取代码对象
    PyCodeObject* code = frame->f_code;
    
    printf("拦截到函数调用: %s\n", PyUnicode_AsUTF8(code->co_name));
    printf("字节码长度: %ld\n", PyBytes_Size(code->co_code));
    
    // 可以选择:
    // A. 调用原始执行(正常运行)
    // B. 调用自己的编译逻辑(优化)
    // C. 混合使用
    
    if (should_compile(frame)) {
        // 编译并执行优化代码
        return compile_and_run(frame);
    } else {
        // 回退到原始执行
        return original_eval_frame(tstate, frame, throwflag);
    }
}

// 3. 安装拦截器
void install_hook() {
    PyInterpreterState* interp = PyThreadState_GET()->interp;
    
    // 保存原始函数
    original_eval_frame = _PyInterpreterState_GetEvalFrameFunc(interp);
    
    // 替换为自定义函数
    _PyInterpreterState_SetEvalFrameFunc(interp, custom_eval_frame);
    
    printf("拦截器已安装!\n");
}

拦截后的执行流程:

复制代码
Python 调用函数
    |
    v
CPython 创建 Frame
    |
    v
调用 interp->eval_frame(frame)
    |
    |  (已经被替换!)
    v
调用 custom_eval_frame(frame)  <-- TorchDynamo 的函数
    |
    v
检查:是否需要编译?
    |
    +-- 是 --> 分析字节码 --> 构建 FX Graph --> 编译 --> 执行优化代码
    |
    +-- 否 --> 调用 original_eval_frame(frame) --> 正常执行
    |
    v
返回结果

完整示例:演示拦截

python 复制代码
# 演示文件: demo_frame_hook.py

import sys
import types

# 模拟的拦截计数器
call_count = 0

def my_custom_eval_frame(frame, exc):
    """
    自定义的帧执行函数(Python 层模拟)
    """
    global call_count
    call_count += 1
    
    code = frame.f_code
    print(f"\n[拦截 #{call_count}]")
    print(f"  函数: {code.co_name}")
    print(f"  参数: {frame.f_locals}")
    
    # 获取字节码
    import dis
    print(f"  字节码:")
    dis.disassemble(code)
    
    # 调用原始执行(在真实实现中,这里会编译)
    # 注意:Python 层无法真正替换 eval_frame,这只是演示概念
    return None

# 实际演示:查看帧信息
def demo_function(x, y):
    """被拦截的函数"""
    frame = sys._getframe()
    
    print("\n=== 帧详细信息 ===")
    print(f"代码对象: {frame.f_code}")
    print(f"  - 函数名: {frame.f_code.co_name}")
    print(f"  - 参数名: {frame.f_code.co_varnames}")
    print(f"  - 常量: {frame.f_code.co_consts}")
    
    print(f"\n执行状态:")
    print(f"  - 局部变量: {frame.f_locals}")
    print(f"  - 当前指令: {frame.f_lasti}")
    
    print(f"\n调用栈:")
    current = frame
    depth = 0
    while current is not None:
        print(f"  [{depth}] {current.f_code.co_name}")
        current = current.f_back
        depth += 1
    
    result = x + y
    return result * 2

# 运行
if __name__ == "__main__":
    result = demo_function(10, 20)
    print(f"\n最终结果: {result}")

运行输出:

复制代码
=== 帧详细信息 ===
代码对象: <code object demo_function at 0x...>
  - 函数名: demo_function
  - 参数名: ('x', 'y', 'frame', 'current', 'depth', 'result')
  - 常量: (None, 0, 1, 2, '  [', '] ', '\n最终结果: ')

执行状态:
  - 局部变量: {'x': 10, 'y': 20, 'frame': <frame object>}
  - 当前指令: 42

调用栈:
  [0] demo_function
  [1] <module>

最终结果: 60

TorchDynamo 真实的 C 层拦截

c 复制代码
// torch/csrc/dynamo/eval_frame.c (简化版)

static PyObject* custom_eval_frame(
    PyThreadState* tstate,
    PyFrameObject* frame,
    int throwflag
) {
    // 【步骤 1】提取信息
    PyCodeObject* code = frame->f_code;
    const char* func_name = PyUnicode_AsUTF8(code->co_name);
    
    // 【步骤 2】决策:是否编译
    if (should_skip_frame(frame)) {
        // 跳过:系统函数、已编译函数等
        return original_eval_frame(tstate, frame, throwflag);
    }
    
    // 【步骤 3】检查缓存
    CacheEntry* cached = lookup_cache(code);
    if (cached != NULL) {
        // 【步骤 4】检查 Guard
        if (check_guards(frame, cached->guards)) {
            // Guard 通过,执行缓存的优化代码
            cached->hit_count++;
            return execute_optimized_code(tstate, frame, cached->code);
        }
        // Guard 失败
        cached->miss_count++;
    }
    
    // 【步骤 5】调用 Python 层编译
    PyObject* result = call_python_compiler(frame);
    
    if (result != NULL) {
        // 编译成功
        return result;
    }
    
    // 【步骤 6】回退
    return original_eval_frame(tstate, frame, throwflag);
}

核心要点总结

  1. 执行帧(Frame)= 函数的运行环境

    • 包含代码(字节码)
    • 包含变量(局部/全局)
    • 包含执行状态(当前指令位置、操作数栈)
  2. CPython 通过 eval_frame 函数执行帧

    • 默认:_PyEval_EvalFrameDefault
    • 逐条执行字节码指令
  3. PEP 523 允许替换 eval_frame 函数

    • 通过 _PyInterpreterState_SetEvalFrameFunc()
    • 在每次函数调用时拦截
  4. TorchDynamo 的拦截流程

    复制代码
    函数调用
      -> custom_eval_frame (拦截!)
        -> 检查缓存
        -> 检查 Guard
        -> 编译 (如果需要)
        -> 执行优化代码
  5. 为什么这么做有用?

    • 无需修改用户代码
    • 无需修改 Python 解释器
    • 在运行时动态优化
    • 保持 Python 的灵活性

直观理解:拦截的本质

把 Python 执行想象成一个流水线:

复制代码
[正常执行]

用户代码: def add(a, b): return a + b
    |
    v
Python 解释器
    |
    v
[解析] -> AST
    |
    v
[编译] -> 字节码
    |
    v
[创建帧] -> Frame{code, locals, stack}
    |
    v
[执行帧] -> _PyEval_EvalFrameDefault(frame)
    |         |
    |         +-> 读取指令
    |         +-> 执行指令
    |         +-> 更新状态
    |         +-> 返回结果
    v
结果: 3

[TorchDynamo 拦截后]

用户代码: @torch.compile
         def add(a, b): return a + b
    |
    v
Python 解释器 (未修改!)
    |
    v
[解析] -> AST
    |
    v
[编译] -> 字节码
    |
    v
[创建帧] -> Frame{code, locals, stack}
    |
    v
[执行帧] -> custom_eval_frame(frame)  <-- 这里被替换了!
    |         |
    |         +-> [检查] 是否需要编译?
    |         |     |
    |         |     +-> 是 -> [提取字节码]
    |         |              [构建 FX Graph]
    |         |              [编译优化]
    |         |              [生成 Guard]
    |         |              [缓存结果]
    |         |              [执行优化代码]
    |         |
    |         +-> 否 -> [调用原始执行]
    |                   _PyEval_EvalFrameDefault(frame)
    v
结果: 3 (但更快!)

关键时刻:拦截发生在何时?

python 复制代码
import torch

@torch.compile
def my_func(x, y):
    return x + y

# 调用函数
result = my_func(torch.randn(3), torch.randn(3))

时间线:

复制代码
t0: 用户调用 my_func()
    |
    v
t1: CPython 准备执行 my_func
    |
    v
t2: CPython 创建 Frame 对象
    |
    |  Frame 内容:
    |  - f_code: my_func 的字节码
    |  - f_locals: {'x': tensor([...]), 'y': tensor([...])}
    |  - f_lasti: 0
    v
t3: CPython 查找 eval_frame 函数
    |
    |  正常情况: _PyEval_EvalFrameDefault
    |  现在: custom_eval_frame (被替换!)
    v
t4: 调用 custom_eval_frame(frame)  [C 层]
    |
    v
t5: custom_eval_frame 检查缓存
    |
    +-> 缓存未命中
    |
    v
t6: 调用 Python 层的 compile_frame()
    |
    v
t7: 分析字节码
    |
    |  LOAD_FAST 0 (x)
    |  LOAD_FAST 1 (y)
    |  BINARY_ADD
    |  RETURN_VALUE
    v
t8: 构建 FX Graph
    |
    |  def forward(x, y):
    |      return torch.add(x, y)
    v
t9: 后端编译 (Inductor)
    |
    |  生成优化的 CUDA kernel
    v
t10: 生成 Guard
    |
    |  Guard 1: x.shape == torch.Size([3])
    |  Guard 2: x.dtype == torch.float32
    |  Guard 3: y.shape == torch.Size([3])
    |  Guard 4: y.dtype == torch.float32
    v
t11: 缓存结果
    |
    v
t12: 执行优化代码
    |
    v
t13: 返回结果

---

第二次调用 (相同输入):

t0: 用户调用 my_func()
    v
t3: 调用 custom_eval_frame(frame)
    v
t5: 检查缓存 -> 命中!
    v
t5.1: 检查 Guard -> 通过!
    v
t12: 直接执行缓存的优化代码 (极快!)
    v
t13: 返回结果

实际运行演示

我已经创建了一个演示程序 Pytorch/examples/demo_frame.py,你可以运行它来亲自查看执行帧的结构:

bash 复制代码
cd Pytorch/examples
python demo_frame.py

程序会展示:

  1. 执行帧的详细信息(代码、变量、状态)
  2. 字节码指令的结构
  3. 调用栈的追踪
  4. 字节码执行的模拟过程
  5. TorchDynamo 拦截的概念
  6. 实际的 torch.compile 例子

用一句话总结

TorchDynamo 就像在 Python 解释器和你的代码之间插入了一个"智能代理",这个代理在每次函数调用时决定:是直接运行,还是先编译优化再运行。而这个"插入"动作是通过 PEP 523 提供的 eval_frame 替换机制实现的。


torch.compile 装饰器完整调用链

在深入 Frame Evaluation Hook 之前,让我们先看看 @torch.compile 装饰器是如何一步步调用到字节码拦截的。

调用链概览

复制代码
@torch.compile 装饰器
    |
    v
torch.compile() 函数 [torch/__init__.py:2422]
    |
    v
torch._dynamo.optimize() [torch/_dynamo/__init__.py]
    |
    v
_optimize_catch_errors() [torch/_dynamo/eval_frame.py]
    |
    v
set_eval_frame() [torch/_dynamo/eval_frame.py]
    |
    v
_C._dynamo.eval_frame.set_eval_frame() [C 扩展]
    |
    v
custom_eval_frame() [C 层拦截器]
    |
    v
Python 函数调用时自动触发

第 1 步:torch.compile() 装饰器

文件位置torch/__init__.py 第 2422 行

python 复制代码
# torch/__init__.py:2422

def compile(
    model: Optional[Callable] = None,
    *,
    fullgraph: bool = False,
    dynamic: Optional[bool] = None,
    backend: Union[str, Callable] = "inductor",
    mode: Union[str, None] = None,
    options: Optional[Dict[str, Union[str, int, bool]]] = None,
    disable: bool = False,
) -> Union[Callable, Callable]:
    """
    主入口函数
    """
    # [1] 日志记录
    _C._log_api_usage_once("torch.compile")
    
    # [2] 版本检查
    if sys.version_info >= (3, 14):
        raise RuntimeError("torch.compile is not supported on Python 3.14+")
    
    # [3] 装饰器模式处理
    if model is None:
        # 返回一个装饰器函数
        def fn(model):
            return compile(
                model,
                fullgraph=fullgraph,
                dynamic=dynamic,
                backend=backend,
                mode=mode,
                options=options,
                disable=disable,
            )
        return fn
    
    # [4] 创建后端编译器包装器
    if backend == "inductor":
        # 默认使用 Inductor 后端
        backend = _TorchCompileInductorWrapper(mode, options, dynamic)
    else:
        # 其他后端
        backend = _TorchCompileWrapper(backend, mode, options, dynamic)
    
    # [5] 关键调用!调用 TorchDynamo 的 optimize
    return torch._dynamo.optimize(
        backend=backend,
        nopython=fullgraph,
        dynamic=dynamic,
        disable=disable,
    )(model)

第 2 步:torch._dynamo.optimize()

文件位置torch/_dynamo/__init__.py

python 复制代码
# torch/_dynamo/__init__.py

def optimize(
    backend="inductor",
    *,
    nopython=False,
    guard_export_fn=None,
    guard_fail_fn=None,
    disable=False,
    dynamic=None,
):
    """
    TorchDynamo 的主优化入口
    
    Args:
        backend: 编译后端(Inductor, eager, aot_autograd 等)
        nopython: 是否要求整个函数可编译(fullgraph)
        dynamic: 动态形状支持
    
    Returns:
        装饰器函数
    """
    # [1] 检查是否禁用
    if disable or os.environ.get("TORCH_COMPILE_DISABLE"):
        return _disable_decorator
    
    # [2] 查找后端实现
    from torch._dynamo.backends.registry import lookup_backend
    backend_fn = lookup_backend(backend)
    
    # [3] 返回实际的装饰器
    def decorator(fn):
        # 这里会包装原始函数
        from torch._dynamo.eval_frame import OptimizedModule
        
        if isinstance(fn, torch.nn.Module):
            # 处理 nn.Module
            return OptimizedModule(
                fn,
                backend=backend_fn,
                nopython=nopython,
                dynamic=dynamic,
            )
        else:
            # 处理普通函数
            from torch._dynamo.eval_frame import _optimize_catch_errors
            return _optimize_catch_errors(
                fn,
                backend=backend_fn,
                nopython=nopython,
                dynamic=dynamic,
            )
    
    return decorator

第 3 步:_optimize_catch_errors() - 包装并启用钩子

文件位置torch/_dynamo/eval_frame.py

python 复制代码
# torch/_dynamo/eval_frame.py

def _optimize_catch_errors(
    fn,
    backend,
    nopython=False,
    dynamic=None,
):
    """
    包装函数并启用字节码拦截
    
    这是连接用户代码和底层 Frame Hook 的桥梁
    """
    # [1] 创建回调函数
    callback = callback_handler(
        backend=backend,
        nopython=nopython,
        dynamic=dynamic,
    )
    
    # [2] 注册编译回调
    from torch._dynamo.convert_frame import convert_frame
    callback.convert_frame_assert = convert_frame_assert
    
    # [3] 包装原始函数
    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        # [关键] 在函数执行前安装 Frame Hook
        prior = set_eval_frame(callback)
        try:
            # 执行原始函数(此时会被 Hook 拦截)
            return fn(*args, **kwargs)
        finally:
            # 恢复之前的 Hook
            set_eval_frame(prior)
    
    # [4] 标记这是一个优化过的函数
    wrapper._torchdynamo_orig_callable = fn
    wrapper._torchdynamo_inline = fn
    
    return wrapper

第 4 步:set_eval_frame() - 安装 C 层钩子

文件位置torch/_dynamo/eval_frame.py

python 复制代码
# torch/_dynamo/eval_frame.py

def set_eval_frame(callback):
    """
    设置 Python 帧执行钩子
    
    这是 Python 层到 C 层的桥梁
    
    Args:
        callback: 编译回调函数,接收 (frame, cache_size) 返回优化代码
    
    Returns:
        之前的回调函数
    """
    # [1] 获取之前的回调
    prior = _eval_frame_callback
    
    if callback is None:
        # [2] 移除钩子,恢复默认执行
        _C._dynamo.eval_frame.set_eval_frame(None)
        _eval_frame_callback = None
    else:
        # [3] 创建 C 层可调用的包装器
        def _callback_wrapper(frame, cache_size, frame_state):
            """
            C 层调用的实际回调
            
            Args:
                frame: PyFrameObject (Python 帧对象)
                cache_size: 当前缓存大小
                frame_state: 帧状态信息
            
            Returns:
                优化后的代码对象,或 None(回退到默认执行)
            """
            try:
                # [4] 调用 Python 层的编译逻辑
                return callback(frame, cache_size, frame_state)
            except Exception as e:
                # [5] 编译失败,回退
                log_compilation_error(e)
                return None
        
        # [6] 安装到 C 扩展
        # 这是关键调用!将回调注册到 CPython 解释器
        _C._dynamo.eval_frame.set_eval_frame(_callback_wrapper)
        _eval_frame_callback = callback
    
    return prior

第 5 步:C 扩展层 - custom_eval_frame()

文件位置torch/csrc/dynamo/eval_frame.c

c 复制代码
// torch/csrc/dynamo/eval_frame.c

// 全局变量:保存原始的 eval_frame 函数
static _PyFrameEvalFunction original_eval_frame = NULL;

// 全局变量:Python 层的回调函数
static PyObject* py_callback = NULL;

// TorchDynamo 的自定义 eval_frame 函数
static PyObject* custom_eval_frame(
    PyThreadState* tstate,
    PyFrameObject* frame,
    int throw_flag
) {
    /*
     * 这个函数会在每次 Python 函数调用时被触发
     * 
     * 执行流程:
     * 1. 检查是否需要编译这个函数
     * 2. 检查缓存
     * 3. 调用 Python 层编译
     * 4. 执行优化代码或回退
     */
    
    // [1] 快速路径:检查是否应该跳过这个帧
    if (should_skip_frame(frame)) {
        // 直接使用原始执行
        return original_eval_frame(tstate, frame, throw_flag);
    }
    
    // [2] 获取代码对象
    PyCodeObject* code = frame->f_code;
    
    // [3] 检查编译缓存
    CacheEntry* cached = lookup_cache(code);
    if (cached != NULL) {
        // [4] 检查 Guard
        if (check_guards(frame, cached->guards)) {
            // Guard 通过,执行缓存的代码
            return execute_cached_code(tstate, frame, cached);
        }
        // Guard 失败,需要重新编译
    }
    
    // [5] 缓存未命中,调用 Python 层编译
    if (py_callback != NULL) {
        // 构建参数
        PyObject* args = Py_BuildValue(
            "(OiO)",
            frame,              // PyFrameObject
            get_cache_size(),   // cache_size
            Py_None             // frame_state
        );
        
        // [6] 调用 Python 回调
        PyObject* result = PyObject_CallObject(py_callback, args);
        Py_DECREF(args);
        
        if (result != NULL && result != Py_None) {
            // [7] 编译成功,执行优化代码
            PyObject* optimized_result = execute_optimized_code(
                tstate,
                frame,
                result
            );
            Py_DECREF(result);
            return optimized_result;
        }
        
        Py_XDECREF(result);
    }
    
    // [8] 编译失败或未启用,回退到原始执行
    return original_eval_frame(tstate, frame, throw_flag);
}

// 安装自定义 eval_frame 函数
PyObject* set_eval_frame_py(PyObject* self, PyObject* callback) {
    /*
     * Python 调用:_C._dynamo.eval_frame.set_eval_frame(callback)
     */
    
    if (callback == Py_None) {
        // [1] 移除钩子
        py_callback = NULL;
        
        // [2] 恢复原始 eval_frame
        _PyInterpreterState_SetEvalFrameFunc(
            PyThreadState_GET()->interp,
            original_eval_frame
        );
    } else {
        // [3] 保存回调
        Py_INCREF(callback);
        py_callback = callback;
        
        // [4] 如果是第一次安装,保存原始函数
        if (original_eval_frame == NULL) {
            original_eval_frame = _PyInterpreterState_GetEvalFrameFunc(
                PyThreadState_GET()->interp
            );
        }
        
        // [5] 安装自定义 eval_frame(关键!)
        _PyInterpreterState_SetEvalFrameFunc(
            PyThreadState_GET()->interp,
            custom_eval_frame  // 我们的自定义函数
        );
    }
    
    Py_RETURN_NONE;
}

第 6 步:compile_frame() - Python 层编译逻辑

文件位置torch/_dynamo/convert_frame.py

python 复制代码
# torch/_dynamo/convert_frame.py

def convert_frame_assert(
    compiler_fn,
    frame,
    cache_size,
    frame_state=None,
):
    """
    编译单个函数帧
    
    这是 C 层回调调用的 Python 函数
    
    Args:
        compiler_fn: 后端编译器函数
        frame: Python 帧对象
        cache_size: 缓存大小
        frame_state: 帧状态
    
    Returns:
        GuardedCode: 包含优化代码和 Guard 的对象
    """
    # [1] 创建编译上下文
    code = frame.f_code
    
    # [2] 检查是否可以编译
    if not is_traceable(code):
        return None
    
    # [3] 获取字节码
    instructions = list(dis.get_instructions(code))
    
    # [4] 创建符号化转换器
    from torch._dynamo.symbolic_convert import InstructionTranslator
    
    tracer = InstructionTranslator(
        instructions=instructions,
        f_code=code,
        f_globals=frame.f_globals,
        f_locals=frame.f_locals,
        f_builtins=frame.f_builtins,
        compiler_fn=compiler_fn,
        # ... 其他参数
    )
    
    # [5] 执行符号化转换(字节码 -> FX Graph)
    try:
        tracer.run()
    except Exception as e:
        # 编译失败
        log_graph_break(e)
        return None
    
    # [6] 获取输出图
    output_graph = tracer.output
    
    # [7] 后端编译
    compiled_fn = compiler_fn(
        output_graph.graph,
        output_graph.example_inputs
    )
    
    # [8] 生成 Guard
    guards = output_graph.guards
    
    # [9] 创建 GuardedCode 对象
    guarded_code = GuardedCode(
        code=compiled_fn,
        guards=guards,
        # ... 其他信息
    )
    
    # [10] 缓存
    cache_code(code, guarded_code)
    
    return guarded_code

完整执行流程示例

python 复制代码
import torch

# 定义一个简单的函数
def my_function(x, y):
    z = x + y
    return z * 2

# 应用 @torch.compile 装饰器
compiled_fn = torch.compile(my_function)

# 第一次调用
x = torch.randn(10)
y = torch.randn(10)
result = compiled_fn(x, y)

# 执行流程:
"""
1. compiled_fn(x, y) 被调用

2. wrapper() 函数开始执行 [eval_frame.py]
   -> set_eval_frame(callback) 安装钩子
   -> 调用 my_function(x, y)

3. CPython 准备执行 my_function
   -> 查找 eval_frame 函数
   -> 发现是 custom_eval_frame() [C 层]

4. custom_eval_frame() 被触发 [eval_frame.c]
   -> 检查缓存:未命中
   -> 调用 py_callback (Python 层)

5. convert_frame_assert() 开始编译 [convert_frame.py]
   -> 获取字节码:
      LOAD_FAST x
      LOAD_FAST y
      BINARY_ADD
      STORE_FAST z
      LOAD_FAST z
      LOAD_CONST 2
      BINARY_MULTIPLY
      RETURN_VALUE
   
   -> 创建 InstructionTranslator
   -> 逐条处理字节码指令

6. InstructionTranslator.run() [symbolic_convert.py]
   -> LOAD_FAST x: 创建 TensorVariable(x)
   -> LOAD_FAST y: 创建 TensorVariable(y)
   -> BINARY_ADD: 创建 FX 节点 add(x, y)
   -> STORE_FAST z: 保存符号化变量
   -> ... 继续处理
   
   -> 构建完整的 FX Graph

7. 生成 Guard
   -> x.shape == (10,)
   -> x.dtype == torch.float32
   -> y.shape == (10,)
   -> y.dtype == torch.float32

8. 后端编译 (Inductor)
   -> 优化 FX Graph
   -> 生成 Triton/C++ 代码
   -> 编译为可执行的 kernel

9. 缓存并返回
   -> 保存到代码对象的缓存中
   -> 返回 GuardedCode

10. custom_eval_frame() 执行优化代码
    -> 执行编译后的 kernel
    -> 返回结果

11. 第二次调用(相同形状)
    -> custom_eval_frame() 检查缓存:命中
    -> 检查 Guard:通过
    -> 直接执行缓存的代码(极快!)
"""

关键数据结构

GuardedCode
python 复制代码
# torch/_dynamo/types.py

@dataclasses.dataclass
class GuardedCode:
    """
    编译结果:包含优化代码和 Guard
    """
    # 优化后的可执行代码
    code: Union[types.CodeType, Callable]
    
    # Guard 列表
    guards: List[Guard]
    
    # Guard 检查函数(编译后的)
    check_fn: Callable[[PyFrameObject], bool]
    
    # 原始代码对象
    code_obj: types.CodeType
    
    # 性能统计
    hit_count: int = 0
    miss_count: int = 0
CacheEntry (C 层)
c 复制代码
// torch/csrc/dynamo/eval_frame.c

typedef struct {
    PyObject* code;           // 优化后的代码对象
    PyObject* guards;         // Guard 列表
    PyObject* check_fn;       // Guard 检查函数
    long hit_count;           // 缓存命中次数
    long miss_count;          // Guard 失败次数
} CacheEntry;

总结

@torch.compile 的完整调用链:

  1. 装饰器层 (torch.compile()):用户入口,参数处理
  2. 优化层 (torch._dynamo.optimize()):创建装饰器包装
  3. 钩子安装层 (set_eval_frame()):注册 C 层回调
  4. C 拦截层 (custom_eval_frame()):拦截每次函数调用
  5. 编译层 (convert_frame_assert()):字节码 -> FX Graph
  6. 后端层 (Inductor/其他):FX Graph -> 优化代码

关键技术点

  • PEP 523 的 _PyInterpreterState_SetEvalFrameFunc 允许替换帧执行函数
  • 字节码在 C 层被拦截,但编译逻辑在 Python 层
  • Guard 机制确保优化代码的正确性
  • 缓存避免重复编译

第一层:Frame Evaluation Hook - 入口机制

PEP 523:Python 的编译器钩子

Python 3.6 引入了 PEP 523,允许 C 扩展替换 Python 的帧执行函数。

文件位置torch/_dynamo/eval_frame.c

核心实现

c 复制代码
// torch/_dynamo/eval_frame.c

// 原始的 Python 帧执行函数
static PyObject* (*previous_eval_frame)(PyFrameObject*, int) = NULL;

// TorchDynamo 的自定义帧执行函数
static PyObject* custom_eval_frame(PyFrameObject* frame, int throw_flag) {
    // [1] 检查是否应该编译这个函数
    if (!should_compile_frame(frame)) {
        // 不编译,使用原始执行
        return previous_eval_frame(frame, throw_flag);
    }
    
    // [2] 检查缓存
    PyObject* cached_code = lookup_cache(frame->f_code);
    if (cached_code != NULL) {
        // [3] 检查 Guard
        if (check_guards(frame, cached_code)) {
            // Guard 通过,执行缓存的优化代码
            return execute_cached_code(frame, cached_code);
        }
    }
    
    // [4] 缓存未命中或 Guard 失败,调用 Python 层编译
    PyObject* result = py_compile_function(frame);
    
    if (result != NULL) {
        // 编译成功,执行优化代码
        return result;
    }
    
    // [5] 编译失败,回退到原始执行
    return previous_eval_frame(frame, throw_flag);
}

// 安装钩子
PyAPI_FUNC(void) _PyInterpreterState_SetEvalFrameFunc(
    PyInterpreterState* interp,
    _PyFrameEvalFunction eval_frame
);

Python 层的接口

文件位置torch/_dynamo/eval_frame.py

python 复制代码
# torch/_dynamo/eval_frame.py

import sys
from typing import Callable, Optional

# 全局编译缓存
_compile_cache = {}

def set_eval_frame(callback: Optional[Callable] = None):
    """
    设置帧执行钩子
    
    Args:
        callback: 编译函数,接收 (frame, cache_size) 返回优化后的代码
    """
    if callback is None:
        # 移除钩子,恢复默认执行
        _C._dynamo.eval_frame.set_eval_frame(None)
    else:
        # 安装钩子
        def wrapper(frame, cache_size):
            try:
                # 调用编译器
                return callback(frame, cache_size)
            except Exception as e:
                # 编译失败,回退
                return None
        
        _C._dynamo.eval_frame.set_eval_frame(wrapper)

def compile_frame(frame, cache_size):
    """
    编译单个函数帧
    
    Args:
        frame: Python 帧对象
        cache_size: 当前缓存大小
    
    Returns:
        优化后的代码对象或 None
    """
    code = frame.f_code
    
    # [1] 检查缓存
    cache_key = make_cache_key(frame)
    if cache_key in _compile_cache:
        cached_entry = _compile_cache[cache_key]
        
        # 检查 Guard
        if check_guards(frame, cached_entry.guards):
            return cached_entry.optimized_code
    
    # [2] 开始编译
    try:
        # 字节码分析与转换
        from .symbolic_convert import InstructionTranslator
        
        translator = InstructionTranslator(
            instructions=dis.get_instructions(code),
            f_globals=frame.f_globals,
            f_locals=frame.f_locals,
            f_builtins=frame.f_builtins,
            code_options={},
            symbolic_locals={},
            symbolic_globals={},
        )
        
        # 执行符号化转换
        translator.run()
        
        # [3] 获取生成的 FX 图和 Guard
        output_graph = translator.output
        guards = translator.output.guards
        
        # [4] 编译后端(例如 Inductor)
        from .backends import compile_fx
        optimized_fn = compile_fx(output_graph.graph, output_graph.example_inputs)
        
        # [5] 缓存结果
        _compile_cache[cache_key] = CacheEntry(
            optimized_code=optimized_fn,
            guards=guards
        )
        
        return optimized_fn
        
    except Exception as e:
        # 编译失败
        log_compilation_error(e)
        return None

关键概念:缓存与 Guard

python 复制代码
class CacheEntry:
    """编译缓存条目"""
    def __init__(self, optimized_code, guards):
        self.optimized_code = optimized_code  # 优化后的代码
        self.guards = guards                  # Guard 列表

class Guard:
    """执行条件守卫"""
    def __init__(self, check_fn, verbose_code):
        self.check_fn = check_fn          # 检查函数
        self.verbose_code = verbose_code  # 可读的条件描述
    
    def check(self, frame):
        """检查 Guard 是否满足"""
        return self.check_fn(frame)

# Guard 示例
def make_tensor_guard(name, expected_shape, expected_dtype):
    """生成 Tensor 的 Guard"""
    def check(frame):
        tensor = frame.f_locals[name]
        return (
            tensor.shape == expected_shape and
            tensor.dtype == expected_dtype
        )
    
    verbose = f"{name}.shape == {expected_shape} and {name}.dtype == {expected_dtype}"
    return Guard(check, verbose)

字节码到 FX Graph 的转换过程

这是 TorchDynamo 最核心的部分!让我们详细看看字节码是如何一步步转换为 FX Graph 的。

核心思想:符号化执行

正常执行:用真实的值运行代码

python 复制代码
x = 3
y = 5
z = x + y  # z = 8 (真实计算)

符号化执行:用"符号"代替真实值,记录操作

python 复制代码
x = Symbol("x")  # 不是真的 3,是一个符号
y = Symbol("y")  # 不是真的 5,是一个符号  
z = x + y        # 不真的计算,而是记录"z = add(x, y)"

完整转换流程示例

让我们跟踪一个实际的例子:

python 复制代码
@torch.compile
def my_function(x, y):
    z = x + y
    return z * 2
步骤 1:获取字节码
python 复制代码
import dis
dis.dis(my_function)

输出:

复制代码
  2           0 LOAD_FAST                0 (x)
              2 LOAD_FAST                1 (y)
              4 BINARY_ADD
              6 STORE_FAST               2 (z)
              
  3           8 LOAD_FAST                2 (z)
             10 LOAD_CONST               1 (2)
             12 BINARY_MULTIPLY
             14 RETURN_VALUE
步骤 2:创建 InstructionTranslator

文件位置torch/_dynamo/symbolic_convert.py

python 复制代码
# 简化的初始化过程
translator = InstructionTranslator(
    instructions=instructions,  # 字节码指令列表
    f_code=code,
    f_globals=frame.f_globals,
    f_locals=frame.f_locals,
)

# 初始状态
translator.stack = []           # 操作数栈(空)
translator.symbolic_locals = {} # 符号化局部变量(空)
translator.output = OutputGraph()  # FX Graph 构建器
步骤 3:符号化执行 - 逐条处理字节码
指令 1: LOAD_FAST 0 (x)
python 复制代码
def INSTR_LOAD_FAST(self, inst):
    """从局部变量加载到栈"""
    var_name = inst.argval  # "x"
    
    # 第一次访问 x,需要创建符号化变量
    if var_name not in self.symbolic_locals:
        # 从真实的 frame.f_locals 获取运行时的值
        real_value = self.f_locals[var_name]  # Tensor([...])
        
        # 创建符号化 Tensor 变量
        from torch._dynamo.variables.tensor import TensorVariable
        
        # 同时创建 FX Graph 中的 placeholder 节点
        proxy = self.output.create_proxy(
            "placeholder",  # 节点类型
            var_name,       # 节点名称
            (),             # args
            {}              # kwargs
        )
        
        # 创建符号化变量(保存 proxy 和示例值)
        symbolic_var = TensorVariable(
            proxy=proxy,
            example_value=real_value,  # 用于形状推导
        )
        
        self.symbolic_locals[var_name] = symbolic_var
    
    # 压入栈
    var = self.symbolic_locals[var_name]
    self.stack.append(var)

# 状态更新:
# stack: [TensorVariable(x)]
# FX Graph: placeholder x
指令 2: LOAD_FAST 1 (y)
python 复制代码
# 同理,创建 y 的符号化变量
var_name = "y"
real_value = self.f_locals["y"]

proxy_y = self.output.create_proxy("placeholder", "y", (), {})
symbolic_y = TensorVariable(proxy=proxy_y, example_value=real_value)

self.symbolic_locals["y"] = symbolic_y
self.stack.append(symbolic_y)

# 状态更新:
# stack: [TensorVariable(x), TensorVariable(y)]
# FX Graph: 
#   placeholder x
#   placeholder y
指令 3: BINARY_ADD
python 复制代码
def INSTR_BINARY_ADD(self, inst):
    """二元加法"""
    # 从栈弹出两个操作数
    right = self.stack.pop()  # TensorVariable(y)
    left = self.stack.pop()   # TensorVariable(x)
    
    # 调用左操作数的 add 方法(符号化执行)
    result = left.call_method(
        self,      # translator
        "__add__", # 方法名
        [right],   # 参数
        {}         # 关键字参数
    )
    
    # 压入结果
    self.stack.append(result)

# TensorVariable.call_method 内部实现:
class TensorVariable:
    def call_method(self, tx, name, args, kwargs):
        if name == "__add__":
            other = args[0]
            
            # 在 FX Graph 中创建 add 节点
            result_proxy = self.proxy + other.proxy  
            # 等价于:
            # result_proxy = tx.output.create_proxy(
            #     "call_function",
            #     torch.add,
            #     (self.proxy, other.proxy),
            #     {}
            # )
            
            # 推导结果的形状
            result_shape = torch.broadcast_shapes(
                self.example_value.shape,
                other.example_value.shape
            )
            result_example = torch.empty(
                result_shape,
                dtype=self.example_value.dtype
            )
            
            # 返回新的符号化变量
            return TensorVariable(
                proxy=result_proxy,
                example_value=result_example
            )

# 状态更新:
# stack: [TensorVariable(z=x+y)]
# FX Graph:
#   placeholder x
#   placeholder y
#   call_function torch.add(x, y) -> z_temp
指令 4: STORE_FAST 2 (z)
python 复制代码
def INSTR_STORE_FAST(self, inst):
    """存储到局部变量"""
    var_name = inst.argval  # "z"
    
    # 从栈弹出值
    value = self.stack.pop()  # TensorVariable(x+y)
    
    # 保存到符号化局部变量
    self.symbolic_locals[var_name] = value

# 状态更新:
# stack: []
# symbolic_locals: {'x': TensorVariable(x), 'y': TensorVariable(y), 'z': TensorVariable(x+y)}
# FX Graph: (不变)
指令 5: LOAD_FAST 2 (z)
python 复制代码
# 加载 z
var = self.symbolic_locals["z"]
self.stack.append(var)

# 状态更新:
# stack: [TensorVariable(z)]
指令 6: LOAD_CONST 1 (2)
python 复制代码
def INSTR_LOAD_CONST(self, inst):
    """加载常量"""
    const_value = inst.argval  # 2
    
    from torch._dynamo.variables.constant import ConstantVariable
    symbolic_const = ConstantVariable(const_value)
    
    self.stack.append(symbolic_const)

# 状态更新:
# stack: [TensorVariable(z), ConstantVariable(2)]
指令 7: BINARY_MULTIPLY
python 复制代码
# 类似 BINARY_ADD
right = self.stack.pop()  # ConstantVariable(2)
left = self.stack.pop()   # TensorVariable(z)

result = left.call_method(self, "__mul__", [right], {})

# TensorVariable.__mul__ 内部:
result_proxy = self.proxy * 2  # 在 FX Graph 中创建 mul 节点
result = TensorVariable(proxy=result_proxy, ...)

self.stack.append(result)

# 状态更新:
# stack: [TensorVariable(z*2)]
# FX Graph:
#   placeholder x
#   placeholder y
#   call_function torch.add(x, y) -> z
#   call_function torch.mul(z, 2) -> result
指令 8: RETURN_VALUE
python 复制代码
def INSTR_RETURN_VALUE(self, inst):
    """返回值"""
    ret_val = self.stack.pop()  # TensorVariable(z*2)
    
    # 添加到输出图
    self.output.add_output(ret_val)

# FX Graph 完成:
#   placeholder x
#   placeholder y
#   call_function torch.add(x, y) -> z
#   call_function torch.mul(z, 2) -> result
#   output result

完整的转换过程可视化

复制代码
[字节码]                [符号化执行]              [FX Graph]

LOAD_FAST x       ->   创建 TensorVariable    ->  placeholder x
                       proxy_x = FX.placeholder("x")
                       stack: [x]

LOAD_FAST y       ->   创建 TensorVariable    ->  placeholder y
                       proxy_y = FX.placeholder("y")
                       stack: [x, y]

BINARY_ADD        ->   弹出 y, x               ->  call_function add
                       调用 x + y
                       proxy_z = proxy_x + proxy_y
                       stack: [z]                   - target: torch.add
                                                    - args: (x, y)

STORE_FAST z      ->   locals['z'] = z
                       stack: []

LOAD_FAST z       ->   stack: [z]

LOAD_CONST 2      ->   创建 ConstantVariable
                       stack: [z, 2]

BINARY_MULTIPLY   ->   弹出 2, z               ->  call_function mul
                       调用 z * 2
                       proxy_r = proxy_z * 2       - target: torch.mul
                       stack: [result]              - args: (z, 2)

RETURN_VALUE      ->   弹出 result             ->  output result
                       output_graph.add_output(result)

最终生成的 FX Graph

python 复制代码
# 生成的 FX Graph 代码
def forward(x, y):
    # placeholder nodes
    x_1 = x
    y_1 = y
    
    # call_function nodes
    add_result = torch.add(x_1, y_1)
    mul_result = torch.mul(add_result, 2)
    
    # output
    return mul_result

关键组件详解

1. OutputGraph - FX Graph 构建器
python 复制代码
# torch/_dynamo/output_graph.py

class OutputGraph:
    """构建 FX Graph 的核心类"""
    
    def __init__(self):
        self.graph = torch.fx.Graph()  # FX Graph
        self.graphargs = []            # 输入参数
        self.guards = []               # Guard 列表
    
    def create_proxy(self, kind, target, args, kwargs):
        """
        在 FX Graph 中创建节点
        
        Args:
            kind: 节点类型 (placeholder, call_function, call_method, etc.)
            target: 目标 (函数名、方法名等)
            args: 参数
            kwargs: 关键字参数
        
        Returns:
            Proxy 对象(包装了 FX Node)
        """
        # 创建 FX Node
        node = self.graph.create_node(
            kind,
            target,
            args,
            kwargs
        )
        
        # 返回 Proxy(用于后续操作)
        return torch.fx.Proxy(node, self.graph)
    
    def add_output(self, var):
        """添加输出节点"""
        self.graph.output(var.proxy.node)
2. 符号化变量系统
python 复制代码
# torch/_dynamo/variables/base.py

class VariableTracker:
    """所有符号化变量的基类"""
    
    def call_method(self, tx, name, args, kwargs):
        """调用方法"""
        raise NotImplementedError


# torch/_dynamo/variables/tensor.py

class TensorVariable(VariableTracker):
    """Tensor 的符号化表示"""
    
    def __init__(self, proxy, example_value):
        self.proxy = proxy              # FX Proxy 对象
        self.example_value = example_value  # 示例值(用于形状推导)
    
    def call_method(self, tx, name, args, kwargs):
        """
        调用 Tensor 方法
        
        这里不会真的计算,而是在 FX Graph 中记录操作
        """
        if name == "__add__":
            other = args[0]
            
            # 在 FX Graph 中创建加法节点
            result_proxy = self.proxy + other.proxy
            
            # 推导结果形状
            result_example = self.example_value + other.example_value
            
            return TensorVariable(result_proxy, result_example)
        
        elif name == "__mul__":
            other = args[0]
            
            if isinstance(other, ConstantVariable):
                # Tensor * 常量
                result_proxy = self.proxy * other.value
            else:
                # Tensor * Tensor
                result_proxy = self.proxy * other.proxy
            
            result_example = self.example_value * other.as_python_constant()
            
            return TensorVariable(result_proxy, result_example)
        
        # ... 其他方法


# torch/_dynamo/variables/constant.py

class ConstantVariable(VariableTracker):
    """常量的符号化表示"""
    
    def __init__(self, value):
        self.value = value
    
    def as_python_constant(self):
        return self.value

完整实际代码示例

让我创建一个演示程序:

python 复制代码
# 演示:模拟 TorchDynamo 的符号化执行

import torch
import torch.fx as fx
import dis

class SimpleOutputGraph:
    """简化的 OutputGraph"""
    def __init__(self):
        self.graph = fx.Graph()
    
    def create_placeholder(self, name):
        return self.graph.placeholder(name)
    
    def create_call_function(self, target, args):
        return self.graph.call_function(target, args)
    
    def set_output(self, node):
        self.graph.output(node)

class SimpleTensorVariable:
    """简化的 TensorVariable"""
    def __init__(self, node):
        self.node = node
    
    def __add__(self, other):
        # 创建加法节点
        result_node = graph.create_call_function(
            torch.add,
            (self.node, other.node)
        )
        return SimpleTensorVariable(result_node)
    
    def __mul__(self, other):
        # 创建乘法节点
        if isinstance(other, int):
            # Tensor * 常量
            result_node = graph.create_call_function(
                torch.mul,
                (self.node, other)
            )
        else:
            result_node = graph.create_call_function(
                torch.mul,
                (self.node, other.node)
            )
        return SimpleTensorVariable(result_node)

# 模拟执行
def symbolic_trace_manual(fn):
    """手动符号化追踪"""
    global graph
    graph = SimpleOutputGraph()
    
    print("[字节码]")
    dis.dis(fn)
    
    print("\n[符号化执行过程]")
    
    # 模拟执行
    print("1. 创建 placeholder x")
    node_x = graph.create_placeholder("x")
    x = SimpleTensorVariable(node_x)
    
    print("2. 创建 placeholder y")
    node_y = graph.create_placeholder("y")
    y = SimpleTensorVariable(node_y)
    
    print("3. 执行 z = x + y (符号化)")
    z = x + y  # 触发 __add__,创建 add 节点
    
    print("4. 执行 result = z * 2 (符号化)")
    result = z * 2  # 触发 __mul__,创建 mul 节点
    
    print("5. 设置输出")
    graph.set_output(result.node)
    
    print("\n[生成的 FX Graph]")
    print(graph.graph)
    
    return graph.graph

# 测试
def my_function(x, y):
    z = x + y
    return z * 2

symbolic_trace_manual(my_function)

真实的 TorchDynamo 内部流程

python 复制代码
# torch/_dynamo/symbolic_convert.py (简化版)

class InstructionTranslator:
    """字节码 -> FX Graph 的核心翻译器"""
    
    def run(self):
        """主执行循环"""
        while self.instruction_pointer < len(self.instructions):
            inst = self.instructions[self.instruction_pointer]
            
            # 获取指令处理函数
            handler = getattr(self, f"INSTR_{inst.opname}", None)
            
            if handler is None:
                # 不支持的指令 -> Graph Break
                self.graph_break(f"unsupported: {inst.opname}")
                return
            
            # 执行指令处理
            try:
                handler(inst)
            except Exception as e:
                self.graph_break(str(e))
                return
            
            self.instruction_pointer += 1
        
        # 所有指令处理完毕
        self.output.finalize()
    
    def INSTR_LOAD_FAST(self, inst):
        """处理 LOAD_FAST 指令"""
        var_name = inst.argval
        
        if var_name not in self.symbolic_locals:
            # 第一次访问,创建 placeholder
            real_value = self.f_locals[var_name]
            
            proxy = self.output.create_proxy(
                "placeholder",
                var_name,
                (),
                {}
            )
            
            # 根据真实值的类型创建对应的符号化变量
            if isinstance(real_value, torch.Tensor):
                var = TensorVariable(proxy, real_value)
            elif isinstance(real_value, int):
                var = ConstantVariable(real_value)
            # ... 其他类型
            
            self.symbolic_locals[var_name] = var
        
        self.stack.append(self.symbolic_locals[var_name])
    
    def INSTR_BINARY_ADD(self, inst):
        """处理 BINARY_ADD 指令"""
        right = self.stack.pop()
        left = self.stack.pop()
        
        # 符号化执行加法
        result = left.call_method(self, "__add__", [right], {})
        
        self.stack.append(result)
    
    # ... 其他指令处理函数

核心要点总结

  1. 符号化执行 = 记录操作而不是真的计算

    • 用符号变量代替真实值
    • 操作被记录到 FX Graph 中
  2. 每条字节码指令 -> FX Graph 节点

    • LOAD_FAST x -> placeholder x
    • BINARY_ADD -> call_function torch.add
    • RETURN_VALUE -> output
  3. 符号化变量携带两种信息

    • proxy: FX Graph 中的节点引用
    • example_value: 运行时的示例值(用于形状推导)
  4. 关键组件

    • InstructionTranslator: 字节码翻译器
    • OutputGraph: FX Graph 构建器
    • TensorVariable: Tensor 的符号化表示
    • Proxy: FX Node 的包装
  5. 转换是保守的

    • 遇到不支持的操作 -> Graph Break
    • 保证正确性优先于性能

完整流程总结图

复制代码
[Python 源码]
def my_function(x, y):
    z = x + y
    return z * 2

        |
        | [Python 编译器]
        v

[字节码]
LOAD_FAST 0 (x)        <-- 指令 1
LOAD_FAST 1 (y)        <-- 指令 2
BINARY_ADD             <-- 指令 3
STORE_FAST 2 (z)       <-- 指令 4
LOAD_FAST 2 (z)        <-- 指令 5
LOAD_CONST 1 (2)       <-- 指令 6
BINARY_MULTIPLY        <-- 指令 7
RETURN_VALUE           <-- 指令 8

        |
        | [TorchDynamo 符号化执行]
        v

[符号化执行过程]

指令 1: LOAD_FAST x
  -> 创建 TensorVariable(x)
  -> proxy_x = graph.placeholder("x")
  -> stack: [x]
  
指令 2: LOAD_FAST y
  -> 创建 TensorVariable(y)
  -> proxy_y = graph.placeholder("y")
  -> stack: [x, y]

指令 3: BINARY_ADD
  -> 弹出 y, x
  -> 调用 x.__add__(y)
  -> proxy_z = graph.call_function(torch.add, (x, y))
  -> stack: [z]

指令 4: STORE_FAST z
  -> locals["z"] = stack.pop()
  -> stack: []

指令 5: LOAD_FAST z
  -> stack: [z]

指令 6: LOAD_CONST 2
  -> stack: [z, 2]

指令 7: BINARY_MULTIPLY
  -> 弹出 2, z
  -> 调用 z.__mul__(2)
  -> proxy_r = graph.call_function(torch.mul, (z, 2))
  -> stack: [result]

指令 8: RETURN_VALUE
  -> graph.output(stack.pop())

        |
        v

[FX Graph]
graph():
    %x : placeholder[target=x]
    %y : placeholder[target=y]
    %add : call_function[target=torch.add](args=(%x, %y))
    %mul : call_function[target=torch.mul](args=(%add, 2))
    return mul

        |
        | [代码生成]
        v

[生成的 Python 代码]
def forward(x, y):
    add = torch.add(x, y)
    mul = torch.mul(add, 2)
    return mul

        |
        | [后端编译 (Inductor)]
        v

[优化的代码]
- 算子融合: add + mul -> fused_add_mul
- CUDA Kernel 生成
- 内存优化

关键理解:三层抽象

复制代码
[第一层] Python 源码
    def my_function(x, y):
        z = x + y
        return z * 2
    
    特点: 人类可读,灵活
    问题: 难以优化,执行慢

        ↓ [Python 编译器]

[第二层] 字节码
    LOAD_FAST x
    LOAD_FAST y
    BINARY_ADD
    ...
    
    特点: CPython 内部表示
    问题: 仍然是解释执行,慢

        ↓ [TorchDynamo 符号化执行]

[第三层] FX Graph
    graph():
        %x = placeholder[target=x]
        %y = placeholder[target=y]
        %add = call_function[torch.add](%x, %y)
        %mul = call_function[torch.mul](%add, 2)
        return mul
    
    特点: 显式的计算图,可分析可变换
    优势: 可以做各种优化!

        ↓ [后端编译器]

[第四层] 机器码
    融合的 CUDA Kernel
    
    特点: 直接在 GPU 上执行
    优势: 极快!

为什么需要 FX Graph?

没有 FX Graph(Eager 执行):

python 复制代码
z = x + y     # 立即执行 add,结果写入显存
result = z * 2  # 立即执行 mul,需要从显存读 z

问题:
- 两个独立的 Kernel 调用
- 中间结果 z 需要写回显存(慢!)
- 无法优化

有了 FX Graph:

python 复制代码
# FX Graph:
graph():
    %add = call_function[torch.add](%x, %y)
    %mul = call_function[torch.mul](%add, 2)

后端编译器看到整个计算流程:
"哦,add 的结果立即被 mul 使用,我可以融合!"

优化后的 Kernel:
result = fused_add_mul(x, y, 2)  # 一个 Kernel,z 不用写回显存

优势:
- 一个 Kernel 调用(减少 Kernel 启动开销)
- 中间结果 z 保持在寄存器(极快!)
- 节省显存带宽

演示程序

我创建了 Pytorch/examples/demo_bytecode_to_graph.py,运行它可以看到:

  1. 字节码指令详情
  2. 逐步的符号化执行过程
  3. 生成的 FX Graph
  4. 对比真实的 TorchDynamo 行为
bash 复制代码
cd Pytorch/examples
python demo_bytecode_to_graph.py

最终总结:一句话理解

字节码到 FX Graph 的转换 = 把 Python 的底层指令"翻译"成一个显式的计算图,这样编译器就能看到整个计算流程并进行优化。

核心技巧

  • 不真的执行计算(符号化执行)
  • 用符号变量代替真实值
  • 每个操作都记录到图中
  • 最终得到一个可优化的中间表示(IR)

第二层:字节码分析 - Bytecode Analysis

Python 字节码基础

python 复制代码
import dis

def example(x, y):
    z = x + y
    if z > 0:
        return z * 2
    else:
        return z + 1

dis.dis(example)

输出

复制代码
  2           0 LOAD_FAST                0 (x)
              2 LOAD_FAST                1 (y)
              4 BINARY_ADD
              6 STORE_FAST               2 (z)

  3           8 LOAD_FAST                2 (z)
             10 LOAD_CONST               1 (0)
             12 COMPARE_OP               4 (>)
             14 POP_JUMP_IF_FALSE       24

  4          16 LOAD_FAST                2 (z)
             18 LOAD_CONST               2 (2)
             20 BINARY_MULTIPLY
             22 RETURN_VALUE

  6     >>   24 LOAD_FAST                2 (z)
             26 LOAD_CONST               3 (1)
             28 BINARY_ADD
             30 RETURN_VALUE

TorchDynamo 的字节码翻译器

文件位置torch/_dynamo/symbolic_convert.py

python 复制代码
# torch/_dynamo/symbolic_convert.py

class InstructionTranslator:
    """
    字节码指令翻译器
    
    职责:
    1. 逐条解析字节码指令
    2. 维护符号化的变量栈
    3. 构建 FX Graph
    4. 生成 Guard
    """
    
    def __init__(
        self,
        instructions,
        f_globals,
        f_locals,
        f_builtins,
        code_options,
        symbolic_locals,
        symbolic_globals,
    ):
        # 字节码指令列表
        self.instructions = list(instructions)
        self.instruction_pointer = 0
        
        # 命名空间
        self.f_globals = f_globals
        self.f_locals = f_locals
        self.f_builtins = f_builtins
        
        # 符号化变量栈(模拟 Python 执行栈)
        self.stack = []
        
        # 局部变量(符号化)
        self.symbolic_locals = {}
        
        # 输出图
        self.output = OutputGraph(
            f_globals=f_globals,
            f_locals=f_locals,
            f_builtins=f_builtins,
        )
        
        # Guard 列表
        self.guards = []
    
    def run(self):
        """执行符号化转换"""
        while self.instruction_pointer < len(self.instructions):
            inst = self.instructions[self.instruction_pointer]
            
            # 分发到具体的指令处理函数
            handler_name = f"INSTR_{inst.opname}"
            handler = getattr(self, handler_name, None)
            
            if handler is None:
                # 不支持的指令 -> Graph Break
                self.graph_break(f"unsupported instruction: {inst.opname}")
                return
            
            # 执行指令
            handler(inst)
            
            # 移动指令指针
            self.instruction_pointer += 1
    
    # ========== 指令处理函数 ==========
    
    def INSTR_LOAD_FAST(self, inst):
        """
        LOAD_FAST: 加载局部变量到栈
        """
        var_name = inst.argval
        
        if var_name not in self.symbolic_locals:
            # 变量未定义 -> Graph Break
            self.graph_break(f"undefined variable: {var_name}")
            return
        
        # 获取符号化变量
        symbolic_var = self.symbolic_locals[var_name]
        
        # 压入栈
        self.stack.append(symbolic_var)
    
    def INSTR_LOAD_CONST(self, inst):
        """
        LOAD_CONST: 加载常量到栈
        """
        const_value = inst.argval
        
        # 创建常量变量
        from .variables import ConstantVariable
        symbolic_var = ConstantVariable(const_value)
        
        # 压入栈
        self.stack.append(symbolic_var)
    
    def INSTR_BINARY_ADD(self, inst):
        """
        BINARY_ADD: 二元加法
        """
        # 弹出两个操作数
        right = self.stack.pop()
        left = self.stack.pop()
        
        # 执行符号化加法
        result = left.call_method(self, "add", [right], {})
        
        # 结果压入栈
        self.stack.append(result)
    
    def INSTR_COMPARE_OP(self, inst):
        """
        COMPARE_OP: 比较操作
        """
        comp_op = inst.argval  # 例如 '>'
        
        # 弹出两个操作数
        right = self.stack.pop()
        left = self.stack.pop()
        
        # 符号化比较
        result = left.call_method(self, f"__{comp_op}__", [right], {})
        
        # 压入栈
        self.stack.append(result)
    
    def INSTR_POP_JUMP_IF_FALSE(self, inst):
        """
        POP_JUMP_IF_FALSE: 条件跳转(if 语句)
        """
        # 弹出条件
        condition = self.stack.pop()
        
        # [关键] 这里需要决定走哪个分支
        # 如果条件依赖于 Tensor 的值,需要生成 Guard
        
        if condition.is_python_constant():
            # 常量条件,直接判断
            if not condition.as_python_constant():
                # 跳转
                self.instruction_pointer = inst.argval - 1
        else:
            # 动态条件,需要 Guard
            # 尝试获取当前运行时的值
            runtime_value = condition.get_runtime_value()
            
            # 生成 Guard
            guard = condition.make_guard(runtime_value)
            self.guards.append(guard)
            
            # 根据运行时值决定分支
            if not runtime_value:
                self.instruction_pointer = inst.argval - 1
    
    def INSTR_RETURN_VALUE(self, inst):
        """
        RETURN_VALUE: 返回值
        """
        # 弹出返回值
        ret_val = self.stack.pop()
        
        # 添加到输出图
        self.output.add_output(ret_val)
    
    def INSTR_CALL_FUNCTION(self, inst):
        """
        CALL_FUNCTION: 函数调用
        """
        n_args = inst.argval
        
        # 弹出参数
        args = [self.stack.pop() for _ in range(n_args)]
        args.reverse()
        
        # 弹出函数
        fn = self.stack.pop()
        
        # 符号化调用
        result = fn.call_function(self, args, {})
        
        # 压入结果
        self.stack.append(result)
    
    # ========== Graph Break 处理 ==========
    
    def graph_break(self, reason):
        """
        触发 Graph Break
        
        Args:
            reason: 断开原因
        """
        log_graph_break(reason)
        
        # 结束当前图
        self.output.finalize()
        
        # 生成恢复执行的代码
        # (这部分代码会在后续章节详细解释)
        self.output.create_resume_at(self.instruction_pointer)

第三层:符号化变量系统 - Symbolic Variables

TorchDynamo 为每种 Python 类型定义了对应的"符号化变量"。

文件位置torch/_dynamo/variables/

变量基类

python 复制代码
# torch/_dynamo/variables/base.py

class VariableTracker:
    """
    符号化变量的基类
    
    每个变量代表 Python 执行过程中的一个值
    """
    
    def __init__(self, source=None):
        self.source = source  # 变量来源(用于生成 Guard)
    
    def call_method(self, tx, name, args, kwargs):
        """
        调用方法
        
        Args:
            tx: InstructionTranslator 实例
            name: 方法名
            args: 参数列表
            kwargs: 关键字参数
        
        Returns:
            新的 VariableTracker
        """
        raise NotImplementedError
    
    def call_function(self, tx, args, kwargs):
        """
        作为函数调用
        """
        raise NotImplementedError
    
    def is_python_constant(self):
        """是否是 Python 常量"""
        return False
    
    def as_python_constant(self):
        """转换为 Python 常量"""
        raise NotImplementedError
    
    def make_guard(self, expected_value):
        """
        生成 Guard
        
        Args:
            expected_value: 期望的运行时值
        
        Returns:
            Guard 对象
        """
        raise NotImplementedError

Tensor 变量

python 复制代码
# torch/_dynamo/variables/tensor.py

class TensorVariable(VariableTracker):
    """
    表示 torch.Tensor 的符号化变量
    """
    
    def __init__(self, proxy, example_value, source=None):
        super().__init__(source)
        self.proxy = proxy          # FX Proxy 对象
        self.example_value = example_value  # 示例值(用于形状推导)
    
    def call_method(self, tx, name, args, kwargs):
        """
        调用 Tensor 方法
        
        例如:tensor.add(other)
        """
        if name == "add":
            # 处理 add 方法
            other = args[0]
            
            # 在 FX 图中创建节点
            result_proxy = self.proxy.add(other.proxy)
            
            # 推导结果形状
            result_shape = torch.broadcast_shapes(
                self.example_value.shape,
                other.example_value.shape
            )
            result_example = torch.empty(result_shape, dtype=self.example_value.dtype)
            
            # 返回新的 TensorVariable
            return TensorVariable(result_proxy, result_example, source=None)
        
        # 其他方法...
        return super().call_method(tx, name, args, kwargs)
    
    def make_guard(self, expected_value):
        """
        生成 Tensor 的 Guard
        
        检查:形状、dtype、设备
        """
        guards = []
        
        # 形状 Guard
        shape_guard = Guard(
            lambda frame: frame.f_locals[self.source.name()].shape == expected_value.shape,
            f"{self.source.name()}.shape == {expected_value.shape}"
        )
        guards.append(shape_guard)
        
        # dtype Guard
        dtype_guard = Guard(
            lambda frame: frame.f_locals[self.source.name()].dtype == expected_value.dtype,
            f"{self.source.name()}.dtype == {expected_value.dtype}"
        )
        guards.append(dtype_guard)
        
        return guards

常量变量

python 复制代码
# torch/_dynamo/variables/constant.py

class ConstantVariable(VariableTracker):
    """
    表示 Python 常量(int, float, str, None 等)
    """
    
    def __init__(self, value):
        super().__init__(source=None)
        self.value = value
    
    def is_python_constant(self):
        return True
    
    def as_python_constant(self):
        return self.value
    
    def call_method(self, tx, name, args, kwargs):
        """
        常量方法调用
        
        例如:(1).__add__(2) -> 3
        """
        # 直接执行 Python 操作
        result = getattr(self.value, name)(*[a.as_python_constant() for a in args])
        return ConstantVariable(result)

第四层:Guard 机制 - 动态适配的关键

文件位置torch/_dynamo/guards.py

Guard 的生成

python 复制代码
# torch/_dynamo/guards.py

class GuardBuilder:
    """
    Guard 构建器
    
    为每个符号化变量生成运行时检查条件
    """
    
    def __init__(self):
        self.guards = []
    
    def TYPE_MATCH(self, obj_source, expected_type):
        """
        类型匹配 Guard
        
        例如:type(x) == torch.Tensor
        """
        def check(frame):
            obj = eval_source(frame, obj_source)
            return type(obj) == expected_type
        
        guard = Guard(
            check_fn=check,
            verbose_code=f"type({obj_source}) == {expected_type.__name__}"
        )
        self.guards.append(guard)
    
    def TENSOR_MATCH(self, tensor_source, expected_tensor):
        """
        Tensor 匹配 Guard
        
        检查:shape, dtype, device, requires_grad
        """
        # 形状
        self.SHAPE_MATCH(tensor_source, expected_tensor.shape)
        
        # dtype
        self.DTYPE_MATCH(tensor_source, expected_tensor.dtype)
        
        # device
        self.DEVICE_MATCH(tensor_source, expected_tensor.device)
        
        # requires_grad
        self.REQUIRES_GRAD_MATCH(tensor_source, expected_tensor.requires_grad)
    
    def SHAPE_MATCH(self, tensor_source, expected_shape):
        """形状匹配"""
        def check(frame):
            tensor = eval_source(frame, tensor_source)
            return tensor.shape == expected_shape
        
        guard = Guard(
            check_fn=check,
            verbose_code=f"{tensor_source}.shape == {expected_shape}"
        )
        self.guards.append(guard)
    
    def DYNAMIC_DIM(self, tensor_source, dim_idx):
        """
        动态维度
        
        不固定某个维度的大小(用于 batch size)
        """
        # 只检查其他维度
        def check(frame):
            tensor = eval_source(frame, tensor_source)
            # 检查除了 dim_idx 之外的维度
            return True  # 简化示例
        
        guard = Guard(
            check_fn=check,
            verbose_code=f"{tensor_source}.shape[{dim_idx}] is dynamic"
        )
        self.guards.append(guard)

# Guard 检查
def check_guards(frame, guards):
    """
    检查所有 Guard 是否通过
    
    Args:
        frame: Python 帧对象
        guards: Guard 列表
    
    Returns:
        bool: 所有 Guard 都通过返回 True
    """
    for guard in guards:
        if not guard.check(frame):
            log_guard_failure(guard)
            return False
    
    return True

Guard 优化

python 复制代码
# Guard 数量会影响性能,需要优化

class GuardOptimizer:
    """Guard 优化器"""
    
    def optimize(self, guards):
        """
        优化 Guard 列表
        
        策略:
        1. 去重
        2. 合并相似的 Guard
        3. 提前失败(快速检查)
        """
        # 1. 去重
        unique_guards = self.deduplicate(guards)
        
        # 2. 排序(让快速检查的 Guard 在前面)
        sorted_guards = self.sort_by_cost(unique_guards)
        
        return sorted_guards
    
    def deduplicate(self, guards):
        """去除重复的 Guard"""
        seen = set()
        result = []
        
        for guard in guards:
            key = guard.verbose_code
            if key not in seen:
                seen.add(key)
                result.append(guard)
        
        return result
    
    def sort_by_cost(self, guards):
        """按检查成本排序(低成本优先)"""
        def cost(guard):
            # 类型检查最快
            if "type(" in guard.verbose_code:
                return 1
            # 属性检查次快
            elif ".shape" in guard.verbose_code or ".dtype" in guard.verbose_code:
                return 2
            # 其他
            else:
                return 3
        
        return sorted(guards, key=cost)

第五层:Graph Break 处理 - 容错机制

Graph Break 的原因

python 复制代码
# 常见的 Graph Break 场景

def example1(x):
    x = x + 1
    print(x)  # [Graph Break] 副作用(print)
    return x * 2

def example2(x):
    x = x + 1
    if x.item() > 0:  # [Graph Break] 需要立即读取 Tensor 值
        return x * 2
    return x + 1

def example3(x):
    x = x + 1
    import numpy as np
    y = np.array([1, 2, 3])  # [Graph Break] 不支持的库
    return x + len(y)

Graph Break 的处理

文件位置torch/_dynamo/resume_execution.py

python 复制代码
# torch/_dynamo/resume_execution.py

class GraphBreakHandler:
    """
    Graph Break 处理器
    
    当遇到不可追踪的代码时:
    1. 结束当前图
    2. 生成恢复执行的代码
    3. 继续追踪后续代码
    """
    
    def handle_graph_break(self, translator, reason, instruction_pointer):
        """
        处理 Graph Break
        
        Args:
            translator: InstructionTranslator 实例
            reason: 断开原因
            instruction_pointer: 当前指令位置
        """
        # [1] 记录 Graph Break
        log_graph_break(reason, instruction_pointer)
        
        # [2] 结束当前图
        current_graph = translator.output.graph
        
        # [3] 生成代码:将当前状态传递给 Python 解释器
        resume_code = self.generate_resume_code(
            translator,
            instruction_pointer
        )
        
        # [4] 创建新的翻译器,继续追踪后续代码
        next_translator = InstructionTranslator(
            instructions=translator.instructions[instruction_pointer:],
            f_globals=translator.f_globals,
            f_locals=translator.f_locals,
            f_builtins=translator.f_builtins,
            code_options=translator.code_options,
            symbolic_locals=translator.symbolic_locals,
            symbolic_globals=translator.symbolic_globals,
        )
        
        return resume_code, next_translator
    
    def generate_resume_code(self, translator, instruction_pointer):
        """
        生成恢复执行的代码
        
        作用:将优化的图执行结果传递给后续的 Python 代码
        """
        # 示例生成的代码:
        code = f"""
def resume_function(*args, **kwargs):
    # 执行优化的图
    graph_output = compiled_graph(*args, **kwargs)
    
    # 恢复局部变量
    locals_dict = {{}}
    locals_dict['x'] = graph_output
    
    # 继续执行原始 Python 代码(从 Graph Break 位置开始)
    # ...
    
    return result
"""
        return code

Graph Break 的可视化

python 复制代码
import torch._dynamo as dynamo

def visualize_graph_breaks(fn):
    """可视化函数中的 Graph Break"""
    explanation = dynamo.explain(fn)
    
    print(f"函数: {fn.__name__}")
    print(f"Graph 数量: {explanation.graph_count}")
    print(f"Graph Break 数量: {explanation.graph_break_count}")
    print("\nGraph Break 详情:")
    
    for i, break_reason in enumerate(explanation.break_reasons):
        print(f"  [{i+1}] {break_reason}")

# 使用示例
def my_function(x):
    x = x + 1
    print(x)  # Break 1
    x = x * 2
    if x.item() > 0:  # Break 2
        return x + 1
    return x - 1

visualize_graph_breaks(my_function)

# 输出:
# 函数: my_function
# Graph 数量: 3
# Graph Break 数量: 2
#
# Graph Break 详情:
#   [1] call to print (side effect)
#   [2] call to tensor.item() (data-dependent control flow)

完整示例:从字节码到优化

让我们通过一个完整的例子,看看 TorchDynamo 如何工作:

python 复制代码
import torch
import torch._dynamo as dynamo

def example(x, y):
    z = x + y
    if z.sum() > 0:
        return z * 2
    else:
        return z + 1

# 编译
compiled_fn = torch.compile(example)

# 第一次运行
x1 = torch.randn(10)
y1 = torch.randn(10)
output1 = compiled_fn(x1, y1)

# 内部流程:
"""
1. Frame Evaluation Hook 拦截
   -> custom_eval_frame() 被调用

2. 检查缓存
   -> 缓存为空,开始编译

3. 字节码分析
   指令序列:
   LOAD_FAST x
   LOAD_FAST y
   BINARY_ADD
   STORE_FAST z
   LOAD_FAST z
   LOAD_METHOD sum
   CALL_METHOD 0
   LOAD_CONST 0
   COMPARE_OP >
   POP_JUMP_IF_FALSE [else分支]
   ...

4. 符号化执行
   -> x: TensorVariable(shape=(10,), dtype=float32)
   -> y: TensorVariable(shape=(10,), dtype=float32)
   -> z = x + y: TensorVariable(shape=(10,), dtype=float32)
   -> z.sum(): TensorVariable(shape=(), dtype=float32)
   -> z.sum() > 0: 需要运行时值
      运行时值 = True (假设)
      
5. 生成 Guard
   Guard 1: x.shape == (10,)
   Guard 2: x.dtype == torch.float32
   Guard 3: y.shape == (10,)
   Guard 4: y.dtype == torch.float32
   Guard 5: (x + y).sum() > 0  # 控制流 Guard

6. 构建 FX Graph(if 分支)
   def forward(x, y):
       z = x + y
       return z * 2

7. 后端编译(Inductor)
   -> 生成融合的 CUDA kernel

8. 缓存并执行
"""

# 第二次运行(相同条件)
x2 = torch.randn(10)
y2 = torch.randn(10)  # 假设 sum() > 0
output2 = compiled_fn(x2, y2)

# 内部流程:
"""
1. 检查缓存 -> 命中
2. 检查 Guard -> 全部通过
3. 直接执行缓存的优化代码 [快!]
"""

# 第三次运行(条件改变)
x3 = torch.randn(10) * -10  # 让 sum() < 0
y3 = torch.randn(10) * -10
output3 = compiled_fn(x3, y3)

# 内部流程:
"""
1. 检查缓存 -> 命中
2. 检查 Guard -> Guard 5 失败!(sum() > 0 不满足)
3. 重新编译 else 分支
4. 生成新的 Guard 和优化代码
5. 缓存两个版本(if 和 else)
"""

总结

TorchDynamo 的核心原理

  1. Frame Evaluation Hook:PEP 523 钩子劫持函数执行
  2. 字节码分析:反汇编并逐条解析字节码
  3. 符号化执行:不实际执行,而是构建符号化的计算图
  4. Guard 机制:记录执行条件,动态适配不同情况
  5. FX Graph 构建:生成 FX 图,传递给后端编译器
  6. Graph Break 处理:遇到不可追踪代码时优雅降级

与 FX 的关系

复制代码
TorchDynamo = 前端(捕获图)
    |
    | 输出 FX Graph
    |
    v
FX = 中间表示(图优化)
    |
    | 输出优化后的 Graph
    |
    v
后端编译器(Inductor/AOTAutograd)= 代码生成

关键优势

  1. 完整的 Python 支持:控制流、异常处理、动态特性
  2. 零代码修改 :只需 @torch.compile 装饰器
  3. 渐进式优化:Graph Break 后继续追踪
  4. 动态适配:Guard 机制支持动态形状和控制流

性能特征

复制代码
首次编译:慢(需要分析、编译)
    |
    v
缓存命中 + Guard 通过:极快(直接执行优化代码)
    |
    v
Guard 失败:中等(重新编译,但有部分缓存)

推荐阅读顺序

如果你想深入阅读 TorchDynamo 源码,建议按以下顺序:

  1. torch/_dynamo/eval_frame.py - 理解 Frame Hook 机制 (500 行)
  2. torch/_dynamo/symbolic_convert.py - 理解字节码翻译 (2000+ 行,核心)
  3. torch/_dynamo/variables/ - 理解符号化变量系统 (多个文件)
  4. torch/_dynamo/guards.py - 理解 Guard 生成与检查 (800 行)
  5. torch/_dynamo/output_graph.py - 理解 FX Graph 构建 (600 行)

总计:~5000+ 行核心代码(比 FX 复杂 2 倍)


本文档基于 PyTorch 2.0+ 源码编写,部分实现细节可能因版本而异

相关推荐
Keep_Trying_Go8 小时前
统一的人群计数训练框架(PyTorch)——基于主流的密度图模型训练框架
人工智能·pytorch·python·深度学习·算法·机器学习·人群计数
西猫雷婶10 小时前
卷积运算效果的池化处理|最大值
人工智能·pytorch·python·深度学习·神经网络·机器学习·cnn
像风一样自由202010 小时前
基于PyTorch实现U-Net的路面裂缝检测系统
人工智能·pytorch·python
九死九歌10 小时前
【Sympydantic】使用sympydantic,利用pydantic告别numpy与pytorch编程中,tensor形状带来的烦人痛点!
开发语言·pytorch·python·机器学习·numpy·pydantic
阿正的梦工坊11 小时前
PyTorch 权重剪枝中的阈值计算:深入解读 numel() 和 torch.kthvalue()
人工智能·pytorch·剪枝
爱ZW的小白猿11 小时前
【pytorch】dataset类的使用
pytorch
白茶三许12 小时前
【江鸟中原】集光鸿蒙项目开发
pytorch·深度学习·harmonyos
瀚岳-诸葛弩12 小时前
对比tensorflow,从0开始学pytorch(五)--CBAM
人工智能·pytorch·python
一瞬祈望1 天前
PyTorch 图像分类完整项目模板实战
人工智能·pytorch·python·深度学习·分类