1 Model层概述
1.1 架构定位
Model层是Spring AI的核心抽象层,它定义了与各类AI模型交互的统一契约。这一层采用经典的面向接口设计,通过泛型抽象实现了高度的可扩展性,使得开发者可以在不改变业务代码的前提下,无缝切换不同的AI提供商。
Model层的核心价值体现在以下几个方面:
- 统一抽象:屏蔽不同AI提供商API的差异,提供一致的编程模型
- 类型安全:通过泛型机制保证请求和响应类型的编译期检查
- 可扩展性:支持同步调用、流式调用,以及自定义模型实现
- 可观测性:内置Micrometer观测支持,便于监控和调试
1.2 整体类图

2 Generic Model API 源码剖析
2.1 Model接口
Model接口是所有AI模型交互的根接口,它使用Java泛型来支持不同类型的请求和响应:
php
/**
* 所有AI模型的根接口
* @param <TReq> 请求类型,必须继承ModelRequest
* @param <TRes> 响应类型,必须继承ModelResponse
*/
public interface Model<TReq extends ModelRequest<?>, TRes extends ModelResponse<?>> {
/**
* 调用AI模型的核心方法
* @param request 请求对象
* @return 响应对象
*/
TRes call(TReq request);
}
设计要点:
- 泛型约束 :
TReq extends ModelRequest<?>确保了请求参数必须是ModelRequest的子类型 - 泛型约束 :
TRes extends ModelResponse<?>确保了响应参数必须是ModelResponse的子类型 - 默认方法 :
call()提供默认实现,避免所有子类都必须实现的方法
2.2 StreamingModel接口
对于原生支持流式响应的模型,Spring AI提供了专门的StreamingModel接口:
php
/**
* 支持流式响应的AI模型接口
* @param <TReq> 请求类型
* @param <TResChunk> 流式响应块类型
*/
public interface StreamingModel<TReq extends ModelRequest<?>, TResChunk extends ModelResponse<?>> {
/**
* 流式调用AI模型
* @param request 请求对象
* @return 响应块的Flux流
*/
Flux<TResChunk> stream(TReq request);
}
2.3 ModelRequest接口
ModelRequest封装了对AI模型的请求信息:
csharp
/**
* AI模型请求的抽象
* @param <T> 指令类型(通常是Prompt或Message列表)
*/
public interface ModelRequest<T> {
/**
* 获取发送给AI模型的指令或输入
* @return 指令内容
*/
T getInstructions();
/**
* 获取AI模型交互的可自定义选项
* @return 模型选项
*/
ModelOptions getOptions();
}
2.4 ModelOptions接口
ModelOptions是一个标记接口,用于携带AI模型的自定义参数:
csharp
/**
* AI模型选项的标记接口
* 各个具体模型实现可以扩展此接口添加自己的选项
*/
public interface ModelOptions {
// 标记接口,无具体方法
}
ChatOptions是ModelOptions的重要子接口,定义了对话模型通用的选项:
scss
public interface ChatOptions extends ModelOptions {
String getModel(); // 模型ID
Double getFrequencyPenalty(); // 频率惩罚(-2.0~2.0),降低重复token的可能性
Integer getMaxTokens(); // 最大生成token数
Double getPresencePenalty(); // 存在惩罚(-2.0~2.0),鼓励谈论新话题
List<String> getStopSequences(); // 停止序列
Double getTemperature(); // 采样温度(0.0~2.0),控制随机性
Integer getTopK(); // Top-K采样
Double getTopP(); // Top-P核采样
<T entends ChatOptions> T copy(); // 深拷贝
}
参数详解:
| 参数 | 取值范围 | 作用 | 典型场景 |
|---|---|---|---|
| temperature | 0.0~2.0 | 控制输出的随机性 | 0.2:代码生成;0.8:创意写作 |
| topP | 0.0~1.0 | 核采样,动态选择token | 0.9:平衡质量和多样性 |
| maxTokens | 1~4096 | 限制响应长度 | 根据任务需求设置 |
| frequencyPenalty | -2.0~2.0 | 降低重复内容 | 0.5:减少重复 |
| presencePenalty | -2.0~2.0 | 鼓励新话题 | 0.5:引导话题转移 |
2.5 ModelResponse接口
ModelResponse封装了AI模型的响应结果:
csharp
/**
* AI模型响应的抽象
* @param <T> 结果类型,必须继承ModelResult
*/
public interface ModelResponse<T extends ModelResult<?>> {
/**
* 获取主要结果
* @return 模型生成的结果
*/
T getResult();
/**
* 获取所有生成的结果(用于n>1的情况)
* @return 结果列表
*/
List<T> getResults();
/**
* 获取响应元数据
* @return 响应元数据
*/
ResponseMetadata getMetadata();
}
2.6 ModelResult接口
ModelResult表示AI模型产生的单个输出结果:
csharp
/**
* AI模型输出的抽象
* @param <T> 输出类型(如AssistantMessage)
*/
public interface ModelResult<T> {
/**
* 获取模型生成的输出
* @return 输出内容
*/
T getOutput();
/**
* 获取结果相关的元数据
* @return 结果元数据
*/
ResultMetadata getMetadata();
}
3 ChatModel接口体系详解
3.1 ChatModel接口定义
ChatModel是针对对话场景的特化接口,它扩展了通用的Model接口:
java
public interface ChatModel extends Model<Prompt, ChatResponse>, StreamingChatModel {
/**
* 便捷方法:直接使用字符串消息调用
* 内部自动创建UserMessage和Prompt
* @param message 用户消息
* @return 响应的文本内容
*/
default String call(String message) {
Prompt prompt = new Prompt(new UserMessage(message));
ChatResponse response = this.call(prompt);
Generation result = response.getResult();
return result != null ? result.getOutput().getText() : null;
}
/**
* 核心方法:使用完整的Prompt对象调用
* @param prompt 提示词对象
* @return 完整的对话响应
*/
@Override
ChatResponse call(Prompt prompt);
}
3.2 StreamingChatModel接口
typescript
public interface StreamingChatModel extends StreamingModel<Prompt, ChatResponse> {
/**
* 便捷方法:流式调用字符串消息
* @param message 用户消息
* @return 文本块的Flux流
*/
default Flux<String> stream(String message) {
Prompt prompt = new Prompt(message);
return stream(prompt).map(response -> Optional.ofNullable(response.getResult())
.map(Generation::getOutput)
.map(AssistantMessage::getText)
.orElse(""));
}
/**
* 核心方法:流式调用完整的Prompt
* @param prompt 提示词对象
* @return ChatResponse块的Flux流
*/
@Override
Flux<ChatResponse> stream(Prompt prompt);
}
3.3 调用流程图

4 Prompt与Message体系
4.1 Message接口与实现
Message是Prompt的基本组成单元,代表对话中的一条消息:
csharp
/**
* 对话消息的抽象
*/
public interface Message extends Content {
/**
* 获取消息类型(用户、系统、助手等)
*/
MessageType getMessageType();
}
// Content接口定义了消息内容的基础能力
public interface Content {
String getText(); // 消息文本内容
Map<String, Object> getMetadata(); // 元数据
}
// 多模态消息扩展
public interface MediaContent extends Content {
List<Media> getMedia(); // 媒体内容列表
}
MessageType枚举:
scss
public enum MessageType {
SYSTEM("system"), // 系统指令,设定AI角色和行为
USER("user"), // 用户输入
ASSISTANT("assistant"),// AI助手回复
TOOL("tool"); // 工具调用请求
}
各Message实现类详解:
scala
// 1. UserMessage - 用户消息
public class UserMessage extends AbstractMessage implements MediaContent{
protected final List<Media> media;
public UserMessage(String textContent) {
this(textContent, new ArrayList<>(), Map.of());
}
private UserMessage(String textContent, Collection<Media> media, Map<String, Object> metadata) {
super(MessageType.USER, textContent, metadata);
this.media = new ArrayList<>(media);
}
}
// 2. SystemMessage - 系统消息(设定AI角色)
public class SystemMessage extends AbstractMessage {
public SystemMessage(String textContent) {
this(textContent, Map.of());
}
private SystemMessage(String textContent, Map<String, Object> metadata) {
super(MessageType.SYSTEM, textContent, metadata);
}
}
// 3. AssistantMessage - 助手回复消息
public class AssistantMessage extends AbstractMessage {
private List<ToolCall> toolCalls; // 可包含工具调用
protected final List<Media> media;
public AssistantMessage(String content) {
this(content, Map.of(), List.of(), List.of());
}
protected AssistantMessage(String content, Map<String, Object> properties, List<ToolCall> toolCalls,
List<Media> media) {
super(MessageType.ASSISTANT, content, properties); Assert.notNull(media, "Media must not be null");
this.toolCalls = toolCalls;
this.media = media;
}
}
// 4. ToolResponseMessage - 工具调用消息
public class ToolResponseMessage extends AbstractMessage {
protected final List<ToolResponse> responses;
protected ToolResponseMessage(List<ToolResponse> responses, Map<String, Object> metadata) {
super(MessageType.TOOL, "", metadata);
this.responses = responses;
}
}
4.2 Prompt类源码剖析
Prompt类是ModelRequest的具体实现,封装了消息列表和选项:
csharp
public class Prompt implements ModelRequest<List<Message>> {
private final List<Message> messages;
private ChatOptions chatOptions;
// 构造函数
public Prompt(List<Message> messages) {
this(messages, null);
}
public Prompt(Message... messages) {
this(Arrays.asList(messages), null);
}
public Prompt(String contents) {
this(new UserMessage(contents));
}
// 带选项的构造函数
public Prompt(List<Message> messages, ChatOptions chatOptions) {
this.messages = messages;
this.chatOptions = chatOptions;
}
// 获取消息列表
@Override
public List<Message> getInstructions() {
return this.messages;
}
// 将所有消息合并为单个字符串
public String getContents() {
StringBuilder sb = new StringBuilder();
for (Message message : getInstructions()) {
sb.append(message.getText());
}
return sb.toString();
}
// Builder模式
public static Builder builder() {
return new Builder();
}
public static class Builder {
private String content;
private List<Message> messages;
private ChatOptions options;
...
public Prompt build() {
return new Prompt(this.messages, this.chatOptions);
}
}
}
4.3 PromptTemplate详解
PromptTemplate用于创建提示词的模板。它允许您定义一个包含变量占位符的模板字符串,然后通过为这些变量提供具体值来渲染(生成)该模板。
typescript
public class PromptTemplate implements PromptTemplateActions, PromptTemplateMessageActions {
private static final TemplateRenderer DEFAULT_TEMPLATE_RENDERER = StTemplateRenderer.builder().build();
private String template;
private final Map<String, Object> variables = new HashMap<>();
private final TemplateRenderer renderer;
public PromptTemplate(String template) {
this(template, new HashMap<>(), DEFAULT_TEMPLATE_RENDERER);
}
PromptTemplate(String template, Map<String, Object> variables, TemplateRenderer renderer) {
Assert.hasText(template, "template cannot be null or empty");
Assert.notNull(variables, "variables cannot be null");
Assert.noNullElements(variables.keySet(), "variables keys cannot be null");
Assert.notNull(renderer, "renderer cannot be null");
this.template = template;
this.variables.putAll(variables);
this.renderer = renderer;
}
public void add(String name, Object value) {
this.variables.put(name, value);
}
@Override
public String render() {
// 使用构造时传入的变量渲染模板,返回字符串
Map<String, Object> processedVariables = new HashMap<>();
for (Entry<String, Object> entry : this.variables.entrySet()) {
if (entry.getValue() instanceof Resource resource) {
processedVariables.put(entry.getKey(), renderResource(resource));
}
else {
processedVariables.put(entry.getKey(), entry.getValue());
}
}
return this.renderer.apply(this.template, processedVariables);
}
@Override
public String render(Map<String, Object> additionalVariables) {
//合并额外变量后渲染
Map<String, Object> combinedVariables = new HashMap<>();
Map<String, Object> mergedVariables = new HashMap<>(this.variables);
// variables + additionalVariables => mergedVariables
if (additionalVariables != null && !additionalVariables.isEmpty()) {
mergedVariables.putAll(additionalVariables);
}
for (Entry<String, Object> entry : mergedVariables.entrySet()) {
if (entry.getValue() instanceof Resource resource) {
combinedVariables.put(entry.getKey(), renderResource(resource));
}
else {
combinedVariables.put(entry.getKey(), entry.getValue());
}
}
return this.renderer.apply(this.template, combinedVariables);
}
@Override
public Message createMessage() {
//渲染后创建 UserMessage
return new UserMessage(render());
}
@Override
public Message createMessage(List<Media> mediaList) {
return UserMessage.builder().text(render()).media(mediaList).build();
}
@Override
public Message createMessage(Map<String, Object> additionalVariables) {
return new UserMessage(render(additionalVariables));
}
// From PromptTemplateActions.
@Override
public Prompt create() {
//渲染后创建 Prompt 对象
return new Prompt(render(new HashMap<>()));
}
@Override
public Prompt create(ChatOptions modelOptions) {
return Prompt.builder().content(render(new HashMap<>())).chatOptions(modelOptions).build();
}
@Override
public Prompt create(Map<String, Object> additionalVariables) {
return new Prompt(render(additionalVariables));
}
@Override
public Prompt create(Map<String, Object> additionalVariables, ChatOptions modelOptions) {
return Prompt.builder().content(render(additionalVariables)).chatOptions(modelOptions).build();
}
public Builder mutate() {
return new Builder().template(this.template).variables(this.variables).renderer(this.renderer);
}
// Builder
public static Builder builder() {
return new Builder();
}
public static class Builder {
protected String template;
protected Resource resource;
protected Map<String, Object> variables = new HashMap<>();
protected TemplateRenderer renderer = DEFAULT_TEMPLATE_RENDERER;
public PromptTemplate build() {
return new PromptTemplate(this.template, this.variables, this.renderer);
}
}
}
三种专用PromptTemplate:
scala
// 1. SystemPromptTemplate - 专门用于创建系统消息
public class SystemPromptTemplate extends PromptTemplate {
@Override
public Message createMessage(Map<String, Object> model) {
return new SystemMessage(render(model));
}
}
// 2. AssistantPromptTemplate - 专门用于创建助手消息
public class AssistantPromptTemplate extends PromptTemplate {
@Override
public Message createMessage(Map<String, Object> model) {
return new AssistantMessage(render(model));
}
}
// 3. ToolResponseMessage - 专门用于创建工具调用消息
public class ToolResponseMessage extends PromptTemplate {
@Override
public Message createMessage(Map<String, Object> model) {
String jsonContent = render(model);
return FunctionCallMessage.fromJson(jsonContent);
}
}
使用示例:
typescript
@Service
public class ChatService {
public String generatePrompt() {
Map<String, Object> variables = new HashMap<>();
variables.put("name", "Spring AI");
SystemPromptTemplate systemPromptTemplate = SystemPromptTemplate.builder()
.template("Hello {name}!")
.variables(variables)
.build();
Prompt prompt = systemPromptTemplate.create(variables);
return prompt;
}
}
5 选项合并机制
5.1 选项层级
Spring AI支持两层选项配置:
- 启动时默认选项 :在创建
ChatModel实例时设置,作为全局默认值 - 运行时请求选项 :在每个
Prompt请求中传递,可以覆盖默认选项
5.2 合并源码剖析
kotlin
/**
* 构建请求 Prompt(将传入的 Prompt 转换为最终的请求 Prompt)
*
* @param prompt 原始 Prompt(包含用户指令和可能的运行时选项)
* @return 合并了默认选项后的最终请求 Prompt
*/
Prompt buildRequestPrompt(Prompt prompt) {
// 处理运行时选项(从原始 Prompt 中提取并转换)
OpenAiChatOptions runtimeOptions = null;
if (prompt.getOptions() != null) {
// 如果选项是 ToolCallingChatOptions 类型,需要特殊处理
if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class,
OpenAiChatOptions.class);
} else {
// 普通 ChatOptions,直接拷贝到 OpenAiChatOptions
runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
OpenAiChatOptions.class);
}
}
// 通过合并运行时选项和默认选项,定义最终请求的选项
OpenAiChatOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions,
OpenAiChatOptions.class);
// 显式合并被 @JsonIgnore 注解标记的选项(这些选项会被 Jackson 忽略,
// 而 ModelOptionsUtils 也会忽略它们,因此需要手动处理)
if (runtimeOptions != null) {
// OpenAI 不支持 topK 参数,给出警告
if (runtimeOptions.getTopK() != null) {
logger.warn("topK 选项不被 OpenAI 聊天模型支持,将被忽略。");
}
// 合并 HTTP 头信息
requestOptions.setHttpHeaders(
mergeHttpHeaders(runtimeOptions.getHttpHeaders(), this.defaultOptions.getHttpHeaders()));
// 合并内部工具执行开关
requestOptions.setInternalToolExecutionEnabled(
ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(),
this.defaultOptions.getInternalToolExecutionEnabled()));
// 合并工具名称列表
requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(),
this.defaultOptions.getToolNames()));
// 合并工具回调函数
requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(),
this.defaultOptions.getToolCallbacks()));
// 合并工具上下文
requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(),
this.defaultOptions.getToolContext()));
} else {
// 没有运行时选项时,直接使用默认选项
requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders());
requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled());
requestOptions.setToolNames(this.defaultOptions.getToolNames());
requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());
requestOptions.setToolContext(this.defaultOptions.getToolContext());
}
// 验证工具回调函数的有效性
ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks());
// 使用处理后的指令和合并后的选项创建最终的 Prompt
return new Prompt(prompt.getInstructions(), requestOptions);
}
5.3 配置示例
yaml
# application.yml - 全局默认配置
spring:
ai:
openai:
chat:
options:
model: gpt-4
temperature: 0.7
max-tokens: 1000
scss
// 运行时覆盖
@Service
public class DynamicConfigService {
@Autowired
private OpenAiChatModel chatModel;
public String lowTemperatureResponse(String input) {
// 创建低温度选项(更确定性的输出)
ChatOptions lowTempOptions = OpenAiChatOptions.builder()
.temperature(0.2)
.build();
Prompt prompt = new Prompt(
new UserMessage(input),
lowTempOptions
);
return chatModel.call(prompt).getResult().getOutput().getText();
}
public String creativeResponse(String input) {
// 创建高温度选项(更有创意的输出)
ChatOptions creativeOptions = OpenAiChatOptions.builder()
.temperature(1.2)
.topP(0.95)
.build();
Prompt prompt = new Prompt(
new UserMessage(input),
creativeOptions
);
return chatModel.call(prompt).getResult().getOutput().getText();
}
}
6 ChatResponse与Generation
6.1 ChatResponse类
kotlin
public class ChatResponse implements ModelResponse<Generation> {
private final ChatResponseMetadata chatResponseMetadata;
/**
* AI返回的生成消息列表。
*/
private final List<Generation> generations;
public ChatResponse(List<Generation> generations) {
this(generations, new ChatResponseMetadata());
}
public ChatResponse(List<Generation> generations, ChatResponseMetadata chatResponseMetadata) {
this.chatResponseMetadata = chatResponseMetadata;
this.generations = List.copyOf(generations);
}
public static Builder builder() {
return new Builder();
}
@Override
public List<Generation> getResults() {
return this.generations;
}
public Generation getResult() {
if (CollectionUtils.isEmpty(this.generations)) {
return null;
}
return this.generations.get(0);
}
/**
* 返回包含 AI使用信息的
*/
@Override
public ChatResponseMetadata getMetadata() {
return this.chatResponseMetadata;
}
}
6.2 Generation类
kotlin
/**
* Represents a response returned by the AI.
*/
public class Generation implements ModelResult<AssistantMessage> {
private final AssistantMessage assistantMessage;
private ChatGenerationMetadata chatGenerationMetadata;
public Generation(AssistantMessage assistantMessage) {
this(assistantMessage, ChatGenerationMetadata.NULL);
}
public Generation(AssistantMessage assistantMessage, ChatGenerationMetadata chatGenerationMetadata) {
this.assistantMessage = assistantMessage;
this.chatGenerationMetadata = chatGenerationMetadata;
}
@Override
public AssistantMessage getOutput() {
return this.assistantMessage;
}
@Override
public ChatGenerationMetadata getMetadata() {
ChatGenerationMetadata chatGenerationMetadata = this.chatGenerationMetadata;
return chatGenerationMetadata != null ? chatGenerationMetadata : ChatGenerationMetadata.NULL;
}
}
7 流式调用实现剖析
7.1 Flux响应式流
Spring AI的流式调用基于Project Reactor的Flux类型:
scss
// 源码示例:OpenAiChatModel的流式实现
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
// 在进一步处理之前,构建最终的请求 Prompt,
// 合并运行时选项和默认选项
Prompt requestPrompt = buildRequestPrompt(prompt);
return internalStream(requestPrompt, null);
}
public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
return Flux.deferContextual(contextView -> {
// 创建流式请求对象(参数 true 表示流式请求)
ChatCompletionRequest request = createRequest(prompt, true);
// 检查是否包含音频输出(流式请求不支持)
if (request.outputModalities() != null
&& request.outputModalities().contains(OpenAiApi.OutputModality.AUDIO)) {
throw new IllegalArgumentException("流式请求不支持音频输出。");
}
// 检查是否包含音频参数(流式请求不支持)
if (request.audioParameters() != null) {
throw new IllegalArgumentException("流式请求不支持音频参数。");
}
// 调用 OpenAI API 发起流式聊天补全请求
Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.openAiApi.chatCompletionStream(request,
getAdditionalHttpHeaders(prompt));
// 用于缓存每个请求 ID 对应的角色信息(只有第一个分块包含 role)
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();
// 构建可观测性上下文,用于记录模型调用的指标
final ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(OpenAiApiConstants.PROVIDER_NAME)
.build();
// 创建可观测性记录(用于性能监控和日志)
Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry);
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
// 将 ChatCompletionChunk 转换为 ChatCompletion,以便复用函数调用处理逻辑
Flux<ChatResponse> chatResponse = completionChunks.map(this::chunkToChatCompletion)
.switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> {
try {
// 如果没有提供 id,则设置为 "NO_ID"
String id = chatCompletion2.id() == null ? "NO_ID" : chatCompletion2.id();
// 将每个 choice 转换为 Generation 对象
List<Generation> generations = chatCompletion2.choices().stream().map(choice -> {
// 缓存角色信息(首个有效分块)
if (choice.message().role() != null) {
roleMap.putIfAbsent(id, choice.message().role().name());
}
// 构建元数据
Map<String, Object> metadata = Map.of(
"id", id,
"role", roleMap.getOrDefault(id, ""),
"index", choice.index() != null ? choice.index() : 0,
"finishReason", getFinishReasonJson(choice.finishReason()),
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "",
"annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of(),
"reasoningContent", choice.message().reasoningContent() != null ? choice.message().reasoningContent() : "");
return buildGeneration(choice, metadata, request);
}).toList();
// 处理用量统计(token 使用情况)
OpenAiApi.Usage usage = chatCompletion2.usage();
Usage currentChatResponseUsage = usage != null ? getDefaultUsage(usage) : new EmptyUsage();
Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage,
previousChatResponse);
return new ChatResponse(generations, from(chatCompletion2, null, accumulatedUsage));
} catch (Exception e) {
return new ChatResponse(List.of());
}
}))
// 创建重叠缓冲区(2个元素,步长1),用于处理最终响应中的 usage 信息
.buffer(2, 1)
.map(bufferList -> {
ChatResponse firstResponse = bufferList.get(0);
// 如果启用了流式用量统计
if (request.streamOptions() != null && request.streamOptions().includeUsage()) {
if (bufferList.size() == 2) {
ChatResponse secondResponse = bufferList.get(1);
if (secondResponse != null && secondResponse.getMetadata() != null) {
Usage usage = secondResponse.getMetadata().getUsage();
if (!UsageCalculator.isEmpty(usage)) {
// 将最终响应的 usage 存储到倒数第二个响应中
return new ChatResponse(firstResponse.getResults(),
from(firstResponse.getMetadata(), usage));
}
}
}
}
return firstResponse;
});
// 处理工具调用(Tool Calling)
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
// 判断是否需要执行工具调用
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
// 使用 bounded elastic 调度器执行同步工具调用
return Flux.deferContextual(ctx -> {
ToolExecutionResult toolExecutionResult;
try {
ToolCallReactiveContextHolder.setContext(ctx);
// 执行工具调用
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
} finally {
ToolCallReactiveContextHolder.clearContext();
}
if (toolExecutionResult.returnDirect()) {
// 直接返回工具执行结果给客户端
return Flux.just(ChatResponse.builder().from(response)
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
.build());
} else {
// 将工具执行结果发送回模型,继续对话
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
response);
}
}).subscribeOn(Schedulers.boundedElastic());
} else {
// 无需工具调用,直接返回响应
return Flux.just(response);
}
})
.doOnError(observation::error) // 记录错误到可观测性
.doFinally(s -> observation.stop()) // 停止可观测性记录
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); // 传递上下文
// 聚合消息并更新可观测性上下文中的响应
return new MessageAggregator().aggregate(flux, observationContext::setResponse);
});
}
7.2 流式响应示例
typescript
@RestController
public class StreamingController {
@Autowired
private ChatModel chatModel;
@GetMapping(value = "/stream/chat", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<String> streamChat(@RequestParam String message) {
Prompt prompt = new Prompt(new UserMessage(message));
// 使用Flux处理流式响应
return chatModel.stream(prompt)
.map(chatResponse -> {
Generation result = chatResponse.getResult();
return result != null ? result.getOutput().getText() : "";
})
.filter(StringUtils::hasText); // 过滤空内容
}
}
7.3 流式响应聚合
当需要将流式响应聚合为完整响应时(如在WebFlux环境中):
typescript
@Service
public class ChatAggregationService {
@Autowired
private ChatModel chatModel;
public Mono<String> getCompleteResponse(String prompt) {
return chatModel.stream(new Prompt(new UserMessage(prompt)))
.map(response -> {
Generation result = response.getResult();
return result != null ? result.getOutput().getText() : "";
})
.collectList() // 收集所有片段
.map(list -> String.join("", list)); // 拼接为完整响应
}
}
8 核心实现类:OpenAiChatModel源码深度剖析
8.1 类结构
java
public class OpenAiChatModel implements ChatModel {
// 核心依赖
private final OpenAiApi openAiApi; // OpenAI API客户端
private final OpenAiChatOptions defaultOptions; // 默认选项
private final RetryTemplate retryTemplate; // 重试模板
private final ToolCallingManager toolCallingManager; // 工具调用管理器
private final ObservationRegistry observationRegistry; // 观测注册表
// 可选配置
private final Predicate<ChatResponse> toolExecutionEligibilityPredicate;
private final ChatModelObservationConvention observationConvention;
// Builder模式构建
public static Builder builder() {
return new Builder();
}
// ... 其他代码
}
8.2 call方法实现
scss
@Override
public ChatResponse call(Prompt prompt) {
// Before moving any further, build the final request Prompt,
// merging runtime and default options.
Prompt requestPrompt = buildRequestPrompt(prompt);
return this.internalCall(requestPrompt, null);
}
public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
// 创建非流式聊天补全请求(参数 false 表示非流式)
ChatCompletionRequest request = createRequest(prompt, false);
// 构建可观测性上下文,用于记录模型调用的指标信息
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(OpenAiApiConstants.PROVIDER_NAME)
.build();
// 在可观测性记录范围内执行调用
ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {
// 使用重试模板执行 API 调用,获取聊天补全响应实体
ResponseEntity<ChatCompletion> completionEntity = this.retryTemplate
.execute(ctx -> this.openAiApi.chatCompletionEntity(request, getAdditionalHttpHeaders(prompt)));
var chatCompletion = completionEntity.getBody();
// 校验响应体是否为空
if (chatCompletion == null) {
return new ChatResponse(List.of());
}
List<Choice> choices = chatCompletion.choices();
// 校验 choices 列表是否为空
if (choices == null) {
return new ChatResponse(List.of());
}
// 将每个 choice 转换为 Generation 对象
List<Generation> generations = choices.stream().map(choice -> {
// 构建元数据:id、role、索引、完成原因、拒绝信息、标注等
Map<String, Object> metadata = Map.of(
"id", chatCompletion.id() != null ? chatCompletion.id() : "",
"role", choice.message().role() != null ? choice.message().role().name() : "",
"index", choice.index() != null ? choice.index() : 0,
"finishReason", getFinishReasonJson(choice.finishReason()),
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "",
"annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of(Map.of()));
return buildGeneration(choice, metadata, request);
}).toList();
// 从响应头中提取速率限制信息
RateLimit rateLimit = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity);
// 处理 token 用量统计
OpenAiApi.Usage usage = chatCompletion.usage();
Usage currentChatResponseUsage = usage != null ? getDefaultUsage(usage) : new EmptyUsage();
// 累积本次调用和之前调用的用量
Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage,
previousChatResponse);
// 构建最终响应对象
ChatResponse chatResponse = new ChatResponse(generations,
from(chatCompletion, rateLimit, accumulatedUsage));
// 将响应记录到可观测性上下文
observationContext.setResponse(chatResponse);
return chatResponse;
});
// 判断是否需要执行工具调用
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
// 执行工具调用
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
if (toolExecutionResult.returnDirect()) {
// 直接返回工具执行结果给客户端(不继续调用模型)
return ChatResponse.builder()
.from(response)
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
.build();
} else {
// 将工具执行结果发送回模型,继续对话
return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
response);
}
}
// 无需工具调用,直接返回响应
return response;
}
8.3 请求转换方法
scss
// 创建聊天补全请求(stream: true=流式,false=非流式)
ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
// 将 Prompt 中的各类型 Message 转换为 OpenAI 的 ChatCompletionMessage 列表
List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(message -> {
// 处理 USER 和 SYSTEM 类型消息
if (message.getMessageType() == MessageType.USER || message.getMessageType() == MessageType.SYSTEM) {
Object content = message.getText();
if (message instanceof UserMessage userMessage) {
// 处理多媒体内容(图片等)
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
List<MediaContent> contentList = new ArrayList<>(List.of(new MediaContent(message.getText())));
contentList.addAll(userMessage.getMedia().stream().map(this::mapToMediaContent).toList());
content = contentList;
}
}
return List.of(new ChatCompletionMessage(content,
ChatCompletionMessage.Role.valueOf(message.getMessageType().name())));
}
// 处理 ASSISTANT 类型消息(模型历史回复)
else if (message.getMessageType() == MessageType.ASSISTANT) {
var assistantMessage = (AssistantMessage) message;
// 转换工具调用
List<ToolCall> toolCalls = null;
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> {
var function = new ChatCompletionFunction(toolCall.name(), toolCall.arguments());
return new ToolCall(toolCall.id(), toolCall.type(), function);
}).toList();
}
// 处理音频输出
AudioOutput audioOutput = null;
if (!CollectionUtils.isEmpty(assistantMessage.getMedia())) {
Assert.isTrue(assistantMessage.getMedia().size() == 1,
"Only one media content is supported for assistant messages");
audioOutput = new AudioOutput(assistantMessage.getMedia().get(0).getId(), null, null, null);
}
return List.of(new ChatCompletionMessage(assistantMessage.getText(),
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput, null, null));
}
// 处理 TOOL 类型消息(工具调用结果)
else if (message.getMessageType() == MessageType.TOOL) {
ToolResponseMessage toolMessage = (ToolResponseMessage) message;
toolMessage.getResponses()
.forEach(response -> Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id"));
return toolMessage.getResponses()
.stream()
.map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(),
tr.id(), null, null, null, null, null))
.toList();
}
else {
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
}
}).flatMap(List::stream).toList();
// 创建基础请求对象
ChatCompletionRequest request = new ChatCompletionRequest(chatCompletionMessages, stream);
// 合并运行时选项(如 temperature、maxTokens 等)
OpenAiChatOptions requestOptions = (OpenAiChatOptions) prompt.getOptions();
request = ModelOptionsUtils.merge(requestOptions, request, ChatCompletionRequest.class);
// 添加工具定义(Function Calling)
List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions);
if (!CollectionUtils.isEmpty(toolDefinitions)) {
request = ModelOptionsUtils.merge(
OpenAiChatOptions.builder().tools(this.getFunctionTools(toolDefinitions)).build(), request,
ChatCompletionRequest.class);
}
// 非流式请求时移除 streamOptions 参数
if (request.streamOptions() != null && !stream) {
logger.warn("Removing streamOptions from the request as it is not a streaming request!");
request = request.streamOptions(null);
}
return request;
}
9 多Provider适配实现
9.1 Provider统一抽象
Spring AI通过统一的ChatModel接口屏蔽了不同Provider的差异:
java
// 所有Provider都实现相同的ChatModel接口
// 1. OpenAI实现
public class OpenAiChatModel implements ChatModel { ... }
// 2. Azure OpenAI实现
public class AzureOpenAiChatModel implements ChatModel { ... }
// 3. Ollama实现
public class OllamaChatModel implements ChatModel { ... }
// 4. HuggingFace实现
public class HuggingFaceChatModel implements ChatModel { ... }
9.2 使用示例
arduino
@Service
public class MultiProviderService {
// 可以轻松切换Provider,代码无需修改
private final ChatModel chatModel;
public MultiProviderService(ChatModel chatModel) {
this.chatModel = chatModel;
}
public String chat(String message) {
return chatModel.call(message);
}
}
10 小结
本文深入剖析了Spring AI Model层的设计与实现,涵盖以下核心内容:
- Generic Model API :
Model、ModelRequest、ModelResponse等基础接口构成了整个AI模型交互的基石,通过泛型设计实现了高度的类型安全性和可扩展性。 - ChatModel体系 :在通用API之上扩展了对话场景的特化接口,提供了
call()和stream()两种调用方式,满足同步和异步场景需求。 - Prompt与Message :
Prompt作为消息容器,支持多种角色消息(System、User、Assistant、Tool),PromptTemplate基于StTemplateRenderer引擎实现动态提示词渲染。 - 选项合并机制:支持启动时默认选项和运行时请求选项的两层配置,运行时选项可以覆盖默认选项,提供了灵活的配置能力。
- 流式调用 :基于Project Reactor的
Flux类型实现响应式流式输出,有效降低首字延迟,提升用户体验。 - OpenAiChatModel实现:作为官方参考实现,展示了完整的请求转换、响应解析、观测集成和工具调用处理逻辑。