04-Spring-AI多模型架构

Spring AI多模型架构:OpenAI/DeepSeek/MiniMax一键切换实战

一、前言

1.1 为什么需要多模型支持

在AI应用开发中,单一模型往往无法满足所有场景需求:

场景 推荐模型 原因
通用对话 GPT-4 / DeepSeek-V3 能力强、理解准确
代码生成 GPT-4 / Claude 3.5 代码质量高
中文优化 DeepSeek / MiniMax 中文理解更好
成本敏感 GPT-3.5 / DeepSeek-V2 价格便宜
离线部署 Ollama本地模型 数据隐私保护

多模型架构的核心价值

  • 成本优化:不同场景选择性价比最优的模型
  • 能力互补:利用各模型优势处理特定任务
  • 风险分散:避免单一供应商依赖
  • 用户体验:根据用户套餐提供差异化服务

1.2 Spring AI 简介

Spring AI 是 Spring 官方推出的 AI 应用开发框架,目标是简化 Java 开发者集成 LLM 的过程:

复制代码
Spring AI 核心特性:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

1. 模型无关的抽象层
   - ChatClient:统一聊天接口
   - EmbeddingClient:统一向量接口
   - ImageClient:统一图像接口

2. Prompt 管理
   - PromptTemplate:模板引擎
   - 支持变量替换和条件渲染

3. 向量存储
   - PgVector、Redis、Chroma等
   - 统一的Document和VectorStore接口

4. 函数调用
   - 支持模型调用本地Java方法
   - 实现Agent能力

5. 自动配置
   - Spring Boot Starter支持
   - 声明式配置

二、Spring AI 核心概念

2.1 ChatClient 与 ChatModel

java 复制代码
/**
 * Spring AI 核心接口关系
 */

// ChatClient - 面向应用的高级接口
public interface ChatClient {
    ChatClientRequestSpec prompt();  // 创建Prompt构建器
    String call(String message);     // 简单调用
    Flux<String> stream(String message);  // 流式调用
}

// ChatModel - 面向模型的底层接口
public interface ChatModel {
    ChatResponse call(Prompt prompt);     // 同步调用
    Flux<ChatResponse> stream(Prompt prompt);  // 流式调用
}

// 实际使用 - ChatClient更简洁
@Autowired
private ChatClient chatClient;

public String chat(String message) {
    return chatClient.prompt()
        .user(message)
        .call()
        .content();
}

2.2 Prompt 模板系统

java 复制代码
/**
 * Prompt 模板使用示例
 */

// 1. 简单模板
public String simplePrompt(String jobTitle, String resume) {
    return chatClient.prompt()
        .system("你是一位专业的" + jobTitle + "面试官")
        .user(u -> u.text("""
            请根据以下简历内容,生成5个面试问题:
            
            简历内容:
            {resume}
            """).param("resume", resume))
        .call()
        .content();
}

// 2. 结构化Prompt(推荐)
public InterviewResponse structuredPrompt(InterviewRequest request) {
    return chatClient.prompt()
        .system(systemPromptProvider.getInterviewSystemPrompt())
        .user(u -> u.text("""
            岗位类型:{jobType}
            工作经验:{experience}
            面试轮次:{round}
            
            请生成适合该候选人的面试问题。
            """)
            .param("jobType", request.getJobType())
            .param("experience", request.getExperience())
            .param("round", request.getRound()))
        .options(ChatOptions.builder()
            .temperature(0.7)
            .maxTokens(2000)
            .build())
        .call()
        .entity(InterviewResponse.class);  // 自动JSON解析
}

// 3. 多轮对话
public String multiTurnChat(List<Message> history, String newMessage) {
    ChatClient.ChatClientRequestSpec spec = chatClient.prompt();
    
    // 添加历史消息
    for (Message msg : history) {
        if (msg.getRole() == Role.USER) {
            spec.user(msg.getContent());
        } else {
            spec.system(msg.getContent());
        }
    }
    
    return spec.user(newMessage)
        .call()
        .content();
}

2.3 流式响应处理

java 复制代码
/**
 * 流式响应处理 - 实现打字机效果
 */
@RestController
@RequestMapping("/api/ai")
@Slf4j
public class StreamingController {
    
    @Autowired
    private ChatClient chatClient;
    
    /**
     * Server-Sent Events 流式输出
     */
    @GetMapping(value = "/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
    public Flux<ServerSentEvent<String>> streamChat(@RequestParam String message) {
        return chatClient.prompt()
            .user(message)
            .stream()
            .content()
            .map(content -> ServerSentEvent.builder(content).build())
            .onErrorResume(e -> {
                log.error("Stream error: ", e);
                return Flux.just(ServerSentEvent.builder("Error: " + e.getMessage()).build());
            });
    }
    
    /**
     * WebFlux 流式输出(前端使用EventSource接收)
     */
    @PostMapping(value = "/interview/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
    public Flux<String> interviewStream(@RequestBody InterviewRequest request) {
        String prompt = buildInterviewPrompt(request);
        
        return chatClient.prompt()
            .system("你是一位专业的技术面试官")
            .user(prompt)
            .options(ChatOptions.builder()
                .temperature(0.8)
                .maxTokens(4000)
                .build())
            .stream()
            .content()
            .doOnNext(chunk -> log.debug("Received chunk: {}", chunk))
            .doOnComplete(() -> log.info("Stream completed"));
    }
}

三、多模型配置管理

3.1 ModelProvider 枚举定义

java 复制代码
/**
 * 模型提供商枚举
 * 支持10+主流模型提供商
 */
@Getter
@AllArgsConstructor
public enum ModelProvider {
    
    // OpenAI 系列
    OPENAI_GPT4("openai", "gpt-4-turbo-preview", "OpenAI GPT-4", true, 50),
    OPENAI_GPT35("openai", "gpt-3.5-turbo", "OpenAI GPT-3.5", true, 10),
    
    // DeepSeek 系列(深度求索)
    DEEPSEEK_V3("deepseek", "deepseek-chat", "DeepSeek V3", true, 8),
    DEEPSEEK_CODER("deepseek", "deepseek-coder", "DeepSeek Coder", true, 8),
    
    // MiniMax 系列(海螺AI)
    MINIMAX_ABAB6("minimax", "abab6-chat", "MiniMax abab6", true, 15),
    MINIMAX_ABAB5("minimax", "abab5.5-chat", "MiniMax abab5.5", true, 10),
    
    // Anthropic Claude
    CLAUDE_3_OPUS("anthropic", "claude-3-opus-20240229", "Claude 3 Opus", true, 80),
    CLAUDE_3_SONNET("anthropic", "claude-3-sonnet-20240229", "Claude 3 Sonnet", true, 25),
    
    // Google Gemini
    GEMINI_PRO("google", "gemini-pro", "Gemini Pro", true, 12),
    GEMINI_ULTRA("google", "gemini-ultra", "Gemini Ultra", true, 100),
    
    // 本地模型(Ollama)
    OLLAMA_LLAMA3("ollama", "llama3", "Llama 3 (Local)", false, 0),
    OLLAMA_QWEN("ollama", "qwen:14b", "Qwen 14B (Local)", false, 0),
    OLLAMA_MISTRAL("ollama", "mistral", "Mistral (Local)", false, 0);
    
    private final String provider;      // 提供商代码
    private final String modelName;     // 模型名称
    private final String displayName;   // 显示名称
    private final boolean isCloud;      // 是否为云端模型
    private final int pricePer1K;       // 每千Token价格(分)
    
    /**
     * 根据模型名称查找
     */
    public static Optional<ModelProvider> fromModelName(String modelName) {
        return Arrays.stream(values())
            .filter(p -> p.modelName.equals(modelName))
            .findFirst();
    }
    
    /**
     * 根据提供商获取所有模型
     */
    public static List<ModelProvider> byProvider(String provider) {
        return Arrays.stream(values())
            .filter(p -> p.provider.equals(provider))
            .collect(Collectors.toList());
    }
    
    /**
     * 获取所有可用模型
     */
    public static List<ModelProvider> availableModels() {
        return Arrays.stream(values())
            .filter(ModelProvider::isEnabled)
            .collect(Collectors.toList());
    }
    
    /**
     * 检查模型是否启用(从配置读取)
     */
    public boolean isEnabled() {
        // 实际实现从配置中心读取
        return true;
    }
}

3.2 多模型 YAML 配置

yaml 复制代码
# application.yml - AI模型配置
spring:
  ai:
    # OpenAI 配置
    openai:
      api-key: ${OPENAI_API_KEY}
      base-url: https://api.openai.com  # 可替换为代理地址
      chat:
        options:
          model: gpt-3.5-turbo
          temperature: 0.7
          max-tokens: 2000
    
    # DeepSeek 配置
    deepseek:
      api-key: ${DEEPSEEK_API_KEY}
      base-url: https://api.deepseek.com
      chat:
        options:
          model: deepseek-chat
          temperature: 0.7
          max-tokens: 4000
    
    # Ollama 本地模型配置
    ollama:
      base-url: http://localhost:11434
      chat:
        options:
          model: llama3
          temperature: 0.8

# 自定义AI配置
ai:
  model:
    # 默认模型
    default: OPENAI_GPT35
    
    # 场景模型映射
    scenarios:
      interview: OPENAI_GPT4        # 面试场景用GPT-4
      resume: DEEPSEEK_V3           # 简历优化用DeepSeek
      code: DEEPSEEK_CODER          # 代码生成用DeepSeek Coder
      chat: OPENAI_GPT35            # 普通对话用GPT-3.5
    
    # 用户套餐模型限制
    plan-limits:
      FREE:
        - OPENAI_GPT35
        - DEEPSEEK_V3
      PREMIUM:
        - OPENAI_GPT35
        - OPENAI_GPT4
        - DEEPSEEK_V3
        - DEEPSEEK_CODER
        - MINIMAX_ABAB5
      ENTERPRISE:
        - all  # 所有模型
    
    # 模型参数配置
    parameters:
      interview:
        temperature: 0.7
        max-tokens: 4000
        top-p: 0.9
      resume:
        temperature: 0.5
        max-tokens: 3000
      code:
        temperature: 0.3
        max-tokens: 4000

3.3 模型配置属性类

java 复制代码
/**
 * AI模型配置属性
 */
@Data
@Configuration
@ConfigurationProperties(prefix = "ai.model")
public class AiModelProperties {
    
    private String defaultModel;
    private Map<String, String> scenarios = new HashMap<>();
    private Map<String, List<String>> planLimits = new HashMap<>();
    private Map<String, ModelParameters> parameters = new HashMap<>();
    
    @Data
    public static class ModelParameters {
        private Double temperature;
        private Integer maxTokens;
        private Double topP;
        private Double frequencyPenalty;
        private Double presencePenalty;
    }
    
    /**
     * 获取场景对应的模型
     */
    public ModelProvider getModelForScenario(String scenario) {
        String modelName = scenarios.getOrDefault(scenario, defaultModel);
        return ModelProvider.valueOf(modelName);
    }
    
    /**
     * 获取用户套餐允许的模型列表
     */
    public List<ModelProvider> getModelsForPlan(String plan) {
        List<String> modelNames = planLimits.getOrDefault(plan, 
            planLimits.get("FREE"));
        
        if (modelNames.contains("all")) {
            return Arrays.asList(ModelProvider.values());
        }
        
        return modelNames.stream()
            .map(ModelProvider::valueOf)
            .collect(Collectors.toList());
    }
    
    /**
     * 获取场景参数
     */
    public ModelParameters getParameters(String scenario) {
        return parameters.getOrDefault(scenario, new ModelParameters());
    }
}

四、动态模型切换机制

4.1 ModelSwitchService 实现

java 复制代码
/**
 * 模型切换服务
 * 实现运行时动态切换模型,无需重启服务
 */
@Service
@Slf4j
public class ModelSwitchService {
    
    @Autowired
    private AiModelProperties modelProperties;
    
    @Autowired
    private ApplicationContext applicationContext;
    
    // 缓存ChatClient实例
    private final Map<ModelProvider, ChatClient> clientCache = new ConcurrentHashMap<>();
    
    // 当前默认模型
    private volatile ModelProvider currentDefault = ModelProvider.OPENAI_GPT35;
    
    /**
     * 获取指定模型的ChatClient
     */
    public ChatClient getChatClient(ModelProvider provider) {
        return clientCache.computeIfAbsent(provider, this::createChatClient);
    }
    
    /**
     * 获取当前默认模型的ChatClient
     */
    public ChatClient getDefaultChatClient() {
        return getChatClient(currentDefault);
    }
    
    /**
     * 根据场景获取ChatClient
     */
    public ChatClient getChatClientForScenario(String scenario) {
        ModelProvider provider = modelProperties.getModelForScenario(scenario);
        return getChatClient(provider);
    }
    
    /**
     * 创建ChatClient实例
     */
    private ChatClient createChatClient(ModelProvider provider) {
        log.info("Creating ChatClient for provider: {}", provider);
        
        ChatModel chatModel = createChatModel(provider);
        
        return ChatClient.builder(chatModel)
            .defaultSystem(systemText -> systemText.text(
                "你是AI面试助手,帮助用户准备技术面试。"))
            .defaultOptions(ChatOptions.builder()
                .temperature(0.7)
                .maxTokens(2000)
                .build())
            .build();
    }
    
    /**
     * 创建ChatModel实例
     */
    private ChatModel createChatModel(ModelProvider provider) {
        return switch (provider.getProvider()) {
            case "openai" -> createOpenAiModel(provider);
            case "deepseek" -> createDeepSeekModel(provider);
            case "minimax" -> createMiniMaxModel(provider);
            case "ollama" -> createOllamaModel(provider);
            default -> throw new IllegalArgumentException("Unknown provider: " + provider);
        };
    }
    
    private ChatModel createOpenAiModel(ModelProvider provider) {
        OpenAiApi openAiApi = new OpenAiApi(
            System.getenv("OPENAI_API_KEY"),
            "https://api.openai.com"
        );
        
        OpenAiChatModel chatModel = new OpenAiChatModel(
            openAiApi,
            OpenAiChatOptions.builder()
                .model(provider.getModelName())
                .temperature(0.7)
                .build()
        );
        
        return chatModel;
    }
    
    private ChatModel createDeepSeekModel(ModelProvider provider) {
        // DeepSeek使用OpenAI兼容接口
        OpenAiApi deepSeekApi = new OpenAiApi(
            System.getenv("DEEPSEEK_API_KEY"),
            "https://api.deepseek.com"
        );
        
        return new OpenAiChatModel(
            deepSeekApi,
            OpenAiChatOptions.builder()
                .model(provider.getModelName())
                .temperature(0.7)
                .build()
        );
    }
    
    private ChatModel createMiniMaxModel(ModelProvider provider) {
        // MiniMax需要自定义实现
        return new MiniMaxChatModel(
            System.getenv("MINIMAX_API_KEY"),
            provider.getModelName()
        );
    }
    
    private ChatModel createOllamaModel(ModelProvider provider) {
        OllamaApi ollamaApi = new OllamaApi("http://localhost:11434");
        
        return OllamaChatModel.builder()
            .ollamaApi(ollamaApi)
            .defaultOptions(OllamaOptions.builder()
                .model(provider.getModelName())
                .temperature(0.8)
                .build())
            .build();
    }
    
    /**
     * 切换默认模型(热切换)
     */
    public void switchDefaultModel(ModelProvider provider) {
        log.info("Switching default model from {} to {}", currentDefault, provider);
        this.currentDefault = provider;
    }
    
    /**
     * 获取当前默认模型
     */
    public ModelProvider getCurrentDefault() {
        return currentDefault;
    }
    
    /**
     * 获取所有可用模型
     */
    public List<ModelProvider> getAvailableModels() {
        return Arrays.stream(ModelProvider.values())
            .filter(this::isModelAvailable)
            .collect(Collectors.toList());
    }
    
    /**
     * 检查模型是否可用
     */
    public boolean isModelAvailable(ModelProvider provider) {
        try {
            ChatClient client = getChatClient(provider);
            // 发送测试请求
            String response = client.prompt()
                .user("Hi")
                .call()
                .content();
            return StringUtils.isNotBlank(response);
        } catch (Exception e) {
            log.warn("Model {} is not available: {}", provider, e.getMessage());
            return false;
        }
    }
    
    /**
     * 清除模型缓存(用于配置更新后)
     */
    public void clearCache() {
        clientCache.clear();
        log.info("ChatClient cache cleared");
    }
}

4.2 模型选择策略

java 复制代码
/**
 * 模型选择策略
 * 根据场景、成本、可用性自动选择最优模型
 */
@Service
@Slf4j
public class ModelSelectionStrategy {
    
    @Autowired
    private ModelSwitchService modelSwitchService;
    
    @Autowired
    private TokenUsageService tokenUsageService;
    
    /**
     * 智能模型选择
     */
    public ModelProvider selectModel(ModelSelectionContext context) {
        // 1. 检查用户是否有权限使用该模型
        if (!hasPermission(context.getUserPlan(), context.getPreferredModel())) {
            log.warn("User plan {} does not have access to model {}", 
                context.getUserPlan(), context.getPreferredModel());
            // 降级到套餐允许的最佳模型
            return getBestAvailableModel(context.getUserPlan(), context.getScenario());
        }
        
        // 2. 检查模型是否可用
        if (!modelSwitchService.isModelAvailable(context.getPreferredModel())) {
            log.warn("Preferred model {} is not available, falling back", 
                context.getPreferredModel());
            return getFallbackModel(context.getPreferredModel());
        }
        
        // 3. 成本优化:如果是简单任务,使用更便宜的模型
        if (context.isCostSensitive() && isSimpleTask(context)) {
            ModelProvider cheaperModel = getCheaperAlternative(context.getPreferredModel());
            if (cheaperModel != null && modelSwitchService.isModelAvailable(cheaperModel)) {
                log.info("Using cheaper model {} instead of {}", 
                    cheaperModel, context.getPreferredModel());
                return cheaperModel;
            }
        }
        
        return context.getPreferredModel();
    }
    
    /**
     * 检查用户是否有权限使用模型
     */
    private boolean hasPermission(String plan, ModelProvider model) {
        List<ModelProvider> allowedModels = modelSwitchService.getAvailableModels();
        return allowedModels.contains(model);
    }
    
    /**
     * 获取最佳可用模型
     */
    private ModelProvider getBestAvailableModel(String plan, String scenario) {
        List<ModelProvider> candidates = switch (scenario) {
            case "interview" -> List.of(
                ModelProvider.OPENAI_GPT4,
                ModelProvider.DEEPSEEK_V3,
                ModelProvider.OPENAI_GPT35
            );
            case "code" -> List.of(
                ModelProvider.DEEPSEEK_CODER,
                ModelProvider.OPENAI_GPT4,
                ModelProvider.CLAUDE_3_SONNET
            );
            default -> List.of(
                ModelProvider.OPENAI_GPT35,
                ModelProvider.DEEPSEEK_V3
            );
        };
        
        for (ModelProvider model : candidates) {
            if (hasPermission(plan, model) && modelSwitchService.isModelAvailable(model)) {
                return model;
            }
        }
        
        return ModelProvider.OPENAI_GPT35;  // 最终兜底
    }
    
    /**
     * 获取降级模型
     */
    private ModelProvider getFallbackModel(ModelProvider original) {
        return switch (original) {
            case OPENAI_GPT4 -> ModelProvider.DEEPSEEK_V3;
            case DEEPSEEK_V3 -> ModelProvider.OPENAI_GPT35;
            case CLAUDE_3_OPUS -> ModelProvider.CLAUDE_3_SONNET;
            default -> ModelProvider.OPENAI_GPT35;
        };
    }
    
    /**
     * 判断是否简单任务
     */
    private boolean isSimpleTask(ModelSelectionContext context) {
        String prompt = context.getPrompt();
        // 简单启发式判断
        return prompt.length() < 500 && 
               !prompt.contains("代码") &&
               !prompt.contains("算法") &&
               !prompt.contains("设计");
    }
    
    /**
     * 获取更便宜的替代模型
     */
    private ModelProvider getCheaperAlternative(ModelProvider original) {
        return switch (original) {
            case OPENAI_GPT4 -> ModelProvider.OPENAI_GPT35;
            case CLAUDE_3_OPUS -> ModelProvider.CLAUDE_3_SONNET;
            case DEEPSEEK_CODER -> ModelProvider.DEEPSEEK_V3;
            default -> null;
        };
    }
}

/**
 * 模型选择上下文
 */
@Data
@Builder
public class ModelSelectionContext {
    private ModelProvider preferredModel;  // 用户偏好的模型
    private String userPlan;               // 用户套餐
    private String scenario;               // 使用场景
    private String prompt;                 // 提示内容
    private boolean costSensitive;         // 是否成本敏感
}

五、统一 LLM 调用封装

5.1 LlmService 抽象层

java 复制代码
/**
 * 统一LLM调用服务
 * 封装所有模型调用细节,对外提供统一接口
 */
@Service
@Slf4j
public class LlmService {
    
    @Autowired
    private ModelSwitchService modelSwitchService;
    
    @Autowired
    private ModelSelectionStrategy modelSelectionStrategy;
    
    @Autowired
    private TokenUsageService tokenUsageService;
    
    @Autowired
    private PromptTemplateService promptTemplateService;
    
    /**
     * 同步调用 - 简单场景
     */
    public String chat(String message) {
        return chat(message, modelSwitchService.getCurrentDefault());
    }
    
    /**
     * 同步调用 - 指定模型
     */
    public String chat(String message, ModelProvider model) {
        ChatClient client = modelSwitchService.getChatClient(model);
        
        long startTime = System.currentTimeMillis();
        try {
            String response = client.prompt()
                .user(message)
                .call()
                .content();
            
            // 记录Token使用
            recordUsage(model, message, response, System.currentTimeMillis() - startTime);
            
            return response;
        } catch (Exception e) {
            log.error("Chat failed with model {}: {}", model, e.getMessage());
            throw new BusinessException(ResultCode.AI_MODEL_ERROR, e.getMessage());
        }
    }
    
    /**
     * 场景化调用 - 自动选择模型
     */
    public String chatWithScenario(String message, String scenario, String userPlan) {
        // 选择最优模型
        ModelSelectionContext context = ModelSelectionContext.builder()
            .preferredModel(modelSwitchService.getCurrentDefault())
            .userPlan(userPlan)
            .scenario(scenario)
            .prompt(message)
            .costSensitive(!"ENTERPRISE".equals(userPlan))
            .build();
        
        ModelProvider selectedModel = modelSelectionStrategy.selectModel(context);
        
        log.info("Selected model {} for scenario {}", selectedModel, scenario);
        
        // 获取场景参数
        AiModelProperties.ModelParameters params = 
            promptTemplateService.getParametersForScenario(scenario);
        
        ChatClient client = modelSwitchService.getChatClient(selectedModel);
        
        long startTime = System.currentTimeMillis();
        String response = client.prompt()
            .user(message)
            .options(ChatOptions.builder()
                .temperature(params.getTemperature())
                .maxTokens(params.getMaxTokens())
                .build())
            .call()
            .content();
        
        recordUsage(selectedModel, message, response, System.currentTimeMillis() - startTime);
        
        return response;
    }
    
    /**
     * 结构化输出 - 自动解析JSON
     */
    public <T> T chatForEntity(String message, Class<T> responseType) {
        return chatForEntity(message, responseType, 
            modelSwitchService.getCurrentDefault());
    }
    
    public <T> T chatForEntity(String message, Class<T> responseType, 
                               ModelProvider model) {
        ChatClient client = modelSwitchService.getChatClient(model);
        
        try {
            return client.prompt()
                .user(message)
                .call()
                .entity(responseType);
        } catch (Exception e) {
            log.error("Entity parsing failed: {}", e.getMessage());
            throw new BusinessException(ResultCode.AI_RESPONSE_ERROR, e.getMessage());
        }
    }
    
    /**
     * 流式调用
     */
    public Flux<String> streamChat(String message) {
        return streamChat(message, modelSwitchService.getCurrentDefault());
    }
    
    public Flux<String> streamChat(String message, ModelProvider model) {
        ChatClient client = modelSwitchService.getChatClient(model);
        
        return client.prompt()
            .user(message)
            .stream()
            .content()
            .doOnNext(chunk -> log.debug("Stream chunk: {}", chunk))
            .doOnError(e -> log.error("Stream error: ", e));
    }
    
    /**
     * 多轮对话
     */
    public String chatWithHistory(List<ChatMessage> history, String newMessage) {
        return chatWithHistory(history, newMessage, 
            modelSwitchService.getCurrentDefault());
    }
    
    public String chatWithHistory(List<ChatMessage> history, String newMessage,
                                  ModelProvider model) {
        ChatClient.ChatClientRequestSpec spec = 
            modelSwitchService.getChatClient(model).prompt();
        
        // 添加系统提示
        spec.system("你是一个专业的AI面试助手");
        
        // 添加历史消息
        for (ChatMessage msg : history) {
            switch (msg.getRole()) {
                case USER -> spec.user(msg.getContent());
                case ASSISTANT -> spec.system(msg.getContent());
            }
        }
        
        return spec.user(newMessage)
            .call()
            .content();
    }
    
    /**
     * 带重试的调用
     */
    public String chatWithRetry(String message, int maxRetries) {
        int attempts = 0;
        Exception lastException = null;
        
        while (attempts < maxRetries) {
            try {
                return chat(message);
            } catch (Exception e) {
                lastException = e;
                attempts++;
                log.warn("Chat attempt {} failed: {}", attempts, e.getMessage());
                
                if (attempts < maxRetries) {
                    // 指数退避
                    try {
                        Thread.sleep((long) Math.pow(2, attempts) * 1000);
                    } catch (InterruptedException ie) {
                        Thread.currentThread().interrupt();
                    }
                }
            }
        }
        
        throw new BusinessException(ResultCode.AI_MODEL_ERROR, 
            "Failed after " + maxRetries + " attempts: " + lastException.getMessage());
    }
    
    /**
     * 记录Token使用
     */
    private void recordUsage(ModelProvider model, String prompt, String response, 
                            long latency) {
        try {
            // 估算Token数(实际应从API响应中获取)
            int promptTokens = estimateTokens(prompt);
            int completionTokens = estimateTokens(response);
            int totalTokens = promptTokens + completionTokens;
            
            TokenUsageRecord record = TokenUsageRecord.builder()
                .model(model)
                .promptTokens(promptTokens)
                .completionTokens(completionTokens)
                .totalTokens(totalTokens)
                .cost(calculateCost(model, totalTokens))
                .latencyMs(latency)
                .timestamp(LocalDateTime.now())
                .build();
            
            tokenUsageService.recordUsage(record);
        } catch (Exception e) {
            log.error("Failed to record token usage: {}", e.getMessage());
        }
    }
    
    /**
     * 估算Token数(简化版)
     */
    private int estimateTokens(String text) {
        // 粗略估算:中文约1.5字符/Token,英文约4字符/Token
        int chineseChars = (int) text.chars().filter(c -> c > 127).count();
        int englishChars = text.length() - chineseChars;
        return (int) (chineseChars / 1.5 + englishChars / 4);
    }
    
    /**
     * 计算成本
     */
    private BigDecimal calculateCost(ModelProvider model, int tokens) {
        // 价格单位:分/千Token
        BigDecimal pricePer1K = BigDecimal.valueOf(model.getPricePer1K());
        return pricePer1K.multiply(BigDecimal.valueOf(tokens))
            .divide(BigDecimal.valueOf(1000), 4, RoundingMode.HALF_UP);
    }
}

5.2 Prompt 模板服务

java 复制代码
/**
 * Prompt 模板服务
 * 管理各类场景的Prompt模板
 */
@Service
@Slf4j
public class PromptTemplateService {
    
    @Autowired
    private AiModelProperties modelProperties;
    
    private final Map<String, String> templates = new HashMap<>();
    
    @PostConstruct
    public void init() {
        // 初始化内置模板
        loadBuiltinTemplates();
        
        // 从文件加载自定义模板
        loadCustomTemplates();
    }
    
    private void loadBuiltinTemplates() {
        // 面试场景模板
        templates.put("interview-system", """
            你是一位经验丰富的技术面试官,正在面试{jobType}岗位的候选人。
            
            候选人背景:
            - 工作年限:{experience}
            - 面试轮次:{round}
            
            你的任务是:
            1. 根据候选人背景提出有针对性的技术问题
            2. 评估候选人的回答质量
            3. 给出建设性的反馈和改进建议
            
            请保持专业、友善的态度,问题难度要适中。
            """);
        
        // 简历优化模板
        templates.put("resume-optimize", """
            你是一位资深HR和职业规划师,擅长简历优化。
            
            请分析以下简历,并给出优化建议:
            
            目标岗位JD:
            {jobDescription}
            
            简历内容:
            {resumeContent}
            
            请从以下维度分析:
            1. 简历与岗位的匹配度评分(0-100)
            2. 关键词匹配情况
            3. 具体优化建议
            4. 修改后的简历内容
            """);
        
        // 代码评审模板
        templates.put("code-review", """
            你是一位资深软件工程师,请对以下代码进行评审。
            
            编程语言:{language}
            
            代码:
            ```{language}
            {code}
            ```
            
            请从以下方面给出评审意见:
            1. 代码质量和可读性
            2. 潜在的问题和Bug
            3. 性能优化建议
            4. 最佳实践建议
            """);
    }
    
    private void loadCustomTemplates() {
        // 从配置文件或数据库加载
        // templates.putAll(loadFromDatabase());
    }
    
    /**
     * 渲染模板
     */
    public String renderTemplate(String templateKey, Map<String, Object> variables) {
        String template = templates.get(templateKey);
        if (template == null) {
            throw new BusinessException("Template not found: " + templateKey);
        }
        
        // 简单变量替换
        String result = template;
        for (Map.Entry<String, Object> entry : variables.entrySet()) {
            result = result.replace("{" + entry.getKey() + "}", 
                String.valueOf(entry.getValue()));
        }
        
        return result;
    }
    
    /**
     * 获取场景参数
     */
    public AiModelProperties.ModelParameters getParametersForScenario(String scenario) {
        AiModelProperties.ModelParameters params = 
            modelProperties.getParameters().get(scenario);
        
        if (params == null) {
            // 返回默认参数
            params = new AiModelProperties.ModelParameters();
            params.setTemperature(0.7);
            params.setMaxTokens(2000);
        }
        
        return params;
    }
    
    /**
     * 获取系统Prompt
     */
    public String getSystemPrompt(String scenario) {
        return templates.getOrDefault(scenario + "-system", 
            "你是一个专业的AI助手。");
    }
}

六、Token 计费与配额管理

6.1 Token 使用记录

java 复制代码
/**
 * Token使用记录实体
 */
@Entity
@Table(name = "token_usage_records")
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class TokenUsageRecord extends BaseEntity {
    
    @Column(name = "user_id")
    private Long userId;
    
    @Enumerated(EnumType.STRING)
    @Column(nullable = false)
    private ModelProvider model;
    
    @Column(name = "prompt_tokens")
    private Integer promptTokens;
    
    @Column(name = "completion_tokens")
    private Integer completionTokens;
    
    @Column(name = "total_tokens")
    private Integer totalTokens;
    
    @Column(name = "cost", precision = 10, scale = 4)
    private BigDecimal cost;  // 成本(元)
    
    @Column(name = "latency_ms")
    private Long latencyMs;
    
    @Column(name = "scenario")
    private String scenario;
    
    @Column(name = "request_id")
    private String requestId;
    
    @Column(name = "status")
    @Enumerated(EnumType.STRING)
    private UsageStatus status;
    
    public enum UsageStatus {
        SUCCESS, FAILED, CACHED
    }
}

6.2 Token 使用服务

java 复制代码
/**
 * Token使用记录服务
 */
@Service
@Slf4j
public class TokenUsageService {
    
    @Autowired
    private TokenUsageRecordRepository usageRepository;
    
    @Autowired
    private StringRedisTemplate redisTemplate;
    
    private static final String USER_QUOTA_KEY = "user:quota:";
    private static final String USER_USAGE_KEY = "user:usage:";
    
    /**
     * 记录Token使用
     */
    @Async
    public void recordUsage(TokenUsageRecord record) {
        try {
            usageRepository.save(record);
            
            // 更新Redis使用统计
            updateUsageStats(record);
            
            log.debug("Token usage recorded: {} tokens, cost: {}", 
                record.getTotalTokens(), record.getCost());
        } catch (Exception e) {
            log.error("Failed to record token usage: {}", e.getMessage());
        }
    }
    
    /**
     * 检查用户配额
     */
    public boolean checkQuota(Long userId, UserPlan plan) {
        String key = USER_USAGE_KEY + userId + ":" + getCurrentMonth();
        String usedStr = redisTemplate.opsForValue().get(key);
        int used = usedStr != null ? Integer.parseInt(usedStr) : 0;
        
        return used < plan.getQuota();
    }
    
    /**
     * 扣除配额
     */
    public void deductQuota(Long userId, int tokens) {
        String key = USER_USAGE_KEY + userId + ":" + getCurrentMonth();
        redisTemplate.opsForValue().increment(key, tokens);
        
        // 设置过期时间为月底
        LocalDateTime endOfMonth = LocalDateTime.now()
            .with(TemporalAdjusters.lastDayOfMonth())
            .with(LocalTime.MAX);
        long seconds = Duration.between(LocalDateTime.now(), endOfMonth).getSeconds();
        redisTemplate.expire(key, seconds, TimeUnit.SECONDS);
    }
    
    /**
     * 获取用户使用统计
     */
    public UsageStats getUsageStats(Long userId, LocalDate startDate, LocalDate endDate) {
        List<TokenUsageRecord> records = usageRepository
            .findByUserIdAndCreatedAtBetween(userId, startDate.atStartOfDay(), 
                endDate.atTime(LocalTime.MAX));
        
        int totalTokens = records.stream()
            .mapToInt(TokenUsageRecord::getTotalTokens)
            .sum();
        
        BigDecimal totalCost = records.stream()
            .map(TokenUsageRecord::getCost)
            .reduce(BigDecimal.ZERO, BigDecimal::add);
        
        Map<ModelProvider, Integer> modelUsage = records.stream()
            .collect(Collectors.groupingBy(
                TokenUsageRecord::getModel,
                Collectors.summingInt(TokenUsageRecord::getTotalTokens)
            ));
        
        return UsageStats.builder()
            .totalTokens(totalTokens)
            .totalCost(totalCost)
            .requestCount(records.size())
            .modelUsage(modelUsage)
            .averageLatency(records.stream()
                .mapToLong(TokenUsageRecord::getLatencyMs)
                .average()
                .orElse(0))
            .build();
    }
    
    /**
     * 获取系统整体使用统计
     */
    public SystemUsageStats getSystemStats(LocalDate date) {
        LocalDateTime start = date.atStartOfDay();
        LocalDateTime end = date.atTime(LocalTime.MAX);
        
        List<TokenUsageRecord> records = usageRepository
            .findByCreatedAtBetween(start, end);
        
        return SystemUsageStats.builder()
            .date(date)
            .totalRequests(records.size())
            .totalTokens(records.stream()
                .mapToInt(TokenUsageRecord::getTotalTokens)
                .sum())
            .totalCost(records.stream()
                .map(TokenUsageRecord::getCost)
                .reduce(BigDecimal.ZERO, BigDecimal::add))
            .modelDistribution(records.stream()
                .collect(Collectors.groupingBy(
                    TokenUsageRecord::getModel,
                    Collectors.counting()
                )))
            .build();
    }
    
    private void updateUsageStats(TokenUsageRecord record) {
        if (record.getUserId() == null) return;
        
        String key = USER_USAGE_KEY + record.getUserId() + ":" + getCurrentMonth();
        redisTemplate.opsForValue().increment(key, record.getTotalTokens());
    }
    
    private String getCurrentMonth() {
        return LocalDate.now().format(DateTimeFormatter.ofPattern("yyyy-MM"));
    }
}

6.3 配额限制切面

java 复制代码
/**
 * 配额限制切面
 * 在AI调用前检查用户配额
 */
@Aspect
@Component
@Slf4j
public class QuotaLimitAspect {
    
    @Autowired
    private TokenUsageService tokenUsageService;
    
    @Autowired
    private UserService userService;
    
    @Around("@annotation(checkQuota)")
    public Object around(ProceedingJoinPoint point, CheckQuota checkQuota) 
            throws Throwable {
        
        // 获取当前用户
        Long userId = getCurrentUserId();
        if (userId == null) {
            return point.proceed();
        }
        
        User user = userService.findById(userId);
        
        // 检查配额
        if (!tokenUsageService.checkQuota(userId, user.getPlan())) {
            log.warn("User {} quota exceeded", userId);
            throw new BusinessException(ResultCode.INSUFFICIENT_BALANCE, 
                "本月AI调用额度已用完,请升级套餐");
        }
        
        return point.proceed();
    }
    
    private Long getCurrentUserId() {
        Authentication auth = SecurityContextHolder.getContext().getAuthentication();
        if (auth != null && auth.getPrincipal() instanceof UserDetails) {
            // 从UserDetails获取用户ID
            return ((UserDetails) auth.getPrincipal()).getId();
        }
        return null;
    }
}

七、代码示例

7.1 ModelSwitchController

java 复制代码
/**
 * 模型切换控制器
 * 提供模型管理和切换接口
 */
@RestController
@RequestMapping("/api/admin/ai")
@PreAuthorize("hasRole('ADMIN')")
@Slf4j
public class ModelSwitchController {
    
    @Autowired
    private ModelSwitchService modelSwitchService;
    
    @Autowired
    private LlmService llmService;
    
    /**
     * 获取所有可用模型
     */
    @GetMapping("/models")
    public Result<List<ModelInfo>> getAvailableModels() {
        List<ModelInfo> models = modelSwitchService.getAvailableModels().stream()
            .map(this::toModelInfo)
            .collect(Collectors.toList());
        return Result.success(models);
    }
    
    /**
     * 切换默认模型
     */
    @PostMapping("/models/default")
    public Result<Void> switchDefaultModel(@RequestBody @Valid SwitchModelRequest request) {
        ModelProvider provider = ModelProvider.valueOf(request.getModel());
        modelSwitchService.switchDefaultModel(provider);
        log.info("Default model switched to {} by admin", provider);
        return Result.success("模型切换成功");
    }
    
    /**
     * 获取当前默认模型
     */
    @GetMapping("/models/default")
    public Result<ModelInfo> getCurrentDefaultModel() {
        ModelProvider current = modelSwitchService.getCurrentDefault();
        return Result.success(toModelInfo(current));
    }
    
    /**
     * 测试模型可用性
     */
    @PostMapping("/models/{model}/test")
    public Result<ModelTestResult> testModel(@PathVariable String model) {
        ModelProvider provider = ModelProvider.valueOf(model);
        
        long startTime = System.currentTimeMillis();
        boolean available = modelSwitchService.isModelAvailable(provider);
        long latency = System.currentTimeMillis() - startTime;
        
        ModelTestResult result = ModelTestResult.builder()
            .model(model)
            .available(available)
            .latencyMs(latency)
            .build();
        
        return Result.success(result);
    }
    
    /**
     * 清除模型缓存
     */
    @PostMapping("/cache/clear")
    public Result<Void> clearCache() {
        modelSwitchService.clearCache();
        return Result.success("缓存已清除");
    }
    
    /**
     * 获取Token使用统计
     */
    @GetMapping("/usage/stats")
    public Result<UsageStatsResponse> getUsageStats(
            @RequestParam @DateTimeFormat(iso = DateTimeFormat.ISO.DATE) LocalDate startDate,
            @RequestParam @DateTimeFormat(iso = DateTimeFormat.ISO.DATE) LocalDate endDate) {
        
        // 获取系统整体统计
        SystemUsageStats stats = tokenUsageService.getSystemStats(LocalDate.now());
        
        UsageStatsResponse response = UsageStatsResponse.builder()
            .todayRequests(stats.getTotalRequests())
            .todayTokens(stats.getTotalTokens())
            .todayCost(stats.getTotalCost())
            .modelDistribution(stats.getModelDistribution())
            .build();
        
        return Result.success(response);
    }
    
    private ModelInfo toModelInfo(ModelProvider provider) {
        return ModelInfo.builder()
            .code(provider.name())
            .name(provider.getDisplayName())
            .provider(provider.getProvider())
            .modelName(provider.getModelName())
            .pricePer1K(provider.getPricePer1K())
            .isCloud(provider.isCloud())
            .isDefault(provider == modelSwitchService.getCurrentDefault())
            .build();
    }
}

7.2 面试服务中使用多模型

java 复制代码
/**
 * 面试服务 - 多模型调用示例
 */
@Service
@Slf4j
public class InterviewService {
    
    @Autowired
    private LlmService llmService;
    
    @Autowired
    private PromptTemplateService promptTemplateService;
    
    @Autowired
    private InterviewSessionRepository sessionRepository;
    
    /**
     * 开始面试 - 根据用户套餐选择模型
     */
    @CheckQuota
    public InterviewResponse startInterview(StartInterviewRequest request, User user) {
        // 根据场景选择模型
        String scenario = "interview";
        
        // 渲染Prompt
        String prompt = promptTemplateService.renderTemplate("interview-system", 
            Map.of(
                "jobType", request.getJobType(),
                "experience", request.getExperience(),
                "round", request.getRound()
            ));
        
        // 调用AI生成第一个问题
        String response = llmService.chatWithScenario(prompt, scenario, 
            user.getPlan().name());
        
        // 创建会话
        InterviewSession session = InterviewSession.builder()
            .userId(user.getId())
            .jobType(request.getJobType())
            .status(InterviewStatus.IN_PROGRESS)
            .build();
        sessionRepository.save(session);
        
        return InterviewResponse.builder()
            .sessionId(session.getId())
            .question(response)
            .build();
    }
    
    /**
     * 回答面试问题 - 流式响应
     */
    public Flux<String> answerQuestionStream(Long sessionId, String answer, User user) {
        InterviewSession session = sessionRepository.findById(sessionId)
            .orElseThrow(() -> new BusinessException(ResultCode.INTERVIEW_NOT_FOUND));
        
        // 构建带上下文的Prompt
        String prompt = buildContextualPrompt(session, answer);
        
        // 流式调用
        return llmService.streamChat(prompt);
    }
    
    /**
     * 评估面试表现 - 使用更强的模型
     */
    public InterviewEvaluation evaluateInterview(Long sessionId, User user) {
        InterviewSession session = sessionRepository.findById(sessionId)
            .orElseThrow(() -> new BusinessException(ResultCode.INTERVIEW_NOT_FOUND));
        
        // 使用GPT-4进行评估
        String prompt = buildEvaluationPrompt(session);
        
        InterviewEvaluationResult result = llmService.chatForEntity(
            prompt, 
            InterviewEvaluationResult.class,
            ModelProvider.OPENAI_GPT4
        );
        
        return InterviewEvaluation.builder()
            .sessionId(sessionId)
            .overallScore(result.getOverallScore())
            .technicalScore(result.getTechnicalScore())
            .communicationScore(result.getCommunicationScore())
            .feedback(result.getFeedback())
            .build();
    }
    
    private String buildContextualPrompt(InterviewSession session, String answer) {
        // 获取历史对话
        List<ChatMessage> history = session.getMessages();
        
        StringBuilder prompt = new StringBuilder();
        prompt.append("面试岗位:").append(session.getJobType()).append("\n\n");
        prompt.append("历史对话:\n");
        
        for (ChatMessage msg : history) {
            prompt.append(msg.getRole() == Role.USER ? "候选人:" : "面试官:")
                  .append(msg.getContent())
                  .append("\n");
        }
        
        prompt.append("候选人最新回答:").append(answer).append("\n\n");
        prompt.append("请作为面试官给出回应,可以是追问或评价。");
        
        return prompt.toString();
    }
    
    private String buildEvaluationPrompt(InterviewSession session) {
        return """
            请对以下技术面试进行全面评估:
            
            岗位类型:%s
            
            面试对话记录:
            %s
            
            请从以下维度给出评分(0-100)和详细反馈:
            1. 整体表现
            2. 技术能力
            3. 沟通表达
            4. 改进建议
            
            请以JSON格式返回结果。
            """.formatted(session.getJobType(), formatConversation(session.getMessages()));
    }
}

八、总结

8.1 本文要点

本文详细介绍了Spring AI多模型架构的完整实现方案:

  1. Spring AI核心概念:ChatClient、Prompt模板、流式响应
  2. 多模型配置管理:ModelProvider枚举、YAML配置、属性类
  3. 动态模型切换:ModelSwitchService实现运行时热切换
  4. 统一LLM调用:LlmService封装所有模型调用细节
  5. Token计费管理:使用记录、配额检查、成本统计
  6. 代码示例:ModelSwitchController、InterviewService实战

8.2 架构优势

  • 灵活性:运行时切换模型,无需重启服务
  • 成本优化:根据场景和用户套餐智能选择模型
  • 可扩展性:新增模型只需添加枚举和配置
  • 可观测性:完整的Token使用统计和成本分析

8.3 后续优化方向

  • 模型缓存:响应结果缓存,减少重复调用
  • A/B测试:不同模型效果对比
  • 智能路由:基于历史数据自动选择最优模型
  • 模型微调:针对特定场景微调本地模型

参考资料

  1. Spring AI Documentation
  2. OpenAI API Reference
  3. DeepSeek API Documentation
  4. MiniMax API Documentation
  5. Ollama Documentation

本系列文章到此结束,感谢阅读!如果对你有帮助,欢迎点赞、收藏、评论交流!

相关推荐
Mr数据杨14 小时前
【CanMV K210】通信扩展 PCF8591 ADC 模数转换与模拟量读取
人工智能·硬件开发·canmv k210
止水编程 water_proof14 小时前
Spring Web MVC 入门
前端·spring·mvc
DogDaoDao14 小时前
【GitHub】RealtimeSTT 深度解析:打造低延迟、生产级语音识别应用的全栈利器
人工智能·语言模型·大模型·github·语音识别·stt·realtimestt
菜鸡旭旭14 小时前
【AI培训中台-练习评分V0】
人工智能
chenying99817914 小时前
本地部署 TTS 方案横向对比:Fish Speech、CosyVoice 2、GPT-SoVITS 与 VoxFlash-TTS
人工智能·实时音视频·语音合成·tts
颖火虫盟主15 小时前
Lua 协程:从 API 到底层原理再到 Skynet 架构的完整学习路径
学习·架构·lua
A153625515 小时前
流量暗战:2026年科技公司GEO应用成熟度调查
人工智能·科技·chatgpt
私人珍藏库15 小时前
[Android] 全能语音计算器v4.6
人工智能·windows·语音识别·工具·软件·多功能
梦想三三15 小时前
【OpenCV】图像的轮廓检测
人工智能·opencv·计算机视觉