PyTorch KernelAgent 源码解读 ---(5)--- Dispatcher

PyTorch KernelAgent 源码解读 ---(5)--- Dispatcher

目录

0x00 概述

dispatch_kernel_agent.py 是 KernelAgent 系统中的调度组件,负责将 subgraph_extractor.py 生成的子图(JSON 格式)转换为具体的 Triton 内核生成任务,并调度 TritonKernelAgent 来生成和验证这些内核。

Dispatcher架构图如下,其功能概括是:读subgraphs.json,把每个子图转成含 reference code 的精确 Triton 生成spec,交给独立的 TritonKernelAgent 实例并发生成,产出kernel.py + summary.json。

0x01 Dispatch Kernel Agent 的作用

dispatch_kernel_agent.py 在 KernelAgent 系统中扮演着桥梁的角色,它将高层的子图分解结果转化为具体的 Triton 内核生成任务。其核心价值在于:

  • 自动化任务分配:将复杂的融合模型分解为独立的子图任务
  • 标准化问题描述:为每个子图生成适合 Triton 内核生成的描述
  • 并行处理能力:支持并发生成多个子图的 Triton 内核
  • 结果整合:收集和整理所有子图的内核生成结果,为后续的合成阶段做准备

dispatch_kernel_agent.py 在 流水线 中的位置如下:

python 复制代码
原始模型 → orchestrator.py(fuse)→ subgraph_extractor.py(extract)→ dispatch_kernel_agent.py → compose_end_to_end.py

1.1 整体功能

并发处理机制

run 函数会把子图发给 KernelAgent,来并行生成Triton 内核

python 复制代码
def run(
    subgraphs_path: Path,
    out_dir: Path,
    agent_model: str | None = None,
    jobs: int = 1,
    target_platform: str = "cuda",
    max_iters: int = 10,
) -> Path:
    """Dispatch subgraphs to KernelAgent with optional parallelism.

    jobs controls the number of concurrent subgraph generations. Default=1
    preserves previous behavior and avoids GPU/LLM contention.
    """

    # Submit tasks with bounded concurrency
    jobs = max(1, int(jobs or 1))
    ordered_inputs: list[tuple[int, dict[str, Any]]] = list(enumerate(items, start=1))
    results: dict[int, dict[str, Any]] = {}
    if jobs == 1: # 串行处理
        for pair in ordered_inputs:
            i, res = _handle_one(pair)
            results[i] = res
    else: # 并发处理
        with _futures.ThreadPoolExecutor(max_workers=jobs) as ex:
            future_map = {
                ex.submit(_handle_one, pair): pair[0] for pair in ordered_inputs
            }
            for fut in _futures.as_completed(future_map):
                i, res = fut.result()
                results[i] = res

任务处理函数

_handle_one 函数会调用 KernelAgent 生成算子。

python 复制代码
    def _handle_one(idx_item: tuple[int, dict[str, Any]]) -> tuple[int, dict[str, Any]]:
        idx, item = idx_item
        sid = str(item.get("id", f"subgraph_{idx}"))
        pdesc = _synthesize_problem_description(item, target_platform=platform)
        sg_dir = out_dir / sid
        sg_dir.mkdir(parents=True, exist_ok=True)
        (sg_dir / "problem.txt").write_text(pdesc, encoding="utf-8")

        # Pin KernelAgent concurrency defaults: 4 workers, max_iters rounds
        # 为每个子图创建独立的 TritonKernelAgent 实例
        local_agent = TritonKernelAgent(
            num_workers=4,
            max_rounds=max_iters,
            model_name=agent_model,
            target_platform=platform,
        )
        
        # 生成算子
        try:
            result = local_agent.generate_kernel(
                problem_description=pdesc, test_code=None
            )

_handle_one 函数会调用 _synthesize_problem_description 来生成 问题描述

问题描述合成

生成包含子图信息、形状、操作序列的问题描述

python 复制代码
def _synthesize_problem_description(
    item: dict[str, Any], target_platform: PlatformConfig
) -> str:
    id_ = str(item.get("id", "unknown"))
    type_ = str(item.get("type", ""))
    layout = item.get("data_layout") or "NCHW"
    dtype = item.get("dtype") or "float32"
    input_shape = item.get("input_shape")
    output_shape = item.get("output_shape")
    inputs_multi = item.get("inputs")
    weights_fused = item.get("weights_fused")
    weights_orig = item.get("weights_original")
    source = item.get("source") or {}

    ref_code, _ = _build_reference_code(item)

    # Get device string for the platform
    header = textwrap.dedent(
        f"""
        Implement a Triton kernel that computes the following subgraph end-to-end.

        Subgraph ID: {id_}
        Type: {type_}
        Data layout: {layout}
        DType: {dtype}
        Target Platform: {target_platform.name}
        Device String: {target_platform.device_string}

        Shapes:
        - input: {_fmt_shape(inputs_multi[0]) if isinstance(inputs_multi, list) else _fmt_shape(input_shape)}
        {("- input2: " + _fmt_shape(inputs_multi[1])) if isinstance(inputs_multi, list) and len(inputs_multi) > 1 else ""}
        - output: {_fmt_shape(output_shape)}

        Weights (fused): {json.dumps(weights_fused, indent=2) if isinstance(weights_fused, dict) else "null"}
        Weights (original): {json.dumps(weights_orig, indent=2) if isinstance(weights_orig, dict) else "null"}

        Operations in order (with parameters):
        {json.dumps(item.get("ops", []), indent=2)}

        Requirements:
        - Return a complete Python file with a @triton.jit kernel and a wrapper function named kernel_function(...).
        - kernel_function must accept input tensor(s) and any required weights/bias parameters (match shapes above).
        - Implement the exact semantics of the listed ops in the given order for the provided shapes.
        - Use {layout} layout and {dtype} dtype semantics.
        - Allocate inputs, weights, intermediates, and outputs on device='{target_platform.device_string}' and keep them there throughout forward/verification.
        - CPU is acceptable only for metadata, scalars, and export serialization---avoid `.cpu()` or `.to('cpu')` on compute tensors.
        - The test will import kernel_function and compare to the reference implementation below.

        Test tolerance policy (enforced in generated tests):
        - Default tolerances: rtol=1e-3, atol=1e-3.
        - Absolute cap: NEVER exceed rtol=1e-2 or atol=1e-2 in torch.allclose.
        - For float16/bfloat16 inputs: use rtol=1e-2, atol=1e-2 at most (do not go higher).
        - Include a one-line comment if you relax from default; never exceed the cap.

        Reference PyTorch implementation (exact semantics to match):
        """
    ).strip()

    src_code_block = ""  # optional original snippet for context
    if isinstance(source, dict) and source.get("code"):
        mod = source.get("module", "Model")
        code = str(source.get("code"))
        src_code_block = f"\nOriginal source snippet ({mod}):\n```python\n{code}\n```\n"

    problem = header + "\n\n```python\n" + ref_code + "```\n" + src_code_block
    return problem

其中,_build_reference_code 生成参考实现,即根据操作类型生成对应的 PyTorch 代码

python 复制代码
def _build_reference_code(item: dict[str, Any]) -> tuple[str, list[str]]:
    """Return (reference_code_str, param_names) implementing the subgraph.

    param_names are additional parameters to reference() beyond the first input(s).
    """
    ops: list[dict[str, Any]] = [
        op for op in (item.get("ops") or []) if isinstance(op, dict)
    ]
    lines: list[str] = ["import torch", "import torch.nn.functional as F", ""]
    params: list[str] = []

    # 省略其他代码

1.2 与系统其他组件的交互

dispatch_kernel_agent.py 与 subgraph_extractor.py 的交互

  • 输入:接收 subgraphs.json 文件

  • 处理:解析子图结构,提取操作和形状信息

  • 依赖:依赖于子图提取阶段的输出

dispatch_kernel_agent.py 与 TritonKernelAgent 的交互

  • 调用:为每个子图实例化 TritonKernelAgent

  • 传递:传递合成的问题描述和平台配置

  • 接收:接收生成的内核代码和验证结果

dispatch_kernel_agent.py 与 compose_end_to_end.py 的交互

  • 输出:生成 summary.json,记录每个子图的内核生成结果

  • 用途:为合成阶段提供已验证的 Triton 内核

1.3 生成结果管理

输出目录结构如下。

python 复制代码
out_dir/
├─ <subgraph_id_1>/
│  ├─ problem.txt  # 合成的问题描述
│  └─ kernel.py    # 生成的 Triton 内核
├─ <subgraph_id_2>/
│  ├─ problem.txt
│  └─ kernel.py
└─ summary.json    # 所有子图的生成结果汇总

摘要文件格式如下

python 复制代码
[
  {
    "id": "subgraph_1",
    "success": true,
    "worker_id": "worker_1",
    "rounds": 3,
    "session_dir": "/path/to/session",
    "kernel_path": "/path/to/kernel.py"
  },
  {
    "id": "subgraph_2",
    "success": false,
    "message": "generation failed...",
    "session_dir": "/path/to/session"
  }
]

0x02 TritonKernelAgent

triton_kernel_agent/agent.py 实现了 TritonKernelAgent 类,这是 Triton 内核生成系统的主要代理类,负责协调整个内核生成过程。

python 复制代码
TritonKernelAgent (agent.py) 
    ↓
WorkerManager (manager.py)
    ↓
VerificationWorker (worker.py)
    ↓
Kernel Generation & Refinement Loop

2.1 核心功能

TritonKernelAgent 是 Triton 内核生成系统的核心协调者,负责:

  • 配置管理:处理环境变量和默认配置
  • 资源初始化:初始化 LLM 提供商、日志记录和子组件
  • 测试生成:使用 LLM 生成适当的测试代码
  • 内核种子生成:生成多个初始内核实现变体
  • 验证协调:协调 WorkerManager 运行并行验证
  • 结果处理:处理生成结果并返回适当的响应

它是连接问题描述和实际 Triton 内核实现的关键组件,通过协调多个子组件来实现高效、可靠的内核生成。

主生成方法

该方法明确三大核心组件的强依赖逻辑_generate_kernel_seeds 生成的多版本初始内核种子,必须基于 generated_test_code (标准化测试代码)进行开发适配;run_verification 则以 _generate_kernel_seeds 的内核种子为验证对象,以 generated_test_code 为验证标准,完成多版本内核的并行有效性检测,三者形成「测试代码标准化→内核种子生成→并行验证筛选」的严格执行链路。

核心特色
  1. 测试代码强制标准化,统一验证基准 :无论用户是否提供参考测试代码,均通过 _generate_test 生成标准化测试代码(参考代码仅作为适配依据),确保后续内核种子生成、验证环节使用统一的测试基准,避免因测试代码格式不统一导致的验证失效。
  2. 全流程会话化归档,可追溯可复现 :为每次内核生成任务创建唯一时间戳会话目录,归档问题描述、标准化测试代码、所有内核种子、最终有效内核及验证结果,实现全流程可追溯,便于问题排查与结果复现。
  3. 多版本内核种子生成,提升有效率 :调用 _generate_kernel_seeds 生成批量初始内核种子,为并行验证提供多版本候选,相比单版本生成大幅提升「筛选出可通过测试内核」的概率。
  4. 并行验证筛选,提升效率 :通过 manager.run_verification 对多版本内核种子做并行验证,利用多工作器同时检测内核是否通过标准化测试,大幅缩短验证耗时,适配批量内核的快速筛选需求。
  5. 标准化结果返回,贴合工程使用:成功时返回有效内核代码、工作器 ID、验证轮次、会话目录等核心信息;失败时明确返回失败状态与原因,结果格式统一,便于上层模块调用与后续处理。
逻辑关系图
代码
python 复制代码
    def generate_kernel(
        self, problem_description: str, test_code: str | None = None
    ) -> dict[str, Any]:
        """
        Generate an optimized Triton kernel for the given problem.

        Args:
            problem_description: Description of the kernel to generate
            test_code: Optional test code (generated if not provided)
                      The test code should:
                      1. Import the kernel function: from kernel import kernel_function
                      2. Test the kernel and return True/False
                      3. Exit with code 0 on success, 1 on failure

        Returns:
            Dictionary with results including successful kernel
        """
        # Always generate test code using LLM (even if test is provided as reference)
        generated_test_code = self._generate_test(problem_description, test_code)

        # Use the generated test code in standardized format
        test_code = generated_test_code

        # Log inputs
        import time

        # Add microseconds to ensure unique directory names
        timestamp = (
            datetime.now().strftime("%Y%m%d_%H%M%S")
            + f"_{int(time.time() * 1000000) % 1000000}"
        )
        session_dir = self.log_dir / f"session_{timestamp}"
        session_dir.mkdir(exist_ok=True)

        with open(session_dir / "problem.txt", "w") as f:
            f.write(problem_description)
        with open(session_dir / "test.py", "w") as f:
            f.write(test_code)

        # Generate kernel seeds
        kernel_seeds = self._generate_kernel_seeds(problem_description, test_code)

        # Save seeds
        for i, kernel in enumerate(kernel_seeds):
            with open(session_dir / f"seed_{i}.py", "w") as f:
                f.write(kernel)

        # Run parallel verification with session directory for worker logs
        result = self.manager.run_verification(
            kernel_seeds=kernel_seeds,
            test_code=test_code,
            problem_description=problem_description,
            session_log_dir=session_dir,
        )

        # Process results
        if result and result["success"]:

            # Save successful kernel
            with open(session_dir / "final_kernel.py", "w") as f:
                f.write(result["kernel_code"])

            # Save full result
            with open(session_dir / "result.json", "w") as f:
                json.dump(result, f, indent=2)

            return {
                "success": True,
                "kernel_code": result["kernel_code"],
                "worker_id": result["worker_id"],
                "rounds": result["rounds"],
                "session_dir": str(session_dir),
            }
        else:
            return {
                "success": False,
                "message": "Failed to generate working kernel",
                "session_dir": str(session_dir),
            }

测试代码生成

核心作用

该方法是基于 LLM 生成 Triton/CUDA 内核代码配套测试代码的核心功能模块 ,专为 PyTorch KernelAgent 设计,核心目标是为待实现的 GPU 内核(最终写入 kernel.py)自动生成可直接运行的标准化测试代码,支撑内核代码的语法校验、真机运行验证、功能正确性检测,是 LLM 生成 GPU 内核流水线中「验证环节」的关键组成部分。

核心特色
  1. LLM 主导生成,支持参考代码适配:优先调用配置的 LLM 服务商(如 OpenAI),通过 Prompt 模板渲染生成贴合问题描述的测试代码;若用户提供参考测试代码,会基于参考代码适配生成,无参考时则生成通用标准化测试,兼顾灵活性与贴合性。
  2. 强约束的标准化输出 :强制要求生成的测试代码从 kernel 模块导入内核函数(因内核最终写入工作目录的 kernel.py),确保测试代码与内核代码的调用路径一致,无运行路径错误。
  3. 完整的代码提取与异常处理:调用 LLM 后会从返回结果中提取有效代码,无有效代码则直接抛出异常;全流程记录日志(生成开始、原始响应、成功 / 失败状态),并捕获 LLM 调用、代码提取中的所有异常,便于问题排查。
  4. 无 Mock 兜底限制,保证生成有效性:仅当未配置 LLM 服务商时才触发 Mock 兜底,且兜底逻辑禁用「Mock 回退开关」,避免无实际能力的空生成,确保测试代码要么由 LLM 专业生成,要么由兜底逻辑生成基础可用代码。
  5. 适配 GPU 内核测试特性 :兜底测试代码默认基于 PyTorch 实现,针对 CUDA 设备设计测试数据,内核函数以普通 Python 函数方式调用(内核启动逻辑封装在 kernel.py 内部),贴合 Triton/CUDA 内核的测试习惯。
逻辑关系图
代码
python 复制代码
    def _generate_test(
        self, problem_description: str, provided_test_code: str | None = None
    ) -> str:
        """
        Generate test code for the problem using OpenAI API.

        The test must import from 'kernel' module since each worker writes
        the kernel to 'kernel.py' in their working directory.

        Args:
            problem_description: Description of the problem
            provided_test_code: Optional reference test code provided by user

        Returns:
            Generated test code in standardized format
        """
        # Use LLM provider if available; no mock fallback allowed
        if not self.provider:
            raise RuntimeError(
                "Unable to generate test code: no LLM provider available and mock fallback disabled"
            )
        # Use LLM provider if available
        if self.provider:
            try:
                self.logger.info(f"Generating test code using {self.model_name}")

                # Create prompt for test generation using template
                prompt = self.prompt_manager.render_test_generation_prompt(
                    problem_description=problem_description,
                    provided_test_code=provided_test_code,
                )

                # Call LLM API
                messages = [{"role": "user", "content": prompt}]
                response_text = self._call_llm(messages, max_tokens=24000)
                self.logger.info("Raw test generation response:\n%s", response_text)

                # Extract test code from response
                test_code = self._extract_code_from_response(response_text)

                if test_code:
                    self.logger.info(
                        f"Successfully generated test code using {self.model_name}"
                    )
                    return test_code
                else:
                    self.logger.error("Failed to extract valid code from LLM response")
                    raise ValueError("No valid code found in LLM response")

            except Exception as e:
                self.logger.error(f"Error generating test with LLM API: {e}")
                raise

        # Mock test generation (fallback)
        self.logger.info("Generating test code (mock implementation)")

        # If provided test code exists, create a basic wrapper
        if provided_test_code:
            test_code = '''"""
Test for kernel implementation (adapted from provided test).
"""
import torch

def test_kernel():
    """Test the kernel implementation."""
    from kernel import kernel_function

    # Adapted from provided test code
    try:
        # Create test data (standardized format)
        test_input = torch.randn(1024, device='cuda')

        # Call kernel_function as a normal Python function
        result = kernel_function(test_input)

        # Basic validation
        if result is not None:
            print("Test passed!")
            return True
        else:
            print("Test failed: No result returned")
            return False
    except Exception as e:
        print(f"Test failed: {e}")
        return False

if __name__ == "__main__":
    import sys
    success = test_kernel()
    sys.exit(0 if success else 1)
'''
        else:
            test_code = '''"""
Test for kernel implementation.
"""
import torch

def test_kernel():
    """Test the kernel implementation."""
    from kernel import kernel_function

    # Mock test - replace with actual test logic
    try:
        # Create test data
        test_input = torch.randn(1024, device='cuda')

        # Call kernel_function as a normal Python function
        # (kernel launch logic is handled inside kernel.py)
        result = kernel_function(test_input)

        print("Test passed!")
        return True
    except Exception as e:
        print(f"Test failed: {e}")
        return False

if __name__ == "__main__":
    import sys
    success = test_kernel()
    sys.exit(0 if success else 1)
'''
        return test_code

内核种子生成

核心作用

该方法是基于 LLM 批量生成 Triton 内核初始实现代码(Kernel Seeds)的核心模块,为 PyTorch KernelAgent 提供多版本的初始内核候选代码,所有生成代码需适配指定测试代码并遵循统一封装规范,是 LLM 生成 Triton 内核流水线中「初始代码生成环节」的核心,为后续内核筛选、调优提供多版本基础素材。

核心特色
  1. 批量生成多版本候选内核,支持数量灵活配置 :可指定生成内核数量(num_seeds),未指定时默认匹配工作器数量(self.num_workers),生成多版本初始内核,为后续筛选可用内核提供样本基础。
  2. LLM 生成强绑定测试代码,确保适配性:生成 Prompt 中融入用户提供的测试代码,要求 LLM 生成的内核必须能对接该测试代码,从源头保证内核与测试的兼容性,避免后续测试环节的基础适配问题。
  3. 原生多响应 + 兜底循环调用,适配不同 LLM 服务商能力 :智能适配 LLM 服务商能力 ------ 支持原生多响应的服务商直接批量生成,不支持的则通过循环单独调用实现批量生成,兼顾调用效率与兼容性,且单独调用时动态提高温度系数(temperature),提升多版本内核的多样性。
  4. 统一 Triton 内核封装规范,贴合工程落地 :强制生成的内核遵循固定封装模式,需实现 kernel_function 作为内核启动包装函数,该函数统一处理参数接收、Triton 内核启动逻辑,与测试代码的调用方式完全匹配,无调用规范冲突。
  5. 完整的代码提取与容错机制:对每个 LLM 响应单独提取有效内核代码,单版本提取失败仅记录警告不中断整体流程,全量提取失败则抛出明确异常;全流程捕获 LLM 调用异常,异常后自动触发 Mock 兜底,保证功能不中断。
逻辑关系图
代码
python 复制代码
    def _generate_kernel_seeds(
        self, problem_description: str, test_code: str, num_seeds: int | None = None
    ) -> list[str]:
        """
        Generate initial kernel implementations using OpenAI API.

        Args:
            problem_description: Description of the kernel to generate
            test_code: Test code that the kernel must pass
            num_seeds: Number of kernel variations to generate

        Returns:
            List of kernel implementation strings
        """
        if num_seeds is None:
            num_seeds = self.num_workers

        # Use LLM provider if available
        if self.provider:
            try:
                self.logger.info(
                    f"Generating {num_seeds} kernel seeds using {self.model_name}"
                )

                # Create prompt with Triton guidelines using template
                prompt = self.prompt_manager.render_kernel_generation_prompt(
                    problem_description=problem_description, test_code=test_code
                )

                kernels = []
                messages = [{"role": "user", "content": prompt}]

                # Use provider's multiple response capability
                max_completion_tokens = 20000

                if self.provider.supports_multiple_completions():
                    # Provider supports native multiple completions
                    responses = self.provider.get_multiple_responses(
                        self.model_name,
                        messages,
                        n=num_seeds,
                        temperature=0.8,
                        max_tokens=max_completion_tokens,
                        high_reasoning_effort=self.high_reasoning_effort,
                    )

                    for i, response in enumerate(responses):
                        kernel_code = self._extract_code_from_response(response.content)
                        if kernel_code:
                            kernels.append(kernel_code)
                        else:
                            self.logger.warning(
                                f"Failed to extract code from kernel seed {i}"
                            )
                else:
                    # Provider doesn't support multiple completions, make individual calls
                    for i in range(num_seeds):
                        response_text = self._call_llm(
                            messages,
                            max_tokens=max_completion_tokens,
                            temperature=0.8 + (i * 0.1),
                        )
                        kernel_code = self._extract_code_from_response(response_text)

                        if kernel_code:
                            kernels.append(kernel_code)
                        else:
                            self.logger.warning(
                                f"Failed to extract code from kernel seed {i}"
                            )

                if kernels:
                    self.logger.info(
                        f"Successfully generated {len(kernels)} kernel seeds"
                    )
                    return kernels
                else:
                    self.logger.error(
                        "Failed to extract any valid kernels from LLM responses"
                    )
                    raise ValueError("No valid kernel code found in any LLM response")

            except Exception as e:
                self.logger.error(f"Error generating kernels with LLM API: {e}")
                # Fall back to mock implementation

        # Mock kernel generation (fallback)
        self.logger.info(f"Generating {num_seeds} kernel seeds (mock implementation)")

        kernels = []
        for i in range(num_seeds):
            # Simpler mock that still demonstrates the wrapper pattern
            if i == 2:  # Third kernel will pass
                kernel = '''"""
Kernel implementation - working version.
"""

def kernel_function(*args, **kwargs):
    """Wrapper function that handles kernel launch."""
    # Mock implementation that passes tests
    # In real kernels, this would launch a Triton kernel
    return True
'''
            else:
                kernel = f'''"""
Kernel implementation {i + 1}.
"""

def kernel_function(*args, **kwargs):
    """Wrapper function that handles kernel launch."""
    # Mock implementation that fails
    raise NotImplementedError('Mock kernel not implemented')
'''
            kernels.append(kernel)

        return kernels

2.2 与子组件的交互

与 PromptManager 的交互

本模块与 PromptManager 在如下时机进行交互。

初始化时创建

python 复制代码
# Initialize prompt manager with resolved config
self.prompt_manager = PromptManager(target_platform=self._platform_config)

生成测试时

python 复制代码
                # Create prompt for test generation using template
                prompt = self.prompt_manager.render_test_generation_prompt(
                    problem_description=problem_description,
                    provided_test_code=provided_test_code,
                )

生成内核时

python 复制代码
                # Create prompt with Triton guidelines using template
                prompt = self.prompt_manager.render_kernel_generation_prompt(
                    problem_description=problem_description, test_code=test_code
               )

优化时

python 复制代码
                # Create refinement prompt using template
                prompt = self.prompt_manager.render_kernel_refinement_prompt(
                    problem_description=problem_description,
                    test_code=test_code,
                    kernel_code=kernel_code,
                    error_info=error_info,
                    history_context=history_context,
                )

与 WorkerManager 的交互

会在生成内核时进行验证

python 复制代码
        # Initialize worker manager
        self.manager = WorkerManager(
            num_workers=self.num_workers,
            max_rounds=self.max_rounds,
            log_dir=self.log_dir,
            openai_api_key=os.getenv("OPENAI_API_KEY"),
            openai_model=self.model_name,
            high_reasoning_effort=self.high_reasoning_effort,
            target_platform=self._platform_config.name,
        )
        
        # Run parallel verification with session directory for worker logs
        result = self.manager.run_verification(
            kernel_seeds=kernel_seeds,
            test_code=test_code,
            problem_description=problem_description,
            session_log_dir=session_dir,
        )        

0x03 prompt

3.1 逻辑流

render_kernel_generation_prompt 函数生成了用于生成kernel的prompt。

python 复制代码
    def render_kernel_generation_prompt(
        self,
        problem_description: str,
        test_code: str,
        triton_guidelines: str | None = None,
    ) -> str:
        """
        Render the kernel generation prompt.

        Args:
            problem_description: Description of the kernel to generate
            test_code: Test code that the kernel must pass
            triton_guidelines: Optional guidelines (if None, loads from template)

        Returns:
            Rendered prompt string
        """
        template = self.templates["kernel_generation"]

        # Load triton guidelines if not provided
        if triton_guidelines is None:
            triton_guidelines = self.render_triton_guidelines()

        return template.render(
            problem_description=problem_description,
            test_code=test_code,
            triton_guidelines=triton_guidelines,
            kernel_guidance=self.target_platform.kernel_guidance,
        )

render_triton_guidelines 函数会获取生成 Triton的指导准则。

python 复制代码
    def render_triton_guidelines(self) -> str:
        """
        Render the Triton guidelines.

        Returns:
            Rendered guidelines string
        """
        template = self.templates["triton_guidelines"]
        return template.render()

template_files 如下:

python 复制代码
        # Define template mappings
        template_files = {
            "test_generation": "test_generation.j2",
            "kernel_generation": "kernel_generation.j2",
            "kernel_refinement": "kernel_refinement.j2",
            "triton_guidelines": "triton_guidelines.j2",
        }

3.2 j2 文件

四个模板如下:

  • test_generation.j2 ---生成测试代码的提示
  • kernel_generation.j2 --- 生成Triton内核的提示
  • kernel_refinement.j2 --- 内核送代修复的提示
  • triton_guidelines.j2 --- Triton 编程指南

triton_guidelines.j2文件

该文件 定义Triton内核编程的基本规范和最佳实践。

具体内容

结构规范:

  • 使用 @triton.jit 装饰器
  • 函数命名规范
  • 常量定义(tl.constexpr)

内存访问模式:

  • tl.load 和 tl.store 的正确使用
  • 掩码(masking)和边界处理

优化技术:

  • 自动调优配置
  • 块大小选择
  • 张量核(tensor cores)使用
  • 操作融合策略

运行时限制:

  • 包装器函数只能做参数验证、分配和启动配置
  • 所有计算必须在 Triton 内核中执行
  • 禁止使用 PyTorch 高级操作
特点
  • 通用性:适用于所有 Triton 内核开发
  • 权威性:作为其他模板的基础参考

kernel_generation.j2 文件

该文件生成全新的 Triton 内核实现

具体内容
  • 任务定义:基于测试要求生成完整的 Triton 内核实现
  • 融合优先策略:强调将多个操作融合为单个内核
  • 运行时限制:包装器仅做参数验证,计算在内核中
  • 严格约束:禁止使用 PyTorch 高级操作(如 torch.nn、torch.nn.functional)
  • 示例代码:提供了具体实现结构示例
特点
  • 创建导向:专注于从零开始创建内核
  • 融合优化:强调操作融合以提高性能
  • 结构完整:要求生成完整的 Python 文件

kernel_refinement.j2 文件

该文件会基于已有内核和错误信息进行修复和改进

具体内容
  • 任务定义:根据测试结果修复现有 Triton 内核实现
  • 错误导向:包含错误输出和标准输出信息
  • 历史上下文:提供之前尝试的上下文信息
  • 融合审计:审查现有实现以寻找更多融合机会
  • 修复重点:分析错误消息并修复问题
特点
  • 修复导向:专注于修复现有代码中的问题
  • 迭代优化:基于错误反馈进行迭代改进
  • 上下文感知:考虑之前尝试的历史信息

三者之间的区别

用途差异、

triton_guidelines.j2:

  • 提供通用的 Triton 编程规范
  • 作为其他两个模板的底层参考
  • 不直接用于生成代码,而是提供指导原则

kernel_generation.j2:

  • 用于生成新的内核实现
  • 从空白状态开始创建完整的内核
  • 重点关注融合操作和性能优化

kernel_refinement.j2:

  • 用于修复现有的内核实现
  • 基于错误信息改进已有代码
  • 注重迭代优化和问题解决
输入数据差异

triton_guidelines.j2:

  • 提供通用准则,不依赖特定输入

kernel_generation.j2:

  • 需要测试代码作为输入
  • 生成基于测试需求的新内核

kernel_refinement.j2:

  • 需要现有内核代码、错误信息和测试代码
  • 包含历史上下文信息

输出目标差异

triton_guidelines.j2:

  • 输出编程准则文档
  • 不产生实际的 Python 代码

kernel_generation.j2:

  • 输出全新的、完整的 Triton 内核实现
  • 目标是满足测试要求的功能实现

kernel_refinement.j2:

  • 输出修复后的内核代码
  • 目标是解决特定错误并保持或改进功能

工作流程差异

triton_guidelines.j2:

  • 在生成和修复过程中都会被引用
  • 提供基础规范

kernel_generation.j2:

  • 通常在工作流程的初始阶段使用
  • 生成第一个可行版本

kernel_refinement.j2:

  • 在测试失败后使用
  • 进行迭代改进直到通过测试

0x04 Worker Manager 的作用

manager.py 是 KernelAgent 系统中的 WorkerManager 类实现文件,负责管理多个内核验证工作者进程的并行执行。在系统架构中的位置如下。

python 复制代码
TritonKernelAgent
    ↓
WorkerManager (manager.py)
    ↓
VerificationWorker (worker.py)
    ↓
Kernel Generation & Testing

4.1 核心功能

manager.py 在 KernelAgent 架构中扮演着并行计算调度器的角色,通过多进程并行执行多个内核验证任务,显著提升了内核生成的成功率和效率。其设计考虑了容错性、性能和资源管理等多个方面,是整个系统能够高效运行的关键组件之一。

使用场景

  • 内核生成加速
    • 并行尝试多种内核实现变体
    • 快速找到满足测试要求的内核
  • 鲁棒性提升
    • 多个独立的验证流程降低失败风险
    • 不同的随机种子可能产生不同的结果

并行工作管理

WorkerManager 类负责多进程协调,其工作负载分配原则如下:

  • 种子内核分发:将多个初始内核实现分发给不同的工作者
  • 并行验证:同时运行多个验证流程
  • 资源隔离:每个工作者在独立的进程中运行
python 复制代码
class WorkerManager:
    def __init__(
        self,
        num_workers: int = 4,
        max_rounds: int = 10,
        history_size: int = 8,
        log_dir: str | None = None,
        openai_api_key: str | None = None,
        openai_model: str = "gpt-5",
        high_reasoning_effort: bool = True,
        target_platform: str = "cuda",
    ):
    # 初始化多个工作者进程
    self.num_workers = num_workers
    self.workers: list[mp.Process] = []

生命周期管理

工作目录管理

会为 为工作者创建临时工作目录。

python 复制代码
    @contextmanager
    def temp_workdirs(self) -> list[Path]:
        """Create temporary working directories for workers."""
        workdirs = []
        try:
            for i in range(self.num_workers):
                workdir = Path(tempfile.mkdtemp(prefix=f"worker_{i}_"))
                workdirs.append(workdir)
                self.logger.info(f"Created workdir for worker {i}: {workdir}")
            yield workdirs
        finally:
            # Cleanup
            for workdir in workdirs:
                if workdir.exists():
                    shutil.rmtree(workdir)
                    self.logger.info(f"Cleaned up workdir: {workdir}")
进程生命周期
  • 启动:使用 multiprocessing.Process 启动工作者进程

  • 监控:监控工作者进程状态

  • 清理:在完成或异常情况下终止进程

同步与通信机制

成功信号

python 复制代码
self.success_event = mp.Event() # 跨进程成功信号
  • 当任一工作者找到解决方案时,设置事件

  • 其他工作者检测到信号后停止工作

结果收集

python 复制代码
self.result_queue = mp.Queue()  # 结果队列
  • 收集各工作者的结果

  • 确保最成功的结果被返回

4.2 工作流程分析

关键设计特点如下:

容错机制

  • 超时处理:进程超时后强制终止

  • 异常恢复:捕获和处理各种异常情况

  • 资源清理:确保临时文件和进程被正确清理

性能优化

  • 并发执行:利用多核并行验证

  • 早期退出:一旦找出成功解,立即停止其他进程

  • 资源限制:控制并发数量,避免资源耗尽

配置灵活性

  • 工作者数量:可配置并行度
  • 最大轮数:控制每个工作者的尝试次数
  • 历史记录:限制错误历史的大小

验证执行流程如下。

python 复制代码
    def run_verification(
        self,
        kernel_seeds: list[str],
        test_code: str,
        problem_description: str,
        session_log_dir: Path | None = None,
    ) -> dict[str, Any | None]:
        """
        Run parallel verification on multiple kernel seeds.

        Args:
            kernel_seeds: List of initial kernel implementations
            test_code: Test code to verify kernel correctness
            problem_description: Description of the problem
            session_log_dir: Optional session directory for worker logs

        Returns:
            Dictionary with successful kernel and metadata, or None
        """
        self.logger.info(f"Starting verification with {len(kernel_seeds)} seeds")
        # Reset cross-worker success signal for a fresh run
        try:
            self.success_event.clear()
        except Exception:
            pass
        # Reset workers list to avoid holding stale processes between runs
        self.workers = []

        # Determine where to put worker logs
        if session_log_dir:
            workers_parent_dir = Path(session_log_dir) / "workers"
            workers_parent_dir.mkdir(exist_ok=True)
        else:
            workers_parent_dir = self.workers_dir

        # 1. 准备工作目录
        with self.temp_workdirs() as workdirs:
            # Start workers 启动工作者进程
            for i, (kernel, workdir) in enumerate(zip(kernel_seeds, workdirs)):
                worker_log_dir = workers_parent_dir / f"worker_{i}"
                worker_log_dir.mkdir(exist_ok=True)

                args = (
                    i,
                    kernel,
                    test_code,
                    problem_description,
                    workdir,
                    worker_log_dir,
                    self.max_rounds,
                    self.history_size,
                    self.success_event,
                    self.result_queue,
                    self.openai_api_key,
                    self.openai_model,
                    self.high_reasoning_effort,
                    self.target_platform,
                )

                process = mp.Process(target=worker_process, args=args)
                process.start()
                self.workers.append(process)
                self.logger.info(f"Started worker {i}")

            # Wait for any worker to succeed or all to finish
            successful_result = None

            # 等待结果或超时
            while any(w.is_alive() for w in self.workers):
                try:
                    # Check for results with timeout
                    result = self.result_queue.get(timeout=1.0)
                    if result["success"]:
                        successful_result = result
                        self.logger.info(f"Worker {result['worker_id']} succeeded!")
                        # Signal all workers to stop
                        # 信号其他进程停止
                        self.success_event.set()
                        break
                except queue.Empty:
                    continue

            # Wait for all workers to finish
            # 清理进程
            for worker in self.workers:
                worker.join(timeout=5.0)
                if worker.is_alive():
                    self.logger.warning(f"Terminating worker {worker.pid}")
                    worker.terminate()

            # Collect any remaining results
            while not self.result_queue.empty():
                try:
                    result = self.result_queue.get_nowait()
                    if result["success"] and successful_result is None:
                        successful_result = result
                except queue.Empty:
                    break

            return successful_result

工作者进程入口

python 复制代码
def worker_process(
    worker_id: int,
    kernel_code: str,
    test_code: str,
    problem_description: str,
    workdir: Path,
    log_dir: Path,
    max_rounds: int,
    history_size: int,
    success_event: mp.Event,
    result_queue: mp.Queue,
    openai_api_key: str | None,
    openai_model: str,
    high_reasoning_effort: bool,
    target_platform: str,
):
    """
    Worker process for kernel verification and refinement.

    This is run in a separate process.
    """
    # Import here to avoid issues with multiprocessing
    from .worker import VerificationWorker

    worker = VerificationWorker(
        worker_id=worker_id,
        workdir=workdir,
        log_dir=log_dir,
        max_rounds=max_rounds,
        history_size=history_size,
        openai_api_key=openai_api_key,
        openai_model=openai_model,
        high_reasoning_effort=high_reasoning_effort,
        target_platform=target_platform,
    )

    result = worker.run(
        kernel_code=kernel_code,
        test_code=test_code,
        problem_description=problem_description,
        success_event=success_event,
    )

    result_queue.put(result)

4.3 与其他组件的交互

与 TritonKernelAgent 的交互

  • 接收:初始内核种子、测试代码、问题描述
  • 返回:成功的内核实现或失败指示

与 VerificationWorker 的交互

  • 启动:为每个工作者进程提供参数
  • 监控:管理工作者的生命周期

与系统资源的交互

  • 文件系统:管理临时工作目录
  • 进程管理:创建和终止子进程
  • 内存管理:通过队列传递结果

0x05 VerificationWorker

VerificationWorker 是专门用于验证和优化单个 Triton 内核实现的工作流程。

python 复制代码
TritonKernelAgent
  ↓
WorkerManager
  ↓
VerificationWorker 
  ↓
Kernel Generation & Refinement Loop

5.1 完整执行流程

VerificationWorker 是 TritonKernelAgent 系统中负责单个内核验证和优化的核心组件,它实现了闭环的验证-优化循环,通过 LLM 驱动的方式自动改进内核实现,同时确保符合 Triton 编程规范和系统约束。

  • 初始化:设置工作目录、日志和文件路径
  • 并行生成
    • 在多进程环境中与其他工作进程协作
    • 快速找到有效的内核实现
  • 检查执行:
    • 检查其他进程是否已经成功
    • 检测内核代码中的违规行为
    • 运行测试验证内核
      • 验证生成的 Triton 内核是否符合规范
      • 确保内核通过给定的测试用例
    • 记录当前轮次结果
    • 如果失败则使用LLM优化内核
      • 自动修复内核中的错误
      • 优化内核以满足 Triton 编程指南
    • 继续下一轮直到成功或达到最大轮次

5.2 设计

违规检测机制

  • 静态分析:检查代码中是否存在不允许的 PyTorch 模式
  • 模式匹配:使用正则表达式检测各种违规用法
  • 字符串剥离:移除注释和字符串以避免误报

LLM 交互机制

  • 历史上下文:包含之前尝试的错误和输出
  • 模板化提示:使用 Jinja2 模板构建优化提示
  • 错误导向:基于具体的错误信息指导 LLM 修复

并发安全

  • 共享事件:使用 multiprocessing.Event 通知其他进程已有成功结果
  • 独立工作目录:每个工作进程有自己独立的文件系统空间

错误处理

  • 超时保护:测试执行有 30 秒超时限制
  • 违规检查:防止内核使用不合规的 PyTorch 功能
  • 历史追踪:记录每轮的尝试结果用于后续优化

日志记录

  • 轮次日志:为每次迭代保存完整的输入输出
  • 历史保留:保留最近几轮的结果用于 LLM 上下文
  • 详细记录:记录时间戳、成功状态和错误信息

5.3 与其他组件的交互

与 WorkerManager 的交互

  • 接收参数:从管理器接收内核代码、测试代码和问题描述
  • 返回结果:将最终结果发送回结果队列

与 PromptManager 的交互

  • 获取模板:使用预定义的 Jinja2 模板构建 LLM 提示
  • 优化提示:生成特定的内核优化提示

与 LLM 提供商的交互

  • API 调用:调用配置的 LLM 模型进行内核优化
  • 响应解析:从 LLM 响应中提取有效的 Python 代码

5.4 Refine

_refine_kernel函数的作用:根据错误信息使用 LLM 来修复和改进当前的Triton内核实现。

这个函数是整个 Triton 内核生成过程中的关键组件,它实现了自动化错误修复和代码优化,使系统能够自主解决遇到的问题。

在验证循环中的角色

反馈循环

  • 测试失败后:当内核测试失败时,_refine_kernel会分析错误并生成改进版本
  • 多轮迭代,自适应优化:支持多轮修复尝试,直到达到最大轮数或测试通过
    • 迭代改进:基于测试失败的原因持续优化内核实现
    • 融合优先:保持融合操作的优先级,在修复时不会降低融合程度

约束维护

  • 规则遵守:确保修复后的代码仍然符合Triton编程指南和约束条件
  • 禁止操作检查:防止使用不允许的PyTorch操作(如torch.nn模块、torch.nn.functional等)

错误处理机制

  • 错误感知:接收来自测试运行的错误信息(error_info),包括标准输出和标准错误内容
  • 上下文构建:利用历史记录构建之前的尝试上下文,帮助LLM理解之前的问题和修改

输入输出

输入参数

  • kernel_code: 需要修复的当前内核代码
  • error_info: 包含测试失败信息的字典(stdout、stderr)
  • problem_description: 原始问题描述
  • test_code: 用于验证的测试代码

输出结果

  • refined_kernel:经过LLM优化的内核代码字符串
  • fallback机制:如果LLM不可用,则返回原始代码或简单修改

prompt

_refine_kernel 调用 render_kernel_refinement_prompt。

提示构建:使用PromptManager生成专门的修复提示,其中包含:

  • 问题描述
  • 测试代码
  • 当前的内核实现
  • 错误信息
  • 历史尝试记录

5.5 设计思考

subprocess

VerificationWorker使用subprocess而非exec来执行生成的代码,其原因如下:

  • 安全隔离:LLM生成的代码不可信,子进程隔离可防止恶意代码污染主进程(如修改全局变量、猴子补丁标准库)
  • 崩溃隔离:segfault、CUDA OOM等致命错误不会拖整个 agent进程。
  • 超时控制:子进程可以被强制终止(SIGTERM/SIGKILL),而exec()中的无限循环只能通过信号打断。
  • 环境白名单:通过env参数控制子进程可见的环境变量,防止泄露 API key。
  • 资源限制:可通过-I 隔离模式或sitecustomize.py注入来禁用网络。

不过项目也提供了_run_test_multiprocess作为替代方案(通过multiprocessing.Process+exec)。

禁止模式

VerificationWorker中的DISALLOWED_TORCH_PATTERNS列表给出3种被禁止的模式。该模式用于检测生成的 Triton 内核代码中是否非法使用了 PyTorch 计算,确保内核是真正用 Triton 实现的,而非偷偷调用 PyTorch。

另外,这也是一种安全防线,防止LLM生成的内核代码通过帧内省"作弊":

  • 窃取测试数据:内核可以通过'inspect.stack获取调用者的栈帧,从中读取测试函数的局部变量(如期望输出expected_output),然后直接返回该值"通过"测试,而根本没有进行真正的Triton计算,
  • 窃取参考实现:可以通过globals找到并调用PyTorch参考函数,伪装为Triton实现。
  • 破坏沙盒假设:帧内省可以跳出模块边界,访问不应被访问的上下文。

LLM有时会"发现"这些捷径,因为它被训练来通过测试一这些规则,确保通过测试的唯一方式是真正实现正确的Triton内核。

被禁止的模式举例如下:

  • import torch.nn或from torch import nn → 禁止导入nn 模块
  • torch.matmul()、torch.mm()、torch.bmm() → 禁止PyTorch 矩阵乘法
  • torch.einsum() → 禁止 einsum
  • inspect.stack()、sys·_getframe() → 禁止帧内省(防止从调用者窃取测试数据)
  • globals()、locals() → 禁止访问全局/局部变量

检测时会先用_strip_comments_and_strings()去掉注释和文档字符串,避免误报。

历史记录

Verificationworker的history使用了deque(maxlen=history_size)(默认 8)。为什么要限制 history 窗口大小而不是保留全部历史?其原因如下:

  • Token 预算控制:每轮修复都需要将history 作为上下文发送给LLM,全部历史会迅速耗尽上下文窗口(refinement prompt 中历史部分包含每轮的kernel代码片段+错误信息)。
  • 信息衰减:早期失败的代码对后续修复的参考价值递减,最近的失败尝试更有指导意义。
  • 避免循环:过多历史可能导致LLM"过拟合"于先前错误模式,而非探索新方向。
  • 内存效率:每轮round_data包含完整kernel代码,保留全部会占用大量内存。

5.6 代码

内核验证功能

测试执行

python 复制代码
    def _run_test(self) -> tuple[bool, str, str]:
        """
        Run the test script and capture results.

        Returns:
            Tuple of (success, stdout, stderr)
        """
        cmd = [sys.executable, str(self.test_file)]

        try:
            result = subprocess.run(
                cmd,
                cwd=str(self.workdir),
                capture_output=True,
                text=True,
                timeout=30,  # 30 second timeout
            )

            success = result.returncode == 0
            if success:
                self.logger.info("Test passed")
            else:
                self.logger.error(
                    "Test failed. Exit code: %s, stderr: %s",
                    result.returncode,
                    result.stderr[:500],
                )

            return success, result.stdout, result.stderr

        except subprocess.TimeoutExpired:
            self.logger.error("Test timed out")
            return False, "", "Test execution timed out after 30 seconds"
        except Exception as e:
            self.logger.error(f"Test execution error: {e}")
            return False, "", str(e)

PyTorch 计算违规检测

python 复制代码
DISALLOWED_TORCH_PATTERNS = [
    (
        re.compile(r"\bimport\s+torch\.nn(\b|\s+as\b)"),
        "importing torch.nn modules is not allowed",
    ),
    (
        re.compile(r"\bfrom\s+torch\s+import\s+nn\b"),
        "importing torch.nn modules is not allowed",
    ),
    (
        re.compile(r"\bimport\s+torch\.nn\.functional\s+as\s+F\b"),
        "aliasing torch.nn.functional as F is not allowed",
    ),
    (re.compile(r"\btorch\.nn\."), "torch.nn module usage is not allowed"),
    (
        re.compile(r"\btorch\.nn\.functional\b"),
        "torch.nn.functional usage is not allowed",
    ),
    (
        re.compile(r"\bF\.[A-Za-z_]+\("),
        "torch.nn.functional alias calls (F.*) are not allowed",
    ),
    (re.compile(r"\btorch\.conv"), "torch convolution helpers are not allowed"),
    (
        re.compile(
            r"\btorch\.(relu|sigmoid|tanh|softmax|gelu|mish|hardtanh|max_pool|avg_pool)[A-Za-z0-9_]*\("
        ),
        "PyTorch activation/pooling helpers are not allowed",
    ),
    (
        re.compile(r"\bclass\s+\w+\s*\(\s*nn\.Module"),
        "Subclassing torch.nn.Module is not allowed",
    ),
    (
        re.compile(r"\.forward\("),
        "Calling .forward() indicates torch.nn module usage and is not allowed",
    ),
    (
        re.compile(r"\btorch\.ops\.aten\b"),
        "Low-level torch.ops.aten.* calls are not allowed; implement these ops directly in Triton kernels instead of relying on PyTorch compute",
    ),
    # Generic tensor-tensor math that must be implemented in Triton kernels
    (
        re.compile(r"\btorch\.(matmul|mm|bmm)\s*\("),
        "PyTorch matmul/mm/bmm tensor-tensor ops are not allowed; implement these in Triton kernels",
    ),
    (
        re.compile(r"\.(matmul|mm|bmm)\s*\("),
        "Tensor.matmul/mm/bmm methods are not allowed; implement these in Triton kernels",
    ),
    (
        re.compile(r"\btorch\.einsum\s*\("),
        "torch.einsum is not allowed; implement this contraction with Triton primitives",
    ),
    (
        re.compile(r"\.einsum\s*\("),
        "Tensor.einsum is not allowed; implement this contraction with Triton primitives",
    ),
    # Introspection / frame inspection that can be used to steal test locals
    (
        re.compile(r"\bimport\s+inspect\b"),
        "inspect-based reflection is not allowed inside kernel files",
    ),
    (
        re.compile(r"\binspect\.(stack|currentframe|getouterframes)\s*\("),
        "inspect stack/frame introspection is not allowed in kernels",
    ),
    (
        re.compile(r"\bsys\._getframe\s*\("),
        "sys._getframe is not allowed in kernels; do not access caller frames",
    ),
    (
        re.compile(r"\.f_locals\b|\.f_globals\b"),
        "Accessing frame locals/globals (f_locals/f_globals) from kernels is not allowed",
    ),
    (
        re.compile(r"\bglobals\s*\("),
        "globals() is not allowed in kernels; avoid depending on ambient test state",
    ),
    (
        re.compile(r"\blocals\s*\("),
        "locals() is not allowed in kernels; avoid depending on caller scopes",
    ),
]

内核优化功能

LLM 驱动优化

python 复制代码
    def _refine_kernel(
        self,
        kernel_code: str,
        error_info: dict[str, str],
        problem_description: str,
        test_code: str,
    ) -> str:
        """
        Refine kernel based on error information using OpenAI API.

        Uses multi-turn dialogue by incorporating history of previous attempts.
        """
        if self.provider:
            try:
                self.logger.info(f"Refining kernel using {self.openai_model}")

                # Build context from history
                # 构建历史上下文
                history_context = ""
                if self.history:
                    history_context = "\n\nPREVIOUS ATTEMPTS:\n"
                    for i, round_data in enumerate(self.history):
                        # 添加历史记录
                        history_context += f"\nAttempt {i + 1}:\n"
                        history_context += f"Kernel code:\n```python\n{round_data['kernel_code'][:500]}...\n```\n"
                        if round_data.get("stderr"):
                            history_context += f"Error: {round_data['stderr'][:200]}\n"
                        if round_data.get("stdout"):
                            history_context += f"Output: {round_data['stdout'][:200]}\n"

                # Create refinement prompt using template
                # 使用模板创建优化提示
                prompt = self.prompt_manager.render_kernel_refinement_prompt(
                    problem_description=problem_description,
                    test_code=test_code,
                    kernel_code=kernel_code,
                    error_info=error_info,
                    history_context=history_context,
                )

                # Call LLM API
                messages = [{"role": "user", "content": prompt}]
                response_text = self._call_llm(messages, max_tokens=20000)

                # Extract refined kernel from response
                # 从响应中提取优化后的内核
                refined_kernel = self._extract_code_from_response(response_text)

                if refined_kernel:
                    self.logger.info(
                        f"Successfully refined kernel using {self.openai_model}"
                    )
                    return refined_kernel
                else:
                    self.logger.error("Failed to extract valid code from LLM response")
                    # Return original kernel if extraction fails
                    return kernel_code

            except Exception as e:
                self.logger.error(f"Error refining kernel with LLM API: {e}")
                # Fall back to mock refinement

        # Mock refinement (fallback)
        self.logger.info("Refining kernel (mock implementation)")

        # For testing, make a simple modification
        if "error" in error_info.get("stderr", "").lower():
            # Add a comment to show refinement happened
            return f"# Refinement attempt {len(self.history) + 1}\n{kernel_code}"

        return kernel_code

优化循环

python 复制代码
    def run(
        self,
        kernel_code: str,
        test_code: str,
        problem_description: str,
        success_event: mp.Event,
    ) -> dict[str, Any]:
        """
        Run verification and refinement loop.

        Args:
            kernel_code: Initial kernel implementation
            test_code: Test code to verify kernel
            problem_description: Problem description for context
            success_event: Shared event to check if another worker succeeded

        Returns:
            Dictionary with results
        """
        self.logger.info(f"Starting verification for worker {self.worker_id}")

        # 运行验证和优化循环
        current_kernel = kernel_code

        for round_num in range(self.max_rounds):
            # Check if another worker has succeeded
            # 检查是否有其他工作进程已经成功
            if success_event.is_set():
                self.logger.info("Another worker succeeded, stopping")
                return {
                    "worker_id": self.worker_id,
                    "success": False,
                    "stopped_early": True,
                    "rounds": round_num,
                }

            self.logger.info(f"Round {round_num + 1}/{self.max_rounds}")

            # 更新内核文件
            # Write files - test only on first round, kernel every round
            if round_num == 0:
                # First round: write both kernel and test
                self._write_files(current_kernel, test_code)
            else:
                # Subsequent rounds: only update kernel, test remains unchanged
                self._write_kernel(current_kernel)

            # 检测违规
            violation = self._detect_pytorch_compute(current_kernel)
            if violation:
                message = f"Disallowed PyTorch usage detected: {violation}"
                self.logger.error(message)
                self._log_round(round_num + 1, False, current_kernel, "", message)
                error_info = {
                    "stdout": "",
                    "stderr": message,
                    "history": list(self.history),
                }
                # 违规处理
                current_kernel = self._refine_kernel(
                    current_kernel, error_info, problem_description, test_code
                )
                continue

            # Run test 执行测试
            success, stdout, stderr = (
                self._run_test()
                if os.getenv("KA_PROCESS_USE_SYS_EXECUTABLE", "1") == "1"
                else _run_test_multiprocess(self.logger, self.workdir, self.test_file)
            )

            # Log round  记录轮次
            self._log_round(round_num + 1, success, current_kernel, stdout, stderr)

            if success:
                self.logger.info(
                    f"Success! Kernel passed test in round {round_num + 1}"
                )
                return {
                    "worker_id": self.worker_id,
                    "success": True,
                    "kernel_code": current_kernel,
                    "rounds": round_num + 1,
                    "history": list(self.history),
                }

            # Refine kernel for next round
            # 为下一轮优化内核
            error_info = {
                "stdout": stdout,
                "stderr": stderr,
                "history": list(self.history),
            }

            current_kernel = self._refine_kernel(
                current_kernel, error_info, problem_description, test_code
            )

        # Max rounds reached without success
        self.logger.warning(f"Max rounds ({self.max_rounds}) reached without success")
        return {
            "worker_id": self.worker_id,
            "success": False,
            "max_rounds_reached": True,
            "rounds": self.max_rounds,
            "history": list(self.history),
        }

0xFF 参考

KernelFalcon: Autonomous GPU Kernel Generation via Deep Agents

基于 LLM 的 GPU 内核代码自动生成相关工作

Automating GPU Kernel Generation with DeepSeek-R1 and Inference Time Scaling

DeepSeek-R1自写CUDA内核跑分屠榜!斯坦福学霸狂飙GPU编程自动化挑战人类

CUDA、Triton 内核生成现状追踪

大模型能否为不同硬件平台生成高性能内核?南大、浙大提出跨平台内核生成评测框架MultiKernelBench

AKG kernel Agent:利用multi-agent进行kernel的生成和迁移

AKG KERNEL AGENT: A MULTI-AGENT FRAMEWORK FOR CROSS-PLATFORM KERNEL SYNTHESIS

AIKG -- 基于AI驱动的算子生成器

RL 猛刷 CUDA 核:CUDA-L1: Improving CUDA Optimization via Contrastive Reinforcement Learning

MultiKernelBench: A Multi-Platform Benchmark for Kernel Generation

Ouyang A, Guo S, Arora S, et al. Kernelbench: Can llms write efficient gpu kernels?[J]. arXiv preprint arXiv:2502.10517, 2025.

Baronio, Carlo, et al. "Kevin: Multi-turn rl for generating cuda kernels."arXiv preprint arXiv:2507.11948(2025).

Li, Shangzhan, et al. "Autotriton: Automatic triton programming with reinforcement learning in llms."arXiv preprint arXiv:2507.05687(2025).

Li, Jianling, et al. "Tritonbench: Benchmarking large language model capabilities for generating triton operators."Findings of the Association for Computational Linguistics: ACL 2025. 2025.

Tjarko Lange, Robert, et al. "Towards Robust Agentic CUDA Kernel Benchmarking, Verification, and Optimization."arXiv e-prints(2025): arXiv-2509.

Chen, Wentao, et al. "CUDA-LLM: LLMs Can Write Efficient CUDA Kernels."arXiv preprint arXiv:2506.09092(2025).