【监督学习】线性回归算法步骤及matlab实现

线性回归算法

线性回归算法

线性回归是一种统计方法,用于建模因变量(通常表示为 y y y)和一个或多个自变量(通常表示为 X 1 , X 2 , . . . , X n X_1,X_2,...,X_n X1,X2,...,Xn)之间的关系。它假设这些变量之间的关系是线性的,即可以通过一条直线(在一维情况下)或多维超平面(在多维情况下)来近似描述这种关系。线性回归的目标是找到最佳拟合直线或超平面,使得预测值与实际观测值之间的差异最小化。

1.算法步骤

数据准备 模型假设 参数估计 模型评估 数据清洗 数据收集 特征标准化 因变量与自变量线性关系 误差项独立同分布 最小二乘法优化 求解损失函数 计算R^2,RMSE等指标 残差分析
数据准备 模型假设 参数估计 模型评估 预测应用

  1. 数据准备

    • 数据收集:获取包含自变量(特征)和因变量(目标)的数据集(如房价与面积、楼层等);
    • 数据清洗:处理缺失值、异常值(如删除或填充);
    • 特征标准化:对特征进行归一化(如Z-score标准化),避免量纲差异影响模型收敛。
  2. 模型假设

    • 线性关系 :因变量 y y y 与自变量 X X X 之间满足线性关系: y = β 0 + β 1 X 1 + ⋯ + β n X n + ϵ y=β_0+β_1X_1+⋯+β_nX_n+ϵ y=β0+β1X1+⋯+βnXn+ϵ
    • 误差项 :误差 ϵ ϵ ϵ 服从均值为 0、方差为常数的正态分布,且相互独立。
  3. 参数估计

    • 目标 :找到最优参数 β β β 使得预测值与真实值的误差平方和最小;
    • 损失函数 : J ( β ) = 1 2 m ∑ i = 1 m ( y i − y ^ i ) 2 J(β)=\frac{1}{2m}\sum_{i=1}^{m}(y_i-\hat{y}_i)^2 J(β)=2m1i=1∑m(yi−y^i)2
    • 求解方法
      • 解析解 (即直接求解,适用于小数据): β = ( X T X ) − 1 X T y β=(\mathbf{X}^T\mathbf{X})^{−1}\mathbf{X}^Ty β=(XTX)−1XTy
      • 梯度下降(迭代优化):逐步调整参数,适用于大数据。
  4. 模型评估

    • R²(决定系数) :衡量模型解释的方差比例,越接近 1 越好: R 2 = 1 − S S E S S T R^2=1−\frac{SSE}{SST} R2=1−SSTSSE
    • MSE(均方误差) :预测值与真实值的平均平方误差: M S E = 1 m ∑ i = 1 m ( y i − y ^ i ) 2 MSE=\frac{1}{m}\sum_{i=1}^{m}(y_i-\hat{y}_i)^2 MSE=m1i=1∑m(yi−y^i)2
    • 残差分析:检查残差是否随机分布(理想情况为无趋势、无异方差)。
  5. 预测应用

  • 使用训练好的模型对新数据进行预测: y ^ = X n e w β \hat{y}=\mathbf{X}_{new}β y^=Xnewβ

2.MATLAB 实现

以房价预测为例,房价真实参数为: p r i c e = 1000 + 5 ∗ a r e a + 20 ∗ f l o o r + n o i s e price = 1000 + 5*area + 20*floor + noise price=1000+5∗area+20∗floor+noise
a r e a area area 为面积, f l o o r floor floor 为楼层, n o i s e noise noise 为其他影响, p r i c e price price 为最后每平方米的房价。

matlab 复制代码
%% 线性回归案例:房价预测
% 功能: 训练线性回归模型预测房价
clc; clear; close all;

%% 1. 生成模拟数据
rng(0); % 固定随机种子
m = 100; % 样本数量
area = 50 + 150*rand(m,1); % 面积 (50~200平方米)
floor = randi([1,20], m,1); % 楼层 (1~20层)
noise = 50*randn(m,1); % 噪声

% 真实参数: price = 1000 + 5*area + 20*floor + noise
price = 1000 + 5*area + 20*floor + noise;

% 合并特征矩阵X,添加截距项1
X = [ones(m,1), area, floor]; 

%% 2. 数据标准化(可选,但推荐)
% 标准化特征(截距项不标准化)
X(:,2:end) = zscore(X(:,2:end)); 

%% 3. 划分训练集和测试集(70%训练,30%测试)
split = 0.7;
trainSize = round(split * m);
X_train = X(1:trainSize, :);
y_train = price(1:trainSize);
X_test = X(trainSize+1:end, :);
y_test = price(trainSize+1:end);

%% 4. 参数估计:解析解(最小二乘法)
beta = (X_train' * X_train) \ (X_train' * y_train); % 直接求解

%% 5. 模型评估
% 预测训练集和测试集
y_train_pred = X_train * beta;
y_test_pred = X_test * beta;

% 计算R²和MSE
SSE_train = sum((y_train - y_train_pred).^2);
SST_train = sum((y_train - mean(y_train)).^2);
R2_train = 1 - SSE_train / SST_train;
MSE_train = mean((y_train - y_train_pred).^2);

SSE_test = sum((y_test - y_test_pred).^2);
SST_test = sum((y_test - mean(y_test)).^2);
R2_test = 1 - SSE_test / SST_test;
MSE_test = mean((y_test - y_test_pred).^2);

% 显示结果
fprintf('训练集: R²=%.3f, MSE=%.2f\n', R2_train, MSE_train);
fprintf('测试集: R²=%.3f, MSE=%.2f\n', R2_test, MSE_test);

%% 6. 可视化结果
figure;

% 残差图
subplot(1,2,1);
scatter(y_train_pred, y_train_pred - y_train, 'b');
hold on;
scatter(y_test_pred, y_test_pred - y_test, 'r');
xlabel('预测值');
ylabel('残差');
title('残差图(蓝色:训练集,红色:测试集)');
grid on;

% 预测值与真实值对比
subplot(1,2,2);
plot(y_test, y_test_pred, 'ro');
hold on;
plot([min(y_test), max(y_test)], [min(y_test), max(y_test)], 'k--');
xlabel('真实价格');
ylabel('预测价格');
title('测试集预测效果');
grid on;

%% 7. 输出参数
disp('模型参数:');
disp(['截距项: ', num2str(beta(1))]);
disp(['面积系数: ', num2str(beta(2))]);
disp(['楼层系数: ', num2str(beta(3))]);
相关推荐
机器学习之心12 小时前
多目标鲸鱼优化算法(NSWOA),含46种测试函数和9个评价指标,MATLAB实现
算法·matlab·多目标鲸鱼优化算法·46种测试函数·9个评价指标
max50060013 小时前
基于Meta Llama的二语习得学习者行为预测计算模型
人工智能·算法·机器学习·分类·数据挖掘·llama
王哥儿聊AI14 小时前
Lynx:新一代个性化视频生成模型,单图即可生成视频,重新定义身份一致性与视觉质量
人工智能·算法·安全·机器学习·音视频·软件工程
手握风云-15 小时前
优选算法的寻踪契合:字符串专题
算法
闭着眼睛学算法15 小时前
【华为OD机考正在更新】2025年双机位A卷真题【完全原创题解 | 详细考点分类 | 不断更新题目 | 六种主流语言Py+Java+Cpp+C+Js+Go】
java·c语言·javascript·c++·python·算法·华为od
IT古董16 小时前
【第五章:计算机视觉-项目实战之目标检测实战】2.目标检测实战:中国交通标志检测-(2)中国交通标志检测数据格式转化与读取
算法·目标检测·计算机视觉
MobotStone16 小时前
LLM 采样入门到进阶:理解与实践 Top-K、Top-P、温度控制
算法
Coovally AI模型快速验证16 小时前
从避障到实时建图:机器学习如何让无人机更智能、更安全、更实用(附微型机载演示示例)
人工智能·深度学习·神经网络·学习·安全·机器学习·无人机
杨小码不BUG16 小时前
CSP-J/S初赛知识点精讲-图论
c++·算法·图论··编码·csp-j/s初赛
东木君_17 小时前
RK3588:MIPI底层驱动学习——入门第三篇(IIC与V4L2如何共存?)
学习