机器学习-线性回归

简单线程回归

基本概念

简单线性回归是一种通过单一自变量(x)来预测因变量(y)的统计方法。它假设两个变量之间存在线性关系,并试图找到最佳拟合直线来描述这种关系。

核心目标

寻找最佳拟合线,使得预测误差最小化。具体来说,是最小化观测值(Y)与预测值(Yp)之间的残差平方和:

min⁡∑(yi−yp)2min∑(y iy p)2

实现步骤

1. 数据预处理

  • 导入必要的Python库(如numpy, pandas)
  • 加载数据集
  • 处理缺失值
  • 将数据集划分为训练集和测试集
  • 必要时进行特征缩放(在简单线性回归中通常不需要)

2. 模型训练

  • 使用sklearn.linear_model中的LinearRegression类
  • 创建回归器对象
  • 调用fit()方法在训练数据上训练模型

3. 结果预测

  • 使用训练好的模型对测试集进行预测
  • 将预测结果存储在变量中(如Y_pred)

4. 结果可视化

  • 使用matplotlib.pyplot绘制结果
  • 可视化训练集和测试集的预测效果
  • 直观评估模型性能

应用特点

  • 适用于探索两个连续变量间的线性关系
  • 计算简单,解释性强
  • 是许多复杂模型的基础

📎studentscores.csvhttps://www.yuque.com/attachments/yuque/0/2025/csv/35582240/1753025467553-3d007ead-1c53-4d81-b007-787efcf45091.csv

复制代码
# 第一步:数据预处理
# Pandas 是 Python 中最强大的数据分析库,主要用于数据操作和分析。
import pandas as pd
# NumPy 是 Python 科学计算的基础包,提供高性能的多维数组对象和相关工具。
import numpy as np
# Matplotlib 是 Python 主要的 2D 绘图库,用于数据可视化。
import matplotlib.pyplot as plt

# 读取CSV文件中的数据,文件路径为相对路径
dataset = pd.read_csv('../doc/studentscores.csv')
# 提取数据集中的特征变量(这里假设第一列是特征变量)
X = dataset.iloc[ : ,   : 1 ].values
# 提取数据集中的目标变量(这里假设第二列是目标变量)
Y = dataset.iloc[ : , 1 ].values
# 从sklearn库中导入用于划分训练集和测试集的函数
from sklearn.model_selection import train_test_split
# 将数据集划分为训练集和测试集,测试集大小为数据集的1/4,同时设置随机种子以保证结果的可重复性
X_train, X_test, Y_train, Y_test = train_test_split( X, Y, test_size = 1/4, random_state = 0)

# 第二步:训练集使用简单线性回归模型来训练
# 导入线性回归模型类
from sklearn.linear_model import LinearRegression
# 实例化线性回归模型
regressor = LinearRegression()
# 使用训练数据拟合线性回归模型
regressor = regressor.fit(X_train, Y_train)

# 第三步:预测结果
# 使用训练好的回归器对测试集进行预测
Y_pred = regressor.predict(X_test)

# 第四步:可视化
# 绘制训练集数据点,颜色为红色
plt.scatter(X_train , Y_train, color = 'red')
# 绘制回归线,颜色为蓝色
plt.plot(X_train , regressor.predict(X_train), color ='blue')
# 显示图形
plt.show()

# 测试集结果可视化
plt.scatter(X_test , Y_test, color = 'red')
plt.plot(X_test , regressor.predict(X_test), color ='blue')
plt.show()

多元线性回归

多元线性回归尝试通过用一个线性方程来适配观测数据,这个线性方程是在两个以上(包括两个)的特征和响应之间构建的一个关系。多元线性回归的实现步骤和简单线性回归很相似,在评价部分有所不同。你可以用它来找出在预测结果上哪个因素影响力最大,以及不同变量是如何相互关联的。

📎50_Startups.csvhttps://www.yuque.com/attachments/yuque/0/2025/csv/35582240/1753025498788-f4d80292-cf6e-4c80-825e-47b326698c8b.csv

第1步: 数据预处理

导入库

复制代码
import pandas as pd
import numpy as np

导入数据集

复制代码
# 从CSV文件中读取数据集
# pandas库用于数据处理和分析
# 这里的文件路径是相对路径,指向'doc'目录下的'50_Startups.csv'文件
dataset = pd.read_csv('../doc/50_Startups.csv')

# 提取数据集中除最后一列的所有列作为特征数据
# iloc方法用于通过位置索引获取数据
# :-1表示获取从第一列到倒数第二列的所有列
# .values将提取到的数据转换为NumPy数组,以便后续处理
X = dataset.iloc[ : , :-1].values

# 提取数据集中的最后一列作为目标变量
# 这里使用4作为索引是因为DataFrame的索引是从0开始的
# 如果数据集没有表头或者表头也被认为是数据的一部分,那么最后一列的索引就是4
# 同样使用.values方法将数据转换为NumPy数组
Y = dataset.iloc[ : ,  4 ].values

将类别数据数字化

复制代码
# 导入LabelEncoder和OneHotEncoder用于数据预处理
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.compose import ColumnTransformer
# 初始化LabelEncoder
labelencoder = LabelEncoder()

# 使用LabelEncoder将类别标签转换为数字
X[:, 3] = labelencoder.fit_transform(X[:, 3])

# 使用ColumnTransformer + OneHotEncoder
# 只对第3列(索引为3)进行独热编码,其他列保持不变
ct = ColumnTransformer(
    transformers=[('onehot', OneHotEncoder(), [3])],
    remainder='passthrough'  # 保留其他列
)

# 应用转换并转换为数组
X = ct.fit_transform(X)  # 已经返回NumPy数组

躲避虚拟变量陷阱

复制代码
# 从矩阵X中移除第一列,只保留从第二列到最后一列的数据
# 作用:去除独热编码后的一个特征列(通常是第一列)。
# 原因:在回归问题中,为了避免多重共线性(multicollinearity),通常会去掉一个类别特征作为"基准"。
X = X[: , 1:]

拆分数据集为训练集和测试集

复制代码
# 从sklearn的model_selection模块中导入train_test_split函数
from sklearn.model_selection import train_test_split

# 使用train_test_split函数将数据集X和Y随机划分为训练集和测试集
# 参数test_size=0.2表示将20%的数据用作测试集,剩余80%的数据用作训练集
# 参数random_state确保每次划分时都能得到相同的数据集,这里设置为0
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size = 0.2, random_state = 0)

第2步: 在训练集上训练多元线性回归模型

复制代码
# 导入线性回归模型类
from sklearn.linear_model import LinearRegression

# 实例化线性回归模型
regressor = LinearRegression()

# 使用训练数据拟合线性回归模型
regressor.fit(X_train, Y_train)

第3步: 在测试集上预测结果

复制代码
# 使用训练好的回归器模型对测试集进行预测,得到预测结果y_pred
y_pred = regressor.predict(X_test)
相关推荐
boooo_hhh1 小时前
第40周——GAN入门
人工智能·python·机器学习
ChaoQiezi1 小时前
Python:如何在Pycharm中显示geemap地图?
python·gee
小白学大数据1 小时前
1688商品数据抓取:Python爬虫+动态页面解析
爬虫·python·okhttp
华科云商xiao徐2 小时前
突破Python性能墙:关键模块C++化的爬虫优化指南
c++·爬虫·python
躲在云朵里`2 小时前
常用Linux指令:Java/MySQL/Tomcat/Redis/Nginx运维指南
开发语言·python
小白狮ww3 小时前
蛋白质设计新高度,RFdiffusion 实现从零设计高亲和力蛋白质
人工智能·python·开源
星火飞码iFlyCode3 小时前
真实案例 | 如何用iFlyCode开发Webpack插件?
java·python·编辑器
三只熊猫3 小时前
一文打通 AI 知识脉络:大语言模型等关键内容详解
人工智能·python
Kyln.Wu3 小时前
【python实用小脚本-187】Python一键批量改PDF文字:拖进来秒出新文件——再也不用Acrobat来回导
python·pdf·c#
xnglan4 小时前
蓝桥杯手算题和杂题简易做法
数据结构·数据库·c++·python·算法·职场和发展·蓝桥杯