Spring AI多模型架构:OpenAI/DeepSeek/MiniMax一键切换实战
一、前言
1.1 为什么需要多模型支持
在AI应用开发中,单一模型往往无法满足所有场景需求:
| 场景 | 推荐模型 | 原因 |
|---|---|---|
| 通用对话 | GPT-4 / DeepSeek-V3 | 能力强、理解准确 |
| 代码生成 | GPT-4 / Claude 3.5 | 代码质量高 |
| 中文优化 | DeepSeek / MiniMax | 中文理解更好 |
| 成本敏感 | GPT-3.5 / DeepSeek-V2 | 价格便宜 |
| 离线部署 | Ollama本地模型 | 数据隐私保护 |
多模型架构的核心价值:
- 成本优化:不同场景选择性价比最优的模型
- 能力互补:利用各模型优势处理特定任务
- 风险分散:避免单一供应商依赖
- 用户体验:根据用户套餐提供差异化服务
1.2 Spring AI 简介
Spring AI 是 Spring 官方推出的 AI 应用开发框架,目标是简化 Java 开发者集成 LLM 的过程:
Spring AI 核心特性:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
1. 模型无关的抽象层
- ChatClient:统一聊天接口
- EmbeddingClient:统一向量接口
- ImageClient:统一图像接口
2. Prompt 管理
- PromptTemplate:模板引擎
- 支持变量替换和条件渲染
3. 向量存储
- PgVector、Redis、Chroma等
- 统一的Document和VectorStore接口
4. 函数调用
- 支持模型调用本地Java方法
- 实现Agent能力
5. 自动配置
- Spring Boot Starter支持
- 声明式配置
二、Spring AI 核心概念
2.1 ChatClient 与 ChatModel
java
/**
* Spring AI 核心接口关系
*/
// ChatClient - 面向应用的高级接口
public interface ChatClient {
ChatClientRequestSpec prompt(); // 创建Prompt构建器
String call(String message); // 简单调用
Flux<String> stream(String message); // 流式调用
}
// ChatModel - 面向模型的底层接口
public interface ChatModel {
ChatResponse call(Prompt prompt); // 同步调用
Flux<ChatResponse> stream(Prompt prompt); // 流式调用
}
// 实际使用 - ChatClient更简洁
@Autowired
private ChatClient chatClient;
public String chat(String message) {
return chatClient.prompt()
.user(message)
.call()
.content();
}
2.2 Prompt 模板系统
java
/**
* Prompt 模板使用示例
*/
// 1. 简单模板
public String simplePrompt(String jobTitle, String resume) {
return chatClient.prompt()
.system("你是一位专业的" + jobTitle + "面试官")
.user(u -> u.text("""
请根据以下简历内容,生成5个面试问题:
简历内容:
{resume}
""").param("resume", resume))
.call()
.content();
}
// 2. 结构化Prompt(推荐)
public InterviewResponse structuredPrompt(InterviewRequest request) {
return chatClient.prompt()
.system(systemPromptProvider.getInterviewSystemPrompt())
.user(u -> u.text("""
岗位类型:{jobType}
工作经验:{experience}
面试轮次:{round}
请生成适合该候选人的面试问题。
""")
.param("jobType", request.getJobType())
.param("experience", request.getExperience())
.param("round", request.getRound()))
.options(ChatOptions.builder()
.temperature(0.7)
.maxTokens(2000)
.build())
.call()
.entity(InterviewResponse.class); // 自动JSON解析
}
// 3. 多轮对话
public String multiTurnChat(List<Message> history, String newMessage) {
ChatClient.ChatClientRequestSpec spec = chatClient.prompt();
// 添加历史消息
for (Message msg : history) {
if (msg.getRole() == Role.USER) {
spec.user(msg.getContent());
} else {
spec.system(msg.getContent());
}
}
return spec.user(newMessage)
.call()
.content();
}
2.3 流式响应处理
java
/**
* 流式响应处理 - 实现打字机效果
*/
@RestController
@RequestMapping("/api/ai")
@Slf4j
public class StreamingController {
@Autowired
private ChatClient chatClient;
/**
* Server-Sent Events 流式输出
*/
@GetMapping(value = "/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<ServerSentEvent<String>> streamChat(@RequestParam String message) {
return chatClient.prompt()
.user(message)
.stream()
.content()
.map(content -> ServerSentEvent.builder(content).build())
.onErrorResume(e -> {
log.error("Stream error: ", e);
return Flux.just(ServerSentEvent.builder("Error: " + e.getMessage()).build());
});
}
/**
* WebFlux 流式输出(前端使用EventSource接收)
*/
@PostMapping(value = "/interview/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<String> interviewStream(@RequestBody InterviewRequest request) {
String prompt = buildInterviewPrompt(request);
return chatClient.prompt()
.system("你是一位专业的技术面试官")
.user(prompt)
.options(ChatOptions.builder()
.temperature(0.8)
.maxTokens(4000)
.build())
.stream()
.content()
.doOnNext(chunk -> log.debug("Received chunk: {}", chunk))
.doOnComplete(() -> log.info("Stream completed"));
}
}
三、多模型配置管理
3.1 ModelProvider 枚举定义
java
/**
* 模型提供商枚举
* 支持10+主流模型提供商
*/
@Getter
@AllArgsConstructor
public enum ModelProvider {
// OpenAI 系列
OPENAI_GPT4("openai", "gpt-4-turbo-preview", "OpenAI GPT-4", true, 50),
OPENAI_GPT35("openai", "gpt-3.5-turbo", "OpenAI GPT-3.5", true, 10),
// DeepSeek 系列(深度求索)
DEEPSEEK_V3("deepseek", "deepseek-chat", "DeepSeek V3", true, 8),
DEEPSEEK_CODER("deepseek", "deepseek-coder", "DeepSeek Coder", true, 8),
// MiniMax 系列(海螺AI)
MINIMAX_ABAB6("minimax", "abab6-chat", "MiniMax abab6", true, 15),
MINIMAX_ABAB5("minimax", "abab5.5-chat", "MiniMax abab5.5", true, 10),
// Anthropic Claude
CLAUDE_3_OPUS("anthropic", "claude-3-opus-20240229", "Claude 3 Opus", true, 80),
CLAUDE_3_SONNET("anthropic", "claude-3-sonnet-20240229", "Claude 3 Sonnet", true, 25),
// Google Gemini
GEMINI_PRO("google", "gemini-pro", "Gemini Pro", true, 12),
GEMINI_ULTRA("google", "gemini-ultra", "Gemini Ultra", true, 100),
// 本地模型(Ollama)
OLLAMA_LLAMA3("ollama", "llama3", "Llama 3 (Local)", false, 0),
OLLAMA_QWEN("ollama", "qwen:14b", "Qwen 14B (Local)", false, 0),
OLLAMA_MISTRAL("ollama", "mistral", "Mistral (Local)", false, 0);
private final String provider; // 提供商代码
private final String modelName; // 模型名称
private final String displayName; // 显示名称
private final boolean isCloud; // 是否为云端模型
private final int pricePer1K; // 每千Token价格(分)
/**
* 根据模型名称查找
*/
public static Optional<ModelProvider> fromModelName(String modelName) {
return Arrays.stream(values())
.filter(p -> p.modelName.equals(modelName))
.findFirst();
}
/**
* 根据提供商获取所有模型
*/
public static List<ModelProvider> byProvider(String provider) {
return Arrays.stream(values())
.filter(p -> p.provider.equals(provider))
.collect(Collectors.toList());
}
/**
* 获取所有可用模型
*/
public static List<ModelProvider> availableModels() {
return Arrays.stream(values())
.filter(ModelProvider::isEnabled)
.collect(Collectors.toList());
}
/**
* 检查模型是否启用(从配置读取)
*/
public boolean isEnabled() {
// 实际实现从配置中心读取
return true;
}
}
3.2 多模型 YAML 配置
yaml
# application.yml - AI模型配置
spring:
ai:
# OpenAI 配置
openai:
api-key: ${OPENAI_API_KEY}
base-url: https://api.openai.com # 可替换为代理地址
chat:
options:
model: gpt-3.5-turbo
temperature: 0.7
max-tokens: 2000
# DeepSeek 配置
deepseek:
api-key: ${DEEPSEEK_API_KEY}
base-url: https://api.deepseek.com
chat:
options:
model: deepseek-chat
temperature: 0.7
max-tokens: 4000
# Ollama 本地模型配置
ollama:
base-url: http://localhost:11434
chat:
options:
model: llama3
temperature: 0.8
# 自定义AI配置
ai:
model:
# 默认模型
default: OPENAI_GPT35
# 场景模型映射
scenarios:
interview: OPENAI_GPT4 # 面试场景用GPT-4
resume: DEEPSEEK_V3 # 简历优化用DeepSeek
code: DEEPSEEK_CODER # 代码生成用DeepSeek Coder
chat: OPENAI_GPT35 # 普通对话用GPT-3.5
# 用户套餐模型限制
plan-limits:
FREE:
- OPENAI_GPT35
- DEEPSEEK_V3
PREMIUM:
- OPENAI_GPT35
- OPENAI_GPT4
- DEEPSEEK_V3
- DEEPSEEK_CODER
- MINIMAX_ABAB5
ENTERPRISE:
- all # 所有模型
# 模型参数配置
parameters:
interview:
temperature: 0.7
max-tokens: 4000
top-p: 0.9
resume:
temperature: 0.5
max-tokens: 3000
code:
temperature: 0.3
max-tokens: 4000
3.3 模型配置属性类
java
/**
* AI模型配置属性
*/
@Data
@Configuration
@ConfigurationProperties(prefix = "ai.model")
public class AiModelProperties {
private String defaultModel;
private Map<String, String> scenarios = new HashMap<>();
private Map<String, List<String>> planLimits = new HashMap<>();
private Map<String, ModelParameters> parameters = new HashMap<>();
@Data
public static class ModelParameters {
private Double temperature;
private Integer maxTokens;
private Double topP;
private Double frequencyPenalty;
private Double presencePenalty;
}
/**
* 获取场景对应的模型
*/
public ModelProvider getModelForScenario(String scenario) {
String modelName = scenarios.getOrDefault(scenario, defaultModel);
return ModelProvider.valueOf(modelName);
}
/**
* 获取用户套餐允许的模型列表
*/
public List<ModelProvider> getModelsForPlan(String plan) {
List<String> modelNames = planLimits.getOrDefault(plan,
planLimits.get("FREE"));
if (modelNames.contains("all")) {
return Arrays.asList(ModelProvider.values());
}
return modelNames.stream()
.map(ModelProvider::valueOf)
.collect(Collectors.toList());
}
/**
* 获取场景参数
*/
public ModelParameters getParameters(String scenario) {
return parameters.getOrDefault(scenario, new ModelParameters());
}
}
四、动态模型切换机制
4.1 ModelSwitchService 实现
java
/**
* 模型切换服务
* 实现运行时动态切换模型,无需重启服务
*/
@Service
@Slf4j
public class ModelSwitchService {
@Autowired
private AiModelProperties modelProperties;
@Autowired
private ApplicationContext applicationContext;
// 缓存ChatClient实例
private final Map<ModelProvider, ChatClient> clientCache = new ConcurrentHashMap<>();
// 当前默认模型
private volatile ModelProvider currentDefault = ModelProvider.OPENAI_GPT35;
/**
* 获取指定模型的ChatClient
*/
public ChatClient getChatClient(ModelProvider provider) {
return clientCache.computeIfAbsent(provider, this::createChatClient);
}
/**
* 获取当前默认模型的ChatClient
*/
public ChatClient getDefaultChatClient() {
return getChatClient(currentDefault);
}
/**
* 根据场景获取ChatClient
*/
public ChatClient getChatClientForScenario(String scenario) {
ModelProvider provider = modelProperties.getModelForScenario(scenario);
return getChatClient(provider);
}
/**
* 创建ChatClient实例
*/
private ChatClient createChatClient(ModelProvider provider) {
log.info("Creating ChatClient for provider: {}", provider);
ChatModel chatModel = createChatModel(provider);
return ChatClient.builder(chatModel)
.defaultSystem(systemText -> systemText.text(
"你是AI面试助手,帮助用户准备技术面试。"))
.defaultOptions(ChatOptions.builder()
.temperature(0.7)
.maxTokens(2000)
.build())
.build();
}
/**
* 创建ChatModel实例
*/
private ChatModel createChatModel(ModelProvider provider) {
return switch (provider.getProvider()) {
case "openai" -> createOpenAiModel(provider);
case "deepseek" -> createDeepSeekModel(provider);
case "minimax" -> createMiniMaxModel(provider);
case "ollama" -> createOllamaModel(provider);
default -> throw new IllegalArgumentException("Unknown provider: " + provider);
};
}
private ChatModel createOpenAiModel(ModelProvider provider) {
OpenAiApi openAiApi = new OpenAiApi(
System.getenv("OPENAI_API_KEY"),
"https://api.openai.com"
);
OpenAiChatModel chatModel = new OpenAiChatModel(
openAiApi,
OpenAiChatOptions.builder()
.model(provider.getModelName())
.temperature(0.7)
.build()
);
return chatModel;
}
private ChatModel createDeepSeekModel(ModelProvider provider) {
// DeepSeek使用OpenAI兼容接口
OpenAiApi deepSeekApi = new OpenAiApi(
System.getenv("DEEPSEEK_API_KEY"),
"https://api.deepseek.com"
);
return new OpenAiChatModel(
deepSeekApi,
OpenAiChatOptions.builder()
.model(provider.getModelName())
.temperature(0.7)
.build()
);
}
private ChatModel createMiniMaxModel(ModelProvider provider) {
// MiniMax需要自定义实现
return new MiniMaxChatModel(
System.getenv("MINIMAX_API_KEY"),
provider.getModelName()
);
}
private ChatModel createOllamaModel(ModelProvider provider) {
OllamaApi ollamaApi = new OllamaApi("http://localhost:11434");
return OllamaChatModel.builder()
.ollamaApi(ollamaApi)
.defaultOptions(OllamaOptions.builder()
.model(provider.getModelName())
.temperature(0.8)
.build())
.build();
}
/**
* 切换默认模型(热切换)
*/
public void switchDefaultModel(ModelProvider provider) {
log.info("Switching default model from {} to {}", currentDefault, provider);
this.currentDefault = provider;
}
/**
* 获取当前默认模型
*/
public ModelProvider getCurrentDefault() {
return currentDefault;
}
/**
* 获取所有可用模型
*/
public List<ModelProvider> getAvailableModels() {
return Arrays.stream(ModelProvider.values())
.filter(this::isModelAvailable)
.collect(Collectors.toList());
}
/**
* 检查模型是否可用
*/
public boolean isModelAvailable(ModelProvider provider) {
try {
ChatClient client = getChatClient(provider);
// 发送测试请求
String response = client.prompt()
.user("Hi")
.call()
.content();
return StringUtils.isNotBlank(response);
} catch (Exception e) {
log.warn("Model {} is not available: {}", provider, e.getMessage());
return false;
}
}
/**
* 清除模型缓存(用于配置更新后)
*/
public void clearCache() {
clientCache.clear();
log.info("ChatClient cache cleared");
}
}
4.2 模型选择策略
java
/**
* 模型选择策略
* 根据场景、成本、可用性自动选择最优模型
*/
@Service
@Slf4j
public class ModelSelectionStrategy {
@Autowired
private ModelSwitchService modelSwitchService;
@Autowired
private TokenUsageService tokenUsageService;
/**
* 智能模型选择
*/
public ModelProvider selectModel(ModelSelectionContext context) {
// 1. 检查用户是否有权限使用该模型
if (!hasPermission(context.getUserPlan(), context.getPreferredModel())) {
log.warn("User plan {} does not have access to model {}",
context.getUserPlan(), context.getPreferredModel());
// 降级到套餐允许的最佳模型
return getBestAvailableModel(context.getUserPlan(), context.getScenario());
}
// 2. 检查模型是否可用
if (!modelSwitchService.isModelAvailable(context.getPreferredModel())) {
log.warn("Preferred model {} is not available, falling back",
context.getPreferredModel());
return getFallbackModel(context.getPreferredModel());
}
// 3. 成本优化:如果是简单任务,使用更便宜的模型
if (context.isCostSensitive() && isSimpleTask(context)) {
ModelProvider cheaperModel = getCheaperAlternative(context.getPreferredModel());
if (cheaperModel != null && modelSwitchService.isModelAvailable(cheaperModel)) {
log.info("Using cheaper model {} instead of {}",
cheaperModel, context.getPreferredModel());
return cheaperModel;
}
}
return context.getPreferredModel();
}
/**
* 检查用户是否有权限使用模型
*/
private boolean hasPermission(String plan, ModelProvider model) {
List<ModelProvider> allowedModels = modelSwitchService.getAvailableModels();
return allowedModels.contains(model);
}
/**
* 获取最佳可用模型
*/
private ModelProvider getBestAvailableModel(String plan, String scenario) {
List<ModelProvider> candidates = switch (scenario) {
case "interview" -> List.of(
ModelProvider.OPENAI_GPT4,
ModelProvider.DEEPSEEK_V3,
ModelProvider.OPENAI_GPT35
);
case "code" -> List.of(
ModelProvider.DEEPSEEK_CODER,
ModelProvider.OPENAI_GPT4,
ModelProvider.CLAUDE_3_SONNET
);
default -> List.of(
ModelProvider.OPENAI_GPT35,
ModelProvider.DEEPSEEK_V3
);
};
for (ModelProvider model : candidates) {
if (hasPermission(plan, model) && modelSwitchService.isModelAvailable(model)) {
return model;
}
}
return ModelProvider.OPENAI_GPT35; // 最终兜底
}
/**
* 获取降级模型
*/
private ModelProvider getFallbackModel(ModelProvider original) {
return switch (original) {
case OPENAI_GPT4 -> ModelProvider.DEEPSEEK_V3;
case DEEPSEEK_V3 -> ModelProvider.OPENAI_GPT35;
case CLAUDE_3_OPUS -> ModelProvider.CLAUDE_3_SONNET;
default -> ModelProvider.OPENAI_GPT35;
};
}
/**
* 判断是否简单任务
*/
private boolean isSimpleTask(ModelSelectionContext context) {
String prompt = context.getPrompt();
// 简单启发式判断
return prompt.length() < 500 &&
!prompt.contains("代码") &&
!prompt.contains("算法") &&
!prompt.contains("设计");
}
/**
* 获取更便宜的替代模型
*/
private ModelProvider getCheaperAlternative(ModelProvider original) {
return switch (original) {
case OPENAI_GPT4 -> ModelProvider.OPENAI_GPT35;
case CLAUDE_3_OPUS -> ModelProvider.CLAUDE_3_SONNET;
case DEEPSEEK_CODER -> ModelProvider.DEEPSEEK_V3;
default -> null;
};
}
}
/**
* 模型选择上下文
*/
@Data
@Builder
public class ModelSelectionContext {
private ModelProvider preferredModel; // 用户偏好的模型
private String userPlan; // 用户套餐
private String scenario; // 使用场景
private String prompt; // 提示内容
private boolean costSensitive; // 是否成本敏感
}
五、统一 LLM 调用封装
5.1 LlmService 抽象层
java
/**
* 统一LLM调用服务
* 封装所有模型调用细节,对外提供统一接口
*/
@Service
@Slf4j
public class LlmService {
@Autowired
private ModelSwitchService modelSwitchService;
@Autowired
private ModelSelectionStrategy modelSelectionStrategy;
@Autowired
private TokenUsageService tokenUsageService;
@Autowired
private PromptTemplateService promptTemplateService;
/**
* 同步调用 - 简单场景
*/
public String chat(String message) {
return chat(message, modelSwitchService.getCurrentDefault());
}
/**
* 同步调用 - 指定模型
*/
public String chat(String message, ModelProvider model) {
ChatClient client = modelSwitchService.getChatClient(model);
long startTime = System.currentTimeMillis();
try {
String response = client.prompt()
.user(message)
.call()
.content();
// 记录Token使用
recordUsage(model, message, response, System.currentTimeMillis() - startTime);
return response;
} catch (Exception e) {
log.error("Chat failed with model {}: {}", model, e.getMessage());
throw new BusinessException(ResultCode.AI_MODEL_ERROR, e.getMessage());
}
}
/**
* 场景化调用 - 自动选择模型
*/
public String chatWithScenario(String message, String scenario, String userPlan) {
// 选择最优模型
ModelSelectionContext context = ModelSelectionContext.builder()
.preferredModel(modelSwitchService.getCurrentDefault())
.userPlan(userPlan)
.scenario(scenario)
.prompt(message)
.costSensitive(!"ENTERPRISE".equals(userPlan))
.build();
ModelProvider selectedModel = modelSelectionStrategy.selectModel(context);
log.info("Selected model {} for scenario {}", selectedModel, scenario);
// 获取场景参数
AiModelProperties.ModelParameters params =
promptTemplateService.getParametersForScenario(scenario);
ChatClient client = modelSwitchService.getChatClient(selectedModel);
long startTime = System.currentTimeMillis();
String response = client.prompt()
.user(message)
.options(ChatOptions.builder()
.temperature(params.getTemperature())
.maxTokens(params.getMaxTokens())
.build())
.call()
.content();
recordUsage(selectedModel, message, response, System.currentTimeMillis() - startTime);
return response;
}
/**
* 结构化输出 - 自动解析JSON
*/
public <T> T chatForEntity(String message, Class<T> responseType) {
return chatForEntity(message, responseType,
modelSwitchService.getCurrentDefault());
}
public <T> T chatForEntity(String message, Class<T> responseType,
ModelProvider model) {
ChatClient client = modelSwitchService.getChatClient(model);
try {
return client.prompt()
.user(message)
.call()
.entity(responseType);
} catch (Exception e) {
log.error("Entity parsing failed: {}", e.getMessage());
throw new BusinessException(ResultCode.AI_RESPONSE_ERROR, e.getMessage());
}
}
/**
* 流式调用
*/
public Flux<String> streamChat(String message) {
return streamChat(message, modelSwitchService.getCurrentDefault());
}
public Flux<String> streamChat(String message, ModelProvider model) {
ChatClient client = modelSwitchService.getChatClient(model);
return client.prompt()
.user(message)
.stream()
.content()
.doOnNext(chunk -> log.debug("Stream chunk: {}", chunk))
.doOnError(e -> log.error("Stream error: ", e));
}
/**
* 多轮对话
*/
public String chatWithHistory(List<ChatMessage> history, String newMessage) {
return chatWithHistory(history, newMessage,
modelSwitchService.getCurrentDefault());
}
public String chatWithHistory(List<ChatMessage> history, String newMessage,
ModelProvider model) {
ChatClient.ChatClientRequestSpec spec =
modelSwitchService.getChatClient(model).prompt();
// 添加系统提示
spec.system("你是一个专业的AI面试助手");
// 添加历史消息
for (ChatMessage msg : history) {
switch (msg.getRole()) {
case USER -> spec.user(msg.getContent());
case ASSISTANT -> spec.system(msg.getContent());
}
}
return spec.user(newMessage)
.call()
.content();
}
/**
* 带重试的调用
*/
public String chatWithRetry(String message, int maxRetries) {
int attempts = 0;
Exception lastException = null;
while (attempts < maxRetries) {
try {
return chat(message);
} catch (Exception e) {
lastException = e;
attempts++;
log.warn("Chat attempt {} failed: {}", attempts, e.getMessage());
if (attempts < maxRetries) {
// 指数退避
try {
Thread.sleep((long) Math.pow(2, attempts) * 1000);
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
}
}
}
}
throw new BusinessException(ResultCode.AI_MODEL_ERROR,
"Failed after " + maxRetries + " attempts: " + lastException.getMessage());
}
/**
* 记录Token使用
*/
private void recordUsage(ModelProvider model, String prompt, String response,
long latency) {
try {
// 估算Token数(实际应从API响应中获取)
int promptTokens = estimateTokens(prompt);
int completionTokens = estimateTokens(response);
int totalTokens = promptTokens + completionTokens;
TokenUsageRecord record = TokenUsageRecord.builder()
.model(model)
.promptTokens(promptTokens)
.completionTokens(completionTokens)
.totalTokens(totalTokens)
.cost(calculateCost(model, totalTokens))
.latencyMs(latency)
.timestamp(LocalDateTime.now())
.build();
tokenUsageService.recordUsage(record);
} catch (Exception e) {
log.error("Failed to record token usage: {}", e.getMessage());
}
}
/**
* 估算Token数(简化版)
*/
private int estimateTokens(String text) {
// 粗略估算:中文约1.5字符/Token,英文约4字符/Token
int chineseChars = (int) text.chars().filter(c -> c > 127).count();
int englishChars = text.length() - chineseChars;
return (int) (chineseChars / 1.5 + englishChars / 4);
}
/**
* 计算成本
*/
private BigDecimal calculateCost(ModelProvider model, int tokens) {
// 价格单位:分/千Token
BigDecimal pricePer1K = BigDecimal.valueOf(model.getPricePer1K());
return pricePer1K.multiply(BigDecimal.valueOf(tokens))
.divide(BigDecimal.valueOf(1000), 4, RoundingMode.HALF_UP);
}
}
5.2 Prompt 模板服务
java
/**
* Prompt 模板服务
* 管理各类场景的Prompt模板
*/
@Service
@Slf4j
public class PromptTemplateService {
@Autowired
private AiModelProperties modelProperties;
private final Map<String, String> templates = new HashMap<>();
@PostConstruct
public void init() {
// 初始化内置模板
loadBuiltinTemplates();
// 从文件加载自定义模板
loadCustomTemplates();
}
private void loadBuiltinTemplates() {
// 面试场景模板
templates.put("interview-system", """
你是一位经验丰富的技术面试官,正在面试{jobType}岗位的候选人。
候选人背景:
- 工作年限:{experience}
- 面试轮次:{round}
你的任务是:
1. 根据候选人背景提出有针对性的技术问题
2. 评估候选人的回答质量
3. 给出建设性的反馈和改进建议
请保持专业、友善的态度,问题难度要适中。
""");
// 简历优化模板
templates.put("resume-optimize", """
你是一位资深HR和职业规划师,擅长简历优化。
请分析以下简历,并给出优化建议:
目标岗位JD:
{jobDescription}
简历内容:
{resumeContent}
请从以下维度分析:
1. 简历与岗位的匹配度评分(0-100)
2. 关键词匹配情况
3. 具体优化建议
4. 修改后的简历内容
""");
// 代码评审模板
templates.put("code-review", """
你是一位资深软件工程师,请对以下代码进行评审。
编程语言:{language}
代码:
```{language}
{code}
```
请从以下方面给出评审意见:
1. 代码质量和可读性
2. 潜在的问题和Bug
3. 性能优化建议
4. 最佳实践建议
""");
}
private void loadCustomTemplates() {
// 从配置文件或数据库加载
// templates.putAll(loadFromDatabase());
}
/**
* 渲染模板
*/
public String renderTemplate(String templateKey, Map<String, Object> variables) {
String template = templates.get(templateKey);
if (template == null) {
throw new BusinessException("Template not found: " + templateKey);
}
// 简单变量替换
String result = template;
for (Map.Entry<String, Object> entry : variables.entrySet()) {
result = result.replace("{" + entry.getKey() + "}",
String.valueOf(entry.getValue()));
}
return result;
}
/**
* 获取场景参数
*/
public AiModelProperties.ModelParameters getParametersForScenario(String scenario) {
AiModelProperties.ModelParameters params =
modelProperties.getParameters().get(scenario);
if (params == null) {
// 返回默认参数
params = new AiModelProperties.ModelParameters();
params.setTemperature(0.7);
params.setMaxTokens(2000);
}
return params;
}
/**
* 获取系统Prompt
*/
public String getSystemPrompt(String scenario) {
return templates.getOrDefault(scenario + "-system",
"你是一个专业的AI助手。");
}
}
六、Token 计费与配额管理
6.1 Token 使用记录
java
/**
* Token使用记录实体
*/
@Entity
@Table(name = "token_usage_records")
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class TokenUsageRecord extends BaseEntity {
@Column(name = "user_id")
private Long userId;
@Enumerated(EnumType.STRING)
@Column(nullable = false)
private ModelProvider model;
@Column(name = "prompt_tokens")
private Integer promptTokens;
@Column(name = "completion_tokens")
private Integer completionTokens;
@Column(name = "total_tokens")
private Integer totalTokens;
@Column(name = "cost", precision = 10, scale = 4)
private BigDecimal cost; // 成本(元)
@Column(name = "latency_ms")
private Long latencyMs;
@Column(name = "scenario")
private String scenario;
@Column(name = "request_id")
private String requestId;
@Column(name = "status")
@Enumerated(EnumType.STRING)
private UsageStatus status;
public enum UsageStatus {
SUCCESS, FAILED, CACHED
}
}
6.2 Token 使用服务
java
/**
* Token使用记录服务
*/
@Service
@Slf4j
public class TokenUsageService {
@Autowired
private TokenUsageRecordRepository usageRepository;
@Autowired
private StringRedisTemplate redisTemplate;
private static final String USER_QUOTA_KEY = "user:quota:";
private static final String USER_USAGE_KEY = "user:usage:";
/**
* 记录Token使用
*/
@Async
public void recordUsage(TokenUsageRecord record) {
try {
usageRepository.save(record);
// 更新Redis使用统计
updateUsageStats(record);
log.debug("Token usage recorded: {} tokens, cost: {}",
record.getTotalTokens(), record.getCost());
} catch (Exception e) {
log.error("Failed to record token usage: {}", e.getMessage());
}
}
/**
* 检查用户配额
*/
public boolean checkQuota(Long userId, UserPlan plan) {
String key = USER_USAGE_KEY + userId + ":" + getCurrentMonth();
String usedStr = redisTemplate.opsForValue().get(key);
int used = usedStr != null ? Integer.parseInt(usedStr) : 0;
return used < plan.getQuota();
}
/**
* 扣除配额
*/
public void deductQuota(Long userId, int tokens) {
String key = USER_USAGE_KEY + userId + ":" + getCurrentMonth();
redisTemplate.opsForValue().increment(key, tokens);
// 设置过期时间为月底
LocalDateTime endOfMonth = LocalDateTime.now()
.with(TemporalAdjusters.lastDayOfMonth())
.with(LocalTime.MAX);
long seconds = Duration.between(LocalDateTime.now(), endOfMonth).getSeconds();
redisTemplate.expire(key, seconds, TimeUnit.SECONDS);
}
/**
* 获取用户使用统计
*/
public UsageStats getUsageStats(Long userId, LocalDate startDate, LocalDate endDate) {
List<TokenUsageRecord> records = usageRepository
.findByUserIdAndCreatedAtBetween(userId, startDate.atStartOfDay(),
endDate.atTime(LocalTime.MAX));
int totalTokens = records.stream()
.mapToInt(TokenUsageRecord::getTotalTokens)
.sum();
BigDecimal totalCost = records.stream()
.map(TokenUsageRecord::getCost)
.reduce(BigDecimal.ZERO, BigDecimal::add);
Map<ModelProvider, Integer> modelUsage = records.stream()
.collect(Collectors.groupingBy(
TokenUsageRecord::getModel,
Collectors.summingInt(TokenUsageRecord::getTotalTokens)
));
return UsageStats.builder()
.totalTokens(totalTokens)
.totalCost(totalCost)
.requestCount(records.size())
.modelUsage(modelUsage)
.averageLatency(records.stream()
.mapToLong(TokenUsageRecord::getLatencyMs)
.average()
.orElse(0))
.build();
}
/**
* 获取系统整体使用统计
*/
public SystemUsageStats getSystemStats(LocalDate date) {
LocalDateTime start = date.atStartOfDay();
LocalDateTime end = date.atTime(LocalTime.MAX);
List<TokenUsageRecord> records = usageRepository
.findByCreatedAtBetween(start, end);
return SystemUsageStats.builder()
.date(date)
.totalRequests(records.size())
.totalTokens(records.stream()
.mapToInt(TokenUsageRecord::getTotalTokens)
.sum())
.totalCost(records.stream()
.map(TokenUsageRecord::getCost)
.reduce(BigDecimal.ZERO, BigDecimal::add))
.modelDistribution(records.stream()
.collect(Collectors.groupingBy(
TokenUsageRecord::getModel,
Collectors.counting()
)))
.build();
}
private void updateUsageStats(TokenUsageRecord record) {
if (record.getUserId() == null) return;
String key = USER_USAGE_KEY + record.getUserId() + ":" + getCurrentMonth();
redisTemplate.opsForValue().increment(key, record.getTotalTokens());
}
private String getCurrentMonth() {
return LocalDate.now().format(DateTimeFormatter.ofPattern("yyyy-MM"));
}
}
6.3 配额限制切面
java
/**
* 配额限制切面
* 在AI调用前检查用户配额
*/
@Aspect
@Component
@Slf4j
public class QuotaLimitAspect {
@Autowired
private TokenUsageService tokenUsageService;
@Autowired
private UserService userService;
@Around("@annotation(checkQuota)")
public Object around(ProceedingJoinPoint point, CheckQuota checkQuota)
throws Throwable {
// 获取当前用户
Long userId = getCurrentUserId();
if (userId == null) {
return point.proceed();
}
User user = userService.findById(userId);
// 检查配额
if (!tokenUsageService.checkQuota(userId, user.getPlan())) {
log.warn("User {} quota exceeded", userId);
throw new BusinessException(ResultCode.INSUFFICIENT_BALANCE,
"本月AI调用额度已用完,请升级套餐");
}
return point.proceed();
}
private Long getCurrentUserId() {
Authentication auth = SecurityContextHolder.getContext().getAuthentication();
if (auth != null && auth.getPrincipal() instanceof UserDetails) {
// 从UserDetails获取用户ID
return ((UserDetails) auth.getPrincipal()).getId();
}
return null;
}
}
七、代码示例
7.1 ModelSwitchController
java
/**
* 模型切换控制器
* 提供模型管理和切换接口
*/
@RestController
@RequestMapping("/api/admin/ai")
@PreAuthorize("hasRole('ADMIN')")
@Slf4j
public class ModelSwitchController {
@Autowired
private ModelSwitchService modelSwitchService;
@Autowired
private LlmService llmService;
/**
* 获取所有可用模型
*/
@GetMapping("/models")
public Result<List<ModelInfo>> getAvailableModels() {
List<ModelInfo> models = modelSwitchService.getAvailableModels().stream()
.map(this::toModelInfo)
.collect(Collectors.toList());
return Result.success(models);
}
/**
* 切换默认模型
*/
@PostMapping("/models/default")
public Result<Void> switchDefaultModel(@RequestBody @Valid SwitchModelRequest request) {
ModelProvider provider = ModelProvider.valueOf(request.getModel());
modelSwitchService.switchDefaultModel(provider);
log.info("Default model switched to {} by admin", provider);
return Result.success("模型切换成功");
}
/**
* 获取当前默认模型
*/
@GetMapping("/models/default")
public Result<ModelInfo> getCurrentDefaultModel() {
ModelProvider current = modelSwitchService.getCurrentDefault();
return Result.success(toModelInfo(current));
}
/**
* 测试模型可用性
*/
@PostMapping("/models/{model}/test")
public Result<ModelTestResult> testModel(@PathVariable String model) {
ModelProvider provider = ModelProvider.valueOf(model);
long startTime = System.currentTimeMillis();
boolean available = modelSwitchService.isModelAvailable(provider);
long latency = System.currentTimeMillis() - startTime;
ModelTestResult result = ModelTestResult.builder()
.model(model)
.available(available)
.latencyMs(latency)
.build();
return Result.success(result);
}
/**
* 清除模型缓存
*/
@PostMapping("/cache/clear")
public Result<Void> clearCache() {
modelSwitchService.clearCache();
return Result.success("缓存已清除");
}
/**
* 获取Token使用统计
*/
@GetMapping("/usage/stats")
public Result<UsageStatsResponse> getUsageStats(
@RequestParam @DateTimeFormat(iso = DateTimeFormat.ISO.DATE) LocalDate startDate,
@RequestParam @DateTimeFormat(iso = DateTimeFormat.ISO.DATE) LocalDate endDate) {
// 获取系统整体统计
SystemUsageStats stats = tokenUsageService.getSystemStats(LocalDate.now());
UsageStatsResponse response = UsageStatsResponse.builder()
.todayRequests(stats.getTotalRequests())
.todayTokens(stats.getTotalTokens())
.todayCost(stats.getTotalCost())
.modelDistribution(stats.getModelDistribution())
.build();
return Result.success(response);
}
private ModelInfo toModelInfo(ModelProvider provider) {
return ModelInfo.builder()
.code(provider.name())
.name(provider.getDisplayName())
.provider(provider.getProvider())
.modelName(provider.getModelName())
.pricePer1K(provider.getPricePer1K())
.isCloud(provider.isCloud())
.isDefault(provider == modelSwitchService.getCurrentDefault())
.build();
}
}
7.2 面试服务中使用多模型
java
/**
* 面试服务 - 多模型调用示例
*/
@Service
@Slf4j
public class InterviewService {
@Autowired
private LlmService llmService;
@Autowired
private PromptTemplateService promptTemplateService;
@Autowired
private InterviewSessionRepository sessionRepository;
/**
* 开始面试 - 根据用户套餐选择模型
*/
@CheckQuota
public InterviewResponse startInterview(StartInterviewRequest request, User user) {
// 根据场景选择模型
String scenario = "interview";
// 渲染Prompt
String prompt = promptTemplateService.renderTemplate("interview-system",
Map.of(
"jobType", request.getJobType(),
"experience", request.getExperience(),
"round", request.getRound()
));
// 调用AI生成第一个问题
String response = llmService.chatWithScenario(prompt, scenario,
user.getPlan().name());
// 创建会话
InterviewSession session = InterviewSession.builder()
.userId(user.getId())
.jobType(request.getJobType())
.status(InterviewStatus.IN_PROGRESS)
.build();
sessionRepository.save(session);
return InterviewResponse.builder()
.sessionId(session.getId())
.question(response)
.build();
}
/**
* 回答面试问题 - 流式响应
*/
public Flux<String> answerQuestionStream(Long sessionId, String answer, User user) {
InterviewSession session = sessionRepository.findById(sessionId)
.orElseThrow(() -> new BusinessException(ResultCode.INTERVIEW_NOT_FOUND));
// 构建带上下文的Prompt
String prompt = buildContextualPrompt(session, answer);
// 流式调用
return llmService.streamChat(prompt);
}
/**
* 评估面试表现 - 使用更强的模型
*/
public InterviewEvaluation evaluateInterview(Long sessionId, User user) {
InterviewSession session = sessionRepository.findById(sessionId)
.orElseThrow(() -> new BusinessException(ResultCode.INTERVIEW_NOT_FOUND));
// 使用GPT-4进行评估
String prompt = buildEvaluationPrompt(session);
InterviewEvaluationResult result = llmService.chatForEntity(
prompt,
InterviewEvaluationResult.class,
ModelProvider.OPENAI_GPT4
);
return InterviewEvaluation.builder()
.sessionId(sessionId)
.overallScore(result.getOverallScore())
.technicalScore(result.getTechnicalScore())
.communicationScore(result.getCommunicationScore())
.feedback(result.getFeedback())
.build();
}
private String buildContextualPrompt(InterviewSession session, String answer) {
// 获取历史对话
List<ChatMessage> history = session.getMessages();
StringBuilder prompt = new StringBuilder();
prompt.append("面试岗位:").append(session.getJobType()).append("\n\n");
prompt.append("历史对话:\n");
for (ChatMessage msg : history) {
prompt.append(msg.getRole() == Role.USER ? "候选人:" : "面试官:")
.append(msg.getContent())
.append("\n");
}
prompt.append("候选人最新回答:").append(answer).append("\n\n");
prompt.append("请作为面试官给出回应,可以是追问或评价。");
return prompt.toString();
}
private String buildEvaluationPrompt(InterviewSession session) {
return """
请对以下技术面试进行全面评估:
岗位类型:%s
面试对话记录:
%s
请从以下维度给出评分(0-100)和详细反馈:
1. 整体表现
2. 技术能力
3. 沟通表达
4. 改进建议
请以JSON格式返回结果。
""".formatted(session.getJobType(), formatConversation(session.getMessages()));
}
}
八、总结
8.1 本文要点
本文详细介绍了Spring AI多模型架构的完整实现方案:
- Spring AI核心概念:ChatClient、Prompt模板、流式响应
- 多模型配置管理:ModelProvider枚举、YAML配置、属性类
- 动态模型切换:ModelSwitchService实现运行时热切换
- 统一LLM调用:LlmService封装所有模型调用细节
- Token计费管理:使用记录、配额检查、成本统计
- 代码示例:ModelSwitchController、InterviewService实战
8.2 架构优势
- 灵活性:运行时切换模型,无需重启服务
- 成本优化:根据场景和用户套餐智能选择模型
- 可扩展性:新增模型只需添加枚举和配置
- 可观测性:完整的Token使用统计和成本分析
8.3 后续优化方向
- 模型缓存:响应结果缓存,减少重复调用
- A/B测试:不同模型效果对比
- 智能路由:基于历史数据自动选择最优模型
- 模型微调:针对特定场景微调本地模型
参考资料
- Spring AI Documentation
- OpenAI API Reference
- DeepSeek API Documentation
- MiniMax API Documentation
- Ollama Documentation
本系列文章到此结束,感谢阅读!如果对你有帮助,欢迎点赞、收藏、评论交流!