PyTorch KernelAgent 源码解读 ---(6)--- Composer
目录
- [PyTorch KernelAgent 源码解读 ---(6)--- Composer](#PyTorch KernelAgent 源码解读 ---(6)--- Composer)
- [0x00 概要](#0x00 概要)
- [0x01 核心功能](#0x01 核心功能)
- [0x02 详细功能](#0x02 详细功能)
- [2.1 使用](#2.1 使用)
- [2.2 错误处理与迭代](#2.2 错误处理与迭代)
- [2.3 _load_kernels_from_summary](#2.3 _load_kernels_from_summary)
- [2.4 Prompt](#2.4 Prompt)
- [2.5 组合生成](#2.5 组合生成)
- [0xFF 参考](#0xFF 参考)
0x00 概要
Fuser/compose_end_to_end.py 是 Fuser 管道中的最后一个关键步骤,它将分散的、针对特定子图优化的 Triton 内核无缝地整合成一个单一的、高性能的端到端 Triton 内核,同时确保其功能与原始 PyTorch 实现的数值等价性。
Composer的架构图如下,其功能概括是一句话:把所有验证通过的子图内核+原问题喂给LLM拼成单文件。也会把错误日志让LLM修,最多重试 max_iters轮,并自动 patch 常见 Triton 陷阱。
0x01 核心功能
1.1 核心作用
Composer的核心功能如下:
- 子图内核整合:将 Fuser 流程中拆分的子图及其对应的、已验证的 Triton 内核,重新组合为一个端到端的 Triton 实现,替代原始 PyTorch 代码的前向传播逻辑。
- LLM 驱动的代码生成:以原始问题代码、子图信息、子图 Triton 内核为输入,通过定制化 Prompt 调用 LLM 生成完整的 Triton 内核代码。
- 功能验证与迭代优化:支持自动验证生成的内核(对比 PyTorch 参考结果),若验证失败则基于错误信息迭代调用 LLM 修正代码,直到通过验证或达到最大迭代次数。
- 约束保障 :通过严格的代码规范(如必须包含
kernel_function、禁止 PyTorch 计算逻辑)和数值校验,确保生成的 Triton 内核可用、正确。
1.2 核心特色
| 特色方向 | 具体说明 |
|---|---|
| 严格的代码约束 | 强制要求生成的代码包含 kernel_function 顶层函数(与原始模型输入一致)、@triton.jit 内核、自测函数(输出 PASS/0 退出码);禁止内核中使用 PyTorch 计算逻辑(仅允许自测时对比)。 |
| 智能错误迭代 | 捕获编译 / 运行错误(stderr/stdout),构建精细化 Prompt 让 LLM 定位并修正问题(如 Triton 常见的 tl.broadcast 误用)。 |
| 自动补丁修复 | 内置 Triton 常见问题的文本补丁(如替换 tl.broadcast(0.0) 为 0.0),减少无意义的 LLM 迭代。 |
| 完整的日志留存 | 保存每一轮的 Prompt、生成的代码、验证结果,便于调试和追溯生成过程。 |
| 数值等价性保障 | 要求自测函数使用 allclose 校验数值(fp32: rtol≤1e-3/atol≤1e-3;fp16/bf16: ≤2e-2),确保 Triton 实现与 PyTorch 结果一致。 |
1.3 流程图
核心逻辑关系图
完整执行流程图
0x02 详细功能
2.1 使用
compose_end_to_end.py 会合成(compose)一个端到端的 Triton 内核,用来解决原始 KernelBench 问题。
compose_end_to_end.py 会将原始问题文件、子图分解信息 + 各张量形状(来自 subgraphs.json)、以及已生成的子图 Triton 内核作为输入,构建提示(prompt)发送给 LLM。然后LLM 会把这些碎片拼成一个语义与原始问题完全一致 的完整内核,最终返回一个 Python 文件(/composed_kernel.py),里面提供:
- 一个或多个使用
@triton.jit装饰的 Triton 内核。 - 一个名为
kernel_function(...)的顶层 Python 包装函数,它接受与原始模型相同的输入张量,并协调 Triton 内核的执行,返回最终输出。 - 一个自测函数(如
test_kernel或run_tests),该函数比较 Triton 实现的结果与原始 PyTorch 问题代码的参考结果,并在成功时打印'PASS'并退出。 - 生成过程的元数据和验证结果会被记录在一个 JSON 格式的摘要文件中(如
composition_summary.json)。
compose_end_to_end.py 用法如下:
bash
python -m Fuser.compose_end_to_end \
--problem /abs/path/to/kernelbench_problem.py \
--subgraphs /abs/path/to/subgraphs.json \
--kernels-summary /abs/path/to/kernels_out/summary.json \
[--model gpt-5] \
[--out-dir ./compose_out] \
[--verify]
可以通过 --verify 标志启用自动验证。在此模式下,每个 LLM 生成的组合尝试(无论是初始的还是经过修正的)都会被传递给 Fuser/runner.py 中的 run_candidate 函数来执行。
验证的成功与否取决于执行是否正常退出(exit code 0)并且输出中包含 'PASS' 字符串或 ALL_TESTS_PASSED 字符串。
2.2 错误处理与迭代
如果 LLM 生成的第一个组合内核无法通过验证(运行 / 编译失败),该脚本会捕获错误信息(stderr,stdout)。它会构建一个新的提示(_build_refinement_prompt),将错误信息作为上下文提供给 LLM,要求其修正代码。这个过程可以重复多次(max_iters 参数控制最大迭代次数),直到生成的内核通过验证或达到最大迭代次数。
2.3 _load_kernels_from_summary
_load_kernels_from_summary 为构建 prompt 提供 「代码素材」(各子图的有效 Triton 内核代码),具体而言,_load_kernels_from_summary从调度阶段生成的内核汇总 JSON 文件中,过滤并加载所有成功生成的有效子图 Triton 内核,校验数据格式与文件有效性,封装为标准化 KernelItem 对象列表,为 LLM 组合内核提供可直接复用的有效代码素材,过滤失败、无效的内核产物,避免无效素材干扰后续组合流程。
其特殊如下:
-
多维度有效内核过滤:依次过滤「非列表格式汇总数据、非字典格式子项、标记为失败的内核、无 ID / 无内核路径的子项、内核文件不存在的子项」,仅保留全量校验通过的成功内核,从源头保证代码素材的有效性;
-
关键字段强制校验:子图 ID(sid)和内核文件路径(kernel_path)为必选字段,缺失任一则直接过滤,确保每个有效内核都能关联到唯一子图且存在实际代码文件;
-
标准化对象封装:将内核的子图 ID、文件路径、代码内容封装为 KernelItem 对象,而非原始字典 / 字符串,提升后续代码处理的可读性与可维护性;
-
无有效内核直接终止:若汇总文件中无任何有效内核,直接抛出 SystemExit 异常终止流程,避免后续流程因无有效素材而无意义执行;
-
兼容调度阶段输出格式:严格适配 dispatch 步骤生成的 summary.json 格式,实现上下游流程的无缝衔接。
2.4 Prompt
以下三个函数是 PyTorch KernelAgent 中LLM 生成端到端 Triton 内核的 Prompt 构建核心模块 ,为 LLM 提供标准化、高指导性、场景适配的精准输入提示,是连接「分散的子图 / 内核 / 问题数据」与「LLM 可理解的生成指令」的关键枢纽。其中
_summarize_subgraphs_for_prompt为基础数据处理函数,负责将子图信息格式化;_build_composition_prompt和_build_refinement_prompt为双 Prompt 构建主函数,分别支撑首次端到端内核组合生成 和基于错误的迭代精修两大核心场景,通过严格的指令约束、完整的上下文信息、针对性的优化指导,确保 LLM 生成符合工程要求、硬件适配、可直接运行的 Triton 内核代码。
核心作用
- _summarize_subgraphs_for_prompt :作为基础支撑函数,将模型分解后的子图信息列表结构化、简洁化转换为文本摘要,提取子图 ID、类型、数据布局、数据类型、输入输出形状、核心算子等关键约束信息,按统一格式拼接为易读字符串,为两个 Prompt 构建函数提供标准化的子图信息描述,让 LLM 快速理解各子图的功能与计算约束。
- _build_composition_prompt :为 LLM 首次生成构建全量上下文、强约束的组合型 Prompt,整合原始 PyTorch 问题代码、子图信息摘要、各子图有效 Triton 内核代码、目标硬件平台配置四大核心信息,明确 LLM 的核心任务是融合子图内核生成端到端 Triton 实现,并制定严格的工程要求、硬件适配规则、Triton 开发规范,指导 LLM 完成从「分散子图」到「一体化内核」的组合与优化。
- _build_refinement_prompt :为 LLM 迭代优化构建错误导向、针对性的精修型 Prompt,在保留核心基础信息的前提下,新增前次生成的错误日志(stdout/stderr)、上一轮失败的代码实现,明确 LLM 的核心任务是基于错误信息定位问题并修正代码,同时追加更具体的错误修复要求,确保精修后的代码能解决编译 / 运行问题,且不违背原有工程规范。
_build_composition_prompt 专属特色
- 四大核心信息全量整合:完整融入「原始问题代码(需求基准)、子图摘要(逻辑约束)、子图内核(代码素材)、平台配置(硬件规则)」,让 LLM 既理解「要做什么(PyTorch 模型功能)」,又知道「有什么可用(子图内核)」,还明确「该怎么适配(硬件平台)」;
- 核心理念明确:融合与融合:明确要求 LLM 优先将多子图融合为尽可能少的 Triton 内核启动,在保证数值语义准确的前提下提升运行效率,贴合 Triton 内核「大粒度融合、减少内核启动开销」的优化核心理念;
- 超详细的硬性要求:制定 10 余项不可违背的硬规则,覆盖代码输出格式、设备张量管理、函数命名与封装、计算接口限制(禁用 PyTorch 计算)、数据格式与算子顺序、数值验证要求、导入与行为限制等,从源头规范代码生成;
- 实用的 Triton 开发指导:提供针对性的 Triton 实现技巧,包括子图融合的形状匹配、常量权重优化、内存访问规范(tl.load/tl.store 带掩码)、网格与分块设计,同时明确列出 Triton 常见开发陷阱及规避方法,降低 LLM 生成错误代码的概率;
- 数值等价性强制要求:明确要求生成的代码必须包含自测试函数,通过与 PyTorch 参考结果的对比验证数值正确性,并制定严格的误差容忍度(fp32/_fp16/bf16 区分),确保 Triton 内核的功能正确性与数值准确性。
_build_refinement_prompt 专属特色
- 错误导向的精准精修:将前次生成的 stderr/stdout 错误日志作为核心参考,让 LLM 聚焦于「定位问题→修复问题」,避免无目的的重生成,大幅提升迭代优化的效率与针对性;
- 保留原有约束,追加修复要求:明确「原有所有要求保持不变」,仅针对错误场景追加更具体的修复规则(如禁止 tl.broadcast 滥用标量),确保精修后的代码不违背原有工程规范,同时解决具体问题;
- 全量代码重生成,拒绝差分输出:强制要求 LLM 返回完整的修正后代码,而非代码差分或修改建议,避免后续代码拼接、整合的额外工作,确保输出可直接替换使用;
- 关键标识强制保留 :明确要求保留顶层
kernel_function函数名、自测试函数及「PASS」打印 / 退出码规则,确保精修后的代码能无缝对接后续的自动化验证流程,无需修改验证逻辑; - 失败代码参考,定位问题更高效:将上一轮的失败代码完整融入 Prompt,让 LLM 能直接对比错误日志与代码实现,快速定位问题所在(如编译错误的行号、运行错误的逻辑),提升修复的准确性。
_summarize_subgraphs_for_prompt 专属特色
_summarize_subgraphs_for_prompt 提供 「逻辑约束」 (各子图的功能、形状、布局等约束信息)。具体而言,_summarize_subgraphs_for_prompt:对模型分解后的子图信息列表进行结构化、简洁化的文本汇总,提取子图 ID、类型、数据布局、数据类型、输入输出形状、核心算子等关键信息,按统一格式拼接为易读的文本字符串,为构建 LLM 组合 Prompt 提供标准化的子图信息描述,让 LLM 快速理解各子图的功能、形状约束与计算要求。
- 关键信息精准提取,剔除冗余:仅保留 LLM 组合 / 精修内核所需的核心约束信息(ID、类型、布局、dtype、输入输出形状、核心算子),剔除无关冗余信息,减少 Prompt 的 Token 占用,提升 LLM 处理效率;
- 合理默认值兜底,保证鲁棒性:对数据布局(默认 NCHW)、数据类型(默认 float32)、算子列表(默认空列表)等易缺失字段设置合理默认值,避免因字段缺失导致 Prompt 构建失败,提升流程的容错性;
- 层级化紧凑格式,易读易解析:采用「一级行标注子图基础属性 + 二级行标注核心算子」的层级格式,既保证信息紧凑(适配 Prompt 长度限制),又结构清晰,让 LLM 能快速关联子图 ID 与对应的功能、约束;
- 算子信息长度控制,避免超限:将算子列表序列化后截取前 400 字符,避免因算子过多导致 Prompt 过长超出 LLM 上下文窗口,同时兼容 JSON 序列化失败的情况(降级为直接字符串截取),保证信息完整性;
- 形状信息灵活适配 :优先使用
inputs字段描述输入,无inputs则降级为input_shape,适配不同子图分解工具的输出格式差异,保证输入输出形状信息的有效传递。
协同工作关系
三个函数形成 「基础数据处理→首次生成 Prompt 构建→迭代精修 Prompt 构建」 的层级支撑关系,为 LLM 生成端到端 Triton 内核提供全流程的 Prompt 支撑:
- 基础层 :
_summarize_subgraphs_for_prompt对原始子图信息做统一格式化处理,生成标准化子图摘要,为上层两个 Prompt 构建函数提供一致、易解析的子图约束信息,实现数据的一次处理、多次复用; - 首次生成层 :
_build_composition_prompt基于标准化子图摘要,整合问题代码、子图内核、平台配置,构建全量约束的组合 Prompt,指导 LLM 完成首次端到端 Triton 内核生成,是「从无到有」的核心指导; - 迭代精修层 :
_build_refinement_prompt复用标准化子图摘要与核心基础信息,新增错误日志和前次失败代码,构建错误导向的精修 Prompt,指导 LLM 完成「从错到对」的迭代优化,是提升代码有效性的关键; - 闭环支撑:三个函数的输出共同支撑 KernelAgent 的「生成→验证→精修→再验证」闭环流程,确保每一轮 LLM 生成都有明确的指导、充足的依据、严格的约束,大幅提升端到端 Triton 内核的生成成功率与工程质量。
代码
python
def _build_composition_prompt(
problem_code: str,
subgraphs: list[dict[str, Any]],
kernel_items: list[KernelItem],
target_platform: PlatformConfig,
) -> str:
"""Create a single user message to instruct composition by the LLM.
构建初始合成Prompt:
为LLM生成结构化指令,引导其基于子图和参考内核,合成端到端的Triton算子代码
参数:
- problem_code: 原始KernelBench问题代码(PyTorch)
- subgraphs: 子图信息列表(JSON格式)
- kernel_items: 参考内核列表(包含子图ID和对应的Triton代码)
- target_platform: 目标平台配置(如CUDA/XPU)
返回值:完整的LLM用户指令字符串
"""
# 第一步:生成子图摘要(压缩子图信息,控制Token消耗,便于LLM快速理解核心特征)
sg_summary = _summarize_subgraphs_for_prompt(subgraphs)
# 第二步:构建参考内核代码区块(仅保留核心代码,避免Token溢出)
# 注释说明:暂时保留完整文件内容,调用方可根据模型窗口限制进一步裁剪
# 初始化内核区块的文本片段列表
kernels_section_parts: list[str] = []
# 遍历每个参考内核项
for ki in kernel_items:
# 为每个子图内核构建带格式的代码片段(Markdown Python代码块)
kernels_section_parts.append(
f"### Subgraph {ki.subgraph_id}\n```python\n" + ki.code + "\n```\n"
)
# 拼接所有内核片段,形成完整的参考内核区块
kernels_section = "\n".join(kernels_section_parts)
# 第三步:获取平台专属的指导规则(如CUDA的内存访问规则、XPU的编译要求)
platform_guidance = target_platform.guidance_block
# 第四步:构建核心指导语(包含任务背景、平台信息、硬性要求、实现技巧)
# 使用textwrap.dedent去除缩进,保证Prompt格式整洁
guidance = textwrap.dedent(
f"""
You are given:
- The original problem file (PyTorch module and helpers).
- A decomposition of the model into fusable subgraphs with exact shapes.
- Working Triton kernels generated for some subgraphs.
TARGET PLATFORM: {target_platform.name}
DEVICE STRING: {target_platform.device_string}
{platform_guidance}
Task:
- Compose an end-to-end Triton implementation that matches the original
model's forward pass for the provided shapes. You may inline, adapt,
or reuse the given subgraph kernels. Prefer fusing into as few kernel
launches as possible while preserving exact numerical semantics.
Hard requirements:
- Return ONE complete Python file only, fenced as a single ```python block.
- Allocate inputs, weights, intermediates, and outputs on device='{target_platform.device_string}' and keep them there throughout forward/verification.
- CPU is acceptable only for metadata, scalars, and export serialization---avoid `.cpu()` or `.to('cpu')` on compute tensors.
- Provide at least one @triton.jit kernel and a top-level Python wrapper
named kernel_function(...). This wrapper must accept the same primary
input tensor(s) as the model and any required weights/biases with shapes
implied by the problem; it should orchestrate Triton kernel(s) and
return the final output tensor.
- No PyTorch math path: kernel_function MUST compute the final outputs
using your Triton kernels only. Do NOT implement or fall back to
torch.nn / torch.nn.functional / torch.* ops
sigmoid, etc.) for producing the final result. Using PyTorch for
reference comparisons is allowed only inside the self-test.
- Use the data layout and dtype semantics indicated by subgraphs, defaulting
to NCHW + float32 if unspecified. Respect stride/padding/dilation/groups,
and exact op order.
- Numerical equivalence: include a self-test (test_kernel or run_tests)
that compares your Triton-based result to a PyTorch reference computed
from the original problem code below (use get_init_inputs() and
get_inputs() if present to instantiate the Model). The test must print
'PASS' on success and exit with code 0. Use allclose with rtol<=1e-3,
atol<=1e-3 for fp32; for fp16/bf16 allow up to 2e-2.
- No imports beyond torch, triton, triton.language as tl, and stdlib. No I/O.
- Do NOT monkey-patch PyTorch device functions or torch.cuda.is_available()
- Do NOT manipulate TRITON_BACKENDS environment variable
- Do NOT disable or mock XPU/CUDA drivers
Implementation tips:
- If merging multiple subgraphs, ensure intermediate tensor shapes match.
- Hoist constant weights or parameters to avoid reloading per block.
- Use tl.load/tl.store with masks for boundary conditions.
- Favor coalesced memory access; tile by blocks; compute grid from shape.
- Common Triton pitfalls to avoid:
* Do NOT call tl.broadcast on Python scalars; tl.maximum(x, 0.0) works.
* Prefer scalar constants directly in elementwise ops (no explicit broadcast needed).
* Keep BLOCK_SIZE power-of-two; mask stores at tail.
"""
).strip() # 去除首尾空白字符
# 第五步:拼接完整的用户指令(按逻辑组织各部分内容)
user_lines: list[str] = []
user_lines.append(guidance) # 核心指导语
user_lines.append("") # 空行分隔
user_lines.append("SUBGRAPHS (summary):") # 子图摘要标题
user_lines.append(sg_summary) # 子图摘要内容
user_lines.append("") # 空行分隔
user_lines.append("ORIGINAL PROBLEM FILE:") # 原始问题代码标题
user_lines.append("```python") # Python代码块开始标记
user_lines.append(problem_code) # 原始问题代码内容
user_lines.append("```") # Python代码块结束标记
user_lines.append("") # 空行分隔
user_lines.append("SUBGRAPH KERNELS (reference implementations):") # 参考内核标题
user_lines.append(kernels_section) # 参考内核代码内容
user_lines.append("") # 空行分隔
# 最终要求:仅返回一个包含完整代码的Python代码块
user_lines.append(
"Return only one fenced Python code block with your final composed implementation."
)
# 拼接所有行,形成完整的Prompt
return "\n".join(user_lines)
def _build_refinement_prompt(
problem_code: str,
subgraphs: list[dict[str, Any]],
kernel_items: list[KernelItem],
previous_code: str,
error_info: dict[str, str],
target_platform: PlatformConfig,
) -> str:
"""Prompt the LLM to refine the previously produced code based on errors.
构建迭代优化Prompt:
基于上一轮代码的错误信息,引导LLM修复Triton算子代码中的编译/运行错误
参数:
- previous_code: 上一轮生成的错误代码
- error_info: 错误信息字典(包含stderr_tail/stdout_tail)
其他参数同_build_composition_prompt
返回值:针对性的优化指令字符串
"""
# 提取错误日志尾部(最后2000字符,聚焦核心错误)
err_tail = error_info.get("stderr_tail", "")
# 提取标准输出尾部(辅助分析错误)
out_tail = error_info.get("stdout_tail", "")
# 构建优化指导语(聚焦错误修复,保留原有核心要求)
guidance = textwrap.dedent(
f"""
You previously produced a composed Triton implementation, but it failed
to run/compile. Analyze the ERROR_CONTEXT below and re-emit the entire
corrected single-file implementation as one ```python block.
TARGET PLATFORM: {target_platform.name}
DEVICE STRING: {target_platform.device_string}
Requirements remain the same. Additionally:
- Fix any Triton compilation/runtime errors. For scalar constants in
elementwise ops (e.g., ReLU), do not use tl.broadcast. Use direct
scalars like 0.0 in tl.maximum(x, 0.0).
- Keep function name kernel_function(...) unchanged and retain the
self-test that prints PASS on success and exits 0.
- Do NOT reintroduce any PyTorch math path in kernel_function. The final
outputs must be computed via your Triton kernels only (no fallback to
torch.nn / torch.nn.functional ops).
- Return the complete corrected file; do not send diffs.
"""
).strip()
# 拼接完整的优化指令(按"指导语→错误信息→原始代码→子图摘要→上一轮代码"组织)
lines: list[str] = []
lines.append(guidance) # 优化指导语
lines.append("") # 空行分隔
# 添加标准错误日志(核心错误信息)
lines.append("ERROR_CONTEXT (stderr tail):\n```\n" + err_tail + "\n```")
# 若标准输出非空,添加标准输出日志(辅助分析)
if out_tail.strip():
lines.append("STDOUT tail:\n```\n" + out_tail + "\n```")
lines.append("") # 空行分隔
# 添加原始问题代码(保证上下文完整)
lines.append("ORIGINAL PROBLEM FILE:\n```python\n" + problem_code + "\n```")
lines.append("") # 空行分隔
# 添加子图摘要(避免LLM遗忘核心特征)
lines.append("SUBGRAPHS (summary):\n" + _summarize_subgraphs_for_prompt(subgraphs))
lines.append("") # 空行分隔
# 添加上一轮错误代码(让LLM对比分析问题)
lines.append("PREVIOUS_ATTEMPT:\n```python\n" + previous_code + "\n```")
lines.append("") # 空行分隔
# 最终要求:仅返回修正后的完整Python代码块
lines.append(
"Return only one fenced Python code block with the corrected implementation."
)
# 拼接所有行,形成优化Prompt
return "\n".join(lines)
def _summarize_subgraphs_for_prompt(subgraphs: list[dict[str, Any]]) -> str:
"""
生成子图摘要:将复杂的子图JSON信息压缩为简洁的文本格式
核心目标:控制Token消耗,同时保留子图的ID、类型、布局、 dtype、形状、算子等核心特征
"""
# 初始化摘要行列表
lines: list[str] = []
# 遍历每个子图
for it in subgraphs:
# 提取子图ID(默认unknown)
sid = str(it.get("id", "unknown"))
# 提取子图类型(如conv2d、linear)
typ = str(it.get("type", ""))
# 提取数据布局(默认NCHW)
layout = it.get("data_layout") or "NCHW"
# 提取数据类型(默认float32)
dtype = it.get("dtype") or "float32"
# 提取输入形状(兼容多输入/单输入格式)
inputs = it.get("inputs")
in_shape = it.get("input_shape")
# 提取输出形状
out_shape = it.get("output_shape")
# 提取算子列表
ops = it.get("ops") or []
# 构建形状描述行(优先多输入格式,其次单输入格式)
shapes_line = (
f"inputs={inputs if inputs is not None else in_shape}, output={out_shape}"
)
# 构建子图核心信息行
lines.append(
f"- ID={sid} type={typ} layout={layout} dtype={dtype} {shapes_line}"
)
# 构建算子摘要(限制长度为400字符,避免Token溢出)
try:
# 尝试转为JSON字符串(结构化)
ops_short = json.dumps(ops)[:400]
except Exception:
# 转换失败则直接转为字符串
ops_short = str(ops)[:400]
# 添加算子摘要行(缩进2空格,提升可读性)
lines.append(f" ops={ops_short}")
# 拼接所有行,形成最终的子图摘要
return "\n".join(lines)
2.5 组合生成
以下两个函数是 PyTorch KernelAgent 中端到端 Triton 内核组合生成的核心执行模块。
_auto_patch_common_triton_issues作为轻量自动化代码修复工具 ,为 LLM 生成的 Triton 代码做前置避坑补丁;_auto_patch_common_triton_issues:对 LLM 生成的 Triton 代码做无侵入式文本级自动化补丁,专门修复 Triton 开发中高频、易犯的基础错误,在代码运行 / 验证前提前规避低级错误,减少无效迭代,提升内核生成效率。compose作为顶层入口与流程调度核心 ,串联起「Prompt 构建、LLM 生成、代码补丁、真机验证、迭代精修、产物归档」的全流程,是 KernelAgent 从「分散子图 / 内核素材」到「可运行端到端 Triton 内核」的端到端驱动核心,根据迭代轮次动态构建「组合 / 精修 Prompt」,调用 LLM 生成代码,对生成代码做自动化补丁,支持真机验证驱动的多轮迭代精修,最终输出硬件适配、可直接部署的一体化 Triton 内核及标准化结果汇总。
_auto_patch_common_triton_issues 函数
核心作用
针对 LLM 生成 Triton 代码时易出现的两类高频低级错误做自动化文本修复,在代码运行前提前避坑,减少因基础错误导致的验证失败,提升迭代效率:
- 修复标量滥用
tl.broadcast问题(如tl.broadcast(0.0, ...)替换为直接标量0.0),贴合 Triton 标量运算规范; - 剥离目标平台不兼容的 CUDA 专属 hack 代码,避免跨平台运行报错,保证代码与目标硬件匹配。
核心特色
- 保守式文本补丁,无逻辑侵入 :仅做字符串级简单替换 / 过滤,不修改代码的核心计算逻辑,避免补丁引入新的逻辑错误,保证修复的安全性;
- 精准针对 Triton 高频坑,修复效率高:补丁规则聚焦 LLM 最易犯的 Triton 基础错误,不处理复杂逻辑问题,修复速度快、命中率高;
- 平台定制化补丁,适配性强 :根据目标平台配置的
cuda_hacks_to_strip动态剥离不兼容代码,支持多平台扩展,无需硬编码平台规则; - 支持整段代码块过滤,处理更全面 :不仅过滤单行 CUDA hack 代码,还能识别
_fake_torch_device这类函数定义,实现整段代码块的跳过过滤,处理更彻底; - 返回修改状态,流程可感知 :以
(patched_code, changed)元组返回结果,明确告知调用方是否对代码做了修改,便于后续日志记录与流程监控; - 鲁棒性强,无代码侵入风险:所有补丁操作均基于原始代码做副本修改,不改变原代码内容,避免副作用。
compose 函数
核心作用
作为内核组合生成的全流程调度核心,完成从输入加载到最终产物输出的所有步骤:
- 初始化环境(创建目录、加载 LLM 服务商、目标平台配置),校验并加载三大核心输入(问题代码、子图信息、有效子图内核);
- 按轮次动态构建 Prompt(首次用组合 Prompt,后续用错误驱动的精修 Prompt),调用 LLM 生成 Triton 代码;
- 对生成代码做自动化补丁,归档中间产物,支持真机验证驱动的多轮迭代精修;
- 最终输出端到端 Triton 内核文件,生成标准化结果汇总并持久化,为上层模块提供统一的调用接口。
核心特色
- 双 Prompt 动态切换,迭代精修更精准 :根据迭代轮次自动切换 Prompt 类型 ------ 首次生成用全量约束的组合 Prompt (从无到有构建内核),后续迭代用错误驱动的精修 Prompt(基于 stderr/stdout 日志针对性修正),避免无目的重生成,大幅提升迭代有效性;
- 验证驱动的闭环迭代,成功率高 :开启验证后,每轮生成的代码均通过
run_candidate做真机运行验证,失败则提取错误日志反馈给 LLM 精修,形成 「生成→补丁→验证→报错→精修」 的闭环,直至验证通过或达到最大迭代轮次,大幅提升最终内核的可用性; - 灵活的验证开关,兼顾效率与有效性 :支持
verify开关控制是否开启真机验证 ------ 开启时保证产物有效性,关闭时仅做一次生成即终止,提升轻量使用场景的效率; - 全流程产物结构化归档,可追溯性强:为每轮尝试创建独立文件(Prompt 文本、生成代码),验证日志、最终内核、结果汇总均按目录结构化存储,所有环节可追溯、可复现,便于问题排查;
- 严格的输入校验,流程鲁棒性高:对输入的子图信息格式、LLM 服务商可用性做严格校验,格式错误直接终止流程并抛出明确提示,避免无效执行;
- 标准化结果输出,工程友好:返回包含「成功状态、内核路径、LLM 用量、迭代轮次、验证结果、目标平台」的标准化字典,同时生成 JSON 格式的汇总文件,便于上层模块调用、结果统计与集成;
- 迭代轮次可配置,适配不同场景 :通过
max_iters控制最大迭代轮次(默认 5 轮),可根据需求调整,平衡生成效率与成功率; - LLM 用量追踪,成本可控 :记录每轮 LLM 调用的用量信息(
last_usage)并纳入最终结果,支持 LLM 生成成本的统计与管控; - 无缝衔接上下游模块 :严格适配上游子图分解、内核生成的输出格式,下游可直接调用生成的
composed_kernel.py,实现全流程的无缝衔接; - 超时与环境控制,验证更可靠 :调用
run_candidate时设置 2400 秒超时,同时支持隔离运行、禁止网络等配置,避免验证过程因环境问题卡死,提升验证的稳定性。
协同工作关系
两个函数形成 「前置补丁防护 + 全流程调度执行」 的紧密协同关系,是 KernelAgent 组合生成端到端 Triton 内核的核心支撑:
- _auto_patch_common_triton_issues 作为
compose函数的前置子步骤 ,在 LLM 生成代码后、真机验证前执行,提前修复 Triton 基础错误,减少因低级问题导致的验证失败,为compose的迭代流程「减负」,提升整体执行效率; - compose 作为顶层调度者 ,负责调用
_auto_patch_common_triton_issues,并为其传递目标平台配置,让补丁操作贴合硬件要求,同时记录补丁后的代码并推进后续验证、迭代流程; - 二者结合实现「LLM 生成的灵活性 + 自动化补丁的避坑能力 + 验证驱动的闭环迭代」,既发挥 LLM 融合子图、构建复杂内核的能力,又通过自动化补丁和多轮精修规避低级错误、解决运行问题,最终输出高质量、可运行的端到端 Triton 内核。
代码
python
def _auto_patch_common_triton_issues(
code: str, target_platform: PlatformConfig
) -> tuple[str, bool]:
"""Apply tiny safe textual patches for known Triton pitfalls.
- Replace tl.broadcast(0.0, ...) or tl.broadcast(1.0, ...) with scalar constants.
Returns (patched_code, changed).
自动修复Triton算子代码中的常见问题:
- 核心修复点:将tl.broadcast(0.0/1.0/0/1, ...)替换为标量常量(避免Triton广播操作的性能/语法问题)
- 返回值:(修复后的代码, 是否发生修改)
"""
# 初始化修复后的代码为原始代码
patched = code
# 标记是否发生修改(默认未修改)
changed = False
# 修复规则:采用保守的简单启发式规则,仅处理无歧义的常见问题
patterns = [
# 规则1:tl.broadcast(0.0 → 替换为0.0(移除不必要的广播操作)
("tl.broadcast(0.0", "0.0"),
# 规则2:tl.broadcast(1.0 → 替换为1.0
("tl.broadcast(1.0", "1.0"),
# 规则3:tl.broadcast(0, → 替换为0.0(统一数值类型为浮点数)
("tl.broadcast(0,", "0.0"),
# 规则4:tl.broadcast(1, → 替换为1.0
("tl.broadcast(1,", "1.0"),
]
# 遍历所有修复规则
for old, new in patterns:
# 若原始代码包含待修复的模式
if old in patched:
# 执行文本替换
patched = patched.replace(old, new)
# 标记为已修改
changed = True
# 移除CUDA相关的冗余hack代码(平台适配)
# 获取当前目标平台需要剥离的CUDA hack模式列表
cuda_hacks = target_platform.cuda_hacks_to_strip
if cuda_hacks:
# 将代码按行拆分
lines = patched.split("\n")
# 存储过滤后的代码行
filtered_lines = []
# 标记是否需要跳过直到空行(用于移除整个函数块)
skip_until_blank = False
# 逐行处理代码
for line in lines:
# 若处于跳过模式:跳过当前行,直到遇到空行
if skip_until_blank:
if line.strip() == "":
# 遇到空行,退出跳过模式
skip_until_blank = False
continue
# 检查当前行是否包含需要剥离的CUDA hack模式
if any(hack in line for hack in cuda_hacks):
# 标记为已修改
changed = True
# 特殊处理:若为_fake_torch_device函数定义,需跳过整个函数块(直到空行)
if "def _fake_torch_device" in line:
skip_until_blank = True
# 跳过当前行(移除hack代码)
continue
# 保留当前行
filtered_lines.append(line)
# 重新拼接过滤后的代码行
patched = "\n".join(filtered_lines)
# 返回修复后的代码和是否修改的标记
return patched, changed
def compose(
problem_path: Path,
subgraphs_path: Path,
kernels_summary_path: Path,
out_dir: Path,
model_name: str,
verify: bool = False,
max_iters: int = 5,
target_platform: str = "cuda",
) -> dict[str, Any]:
"""
核心函数:基于子图信息和内核摘要,合成/优化Triton算子代码
参数说明:
- problem_path: KernelBench问题文件路径
- subgraphs_path: 子图信息JSON文件路径(由之前的子图提取模块生成)
- kernels_summary_path: 内核摘要文件路径
- out_dir: 输出目录路径
- model_name: LLM模型名称
- verify: 是否验证生成的代码(默认False)
- max_iters: 最大迭代次数(默认5)
- target_platform: 目标平台(默认cuda)
返回值:包含合成结果的字典(成功状态、代码路径、迭代次数等)
"""
# 前置检查:确保LLM提供商模块可导入
if get_model_provider is None:
raise SystemExit(
"KernelAgent providers unavailable; ensure package import and dependencies"
)
# 创建输出目录(递归创建父目录,已存在则忽略)
out_dir.mkdir(parents=True, exist_ok=True)
# 获取LLM提供商实例(用于调用LLM生成代码)
provider = get_model_provider(model_name)
# 初始化目标平台配置(加载平台相关的修复规则、hack列表等)
platform = get_platform(target_platform)
# 加载输入文件:
# 1. 加载KernelBench问题代码(原始PyTorch代码)
problem_code = _read_text(problem_path)
# 2. 加载子图信息JSON(由子图提取模块生成)
subgraphs = json.loads(_read_text(subgraphs_path))
# 检查子图信息是否为列表:非列表则退出
if not isinstance(subgraphs, list):
raise SystemExit("subgraphs.json must be a JSON array")
# 3. 从内核摘要文件加载内核信息
kernels = _load_kernels_from_summary(kernels_summary_path)
# 创建迭代尝试目录(存储每一轮的生成代码)
attempts_dir = out_dir / "attempts"
attempts_dir.mkdir(parents=True, exist_ok=True)
# 初始化变量:
# - last_usage: 最后一次LLM调用的token使用信息
# - last_code: 上一轮生成的代码
# - verify_info: 验证信息字典
last_usage = None
last_code = None
verify_info: dict[str, Any] = {}
# 多轮迭代生成/优化代码(最多max_iters轮)
for i in range(1, max_iters + 1):
# 构建LLM Prompt:
# 第一轮/上一轮代码为空 → 构建初始合成Prompt
if i == 1 or last_code is None:
prompt = _build_composition_prompt(
problem_code, subgraphs, kernels, target_platform=platform
)
else:
# 非第一轮 → 基于上一轮的错误信息构建优化Prompt
# 提取验证日志的最后2000字符(便于LLM定位问题)
stderr_tail = ""
stdout_tail = ""
try:
# 读取标准错误日志尾部
if verify_info.get("stderr_path"):
with open(
verify_info["stderr_path"],
"r",
encoding="utf-8",
errors="ignore",
) as f:
stderr_tail = f.read()[-2000:]
# 读取标准输出日志尾部
if verify_info.get("stdout_path"):
with open(
verify_info["stdout_path"],
"r",
encoding="utf-8",
errors="ignore",
) as f:
stdout_tail = f.read()[-2000:]
except Exception:
# 日志读取失败则忽略(避免中断迭代)
pass
# 构建优化Prompt(包含上一轮代码和错误信息)
prompt = _build_refinement_prompt(
problem_code,
subgraphs,
kernels,
previous_code=last_code,
error_info={"stderr_tail": stderr_tail, "stdout_tail": stdout_tail},
target_platform=platform,
)
# 保存当前轮次的Prompt(便于追溯和调试)
(attempts_dir / f"attempt_{i}.prompt.txt").write_text(prompt, encoding="utf-8")
# 调用LLM生成代码:
# - 消息格式:仅包含user角色的Prompt
# - 最大生成Token数:50000(适配长代码生成)
response = provider.get_response(
model_name, [{"role": "user", "content": prompt}], max_tokens=50000
)
# 记录最后一次LLM调用的token使用信息
last_usage = response.usage
# 提取LLM输出的原始文本
raw_text = response.content or ""
# 从LLM输出中提取Python代码(剥离无关文本)
extracted = extract_single_python_file(raw_text)
code = extracted.code
# 自动修复Triton常见问题(文本补丁)
code, changed = _auto_patch_common_triton_issues(code, platform)
# 保存当前轮次的生成代码
(attempts_dir / f"attempt_{i}.py").write_text(code, encoding="utf-8")
# 更新last_code为当前轮次的代码
last_code = code
# 验证当前轮次的代码(若开启验证)
if verify:
# 运行候选代码验证:
# - artifacts_code_path: 当前轮次的代码路径
# - run_root: 验证运行目录
# - timeout_s: 验证超时时间(2400秒)
# - isolated: 非隔离模式
# - deny_network: 允许网络访问
rr = run_candidate(
artifacts_code_path=attempts_dir / f"attempt_{i}.py",
run_root=out_dir / "runs",
timeout_s=2400,
isolated=False,
deny_network=False,
)
# 记录验证信息
verify_info = {
"verify_rc": rr.rc, # 验证退出码
"verify_passed": rr.passed, # 是否验证通过
"verify_reason": rr.reason, # 验证结果原因
"validator": rr.validator_used, # 使用的验证器
"stdout_path": str(rr.stdout_path), # 标准输出日志路径
"stderr_path": str(rr.stderr_path), # 标准错误日志路径
}
# 若验证通过,终止迭代(无需继续优化)
if rr.passed:
break
else:
# 若未开启验证,仅执行第一轮后终止
break
# 保存最终合成的算子代码(取最后一轮的代码)
composed_path = out_dir / "composed_kernel.py"
composed_path.write_text(last_code or "", encoding="utf-8")
# 构建结果字典:包含核心合成信息
result: dict[str, Any] = {
# 成功状态:验证模式下为是否通过验证,非验证模式下默认成功
"success": bool(verify_info.get("verify_passed", not verify)),
# 最终合成代码的绝对路径
"composed_path": str(composed_path.resolve()),
# 使用的LLM模型名称
"model": model_name,
# LLM token使用信息
"usage": last_usage,
# 实际迭代次数
"rounds": i,
# 目标平台
"target_platform": target_platform,
}
# 合并验证信息到结果字典
result.update(verify_info)
# 保存合成摘要(结构化JSON文件)
(out_dir / "composition_summary.json").write_text(
json.dumps(result, indent=2), encoding="utf-8"
)
# 返回结果字典
return result