PyTorch学习笔记(16):scheduler.py

本文聚焦 torch/_inductor/scheduler.py------一个近 8000 行的核心文件。它是 Inductor 编译器后端的"大脑",负责将 IR 节点组织为可执行的 kernel 序列,并通过激进的**算子融合(Operator Fusion)**策略最大化 GPU/CPU 吞吐量。理解这个文件,就掌握了 PyTorch 编译器从"计算图"到"可执行代码"的关键转化环节。

一、Scheduler 在 Inductor 编译管线中的位置

复制代码
torch.compile(fn)
    │
    ▼ TorchDynamo
FX Graph (torch 级算子)
    │
    ▼ AOT Autograd
FX Graph (ATen 级算子, decomposed)
    │
    ▼ Inductor Lowering (lowering.py)
Inductor IR 节点 (ComputedBuffer, TemplateBuffer, ExternKernel...)
    │
    ▼ ★ Scheduler (scheduler.py) ★
    │   ├─ 依赖分析 (compute_dependencies)
    │   ├─ 拓扑排序 (topological_sort)
    │   ├─ 死代码消除 (dead_node_elimination)
    │   ├─ Kernel 融合 (fuse_nodes)
    │   ├─ 内存优化 (reorder_for_peak_memory)
    │   ├─ 通信重叠 (reorder_compute_and_comm_for_overlap)
    │   ├─ 图分区 (graph_partition → CUDAGraph)
    │   └─ 代码生成调度 (codegen)
    │
    ▼ Triton/C++/CUDA 源代码
    │
    ▼ 编译为可执行 kernel

核心定位 :Scheduler 接收 Inductor IR 的 ir.Operation 列表,经过一系列优化 Pass 后,调度各后端(Triton/C++/CUDA)生成最终代码。它是编译器中端与后端的交汇处


二、文件整体结构鸟瞰

scheduler.py 共 7999 行,包含 16 个类 和大量辅助函数:

复制代码
scheduler.py (7999 行)
│
├── [1-109]       import + 全局变量
├── [110-425]     辅助类:FusionResult, PendingFusion, MixOrderReduction
├── [427-560]     SchedulerBuffer / SchedulerDonatedBuffer  ← 缓冲区节点
├── [539-1413]    BaseSchedulerNode                         ← 调度节点基类(核心)
├── [1414-1505]   WhyNoFuse / OutputNode                    ← 融合诊断 + 输出哨兵
├── [1506-1937]   ExternKernelSchedulerNode / NopKernelSchedulerNode / SchedulerNode
│                                                            ← 三种基础节点
├── [1938-2962]   FusedSchedulerNode 及其子类                ← 融合节点家族
│                 ├─ FusedMixOrderReductions
│                 ├─ FusedExternTritonKernelSchedulerNode
│                 ├─ ForeachKernelSchedulerNode
│                 └─ GroupedSchedulerNode
├── [2963-3077]   NodeUser                                   ← 依赖边
├── [3078-7814]   Scheduler                                  ← 主调度器(3700+ 行)
└── [7815-7999]   BaseScheduling                             ← 后端调度接口

类层次结构

复制代码
BaseSchedulerNode
├── SchedulerNode               ← 可计算节点(Pointwise / Reduction / Template)
├── ExternKernelSchedulerNode   ← 外部 kernel(cuBLAS / 自定义 Triton)
├── NopKernelSchedulerNode      ← 空操作(视图变换等)
└── FusedSchedulerNode          ← 融合节点(包含多个子节点)
    ├── FusedMixOrderReductions ← 混合维度 reduction 融合
    ├── FusedExternTritonKernelSchedulerNode  ← 用户 Triton kernel 的 epilogue 融合
    ├── ForeachKernelSchedulerNode            ← foreach / combo kernel
    └── GroupedSchedulerNode                  ← 分组节点(通信重叠)

三、核心数据结构

3.1 SchedulerBuffer --- 缓冲区抽象

python 复制代码
@dataclasses.dataclass
class SchedulerBuffer:
    scheduler: Scheduler
    node: ir.Buffer              # 底层 IR 缓冲区
    defining_op: BaseSchedulerNode | None  # 哪个操作产生了这个 buffer
    users: list[NodeUser]        # 所有消费者
    mpi_buffer: MemoryPlanningInfoForBuffer  # 内存规划信息

    def allocate(self) -> None: ...    # 在 wrapper code 中生成分配代码
    def can_free(self) -> bool: ...    # 是否可以释放
    def set_users(self, users) -> None: ...  # 设置消费者(去重)

每个 ir.Buffer 对应一个 SchedulerBuffer,通过 Scheduler.name_to_buf 字典索引。

3.2 BaseSchedulerNode --- 调度节点基类

python 复制代码
class BaseSchedulerNode:
    # 核心字段
    node: ir.Operation | None        # 底层 IR 操作
    outputs: list[SchedulerBuffer]   # 输出缓冲区列表
    read_writes: dependencies.ReadWrites  # 读写依赖
    unmet_dependencies: OrderedSet[Dep]   # 未满足的依赖
    ancestors: OrderedSet[str]       # 祖先节点集合
    last_usage: OrderedSet[str]      # 本节点后不再需要的 buffer
    group: tuple[device, tuple[tuple[sympy.Expr, ...], ...]]  # 分组信息
    min_order: int                   # 拓扑排序中的位置
    max_order: int

核心方法:

方法 功能
decide_inplace_update() 决定是否原地更新(buffer reuse)
get_estimated_runtime() 估算运行时间(用于融合决策)
get_read_write_buffers_sizes() 计算读写数据量(用于融合评分)
estimate_flops() 估算浮点运算量
prune_deps() 裁剪冗余依赖

3.3 NodeUser --- 依赖边

python 复制代码
@dataclasses.dataclass
class NodeUser:
    node: BaseSchedulerNode | OutputNode
    can_inplace: bool = False    # 消费者是否可以原地复用 buffer
    is_weak: bool = False        # 弱依赖(仅要求顺序,不使用数据)

弱依赖(WeakDep)是一个重要概念------它表示 mutation ordering 约束,但不构成真正的数据依赖,在融合时可以被安全移除。

3.4 FusionResult 与 PendingFusion

python 复制代码
@dataclasses.dataclass
class FusionResult:
    should_fuse: bool | None = None    # 直接决策
    callable_fn: Callable | None = None  # 延迟决策(异步 benchmark)
    future: LambdaFuture | None = None   # 异步编译结果

@dataclasses.dataclass
class PendingFusion:
    callable_fn: Callable[[], bool]   # 融合判定回调
    node1: BaseSchedulerNode
    node2: BaseSchedulerNode
    future: LambdaFuture | None = None

这两个类支持异步融合决策 ------当 config.benchmark_kernel=True 时,Scheduler 先异步编译融合和未融合两种版本,然后通过 benchmark 决定是否采纳融合。


四、Scheduler 主类:初始化流水线

Scheduler.__init__() 是整个调度过程的入口,它执行了一条 12 步优化管线

python 复制代码
def _init(self, nodes: list[ir.Operation]) -> None:
    # 1. 创建调度节点
    self.nodes = [self.create_scheduler_node(n) for n in nodes]

    # 2. 裁剪无效依赖
    for node in self.nodes:
        node.prune_deps()

    # 3. 构建 name → node/buffer 索引
    self.name_to_node = {n.get_name(): n for n in self.nodes}
    self.name_to_buf = {buf.get_name(): buf ...}

    # 4. 通信排序(分布式训练)
    self.nodes = comms.decide_global_ordering_of_comms(...)

    # 5. 依赖分析
    self.compute_dependencies()

    # 6. 拓扑排序
    self.nodes = self.topological_sort_schedule(self.nodes)

    # 7. 死代码消除
    self.dead_node_elimination()

    # 8. 计算祖先 + 输入距离
    self.compute_ancestors()
    self.compute_input_distances()

    # 9. 流分配(CUDA 多流调度)
    self._populate_stream_assignments()

    # 10. ★ Kernel 融合 ★ (核心优化 Pass)
    self.nodes = self.fuse_nodes(self.nodes)

    # 11. 内存优化 + 通信重叠
    if config.reorder_for_peak_memory:
        self.nodes = reorder_for_peak_memory(...)
    if config.reorder_for_compute_comm_overlap:
        self.nodes = comms.reorder_compute_and_comm_for_overlap(...)

    # 12. 图分区(CUDAGraph 兼容性)
    if config.triton.reorder_for_reducing_graph_partitions:
        self.nodes = self.maybe_reorder_for_minimizing_partition(...)

    # 后处理
    self.compute_last_usage()

五、依赖分析:compute_dependencies()

这是调度器最复杂的方法之一(约 300 行),构建完整的数据依赖图。它处理四类依赖:

5.1 普通读写依赖

python 复制代码
for read in node.read_writes.reads:
    if not isinstance(read, WeakDep):
        add_user(read.name, node, node.can_inplace(read))

节点读取某个 buffer → 该 buffer 的生产者是此节点的依赖。

5.2 Mutation 依赖

python 复制代码
for alt_name in buf.get_mutations():
    # 此节点必须在之前的写入者之后执行
    add_user(alt_name, node)
    # 此节点必须在所有之前的读取者之后执行
    for user in name_to_users[alt_name].items:
        node.add_fake_dep(WeakDep(other_name, mutating_buf=buf.get_name()))

当节点 A mutation 了 buffer X,则 A 必须在 X 的所有之前使用者之后执行。

5.3 Alias(别名)依赖

python 复制代码
for buf2_name in buf1.get_aliases():
    # 合并 alias buffer 的用户列表(Python 对象别名)
    name_to_users[buf1_name] = name_to_users[buf2_name]

如果两个 buffer 是 alias 关系,它们共享同一个用户列表------对一个的 mutation 自动影响另一个。

5.4 Unbacked SymInt 依赖

python 复制代码
# 动态 shape 中的符号整数可能由某个节点定义
for s in unbacked_symbol_uses:
    if (r := unbacked_symbol_to_origin_node[s]) is not None:
        node.add_fake_dep(StarDep(buf.get_name()))

处理动态 shape 时,符号整数的定义-使用关系也需要作为依赖边。


六、Kernel 融合:核心优化 Pass

Kernel 融合是 Scheduler 最重要的优化------它将多个独立的 IR 节点合并为一个 FusedSchedulerNode,在代码生成时产出单个 kernel,从而减少 GPU kernel launch 开销和内存往返。

6.1 融合流程总览

python 复制代码
def fuse_nodes(self, nodes):
    """多轮融合直到收敛"""
    for _ in range(10):  # 最多 10 轮
        new_nodes = self.fuse_nodes_once(nodes)
        if len(new_nodes) == len(nodes):
            break  # 无新融合,收敛
        nodes = new_nodes
    return nodes

每轮 fuse_nodes_once() 的流程:

复制代码
所有节点
    │
    ▼ get_possible_fusions()
收集所有合法融合候选对
    │
    ▼ score_fusion_key() 排序
按融合收益排序(高分优先)
    │
    ▼ _try_fusion_pairs()
逐对检查:
    ├─ can_fuse() → 合法性检查
    ├─ will_fusion_create_cycle() → 环检测
    ├─ can_fusion_increase_peak_memory() → 峰值内存检查
    └─ speedup_by_fusion() → benchmark 验证(可选)
    │
    ▼ FusedSchedulerNode.fuse(node1, node2)
执行融合
    │
    ▼ topological_sort_schedule()
重新拓扑排序

6.2 can_fuse() --- 融合合法性判断

can_fuse() 是一个约 250 行的方法,检查两个节点是否可以被融合。核心判断逻辑:

条件 说明
不是同一节点 node1 is not node2
同一 CUDA 流 多流调度时禁止跨流融合
无祖先冲突 node2.get_operation_names() & node1.ancestors == ∅
非 GroupedNode 分组节点不可二次融合
非 NopKernel 空操作不参与融合
设备匹配 同一 GPU/CPU 设备
分组兼容 (numel, rnumel) 分组匹配
后端允许 backend.can_fuse_vertical() / can_fuse_horizontal()

融合方向有两种:

  • 垂直融合(Vertical):node1 的输出被 node2 消费------生产者-消费者关系
  • 水平融合(Horizontal):node1 和 node2 无依赖,但在相同循环范围上操作

6.3 score_fusion_memory() --- 融合评分

python 复制代码
def score_fusion_memory(self, node1, node2) -> int:
    """估算融合可节省的内存操作数"""
    # 策略 1:精确匹配------找到相同 buffer + 相同索引的 dep
    shared_deps = node1.read_writes.reads & node2.read_writes.reads
    score = sum(dep_size_hint(dep) for dep in shared_deps)

    # 策略 2:Buffer 重叠------同 buffer 不同索引(如 split 操作的不同 slice)
    if score == 0:
        score = self._score_fusion_memory_by_buffer_overlap(...)

    return score

分数越高,表示融合节省越多的内存读写------这是融合排序的核心指标。

6.4 will_fusion_create_cycle() --- 环检测

融合 A 和 B 后,如果从 A∪B 的某个依赖出发能回到 A∪B,就会形成环:

python 复制代码
def will_fusion_create_cycle(self, node1, node2) -> bool:
    """BFS 检测融合是否会创建依赖环"""
    combined = node1.get_operation_names() | node2.get_operation_names()
    # 从 node2 的消费者出发 BFS,看能否到达 node1
    def found_path(node):
        for user in node.users:
            if user.get_name() in combined:
                return True  # 环!
            if found_path(user.node):
                return True
        return False
    return found_path(node2)

6.5 speedup_by_fusion() --- Benchmark 验证

config.benchmark_kernel=True 时,Scheduler 不是盲目融合,而是实际 benchmark

python 复制代码
def speedup_by_fusion(self, node1, node2) -> FusionResult:
    # 1. 生成融合版 kernel 代码
    fused_code = generate_kernel_code([node1, node2])
    # 2. 生成独立版 kernel 代码
    code1 = generate_kernel_code([node1])
    code2 = generate_kernel_code([node2])
    # 3. 异步编译两个版本
    fused_future = async_compile(fused_code)
    separate_futures = [async_compile(code1), async_compile(code2)]
    # 4. 实际 benchmark
    ms_fused = benchmark(fused_future)
    ms_separate = benchmark(separate_futures[0]) + benchmark(separate_futures[1])
    # 5. 融合版更快才采纳
    return FusionResult.fuse(ms_fused < ms_separate)

七、特殊融合策略

7.1 Template 融合(Epilogue / Prologue)

Template 节点(如 CUTLASS GEMM)可以与 Pointwise 节点融合:

复制代码
Epilogue 融合:GEMM → ReLU  →  GEMM+ReLU(合并为一个 kernel)
Prologue 融合:Scale → GEMM  →  Scale+GEMM(输入预处理融入 kernel)
python 复制代码
def is_epilogue_fusion(node1, node2):
    return node1.is_template() and not node2.is_template()

def is_prologue_fusion(node1, node2):
    return node2.is_template() and not node1.is_template()

7.2 MixOrderReduction --- 混合维度 Reduction 融合

两个 reduction 操作在不同维度上归约同一个输入张量时,可以融合为单个 kernel:

python 复制代码
class MixOrderReduction:
    @staticmethod
    def can_fuse(node1, node2) -> bool:
        """检查两个 reduction 是否在不同维度归约同一输入"""
        # 条件:共享读取 + 不同归约维度
        return has_common_read(node1, node2) and has_mix_reduction_orders(node1, node2)

例如:x.sum(dim=0)x.sum(dim=1) 可以融合为一次读取、两个归约通道的单 kernel。

7.3 Foreach / Combo Kernel

ForeachKernelSchedulerNode 将多个独立的逐元素操作合并为一个 kernel(减少 launch 开销):

python 复制代码
def create_combo_kernel_nodes(self):
    """将可并行的独立节点组合为 ComboKernel"""
    for node_list in group_nodes_for_combo_kernels(self):
        if self.speedup_by_combo_kernel(node_list):
            group_snode = ForeachKernelSchedulerNode(self, node_list)
            # 替换原节点

7.4 ExternTritonKernel Epilogue 融合

用户自定义 Triton kernel 支持 epilogue 融合------将下游 Pointwise 操作直接融入 Triton kernel 的输出写入逻辑:

python 复制代码
class FusedExternTritonKernelSchedulerNode(FusedSchedulerNode):
    """用户 Triton kernel + epilogue pointwise 的融合节点"""
    @classmethod
    def epilogue_fuse(cls, extern_node, pw_node):
        return cls(extern_node.scheduler, [extern_node, pw_node])

八、内存优化

8.1 Inplace Update(原地更新)

decide_inplace_update() 检查一个输出 buffer 是否可以复用某个输入 buffer 的内存(避免额外分配):

python 复制代码
def decide_inplace_update(self):
    for buf in self.outputs:
        for user in buf.users:
            if user.can_inplace:
                # 检查:此 buffer 后续没有其他活跃使用者
                # 检查:size/stride 匹配
                # 检查:经过的所有融合节点中索引一致
                V.kernel.inplace_update_buffers[buf.name] = input_buf.name

8.2 Peak Memory Reordering

通过重排节点顺序降低峰值内存:

python 复制代码
if config.reorder_for_peak_memory:
    self.nodes = reorder_for_peak_memory(
        self.nodes,
        self.name_to_buf,
        self.name_to_fused_node,
        graph_inputs,
        graph_outputs,
    )

该 Pass 来自 memory.py 模块,使用贪心策略优先调度能释放大量内存的节点。

8.3 Dead Node Elimination

python 复制代码
def dead_node_elimination(self):
    """反向遍历,移除没有活跃消费者的节点"""
    for node in reversed(self.nodes):
        can_eliminate = all(
            user.is_weak or user.get_name() in removed_operations
            for buf in node.get_outputs()
            for user in buf.users
        )
        if can_eliminate and not node.has_side_effects():
            V.graph.removed_operations.add(node.get_name())

九、代码生成调度

9.1 codegen() 入口

python 复制代码
def codegen(self):
    if config.graph_partition:
        self._codegen_partitions()  # 分区模式(CUDAGraph 兼容)
    else:
        self._codegen(self.nodes)   # 直接生成

9.2 _codegen() --- 逐节点代码生成

python 复制代码
def _codegen(self, nodes):
    for node in nodes:
        self.enter_context(node)        # 设备/流上下文切换

        if node.is_template():
            backend.codegen_template(template_node, epilogue, prologue)
        elif node.is_extern():
            self.codegen_extern_call(node)
        elif node.is_foreach():
            backend.codegen_combo_kernel(node)
        elif isinstance(node, FusedMixOrderReductions):
            backend.codegen_mix_order_reduction(node)
        elif isinstance(node, (FusedSchedulerNode, SchedulerNode)):
            backend.codegen_node(node)   # ← 最常走的路径
        else:
            node.mark_run()              # NopKernel

        self.available_buffer_names.update(node.get_buffer_names())
        self.buffer_names_to_free.update(node.last_usage)

代码生成委托给 BaseScheduling 子类(如 SIMDScheduling 负责 Triton 代码生成)。

9.3 Graph Partition --- CUDAGraph 兼容

CUDAGraph 不支持某些操作(如动态 shape、CPU 操作、非确定性输入),graph_partition() 将节点分割为多个子图:

复制代码
原始节点序列: [GPU_op1, GPU_op2, CPU_op, GPU_op3, GPU_op4]
                                    ↓ graph_partition
分区结果:
    Partition 0: [GPU_op1, GPU_op2]   → CUDAGraph 捕获
    Partition 1: [CPU_op]             → 直接执行
    Partition 2: [GPU_op3, GPU_op4]   → CUDAGraph 捕获

每个可 CUDAGraph 的分区被包装为独立函数,允许 CUDAGraph 只捕获兼容的部分。

9.4 多流调度

python 复制代码
def _populate_stream_assignments(self):
    """从 FX 节点的 custom.stream 元数据读取流分配"""
    for node in self.nodes:
        for fx_node in node.node.get_origins():
            if "stream" in fx_node.meta.get("custom", {}):
                stream_idx = ...
                self.node_to_stream[node] = stream_idx

Scheduler 在融合时禁止跨流融合,在代码生成时插入流切换代码。


十、BaseScheduling --- 后端调度接口

python 复制代码
class BaseScheduling:
    """所有后端调度器的基类"""

    def can_fuse_vertical(self, node1, node2) -> bool: ...
    def can_fuse_horizontal(self, node1, node2) -> bool: ...
    def fuse(self, node1, node2) -> FusedSchedulerNode: ...
    def group_fn(self, sizes) -> tuple: ...
    def codegen_node(self, node) -> None: ...
    def codegen_template(self, template, epilogue, prologue) -> str | None: ...
    def codegen_sync(self) -> None: ...
    def flush(self) -> None: ...
    def benchmark_fused_nodes(self, nodes) -> tuple[float, str]: ...

具体后端实现包括:

  • SIMDSchedulingcodegen/simd.py):Triton 代码生成
  • CppSchedulingcodegen/cpp.py):C++/OpenMP 代码生成
  • CUDACombinedSchedulingcodegen/cuda_combined_scheduling.py):CUDA C++ 代码生成

十一、WhyNoFuse --- 融合诊断工具

WhyNoFuse 是一个精巧的调试工具,记录每对节点未能融合的原因:

python 复制代码
class WhyNoFuse:
    def __init__(self, node1, node2):
        self.name1 = node1.get_name()
        self.name2 = node2.get_name()

    def __call__(self, reason, *args):
        self.reason = reason
        fusion_log.debug(self)

    def __str__(self):
        return f"cannot fuse {self.name1} with {self.name2}: {self.reason % self.args}"

使用方式:

python 复制代码
why = WhyNoFuse(node1, node2)
if node1.is_cpu() and node2.is_gpu():
    why("device mismatch: %s vs %s", node1.get_device(), node2.get_device())
    return False

通过 TORCH_LOGS="+fusion" 可以查看所有融合决策的详细日志。


十二、数据流全景图

复制代码
ir.Operation 列表(来自 Inductor Lowering)
    │
    ▼ create_scheduler_node()
    │  ComputedBuffer → SchedulerNode
    │  TemplateBuffer → SchedulerNode
    │  ExternKernel   → ExternKernelSchedulerNode
    │
    ▼ compute_dependencies()
    │  构建 name_to_users 依赖图
    │  处理 alias / mutation / unbacked symint
    │
    ▼ topological_sort_schedule()
    ▼ dead_node_elimination()
    ▼ compute_ancestors() + compute_input_distances()
    │
    ▼ _populate_stream_assignments()
    │  读取 FX 节点的 stream 元信息
    │
    ▼ fuse_nodes()  【核心优化】
    │  ┌── fuse_nodes_once() 循环(最多 10 轮)
    │  │   ├─ get_possible_fusions()      ← 枚举候选对
    │  │   ├─ can_fuse()                   ← 合法性检查
    │  │   ├─ score_fusion_memory()        ← 评分排序
    │  │   ├─ will_fusion_create_cycle()   ← 环检测
    │  │   ├─ speedup_by_fusion()          ← benchmark(可选)
    │  │   └─ FusedSchedulerNode.fuse()    ← 执行融合
    │  └── 直到无新融合
    │
    ▼ reorder_for_peak_memory()        ← 内存优化
    ▼ reorder_compute_and_comm_for_overlap()  ← 通信重叠
    ▼ graph_partition()                 ← CUDAGraph 分区
    │
    ▼ codegen()
    │  for node in nodes:
    │      backend.codegen_node(node)   ← 委托给 SIMDScheduling 等
    │
    ▼ Triton / C++ / CUDA 源代码

十三、关键配置项速查

配置项 默认值 功能
config.benchmark_kernel False 通过实际 benchmark 决定是否融合
config.epilogue_fusion True 允许 template 的 epilogue 融合
config.prologue_fusion True 允许 template 的 prologue 融合
config.combo_kernels True 启用 ComboKernel(合并并行算子)
config.inplace_buffers True 启用 buffer 原地复用
config.reorder_for_peak_memory True 启用峰值内存优化重排
config.reorder_for_compute_comm_overlap True 启用计算/通信重叠
config.graph_partition True 启用 CUDAGraph 图分区
config.loop_ordering_after_fusion True 融合后再做循环排序
config.max_fusion_buffer_group_pairwise_attempts 128 每组最大候选对数
config.use_dce True 启用死代码消除

十四、阅读路线建议

复制代码
Level 1:理解全局架构
    → Scheduler.__init__() 的 12 步管线 (3078-3200)
    → 类层次结构 (BaseSchedulerNode 及其子类)

Level 2:理解依赖系统
    → compute_dependencies() (3478-3730)
    → NodeUser / WeakDep / StarDep 语义

Level 3:理解融合引擎
    → fuse_nodes_once() (4957-5025) ← 主循环
    → can_fuse() (5785-6030) ← 合法性判断(250+ 行)
    → score_fusion_memory() (6207-6320) ← 评分逻辑
    → will_fusion_create_cycle() (5134-5180) ← 环检测

Level 4:理解代码生成
    → _codegen() (7481-7640) ← 逐节点调度
    → BaseScheduling 接口 (7815-7999) ← 后端契约
    → _codegen_partitions() ← CUDAGraph 分区

Level 5:调试技巧
    → TORCH_LOGS="+fusion" 查看融合决策
    → INDUCTOR_WRITE_SCHEDULER_GRAPH=1 生成图可视化
    → TORCH_COMPILE_DEBUG=1 查看所有编译产物

十五、总结

scheduler.py 的 8000 行代码浓缩了编译器优化的精华。它扮演了三个关键角色:

  1. 依赖分析器 :通过 compute_dependencies() 构建完整的数据流图,正确处理 mutation、aliasing、unbacked symint 等边界情况
  2. 融合决策器 :通过多轮迭代的 fuse_nodes() 循环,在合法性检查(环检测、设备匹配、后端约束)和收益评估(内存评分、benchmark 验证)之间找到最优融合方案
  3. 代码生成调度器:将最终的融合节点序列按设备、流、CUDAGraph 分区组织,委托各后端生成高效的 Triton/C++/CUDA 代码

这三个角色共同构成了从"IR 视图"到"可执行代码"的最后一公里。理解 scheduler.py,就理解了 PyTorch Inductor 为什么能在不牺牲灵活性的前提下,逼近手写 CUDA kernel 的性能。

相关推荐
半壶清水2 小时前
[软考网规考点笔记]-局域网之高速以太网
网络·笔记·网络协议·考试
一定要AK10 小时前
Spring 入门核心笔记
java·笔记·spring
AI成长日志10 小时前
【Agentic RL】1.1 什么是Agentic RL:从传统RL到智能体学习
人工智能·学习·算法
_李小白11 小时前
【OSG学习笔记】Day 38: TextureVisitor(纹理访问器)
android·笔记·学习
杨云龙UP11 小时前
从0到1快速学会Linux操作系统(基础),这一篇就够了!
linux·运维·服务器·学习·ubuntu·centos·ssh
头疼的程序员12 小时前
计算机网络:自顶向下方法(第七版)第八章 学习分享(三)
网络·学习·计算机网络
_李小白13 小时前
【OSG学习笔记】Day 37: NodeVisitor(顶点访问器)
笔记·学习
断眉的派大星13 小时前
pytorch中view和reshape的区别
人工智能·pytorch·python
程序员雷欧13 小时前
大模型应用开发学习第八天
大数据·人工智能·学习