langchain4j+redis+持久化存储记忆
一、前言
Langchain4j官方默认模型记忆用内存存储,但是这样只要重启进程,记忆就会消失。这里我们利用redis的持久化存储RDB+AOF来解决这个问题,这样的搭配高效且方便。
二、环境和依赖
Java 17(langchain4j最低支持)
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
<version>1.3.0</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-community-dashscope-spring-boot-starter</artifactId>
<version>1.1.0-beta7</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
<version>3.1.5</version>
</dependency>
当然还需要本地redis环境。
三、实战
配置:
spring:
data:
redis:
host: localhost
port: 6379
MemoryConfig
@Configuration
public class MemoryConfig {
// redis
@Bean
public StringRedisTemplate stringRedisTemplate(RedisConnectionFactory factory) {
return new StringRedisTemplate(factory);
}
@Bean
public RedisChatMemoryStore redisChatMemoryStore(
StringRedisTemplate redisTemplate,
ObjectMapper objectMapper) {
return new RedisChatMemoryStore(redisTemplate, objectMapper);
}
}
RedisConfig
@Configuration
public class RedisConfig {
@Bean
public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory factory) {
RedisTemplate<String, Object> template = new RedisTemplate<>();
template.setConnectionFactory(factory);
// key序列化
StringRedisSerializer stringSerializer = new StringRedisSerializer();
// value序列化(JSON + 支持多态)
GenericJackson2JsonRedisSerializer jsonSerializer =
new GenericJackson2JsonRedisSerializer();
// 设置序列化
template.setKeySerializer(stringSerializer);
template.setHashKeySerializer(stringSerializer);
template.setValueSerializer(jsonSerializer);
template.setHashValueSerializer(jsonSerializer);
template.afterPropertiesSet();
return template;
}
}
ChatMessageDTO
package cn.langchain4j.ai.dto;
import lombok.Getter;
@Getter
public class ChatMessageDTO {
private String role; // user / ai / system
private String content;
public ChatMessageDTO() {}
public ChatMessageDTO(String role, String content) {
this.role = role;
this.content = content;
}
public void setRole(String role) {
this.role = role;
}
public void setContent(String content) {
this.content = content;
}
}
ChatMessageConverter
package cn.langchain4j.ai.memory;
import cn.langchain4j.ai.dto.ChatMessageDTO;
import dev.langchain4j.data.message.*;
public class ChatMessageConverter {
// LangChain4j -> DTO
public static ChatMessageDTO toDTO(ChatMessage msg) {
if (msg instanceof UserMessage m) {
return new ChatMessageDTO("user", m.singleText());
}
if (msg instanceof AiMessage m) {
return new ChatMessageDTO("ai", m.text());
}
if (msg instanceof SystemMessage m) {
return new ChatMessageDTO("system", m.text());
}
throw new IllegalArgumentException("Unknown message type: " + msg.getClass());
}
// DTO -> LangChain4j
public static ChatMessage toDomain(ChatMessageDTO dto) {
return switch (dto.getRole()) {
case "user" -> UserMessage.from(dto.getContent());
case "ai" -> AiMessage.from(dto.getContent());
case "system" -> SystemMessage.from(dto.getContent());
default -> throw new IllegalArgumentException("Unknown role: " + dto.getRole());
};
}
}
RedisChatMemoryStore
package cn.langchain4j.ai.memory;
import cn.langchain4j.ai.dto.ChatMessageDTO;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import org.springframework.data.redis.core.StringRedisTemplate;
import java.util.ArrayList;
import java.util.List;
public class RedisChatMemoryStore implements ChatMemoryStore {
private final StringRedisTemplate redisTemplate;
private final ObjectMapper objectMapper;
public RedisChatMemoryStore(StringRedisTemplate redisTemplate,
ObjectMapper objectMapper) {
this.redisTemplate = redisTemplate;
this.objectMapper = objectMapper;
}
private String buildKey(Object memoryId) {
return "chat_memory:" + memoryId;
}
@Override
public List<ChatMessage> getMessages(Object memoryId) {
try {
String json = redisTemplate.opsForValue().get(buildKey(memoryId));
if (json == null) {
return new ArrayList<>();
}
List<ChatMessageDTO> dtoList =
objectMapper.readValue(json,
new TypeReference<List<ChatMessageDTO>>() {});
return dtoList.stream()
.map(ChatMessageConverter::toDomain)
.toList();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public void updateMessages(Object memoryId, List<ChatMessage> messages) {
try {
List<ChatMessageDTO> dtoList = messages.stream()
.map(ChatMessageConverter::toDTO)
.toList();
String json = objectMapper.writeValueAsString(dtoList);
redisTemplate.opsForValue().set(buildKey(memoryId), json);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public void deleteMessages(Object memoryId) {
redisTemplate.delete(buildKey(memoryId));
}
}
AiCodeHelperServiceFactory
@Configuration
public class AiCodeHelperServiceFactory {
@Resource
private ChatModel qwenChatModel;
@Resource
private ContentRetriever contentRetriever;
@Resource
private McpToolProvider mcpToolProvider;
@Resource
private RedisChatMemoryStore redisChatMemoryStore;
@Resource
private StreamingChatModel streamingChatModel;
@Bean
public AiCodeHelperService aiCodeHelperService(){
// 会话记忆
// 构建
AiCodeHelperService aiCodeHelperService = AiServices.builder(AiCodeHelperService.class)
.streamingChatModel(streamingChatModel)
.chatModel(qwenChatModel)
.chatMemoryProvider(memoryId ->
MessageWindowChatMemory.builder()
.id(memoryId)
.maxMessages(10)
.chatMemoryStore(redisChatMemoryStore) // Redis记忆
.build()
)
// .contentRetriever(contentRetriever) // 内容检索 (启用 RAG)
.tools(new JavaInfoTool(), new RagTool()) // 工具
.toolProvider(mcpToolProvider) // mcp
.build();
return aiCodeHelperService;
}
}