单模型成本高、风险大?Spring AI多模型路由实战:成本降70%,可用性更稳
很多 AI 应用刚上线时,架构都很简单:
text
用户请求 -> GPT-4o -> 返回结果
这条链路能跑,但流量一上来就会暴露三个问题:
- 简单 FAQ 也走强模型,成本不划算
- 单模型异常时,业务入口跟着不可用
- 想接本地模型或低价模型,却不知道怎么控制质量边界
更合理的做法是:按任务复杂度路由到不同模型。
本文用 Spring AI 做一个可落地的多模型路由器:
- SIMPLE:走 Ollama 本地模型,适合 FAQ 和短问答
- MEDIUM:走 GPT-4o-mini,适合摘要、改写、普通文案
- COMPLEX:走 GPT-4o,适合代码、架构分析、多步骤推理
- 异常时按
COMPLEX -> MEDIUM -> SIMPLE逐级 fallback - 预算快超限时强制降级到 SIMPLE
1. 架构思路
Spring AI 的 ChatClient 支持在一个应用里创建多个 ChatClient。官方文档也把"不同任务用不同模型""fallback""A/B 测试"列为典型场景。
本文的架构是:
text
Request
-> Complexity Assessor
-> Model Router
-> SIMPLE -> Ollama(qwen3:4b)
-> MEDIUM -> GPT-4o-mini
-> COMPLEX -> GPT-4o
-> Fallback Chain
-> Cost Monitor
路由器不关心"哪个供应商最强",只关心"这个任务需要什么能力"。这点很重要:后面你要把 SIMPLE 换成 DeepSeek、本地 Qwen,或者把 COMPLEX 换成其他强模型,都不应该影响业务 Controller。
2. 定义任务复杂度
java
public enum TaskComplexity {
SIMPLE, // FAQ、短问答、知识查询
MEDIUM, // 摘要、改写、普通文案
COMPLEX // 代码生成、长文分析、多步骤推理
}
第一版可以先用规则判断,后面再升级成轻量分类模型。
3. 配置多个 ChatClient
application.yml:
yaml
spring:
ai:
chat:
client:
enabled: false
openai:
api-key: ${OPENAI_API_KEY}
ollama:
base-url: http://localhost:11434
ai:
budget:
monthly: 1000
本地模型先拉取:
bash
ollama pull qwen3:4b
配置类:
java
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.api.OllamaChatOptions;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration
public class MultiModelConfig {
@Bean("simpleClient")
public ChatClient simpleClient(OllamaChatModel ollamaChatModel) {
return ChatClient.builder(ollamaChatModel)
.defaultOptions(OllamaChatOptions.builder()
.model("qwen3:4b")
.temperature(0.2)
.build())
.build();
}
@Bean("mediumClient")
public ChatClient mediumClient(OpenAiChatModel openAiChatModel) {
return ChatClient.builder(openAiChatModel)
.defaultOptions(OpenAiChatOptions.builder()
.model("gpt-4o-mini")
.temperature(0.3)
.build())
.build();
}
@Bean("complexClient")
public ChatClient complexClient(OpenAiChatModel openAiChatModel) {
return ChatClient.builder(openAiChatModel)
.defaultOptions(OpenAiChatOptions.builder()
.model("gpt-4o")
.temperature(0.2)
.build())
.build();
}
}
4. 实现路由器
java
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
@Service
@Slf4j
public class ModelRouter {
private final ChatClient simpleClient;
private final ChatClient mediumClient;
private final ChatClient complexClient;
public ModelRouter(
@Qualifier("simpleClient") ChatClient simpleClient,
@Qualifier("mediumClient") ChatClient mediumClient,
@Qualifier("complexClient") ChatClient complexClient) {
this.simpleClient = simpleClient;
this.mediumClient = mediumClient;
this.complexClient = complexClient;
}
public String route(String userInput) {
TaskComplexity complexity = assessComplexity(userInput);
return routeTo(complexity, userInput);
}
public String routeTo(TaskComplexity complexity, String userInput) {
return executeWithFallback(fallbackChain(complexity), userInput);
}
public String routeToSimple(String userInput) {
return executeWithFallback(List.of(simpleClient), userInput);
}
public TaskComplexity assessComplexity(String input) {
int length = input.length();
String lower = input.toLowerCase();
if (length < 50 &&
(lower.contains("怎么") || lower.contains("什么是") || lower.contains("请问"))) {
return TaskComplexity.SIMPLE;
}
if (length > 200 ||
lower.contains("实现") || lower.contains("编写") ||
lower.contains("分析") || lower.contains("比较") ||
lower.contains("架构") || lower.contains("性能优化")) {
return TaskComplexity.COMPLEX;
}
return TaskComplexity.MEDIUM;
}
private List<ChatClient> fallbackChain(TaskComplexity complexity) {
return switch (complexity) {
case COMPLEX -> List.of(complexClient, mediumClient, simpleClient);
case MEDIUM -> List.of(mediumClient, simpleClient);
case SIMPLE -> List.of(simpleClient);
};
}
private String executeWithFallback(List<ChatClient> clients, String input) {
for (int i = 0; i < clients.size(); i++) {
try {
return clients.get(i)
.prompt()
.system("你是企业级 AI 助手。回答要准确、简洁;不确定时说明不确定。")
.user(input)
.call()
.content();
} catch (Exception e) {
log.warn("模型调用失败,index={}, reason={}", i, e.getMessage());
}
}
throw new IllegalStateException("所有模型调用失败,AI 服务暂时不可用");
}
}
这里的关键不是规则有多复杂,而是 fallback 链必须单向、有限、可观测。不要写成 A 失败跳 B,B 失败又跳回 A。
5. 加上预算降级
java
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
@Service
@Slf4j
public class CostMonitoredRouter {
private static final Map<TaskComplexity, Double> ESTIMATED_COST = Map.of(
TaskComplexity.SIMPLE, 0.0,
TaskComplexity.MEDIUM, 0.0003,
TaskComplexity.COMPLEX, 0.003
);
private final ModelRouter router;
private final double monthlyBudget;
private final Map<TaskComplexity, AtomicLong> callCounts = new ConcurrentHashMap<>();
public CostMonitoredRouter(ModelRouter router,
@Value("${ai.budget.monthly:1000}") double monthlyBudget) {
this.router = router;
this.monthlyBudget = monthlyBudget;
}
public String routeWithBudgetCheck(String userInput) {
TaskComplexity complexity = router.assessComplexity(userInput);
double nextCost = ESTIMATED_COST.getOrDefault(complexity, 0.001);
if (currentMonthlySpend() + nextCost > monthlyBudget) {
log.warn("预算即将超限,降级到 simpleClient");
record(TaskComplexity.SIMPLE);
return router.routeToSimple(userInput);
}
String response = router.routeTo(complexity, userInput);
record(complexity);
return response;
}
private void record(TaskComplexity complexity) {
callCounts.computeIfAbsent(complexity, key -> new AtomicLong()).incrementAndGet();
}
private double currentMonthlySpend() {
return callCounts.entrySet().stream()
.mapToDouble(entry -> entry.getValue().get()
* ESTIMATED_COST.getOrDefault(entry.getKey(), 0.001))
.sum();
}
}
生产环境不要只用内存 Map,建议把 token usage、模型名、用户、业务线都打到可观测系统或账单表里。
6. Controller 接入
java
@RestController
@RequestMapping("/api/ai")
public class AiController {
private final ModelRouter router;
private final CostMonitoredRouter costRouter;
public AiController(ModelRouter router, CostMonitoredRouter costRouter) {
this.router = router;
this.costRouter = costRouter;
}
@GetMapping("/chat")
public ResponseEntity<String> chat(@RequestParam String message) {
return ResponseEntity.ok(router.route(message));
}
@GetMapping("/chat-with-budget")
public ResponseEntity<String> chatWithBudget(@RequestParam String message) {
return ResponseEntity.ok(costRouter.routeWithBudgetCheck(message));
}
}
构造器注入比在方法参数里写 @Autowired 稳得多,也更符合 Spring MVC 的使用习惯。
7. 成本收益怎么判断
多模型路由的收益取决于任务分布。一个比较常见的情况是:
- 60% 是简单 FAQ / 短问答
- 30% 是摘要、改写、普通文案
- 10% 是复杂代码或分析
如果简单任务从强模型迁到本地模型,月度成本下降 50% - 70% 是有机会的。但这个数字必须用你自己的真实 token、真实流量、真实模型价格重新算。
上线前建议压三类指标:
- 路由准确率:复杂任务有没有被错分到简单模型
- fallback 成功率:主模型失败时备用模型是否真的接住
- 单位请求成本:按业务线统计,而不是只看总账单
8. 踩坑
第一,便宜模型不是强模型平替。
本地模型适合低风险任务。涉及法律、财务、复杂代码生成时,最好继续走强模型或加人工审核。
第二,fallback 要可观测。
每次降级都要记录原模型、目标模型、异常原因。否则线上回复质量波动时,你很难定位。
第三,预算降级要真的调用模型。
不要预算超了就返回一句"已降级"。用户要的是答案,不是系统状态。
第四,版本要跟官方文档对齐。
Spring AI 版本更新很快,starter 名称、配置项、Options 类都可能变化。写进生产代码前,一定按当前版本文档确认。
结尾
多模型路由不是炫技,而是企业 AI 应用从 demo 走向生产的必经步骤。
简单任务交给低成本模型,复杂任务保留强模型,异常时有 fallback,预算快爆时能自动降级。这样系统不会因为某一个模型慢、贵、挂而整体失控。
讨论话题:你们在生产环境里会把所有 AI 请求都交给同一个模型吗?有没有做过模型降级、fallback 或成本监控?