Spring AI 设计模式综合应用与完整工程实现

Spring AI 设计模式综合应用与完整工程实现

🎯 一、项目概述

本项目将模拟实现 Spring AI 的核心原理,通过多种设计模式的组合使用,展示一个完整的 AI 服务集成框架。我们将实现以下核心功能:

  1. 统一 AI 客户端接口 - 支持多种 AI 模型(OpenAI、Azure、本地模型)
  2. 智能路由机制 - 根据模型类型自动选择适配器
  3. 可扩展插件架构 - 支持自定义模型扩展
  4. 完整配置管理 - 基于 Spring Boot 的自动配置
  5. 监控和指标收集 - 集成 Micrometer
  6. 完整的示例应用 - 演示如何使用

📁 二、完整项目结构

复制代码
spring-ai-simulation/
├── pom.xml                              # Maven 依赖配置
├── src/
│   ├── main/
│   │   ├── java/
│   │   │   └── com/
│   │   │       └── example/
│   │   │           └── ai/
│   │   │               ├── SpringAiSimulationApplication.java
│   │   │               ├── annotation/
│   │   │               │   └── EnableAiClient.java
│   │   │               ├── config/
│   │   │               │   ├── AiAutoConfiguration.java
│   │   │               │   ├── AiProperties.java
│   │   │               │   ├── AiClientConfig.java
│   │   │               │   └── Condition/
│   │   │               │       ├── OnAiModelCondition.java
│   │   │               │       └── OnAiProviderCondition.java
│   │   │               ├── core/
│   │   │               │   ├── AiClient.java
│   │   │               │   ├── ChatClient.java
│   │   │               │   ├── EmbeddingClient.java
│   │   │               │   ├── model/
│   │   │               │   │   ├── ChatRequest.java
│   │   │               │   │   ├── ChatResponse.java
│   │   │               │   │   ├── Message.java
│   │   │               │   │   ├── Choice.java
│   │   │               │   │   ├── Usage.java
│   │   │               │   │   ├── EmbeddingRequest.java
│   │   │               │   │   └── EmbeddingResponse.java
│   │   │               │   ├── template/
│   │   │               │   │   ├── AiClientTemplate.java
│   │   │               │   │   ├── BaseAiClient.java
│   │   │               │   │   ├── AbstractAiClient.java
│   │   │               │   │   └── ChatClientTemplate.java
│   │   │               │   └── exception/
│   │   │               │       ├── AiClientException.java
│   │   │               │       ├── AiServiceException.java
│   │   │               │       ├── RateLimitException.java
│   │   │               │       └── UnauthorizedException.java
│   │   │               ├── pattern/
│   │   │               │   ├── factory/
│   │   │               │   │   ├── AiClientFactory.java
│   │   │               │   │   └── AiAdapterFactory.java
│   │   │               │   ├── strategy/
│   │   │               │   │   ├── AiStrategy.java
│   │   │               │   │   ├── AiProviderStrategy.java
│   │   │               │   │   ├── OpenAiStrategy.java
│   │   │               │   │   ├── AzureAiStrategy.java
│   │   │               │   │   └── LocalAiStrategy.java
│   │   │               │   ├── adapter/
│   │   │               │   │   ├── AiModelAdapter.java
│   │   │               │   │   ├── OpenAiAdapter.java
│   │   │               │   │   ├── AzureAiAdapter.java
│   │   │               │   │   └── LocalAiAdapter.java
│   │   │               │   ├── builder/
│   │   │               │   │   └── ChatRequestBuilder.java
│   │   │               │   ├── chain/
│   │   │               │   │   ├── AiClientInterceptor.java
│   │   │               │   │   └── AiClientInterceptorChain.java
│   │   │               │   └── proxy/
│   │   │               │       └── AiClientProxy.java
│   │   │               ├── provider/
│   │   │               │   ├── OpenAiClient.java
│   │   │               │   ├── AzureAiClient.java
│   │   │               │   └── LocalAiClient.java
│   │   │               ├── service/
│   │   │               │   ├── AiService.java
│   │   │               │   ├── ChatService.java
│   │   │               │   └── EmbeddingService.java
│   │   │               ├── interceptor/
│   │   │               │   ├── LoggingInterceptor.java
│   │   │               │   ├── MetricsInterceptor.java
│   │   │               │   └── RetryInterceptor.java
│   │   │               ├── metrics/
│   │   │               │   └── AiClientMetrics.java
│   │   │               └── router/
│   │   │                   └── ModelRouter.java
│   │   └── resources/
│   │       ├── application.yml
│   │       └── application-dev.yml
│   └── test/
│       └── java/
│           └── com/
│               └── example/
│                   └── ai/
│                       ├── SpringAiSimulationApplicationTests.java
│                       └── service/
│                           └── AiServiceTest.java

📦 三、Maven 依赖配置

xml 复制代码
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 
         http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.example</groupId>
    <artifactId>spring-ai-simulation</artifactId>
    <version>1.0.0</version>
    <name>Spring AI Simulation</name>
    <description>A simulation framework demonstrating Spring AI design patterns</description>

    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>3.1.5</version>
        <relativePath/>
    </parent>

    <properties>
        <java.version>17</java.version>
        <jackson.version>2.15.2</jackson.version>
        <micrometer.version>1.11.5</micrometer.version>
    </properties>

    <dependencies>
        <!-- Spring Boot Starters -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-actuator</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-aop</artifactId>
        </dependency>
        
        <!-- Monitoring -->
        <dependency>
            <groupId>io.micrometer</groupId>
            <artifactId>micrometer-core</artifactId>
        </dependency>
        <dependency>
            <groupId>io.micrometer</groupId>
            <artifactId>micrometer-registry-prometheus</artifactId>
        </dependency>
        
        <!-- HTTP Client -->
        <dependency>
            <groupId>org.apache.httpcomponents.client5</groupId>
            <artifactId>httpclient5</artifactId>
        </dependency>
        
        <!-- JSON Processing -->
        <dependency>
            <groupId>com.fasterxml.jackson.core</groupId>
            <artifactId>jackson-databind</artifactId>
            <version>${jackson.version}</version>
        </dependency>
        
        <!-- Utilities -->
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <optional>true</optional>
        </dependency>
        <dependency>
            <groupId>com.google.guava</groupId>
            <artifactId>guava</artifactId>
            <version>32.1.3-jre</version>
        </dependency>
        
        <!-- Testing -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>
        <dependency>
            <groupId>org.mockito</groupId>
            <artifactId>mockito-core</artifactId>
            <scope>test</scope>
        </dependency>
    </dependencies>

    <build>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
                <configuration>
                    <excludes>
                        <exclude>
                            <groupId>org.projectlombok</groupId>
                            <artifactId>lombok</artifactId>
                        </exclude>
                    </excludes>
                </configuration>
            </plugin>
        </plugins>
    </build>
</project>

⚙️ 四、核心配置文件

1. application.yml

yaml 复制代码
# 应用配置
server:
  port: 8080
  servlet:
    context-path: /api

spring:
  application:
    name: spring-ai-simulation
  
  # AI 配置
  ai:
    enabled: true
    default-provider: openai
    default-model: gpt-3.5-turbo
    enable-metrics: true
    
    # OpenAI 配置
    openai:
      enabled: true
      api-key: ${OPENAI_API_KEY:your-openai-key}
      base-url: https://api.openai.com/v1
      timeout: 30s
      max-retries: 3
    
    # Azure AI 配置
    azure:
      enabled: false
      endpoint: https://your-resource.openai.azure.com
      api-key: ${AZURE_API_KEY:your-azure-key}
      deployment-name: gpt-35-turbo
      api-version: 2023-12-01-preview
    
    # 本地模型配置
    local:
      enabled: false
      model-path: /path/to/local/model
      device: cpu
    
    # HTTP 客户端配置
    http-client:
      max-connections: 50
      connection-timeout: 10s
      read-timeout: 30s
      keep-alive: 5m
    
    # 监控配置
    monitoring:
      enabled: true
      metrics-prefix: ai.client
      slow-query-threshold: 1000ms
      enable-tracing: true

# 日志配置
logging:
  level:
    com.example.ai: DEBUG
  pattern:
    console: "%d{yyyy-MM-dd HH:mm:ss} [%thread] %-5level %logger{36} - %msg%n"

# Actuator 配置
management:
  endpoints:
    web:
      exposure:
        include: health,info,metrics,prometheus
  metrics:
    export:
      prometheus:
        enabled: true
  endpoint:
    health:
      show-details: always

2. 应用启动类

java 复制代码
package com.example.ai;

import com.example.ai.annotation.EnableAiClient;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

@SpringBootApplication
@EnableAiClient
public class SpringAiSimulationApplication {
    public static void main(String[] args) {
        SpringApplication.run(SpringAiSimulationApplication.class, args);
    }
}

🔧 五、核心设计模式实现

1. 策略模式 (Strategy Pattern) - AI 提供商策略

java 复制代码
// 策略接口
package com.example.ai.pattern.strategy;

import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;
import com.example.ai.core.model.EmbeddingRequest;
import com.example.ai.core.model.EmbeddingResponse;

/**
 * AI 提供商策略接口
 */
public interface AiProviderStrategy {
    String getProviderName();
    
    ChatResponse chat(ChatRequest request);
    
    EmbeddingResponse embed(EmbeddingRequest request);
    
    boolean supports(String model);
    
    void validateApiKey(String apiKey);
}
java 复制代码
// OpenAI 策略实现
package com.example.ai.pattern.strategy;

import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;
import com.example.ai.core.model.EmbeddingRequest;
import com.example.ai.core.model.EmbeddingResponse;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.web.client.RestTemplate;

@Slf4j
public class OpenAiStrategy implements AiProviderStrategy {
    
    private final RestTemplate restTemplate;
    private final ObjectMapper objectMapper;
    private final String apiKey;
    private final String baseUrl;
    
    public OpenAiStrategy(RestTemplate restTemplate, 
                         ObjectMapper objectMapper,
                         String apiKey, 
                         String baseUrl) {
        this.restTemplate = restTemplate;
        this.objectMapper = objectMapper;
        this.apiKey = apiKey;
        this.baseUrl = baseUrl;
    }
    
    @Override
    public String getProviderName() {
        return "openai";
    }
    
    @Override
    public ChatResponse chat(ChatRequest request) {
        log.debug("OpenAI Strategy: Processing chat request for model: {}", request.getModel());
        
        // 构建 OpenAI 特定格式的请求
        var openAiRequest = buildOpenAiChatRequest(request);
        
        // 发送请求
        var headers = new HttpHeaders();
        headers.setContentType(MediaType.APPLICATION_JSON);
        headers.setBearerAuth(apiKey);
        
        var entity = new HttpEntity<>(openAiRequest, headers);
        var url = String.format("%s/chat/completions", baseUrl);
        
        try {
            var response = restTemplate.postForObject(url, entity, String.class);
            return parseOpenAiChatResponse(response);
        } catch (Exception e) {
            throw new RuntimeException("OpenAI API call failed", e);
        }
    }
    
    @Override
    public EmbeddingResponse embed(EmbeddingRequest request) {
        // 嵌入向量化实现
        return null;
    }
    
    @Override
    public boolean supports(String model) {
        return model != null && 
               (model.startsWith("gpt-") || 
                model.startsWith("text-embedding-"));
    }
    
    @Override
    public void validateApiKey(String apiKey) {
        if (apiKey == null || apiKey.trim().isEmpty()) {
            throw new IllegalArgumentException("OpenAI API key is required");
        }
        if (!apiKey.startsWith("sk-")) {
            log.warn("OpenAI API key may be invalid, should start with 'sk-'");
        }
    }
    
    private Object buildOpenAiChatRequest(ChatRequest request) {
        // 转换逻辑
        return null;
    }
    
    private ChatResponse parseOpenAiChatResponse(String response) {
        // 解析逻辑
        return null;
    }
}

2. 模板方法模式 (Template Method Pattern) - AI 客户端模板

java 复制代码
// 抽象模板类
package com.example.ai.core.template;

import com.example.ai.core.AiClient;
import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;
import com.example.ai.core.model.EmbeddingRequest;
import com.example.ai.core.model.EmbeddingResponse;
import com.example.ai.pattern.chain.AiClientInterceptorChain;
import lombok.extern.slf4j.Slf4j;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;

/**
 * 抽象 AI 客户端模板
 * 定义 AI 请求的标准处理流程
 */
@Slf4j
public abstract class AbstractAiClient implements AiClient {
    
    protected final AiClientInterceptorChain interceptorChain;
    protected final Executor executor;
    
    protected AbstractAiClient(AiClientInterceptorChain interceptorChain, 
                              Executor executor) {
        this.interceptorChain = interceptorChain;
        this.executor = executor;
    }
    
    @Override
    public ChatResponse chat(ChatRequest request) {
        // 模板方法:定义标准处理流程
        long startTime = System.currentTimeMillis();
        
        try {
            // 1. 预处理
            preProcess(request);
            
            // 2. 执行拦截器链前置处理
            interceptorChain.applyPreHandle(request);
            
            // 3. 执行实际业务逻辑(由子类实现)
            ChatResponse response = doChat(request);
            
            // 4. 执行拦截器链后置处理
            interceptorChain.applyPostHandle(request, response);
            
            // 5. 后处理
            postProcess(request, response, startTime);
            
            return response;
            
        } catch (Exception e) {
            // 6. 异常处理
            handleException(request, e, startTime);
            throw e;
        }
    }
    
    @Override
    public CompletableFuture<ChatResponse> chatAsync(ChatRequest request) {
        return CompletableFuture.supplyAsync(() -> chat(request), executor);
    }
    
    /**
     * 预处理(钩子方法)
     */
    protected void preProcess(ChatRequest request) {
        log.debug("Pre-processing chat request: {}", request);
        validateRequest(request);
    }
    
    /**
     * 实际聊天逻辑(由子类实现)
     */
    protected abstract ChatResponse doChat(ChatRequest request);
    
    /**
     * 后处理(钩子方法)
     */
    protected void postProcess(ChatRequest request, ChatResponse response, long startTime) {
        long duration = System.currentTimeMillis() - startTime;
        log.debug("Chat completed in {} ms. Tokens used: {}", 
                 duration, response.getUsage().getTotalTokens());
        
        // 发布事件
        publishChatCompletedEvent(request, response, duration);
    }
    
    /**
     * 异常处理
     */
    protected void handleException(ChatRequest request, Exception e, long startTime) {
        long duration = System.currentTimeMillis() - startTime;
        log.error("Chat request failed after {} ms: {}", duration, e.getMessage(), e);
        
        // 发布异常事件
        publishChatFailedEvent(request, e, duration);
    }
    
    /**
     * 验证请求
     */
    protected void validateRequest(ChatRequest request) {
        if (request == null) {
            throw new IllegalArgumentException("ChatRequest cannot be null");
        }
        if (request.getMessages() == null || request.getMessages().isEmpty()) {
            throw new IllegalArgumentException("Messages cannot be empty");
        }
    }
    
    // 其他钩子方法...
    protected abstract void publishChatCompletedEvent(ChatRequest request, 
                                                     ChatResponse response, 
                                                     long duration);
    
    protected abstract void publishChatFailedEvent(ChatRequest request, 
                                                  Exception e, 
                                                  long duration);
}

3. 工厂模式 (Factory Pattern) - AI 客户端工厂

java 复制代码
// AI 客户端工厂
package com.example.ai.pattern.factory;

import com.example.ai.core.AiClient;
import com.example.ai.config.AiProperties;
import com.example.ai.pattern.strategy.AiProviderStrategy;
import com.example.ai.pattern.strategy.OpenAiStrategy;
import com.example.ai.pattern.strategy.AzureAiStrategy;
import com.example.ai.pattern.strategy.LocalAiStrategy;
import com.example.ai.provider.OpenAiClient;
import com.example.ai.provider.AzureAiClient;
import com.example.ai.provider.LocalAiClient;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.client.RestTemplate;

import java.util.HashMap;
import java.util.Map;

/**
 * AI 客户端工厂
 * 根据配置创建不同类型的 AI 客户端
 */
@Slf4j
public class AiClientFactory {
    
    private final Map<String, AiClient> clientCache = new HashMap<>();
    private final AiProperties aiProperties;
    private final RestTemplate restTemplate;
    private final ObjectMapper objectMapper;
    
    public AiClientFactory(AiProperties aiProperties, 
                          RestTemplate restTemplate, 
                          ObjectMapper objectMapper) {
        this.aiProperties = aiProperties;
        this.restTemplate = restTemplate;
        this.objectMapper = objectMapper;
    }
    
    /**
     * 根据提供商创建 AI 客户端
     */
    public AiClient createAiClient(String provider) {
        return clientCache.computeIfAbsent(provider, this::createNewAiClient);
    }
    
    /**
     * 根据模型自动选择客户端
     */
    public AiClient createAiClientForModel(String model) {
        String provider = determineProviderByModel(model);
        return createAiClient(provider);
    }
    
    private AiClient createNewAiClient(String provider) {
        log.info("Creating new AI client for provider: {}", provider);
        
        switch (provider.toLowerCase()) {
            case "openai":
                return createOpenAiClient();
            case "azure":
                return createAzureAiClient();
            case "local":
                return createLocalAiClient();
            default:
                throw new IllegalArgumentException("Unsupported AI provider: " + provider);
        }
    }
    
    private AiClient createOpenAiClient() {
        var properties = aiProperties.getOpenai();
        var strategy = new OpenAiStrategy(
            restTemplate,
            objectMapper,
            properties.getApiKey(),
            properties.getBaseUrl()
        );
        
        return new OpenAiClient(strategy, aiProperties);
    }
    
    private AiClient createAzureAiClient() {
        var properties = aiProperties.getAzure();
        var strategy = new AzureAiStrategy(
            restTemplate,
            objectMapper,
            properties.getApiKey(),
            properties.getEndpoint(),
            properties.getDeploymentName(),
            properties.getApiVersion()
        );
        
        return new AzureAiClient(strategy, aiProperties);
    }
    
    private AiClient createLocalAiClient() {
        var properties = aiProperties.getLocal();
        var strategy = new LocalAiStrategy(
            properties.getModelPath(),
            properties.getDevice()
        );
        
        return new LocalAiClient(strategy, aiProperties);
    }
    
    private String determineProviderByModel(String model) {
        if (model == null) {
            return aiProperties.getDefaultProvider();
        }
        
        if (model.startsWith("gpt-") || model.startsWith("text-")) {
            return "openai";
        } else if (model.contains("azure")) {
            return "azure";
        } else if (model.contains("local") || model.contains("llama")) {
            return "local";
        }
        
        return aiProperties.getDefaultProvider();
    }
    
    /**
     * 获取所有支持的提供商
     */
    public Map<String, AiProviderStrategy> getSupportedStrategies() {
        Map<String, AiProviderStrategy> strategies = new HashMap<>();
        
        if (aiProperties.getOpenai().isEnabled()) {
            strategies.put("openai", new OpenAiStrategy(
                restTemplate, objectMapper,
                aiProperties.getOpenai().getApiKey(),
                aiProperties.getOpenai().getBaseUrl()
            ));
        }
        
        if (aiProperties.getAzure().isEnabled()) {
            strategies.put("azure", new AzureAiStrategy(
                restTemplate, objectMapper,
                aiProperties.getAzure().getApiKey(),
                aiProperties.getAzure().getEndpoint(),
                aiProperties.getAzure().getDeploymentName(),
                aiProperties.getAzure().getApiVersion()
            ));
        }
        
        if (aiProperties.getLocal().isEnabled()) {
            strategies.put("local", new LocalAiStrategy(
                aiProperties.getLocal().getModelPath(),
                aiProperties.getLocal().getDevice()
            ));
        }
        
        return strategies;
    }
}

4. 适配器模式 (Adapter Pattern) - 模型适配器

java 复制代码
// 适配器接口
package com.example.ai.pattern.adapter;

import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;

/**
 * AI 模型适配器接口
 * 将统一请求转换为特定模型格式
 */
public interface AiModelAdapter {
    
    String getModelType();
    
    boolean supports(String model);
    
    Object adaptRequest(ChatRequest request);
    
    ChatResponse adaptResponse(Object rawResponse, ChatRequest originalRequest);
}
java 复制代码
// OpenAI 适配器实现
package com.example.ai.pattern.adapter;

import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;
import com.example.ai.core.model.Message;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

@Slf4j
@RequiredArgsConstructor
public class OpenAiAdapter implements AiModelAdapter {
    
    private final ObjectMapper objectMapper;
    
    @Override
    public String getModelType() {
        return "openai";
    }
    
    @Override
    public boolean supports(String model) {
        return model != null && 
               (model.startsWith("gpt-") || 
                model.startsWith("text-"));
    }
    
    @Override
    public Object adaptRequest(ChatRequest request) {
        Map<String, Object> openAiRequest = new HashMap<>();
        
        // 转换模型名称
        openAiRequest.put("model", request.getModel());
        
        // 转换消息
        List<Map<String, String>> messages = request.getMessages().stream()
            .map(this::convertMessage)
            .collect(Collectors.toList());
        openAiRequest.put("messages", messages);
        
        // 转换参数
        if (request.getTemperature() != null) {
            openAiRequest.put("temperature", request.getTemperature());
        }
        if (request.getMaxTokens() != null) {
            openAiRequest.put("max_tokens", request.getMaxTokens());
        }
        if (request.getTopP() != null) {
            openAiRequest.put("top_p", request.getTopP());
        }
        
        // 流式响应
        if (request.isStream()) {
            openAiRequest.put("stream", true);
        }
        
        log.debug("Adapted OpenAI request: {}", openAiRequest);
        return openAiRequest;
    }
    
    @Override
    public ChatResponse adaptResponse(Object rawResponse, ChatRequest originalRequest) {
        try {
            String responseJson = objectMapper.writeValueAsString(rawResponse);
            Map<String, Object> responseMap = objectMapper.readValue(responseJson, Map.class);
            
            return ChatResponse.builder()
                .id((String) responseMap.get("id"))
                .model((String) responseMap.get("model"))
                .created((Integer) responseMap.get("created"))
                .choices(extractChoices(responseMap))
                .usage(extractUsage(responseMap))
                .build();
                
        } catch (JsonProcessingException e) {
            throw new RuntimeException("Failed to adapt OpenAI response", e);
        }
    }
    
    private Map<String, String> convertMessage(Message message) {
        Map<String, String> result = new HashMap<>();
        result.put("role", message.getRole().name().toLowerCase());
        result.put("content", message.getContent());
        return result;
    }
    
    private List<ChatResponse.Choice> extractChoices(Map<String, Object> responseMap) {
        // 提取 choices
        return null;
    }
    
    private ChatResponse.Usage extractUsage(Map<String, Object> responseMap) {
        // 提取 usage
        return null;
    }
}

5. 责任链模式 (Chain of Responsibility) - 拦截器链

java 复制代码
// 拦截器接口
package com.example.ai.pattern.chain;

import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;

/**
 * AI 客户端拦截器
 */
public interface AiClientInterceptor {
    
    /**
     * 前置处理
     */
    default boolean preHandle(ChatRequest request) {
        return true;
    }
    
    /**
     * 后置处理
     */
    default void postHandle(ChatRequest request, ChatResponse response) {
    }
    
    /**
     * 异常处理
     */
    default void afterCompletion(ChatRequest request, Exception ex) {
    }
}
java 复制代码
// 拦截器链
package com.example.ai.pattern.chain;

import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

import java.util.List;

@Slf4j
@RequiredArgsConstructor
public class AiClientInterceptorChain {
    
    private final List<AiClientInterceptor> interceptors;
    
    /**
     * 应用所有拦截器的前置处理
     */
    public boolean applyPreHandle(ChatRequest request) {
        for (AiClientInterceptor interceptor : interceptors) {
            if (!interceptor.preHandle(request)) {
                log.debug("Interceptor {} prevented request processing", 
                         interceptor.getClass().getSimpleName());
                return false;
            }
        }
        return true;
    }
    
    /**
     * 应用所有拦截器的后置处理
     */
    public void applyPostHandle(ChatRequest request, ChatResponse response) {
        for (AiClientInterceptor interceptor : interceptors) {
            try {
                interceptor.postHandle(request, response);
            } catch (Exception e) {
                log.error("Interceptor postHandle failed", e);
            }
        }
    }
    
    /**
     * 应用所有拦截器的完成处理
     */
    public void applyAfterCompletion(ChatRequest request, Exception ex) {
        for (AiClientInterceptor interceptor : interceptors) {
            try {
                interceptor.afterCompletion(request, ex);
            } catch (Exception e) {
                log.error("Interceptor afterCompletion failed", e);
            }
        }
    }
}

6. 构建者模式 (Builder Pattern) - 请求构建器

java 复制代码
// 聊天请求构建器
package com.example.ai.pattern.builder;

import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.Message;

import java.util.ArrayList;
import java.util.List;

/**
 * 聊天请求构建器
 * 使用构建者模式创建复杂的请求对象
 */
public class ChatRequestBuilder {
    
    private String model;
    private final List<Message> messages = new ArrayList<>();
    private Double temperature;
    private Integer maxTokens;
    private Double topP;
    private boolean stream = false;
    
    public ChatRequestBuilder model(String model) {
        this.model = model;
        return this;
    }
    
    public ChatRequestBuilder message(Message message) {
        this.messages.add(message);
        return this;
    }
    
    public ChatRequestBuilder message(String role, String content) {
        return message(Message.builder()
            .role(Message.Role.valueOf(role.toUpperCase()))
            .content(content)
            .build());
    }
    
    public ChatRequestBuilder systemMessage(String content) {
        return message("system", content);
    }
    
    public ChatRequestBuilder userMessage(String content) {
        return message("user", content);
    }
    
    public ChatRequestBuilder assistantMessage(String content) {
        return message("assistant", content);
    }
    
    public ChatRequestBuilder temperature(Double temperature) {
        this.temperature = temperature;
        return this;
    }
    
    public ChatRequestBuilder maxTokens(Integer maxTokens) {
        this.maxTokens = maxTokens;
        return this;
    }
    
    public ChatRequestBuilder topP(Double topP) {
        this.topP = topP;
        return this;
    }
    
    public ChatRequestBuilder stream(boolean stream) {
        this.stream = stream;
        return this;
    }
    
    public ChatRequest build() {
        if (model == null || model.trim().isEmpty()) {
            throw new IllegalArgumentException("Model is required");
        }
        if (messages.isEmpty()) {
            throw new IllegalArgumentException("At least one message is required");
        }
        
        return ChatRequest.builder()
            .model(model)
            .messages(new ArrayList<>(messages))
            .temperature(temperature)
            .maxTokens(maxTokens)
            .topP(topP)
            .stream(stream)
            .build();
    }
}

🔄 六、模型路由器 (Model Router)

java 复制代码
package com.example.ai.router;

import com.example.ai.core.AiClient;
import com.example.ai.pattern.factory.AiClientFactory;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * 模型路由器
 * 根据模型名称路由到合适的 AI 客户端
 */
@Slf4j
@Component
@RequiredArgsConstructor
public class ModelRouter {
    
    private final AiClientFactory aiClientFactory;
    private final Map<String, AiClient> clientCache = new ConcurrentHashMap<>();
    
    /**
     * 路由到合适的 AI 客户端
     */
    public AiClient route(String model) {
        return clientCache.computeIfAbsent(model, this::createClientForModel);
    }
    
    /**
     * 获取所有可用的模型
     */
    public Map<String, String> getAvailableModels() {
        Map<String, String> models = new ConcurrentHashMap<>();
        
        // OpenAI 模型
        models.put("gpt-3.5-turbo", "openai");
        models.put("gpt-4", "openai");
        models.put("text-embedding-ada-002", "openai");
        
        // Azure 模型
        models.put("gpt-35-turbo", "azure");
        models.put("gpt-4-azure", "azure");
        
        // 本地模型
        models.put("llama-2-7b", "local");
        models.put("vicuna-13b", "local");
        
        return models;
    }
    
    /**
     * 根据模型选择最佳提供商
     */
    public String selectBestProvider(String model, String preferredProvider) {
        if (preferredProvider != null) {
            return preferredProvider;
        }
        
        var availableModels = getAvailableModels();
        return availableModels.getOrDefault(model, "openai");
    }
    
    private AiClient createClientForModel(String model) {
        String provider = determineProvider(model);
        log.info("Creating AI client for model: {} -> provider: {}", model, provider);
        
        return aiClientFactory.createAiClient(provider);
    }
    
    private String determineProvider(String model) {
        if (model == null) {
            return "openai";
        }
        
        if (model.startsWith("gpt-") || model.startsWith("text-")) {
            return "openai";
        } else if (model.contains("azure") || model.contains("AZURE")) {
            return "azure";
        } else if (model.contains("local") || model.contains("llama")) {
            return "local";
        }
        
        return "openai";
    }
}

📊 七、监控指标收集

java 复制代码
package com.example.ai.metrics;

import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Timer;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;

/**
 * AI 客户端监控指标
 */
@Slf4j
@Component
@RequiredArgsConstructor
public class AiClientMetrics {
    
    private final MeterRegistry meterRegistry;
    private final Map<String, Timer> timers = new ConcurrentHashMap<>();
    private final Map<String, Counter> counters = new ConcurrentHashMap<>();
    
    /**
     * 记录请求成功
     */
    public void recordSuccess(String provider, String model, String operation, long duration) {
        getTimer(provider, model, operation)
            .record(duration, TimeUnit.MILLISECONDS);
        
        incrementCounter("ai.client.requests.success", 
            Map.of("provider", provider, "model", model, "operation", operation));
        
        // 记录响应时间分布
        meterRegistry.timer("ai.client.response.time", 
            "provider", provider, 
            "model", model, 
            "operation", operation)
            .record(duration, TimeUnit.MILLISECONDS);
    }
    
    /**
     * 记录请求失败
     */
    public void recordError(String provider, String model, String operation, String errorType) {
        incrementCounter("ai.client.requests.error", 
            Map.of("provider", provider, 
                   "model", model, 
                   "operation", operation, 
                   "error_type", errorType));
    }
    
    /**
     * 记录令牌使用
     */
    public void recordTokenUsage(String provider, String model, 
                                int promptTokens, int completionTokens, int totalTokens) {
        meterRegistry.counter("ai.client.tokens.prompt", 
            "provider", provider, "model", model)
            .increment(promptTokens);
        
        meterRegistry.counter("ai.client.tokens.completion", 
            "provider", provider, "model", model)
            .increment(completionTokens);
        
        meterRegistry.counter("ai.client.tokens.total", 
            "provider", provider, "model", model)
            .increment(totalTokens);
    }
    
    /**
     * 记录流式响应块
     */
    public void recordStreamChunk(String provider, String model) {
        incrementCounter("ai.client.stream.chunks", 
            Map.of("provider", provider, "model", model));
    }
    
    private Timer getTimer(String provider, String model, String operation) {
        String key = provider + ":" + model + ":" + operation;
        
        return timers.computeIfAbsent(key, k -> 
            Timer.builder("ai.client.requests")
                .tag("provider", provider)
                .tag("model", model)
                .tag("operation", operation)
                .publishPercentiles(0.5, 0.95, 0.99)
                .sla(java.time.Duration.ofMillis(100), 
                    java.time.Duration.ofMillis(500), 
                    java.time.Duration.ofMillis(1000))
                .register(meterRegistry)
        );
    }
    
    private void incrementCounter(String name, Map<String, String> tags) {
        Counter.Builder builder = Counter.builder(name);
        tags.forEach(builder::tag);
        builder.register(meterRegistry).increment();
    }
}

🎯 八、使用示例

1. 聊天服务示例

java 复制代码
package com.example.ai.service;

import com.example.ai.core.AiClient;
import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;
import com.example.ai.core.model.Message;
import com.example.ai.router.ModelRouter;
import com.example.ai.metrics.AiClientMetrics;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;

import java.util.List;
import java.util.concurrent.CompletableFuture;

@Slf4j
@Service
@RequiredArgsConstructor
public class ChatService {
    
    private final ModelRouter modelRouter;
    private final AiClientMetrics metrics;
    
    /**
     * 同步聊天
     */
    public String chat(String model, String userMessage) {
        long startTime = System.currentTimeMillis();
        
        try {
            // 获取合适的 AI 客户端
            AiClient aiClient = modelRouter.route(model);
            
            // 构建请求
            ChatRequest request = ChatRequest.builder()
                .model(model)
                .messages(List.of(
                    Message.userMessage(userMessage)
                ))
                .temperature(0.7)
                .maxTokens(1000)
                .build();
            
            // 执行聊天
            ChatResponse response = aiClient.chat(request);
            String content = response.getContent();
            
            // 记录指标
            long duration = System.currentTimeMillis() - startTime;
            metrics.recordSuccess(
                "openai", model, "chat", duration
            );
            metrics.recordTokenUsage(
                "openai", model,
                response.getUsage().getPromptTokens(),
                response.getUsage().getCompletionTokens(),
                response.getUsage().getTotalTokens()
            );
            
            return content;
            
        } catch (Exception e) {
            long duration = System.currentTimeMillis() - startTime;
            metrics.recordError("openai", model, "chat", e.getClass().getSimpleName());
            throw e;
        }
    }
    
    /**
     * 异步聊天
     */
    public CompletableFuture<String> chatAsync(String model, String userMessage) {
        return CompletableFuture.supplyAsync(() -> chat(model, userMessage));
    }
    
    /**
     * 带上下文的聊天
     */
    public String chatWithContext(String model, String systemPrompt, 
                                 List<Message> history, String userMessage) {
        AiClient aiClient = modelRouter.route(model);
        
        // 构建消息列表
        List<Message> messages = new java.util.ArrayList<>();
        messages.add(Message.systemMessage(systemPrompt));
        messages.addAll(history);
        messages.add(Message.userMessage(userMessage));
        
        ChatRequest request = ChatRequest.builder()
            .model(model)
            .messages(messages)
            .temperature(0.7)
            .maxTokens(2000)
            .build();
        
        return aiClient.chat(request).getContent();
    }
}

2. REST 控制器

java 复制代码
package com.example.ai.controller;

import com.example.ai.service.ChatService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;

import java.util.Map;
import java.util.concurrent.CompletableFuture;

@Slf4j
@RestController
@RequestMapping("/api/v1/ai")
@RequiredArgsConstructor
public class AiController {
    
    private final ChatService chatService;
    
    /**
     * 简单聊天接口
     */
    @PostMapping("/chat")
    public ResponseEntity<Map<String, Object>> chat(
            @RequestParam(defaultValue = "gpt-3.5-turbo") String model,
            @RequestBody Map<String, String> request) {
        
        String message = request.get("message");
        if (message == null || message.trim().isEmpty()) {
            return ResponseEntity.badRequest()
                .body(Map.of("error", "Message is required"));
        }
        
        try {
            String response = chatService.chat(model, message);
            return ResponseEntity.ok(Map.of(
                "model", model,
                "response", response
            ));
        } catch (Exception e) {
            log.error("Chat failed", e);
            return ResponseEntity.internalServerError()
                .body(Map.of("error", e.getMessage()));
        }
    }
    
    /**
     * 异步聊天接口
     */
    @PostMapping("/chat/async")
    public CompletableFuture<ResponseEntity<Map<String, Object>>> chatAsync(
            @RequestParam(defaultValue = "gpt-3.5-turbo") String model,
            @RequestBody Map<String, String> request) {
        
        return chatService.chatAsync(model, request.get("message"))
            .thenApply(response -> ResponseEntity.ok(Map.of(
                "model", model,
                "response", response
            )))
            .exceptionally(e -> ResponseEntity.internalServerError()
                .body(Map.of("error", e.getMessage())));
    }
    
    /**
     * 健康检查
     */
    @GetMapping("/health")
    public ResponseEntity<Map<String, Object>> health() {
        return ResponseEntity.ok(Map.of(
            "status", "UP",
            "timestamp", System.currentTimeMillis()
        ));
    }
    
    /**
     * 获取支持的模型
     */
    @GetMapping("/models")
    public ResponseEntity<Map<String, Object>> getAvailableModels() {
        return ResponseEntity.ok(Map.of(
            "models", Map.of(
                "gpt-3.5-turbo", "OpenAI GPT-3.5 Turbo",
                "gpt-4", "OpenAI GPT-4",
                "text-embedding-ada-002", "OpenAI Embedding Model"
            )
        ));
    }
}

🧪 九、测试代码

java 复制代码
package com.example.ai.service;

import com.example.ai.core.AiClient;
import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;
import com.example.ai.core.model.Message;
import com.example.ai.router.ModelRouter;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.test.util.ReflectionTestUtils;

import java.util.List;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;

@ExtendWith(MockitoExtension.class)
class ChatServiceTest {
    
    @Mock
    private ModelRouter modelRouter;
    
    @Mock
    private AiClient aiClient;
    
    private ChatService chatService;
    
    @BeforeEach
    void setUp() {
        chatService = new ChatService(modelRouter, null);
    }
    
    @Test
    void testChat_Success() {
        // 准备测试数据
        String model = "gpt-3.5-turbo";
        String userMessage = "Hello, how are you?";
        String expectedResponse = "I'm fine, thank you!";
        
        // 模拟行为
        when(modelRouter.route(model)).thenReturn(aiClient);
        
        ChatResponse mockResponse = ChatResponse.builder()
            .content(expectedResponse)
            .build();
        when(aiClient.chat(any(ChatRequest.class))).thenReturn(mockResponse);
        
        // 执行测试
        String result = chatService.chat(model, userMessage);
        
        // 验证结果
        assertEquals(expectedResponse, result);
        verify(modelRouter, times(1)).route(model);
        verify(aiClient, times(1)).chat(any(ChatRequest.class));
    }
    
    @Test
    void testChat_WithSystemPrompt() {
        // 准备测试数据
        String model = "gpt-3.5-turbo";
        String systemPrompt = "You are a helpful assistant.";
        String userMessage = "What is AI?";
        
        // 模拟行为
        when(modelRouter.route(model)).thenReturn(aiClient);
        
        ChatResponse mockResponse = ChatResponse.builder()
            .content("AI stands for Artificial Intelligence.")
            .build();
        when(aiClient.chat(any(ChatRequest.class))).thenReturn(mockResponse);
        
        // 执行测试
        String result = chatService.chatWithContext(
            model, systemPrompt, List.of(), userMessage
        );
        
        // 验证结果
        assertNotNull(result);
        verify(aiClient, times(1)).chat(any(ChatRequest.class));
    }
    
    @Test
    void testChat_ModelNotFound() {
        // 准备测试数据
        String model = "unknown-model";
        String userMessage = "Hello";
        
        // 模拟行为
        when(modelRouter.route(model)).thenThrow(
            new IllegalArgumentException("Model not found: " + model)
        );
        
        // 执行测试并验证异常
        assertThrows(IllegalArgumentException.class, () -> {
            chatService.chat(model, userMessage);
        });
    }
}

📈 十、设计模式组合使用架构图

配置层
监控层
提供商层
核心层 - 设计模式组合
路由层
客户端层
REST 控制器
服务层
模型路由器
模型路由器工厂
工厂模式

AiClientFactory
策略模式

AiProviderStrategy
模板方法模式

AbstractAiClient
适配器模式

AiModelAdapter
责任链模式

AiClientInterceptorChain
构建者模式

ChatRequestBuilder
OpenAI 客户端
Azure AI 客户端
本地模型客户端
指标收集器
拦截器链
Micrometer 集成
自动配置
属性绑定
条件装配

🔄 十一、工作流程时序图

AI服务提供商 指标监控 拦截器链 模型适配器 AI客户端实现 AI策略 客户端工厂 模型路由器 聊天服务 REST 控制器 客户端 AI服务提供商 指标监控 拦截器链 模型适配器 AI客户端实现 AI策略 客户端工厂 模型路由器 聊天服务 REST 控制器 客户端 完整处理流程 耗时: 200-800ms POST /api/v1/ai/chat 调用聊天服务 路由到合适客户端 获取客户端工厂 选择策略 返回策略 返回客户端 返回AI客户端 调用chat方法 前置拦截处理 继续处理 转换请求格式 返回转换后请求 调用AI服务API 返回原始响应 转换响应格式 返回统一格式响应 后置拦截处理 记录监控指标 记录完成 处理完成 返回响应 记录业务指标 记录完成 返回聊天结果 返回HTTP响应

🎯 十二、运行和测试

1. 启动应用

bash 复制代码
# 设置环境变量
export OPENAI_API_KEY=your-api-key-here

# 编译并运行
mvn clean package
java -jar target/spring-ai-simulation-1.0.0.jar

# 或者使用 Maven 直接运行
mvn spring-boot:run

2. 测试 API

bash 复制代码
# 健康检查
curl http://localhost:8080/api/v1/ai/health

# 获取可用模型
curl http://localhost:8080/api/v1/ai/models

# 聊天测试
curl -X POST http://localhost:8080/api/v1/ai/chat \
  -H "Content-Type: application/json" \
  -d '{"message": "Hello, how are you?"}'

# 指定模型聊天
curl -X POST "http://localhost:8080/api/v1/ai/chat?model=gpt-3.5-turbo" \
  -H "Content-Type: application/json" \
  -d '{"message": "What is Spring AI?"}'

3. 监控指标

📊 十三、设计模式总结

设计模式 应用位置 解决的问题 核心实现类
工厂模式 AI 客户端创建 统一创建不同 AI 提供商客户端 AiClientFactory
策略模式 AI 提供商切换 支持多种 AI 服务提供商 AiProviderStrategy
模板方法 请求处理流程 统一 AI 请求处理流程 AbstractAiClient
适配器模式 模型格式转换 转换不同模型的请求/响应格式 AiModelAdapter
责任链模式 拦截器处理 实现可插拔的拦截器链 AiClientInterceptorChain
构建者模式 请求对象构建 构建复杂的请求对象 ChatRequestBuilder
单例模式 配置管理 确保配置对象唯一性 AiProperties
代理模式 监控增强 为 AI 客户端添加监控功能 AiClientProxy

🔧 十四、扩展指南

1. 添加新的 AI 提供商

java 复制代码
@Component
public class CustomAiStrategy implements AiProviderStrategy {
    
    @Override
    public String getProviderName() {
        return "custom";
    }
    
    @Override
    public ChatResponse chat(ChatRequest request) {
        // 实现自定义 AI 服务调用逻辑
        return null;
    }
    
    // 其他方法实现...
}

2. 添加自定义拦截器

java 复制代码
@Component
public class CustomInterceptor implements AiClientInterceptor {
    
    @Override
    public boolean preHandle(ChatRequest request) {
        // 自定义前置处理逻辑
        return true;
    }
    
    @Override
    public void postHandle(ChatRequest request, ChatResponse response) {
        // 自定义后置处理逻辑
    }
}

3. 配置自定义模型

yaml 复制代码
spring:
  ai:
    custom:
      enabled: true
      api-key: ${CUSTOM_API_KEY}
      endpoint: https://api.custom-ai.com/v1
      models:
        - custom-model-1
        - custom-model-2

这个完整的 Spring AI 模拟框架展示了如何通过多种设计模式的组合,构建一个灵活、可扩展、可维护的 AI 服务集成框架。每个设计模式都解决了特定的问题,共同构成了一个完整的解决方案。

相关推荐
惊鸿一博2 小时前
自动驾驶_端到端_VLA_概念介绍
人工智能·机器学习·自动驾驶
乾元2 小时前
当网络变成博弈场:混合云时代,如何用 AI 重构跨域链路的成本与体验平衡
运维·网络·人工智能·网络协议·安全·华为·重构
云老大TG:@yunlaoda3602 小时前
华为云国际站代理商MSGSMS主要有什么作用呢?
网络·人工智能·华为云
GISERLiu2 小时前
Spring Boot + Spring Security
java·spring boot·spring
有一个好名字2 小时前
设计模式-状态模式
设计模式·状态模式
一瞬祈望2 小时前
⭐ 深度学习入门体系(第 6 篇): MLP 和 CNN 有什么本质区别?
人工智能·深度学习·cnn·mlp
ppo_wu2 小时前
Kafka 3.9.0:部署、监控与消息发送教程
java·linux·spring boot·分布式·后端·spring·kafka
jimmyleeee2 小时前
人工智能基础知识笔记二十九:大模型量化技术(Quantisation)
人工智能·笔记
xian_wwq2 小时前
【学习笔记】AI的边界
人工智能·笔记·学习