序
本文主要研究一下langchain4j的AiServices
示例
原生版本
arduino
public interface Assistant {
String chat(String userMessage);
}
构建
ini
Assistant assistant = AiServices.create(Assistant.class, chatLanguageModel);
String resp = assistant.chat(userMessage);
spring-boot版本
kotlin
@AiService
public interface AssistantV2 {
@SystemMessage("You are a polite assistant")
String chat(String userMessage);
}
之后直接像使用托管的bean一样注入就可以使用
less
@Autowired
AssistantV2 assistantV2;
@GetMapping("/ai-service")
public String aiService(@RequestParam("prompt") String prompt) {
return assistantV2.chat(prompt);
}
源码
AiServices
dev/langchain4j/service/AiServices.java
kotlin
public abstract class AiServices<T> {
protected static final String DEFAULT = "default";
protected final AiServiceContext context;
private boolean retrieverSet = false;
private boolean contentRetrieverSet = false;
private boolean retrievalAugmentorSet = false;
protected AiServices(AiServiceContext context) {
this.context = context;
}
/**
* Creates an AI Service (an implementation of the provided interface), that is backed by the provided chat model.
* This convenience method can be used to create simple AI Services.
* For more complex cases, please use {@link #builder}.
*
* @param aiService The class of the interface to be implemented.
* @param chatLanguageModel The chat model to be used under the hood.
* @return An instance of the provided interface, implementing all its defined methods.
*/
public static <T> T create(Class<T> aiService, ChatLanguageModel chatLanguageModel) {
return builder(aiService).chatLanguageModel(chatLanguageModel).build();
}
/**
* Creates an AI Service (an implementation of the provided interface), that is backed by the provided streaming chat model.
* This convenience method can be used to create simple AI Services.
* For more complex cases, please use {@link #builder}.
*
* @param aiService The class of the interface to be implemented.
* @param streamingChatLanguageModel The streaming chat model to be used under the hood.
* The return type of all methods should be {@link TokenStream}.
* @return An instance of the provided interface, implementing all its defined methods.
*/
public static <T> T create(Class<T> aiService, StreamingChatLanguageModel streamingChatLanguageModel) {
return builder(aiService)
.streamingChatLanguageModel(streamingChatLanguageModel)
.build();
}
/**
* Begins the construction of an AI Service.
*
* @param aiService The class of the interface to be implemented.
* @return builder
*/
public static <T> AiServices<T> builder(Class<T> aiService) {
AiServiceContext context = new AiServiceContext(aiService);
for (AiServicesFactory factory : loadFactories(AiServicesFactory.class)) {
return factory.create(context);
}
return new DefaultAiServices<>(context);
}
/**
* Configures chat model that will be used under the hood of the AI Service.
* <p>
* Either {@link ChatLanguageModel} or {@link StreamingChatLanguageModel} should be configured,
* but not both at the same time.
*
* @param chatLanguageModel Chat model that will be used under the hood of the AI Service.
* @return builder
*/
public AiServices<T> chatLanguageModel(ChatLanguageModel chatLanguageModel) {
context.chatModel = chatLanguageModel;
return this;
}
/**
* Configures streaming chat model that will be used under the hood of the AI Service.
* The methods of the AI Service must return a {@link TokenStream} type.
* <p>
* Either {@link ChatLanguageModel} or {@link StreamingChatLanguageModel} should be configured,
* but not both at the same time.
*
* @param streamingChatLanguageModel Streaming chat model that will be used under the hood of the AI Service.
* @return builder
*/
public AiServices<T> streamingChatLanguageModel(StreamingChatLanguageModel streamingChatLanguageModel) {
context.streamingChatModel = streamingChatLanguageModel;
return this;
}
/**
* Configures the system message provider, which provides a system message to be used each time an AI service is invoked.
* <br>
* When both {@code @SystemMessage} and the system message provider are configured,
* {@code @SystemMessage} takes precedence.
*
* @param systemMessageProvider A {@link Function} that accepts a chat memory ID
* (a value of a method parameter annotated with @{@link MemoryId})
* and returns a system message to be used.
* If there is no parameter annotated with {@code @MemoryId},
* the value of memory ID is "default".
* The returned {@link String} can be either a complete system message
* or a system message template containing unresolved template variables (e.g. "{{name}}"),
* which will be resolved using the values of method parameters annotated with @{@link V}.
* @return builder
*/
public AiServices<T> systemMessageProvider(Function<Object, String> systemMessageProvider) {
context.systemMessageProvider = systemMessageProvider.andThen(Optional::ofNullable);
return this;
}
/**
* Configures the chat memory that will be used to preserve conversation history between method calls.
* <p>
* Unless a {@link ChatMemory} or {@link ChatMemoryProvider} is configured, all method calls will be independent of each other.
* In other words, the LLM will not remember the conversation from the previous method calls.
* <p>
* The same {@link ChatMemory} instance will be used for every method call.
* <p>
* If you want to have a separate {@link ChatMemory} for each user/conversation, configure {@link #chatMemoryProvider} instead.
* <p>
* Either a {@link ChatMemory} or a {@link ChatMemoryProvider} can be configured, but not both simultaneously.
*
* @param chatMemory An instance of chat memory to be used by the AI Service.
* @return builder
*/
public AiServices<T> chatMemory(ChatMemory chatMemory) {
context.chatMemories = new ConcurrentHashMap<>();
context.chatMemories.put(DEFAULT, chatMemory);
return this;
}
/**
* Configures the chat memory provider, which provides a dedicated instance of {@link ChatMemory} for each user/conversation.
* To distinguish between users/conversations, one of the method's arguments should be a memory ID (of any data type)
* annotated with {@link MemoryId}.
* For each new (previously unseen) memoryId, an instance of {@link ChatMemory} will be automatically obtained
* by invoking {@link ChatMemoryProvider#get(Object id)}.
* Example:
* <pre>
* interface Assistant {
*
* String chat(@MemoryId int memoryId, @UserMessage String message);
* }
* </pre>
* If you prefer to use the same (shared) {@link ChatMemory} for all users/conversations, configure a {@link #chatMemory} instead.
* <p>
* Either a {@link ChatMemory} or a {@link ChatMemoryProvider} can be configured, but not both simultaneously.
*
* @param chatMemoryProvider The provider of a {@link ChatMemory} for each new user/conversation.
* @return builder
*/
public AiServices<T> chatMemoryProvider(ChatMemoryProvider chatMemoryProvider) {
context.chatMemories = new ConcurrentHashMap<>();
context.chatMemoryProvider = chatMemoryProvider;
return this;
}
/**
* Configures a moderation model to be used for automatic content moderation.
* If a method in the AI Service is annotated with {@link Moderate}, the moderation model will be invoked
* to check the user content for any inappropriate or harmful material.
*
* @param moderationModel The moderation model to be used for content moderation.
* @return builder
* @see Moderate
*/
public AiServices<T> moderationModel(ModerationModel moderationModel) {
context.moderationModel = moderationModel;
return this;
}
/**
* Configures the tools that the LLM can use.
*
* @param objectsWithTools One or more objects whose methods are annotated with {@link Tool}.
* All these tools (methods annotated with {@link Tool}) will be accessible to the LLM.
* Note that inherited methods are ignored.
* @return builder
* @see Tool
*/
public AiServices<T> tools(Object... objectsWithTools) {
return tools(asList(objectsWithTools));
}
/**
* Configures the tools that the LLM can use.
*
* @param objectsWithTools A list of objects whose methods are annotated with {@link Tool}.
* All these tools (methods annotated with {@link Tool}) are accessible to the LLM.
* Note that inherited methods are ignored.
* @return builder
* @see Tool
*/
public AiServices<T> tools(Collection<Object> objectsWithTools) {
context.toolService.tools(objectsWithTools);
return this;
}
/**
* Configures the tool provider that the LLM can use
*
* @param toolProvider Decides which tools the LLM could use to handle the request
* @return builder
*/
public AiServices<T> toolProvider(ToolProvider toolProvider) {
context.toolService.toolProvider(toolProvider);
return this;
}
/**
* Configures the tools that the LLM can use.
*
* @param tools A map of {@link ToolSpecification} to {@link ToolExecutor} entries.
* This method of configuring tools is useful when tools must be configured programmatically.
* Otherwise, it is recommended to use the {@link Tool}-annotated java methods
* and configure tools with the {@link #tools(Object...)} and {@link #tools(Collection)} methods.
* @return builder
*/
public AiServices<T> tools(Map<ToolSpecification, ToolExecutor> tools) {
context.toolService.tools(tools);
return this;
}
/**
* Configures the strategy to be used when the LLM hallucinates a tool name (i.e., attempts to call a nonexistent tool).
*
* @param hallucinatedToolNameStrategy A Function from {@link ToolExecutionRequest} to {@link ToolExecutionResultMessage} defining
* the response provided to the LLM when it hallucinates a tool name.
* @return builder
*/
public AiServices<T> hallucinatedToolNameStrategy(
Function<ToolExecutionRequest, ToolExecutionResultMessage> hallucinatedToolNameStrategy) {
context.toolService.hallucinatedToolNameStrategy(hallucinatedToolNameStrategy);
return this;
}
/**
* @param retriever The retriever to be used by the AI Service.
* @return builder
* @deprecated Use {@link #contentRetriever(ContentRetriever)}
* (e.g. {@link EmbeddingStoreContentRetriever}) instead.
* <br>
* Configures a retriever that will be invoked on every method call to fetch relevant information
* related to the current user message from an underlying source (e.g., embedding store).
* This relevant information is automatically injected into the message sent to the LLM.
*/
@Deprecated(forRemoval = true)
public AiServices<T> retriever(Retriever<TextSegment> retriever) {
if (contentRetrieverSet || retrievalAugmentorSet) {
throw illegalConfiguration("Only one out of [retriever, contentRetriever, retrievalAugmentor] can be set");
}
if (retriever != null) {
AiServices<T> withContentRetriever = contentRetriever(retriever.toContentRetriever());
retrieverSet = true;
return withContentRetriever;
}
return this;
}
/**
* Configures a content retriever to be invoked on every method call for retrieving relevant content
* related to the user's message from an underlying data source
* (e.g., an embedding store in the case of an {@link EmbeddingStoreContentRetriever}).
* The retrieved relevant content is then automatically incorporated into the message sent to the LLM.
* <br>
* This method provides a straightforward approach for those who do not require
* a customized {@link RetrievalAugmentor}.
* It configures a {@link DefaultRetrievalAugmentor} with the provided {@link ContentRetriever}.
*
* @param contentRetriever The content retriever to be used by the AI Service.
* @return builder
*/
public AiServices<T> contentRetriever(ContentRetriever contentRetriever) {
if (retrieverSet || retrievalAugmentorSet) {
throw illegalConfiguration("Only one out of [retriever, contentRetriever, retrievalAugmentor] can be set");
}
contentRetrieverSet = true;
context.retrievalAugmentor = DefaultRetrievalAugmentor.builder()
.contentRetriever(ensureNotNull(contentRetriever, "contentRetriever"))
.build();
return this;
}
/**
* Configures a retrieval augmentor to be invoked on every method call.
*
* @param retrievalAugmentor The retrieval augmentor to be used by the AI Service.
* @return builder
*/
public AiServices<T> retrievalAugmentor(RetrievalAugmentor retrievalAugmentor) {
if (retrieverSet || contentRetrieverSet) {
throw illegalConfiguration("Only one out of [retriever, contentRetriever, retrievalAugmentor] can be set");
}
retrievalAugmentorSet = true;
context.retrievalAugmentor = ensureNotNull(retrievalAugmentor, "retrievalAugmentor");
return this;
}
/**
* Constructs and returns the AI Service.
*
* @return An instance of the AI Service implementing the specified interface.
*/
public abstract T build();
//......
}
AiServices是个抽象类,它提供了AiServices的builder方法,默认创建DefaultAiServices,它提供了设置chatLanguageModel、streamingChatLanguageModel、systemMessageProvider、chatMemory、chatMemoryProvider、moderationModel、tools、toolProvider、contentRetriever、retrievalAugmentor方法。它定义了build抽象方法供子类去实现。
DefaultAiServices
dev/langchain4j/service/DefaultAiServices.java
scss
class DefaultAiServices<T> extends AiServices<T> {
private final ServiceOutputParser serviceOutputParser = new ServiceOutputParser();
private final Collection<TokenStreamAdapter> tokenStreamAdapters = loadFactories(TokenStreamAdapter.class);
DefaultAiServices(AiServiceContext context) {
super(context);
}
//......
public T build() {
performBasicValidation();
for (Method method : context.aiServiceClass.getMethods()) {
if (method.isAnnotationPresent(Moderate.class) && context.moderationModel == null) {
throw illegalConfiguration(
"The @Moderate annotation is present, but the moderationModel is not set up. "
+ "Please ensure a valid moderationModel is configured before using the @Moderate annotation.");
}
if (method.getReturnType() == Result.class
|| method.getReturnType() == List.class
|| method.getReturnType() == Set.class) {
TypeUtils.validateReturnTypesAreProperlyParametrized(method.getName(), method.getGenericReturnType());
}
if (context.chatMemoryProvider == null) {
for (Parameter parameter : method.getParameters()) {
if (parameter.isAnnotationPresent(MemoryId.class)) {
throw illegalConfiguration(
"In order to use @MemoryId, please configure the ChatMemoryProvider on the '%s'.",
context.aiServiceClass.getName());
}
}
}
}
Object proxyInstance = Proxy.newProxyInstance(
context.aiServiceClass.getClassLoader(),
new Class<?>[] {context.aiServiceClass},
new InvocationHandler() {
private final ExecutorService executor = Executors.newCachedThreadPool();
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Exception {
if (method.getDeclaringClass() == Object.class) {
// methods like equals(), hashCode() and toString() should not be handled by this proxy
return method.invoke(this, args);
}
validateParameters(method);
Object memoryId = findMemoryId(method, args).orElse(DEFAULT);
Optional<SystemMessage> systemMessage = prepareSystemMessage(memoryId, method, args);
UserMessage userMessage = prepareUserMessage(method, args);
AugmentationResult augmentationResult = null;
if (context.retrievalAugmentor != null) {
List<ChatMessage> chatMemory = context.hasChatMemory()
? context.chatMemory(memoryId).messages()
: null;
Metadata metadata = Metadata.from(userMessage, memoryId, chatMemory);
AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata);
augmentationResult = context.retrievalAugmentor.augment(augmentationRequest);
userMessage = (UserMessage) augmentationResult.chatMessage();
}
// TODO give user ability to provide custom OutputParser
Type returnType = method.getGenericReturnType();
boolean streaming = returnType == TokenStream.class || canAdaptTokenStreamTo(returnType);
boolean supportsJsonSchema =
supportsJsonSchema(); // TODO should it be called for returnType==String?
Optional<JsonSchema> jsonSchema = Optional.empty();
if (supportsJsonSchema && !streaming) {
jsonSchema = jsonSchemaFrom(returnType);
}
if ((!supportsJsonSchema || jsonSchema.isEmpty()) && !streaming) {
// TODO append after storing in the memory?
userMessage = appendOutputFormatInstructions(returnType, userMessage);
}
if (context.hasChatMemory()) {
ChatMemory chatMemory = context.chatMemory(memoryId);
systemMessage.ifPresent(chatMemory::add);
chatMemory.add(userMessage);
}
List<ChatMessage> messages;
if (context.hasChatMemory()) {
messages = context.chatMemory(memoryId).messages();
} else {
messages = new ArrayList<>();
systemMessage.ifPresent(messages::add);
messages.add(userMessage);
}
Future<Moderation> moderationFuture = triggerModerationIfNeeded(method, messages);
ToolExecutionContext toolExecutionContext =
context.toolService.executionContext(memoryId, userMessage);
if (streaming) {
TokenStream tokenStream = new AiServiceTokenStream(
messages,
toolExecutionContext.toolSpecifications(),
toolExecutionContext.toolExecutors(),
augmentationResult != null ? augmentationResult.contents() : null,
context,
memoryId);
// TODO moderation
if (returnType == TokenStream.class) {
return tokenStream;
} else {
return adapt(tokenStream, returnType);
}
}
ResponseFormat responseFormat = null;
if (supportsJsonSchema && jsonSchema.isPresent()) {
responseFormat = ResponseFormat.builder()
.type(JSON)
.jsonSchema(jsonSchema.get())
.build();
}
ChatRequestParameters parameters = ChatRequestParameters.builder()
.toolSpecifications(toolExecutionContext.toolSpecifications())
.responseFormat(responseFormat)
.build();
ChatRequest chatRequest = ChatRequest.builder()
.messages(messages)
.parameters(parameters)
.build();
ChatResponse chatResponse = context.chatModel.chat(chatRequest);
verifyModerationIfNeeded(moderationFuture);
ToolExecutionResult toolExecutionResult = context.toolService.executeInferenceAndToolsLoop(
chatResponse,
parameters,
messages,
context.chatModel,
context.hasChatMemory() ? context.chatMemory(memoryId) : null,
memoryId,
toolExecutionContext.toolExecutors());
chatResponse = toolExecutionResult.chatResponse();
FinishReason finishReason = chatResponse.metadata().finishReason();
Response<AiMessage> response = Response.from(
chatResponse.aiMessage(), toolExecutionResult.tokenUsageAccumulator(), finishReason);
Object parsedResponse = serviceOutputParser.parse(response, returnType);
if (typeHasRawClass(returnType, Result.class)) {
return Result.builder()
.content(parsedResponse)
.tokenUsage(toolExecutionResult.tokenUsageAccumulator())
.sources(augmentationResult == null ? null : augmentationResult.contents())
.finishReason(finishReason)
.toolExecutions(toolExecutionResult.toolExecutions())
.build();
} else {
return parsedResponse;
}
}
private boolean canAdaptTokenStreamTo(Type returnType) {
for (TokenStreamAdapter tokenStreamAdapter : tokenStreamAdapters) {
if (tokenStreamAdapter.canAdaptTokenStreamTo(returnType)) {
return true;
}
}
return false;
}
private Object adapt(TokenStream tokenStream, Type returnType) {
for (TokenStreamAdapter tokenStreamAdapter : tokenStreamAdapters) {
if (tokenStreamAdapter.canAdaptTokenStreamTo(returnType)) {
return tokenStreamAdapter.adapt(tokenStream);
}
}
throw new IllegalStateException("Can't find suitable TokenStreamAdapter");
}
private boolean supportsJsonSchema() {
return context.chatModel != null
&& context.chatModel.supportedCapabilities().contains(RESPONSE_FORMAT_JSON_SCHEMA);
}
private UserMessage appendOutputFormatInstructions(Type returnType, UserMessage userMessage) {
String outputFormatInstructions = serviceOutputParser.outputFormatInstructions(returnType);
String text = userMessage.singleText() + outputFormatInstructions;
if (isNotNullOrBlank(userMessage.name())) {
userMessage = UserMessage.from(userMessage.name(), text);
} else {
userMessage = UserMessage.from(text);
}
return userMessage;
}
private Future<Moderation> triggerModerationIfNeeded(Method method, List<ChatMessage> messages) {
if (method.isAnnotationPresent(Moderate.class)) {
return executor.submit(() -> {
List<ChatMessage> messagesToModerate = removeToolMessages(messages);
return context.moderationModel
.moderate(messagesToModerate)
.content();
});
}
return null;
}
});
return (T) proxyInstance;
}
//......
}
DefaultAiServices集成了AiServices,它的build方法主要通过Proxy.newProxyInstance来创建实现类,InvocationHandler的实现主要是处理systemMessage、userMessage、构建chatMemory、toolExecutionContext,最后构建ChatRequest,通过context.chatModel.chat(chatRequest)执行请求,然后解析和适配输出。
小结
langchain4j提供了诸如ChatLanguageModel, ChatMessage, ChatMemory的low level的组件,也提供了诸如Chains和AI Services这样的high level的组件,用于协同多个组件(提示模版、ChatMemory、LLM、输出解析、RAG组件:嵌入模型和评分
)一起。其中Chains是从Python的LangChain移植过来的,不过不方便自定义,于是后续不再继续添加新增功能了。langchain4j提供了AI Services来取代Chains,它有点类似于JPA或者Retrofit,通过简单声明接口就可以自动生成代理实现类,它可以处理LLM输入的格式化,LLM输出的解析,ChatMemory、Tools、RAG。
langchain4j提供了AiServices来创建DefaultAiServices,它默认是通过JDK的Proxy.newProxyInstance创建了实现类。