前沿
CoT(Chain-of-Thought)和 ReACT(Reasoning and Acting)是两种先进的 Prompt Engineering(提示工程) 技术,旨在提升大语言模型(LLM)的推理、规划和执行能力。
- CoT(Chain-of-Thought,思维链)
核心思想:通过引导模型生成中间推理步骤(step-by-step reasoning),模拟人类的渐进式思考过程,从而解决复杂问题。
特点:
显式推理:要求模型"展示思考过程",例如先分解问题,再逐步推导答案。
适用场景:数学题、逻辑推理、需要多步分析的任务(如"如果A比B高,B比C高,谁最矮?")。
实现方式:在Prompt中加入类似 "Let's think step by step" 的指令,或提供少量CoT示例(Few-shot CoT)。
优势:显著提升模型在复杂任务中的准确性(如GSM8K数学数据集上效果提升明显)。
局限性:仅生成文本推理,不涉及外部动作(如搜索、调用工具)
实例 数学问题
html
问题:小明有5个苹果,吃了2个,又买了3个,现在有多少个?
模型输出(CoT):
1. 初始有5个苹果。
2. 吃掉2个后剩余:5 - 2 = 3个。
3. 买了3个后总数:3 + 3 = 6个。
最终答案:6个苹果。
- ReACT(Reasoning + Acting,推理与行动)
核心思想:结合模型的推理能力和外部行动能力(如调用API、搜索、执行代码),通过交互式决策解决动态问题。
特点:
推理与行动交替:
模型先思考(Reason),再决定是否采取行动(Act),最后观察结果并迭代。
适用场景:
需要实时信息的任务(如问答)、工具调用(如计算器、数据库查询)、多步骤规划(如订机票+酒店)。
实现方式:
Prompt中明确模板(如"Thought: ... Action: ... Observation: ..."),并集成外部工具。
优势: 突破纯文本限制,通过工具扩展模型能力。
局限性: 依赖外部工具接口的设计和稳定性
示例(实时问答):
html
用户:2023年诺贝尔文学奖得主是谁?
模型输出(ReACT):
Thought: 我需要查找最新诺贝尔奖信息,应使用搜索工具。
Action: Search[2023 Nobel Literature Prize winner]
Observation: 2023年诺贝尔文学奖授予挪威作家Jon Fosse。
最终答案:Jon Fosse。
Java当中的实现
在Java中实现CoT(Chain-of-Thought)和ReACT(Reasoning and Acting)需要结合大语言模型(LLM)的API调用、动态Prompt生成以及外部工具集成。虽然Java生态的LLM工具链不如Python丰富,但通过HTTP客户端(如OkHttp)、JSON处理库(如Gson/Jackson)和规则引擎(如Drools)仍可构建完整流程。以下是具体实现方案:
1. 基础架构设计
2. 实现CoT(Chain-of-Thought)
方案1:静态模板(硬编码)
java
package com.prompt;
import java.util.Map;
import java.util.HashMap;
/**
* 实现CoT(Chain-of-Thought) 静态模板
*/
public class CotExample {
public static String generateCotPrompt(String question) {
// 定义CoT模板(可配置化)
String template = """
请逐步解决以下问题:
问题:{question}
步骤1:分析问题的关键点...
步骤2:拆解子问题...
步骤3:逐步计算...
最终答案:""";
// 变量替换
Map<String, String> variables = new HashMap<>();
variables.put("question", question);
return replacePlaceholders(template, variables);
}
private static String replacePlaceholders(String template, Map<String, String> vars) {
String result = template;
for (Map.Entry<String, String> entry : vars.entrySet()) {
result = result.replace("{" + entry.getKey() + "}", entry.getValue());
}
return result;
}
public static void main(String[] args) {
String prompt = generateCotPrompt("小明有5个苹果,吃了2个,又买了3个,现在有多少个?");
System.out.println(prompt);
System.out.println(generateCotPrompt("鸡兔共有头10个,腿30个,求鸡和兔的数量?"));
// 调用LLM API(见后续步骤)
}
}
运行结果
go
步骤1:分析问题的关键点
- 小明最开始有5个苹果。
- 他吃掉了2个苹果。
- 然后他又买了3个苹果。
- 最后我们要计算他现在总共有多少个苹果。
步骤2:拆解子问题
- 小明吃掉了多少个苹果?
- 吃掉苹果后还剩下多少个?
- 买了多少个苹果?
- 最后总共有多少个苹果?
步骤3:逐步计算
1. 小明最开始有5个苹果。
2. 吃掉2个苹果后:
5 - 2 = 3个苹果
3. 买了3个苹果后:
3 + 3 = 6个苹果
最终答案:小明现在有6个苹果。
方案2:动态模板引擎(Apache Velocity)
xml
<!-- pom.xml 依赖 -->
<dependency>
<groupId>org.apache.velocity</groupId>
<artifactId>velocity-engine-core</artifactId>
<version>2.4.1</version>
</dependency>
java
package com.prompt;
import org.apache.velocity.VelocityContext;
import org.apache.velocity.app.Velocity;
import java.io.StringWriter;
public class CotVelocity {
public static String generateWithVelocity(String question) {
Velocity.init();
VelocityContext context = new VelocityContext();
context.put("question", question);
String template = """
请逐步解决以下问题:
#set($steps = ["分析问题的关键点", "拆解子问题", "逐步计算"])
问题:$question
#foreach($step in $steps)
步骤${foreach.count}:$step...
#end
最终答案:""";
StringWriter writer = new StringWriter();
Velocity.evaluate(context, `writer`, "cot", template);
return writer.toString();
}
public static void main(String[] args) {
String promptStr =CotVelocity.generateWithVelocity("天上有多少颗星星");
// 调用LLM API(见后续步骤)
OpenAiChatModel model = OpenAiChatModel.builder()
.baseUrl("http://langchain4j.dev/demo/openai/v1")
.apiKey("demo")
.modelName(GPT_4_O_MINI)
// .httpClientBuilder(new SpringRestClientBuilder())
.logRequests(true)
.logResponses(true)
.build();
System.out.println(model.generate(promptStr));
}
}
运行结果
java
DEBUG: Request:
- method: POST
- url: http://langchain4j.dev/demo/openai/v1/chat/completions
- headers: [Authorization: Bearer ...], [User-Agent: langchain4j-openai]
- body: {
"model" : "gpt-4o-mini",
"messages" : [ {
"role" : "user",
"content" : "问题:天上有多少颗星星\n步骤1:分析问题的关键点...\n步骤2:拆解子问题...\n步骤3:逐步计算...\n最终答案:"
} ],
"temperature" : 0.7
}
2025-06-08 00:07:30 [main] dev.ai4j.openai4j.ResponseLoggingInterceptor.logDebug()
DEBUG: Response:
- status code: 200
- headers: [Date: Sat, 07 Jun 2025 16:07:22 GMT], [Content-Type: text/html;charset=utf-8], [Transfer-Encoding: chunked], [Server: Jetty(9.4.48.v20220622)]
- body: {"id":"chatcmpl-BfqMlXJE9BA8dVq2YpT04NHyHTNjV","created":1749312443,"model":"gpt-4o-mini-2024-07-18","choices":[{"index":0,"message":{"role":"assistant","content":"问题:天上有多少颗星星\n\n步骤1:分析问题的关键点\n- 这个问题实际上是一个开放性的问题,因为"星星"的定义和可见范围会影响答案。\n- 我们需要考虑可见星星的数量、银河系中的星星数量,以及宇宙中星星的总体估计。\n\n步骤2:拆解子问题\n1. 在晴朗的夜空中,肉眼可见的星星数量大约是多少?\n2. 银河系中大约有多少颗星星?\n3. 宇宙中大约有多少个银河系?每个银河系又大约有多少颗星星?\n\n步骤3:逐步计算\n1. 肉眼可见的星星数量:\n - 在理想条件下,肉眼可见的星星大约在2000到3000颗之间。\n\n2. 银河系中的星星数量:\n - 银河系中估计有1000亿到4000亿颗星星。我们可以取一个平均值,大约为2000亿颗。\n\n3. 宇宙中的银河系数量及其星星数量:\n - 当前的研究表明,宇宙中大约有2000亿到3000亿个银河系。我们取一个中间值2500亿个银河系。\n - 每个银河系平均约有2000亿颗星星。\n - 因此,宇宙中的星星数量大约为:2500亿 × 2000亿 \u003d 5 × 10^22颗星星(即50000000000000000000000颗星星)。\n\n最终答案:\n在晴朗的夜空中,肉眼可见大约2000到3000颗星星;而在整个宇宙中,估计有大约50000000000000000000000颗星星。"},"finish_reason":"stop"}],"usage":{"prompt_tokens":44,"completion_tokens":403,"total_tokens":447},"system_fingerprint":"fp_34a54ae93c"}
问题:天上有多少颗星星
步骤1:分析问题的关键点
- 这个问题实际上是一个开放性的问题,因为"星星"的定义和可见范围会影响答案。
- 我们需要考虑可见星星的数量、银河系中的星星数量,以及宇宙中星星的总体估计。
步骤2:拆解子问题
1. 在晴朗的夜空中,肉眼可见的星星数量大约是多少?
2. 银河系中大约有多少颗星星?
3. 宇宙中大约有多少个银河系?每个银河系又大约有多少颗星星?
步骤3:逐步计算
1. 肉眼可见的星星数量:
- 在理想条件下,肉眼可见的星星大约在2000到3000颗之间。
2. 银河系中的星星数量:
- 银河系中估计有1000亿到4000亿颗星星。我们可以取一个平均值,大约为2000亿颗。
3. 宇宙中的银河系数量及其星星数量:
- 当前的研究表明,宇宙中大约有2000亿到3000亿个银河系。我们取一个中间值2500亿个银河系。
- 每个银河系平均约有2000亿颗星星。
- 因此,宇宙中的星星数量大约为:2500亿 × 2000亿 = 5 × 10^22颗星星(即50000000000000000000000颗星星)。
最终答案:
在晴朗的夜空中,肉眼可见大约2000到3000颗星星;而在整个宇宙中,估计有大约50000000000000000000000颗星星。
3. 实现ReACT(Reasoning + Acting)
关键组件
LLM API调用:通过HTTP客户端访问OpenAI/Gemini等API。
动作调度器:根据LLM输出调用工具(如搜索、数据库)。
循环控制器:管理Thought→Action→Observation流程
案例一
模拟一个旅行规划AI Agent,需要协调航班查询、酒店预订和天气检查
1、首先定义一个工具类 模拟第三方接口的数据 (必须用 @Tool 注解)
java
package com.prompt.react;
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.agent.tool.ToolSpecifications;
import dev.langchain4j.exception.IllegalConfigurationException;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.tool.DefaultToolExecutor;
import java.lang.reflect.Method;
import java.time.LocalDate;
import java.util.*;
public class TravelTools {
// 模拟航班数据库
private static final Map<String, List<Flight>> FLIGHTS_DB = Map.of(
"NYC-LON", List.of(
new Flight("DL123", "NYC", "LON", LocalDate.of(2023,12,25).plusDays(0), 599),
new Flight("DL123", "NYC", "LON", LocalDate.of(2023,12,24).plusDays(1), 799),
new Flight("BA456", "NYC", "LON", LocalDate.of(2023,12,24).plusDays(2), 699)
),
"纽约-伦敦", List.of(
new Flight("DL128", "JFK", "LHR", LocalDate.of(2023,12,24).plusDays(1), 799),
new Flight("BA459", "EWR", "LGW", LocalDate.of(2023,12,24).plusDays(2), 699)
)
);
// 模拟酒店数据库
private static final Map<String, List<Hotel>> HOTELS_DB = Map.of(
"LON", List.of(
new Hotel("Hilton", 250, 4.5),
new Hotel("Premier Inn", 120, 4.0)
),
"London", List.of(
new Hotel("Hilton", 250, 4.5),
new Hotel("Premier Inn", 120, 4.0)
)
);
// 模拟天气API
private static final Map<String, String> WEATHER_DB = Map.of(
"LON-3days", "Sunny, 22°C / Rainy, 18°C / Cloudy, 20°C",
"London-3days", "Sunny, 22°C / Rainy, 18°C / Cloudy, 20°C"
);
@Tool("查询航班信息 (参数示例: {'route':'NYC-LON','date':'2023-12-25'})")
public String searchFlights(String departure,String destination, String date) {
System.out.println("=================="+destination+" "+date);
return FLIGHTS_DB.getOrDefault(departure+"-"+destination, Collections.emptyList()).stream()
.filter(f -> f.date.equals(LocalDate.parse(date)))
.findFirst()
.map(f -> String.format(
"航班 %s: %s→%s, 价格 $%d",
f.number, f.from, f.to, f.price
))
.orElse("无可用航班");
}
@Tool("查询酒店信息 (参数示例: {'city':'LON','budget':200})")
public String searchHotels(String city, int budget) {
System.out.println("----------------------酒店信息"+city);
return HOTELS_DB.getOrDefault(city, Collections.emptyList()).stream()
.filter(h -> h.price <= budget)
.sorted(Comparator.comparingDouble(h -> -h.rating))
.findFirst()
.map(h -> String.format(
"酒店 %s: $%d/晚, 评分 %.1f",
h.name, h.price, h.rating
))
.orElse("无符合预算的酒店");
}
@Tool("查询目的地天气 (参数示例: {'city':'LON','days':3})")
public String checkWeather(String city, int days) {
System.out.println("--------------------------city weather: " + city);
return WEATHER_DB.getOrDefault(city + "-" + days + "days", "天气数据暂不可用");
}
// 数据类
record Flight(String number, String from, String to, LocalDate date, int price) {}
record Hotel(String name, int price, double rating) {}
public static List<ToolSpecification> tools(List<Object> objectsWithTools) {
if(objectsWithTools==null) return Collections.emptyList();
List<ToolSpecification> tools = new ArrayList<>();
Iterator var2 = objectsWithTools.iterator();
while(var2.hasNext()) {
Object objectWithTool = var2.next();
if (objectWithTool instanceof Class) {
throw IllegalConfigurationException.illegalConfiguration("Tool '%s' must be an object, not a class", new Object[]{objectWithTool});
}
Method[] var4 = objectWithTool.getClass().getDeclaredMethods();
int var5 = var4.length;
for(int var6 = 0; var6 < var5; ++var6) {
Method method = var4[var6];
if (method.isAnnotationPresent(Tool.class)) {
ToolSpecification toolSpecification = ToolSpecifications.toolSpecificationFrom(method);
tools.add(toolSpecification);
}
}
}
return tools;
}
}
2. ReACT Agent 实现(核心逻辑)
java
package com.prompt.react;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.agent.tool.*;
import dev.langchain4j.data.message.*;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.service.tool.DefaultToolExecutor;
import dev.langchain4j.service.tool.ToolExecutor;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_4_O_MINI;
public class TravelPlanner {
private final OpenAiChatModel model;
private final ChatMemory memory;
private final ToolExecutor toolExecutor;
String uuid = UUID.randomUUID().toString();
public TravelPlanner(String apiKey) {
// this.model = OpenAiChatModel.builder()
// .baseUrl("http://langchain4j.dev/demo/openai/v1")
// .apiKey("demo")
// .modelName(GPT_4_O_MINI)
// // .httpClientBuilder(new SpringRestClientBuilder())
// .logRequests(true)
// .logResponses(true)
// .build();
this.model = OpenAiChatModel.builder()
.apiKey(apiKey)
.baseUrl("https://api.deepseek.com/v1")
.modelName("deepseek-chat")
// .httpClientBuilder(new SpringRestClientBuilder())
.strictTools(true) // https://docs.langchain4j.dev/integrations/language-models/open-ai#structured-outputs-for-tools
.logRequests(true)
.logResponses(true)
.build();
this.memory = MessageWindowChatMemory.withMaxMessages(20);
this.toolExecutor = new ParallelToolExecutor(new TravelTools());
// 关键:注入系统提示(强制ReACT格式)
memory.add(SystemMessage.systemMessage(
"你是一个旅行规划AI,必须按以下步骤处理请求:\n" +
"1. 分析用户需求\n" +
"2. 调用工具获取数据\n" +
"3. 综合结果给出建议\n\n" +
"必须使用以下格式响应:\n" +
"Thought: 思考过程\n" +
"Action: 工具名\n" +
"Action Input: {\"参数\":\"值\"}\n\n" +
"可用工具:\n" +
"- searchFlights (参数: departure, destination, date, maxBudget)\n" +
"- searchHotels (参数: city, maxPrice)\n" +
"- checkWeather (参数: city, days)"
));
}
public String plan(String userRequest) throws InterruptedException {
memory.add(UserMessage.userMessage(userRequest));
System.out.println("=== 用户请求 ===\n" + userRequest);
List<ToolSpecification> toolSpecifications = ToolSpecifications.toolSpecificationsFrom(TravelTools.class);
for (int step = 1; step <= 5; step++) {
// 1. 调用模型生成响应
Response<AiMessage> response = model.generate(memory.messages(),toolSpecifications);
AiMessage aiMessage = response.content();
memory.add(aiMessage);
// 2. 打印模型思考过程
System.out.println("\n[Step " + step + "]");
List<ToolExecutionRequest> toolExecutionRequests = aiMessage.toolExecutionRequests();
// 3. 检查工具调用
if (toolExecutionRequests ==null || toolExecutionRequests.isEmpty()) {
return "\n=== 最终建议 ===\n" + aiMessage.text();
}
// 4. 执行所有工具调用
toolExecutionRequests.forEach(req -> {
String output = toolExecutor.execute(req,uuid);
memory.add(ToolExecutionResultMessage.from(req, output));
System.out.println("[工具执行] " + req.name() + " → " + output);
});
Thread.sleep(2000);
}
throw new RuntimeException("超过最大推理步数");
}
// 并行工具执行器
static class ParallelToolExecutor implements ToolExecutor {
private final TravelTools tools;
ParallelToolExecutor(TravelTools tools) {
this.tools = tools;
}
@Override
public String execute(ToolExecutionRequest request,Object o) {
return switch (request.name()) {
case "searchFlights" -> parseAndCall(
request.arguments(),
args -> tools.searchFlights(args.get("departure"),args.get("destination"), args.get("date"))
);
case "searchHotels" -> parseAndCall(
request.arguments(),
args -> tools.searchHotels(args.get("city"), Integer.parseInt(args.get("budget")))
);
case "checkWeather" -> parseAndCall(
request.arguments(),
args -> tools.checkWeather(args.get("city"), Integer.parseInt(args.get("days")))
);
default -> throw new IllegalArgumentException("未知工具: " + request.name());
};
}
private String parseAndCall(String jsonArgs, java.util.function.Function<Map<String, String>, String> func) {
try {
Map<String, String> args = new ObjectMapper().readValue(jsonArgs, new TypeReference<>() {});
return func.apply(args);
} catch (Exception e) {
return "工具调用失败: " + e.getMessage();
}
}
}
}
注意:不同大模型 返回的参数名略有差异 需要好好调试·
3. 主程序运行
java
package com.prompt.react;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
public class TravelPlanningDemo {
public static void main(String[] args) throws InterruptedException {
String apiKey="sk-xx";
TravelPlanner planner = new TravelPlanner(apiKey);
String request = "帮我规划2023年12月25日从纽约到伦敦的旅行:" +
"1. 查找预算$700以内的航班\n" +
"2. 预订$200/晚以下的酒店\n" +
"3. 检查3天内的天气";
String result = planner.plan(request);
System.out.println("\n=== 最终规划结果 ===");
System.out.println(result);
}
}
运行结果如下:
java
Connected to the target VM, address: '127.0.0.1:11304', transport: 'socket'
=== 用户请求 ===
帮我规划2023年12月25日从纽约到伦敦的旅行:1. 查找预算$700以内的航班
2. 预订$200/晚以下的酒店
3. 检查3天内的天气
[Step 1]
==================LON 2023-12-25
[工具执行] searchFlights → 航班 DL123: NYC→LON, 价格 $599
----------------------酒店信息London
[工具执行] searchHotels → 酒店 Premier Inn: $120/晚, 评分 4.0
--------------------------city weather: London
[工具执行] checkWeather → Sunny, 22°C / Rainy, 18°C / Cloudy, 20°C
[Step 2]
=== 最终规划结果 ===
=== 最终建议 ===
Thought: 已获取航班、酒店和天气信息,现在综合结果给出建议。
### 旅行规划建议:
1. **航班**:
- **航班号**:DL123
- **路线**:纽约 (NYC) → 伦敦 (LON)
- **日期**:2023年12月25日
- **价格**:$599(在您的预算$700以内)
2. **酒店**:
- **名称**:Premier Inn
- **价格**:$120/晚(在您的预算$200/晚以内)
- **评分**:4.0
3. **天气**(伦敦未来3天):
- 第一天:晴天,22°C
- 第二天:雨天,18°C
- 第三天:多云,20°C
### 下一步建议:
- 确认航班和酒店预订。
- 根据天气情况准备衣物(建议带雨具)。
- 如需进一步帮助,请告知!
Process finished with exit code 0
案例二
该案例模拟一个 金融投资顾问AI,需要协同调用股票查询、风险评估和投资组合优化工具
1、 复杂工具定义
模拟三个关键服务:股票数据、风险评估、组合优化
java
package com.prompt.invest;
import dev.langchain4j.agent.tool.Tool;
import com.google.gson.Gson;
public class InvestmentTools {
private static final Gson gson = new Gson();
public static void main(String[] args) {
System.out.println(gson.fromJson("{\"request\": {\"symbol\": \"AAPL\", \"days\": 30}}".replace("{\"request\":","").replace("}","")+"}", StockRequest.class));
}
// 工具1:股票数据查询
@Tool("查询股票实时数据 (参数示例: {'symbol':'AAPL','days':30})")
public String getStockData(StockRequest request) {
System.out.println("======================getStockData");
return String.format("""
%s股票最近%d天数据:
- 当前价: $%.2f
- 波动率: %.2f%%
- 市盈率: %.1f""",
request.symbol, request.days,
Math.random() * 100 + 100, // 模拟价格
Math.random() * 20 + 5, // 模拟波动率
Math.random() * 10 + 15 // 模拟PE
);
}
// 工具2:风险评估
@Tool("评估投资组合风险 (参数示例: {'stocks':['AAPL','MSFT'],'total_amount':10000})")
public String assessRisk(PortfolioRequest request) {
System.out.println("======================assessRisk");
return String.format("""
组合风险评估 (总金额: $%,d):
- 最大回撤: %.1f%%
- 夏普比率: %.2f
- 风险等级: %s""",
request.total_amount,
Math.random() * 15 + 5,
Math.random() * 2 + 0.5,
new String[]{"A","B","C"}[(int)(Math.random()*3)]
);
}
// 工具3:组合优化
@Tool("生成优化投资组合 (参数示例: {'stocks':['AAPL','MSFT'],'risk_level':'B','amount':10000})")
public String optimizePortfolio(OptimizationRequest request) {
System.out.println("======================optimizePortfolio");
return String.format("""
优化建议 (风险等级 %s):
- AAPL: %.0f%%
- MSFT: %.0f%%
- 现金: %.0f%%""",
request.risk_level,
(0.6 + Math.random() * 0.2) * 100,
(0.3 + Math.random() * 0.1) * 100,
Math.random() * 10
);
}
// 参数类(强制JSON结构)
public static class StockRequest {
public String symbol;
public int days;
@Override
public String toString() {
return "StockRequest{" +
"symbol='" + symbol + '\'' +
", days=" + days +
'}';
}
public StockRequest() {
}
public StockRequest(String symbol, int days) {
this.symbol = symbol;
this.days = days;
}
public String getSymbol() {
return symbol;
}
public void setSymbol(String symbol) {
this.symbol = symbol;
}
public int getDays() {
return days;
}
public void setDays(int days) {
this.days = days;
}
}
public static class PortfolioRequest {
public String[] stocks;
public int total_amount;
public PortfolioRequest(String[] stocks, int total_amount) {
this.stocks = stocks;
this.total_amount = total_amount;
}
public PortfolioRequest() {
}
public String[] getStocks() {
return stocks;
}
public void setStocks(String[] stocks) {
this.stocks = stocks;
}
public int getTotal_amount() {
return total_amount;
}
public void setTotal_amount(int total_amount) {
this.total_amount = total_amount;
}
}
public static class OptimizationRequest {
public String[] stocks;
public String risk_level;
public int amount;
public OptimizationRequest() {
}
public OptimizationRequest(String[] stocks, String risk_level, int amount) {
this.stocks = stocks;
this.risk_level = risk_level;
this.amount = amount;
}
public String[] getStocks() {
return stocks;
}
public void setStocks(String[] stocks) {
this.stocks = stocks;
}
public int getAmount() {
return amount;
}
public void setAmount(int amount) {
this.amount = amount;
}
public String getRisk_level() {
return risk_level;
}
public void setRisk_level(String risk_level) {
this.risk_level = risk_level;
}
}
}
2. ReACT Agent核心实现
严格保证工具调用顺序和参数传递
java
package com.prompt.invest;
import com.google.gson.Gson;
import com.prompt.react.TravelTools;
import dev.langchain4j.agent.tool.*;
import dev.langchain4j.data.message.*;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.service.tool.ToolExecutor;
import java.util.*;
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_4_O_MINI;
public class InvestmentAdvisor {
private final OpenAiChatModel model;
private final ChatMemory memory;
private final ToolExecutor toolExecutor;
public InvestmentAdvisor(String apiKey) {
this.model = OpenAiChatModel.builder()
.apiKey(apiKey)
.baseUrl("https://api.deepseek.com/v1")
.modelName("deepseek-chat")
// .httpClientBuilder(new SpringRestClientBuilder())
.strictTools(true) // https://docs.langchain4j.dev/integrations/language-models/open-ai#structured-outputs-for-tools
.logRequests(true)
.logResponses(true)
.build();
this.memory = MessageWindowChatMemory.withMaxMessages(30);
this.toolExecutor = new StrictToolExecutor(new InvestmentTools());
// 注入强约束系统提示
memory.add(SystemMessage.systemMessage("""
你是一个高级投资顾问,必须按以下步骤操作:
1. 先查询股票数据
2. 然后评估当前组合风险
3. 最后生成优化建议
必须使用以下工具和格式:
Action: 工具名
Action Input: {"参数":"值"}
"""));
}
public String advise(String request) {
memory.add(UserMessage.userMessage(request));
System.out.println("\n=== 开始处理 ===");
System.out.println("[用户请求] " + request);
List<ToolSpecification> toolSpecifications = ToolSpecifications.toolSpecificationsFrom(InvestmentTools.class);
for (int i = 1; i <= 6; i++) {
// 1. 生成响应
Response<AiMessage> response = model.generate(memory.messages(),toolSpecifications);
AiMessage aiMessage = response.content();
memory.add(aiMessage);
List<ToolExecutionRequest> toolExecutionRequests = aiMessage.toolExecutionRequests();
// 2. 处理工具调用
if (toolExecutionRequests!=null && !toolExecutionRequests.isEmpty()) {
toolExecutionRequests.forEach(req -> {
String output = toolExecutor.execute(req,null);
memory.add(ToolExecutionResultMessage.from(req, output));
logToolCall(req, output);
});
} else {
if(aiMessage.text() ==null){
continue;
}
return formatFinalAnswer(aiMessage.text());
}
if(aiMessage.text() ==null){
continue;
}
logThought(aiMessage.text());
}
throw new RuntimeException("超过最大处理步数");
}
// 严格参数校验的执行器
private static class StrictToolExecutor implements ToolExecutor {
private final InvestmentTools tools;
private final Gson gson = new Gson();
StrictToolExecutor(InvestmentTools tools) {
this.tools = tools;
}
@Override
public String execute(ToolExecutionRequest request,Object memoryId) {
try {
return switch (request.name()) {
case "getStockData" ->
tools.getStockData(gson.fromJson(request.arguments().replace("{\"request\":","").replace("}","")+"}", InvestmentTools.StockRequest.class));
case "assessRisk" ->
tools.assessRisk(gson.fromJson(request.arguments().replace("{\"request\":","").replace("}","")+"}", InvestmentTools.PortfolioRequest.class));
case "optimizePortfolio" ->
tools.optimizePortfolio(gson.fromJson(request.arguments().replace("{\"request\":","").replace("}","")+"}", InvestmentTools.OptimizationRequest.class));
default -> throw new IllegalArgumentException("未知工具: " + request.name());
};
} catch (Exception e) {
return "执行失败: " + e.getMessage();
}
}
}
private void logThought(String thought) {
System.out.println("[AI思考] " + thought.split("\n")[0]); // 只打印第一行
}
private void logToolCall(ToolExecutionRequest req, String output) {
System.out.printf("[Step %d] 调用工具: %s(%s)\n 结果: %s\n",
memory.messages().size() / 2, // 粗略计算步数
req.name(),
req.arguments().replace(",", ",\n "),
output.replace("\n", "\n ")
);
}
private String formatFinalAnswer(String text) {
return "\n=== 最终建议 ===\n" + text;
}
}
3、主程序测试
java
public static void main(String[] args) {
String apiKey="sk-xx";
InvestmentAdvisor advisor = new InvestmentAdvisor(apiKey);
String result = advisor.advise(
"请分析AAPL和MSFT的投资组合,总金额50000美元,风险等级B"
);
System.out.println(result);
}
运行结果如下:
java
=== 开始处理 ===
[用户请求] 请分析AAPL和MSFT的投资组合,总金额50000美元,风险等级B
======================getStockData
[Step 2] 调用工具: getStockData({"request": {"symbol": "AAPL",
"days": 30}})
结果: AAPL股票最近30天数据:
- 当前价: $121.81
- 波动率: 21.18%
- 市盈率: 20.6
======================getStockData
[Step 2] 调用工具: getStockData({"request": {"symbol": "MSFT",
"days": 30}})
结果: MSFT股票最近30天数据:
- 当前价: $142.23
- 波动率: 8.33%
- 市盈率: 22.1
======================assessRisk
[Step 3] 调用工具: assessRisk({"request":{"stocks":["AAPL",
"MSFT"],
"total_amount":50000}})
结果: 组合风险评估 (总金额: $50,000):
- 最大回撤: 5.5%
- 夏普比率: 0.76
- 风险等级: B
======================optimizePortfolio
[Step 4] 调用工具: optimizePortfolio({"request":{"stocks":["AAPL",
"MSFT"],
"risk_level":"B",
"amount":50000}})
结果: 优化建议 (风险等级 B):
- AAPL: 68%
- MSFT: 40%
- 现金: 1%
=== 最终建议 ===
### 分析结果
#### 1. 股票数据查询
- **AAPL**:
- 当前价: $121.81
- 波动率: 21.18%
- 市盈率: 20.6
- **MSFT**:
- 当前价: $142.23
- 波动率: 8.33%
- 市盈率: 22.1
#### 2. 组合风险评估
- **总金额**: $50,000
- **最大回撤**: 5.5%
- **夏普比率**: 0.76
- **风险等级**: B
#### 3. 优化建议
根据风险等级 **B**,优化后的投资组合建议如下:
- **AAPL**: 68% (约 $34,000)
- **MSFT**: 40% (约 $20,000)
- **现金**: 1% (约 $500)
### 说明
- **AAPL** 的波动率较高,适合在组合中承担更高的风险收益比。
- **MSFT** 的波动率较低,适合作为稳定收益的部分。
- 保留少量现金以应对市场波动。
如果需要进一步调整或优化,请告知!
感谢 各位看官的支持 【zwpflc】
有问题可以给我留言 多交流