[机器学习]04-逻辑回归(python)-03-API与癌症分类案例讲解

**逻辑回归(**Logistic Regression) 的一API 介绍

关于如何配置模型中的优化器、正则化和惩罚项。

1. 逻辑回归 API

在 Scikit-learn 中,逻辑回归可以通过如下方式定义:

python 复制代码
from sklearn.linear_model import LogisticRegression
​
model = LogisticRegression(solver='liblinear', penalty='l2', C=1.0)

solver:这是用于优化损失函数的算法(或求解器),因为我们需要根据给定的数据去寻找模型参数,所以这个求解器决定了使用哪种算法去最小化损失函数。

penalty :这是正则化的类型,可以选择 'l1''l2',它是为了防止模型过拟合。

C:这是正则化的强度参数,值越小正则化越强,默认是 1.0。

2. solver(求解器)介绍

**solver**是用于最小化损失函数的算法选项,在不同的数据规模和正则化条件下,选择合适的求解器会影响计算效率和结果。

liblinear :这个求解器适合 小数据集 ,速度更快。它采用 坐标下降法,可以用于 L1 和 L2 正则化。因为它的计算复杂度较低,所以在小数据集上的表现很好.

sag、saga :这两个求解器适合 大数据集 ,它们使用的是 随机梯度下降法(Stochastic Average Gradient Descent),这种方法可以在处理大量数据时表现得更高效。

sagsaga 支持 L2 正则化,或者也可以不使用正则化。

  • saga 支持 L1 和 L2 正则化,所以它更灵活,特别适合处理稀疏数据集(很多 0 值的数据)。

3. 正则化

正则化(penalty) 是用于防止模型 过拟合 的一种技术。过拟合意味着模型在训练集上表现很好,但在测试集上表现不佳,正则化的作用就是让模型在训练时不过于依赖训练数据。

  • L1 正则化 :又称为 Lasso 正则化 ,它会让某些权重变为 0,从而达到特征选择的效果。它适合处理有很多不相关特征的数据集,因为它会自动选择出相关特征,忽略掉不相关的特征。
  • L2 正则化 :又称为 Ridge 正则化,它不会让权重变为 0,但会压缩权重的值。它倾向于使模型的参数尽量小,从而减少模型的复杂度。

4. 惩罚参数 C

C 是正则化的强度,值越小,正则化越强

  • 当 ( C ) 很小的时候,模型的正则化效果很强,它会强制模型的权重变得更小,以此来减少模型的复杂性。
  • 当 ( C ) 很大的时候,正则化的效果很弱,模型允许有较大的权重,可能会导致模型过拟合。

5. 总结 API 的使用场景

  • 当我们面对 小数据集 时,可以选择 liblinear 求解器,它的计算速度更快,适合处理小规模问题。
  • 当面对 大数据集 时,推荐使用 sagsaga,因为它们能够更高效地处理大规模数据集。
  • 正则化的选择:
    • 如果我们想要做 特征选择 或者数据集中的特征较多,使用 L1 正则化(例如在稀疏数据中,比如文本分类)。
    • 如果我们仅仅想防止过拟合,但不需要特征选择,使用 L2 正则化

举个简单的例子:

假设你在做一个文本分类任务,比如垃圾邮件分类。你有一个非常大的数据集,每封邮件都可以用几万个单词表示,那么这里很多单词是不相关的,你可以使用:

python 复制代码
model = LogisticRegression(solver='saga', penalty='l1', C=1.0)
  • 选择 saga,因为数据集很大,它的求解效率高。
  • 选择 L1 正则化,因为它可以自动选择出有用的特征(即对垃圾邮件的预测有用的单词)。
  • 通过这种组合,你可以让模型在大规模数据集上表现得更好,同时还能筛选出关键特征。

我们使用一份在线的癌症数据来进行分析,下面是数据的info()信息和PPT 里面的字段会有差异,但是不影响。

python 复制代码
import pandas as pd
url = "https://github.com/akmand/datasets/raw/main/breast_cancer_wisconsin.csv"
data = pd.read_csv(url)
data.info()

<class 'pandas.core.frame.DataFrame'>

RangeIndex: 569 entries, 0 to 568

Data columns (total 31 columns):

Column Non-Null Count Dtype


0 mean_radius 569 non-null float64

1 mean_texture 569 non-null float64

2 mean_perimeter 569 non-null float64

3 mean_area 569 non-null float64

4 mean_smoothness 569 non-null float64

5 mean_compactness 569 non-null float64

6 mean_concavity 569 non-null float64

7 mean_concave_points 569 non-null float64

8 mean_symmetry 569 non-null float64

9 mean_fractal_dimension 569 non-null float64

10 radius_error 569 non-null float64

11 texture_error 569 non-null float64

12 perimeter_error 569 non-null float64

13 area_error 569 non-null float64

14 smoothness_error 569 non-null float64

15 compactness_error 569 non-null float64

16 concavity_error 569 non-null float64

17 concave_points_error 569 non-null float64

18 symmetry_error 569 non-null float64

19 fractal_dimension_error 569 non-null float64

20 worst_radius 569 non-null float64

21 worst_texture 569 non-null float64

22 worst_perimeter 569 non-null float64

23 worst_area 569 non-null float64

24 worst_smoothness 569 non-null float64

25 worst_compactness 569 non-null float64

26 worst_concavity 569 non-null float64

27 worst_concave_points 569 non-null float64

28 worst_symmetry 569 non-null float64

29 worst_fractal_dimension 569 non-null float64

30 diagnosis 569 non-null object

dtypes: float64(30), object(1)

memory usage: 137.9+ KB

根据字段信息,以下是翻译后的中文版本:

  1. mean_radius - 平均半径

  2. mean_texture - 平均纹理

  3. mean_perimeter - 平均周长

  4. mean_area - 平均面积

  5. mean_smoothness - 平均平滑度

  6. mean_compactness - 平均致密度

  7. mean_concavity - 平均凹陷

  8. mean_concave_points - 平均凹点

  9. mean_symmetry - 平均对称性

  10. mean_fractal_dimension - 平均分形维数

  11. radius_error - 半径误差

  12. texture_error - 纹理误差

  13. perimeter_error - 周长误差

  14. area_error - 面积误差

  15. smoothness_error - 平滑度误差

  16. compactness_error - 致密度误差

  17. concavity_error - 凹陷误差

  18. concave_points_error - 凹点误差

  19. symmetry_error - 对称性误差

  20. fractal_dimension_error - 分形维数误差

  21. worst_radius - 最差半径

  22. worst_texture - 最差纹理

  23. worst_perimeter - 最差周长

  24. worst_area - 最差面积

  25. worst_smoothness - 最差平滑度

  26. worst_compactness - 最差致密度

  27. worst_concavity - 最差凹陷

  28. worst_concave_points - 最差凹点

  29. worst_symmetry - 最差对称性

  30. worst_fractal_dimension - 最差分形维数

  31. diagnosis - 诊断(通常 "B" 表示良性,"M" 表示恶性)

这些翻译可以帮助你更好地理解数据集中的各个特征。

python 复制代码
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
import pandas as pd
import numpy as np

def dm_LogisticRegression():
    # 1. 获取数据
    url = "https://github.com/akmand/datasets/raw/main/breast_cancer_wisconsin.csv"
    data = pd.read_csv(url)
    
    # 2. 基本数据处理
    # 2.1 缺失值处理
    data = data.replace(to_replace="?", value=np.NaN)
    data = data.dropna()
    
    # 2.2 确定特征值和目标值
    x = data.iloc[:, 1:-1]  # 去除 ID 列
    y = data.iloc[:, -1]  # 目标值为最后一列,通常为 'Diagnosis'
    
    # 2.3 按 8:2 比例分割数据集
    x_train, x_test, y_train, y_test = train_test_split(x, y,train_size=0.8, random_state=22)
    
    # 3. 特征工程(标准化)
    transfer = StandardScaler()
    x_train = transfer.fit_transform(x_train)
    x_test = transfer.transform(x_test)
    
    # 4. 机器学习模型训练(逻辑回归)
    estimator = LogisticRegression()
    estimator.fit(x_train, y_train)
    
    # 5. 模型评估
    y_predict = estimator.predict(x_test)
    print('y_predict -->', y_predict)
    accuracy = estimator.score(x_test, y_test)
    print('accuracy -->', accuracy)

# 调用函数
dm_LogisticRegression()

结果分析:

1. y_predict --> ['B', 'M', 'M', 'M', ...]:

• 这是模型对测试集的预测结果,其中:

• 'B' 表示 良性(Benign)肿瘤。

• 'M' 表示 恶性(Malignant)肿瘤。

• 模型对每个样本进行了分类,给出了它是良性还是恶性肿瘤。

2. accuracy --> 0.951048951048951:

准确率(Accuracy)是模型在测试集上的表现指标,定义为模型预测正确的样本数量占总测试样本数量的比例。

• 该模型的准确率为 95.1% ,这意味着模型在测试集中 95.1% 的样本被正确分类为良性或恶性肿瘤。

结论:

模型表现良好:95.1% 的准确率表明该逻辑回归模型对乳腺癌数据集有着较好的分类能力,大多数情况下能够正确判断肿瘤是良性还是恶性。

进一步改进:虽然 95.1% 的准确率已经较高,但在实际应用中可以进一步优化模型(例如通过调参、使用更复杂的模型、处理不平衡数据等)以提升准确率和鲁棒性。

总体来说,该结果表明模型在检测良性和恶性肿瘤的分类任务上表现较好,但仍有一定的错误分类。

相关推荐
老艾的AI世界3 小时前
AI翻唱神器,一键用你喜欢的歌手翻唱他人的曲目(附下载链接)
人工智能·深度学习·神经网络·机器学习·ai·ai翻唱·ai唱歌·ai歌曲
DK221513 小时前
机器学习系列----关联分析
人工智能·机器学习
FreedomLeo13 小时前
Python数据分析NumPy和pandas(四十、Python 中的建模库statsmodels 和 scikit-learn)
python·机器学习·数据分析·scikit-learn·statsmodels·numpy和pandas
风间琉璃""4 小时前
二进制与网络安全的关系
安全·机器学习·网络安全·逆向·二进制
Java Fans4 小时前
梯度提升树(Gradient Boosting Trees)详解
机器学习·集成学习·boosting
谢眠5 小时前
机器学习day6-线性代数2-梯度下降
人工智能·机器学习
sp_fyf_20246 小时前
【大语言模型】ACL2024论文-19 SportsMetrics: 融合文本和数值数据以理解大型语言模型中的信息融合
人工智能·深度学习·神经网络·机器学习·语言模型·自然语言处理
sp_fyf_20248 小时前
【大语言模型】ACL2024论文-18 MINPROMPT:基于图的最小提示数据增强用于少样本问答
人工智能·深度学习·神经网络·目标检测·机器学习·语言模型·自然语言处理
爱喝白开水a8 小时前
Sentence-BERT实现文本匹配【分类目标函数】
人工智能·深度学习·机器学习·自然语言处理·分类·bert·大模型微调
封步宇AIGC9 小时前
量化交易系统开发-实时行情自动化交易-4.2.3.指数移动平均线实现
人工智能·python·机器学习·数据挖掘