spark 3.4.4 利用Spark ML中的交叉验证、管道流实现鸢尾花分类预测案例选取最优模型

前情回顾

前面的案例中,介绍了怎么基于管道流实现啊鸢尾花案例,利用逻辑斯蒂回归模型预测。详细内容步骤可以参照相应的博客内容

本案例内容

在 Spark 中使用交叉验证结合逻辑回归(Logistic Regression)以及管道流(Pipeline)实现鸢尾花案例最优模型选择的详细介绍和示例代码(以 Scala 语言为例):

知识点介绍

1. 管道流(Pipeline)概述

在 Spark ML 中,管道流是一种用于将多个数据处理和机器学习阶段组合在一起的机制,使得整个机器学习流程更加清晰、易于管理和复用。一个典型的管道通常包含多个按顺序执行的阶段(Stage),比如数据预处理阶段(例如标准化数据、特征编码等)、特征选择阶段以及最终的模型训练阶段等。

例如,假设我们要处理一个包含文本特征和数值特征的数据集来进行分类任务。可能首先需要对文本特征进行词向量转换(一种特征工程操作),然后对所有特征进行标准化,最后使用分类模型进行训练。这些步骤就可以通过管道流按顺序组织起来,方便地进行整体操作。

2. 逻辑回归结合管道流及交叉验证的优势

  • 整合流程:通过管道流,可以把逻辑回归模型之前需要进行的一系列数据预处理步骤(如特征编码、归一化等)与逻辑回归模型训练本身整合在一个连贯的流程里。这样在进行交叉验证时,每次数据划分后都能自动按照统一的流程依次执行各阶段操作,确保数据处理和模型训练的一致性,避免了手动分别处理不同数据子集时可能出现的错误或不一致情况。
  • 高效与可复用:便于在不同数据集或者不同超参数调整场景下复用整个流程。对于逻辑回归的交叉验证来说,只需改变交叉验证中的参数网格(比如尝试不同的逻辑回归超参数组合)等设置,就能方便地重新运行整个包含数据预处理、模型训练和评估的流程,高效地找到最佳模型配置。

3. 交叉验证

交叉验证是一种统计方法,用于评估机器学习模型的性能。它通过将数据集分割成若干个互斥的子集,然后多次训练和测试模型来提高评估的可靠性和准确性。最常用的交叉验证方法是K折交叉验证(K-Fold Cross Validation),下面是其基本步骤:

K折交叉验证的基本步骤
准备数据:

收集并清洗数据集,确保数据质量。
数据集划分:

将整个数据集随机分成K个大小相等(或尽可能相等)的互斥子集,也称为"折"(fold)。K的常见值有5或10。
训练与验证:

对于每个不同的K折:

将其中的一个子集作为验证集。

剩余的K-1个子集合并起来作为训练集。

使用训练集来训练模型。

利用验证集评估模型的表现,记录评估指标(例如准确率、F1分数、均方误差等)。
汇总评估:

计算所有K次验证过程中的评估指标的平均值,以此作为模型性能的估计。
模型选择:

根据交叉验证的结果,选择最佳的模型参数或者算法。
交叉验证的优点

减少偏差:通过多次训练和测试,减少了由于特定训练/测试集的选择导致的模型性能评估的偏差。

充分利用数据:几乎所有的数据都被用来训练和测试,从而提高了模型评估的可靠性。

参数调整:有助于在不同参数设置之间做出更准确的比较,选择最优的模型配置。
注意事项

在进行K折交叉验证时,重要的是确保每个折的数据分布尽量保持一致,这通常意味着要在分层的基础上进行划分,尤其是在处理分类问题时。

如果数据集非常大,可以考虑使用留一法(Leave-One-Out, LOO)或重复K折交叉验证来增加评估的稳定性。

4. 网格搜索

网格搜索(Grid Search)是一种用于超参数优化的技术,在机器学习中广泛应用于模型选择和调优。它通过在指定的参数范围内系统地遍历所有可能的参数组合,来寻找最优的模型参数设置。以下是网格搜索的详细解释:

工作原理

  1. 定义参数网格 :首先,为模型的每个超参数定义一个搜索范围,这些范围的组合构成了参数网格。例如,对于支持向量机(SVM)模型,你可能想要调整C(正则化参数)和gamma(核函数的系数)的值。

  2. 模型实例化:选择你要调优的模型类型,并实例化一个模型对象。

  3. 交叉验证:对于参数网格中的每一组参数,使用交叉验证(如k折交叉验证)来评估模型的性能。交叉验证通过将数据集划分为多个子集,并在不同的子集上训练和验证模型,来提供更可靠的性能估计。

  4. 性能评估:根据交叉验证的结果,计算每个参数组合对应的性能指标(如准确率、召回率、F1分数等)。

  5. 选择最优参数:比较所有参数组合的性能指标,选择性能最好的参数组合作为最优参数。

优点

  • 系统性:网格搜索通过遍历所有可能的参数组合,确保不会错过任何潜在的最优解。

  • 可靠性:使用交叉验证来评估模型性能,减少了过拟合的风险,并提供了更可靠的性能估计。

缺点

  • 计算成本高:当参数网格很大时,网格搜索的计算成本可能非常高,因为它需要训练多个模型并评估它们的性能。

  • 可能陷入局部最优:尽管网格搜索是系统性的,但它仍然受限于定义的参数范围。如果最优参数不在定义的范围内,网格搜索将无法找到它。

案例代码

Scala 复制代码
package cn.lh.pblh123.spark2024.theorycourse.charpter9

import cn.lh.pblh123.spark2024.theorycourse.charpter9.MLGMM.checkPathExistStatus
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, MulticlassClassificationEvaluator}
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorAssembler, VectorIndexer}
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.DoubleType

object CrossValidatorLR {

  def main(args: Array[String]): Unit = {
    // 检查命令行参数数量是否正确,确保程序正确使用
    if (args.length != 3) {
      System.err.println("Usage: <murl> <inputfile> <modelpath>")
      System.exit(1)
    }
    // 从命令行参数中提取变量
    val murl = args(0)
    val inputfile = args(1)
    val modelpath = args(2)
    // 创建SparkSession实例,用于数据处理和模型训练
    val spark = SparkSession.builder().appName(s"${this.getClass.getName}").master(murl).getOrCreate()

    // 加载数据
    val df = spark.read.option("header", true)
      .csv(inputfile)

    // 显示数据样例和 schema
    df.show(3, false)
    df.printSchema()

    // 数据预处理,转换特征和标签
    val dfDouble = df.select(col("sepal_length").cast(DoubleType), col("sepal_width").cast(DoubleType),
      col("petal_length").cast(DoubleType), col("petal_width").cast(DoubleType),
      col("species").alias("label")
    )
    dfDouble.printSchema()

    // 特征处理
    // 创建一个VectorAssembler实例,用于将多列特征组合成单一的特征向量
    val assembler = new VectorAssembler().setInputCols(
      Array("sepal_length", "sepal_width", "petal_length", "petal_width")
    ).setOutputCol("features")

    // 使用VectorAssembler转换原始DataFrame,生成一个新的DataFrame,其中包含特征向量和标签列
    val dataFrame = assembler.transform(dfDouble).select("features", "label")
    // 显示转换后的DataFrame的前3行数据,以验证转换结果
    dataFrame.show(3, 0)

    // 获取标签列和特征列
    // 使用StringIndexer将标签列转换为索引形式,以便后续的机器学习算法能够处理
    val labelIndex = new StringIndexer().setInputCol("label").setOutputCol("labelIndex").fit(dataFrame)

    // 使用VectorIndexer对特征列进行索引,这有助于提高机器学习模型的效率和效果
    val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").fit(dataFrame)

    // 创建Logistic回归模型实例,设置标签列和特征列,以及模型的训练参数
    // 最大迭代次数设为100
    val logisticRegression = new LogisticRegression().setLabelCol("labelIndex").setFeaturesCol("indexedFeatures")
      .setMaxIter(100)

    // 打印Logistic回归模型的参数,以便调试和优化
    println("logistricRegression parameters:\n" + logisticRegression.explainParams() + "\n")

    // 设置indexToString转换器
    val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel")
      .setLabels(labelIndex.labels)

    // 设置逻辑回归流水线
    val lrpiple = new Pipeline().setStages(Array(labelIndex, featureIndexer, logisticRegression, labelConverter))

    // 划分训练集和测试集,利用随机种子
    val Array(trainingData, testData) = dataFrame.randomSplit(Array(0.7, 0.3), 1234L)
    trainingData.show(3, 0)
    testData.show(3, 0)

    // 创建参数网格,用于交叉验证
    val paraGrid = new ParamGridBuilder()
      .addGrid(logisticRegression.elasticNetParam, Array(0.2, 0.8))
      .addGrid(logisticRegression.regParam, Array(0.01, 0.1, 0.3, 0.5))
      .build()

    // 打印参数网格,以便调试和优化
    println("paraGrid parameters:\n" + paraGrid + "\n")

    // 创建评估器,用于评估模型性能
    val evaluator = new MulticlassClassificationEvaluator().setLabelCol("labelIndex").setPredictionCol("prediction")

    // 创建交叉验证器,用于模型选择
    val crossValidator = new CrossValidator()
      .setEstimator(lrpiple)
      .setEvaluator(evaluator)
      .setEstimatorParamMaps(paraGrid)
      .setNumFolds(3)
      .setSeed(1234L)

    // 训练最佳模型
    val CVModel = crossValidator.fit(trainingData)

    // 使用最佳模型进行预测
    val predictions = CVModel.transform(testData)

    // 显示预测结果
    predictions.select("predictedLabel", "label", "features", "probability").show(5, 0)

    // 评估模型性能
    val accuracy = evaluator.evaluate(predictions)
    println("Accuracy = " + accuracy)
    println("Test Error = " + (1.0 - accuracy))

    // 从交叉验证模型中提取最佳的管道模型
    val bestPipleModel = CVModel.bestModel.asInstanceOf[PipelineModel]

    // 从最佳管道模型中提取逻辑回归模型,该模型位于管道的第三个阶段
    val lrModel = bestPipleModel.stages(2).asInstanceOf[LogisticRegressionModel]

    // 打印最佳模型的参数和统计信息
    println("Best Model Parameters:\n" + lrModel.explainParams())
    println("Best Model Coefficients:\n" + lrModel.coefficientMatrix)
    println("Best Model Intercept:\n" + lrModel.interceptVector)
    println("Best Model Summary:\n" + lrModel.summary)
    println("Best Model Summary Accuracy:\n" + lrModel.summary.accuracy)
    println("Best Model Summary False Positive Rate:\n" + lrModel.summary.falsePositiveRateByLabel)
    println("Best Model Summary Precision:\n" + lrModel.summary.precisionByLabel)
    println("Best Model Summary Recall:\n" + lrModel.summary.recallByLabel)
    println("Best Model Summary FMeasure:\n" + lrModel.summary.fMeasureByLabel)

    // 计算AUC
    val binaryEvaluator = new BinaryClassificationEvaluator()
     .setLabelCol("labelIndex")
     .setRawPredictionCol("prediction")
     .setMetricName("areaUnderROC")

    val auc = binaryEvaluator.evaluate(predictions)
    println("AUC = " + auc)

    // 打印Logistic回归模型的参数
    lrModel.explainParams()

    // 检查模型路径是否存在,并保存最佳模型
    checkPathExistStatus(modelpath)
    lrModel.save(modelpath)

    // 停止SparkSession
    spark.stop()
  }

}

运行结果

bash 复制代码
2024-11-25 16:32:54,035 WARN  [main] util.NativeCodeLoader (NativeCodeLoader.java:60) - Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
+------------+-----------+------------+-----------+-------+
|sepal_length|sepal_width|petal_length|petal_width|species|
+------------+-----------+------------+-----------+-------+
|5.1         |3.5        |1.4         |0.2        |setosa |
|4.9         |3.0        |1.4         |0.2        |setosa |
|4.7         |3.2        |1.3         |0.2        |setosa |
+------------+-----------+------------+-----------+-------+
only showing top 3 rows

root
 |-- sepal_length: string (nullable = true)
 |-- sepal_width: string (nullable = true)
 |-- petal_length: string (nullable = true)
 |-- petal_width: string (nullable = true)
 |-- species: string (nullable = true)

root
 |-- sepal_length: double (nullable = true)
 |-- sepal_width: double (nullable = true)
 |-- petal_length: double (nullable = true)
 |-- petal_width: double (nullable = true)
 |-- label: string (nullable = true)

+-----------------+------+
|features         |label |
+-----------------+------+
|[5.1,3.5,1.4,0.2]|setosa|
|[4.9,3.0,1.4,0.2]|setosa|
|[4.7,3.2,1.3,0.2]|setosa|
+-----------------+------+
only showing top 3 rows

logistricRegression parameters:
aggregationDepth: suggested depth for treeAggregate (>= 2) (default: 2)
elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty (default: 0.0)
family: The name of family which is a description of the label distribution to be used in the model. Supported options: auto, binomial, multinomial. (default: auto)
featuresCol: features column name (default: features, current: indexedFeatures)
fitIntercept: whether to fit an intercept term (default: true)
labelCol: label column name (default: label, current: labelIndex)
lowerBoundsOnCoefficients: The lower bounds on coefficients if fitting under bound constrained optimization. (undefined)
lowerBoundsOnIntercepts: The lower bounds on intercepts if fitting under bound constrained optimization. (undefined)
maxBlockSizeInMB: Maximum memory in MB for stacking input data into blocks. Data is stacked within partitions. If more than remaining data size in a partition then it is adjusted to the data size. Default 0.0 represents choosing optimal value, depends on specific algorithm. Must be >= 0. (default: 0.0)
maxIter: maximum number of iterations (>= 0) (default: 100, current: 100)
predictionCol: prediction column name (default: prediction)
probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities (default: probability)
rawPredictionCol: raw prediction (a.k.a. confidence) column name (default: rawPrediction)
regParam: regularization parameter (>= 0) (default: 0.0)
standardization: whether to standardize the training features before fitting the model (default: true)
threshold: threshold in binary classification prediction, in range [0, 1] (default: 0.5)
thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold (undefined)
tol: the convergence tolerance for iterative algorithms (>= 0) (default: 1.0E-6)
upperBoundsOnCoefficients: The upper bounds on coefficients if fitting under bound constrained optimization. (undefined)
upperBoundsOnIntercepts: The upper bounds on intercepts if fitting under bound constrained optimization. (undefined)
weightCol: weight column name. If this is not set or empty, we treat all instance weights as 1.0 (undefined)

+-----------------+------+
|features         |label |
+-----------------+------+
|[4.4,3.0,1.3,0.2]|setosa|
|[4.4,3.2,1.3,0.2]|setosa|
|[4.6,3.1,1.5,0.2]|setosa|
+-----------------+------+
only showing top 3 rows

+-----------------+------+
|features         |label |
+-----------------+------+
|[4.3,3.0,1.1,0.1]|setosa|
|[4.4,2.9,1.4,0.2]|setosa|
|[4.5,2.3,1.3,0.3]|setosa|
+-----------------+------+
only showing top 3 rows

paraGrid parameters:
[Lorg.apache.spark.ml.param.ParamMap;@7cb32ca5

2024-11-25 16:33:05,441 WARN  [main] blas.InstanceBuilder (InstanceBuilder.java:52) - Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
+-----------------+------+--------------+--------------------------------------------------------------+
|features         |label |predictedLabel|probability                                                   |
+-----------------+------+--------------+--------------------------------------------------------------+
|[4.3,3.0,1.1,0.1]|setosa|setosa        |[0.9757051466049288,0.024294044215834764,8.091792365185294E-7]|
|[4.4,2.9,1.4,0.2]|setosa|setosa        |[0.942414263844343,0.057581913147947784,3.823007709168387E-6] |
|[4.5,2.3,1.3,0.3]|setosa|setosa        |[0.6006671038922778,0.39930355820688734,2.9337900834715793E-5]|
|[4.9,3.6,1.4,0.1]|setosa|setosa        |[0.9914556600750227,0.008543921422201815,4.185027753908102E-7]|
|[5.0,3.0,1.6,0.2]|setosa|setosa        |[0.8901960393397007,0.10979575995270213,8.200707597259611E-6] |
+-----------------+------+--------------+--------------------------------------------------------------+
only showing top 5 rows

+--------------+------+-----------------+--------------------------------------------------------------+
|predictedLabel|label |features         |probability                                                   |
+--------------+------+-----------------+--------------------------------------------------------------+
|setosa        |setosa|[4.3,3.0,1.1,0.1]|[0.9757051466049288,0.024294044215834764,8.091792365185294E-7]|
|setosa        |setosa|[4.4,2.9,1.4,0.2]|[0.942414263844343,0.057581913147947784,3.823007709168387E-6] |
|setosa        |setosa|[4.5,2.3,1.3,0.3]|[0.6006671038922778,0.39930355820688734,2.9337900834715793E-5]|
|setosa        |setosa|[4.9,3.6,1.4,0.1]|[0.9914556600750227,0.008543921422201815,4.185027753908102E-7]|
|setosa        |setosa|[5.0,3.0,1.6,0.2]|[0.8901960393397007,0.10979575995270213,8.200707597259611E-6] |
+--------------+------+-----------------+--------------------------------------------------------------+
only showing top 5 rows

Accuracy = 0.9607843137254901
Test Error = 0.03921568627450989
Best Model Parameters:
aggregationDepth: suggested depth for treeAggregate (>= 2) (default: 2)
elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty (default: 0.0, current: 0.2)
family: The name of family which is a description of the label distribution to be used in the model. Supported options: auto, binomial, multinomial. (default: auto)
featuresCol: features column name (default: features, current: indexedFeatures)
fitIntercept: whether to fit an intercept term (default: true)
labelCol: label column name (default: label, current: labelIndex)
lowerBoundsOnCoefficients: The lower bounds on coefficients if fitting under bound constrained optimization. (undefined)
lowerBoundsOnIntercepts: The lower bounds on intercepts if fitting under bound constrained optimization. (undefined)
maxBlockSizeInMB: Maximum memory in MB for stacking input data into blocks. Data is stacked within partitions. If more than remaining data size in a partition then it is adjusted to the data size. Default 0.0 represents choosing optimal value, depends on specific algorithm. Must be >= 0. (default: 0.0)
maxIter: maximum number of iterations (>= 0) (default: 100, current: 100)
predictionCol: prediction column name (default: prediction)
probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities (default: probability)
rawPredictionCol: raw prediction (a.k.a. confidence) column name (default: rawPrediction)
regParam: regularization parameter (>= 0) (default: 0.0, current: 0.01)
standardization: whether to standardize the training features before fitting the model (default: true)
threshold: threshold in binary classification prediction, in range [0, 1] (default: 0.5)
thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold (undefined)
tol: the convergence tolerance for iterative algorithms (>= 0) (default: 1.0E-6)
upperBoundsOnCoefficients: The upper bounds on coefficients if fitting under bound constrained optimization. (undefined)
upperBoundsOnIntercepts: The upper bounds on intercepts if fitting under bound constrained optimization. (undefined)
weightCol: weight column name. If this is not set or empty, we treat all instance weights as 1.0 (undefined)
Best Model Coefficients:
-1.0160525354315355  2.5595568393649253   -0.9181804522594151     -1.9942533505022708  
0.46823886626877925  -1.1501375486462728  -0.0041242496510105284  -0.9534800212737365  
0.25465889870796954  -0.8280356114881423  1.0644125342983854      3.2745417344773724   
Best Model Intercept:
[4.022500795582984,3.9666625659384906,-7.9891633615214745]
Best Model Summary:
org.apache.spark.ml.classification.LogisticRegressionTrainingSummaryImpl@28377ea2
Best Model Summary Accuracy:
0.9595959595959596
Best Model Summary False Positive Rate:
[D@61c9f15d
Best Model Summary Precision:
[D@6a04b1bd
Best Model Summary Recall:
[D@329033ef
Best Model Summary FMeasure:
[D@13461194
AUC = 1.0
相关推荐
郝学胜-神的一滴9 小时前
反向传播:神经网络的「灵魂」修炼法则
人工智能·pytorch·深度学习·神经网络·机器学习·数据挖掘
Fleshy数模15 小时前
基于 Qwen2.5-1.5B-Instruct 实现多轮对话与文本分类实践
人工智能·分类·大模型
Betelgeuse7617 小时前
从爬虫脚本到 AI 智能体:一次数据挖掘实践的完整进化
人工智能·爬虫·数据挖掘
计算机毕业编程指导师1 天前
【计算机毕设推荐】Python+Hadoop+Spark共享单车数据可视化分析系统 毕业设计 选题推荐 毕设选题 数据分析 机器学习 数据挖掘
大数据·hadoop·python·计算机·数据挖掘·spark·课程设计
计算机毕业编程指导师1 天前
【计算机毕设】基于Hadoop的共享单车订单数据分析系统+Python+Django全栈开发 毕业设计 选题推荐 毕设选题 数据分析 机器学习 数据挖掘
大数据·hadoop·python·计算机·数据挖掘·spark·django
夜郎king2 天前
水力模型 INP 文件如何导入 QGIS?超详细实操教程
人工智能·数据挖掘·水力模型·qgis水力制图
计算机毕业编程指导师2 天前
基于Spark的性格行为数据分析与可视化系统源码 毕业设计 选题推荐 毕设选题 数据分析 机器学习 数据挖掘
大数据·python·数据挖掘·数据分析·spark·毕业设计·性格行为
QDYOKR1682 天前
OKR管理系统怎么选?2026主流OKR工具深度解析
大数据·人工智能·信息可视化·数据挖掘·数据分析
Dfreedom.2 天前
【实战篇】分类任务全流程演示——决策树
人工智能·算法·决策树·机器学习·分类
2601_954971132 天前
经济学专业考CDA数据分析师证书值不值?对求职帮助到底有多大
数据挖掘