序
本文主要研究一下Spring AI Alibaba的PlantUMLGenerator
DiagramGenerator
spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/DiagramGenerator.java
scss
public abstract class DiagramGenerator {
public enum CallStyle {
DEFAULT, START, END, CONDITIONAL, PARALLEL
}
public record Context(StringBuilder sb, String title, boolean printConditionalEdge, boolean isSubGraph) {
static Builder builder() {
return new Builder();
}
static public class Builder {
String title;
boolean printConditionalEdge;
boolean IsSubGraph;
private Builder() {
}
public Builder title(String title) {
this.title = title;
return this;
}
public Builder printConditionalEdge(boolean value) {
this.printConditionalEdge = value;
return this;
}
public Builder isSubGraph(boolean value) {
this.IsSubGraph = value;
return this;
}
public Context build() {
return new Context(new StringBuilder(), title, printConditionalEdge, IsSubGraph);
}
}
/**
* Converts a given title string to snake_case format by replacing all
* non-alphanumeric characters with underscores.
* @return the snake_case formatted string
*/
public String titleToSnakeCase() {
return title.replaceAll("[^a-zA-Z0-9]", "_");
}
/**
* Returns a string representation of this object by returning the string built in
* {@link #sb}.
* @return a string representation of this object.
*/
@Override
public String toString() {
return sb.toString();
}
}
/**
* Appends a header to the output based on the provided context.
* @param ctx The {@link Context} containing the information needed for appending the
* header.
*/
protected abstract void appendHeader(Context ctx);
/**
* Appends a footer to the content.
* @param ctx Context object containing the necessary information.
*/
protected abstract void appendFooter(Context ctx);
/**
* This method is an abstract method that must be implemented by subclasses. It is
* used to initiate a communication call between two parties identified by their phone
* numbers.
* @param ctx The current context in which the call is being made.
* @param from The phone number of the caller.
* @param to The phone number of the recipient.
*/
protected abstract void call(Context ctx, String from, String to, CallStyle style);
/**
* Abstract method that must be implemented by subclasses to handle the logic of
* making a call.
* @param ctx The context in which the call is being made.
* @param from The phone number of the caller.
* @param to The phone number of the recipient.
* @param description A brief description of the call.
*/
protected abstract void call(Context ctx, String from, String to, String description, CallStyle style);
/**
* Declares a conditional element in the configuration or template. This method is
* used to mark the start of a conditional section based on the provided {@code name}.
* It takes a {@code Context} object that may contain additional parameters necessary
* for the declaration, and a {@code name} which identifies the type or key associated
* with the conditional section.
* @param ctx The context containing contextual information needed for the
* declaration.
* @param name The name of the conditional section to be declared.
*/
protected abstract void declareConditionalStart(Context ctx, String name);
/**
* Declares a node in the specified context with the given name.
* @param ctx the context in which to declare the node {@code @literal (not null)}
* @param name the name of the node to be declared
* {@code @literal (not null, not empty)}
*/
protected abstract void declareNode(Context ctx, String name);
/**
* Declares a conditional edge in the context with a specified ordinal.
* @param ctx the context
* @param ordinal the ordinal value
*/
protected abstract void declareConditionalEdge(Context ctx, int ordinal);
/**
* Comment a line in the given context.
* @param ctx The context in which the line is to be commented.
* @param yesOrNo Whether the line should be uncommented ({@literal true}) or
* commented ({@literal false}).
*/
protected abstract void commentLine(Context ctx, boolean yesOrNo);
/**
* Generate a textual representation of the given graph.
* @param nodes the state graph nodes used to generate the context, which must not be
* null
* @param edges the state graph edges used to generate the context, which must not be
* null
* @param title The title of the graph.
* @param printConditionalEdge Whether to print the conditional edge condition.
* @return A string representation of the graph.
*/
public final String generate(StateGraph.Nodes nodes, StateGraph.Edges edges, String title,
boolean printConditionalEdge) {
return generate(nodes, edges,
Context.builder().title(title).isSubGraph(false).printConditionalEdge(printConditionalEdge).build())
.toString();
}
/**
* Generates a context based on the given state graph.
* @param nodes the state graph nodes used to generate the context, which must not be
* null
* @param edges the state graph edges used to generate the context, which must not be
* null
* @param ctx the initial context, which must not be null
* @return the generated context, which will not be null
*/
protected final Context generate(StateGraph.Nodes nodes, StateGraph.Edges edges, Context ctx) {
appendHeader(ctx);
for (var n : nodes.elements) {
if (n instanceof SubGraphNode subGraphNode) {
@SuppressWarnings("unchecked")
var subGraph = (StateGraph) subGraphNode.subGraph();
Context subgraphCtx = generate(subGraph.nodes, subGraph.edges,
Context.builder()
.title(n.id())
.printConditionalEdge(ctx.printConditionalEdge)
.isSubGraph(true)
.build());
ctx.sb().append(subgraphCtx);
}
else {
declareNode(ctx, n.id());
}
}
final int[] conditionalEdgeCount = { 0 };
edges.elements.stream()
.filter(e -> !Objects.equals(e.sourceId(), START))
.filter(e -> !e.isParallel())
.forEach(e -> {
if (e.target().value() != null) {
conditionalEdgeCount[0] += 1;
commentLine(ctx, !ctx.printConditionalEdge());
declareConditionalEdge(ctx, conditionalEdgeCount[0]);
}
});
var edgeStart = edges.elements.stream()
.filter(e -> Objects.equals(e.sourceId(), START))
.findFirst()
.orElseThrow();
if (edgeStart.isParallel()) {
edgeStart.targets().forEach(target -> {
call(ctx, START, target.id(), CallStyle.START);
});
}
else if (edgeStart.target().id() != null) {
call(ctx, START, edgeStart.target().id(), CallStyle.START);
}
else if (edgeStart.target().value() != null) {
String conditionName = "startcondition";
commentLine(ctx, !ctx.printConditionalEdge());
declareConditionalStart(ctx, conditionName);
edgeCondition(ctx, edgeStart.target().value(), START, conditionName);
}
conditionalEdgeCount[0] = 0; // reset
edges.elements.stream().filter(e -> !Objects.equals(e.sourceId(), START)).forEach(v -> {
if (v.isParallel()) {
v.targets().forEach(target -> {
call(ctx, v.sourceId(), target.id(), CallStyle.PARALLEL);
});
}
else if (v.target().id() != null) {
call(ctx, v.sourceId(), v.target().id(), CallStyle.DEFAULT);
}
else if (v.target().value() != null) {
conditionalEdgeCount[0] += 1;
String conditionName = format("condition%d", conditionalEdgeCount[0]);
edgeCondition(ctx, v.targets().get(0).value(), v.sourceId(), conditionName);
}
});
appendFooter(ctx);
return ctx;
}
/**
* Evaluates an edge condition based on the given context and condition.
* @param ctx the current context used for evaluation
* @param condition the condition to be evaluated
* @param k a string identifier for the condition
* @param conditionName the name of the condition being processed
*/
private void edgeCondition(Context ctx, EdgeCondition condition, String k, String conditionName) {
commentLine(ctx, !ctx.printConditionalEdge());
call(ctx, k, conditionName, CallStyle.CONDITIONAL);
condition.mappings().forEach((cond, to) -> {
commentLine(ctx, !ctx.printConditionalEdge());
call(ctx, conditionName, to, cond, CallStyle.CONDITIONAL);
commentLine(ctx, ctx.printConditionalEdge());
call(ctx, k, to, cond, CallStyle.CONDITIONAL);
});
}
}
DiagramGenerator是个抽象类,定义了流程图生成的基类,它提供了appendHeader、appendFooter、call、declareConditionalStart、declareNode、declareConditionalEdge、commentLine抽象方法;它提供了generate方法根据nodes、edges、ctx生成图的文字表示。
PlantUMLGenerator
spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/diagram/PlantUMLGenerator.java
swift
public class PlantUMLGenerator extends DiagramGenerator {
@Override
protected void appendHeader(Context ctx) {
if (ctx.isSubGraph()) {
ctx.sb()
.append(format("rectangle %s [ {{\ntitle \"%s\"\n", ctx.title(), ctx.title()))
.append(format("circle \" \" as %s\n", START))
.append(format("circle exit as %s\n", END));
}
else {
ctx.sb()
.append(format("@startuml %s\n", ctx.titleToSnakeCase()))
.append("skinparam usecaseFontSize 14\n")
.append("skinparam usecaseStereotypeFontSize 12\n")
.append("skinparam hexagonFontSize 14\n")
.append("skinparam hexagonStereotypeFontSize 12\n")
.append(format("title \"%s\"\n", ctx.title()))
.append("footer\n\n")
.append("powered by spring-ai-alibaba\n")
.append("end footer\n")
.append(format("circle start<<input>> as %s\n", START))
.append(format("circle stop as %s\n", END));
}
}
@Override
protected void appendFooter(Context ctx) {
if (ctx.isSubGraph()) {
ctx.sb().append("\n}} ]\n");
}
else {
ctx.sb().append("@enduml\n");
}
}
@Override
protected void call(Context ctx, String from, String to, CallStyle style) {
ctx.sb().append(switch (style) {
case CONDITIONAL -> format("\"%s\" .down.> \"%s\"\n", from, to);
default -> format("\"%s\" -down-> \"%s\"\n", from, to);
});
}
@Override
protected void call(Context ctx, String from, String to, String description, CallStyle style) {
ctx.sb().append(switch (style) {
case CONDITIONAL -> format("\"%s\" .down.> \"%s\": \"%s\"\n", from, to, description);
default -> format("\"%s\" -down-> \"%s\": \"%s\"\n", from, to, description);
});
}
@Override
protected void declareConditionalStart(Context ctx, String name) {
ctx.sb().append(format("hexagon \"check state\" as %s<<Condition>>\n", name));
}
@Override
protected void declareNode(Context ctx, String name) {
ctx.sb().append(format("usecase \"%s\"<<Node>>\n", name));
}
@Override
protected void declareConditionalEdge(Context ctx, int ordinal) {
ctx.sb().append(format("hexagon \"check state\" as condition%d<<Condition>>\n", ordinal));
}
@Override
protected void commentLine(Context ctx, boolean yesOrNo) {
if (yesOrNo)
ctx.sb().append("'");
}
}
PlantUMLGenerator实现了DiagramGenerator的抽象方法
StateGraph
spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/StateGraph.java
scss
/**
* Represents a state graph with nodes and edges.
*
*/
public class StateGraph {
public static String END = "__END__";
public static String START = "__START__";
final Nodes nodes = new Nodes();
final Edges edges = new Edges();
private OverAllState overAllState;
private String name;
public OverAllState getOverAllState() {
return overAllState;
}
public StateGraph setOverAllState(OverAllState overAllState) {
this.overAllState = overAllState;
return this;
}
private final PlainTextStateSerializer stateSerializer;
//......
/**
* Instantiates a new State graph.
* @param overAllState the over all state
* @param plainTextStateSerializer the plain text state serializer
*/
public StateGraph(OverAllState overAllState, PlainTextStateSerializer plainTextStateSerializer) {
this.overAllState = overAllState;
this.stateSerializer = plainTextStateSerializer;
}
public StateGraph(String name, OverAllState overAllState) {
this.name = name;
this.overAllState = overAllState;
this.stateSerializer = new GsonSerializer();
}
/**
* Instantiates a new State graph.
* @param overAllState the over all state
*/
public StateGraph(OverAllState overAllState) {
this.overAllState = overAllState;
this.stateSerializer = new GsonSerializer();
}
public StateGraph(String name, AgentStateFactory<OverAllState> factory) {
this.name = name;
this.overAllState = factory.apply(Map.of());
this.stateSerializer = new GsonSerializer2(factory);
}
public StateGraph(AgentStateFactory<OverAllState> factory) {
this.overAllState = factory.apply(Map.of());
this.stateSerializer = new GsonSerializer2(factory);
}
/**
* Instantiates a new State graph.
*/
public StateGraph() {
this.stateSerializer = new GsonSerializer();
}
public String getName() {
return name;
}
/**
* Key strategies map.
* @return the map
*/
public Map<String, KeyStrategy> keyStrategies() {
return overAllState.keyStrategies();
}
/**
* Gets state serializer.
* @return the state serializer
*/
public StateSerializer getStateSerializer() {
return stateSerializer;
}
/**
* Gets state factory.
* @return the state factory
*/
public final AgentStateFactory<OverAllState> getStateFactory() {
return stateSerializer.stateFactory();
}
/**
* /** Adds a node to the graph.
* @param id the identifier of the node
* @param action the action to be performed by the node
* @throws GraphStateException if the node identifier is invalid or the node already
* exists
*/
public StateGraph addNode(String id, AsyncNodeAction action) throws GraphStateException {
return addNode(id, AsyncNodeActionWithConfig.of(action));
}
/**
* @param id the identifier of the node
* @param actionWithConfig the action to be performed by the node
* @return this
* @throws GraphStateException if the node identifier is invalid or the node already
* exists
*/
public StateGraph addNode(String id, AsyncNodeActionWithConfig actionWithConfig) throws GraphStateException {
Node node = new Node(id, (config) -> actionWithConfig);
return addNode(id, node);
}
/**
* @param id the identifier of the node
* @param node the node to be added
* @return this
* @throws GraphStateException if the node identifier is invalid or the node already
* exists
*/
public StateGraph addNode(String id, Node node) throws GraphStateException {
if (Objects.equals(node.id(), END)) {
throw Errors.invalidNodeIdentifier.exception(END);
}
if (!Objects.equals(node.id(), id)) {
throw Errors.invalidNodeIdentifier.exception(node.id(), id);
}
if (nodes.elements.contains(node)) {
throw Errors.duplicateNodeError.exception(id);
}
nodes.elements.add(node);
return this;
}
/**
* Adds a subgraph to the state graph by creating a node with the specified
* identifier. This implies that Subgraph share the same state with parent graph
* @param id the identifier of the node representing the subgraph
* @param subGraph the compiled subgraph to be added
* @return this state graph instance
* @throws GraphStateException if the node identifier is invalid or the node already
* exists
*/
public StateGraph addNode(String id, CompiledGraph subGraph) throws GraphStateException {
if (Objects.equals(id, END)) {
throw Errors.invalidNodeIdentifier.exception(END);
}
var node = new SubCompiledGraphNode(id, subGraph);
if (nodes.elements.contains(node)) {
throw Errors.duplicateNodeError.exception(id);
}
nodes.elements.add(node);
return this;
}
/**
* Adds a subgraph to the state graph by creating a node with the specified
* identifier. This implies that Subgraph share the same state with parent graph
* @param id the identifier of the node representing the subgraph
* @param subGraph the subgraph to be added. it will be compiled on compilation of the
* parent
* @return this state graph instance
* @throws GraphStateException if the node identifier is invalid or the node already
* exists
*/
public StateGraph addNode(String id, StateGraph subGraph) throws GraphStateException {
if (Objects.equals(id, END)) {
throw Errors.invalidNodeIdentifier.exception(END);
}
subGraph.validateGraph();
OverAllState subGraphOverAllState = subGraph.getOverAllState();
OverAllState superOverAllState = getOverAllState();
if (subGraphOverAllState != null) {
Map<String, KeyStrategy> strategies = subGraphOverAllState.keyStrategies();
for (Map.Entry<String, KeyStrategy> strategyEntry : strategies.entrySet()) {
if (!superOverAllState.containStrategy(strategyEntry.getKey())) {
superOverAllState.registerKeyAndStrategy(strategyEntry.getKey(), strategyEntry.getValue());
}
}
}
subGraph.setOverAllState(getOverAllState());
var node = new SubStateGraphNode(id, subGraph);
if (nodes.elements.contains(node)) {
throw Errors.duplicateNodeError.exception(id);
}
nodes.elements.add(node);
return this;
}
/**
* Adds an edge to the graph.
* @param sourceId the identifier of the source node
* @param targetId the identifier of the target node
* @throws GraphStateException if the edge identifier is invalid or the edge already
* exists
*/
public StateGraph addEdge(String sourceId, String targetId) throws GraphStateException {
if (Objects.equals(sourceId, END)) {
throw Errors.invalidEdgeIdentifier.exception(END);
}
// if (Objects.equals(sourceId, START)) {
// this.entryPoint = new EdgeValue<>(targetId);
// return this;
// }
var newEdge = new Edge(sourceId, new EdgeValue(targetId));
int index = edges.elements.indexOf(newEdge);
if (index >= 0) {
var newTargets = new ArrayList<>(edges.elements.get(index).targets());
newTargets.add(newEdge.target());
edges.elements.set(index, new Edge(sourceId, newTargets));
}
else {
edges.elements.add(newEdge);
}
return this;
}
/**
* Adds conditional edges to the graph.
* @param sourceId the identifier of the source node
* @param condition the condition to determine the target node
* @param mappings the mappings of conditions to target nodes
* @throws GraphStateException if the edge identifier is invalid, the mappings are
* empty, or the edge already exists
*/
public StateGraph addConditionalEdges(String sourceId, AsyncEdgeAction condition, Map<String, String> mappings)
throws GraphStateException {
if (Objects.equals(sourceId, END)) {
throw Errors.invalidEdgeIdentifier.exception(END);
}
if (mappings == null || mappings.isEmpty()) {
throw Errors.edgeMappingIsEmpty.exception(sourceId);
}
var newEdge = new Edge(sourceId, new EdgeValue(new EdgeCondition(condition, mappings)));
if (edges.elements.contains(newEdge)) {
throw Errors.duplicateConditionalEdgeError.exception(sourceId);
}
else {
edges.elements.add(newEdge);
}
return this;
}
void validateGraph() throws GraphStateException {
var edgeStart = edges.edgeBySourceId(START).orElseThrow(Errors.missingEntryPoint::exception);
edgeStart.validate(nodes);
for (Edge edge : edges.elements) {
edge.validate(nodes);
}
}
/**
* Compiles the state graph into a compiled graph.
* @param config the compile configuration
* @return a compiled graph
* @throws GraphStateException if there are errors related to the graph state
*/
public CompiledGraph compile(CompileConfig config) throws GraphStateException {
Objects.requireNonNull(config, "config cannot be null");
validateGraph();
return new CompiledGraph(this, config);
}
/**
* Compiles the state graph into a compiled graph.
* @return a compiled graph
* @throws GraphStateException if there are errors related to the graph state
*/
public CompiledGraph compile() throws GraphStateException {
SaverConfig saverConfig = SaverConfig.builder().register(SaverConstant.MEMORY, new MemorySaver()).build();
return compile(CompileConfig.builder()
.plainTextStateSerializer(new JacksonSerializer())
.saverConfig(saverConfig)
.build());
}
/**
* Generates a drawable graph representation of the state graph.
* @param type the type of graph representation to generate
* @param title the title of the graph
* @param printConditionalEdges whether to print conditional edges
* @return a diagram code of the state graph
*/
public GraphRepresentation getGraph(GraphRepresentation.Type type, String title, boolean printConditionalEdges) {
String content = type.generator.generate(nodes, edges, title, printConditionalEdges);
return new GraphRepresentation(type, content);
}
/**
* Generates a drawable graph representation of the state graph.
* @param type the type of graph representation to generate
* @param title the title of the graph
* @return a diagram code of the state graph
*/
public GraphRepresentation getGraph(GraphRepresentation.Type type, String title) {
String content = type.generator.generate(nodes, edges, title, true);
return new GraphRepresentation(type, content);
}
public GraphRepresentation getGraph(GraphRepresentation.Type type) {
String content = type.generator.generate(nodes, edges, name, true);
return new GraphRepresentation(type, content);
}
//......
}
StateGraph提供了addNode、addEdge、addConditionalEdges等方法,其中getGraph方法根据指定GraphRepresentation.Type的DiagramGenerator来生成状态图
示例
less
@Test
public void testGraph() throws GraphStateException {
OverAllState overAllState = getOverAllState();
StateGraph workflow = new StateGraph(overAllState).addNode("agent_1", node_async(state -> {
System.out.println("agent_1");
return Map.of("messages", "message1");
})).addNode("agent_2", node_async(state -> {
System.out.println("agent_2");
return Map.of("messages", new String[] { "message2" });
})).addNode("agent_3", node_async(state -> {
System.out.println("agent_3");
List<String> messages = Optional.ofNullable(state.value("messages").get())
.filter(List.class::isInstance)
.map(List.class::cast)
.orElse(new ArrayList<>());
int steps = messages.size() + 1;
return Map.of("messages", "message3", "steps", steps);
}))
.addEdge("agent_1", "agent_2")
.addEdge("agent_2", "agent_3")
.addEdge(StateGraph.START, "agent_1")
.addEdge("agent_3", StateGraph.END);
GraphRepresentation representation = workflow.getGraph(GraphRepresentation.Type.PLANTUML, "demo");
System.out.println(representation.content());
}
输出如下:
less
@startuml demo
skinparam usecaseFontSize 14
skinparam usecaseStereotypeFontSize 12
skinparam hexagonFontSize 14
skinparam hexagonStereotypeFontSize 12
title "demo"
footer
powered by spring-ai-alibaba
end footer
circle start<<input>> as __START__
circle stop as __END__
usecase "agent_1"<<Node>>
usecase "agent_2"<<Node>>
usecase "agent_3"<<Node>>
"__START__" -down-> "agent_1"
"agent_1" -down-> "agent_2"
"agent_2" -down-> "agent_3"
"agent_3" -down-> "__END__"
@enduml
小结
DiagramGenerator是个抽象类,定义了流程图生成的基类,它提供了appendHeader、appendFooter、call、declareConditionalStart、declareNode、declareConditionalEdge、commentLine抽象方法;它提供了generate方法根据nodes、edges、ctx生成图的文字表示。PlantUMLGenerator继承了DiagramGenerator,根据plantUML语法实现了抽象方法。