Chat Memory
自定义实现多轮对话
java
public static void main(String[] args) {
List<ChatMessage> chatMessageList = new ArrayList<>();
List<String> questions = new ArrayList<>();
questions.add("北京大学是211么?");
questions.add("是985么?");
questions.add("是双一流么?");
for (String userMessage : questions) {
SystemMessage systemMessage = SystemMessage.from("你是一个智能助手!");
if (!chatMessageList.isEmpty()) {
ChatMessage chatMessage = chatMessageList.get(0);
if (chatMessage instanceof SystemMessage) {
} else {
chatMessageList.add(systemMessage);
}
} else {
chatMessageList.add(systemMessage);
}
chatMessageList.add(UserMessage.from(userMessage));
ChatRequest request = ChatRequest.builder().messages(chatMessageList).build();
ChatRequestOptions options = ChatRequestOptions.EMPTY;
ChatResponse chat = BASE_MODEL.chat(request, options);
AiMessage aiMessage = chat.aiMessage();
chatMessageList.add(aiMessage);
System.out.println(aiMessage.text());
}
}
以上有个问题,所有的对话都保存在一个集合里面,如何解决,为每个对话分配一个List集合
ChatMemory
java
public static void main(String[] args) {
ChatMemory chatMemory = MessageWindowChatMemory.builder().maxMessages(10).build();
List<String> questions = new ArrayList<>();
questions.add("北京大学是211么?");
questions.add("是985么?");
questions.add("是双一流么?");
for (String userMessage : questions) {
SystemMessage systemMessage = SystemMessage.from("你是一个智能助手!");
chatMemory.add(systemMessage);
chatMemory.add(UserMessage.from(userMessage));
ChatRequest request = ChatRequest.builder().messages(chatMemory.messages()).build();
ChatRequestOptions options = ChatRequestOptions.EMPTY;
ChatResponse chat = BASE_MODEL.chat(request, options);
AiMessage aiMessage = chat.aiMessage();
chatMemory.add(aiMessage);
System.out.println(aiMessage.text());
}
}
追踪源码发现:MessageWindowChatMemory 底层提供了 chatMemoryStore(默认是:SingleSlotChatMemoryStore),里面维护的就是一个List。
使用系统的 InMemoryChatMemoryStore
java
ChatMemory chatMemory = MessageWindowChatMemory.builder()
.chatMemoryStore(new InMemoryChatMemoryStore())
.maxMessages(10).build();
分析 InMemoryChatMemoryStore 的源码
java
public class InMemoryChatMemoryStore implements ChatMemoryStore {
private final Map<Object, List<ChatMessage>> messagesByMemoryId = new ConcurrentHashMap<>();
/**
* Constructs a new {@link InMemoryChatMemoryStore}.
*/
public InMemoryChatMemoryStore() {}
@Override
public List<ChatMessage> getMessages(Object memoryId) {
return messagesByMemoryId.computeIfAbsent(memoryId, ignored -> new ArrayList<>());
}
@Override
public void updateMessages(Object memoryId, List<ChatMessage> messages) {
messagesByMemoryId.put(memoryId, messages);
}
@Override
public void deleteMessages(Object memoryId) {
messagesByMemoryId.remove(memoryId);
}
}
自定义 chatMemoryStore ,这里使用redis
具体支持参考:
引入redis依赖:
xml
<!-- Source: https://mvnrepository.com/artifact/redis.clients/jedis -->
<dependency>
<groupId>redis.clients</groupId>
<artifactId>jedis</artifactId>
<version>7.5.0</version>
<scope>compile</scope>
</dependency>
使用redis来存储:
java
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageDeserializer;
import dev.langchain4j.data.message.ChatMessageSerializer;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import redis.clients.jedis.*;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.List;
import static dev.langchain4j.internal.ValidationUtils.*;
public class RedisChatMemoryStore implements ChatMemoryStore {
/**
* Redis client for database operations.
*/
private final RedisClient client;
/**
* Prefix to be added to all Redis keys.
*/
private final String keyPrefix;
/**
* Time-to-live value for Redis keys in seconds.
* Keys will automatically expire after this duration.
* A value of 0 or less means keys will not expire.
*/
private final Long ttl;
public RedisChatMemoryStore(String host, Integer port,Integer database, String user, String password) {
this(host, port, database, user, password, "", 0L);
}
public RedisChatMemoryStore(
String host, Integer port, Integer database, String user, String password, String prefix, Long ttl) {
String rawPassword = password;
String encodedPassword = URLEncoder.encode(rawPassword, StandardCharsets.UTF_8);
this.client = RedisClient.create("redis://:"+ encodedPassword +"@192.168.6.15:6379");
this.keyPrefix = ensureNotNull(prefix, "prefix");
this.ttl = ensureNotNull(ttl, "ttl");
}
@Override
public List<ChatMessage> getMessages(Object memoryId) {
String json = client.get(toRedisKey(memoryId));
List<ChatMessage> chatMessages = ChatMessageDeserializer.messagesFromJson(json);
return chatMessages;
}
@Override
public void updateMessages(Object memoryId, List<ChatMessage> messages) {
String json = ChatMessageSerializer.messagesToJson(ensureNotEmpty(messages, "messages"));
client.set(toRedisKey(memoryId), json);
}
@Override
public void deleteMessages(Object memoryId) {
client.del(toRedisKey(memoryId));
}
private String toMemoryIdString(Object memoryId) {
boolean isNullOrEmpty = memoryId == null || memoryId.toString().trim().isEmpty();
if (isNullOrEmpty) {
throw new IllegalArgumentException("memoryId cannot be null or empty");
}
return memoryId.toString();
}
private String toRedisKey(Object memoryId) {
return keyPrefix + toMemoryIdString(memoryId);
}
}
测试
java
public static void main(String[] args) {
RedisChatMemoryStore chatMemoryStore = new RedisChatMemoryStore("192.168.6.15:6379", 6349, 11, userId, "redis@2025");
ChatMemory chatMemory = MessageWindowChatMemory.builder()
.chatMemoryStore(chatMemoryStore)
.maxMessages(10).build();
List<String> questions = new ArrayList<>();
questions.add("北京大学是211么?");
questions.add("是985么?");
questions.add("是双一流么?");
for (String userMessage : questions) {
SystemMessage systemMessage = SystemMessage.from("你是一个智能助手!");
chatMemory.add(systemMessage);
chatMemory.add(UserMessage.from(userMessage));
ChatRequest request = ChatRequest.builder().messages(chatMemory.messages()).build();
ChatRequestOptions options = ChatRequestOptions.EMPTY;
ChatResponse chat = BASE_MODEL.chat(request, options);
AiMessage aiMessage = chat.aiMessage();
chatMemory.add(aiMessage);
System.out.println(aiMessage.text());
}
}
解释: chatMemoryId, 默认值:default
java
@Test
void test2() {
RedisChatMemoryStore chatMemoryStore = new RedisChatMemoryStore("192.168.6.15:6379", 6349, 11, userId, "redis@2025");
ChatMemory chatMemory = MessageWindowChatMemory.builder()
//.id(chatMemoryId)
.chatMemoryStore(chatMemoryStore)
.maxMessages(10).build();
chatMemory.add(UserMessage.from("帮我总结刚才的问题"));
ChatRequest request = ChatRequest.builder().messages(chatMemory.messages()).build();
ChatRequestOptions options = ChatRequestOptions.EMPTY;
ChatResponse chat = BASE_MODEL.chat(request, options);
AiMessage aiMessage = chat.aiMessage();
chatMemory.add(aiMessage);
System.out.println(aiMessage.text());
}
结果是:获取到了另一个的会话内容,如何区分开?
chatMemoryId 应该设计为:userid + 会话id + 其他的id + 时间戳等。
java
String chatMemoryId = userId + UUID.randomUUID().toString() + System.currentTimeMillis();
ChatMemory chatMemory = MessageWindowChatMemory.builder()
.id(chatMemoryId)
.chatMemoryStore(chatMemoryStore)
.maxMessages(10).build();