RAG核心升级|多LLM模型动态切换方案

RAG核心升级|多LLM模型动态切换方案

一、多模型切换核心价值

传统RAG架构绑定单一LLM模型,无法适配多样化的问答场景;而多模型动态切换方案 可根据场景/需求灵活选择最优模型,核心价值如下: 场景化适配:技术文档问答用deepseek-r1(代码/技术理解强),通用闲聊用qwen2-7b(中文流畅度高),精准匹配不同场景需求; 性能与成本平衡:简单问答用轻量模型(如7B),复杂推理用高性能模型(如8B/70B),降低整体推理成本; 容错与容灾:单个模型服务异常时,可快速切换到备用模型,提升系统可用性; 灵活扩展:新增模型仅需配置YAML,无需修改核心代码,支持无缝接入Llama3、Gemini等新模型; 个性化调优:不同模型独立配置参数(温度、最大生成长度),针对场景精细化调优,提升回答质量。

二、多模型切换 vs 单模型绑定(核心优势对比)

对比维度 单模型绑定(传统方案) 多模型动态切换(优化方案)
场景适配性 单一模型无法兼顾所有场景(如技术问答/通用闲聊) 按场景选择最优模型,精准匹配不同问答需求
参数调优 全局参数配置,无法针对场景精细化调优 每个模型独立配置参数(温度、max_tokens),灵活调优
系统可用性 模型服务异常则整体不可用 支持模型降级/切换,单个模型故障不影响整体服务
扩展成本 新增模型需修改代码、重新部署 仅需新增YAML配置,热扩展无需重启服务
资源利用效率 轻量场景浪费大模型资源,复杂场景轻量模型能力不足 按需分配模型资源,平衡性能与成本
迭代灵活性 模型升级需全量替换,风险高 可灰度验证新模型,逐步切换,降低迭代风险

三、完整实现方案

1. YAML配置(结构化+可扩展)

yaml 复制代码
spring:
  ai:
    ollama:
      # Ollama 服务基础地址(统一配置,避免硬编码)
      base-url: http://127.0.0.1:11434
      # 多Chat模型配置列表(核心:支持动态扩展)
      chats:
        # 模型1:通义千问(默认模型,中文通用场景)
        - code: qwen2-7b
          model: qwen2:7b-instruct
          options:
            temperature: 0.3        # 低随机性,保证回答精准
            max_tokens: 2000        # 适配通用问答长度
            connect_timeout: 60s    # 连接超时
            read_timeout: 5m        # 读取超时
          context:
            expire-minutes: 120     # 会话上下文过期时间
            max-message-count: 20   # 最大上下文消息数
        # 模型2:深度求索(技术/代码场景专用)
        - code: deepseek-r1
          model: deepseek-r1:8b
          options:
            temperature: 0.5        # 中等随机性,兼顾创意与精准
            max_tokens: 3000        # 适配长文本技术回答
            connect_timeout: 60s
            read_timeout: 5m
          context:
            expire-minutes: 120
            max-message-count: 20
        # 扩展:新增模型仅需添加此节点,无需修改代码
        # - code: llama3-8b
        #   model: llama3:8b-instruct
        #   options:
        #     temperature: 0.4
        #     max_tokens: 2500
        #   context:
        #     expire-minutes: 120
        #     max-message-count: 20

2. 多模型配置绑定类

java 复制代码
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;

import java.util.List;

/**
 * Ollama多模型配置绑定类【生产级】
 * 职责:绑定YAML中的多模型配置,提供类型安全的配置访问
 */
@Data
@Component
@ConfigurationProperties(prefix = "spring.ai.ollama")
public class OllamaMultiModelProperties {
    // Ollama基础服务地址(必填)
    private String baseUrl;
    // 多Chat模型配置列表(至少配置1个模型)
    private List<OllamaChatModelConfig> chats;

    /**
     * 初始化校验:确保核心配置非空
     */
    public void afterPropertiesSet() {
        Assert.hasText(baseUrl, "spring.ai.ollama.base-url 不能为空");
        Assert.notEmpty(chats, "spring.ai.ollama.chats 至少配置一个模型");
        // 校验每个模型的核心配置
        for (OllamaChatModelConfig config : chats) {
            Assert.hasText(config.getCode(), "模型编码(code)不能为空");
            Assert.hasText(config.getModel(), "模型名称(model)不能为空");
            Assert.notNull(config.getOptions(), "模型参数(options)不能为空");
        }
    }

    /**
     * 单个Chat模型的完整配置项
     */
    @Data
    public static class OllamaChatModelConfig {
        // 模型唯一编码(如qwen2-7b,用于前端/业务层指定模型)
        private String code;
        // Ollama中注册的模型名称(如qwen2:7b-instruct)
        private String model;
        // 模型推理参数(温度、最大生成长度等)
        private OllamaChatOptionsConfig options;
        // 会话上下文配置(过期时间、最大消息数)
        private OllamaChatContextConfig context;

        /**
         * 模型推理参数配置
         */
        @Data
        public static class OllamaChatOptionsConfig {
            // 温度:0-1,值越高回答越随机,越低越精准
            private Double temperature = 0.3; // 默认值
            // 最大生成长度:限制单次回答的token数
            private Integer maxTokens = 2000; // 默认值
            // 连接超时(如60s)
            private String connectTimeout = "60s"; // 默认值
            // 读取超时(如5m)
            private String readTimeout = "5m"; // 默认值
        }

        /**
         * 会话上下文配置
         */
        @Data
        public static class OllamaChatContextConfig {
            // 上下文过期时间(分钟)
            private Integer expireMinutes = 120; // 默认值
            // 最大上下文消息数(避免上下文过长)
            private Integer maxMessageCount = 20; // 默认值
        }
    }
}

3. 多模型工厂配置类

java 复制代码
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
import org.springframework.ai.ollama.OllamaApi;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.OllamaOptions;
import org.springframework.ai.observation.ObservationRegistry;
import org.springframework.ai.tool.ToolCallingManager;
import org.springframework.ai.model.ModelManagementOptions;
import org.springframework.beans.factory.annotation.Autowired;
import javax.annotation.Resource;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

/**
 * 多LLM模型配置类【生产级】
 * 职责:注册多模型实例、提供动态获取ChatClient的工厂,支撑模型切换
 */
@Configuration
@Slf4j
public class ChatClientConfig {

    @Resource
    private OllamaMultiModelProperties ollamaMultiModelProperties;

    // ========== 1. 注册默认Chat模型(兼容原有业务逻辑) ==========
    @Bean
    @Primary
    public OllamaChatModel defaultOllamaChatModel() {
        log.info("初始化默认Chat模型:qwen2-7b");
        return createOllamaChatModel("qwen2-7b");
    }

    // ========== 2. 注册多模型映射表(核心:存储所有模型实例) ==========
    @Bean
    public Map<String, OllamaChatModel> ollamaChatModelMap() {
        Map<String, OllamaChatModel> modelMap = new HashMap<>();
        // 遍历配置,初始化所有模型实例
        for (OllamaMultiModelProperties.OllamaChatModelConfig config : ollamaMultiModelProperties.getChats()) {
            String modelCode = config.getCode();
            try {
                OllamaChatModel chatModel = createOllamaChatModel(modelCode);
                modelMap.put(modelCode, chatModel);
                log.info("成功初始化模型:{},对应Ollama模型名:{}", modelCode, config.getModel());
            } catch (Exception e) {
                log.error("初始化模型【{}】失败,跳过该模型", modelCode, e);
                // 单个模型初始化失败不影响整体流程
                continue;
            }
        }
        // 校验模型映射表非空
        if (modelMap.isEmpty()) {
            throw new RuntimeException("未成功初始化任何LLM模型,请检查配置");
        }
        return modelMap;
    }

    // ========== 3. 注册ChatClient工厂(对外提供模型切换能力) ==========
    @Bean
    public ChatClientFactory chatClientFactory(Map<String, OllamaChatModel> ollamaChatModelMap) {
        return new ChatClientFactory(ollamaChatModelMap);
    }

    // ========== 4. 兼容原有ChatClient.Builder(避免业务代码改造) ==========
    @Bean
    public org.springframework.ai.chat.ChatClient.Builder defaultChatClientBuilder(OllamaChatModel defaultOllamaChatModel) {
        return org.springframework.ai.chat.ChatClient.builder(defaultOllamaChatModel);
    }

    // ========== 私有方法:根据模型编码创建OllamaChatModel实例 ==========
    private OllamaChatModel createOllamaChatModel(String modelCode) {
        // 1. 根据编码查找模型配置(优化:使用流式API,避免循环)
        Optional<OllamaMultiModelProperties.OllamaChatModelConfig> modelConfigOpt =
                ollamaMultiModelProperties.getChats().stream()
                        .filter(config -> modelCode.equals(config.getCode()))
                        .findFirst();
        // 兜底:使用第一个模型配置
        OllamaMultiModelProperties.OllamaChatModelConfig modelConfig = modelConfigOpt.orElseGet(() -> {
            log.warn("未找到模型【{}】的配置,使用第一个模型作为兜底", modelCode);
            return ollamaMultiModelProperties.getChats().get(0);
        });

        // 2. 构建Ollama API客户端(基础地址统一配置)
        OllamaApi ollamaApi = OllamaApi.builder()
                .baseUrl(ollamaMultiModelProperties.getBaseUrl())
                .build();

        // 3. 构建模型参数(动态读取配置,支持所有Ollama参数)
        OllamaOptions options = OllamaOptions.builder()
                .model(modelConfig.getModel())
                .temperature(modelConfig.getOptions().getTemperature())
                .build();
        options.setMaxTokens(modelConfig.getOptions().getMaxTokens());

        // 4. 固定配置(工具调用、观测、模型管理)
        ToolCallingManager toolCallingManager = ToolCallingManager.builder().build();
        ModelManagementOptions modelManagementOptions = ModelManagementOptions.builder().build();
        ObservationRegistry observationRegistry = ObservationRegistry.create();

        // 5. 创建并返回模型实例
        return new OllamaChatModel(ollamaApi, options, toolCallingManager, observationRegistry, modelManagementOptions);
    }

    // ========== 核心内部类:ChatClient工厂(对外暴露模型切换能力) ==========
    public static class ChatClientFactory {
        private final Map<String, OllamaChatModel> chatModelMap;
        // 默认模型编码(可配置化)
        private static final String DEFAULT_MODEL_CODE = "qwen2-7b";

        public ChatClientFactory(Map<String, OllamaChatModel> chatModelMap) {
            this.chatModelMap = chatModelMap;
        }

        /**
         * 根据模型编码动态获取ChatClient(核心方法)
         * @param modelCode 模型编码(qwen2-7b/deepseek-r1)
         * @return 绑定对应模型的ChatClient
         */
        public org.springframework.ai.chat.ChatClient getChatClient(String modelCode) {
            // 空值处理:使用默认模型
            if (modelCode == null || modelCode.isBlank()) {
                log.info("模型编码为空,使用默认模型:{}", DEFAULT_MODEL_CODE);
                return getDefaultChatClient();
            }
            // 获取指定模型
            OllamaChatModel chatModel = chatModelMap.get(modelCode);
            if (chatModel == null) {
                log.warn("模型【{}】不存在,降级到默认模型:{}", modelCode, DEFAULT_MODEL_CODE);
                return getDefaultChatClient();
            }
            log.debug("使用模型【{}】创建ChatClient", modelCode);
            return org.springframework.ai.chat.ChatClient.builder(chatModel).build();
        }

        /**
         * 获取默认ChatClient(兼容原有逻辑)
         */
        public org.springframework.ai.chat.ChatClient getDefaultChatClient() {
            OllamaChatModel defaultModel = chatModelMap.get(DEFAULT_MODEL_CODE);
            return org.springframework.ai.chat.ChatClient.builder(defaultModel).build();
        }

        /**
         * 获取所有可用模型编码(用于前端展示/校验)
         */
        public Map<String, String> listAllModels() {
            Map<String, String> modelList = new HashMap<>();
            for (Map.Entry<String, OllamaChatModel> entry : chatModelMap.entrySet()) {
                String modelCode = entry.getKey();
                String modelName = entry.getValue().getDefaultOptions().getModel();
                modelList.put(modelCode, modelName);
            }
            return modelList;
        }
    }
}

4. 多模型调用核心业务代码(生产级优化)

java 复制代码
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.messages.ChatMessage;
import org.springframework.ai.chat.prompt.Prompt;
import javax.annotation.Resource;
import java.util.List;

/**
 * 多模型问答核心Controller【生产级】
 * 职责:接收问答请求,根据模型编码切换LLM模型,返回回答结果
 */
@RestController
@RequestMapping("/api/rag")
@Slf4j
public class RagChatController {

    @Resource
    private ChatClientConfig.ChatClientFactory chatClientFactory;
    @Resource
    private RedisChatContextManager redisChatContextManager;
    @Resource
    private AgentThinker agentThinker;
    @Resource
    private AgentExecutor agentExecutor;
    @Resource
    private AgentSummarizer agentSummarizer;

    /**
     * 核心问答接口(支持多模型切换)
     * @param param 问答参数(包含query、sessionId、modelCode)
     */
    @PostMapping("/chat")
    public Result<String> chat(@RequestBody AgentChatParam param) {
        // 1. 参数校验
        String query = param.getQuery();
        String sessionId = param.getSessionId();
        String modelCode = param.getModelCode();
        if (query == null || query.isBlank()) {
            return Result.fail("提问内容不能为空");
        }
        if (sessionId == null || sessionId.isBlank()) {
            return Result.fail("会话ID不能为空");
        }
        log.info("处理问答请求:sessionId={}, query={}, 目标模型={}", sessionId, query, modelCode);

        try {
            // 2. Agent思考:决策调用工具/直接回答
            List<ThinkResult> thinkResults = agentThinker.think(query, sessionId);
            StringBuilder toolResults = new StringBuilder();
            for (ThinkResult thinkResult : thinkResults) {
                if ("DIRECT_ANSWER".equals(thinkResult.getAction())) {
                    toolResults.append(thinkResult.getContent());
                } else {
                    // 3. Agent执行:调用工具(如知识库检索)
                    ToolResult toolResult = agentExecutor.execute(thinkResult);
                    if (toolResult.isSuccess()) {
                        toolResults.append(toolResult.getContent());
                    } else {
                        log.error("工具调用失败:{}", toolResult.getErrorMsg());
                        return Result.fail("处理失败:" + toolResult.getErrorMsg());
                    }
                }
            }

            // 4. 动态获取ChatClient(根据模型编码切换)
            ChatClient chatClient = chatClientFactory.getChatClient(modelCode);
            // 5. Agent总结:调用指定模型生成最终回答
            String answer = agentSummarizer.summarize(query, toolResults.toString(), sessionId, chatClient);

            // 6. 保存会话上下文到Redis
            redisChatContextManager.updateChatContext(sessionId, ChatMessage.user(query));
            redisChatContextManager.updateChatContext(sessionId, ChatMessage.assistant(answer));

            log.info("问答请求处理完成:sessionId={}, 模型={}, 回答长度={}", sessionId, modelCode, answer.length());
            return Result.success(answer);
        } catch (Exception e) {
            log.error("问答请求处理失败:sessionId={}", sessionId, e);
            return Result.fail("处理失败:" + e.getMessage());
        }
    }

    /**
     * 模型列表查询接口(用于前端选择模型)
     */
    @PostMapping("/models/list")
    public Result<Map<String, String>> listModels() {
        Map<String, String> modelList = chatClientFactory.listAllModels();
        return Result.success(modelList);
    }
}

5. 模型总结器适配多模型

java 复制代码
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.stereotype.Service;
import javax.annotation.Resource;

/**
 * Agent总结器(适配多模型)
 */
@Service
public class AgentSummarizer {

    @Resource
    private ChatClientConfig.ChatClientFactory chatClientFactory;

    /**
     * 生成最终回答(支持指定模型)
     */
    public String summarize(String query, String toolResults, String sessionId, ChatClient chatClient) {
        // 构建提示词(结合工具结果+上下文)
        String promptContent = buildSummaryPrompt(query, toolResults, sessionId);
        // 调用指定模型生成回答
        return chatClient.prompt(new Prompt(prompt))
                .call().chatResponse()
                .getResult()
                .getOutput()
                .getText();
    }

    /**
     * 兼容原有方法(使用默认模型)
     */
    public String summarize(String query, String toolResults, String sessionId) {
        ChatClient defaultClient = chatClientFactory.getDefaultChatClient();
        return summarize(query, toolResults, sessionId, defaultClient);
    }

    /**
     * 构建总结提示词
     */
    private String buildSummaryPrompt(String query, String toolResults, String sessionId) {
        // 补充上下文、工具结果,构建完整提示词
        return String.format("根据以下信息回答问题:\n工具结果:%s\n问题:%s\n要求:回答准确、简洁,符合中文表达习惯",
                toolResults, query);
    }
}

6. 核心参数实体类

java 复制代码
import lombok.Data;

/**
 * 问答请求参数实体
 */
@Data
public class AgentChatParam {
    /** 用户提问内容 */
    private String query;
    /** 会话ID(用于上下文管理) */
    private String sessionId;
    /** 模型编码(如qwen2-7b/deepseek-r1) */
    private String modelCode;
}

四、多模型切换完整链路(RAG集成)

markdown 复制代码
1. 前端请求:传入query + sessionId + modelCode(如deepseek-r1)
2. 参数校验:验证必填参数,日志记录请求信息
3. Agent思考:决策调用知识库检索/直接回答
4. 工具执行:调用知识库检索,获取相关文本片段
5. 模型工厂:根据modelCode获取对应的ChatClient
6. 生成回答:调用指定模型,结合工具结果生成回答
7. 上下文保存:将问答记录存入Redis,支撑多轮对话
8. 返回结果:返回回答内容,前端展示

五、生产级进阶扩展(可选)

扩展1:模型路由策略

根据提问内容自动选择最优模型(无需前端指定):

java 复制代码
/**
 * 模型路由策略(示例)
 */
@Service
public class ModelRouter {
    /**
     * 根据提问内容自动选择模型
     */
    public String routeModel(String query) {
        // 技术相关问题→deepseek-r1
        if (query.contains("代码") || query.contains("技术") || query.contains("接口")) {
            return "deepseek-r1";
        }
        // 通用问题→qwen2-7b
        return "qwen2-7b";
    }
}

扩展2:模型健康检查

定时检查模型可用性,异常时自动切换:

java 复制代码
/**
 * 模型健康检查(示例)
 */
@Scheduled(fixedRate = 60000) // 每分钟检查一次
public void checkModelHealth() {
    Map<String, OllamaChatModel> modelMap = ollamaChatModelMap();
    for (Map.Entry<String, OllamaChatModel> entry : modelMap.entrySet()) {
        String modelCode = entry.getKey();
        try {
            // 发送测试请求,检查模型是否可用
            entry.getValue().call(new Prompt("hello")).getResult();
            log.info("模型【{}】健康检查通过", modelCode);
        } catch (Exception e) {
            log.error("模型【{}】健康检查失败", modelCode, e);
            // 标记模型不可用,路由时跳过
            unhealthyModels.add(modelCode);
        }
    }
}

扩展3:模型调用限流

针对不同模型设置限流规则,避免资源耗尽:

java 复制代码
// 使用Sentinel实现模型级限流
@SentinelResource(value = "chat_qwen2_7b", fallback = "defaultFallback")
public String chatWithQwen2(String query, String sessionId) {
    // 调用qwen2-7b模型的逻辑
}

@SentinelResource(value = "chat_deepseek_r1", fallback = "defaultFallback")
public String chatWithDeepseek(String query, String sessionId) {
    // 调用deepseek-r1模型的逻辑
}

六、关键注意事项(生产落地必看)

模型资源隔离 :不同模型的推理资源(CPU/GPU)需隔离,避免单个模型占用全部资源; 上下文兼容 :不同模型的上下文格式/长度限制不同,需适配(如Llama3的上下文长度为8192,Qwen2为4096); 参数调优 :每个模型的最优参数不同(如deepseek-r1的temperature=0.5更优),需单独调优; 监控告警 :监控各模型的调用量、响应时间、失败率,异常时及时告警; 版本管理:模型升级时(如qwen2-7b→qwen2-14b),需保留旧模型配置,支持灰度切换。

相关推荐
yunni816 小时前
知识库 × AI写作:打通公文写作的“最后一公里”
大数据·人工智能
Baihai_IDP16 小时前
Andrej Karpathy:2025 年 LLM 领域的六项范式转变
人工智能·面试·llm
EntyIU16 小时前
自己实现mybatisplus的批量插入
java·后端
踩着两条虫16 小时前
VTJ.PRO「AI + 低代码」应用开发平台的后端模块系统
前端·人工智能·低代码
人工智能AI技术16 小时前
开源模型落地指南:DeepSeek微调实战,在垂直场景打造差异化竞争力
人工智能
一个会的不多的人16 小时前
人工智能基础篇:概念性名词浅谈(第二十二讲)
人工智能·制造·数字化转型
用户6174332731016 小时前
MySQL 表的类 Git 版本控制
后端
极新16 小时前
新看点/818AI创始人冷煜:AI落地,决胜“最后100米” | 2025极新AIGC峰会演讲实录
大数据·人工智能
环黄金线HHJX.16 小时前
《QuantumTuan ⇆ QT:Qt》
人工智能·qt·算法·编辑器·量子计算