PyTorch KernelAgent 源码解读 ---(5)--- Dispatcher
目录
- [PyTorch KernelAgent 源码解读 ---(5)--- Dispatcher](#PyTorch KernelAgent 源码解读 ---(5)--- Dispatcher)
- [0x00 概述](#0x00 概述)
- [0x01 Dispatch Kernel Agent 的作用](#0x01 Dispatch Kernel Agent 的作用)
- [0x02 TritonKernelAgent](#0x02 TritonKernelAgent)
- [0x03 prompt](#0x03 prompt)
- [0x04 Worker Manager 的作用](#0x04 Worker Manager 的作用)
- [0x05 VerificationWorker](#0x05 VerificationWorker)
- [0xFF 参考](#0xFF 参考)
0x00 概述
dispatch_kernel_agent.py 是 KernelAgent 系统中的调度组件,负责将 subgraph_extractor.py 生成的子图(JSON 格式)转换为具体的 Triton 内核生成任务,并调度 TritonKernelAgent 来生成和验证这些内核。
Dispatcher架构图如下,其功能概括是:读subgraphs.json,把每个子图转成含 reference code 的精确 Triton 生成spec,交给独立的 TritonKernelAgent 实例并发生成,产出kernel.py + summary.json。
0x01 Dispatch Kernel Agent 的作用
dispatch_kernel_agent.py 在 KernelAgent 系统中扮演着桥梁的角色,它将高层的子图分解结果转化为具体的 Triton 内核生成任务。其核心价值在于:
- 自动化任务分配:将复杂的融合模型分解为独立的子图任务
- 标准化问题描述:为每个子图生成适合 Triton 内核生成的描述
- 并行处理能力:支持并发生成多个子图的 Triton 内核
- 结果整合:收集和整理所有子图的内核生成结果,为后续的合成阶段做准备
dispatch_kernel_agent.py 在 流水线 中的位置如下:
python
原始模型 → orchestrator.py(fuse)→ subgraph_extractor.py(extract)→ dispatch_kernel_agent.py → compose_end_to_end.py
1.1 整体功能
并发处理机制
run 函数会把子图发给 KernelAgent,来并行生成Triton 内核
python
def run(
subgraphs_path: Path,
out_dir: Path,
agent_model: str | None = None,
jobs: int = 1,
target_platform: str = "cuda",
max_iters: int = 10,
) -> Path:
"""Dispatch subgraphs to KernelAgent with optional parallelism.
jobs controls the number of concurrent subgraph generations. Default=1
preserves previous behavior and avoids GPU/LLM contention.
"""
# Submit tasks with bounded concurrency
jobs = max(1, int(jobs or 1))
ordered_inputs: list[tuple[int, dict[str, Any]]] = list(enumerate(items, start=1))
results: dict[int, dict[str, Any]] = {}
if jobs == 1: # 串行处理
for pair in ordered_inputs:
i, res = _handle_one(pair)
results[i] = res
else: # 并发处理
with _futures.ThreadPoolExecutor(max_workers=jobs) as ex:
future_map = {
ex.submit(_handle_one, pair): pair[0] for pair in ordered_inputs
}
for fut in _futures.as_completed(future_map):
i, res = fut.result()
results[i] = res
任务处理函数
_handle_one 函数会调用 KernelAgent 生成算子。
python
def _handle_one(idx_item: tuple[int, dict[str, Any]]) -> tuple[int, dict[str, Any]]:
idx, item = idx_item
sid = str(item.get("id", f"subgraph_{idx}"))
pdesc = _synthesize_problem_description(item, target_platform=platform)
sg_dir = out_dir / sid
sg_dir.mkdir(parents=True, exist_ok=True)
(sg_dir / "problem.txt").write_text(pdesc, encoding="utf-8")
# Pin KernelAgent concurrency defaults: 4 workers, max_iters rounds
# 为每个子图创建独立的 TritonKernelAgent 实例
local_agent = TritonKernelAgent(
num_workers=4,
max_rounds=max_iters,
model_name=agent_model,
target_platform=platform,
)
# 生成算子
try:
result = local_agent.generate_kernel(
problem_description=pdesc, test_code=None
)
_handle_one 函数会调用 _synthesize_problem_description 来生成 问题描述
问题描述合成
生成包含子图信息、形状、操作序列的问题描述
python
def _synthesize_problem_description(
item: dict[str, Any], target_platform: PlatformConfig
) -> str:
id_ = str(item.get("id", "unknown"))
type_ = str(item.get("type", ""))
layout = item.get("data_layout") or "NCHW"
dtype = item.get("dtype") or "float32"
input_shape = item.get("input_shape")
output_shape = item.get("output_shape")
inputs_multi = item.get("inputs")
weights_fused = item.get("weights_fused")
weights_orig = item.get("weights_original")
source = item.get("source") or {}
ref_code, _ = _build_reference_code(item)
# Get device string for the platform
header = textwrap.dedent(
f"""
Implement a Triton kernel that computes the following subgraph end-to-end.
Subgraph ID: {id_}
Type: {type_}
Data layout: {layout}
DType: {dtype}
Target Platform: {target_platform.name}
Device String: {target_platform.device_string}
Shapes:
- input: {_fmt_shape(inputs_multi[0]) if isinstance(inputs_multi, list) else _fmt_shape(input_shape)}
{("- input2: " + _fmt_shape(inputs_multi[1])) if isinstance(inputs_multi, list) and len(inputs_multi) > 1 else ""}
- output: {_fmt_shape(output_shape)}
Weights (fused): {json.dumps(weights_fused, indent=2) if isinstance(weights_fused, dict) else "null"}
Weights (original): {json.dumps(weights_orig, indent=2) if isinstance(weights_orig, dict) else "null"}
Operations in order (with parameters):
{json.dumps(item.get("ops", []), indent=2)}
Requirements:
- Return a complete Python file with a @triton.jit kernel and a wrapper function named kernel_function(...).
- kernel_function must accept input tensor(s) and any required weights/bias parameters (match shapes above).
- Implement the exact semantics of the listed ops in the given order for the provided shapes.
- Use {layout} layout and {dtype} dtype semantics.
- Allocate inputs, weights, intermediates, and outputs on device='{target_platform.device_string}' and keep them there throughout forward/verification.
- CPU is acceptable only for metadata, scalars, and export serialization---avoid `.cpu()` or `.to('cpu')` on compute tensors.
- The test will import kernel_function and compare to the reference implementation below.
Test tolerance policy (enforced in generated tests):
- Default tolerances: rtol=1e-3, atol=1e-3.
- Absolute cap: NEVER exceed rtol=1e-2 or atol=1e-2 in torch.allclose.
- For float16/bfloat16 inputs: use rtol=1e-2, atol=1e-2 at most (do not go higher).
- Include a one-line comment if you relax from default; never exceed the cap.
Reference PyTorch implementation (exact semantics to match):
"""
).strip()
src_code_block = "" # optional original snippet for context
if isinstance(source, dict) and source.get("code"):
mod = source.get("module", "Model")
code = str(source.get("code"))
src_code_block = f"\nOriginal source snippet ({mod}):\n```python\n{code}\n```\n"
problem = header + "\n\n```python\n" + ref_code + "```\n" + src_code_block
return problem
其中,_build_reference_code 生成参考实现,即根据操作类型生成对应的 PyTorch 代码
python
def _build_reference_code(item: dict[str, Any]) -> tuple[str, list[str]]:
"""Return (reference_code_str, param_names) implementing the subgraph.
param_names are additional parameters to reference() beyond the first input(s).
"""
ops: list[dict[str, Any]] = [
op for op in (item.get("ops") or []) if isinstance(op, dict)
]
lines: list[str] = ["import torch", "import torch.nn.functional as F", ""]
params: list[str] = []
# 省略其他代码
1.2 与系统其他组件的交互
dispatch_kernel_agent.py 与 subgraph_extractor.py 的交互
-
输入:接收 subgraphs.json 文件
-
处理:解析子图结构,提取操作和形状信息
-
依赖:依赖于子图提取阶段的输出
dispatch_kernel_agent.py 与 TritonKernelAgent 的交互
-
调用:为每个子图实例化 TritonKernelAgent
-
传递:传递合成的问题描述和平台配置
-
接收:接收生成的内核代码和验证结果
dispatch_kernel_agent.py 与 compose_end_to_end.py 的交互
-
输出:生成 summary.json,记录每个子图的内核生成结果
-
用途:为合成阶段提供已验证的 Triton 内核
1.3 生成结果管理
输出目录结构如下。
python
out_dir/
├─ <subgraph_id_1>/
│ ├─ problem.txt # 合成的问题描述
│ └─ kernel.py # 生成的 Triton 内核
├─ <subgraph_id_2>/
│ ├─ problem.txt
│ └─ kernel.py
└─ summary.json # 所有子图的生成结果汇总
摘要文件格式如下
python
[
{
"id": "subgraph_1",
"success": true,
"worker_id": "worker_1",
"rounds": 3,
"session_dir": "/path/to/session",
"kernel_path": "/path/to/kernel.py"
},
{
"id": "subgraph_2",
"success": false,
"message": "generation failed...",
"session_dir": "/path/to/session"
}
]
0x02 TritonKernelAgent
triton_kernel_agent/agent.py 实现了 TritonKernelAgent 类,这是 Triton 内核生成系统的主要代理类,负责协调整个内核生成过程。
python
TritonKernelAgent (agent.py)
↓
WorkerManager (manager.py)
↓
VerificationWorker (worker.py)
↓
Kernel Generation & Refinement Loop
2.1 核心功能
TritonKernelAgent 是 Triton 内核生成系统的核心协调者,负责:
- 配置管理:处理环境变量和默认配置
- 资源初始化:初始化 LLM 提供商、日志记录和子组件
- 测试生成:使用 LLM 生成适当的测试代码
- 内核种子生成:生成多个初始内核实现变体
- 验证协调:协调 WorkerManager 运行并行验证
- 结果处理:处理生成结果并返回适当的响应
它是连接问题描述和实际 Triton 内核实现的关键组件,通过协调多个子组件来实现高效、可靠的内核生成。
主生成方法
该方法明确三大核心组件的强依赖逻辑 :_generate_kernel_seeds 生成的多版本初始内核种子,必须基于 generated_test_code (标准化测试代码)进行开发适配;run_verification 则以 _generate_kernel_seeds 的内核种子为验证对象,以 generated_test_code 为验证标准,完成多版本内核的并行有效性检测,三者形成「测试代码标准化→内核种子生成→并行验证筛选」的严格执行链路。
核心特色
- 测试代码强制标准化,统一验证基准 :无论用户是否提供参考测试代码,均通过
_generate_test生成标准化测试代码(参考代码仅作为适配依据),确保后续内核种子生成、验证环节使用统一的测试基准,避免因测试代码格式不统一导致的验证失效。 - 全流程会话化归档,可追溯可复现 :为每次内核生成任务创建唯一时间戳会话目录,归档问题描述、标准化测试代码、所有内核种子、最终有效内核及验证结果,实现全流程可追溯,便于问题排查与结果复现。
- 多版本内核种子生成,提升有效率 :调用
_generate_kernel_seeds生成批量初始内核种子,为并行验证提供多版本候选,相比单版本生成大幅提升「筛选出可通过测试内核」的概率。 - 并行验证筛选,提升效率 :通过
manager.run_verification对多版本内核种子做并行验证,利用多工作器同时检测内核是否通过标准化测试,大幅缩短验证耗时,适配批量内核的快速筛选需求。 - 标准化结果返回,贴合工程使用:成功时返回有效内核代码、工作器 ID、验证轮次、会话目录等核心信息;失败时明确返回失败状态与原因,结果格式统一,便于上层模块调用与后续处理。
逻辑关系图
代码
python
def generate_kernel(
self, problem_description: str, test_code: str | None = None
) -> dict[str, Any]:
"""
Generate an optimized Triton kernel for the given problem.
Args:
problem_description: Description of the kernel to generate
test_code: Optional test code (generated if not provided)
The test code should:
1. Import the kernel function: from kernel import kernel_function
2. Test the kernel and return True/False
3. Exit with code 0 on success, 1 on failure
Returns:
Dictionary with results including successful kernel
"""
# Always generate test code using LLM (even if test is provided as reference)
generated_test_code = self._generate_test(problem_description, test_code)
# Use the generated test code in standardized format
test_code = generated_test_code
# Log inputs
import time
# Add microseconds to ensure unique directory names
timestamp = (
datetime.now().strftime("%Y%m%d_%H%M%S")
+ f"_{int(time.time() * 1000000) % 1000000}"
)
session_dir = self.log_dir / f"session_{timestamp}"
session_dir.mkdir(exist_ok=True)
with open(session_dir / "problem.txt", "w") as f:
f.write(problem_description)
with open(session_dir / "test.py", "w") as f:
f.write(test_code)
# Generate kernel seeds
kernel_seeds = self._generate_kernel_seeds(problem_description, test_code)
# Save seeds
for i, kernel in enumerate(kernel_seeds):
with open(session_dir / f"seed_{i}.py", "w") as f:
f.write(kernel)
# Run parallel verification with session directory for worker logs
result = self.manager.run_verification(
kernel_seeds=kernel_seeds,
test_code=test_code,
problem_description=problem_description,
session_log_dir=session_dir,
)
# Process results
if result and result["success"]:
# Save successful kernel
with open(session_dir / "final_kernel.py", "w") as f:
f.write(result["kernel_code"])
# Save full result
with open(session_dir / "result.json", "w") as f:
json.dump(result, f, indent=2)
return {
"success": True,
"kernel_code": result["kernel_code"],
"worker_id": result["worker_id"],
"rounds": result["rounds"],
"session_dir": str(session_dir),
}
else:
return {
"success": False,
"message": "Failed to generate working kernel",
"session_dir": str(session_dir),
}
测试代码生成
核心作用
该方法是基于 LLM 生成 Triton/CUDA 内核代码配套测试代码的核心功能模块 ,专为 PyTorch KernelAgent 设计,核心目标是为待实现的 GPU 内核(最终写入 kernel.py)自动生成可直接运行的标准化测试代码,支撑内核代码的语法校验、真机运行验证、功能正确性检测,是 LLM 生成 GPU 内核流水线中「验证环节」的关键组成部分。
核心特色
- LLM 主导生成,支持参考代码适配:优先调用配置的 LLM 服务商(如 OpenAI),通过 Prompt 模板渲染生成贴合问题描述的测试代码;若用户提供参考测试代码,会基于参考代码适配生成,无参考时则生成通用标准化测试,兼顾灵活性与贴合性。
- 强约束的标准化输出 :强制要求生成的测试代码从
kernel模块导入内核函数(因内核最终写入工作目录的kernel.py),确保测试代码与内核代码的调用路径一致,无运行路径错误。 - 完整的代码提取与异常处理:调用 LLM 后会从返回结果中提取有效代码,无有效代码则直接抛出异常;全流程记录日志(生成开始、原始响应、成功 / 失败状态),并捕获 LLM 调用、代码提取中的所有异常,便于问题排查。
- 无 Mock 兜底限制,保证生成有效性:仅当未配置 LLM 服务商时才触发 Mock 兜底,且兜底逻辑禁用「Mock 回退开关」,避免无实际能力的空生成,确保测试代码要么由 LLM 专业生成,要么由兜底逻辑生成基础可用代码。
- 适配 GPU 内核测试特性 :兜底测试代码默认基于 PyTorch 实现,针对 CUDA 设备设计测试数据,内核函数以普通 Python 函数方式调用(内核启动逻辑封装在
kernel.py内部),贴合 Triton/CUDA 内核的测试习惯。
逻辑关系图
代码
python
def _generate_test(
self, problem_description: str, provided_test_code: str | None = None
) -> str:
"""
Generate test code for the problem using OpenAI API.
The test must import from 'kernel' module since each worker writes
the kernel to 'kernel.py' in their working directory.
Args:
problem_description: Description of the problem
provided_test_code: Optional reference test code provided by user
Returns:
Generated test code in standardized format
"""
# Use LLM provider if available; no mock fallback allowed
if not self.provider:
raise RuntimeError(
"Unable to generate test code: no LLM provider available and mock fallback disabled"
)
# Use LLM provider if available
if self.provider:
try:
self.logger.info(f"Generating test code using {self.model_name}")
# Create prompt for test generation using template
prompt = self.prompt_manager.render_test_generation_prompt(
problem_description=problem_description,
provided_test_code=provided_test_code,
)
# Call LLM API
messages = [{"role": "user", "content": prompt}]
response_text = self._call_llm(messages, max_tokens=24000)
self.logger.info("Raw test generation response:\n%s", response_text)
# Extract test code from response
test_code = self._extract_code_from_response(response_text)
if test_code:
self.logger.info(
f"Successfully generated test code using {self.model_name}"
)
return test_code
else:
self.logger.error("Failed to extract valid code from LLM response")
raise ValueError("No valid code found in LLM response")
except Exception as e:
self.logger.error(f"Error generating test with LLM API: {e}")
raise
# Mock test generation (fallback)
self.logger.info("Generating test code (mock implementation)")
# If provided test code exists, create a basic wrapper
if provided_test_code:
test_code = '''"""
Test for kernel implementation (adapted from provided test).
"""
import torch
def test_kernel():
"""Test the kernel implementation."""
from kernel import kernel_function
# Adapted from provided test code
try:
# Create test data (standardized format)
test_input = torch.randn(1024, device='cuda')
# Call kernel_function as a normal Python function
result = kernel_function(test_input)
# Basic validation
if result is not None:
print("Test passed!")
return True
else:
print("Test failed: No result returned")
return False
except Exception as e:
print(f"Test failed: {e}")
return False
if __name__ == "__main__":
import sys
success = test_kernel()
sys.exit(0 if success else 1)
'''
else:
test_code = '''"""
Test for kernel implementation.
"""
import torch
def test_kernel():
"""Test the kernel implementation."""
from kernel import kernel_function
# Mock test - replace with actual test logic
try:
# Create test data
test_input = torch.randn(1024, device='cuda')
# Call kernel_function as a normal Python function
# (kernel launch logic is handled inside kernel.py)
result = kernel_function(test_input)
print("Test passed!")
return True
except Exception as e:
print(f"Test failed: {e}")
return False
if __name__ == "__main__":
import sys
success = test_kernel()
sys.exit(0 if success else 1)
'''
return test_code
内核种子生成
核心作用
该方法是基于 LLM 批量生成 Triton 内核初始实现代码(Kernel Seeds)的核心模块,为 PyTorch KernelAgent 提供多版本的初始内核候选代码,所有生成代码需适配指定测试代码并遵循统一封装规范,是 LLM 生成 Triton 内核流水线中「初始代码生成环节」的核心,为后续内核筛选、调优提供多版本基础素材。
核心特色
- 批量生成多版本候选内核,支持数量灵活配置 :可指定生成内核数量(
num_seeds),未指定时默认匹配工作器数量(self.num_workers),生成多版本初始内核,为后续筛选可用内核提供样本基础。 - LLM 生成强绑定测试代码,确保适配性:生成 Prompt 中融入用户提供的测试代码,要求 LLM 生成的内核必须能对接该测试代码,从源头保证内核与测试的兼容性,避免后续测试环节的基础适配问题。
- 原生多响应 + 兜底循环调用,适配不同 LLM 服务商能力 :智能适配 LLM 服务商能力 ------ 支持原生多响应的服务商直接批量生成,不支持的则通过循环单独调用实现批量生成,兼顾调用效率与兼容性,且单独调用时动态提高温度系数(
temperature),提升多版本内核的多样性。 - 统一 Triton 内核封装规范,贴合工程落地 :强制生成的内核遵循固定封装模式,需实现
kernel_function作为内核启动包装函数,该函数统一处理参数接收、Triton 内核启动逻辑,与测试代码的调用方式完全匹配,无调用规范冲突。 - 完整的代码提取与容错机制:对每个 LLM 响应单独提取有效内核代码,单版本提取失败仅记录警告不中断整体流程,全量提取失败则抛出明确异常;全流程捕获 LLM 调用异常,异常后自动触发 Mock 兜底,保证功能不中断。
逻辑关系图
代码
python
def _generate_kernel_seeds(
self, problem_description: str, test_code: str, num_seeds: int | None = None
) -> list[str]:
"""
Generate initial kernel implementations using OpenAI API.
Args:
problem_description: Description of the kernel to generate
test_code: Test code that the kernel must pass
num_seeds: Number of kernel variations to generate
Returns:
List of kernel implementation strings
"""
if num_seeds is None:
num_seeds = self.num_workers
# Use LLM provider if available
if self.provider:
try:
self.logger.info(
f"Generating {num_seeds} kernel seeds using {self.model_name}"
)
# Create prompt with Triton guidelines using template
prompt = self.prompt_manager.render_kernel_generation_prompt(
problem_description=problem_description, test_code=test_code
)
kernels = []
messages = [{"role": "user", "content": prompt}]
# Use provider's multiple response capability
max_completion_tokens = 20000
if self.provider.supports_multiple_completions():
# Provider supports native multiple completions
responses = self.provider.get_multiple_responses(
self.model_name,
messages,
n=num_seeds,
temperature=0.8,
max_tokens=max_completion_tokens,
high_reasoning_effort=self.high_reasoning_effort,
)
for i, response in enumerate(responses):
kernel_code = self._extract_code_from_response(response.content)
if kernel_code:
kernels.append(kernel_code)
else:
self.logger.warning(
f"Failed to extract code from kernel seed {i}"
)
else:
# Provider doesn't support multiple completions, make individual calls
for i in range(num_seeds):
response_text = self._call_llm(
messages,
max_tokens=max_completion_tokens,
temperature=0.8 + (i * 0.1),
)
kernel_code = self._extract_code_from_response(response_text)
if kernel_code:
kernels.append(kernel_code)
else:
self.logger.warning(
f"Failed to extract code from kernel seed {i}"
)
if kernels:
self.logger.info(
f"Successfully generated {len(kernels)} kernel seeds"
)
return kernels
else:
self.logger.error(
"Failed to extract any valid kernels from LLM responses"
)
raise ValueError("No valid kernel code found in any LLM response")
except Exception as e:
self.logger.error(f"Error generating kernels with LLM API: {e}")
# Fall back to mock implementation
# Mock kernel generation (fallback)
self.logger.info(f"Generating {num_seeds} kernel seeds (mock implementation)")
kernels = []
for i in range(num_seeds):
# Simpler mock that still demonstrates the wrapper pattern
if i == 2: # Third kernel will pass
kernel = '''"""
Kernel implementation - working version.
"""
def kernel_function(*args, **kwargs):
"""Wrapper function that handles kernel launch."""
# Mock implementation that passes tests
# In real kernels, this would launch a Triton kernel
return True
'''
else:
kernel = f'''"""
Kernel implementation {i + 1}.
"""
def kernel_function(*args, **kwargs):
"""Wrapper function that handles kernel launch."""
# Mock implementation that fails
raise NotImplementedError('Mock kernel not implemented')
'''
kernels.append(kernel)
return kernels
2.2 与子组件的交互
与 PromptManager 的交互
本模块与 PromptManager 在如下时机进行交互。
初始化时创建
python
# Initialize prompt manager with resolved config
self.prompt_manager = PromptManager(target_platform=self._platform_config)
生成测试时
python
# Create prompt for test generation using template
prompt = self.prompt_manager.render_test_generation_prompt(
problem_description=problem_description,
provided_test_code=provided_test_code,
)
生成内核时
python
# Create prompt with Triton guidelines using template
prompt = self.prompt_manager.render_kernel_generation_prompt(
problem_description=problem_description, test_code=test_code
)
优化时
python
# Create refinement prompt using template
prompt = self.prompt_manager.render_kernel_refinement_prompt(
problem_description=problem_description,
test_code=test_code,
kernel_code=kernel_code,
error_info=error_info,
history_context=history_context,
)
与 WorkerManager 的交互
会在生成内核时进行验证
python
# Initialize worker manager
self.manager = WorkerManager(
num_workers=self.num_workers,
max_rounds=self.max_rounds,
log_dir=self.log_dir,
openai_api_key=os.getenv("OPENAI_API_KEY"),
openai_model=self.model_name,
high_reasoning_effort=self.high_reasoning_effort,
target_platform=self._platform_config.name,
)
# Run parallel verification with session directory for worker logs
result = self.manager.run_verification(
kernel_seeds=kernel_seeds,
test_code=test_code,
problem_description=problem_description,
session_log_dir=session_dir,
)
0x03 prompt
3.1 逻辑流
render_kernel_generation_prompt 函数生成了用于生成kernel的prompt。
python
def render_kernel_generation_prompt(
self,
problem_description: str,
test_code: str,
triton_guidelines: str | None = None,
) -> str:
"""
Render the kernel generation prompt.
Args:
problem_description: Description of the kernel to generate
test_code: Test code that the kernel must pass
triton_guidelines: Optional guidelines (if None, loads from template)
Returns:
Rendered prompt string
"""
template = self.templates["kernel_generation"]
# Load triton guidelines if not provided
if triton_guidelines is None:
triton_guidelines = self.render_triton_guidelines()
return template.render(
problem_description=problem_description,
test_code=test_code,
triton_guidelines=triton_guidelines,
kernel_guidance=self.target_platform.kernel_guidance,
)
render_triton_guidelines 函数会获取生成 Triton的指导准则。
python
def render_triton_guidelines(self) -> str:
"""
Render the Triton guidelines.
Returns:
Rendered guidelines string
"""
template = self.templates["triton_guidelines"]
return template.render()
template_files 如下:
python
# Define template mappings
template_files = {
"test_generation": "test_generation.j2",
"kernel_generation": "kernel_generation.j2",
"kernel_refinement": "kernel_refinement.j2",
"triton_guidelines": "triton_guidelines.j2",
}
3.2 j2 文件
四个模板如下:
- test_generation.j2 ---生成测试代码的提示
- kernel_generation.j2 --- 生成Triton内核的提示
- kernel_refinement.j2 --- 内核送代修复的提示
- triton_guidelines.j2 --- Triton 编程指南
triton_guidelines.j2文件
该文件 定义Triton内核编程的基本规范和最佳实践。
具体内容
结构规范:
- 使用 @triton.jit 装饰器
- 函数命名规范
- 常量定义(tl.constexpr)
内存访问模式:
- tl.load 和 tl.store 的正确使用
- 掩码(masking)和边界处理
优化技术:
- 自动调优配置
- 块大小选择
- 张量核(tensor cores)使用
- 操作融合策略
运行时限制:
- 包装器函数只能做参数验证、分配和启动配置
- 所有计算必须在 Triton 内核中执行
- 禁止使用 PyTorch 高级操作
特点
- 通用性:适用于所有 Triton 内核开发
- 权威性:作为其他模板的基础参考
kernel_generation.j2 文件
该文件生成全新的 Triton 内核实现
具体内容
- 任务定义:基于测试要求生成完整的 Triton 内核实现
- 融合优先策略:强调将多个操作融合为单个内核
- 运行时限制:包装器仅做参数验证,计算在内核中
- 严格约束:禁止使用 PyTorch 高级操作(如 torch.nn、torch.nn.functional)
- 示例代码:提供了具体实现结构示例
特点
- 创建导向:专注于从零开始创建内核
- 融合优化:强调操作融合以提高性能
- 结构完整:要求生成完整的 Python 文件
kernel_refinement.j2 文件
该文件会基于已有内核和错误信息进行修复和改进
具体内容
- 任务定义:根据测试结果修复现有 Triton 内核实现
- 错误导向:包含错误输出和标准输出信息
- 历史上下文:提供之前尝试的上下文信息
- 融合审计:审查现有实现以寻找更多融合机会
- 修复重点:分析错误消息并修复问题
特点
- 修复导向:专注于修复现有代码中的问题
- 迭代优化:基于错误反馈进行迭代改进
- 上下文感知:考虑之前尝试的历史信息
三者之间的区别
用途差异、
triton_guidelines.j2:
- 提供通用的 Triton 编程规范
- 作为其他两个模板的底层参考
- 不直接用于生成代码,而是提供指导原则
kernel_generation.j2:
- 用于生成新的内核实现
- 从空白状态开始创建完整的内核
- 重点关注融合操作和性能优化
kernel_refinement.j2:
- 用于修复现有的内核实现
- 基于错误信息改进已有代码
- 注重迭代优化和问题解决
输入数据差异
triton_guidelines.j2:
- 提供通用准则,不依赖特定输入
kernel_generation.j2:
- 需要测试代码作为输入
- 生成基于测试需求的新内核
kernel_refinement.j2:
- 需要现有内核代码、错误信息和测试代码
- 包含历史上下文信息
输出目标差异
triton_guidelines.j2:
- 输出编程准则文档
- 不产生实际的 Python 代码
kernel_generation.j2:
- 输出全新的、完整的 Triton 内核实现
- 目标是满足测试要求的功能实现
kernel_refinement.j2:
- 输出修复后的内核代码
- 目标是解决特定错误并保持或改进功能
工作流程差异
triton_guidelines.j2:
- 在生成和修复过程中都会被引用
- 提供基础规范
kernel_generation.j2:
- 通常在工作流程的初始阶段使用
- 生成第一个可行版本
kernel_refinement.j2:
- 在测试失败后使用
- 进行迭代改进直到通过测试
0x04 Worker Manager 的作用
manager.py 是 KernelAgent 系统中的 WorkerManager 类实现文件,负责管理多个内核验证工作者进程的并行执行。在系统架构中的位置如下。
python
TritonKernelAgent
↓
WorkerManager (manager.py)
↓
VerificationWorker (worker.py)
↓
Kernel Generation & Testing
4.1 核心功能
manager.py 在 KernelAgent 架构中扮演着并行计算调度器的角色,通过多进程并行执行多个内核验证任务,显著提升了内核生成的成功率和效率。其设计考虑了容错性、性能和资源管理等多个方面,是整个系统能够高效运行的关键组件之一。
使用场景
- 内核生成加速
- 并行尝试多种内核实现变体
- 快速找到满足测试要求的内核
- 鲁棒性提升
- 多个独立的验证流程降低失败风险
- 不同的随机种子可能产生不同的结果
并行工作管理
WorkerManager 类负责多进程协调,其工作负载分配原则如下:
- 种子内核分发:将多个初始内核实现分发给不同的工作者
- 并行验证:同时运行多个验证流程
- 资源隔离:每个工作者在独立的进程中运行
python
class WorkerManager:
def __init__(
self,
num_workers: int = 4,
max_rounds: int = 10,
history_size: int = 8,
log_dir: str | None = None,
openai_api_key: str | None = None,
openai_model: str = "gpt-5",
high_reasoning_effort: bool = True,
target_platform: str = "cuda",
):
# 初始化多个工作者进程
self.num_workers = num_workers
self.workers: list[mp.Process] = []
生命周期管理
工作目录管理
会为 为工作者创建临时工作目录。
python
@contextmanager
def temp_workdirs(self) -> list[Path]:
"""Create temporary working directories for workers."""
workdirs = []
try:
for i in range(self.num_workers):
workdir = Path(tempfile.mkdtemp(prefix=f"worker_{i}_"))
workdirs.append(workdir)
self.logger.info(f"Created workdir for worker {i}: {workdir}")
yield workdirs
finally:
# Cleanup
for workdir in workdirs:
if workdir.exists():
shutil.rmtree(workdir)
self.logger.info(f"Cleaned up workdir: {workdir}")
进程生命周期
-
启动:使用 multiprocessing.Process 启动工作者进程
-
监控:监控工作者进程状态
-
清理:在完成或异常情况下终止进程
同步与通信机制
成功信号
python
self.success_event = mp.Event() # 跨进程成功信号
-
当任一工作者找到解决方案时,设置事件
-
其他工作者检测到信号后停止工作
结果收集
python
self.result_queue = mp.Queue() # 结果队列
-
收集各工作者的结果
-
确保最成功的结果被返回
4.2 工作流程分析
关键设计特点如下:
容错机制
-
超时处理:进程超时后强制终止
-
异常恢复:捕获和处理各种异常情况
-
资源清理:确保临时文件和进程被正确清理
性能优化
-
并发执行:利用多核并行验证
-
早期退出:一旦找出成功解,立即停止其他进程
-
资源限制:控制并发数量,避免资源耗尽
配置灵活性
- 工作者数量:可配置并行度
- 最大轮数:控制每个工作者的尝试次数
- 历史记录:限制错误历史的大小
验证执行流程如下。
python
def run_verification(
self,
kernel_seeds: list[str],
test_code: str,
problem_description: str,
session_log_dir: Path | None = None,
) -> dict[str, Any | None]:
"""
Run parallel verification on multiple kernel seeds.
Args:
kernel_seeds: List of initial kernel implementations
test_code: Test code to verify kernel correctness
problem_description: Description of the problem
session_log_dir: Optional session directory for worker logs
Returns:
Dictionary with successful kernel and metadata, or None
"""
self.logger.info(f"Starting verification with {len(kernel_seeds)} seeds")
# Reset cross-worker success signal for a fresh run
try:
self.success_event.clear()
except Exception:
pass
# Reset workers list to avoid holding stale processes between runs
self.workers = []
# Determine where to put worker logs
if session_log_dir:
workers_parent_dir = Path(session_log_dir) / "workers"
workers_parent_dir.mkdir(exist_ok=True)
else:
workers_parent_dir = self.workers_dir
# 1. 准备工作目录
with self.temp_workdirs() as workdirs:
# Start workers 启动工作者进程
for i, (kernel, workdir) in enumerate(zip(kernel_seeds, workdirs)):
worker_log_dir = workers_parent_dir / f"worker_{i}"
worker_log_dir.mkdir(exist_ok=True)
args = (
i,
kernel,
test_code,
problem_description,
workdir,
worker_log_dir,
self.max_rounds,
self.history_size,
self.success_event,
self.result_queue,
self.openai_api_key,
self.openai_model,
self.high_reasoning_effort,
self.target_platform,
)
process = mp.Process(target=worker_process, args=args)
process.start()
self.workers.append(process)
self.logger.info(f"Started worker {i}")
# Wait for any worker to succeed or all to finish
successful_result = None
# 等待结果或超时
while any(w.is_alive() for w in self.workers):
try:
# Check for results with timeout
result = self.result_queue.get(timeout=1.0)
if result["success"]:
successful_result = result
self.logger.info(f"Worker {result['worker_id']} succeeded!")
# Signal all workers to stop
# 信号其他进程停止
self.success_event.set()
break
except queue.Empty:
continue
# Wait for all workers to finish
# 清理进程
for worker in self.workers:
worker.join(timeout=5.0)
if worker.is_alive():
self.logger.warning(f"Terminating worker {worker.pid}")
worker.terminate()
# Collect any remaining results
while not self.result_queue.empty():
try:
result = self.result_queue.get_nowait()
if result["success"] and successful_result is None:
successful_result = result
except queue.Empty:
break
return successful_result
工作者进程入口
python
def worker_process(
worker_id: int,
kernel_code: str,
test_code: str,
problem_description: str,
workdir: Path,
log_dir: Path,
max_rounds: int,
history_size: int,
success_event: mp.Event,
result_queue: mp.Queue,
openai_api_key: str | None,
openai_model: str,
high_reasoning_effort: bool,
target_platform: str,
):
"""
Worker process for kernel verification and refinement.
This is run in a separate process.
"""
# Import here to avoid issues with multiprocessing
from .worker import VerificationWorker
worker = VerificationWorker(
worker_id=worker_id,
workdir=workdir,
log_dir=log_dir,
max_rounds=max_rounds,
history_size=history_size,
openai_api_key=openai_api_key,
openai_model=openai_model,
high_reasoning_effort=high_reasoning_effort,
target_platform=target_platform,
)
result = worker.run(
kernel_code=kernel_code,
test_code=test_code,
problem_description=problem_description,
success_event=success_event,
)
result_queue.put(result)
4.3 与其他组件的交互
与 TritonKernelAgent 的交互
- 接收:初始内核种子、测试代码、问题描述
- 返回:成功的内核实现或失败指示
与 VerificationWorker 的交互
- 启动:为每个工作者进程提供参数
- 监控:管理工作者的生命周期
与系统资源的交互
- 文件系统:管理临时工作目录
- 进程管理:创建和终止子进程
- 内存管理:通过队列传递结果
0x05 VerificationWorker
VerificationWorker 是专门用于验证和优化单个 Triton 内核实现的工作流程。
python
TritonKernelAgent
↓
WorkerManager
↓
VerificationWorker
↓
Kernel Generation & Refinement Loop
5.1 完整执行流程
VerificationWorker 是 TritonKernelAgent 系统中负责单个内核验证和优化的核心组件,它实现了闭环的验证-优化循环,通过 LLM 驱动的方式自动改进内核实现,同时确保符合 Triton 编程规范和系统约束。
- 初始化:设置工作目录、日志和文件路径
- 并行生成
- 在多进程环境中与其他工作进程协作
- 快速找到有效的内核实现
- 检查执行:
- 检查其他进程是否已经成功
- 检测内核代码中的违规行为
- 运行测试验证内核
- 验证生成的 Triton 内核是否符合规范
- 确保内核通过给定的测试用例
- 记录当前轮次结果
- 如果失败则使用LLM优化内核
- 自动修复内核中的错误
- 优化内核以满足 Triton 编程指南
- 继续下一轮直到成功或达到最大轮次
5.2 设计
违规检测机制
- 静态分析:检查代码中是否存在不允许的 PyTorch 模式
- 模式匹配:使用正则表达式检测各种违规用法
- 字符串剥离:移除注释和字符串以避免误报
LLM 交互机制
- 历史上下文:包含之前尝试的错误和输出
- 模板化提示:使用 Jinja2 模板构建优化提示
- 错误导向:基于具体的错误信息指导 LLM 修复
并发安全
- 共享事件:使用 multiprocessing.Event 通知其他进程已有成功结果
- 独立工作目录:每个工作进程有自己独立的文件系统空间
错误处理
- 超时保护:测试执行有 30 秒超时限制
- 违规检查:防止内核使用不合规的 PyTorch 功能
- 历史追踪:记录每轮的尝试结果用于后续优化
日志记录
- 轮次日志:为每次迭代保存完整的输入输出
- 历史保留:保留最近几轮的结果用于 LLM 上下文
- 详细记录:记录时间戳、成功状态和错误信息
5.3 与其他组件的交互
与 WorkerManager 的交互
- 接收参数:从管理器接收内核代码、测试代码和问题描述
- 返回结果:将最终结果发送回结果队列
与 PromptManager 的交互
- 获取模板:使用预定义的 Jinja2 模板构建 LLM 提示
- 优化提示:生成特定的内核优化提示
与 LLM 提供商的交互
- API 调用:调用配置的 LLM 模型进行内核优化
- 响应解析:从 LLM 响应中提取有效的 Python 代码
5.4 Refine
_refine_kernel函数的作用:根据错误信息使用 LLM 来修复和改进当前的Triton内核实现。
这个函数是整个 Triton 内核生成过程中的关键组件,它实现了自动化错误修复和代码优化,使系统能够自主解决遇到的问题。
在验证循环中的角色
反馈循环
- 测试失败后:当内核测试失败时,_refine_kernel会分析错误并生成改进版本
- 多轮迭代,自适应优化:支持多轮修复尝试,直到达到最大轮数或测试通过
- 迭代改进:基于测试失败的原因持续优化内核实现
- 融合优先:保持融合操作的优先级,在修复时不会降低融合程度
约束维护
- 规则遵守:确保修复后的代码仍然符合Triton编程指南和约束条件
- 禁止操作检查:防止使用不允许的PyTorch操作(如torch.nn模块、torch.nn.functional等)
错误处理机制
- 错误感知:接收来自测试运行的错误信息(error_info),包括标准输出和标准错误内容
- 上下文构建:利用历史记录构建之前的尝试上下文,帮助LLM理解之前的问题和修改
输入输出
输入参数
- kernel_code: 需要修复的当前内核代码
- error_info: 包含测试失败信息的字典(stdout、stderr)
- problem_description: 原始问题描述
- test_code: 用于验证的测试代码
输出结果
- refined_kernel:经过LLM优化的内核代码字符串
- fallback机制:如果LLM不可用,则返回原始代码或简单修改
prompt
_refine_kernel 调用 render_kernel_refinement_prompt。
提示构建:使用PromptManager生成专门的修复提示,其中包含:
- 问题描述
- 测试代码
- 当前的内核实现
- 错误信息
- 历史尝试记录
5.5 设计思考
subprocess
VerificationWorker使用subprocess而非exec来执行生成的代码,其原因如下:
- 安全隔离:LLM生成的代码不可信,子进程隔离可防止恶意代码污染主进程(如修改全局变量、猴子补丁标准库)
- 崩溃隔离:segfault、CUDA OOM等致命错误不会拖整个 agent进程。
- 超时控制:子进程可以被强制终止(SIGTERM/SIGKILL),而exec()中的无限循环只能通过信号打断。
- 环境白名单:通过env参数控制子进程可见的环境变量,防止泄露 API key。
- 资源限制:可通过-I 隔离模式或sitecustomize.py注入来禁用网络。
不过项目也提供了_run_test_multiprocess作为替代方案(通过multiprocessing.Process+exec)。
禁止模式
VerificationWorker中的DISALLOWED_TORCH_PATTERNS列表给出3种被禁止的模式。该模式用于检测生成的 Triton 内核代码中是否非法使用了 PyTorch 计算,确保内核是真正用 Triton 实现的,而非偷偷调用 PyTorch。
另外,这也是一种安全防线,防止LLM生成的内核代码通过帧内省"作弊":
- 窃取测试数据:内核可以通过'inspect.stack获取调用者的栈帧,从中读取测试函数的局部变量(如期望输出expected_output),然后直接返回该值"通过"测试,而根本没有进行真正的Triton计算,
- 窃取参考实现:可以通过globals找到并调用PyTorch参考函数,伪装为Triton实现。
- 破坏沙盒假设:帧内省可以跳出模块边界,访问不应被访问的上下文。
LLM有时会"发现"这些捷径,因为它被训练来通过测试一这些规则,确保通过测试的唯一方式是真正实现正确的Triton内核。
被禁止的模式举例如下:
- import torch.nn或from torch import nn → 禁止导入nn 模块
- torch.matmul()、torch.mm()、torch.bmm() → 禁止PyTorch 矩阵乘法
- torch.einsum() → 禁止 einsum
- inspect.stack()、sys·
_getframe() → 禁止帧内省(防止从调用者窃取测试数据) - globals()、locals() → 禁止访问全局/局部变量
检测时会先用_strip_comments_and_strings()去掉注释和文档字符串,避免误报。
历史记录
Verificationworker的history使用了deque(maxlen=history_size)(默认 8)。为什么要限制 history 窗口大小而不是保留全部历史?其原因如下:
- Token 预算控制:每轮修复都需要将history 作为上下文发送给LLM,全部历史会迅速耗尽上下文窗口(refinement prompt 中历史部分包含每轮的kernel代码片段+错误信息)。
- 信息衰减:早期失败的代码对后续修复的参考价值递减,最近的失败尝试更有指导意义。
- 避免循环:过多历史可能导致LLM"过拟合"于先前错误模式,而非探索新方向。
- 内存效率:每轮round_data包含完整kernel代码,保留全部会占用大量内存。
5.6 代码
内核验证功能
测试执行
python
def _run_test(self) -> tuple[bool, str, str]:
"""
Run the test script and capture results.
Returns:
Tuple of (success, stdout, stderr)
"""
cmd = [sys.executable, str(self.test_file)]
try:
result = subprocess.run(
cmd,
cwd=str(self.workdir),
capture_output=True,
text=True,
timeout=30, # 30 second timeout
)
success = result.returncode == 0
if success:
self.logger.info("Test passed")
else:
self.logger.error(
"Test failed. Exit code: %s, stderr: %s",
result.returncode,
result.stderr[:500],
)
return success, result.stdout, result.stderr
except subprocess.TimeoutExpired:
self.logger.error("Test timed out")
return False, "", "Test execution timed out after 30 seconds"
except Exception as e:
self.logger.error(f"Test execution error: {e}")
return False, "", str(e)
PyTorch 计算违规检测
python
DISALLOWED_TORCH_PATTERNS = [
(
re.compile(r"\bimport\s+torch\.nn(\b|\s+as\b)"),
"importing torch.nn modules is not allowed",
),
(
re.compile(r"\bfrom\s+torch\s+import\s+nn\b"),
"importing torch.nn modules is not allowed",
),
(
re.compile(r"\bimport\s+torch\.nn\.functional\s+as\s+F\b"),
"aliasing torch.nn.functional as F is not allowed",
),
(re.compile(r"\btorch\.nn\."), "torch.nn module usage is not allowed"),
(
re.compile(r"\btorch\.nn\.functional\b"),
"torch.nn.functional usage is not allowed",
),
(
re.compile(r"\bF\.[A-Za-z_]+\("),
"torch.nn.functional alias calls (F.*) are not allowed",
),
(re.compile(r"\btorch\.conv"), "torch convolution helpers are not allowed"),
(
re.compile(
r"\btorch\.(relu|sigmoid|tanh|softmax|gelu|mish|hardtanh|max_pool|avg_pool)[A-Za-z0-9_]*\("
),
"PyTorch activation/pooling helpers are not allowed",
),
(
re.compile(r"\bclass\s+\w+\s*\(\s*nn\.Module"),
"Subclassing torch.nn.Module is not allowed",
),
(
re.compile(r"\.forward\("),
"Calling .forward() indicates torch.nn module usage and is not allowed",
),
(
re.compile(r"\btorch\.ops\.aten\b"),
"Low-level torch.ops.aten.* calls are not allowed; implement these ops directly in Triton kernels instead of relying on PyTorch compute",
),
# Generic tensor-tensor math that must be implemented in Triton kernels
(
re.compile(r"\btorch\.(matmul|mm|bmm)\s*\("),
"PyTorch matmul/mm/bmm tensor-tensor ops are not allowed; implement these in Triton kernels",
),
(
re.compile(r"\.(matmul|mm|bmm)\s*\("),
"Tensor.matmul/mm/bmm methods are not allowed; implement these in Triton kernels",
),
(
re.compile(r"\btorch\.einsum\s*\("),
"torch.einsum is not allowed; implement this contraction with Triton primitives",
),
(
re.compile(r"\.einsum\s*\("),
"Tensor.einsum is not allowed; implement this contraction with Triton primitives",
),
# Introspection / frame inspection that can be used to steal test locals
(
re.compile(r"\bimport\s+inspect\b"),
"inspect-based reflection is not allowed inside kernel files",
),
(
re.compile(r"\binspect\.(stack|currentframe|getouterframes)\s*\("),
"inspect stack/frame introspection is not allowed in kernels",
),
(
re.compile(r"\bsys\._getframe\s*\("),
"sys._getframe is not allowed in kernels; do not access caller frames",
),
(
re.compile(r"\.f_locals\b|\.f_globals\b"),
"Accessing frame locals/globals (f_locals/f_globals) from kernels is not allowed",
),
(
re.compile(r"\bglobals\s*\("),
"globals() is not allowed in kernels; avoid depending on ambient test state",
),
(
re.compile(r"\blocals\s*\("),
"locals() is not allowed in kernels; avoid depending on caller scopes",
),
]
内核优化功能
LLM 驱动优化
python
def _refine_kernel(
self,
kernel_code: str,
error_info: dict[str, str],
problem_description: str,
test_code: str,
) -> str:
"""
Refine kernel based on error information using OpenAI API.
Uses multi-turn dialogue by incorporating history of previous attempts.
"""
if self.provider:
try:
self.logger.info(f"Refining kernel using {self.openai_model}")
# Build context from history
# 构建历史上下文
history_context = ""
if self.history:
history_context = "\n\nPREVIOUS ATTEMPTS:\n"
for i, round_data in enumerate(self.history):
# 添加历史记录
history_context += f"\nAttempt {i + 1}:\n"
history_context += f"Kernel code:\n```python\n{round_data['kernel_code'][:500]}...\n```\n"
if round_data.get("stderr"):
history_context += f"Error: {round_data['stderr'][:200]}\n"
if round_data.get("stdout"):
history_context += f"Output: {round_data['stdout'][:200]}\n"
# Create refinement prompt using template
# 使用模板创建优化提示
prompt = self.prompt_manager.render_kernel_refinement_prompt(
problem_description=problem_description,
test_code=test_code,
kernel_code=kernel_code,
error_info=error_info,
history_context=history_context,
)
# Call LLM API
messages = [{"role": "user", "content": prompt}]
response_text = self._call_llm(messages, max_tokens=20000)
# Extract refined kernel from response
# 从响应中提取优化后的内核
refined_kernel = self._extract_code_from_response(response_text)
if refined_kernel:
self.logger.info(
f"Successfully refined kernel using {self.openai_model}"
)
return refined_kernel
else:
self.logger.error("Failed to extract valid code from LLM response")
# Return original kernel if extraction fails
return kernel_code
except Exception as e:
self.logger.error(f"Error refining kernel with LLM API: {e}")
# Fall back to mock refinement
# Mock refinement (fallback)
self.logger.info("Refining kernel (mock implementation)")
# For testing, make a simple modification
if "error" in error_info.get("stderr", "").lower():
# Add a comment to show refinement happened
return f"# Refinement attempt {len(self.history) + 1}\n{kernel_code}"
return kernel_code
优化循环
python
def run(
self,
kernel_code: str,
test_code: str,
problem_description: str,
success_event: mp.Event,
) -> dict[str, Any]:
"""
Run verification and refinement loop.
Args:
kernel_code: Initial kernel implementation
test_code: Test code to verify kernel
problem_description: Problem description for context
success_event: Shared event to check if another worker succeeded
Returns:
Dictionary with results
"""
self.logger.info(f"Starting verification for worker {self.worker_id}")
# 运行验证和优化循环
current_kernel = kernel_code
for round_num in range(self.max_rounds):
# Check if another worker has succeeded
# 检查是否有其他工作进程已经成功
if success_event.is_set():
self.logger.info("Another worker succeeded, stopping")
return {
"worker_id": self.worker_id,
"success": False,
"stopped_early": True,
"rounds": round_num,
}
self.logger.info(f"Round {round_num + 1}/{self.max_rounds}")
# 更新内核文件
# Write files - test only on first round, kernel every round
if round_num == 0:
# First round: write both kernel and test
self._write_files(current_kernel, test_code)
else:
# Subsequent rounds: only update kernel, test remains unchanged
self._write_kernel(current_kernel)
# 检测违规
violation = self._detect_pytorch_compute(current_kernel)
if violation:
message = f"Disallowed PyTorch usage detected: {violation}"
self.logger.error(message)
self._log_round(round_num + 1, False, current_kernel, "", message)
error_info = {
"stdout": "",
"stderr": message,
"history": list(self.history),
}
# 违规处理
current_kernel = self._refine_kernel(
current_kernel, error_info, problem_description, test_code
)
continue
# Run test 执行测试
success, stdout, stderr = (
self._run_test()
if os.getenv("KA_PROCESS_USE_SYS_EXECUTABLE", "1") == "1"
else _run_test_multiprocess(self.logger, self.workdir, self.test_file)
)
# Log round 记录轮次
self._log_round(round_num + 1, success, current_kernel, stdout, stderr)
if success:
self.logger.info(
f"Success! Kernel passed test in round {round_num + 1}"
)
return {
"worker_id": self.worker_id,
"success": True,
"kernel_code": current_kernel,
"rounds": round_num + 1,
"history": list(self.history),
}
# Refine kernel for next round
# 为下一轮优化内核
error_info = {
"stdout": stdout,
"stderr": stderr,
"history": list(self.history),
}
current_kernel = self._refine_kernel(
current_kernel, error_info, problem_description, test_code
)
# Max rounds reached without success
self.logger.warning(f"Max rounds ({self.max_rounds}) reached without success")
return {
"worker_id": self.worker_id,
"success": False,
"max_rounds_reached": True,
"rounds": self.max_rounds,
"history": list(self.history),
}
0xFF 参考
KernelFalcon: Autonomous GPU Kernel Generation via Deep Agents
Automating GPU Kernel Generation with DeepSeek-R1 and Inference Time Scaling
DeepSeek-R1自写CUDA内核跑分屠榜!斯坦福学霸狂飙GPU编程自动化挑战人类
大模型能否为不同硬件平台生成高性能内核?南大、浙大提出跨平台内核生成评测框架MultiKernelBench
AKG kernel Agent:利用multi-agent进行kernel的生成和迁移
AKG KERNEL AGENT: A MULTI-AGENT FRAMEWORK FOR CROSS-PLATFORM KERNEL SYNTHESIS
RL 猛刷 CUDA 核:CUDA-L1: Improving CUDA Optimization via Contrastive Reinforcement Learning
MultiKernelBench: A Multi-Platform Benchmark for Kernel Generation
Ouyang A, Guo S, Arora S, et al. Kernelbench: Can llms write efficient gpu kernels?[J]. arXiv preprint arXiv:2502.10517, 2025.
Baronio, Carlo, et al. "Kevin: Multi-turn rl for generating cuda kernels."arXiv preprint arXiv:2507.11948(2025).
Li, Shangzhan, et al. "Autotriton: Automatic triton programming with reinforcement learning in llms."arXiv preprint arXiv:2507.05687(2025).
Li, Jianling, et al. "Tritonbench: Benchmarking large language model capabilities for generating triton operators."Findings of the Association for Computational Linguistics: ACL 2025. 2025.
Tjarko Lange, Robert, et al. "Towards Robust Agentic CUDA Kernel Benchmarking, Verification, and Optimization."arXiv e-prints(2025): arXiv-2509.
Chen, Wentao, et al. "CUDA-LLM: LLMs Can Write Efficient CUDA Kernels."arXiv preprint arXiv:2506.09092(2025).