Day 65: 集成学习之 AdaBoosting (3. 集成器)

代码:

java 复制代码
package dl;

import java.io.FileReader;
import weka.core.Instance;
import weka.core.Instances;

/**
 * The booster which ensembles base classifiers.
 */
public class Booster {

    /**
     * Classifiers.
     */
    SimpleClassifier[] classifiers;

    /**
     * Number of classifiers.
     */
    int numClassifiers;

    /**
     * Whether or not stop after the training error is 0.
     */
    boolean stopAfterConverge = false;

    /**
     * The weights of classifiers.
     */
    double[] classifierWeights;

    /**
     * The training data.
     */
    Instances trainingData;

    /**
     * The testing data.
     */
    Instances testingData;

    /**
     ******************
     * The first constructor. The testing set is the same as the training set.
     *
     * @param paraTrainingFilename
     *            The data filename.
     ******************
     */
    public Booster(String paraTrainingFilename) {
        // Step 1. Read training set.
        try {
            FileReader tempFileReader = new FileReader(paraTrainingFilename);
            trainingData = new Instances(tempFileReader);
            tempFileReader.close();
        } catch (Exception ee) {
            System.out.println("Cannot read the file: " + paraTrainingFilename + "\r\n" + ee);
            System.exit(0);
        } // Of try

        // Step 2. Set the last attribute as the class index.
        trainingData.setClassIndex(trainingData.numAttributes() - 1);

        // Step 3. The testing data is the same as the training data.
        testingData = trainingData;

        stopAfterConverge = true;

        System.out.println("****************Data**********\r\n" + trainingData);
    }// Of the first constructor

    /**
     ******************
     * Set the number of base classifier, and allocate space for them.
     *
     * @param paraNumBaseClassifiers
     *            The number of base classifier.
     ******************
     */
    public void setNumBaseClassifiers(int paraNumBaseClassifiers) {
        numClassifiers = paraNumBaseClassifiers;

        // Step 1. Allocate space (only reference) for classifiers
        classifiers = new SimpleClassifier[numClassifiers];

        // Step 2. Initialize classifier weights.
        classifierWeights = new double[numClassifiers];
    }// Of setNumBaseClassifiers

    /**
     ******************
     * Train the booster.
     *
     * @see algorithm.StumpClassifier#train()
     ******************
     */
    public void train() {
        // Step 1. Initialize.
        WeightedInstances tempWeightedInstances = null;
        double tempError;
        numClassifiers = 0;

        // Step 2. Build other classifiers.
        for (int i = 0; i < classifiers.length; i++) {
            // Step 2.1 Key code: Construct or adjust the weightedInstances
            if (i == 0) {
                tempWeightedInstances = new WeightedInstances(trainingData);
            } else {
                // Adjust the weights of the data.
                tempWeightedInstances.adjustWeights(classifiers[i - 1].computeCorrectnessArray(),
                        classifierWeights[i - 1]);
            } // Of if

            // Step 2.2 Train the next classifier.
            classifiers[i] = new StumpClassifier(tempWeightedInstances);
            classifiers[i].train();

            tempError = classifiers[i].computeWeightedError();

            // Key code: Set the classifier weight.
            classifierWeights[i] = 0.5 * Math.log(1 / tempError - 1);
            if (classifierWeights[i] < 1e-6) {
                classifierWeights[i] = 0;
            } // Of if

            System.out.println("Classifier #" + i + " , weighted error = " + tempError + ", weight = "
                    + classifierWeights[i] + "\r\n");

            numClassifiers++;

            // The accuracy is enough.
            if (stopAfterConverge) {
                double tempTrainingAccuracy = computeTrainingAccuray();
                System.out.println("The accuracy of the booster is: " + tempTrainingAccuracy + "\r\n");
                if (tempTrainingAccuracy > 0.999999) {
                    System.out.println("Stop at the round: " + i + " due to converge.\r\n");
                    break;
                } // Of if
            } // Of if
        } // Of for i
    }// Of train

    /**
     ******************
     * Classify an instance.
     *
     * @param paraInstance
     *            The given instance.
     * @return The predicted label.
     ******************
     */
    public int classify(Instance paraInstance) {
        double[] tempLabelsCountArray = new double[trainingData.classAttribute().numValues()];
        for (int i = 0; i < numClassifiers; i++) {
            int tempLabel = classifiers[i].classify(paraInstance);
            tempLabelsCountArray[tempLabel] += classifierWeights[i];
        } // Of for i

        int resultLabel = -1;
        double tempMax = -1;
        for (int i = 0; i < tempLabelsCountArray.length; i++) {
            if (tempMax < tempLabelsCountArray[i]) {
                tempMax = tempLabelsCountArray[i];
                resultLabel = i;
            } // Of if
        } // Of for

        return resultLabel;
    }// Of classify

    /**
     ******************
     * Test the booster on the training data.
     *
     * @return The classification accuracy.
     ******************
     */
    public double test() {
        System.out.println("Testing on " + testingData.numInstances() + " instances.\r\n");

        return test(testingData);
    }// Of test

    /**
     ******************
     * Test the booster.
     *
     * @param paraInstances
     *            The testing set.
     * @return The classification accuracy.
     ******************
     */
    public double test(Instances paraInstances) {
        double tempCorrect = 0;
        paraInstances.setClassIndex(paraInstances.numAttributes() - 1);

        for (int i = 0; i < paraInstances.numInstances(); i++) {
            Instance tempInstance = paraInstances.instance(i);
            if (classify(tempInstance) == (int) tempInstance.classValue()) {
                tempCorrect++;
            } // Of if
        } // Of for i

        double resultAccuracy = tempCorrect / paraInstances.numInstances();
        System.out.println("The accuracy is: " + resultAccuracy);

        return resultAccuracy;
    } // Of test

    /**
     ******************
     * Compute the training accuracy of the booster. It is not weighted.
     *
     * @return The training accuracy.
     ******************
     */
    public double computeTrainingAccuray() {
        double tempCorrect = 0;

        for (int i = 0; i < trainingData.numInstances(); i++) {
            if (classify(trainingData.instance(i)) == (int) trainingData.instance(i).classValue()) {
                tempCorrect++;
            } // Of if
        } // Of for i

        double tempAccuracy = tempCorrect / trainingData.numInstances();

        return tempAccuracy;
    }// Of computeTrainingAccuray

    /**
     ******************
     * For integration test.
     *
     * @param args
     *            Not provided.
     ******************
     */
    public static void main(String args[]) {
        System.out.println("Starting AdaBoosting...");
        Booster tempBooster = new Booster("C:\\Users\\86183\\IdeaProjects\\deepLearning\\src\\main\\java\\resources\\iris.arff");

        tempBooster.setNumBaseClassifiers(100);
        tempBooster.train();

        System.out.println("The training accuracy is: " + tempBooster.computeTrainingAccuray());
        tempBooster.test();
    }// Of main

}// Of class Booster

结果:

相关推荐
twc82932 分钟前
大模型生成 QA Pairs 提升 RAG 应用测试效率的实践
服务器·数据库·人工智能·windows·rag·大模型测试
宇擎智脑科技33 分钟前
A2A Python SDK 源码架构解读:一个请求是如何被处理的
人工智能·python·架构·a2a
IT_陈寒34 分钟前
Redis缓存击穿:3个鲜为人知的防御策略,90%开发者都忽略了!
前端·人工智能·后端
vx_biyesheji000136 分钟前
Python 全国城市租房洞察系统 Django框架 Requests爬虫 可视化 房子 房源 大数据 大模型 计算机毕业设计源码(建议收藏)✅
爬虫·python·机器学习·django·flask·课程设计·旅游
电商API&Tina1 小时前
【电商API接口】开发者一站式电商API接入说明
大数据·数据库·人工智能·云计算·json
湘美书院--湘美谈教育1 小时前
湘美谈教育湘美书院网文研究:人工智能与微型小说选集
人工智能·深度学习·神经网络·机器学习·ai写作
uzong1 小时前
Harness Engineering 是什么?一场新的 AI 范式已经开始
人工智能·后端·架构
墨有6661 小时前
FieldFormer:基于物理场论的极简AI大模型底层架构,附带源码
人工智能·架构·电磁场算法映射
Mountain and sea2 小时前
从零搭建工业机器人激光切割+焊接产线:KUKA七轴协同+节卡AGV+视觉检测实战复盘
人工智能·机器人·视觉检测
K姐研究社2 小时前
阿里JVS Claw实测 – 手机一键部署 OpenClaw,开箱即用
人工智能·智能手机·aigc·飞书