线性回归学习

一、线性回归简介核心思想:线性回归是一种通过属性的线性组合来做预测的模型。它的目标很明确,就是找到一条合适的直线、平面或者更高维度的超平面,让预测出来的值和实际真实值之间的差距尽可能小。比如在预测房屋价格时,就可以根据房屋大小这个属性,拟合出一条能预测价格的直线。一般形式:对于一个有多个属性描述的样本,线性回归会把这些属性分别乘以对应的权重,再加上一个偏置项,得到预测结果。用向量的形式可以更简洁地表示这种组合关系。二、模型求解:最小二乘法基本原理:最小二乘法是基于 "欧氏距离" 来寻找最优模型的方法。它的核心是找到一条线,让所有样本点到这条线的欧氏距离加起来最小,也就是让预测值和真实值之间的误差总和最小。参数估计:这个过程就是找到最合适的权重和偏置项,使得误差函数的值最小。这里的误差函数反映的是所有样本预测误差的平方和。求解过程:通过对误差函数分别关于权重和偏置项求导,然后让导数等于 0,就能计算出权重和偏置项的最优值。三、线性回归的评估指标误差平方和 / 残差平方和(SSE/RSS):它是把每个样本的预测值和真实值之间的差值平方后加起来的结果,能反映出预测值和真实值之间的总误差大小。平方损失 / 均方误差(MSE):是误差平方和的平均值,它消除了样本数量对误差结果的影响,方便不同数据集之间进行误差比较。R 方(\(R^2\)):这个指标用来衡量模型对数据的拟合效果。它的值越接近 1,说明模型对真实数据的解释能力越强,拟合效果也就越好。简单来说,就是看预测值能解释真实值变化的程度。四、多元线性回归模型形式:当样本有多个属性时,多元线性回归会把每个属性都考虑进来,每个属性都有对应的权重,再加上一个偏置项,共同组成预测公式。矩阵表示:对于多元线性回归,可以用矩阵的形式来表示模型和相关计算,这样在处理大规模数据时会更方便高效。五、实践应用:sklearn 中的线性回归相关函数:在 Python 的 sklearn 库中,linear_model.LinearRegression()函数可以直接实现线性回归算法。主要参数:fit_intercept:用于设置模型是否包含偏置项,如果设置为 False,那么拟合的直线会经过原点,默认是包含偏置项的。normalize:设置是否对数据进行归一化处理,默认是不进行归一化。应用示例:比如用这个函数可以实现对波士顿房价的预测,通过输入房屋的各种属性,得到预测的房价。通过本次学习,我对线性回归的基本概念、求解方法、评估方式和实际应用有了全面的了解。线性回归作为一种简单实用的模型,在很多预测场景中都发挥着重要作用。

代码参考

导入必要的库

import pandas as pd # 用于数据处理和分析的工具库

import numpy as np # 用于数值计算的基础库

import matplotlib.pyplot as plt # 用于数据可视化的库

from sklearn.datasets import load_diabetes # 从sklearn加载糖尿病数据集

from sklearn.model_selection import train_test_split # 用于划分训练集和测试集

from sklearn.linear_model import LinearRegression # 线性回归模型

from sklearn.metrics import mean_squared_error, r2_score # 模型评估指标

1. 加载糖尿病数据集

该数据集包含442名糖尿病患者的10项生理特征和1年后的病情发展指标

diabetes = load_diabetes()

2. 数据处理与格式化

将特征数据转换为DataFrame格式,便于查看和处理

diabetes.data包含特征数据,diabetes.feature_names是特征名称列表

X = pd.DataFrame(diabetes.data, columns=diabetes.feature_names)

将目标值(标签)转换为Series格式,目标值表示患者1年后的病情发展情况

y = pd.Series(diabetes.target, name='DiseaseProgression')

3. 数据探索:查看数据集基本信息

print("数据集基本信息:")

print(f"特征数量:{X.shape[1]}") # X.shape[1]表示列数,即特征数量

print(f"样本数量:{X.shape[0]}") # X.shape[0]表示行数,即样本数量

print("\n特征名称:", diabetes.feature_names) # 打印所有特征的名称

print("\n数据统计描述:")

打印数据的统计信息(均值、标准差、最小值、最大值等)

print(X.describe())

4. 划分训练集和测试集

train_test_split函数用于将数据集随机划分为训练集和测试集

test_size=0.3表示测试集占总数据的30%,训练集占70%

random_state=42设置随机种子,确保每次运行划分结果一致,保证实验可复现

X_train, X_test, y_train, y_test = train_test_split(

X, y, test_size=0.3, random_state=42

)

5. 构建线性回归模型并训练

创建线性回归模型实例,fit_intercept=True表示模型会计算偏置项(截距)

model = LinearRegression(fit_intercept=True)

使用训练集数据训练模型,即通过最小二乘法找到最优的权重和偏置

模型会学习特征(X_train)与目标值(y_train)之间的线性关系

model.fit(X_train, y_train)

6. 输出模型参数(权重和偏置)

print("\n模型参数:")

遍历每个特征及其对应的系数(权重)

系数的大小和符号表示该特征对预测结果的影响程度和方向

for feature, coef in zip(diabetes.feature_names, model.coef_):

print(f"{feature}: {coef:.4f}") # 保留4位小数输出

输出偏置项(intercept),即线性方程中的常数项

print(f"偏置(intercept): {model.intercept_:.4f}")

7. 使用训练好的模型进行预测

用测试集的特征数据进行预测,得到模型的预测结果y_pred

y_pred = model.predict(X_test)

8. 模型评估:计算评估指标

计算均方误差(MSE):预测值与真实值差值的平方的平均值

mse = mean_squared_error(y_test, y_pred)

计算均方根误差(RMSE):MSE的平方根,与目标值单位一致,更易解释

rmse = np.sqrt(mse)

计算R方(R²):表示模型解释的方差比例,范围0-1,越接近1说明拟合效果越好

r2 = r2_score(y_test, y_pred)

打印评估指标结果

print("\n模型评估指标:")

print(f"均方误差(MSE):{mse:.4f}")

print(f"均方根误差(RMSE):{rmse:.4f}")

print(f"R方(R^2):{r2:.4f}")

9. 结果可视化:真实值与预测值对比

创建画布,设置大小为10x6英寸

plt.figure(figsize=(10, 6))

绘制散点图:x轴为真实值,y轴为预测值,alpha设置点的透明度

plt.scatter(y_test, y_pred, alpha=0.5, label='预测值 vs 真实值')

绘制参考线:y=x的虚线,表示理想情况下预测值等于真实值

plt.plot([y.min(), y.max()], [y.min(), y.max()], 'r--', label='理想预测线')

设置坐标轴标签和图表标题

plt.xlabel('真实病情发展指标')

plt.ylabel('预测病情发展指标')

plt.title('糖尿病病情预测:真实值 vs 预测值')

添加图例和网格线

plt.legend()

plt.grid(alpha=0.3) # alpha设置网格线透明度

显示图表

plt.show()

10. 特征重要性可视化

计算各特征系数的绝对值,用于表示特征重要性(绝对值越大影响越大)

coef_abs = np.abs(model.coef_)

创建画布,设置大小为10x6英寸

plt.figure(figsize=(10, 6))

绘制柱状图展示各特征的重要性

plt.bar(diabetes.feature_names, coef_abs)

设置图表标题和坐标轴标签

plt.title('特征重要性(基于系数绝对值)')

plt.xlabel('特征名称')

plt.ylabel('系数绝对值')

添加y轴网格线,便于查看数值

plt.grid(axis='y', alpha=0.3)

显示图表

plt.show()

相关推荐
MowenPan199524 分钟前
高等数学 9.1多元函数的基本概念
笔记·学习·高等数学
没有梦想的咸鱼185-1037-16632 小时前
AI大模型支持下的:CMIP6数据分析与可视化、降尺度技术与气候变化的区域影响、极端气候分析
人工智能·python·深度学习·机器学习·chatgpt·数据挖掘·数据分析
今天也要学习吖4 小时前
Azure TTS Importer:一键导入,将微软TTS语音接入你的阅读软件!
人工智能·学习·microsoft·ai·大模型·aigc·azure
楼田莉子5 小时前
C++算法学习专题:滑动窗口
开发语言·数据结构·c++·学习·算法·leetcode
小晶晶京京5 小时前
day38-HTTP
网络·网络协议·学习·http
炸膛坦客5 小时前
C++ 学习与 CLion 使用:(四)常量和变量,包括字面常量和符号常量
开发语言·c++·学习
zheshiyangyang5 小时前
uni-app学习【pages】
前端·学习·uni-app
livemetee6 小时前
Flink2.0学习笔记:使用HikariCP 自定义sink实现数据库连接池化
大数据·数据库·笔记·学习·flink
艾莉丝努力练剑8 小时前
《递归与迭代:从斐波那契到汉诺塔的算法精髓》
c语言·学习·算法