PyTorch KernelAgent 源码解读 ---(4)--- ExtractorAgent

PyTorch KernelAgent 源码解读 ---(4)--- ExtractorAgent

目录

  • [PyTorch KernelAgent 源码解读 ---(4)--- ExtractorAgent](#PyTorch KernelAgent 源码解读 ---(4)--- ExtractorAgent)
    • [0x00 摘要](#0x00 摘要)
    • [0x01 功能详解](#0x01 功能详解)
      • [1.1 核心作用](#1.1 核心作用)
      • [1.2 详细分析](#1.2 详细分析)
      • [1.3 流程图](#1.3 流程图)
      • [1.4 与系统其他组件的交互](#1.4 与系统其他组件的交互)
        • [与 Orchestrator 的交互](#与 Orchestrator 的交互)
        • [与 Dispatch Kernel Agent 的交互](#与 Dispatch Kernel Agent 的交互)
        • [与 Composer 的交互](#与 Composer 的交互)
        • [LLM 交互机制](#LLM 交互机制)
    • [0x02 Prompt](#0x02 Prompt)
      • [2.1 概括](#2.1 概括)
      • [2.2 Prompt 的基本结构](#2.2 Prompt 的基本结构)
        • [System Prompt](#System Prompt)
        • [User Prompt 详细结构](#User Prompt 详细结构)
      • [2.3 使用时机](#2.3 使用时机)
    • [0x03 实现](#0x03 实现)
      • [3.1 特色](#3.1 特色)
      • [3.2 代码](#3.2 代码)
    • [0xFF 参考](#0xFF 参考)

0x00 摘要

subgraph_extractor.py 是 KernelFalcon 实现 "PyTorch 模型子图提取 + 形状签名去重" 的关键组件,核心职责是通过 Fuser 生成融合代码后,借助 LLM 解析并提取模型中唯一的计算子图(按形状 / 算子 / 权重特征去重),最终输出标准化 JSON 格式的子图信息。这一模块体现了 "Agent 端到端优化" 中 "精准子图识别" 的关键能力。

Extractor 的架构如下,其三阶段总结如下:

  • 融合重写:调Orchestrator(多worker并行),让LLM把原PyTorch改写成可融合子模块,沙箱验证数值等价,产出code.py.tgz.
  • 子图识别:把原问题+融合代码一起喂给LLM,要求输出每个唯一子图的 JSON描述(ops、shapes、weights、来源代码片段)
  • 去重合并:按ops+shapes+weights+layout+dtype算签名,相同签名合并count,确保Dispatcher 只为每个独特子图生成一次内核

0x01 功能详解

subgraph_extractor.py 是子图提取的关键组件,用于从融合的 PyTorch 代码中识别和提取可融合的子图及其形状信息。这个组件使得复杂的融合操作可以分解为更小、更易管理的子图,每个子图都有精确的形状和语义信息,为后续的 Triton 内核生成奠定了坚实基础。

1.1 核心作用

subgraph_extractor.py 的核心作用如下:

  • 分析融合代码:从 Fuser 生成的融合 PyTorch 代码中提取语义信息

  • 识别可以独立优化的子图

  • 形状感知:提取精确的输入 / 输出形状用于后续优化

  • 生成结构化描述:创建精确的 JSON 格式子图描述

  • 去重机制:基于形状签名消除重复子图

  • 为后续阶段提供输入:为 Triton 内核生成和最终合成提供基础数据

主要功能流程如下。

python 复制代码
问题文件
   ↓
Fuser Orchestrator 生成融合代码
   ↓
subgraph_extractor.py 分析融合代码并提取子图
    ↓
subgraphs.json 结构化子图描述

1.2 详细分析

代码提取功能

_load_code_from_tar 函数完成代码提取功能。

python 复制代码
  def _load_code_from_tar(artifact_path: Path) -> str:
    """从tar.gz压缩包中读取code.py文件内容(Fuser生成的融合代码)"""
    # 检查压缩包文件是否存在,不存在则返回空字符串
    if not artifact_path.is_file():
        return ""
    # 以只读模式打开gzip压缩的tar包
    with tarfile.open(artifact_path, "r:gz") as tf:
        try:
            # 获取压缩包中名为code.py的文件成员
            member = tf.getmember("code.py")
        except KeyError:
            # 若code.py不存在,返回空字符串
            return ""
        # 提取code.py文件内容
        extracted = tf.extractfile(member)
        # 若提取失败(文件为空),返回空字符串
        if extracted is None:
            return ""
        # 读取文件内容并解码为UTF-8字符串返回
        return extracted.read().decode("utf-8")    

LLM提示构建

_build_llm_prompt_for_shapes 完成了 prompt构建功能,其关键特点是:

  • 精确性:要求精确的形状签名
  • 结构化:强制返回特定的 JSON 格式
  • 完整性:包含操作、权重、布局等所有相关信息
python 复制代码
def _build_llm_prompt_for_shapes(fused_code: str, problem_code: str) -> tuple[str, str]:
    """构建LLM提示词:引导LLM分析融合代码和原始代码,提取子图信息"""
    # System Prompt:强制要求仅返回JSON数组
    system = "Return a single JSON array only."
    user_lines: list[str] = []
    # 角色与背景说明:告知LLM输入内容(原始问题代码+融合代码)
    user_lines.append(
        "You are given:\n- The original problem (PyTorch).\n- A fused refactor produced by Fuser (PyTorch subgraph modules)."
    )
    # 核心任务说明:按形状签名识别唯一子图,输出指定Schema的JSON数组
    user_lines.append(
        "Task: Identify every unique subgraph by exact shape signature and emit a JSON array matching this schema (and only this schema):"
    )
    # 详细Schema定义:明确每个字段的含义和格式
    user_lines.append(
        "{\n"
        '  "id": <string>,\n'
        '  "type": <string>,\n'
        '  "data_layout": <\\"NCHW\\"|\\"NHWC\\"|null>,\n'
        '  "dtype": <string|null>,\n'
        '  "ops": [ {"op": <string>, ... op-specific fields ... } ],\n'
        '  "input_shape": [<int|sym>, ...]  // OR \\"inputs\\": [[...], [...]] for multi-input\n'
        '  "output_shape": [<int|sym>, ...],\n'
        '  "weights_fused": { <name>: [<int|sym>, ...], ... } | null,\n'
        '  "weights_original": { <name>: [<int|sym>, ...], ... } | null,\n'
        '  "count": <int>,\n'
        '  "where": <string>,\n'
        '  "source": { "module": <string>, "code": <string> }\n'
        "}"
    )
    # 关键注意事项:细化提取规则,提升准确性
    user_lines.append("Notes:")
    user_lines.append(
        "- Treat any shape difference (inputs/outputs/weights) as a distinct subgraph. Count occurrences."
    )
    user_lines.append(
        "- Populate op-specific fields for conv/pool/linear, e.g., kernel_size/stride/padding/groups, bn_fused, output_size, start_dim."
    )
    user_lines.append(
        "- Include both weights_original (pre-fusion params like BN gamma/beta/running stats) and weights_fused (post-fusion conv/bias). Use null if not applicable."
    )
    user_lines.append(
        "- Provide a short \"where\" string (e.g., 'Model.forward stem' or 'layer2.block3.conv')."
    )
    user_lines.append(
        '- Provide "source" with the smallest contiguous code snippet implementing the subgraph.'
    )
    user_lines.append(
        "- Use data_layout and dtype when clear (default conv layout is NCHW)."
    )
    user_lines.append(
        '- For binary ops like residual add, use "inputs": [[...],[...]].'
    )
    user_lines.append(
        "- Prefer concrete integers from get_inputs() shapes in the problem; otherwise use symbols like B, H, W."
    )
    user_lines.append("")
    # 输入代码:原始问题代码
    user_lines.append("PROBLEM_FILE:\n```python")
    user_lines.append(problem_code)
    user_lines.append("```")
    user_lines.append("")
    # 输入代码:Fuser生成的融合代码
    user_lines.append("FUSED_CODE:")
    user_lines.append("""```python""")
    user_lines.append(fused_code)
    user_lines.append("```")
    user_lines.append("")
    # 最终要求:仅返回包含数组的JSON代码块,无其他文本
    user_lines.append(
        "Now return only one fenced JSON block containing the array. No prose."
    )
    # 返回System Prompt和User Prompt
    return system, "\n".join(user_lines)

形状签名去重机制

_dedup_by_shape_signature 实现了去重代码。

  • 基于输入 / 权重 / 输出形状的标准化表示
  • 忽略名称但保留维度和数据类型
  • 确保相同语义的子图被合并
python 复制代码
def _dedup_by_shape_signature(items: list[dict[str, Any]]) -> list[dict[str, Any]]:
    """Deduplicate items by a stable shape signature.

    The signature is based on sorted lists of input/weight/output shapes content,
    ignoring names but preserving dimensions and dtypes.
    基于稳定的形状签名对子图列表去重:
    - 签名基于输入/权重/输出形状的标准化内容(忽略名称,保留维度和数据类型)
    - 保证相同形状特征的子图只保留一个
    """

    def norm_shapes(arr: Any) -> Any:
        """内部函数:标准化形状数组(统一不同格式的形状描述)"""
        # 非列表类型直接返回空列表
        if not isinstance(arr, list):
            return []
        normed: list[Any] = []
        # 遍历数组中的每个元素
        for e in arr:
            if isinstance(e, dict):
                # 标准化形状字典的键(兼容不同命名方式:shape/dims/size)
                shape = e.get("shape") or e.get("dims") or e.get("size")
                dtype = e.get("dtype")
                kind = e.get("kind") or e.get("role")
                # 标准化维度:优先int/str类型,其他类型转为字符串
                if isinstance(shape, list):
                    dims = [str(x) for x in shape]
                elif isinstance(shape, (int, str)):
                    dims = [str(shape)]
                else:
                    dims = [str(shape)] if shape is not None else []
                # 构建标准化的形状描述字典
                normed.append(
                    {"dims": dims, "dtype": str(dtype) if dtype else None, "k": kind}
                )
            else:
                # 非字典元素直接转为字符串
                normed.append(str(e))
        # 排序以保证签名的稳定性(避免顺序不同导致签名不同)
        return sorted(normed, key=lambda x: json.dumps(x, sort_keys=True))

    # 存储已见过的签名,避免重复
    seen: set[str] = set()
    out: list[dict[str, Any]] = []
    # 遍历所有子图项
    for it in items:
        # 构建签名对象:包含输入/权重/输出的标准化形状
        sig_obj = {
            "inputs": norm_shapes(it.get("input_shapes")),
            "weights": norm_shapes(it.get("weight_shapes") or it.get("weights")),
            "outputs": norm_shapes(it.get("output_shapes")),
        }
        # 转为JSON字符串作为唯一签名(排序保证稳定性)
        sig = json.dumps(sig_obj, sort_keys=True)
        # 若签名未见过,则保留该子图
        if sig in seen:
            continue
        seen.add(sig)
        out.append(it)
    # 返回去重后的子图列表
    return out

1.3 流程图

subgraph_extractor.py 的流程如下:

初始化阶段:

  • 创建OrchestratorConfig配置对象
  • 生成唯一的运行ID并创建运行目录结构
  • 初始化Orchestrator对象

代码提取阶段:

  • 运行Orchestrator.run()获取融合后的PyTorch代码
  • 检查是否成功找到解决方案
  • 加载原始问题代码和融合后的代码

LLM分析阶段:

  • 构建包含原始问题和融合代码的提示
  • 根据提供商类型选择不同的API调用方式
  • 提取并解析LLM返回的JSON格式的子图描述

后处理阶段:

  • 验证JSON结构的有效性
  • 通过形状签名对子图进行去重和合并
  • 保存最终的subgraphs.json文件
  • 返回运行目录和JSON文件路径

具体流程图如下:

1.4 与系统其他组件的交互

与 Orchestrator 的交互

调用 Fuser Orchestrator 生成融合代码

python 复制代码
orch = Orchestrator(...)
summary = orch.run()
fused_code = _load_code_from_tar(Path(summary.artifact_path))

与 Dispatch Kernel Agent 的交互

生成的 subgraphs.json 作为 dispatch_kernel_agent.py 的输入,为每个子图生成 Triton 内核。

与 Composer 的交互

subgraphs.json 作为 compose_end_to_end.py 的输入之一,用于最终的端到端合成

LLM 交互机制

python 复制代码
# Provider 选择
provider = get_model_provider(model_name)
if provider.name != "openai":
    # 直接调用提供商
    result = provider.get_response(...)
else:
    # 通过 EventAdapter 流式处理
    adapter = EventAdapter(...)
    result = adapter.stream(
                    system_prompt=SYSTEM_PROMPT, user_prompt=rp.user, extras=rp.extras)

0x02 Prompt

我们来分析 subgraph_extractor.py 中的 Prompt 构建机制。

2.1 概括

subgraph_extractor.py 使用的是 LLM 提示,专门用于从融合的 PyTorch 代码中提取子图及其形状信息。

这条 prompt 可以一句话概括:"把话说到编译器级别,不给自由发挥留缝隙。"具体特点拆解如下:

  1. 极端结构化
    • 用 JSON Schema 把字段名、类型、取值范围、嵌套层级一次性钉死,连 null 能出现在哪都标好。
    • 要求"只返回一个 fenced code block",直接把自然语言出口焊死,防止模型"顺便聊聊"。
  2. 双重代码上下文
    • 同时给出"原始 PyTorch 代码"和"融合后的代码",让模型既能看到"改名前的权重"也能看到"融合后的权重",相当于开卷考试但限定只能写标准答案格式。
  3. 微观操作级说明书
    • 对每一类算子(conv、pool、linear、add)都列出必须出现的 key(kernel_size/stride/padding/groups...),把"该抄哪几行"写成 checklist,模型只要漏一项就能被后处理脚本一键拒收。
    • 明确"形状不同就算新子图",避免模型把不同 block 的同名层合并。
  4. 符号系统与优先级双重约束
    • 先拿 get_inputs() 的 concrete shape 当"硬数",找不到才允许用 B/H/W 符号,既保证可静态检查,又留一条退路。
    • 权重必须同时给 weights_originalweights_fused,逼模型把"融合前后张量对应关系"显式写出来,防止"黑箱合并"。
  5. Zero-shot 但 Zero-creativity
    • 没有 few-shot 示例,却用 12 条"Notes"把边界情况全部穷举,等于告诉模型"你不需要创新,只需要当一台会数数的扫描仪"。
    • 最后用"No prose"把寒暄、总结、解释统统 ban 掉,输出直接变成可 json.loads 的"机器口粮"。

2.2 Prompt 的基本结构

System Prompt

System Prompt 的内容如下:"Return a single JSON array only."

其要求 LLM 只返回单个 JSON 数组,避免返回额外的文本说明。

User Prompt 详细结构

首先是背景介绍

python 复制代码
    user_lines.append(
        "You are given:\n- The original problem (PyTorch).\n- A fused refactor produced by Fuser (PyTorch subgraph modules)."
    )

其次是任务描述,从融合代码中识别所有独特的子图,提取精确的形状信息(输入 / 输出 / 权重),为后续的 Triton 内核生成提供结构化输入

python 复制代码
    user_lines.append(
        "Task: Identify every unique subgraph by exact shape signature and emit a JSON array matching this schema (and only this schema):"
    )

接下来会说明期望的JSON schema

python 复制代码
    user_lines.append(
        "{\n"
        '  "id": <string>,\n'
        '  "type": <string>,\n'
        '  "data_layout": <\\"NCHW\\"|\\"NHWC\\"|null>,\n'
        '  "dtype": <string|null>,\n'
        '  "ops": [ {"op": <string>, ... op-specific fields ... } ],\n'
        '  "input_shape": [<int|sym>, ...]  // OR \\"inputs\\": [[...], [...]] for multi-input\n'
        '  "output_shape": [<int|sym>, ...],\n'
        '  "weights_fused": { <name>: [<int|sym>, ...], ... } | null,\n'
        '  "weights_original": { <name>: [<int|sym>, ...], ... } | null,\n'
        '  "count": <int>,\n'
        '  "where": <string>,\n'
        '  "source": { "module": <string>, "code": <string> }\n'
        "}"
    )

然后是详细说明和注意事项

python 复制代码
    user_lines.append("Notes:")
    user_lines.append(
        "- Treat any shape difference (inputs/outputs/weights) as a distinct subgraph. Count occurrences."
    )
    user_lines.append(
        "- Populate op-specific fields for conv/pool/linear, e.g., kernel_size/stride/padding/groups, bn_fused, output_size, start_dim."
    )
    user_lines.append(
        "- Include both weights_original (pre-fusion params like BN gamma/beta/running stats) and weights_fused (post-fusion conv/bias). Use null if not applicable."
    )
    user_lines.append(
        "- Provide a short \"where\" string (e.g., 'Model.forward stem' or 'layer2.block3.conv')."
    )
    user_lines.append(
        '- Provide "source" with the smallest contiguous code snippet implementing the subgraph.'
    )
    user_lines.append(
        "- Use data_layout and dtype when clear (default conv layout is NCHW)."
    )
    user_lines.append(
        '- For binary ops like residual add, use "inputs": [[...],[...]].'
    )
    user_lines.append(
        "- Prefer concrete integers from get_inputs() shapes in the problem; otherwise use symbols like B, H, W."
    )

最后是输入代码示例

python 复制代码
    user_lines.append("PROBLEM_FILE:\n```python")
    user_lines.append(problem_code)
    user_lines.append("```")
    user_lines.append("")
    user_lines.append("FUSED_CODE:")
    user_lines.append("""```python""")
    user_lines.append(fused_code)
    user_lines.append("```")
    user_lines.append("")
    user_lines.append(
        "Now return only one fenced JSON block containing the array. No prose."
    )

2.3 使用时机

在 extract_subgraphs_to_json 函数中会调用 prompt:

python 复制代码
# Ask LLM for shapes JSON
system, user = _build_llm_prompt_for_shapes(fused_code, problem_code)
"""
Temporary MUX to support Relay while we migrate to OpenAI Responses API.
Uses EventAdapter for OpenAI, otherwise Provider inferface
"""
provider = get_model_provider(model_name)
if provider.name != "openai":
    # 直接调用提供商
    result = provider.get_response(...)
else:
    # 通过 EventAdapter 流式处理
    adapter = EventAdapter(...)
	result = adapter.stream(...)

0x03 实现

subgraph_extractor.py 实现了 KernelFalcon 的 "子图识别" 核心能力 ------ 通过 Fuser 生成融合代码→LLM 解析代码提取子图→签名去重合并→输出标准化 JSON,为后续 Triton 算子自动生成提供精准的子图粒度输入。

3.1 特色

  • LLM 驱动的智能子图识别:放弃传统的 "静态代码解析 + 规则匹配",改用 LLM 理解 PyTorch 代码语义,精准识别卷积 / 池化 / 线性层等算子的子图边界、形状、权重特征,适配复杂的融合代码场景;
  • 鲁棒的签名去重机制:基于 "算子 + 输入 / 输出形状 + 权重结构 + 数据布局 + 数据类型" 构建稳定签名,避免因命名 / 格式差异导致的重复子图,保证子图识别的唯一性;
  • 全链路容错设计:针对 LLM 输出格式异常、JSON 解析失败、代码文件缺失等场景,均有明确的容错逻辑和诊断文件输出,提升工业级可用性;
  • 标准化输出格式:定义统一的子图 JSON Schema,包含 id、类型、形状、权重、计数等核心字段,为后续算子生成和模型优化提供标准化输入;
  • 适配多 LLM 提供商:兼容 OpenAI Responses API 和其他 LLM 提供商的接口,通过适配层统一调用逻辑,保证灵活性。

3.2 代码

完整链路为:

  1. 从 Fuser 生成的融合代码压缩包中读取核心代码文件;
  2. 构造精准的 LLM Prompt,引导 LLM 分析原始 PyTorch 问题代码和融合代码,识别所有计算子图并输出 JSON 格式信息;
  3. 提取 LLM 输出中的 JSON 区块,处理格式异常并做容错;
  4. 基于 "形状 + 算子 + 权重" 的稳定签名对子图去重,合并重复子图的计数;
  5. 输出标准化的子图 JSON 文件,为后续 Triton 算子自动生成和模型优化提供精准的子图粒度信息。

提取时,从LLM输出中提取代码时选择最后一个python 代码块。这个思路如下:

  • LLM的回复通常遵循"分析--->实现"的模式:先给出代码片段解释思路,最后给出完整实现
  • 中间的代码块可能是不完整的示例、伪代码、或者修改前的版本
  • 最后一个代码块最可能是LLM的"最终答案"一完整的、经过修正的实现

多代码块的处理优先级:

  • 优先找带 python语言标签的最后一个代码块
  • 如果没有python标签,回退到最后一个任意语言的代码块
  • 提取后用ast.parse()验证是否为合法Python
  • 如果解析失败,抛出ValueError

潜在风险如下:

  • 如果 LLM 先给完整实现再给修改后的局部片段,会错误地选择不完整的代码
python 复制代码
# 正则表达式:匹配Markdown中的JSON代码块(```json ... ```)
# 匹配规则:
# - 开头:``` 后可跟空格/tab + json(可选) + 空格/tab + 换行
# - 中间:任意字符(非贪婪匹配)
# - 结尾:换行 + ``` + 空格/tab + 行尾
# - 修饰符:MULTILINE(多行匹配)、IGNORECASE(忽略大小写)
_JSON_BLOCK_RE = re.compile(
    r"^```[ \t]*(json)?[ \t]*\n([\s\S]*?)^```[ \t]*$",
    re.MULTILINE | re.IGNORECASE,
)

def _extract_json_block(text: str) -> str:
    """Extract the last fenced JSON block or fallback to best-effort slice.
    从LLM输出文本中提取JSON区块:
    1. 优先提取最后一个标记为json的代码块;
    2. 若无则提取最后一个任意代码块;
    3. 若仍无则尝试截取第一个[和最后一个]之间的内容;
    4. 均失败则返回空字符串
    """
    # 查找所有匹配的代码块
    matches = list(_JSON_BLOCK_RE.finditer(text))
    chosen: re.Match[str] | None = None
    # 逆序遍历匹配结果,优先选择标记为json的代码块
    for m in reversed(matches):
        lang = (m.group(1) or "").strip().lower()
        if lang == "json":
            chosen = m
            break
    # 若无json标记的代码块,但有其他代码块,选择最后一个
    if chosen is None and matches:
        chosen = matches[-1]
    # 若找到代码块,返回其中的内容(group(2))
    if chosen is not None:
        return chosen.group(2)
    # 兜底方案:尝试截取第一个[和最后一个]之间的内容(JSON数组)
    start = text.find("[")
    end = text.rfind("]")
    if start != -1 and end != -1 and end > start:
        return text[start : end + 1]
    # 所有方案均失败,返回空字符串
    return ""

def extract_subgraphs_to_json(
    problem_path: Path,
    model_name: str,
    workers: int,
    max_iters: int,
    llm_timeout_s: int,
    run_timeout_s: int,
    target_platform: str = "cuda",
) -> tuple[Path, Path]:
    """Run Fuser to produce fused code, then use LLM to emit subgraphs JSON.

    Returns (run_dir, json_path).
    核心函数:
    1. 运行Orchestrator执行Fuser生成融合代码;
    2. 调用LLM分析融合代码和原始代码,提取子图信息;
    3. 解析LLM输出,去重并合并计数;
    4. 输出标准化的子图JSON文件;
    返回值:(运行目录路径, JSON文件路径)
    """
    # 1. 配置Orchestrator:设置Fuser运行参数
    cfg = OrchestratorConfig(
        problem_path=problem_path,  # KernelBench问题文件路径
        model=model_name,  # LLM模型名称
        workers=workers,  # 工作进程数
        max_iters=max_iters,  # 最大迭代次数
        llm_timeout_s=llm_timeout_s,  # LLM调用超时时间(秒)
        run_timeout_s=run_timeout_s,  # 整体运行超时时间(秒)
        stream_mode="winner",  # 流模式:仅保留最优结果
        store_responses=False,  # 不存储LLM响应
        isolated=False,  # 非隔离模式
        deny_network=False,  # 允许网络访问
        enable_reasoning_extras=True,  # 启用推理增强功能
        target_platform=target_platform,  # 目标平台(默认CUDA)
    )
    # 生成唯一运行ID
    run_id = new_run_id()
    # 构建运行目录(当前目录/.fuse)
    base_dir = Path.cwd() / ".fuse"
    base_dir.mkdir(exist_ok=True)  # 目录不存在则创建
    # 创建运行相关的目录结构
    dirs = make_run_dirs(base_dir, run_id)

    # 初始化Orchestrator:负责调度Fuser和LLM任务
    orch = Orchestrator(
        cfg,
        run_dir=dirs["run_dir"],  # 运行目录
        workers_dir=dirs["workers"],  # 工作进程目录
        orchestrator_dir=dirs["orchestrator"],  # 调度器目录
    )
    # 运行Orchestrator,获取执行摘要
    summary = orch.run()
    # 检查是否生成有效融合代码:无最优结果则退出
    if summary.winner_worker_id is None or not summary.artifact_path:
        raise SystemExit(f"No passing fused code: {summary.reason}")

    # 2. 加载融合代码:从Orchestrator输出的压缩包中读取code.py
    fused_code = _load_code_from_tar(Path(summary.artifact_path))
    # 检查融合代码是否有效:为空则退出
    if not fused_code.strip():
        raise SystemExit("Winner artifact missing code.py or empty")
    # 加载原始问题代码
    problem_code = problem_path.read_text(encoding="utf-8")

    # 3. 构建LLM Prompt并调用LLM
    # 构建System和User Prompt
    system, user = _build_llm_prompt_for_shapes(fused_code, problem_code)

    """
    Temporary MUX to support Relay while we migrate to OpenAI Responses API.

    Uses EventAdapter for OpenAI, otherwise Provider inferface
    临时适配层:兼容OpenAI Responses API和其他LLM提供商
    """
    # 获取LLM提供商实例
    provider = get_model_provider(model_name)
    if provider.name != "openai":
        # 非OpenAI提供商:使用标准Provider接口
        messages: list[dict[str, str]] = [
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ]
        # 调用LLM获取响应
        result = provider.get_response(
            model_name,
            messages,
            max_tokens=16000,  # 最大生成Token数
            text={"format": {"type": "text"}},  # 输出格式:纯文本
        )
        # 提取LLM输出文本
        output_text = result.content or ""
    else:
        # OpenAI提供商:使用EventAdapter适配Responses API
        jsonl_path = dirs["orchestrator"] / "subgraphs.stream.jsonl"
        adapter = EventAdapter(
            model=model_name,
            store_responses=False,
            timeout_s=llm_timeout_s,
            jsonl_path=jsonl_path,
        )
        # 流式调用LLM
        result = adapter.stream(
            system_prompt=system,
            user_prompt=user,
            extras={"text": {"format": {"type": "text"}}},
        )
        # 提取输出文本
        output_text = result.get("output_text", "")

    # 4. 解析LLM输出:提取JSON区块
    raw_json = _extract_json_block(output_text)
    try:
        # 解析JSON字符串为Python对象
        data = json.loads(raw_json)
    except Exception as e:
        # 解析失败:写入诊断文件并退出
        diag = dirs["orchestrator"] / "subgraphs.raw.txt"
        diag.write_text(output_text, encoding="utf-8")
        raise SystemExit(f"Failed to parse LLM JSON: {e}")

    # 检查解析结果是否为列表:非列表则退出
    if not isinstance(data, list):
        raise SystemExit("LLM output JSON is not a list")

    # 5. 子图去重:按签名合并重复子图并累加计数
    grouped: dict[str, dict[str, Any]] = {}

    def sig_of(it: dict[str, Any]) -> str:
        """内部函数:为子图生成唯一签名(基于算子+形状+权重)"""
        # 提取算子列表并标准化(排序每个算子字典的键)
        ops = it.get("ops") or []
        ops_norm = []
        for op in ops:
            if isinstance(op, dict):
                ops_norm.append(json.loads(json.dumps(op, sort_keys=True)))
            else:
                ops_norm.append(op)
        # 提取输入形状(兼容单输入/多输入格式)
        inputs_single = it.get("input_shape")
        inputs_multi = it.get("inputs")
        # 提取输出形状
        outputs = it.get("output_shape")
        # 提取权重信息(兼容不同命名方式)
        weights = it.get("weights") or {}
        weights_fused = it.get("weights_fused") or {}
        weights_original = it.get("weights_original") or {}

        # 内部函数:标准化权重字典(按键排序保证稳定性)
        def sort_w(obj: Any) -> dict[str, Any]:
            if isinstance(obj, dict):
                return {k: obj[k] for k in sorted(obj.keys())}
            return {}

        # 标准化权重字典
        weights_norm = sort_w(weights)
        wf_norm = sort_w(weights_fused)
        wo_norm = sort_w(weights_original)
        # 构建签名对象:包含算子、形状、权重、布局、数据类型
        sig_obj = {
            "ops": ops_norm,
            "in": inputs_multi if inputs_multi is not None else inputs_single,
            "out": outputs,
            "w": weights_norm,
            "wf": wf_norm,
            "wo": wo_norm,
            "layout": it.get("data_layout"),
            "dtype": it.get("dtype"),
        }
        # 转为JSON字符串作为唯一签名
        return json.dumps(sig_obj, sort_keys=True)

    # 遍历所有子图,按签名分组
    for it in data:
        s = sig_of(it)
        if s not in grouped:
            # 确保子图ID存在:无则基于签名哈希生成
            if not it.get("id"):
                it["id"] = f"sg_{hash(s) & 0xFFFFFFFF:08x}"
            # 标准化计数:确保为整数,默认1
            c = it.get("count")
            try:
                count_val = int(c) if c is not None else 1
            except Exception:
                count_val = 1
            it["count"] = count_val
            # 新增分组
            grouped[s] = it
        else:
            # 合并重复子图:累加计数
            try:
                grouped[s]["count"] += int(it.get("count") or 1)
            except Exception:
                grouped[s]["count"] += 1

    # 转换为去重后的列表
    deduped = list(grouped.values())
    # 构建输出JSON文件路径
    out_path = dirs["run_dir"] / "subgraphs.json"
    # 写入标准化JSON文件(缩进2,保证可读性)
    out_path.write_text(json.dumps(deduped, indent=2), encoding="utf-8")
    # 返回运行目录和JSON文件路径
    return dirs["run_dir"], out_path


def main(argv: list[str] | None = None) -> int:
    """主函数:命令行入口,解析参数并执行子图提取"""
    # 加载环境变量(若存在.env文件)
    _load_dotenv_if_present()
    # 初始化命令行参数解析器
    p = argparse.ArgumentParser(
        description="Extract unique subgraphs with shapes (JSON)"
    )
    # 必选参数:KernelBench问题文件的绝对路径
    p.add_argument(
        "--problem", required=True, help="Absolute path to KernelBench problem file"
    )
    # 可选参数:LLM模型名称(默认gpt-5)
    p.add_argument("--model", default="gpt-5", help="OpenAI model name (Responses API)")
    # 可选参数:工作进程数(默认4)
    p.add_argument("--workers", type=int, default=4)
    # 可选参数:最大迭代次数(默认5)
    p.add_argument("--max-iters", type=int, default=5)
    # 可选参数:LLM超时时间(秒,默认2400)
    p.add_argument("--llm-timeout-s", type=int, default=2400)
    # 可选参数:整体运行超时时间(秒,默认2400)
    p.add_argument("--run-timeout-s", type=int, default=2400)
    # 可选参数:目标平台(默认cuda,可选值由get_platform_choices返回)
    p.add_argument(
        "--target-platform",
        default="cuda",
        choices=get_platform_choices(),
        help="Target platform",
    )
    # 解析命令行参数
    args = p.parse_args(argv)

    try:
        # 验证并转换问题文件路径为绝对路径
        problem_path = ensure_abs_regular_file(args.problem)
    except PathSafetyError as e:
        # 路径验证失败:输出错误并返回退出码2
        print(str(e), file=sys.stderr)
        return 2

    # 执行子图提取
    run_dir, json_path = extract_subgraphs_to_json(
        problem_path=problem_path,
        model_name=args.model,
        workers=args.workers,
        max_iters=args.max_iters,
        llm_timeout_s=args.llm_timeout_s,
        run_timeout_s=args.run_timeout_s,
        target_platform=args.target_platform,
    )
    # 输出JSON文件路径
    print(str(json_path))
    # 正常退出:返回0
    return 0

0xFF 参考

KernelFalcon: Autonomous GPU Kernel Generation via Deep Agents

基于 LLM 的 GPU 内核代码自动生成相关工作

Automating GPU Kernel Generation with DeepSeek-R1 and Inference Time Scaling

DeepSeek-R1自写CUDA内核跑分屠榜!斯坦福学霸狂飙GPU编程自动化挑战人类

CUDA、Triton 内核生成现状追踪

大模型能否为不同硬件平台生成高性能内核?南大、浙大提出跨平台内核生成评测框架MultiKernelBench

AKG kernel Agent:利用multi-agent进行kernel的生成和迁移

AKG KERNEL AGENT: A MULTI-AGENT FRAMEWORK FOR CROSS-PLATFORM KERNEL SYNTHESIS

AIKG -- 基于AI驱动的算子生成器

RL 猛刷 CUDA 核:CUDA-L1: Improving CUDA Optimization via Contrastive Reinforcement Learning

MultiKernelBench: A Multi-Platform Benchmark for Kernel Generation

Ouyang A, Guo S, Arora S, et al. Kernelbench: Can llms write efficient gpu kernels?[J]. arXiv preprint arXiv:2502.10517, 2025.

Baronio, Carlo, et al. "Kevin: Multi-turn rl for generating cuda kernels."arXiv preprint arXiv:2507.11948(2025).

Li, Shangzhan, et al. "Autotriton: Automatic triton programming with reinforcement learning in llms."arXiv preprint arXiv:2507.05687(2025).

Li, Jianling, et al. "Tritonbench: Benchmarking large language model capabilities for generating triton operators."Findings of the Association for Computational Linguistics: ACL 2025. 2025.

Tjarko Lange, Robert, et al. "Towards Robust Agentic CUDA Kernel Benchmarking, Verification, and Optimization."arXiv e-prints(2025): arXiv-2509.

Chen, Wentao, et al. "CUDA-LLM: LLMs Can Write Efficient CUDA Kernels."arXiv preprint arXiv:2506.09092(2025).