RAG核心升级|多LLM模型动态切换方案
一、多模型切换核心价值
传统RAG架构绑定单一LLM模型,无法适配多样化的问答场景;而多模型动态切换方案 可根据场景/需求灵活选择最优模型,核心价值如下: 场景化适配:技术文档问答用deepseek-r1(代码/技术理解强),通用闲聊用qwen2-7b(中文流畅度高),精准匹配不同场景需求; 性能与成本平衡:简单问答用轻量模型(如7B),复杂推理用高性能模型(如8B/70B),降低整体推理成本; 容错与容灾:单个模型服务异常时,可快速切换到备用模型,提升系统可用性; 灵活扩展:新增模型仅需配置YAML,无需修改核心代码,支持无缝接入Llama3、Gemini等新模型; 个性化调优:不同模型独立配置参数(温度、最大生成长度),针对场景精细化调优,提升回答质量。
二、多模型切换 vs 单模型绑定(核心优势对比)
| 对比维度 | 单模型绑定(传统方案) | 多模型动态切换(优化方案) |
|---|---|---|
| 场景适配性 | 单一模型无法兼顾所有场景(如技术问答/通用闲聊) | 按场景选择最优模型,精准匹配不同问答需求 |
| 参数调优 | 全局参数配置,无法针对场景精细化调优 | 每个模型独立配置参数(温度、max_tokens),灵活调优 |
| 系统可用性 | 模型服务异常则整体不可用 | 支持模型降级/切换,单个模型故障不影响整体服务 |
| 扩展成本 | 新增模型需修改代码、重新部署 | 仅需新增YAML配置,热扩展无需重启服务 |
| 资源利用效率 | 轻量场景浪费大模型资源,复杂场景轻量模型能力不足 | 按需分配模型资源,平衡性能与成本 |
| 迭代灵活性 | 模型升级需全量替换,风险高 | 可灰度验证新模型,逐步切换,降低迭代风险 |
三、完整实现方案
1. YAML配置(结构化+可扩展)
yaml
spring:
ai:
ollama:
# Ollama 服务基础地址(统一配置,避免硬编码)
base-url: http://127.0.0.1:11434
# 多Chat模型配置列表(核心:支持动态扩展)
chats:
# 模型1:通义千问(默认模型,中文通用场景)
- code: qwen2-7b
model: qwen2:7b-instruct
options:
temperature: 0.3 # 低随机性,保证回答精准
max_tokens: 2000 # 适配通用问答长度
connect_timeout: 60s # 连接超时
read_timeout: 5m # 读取超时
context:
expire-minutes: 120 # 会话上下文过期时间
max-message-count: 20 # 最大上下文消息数
# 模型2:深度求索(技术/代码场景专用)
- code: deepseek-r1
model: deepseek-r1:8b
options:
temperature: 0.5 # 中等随机性,兼顾创意与精准
max_tokens: 3000 # 适配长文本技术回答
connect_timeout: 60s
read_timeout: 5m
context:
expire-minutes: 120
max-message-count: 20
# 扩展:新增模型仅需添加此节点,无需修改代码
# - code: llama3-8b
# model: llama3:8b-instruct
# options:
# temperature: 0.4
# max_tokens: 2500
# context:
# expire-minutes: 120
# max-message-count: 20
2. 多模型配置绑定类
java
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
import java.util.List;
/**
* Ollama多模型配置绑定类【生产级】
* 职责:绑定YAML中的多模型配置,提供类型安全的配置访问
*/
@Data
@Component
@ConfigurationProperties(prefix = "spring.ai.ollama")
public class OllamaMultiModelProperties {
// Ollama基础服务地址(必填)
private String baseUrl;
// 多Chat模型配置列表(至少配置1个模型)
private List<OllamaChatModelConfig> chats;
/**
* 初始化校验:确保核心配置非空
*/
public void afterPropertiesSet() {
Assert.hasText(baseUrl, "spring.ai.ollama.base-url 不能为空");
Assert.notEmpty(chats, "spring.ai.ollama.chats 至少配置一个模型");
// 校验每个模型的核心配置
for (OllamaChatModelConfig config : chats) {
Assert.hasText(config.getCode(), "模型编码(code)不能为空");
Assert.hasText(config.getModel(), "模型名称(model)不能为空");
Assert.notNull(config.getOptions(), "模型参数(options)不能为空");
}
}
/**
* 单个Chat模型的完整配置项
*/
@Data
public static class OllamaChatModelConfig {
// 模型唯一编码(如qwen2-7b,用于前端/业务层指定模型)
private String code;
// Ollama中注册的模型名称(如qwen2:7b-instruct)
private String model;
// 模型推理参数(温度、最大生成长度等)
private OllamaChatOptionsConfig options;
// 会话上下文配置(过期时间、最大消息数)
private OllamaChatContextConfig context;
/**
* 模型推理参数配置
*/
@Data
public static class OllamaChatOptionsConfig {
// 温度:0-1,值越高回答越随机,越低越精准
private Double temperature = 0.3; // 默认值
// 最大生成长度:限制单次回答的token数
private Integer maxTokens = 2000; // 默认值
// 连接超时(如60s)
private String connectTimeout = "60s"; // 默认值
// 读取超时(如5m)
private String readTimeout = "5m"; // 默认值
}
/**
* 会话上下文配置
*/
@Data
public static class OllamaChatContextConfig {
// 上下文过期时间(分钟)
private Integer expireMinutes = 120; // 默认值
// 最大上下文消息数(避免上下文过长)
private Integer maxMessageCount = 20; // 默认值
}
}
}
3. 多模型工厂配置类
java
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
import org.springframework.ai.ollama.OllamaApi;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.OllamaOptions;
import org.springframework.ai.observation.ObservationRegistry;
import org.springframework.ai.tool.ToolCallingManager;
import org.springframework.ai.model.ModelManagementOptions;
import org.springframework.beans.factory.annotation.Autowired;
import javax.annotation.Resource;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
/**
* 多LLM模型配置类【生产级】
* 职责:注册多模型实例、提供动态获取ChatClient的工厂,支撑模型切换
*/
@Configuration
@Slf4j
public class ChatClientConfig {
@Resource
private OllamaMultiModelProperties ollamaMultiModelProperties;
// ========== 1. 注册默认Chat模型(兼容原有业务逻辑) ==========
@Bean
@Primary
public OllamaChatModel defaultOllamaChatModel() {
log.info("初始化默认Chat模型:qwen2-7b");
return createOllamaChatModel("qwen2-7b");
}
// ========== 2. 注册多模型映射表(核心:存储所有模型实例) ==========
@Bean
public Map<String, OllamaChatModel> ollamaChatModelMap() {
Map<String, OllamaChatModel> modelMap = new HashMap<>();
// 遍历配置,初始化所有模型实例
for (OllamaMultiModelProperties.OllamaChatModelConfig config : ollamaMultiModelProperties.getChats()) {
String modelCode = config.getCode();
try {
OllamaChatModel chatModel = createOllamaChatModel(modelCode);
modelMap.put(modelCode, chatModel);
log.info("成功初始化模型:{},对应Ollama模型名:{}", modelCode, config.getModel());
} catch (Exception e) {
log.error("初始化模型【{}】失败,跳过该模型", modelCode, e);
// 单个模型初始化失败不影响整体流程
continue;
}
}
// 校验模型映射表非空
if (modelMap.isEmpty()) {
throw new RuntimeException("未成功初始化任何LLM模型,请检查配置");
}
return modelMap;
}
// ========== 3. 注册ChatClient工厂(对外提供模型切换能力) ==========
@Bean
public ChatClientFactory chatClientFactory(Map<String, OllamaChatModel> ollamaChatModelMap) {
return new ChatClientFactory(ollamaChatModelMap);
}
// ========== 4. 兼容原有ChatClient.Builder(避免业务代码改造) ==========
@Bean
public org.springframework.ai.chat.ChatClient.Builder defaultChatClientBuilder(OllamaChatModel defaultOllamaChatModel) {
return org.springframework.ai.chat.ChatClient.builder(defaultOllamaChatModel);
}
// ========== 私有方法:根据模型编码创建OllamaChatModel实例 ==========
private OllamaChatModel createOllamaChatModel(String modelCode) {
// 1. 根据编码查找模型配置(优化:使用流式API,避免循环)
Optional<OllamaMultiModelProperties.OllamaChatModelConfig> modelConfigOpt =
ollamaMultiModelProperties.getChats().stream()
.filter(config -> modelCode.equals(config.getCode()))
.findFirst();
// 兜底:使用第一个模型配置
OllamaMultiModelProperties.OllamaChatModelConfig modelConfig = modelConfigOpt.orElseGet(() -> {
log.warn("未找到模型【{}】的配置,使用第一个模型作为兜底", modelCode);
return ollamaMultiModelProperties.getChats().get(0);
});
// 2. 构建Ollama API客户端(基础地址统一配置)
OllamaApi ollamaApi = OllamaApi.builder()
.baseUrl(ollamaMultiModelProperties.getBaseUrl())
.build();
// 3. 构建模型参数(动态读取配置,支持所有Ollama参数)
OllamaOptions options = OllamaOptions.builder()
.model(modelConfig.getModel())
.temperature(modelConfig.getOptions().getTemperature())
.build();
options.setMaxTokens(modelConfig.getOptions().getMaxTokens());
// 4. 固定配置(工具调用、观测、模型管理)
ToolCallingManager toolCallingManager = ToolCallingManager.builder().build();
ModelManagementOptions modelManagementOptions = ModelManagementOptions.builder().build();
ObservationRegistry observationRegistry = ObservationRegistry.create();
// 5. 创建并返回模型实例
return new OllamaChatModel(ollamaApi, options, toolCallingManager, observationRegistry, modelManagementOptions);
}
// ========== 核心内部类:ChatClient工厂(对外暴露模型切换能力) ==========
public static class ChatClientFactory {
private final Map<String, OllamaChatModel> chatModelMap;
// 默认模型编码(可配置化)
private static final String DEFAULT_MODEL_CODE = "qwen2-7b";
public ChatClientFactory(Map<String, OllamaChatModel> chatModelMap) {
this.chatModelMap = chatModelMap;
}
/**
* 根据模型编码动态获取ChatClient(核心方法)
* @param modelCode 模型编码(qwen2-7b/deepseek-r1)
* @return 绑定对应模型的ChatClient
*/
public org.springframework.ai.chat.ChatClient getChatClient(String modelCode) {
// 空值处理:使用默认模型
if (modelCode == null || modelCode.isBlank()) {
log.info("模型编码为空,使用默认模型:{}", DEFAULT_MODEL_CODE);
return getDefaultChatClient();
}
// 获取指定模型
OllamaChatModel chatModel = chatModelMap.get(modelCode);
if (chatModel == null) {
log.warn("模型【{}】不存在,降级到默认模型:{}", modelCode, DEFAULT_MODEL_CODE);
return getDefaultChatClient();
}
log.debug("使用模型【{}】创建ChatClient", modelCode);
return org.springframework.ai.chat.ChatClient.builder(chatModel).build();
}
/**
* 获取默认ChatClient(兼容原有逻辑)
*/
public org.springframework.ai.chat.ChatClient getDefaultChatClient() {
OllamaChatModel defaultModel = chatModelMap.get(DEFAULT_MODEL_CODE);
return org.springframework.ai.chat.ChatClient.builder(defaultModel).build();
}
/**
* 获取所有可用模型编码(用于前端展示/校验)
*/
public Map<String, String> listAllModels() {
Map<String, String> modelList = new HashMap<>();
for (Map.Entry<String, OllamaChatModel> entry : chatModelMap.entrySet()) {
String modelCode = entry.getKey();
String modelName = entry.getValue().getDefaultOptions().getModel();
modelList.put(modelCode, modelName);
}
return modelList;
}
}
}
4. 多模型调用核心业务代码(生产级优化)
java
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.messages.ChatMessage;
import org.springframework.ai.chat.prompt.Prompt;
import javax.annotation.Resource;
import java.util.List;
/**
* 多模型问答核心Controller【生产级】
* 职责:接收问答请求,根据模型编码切换LLM模型,返回回答结果
*/
@RestController
@RequestMapping("/api/rag")
@Slf4j
public class RagChatController {
@Resource
private ChatClientConfig.ChatClientFactory chatClientFactory;
@Resource
private RedisChatContextManager redisChatContextManager;
@Resource
private AgentThinker agentThinker;
@Resource
private AgentExecutor agentExecutor;
@Resource
private AgentSummarizer agentSummarizer;
/**
* 核心问答接口(支持多模型切换)
* @param param 问答参数(包含query、sessionId、modelCode)
*/
@PostMapping("/chat")
public Result<String> chat(@RequestBody AgentChatParam param) {
// 1. 参数校验
String query = param.getQuery();
String sessionId = param.getSessionId();
String modelCode = param.getModelCode();
if (query == null || query.isBlank()) {
return Result.fail("提问内容不能为空");
}
if (sessionId == null || sessionId.isBlank()) {
return Result.fail("会话ID不能为空");
}
log.info("处理问答请求:sessionId={}, query={}, 目标模型={}", sessionId, query, modelCode);
try {
// 2. Agent思考:决策调用工具/直接回答
List<ThinkResult> thinkResults = agentThinker.think(query, sessionId);
StringBuilder toolResults = new StringBuilder();
for (ThinkResult thinkResult : thinkResults) {
if ("DIRECT_ANSWER".equals(thinkResult.getAction())) {
toolResults.append(thinkResult.getContent());
} else {
// 3. Agent执行:调用工具(如知识库检索)
ToolResult toolResult = agentExecutor.execute(thinkResult);
if (toolResult.isSuccess()) {
toolResults.append(toolResult.getContent());
} else {
log.error("工具调用失败:{}", toolResult.getErrorMsg());
return Result.fail("处理失败:" + toolResult.getErrorMsg());
}
}
}
// 4. 动态获取ChatClient(根据模型编码切换)
ChatClient chatClient = chatClientFactory.getChatClient(modelCode);
// 5. Agent总结:调用指定模型生成最终回答
String answer = agentSummarizer.summarize(query, toolResults.toString(), sessionId, chatClient);
// 6. 保存会话上下文到Redis
redisChatContextManager.updateChatContext(sessionId, ChatMessage.user(query));
redisChatContextManager.updateChatContext(sessionId, ChatMessage.assistant(answer));
log.info("问答请求处理完成:sessionId={}, 模型={}, 回答长度={}", sessionId, modelCode, answer.length());
return Result.success(answer);
} catch (Exception e) {
log.error("问答请求处理失败:sessionId={}", sessionId, e);
return Result.fail("处理失败:" + e.getMessage());
}
}
/**
* 模型列表查询接口(用于前端选择模型)
*/
@PostMapping("/models/list")
public Result<Map<String, String>> listModels() {
Map<String, String> modelList = chatClientFactory.listAllModels();
return Result.success(modelList);
}
}
5. 模型总结器适配多模型
java
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.stereotype.Service;
import javax.annotation.Resource;
/**
* Agent总结器(适配多模型)
*/
@Service
public class AgentSummarizer {
@Resource
private ChatClientConfig.ChatClientFactory chatClientFactory;
/**
* 生成最终回答(支持指定模型)
*/
public String summarize(String query, String toolResults, String sessionId, ChatClient chatClient) {
// 构建提示词(结合工具结果+上下文)
String promptContent = buildSummaryPrompt(query, toolResults, sessionId);
// 调用指定模型生成回答
return chatClient.prompt(new Prompt(prompt))
.call().chatResponse()
.getResult()
.getOutput()
.getText();
}
/**
* 兼容原有方法(使用默认模型)
*/
public String summarize(String query, String toolResults, String sessionId) {
ChatClient defaultClient = chatClientFactory.getDefaultChatClient();
return summarize(query, toolResults, sessionId, defaultClient);
}
/**
* 构建总结提示词
*/
private String buildSummaryPrompt(String query, String toolResults, String sessionId) {
// 补充上下文、工具结果,构建完整提示词
return String.format("根据以下信息回答问题:\n工具结果:%s\n问题:%s\n要求:回答准确、简洁,符合中文表达习惯",
toolResults, query);
}
}
6. 核心参数实体类
java
import lombok.Data;
/**
* 问答请求参数实体
*/
@Data
public class AgentChatParam {
/** 用户提问内容 */
private String query;
/** 会话ID(用于上下文管理) */
private String sessionId;
/** 模型编码(如qwen2-7b/deepseek-r1) */
private String modelCode;
}
四、多模型切换完整链路(RAG集成)
markdown
1. 前端请求:传入query + sessionId + modelCode(如deepseek-r1)
2. 参数校验:验证必填参数,日志记录请求信息
3. Agent思考:决策调用知识库检索/直接回答
4. 工具执行:调用知识库检索,获取相关文本片段
5. 模型工厂:根据modelCode获取对应的ChatClient
6. 生成回答:调用指定模型,结合工具结果生成回答
7. 上下文保存:将问答记录存入Redis,支撑多轮对话
8. 返回结果:返回回答内容,前端展示
五、生产级进阶扩展(可选)
扩展1:模型路由策略
根据提问内容自动选择最优模型(无需前端指定):
java
/**
* 模型路由策略(示例)
*/
@Service
public class ModelRouter {
/**
* 根据提问内容自动选择模型
*/
public String routeModel(String query) {
// 技术相关问题→deepseek-r1
if (query.contains("代码") || query.contains("技术") || query.contains("接口")) {
return "deepseek-r1";
}
// 通用问题→qwen2-7b
return "qwen2-7b";
}
}
扩展2:模型健康检查
定时检查模型可用性,异常时自动切换:
java
/**
* 模型健康检查(示例)
*/
@Scheduled(fixedRate = 60000) // 每分钟检查一次
public void checkModelHealth() {
Map<String, OllamaChatModel> modelMap = ollamaChatModelMap();
for (Map.Entry<String, OllamaChatModel> entry : modelMap.entrySet()) {
String modelCode = entry.getKey();
try {
// 发送测试请求,检查模型是否可用
entry.getValue().call(new Prompt("hello")).getResult();
log.info("模型【{}】健康检查通过", modelCode);
} catch (Exception e) {
log.error("模型【{}】健康检查失败", modelCode, e);
// 标记模型不可用,路由时跳过
unhealthyModels.add(modelCode);
}
}
}
扩展3:模型调用限流
针对不同模型设置限流规则,避免资源耗尽:
java
// 使用Sentinel实现模型级限流
@SentinelResource(value = "chat_qwen2_7b", fallback = "defaultFallback")
public String chatWithQwen2(String query, String sessionId) {
// 调用qwen2-7b模型的逻辑
}
@SentinelResource(value = "chat_deepseek_r1", fallback = "defaultFallback")
public String chatWithDeepseek(String query, String sessionId) {
// 调用deepseek-r1模型的逻辑
}
六、关键注意事项(生产落地必看)
模型资源隔离 :不同模型的推理资源(CPU/GPU)需隔离,避免单个模型占用全部资源; 上下文兼容 :不同模型的上下文格式/长度限制不同,需适配(如Llama3的上下文长度为8192,Qwen2为4096); 参数调优 :每个模型的最优参数不同(如deepseek-r1的temperature=0.5更优),需单独调优; 监控告警 :监控各模型的调用量、响应时间、失败率,异常时及时告警; 版本管理:模型升级时(如qwen2-7b→qwen2-14b),需保留旧模型配置,支持灰度切换。