spring-ai 工作流

目录

工作流概念

工作流是以相对固化的模式来人为地拆解任务,将一个大任务拆解为包含多个分支的固化流程。工作流的优势是确定性强,模型作为流程中的一个节点起到的更多是一个分类决策、内容生成的职责,因此它更适合意图识别等类别属性强的应用场景。

参考文档:https://java2ai.com/docs/1.0.0.2/get-started/workflow/?spm=4347728f.7cee0e64.0.0.39076dd1jbppqZ

工作流程图

商品评价分类流程图:

如用户反馈

  • This product is excellent, I love it!

    则输出:Praise, no action taken.

    说明:很好,不需要改进措施

  • The product broke after one day, very disappointed."

    则输出:product quality

    说明:有问题,产品质量问题

spring-boot 编码

使用:Spring AI Alibaba Graph

附maven的pom.xml

java 复制代码
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
	xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
	<modelVersion>4.0.0</modelVersion>
	<parent>
		<groupId>org.springframework.boot</groupId>
		<artifactId>spring-boot-starter-parent</artifactId>
		<version>3.4.6</version>
		<relativePath/> <!-- lookup parent from repository -->
	</parent>
	<groupId>com.example</groupId>
	<artifactId>demo-spring-test</artifactId>
	<version>0.0.1-SNAPSHOT</version>
	<name>demo-spring-test</name>
	<description>Demo project for Spring Boot</description>
	<url/>
	<licenses>
		<license/>
	</licenses>
	<developers>
		<developer/>
	</developers>
	<scm>
		<connection/>
		<developerConnection/>
		<tag/>
		<url/>
	</scm>
	<properties>
		<java.version>17</java.version>
		<spring-ai.version>1.0.0</spring-ai.version>
	</properties>
	<dependencies>
		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter-web</artifactId>
		</dependency>

		<!-- Spring AI Alibaba(通义大模型支持) -->
		<dependency>
			<groupId>com.alibaba.cloud.ai</groupId>
			<artifactId>spring-ai-alibaba-starter</artifactId>
			<version>1.0.0-M6.1</version>
		</dependency>
		<dependency>
			<groupId>org.springframework.ai</groupId>
			<artifactId>spring-ai-core</artifactId>
			<version>1.0.0-M6</version>
		</dependency>
		<dependency>
			<groupId>com.alibaba.cloud.ai</groupId>
			<artifactId>spring-ai-alibaba-autoconfigure</artifactId>
			<version>1.0.0-M6.1</version>
		</dependency>

		<!-- 引入 Graph 核心依赖 -->
		<dependency>
			<groupId>com.alibaba.cloud.ai</groupId>
			<artifactId>spring-ai-alibaba-graph-core</artifactId>
			<version>1.0.0.2</version>
		</dependency>
		<dependency>
			<groupId>com.alibaba.cloud.ai</groupId>
			<artifactId>spring-ai-alibaba-starter-document-parser-tika</artifactId>
			<version>1.0.0.2</version>
		</dependency>

		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter-test</artifactId>
			<scope>test</scope>
		</dependency>
	</dependencies>

	<build>
		<plugins>
			<plugin>
				<groupId>org.springframework.boot</groupId>
				<artifactId>spring-boot-maven-plugin</artifactId>
			</plugin>
		</plugins>
	</build>

</project>

定义节点 (Node)

创建工作流中的核心节点,包括两个文本分类节点和一个记录节点

分类

java 复制代码
// 评价正负分类节点
QuestionClassifierNode feedbackClassifier = QuestionClassifierNode.builder()
      .chatClient(chatClient)
      .inputTextKey("input")
      .categories(List.of("positive feedback", "negative feedback"))
      .classificationInstructions(
          List.of("Try to understand the user's feeling when he/she is giving the feedback."))
      .build();
// 负面评价具体问题分类节点
QuestionClassifierNode specificQuestionClassifier = QuestionClassifierNode.builder()
      .chatClient(chatClient)
      .inputTextKey("input")
      .categories(List.of("after-sale service", "transportation", "product quality", "others"))
      .classificationInstructions(List.of(
          "What kind of service or help the customer is trying to get from us? " +
          "Classify the question based on your understanding."))
      .build();

记录节点 RecordingNode:

java 复制代码
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.HashMap;
import java.util.Map;

public class RecordingNode implements NodeAction {

    private static final Logger logger = LoggerFactory.getLogger(RecordingNode.class);

    @Override
    public Map<String, Object> apply(OverAllState state) throws Exception {
        String feedback = (String) state.value("classifier_output").get();

        Map<String, Object> updatedState = new HashMap<>();
        if (feedback.contains("positive")) {
            logger.info("Received positive feedback: {}", feedback);
            updatedState.put("solution", "Praise, no action taken.");
        } else {
            logger.info("Received negative feedback: {}", feedback);
            updatedState.put("solution", feedback);
        }

        return updatedState;
    }

}

定义节点图StateGraph

java 复制代码
StateGraph graph = new StateGraph("Consumer Service Workflow Demo", stateFactory)
		// 添加节点
         .addNode("feedback_classifier", node_async(feedbackClassifier))
         .addNode("specific_question_classifier", node_async(specificQuestionClassifier))
         .addNode("recorder", node_async(recordingNode))
         // 定义边(流程顺序)
         .addEdge(START, "feedback_classifier")  // 起始节点
         .addConditionalEdges("feedback_classifier",
                 edge_async(new CustomerServiceController.FeedbackQuestionDispatcher()),
                 Map.of("positive", "recorder", "negative", "specific_question_classifier"))
         .addConditionalEdges("specific_question_classifier",
                 edge_async(new CustomerServiceController.SpecificQuestionDispatcher()),
                 Map.of("after-sale", "recorder", "transportation", "recorder",
                         "quality", "recorder", "others", "recorder"))
         .addEdge("recorder", END);  // 结束节点
 System.out.println("\n");
 return graph;

完整代码:

java 复制代码
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.OverAllStateFactory;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.node.QuestionClassifierNode;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import java.util.List;
import java.util.Map;

import static com.alibaba.cloud.ai.graph.StateGraph.END;
import static com.alibaba.cloud.ai.graph.StateGraph.START;
import static com.alibaba.cloud.ai.graph.action.AsyncEdgeAction.edge_async;
import static com.alibaba.cloud.ai.graph.action.AsyncNodeAction.node_async;

@Configuration
public class WorkflowAutoconfiguration {

    @Bean
    public StateGraph workflowGraph(ChatModel chatModel) throws GraphStateException {
        ChatClient chatClient = ChatClient.builder(chatModel)
                .defaultAdvisors(new SimpleLoggerAdvisor()).build();

        RecordingNode recordingNode = new RecordingNode();

        // 评价正负分类节点
        QuestionClassifierNode feedbackClassifier = QuestionClassifierNode.builder()
                .chatClient(chatClient)
                .inputTextKey("input")
                .categories(List.of("positive feedback", "negative feedback"))
                .classificationInstructions(
                        List.of("Try to understand the user's feeling when he/she is giving the feedback."))
                .build();

        // 负面评价具体问题分类节点
        QuestionClassifierNode specificQuestionClassifier = QuestionClassifierNode.builder()
                .chatClient(chatClient)
                .inputTextKey("input")
                .categories(List.of("after-sale service", "transportation", "product quality", "others"))
                .classificationInstructions(List.of(
                        "What kind of service or help the customer is trying to get from us? " +
                                "Classify the question based on your understanding."))
                .build();

        // 定义一个 OverAllStateFactory,用于在每次执行工作流时创建初始的全局状态对象
        OverAllStateFactory stateFactory = () -> {
            OverAllState state = new OverAllState();
            state.registerKeyAndStrategy("input", new ReplaceStrategy());
            state.registerKeyAndStrategy("classifier_output", new ReplaceStrategy());
            state.registerKeyAndStrategy("solution", new ReplaceStrategy());
            return state;
        };

        StateGraph graph = new StateGraph("Consumer Service Workflow Demo", stateFactory)
                .addNode("feedback_classifier", node_async(feedbackClassifier))
                .addNode("specific_question_classifier", node_async(specificQuestionClassifier))
                .addNode("recorder", node_async(recordingNode))
                // 定义边(流程顺序)
                .addEdge(START, "feedback_classifier")  // 起始节点
                .addConditionalEdges("feedback_classifier",
                        edge_async(new CustomerServiceController.FeedbackQuestionDispatcher()),
                        Map.of("positive", "recorder", "negative", "specific_question_classifier"))
                .addConditionalEdges("specific_question_classifier",
                        edge_async(new CustomerServiceController.SpecificQuestionDispatcher()),
                        Map.of("after-sale", "recorder", "transportation", "recorder",
                                "quality", "recorder", "others", "recorder"))
                .addEdge("recorder", END);  // 结束节点
        System.out.println("\n");
        return graph;
    }

}

controller测试

  • CustomerServiceController 完整代码
java 复制代码
import java.util.HashMap;
import java.util.Map;

import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.action.EdgeAction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

@RestController
@RequestMapping("/customer")
public class CustomerServiceController {

    private static final Logger logger = LoggerFactory.getLogger(CustomerServiceController.class);

    private CompiledGraph compiledGraph;

    public CustomerServiceController(@Qualifier("workflowGraph") StateGraph stateGraph) throws GraphStateException {
        this.compiledGraph = stateGraph.compile();
    }

    /**
     * localhost:8080/customer/chat?query=The product broke after one day, very disappointed.
     */
    @GetMapping("/chat")
    public String simpleChat(String query) throws Exception {
        logger.info("simpleChat: {}", query);
        return compiledGraph.invoke(Map.of("input", query))
                .get().value("solution")
                .get().toString();
    }

    public static class FeedbackQuestionDispatcher implements EdgeAction {

        @Override
        public String apply(OverAllState state) throws Exception {
            /**
             * 反馈的是商品的负面内容
             * 分类为:negative
             */
            String classifierOutput = (String) state.value("classifier_output").orElse("");
            logger.info("classifierOutput: {}", classifierOutput);

            if (classifierOutput.contains("positive")) {
                return "positive";
            }
            return "negative";
        }

    }

    public static class SpecificQuestionDispatcher implements EdgeAction {

        @Override
        public String apply(OverAllState state) throws Exception {
            /**
             * 反馈的是产品的质量
             * 分类为:quality
             */
            String classifierOutput = (String) state.value("classifier_output").orElse("");
            logger.info("classifierOutput: {}", classifierOutput);

            Map<String, String> classifierMap = new HashMap<>();
            classifierMap.put("after-sale", "after-sale");
            classifierMap.put("quality", "quality");
            classifierMap.put("transportation", "transportation");

            for (Map.Entry<String, String> entry : classifierMap.entrySet()) {
                if (classifierOutput.contains(entry.getKey())) {
                    return entry.getValue();
                }
            }

            return "others";
        }

    }

}
浏览器测试用户输入
相关推荐
nicepainkiller4 分钟前
anchor 智能合约案例3 之 journal
人工智能·智能合约·solana·anchor
nicepainkiller8 分钟前
anchor 智能合约案例2 之 vote
人工智能·智能合约·solana·anchor
Akttt16 分钟前
【T2I】R&B: REGION AND BOUNDARY AWARE ZERO-SHOT GROUNDED TEXT-TO-IMAGE GENERATION
人工智能·深度学习·计算机视觉·text2img
大模型服务器厂商22 分钟前
武汉大学机器人学院启航:一场颠覆性的产教融合实验,如何重塑中国智造未来?
人工智能
wx_ywyy679843 分钟前
推客系统小程序终极指南:从0到1构建自动裂变增长引擎,实现业绩10倍增长!
大数据·人工智能·短剧·短剧系统·推客系统·推客小程序·推客系统开发
说私域1 小时前
基于开源AI智能客服、AI智能名片与S2B2C商城小程序的微商服务质量提升路径研究
人工智能·小程序·开源
静心问道1 小时前
STEP-BACK PROMPTING:退一步:通过抽象在大型语言模型中唤起推理能力
人工智能·语言模型·大模型
机器之心1 小时前
野生DeepSeek火了,速度碾压官方版,权重开源
人工智能
机器之心1 小时前
人机协同筛出2600万条数据,七项基准全部SOTA,昆仑万维开源奖励模型再迎新突破
人工智能
jndingxin1 小时前
OpenCV CUDA模块设备层-----反向二值化阈值处理函数thresh_binary_inv_func()
人工智能·opencv·计算机视觉