spark 3.4.4 机器学习基于逻辑回归算法及管道流实现鸢尾花分类预测案例

知识点介绍:本文基于鸢尾花案例,实现逻辑回归分类预测案例

逻辑回归算法介绍

1. 定义与基本原理
  • 定义:逻辑回归(Logistic Regression)是一种广泛应用于分类问题的统计学习方法,尽管名字里带有 "回归",但它主要用于解决二分类问题,也可以通过扩展的方式(如 One-vs-Rest、One-vs-One 等策略)应用于多分类问题。
  • 基本原理:逻辑回归基于线性回归的思想,首先对输入特征进行线性组合(通过计算特征向量与权重向量的内积,再加上一个偏置项),得到一个线性的预测值,然后将这个线性预测值通过一个非线性的激活函数(通常是 Sigmoid 函数用于二分类,Softmax 函数用于多分类)进行转换,将结果映射到 0 到 1 之间(二分类时表示属于某一类别的概率,多分类时表示属于各个类别的概率分布),最后根据设定的阈值(一般为 0.5 用于二分类)来判断样本属于哪一类。

例如,对于一个简单的二分类逻辑回归,假设输入特征向量为,权重向量为,偏置项为,则线性组合的计算公式为:,然后通过 Sigmoid 函数将转换为概率值,如果,则预测样本属于正类,否则属于负类。

2. 应用场景
  • 医疗领域:判断患者是否患有某种疾病,比如根据患者的年龄、症状、检查指标等特征来预测是否患有糖尿病、心脏病等。例如,通过收集大量已确诊患者和健康人的相关数据,利用逻辑回归构建模型,对新的患者进行疾病风险预测。
  • 金融领域:评估客户的信用风险,决定是否给客户发放贷款等。例如,依据客户的收入、资产、信用历史、负债情况等特征,模型预测客户违约的概率,银行根据这个概率来决定是否批准贷款申请。
  • 市场营销领域:预测客户是否会购买某种产品或服务,基于客户的消费行为、人口统计学特征、浏览历史等,企业可以针对性地开展营销活动,提高营销效果。
3. 优缺点
  • 优点
    • 简单易懂且计算效率高:逻辑回归的原理相对直观,基于线性组合和简单的函数变换,模型的训练和预测过程计算复杂度较低,在大规模数据集上也能较快地得到结果。
    • 可解释性强:模型的权重可以直观地反映出每个特征对分类结果的影响程度,通过分析权重的正负和大小,能够了解特征与目标类别之间的关联关系,有助于业务理解和决策。
    • 模型训练稳定:一般情况下,逻辑回归在合理的参数设置下不容易出现梯度消失、梯度爆炸等训练不稳定的问题,能够较为稳定地收敛到一个较优的解。
  • 缺点
    • 只能处理线性可分问题(原始特征空间下):如果数据在原始特征空间中是非线性可分的,逻辑回归的分类效果可能不佳,需要通过人工特征工程(如添加多项式特征等)或者结合核技巧等方法来扩展特征空间,使其能够处理非线性关系。
    • 对特征要求较高:需要对输入的特征进行合理的选择和预处理,如果存在冗余、无关或者高度相关的特征,可能会影响模型的性能,导致过拟合或者欠拟合等问题,所以往往需要进行特征选择和特征缩放等操作。
    • 容易欠拟合:在处理复杂的非线性关系数据时,由于其模型本身的线性本质,相较于一些复杂的非线性模型(如深度学习模型),可能较难拟合数据中的复杂模式,容易出现欠拟合现象。
4. 重要参数
  • 最大迭代次数(maxIter:控制模型训练时的迭代次数,用于确保模型能够收敛到一个相对稳定的解。如果设置的值过小,模型可能还未充分学习到数据中的模式就停止训练,导致欠拟合;而设置过大则可能导致训练时间过长,甚至可能出现过拟合的情况,需要根据具体数据集和问题进行合理调整,通常通过交叉验证等方法来寻找合适的值。
  • 正则化参数(regParam:为了防止模型过拟合,在损失函数中添加正则项,正则化参数用于控制正则项的强度。常见的正则化方式有 L1 正则化(Lasso 回归)、L2 正则化(岭回归)以及两者结合的 ElasticNet 正则化。合适的正则化参数可以平衡模型的复杂度和对训练数据的拟合程度,不同数据集的最优值不同,同样需要通过实验来确定。
  • ElasticNet 参数(ElasticNetParam:当使用 ElasticNet 正则化时,该参数用于控制 L1 和 L2 正则化的比例,取值范围在 0 到 1 之间,为 0 时表示只使用 L2 正则化,为 1 时表示只使用 L1 正则化,介于两者之间则是两者按相应比例结合。L1 正则化有助于进行特征选择,能将一些不重要的特征对应的系数压缩为 0,而 L2 正则化能使模型的系数更加平滑,避免过大的系数值导致过拟合。

Spark ML 管道流(Pipeline)介绍

1. 概念与作用
  • 概念:Spark ML 中的管道流(Pipeline)是一种将多个机器学习相关的处理步骤按照顺序组合成一个整体工作流的工具。它类似于工业生产中的流水线,数据从一端输入,按照预先设定好的各个阶段依次进行处理,最终在另一端输出处理后的结果,比如完成模型训练或者得到预测数据。
  • 作用
    • 方便流程管理:在实际的机器学习项目中,通常包含数据读取、数据预处理(如特征缩放、缺失值处理、特征编码等)、特征工程(特征选择、特征组合等)以及模型训练、模型评估等多个步骤,使用 Pipeline 可以清晰地将这些步骤组织起来,形成一个逻辑连贯的整体流程,便于代码的编写、维护和理解。
    • 保证处理顺序:确保各个处理步骤按照正确的顺序执行,避免因为顺序混乱导致的错误结果。例如,必须先对数据进行特征编码,再将编码后的特征输入到模型中进行训练,Pipeline 能够严格按照设定的顺序依次调用每个阶段的处理逻辑。
    • 便于模型复用和部署:一旦定义好一个完整的 Pipeline,它可以方便地进行保存和加载,在不同的环境(如开发环境、测试环境、生产环境)中复用,快速进行模型训练或者利用已训练好的模型进行预测,大大简化了模型部署和应用的过程。
2. 组成部分
  • 阶段(Stage) :Pipeline 由多个阶段组成,每个阶段可以是一个数据转换操作(如StringIndexer将字符串标签转换为索引、VectorAssembler将多个特征列组合成特征向量等),也可以是一个机器学习模型(如LogisticRegressionDecisionTreeClassifier等)。这些阶段按照在Pipeline中设置的顺序依次执行,数据在前一个阶段处理完成后会自动传递到下一个阶段进行后续处理。例如,一个简单的包含数据预处理和模型训练的 Pipeline 可能有以下几个阶段:
    • StringIndexer阶段:对字符串类型的标签列进行索引化处理。
    • VectorAssembler阶段:将多个数值特征列组装成一个特征向量列。
    • LogisticRegression阶段:使用组装好的特征向量列和索引化的标签列进行逻辑回归模型训练。

鸢尾花数据介绍

鸢尾花数据集(Iris Dataset)是一类非常经典且常用的数据集,在机器学习、数据分析和统计学习等领域被广泛应用,以下是对它的详细介绍:

1. 数据集来源

鸢尾花数据集是由美国植物学家埃德加・安德森(Edgar Anderson)在 20 世纪 30 年代收集整理的,后经英国统计学家和生物学家罗纳德・费舍尔(Ronald Fisher)在其 1936 年的论文《The use of multiple measurements in taxonomic problems》中使用并推广开来,成为了分类任务中的标准测试数据集之一。

2. 数据内容

  • 特征维度 :该数据集包含了 4 个特征,分别是:
    • 花萼长度(sepal length):通常以厘米为单位,描述鸢尾花花萼部分的长度。
    • 花萼宽度(sepal width):同样以厘米为单位,对应花萼部分的宽度情况。
    • 花瓣长度(petal length):指鸢尾花花瓣的长度,单位厘米,是区分不同鸢尾花种类的重要特征之一。
    • 花瓣宽度(petal width):以厘米为单位衡量花瓣的宽窄程度,在判断鸢尾花类别时也起着关键作用。
  • 类别标签 :总共有 3 种不同的鸢尾花品种类别,分别是:
    • 山鸢尾(Iris-setosa):其花瓣相对较窄较短,在形态上与另外两种有比较明显的区别。
    • 变色鸢尾(Iris-versicolor):花瓣长度、宽度等特征处于另外两种鸢尾花之间的状态,具有一定的过渡特点。
    • 维吉尼亚鸢尾(Iris-virginica):通常花瓣较为宽大,整体花朵形态与前两者不同。

3. 数据规模

鸢尾花数据集一共包含 150 条样本数据,每种鸢尾花类别各有 50 条样本,整体规模较小,便于快速进行模型的训练、测试以及算法验证等操作,尤其适合初学者理解和掌握分类算法的原理及流程。

代码实现

Scala 复制代码
package cn.lh.pblh123.spark2024.theorycourse.charpter9

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorAssembler, VectorIndexer}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.DoubleType

object SparkMLLogicalRegressionPipeLine {

  def main(args: Array[String]): Unit = {

    if (args.length != 3) {
      System.err.println("Usage: <murl> <inputfile> <modelpath>")
      System.exit(1)
    }
    val murl = args(0)
    val inputfile = args(1)
    val modelpath = args(2)
    val spark = SparkSession.builder().appName(s"${this.getClass.getName}").master(murl).getOrCreate()

    // 加载数据
    val df = spark.read.option("header", true)
      .csv(inputfile)

    df.show(3, false)
    df.printSchema()

    val dfDouble = df.select(col("sepal_length").cast(DoubleType), col("sepal_width").cast(DoubleType),
      col("petal_length").cast(DoubleType), col("petal_width").cast(DoubleType),
      col("species").alias("label")
    )
    dfDouble.printSchema()

    // 特征处理
    // 创建一个VectorAssembler实例,用于将多列特征组合成单一的特征向量
    val assembler = new VectorAssembler().setInputCols(
      Array("sepal_length", "sepal_width", "petal_length", "petal_width")
    ).setOutputCol("features")

    // 使用VectorAssembler转换原始DataFrame,生成一个新的DataFrame,其中包含特征向量和标签列
    val dataFrame = assembler.transform(dfDouble).select("features", "label")
    // 显示转换后的DataFrame的前3行数据,以验证转换结果
    dataFrame.show(3, 0)

    // 获取标签列和特征列
    // 使用StringIndexer将标签列转换为索引形式,以便后续的机器学习算法能够处理
    val labelIndex = new StringIndexer().setInputCol("label").setOutputCol("labelIndex").fit(dataFrame)

    // 使用VectorIndexer对特征列进行索引,这有助于提高机器学习模型的效率和效果
    val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").fit(dataFrame)

    // 创建Logistic回归模型实例,设置标签列和特征列,以及模型的训练参数
    // 最大迭代次数设为100,正则化参数设为0.3,ElasticNet参数设为0.8,这样的设置旨在平衡偏差和方差,避免过拟合
    val logisticRegression = new LogisticRegression().setLabelCol("labelIndex").setFeaturesCol("indexedFeatures")
      .setMaxIter(100).setRegParam(0.3).setElasticNetParam(0.8)

    println("logistricRegression parameters:\n" + logisticRegression.explainParams() + "\n")

    // 设置indexToString转换器
    val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel")
      .setLabels(labelIndex.labels)

    // 设置逻辑回归流水线
    val lrpiple = new Pipeline().setStages(Array(labelIndex, featureIndexer, logisticRegression, labelConverter))

    // 划分训练集和测试集,利用随机种子
    val Array(trainingData, testData) = dataFrame.randomSplit(Array(0.7, 0.3), 1234L)
    trainingData.show(3, 0)
    testData.show(3, 0)

    // 利用流水线训练模型
    val model = lrpiple.fit(trainingData)
    val predictions = model.transform(testData)
    // 显示预测结果
    predictions.select("predictedLabel", "label", "features", "probability").show(5, 0)

    // 评估模型
    // 创建一个MulticlassClassificationEvaluator实例用于评估分类模型的准确性
    // 设置评估器的标签列名为"labelIndex",预测列名为"prediction",并使用"accuracy"作为评估指标
    val evaluator = new MulticlassClassificationEvaluator().setLabelCol("labelIndex").setPredictionCol("prediction")
      .setMetricName("accuracy")

    // 使用评估器计算预测结果的准确性
    val accuracy = evaluator.evaluate(predictions)

    // 打印测试错误率,即1减去准确率
    println("Test Error = " + (1.0 - accuracy))

    // 通过流水线获取模型参数
    val lrModel = model.stages(2).asInstanceOf[LogisticRegressionModel]
    println("Learned classification logistic regression model:\n" + lrModel.summary.totalIterations)
    println("Coefficients: \n" + lrModel.coefficientMatrix)
    println("Intercepts: \n" + lrModel.interceptVector)
    println("logistic regression model num of Classes" + lrModel.numClasses)
    println("logistic regression model num of features" + lrModel.numFeatures)

    // 保存模型
    lrModel.write.overwrite().save(modelpath)

    spark.stop()
  }

}

代码执行效果如下:

相关推荐
使者大牙5 分钟前
【LLM学习笔记】第四篇:模型压缩方法——量化、剪枝、蒸馏、分解
人工智能·深度学习·算法·机器学习
ydl112837 分钟前
机器学习周志华学习笔记-第7章<贝叶斯分类器>
笔记·学习·机器学习
魅美1 小时前
大数据技术之SparkCore
大数据·spark
研一计算机小白一枚2 小时前
Which Tasks Should Be Learned Together in Multi-task Learning? 译文
人工智能·python·学习·机器学习
xianghan收藏册2 小时前
基于lora的llama2二次预训练
人工智能·深度学习·机器学习·chatgpt·transformer
不去幼儿园2 小时前
【RL Base】多级反馈队列(MFQ)算法
人工智能·python·算法·机器学习·强化学习
Alone--阮泽宇3 小时前
【机器学习】—PCA(主成分分析)
人工智能·机器学习
Easy数模3 小时前
因果机器学习EconML | 客户细分案例——基于机器学习的异质性处理效果估计
人工智能·机器学习
我怎么又饿了呀3 小时前
DataWhale—PumpkinBook(TASK05决策树)
算法·决策树·机器学习