From b97fe00d17c9d6665bb63f1f0bcbd968a875dfe8 Mon Sep 17 00:00:00 2001 From: Travis Galoppo Date: Thu, 18 Dec 2014 20:28:02 -0500 Subject: [PATCH] Minor fixes and tweaks. --- .../org/apache/spark/examples/mllib/DenseGmmEM.scala | 10 +++++----- .../mllib/clustering/GaussianMixtureModelEM.scala | 12 ++++++++---- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala index d59ba49ed1ba3..41adcbb0a3115 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala @@ -37,13 +37,13 @@ object DenseGmmEM { } } - def run(inputFile: String, k: Int, convergenceTol: Double) { + private def run(inputFile: String, k: Int, convergenceTol: Double) { val conf = new SparkConf().setAppName("Spark EM Sample") val ctx = new SparkContext(conf) val data = ctx.textFile(inputFile).map{ line => Vectors.dense(line.trim.split(' ').map(_.toDouble)) - }.cache + }.cache() val clusters = new GaussianMixtureModelEM() .setK(k) @@ -55,11 +55,11 @@ object DenseGmmEM { (clusters.weight(i), clusters.mu(i), clusters.sigma(i))) } - println("Cluster labels:") + println("Cluster labels (first <= 100):") val (responsibilityMatrix, clusterLabels) = clusters.predict(data) - for (x <- clusterLabels.collect) { + clusterLabels.take(100).foreach{ x => print(" " + x) } - println + println() } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModelEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModelEM.scala index 4644259d9fa28..b0d8c7a0aafbd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModelEM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModelEM.scala @@ -71,10 +71,12 @@ class GaussianMixtureModelEM private ( // (U, U) => U for aggregation private def addExpectationSums(m1: ExpectationSum, m2: ExpectationSum): ExpectationSum = { m1._1(0) += m2._1(0) - for (i <- 0 until m1._2.length) { + var i = 0 + while (i < m1._2.length) { m1._2(i) += m2._2(i) m1._3(i) += m2._3(i) m1._4(i) += m2._4(i) + i = i + 1 } m1 } @@ -90,11 +92,13 @@ class GaussianMixtureModelEM private ( val pSum = p.sum sums._1(0) += math.log(pSum) val xxt = x * new Transpose(x) - for (i <- 0 until k) { + var i = 0 + while (i < k) { p(i) /= pSum sums._2(i) += p(i) sums._3(i) += x * p(i) sums._4(i) += xxt * p(i) + i = i + 1 } sums } @@ -123,7 +127,7 @@ class GaussianMixtureModelEM private ( } /** Return the user supplied initial GMM, if supplied */ - def getInitialiGmm: Option[GaussianMixtureModel] = initialGmm + def getInitialGmm: Option[GaussianMixtureModel] = initialGmm /** Set the number of Gaussians in the mixture model. Default: 2 */ def setK(k: Int): this.type = { @@ -182,7 +186,7 @@ class GaussianMixtureModelEM private ( case None => { val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt) - ((0 until k).map(_ => 1.0 / k).toArray, (0 until k).map{ i => + (Array.fill[Double](k)(1.0 / k), (0 until k).map{ i => val slice = samples.view(i * nSamples, (i + 1) * nSamples) new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) }.toArray)