Spark2.x 入门:高斯混合模型(GMM)聚类算法

模型的训练与分析

Spark的ML库提供的高斯混合模型都在org.apache.spark.ml.clustering包下,和其他的聚类方法类似,其具体实现分为两个类:用于抽象GMM的超参数并进行训练的GaussianMixture类(Estimator)和训练后的模型GaussianMixtureModel类(Transformer),在使用前,引入需要的包:

scala 复制代码
import org.apache.spark.ml.clustering.{GaussianMixture,GaussianMixtureModel}
import org.apache.spark.ml.linalg.{Vector,Vectors}

开启RDD的隐式转换:

scala 复制代码
import spark.implicits._

下文中,我们默认名为sparkSparkSession已经创建。与其他教程相同,本文亦使用模式识别领域广泛使用的UCI数据集中的鸢尾花数据Iris进行实验,它可以在iris获取,Iris数据的样本容量为150,有四个实数值的特征,分别代表花朵四个部位的尺寸,以及该样本对应鸢尾花的亚种类型(共有3种亚种类型),如下所示:

5.1,3.5,1.4,0.2,setosa
...
5.4,3.0,4.5,1.5,versicolor
...
7.1,3.0,5.9,2.1,virginica
...

为了便于生成相应的DataFrame,这里定义一个名为model_instancecase class作为DataFrame每一行(一个数据样本)的数据类型。

注:因为是非监督学习,所以不需要数据中的label,只需要使用特征向量数据就可以。

scala 复制代码
scala> case class model_instance(features: org.apache.spark.ml.linalg.Vector)
defined class model_instance

在定义数据类型完成后,即可将数据读入RDD[model_instance]的结构中,并通过RDD的隐式转换.toDF()方法完成RDDDataFrame的转换:

scala 复制代码
val rawData = sc.textFile("file:///root/data/iris.txt")
rawData: org.apache.spark.rdd.RDD[String] = file:///root/data/iris.txt MapPartitionsRDD[220] at textFile at <console>:56

scala> val df = rawData.map(line =>{ model_instance( Vectors.dense(line.split(",").filter(p => p.matches("\\d*(\\.?)\\d*")).map(_.toDouble)) )}).toDF()
df: org.apache.spark.sql.DataFrame = [features: vector]

MLlib版的教程类似,我们使用了filter算子,过滤掉类标签,正则表达式\\d*(\\.?)\\d*可以用于匹配实数类型的数字,\\d*使用了*限定符,表示匹配0次或多次的数字字符,\\.?使用了?限定符,表示匹配0次或1次的小数点。

可以通过创建一个GaussianMixture类,设置相应的超参数,并调用fit(..)方法来训练一个GMM模型GaussianMixtureModel,在该方法调用前需要设置一系列超参数,如下表所示:

参数 含义
K 聚类数目,默认为2
maxIter 最大迭代次数,默认为100
seed 随机数种子,默认为随机Long值
Tol 对数似然函数收敛阈值,默认为0.01

其中,每一个超参数均可通过名为setXXX(...)(如maxIterations即为setMaxIterations())的方法进行设置。

这里,我们建立一个简单的GaussianMixture对象,设定其聚类数目为3,其他参数取默认值。

scala 复制代码
scala> val gm = new GaussianMixture().setK(3).setPredictionCol("Prediction").setProbabilityCol("Probability")
gm: org.apache.spark.ml.clustering.GaussianMixture = GaussianMixture_d8479eee2bdf

scala> val gmm = gm.fit(df)
17/09/07 17:10:07 WARN LAPACK: Failed to load implementation from: com.github.fommil.netlib.NativeSystemLAPACK
17/09/07 17:10:07 WARN LAPACK: Failed to load implementation from: com.github.fommil.netlib.NativeRefLAPACK
gmm: org.apache.spark.ml.clustering.GaussianMixtureModel = GaussianMixture_d8479eee2bdf

和KMeans等硬聚类方法不同的是,除了可以得到对样本的聚簇归属预测外,还可以得到样本属于各个聚簇的概率(这里我们存在"Probability"列中)。

调用transform()方法处理数据集之后,打印数据集,可以看到每一个样本的预测簇以及其概率分布向量(这里为了明晰起见,省略了大部分行,只选择三行):

scala 复制代码
scala> val result = gmm.transform(df)
result: org.apache.spark.sql.DataFrame = [features: vector, Prediction: int ... 1 more field]

scala> result.show(150, false)
+-----------------+----------+------------------------------------------------------------------+
|features         |Prediction|Probability                                                       |
+-----------------+----------+------------------------------------------------------------------+
|[5.1,3.5,1.4,0.2]|0         |[0.9999999999999951,4.682229962936943E-17,4.868372929925642E-15]  |
|[5.8,2.7,5.1,1.9]|1         |[5.443776218961163E-16,0.5809846403552363,0.41901535964476316]    |
|[5.1,2.5,3.0,1.1]|2         |[1.574811409348219E-14,1.646719617640849E-14,0.9999999999999678]  |

得到模型后,即可查看模型的相关参数,与KMeans方法不同,GMM不直接给出聚类中心,而是给出各个混合成分(多元高斯分布)的参数。在ML的实现中,GMM的每一个混合成分都使用一个MultivariateGaussian类(位于org.apache.spark.ml.stat.distribution包)来存储,我们可以使用GaussianMixtureModel类的weights成员获取到各个混合成分的权重,使用gaussians成员来获取到各个混合成分的参数(均值向量和协方差矩阵):

scala 复制代码
scala> for (i <- 0 until gmm.getK) {println("Component %d : weight is %f \n mu vector is %s \n sigma matrix is %s" format(i, gmm.weights(i), gmm.gaussians(i).mean, gmm.gaussians(i).cov))}
Component 0 : weight is 0.333333 
 mu vector is [5.006000336585284,3.41800074359835,1.4640001090120234,0.24399996278677907] 
 sigma matrix is 0.12176391071215485  0.098291689186003     0.01581595534223468   0.01033602571352466   
0.098291689186003    0.14227526345684152   0.011447885703674401  0.01120804907975396   
0.01581595534223468  0.011447885703674401  0.029504001732923526  0.005584009823879005  
0.01033602571352466  0.01120804907975396   0.005584009823879005  0.011264005407846419  
Component 1 : weight is 0.158358 
 mu vector is [6.683368733405749,2.869615454114306,5.6462886220107,2.0056734271362298] 
 sigma matrix is 0.4932801350542996    0.050374713498106384  0.3573203540815376    0.05001856939219663   
0.050374713498106384  0.04009423452906383   0.004169715059366357  0.020005237661706105  
0.3573203540815376    0.004169715059366357  0.33772537665484004   0.01700691760482919   
0.05001856939219663   0.020005237661706105  0.01700691760482919   0.06935869650451344   
Component 2 : weight is 0.508309 
 mu vector is [6.130726266791144,2.872742630634865,4.675369349848154,1.5732931362537985] 
 sigma matrix is 0.34423978263400257  0.1433295221383842   0.3498831855148469   0.1447023418962772   
0.1433295221383842   0.13127254135550323  0.18483272859443253  0.09799971374720756  
0.3498831855148469   0.18483272859443253  0.5558476131836453   0.26987975624410054  
0.1447023418962772   0.09799971374720756  0.26987975624410054  0.16825697031716608 
相关推荐
_feivirus_22 分钟前
神经网络_使用TensorFlow预测气温
人工智能·神经网络·算法·tensorflow·预测气温
大柏怎么被偷了31 分钟前
【C++算法】位运算
开发语言·c++·算法
程序猿方梓燚33 分钟前
C/C++实现植物大战僵尸(PVZ)(打地鼠版)
c语言·开发语言·c++·算法·游戏
CPP_ZhouXuyang33 分钟前
C语言——模拟实现strcpy
c语言·开发语言·数据结构·算法·程序员创富
闻缺陷则喜何志丹34 分钟前
【C++前后缀分解 动态规划】2100. 适合野炊的日子|1702
c++·算法·动态规划·力扣·前后缀分解·日子·适合
逝去的秋风1 小时前
【代码随想录训练营第42期 Day57打卡 - 图论Part7 - Prim算法与Kruskal算法
算法·图论·prim算法
QXH2000001 小时前
数据结构—双向链表
c语言·数据结构·算法·链表
旺小仔.1 小时前
【数据结构篇】~排序(1)之插入排序
c语言·数据结构·算法·链表·性能优化·排序算法
绎岚科技2 小时前
深度学习自编码器 - 随机编码器和解码器篇
人工智能·深度学习·算法·机器学习
jingling5552 小时前
后端开发刷题 | 数字字符串转化成IP地址
java·开发语言·javascript·算法