SpringAI(GA):Tool源码+工具触发链路解读

原文链接:SpringAI(GA):Tool源码+工具触发链路解读

教程说明

说明:本教程将采用2025年5月20日正式的GA版,给出如下内容

  1. 核心功能模块的快速上手教程
  2. 核心功能模块的源码级解读
  3. Spring ai alibaba增强的快速上手教程 + 源码级解读

版本:JDK21 + SpringBoot3.4.5 + SpringAI 1.0.0 + SpringAI Alibaba 1.0.0.2

将陆续完成如下章节教程。本章是第三章(tool整合)下的Tool源码解读+工具触发链路解读

代码开源如下:github.com/GTyingzi/sp...

微信推文往届解读可参考:

第一章内容

SpringAI(GA)的chat:快速上手+自动注入源码解读

SpringAI(GA):ChatClient调用链路解读

第二章内容

SpringAI的Advisor:快速上手+源码解读

SpringAI(GA):Sqlite、Mysql、Redis消息存储快速上手

第三章内容

SpringAI(GA):Tool工具整合---快速上手

第五章内容

SpringAI(GA):内存、Redis、ES的向量数据库存储---快速上手

SpringAI(GA):向量数据库理论源码解读+Redis、Es接入源码

第六章内容

SpringAI(GA):RAG快速上手+模块化解读

SpringAI(GA):RAG下的ETL快速上手

SpringAI(GA):RAG下的ETL源码解读

整理不易,获取更好的观赏体验,可付费获取飞书云文档Spring AI最新教程权限,目前49.9,随着内容不断完善,会逐步涨价。

注:M6版快速上手教程+源码解读飞书云文档已免费提供

为鼓励大家积极参与为Spring Ai Alibaba开源社区:github.com/alibaba/spr...

Tool 源码解读

!TIP\] 本文档是关于 Tool 调用底层机制的梳理

工具自动注入的源码《工具自动注入源码》

涉及到 MCP 部分的 tool 参考 《MCP 下的 ToolCallback》

工具各类说明(不含 MCP 内容)

Tool(工具注解)

标记一个方法为 SpringAI 中的工具,从而使方法能够被框架识别并用于 AI 模型的调用

  • name:工具名称,默认为方法名称
  • description:工具描述信息,默认为方法名称
  • returnDirect:指定工具结果是否直接返回或传递给模型,默认为 false
  • resultConverter:工具调用结果的转化器,默认使用 DefaultToolCallResultConverter,将结果转换为字符串
java 复制代码
package org.springframework.ai.tool.annotation;

import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

import org.springframework.ai.tool.execution.DefaultToolCallResultConverter;
import org.springframework.ai.tool.execution.ToolCallResultConverter;

@Target({ ElementType.METHOD, ElementType.ANNOTATIONTYPE })
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Tool {

    String name() default "";

    String description() default "";

    boolean returnDirect() default false;
    Class<? extends ToolCallResultConverter> resultConverter() default DefaultToolCallResultConverter.class;

}
ToolParam

用来标记工具方法的接口入参,通常和 @Tool 注解配合使用

  • required:指定参数是否为必需参数,默认为 true
  • description:参数的描述信息
java 复制代码
package org.springframework.ai.tool.annotation;

import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target({ ElementType.PARAMETER, ElementType.FIELD, ElementType.ANNOTATIONTYPE })
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface ToolParam {

    boolean required() default true;

    String description() default "";

}

ToolDefinition(工具定义)

定义工具的基本信息和调用参数结构

|-------------|------------------------|
| 方法名称 | 描述 |
| name | 工具的名称,提供给AI模型时,需要标识唯一 |
| description | 工具的描述信息,帮助AI模型理解工具的用途 |
| inputSchema | 定义工具的入参结构 |
| builder | 创建默认的ToolDefinition构建器 |

java 复制代码
package org.springframework.ai.tool.definition;

public interface ToolDefinition {

    String name();

    String description();

    String inputSchema();

    static DefaultToolDefinition.Builder builder() {
       return DefaultToolDefinition.builder();
    }

}
DefaultToolDefinition

ToolDefinition 接口的默认实现类

java 复制代码
package org.springframework.ai.tool.definition;

import org.springframework.ai.util.ParsingUtils;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

public record DefaultToolDefinition(String name, String description, String inputSchema) implements ToolDefinition {

    public DefaultToolDefinition {
       Assert.hasText(name, "name cannot be null or empty");
       Assert.hasText(description, "description cannot be null or empty");
       Assert.hasText(inputSchema, "inputSchema cannot be null or empty");
    }

    public static Builder builder() {
       return new Builder();
    }

    public static final class Builder {

       private String name;

       private String description;

       private String inputSchema;

       private Builder() {
       }

       public Builder name(String name) {
          this.name = name;
          return this;
       }

       public Builder description(String description) {
          this.description = description;
          return this;
       }

       public Builder inputSchema(String inputSchema) {
          this.inputSchema = inputSchema;
          return this;
       }

       public ToolDefinition build() {
          if (!StringUtils.hasText(this.description)) {
             Assert.hasText(this.name, "toolName cannot be null or empty");
             this.description = ParsingUtils.reConcatenateCamelCase(this.name, " ");
          }
          return new DefaultToolDefinition(this.name, this.description, this.inputSchema);
       }

    }

}
ToolDefinitions

主要用于根据 Java 的 Method 对象快速创建和构建 ToolDefinition 实例

java 复制代码
package org.springframework.ai.tool.support;

import java.lang.reflect.Method;

import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.util.json.schema.JsonSchemaGenerator;
import org.springframework.util.Assert;

public final class ToolDefinitions {

    private ToolDefinitions() {
    }

    public static DefaultToolDefinition.Builder builder(Method method) {
       Assert.notNull(method, "method cannot be null");
       return DefaultToolDefinition.builder()
          .name(ToolUtils.getToolName(method))
          .description(ToolUtils.getToolDescription(method))
          .inputSchema(JsonSchemaGenerator.generateForMethodInput(method));
    }

    public static ToolDefinition from(Method method) {
       return builder(method).build();
    }

}

ToolMetadata(工具元数据)

描述工具的元数据信息,目前仅用来控制是否直接将工具结果返回给 AI 模型

|--------------|-------------------------------------|
| 方法名称 | 描述 |
| returnDirect | 是否将工具结果直接返回给调用方(如 AI 模型),默认返回 false |
| builder | 创建默认的ToolMetadata构建器 |
| from | 通过反射方法对象创建ToolMetadata实例 |

java 复制代码
package org.springframework.ai.tool.metadata;

import java.lang.reflect.Method;

import org.springframework.ai.tool.support.ToolUtils;
import org.springframework.util.Assert;

public interface ToolMetadata {

    default boolean returnDirect() {
       return false;
    }
    static DefaultToolMetadata.Builder builder() {
       return DefaultToolMetadata.builder();
    }
    
    static ToolMetadata from(Method method) {
       Assert.notNull(method, "method cannot be null");
       return DefaultToolMetadata.builder().returnDirect(ToolUtils.getToolReturnDirect(method)).build();
    }

}
DefaultToolMetadata

ToolMetadata 接口类的默认实现类

java 复制代码
package org.springframework.ai.tool.metadata;


public record DefaultToolMetadata(boolean returnDirect) implements ToolMetadata {

    public static Builder builder() {
       return new Builder();
    }

    public static final class Builder {

       private boolean returnDirect = false;

       private Builder() {
       }

       public Builder returnDirect(boolean returnDirect) {
          this.returnDirect = returnDirect;
          return this;
       }

       public ToolMetadata build() {
          return new DefaultToolMetadata(this.returnDirect);
       }

    }

}

ToolUtils(工具的辅助类)

主要为 SpringAI 工具相关的内部框架提供各种静态辅助方法,简化对 @Tool 注解方法的元数据提取、工具名称/描述处理、结果转换器实例化、工具名唯一性校验等操作

|----------------------------|-----------------------------------------------------------|
| 方法名称 | 描述 |
| getToolName | 获取方法的工具名称。优先取 @Tool 注解的 name 属性,否则用方法名 |
| getToolDescriptionFromName | 根据工具名生成描述(如将驼峰命名转为带空格的描述),便于自动生成人类可读的说明 |
| getToolDescription | 获取方法的工具描述。优先取 @Tool 注解的 description 属性,否则用方法名或自动生成 |
| getToolReturnDirect | 判断方法的 @Tool 注解是否设置了 returnDirect,用于标记工具是否直接返回结果 |
| getToolCallResultConverter | 获取方法指定的结果转换器实例。优先取 @Tool 注解的 resultConverter 类型,未指定则用默认实现 |
| getDuplicateToolNames | 检查工具回调中是否有重复的工具名,返回所有的重复工具名称 |

java 复制代码
package org.springframework.ai.tool.support;

import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.execution.DefaultToolCallResultConverter;
import org.springframework.ai.tool.execution.ToolCallResultConverter;
import org.springframework.ai.util.ParsingUtils;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

public final class ToolUtils {

    private ToolUtils() {
    }

    public static String getToolName(Method method) {
       Assert.notNull(method, "method cannot be null");
       var tool = method.getAnnotation(Tool.class);
       if (tool == null) {
          return method.getName();
       }
       return StringUtils.hasText(tool.name()) ? tool.name() : method.getName();
    }

    public static String getToolDescriptionFromName(String toolName) {
       Assert.hasText(toolName, "toolName cannot be null or empty");
       return ParsingUtils.reConcatenateCamelCase(toolName, " ");
    }

    public static String getToolDescription(Method method) {
       Assert.notNull(method, "method cannot be null");
       var tool = method.getAnnotation(Tool.class);
       if (tool == null) {
          return ParsingUtils.reConcatenateCamelCase(method.getName(), " ");
       }
       return StringUtils.hasText(tool.description()) ? tool.description() : method.getName();
    }

    public static boolean getToolReturnDirect(Method method) {
       Assert.notNull(method, "method cannot be null");
       var tool = method.getAnnotation(Tool.class);
       return tool != null && tool.returnDirect();
    }

    public static ToolCallResultConverter getToolCallResultConverter(Method method) {
       Assert.notNull(method, "method cannot be null");
       var tool = method.getAnnotation(Tool.class);
       if (tool == null) {
          return new DefaultToolCallResultConverter();
       }
       var type = tool.resultConverter();
       try {
          return type.getDeclaredConstructor().newInstance();
       }
       catch (Exception e) {
          throw new IllegalArgumentException("Failed to instantiate ToolCallResultConverter: " + type, e);
       }
    }

    public static List<String> getDuplicateToolNames(List<ToolCallback> toolCallbacks) {
       Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
       return toolCallbacks.stream()
          .collect(Collectors.groupingBy(toolCallback -> toolCallback.getToolDefinition().name(),
                Collectors.counting()))
          .entrySet()
          .stream()
          .filter(entry -> entry.getValue() > 1)
          .map(Map.Entry::getKey)
          .collect(Collectors.toList());
    }

    public static List<String> getDuplicateToolNames(ToolCallback... toolCallbacks) {
       Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
       return getDuplicateToolNames(Arrays.asList(toolCallbacks));
    }

}

ToolCallback(工具回调)

该接口定义了一个可被 AI 模型触发执行的工具回调

|-------------------|------------------------|
| 方法名称 | 描述 |
| getToolDefinition | 获取工具的定义信息 |
| getToolMetadata | 获取工具的元数据信息 |
| call | 传入工具入参、工具上下文等信息,执行工具逻辑 |

java 复制代码
package org.springframework.ai.tool;

import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.metadata.ToolMetadata;
import org.springframework.lang.Nullable;

public interface ToolCallback {

    ToolDefinition getToolDefinition();

    default ToolMetadata getToolMetadata() {
       return ToolMetadata.builder().build();
    }

    String call(String toolInput);

    default String call(String toolInput, @Nullable ToolContext tooContext) {
       if (tooContext != null && !tooContext.getContext().isEmpty()) {
          throw new UnsupportedOperationException("Tool context is not supported!");
       }
       return call(toolInput);
    }

}
FunctionToolCallback

用于将 Java 的函数式接口(如 Function、BiFunction、Supplier、Consumer)封装为可被 AI 框架调用的工具

  • ToolDefinition toolDefinition:工具定义
  • ToolMetadata toolMetadata:工具元数据
  • Type toolInputType:工具输入参数的类型,用于 JSON 反序列化
  • BiFunction<I, ToolContext, O> toolFunction:实际执行的函数逻辑
  • ToolCallResultConverter toolCallResultConverter:结果转化器,默认将工具结果转换为字符串
java 复制代码
package org.springframework.ai.tool.function;

import java.lang.reflect.Type;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.DefaultToolCallResultConverter;
import org.springframework.ai.tool.execution.ToolCallResultConverter;
import org.springframework.ai.tool.metadata.ToolMetadata;
import org.springframework.ai.tool.support.ToolUtils;
import org.springframework.ai.util.json.JsonParser;
import org.springframework.ai.util.json.schema.JsonSchemaGenerator;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

public class FunctionToolCallback<I, O> implements ToolCallback {

    private static final Logger logger = LoggerFactory.getLogger(FunctionToolCallback.class);

    private static final ToolCallResultConverter DEFAULTRESULTCONVERTER = new DefaultToolCallResultConverter();

    private static final ToolMetadata DEFAULTTOOLMETADATA = ToolMetadata.builder().build();

    private final ToolDefinition toolDefinition;

    private final ToolMetadata toolMetadata;

    private final Type toolInputType;

    private final BiFunction<I, ToolContext, O> toolFunction;

    private final ToolCallResultConverter toolCallResultConverter;

    public FunctionToolCallback(ToolDefinition toolDefinition, @Nullable ToolMetadata toolMetadata, Type toolInputType,
          BiFunction<I, ToolContext, O> toolFunction, @Nullable ToolCallResultConverter toolCallResultConverter) {
       Assert.notNull(toolDefinition, "toolDefinition cannot be null");
       Assert.notNull(toolInputType, "toolInputType cannot be null");
       Assert.notNull(toolFunction, "toolFunction cannot be null");
       this.toolDefinition = toolDefinition;
       this.toolMetadata = toolMetadata != null ? toolMetadata : DEFAULTTOOLMETADATA;
       this.toolFunction = toolFunction;
       this.toolInputType = toolInputType;
       this.toolCallResultConverter = toolCallResultConverter != null ? toolCallResultConverter
             : DEFAULTRESULTCONVERTER;
    }

    @Override
    public ToolDefinition getToolDefinition() {
       return this.toolDefinition;
    }

    @Override
    public ToolMetadata getToolMetadata() {
       return this.toolMetadata;
    }

    @Override
    public String call(String toolInput) {
       return call(toolInput, null);
    }

    @Override
    public String call(String toolInput, @Nullable ToolContext toolContext) {
       Assert.hasText(toolInput, "toolInput cannot be null or empty");

       logger.debug("Starting execution of tool: {}", this.toolDefinition.name());

       I request = JsonParser.fromJson(toolInput, this.toolInputType);
       O response = this.toolFunction.apply(request, toolContext);

       logger.debug("Successful execution of tool: {}", this.toolDefinition.name());

       return this.toolCallResultConverter.convert(response, null);
    }

    @Override
    public String toString() {
       return "FunctionToolCallback{" + "toolDefinition=" + this.toolDefinition + ", toolMetadata=" + this.toolMetadata
             + '}';
    }

    /**
     * Build a {@link FunctionToolCallback} from a {@link BiFunction}.
     */
    public static <I, O> Builder<I, O> builder(String name, BiFunction<I, ToolContext, O> function) {
       return new Builder<>(name, function);
    }

    /**
     * Build a {@link FunctionToolCallback} from a {@link Function}.
     */
    public static <I, O> Builder<I, O> builder(String name, Function<I, O> function) {
       Assert.notNull(function, "function cannot be null");
       return new Builder<>(name, (request, context) -> function.apply(request));
    }

    /**
     * Build a {@link FunctionToolCallback} from a {@link Supplier}.
     */
    public static <O> Builder<Void, O> builder(String name, Supplier<O> supplier) {
       Assert.notNull(supplier, "supplier cannot be null");
       Function<Void, O> function = input -> supplier.get();
       return builder(name, function).inputType(Void.class);
    }

    /**
     * Build a {@link FunctionToolCallback} from a {@link Consumer}.
     */
    public static <I> Builder<I, Void> builder(String name, Consumer<I> consumer) {
       Assert.notNull(consumer, "consumer cannot be null");
       Function<I, Void> function = (I input) -> {
          consumer.accept(input);
          return null;
       };
       return builder(name, function);
    }

    public static final class Builder<I, O> {

       private String name;

       private String description;

       private String inputSchema;

       private Type inputType;

       private ToolMetadata toolMetadata;

       private BiFunction<I, ToolContext, O> toolFunction;

       private ToolCallResultConverter toolCallResultConverter;

       private Builder(String name, BiFunction<I, ToolContext, O> toolFunction) {
          Assert.hasText(name, "name cannot be null or empty");
          Assert.notNull(toolFunction, "toolFunction cannot be null");
          this.name = name;
          this.toolFunction = toolFunction;
       }

       public Builder<I, O> description(String description) {
          this.description = description;
          return this;
       }

       public Builder<I, O> inputSchema(String inputSchema) {
          this.inputSchema = inputSchema;
          return this;
       }

       public Builder<I, O> inputType(Type inputType) {
          this.inputType = inputType;
          return this;
       }

       public Builder<I, O> inputType(ParameterizedTypeReference<?> inputType) {
          Assert.notNull(inputType, "inputType cannot be null");
          this.inputType = inputType.getType();
          return this;
       }

       public Builder<I, O> toolMetadata(ToolMetadata toolMetadata) {
          this.toolMetadata = toolMetadata;
          return this;
       }

       public Builder<I, O> toolCallResultConverter(ToolCallResultConverter toolCallResultConverter) {
          this.toolCallResultConverter = toolCallResultConverter;
          return this;
       }

       public FunctionToolCallback<I, O> build() {
          Assert.notNull(this.inputType, "inputType cannot be null");
          var toolDefinition = DefaultToolDefinition.builder()
             .name(this.name)
             .description(StringUtils.hasText(this.description) ? this.description
                   : ToolUtils.getToolDescriptionFromName(this.name))
             .inputSchema(StringUtils.hasText(this.inputSchema) ? this.inputSchema
                   : JsonSchemaGenerator.generateForType(this.inputType))
             .build();
          return new FunctionToolCallback<>(toolDefinition, this.toolMetadata, this.inputType, this.toolFunction,
                this.toolCallResultConverter);
       }

    }

}
MethodToolCallback

用于将 Java 方法封装为可被 AI 框架调用的工具

  • ToolDefinition toolDefinition:工具定义
  • ToolMetadata toolMetadata:工具元数据
  • Method toolMethod:要调用的 Java 方法对象
  • Object toolObject:方法所属对象,静态方法可为 null,实例方法必须提供
  • ToolCallResultConverter toolCallResultConverter:结果转化器,默认将工具结果转换为字符串
java 复制代码
package org.springframework.ai.tool.method;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Type;
import java.util.Map;
import java.util.stream.Stream;

import com.fasterxml.jackson.core.type.TypeReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.DefaultToolCallResultConverter;
import org.springframework.ai.tool.execution.ToolCallResultConverter;
import org.springframework.ai.tool.execution.ToolExecutionException;
import org.springframework.ai.tool.metadata.ToolMetadata;
import org.springframework.ai.util.json.JsonParser;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;

public final class MethodToolCallback implements ToolCallback {

    private static final Logger logger = LoggerFactory.getLogger(MethodToolCallback.class);

    private static final ToolCallResultConverter DEFAULTRESULTCONVERTER = new DefaultToolCallResultConverter();

    private static final ToolMetadata DEFAULTTOOLMETADATA = ToolMetadata.builder().build();

    private final ToolDefinition toolDefinition;

    private final ToolMetadata toolMetadata;

    private final Method toolMethod;

    @Nullable
    private final Object toolObject;

    private final ToolCallResultConverter toolCallResultConverter;

    public MethodToolCallback(ToolDefinition toolDefinition, @Nullable ToolMetadata toolMetadata, Method toolMethod,
          @Nullable Object toolObject, @Nullable ToolCallResultConverter toolCallResultConverter) {
       Assert.notNull(toolDefinition, "toolDefinition cannot be null");
       Assert.notNull(toolMethod, "toolMethod cannot be null");
       Assert.isTrue(Modifier.isStatic(toolMethod.getModifiers()) || toolObject != null,
             "toolObject cannot be null for non-static methods");
       this.toolDefinition = toolDefinition;
       this.toolMetadata = toolMetadata != null ? toolMetadata : DEFAULTTOOLMETADATA;
       this.toolMethod = toolMethod;
       this.toolObject = toolObject;
       this.toolCallResultConverter = toolCallResultConverter != null ? toolCallResultConverter
             : DEFAULTRESULTCONVERTER;
    }

    @Override
    public ToolDefinition getToolDefinition() {
       return this.toolDefinition;
    }

    @Override
    public ToolMetadata getToolMetadata() {
       return this.toolMetadata;
    }

    @Override
    public String call(String toolInput) {
       return call(toolInput, null);
    }

    @Override
    public String call(String toolInput, @Nullable ToolContext toolContext) {
       Assert.hasText(toolInput, "toolInput cannot be null or empty");

       logger.debug("Starting execution of tool: {}", this.toolDefinition.name());

       validateToolContextSupport(toolContext);

       Map<String, Object> toolArguments = extractToolArguments(toolInput);

       Object[] methodArguments = buildMethodArguments(toolArguments, toolContext);

       Object result = callMethod(methodArguments);

       logger.debug("Successful execution of tool: {}", this.toolDefinition.name());

       Type returnType = this.toolMethod.getGenericReturnType();

       return this.toolCallResultConverter.convert(result, returnType);
    }

    private void validateToolContextSupport(@Nullable ToolContext toolContext) {
       var isNonEmptyToolContextProvided = toolContext != null && !CollectionUtils.isEmpty(toolContext.getContext());
       var isToolContextAcceptedByMethod = Stream.of(this.toolMethod.getParameterTypes())
          .anyMatch(type -> ClassUtils.isAssignable(type, ToolContext.class));
       if (isToolContextAcceptedByMethod && !isNonEmptyToolContextProvided) {
          throw new IllegalArgumentException("ToolContext is required by the method as an argument");
       }
    }

    private Map<String, Object> extractToolArguments(String toolInput) {
       return JsonParser.fromJson(toolInput, new TypeReference<>() {
       });
    }

    // Based on the implementation in MethodToolCallback.
    private Object[] buildMethodArguments(Map<String, Object> toolInputArguments, @Nullable ToolContext toolContext) {
       return Stream.of(this.toolMethod.getParameters()).map(parameter -> {
          if (parameter.getType().isAssignableFrom(ToolContext.class)) {
             return toolContext;
          }
          Object rawArgument = toolInputArguments.get(parameter.getName());
          return buildTypedArgument(rawArgument, parameter.getParameterizedType());
       }).toArray();
    }

    @Nullable
    private Object buildTypedArgument(@Nullable Object value, Type type) {
       if (value == null) {
          return null;
       }

       if (type instanceof Class<?>) {
          return JsonParser.toTypedObject(value, (Class<?>) type);
       }

       // For generic types, use the fromJson method that accepts Type
       String json = JsonParser.toJson(value);
       return JsonParser.fromJson(json, type);
    }

    @Nullable
    private Object callMethod(Object[] methodArguments) {
       if (isObjectNotPublic() || isMethodNotPublic()) {
          this.toolMethod.setAccessible(true);
       }

       Object result;
       try {
          result = this.toolMethod.invoke(this.toolObject, methodArguments);
       }
       catch (IllegalAccessException ex) {
          throw new IllegalStateException("Could not access method: " + ex.getMessage(), ex);
       }
       catch (InvocationTargetException ex) {
          throw new ToolExecutionException(this.toolDefinition, ex.getCause());
       }
       return result;
    }

    private boolean isObjectNotPublic() {
       return this.toolObject != null && !Modifier.isPublic(this.toolObject.getClass().getModifiers());
    }

    private boolean isMethodNotPublic() {
       return !Modifier.isPublic(this.toolMethod.getModifiers());
    }

    @Override
    public String toString() {
       return "MethodToolCallback{" + "toolDefinition=" + this.toolDefinition + ", toolMetadata=" + this.toolMetadata
             + '}';
    }

    public static Builder builder() {
       return new Builder();
    }

    public static final class Builder {

       private ToolDefinition toolDefinition;

       private ToolMetadata toolMetadata;

       private Method toolMethod;

       private Object toolObject;

       private ToolCallResultConverter toolCallResultConverter;

       private Builder() {
       }

       public Builder toolDefinition(ToolDefinition toolDefinition) {
          this.toolDefinition = toolDefinition;
          return this;
       }

       public Builder toolMetadata(ToolMetadata toolMetadata) {
          this.toolMetadata = toolMetadata;
          return this;
       }

       public Builder toolMethod(Method toolMethod) {
          this.toolMethod = toolMethod;
          return this;
       }

       public Builder toolObject(Object toolObject) {
          this.toolObject = toolObject;
          return this;
       }

       public Builder toolCallResultConverter(ToolCallResultConverter toolCallResultConverter) {
          this.toolCallResultConverter = toolCallResultConverter;
          return this;
       }

       public MethodToolCallback build() {
          return new MethodToolCallback(this.toolDefinition, this.toolMetadata, this.toolMethod, this.toolObject,
                this.toolCallResultConverter);
       }

    }

}

ToolCallbackProvider(工具回调提供者)

抽象出 ToolCallback 的获取方式,支持从不同来源(MethodToolCallbackProvider、StaticToolCallbackProvider 等)集中管理 ToolCallback,便于框架扩展和集成

java 复制代码
package org.springframework.ai.tool;

import java.util.List;

public interface ToolCallbackProvider {

    ToolCallback[] getToolCallbacks();

    static ToolCallbackProvider from(List<? extends ToolCallback> toolCallbacks) {
       return new StaticToolCallbackProvider(toolCallbacks);
    }

    static ToolCallbackProvider from(ToolCallback... toolCallbacks) {
       return new StaticToolCallbackProvider(toolCallbacks);
    }

}
StaticToolCallbackProvider

ToolCallbackProvider 的一个简单实现,用于以静态方式集中管理和提供一组不可变的 ToolCallback

java 复制代码
package org.springframework.ai.tool;

import java.util.List;

import org.springframework.util.Assert;

public class StaticToolCallbackProvider implements ToolCallbackProvider {

    private final ToolCallback[] toolCallbacks;

    public StaticToolCallbackProvider(ToolCallback... toolCallbacks) {
       Assert.notNull(toolCallbacks, "ToolCallbacks must not be null");
       this.toolCallbacks = toolCallbacks;
    }

    public StaticToolCallbackProvider(List<? extends ToolCallback> toolCallbacks) {
       Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
       this.toolCallbacks = toolCallbacks.toArray(new ToolCallback[0]);
    }

    @Override
    public ToolCallback[] getToolCallbacks() {
       return this.toolCallbacks;
    }

}
MethodToolCallbackProvider

用于从带 @Tool 注解的方法动态构建 ToolCallback 实例,统一提供给 AI 框架调用。适合基于注解的工具注册场景,简化工具方法的发现与管理。

getToolCallbacks 方法流程如下:

  1. 遍历工具对象:对 toolObjects 列表中的每个对象进行处理
  2. 获取方法列表:对每个对象,获取其所有声明的方法
  3. 筛选 @Tool 注解方法:过滤出带有 @Tool 注解的方法
  4. 排除函数式类型方法:过滤掉返回类型为 Function、Supplier、Consumer 的方法
  5. 构建 ToolCallback 实例:对每个符合条件的方法,使用 MethodToolCallback.builder() 构建 ToolCallback 实例
  6. 收集所有 ToolCallback:将所有构建好的 ToolCallback 实例收集到一个数组中
  7. 校验工具名唯一性:调用 validateToolCallbacks,确保所有工具名唯一
java 复制代码
package org.springframework.ai.tool.method;

import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.metadata.ToolMetadata;
import org.springframework.ai.tool.support.ToolDefinitions;
import org.springframework.ai.tool.support.ToolUtils;
import org.springframework.aop.support.AopUtils;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;

public final class MethodToolCallbackProvider implements ToolCallbackProvider {

    private static final Logger logger = LoggerFactory.getLogger(MethodToolCallbackProvider.class);

    private final List<Object> toolObjects;

    private MethodToolCallbackProvider(List<Object> toolObjects) {
       Assert.notNull(toolObjects, "toolObjects cannot be null");
       Assert.noNullElements(toolObjects, "toolObjects cannot contain null elements");
       assertToolAnnotatedMethodsPresent(toolObjects);
       this.toolObjects = toolObjects;
       validateToolCallbacks(getToolCallbacks());
    }

    private void assertToolAnnotatedMethodsPresent(List<Object> toolObjects) {

       for (Object toolObject : toolObjects) {
          List<Method> toolMethods = Stream
             .of(ReflectionUtils.getDeclaredMethods(
                   AopUtils.isAopProxy(toolObject) ? AopUtils.getTargetClass(toolObject) : toolObject.getClass()))
             .filter(toolMethod -> toolMethod.isAnnotationPresent(Tool.class))
             .filter(toolMethod -> !isFunctionalType(toolMethod))
             .toList();

          if (toolMethods.isEmpty()) {
             throw new IllegalStateException("No @Tool annotated methods found in " + toolObject + "."
                   + "Did you mean to pass a ToolCallback or ToolCallbackProvider? If so, you have to use .toolCallbacks() instead of .tool()");
          }
       }
    }

    @Override
    public ToolCallback[] getToolCallbacks() {
       var toolCallbacks = this.toolObjects.stream()
          .map(toolObject -> Stream
             .of(ReflectionUtils.getDeclaredMethods(
                   AopUtils.isAopProxy(toolObject) ? AopUtils.getTargetClass(toolObject) : toolObject.getClass()))
             .filter(toolMethod -> toolMethod.isAnnotationPresent(Tool.class))
             .filter(toolMethod -> !isFunctionalType(toolMethod))
             .map(toolMethod -> MethodToolCallback.builder()
                .toolDefinition(ToolDefinitions.from(toolMethod))
                .toolMetadata(ToolMetadata.from(toolMethod))
                .toolMethod(toolMethod)
                .toolObject(toolObject)
                .toolCallResultConverter(ToolUtils.getToolCallResultConverter(toolMethod))
                .build())
             .toArray(ToolCallback[]::new))
          .flatMap(Stream::of)
          .toArray(ToolCallback[]::new);

       validateToolCallbacks(toolCallbacks);

       return toolCallbacks;
    }

    private boolean isFunctionalType(Method toolMethod) {
       var isFunction = ClassUtils.isAssignable(toolMethod.getReturnType(), Function.class)
             || ClassUtils.isAssignable(toolMethod.getReturnType(), Supplier.class)
             || ClassUtils.isAssignable(toolMethod.getReturnType(), Consumer.class);

       if (isFunction) {
          logger.warn("Method {} is annotated with @Tool but returns a functional type. "
                + "This is not supported and the method will be ignored.", toolMethod.getName());
       }

       return isFunction;
    }

    private void validateToolCallbacks(ToolCallback[] toolCallbacks) {
       List<String> duplicateToolNames = ToolUtils.getDuplicateToolNames(toolCallbacks);
       if (!duplicateToolNames.isEmpty()) {
          throw new IllegalStateException("Multiple tools with the same name (%s) found in sources: %s".formatted(
                String.join(", ", duplicateToolNames),
                this.toolObjects.stream().map(o -> o.getClass().getName()).collect(Collectors.joining(", "))));
       }
    }

    public static Builder builder() {
       return new Builder();
    }

    public static final class Builder {

       private List<Object> toolObjects;

       private Builder() {
       }

       public Builder toolObjects(Object... toolObjects) {
          Assert.notNull(toolObjects, "toolObjects cannot be null");
          this.toolObjects = Arrays.asList(toolObjects);
          return this;
       }

       public MethodToolCallbackProvider build() {
          return new MethodToolCallbackProvider(this.toolObjects);
       }

    }

}

ToolCallbackResolver(工具回调解析器)

为框架提供统一的工具回调解析入口,支持通过工具名查找实际的工具实现,便于解耦工具注册与调用逻辑,适合多工具动态分发场景

java 复制代码
package org.springframework.ai.tool.resolution;

import org.springframework.ai.tool.ToolCallback;
import org.springframework.lang.Nullable;

public interface ToolCallbackResolver {

    @Nullable
    ToolCallback resolve(String toolName);

}
DelegatingToolCallbackResolver

实现工具回调解析的链式委托机制,支持将多个不同来源(如 StaticToolCallbackResolver、SpringBeanToolCallbackResolver 等)的工具回调解析器组合起来,统一对外提供按名称查找 ToolCallback 的能力

java 复制代码
package org.springframework.ai.tool.resolution;

import java.util.List;

import org.springframework.ai.tool.ToolCallback;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

public class DelegatingToolCallbackResolver implements ToolCallbackResolver {

    private final List<ToolCallbackResolver> toolCallbackResolvers;

    public DelegatingToolCallbackResolver(List<ToolCallbackResolver> toolCallbackResolvers) {
       Assert.notNull(toolCallbackResolvers, "toolCallbackResolvers cannot be null");
       Assert.noNullElements(toolCallbackResolvers, "toolCallbackResolvers cannot contain null elements");
       this.toolCallbackResolvers = toolCallbackResolvers;
    }

    @Override
    @Nullable
    public ToolCallback resolve(String toolName) {
       Assert.hasText(toolName, "toolName cannot be null or empty");

       for (ToolCallbackResolver toolCallbackResolver : this.toolCallbackResolvers) {
          ToolCallback toolCallback = toolCallbackResolver.resolve(toolName);
          if (toolCallback != null) {
             return toolCallback;
          }
       }
       return null;
    }

}
StaticToolCallbackResolver

集中管理一组已知的 ToolCallback,通过工具名高效检索对应的 ToolCallback,便于 AI 框架按需调用工具,且实现简单、线程安全

java 复制代码
package org.springframework.ai.tool.resolution;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.tool.ToolCallback;
import org.springframework.util.Assert;

public class StaticToolCallbackResolver implements ToolCallbackResolver {

    private static final Logger logger = LoggerFactory.getLogger(StaticToolCallbackResolver.class);

    private final Map<String, ToolCallback> toolCallbacks = new HashMap<>();

    public StaticToolCallbackResolver(List<ToolCallback> toolCallbacks) {
       Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
       Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");

       toolCallbacks
          .forEach(toolCallback -> this.toolCallbacks.put(toolCallback.getToolDefinition().name(), toolCallback));
    }

    @Override
    public ToolCallback resolve(String toolName) {
       Assert.hasText(toolName, "toolName cannot be null or empty");
       logger.debug("ToolCallback resolution attempt from static registry");
       return this.toolCallbacks.get(toolName);
    }

}
SpringBeanToolCallbackResolver

基于 Spring ApplicationContext 的工具回调器,用于从 Spring 容器中按名称检索 bean,并将其包装为 ToolCallback

  • Map<String, ToolCallback> toolCallbacksCache:静态缓存,存储已解析过的工具回调
  • GenericApplicationContext applicationContext:Spring 应用上下文,用于查找和获取 bean
  • SchemaType schemaType:例使用的 schema 类型,决定参数结构描述的生成方式,默认为 JSONSCHEMA

resolve 方法说明

  1. 先查缓存,若命中则直接返回
  2. 若未命中缓存,则从 Spring 容器查找 bean,推断类型、输入参数类型,生成描述和 schema,构建 ToolCallback 并缓存
  3. 若找不到或异常,返回 null
java 复制代码
package org.springframework.ai.tool.resolution;

import java.util.HashMap;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

import com.fasterxml.jackson.annotation.JsonClassDescription;
import kotlin.jvm.functions.Function0;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.functions.Function2;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.function.FunctionToolCallback;
import org.springframework.ai.tool.support.ToolUtils;
import org.springframework.ai.util.json.schema.JsonSchemaGenerator;
import org.springframework.ai.util.json.schema.SchemaType;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Description;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.core.KotlinDetector;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.ResolvableType;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

public class SpringBeanToolCallbackResolver implements ToolCallbackResolver {

    private static final Logger logger = LoggerFactory.getLogger(SpringBeanToolCallbackResolver.class);

    private static final Map<String, ToolCallback> toolCallbacksCache = new HashMap<>();

    private static final SchemaType DEFAULTSCHEMATYPE = SchemaType.JSONSCHEMA;

    private final GenericApplicationContext applicationContext;

    private final SchemaType schemaType;

    public SpringBeanToolCallbackResolver(GenericApplicationContext applicationContext,
          @Nullable SchemaType schemaType) {
       Assert.notNull(applicationContext, "applicationContext cannot be null");

       this.applicationContext = applicationContext;
       this.schemaType = schemaType != null ? schemaType : DEFAULTSCHEMATYPE;
    }

    @Override
    public ToolCallback resolve(String toolName) {
       Assert.hasText(toolName, "toolName cannot be null or empty");

       logger.debug("ToolCallback resolution attempt from Spring application context");

       ToolCallback resolvedToolCallback = toolCallbacksCache.get(toolName);

       if (resolvedToolCallback != null) {
          return resolvedToolCallback;
       }

       try {
          ResolvableType toolType = TypeResolverHelper.resolveBeanType(this.applicationContext, toolName);
          ResolvableType toolInputType = (ResolvableType.forType(Supplier.class).isAssignableFrom(toolType))
                ? ResolvableType.forType(Void.class) : TypeResolverHelper.getFunctionArgumentType(toolType, 0);

          String toolDescription = resolveToolDescription(toolName, toolInputType.toClass());
          Object bean = this.applicationContext.getBean(toolName);

          resolvedToolCallback = buildToolCallback(toolName, toolType, toolInputType, toolDescription, bean);

          toolCallbacksCache.put(toolName, resolvedToolCallback);

          return resolvedToolCallback;
       }
       catch (Exception e) {
          logger.debug("ToolCallback resolution failed from Spring application context", e);
          return null;
       }
    }

    public SchemaType getSchemaType() {
       return this.schemaType;
    }

    private String resolveToolDescription(String toolName, Class<?> toolInputType) {
       Description descriptionAnnotation = this.applicationContext.findAnnotationOnBean(toolName, Description.class);
       if (descriptionAnnotation != null && StringUtils.hasText(descriptionAnnotation.value())) {
          return descriptionAnnotation.value();
       }

       JsonClassDescription jsonClassDescriptionAnnotation = toolInputType.getAnnotation(JsonClassDescription.class);
       if (jsonClassDescriptionAnnotation != null && StringUtils.hasText(jsonClassDescriptionAnnotation.value())) {
          return jsonClassDescriptionAnnotation.value();
       }

       return ToolUtils.getToolDescriptionFromName(toolName);
    }

    private ToolCallback buildToolCallback(String toolName, ResolvableType toolType, ResolvableType toolInputType,
          String toolDescription, Object bean) {
       if (KotlinDetector.isKotlinPresent()) {
          if (KotlinDelegate.isKotlinFunction(toolType.toClass())) {
             return FunctionToolCallback.builder(toolName, KotlinDelegate.wrapKotlinFunction(bean))
                .description(toolDescription)
                .inputSchema(generateSchema(toolInputType))
                .inputType(ParameterizedTypeReference.forType(toolInputType.getType()))
                .build();
          }
          if (KotlinDelegate.isKotlinBiFunction(toolType.toClass())) {
             return FunctionToolCallback.builder(toolName, KotlinDelegate.wrapKotlinBiFunction(bean))
                .description(toolDescription)
                .inputSchema(generateSchema(toolInputType))
                .inputType(ParameterizedTypeReference.forType(toolInputType.getType()))
                .build();
          }
          if (KotlinDelegate.isKotlinSupplier(toolType.toClass())) {
             return FunctionToolCallback.builder(toolName, KotlinDelegate.wrapKotlinSupplier(bean))
                .description(toolDescription)
                .inputSchema(generateSchema(toolInputType))
                .inputType(ParameterizedTypeReference.forType(toolInputType.getType()))
                .build();
          }
       }

       if (bean instanceof Function<?, ?> function) {
          return FunctionToolCallback.builder(toolName, function)
             .description(toolDescription)
             .inputSchema(generateSchema(toolInputType))
             .inputType(ParameterizedTypeReference.forType(toolInputType.getType()))
             .build();
       }
       if (bean instanceof BiFunction<?, ?, ?>) {
          return FunctionToolCallback.builder(toolName, (BiFunction<?, ToolContext, ?>) bean)
             .description(toolDescription)
             .inputSchema(generateSchema(toolInputType))
             .inputType(ParameterizedTypeReference.forType(toolInputType.getType()))
             .build();
       }
       if (bean instanceof Supplier<?> supplier) {
          return FunctionToolCallback.builder(toolName, supplier)
             .description(toolDescription)
             .inputSchema(generateSchema(toolInputType))
             .inputType(ParameterizedTypeReference.forType(toolInputType.getType()))
             .build();
       }
       if (bean instanceof Consumer<?> consumer) {
          return FunctionToolCallback.builder(toolName, consumer)
             .description(toolDescription)
             .inputSchema(generateSchema(toolInputType))
             .inputType(ParameterizedTypeReference.forType(toolInputType.getType()))
             .build();
       }

       throw new IllegalStateException(
             "Unsupported bean type. Support types: Function, BiFunction, Supplier, Consumer.");
    }

    private String generateSchema(ResolvableType toolInputType) {
       if (this.schemaType == SchemaType.OPENAPISCHEMA) {
          return JsonSchemaGenerator.generateForType(toolInputType.getType(),
                JsonSchemaGenerator.SchemaOption.UPPERCASETYPEVALUES);
       }
       return JsonSchemaGenerator.generateForType(toolInputType.getType());
    }

    public static Builder builder() {
       return new Builder();
    }

    public static class Builder {

       private GenericApplicationContext applicationContext;

       private SchemaType schemaType;

       public Builder applicationContext(GenericApplicationContext applicationContext) {
          this.applicationContext = applicationContext;
          return this;
       }

       public Builder schemaType(SchemaType schemaType) {
          this.schemaType = schemaType;
          return this;
       }

       public SpringBeanToolCallbackResolver build() {
          return new SpringBeanToolCallbackResolver(this.applicationContext, this.schemaType);
       }

    }

    private static final class KotlinDelegate {

       public static boolean isKotlinSupplier(Class<?> clazz) {
          return Function0.class.isAssignableFrom(clazz);
       }

       @SuppressWarnings("unchecked")
       public static Supplier<?> wrapKotlinSupplier(Object bean) {
          return () -> ((Function0<Object>) bean).invoke();
       }

       public static boolean isKotlinFunction(Class<?> clazz) {
          return Function1.class.isAssignableFrom(clazz);
       }

       @SuppressWarnings("unchecked")
       public static Function<?, ?> wrapKotlinFunction(Object bean) {
          return t -> ((Function1<Object, Object>) bean).invoke(t);
       }

       public static boolean isKotlinBiFunction(Class<?> clazz) {
          return Function2.class.isAssignableFrom(clazz);
       }

       @SuppressWarnings("unchecked")
       public static BiFunction<?, ToolContext, ?> wrapKotlinBiFunction(Object bean) {
          return (t, u) -> ((Function2<Object, ToolContext, Object>) bean).invoke(t, u);
       }

    }

}
TypeResolverHelper

用于在 Spring AI 工具体系中解析函数式接口(如 Function、BiFunction、Supplier、Consumer 及 Kotlin 函数)相关的类型信息,辅助 Spring 容器中 Bean 的类型推断和参数类型获取

|----------------------------|-------------------------------------------------------|
| 方法名称 | 描述 |
| getConsumerInputClass | 获取 Consumer 的输入参数类型 |
| getBiFunctionInputClass | 获取 BiFunction 的第一个输入参数类型 |
| getFunctionInputClass | 获取 Function 的输入参数类型 |
| getFunctionOutputClass | 获取 Function 的输出参数类型 |
| getFunctionArgumentClass | 获取 Function 指定参数索引的类型(如 0 为输入,1 为输出) |
| getBiFunctionArgumentClass | 获取 BiFunction 指定参数索引的类型 |
| resolveBeanType | 解析 Spring 容器中指定 bean 的类型,支持直接解析、工厂方法、@Component 等多种情况 |
| getFunctionArgumentType | 获取函数类型(支持 Java/Kotlin 各种函数式接口)指定参数的 ResolvableType |

ToolCallingManager(工具管理器)

该接口用于管理聊天模型的工具调用流程

|------------------------|------------------------------|
| 方法名称 | 描述 |
| resolveToolDefinitions | 根据模型的工具调用选项,解析出可用的工具定义列表 |
| executeToolCalls | 根据模型的响应,实际执行所请求的工具调用,并返回执行结果 |
| builder | 提供默认实现的构建器 |

java 复制代码
package org.springframework.ai.model.tool;

import java.util.List;

import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.tool.definition.ToolDefinition;

public interface ToolCallingManager {

    List<ToolDefinition> resolveToolDefinitions(ToolCallingChatOptions chatOptions);

    ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResponse);

    static DefaultToolCallingManager.Builder builder() {
       return DefaultToolCallingManager.builder();
    }

}
DefaultToolCallingManager

ToolCallingManager 的默认实现类,负责管理 AI 聊天模型的工具调用流程,包括工具定义的解析、工具调用的执行、异常处理和观测埋点等

  • ObservationRegistry observationRegistry:观测注册表,用于埋点和监控工具调用过程
  • ToolCallbackResolver toolCallbackResolver:具回调解析器,根据工具名查找对应的 ToolCallback 实例
  • ToolExecutionExceptionProcessor toolExecutionExceptionProcessor:工具执行异常处理器,负责将工具调用异常转为可返回的结果
  • ToolCallingObservationConvention observationConvention:工具调用观测约定,定义观测数据的结构和内容

|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 方法名称 | 描述 |
| resolveToolDefinitions | 解析并返回当前会话可用的工具定义列表 |
| executeToolCalls | 1. 提取工具调用请求:从 chatResponse 的结果中查找包含工具调用(toolCalls)的 Generation 2. 构建ToolContext:根据 prompt 和提取到的 AssistantMessage 构建工具上下文(ToolContext),包含上下文参数和对话历史 3. 执行工具调用:调用私有方法 executeToolCall,对每个工具以此调用 1. 获取请求时的ToolCallback列表:若Prompt 的 options 是 ToolCallingChatOptions,则取出其中的工具回调列表(toolCallbacks),否则为空列表 2. 遍历所有工具调用请求:对 AssistantMessage 中的每个 ToolCall 依次处理 1. 查找对应的 ToolCallback:先在 toolCallbacks 里按名称查找,找不到则用 toolCallbackResolver 解析 2. 处理 returnDirect:第一次取当前工具的 returnDirect,后续与前面结果做 AND,确保所有工具都要求 returnDirect 才为 true 3. 构建观测上下文:用工具定义、元数据、调用参数构建 ToolCallingObservationContext,用于埋点观测 4. 执行工具调用并观测:通过 observation.observe 执行工具回调(toolCallback.call),如有异常则用异常处理器处理,并将结果写入观测上下文 5. 收集响应:将每个工具调用的结果封装为 ToolResponseMessage.ToolResponse,加入响应列表 3. 结果返回:将每个工具调用的结果封装为 ToolResponseMessage.ToolResponse,加入响应列表 4. 构建新的对话历史:将原有对话、助手消息和工具响应消息合并,形成新的对话历史 5. 结果返回:返回 ToolExecutionResult,包含新的对话历史和 returnDirect 标志 |
| setObservationConvention | 设置自定义的观测约定 |

java 复制代码
package org.springframework.ai.model.tool;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import io.micrometer.observation.ObservationRegistry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
import org.springframework.ai.tool.execution.ToolExecutionException;
import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor;
import org.springframework.ai.tool.observation.DefaultToolCallingObservationConvention;
import org.springframework.ai.tool.observation.ToolCallingObservationContext;
import org.springframework.ai.tool.observation.ToolCallingObservationConvention;
import org.springframework.ai.tool.observation.ToolCallingObservationDocumentation;
import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver;
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

public final class DefaultToolCallingManager implements ToolCallingManager {

    private static final Logger logger = LoggerFactory.getLogger(DefaultToolCallingManager.class);

    private static final ObservationRegistry DEFAULTOBSERVATIONREGISTRY
          = ObservationRegistry.NOOP;

    private static final ToolCallingObservationConvention DEFAULTOBSERVATIONCONVENTION
          = new DefaultToolCallingObservationConvention();

    private static final ToolCallbackResolver DEFAULTTOOLCALLBACKRESOLVER
          = new DelegatingToolCallbackResolver(List.of());

    private static final ToolExecutionExceptionProcessor DEFAULTTOOLEXECUTIONEXCEPTIONPROCESSOR
          = DefaultToolExecutionExceptionProcessor.builder().build();

    private final ObservationRegistry observationRegistry;

    private final ToolCallbackResolver toolCallbackResolver;

    private final ToolExecutionExceptionProcessor toolExecutionExceptionProcessor;

    private ToolCallingObservationConvention observationConvention = DEFAULTOBSERVATIONCONVENTION;

    public DefaultToolCallingManager(ObservationRegistry observationRegistry, ToolCallbackResolver toolCallbackResolver,
          ToolExecutionExceptionProcessor toolExecutionExceptionProcessor) {
       Assert.notNull(observationRegistry, "observationRegistry cannot be null");
       Assert.notNull(toolCallbackResolver, "toolCallbackResolver cannot be null");
       Assert.notNull(toolExecutionExceptionProcessor, "toolCallExceptionConverter cannot be null");

       this.observationRegistry = observationRegistry;
       this.toolCallbackResolver = toolCallbackResolver;
       this.toolExecutionExceptionProcessor = toolExecutionExceptionProcessor;
    }

    @Override
    public List<ToolDefinition> resolveToolDefinitions(ToolCallingChatOptions chatOptions) {
       Assert.notNull(chatOptions, "chatOptions cannot be null");

       List<ToolCallback> toolCallbacks = new ArrayList<>(chatOptions.getToolCallbacks());
       for (String toolName : chatOptions.getToolNames()) {
          // Skip the tool if it is already present in the request toolCallbacks.
          // That might happen if a tool is defined in the options
          // both as a ToolCallback and as a tool name.
          if (chatOptions.getToolCallbacks()
             .stream()
             .anyMatch(tool -> tool.getToolDefinition().name().equals(toolName))) {
             continue;
          }
          ToolCallback toolCallback = this.toolCallbackResolver.resolve(toolName);
          if (toolCallback == null) {
             throw new IllegalStateException("No ToolCallback found for tool name: " + toolName);
          }
          toolCallbacks.add(toolCallback);
       }

       return toolCallbacks.stream().map(ToolCallback::getToolDefinition).toList();
    }

    @Override
    public ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResponse) {
       Assert.notNull(prompt, "prompt cannot be null");
       Assert.notNull(chatResponse, "chatResponse cannot be null");

       Optional<Generation> toolCallGeneration = chatResponse.getResults()
          .stream()
          .filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls()))
          .findFirst();

       if (toolCallGeneration.isEmpty()) {
          throw new IllegalStateException("No tool call requested by the chat model");
       }

       AssistantMessage assistantMessage = toolCallGeneration.get().getOutput();

       ToolContext toolContext = buildToolContext(prompt, assistantMessage);

       InternalToolExecutionResult internalToolExecutionResult = executeToolCall(prompt, assistantMessage,
             toolContext);

       List<Message> conversationHistory = buildConversationHistoryAfterToolExecution(prompt.getInstructions(),
             assistantMessage, internalToolExecutionResult.toolResponseMessage());

       return ToolExecutionResult.builder()
          .conversationHistory(conversationHistory)
          .returnDirect(internalToolExecutionResult.returnDirect())
          .build();
    }

    private static ToolContext buildToolContext(Prompt prompt, AssistantMessage assistantMessage) {
       Map<String, Object> toolContextMap = Map.of();

       if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions
             && !CollectionUtils.isEmpty(toolCallingChatOptions.getToolContext())) {
          toolContextMap = new HashMap<>(toolCallingChatOptions.getToolContext());

          List<Message> messageHistory = new ArrayList<>(prompt.copy().getInstructions());
          messageHistory.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(),
                assistantMessage.getToolCalls()));

          toolContextMap.put(ToolContext.TOOLCALLHISTORY,
                buildConversationHistoryBeforeToolExecution(prompt, assistantMessage));
       }

       return new ToolContext(toolContextMap);
    }

    private static List<Message> buildConversationHistoryBeforeToolExecution(Prompt prompt,
          AssistantMessage assistantMessage) {
       List<Message> messageHistory = new ArrayList<>(prompt.copy().getInstructions());
       messageHistory.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(),
             assistantMessage.getToolCalls()));
       return messageHistory;
    }

    /**
     * Execute the tool call and return the response message.
     */
    private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMessage assistantMessage,
          ToolContext toolContext) {
       List<ToolCallback> toolCallbacks = List.of();
       if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
          toolCallbacks = toolCallingChatOptions.getToolCallbacks();
       }

       List<ToolResponseMessage.ToolResponse> toolResponses = new ArrayList<>();

       Boolean returnDirect = null;

       for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {

          logger.debug("Executing tool call: {}", toolCall.name());

          String toolName = toolCall.name();
          String toolInputArguments = toolCall.arguments();

          ToolCallback toolCallback = toolCallbacks.stream()
             .filter(tool -> toolName.equals(tool.getToolDefinition().name()))
             .findFirst()
             .orElseGet(() -> this.toolCallbackResolver.resolve(toolName));

          if (toolCallback == null) {
             throw new IllegalStateException("No ToolCallback found for tool name: " + toolName);
          }

          if (returnDirect == null) {
             returnDirect = toolCallback.getToolMetadata().returnDirect();
          }
          else {
             returnDirect = returnDirect && toolCallback.getToolMetadata().returnDirect();
          }

          ToolCallingObservationContext observationContext = ToolCallingObservationContext.builder()
             .toolDefinition(toolCallback.getToolDefinition())
             .toolMetadata(toolCallback.getToolMetadata())
             .toolCallArguments(toolInputArguments)
             .build();

          String toolCallResult = ToolCallingObservationDocumentation.TOOLCALL
             .observation(this.observationConvention, DEFAULTOBSERVATIONCONVENTION, () -> observationContext,
                   this.observationRegistry)
             .observe(() -> {
                String toolResult;
                try {
                   toolResult = toolCallback.call(toolInputArguments, toolContext);
                }
                catch (ToolExecutionException ex) {
                   toolResult = this.toolExecutionExceptionProcessor.process(ex);
                }
                observationContext.setToolCallResult(toolResult);
                return toolResult;
             });

          toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolName,
                toolCallResult != null ? toolCallResult : ""));
       }

       return new InternalToolExecutionResult(new ToolResponseMessage(toolResponses, Map.of()), returnDirect);
    }

    private List<Message> buildConversationHistoryAfterToolExecution(List<Message> previousMessages,
          AssistantMessage assistantMessage, ToolResponseMessage toolResponseMessage) {
       List<Message> messages = new ArrayList<>(previousMessages);
       messages.add(assistantMessage);
       messages.add(toolResponseMessage);
       return messages;
    }

    public void setObservationConvention(ToolCallingObservationConvention observationConvention) {
       this.observationConvention = observationConvention;
    }

    public static Builder builder() {
       return new Builder();
    }

    private record InternalToolExecutionResult(ToolResponseMessage toolResponseMessage, boolean returnDirect) {
    }

    public final static class Builder {

       private ObservationRegistry observationRegistry = DEFAULTOBSERVATIONREGISTRY;

       private ToolCallbackResolver toolCallbackResolver = DEFAULTTOOLCALLBACKRESOLVER;

       private ToolExecutionExceptionProcessor toolExecutionExceptionProcessor = DEFAULTTOOLEXECUTIONEXCEPTIONPROCESSOR;

       private Builder() {
       }

       public Builder observationRegistry(ObservationRegistry observationRegistry) {
          this.observationRegistry = observationRegistry;
          return this;
       }

       public Builder toolCallbackResolver(ToolCallbackResolver toolCallbackResolver) {
          this.toolCallbackResolver = toolCallbackResolver;
          return this;
       }

       public Builder toolExecutionExceptionProcessor(
             ToolExecutionExceptionProcessor toolExecutionExceptionProcessor) {
          this.toolExecutionExceptionProcessor = toolExecutionExceptionProcessor;
          return this;
       }

       public DefaultToolCallingManager build() {
          return new DefaultToolCallingManager(this.observationRegistry, this.toolCallbackResolver,
                this.toolExecutionExceptionProcessor);
       }

    }

}

ToolExecutionResult(工具执行结果)

该接口用于表示一次工具调用后的执行结果,统一封装工具执行后的对话历史和返回策略

  • String FINISHREASON:标识工具执行完成的原因
  • String METADATATOOLID:工具调用的唯一标识
  • String METADATATOOLNAME:工具名称

|---------------------|------------------------------------------------------------------------------------------|
| 方法名称 | 描述 |
| conversationHistory | 获取包含工具执行结果在内的完整对话历史,便于后续上下文处理或直接返回给客户端 |
| returnDirect | 标识工具执行结果是否应直接返回给用户(true),还是继续传递给大模型进一步处理(false)。默认返回 false |
| builder | 获取默认实现的构建器,便于链式构建 ToolExecutionResult 实例 |
| buildGenerations | 将工具执行结果转换为 Generation 列表,方便直接发送给客户端或用于后续处理。会提取最后一条 ToolResponseMessage 并生成对应的 Generation |

java 复制代码
package org.springframework.ai.model.tool;

import java.util.ArrayList;
import java.util.List;

import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.model.Generation;

public interface ToolExecutionResult {

    String FINISHREASON = "returnDirect";

    String METADATATOOLID = "toolId";

    String METADATATOOLNAME = "toolName";
    
    List<Message> conversationHistory();

    default boolean returnDirect() {
       return false;
    }

    static DefaultToolExecutionResult.Builder builder() {
       return DefaultToolExecutionResult.builder();
    }

    static List<Generation> buildGenerations(ToolExecutionResult toolExecutionResult) {
       List<Message> conversationHistory = toolExecutionResult.conversationHistory();
       List<Generation> generations = new ArrayList<>();
       if (conversationHistory
          .get(conversationHistory.size() - 1) instanceof ToolResponseMessage toolResponseMessage) {
          toolResponseMessage.getResponses().forEach(response -> {
             AssistantMessage assistantMessage = new AssistantMessage(response.responseData());
             Generation generation = new Generation(assistantMessage,
                   ChatGenerationMetadata.builder()
                      .metadata(METADATATOOLID, response.id())
                      .metadata(METADATATOOLNAME, response.name())
                      .finishReason(FINISHREASON)
                      .build());
             generations.add(generation);
          });
       }
       return generations;
    }

}
DefaultToolExecutionResult
java 复制代码
package org.springframework.ai.model.tool;

import java.util.List;

import org.springframework.ai.chat.messages.Message;
import org.springframework.util.Assert;

public record DefaultToolExecutionResult(List<Message> conversationHistory,
       boolean returnDirect) implements ToolExecutionResult {

    public DefaultToolExecutionResult {
       Assert.notNull(conversationHistory, "conversationHistory cannot be null");
       Assert.noNullElements(conversationHistory, "conversationHistory cannot contain null elements");
    }

    public static Builder builder() {
       return new Builder();
    }

    public static final class Builder {

       private List<Message> conversationHistory = List.of();

       private boolean returnDirect;

       private Builder() {
       }

       public Builder conversationHistory(List<Message> conversationHistory) {
          this.conversationHistory = conversationHistory;
          return this;
       }

       public Builder returnDirect(boolean returnDirect) {
          this.returnDirect = returnDirect;
          return this;
       }

       public DefaultToolExecutionResult build() {
          return new DefaultToolExecutionResult(this.conversationHistory, this.returnDirect);
       }

    }

}

ToolCallingChatOptions(工具会话选项)

该接口用于配置与 ChatModel 交互时的工具调用相关选项

|-------------------------------------------------------------------|---------------------------|
| 方法名称 | 描述 |
| getToolCallbacks / setToolCallbacks | 获取 / 设置当前注册的ToolCallbacks |
| getToolNames / setToolNames | 获取 / 设置注册到模型的工具名称 |
| getInternalToolExecutionEnabled / setInternalToolExecutionEnabled | 获取 /设置工具执行方式 |
| getToolContext / setToolContext | 获取 / 设置工具调用时的上下文参数 |
| mergeToolCallbacks | 合并运行时和默认的工具回调列表 |
| mergeToolNames | 合并运行时和默认的工具名集合 |
| mergeToolContext | 合并运行时和默认的工具上下文参数 |
| validateToolCallbacks | 校验工具回调列表中是否有重名工具 |
| builder | 获取构建器,便于链式配置各项参数 |

java 复制代码
package org.springframework.ai.model.tool;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.support.ToolUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

public interface ToolCallingChatOptions extends ChatOptions {

    boolean DEFAULTTOOLEXECUTIONENABLED = true;
    
    List<ToolCallback> getToolCallbacks();
    
    void setToolCallbacks(List<ToolCallback> toolCallbacks);

    Set<String> getToolNames();

    void setToolNames(Set<String> toolNames);

    @Nullable
    Boolean getInternalToolExecutionEnabled();

    void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled);

    Map<String, Object> getToolContext();
    
    void setToolContext(Map<String, Object> toolContext);

    static Builder builder() {
       return new DefaultToolCallingChatOptions.Builder();
    }

    static boolean isInternalToolExecutionEnabled(ChatOptions chatOptions) {
       Assert.notNull(chatOptions, "chatOptions cannot be null");
       boolean internalToolExecutionEnabled;
       if (chatOptions instanceof ToolCallingChatOptions toolCallingChatOptions
             && toolCallingChatOptions.getInternalToolExecutionEnabled() != null) {
          internalToolExecutionEnabled = Boolean.TRUE
             .equals(toolCallingChatOptions.getInternalToolExecutionEnabled());
       }
       else {
          internalToolExecutionEnabled = DEFAULTTOOLEXECUTIONENABLED;
       }
       return internalToolExecutionEnabled;
    }

    static Set<String> mergeToolNames(Set<String> runtimeToolNames, Set<String> defaultToolNames) {
       Assert.notNull(runtimeToolNames, "runtimeToolNames cannot be null");
       Assert.notNull(defaultToolNames, "defaultToolNames cannot be null");
       if (CollectionUtils.isEmpty(runtimeToolNames)) {
          return new HashSet<>(defaultToolNames);
       }
       return new HashSet<>(runtimeToolNames);
    }

    static List<ToolCallback> mergeToolCallbacks(List<ToolCallback> runtimeToolCallbacks,
          List<ToolCallback> defaultToolCallbacks) {
       Assert.notNull(runtimeToolCallbacks, "runtimeToolCallbacks cannot be null");
       Assert.notNull(defaultToolCallbacks, "defaultToolCallbacks cannot be null");
       if (CollectionUtils.isEmpty(runtimeToolCallbacks)) {
          return new ArrayList<>(defaultToolCallbacks);
       }
       return new ArrayList<>(runtimeToolCallbacks);
    }

    static Map<String, Object> mergeToolContext(Map<String, Object> runtimeToolContext,
          Map<String, Object> defaultToolContext) {
       Assert.notNull(runtimeToolContext, "runtimeToolContext cannot be null");
       Assert.noNullElements(runtimeToolContext.keySet(), "runtimeToolContext keys cannot be null");
       Assert.notNull(defaultToolContext, "defaultToolContext cannot be null");
       Assert.noNullElements(defaultToolContext.keySet(), "defaultToolContext keys cannot be null");
       var mergedToolContext = new HashMap<>(defaultToolContext);
       mergedToolContext.putAll(runtimeToolContext);
       return mergedToolContext;
    }

    static void validateToolCallbacks(List<ToolCallback> toolCallbacks) {
       List<String> duplicateToolNames = ToolUtils.getDuplicateToolNames(toolCallbacks);
       if (!duplicateToolNames.isEmpty()) {
          throw new IllegalStateException("Multiple tools with the same name (%s) found in ToolCallingChatOptions"
             .formatted(String.join(", ", duplicateToolNames)));
       }
    }

    interface Builder extends ChatOptions.Builder {

       Builder toolCallbacks(List<ToolCallback> toolCallbacks);

       Builder toolCallbacks(ToolCallback... toolCallbacks);

       Builder toolNames(Set<String> toolNames);

       Builder toolNames(String... toolNames);

       Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled);

       Builder toolContext(Map<String, Object> context);

       Builder toolContext(String key, Object value);

       // ChatOptions.Builder methods

       @Override
       Builder model(@Nullable String model);

       @Override
       Builder frequencyPenalty(@Nullable Double frequencyPenalty);

       @Override
       Builder maxTokens(@Nullable Integer maxTokens);

       @Override
       Builder presencePenalty(@Nullable Double presencePenalty);

       @Override
       Builder stopSequences(@Nullable List<String> stopSequences);

       @Override
       Builder temperature(@Nullable Double temperature);

       @Override
       Builder topK(@Nullable Integer topK);

       @Override
       Builder topP(@Nullable Double topP);

       @Override
       ToolCallingChatOptions build();

    }

}
DefaultToolCallingChatOptions

ToolCallingChatOptions 接口类的默认实现

java 复制代码
package org.springframework.ai.model.tool;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

public class DefaultToolCallingChatOptions implements ToolCallingChatOptions {

    private List<ToolCallback> toolCallbacks = new ArrayList<>();

    private Set<String> toolNames = new HashSet<>();

    private Map<String, Object> toolContext = new HashMap<>();

    @Nullable
    private Boolean internalToolExecutionEnabled;

    @Nullable
    private String model;

    @Nullable
    private Double frequencyPenalty;

    @Nullable
    private Integer maxTokens;

    @Nullable
    private Double presencePenalty;

    @Nullable
    private List<String> stopSequences;

    @Nullable
    private Double temperature;

    @Nullable
    private Integer topK;

    @Nullable
    private Double topP;

    @Override
    public List<ToolCallback> getToolCallbacks() {
       return List.copyOf(this.toolCallbacks);
    }

    @Override
    public void setToolCallbacks(List<ToolCallback> toolCallbacks) {
       Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
       Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
       this.toolCallbacks = new ArrayList<>(toolCallbacks);
    }

    @Override
    public Set<String> getToolNames() {
       return Set.copyOf(this.toolNames);
    }

    @Override
    public void setToolNames(Set<String> toolNames) {
       Assert.notNull(toolNames, "toolNames cannot be null");
       Assert.noNullElements(toolNames, "toolNames cannot contain null elements");
       toolNames.forEach(toolName -> Assert.hasText(toolName, "toolNames cannot contain empty elements"));
       this.toolNames = new HashSet<>(toolNames);
    }

    @Override
    public Map<String, Object> getToolContext() {
       return Map.copyOf(this.toolContext);
    }

    @Override
    public void setToolContext(Map<String, Object> toolContext) {
       Assert.notNull(toolContext, "toolContext cannot be null");
       Assert.noNullElements(toolContext.keySet(), "toolContext cannot contain null keys");
       this.toolContext = new HashMap<>(toolContext);
    }

    @Override
    @Nullable
    public Boolean getInternalToolExecutionEnabled() {
       return this.internalToolExecutionEnabled;
    }

    @Override
    public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) {
       this.internalToolExecutionEnabled = internalToolExecutionEnabled;
    }

    @Override
    @Nullable
    public String getModel() {
       return this.model;
    }

    public void setModel(@Nullable String model) {
       this.model = model;
    }

    @Override
    @Nullable
    public Double getFrequencyPenalty() {
       return this.frequencyPenalty;
    }

    public void setFrequencyPenalty(@Nullable Double frequencyPenalty) {
       this.frequencyPenalty = frequencyPenalty;
    }

    @Override
    @Nullable
    public Integer getMaxTokens() {
       return this.maxTokens;
    }

    public void setMaxTokens(@Nullable Integer maxTokens) {
       this.maxTokens = maxTokens;
    }

    @Override
    @Nullable
    public Double getPresencePenalty() {
       return this.presencePenalty;
    }

    public void setPresencePenalty(@Nullable Double presencePenalty) {
       this.presencePenalty = presencePenalty;
    }

    @Override
    @Nullable
    public List<String> getStopSequences() {
       return this.stopSequences;
    }

    public void setStopSequences(@Nullable List<String> stopSequences) {
       this.stopSequences = stopSequences;
    }

    @Override
    @Nullable
    public Double getTemperature() {
       return this.temperature;
    }

    public void setTemperature(@Nullable Double temperature) {
       this.temperature = temperature;
    }

    @Override
    @Nullable
    public Integer getTopK() {
       return this.topK;
    }

    public void setTopK(@Nullable Integer topK) {
       this.topK = topK;
    }

    @Override
    @Nullable
    public Double getTopP() {
       return this.topP;
    }

    public void setTopP(@Nullable Double topP) {
       this.topP = topP;
    }

    @Override
    @SuppressWarnings("unchecked")
    public <T extends ChatOptions> T copy() {
       DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions();
       options.setToolCallbacks(getToolCallbacks());
       options.setToolNames(getToolNames());
       options.setToolContext(getToolContext());
       options.setInternalToolExecutionEnabled(getInternalToolExecutionEnabled());
       options.setModel(getModel());
       options.setFrequencyPenalty(getFrequencyPenalty());
       options.setMaxTokens(getMaxTokens());
       options.setPresencePenalty(getPresencePenalty());
       options.setStopSequences(getStopSequences());
       options.setTemperature(getTemperature());
       options.setTopK(getTopK());
       options.setTopP(getTopP());
       return (T) options;
    }

    public static Builder builder() {
       return new Builder();
    }

    /**
     * Default implementation of {@link ToolCallingChatOptions.Builder}.
     */
    public static class Builder implements ToolCallingChatOptions.Builder {

       private final DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions();

       @Override
       public ToolCallingChatOptions.Builder toolCallbacks(List<ToolCallback> toolCallbacks) {
          this.options.setToolCallbacks(toolCallbacks);
          return this;
       }

       @Override
       public ToolCallingChatOptions.Builder toolCallbacks(ToolCallback... toolCallbacks) {
          Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
          this.options.setToolCallbacks(Arrays.asList(toolCallbacks));
          return this;
       }

       @Override
       public ToolCallingChatOptions.Builder toolNames(Set<String> toolNames) {
          this.options.setToolNames(toolNames);
          return this;
       }

       @Override
       public ToolCallingChatOptions.Builder toolNames(String... toolNames) {
          Assert.notNull(toolNames, "toolNames cannot be null");
          this.options.setToolNames(Set.of(toolNames));
          return this;
       }

       @Override
       public ToolCallingChatOptions.Builder toolContext(Map<String, Object> context) {
          this.options.setToolContext(context);
          return this;
       }

       @Override
       public ToolCallingChatOptions.Builder toolContext(String key, Object value) {
          Assert.hasText(key, "key cannot be null");
          Assert.notNull(value, "value cannot be null");
          Map<String, Object> updatedToolContext = new HashMap<>(this.options.getToolContext());
          updatedToolContext.put(key, value);
          this.options.setToolContext(updatedToolContext);
          return this;
       }

       @Override
       public ToolCallingChatOptions.Builder internalToolExecutionEnabled(
             @Nullable Boolean internalToolExecutionEnabled) {
          this.options.setInternalToolExecutionEnabled(internalToolExecutionEnabled);
          return this;
       }

       @Override
       public ToolCallingChatOptions.Builder model(@Nullable String model) {
          this.options.setModel(model);
          return this;
       }

       @Override
       public ToolCallingChatOptions.Builder frequencyPenalty(@Nullable Double frequencyPenalty) {
          this.options.setFrequencyPenalty(frequencyPenalty);
          return this;
       }

       @Override
       public ToolCallingChatOptions.Builder maxTokens(@Nullable Integer maxTokens) {
          this.options.setMaxTokens(maxTokens);
          return this;
       }

       @Override
       public ToolCallingChatOptions.Builder presencePenalty(@Nullable Double presencePenalty) {
          this.options.setPresencePenalty(presencePenalty);
          return this;
       }

       @Override
       public ToolCallingChatOptions.Builder stopSequences(@Nullable List<String> stopSequences) {
          this.options.setStopSequences(stopSequences);
          return this;
       }

       @Override
       public ToolCallingChatOptions.Builder temperature(@Nullable Double temperature) {
          this.options.setTemperature(temperature);
          return this;
       }

       @Override
       public ToolCallingChatOptions.Builder topK(@Nullable Integer topK) {
          this.options.setTopK(topK);
          return this;
       }

       @Override
       public ToolCallingChatOptions.Builder topP(@Nullable Double topP) {
          this.options.setTopP(topP);
          return this;
       }

       @Override
       public ToolCallingChatOptions build() {
          return this.options;
       }

    }

}

ToolExecutionEligibilityPredicate(工具执行判断器)

作为一个函数式接口,用于根据对话选项和模型响应,灵活判断是否需要执行工具调用

java 复制代码
package org.springframework.ai.model.tool;

import java.util.function.BiPredicate;

import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.util.Assert;

public interface ToolExecutionEligibilityPredicate extends BiPredicate<ChatOptions, ChatResponse> {

    default boolean isToolExecutionRequired(ChatOptions promptOptions, ChatResponse chatResponse) {
       Assert.notNull(promptOptions, "promptOptions cannot be null");
       Assert.notNull(chatResponse, "chatResponse cannot be null");
       return test(promptOptions, chatResponse);
    }

}
DefaultToolExecutionEligibilityPredicate

ToolExecutionEligibilityPredicate 的默认实现,主要根据 promptOptions 是否启用了内部工具执行 && chatResponse 是否非空且包含工具调用

java 复制代码
package org.springframework.ai.model.tool;

import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.ChatOptions;

public class DefaultToolExecutionEligibilityPredicate implements ToolExecutionEligibilityPredicate {

    @Override
    public boolean test(ChatOptions promptOptions, ChatResponse chatResponse) {
       return ToolCallingChatOptions.isInternalToolExecutionEnabled(promptOptions) && chatResponse != null
             && chatResponse.hasToolCalls();
    }

}

ToolCallResultConverter(工具结果转换器)

函数式接口,主要用于将工具(Tool)调用的结果对象转换为可以返回给 AI 模型的字符串格式

java 复制代码
package org.springframework.ai.tool.execution;

import java.lang.reflect.Type;

import org.springframework.lang.Nullable;

@FunctionalInterface
public interface ToolCallResultConverter {

    String convert(@Nullable Object result, @Nullable Type returnType);

}
DefaultToolCallResultConverter

ToolCallResultConverter 的默认实现,统一处理工具调用结果的序列化,确保各种类型的返回值都能被正确转换为字符串,便于 AI 模型理解和处理

convert 核心逻辑

  1. 无返回值(Void.TYPE):工具声明的返回类型为 void,则返回 "Done" 的 JSON 字符串,表示操作已完成
  2. 图片类型(RenderedImage):若工具返回结果是图片(如 RenderedImage),会将图片编码为 PNG 格式的字节流,再转为 base64 字符串,并以 JSON 形式返回(包含 mimeType 和 data 字段)
  3. 其他类型:直接使用 JsonParser 序列化为 JSON 字符串
java 复制代码
package org.springframework.ai.tool.execution;

import java.awt.image.RenderedImage;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.lang.reflect.Type;
import java.util.Base64;
import java.util.Map;

import javax.imageio.ImageIO;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.util.json.JsonParser;
import org.springframework.lang.Nullable;

public final class DefaultToolCallResultConverter implements ToolCallResultConverter {

    private static final Logger logger = LoggerFactory.getLogger(DefaultToolCallResultConverter.class);

    @Override
    public String convert(@Nullable Object result, @Nullable Type returnType) {
       if (returnType == Void.TYPE) {
          logger.debug("The tool has no return type. Converting to conventional response.");
          return JsonParser.toJson("Done");
       }
       if (result instanceof RenderedImage) {
          final var buf = new ByteArrayOutputStream(1024 * 4);
          try {
             ImageIO.write((RenderedImage) result, "PNG", buf);
          }
          catch (IOException e) {
             return "Failed to convert tool result to a base64 image: " + e.getMessage();
          }
          final var imgB64 = Base64.getEncoder().encodeToString(buf.toByteArray());
          return JsonParser.toJson(Map.of("mimeType", "image/png", "data", imgB64));
       }
       else {
          logger.debug("Converting tool result to JSON.");
          return JsonParser.toJson(result);
       }
    }

}

ToolExecutionExceptionProcessor(工具执行异常处理器)

函数式接口,主要用于处理工具执行过程中抛出的 ToolExecutionException 异常

java 复制代码
package org.springframework.ai.tool.execution;

@FunctionalInterface
public interface ToolExecutionExceptionProcessor {

    String process(ToolExecutionException exception);

}
DefaultToolExecutionExceptionProcessor

ToolExecutionExceptionProcessor 的默认实现类

  • boolean alwaysThrow:若为 true,遇到异常时直接抛出,由上层调用方处理;如果为 false(默认),则将异常信息(message)转换为字符串返回,通常用于反馈给 AI 模型
java 复制代码
package org.springframework.ai.tool.execution;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.util.Assert;

public class DefaultToolExecutionExceptionProcessor implements ToolExecutionExceptionProcessor {

    private final static Logger logger = LoggerFactory.getLogger(DefaultToolExecutionExceptionProcessor.class);

    private static final boolean DEFAULTALWAYSTHROW = false;

    private final boolean alwaysThrow;

    public DefaultToolExecutionExceptionProcessor(boolean alwaysThrow) {
       this.alwaysThrow = alwaysThrow;
    }

    @Override
    public String process(ToolExecutionException exception) {
       Assert.notNull(exception, "exception cannot be null");
       if (this.alwaysThrow) {
          throw exception;
       }
       logger.debug("Exception thrown by tool: {}. Message: {}", exception.getToolDefinition().name(),
             exception.getMessage());
       return exception.getMessage();
    }

    public static Builder builder() {
       return new Builder();
    }

    public static class Builder {

       private boolean alwaysThrow = DEFAULTALWAYSTHROW;

       public Builder alwaysThrow(boolean alwaysThrow) {
          this.alwaysThrow = alwaysThrow;
          return this;
       }

       public DefaultToolExecutionExceptionProcessor build() {
          return new DefaultToolExecutionExceptionProcessor(this.alwaysThrow);
       }

    }

}

ToolContext(工具上下文)

用于在函数调用(工具调用)场景下,封装和传递工具执行所需的上下文信息。它保证上下文数据不可变,便于多线程安全地传递和使用

  • Map<String, Object> context:于存储和获取工具调用的消息历史
java 复制代码
package org.springframework.ai.chat.model;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.springframework.ai.chat.messages.Message;

public final class ToolContext {
    public static final String TOOLCALLHISTORY = "TOOLCALLHISTORY";
    private final Map<String, Object> context;

    public ToolContext(Map<String, Object> context) {
        this.context = Collections.unmodifiableMap(context);
    }

    public Map<String, Object> getContext() {
        return this.context;
    }

    public List<Message> getToolCallHistory() {
        return (List)this.context.get("TOOLCALLHISTORY");
    }
}

工具触发链路解读

!TIP\] 导入工具依赖,自动注入在 ChatModel 时需要用到的 ToolCallingManager,进行工具的系列调用说明

pom.xml

xml 复制代码
<dependency>
    <groupId>org.springframework.ai</groupId>
    <artifactId>spring-ai-autoconfigure-model-tool</artifactId>
</dependency>

ToolCallingAutoConfiguration

用于自动装配与 AI 工具调用(Tool Calling)相关的核心组件,简化开发者集成和使用工具链的流程

提供的 Bean 说明

|-------------------------------------|-------------------------------------------------------------------------------------------------|
| 方法名称 | 描述 |
| toolCallbackResolver | 提供ToolCallbackResolver的Bean,统一解析和管理所有可用的工具回调(ToolCallback),支持静态注册、Spring Bean 自动发现和 Provider 扩展 |
| toolExecutionExceptionProcessor | 提供ToolExecutionExceptionProcessor的Bean,处理工具执行过程中的异常 |
| toolCallingManager | 提供ToolCallingManager的Bean,调用的核心管理器,负责协调工具回调解析、异常处理、观测注册等 |
| toolCallingContentObservationFilter | 提供ToolCallingContentObservationFilter的Bean,用于观测链路中,过滤和记录工具调用的参数与结果内容。开启后会有安全警告,避免敏感信息泄露 |

java 复制代码
package org.springframework.ai.model.tool.autoconfigure;

import io.micrometer.observation.ObservationRegistry;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.model.tool.DefaultToolCallingManager;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor;
import org.springframework.ai.tool.observation.ToolCallingContentObservationFilter;
import org.springframework.ai.tool.observation.ToolCallingObservationConvention;
import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver;
import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver;
import org.springframework.ai.tool.resolution.StaticToolCallbackResolver;
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.support.GenericApplicationContext;

@AutoConfiguration
@ConditionalOnClass({ChatModel.class})
@EnableConfigurationProperties({ToolCallingProperties.class})
public class ToolCallingAutoConfiguration {
    private static final Logger logger = LoggerFactory.getLogger(ToolCallingAutoConfiguration.class);

    @Bean
    @ConditionalOnMissingBean
    ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationContext, List<ToolCallback> toolCallbacks, List<ToolCallbackProvider> tcbProviders) {
        List<ToolCallback> allFunctionAndToolCallbacks = new ArrayList(toolCallbacks);
        Stream var10000 = tcbProviders.stream().map((pr) -> List.of(pr.getToolCallbacks()));
        Objects.requireNonNull(allFunctionAndToolCallbacks);
        var10000.forEach(allFunctionAndToolCallbacks::addAll);
        StaticToolCallbackResolver staticToolCallbackResolver = new StaticToolCallbackResolver(allFunctionAndToolCallbacks);
        SpringBeanToolCallbackResolver springBeanToolCallbackResolver = SpringBeanToolCallbackResolver.builder().applicationContext(applicationContext).build();
        return new DelegatingToolCallbackResolver(List.of(staticToolCallbackResolver, springBeanToolCallbackResolver));
    }

    @Bean
    @ConditionalOnMissingBean
    ToolExecutionExceptionProcessor toolExecutionExceptionProcessor() {
        return new DefaultToolExecutionExceptionProcessor(false);
    }

    @Bean
    @ConditionalOnMissingBean
    ToolCallingManager toolCallingManager(ToolCallbackResolver toolCallbackResolver, ToolExecutionExceptionProcessor toolExecutionExceptionProcessor, ObjectProvider<ObservationRegistry> observationRegistry, ObjectProvider<ToolCallingObservationConvention> observationConvention) {
        DefaultToolCallingManager toolCallingManager = ToolCallingManager.builder().observationRegistry((ObservationRegistry)observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)).toolCallbackResolver(toolCallbackResolver).toolExecutionExceptionProcessor(toolExecutionExceptionProcessor).build();
        Objects.requireNonNull(toolCallingManager);
        observationConvention.ifAvailable(toolCallingManager::setObservationConvention);
        return toolCallingManager;
    }

    @Bean
    @ConditionalOnMissingBean
    @ConditionalOnProperty(
        prefix = "spring.ai.tools.observations",
        name = {"include-content"},
        havingValue = "true"
    )
    ToolCallingContentObservationFilter toolCallingContentObservationFilter() {
        logger.warn("You have enabled the inclusion of the tool call arguments and result in the observations, with the risk of exposing sensitive or private information. Please, be careful!");
        return new ToolCallingContentObservationFilter();
    }
}

ToolCallingProperties

Spring AI 工具调用相关参数的配置类

  • boolean throwExceptionOnError:控制工具调用过程中的异常处理方式,true 时抛出异常,false 时将错误消息返回给 AI 模型(默认)
  • Observations observations:置观测(Observations)相关选项,决定是否在观测数据中包含工具调用的参数和结果内容(默认 false,开启会有泄漏信息风险)
java 复制代码
package org.springframework.ai.model.tool.autoconfigure;

import org.springframework.boot.context.properties.ConfigurationProperties;

@ConfigurationProperties("spring.ai.tools")
public class ToolCallingProperties {
    public static final String CONFIGPREFIX = "spring.ai.tools";
    private final Observations observations = new Observations();

    public static class Observations {
        private boolean includeContent = false;

        public boolean isIncludeContent() {
            return this.includeContent;
        }

        public void setIncludeContent(boolean includeContent) {
            this.includeContent = includeContent;
        }
    }
}

client 触发工具链路

学习交流圈

你好,我是影子,曾先后在🐻、新能源、老铁就职,现在是一名AI研发工程师,同时作为Spring AI Alibaba开源社区的Committer。目前新建了一个交流群,一个人走得快,一群人走得远,。另外,本人长期维护一套飞书云文档笔记,涵盖后端、大数据系统化的面试资料,可私信免费获取

相关推荐
烛阴5 小时前
自动化测试、前后端mock数据量产利器:Chance.js深度教程
前端·javascript·后端
.生产的驴5 小时前
SpringCloud 分布式锁Redisson锁的重入性与看门狗机制 高并发 可重入
java·分布式·后端·spring·spring cloud·信息可视化·tomcat
攒了一袋星辰5 小时前
Spring @Autowired自动装配的实现机制
java·后端·spring
我的golang之路果然有问题6 小时前
快速了解GO+ElasticSearch
开发语言·经验分享·笔记·后端·elasticsearch·golang
love530love6 小时前
Windows 下部署 SUNA 项目:虚拟环境尝试与最终方案
前端·人工智能·windows·后端·docker·rust·开源
元闰子6 小时前
走技术路线需要些什么?
后端·面试·程序员
元闰子7 小时前
AI Agent需要什么样的数据库?
数据库·人工智能·后端
知初~7 小时前
SpringCloud
后端·spring·spring cloud
希望20177 小时前
go语言基础|slice入门
后端·golang