java中使用BP网络进行回归

在神经网络中,激活函数扮演着至关重要的角色,它们为网络引入了非线性因素,使得网络能够学习和模拟复杂的非线性关系。下面列出的各种激活函数及其特点如下:

  1. LINEAR ("Linear") :
    • 线性激活函数实际上就是恒等函数,即 f(x)=x。
    • 它不会给网络带来任何非线性特性,因此在实际应用中很少使用(除了在某些特殊情况下,如回归问题或某些层作为输出层时)。
  2. RAMP ("Ramp") :
    • Ramp函数是一种分段线性函数,通常定义为当 x<0 时 f(x)=0,当 x≥0 时 f(x)=x。
    • 它与ReLU(Rectified Linear Unit)类似,但Ramp函数在 x<0 时保持输出为0,而不是像ReLU那样保持不变。Ramp函数有时被认为是一个更平滑的ReLU版本。
  3. STEP ("Step") :
    • Step函数是一种简单的二值激活函数,通常定义为当 x<0 时 f(x)=0,当 x≥0 时 f(x)=1。
    • 它将任意输入值映射到两个输出值之一,这种非连续性使得它在大多数现代神经网络中不太实用。
  4. SIGMOID ("Sigmoid") :
    • Sigmoid函数是一个S形曲线,定义为 f(x)=1+e−x1。
    • 它将任意实值压缩到区间(0, 1)内,非常适合用于二分类问题的输出层。然而,由于梯度消失问题和计算量较大,它在深度神经网络中的使用已经减少。
  5. TANH ("Tanh") :
    • Tanh函数是双曲正切函数,定义为 f(x)=ex+e−xex−e−x。
    • 它将输入值压缩到区间(-1, 1)内,并且是零中心的(即,输出的均值为0)。这有助于加快收敛速度,因此它比Sigmoid函数更受欢迎。然而,梯度消失问题仍然存在。
  6. GAUSSIAN ("Gaussian") :
    • 高斯(或称为高斯)激活函数使用高斯(或正态)分布函数,形式可能因具体实现而异,但通常与标准正态分布相关。
    • 它将输入值映射到一个钟形曲线,但这种激活函数在神经网络中不太常见,因为其他激活函数通常能提供更好的性能。
  7. TRAPEZOID ("Trapezoid") :
    • 梯形激活函数是一种非线性函数,其形状类似于梯形。
    • 它结合了线性和非线性的部分,但在实际应用中不太常见,可能是因为它缺乏足够的理论依据或优势。
  8. SGN ("Sgn") :
    • Sgn(符号)函数根据输入值的正负输出+1、-1或0(取决于具体定义)。
    • 它是一种简单的二值激活函数,但在神经网络中并不常用,因为它的梯度在大部分输入上都是0(除了0点附近的微小区域),这会导致梯度消失问题。
  9. SIN ("Sin") :
    • Sin函数是正弦函数,将输入值映射到[-1, 1]区间的正弦波上。
    • 它是一个周期性的非线性函数,但在神经网络中不常用,因为神经网络通常希望激活函数具有单调递增或递减的特性,以便进行有效的梯度下降。
  10. LOG ("Log") :
    • Log函数(可能指的是自然对数或其他对数)将输入值映射到对数尺度上。
    • 它在某些特定领域(如机器学习中的某些损失函数)中很有用,但在作为神经网络的激活函数时并不常见。
  11. RECTIFIED ("RectifiedLinear", ReLU) :
    • ReLU(Rectified Linear Unit)函数是当前深度学习中使用最广泛的激活函数之一。
    • 它定义为 f(x)=max(0,x),即将所有负值置为0,保持正值不变。
    • ReLU函数简单、计算效率高,且有助于缓解梯度消失问题(在正值区域内)。然而,它也存在死亡ReLU问题(即某些神经元在训练过程中永远不会被激活)。

为了分析1个y和3个自变量(x,z, k) (0为训练集,1为测试集) 的回归问题使用java的BP网络激活函数使用RECTIFIED如下使用依赖如下:

XML 复制代码
  <properties>
<neuroph-core.version>2.96</neuroph-core.version>
<visrec-api.version>1.0.3</visrec-api.version>     
  </properties>

   <dependency>
            <groupId>org.neuroph</groupId>
            <artifactId>neuroph-core</artifactId>
            <version>${neuroph-core.version}</version>
        </dependency>

        <!-- https://mvnrepository.com/artifact/javax.visrec/visrec-api -->
        <dependency>
            <groupId>javax.visrec</groupId>
            <artifactId>visrec-api</artifactId>
            <version>${visrec-api.version}</version>
        </dependency>

进行回归的方法如下:

复制代码
private static void poly3UseBpn(Double[] x0, Double[] z0, Double[] k0, Double[] y, Double[] x1, Double[] z1, Double[] k1, Double[] y0, Double[] y1) {
    // 归一化数据
    double minX = Arrays.stream(x0).min(Double::compareTo).get();
    double maxX = Arrays.stream(x0).max(Double::compareTo).get();
    double rangeX = maxX - minX;
    double minZ = Arrays.stream(z0).min(Double::compareTo).get();
    double maxZ = Arrays.stream(z0).max(Double::compareTo).get();
    double rangeZ = maxZ - minZ;
    double minK = Arrays.stream(k0).min(Double::compareTo).get();
    double maxK = Arrays.stream(k0).max(Double::compareTo).get();
    double rangeK = maxK - minK;
    // 创建数据集 (3 个输入,1 个输出)
    DataSet trainingSet = new DataSet(3, 1);
    for (int i = 0; i < x0.length; i++) {
        double normalizedInputX = (x0[i] - minX) / rangeX;
        double normalizedInputZ = (z0[i] - minZ) / rangeZ;
        double normalizedInputK = (k0[i] - minK) / rangeZ;
        trainingSet.add(new DataSetRow(new double[]{normalizedInputX, normalizedInputZ, normalizedInputK}, new double[]{y[i]}));
    }
    // 创建一个具有 3 个输入层神经元,10 个隐藏层神经元,1 个输出层神经元的神经网络
    MultiLayerPerceptron neuralNet = new MultiLayerPerceptron(TransferFunctionType.RECTIFIED, 3, 10, 1);
    //使用固定的权重
    neuralNet.randomizeWeights(new Random(1));
    // 设置学习率和迭代次数
    double currentLearningRate = 0.01;
    int iterations = 10000;
    double errorThreshold = 0.001;
    double minLearningRate = 1e-5;
    double decay = 0.9;
    // 训练网络
    BackPropagation backPropagation = new BackPropagation();
    backPropagation.setLearningRate(currentLearningRate);
    neuralNet.setLearningRule(backPropagation);
    for (int i = 0; i < iterations; i++) {
        neuralNet.learn(trainingSet);
        // 获取当前误差
        double error = backPropagation.getTotalNetworkError();
        if (error < errorThreshold) {
            System.out.println("error:" + error + "---->" + (error < errorThreshold) + "====>" + i);
            break;
        }
        // 更新学习率(指数衰减)
        currentLearningRate = Math.max(minLearningRate, currentLearningRate * decay);
        backPropagation.setLearningRate(currentLearningRate);
    }

    // 在 x0 和 z0 上评估拟合值,并将结果存储在 y0 中
    for (int i = 0; i < x0.length; i++) {
        double normalizedInputX = (x0[i] - minX) / rangeX;
        double normalizedInputZ = (z0[i] - minZ) / rangeZ;
        double normalizedInputK = (k0[i] - minK) / rangeZ;
        neuralNet.setInput(normalizedInputX, normalizedInputZ, normalizedInputK);
        neuralNet.calculate();
        y0[i] = neuralNet.getOutput()[0];
    }
    // 在 x1 和 z1 上评估拟合值,并将结果存储在 y1 中
    for (int i = 0; i < x1.length; i++) {
        double normalizedInputX = (x1[i] - minX) / rangeX;
        double normalizedInputZ = (z1[i] - minZ) / rangeZ;
        double normalizedInputK = (k1[i] - minK) / rangeZ;
        neuralNet.setInput(normalizedInputX, normalizedInputZ, normalizedInputK);
        neuralNet.calculate();
        y1[i] = neuralNet.getOutput()[0];
    }

由于这边数据量比较小大概在70条左右,由于neuroph是单线程的,在执行到neuralNet.learn(trainingSet);这里时会卡住,调整学习率和减少隐藏层单元后仍旧无法训练故只能换个方法如下使用deeplearning4j

XML 复制代码
  <dl4j.version>1.0.0-M1</dl4j.version>
        <nd4j.version>1.0.0-M1</nd4j.version>
  
<dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-core</artifactId>
            <version>${dl4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-nn</artifactId>
            <version>${dl4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>nd4j-native-platform</artifactId>
            <version>${nd4j.version}</version>
        </dependency>

修改后的训练方法如下:

java 复制代码
public void poly3UseBpn(double[] x0, double[] z0, double[] k0, double[] y,
                            double[] x1, double[] z1, double[] k1, double[] y0, double[] y1) {
        // 创建输入数据集 (3个输入,1个输出)
        List<org.nd4j.linalg.dataset.DataSet> dataSetList = new ArrayList<>();
        for (int i = 0; i < x0.length; i++) {
            INDArray input = Nd4j.create(new double[][]{{x0[i], z0[i], k0[i]}});
            INDArray output = Nd4j.create(new double[][]{{y[i]}});
            org.nd4j.linalg.dataset.DataSet DataSet = new org.nd4j.linalg.dataset.DataSet(input, output);
            dataSetList.add(DataSet);
        }
        ListDataSetIterator<org.nd4j.linalg.dataset.DataSet> trainingData = new ListDataSetIterator<>(dataSetList);
        // 构建网络配置
        MultiLayerNetwork model = new MultiLayerNetwork(new NeuralNetConfiguration.Builder()
                .seed(42)//随机种子
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) //随机梯度下降
                .updater(new org.nd4j.linalg.learning.config.Adam(0.01)) //初始化学习率
                .weightInit(WeightInit.RELU)//权重初始化
                .list()
                .layer(0, new DenseLayer.Builder()
                        .nIn(3)
                        .nOut(10)
                        .activation(Activation.RELU) //激活函数
                        .build()) // 输入层 3个输入
                .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE) //MSE损失函数 MSE表示均方误差
                        .nIn(10)
                        .nOut(1)
                        .activation(Activation.IDENTITY)
                        .build()) // 输出层 1个输出
                .build());
        //网络初始化
        model.init();
        //用于输出误差信息 100次输出1个
        model.setListeners(new ScoreIterationListener(100));
        // 设置训练参数
        double errorThreshold = 0.001;
        int maxIterations = 1000;
        int patience = 10;
        double bestError = Double.MAX_VALUE;
        int noImprovementCount = 0;
        for (int i = 0; i < maxIterations; i++) {
            model.fit(trainingData);
            double currentError = model.score();
            System.out.println("error:" + currentError + "====>" + i);

            if (currentError < bestError) {
                bestError = currentError;
                noImprovementCount = 0;
            } else {
                noImprovementCount++;
            }

            if (currentError < errorThreshold || noImprovementCount >= patience) {
                break;
            }
        }

        // 评估在 x0 和 z0 上的拟合值,并将结果存储在 y0 中
        for (int i = 0; i < x0.length; i++) {
            INDArray input = Nd4j.create(new double[][]{{x0[i], z0[i], k0[i]}});
            INDArray output = model.output(input);
            y0[i] = output.getDouble(0);
        }

        // 评估在 x1 和 z1 上的拟合值,并将结果存储在 y1 中
        for (int i = 0; i < x1.length; i++) {
            INDArray input = Nd4j.create(new double[][]{{x1[i], z1[i], k1[i]}});
            INDArray output = model.output(input);
            y1[i] = output.getDouble(0);
        }
    }

修改完后重新运行,已经可以顺利训练出数据了,自此记录一下!

相关推荐
xlsw_2 小时前
java全栈day20--Web后端实战(Mybatis基础2)
java·开发语言·mybatis
神仙别闹3 小时前
基于java的改良版超级玛丽小游戏
java
黄油饼卷咖喱鸡就味增汤拌孜然羊肉炒饭3 小时前
SpringBoot如何实现缓存预热?
java·spring boot·spring·缓存·程序员
暮湫3 小时前
泛型(2)
java
超爱吃士力架3 小时前
邀请逻辑
java·linux·后端
南宫生3 小时前
力扣-图论-17【算法学习day.67】
java·学习·算法·leetcode·图论
转码的小石4 小时前
12/21java基础
java
李小白664 小时前
Spring MVC(上)
java·spring·mvc
GoodStudyAndDayDayUp4 小时前
IDEA能够从mapper跳转到xml的插件
xml·java·intellij-idea
装不满的克莱因瓶4 小时前
【Redis经典面试题六】Redis的持久化机制是怎样的?
java·数据库·redis·持久化·aof·rdb