一、概念
首先要知道,在整个AI应用中什么是代码解释器,官方给出的解释为:大模型不擅长数学运算、数据可视化等精确计算任务,您可以使用 Assistant API 预置的代码解释器插件,使智能体编写和运行 Python 程序,从而逐步解决复杂的数据问题。如果智能体编写的代码未能成功运行,它还会不断调整代码,尝试不同的方法,直到代码能够顺利执行。
咱们用生活化的比喻来讲,一下子就明白了:
你可以把大模型想象成一个特别会聊天、懂超多知识的 "文科生" ------ 它能跟你侃天说地、写文章、解逻辑题,但让它算复杂的数学题(比如统计一堆数据的均值、方差)、画图表(比如把销售数据做成柱状图)这些需要 "精确计算 + 实操" 的活儿,它就容易出错,有点 "手忙脚乱"。
而 Assistant API 的代码解释器插件,就相当于给这个 "文科生" 配了一个 "理科小助手"(Python 编程工具):
- 遇到数学运算、数据可视化这些它不擅长的任务,它会先写一段 Python 代码(就像给小助手下指令);
- 让这个小助手去执行代码,精准完成计算、画图这些活儿;
- 如果第一次写的代码跑不通(比如语法错了、逻辑漏了),它还会像个较真的程序员一样,自己检查、修改代码,换不同的写法再试,直到代码能顺利跑起来,把准确的结果给到你为止。
简单说:插件就是给大模型装了个 "计算 + 实操外挂",让它能靠写代码搞定自己不擅长的精确数据任务,还能自动纠错,直到成功为止。
二、代码实现
1. 思路
新增 本地 Java 代码解释 / 执行工具,让大模型能生成 Java 代码并调用本地 JDK 编译运行,核心新增点包括:
- 定义 Java 代码执行工具的 Schema(对齐 Function Calling 规范);
- 实现本地编译 / 运行 Java 代码的核心方法(含临时文件管理、输出捕获、异常处理);
- 整合到现有工具调用逻辑中,与天气工具并列;
- 增加安全防护(临时文件自动清理、代码权限限制)。
2. 代码
2.1 新增 Java 代码执行工具的 Schema 定义
java
String javaCodeParamsSchema = "{" +
"\"type\":\"object\"," +
"\"properties\":{" +
"\"code\":{\"type\":\"string\",\"description\":\"需要执行的Java代码,必须包含main方法,仅使用JDK原生类,禁止文件写入/网络请求等危险操作\"}" +
"}," +
"\"required\":[\"code\"]" +
"}";
- 定义了工具的参数规范:必须传入
code字段(Java 代码字符串); - 明确约束:代码需包含 main 方法、仅用 JDK 原生类、禁止危险操作(安全防护)
2.2 构建 Java 代码执行工具并加入工具列表
java
FunctionDefinition javaCodeFunction = FunctionDefinition.builder()
.name("run_java_code")
.description("当需要执行精确计算、数据处理、逻辑运算等Java代码时调用...")
.parameters(JsonUtils.parseString(javaCodeParamsSchema).getAsJsonObject())
.build();
List<ToolFunction> toolFunctions = new ArrayList<>();
toolFunctions.add(ToolFunction.builder().function(weatherFunction).build());
toolFunctions.add(ToolFunction.builder().function(javaCodeFunction).build());
- 工具名
run_java_code与后续executeLocalFunction中的 case 对应; - 工具列表同时包含天气工具和 Java 执行工具,大模型会自动选择调用。
2.3 配置 Jackson 解析器(处理特殊字符)
java
@Configuration
public class JacksonConfig {
@Bean
public ObjectMapper objectMapper(Jackson2ObjectMapperBuilder builder) {
ObjectMapper objectMapper = builder.build();
// 允许解析包含未转义控制字符的JSON(容错)
objectMapper.configure(JsonParser.Feature.ALLOW_UNQUOTED_CONTROL_CHARS, true);
// 忽略未知字段,避免请求体多字段导致解析失败
objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
return objectMapper;
}
}
2.4 核心:本地 Java 代码执行方法 runJavaCode
- 临时文件管理:生成随机类名和临时文件,避免冲突;
- 编译逻辑 :调用
javac编译 Java 文件,捕获编译错误; - 运行逻辑 :调用
java运行编译后的类,捕获运行输出; - 安全清理 :执行完成后自动删除临时
.java和.class文件; - 超时控制:限制编译 / 运行时间(30 秒),避免卡死。
java
/**
* 新增核心方法:本地编译并运行Java代码
* @param javaCode 完整的Java代码字符串(需包含main方法)
* @return 执行结果/编译错误/运行错误
*/
private String runJavaCode(String javaCode) {
// 1. 初始化临时目录
File tempDir = new File(JAVA_TEMP_DIR);
if (!tempDir.exists()) {
boolean mkdirSuccess = tempDir.mkdirs();
if (!mkdirSuccess) {
return "创建临时目录失败:" + JAVA_TEMP_DIR;
}
}
// 2. 生成随机类名(避免冲突)
String className = "TempJavaCode_" + System.currentTimeMillis() + "_" + new Random().nextInt(1000);
String javaFileName = className + ".java";
File javaFile = new File(tempDir, javaFileName);
try {
// 3. 写入Java代码到临时文件(替换类名为生成的随机名)
String formattedCode = javaCode.replaceFirst("public class\\s+\\w+", "public class " + className);
Files.write(javaFile.toPath(), formattedCode.getBytes(StandardCharsets.UTF_8));
log.info("生成临时Java文件:{},代码:\n{}", javaFile.getAbsolutePath(), formattedCode);
// 4. 编译Java代码(调用javac)
Process compileProcess = new ProcessBuilder()
.command(JDK_COMPILE_CMD, "-encoding", "UTF-8", javaFile.getAbsolutePath())
.directory(tempDir)
.redirectErrorStream(true)
.start();
// 等待编译完成并捕获输出
String compileOutput = readProcessOutput(compileProcess);
compileProcess.waitFor(EXEC_TIMEOUT_SECONDS, TimeUnit.SECONDS);
if (compileProcess.exitValue() != 0) {
return "编译失败:\n" + compileOutput;
}
// 5. 运行编译后的类(调用java)
Process runProcess = new ProcessBuilder()
.command(JDK_RUN_CMD, "-cp", tempDir.getAbsolutePath(), className)
.directory(tempDir)
.redirectErrorStream(true)
.start();
// 等待运行完成并捕获输出
String runOutput = readProcessOutput(runProcess);
runProcess.waitFor(EXEC_TIMEOUT_SECONDS, TimeUnit.SECONDS);
if (runProcess.exitValue() != 0) {
return "运行失败:\n" + runOutput;
}
return "执行成功:\n" + runOutput;
} catch (Exception e) {
log.error("执行Java代码异常", e);
return "Java代码执行异常:" + e.getMessage();
} finally {
// 6. 清理临时文件(安全防护)
try {
if (javaFile.exists()) {
Files.deleteIfExists(javaFile.toPath());
}
File classFile = new File(tempDir, className + ".class");
if (classFile.exists()) {
Files.deleteIfExists(classFile.toPath());
}
} catch (IOException e) {
log.warn("清理临时Java文件失败", e);
}
}
}
2.5 总的
java
package gzj.spring.ai.Service.ServiceImpl;
import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.generation.GenerationParam;
import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.alibaba.dashscope.common.Message;
import com.alibaba.dashscope.common.Role;
import com.alibaba.dashscope.exception.ApiException;
import com.alibaba.dashscope.exception.InputRequiredException;
import com.alibaba.dashscope.exception.NoApiKeyException;
import com.alibaba.dashscope.tools.FunctionDefinition;
import com.alibaba.dashscope.tools.ToolCallBase;
import com.alibaba.dashscope.tools.ToolCallFunction;
import com.alibaba.dashscope.tools.ToolFunction;
import com.alibaba.dashscope.utils.JsonUtils;
import com.alibaba.fastjson.JSONObject;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import gzj.spring.ai.DTO.ChatRequestDTO;
import gzj.spring.ai.Service.DashScopeService;
import gzj.spring.ai.util.ChatContextCacheUtil;
import jakarta.servlet.http.HttpServletResponse;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.*;
import java.net.HttpURLConnection;
import java.net.URL;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
/**
* 整合上下文缓存 + 通义千问官方 Function Calling + 联网搜索 + 聚合天气接口 + 本地Java代码执行工具
* 完全对齐官方示例 API 规范
* @author DELL
*/
@Slf4j
@Service
@RequiredArgsConstructor
public class DashScopeServiceImpl implements DashScopeService {
// 缓存工具类(原有)
private final ChatContextCacheUtil chatContextCacheUtil;
// 通义千问配置(原有)
@Value("${spring.ai.dashscope.api-key}")
private String apiKey;
@Value("${spring.ai.dashscope.model}")
private String model;
// 聚合数据天气接口配置
@Value("${jute.weather.api-key}")
private String juheWeatherApiKey;
private static final String JUHE_WEATHER_API_URL = "http://apis.juhe.cn/simpleWeather/query";
// Java代码执行配置
private static final String JAVA_TEMP_DIR = System.getProperty("java.io.tmpdir") + "/ai_java_code/";
// 需确保javac在系统PATH中
private static final String JDK_COMPILE_CMD = "javac";
// 需确保java在系统PATH中
private static final String JDK_RUN_CMD = "java";
// 代码执行超时时间
private static final int EXEC_TIMEOUT_SECONDS = 30;
// JSON 解析器(官方示例用)
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
/**
* 流式对话核心方法(整合缓存 + 官方 Function Calling + 联网搜索 + Java代码执行)
*/
@Override
public SseEmitter streamChat(ChatRequestDTO requestDTO, HttpServletResponse response) {
// 校验userId(原有逻辑)
if (requestDTO.getUserId() == null || requestDTO.getUserId().trim().isEmpty()) {
SseEmitter emitter = new SseEmitter(0L);
sendError(emitter, "userId不能为空,请传入用户唯一标识");
return emitter;
}
String userId = requestDTO.getUserId().trim();
SseEmitter emitter = new SseEmitter(TimeUnit.SECONDS.toMillis(60));
initEmitterCallback(emitter);
new Thread(() -> {
try {
// ========== 原有缓存逻辑:保留不变 ==========
List<ChatRequestDTO.MessageDTO> cacheHistory = chatContextCacheUtil.getHistoryByUserId(userId);
List<ChatRequestDTO.MessageDTO> finalHistory = mergeHistory(cacheHistory, requestDTO.getHistoryMessages());
// 转换为官方 Message 格式
List<Message> messages = convertToDashScopeMessages(finalHistory);
// 添加本次用户提问
String cleanNewQuestion = requestDTO.getQuestion().replaceAll("\\s+", " ").trim();
if (cleanNewQuestion.isEmpty()) {
throw new InputRequiredException("用户提问不能为空");
}
Message userMsg = Message.builder()
.role(Role.USER.getValue())
.content(cleanNewQuestion)
.build();
messages.add(userMsg);
// ========== 核心修改:添加Java代码执行工具 + 天气工具 ==========
// 1. 定义天气查询工具的 Schema(原有)
String weatherParamsSchema = "{" +
"\"type\":\"object\"," +
"\"properties\":{" +
"\"city\":{\"type\":\"string\",\"description\":\"城市名称,比如北京、上海、苏州等。\"}" +
"}," +
"\"required\":[\"city\"]" +
"}";
// 2. 定义Java代码执行工具的 Schema(新增)
String javaCodeParamsSchema = "{" +
"\"type\":\"object\"," +
"\"properties\":{" +
"\"code\":{\"type\":\"string\",\"description\":\"需要执行的Java代码,必须包含main方法,仅使用JDK原生类,禁止文件写入/网络请求等危险操作\"}" +
"}," +
"\"required\":[\"code\"]" +
"}";
// 3. 构建天气工具定义(原有)
FunctionDefinition weatherFunction = FunctionDefinition.builder()
.name("get_current_weather")
.description("当你想查询指定城市的天气时非常有用,必须调用该工具获取真实天气数据。")
.parameters(JsonUtils.parseString(weatherParamsSchema).getAsJsonObject())
.build();
// 4. 构建Java代码执行工具定义(新增)
FunctionDefinition javaCodeFunction = FunctionDefinition.builder()
.name("run_java_code")
.description("当需要执行精确计算、数据处理、逻辑运算等Java代码时调用,代码必须包含main方法,仅使用JDK原生类,禁止危险操作。")
.parameters(JsonUtils.parseString(javaCodeParamsSchema).getAsJsonObject())
.build();
// 5. 直接构建 ToolFunction 列表(SDK 官方正确写法)
List<ToolFunction> toolFunctions = new ArrayList<>();
toolFunctions.add(ToolFunction.builder().function(weatherFunction).build());
toolFunctions.add(ToolFunction.builder().function(javaCodeFunction).build());
// 6. 构建 GenerationParam(更新工具列表)
GenerationParam param = GenerationParam.builder()
.model(model) // 如 qwen-plus
.apiKey(apiKey)
.messages(messages)
.tools(Collections.unmodifiableList(toolFunctions)) // 替换为包含Java工具的列表
.enableSearch(true) // 启用联网搜索
.resultFormat(GenerationParam.ResultFormat.MESSAGE)
.incrementalOutput(true) // 流式输出
.build();
// 7. 初始化 Generation 客户端(原有)
Generation generation = new Generation();
StringBuilder aiReplyContent = new StringBuilder();
// ========== 处理工具调用 + 流式返回(原有逻辑,自动适配新工具) ==========
processToolCallAndStreamResponse(generation, param, emitter, aiReplyContent, messages);
// 8. 更新缓存(原有逻辑)
updateChatCache(userId, cleanNewQuestion, aiReplyContent.toString());
} catch (NoApiKeyException e) {
sendError(emitter, "未配置DashScope API Key");
} catch (ApiException | InputRequiredException e) {
sendError(emitter, "AI调用失败:" + e.getMessage());
} catch (Exception e) {
log.error("流式对话核心逻辑异常", e);
sendError(emitter, "系统异常:" + e.getMessage());
}
}).start();
return emitter;
}
/**
* 处理工具调用 + 流式返回结果(原有逻辑,无需修改)
*/
private void processToolCallAndStreamResponse(Generation generation, GenerationParam param,
SseEmitter emitter, StringBuilder aiReplyContent,
List<Message> messages) throws Exception {
// 首次调用模型
GenerationResult result = generation.call(param);
Message assistantOutput = result.getOutput().getChoices().get(0).getMessage();
messages.add(assistantOutput);
// 判断是否需要调用工具(官方逻辑)
if (assistantOutput.getToolCalls() == null || assistantOutput.getToolCalls().isEmpty()) {
// 无需工具调用,直接推送结果
pushContentToEmitter(emitter, assistantOutput.getContent(), aiReplyContent);
} else {
// 需要工具调用,循环处理(支持多次工具调用)
while (assistantOutput.getToolCalls() != null && !assistantOutput.getToolCalls().isEmpty()) {
ToolCallBase toolCall = assistantOutput.getToolCalls().get(0);
ToolCallFunction functionCall = (ToolCallFunction) toolCall;
String funcName = functionCall.getFunction().getName();
String arguments = functionCall.getFunction().getArguments();
log.info("调用工具 [{}],参数:{}", funcName, arguments);
// 执行本地工具方法(新增Java工具会在这里被调用)
String toolResult = executeLocalFunction(funcName, arguments);
log.info("工具返回:{}", toolResult);
// 构建工具返回消息(官方规范:role=tool)
Message toolMessage = Message.builder()
.role("tool")
.toolCallId(toolCall.getId())
.content(toolResult)
.build();
messages.add(toolMessage);
// 再次调用模型,基于工具结果生成最终回复
param.setMessages(messages);
result = generation.call(param);
assistantOutput = result.getOutput().getChoices().get(0).getMessage();
messages.add(assistantOutput);
}
// 推送最终回复
pushContentToEmitter(emitter, assistantOutput.getContent(), aiReplyContent);
}
// 流式结束
emitter.complete();
}
/**
* 执行本地工具函数(核心修改:新增run_java_code分支)
*/
private String executeLocalFunction(String funcName, String arguments) {
try {
switch (funcName) {
case "get_current_weather":
// 解析官方模型传入的参数(JSON 格式)
JsonNode argsNode = OBJECT_MAPPER.readTree(arguments);
String city = argsNode.get("city").asText();
// 调用聚合天气接口(原有逻辑)
return getWeather(city);
case "run_java_code":
// 解析Java代码参数(新增)
JsonNode javaArgsNode = OBJECT_MAPPER.readTree(arguments);
String javaCode = javaArgsNode.get("code").asText();
// 调用本地Java代码执行方法
return runJavaCode(javaCode);
default:
return "未找到工具函数:" + funcName;
}
} catch (Exception e) {
log.error("执行工具函数失败", e);
return "工具执行异常:" + e.getMessage();
}
}
/**
* 新增核心方法:本地编译并运行Java代码
* @param javaCode 完整的Java代码字符串(需包含main方法)
* @return 执行结果/编译错误/运行错误
*/
private String runJavaCode(String javaCode) {
// 1. 初始化临时目录
File tempDir = new File(JAVA_TEMP_DIR);
if (!tempDir.exists()) {
boolean mkdirSuccess = tempDir.mkdirs();
if (!mkdirSuccess) {
return "创建临时目录失败:" + JAVA_TEMP_DIR;
}
}
// 2. 生成随机类名(避免冲突)
String className = "TempJavaCode_" + System.currentTimeMillis() + "_" + new Random().nextInt(1000);
String javaFileName = className + ".java";
File javaFile = new File(tempDir, javaFileName);
try {
// 3. 写入Java代码到临时文件(替换类名为生成的随机名)
String formattedCode = javaCode.replaceFirst("public class\\s+\\w+", "public class " + className);
Files.write(javaFile.toPath(), formattedCode.getBytes(StandardCharsets.UTF_8));
log.info("生成临时Java文件:{},代码:\n{}", javaFile.getAbsolutePath(), formattedCode);
// 4. 编译Java代码(调用javac)
Process compileProcess = new ProcessBuilder()
.command(JDK_COMPILE_CMD, "-encoding", "UTF-8", javaFile.getAbsolutePath())
.directory(tempDir)
.redirectErrorStream(true)
.start();
// 等待编译完成并捕获输出
String compileOutput = readProcessOutput(compileProcess);
compileProcess.waitFor(EXEC_TIMEOUT_SECONDS, TimeUnit.SECONDS);
if (compileProcess.exitValue() != 0) {
return "编译失败:\n" + compileOutput;
}
// 5. 运行编译后的类(调用java)
Process runProcess = new ProcessBuilder()
.command(JDK_RUN_CMD, "-cp", tempDir.getAbsolutePath(), className)
.directory(tempDir)
.redirectErrorStream(true)
.start();
// 等待运行完成并捕获输出
String runOutput = readProcessOutput(runProcess);
runProcess.waitFor(EXEC_TIMEOUT_SECONDS, TimeUnit.SECONDS);
if (runProcess.exitValue() != 0) {
return "运行失败:\n" + runOutput;
}
return "执行成功:\n" + runOutput;
} catch (Exception e) {
log.error("执行Java代码异常", e);
return "Java代码执行异常:" + e.getMessage();
} finally {
// 6. 清理临时文件(安全防护)
try {
if (javaFile.exists()) {
Files.deleteIfExists(javaFile.toPath());
}
File classFile = new File(tempDir, className + ".class");
if (classFile.exists()) {
Files.deleteIfExists(classFile.toPath());
}
} catch (IOException e) {
log.warn("清理临时Java文件失败", e);
}
}
}
/**
* 辅助方法:读取进程输出(修复编码乱码)
*/
private String readProcessOutput(Process process) throws IOException {
try (BufferedReader reader = new BufferedReader(
// 显式指定UTF-8编码,避免系统默认编码导致乱码
new InputStreamReader(process.getInputStream(), StandardCharsets.UTF_8))) {
return reader.lines().collect(Collectors.joining("\n"));
}
}
/**
* 真实调用聚合数据天气接口(原有逻辑,修正拼写错误)
*/
private String getWeather(String city) {
try {
Map<String, String> params = new HashMap<>();
params.put("key", juheWeatherApiKey);
params.put("city", city);
String paramStr = params.entrySet().stream()
.map(entry -> {
try {
return entry.getKey() + "=" + URLEncoder.encode(entry.getValue(), StandardCharsets.UTF_8);
} catch (Exception e) {
log.error("参数编码失败", e);
return entry.getKey() + "=" + entry.getValue();
}
})
.collect(Collectors.joining("&"));
URL url = new URL(String.format("%s?%s", JUHE_WEATHER_API_URL, paramStr));
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
connection.setRequestMethod("GET");
connection.setConnectTimeout(5000);
connection.setReadTimeout(5000);
try (BufferedReader in = new BufferedReader(
new InputStreamReader(connection.getInputStream(), StandardCharsets.UTF_8))) {
String inputLine;
StringBuffer response = new StringBuffer();
while ((inputLine = in.readLine()) != null) {
response.append(inputLine);
}
JSONObject resultJson = JSONObject.parseObject(response.toString());
int errorCode = resultJson.getIntValue("error_code");
if (errorCode != 0) {
return "天气查询失败:" + resultJson.getString("reason");
}
JSONObject realtime = resultJson.getJSONObject("result").getJSONObject("realtime");
String weather = realtime.getString("info");
String temp = realtime.getString("temperature");
String windDir = realtime.getString("direct");
String windPower = realtime.getString("power");
String humidity = realtime.getString("humidity");
return String.format("%s当前天气:%s,气温:%s℃,风向:%s,风力:%s级,湿度:%s%%",
city, weather, temp, windDir, windPower, humidity);
} finally {
connection.disconnect();
}
} catch (Exception e) {
log.error("调用聚合天气接口异常", e);
return "天气查询出错:" + e.getMessage();
}
}
// ========== 以下为原有工具方法,无需修改 ==========
private void pushContentToEmitter(SseEmitter emitter, String content, StringBuilder aiReplyContent) {
try {
if (content != null && !content.isEmpty()) {
content = content.replaceAll("(?m)^data:\\s*", "").trim();
aiReplyContent.append(content);
emitter.send(content.getBytes(StandardCharsets.UTF_8), MediaType.TEXT_PLAIN);
log.debug("推送内容:{}", content);
}
} catch (IOException e) {
log.error("推送失败", e);
emitter.complete();
}
}
private List<ChatRequestDTO.MessageDTO> mergeHistory(List<ChatRequestDTO.MessageDTO> cacheHistory,
List<ChatRequestDTO.MessageDTO> manualHistory) {
List<ChatRequestDTO.MessageDTO> merged = new ArrayList<>(cacheHistory);
if (manualHistory == null || manualHistory.isEmpty()) {
return merged;
}
for (ChatRequestDTO.MessageDTO msg : manualHistory) {
if (chatContextCacheUtil.isValidMessage(msg) && !merged.contains(msg)) {
merged.add(msg);
}
}
return merged;
}
private void updateChatCache(String userId, String userQuestion, String aiReply) {
ChatRequestDTO.MessageDTO userMsg = new ChatRequestDTO.MessageDTO();
userMsg.setRole("user");
userMsg.setContent(userQuestion.trim());
if (chatContextCacheUtil.isValidMessage(userMsg)) {
chatContextCacheUtil.updateHistory(userId, userMsg);
}
ChatRequestDTO.MessageDTO aiMsg = new ChatRequestDTO.MessageDTO();
aiMsg.setRole("assistant");
aiMsg.setContent(aiReply.trim());
if (chatContextCacheUtil.isValidMessage(aiMsg)) {
chatContextCacheUtil.updateHistory(userId, aiMsg);
}
}
private List<Message> convertToDashScopeMessages(List<ChatRequestDTO.MessageDTO> dtoMessages) {
List<Message> messages = new ArrayList<>();
if (dtoMessages == null || dtoMessages.isEmpty()) {
return messages;
}
for (ChatRequestDTO.MessageDTO dtoMsg : dtoMessages) {
String role = dtoMsg.getRole().toLowerCase();
if (!List.of("user", "assistant", "tool").contains(role)) {
log.warn("跳过非法角色消息:{}", dtoMsg);
continue;
}
String cleanContent = dtoMsg.getContent()
.replaceAll("\\*\\*", "")
.replaceAll("\\s+", " ")
.trim();
if (cleanContent.isEmpty()) {
continue;
}
messages.add(Message.builder()
.role(role)
.content(cleanContent)
.build());
}
return messages;
}
private void initEmitterCallback(SseEmitter emitter) {
emitter.onTimeout(() -> {
log.warn("SSE连接超时");
emitter.complete();
});
emitter.onError(e -> {
log.error("SSE连接异常", e);
emitter.complete();
});
emitter.onCompletion(() -> log.info("SSE连接正常关闭"));
}
@Override
public GenerationParam buildGenerationParam(Message userMsg) {
return GenerationParam.builder()
.apiKey(apiKey)
.model(model)
.messages(Collections.singletonList(userMsg))
.resultFormat(GenerationParam.ResultFormat.MESSAGE)
.incrementalOutput(true)
.build();
}
private void sendError(SseEmitter emitter, String errorMsg) {
try {
byte[] errorBytes = errorMsg.getBytes(StandardCharsets.UTF_8);
emitter.send(errorBytes, MediaType.TEXT_PLAIN);
} catch (IOException e) {
log.error("推送SSE错误消息失败", e);
} finally {
emitter.complete();
}
}
}
3. 结果


创作不易,请各位靓仔靓女点个关注啵~(◦˙▽˙◦)