逻辑回归OvR策略

逻辑回归本身是为解决二分类问题而设计的,但聪明的前辈们很快就将其应用于多类别分类问题。这种方法被称为一对多(One-vs-Rest, OvR)策略,它通过将多类别问题分解为多个二分类问题来实现。

示例场景:水果分类

假设我们有一个简单的分类问题,需要根据水果的某些特征(如重量、颜色、大小等)将其分类为苹果、橘子或香蕉。

类别及其表示

类别有三种:苹果、橘子、香蕉。

Step1.创建三个二分类逻辑回归模型:

模型A:区分苹果和非苹果。

模型B:区分橘子和非橘子。

模型C:区分香蕉和非香蕉。

Step2.训练每个模型:

对于模型A,我们将所有苹果的样本标记为1(正类),所有非苹果(橘子和香蕉)的样本标记为0(负类),然后进行训练。

类似地,对于模型B和模型C,分别以橘子和香蕉作为正类进行训练。

Step3.进行预测:

当有一个新的水果样本需要分类时,我们将该样本分别输入到这三个模型中。

每个模型都会输出一个概率,分别表示该样本为苹果、橘子或香蕉的概率。

Step4.选择最高概率的类别:

比较三个模型输出的概率,最高概率对应的类别即为该样本的预测类别。

例如,如果模型A输出的概率是0.7,模型B是0.2,模型C是0.1,那么我们将这个样本分类为苹果,因为苹果的概率最高。

这种OvR策略的优点在于它的简单性和直观性,使得我们可以利用已有的二分类逻辑回归算法来解决更复杂的多类别分类问题。然而,它的缺点是每个分类器都是独立工作的,没有考虑不同类别之间可能存在的关系。此外,当类别数量非常多时,这种方法可能会变得低效,因为需要训练大量的分类器。

Scikit-Learn库中内置的鸢尾花数据集非常适合来做多分类,这里我们直接用。简单介绍一下鸢尾花数据集:鸢尾花(Iris)数据集是机器学习和统计学中最著名的数据集之一。它最初由著名的统计学家和生物学家罗纳德·费舍尔(Ronald Fisher)在1936年介绍,并且因其在模式识别文献中的广泛使用而变得非常知名。这个数据集通常用于演示分类算法的效果。

样本分为三个鸢尾花的品种:Setosa、Versicolour 和 Virginica,每个品种包含50个样本。每个样本都有四个特征:萼片长度(sepal length)、萼片宽度(sepal width)、花瓣长度(petal length)和花瓣宽度(petal width)。

OvR分类算法代码:

bash 复制代码
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import classification_report

# 加载鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target

# 划分数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 创建逻辑回归模型,并使用OneVsRestClassifier来实现OvR策略
ovr_classifier = OneVsRestClassifier(LogisticRegression())

# 训练模型
ovr_classifier.fit(X_train, y_train)

# 在测试集上进行预测
y_pred = ovr_classifier.predict(X_test)

# 评估模型
print(classification_report(y_test, y_pred))

分类结果:

类别 精度 (Precision) 召回率 (Recall) F1得分 支持 (Support)
0 1.00 1.00 1.00 19
1 1.00 0.85 0.92 13
2 0.87 1.00 0.93 13
总计 0.96 0.95 0.95 45

指标解释

精度(Precision):

精度是指预测为正类(比如某个特定鸢尾花种类)的样本中,实际为正类的比例。

计算公式为:Precision = True Positives / (True Positives + False Positives)。

高精度表示误将负类判为正类的情况较少。

召回率(Recall):

召回率是指所有正类样本中,被正确预测为正类的比例。

计算公式为:Recall = True Positives / (True Positives + False Negatives)。

高召回率表示模型能够捕捉到大部分的正类样本。

F1得分(F1 Score):

F1得分是精度和召回率的调和平均,用于衡量模型的总体准确性。

计算公式为:F1 Score = 2 * (Precision * Recall) / (Precision + Recall)。

它是一个介于0和1之间的数,越接近1表示模型越好。

支持(Support):

支持数是指每个类别在测试集中的真实样本数。

它可以帮助我们理解每个类别的重要性和数据集中的分布。

可以看到在鸢尾花数据集上,逻辑回归模型配合OvR策略总体准确率达到了96%,效果还是比较好的,尤其是对Setosa(类别0)的分类准确率达到了100%。在Versicolour和Virginica的分类上,虽然也有较高的准确率,但存在一些误分类,这可能与这两个类别在特征空间中的重叠有关。

相关推荐
_WndProc7 分钟前
C++ 日志输出
开发语言·c++·算法
努力学习编程的伍大侠20 分钟前
基础排序算法
数据结构·c++·算法
qq_5290252934 分钟前
Torch.gather
python·深度学习·机器学习
XiaoLeisj1 小时前
【递归,搜索与回溯算法 & 综合练习】深入理解暴搜决策树:递归,搜索与回溯算法综合小专题(二)
数据结构·算法·leetcode·决策树·深度优先·剪枝
IT古董1 小时前
【漫话机器学习系列】017.大O算法(Big-O Notation)
人工智能·机器学习
Jasmine_llq1 小时前
《 火星人 》
算法·青少年编程·c#
闻缺陷则喜何志丹1 小时前
【C++动态规划 图论】3243. 新增道路查询后的最短距离 I|1567
c++·算法·动态规划·力扣·图论·最短路·路径
海棠AI实验室2 小时前
AI的进阶之路:从机器学习到深度学习的演变(三)
人工智能·深度学习·机器学习
Lenyiin2 小时前
01.02、判定是否互为字符重排
算法·leetcode