代码
java
import java.util.Random;
public class ModelTraining {
// 简单的线性模型: y = wx + b
static class LinearModel {
double w; // 权重
double b; // 偏置
// static Random random = new Random(42); // 固定种子
public LinearModel() {
this.w = Math.random(); // 随机初始化
this.b = Math.random();
}
// 前向传播:预测
public double predict(double x) {
return w * x + b;
}
@Override
public String toString() {
return String.format("y = %.4fx + %.4f", w, b);
}
}
/**
* 均方误差损失函数
*/
public static double mseLoss(double[] x, double[] y, LinearModel model) {
double sum = 0;
int n = x.length;
for (int i = 0; i < n; i++) {
double predicted = model.predict(x[i]);
double error = y[i] - predicted;
sum += error * error;
}
return sum / n;
}
/**
* 训练模型(梯度下降)
* 输入:数据、损失函数、学习率、迭代次数
* 输出:训练好的模型参数
*/
public static LinearModel train(
double[] x,
double[] y,
double learningRate,
int epochs) {
LinearModel model = new LinearModel();
int n = x.length;
System.out.println("初始模型: " + model);
System.out.println("初始损失: " + mseLoss(x, y, model));
System.out.println("==========================================");
// 训练循环
for (int epoch = 0; epoch < epochs; epoch++) {
double gradW = 0; // 权重梯度
double gradB = 0; // 偏置梯度
// 计算梯度
for (int i = 0; i < n; i++) {
double predicted = model.predict(x[i]);
double error = y[i] - predicted;
// 损失函数对 w 和 b 的偏导数
gradW += -2 * x[i] * error / n;
gradB += -2 * error / n;
}
// 更新参数(梯度下降)
model.w -= learningRate * gradW;
model.b -= learningRate * gradB;
// 每100轮打印一次
if ((epoch + 1) % 100 == 0) {
double loss = mseLoss(x, y, model);
System.out.printf("Epoch %d: 损失=%.6f, 模型=%s%n",
epoch + 1, loss, model);
}
}
System.out.println("==========================================");
System.out.println("训练完成!");
System.out.println("最终模型: " + model);
System.out.println("最终损失: " + mseLoss(x, y, model));
// 返回训练好的模型(最优参数)
return model;
}
// 使用训练好的模型进行预测
public static void predict(LinearModel model, double[] testX) {
System.out.println("\n使用训练好的模型进行预测:");
for (double x : testX) {
double y = model.predict(x);
System.out.printf("输入 x=%.2f, 预测 y=%.4f%n", x, y);
}
}
public static void main(String[] args) {
// 训练数据:y = 2x + 1
double[] x = {1, 2, 3, 4, 5};
double[] y = {3, 5, 7, 9, 11};
//如果只有一组训练数据,可观察其训练效果,远远达不到目的,因为通过一组数据不能发现一组数据的规律。会导致过拟合,只在测试数据上使损失值降低最低,用其他数据测就效果不好。
//可以这样认为 y = kx + b ,如果只有一组x,y的已知数据,那么k和b的组合有好多种,随着x,y的数据越多,越能使用k,b都能兼顾到。
// 因为
// double[] x = {1};
// double[] y = {3};
// 训练模型
// 输入:损失函数(MSE)+ 数据 + 超参数
// 输出:最优参数 (w≈2, b≈1)
LinearModel trainedModel = train(x, y, 0.01, 2000);
// 使用训练结果进行预测
double[] testX = {6, 7, 8};
predict(trainedModel, testX);
}
}

另外:训练的数据集不能太少,例如:
python
// double[] x = {1};
// double[] y = {3};