512-Spring AI Alibaba 字段分类分级 Graph 示例

本案例演示如何使用 Spring AI Alibaba 的 Graph 功能,实现一个字段分类分级的智能工作流系统。该系统通过多个节点协同工作,实现敏感词检测、字段分类、人工审核等功能。

1. 案例目标

我们将构建一个包含以下核心功能的字段分类分级系统:

  1. 敏感词检测:自动检测输入字段是否包含敏感词,决定后续处理流程。
  2. 字段分类:基于知识库对字段进行智能分类和分级。
  3. 人工审核:对分类结果进行人工审核,可以批准或拒绝AI的分类结果。
  4. 结果保存:将最终分类结果保存到数据库中。

2. 技术栈与核心依赖

  • Spring Boot 3.x
  • Spring AI Alibaba Graph (用于构建智能工作流)
  • MyBatis-Plus (用于数据库操作)
  • MySQL (数据存储)
  • OpenAI API (通过DashScope兼容模式调用通义大模型)

pom.xml 中,核心依赖包括:

复制代码
<dependencies>
    <!-- Spring Web 用于构建 RESTful API -->
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-web</artifactId>
    </dependency>
    
    <!-- OpenAI 模型支持 -->
    <dependency>
        <groupId>org.springframework.ai</groupId>
        <artifactId>spring-ai-starter-model-openai</artifactId>
    </dependency>
    
    <!-- Spring AI Alibaba Graph 核心 -->
    <dependency>
        <groupId>com.alibaba.cloud.ai</groupId>
        <artifactId>spring-ai-alibaba-graph-core</artifactId>
        <version>1.0.0.3</version>
    </dependency>
    
    <!-- MyBatis-Plus Starter -->
    <dependency>
        <groupId>com.baomidou</groupId>
        <artifactId>mybatis-plus-boot-starter</artifactId>
        <version>3.5.6</version>
    </dependency>
    
    <!-- MySQL Connector -->
    <dependency>
        <groupId>com.mysql</groupId>
        <artifactId>mysql-connector-j</artifactId>
        <version>8.2.0</version>
    </dependency>
</dependencies>

3. 项目配置

src/main/resources/application.yml 文件中,配置数据库连接和AI模型参数:

复制代码
server:
  port: 8080

spring:
  datasource:
    driver-class-name: com.mysql.cj.jdbc.Driver
    url: jdbc:mysql://127.0.0.1:3306/test?characterEncoding=utf8&autoReconnect=true&useUnicode=true&useSSL=false&serverTimezone=UTC
    username: your_username
    password: your_password

  application:
    name: spring-ai-alibaba-graph-sec

  ai:
    openai:
      api-key: your_api_key
      base-url: https://dashscope.aliyuncs.com/compatible-mode
      embedding:
        options:
          model: text-embedding-v1
      chat:
        options:
          model: qwen-max

mybatis-plus:
  configuration:
    map-underscore-to-camel-case: true

4. 系统架构

本系统基于 Spring AI Alibaba Graph 构建,包含以下核心组件:

4.1 Graph 工作流

系统工作流由多个节点组成,通过状态图(StateGraph)定义节点间的流转关系:

复制代码
@Configuration
public class SecGraphBuilder {
    
    @Bean
    public StateGraph secGraph() {
        // 构建状态图
        return StateGraph.builder(OverAllState.class)
                .addNode("sensitive", node_async(new SensitiveWordDecNode()))
                .addNode("answer", node_async(new AnswerNode()))
                .addNode("clft", node_async(new ClftNode()))
                .addNode("human", node_async(new HumanFeedbackNode()))
                .addNode("saveTool", node_async(new ToolNode(List.of(new FieldSaveTool()))))
                .addEdge(START, "sensitive")
                .addEdge("answer", END)
                .addConditionalEdges("sensitive", 
                    AsyncEdgeAction.of(new SensitiveDispatcher()),
                    Map.of("yes", "answer", "no", "clft"))
                .addConditionalEdges("clft", 
                    AsyncEdgeAction.of(new ClftDispatcher()),
                    Map.of("yes", "human", "no", "saveTool"))
                .addConditionalEdges("human", 
                    AsyncEdgeAction.of(new HumanFeedbackDispatcher()),
                    Map.of("saveTool", "saveTool", "clft", "clft"))
                .addEdge("saveTool", END)
                .build();
    }
}

4.2 节点(Nodes)

系统包含以下核心节点:

  • SensitiveWordDecNode:敏感词检测节点,判断输入字段是否包含敏感词。
  • AnswerNode:回答节点,当检测到敏感词时给出相应回答。
  • ClftNode:分类节点,对字段进行分类和分级。
  • HumanFeedbackNode:人工反馈节点,处理人工审核结果。
  • ToolNode:工具节点,执行字段保存操作。

4.3 调度器(Dispatchers)

调度器负责根据节点执行结果决定下一个流转的节点:

  • SensitiveDispatcher:根据敏感词检测结果决定流向。
  • ClftDispatcher:根据分类结果决定是否需要人工审核。
  • HumanFeedbackDispatcher:根据人工反馈结果决定下一步操作。

4.4 数据存储

系统使用 MySQL 数据库存储字段分类结果,包含以下实体:

  • Field:字段实体,包含字段名、分类、级别和推理过程。
  • FieldMapper:数据访问层接口。
  • IFieldService:服务层接口,提供字段相关业务逻辑。

5. 核心代码实现

5.1 主应用类

复制代码
@SpringBootApplication
@Slf4j
public class Application {
    
    public static void main(String[] args) {
        SpringApplication.run(Application.class, args);
    }
    
    @Bean
    CommandLineRunner vectorIngestRunner(
            @Value("${rag.source:classpath:rag/rag_friendly_classification.txt}") Resource ragSource,
            EmbeddingModel embeddingModel,
            @Qualifier("classificationVectorStore") VectorStore classificationVectorStore
    ) {
        return args -> {
            logger.info("🔄 正在向量化加载分类分级知识库...");
            var chunks = new TokenTextSplitter().transform(new TextReader(ragSource).read());
            classificationVectorStore.write(chunks);
        };
    }
    
    @Bean
    public VectorStore classificationVectorStore(EmbeddingModel embeddingModel) {
        return SimpleVectorStore.builder(embeddingModel).build();
    }
    
    @Bean
    public ChatMemory chatMemory() {
        return MessageWindowChatMemory.builder().build();
    }
}

5.2 控制器实现

复制代码
@RestController
@RequestMapping("/sec/graph")
@Slf4j
public class SecGraphController {
    private final CompiledGraph compiledGraph;

    public SecGraphController(@Qualifier("secGraph") StateGraph stateGraph) throws GraphStateException {
        SaverConfig saverConfig = SaverConfig.builder()
                .register(SaverConstant.MEMORY, new MemorySaver())
                .build();

        this.compiledGraph = stateGraph
                .compile(CompileConfig.builder()
                        .saverConfig(saverConfig)
                        .interruptBefore("human")
                        .build());
    }

    @GetMapping(value = "/chat", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
    public Flux<ServerSentEvent<String>> simpleChat(
            @RequestParam("fieldName") String fieldName,
            @RequestParam(value = "thread_id", defaultValue = "yhong", required = false) String threadId) throws Exception {
        RunnableConfig runnableConfig = RunnableConfig.builder().threadId(threadId).build();
        GraphProcess graphProcess = new GraphProcess(this.compiledGraph);
        Sinks.Many<ServerSentEvent<String>> sink = Sinks.many().unicast().onBackpressureBuffer();
        AsyncGenerator<NodeOutput> resultFuture = compiledGraph.stream(Map.of("field", fieldName), runnableConfig);
        graphProcess.processStream(resultFuture, sink);
        return sink.asFlux()
                .doOnCancel(() -> log.info("Client disconnected from stream"))
                .doOnError(e -> log.error("Error occurred during streaming", e));
    }

    @GetMapping(value = "/resume", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
    public Flux<ServerSentEvent<String>> resume(
            @RequestParam(value = "thread_id", defaultValue = "yhong", required = false) String threadId,
            @RequestParam(value = "feed_back", defaultValue = "true", required = false) boolean feedBack,
            @RequestParam(value = "feedback_reason", defaultValue = "", required = false) String humanReason) throws GraphRunnerException {
        RunnableConfig runnableConfig = RunnableConfig.builder().threadId(threadId).build();
        StateSnapshot stateSnapshot = this.compiledGraph.getState(runnableConfig);
        OverAllState state = stateSnapshot.state();
        state.withResume();

        Map<String, Object> objectMap = new HashMap<>();
        objectMap.put("feed_back", feedBack);
        objectMap.put("feedback_reason", humanReason);
        state.withHumanFeedback(new OverAllState.HumanFeedback(objectMap, "feed_back"));

        Sinks.Many<ServerSentEvent<String>> sink = Sinks.many().unicast().onBackpressureBuffer();
        GraphProcess graphProcess = new GraphProcess(this.compiledGraph);
        AsyncGenerator<NodeOutput> resultFuture = compiledGraph.streamFromInitialNode(state, runnableConfig);
        graphProcess.processStream(resultFuture, sink);

        return sink.asFlux()
                .doOnCancel(() -> log.info("Client disconnected from stream"))
                .doOnError(e -> log.error("Error occurred during streaming", e));
    }
}

5.3 敏感词检测节点

复制代码
public class SensitiveWordDecNode implements NodeAction {
    
    @Override
    public Map<String, Object> apply(OverAllState state) {
        String field = state.value("field");
        boolean isSensitive = checkSensitiveWords(field);
        
        Map<String, Object> result = new HashMap<>();
        result.put("is_sensitive", isSensitive);
        
        if (isSensitive) {
            result.put("sensitive_reason", "检测到敏感词: " + field);
        }
        
        return result;
    }
    
    private boolean checkSensitiveWords(String field) {
        // 实现敏感词检测逻辑
        return false; // 简化示例
    }
}

5.4 分类节点

复制代码
public class ClftNode implements NodeAction {
    
    private final ChatClient.Builder chatClientBuilder;
    private final VectorStore vectorStore;
    
    public ClftNode(ChatClient.Builder chatClientBuilder, VectorStore vectorStore) {
        this.chatClientBuilder = chatClientBuilder;
        this.vectorStore = vectorStore;
    }
    
    @Override
    public Map<String, Object> apply(OverAllState state) {
        String field = state.value("field");
        
        // 使用RAG检索相关分类知识
        List<Document> similarDocs = vectorStore.similaritySearch(SearchRequest.query(field).withTopK(3));
        
        // 构建提示词
        String context = similarDocs.stream()
                .map(Document::getContent)
                .collect(Collectors.joining("\n\n"));
        
        // 调用大模型进行分类
        String classification = chatClientBuilder.build()
                .prompt()
                .user(u -> u.text("请根据以下上下文信息,对字段 '{field}' 进行分类和分级。\n\n上下文信息:\n{context}")
                        .param("field", field)
                        .param("context", context))
                .call()
                .content();
        
        // 解析分类结果
        Map<String, Object> result = parseClassificationResult(classification);
        return result;
    }
    
    private Map<String, Object> parseClassificationResult(String classification) {
        // 解析大模型返回的分类结果
        Map<String, Object> result = new HashMap<>();
        // 简化示例,实际应根据大模型返回格式进行解析
        result.put("classification", "个人信息");
        result.put("level", "高");
        result.put("reasoning", "该字段涉及用户个人身份信息");
        result.put("need_human_review", true);
        return result;
    }
}

5.5 人工反馈节点

复制代码
public class HumanFeedbackNode implements NodeAction {
    
    @Override
    public Map<String, Object> apply(OverAllState state) {
        if (state.humanFeedback() == null || !state.humanFeedback().isResume()) {
            throw new GraphRunnerException("等待人工反馈...");
        }
        
        Map<String, Object> feedbackData = state.humanFeedback().data();
        boolean isApproved = (boolean) feedbackData.get("feed_back");
        String feedbackReason = (String) feedbackData.get("feedback_reason");
        
        Map<String, Object> result = new HashMap<>();
        result.put("human_next_node", isApproved ? "saveTool" : "clft");
        result.put("feedback_reason", feedbackReason);
        result.put("feedback", isApproved ? "approved" : "rejected");
        
        return result;
    }
}

5.6 字段保存工具

复制代码
public class FieldSaveTool implements ToolCallback {
    
    private final IFieldService fieldService;
    private final ObjectMapper objectMapper;
    
    public FieldSaveTool(IFieldService fieldService, ObjectMapper objectMapper) {
        this.fieldService = fieldService;
        this.objectMapper = objectMapper;
    }
    
    @Override
    public String getName() {
        return "save_field_classification";
    }
    
    @Override
    public String getDescription() {
        return "保存字段分类分级信息";
    }
    
    @Override
    public String getInputSchema() {
        return "{\"type\":\"object\",\"properties\":{\"fieldName\":{\"type\":\"string\",\"description\":\"字段名称\"},\"classification\":{\"type\":\"string\",\"description\":\"分类\"},\"level\":{\"type\":\"string\",\"description\":\"级别\"},\"reasoning\":{\"type\":\"string\",\"description\":\"推理过程\"}},\"required\":[\"fieldName\",\"classification\",\"level\",\"reasoning\"]}";
    }
    
    @Override
    public String call(String toolInput) {
        try {
            JsonNode jsonNode = objectMapper.readTree(toolInput);
            Field field = new Field();
            field.setFieldName(jsonNode.get("fieldName").asText());
            field.setClassification(jsonNode.get("classification").asText());
            field.setLevel(jsonNode.get("level").asText());
            field.setReasoning(jsonNode.get("reasoning").asText());
            
            boolean success = fieldService.save(field);
            return success ? "字段分类分级信息保存成功" : "字段分类分级信息保存失败";
        } catch (Exception e) {
            return "保存失败: " + e.getMessage();
        }
    }
}

6. 运行与测试

  1. 启动应用:运行 Spring Boot 主程序。
  2. 使用浏览器或 API 工具进行测试

测试 1:字段分类流程

访问以下 URL,对字段"用户姓名"进行分类:

复制代码
http://localhost:8080/sec/graph/chat?fieldName=用户姓名

预期响应:系统将返回流式响应,展示整个分类流程的执行过程。

测试 2:人工审核流程

当系统检测到需要人工审核时,工作流会在人工反馈节点暂停。此时可以通过以下接口恢复流程:

复制代码
http://localhost:8080/sec/graph/resume?thread_id=yhong&feed_back=true&feedback_reason=分类准确

参数说明

  • thread_id:工作流线程ID,用于标识特定的工作流实例。
  • feed_back:是否批准分类结果,true表示批准,false表示拒绝。
  • feedback_reason:人工审核的理由或说明。

7. 实现思路与扩展建议

实现思路

本案例的核心思想是**"基于Graph的工作流编排"**。通过将复杂的业务流程拆分为多个节点,并定义节点间的流转关系,实现灵活可扩展的业务流程管理。这使得:

  • 流程可视化:通过状态图可以清晰地看到整个业务流程的执行路径。
  • 节点复用:每个节点都是独立的组件,可以在不同流程中复用。
  • 人机协作:通过人工反馈节点实现人机协作,提高系统可靠性。
  • 状态持久化:通过检查点机制实现工作流状态的持久化,支持长时间运行的业务流程。

扩展建议

  • 多模态支持:扩展系统以支持图片、音频等多模态字段的分类。
  • 规则引擎集成:集成规则引擎,实现更复杂的业务规则判断。
  • 分布式执行:将节点执行分布到多个服务实例,提高系统吞吐量。
  • 流程监控:增加流程执行监控和告警机制,及时发现和处理异常情况。
  • 批量处理:支持批量字段分类,提高处理效率。
相关推荐
时序大模型8 小时前
KDD2025 |DUET:时间 - 通道双聚类框架,多变量时序预测的 “全能选手”出现!
人工智能·机器学习·时间序列预测·时间序列·kdd2025
共绩算力9 小时前
Ming Lite 万能模型对标 GPT-4o 的多模态能力
人工智能·共绩算力
猫先生Mr.Mao9 小时前
2025年8月AGI月评|AI开源项目全解析:从智能体到3D世界,技术边界再突破
人工智能·开源·aigc·agi·ai资讯·分布式推理框架
深入理解GEE云计算9 小时前
遥感生态指数(RSEI):理论发展、方法论争与实践进展
javascript·人工智能·算法·机器学习
IT_陈寒9 小时前
从2秒到200ms:我是如何用JavaScript优化页面加载速度的🚀
前端·人工智能·后端
深度学习lover9 小时前
<项目代码>yolo织物缺陷识别<目标检测>
人工智能·python·yolo·目标检测·计算机视觉·织物缺陷识别·项目代码
StarPrayers.9 小时前
Binary Classification& sigmoid 函数的逻辑回归&Decision Boundary
人工智能·分类·数据挖掘
m0_736927049 小时前
Spring Boot项目中如何实现接口幂等
java·开发语言·spring boot·后端·spring·面试·职场和发展
渡我白衣9 小时前
C++:链接的两难 —— ODR中的强与弱符号机制
开发语言·c++·人工智能·深度学习·网络协议·算法·机器学习