PyTorch 学习笔记(18) : lowering.py

本文聚焦 torch/_inductor/lowering.py------一个 8384 行 的核心文件。它是 Inductor 编译器的"翻译中枢",负责将 FX 图中的 ATen 算子翻译为 Inductor 内部 IR(PointwiseReductionFallbackKernel 等),从而驱动后续的 Triton/C++ 代码生成。如果 decompositions.py 是"算子拆解词典",那么 lowering.py 就是"算子 → IR 编译规则大全"。


一、Lowering 在编译栈中的位置

复制代码
用户代码: y = torch.relu(x) + 1
    │
    ▼ TorchDynamo 捕获
FX Graph: [aten.relu, aten.add]
    │
    ▼ AOT Autograd + 算子分解
FX Graph: [aten.relu, aten.add]  (部分算子已分解)
    │
    ▼ ★ Lowering (lowering.py) ★
Inductor IR: [Pointwise(relu), Pointwise(add)]
    │
    ▼ Scheduler (scheduler.py)
FusedSchedulerNode: [Pointwise(relu + add)]  ← kernel 融合
    │
    ▼ Codegen
Triton Kernel: @triton.jit def fused_relu_add(...)

核心职责:为每个 ATen/Prims 算子提供"如何构造 Inductor IR 节点"的规则。这些规则定义了:

  • 输出的 shape 如何计算
  • 数据的读取方式(make_loader / inner_fn
  • 产出的 IR 类型(PointwiseReductionFallbackKernel 等)

二、文件全局架构

复制代码
lowering.py (8384 行)
│
├── [1-130]       全局变量 & 注册表
│  lowerings = {}           ← 核心字典:op → lowering_fn
│  user_lowerings = {}      ← 用户自定义 lowering(优先级最高)
│  fallbacks = OrderedSet() ← 回退到 eager 执行的算子集
│  needs_realized_inputs    ← 需要物化输入的算子(如 conv, mm)
│  foreach_ops              ← foreach 批量算子集
│
├── [130-550]     注册机制 & 核心辅助
│  register_lowering()      ← 装饰器:注册算子 lowering
│  _register_lowering()     ← 实际注册逻辑(类型提升+广播+验证)
│  register_pointwise()     ← 逐元素算子批量注册
│  make_pointwise()         ← 构造 Pointwise IR 的工厂函数
│  make_fallback()          ← 注册 fallback 到 eager 执行
│  transform_args()         ← 参数预处理(广播+类型提升)
│
├── [550-860]     广播 & 类型提升
│  broadcast_symbolic_shapes()  ← 符号形状的广播推导
│  promote_constants()          ← 标量 → Tensor 常量提升
│  make_foreach_pointwise()     ← foreach 批量逐元素工厂
│
├── [860-1130]    类型转换 lowering
│  _convert_element_type()  ← prims.convert_element_type
│  to_dtype() / to_dtype_bitcast() / to_device()
│  _foreach_map()           ← foreach_map 高阶算子
│
├── [1130-2340]   视图 & 形状操作 lowering (~35 个)
│  view, reshape, permute, expand, squeeze, unsqueeze
│  slice_, as_strided, select, split, unbind, unfold
│  cat, diagonal, diagonal_scatter
│
├── [2340-2630]   Fallback 机制
│  fallback_handler()       ← 创建 FallbackKernel 的通用处理器
│  make_fallback()          ← 注册 fallback + 断言无冲突分解
│  fallback_node_due_to_unsupported_type() ← 自动检测需要回退的类型
│
├── [2630-2770]   随机数算子
│  philox_rand, native_dropout, bernoulli_, rand, randn
│  inductor_seed / inductor_seeds / inductor_lookup_seed
│
├── [2770-3700]   Make_fallback 批量注册 (~124 个)
│  按难度分六大类:Easy / Medium / Difficult / Backwards
│  Impossible(sort,linalg) / Pattern-matched(SDPA)
│
├── [3700-4400]   张量创建 & 填充
│  full, empty, empty_strided, zeros_like, ones_like
│  tensor, arange, iota, linspace
│  gather, embedding, index, index_put
│
├── [4400-4700]   Scatter / Index 操作
│  scatter, scatter_add, scatter_reduce
│  index_put_impl_ (核心 ~200 行)
│
├── [4700-5100]   Padding 操作
│  constant_pad_nd (含 pad_as_cat 优化)
│  reflection_pad, replication_pad
│
├── [5100-6500]   池化 & 上采样 lowering
│  avg_pool2d/3d, max_pool2d/3d_with_indices
│  adaptive_avg_pool2d, adaptive_max_pool2d
│  fractional_max_pool2d/3d
│  upsample_nearest2d_backward
│  avg_pool2d_backward, avg_pool3d_backward
│
├── [6500-6900]   归约操作
│  make_reduction()         ← 归约 IR 工厂
│  mean, var_mean, sum, prod
│  cumsum, cumprod, cummax, cummin
│
├── [6900-7550]   数学运算 & 逐元素注册
│  pow, div, mul, add, sub  ← 手写 lowering
│  register_pointwise(aten.relu) ← 批量注册 (~50 个)
│  register_foreach_pointwise() ← foreach 变体
│
├── [7550-7780]   Inplace 变体注册
│  register_inplace() / register_foreach_inplace()
│
├── [7780-8060]   高阶算子 lowering
│  cond, while_loop, invoke_subgraph
│  associative_scan, with_effects
│
└── [8060-8384]   扩展注册 & 外部模块
│  kernel/ 子模块导入 (mm, conv, 等模板化 kernel)
│  quantized_lowerings, mkldnn_lowerings, jagged_lowerings
│  comm_lowering (分布式通信)
│  inline_asm_elementwise, cvt_e8m0_rceil

三、注册机制------三驾马车

3.1 register_lowering:核心装饰器

python 复制代码
lowerings: dict[Callable | str, Callable] = {}  # 全局 lowering 注册表

def register_lowering(
    aten_fn,
    broadcast=False,
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
    convert_input_to_bool=False,
):
    """装饰器:将 Python 函数注册为某个 ATen 算子的 lowering 实现"""
    return functools.partial(
        _register_lowering, aten_fn,
        broadcast=broadcast,
        type_promotion_kind=type_promotion_kind,
        ...
    )

_register_lowering 内部做三件事:

复制代码
1. 参数预处理  →  transform_args(args, kwargs, broadcast, type_promotion)
   ├── 类型提升:INT → FLOAT, FP16 → FP32 等
   └── 广播对齐:broadcast_tensors(*inputs)

2. 调用 decomp_fn  →  out = decomp_fn(*args, **kwargs)

3. 输出验证  →  validate_ir(out)  # 确保产出合法的 IR 节点

3.2 register_pointwise:逐元素算子的批量注册

python 复制代码
def register_pointwise(
    aten_fn, name=None, broadcast=True,
    type_promotion_kind=DEFAULT,
    override_return_dtype=None,
    override_fn_when_input_bool=None,
    allow_alpha=False,
    ...
):
    """一行代码注册一个逐元素算子"""
    name = name or aten_fn.__name__
    fn = ops_wrapper(name)       # → ops.relu, ops.sigmoid, ...
    fn = make_pointwise(fn, ...) # → 包装为 Pointwise IR 构造器
    fn = register_lowering(aten_fn, ...)(fn)  # → 注册到 lowerings 字典
    return fn

# 使用示例------一行注册一个算子:
relu = register_pointwise(aten.relu)
sigmoid = register_pointwise_numeric_ldf64(aten.sigmoid)
exp = register_pointwise_numeric_ldf64(aten.exp)
sqrt = register_pointwise_numeric_ldf64(aten.sqrt)
abs = register_pointwise(aten.abs)
neg = register_pointwise(aten.neg)
add = register_pointwise(aten.add, allow_alpha=True, use_fma_for_alpha=True)
sub = register_pointwise(aten.sub, allow_alpha=True)
mul = ...  # 手写(含常量优化)

文件中约有 50+ 个算子 通过 register_pointwise 一行注册。

3.3 make_fallback:不能编译的算子回退到 eager

python 复制代码
def make_fallback(op, layout_constraint=None, warn=True):
    """将算子注册为 FallbackKernel ------ 在运行时直接调用 ATen eager 实现"""
    # 1. 检查:如果存在分解,优先用分解,不要 fallback
    assert op not in decompositions or override_decomp

    # 2. CI 中强制报错防止遗漏
    if warn and bool(os.getenv("CI")) and get_decompositions([op]):
        raise AssertionError("a decomposition exists, we should switch to it")

    # 3. 注册
    add_needs_realized_inputs(op)       # fallback 前必须物化所有输入
    add_layout_constraint(op, ...)      # 约束输入的内存布局
    register_lowering(op)(fallback_handler(op))

文件中共有 124 个 make_fallback 调用,按难度分为六档:

级别 说明 代表算子
Easy 有分解但暂未迁移 uniform, exponential, soft_margin_loss_backward
Medium 需要新 IR 或较复杂实现 _trilinear
Difficult 需要新 IR 类型 segment_reduce, histc, histogram
Backwards 前向有实现但反向未实现 _adaptive_avg_pool2d_backward, grid_sampler_2d_backward
Impossible 缺少 Triton/CPU 特性 sort, topk, kthvalue, 全部 linalg.*
Pattern-matched 由模式匹配器处理 全部 _scaled_dot_product_*_attention*

四、make_pointwise------逐元素 IR 构造的核心

这是理解整个 lowering 系统的关键函数(~100 行):

python 复制代码
def make_pointwise(fn, override_return_dtype=None, ...):
    """将一个标量函数 fn 包装为 Pointwise IR 构造器"""

    def inner(*inputs: TensorBox, alpha=None):
        # 1. 常量提升:标量 → 常量 Tensor
        inputs = promote_constants(inputs, override_return_dtype)

        # 2. alpha 处理:a + alpha * b
        if allow_alpha and alpha is not None and alpha != 1:
            if use_fma_for_alpha and is_cuda_float:
                return _add_with_alpha_fma(inputs[0], inputs[1], alpha)
            inputs[-1] = mul(inputs[-1], alpha)

        # 3. 创建加载器
        loaders = [x.make_loader() for x in inputs]
        ranges = inputs[0].get_size()
        dtype = override_return_dtype or inputs[0].get_dtype()

        # 4. 定义逐元素计算函数
        def inner_fn(index):
            return fn(*[load(index) for load in loaders])

        # 5. 构造 Pointwise IR 节点
        return Pointwise.create(
            device=device,
            dtype=dtype,
            inner_fn=inner_fn,  # ← 这个闭包就是将来 codegen 的核心
            ranges=ranges,
        )

    return inner

核心思想inner_fn 是一个接受 index 返回标量表达式的闭包。它定义了"给定输出位置 (i,j,k,...),如何计算该位置的值"。Inductor 的 scheduler 和 codegen 会遍历输出的每个索引,调用 inner_fn 来生成计算代码。


五、精度仿真机制

make_pointwise 中嵌入了一个精妙的低精度仿真系统:

python 复制代码
# 在 V.graph.current_node.meta 中标记了 low_precision_pointwise_barrier
emulate_precision_casts = V.graph.current_node.meta.get(
    "low_precision_pointwise_barrier", False
)

def inner_fn(index):
    inputs_loaded = []
    for inp_index, load in enumerate(loaders):
        out = load(index)
        inp_dtype = inputs[inp_index].get_dtype()
        if emulate_precision_casts and inp_dtype in (torch.float16, torch.bfloat16):
            # 1. 先降到低精度(模拟从内存读取)
            downcast = ops.to_dtype(out, inp_dtype, use_compute_types=False)
            # 2. 再升到计算精度(模拟 GPU 寄存器提升)
            out = ops.to_dtype(downcast, inp_dtype)
        inputs_loaded.append(out)

    out = fn(*inputs_loaded)

    if emulate_output_cast:
        # 输出也做同样的降→升,模拟回写内存再读取的精度损失
        downcast = ops.to_dtype(out, dtype, use_compute_types=False)
        return ops.to_dtype(downcast, dtype)
    return out

这保证了编译后的 fused kernel 与非融合版本(每步都回写内存)有一致的数值行为。


六、视图操作的 lowering 机制

视图操作不产生计算,只产生 IR 中的"视图节点":

6.1 直通视图(零开销)

python 复制代码
@register_lowering([aten.alias, aten.detach, aten.detach_, aten.lift, prims.view_of])
def nop(x):
    return x  # 什么都不做,直接返回

6.2 形状变换视图

python 复制代码
@register_lowering(aten.view, type_promotion_kind=None)
@register_lowering(aten.reshape, type_promotion_kind=None)
@register_lowering(aten._unsafe_view, type_promotion_kind=None)
def view(x, sizes):
    return TensorBox(View.create(x.data, sizes))  # 创建 View IR

@register_lowering(aten.permute, type_promotion_kind=None)
def permute(x, dims):
    return TensorBox(PermuteView.create(x.data, tuple(dims)))

@register_lowering(aten.expand, type_promotion_kind=None)
def expand(x, sizes):
    ...
    return TensorBox(ExpandView.create(x.data, sizes))

@register_lowering(aten.unsqueeze, type_promotion_kind=None)
def unsqueeze(x, dim):
    return TensorBox(SqueezeView.create(x.data, dim=dim))

6.3 as_strided---最底层的视图

python 复制代码
@register_lowering(aten.as_strided, type_promotion_kind=None)
def as_strided(x, size, stride, storage_offset=None):
    # 如果 x 已经是视图,先解包到底层 storage
    if isinstance(x.data, ir.BaseView):
        new_dtype = x.dtype
        x = x.data.unwrap_view()

    x.realize()  # 物化底层 buffer
    storage, old_layout = ir.as_storage_and_layout(x)

    # 构造新的固定布局
    new_layout = ir.FixedLayout(
        device, dtype, size, stride, storage_offset or 0
    )
    return TensorBox(ir.ReinterpretView(data=storage, layout=new_layout))

6.4 slice --- 包含动态尺寸处理

slice_ 是文件中最长的视图 lowering(~130 行),需要处理动态 unbacked 符号的情况:

复制代码
静态/有后端的符号  →  直接创建 SliceView (快速路径)
                         │
动态 unbacked 符号  →  创建 DynamicSliceSize + DynamicSelectStorageOffset
                       (ExternKernel,在运行时计算 size 和 offset)

七、Cat 操作------Inductor 的两种实现

cat 是文件中第二长的单个 lowering(~170 行),有两条截然不同的路径:

7.1 Pointwise Cat(GPU 优化路径)

python 复制代码
def pointwise_cat(inputs, dim=0):
    """将 cat 实现为一个 Pointwise 节点,通过条件分支选择加载源"""
    def inner_fn(idx):
        idx_dim = ops.index_expr(idx[dim], torch.int64)
        # 对每个输入张量生成 mask 和 masked_load
        for i in range(len(inputs)):
            mask = ops.and_(ops.ge(idx_dim, start), ops.lt(idx_dim, end))
            masked_loads.append(ops.masked(mask, lambda: loaders[i](...), 0.0))
        # 链式 where 选择正确的源
        return ops.where(masks[0], masked_loads[0],
               ops.where(masks[1], masked_loads[1], ...))

优势:cat 与下游计算可以被 融合进同一个 kernel,避免额外的内存读写。

7.2 ConcatKernel(通用路径)

python 复制代码
@register_lowering(aten.cat)
def cat(inputs, dim=0):
    # 决策:用 pointwise 还是 ConcatKernel
    if should_use_pointwise_cat(inputs, dim):
        return pointwise_cat(inputs, dim)
    else:
        # ConcatKernel:分配输出 buffer,逐个 copy 输入
        return TensorBox(ir.ConcatKernel.create(inputs, dim))

八、索引操作------Index / Scatter 体系

8.1 index(高级索引读取)

python 复制代码
@register_lowering(aten.index, type_promotion_kind=None)
def index(x, indices):
    try:
        return index_impl(x, indices, check=True)
    except NotImplementedError:
        # 布尔索引:回退到 ATen
        return fallback_handler(aten.index.Tensor)(x, indices)

index_impl 的核心是构造一个 inner_fn,通过 ops.indirect_indexing 实现间接寻址:

python 复制代码
def fn(idx):
    gather_idx = ops.indirect_indexing(index_loader(idx), size[dim])
    idx[dim] = gather_idx
    return x_loader(idx)

ops.indirect_indexing 在 Triton codegen 中会生成 tl.load(ptr + gather_idx * stride)

8.2 index_put_impl_ ------ 最复杂的 lowering 之一

复制代码
index_put_(self, indices, values, accumulate)
    │
    ├── 单个布尔索引 + 标量 value → index_put_as_masked_fill
    │   (优化为 where 操作)
    │
    ├── 可以用 scatter → ir.ScatterKernel
    │   (accumulate=True 时用 atomic_add)
    │
    └── 复杂情况 → index_put_fallback
        (调用 aten.index_put_ 的 eager 实现)

8.3 scatter 系列

python 复制代码
@register_lowering(aten.scatter, type_promotion_kind=None)
def scatter(x, dim, index, src_or_value, **kwargs):
    return scatter_fallback(..., op=aten.scatter_.value, ...)

@register_lowering(aten.scatter_add, type_promotion_kind=None)
def scatter_add(x, dim, index, src):
    return scatter_fallback(
        ..., reduce="sum",
        op=aten.scatter_add_, ...
    )

特殊情况下 scatter 会激活 atomic 操作的限制检测:

python 复制代码
if needs_fallback_due_to_atomic_add_limitations(dtype):
    # 某些 dtype 没有 Triton 原子操作支持
    return fallback_handler(op)(...)

九、归约操作的 lowering

9.1 make_reduction 工厂

python 复制代码
def make_reduction(reduction_type, override_return_dtype=None):
    def inner(x, axis=None, keepdims=False, *, dtype=None):
        kwargs = _make_reduction_inner(x, axis=axis, keepdims=keepdims, ...)
        result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs)
        if isinstance(result.data.data, Reduction):
            result.realize()  # 非展开的归约需要物化
        return result
    return inner

# 注册归约算子
reduce_amax = register_lowering(aten.amax)(make_reduction("max"))
reduce_amin = register_lowering(aten.amin)(make_reduction("min"))
reduce_argmax = register_lowering(aten.argmax)(make_reduction("argmax", torch.int64))
reduce_argmin = register_lowering(aten.argmin)(make_reduction("argmin", torch.int64))

9.2 _make_reduction_inner 的轴处理

python 复制代码
def _make_reduction_inner(x, *, axis, keepdims, dtype, ...):
    axis = _validate_reduction_axis(x, axis)
    # 将维度分为"保留维"和"归约维"
    kept_sizes, reduced_sizes = [], []
    for i in range(len(size)):
        if i in axis:
            reduced_sizes.append(size[i])
        else:
            kept_sizes.append(size[i])

    def loader(index, reduction_index):
        # 将 kept_index 和 reduction_index 重新组装为原始索引
        new_index[kept_idx] = index
        new_index[reduced_idx] = reduction_index
        return inner_loader(new_index)

    return dict(
        inner_fn=loader,
        ranges=new_size,               # 输出形状(保留维)
        reduction_ranges=reduced_sizes,  # 归约维度(将被遍历求和/求max等)
    )

9.3 mean 的 lowering------展示组合模式

python 复制代码
@register_lowering(aten.mean)
def mean(x, axis=None, keepdim=False, *, dtype=None):
    if output_dtype in (torch.float16, torch.bfloat16):
        x = to_dtype(x, torch.float)  # 高精度计算
    sum_result = sum_(x, axis, keepdim)
    denom = sympy_product(size[i] for i in axis)  # 符号化除数
    denom = ir.IndexingConstant(index=denom, ...)
    denom = ExpandView.create(denom, list(sum_result.get_size()))
    return to_dtype(div(sum_result, denom), output_dtype)

9.4 累积操作------scan vs fallback

python 复制代码
@register_lowering(aten.cumsum)
def cumsum(x, axis=None, dtype=None):
    if use_triton_scan(x):
        # GPU:使用 Triton 的 associative_scan
        kwargs = _make_scan_inner(x, axis=axis, dtype=dtype)
        result = ir.Scan.create(combine_fn=..., **kwargs)
    else:
        # CPU 或不支持 scan:回退到 eager
        return fallback_cumsum(x, axis, dtype=dtype)

9.5 var_mean 的两种策略

python 复制代码
def var_mean_helper_(x, *, axis, correction, keepdim, return_mean):
    if use_two_step_variance(x, axis, keepdim):
        # 小归约维度:两步法(先 mean,再 sum((x-mean)^2))
        return var_mean_sum_(x, axis, correction, keepdim, return_mean)
    else:
        # 大归约维度:Welford 在线算法(数值更稳定)
        return var_mean_welford_(x, axis, correction=correction, ...)

十、池化操作------Inductor 的模板模式

池化是 lowering.py 中最长的代码区域(~1400 行),因为每种池化都需要精确的索引计算。

10.1 avg_pool2d 的计算结构

python 复制代码
def _avg_poolnd(x, kernel_size, stride, padding, ceil_mode, count_include_pad, ...):
    # 1. 计算输出尺寸
    h_out, ceil_mode_h = pooling_size(h, 0, kernel_size, stride, padding, ceil_mode)
    w_out, ceil_mode_w = pooling_size(w, 1, kernel_size, stride, padding, ceil_mode)

    # 2. 边界条件处理器
    if any(padding) or any(ceil_mode):
        x_loader = constant_boundary_condition(x, 0.0, dim=2)
    else:
        x_loader = x.make_loader()

    # 3. 定义归约内部函数
    def fn(idx, reduction_idx):
        *prefix, bh, bw = idx
        rh, rw = reduction_idx
        ih = bh * stride[0] + rh - padding[0]
        iw = bw * stride[1] + rw - padding[1]
        return x_loader([*prefix, ih, iw])

    # 4. 创建 sum 归约
    result = Reduction.create(
        reduction_type="sum", inner_fn=fn,
        ranges=new_size,               # [N, C, H_out, W_out]
        reduction_ranges=kernel_size,   # [kH, kW]
    )

    # 5. 除以池化窗口大小
    if divisor_override:
        return div(result, divisor_override)
    elif count_include_pad:
        return div(result, kernel_size[0] * kernel_size[1])
    else:
        # 不含 padding 的实际窗口大小(逐位置不同)
        return div(result, compute_pool_size_without_padding(...))

10.2 adaptive_max_pool2d 的窗口限制

python 复制代码
@register_lowering(aten.adaptive_max_pool2d)
def adaptive_max_pool2d(x, output_size):
    window_size = h_kernel_max * w_kernel_max
    if window_size > 25:
        # 窗口太大 → Triton 展开后代码难以优化
        return fallback_adaptive_max_pool2d(x, output_size)
    # 否则:展开为 Pointwise
    ...

Inductor 对池化窗口大小有硬限制------超过 25 就回退,因为 Triton 展开大循环后产生的 IR 节点过多,编译时间和代码质量都会劣化。


十一、数学运算------特殊优化

11.1 add 的 FMA 优化

python 复制代码
add = register_pointwise(
    aten.add, allow_alpha=True, use_fma_for_alpha=True
)

a + alpha * b 在 CUDA 浮点上调用时,Inductor 自动使用 FMA(Fused Multiply-Add)指令:

python 复制代码
def _add_with_alpha_fma(a, b, alpha):
    def inner_fn(idx):
        a_val = a_loader(idx)
        b_val = b_loader(idx)
        alpha_expr = ops.constant(alpha, dtype)
        return ops.fma(b_val, alpha_expr, a_val)  # b*alpha + a
    return Pointwise.create(...)

11.2 pow 的递归展开

python 复制代码
def pow_recursive(x, y, dtype):
    """将整数幂展开为乘法链,避免调用 pow 函数"""
    if y < 0:
        return ops.reciprocal(pow_recursive(x, -y, dtype))
    if y == 0:
        return ops.constant(1, dtype)
    if y == 1:
        return x
    result = pow_recursive(x, y // 2, dtype)
    result = ops.mul(result, result)
    if y % 2 == 1:
        result = ops.mul(result, x)
    return result

对于 x ** 3,不会调用 pow(x, 3),而是展开为 x * x * x

11.3 mul 的常量折叠

python 复制代码
@register_lowering(aten.mul)
def mul(a, b):
    # 检测 x * 1 = x 和 x * 0 = 0
    if get_constant_value(a) == 1:
        return b
    if get_constant_value(b) == 1:
        return a
    # ... 0 的情况类似
    return _mul(a, b)

11.4 div 的特殊处理

python 复制代码
@register_lowering(aten.div.Tensor)
def div(a, b):
    # 整数除法的特殊语义
    both_integer = is_integer_type(a) and is_integer_type(b)
    if both_integer:
        return truncdiv(a, b)  # 向零截断(而非向下取整)
    return div_prim(a, b)

十二、Foreach 批量操作

Foreach 操作是 PyTorch 优化器的核心------一次 kernel 调用处理整个参数列表:

python 复制代码
def make_foreach_pointwise(pw_fn, allow_alpha=False, scalar_kwarg="alpha"):
    def inner(*inputs, alpha=1, ...):
        # 1. 按 (device, is_dynamic) 分组
        groups = group_foreach_args(zip(*broadcast_inputs))

        # 2. 对每组的每个张量应用 pointwise 函数
        for (device, use_foreach), group in groups.items():
            for output_ind, args in group:
                output = pw_fn(*args, alpha=scalar_val)
                if use_foreach:
                    output.realize()
                    operation_list.append(output.get_operation_name())

        # 3. 注册操作列表,告诉 scheduler 可以水平融合
        V.graph.register_operation_list(operation_list)
    return inner

# 注册示例
foreach_add_list = register_foreach_pointwise(aten._foreach_add.List, add)
foreach_mul_scalar = register_foreach_pointwise(aten._foreach_mul.Scalar, mul)
foreach_div_list = register_foreach_pointwise(aten._foreach_div.List, div)

register_operation_list 告诉 scheduler:这些操作可以打包到同一个 Triton kernel 中。在 Adam 这样的优化器中,这意味着一次 kernel launch 更新所有参数。


十三、Inplace 操作的处理

python 复制代码
def register_inplace(aten_op, outplace_op):
    """将 inplace 操作转为 out-of-place + mutate_to"""
    @register_lowering(aten_op, type_promotion_kind=None)
    def fn(*args, **kwargs):
        result = outplace_op(*args[1:], **kwargs)
        # first arg is self
        return mutate_to(args[0], result)
    return fn

def mutate_to(changed, val, unsafe_alias=False):
    """核心的 mutation 机制------替换底层 data"""
    if isinstance(changed, TensorBox):
        changed_data = changed.data
    else:
        changed_data = changed
    # 替换底层数据指针
    if isinstance(val, TensorBox):
        val = val.data
    ...
    changed_data.data = val  # 底层发生了 mutation
    return changed

所有 inplace 操作(relu_, add_, mul_ 等)都走这条路径。


十四、Layout 约束系统

某些算子对输入的内存布局有严格要求:

python 复制代码
# 约束函数定义
require_contiguous = constrain_to_fx_strides  # 需要连续布局
require_dense = constrain_to_fx_strides       # 需要密集布局

# 使用示例
make_fallback(aten._cudnn_rnn, require_dense)
make_fallback(aten._cdist_forward)            # 无约束

# Tag 系统:从算子元数据自动推导
def tag_to_layout_constraint(tag):
    if tag == Tag.needs_exact_strides:
        return constrain_to_fake_tensors   # 必须匹配 fake tensor 的精确 stride
    if tag == Tag.needs_contiguous_strides:
        return require_contiguous_strides   # 必须行优先连续
    if tag == Tag.needs_fixed_stride_order:
        return constrain_to_fx_strides      # 必须匹配 FX 图中的 stride 顺序
    if tag == Tag.flexible_layout:
        return None                         # 无约束

十五、高阶算子的 lowering

15.1 with_effects(带副作用的算子)

python 复制代码
@register_lowering(torch.ops.higher_order.with_effects, type_promotion_kind=None)
def with_effects(token, op, *args, **kwargs):
    """处理带副作用的算子(如 RNG、IO)"""
    # 1. 获取副作用类型
    effect_type = _get_effect(op)

    # 2. 正常 lower 内部算子
    operation_len = len(V.graph.operations)
    result = lowerings[op](*args, **kwargs)

    # 3. 为新创建的 IR 节点添加 StarDep(顺序依赖)
    prev_effect_buffer = V.graph.effectful_ops.get(effect_type)
    for new_op in V.graph.operations[operation_len:]:
        new_op.has_side_effects = lambda: True
        if prev_effect_buffer:
            V.graph.additional_star_deps[op_name].add(prev_effect_buffer.get_name())

    return (token, result)

15.2 associative_scan(关联扫描)

python 复制代码
@register_lowering(associative_scan_op, type_promotion_kind=None)
def associative_scan(combine_fn, xs, additional_inputs):
    # 1. 将子图 lowering 为 pointwise
    lowered_combine_fn = lower_pointwise_subgraph(combine_fn, subgraph_inputs)

    # 2. 创建 Scan IR
    result = ir.Scan.create(combine_fn=wrapped_combine_fn, ...)

十六、扩展性------外部模块注册

文件末尾通过导入子模块来注册更多 lowering:

python 复制代码
from . import kernel           # mm, conv, bmm 等模板化 kernel
import_submodule(kernel)       # 递归导入 kernel/ 下所有子模块

from . import quantized_lowerings
quantized_lowerings.register_quantized_ops()     # 量化算子
quantized_lowerings.register_woq_mm_ops()        # Weight-Only 量化矩阵乘

from . import mkldnn_lowerings
mkldnn_lowerings.register_onednn_fusion_ops()    # Intel oneDNN 融合算子

from . import jagged_lowerings
jagged_lowerings.register_jagged_ops()           # Jagged Tensor 算子

from .comm_lowering import register_comm_lowerings
register_comm_lowerings()                        # 分布式通信算子

这意味着实际可用的 lowering 远不止本文件中的 ------kernel/ 目录下有 mm、conv、bmm 等核心计算密集型算子的模板化实现。


十七、数据流全景图

复制代码
FX Graph (ATen ops)
    │
    ▼ GraphLowering.run()
遍历 FX 图每个节点:
    │
    ├── node.target in lowerings?
    │   ├── YES → 调用 lowerings[target](*args)
    │   │         ├── Pointwise 算子 → Pointwise.create(inner_fn=...)
    │   │         ├── 视图操作 → View/PermuteView/ExpandView/...
    │   │         ├── 归约操作 → Reduction.create(reduction_type=...)
    │   │         ├── Index/Scatter → ops.indirect_indexing(...)
    │   │         └── 池化/卷积 → 组合 Reduction + Pointwise
    │   │
    │   └── NO → fallback_node_due_to_unsupported_type?
    │       ├── YES → FallbackKernel.create(target, *args)
    │       └── NO → RuntimeError("missing lowering")
    │
    ▼
Inductor IR Graph
    │
    ▼ scheduler.py
Fused Kernel Groups → Triton/C++ Codegen

十八、统计概览

维度 数据
文件总行数 8384
@register_lowering 装饰器 168 个
make_fallback 调用 124 个
register_pointwise 调用 ~50 个
register_foreach_pointwise ~15 个
register_inplace ~20 个
外部模块注册 5 个(kernel, quantized, mkldnn, jagged, comm)
实际覆盖 ATen 算子数 ~400+

十九、阅读路线建议

复制代码
Level 1:理解注册机制
    → register_lowering (528行) --- 装饰器
    → _register_lowering (475行) --- 实际逻辑
    → register_pointwise (994行) --- 批量注册

Level 2:理解 Pointwise IR 构造
    → make_pointwise (668行) --- 核心工厂
    → 选 relu 或 add 作为入门示例

Level 3:理解视图 lowering
    → view/reshape (1376行) --- 最简
    → slice_ (1390行) --- 含动态符号处理
    → as_strided (1518行) --- 最底层

Level 4:理解归约 lowering
    → make_reduction (6548行) --- 工厂模式
    → mean (6642行) --- 展示组合模式
    → var_mean (6751行) --- 两种算法策略

Level 5:理解复杂 lowering
    → cat (1941行) --- Pointwise vs ConcatKernel 决策
    → index_put_impl_ --- 最复杂的 mutation 操作
    → avg_pool2d (5920行) --- 典型的池化归约模式

Level 6:理解系统边界
    → make_fallback (2462行) --- 何时放弃编译
    → fallback_handler (2348行) --- FallbackKernel 如何工作
    → with_effects (8109行) --- 副作用处理

二十、总结

lowering.py 的 8384 行代码定义了 Inductor 编译器的算子覆盖边界------它决定了哪些 ATen 算子可以被编译为高效 Triton/C++ kernel,哪些必须回退到 eager 执行。

三个层面的设计哲学:

  1. 分层抽象register_pointwise(一行注册)→ make_pointwise(IR 构造工厂)→ Pointwise.create(IR 节点)→ Triton codegen(最终代码)
  2. 渐进覆盖 :优先实现高频算子的原生 lowering,其余通过 make_fallback 兜底,并在 CI 中持续推进分解/lowering 覆盖
  3. 性能导向:FMA 优化(add+alpha)、pow 递归展开、cat→pointwise 融合、池化窗口限制------每个细节都服务于生成更快的 kernel

这个文件就像 Inductor 的"翻译词典"------左边是 ATen 算子语义,右边是 Inductor IR 的构造方法。掌握它,就掌握了 PyTorch 编译器"如何将算子变成代码"的关键一步。

相关推荐
断眉的派大星4 小时前
PyTorch 计算图与自动求导机制(超通俗精讲)
人工智能·pytorch·python
ACGkaka_4 小时前
ES 学习(七)性能陷阱
大数据·学习·elasticsearch
我的xiaodoujiao4 小时前
API 接口自动化测试详细图文教程学习系列10--Requests模块2--举例说明
python·学习·测试工具·pytest
CHU7290354 小时前
在线教学课堂APP功能版块设计方案:重构学习场景的交互逻辑
java·学习·小程序·重构
一定要AK4 小时前
HTML5 入门到精通全章节学习笔记
笔记·学习·html5
程序员zgh4 小时前
C/C++ 单元测试系统 构建
c语言·开发语言·c++·学习·单元测试
沪漂阿龙4 小时前
PyTorch 深度学习完全指南:从激活函数到房价预测实战
人工智能·pytorch·深度学习
Chef_Chen4 小时前
Agent学习-RAG--上下文压缩与知识库的更新
人工智能·学习·自然语言处理
计算机学姐4 小时前
基于SpringBoot的在线学习网站平台【个性化推荐+数据可视化+课程章节学习】
java·vue.js·spring boot·后端·学习·mysql·信息可视化