聊聊Spring AI的Prompt

本文主要研究一下Spring AI的Prompt

Prompt

org/springframework/ai/chat/prompt/Prompt.java

复制代码
public class Prompt implements ModelRequest<List<Message>> {

	private final List<Message> messages;

	private ChatOptions chatOptions;

	public Prompt(String contents) {
		this(new UserMessage(contents));
	}

	public Prompt(Message message) {
		this(Collections.singletonList(message));
	}

	public Prompt(List<Message> messages) {
		this(messages, null);
	}

	public Prompt(Message... messages) {
		this(Arrays.asList(messages), null);
	}

	public Prompt(String contents, ChatOptions chatOptions) {
		this(new UserMessage(contents), chatOptions);
	}

	public Prompt(Message message, ChatOptions chatOptions) {
		this(Collections.singletonList(message), chatOptions);
	}

	public Prompt(List<Message> messages, ChatOptions chatOptions) {
		this.messages = messages;
		this.chatOptions = chatOptions;
	}

	public String getContents() {
		StringBuilder sb = new StringBuilder();
		for (Message message : getInstructions()) {
			sb.append(message.getText());
		}
		return sb.toString();
	}

	//......
}	

Prompt实现了ModelRequest方法,其getInstructions返回的类型为List<Message>,其getContents方法遍历getInstructions添加message.getText()

Message

org/springframework/ai/chat/messages/Message.java

复制代码
public interface Message extends Content {

	/**
	 * Get the message type.
	 * @return the message type
	 */
	MessageType getMessageType();

}

MessageType

org/springframework/ai/chat/messages/MessageType.java

复制代码
public enum MessageType {

	/**
	 * A {@link Message} of type {@literal user}, having the user role and originating
	 * from an end-user or developer.
	 * @see UserMessage
	 */
	USER("user"),

	/**
	 * A {@link Message} of type {@literal assistant} passed in subsequent input
	 * {@link Message Messages} as the {@link Message} generated in response to the user.
	 * @see AssistantMessage
	 */
	ASSISTANT("assistant"),

	/**
	 * A {@link Message} of type {@literal system} passed as input {@link Message
	 * Messages} containing high-level instructions for the conversation, such as behave
	 * like a certain character or provide answers in a specific format.
	 * @see SystemMessage
	 */
	SYSTEM("system"),

	/**
	 * A {@link Message} of type {@literal function} passed as input {@link Message
	 * Messages} with function content in a chat application.
	 * @see ToolResponseMessage
	 */
	TOOL("tool");

	private final String value;

	MessageType(String value) {
		this.value = value;
	}

	public static MessageType fromValue(String value) {
		for (MessageType messageType : MessageType.values()) {
			if (messageType.getValue().equals(value)) {
				return messageType;
			}
		}
		throw new IllegalArgumentException("Invalid MessageType value: " + value);
	}

	public String getValue() {
		return this.value;
	}

}

MessageType定义了USER、SYSTEM、ASSISTANT、TOOL这几种类型

PromptTemplate

PromptTemplateMessageActions

org/springframework/ai/chat/prompt/PromptTemplateMessageActions.java

复制代码
public interface PromptTemplateMessageActions {

	Message createMessage();

	Message createMessage(List<Media> mediaList);

	Message createMessage(Map<String, Object> model);

}

PromptTemplateMessageActions定义了createMessage方法

PromptTemplateStringActions

org/springframework/ai/chat/prompt/PromptTemplateStringActions.java

复制代码
public interface PromptTemplateStringActions {

	String render();

	String render(Map<String, Object> model);

}

PromptTemplateStringActions定义了render方法,渲染为String类型

PromptTemplateChatActions

org/springframework/ai/chat/prompt/PromptTemplateChatActions.java

复制代码
public interface PromptTemplateChatActions {

	List<Message> createMessages();

	List<Message> createMessages(Map<String, Object> model);

}

PromptTemplateChatActions接口定义了createMessages方法,返回List<Message>

PromptTemplateActions

org/springframework/ai/chat/prompt/PromptTemplateActions.java

复制代码
public interface PromptTemplateActions extends PromptTemplateStringActions {

	Prompt create();

	Prompt create(ChatOptions modelOptions);

	Prompt create(Map<String, Object> model);

	Prompt create(Map<String, Object> model, ChatOptions modelOptions);

}

PromptTemplateActions继承了PromptTemplateStringActions接口,它定义了create方法,用于创建Prompt

PromptTemplate

复制代码
public class PromptTemplate implements PromptTemplateActions, PromptTemplateMessageActions {

	protected String template;

	protected TemplateFormat templateFormat = TemplateFormat.ST;

	private ST st;

	private Map<String, Object> dynamicModel = new HashMap<>();

	//......

	// Render Methods
	@Override
	public String render() {
		validate(this.dynamicModel);
		return this.st.render();
	}

	@Override
	public String render(Map<String, Object> model) {
		validate(model);
		for (Entry<String, Object> entry : model.entrySet()) {
			if (this.st.getAttribute(entry.getKey()) != null) {
				this.st.remove(entry.getKey());
			}
			if (entry.getValue() instanceof Resource) {
				this.st.add(entry.getKey(), renderResource((Resource) entry.getValue()));
			}
			else {
				this.st.add(entry.getKey(), entry.getValue());
			}

		}
		return this.st.render();
	}

	@Override
	public Message createMessage() {
		return new UserMessage(render());
	}

	@Override
	public Message createMessage(List<Media> mediaList) {
		return new UserMessage(render(), mediaList);
	}

	@Override
	public Message createMessage(Map<String, Object> model) {
		return new UserMessage(render(model));
	}

	@Override
	public Prompt create() {
		return new Prompt(render(new HashMap<>()));
	}

	@Override
	public Prompt create(ChatOptions modelOptions) {
		return new Prompt(render(new HashMap<>()), modelOptions);
	}

	@Override
	public Prompt create(Map<String, Object> model) {
		return new Prompt(render(model));
	}

	@Override
	public Prompt create(Map<String, Object> model, ChatOptions modelOptions) {
		return new Prompt(render(model), modelOptions);
	}

	//......		
}	

PromptTemplate实现了PromptTemplateActions、PromptTemplateMessageActions接口,其render使用了org.stringtemplate.v4.ST来渲染。

PromptTemplateStringActions专注于创建和渲染提示字符串,代表了提示生成的最基本形式。

PromptTemplateMessageActions专为通过生成和操作Message对象来创建提示而设计。

PromptTemplateActions旨在返回Prompt对象,该对象可以传递给ChatModel以生成响应。

SystemPromptTemplate

org/springframework/ai/chat/prompt/SystemPromptTemplate.java

复制代码
public class SystemPromptTemplate extends PromptTemplate {

	public SystemPromptTemplate(String template) {
		super(template);
	}

	public SystemPromptTemplate(Resource resource) {
		super(resource);
	}

	@Override
	public Message createMessage() {
		return new SystemMessage(render());
	}

	@Override
	public Message createMessage(Map<String, Object> model) {
		return new SystemMessage(render(model));
	}

	@Override
	public Prompt create() {
		return new Prompt(new SystemMessage(render()));
	}

	@Override
	public Prompt create(Map<String, Object> model) {
		return new Prompt(new SystemMessage(render(model)));
	}

}

SystemPromptTemplate继承了PromptTemplate,其createMessage返回的是SystemMessage

FunctionPromptTemplate

org/springframework/ai/chat/prompt/FunctionPromptTemplate.java

复制代码
public class FunctionPromptTemplate extends PromptTemplate {

	private String name;

	public FunctionPromptTemplate(String template) {
		super(template);
	}

}

FunctionPromptTemplate继承了PromptTemplate,它定义了一个name属性

ChatPromptTemplate

org/springframework/ai/chat/prompt/ChatPromptTemplate.java

复制代码
public class ChatPromptTemplate implements PromptTemplateActions, PromptTemplateChatActions {

	private final List<PromptTemplate> promptTemplates;

	public ChatPromptTemplate(List<PromptTemplate> promptTemplates) {
		this.promptTemplates = promptTemplates;
	}

	@Override
	public String render() {
		StringBuilder sb = new StringBuilder();
		for (PromptTemplate promptTemplate : this.promptTemplates) {
			sb.append(promptTemplate.render());
		}
		return sb.toString();
	}

	@Override
	public String render(Map<String, Object> model) {
		StringBuilder sb = new StringBuilder();
		for (PromptTemplate promptTemplate : this.promptTemplates) {
			sb.append(promptTemplate.render(model));
		}
		return sb.toString();
	}

	@Override
	public List<Message> createMessages() {
		List<Message> messages = new ArrayList<>();
		for (PromptTemplate promptTemplate : this.promptTemplates) {
			messages.add(promptTemplate.createMessage());
		}
		return messages;
	}

	@Override
	public List<Message> createMessages(Map<String, Object> model) {
		List<Message> messages = new ArrayList<>();
		for (PromptTemplate promptTemplate : this.promptTemplates) {
			messages.add(promptTemplate.createMessage(model));
		}
		return messages;
	}

	@Override
	public Prompt create() {
		List<Message> messages = createMessages();
		return new Prompt(messages);
	}

	@Override
	public Prompt create(ChatOptions modelOptions) {
		List<Message> messages = createMessages();
		return new Prompt(messages, modelOptions);
	}

	@Override
	public Prompt create(Map<String, Object> model) {
		List<Message> messages = createMessages(model);
		return new Prompt(messages);
	}

	@Override
	public Prompt create(Map<String, Object> model, ChatOptions modelOptions) {
		List<Message> messages = createMessages(model);
		return new Prompt(messages, modelOptions);
	}

}

ChatPromptTemplate实现了PromptTemplateActions, PromptTemplateChatActions接口,其构造器输入promptTemplates,其render方法遍历promptTemplates,挨个添加promptTemplate.render();其createMessages方法遍历promptTemplates,挨个添加promptTemplate.createMessage()

AssistantPromptTemplate

org/springframework/ai/chat/prompt/AssistantPromptTemplate.java

复制代码
public class AssistantPromptTemplate extends PromptTemplate {

	public AssistantPromptTemplate(String template) {
		super(template);
	}

	public AssistantPromptTemplate(Resource resource) {
		super(resource);
	}

	@Override
	public Prompt create() {
		return new Prompt(new AssistantMessage(render()));
	}

	@Override
	public Prompt create(Map<String, Object> model) {
		return new Prompt(new AssistantMessage(render(model)));
	}

	@Override
	public Message createMessage() {
		return new AssistantMessage(render());
	}

	@Override
	public Message createMessage(Map<String, Object> model) {
		return new AssistantMessage(render(model));
	}

}

AssistantPromptTemplate继承了PromptTemplate,其createMessage方法返回的是AssistantMessage

示例

PromptTemplate示例

复制代码
PromptTemplate promptTemplate = new PromptTemplate("Tell me a {adjective} joke about {topic}");

Prompt prompt = promptTemplate.create(Map.of("adjective", adjective, "topic", topic));

return chatModel.call(prompt).getResult();

SystemPromptTemplate示例

复制代码
String userText = """
    Tell me about three famous pirates from the Golden Age of Piracy and why they did.
    Write at least a sentence for each pirate.
    """;

Message userMessage = new UserMessage(userText);

String systemText = """
  You are a helpful AI assistant that helps people find information.
  Your name is {name}
  You should reply to the user's request with your name and also in the style of a {voice}.
  """;

SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemText);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice));

Prompt prompt = new Prompt(List.of(userMessage, systemMessage));

List<Generation> response = chatModel.call(prompt).getResults();

小结

Spring AI的Message定义了MessageType属性,它有USER、SYSTEM、ASSISTANT、TOOL这几种类型;PromptTemplate的createMessage方法返回的是UserMessage,SystemPromptTemplate的createMessage方法返回的是SystemMessage,AssistantPromptTemplate的createMessage方法返回的是AssistantMessage。SystemPromptTemplate及AssistantPromptTemplate都继承了PromptTemplate,其render方法使用了org.stringtemplate.v4.ST来渲染;ChatPromptTemplate则是聚合了一系列的promptTemplates。

doc

相关推荐
生信漫谈13 分钟前
Rice Science∣武汉大学水稻研究团队发现水稻壁相关激酶OsWAKg16和OsWAKg52同时调控水稻抗病性和产量
人工智能·学习方法
TO ENFJ1 小时前
day 10 机器学习建模与评估
人工智能·机器学习
高效匠人1 小时前
文章五《卷积神经网络(CNN)与图像处理》
图像处理·人工智能·python·cnn
卧式纯绿1 小时前
卷积神经网络基础(五)
人工智能·深度学习·神经网络·目标检测·机器学习·计算机视觉·cnn
乌恩大侠1 小时前
【东枫科技】代理销售 NVIDIA DGX Spark 您的桌上有一台 Grace Blackwell AI 超级计算机。
大数据·人工智能·科技·spark·nvidia
zhanzhan01092 小时前
ubantu安装CUDA
人工智能·python·深度学习
IT古董2 小时前
【漫话机器学习系列】243.数值下溢(Underflow)
人工智能·机器学习
奋斗者1号2 小时前
《机器学习中的过拟合与模型复杂性:理解与应对策略》
人工智能·机器学习
Blossom.1182 小时前
机器学习在网络安全中的应用:守护数字世界的防线
人工智能·深度学习·神经网络·安全·web安全·机器学习·计算机视觉
Ven%2 小时前
LangChain:大语言模型应用的“瑞士军刀”入门指南
人工智能·语言模型·langchain