Spark2.x 入门:决策树分类器

一、方法简介 ​

决策树(decision tree)是一种基本的分类与回归方法,这里主要介绍用于分类的决策树。决策树模式呈树形结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别。学习时利用训练数据,根据损失函数最小化的原则建立决策树模型;预测时,对新的数据,利用决策树模型进行分类。

决策树学习通常包括3个步骤:特征选择、决策树的生成和决策树的剪枝。

示例代码

我们以iris数据集(iris)为例进行分析。iris以鸢尾花的特征作为数据来源,数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性,是在数据挖掘、数据分类中非常常用的测试集、训练集。决策树可以用于分类和回归,接下来我们将在代码中分别进行介绍。

1. 导入需要的包:

scala 复制代码
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.linalg.{Vector,Vectors}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}

2. 读取数据,简要分析:

导入spark.implicits._,使其支持把一个RDD隐式转换为一个DataFrame。我们用case class定义一个schema:Iris,Iris就是我们需要的数据的结构;然后读取文本文件,第一个map把每行的数据用","隔开,比如在我们的数据集中,每行被分成了5部分,前4部分是鸢尾花的4个特征,最后一部分是鸢尾花的分类;我们这里把特征存储在Vector中,创建一个Iris模式的RDD,然后转化成dataframe;然后把刚刚得到的数据注册成一个表iris,注册成这个表之后,我们就可以通过sql语句进行数据查询;选出我们需要的数据后,我们可以把结果打印出来查看一下数据。

scala 复制代码
scala> import spark.implicits._
import spark.implicits._

scala> case class Iris(features: org.apache.spark.ml.linalg.Vector, label: String)
defined class Iris

scala> val data = spark.read.textFile("file:///root/data/iris.txt").map(_.split(",")).map(p => Iris(Vectors.dense(p(0).toDouble,p(1).toDouble,p(2).toDouble,p(3).toDouble),p(4).toString())).toDF()

scala> data.createOrReplaceTempView("iris")

scala> val df = spark.sql("select * from iris")
df: org.apache.spark.sql.DataFrame = [features: vector, label: string]

scala> df.map(t => t(1)+":"+t(0)).collect().foreach(println)
Iris-setosa:[5.1,3.5,1.4,0.2]
Iris-setosa:[4.9,3.0,1.4,0.2]
Iris-setosa:[4.7,3.2,1.3,0.2]
Iris-setosa:[4.6,3.1,1.5,0.2]
Iris-setosa:[5.0,3.6,1.4,0.2]
Iris-setosa:[5.4,3.9,1.7,0.4]
Iris-setosa:[4.6,3.4,1.4,0.3]
......

3. 进一步处理特征和标签,以及数据分组:

scala 复制代码
//分别获取标签列和特征列,进行索引,并进行了重命名。
scala> val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(df) 
labelIndexer: org.apache.spark.ml.feature.StringIndexerModel = strIdx_6c3c138d61bf

scala> val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(4).fit(df)
featureIndexer: org.apache.spark.ml.feature.VectorIndexerModel = vecIdx_08c01d7fd953

//这里我们设置一个labelConverter,目的是把预测的类别重新转化成字符型的。
scala> val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
labelConverter: org.apache.spark.ml.feature.IndexToString = idxToStr_11ce3220e43a


//接下来,我们把数据集随机分成训练集和测试集,其中训练集占70%。
scala> val Array(trainingData, testData) = df.randomSplit(Array(0.7, 0.3))
trainingData: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [features: vector, label: string]
testData: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [features: vector, label: string]

4. 构建决策树分类模型:

scala 复制代码
//导入所需要的包
scala> import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.classification.DecisionTreeClassificationModel

scala> import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.classification.DecisionTreeClassifier

scala> import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator


//训练决策树模型,这里我们可以通过setter的方法来设置决策树的参数,也可以用ParamMap来设置(具体的可以查看spark mllib的官网)。具体的可以设置的参数可以通过explainParams()来获取。
scala> val dtClassifier = new DecisionTreeClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures")
dtClassifier: org.apache.spark.ml.classification.DecisionTreeClassifier = dtc_7948c1724433

//在pipeline中进行设置
scala> val pipelinedClassifier = new Pipeline().setStages(Array(labelIndexer, featureIndexer, dtClassifier, labelConverter))
pipelinedClassifier: org.apache.spark.ml.Pipeline = pipeline_b5a49e693b35

//训练决策树模型
scala> val modelClassifier = pipelinedClassifier.fit(trainingData)
modelClassifier: org.apache.spark.ml.PipelineModel = pipeline_b5a49e693b35

//进行预测
scala> val predictionsClassifier = modelClassifier.transform(testData)
predictionsClassifier: org.apache.spark.sql.DataFrame = [features: vector, label: string ... 6 more fields]

//查看部分预测的结果
scala> predictionsClassifier.select("predictedLabel", "label", "features").show(20)
+---------------+---------------+-----------------+
| predictedLabel|          label|         features|
+---------------+---------------+-----------------+
|    Iris-setosa|    Iris-setosa|[4.4,2.9,1.4,0.2]|
|    Iris-setosa|    Iris-setosa|[4.6,3.4,1.4,0.3]|
|    Iris-setosa|    Iris-setosa|[4.6,3.6,1.0,0.2]|
|    Iris-setosa|    Iris-setosa|[4.7,3.2,1.6,0.2]|
|    Iris-setosa|    Iris-setosa|[4.8,3.0,1.4,0.1]|
|    Iris-setosa|    Iris-setosa|[4.8,3.4,1.9,0.2]|
|    Iris-setosa|    Iris-setosa|[4.9,3.1,1.5,0.1]|
|Iris-versicolor|Iris-versicolor|[5.0,2.3,3.3,1.0]|
|    Iris-setosa|    Iris-setosa|[5.0,3.2,1.2,0.2]|
|    Iris-setosa|    Iris-setosa|[5.0,3.3,1.4,0.2]|
|    Iris-setosa|    Iris-setosa|[5.0,3.4,1.6,0.4]|
|    Iris-setosa|    Iris-setosa|[5.1,3.3,1.7,0.5]|
|    Iris-setosa|    Iris-setosa|[5.1,3.7,1.5,0.4]|
|    Iris-setosa|    Iris-setosa|[5.3,3.7,1.5,0.2]|
|    Iris-setosa|    Iris-setosa|[5.4,3.4,1.5,0.4]|
|    Iris-setosa|    Iris-setosa|[5.4,3.9,1.7,0.4]|
|Iris-versicolor|Iris-versicolor|[5.5,2.3,4.0,1.3]|
|Iris-versicolor|Iris-versicolor|[5.5,2.5,4.0,1.3]|
|Iris-versicolor|Iris-versicolor|[5.5,2.6,4.4,1.2]|
|    Iris-setosa|    Iris-setosa|[5.5,4.2,1.4,0.2]|
+---------------+---------------+-----------------+
only showing top 20 rows

5. 评估决策树分类模型:

scala 复制代码
scala> val evaluatorClassifier = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy")
evaluatorClassifier: org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator = mcEval_8059f30a8634

scala> val accuracy = evaluatorClassifier.evaluate(predictionsClassifier)
accuracy: Double = 0.94

scala> println("Test Error = " + (1.0 - accuracy))
Test Error = 0.06000000000000005

scala> val treeModelClassifier = modelClassifier.stages(2).asInstanceOf[DecisionTreeClassificationModel]
treeModelClassifier: org.apache.spark.ml.classification.DecisionTreeClassificationModel = DecisionTreeClassificationModel (uid=dtc_7948c1724433) of depth 4 with 13 nodes

scala> println("Learned classification tree model:\n" + treeModelClassifier.toDebugString)
Learned classification tree model:
DecisionTreeClassificationModel (uid=dtc_7948c1724433) of depth 4 with 13 nodes
  If (feature 2 <= 1.9)
   Predict: 0.0
  Else (feature 2 > 1.9)
   If (feature 3 <= 1.6)
    If (feature 2 <= 4.9)
     Predict: 1.0
    Else (feature 2 > 4.9)
     If (feature 0 <= 6.0)
      Predict: 1.0
     Else (feature 0 > 6.0)
      Predict: 2.0
   Else (feature 3 > 1.6)
    If (feature 2 <= 4.8)
     If (feature 1 <= 2.8)
      Predict: 2.0
     Else (feature 1 > 2.8)
      Predict: 1.0
    Else (feature 2 > 4.8)
     Predict: 2.0

从上述结果可以看到模型的预测准确率为 0.94 以及训练的决策树模型结构。

6. 构建决策树回归模型:

scala 复制代码
//导入所需要的包
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.regression.DecisionTreeRegressor

//训练决策树模型
scala> val dtRegressor = new DecisionTreeRegressor().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures")
dtRegressor: org.apache.spark.ml.regression.DecisionTreeRegressor = dtr_e98e9ef10e22

//在pipeline中进行设置
scala> val pipelineRegressor = new Pipeline().setStages(Array(labelIndexer, featureIndexer, dtRegressor, labelConverter))
pipelineRegressor: org.apache.spark.ml.Pipeline = pipeline_9f0fb530c801

//训练决策树模型
scala> val modelRegressor = pipelineRegressor.fit(trainingData)
modelRegressor: org.apache.spark.ml.PipelineModel = pipeline_9f0fb530c801

//进行预测
scala> val predictionsRegressor = modelRegressor.transform(testData)
predictionsRegressor: org.apache.spark.sql.DataFrame = [features: vector, label: string ... 4 more fields]

//查看部分预测结果
scala> predictionsRegressor.select("predictedLabel", "label", "features").show(20)
+---------------+---------------+-----------------+
| predictedLabel|          label|         features|
+---------------+---------------+-----------------+
|    Iris-setosa|    Iris-setosa|[4.4,2.9,1.4,0.2]|
|    Iris-setosa|    Iris-setosa|[4.6,3.4,1.4,0.3]|
|    Iris-setosa|    Iris-setosa|[4.6,3.6,1.0,0.2]|
|    Iris-setosa|    Iris-setosa|[4.7,3.2,1.6,0.2]|
|    Iris-setosa|    Iris-setosa|[4.8,3.0,1.4,0.1]|
|    Iris-setosa|    Iris-setosa|[4.8,3.4,1.9,0.2]|
|    Iris-setosa|    Iris-setosa|[4.9,3.1,1.5,0.1]|
|Iris-versicolor|Iris-versicolor|[5.0,2.3,3.3,1.0]|
|    Iris-setosa|    Iris-setosa|[5.0,3.2,1.2,0.2]|
|    Iris-setosa|    Iris-setosa|[5.0,3.3,1.4,0.2]|
|    Iris-setosa|    Iris-setosa|[5.0,3.4,1.6,0.4]|
|    Iris-setosa|    Iris-setosa|[5.1,3.3,1.7,0.5]|
|    Iris-setosa|    Iris-setosa|[5.1,3.7,1.5,0.4]|
|    Iris-setosa|    Iris-setosa|[5.3,3.7,1.5,0.2]|
|    Iris-setosa|    Iris-setosa|[5.4,3.4,1.5,0.4]|
|    Iris-setosa|    Iris-setosa|[5.4,3.9,1.7,0.4]|
|Iris-versicolor|Iris-versicolor|[5.5,2.3,4.0,1.3]|
|Iris-versicolor|Iris-versicolor|[5.5,2.5,4.0,1.3]|
|Iris-versicolor|Iris-versicolor|[5.5,2.6,4.4,1.2]|
|    Iris-setosa|    Iris-setosa|[5.5,4.2,1.4,0.2]|
+---------------+---------------+-----------------+
only showing top 20 rows

7. 评估决策树回归模型:

scala 复制代码
scala> val evaluatorRegressor = new RegressionEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("rmse")
evaluatorRegressor: org.apache.spark.ml.evaluation.RegressionEvaluator = regEval_162861380a26

scala> val rmse = evaluatorRegressor.evaluate(predictionsRegressor)
rmse: Double = 0.2449489742783178

scala> println("Root Mean Squared Error (RMSE) on test data = " + rmse)
Root Mean Squared Error (RMSE) on test data = 0.2449489742783178

scala> val treeModelRegressor = modelRegressor.stages(2).asInstanceOf[DecisionTreeRegressionModel]
treeModelRegressor: org.apache.spark.ml.regression.DecisionTreeRegressionModel = DecisionTreeRegressionModel (uid=dtr_e98e9ef10e22) of depth 4 with 13 nodes

scala> println("Learned regression tree model:\n" + treeModelRegressor.toDebugString)
Learned regression tree model:
DecisionTreeRegressionModel (uid=dtr_e98e9ef10e22) of depth 4 with 13 nodes
  If (feature 2 <= 1.9)
   Predict: 0.0
  Else (feature 2 > 1.9)
   If (feature 3 <= 1.6)
    If (feature 2 <= 4.9)
     Predict: 1.0
    Else (feature 2 > 4.9)
     If (feature 0 <= 6.0)
      Predict: 1.0
     Else (feature 0 > 6.0)
      Predict: 2.0
   Else (feature 3 > 1.6)
    If (feature 2 <= 4.8)
     If (feature 1 <= 2.8)
      Predict: 2.0
     Else (feature 1 > 2.8)
      Predict: 1.0
    Else (feature 2 > 4.8)
     Predict: 2.0
相关推荐
xiaoshiguang33 小时前
LeetCode:222.完全二叉树节点的数量
算法·leetcode
爱吃西瓜的小菜鸡3 小时前
【C语言】判断回文
c语言·学习·算法
别NULL3 小时前
机试题——疯长的草
数据结构·c++·算法
TT哇3 小时前
*【每日一题 提高题】[蓝桥杯 2022 国 A] 选素数
java·算法·蓝桥杯
yuanbenshidiaos5 小时前
C++----------函数的调用机制
java·c++·算法
唐叔在学习5 小时前
【唐叔学算法】第21天:超越比较-计数排序、桶排序与基数排序的Java实践及性能剖析
数据结构·算法·排序算法
ALISHENGYA5 小时前
全国青少年信息学奥林匹克竞赛(信奥赛)备考实战之分支结构(switch语句)
数据结构·算法
chengooooooo5 小时前
代码随想录训练营第二十七天| 贪心理论基础 455.分发饼干 376. 摆动序列 53. 最大子序和
算法·leetcode·职场和发展
jackiendsc5 小时前
Java的垃圾回收机制介绍、工作原理、算法及分析调优
java·开发语言·算法