第五篇:生产篇
第 12 章:性能优化与成本控制
12.1 缓存策略
12.1.1 响应缓存
java
@Service
public class CachedChatService {
private final ChatClient chatClient;
private final CacheManager cacheManager;
public String chatWithCache(String message, String sessionId) {
// 生成缓存键
String cacheKey = generateCacheKey(message, sessionId);
// 尝试从缓存获取
Cache cache = cacheManager.getCache("ai-responses");
if (cache != null) {
String cached = cache.get(cacheKey, String.class);
if (cached != null) {
return cached;
}
}
// 调用 AI
String response = chatClient.prompt()
.user(message)
.advisors(a -> a.param("conversationId", sessionId))
.call()
.content();
// 缓存结果
if (cache != null && shouldCache(message, response)) {
cache.put(cacheKey, response);
}
return response;
}
private String generateCacheKey(String message, String sessionId) {
// 对于不依赖上下文的查询,忽略 sessionId
if (isContextIndependent(message)) {
return "global:" + DigestUtils.md5Hex(message);
}
return sessionId + ":" + DigestUtils.md5Hex(message);
}
private boolean shouldCache(String message, String response) {
// 不缓存错误或敏感信息
if (response.contains("错误") || response.contains("抱歉")) {
return false;
}
// 不缓存过长响应
if (response.length() > 2000) {
return false;
}
return true;
}
private boolean isContextIndependent(String message) {
// 事实性问题通常不依赖上下文
return message.startsWith("什么是") ||
message.startsWith("谁发明了") ||
message.startsWith("如何计算");
}
}
Redis 缓存配置:
yaml
spring:
cache:
type: redis
redis:
time-to-live: 3600000 # 1 小时
cache-null-values: false
ai:
cache:
enabled: true
ttl: 3600
12.1.2 嵌入缓存
java
@Service
public class CachedEmbeddingService {
private final EmbeddingModel embeddingModel;
private final CacheManager cacheManager;
public Embedding embedWithCache(String text) {
String cacheKey = "embedding:" + DigestUtils.md5Hex(text);
Cache cache = cacheManager.getCache("embeddings");
if (cache != null) {
byte[] cached = cache.get(cacheKey, byte[].class);
if (cached != null) {
return deserializeEmbedding(cached);
}
}
Embedding embedding = embeddingModel.embed(text).content();
if (cache != null) {
cache.put(cacheKey, serializeEmbedding(embedding));
}
return embedding;
}
private byte[] serializeEmbedding(Embedding embedding) {
// 序列化为字节数组
float[] vector = embedding.vector();
ByteBuffer buffer = ByteBuffer.allocate(vector.length * 4);
for (float v : vector) {
buffer.putFloat(v);
}
return buffer.array();
}
private Embedding deserializeEmbedding(byte[] data) {
ByteBuffer buffer = ByteBuffer.wrap(data);
float[] vector = new float[data.length / 4];
for (int i = 0; i < vector.length; i++) {
vector[i] = buffer.getFloat();
}
return new Embedding(vector);
}
}
12.2 Token 优化
12.2.1 Prompt 压缩
java
@Service
public class PromptOptimizationService {
private final ChatClient optimizationClient;
public String optimizePrompt(String originalPrompt, int targetTokens) {
int currentTokens = countTokens(originalPrompt);
if (currentTokens <= targetTokens) {
return originalPrompt;
}
String compressionPrompt = """
请压缩以下 prompt,保留核心信息,去除冗余表达。
目标长度:%d tokens 以内。
原始 prompt:
%s
只返回压缩后的 prompt。
""".formatted(targetTokens, originalPrompt);
return optimizationClient.prompt(compressionPrompt).call().content();
}
// 移除不必要的上下文
public String pruneContext(List<ChatMessage> history, int maxTokens) {
int totalTokens = countTokens(history);
if (totalTokens <= maxTokens) {
return formatMessages(history);
}
// 保留系统消息和最近的消息
List<ChatMessage> pruned = new ArrayList<>();
// 添加系统消息
history.stream()
.filter(msg -> msg.type() == MessageType.SYSTEM)
.findFirst()
.ifPresent(pruned::add);
// 从后向前添加消息直到达到限制
int currentTokens = countTokens(pruned);
for (int i = history.size() - 1; i >= 0; i--) {
ChatMessage msg = history.get(i);
if (msg.type() == MessageType.SYSTEM) continue;
int msgTokens = countTokens(msg);
if (currentTokens + msgTokens > maxTokens) {
break;
}
pruned.add(0, msg);
currentTokens += msgTokens;
}
return formatMessages(pruned);
}
private int countTokens(Object content) {
// 使用 tiktoken 或类似库
if (content instanceof String) {
return ((String) content).length() / 4; // 粗略估计
}
if (content instanceof List) {
return ((List<?>) content).stream()
.mapToInt(this::countTokens)
.sum();
}
return 0;
}
private String formatMessages(List<ChatMessage> messages) {
return messages.stream()
.map(msg -> msg.type() + ": " + msg.content())
.collect(Collectors.joining("\n"));
}
}
12.2.2 模型路由
java
@Service
public class ModelRouterService {
private final ChatClient cheapModel; // 如 GPT-3.5
private final ChatClient expensiveModel; // 如 GPT-4
private final ChatClient classifierModel; // 小型模型用于分类
public String routeAndGenerate(String prompt) {
// 判断任务复杂度
String complexity = classifierModel.prompt("""
评估以下任务的复杂度(simple/medium/complex):
任务:%s
只返回一个词。
""".formatted(prompt)).call().content().trim();
switch (complexity.toLowerCase()) {
case "simple":
return cheapModel.prompt(prompt).call().content();
case "medium":
// 中等任务可以尝试便宜模型,失败再升级
try {
return cheapModel.prompt(prompt).call().content();
} catch (Exception e) {
return expensiveModel.prompt(prompt).call().content();
}
case "complex":
default:
return expensiveModel.prompt(prompt).call().content();
}
}
}
成本对比(以 OpenAI 为例):
| 模型 | 输入价格 ($/1K tokens) | 输出价格 ($/1K tokens) | 适用场景 |
|---|---|---|---|
| GPT-4.5 | $0.03 | $0.06 | 复杂推理、代码生成 |
| GPT-4o | $0.01 | $0.03 | 通用任务、多模态 |
| GPT-3.5-Turbo | $0.0005 | $0.0015 | 简单问答、分类 |
通过智能路由,可节省 40-70% 的成本。
第 13 章:监控、日志与可观测性
13.1 Micrometer 集成
13.1.1 自定义指标
java
@Configuration
public class AiMetricsConfig {
@Bean
public MeterBinder aiMetricsBinder(ChatClient chatClient) {
return (registry) -> {
// 请求计数器
Counter.builder("ai.requests.total")
.description("Total AI requests")
.tag("provider", "openai")
.register(registry);
// 延迟直方图
Timer.builder("ai.request.duration")
.description("AI request duration")
.tag("provider", "openai")
.register(registry);
// Token 使用量
DistributionSummary.builder("ai.tokens.used")
.description("Tokens used per request")
.tag("type", "input")
.register(registry);
DistributionSummary.builder("ai.tokens.used")
.description("Tokens used per request")
.tag("type", "output")
.register(registry);
// 错误率
Counter.builder("ai.errors.total")
.description("Total AI errors")
.tag("provider", "openai")
.register(registry);
};
}
@Bean
public ChatClient monitoredChatClient(ChatClient delegate, MeterRegistry registry) {
return new MonitoredChatClient(delegate, registry);
}
}
class MonitoredChatClient implements ChatClient {
private final ChatClient delegate;
private final MeterRegistry registry;
private final Timer timer;
private final Counter requestCounter;
private final Counter errorCounter;
public MonitoredChatClient(ChatClient delegate, MeterRegistry registry) {
this.delegate = delegate;
this.registry = registry;
this.timer = Timer.builder("ai.request.duration")
.register(registry);
this.requestCounter = Counter.builder("ai.requests.total")
.register(registry);
this.errorCounter = Counter.builder("ai.errors.total")
.register(registry);
}
@Override
public ResponseSpec prompt() {
return new MonitoredResponseSpec(delegate.prompt(), timer, requestCounter, errorCounter);
}
// ... 其他方法委托
}
class MonitoredResponseSpec implements ChatClient.ResponseSpec {
private final ChatClient.ResponseSpec delegate;
private final Timer timer;
private final Counter requestCounter;
private final Counter errorCounter;
public MonitoredResponseSpec(ChatClient.ResponseSpec delegate, Timer timer,
Counter requestCounter, Counter errorCounter) {
this.delegate = delegate;
this.timer = timer;
this.requestCounter = requestCounter;
this.errorCounter = errorCounter;
}
@Override
public ContentSpec call() {
Timer.Sample sample = Timer.start(timer);
requestCounter.increment();
try {
ContentSpec result = delegate.call();
sample.stop(timer);
return result;
} catch (Exception e) {
errorCounter.increment();
sample.stop(timer);
throw e;
}
}
// ... 其他方法委托
}
13.1.2 Grafana 仪表盘配置
json
{
"dashboard": {
"title": "Spring AI 监控",
"panels": [
{
"title": "请求速率",
"type": "graph",
"targets": [
{
"expr": "rate(ai_requests_total[5m])",
"legendFormat": "Requests/sec"
}
]
},
{
"title": "P95 延迟",
"type": "graph",
"targets": [
{
"expr": "histogram_quantile(0.95, rate(ai_request_duration_bucket[5m]))",
"legendFormat": "P95 Latency"
}
]
},
{
"title": "Token 使用量",
"type": "graph",
"targets": [
{
"expr": "sum(rate(ai_tokens_used_sum[5m])) by (type)",
"legendFormat": "{{type}} tokens"
}
]
},
{
"title": "错误率",
"type": "graph",
"targets": [
{
"expr": "rate(ai_errors_total[5m]) / rate(ai_requests_total[5m]) * 100",
"legendFormat": "Error Rate %"
}
]
}
]
}
}
13.2 结构化日志
java
@Component
public class AiLoggingAdvisor implements PromptCallAdvisor {
private static final Logger log = LoggerFactory.getLogger(AiLoggingAdvisor.class);
private final ObjectMapper objectMapper = new ObjectMapper();
@Override
public Prompt apply(Prompt prompt, Map<String, Object> context) {
Map<String, Object> logData = new HashMap<>();
logData.put("event", "ai_prompt_sent");
logData.put("timestamp", Instant.now());
logData.put("conversationId", context.get("conversationId"));
logData.put("promptLength", prompt.getContents().length());
logData.put("model", context.get("model"));
// 脱敏处理
String sanitizedPrompt = sanitize(prompt.getContents());
logData.put("promptPreview", sanitizedPrompt.substring(0, Math.min(200, sanitizedPrompt.length())));
try {
log.info(objectMapper.writeValueAsString(logData));
} catch (JsonProcessingException e) {
log.warn("日志序列化失败", e);
}
return prompt;
}
@Override
public Response apply(Response response, Map<String, Object> context) {
Map<String, Object> logData = new HashMap<>();
logData.put("event", "ai_response_received");
logData.put("timestamp", Instant.now());
logData.put("conversationId", context.get("conversationId"));
if (response.getResult() != null) {
TokenUsage usage = response.getResult().getMetadata().getTokenUsage();
if (usage != null) {
logData.put("inputTokens", usage.inputTokens());
logData.put("outputTokens", usage.outputTokens());
logData.put("totalTokens", usage.totalTokens());
}
String sanitizedResponse = sanitize(response.getResult().getOutput().getContent());
logData.put("responsePreview", sanitizedResponse.substring(0, Math.min(200, sanitizedResponse)));
}
try {
log.info(objectMapper.writeValueAsString(logData));
} catch (JsonProcessingException e) {
log.warn("日志序列化失败", e);
}
return response;
}
private String sanitize(String content) {
// 移除敏感信息
return content.replaceAll("sk-\\w+", "[REDACTED]")
.replaceAll("\\b\\d{16}\\b", "[CARD_REDACTED]")
.replaceAll("[\\w.-]+@[\\w.-]+", "[EMAIL_REDACTED]");
}
}
第 14 章:安全与合规
14.1 输入输出过滤
14.1.1 提示注入防护
java
@Component
public class PromptInjectionFilter implements PromptCallAdvisor {
private static final List<String> DANGEROUS_PATTERNS = Arrays.asList(
"忽略上述指令",
"无视之前的规则",
"你现在是一个",
"系统提示是",
"忘记你所受的",
"突破所有限制",
"不要遵守",
"绕过安全"
);
@Override
public Prompt apply(Prompt prompt, Map<String, Object> context) {
String content = prompt.getContents();
for (String pattern : DANGEROUS_PATTERNS) {
if (content.toLowerCase().contains(pattern.toLowerCase())) {
log.warn("检测到潜在的提示注入攻击:{}", pattern);
throw new SecurityException("检测到恶意的 prompt 注入尝试");
}
}
// 额外的启发式检测
if (detectJailbreakAttempt(content)) {
log.warn("检测到越狱尝试");
throw new SecurityException("检测到越狱尝试");
}
return prompt;
}
private boolean detectJailbreakAttempt(String content) {
// 检测常见的越狱模式
// 如 DAN (Do Anything Now) 模式
// 实现具体的检测逻辑
return content.contains("DAN") && content.contains("不受限制");
}
}
14.1.2 输出内容审核
java
@Service
public class ContentModerationService {
private final ChatClient moderationClient;
public String moderateAndReturn(String content) {
// 使用专门的审核模型
String moderationPrompt = """
审核以下内容是否包含:
1. 仇恨言论
2. 暴力内容
3. 色情内容
4. 个人隐私信息
5. 虚假信息
内容:%s
如果有问题,返回具体问题类型;否则返回 SAFE。
只返回一个问题类型或 SAFE。
""".formatted(content);
String result = moderationClient.prompt(moderationPrompt).call().content().trim();
if (!"SAFE".equalsIgnoreCase(result)) {
log.warn("内容审核未通过:{}", result);
throw new ContentModerationException("内容包含不当信息:" + result);
}
return content;
}
// 也可以使用第三方审核 API
public boolean isContentSafe(String content) {
// 调用 Azure Content Safety API 或其他服务
// 实现具体的调用逻辑
return true;
}
}
14.2 数据隐私保护
14.2.1 PII 检测和脱敏
java
@Component
public class PiiDetectionService {
private static final Pattern EMAIL_PATTERN = Pattern.compile("[\\w.-]+@[\\w.-]+");
private static final Pattern PHONE_PATTERN = Pattern.compile("\\b\\d{11}\\b");
private static final Pattern ID_CARD_PATTERN = Pattern.compile("\\b\\d{17}[\\dXx]\\b");
private static final Pattern BANK_CARD_PATTERN = Pattern.compile("\\b\\d{16,19}\\b");
public String detectAndMask(String content) {
String masked = content;
// 邮箱脱敏
masked = EMAIL_PATTERN.matcher(masked)
.replaceAll(m -> maskEmail(m.group()));
// 手机号脱敏
masked = PHONE_PATTERN.matcher(masked)
.replaceAll(m -> maskPhone(m.group()));
// 身份证号脱敏
masked = ID_CARD_PATTERN.matcher(masked)
.replaceAll(m -> maskIdCard(m.group()));
// 银行卡号脱敏
masked = BANK_CARD_PATTERN.matcher(masked)
.replaceAll(m -> maskBankCard(m.group()));
return masked;
}
private String maskEmail(String email) {
int atIdx = email.indexOf('@');
if (atIdx <= 2) {
return "***" + email.substring(atIdx);
}
return email.substring(0, 2) + "***" + email.substring(atIdx);
}
private String maskPhone(String phone) {
return phone.substring(0, 3) + "****" + phone.substring(7);
}
private String maskIdCard(String idCard) {
return idCard.substring(0, 6) + "********" + idCard.substring(14);
}
private String maskBankCard(String card) {
return "**** **** **** " + card.substring(card.length() - 4);
}
}
第六篇:实战篇
第 15 章:智能客服系统实战
15.1 系统架构设计
scss
┌─────────────────┐
│ 用户界面 │
│ (Web/Mobile) │
└────────┬────────┘
│
┌────────▼────────┐
│ API Gateway │
│ (认证/限流) │
└────────┬────────┘
│
┌────────▼────────┐
│ 意图识别服务 │
│ (Intent Router) │
└────────┬────────┘
│
┌────┴────┐
│ │
┌───▼───┐ ┌──▼──────┐
│ FAQ │ │ 人工客服 │
│ 问答 │ │ 转接 │
└───┬───┘ └─────────┘
│
┌───▼──────────┐
│ RAG 引擎 │
│ (知识库检索) │
└───┬──────────┘
│
┌───▼──────────┐
│ 对话管理 │
│ (上下文/记忆) │
└───┬──────────┘
│
┌───▼──────────┐
│ LLM 生成 │
│ (Spring AI/ │
│ LangChain4j)│
└──────────────┘
15.2 核心代码实现
15.2.1 意图识别器
java
@Service
public class CustomerServiceIntentRouter {
private final ChatClient classifierClient;
private final FaqService faqService;
private final OrderQueryService orderQueryService;
private final ComplaintService complaintService;
private final HumanHandoffService humanHandoffService;
public CustomerServiceResponse handleRequest(
String userId,
String message,
String conversationHistory
) {
// 意图分类
String intent = classifyIntent(message, conversationHistory);
CustomerServiceResponse response = new CustomerServiceResponse();
response.setIntent(intent);
switch (intent) {
case "FAQ":
response.setAnswer(faqService.answer(message));
response.setConfidence(0.9);
break;
case "ORDER_QUERY":
response.setAnswer(orderQueryService.query(userId, message));
response.setConfidence(0.85);
break;
case "COMPLAINT":
response.setAnswer(complaintService.handle(userId, message));
response.setRequiresHumanFollowup(true);
response.setConfidence(0.8);
break;
case "HUMAN_HANDOFF":
response = humanHandoffService.initiate(userId, message);
break;
default:
// 未知意图,尝试通用回答或转人工
response.setAnswer(generateFallbackResponse(message));
response.setSuggestHumanHandoff(true);
}
return response;
}
private String classifyIntent(String message, String history) {
String prompt = """
你是客服意图分类器。将用户请求分类到以下类别之一:
- FAQ: 常见问题咨询(产品、服务、政策等)
- ORDER_QUERY: 订单查询(状态、物流、退款等)
- COMPLAINT: 投诉建议
- HUMAN_HANDOFF: 明确要求人工客服
- GENERAL: 其他闲聊
对话历史:%s
当前消息:%s
只返回类别名称。
""".formatted(history, message);
return classifierClient.prompt(prompt).call().content().trim();
}
private String generateFallbackResponse(String message) {
return """
抱歉,我还没有完全理解您的问题。您可以:
1. 换一种方式描述您的问题
2. 询问常见问题(如"如何退货")
3. 输入"转人工"联系人工客服
请问还有什么我可以帮您的吗?
""";
}
}
15.2.2 FAQ RAG 系统
java
@Service
public class FaqRagService {
private final VectorStore faqVectorStore;
private final EmbeddingModel embeddingModel;
private final ChatClient chatClient;
@PostConstruct
public void initFaqKnowledgeBase() {
// 加载 FAQ 文档
List<FaqDocument> faqs = loadFaqsFromDatabase();
for (FaqDocument faq : faqs) {
String text = "问题:" + faq.getQuestion() + "\n答案:" + faq.getAnswer();
Embedding embedding = embeddingModel.embed(text).content();
TextSegment segment = TextSegment.from(text, Map.of(
"faqId", faq.getId(),
"category", faq.getCategory(),
"tags", String.join(",", faq.getTags())
));
faqVectorStore.add(embedding, segment);
}
}
public String answer(String question) {
// 检索相关 FAQ
Embedding queryEmbedding = embeddingModel.embed(question).content();
List<EmbeddingMatch<TextSegment>> matches =
faqVectorStore.findRelevant(queryEmbedding, 3);
if (matches.isEmpty() || matches.get(0).score() < 0.6) {
return null; // 没有找到匹配的 FAQ
}
// 构建上下文
String context = matches.stream()
.map(match -> {
TextSegment segment = match.embedded();
return segment.text() + " (相似度:" + match.score() + ")";
})
.collect(Collectors.joining("\n\n"));
// 生成回答
String prompt = """
基于以下 FAQ 知识回答问题。如果知识中没有答案,请说"抱歉,我的知识库中没有相关信息"。
知识库:
%s
用户问题:%s
请用友好、专业的语气回答。
""".formatted(context, question);
return chatClient.prompt(prompt).call().content();
}
private List<FaqDocument> loadFaqsFromDatabase() {
// 从数据库加载 FAQ
// 实际实现
return new ArrayList<>();
}
}
第 16 章:企业知识库问答系统
16.1 系统架构
scss
文档源 (PDF/Word/Excel/Wiki)
↓
ETL Pipeline (解析 → 分块 → 清洗)
↓
向量化 (Embedding Model)
↓
向量数据库 (PGVector/Redis)
↓
RAG 查询引擎
↓
LLM 生成回答
↓
用户界面
16.2 完整实现
16.2.1 文档 ETL 流水线
java
@Service
public class DocumentEtlPipeline {
private final DocumentParser documentParser;
private final TextSplitter textSplitter;
private final EmbeddingModel embeddingModel;
private final VectorStore vectorStore;
private final DocumentMetadataExtractor metadataExtractor;
@Transactional
public IngestionResult ingestDocument(MultipartFile file, String category) {
IngestionResult result = new IngestionResult();
try {
// 1. 解析文档
Document document = documentParser.parse(file.getInputStream());
result.setOriginalContent(document.content());
// 2. 提取元数据
Map<String, Object> metadata = metadataExtractor.extract(file, document);
metadata.put("category", category);
metadata.put("uploadTime", System.currentTimeMillis());
metadata.put("uploadedBy", SecurityContextHolder.getCurrentUser().getName());
// 3. 文本分块
List<TextSegment> segments = textSplitter.split(document.content());
result.setChunkCount(segments.size());
// 4. 添加元数据到每个 chunk
for (int i = 0; i < segments.size(); i++) {
segments.get(i).metadata().putAll(metadata);
segments.get(i).metadata().put("chunkIndex", i);
segments.get(i).metadata().put("totalChunks", segments.size());
}
// 5. 批量生成嵌入
List<Embedding> embeddings = embeddingModel.embedAll(segments).content();
result.setEmbeddingCount(embeddings.size());
// 6. 存储到向量数据库
for (int i = 0; i < segments.size(); i++) {
vectorStore.add(embeddings.get(i), segments.get(i));
}
result.setStatus("SUCCESS");
result.setMessage("文档入库成功");
} catch (Exception e) {
result.setStatus("FAILED");
result.setMessage("入库失败:" + e.getMessage());
log.error("文档入库失败", e);
}
return result;
}
@Scheduled(cron = "0 0 2 * * ?") // 每天凌晨 2 点
public void incrementalSync() {
// 增量同步逻辑
List<DocumentUpdate> updates = documentRepository.findUpdatedSince(lastSyncTime);
for (DocumentUpdate update : updates) {
if (update.isDeleted()) {
deleteFromVectorStore(update.getDocumentId());
} else {
ingestDocument(update.getFile(), update.getCategory());
}
}
lastSyncTime = LocalDateTime.now();
}
private void deleteFromVectorStore(String documentId) {
// 删除指定文档的所有 chunk
// 实现删除逻辑
}
}
16.2.2 高级 RAG 查询
java
@Service
public class EnterpriseRagService {
private final VectorStore vectorStore;
private final EmbeddingModel embeddingModel;
private final ChatClient chatClient;
private final QueryExpansionService queryExpansionService;
private final RerankingService rerankingService;
private final AccessControlService accessControlService;
public RagResponse query(RagQueryRequest request) {
String userId = request.getUserId();
String query = request.getQuery();
List<String> allowedCategories = accessControlService.getAllowedCategories(userId);
// 1. 查询扩展
List<String> expandedQueries = queryExpansionService.expand(query);
// 2. 多查询检索
List<EmbeddingMatch<TextSegment>> allMatches = new ArrayList<>();
for (String q : expandedQueries) {
Embedding queryEmbedding = embeddingModel.embed(q).content();
// 构建元数据过滤器
FilterMetadata filter = MetadataFilterBuilder.metadata("category")
.isIn(allowedCategories)
.build();
List<EmbeddingMatch<TextSegment>> matches =
vectorStore.findRelevant(queryEmbedding, 10, filter);
allMatches.addAll(matches);
}
// 3. 去重
Map<String, EmbeddingMatch<TextSegment>> unique = new LinkedHashMap<>();
for (EmbeddingMatch<TextSegment> match : allMatches) {
unique.putIfAbsent(match.embedded().text(), match);
}
// 4. 重排序
List<EmbeddingMatch<TextSegment>> reranked =
rerankingService.rerank(query, new ArrayList<>(unique.values()), 5);
// 5. 构建引用信息
List<Citation> citations = reranked.stream()
.map(match -> {
TextSegment segment = match.embedded();
return new Citation(
segment.metadata().get("source").toString(),
segment.metadata().get("category").toString(),
match.score()
);
})
.collect(Collectors.toList());
// 6. 构建上下文
String context = reranked.stream()
.map(match -> match.embedded().text())
.collect(Collectors.joining("\n\n---\n\n"));
// 7. 生成回答
String prompt = """
你是企业知识库助手。基于以下文档片段回答问题。
要求:
1. 只基于提供的文档回答
2. 如果文档中没有答案,明确说明
3. 引用信息来源(使用 [1], [2] 等标记)
4. 保持专业和准确
文档片段:
%s
问题:%s
回答:
""".formatted(context, query);
String answer = chatClient.prompt(prompt)
.advisors(a -> a.param("userId", userId))
.call()
.content();
// 8. 构建响应
RagResponse response = new RagResponse();
response.setAnswer(answer);
response.setCitations(citations);
response.setQuery(query);
response.setTimestamp(System.currentTimeMillis());
// 9. 记录查询日志
logQuery(userId, query, answer, citations);
return response;
}
private void logQuery(String userId, String query, String answer, List<Citation> citations) {
// 保存到数据库用于分析和优化
QueryLog log = new QueryLog();
log.setUserId(userId);
log.setQuery(query);
log.setAnswer(answer);
log.setCitations(citations);
log.setTimestamp(LocalDateTime.now());
queryLogRepository.save(log);
}
}
第 17 章:多 Agent 协作系统
17.1 场景:自动化软件开发助手
markdown
用户需求
↓
需求分析 Agent → 技术选型 Agent
↓ ↓
架构设计 Agent ←────────┘
↓
代码生成 Agent → 代码审查 Agent
↓
测试生成 Agent
↓
文档生成 Agent
↓
最终交付物
17.2 实现
java
@Service
public class SoftwareDevelopmentAgentOrchestrator {
private final ChatClient requirementsAgent;
private final ChatClient architectureAgent;
private final ChatClient codingAgent;
private final ChatClient reviewAgent;
private final ChatClient testingAgent;
private final ChatClient documentationAgent;
public DevelopmentDeliverable developProject(ProjectRequirement requirement) {
// 步骤 1: 需求分析
RequirementsAnalysis requirements = requirementsAgent.prompt("""
分析以下项目需求:
%s
输出:
1. 功能列表
2. 非功能需求
3. 用户故事
4. 验收标准
以 JSON 格式返回。
""".formatted(requirement.getDescription()))
.call()
.entity(RequirementsAnalysis.class);
// 步骤 2: 技术选型
TechnologyStack techStack = requirementsAgent.prompt("""
基于以下需求推荐技术栈:
%s
考虑因素:
1. 项目规模
2. 团队技能
3. 性能要求
4. 成本预算
输出:前端、后端、数据库、部署方案。
""".formatted(requirements))
.call()
.entity(TechnologyStack.class);
// 步骤 3: 架构设计
ArchitectureDesign architecture = architectureAgent.prompt("""
设计系统架构:
需求:%s
技术栈:%s
输出:
1. 系统组件图
2. 数据流图
3. API 设计
4. 数据库 schema
详细描述。
""".formatted(requirements, techStack))
.call()
.content();
// 步骤 4: 代码生成
List<CodeFile> codeFiles = codingAgent.prompt("""
根据以下架构设计生成代码:
%s
技术栈:%s
生成完整的源代码文件,每个文件包含:
- 文件路径
- 文件内容
以 JSON 数组格式返回。
""".formatted(architecture, techStack))
.call()
.entity(new TypeReference<List<CodeFile>>(){});
// 步骤 5: 代码审查
CodeReview review = reviewAgent.prompt("""
审查以下代码:
%s
检查:
1. 代码规范
2. 潜在 bug
3. 安全漏洞
4. 性能问题
5. 改进建议
输出审查报告。
""".formatted(formatCodeFiles(codeFiles)))
.call()
.entity(CodeReview.class);
// 如果需要修改,迭代
if (!review.isPassed()) {
codeFiles = fixCodeBasedOnReview(codeFiles, review);
}
// 步骤 6: 测试生成
List<TestFile> testFiles = testingAgent.prompt("""
为以下代码生成单元测试:
%s
要求:
1. 覆盖率达到 80%+
2. 包含边界测试
3. 使用 JUnit 5
输出测试文件。
""".formatted(formatCodeFiles(codeFiles)))
.call()
.entity(new TypeReference<List<TestFile>>(){});
// 步骤 7: 文档生成
Documentation docs = documentationAgent.prompt("""
生成项目文档:
需求:%s
架构:%s
代码:%s
输出:
1. README.md
2. API 文档
3. 部署指南
4. 用户手册
""".formatted(requirements, architecture, formatCodeFiles(codeFiles)))
.call()
.entity(Documentation.class);
// 组装交付物
DevelopmentDeliverable deliverable = new DevelopmentDeliverable();
deliverable.setRequirements(requirements);
deliverable.setArchitecture(architecture);
deliverable.setCodeFiles(codeFiles);
deliverable.setTestFiles(testFiles);
deliverable.setDocumentation(docs);
deliverable.setReviewReport(review);
return deliverable;
}
private List<CodeFile> fixCodeBasedOnReview(List<CodeFile> codeFiles, CodeReview review) {
// 根据审查意见修复代码
String fixPrompt = """
原代码:%s
审查意见:%s
请修复所有指出的问题。
""".formatted(formatCodeFiles(codeFiles), review.getComments());
return codingAgent.prompt(fixPrompt)
.call()
.entity(new TypeReference<List<CodeFile>>(){});
}
private String formatCodeFiles(List<CodeFile> files) {
return files.stream()
.map(f -> "文件:" + f.getPath() + "\n```\n" + f.getContent() + "\n```")
.collect(Collectors.joining("\n\n"));
}
}