Trae-Agent中的 selector核心逻辑

Trae Agent Selector 选择逻辑分析

概述

Selector(选择器)是 Trae Agent 评估模块中的核心组件,用于从多个候选 Patch 中选择最佳解决方案。本文档详细分析 Selector 的核心逻辑和实现。


一、整体架构

yaml 复制代码
┌─────────────────────────────────────────────────────────────────────────────┐
│                      Patch Selection 流程                                    │
└─────────────────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────────────────┐
│  Phase 1: 准备阶段                                                          │
│  ┌─────────────────────────────────────────────────────────────────────┐   │
│  │  - 读取候选 Patch 列表 (candidates.jsonl)                            │   │
│  │  - 回归测试过滤                                                      │   │
│  │  - Patch 去重                                                        │   │
│  │  - 分组处理                                                          │   │
│  └─────────────────────────────────────────────────────────────────────┘   │
└─────────────────────────────────────────────────────────────────────────────┘
                                     │
                                     ▼
┌─────────────────────────────────────────────────────────────────────────────┐
│  Phase 2: 选择阶段                                                          │
│  ┌─────────────────────────────────────────────────────────────────────┐   │
│  │  SelectorAgent 运行                                                  │   │
│  │  - 分析每个候选 Patch                                                 │   │
│  │  - 在沙箱中验证(可选)                                               │   │
│  │  - 选择最佳 Patch                                                     │   │
│  └─────────────────────────────────────────────────────────────────────┘   │
└─────────────────────────────────────────────────────────────────────────────┘
                                     │
                                     ▼
┌─────────────────────────────────────────────────────────────────────────────┐
│  Phase 3: 投票阶段(可选)                                                   │
│  ┌─────────────────────────────────────────────────────────────────────┐   │
│  │  - 多次运行 SelectorAgent                                            │   │
│  │  - 统计选择频率                                                      │   │
│  │  - 多数投票得出最终结果                                               │   │
│  └─────────────────────────────────────────────────────────────────────┘   │
└─────────────────────────────────────────────────────────────────────────────┘

二、核心组件

2.1 入口:selector.py

文件 : evaluation/patch_selection/selector.py

python 复制代码
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--instances_path", required=True)    # 实例列表
    parser.add_argument("--candidate_path", required=True)   # 候选 Patch 文件
    parser.add_argument("--num_candidate", type=int, default=10)
    parser.add_argument("--group_size", type=int, default=10)  # 分组大小
    parser.add_argument("--majority_voting", action="store_true")  # 多数投票
    args = parser.parse_args()
    
    # 加载候选 Patch
    candidate_dic = {}
    with open(args.candidate_path, "r") as file:
        for line in file.readlines():
            candidate = json.loads(line.strip())
            candidate_dic[candidate["instance_id"]] = candidate
    
    # 创建评估器
    evaluation = SelectorEvaluation(
        llm_config,
        args.num_candidate,
        args.max_retry,
        args.max_turn,
        args.log_path,
        args.output_path,
        args.patches_path,
        instance_list,
        candidate_dic,
        tools_path,
        args.statistics_path,
        args.group_size,
        majority_voting=args.majority_voting,
    )
    
    # 运行所有实例
    evaluation.run_all(max_workers=args.max_workers)

2.2 SelectorEvaluation 类

文件 : evaluation/patch_selection/trae_selector/selector_evaluation.py

2.2.1 分组处理
python 复制代码
def run_instance(
    instance,
    candidate_log,  # 候选 Patch 日志
    num_candidate,
    group_size,
    ...
):
    """将候选 Patch 分组处理"""
    # 将 N 个候选分为 M 组
    groups = []
    for i in range(0, num_candidate, group_size):
        this_group = {
            "instance_id": candidate_log["instance_id"],
            "issue": candidate_log["issue"],
            "patches": candidate_log["patches"][i:i + group_size],
            "regressions": candidate_log["regressions"][i:i + group_size],
            "success_id": candidate_log["success_id"][i:i + group_size],
        }
        groups.append(this_group)
    
    # 每组独立选择
    for group_id, group in enumerate(groups):
        run_instance_by_group(instance=instance, candidate_log=group, ...)

分组策略示例

sql 复制代码
50 个候选 Patch,group_size=10
  ↓
Group 0: Patch 0-9
Group 1: Patch 10-19
Group 2: Patch 20-29
Group 3: Patch 30-39
Group 4: Patch 40-49
  ↓
每组选择 1 个 → 共 5 个候选
  ↓
最终选择 1 个
2.2.2 单组处理流程
python 复制代码
def run_instance_by_group(instance, candidate_log, ...):
    """处理单个组的候选 Patch"""
    
    # 1. 检查是否已处理
    file_path = statistics_path + f"/group_{group_id}/{instance['instance_id']}.json"
    if os.path.exists(file_path) and os.path.getsize(file_path) > 0:
        print("Already processed. Skipping...")
        return
    
    # 2. 检查是否全部失败/成功(边界情况)
    all_failed = all(success_id == 0 for success_id in candidate_log["success_id"])
    all_success = all(success_id == 1 for success_id in candidate_log["success_id"])
    if all_failed or all_success:
        # 直接保存结果,跳过选择
        save_patches(...)
        save_selection_success(...)
        return
    
    # 3. 构建候选列表
    candidate_list = []
    for idx in range(len(candidate_log["patches"])):
        candidate_list.append(CandidatePatch(
            id=idx,
            patch=candidate_log["patches"][idx],
            cleaned_patch=clean_patch(candidate_log["patches"][idx]),
            is_success_regression=len(candidate_log["regressions"][idx]) == 0,
            is_success_patch=candidate_log["success_id"][idx],
        ))
    
    # 4. 回归测试过滤
    candidate_list_regression = [
        c for c in candidate_list if c.is_success_regression
    ]
    if len(candidate_list_regression) > 0:
        candidate_list = candidate_list_regression
    
    # 5. Patch 去重
    candidate_list_deduplication = []
    cleaned_candidate_set = set()
    for candidate in candidate_list:
        if candidate.cleaned_patch not in cleaned_candidate_set:
            cleaned_candidate_set.add(candidate.cleaned_patch)
            candidate_list_deduplication.append(candidate)
    candidate_list = candidate_list_deduplication
    
    # 6. 创建沙箱
    sandbox = Sandbox(namespace, image_name, tag, instance, tools_path)
    sandbox.start_container()
    project_path = sandbox.get_project_path()
    
    # 7. 运行选择
    if majority_voting:
        # 多数投票模式
        final_id_list, final_patch_list = [], []
        for idx in range(num_candidate):
            select_agent = SelectorAgent(...)
            final_id, final_patch = select_agent.run()
            final_id_list.append(final_id)
            final_patch_list.append(final_patch)
            
            # 提前终止
            if max(Counter(final_id_list).values()) > num_candidate / 2:
                break
        
        # 统计投票
        counter = Counter(final_id_list)
        max_count = max(counter.values())
        most_common_ids = [elem for elem, count in counter.items() if count == max_count]
        final_id = most_common_ids[0]
        final_patch = final_patch_list[final_id_list.index(final_id)]
    else:
        # 单次选择模式
        select_agent = SelectorAgent(...)
        final_id, final_patch = select_agent.run()
    
    # 8. 保存结果
    save_patches(instance_id=instance["instance_id"], patches_path=patches_path,
                 patches=final_patch, group_id=group_id)
    save_selection_success(instance_id=instance["instance_id"], ...)
    
    sandbox.stop_container()

2.3 SelectorAgent 类

文件 : evaluation/patch_selection/trae_selector/selector_agent.py

python 复制代码
class SelectorAgent:
    def __init__(
        self,
        llm_config: ModelConfig,
        sandbox: Sandbox,
        project_path: str,
        issue_description: str,
        trajectory_file_name: str,
        candidate_list: list[CandidatePatch],
        max_turn: int = 50,
    ):
        self.llm_config = llm_config
        self.max_turn = max_turn
        self.sandbox = sandbox
        self.sandbox_session = self.sandbox.get_session()
        
        # 重置代码到基线
        self.sandbox_session.execute("git reset --hard HEAD")
        
        # 初始化工具(仅 bash 和 str_replace_based_edit_tool)
        self.tools = [
            tools_registry[tool_name](model_provider=llm_config.model_provider.provider)
            for tool_name in ["bash", "str_replace_based_edit_tool"]
        ]
        
        # 初始化 LLM 客户端
        self.llm_client = LLMClient(llm_config)
        
        # 构建初始消息
        self.initial_messages = [
            LLMMessage(role="system", content=build_system_prompt(len(candidate_list)))
        ]
        
        # 添加用户提示
        user_prompt = f"""
[Codebase path]: {project_path}
[Github issue description]:

{issue_description}

css 复制代码
[Candidate Patches]:
"""
        for idx, candidate in enumerate(candidate_list):
            user_prompt += f"\nPatch-{idx + 1}:\n```\n{candidate.patch}\n```"
        
        self.initial_messages.append(LLMMessage(role="user", content=user_prompt))

2.4 系统提示词设计

python 复制代码
def build_system_prompt(candidate_length: int) -> str:
    return f"""\
# ROLE: Act as an expert code evaluator. 
Given a codebase, an github issue and **{candidate_length} candidate patches** 
proposed by your colleagues, your responsibility is to **select the correct one** 
to solve the issue.

# WORK PROCESS:
1. Understand the Issue and Codebase
2. Analyze the Candidate Patches
3. Validate Functionality (Optional but Recommended)
4. Select the Best Patch

# FINAL REPORT:
### Status: succeed
### Result: Patch-x
### Analysis: [Explain why Patch-x is correct.]

# IMPORTANT TIPS:
1. Never avoid making a selection.
2. Do not propose new patches.
3. There must be at least one correct patch.
"""

三、执行流程

3.1 SelectorAgent.run() 方法

python 复制代码
def run(self):
    """Selector Agent 主循环"""
    turn = 0
    final_id, final_patch = self.candidate_list[0].id, self.candidate_list[0].patch
    messages = self.initial_messages
    
    while turn < self.max_turn:
        turn += 1
        
        # 1. 调用 LLM
        llm_response = self.llm_client.chat(messages, self.llm_config, self.tools)
        
        # 2. 记录轨迹
        self.trajectory_recorder.record_llm_interaction(...)
        
        # 3. 检查是否完成选择
        match = re.search(
            r"Status:\s*(success|succeed).*\n.*Result:\s*Patch-(\d+)",
            llm_response.content,
        )
        
        if match:
            # 提取选择的 Patch
            selected_idx = int(match.group(2)) - 1
            if selected_idx < len(self.candidate_list):
                final_id = self.candidate_list[selected_idx].id
                final_patch = self.candidate_list[selected_idx].patch
            break
        
        # 4. 执行工具调用
        messages += parse_tool_response(
            llm_response, self.sandbox_session
        )
    
    # 清理
    self.trajectory_recorder.finalize_recording(True, final_patch)
    self.sandbox_session.execute("git reset --hard HEAD")
    self.sandbox_session.close()
    
    return final_id, final_patch

3.2 工具响应解析

python 复制代码
def parse_tool_response(answer: LLMResponse, sandbox_session):
    """解析 LLM 工具调用并在沙箱中执行"""
    result = []
    
    for tool_call in answer.tool_calls:
        tool_call_id = tool_call.call_id
        tool_name = tool_call.name
        
        # 1. 构建执行命令
        if tool_name == "str_replace_based_edit_tool":
            cmd = "cd /home/swe-bench/tools/ && /home/swe-bench/py312/bin/python3 execute_str_replace_editor.py"
        elif tool_name == "bash":
            cmd = "cd /home/swe-bench/tools/ && /home/swe-bench/py312/bin/python3 execute_bash.py"
        else:
            # 未知工具
            result.append(LLMMessage(
                role="user",
                content="Tool not available...",
                tool_result=ToolResult(success=False, ...)
            ))
            continue
        
        # 2. 添加参数
        for key, value in tool_call.arguments.items():
            cmd += f" --{key} {shlex.quote(str(value))}"
        
        # 3. 在沙箱中执行
        cmd += " > /home/swe-bench/tools/log.out 2>&1"
        sandbox_session.execute(cmd)
        sandbox_res = sandbox_session.execute("cat /home/swe-bench/tools/log.out")
        
        # 4. 解析执行结果
        status = "Tool Call Status: 0"  # 假设成功
        if "Tool Call Status: -1" in sandbox_res:
            status = "Tool Call Status: -1"
        
        result.append(LLMMessage(
            role="user",
            content=sandbox_res,
            tool_result=ToolResult(
                success=status == "Tool Call Status: 0",
                ...
            )
        ))
    
    return result

四、关键数据结构

4.1 CandidatePatch

python 复制代码
class CandidatePatch:
    def __init__(
        self,
        id,                       # Patch ID
        patch,                    # 原始 Patch
        cleaned_patch,            # 清理后的 Patch(用于去重)
        is_success_regression,    # 是否通过回归测试
        is_success_patch,         # 是否正确(Ground Truth)
    ):
        self.id = id
        self.patch = patch
        self.cleaned_patch = cleaned_patch
        self.is_success_regression = is_success_regression
        self.is_success_patch = is_success_patch

4.2 输入数据格式

json 复制代码
{
    "instance_id": "astropy__astropy-14369",
    "issue": "Issue description...",
    "patches": [
        "patch diff 1",
        "patch diff 2",
        "patch diff N"
    ],
    "success_id": [1, 0, 1],
    "regressions": [
        [],
        ["test_module.py::test_func"],
        []
    ]
}
字段 说明
instance_id 实例 ID
issue 问题描述
patches 候选 Patch 列表
success_id 是否正确 (1=正确, 0=错误)
regressions 回归测试失败的测试列表

五、输出结构

bash 复制代码
results/
├── log/
│   └── group_0/
│       └── instance_id_voting_0_trail_1.json   # LLM 交互日志
├── output/
│   └── group_0/
│       └── instance_id.log                      # 标准输出
├── patch/
│   └── group_0/
│       └── instance_id_1.patch                  # 选中的 Patch
└── statistics/
    └── group_0/
        └── instance_id.json                     # 统计结果

六、设计特点

特点 说明
分组处理 避免上下文过长,提高选择准确性
回归过滤 优先选择通过回归测试的 Patch
去重机制 避免重复分析相同的 Patch
多数投票 提高选择稳定性
沙箱验证 在隔离环境中验证 Patch
轨迹记录 完整记录选择过程用于分析

七、关键文件

文件 功能
evaluation/patch_selection/selector.py 主入口
evaluation/patch_selection/trae_selector/selector_evaluation.py 评估协调
evaluation/patch_selection/trae_selector/selector_agent.py Agent 实现
evaluation/patch_selection/trae_selector/sandbox.py 沙箱环境

最后更新: 2026-03-16

相关推荐
key_3_feng2 小时前
AI大模型时代的企业可观测性架构设计方案
人工智能·可观测性
张艾拉 Fun AI Everyday2 小时前
苹果的 AI 战略到底是什么?
大数据·人工智能
咚咚王者2 小时前
人工智能之知识蒸馏 第四章 知识蒸馏架构演进与适配方案
人工智能·架构
岁月宁静2 小时前
都知道AI大模型能生成文本内容,那你知道大模型是怎样生成文本的吗?
前端·vue.js·人工智能
Jumbo星2 小时前
20260416 时代的变化
人工智能
黎阳之光2 小时前
去标签化无感定位技术突破,黎阳之光重构空间定位技术路径
大数据·人工智能·算法·安全·数字孪生
风曦Kisaki2 小时前
# LAMP 架构 + Discuz! 论坛实战笔记
笔记·架构
jasonblog3 小时前
对小龙虾openclaw的关注、学习、使用和变化观察
人工智能·学习·ai
太难了啊3 小时前
从零构建你的 AI Agent 框架:Node.js 版 HelloAgents 实战指南
人工智能·node.js