04-Spring-AI多模型架构

Spring AI多模型架构:OpenAI/DeepSeek/MiniMax一键切换实战

一、前言

1.1 为什么需要多模型支持

在AI应用开发中,单一模型往往无法满足所有场景需求:

场景 推荐模型 原因
通用对话 GPT-4 / DeepSeek-V3 能力强、理解准确
代码生成 GPT-4 / Claude 3.5 代码质量高
中文优化 DeepSeek / MiniMax 中文理解更好
成本敏感 GPT-3.5 / DeepSeek-V2 价格便宜
离线部署 Ollama本地模型 数据隐私保护

多模型架构的核心价值

  • 成本优化:不同场景选择性价比最优的模型
  • 能力互补:利用各模型优势处理特定任务
  • 风险分散:避免单一供应商依赖
  • 用户体验:根据用户套餐提供差异化服务

1.2 Spring AI 简介

Spring AI 是 Spring 官方推出的 AI 应用开发框架,目标是简化 Java 开发者集成 LLM 的过程:

复制代码
Spring AI 核心特性:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

1. 模型无关的抽象层
   - ChatClient:统一聊天接口
   - EmbeddingClient:统一向量接口
   - ImageClient:统一图像接口

2. Prompt 管理
   - PromptTemplate:模板引擎
   - 支持变量替换和条件渲染

3. 向量存储
   - PgVector、Redis、Chroma等
   - 统一的Document和VectorStore接口

4. 函数调用
   - 支持模型调用本地Java方法
   - 实现Agent能力

5. 自动配置
   - Spring Boot Starter支持
   - 声明式配置

二、Spring AI 核心概念

2.1 ChatClient 与 ChatModel

java 复制代码
/**
 * Spring AI 核心接口关系
 */

// ChatClient - 面向应用的高级接口
public interface ChatClient {
    ChatClientRequestSpec prompt();  // 创建Prompt构建器
    String call(String message);     // 简单调用
    Flux<String> stream(String message);  // 流式调用
}

// ChatModel - 面向模型的底层接口
public interface ChatModel {
    ChatResponse call(Prompt prompt);     // 同步调用
    Flux<ChatResponse> stream(Prompt prompt);  // 流式调用
}

// 实际使用 - ChatClient更简洁
@Autowired
private ChatClient chatClient;

public String chat(String message) {
    return chatClient.prompt()
        .user(message)
        .call()
        .content();
}

2.2 Prompt 模板系统

java 复制代码
/**
 * Prompt 模板使用示例
 */

// 1. 简单模板
public String simplePrompt(String jobTitle, String resume) {
    return chatClient.prompt()
        .system("你是一位专业的" + jobTitle + "面试官")
        .user(u -> u.text("""
            请根据以下简历内容,生成5个面试问题:
            
            简历内容:
            {resume}
            """).param("resume", resume))
        .call()
        .content();
}

// 2. 结构化Prompt(推荐)
public InterviewResponse structuredPrompt(InterviewRequest request) {
    return chatClient.prompt()
        .system(systemPromptProvider.getInterviewSystemPrompt())
        .user(u -> u.text("""
            岗位类型:{jobType}
            工作经验:{experience}
            面试轮次:{round}
            
            请生成适合该候选人的面试问题。
            """)
            .param("jobType", request.getJobType())
            .param("experience", request.getExperience())
            .param("round", request.getRound()))
        .options(ChatOptions.builder()
            .temperature(0.7)
            .maxTokens(2000)
            .build())
        .call()
        .entity(InterviewResponse.class);  // 自动JSON解析
}

// 3. 多轮对话
public String multiTurnChat(List<Message> history, String newMessage) {
    ChatClient.ChatClientRequestSpec spec = chatClient.prompt();
    
    // 添加历史消息
    for (Message msg : history) {
        if (msg.getRole() == Role.USER) {
            spec.user(msg.getContent());
        } else {
            spec.system(msg.getContent());
        }
    }
    
    return spec.user(newMessage)
        .call()
        .content();
}

2.3 流式响应处理

java 复制代码
/**
 * 流式响应处理 - 实现打字机效果
 */
@RestController
@RequestMapping("/api/ai")
@Slf4j
public class StreamingController {
    
    @Autowired
    private ChatClient chatClient;
    
    /**
     * Server-Sent Events 流式输出
     */
    @GetMapping(value = "/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
    public Flux<ServerSentEvent<String>> streamChat(@RequestParam String message) {
        return chatClient.prompt()
            .user(message)
            .stream()
            .content()
            .map(content -> ServerSentEvent.builder(content).build())
            .onErrorResume(e -> {
                log.error("Stream error: ", e);
                return Flux.just(ServerSentEvent.builder("Error: " + e.getMessage()).build());
            });
    }
    
    /**
     * WebFlux 流式输出(前端使用EventSource接收)
     */
    @PostMapping(value = "/interview/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
    public Flux<String> interviewStream(@RequestBody InterviewRequest request) {
        String prompt = buildInterviewPrompt(request);
        
        return chatClient.prompt()
            .system("你是一位专业的技术面试官")
            .user(prompt)
            .options(ChatOptions.builder()
                .temperature(0.8)
                .maxTokens(4000)
                .build())
            .stream()
            .content()
            .doOnNext(chunk -> log.debug("Received chunk: {}", chunk))
            .doOnComplete(() -> log.info("Stream completed"));
    }
}

三、多模型配置管理

3.1 ModelProvider 枚举定义

java 复制代码
/**
 * 模型提供商枚举
 * 支持10+主流模型提供商
 */
@Getter
@AllArgsConstructor
public enum ModelProvider {
    
    // OpenAI 系列
    OPENAI_GPT4("openai", "gpt-4-turbo-preview", "OpenAI GPT-4", true, 50),
    OPENAI_GPT35("openai", "gpt-3.5-turbo", "OpenAI GPT-3.5", true, 10),
    
    // DeepSeek 系列(深度求索)
    DEEPSEEK_V3("deepseek", "deepseek-chat", "DeepSeek V3", true, 8),
    DEEPSEEK_CODER("deepseek", "deepseek-coder", "DeepSeek Coder", true, 8),
    
    // MiniMax 系列(海螺AI)
    MINIMAX_ABAB6("minimax", "abab6-chat", "MiniMax abab6", true, 15),
    MINIMAX_ABAB5("minimax", "abab5.5-chat", "MiniMax abab5.5", true, 10),
    
    // Anthropic Claude
    CLAUDE_3_OPUS("anthropic", "claude-3-opus-20240229", "Claude 3 Opus", true, 80),
    CLAUDE_3_SONNET("anthropic", "claude-3-sonnet-20240229", "Claude 3 Sonnet", true, 25),
    
    // Google Gemini
    GEMINI_PRO("google", "gemini-pro", "Gemini Pro", true, 12),
    GEMINI_ULTRA("google", "gemini-ultra", "Gemini Ultra", true, 100),
    
    // 本地模型(Ollama)
    OLLAMA_LLAMA3("ollama", "llama3", "Llama 3 (Local)", false, 0),
    OLLAMA_QWEN("ollama", "qwen:14b", "Qwen 14B (Local)", false, 0),
    OLLAMA_MISTRAL("ollama", "mistral", "Mistral (Local)", false, 0);
    
    private final String provider;      // 提供商代码
    private final String modelName;     // 模型名称
    private final String displayName;   // 显示名称
    private final boolean isCloud;      // 是否为云端模型
    private final int pricePer1K;       // 每千Token价格(分)
    
    /**
     * 根据模型名称查找
     */
    public static Optional<ModelProvider> fromModelName(String modelName) {
        return Arrays.stream(values())
            .filter(p -> p.modelName.equals(modelName))
            .findFirst();
    }
    
    /**
     * 根据提供商获取所有模型
     */
    public static List<ModelProvider> byProvider(String provider) {
        return Arrays.stream(values())
            .filter(p -> p.provider.equals(provider))
            .collect(Collectors.toList());
    }
    
    /**
     * 获取所有可用模型
     */
    public static List<ModelProvider> availableModels() {
        return Arrays.stream(values())
            .filter(ModelProvider::isEnabled)
            .collect(Collectors.toList());
    }
    
    /**
     * 检查模型是否启用(从配置读取)
     */
    public boolean isEnabled() {
        // 实际实现从配置中心读取
        return true;
    }
}

3.2 多模型 YAML 配置

yaml 复制代码
# application.yml - AI模型配置
spring:
  ai:
    # OpenAI 配置
    openai:
      api-key: ${OPENAI_API_KEY}
      base-url: https://api.openai.com  # 可替换为代理地址
      chat:
        options:
          model: gpt-3.5-turbo
          temperature: 0.7
          max-tokens: 2000
    
    # DeepSeek 配置
    deepseek:
      api-key: ${DEEPSEEK_API_KEY}
      base-url: https://api.deepseek.com
      chat:
        options:
          model: deepseek-chat
          temperature: 0.7
          max-tokens: 4000
    
    # Ollama 本地模型配置
    ollama:
      base-url: http://localhost:11434
      chat:
        options:
          model: llama3
          temperature: 0.8

# 自定义AI配置
ai:
  model:
    # 默认模型
    default: OPENAI_GPT35
    
    # 场景模型映射
    scenarios:
      interview: OPENAI_GPT4        # 面试场景用GPT-4
      resume: DEEPSEEK_V3           # 简历优化用DeepSeek
      code: DEEPSEEK_CODER          # 代码生成用DeepSeek Coder
      chat: OPENAI_GPT35            # 普通对话用GPT-3.5
    
    # 用户套餐模型限制
    plan-limits:
      FREE:
        - OPENAI_GPT35
        - DEEPSEEK_V3
      PREMIUM:
        - OPENAI_GPT35
        - OPENAI_GPT4
        - DEEPSEEK_V3
        - DEEPSEEK_CODER
        - MINIMAX_ABAB5
      ENTERPRISE:
        - all  # 所有模型
    
    # 模型参数配置
    parameters:
      interview:
        temperature: 0.7
        max-tokens: 4000
        top-p: 0.9
      resume:
        temperature: 0.5
        max-tokens: 3000
      code:
        temperature: 0.3
        max-tokens: 4000

3.3 模型配置属性类

java 复制代码
/**
 * AI模型配置属性
 */
@Data
@Configuration
@ConfigurationProperties(prefix = "ai.model")
public class AiModelProperties {
    
    private String defaultModel;
    private Map<String, String> scenarios = new HashMap<>();
    private Map<String, List<String>> planLimits = new HashMap<>();
    private Map<String, ModelParameters> parameters = new HashMap<>();
    
    @Data
    public static class ModelParameters {
        private Double temperature;
        private Integer maxTokens;
        private Double topP;
        private Double frequencyPenalty;
        private Double presencePenalty;
    }
    
    /**
     * 获取场景对应的模型
     */
    public ModelProvider getModelForScenario(String scenario) {
        String modelName = scenarios.getOrDefault(scenario, defaultModel);
        return ModelProvider.valueOf(modelName);
    }
    
    /**
     * 获取用户套餐允许的模型列表
     */
    public List<ModelProvider> getModelsForPlan(String plan) {
        List<String> modelNames = planLimits.getOrDefault(plan, 
            planLimits.get("FREE"));
        
        if (modelNames.contains("all")) {
            return Arrays.asList(ModelProvider.values());
        }
        
        return modelNames.stream()
            .map(ModelProvider::valueOf)
            .collect(Collectors.toList());
    }
    
    /**
     * 获取场景参数
     */
    public ModelParameters getParameters(String scenario) {
        return parameters.getOrDefault(scenario, new ModelParameters());
    }
}

四、动态模型切换机制

4.1 ModelSwitchService 实现

java 复制代码
/**
 * 模型切换服务
 * 实现运行时动态切换模型,无需重启服务
 */
@Service
@Slf4j
public class ModelSwitchService {
    
    @Autowired
    private AiModelProperties modelProperties;
    
    @Autowired
    private ApplicationContext applicationContext;
    
    // 缓存ChatClient实例
    private final Map<ModelProvider, ChatClient> clientCache = new ConcurrentHashMap<>();
    
    // 当前默认模型
    private volatile ModelProvider currentDefault = ModelProvider.OPENAI_GPT35;
    
    /**
     * 获取指定模型的ChatClient
     */
    public ChatClient getChatClient(ModelProvider provider) {
        return clientCache.computeIfAbsent(provider, this::createChatClient);
    }
    
    /**
     * 获取当前默认模型的ChatClient
     */
    public ChatClient getDefaultChatClient() {
        return getChatClient(currentDefault);
    }
    
    /**
     * 根据场景获取ChatClient
     */
    public ChatClient getChatClientForScenario(String scenario) {
        ModelProvider provider = modelProperties.getModelForScenario(scenario);
        return getChatClient(provider);
    }
    
    /**
     * 创建ChatClient实例
     */
    private ChatClient createChatClient(ModelProvider provider) {
        log.info("Creating ChatClient for provider: {}", provider);
        
        ChatModel chatModel = createChatModel(provider);
        
        return ChatClient.builder(chatModel)
            .defaultSystem(systemText -> systemText.text(
                "你是AI面试助手,帮助用户准备技术面试。"))
            .defaultOptions(ChatOptions.builder()
                .temperature(0.7)
                .maxTokens(2000)
                .build())
            .build();
    }
    
    /**
     * 创建ChatModel实例
     */
    private ChatModel createChatModel(ModelProvider provider) {
        return switch (provider.getProvider()) {
            case "openai" -> createOpenAiModel(provider);
            case "deepseek" -> createDeepSeekModel(provider);
            case "minimax" -> createMiniMaxModel(provider);
            case "ollama" -> createOllamaModel(provider);
            default -> throw new IllegalArgumentException("Unknown provider: " + provider);
        };
    }
    
    private ChatModel createOpenAiModel(ModelProvider provider) {
        OpenAiApi openAiApi = new OpenAiApi(
            System.getenv("OPENAI_API_KEY"),
            "https://api.openai.com"
        );
        
        OpenAiChatModel chatModel = new OpenAiChatModel(
            openAiApi,
            OpenAiChatOptions.builder()
                .model(provider.getModelName())
                .temperature(0.7)
                .build()
        );
        
        return chatModel;
    }
    
    private ChatModel createDeepSeekModel(ModelProvider provider) {
        // DeepSeek使用OpenAI兼容接口
        OpenAiApi deepSeekApi = new OpenAiApi(
            System.getenv("DEEPSEEK_API_KEY"),
            "https://api.deepseek.com"
        );
        
        return new OpenAiChatModel(
            deepSeekApi,
            OpenAiChatOptions.builder()
                .model(provider.getModelName())
                .temperature(0.7)
                .build()
        );
    }
    
    private ChatModel createMiniMaxModel(ModelProvider provider) {
        // MiniMax需要自定义实现
        return new MiniMaxChatModel(
            System.getenv("MINIMAX_API_KEY"),
            provider.getModelName()
        );
    }
    
    private ChatModel createOllamaModel(ModelProvider provider) {
        OllamaApi ollamaApi = new OllamaApi("http://localhost:11434");
        
        return OllamaChatModel.builder()
            .ollamaApi(ollamaApi)
            .defaultOptions(OllamaOptions.builder()
                .model(provider.getModelName())
                .temperature(0.8)
                .build())
            .build();
    }
    
    /**
     * 切换默认模型(热切换)
     */
    public void switchDefaultModel(ModelProvider provider) {
        log.info("Switching default model from {} to {}", currentDefault, provider);
        this.currentDefault = provider;
    }
    
    /**
     * 获取当前默认模型
     */
    public ModelProvider getCurrentDefault() {
        return currentDefault;
    }
    
    /**
     * 获取所有可用模型
     */
    public List<ModelProvider> getAvailableModels() {
        return Arrays.stream(ModelProvider.values())
            .filter(this::isModelAvailable)
            .collect(Collectors.toList());
    }
    
    /**
     * 检查模型是否可用
     */
    public boolean isModelAvailable(ModelProvider provider) {
        try {
            ChatClient client = getChatClient(provider);
            // 发送测试请求
            String response = client.prompt()
                .user("Hi")
                .call()
                .content();
            return StringUtils.isNotBlank(response);
        } catch (Exception e) {
            log.warn("Model {} is not available: {}", provider, e.getMessage());
            return false;
        }
    }
    
    /**
     * 清除模型缓存(用于配置更新后)
     */
    public void clearCache() {
        clientCache.clear();
        log.info("ChatClient cache cleared");
    }
}

4.2 模型选择策略

java 复制代码
/**
 * 模型选择策略
 * 根据场景、成本、可用性自动选择最优模型
 */
@Service
@Slf4j
public class ModelSelectionStrategy {
    
    @Autowired
    private ModelSwitchService modelSwitchService;
    
    @Autowired
    private TokenUsageService tokenUsageService;
    
    /**
     * 智能模型选择
     */
    public ModelProvider selectModel(ModelSelectionContext context) {
        // 1. 检查用户是否有权限使用该模型
        if (!hasPermission(context.getUserPlan(), context.getPreferredModel())) {
            log.warn("User plan {} does not have access to model {}", 
                context.getUserPlan(), context.getPreferredModel());
            // 降级到套餐允许的最佳模型
            return getBestAvailableModel(context.getUserPlan(), context.getScenario());
        }
        
        // 2. 检查模型是否可用
        if (!modelSwitchService.isModelAvailable(context.getPreferredModel())) {
            log.warn("Preferred model {} is not available, falling back", 
                context.getPreferredModel());
            return getFallbackModel(context.getPreferredModel());
        }
        
        // 3. 成本优化:如果是简单任务,使用更便宜的模型
        if (context.isCostSensitive() && isSimpleTask(context)) {
            ModelProvider cheaperModel = getCheaperAlternative(context.getPreferredModel());
            if (cheaperModel != null && modelSwitchService.isModelAvailable(cheaperModel)) {
                log.info("Using cheaper model {} instead of {}", 
                    cheaperModel, context.getPreferredModel());
                return cheaperModel;
            }
        }
        
        return context.getPreferredModel();
    }
    
    /**
     * 检查用户是否有权限使用模型
     */
    private boolean hasPermission(String plan, ModelProvider model) {
        List<ModelProvider> allowedModels = modelSwitchService.getAvailableModels();
        return allowedModels.contains(model);
    }
    
    /**
     * 获取最佳可用模型
     */
    private ModelProvider getBestAvailableModel(String plan, String scenario) {
        List<ModelProvider> candidates = switch (scenario) {
            case "interview" -> List.of(
                ModelProvider.OPENAI_GPT4,
                ModelProvider.DEEPSEEK_V3,
                ModelProvider.OPENAI_GPT35
            );
            case "code" -> List.of(
                ModelProvider.DEEPSEEK_CODER,
                ModelProvider.OPENAI_GPT4,
                ModelProvider.CLAUDE_3_SONNET
            );
            default -> List.of(
                ModelProvider.OPENAI_GPT35,
                ModelProvider.DEEPSEEK_V3
            );
        };
        
        for (ModelProvider model : candidates) {
            if (hasPermission(plan, model) && modelSwitchService.isModelAvailable(model)) {
                return model;
            }
        }
        
        return ModelProvider.OPENAI_GPT35;  // 最终兜底
    }
    
    /**
     * 获取降级模型
     */
    private ModelProvider getFallbackModel(ModelProvider original) {
        return switch (original) {
            case OPENAI_GPT4 -> ModelProvider.DEEPSEEK_V3;
            case DEEPSEEK_V3 -> ModelProvider.OPENAI_GPT35;
            case CLAUDE_3_OPUS -> ModelProvider.CLAUDE_3_SONNET;
            default -> ModelProvider.OPENAI_GPT35;
        };
    }
    
    /**
     * 判断是否简单任务
     */
    private boolean isSimpleTask(ModelSelectionContext context) {
        String prompt = context.getPrompt();
        // 简单启发式判断
        return prompt.length() < 500 && 
               !prompt.contains("代码") &&
               !prompt.contains("算法") &&
               !prompt.contains("设计");
    }
    
    /**
     * 获取更便宜的替代模型
     */
    private ModelProvider getCheaperAlternative(ModelProvider original) {
        return switch (original) {
            case OPENAI_GPT4 -> ModelProvider.OPENAI_GPT35;
            case CLAUDE_3_OPUS -> ModelProvider.CLAUDE_3_SONNET;
            case DEEPSEEK_CODER -> ModelProvider.DEEPSEEK_V3;
            default -> null;
        };
    }
}

/**
 * 模型选择上下文
 */
@Data
@Builder
public class ModelSelectionContext {
    private ModelProvider preferredModel;  // 用户偏好的模型
    private String userPlan;               // 用户套餐
    private String scenario;               // 使用场景
    private String prompt;                 // 提示内容
    private boolean costSensitive;         // 是否成本敏感
}

五、统一 LLM 调用封装

5.1 LlmService 抽象层

java 复制代码
/**
 * 统一LLM调用服务
 * 封装所有模型调用细节,对外提供统一接口
 */
@Service
@Slf4j
public class LlmService {
    
    @Autowired
    private ModelSwitchService modelSwitchService;
    
    @Autowired
    private ModelSelectionStrategy modelSelectionStrategy;
    
    @Autowired
    private TokenUsageService tokenUsageService;
    
    @Autowired
    private PromptTemplateService promptTemplateService;
    
    /**
     * 同步调用 - 简单场景
     */
    public String chat(String message) {
        return chat(message, modelSwitchService.getCurrentDefault());
    }
    
    /**
     * 同步调用 - 指定模型
     */
    public String chat(String message, ModelProvider model) {
        ChatClient client = modelSwitchService.getChatClient(model);
        
        long startTime = System.currentTimeMillis();
        try {
            String response = client.prompt()
                .user(message)
                .call()
                .content();
            
            // 记录Token使用
            recordUsage(model, message, response, System.currentTimeMillis() - startTime);
            
            return response;
        } catch (Exception e) {
            log.error("Chat failed with model {}: {}", model, e.getMessage());
            throw new BusinessException(ResultCode.AI_MODEL_ERROR, e.getMessage());
        }
    }
    
    /**
     * 场景化调用 - 自动选择模型
     */
    public String chatWithScenario(String message, String scenario, String userPlan) {
        // 选择最优模型
        ModelSelectionContext context = ModelSelectionContext.builder()
            .preferredModel(modelSwitchService.getCurrentDefault())
            .userPlan(userPlan)
            .scenario(scenario)
            .prompt(message)
            .costSensitive(!"ENTERPRISE".equals(userPlan))
            .build();
        
        ModelProvider selectedModel = modelSelectionStrategy.selectModel(context);
        
        log.info("Selected model {} for scenario {}", selectedModel, scenario);
        
        // 获取场景参数
        AiModelProperties.ModelParameters params = 
            promptTemplateService.getParametersForScenario(scenario);
        
        ChatClient client = modelSwitchService.getChatClient(selectedModel);
        
        long startTime = System.currentTimeMillis();
        String response = client.prompt()
            .user(message)
            .options(ChatOptions.builder()
                .temperature(params.getTemperature())
                .maxTokens(params.getMaxTokens())
                .build())
            .call()
            .content();
        
        recordUsage(selectedModel, message, response, System.currentTimeMillis() - startTime);
        
        return response;
    }
    
    /**
     * 结构化输出 - 自动解析JSON
     */
    public <T> T chatForEntity(String message, Class<T> responseType) {
        return chatForEntity(message, responseType, 
            modelSwitchService.getCurrentDefault());
    }
    
    public <T> T chatForEntity(String message, Class<T> responseType, 
                               ModelProvider model) {
        ChatClient client = modelSwitchService.getChatClient(model);
        
        try {
            return client.prompt()
                .user(message)
                .call()
                .entity(responseType);
        } catch (Exception e) {
            log.error("Entity parsing failed: {}", e.getMessage());
            throw new BusinessException(ResultCode.AI_RESPONSE_ERROR, e.getMessage());
        }
    }
    
    /**
     * 流式调用
     */
    public Flux<String> streamChat(String message) {
        return streamChat(message, modelSwitchService.getCurrentDefault());
    }
    
    public Flux<String> streamChat(String message, ModelProvider model) {
        ChatClient client = modelSwitchService.getChatClient(model);
        
        return client.prompt()
            .user(message)
            .stream()
            .content()
            .doOnNext(chunk -> log.debug("Stream chunk: {}", chunk))
            .doOnError(e -> log.error("Stream error: ", e));
    }
    
    /**
     * 多轮对话
     */
    public String chatWithHistory(List<ChatMessage> history, String newMessage) {
        return chatWithHistory(history, newMessage, 
            modelSwitchService.getCurrentDefault());
    }
    
    public String chatWithHistory(List<ChatMessage> history, String newMessage,
                                  ModelProvider model) {
        ChatClient.ChatClientRequestSpec spec = 
            modelSwitchService.getChatClient(model).prompt();
        
        // 添加系统提示
        spec.system("你是一个专业的AI面试助手");
        
        // 添加历史消息
        for (ChatMessage msg : history) {
            switch (msg.getRole()) {
                case USER -> spec.user(msg.getContent());
                case ASSISTANT -> spec.system(msg.getContent());
            }
        }
        
        return spec.user(newMessage)
            .call()
            .content();
    }
    
    /**
     * 带重试的调用
     */
    public String chatWithRetry(String message, int maxRetries) {
        int attempts = 0;
        Exception lastException = null;
        
        while (attempts < maxRetries) {
            try {
                return chat(message);
            } catch (Exception e) {
                lastException = e;
                attempts++;
                log.warn("Chat attempt {} failed: {}", attempts, e.getMessage());
                
                if (attempts < maxRetries) {
                    // 指数退避
                    try {
                        Thread.sleep((long) Math.pow(2, attempts) * 1000);
                    } catch (InterruptedException ie) {
                        Thread.currentThread().interrupt();
                    }
                }
            }
        }
        
        throw new BusinessException(ResultCode.AI_MODEL_ERROR, 
            "Failed after " + maxRetries + " attempts: " + lastException.getMessage());
    }
    
    /**
     * 记录Token使用
     */
    private void recordUsage(ModelProvider model, String prompt, String response, 
                            long latency) {
        try {
            // 估算Token数(实际应从API响应中获取)
            int promptTokens = estimateTokens(prompt);
            int completionTokens = estimateTokens(response);
            int totalTokens = promptTokens + completionTokens;
            
            TokenUsageRecord record = TokenUsageRecord.builder()
                .model(model)
                .promptTokens(promptTokens)
                .completionTokens(completionTokens)
                .totalTokens(totalTokens)
                .cost(calculateCost(model, totalTokens))
                .latencyMs(latency)
                .timestamp(LocalDateTime.now())
                .build();
            
            tokenUsageService.recordUsage(record);
        } catch (Exception e) {
            log.error("Failed to record token usage: {}", e.getMessage());
        }
    }
    
    /**
     * 估算Token数(简化版)
     */
    private int estimateTokens(String text) {
        // 粗略估算:中文约1.5字符/Token,英文约4字符/Token
        int chineseChars = (int) text.chars().filter(c -> c > 127).count();
        int englishChars = text.length() - chineseChars;
        return (int) (chineseChars / 1.5 + englishChars / 4);
    }
    
    /**
     * 计算成本
     */
    private BigDecimal calculateCost(ModelProvider model, int tokens) {
        // 价格单位:分/千Token
        BigDecimal pricePer1K = BigDecimal.valueOf(model.getPricePer1K());
        return pricePer1K.multiply(BigDecimal.valueOf(tokens))
            .divide(BigDecimal.valueOf(1000), 4, RoundingMode.HALF_UP);
    }
}

5.2 Prompt 模板服务

java 复制代码
/**
 * Prompt 模板服务
 * 管理各类场景的Prompt模板
 */
@Service
@Slf4j
public class PromptTemplateService {
    
    @Autowired
    private AiModelProperties modelProperties;
    
    private final Map<String, String> templates = new HashMap<>();
    
    @PostConstruct
    public void init() {
        // 初始化内置模板
        loadBuiltinTemplates();
        
        // 从文件加载自定义模板
        loadCustomTemplates();
    }
    
    private void loadBuiltinTemplates() {
        // 面试场景模板
        templates.put("interview-system", """
            你是一位经验丰富的技术面试官,正在面试{jobType}岗位的候选人。
            
            候选人背景:
            - 工作年限:{experience}
            - 面试轮次:{round}
            
            你的任务是:
            1. 根据候选人背景提出有针对性的技术问题
            2. 评估候选人的回答质量
            3. 给出建设性的反馈和改进建议
            
            请保持专业、友善的态度,问题难度要适中。
            """);
        
        // 简历优化模板
        templates.put("resume-optimize", """
            你是一位资深HR和职业规划师,擅长简历优化。
            
            请分析以下简历,并给出优化建议:
            
            目标岗位JD:
            {jobDescription}
            
            简历内容:
            {resumeContent}
            
            请从以下维度分析:
            1. 简历与岗位的匹配度评分(0-100)
            2. 关键词匹配情况
            3. 具体优化建议
            4. 修改后的简历内容
            """);
        
        // 代码评审模板
        templates.put("code-review", """
            你是一位资深软件工程师,请对以下代码进行评审。
            
            编程语言:{language}
            
            代码:
            ```{language}
            {code}
            ```
            
            请从以下方面给出评审意见:
            1. 代码质量和可读性
            2. 潜在的问题和Bug
            3. 性能优化建议
            4. 最佳实践建议
            """);
    }
    
    private void loadCustomTemplates() {
        // 从配置文件或数据库加载
        // templates.putAll(loadFromDatabase());
    }
    
    /**
     * 渲染模板
     */
    public String renderTemplate(String templateKey, Map<String, Object> variables) {
        String template = templates.get(templateKey);
        if (template == null) {
            throw new BusinessException("Template not found: " + templateKey);
        }
        
        // 简单变量替换
        String result = template;
        for (Map.Entry<String, Object> entry : variables.entrySet()) {
            result = result.replace("{" + entry.getKey() + "}", 
                String.valueOf(entry.getValue()));
        }
        
        return result;
    }
    
    /**
     * 获取场景参数
     */
    public AiModelProperties.ModelParameters getParametersForScenario(String scenario) {
        AiModelProperties.ModelParameters params = 
            modelProperties.getParameters().get(scenario);
        
        if (params == null) {
            // 返回默认参数
            params = new AiModelProperties.ModelParameters();
            params.setTemperature(0.7);
            params.setMaxTokens(2000);
        }
        
        return params;
    }
    
    /**
     * 获取系统Prompt
     */
    public String getSystemPrompt(String scenario) {
        return templates.getOrDefault(scenario + "-system", 
            "你是一个专业的AI助手。");
    }
}

六、Token 计费与配额管理

6.1 Token 使用记录

java 复制代码
/**
 * Token使用记录实体
 */
@Entity
@Table(name = "token_usage_records")
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class TokenUsageRecord extends BaseEntity {
    
    @Column(name = "user_id")
    private Long userId;
    
    @Enumerated(EnumType.STRING)
    @Column(nullable = false)
    private ModelProvider model;
    
    @Column(name = "prompt_tokens")
    private Integer promptTokens;
    
    @Column(name = "completion_tokens")
    private Integer completionTokens;
    
    @Column(name = "total_tokens")
    private Integer totalTokens;
    
    @Column(name = "cost", precision = 10, scale = 4)
    private BigDecimal cost;  // 成本(元)
    
    @Column(name = "latency_ms")
    private Long latencyMs;
    
    @Column(name = "scenario")
    private String scenario;
    
    @Column(name = "request_id")
    private String requestId;
    
    @Column(name = "status")
    @Enumerated(EnumType.STRING)
    private UsageStatus status;
    
    public enum UsageStatus {
        SUCCESS, FAILED, CACHED
    }
}

6.2 Token 使用服务

java 复制代码
/**
 * Token使用记录服务
 */
@Service
@Slf4j
public class TokenUsageService {
    
    @Autowired
    private TokenUsageRecordRepository usageRepository;
    
    @Autowired
    private StringRedisTemplate redisTemplate;
    
    private static final String USER_QUOTA_KEY = "user:quota:";
    private static final String USER_USAGE_KEY = "user:usage:";
    
    /**
     * 记录Token使用
     */
    @Async
    public void recordUsage(TokenUsageRecord record) {
        try {
            usageRepository.save(record);
            
            // 更新Redis使用统计
            updateUsageStats(record);
            
            log.debug("Token usage recorded: {} tokens, cost: {}", 
                record.getTotalTokens(), record.getCost());
        } catch (Exception e) {
            log.error("Failed to record token usage: {}", e.getMessage());
        }
    }
    
    /**
     * 检查用户配额
     */
    public boolean checkQuota(Long userId, UserPlan plan) {
        String key = USER_USAGE_KEY + userId + ":" + getCurrentMonth();
        String usedStr = redisTemplate.opsForValue().get(key);
        int used = usedStr != null ? Integer.parseInt(usedStr) : 0;
        
        return used < plan.getQuota();
    }
    
    /**
     * 扣除配额
     */
    public void deductQuota(Long userId, int tokens) {
        String key = USER_USAGE_KEY + userId + ":" + getCurrentMonth();
        redisTemplate.opsForValue().increment(key, tokens);
        
        // 设置过期时间为月底
        LocalDateTime endOfMonth = LocalDateTime.now()
            .with(TemporalAdjusters.lastDayOfMonth())
            .with(LocalTime.MAX);
        long seconds = Duration.between(LocalDateTime.now(), endOfMonth).getSeconds();
        redisTemplate.expire(key, seconds, TimeUnit.SECONDS);
    }
    
    /**
     * 获取用户使用统计
     */
    public UsageStats getUsageStats(Long userId, LocalDate startDate, LocalDate endDate) {
        List<TokenUsageRecord> records = usageRepository
            .findByUserIdAndCreatedAtBetween(userId, startDate.atStartOfDay(), 
                endDate.atTime(LocalTime.MAX));
        
        int totalTokens = records.stream()
            .mapToInt(TokenUsageRecord::getTotalTokens)
            .sum();
        
        BigDecimal totalCost = records.stream()
            .map(TokenUsageRecord::getCost)
            .reduce(BigDecimal.ZERO, BigDecimal::add);
        
        Map<ModelProvider, Integer> modelUsage = records.stream()
            .collect(Collectors.groupingBy(
                TokenUsageRecord::getModel,
                Collectors.summingInt(TokenUsageRecord::getTotalTokens)
            ));
        
        return UsageStats.builder()
            .totalTokens(totalTokens)
            .totalCost(totalCost)
            .requestCount(records.size())
            .modelUsage(modelUsage)
            .averageLatency(records.stream()
                .mapToLong(TokenUsageRecord::getLatencyMs)
                .average()
                .orElse(0))
            .build();
    }
    
    /**
     * 获取系统整体使用统计
     */
    public SystemUsageStats getSystemStats(LocalDate date) {
        LocalDateTime start = date.atStartOfDay();
        LocalDateTime end = date.atTime(LocalTime.MAX);
        
        List<TokenUsageRecord> records = usageRepository
            .findByCreatedAtBetween(start, end);
        
        return SystemUsageStats.builder()
            .date(date)
            .totalRequests(records.size())
            .totalTokens(records.stream()
                .mapToInt(TokenUsageRecord::getTotalTokens)
                .sum())
            .totalCost(records.stream()
                .map(TokenUsageRecord::getCost)
                .reduce(BigDecimal.ZERO, BigDecimal::add))
            .modelDistribution(records.stream()
                .collect(Collectors.groupingBy(
                    TokenUsageRecord::getModel,
                    Collectors.counting()
                )))
            .build();
    }
    
    private void updateUsageStats(TokenUsageRecord record) {
        if (record.getUserId() == null) return;
        
        String key = USER_USAGE_KEY + record.getUserId() + ":" + getCurrentMonth();
        redisTemplate.opsForValue().increment(key, record.getTotalTokens());
    }
    
    private String getCurrentMonth() {
        return LocalDate.now().format(DateTimeFormatter.ofPattern("yyyy-MM"));
    }
}

6.3 配额限制切面

java 复制代码
/**
 * 配额限制切面
 * 在AI调用前检查用户配额
 */
@Aspect
@Component
@Slf4j
public class QuotaLimitAspect {
    
    @Autowired
    private TokenUsageService tokenUsageService;
    
    @Autowired
    private UserService userService;
    
    @Around("@annotation(checkQuota)")
    public Object around(ProceedingJoinPoint point, CheckQuota checkQuota) 
            throws Throwable {
        
        // 获取当前用户
        Long userId = getCurrentUserId();
        if (userId == null) {
            return point.proceed();
        }
        
        User user = userService.findById(userId);
        
        // 检查配额
        if (!tokenUsageService.checkQuota(userId, user.getPlan())) {
            log.warn("User {} quota exceeded", userId);
            throw new BusinessException(ResultCode.INSUFFICIENT_BALANCE, 
                "本月AI调用额度已用完,请升级套餐");
        }
        
        return point.proceed();
    }
    
    private Long getCurrentUserId() {
        Authentication auth = SecurityContextHolder.getContext().getAuthentication();
        if (auth != null && auth.getPrincipal() instanceof UserDetails) {
            // 从UserDetails获取用户ID
            return ((UserDetails) auth.getPrincipal()).getId();
        }
        return null;
    }
}

七、代码示例

7.1 ModelSwitchController

java 复制代码
/**
 * 模型切换控制器
 * 提供模型管理和切换接口
 */
@RestController
@RequestMapping("/api/admin/ai")
@PreAuthorize("hasRole('ADMIN')")
@Slf4j
public class ModelSwitchController {
    
    @Autowired
    private ModelSwitchService modelSwitchService;
    
    @Autowired
    private LlmService llmService;
    
    /**
     * 获取所有可用模型
     */
    @GetMapping("/models")
    public Result<List<ModelInfo>> getAvailableModels() {
        List<ModelInfo> models = modelSwitchService.getAvailableModels().stream()
            .map(this::toModelInfo)
            .collect(Collectors.toList());
        return Result.success(models);
    }
    
    /**
     * 切换默认模型
     */
    @PostMapping("/models/default")
    public Result<Void> switchDefaultModel(@RequestBody @Valid SwitchModelRequest request) {
        ModelProvider provider = ModelProvider.valueOf(request.getModel());
        modelSwitchService.switchDefaultModel(provider);
        log.info("Default model switched to {} by admin", provider);
        return Result.success("模型切换成功");
    }
    
    /**
     * 获取当前默认模型
     */
    @GetMapping("/models/default")
    public Result<ModelInfo> getCurrentDefaultModel() {
        ModelProvider current = modelSwitchService.getCurrentDefault();
        return Result.success(toModelInfo(current));
    }
    
    /**
     * 测试模型可用性
     */
    @PostMapping("/models/{model}/test")
    public Result<ModelTestResult> testModel(@PathVariable String model) {
        ModelProvider provider = ModelProvider.valueOf(model);
        
        long startTime = System.currentTimeMillis();
        boolean available = modelSwitchService.isModelAvailable(provider);
        long latency = System.currentTimeMillis() - startTime;
        
        ModelTestResult result = ModelTestResult.builder()
            .model(model)
            .available(available)
            .latencyMs(latency)
            .build();
        
        return Result.success(result);
    }
    
    /**
     * 清除模型缓存
     */
    @PostMapping("/cache/clear")
    public Result<Void> clearCache() {
        modelSwitchService.clearCache();
        return Result.success("缓存已清除");
    }
    
    /**
     * 获取Token使用统计
     */
    @GetMapping("/usage/stats")
    public Result<UsageStatsResponse> getUsageStats(
            @RequestParam @DateTimeFormat(iso = DateTimeFormat.ISO.DATE) LocalDate startDate,
            @RequestParam @DateTimeFormat(iso = DateTimeFormat.ISO.DATE) LocalDate endDate) {
        
        // 获取系统整体统计
        SystemUsageStats stats = tokenUsageService.getSystemStats(LocalDate.now());
        
        UsageStatsResponse response = UsageStatsResponse.builder()
            .todayRequests(stats.getTotalRequests())
            .todayTokens(stats.getTotalTokens())
            .todayCost(stats.getTotalCost())
            .modelDistribution(stats.getModelDistribution())
            .build();
        
        return Result.success(response);
    }
    
    private ModelInfo toModelInfo(ModelProvider provider) {
        return ModelInfo.builder()
            .code(provider.name())
            .name(provider.getDisplayName())
            .provider(provider.getProvider())
            .modelName(provider.getModelName())
            .pricePer1K(provider.getPricePer1K())
            .isCloud(provider.isCloud())
            .isDefault(provider == modelSwitchService.getCurrentDefault())
            .build();
    }
}

7.2 面试服务中使用多模型

java 复制代码
/**
 * 面试服务 - 多模型调用示例
 */
@Service
@Slf4j
public class InterviewService {
    
    @Autowired
    private LlmService llmService;
    
    @Autowired
    private PromptTemplateService promptTemplateService;
    
    @Autowired
    private InterviewSessionRepository sessionRepository;
    
    /**
     * 开始面试 - 根据用户套餐选择模型
     */
    @CheckQuota
    public InterviewResponse startInterview(StartInterviewRequest request, User user) {
        // 根据场景选择模型
        String scenario = "interview";
        
        // 渲染Prompt
        String prompt = promptTemplateService.renderTemplate("interview-system", 
            Map.of(
                "jobType", request.getJobType(),
                "experience", request.getExperience(),
                "round", request.getRound()
            ));
        
        // 调用AI生成第一个问题
        String response = llmService.chatWithScenario(prompt, scenario, 
            user.getPlan().name());
        
        // 创建会话
        InterviewSession session = InterviewSession.builder()
            .userId(user.getId())
            .jobType(request.getJobType())
            .status(InterviewStatus.IN_PROGRESS)
            .build();
        sessionRepository.save(session);
        
        return InterviewResponse.builder()
            .sessionId(session.getId())
            .question(response)
            .build();
    }
    
    /**
     * 回答面试问题 - 流式响应
     */
    public Flux<String> answerQuestionStream(Long sessionId, String answer, User user) {
        InterviewSession session = sessionRepository.findById(sessionId)
            .orElseThrow(() -> new BusinessException(ResultCode.INTERVIEW_NOT_FOUND));
        
        // 构建带上下文的Prompt
        String prompt = buildContextualPrompt(session, answer);
        
        // 流式调用
        return llmService.streamChat(prompt);
    }
    
    /**
     * 评估面试表现 - 使用更强的模型
     */
    public InterviewEvaluation evaluateInterview(Long sessionId, User user) {
        InterviewSession session = sessionRepository.findById(sessionId)
            .orElseThrow(() -> new BusinessException(ResultCode.INTERVIEW_NOT_FOUND));
        
        // 使用GPT-4进行评估
        String prompt = buildEvaluationPrompt(session);
        
        InterviewEvaluationResult result = llmService.chatForEntity(
            prompt, 
            InterviewEvaluationResult.class,
            ModelProvider.OPENAI_GPT4
        );
        
        return InterviewEvaluation.builder()
            .sessionId(sessionId)
            .overallScore(result.getOverallScore())
            .technicalScore(result.getTechnicalScore())
            .communicationScore(result.getCommunicationScore())
            .feedback(result.getFeedback())
            .build();
    }
    
    private String buildContextualPrompt(InterviewSession session, String answer) {
        // 获取历史对话
        List<ChatMessage> history = session.getMessages();
        
        StringBuilder prompt = new StringBuilder();
        prompt.append("面试岗位:").append(session.getJobType()).append("\n\n");
        prompt.append("历史对话:\n");
        
        for (ChatMessage msg : history) {
            prompt.append(msg.getRole() == Role.USER ? "候选人:" : "面试官:")
                  .append(msg.getContent())
                  .append("\n");
        }
        
        prompt.append("候选人最新回答:").append(answer).append("\n\n");
        prompt.append("请作为面试官给出回应,可以是追问或评价。");
        
        return prompt.toString();
    }
    
    private String buildEvaluationPrompt(InterviewSession session) {
        return """
            请对以下技术面试进行全面评估:
            
            岗位类型:%s
            
            面试对话记录:
            %s
            
            请从以下维度给出评分(0-100)和详细反馈:
            1. 整体表现
            2. 技术能力
            3. 沟通表达
            4. 改进建议
            
            请以JSON格式返回结果。
            """.formatted(session.getJobType(), formatConversation(session.getMessages()));
    }
}

八、总结

8.1 本文要点

本文详细介绍了Spring AI多模型架构的完整实现方案:

  1. Spring AI核心概念:ChatClient、Prompt模板、流式响应
  2. 多模型配置管理:ModelProvider枚举、YAML配置、属性类
  3. 动态模型切换:ModelSwitchService实现运行时热切换
  4. 统一LLM调用:LlmService封装所有模型调用细节
  5. Token计费管理:使用记录、配额检查、成本统计
  6. 代码示例:ModelSwitchController、InterviewService实战

8.2 架构优势

  • 灵活性:运行时切换模型,无需重启服务
  • 成本优化:根据场景和用户套餐智能选择模型
  • 可扩展性:新增模型只需添加枚举和配置
  • 可观测性:完整的Token使用统计和成本分析

8.3 后续优化方向

  • 模型缓存:响应结果缓存,减少重复调用
  • A/B测试:不同模型效果对比
  • 智能路由:基于历史数据自动选择最优模型
  • 模型微调:针对特定场景微调本地模型

参考资料

  1. Spring AI Documentation
  2. OpenAI API Reference
  3. DeepSeek API Documentation
  4. MiniMax API Documentation
  5. Ollama Documentation

本系列文章到此结束,感谢阅读!如果对你有帮助,欢迎点赞、收藏、评论交流!

相关推荐
冬奇Lab8 小时前
Agent 系列(23):Web Agent——让 Agent 真正浏览网页
人工智能·llm·agent
冬奇Lab8 小时前
每日一个开源项目(第135篇):codebase-memory-mcp - 给 AI Agent 一张代码库的知识图谱
人工智能·开源·llm
candyTong8 小时前
RTK 技术原理:一次典型会话里,80% 上下文是怎么省下来的
javascript·后端·架构
IT_陈寒11 小时前
JavaScript的闭包把我坑惨了,说好的内存会自动回收呢?
前端·人工智能·后端
唐某人丶13 小时前
从画架构图开始:架构分析与进阶指南
架构
jooloo15 小时前
Codex 间歇性 400 之谜:一条对话里,它为什么有时候用 chat/completions,有时候切到 responses?
人工智能
用户51914958484515 小时前
OpenSSL PKCS#12 PBMAC1 堆栈缓冲区溢出漏洞 (CVE-2025-11187) 分析与验证
人工智能·aigc
用户51914958484516 小时前
HP Sound Research SECOMNService 权限提升漏洞利用工具
人工智能·aigc