【IDEA + Spark 3.4.1 + sbt 1.9.3 + Spark MLlib 构建鸢尾花决策树分类预测模型】

决策树进行鸢尾花分类的案例

背景说明:

通过IDEA + Spark 3.4.1 + sbt 1.9.3 + Spark MLlib 构建鸢尾花决策树分类预测模型,这是一个分类模型案例,通过该案例,可以快速了解Spark MLlib分类预测模型的使用方法。

依赖

sbt 复制代码
ThisBuild / version := "0.1.0-SNAPSHOT"  
  
ThisBuild / scalaVersion := "2.13.11"  
  
lazy val root = (project in file("."))  
  .settings(  
    name := "SparkLearning",  
    idePackagePrefix := Some("cn.lh.spark"),  
    libraryDependencies += "org.apache.spark" %% "spark-sql" % "3.4.1",  
    libraryDependencies += "org.apache.spark" %% "spark-core" % "3.4.1",  
    libraryDependencies += "org.apache.hadoop" % "hadoop-auth" % "3.3.6",     libraryDependencies += "org.apache.spark" %% "spark-streaming" % "3.4.1",  
    libraryDependencies += "org.apache.spark" %% "spark-streaming-kafka-0-10" % "3.4.1",  
    libraryDependencies += "org.apache.spark" %% "spark-mllib" % "3.4.1",  
    libraryDependencies += "mysql" % "mysql-connector-java" % "8.0.30"  
)

完整代码

scala 复制代码
package cn.lh.spark  
  
import org.apache.spark.ml.{Pipeline, PipelineModel}  
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}  
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator  
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, StringIndexerModel, VectorIndexer, VectorIndexerModel}  
import org.apache.spark.ml.linalg.Vectors  
import org.apache.spark.rdd.RDD  
import org.apache.spark.sql.{DataFrame, SparkSession}  
  
  
/**  
 * 决策树分类器,实现鸢尾花分类  
 */  
  
//case class Iris(features: org.apache.spark.ml.linalg.Vector, label: String)  // MLlibLogisticRegression 中存在该样例类,这里不用写,一个包里不存在这个样例类时需要写
  
object MLlibDecisionTreeClassifier {  
  
  def main(args: Array[String]): Unit = {  
  
    val spark: SparkSession = SparkSession.builder().master("local[2]")  
      .appName("Spark MLlib DecisionTreeClassifier").getOrCreate()  
  
    val irisRDD: RDD[Iris] = spark.sparkContext.textFile("F:\\niit\\2023\\2023_2\\Spark\\codes\\data\\iris.txt")  
      .map(_.split(",")).map(p =>  
      Iris(Vectors.dense(p(0).toDouble, p(1).toDouble, p(2).toDouble, p(3).toDouble), p(4).toString()))  
  
    import spark.implicits._  
    val data: DataFrame = irisRDD.toDF()  
    data.show()  
  
    data.createOrReplaceTempView("iris")  
    val df: DataFrame = spark.sql("select * from iris")  
  
    println("鸢尾花原始数据如下:")  
    df.map(t => t(1)+":"+t(0)).collect().foreach(println)  
  
    //    处理特征和标签,以及数据分组  
    val labelIndexer: StringIndexerModel = new StringIndexer().setInputCol("label").setOutputCol(  
      "indexedLabel").fit(df)  
  
    val featureIndexer: VectorIndexerModel = new VectorIndexer().setInputCol("features")  
      .setOutputCol("indexedFeatures").setMaxCategories(4).fit(df)  
    //这里我们设置一个labelConverter,目的是把预测的类别重新转化成字符型的  
    val labelConverter: IndexToString = new IndexToString().setInputCol("prediction")  
      .setOutputCol("predictedLabel").setLabels(labelIndexer.labels)  
  
    //接下来,我们把数据集随机分成训练集和测试集,其中训练集占70%。  
    val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))  
  
    val dtClassifier: DecisionTreeClassifier = new DecisionTreeClassifier()  
      .setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures")  
  
    //在pipeline中进行设置  
    val pipelinedClassifier: Pipeline = new Pipeline()  
      .setStages(Array(labelIndexer, featureIndexer, dtClassifier, labelConverter))  
    //训练决策树模型  
    val modelClassifier: PipelineModel = pipelinedClassifier.fit(trainingData)  
    //进行预测  
    val predictionsClassifier: DataFrame = modelClassifier.transform(testData)  
    predictionsClassifier.select("predictedLabel", "label", "features").show(5)  
  
    //    评估决策树分类模型  
    val evaluatorClassifier: MulticlassClassificationEvaluator = new MulticlassClassificationEvaluator()  
      .setLabelCol("indexedLabel")  
      .setPredictionCol("prediction").setMetricName("accuracy")  
    val accuracy: Double = evaluatorClassifier.evaluate(predictionsClassifier)  
    println("Test Error = " + (1.0 - accuracy))  
  
    val treeModelClassifier: DecisionTreeClassificationModel = modelClassifier.stages(2)  
      .asInstanceOf[DecisionTreeClassificationModel]  
  
    println("Learned classification tree model:\n" + treeModelClassifier.toDebugString)  
  
  
    spark.stop()  
  }  
  
}
相关推荐
就叫飞六吧2 小时前
找不到或无法加载主类 @C:\***\Local\Temp\idea_arg_file...
java·ide·intellij-idea
Mr.朱鹏5 小时前
RocketMQ可视化监控与管理
java·spring boot·spring·spring cloud·maven·intellij-idea·rocketmq
q***16087 小时前
解决 IntelliJ IDEA 中 Tomcat 日志乱码问题的详细指南
java·tomcat·intellij-idea
AI_56787 小时前
从“插件装一堆”到“效率翻一倍”——IntelliJ IDEA的插件化开发革命
java·ide·intellij-idea
B站计算机毕业设计之家7 小时前
机器学习:python智能电商推荐平台 大数据 spark(Django后端+Vue3前端+协同过滤 毕业设计/实战 源码)✅
大数据·python·spark·django·推荐算法·电商
mn_kw7 小时前
Spark Shuffle 深度解析与参数详解
大数据·分布式·spark
红队it8 小时前
【Spark+Hive】基于Spark大数据旅游景点数据分析可视化推荐系统(完整系统源码+数据库+开发笔记+详细部署教程+虚拟机分布式启动教程)✅
大数据·python·算法·数据分析·spark·django·echarts
N***p3658 小时前
IDEA搭建SpringBoot,MyBatis,Mysql工程项目
spring boot·intellij-idea·mybatis
mn_kw9 小时前
Hive On Spark 统计信息收集深度解析
hive·hadoop·spark
mn_kw9 小时前
Spark SQL CBO(基于成本的优化器)参数深度解析
前端·sql·spark