方法简介
逻辑斯蒂回归(logistic regression)是统计学习中的经典分类方法,属于对数线性模型。logistic回归的因变量可以是二分类的,也可以是多分类的。
示例代码
我们以iris数据集(iris)为例进行分析。iris以鸢尾花的特征作为数据来源,数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性,是在数据挖掘、数据分类中非常常用的测试集、训练集。为了便于理解,我们这里主要用后两个属性(花瓣的长度和宽度)来进行分类。目前 spark.ml 中支持二分类和多分类,我们将分别从"用二项逻辑斯蒂回归来解决二分类问题"、"用多项逻辑斯蒂回归来解决二分类问题"、"用多项逻辑斯蒂回归来解决多分类问题"三个方面进行分析。
用二项逻辑斯蒂回归解决 二分类 问题
首先我们先取其中的后两类数据,用二项逻辑斯蒂回归进行二分类分析。
1. 导入需要的包:
scala
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.linalg.{Vector,Vectors}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.{Pipeline,PipelineModel}
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer,HashingTF, Tokenizer}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.classification.LogisticRegressionModel
import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression}
import org.apache.spark.sql.functions;
2. 读取数据,简要分析:
导入spark.implicits._,使其支持把一个RDD隐式转换为一个DataFrame。我们用case class定义一个schema:Iris,Iris就是我们需要的数据的结构;然后读取文本文件,第一个map把每行的数据用","隔开,比如在我们的数据集中,每行被分成了5部分,前4部分是鸢尾花的4个特征,最后一部分是鸢尾花的分类;我们这里把特征存储在Vector中,创建一个Iris模式的RDD,然后转化成dataframe;最后调用show()方法来查看一下部分数据。
scala
scala> import spark.implicits._
import spark.implicits._
scala> case class Iris(features: org.apache.spark.ml.linalg.Vector, label: String)
defined class Iris
//使用spark.sparkContext.textFile()读取文件
scala> val data = spark.sparkContext.textFile("file:///root/data/iris.txt").map(_.split(",")).map(p => Iris(Vectors.dense(p(0).toDouble,p(1).toDouble,p(2).toDouble,p(3).toDouble),p(4).toString())).toDF()
data: org.apache.spark.sql.DataFrame = [features: vector, label: string]
//也可以使用spark.read.textFile() 读取文件
scala> val dataTest = spark.read.textFile("file:///root/data/iris.txt").map(_.split(",")).map(p => Iris(Vectors.dense(p(0).toDouble,p(1).toDouble,p(2).toDouble,p(3).toDouble),p(4).toString())).toDF()
dataTest: org.apache.spark.sql.DataFrame = [features: vector, label: string]
scala> data.show()
+-----------------+-----------+
| features| label|
+-----------------+-----------+
|[5.1,3.5,1.4,0.2]|Iris-setosa|
|[4.9,3.0,1.4,0.2]|Iris-setosa|
|[4.7,3.2,1.3,0.2]|Iris-setosa|
|[4.6,3.1,1.5,0.2]|Iris-setosa|
|[5.0,3.6,1.4,0.2]|Iris-setosa|
|[5.4,3.9,1.7,0.4]|Iris-setosa|
|[4.6,3.4,1.4,0.3]|Iris-setosa|
|[5.0,3.4,1.5,0.2]|Iris-setosa|
|[4.4,2.9,1.4,0.2]|Iris-setosa|
|[4.9,3.1,1.5,0.1]|Iris-setosa|
|[5.4,3.7,1.5,0.2]|Iris-setosa|
|[4.8,3.4,1.6,0.2]|Iris-setosa|
|[4.8,3.0,1.4,0.1]|Iris-setosa|
|[4.3,3.0,1.1,0.1]|Iris-setosa|
|[5.8,4.0,1.2,0.2]|Iris-setosa|
|[5.7,4.4,1.5,0.4]|Iris-setosa|
|[5.4,3.9,1.3,0.4]|Iris-setosa|
|[5.1,3.5,1.4,0.3]|Iris-setosa|
|[5.7,3.8,1.7,0.3]|Iris-setosa|
|[5.1,3.8,1.5,0.3]|Iris-setosa|
+-----------------+-----------+
only showing top 20 rows
因为我们现在处理的是2分类问题,所以我们不需要全部的3类数据,我们要从中选出两类的数据。这里首先把刚刚得到的数据注册成一个表iris,注册成这个表之后,我们就可以通过sql语句进行数据查询,比如我们这里选出了所有不属于"Iris-setosa"类别的数据;选出我们需要的数据后,我们可以把结果打印出来看一下,这时就已经没有"Iris-setosa"类别的数据。
scala
scala> data.createOrReplaceTempView("iris")
scala> val df = spark.sql("select * from iris where label != 'Iris-setosa'")
df: org.apache.spark.sql.DataFrame = [features: vector, label: string]
//因为有三类数据Iris-setosa、Iris-versicolor和Iris-virginica,没类数据有50个样本,现在去除Iris-setosa类的样本,现在还剩100个样本。
scala> df.count()
res4: Long = 100
scala> df.map(t => t(1)+":"+t(0)).collect().foreach(println)
Iris-versicolor:[7.0,3.2,4.7,1.4]
Iris-versicolor:[6.4,3.2,4.5,1.5]
Iris-versicolor:[6.9,3.1,4.9,1.5]
Iris-versicolor:[5.5,2.3,4.0,1.3]
Iris-versicolor:[6.5,2.8,4.6,1.5]
Iris-versicolor:[5.7,2.8,4.5,1.3]
Iris-versicolor:[6.3,3.3,4.7,1.6]
Iris-versicolor:[4.9,2.4,3.3,1.0]
Iris-versicolor:[6.6,2.9,4.6,1.3]
Iris-versicolor:[5.2,2.7,3.9,1.4]
Iris-versicolor:[5.0,2.0,3.5,1.0]
Iris-versicolor:[5.9,3.0,4.2,1.5]
Iris-versicolor:[6.0,2.2,4.0,1.0]
Iris-versicolor:[6.1,2.9,4.7,1.4]
Iris-versicolor:[5.6,2.9,3.6,1.3]
Iris-versicolor:[6.7,3.1,4.4,1.4]
Iris-versicolor:[5.6,3.0,4.5,1.5]
Iris-versicolor:[5.8,2.7,4.1,1.0]
Iris-versicolor:[6.2,2.2,4.5,1.5]
Iris-versicolor:[5.6,2.5,3.9,1.1]
Iris-versicolor:[5.9,3.2,4.8,1.8]
Iris-versicolor:[6.1,2.8,4.0,1.3]
Iris-versicolor:[6.3,2.5,4.9,1.5]
Iris-versicolor:[6.1,2.8,4.7,1.2]
Iris-versicolor:[6.4,2.9,4.3,1.3]
Iris-versicolor:[6.6,3.0,4.4,1.4]
Iris-versicolor:[6.8,2.8,4.8,1.4]
Iris-versicolor:[6.7,3.0,5.0,1.7]
Iris-versicolor:[6.0,2.9,4.5,1.5]
Iris-versicolor:[5.7,2.6,3.5,1.0]
Iris-versicolor:[5.5,2.4,3.8,1.1]
Iris-versicolor:[5.5,2.4,3.7,1.0]
Iris-versicolor:[5.8,2.7,3.9,1.2]
Iris-versicolor:[6.0,2.7,5.1,1.6]
Iris-versicolor:[5.4,3.0,4.5,1.5]
Iris-versicolor:[6.0,3.4,4.5,1.6]
Iris-versicolor:[6.7,3.1,4.7,1.5]
Iris-versicolor:[6.3,2.3,4.4,1.3]
Iris-versicolor:[5.6,3.0,4.1,1.3]
Iris-versicolor:[5.5,2.5,4.0,1.3]
Iris-versicolor:[5.5,2.6,4.4,1.2]
Iris-versicolor:[6.1,3.0,4.6,1.4]
Iris-versicolor:[5.8,2.6,4.0,1.2]
Iris-versicolor:[5.0,2.3,3.3,1.0]
Iris-versicolor:[5.6,2.7,4.2,1.3]
Iris-versicolor:[5.7,3.0,4.2,1.2]
Iris-versicolor:[5.7,2.9,4.2,1.3]
Iris-versicolor:[6.2,2.9,4.3,1.3]
Iris-versicolor:[5.1,2.5,3.0,1.1]
Iris-versicolor:[5.7,2.8,4.1,1.3]
Iris-virginica:[6.3,3.3,6.0,2.5]
Iris-virginica:[5.8,2.7,5.1,1.9]
Iris-virginica:[7.1,3.0,5.9,2.1]
Iris-virginica:[6.3,2.9,5.6,1.8]
Iris-virginica:[6.5,3.0,5.8,2.2]
Iris-virginica:[7.6,3.0,6.6,2.1]
Iris-virginica:[4.9,2.5,4.5,1.7]
Iris-virginica:[7.3,2.9,6.3,1.8]
Iris-virginica:[6.7,2.5,5.8,1.8]
Iris-virginica:[7.2,3.6,6.1,2.5]
Iris-virginica:[6.5,3.2,5.1,2.0]
Iris-virginica:[6.4,2.7,5.3,1.9]
Iris-virginica:[6.8,3.0,5.5,2.1]
Iris-virginica:[5.7,2.5,5.0,2.0]
Iris-virginica:[5.8,2.8,5.1,2.4]
Iris-virginica:[6.4,3.2,5.3,2.3]
Iris-virginica:[6.5,3.0,5.5,1.8]
Iris-virginica:[7.7,3.8,6.7,2.2]
Iris-virginica:[7.7,2.6,6.9,2.3]
Iris-virginica:[6.0,2.2,5.0,1.5]
Iris-virginica:[6.9,3.2,5.7,2.3]
Iris-virginica:[5.6,2.8,4.9,2.0]
Iris-virginica:[7.7,2.8,6.7,2.0]
Iris-virginica:[6.3,2.7,4.9,1.8]
Iris-virginica:[6.7,3.3,5.7,2.1]
Iris-virginica:[7.2,3.2,6.0,1.8]
Iris-virginica:[6.2,2.8,4.8,1.8]
Iris-virginica:[6.1,3.0,4.9,1.8]
Iris-virginica:[6.4,2.8,5.6,2.1]
Iris-virginica:[7.2,3.0,5.8,1.6]
Iris-virginica:[7.4,2.8,6.1,1.9]
Iris-virginica:[7.9,3.8,6.4,2.0]
Iris-virginica:[6.4,2.8,5.6,2.2]
Iris-virginica:[6.3,2.8,5.1,1.5]
Iris-virginica:[6.1,2.6,5.6,1.4]
Iris-virginica:[7.7,3.0,6.1,2.3]
Iris-virginica:[6.3,3.4,5.6,2.4]
Iris-virginica:[6.4,3.1,5.5,1.8]
Iris-virginica:[6.0,3.0,4.8,1.8]
Iris-virginica:[6.9,3.1,5.4,2.1]
Iris-virginica:[6.7,3.1,5.6,2.4]
Iris-virginica:[6.9,3.1,5.1,2.3]
Iris-virginica:[5.8,2.7,5.1,1.9]
Iris-virginica:[6.8,3.2,5.9,2.3]
Iris-virginica:[6.7,3.3,5.7,2.5]
Iris-virginica:[6.7,3.0,5.2,2.3]
Iris-virginica:[6.3,2.5,5.0,1.9]
Iris-virginica:[6.5,3.0,5.2,2.0]
Iris-virginica:[6.2,3.4,5.4,2.3]
Iris-virginica:[5.9,3.0,5.1,1.8]
3. 构建ML的pipeline
分别获取标签列和特征列,进行索引,并进行了重命名。
scala
scala> val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(df)
labelIndexer: org.apache.spark.ml.feature.StringIndexerModel = strIdx_a43d5773da03
scala> val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").fit(df)
featureIndexer: org.apache.spark.ml.feature.VectorIndexerModel = vecIdx_65fcd96b5dbc
接下来,我们把数据集随机分成训练集和测试集,其中训练集占70%。
scala
scala> val Array(trainingData, testData) = df.randomSplit(Array(0.7, 0.3))
trainingData: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [features: vector, label: string]
testData: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [features: vector, label: string]
然后,我们设置logistic的参数,这里我们统一用setter的方法来设置,也可以用ParamMap来设置(具体的可以查看spark mllib的官网)。这里我们设置了循环次数为10次,正则化项为0.3等,具体的可以设置的参数可以通过explainParams()来获取,还能看到我们已经设置的参数的结果。
scala
scala> val lr = new LogisticRegression().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8)
lr: org.apache.spark.ml.classification.LogisticRegression = logreg_d5c473ad2b66
scala> println("LogisticRegression parameters:\n" + lr.explainParams() + "\n")
LogisticRegression parameters:
aggregationDepth: suggested depth for treeAggregate (>= 2) (default: 2)
elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty (default: 0.0, current: 0.8)
family: The name of family which is a description of the label distribution to be used in the model. Supported options: auto, binomial, multinomial. (default: auto)
featuresCol: features column name (default: features, current: indexedFeatures)
fitIntercept: whether to fit an intercept term (default: true)
labelCol: label column name (default: label, current: indexedLabel)
lowerBoundsOnCoefficients: The lower bounds on coefficients if fitting under bound constrained optimization. (undefined)
lowerBoundsOnIntercepts: The lower bounds on intercepts if fitting under bound constrained optimization. (undefined)
maxIter: maximum number of iterations (>= 0) (default: 100, current: 10)
predictionCol: prediction column name (default: prediction)
probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities (default: probability)
rawPredictionCol: raw prediction (a.k.a. confidence) column name (default: rawPrediction)
regParam: regularization parameter (>= 0) (default: 0.0, current: 0.3)
standardization: whether to standardize the training features before fitting the model (default: true)
threshold: threshold in binary classification prediction, in range [0, 1] (default: 0.5)
thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold (undefined)
tol: the convergence tolerance for iterative algorithms (>= 0) (default: 1.0E-6)
upperBoundsOnCoefficients: The upper bounds on coefficients if fitting under bound constrained optimization. (undefined)
upperBoundsOnIntercepts: The upper bounds on intercepts if fitting under bound constrained optimization. (undefined)
weightCol: weight column name. If this is not set or empty, we treat all instance weights as 1.0 (undefined)
这里我们设置一个labelConverter,目的是把预测的类别重新转化成字符型的。
scala
scala> val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
labelConverter: org.apache.spark.ml.feature.IndexToString = idxToStr_9b53b11ef180
构建pipeline,设置stage,然后调用fit()来训练模型。
scala
val lrPipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, lr, labelConverter))
lrPipeline: org.apache.spark.ml.Pipeline = pipeline_2c8791fd18dd
scala> val lrPipelineModel = lrPipeline.fit(trainingData)
17/09/07 10:05:03 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
17/09/07 10:05:03 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS
lrPipelineModel: org.apache.spark.ml.PipelineModel = pipeline_2c8791fd18dd
pipeline本质上是一个Estimator,当pipeline调用fit()的时候就产生了一个PipelineModel,本质上是一个Transformer。然后这个PipelineModel就可以调用transform()来进行预测,生成一个新的DataFrame,即利用训练得到的模型对测试集进行验证。
scala
scala> val lrPredictions = lrPipelineModel.transform(testData)
lrPredictions: org.apache.spark.sql.DataFrame = [features: vector, label: string ... 6 more fields]
scala> lrPredictions.show()
(Iris-versicolor, [5.6,3.0,4.5,1.5]) --> prob=[0.5015822116420338,0.4984177883579662], predicted Label=Iris-versicolor
+-----------------+---------------+------------+------------------+--------------------+--------------------+----------+---------------+
| features| label|indexedLabel| indexedFeatures| rawPrediction| probability|prediction| predictedLabel|
+-----------------+---------------+------------+------------------+--------------------+--------------------+----------+---------------+
|[5.6,3.0,4.5,1.5]|Iris-versicolor| 0.0| [5.6,9.0,4.5,5.0]|[0.00632886769305...|[0.50158221164203...| 0.0|Iris-versicolor|
|[5.8,2.7,4.1,1.0]|Iris-versicolor| 0.0| [5.8,6.0,4.1,0.0]|[0.46667028524646...|[0.61459535539331...| 0.0|Iris-versicolor|
|[6.0,2.2,4.0,1.0]|Iris-versicolor| 0.0| [6.0,1.0,4.0,0.0]|[0.47942122856815...|[0.61761119701304...| 0.0|Iris-versicolor|
|[6.6,3.0,4.4,1.4]|Iris-versicolor| 0.0| [6.6,9.0,4.4,4.0]|[0.15960167914787...|[0.53981593737467...| 0.0|Iris-versicolor|
|[6.7,3.0,5.0,1.7]|Iris-versicolor| 0.0| [6.7,9.0,5.0,7.0]|[-0.1025771337303...|[0.47437817884138...| 1.0| Iris-virginica|
|[4.9,2.5,4.5,1.7]| Iris-virginica| 1.0| [4.9,4.0,4.5,7.0]|[-0.2173356236255...|[0.44587895949586...| 1.0| Iris-virginica|
|[5.5,2.4,3.8,1.1]|Iris-versicolor| 0.0| [5.5,3.0,3.8,1.0]|[0.35802577541757...|[0.58856244608996...| 0.0|Iris-versicolor|
|[5.6,3.0,4.1,1.3]|Iris-versicolor| 0.0| [5.6,9.0,4.1,3.0]|[0.18536505738574...|[0.54620902741999...| 0.0|Iris-versicolor|
|[5.7,2.5,5.0,2.0]| Iris-virginica| 1.0|[5.7,4.0,5.0,10.0]|[-0.4348861348778...|[0.39296017326534...| 1.0| Iris-virginica|
|[5.7,2.9,4.2,1.3]|Iris-versicolor| 0.0| [5.7,8.0,4.2,3.0]|[0.19174052904658...|[0.54778881119272...| 0.0|Iris-versicolor|
|[5.8,2.6,4.0,1.2]|Iris-versicolor| 0.0| [5.8,5.0,4.0,2.0]|[0.28763409555377...|[0.57141682194390...| 0.0|Iris-versicolor|
|[5.8,2.7,3.9,1.2]|Iris-versicolor| 0.0| [5.8,6.0,3.9,2.0]|[0.28763409555377...|[0.57141682194390...| 0.0|Iris-versicolor|
|[5.8,2.7,5.1,1.9]| Iris-virginica| 1.0| [5.8,6.0,5.1,9.0]|[-0.3389925683706...|[0.41605421498910...| 1.0| Iris-virginica|
|[5.9,3.0,5.1,1.8]| Iris-virginica| 1.0| [5.9,9.0,5.1,8.0]|[-0.2430990018634...|[0.43952279234922...| 1.0| Iris-virginica|
|[6.0,2.7,5.1,1.6]|Iris-versicolor| 0.0| [6.0,6.0,5.1,6.0]|[-0.0576873405098...|[0.48558216299242...| 1.0| Iris-virginica|
|[6.0,2.9,4.5,1.5]|Iris-versicolor| 0.0| [6.0,8.0,4.5,5.0]|[0.03183075433644...|[0.50795701676004...| 0.0|Iris-versicolor|
|[6.1,2.6,5.6,1.4]| Iris-virginica| 1.0| [6.1,5.0,5.6,4.0]|[0.12772432084363...|[0.53188774193068...| 0.0|Iris-versicolor|
|[6.1,3.0,4.6,1.4]|Iris-versicolor| 0.0| [6.1,9.0,4.6,4.0]|[0.12772432084363...|[0.53188774193068...| 0.0|Iris-versicolor|
|[6.3,2.5,5.0,1.9]| Iris-virginica| 1.0| [6.3,4.0,5.0,9.0]|[-0.3071152100663...|[0.42381903909131...| 1.0| Iris-virginica|
|[6.3,2.9,5.6,1.8]| Iris-virginica| 1.0| [6.3,8.0,5.6,8.0]|[-0.2175971152200...|[0.44581435344357...| 1.0| Iris-virginica|
+-----------------+---------------+------------+------------------+--------------------+--------------------+----------+---------------+
only showing top 20 rows
注意:
probability
是一个Vector
,例如上面第一行数据中probability
为[0.5015822116420338,0.4984177883579662]
,表示预测每一类的概率,两者之和等于1.
最后我们可以输出预测的结果,其中select选择要输出的列,collect获取所有行的数据,用foreach把每行打印出来。其中打印出来的值依次分别代表该行数据的真实分类和特征值、预测属于不同分类的概率、预测的分类。
scala
scala> lrPredictions.select("predictedLabel", "label", "features", "probability").collect().foreach { case Row(predictedLabel: String, label: String, features: Vector, prob: Vector) => println(s"($label, $features) --> prob=$prob, predicted Label=$predictedLabel")}
(Iris-versicolor, [5.6,3.0,4.5,1.5]) --> prob=[0.5015822116420338,0.4984177883579662], predicted Label=Iris-versicolor
(Iris-versicolor, [5.8,2.7,4.1,1.0]) --> prob=[0.6145953553933187,0.3854046446066813], predicted Label=Iris-versicolor
(Iris-versicolor, [6.0,2.2,4.0,1.0]) --> prob=[0.617611197013045,0.38238880298695493], predicted Label=Iris-versicolor
(Iris-versicolor, [6.6,3.0,4.4,1.4]) --> prob=[0.539815937374677,0.460184062625323], predicted Label=Iris-versicolor
(Iris-versicolor, [6.7,3.0,5.0,1.7]) --> prob=[0.47437817884138095,0.525621821158619], predicted Label=Iris-virginica
(Iris-virginica, [4.9,2.5,4.5,1.7]) --> prob=[0.44587895949586764,0.5541210405041324], predicted Label=Iris-virginica
(Iris-versicolor, [5.5,2.4,3.8,1.1]) --> prob=[0.5885624460899619,0.41143755391003817], predicted Label=Iris-versicolor
(Iris-versicolor, [5.6,3.0,4.1,1.3]) --> prob=[0.5462090274199968,0.45379097258000317], predicted Label=Iris-versicolor
(Iris-virginica, [5.7,2.5,5.0,2.0]) --> prob=[0.39296017326534144,0.6070398267346586], predicted Label=Iris-virginica
(Iris-versicolor, [5.7,2.9,4.2,1.3]) --> prob=[0.5477888111927275,0.45221118880727246], predicted Label=Iris-versicolor
(Iris-versicolor, [5.8,2.6,4.0,1.2]) --> prob=[0.5714168219439002,0.42858317805609975], predicted Label=Iris-versicolor
(Iris-versicolor, [5.8,2.7,3.9,1.2]) --> prob=[0.5714168219439002,0.42858317805609975], predicted Label=Iris-versicolor
(Iris-virginica, [5.8,2.7,5.1,1.9]) --> prob=[0.41605421498910905,0.583945785010891], predicted Label=Iris-virginica
(Iris-virginica, [5.9,3.0,5.1,1.8]) --> prob=[0.4395227923492292,0.5604772076507708], predicted Label=Iris-virginica
(Iris-versicolor, [6.0,2.7,5.1,1.6]) --> prob=[0.4855821629924292,0.5144178370075708], predicted Label=Iris-virginica
(Iris-versicolor, [6.0,2.9,4.5,1.5]) --> prob=[0.5079570167600491,0.4920429832399509], predicted Label=Iris-versicolor
(Iris-virginica, [6.1,2.6,5.6,1.4]) --> prob=[0.531887741930683,0.46811225806931694], predicted Label=Iris-versicolor
(Iris-versicolor, [6.1,3.0,4.6,1.4]) --> prob=[0.531887741930683,0.46811225806931694], predicted Label=Iris-versicolor
(Iris-virginica, [6.3,2.5,5.0,1.9]) --> prob=[0.4238190390913101,0.5761809609086899], predicted Label=Iris-virginica
(Iris-virginica, [6.3,2.9,5.6,1.8]) --> prob=[0.44581435344357156,0.5541856465564284], predicted Label=Iris-virginica
(Iris-virginica, [6.3,3.3,6.0,2.5]) --> prob=[0.30064595369512165,0.6993540463048784], predicted Label=Iris-virginica
(Iris-virginica, [6.4,3.1,5.5,1.8]) --> prob=[0.4473900414350662,0.5526099585649339], predicted Label=Iris-virginica
(Iris-virginica, [6.7,2.5,5.8,1.8]) --> prob=[0.45212332546733563,0.5478766745326644], predicted Label=Iris-virginica
(Iris-virginica, [6.7,3.0,5.2,2.3]) --> prob=[0.3453175901585396,0.6546824098414604], predicted Label=Iris-virginica
(Iris-versicolor, [6.7,3.1,4.7,1.5]) --> prob=[0.5191054573756626,0.4808945426243374], predicted Label=Iris-versicolor
(Iris-virginica, [6.8,3.2,5.9,2.3]) --> prob=[0.3467603323149756,0.6532396676850243], predicted Label=Iris-virginica
(Iris-virginica, [7.2,3.0,5.8,1.6]) --> prob=[0.5047044410242485,0.49529555897575156], predicted Label=Iris-versicolor
(Iris-virginica, [7.3,2.9,6.3,1.8]) --> prob=[0.4616150767129017,0.5383849232870982], predicted Label=Iris-virginica
(Iris-virginica, [7.7,2.6,6.9,2.3]) --> prob=[0.3598694101321466,0.6401305898678533], predicted Label=Iris-virginica
(Iris-virginica, [7.7,2.8,6.7,2.0]) --> prob=[0.4237551850416873,0.5762448149583126], predicted Label=Iris-virginica
4. 模型评估
创建一个MulticlassClassificationEvaluator实例,用setter方法把预测分类的列名和真实分类的列名进行设置;然后计算预测准确率和错误率。
scala
scala> val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")
evaluator: org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator = mcEval_e27111142138
scala> val lrAccuracy = evaluator.evaluate(lrPredictions)
lrAccuracy: Double = 0.8666666666666667
scala> println("Test Error = " + (1.0 - lrAccuracy))
Test Error = 0.1333333333333333
从上面可以看到预测的准确性达到86.6%,接下来我们可以通过model来获取我们训练得到的逻辑斯蒂模型。前面已经说过model是一个PipelineModel,因此我们可以通过调用它的stages来获取模型,具体如下:
scala
scala> val lrModel = lrPipelineModel.stages(2).asInstanceOf[LogisticRegressionModel]
lrModel: org.apache.spark.ml.classification.LogisticRegressionModel = logreg_d5c473ad2b66
scala> println("Coefficients: " + lrModel.coefficients+"\n"+"Intercept: "+lrModel.intercept+"\n"+"numClasses: "+lrModel.numClasses+"\n"+"numFeatures: "+lrModel.numFeatures)
Coefficients: [-0.06375471660847784,0.0,0.0,0.08951809484634254]
Intercept: -0.09689292891729179
numClasses: 2
numFeatures: 4
- Coefficients: 系数(W)
- Intercept: 截距(b)
- numClasses: 分类种类个数
- numFeatures: 特征向量个数
Wx + b
5. 模型评估
spark的ml库还提供了一个对模型的摘要总结(summary
),不过目前只支持二项逻辑斯蒂回归,而且要显示转化成BinaryLogisticRegressionSummary
。在下面的代码中,首先获得二项逻辑斯模型的摘要;然后获得10次循环中损失函数的变化,并将结果打印出来,可以看到损失函数随着循环是逐渐变小的,损失函数越小,模型就越好;接下来,我们把摘要强制转化为BinaryLogisticRegressionSummary
,来获取用来评估模型性能的矩阵;通过获取ROC,我们可以判断模型的好坏,areaUnderROC
达到了 0.9848856209150327
,说明我们的分类器还是不错的;最后,我们通过最大化fMeasure
来选取最合适的阈值,其中fMeasure
是一个综合了召回率和准确率的指标,通过最大化fMeasure
,我们可以选取到用来分类的最合适的阈值。
scala
scala> val trainingSummary = lrModel.summary
trainingSummary: org.apache.spark.ml.classification.LogisticRegressionTrainingSummary = org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary@20e7a10
scala> val objectiveHistory = trainingSummary.objectiveHistory
objectiveHistory: Array[Double] = Array(0.6927389617440809, 0.6899308387407462, 0.6865569217265758, 0.6736046289373718, 0.6694299346003736, 0.6681137586393369, 0.6679359916200565, 0.6673276439679328, 0.6668085232695621, 0.6665822228169366, 0.6663915835564641)
scala> objectiveHistory.foreach(loss => println(loss))
0.6927389617440809
0.6899308387407462
0.6865569217265758
0.6736046289373718
0.6694299346003736
0.6681137586393369
0.6679359916200565
0.6673276439679328
0.6668085232695621
0.6665822228169366
0.6663915835564641
scala> val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary]
binarySummary: org.apache.spark.ml.classification.BinaryLogisticRegressionSummary = org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary@20e7a10
scala> println(s"areaUnderROC: ${binarySummary.areaUnderROC}")
areaUnderROC: 0.9848856209150327
scala> val fMeasure = binarySummary.fMeasureByThreshold
fMeasure: org.apache.spark.sql.DataFrame = [threshold: double, F-Measure: double]
scala> val maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble(0)
maxFMeasure: Double = 0.955223880597015
scala> val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure).select("threshold").head().getDouble(0)
bestThreshold: Double = 0.5399690045423275
scala> lrModel.setThreshold(bestThreshold)
res17: lrModel.type = logreg_d5c473ad2b66
用多项逻辑斯蒂回归解决 二分类 问题
对于二分类问题,我们还可以用多项逻辑斯蒂回归进行多分类分析。多项逻辑斯蒂回归与二项逻辑斯蒂回归类似,只是在模型设置上把family参数设置成multinomial,这里我们仅列出结果:
scala
scala> val mlr = new LogisticRegression().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8).setFamily("multinomial")
mlr: org.apache.spark.ml.classification.LogisticRegression = logreg_82bf612d153e
scala> val mlrPipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, mlr, labelConverter))
mlrPipeline: org.apache.spark.ml.Pipeline = pipeline_016e81dbca1f
scala> val mlrPipelineModel = mlrPipeline.fit(trainingData)
mlrPipelineModel: org.apache.spark.ml.PipelineModel = pipeline_016e81dbca1f
scala> val mlrPredictions = mlrPipelineModel.transform(testData)
mlrPredictions: org.apache.spark.sql.DataFrame = [features: vector, label: string ... 6 more fields]
scala> mlrPredictions.select("predictedLabel", "label", "features", "probability").collect().foreach { case Row(predictedLabel: String, label: String, features:
Vector, prob: Vector) => println(s"($label, $features) --> prob=$prob, predictedLabel=$predictedLabel")}
(Iris-virginica, [4.9,2.5,4.5,1.7]) --> prob=[0.4706991400166566,0.5293008599833434], predictedLabel=Iris-virginica
(Iris-versicolor, [5.1,2.5,3.0,1.1]) --> prob=[0.6123754219240134,0.38762457807598644], predictedLabel=Iris-versicolor
(Iris-versicolor, [5.5,2.3,4.0,1.3]) --> prob=[0.5724859784244956,0.42751402157550444], predictedLabel=Iris-versicolor
(Iris-versicolor, [5.5,2.4,3.8,1.1]) --> prob=[0.617700896993959,0.3822991030060409], predictedLabel=Iris-versicolor
(Iris-virginica, [5.8,2.7,5.1,1.9]) --> prob=[0.43670908827255583,0.563290911727444], predictedLabel=Iris-virginica
(Iris-versicolor, [5.9,3.0,4.2,1.5]) --> prob=[0.5316312191190347,0.4683687808809653], predictedLabel=Iris-versicolor
(Iris-versicolor, [6.1,3.0,4.6,1.4]) --> prob=[0.5577018837203559,0.44229811627964416], predictedLabel=Iris-versicolor
(Iris-virginica, [6.2,3.4,5.4,2.3]) --> prob=[0.3525986631597158,0.6474013368402842], predictedLabel=Iris-virginica
(Iris-versicolor, [6.3,2.3,4.4,1.3]) --> prob=[0.5834583948072782,0.4165416051927219], predictedLabel=Iris-versicolor
(Iris-versicolor, [6.3,3.3,4.7,1.6]) --> prob=[0.5138182157242249,0.4861817842757751], predictedLabel=Iris-versicolor
(Iris-virginica, [6.3,3.3,6.0,2.5]) --> prob=[0.31220890350252506,0.6877910964974749], predictedLabel=Iris-virginica
(Iris-virginica, [6.3,3.4,5.6,2.4]) --> prob=[0.33271906258823314,0.6672809374117669], predictedLabel=Iris-virginica
(Iris-virginica, [6.4,2.8,5.6,2.2]) --> prob=[0.3769557902496345,0.6230442097503656], predictedLabel=Iris-virginica
(Iris-versicolor, [6.4,3.2,4.5,1.5]) --> prob=[0.5386254134782479,0.461374586521752], predictedLabel=Iris-versicolor
(Iris-versicolor, [6.5,2.8,4.6,1.5]) --> prob=[0.5400225182386277,0.4599774817613724], predictedLabel=Iris-versicolor
(Iris-virginica, [6.5,3.0,5.5,1.8]) --> prob=[0.4697204592971098,0.5302795407028901], predictedLabel=Iris-virginica
(Iris-virginica, [6.5,3.2,5.1,2.0]) --> prob=[0.42334262258749805,0.576657377412502], predictedLabel=Iris-virginica
(Iris-versicolor, [6.6,2.9,4.6,1.3]) --> prob=[0.587552435510627,0.412447564489373], predictedLabel=Iris-versicolor
(Iris-virginica, [6.7,3.0,5.2,2.3]) --> prob=[0.35904307162043886,0.6409569283795613], predictedLabel=Iris-virginica
(Iris-virginica, [6.7,3.3,5.7,2.1]) --> prob=[0.4033033147932801,0.5966966852067198], predictedLabel=Iris-virginica
(Iris-virginica, [6.8,3.0,5.5,2.1]) --> prob=[0.40465727034257987,0.5953427296574202], predictedLabel=Iris-virginica
(Iris-virginica, [6.8,3.2,5.9,2.3]) --> prob=[0.36033816936772717,0.6396618306322728], predictedLabel=Iris-virginica
(Iris-versicolor, [6.9,3.1,4.9,1.5]) --> prob=[0.5456044340218952,0.4543955659781048], predictedLabel=Iris-versicolor
(Iris-virginica, [6.9,3.2,5.7,2.3]) --> prob=[0.361635302911029,0.638364697088971], predictedLabel=Iris-virginica
(Iris-virginica, [7.2,3.2,6.0,1.8]) --> prob=[0.47953540973471454,0.5204645902652856], predictedLabel=Iris-virginica
(Iris-virginica, [7.2,3.6,6.1,2.5]) --> prob=[0.3231782795184636,0.6768217204815363], predictedLabel=Iris-virginica
(Iris-virginica, [7.3,2.9,6.3,1.8]) --> prob=[0.48093901384053533,0.5190609861594646], predictedLabel=Iris-virginica
(Iris-virginica, [7.4,2.8,6.1,1.9]) --> prob=[0.4589531699542302,0.5410468300457698], predictedLabel=Iris-virginica
(Iris-virginica, [7.7,2.8,6.7,2.0]) --> prob=[0.43989505147330155,0.5601049485266986], predictedLabel=Iris-virginica
scala> val mlrAccuracy = evaluator.evaluate(mlrPredictions)
mlrAccuracy: Double = 1.0
scala> println("Test Error = " + (1.0 - mlrAccuracy))
Test Error = 0.0
scala> val mlrModel = mlrPipelineModel.stages(2).asInstanceOf[LogisticRegressionModel]
mlrModel: org.apache.spark.ml.classification.LogisticRegressionModel = logreg_82bf612d153e
scala> println("Multinomial coefficients: " + mlrModel.coefficientMatrix+"Multin
omial intercepts: "+mlrModel.interceptVector+"numClasses: "+mlrModel.numClasses+
"numFeatures: "+mlrModel.numFeatures)
Multinomial coefficients: 0.028116025552572706 0.0 0.0 -0.046949976003541706
-0.028116025552572695 0.0 0.0 0.046949976003541706 Multinomial intercepts:
[0.13221236569525302,-0.13221236569525302]numClasses: 2numFeatures: 4
用多项逻辑斯蒂回归解决 多分类 问题
对于多分类问题,我们需要用多项逻辑斯蒂回归进行多分类分析。这里我们用全部的iris数据集,即有三个类别,过程与上述基本一致,这里我们同样仅列出结果:
scala
scala> mlrPredictions.select("predictedLabel", "label", "features", "probability").collect().foreach { case Row(predictedLabel: String, label: String, features:
Vector, prob: Vector) => println(s"($label, $features) --> prob=$prob, predictedLabel=$predictedLabel")}
(Iris-setosa, [4.3,3.0,1.1,0.1]) --> prob=[0.49856067730476944,0.2623440805400292,0.23909524215520148], predictedLabel=Iris-setosa
(Iris-setosa, [4.4,2.9,1.4,0.2]) --> prob=[0.46571089790971687,0.277891570222724,0.25639753186755915], predictedLabel=Iris-setosa
(Iris-setosa, [4.6,3.4,1.4,0.3]) --> prob=[0.5001101367665973,0.25928904940719977,0.24060081382620296], predictedLabel=Iris-setosa
(Iris-setosa, [4.6,3.6,1.0,0.2]) --> prob=[0.5463459284110406,0.236823238870237,0.2168308327187224], predictedLabel=Iris-setosa
(Iris-setosa, [4.7,3.2,1.6,0.2]) --> prob=[0.48370179709200706,0.2689591735381297,0.24733902936986318], predictedLabel=Iris-setosa
(Iris-setosa, [4.8,3.0,1.4,0.1]) --> prob=[0.4851576852808171,0.2693562861247639,0.24548602859441893], predictedLabel=Iris-setosa
(Iris-setosa, [4.9,3.0,1.4,0.2]) --> prob=[0.47467118791268154,0.2733753207454508,0.2519534913418676], predictedLabel=Iris-setosa
(Iris-setosa, [4.9,3.1,1.5,0.1]) --> prob=[0.48967732688779486,0.267131626783035,0.2431910463291701], predictedLabel=Iris-setosa
(Iris-versicolor, [5.0,2.3,3.3,1.0]) --> prob=[0.26303122888674907,0.36560215832179155,0.3713666127914594], predictedLabel=Iris-virginica
(Iris-setosa, [5.0,3.5,1.3,0.3]) --> prob=[0.5135688079539743,0.2524416257183621,0.2339895663276636], predictedLabel=Iris-setosa
(Iris-setosa, [5.0,3.5,1.6,0.6]) --> prob=[0.4686356517088239,0.2713034457686629,0.26006090252251307], predictedLabel=Iris-setosa
(Iris-setosa, [5.1,3.5,1.4,0.3]) --> prob=[0.5091020180722664,0.25475974124614675,0.23613824068158687], predictedLabel=Iris-setosa
(Iris-setosa, [5.1,3.8,1.5,0.3]) --> prob=[0.531570061574297,0.24348517949467904,0.22494475893102403], predictedLabel=Iris-setosa
(Iris-setosa, [5.1,3.8,1.9,0.4]) --> prob=[0.503222274154322,0.25683175058110785,0.23994597526457007], predictedLabel=Iris-setosa
(Iris-setosa, [5.2,3.5,1.5,0.2]) --> prob=[0.5151370941776632,0.2529823490495923,0.2318805567727446], predictedLabel=Iris-setosa
(Iris-setosa, [5.3,3.7,1.5,0.2]) --> prob=[0.5330773525753305,0.24387796024925384,0.22304468717541576], predictedLabel=Iris-setosa
(Iris-versicolor, [5.4,3.0,4.5,1.5]) --> prob=[0.2306542447600023,0.372383222489962,0.39696253275003573], predictedLabel=Iris-virginica
(Iris-setosa, [5.4,3.9,1.3,0.4]) --> prob=[0.5389512877303541,0.23848657002728416,0.2225621422423618], predictedLabel=Iris-setosa
(Iris-versicolor, [5.5,2.4,3.7,1.0]) --> prob=[0.25620601559263473,0.36919246180632764,0.37460152260103763], predictedLabel=Iris-virginica
(Iris-setosa, [5.5,3.5,1.3,0.2]) --> prob=[0.5240613549472979,0.24832602160956213,0.22761262344314004], predictedLabel=Iris-setosa
(Iris-setosa, [5.5,4.2,1.4,0.2]) --> prob=[0.5818115053858839,0.21899706180633755,0.19919143280777854], predictedLabel=Iris-setosa
(Iris-versicolor, [5.6,2.5,3.9,1.1]) --> prob=[0.24827164138938784,0.3712338899987297,0.38049446861188246], predictedLabel=Iris-virginica
(Iris-versicolor, [5.6,2.7,4.2,1.3]) --> prob=[0.23609842674482123,0.3733910806218104,0.39051049263336834], predictedLabel=Iris-virginica
(Iris-virginica, [5.6,2.8,4.9,2.0]) --> prob=[0.17353784667372726,0.38803750951559646,0.43842464381067625], predictedLabel=Iris-virginica
(Iris-versicolor, [5.6,2.9,3.6,1.3]) --> prob=[0.26994082035183004,0.35725015822484213,0.37280902142332784], predictedLabel=Iris-virginica
(Iris-setosa, [5.7,4.4,1.5,0.4]) --> prob=[0.5744990088621882,0.22068271118182198,0.20481827995598978], predictedLabel=Iris-setosa
(Iris-virginica, [5.8,2.8,5.1,2.4]) --> prob=[0.14589555459093273,0.39150544114527663,0.4625990042637906], predictedLabel=Iris-virginica
(Iris-virginica, [5.9,3.0,5.1,1.8]) --> prob=[0.19164845952411863,0.38448782728830505,0.42386371318757643], predictedLabel=Iris-virginica
(Iris-versicolor, [6.0,2.2,4.0,1.0]) --> prob=[0.23300779791940326,0.3802856918956981,0.3867065101848985], predictedLabel=Iris-virginica
(Iris-versicolor, [6.0,2.7,5.1,1.6]) --> prob=[0.18810463050749873,0.3900406691963187,0.42185470029618244], predictedLabel=Iris-virginica
(Iris-versicolor, [6.1,2.8,4.0,1.3]) --> prob=[0.24928433400912278,0.3671520807495573,0.3835635852413199], predictedLabel=Iris-virginica
(Iris-versicolor, [6.2,2.2,4.5,1.5]) --> prob=[0.18351550396686533,0.3934066675024647,0.42307782853066994], predictedLabel=Iris-virginica
(Iris-virginica, [6.2,2.8,4.8,1.8]) --> prob=[0.1888126898204262,0.38539188363903437,0.4257954265405395], predictedLabel=Iris-virginica
(Iris-versicolor, [6.2,2.9,4.3,1.3]) --> prob=[0.24600050420847877,0.3689652108789115,0.38503428491260977], predictedLabel=Iris-virginica
(Iris-virginica, [6.2,3.4,5.4,2.3]) --> prob=[0.17337730890542696,0.3825617039174212,0.44406098717715176], predictedLabel=Iris-virginica
(Iris-virginica, [6.3,2.9,5.6,1.8]) --> prob=[0.1729681423511942,0.3931462837297906,0.4338855739190153], predictedLabel=Iris-virginica
(Iris-virginica, [6.4,3.1,5.5,1.8]) --> prob=[0.18621090846131505,0.3872972795834499,0.42649181195523495], predictedLabel=Iris-virginica
(Iris-versicolor, [6.6,2.9,4.6,1.3]) --> prob=[0.23618909578565045,0.373766365784125,0.3900445384302246], predictedLabel=Iris-virginica
(Iris-virginica, [6.7,3.3,5.7,2.5]) --> prob=[0.1496994275680708,0.38855932284425526,0.4617412495876739], predictedLabel=Iris-virginica
(Iris-virginica, [6.8,3.0,5.5,2.1]) --> prob=[0.16265889090899283,0.39126984184915486,0.4460712672418523], predictedLabel=Iris-virginica
(Iris-virginica, [7.2,3.2,6.0,1.8]) --> prob=[0.1782593898810351,0.3913068582491216,0.43043375186984334], predictedLabel=Iris-virginica
(Iris-virginica, [7.7,2.6,6.9,2.3]) --> prob=[0.10733085394350968,0.41117706558989164,0.4814920804665987], predictedLabel=Iris-virginica
(Iris-virginica, [7.7,3.8,6.7,2.2]) --> prob=[0.16693678799079806,0.38877323991855633,0.44428997209064564], predictedLabel=Iris-virginica
(Iris-virginica, [7.9,3.8,6.4,2.0]) --> prob=[0.18714592916724979,0.3838745095632083,0.42897956126954184], predictedLabel=Iris-virginica
scala> val mlrAccuracy = evaluator.evaluate(mlrPredictions)
mlrAccuracy: Double = 0.6339712918660287
scala> println("Test Error = " + (1.0 - mlrAccuracy))
Test Error = 0.36602870813397126
scala> val mlrModel = mlrPipelineModel.stages(2).asInstanceOf[LogisticRegressionModel]
mlrModel: org.apache.spark.ml.classification.LogisticRegressionModel = logreg_9661a4f56149
scala> println("Multinomial coefficients: " + mlrModel.coefficientMatrix+"Multinomial intercepts: "+mlrModel.interceptVector+"numClasses: "+mlrModel.numClasses+"numFeatures: "+mlrModel.numFeatures)
Multinomial coefficients: 0.0 0.35442627664118775 -0.1787646656602406 -0.36
662299325180614
0.0 0.0 0.0 0.0
0.0 -0.010992364266212548 0.0 0.11193811404312962 Multinomi
al intercepts: [-0.10160079218819881,0.0863062310816332,0.01529456110656562]numC
lasses: 3numFeatures: 4