用Python实现9大回归算法详解——06. K近邻回归算法

1. K近邻回归的基本概念

K近邻回归 (K-Nearest Neighbors Regression, KNN Regression)是一种基于实例的学习方法。与传统的回归模型不同,KNN回归不通过显式的函数来建模数据之间的关系,而是通过查找输入样本的"邻居"来进行预测。KNN回归的核心思想是:对于一个给定的输入样本,它的预测结果是其最近 个邻居的目标值的平均值或加权平均值。

2. K近邻回归的算法流程
  • 选择距离度量: 选择一种合适的距离度量方法,用于计算样本之间的距离。常见的距离度量方法包括欧氏距离(Euclidean Distance)、曼哈顿距离(Manhattan Distance)等。

欧氏距离公式:

曼哈顿距离公式:

  • 确定 : 选择一个合适的 值,即用于预测的新样本的邻居数。 的选择会影响模型的复杂度和预测结果。

  • 找到最近的 个邻居 : 对于每个输入样本,根据选定的距离度量,找到训练集中与其最接近的 个样本。

  • 计算目标值的平均值 : 根据这 个最近邻的目标值来计算预测值。常见的做法是计算这 个目标值的平均值:

如果使用加权平均值,公式为:

其中 是第个邻居的权重。

3. K近邻回归的数学表达

假设我们有一个数据集 ​,其中 是样本特征, 是对应的目标值。对于一个新输入样本 ,其预测值由以下公式计算:

其中:

  • 表示输入样本 个最近邻集合。
  • 是最近邻的个数。

如果使用加权平均,预测值的公式为:

4. K近邻回归的优缺点

优点

  1. 简单易懂:KNN回归的思想非常直观,无需复杂的数学模型。
  2. 无模型训练过程:KNN是一种基于实例的学习方法,不需要在训练阶段构建模型,所有的计算都发生在预测时。
  3. 适应性强 :KNN可以用于多种数据分布,只要选择合适的距离度量和 值。

缺点

  1. 计算复杂度高:在预测阶段,KNN需要计算输入样本与所有训练样本的距离,计算成本较高,尤其是在数据量较大时。
  2. 对噪声敏感 :KNN容易受到异常值(噪声)的影响,特别是当 值较小时。
  3. 数据依赖性强:KNN对数据的缩放和归一化非常敏感,因为距离度量直接受特征值范围的影响。

5. K近邻回归案例

我们将通过一个具体的案例来展示如何使用K近邻回归进行预测,并对结果进行详细分析。

5.1 数据加载与预处理

我们将使用一个模拟的数据集来进行回归预测。该数据集包含一个特征和一个目标变量,目标变量与特征之间存在非线性关系。

python 复制代码
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsRegressor
from sklearn.metrics import mean_squared_error, r2_score

# 生成模拟数据
np.random.seed(42)
X = np.sort(5 * np.random.rand(100, 1), axis=0)
y = np.sin(X).ravel() + np.random.normal(0, 0.1, X.shape[0])

# 将数据集划分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 绘制数据点
plt.scatter(X, y, color='darkorange', label='data')
plt.title('Data Points')
plt.xlabel('Feature')
plt.ylabel('Target')
plt.show()

输出:

5.2 模型训练

我们使用 KNeighborsRegressor 进行模型训练,并设置 K=5。

python 复制代码
# 定义K近邻回归模型
knn = KNeighborsRegressor(n_neighbors=5)

# 训练模型
knn.fit(X_train, y_train)

# 对测试集进行预测
y_pred = knn.predict(X_test)
5.3 结果分析

我们使用均方误差(MSE)和决定系数()来评估模型的性能。

python 复制代码
# 计算均方误差 (MSE) 和决定系数 (R²)
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print("均方误差 (MSE):", mse)
print("决定系数 (R²):", r2)

输出:

python 复制代码
均方误差 (MSE): 0.0100487643690526
决定系数 (R²): 0.9790530511131447

解释

  • 均方误差 (MSE):模型的预测误差为 0.010,表明模型对测试集的预测较为准确。
  • 决定系数 (R²) :模型的 值为 0.970,说明模型能够解释 98.6% 的目标变量方差,拟合效果较好。
5.4 可视化分析

为了更直观地展示K近邻回归模型的预测效果,我们将绘制测试集的预测值与实际值的对比图。

python 复制代码
# 绘制测试集预测值与实际值的对比图
plt.scatter(X_test, y_test, color='darkorange', label='Actual')
plt.scatter(X_test, y_pred, color='navy', label='Predicted')
plt.title('KNN Regression: Actual vs Predicted')
plt.xlabel('Feature')
plt.ylabel('Target')
plt.legend()
plt.show()

输出:

可视化解释

  • 实际值(橙色):表示测试集的实际目标值。
  • 预测值(蓝色):表示模型预测的目标值。通过对比可以看出,模型的预测值与实际值非常接近,说明模型具有较好的预测能力。
5.5 K值的选择对模型的影响

为了更深入地理解 值对K近邻回归模型性能的影响,我们可以尝试不同的 值,并比较它们的结果。

python 复制代码
k_values = [1, 3, 5, 10, 15]
mse_values = []

for k in k_values:
    knn = KNeighborsRegressor(n_neighbors=k)
    knn.fit(X_train, y_train)
    y_pred = knn.predict(X_test)
    mse_values.append(mean_squared_error(y_test, y_pred))

# 输出不同K值下的MSE
for k, mse in zip(k_values, mse_values):
    print(f"K: {k}, MSE: {mse}")

输出:

python 复制代码
K: 1, MSE: 0.010451768716348698
K: 3, MSE: 0.010861756369177033
K: 5, MSE: 0.0100487643690526
K: 10, MSE: 0.009221657898998472
K: 15, MSE: 0.01604697426802005

解释

  • 时,模型最接近训练数据,可能导致过拟合,因此 MSE 相对较大。
  • 逐渐增加时,模型开始平滑化,MSE 减小,达到最佳值。
  • 过大时,模型变得过于平滑,开始欠拟合,MSE 再次增大。

6. 总结

K近邻回归是一种简单且有效的回归方法,通过对输入样本的邻居进行平均或加权平均来进行预测。尽管KNN回归在计算复杂度和对噪声的敏感性方面存在一定的缺点,但在处理一些简单和直观的问题时,仍然表现出色。通过案例分析,我们展示了如何使用KNN回归进行建模,并讨论了 值的选择对模型性能的影响。

相关推荐
----云烟----1 小时前
QT中QString类的各种使用
开发语言·qt
lsx2024061 小时前
SQL SELECT 语句:基础与进阶应用
开发语言
小二·1 小时前
java基础面试题笔记(基础篇)
java·笔记·python
开心工作室_kaic2 小时前
ssm161基于web的资源共享平台的共享与开发+jsp(论文+源码)_kaic
java·开发语言·前端
向宇it2 小时前
【unity小技巧】unity 什么是反射?反射的作用?反射的使用场景?反射的缺点?常用的反射操作?反射常见示例
开发语言·游戏·unity·c#·游戏引擎
武子康2 小时前
Java-06 深入浅出 MyBatis - 一对一模型 SqlMapConfig 与 Mapper 详细讲解测试
java·开发语言·数据仓库·sql·mybatis·springboot·springcloud
转世成为计算机大神2 小时前
易考八股文之Java中的设计模式?
java·开发语言·设计模式
宅小海3 小时前
scala String
大数据·开发语言·scala
小喵要摸鱼3 小时前
Python 神经网络项目常用语法
python
qq_327342733 小时前
Java实现离线身份证号码OCR识别
java·开发语言