【机器学习|Python】sklearn中的逻辑回归模型

前言

本文主要说明 Python 的 sklearn 库中的随机森林模型的常用接口、属性以及参数调优说明。需要读者或多或少了解过sklearn库和一些基本的机器学习知识。

sklearn中的逻辑回归

sklearn中的逻辑回归相关类:

  • 逻辑回归模型:sklearn.linear_model.LogisticRegression
  • 交叉熵损失(又称对数损失或逻辑损失):sklearn.metrics.log_loss

基本使用

模型基本使用(以sklearn中的乳腺癌细胞数据集为例):

python 复制代码
from sklearn.datasets import load_breast_cancer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

# 数据集
breast_cancer = load_breast_cancer()
X_train, X_test, y_train, y_test = train_test_split(breast_cancer.data, breast_cancer.target, test_size=0.3, random_state=42)

lr = LogisticRegression()  # 创建模型对象
lr.fit(X_train, y_train)  # 拟合训练集

print(lr.predict(X_test))  # 测试集上的预测结果
print(lr.score(X_test, y_test))   # 测试集上的准确率

常用属性和接口

  • .coef_:模型中各特征对应的权重系数(或参数)
  • .intercept_:逻辑回归中预测函数中的截距(或偏差)
  • .n_iter_:模型的实际迭代次数(二分类中仅构建一个模型,返回一个值;多分类模型中构建多个模型,返回多个值)
  • lr.classes_:输出数据集中包含的各标签对应的列表
  • predict_proba:预测各个样本属于各个标签的概率

参数说明

默认参数

python 复制代码
LogisticRegression(penalty='l2', dual=False, tol=0.0001, C=1.0, 
                   fit_intercept=True, intercept_scaling=1, class_weight=None, 
                   random_state=None, solver='lbfgs', max_iter=100,
                   multi_class='auto', verbose=0, warm_start=False, n_jobs=None, 
                   l1_ratio=None)

常用参数

  • penalty:正则化参数,选择正则化类型
  • C:对于引入正则化系数的逻辑回归模型,设置正则化强度
  • solver:模型求解器,选择模型优化算法
  • multi_class:设置分类模式或模型要处理的分类问题的类型

penalty & C

正则化:正则化是用来防止模型过拟合的过程,常用的有L1正则化和L2正则化两种选项,分别通过在损失函数后加上参数向量的L1范式或L2范式来实现,其中增加的范式被称为 "正则项"(或惩罚项),基于损失函数的最优化来求解的参数取值由于引入的正则项而发生改变,以此达到调节模型拟合程度的目的(截距一般是不参与正则化的)。

① penalty:默认值为 'l2'。

  • 'l1':指定l1正则化
  • 'l2':指定l2正则化

② C:默认值为 1.0,即默认正则项与损失函数的比值是 1 : 1;必须是一个大于0的浮点数,表示为正则化强度的倒数;C越大,正则化强度越小,C越小,正则化强度越大。

其它说明:L1正则化和L2正则化虽然都可以控制过拟合,但它们的效果并不相同。当正则化强度逐渐增大(即C逐渐变小),各权重系数(参数)的取值会逐渐变小,但L1正则化会趋于将参数压缩为0,而L2正则化则趋于让参数尽量小,不会取到0,由于这个特性,L1正则化也常用于基于Embedded嵌入法的特征选择。

参数设置:通常情况下一般设置参数penalty 为 'l2',当进行特征选择时我们一般使用 'l1',当有效特征数量较少时一般使用 'l2';对于参数C,可以根据不同情况绘制对应学习曲线观察效果后进行调整。

solver & multi_class

① multi_class:默认值为 'auto';该参数表示设置分类方式

  • 'auto':根据数据的分类情况和其他参数来确定模型要处理的分类问题的类型;若训练数据为二分类或solver的取值为 "liblinear",选择"ovr";反之选择 "multinomial"。
  • 'ovr':表示分类问题是二分类,或让模型使用 "一对多" 的形式来处理多分类问题。
  • 'multinomial':表示处理多分类问题,让模型使用 "多对多" 的形式来处理多分类问题。

② solver:默认值为'lbfgs';设置优化算法(模型求解算法)

  • 'liblinear':坐标下降法
  • 'lbfgs':拟牛顿法的一种
  • 'newton-cg':牛顿法的一种
  • 'sag':随机平均梯度下降
  • 'saga':随机平均梯度下降的变形

其它说明:对于solver参数,牛顿法和拟牛顿法相关算法一般使用二阶导数矩阵(海森矩阵)来进行模型迭代,因此若使用 'lbfgs' 或 'newton-cg',则无法使用l1正则化;随机平均梯度下降若使用l1正则化会产生一些问题,我们的损失函数一般为凸函数,但当使用l1正则化后,可能引发一些例如多个极小值或某些区域不可导的问题,因此传统的梯度下降法一般不适用带l1正则化的情况;即:'lbfgs'、'newton-cg'、'sag' 不能使用 l1 正则化,只能使用 l2 正则化,'liblinear' 和 'saga' l1正则化和l2正则化都可以使用;此外,坐标下降法 'liblinear' 不能用于 'multinomial' 模式下的求解算法,只能用于 'ovr' 模式。

对于multi_class,一般使用 'auto',在二分类情况下,一般 'ovr' 有着更好的效果,相对地,在多分类任务下 'multinomial' 一般有着更好的效果;在多分类任务下,'ovr' 和 'multinomial',对于 N 分类任务,都会建立 N 个分类器来进行决策,但相对ovr,'multinomial' 在多分类任务上有着更严谨的数学过程,'ovr' 则对数据的质量要求比较高。

参数设置:逻辑回归一般使用 'saga' 作为优化算法,若进行特征选择,由于只能使用l1正则化,因此只能考虑 'liblinear' 或 'saga',一般使用 'saga';对于 multi_class,一般使用 'auto'。

其它参数

① max_iter:默认值为100,;设置模型在训练集上的最大迭代次数,当模型在该参数之前已收敛,则会自动停止迭代,若未收敛,则到达该参数时停止迭代。

② fit_intercept:布尔值,默认值为True;若为True,则添加截距(偏差)到模型中;若为False,则不包含截距(偏差)。

③ random_state:默认为None,随机数生成器为np.random模块下的一个RandomState实例;当模型求解器为 'liblinear' 或 'sag' 时才有效,一般输入整数。

④ n_jobs:默认为None;在多分类任务下平行计算所使用的CPU线程数,一般使用 -1,表示使用所有可用的线程数;当参数solver设置为'liblinear'时,忽略此参数。

③ class_weight:用于处理样本不平衡的情况,但这个参数我们一般不用,对于这种场景,我们更多地使用上采样或下采样方法来应对,对于逻辑回归来说,上采样一般为首要选择。

总结

在逻辑回归模型中,对于需要手动调参的参数的数量比较少,penalty、solver、multi_class一般通过相应任务和使用场景就可以优先调好,而参数C我们则可以通过画学习曲线的方式来进行调整,但在训练模型之前我们可能会先基于Embedded嵌入法进行特征选择,由于我们要基于加上正则化的模型对特征进行选择,这时我们一般先对参数C绘制学习曲线来优先确定参数C的取值以得到更好的模型,然后基于使用该参数的模型对嵌入法绘制学习曲线以确定保留特征个数。

Reference

相关推荐
q0_0p30 分钟前
从零开始的Python世界生活——基础篇(Python字典)
python·python基础
上海合宙LuatOS35 分钟前
直接抄作业!Air780E模组LuatOS开发:位运算(bit)示例
人工智能·单片机·嵌入式硬件·物联网·硬件工程·iot
databook39 分钟前
manim边做边学--圆柱体
python·动效
池央1 小时前
深度学习模型:卷积神经网络(CNN)
人工智能·深度学习·cnn
deephub1 小时前
Scikit-learn Pipeline完全指南:高效构建机器学习工作流
人工智能·python·机器学习·scikit-learn
知来者逆1 小时前
首次公开用系统审查与评估大语言模型安全性的数据集
人工智能·机器学习·语言模型·自然语言处理·llm·大语言模型
麻衣带我去上学1 小时前
Pytest使用Jpype调用jar包报错:Windows fatal exception: access violation
windows·python·pytest·jar
易风有点疯2 小时前
Python:序列化
开发语言·python
亚图跨际2 小时前
Python和R统计检验比较各组之间的免疫浸润
python·r语言·统计检验
Promising_GEO2 小时前
使用R语言绘制简单地图的教程
开发语言·python·r语言