决策树实战:基于 KingbaseES 的鸢尾花分类 ------ 模型可视化输出
------别让"Hello World"停留在内存里,让它在国产数据库中生根发芽
大家好,我是那个总在演示会上被问"能不能把这棵树画出来?"、又在 KES 表里手动拼接 feature + ' <= ' + threshold 的老架构。今天我们要干一件看似简单、却极具象征意义的事:
用决策树对经典的鸢尾花(Iris)数据集做分类,并把整棵树从电科金仓 KingbaseES(KES)中读出来、训练、再完整地可视化输出。
很多人说:"鸢尾花?那是玩具数据。"
但真相是:玩具数据的价值,在于验证你的技术栈是否闭环。
如果你连 Iris 都跑不通端到端------从建表、入库、训练到可视化------那你凭什么相信自己能在真实业务中驾驭更复杂的模型?
今天我们就用 Java + 自研 CART 实现 + Graphviz,完成一次纯国产技术栈的 AI 全流程演练。全程不依赖 Python、不调 sklearn,只为证明一件事:
在电科金仓的土壤上,我们也能长出清晰、可解释、可展示的 AI 之树。
一、为什么选鸢尾花?因为它暴露一切细节
Iris 数据集只有 150 条记录,4 个连续特征(花萼/花瓣长宽),3 个类别(setosa, versicolor, virginica)。
但它足够小,能让你看清:
- 特征如何分裂;
- 树深多少合适;
- 可视化是否准确。
更重要的是:它没有缺失值、没有噪声、标签干净------是检验你工程链路是否通畅的"试金石"。
二、在 KES 中建表并加载数据
首先,在 KingbaseES 中创建 schema 和表:
sql
CREATE SCHEMA IF NOT EXISTS ai_demo;
CREATE TABLE ai_demo.iris_data (
id SERIAL PRIMARY KEY,
sepal_length REAL NOT NULL,
sepal_width REAL NOT NULL,
petal_length REAL NOT NULL,
petal_width REAL NOT NULL,
species VARCHAR(20) NOT NULL -- 'setosa', 'versicolor', 'virginica'
);
然后用 Java 批量插入(模拟从 CSV 或外部系统导入):
java
public void loadIrisDataToKES(Connection conn) throws SQLException {
// 鸢尾花标准数据(此处简化,实际可从资源文件读取)
double[][] features = {
{5.1, 3.5, 1.4, 0.2}, // setosa
{7.0, 3.2, 4.7, 1.4}, // versicolor
{6.3, 3.3, 6.0, 2.5} // virginica
// ... 共150条
};
String[] labels = {"setosa", "versicolor", "virginica", /*...*/};
String sql = "INSERT INTO ai_demo.iris_data (sepal_length, sepal_width, petal_length, petal_width, species) VALUES (?, ?, ?, ?, ?)";
try (PreparedStatement ps = conn.prepareStatement(sql)) {
for (int i = 0; i < features.length; i++) {
ps.setDouble(1, features[i][0]);
ps.setDouble(2, features[i][1]);
ps.setDouble(3, features[i][2]);
ps.setDouble(4, features[i][3]);
ps.setString(5, labels[i]);
ps.addBatch();
}
ps.executeBatch();
System.out.println("Loaded " + features.length + " iris samples into KES.");
}
}
🔗 确保使用 电科金仓 JDBC 驱动 支持
REAL类型精确写入。
三、从 KES 读取数据并训练 CART 树
复用上期优化后的 CART 实现,支持多分类(基尼不纯度自然扩展):
java
public List<Instance> loadIrisFromKES(Connection conn) throws SQLException {
String sql = "SELECT sepal_length, sepal_width, petal_length, petal_width, species FROM ai_demo.iris_data";
List<Instance> data = new ArrayList<>();
try (PreparedStatement ps = conn.prepareStatement(sql);
ResultSet rs = ps.executeQuery()) {
while (rs.next()) {
Map<String, FeatureValue> feats = new HashMap<>();
feats.put("sepal_length", new FeatureValue("sepal_length", rs.getDouble("sepal_length")));
feats.put("sepal_width", new FeatureValue("sepal_width", rs.getDouble("sepal_width")));
feats.put("petal_length", new FeatureValue("petal_length", rs.getDouble("petal_length")));
feats.put("petal_width", new FeatureValue("petal_width", rs.getDouble("petal_width")));
String species = rs.getString("species");
data.add(new Instance(feats, species));
}
}
return data;
}
// 注意:Instance.label 改为 String 以支持多类
训练(限制 maxDepth=3 避免过拟合):
java
List<Instance> irisData = loadIrisFromKES(conn);
Set<String> features = Set.of("sepal_length", "sepal_width", "petal_length", "petal_width");
TreeNode root = buildCartTree(irisData, features, 0, minSamplesSplit=5, maxDepth=3);
典型分裂结果(人工验证):
- 根节点:
petal_length <= 2.45→ 左子树全为 setosa; - 右子树:
petal_width <= 1.75→ 区分 versicolor/virginica。
四、模型可视化:生成 Graphviz DOT 文件
为了让业务方"看见"模型,我们输出标准 DOT 格式:
java
public void exportTreeToDot(TreeNode node, PrintWriter writer, AtomicInteger nodeId) {
int currentId = nodeId.getAndIncrement();
if (node.isLeaf()) {
// 叶节点:显示类别分布
String label = node.prediction + "\\n(samples=" + node.sampleCount + ")";
writer.println(currentId + " [label=\"" + label + "\", shape=box, style=filled, fillcolor=\"#e6f7ff\"];");
} else {
// 内部节点:显示分裂条件
String condition = node.featureName + " <= " + String.format("%.2f", node.threshold);
writer.println(currentId + " [label=\"" + condition + "\\nsamples=" + node.sampleCount + "\", shape=ellipse];");
// 递归左右子树
int leftId = nodeId.get();
exportTreeToDot(node.left, writer, nodeId);
writer.println(currentId + " -> " + leftId + " [label=\"True\"];");
int rightId = nodeId.get();
exportTreeToDot(node.right, writer, nodeId);
writer.println(currentId + " -> " + rightId + " [label=\"False\"];");
}
}
// 调用
try (PrintWriter writer = new PrintWriter("iris_tree.dot")) {
writer.println("digraph IrisTree {");
exportTreeToDot(root, writer, new AtomicInteger(0));
writer.println("}");
}
生成的 iris_tree.dot 示例:
dot
digraph IrisTree {
0 [label="petal_length <= 2.45\nsamples=150", shape=ellipse];
1 [label="setosa\nsamples=50", shape=box, style=filled, fillcolor="#e6f7ff"];
0 -> 1 [label="True"];
2 [label="petal_width <= 1.75\nsamples=100", shape=ellipse];
0 -> 2 [label="False"];
...
}
用 Graphviz 渲染:
bash
dot -Tpng iris_tree.dot -o iris_tree.png

✅ 这就是可交付、可汇报、可嵌入文档的模型资产。
五、将树结构存回 KES 供在线服务
为支持实时预测,我们将树节点持久化:
sql
CREATE TABLE ai_models.iris_tree_nodes (
node_id SERIAL PRIMARY KEY,
parent_id INT,
is_leaf BOOLEAN,
feature_name VARCHAR(32),
threshold DOUBLE PRECISION,
prediction VARCHAR(20),
sample_count INT,
path TEXT -- 如 "root/left/right"
);
Java 写入(略去递归逻辑):
java
saveNodeToKES(conn, root, null, "root");
在线预测服务只需递归查询:
sql
WITH RECURSIVE predict(path, feature, threshold, pred) AS (
SELECT 'root', feature_name, threshold, prediction
FROM ai_models.iris_tree_nodes WHERE path = 'root'
UNION ALL
SELECT ...,
CASE WHEN input.petal_length <= threshold THEN path || '/left' ELSE path || '/right' END
...
)
SELECT pred FROM predict WHERE is_leaf;
六、为什么这很重要?
- 闭环验证:从 KES 出,回 KES 去,证明国产数据库可作为 AI 数据底座;
- 可解释交付:一张图胜过千行 accuracy 报告;
- 轻量部署:无需 Python 环境,JVM + SQL 即可推理;
- 教学价值:新人可通过 Iris 快速理解整个 ML 流程。
结语:小数据,大意义
在追逐千亿参数的时代,我们容易忽略:AI 落地的第一步,往往是把一个 150 行的数据集跑通。
当你能在电科金仓 KES 中,用 Java 完成 Iris 的端到端训练与可视化------你就已经搭建了一个自主可控、可审计、可展示的 AI 基础能力。
而这,正是国产化替代最需要的"最小可行闭环"。
下一期,我们会讲:随机森林实战:用 KES 构建高精度信贷违约预测模型 。
敬请期待。
------ 一位相信"再小的模型,也值得被认真对待"的架构师