决策树,随机森林,boost森林算法

欢迎访问我的主页: https://heeheeaii.github.io/

kotlin 复制代码
package com.treevalue.beself.other

import kotlin.math.*
import kotlin.random.Random

data class DataNode(val features: DoubleArray, val value: Double) {
    override fun equals(other: Any?): Boolean {
        if (this === other) return true
        if (javaClass != other?.javaClass) return false
        other as DataNode
        if (!features.contentEquals(other.features)) return false
        if (value != other.value) return false
        return true
    }

    override fun hashCode(): Int {
        var result = features.contentHashCode()
        result = 31 * result + value.hashCode()
        return result
    }
}

sealed class TreeNode {
    data class Leaf(val value: Double) : TreeNode()
    data class MidNode(
        val splitIdx: Int,
        val threshold: Double,
        val left: TreeNode,
        val right: TreeNode,
    ) : TreeNode()
}

class DecisionTree(
    private val maxDepth: Int = 10,
    private val minSplitNum: Int = 2,
    private val minSamplesLeaf: Int = 1,
) {
    private var root: TreeNode? = null

    fun train(data: List<DataNode>) {
        root = buildTree(data, depth = 0)
    }

    private fun buildTree(data: List<DataNode>, depth: Int): TreeNode {
        if (depth >= maxDepth || // 深度过深
            data.size < minSplitNum || // 数量过少
            data.map { it.value }.distinct().size == 1 // 值相同
        ) {
            val prediction = data.map { it.value }.average()
            return TreeNode.Leaf(prediction)
        }

        val bestSplit = findBestSplit(data)

        if (bestSplit == null) {
            val prediction = data.map { it.value }.average()
            return TreeNode.Leaf(prediction)
        }

        val (leftData, rightData) = splitData(data, bestSplit.first, bestSplit.second)

        if (leftData.size < minSamplesLeaf || rightData.size < minSamplesLeaf) {
            val prediction = data.map { it.value }.average()
            return TreeNode.Leaf(prediction)
        }

        val leftTree = buildTree(leftData, depth + 1)
        val rightTree = buildTree(rightData, depth + 1)

        return TreeNode.MidNode(bestSplit.first, bestSplit.second, leftTree, rightTree)
    }

    private fun findBestSplit(data: List<DataNode>): Pair<Int, Double>? {
        if (data.isEmpty()) return null

        val featureSize = data[0].features.size
        var bestMse = Double.MAX_VALUE // 均方误差
        var bestSplit: Pair<Int, Double>? = null // featureIdx, threshold

        for (featureIdx in 0 until featureSize) { // 遍历特征找最小的均方误差
            val featureValues = data.map { it.features[featureIdx] }.distinct().sorted()

            for (jdx in 0 until featureValues.size - 1) {
                val threshold = (featureValues[jdx] + featureValues[jdx + 1]) / 2
                val mse = calculateSplitMse(data, featureIdx, threshold)

                if (mse < bestMse) {
                    bestMse = mse
                    bestSplit = Pair(featureIdx, threshold)
                }
            }
        }

        return bestSplit
    }

    private fun calculateSplitMse(data: List<DataNode>, featureIndex: Int, threshold: Double): Double {
        val (leftData, rightData) = splitData(data, featureIndex, threshold)

        if (leftData.isEmpty() || rightData.isEmpty()) {
            return Double.MAX_VALUE
        }

        val totalSize = data.size.toDouble()
        val leftWeight = leftData.size / totalSize
        val rightWeight = rightData.size / totalSize

        val leftMse = calculateMse(leftData)
        val rightMse = calculateMse(rightData)

        return leftWeight * leftMse + rightWeight * rightMse
    }

    private fun calculateMse(data: List<DataNode>): Double {
        if (data.isEmpty()) return 0.0

        val mean = data.map { it.value }.average()
        return data.map { (it.value - mean).pow(2) }.average()
    }

    private fun splitData(
        data: List<DataNode>,
        featureIndex: Int,
        threshold: Double,
    ): Pair<List<DataNode>, List<DataNode>> {
        val leftData = data.filter { it.features[featureIndex] <= threshold }
        val rightData = data.filter { it.features[featureIndex] > threshold }
        return Pair(leftData, rightData)
    }

    fun predict(features: DoubleArray): Double {
        return root?.let { predictRecursive(it, features) } ?: 0.0
    }

    private fun predictRecursive(inputNode: TreeNode, features: DoubleArray): Double {
        return when (inputNode) {
            is TreeNode.Leaf -> inputNode.value
            is TreeNode.MidNode -> {
                if (features[inputNode.splitIdx] <= inputNode.threshold) {
                    predictRecursive(inputNode.left, features)
                } else {
                    predictRecursive(inputNode.right, features)
                }
            }
        }
    }

    fun predict(dataPoints: List<DoubleArray>): List<Double> {
        return dataPoints.map { predict(it) }
    }
}

class RandomForest(
    private val maxTreeNum: Int = 100,
    private val maxTreeDepth: Int = 10,
    private val minSplitNum: Int = 2,
    private val minLeafNodeNum: Int = 1,
    private val maxFeatureNum: Int? = null,
    private val sampleRatio: Double = 1.0,
    private val random: Random = Random.Default,
) {
    private val trees = mutableListOf<DecisionTree>()
    private val useFeatures = mutableListOf<IntArray>()

    fun train(data: List<DataNode>) {
        val featureNum = data[0].features.size
        val actualMaxFeatures = maxFeatureNum ?: sqrt(featureNum.toDouble()).toInt()

        repeat(maxTreeNum) { _ ->
            val sampleData = startSample(data, sampleRatio)
            val sampleFeature = selectRandomFeatures(featureNum, actualMaxFeatures)
            useFeatures.add(sampleFeature)

            val subsetData = createFeatureSubsetData(sampleData, sampleFeature) // 随机采样

            val tree = DecisionTree(maxTreeDepth, minSplitNum, minLeafNodeNum)
            tree.train(subsetData)
            trees.add(tree)
        }

    }

    private fun startSample(data: List<DataNode>, ratio: Double): List<DataNode> {
        val sampleSize = (data.size * ratio).toInt()
        return (1..sampleSize).map {
            data[random.nextInt(data.size)]
        }
    }

    private fun selectRandomFeatures(totalFeatures: Int, maxFeatures: Int): IntArray {
        val features = (0 until totalFeatures).toMutableList()
        features.shuffle(random)
        return features.take(maxFeatures).toIntArray()
    }

    private fun createFeatureSubsetData(data: List<DataNode>, featureSubset: IntArray): List<DataNode> {
        return data.map { point ->
            val subsetFeatures = featureSubset.map { point.features[it] }.toDoubleArray()
            DataNode(subsetFeatures, point.value)
        }
    }

    private fun predict(features: DoubleArray): Double {
        val predictions = trees.mapIndexed { idx, tree ->
            val featureSubset = useFeatures[idx]
            val subsetFeatures = featureSubset.map { features[it] }.toDoubleArray()
            tree.predict(subsetFeatures)
        }

        return predictions.average()
    }

    fun predict(dataPoints: List<DoubleArray>): List<Double> {
        return dataPoints.map { predict(it) }
    }

    fun getWeightStatistic(maxFeatureNum: Int): DoubleArray { //
        val weights = DoubleArray(maxFeatureNum)

        useFeatures.forEach { useFt ->
            useFt.forEach { idx ->
                weights[idx] += 1.0
            }
        }

        val total = weights.sum()
        if (total > 0) { // 除0错误
            for (idx in weights.indices) {
                weights[idx] /= total
            }
        }

        return weights
    }
}

class GradientBoostingRegressor(
    // 梯度提升回归器
    private val learnerNum: Int = 100,
    private val learningRate: Double = 0.1,
    private val maxTreeDepth: Int = 3,
    private val minSplitNum: Int = 2,
    private val minLeafNum: Int = 1,
    private val sampleRate: Double = 1.0,
    private val random: Random = Random.Default,
) {
    private val trees = mutableListOf<DecisionTree>()
    private var initPrediction: Double = 0.0

    fun train(data: List<DataNode>) {
        initPrediction = data.map { it.value }.average()

        var residuals = data.map { it.value - initPrediction }.toMutableList()

        repeat(learnerNum) { idx ->
            val residualData = data.mapIndexed { index, point ->
                DataNode(point.features, residuals[index])
            }

            val trainData = if (sampleRate < 1.0) {
                val sampleSize = (residualData.size * sampleRate).toInt()
                residualData.shuffled(random).take(sampleSize)
            } else {
                residualData
            }

            val tree = DecisionTree(
                maxDepth = maxTreeDepth,
                minSplitNum = minSplitNum,
                minSamplesLeaf = minLeafNum
            )
            tree.train(trainData)
            trees.add(tree)

            residuals = residuals.mapIndexed { idx, residual ->
                val prediction = tree.predict(data[idx].features)
                residual - learningRate * prediction
            }.toMutableList()

            if ((idx + 1) % 20 == 0) {
                val predictionList = data.map { predict(it.features) }
                val mse = data.zip(predictionList) { origin, pred ->
                    (origin.value - pred).pow(2)
                }.average()
                println("第 ${idx + 1} 轮后的MSE: $mse")
            }
        }

        println("梯度提升模型训练完成!")
    }

    private fun predict(features: DoubleArray): Double {
        var prediction = initPrediction

        trees.forEach { tree ->
            prediction += learningRate * tree.predict(features)
        }

        return prediction
    }

    fun predict(dataPoints: List<DoubleArray>): List<Double> {
        return dataPoints.map { predict(it) }
    }

}

object ModelEvaluator {
    fun calculateMseInSameLen(actual: List<Double>, predicted: List<Double>): Double {
        return actual.zip(predicted) { a, p -> (a - p).pow(2) }.average()
    }

    fun calculateRmseInSameLen(actual: List<Double>, predicted: List<Double>): Double {
        return sqrt(calculateMseInSameLen(actual, predicted))
    }

    fun calculateR2InSameLen(actual: List<Double>, predicted: List<Double>): Double {
        val actualMean = actual.average()
        val totalSumSquares = actual.sumOf { (it - actualMean).pow(2) }
        val residualSumSquares = actual.zip(predicted) { a, p -> (a - p).pow(2) }.sum()
        return 1.0 - (residualSumSquares / totalSumSquares)
    }
}

object DataGenerator {
    fun generateNonlinearData(
        samples: Int = 1000,
        noise: Double = 0.2,
        random: Random = Random.Default,
    ): List<DataNode> {
        return (1..samples).map {
            val x1 = random.nextDouble(-PI, PI)
            val x2 = random.nextDouble(-PI, PI)
            val x3 = random.nextDouble(-2.0, 2.0)

            val target = sin(x1) + cos(x2) + x3.pow(2) + random.nextDouble() * noise
            DataNode(doubleArrayOf(x1, x2, x3), target)
        }
    }
}

fun main() {
    println("=== 机器学习算法演示 ===\n")

    println("生成数据集...")
    val random = Random(42)
    val trainData = DataGenerator.generateNonlinearData(800, 0.2, random)
    val testData = DataGenerator.generateNonlinearData(200, 0.2, random)

    val testFeatures = testData.map { it.features }
    val testLabels = testData.map { it.value }

    println("训练集大小: ${trainData.size}")
    println("测试集大小: ${testData.size}\n")

    println("=== 1. 决策树 ===")
    val decisionTree = DecisionTree(maxDepth = 8, minSplitNum = 5, minSamplesLeaf = 2)
    decisionTree.train(trainData)

    val dtPredictions = decisionTree.predict(testFeatures)
    val dtMse = ModelEvaluator.calculateMseInSameLen(testLabels, dtPredictions)
    val dtRmse = ModelEvaluator.calculateRmseInSameLen(testLabels, dtPredictions)
    val dtR2 = ModelEvaluator.calculateR2InSameLen(testLabels, dtPredictions)

    println("决策树结果:")
    println("  MSE: $dtMse")
    println("  RMSE: $dtRmse")
    println("  R²: $dtR2\n")

    println("=== 2. 随机森林 ===")
    val randomForest = RandomForest(
        maxTreeNum = 50,
        maxTreeDepth = 8,
        minSplitNum = 5,
        minLeafNodeNum = 2,
        maxFeatureNum = 2,
        sampleRatio = 0.8,
        random = random
    )
    randomForest.train(trainData)

    val rfPredictions = randomForest.predict(testFeatures)
    val rfMse = ModelEvaluator.calculateMseInSameLen(testLabels, rfPredictions)
    val rfRmse = ModelEvaluator.calculateRmseInSameLen(testLabels, rfPredictions)
    val rfR2 = ModelEvaluator.calculateR2InSameLen(testLabels, rfPredictions)

    println("随机森林结果:")
    println("  MSE: $rfMse")
    println("  RMSE: $rfRmse")
    println("  R²: $rfR2")

    val featureImportances = randomForest.getWeightStatistic(3)
    println("  特征重要性: ${featureImportances.contentToString()}\n")

    println("=== 3. 梯度提升 ===")
    val gradientBoosting = GradientBoostingRegressor(
        learnerNum = 100,
        learningRate = 0.1,
        maxTreeDepth = 4,
        minSplitNum = 5,
        sampleRate = 0.8,
        random = random
    )
    gradientBoosting.train(trainData)

    val gbPredictions = gradientBoosting.predict(testFeatures)
    val gbMse = ModelEvaluator.calculateMseInSameLen(testLabels, gbPredictions)
    val gbRmse = ModelEvaluator.calculateRmseInSameLen(testLabels, gbPredictions)
    val gbR2 = ModelEvaluator.calculateR2InSameLen(testLabels, gbPredictions)

    println("梯度提升结果:")
    println("  MSE: $gbMse")
    println("  RMSE: $gbRmse")
    println("  R²: $gbR2\n")

    println("=== 算法对比 ===")
    println("算法           MSE        RMSE       R²")
    println("决策树      %.6f   %.6f   %.6f".format(dtMse, dtRmse, dtR2))
    println("随机森林    %.6f   %.6f   %.6f".format(rfMse, rfRmse, rfR2))
    println("梯度提升    %.6f   %.6f   %.6f".format(gbMse, gbRmse, gbR2))

    println("\n=== 演示完成 ===")
}
相关推荐
(●—●)橘子……3 小时前
记力扣2271.毯子覆盖的最多白色砖块数 练习理解
数据结构·笔记·python·学习·算法·leetcode
Tiny番茄4 小时前
排序算法汇总,堆排序,归并排序,冒泡排序,插入排序
算法·排序算法
汽车仪器仪表相关领域4 小时前
南华 NHXJ-02 汽车悬架检验台:技术特性与实操应用指南
人工智能·算法·汽车·安全性测试·稳定性测试·汽车检测·年检站
m0_726965985 小时前
【算法】小点:List.remove
算法
rhy200605205 小时前
SAM的低秩特性
人工智能·算法·机器学习·语言模型
new coder5 小时前
[算法练习]第三天:定长滑动窗口
数据结构·算法
eqwaak05 小时前
科技信息差(9.29)
开发语言·科技·学习·算法
晨非辰5 小时前
《从数组到动态顺序表:数据结构与算法如何优化内存管理?》
c语言·数据结构·经验分享·笔记·其他·算法
麻雀无能为力5 小时前
第三章 鸽巢原理
笔记·算法