本文聚焦
torch/_inductor/lowering.py------一个 8384 行 的核心文件。它是 Inductor 编译器的"翻译中枢",负责将 FX 图中的 ATen 算子翻译为 Inductor 内部 IR(Pointwise、Reduction、FallbackKernel等),从而驱动后续的 Triton/C++ 代码生成。如果decompositions.py是"算子拆解词典",那么lowering.py就是"算子 → IR 编译规则大全"。
一、Lowering 在编译栈中的位置
用户代码: y = torch.relu(x) + 1
│
▼ TorchDynamo 捕获
FX Graph: [aten.relu, aten.add]
│
▼ AOT Autograd + 算子分解
FX Graph: [aten.relu, aten.add] (部分算子已分解)
│
▼ ★ Lowering (lowering.py) ★
Inductor IR: [Pointwise(relu), Pointwise(add)]
│
▼ Scheduler (scheduler.py)
FusedSchedulerNode: [Pointwise(relu + add)] ← kernel 融合
│
▼ Codegen
Triton Kernel: @triton.jit def fused_relu_add(...)
核心职责:为每个 ATen/Prims 算子提供"如何构造 Inductor IR 节点"的规则。这些规则定义了:
- 输出的 shape 如何计算
- 数据的读取方式(
make_loader/inner_fn) - 产出的 IR 类型(
Pointwise、Reduction、FallbackKernel等)
二、文件全局架构
lowering.py (8384 行)
│
├── [1-130] 全局变量 & 注册表
│ lowerings = {} ← 核心字典:op → lowering_fn
│ user_lowerings = {} ← 用户自定义 lowering(优先级最高)
│ fallbacks = OrderedSet() ← 回退到 eager 执行的算子集
│ needs_realized_inputs ← 需要物化输入的算子(如 conv, mm)
│ foreach_ops ← foreach 批量算子集
│
├── [130-550] 注册机制 & 核心辅助
│ register_lowering() ← 装饰器:注册算子 lowering
│ _register_lowering() ← 实际注册逻辑(类型提升+广播+验证)
│ register_pointwise() ← 逐元素算子批量注册
│ make_pointwise() ← 构造 Pointwise IR 的工厂函数
│ make_fallback() ← 注册 fallback 到 eager 执行
│ transform_args() ← 参数预处理(广播+类型提升)
│
├── [550-860] 广播 & 类型提升
│ broadcast_symbolic_shapes() ← 符号形状的广播推导
│ promote_constants() ← 标量 → Tensor 常量提升
│ make_foreach_pointwise() ← foreach 批量逐元素工厂
│
├── [860-1130] 类型转换 lowering
│ _convert_element_type() ← prims.convert_element_type
│ to_dtype() / to_dtype_bitcast() / to_device()
│ _foreach_map() ← foreach_map 高阶算子
│
├── [1130-2340] 视图 & 形状操作 lowering (~35 个)
│ view, reshape, permute, expand, squeeze, unsqueeze
│ slice_, as_strided, select, split, unbind, unfold
│ cat, diagonal, diagonal_scatter
│
├── [2340-2630] Fallback 机制
│ fallback_handler() ← 创建 FallbackKernel 的通用处理器
│ make_fallback() ← 注册 fallback + 断言无冲突分解
│ fallback_node_due_to_unsupported_type() ← 自动检测需要回退的类型
│
├── [2630-2770] 随机数算子
│ philox_rand, native_dropout, bernoulli_, rand, randn
│ inductor_seed / inductor_seeds / inductor_lookup_seed
│
├── [2770-3700] Make_fallback 批量注册 (~124 个)
│ 按难度分六大类:Easy / Medium / Difficult / Backwards
│ Impossible(sort,linalg) / Pattern-matched(SDPA)
│
├── [3700-4400] 张量创建 & 填充
│ full, empty, empty_strided, zeros_like, ones_like
│ tensor, arange, iota, linspace
│ gather, embedding, index, index_put
│
├── [4400-4700] Scatter / Index 操作
│ scatter, scatter_add, scatter_reduce
│ index_put_impl_ (核心 ~200 行)
│
├── [4700-5100] Padding 操作
│ constant_pad_nd (含 pad_as_cat 优化)
│ reflection_pad, replication_pad
│
├── [5100-6500] 池化 & 上采样 lowering
│ avg_pool2d/3d, max_pool2d/3d_with_indices
│ adaptive_avg_pool2d, adaptive_max_pool2d
│ fractional_max_pool2d/3d
│ upsample_nearest2d_backward
│ avg_pool2d_backward, avg_pool3d_backward
│
├── [6500-6900] 归约操作
│ make_reduction() ← 归约 IR 工厂
│ mean, var_mean, sum, prod
│ cumsum, cumprod, cummax, cummin
│
├── [6900-7550] 数学运算 & 逐元素注册
│ pow, div, mul, add, sub ← 手写 lowering
│ register_pointwise(aten.relu) ← 批量注册 (~50 个)
│ register_foreach_pointwise() ← foreach 变体
│
├── [7550-7780] Inplace 变体注册
│ register_inplace() / register_foreach_inplace()
│
├── [7780-8060] 高阶算子 lowering
│ cond, while_loop, invoke_subgraph
│ associative_scan, with_effects
│
└── [8060-8384] 扩展注册 & 外部模块
│ kernel/ 子模块导入 (mm, conv, 等模板化 kernel)
│ quantized_lowerings, mkldnn_lowerings, jagged_lowerings
│ comm_lowering (分布式通信)
│ inline_asm_elementwise, cvt_e8m0_rceil
三、注册机制------三驾马车
3.1 register_lowering:核心装饰器
python
lowerings: dict[Callable | str, Callable] = {} # 全局 lowering 注册表
def register_lowering(
aten_fn,
broadcast=False,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
convert_input_to_bool=False,
):
"""装饰器:将 Python 函数注册为某个 ATen 算子的 lowering 实现"""
return functools.partial(
_register_lowering, aten_fn,
broadcast=broadcast,
type_promotion_kind=type_promotion_kind,
...
)
_register_lowering 内部做三件事:
1. 参数预处理 → transform_args(args, kwargs, broadcast, type_promotion)
├── 类型提升:INT → FLOAT, FP16 → FP32 等
└── 广播对齐:broadcast_tensors(*inputs)
2. 调用 decomp_fn → out = decomp_fn(*args, **kwargs)
3. 输出验证 → validate_ir(out) # 确保产出合法的 IR 节点
3.2 register_pointwise:逐元素算子的批量注册
python
def register_pointwise(
aten_fn, name=None, broadcast=True,
type_promotion_kind=DEFAULT,
override_return_dtype=None,
override_fn_when_input_bool=None,
allow_alpha=False,
...
):
"""一行代码注册一个逐元素算子"""
name = name or aten_fn.__name__
fn = ops_wrapper(name) # → ops.relu, ops.sigmoid, ...
fn = make_pointwise(fn, ...) # → 包装为 Pointwise IR 构造器
fn = register_lowering(aten_fn, ...)(fn) # → 注册到 lowerings 字典
return fn
# 使用示例------一行注册一个算子:
relu = register_pointwise(aten.relu)
sigmoid = register_pointwise_numeric_ldf64(aten.sigmoid)
exp = register_pointwise_numeric_ldf64(aten.exp)
sqrt = register_pointwise_numeric_ldf64(aten.sqrt)
abs = register_pointwise(aten.abs)
neg = register_pointwise(aten.neg)
add = register_pointwise(aten.add, allow_alpha=True, use_fma_for_alpha=True)
sub = register_pointwise(aten.sub, allow_alpha=True)
mul = ... # 手写(含常量优化)
文件中约有 50+ 个算子 通过 register_pointwise 一行注册。
3.3 make_fallback:不能编译的算子回退到 eager
python
def make_fallback(op, layout_constraint=None, warn=True):
"""将算子注册为 FallbackKernel ------ 在运行时直接调用 ATen eager 实现"""
# 1. 检查:如果存在分解,优先用分解,不要 fallback
assert op not in decompositions or override_decomp
# 2. CI 中强制报错防止遗漏
if warn and bool(os.getenv("CI")) and get_decompositions([op]):
raise AssertionError("a decomposition exists, we should switch to it")
# 3. 注册
add_needs_realized_inputs(op) # fallback 前必须物化所有输入
add_layout_constraint(op, ...) # 约束输入的内存布局
register_lowering(op)(fallback_handler(op))
文件中共有 124 个 make_fallback 调用,按难度分为六档:
| 级别 | 说明 | 代表算子 |
|---|---|---|
| Easy | 有分解但暂未迁移 | uniform, exponential, soft_margin_loss_backward |
| Medium | 需要新 IR 或较复杂实现 | _trilinear |
| Difficult | 需要新 IR 类型 | segment_reduce, histc, histogram |
| Backwards | 前向有实现但反向未实现 | _adaptive_avg_pool2d_backward, grid_sampler_2d_backward |
| Impossible | 缺少 Triton/CPU 特性 | sort, topk, kthvalue, 全部 linalg.* |
| Pattern-matched | 由模式匹配器处理 | 全部 _scaled_dot_product_*_attention* |
四、make_pointwise------逐元素 IR 构造的核心
这是理解整个 lowering 系统的关键函数(~100 行):
python
def make_pointwise(fn, override_return_dtype=None, ...):
"""将一个标量函数 fn 包装为 Pointwise IR 构造器"""
def inner(*inputs: TensorBox, alpha=None):
# 1. 常量提升:标量 → 常量 Tensor
inputs = promote_constants(inputs, override_return_dtype)
# 2. alpha 处理:a + alpha * b
if allow_alpha and alpha is not None and alpha != 1:
if use_fma_for_alpha and is_cuda_float:
return _add_with_alpha_fma(inputs[0], inputs[1], alpha)
inputs[-1] = mul(inputs[-1], alpha)
# 3. 创建加载器
loaders = [x.make_loader() for x in inputs]
ranges = inputs[0].get_size()
dtype = override_return_dtype or inputs[0].get_dtype()
# 4. 定义逐元素计算函数
def inner_fn(index):
return fn(*[load(index) for load in loaders])
# 5. 构造 Pointwise IR 节点
return Pointwise.create(
device=device,
dtype=dtype,
inner_fn=inner_fn, # ← 这个闭包就是将来 codegen 的核心
ranges=ranges,
)
return inner
核心思想 :inner_fn 是一个接受 index 返回标量表达式的闭包。它定义了"给定输出位置 (i,j,k,...),如何计算该位置的值"。Inductor 的 scheduler 和 codegen 会遍历输出的每个索引,调用 inner_fn 来生成计算代码。
五、精度仿真机制
make_pointwise 中嵌入了一个精妙的低精度仿真系统:
python
# 在 V.graph.current_node.meta 中标记了 low_precision_pointwise_barrier
emulate_precision_casts = V.graph.current_node.meta.get(
"low_precision_pointwise_barrier", False
)
def inner_fn(index):
inputs_loaded = []
for inp_index, load in enumerate(loaders):
out = load(index)
inp_dtype = inputs[inp_index].get_dtype()
if emulate_precision_casts and inp_dtype in (torch.float16, torch.bfloat16):
# 1. 先降到低精度(模拟从内存读取)
downcast = ops.to_dtype(out, inp_dtype, use_compute_types=False)
# 2. 再升到计算精度(模拟 GPU 寄存器提升)
out = ops.to_dtype(downcast, inp_dtype)
inputs_loaded.append(out)
out = fn(*inputs_loaded)
if emulate_output_cast:
# 输出也做同样的降→升,模拟回写内存再读取的精度损失
downcast = ops.to_dtype(out, dtype, use_compute_types=False)
return ops.to_dtype(downcast, dtype)
return out
这保证了编译后的 fused kernel 与非融合版本(每步都回写内存)有一致的数值行为。
六、视图操作的 lowering 机制
视图操作不产生计算,只产生 IR 中的"视图节点":
6.1 直通视图(零开销)
python
@register_lowering([aten.alias, aten.detach, aten.detach_, aten.lift, prims.view_of])
def nop(x):
return x # 什么都不做,直接返回
6.2 形状变换视图
python
@register_lowering(aten.view, type_promotion_kind=None)
@register_lowering(aten.reshape, type_promotion_kind=None)
@register_lowering(aten._unsafe_view, type_promotion_kind=None)
def view(x, sizes):
return TensorBox(View.create(x.data, sizes)) # 创建 View IR
@register_lowering(aten.permute, type_promotion_kind=None)
def permute(x, dims):
return TensorBox(PermuteView.create(x.data, tuple(dims)))
@register_lowering(aten.expand, type_promotion_kind=None)
def expand(x, sizes):
...
return TensorBox(ExpandView.create(x.data, sizes))
@register_lowering(aten.unsqueeze, type_promotion_kind=None)
def unsqueeze(x, dim):
return TensorBox(SqueezeView.create(x.data, dim=dim))
6.3 as_strided---最底层的视图
python
@register_lowering(aten.as_strided, type_promotion_kind=None)
def as_strided(x, size, stride, storage_offset=None):
# 如果 x 已经是视图,先解包到底层 storage
if isinstance(x.data, ir.BaseView):
new_dtype = x.dtype
x = x.data.unwrap_view()
x.realize() # 物化底层 buffer
storage, old_layout = ir.as_storage_and_layout(x)
# 构造新的固定布局
new_layout = ir.FixedLayout(
device, dtype, size, stride, storage_offset or 0
)
return TensorBox(ir.ReinterpretView(data=storage, layout=new_layout))
6.4 slice --- 包含动态尺寸处理
slice_ 是文件中最长的视图 lowering(~130 行),需要处理动态 unbacked 符号的情况:
静态/有后端的符号 → 直接创建 SliceView (快速路径)
│
动态 unbacked 符号 → 创建 DynamicSliceSize + DynamicSelectStorageOffset
(ExternKernel,在运行时计算 size 和 offset)
七、Cat 操作------Inductor 的两种实现
cat 是文件中第二长的单个 lowering(~170 行),有两条截然不同的路径:
7.1 Pointwise Cat(GPU 优化路径)
python
def pointwise_cat(inputs, dim=0):
"""将 cat 实现为一个 Pointwise 节点,通过条件分支选择加载源"""
def inner_fn(idx):
idx_dim = ops.index_expr(idx[dim], torch.int64)
# 对每个输入张量生成 mask 和 masked_load
for i in range(len(inputs)):
mask = ops.and_(ops.ge(idx_dim, start), ops.lt(idx_dim, end))
masked_loads.append(ops.masked(mask, lambda: loaders[i](...), 0.0))
# 链式 where 选择正确的源
return ops.where(masks[0], masked_loads[0],
ops.where(masks[1], masked_loads[1], ...))
优势:cat 与下游计算可以被 融合进同一个 kernel,避免额外的内存读写。
7.2 ConcatKernel(通用路径)
python
@register_lowering(aten.cat)
def cat(inputs, dim=0):
# 决策:用 pointwise 还是 ConcatKernel
if should_use_pointwise_cat(inputs, dim):
return pointwise_cat(inputs, dim)
else:
# ConcatKernel:分配输出 buffer,逐个 copy 输入
return TensorBox(ir.ConcatKernel.create(inputs, dim))
八、索引操作------Index / Scatter 体系
8.1 index(高级索引读取)
python
@register_lowering(aten.index, type_promotion_kind=None)
def index(x, indices):
try:
return index_impl(x, indices, check=True)
except NotImplementedError:
# 布尔索引:回退到 ATen
return fallback_handler(aten.index.Tensor)(x, indices)
index_impl 的核心是构造一个 inner_fn,通过 ops.indirect_indexing 实现间接寻址:
python
def fn(idx):
gather_idx = ops.indirect_indexing(index_loader(idx), size[dim])
idx[dim] = gather_idx
return x_loader(idx)
ops.indirect_indexing 在 Triton codegen 中会生成 tl.load(ptr + gather_idx * stride)。
8.2 index_put_impl_ ------ 最复杂的 lowering 之一
index_put_(self, indices, values, accumulate)
│
├── 单个布尔索引 + 标量 value → index_put_as_masked_fill
│ (优化为 where 操作)
│
├── 可以用 scatter → ir.ScatterKernel
│ (accumulate=True 时用 atomic_add)
│
└── 复杂情况 → index_put_fallback
(调用 aten.index_put_ 的 eager 实现)
8.3 scatter 系列
python
@register_lowering(aten.scatter, type_promotion_kind=None)
def scatter(x, dim, index, src_or_value, **kwargs):
return scatter_fallback(..., op=aten.scatter_.value, ...)
@register_lowering(aten.scatter_add, type_promotion_kind=None)
def scatter_add(x, dim, index, src):
return scatter_fallback(
..., reduce="sum",
op=aten.scatter_add_, ...
)
特殊情况下 scatter 会激活 atomic 操作的限制检测:
python
if needs_fallback_due_to_atomic_add_limitations(dtype):
# 某些 dtype 没有 Triton 原子操作支持
return fallback_handler(op)(...)
九、归约操作的 lowering
9.1 make_reduction 工厂
python
def make_reduction(reduction_type, override_return_dtype=None):
def inner(x, axis=None, keepdims=False, *, dtype=None):
kwargs = _make_reduction_inner(x, axis=axis, keepdims=keepdims, ...)
result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs)
if isinstance(result.data.data, Reduction):
result.realize() # 非展开的归约需要物化
return result
return inner
# 注册归约算子
reduce_amax = register_lowering(aten.amax)(make_reduction("max"))
reduce_amin = register_lowering(aten.amin)(make_reduction("min"))
reduce_argmax = register_lowering(aten.argmax)(make_reduction("argmax", torch.int64))
reduce_argmin = register_lowering(aten.argmin)(make_reduction("argmin", torch.int64))
9.2 _make_reduction_inner 的轴处理
python
def _make_reduction_inner(x, *, axis, keepdims, dtype, ...):
axis = _validate_reduction_axis(x, axis)
# 将维度分为"保留维"和"归约维"
kept_sizes, reduced_sizes = [], []
for i in range(len(size)):
if i in axis:
reduced_sizes.append(size[i])
else:
kept_sizes.append(size[i])
def loader(index, reduction_index):
# 将 kept_index 和 reduction_index 重新组装为原始索引
new_index[kept_idx] = index
new_index[reduced_idx] = reduction_index
return inner_loader(new_index)
return dict(
inner_fn=loader,
ranges=new_size, # 输出形状(保留维)
reduction_ranges=reduced_sizes, # 归约维度(将被遍历求和/求max等)
)
9.3 mean 的 lowering------展示组合模式
python
@register_lowering(aten.mean)
def mean(x, axis=None, keepdim=False, *, dtype=None):
if output_dtype in (torch.float16, torch.bfloat16):
x = to_dtype(x, torch.float) # 高精度计算
sum_result = sum_(x, axis, keepdim)
denom = sympy_product(size[i] for i in axis) # 符号化除数
denom = ir.IndexingConstant(index=denom, ...)
denom = ExpandView.create(denom, list(sum_result.get_size()))
return to_dtype(div(sum_result, denom), output_dtype)
9.4 累积操作------scan vs fallback
python
@register_lowering(aten.cumsum)
def cumsum(x, axis=None, dtype=None):
if use_triton_scan(x):
# GPU:使用 Triton 的 associative_scan
kwargs = _make_scan_inner(x, axis=axis, dtype=dtype)
result = ir.Scan.create(combine_fn=..., **kwargs)
else:
# CPU 或不支持 scan:回退到 eager
return fallback_cumsum(x, axis, dtype=dtype)
9.5 var_mean 的两种策略
python
def var_mean_helper_(x, *, axis, correction, keepdim, return_mean):
if use_two_step_variance(x, axis, keepdim):
# 小归约维度:两步法(先 mean,再 sum((x-mean)^2))
return var_mean_sum_(x, axis, correction, keepdim, return_mean)
else:
# 大归约维度:Welford 在线算法(数值更稳定)
return var_mean_welford_(x, axis, correction=correction, ...)
十、池化操作------Inductor 的模板模式
池化是 lowering.py 中最长的代码区域(~1400 行),因为每种池化都需要精确的索引计算。
10.1 avg_pool2d 的计算结构
python
def _avg_poolnd(x, kernel_size, stride, padding, ceil_mode, count_include_pad, ...):
# 1. 计算输出尺寸
h_out, ceil_mode_h = pooling_size(h, 0, kernel_size, stride, padding, ceil_mode)
w_out, ceil_mode_w = pooling_size(w, 1, kernel_size, stride, padding, ceil_mode)
# 2. 边界条件处理器
if any(padding) or any(ceil_mode):
x_loader = constant_boundary_condition(x, 0.0, dim=2)
else:
x_loader = x.make_loader()
# 3. 定义归约内部函数
def fn(idx, reduction_idx):
*prefix, bh, bw = idx
rh, rw = reduction_idx
ih = bh * stride[0] + rh - padding[0]
iw = bw * stride[1] + rw - padding[1]
return x_loader([*prefix, ih, iw])
# 4. 创建 sum 归约
result = Reduction.create(
reduction_type="sum", inner_fn=fn,
ranges=new_size, # [N, C, H_out, W_out]
reduction_ranges=kernel_size, # [kH, kW]
)
# 5. 除以池化窗口大小
if divisor_override:
return div(result, divisor_override)
elif count_include_pad:
return div(result, kernel_size[0] * kernel_size[1])
else:
# 不含 padding 的实际窗口大小(逐位置不同)
return div(result, compute_pool_size_without_padding(...))
10.2 adaptive_max_pool2d 的窗口限制
python
@register_lowering(aten.adaptive_max_pool2d)
def adaptive_max_pool2d(x, output_size):
window_size = h_kernel_max * w_kernel_max
if window_size > 25:
# 窗口太大 → Triton 展开后代码难以优化
return fallback_adaptive_max_pool2d(x, output_size)
# 否则:展开为 Pointwise
...
Inductor 对池化窗口大小有硬限制------超过 25 就回退,因为 Triton 展开大循环后产生的 IR 节点过多,编译时间和代码质量都会劣化。
十一、数学运算------特殊优化
11.1 add 的 FMA 优化
python
add = register_pointwise(
aten.add, allow_alpha=True, use_fma_for_alpha=True
)
当 a + alpha * b 在 CUDA 浮点上调用时,Inductor 自动使用 FMA(Fused Multiply-Add)指令:
python
def _add_with_alpha_fma(a, b, alpha):
def inner_fn(idx):
a_val = a_loader(idx)
b_val = b_loader(idx)
alpha_expr = ops.constant(alpha, dtype)
return ops.fma(b_val, alpha_expr, a_val) # b*alpha + a
return Pointwise.create(...)
11.2 pow 的递归展开
python
def pow_recursive(x, y, dtype):
"""将整数幂展开为乘法链,避免调用 pow 函数"""
if y < 0:
return ops.reciprocal(pow_recursive(x, -y, dtype))
if y == 0:
return ops.constant(1, dtype)
if y == 1:
return x
result = pow_recursive(x, y // 2, dtype)
result = ops.mul(result, result)
if y % 2 == 1:
result = ops.mul(result, x)
return result
对于 x ** 3,不会调用 pow(x, 3),而是展开为 x * x * x。
11.3 mul 的常量折叠
python
@register_lowering(aten.mul)
def mul(a, b):
# 检测 x * 1 = x 和 x * 0 = 0
if get_constant_value(a) == 1:
return b
if get_constant_value(b) == 1:
return a
# ... 0 的情况类似
return _mul(a, b)
11.4 div 的特殊处理
python
@register_lowering(aten.div.Tensor)
def div(a, b):
# 整数除法的特殊语义
both_integer = is_integer_type(a) and is_integer_type(b)
if both_integer:
return truncdiv(a, b) # 向零截断(而非向下取整)
return div_prim(a, b)
十二、Foreach 批量操作
Foreach 操作是 PyTorch 优化器的核心------一次 kernel 调用处理整个参数列表:
python
def make_foreach_pointwise(pw_fn, allow_alpha=False, scalar_kwarg="alpha"):
def inner(*inputs, alpha=1, ...):
# 1. 按 (device, is_dynamic) 分组
groups = group_foreach_args(zip(*broadcast_inputs))
# 2. 对每组的每个张量应用 pointwise 函数
for (device, use_foreach), group in groups.items():
for output_ind, args in group:
output = pw_fn(*args, alpha=scalar_val)
if use_foreach:
output.realize()
operation_list.append(output.get_operation_name())
# 3. 注册操作列表,告诉 scheduler 可以水平融合
V.graph.register_operation_list(operation_list)
return inner
# 注册示例
foreach_add_list = register_foreach_pointwise(aten._foreach_add.List, add)
foreach_mul_scalar = register_foreach_pointwise(aten._foreach_mul.Scalar, mul)
foreach_div_list = register_foreach_pointwise(aten._foreach_div.List, div)
register_operation_list 告诉 scheduler:这些操作可以打包到同一个 Triton kernel 中。在 Adam 这样的优化器中,这意味着一次 kernel launch 更新所有参数。
十三、Inplace 操作的处理
python
def register_inplace(aten_op, outplace_op):
"""将 inplace 操作转为 out-of-place + mutate_to"""
@register_lowering(aten_op, type_promotion_kind=None)
def fn(*args, **kwargs):
result = outplace_op(*args[1:], **kwargs)
# first arg is self
return mutate_to(args[0], result)
return fn
def mutate_to(changed, val, unsafe_alias=False):
"""核心的 mutation 机制------替换底层 data"""
if isinstance(changed, TensorBox):
changed_data = changed.data
else:
changed_data = changed
# 替换底层数据指针
if isinstance(val, TensorBox):
val = val.data
...
changed_data.data = val # 底层发生了 mutation
return changed
所有 inplace 操作(relu_, add_, mul_ 等)都走这条路径。
十四、Layout 约束系统
某些算子对输入的内存布局有严格要求:
python
# 约束函数定义
require_contiguous = constrain_to_fx_strides # 需要连续布局
require_dense = constrain_to_fx_strides # 需要密集布局
# 使用示例
make_fallback(aten._cudnn_rnn, require_dense)
make_fallback(aten._cdist_forward) # 无约束
# Tag 系统:从算子元数据自动推导
def tag_to_layout_constraint(tag):
if tag == Tag.needs_exact_strides:
return constrain_to_fake_tensors # 必须匹配 fake tensor 的精确 stride
if tag == Tag.needs_contiguous_strides:
return require_contiguous_strides # 必须行优先连续
if tag == Tag.needs_fixed_stride_order:
return constrain_to_fx_strides # 必须匹配 FX 图中的 stride 顺序
if tag == Tag.flexible_layout:
return None # 无约束
十五、高阶算子的 lowering
15.1 with_effects(带副作用的算子)
python
@register_lowering(torch.ops.higher_order.with_effects, type_promotion_kind=None)
def with_effects(token, op, *args, **kwargs):
"""处理带副作用的算子(如 RNG、IO)"""
# 1. 获取副作用类型
effect_type = _get_effect(op)
# 2. 正常 lower 内部算子
operation_len = len(V.graph.operations)
result = lowerings[op](*args, **kwargs)
# 3. 为新创建的 IR 节点添加 StarDep(顺序依赖)
prev_effect_buffer = V.graph.effectful_ops.get(effect_type)
for new_op in V.graph.operations[operation_len:]:
new_op.has_side_effects = lambda: True
if prev_effect_buffer:
V.graph.additional_star_deps[op_name].add(prev_effect_buffer.get_name())
return (token, result)
15.2 associative_scan(关联扫描)
python
@register_lowering(associative_scan_op, type_promotion_kind=None)
def associative_scan(combine_fn, xs, additional_inputs):
# 1. 将子图 lowering 为 pointwise
lowered_combine_fn = lower_pointwise_subgraph(combine_fn, subgraph_inputs)
# 2. 创建 Scan IR
result = ir.Scan.create(combine_fn=wrapped_combine_fn, ...)
十六、扩展性------外部模块注册
文件末尾通过导入子模块来注册更多 lowering:
python
from . import kernel # mm, conv, bmm 等模板化 kernel
import_submodule(kernel) # 递归导入 kernel/ 下所有子模块
from . import quantized_lowerings
quantized_lowerings.register_quantized_ops() # 量化算子
quantized_lowerings.register_woq_mm_ops() # Weight-Only 量化矩阵乘
from . import mkldnn_lowerings
mkldnn_lowerings.register_onednn_fusion_ops() # Intel oneDNN 融合算子
from . import jagged_lowerings
jagged_lowerings.register_jagged_ops() # Jagged Tensor 算子
from .comm_lowering import register_comm_lowerings
register_comm_lowerings() # 分布式通信算子
这意味着实际可用的 lowering 远不止本文件中的 ------kernel/ 目录下有 mm、conv、bmm 等核心计算密集型算子的模板化实现。
十七、数据流全景图
FX Graph (ATen ops)
│
▼ GraphLowering.run()
遍历 FX 图每个节点:
│
├── node.target in lowerings?
│ ├── YES → 调用 lowerings[target](*args)
│ │ ├── Pointwise 算子 → Pointwise.create(inner_fn=...)
│ │ ├── 视图操作 → View/PermuteView/ExpandView/...
│ │ ├── 归约操作 → Reduction.create(reduction_type=...)
│ │ ├── Index/Scatter → ops.indirect_indexing(...)
│ │ └── 池化/卷积 → 组合 Reduction + Pointwise
│ │
│ └── NO → fallback_node_due_to_unsupported_type?
│ ├── YES → FallbackKernel.create(target, *args)
│ └── NO → RuntimeError("missing lowering")
│
▼
Inductor IR Graph
│
▼ scheduler.py
Fused Kernel Groups → Triton/C++ Codegen
十八、统计概览
| 维度 | 数据 |
|---|---|
| 文件总行数 | 8384 |
@register_lowering 装饰器 |
168 个 |
make_fallback 调用 |
124 个 |
register_pointwise 调用 |
~50 个 |
register_foreach_pointwise |
~15 个 |
register_inplace |
~20 个 |
| 外部模块注册 | 5 个(kernel, quantized, mkldnn, jagged, comm) |
| 实际覆盖 ATen 算子数 | ~400+ |
十九、阅读路线建议
Level 1:理解注册机制
→ register_lowering (528行) --- 装饰器
→ _register_lowering (475行) --- 实际逻辑
→ register_pointwise (994行) --- 批量注册
Level 2:理解 Pointwise IR 构造
→ make_pointwise (668行) --- 核心工厂
→ 选 relu 或 add 作为入门示例
Level 3:理解视图 lowering
→ view/reshape (1376行) --- 最简
→ slice_ (1390行) --- 含动态符号处理
→ as_strided (1518行) --- 最底层
Level 4:理解归约 lowering
→ make_reduction (6548行) --- 工厂模式
→ mean (6642行) --- 展示组合模式
→ var_mean (6751行) --- 两种算法策略
Level 5:理解复杂 lowering
→ cat (1941行) --- Pointwise vs ConcatKernel 决策
→ index_put_impl_ --- 最复杂的 mutation 操作
→ avg_pool2d (5920行) --- 典型的池化归约模式
Level 6:理解系统边界
→ make_fallback (2462行) --- 何时放弃编译
→ fallback_handler (2348行) --- FallbackKernel 如何工作
→ with_effects (8109行) --- 副作用处理
二十、总结
lowering.py 的 8384 行代码定义了 Inductor 编译器的算子覆盖边界------它决定了哪些 ATen 算子可以被编译为高效 Triton/C++ kernel,哪些必须回退到 eager 执行。
三个层面的设计哲学:
- 分层抽象 :
register_pointwise(一行注册)→make_pointwise(IR 构造工厂)→Pointwise.create(IR 节点)→ Triton codegen(最终代码) - 渐进覆盖 :优先实现高频算子的原生 lowering,其余通过
make_fallback兜底,并在 CI 中持续推进分解/lowering 覆盖 - 性能导向:FMA 优化(add+alpha)、pow 递归展开、cat→pointwise 融合、池化窗口限制------每个细节都服务于生成更快的 kernel
这个文件就像 Inductor 的"翻译词典"------左边是 ATen 算子语义,右边是 Inductor IR 的构造方法。掌握它,就掌握了 PyTorch 编译器"如何将算子变成代码"的关键一步。