【scikit-learn基础】--『监督学习』之 LASSO回归

LASSOLeast Absolute Shrinkage and Selection Operator)回归模型一般都是用英文缩写表示,

硬要翻译的话,可翻译为 最小绝对收缩和选择算子

它是一种线性回归模型的扩展,其主要目标是解决高维数据中的特征选择和正则化问题。

1. 概述

LASSO中,通过使用L1正则化 项,它能够在回归系数中引入稀疏性,

也就是允许某些系数在优化过程中缩减为零,从而实现特征的选择。

与岭回归不同的是,LASSO的损失函数一般定义为:\(L(w) = (y-wX)^2+\lambda\parallel w\parallel_1\)

其中 \(\lambda\parallel w\parallel_1\),也就是 L1正则化项 (岭回归中用的是 L2正则化项)。

模型训练的过程就是寻找让损失函数\(L(w)\)最小的参数\(w\)。

也就等价于:\(\begin{align} & arg\ min(y-wX)^2 \\ & s.t. \sum |w_{ij}| < s \end{align}\)

这两个公式表示,在满足约束条件 \(\sum |w_{ij}| < s\)的情况下,计算 \((y-wX)^2\)的最小值。

2. 创建样本数据

相比于岭回归模型,LASSO回归模型不仅对于共线性数据集友好,

对于高维数据的数据集,也有不错的性能表现。

它通过将不重要的特征的系数压缩为零,帮助我们选择最重要的特征,从而提高模型的预测准确性和可解释性。

下面我们模拟创建一些高维数据,创建一个特征数比样本数还多的样本数据集。

python 复制代码
from sklearn.datasets import make_regression

X, y = make_regression(n_samples=80, n_features=100, noise=10)

这个数据集中,只有80个样本,每个样本却有100个特征,并且噪声也设置的很大(noise=10)。

3. 模型训练

第一步,分割训练集测试集

python 复制代码
from sklearn.model_selection import train_test_split

# 分割训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)

scikit-learn中的LASSO模型来训练:

python 复制代码
from sklearn.linear_model import Lasso

# 初始化LASSO线性模型
reg = Lasso()
# 训练模型
reg.fit(X_train, y_train)

这里使用的 Lasso()的默认参数来训练模型,它的主要参数包括:

  1. alpha :正则化项系数。它控制了L1正则化 项的强度,即对模型复杂度的惩罚。alpha越大,模型越简单,但过大的alpha可能会导致模型欠拟合;alpha越小,模型越复杂,但过小的alpha可能会导致模型过拟合。默认值为1.0
  2. fit_intercept :布尔值,指定是否需要计算截距b值 。如果设为False,则不计算b值默认值为True
  3. normalize :布尔值。如果设为True,则在模型训练之前将数据归一化。默认值为False
  4. precompute :布尔值,指定是否预先计算X的平方和。如果设为True,则在每次迭代之前计算X的平方和。默认值为False
  5. copy_X :布尔值,指定是否在训练过程中复制X。如果设为True,则在训练过程中复制X默认值为True
  6. max_iter :最大迭代次数。默认值为1000
  7. tol :阈值,用于判断是否达到收敛条件。默认值为1e-4
  8. warm_start :布尔值,如果设为True,则使用前一次的解作为本次迭代的起始点。默认值为False
  9. positive :布尔值,如果设为True,则强制系数为正。默认值为False
  10. selection :用于在每次迭代中选择系数的算法(有"cyclic"和"random"两种选择)。默认值为"cyclic",即循环选择。

最后验证模型的训练效果:

python 复制代码
from sklearn import metrics

y_pred = reg.predict(X_test)
mse = metrics.mean_squared_error(y_test, y_pred)
r2 = metrics.r2_score(y_test, y_pred)
m_error = metrics.median_absolute_error(y_test, y_pred)

print("均方误差:{}".format(mse))
print("复相关系数:{}".format(r2))
print("中位数绝对误差:{}".format(m_error))

# 运行结果
均方误差:441.07830708712186
复相关系数:0.9838880665687711
中位数绝对误差:11.643348614829785

误差看上去不小,因为这次实际生成的样本,不仅数量小(80件)且噪声大(noise=10)。

3.1. 与岭回归模型比较

单独看LASSO模型的训练结果,看不出其处理高维数据的优势。

同样用上面分割好的训练集测试集 ,来看看岭回归模型的拟合效果。

python 复制代码
from sklearn.linear_model import Ridge
# from sklearn.model_selection import train_test_split

mse, r2, m_error = 0.0, 0.0, 0.0

# 初始化岭回归线性模型
reg = Ridge()
# 训练模型
reg.fit(X_train, y_train)

y_pred = reg.predict(X_test)
mse = metrics.mean_squared_error(y_test, y_pred)
r2 = metrics.r2_score(y_test, y_pred)
m_error = metrics.median_absolute_error(y_test, y_pred)

print("均方误差:{}".format(mse))
print("复相关系数:{}".format(r2))
print("中位数绝对误差:{}".format(m_error))

# 运行结果
均方误差:6315.046844910431
复相关系数:0.7693207470296398
中位数绝对误差:60.65140692273637

对于高维数据,可以看出,岭回归模型的误差 远远大于 LASSO模型。

3.2. 与最小二乘法模型比较

同样用上面分割好的训练集测试集 ,再来看看线性模型(最小二乘法)的拟合效果。

python 复制代码
from sklearn.linear_model import LinearRegression

mse, r2, m_error = 0.0, 0.0, 0.0

# 初始化最小二乘法线性模型
reg = LinearRegression()
# 训练模型
reg.fit(X_train, y_train)

y_pred = reg.predict(X_test)
mse = metrics.mean_squared_error(y_test, y_pred)
r2 = metrics.r2_score(y_test, y_pred)
m_error = metrics.median_absolute_error(y_test, y_pred)

print("均方误差:{}".format(mse))
print("复相关系数:{}".format(r2))
print("中位数绝对误差:{}".format(m_error))

# 运行结果
均方误差:5912.442445894787
复相关系数:0.7840272859181612
中位数绝对误差:62.89225147465376

可以看出,线性模型 的训练效果和岭回归 模型差不多,但是都远远不如LASSO模型

4. 总结

总的来说,LASSO回归模型是一种流行的线性回归扩展,具有一些显著的优势和劣势。

比如,在特征选择 上,LASSO通过将某些系数压缩为零,能够有效地进行特征选择,这在高维数据集中特别有用。

此外,LASSO可以作为正则化工具,有助于防止过拟合。

不过,LASSO会假设特征是线性相关的,对于非线性关系的数据,效果可能不佳。

而且,如果数据存在复杂模式或噪声,LASSO可能会过度拟合这些模式。

相关推荐
用户8356290780511 天前
Python 实现 PDF 文件加密与解密方法
后端·python
用户8356290780511 天前
使用 Python 冻结与拆分 Excel 窗格教程
后端·python
你好潘先生1 天前
别再记命令了,用 yeero do 说句人话就能跑脚本,而且不烧 token
服务器·python·命令行
Agent_大师1 天前
WebSocket 行情重连成功,K线缺口不会自动消失
python
荣码1 天前
LLM结构化输出:让AI返回JSON而不是废话,我踩了4个坑
java·python
copyer_xyf1 天前
FastAPI 如何连接 MySQL
后端·python
apocelipes2 天前
常用编程语言和库的正则表达式性能对比
c语言·c++·python·性能优化·golang·开发工具和环境
用户8356290780512 天前
使用 Python 在 PDF 中创建与管理书签
后端·python
MeixianAgent2 天前
Python 回测数据入口怎么验?历史 K 线入库前先做 5 个检查
后端·python
咕白m6252 天前
用 Python 实现一键批量查找与替换 Excel 数据
后端·python