AI模型多阶段调用进度追踪系统设计文档
一、系统概述
为解决AI模型处理大型文件时响应时间长的问题,我们设计并实现了一套异步进度追踪系统。该系统采用Server-Sent Events (SSE) 技术,建立从服务器到客户端的单向实时通信通道,使前端能够实时获取后端文件处理的进度信息。
二、核心功能
- 异步文件处理:通过异步方式处理上传的文件,避免阻塞HTTP请求线程
- 实时进度推送:使用SSE技术向前端实时推送处理进度
- 多阶段处理跟踪:精确追踪每个AI处理阶段的进度
- 错误处理与恢复:完善的异常处理机制,确保资源得到释放
三、系统架构
系统架构由以下几个关键组件组成:
- 进度追踪服务:核心服务,管理任务进度和SSE连接
- 进度追踪控制器:提供REST API,供前端获取进度信息
- 文件处理服务增强:在现有服务上增加进度报告功能
- 文件上传控制器适配:修改现有控制器适配异步处理
四、后端实现代码
1. 进度追踪服务接口(IProgressTrackingService.java)
java
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
/**
* 进度追踪服务接口
* 用于追踪长时间运行任务的进度并推送给前端
*/
public interface IProgressTrackingService {
/**
* 创建新的进度追踪会话
*
* @param taskId 任务ID
* @return SseEmitter 事件发射器
*/
SseEmitter createProgressTracker(String taskId);
/**
* 更新任务进度
*
* @param taskId 任务ID
* @param currentStage 当前阶段索引(从0开始)
* @param totalStages 总阶段数
* @param message 进度消息
*/
void updateProgress(String taskId, int currentStage, int totalStages, String message);
/**
* 完成任务进度追踪
*
* @param taskId 任务ID
* @param success 是否成功
* @param message 完成消息
*/
void completeProgress(String taskId, boolean success, String message);
}
2. 进度追踪服务实现(ProgressTrackingServiceImpl.java)
java
import com.fasterxml.jackson.databind.ObjectMapper;
import com.greatech.abnormal_monitoring.service.IProgressTrackingService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* 进度追踪服务实现类
* 用于追踪长时间运行任务的进度并推送给前端
*/
@Slf4j
@Service
public class ProgressTrackingServiceImpl implements IProgressTrackingService {
// 存储任务ID到SSE发射器的映射
private final Map<String, SseEmitter> emitterMap = new ConcurrentHashMap<>();
// JSON对象映射器
private final ObjectMapper objectMapper = new ObjectMapper();
/**
* 创建新的进度追踪会话
*
* @param taskId 任务ID
* @return SseEmitter 事件发射器
*/
@Override
public SseEmitter createProgressTracker(String taskId) {
// 创建一个超时时间为30分钟的SSE发射器
SseEmitter emitter = new SseEmitter(1800000L);
// 设置完成回调
emitter.onCompletion(() -> {
log.info("进度追踪会话已完成: {}", taskId);
emitterMap.remove(taskId);
});
// 设置超时回调
emitter.onTimeout(() -> {
log.warn("进度追踪会话超时: {}", taskId);
emitterMap.remove(taskId);
});
// 设置错误回调
emitter.onError(ex -> {
log.error("进度追踪会话发生错误: {}", taskId, ex);
emitterMap.remove(taskId);
});
// 将发射器存储到映射中
emitterMap.put(taskId, emitter);
try {
// 发送初始化事件
Map<String, Object> initialEvent = Map.of(
"type", "INIT",
"taskId", taskId,
"message", "进度追踪会话已创建",
"progress", 0
);
emitter.send(SseEmitter.event().data(objectMapper.writeValueAsString(initialEvent)));
} catch (IOException e) {
log.error("发送初始化事件失败: {}", taskId, e);
emitter.completeWithError(e);
}
return emitter;
}
/**
* 更新任务进度
*
* @param taskId 任务ID
* @param currentStage 当前阶段索引(从0开始)
* @param totalStages 总阶段数
* @param message 进度消息
*/
@Override
public void updateProgress(String taskId, int currentStage, int totalStages, String message) {
SseEmitter emitter = emitterMap.get(taskId);
if (emitter == null) {
log.warn("尝试更新不存在的进度追踪会话: {}", taskId);
return;
}
try {
// 计算进度百分比
int progress = (int) Math.round((currentStage * 100.0) / totalStages);
// 创建进度事件
Map<String, Object> progressEvent = Map.of(
"type", "PROGRESS",
"taskId", taskId,
"currentStage", currentStage,
"totalStages", totalStages,
"progress", progress,
"message", message
);
// 发送进度事件
emitter.send(SseEmitter.event().data(objectMapper.writeValueAsString(progressEvent)));
log.debug("已更新进度: taskId={}, stage={}/{}, progress={}%, message={}",
taskId, currentStage, totalStages, progress, message);
} catch (IOException e) {
log.error("发送进度更新失败: {}", taskId, e);
emitter.completeWithError(e);
}
}
/**
* 完成任务进度追踪
*
* @param taskId 任务ID
* @param success 是否成功
* @param message 完成消息
*/
@Override
public void completeProgress(String taskId, boolean success, String message) {
SseEmitter emitter = emitterMap.get(taskId);
if (emitter == null) {
log.warn("尝试完成不存在的进度追踪会话: {}", taskId);
return;
}
try {
// 创建完成事件
Map<String, Object> completeEvent = Map.of(
"type", success ? "COMPLETE" : "ERROR",
"taskId", taskId,
"progress", success ? 100 : -1,
"message", message,
"success", success
);
// 发送完成事件
emitter.send(SseEmitter.event().data(objectMapper.writeValueAsString(completeEvent)));
log.info("进度追踪已完成: taskId={}, success={}, message={}", taskId, success, message);
// 关闭SSE连接
emitter.complete();
} catch (IOException e) {
log.error("发送完成事件失败: {}", taskId, e);
emitter.completeWithError(e);
} finally {
// 确保从映射中移除
emitterMap.remove(taskId);
}
}
}
3. 进度追踪控制器(ProgressTrackingController.java)
java
import com.greatech.abnormal_monitoring.service.IProgressTrackingService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.UUID;
/**
* 进度追踪控制器
* 提供创建和获取进度追踪会话的接口
*/
@Slf4j
@RestController
@RequestMapping("/api/progress")
@RequiredArgsConstructor
@Tag(name = "进度追踪", description = "提供长时间运行任务的进度追踪功能")
public class ProgressTrackingController {
private final IProgressTrackingService progressTrackingService;
/**
* 创建新的进度追踪会话
*
* @return SseEmitter 事件发射器
*/
@Operation(summary = "创建进度追踪会话", description = "创建一个新的进度追踪会话并返回SSE连接")
@GetMapping(value = "/track", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public SseEmitter createProgressTracker(
@Parameter(description = "任务ID(可选)") @RequestParam(required = false) String taskId) {
// 如果未提供任务ID,则生成一个新ID
if (taskId == null || taskId.isEmpty()) {
taskId = UUID.randomUUID().toString();
}
log.info("创建进度追踪会话: {}", taskId);
return progressTrackingService.createProgressTracker(taskId);
}
/**
* 获取特定任务的进度追踪会话
*
* @param taskId 任务ID
* @return SseEmitter 事件发射器
*/
@Operation(summary = "获取任务进度", description = "根据任务ID获取进度追踪会话并返回SSE连接")
@GetMapping(value = "/track/{taskId}", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public SseEmitter getProgressTracker(
@Parameter(description = "任务ID", required = true) @PathVariable String taskId) {
log.info("获取进度追踪会话: {}", taskId);
return progressTrackingService.createProgressTracker(taskId);
}
}
4. 修改后的文件上传控制器方法
java
/**
* 上传模板文件并解析数据
*
* @param file 模板文件
* @param templateId 模板ID
* @return 上传结果
*/
@Operation(summary = "上传模板文件并解析", description = "上传模板文件到服务器并解析其内容到数据库")
@PostMapping("/upload-template")
public ServerResponse<Map<String, Object>> uploadAndParseTemplate(
@Parameter(description = "模板文件", required = true) @RequestParam("file") MultipartFile file,
@RequestParam("templateId") Long templateId) {
log.info("开始处理模板文件上传请求,模板ID: {}, 文件名: {}", templateId, file.getOriginalFilename());
try {
// 处理文件并获取结果(现在包含任务ID)
String result = templateService.processTemplateFile(file, templateId);
// 从结果中提取任务ID
String taskId = "";
if (result.contains("任务ID:")) {
taskId = result.substring(result.indexOf("任务ID:") + 5).trim();
} else {
// 对于不支持进度追踪的文件处理(如Excel文件),直接返回处理结果
log.info("模板文件处理成功,模板ID: {}, 文件名: {}", templateId, file.getOriginalFilename());
Map<String, Object> directResult = new HashMap<>();
directResult.put("message", result);
directResult.put("completed", true);
return ServerResponse.success(directResult);
}
log.info("模板文件上传成功并开始异步处理,模板ID: {}, 文件名: {}, 任务ID: {}",
templateId, file.getOriginalFilename(), taskId);
// 构建响应数据
Map<String, Object> responseData = new HashMap<>();
responseData.put("taskId", taskId);
responseData.put("message", "文件上传成功,正在异步处理中");
responseData.put("completed", false);
return ServerResponse.success(responseData);
} catch (Exception e) {
log.error("模板文件处理失败,模板ID: {}, 文件名: {}", templateId, file.getOriginalFilename(), e);
return ServerResponse.error("文件处理失败: " + e.getMessage());
}
}
5. 模板服务处理文本文件方法(TemplateServiceImpl.java中的processTextFile方法)
java
/**
* 处理文本文件
*
* @param file 已保存的文件
* @param templateId 模板ID
* @return 处理结果
*/
public String processTextFile(File file, Long templateId) {
log.info("处理文本文件:{}, 使用会话方式调用大模型API", file.getName());
// 生成唯一任务ID用于进度追踪
String taskId = UUID.randomUUID().toString();
log.info("生成进度追踪任务ID: {}", taskId);
// 创建新的会话ID
String sessionId = aiCommonService.createNewSession();
log.info("创建新会话:{}", sessionId);
// 使用CompletableFuture异步处理
CompletableFuture.runAsync(() -> {
try {
// 将文件转换为MultipartFile以便上传
MultipartFile multipartFile = convertToMultipartFile(file);
// 读取提示词模板并按阶段拆分
String promptTemplateContent = readPromptTemplate();
List<String> promptStages = splitContentIntoStages(promptTemplateContent);
log.info("提示词模板已拆分为{}个阶段", promptStages.size());
// 初始化进度
progressTrackingService.updateProgress(taskId, 0, promptStages.size(),
"开始处理文本文件,共" + promptStages.size() + "个阶段");
String jsonResponse = null;
// 顺序执行每个阶段,只保留最后一个阶段的结果
for (int i = 0; i < promptStages.size(); i++) {
String stagePrompt = promptStages.get(i);
// 更新进度
progressTrackingService.updateProgress(taskId, i, promptStages.size(),
"正在执行阶段 " + (i+1) + "/" + promptStages.size());
log.info("执行阶段 {}/{}", i+1, promptStages.size());
// 对于第一个阶段,需要上传文件
if (i == 0) {
jsonResponse = aiCommonService.callAiModelWithFileContext(stagePrompt, sessionId, multipartFile);
} else {
// 后续阶段只需要发送提示词
jsonResponse = aiCommonService.callAiModelWithFileContext(stagePrompt, sessionId, null);
}
// 仅记录当前处于哪个阶段,不记录完整响应以避免日志过大
log.info("阶段 {}/{} 执行完成", i+1, promptStages.size());
// 更新进度
progressTrackingService.updateProgress(taskId, i+1, promptStages.size(),
"已完成阶段 " + (i+1) + "/" + promptStages.size());
}
log.info("所有阶段执行完成,获取最终JSON响应");
progressTrackingService.updateProgress(taskId, promptStages.size(), promptStages.size(),
"所有阶段执行完成,正在处理最终结果");
// 将JSON响应转换为Excel文件
String excelFilePath = convertJsonToExcelFile(jsonResponse, templateId);
// 使用Excel解析逻辑处理生成的Excel文件
sheetService.parseExcelFile(new File(excelFilePath), templateId);
// 清理会话资源
aiCommonService.clearSession(sessionId);
log.info("会话已清理:{}", sessionId);
// 标记任务完成
progressTrackingService.completeProgress(taskId, true,
"文本文件已成功处理并解析,原始文件路径:" + file.getAbsolutePath());
} catch (Exception e) {
// 出错时也要尝试清理会话资源
try {
aiCommonService.clearSession(sessionId);
log.info("会话清理完成:{}", sessionId);
} catch (Exception ex) {
log.error("清理会话失败:{}", sessionId, ex);
}
log.error("处理文本文件失败", e);
// 标记任务失败
progressTrackingService.completeProgress(taskId, false,
"处理文件失败: " + e.getMessage());
}
});
// 立即返回,不阻塞请求线程
return "文件处理已开始,可通过任务ID: " + taskId + " 追踪进度";
}
6. 提取阶段的灵活实现(splitContentIntoStages方法)
java
/**
* 将文件内容按阶段拆分,使用配置文件中的设置
*
* @param content 文件内容
* @return 拆分后的阶段列表
*/
private List<String> splitContentIntoStages(String content) {
List<String> stages = new ArrayList<>();
// 1. 首先尝试使用配置的标识符列表
String identifiersConfig = environment.getProperty("template.stage.identifiers");
if (identifiersConfig != null && !identifiersConfig.isEmpty()) {
return splitByConfiguredIdentifiers(content, identifiersConfig);
}
// 2. 如果没有配置具体标识符,则使用正则表达式
return splitByRegexPattern(content);
}
/**
* 根据配置的标识符列表拆分内容
*/
private List<String> splitByConfiguredIdentifiers(String content, String identifiersConfig) {
List<String> stages = new ArrayList<>();
String[] identifiers = identifiersConfig.split(",");
log.info("使用配置的 {} 个阶段标识符拆分内容", identifiers.length);
// 寻找每个阶段的起始和结束位置
int currentPos = 0;
for (int i = 0; i < identifiers.length; i++) {
String identifier = identifiers[i].trim();
int stageStart = content.indexOf(identifier, currentPos);
if (stageStart != -1) {
// 找到阶段起始位置
int nextStageStart = -1;
// 查找下一个阶段的起始位置(如果存在)
if (i < identifiers.length - 1) {
nextStageStart = content.indexOf(identifiers[i + 1].trim(), stageStart);
}
// 提取当前阶段内容
String stageContent;
if (nextStageStart != -1) {
stageContent = content.substring(stageStart, nextStageStart).trim();
currentPos = nextStageStart;
} else {
// 如果是最后一个阶段,则截取到文件末尾
stageContent = content.substring(stageStart).trim();
}
log.info("提取阶段 {}: {} 字符", i+1, stageContent.length());
stages.add(stageContent);
} else {
log.warn("未找到阶段标识符: {}", identifier);
}
}
return stages;
}
/**
* 使用正则表达式模式拆分内容
*/
private List<String> splitByRegexPattern(String content) {
List<String> stages = new ArrayList<>();
// 获取配置的正则表达式,如果未配置则使用默认值
String patternStr = environment.getProperty("template.stage.pattern");
if (patternStr == null || patternStr.isEmpty()) {
patternStr = "(?:#{0,3}\\s*【?\\s*阶段\\s*(\\d+)\\s*】?)";
}
log.info("使用正则表达式模式拆分内容: {}", patternStr);
// 使用正则表达式匹配阶段标识符
Pattern stagePattern = Pattern.compile(patternStr, Pattern.CASE_INSENSITIVE);
Matcher matcher = stagePattern.matcher(content);
// 存储找到的所有阶段起始位置和编号
Map<Integer, Integer> positionToStageNumber = new TreeMap<>();
while (matcher.find()) {
int stageNumber = Integer.parseInt(matcher.group(1));
positionToStageNumber.put(matcher.start(), stageNumber);
log.debug("发现阶段 {} 在位置 {}", stageNumber, matcher.start());
}
// 如果没有找到任何阶段标识,将整个内容作为一个阶段
if (positionToStageNumber.isEmpty()) {
log.warn("未在内容中找到阶段标识符,将整个内容作为单一阶段处理");
stages.add(content);
return stages;
}
// 是否严格按照阶段编号排序
boolean strictOrder = Boolean.parseBoolean(
environment.getProperty("template.stage.strict-order", "true")
);
// 获取最大阶段数限制
int maxStageCount = Integer.parseInt(
environment.getProperty("template.stage.max-count", "10")
);
// 提取每个阶段的内容
List<Integer> positions = new ArrayList<>(positionToStageNumber.keySet());
Collections.sort(positions);
Map<Integer, String> numberToContent = new HashMap<>();
for (int i = 0; i < positions.size(); i++) {
int startPos = positions.get(i);
int endPos = (i < positions.size() - 1) ? positions.get(i + 1) : content.length();
String stageContent = content.substring(startPos, endPos).trim();
int stageNumber = positionToStageNumber.get(startPos);
if (stageNumber > maxStageCount) {
log.warn("阶段编号 {} 超过最大限制 {},将被忽略", stageNumber, maxStageCount);
continue;
}
log.info("提取阶段 {}: {} 字符", stageNumber, stageContent.length());
numberToContent.put(stageNumber, stageContent);
}
// 按照阶段编号排序
if (strictOrder) {
List<Integer> sortedNumbers = new ArrayList<>(numberToContent.keySet());
Collections.sort(sortedNumbers);
for (Integer num : sortedNumbers) {
stages.add(numberToContent.get(num));
}
} else {
// 按照在文件中的位置顺序
for (int i = 0; i < positions.size(); i++) {
int stageNumber = positionToStageNumber.get(positions.get(i));
if (stageNumber <= maxStageCount) {
stages.add(numberToContent.get(stageNumber));
}
}
}
log.info("共拆分出 {} 个阶段", stages.size());
return stages;
}
五、前端调用说明
这里提供简要的前端调用说明,以便于前端开发人员理解如何与后端集成。
1. 前端调用流程
前端与后端的交互流程如下:
- 文件上传:使用标准的HTTP POST请求上传文件到
/api/template/upload-template
- 获取任务ID:从上传响应中提取任务ID
- 建立SSE连接:使用
EventSource
接口建立与/api/progress/track/{taskId}
的连接 - 处理进度更新:监听并处理SSE事件,更新UI显示进度
- 处理完成通知:接收处理完成或失败的通知,更新UI状态
2. 代码示例
以下是前端实现的简化示例:
javascript
// 文件上传示例
async function uploadFile(file, templateId) {
const formData = new FormData();
formData.append('file', file);
formData.append('templateId', templateId);
try {
const response = await fetch('/api/template/upload-template', {
method: 'POST',
body: formData
});
const result = await response.json();
if (result.code === 200) {
if (result.data.taskId) {
// 异步处理任务,开始跟踪进度
trackProgress(result.data.taskId);
} else if (result.data.completed) {
// 同步处理已完成
showSuccess(result.data.message);
}
} else {
showError(result.message);
}
} catch (error) {
showError('上传失败: ' + error.message);
}
}
// 进度追踪示例
function trackProgress(taskId) {
// 创建EventSource连接
const eventSource = new EventSource(`/api/progress/track/${taskId}`);
// 处理消息事件
eventSource.onmessage = (event) => {
const data = JSON.parse(event.data);
switch (data.type) {
case 'INIT':
updateProgressUI(0, data.message);
break;
case 'PROGRESS':
updateProgressUI(data.progress, data.message);
break;
case 'COMPLETE':
updateProgressUI(100, data.message);
showSuccess(data.message);
eventSource.close();
break;
case 'ERROR':
showError(data.message);
eventSource.close();
break;
}
};
// 处理错误
eventSource.onerror = (error) => {
showError('连接中断,无法获取进度更新');
eventSource.close();
};
}
// 更新UI显示进度(由前端实现)
function updateProgressUI(progress, message) {
// 更新进度条
document.getElementById('progress-bar').value = progress;
document.getElementById('progress-percentage').textContent = progress + '%';
document.getElementById('progress-message').textContent = message;
}
// 显示成功消息(由前端实现)
function showSuccess(message) {
// 显示成功提示
}
// 显示错误消息(由前端实现)
function showError(message) {
// 显示错误提示
}
在Spring Boot应用程序配置中启用异步支持:
java
@Configuration
@EnableAsync
public class AsyncConfig implements AsyncConfigurer {
@Override
public Executor getAsyncExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
executor.setCorePoolSize(5);
executor.setMaxPoolSize(10);
executor.setQueueCapacity(25);
executor.setThreadNamePrefix("async-");
executor.initialize();
return executor;
}
@Override
public AsyncUncaughtExceptionHandler getAsyncUncaughtExceptionHandler() {
return new SimpleAsyncUncaughtExceptionHandler();
}
}
确保在应用程序上下文中注册以下组件:
- ProgressTrackingServiceImpl
- ProgressTrackingController
六、安全性考虑
- 任务ID安全:生成的任务ID应使用安全的随机数生成器,避免被猜测
- 资源限制:设置SSE连接和异步处理的超时时间
- 错误处理:确保异常不会泄露敏感信息
- 防止资源泄露:正确清理未使用的连接和资源
- 访问控制:根据需要为进度追踪API添加身份验证
七、性能优化
- 异步线程池调优:根据服务器性能和预期负载调整线程池参数
- 超时控制:为长时间运行的任务设置合理的超时时间
- 内存管理:避免在SSE消息中发送大量数据
- 数据压缩:考虑对SSE消息进行压缩,减少网络流量
八、故障排除
常见问题及解决方案:
SSE连接中断
- 原因:网络问题、服务器重启、超时
- 解决:前端实现自动重连逻辑
任务处理缓慢
- 原因:AI模型处理速度、服务器资源不足
- 解决:增加服务器资源、优化AI调用参数
任务进度更新失败
- 原因:SSE连接已关闭、进度追踪服务异常
- 解决:检查日志、确保任务ID正确
资源泄露
- 原因:未正确关闭SSE连接或清理资源
- 解决:确保所有异常路径都调用资源清理方法
九、扩展方向
系统可以在以下方向进行扩展:
- 任务持久化:将任务进度存储到数据库,支持服务重启后恢复
- 任务队列:使用消息队列管理大量任务,避免服务器过载
- 分布式部署:支持集群环境中的任务进度追踪
- 管理界面:添加管理界面,监控和管理所有正在运行的任务
- 推送通知:集成WebSocket或推送通知,提供更实时的进度更新
十、总结
本系统提供了一种高效、可靠的方式来追踪AI模型多阶段调用的处理进度。通过使用Server-Sent Events技术和异步处理,系统能够在长时间运行的任务中向前端实时推送进度更新,大大提高了用户体验。该设计将文件处理与进度报告解耦,使系统更加灵活和可维护。
系统的核心优势在于:
- 实时进度反馈:通过SSE技术向前端实时推送进度信息
- 多阶段处理:支持AI模型的多阶段调用,精确追踪每个阶段的进度
- 灵活配置:支持通过配置文件自定义阶段标识符和处理行为
- 完善的错误处理:全面的异常处理和资源清理机制
- 易于集成:前端只需少量代码即可集成该功能
通过这一设计,我们解决了长时间运行任务缺乏进度反馈的问题,提供了更好的用户体验,同时提高了系统的可靠性和可维护性。