一、FX 是什么?
FX 是 PyTorch 提供的模型转换工具包 ,核心功能是将 nn.Module 转换为可分析、可修改的中间表示(IR),再生成新的 Python 代码。
FX 三大核心组件
| 组件 | 作用 |
|---|---|
| 符号追踪器 (Symbolic Tracer) | 通过"符号执行"捕获模型语义 |
| 中间表示 (Graph/IR) | 用 Graph 结构表示计算流程 |
| Python 代码生成 | 将修改后的 Graph 转回可执行代码 |
快速示例
python
import torch
from torch.fx import symbolic_trace
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
# 1. 符号追踪
module = MyModule()
symbolic_traced = symbolic_trace(module)
# 2. 查看中间表示(Graph)
print(symbolic_traced.graph)
"""
graph():
%x : [num_users=1] = placeholder[target=x]
%param : [num_users=1] = get_attr[target=param]
%add : [num_users=1] = call_function[target=operator.add](...)
%linear : [num_users=1] = call_module[target=linear](...)
%clamp : [num_users=1] = call_method[target=clamp](...)
return clamp
"""
# 3. 查看生成的代码
print(symbolic_traced.code)
"""
def forward(self, x):
param = self.param
add = x + param; x = param = None
linear = self.linear(add); add = None
clamp = linear.clamp(min=0.0, max=1.0); linear = None
return clamp
"""
二、Graph 结构详解
FX 的 Graph 由 Node 组成,每个 Node 代表一个操作。
Node 的六种操作类型(opcode)
| opcode | 含义 | 示例 |
|---|---|---|
placeholder |
函数输入参数 | %x = placeholder[target=x] |
get_attr |
获取模块属性/参数 | %param = get_attr[target=param] |
call_function |
调用函数(如 torch.add) |
call_function[target=operator.add] |
call_module |
调用子模块的 forward | call_module[target=linear] |
call_method |
调用 Tensor 方法 | call_method[target=clamp] |
output |
返回值 | return clamp |
打印 Graph 表格
python
gm.graph.print_tabular()
输出:
opcode name target args kwargs
------------- ------ ----------------------- ---------- --------
placeholder x x () {}
get_attr linear_weight linear.weight () {}
call_function add_1 <built-in function add> (x, ...) {}
call_module linear_1 linear (add_1,) {}
...
三、编写 FX 转换(Transformation)
标准转换模板
python
import torch
import torch.fx as fx
def transform(m: torch.nn.Module,
tracer_class: type = fx.Tracer) -> torch.nn.Module:
# Step 1: 获取 Graph
graph = tracer_class().trace(m)
# Step 2: 修改 Graph
# ... 转换逻辑 ...
# Step 3: 返回新的 GraphModule
return fx.GraphModule(m, graph)
直接修改 Graph 示例:替换算子
python
def transform(m: torch.nn.Module) -> torch.nn.Module:
graph = fx.Tracer().trace(m)
for node in graph.nodes:
# 找到 torch.add 调用,替换为 torch.mul
if node.op == 'call_function' and node.target == torch.add:
node.target = torch.mul # 直接修改目标函数
graph.lint() # 检查 Graph 合法性
return fx.GraphModule(m, graph)
插入新节点示例
python
# 在指定节点后插入 ReLU
with traced.graph.inserting_after(node):
new_node = traced.graph.call_function(
torch.relu, args=(node,))
# 将所有使用原节点的地方替换为新节点
node.replace_all_uses_with(new_node)
四、高级转换技巧
1. 子图重写(replace_pattern)
FX 提供"查找-替换"功能,自动匹配并替换子图:
python
from torch.fx import replace_pattern
# 定义模式(要查找的子图)
def pattern(w1, w2):
return torch.cat([w1, w2])
# 定义替换(新的子图)
def replacement(w1, w2):
return torch.stack([w1, w2])
# 执行替换
matches = replace_pattern(gm, pattern, replacement)
2. Proxy 重追踪
用 Proxy 机制自动记录操作,避免手动 Graph 操作:
python
def relu_decomposition(x):
"""将 ReLU 分解为 (x > 0) * x"""
return (x > 0) * x
decomposition_rules = {F.relu: relu_decomposition}
def decompose(model):
graph = fx.Tracer().trace(model)
new_graph = fx.Graph()
env = {}
tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)
for node in graph.nodes:
if node.op == 'call_function' and node.target in decomposition_rules:
# 用 Proxy 包装参数,自动记录操作
proxy_args = [
fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x
for x in node.args
]
output_proxy = decomposition_rules[node.target](*proxy_args)
env[node.name] = output_proxy.node
else:
new_node = new_graph.node_copy(node, lambda x: env[x.name])
env[node.name] = new_node
return fx.GraphModule(model, new_graph)
3. Interpreter 模式
逐节点执行 Graph,适合分析和转换:
python
class ShapeProp:
"""形状传播:记录每个节点的 shape 和 dtype"""
def __init__(self, mod):
self.mod = mod
self.graph = mod.graph
self.modules = dict(self.mod.named_modules())
def propagate(self, *args):
args_iter = iter(args)
env = {}
def load_arg(a):
return torch.fx.graph.map_arg(a, lambda n: env[n.name])
for node in self.graph.nodes:
if node.op == 'placeholder':
result = next(args_iter)
elif node.op == 'get_attr':
result = fetch_attr(node.target)
elif node.op == 'call_function':
result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
# ... 其他操作类型 ...
# 记录 shape 和 dtype
if isinstance(result, torch.Tensor):
node.shape = result.shape
node.dtype = result.dtype
env[node.name] = result
return load_arg(self.graph.result)
五、Transformer 类
Transformer 是 Interpreter 的子类,用于生成新 Graph:
python
class NegSigmSwapXformer(fx.Transformer):
def call_function(self, target, args, kwargs):
if target is torch.sigmoid:
return torch.neg(*args, **kwargs)
return super().call_function(target, args, kwargs)
def call_method(self, target, args, kwargs):
if target == "neg":
call_self, *args_tail = args
return call_self.sigmoid(*args_tail, **kwargs)
return super().call_method(target, args, kwargs)
# 使用
gm = fx.symbolic_trace(fn)
transformed = NegSigmSwapXformer(gm).transform()
六、调试技巧
1. 检查转换正确性
不要用 == 比较 Tensor,用 torch.allclose():
python
# ❌ 错误
assert original(input) == transformed(input)
# ✅ 正确
assert torch.allclose(original(input), transformed(input))
2. 调试生成的代码
python
# 方法1:打印代码
print(traced.code)
# 方法2:导出到文件夹
traced.to_folder("output_folder", "ModuleName")
from output_folder import ModuleName
# 方法3:使用 pdb
import pdb; pdb.set_trace()
traced(input) # 单步调试
3. 可视化 Graph
python
# 打印表格形式
traced.graph.print_tabular()
# 打印 Graph 结构
print(traced.graph)
七、符号追踪的限制
❌ 动态控制流(不支持)
python
def func_to_trace(x):
if x.sum() > 0: # ❌ 错误:条件依赖输入值
return torch.relu(x)
else:
return torch.neg(x)
# 报错:TraceError: symbolically traced variables cannot be used as inputs to control flow
✅ 静态控制流(支持)
python
class MyModule(torch.nn.Module):
def __init__(self, do_activation=False):
super().__init__()
self.do_activation = do_activation # 超参数,非输入
def forward(self, x):
x = self.linear(x)
if self.do_activation: # ✅ 正确:条件不依赖输入
x = torch.relu(x)
return x
解决方案:concrete_args
python
# 用 concrete_args 绑定具体值
fx.symbolic_trace(f, concrete_args={'flag': True})
八、常用 API 速查
| API | 功能 |
|---|---|
symbolic_trace(root, concrete_args=None) |
符号追踪 |
wrap(fn_or_name) |
注册叶子函数 |
GraphModule(root, graph) |
从 Graph 创建模块 |
graph.call_function(fn, args, kwargs) |
插入函数调用节点 |
graph.call_module(module_name, args, kwargs) |
插入模块调用节点 |
node.replace_all_uses_with(new_node) |
替换所有使用 |
graph.lint() |
检查 Graph 合法性 |
gm.recompile() |
重新编译 forward |
replace_pattern(gm, pattern, replacement) |
子图替换 |
Interpreter.run(*args) |
解释执行 Graph |
Transformer.transform() |
转换并返回新模块 |
九、最佳实践
- 转换后调用
graph.lint()- 确保 Graph 结构合法 - 修改后调用
gm.recompile()- 同步生成 forward 代码 - 用
torch.allclose()验证正确性 - 浮点数比较 - 避免 set 迭代 - 用 dict 保持确定性顺序
- 标记叶子模块 - 对训练标志敏感的模块用
is_leaf_module
结语
FX 是 PyTorch 模型优化的基础设施,广泛应用于:
- 量化(Quantization)
- 算子融合(Operator Fusion)
- 剪枝(Pruning)
- 分布式训练优化
掌握 FX 可以让你深入理解 PyTorch 模型的内部结构,实现自定义的编译优化流程。
📌 参考资源