线性回归的总结:

代码

java 复制代码
import java.util.Random;

public class ModelTraining {

    // 简单的线性模型: y = wx + b
    static class LinearModel {
        double w;  // 权重
        double b;  // 偏置
//        static Random random = new Random(42);  // 固定种子
        public LinearModel() {
            this.w = Math.random();  // 随机初始化
            this.b = Math.random();
        }

        // 前向传播:预测
        public double predict(double x) {
            return w * x + b;
        }

        @Override
        public String toString() {
            return String.format("y = %.4fx + %.4f", w, b);
        }
    }

    /**
     * 均方误差损失函数
     */
    public static double mseLoss(double[] x, double[] y, LinearModel model) {
        double sum = 0;
        int n = x.length;

        for (int i = 0; i < n; i++) {
            double predicted = model.predict(x[i]);
            double error = y[i] - predicted;
            sum += error * error;
        }

        return sum / n;
    }

    /**
     * 训练模型(梯度下降)
     * 输入:数据、损失函数、学习率、迭代次数
     * 输出:训练好的模型参数
     */
    public static LinearModel train(
            double[] x,
            double[] y,
            double learningRate,
            int epochs) {

        LinearModel model = new LinearModel();
        int n = x.length;

        System.out.println("初始模型: " + model);
        System.out.println("初始损失: " + mseLoss(x, y, model));
        System.out.println("==========================================");

        // 训练循环
        for (int epoch = 0; epoch < epochs; epoch++) {
            double gradW = 0;  // 权重梯度
            double gradB = 0;  // 偏置梯度

            // 计算梯度
            for (int i = 0; i < n; i++) {
                double predicted = model.predict(x[i]);
                double error = y[i] - predicted;

                // 损失函数对 w 和 b 的偏导数
                gradW += -2 * x[i] * error / n;
                gradB += -2 * error / n;
            }

            // 更新参数(梯度下降)
            model.w -= learningRate * gradW;
            model.b -= learningRate * gradB;

            // 每100轮打印一次
            if ((epoch + 1) % 100 == 0) {
                double loss = mseLoss(x, y, model);
                System.out.printf("Epoch %d: 损失=%.6f, 模型=%s%n",
                        epoch + 1, loss, model);
            }
        }

        System.out.println("==========================================");
        System.out.println("训练完成!");
        System.out.println("最终模型: " + model);
        System.out.println("最终损失: " + mseLoss(x, y, model));

        // 返回训练好的模型(最优参数)
        return model;
    }

    // 使用训练好的模型进行预测
    public static void predict(LinearModel model, double[] testX) {
        System.out.println("\n使用训练好的模型进行预测:");
        for (double x : testX) {
            double y = model.predict(x);
            System.out.printf("输入 x=%.2f, 预测 y=%.4f%n", x, y);
        }
    }

    public static void main(String[] args) {
        // 训练数据:y = 2x + 1
        double[] x = {1, 2, 3, 4, 5};
        double[] y = {3, 5, 7, 9, 11};
        //如果只有一组训练数据,可观察其训练效果,远远达不到目的,因为通过一组数据不能发现一组数据的规律。会导致过拟合,只在测试数据上使损失值降低最低,用其他数据测就效果不好。
        //可以这样认为  y = kx + b ,如果只有一组x,y的已知数据,那么k和b的组合有好多种,随着x,y的数据越多,越能使用k,b都能兼顾到。
        // 因为
//        double[] x = {1};
//        double[] y = {3};

        // 训练模型
        // 输入:损失函数(MSE)+ 数据 + 超参数
        // 输出:最优参数 (w≈2, b≈1)
        LinearModel trainedModel = train(x, y, 0.01, 2000);

        // 使用训练结果进行预测
        double[] testX = {6, 7, 8};
        predict(trainedModel, testX);
    }
}

另外:训练的数据集不能太少,例如:

python 复制代码
//        double[] x = {1};
//        double[] y = {3};
相关推荐
郝亚军1 小时前
IEEE 754 单精度浮点的SEM表示
开发语言·c++·算法
青山师1 小时前
动态规划算法深度解析:从状态转移方程到工业级优化
数据结构·算法·面试·动态规划·代理模式·java面试
黎阳之光2 小时前
数智透明·安全兜底|黎阳之光透明矿山,AI+数字孪生守护矿山生命线
人工智能·物联网·算法·安全·数字孪生
吴可可1232 小时前
控制弦高精度的样条离散化方法
算法
人工智能培训2 小时前
设备故障?数字孪生提前预警
人工智能·深度学习·神经网络·机器学习·生成对抗网络
风落无尘2 小时前
第十一章《对齐与安全》 完整学习资料
python·安全·机器学习
wuweijianlove2 小时前
算法设计中的空间复用与数据对齐优化的技术5
算法
yuan199973 小时前
基于 MATLAB PSO 工具箱的函数寻优算法
开发语言·算法·matlab
YUANQIANG20243 小时前
博弈论中势函数与势博弈构造:为什么看似 “先射箭后画靶”
算法·信息与通信