本文 聚焦
torch/_decomp/decompositions.py------一个 5746 行、包含 171 个算子分解 的核心文件。它是 PyTorch 编译器(torch.compile)和导出系统(torch.export)的"降维利器",将上百个复杂 ATen 算子拆解为更少、更基础的算子集合,大幅降低后端编译器需要支持的算子数量。理解这个文件,就掌握了 PyTorch 算子语义的精确定义。
一、算子分解在编译栈中的位置
用户代码: y = F.gelu(x)
│
▼ TorchDynamo 捕获
FX Graph: call_function(aten.gelu, x)
│
▼ ★ 算子分解 (decompositions.py) ★
FX Graph: call_function(aten.mul, x, 0.5 * (1 + aten.erf(x * 0.707...)))
│
▼ Inductor Lowering
Inductor IR: Pointwise(lambda x: x * 0.5 * (1 + erf(x * 0.707...)))
│
▼ Triton/C++ Codegen
最终 kernel 代码
核心价值 :ATen 库有 2000+ 个算子,如果每个后端(Triton/C++/MPS/XPU)都要逐一实现,工程量巨大。算子分解将复杂算子表达为基础算子的组合,使后端只需实现约 200 个核心算子(Core ATen Ops) 即可支撑整个 PyTorch 功能集。
二、_decomp 子系统全景
torch/_decomp/
├── __init__.py ← 注册机制 + decomposition table + API
├── decompositions.py ← ★ 本文分析对象(171 个 post-autograd 分解)
├── decompositions_for_jvp.py ← 前向自动微分(JVP)专用分解
└── decompositions_for_rng.py ← RNG(随机数生成器)函数化分解
两张分解表
python
# torch/_decomp/__init__.py
global_decomposition_table = {
"post_autograd": {}, # ← decompositions.py 注册到这里
"pre_autograd": {}, # ← 自动微分之前的分解(少量)
"meta": {}, # ← Meta tensor 推导(shape/dtype 计算)
}
decomposition_table = global_decomposition_table["post_autograd"]
pre_autograd_decomposition_table = global_decomposition_table["pre_autograd"]
- post_autograd(主要):在自动微分展开之后应用,分解前向和反向算子
- pre_autograd:在自动微分展开之前应用,通常用于将高层 API 分解为可微分的底层算子
分解 vs 引用实现 vs 原语
复杂 ATen 算子 (aten.gelu, aten.batch_norm, aten.lstm, ...)
│
▼ decompositions.py(算子分解)
核心 ATen 算子 (aten.mul, aten.add, aten.exp, aten.var_mean, ...)
│
▼ torch._refs(引用实现,可选)
原语 _prims (prims.mul, prims.add, prims.exp, ...)
| 层级 | 模块 | 角色 | 示例 |
|---|---|---|---|
| 分解 | _decomp/ |
将复杂算子拆为核心算子 | gelu → mul + erf |
| 引用 | _refs/ |
核心算子的 Python 参考实现 | add → prims.add |
| 原语 | _prims/ |
最小的不可再分操作 | prims.add |
三、文件整体结构鸟瞰
decompositions.py 共 5746 行,可按功能分为以下区域:
decompositions.py (5746 行, 171 个 @register_decomposition)
│
├── [1-120] import + 辅助函数
│ type_casts / pw_cast_for_opmath 等类型提升装饰器
│
├── [125-360] 激活函数及其反向 (15 个)
│ tanh_backward, sigmoid_backward, gelu_backward,
│ silu, hardswish, mish, prelu, rrelu, leaky_relu...
│
├── [360-700] 损失函数 (12 个)
│ mse_loss, smooth_l1_loss, huber_loss,
│ binary_cross_entropy, nll_loss, soft_margin_loss...
│
├── [700-900] 基础数学/张量操作 (10 个)
│ dist, _euclidean_dist, slice, slice_backward,
│ slice_scatter, select_scatter...
│
├── [900-1350] 形状 & 内存操作 (15 个)
│ split, unfold, diag_embed, diagonal_backward,
│ _chunk_cat, split_with_sizes_copy...
│
├── [1350-1600] 归一化层 (8 个)
│ native_layer_norm_backward, native_group_norm,
│ native_group_norm_backward, native_batch_norm,
│ cudnn_batch_norm, batch_norm...
│
├── [1600-2100] BatchNorm 完整实现 (~500 行)
│ native_batch_norm_helper (训练/推理两条路径)
│ batch_norm_backward 反向传播
│
├── [2100-2600] Embedding & 索引操作 (8 个)
│ embedding, embedding_dense_backward,
│ index_add, index_copy, _unsafe_masked_index...
│
├── [2600-2950] 自适应池化 & 上采样基础 (5 个)
│ _adaptive_avg_pool2d, im2col, col2im...
│
├── [2950-3500] 上采样 nearest (20+ 个)
│ upsample_nearest1d/2d/3d 及其 exact 变体
│ 完整的 N 维最近邻上采样实现
│
├── [3500-3930] RNN 系列 (8 个)
│ rnn_tanh, rnn_relu, lstm, gru
│ 每种有 input 和 data (PackedSequence) 两个变体
│
├── [3930-4600] 上采样 bilinear/bicubic/trilinear (15 个)
│ upsample_bilinear2d, upsample_bicubic2d,
│ upsample_trilinear3d, _upsample_bicubic2d_aa...
│
├── [4600-5100] 混合算子 (15 个)
│ affine_grid_generator, grid_sampler_2d,
│ mv, reflection_pad, replication_pad, aminmax,
│ nll_loss_forward...
│
└── [5100-5746] 杂项 (20+ 个)
│ arange, floor_divide, squeeze, baddbmm,
│ scaled_dot_product_attention_for_cpu,
│ weight_norm, isin, bernoulli, take,
│ max_pool2d_with_indices_backward, hann_window...
四、注册机制详解
4.1 @register_decomposition 装饰器
python
@register_decomposition(aten.gelu_backward)
@out_wrapper("grad_input")
@pw_cast_for_opmath
def gelu_backward(grad: Tensor, self: Tensor, approximate: str = "none"):
...
三层装饰器各司其职:
| 装饰器 | 功能 |
|---|---|
@register_decomposition(aten_op) |
将函数注册到 decomposition_table[aten_op] |
@out_wrapper("grad_input") |
自动处理 out= 参数(inplace 输出语义) |
@pw_cast_for_opmath |
自动类型提升(如 FP16 → FP32 计算 → FP16 输出) |
4.2 注册到全局分解表
python
# __init__.py 中的核心逻辑
def register_decomposition(aten_op, registry=None, *, type="post_autograd"):
def decomposition_decorator(fn):
if registry is None:
registry = global_decomposition_table[type]
# 将 fn 注册到 registry[aten_op]
_add_op_to_registry(registry, aten_op, fn)
return fn
return decomposition_decorator
支持三种 aten_op 形式:
- 单个 OpOverload :
aten.gelu_backward(最常见) - OpOverloadPacket:自动注册所有 overload
- 列表 :
[aten.arange.default, aten.arange.out](一个实现覆盖多个变体)
4.3 消费者如何使用分解表
python
# Inductor 获取分解
from torch._decomp import get_decompositions
decomps = get_decompositions([aten.gelu, aten.batch_norm, ...])
# torch.export 获取核心分解
from torch._decomp import core_aten_decompositions
decomps = core_aten_decompositions() # 约 200+ 个算子
五、类型提升机制
5.1 type_casts 装饰器
这是一个精妙的装饰器工厂,解决低精度计算的数值稳定性问题:
python
def type_casts(f, type_promotion, compute_dtype_only=False):
@functools.wraps(f)
def inner(*args, **kwargs):
# 1. 从所有 Tensor 参数推导计算精度和输出精度
computation_dtype, result_dtype = utils.elementwise_dtypes(
*flat_args, type_promotion_kind=type_promotion
)
# 2. 将所有输入提升到计算精度
r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))
# 3. 将输出降低到结果精度
return tree_map(decrease_prec, r)
return inner
典型场景:用户输入 FP16 Tensor → 提升到 FP32 做 GELU 计算 → 输出降回 FP16。
5.2 四种预定义策略
python
# 默认提升:FP16/BF16 → FP32 计算 → 降回原精度
pw_cast_for_opmath = partial(type_casts, type_promotion=DEFAULT)
# 仅提升计算精度,不降回(用于中间计算)
compute_only_pw_cast_for_opmath = partial(type_casts, compute_dtype_only=True)
# 整数 → 浮点提升(如 INT → FLOAT 的除法)
pw_cast_for_int_to_real = partial(type_casts, type_promotion=INT_TO_FLOAT)
# 包含非 Tensor 参数的提升
pw_cast_for_opmath_non_tensor_args = partial(type_casts, include_non_tensor_args=True)
六、典型分解案例深度剖析
6.1 激活函数------以 GELU 反向为例
python
@register_decomposition(aten.gelu_backward)
@out_wrapper("grad_input")
@pw_cast_for_opmath
def gelu_backward(grad: Tensor, self: Tensor, approximate: str = "none"):
if approximate == "tanh":
# tanh 近似版本的精确梯度公式
kBeta = M_SQRT2 * M_2_SQRTPI * 0.5
kKappa = 0.044715
x_sq = self * self
x_cube = x_sq * self
inner = kBeta * (self + kKappa * x_cube)
tanh_inner = torch.tanh(inner)
left = 0.5 * self
right = 1 + tanh_inner
left_derivative = 0.5 * right
tanh_derivative = 1 - tanh_inner * tanh_inner
inner_derivative = kBeta * (1 + 3 * kKappa * x_sq)
right_derivative = left * tanh_derivative * inner_derivative
return grad * (left_derivative + right_derivative)
else:
# 精确版本:GELU(x) = x * Φ(x)
# GELU'(x) = Φ(x) + x * φ(x)
kAlpha = M_SQRT1_2
kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5
cdf = 0.5 * (1 + torch.erf(self * kAlpha)) # Φ(x)
pdf = kBeta * torch.exp(self * self * -0.5) # φ(x)
return grad * (cdf + self * pdf)
分解产物 :aten.gelu_backward → {mul, erf, exp, tanh, add, sub} 这些基础算子的组合。Inductor 可以将它们融合为一个 Triton kernel。
6.2 损失函数------以 NLL Loss 为例
NLL Loss 的分解展示了更复杂的逻辑------需要处理 ignore_index、weight、reduction 三个维度的组合:
python
@register_decomposition(aten.nll_loss_backward)
@out_wrapper("grad_input")
def nll_loss_backward(grad_output, self, target, weight, reduction, ignore_index, total_weight):
# 核心分解逻辑
def _nll_loss_backward(target, weight, ignore_index):
# 1. 创建全零梯度
grad_input = self.new_zeros(self.shape)
# 2. 处理 ignore_index
safe_target = torch.where(target != ignore_index, target, 0)
# 3. 计算逐样本权重
if weight is not None:
grad_input_per_sample = -weight.index_select(0, safe_target.view(-1))
else:
grad_input_per_sample = self.new_full(...)
# 4. 根据 reduction 模式调整
if reduction == Reduction.MEAN.value:
grad_input_per_sample /= total_weight
# 5. scatter 回梯度矩阵
grad_input.scatter_(1, safe_target.unsqueeze(1), grad_input_per_sample.unsqueeze(1))
return grad_input
6.3 归一化层------BatchNorm 完整实现
BatchNorm 的分解是文件中最长的连续代码段(~500 行),因为它需要处理训练/推理两条完全不同的路径:
python
def native_batch_norm_helper(input, weight, bias, running_mean, running_var,
training, momentum, eps, functional):
if training:
# 训练模式:计算当前 batch 的统计量
biased_var, mean = torch.var_mean(input, dim=reduction_dims, correction=0)
rstd = torch.rsqrt(biased_var + eps)
output = (input - mean) * rstd
# 更新 running statistics(EMA)
if running_mean is not None:
new_running_mean = momentum * save_mean + (1 - momentum) * running_mean
if not functional:
running_mean.copy_(new_running_mean) # inplace 更新
# ... running_var 类似
else:
# 推理模式:使用 running statistics
output = (input - running_mean) * (1 / sqrt(running_var + eps))
# 应用 affine 变换
if weight is not None:
output = output * weight
if bias is not None:
output = output + bias
return output, save_mean, save_rstd
关键设计决策:
functional=True时返回新的 running_mean/var(用于 functorch/export)functional=False时 inplace 更新(兼容 eager 模式语义)- 训练模式使用无偏方差估计(
n/(n-1)校正)更新 running_var
6.4 RNN 系列------递归结构的分解
RNN/LSTM/GRU 的分解是最具教学意义的:
python
# 基础 RNN cell
def rnn_cell(nonlinearity):
def inner(x, hidden, wi, wh, bi, bh):
igates = x @ wi.t()
hgates = hidden @ wh.t()
if bi is not None:
igates = igates + bi
if bh is not None:
hgates = hgates + bh
return nonlinearity(igates + hgates)
return inner
# LSTM cell
def lstm_cell(x, hid, wi, wh, bi, bh):
h, c = hid
gates = x @ wi.t() + h @ wh.t() # 4 个 gate 一次矩阵乘法
if bi is not None:
gates += bi + bh
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
c_new = (forgetgate * c) + (ingate * cellgate)
h_new = outgate * torch.tanh(c_new)
return h_new, c_new
然后通过 _rnn_helper 处理多层、双向、dropout 等组合:
python
def _rnn_helper(input, hidden, params, has_biases, num_layers,
dropout, train, bidirectional, batch_first, layer_fn):
for i in range(num_layers):
# 前向
result_fw = layer_fn(input, hidden_fw, params_fw)
if bidirectional:
# 反向(翻转输入)
result_bw = layer_fn(reversed_input, hidden_bw, params_bw)
result = torch.cat([result_fw, result_bw], -1)
if dropout > 0 and i < num_layers - 1:
result = F.dropout(result, p=dropout, training=train)
return result, final_hiddens
6.5 上采样------算法密集型分解
upsample_bicubic2d 是文件中最长的单个分解(~150 行),实现了完整的双三次插值算法:
python
@register_decomposition(aten.upsample_bicubic2d.default)
def upsample_bicubic2d(input, output_size, align_corners, scales_h=None, scales_w=None):
# 1. 计算源坐标和插值系数
def compute_source_index(output_index, input_size, output_size, align_corners):
if align_corners:
return output_index * (input_size - 1) / (output_size - 1)
else:
return (output_index + 0.5) / scale - 0.5
# 2. 对每个输出像素,取 4x4 邻域并应用 cubic 核
def cubic_interp1d(x0, x1, x2, x3, t):
# Cardinal cubic spline
A = -0.75
coeffs_0 = ((A * (t + 1) - 5 * A) * (t + 1) + 8 * A) * (t + 1) - 4 * A
coeffs_1 = ((A + 2) * t - (A + 3)) * t * t + 1
coeffs_2 = ((A + 2) * (1 - t) - (A + 3)) * (1 - t) * (1 - t) + 1
coeffs_3 = ...
return x0 * coeffs_0 + x1 * coeffs_1 + x2 * coeffs_2 + x3 * coeffs_3
# 3. 先沿 H 方向插值,再沿 W 方向插值(可分离卷积)
6.6 Scaled Dot Product Attention --- 跨层语义对齐
SDPA for CPU 的分解展示了一个独特需求------保证分解后的输出 Tensor 与原始 flash attention 的内存布局完全一致:
python
@register_decomposition(aten._scaled_dot_product_flash_attention_for_cpu.default)
def scaled_dot_product_flash_attention_for_cpu(query, key, value, ...):
output, attn = aten._scaled_dot_product_attention_math.default(
query, key, value, ...
)
# 关键:对齐内存布局
# Flash attention 返回 [N, H, L, E],但内部 permute 后 contiguous
# 分解必须产出相同的 strides,否则后续 .view() 会失败
output = (
output.permute(2, 0, 1, 3)
.contiguous(memory_format=torch.contiguous_format)
.permute(1, 2, 0, 3)
)
return output, attn
七、辅助函数体系
文件中有大量不带 @register_decomposition 的辅助函数,它们被多个分解共享:
7.1 通用辅助
| 函数 | 功能 |
|---|---|
_unsqueeze_to_dim(x, dim) |
反复 unsqueeze 直到 x.dim() == dim |
apply_loss_reduction(loss, reduction) |
统一处理 loss 的 NONE/MEAN/SUM 模式 |
to_real_dtype(dtype) |
复数类型 → 对应实数类型 |
_maybe_cast(x, dtype) |
安全的可选 Tensor 类型转换 |
_normalize_start_end(x, dim, start, end) |
标准化 slice 的 start/end 参数 |
7.2 RNN 专用辅助
| 函数 | 功能 |
|---|---|
rnn_cell(nonlinearity) |
工厂函数,创建基础 RNN cell |
lstm_cell(x, hid, ...) |
LSTM cell 的完整实现 |
gru_cell(x, hid, ...) |
GRU cell 的完整实现 |
one_layer_rnn(...) |
单层 RNN 循环 |
_rnn_helper(...) |
多层+双向+dropout 的通用 RNN 框架 |
gather_params(params, has_biases, ...) |
从扁平参数列表重组为层级结构 |
mkldnn_rnn_layer(...) |
MKL-DNN 后端适配 |
7.3 上采样专用辅助
| 函数 | 功能 |
|---|---|
get_scale_value(...) |
从 scale_factors 或 output_size 计算缩放比 |
_compute_scale(...) |
计算 align_corners 模式的精确缩放 |
compute_source_index(...) |
输出坐标 → 源坐标映射 |
_upsample_get_steps(...) |
生成上采样索引序列 |
7.4 注册 inplace 变体
python
def register_inplace(aten_op, outplace_op):
"""将 out-of-place 分解自动转为 inplace 变体"""
@register_decomposition(aten_op)
def inplace_op(*args, **kwargs):
out = outplace_op(*args, **kwargs)
return args[0].copy_(out)
return inplace_op
# 使用示例:
register_inplace(aten.leaky_relu_, aten.leaky_relu)
register_inplace(aten.silu_, aten.silu)
八、分解的分类统计
对 171 个分解按功能域分类:
| 类别 | 数量 | 代表算子 |
|---|---|---|
| 激活函数 + 反向 | 20 | gelu, silu, hardswish, mish, prelu, rrelu, elu, leaky_relu |
| 损失函数 + 反向 | 15 | mse_loss, nll_loss, bce, smooth_l1, huber, soft_margin |
| 归一化层 + 反向 | 12 | batch_norm, layer_norm, group_norm, cudnn_batch_norm |
| 上采样 | 25+ | nearest 1d/2d/3d, bilinear, bicubic, trilinear, lanczos |
| RNN 系列 | 8 | rnn_tanh, rnn_relu, lstm, gru (input + data 变体) |
| 索引 & 聚集 | 10 | embedding, index_add, index_copy, scatter, take |
| 形状操作 | 15 | split, slice, unfold, diag_embed, chunk_cat, pad_sequence |
| 池化 | 5 | adaptive_avg_pool2d, max_pool2d_backward, im2col, col2im |
| 数学运算 | 10 | dist, euclidean_dist, addmm, baddbmm, mv, floor_divide |
| Padding | 8 | reflection_pad 1d/2d/3d, replication_pad 1d/2d/3d + backward |
| 空间变换 | 5 | affine_grid_generator, grid_sampler_2d, upsample_bicubic2d |
| 杂项 | 38+ | arange, bernoulli, dropout, weight_norm, isin, nansum, sdpa... |
九、Core ATen Ops --- 分解的目标集
core_aten_decompositions() 定义了 torch.export 使用的标准分解集,其目标是将 2000+ ATen 算子降低到约 200 个核心算子:
python
# 核心算子示例(分解的目标集,不会被进一步分解)
不分解的核心算子 = {
aten.add, aten.mul, aten.sub, aten.div, # 算术
aten.exp, aten.log, aten.sin, aten.cos, # 基础数学
aten.mm, aten.bmm, aten.addmm, # 矩阵乘法
aten.conv2d, aten.max_pool2d, # 卷积/池化
aten.cat, aten.stack, aten.reshape, # 形状操作
aten.index, aten.scatter, aten.gather, # 索引操作
aten.sum, aten.mean, aten.var_mean, # 归约操作
aten.where, aten.clamp, # 条件操作
...
}
分解示意:
aten.gelu → aten.mul + aten.erf + aten.add
aten.batch_norm → aten.var_mean + aten.rsqrt + aten.mul + aten.add
aten.lstm → aten.mm + aten.sigmoid + aten.tanh + aten.mul + aten.add
aten.nll_loss → aten.gather + aten.where + aten.scatter + aten.sum
aten.upsample_* → aten.index + aten.mul + aten.add (插值系数)
十、Inductor 如何使用分解
python
# torch/_inductor/decomposition.py(简化)
def select_decomp_table():
decomps = get_decompositions([
aten.addcdiv, aten.affine_grid_generator, aten.batch_norm,
aten.binary_cross_entropy, aten.embedding_dense_backward,
aten.gelu, aten.gelu_backward, aten.lstm, aten.gru,
aten.native_batch_norm, aten.native_layer_norm_backward,
aten.nll_loss_backward, aten.reflection_pad2d,
aten.upsample_bicubic2d, aten.upsample_nearest2d,
# ... 200+ 个算子
])
return decomps
这些分解在 AOT Autograd 的 aot_stage1_graph_capture 阶段被应用------FX 图中的算子被替换为分解后的等价表达式。
十一、设计哲学与权衡
11.1 为什么不把所有算子都分解?
- 性能:有些算子的专用 kernel(如 cuDNN 的 BatchNorm)比分解后的通用融合更快
- 数值精度:分解可能改变数值结果(如 BatchNorm 的方差累加方式)
- 编译时间:分解产生更多节点,增加 Inductor 的融合搜索空间
11.2 分解的正确性保证
每个分解都必须满足:
- 语义等价:对任意输入产出与原算子 bit-exact 相同的结果(或在文档中声明差异)
- 类型正确:输出 dtype、device、shape 必须与原算子一致
- 梯度正确:反向分解产出的梯度必须与 autograd 引擎一致
11.3 out_wrapper 的精妙设计
许多 ATen 算子有 out= 变体(如 aten.gelu.out),@out_wrapper 装饰器自动处理这一差异:
python
@out_wrapper()
def gelu(self: Tensor) -> Tensor:
return self * 0.5 * (1 + torch.erf(self * 0.707...))
# 自动生成等价的:
# def gelu_out(self, *, out):
# result = gelu(self)
# out.resize_(result.shape)
# out.copy_(result)
# return out
十二、阅读路线建议
Level 1:理解机制
→ __init__.py 的 register_decomposition + decomposition_table
→ 选 silu (5行) 或 fill_scalar (2行) 作为入门示例
Level 2:理解类型系统
→ type_casts 装饰器 (58-120行)
→ pw_cast_for_opmath 的四种变体
Level 3:读懂核心分解
→ gelu_backward (251-280行) ← 激活函数代表
→ nll_loss_backward (557-607行) ← 损失函数代表
→ native_batch_norm_helper (1944-2040行) ← 归一化代表
→ lstm_cell (3630-3650行) ← RNN 代表
Level 4:读懂复杂分解
→ adaptive_avg_pool2d (2644-2780行) ← 索引密集型
→ upsample_bicubic2d (4811-4900行) ← 算法密集型
→ affine_grid_generator (4415-4595行) ← 空间变换
Level 5:调试技巧
→ TORCH_COMPILE_DEBUG=1 查看分解前后的 FX 图差异
→ torch._decomp.get_decompositions([aten.xxx]) 查询某算子是否有分解
→ torch.fx.experimental.proxy_tensor.make_fx(fn)() 观察分解效果
十三、数据流全景图
ATen 算子集(2000+ 个)
│
▼ register_decomposition 注册
global_decomposition_table["post_autograd"]
│ 171 个分解函数
│
├─── torch.compile 路径:
│ AOT Autograd → aot_config.decompositions = {...}
│ → make_fx 追踪时自动替换
│ → 分解后的 FX 图交给 Inductor
│
├─── torch.export 路径:
│ core_aten_decompositions()
│ → 导出标准化 FX 图(仅含核心算子)
│
└─── functorch 路径:
vmap/grad 需要所有算子可函数化
→ 分解将 inplace/mutation 算子转为纯函数
核心 ATen 算子集 (~200 个)
│
▼ Inductor Lowering
Inductor IR → Triton/C++ 代码
十四、总结
decompositions.py 的 5746 行代码浓缩了 PyTorch 算子语义的精确定义。它的价值体现在三个维度:
- 工程价值:将 2000+ ATen 算子降维到 ~200 个核心算子,使后端编译器的实现成本降低一个数量级
- 教学价值:每个分解都是对应算子的"可执行数学公式"------BatchNorm 的训练/推理路径、LSTM 的门控机制、双三次插值的 Cardinal Spline 系数------全部用 Python + 基础张量操作表达得清清楚楚
- 生态价值 :通过
core_aten_decompositions()和torch.export,第三方推理引擎只需支持 ~200 个核心算子即可运行任意 PyTorch 模型
这个文件就像一本"ATen 算子翻译词典"------左边是复杂算子,右边是基础算子的组合。掌握它,就掌握了 PyTorch 编译器和导出系统的语义基石。