Pytorch 学习笔记(17):decompositions.py —— 算子分解的百科全书

本文 聚焦 torch/_decomp/decompositions.py------一个 5746 行、包含 171 个算子分解 的核心文件。它是 PyTorch 编译器(torch.compile)和导出系统(torch.export)的"降维利器",将上百个复杂 ATen 算子拆解为更少、更基础的算子集合,大幅降低后端编译器需要支持的算子数量。理解这个文件,就掌握了 PyTorch 算子语义的精确定义。


一、算子分解在编译栈中的位置

复制代码
用户代码: y = F.gelu(x)
    │
    ▼ TorchDynamo 捕获
FX Graph: call_function(aten.gelu, x)
    │
    ▼ ★ 算子分解 (decompositions.py) ★
FX Graph: call_function(aten.mul, x, 0.5 * (1 + aten.erf(x * 0.707...)))
    │
    ▼ Inductor Lowering
Inductor IR: Pointwise(lambda x: x * 0.5 * (1 + erf(x * 0.707...)))
    │
    ▼ Triton/C++ Codegen
最终 kernel 代码

核心价值 :ATen 库有 2000+ 个算子,如果每个后端(Triton/C++/MPS/XPU)都要逐一实现,工程量巨大。算子分解将复杂算子表达为基础算子的组合,使后端只需实现约 200 个核心算子(Core ATen Ops) 即可支撑整个 PyTorch 功能集。


二、_decomp 子系统全景

复制代码
torch/_decomp/
├── __init__.py              ← 注册机制 + decomposition table + API
├── decompositions.py        ← ★ 本文分析对象(171 个 post-autograd 分解)
├── decompositions_for_jvp.py ← 前向自动微分(JVP)专用分解
└── decompositions_for_rng.py ← RNG(随机数生成器)函数化分解

两张分解表

python 复制代码
# torch/_decomp/__init__.py
global_decomposition_table = {
    "post_autograd": {},   # ← decompositions.py 注册到这里
    "pre_autograd": {},    # ← 自动微分之前的分解(少量)
    "meta": {},            # ← Meta tensor 推导(shape/dtype 计算)
}

decomposition_table = global_decomposition_table["post_autograd"]
pre_autograd_decomposition_table = global_decomposition_table["pre_autograd"]
  • post_autograd(主要):在自动微分展开之后应用,分解前向和反向算子
  • pre_autograd:在自动微分展开之前应用,通常用于将高层 API 分解为可微分的底层算子

分解 vs 引用实现 vs 原语

复制代码
复杂 ATen 算子 (aten.gelu, aten.batch_norm, aten.lstm, ...)
      │
      ▼ decompositions.py(算子分解)
核心 ATen 算子 (aten.mul, aten.add, aten.exp, aten.var_mean, ...)
      │
      ▼ torch._refs(引用实现,可选)
原语 _prims (prims.mul, prims.add, prims.exp, ...)
层级 模块 角色 示例
分解 _decomp/ 将复杂算子拆为核心算子 gelu → mul + erf
引用 _refs/ 核心算子的 Python 参考实现 add → prims.add
原语 _prims/ 最小的不可再分操作 prims.add

三、文件整体结构鸟瞰

decompositions.py 共 5746 行,可按功能分为以下区域:

复制代码
decompositions.py (5746 行, 171 个 @register_decomposition)
│
├── [1-120]       import + 辅助函数
│                 type_casts / pw_cast_for_opmath 等类型提升装饰器
│
├── [125-360]     激活函数及其反向 (15 个)
│                 tanh_backward, sigmoid_backward, gelu_backward,
│                 silu, hardswish, mish, prelu, rrelu, leaky_relu...
│
├── [360-700]     损失函数 (12 个)
│                 mse_loss, smooth_l1_loss, huber_loss,
│                 binary_cross_entropy, nll_loss, soft_margin_loss...
│
├── [700-900]     基础数学/张量操作 (10 个)
│                 dist, _euclidean_dist, slice, slice_backward,
│                 slice_scatter, select_scatter...
│
├── [900-1350]    形状 & 内存操作 (15 个)
│                 split, unfold, diag_embed, diagonal_backward,
│                 _chunk_cat, split_with_sizes_copy...
│
├── [1350-1600]   归一化层 (8 个)
│                 native_layer_norm_backward, native_group_norm,
│                 native_group_norm_backward, native_batch_norm,
│                 cudnn_batch_norm, batch_norm...
│
├── [1600-2100]   BatchNorm 完整实现 (~500 行)
│                 native_batch_norm_helper (训练/推理两条路径)
│                 batch_norm_backward 反向传播
│
├── [2100-2600]   Embedding & 索引操作 (8 个)
│                 embedding, embedding_dense_backward,
│                 index_add, index_copy, _unsafe_masked_index...
│
├── [2600-2950]   自适应池化 & 上采样基础 (5 个)
│                 _adaptive_avg_pool2d, im2col, col2im...
│
├── [2950-3500]   上采样 nearest (20+ 个)
│                 upsample_nearest1d/2d/3d 及其 exact 变体
│                 完整的 N 维最近邻上采样实现
│
├── [3500-3930]   RNN 系列 (8 个)
│                 rnn_tanh, rnn_relu, lstm, gru
│                 每种有 input 和 data (PackedSequence) 两个变体
│
├── [3930-4600]   上采样 bilinear/bicubic/trilinear (15 个)
│                 upsample_bilinear2d, upsample_bicubic2d,
│                 upsample_trilinear3d, _upsample_bicubic2d_aa...
│
├── [4600-5100]   混合算子 (15 个)
│                 affine_grid_generator, grid_sampler_2d,
│                 mv, reflection_pad, replication_pad, aminmax,
│                 nll_loss_forward...
│
└── [5100-5746]   杂项 (20+ 个)
│                 arange, floor_divide, squeeze, baddbmm,
│                 scaled_dot_product_attention_for_cpu,
│                 weight_norm, isin, bernoulli, take,
│                 max_pool2d_with_indices_backward, hann_window...

四、注册机制详解

4.1 @register_decomposition 装饰器

python 复制代码
@register_decomposition(aten.gelu_backward)
@out_wrapper("grad_input")
@pw_cast_for_opmath
def gelu_backward(grad: Tensor, self: Tensor, approximate: str = "none"):
    ...

三层装饰器各司其职:

装饰器 功能
@register_decomposition(aten_op) 将函数注册到 decomposition_table[aten_op]
@out_wrapper("grad_input") 自动处理 out= 参数(inplace 输出语义)
@pw_cast_for_opmath 自动类型提升(如 FP16 → FP32 计算 → FP16 输出)

4.2 注册到全局分解表

python 复制代码
# __init__.py 中的核心逻辑
def register_decomposition(aten_op, registry=None, *, type="post_autograd"):
    def decomposition_decorator(fn):
        if registry is None:
            registry = global_decomposition_table[type]
        # 将 fn 注册到 registry[aten_op]
        _add_op_to_registry(registry, aten_op, fn)
        return fn
    return decomposition_decorator

支持三种 aten_op 形式:

  • 单个 OpOverloadaten.gelu_backward(最常见)
  • OpOverloadPacket:自动注册所有 overload
  • 列表[aten.arange.default, aten.arange.out](一个实现覆盖多个变体)

4.3 消费者如何使用分解表

python 复制代码
# Inductor 获取分解
from torch._decomp import get_decompositions
decomps = get_decompositions([aten.gelu, aten.batch_norm, ...])

# torch.export 获取核心分解
from torch._decomp import core_aten_decompositions
decomps = core_aten_decompositions()  # 约 200+ 个算子

五、类型提升机制

5.1 type_casts 装饰器

这是一个精妙的装饰器工厂,解决低精度计算的数值稳定性问题:

python 复制代码
def type_casts(f, type_promotion, compute_dtype_only=False):
    @functools.wraps(f)
    def inner(*args, **kwargs):
        # 1. 从所有 Tensor 参数推导计算精度和输出精度
        computation_dtype, result_dtype = utils.elementwise_dtypes(
            *flat_args, type_promotion_kind=type_promotion
        )
        # 2. 将所有输入提升到计算精度
        r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))
        # 3. 将输出降低到结果精度
        return tree_map(decrease_prec, r)
    return inner

典型场景:用户输入 FP16 Tensor → 提升到 FP32 做 GELU 计算 → 输出降回 FP16。

5.2 四种预定义策略

python 复制代码
# 默认提升:FP16/BF16 → FP32 计算 → 降回原精度
pw_cast_for_opmath = partial(type_casts, type_promotion=DEFAULT)

# 仅提升计算精度,不降回(用于中间计算)
compute_only_pw_cast_for_opmath = partial(type_casts, compute_dtype_only=True)

# 整数 → 浮点提升(如 INT → FLOAT 的除法)
pw_cast_for_int_to_real = partial(type_casts, type_promotion=INT_TO_FLOAT)

# 包含非 Tensor 参数的提升
pw_cast_for_opmath_non_tensor_args = partial(type_casts, include_non_tensor_args=True)

六、典型分解案例深度剖析

6.1 激活函数------以 GELU 反向为例

python 复制代码
@register_decomposition(aten.gelu_backward)
@out_wrapper("grad_input")
@pw_cast_for_opmath
def gelu_backward(grad: Tensor, self: Tensor, approximate: str = "none"):
    if approximate == "tanh":
        # tanh 近似版本的精确梯度公式
        kBeta = M_SQRT2 * M_2_SQRTPI * 0.5
        kKappa = 0.044715
        x_sq = self * self
        x_cube = x_sq * self
        inner = kBeta * (self + kKappa * x_cube)
        tanh_inner = torch.tanh(inner)

        left = 0.5 * self
        right = 1 + tanh_inner
        left_derivative = 0.5 * right
        tanh_derivative = 1 - tanh_inner * tanh_inner
        inner_derivative = kBeta * (1 + 3 * kKappa * x_sq)
        right_derivative = left * tanh_derivative * inner_derivative

        return grad * (left_derivative + right_derivative)
    else:
        # 精确版本:GELU(x) = x * Φ(x)
        # GELU'(x) = Φ(x) + x * φ(x)
        kAlpha = M_SQRT1_2
        kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5
        cdf = 0.5 * (1 + torch.erf(self * kAlpha))   # Φ(x)
        pdf = kBeta * torch.exp(self * self * -0.5)    # φ(x)
        return grad * (cdf + self * pdf)

分解产物aten.gelu_backward{mul, erf, exp, tanh, add, sub} 这些基础算子的组合。Inductor 可以将它们融合为一个 Triton kernel。

6.2 损失函数------以 NLL Loss 为例

NLL Loss 的分解展示了更复杂的逻辑------需要处理 ignore_index、weight、reduction 三个维度的组合:

python 复制代码
@register_decomposition(aten.nll_loss_backward)
@out_wrapper("grad_input")
def nll_loss_backward(grad_output, self, target, weight, reduction, ignore_index, total_weight):
    # 核心分解逻辑
    def _nll_loss_backward(target, weight, ignore_index):
        # 1. 创建全零梯度
        grad_input = self.new_zeros(self.shape)
        # 2. 处理 ignore_index
        safe_target = torch.where(target != ignore_index, target, 0)
        # 3. 计算逐样本权重
        if weight is not None:
            grad_input_per_sample = -weight.index_select(0, safe_target.view(-1))
        else:
            grad_input_per_sample = self.new_full(...)
        # 4. 根据 reduction 模式调整
        if reduction == Reduction.MEAN.value:
            grad_input_per_sample /= total_weight
        # 5. scatter 回梯度矩阵
        grad_input.scatter_(1, safe_target.unsqueeze(1), grad_input_per_sample.unsqueeze(1))
        return grad_input

6.3 归一化层------BatchNorm 完整实现

BatchNorm 的分解是文件中最长的连续代码段(~500 行),因为它需要处理训练/推理两条完全不同的路径:

python 复制代码
def native_batch_norm_helper(input, weight, bias, running_mean, running_var,
                              training, momentum, eps, functional):
    if training:
        # 训练模式:计算当前 batch 的统计量
        biased_var, mean = torch.var_mean(input, dim=reduction_dims, correction=0)
        rstd = torch.rsqrt(biased_var + eps)
        output = (input - mean) * rstd

        # 更新 running statistics(EMA)
        if running_mean is not None:
            new_running_mean = momentum * save_mean + (1 - momentum) * running_mean
            if not functional:
                running_mean.copy_(new_running_mean)  # inplace 更新
        # ... running_var 类似
    else:
        # 推理模式:使用 running statistics
        output = (input - running_mean) * (1 / sqrt(running_var + eps))

    # 应用 affine 变换
    if weight is not None:
        output = output * weight
    if bias is not None:
        output = output + bias

    return output, save_mean, save_rstd

关键设计决策

  • functional=True 时返回新的 running_mean/var(用于 functorch/export)
  • functional=False 时 inplace 更新(兼容 eager 模式语义)
  • 训练模式使用无偏方差估计(n/(n-1) 校正)更新 running_var

6.4 RNN 系列------递归结构的分解

RNN/LSTM/GRU 的分解是最具教学意义的:

python 复制代码
# 基础 RNN cell
def rnn_cell(nonlinearity):
    def inner(x, hidden, wi, wh, bi, bh):
        igates = x @ wi.t()
        hgates = hidden @ wh.t()
        if bi is not None:
            igates = igates + bi
        if bh is not None:
            hgates = hgates + bh
        return nonlinearity(igates + hgates)
    return inner

# LSTM cell
def lstm_cell(x, hid, wi, wh, bi, bh):
    h, c = hid
    gates = x @ wi.t() + h @ wh.t()  # 4 个 gate 一次矩阵乘法
    if bi is not None:
        gates += bi + bh
    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
    ingate = torch.sigmoid(ingate)
    forgetgate = torch.sigmoid(forgetgate)
    cellgate = torch.tanh(cellgate)
    outgate = torch.sigmoid(outgate)
    c_new = (forgetgate * c) + (ingate * cellgate)
    h_new = outgate * torch.tanh(c_new)
    return h_new, c_new

然后通过 _rnn_helper 处理多层、双向、dropout 等组合:

python 复制代码
def _rnn_helper(input, hidden, params, has_biases, num_layers,
                dropout, train, bidirectional, batch_first, layer_fn):
    for i in range(num_layers):
        # 前向
        result_fw = layer_fn(input, hidden_fw, params_fw)
        if bidirectional:
            # 反向(翻转输入)
            result_bw = layer_fn(reversed_input, hidden_bw, params_bw)
            result = torch.cat([result_fw, result_bw], -1)
        if dropout > 0 and i < num_layers - 1:
            result = F.dropout(result, p=dropout, training=train)
    return result, final_hiddens

6.5 上采样------算法密集型分解

upsample_bicubic2d 是文件中最长的单个分解(~150 行),实现了完整的双三次插值算法:

python 复制代码
@register_decomposition(aten.upsample_bicubic2d.default)
def upsample_bicubic2d(input, output_size, align_corners, scales_h=None, scales_w=None):
    # 1. 计算源坐标和插值系数
    def compute_source_index(output_index, input_size, output_size, align_corners):
        if align_corners:
            return output_index * (input_size - 1) / (output_size - 1)
        else:
            return (output_index + 0.5) / scale - 0.5

    # 2. 对每个输出像素,取 4x4 邻域并应用 cubic 核
    def cubic_interp1d(x0, x1, x2, x3, t):
        # Cardinal cubic spline
        A = -0.75
        coeffs_0 = ((A * (t + 1) - 5 * A) * (t + 1) + 8 * A) * (t + 1) - 4 * A
        coeffs_1 = ((A + 2) * t - (A + 3)) * t * t + 1
        coeffs_2 = ((A + 2) * (1 - t) - (A + 3)) * (1 - t) * (1 - t) + 1
        coeffs_3 = ...
        return x0 * coeffs_0 + x1 * coeffs_1 + x2 * coeffs_2 + x3 * coeffs_3

    # 3. 先沿 H 方向插值,再沿 W 方向插值(可分离卷积)

6.6 Scaled Dot Product Attention --- 跨层语义对齐

SDPA for CPU 的分解展示了一个独特需求------保证分解后的输出 Tensor 与原始 flash attention 的内存布局完全一致:

python 复制代码
@register_decomposition(aten._scaled_dot_product_flash_attention_for_cpu.default)
def scaled_dot_product_flash_attention_for_cpu(query, key, value, ...):
    output, attn = aten._scaled_dot_product_attention_math.default(
        query, key, value, ...
    )
    # 关键:对齐内存布局
    # Flash attention 返回 [N, H, L, E],但内部 permute 后 contiguous
    # 分解必须产出相同的 strides,否则后续 .view() 会失败
    output = (
        output.permute(2, 0, 1, 3)
        .contiguous(memory_format=torch.contiguous_format)
        .permute(1, 2, 0, 3)
    )
    return output, attn

七、辅助函数体系

文件中有大量不带 @register_decomposition 的辅助函数,它们被多个分解共享:

7.1 通用辅助

函数 功能
_unsqueeze_to_dim(x, dim) 反复 unsqueeze 直到 x.dim() == dim
apply_loss_reduction(loss, reduction) 统一处理 loss 的 NONE/MEAN/SUM 模式
to_real_dtype(dtype) 复数类型 → 对应实数类型
_maybe_cast(x, dtype) 安全的可选 Tensor 类型转换
_normalize_start_end(x, dim, start, end) 标准化 slice 的 start/end 参数

7.2 RNN 专用辅助

函数 功能
rnn_cell(nonlinearity) 工厂函数,创建基础 RNN cell
lstm_cell(x, hid, ...) LSTM cell 的完整实现
gru_cell(x, hid, ...) GRU cell 的完整实现
one_layer_rnn(...) 单层 RNN 循环
_rnn_helper(...) 多层+双向+dropout 的通用 RNN 框架
gather_params(params, has_biases, ...) 从扁平参数列表重组为层级结构
mkldnn_rnn_layer(...) MKL-DNN 后端适配

7.3 上采样专用辅助

函数 功能
get_scale_value(...) 从 scale_factors 或 output_size 计算缩放比
_compute_scale(...) 计算 align_corners 模式的精确缩放
compute_source_index(...) 输出坐标 → 源坐标映射
_upsample_get_steps(...) 生成上采样索引序列

7.4 注册 inplace 变体

python 复制代码
def register_inplace(aten_op, outplace_op):
    """将 out-of-place 分解自动转为 inplace 变体"""
    @register_decomposition(aten_op)
    def inplace_op(*args, **kwargs):
        out = outplace_op(*args, **kwargs)
        return args[0].copy_(out)
    return inplace_op

# 使用示例:
register_inplace(aten.leaky_relu_, aten.leaky_relu)
register_inplace(aten.silu_, aten.silu)

八、分解的分类统计

对 171 个分解按功能域分类:

类别 数量 代表算子
激活函数 + 反向 20 gelu, silu, hardswish, mish, prelu, rrelu, elu, leaky_relu
损失函数 + 反向 15 mse_loss, nll_loss, bce, smooth_l1, huber, soft_margin
归一化层 + 反向 12 batch_norm, layer_norm, group_norm, cudnn_batch_norm
上采样 25+ nearest 1d/2d/3d, bilinear, bicubic, trilinear, lanczos
RNN 系列 8 rnn_tanh, rnn_relu, lstm, gru (input + data 变体)
索引 & 聚集 10 embedding, index_add, index_copy, scatter, take
形状操作 15 split, slice, unfold, diag_embed, chunk_cat, pad_sequence
池化 5 adaptive_avg_pool2d, max_pool2d_backward, im2col, col2im
数学运算 10 dist, euclidean_dist, addmm, baddbmm, mv, floor_divide
Padding 8 reflection_pad 1d/2d/3d, replication_pad 1d/2d/3d + backward
空间变换 5 affine_grid_generator, grid_sampler_2d, upsample_bicubic2d
杂项 38+ arange, bernoulli, dropout, weight_norm, isin, nansum, sdpa...

九、Core ATen Ops --- 分解的目标集

core_aten_decompositions() 定义了 torch.export 使用的标准分解集,其目标是将 2000+ ATen 算子降低到约 200 个核心算子

python 复制代码
# 核心算子示例(分解的目标集,不会被进一步分解)
不分解的核心算子 = {
    aten.add, aten.mul, aten.sub, aten.div,        # 算术
    aten.exp, aten.log, aten.sin, aten.cos,         # 基础数学
    aten.mm, aten.bmm, aten.addmm,                  # 矩阵乘法
    aten.conv2d, aten.max_pool2d,                    # 卷积/池化
    aten.cat, aten.stack, aten.reshape,              # 形状操作
    aten.index, aten.scatter, aten.gather,           # 索引操作
    aten.sum, aten.mean, aten.var_mean,              # 归约操作
    aten.where, aten.clamp,                          # 条件操作
    ...
}

分解示意

复制代码
aten.gelu         →  aten.mul + aten.erf + aten.add
aten.batch_norm   →  aten.var_mean + aten.rsqrt + aten.mul + aten.add
aten.lstm         →  aten.mm + aten.sigmoid + aten.tanh + aten.mul + aten.add
aten.nll_loss     →  aten.gather + aten.where + aten.scatter + aten.sum
aten.upsample_*   →  aten.index + aten.mul + aten.add (插值系数)

十、Inductor 如何使用分解

python 复制代码
# torch/_inductor/decomposition.py(简化)

def select_decomp_table():
    decomps = get_decompositions([
        aten.addcdiv, aten.affine_grid_generator, aten.batch_norm,
        aten.binary_cross_entropy, aten.embedding_dense_backward,
        aten.gelu, aten.gelu_backward, aten.lstm, aten.gru,
        aten.native_batch_norm, aten.native_layer_norm_backward,
        aten.nll_loss_backward, aten.reflection_pad2d,
        aten.upsample_bicubic2d, aten.upsample_nearest2d,
        # ... 200+ 个算子
    ])
    return decomps

这些分解在 AOT Autograd 的 aot_stage1_graph_capture 阶段被应用------FX 图中的算子被替换为分解后的等价表达式。


十一、设计哲学与权衡

11.1 为什么不把所有算子都分解?

  1. 性能:有些算子的专用 kernel(如 cuDNN 的 BatchNorm)比分解后的通用融合更快
  2. 数值精度:分解可能改变数值结果(如 BatchNorm 的方差累加方式)
  3. 编译时间:分解产生更多节点,增加 Inductor 的融合搜索空间

11.2 分解的正确性保证

每个分解都必须满足:

  • 语义等价:对任意输入产出与原算子 bit-exact 相同的结果(或在文档中声明差异)
  • 类型正确:输出 dtype、device、shape 必须与原算子一致
  • 梯度正确:反向分解产出的梯度必须与 autograd 引擎一致

11.3 out_wrapper 的精妙设计

许多 ATen 算子有 out= 变体(如 aten.gelu.out),@out_wrapper 装饰器自动处理这一差异:

python 复制代码
@out_wrapper()
def gelu(self: Tensor) -> Tensor:
    return self * 0.5 * (1 + torch.erf(self * 0.707...))

# 自动生成等价的:
# def gelu_out(self, *, out):
#     result = gelu(self)
#     out.resize_(result.shape)
#     out.copy_(result)
#     return out

十二、阅读路线建议

复制代码
Level 1:理解机制
    → __init__.py 的 register_decomposition + decomposition_table
    → 选 silu (5行) 或 fill_scalar (2行) 作为入门示例

Level 2:理解类型系统
    → type_casts 装饰器 (58-120行)
    → pw_cast_for_opmath 的四种变体

Level 3:读懂核心分解
    → gelu_backward (251-280行) ← 激活函数代表
    → nll_loss_backward (557-607行) ← 损失函数代表
    → native_batch_norm_helper (1944-2040行) ← 归一化代表
    → lstm_cell (3630-3650行) ← RNN 代表

Level 4:读懂复杂分解
    → adaptive_avg_pool2d (2644-2780行) ← 索引密集型
    → upsample_bicubic2d (4811-4900行) ← 算法密集型
    → affine_grid_generator (4415-4595行) ← 空间变换

Level 5:调试技巧
    → TORCH_COMPILE_DEBUG=1 查看分解前后的 FX 图差异
    → torch._decomp.get_decompositions([aten.xxx]) 查询某算子是否有分解
    → torch.fx.experimental.proxy_tensor.make_fx(fn)() 观察分解效果

十三、数据流全景图

复制代码
ATen 算子集(2000+ 个)
    │
    ▼ register_decomposition 注册
global_decomposition_table["post_autograd"]
    │ 171 个分解函数
    │
    ├─── torch.compile 路径:
    │    AOT Autograd → aot_config.decompositions = {...}
    │    → make_fx 追踪时自动替换
    │    → 分解后的 FX 图交给 Inductor
    │
    ├─── torch.export 路径:
    │    core_aten_decompositions()
    │    → 导出标准化 FX 图(仅含核心算子)
    │
    └─── functorch 路径:
         vmap/grad 需要所有算子可函数化
         → 分解将 inplace/mutation 算子转为纯函数

核心 ATen 算子集 (~200 个)
    │
    ▼ Inductor Lowering
Inductor IR → Triton/C++ 代码

十四、总结

decompositions.py 的 5746 行代码浓缩了 PyTorch 算子语义的精确定义。它的价值体现在三个维度:

  1. 工程价值:将 2000+ ATen 算子降维到 ~200 个核心算子,使后端编译器的实现成本降低一个数量级
  2. 教学价值:每个分解都是对应算子的"可执行数学公式"------BatchNorm 的训练/推理路径、LSTM 的门控机制、双三次插值的 Cardinal Spline 系数------全部用 Python + 基础张量操作表达得清清楚楚
  3. 生态价值 :通过 core_aten_decompositions()torch.export,第三方推理引擎只需支持 ~200 个核心算子即可运行任意 PyTorch 模型

这个文件就像一本"ATen 算子翻译词典"------左边是复杂算子,右边是基础算子的组合。掌握它,就掌握了 PyTorch 编译器和导出系统的语义基石。

相关推荐
xian_wwq2 小时前
【学习笔记】大模型如何理解图片
笔记·学习
星马梦缘2 小时前
强化学习实战5——BaseLine3使用自定义环境训练【输入状态向量】
pytorch·python·jupyter·强化学习·baseline3·gymnasium
talen_hx2962 小时前
《零基础入门Spark》学习笔记 Day 13
笔记·学习·spark
Flittly2 小时前
【SpringAIAlibaba新手村系列】(15)MCP Client 调用本地服务
java·笔记·spring·ai·springboot
SteveSenna2 小时前
强化学习4.1:基于价值——Q-learning
人工智能·学习·算法·机器人
少许极端2 小时前
算法奇妙屋(四十四)-贪心算法学习之路11
java·学习·算法·贪心算法
艾莉丝努力练剑2 小时前
C++ 核心编程练习:从基础语法到递归、重载与宏定义
linux·运维·服务器·c语言·c++·学习
鱼鳞_2 小时前
Java学习笔记_Day24(HashMAap)
java·笔记·学习
AI视觉网奇2 小时前
ChatTutor 部署笔记
笔记