序
本文主要研究一下Spring AI的Prompt
Prompt
org/springframework/ai/chat/prompt/Prompt.java
public class Prompt implements ModelRequest<List<Message>> {
private final List<Message> messages;
private ChatOptions chatOptions;
public Prompt(String contents) {
this(new UserMessage(contents));
}
public Prompt(Message message) {
this(Collections.singletonList(message));
}
public Prompt(List<Message> messages) {
this(messages, null);
}
public Prompt(Message... messages) {
this(Arrays.asList(messages), null);
}
public Prompt(String contents, ChatOptions chatOptions) {
this(new UserMessage(contents), chatOptions);
}
public Prompt(Message message, ChatOptions chatOptions) {
this(Collections.singletonList(message), chatOptions);
}
public Prompt(List<Message> messages, ChatOptions chatOptions) {
this.messages = messages;
this.chatOptions = chatOptions;
}
public String getContents() {
StringBuilder sb = new StringBuilder();
for (Message message : getInstructions()) {
sb.append(message.getText());
}
return sb.toString();
}
//......
}
Prompt实现了ModelRequest方法,其getInstructions返回的类型为
List<Message>
,其getContents方法遍历getInstructions添加message.getText()
Message
org/springframework/ai/chat/messages/Message.java
public interface Message extends Content {
/**
* Get the message type.
* @return the message type
*/
MessageType getMessageType();
}
MessageType
org/springframework/ai/chat/messages/MessageType.java
public enum MessageType {
/**
* A {@link Message} of type {@literal user}, having the user role and originating
* from an end-user or developer.
* @see UserMessage
*/
USER("user"),
/**
* A {@link Message} of type {@literal assistant} passed in subsequent input
* {@link Message Messages} as the {@link Message} generated in response to the user.
* @see AssistantMessage
*/
ASSISTANT("assistant"),
/**
* A {@link Message} of type {@literal system} passed as input {@link Message
* Messages} containing high-level instructions for the conversation, such as behave
* like a certain character or provide answers in a specific format.
* @see SystemMessage
*/
SYSTEM("system"),
/**
* A {@link Message} of type {@literal function} passed as input {@link Message
* Messages} with function content in a chat application.
* @see ToolResponseMessage
*/
TOOL("tool");
private final String value;
MessageType(String value) {
this.value = value;
}
public static MessageType fromValue(String value) {
for (MessageType messageType : MessageType.values()) {
if (messageType.getValue().equals(value)) {
return messageType;
}
}
throw new IllegalArgumentException("Invalid MessageType value: " + value);
}
public String getValue() {
return this.value;
}
}
MessageType定义了USER、SYSTEM、ASSISTANT、TOOL这几种类型
PromptTemplate
PromptTemplateMessageActions
org/springframework/ai/chat/prompt/PromptTemplateMessageActions.java
public interface PromptTemplateMessageActions {
Message createMessage();
Message createMessage(List<Media> mediaList);
Message createMessage(Map<String, Object> model);
}
PromptTemplateMessageActions定义了createMessage方法
PromptTemplateStringActions
org/springframework/ai/chat/prompt/PromptTemplateStringActions.java
public interface PromptTemplateStringActions {
String render();
String render(Map<String, Object> model);
}
PromptTemplateStringActions定义了render方法,渲染为String类型
PromptTemplateChatActions
org/springframework/ai/chat/prompt/PromptTemplateChatActions.java
public interface PromptTemplateChatActions {
List<Message> createMessages();
List<Message> createMessages(Map<String, Object> model);
}
PromptTemplateChatActions接口定义了createMessages方法,返回
List<Message>
PromptTemplateActions
org/springframework/ai/chat/prompt/PromptTemplateActions.java
public interface PromptTemplateActions extends PromptTemplateStringActions {
Prompt create();
Prompt create(ChatOptions modelOptions);
Prompt create(Map<String, Object> model);
Prompt create(Map<String, Object> model, ChatOptions modelOptions);
}
PromptTemplateActions继承了PromptTemplateStringActions接口,它定义了create方法,用于创建Prompt
PromptTemplate
public class PromptTemplate implements PromptTemplateActions, PromptTemplateMessageActions {
protected String template;
protected TemplateFormat templateFormat = TemplateFormat.ST;
private ST st;
private Map<String, Object> dynamicModel = new HashMap<>();
//......
// Render Methods
@Override
public String render() {
validate(this.dynamicModel);
return this.st.render();
}
@Override
public String render(Map<String, Object> model) {
validate(model);
for (Entry<String, Object> entry : model.entrySet()) {
if (this.st.getAttribute(entry.getKey()) != null) {
this.st.remove(entry.getKey());
}
if (entry.getValue() instanceof Resource) {
this.st.add(entry.getKey(), renderResource((Resource) entry.getValue()));
}
else {
this.st.add(entry.getKey(), entry.getValue());
}
}
return this.st.render();
}
@Override
public Message createMessage() {
return new UserMessage(render());
}
@Override
public Message createMessage(List<Media> mediaList) {
return new UserMessage(render(), mediaList);
}
@Override
public Message createMessage(Map<String, Object> model) {
return new UserMessage(render(model));
}
@Override
public Prompt create() {
return new Prompt(render(new HashMap<>()));
}
@Override
public Prompt create(ChatOptions modelOptions) {
return new Prompt(render(new HashMap<>()), modelOptions);
}
@Override
public Prompt create(Map<String, Object> model) {
return new Prompt(render(model));
}
@Override
public Prompt create(Map<String, Object> model, ChatOptions modelOptions) {
return new Prompt(render(model), modelOptions);
}
//......
}
PromptTemplate实现了PromptTemplateActions、PromptTemplateMessageActions接口,其render使用了
org.stringtemplate.v4.ST
来渲染。PromptTemplateStringActions专注于创建和渲染提示字符串,代表了提示生成的最基本形式。
PromptTemplateMessageActions专为通过生成和操作Message对象来创建提示而设计。
PromptTemplateActions旨在返回Prompt对象,该对象可以传递给ChatModel以生成响应。
SystemPromptTemplate
org/springframework/ai/chat/prompt/SystemPromptTemplate.java
public class SystemPromptTemplate extends PromptTemplate {
public SystemPromptTemplate(String template) {
super(template);
}
public SystemPromptTemplate(Resource resource) {
super(resource);
}
@Override
public Message createMessage() {
return new SystemMessage(render());
}
@Override
public Message createMessage(Map<String, Object> model) {
return new SystemMessage(render(model));
}
@Override
public Prompt create() {
return new Prompt(new SystemMessage(render()));
}
@Override
public Prompt create(Map<String, Object> model) {
return new Prompt(new SystemMessage(render(model)));
}
}
SystemPromptTemplate继承了PromptTemplate,其createMessage返回的是SystemMessage
FunctionPromptTemplate
org/springframework/ai/chat/prompt/FunctionPromptTemplate.java
public class FunctionPromptTemplate extends PromptTemplate {
private String name;
public FunctionPromptTemplate(String template) {
super(template);
}
}
FunctionPromptTemplate继承了PromptTemplate,它定义了一个name属性
ChatPromptTemplate
org/springframework/ai/chat/prompt/ChatPromptTemplate.java
public class ChatPromptTemplate implements PromptTemplateActions, PromptTemplateChatActions {
private final List<PromptTemplate> promptTemplates;
public ChatPromptTemplate(List<PromptTemplate> promptTemplates) {
this.promptTemplates = promptTemplates;
}
@Override
public String render() {
StringBuilder sb = new StringBuilder();
for (PromptTemplate promptTemplate : this.promptTemplates) {
sb.append(promptTemplate.render());
}
return sb.toString();
}
@Override
public String render(Map<String, Object> model) {
StringBuilder sb = new StringBuilder();
for (PromptTemplate promptTemplate : this.promptTemplates) {
sb.append(promptTemplate.render(model));
}
return sb.toString();
}
@Override
public List<Message> createMessages() {
List<Message> messages = new ArrayList<>();
for (PromptTemplate promptTemplate : this.promptTemplates) {
messages.add(promptTemplate.createMessage());
}
return messages;
}
@Override
public List<Message> createMessages(Map<String, Object> model) {
List<Message> messages = new ArrayList<>();
for (PromptTemplate promptTemplate : this.promptTemplates) {
messages.add(promptTemplate.createMessage(model));
}
return messages;
}
@Override
public Prompt create() {
List<Message> messages = createMessages();
return new Prompt(messages);
}
@Override
public Prompt create(ChatOptions modelOptions) {
List<Message> messages = createMessages();
return new Prompt(messages, modelOptions);
}
@Override
public Prompt create(Map<String, Object> model) {
List<Message> messages = createMessages(model);
return new Prompt(messages);
}
@Override
public Prompt create(Map<String, Object> model, ChatOptions modelOptions) {
List<Message> messages = createMessages(model);
return new Prompt(messages, modelOptions);
}
}
ChatPromptTemplate实现了PromptTemplateActions, PromptTemplateChatActions接口,其构造器输入promptTemplates,其render方法遍历promptTemplates,挨个添加promptTemplate.render();其createMessages方法遍历promptTemplates,挨个添加promptTemplate.createMessage()
AssistantPromptTemplate
org/springframework/ai/chat/prompt/AssistantPromptTemplate.java
public class AssistantPromptTemplate extends PromptTemplate {
public AssistantPromptTemplate(String template) {
super(template);
}
public AssistantPromptTemplate(Resource resource) {
super(resource);
}
@Override
public Prompt create() {
return new Prompt(new AssistantMessage(render()));
}
@Override
public Prompt create(Map<String, Object> model) {
return new Prompt(new AssistantMessage(render(model)));
}
@Override
public Message createMessage() {
return new AssistantMessage(render());
}
@Override
public Message createMessage(Map<String, Object> model) {
return new AssistantMessage(render(model));
}
}
AssistantPromptTemplate继承了PromptTemplate,其createMessage方法返回的是AssistantMessage
示例
PromptTemplate示例
PromptTemplate promptTemplate = new PromptTemplate("Tell me a {adjective} joke about {topic}");
Prompt prompt = promptTemplate.create(Map.of("adjective", adjective, "topic", topic));
return chatModel.call(prompt).getResult();
SystemPromptTemplate示例
String userText = """
Tell me about three famous pirates from the Golden Age of Piracy and why they did.
Write at least a sentence for each pirate.
""";
Message userMessage = new UserMessage(userText);
String systemText = """
You are a helpful AI assistant that helps people find information.
Your name is {name}
You should reply to the user's request with your name and also in the style of a {voice}.
""";
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemText);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice));
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
List<Generation> response = chatModel.call(prompt).getResults();
小结
Spring AI的Message定义了MessageType属性,它有USER、SYSTEM、ASSISTANT、TOOL这几种类型;PromptTemplate的createMessage方法返回的是UserMessage,SystemPromptTemplate的createMessage方法返回的是SystemMessage,AssistantPromptTemplate的createMessage方法返回的是AssistantMessage。SystemPromptTemplate及AssistantPromptTemplate都继承了PromptTemplate,其render方法使用了org.stringtemplate.v4.ST
来渲染;ChatPromptTemplate则是聚合了一系列的promptTemplates。