PyTorch KernelAgent 源码解读 ---(4)--- ExtractorAgent
目录
- [PyTorch KernelAgent 源码解读 ---(4)--- ExtractorAgent](#PyTorch KernelAgent 源码解读 ---(4)--- ExtractorAgent)
- [0x00 摘要](#0x00 摘要)
- [0x01 功能详解](#0x01 功能详解)
- [0x02 Prompt](#0x02 Prompt)
- [2.1 概括](#2.1 概括)
- [2.2 Prompt 的基本结构](#2.2 Prompt 的基本结构)
- [System Prompt](#System Prompt)
- [User Prompt 详细结构](#User Prompt 详细结构)
- [2.3 使用时机](#2.3 使用时机)
- [0x03 实现](#0x03 实现)
- [3.1 特色](#3.1 特色)
- [3.2 代码](#3.2 代码)
- [0xFF 参考](#0xFF 参考)
0x00 摘要
subgraph_extractor.py 是 KernelFalcon 实现 "PyTorch 模型子图提取 + 形状签名去重" 的关键组件,核心职责是通过 Fuser 生成融合代码后,借助 LLM 解析并提取模型中唯一的计算子图(按形状 / 算子 / 权重特征去重),最终输出标准化 JSON 格式的子图信息。这一模块体现了 "Agent 端到端优化" 中 "精准子图识别" 的关键能力。
Extractor 的架构如下,其三阶段总结如下:
- 融合重写:调Orchestrator(多worker并行),让LLM把原PyTorch改写成可融合子模块,沙箱验证数值等价,产出code.py.tgz.
- 子图识别:把原问题+融合代码一起喂给LLM,要求输出每个唯一子图的 JSON描述(ops、shapes、weights、来源代码片段)
- 去重合并:按ops+shapes+weights+layout+dtype算签名,相同签名合并count,确保Dispatcher 只为每个独特子图生成一次内核
0x01 功能详解
subgraph_extractor.py 是子图提取的关键组件,用于从融合的 PyTorch 代码中识别和提取可融合的子图及其形状信息。这个组件使得复杂的融合操作可以分解为更小、更易管理的子图,每个子图都有精确的形状和语义信息,为后续的 Triton 内核生成奠定了坚实基础。
1.1 核心作用
subgraph_extractor.py 的核心作用如下:
-
分析融合代码:从 Fuser 生成的融合 PyTorch 代码中提取语义信息
-
识别可以独立优化的子图
-
形状感知:提取精确的输入 / 输出形状用于后续优化
-
生成结构化描述:创建精确的 JSON 格式子图描述
-
去重机制:基于形状签名消除重复子图
-
为后续阶段提供输入:为 Triton 内核生成和最终合成提供基础数据
主要功能流程如下。
python
问题文件
↓
Fuser Orchestrator 生成融合代码
↓
subgraph_extractor.py 分析融合代码并提取子图
↓
subgraphs.json 结构化子图描述
1.2 详细分析
代码提取功能
_load_code_from_tar 函数完成代码提取功能。
python
def _load_code_from_tar(artifact_path: Path) -> str:
"""从tar.gz压缩包中读取code.py文件内容(Fuser生成的融合代码)"""
# 检查压缩包文件是否存在,不存在则返回空字符串
if not artifact_path.is_file():
return ""
# 以只读模式打开gzip压缩的tar包
with tarfile.open(artifact_path, "r:gz") as tf:
try:
# 获取压缩包中名为code.py的文件成员
member = tf.getmember("code.py")
except KeyError:
# 若code.py不存在,返回空字符串
return ""
# 提取code.py文件内容
extracted = tf.extractfile(member)
# 若提取失败(文件为空),返回空字符串
if extracted is None:
return ""
# 读取文件内容并解码为UTF-8字符串返回
return extracted.read().decode("utf-8")
LLM提示构建
_build_llm_prompt_for_shapes 完成了 prompt构建功能,其关键特点是:
- 精确性:要求精确的形状签名
- 结构化:强制返回特定的 JSON 格式
- 完整性:包含操作、权重、布局等所有相关信息
python
def _build_llm_prompt_for_shapes(fused_code: str, problem_code: str) -> tuple[str, str]:
"""构建LLM提示词:引导LLM分析融合代码和原始代码,提取子图信息"""
# System Prompt:强制要求仅返回JSON数组
system = "Return a single JSON array only."
user_lines: list[str] = []
# 角色与背景说明:告知LLM输入内容(原始问题代码+融合代码)
user_lines.append(
"You are given:\n- The original problem (PyTorch).\n- A fused refactor produced by Fuser (PyTorch subgraph modules)."
)
# 核心任务说明:按形状签名识别唯一子图,输出指定Schema的JSON数组
user_lines.append(
"Task: Identify every unique subgraph by exact shape signature and emit a JSON array matching this schema (and only this schema):"
)
# 详细Schema定义:明确每个字段的含义和格式
user_lines.append(
"{\n"
' "id": <string>,\n'
' "type": <string>,\n'
' "data_layout": <\\"NCHW\\"|\\"NHWC\\"|null>,\n'
' "dtype": <string|null>,\n'
' "ops": [ {"op": <string>, ... op-specific fields ... } ],\n'
' "input_shape": [<int|sym>, ...] // OR \\"inputs\\": [[...], [...]] for multi-input\n'
' "output_shape": [<int|sym>, ...],\n'
' "weights_fused": { <name>: [<int|sym>, ...], ... } | null,\n'
' "weights_original": { <name>: [<int|sym>, ...], ... } | null,\n'
' "count": <int>,\n'
' "where": <string>,\n'
' "source": { "module": <string>, "code": <string> }\n'
"}"
)
# 关键注意事项:细化提取规则,提升准确性
user_lines.append("Notes:")
user_lines.append(
"- Treat any shape difference (inputs/outputs/weights) as a distinct subgraph. Count occurrences."
)
user_lines.append(
"- Populate op-specific fields for conv/pool/linear, e.g., kernel_size/stride/padding/groups, bn_fused, output_size, start_dim."
)
user_lines.append(
"- Include both weights_original (pre-fusion params like BN gamma/beta/running stats) and weights_fused (post-fusion conv/bias). Use null if not applicable."
)
user_lines.append(
"- Provide a short \"where\" string (e.g., 'Model.forward stem' or 'layer2.block3.conv')."
)
user_lines.append(
'- Provide "source" with the smallest contiguous code snippet implementing the subgraph.'
)
user_lines.append(
"- Use data_layout and dtype when clear (default conv layout is NCHW)."
)
user_lines.append(
'- For binary ops like residual add, use "inputs": [[...],[...]].'
)
user_lines.append(
"- Prefer concrete integers from get_inputs() shapes in the problem; otherwise use symbols like B, H, W."
)
user_lines.append("")
# 输入代码:原始问题代码
user_lines.append("PROBLEM_FILE:\n```python")
user_lines.append(problem_code)
user_lines.append("```")
user_lines.append("")
# 输入代码:Fuser生成的融合代码
user_lines.append("FUSED_CODE:")
user_lines.append("""```python""")
user_lines.append(fused_code)
user_lines.append("```")
user_lines.append("")
# 最终要求:仅返回包含数组的JSON代码块,无其他文本
user_lines.append(
"Now return only one fenced JSON block containing the array. No prose."
)
# 返回System Prompt和User Prompt
return system, "\n".join(user_lines)
形状签名去重机制
_dedup_by_shape_signature 实现了去重代码。
- 基于输入 / 权重 / 输出形状的标准化表示
- 忽略名称但保留维度和数据类型
- 确保相同语义的子图被合并
python
def _dedup_by_shape_signature(items: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Deduplicate items by a stable shape signature.
The signature is based on sorted lists of input/weight/output shapes content,
ignoring names but preserving dimensions and dtypes.
基于稳定的形状签名对子图列表去重:
- 签名基于输入/权重/输出形状的标准化内容(忽略名称,保留维度和数据类型)
- 保证相同形状特征的子图只保留一个
"""
def norm_shapes(arr: Any) -> Any:
"""内部函数:标准化形状数组(统一不同格式的形状描述)"""
# 非列表类型直接返回空列表
if not isinstance(arr, list):
return []
normed: list[Any] = []
# 遍历数组中的每个元素
for e in arr:
if isinstance(e, dict):
# 标准化形状字典的键(兼容不同命名方式:shape/dims/size)
shape = e.get("shape") or e.get("dims") or e.get("size")
dtype = e.get("dtype")
kind = e.get("kind") or e.get("role")
# 标准化维度:优先int/str类型,其他类型转为字符串
if isinstance(shape, list):
dims = [str(x) for x in shape]
elif isinstance(shape, (int, str)):
dims = [str(shape)]
else:
dims = [str(shape)] if shape is not None else []
# 构建标准化的形状描述字典
normed.append(
{"dims": dims, "dtype": str(dtype) if dtype else None, "k": kind}
)
else:
# 非字典元素直接转为字符串
normed.append(str(e))
# 排序以保证签名的稳定性(避免顺序不同导致签名不同)
return sorted(normed, key=lambda x: json.dumps(x, sort_keys=True))
# 存储已见过的签名,避免重复
seen: set[str] = set()
out: list[dict[str, Any]] = []
# 遍历所有子图项
for it in items:
# 构建签名对象:包含输入/权重/输出的标准化形状
sig_obj = {
"inputs": norm_shapes(it.get("input_shapes")),
"weights": norm_shapes(it.get("weight_shapes") or it.get("weights")),
"outputs": norm_shapes(it.get("output_shapes")),
}
# 转为JSON字符串作为唯一签名(排序保证稳定性)
sig = json.dumps(sig_obj, sort_keys=True)
# 若签名未见过,则保留该子图
if sig in seen:
continue
seen.add(sig)
out.append(it)
# 返回去重后的子图列表
return out
1.3 流程图
subgraph_extractor.py 的流程如下:
初始化阶段:
- 创建OrchestratorConfig配置对象
- 生成唯一的运行ID并创建运行目录结构
- 初始化Orchestrator对象
代码提取阶段:
- 运行Orchestrator.run()获取融合后的PyTorch代码
- 检查是否成功找到解决方案
- 加载原始问题代码和融合后的代码
LLM分析阶段:
- 构建包含原始问题和融合代码的提示
- 根据提供商类型选择不同的API调用方式
- 提取并解析LLM返回的JSON格式的子图描述
后处理阶段:
- 验证JSON结构的有效性
- 通过形状签名对子图进行去重和合并
- 保存最终的subgraphs.json文件
- 返回运行目录和JSON文件路径
具体流程图如下:
1.4 与系统其他组件的交互
与 Orchestrator 的交互
调用 Fuser Orchestrator 生成融合代码
python
orch = Orchestrator(...)
summary = orch.run()
fused_code = _load_code_from_tar(Path(summary.artifact_path))
与 Dispatch Kernel Agent 的交互
生成的 subgraphs.json 作为 dispatch_kernel_agent.py 的输入,为每个子图生成 Triton 内核。
与 Composer 的交互
subgraphs.json 作为 compose_end_to_end.py 的输入之一,用于最终的端到端合成
LLM 交互机制
python
# Provider 选择
provider = get_model_provider(model_name)
if provider.name != "openai":
# 直接调用提供商
result = provider.get_response(...)
else:
# 通过 EventAdapter 流式处理
adapter = EventAdapter(...)
result = adapter.stream(
system_prompt=SYSTEM_PROMPT, user_prompt=rp.user, extras=rp.extras)
0x02 Prompt
我们来分析 subgraph_extractor.py 中的 Prompt 构建机制。
2.1 概括
subgraph_extractor.py 使用的是 LLM 提示,专门用于从融合的 PyTorch 代码中提取子图及其形状信息。
这条 prompt 可以一句话概括:"把话说到编译器级别,不给自由发挥留缝隙。"具体特点拆解如下:
- 极端结构化
- 用 JSON Schema 把字段名、类型、取值范围、嵌套层级一次性钉死,连
null能出现在哪都标好。 - 要求"只返回一个 fenced code block",直接把自然语言出口焊死,防止模型"顺便聊聊"。
- 用 JSON Schema 把字段名、类型、取值范围、嵌套层级一次性钉死,连
- 双重代码上下文
- 同时给出"原始 PyTorch 代码"和"融合后的代码",让模型既能看到"改名前的权重"也能看到"融合后的权重",相当于开卷考试但限定只能写标准答案格式。
- 微观操作级说明书
- 对每一类算子(conv、pool、linear、add)都列出必须出现的 key(
kernel_size/stride/padding/groups...),把"该抄哪几行"写成 checklist,模型只要漏一项就能被后处理脚本一键拒收。 - 明确"形状不同就算新子图",避免模型把不同 block 的同名层合并。
- 对每一类算子(conv、pool、linear、add)都列出必须出现的 key(
- 符号系统与优先级双重约束
- 先拿
get_inputs()的 concrete shape 当"硬数",找不到才允许用B/H/W符号,既保证可静态检查,又留一条退路。 - 权重必须同时给
weights_original和weights_fused,逼模型把"融合前后张量对应关系"显式写出来,防止"黑箱合并"。
- 先拿
- Zero-shot 但 Zero-creativity
- 没有 few-shot 示例,却用 12 条"Notes"把边界情况全部穷举,等于告诉模型"你不需要创新,只需要当一台会数数的扫描仪"。
- 最后用"No prose"把寒暄、总结、解释统统 ban 掉,输出直接变成可
json.loads的"机器口粮"。
2.2 Prompt 的基本结构
System Prompt
System Prompt 的内容如下:"Return a single JSON array only."
其要求 LLM 只返回单个 JSON 数组,避免返回额外的文本说明。
User Prompt 详细结构
首先是背景介绍
python
user_lines.append(
"You are given:\n- The original problem (PyTorch).\n- A fused refactor produced by Fuser (PyTorch subgraph modules)."
)
其次是任务描述,从融合代码中识别所有独特的子图,提取精确的形状信息(输入 / 输出 / 权重),为后续的 Triton 内核生成提供结构化输入
python
user_lines.append(
"Task: Identify every unique subgraph by exact shape signature and emit a JSON array matching this schema (and only this schema):"
)
接下来会说明期望的JSON schema
python
user_lines.append(
"{\n"
' "id": <string>,\n'
' "type": <string>,\n'
' "data_layout": <\\"NCHW\\"|\\"NHWC\\"|null>,\n'
' "dtype": <string|null>,\n'
' "ops": [ {"op": <string>, ... op-specific fields ... } ],\n'
' "input_shape": [<int|sym>, ...] // OR \\"inputs\\": [[...], [...]] for multi-input\n'
' "output_shape": [<int|sym>, ...],\n'
' "weights_fused": { <name>: [<int|sym>, ...], ... } | null,\n'
' "weights_original": { <name>: [<int|sym>, ...], ... } | null,\n'
' "count": <int>,\n'
' "where": <string>,\n'
' "source": { "module": <string>, "code": <string> }\n'
"}"
)
然后是详细说明和注意事项
python
user_lines.append("Notes:")
user_lines.append(
"- Treat any shape difference (inputs/outputs/weights) as a distinct subgraph. Count occurrences."
)
user_lines.append(
"- Populate op-specific fields for conv/pool/linear, e.g., kernel_size/stride/padding/groups, bn_fused, output_size, start_dim."
)
user_lines.append(
"- Include both weights_original (pre-fusion params like BN gamma/beta/running stats) and weights_fused (post-fusion conv/bias). Use null if not applicable."
)
user_lines.append(
"- Provide a short \"where\" string (e.g., 'Model.forward stem' or 'layer2.block3.conv')."
)
user_lines.append(
'- Provide "source" with the smallest contiguous code snippet implementing the subgraph.'
)
user_lines.append(
"- Use data_layout and dtype when clear (default conv layout is NCHW)."
)
user_lines.append(
'- For binary ops like residual add, use "inputs": [[...],[...]].'
)
user_lines.append(
"- Prefer concrete integers from get_inputs() shapes in the problem; otherwise use symbols like B, H, W."
)
最后是输入代码示例
python
user_lines.append("PROBLEM_FILE:\n```python")
user_lines.append(problem_code)
user_lines.append("```")
user_lines.append("")
user_lines.append("FUSED_CODE:")
user_lines.append("""```python""")
user_lines.append(fused_code)
user_lines.append("```")
user_lines.append("")
user_lines.append(
"Now return only one fenced JSON block containing the array. No prose."
)
2.3 使用时机
在 extract_subgraphs_to_json 函数中会调用 prompt:
python
# Ask LLM for shapes JSON
system, user = _build_llm_prompt_for_shapes(fused_code, problem_code)
"""
Temporary MUX to support Relay while we migrate to OpenAI Responses API.
Uses EventAdapter for OpenAI, otherwise Provider inferface
"""
provider = get_model_provider(model_name)
if provider.name != "openai":
# 直接调用提供商
result = provider.get_response(...)
else:
# 通过 EventAdapter 流式处理
adapter = EventAdapter(...)
result = adapter.stream(...)
0x03 实现
subgraph_extractor.py 实现了 KernelFalcon 的 "子图识别" 核心能力 ------ 通过 Fuser 生成融合代码→LLM 解析代码提取子图→签名去重合并→输出标准化 JSON,为后续 Triton 算子自动生成提供精准的子图粒度输入。
3.1 特色
- LLM 驱动的智能子图识别:放弃传统的 "静态代码解析 + 规则匹配",改用 LLM 理解 PyTorch 代码语义,精准识别卷积 / 池化 / 线性层等算子的子图边界、形状、权重特征,适配复杂的融合代码场景;
- 鲁棒的签名去重机制:基于 "算子 + 输入 / 输出形状 + 权重结构 + 数据布局 + 数据类型" 构建稳定签名,避免因命名 / 格式差异导致的重复子图,保证子图识别的唯一性;
- 全链路容错设计:针对 LLM 输出格式异常、JSON 解析失败、代码文件缺失等场景,均有明确的容错逻辑和诊断文件输出,提升工业级可用性;
- 标准化输出格式:定义统一的子图 JSON Schema,包含 id、类型、形状、权重、计数等核心字段,为后续算子生成和模型优化提供标准化输入;
- 适配多 LLM 提供商:兼容 OpenAI Responses API 和其他 LLM 提供商的接口,通过适配层统一调用逻辑,保证灵活性。
3.2 代码
完整链路为:
- 从 Fuser 生成的融合代码压缩包中读取核心代码文件;
- 构造精准的 LLM Prompt,引导 LLM 分析原始 PyTorch 问题代码和融合代码,识别所有计算子图并输出 JSON 格式信息;
- 提取 LLM 输出中的 JSON 区块,处理格式异常并做容错;
- 基于 "形状 + 算子 + 权重" 的稳定签名对子图去重,合并重复子图的计数;
- 输出标准化的子图 JSON 文件,为后续 Triton 算子自动生成和模型优化提供精准的子图粒度信息。
提取时,从LLM输出中提取代码时选择最后一个python 代码块。这个思路如下:
- LLM的回复通常遵循"分析--->实现"的模式:先给出代码片段解释思路,最后给出完整实现
- 中间的代码块可能是不完整的示例、伪代码、或者修改前的版本
- 最后一个代码块最可能是LLM的"最终答案"一完整的、经过修正的实现
多代码块的处理优先级:
- 优先找带 python语言标签的最后一个代码块
- 如果没有python标签,回退到最后一个任意语言的代码块
- 提取后用ast.parse()验证是否为合法Python
- 如果解析失败,抛出ValueError
潜在风险如下:
- 如果 LLM 先给完整实现再给修改后的局部片段,会错误地选择不完整的代码
python
# 正则表达式:匹配Markdown中的JSON代码块(```json ... ```)
# 匹配规则:
# - 开头:``` 后可跟空格/tab + json(可选) + 空格/tab + 换行
# - 中间:任意字符(非贪婪匹配)
# - 结尾:换行 + ``` + 空格/tab + 行尾
# - 修饰符:MULTILINE(多行匹配)、IGNORECASE(忽略大小写)
_JSON_BLOCK_RE = re.compile(
r"^```[ \t]*(json)?[ \t]*\n([\s\S]*?)^```[ \t]*$",
re.MULTILINE | re.IGNORECASE,
)
def _extract_json_block(text: str) -> str:
"""Extract the last fenced JSON block or fallback to best-effort slice.
从LLM输出文本中提取JSON区块:
1. 优先提取最后一个标记为json的代码块;
2. 若无则提取最后一个任意代码块;
3. 若仍无则尝试截取第一个[和最后一个]之间的内容;
4. 均失败则返回空字符串
"""
# 查找所有匹配的代码块
matches = list(_JSON_BLOCK_RE.finditer(text))
chosen: re.Match[str] | None = None
# 逆序遍历匹配结果,优先选择标记为json的代码块
for m in reversed(matches):
lang = (m.group(1) or "").strip().lower()
if lang == "json":
chosen = m
break
# 若无json标记的代码块,但有其他代码块,选择最后一个
if chosen is None and matches:
chosen = matches[-1]
# 若找到代码块,返回其中的内容(group(2))
if chosen is not None:
return chosen.group(2)
# 兜底方案:尝试截取第一个[和最后一个]之间的内容(JSON数组)
start = text.find("[")
end = text.rfind("]")
if start != -1 and end != -1 and end > start:
return text[start : end + 1]
# 所有方案均失败,返回空字符串
return ""
def extract_subgraphs_to_json(
problem_path: Path,
model_name: str,
workers: int,
max_iters: int,
llm_timeout_s: int,
run_timeout_s: int,
target_platform: str = "cuda",
) -> tuple[Path, Path]:
"""Run Fuser to produce fused code, then use LLM to emit subgraphs JSON.
Returns (run_dir, json_path).
核心函数:
1. 运行Orchestrator执行Fuser生成融合代码;
2. 调用LLM分析融合代码和原始代码,提取子图信息;
3. 解析LLM输出,去重并合并计数;
4. 输出标准化的子图JSON文件;
返回值:(运行目录路径, JSON文件路径)
"""
# 1. 配置Orchestrator:设置Fuser运行参数
cfg = OrchestratorConfig(
problem_path=problem_path, # KernelBench问题文件路径
model=model_name, # LLM模型名称
workers=workers, # 工作进程数
max_iters=max_iters, # 最大迭代次数
llm_timeout_s=llm_timeout_s, # LLM调用超时时间(秒)
run_timeout_s=run_timeout_s, # 整体运行超时时间(秒)
stream_mode="winner", # 流模式:仅保留最优结果
store_responses=False, # 不存储LLM响应
isolated=False, # 非隔离模式
deny_network=False, # 允许网络访问
enable_reasoning_extras=True, # 启用推理增强功能
target_platform=target_platform, # 目标平台(默认CUDA)
)
# 生成唯一运行ID
run_id = new_run_id()
# 构建运行目录(当前目录/.fuse)
base_dir = Path.cwd() / ".fuse"
base_dir.mkdir(exist_ok=True) # 目录不存在则创建
# 创建运行相关的目录结构
dirs = make_run_dirs(base_dir, run_id)
# 初始化Orchestrator:负责调度Fuser和LLM任务
orch = Orchestrator(
cfg,
run_dir=dirs["run_dir"], # 运行目录
workers_dir=dirs["workers"], # 工作进程目录
orchestrator_dir=dirs["orchestrator"], # 调度器目录
)
# 运行Orchestrator,获取执行摘要
summary = orch.run()
# 检查是否生成有效融合代码:无最优结果则退出
if summary.winner_worker_id is None or not summary.artifact_path:
raise SystemExit(f"No passing fused code: {summary.reason}")
# 2. 加载融合代码:从Orchestrator输出的压缩包中读取code.py
fused_code = _load_code_from_tar(Path(summary.artifact_path))
# 检查融合代码是否有效:为空则退出
if not fused_code.strip():
raise SystemExit("Winner artifact missing code.py or empty")
# 加载原始问题代码
problem_code = problem_path.read_text(encoding="utf-8")
# 3. 构建LLM Prompt并调用LLM
# 构建System和User Prompt
system, user = _build_llm_prompt_for_shapes(fused_code, problem_code)
"""
Temporary MUX to support Relay while we migrate to OpenAI Responses API.
Uses EventAdapter for OpenAI, otherwise Provider inferface
临时适配层:兼容OpenAI Responses API和其他LLM提供商
"""
# 获取LLM提供商实例
provider = get_model_provider(model_name)
if provider.name != "openai":
# 非OpenAI提供商:使用标准Provider接口
messages: list[dict[str, str]] = [
{"role": "system", "content": system},
{"role": "user", "content": user},
]
# 调用LLM获取响应
result = provider.get_response(
model_name,
messages,
max_tokens=16000, # 最大生成Token数
text={"format": {"type": "text"}}, # 输出格式:纯文本
)
# 提取LLM输出文本
output_text = result.content or ""
else:
# OpenAI提供商:使用EventAdapter适配Responses API
jsonl_path = dirs["orchestrator"] / "subgraphs.stream.jsonl"
adapter = EventAdapter(
model=model_name,
store_responses=False,
timeout_s=llm_timeout_s,
jsonl_path=jsonl_path,
)
# 流式调用LLM
result = adapter.stream(
system_prompt=system,
user_prompt=user,
extras={"text": {"format": {"type": "text"}}},
)
# 提取输出文本
output_text = result.get("output_text", "")
# 4. 解析LLM输出:提取JSON区块
raw_json = _extract_json_block(output_text)
try:
# 解析JSON字符串为Python对象
data = json.loads(raw_json)
except Exception as e:
# 解析失败:写入诊断文件并退出
diag = dirs["orchestrator"] / "subgraphs.raw.txt"
diag.write_text(output_text, encoding="utf-8")
raise SystemExit(f"Failed to parse LLM JSON: {e}")
# 检查解析结果是否为列表:非列表则退出
if not isinstance(data, list):
raise SystemExit("LLM output JSON is not a list")
# 5. 子图去重:按签名合并重复子图并累加计数
grouped: dict[str, dict[str, Any]] = {}
def sig_of(it: dict[str, Any]) -> str:
"""内部函数:为子图生成唯一签名(基于算子+形状+权重)"""
# 提取算子列表并标准化(排序每个算子字典的键)
ops = it.get("ops") or []
ops_norm = []
for op in ops:
if isinstance(op, dict):
ops_norm.append(json.loads(json.dumps(op, sort_keys=True)))
else:
ops_norm.append(op)
# 提取输入形状(兼容单输入/多输入格式)
inputs_single = it.get("input_shape")
inputs_multi = it.get("inputs")
# 提取输出形状
outputs = it.get("output_shape")
# 提取权重信息(兼容不同命名方式)
weights = it.get("weights") or {}
weights_fused = it.get("weights_fused") or {}
weights_original = it.get("weights_original") or {}
# 内部函数:标准化权重字典(按键排序保证稳定性)
def sort_w(obj: Any) -> dict[str, Any]:
if isinstance(obj, dict):
return {k: obj[k] for k in sorted(obj.keys())}
return {}
# 标准化权重字典
weights_norm = sort_w(weights)
wf_norm = sort_w(weights_fused)
wo_norm = sort_w(weights_original)
# 构建签名对象:包含算子、形状、权重、布局、数据类型
sig_obj = {
"ops": ops_norm,
"in": inputs_multi if inputs_multi is not None else inputs_single,
"out": outputs,
"w": weights_norm,
"wf": wf_norm,
"wo": wo_norm,
"layout": it.get("data_layout"),
"dtype": it.get("dtype"),
}
# 转为JSON字符串作为唯一签名
return json.dumps(sig_obj, sort_keys=True)
# 遍历所有子图,按签名分组
for it in data:
s = sig_of(it)
if s not in grouped:
# 确保子图ID存在:无则基于签名哈希生成
if not it.get("id"):
it["id"] = f"sg_{hash(s) & 0xFFFFFFFF:08x}"
# 标准化计数:确保为整数,默认1
c = it.get("count")
try:
count_val = int(c) if c is not None else 1
except Exception:
count_val = 1
it["count"] = count_val
# 新增分组
grouped[s] = it
else:
# 合并重复子图:累加计数
try:
grouped[s]["count"] += int(it.get("count") or 1)
except Exception:
grouped[s]["count"] += 1
# 转换为去重后的列表
deduped = list(grouped.values())
# 构建输出JSON文件路径
out_path = dirs["run_dir"] / "subgraphs.json"
# 写入标准化JSON文件(缩进2,保证可读性)
out_path.write_text(json.dumps(deduped, indent=2), encoding="utf-8")
# 返回运行目录和JSON文件路径
return dirs["run_dir"], out_path
def main(argv: list[str] | None = None) -> int:
"""主函数:命令行入口,解析参数并执行子图提取"""
# 加载环境变量(若存在.env文件)
_load_dotenv_if_present()
# 初始化命令行参数解析器
p = argparse.ArgumentParser(
description="Extract unique subgraphs with shapes (JSON)"
)
# 必选参数:KernelBench问题文件的绝对路径
p.add_argument(
"--problem", required=True, help="Absolute path to KernelBench problem file"
)
# 可选参数:LLM模型名称(默认gpt-5)
p.add_argument("--model", default="gpt-5", help="OpenAI model name (Responses API)")
# 可选参数:工作进程数(默认4)
p.add_argument("--workers", type=int, default=4)
# 可选参数:最大迭代次数(默认5)
p.add_argument("--max-iters", type=int, default=5)
# 可选参数:LLM超时时间(秒,默认2400)
p.add_argument("--llm-timeout-s", type=int, default=2400)
# 可选参数:整体运行超时时间(秒,默认2400)
p.add_argument("--run-timeout-s", type=int, default=2400)
# 可选参数:目标平台(默认cuda,可选值由get_platform_choices返回)
p.add_argument(
"--target-platform",
default="cuda",
choices=get_platform_choices(),
help="Target platform",
)
# 解析命令行参数
args = p.parse_args(argv)
try:
# 验证并转换问题文件路径为绝对路径
problem_path = ensure_abs_regular_file(args.problem)
except PathSafetyError as e:
# 路径验证失败:输出错误并返回退出码2
print(str(e), file=sys.stderr)
return 2
# 执行子图提取
run_dir, json_path = extract_subgraphs_to_json(
problem_path=problem_path,
model_name=args.model,
workers=args.workers,
max_iters=args.max_iters,
llm_timeout_s=args.llm_timeout_s,
run_timeout_s=args.run_timeout_s,
target_platform=args.target_platform,
)
# 输出JSON文件路径
print(str(json_path))
# 正常退出:返回0
return 0
0xFF 参考
KernelFalcon: Autonomous GPU Kernel Generation via Deep Agents
Automating GPU Kernel Generation with DeepSeek-R1 and Inference Time Scaling
DeepSeek-R1自写CUDA内核跑分屠榜!斯坦福学霸狂飙GPU编程自动化挑战人类
大模型能否为不同硬件平台生成高性能内核?南大、浙大提出跨平台内核生成评测框架MultiKernelBench
AKG kernel Agent:利用multi-agent进行kernel的生成和迁移
AKG KERNEL AGENT: A MULTI-AGENT FRAMEWORK FOR CROSS-PLATFORM KERNEL SYNTHESIS
RL 猛刷 CUDA 核:CUDA-L1: Improving CUDA Optimization via Contrastive Reinforcement Learning
MultiKernelBench: A Multi-Platform Benchmark for Kernel Generation
Ouyang A, Guo S, Arora S, et al. Kernelbench: Can llms write efficient gpu kernels?[J]. arXiv preprint arXiv:2502.10517, 2025.
Baronio, Carlo, et al. "Kevin: Multi-turn rl for generating cuda kernels."arXiv preprint arXiv:2507.11948(2025).
Li, Shangzhan, et al. "Autotriton: Automatic triton programming with reinforcement learning in llms."arXiv preprint arXiv:2507.05687(2025).
Li, Jianling, et al. "Tritonbench: Benchmarking large language model capabilities for generating triton operators."Findings of the Association for Computational Linguistics: ACL 2025. 2025.
Tjarko Lange, Robert, et al. "Towards Robust Agentic CUDA Kernel Benchmarking, Verification, and Optimization."arXiv e-prints(2025): arXiv-2509.
Chen, Wentao, et al. "CUDA-LLM: LLMs Can Write Efficient CUDA Kernels."arXiv preprint arXiv:2506.09092(2025).