Pytorch 学习笔记(8): PyTorch FX

一、FX 是什么?

FX 是 PyTorch 提供的模型转换工具包 ,核心功能是将 nn.Module 转换为可分析、可修改的中间表示(IR),再生成新的 Python 代码。

FX 三大核心组件

组件 作用
符号追踪器 (Symbolic Tracer) 通过"符号执行"捕获模型语义
中间表示 (Graph/IR) 用 Graph 结构表示计算流程
Python 代码生成 将修改后的 Graph 转回可执行代码

快速示例

python 复制代码
import torch
from torch.fx import symbolic_trace

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)
    
    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)

# 1. 符号追踪
module = MyModule()
symbolic_traced = symbolic_trace(module)

# 2. 查看中间表示(Graph)
print(symbolic_traced.graph)
"""
graph():
    %x : [num_users=1] = placeholder[target=x]
    %param : [num_users=1] = get_attr[target=param]
    %add : [num_users=1] = call_function[target=operator.add](...)
    %linear : [num_users=1] = call_module[target=linear](...)
    %clamp : [num_users=1] = call_method[target=clamp](...)
    return clamp
"""

# 3. 查看生成的代码
print(symbolic_traced.code)
"""
def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min=0.0, max=1.0);  linear = None
    return clamp
"""

二、Graph 结构详解

FX 的 Graph 由 Node 组成,每个 Node 代表一个操作。

Node 的六种操作类型(opcode)

opcode 含义 示例
placeholder 函数输入参数 %x = placeholder[target=x]
get_attr 获取模块属性/参数 %param = get_attr[target=param]
call_function 调用函数(如 torch.add call_function[target=operator.add]
call_module 调用子模块的 forward call_module[target=linear]
call_method 调用 Tensor 方法 call_method[target=clamp]
output 返回值 return clamp

打印 Graph 表格

python 复制代码
gm.graph.print_tabular()

输出:

复制代码
opcode         name    target                   args        kwargs
-------------  ------  -----------------------  ----------  --------
placeholder    x       x                        ()          {}
get_attr       linear_weight  linear.weight      ()          {}
call_function  add_1   <built-in function add>    (x, ...)    {}
call_module    linear_1  linear                 (add_1,)    {}
...

三、编写 FX 转换(Transformation)

标准转换模板

python 复制代码
import torch
import torch.fx as fx

def transform(m: torch.nn.Module, 
              tracer_class: type = fx.Tracer) -> torch.nn.Module:
    # Step 1: 获取 Graph
    graph = tracer_class().trace(m)
    
    # Step 2: 修改 Graph
    # ... 转换逻辑 ...
    
    # Step 3: 返回新的 GraphModule
    return fx.GraphModule(m, graph)

直接修改 Graph 示例:替换算子

python 复制代码
def transform(m: torch.nn.Module) -> torch.nn.Module:
    graph = fx.Tracer().trace(m)
    
    for node in graph.nodes:
        # 找到 torch.add 调用,替换为 torch.mul
        if node.op == 'call_function' and node.target == torch.add:
            node.target = torch.mul  # 直接修改目标函数
    
    graph.lint()  # 检查 Graph 合法性
    return fx.GraphModule(m, graph)

插入新节点示例

python 复制代码
# 在指定节点后插入 ReLU
with traced.graph.inserting_after(node):
    new_node = traced.graph.call_function(
        torch.relu, args=(node,))
    
    # 将所有使用原节点的地方替换为新节点
    node.replace_all_uses_with(new_node)

四、高级转换技巧

1. 子图重写(replace_pattern)

FX 提供"查找-替换"功能,自动匹配并替换子图:

python 复制代码
from torch.fx import replace_pattern

# 定义模式(要查找的子图)
def pattern(w1, w2):
    return torch.cat([w1, w2])

# 定义替换(新的子图)
def replacement(w1, w2):
    return torch.stack([w1, w2])

# 执行替换
matches = replace_pattern(gm, pattern, replacement)

2. Proxy 重追踪

用 Proxy 机制自动记录操作,避免手动 Graph 操作:

python 复制代码
def relu_decomposition(x):
    """将 ReLU 分解为 (x > 0) * x"""
    return (x > 0) * x

decomposition_rules = {F.relu: relu_decomposition}

def decompose(model):
    graph = fx.Tracer().trace(model)
    new_graph = fx.Graph()
    env = {}
    tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)
    
    for node in graph.nodes:
        if node.op == 'call_function' and node.target in decomposition_rules:
            # 用 Proxy 包装参数,自动记录操作
            proxy_args = [
                fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x 
                for x in node.args
            ]
            output_proxy = decomposition_rules[node.target](*proxy_args)
            env[node.name] = output_proxy.node
        else:
            new_node = new_graph.node_copy(node, lambda x: env[x.name])
            env[node.name] = new_node
    
    return fx.GraphModule(model, new_graph)

3. Interpreter 模式

逐节点执行 Graph,适合分析和转换:

python 复制代码
class ShapeProp:
    """形状传播:记录每个节点的 shape 和 dtype"""
    def __init__(self, mod):
        self.mod = mod
        self.graph = mod.graph
        self.modules = dict(self.mod.named_modules())
    
    def propagate(self, *args):
        args_iter = iter(args)
        env = {}
        
        def load_arg(a):
            return torch.fx.graph.map_arg(a, lambda n: env[n.name])
        
        for node in self.graph.nodes:
            if node.op == 'placeholder':
                result = next(args_iter)
            elif node.op == 'get_attr':
                result = fetch_attr(node.target)
            elif node.op == 'call_function':
                result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
            # ... 其他操作类型 ...
            
            # 记录 shape 和 dtype
            if isinstance(result, torch.Tensor):
                node.shape = result.shape
                node.dtype = result.dtype
            
            env[node.name] = result
        
        return load_arg(self.graph.result)

五、Transformer 类

TransformerInterpreter 的子类,用于生成新 Graph

python 复制代码
class NegSigmSwapXformer(fx.Transformer):
    def call_function(self, target, args, kwargs):
        if target is torch.sigmoid:
            return torch.neg(*args, **kwargs)
        return super().call_function(target, args, kwargs)
    
    def call_method(self, target, args, kwargs):
        if target == "neg":
            call_self, *args_tail = args
            return call_self.sigmoid(*args_tail, **kwargs)
        return super().call_method(target, args, kwargs)

# 使用
gm = fx.symbolic_trace(fn)
transformed = NegSigmSwapXformer(gm).transform()

六、调试技巧

1. 检查转换正确性

不要用 == 比较 Tensor,用 torch.allclose()

python 复制代码
# ❌ 错误
assert original(input) == transformed(input)

# ✅ 正确
assert torch.allclose(original(input), transformed(input))

2. 调试生成的代码

python 复制代码
# 方法1:打印代码
print(traced.code)

# 方法2:导出到文件夹
traced.to_folder("output_folder", "ModuleName")
from output_folder import ModuleName

# 方法3:使用 pdb
import pdb; pdb.set_trace()
traced(input)  # 单步调试

3. 可视化 Graph

python 复制代码
# 打印表格形式
traced.graph.print_tabular()

# 打印 Graph 结构
print(traced.graph)

七、符号追踪的限制

❌ 动态控制流(不支持)

python 复制代码
def func_to_trace(x):
    if x.sum() > 0:      # ❌ 错误:条件依赖输入值
        return torch.relu(x)
    else:
        return torch.neg(x)

# 报错:TraceError: symbolically traced variables cannot be used as inputs to control flow

✅ 静态控制流(支持)

python 复制代码
class MyModule(torch.nn.Module):
    def __init__(self, do_activation=False):
        super().__init__()
        self.do_activation = do_activation  # 超参数,非输入
    
    def forward(self, x):
        x = self.linear(x)
        if self.do_activation:  # ✅ 正确:条件不依赖输入
            x = torch.relu(x)
        return x

解决方案:concrete_args

python 复制代码
# 用 concrete_args 绑定具体值
fx.symbolic_trace(f, concrete_args={'flag': True})

八、常用 API 速查

API 功能
symbolic_trace(root, concrete_args=None) 符号追踪
wrap(fn_or_name) 注册叶子函数
GraphModule(root, graph) 从 Graph 创建模块
graph.call_function(fn, args, kwargs) 插入函数调用节点
graph.call_module(module_name, args, kwargs) 插入模块调用节点
node.replace_all_uses_with(new_node) 替换所有使用
graph.lint() 检查 Graph 合法性
gm.recompile() 重新编译 forward
replace_pattern(gm, pattern, replacement) 子图替换
Interpreter.run(*args) 解释执行 Graph
Transformer.transform() 转换并返回新模块

九、最佳实践

  1. 转换后调用 graph.lint() - 确保 Graph 结构合法
  2. 修改后调用 gm.recompile() - 同步生成 forward 代码
  3. torch.allclose() 验证正确性 - 浮点数比较
  4. 避免 set 迭代 - 用 dict 保持确定性顺序
  5. 标记叶子模块 - 对训练标志敏感的模块用 is_leaf_module

结语

FX 是 PyTorch 模型优化的基础设施,广泛应用于:

  • 量化(Quantization)
  • 算子融合(Operator Fusion)
  • 剪枝(Pruning)
  • 分布式训练优化

掌握 FX 可以让你深入理解 PyTorch 模型的内部结构,实现自定义的编译优化流程。


📌 参考资源

相关推荐
xuhaoyu_cpp_java2 小时前
Boyer-Moore 投票算法
java·经验分享·笔记·学习·算法
雨浓YN2 小时前
OPC UA 通讯开发笔记 - 基于Opc.Ua.Client
笔记·c#
沪漂阿龙3 小时前
深度剖析神经网络学习:从损失函数到SGD,手写数字识别完整实战
人工智能·神经网络·学习
迷路爸爸1803 小时前
Docker 入门学习笔记 06:用一个可复现的 Python 项目真正理解 Dockerfile
笔记·学习·docker
ghie90903 小时前
基于学习的模型预测控制(LBMPC)MATLAB实现指南
开发语言·学习·matlab
Engineer邓祥浩3 小时前
JVM学习笔记(6) 第二部分 自动内存管理 第5章节 调优案例分析与实战
jvm·笔记·学习
ysa0510303 小时前
斐波那契上斐波那契【矩阵快速幂】
数据结构·c++·笔记·算法
倒酒小生3 小时前
4月7日算法学习小结
linux·服务器·学习
xinzheng新政3 小时前
Javascript·深入学习基础知识2
开发语言·javascript·学习