【springai】 Model层设计与实现

1 Model层概述

1.1 架构定位

Model层是Spring AI的核心抽象层,它定义了与各类AI模型交互的统一契约。这一层采用经典的面向接口设计,通过泛型抽象实现了高度的可扩展性,使得开发者可以在不改变业务代码的前提下,无缝切换不同的AI提供商。

Model层的核心价值体现在以下几个方面:

  • 统一抽象:屏蔽不同AI提供商API的差异,提供一致的编程模型
  • 类型安全:通过泛型机制保证请求和响应类型的编译期检查
  • 可扩展性:支持同步调用、流式调用,以及自定义模型实现
  • 可观测性:内置Micrometer观测支持,便于监控和调试

1.2 整体类图

2 Generic Model API 源码剖析

2.1 Model接口

Model接口是所有AI模型交互的根接口,它使用Java泛型来支持不同类型的请求和响应:

php 复制代码
/**
 * 所有AI模型的根接口
 * @param <TReq> 请求类型,必须继承ModelRequest
 * @param <TRes> 响应类型,必须继承ModelResponse
 */
public interface Model<TReq extends ModelRequest<?>, TRes extends ModelResponse<?>> {

    /**
     * 调用AI模型的核心方法
     * @param request 请求对象
     * @return 响应对象
     */
    TRes call(TReq request);

}

设计要点

  1. 泛型约束TReq extends ModelRequest<?>确保了请求参数必须是ModelRequest的子类型
  2. 泛型约束TRes extends ModelResponse<?>确保了响应参数必须是ModelResponse的子类型
  3. 默认方法call()提供默认实现,避免所有子类都必须实现的方法

2.2 StreamingModel接口

对于原生支持流式响应的模型,Spring AI提供了专门的StreamingModel接口:

php 复制代码
/**
 * 支持流式响应的AI模型接口
 * @param <TReq> 请求类型
 * @param <TResChunk> 流式响应块类型
 */
public interface StreamingModel<TReq extends ModelRequest<?>, TResChunk extends ModelResponse<?>> {

    /**
     * 流式调用AI模型
     * @param request 请求对象
     * @return 响应块的Flux流
     */
    Flux<TResChunk> stream(TReq request);
}

2.3 ModelRequest接口

ModelRequest封装了对AI模型的请求信息:

csharp 复制代码
/**
 * AI模型请求的抽象
 * @param <T> 指令类型(通常是Prompt或Message列表)
 */
public interface ModelRequest<T> {

    /**
     * 获取发送给AI模型的指令或输入
     * @return 指令内容
     */
    T getInstructions();

    /**
     * 获取AI模型交互的可自定义选项
     * @return 模型选项
     */
    ModelOptions getOptions();
}

2.4 ModelOptions接口

ModelOptions是一个标记接口,用于携带AI模型的自定义参数:

csharp 复制代码
/**
 * AI模型选项的标记接口
 * 各个具体模型实现可以扩展此接口添加自己的选项
 */
public interface ModelOptions {
    // 标记接口,无具体方法
}

ChatOptionsModelOptions的重要子接口,定义了对话模型通用的选项:

scss 复制代码
public interface ChatOptions extends ModelOptions {
    
    String getModel();                    // 模型ID
    Double getFrequencyPenalty();          // 频率惩罚(-2.0~2.0),降低重复token的可能性
    Integer getMaxTokens();               // 最大生成token数
    Double getPresencePenalty();           // 存在惩罚(-2.0~2.0),鼓励谈论新话题
    List<String> getStopSequences();      // 停止序列
    Double getTemperature();               // 采样温度(0.0~2.0),控制随机性
    Integer getTopK();                    // Top-K采样
    Double getTopP();                      // Top-P核采样
    
    <T entends ChatOptions> T copy();                   // 深拷贝
}

参数详解

参数 取值范围 作用 典型场景
temperature 0.0~2.0 控制输出的随机性 0.2:代码生成;0.8:创意写作
topP 0.0~1.0 核采样,动态选择token 0.9:平衡质量和多样性
maxTokens 1~4096 限制响应长度 根据任务需求设置
frequencyPenalty -2.0~2.0 降低重复内容 0.5:减少重复
presencePenalty -2.0~2.0 鼓励新话题 0.5:引导话题转移

2.5 ModelResponse接口

ModelResponse封装了AI模型的响应结果:

csharp 复制代码
/**
 * AI模型响应的抽象
 * @param <T> 结果类型,必须继承ModelResult
 */
public interface ModelResponse<T extends ModelResult<?>> {

    /**
     * 获取主要结果
     * @return 模型生成的结果
     */
    T getResult();

    /**
     * 获取所有生成的结果(用于n>1的情况)
     * @return 结果列表
     */
    List<T> getResults();

    /**
     * 获取响应元数据
     * @return 响应元数据
     */
    ResponseMetadata getMetadata();
}

2.6 ModelResult接口

ModelResult表示AI模型产生的单个输出结果:

csharp 复制代码
/**
 * AI模型输出的抽象
 * @param <T> 输出类型(如AssistantMessage)
 */
public interface ModelResult<T> {

    /**
     * 获取模型生成的输出
     * @return 输出内容
     */
    T getOutput();

    /**
     * 获取结果相关的元数据
     * @return 结果元数据
     */
    ResultMetadata getMetadata();
}

3 ChatModel接口体系详解

3.1 ChatModel接口定义

ChatModel是针对对话场景的特化接口,它扩展了通用的Model接口:

java 复制代码
public interface ChatModel extends Model<Prompt, ChatResponse>, StreamingChatModel {

    /**
     * 便捷方法:直接使用字符串消息调用
     * 内部自动创建UserMessage和Prompt
     * @param message 用户消息
     * @return 响应的文本内容
     */
    default String call(String message) {
        Prompt prompt = new Prompt(new UserMessage(message));
        ChatResponse response = this.call(prompt);
        Generation result = response.getResult();
        return result != null ? result.getOutput().getText() : null;
    }

    /**
     * 核心方法:使用完整的Prompt对象调用
     * @param prompt 提示词对象
     * @return 完整的对话响应
     */
    @Override
    ChatResponse call(Prompt prompt);
}

3.2 StreamingChatModel接口

typescript 复制代码
public interface StreamingChatModel extends StreamingModel<Prompt, ChatResponse> {

    /**
     * 便捷方法:流式调用字符串消息
     * @param message 用户消息
     * @return 文本块的Flux流
     */
    default Flux<String> stream(String message) {
		Prompt prompt = new Prompt(message);
		return stream(prompt).map(response -> Optional.ofNullable(response.getResult())
			.map(Generation::getOutput)
			.map(AssistantMessage::getText)
			.orElse(""));
	}

    /**
     * 核心方法:流式调用完整的Prompt
     * @param prompt 提示词对象
     * @return ChatResponse块的Flux流
     */
    @Override
    Flux<ChatResponse> stream(Prompt prompt);
}

3.3 调用流程图

4 Prompt与Message体系

4.1 Message接口与实现

MessagePrompt的基本组成单元,代表对话中的一条消息:

csharp 复制代码
/**
 * 对话消息的抽象
 */
public interface Message extends Content {
    /**
     * 获取消息类型(用户、系统、助手等)
     */
    MessageType getMessageType();
}

// Content接口定义了消息内容的基础能力
public interface Content {
    String getText();                      // 消息文本内容
    Map<String, Object> getMetadata();     // 元数据
}

// 多模态消息扩展
public interface MediaContent extends Content {
    List<Media> getMedia();          // 媒体内容列表
}

MessageType枚举

scss 复制代码
public enum MessageType {
    SYSTEM("system"),      // 系统指令,设定AI角色和行为
    USER("user"),          // 用户输入
    ASSISTANT("assistant"),// AI助手回复
    TOOL("tool");          // 工具调用请求
}

各Message实现类详解

scala 复制代码
// 1. UserMessage - 用户消息
public class UserMessage extends AbstractMessage implements MediaContent{
    protected final List<Media> media;
    public UserMessage(String textContent) {
		this(textContent, new ArrayList<>(), Map.of());
	}

	private UserMessage(String textContent, Collection<Media> media, Map<String, Object> metadata) {
		super(MessageType.USER, textContent, metadata);
		this.media = new ArrayList<>(media);
	}
}

// 2. SystemMessage - 系统消息(设定AI角色)
public class SystemMessage extends AbstractMessage {
    public SystemMessage(String textContent) {
		this(textContent, Map.of());
	}

	private SystemMessage(String textContent, Map<String, Object> metadata) {
		super(MessageType.SYSTEM, textContent, metadata);
	}
}

// 3. AssistantMessage - 助手回复消息
public class AssistantMessage extends AbstractMessage {
    private List<ToolCall> toolCalls;  // 可包含工具调用
    protected final List<Media> media;

	public AssistantMessage(String content) {
		this(content, Map.of(), List.of(), List.of());
	}

	protected AssistantMessage(String content, Map<String, Object> properties, List<ToolCall> toolCalls,
			List<Media> media) {
		super(MessageType.ASSISTANT, content, properties);		Assert.notNull(media, "Media must not be null");
		this.toolCalls = toolCalls;
		this.media = media;
	}
}

// 4. ToolResponseMessage - 工具调用消息
public class ToolResponseMessage extends AbstractMessage {

	protected final List<ToolResponse> responses;

	protected ToolResponseMessage(List<ToolResponse> responses, Map<String, Object> metadata) {
		super(MessageType.TOOL, "", metadata);
		this.responses = responses;
	}
}

4.2 Prompt类源码剖析

Prompt类是ModelRequest的具体实现,封装了消息列表和选项:

csharp 复制代码
public class Prompt implements ModelRequest<List<Message>> {

    private final List<Message> messages;
    private ChatOptions chatOptions;

    // 构造函数
    public Prompt(List<Message> messages) {
        this(messages, null);
    }

    public Prompt(Message... messages) {
        this(Arrays.asList(messages), null);
    }

    public Prompt(String contents) {
        this(new UserMessage(contents));
    }

    // 带选项的构造函数
    public Prompt(List<Message> messages, ChatOptions chatOptions) {
        this.messages = messages;
        this.chatOptions = chatOptions;
    }

    // 获取消息列表
    @Override
    public List<Message> getInstructions() {
        return this.messages;
    }


    // 将所有消息合并为单个字符串
    public String getContents() {
        StringBuilder sb = new StringBuilder();
        for (Message message : getInstructions()) {
            sb.append(message.getText());
        }
        return sb.toString();
    }

    // Builder模式
    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private String content;
        private List<Message> messages;
        private ChatOptions options;

        ...

        public Prompt build() {
            return new Prompt(this.messages, this.chatOptions);
        }
    }
}

4.3 PromptTemplate详解

PromptTemplate用于创建提示词的模板。它允许您定义一个包含变量占位符的模板字符串,然后通过为这些变量提供具体值来渲染(生成)该模板。

typescript 复制代码
public class PromptTemplate implements PromptTemplateActions, PromptTemplateMessageActions {

	private static final TemplateRenderer DEFAULT_TEMPLATE_RENDERER = StTemplateRenderer.builder().build();
	private String template;
	private final Map<String, Object> variables = new HashMap<>();
	private final TemplateRenderer renderer;

	public PromptTemplate(String template) {
		this(template, new HashMap<>(), DEFAULT_TEMPLATE_RENDERER);
	}

	PromptTemplate(String template, Map<String, Object> variables, TemplateRenderer renderer) {
		Assert.hasText(template, "template cannot be null or empty");
		Assert.notNull(variables, "variables cannot be null");
		Assert.noNullElements(variables.keySet(), "variables keys cannot be null");
		Assert.notNull(renderer, "renderer cannot be null");

		this.template = template;
		this.variables.putAll(variables);
		this.renderer = renderer;
	}

	public void add(String name, Object value) {
		this.variables.put(name, value);
	}
	@Override
	public String render() {
		// 使用构造时传入的变量渲染模板,返回字符串
		Map<String, Object> processedVariables = new HashMap<>();
		for (Entry<String, Object> entry : this.variables.entrySet()) {
			if (entry.getValue() instanceof Resource resource) {
				processedVariables.put(entry.getKey(), renderResource(resource));
			}
			else {
				processedVariables.put(entry.getKey(), entry.getValue());
			}
		}
		return this.renderer.apply(this.template, processedVariables);
	}

	@Override
	public String render(Map<String, Object> additionalVariables) {
		//合并额外变量后渲染
		Map<String, Object> combinedVariables = new HashMap<>();
		Map<String, Object> mergedVariables = new HashMap<>(this.variables);
		// variables + additionalVariables => mergedVariables
		if (additionalVariables != null && !additionalVariables.isEmpty()) {
			mergedVariables.putAll(additionalVariables);
		}

		for (Entry<String, Object> entry : mergedVariables.entrySet()) {
			if (entry.getValue() instanceof Resource resource) {
				combinedVariables.put(entry.getKey(), renderResource(resource));
			}
			else {
				combinedVariables.put(entry.getKey(), entry.getValue());
			}
		}

		return this.renderer.apply(this.template, combinedVariables);
	}
 

	@Override
	public Message createMessage() {
		//渲染后创建 UserMessage
		return new UserMessage(render());
	}

	@Override
	public Message createMessage(List<Media> mediaList) {
		return UserMessage.builder().text(render()).media(mediaList).build();
	}

	@Override
	public Message createMessage(Map<String, Object> additionalVariables) {
		return new UserMessage(render(additionalVariables));
	}

	// From PromptTemplateActions.

	@Override
	public Prompt create() {
		//渲染后创建 Prompt 对象
		return new Prompt(render(new HashMap<>()));  
	}

	@Override
	public Prompt create(ChatOptions modelOptions) {
		return Prompt.builder().content(render(new HashMap<>())).chatOptions(modelOptions).build();
	}

	@Override
	public Prompt create(Map<String, Object> additionalVariables) {
		return new Prompt(render(additionalVariables));
	}

	@Override
	public Prompt create(Map<String, Object> additionalVariables, ChatOptions modelOptions) {
		return Prompt.builder().content(render(additionalVariables)).chatOptions(modelOptions).build();
	}

	public Builder mutate() {
		return new Builder().template(this.template).variables(this.variables).renderer(this.renderer);
	}

	// Builder

	public static Builder builder() {
		return new Builder();
	}

	public static class Builder {

		protected String template;

		protected Resource resource;

		protected Map<String, Object> variables = new HashMap<>();

		protected TemplateRenderer renderer = DEFAULT_TEMPLATE_RENDERER;
 
		public PromptTemplate build() {
				return new PromptTemplate(this.template, this.variables, this.renderer);
		}

	}

}

三种专用PromptTemplate

scala 复制代码
// 1. SystemPromptTemplate - 专门用于创建系统消息
public class SystemPromptTemplate extends PromptTemplate {
    @Override
    public Message createMessage(Map<String, Object> model) {
        return new SystemMessage(render(model));
    }
}

// 2. AssistantPromptTemplate - 专门用于创建助手消息
public class AssistantPromptTemplate extends PromptTemplate {
    @Override
    public Message createMessage(Map<String, Object> model) {
        return new AssistantMessage(render(model));
    }
}

// 3. ToolResponseMessage - 专门用于创建工具调用消息
public class ToolResponseMessage extends PromptTemplate {
    @Override
    public Message createMessage(Map<String, Object> model) {
        String jsonContent = render(model);
        return FunctionCallMessage.fromJson(jsonContent);
    }
}

使用示例

typescript 复制代码
@Service
public class ChatService {
    
    public String generatePrompt() {
        Map<String, Object> variables = new HashMap<>();
        variables.put("name", "Spring AI");
        SystemPromptTemplate systemPromptTemplate = SystemPromptTemplate.builder()
			.template("Hello {name}!")
			.variables(variables)
			.build();
        Prompt prompt = systemPromptTemplate.create(variables);
        return prompt;
    }
}

5 选项合并机制

5.1 选项层级

Spring AI支持两层选项配置:

  1. 启动时默认选项 :在创建ChatModel实例时设置,作为全局默认值
  2. 运行时请求选项 :在每个Prompt请求中传递,可以覆盖默认选项

5.2 合并源码剖析

kotlin 复制代码
/**
 * 构建请求 Prompt(将传入的 Prompt 转换为最终的请求 Prompt)
 *
 * @param prompt 原始 Prompt(包含用户指令和可能的运行时选项)
 * @return 合并了默认选项后的最终请求 Prompt
 */
Prompt buildRequestPrompt(Prompt prompt) {
    // 处理运行时选项(从原始 Prompt 中提取并转换)
    OpenAiChatOptions runtimeOptions = null;
    if (prompt.getOptions() != null) {
        // 如果选项是 ToolCallingChatOptions 类型,需要特殊处理
        if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
            runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class,
                    OpenAiChatOptions.class);
        } else {
            // 普通 ChatOptions,直接拷贝到 OpenAiChatOptions
            runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
                    OpenAiChatOptions.class);
        }
    }

    // 通过合并运行时选项和默认选项,定义最终请求的选项
    OpenAiChatOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions,
            OpenAiChatOptions.class);

    // 显式合并被 @JsonIgnore 注解标记的选项(这些选项会被 Jackson 忽略,
    // 而 ModelOptionsUtils 也会忽略它们,因此需要手动处理)
    if (runtimeOptions != null) {
        // OpenAI 不支持 topK 参数,给出警告
        if (runtimeOptions.getTopK() != null) {
            logger.warn("topK 选项不被 OpenAI 聊天模型支持,将被忽略。");
        }

        // 合并 HTTP 头信息
        requestOptions.setHttpHeaders(
                mergeHttpHeaders(runtimeOptions.getHttpHeaders(), this.defaultOptions.getHttpHeaders()));
        // 合并内部工具执行开关
        requestOptions.setInternalToolExecutionEnabled(
                ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(),
                        this.defaultOptions.getInternalToolExecutionEnabled()));
        // 合并工具名称列表
        requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(),
                this.defaultOptions.getToolNames()));
        // 合并工具回调函数
        requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(),
                this.defaultOptions.getToolCallbacks()));
        // 合并工具上下文
        requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(),
                this.defaultOptions.getToolContext()));
    } else {
        // 没有运行时选项时,直接使用默认选项
        requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders());
        requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled());
        requestOptions.setToolNames(this.defaultOptions.getToolNames());
        requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());
        requestOptions.setToolContext(this.defaultOptions.getToolContext());
    }

    // 验证工具回调函数的有效性
    ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks());

    // 使用处理后的指令和合并后的选项创建最终的 Prompt
    return new Prompt(prompt.getInstructions(), requestOptions);
}

5.3 配置示例

yaml 复制代码
# application.yml - 全局默认配置
spring:
  ai:
    openai:
      chat:
        options:
          model: gpt-4
          temperature: 0.7
          max-tokens: 1000
scss 复制代码
// 运行时覆盖
@Service
public class DynamicConfigService {
    @Autowired
    private OpenAiChatModel chatModel;
    public String lowTemperatureResponse(String input) {
        // 创建低温度选项(更确定性的输出)
        ChatOptions lowTempOptions = OpenAiChatOptions.builder()
            .temperature(0.2)
            .build();
        
        Prompt prompt = new Prompt(
            new UserMessage(input),
            lowTempOptions
        );
        
        return chatModel.call(prompt).getResult().getOutput().getText();
    }
    
    public String creativeResponse(String input) {
        // 创建高温度选项(更有创意的输出)
        ChatOptions creativeOptions = OpenAiChatOptions.builder()
            .temperature(1.2)
            .topP(0.95)
            .build();
        
        Prompt prompt = new Prompt(
            new UserMessage(input),
            creativeOptions
        );
        
        return chatModel.call(prompt).getResult().getOutput().getText();
    }
}

6 ChatResponse与Generation

6.1 ChatResponse类

kotlin 复制代码
public class ChatResponse implements ModelResponse<Generation> {

	private final ChatResponseMetadata chatResponseMetadata;

	/**
	 * AI返回的生成消息列表。
	 */
	private final List<Generation> generations;

	public ChatResponse(List<Generation> generations) {
		this(generations, new ChatResponseMetadata());
	}

	public ChatResponse(List<Generation> generations, ChatResponseMetadata chatResponseMetadata) {
		this.chatResponseMetadata = chatResponseMetadata;
		this.generations = List.copyOf(generations);
	}

	public static Builder builder() {
		return new Builder();
	}

	@Override
	public List<Generation> getResults() {
		return this.generations;
	}

	public Generation getResult() {
		if (CollectionUtils.isEmpty(this.generations)) {
			return null;
		}
		return this.generations.get(0);
	}

	/**
	 * 返回包含 AI使用信息的
	 */
	@Override
	public ChatResponseMetadata getMetadata() {
		return this.chatResponseMetadata;
	}
}

6.2 Generation类

kotlin 复制代码
/**
 * Represents a response returned by the AI.
 */
public class Generation implements ModelResult<AssistantMessage> {

	private final AssistantMessage assistantMessage;

	private ChatGenerationMetadata chatGenerationMetadata;

	public Generation(AssistantMessage assistantMessage) {
		this(assistantMessage, ChatGenerationMetadata.NULL);
	}

	public Generation(AssistantMessage assistantMessage, ChatGenerationMetadata chatGenerationMetadata) {
		this.assistantMessage = assistantMessage;
		this.chatGenerationMetadata = chatGenerationMetadata;
	}

	@Override
	public AssistantMessage getOutput() {
		return this.assistantMessage;
	}

	@Override
	public ChatGenerationMetadata getMetadata() {
		ChatGenerationMetadata chatGenerationMetadata = this.chatGenerationMetadata;
		return chatGenerationMetadata != null ? chatGenerationMetadata : ChatGenerationMetadata.NULL;
	}
}

7 流式调用实现剖析

7.1 Flux响应式流

Spring AI的流式调用基于Project Reactor的Flux类型:

scss 复制代码
// 源码示例:OpenAiChatModel的流式实现
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
    // 在进一步处理之前,构建最终的请求 Prompt,
    // 合并运行时选项和默认选项
    Prompt requestPrompt = buildRequestPrompt(prompt);
    return internalStream(requestPrompt, null);
}

public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
    return Flux.deferContextual(contextView -> {
        // 创建流式请求对象(参数 true 表示流式请求)
        ChatCompletionRequest request = createRequest(prompt, true);

        // 检查是否包含音频输出(流式请求不支持)
        if (request.outputModalities() != null
                && request.outputModalities().contains(OpenAiApi.OutputModality.AUDIO)) { 
            throw new IllegalArgumentException("流式请求不支持音频输出。");
        }

        // 检查是否包含音频参数(流式请求不支持)
        if (request.audioParameters() != null) { 
            throw new IllegalArgumentException("流式请求不支持音频参数。");
        }

        // 调用 OpenAI API 发起流式聊天补全请求
        Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.openAiApi.chatCompletionStream(request,
                getAdditionalHttpHeaders(prompt));

        // 用于缓存每个请求 ID 对应的角色信息(只有第一个分块包含 role)
        ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();

        // 构建可观测性上下文,用于记录模型调用的指标
        final ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
                .prompt(prompt)
                .provider(OpenAiApiConstants.PROVIDER_NAME)
                .build();

        // 创建可观测性记录(用于性能监控和日志)
        Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
                this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
                this.observationRegistry);

        observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();

        // 将 ChatCompletionChunk 转换为 ChatCompletion,以便复用函数调用处理逻辑
        Flux<ChatResponse> chatResponse = completionChunks.map(this::chunkToChatCompletion)
            .switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> {
                try {
                    // 如果没有提供 id,则设置为 "NO_ID"
                    String id = chatCompletion2.id() == null ? "NO_ID" : chatCompletion2.id();

                    // 将每个 choice 转换为 Generation 对象
                    List<Generation> generations = chatCompletion2.choices().stream().map(choice -> {
                        // 缓存角色信息(首个有效分块)
                        if (choice.message().role() != null) {
                            roleMap.putIfAbsent(id, choice.message().role().name());
                        }
                        // 构建元数据
                        Map<String, Object> metadata = Map.of(
                                "id", id,
                                "role", roleMap.getOrDefault(id, ""),
                                "index", choice.index() != null ? choice.index() : 0,
                                "finishReason", getFinishReasonJson(choice.finishReason()),
                                "refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "",
                                "annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of(),
                                "reasoningContent", choice.message().reasoningContent() != null ? choice.message().reasoningContent() : "");
                        return buildGeneration(choice, metadata, request);
                    }).toList();

                    // 处理用量统计(token 使用情况)
                    OpenAiApi.Usage usage = chatCompletion2.usage();
                    Usage currentChatResponseUsage = usage != null ? getDefaultUsage(usage) : new EmptyUsage();
                    Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage,
                            previousChatResponse);
                    return new ChatResponse(generations, from(chatCompletion2, null, accumulatedUsage));
                } catch (Exception e) { 
                    return new ChatResponse(List.of());
                }
            }))
            // 创建重叠缓冲区(2个元素,步长1),用于处理最终响应中的 usage 信息
            .buffer(2, 1)
            .map(bufferList -> {
                ChatResponse firstResponse = bufferList.get(0);
                // 如果启用了流式用量统计
                if (request.streamOptions() != null && request.streamOptions().includeUsage()) {
                    if (bufferList.size() == 2) {
                        ChatResponse secondResponse = bufferList.get(1);
                        if (secondResponse != null && secondResponse.getMetadata() != null) {
                            Usage usage = secondResponse.getMetadata().getUsage();
                            if (!UsageCalculator.isEmpty(usage)) {
                                // 将最终响应的 usage 存储到倒数第二个响应中
                                return new ChatResponse(firstResponse.getResults(),
                                        from(firstResponse.getMetadata(), usage));
                            }
                        }
                    }
                }
                return firstResponse;
            });

        // 处理工具调用(Tool Calling)
        Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
            // 判断是否需要执行工具调用
            if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
                // 使用 bounded elastic 调度器执行同步工具调用
                return Flux.deferContextual(ctx -> {
                    ToolExecutionResult toolExecutionResult;
                    try {
                        ToolCallReactiveContextHolder.setContext(ctx);
                        // 执行工具调用
                        toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
                    } finally {
                        ToolCallReactiveContextHolder.clearContext();
                    }
                    if (toolExecutionResult.returnDirect()) {
                        // 直接返回工具执行结果给客户端
                        return Flux.just(ChatResponse.builder().from(response)
                                .generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
                                .build());
                    } else {
                        // 将工具执行结果发送回模型,继续对话
                        return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
                                response);
                    }
                }).subscribeOn(Schedulers.boundedElastic());
            } else {
                // 无需工具调用,直接返回响应
                return Flux.just(response);
            }
        })
        .doOnError(observation::error)  // 记录错误到可观测性
        .doFinally(s -> observation.stop())  // 停止可观测性记录
        .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));  // 传递上下文

        // 聚合消息并更新可观测性上下文中的响应
        return new MessageAggregator().aggregate(flux, observationContext::setResponse);
    });
}

7.2 流式响应示例

typescript 复制代码
@RestController
public class StreamingController {
    @Autowired
    private ChatModel chatModel;
    
    @GetMapping(value = "/stream/chat", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
    public Flux<String> streamChat(@RequestParam String message) {
        Prompt prompt = new Prompt(new UserMessage(message));
        
        // 使用Flux处理流式响应
        return chatModel.stream(prompt)
            .map(chatResponse -> {
                Generation result = chatResponse.getResult();
                return result != null ? result.getOutput().getText() : "";
            })
            .filter(StringUtils::hasText);  // 过滤空内容
    }
}

7.3 流式响应聚合

当需要将流式响应聚合为完整响应时(如在WebFlux环境中):

typescript 复制代码
@Service
public class ChatAggregationService {
    @Autowired
    private ChatModel chatModel;
    
    public Mono<String> getCompleteResponse(String prompt) {
        return chatModel.stream(new Prompt(new UserMessage(prompt)))
            .map(response -> {
                Generation result = response.getResult();
                return result != null ? result.getOutput().getText() : "";
            })
            .collectList()          // 收集所有片段
            .map(list -> String.join("", list));  // 拼接为完整响应
    }
}

8 核心实现类:OpenAiChatModel源码深度剖析

8.1 类结构

java 复制代码
public class OpenAiChatModel implements ChatModel {
    
    // 核心依赖
    private final OpenAiApi openAiApi;                       // OpenAI API客户端
    private final OpenAiChatOptions defaultOptions;          // 默认选项
    private final RetryTemplate retryTemplate;               // 重试模板
    private final ToolCallingManager toolCallingManager;     // 工具调用管理器
    private final ObservationRegistry observationRegistry;   // 观测注册表
    
    // 可选配置
    private final Predicate<ChatResponse> toolExecutionEligibilityPredicate;
    private final ChatModelObservationConvention observationConvention;
    
    // Builder模式构建
    public static Builder builder() {
        return new Builder();
    }
    
    // ... 其他代码
}

8.2 call方法实现

scss 复制代码
@Override
public ChatResponse call(Prompt prompt) {
    // Before moving any further, build the final request Prompt,
    // merging runtime and default options.
    Prompt requestPrompt = buildRequestPrompt(prompt);
    return this.internalCall(requestPrompt, null);
}

public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {

    // 创建非流式聊天补全请求(参数 false 表示非流式)
    ChatCompletionRequest request = createRequest(prompt, false);

    // 构建可观测性上下文,用于记录模型调用的指标信息
    ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
            .prompt(prompt)
            .provider(OpenAiApiConstants.PROVIDER_NAME)
            .build();

    // 在可观测性记录范围内执行调用
    ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
            .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
                    this.observationRegistry)
            .observe(() -> {

                // 使用重试模板执行 API 调用,获取聊天补全响应实体
                ResponseEntity<ChatCompletion> completionEntity = this.retryTemplate
                    .execute(ctx -> this.openAiApi.chatCompletionEntity(request, getAdditionalHttpHeaders(prompt)));

                var chatCompletion = completionEntity.getBody();

                // 校验响应体是否为空
                if (chatCompletion == null) {
                    return new ChatResponse(List.of());
                }

                List<Choice> choices = chatCompletion.choices();
                // 校验 choices 列表是否为空
                if (choices == null) {
                    return new ChatResponse(List.of());
                }

                // 将每个 choice 转换为 Generation 对象
                List<Generation> generations = choices.stream().map(choice -> {
                    // 构建元数据:id、role、索引、完成原因、拒绝信息、标注等
                    Map<String, Object> metadata = Map.of(
                            "id", chatCompletion.id() != null ? chatCompletion.id() : "",
                            "role", choice.message().role() != null ? choice.message().role().name() : "",
                            "index", choice.index() != null ? choice.index() : 0,
                            "finishReason", getFinishReasonJson(choice.finishReason()),
                            "refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "",
                            "annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of(Map.of()));
                    return buildGeneration(choice, metadata, request);
                }).toList();

                // 从响应头中提取速率限制信息
                RateLimit rateLimit = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity);

                // 处理 token 用量统计
                OpenAiApi.Usage usage = chatCompletion.usage();
                Usage currentChatResponseUsage = usage != null ? getDefaultUsage(usage) : new EmptyUsage();
                // 累积本次调用和之前调用的用量
                Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage,
                        previousChatResponse);
                
                // 构建最终响应对象
                ChatResponse chatResponse = new ChatResponse(generations,
                        from(chatCompletion, rateLimit, accumulatedUsage));

                // 将响应记录到可观测性上下文
                observationContext.setResponse(chatResponse);

                return chatResponse;
            });

    // 判断是否需要执行工具调用
    if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
        // 执行工具调用
        var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
        
        if (toolExecutionResult.returnDirect()) {
            // 直接返回工具执行结果给客户端(不继续调用模型)
            return ChatResponse.builder()
                    .from(response)
                    .generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
                    .build();
        } else {
            // 将工具执行结果发送回模型,继续对话
            return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
                    response);
        }
    }

    // 无需工具调用,直接返回响应
    return response;
}

8.3 请求转换方法

scss 复制代码
// 创建聊天补全请求(stream: true=流式,false=非流式)
ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {

    // 将 Prompt 中的各类型 Message 转换为 OpenAI 的 ChatCompletionMessage 列表
    List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(message -> {
        // 处理 USER 和 SYSTEM 类型消息
        if (message.getMessageType() == MessageType.USER || message.getMessageType() == MessageType.SYSTEM) {
            Object content = message.getText();
            if (message instanceof UserMessage userMessage) {
                // 处理多媒体内容(图片等)
                if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
                    List<MediaContent> contentList = new ArrayList<>(List.of(new MediaContent(message.getText())));
                    contentList.addAll(userMessage.getMedia().stream().map(this::mapToMediaContent).toList());
                    content = contentList;
                }
            }
            return List.of(new ChatCompletionMessage(content,
                    ChatCompletionMessage.Role.valueOf(message.getMessageType().name())));
        }
        // 处理 ASSISTANT 类型消息(模型历史回复)
        else if (message.getMessageType() == MessageType.ASSISTANT) {
            var assistantMessage = (AssistantMessage) message;
            // 转换工具调用
            List<ToolCall> toolCalls = null;
            if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
                toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> {
                    var function = new ChatCompletionFunction(toolCall.name(), toolCall.arguments());
                    return new ToolCall(toolCall.id(), toolCall.type(), function);
                }).toList();
            }
            // 处理音频输出
            AudioOutput audioOutput = null;
            if (!CollectionUtils.isEmpty(assistantMessage.getMedia())) {
                Assert.isTrue(assistantMessage.getMedia().size() == 1,
                        "Only one media content is supported for assistant messages");
                audioOutput = new AudioOutput(assistantMessage.getMedia().get(0).getId(), null, null, null);
            }
            return List.of(new ChatCompletionMessage(assistantMessage.getText(),
                    ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput, null, null));
        }
        // 处理 TOOL 类型消息(工具调用结果)
        else if (message.getMessageType() == MessageType.TOOL) {
            ToolResponseMessage toolMessage = (ToolResponseMessage) message;
            toolMessage.getResponses()
                .forEach(response -> Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id"));
            return toolMessage.getResponses()
                .stream()
                .map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(),
                        tr.id(), null, null, null, null, null))
                .toList();
        }
        else {
            throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
        }
    }).flatMap(List::stream).toList();

    // 创建基础请求对象
    ChatCompletionRequest request = new ChatCompletionRequest(chatCompletionMessages, stream);

    // 合并运行时选项(如 temperature、maxTokens 等)
    OpenAiChatOptions requestOptions = (OpenAiChatOptions) prompt.getOptions();
    request = ModelOptionsUtils.merge(requestOptions, request, ChatCompletionRequest.class);

    // 添加工具定义(Function Calling)
    List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions);
    if (!CollectionUtils.isEmpty(toolDefinitions)) {
        request = ModelOptionsUtils.merge(
                OpenAiChatOptions.builder().tools(this.getFunctionTools(toolDefinitions)).build(), request,
                ChatCompletionRequest.class);
    }

    // 非流式请求时移除 streamOptions 参数
    if (request.streamOptions() != null && !stream) {
        logger.warn("Removing streamOptions from the request as it is not a streaming request!");
        request = request.streamOptions(null);
    }

    return request;
}

9 多Provider适配实现

9.1 Provider统一抽象

Spring AI通过统一的ChatModel接口屏蔽了不同Provider的差异:

java 复制代码
// 所有Provider都实现相同的ChatModel接口
// 1. OpenAI实现
public class OpenAiChatModel implements ChatModel { ... }

// 2. Azure OpenAI实现
public class AzureOpenAiChatModel implements ChatModel { ... }

// 3. Ollama实现
public class OllamaChatModel implements ChatModel { ... }

// 4. HuggingFace实现
public class HuggingFaceChatModel implements ChatModel { ... }

9.2 使用示例

arduino 复制代码
@Service
public class MultiProviderService {
    
    // 可以轻松切换Provider,代码无需修改
    private final ChatModel chatModel;
    
    public MultiProviderService(ChatModel chatModel) {
        this.chatModel = chatModel;
    }
    
    public String chat(String message) {
        return chatModel.call(message);
    }
}

10 小结

本文深入剖析了Spring AI Model层的设计与实现,涵盖以下核心内容:

  1. Generic Model APIModelModelRequestModelResponse等基础接口构成了整个AI模型交互的基石,通过泛型设计实现了高度的类型安全性和可扩展性。
  2. ChatModel体系 :在通用API之上扩展了对话场景的特化接口,提供了call()stream()两种调用方式,满足同步和异步场景需求。
  3. Prompt与MessagePrompt作为消息容器,支持多种角色消息(System、User、Assistant、Tool),PromptTemplate基于StTemplateRenderer引擎实现动态提示词渲染。
  4. 选项合并机制:支持启动时默认选项和运行时请求选项的两层配置,运行时选项可以覆盖默认选项,提供了灵活的配置能力。
  5. 流式调用 :基于Project Reactor的Flux类型实现响应式流式输出,有效降低首字延迟,提升用户体验。
  6. OpenAiChatModel实现:作为官方参考实现,展示了完整的请求转换、响应解析、观测集成和工具调用处理逻辑。
相关推荐
认真的薛薛12 小时前
Linux基础:GitOps发布流程
java·linux·运维
鱼鳞_12 小时前
苍穹外卖-Day05(Redis)
java·redis
雨落在了我的手上13 小时前
初识java(九):类和对象(⼀)
java·开发语言
是码龙不是码农13 小时前
数据库主键选型:为什么别用自增 ID?
java·数据库
北风toto13 小时前
Jenkins新手入门安装插件全报错
java·运维·jenkins
罗超驿13 小时前
20.MySQL事务隔离级别示例详解(脏读、不可重复读、幻读)
java·数据库·mysql·面试
Dicky-_-zhang13 小时前
KubeEdge边缘部署实践
java·jvm
码银13 小时前
在若依中如何新建一个模块(图文教程)
java·javascript