思维树
定义 :思维树(Tree of Thoughts, ToT) 是一种先进的推理框架,它通过同时探索多条推理路径对思维链(Chain of Thought)** 进行了扩展。该技术将问题解决视为一个搜索过程 ------ 模型生成不同的中间步骤,评估这些步骤的可行性,并探索最有希望的路径。
Tree of Thoughts (ToT) 是一种大语言模型推理框架,通过树状结构探索多条推理路径,允许模型自我评估路径可行性并回溯调整,模拟人类解决复杂问题时的 "试错 - 评估 - 选择" 过程。
目标:解决传统 LLMs 逐 Token 单向决策的局限,提升在需要探索、战略前瞻或多步规划任务(如数学推理、创意写作、谜题)中的表现。
ToT 框架核心机制
- 核心思路:将问题解决视为树状搜索过程,通过生成 ** 连贯的中间思维单元(Thoughts)** 作为推理的中间步骤,而非单一 Token。
- 关键能力:多路径探索:同时生成多条推理路径(如不同的解题思路)。
- 自我评估:评估每条路径的可行性,选择最有希望的分支继续探索。
- 回溯决策:必要时回溯到之前的思维节点,调整后续策略(类似人类解题的试错过程)。与 Chain of Thought(CoT)的区别:
与COT的对比
CoT 仅生成单一推理链,而 ToT 支持并行探索多条链,并通过评估机制实现全局最优决策。
24点案例
使用数字4、9、10和13以及四种基本运算符(+、-、/、*),生成一个结果为24的表达式。
java
step1
输入:4, 9, 10, 13
可能的下一步操作:
- 4 + 9 = 13(剩余:13, 10, 13)
- 10 - 4 = 6(剩余:6, 9, 13)
- 13 - 10 = 3(剩余:4, 9, 3)
- 9 × 4 = 36(剩余:36, 10, 13)
- 10 ÷ 4 = 2.5(剩余:2.5, 9, 13)
输入:4, 9, 10, 13
请给出可能得下一步操作
输出:
4+9 = 13 (left: 13, 10, 13)
10-4 = 6 (left: 6, 9, 13)
13-9 = 4 (left: 4, 9, 10)
...
...
step2
计算是否可以得到24
10 14: 10+14 = 24 sure
10 7 2: 7*2+10 = 24 sure
11 11: 11 + 11 = 22 impossible
输入第一组结果,请给出可能得结果
13, 10, 13:
输出:
10 + 13 + 13 = 36 impossible
...
...
计算是否可以得到24
10 14: 10+14 = 24 sure
10 7 2: 7*2+10 = 24 sure
11 11: 11 + 11 = 22 impossible
输入第一组结果,请给出可能得结果
6, 9, 13:
输出:
6 * (13-9) = 24 sure
...
...
自动化代码示例
生成思维结点,以树状形式组织;沿着思维结点进行探索,评估结果;根据评估结果选择下一步操作
java
package com.example.tot24;
import ai.spring.ai.client.ChatClient;
import ai.spring.ai.client.Generation;
import ai.spring.ai.client.Message;
import ai.spring.ai.client.chat.ChatResponse;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.CommandLineRunner;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
public class Tot24Application {
// 思维树节点类
static class TreeNode {
private List<Double> numbers;
private List<String> history;
private List<TreeNode> children;
private double score;
private boolean terminal;
}
// 候选操作类
static class CandidateOperation {
private String operation;
private List<Double> expectedNumbers;
private String reason;
private double score;
private String explanation;
}
// 24点游戏求解器
static class TwentyFourSolver {
private static final double TARGET = 24.0;
private static final double TOLERANCE = 1e-6;
private static final int MAX_STEPS = 5;
private static final int BEAM_WIDTH = 3;
private final ChatClient chatClient;
private final String modelName;
private final String systemPrompt;
public TwentyFourSolver(ChatClient chatClient, String modelName) {
this.chatClient = chatClient;
this.modelName = modelName;
// 构建系统提示
this.systemPrompt = """
你是一个解决24点游戏的专家。给定4个1-13之间的数字,使用加、减、乘、除和括号,使最终计算结果为24。
解决过程中,请遵循以下规则:
1. 每个数字必须且只能使用一次
2. 中间步骤的计算结果可以是分数
3. 最终答案必须是精确的24
当被要求生成下一步操作时,请提供JSON格式的候选操作列表(最多5个有希望的操作):
[
{
"operation": "具体操作(如:4+5=9)",
"expected_numbers": [操作后的数字列表],
"reason": "选择该操作的理由"
},
...
]
当被要求评估状态时,请提供JSON格式的评分和解释:
{
"score": 3,
"explanation": "理由..."
}
评分标准:
- 1分:当前数字组合不可能得到24
- 2分:可能得到24,但难度高
- 3分:有合理可能性得到24
- 4分:非常有希望得到24
- 5分:已得到24
""";
}
public Optional<String> solve(List<Integer> numbers) {
List<Double> initialNumbers = numbers.stream()
.map(Double::valueOf)
.collect(Collectors.toList());
TreeNode root = new TreeNode(initialNumbers, new ArrayList<>());
Queue<TreeNode> queue = new LinkedList<>();
queue.add(root);
while (!queue.isEmpty()) {
TreeNode currentNode = queue.poll();
// 检查是否已解决
if (currentNode.getNumbers().stream()
.anyMatch(n -> Math.abs(n - TARGET) < TOLERANCE)) {
return Optional.of(formatSolution(currentNode));
}
// 生成候选操作
List<CandidateOperation> candidates = generateCandidates(currentNode);
// 评估候选操作
evaluateCandidates(currentNode, candidates);
// 选择最有希望的操作
List<CandidateOperation> topCandidates = candidates.stream()
.sorted(Comparator.comparingDouble(CandidateOperation::getScore).reversed())
.limit(BEAM_WIDTH)
.collect(Collectors.toList());
// 创建子节点
for (CandidateOperation candidate : topCandidates) {
TreeNode childNode = new TreeNode(
candidate.getExpectedNumbers(),
new ArrayList<>(currentNode.getHistory())
);
childNode.getHistory().add(candidate.getOperation());
childNode.setScore(candidate.getScore());
currentNode.getChildren().add(childNode);
// 如果分数足够高,继续探索
if (candidate.getScore() >= 3) {
queue.add(childNode);
}
}
}
return Optional.empty(); // 无解
}
private List<CandidateOperation> generateCandidates(TreeNode node) {
String userPrompt = String.format("""
当前状态:
数字:%s
历史:%s
请生成最多5个有希望的下一步操作。
""", node.getNumbers(), node.getHistory());
String response = callLLM(userPrompt);
try {
// 解析JSON响应
List<CandidateOperation> candidates = new ArrayList<>();
// 实际应用中需要使用真正的JSON解析库
// 这里简化处理,实际代码应使用Jackson等库
return candidates;
} catch (Exception e) {
System.err.println("解析候选操作失败: " + e.getMessage());
System.err.println("LLM响应: " + response);
return Collections.emptyList();
}
}
private void evaluateCandidates(TreeNode node, List<CandidateOperation> candidates) {
for (CandidateOperation candidate : candidates) {
String userPrompt = String.format("""
当前状态:
数字:%s
历史:%s
候选操作:
%s
操作后数字:%s
请评分并解释。
""",
node.getNumbers(),
node.getHistory(),
candidate.getOperation(),
candidate.getExpectedNumbers());
String response = callLLM(userPrompt);
try {
// 解析JSON响应获取评分和解释
// 实际应用中需要使用真正的JSON解析库
// 这里简化处理
double score = 3.0; // 默认值
String explanation = "默认评估";
candidate.setScore(score);
candidate.setExplanation(explanation);
} catch (Exception e) {
System.err.println("解析评估结果失败: " + e.getMessage());
System.err.println("LLM响应: " + response);
candidate.setScore(2.0); // 保守评分
}
}
}
private String callLLM(String userPrompt) {
Message systemMessage = new Message(systemPrompt, "system");
Message userMessage = new Message(userPrompt, "user");
ChatResponse response = chatClient.generate(
List.of(systemMessage, userMessage),
modelName
);
Generation generation = response.getGenerations().get(0);
return generation.getContent();
}
private String formatSolution(TreeNode node) {
StringBuilder sb = new StringBuilder();
for (String step : node.getHistory()) {
sb.append(step).append("\n");
}
return sb.toString();
}
}
}