决策树原理(一):信息增益与特征选择 —— Java 实现 ID3 算法

决策树原理(一):信息增益与特征选择 ------ Java 实现 ID3 算法

------别被"黑盒"吓住,一棵树的决策逻辑,应该像 SQL 一样清晰

大家好,我是那个总在模型评审会上被问"为什么这个用户被判定为高风险?"、又在 KES 表里逐层验证每条规则的老架构。今天我们要聊一个看似简单、却深刻影响可解释 AI 的基础模型:

决策树(Decision Tree)

很多人一听"树",就想到 XGBoost、LightGBM 这些集成大杀器。

但真相是:所有复杂树模型的根,都长在 ID3/C4.5 这些朴素算法里

而它的核心思想,只用一句话就能说清:

"每次选一个特征,把数据分得最'纯'。"

今天我们就手写 ID3 算法,用纯 Java 实现,并从电科金仓 KingbaseES(KES)中加载真实业务数据------全程不依赖任何 ML 框架,只为让你真正理解:信息增益如何驱动特征选择,以及为什么决策树是国产化场景中最值得信赖的模型之一


一、问题场景:信贷审批规则挖掘

假设我们在 KES 中有一张信贷申请表:

sql 复制代码
CREATE TABLE ai_features.loan_applications (
    app_id        BIGINT,
    age_group     VARCHAR(10),   -- 'young', 'middle', 'senior'
    income_level  VARCHAR(10),   -- 'low', 'medium', 'high'
    has_car       BOOLEAN,
    credit_score  VARCHAR(10),   -- 'poor', 'fair', 'good'
    approved      BOOLEAN        -- true=通过, false=拒绝 ← 标签
);

目标:自动学习一套 if-else 规则,用于辅助人工审批

这正是决策树的强项------输出人类可读的规则链


二、核心原理:信息熵与信息增益

2.1 信息熵(Entropy)------衡量"混乱度"

对于一个数据集 D,其标签分布越均匀,熵越大;越集中,熵越小。

H(D) = - Σ pᵢ log₂(pᵢ)

例如:

  • 全部通过(p=1)→ H = 0(最纯)
  • 一半通过一半拒绝(p=0.5)→ H = 1(最乱)

2.2 信息增益(Information Gain)------衡量"分得有多纯"

当我们用特征 A 划分数据集 D,得到子集 D₁, D₂, ..., Dₖ:

IG(D, A) = H(D) - Σ (|Dⱼ| / |D|) × H(Dⱼ)

增益越大,说明用 A 划分后整体更"纯净"

ID3 算法就是:每次选 IG 最大的特征作为分裂节点

💡 关键洞察:
信息增益本质是"不确定性减少量"------AI 的决策,就是不断消除不确定性


三、Java 实现:从 KES 加载数据 + 手写 ID3

步骤 1:定义数据结构

java 复制代码
public static class Instance {
    public final Map<String, String> features; // 特征名 → 值(离散)
    public final boolean label;                // true/false

    public Instance(Map<String, String> features, boolean label) {
        this.features = features;
        this.label = label;
    }
}

⚠️ 注意:ID3 只支持离散特征,连续特征需先分箱(binning)。


步骤 2:从 KES 加载信贷数据

java 复制代码
public List<Instance> loadLoanData(Connection conn) throws SQLException {
    String sql = "SELECT age_group, income_level, has_car::TEXT, credit_score, approved FROM ai_features.loan_applications";
    List<Instance> data = new ArrayList<>();
    
    try (PreparedStatement ps = conn.prepareStatement(sql);
         ResultSet rs = ps.executeQuery()) {
        while (rs.next()) {
            Map<String, String> feats = new HashMap<>();
            feats.put("age_group", rs.getString("age_group"));
            feats.put("income_level", rs.getString("income_level"));
            feats.put("has_car", rs.getString("has_car")); // "true"/"false"
            feats.put("credit_score", rs.getString("credit_score"));
            boolean approved = rs.getBoolean("approved");
            data.add(new Instance(feats, approved));
        }
    }
    return data;
}

🔗 使用 KES JDBC 驱动 确保 getBoolean/getString 稳定。


步骤 3:计算熵与信息增益

java 复制代码
// 计算数据集熵
public static double entropy(List<Instance> data) {
    if (data.isEmpty()) return 0.0;
    long total = data.size();
    long positive = data.stream().filter(i -> i.label).count();
    double p = positive / (double) total;
    if (p == 0 || p == 1) return 0.0;
    return -p * log2(p) - (1 - p) * log2(1 - p);
}

private static double log2(double x) {
    return Math.log(x) / Math.log(2);
}

// 计算某特征的信息增益
public static double informationGain(List<Instance> data, String featureName) {
    double baseEntropy = entropy(data);
    Map<String, List<Instance>> splits = data.stream()
        .collect(Collectors.groupingBy(inst -> inst.features.get(featureName)));
    
    double weightedEntropy = 0.0;
    for (List<Instance> subset : splits.values()) {
        double weight = subset.size() / (double) data.size();
        weightedEntropy += weight * entropy(subset);
    }
    return baseEntropy - weightedEntropy;
}

步骤 4:递归构建决策树

java 复制代码
public static TreeNode buildTree(List<Instance> data, Set<String> availableFeatures) {
    // 终止条件1:全为同一类
    if (data.stream().allMatch(i -> i.label)) return new TreeNode(true);
    if (data.stream().noneMatch(i -> i.label)) return new TreeNode(false);

    // 终止条件2:无特征可用
    if (availableFeatures.isEmpty()) {
        boolean majority = data.stream().filter(i -> i.label).count() > data.size() / 2;
        return new TreeNode(majority);
    }

    // 选择信息增益最大的特征
    String bestFeature = null;
    double maxGain = -1;
    for (String feat : availableFeatures) {
        double gain = informationGain(data, feat);
        if (gain > maxGain) {
            maxGain = gain;
            bestFeature = feat;
        }
    }

    // 创建内部节点
    TreeNode node = new TreeNode(bestFeature);
    Map<String, List<Instance>> splits = data.stream()
        .collect(Collectors.groupingBy(inst -> inst.features.get(bestFeature)));

    // 递归构建子树
    Set<String> newFeatures = new HashSet<>(availableFeatures);
    newFeatures.remove(bestFeature);
    for (Map.Entry<String, List<Instance>> entry : splits.entrySet()) {
        TreeNode child = buildTree(entry.getValue(), newFeatures);
        node.addChild(entry.getKey(), child);
    }

    return node;
}

// 树节点定义
public static class TreeNode {
    private final String feature;      // null 表示叶节点
    private final Boolean prediction;  // 叶节点预测值
    private final Map<String, TreeNode> children = new HashMap<>();

    public TreeNode(String feature) { this.feature = feature; this.prediction = null; }
    public TreeNode(boolean prediction) { this.feature = null; this.prediction = prediction; }

    public void addChild(String value, TreeNode child) { children.put(value, child); }
    public boolean isLeaf() { return feature == null; }
    // ... getter
}

步骤 5:打印可读规则

java 复制代码
public static void printRules(TreeNode node, List<String> path) {
    if (node.isLeaf()) {
        System.out.println(String.join(" AND ", path) + " → APPROVED=" + node.prediction);
        return;
    }
    for (Map.Entry<String, TreeNode> child : node.children.entrySet()) {
        List<String> newPath = new ArrayList<>(path);
        newPath.add(node.feature + "=" + child.getKey());
        printRules(child.getValue(), newPath);
    }
}

// 调用
TreeNode root = buildTree(data, Set.of("age_group", "income_level", "has_car", "credit_score"));
printRules(root, new ArrayList<>());

典型输出:

复制代码
credit_score=good AND income_level=high → APPROVED=true  
credit_score=poor → APPROVED=false  
credit_score=fair AND has_car=true → APPROVED=true  
...

✅ 这就是可解释、可审计、可人工干预的 AI 规则!


四、与 KES 协同:将规则存入数据库供风控系统调用

sql 复制代码
CREATE TABLE ai_models.decision_rules (
    rule_id      SERIAL PRIMARY KEY,
    condition    TEXT,          -- "credit_score='good' AND income_level='high'"
    action       VARCHAR(10),   -- 'approve', 'reject'
    created_at   TIMESTAMP DEFAULT NOW()
);

Java 写入:

java 复制代码
public void saveRulesToKES(Connection conn, TreeNode root) throws SQLException {
    List<String> rules = extractRuleStrings(root, new ArrayList<>());
    String sql = "INSERT INTO ai_models.decision_rules (condition, action) VALUES (?, ?)";
    try (PreparedStatement ps = conn.prepareStatement(sql)) {
        for (String rule : rules) {
            String[] parts = rule.split(" → ");
            ps.setString(1, parts[0]);
            ps.setString(2, parts[1].contains("true") ? "approve" : "reject");
            ps.addBatch();
        }
        ps.executeBatch();
    }
}

风控系统只需执行:

sql 复制代码
SELECT action FROM ai_models.decision_rules 
WHERE 'credit_score=good AND income_level=high' LIKE CONCAT('%', condition, '%')
LIMIT 1;

五、ID3 的局限与工程思考

  1. 只支持离散特征 → 连续特征需预处理(如 KES 中用 CASE WHEN 分箱);
  2. 偏向高基数特征 → 后续 C4.5 引入增益率解决;
  3. 容易过拟合 → 需剪枝(后续文章讲)。

但在国产化场景中,ID3 的透明性远胜精度损失------尤其在金融、政务等强监管领域。


结语:可解释,是国产 AI 的护城河

在追逐大模型的时代,我们容易忽略:很多业务问题,根本不需要黑盒

一个由 ID3 生成的决策树,配合电科金仓 KES 的稳定存储,足以支撑信贷审批、合规检查、运维告警等核心场景。

当你能在飞腾服务器上,用不到 200 行 Java 代码,从 KES 读取数据、训练模型、输出可读规则------你就拥有了自主可控、可解释、可落地的 AI 能力

下一期,我们会讲:决策树原理(二):C4.5 与连续特征处理,Java 实现增益率

敬请期待。

------ 一位相信"最好的 AI,是业务人员能看懂的 AI"的架构师

相关推荐
2401_832402752 小时前
使用Docker容器化你的Python应用
jvm·数据库·python
仍然.2 小时前
MySQL--库的操作、数据类型、表的操作
数据库·mysql
让我上个超影吧2 小时前
天机学堂——BitMap实现签到
java·数据库·spring boot·redis·spring cloud
迷路爸爸1802 小时前
无sudo权限远程连接Ubuntu服务器安装TeX Live实操记录(适配VS Code+LaTeX Workshop,含路径选择与卸载方案)
java·服务器·ubuntu·latex
有梦想的攻城狮2 小时前
maven中的os-maven-plugin插件的使用
java·maven·maven插件·os-maven-plugin·classifer
宇神城主_蒋浩宇2 小时前
最简单的es理解 数据库视角看写 ES 加 java正删改查深度分页
大数据·数据库·elasticsearch
A尘埃2 小时前
数值特征标准化StandardScaler和类别不平衡SMOTE
人工智能·深度学习·机器学习
Carry灭霸2 小时前
【BUG】Redisson Connection refused 127.0.0.1
java·redis
消失的旧时光-19432 小时前
第九课实战版:异常与日志体系 —— 后端稳定性的第一道防线
java·后端