【大模型记忆实战Demo】基于SpringAIAlibaba通过内存和Redis两种方式实现多轮记忆对话

文章目录

多轮对话记忆管理------基于Memory的对话记忆

Spring AI Alibaba共实现了三种方式:

  1. 基于内存的方式
  2. 基于jdbc(数据库)的方式
  3. 基于redis的方式

下文主要演示基于内存和redis的方式

基于内存存储历史对话

  • 代码
    首先定义大模型的角色,一个旅游规划师
    设置增强拦截器
    接着接口传入prompt和chatId
    设定好唯一标识符和记忆轮数
java 复制代码
	private final ChatClient chatClient;

	public ChatMemoryController(ChatModel chatModel) {

		this.chatClient = ChatClient
				.builder(chatModel)
				.defaultSystem("你是一个旅游规划师,请根据用户的需求提供旅游规划建议。")
				.defaultAdvisors(new MessageChatMemoryAdvisor(new InMemoryChatMemory()))
				.build();
	}	
/**
	 * 获取内存中的聊天内容
	 * 根据提供的prompt和chatId,从内存中获取相关的聊天内容,并设置响应的字符编码为UTF-8。
	 *
	 * @param prompt 用于获取聊天内容的提示信息
	 * @param chatId 聊天的唯一标识符,用于区分不同的聊天会话
	 * @param response HTTP响应对象,用于设置响应的字符编码
	 * @return 返回包含聊天内容的Flux<String>对象
	 */
	@GetMapping("/in-memory")
	public Flux<String> memory(
			@RequestParam("prompt") String prompt,
			@RequestParam("chatId") String chatId,
			HttpServletResponse response
	) {

		response.setCharacterEncoding("UTF-8");
		return chatClient.prompt(prompt).advisors(
				a -> a
						.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId)
						.param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 100)
		).stream().content();
	}
  • 调用结果
    第一次:

我提问,想去杭州玩

  • 第二轮

我提问:那有哪些好玩的地方

可以看到,第二次根据我第一次"杭州"的关键词进行了推荐,拥有了记忆

基于内存的方法存在一个缺点:如果机器重启了,记忆就消失了,因此可以采用持久化到Redis的方式

基于Redis存储历史对话

  • 代码
java 复制代码
	private final ChatClient chatClient;

	public ChatMemoryController(ChatModel chatModel) {

		this.chatClient = ChatClient
				.builder(chatModel)
				.defaultSystem("你是一个旅游规划师,请根据用户的需求提供旅游规划建议。")
				.defaultAdvisors(new MessageChatMemoryAdvisor(new RedisChatMemory(
						"127.0.0.1",
						6379,
						null
				)))
				.build();
	}
	/**
	 * 从Redis中获取聊天内容
	 * 根据提供的prompt和chatId,从Redis中检索聊天内容,并以Flux<String>的形式返回
	 *
	 * @param prompt 聊天内容的提示或查询关键字
	 * @param chatId 聊天的唯一标识符,用于从Redis中检索特定的聊天内容
	 * @param response HttpServletResponse对象,用于设置响应的字符编码为UTF-8
	 * @return Flux<String> 包含聊天内容的反应式流
	 */
	@GetMapping("/redis")
	public Flux<String> redis(
			@RequestParam("prompt") String prompt,
			@RequestParam("chatId") String chatId,
			HttpServletResponse response
	) {

		response.setCharacterEncoding("UTF-8");

		return chatClient.prompt(prompt)
				.advisors(
				a -> a
						.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId)
						.param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10)
				)
				.stream().content();
	}

其中的RedisChatMemory:

java 复制代码
/**
 *
 * 基于Redis的聊天记忆实现。
 * 该类实现了ChatMemory接口,提供了将聊天消息存储到Redis中的功能。
 *
 * @author Fox
 */
public class RedisChatMemory implements ChatMemory, AutoCloseable {

    private static final Logger logger = LoggerFactory.getLogger(RedisChatMemory.class);

    private static final String DEFAULT_KEY_PREFIX = "chat:";

    private static final String DEFAULT_HOST = "127.0.0.1";

    private static final int DEFAULT_PORT = 6379;

    private static final String DEFAULT_PASSWORD = null;

    private final JedisPool jedisPool;


    private final ObjectMapper objectMapper;

    public RedisChatMemory() {

        this(DEFAULT_HOST, DEFAULT_PORT, DEFAULT_PASSWORD);
    }

    public RedisChatMemory(String host, int port, String password) {

        JedisPoolConfig poolConfig = new JedisPoolConfig();

        this.jedisPool = new JedisPool(poolConfig, host, port, 2000, password);
        this.objectMapper = new ObjectMapper();
        logger.info("Connected to Redis at {}:{}", host, port);
    }

    @Override
    public void add(String conversationId, List<Message> messages) {

        String key = DEFAULT_KEY_PREFIX + conversationId;

        AtomicLong timestamp = new AtomicLong(System.currentTimeMillis());

        try (Jedis jedis = jedisPool.getResource()) {
            // 使用pipeline批量操作提升性能
            var pipeline = jedis.pipelined();
            messages.forEach(message ->
                    pipeline.hset(key, String.valueOf(timestamp.getAndIncrement()), message.toString())
            );
            pipeline.sync();
        }

        logger.info("Added messages to conversationId: {}", conversationId);
    }

    @Override
    public List<Message> get(String conversationId, int lastN) {

        String key = DEFAULT_KEY_PREFIX + conversationId;

        try (Jedis jedis = jedisPool.getResource()) {
            Map<String, String> allMessages = jedis.hgetAll(key);
            if (allMessages.isEmpty()) {
                return List.of();
            }

            return allMessages.entrySet().stream()
                    .sorted((e1, e2) ->
                            Long.compare(Long.parseLong(e2.getKey()), Long.parseLong(e1.getKey()))
                    )
                    .limit(lastN)
                    .map(entry -> new UserMessage(entry.getValue()))
                    .collect(Collectors.toList());
        }


    }

    @Override
    public void clear(String conversationId) {

        String key = DEFAULT_KEY_PREFIX + conversationId;

        try (Jedis jedis = jedisPool.getResource()) {
            jedis.del(key);
        }
        logger.info("Cleared messages for conversationId: {}", conversationId);
    }

    @Override
    public void close() {
        try (Jedis jedis = jedisPool.getResource()) {
            if (jedis != null) {

                jedis.close();

                logger.info("Redis connection closed.");
            }
            if (jedisPool != null) {

                jedisPool.close();

                logger.info("Jedis pool closed.");
            }
        }

    }

    public void clearOverLimit(String conversationId, int maxLimit, int deleteSize) {
        try {
            String key = DEFAULT_KEY_PREFIX + conversationId;
            try (Jedis jedis = jedisPool.getResource()) {
                List<String> all = jedis.lrange(key, 0, -1);

                if (all.size() >= maxLimit) {
                    all = all.stream().skip(Math.max(0, deleteSize)).toList();
                }
                this.clear(conversationId);
                for (String message : all) {
                    jedis.rpush(key, message);
                }
            }
        }
        catch (Exception e) {
            logger.error("Error clearing messages from Redis chat memory", e);
            throw new RuntimeException(e);
        }
    }

}
  • 第一次调用

提问:我想去三亚

可见成功将提问和回答都写入redis

  • 第二次调用

成功读取记忆,并将新的问答结果写入

至此,我们成功完成了基于内存和Redis两种方式,实现大模型的多轮记忆对话!