Day 65: 集成学习之 AdaBoosting (3. 集成器)

代码:

java 复制代码
package dl;

import java.io.FileReader;
import weka.core.Instance;
import weka.core.Instances;

/**
 * The booster which ensembles base classifiers.
 */
public class Booster {

    /**
     * Classifiers.
     */
    SimpleClassifier[] classifiers;

    /**
     * Number of classifiers.
     */
    int numClassifiers;

    /**
     * Whether or not stop after the training error is 0.
     */
    boolean stopAfterConverge = false;

    /**
     * The weights of classifiers.
     */
    double[] classifierWeights;

    /**
     * The training data.
     */
    Instances trainingData;

    /**
     * The testing data.
     */
    Instances testingData;

    /**
     ******************
     * The first constructor. The testing set is the same as the training set.
     *
     * @param paraTrainingFilename
     *            The data filename.
     ******************
     */
    public Booster(String paraTrainingFilename) {
        // Step 1. Read training set.
        try {
            FileReader tempFileReader = new FileReader(paraTrainingFilename);
            trainingData = new Instances(tempFileReader);
            tempFileReader.close();
        } catch (Exception ee) {
            System.out.println("Cannot read the file: " + paraTrainingFilename + "\r\n" + ee);
            System.exit(0);
        } // Of try

        // Step 2. Set the last attribute as the class index.
        trainingData.setClassIndex(trainingData.numAttributes() - 1);

        // Step 3. The testing data is the same as the training data.
        testingData = trainingData;

        stopAfterConverge = true;

        System.out.println("****************Data**********\r\n" + trainingData);
    }// Of the first constructor

    /**
     ******************
     * Set the number of base classifier, and allocate space for them.
     *
     * @param paraNumBaseClassifiers
     *            The number of base classifier.
     ******************
     */
    public void setNumBaseClassifiers(int paraNumBaseClassifiers) {
        numClassifiers = paraNumBaseClassifiers;

        // Step 1. Allocate space (only reference) for classifiers
        classifiers = new SimpleClassifier[numClassifiers];

        // Step 2. Initialize classifier weights.
        classifierWeights = new double[numClassifiers];
    }// Of setNumBaseClassifiers

    /**
     ******************
     * Train the booster.
     *
     * @see algorithm.StumpClassifier#train()
     ******************
     */
    public void train() {
        // Step 1. Initialize.
        WeightedInstances tempWeightedInstances = null;
        double tempError;
        numClassifiers = 0;

        // Step 2. Build other classifiers.
        for (int i = 0; i < classifiers.length; i++) {
            // Step 2.1 Key code: Construct or adjust the weightedInstances
            if (i == 0) {
                tempWeightedInstances = new WeightedInstances(trainingData);
            } else {
                // Adjust the weights of the data.
                tempWeightedInstances.adjustWeights(classifiers[i - 1].computeCorrectnessArray(),
                        classifierWeights[i - 1]);
            } // Of if

            // Step 2.2 Train the next classifier.
            classifiers[i] = new StumpClassifier(tempWeightedInstances);
            classifiers[i].train();

            tempError = classifiers[i].computeWeightedError();

            // Key code: Set the classifier weight.
            classifierWeights[i] = 0.5 * Math.log(1 / tempError - 1);
            if (classifierWeights[i] < 1e-6) {
                classifierWeights[i] = 0;
            } // Of if

            System.out.println("Classifier #" + i + " , weighted error = " + tempError + ", weight = "
                    + classifierWeights[i] + "\r\n");

            numClassifiers++;

            // The accuracy is enough.
            if (stopAfterConverge) {
                double tempTrainingAccuracy = computeTrainingAccuray();
                System.out.println("The accuracy of the booster is: " + tempTrainingAccuracy + "\r\n");
                if (tempTrainingAccuracy > 0.999999) {
                    System.out.println("Stop at the round: " + i + " due to converge.\r\n");
                    break;
                } // Of if
            } // Of if
        } // Of for i
    }// Of train

    /**
     ******************
     * Classify an instance.
     *
     * @param paraInstance
     *            The given instance.
     * @return The predicted label.
     ******************
     */
    public int classify(Instance paraInstance) {
        double[] tempLabelsCountArray = new double[trainingData.classAttribute().numValues()];
        for (int i = 0; i < numClassifiers; i++) {
            int tempLabel = classifiers[i].classify(paraInstance);
            tempLabelsCountArray[tempLabel] += classifierWeights[i];
        } // Of for i

        int resultLabel = -1;
        double tempMax = -1;
        for (int i = 0; i < tempLabelsCountArray.length; i++) {
            if (tempMax < tempLabelsCountArray[i]) {
                tempMax = tempLabelsCountArray[i];
                resultLabel = i;
            } // Of if
        } // Of for

        return resultLabel;
    }// Of classify

    /**
     ******************
     * Test the booster on the training data.
     *
     * @return The classification accuracy.
     ******************
     */
    public double test() {
        System.out.println("Testing on " + testingData.numInstances() + " instances.\r\n");

        return test(testingData);
    }// Of test

    /**
     ******************
     * Test the booster.
     *
     * @param paraInstances
     *            The testing set.
     * @return The classification accuracy.
     ******************
     */
    public double test(Instances paraInstances) {
        double tempCorrect = 0;
        paraInstances.setClassIndex(paraInstances.numAttributes() - 1);

        for (int i = 0; i < paraInstances.numInstances(); i++) {
            Instance tempInstance = paraInstances.instance(i);
            if (classify(tempInstance) == (int) tempInstance.classValue()) {
                tempCorrect++;
            } // Of if
        } // Of for i

        double resultAccuracy = tempCorrect / paraInstances.numInstances();
        System.out.println("The accuracy is: " + resultAccuracy);

        return resultAccuracy;
    } // Of test

    /**
     ******************
     * Compute the training accuracy of the booster. It is not weighted.
     *
     * @return The training accuracy.
     ******************
     */
    public double computeTrainingAccuray() {
        double tempCorrect = 0;

        for (int i = 0; i < trainingData.numInstances(); i++) {
            if (classify(trainingData.instance(i)) == (int) trainingData.instance(i).classValue()) {
                tempCorrect++;
            } // Of if
        } // Of for i

        double tempAccuracy = tempCorrect / trainingData.numInstances();

        return tempAccuracy;
    }// Of computeTrainingAccuray

    /**
     ******************
     * For integration test.
     *
     * @param args
     *            Not provided.
     ******************
     */
    public static void main(String args[]) {
        System.out.println("Starting AdaBoosting...");
        Booster tempBooster = new Booster("C:\\Users\\86183\\IdeaProjects\\deepLearning\\src\\main\\java\\resources\\iris.arff");

        tempBooster.setNumBaseClassifiers(100);
        tempBooster.train();

        System.out.println("The training accuracy is: " + tempBooster.computeTrainingAccuray());
        tempBooster.test();
    }// Of main

}// Of class Booster

结果:

相关推荐
点云SLAM3 分钟前
Decisive 英文单词学习
人工智能·学习·英文单词学习·雅思备考·decisive·起决定性的·果断的
码农很忙4 分钟前
让复杂AI应用构建像搭积木:Spring AI Alibaba Graph深度指南与源码拆解
开发语言·人工智能·python
余俊晖15 分钟前
多模态视觉语言模型增强原生分辨率继续预训练方法-COMP架构及训练方法
人工智能·语言模型·自然语言处理
运维@小兵27 分钟前
使用Spring-ai实现同步响应和流式响应
java·人工智能·spring-ai·ai流式响应
玩具猴_wjh28 分钟前
线性规划核心知识点
人工智能·机器学习
科学最TOP33 分钟前
IJCAI25|如何平衡文本与时序信息的融合适配?
人工智能·深度学习·神经网络·机器学习·时间序列
maycho1231 小时前
探索锂电池主动均衡仿真:从开关电容到多种电路的奇妙之旅
人工智能
余俊晖1 小时前
多模态文档智能解析模型进展-英伟达NVIDIA-Nemotron-Parse-v1.1
人工智能·ocr·多模态
南太湖小蚂蚁1 小时前
通过TRAE和LLM实现电影数据查询和分析
人工智能