Spring Ai Alibaba Graph的多节点并行流式处理

往期相关文章

Spring Ai Alibaba Graph源码系列-Serializer及Checkpoint机制

Spring Ai Alibaba Graph源码解读系列---核心启动类

Spring Ai Alibaba Graph源码解读系列---action

Spring AI Alibaba Graph:中断!人类反馈介入,流程丝滑走完~

Spring AI Alibaba Graph:分配MCP到指定节点

Spring AI Alibaba Graph:节点流式透传案例

Spring AI Alibaba Graph:多节点并行---快速上手

Spring AI Alibaba Graph:快速入门

可付费(69.9元)获取飞书云文档在线版预览,及github.com/GTyingzi/sp...

本期是Spring Ai Aliabab Graph的多节点并行流式处理

多节点并行流式处理

!TIP\] 在利用 Graph 搭建工作流时,我们希望能对节点并行处理加快执行效率,同时希望每个节点在对 AI 模型调用时能实时输出结果

以下为一个简单实现案例:包含三个节点,扩展节点、翻译节点、合并节点

  • 扩展节点:AI 模型流式输出扩展文本
  • 翻译节点:AI 模型流式输出翻译文本
  • 合并节点:扩展节点 && 翻译节点是并行处理,当它们都处理完成后交由合并节点

实践案例可见:github.com/GTyingzi/sp... 下的 graph 目录,本章代码为其 parallel-stream-node 模块

pom.xml

xml 复制代码
<properties>
    <spring-ai-alibaba.version>1.0.0.3</spring-ai-alibaba.version>
</properties>

<dependencies>

    <dependency>
        <groupId>org.springframework.ai</groupId>
        <artifactId>spring-ai-autoconfigure-model-openai</artifactId>
    </dependency>

    <dependency>
        <groupId>org.springframework.ai</groupId>
        <artifactId>spring-ai-autoconfigure-model-chat-client</artifactId>
    </dependency>

    <dependency>
        <groupId>com.alibaba.cloud.ai</groupId>
        <artifactId>spring-ai-alibaba-graph-core</artifactId>
        <version>${spring-ai-alibaba.version}</version>
    </dependency>

    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-web</artifactId>
    </dependency>
</dependencies>

application.yml

yaml 复制代码
server:
  port: 8080
spring:
  application:
    name: parallel-stream-node
  ai:
    openai:
      api-key: ${AIDASHSCOPEAPIKEY}
      base-url: https://dashscope.aliyuncs.com/compatible-mode
      chat:
        options:
          model: qwen-max

config

OverAllState 中存储的字段

  • query:用户的问题
  • expandernumber:扩展的数量
  • expandercontent:扩展的内容
  • translatelanguage:翻译的目标语言,默认为英文
  • translatecontent:翻译的内容
  • mergeresult:合并的内容

边的连接为:

java 复制代码
START -> expander -> merge
START -> translate -> merge
merge -> END
java 复制代码
package com.spring.ai.tutorial.graph.config;

import com.alibaba.cloud.ai.graph.GraphRepresentation;
import com.alibaba.cloud.ai.graph.KeyStrategy;
import com.alibaba.cloud.ai.graph.KeyStrategyFactory;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import com.spring.ai.tutorial.graph.model.NodeStatus;
import com.spring.ai.tutorial.graph.node.ExpanderNode;
import com.spring.ai.tutorial.graph.node.TranslateNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import java.util.HashMap;
import java.util.Map;

import static com.alibaba.cloud.ai.graph.action.AsyncNodeAction.nodeasync;

/**
 * @author yingzi
 * @since 2025/8/26
 */
@Configuration
public class GraphConfiguration {

    private static final Logger logger = LoggerFactory.getLogger(GraphConfiguration.class);

    @Bean
    public StateGraph parallelStreamGraph(ChatClient.Builder chatClientBuilder) throws GraphStateException {
        KeyStrategyFactory keyStrategyFactory = () -> {
            HashMap<String, KeyStrategy> keyStrategyHashMap = new HashMap<>();

            // 用户输入
            keyStrategyHashMap.put("query", new ReplaceStrategy());
            keyStrategyHashMap.put("expandernumber", new ReplaceStrategy());
            keyStrategyHashMap.put("expandercontent", new ReplaceStrategy());

            keyStrategyHashMap.put("translatelanguage", new ReplaceStrategy());
            keyStrategyHashMap.put("translatecontent", new ReplaceStrategy());

            keyStrategyHashMap.put("mergeresult", new ReplaceStrategy());

            return keyStrategyHashMap;
        };

        Map<String, NodeStatus> node2Status = new HashMap<>();

        StateGraph stateGraph = new StateGraph(keyStrategyFactory)
                .addNode(ExpanderNode.NODENAME, nodeasync(new ExpanderNode(chatClientBuilder, node2Status)))
                .addNode(TranslateNode.NODENAME, nodeasync(new TranslateNode(chatClientBuilder, node2Status)))
                .addNode(MergeResultsNode.NODENAME, nodeasync(new MergeResultsNode(node2Status)))

                .addEdge(StateGraph.START, TranslateNode.NODENAME)
                .addEdge(StateGraph.START, ExpanderNode.NODENAME)

                .addEdge(TranslateNode.NODENAME, MergeResultsNode.NODENAME)
                .addEdge(ExpanderNode.NODENAME, MergeResultsNode.NODENAME)

                .addEdge(MergeResultsNode.NODENAME, StateGraph.END);

        // 添加 PlantUML 打印
        GraphRepresentation representation = stateGraph.getGraph(GraphRepresentation.Type.PLANTUML,
                "expander flow");
        logger.info("\n=== expander UML Flow ===");
        logger.info(representation.content());
        logger.info("==================================\n");

        return stateGraph;
    }

    private class MergeResultsNode implements NodeAction {

        public static final String NODENAME = "merge";

        private final Map<String, NodeStatus> node2Status;

        public MergeResultsNode(Map<String, NodeStatus> node2Status) {
            this.node2Status = node2Status;
        }

        @Override
        public Map<String, Object> apply(OverAllState state) {
            if (!isDone(node2Status)) {
                return Map.of();
            }

            Object expanderContent = state.value("expandercontent").orElse("unknown");
            String translateContent = (String) state.value("translatecontent").orElse("");

            return Map.of("mergeresult", Map.of("expandercontent", expanderContent,
                    "translatecontent", translateContent));
        }

        private boolean isDone(Map<String, NodeStatus> node2Status) {
            return node2Status.get(ExpanderNode.NODENAME) == NodeStatus.COMPLETED
                    && node2Status.get(TranslateNode.NODENAME) == NodeStatus.COMPLETED;
        }
    }
}

node

TranslateNode

翻译节点,流式输出

java 复制代码
package com.spring.ai.tutorial.graph.node;

import com.alibaba.cloud.ai.graph.NodeOutput;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.async.AsyncGenerator;
import com.alibaba.cloud.ai.graph.streaming.StreamingChatGenerator;
import com.spring.ai.tutorial.graph.model.NodeStatus;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.PromptTemplate;
import reactor.core.publisher.Flux;

import java.util.Map;

/**
 * @author yingzi
 * @since 2025/6/13
 */

public class TranslateNode implements NodeAction {

    private static final PromptTemplate DEFAULTPROMPTTEMPLATE = new PromptTemplate("Given a user query, translate it to {targetLanguage}.\nIf the query is already in {targetLanguage}, return it unchanged.\nIf you don't know the language of the query, return it unchanged.\nDo not add explanations nor any other text.\n\nOriginal query: {query}\n\nTranslated query:\n");

    private final ChatClient chatClient;

    private final String  TARGETLANGUAGE= "English";

    private final Map<String, NodeStatus> node2Status;

    public static final String NODENAME = "translate";


    public TranslateNode(ChatClient.Builder chatClientBuilder, Map<String, NodeStatus> node2Status) {
        this.chatClient = chatClientBuilder.build();
        this.node2Status = node2Status;
    }

    @Override
    public Map<String, Object> apply(OverAllState state) {
        node2Status.put(NODENAME, NodeStatus.RUNNING);

        String query = state.value("query", "");
        String targetLanguage = state.value("translatelanguage", TARGETLANGUAGE);

        Flux<ChatResponse> chatResponseFlux = this.chatClient.prompt().user((user) -> user.text(DEFAULTPROMPTTEMPLATE.getTemplate()).param("targetLanguage", targetLanguage).param("query", query)).stream().chatResponse();

        AsyncGenerator<? extends NodeOutput> generator = StreamingChatGenerator.builder()
                .startingNode("translatellmstream")
                .startingState(state)
                .mapResult(response -> {
                    String text = response.getResult().getOutput().getText();
                    node2Status.put(NODENAME, NodeStatus.COMPLETED);
                    assert text != null;
                    return Map.of("translatecontent", text);
                }).build(chatResponseFlux);

        return Map.of("translatecontent", generator);
    }
}

ExpanderNode

扩展节点,流式输出

java 复制代码
package com.spring.ai.tutorial.graph.node;

import com.alibaba.cloud.ai.graph.NodeOutput;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.async.AsyncGenerator;
import com.alibaba.cloud.ai.graph.streaming.StreamingChatGenerator;
import com.spring.ai.tutorial.graph.model.NodeStatus;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.PromptTemplate;
import reactor.core.publisher.Flux;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * @author yingzi
 * @since 2025/6/13
 */

public class ExpanderNode implements NodeAction {

    private static final PromptTemplate DEFAULTPROMPTTEMPLATE = new PromptTemplate("You are an expert at information retrieval and search optimization.\nYour task is to generate {number} different versions of the given query.\n\nEach variant must cover different perspectives or aspects of the topic,\nwhile maintaining the core intent of the original query. The goal is to\nexpand the search space and improve the chances of finding relevant information.\n\nDo not explain your choices or add any other text.\nProvide the query variants separated by newlines.\n\nOriginal query: {query}\n\nQuery variants:\n");

    private final ChatClient chatClient;

    private final Integer NUMBER = 3;

    private final Map<String, NodeStatus> node2Status;

    public static final String NODENAME = "expander";

    public ExpanderNode(ChatClient.Builder chatClientBuilder, Map<String, NodeStatus> node2Status) {
        this.chatClient = chatClientBuilder.build();
        this.node2Status = node2Status;
    }

    @Override
    public Map<String, Object> apply(OverAllState state) {
        node2Status.put(NODENAME, NodeStatus.RUNNING);

        String query = state.value("query", "");
        Integer expanderNumber = state.value("expandernumber", this.NUMBER);

        Flux<ChatResponse> chatResponseFlux = this.chatClient.prompt().user((user) -> user.text(DEFAULTPROMPTTEMPLATE.getTemplate()).param("number", expanderNumber).param("query", query)).stream().chatResponse();

        AsyncGenerator<? extends NodeOutput> generator = StreamingChatGenerator.builder()
                .startingNode("expanderllmstream")
                .startingState(state)
                .mapResult(response -> {
                    String text = response.getResult().getOutput().getText();
                    List<String> queryVariants = Arrays.asList(text.split("\n"));
                    node2Status.put(NODENAME, NodeStatus.COMPLETED);
                    return Map.of("expandercontent", queryVariants);
                }).build(chatResponseFlux);
        return Map.of("expandercontent", generator);
    }
}

MergeResultsNode

合并 TranslateNode、ExpanderNode 节点内容

java 复制代码
private class MergeResultsNode implements NodeAction {

    public static final String NODENAME = "merge";

    private final Map<String, NodeStatus> node2Status;

    public MergeResultsNode(Map<String, NodeStatus> node2Status) {
        this.node2Status = node2Status;
    }

    @Override
    public Map<String, Object> apply(OverAllState state) {
        if (!isDone(node2Status)) {
            return Map.of();
        }

        Object expanderContent = state.value("expandercontent").orElse("unknown");
        String translateContent = (String) state.value("translatecontent").orElse("");

        return Map.of("mergeresult", Map.of("expandercontent", expanderContent,
                "translatecontent", translateContent));
    }

    private boolean isDone(Map<String, NodeStatus> node2Status) {
        return node2Status.get(ExpanderNode.NODENAME) == NodeStatus.COMPLETED
                && node2Status.get(TranslateNode.NODENAME) == NodeStatus.COMPLETED;
    }
}

NodeStatus

java 复制代码
public enum NodeStatus {

    RUNNING("running", "运行中"),

    COMPLETED("completed", "已完成"),

    FAILED("failed", "失败");

    String code;

    String desc;

    NodeStatus(String running, String desc) {
        this.code = running;
        this.desc = desc;
    }

}

GraphStreamController

java 复制代码
package com.spring.ai.tutorial.graph.controller;

import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.NodeOutput;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.async.AsyncGenerator;
import com.alibaba.cloud.ai.graph.exception.GraphRunnerException;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.spring.ai.tutorial.graph.controller.GraphProcess.GraphProcess;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Sinks;

import java.util.HashMap;
import java.util.Map;

/**
 * @author yingzi
 * @since 2025/6/13
 */
@RestController
@RequestMapping("/graph/parallel-stream")
public class GraphStreamController {

    private static final Logger logger = LoggerFactory.getLogger(GraphStreamController.class);

    private final CompiledGraph compiledGraph;

    public GraphStreamController(@Qualifier("parallelStreamGraph")StateGraph stateGraph) throws GraphStateException {
        this.compiledGraph = stateGraph.compile();
    }

    @GetMapping(value = "/expand-translate", produces = MediaType.TEXTEVENTSTREAMVALUE)
    public Flux<ServerSentEvent<String>> expand(@RequestParam(value = "query", defaultValue = "你好,很高兴认识你,能简单介绍一下自己吗?", required = false) String query,
                                                @RequestParam(value = "expandernumber", defaultValue = "3", required = false) Integer  expanderNumber,
                                                @RequestParam(value = "translatelanguage", defaultValue = "english", required = false) String translateLanguage,
                                                @RequestParam(value = "threadid", defaultValue = "yingzi", required = false) String threadId) throws GraphRunnerException {
        RunnableConfig runnableConfig = RunnableConfig.builder().threadId(threadId).build();
        Map<String, Object> objectMap = new HashMap<>();
        objectMap.put("query", query);
        objectMap.put("expandernumber", expanderNumber);
        objectMap.put("translatelanguage", translateLanguage);

        GraphProcess graphProcess = new GraphProcess(this.compiledGraph);
        Sinks.Many<ServerSentEvent<String>> sink = Sinks.many().unicast().onBackpressureBuffer();
        AsyncGenerator<NodeOutput> resultFuture = compiledGraph.stream(objectMap, runnableConfig);
        graphProcess.processStream(resultFuture, sink);

        return sink.asFlux()
                .doOnCancel(() -> logger.info("Client disconnected from stream"))
                .doOnError(e -> logger.error("Error occurred during streaming", e));
    }


}
java 复制代码
package com.spring.ai.tutorial.graph.controller.GraphProcess;

import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.NodeOutput;
import com.alibaba.cloud.ai.graph.async.AsyncGenerator;
import com.alibaba.cloud.ai.graph.streaming.StreamingOutput;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.codec.ServerSentEvent;
import reactor.core.publisher.Sinks;

import java.util.Map;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/**
 * @author yingzi
 * @since 2025/6/13
 */

public class GraphProcess {

    private static final Logger logger = LoggerFactory.getLogger(GraphProcess.class);

    private final ExecutorService executor = Executors.newSingleThreadExecutor();

    private CompiledGraph compiledGraph;

    public GraphProcess(CompiledGraph compiledGraph) {
        this.compiledGraph = compiledGraph;
    }

    public void processStream(AsyncGenerator<NodeOutput> generator, Sinks.Many<ServerSentEvent<String>> sink) {
        executor.submit(() -> {
            generator.forEachAsync(output -> {
                try {
                    logger.info("output = {}", output);
                    String nodeName = output.node();
                    String content;
                    if (output instanceof StreamingOutput streamingOutput) {
                        content = JSON.toJSONString(Map.of(nodeName, streamingOutput.chunk()));
                    } else {
                        JSONObject nodeOutput = new JSONObject();
                        nodeOutput.put("data", output.state().data());
                        nodeOutput.put("node", nodeName);
                        content = JSON.toJSONString(nodeOutput);
                    }
                    sink.tryEmitNext(ServerSentEvent.builder(content).build());
                } catch (Exception e) {
                    throw new CompletionException(e);
                }
            }).thenAccept(v -> {
                // 正常完成
                sink.tryEmitComplete();
            }).exceptionally(e -> {
                sink.tryEmitError(e);
                return null;
            });
        });
    }
}

效果演示

学习交流圈

你好,我是影子,曾先后在🐻、新能源、老铁就职,兼任Spring AI Alibaba开源社区的Committer。另外,本人长期维护一套飞书云文档笔记,涵盖后端、大数据系统化的面试资料,可私信免费获取

相关推荐
databook11 小时前
Manim实现旋转扭曲特效
后端·python·动效
karry_k11 小时前
ThreadLocal原理以及内存泄漏
java·后端·面试
MrSun的博客11 小时前
数据源切换之道
后端
Keepreal49611 小时前
1小时快速上手SpringBoot,熟练掌握CRUD
spring boot·后端
豆浆Whisky11 小时前
Go interface性能调优指南:避免常见陷阱的实用技巧|Go语言进阶(10)
后端·go
IT_陈寒13 小时前
「Redis性能翻倍的5个核心优化策略:从数据结构选择到持久化配置全解析」
前端·人工智能·后端
风象南13 小时前
SpringBoot安全进阶:利用门限算法加固密钥与敏感配置
后端
数据知道14 小时前
Go语言:用Go操作SQLite详解
开发语言·后端·golang·sqlite·go语言
你的人类朋友21 小时前
【Node】单线程的Node.js为什么可以实现多线程?
前端·后端·node.js
iナナ21 小时前
Spring Web MVC入门
java·前端·网络·后端·spring·mvc