React Flow + Zustand 搭建工作流编排工作台

一、👉 什么是"工作流编排"?

你可以把它理解成:

用"拖拽 + 连线"的方式,定义一套执行流程

本质是:

一个 AI 推理流程的可视化编辑器

用户可以:

  • 拖节点(检测 / 分类 / 逻辑)
  • 连线(定义执行顺序)
  • 配置节点参数
  • 最终导出 JSON(给后端执行)
👉 最终效果

一句话总结:

用"画图"的方式生成一段可执行逻辑

二、整体架构设计(核心)
🧱 1️⃣ UI 层(看得见的)

负责展示:

  • Canvas 👉 画布
  • CustomNode 👉 节点
  • Toolbar 👉 工具栏
  • ConfigModal 👉 配置面板

👉 一句话总结:

负责"画出来 + 点得到"


Canvas组件

设置背景、控制器、小地图、

设置节点、边缘数据,

设置自定义节点组件

TypeScript 复制代码
import React, { useCallback, useRef } from "react";
import ReactFlow, {
  Background,
  Controls,
  MiniMap,
  BackgroundVariant,
  addEdge,
  Connection,
  Edge,
  useReactFlow,
} from "reactflow";
import "reactflow/dist/style.css";
import { useWorkflowStore } from "../store";
import { CustomNode } from "./CustomNode";
import { validateConnection } from "../utils/topology";
import { message } from "antd";
import { NodeType } from "../utils/constants";

const nodeTypes = {
  customNode: CustomNode,
};

export const Canvas = () => {
  const { nodes, edges, onNodesChange, onEdgesChange, onConnect } =
    useWorkflowStore();
  const reactFlowWrapper = useRef<HTMLDivElement>(null);
  // const { project } = useReactFlow();

  // console.log("project", project);

  const onConnectWrapper = useCallback(
    (params: Edge | Connection) => {
      const sourceNode = nodes.find((n) => n.id === params.source);
      const targetNode = nodes.find((n) => n.id === params.target);
      if (!sourceNode || !targetNode) return;

      const isValid = validateConnection(
        sourceNode.data.type as NodeType,
        targetNode.data.type as NodeType,
      );

      if (!isValid) {
        message.error(
          `不允许从 ${sourceNode.data.name} 连接到 ${targetNode.data.name}`,
        );
        return;
      }
      onConnect(params as Connection);
    },
    [nodes, onConnect],
  );

  return (
    <div
      style={{ flex: 1, height: "100%", position: "relative" }}
      ref={reactFlowWrapper}
    >
      <ReactFlow
        nodes={nodes}
        edges={edges}
        onNodesChange={onNodesChange}
        onEdgesChange={onEdgesChange}
        onConnect={onConnectWrapper}
        nodeTypes={nodeTypes}
        fitView
        deleteKeyCode={["Backspace", "Delete"]} // 允许使用 Delete 或 Backspace 键删除选中节点/连线
      >
        <Background variant={BackgroundVariant.Dots} gap={12} size={1} />
        <Controls />
        <MiniMap zoomable pannable />
      </ReactFlow>
    </div>
  );
};

CustomNode组件,自定义的卡片组件

自定义卡片内容、样式、Handle小圆点

点击编辑,触发自定义事件,modal中展示数据

点击添加,判断当前节点的下级节点并进行添加

点击删除,根据节点id删除

TypeScript 复制代码
import React, { useState } from "react";
import { Handle, Position } from "reactflow";
import { Card, Button, Dropdown, MenuProps, Tooltip } from "antd";
import { EditOutlined, PlusOutlined, DeleteOutlined } from "@ant-design/icons";
import { useWorkflowStore } from "../store";
import { NODE_CONSTRAINTS, NODE_TYPES, NodeType } from "../utils/constants";

export interface NodeData {
  type: NodeType;
  name: string;
  config: any;
  index?: number;
  position?: { x: number; y: number }; // 补充用于布局计算的位置信息
}

export const CustomNode = ({ id, data, xPos, yPos }: any) => {
  const { addNode, removeNode } = useWorkflowStore();
  const [isHovered, setIsHovered] = useState(false);

  const handleAdd = (targetType: NodeType) => {
    // 优先使用 React Flow 传入的真实坐标,降级使用 data 中的坐标
    const currentX =
      typeof xPos === "number" ? xPos : (data.position?.x ?? 250);
    const currentY = typeof yPos === "number" ? yPos : (data.position?.y ?? 50);

    const position = {
      x: currentX, // 去掉随机偏移,让其完全水平居中对齐
      y: currentY + 150, // 保持正下方 150px
    };

    addNode(targetType, position, id);
  };

  const allowedDownstream = NODE_CONSTRAINTS[data.type as NodeType].down;

  const addMenuItems: MenuProps["items"] = allowedDownstream.map(
    (type: NodeType) => ({
      key: type,
      label: `添加 ${NODE_TYPES[type]}`,
      onClick: () => handleAdd(type),
    }),
  );

  // 渲染配置摘要
  const renderConfigSummary = () => {
    if (Object.keys(data.config).length === 0)
      return <div style={{ color: "#999", fontSize: 12 }}>无配置项</div>;
    // return <></>;

    return (
      <div style={{ fontSize: 12, color: "#666", marginTop: 8 }}>
        {Object.entries(data.config).map(([k, v]) => (
          <div
            key={k}
            style={{
              overflow: "hidden",
              textOverflow: "ellipsis",
              whiteSpace: "nowrap",
            }}
          >
            {k}: {Array.isArray(v) ? `[${v.join(",")}]` : String(v)}
          </div>
        ))}
      </div>
    );
  };

  return (
    <Card
      size="small"
      title={
        <div
          style={{
            display: "flex",
            justifyContent: "space-between",
            alignItems: "center",
          }}
        >
          <span>{data.name}</span>
        </div>
      }
      style={{
        width: 240,
        boxShadow: isHovered
          ? "0 4px 12px rgba(0,0,0,0.1)"
          : "0 2px 8px rgba(0,0,0,0.05)",
        border:
          data.type === "start"
            ? "1px solid #52c41a"
            : data.type === "end"
              ? "1px solid #f5222d"
              : "1px solid #d9d9d9",
      }}
      onMouseEnter={() => setIsHovered(true)}
      onMouseLeave={() => setIsHovered(false)}
      actions={[
        data.type !== "start" ? (
          <Tooltip title="编辑配置" key="edit">
            {/* 增加 className="nodrag" 防止 React Flow 拖拽事件拦截点击 */}
            <Button
              className="nodrag"
              type="text"
              icon={<EditOutlined style={{ fontSize: 24 }} />}
              onClick={() =>
                window.dispatchEvent(
                  new CustomEvent("OPEN_CONFIG_MODAL", {
                    detail: { id, data },
                  }),
                )
              }
            />
          </Tooltip>
        ) : (
          <span key="empty_edit"></span>
        ),
        allowedDownstream.length > 0 ? (
          <Dropdown
            menu={{ items: addMenuItems }}
            trigger={["click"]}
            key="add"
          >
            <Tooltip title="添加下游节点">
              {/* 增加 className="nodrag" 防止 React Flow 拖拽事件拦截点击 */}
              <Button
                className="nodrag"
                type="text"
                icon={<PlusOutlined style={{ fontSize: 24 }} />}
              />
            </Tooltip>
          </Dropdown>
        ) : (
          <span key="empty"></span>
        ),
        data.type !== "start" ? (
          <Tooltip title="删除节点" key="delete">
            {/* 增加 className="nodrag" 防止 React Flow 拖拽事件拦截点击 */}
            <Button
              className="nodrag"
              type="text"
              danger
              icon={<DeleteOutlined style={{ fontSize: 24 }} />}
              onClick={() => removeNode(id)}
            />
          </Tooltip>
        ) : (
          <span key="empty2"></span>
        ),
      ]}
    >
      {data.type !== "start" && (
        <Handle
          type="target"
          position={Position.Top}
          style={{
            background: "#52c41a",
            width: 12,
            height: 12,
            top: -6,
            border: "2px solid #fff",
          }}
        />
      )}
      {renderConfigSummary()}
      <Handle
        type="source"
        position={Position.Bottom}
        style={{
          background: "#1890ff",
          width: 12,
          height: 12,
          bottom: -6,
          cursor: "crosshair",
          border: "2px solid #fff",
        }}
      />
    </Card>
  );
};

Toolbar组件

切换版本,加载目标版本数据

查看代码,将当前的节点数据转化为json格式展示出来

自适应,自动调整画布位置

保存,校验当前节点信息是否满足约束,满足则将节点信息转为json

TypeScript 复制代码
import React, { useState } from "react";
import { Button, Tooltip, Space, Modal, Input } from "antd";
import {
  CodeOutlined,
  SyncOutlined,
  FullscreenOutlined,
  SaveOutlined,
} from "@ant-design/icons";
import { useReactFlow } from "reactflow";
import { useWorkflowStore } from "../store";
import { validateTopology } from "../utils/topology";
import { message } from "antd";

export const Toolbar = () => {
  const { fitView } = useReactFlow();
  const { generateJSON, nodes, edges } = useWorkflowStore();
  const [jsonVisible, setJsonVisible] = useState(false);
  const [jsonContent, setJsonContent] = useState("");

  const handleFitView = () => {
    fitView({ padding: 0.2, duration: 800 });
  };

  const handleViewCode = () => {
    const jsonStr = generateJSON();
    setJsonContent(jsonStr);
    setJsonVisible(true);
  };

  const handleSave = () => {
    const validation = validateTopology(nodes, edges);
    if (!validation.valid) {
      message.error(`拓扑校验失败: ${validation.message}`);
      return;
    }
    message.success("校验通过,保存成功!");
    console.log("Saved JSON:", generateJSON());
  };

  return (
    <>
      <div
        style={{
          position: "absolute",
          bottom: 20,
          left: 20,
          zIndex: 10,
          background: "#fff",
          padding: "8px",
          borderRadius: "8px",
          boxShadow: "0 2px 8px rgba(0,0,0,0.15)",
        }}
      >
        <Space direction="vertical">
          <Tooltip title="切换版本" placement="right">
            {/* // TODO: 此处替换本地图标 */}
            <Button
              type="text"
              icon={<SyncOutlined style={{ fontSize: 24 }} />}
            />
          </Tooltip>
          <Tooltip title="查看代码" placement="right">
            {/* // TODO: 此处替换本地图标 */}
            <Button
              type="text"
              icon={<CodeOutlined style={{ fontSize: 24 }} />}
              onClick={handleViewCode}
            />
          </Tooltip>
          <Tooltip title="自适应" placement="right">
            {/* // TODO: 此处替换本地图标 */}
            <Button
              type="text"
              icon={<FullscreenOutlined style={{ fontSize: 24 }} />}
              onClick={handleFitView}
            />
          </Tooltip>
          <Tooltip title="校验并保存" placement="right">
            {/* // TODO: 此处替换本地图标 */}
            <Button
              type="primary"
              icon={<SaveOutlined style={{ fontSize: 24 }} />}
              onClick={handleSave}
              style={{ marginTop: 8 }}
            />
          </Tooltip>
        </Space>
      </div>

      <Modal
        title="工作流 JSON 配置"
        open={jsonVisible}
        onCancel={() => setJsonVisible(false)}
        width={800}
        centered
        footer={[
          <Button
            key="close"
            variant="filled"
            color="default"
            onClick={() => setJsonVisible(false)}
          >
            关闭
          </Button>,
        ]}
      >
        <Input.TextArea
          value={jsonContent}
          rows={20}
          readOnly
          style={{ fontFamily: "monospace", fontSize: 13 }}
        />
      </Modal>
    </>
  );
};

ConfigModal组件

某一个节点的信息数据

TypeScript 复制代码
import React, { useEffect, useState } from "react";
import {
  Modal,
  Form,
  InputNumber,
  Select,
  Checkbox,
  Slider,
  Switch,
  Divider,
  Input,
  Button,
} from "antd";
import { useWorkflowStore } from "../store";
import { MOCK_MODELS, NodeType } from "../utils/constants";

interface ConfigModalProps {
  visible: boolean;
  nodeId: string;
  nodeData: any;
  onCancel: () => void;
}

export const ConfigModal = ({
  visible,
  nodeId,
  nodeData,
  onCancel,
}: ConfigModalProps) => {
  const [form] = Form.useForm();
  const { updateNodeData } = useWorkflowStore();

  useEffect(() => {
    if (visible && nodeData) {
      form.setFieldsValue(nodeData.config);
    }
  }, [visible, nodeData, form]);

  const handleOk = async () => {
    try {
      const values = await form.validateFields();
      updateNodeData(nodeId, values);
      onCancel();
    } catch (error) {
      console.error("Validation Failed:", error);
    }
  };

  const handleModelChange = (value: string) => {
    if (!nodeData) return;
    const { type } = nodeData;
    const modelList =
      type === "feature"
        ? MOCK_MODELS.feature
        : MOCK_MODELS[type as "detect" | "classify"];
    const selectedModel = modelList?.find((m) => m.name === value);

    if (selectedModel) {
      // 模拟从后端获取到的模型关联参数
      form.setFieldsValue({
        onnx: `${selectedModel.name}_${selectedModel.version}.omodel`,
        classes: selectedModel.classes,
      });
    }
  };

  const renderFormItems = () => {
    if (!nodeData) return null;
    const { type } = nodeData;

    switch (type as NodeType) {
      case "detect":
      case "classify":
        return (
          <>
            <Form.Item
              name="model_name"
              label="模型选择"
              rules={[{ required: true }]}
            >
              <Select
                placeholder="请选择模型"
                variant="filled"
                onChange={handleModelChange}
                options={(MOCK_MODELS[type as "detect" | "classify"] || []).map(
                  (m: any) => ({
                    label: `${m.name} v${m.version}`,
                    value: m.name,
                  }),
                )}
              />
            </Form.Item>
            <Form.Item name="onnx" label="模型文件 (仅作数据透传预览)">
              <Input disabled variant="filled" />
            </Form.Item>
            <Form.Item name="classes" label="类别总数 (仅作数据透传预览)">
              <InputNumber disabled variant="filled" />
            </Form.Item>
            <Form.Item
              name="threshold"
              label="置信度"
              rules={[{ required: true, type: "number", min: 0, max: 1 }]}
            >
              <Slider min={0} max={1} step={0.01} marks={{ 0: "0", 1: "1" }} />
            </Form.Item>
            <Form.Item
              name="upload_class"
              label="输出目标/属性"
              tooltip="若有检测到此类目标,则输出给下级;若都不勾选,则表示全部输出给下级"
            >
              <Select
                mode="multiple"
                placeholder="请选择输出类别ID(模拟)"
                variant="filled"
                options={[
                  { label: "人体(0)", value: 0 },
                  { label: "工作服(1000)", value: 1000 },
                ]}
              />
            </Form.Item>
            <Form.Item
              name="unoutput_class"
              label="屏蔽目标/属性"
              tooltip="若有检测到此类目标,则不输出给下级"
            >
              <Select
                mode="multiple"
                variant="filled"
                placeholder="请选择屏蔽类别ID(模拟)"
                options={[{ label: "红色(2000)", value: 2000 }]}
              />
            </Form.Item>
          </>
        );
      case "feature":
        return (
          <>
            <Form.Item
              name="model_name"
              label="大模型选择"
              rules={[{ required: true }]}
            >
              <Select
                variant="filled"
                placeholder="请选择大模型"
                onChange={handleModelChange}
                options={MOCK_MODELS.feature.map((m: any) => ({
                  label: `${m.name} v${m.version}`,
                  value: m.name,
                }))}
              />
            </Form.Item>
            <Form.Item name="model_path" label="大模型路径 (仅作数据透传预览)">
              <Input disabled variant="filled" />
            </Form.Item>
            {/* 简化的正负向提示词配置 */}
            <Form.Item
              name="hit_texts"
              label="正向提示词"
              tooltip="命中描述则输出结果 (模拟输入数组对象)"
            >
              <Select
                mode="tags"
                variant="filled"
                placeholder="例如: 未穿工作服的人"
              />
            </Form.Item>
            <Form.Item
              name="filter_texts"
              label="负向提示词"
              tooltip="命中描述则丢弃结果 (模拟输入数组对象)"
            >
              <Select
                mode="tags"
                variant="filled"
                placeholder="例如: 穿工作服的人"
              />
            </Form.Item>
            <Form.Item
              name="threshold"
              label="置信度阈值"
              rules={[{ required: true, type: "number", min: 0, max: 1 }]}
            >
              <Slider min={0} max={1} step={0.01} />
            </Form.Item>
          </>
        );
      case "logic":
        return (
          <>
            <Form.Item
              name={["filter", "output_class"]}
              label="输出目标"
              tooltip="确定最后输出给下级具体哪些符合条件的目标;若都不勾选,则表示全部输出给下级"
            >
              <Select
                mode="multiple"
                placeholder="输出类别ID"
                options={[{ label: "人体(0)", value: 0 }]}
                variant="filled"
              />
            </Form.Item>
            <Form.Item
              name={["filter", "contain_class"]}
              label="包含关系"
              tooltip="输出目标的框内必须包含此类目标"
            >
              <Select
                mode="multiple"
                variant="filled"
                placeholder="包含类别ID"
                options={[{ label: "工作服(1000)", value: 1000 }]}
              />
            </Form.Item>
            <Form.Item
              name={["filter", "no_contain_class"]}
              label="不包含关系"
              tooltip="输出目标的框内必须不包含此类目标"
            >
              <Select
                mode="multiple"
                variant="filled"
                placeholder="不包含类别ID"
              />
            </Form.Item>
            <Form.Item
              name={["filter", "in_class"]}
              label="属于关系"
              tooltip="输出目标必须在此类目标的框内"
            >
              <Select
                mode="multiple"
                variant="filled"
                placeholder="属于类别ID"
              />
            </Form.Item>
            <Form.Item
              name={["filter", "intersect_class"]}
              label="相交关系"
              tooltip="输出目标与此类目标必须有重叠关系"
            >
              <Select
                mode="multiple"
                variant="filled"
                placeholder="相交类别ID"
              />
            </Form.Item>
          </>
        );
      case "end":
        return (
          <>
            <Form.Item name="min_targets" label="最小目标数量">
              <InputNumber min={0} />
            </Form.Item>
            <Form.Item name="min_width" label="最小宽度限制 (px)">
              <InputNumber min={0} />
            </Form.Item>
            <Form.Item name="min_height" label="最小高度限制 (px)">
              <InputNumber min={0} />
            </Form.Item>
            <Form.Item name="max_width" label="最大宽度限制 (px)">
              <InputNumber min={0} />
            </Form.Item>
            <Form.Item name="max_height" label="最大高度限制 (px)">
              <InputNumber min={0} />
            </Form.Item>
          </>
        );
      default:
        return <div>当前节点类型无需特殊配置</div>;
    }
  };

  return (
    <Modal
      title={`配置节点 - ${nodeData?.name || ""}`}
      open={visible}
      onOk={handleOk}
      onCancel={onCancel}
      width={600}
      destroyOnHidden
      footer={[
        <Button
          key="cancel"
          color="default"
          variant="filled"
          onClick={onCancel}
        >
          取消
        </Button>,
        <Button key="ok" type="primary" onClick={handleOk}>
          确定
        </Button>,
      ]}
    >
      <Form form={form} layout="vertical">
        {renderFormItems()}
      </Form>
    </Modal>
  );
};

🧠 2️⃣ 状态层(Zustand)

核心数据:

  • nodes
  • edges

核心能力:

  • 增删改查节点
  • 处理连线
  • 生成 JSON

👉 一句话总结:
所有操作,本质都是在改状态

TypeScript 复制代码
import { create } from "zustand";
import {
  Node,
  Edge,
  Connection,
  addEdge,
  applyNodeChanges,
  applyEdgeChanges,
  NodeChange,
  EdgeChange,
} from "reactflow";
import {
  NodeType,
  NODE_TYPES,
  INITIAL_NODES,
  MOCK_MODELS,
} from "../utils/constants";

interface WorkflowState {
  nodes: Node[];
  edges: Edge[];
  maxIndex: number;
  onNodesChange: (changes: NodeChange[]) => void;
  onEdgesChange: (changes: EdgeChange[]) => void;
  onConnect: (connection: Connection) => void;
  addNode: (
    type: NodeType,
    position: { x: number; y: number },
    sourceId?: string,
  ) => void;
  removeNode: (id: string) => void;
  updateNodeData: (id: string, data: any) => void;
  generateJSON: () => string;
}

// 用户拖节点->产生 NodeChange[]->applyNodeChanges->nodes 更新
// 用户连线->产生 Connection->addEdge->edges 更新
// 用户操作边(删除 / 选中 / 更新)->产生 EdgeChange[]->applyEdgeChanges->edges 更新
export const useWorkflowStore = create<WorkflowState>((set, get) => ({
  nodes: [
    {
      id: "start_0",
      type: "customNode",
      position: { x: 250, y: 50 },
      data: {
        type: "start",
        name: "开始",
        config: {},
        position: { x: 250, y: 50 },
      },
    },
  ],
  edges: [],
  maxIndex: 0,

  onNodesChange: (changes: NodeChange[]) => {
    // 拦截对开始节点的删除操作
    const filteredChanges = changes.filter(
      (c) => !(c.type === "remove" && c.id.startsWith("start")),
    );

    // 更新节点状态
    const updatedNodes = applyNodeChanges(filteredChanges, get().nodes);

    // 同步更新 node.data.position,以便我们添加子节点时可以拿它做后备坐标
    updatedNodes.forEach((node) => {
      const positionChange = filteredChanges.find(
        (c) => c.type === "position" && c.id === node.id,
      ) as any;
      if (positionChange && positionChange.position) {
        node.data = { ...node.data, position: positionChange.position };
      }
    });

    set({ nodes: updatedNodes });
  },

  onEdgesChange: (changes: EdgeChange[]) => {
    set({
      edges: applyEdgeChanges(changes, get().edges),
    });
  },

  onConnect: (connection: Connection) => {
    set({
      edges: addEdge(
        { ...connection, type: "smoothstep", animated: true },
        get().edges,
      ),
    });
  },

  addNode: (
    type: NodeType,
    position: { x: number; y: number },
    sourceId?: string,
  ) => {
    const { maxIndex, nodes, edges } = get();
    const newIndex = maxIndex + 1;
    const newNodeId = `${type}_${newIndex}`;

    // 初始化默认配置(包含所有 JSON 模板所需的预设或 Mock 数据)
    const defaultConfig: any = {};
    if (type === "detect") {
      defaultConfig.model_name = MOCK_MODELS.detect[0].name; // 默认选中第一个
      defaultConfig.onnx = `${MOCK_MODELS.detect[0].name}_${MOCK_MODELS.detect[0].version}.omodel`; // Mock 从后端获取的参数
      defaultConfig.classes = MOCK_MODELS.detect[0].classes;
      defaultConfig.threshold = 0.45;
      defaultConfig.input_w = 640;
      defaultConfig.input_h = 640;
      defaultConfig.upload_class = [];
      defaultConfig.unoutput_class = [];
    } else if (type === "classify") {
      defaultConfig.model_name = MOCK_MODELS.classify[0].name;
      defaultConfig.onnx = `${MOCK_MODELS.classify[0].name}_${MOCK_MODELS.classify[0].version}.omodel`;
      defaultConfig.classes = MOCK_MODELS.classify[0].classes;
      defaultConfig.threshold = 0.65;
      defaultConfig.upload_class = [];
      defaultConfig.unoutput_class = [];
    } else if (type === "feature") {
      defaultConfig.model_name = MOCK_MODELS.feature[0].name;
      defaultConfig.model_path = "cnclip";
      defaultConfig.hit_texts = [];
      defaultConfig.filter_texts = [];
      defaultConfig.threshold = 0.27; // 新增大模型默认置信度
    } else if (type === "logic") {
      defaultConfig.filter = {
        output_class: [],
        in_class: [],
        no_class: [],
        contain_class: [],
        no_contain_class: [],
        intersect_class: [],
      };
    } else if (type === "end") {
      defaultConfig.min_targets = 1;
      defaultConfig.min_width = 20;
      defaultConfig.min_height = 20;
      defaultConfig.max_width = 1000;
      defaultConfig.max_height = 1000;
    }

    const newNode: Node = {
      id: newNodeId,
      type: "customNode",
      position,
      data: {
        type,
        name: NODE_TYPES[type],
        index: newIndex, // 核心:用于ID偏移计算
        config: defaultConfig,
        position, // 记录一下初始位置给 CustomNode 用
      },
    };

    const newEdges = [...edges];
    if (sourceId) {
      newEdges.push({
        id: `e-${sourceId}-${newNodeId}`,
        source: sourceId,
        target: newNodeId,
        type: "smoothstep",
        animated: true,
      });
    }

    set({
      nodes: [...nodes, newNode],
      edges: newEdges,
      maxIndex: newIndex,
    });
  },

  removeNode: (id: string) => {
    if (id.startsWith("start")) return; // 保护开始节点
    set({
      nodes: get().nodes.filter((node) => node.id !== id),
      edges: get().edges.filter(
        (edge) => edge.source !== id && edge.target !== id,
      ),
    });
  },

  updateNodeData: (id: string, configData: any) => {
    set({
      nodes: get().nodes.map((node) => {
        if (node.id === id) {
          return {
            ...node,
            data: {
              ...node.data,
              config: { ...node.data.config, ...configData },
            },
          };
        }
        return node;
      }),
    });
  },

  generateJSON: () => {
    const { nodes, edges } = get();

    // 构建JSON模版
    const resultJSON = {
      apply: [
        {
          factory: "custom",
          ability: [
            {
              type: "uniformdetect",
              version: "3.3",
              name: "自定义智能分析工作流",
              update_time: Math.floor(Date.now() / 1000),
              model: [] as any[],
              results: [] as any[],
              objects: [] as any[],
              task_config: {},
            },
          ],
        },
      ],
    };

    const ability = resultJSON.apply[0].ability[0];

    // 遍历节点,根据类型分类填充到模型(model)或逻辑(results)
    nodes.forEach((node) => {
      const { type, config, index, name } = node.data;

      // 寻找上游节点
      const upEdges = edges.filter((e) => e.target === node.id);
      const src = upEdges.length > 0 ? upEdges[0].source : "src";

      // 寻找下游节点
      const downEdges = edges.filter((e) => e.source === node.id);
      const dst = downEdges.length > 0 ? downEdges[0].target : "dst";

      if (["detect", "classify", "feature"].includes(type)) {
        // 组装符合 JSON 模板的模型结构
        const modelNode: any = {
          key: node.id,
          type: type,
          src: src,
          dst: dst,
          config: {},
        };

        if (type === "detect" || type === "classify") {
          modelNode.config = {
            onnx: config.onnx,
            classes: config.classes,
            input_w: config.input_w,
            input_h: config.input_h,
            upload_class: config.upload_class || [],
            unoutput_class: config.unoutput_class || [],
            threshold: config.threshold,
          };

          // ID 偏移逻辑
          if (config.upload_class) {
            modelNode.config.upload_class = config.upload_class.map(
              (id: number) => id + index * 1000,
            );
          }
          if (config.unoutput_class) {
            modelNode.config.unoutput_class = config.unoutput_class.map(
              (id: number) => id + index * 1000,
            );
          }
        } else if (type === "feature") {
          // 处理大模型复核的特有配置
          modelNode.config = {
            model_path: config.model_path,
            hit_texts: (config.hit_texts || []).map((text: string) => ({
              text,
              threshold: config.threshold,
            })),
            filter_texts: (config.filter_texts || []).map((text: string) => ({
              text,
              threshold: config.threshold,
            })),
          };
        }

        ability.model.push(modelNode);

        // 我们将不再在这里直接处理 detect 的 objects。
        // objects 将由前一个节点选中的内容来决定,我们把提取逻辑放到单独的循环中
      } else if (type === "logic") {
        // 逻辑节点
        ability.results.push({
          key: node.id,
          filter: { ...config.filter },
        });
      } else if (type === "end") {
        // 结束节点参数映射到 task_config
        ability.task_config = { ...config };
      }
    });

    // ========== 新增:处理 objects 数组逻辑 ==========
    // 规则:展示结束前一个节点选中的输出目标 ID、中文名和英文名

    // 1. 找到所有的结束节点 (end)
    const endNodes = nodes.filter((n) => n.data.type === "end");
    const objectIds = new Set<number>(); // 用于去重

    endNodes.forEach((endNode) => {
      // 2. 找到指向 end 节点的所有上游连线
      const upEdgesToEnd = edges.filter((e) => e.target === endNode.id);

      upEdgesToEnd.forEach((edge) => {
        // 3. 找到紧挨着 end 节点的"前一个节点"
        const prevNode = nodes.find((n) => n.id === edge.source);
        if (!prevNode) return;

        const prevConfig = prevNode.data.config;
        const prevIndex = prevNode.data.index || 0;

        // 4. 提取该节点配置中选中的输出目标 (upload_class)
        // 注意:如果是 logic 节点,它可能自身没有 upload_class。
        // 如果严格要求"结束前一个节点",并且它可能就是包含目标数据的 detect/classify,我们取它的 upload_class。
        // 如果逻辑节点本身要透传前面的目标,这里可能需要递归往上找。
        // 这里按照"前一个节点如果包含 upload_class 则提取"来处理,如果前一个是 logic,我们往上追溯一层找到 detect/classify

        const extractObjectsFromNode = (node: Node) => {
          const config = node.data.config;
          const index = node.data.index || 0;
          if (config.upload_class && Array.isArray(config.upload_class)) {
            config.upload_class.forEach((classId: number) => {
              const offsetId = classId + index * 1000;
              if (!objectIds.has(offsetId)) {
                objectIds.add(offsetId);
                const isHuman = classId === 0; // 简单的Mock判断
                ability.objects.push({
                  id: offsetId,
                  name_ch: isHuman ? "人体" : `目标_${classId}`,
                  name_en: isHuman ? "body" : `target_${classId}`,
                });
              }
            });
          }
        };

        if (prevNode.data.type === "logic") {
          // 如果前一个是逻辑节点,往上追溯找到连接到该逻辑节点的 detect/classify
          const upEdgesToLogic = edges.filter((e) => e.target === prevNode.id);
          upEdgesToLogic.forEach((e) => {
            const sourceNode = nodes.find((n) => n.id === e.source);
            if (sourceNode) extractObjectsFromNode(sourceNode);
          });
        } else {
          // 如果前一个节点直接是包含目标输出的节点
          extractObjectsFromNode(prevNode);
        }
      });
    });

    return JSON.stringify(resultJSON, null, 2);
  },
}));

⚙️ 3️⃣ 引擎层(核心逻辑)

包括:

  • 拓扑校验(是否合法连接)
  • JSON 生成(执行链)

👉 一句话总结:
把"图"变成"逻辑"

TypeScript 复制代码
import { Node, Edge } from 'reactflow';
import { NodeType, NODE_CONSTRAINTS } from './constants';

// 验证两节点相连是否合法
export const validateConnection = (sourceType: NodeType, targetType: NodeType): boolean => {
  return NODE_CONSTRAINTS[sourceType].down.includes(targetType) && 
         NODE_CONSTRAINTS[targetType].up.includes(sourceType);
};

// 全量拓扑校验:DFS 确保从开始节点能连通至结束节点,无孤立节点
export const validateTopology = (nodes: Node[], edges: Edge[]): { valid: boolean; message?: string } => {
  if (nodes.length === 0) return { valid: false, message: '画布不能为空' };

  const startNodes = nodes.filter(n => n.data.type === 'start');
  const endNodes = nodes.filter(n => n.data.type === 'end');

  if (startNodes.length !== 1) return { valid: false, message: '必须有且仅有一个"开始"节点' };
  if (endNodes.length === 0) return { valid: false, message: '至少需要一个"结束"节点' };

  // 检查孤立节点 (既没有作为源,也没有作为目标的非开始/结束节点)
  const connectedNodeIds = new Set<string>();
  edges.forEach(e => {
    connectedNodeIds.add(e.source);
    connectedNodeIds.add(e.target);
  });
  
  // 确保所有节点都在连线中,除了画布中仅有单个开始节点的情况
  if (nodes.length > 1) {
    for (const node of nodes) {
      if (!connectedNodeIds.has(node.id)) {
        return { valid: false, message: `节点 "${node.data.name}" 是孤立节点,请连接它或将其删除` };
      }
    }
  }

  // 构建邻接表
  const adjList = new Map<string, string[]>();
  nodes.forEach(n => adjList.set(n.id, []));
  edges.forEach(e => {
    adjList.get(e.source)?.push(e.target);
  });

  // 检查所有路径是否都能走到 end 节点
  let hasValidPath = false;
  const visited = new Set<string>();

  // 简单的DFS,检查从开始节点是否可达结束节点,以及是否有死胡同
  const dfs = (nodeId: string, path: Set<string>): boolean => {
    const node = nodes.find(n => n.id === nodeId);
    if (!node) return false;
    
    if (node.data.type === 'end') {
      hasValidPath = true;
      return true; // 到达终点,此路径有效
    }

    const neighbors = adjList.get(nodeId) || [];
    if (neighbors.length === 0) {
      // 不是结束节点,且没有下游节点
      return false;
    }

    let allPathsValid = true;
    for (const neighbor of neighbors) {
      if (path.has(neighbor)) {
        // 发现环,视业务需求,如果允许环(如视频流处理可能有状态,但此编排通常是DAG无环图)
        // 这里暂时认为不允许环路(DAG)
        return false;
      }
      path.add(neighbor);
      visited.add(neighbor);
      const res = dfs(neighbor, path);
      if (!res) allPathsValid = false;
      path.delete(neighbor);
    }
    return allPathsValid;
  };

  const startId = startNodes[0].id;
  visited.add(startId);
  const isValid = dfs(startId, new Set([startId]));

  if (!isValid) {
    return { valid: false, message: '存在未连通至结束节点的死胡同路径或环路' };
  }
  
  if (!hasValidPath) {
    return { valid: false, message: '未找到从开始到结束节点的有效连通路径' };
  }

  // 检查是否有节点未被访问到(不连通的子图)
  for (const node of nodes) {
    if (!visited.has(node.id)) {
      return { valid: false, message: `存在未连通至主流程的节点: ${node.data.name}` };
    }
  }

  return { valid: true };
};

🔗 4️⃣ 通信机制

你用了两种:

① ReactFlowProvider

解决:

Canvas 和 Toolbar 共享上下文

② window 事件总线
TypeScript 复制代码
window.dispatchEvent(new CustomEvent("OPEN_CONFIG_MODAL"))

👉 一句话总结:

轻量跨组件通信,不优雅但组件解耦

TypeScript 复制代码
import { useState, useEffect } from "react";
import { ReactFlowProvider } from "reactflow";
import { Canvas } from "./components/Canvas";
import { Toolbar } from "./components/Toolbar";
import { ConfigModal } from "./components/ConfigModal";

const App = () => {
  const [modalVisible, setModalVisible] = useState(false);
  const [activeNode, setActiveNode] = useState<{
    id: string;
    data: any;
  } | null>(null);

  useEffect(() => {
    const handleOpenModal = (e: any) => {
      setActiveNode(e.detail);
      setModalVisible(true);
    };
    // 对于不需要保留状态、不需要 UI 响应;事件总线是最简洁的。
    window.addEventListener("OPEN_CONFIG_MODAL", handleOpenModal);
    return () => {
      window.removeEventListener("OPEN_CONFIG_MODAL", handleOpenModal);
    };
  }, []);

  return (
    <div
      style={{
        width: "100vw",
        height: "100vh",
        display: "flex",
        flexDirection: "column",
      }}
    >
      <div
        style={{
          height: 48,
          background: "#001529",
          color: "#fff",
          display: "flex",
          alignItems: "center",
          padding: "0 20px",
          fontSize: 16,
          fontWeight: "bold",
        }}
      >
        智能算法工作流编排画布
      </div>

      <div
        style={{
          flex: 1,
          position: "relative",
          overflow: "hidden",
          background: "#f0f2f5",
        }}
      >
        <ReactFlowProvider>
          <Canvas />
          <Toolbar />
        </ReactFlowProvider>
      </div>

      <ConfigModal
        visible={modalVisible}
        nodeId={activeNode?.id || ""}
        nodeData={activeNode?.data}
        onCancel={() => setModalVisible(false)}
      />
    </div>
  );
};

export default App;

三、约束系统(核心亮点🔥)

这是你这套系统最有价值的地方之一。

1️⃣ NODE_CONSTRAINTS
TypeScript 复制代码
start: { down: ['detect', 'classify'] }

👉 本质:

定义"图的规则"

2️⃣ validateConnection
TypeScript 复制代码
validateConnection(sourceType, targetType)

作用:

阻止非法连接

TypeScript 复制代码
import { Node, Edge } from 'reactflow';
import { NodeType, NODE_CONSTRAINTS } from './constants';

// 验证两节点相连是否合法
export const validateConnection = (sourceType: NodeType, targetType: NodeType): boolean => {
  return NODE_CONSTRAINTS[sourceType].down.includes(targetType) && 
         NODE_CONSTRAINTS[targetType].up.includes(sourceType);
};

// 全量拓扑校验:DFS 确保从开始节点能连通至结束节点,无孤立节点
export const validateTopology = (nodes: Node[], edges: Edge[]): { valid: boolean; message?: string } => {
  if (nodes.length === 0) return { valid: false, message: '画布不能为空' };

  const startNodes = nodes.filter(n => n.data.type === 'start');
  const endNodes = nodes.filter(n => n.data.type === 'end');

  if (startNodes.length !== 1) return { valid: false, message: '必须有且仅有一个"开始"节点' };
  if (endNodes.length === 0) return { valid: false, message: '至少需要一个"结束"节点' };

  // 检查孤立节点 (既没有作为源,也没有作为目标的非开始/结束节点)
  const connectedNodeIds = new Set<string>();
  edges.forEach(e => {
    connectedNodeIds.add(e.source);
    connectedNodeIds.add(e.target);
  });
  
  // 确保所有节点都在连线中,除了画布中仅有单个开始节点的情况
  if (nodes.length > 1) {
    for (const node of nodes) {
      if (!connectedNodeIds.has(node.id)) {
        return { valid: false, message: `节点 "${node.data.name}" 是孤立节点,请连接它或将其删除` };
      }
    }
  }

  // 构建邻接表
  const adjList = new Map<string, string[]>();
  nodes.forEach(n => adjList.set(n.id, []));
  edges.forEach(e => {
    adjList.get(e.source)?.push(e.target);
  });

  // 检查所有路径是否都能走到 end 节点
  let hasValidPath = false;
  const visited = new Set<string>();

  // 简单的DFS,检查从开始节点是否可达结束节点,以及是否有死胡同
  const dfs = (nodeId: string, path: Set<string>): boolean => {
    const node = nodes.find(n => n.id === nodeId);
    if (!node) return false;
    
    if (node.data.type === 'end') {
      hasValidPath = true;
      return true; // 到达终点,此路径有效
    }

    const neighbors = adjList.get(nodeId) || [];
    if (neighbors.length === 0) {
      // 不是结束节点,且没有下游节点
      return false;
    }

    let allPathsValid = true;
    for (const neighbor of neighbors) {
      if (path.has(neighbor)) {
        // 发现环,视业务需求,如果允许环(如视频流处理可能有状态,但此编排通常是DAG无环图)
        // 这里暂时认为不允许环路(DAG)
        return false;
      }
      path.add(neighbor);
      visited.add(neighbor);
      const res = dfs(neighbor, path);
      if (!res) allPathsValid = false;
      path.delete(neighbor);
    }
    return allPathsValid;
  };

  const startId = startNodes[0].id;
  visited.add(startId);
  const isValid = dfs(startId, new Set([startId]));

  if (!isValid) {
    return { valid: false, message: '存在未连通至结束节点的死胡同路径或环路' };
  }
  
  if (!hasValidPath) {
    return { valid: false, message: '未找到从开始到结束节点的有效连通路径' };
  }

  // 检查是否有节点未被访问到(不连通的子图)
  for (const node of nodes) {
    if (!visited.has(node.id)) {
      return { valid: false, message: `存在未连通至主流程的节点: ${node.data.name}` };
    }
  }

  return { valid: true };
};

👉 升维理解

这其实是:

图结构合法性校验(Graph Validation),不是所有节点都能乱连


四、React Flow 注意点
1️⃣ ReactFlowProvider 的作用
TypeScript 复制代码
<ReactFlowProvider>
  <Canvas />
  <Toolbar />
</ReactFlowProvider>

1️⃣ 提供 React Flow 内部状态(nodes / edges / viewport 等)

2️⃣ 提供操作方法(zoom / fitView / project 等)

3️⃣ 让多个组件共享同一个 flow 实例

否则:

👉 useReactFlow() 会拿不到实例


2️⃣ 节点结构
TypeScript 复制代码
data: {
  type,
  name,
  config,
  position
}
👉 为什么把 position 存进 data?

原因:

ReactFlow 的 position 在拖拽时是"临时态"

你存一份:

👉 可以用于新增子节点定位


3️⃣ Handle 是什么?
TypeScript 复制代码
<Handle type="source" />
<Handle type="target" />

👉 本质:

连接点(输入 / 输出)

相关推荐
kilito_012 小时前
react疑难讲解
前端·react.js·前端框架
字节跳动的猫2 小时前
2026 四款 AI:开发场景适配全面解析
前端·人工智能·开源
gyx_这个杀手不太冷静2 小时前
大人工智能时代下前端界面全新开发模式的思考(四)
前端·架构·ai编程
Ruihong2 小时前
你的 Vue 3 useAttrs(),VuReact 会编译成什么样的 React?
vue.js·react.js·面试
李剑一2 小时前
我做了个微信聊天模拟器,已开源
前端
代码搬运媛3 小时前
30分钟带你从0手搓一个AI-Cli命令行工具
前端
赛博切图仔3 小时前
前端性能内卷终点?Signals 正在重塑我们的开发习惯
前端·javascript·vue.js
小江的记录本3 小时前
【RAG】RAG检索增强生成(核心架构、全流程、RAG优化方案、常见问题与解决方案)
java·前端·人工智能·后端·python·机器学习·架构
程序员buddha3 小时前
SCSS从0到1精通教程
前端·css·scss