先看效果

对话过程被缓存到了Redis 中。
原理
在上一节我们快速入门了SpringAI,具体文章请查看:快速入门Spring AI
创建 ChatClient 的代码如下:
java
this.chatClient = ChatClient.builder(chatModel)
.defaultSystem(DEFAULT_PROMPT)
.defaultAdvisors(new MessageChatMemoryAdvisor(new InMemoryChatMemory()))
.defaultAdvisors(new SimpleLoggerAdvisor())
.defaultOptions(OpenAiChatOptions.builder().temperature(0.0d).build())
.build();
其中new MessageChatMemoryAdvisor(new InMemoryChatMemory())
会将对话缓存在内存中,查看类InMemoryChatMemory
的源码发现,它实际上实现了ChatMemory
接口,实现了 add
,get
以及clear
三个方法。
实现
先添加 Redis 的依赖:
xml
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
然后定义一个类 RedisChatMemory
实现 ChatMemory
接口,实现三个方法:
java
@Override
public void add(String conversationId, List<Message> messages) {
long time = System.currentTimeMillis();
for (Message message : messages) {
redisTemplate.opsForHash().put(conversationId, String.valueOf(time), message);
}
}
@Override
public List<Message> get(String conversationId, int lastN) {
Map<Object, Object> entries = redisTemplate.opsForHash().entries(conversationId);
return entries.entrySet().stream()
.sorted((o1, o2) -> {
long time1 = Long.parseLong(o1.getKey().toString());
long time2 = Long.parseLong(o2.getKey().toString());
return Long.compare(time1, time2);
}).limit(lastN)
.map(e -> new UserMessage(e.getValue().toString()))
.collect(Collectors.toList());
}
@Override
public void clear(String conversationId) {
redisTemplate.delete(conversationId);
}
再把 RedisChatMemory
注册成 Bean 对象:
java
@Bean
public RedisChatMemory redisChatMemory(RedisTemplate<String, Object> redisTemplate) {
return new RedisChatMemory(redisTemplate);
}
最后替换 ChatClient 定义中的 InMemoryChatMemory
:
java
this.chatClient = ChatClient.builder(chatModel)
.defaultSystem(DEFAULT_PROMPT)
// .defaultAdvisors(new MessageChatMemoryAdvisor(new InMemoryChatMemory()))
.defaultAdvisors(new MessageChatMemoryAdvisor(redisChatMemory))
.defaultAdvisors(new SimpleLoggerAdvisor())
.defaultOptions(OpenAiChatOptions.builder().temperature(0.0d).build())
.build();
具体代码:代码地址