线性回归的总结:

代码

java 复制代码
import java.util.Random;

public class ModelTraining {

    // 简单的线性模型: y = wx + b
    static class LinearModel {
        double w;  // 权重
        double b;  // 偏置
//        static Random random = new Random(42);  // 固定种子
        public LinearModel() {
            this.w = Math.random();  // 随机初始化
            this.b = Math.random();
        }

        // 前向传播:预测
        public double predict(double x) {
            return w * x + b;
        }

        @Override
        public String toString() {
            return String.format("y = %.4fx + %.4f", w, b);
        }
    }

    /**
     * 均方误差损失函数
     */
    public static double mseLoss(double[] x, double[] y, LinearModel model) {
        double sum = 0;
        int n = x.length;

        for (int i = 0; i < n; i++) {
            double predicted = model.predict(x[i]);
            double error = y[i] - predicted;
            sum += error * error;
        }

        return sum / n;
    }

    /**
     * 训练模型(梯度下降)
     * 输入:数据、损失函数、学习率、迭代次数
     * 输出:训练好的模型参数
     */
    public static LinearModel train(
            double[] x,
            double[] y,
            double learningRate,
            int epochs) {

        LinearModel model = new LinearModel();
        int n = x.length;

        System.out.println("初始模型: " + model);
        System.out.println("初始损失: " + mseLoss(x, y, model));
        System.out.println("==========================================");

        // 训练循环
        for (int epoch = 0; epoch < epochs; epoch++) {
            double gradW = 0;  // 权重梯度
            double gradB = 0;  // 偏置梯度

            // 计算梯度
            for (int i = 0; i < n; i++) {
                double predicted = model.predict(x[i]);
                double error = y[i] - predicted;

                // 损失函数对 w 和 b 的偏导数
                gradW += -2 * x[i] * error / n;
                gradB += -2 * error / n;
            }

            // 更新参数(梯度下降)
            model.w -= learningRate * gradW;
            model.b -= learningRate * gradB;

            // 每100轮打印一次
            if ((epoch + 1) % 100 == 0) {
                double loss = mseLoss(x, y, model);
                System.out.printf("Epoch %d: 损失=%.6f, 模型=%s%n",
                        epoch + 1, loss, model);
            }
        }

        System.out.println("==========================================");
        System.out.println("训练完成!");
        System.out.println("最终模型: " + model);
        System.out.println("最终损失: " + mseLoss(x, y, model));

        // 返回训练好的模型(最优参数)
        return model;
    }

    // 使用训练好的模型进行预测
    public static void predict(LinearModel model, double[] testX) {
        System.out.println("\n使用训练好的模型进行预测:");
        for (double x : testX) {
            double y = model.predict(x);
            System.out.printf("输入 x=%.2f, 预测 y=%.4f%n", x, y);
        }
    }

    public static void main(String[] args) {
        // 训练数据:y = 2x + 1
        double[] x = {1, 2, 3, 4, 5};
        double[] y = {3, 5, 7, 9, 11};
        //如果只有一组训练数据,可观察其训练效果,远远达不到目的,因为通过一组数据不能发现一组数据的规律。会导致过拟合,只在测试数据上使损失值降低最低,用其他数据测就效果不好。
        //可以这样认为  y = kx + b ,如果只有一组x,y的已知数据,那么k和b的组合有好多种,随着x,y的数据越多,越能使用k,b都能兼顾到。
        // 因为
//        double[] x = {1};
//        double[] y = {3};

        // 训练模型
        // 输入:损失函数(MSE)+ 数据 + 超参数
        // 输出:最优参数 (w≈2, b≈1)
        LinearModel trainedModel = train(x, y, 0.01, 2000);

        // 使用训练结果进行预测
        double[] testX = {6, 7, 8};
        predict(trainedModel, testX);
    }
}

另外:训练的数据集不能太少,例如:

python 复制代码
//        double[] x = {1};
//        double[] y = {3};
相关推荐
JieE2128 小时前
LeetCode 56. 合并区间|超清晰 JS 图解思路,面试高频区间题
javascript·算法·面试
Jack2015 小时前
HarmonyOS开发中错误处理策略:网络异常统一处理
算法
哥布林学者16 小时前
深度学习进阶(三十一)FlashAttention:IO 感知的精确注意力
机器学习·ai
小小杨树17 小时前
读懂色彩:拍照调色不再难
算法·计算机视觉·配色
JieE2121 天前
LeetCode 226. 翻转二叉树|JS 递归超详细拆解,二叉树入门经典题
javascript·算法
JieE2121 天前
LeetCode 104. 二叉树的最大深度|递归思路超详细拆解
javascript·算法
vivo互联网技术2 天前
CVPR 2026 | 全新强化学习框架 BeautyGRPO:重塑真实人像
算法·大模型·cvpr·影像
Darling噜啦啦2 天前
列表转树算法深度解析:从 Map 到 Reduce 的两种实现,面试高频考点
数据结构·算法·面试
用户497863050732 天前
(一)小红的数组操作
算法·编程语言