Spark2.x 入门:高斯混合模型(GMM)聚类算法

模型的训练与分析

Spark的ML库提供的高斯混合模型都在org.apache.spark.ml.clustering包下,和其他的聚类方法类似,其具体实现分为两个类:用于抽象GMM的超参数并进行训练的GaussianMixture类(Estimator)和训练后的模型GaussianMixtureModel类(Transformer),在使用前,引入需要的包:

scala 复制代码
import org.apache.spark.ml.clustering.{GaussianMixture,GaussianMixtureModel}
import org.apache.spark.ml.linalg.{Vector,Vectors}

开启RDD的隐式转换:

scala 复制代码
import spark.implicits._

下文中,我们默认名为sparkSparkSession已经创建。与其他教程相同,本文亦使用模式识别领域广泛使用的UCI数据集中的鸢尾花数据Iris进行实验,它可以在iris获取,Iris数据的样本容量为150,有四个实数值的特征,分别代表花朵四个部位的尺寸,以及该样本对应鸢尾花的亚种类型(共有3种亚种类型),如下所示:

复制代码
5.1,3.5,1.4,0.2,setosa
...
5.4,3.0,4.5,1.5,versicolor
...
7.1,3.0,5.9,2.1,virginica
...

为了便于生成相应的DataFrame,这里定义一个名为model_instancecase class作为DataFrame每一行(一个数据样本)的数据类型。

注:因为是非监督学习,所以不需要数据中的label,只需要使用特征向量数据就可以。

scala 复制代码
scala> case class model_instance(features: org.apache.spark.ml.linalg.Vector)
defined class model_instance

在定义数据类型完成后,即可将数据读入RDD[model_instance]的结构中,并通过RDD的隐式转换.toDF()方法完成RDDDataFrame的转换:

scala 复制代码
val rawData = sc.textFile("file:///root/data/iris.txt")
rawData: org.apache.spark.rdd.RDD[String] = file:///root/data/iris.txt MapPartitionsRDD[220] at textFile at <console>:56

scala> val df = rawData.map(line =>{ model_instance( Vectors.dense(line.split(",").filter(p => p.matches("\\d*(\\.?)\\d*")).map(_.toDouble)) )}).toDF()
df: org.apache.spark.sql.DataFrame = [features: vector]

MLlib版的教程类似,我们使用了filter算子,过滤掉类标签,正则表达式\\d*(\\.?)\\d*可以用于匹配实数类型的数字,\\d*使用了*限定符,表示匹配0次或多次的数字字符,\\.?使用了?限定符,表示匹配0次或1次的小数点。

可以通过创建一个GaussianMixture类,设置相应的超参数,并调用fit(..)方法来训练一个GMM模型GaussianMixtureModel,在该方法调用前需要设置一系列超参数,如下表所示:

参数 含义
K 聚类数目,默认为2
maxIter 最大迭代次数,默认为100
seed 随机数种子,默认为随机Long值
Tol 对数似然函数收敛阈值,默认为0.01

其中,每一个超参数均可通过名为setXXX(...)(如maxIterations即为setMaxIterations())的方法进行设置。

这里,我们建立一个简单的GaussianMixture对象,设定其聚类数目为3,其他参数取默认值。

scala 复制代码
scala> val gm = new GaussianMixture().setK(3).setPredictionCol("Prediction").setProbabilityCol("Probability")
gm: org.apache.spark.ml.clustering.GaussianMixture = GaussianMixture_d8479eee2bdf

scala> val gmm = gm.fit(df)
17/09/07 17:10:07 WARN LAPACK: Failed to load implementation from: com.github.fommil.netlib.NativeSystemLAPACK
17/09/07 17:10:07 WARN LAPACK: Failed to load implementation from: com.github.fommil.netlib.NativeRefLAPACK
gmm: org.apache.spark.ml.clustering.GaussianMixtureModel = GaussianMixture_d8479eee2bdf

和KMeans等硬聚类方法不同的是,除了可以得到对样本的聚簇归属预测外,还可以得到样本属于各个聚簇的概率(这里我们存在"Probability"列中)。

调用transform()方法处理数据集之后,打印数据集,可以看到每一个样本的预测簇以及其概率分布向量(这里为了明晰起见,省略了大部分行,只选择三行):

scala 复制代码
scala> val result = gmm.transform(df)
result: org.apache.spark.sql.DataFrame = [features: vector, Prediction: int ... 1 more field]

scala> result.show(150, false)
+-----------------+----------+------------------------------------------------------------------+
|features         |Prediction|Probability                                                       |
+-----------------+----------+------------------------------------------------------------------+
|[5.1,3.5,1.4,0.2]|0         |[0.9999999999999951,4.682229962936943E-17,4.868372929925642E-15]  |
|[5.8,2.7,5.1,1.9]|1         |[5.443776218961163E-16,0.5809846403552363,0.41901535964476316]    |
|[5.1,2.5,3.0,1.1]|2         |[1.574811409348219E-14,1.646719617640849E-14,0.9999999999999678]  |

得到模型后,即可查看模型的相关参数,与KMeans方法不同,GMM不直接给出聚类中心,而是给出各个混合成分(多元高斯分布)的参数。在ML的实现中,GMM的每一个混合成分都使用一个MultivariateGaussian类(位于org.apache.spark.ml.stat.distribution包)来存储,我们可以使用GaussianMixtureModel类的weights成员获取到各个混合成分的权重,使用gaussians成员来获取到各个混合成分的参数(均值向量和协方差矩阵):

scala 复制代码
scala> for (i <- 0 until gmm.getK) {println("Component %d : weight is %f \n mu vector is %s \n sigma matrix is %s" format(i, gmm.weights(i), gmm.gaussians(i).mean, gmm.gaussians(i).cov))}
Component 0 : weight is 0.333333 
 mu vector is [5.006000336585284,3.41800074359835,1.4640001090120234,0.24399996278677907] 
 sigma matrix is 0.12176391071215485  0.098291689186003     0.01581595534223468   0.01033602571352466   
0.098291689186003    0.14227526345684152   0.011447885703674401  0.01120804907975396   
0.01581595534223468  0.011447885703674401  0.029504001732923526  0.005584009823879005  
0.01033602571352466  0.01120804907975396   0.005584009823879005  0.011264005407846419  
Component 1 : weight is 0.158358 
 mu vector is [6.683368733405749,2.869615454114306,5.6462886220107,2.0056734271362298] 
 sigma matrix is 0.4932801350542996    0.050374713498106384  0.3573203540815376    0.05001856939219663   
0.050374713498106384  0.04009423452906383   0.004169715059366357  0.020005237661706105  
0.3573203540815376    0.004169715059366357  0.33772537665484004   0.01700691760482919   
0.05001856939219663   0.020005237661706105  0.01700691760482919   0.06935869650451344   
Component 2 : weight is 0.508309 
 mu vector is [6.130726266791144,2.872742630634865,4.675369349848154,1.5732931362537985] 
 sigma matrix is 0.34423978263400257  0.1433295221383842   0.3498831855148469   0.1447023418962772   
0.1433295221383842   0.13127254135550323  0.18483272859443253  0.09799971374720756  
0.3498831855148469   0.18483272859443253  0.5558476131836453   0.26987975624410054  
0.1447023418962772   0.09799971374720756  0.26987975624410054  0.16825697031716608 
相关推荐
JK0x071 小时前
代码随想录算法训练营 Day40 动态规划Ⅷ 股票问题
算法·动态规划
Feliz..1 小时前
关于离散化算法的看法与感悟
算法
水蓝烟雨2 小时前
1128. 等价多米诺骨牌对的数量
算法·hot 100
codists2 小时前
《算法导论(第4版)》阅读笔记:p11-p13
算法
QFIUNE3 小时前
数据分析之药物-基因-代谢物
数据挖掘·数据分析
Kidddddult4 小时前
力扣刷题Day 43:矩阵置零(73)
算法·leetcode·力扣
大龄Python青年6 小时前
C语言 交换算法之加减法,及溢出防范
c语言·开发语言·算法
啊我不会诶6 小时前
CF每日5题
算法
zx437 小时前
聚类后的分析:推断簇的类型
人工智能·python·机器学习·聚类
朱剑君7 小时前
排序算法——基数排序
算法·排序算法