聊聊Spring AI Alibaba的PlantUMLGenerator

本文主要研究一下Spring AI Alibaba的PlantUMLGenerator

DiagramGenerator

spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/DiagramGenerator.java

scss 复制代码
public abstract class DiagramGenerator {

  public enum CallStyle {

    DEFAULT, START, END, CONDITIONAL, PARALLEL

  }

  public record Context(StringBuilder sb, String title, boolean printConditionalEdge, boolean isSubGraph) {

    static Builder builder() {
      return new Builder();
    }

    static public class Builder {

      String title;

      boolean printConditionalEdge;

      boolean IsSubGraph;

      private Builder() {
      }

      public Builder title(String title) {
        this.title = title;
        return this;
      }

      public Builder printConditionalEdge(boolean value) {
        this.printConditionalEdge = value;
        return this;
      }

      public Builder isSubGraph(boolean value) {
        this.IsSubGraph = value;
        return this;
      }

      public Context build() {
        return new Context(new StringBuilder(), title, printConditionalEdge, IsSubGraph);
      }

    }

    /**
     * Converts a given title string to snake_case format by replacing all
     * non-alphanumeric characters with underscores.
     * @return the snake_case formatted string
     */
    public String titleToSnakeCase() {
      return title.replaceAll("[^a-zA-Z0-9]", "_");
    }

    /**
     * Returns a string representation of this object by returning the string built in
     * {@link #sb}.
     * @return a string representation of this object.
     */
    @Override
    public String toString() {
      return sb.toString();
    }
  }

  /**
   * Appends a header to the output based on the provided context.
   * @param ctx The {@link Context} containing the information needed for appending the
   * header.
   */
  protected abstract void appendHeader(Context ctx);

  /**
   * Appends a footer to the content.
   * @param ctx Context object containing the necessary information.
   */
  protected abstract void appendFooter(Context ctx);

  /**
   * This method is an abstract method that must be implemented by subclasses. It is
   * used to initiate a communication call between two parties identified by their phone
   * numbers.
   * @param ctx The current context in which the call is being made.
   * @param from The phone number of the caller.
   * @param to The phone number of the recipient.
   */
  protected abstract void call(Context ctx, String from, String to, CallStyle style);

  /**
   * Abstract method that must be implemented by subclasses to handle the logic of
   * making a call.
   * @param ctx The context in which the call is being made.
   * @param from The phone number of the caller.
   * @param to The phone number of the recipient.
   * @param description A brief description of the call.
   */
  protected abstract void call(Context ctx, String from, String to, String description, CallStyle style);

  /**
   * Declares a conditional element in the configuration or template. This method is
   * used to mark the start of a conditional section based on the provided {@code name}.
   * It takes a {@code Context} object that may contain additional parameters necessary
   * for the declaration, and a {@code name} which identifies the type or key associated
   * with the conditional section.
   * @param ctx The context containing contextual information needed for the
   * declaration.
   * @param name The name of the conditional section to be declared.
   */
  protected abstract void declareConditionalStart(Context ctx, String name);

  /**
   * Declares a node in the specified context with the given name.
   * @param ctx the context in which to declare the node {@code @literal (not null)}
   * @param name the name of the node to be declared
   * {@code @literal (not null, not empty)}
   */
  protected abstract void declareNode(Context ctx, String name);

  /**
   * Declares a conditional edge in the context with a specified ordinal.
   * @param ctx the context
   * @param ordinal the ordinal value
   */
  protected abstract void declareConditionalEdge(Context ctx, int ordinal);

  /**
   * Comment a line in the given context.
   * @param ctx The context in which the line is to be commented.
   * @param yesOrNo Whether the line should be uncommented ({@literal true}) or
   * commented ({@literal false}).
   */
  protected abstract void commentLine(Context ctx, boolean yesOrNo);

  /**
   * Generate a textual representation of the given graph.
   * @param nodes the state graph nodes used to generate the context, which must not be
   * null
   * @param edges the state graph edges used to generate the context, which must not be
   * null
   * @param title The title of the graph.
   * @param printConditionalEdge Whether to print the conditional edge condition.
   * @return A string representation of the graph.
   */
  public final String generate(StateGraph.Nodes nodes, StateGraph.Edges edges, String title,
      boolean printConditionalEdge) {

    return generate(nodes, edges,
        Context.builder().title(title).isSubGraph(false).printConditionalEdge(printConditionalEdge).build())
      .toString();

  }

  /**
   * Generates a context based on the given state graph.
   * @param nodes the state graph nodes used to generate the context, which must not be
   * null
   * @param edges the state graph edges used to generate the context, which must not be
   * null
   * @param ctx the initial context, which must not be null
   * @return the generated context, which will not be null
   */
  protected final Context generate(StateGraph.Nodes nodes, StateGraph.Edges edges, Context ctx) {

    appendHeader(ctx);

    for (var n : nodes.elements) {

      if (n instanceof SubGraphNode subGraphNode) {

        @SuppressWarnings("unchecked")
        var subGraph = (StateGraph) subGraphNode.subGraph();
        Context subgraphCtx = generate(subGraph.nodes, subGraph.edges,
            Context.builder()
              .title(n.id())
              .printConditionalEdge(ctx.printConditionalEdge)
              .isSubGraph(true)
              .build());
        ctx.sb().append(subgraphCtx);
      }
      else {
        declareNode(ctx, n.id());
      }
    }

    final int[] conditionalEdgeCount = { 0 };

    edges.elements.stream()
      .filter(e -> !Objects.equals(e.sourceId(), START))
      .filter(e -> !e.isParallel())
      .forEach(e -> {
        if (e.target().value() != null) {
          conditionalEdgeCount[0] += 1;
          commentLine(ctx, !ctx.printConditionalEdge());
          declareConditionalEdge(ctx, conditionalEdgeCount[0]);
        }
      });

    var edgeStart = edges.elements.stream()
      .filter(e -> Objects.equals(e.sourceId(), START))
      .findFirst()
      .orElseThrow();
    if (edgeStart.isParallel()) {
      edgeStart.targets().forEach(target -> {
        call(ctx, START, target.id(), CallStyle.START);
      });
    }
    else if (edgeStart.target().id() != null) {
      call(ctx, START, edgeStart.target().id(), CallStyle.START);
    }
    else if (edgeStart.target().value() != null) {
      String conditionName = "startcondition";
      commentLine(ctx, !ctx.printConditionalEdge());
      declareConditionalStart(ctx, conditionName);
      edgeCondition(ctx, edgeStart.target().value(), START, conditionName);
    }

    conditionalEdgeCount[0] = 0; // reset

    edges.elements.stream().filter(e -> !Objects.equals(e.sourceId(), START)).forEach(v -> {

      if (v.isParallel()) {
        v.targets().forEach(target -> {
          call(ctx, v.sourceId(), target.id(), CallStyle.PARALLEL);
        });
      }
      else if (v.target().id() != null) {
        call(ctx, v.sourceId(), v.target().id(), CallStyle.DEFAULT);
      }
      else if (v.target().value() != null) {
        conditionalEdgeCount[0] += 1;
        String conditionName = format("condition%d", conditionalEdgeCount[0]);

        edgeCondition(ctx, v.targets().get(0).value(), v.sourceId(), conditionName);
      }
    });

    appendFooter(ctx);

    return ctx;

  }

  /**
   * Evaluates an edge condition based on the given context and condition.
   * @param ctx the current context used for evaluation
   * @param condition the condition to be evaluated
   * @param k a string identifier for the condition
   * @param conditionName the name of the condition being processed
   */
  private void edgeCondition(Context ctx, EdgeCondition condition, String k, String conditionName) {
    commentLine(ctx, !ctx.printConditionalEdge());
    call(ctx, k, conditionName, CallStyle.CONDITIONAL);

    condition.mappings().forEach((cond, to) -> {
      commentLine(ctx, !ctx.printConditionalEdge());
      call(ctx, conditionName, to, cond, CallStyle.CONDITIONAL);
      commentLine(ctx, ctx.printConditionalEdge());
      call(ctx, k, to, cond, CallStyle.CONDITIONAL);
    });
  }

}

DiagramGenerator是个抽象类,定义了流程图生成的基类,它提供了appendHeader、appendFooter、call、declareConditionalStart、declareNode、declareConditionalEdge、commentLine抽象方法;它提供了generate方法根据nodes、edges、ctx生成图的文字表示。

PlantUMLGenerator

spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/diagram/PlantUMLGenerator.java

swift 复制代码
public class PlantUMLGenerator extends DiagramGenerator {

  @Override
  protected void appendHeader(Context ctx) {

    if (ctx.isSubGraph()) {
      ctx.sb()
        .append(format("rectangle %s [ {{\ntitle \"%s\"\n", ctx.title(), ctx.title()))
        .append(format("circle \" \" as %s\n", START))
        .append(format("circle exit as %s\n", END));
    }
    else {
      ctx.sb()
        .append(format("@startuml %s\n", ctx.titleToSnakeCase()))
        .append("skinparam usecaseFontSize 14\n")
        .append("skinparam usecaseStereotypeFontSize 12\n")
        .append("skinparam hexagonFontSize 14\n")
        .append("skinparam hexagonStereotypeFontSize 12\n")
        .append(format("title \"%s\"\n", ctx.title()))
        .append("footer\n\n")
        .append("powered by spring-ai-alibaba\n")
        .append("end footer\n")
        .append(format("circle start<<input>> as %s\n", START))
        .append(format("circle stop as %s\n", END));
    }
  }

  @Override
  protected void appendFooter(Context ctx) {
    if (ctx.isSubGraph()) {
      ctx.sb().append("\n}} ]\n");
    }
    else {
      ctx.sb().append("@enduml\n");
    }
  }

  @Override
  protected void call(Context ctx, String from, String to, CallStyle style) {
    ctx.sb().append(switch (style) {
      case CONDITIONAL -> format("\"%s\" .down.> \"%s\"\n", from, to);
      default -> format("\"%s\" -down-> \"%s\"\n", from, to);
    });
  }

  @Override
  protected void call(Context ctx, String from, String to, String description, CallStyle style) {

    ctx.sb().append(switch (style) {
      case CONDITIONAL -> format("\"%s\" .down.> \"%s\": \"%s\"\n", from, to, description);
      default -> format("\"%s\" -down-> \"%s\": \"%s\"\n", from, to, description);
    });
  }

  @Override
  protected void declareConditionalStart(Context ctx, String name) {
    ctx.sb().append(format("hexagon \"check state\" as %s<<Condition>>\n", name));
  }

  @Override
  protected void declareNode(Context ctx, String name) {
    ctx.sb().append(format("usecase \"%s\"<<Node>>\n", name));
  }

  @Override
  protected void declareConditionalEdge(Context ctx, int ordinal) {
    ctx.sb().append(format("hexagon \"check state\" as condition%d<<Condition>>\n", ordinal));
  }

  @Override
  protected void commentLine(Context ctx, boolean yesOrNo) {
    if (yesOrNo)
      ctx.sb().append("'");
  }

}

PlantUMLGenerator实现了DiagramGenerator的抽象方法

StateGraph

spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/StateGraph.java

scss 复制代码
/**
 * Represents a state graph with nodes and edges.
 *
 */
public class StateGraph {

  public static String END = "__END__";

  public static String START = "__START__";

  final Nodes nodes = new Nodes();

  final Edges edges = new Edges();

  private OverAllState overAllState;

  private String name;

  public OverAllState getOverAllState() {
    return overAllState;
  }

  public StateGraph setOverAllState(OverAllState overAllState) {
    this.overAllState = overAllState;
    return this;
  }

  private final PlainTextStateSerializer stateSerializer;

  //......

  /**
   * Instantiates a new State graph.
   * @param overAllState the over all state
   * @param plainTextStateSerializer the plain text state serializer
   */
  public StateGraph(OverAllState overAllState, PlainTextStateSerializer plainTextStateSerializer) {
    this.overAllState = overAllState;
    this.stateSerializer = plainTextStateSerializer;
  }

  public StateGraph(String name, OverAllState overAllState) {
    this.name = name;
    this.overAllState = overAllState;
    this.stateSerializer = new GsonSerializer();
  }

  /**
   * Instantiates a new State graph.
   * @param overAllState the over all state
   */
  public StateGraph(OverAllState overAllState) {
    this.overAllState = overAllState;
    this.stateSerializer = new GsonSerializer();
  }

  public StateGraph(String name, AgentStateFactory<OverAllState> factory) {
    this.name = name;
    this.overAllState = factory.apply(Map.of());
    this.stateSerializer = new GsonSerializer2(factory);
  }

  public StateGraph(AgentStateFactory<OverAllState> factory) {
    this.overAllState = factory.apply(Map.of());
    this.stateSerializer = new GsonSerializer2(factory);
  }

  /**
   * Instantiates a new State graph.
   */
  public StateGraph() {
    this.stateSerializer = new GsonSerializer();
  }

  public String getName() {
    return name;
  }

  /**
   * Key strategies map.
   * @return the map
   */
  public Map<String, KeyStrategy> keyStrategies() {
    return overAllState.keyStrategies();
  }

  /**
   * Gets state serializer.
   * @return the state serializer
   */
  public StateSerializer getStateSerializer() {
    return stateSerializer;
  }

  /**
   * Gets state factory.
   * @return the state factory
   */
  public final AgentStateFactory<OverAllState> getStateFactory() {
    return stateSerializer.stateFactory();
  }

  /**
   * /** Adds a node to the graph.
   * @param id the identifier of the node
   * @param action the action to be performed by the node
   * @throws GraphStateException if the node identifier is invalid or the node already
   * exists
   */
  public StateGraph addNode(String id, AsyncNodeAction action) throws GraphStateException {
    return addNode(id, AsyncNodeActionWithConfig.of(action));
  }

  /**
   * @param id the identifier of the node
   * @param actionWithConfig the action to be performed by the node
   * @return this
   * @throws GraphStateException if the node identifier is invalid or the node already
   * exists
   */
  public StateGraph addNode(String id, AsyncNodeActionWithConfig actionWithConfig) throws GraphStateException {
    Node node = new Node(id, (config) -> actionWithConfig);
    return addNode(id, node);
  }

  /**
   * @param id the identifier of the node
   * @param node the node to be added
   * @return this
   * @throws GraphStateException if the node identifier is invalid or the node already
   * exists
   */
  public StateGraph addNode(String id, Node node) throws GraphStateException {
    if (Objects.equals(node.id(), END)) {
      throw Errors.invalidNodeIdentifier.exception(END);
    }
    if (!Objects.equals(node.id(), id)) {
      throw Errors.invalidNodeIdentifier.exception(node.id(), id);
    }

    if (nodes.elements.contains(node)) {
      throw Errors.duplicateNodeError.exception(id);
    }

    nodes.elements.add(node);
    return this;
  }

  /**
   * Adds a subgraph to the state graph by creating a node with the specified
   * identifier. This implies that Subgraph share the same state with parent graph
   * @param id the identifier of the node representing the subgraph
   * @param subGraph the compiled subgraph to be added
   * @return this state graph instance
   * @throws GraphStateException if the node identifier is invalid or the node already
   * exists
   */
  public StateGraph addNode(String id, CompiledGraph subGraph) throws GraphStateException {
    if (Objects.equals(id, END)) {
      throw Errors.invalidNodeIdentifier.exception(END);
    }

    var node = new SubCompiledGraphNode(id, subGraph);

    if (nodes.elements.contains(node)) {
      throw Errors.duplicateNodeError.exception(id);
    }

    nodes.elements.add(node);
    return this;

  }

  /**
   * Adds a subgraph to the state graph by creating a node with the specified
   * identifier. This implies that Subgraph share the same state with parent graph
   * @param id the identifier of the node representing the subgraph
   * @param subGraph the subgraph to be added. it will be compiled on compilation of the
   * parent
   * @return this state graph instance
   * @throws GraphStateException if the node identifier is invalid or the node already
   * exists
   */
  public StateGraph addNode(String id, StateGraph subGraph) throws GraphStateException {
    if (Objects.equals(id, END)) {
      throw Errors.invalidNodeIdentifier.exception(END);
    }

    subGraph.validateGraph();
    OverAllState subGraphOverAllState = subGraph.getOverAllState();
    OverAllState superOverAllState = getOverAllState();
    if (subGraphOverAllState != null) {
      Map<String, KeyStrategy> strategies = subGraphOverAllState.keyStrategies();
      for (Map.Entry<String, KeyStrategy> strategyEntry : strategies.entrySet()) {
        if (!superOverAllState.containStrategy(strategyEntry.getKey())) {
          superOverAllState.registerKeyAndStrategy(strategyEntry.getKey(), strategyEntry.getValue());
        }
      }
    }
    subGraph.setOverAllState(getOverAllState());

    var node = new SubStateGraphNode(id, subGraph);

    if (nodes.elements.contains(node)) {
      throw Errors.duplicateNodeError.exception(id);
    }

    nodes.elements.add(node);
    return this;
  }

  /**
   * Adds an edge to the graph.
   * @param sourceId the identifier of the source node
   * @param targetId the identifier of the target node
   * @throws GraphStateException if the edge identifier is invalid or the edge already
   * exists
   */
  public StateGraph addEdge(String sourceId, String targetId) throws GraphStateException {
    if (Objects.equals(sourceId, END)) {
      throw Errors.invalidEdgeIdentifier.exception(END);
    }

    // if (Objects.equals(sourceId, START)) {
    // this.entryPoint = new EdgeValue<>(targetId);
    // return this;
    // }

    var newEdge = new Edge(sourceId, new EdgeValue(targetId));

    int index = edges.elements.indexOf(newEdge);
    if (index >= 0) {
      var newTargets = new ArrayList<>(edges.elements.get(index).targets());
      newTargets.add(newEdge.target());
      edges.elements.set(index, new Edge(sourceId, newTargets));
    }
    else {
      edges.elements.add(newEdge);
    }

    return this;
  }

  /**
   * Adds conditional edges to the graph.
   * @param sourceId the identifier of the source node
   * @param condition the condition to determine the target node
   * @param mappings the mappings of conditions to target nodes
   * @throws GraphStateException if the edge identifier is invalid, the mappings are
   * empty, or the edge already exists
   */
  public StateGraph addConditionalEdges(String sourceId, AsyncEdgeAction condition, Map<String, String> mappings)
      throws GraphStateException {
    if (Objects.equals(sourceId, END)) {
      throw Errors.invalidEdgeIdentifier.exception(END);
    }
    if (mappings == null || mappings.isEmpty()) {
      throw Errors.edgeMappingIsEmpty.exception(sourceId);
    }

    var newEdge = new Edge(sourceId, new EdgeValue(new EdgeCondition(condition, mappings)));

    if (edges.elements.contains(newEdge)) {
      throw Errors.duplicateConditionalEdgeError.exception(sourceId);
    }
    else {
      edges.elements.add(newEdge);
    }
    return this;
  }

  void validateGraph() throws GraphStateException {
    var edgeStart = edges.edgeBySourceId(START).orElseThrow(Errors.missingEntryPoint::exception);

    edgeStart.validate(nodes);

    for (Edge edge : edges.elements) {
      edge.validate(nodes);
    }

  }

  /**
   * Compiles the state graph into a compiled graph.
   * @param config the compile configuration
   * @return a compiled graph
   * @throws GraphStateException if there are errors related to the graph state
   */
  public CompiledGraph compile(CompileConfig config) throws GraphStateException {
    Objects.requireNonNull(config, "config cannot be null");

    validateGraph();

    return new CompiledGraph(this, config);
  }

  /**
   * Compiles the state graph into a compiled graph.
   * @return a compiled graph
   * @throws GraphStateException if there are errors related to the graph state
   */
  public CompiledGraph compile() throws GraphStateException {
    SaverConfig saverConfig = SaverConfig.builder().register(SaverConstant.MEMORY, new MemorySaver()).build();
    return compile(CompileConfig.builder()
      .plainTextStateSerializer(new JacksonSerializer())
      .saverConfig(saverConfig)
      .build());
  }

  /**
   * Generates a drawable graph representation of the state graph.
   * @param type the type of graph representation to generate
   * @param title the title of the graph
   * @param printConditionalEdges whether to print conditional edges
   * @return a diagram code of the state graph
   */
  public GraphRepresentation getGraph(GraphRepresentation.Type type, String title, boolean printConditionalEdges) {

    String content = type.generator.generate(nodes, edges, title, printConditionalEdges);

    return new GraphRepresentation(type, content);
  }

  /**
   * Generates a drawable graph representation of the state graph.
   * @param type the type of graph representation to generate
   * @param title the title of the graph
   * @return a diagram code of the state graph
   */
  public GraphRepresentation getGraph(GraphRepresentation.Type type, String title) {

    String content = type.generator.generate(nodes, edges, title, true);

    return new GraphRepresentation(type, content);
  }

  public GraphRepresentation getGraph(GraphRepresentation.Type type) {

    String content = type.generator.generate(nodes, edges, name, true);

    return new GraphRepresentation(type, content);
  }

  //......
}        

StateGraph提供了addNode、addEdge、addConditionalEdges等方法,其中getGraph方法根据指定GraphRepresentation.Type的DiagramGenerator来生成状态图

示例

less 复制代码
  @Test
  public void testGraph() throws GraphStateException {
    OverAllState overAllState = getOverAllState();
    StateGraph workflow = new StateGraph(overAllState).addNode("agent_1", node_async(state -> {
          System.out.println("agent_1");
          return Map.of("messages", "message1");
        })).addNode("agent_2", node_async(state -> {
          System.out.println("agent_2");
          return Map.of("messages", new String[] { "message2" });
        })).addNode("agent_3", node_async(state -> {
          System.out.println("agent_3");
          List<String> messages = Optional.ofNullable(state.value("messages").get())
              .filter(List.class::isInstance)
              .map(List.class::cast)
              .orElse(new ArrayList<>());

          int steps = messages.size() + 1;

          return Map.of("messages", "message3", "steps", steps);
        }))
        .addEdge("agent_1", "agent_2")
        .addEdge("agent_2", "agent_3")
        .addEdge(StateGraph.START, "agent_1")
        .addEdge("agent_3", StateGraph.END);
    GraphRepresentation representation = workflow.getGraph(GraphRepresentation.Type.PLANTUML, "demo");
    System.out.println(representation.content());
  }

输出如下:

less 复制代码
@startuml demo
skinparam usecaseFontSize 14
skinparam usecaseStereotypeFontSize 12
skinparam hexagonFontSize 14
skinparam hexagonStereotypeFontSize 12
title "demo"
footer

powered by spring-ai-alibaba
end footer
circle start<<input>> as __START__
circle stop as __END__
usecase "agent_1"<<Node>>
usecase "agent_2"<<Node>>
usecase "agent_3"<<Node>>
"__START__" -down-> "agent_1"
"agent_1" -down-> "agent_2"
"agent_2" -down-> "agent_3"
"agent_3" -down-> "__END__"
@enduml

小结

DiagramGenerator是个抽象类,定义了流程图生成的基类,它提供了appendHeader、appendFooter、call、declareConditionalStart、declareNode、declareConditionalEdge、commentLine抽象方法;它提供了generate方法根据nodes、edges、ctx生成图的文字表示。PlantUMLGenerator继承了DiagramGenerator,根据plantUML语法实现了抽象方法。

doc

相关推荐
win4r5 小时前
🚀企业级最强开源大模型Qwen3震撼发布!本地部署+全面客观测评!Qwen3-235B-A22B+Qwen3-32B+Qwen3-14B谁是王者?ollama
llm·aigc·openai
几米哥6 小时前
消费级GPU的AI逆袭:Gemma 3 QAT模型完整部署与应用指南
google·llm·gpu
阿里云大数据AI技术10 小时前
PAI Model Gallery 支持云上一键部署 Qwen3 全尺寸模型
人工智能·llm
changzz200814 小时前
低配置电脑预训练minimind的实践
llm·大语言模型·minimind·低配置
量子位1 天前
图像编辑开源新 SOTA,来自多模态卷王阶跃!大模型行业正步入「多模态时间」
人工智能·llm
智泊AI1 天前
别再吹通用型Al Agent了!其实真实业务都是Workflow
llm
free慢1 天前
用好大模型-提示词工程
llm
GPUStack2 天前
GPUStack v0.5:模型Catalog、图生图功能上线,多维优化全面提升产品能力与使用体验
ai·大模型·llm·genai·gpu集群
Goboy2 天前
Cursor 玩转 百度 AppBuilder
llm·cursor·mcp