机器学习(二)-简单线性回归

文章目录

    • [1. 简单线性回归理论](#1. 简单线性回归理论)
    • [2. python通过简单线性回归预测房价](#2. python通过简单线性回归预测房价)
      • [2.1 预测数据](#2.1 预测数据)
      • 2.2导入标准库
      • [2.3 导入数据](#2.3 导入数据)
      • [2.4 划分数据集](#2.4 划分数据集)
      • [2.5 导入线性回归模块](#2.5 导入线性回归模块)
      • [2.6 对测试集进行预测](#2.6 对测试集进行预测)
      • [2.7 计算均方误差 J](#2.7 计算均方误差 J)
      • [2.8 计算参数 w0、w1](#2.8 计算参数 w0、w1)
      • [2.9 可视化训练集拟合结果](#2.9 可视化训练集拟合结果)
      • [2.10 可视化测试集拟合结果](#2.10 可视化测试集拟合结果)
      • [2.11 保存模型](#2.11 保存模型)
      • [2.12 加载模型并预测](#2.12 加载模型并预测)

在机器学习和统计学中,简单线性回归是一种基础而强大的工具,用于建立自变量与因变量之间的关系。

假设你是一个房产中介,想通过房屋面积来预测房价。简单线性回归可以帮助你找到房屋面积与房价之间的线性关系,进而为客户提供更合理的报价。

本文将带你深入了解简单线性回归的理论基础、公式推导以及如何在Python中实现这一模型。

1. 简单线性回归理论

简单线性回归的基本假设是,因变量 Y(例如房价)与自变量 X(例如人口)之间存在线性关系。我们可以用以下的线性方程来表示这种关系:

其中:

  • y 是因变量(我们要预测的变量)。

  • x 是自变量(我们用来进行预测的变量)。

  • w0是截距(当x=0) 时,y的值)。

  • w1是斜率(自变量变化一个单位时,因变量的变化量)。

我们的目标是求 w0和w1的值,来找到一条跟预测值相关的直线。

从图中我们可以看出预测值与真实值之间存在误差,那么我们引入机器学习中的一个概念均方误差,它表示的是这些差值的平方和的平均数。这些误差的表达式如下:

均方误差的表达式如下:

2. python通过简单线性回归预测房价

2.1 预测数据

数据如下:

tex 复制代码
polulation,median_house_value
961,3.03
234,0.68
1074,2.92
1547,4.24
805,2.39
597,1.59
784,2.21
498,1.31
1602,4.28
292,0.54
1499,4.18
718,1.95
180,0.43
1202,3.62
1258,3.48
453,1.08
845,2.31
1032,2.96
384,0.68
896,2.62
425,0.82
928,2.95
1324,3.59
1435,4.02
543,1.62
1132,3.34
328,0.76
638,1.54
1389,3.78
692,1.79

x 轴是人口数量,y轴是房价

2.2导入标准库

python 复制代码
# 导入标准库
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd
matplotlib.use('TkAgg')

2.3 导入数据

python 复制代码
# 导入数据集
dataset = pd.read_csv('Data.csv')
x = dataset.iloc[:, :-1]
y = dataset.iloc[:, 1]

2.4 划分数据集

python 复制代码
# 数据集划分 训练集/测试集
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(x,y,test_size=0.2,random_state=0)

2.5 导入线性回归模块

python 复制代码
# 简单线性回归算法
from sklearn.linear_model import LinearRegression
regressor = LinearRegression()
regressor.fit(X_train, y_train)

2.6 对测试集进行预测

python 复制代码
# 对测试集进行预测
y_pred = regressor.predict(X_test)

2.7 计算均方误差 J

python 复制代码
# 计算J
J = 1/X_train.shape[0] * np.sum((regressor.predict(X_train) - y_train)**2)
print("J = {}".format(J))

输出结果:

tex 复制代码
J = 0.031198935319832692

2.8 计算参数 w0、w1

python 复制代码
# 计算参数 w0、w1
w0 = regressor.intercept_
w1 = regressor.coef_[0]
print("w0 = {}, w1 = {}".format(w0, w1))

输出结果:

tex 复制代码
w0 = -0.16411984840092098, w1 = 0.0029383965595942067

2.9 可视化训练集拟合结果

python 复制代码
# 可视化训练集拟合结果
plt.figure(1)
plt.scatter(X_train, y_train, color = 'red')
plt.plot(X_train, regressor.predict(X_train), color = 'blue')
plt.title('population VS median_house_value (training set)')
plt.xlabel('population')
plt.ylabel('median_house_value')
plt.show()

输出结果:

可以很好的看到拟合的直线可以很好的表示原始数据的人口和房价的走势

2.10 可视化测试集拟合结果

python 复制代码
# 可视化测试集拟合结果
plt.figure(2)
plt.scatter(X_test, y_test, color = 'red')
plt.plot(X_train, regressor.predict(X_train), color = 'blue')
plt.title('population VS median_house_value (test set)')
plt.xlabel('population')
plt.ylabel('median_house_value')
plt.show()

输出结果:

可以看到,拟合的直线在测试集上的表现是相当不错了,说明我们训练的线性模型有很好的应用效果。

2.11 保存模型

python 复制代码
# 保存模型
import pickle
with open('../model/simple_house_price_model.pkl','wb') as file:
    pickle.dump(regressor,file);

2.12 加载模型并预测

python 复制代码
import pickle
import numpy as np
import pandas as pd
# 加载模型并预测
with open('../model/simple_house_price_model.pkl','rb') as file:
    model = pickle.load(file)

x_test = np.array([693,694])
x_test = pd.DataFrame(x_test)
x_test.columns=['polulation']
y_pred = model.predict(x_test)
print(y_pred)

输出结果:

tex 复制代码
[1.87218897 1.87512736]
相关推荐
audyxiao00116 分钟前
人工智能顶级期刊PR论文解读|HCRT:基于相关性感知区域的混合网络,用于DCE-MRI图像中的乳腺肿瘤分割
网络·人工智能·智慧医疗·肿瘤分割
零售ERP菜鸟18 分钟前
IT价值证明:从“成本中心”到“增长引擎”的确定性度量
大数据·人工智能·职场和发展·创业创新·学习方法·业界资讯
叫我:松哥22 分钟前
基于大数据和深度学习的智能空气质量监测与预测平台,采用Spark数据预处理,利用TensorFlow构建LSTM深度学习模型
大数据·python·深度学习·机器学习·spark·flask·lstm
童话名剑1 小时前
目标检测(吴恩达深度学习笔记)
人工智能·目标检测·滑动窗口·目标定位·yolo算法·特征点检测
木卫四科技1 小时前
【木卫四 CES 2026】观察:融合智能体与联邦数据湖的安全数据运营成为趋势
人工智能·安全·汽车
珠海西格电力7 小时前
零碳园区有哪些政策支持?
大数据·数据库·人工智能·物联网·能源
じ☆冷颜〃7 小时前
黎曼几何驱动的算法与系统设计:理论、实践与跨领域应用
笔记·python·深度学习·网络协议·算法·机器学习
启途AI7 小时前
2026免费好用的AIPPT工具榜:智能演示文稿制作新纪元
人工智能·powerpoint·ppt
TH_17 小时前
35、AI自动化技术与职业变革探讨
运维·人工智能·自动化
楚来客7 小时前
AI基础概念之八:Transformer算法通俗解析
人工智能·算法·transformer