Spring AI Alibaba 1.x 系列【13】 检查点 (Checkpoint) 机制及各类持久化实现

文章目录

  • [1. 概述](#1. 概述)
  • [2. Checkpoint 类](#2. Checkpoint 类)
  • [3. BaseCheckpointSaver 接口](#3. BaseCheckpointSaver 接口)
    • [2.1 ThreadId](#2.1 ThreadId)
    • [2.2 Tag](#2.2 Tag)
    • [2.3 getLast()](#2.3 getLast())
    • [2.4 list()](#2.4 list())
    • [2.5 get()](#2.5 get())
    • [2.6 put()](#2.6 put())
    • [2.7 release()](#2.7 release())
  • [4. BaseCheckpointSaver 实现类](#4. BaseCheckpointSaver 实现类)
    • [4.1 RedisSaver](#4.1 RedisSaver)
    • [4.2 MongoSaver](#4.2 MongoSaver)
    • [4.3 VersionedMemorySaver](#4.3 VersionedMemorySaver)
    • [4.4 MemorySaver](#4.4 MemorySaver)
      • [4.4.1 MysqlSaver](#4.4.1 MysqlSaver)
      • [4.4.2 OracleSaver](#4.4.2 OracleSaver)
      • [4.4.3 FileSystemSaver](#4.4.3 FileSystemSaver)
      • [4.4.4 PostgresSaver](#4.4.4 PostgresSaver)

1. 概述

Checkpoint(检查点) 是一种通用的系统机制,广泛应用于数据库恢复、分布式计算领域。在 Spring AI Alibaba 中,检查点是 Agent 在执行过程中某一时刻的状态快照,包含:

  • 消息历史: 当前的对话消息列表
  • 执行状态 : Agent 当前的执行阶段和状态
  • 上下文信息 : RunnableConfig 中的配置信息
  • 时间戳: 检查点创建的时间

检查点机制允许 Agent 在执行过程中的状态被保存,从而支持:

  • 状态恢复 :从保存的检查点恢复 Agent 执行状态
  • 会话持久化:维护多轮对话的上下文
  • 执行追踪 :记录 Agent 的完整执行历史状态
  • 中断恢复 :支持 Agent 执行中断后可以从检查点继续

2. Checkpoint 类

Checkpoint 类就是 Spring AI Alibaba 中检查点表示,保存了某一时刻的完整状态。

java 复制代码
public class Checkpoint {

    // 唯一标识符 (UUID)
    private final String id;

    // 图状态数据 (Map 形式)
    private Map<String, Object> state = null;

    // 当前节点 ID
    private String nodeId = null;

    // 下一个节点 ID
    private String nextNodeId = null;

    // ... 构造函数和方法
}

字段定义表:

字段名 数据类型 核心作用 说明与示例
id String Checkpoint 唯一标识符,用于查找、更新、删除 由 UUID 生成,被 final 修饰,创建后不可修改
state Map<String, Object> 存储图的完整状态数据,用于恢复执行时作为初始状态 来源于 OverAllState.data() 快照,包含 messages 对话消息列表及其他自定义状态键
nodeId String 标识创建该 Checkpoint 时正在执行的节点 示例:agent_llm、agent_tool、custom_node
nextNodeId String 标识恢复执行时的起始节点,决定流程恢复入口 示例:agent_tool、agent_llm、END

私有构造函数 (Jackson 序列化支持):

java 复制代码
@JsonCreator // Jackson 反序列化时使用此构造函数
private Checkpoint(
    @JsonProperty("id") String id, // 指定 JSON 属性名与字段映射
    @JsonProperty("state") Map<String, Object> state,
    @JsonProperty("nodeId") String nodeId,
    @JsonProperty("nextNodeId") String nextNodeId) {
    
		// 所有字段使用 requireNonNull 确保非空
    this.id = requireNonNull(id, "id cannot be null");
    this.state = requireNonNull(state, "state cannot be null");
    this.nodeId = requireNonNull(nodeId, "nodeId cannot be null");
    this.nextNodeId = requireNonNull(nextNodeId, "Checkpoint.nextNodeId cannot be null");
}

状态更新方法,调用 OverAllState.updateState() 合并状态,返回新 Checkpoint

java 复制代码
public Checkpoint updateState(Map<String, Object> values, Map<String, KeyStrategy> channels) {
    return new Checkpoint(
        this.id,  // ID 不变
        OverAllState.updateState(this.state, values, channels),  // 合并状态
        this.nodeId,
        this.nextNodeId
    );
}

静态工厂方法,创建副本 (新 ID,其他字段相同):

java 复制代码
public static Checkpoint copyOf(Checkpoint checkpoint) {
    requireNonNull(checkpoint, "checkpoint cannot be null");
    return new Checkpoint(
        UUID.randomUUID().toString(),  // 新 ID
        checkpoint.state,               // 复制状态
        checkpoint.nodeId,              // 复制 nodeId
        checkpoint.nextNodeId           // 复制 nextNodeId
    );
}

// 获取 Builder
public static Builder builder() {
    return new Builder();
}

Builder 构造方式:

java 复制代码
public static Builder builder() {
    return new Builder();
}

public static class Builder {

    private String id = UUID.randomUUID().toString();  // 默认自动生成
    private Map<String, Object> state = null;
    private String nodeId = null;
    private String nextNodeId = null;

    public Builder id(String id) {
        this.id = id;
        return this;
    }

    public Builder state(OverAllState state) {
        this.state = state.data();  // 从 OverAllState 提取数据
        return this;
    }

    public Builder state(Map<String, Object> state) {
        this.state = state;  // 直接使用 Map
        return this;
    }

    public Builder nodeId(String nodeId) {
        this.nodeId = nodeId;
        return this;
    }

    public Builder nextNodeId(String nextNodeId) {
        this.nextNodeId = nextNodeId;
        return this;
    }

    public Checkpoint build() {
        return new Checkpoint(id, state, nodeId, nextNodeId);
    }
}

3. BaseCheckpointSaver 接口

检查点保存器的基础接口,用于管理 Agent 执行状态的持久化和恢复。

java 复制代码
public interface BaseCheckpointSaver {
    String THREAD_ID_DEFAULT = "$default";

    // 获取最新检查点(默认实现)
    default Optional<Checkpoint> getLast(LinkedList<Checkpoint> checkpoints, RunnableConfig config);

    // 核心方法(需要实现)
    Collection<Checkpoint> list(RunnableConfig config);
    Optional<Checkpoint> get(RunnableConfig config);
    RunnableConfig put(RunnableConfig config, Checkpoint checkpoint) throws Exception;
    Tag release(RunnableConfig config) throws Exception;

    // 标签记录类
    record Tag(String threadId, Collection<Checkpoint> checkpoints) { }
}

默认提供的实现类:

2.1 ThreadId

ThreadId 是检查点管理的关键标识,实现多用户/多会话的对话历史隔离,默认值为 $default

使用场景

  • 多用户系统:每个用户有独立的 threadId
  • 多会话管理:同一用户的不同会话使用不同 threadId
java 复制代码
    /**
     * 默认线程ID,用于没有指定 threadId 的场景。
     * <p>
     * 当 RunnableConfig 未设置 threadId 时,使用此默认值。
     * 这确保所有未指定线程的会话使用相同的默认存储空间。
     */
    String THREAD_ID_DEFAULT = "$default";

2.2 Tag

Tag 是释放会话时返回的结果,包含被释放的 threadId 和该会话的所有 Checkpoints

java 复制代码
    /**
     * 标签记录类,用于封装检查点释放操作的结果。
     * <p>
     * Tag 包含两个信息:
     * <ul>
     *   <li>threadId: 被释放的线程ID</li>
     *   <li>checkpoints: 被释放的检查点列表(不可变副本)</li>
     * </ul>
     * <p>
     * checkpoints 使用不可变列表存储,确保释放后的检查点列表
     * 不会被外部修改,提供数据安全性。
     *
     * @param threadId 线程ID
     * @param checkpoints 检查点集合
     */
    record Tag(String threadId, Collection<Checkpoint> checkpoints) {

        /**
         * Tag 的紧凑构造函数。
         * <p>
         * 此构造函数确保 checkpoints 被转换为不可变列表,
         * 防止外部修改。如果传入 null,则创建空列表。
         *
         * @param threadId 线程ID
         * @param checkpoints 检查点集合(会被转换为不可变列表)
         */
        public Tag(String threadId, Collection<Checkpoint> checkpoints) {
            this.threadId = threadId;
            this.checkpoints = ofNullable(checkpoints)
                    .map(List::copyOf)
                    .orElseGet(List::of);
        }
    }

2.3 getLast()

默认的获取最新检查点方法,子类可直接使用或覆盖:

java 复制代码
    /**
     * 从检查点列表中获取最后一个(最新的)检查点。
     * <p>
     * 这是一个默认实现,直接从 LinkedList 的头部获取最新检查点。
     * 检查点列表通常按时间顺序排列,最新的检查点在最前面。
     * <p>
     * 子类可以覆盖此方法以提供自定义的最新检查点获取逻辑。
     *
     * @param checkpoints 检查点列表,按时间顺序排列(最新在最前)
     * @param config 运行配置,包含 threadId 等信息
     * @return 最新的检查点,如果列表为空则返回 Optional.empty()
     */
    default Optional<Checkpoint> getLast(LinkedList<Checkpoint> checkpoints, RunnableConfig config) {
        return (checkpoints.isEmpty()) ? Optional.empty() : ofNullable(checkpoints.peek());
    }

2.4 list()

获取指定配置下的所有检查点列表:

java 复制代码
    /**
     * 获取指定配置下的所有检查点列表。
     * <p>
     * 返回指定 threadId 关联的所有检查点,按时间顺序排列。
     * 通常用于:
     * <ul>
     *   <li>查看会话的完整执行历史</li>
     *   <li>选择特定历史状态进行恢复</li>
     *   <li>分析 Agent 执行轨迹</li>
     * </ul>
     *
     * @param config 运行配置,包含 threadId 用于定位检查点
     * @return 该 threadId 关联的所有检查点集合
     */
    Collection<Checkpoint> list(RunnableConfig config);

2.5 get()

获取指定配置下的当前检查点:

java 复制代码
    /**
     * 获取指定配置下的当前检查点。
     * <p>
     * 返回指定 threadId 的最新检查点,用于恢复或继续 Agent 执行。
     * 这是状态恢复的核心方法,Agent 从此检查点恢复执行状态。
     *
     * @param config 运行配置,包含 threadId 用于定位检查点
     * @return 当前检查点,如果不存在则返回 Optional.empty()
     */
    Optional<Checkpoint> get(RunnableConfig config);

2.6 put()

保存检查点到存储中:

java 复制代码
    /**
     * 保存检查点到存储中。
     * <p>
     * 将当前 Agent 执行状态保存为检查点,关联到指定的 threadId。
     * 返回更新后的 RunnableConfig,可能包含新的 checkpointId 等信息。
     * <p>
     * 此方法通常在 Agent 执行的关键节点被调用:
     * <ul>
     *   <li>Agent 循环开始前</li>
     *   <li>模型调用完成后</li>
     *   <li>工具执行完成后</li>
     *   <li>Agent 执行中断时</li>
     * </ul>
     *
     * @param config 运行配置,包含 threadId 用于关联检查点
     * @param checkpoint 要保存的检查点对象
     * @return 更新后的运行配置
     * @throws Exception 保存过程中可能发生的异常
     */
    RunnableConfig put(RunnableConfig config, Checkpoint checkpoint) throws Exception;

2.7 release()

释放指定配置下的检查点资源:

java 复制代码
    /**
     * 释放指定配置下的检查点资源。
     * <p>
     * 此方法用于清理或释放检查点资源,返回一个 Tag 对象
     * 包含被释放的 threadId 和相关的检查点列表。
     * <p>
     * 典型使用场景:
     * <ul>
     *   <li>会话结束后的清理</li>
     *   <li>释放存储空间</li>
     *   <li>重置会话状态</li>
     * </ul>
     *
     * @param config 运行配置,包含 threadId 用于定位要释放的资源
     * @return 包含 threadId 和检查点列表的 Tag 对象
     * @throws Exception 释放过程中可能发生的异常
     */
    Tag release(RunnableConfig config) throws Exception;

4. BaseCheckpointSaver 实现类

4.1 RedisSaver

基于 Redis + Redisson 实现的分布式检查点存储,支持线程隔离、序列化/反序列化、分布式锁保证线程安全。

Redis Key 前缀常量:

java 复制代码
	// ==================== Redis Key 前缀常量 ====================
	/** 检查点内容存储 Key 前缀 */
	private static final String CHECKPOINT_PREFIX = "graph:checkpoint:content:";
	/** 线程元数据存储 Key 前缀 */
	private static final String THREAD_META_PREFIX = "graph:thread:meta:";
	/** 线程反向映射 Key 前缀 */
	private static final String THREAD_REVERSE_PREFIX = "graph:thread:reverse:";
	/** 分布式锁 Key 前缀 */
	private static final String LOCK_PREFIX = "graph:checkpoint:lock:";

线程元数据 Hash 字段常量:

java 复制代码
	/** 线程ID字段 */
	private static final String FIELD_THREAD_ID = "thread_id";
	/** 线程是否释放字段 */
	private static final String FIELD_IS_RELEASED = "is_released";
	/** 线程名称字段 */
	private static final String FIELD_THREAD_NAME = "thread_name";

检查点序列化器:

java 复制代码
	/** 检查点序列化器 */
	private final Serializer<Checkpoint> checkpointSerializer;

Redisson 客户端,用于操作 Redis

java 复制代码
	/** Redisson 客户端,用于操作 Redis */
	private RedissonClient redisson;

构造方法:

java 复制代码
	/**
	 * 构造方法
	 * 禁止直接外部实例化,需通过 builder() 构建
	 * @param redisson Redisson客户端实例
	 * @param stateSerializer 状态序列化器
	 */
	protected RedisSaver(RedissonClient redisson, StateSerializer stateSerializer) {
		requireNonNull(redisson, "redisson 客户端不能为空");
		requireNonNull(stateSerializer, "状态序列化器不能为空");
		this.redisson = redisson;
		this.checkpointSerializer = new CheckPointSerializer(stateSerializer);
	}

实现了 BaseCheckpointSaver 声明的相关方法:

java 复制代码
    /**
     * 序列化检查点列表为 Base64 字符串
     * @param checkpoints 检查点集合
     * @return Base64 编码字符串
     * @throws IOException 序列化异常
     */
    private String serializeCheckpoints(List<Checkpoint> checkpoints) throws IOException {
        try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
             ObjectOutputStream oos = new ObjectOutputStream(baos)) {
            oos.writeInt(checkpoints.size());
            for (Checkpoint checkpoint : checkpoints) {
                checkpointSerializer.write(checkpoint, oos);
            }
            oos.flush();
            byte[] bytes = baos.toByteArray();
            return Base64.getEncoder().encodeToString(bytes);
        }
    }

    /**
     * 反序列化 Base64 字符串为检查点链表
     * @param content Redis 存储的 Base64 字符串
     * @return 检查点链表
     * @throws IOException 反序列化IO异常
     * @throws ClassNotFoundException 类未找到异常
     */
    private LinkedList<Checkpoint> deserializeCheckpoints(String content) throws IOException, ClassNotFoundException {
        if (content == null || content.isEmpty()) {
            return new LinkedList<>();
        }
        byte[] bytes = Base64.getDecoder().decode(content);
        try (ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
             ObjectInputStream ois = new ObjectInputStream(bais)) {
            int size = ois.readInt();
            LinkedList<Checkpoint> checkpoints = new LinkedList<>();
            for (int i = 0; i < size; i++) {
                checkpoints.add(checkpointSerializer.read(ois));
            }
            return checkpoints;
        }
    }

    /**
     * 根据线程名称获取/创建线程ID
     * 存在活跃线程则直接返回,否则创建新线程ID并初始化元数据
     * @param threadName 线程名称
     * @return 线程唯一ID(UUID格式)
     */
    private String getOrCreateThreadId(String threadName) {
        String metaKey = THREAD_META_PREFIX + threadName;
        RMap<String, String> meta = redisson.getMap(metaKey);

        // 查询是否存在活跃线程
        String threadId = meta.get(FIELD_THREAD_ID);
        String isReleased = meta.get(FIELD_IS_RELEASED);

        if (threadId != null && !"true".equals(isReleased)) {
            // 存在活跃线程,直接返回
            return threadId;
        }

        // 无活跃线程/线程已释放,创建新线程ID
        String newThreadId = UUID.randomUUID().toString();
        meta.put(FIELD_THREAD_ID, newThreadId);
        meta.put(FIELD_IS_RELEASED, "false");

        // 维护反向映射关系:线程ID -> 线程名称
        String reverseKey = THREAD_REVERSE_PREFIX + newThreadId;
        RMap<String, String> reverse = redisson.getMap(reverseKey);
        reverse.put(FIELD_THREAD_NAME, threadName);
        reverse.put(FIELD_IS_RELEASED, "false");

        return newThreadId;
    }

    /**
     * 根据线程名称获取活跃线程ID
     * 仅返回未释放的线程ID,无活跃线程返回null
     * @param threadName 线程名称
     * @return 活跃线程ID / null
     */
    private String getActiveThreadId(String threadName) {
        String metaKey = THREAD_META_PREFIX + threadName;
        RMap<String, String> meta = redisson.getMap(metaKey);

        String threadId = meta.get(FIELD_THREAD_ID);
        String isReleased = meta.get(FIELD_IS_RELEASED);

        if (threadId != null && !"true".equals(isReleased)) {
            return threadId;
        }
        // 无活跃线程
        return null;
    }

    /**
     * 查询指定线程的所有检查点
     * @param config 运行配置(必须包含线程ID)
     * @return 检查点集合
     */
    @Override
    public Collection<Checkpoint> list(RunnableConfig config) {
        Optional<String> threadNameOpt = config.threadId();
        if (!threadNameOpt.isPresent()) {
            throw new IllegalArgumentException("threadId 不允许为空");
        }

        String threadName = threadNameOpt.get();
        RLock lock = redisson.getLock(LOCK_PREFIX + threadName);
        boolean tryLock = false;
        try {
            // 读操作锁超时时间:500ms
            tryLock = lock.tryLock(500, TimeUnit.MILLISECONDS);
            if (!tryLock) {
                return List.of();
            }

            // 获取活跃线程ID
            String threadId = getActiveThreadId(threadName);
            if (threadId == null) {
                return List.of();
            }

            // 查询Redis中的检查点数据
            RBucket<String> bucket = redisson.getBucket(CHECKPOINT_PREFIX + threadId);
            String content = bucket.get();
            return deserializeCheckpoints(content);

        }
        catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
        catch (IOException | ClassNotFoundException e) {
            throw new RuntimeException("检查点反序列化失败", e);
        }
        finally {
            if (lock.isHeldByCurrentThread()) {
                lock.unlock();
            }
        }
    }

    /**
     * 获取单个检查点
     * 支持按checkPointId精确查询,无ID则返回最新检查点
     * @param config 运行配置
     * @return 检查点Optional对象
     */
    @Override
    public Optional<Checkpoint> get(RunnableConfig config) {
        Optional<String> threadNameOpt = config.threadId();
        if (!threadNameOpt.isPresent()) {
            throw new IllegalArgumentException("threadId 不允许为空");
        }

        String threadName = threadNameOpt.get();
        RLock lock = redisson.getLock(LOCK_PREFIX + threadName);
        boolean tryLock = false;
        try {
            // 读操作锁超时时间:500ms
            tryLock = lock.tryLock(500, TimeUnit.MILLISECONDS);
            if (!tryLock) {
                return Optional.empty();
            }

            // 获取活跃线程ID
            String threadId = getActiveThreadId(threadName);
            if (threadId == null) {
                return Optional.empty();
            }

            // 查询并反序列化检查点
            RBucket<String> bucket = redisson.getBucket(CHECKPOINT_PREFIX + threadId);
            String content = bucket.get();
            LinkedList<Checkpoint> checkpoints = deserializeCheckpoints(content);

            // 按ID精确匹配
            if (config.checkPointId().isPresent()) {
                return config.checkPointId()
                        .flatMap(id -> checkpoints.stream()
                                .filter(checkpoint -> checkpoint.getId().equals(id))
                                .findFirst());
            }
            // 返回最后一个检查点
            return getLast(checkpoints, config);

        }
        catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
        catch (IOException | ClassNotFoundException e) {
            throw new RuntimeException("检查点反序列化失败", e);
        }
        finally {
            if (lock.isHeldByCurrentThread()) {
                lock.unlock();
            }
        }
    }

    /**
     * 保存/更新检查点
     * 存在checkPointId则更新,不存在则新增
     * @param config 运行配置
     * @param checkpoint 检查点对象
     * @return 更新后的运行配置
     * @throws Exception 保存异常
     */
    @Override
    public RunnableConfig put(RunnableConfig config, Checkpoint checkpoint) throws Exception {
        Optional<String> threadNameOpt = config.threadId();
        if (!threadNameOpt.isPresent()) {
            throw new IllegalArgumentException("threadId 不允许为空");
        }

        String threadName = threadNameOpt.get();
        RLock lock = redisson.getLock(LOCK_PREFIX + threadName);
        boolean tryLock = false;
        try {
            // 写操作锁超时时间:3秒(并发场景更长等待)
            tryLock = lock.tryLock(3, TimeUnit.SECONDS);
            if (!tryLock) {
                throw new RuntimeException("获取线程锁失败:" + threadName);
            }

            // 获取/创建线程ID
            String threadId = getOrCreateThreadId(threadName);

            // 查询现有检查点
            RBucket<String> bucket = redisson.getBucket(CHECKPOINT_PREFIX + threadId);
            String content = bucket.get();
            LinkedList<Checkpoint> checkpoints = deserializeCheckpoints(content);

            if (config.checkPointId().isPresent()) {
                // 更新已有检查点
                String checkPointId = config.checkPointId().get();
                int index = IntStream.range(0, checkpoints.size())
                        .filter(i -> checkpoints.get(i).getId().equals(checkPointId))
                        .findFirst()
                        .orElseThrow(() -> new NoSuchElementException(
                                format("未找到ID为 %s 的检查点!", checkPointId)));
                checkpoints.set(index, checkpoint);
            }
            else {
                // 新增检查点
                checkpoints.push(checkpoint);
            }

            // 序列化并保存到Redis
            bucket.set(serializeCheckpoints(checkpoints));
            return RunnableConfig.builder(config).checkPointId(checkpoint.getId()).build();

        }
        catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
        catch (IOException | ClassNotFoundException e) {
            throw new RuntimeException("检查点序列化/反序列化失败", e);
        }
        finally {
            if (lock.isHeldByCurrentThread()) {
                lock.unlock();
            }
        }
    }

    /**
     * 释放线程检查点
     * 标记线程为已释放状态,生成并返回检查点标签
     * @param config 运行配置
     * @return 检查点标签对象
     * @throws Exception 释放异常
     */
    @Override
    public Tag release(RunnableConfig config) throws Exception {
        Optional<String> threadNameOpt = config.threadId();
        if (!threadNameOpt.isPresent()) {
            throw new IllegalArgumentException("threadId 不允许为空");
        }

        String threadName = threadNameOpt.get();
        RLock lock = redisson.getLock(LOCK_PREFIX + threadName);
        boolean tryLock = false;
        try {
            // 写操作锁超时时间:3秒
            tryLock = lock.tryLock(3, TimeUnit.SECONDS);
            if (!tryLock) {
                throw new RuntimeException("获取线程锁失败:" + threadName);
            }

            // 获取线程元数据
            String metaKey = THREAD_META_PREFIX + threadName;
            RMap<String, String> meta = redisson.getMap(metaKey);

            String threadId = meta.get(FIELD_THREAD_ID);
            if (threadId == null) {
                throw new IllegalStateException("线程不存在:" + threadName);
            }

            // 标记线程为已释放
            meta.put(FIELD_IS_RELEASED, "true");

            // 更新反向映射状态
            String reverseKey = THREAD_REVERSE_PREFIX + threadId;
            RMap<String, String> reverse = redisson.getMap(reverseKey);
            if (reverse != null) {
                reverse.put(FIELD_IS_RELEASED, "true");
            }

            // 获取检查点数据并封装Tag
            String contentKey = CHECKPOINT_PREFIX + threadId;
            RBucket<String> bucket = redisson.getBucket(contentKey);
            String content = bucket.get();
            Collection<Checkpoint> checkpoints = deserializeCheckpoints(content);

            return new Tag(threadName, checkpoints);

        }
        catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
        catch (IOException | ClassNotFoundException e) {
            throw new RuntimeException("检查点反序列化失败", e);
        }
        finally {
            if (lock.isHeldByCurrentThread()) {
                lock.unlock();
            }
        }
    }

4.2 MongoSaver

基于 MongoDB 实现的分布式检查点存储,支持事务、原子操作、线程隔离,实现 Checkpoint 持久化、查询、更新、释放等功能。

核心属性和常量:

java 复制代码
public class MongoSaver implements BaseCheckpointSaver {

    /** 日志对象 */
    private static final Logger logger = LoggerFactory.getLogger(MongoSaver.class);

    /** MongoDB 数据库名称 */
    private static final String DB_NAME = "check_point_db";

    /** 线程元数据集合名称 */
    private static final String THREAD_META_COLLECTION = "thread_meta";

    /** 检查点数据集合名称 */
    private static final String CHECKPOINT_COLLECTION = "checkpoint_collection";

    /** 线程元数据 Key 前缀 */
    private static final String THREAD_META_PREFIX = "mongo:thread:meta:";

    /** 检查点内容 Key 前缀 */
    private static final String CHECKPOINT_PREFIX = "mongo:checkpoint:content:";

    /** 检查点内容存储字段名 */
    private static final String DOCUMENT_CONTENT_KEY = "checkpoint_content";

    // ==================== 线程元数据文档字段 ====================
    /** 线程ID字段 */
    private static final String FIELD_THREAD_ID = "thread_id";
    /** 线程是否已释放字段 */
    private static final String FIELD_IS_RELEASED = "is_released";
    /** 线程名称字段 */
    private static final String FIELD_THREAD_NAME = "thread_name";

    /** 检查点序列化器 */
    private final Serializer<Checkpoint> checkpointSerializer;

    /** MongoDB 客户端 */
    private MongoClient client;

    /** MongoDB 数据库实例 */
    private MongoDatabase database;

    /** 事务配置选项(多数写确认) */
    private TransactionOptions txnOptions;

构造方法:

java 复制代码
    /**
     * 构造方法
     * 禁止直接外部实例化,需通过 builder() 构建
     * @param client MongoDB客户端
     * @param stateSerializer 状态序列化器
     */
    protected MongoSaver(MongoClient client, StateSerializer stateSerializer) {
        Objects.requireNonNull(client, "MongoDB客户端不能为空");
        Objects.requireNonNull(stateSerializer, "状态序列化器不能为空");
        this.client = client;
        this.database = client.getDatabase(DB_NAME);
        // 配置事务:多数节点确认写入
        this.txnOptions = TransactionOptions.builder().writeConcern(WriteConcern.MAJORITY).build();
        this.checkpointSerializer = new CheckPointSerializer(stateSerializer);
        // 注册JVM关闭钩子,自动关闭Mongo客户端
        Runtime.getRuntime().addShutdownHook(new Thread(client::close));
    }

实现了 BaseCheckpointSaver 相关方法:

java 复制代码
    /**
     * 序列化检查点列表为Base64字符串
     * @param checkpoints 检查点集合
     * @return Base64编码字符串
     * @throws IOException 序列化IO异常
     */
    private String serializeCheckpoints(List<Checkpoint> checkpoints) throws IOException {
        try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
             ObjectOutputStream oos = new ObjectOutputStream(baos)) {
            oos.writeInt(checkpoints.size());
            for (Checkpoint checkpoint : checkpoints) {
                checkpointSerializer.write(checkpoint, oos);
            }
            oos.flush();
            byte[] bytes = baos.toByteArray();
            return Base64.getEncoder().encodeToString(bytes);
        }
    }

    /**
     * 反序列化Base64字符串为检查点链表
     * @param content MongoDB存储的Base64字符串
     * @return 检查点链表
     * @throws IOException 反序列化IO异常
     * @throws ClassNotFoundException 类未找到异常
     */
    private LinkedList<Checkpoint> deserializeCheckpoints(String content) throws IOException, ClassNotFoundException {
        if (content == null || content.isEmpty()) {
            return new LinkedList<>();
        }
        byte[] bytes = Base64.getDecoder().decode(content);
        try (ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
             ObjectInputStream ois = new ObjectInputStream(bais)) {
            int size = ois.readInt();
            LinkedList<Checkpoint> checkpoints = new LinkedList<>();
            for (int i = 0; i < size; i++) {
                checkpoints.add(checkpointSerializer.read(ois));
            }
            return checkpoints;
        }
    }

    /**
     * 原子性获取或创建线程ID
     * 存在活跃线程直接返回,无活跃线程则创建新ID,保证并发安全
     * @param threadName 线程名称
     * @param clientSession MongoDB会话(事务)
     * @return 线程唯一ID(UUID格式)
     */
    private String getOrCreateThreadId(String threadName, ClientSession clientSession) {
        MongoCollection<Document> threadMetaCollection = database.getCollection(THREAD_META_COLLECTION);
        String metaId = THREAD_META_PREFIX + threadName;

        // 步骤1:原子查询活跃线程(未释放)
        Document activeThreadFilter = new Document("_id", metaId)
                .append(FIELD_IS_RELEASED, new Document("$ne", true));

        FindOneAndUpdateOptions findOptions = new FindOneAndUpdateOptions()
                .returnDocument(ReturnDocument.AFTER);

        // 原子读取,保证并发安全
        Document existingDoc = threadMetaCollection.findOneAndUpdate(
                clientSession,
                activeThreadFilter,
                Updates.currentDate("_lastAccessed"), // 空更新,仅用于原子读取
                findOptions
        );

        if (existingDoc != null) {
            String threadId = existingDoc.getString(FIELD_THREAD_ID);
            if (threadId != null) {
                // 存在活跃线程,直接返回
                return threadId;
            }
        }

        // 步骤2:无活跃线程,原子创建新线程
        String newThreadId = UUID.randomUUID().toString();
        FindOneAndUpdateOptions upsertOptions = new FindOneAndUpdateOptions()
                .upsert(true)
                .returnDocument(ReturnDocument.AFTER);

        // 不存在则创建,存在则不修改
        Document createResult = threadMetaCollection.findOneAndUpdate(
                clientSession,
                Filters.eq("_id", metaId),
                Updates.combine(
                        Updates.setOnInsert(FIELD_THREAD_ID, newThreadId),
                        Updates.setOnInsert(FIELD_IS_RELEASED, false)
                ),
                upsertOptions
        );

        if (createResult != null) {
            Boolean isReleased = createResult.getBoolean(FIELD_IS_RELEASED, false);
            String existingThreadId = createResult.getString(FIELD_THREAD_ID);

            // 线程有效,直接返回
            if (existingThreadId != null && !Boolean.TRUE.equals(isReleased)) {
                return existingThreadId;
            }

            // 线程已释放,原子更新为新线程
            if (Boolean.TRUE.equals(isReleased)) {
                Document updateResult = threadMetaCollection.findOneAndUpdate(
                        clientSession,
                        Filters.and(
                                Filters.eq("_id", metaId),
                                Filters.eq(FIELD_IS_RELEASED, true)
                        ),
                        Updates.combine(
                                Updates.set(FIELD_THREAD_ID, newThreadId),
                                Updates.set(FIELD_IS_RELEASED, false)
                        ),
                        new FindOneAndUpdateOptions().returnDocument(ReturnDocument.AFTER)
                );

                if (updateResult != null) {
                    return updateResult.getString(FIELD_THREAD_ID);
                }

                // 并发冲突,重新查询
                Document finalDoc = threadMetaCollection.find(clientSession, new BasicDBObject("_id", metaId)).first();
                if (finalDoc != null) {
                    String finalThreadId = finalDoc.getString(FIELD_THREAD_ID);
                    Boolean finalIsReleased = finalDoc.getBoolean(FIELD_IS_RELEASED, false);
                    if (finalThreadId != null && !Boolean.TRUE.equals(finalIsReleased)) {
                        return finalThreadId;
                    }
                }
            }
        }

        // 最终兜底查询
        Document finalDoc = threadMetaCollection.find(clientSession, new BasicDBObject("_id", metaId)).first();
        if (finalDoc != null) {
            String finalThreadId = finalDoc.getString(FIELD_THREAD_ID);
            if (finalThreadId != null) {
                return finalThreadId;
            }
        }

        return newThreadId;
    }

    /**
     * 获取活跃线程ID
     * 仅返回未释放的线程,无活跃线程返回null
     * @param threadName 线程名称
     * @param clientSession MongoDB会话
     * @return 活跃线程ID / null
     */
    private String getActiveThreadId(String threadName, ClientSession clientSession) {
        MongoCollection<Document> threadMetaCollection = database.getCollection(THREAD_META_COLLECTION);
        String metaId = THREAD_META_PREFIX + threadName;

        Document metaDoc = threadMetaCollection.find(clientSession, new BasicDBObject("_id", metaId)).first();

        if (metaDoc != null) {
            String threadId = metaDoc.getString(FIELD_THREAD_ID);
            Boolean isReleased = metaDoc.getBoolean(FIELD_IS_RELEASED, false);

            if (threadId != null && !Boolean.TRUE.equals(isReleased)) {
                return threadId;
            }
        }

        // 无活跃线程
        return null;
    }

    /**
     * 查询指定线程的所有检查点
     * @param config 运行配置(必须包含线程ID)
     * @return 检查点集合
     */
    @Override
    public Collection<Checkpoint> list(RunnableConfig config) {
        Optional<String> threadNameOpt = config.threadId();
        if (!threadNameOpt.isPresent()) {
            throw new IllegalArgumentException("线程ID不能为空");
        }

        String threadName = threadNameOpt.get();
        // 创建事务会话
        ClientSession clientSession = this.client
                .startSession(ClientSessionOptions.builder().defaultTransactionOptions(txnOptions).build());
        clientSession.startTransaction();
        List<Checkpoint> checkpoints = null;
        try {
            // 获取活跃线程ID
            String threadId = getActiveThreadId(threadName, clientSession);
            if (threadId == null) {
                clientSession.commitTransaction();
                return Collections.emptyList();
            }

            // 查询检查点数据
            MongoCollection<Document> collection = database.getCollection(CHECKPOINT_COLLECTION);
            String checkpointId = CHECKPOINT_PREFIX + threadId;
            Document document = collection.find(clientSession, new BasicDBObject("_id", checkpointId)).first();
            if (document == null) {
                clientSession.commitTransaction();
                return Collections.emptyList();
            }
            String checkpointsStr = document.getString(DOCUMENT_CONTENT_KEY);
            checkpoints = deserializeCheckpoints(checkpointsStr);
            clientSession.commitTransaction();
        }
        catch (Exception e) {
            clientSession.abortTransaction();
            throw new RuntimeException("查询检查点列表失败", e);
        }
        finally {
            clientSession.close();
        }
        return checkpoints;
    }

    /**
     * 获取单个检查点
     * 支持按ID精确查询,无ID则返回最新检查点
     * @param config 运行配置
     * @return 检查点Optional对象
     */
    @Override
    public Optional<Checkpoint> get(RunnableConfig config) {
        Optional<String> threadNameOpt = config.threadId();
        if (!threadNameOpt.isPresent()) {
            throw new IllegalArgumentException("线程ID不能为空");
        }

        String threadName = threadNameOpt.get();
        ClientSession clientSession = this.client
                .startSession(ClientSessionOptions.builder().defaultTransactionOptions(txnOptions).build());
        LinkedList<Checkpoint> checkpoints = null;
        try {
            clientSession.startTransaction();

            // 获取活跃线程ID
            String threadId = getActiveThreadId(threadName, clientSession);
            if (threadId == null) {
                clientSession.commitTransaction();
                return Optional.empty();
            }

            // 查询检查点
            MongoCollection<Document> collection = database.getCollection(CHECKPOINT_COLLECTION);
            String checkpointId = CHECKPOINT_PREFIX + threadId;
            Document document = collection.find(clientSession, new BasicDBObject("_id", checkpointId)).first();
            if (document == null) {
                clientSession.commitTransaction();
                return Optional.empty();
            }
            String checkpointsStr = document.getString(DOCUMENT_CONTENT_KEY);
            checkpoints = deserializeCheckpoints(checkpointsStr);
            clientSession.commitTransaction();

            // 按ID精确匹配
            if (config.checkPointId().isPresent()) {
                List<Checkpoint> finalCheckpoints = checkpoints;
                return config.checkPointId()
                        .flatMap(id -> finalCheckpoints.stream()
                                .filter(checkpoint -> checkpoint.getId().equals(id))
                                .findFirst());
            }
            // 返回最后一个检查点
            return getLast(checkpoints, config);
        }
        catch (Exception e) {
            clientSession.abortTransaction();
            throw new RuntimeException("获取检查点失败", e);
        }
        finally {
            clientSession.close();
        }
    }

    /**
     * 保存/更新检查点
     * 存在ID则更新,不存在则新增
     * @param config 运行配置
     * @param checkpoint 检查点对象
     * @return 更新后的运行配置
     * @throws Exception 保存异常
     */
    @Override
    public RunnableConfig put(RunnableConfig config, Checkpoint checkpoint) throws Exception {
        Optional<String> threadNameOpt = config.threadId();
        if (!threadNameOpt.isPresent()) {
            throw new IllegalArgumentException("线程ID不能为空");
        }

        String threadName = threadNameOpt.get();
        ClientSession clientSession = this.client
                .startSession(ClientSessionOptions.builder().defaultTransactionOptions(txnOptions).build());
        clientSession.startTransaction();
        try {
            // 获取/创建线程ID
            String threadId = getOrCreateThreadId(threadName, clientSession);

            MongoCollection<Document> collection = database.getCollection(CHECKPOINT_COLLECTION);
            String checkpointDocId = CHECKPOINT_PREFIX + threadId;
            Document document = collection.find(clientSession, new BasicDBObject("_id", checkpointDocId)).first();
            LinkedList<Checkpoint> checkpointLinkedList = null;

            // 文档存在,执行更新/替换
            if (Objects.nonNull(document)) {
                String checkpointsStr = document.getString(DOCUMENT_CONTENT_KEY);
                checkpointLinkedList = deserializeCheckpoints(checkpointsStr);
                LinkedList<Checkpoint> finalCheckpointLinkedList = checkpointLinkedList;
                // 指定ID,执行替换
                if (config.checkPointId().isPresent()) {
                    String checkPointId = config.checkPointId().get();
                    int index = IntStream.range(0, checkpointLinkedList.size())
                            .filter(i -> finalCheckpointLinkedList.get(i).getId().equals(checkPointId))
                            .findFirst()
                            .orElseThrow(() -> (new NoSuchElementException(
                                    format("未找到ID为 %s 的检查点!", checkPointId))));
                    finalCheckpointLinkedList.set(index, checkpoint);
                    Document tempDocument = new Document().append("_id", checkpointDocId)
                            .append(DOCUMENT_CONTENT_KEY, serializeCheckpoints(finalCheckpointLinkedList));
                    collection.replaceOne(clientSession, Filters.eq("_id", checkpointDocId), tempDocument);
                    clientSession.commitTransaction();
                    return RunnableConfig.builder(config).checkPointId(checkpoint.getId()).build();
                }
            }

            // 新增检查点
            if (checkpointLinkedList == null) {
                checkpointLinkedList = new LinkedList<>();
                checkpointLinkedList.push(checkpoint);
                Document tempDocument = new Document().append("_id", checkpointDocId)
                        .append(DOCUMENT_CONTENT_KEY, serializeCheckpoints(checkpointLinkedList));
                collection.insertOne(clientSession, tempDocument);
            }
            else {
                checkpointLinkedList.push(checkpoint);
                Document tempDocument = new Document().append("_id", checkpointDocId)
                        .append(DOCUMENT_CONTENT_KEY, serializeCheckpoints(checkpointLinkedList));
                ReplaceOptions opts = new ReplaceOptions().upsert(true);
                collection.replaceOne(clientSession, Filters.eq("_id", checkpointDocId), tempDocument, opts);
            }
            clientSession.commitTransaction();
        }
        catch (Exception e) {
            clientSession.abortTransaction();
            throw new RuntimeException("保存检查点失败", e);
        }
        finally {
            clientSession.close();
        }
        return RunnableConfig.builder(config).checkPointId(checkpoint.getId()).build();
    }

    /**
     * 释放线程检查点
     * 原子标记线程为已释放状态,并返回检查点标签
     * @param config 运行配置
     * @return 检查点标签
     * @throws Exception 释放异常
     */
    @Override
    public Tag release(RunnableConfig config) throws Exception {
        Optional<String> threadNameOpt = config.threadId();
        if (!threadNameOpt.isPresent()) {
            throw new IllegalArgumentException("线程ID不能为空");
        }

        String threadName = threadNameOpt.get();
        ClientSession clientSession = this.client
                .startSession(ClientSessionOptions.builder().defaultTransactionOptions(txnOptions).build());
        clientSession.startTransaction();
        try {
            MongoCollection<Document> threadMetaCollection = database.getCollection(THREAD_META_COLLECTION);
            String metaId = THREAD_META_PREFIX + threadName;

            // 查询线程元数据
            Document metaDoc = threadMetaCollection.find(clientSession, new BasicDBObject("_id", metaId)).first();
            if (metaDoc == null) {
                clientSession.abortTransaction();
                throw new IllegalStateException("线程不存在:" + threadName);
            }

            String threadId = metaDoc.getString(FIELD_THREAD_ID);
            if (threadId == null) {
                clientSession.abortTransaction();
                throw new IllegalStateException("线程不存在:" + threadName);
            }

            // 原子标记线程为已释放(仅允许释放活跃线程)
            Document releaseFilter = new Document("_id", metaId)
                    .append(FIELD_IS_RELEASED, false);

            Document updatedDoc = threadMetaCollection.findOneAndUpdate(
                    clientSession,
                    releaseFilter,
                    Updates.set(FIELD_IS_RELEASED, true),
                    new FindOneAndUpdateOptions().returnDocument(ReturnDocument.AFTER)
            );

            if (updatedDoc == null) {
                clientSession.abortTransaction();
                throw new IllegalStateException("线程非活跃或已释放:" + threadName);
            }

            // 获取检查点数据
            MongoCollection<Document> checkpointCollection = database.getCollection(CHECKPOINT_COLLECTION);
            String checkpointDocId = CHECKPOINT_PREFIX + threadId;
            Document checkpointDoc = checkpointCollection.find(clientSession, new BasicDBObject("_id", checkpointDocId))
                    .first();

            Collection<Checkpoint> checkpoints = Collections.emptyList();
            if (checkpointDoc != null) {
                String checkpointsStr = checkpointDoc.getString(DOCUMENT_CONTENT_KEY);
                if (checkpointsStr != null) {
                    checkpoints = deserializeCheckpoints(checkpointsStr);
                }
            }

            clientSession.commitTransaction();
            return new Tag(threadName, checkpoints);

        }
        catch (Exception e) {
            clientSession.abortTransaction();
            throw new RuntimeException("释放检查点失败", e);
        }
        finally {
            clientSession.close();
        }
    }

4.3 VersionedMemorySaver

实现了 BaseCheckpointSaverHasVersions 接口,基于内存提供带版本管理 的检查点存储功能,支持按线程 ID 保存历史版本、查询历史检查点,属于实验性功能。

核心属性:

java 复制代码
    /**
     * 按线程ID分组的检查点历史记录
     * Key:线程ID
     * Value:有序的版本映射表(Key=版本号,Value=版本标签)
     */
    private final Map<String, TreeMap<Integer, Tag>> checkpointsHistoryByThread = new HashMap<>();

    /**
     * 基础无版本内存检查点存储器,委托执行核心的检查点读写操作
     */
    private final MemorySaver noVersionSaver = new MemorySaver();

    /**
     * 可重入锁,保证多线程环境下数据操作的线程安全
     */
    private final ReentrantLock lock = new ReentrantLock();

实现了 BaseCheckpointSaverHasVersions 声明的相关方法:

java 复制代码
    /**
     * 根据线程ID获取对应的检查点历史记录
     * @param threadId 线程ID
     * @return 存在则返回版本映射表,否则返回空Optional
     */
    private Optional<TreeMap<Integer, Tag>> getCheckpointHistoryByThread(String threadId) {
        return ofNullable(checkpointsHistoryByThread.get(threadId));
    }

    /**
     * 根据版本号获取对应的标签
     * @param checkpointsHistory 检查点历史版本表
     * @param threadVersion 线程版本号
     * @return 标签对象Optional
     */
    final Optional<Tag> getTagByVersion(TreeMap<Integer, Tag> checkpointsHistory, int threadVersion) {
        lock.lock();
        try {
            return ofNullable(checkpointsHistory.get(threadVersion));
        }
        finally {
            lock.unlock();
        }
    }

    /**
     * 根据线程ID和版本号获取对应的检查点集合
     * @param threadId 线程ID
     * @param threadVersion 版本号
     * @return 检查点集合
     * @throws IllegalArgumentException 版本不存在时抛出异常
     */
    final Collection<Checkpoint> getCheckpointsByVersion(String threadId, int threadVersion) {
        lock.lock();
        try {
            return getCheckpointHistoryByThread(threadId)
                    .map(history -> history.get(threadVersion))
                    .map(Tag::checkpoints)
                    .orElseThrow(() -> new IllegalArgumentException(
                            format("线程[%s]的版本[%s]不存在", threadVersion, threadId)));
        }
        finally {
            lock.unlock();
        }
    }

    /**
     * 根据线程ID获取所有历史版本号
     * @param threadId 线程ID(可为null,使用默认值)
     * @return 版本号集合,无数据则返回空集合
     */
    @Override
    public Collection<Integer> versionsByThreadId(String threadId) {
        return getCheckpointHistoryByThread(ofNullable(threadId).orElse(THREAD_ID_DEFAULT))
                .map(history -> (Collection<Integer>) history.keySet())
                .orElse(Collections.emptyList());
    }

    /**
     * 根据线程ID获取最新的版本号
     * @param threadId 线程ID(可为null,使用默认值)
     * @return 最新版本号Optional
     */
    @Override
    public Optional<Integer> lastVersionByThreadId(String threadId) {
        return getCheckpointHistoryByThread(ofNullable(threadId).orElse(THREAD_ID_DEFAULT))
                .map(TreeMap::lastKey);
    }

    /**
     * 根据配置查询检查点列表
     * @param config 运行配置
     * @return 检查点集合
     * @throws RuntimeException 查询异常
     */
    @Override
    public Collection<Checkpoint> list(RunnableConfig config) {
        lock.lock();
        try {
            return noVersionSaver.list(config);
        }
        finally {
            lock.unlock();
        }
    }

    /**
     * 根据配置获取单个检查点
     * @param config 运行配置
     * @return 检查点Optional
     */
    @Override
    public Optional<Checkpoint> get(RunnableConfig config) {
        lock.lock();
        try {
            return noVersionSaver.get(config);
        }
        finally {
            lock.unlock();
        }
    }

    /**
     * 保存检查点
     * @param config 运行配置
     * @param checkpoint 检查点对象
     * @return 更新后的运行配置
     * @throws Exception 保存异常
     */
    @Override
    public RunnableConfig put(RunnableConfig config, Checkpoint checkpoint) throws Exception {
        lock.lock();
        try {
            return noVersionSaver.put(config, checkpoint);
        }
        finally {
            lock.unlock();
        }
    }

    /**
     * 释放当前检查点并生成新版本
     * 将当前检查点存入历史版本,版本号自增
     * @param config 运行配置
     * @return 生成的版本标签
     * @throws Exception 操作异常
     */
    @Override
    public Tag release(RunnableConfig config) throws Exception {
        lock.lock();
        try {
            // 获取线程ID,使用默认值兜底
            String threadId = config.threadId().orElse(THREAD_ID_DEFAULT);
            // 委托基础存储器生成标签
            Tag tag = noVersionSaver.release(config);
            // 获取/创建线程对应的历史版本表
            TreeMap<Integer, Tag> checkpointsHistory = checkpointsHistoryByThread.computeIfAbsent(threadId, k -> new TreeMap<>());
            // 计算新版本号
            int newVersion = ofNullable(checkpointsHistory.lastEntry()).map(Map.Entry::getKey).orElse(0) + 1;
            // 保存新版本
            checkpointsHistory.put(newVersion, tag);
            return tag;
        }
        finally {
            lock.unlock();
        }
    }

4.4 MemorySaver

基于内存 HashMap 实现的轻量级检查点存储,线程安全,适用于单机测试、开发环境,重启后数据会丢失。

核心属性:

java 复制代码
    /**
     * 按线程ID存储的检查点集合
     * Key:线程ID,Value:该线程下的检查点链表(头插法,最新在前)
     */
    final Map<String, LinkedList<Checkpoint>> _checkpointsByThread = new HashMap<>();

    /**
     * 可重入锁,保证多线程并发操作的线程安全
     */
    private final ReentrantLock _lock = new ReentrantLock();

4.4.1 MysqlSaver

继承 MemorySaver 实现,将流程状态持久化存储到 MySQL 数据库,支持事务、数据持久化,支持支持自动建表、事务控制、数据序列化/反序列化。

核心依赖两张表:

  • GRAPH_THREAD(线程元数据)
  • GRAPH_CHECKPOINT(检查点数据)

线程元数据表:

sql 复制代码
-- 线程元数据表
CREATE TABLE GRAPH_THREAD (
     thread_id VARCHAR(36) PRIMARY KEY,
     thread_name VARCHAR(255),
     is_released BOOLEAN DEFAULT FALSE NOT NULL
)
CREATE UNIQUE INDEX IDX_GRAPH_THREAD_NAME_RELEASED
     ON GRAPH_THREAD(thread_name, is_released)

检查点数据表:

sql 复制代码
-- 检查点数据表
CREATE TABLE GRAPH_CHECKPOINT (
     checkpoint_id VARCHAR(36) PRIMARY KEY,
     thread_id VARCHAR(36) NOT NULL,
     node_id VARCHAR(255),
     next_node_id VARCHAR(255),
     state_data JSON NOT NULL,
     saved_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,

     CONSTRAINT GRAPH_FK_THREAD
         FOREIGN KEY(thread_id)
         REFERENCES GRAPH_THREAD(thread_id)
         ON DELETE CASCADE
)

使用常量定义了表结构、数据操作语句:

java 复制代码
    // ==================== 表结构 DDL 语句 ====================
    /** 创建线程元数据表 */
    private static final String CREATE_THREAD_TABLE = """
			CREATE TABLE IF NOT EXISTS GRAPH_THREAD (
			   thread_id VARCHAR(36) PRIMARY KEY,
			   thread_name VARCHAR(255),
			   is_released BOOLEAN DEFAULT FALSE NOT NULL
			)""";

    /** 创建线程名称唯一索引 */
    private static final String INDEX_THREAD_TABLE = """
			CREATE UNIQUE INDEX IDX_GRAPH_THREAD_NAME_RELEASED
			  ON GRAPH_THREAD(thread_name, is_released)
			""";

    /** 创建检查点数据表 */
    private static final String CREATE_CHECKPOINT_TABLE = """
			CREATE TABLE IF NOT EXISTS GRAPH_CHECKPOINT (
			   checkpoint_id VARCHAR(36) PRIMARY KEY,
			   thread_id VARCHAR(36) NOT NULL,
			   node_id VARCHAR(255),
			   next_node_id VARCHAR(255),
			   state_data JSON NOT NULL,
			   saved_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
			
			   CONSTRAINT GRAPH_FK_THREAD
			       FOREIGN KEY(thread_id)
			       REFERENCES GRAPH_THREAD(thread_id)
			       ON DELETE CASCADE
			)""";

    /** 删除检查点数据表 */
    private static final String DROP_CHECKPOINT_TABLE = "DROP TABLE IF EXISTS GRAPH_CHECKPOINT";
    /** 删除线程元数据表 */
    private static final String DROP_THREAD_TABLE = "DROP TABLE IF EXISTS GRAPH_THREAD";

    // ==================== 数据操作 DML 语句 ====================
    /** 插入/更新线程数据(存在则忽略) */
    private static final String UPSERT_THREAD = """
			INSERT INTO GRAPH_THREAD (thread_id, thread_name, is_released)
			VALUES (?, ?, FALSE)
			ON DUPLICATE KEY UPDATE thread_id = thread_id
			""";

    /** 插入检查点数据 */
    private static final String INSERT_CHECKPOINT = """
			INSERT INTO GRAPH_CHECKPOINT(checkpoint_id, thread_id, node_id, next_node_id, state_data)
			SELECT ?, thread_id, ?, ?, ?
			FROM GRAPH_THREAD
			WHERE thread_name = ? AND is_released = FALSE
			""";

    /** 更新检查点数据 */
    private static final String UPDATE_CHECKPOINT = """
			UPDATE GRAPH_CHECKPOINT
			SET
			  checkpoint_id = ?,
			  node_id = ?,
			  next_node_id = ?,
			  state_data = ?
			WHERE checkpoint_id = ?
			""";

    /** 查询指定线程的所有检查点(按时间倒序) */
    private static final String SELECT_CHECKPOINTS = """
			SELECT
			  c.checkpoint_id,
			  c.node_id,
			  c.next_node_id,
			  JSON_UNQUOTE(JSON_EXTRACT(c.state_data, '$.binaryPayload')) AS base64_data
			FROM GRAPH_CHECKPOINT c
			  INNER JOIN GRAPH_THREAD t ON c.thread_id = t.thread_id
			WHERE t.thread_name = ? AND t.is_released != TRUE
			ORDER BY c.saved_at DESC
			""";

    /** 根据ID删除检查点 */
    private static final String DELETE_CHECKPOINTS = """
			    DELETE FROM GRAPH_CHECKPOINT WHERE checkpoint_id = ?
			""";

    /** 标记线程为已释放状态 */
    private static final String RELEASE_THREAD = """
			UPDATE GRAPH_THREAD SET is_released = TRUE WHERE thread_name = ? AND is_released = FALSE
			""";

核心属性:

java 复制代码
    /** 数据库数据源 */
    private final DataSource dataSource;
    /** 表创建策略 */
    private final CreateOption createOption;
    /** 状态序列化器 */
    private final StateSerializer stateSerializer;

4.4.2 OracleSaver

继承 MemorySaver 实现,将流程状态持久化存储到 Oracle 数据库,支持事务、JSON 类型、级联删除。

核心属性:

java 复制代码
    /** 数据库数据源 */
    private final DataSource dataSource;
    /** 表创建策略 */
    private final CreateOption createOption;
    /** 状态序列化器 */
    private final StateSerializer stateSerializer;

4.4.3 FileSystemSaver

继承 MemorySaver 实现,将检查点数据持久化存储到本地文件系统,每个线程 ID 对应一个独立文件。

文件命名规则thread-{线程ID}.saver

核心属性:

java 复制代码
    /**
     * 检查点文件默认后缀
     */
    public static final String EXTENSION = ".saver";

    /**
     * 日志对象
     */
    private static final org.slf4j.Logger log = org.slf4j.LoggerFactory.getLogger(FileSystemSaver.class);

    /**
     * 检查点文件存储根目录
     */
    private final Path targetFolder;

    /**
     * 检查点序列化工具
     */
    private final Serializer<Checkpoint> serializer;

4.4.4 PostgresSaver

继承 MemorySaver 实现,基于 PostgreSQL 数据库持久化存储流程检查点。

核心属性:

java 复制代码
	/**
	 * Datasource used to create the store
	 */
	protected final DataSource datasource;

	private final StateSerializer stateSerializer;
相关推荐
AINative软件工程2 小时前
我给自己的MCP Server做了一次渗透测试,结果吓出一身冷汗
人工智能
水如烟2 小时前
孤能子视角:创新–幻觉“三线模型“,豆包的“飞“
人工智能
火山引擎开发者社区2 小时前
ArkClaw 养虾省钱攻略,这 10% 的返利你还不知道?
人工智能
跨境卫士苏苏2 小时前
跨境电商成本持续上升卖家利润空间如何守住
大数据·人工智能·跨境电商·亚马逊·跨境
IT大师兄吖2 小时前
SAM3 提示词 视频分割 ComfyUI 懒人整合包
人工智能
AI、少年郎2 小时前
MiniMind第 3 篇:底层原理|Decoder-Only 小模型核心:RMSNorm/SwiGLU/RoPE 极简吃透
人工智能·ai编程·大模型训练·大模型微调·大模型原理
雾喔2 小时前
【学习笔记3】AI 工程实战
人工智能·笔记·学习
火山引擎开发者社区2 小时前
玩转 ArkClaw:用自动修复打造稳定可靠的 AI 助理
人工智能
殷紫川2 小时前
深入拆解 Fork/Join 框架:核心原理、分治模型与参数调优实战
java