From 6cf6fdf3ff5d1cf33c2dc28f039adc4d7c0f0464 Mon Sep 17 00:00:00 2001 From: Travis Galoppo Date: Mon, 29 Dec 2014 15:29:15 -0800 Subject: [PATCH] SPARK-4156 [MLLIB] EM algorithm for GMMs Implementation of Expectation-Maximization for Gaussian Mixture Models. This is my maiden contribution to Apache Spark, so I apologize now if I have done anything incorrectly; having said that, this work is my own, and I offer it to the project under the project's open source license. Author: Travis Galoppo Author: Travis Galoppo Author: tgaloppo Author: FlytxtRnD Closes #3022 from tgaloppo/master and squashes the following commits: aaa8f25 [Travis Galoppo] MLUtils: changed privacy of EPSILON from [util] to [mllib] 709e4bf [Travis Galoppo] fixed usage line to include optional maxIterations parameter acf1fba [Travis Galoppo] Fixed parameter comment in GaussianMixtureModel Made maximum iterations an optional parameter to DenseGmmEM 9b2fc2a [Travis Galoppo] Style improvements Changed ExpectationSum to a private class b97fe00 [Travis Galoppo] Minor fixes and tweaks. 1de73f3 [Travis Galoppo] Removed redundant array from array creation 578c2d1 [Travis Galoppo] Removed unused import 227ad66 [Travis Galoppo] Moved prediction methods into model class. 308c8ad [Travis Galoppo] Numerous changes to improve code cff73e0 [Travis Galoppo] Replaced accumulators with RDD.aggregate 20ebca1 [Travis Galoppo] Removed unusued code 42b2142 [Travis Galoppo] Added functionality to allow setting of GMM starting point. Added two cluster test to testing suite. 8b633f3 [Travis Galoppo] Style issue 9be2534 [Travis Galoppo] Style issue d695034 [Travis Galoppo] Fixed style issues c3b8ce0 [Travis Galoppo] Merge branch 'master' of https://github.com/tgaloppo/spark Adds predict() method 2df336b [Travis Galoppo] Fixed style issue b99ecc4 [tgaloppo] Merge pull request #1 from FlytxtRnD/predictBranch f407b4c [FlytxtRnD] Added predict() to return the cluster labels and membership values 97044cf [Travis Galoppo] Fixed style issues dc9c742 [Travis Galoppo] Moved MultivariateGaussian utility class e7d413b [Travis Galoppo] Moved multivariate Gaussian utility class to mllib/stat/impl Improved comments 9770261 [Travis Galoppo] Corrected a variety of style and naming issues. 8aaa17d [Travis Galoppo] Added additional train() method to companion object for cluster count and tolerance parameters. 676e523 [Travis Galoppo] Fixed to no longer ignore delta value provided on command line e6ea805 [Travis Galoppo] Merged with master branch; update test suite with latest context changes. Improved cluster initialization strategy. 86fb382 [Travis Galoppo] Merge remote-tracking branch 'upstream/master' 719d8cc [Travis Galoppo] Added scala test suite with basic test c1a8e16 [Travis Galoppo] Made GaussianMixtureModel class serializable Modified sum function for better performance 5c96c57 [Travis Galoppo] Merge remote-tracking branch 'upstream/master' c15405c [Travis Galoppo] SPARK-4156 --- .../spark/examples/mllib/DenseGmmEM.scala | 67 +++++ .../mllib/clustering/GaussianMixtureEM.scala | 241 ++++++++++++++++++ .../clustering/GaussianMixtureModel.scala | 91 +++++++ .../stat/impl/MultivariateGaussian.scala | 39 +++ .../org/apache/spark/mllib/util/MLUtils.scala | 2 +- .../GMMExpectationMaximizationSuite.scala | 78 ++++++ 6 files changed, 517 insertions(+), 1 deletion(-) create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala new file mode 100644 index 0000000000000..948c350953e27 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.clustering.GaussianMixtureEM +import org.apache.spark.mllib.linalg.Vectors + +/** + * An example Gaussian Mixture Model EM app. Run with + * {{{ + * ./bin/run-example org.apache.spark.examples.mllib.DenseGmmEM + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object DenseGmmEM { + def main(args: Array[String]): Unit = { + if (args.length < 3) { + println("usage: DenseGmmEM [maxIterations]") + } else { + val maxIterations = if (args.length > 3) args(3).toInt else 100 + run(args(0), args(1).toInt, args(2).toDouble, maxIterations) + } + } + + private def run(inputFile: String, k: Int, convergenceTol: Double, maxIterations: Int) { + val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example") + val ctx = new SparkContext(conf) + + val data = ctx.textFile(inputFile).map { line => + Vectors.dense(line.trim.split(' ').map(_.toDouble)) + }.cache() + + val clusters = new GaussianMixtureEM() + .setK(k) + .setConvergenceTol(convergenceTol) + .setMaxIterations(maxIterations) + .run(data) + + for (i <- 0 until clusters.k) { + println("weight=%f\nmu=%s\nsigma=\n%s\n" format + (clusters.weight(i), clusters.mu(i), clusters.sigma(i))) + } + + println("Cluster labels (first <= 100):") + val clusterLabels = clusters.predict(data) + clusterLabels.take(100).foreach { x => + print(" " + x) + } + println() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala new file mode 100644 index 0000000000000..bdf984aee4dae --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import scala.collection.mutable.IndexedSeq + +import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix, diag, Transpose} +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors} +import org.apache.spark.mllib.stat.impl.MultivariateGaussian +import org.apache.spark.mllib.util.MLUtils + +/** + * This class performs expectation maximization for multivariate Gaussian + * Mixture Models (GMMs). A GMM represents a composite distribution of + * independent Gaussian distributions with associated "mixing" weights + * specifying each's contribution to the composite. + * + * Given a set of sample points, this class will maximize the log-likelihood + * for a mixture of k Gaussians, iterating until the log-likelihood changes by + * less than convergenceTol, or until it has reached the max number of iterations. + * While this process is generally guaranteed to converge, it is not guaranteed + * to find a global optimum. + * + * @param k The number of independent Gaussians in the mixture model + * @param convergenceTol The maximum change in log-likelihood at which convergence + * is considered to have occurred. + * @param maxIterations The maximum number of iterations to perform + */ +class GaussianMixtureEM private ( + private var k: Int, + private var convergenceTol: Double, + private var maxIterations: Int) extends Serializable { + + /** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */ + def this() = this(2, 0.01, 100) + + // number of samples per cluster to use when initializing Gaussians + private val nSamples = 5 + + // an initializing GMM can be provided rather than using the + // default random starting point + private var initialModel: Option[GaussianMixtureModel] = None + + /** Set the initial GMM starting point, bypassing the random initialization. + * You must call setK() prior to calling this method, and the condition + * (model.k == this.k) must be met; failure will result in an IllegalArgumentException + */ + def setInitialModel(model: GaussianMixtureModel): this.type = { + if (model.k == k) { + initialModel = Some(model) + } else { + throw new IllegalArgumentException("mismatched cluster count (model.k != k)") + } + this + } + + /** Return the user supplied initial GMM, if supplied */ + def getInitialModel: Option[GaussianMixtureModel] = initialModel + + /** Set the number of Gaussians in the mixture model. Default: 2 */ + def setK(k: Int): this.type = { + this.k = k + this + } + + /** Return the number of Gaussians in the mixture model */ + def getK: Int = k + + /** Set the maximum number of iterations to run. Default: 100 */ + def setMaxIterations(maxIterations: Int): this.type = { + this.maxIterations = maxIterations + this + } + + /** Return the maximum number of iterations to run */ + def getMaxIterations: Int = maxIterations + + /** + * Set the largest change in log-likelihood at which convergence is + * considered to have occurred. + */ + def setConvergenceTol(convergenceTol: Double): this.type = { + this.convergenceTol = convergenceTol + this + } + + /** Return the largest change in log-likelihood at which convergence is + * considered to have occurred. + */ + def getConvergenceTol: Double = convergenceTol + + /** Perform expectation maximization */ + def run(data: RDD[Vector]): GaussianMixtureModel = { + val sc = data.sparkContext + + // we will operate on the data as breeze data + val breezeData = data.map(u => u.toBreeze.toDenseVector).cache() + + // Get length of the input vectors + val d = breezeData.first.length + + // Determine initial weights and corresponding Gaussians. + // If the user supplied an initial GMM, we use those values, otherwise + // we start with uniform weights, a random mean from the data, and + // diagonal covariance matrices using component variances + // derived from the samples + val (weights, gaussians) = initialModel match { + case Some(gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma).map { case(mu, sigma) => + new MultivariateGaussian(mu.toBreeze.toDenseVector, sigma.toBreeze.toDenseMatrix) + }) + + case None => { + val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt) + (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => + val slice = samples.view(i * nSamples, (i + 1) * nSamples) + new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) + }) + } + } + + var llh = Double.MinValue // current log-likelihood + var llhp = 0.0 // previous log-likelihood + + var iter = 0 + while(iter < maxIterations && Math.abs(llh-llhp) > convergenceTol) { + // create and broadcast curried cluster contribution function + val compute = sc.broadcast(ExpectationSum.add(weights, gaussians)_) + + // aggregate the cluster contribution for all sample points + val sums = breezeData.aggregate(ExpectationSum.zero(k, d))(compute.value, _ += _) + + // Create new distributions based on the partial assignments + // (often referred to as the "M" step in literature) + val sumWeights = sums.weights.sum + var i = 0 + while (i < k) { + val mu = sums.means(i) / sums.weights(i) + val sigma = sums.sigmas(i) / sums.weights(i) - mu * new Transpose(mu) // TODO: Use BLAS.dsyr + weights(i) = sums.weights(i) / sumWeights + gaussians(i) = new MultivariateGaussian(mu, sigma) + i = i + 1 + } + + llhp = llh // current becomes previous + llh = sums.logLikelihood // this is the freshly computed log-likelihood + iter += 1 + } + + // Need to convert the breeze matrices to MLlib matrices + val means = Array.tabulate(k) { i => Vectors.fromBreeze(gaussians(i).mu) } + val sigmas = Array.tabulate(k) { i => Matrices.fromBreeze(gaussians(i).sigma) } + new GaussianMixtureModel(weights, means, sigmas) + } + + /** Average of dense breeze vectors */ + private def vectorMean(x: IndexedSeq[BreezeVector[Double]]): BreezeVector[Double] = { + val v = BreezeVector.zeros[Double](x(0).length) + x.foreach(xi => v += xi) + v / x.length.toDouble + } + + /** + * Construct matrix where diagonal entries are element-wise + * variance of input vectors (computes biased variance) + */ + private def initCovariance(x: IndexedSeq[BreezeVector[Double]]): BreezeMatrix[Double] = { + val mu = vectorMean(x) + val ss = BreezeVector.zeros[Double](x(0).length) + x.map(xi => (xi - mu) :^ 2.0).foreach(u => ss += u) + diag(ss / x.length.toDouble) + } +} + +// companion class to provide zero constructor for ExpectationSum +private object ExpectationSum { + def zero(k: Int, d: Int): ExpectationSum = { + new ExpectationSum(0.0, Array.fill(k)(0.0), + Array.fill(k)(BreezeVector.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d))) + } + + // compute cluster contributions for each input point + // (U, T) => U for aggregation + def add( + weights: Array[Double], + dists: Array[MultivariateGaussian]) + (sums: ExpectationSum, x: BreezeVector[Double]): ExpectationSum = { + val p = weights.zip(dists).map { + case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(x) + } + val pSum = p.sum + sums.logLikelihood += math.log(pSum) + val xxt = x * new Transpose(x) + var i = 0 + while (i < sums.k) { + p(i) /= pSum + sums.weights(i) += p(i) + sums.means(i) += x * p(i) + sums.sigmas(i) += xxt * p(i) // TODO: use BLAS.dsyr + i = i + 1 + } + sums + } +} + +// Aggregation class for partial expectation results +private class ExpectationSum( + var logLikelihood: Double, + val weights: Array[Double], + val means: Array[BreezeVector[Double]], + val sigmas: Array[BreezeMatrix[Double]]) extends Serializable { + + val k = weights.length + + def +=(x: ExpectationSum): ExpectationSum = { + var i = 0 + while (i < k) { + weights(i) += x.weights(i) + means(i) += x.means(i) + sigmas(i) += x.sigmas(i) + i = i + 1 + } + logLikelihood += x.logLikelihood + this + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala new file mode 100644 index 0000000000000..11a110db1f7ca --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import breeze.linalg.{DenseVector => BreezeVector} + +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.{Matrix, Vector} +import org.apache.spark.mllib.stat.impl.MultivariateGaussian +import org.apache.spark.mllib.util.MLUtils + +/** + * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points + * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are + * the respective mean and covariance for each Gaussian distribution i=1..k. + * + * @param weight Weights for each Gaussian distribution in the mixture, where weight(i) is + * the weight for Gaussian i, and weight.sum == 1 + * @param mu Means for each Gaussian in the mixture, where mu(i) is the mean for Gaussian i + * @param sigma Covariance maxtrix for each Gaussian in the mixture, where sigma(i) is the + * covariance matrix for Gaussian i + */ +class GaussianMixtureModel( + val weight: Array[Double], + val mu: Array[Vector], + val sigma: Array[Matrix]) extends Serializable { + + /** Number of gaussians in mixture */ + def k: Int = weight.length + + /** Maps given points to their cluster indices. */ + def predict(points: RDD[Vector]): RDD[Int] = { + val responsibilityMatrix = predictMembership(points, mu, sigma, weight, k) + responsibilityMatrix.map(r => r.indexOf(r.max)) + } + + /** + * Given the input vectors, return the membership value of each vector + * to all mixture components. + */ + def predictMembership( + points: RDD[Vector], + mu: Array[Vector], + sigma: Array[Matrix], + weight: Array[Double], + k: Int): RDD[Array[Double]] = { + val sc = points.sparkContext + val dists = sc.broadcast { + (0 until k).map { i => + new MultivariateGaussian(mu(i).toBreeze.toDenseVector, sigma(i).toBreeze.toDenseMatrix) + }.toArray + } + val weights = sc.broadcast(weight) + points.map { x => + computeSoftAssignments(x.toBreeze.toDenseVector, dists.value, weights.value, k) + } + } + + /** + * Compute the partial assignments for each vector + */ + private def computeSoftAssignments( + pt: BreezeVector[Double], + dists: Array[MultivariateGaussian], + weights: Array[Double], + k: Int): Array[Double] = { + val p = weights.zip(dists).map { + case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(pt) + } + val pSum = p.sum + for (i <- 0 until k) { + p(i) /= pSum + } + p + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala new file mode 100644 index 0000000000000..2eab5d277827d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.impl + +import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, Transpose, det, pinv} + +/** + * Utility class to implement the density function for multivariate Gaussian distribution. + * Breeze provides this functionality, but it requires the Apache Commons Math library, + * so this class is here so-as to not introduce a new dependency in Spark. + */ +private[mllib] class MultivariateGaussian( + val mu: DBV[Double], + val sigma: DBM[Double]) extends Serializable { + private val sigmaInv2 = pinv(sigma) * -0.5 + private val U = math.pow(2.0 * math.Pi, -mu.length / 2.0) * math.pow(det(sigma), -0.5) + + /** Returns density of this multivariate Gaussian at given point, x */ + def pdf(x: DBV[Double]): Double = { + val delta = x - mu + val deltaTranspose = new Transpose(delta) + U * math.exp(deltaTranspose * sigmaInv2 * delta) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index b0d05ae33e1b5..1d07b5dab8268 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -39,7 +39,7 @@ import org.apache.spark.streaming.dstream.DStream */ object MLUtils { - private[util] lazy val EPSILON = { + private[mllib] lazy val EPSILON = { var eps = 1.0 while ((1.0 + (eps / 2.0)) != 1.0) { eps /= 2.0 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala new file mode 100644 index 0000000000000..23feb82874b70 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{Vectors, Matrices} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContext { + test("single cluster") { + val data = sc.parallelize(Array( + Vectors.dense(6.0, 9.0), + Vectors.dense(5.0, 10.0), + Vectors.dense(4.0, 11.0) + )) + + // expectations + val Ew = 1.0 + val Emu = Vectors.dense(5.0, 10.0) + val Esigma = Matrices.dense(2, 2, Array(2.0 / 3.0, -2.0 / 3.0, -2.0 / 3.0, 2.0 / 3.0)) + + val gmm = new GaussianMixtureEM().setK(1).run(data) + + assert(gmm.weight(0) ~== Ew absTol 1E-5) + assert(gmm.mu(0) ~== Emu absTol 1E-5) + assert(gmm.sigma(0) ~== Esigma absTol 1E-5) + } + + test("two clusters") { + val data = sc.parallelize(Array( + Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220), + Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118), + Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322), + Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026), + Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) + )) + + // we set an initial gaussian to induce expected results + val initialGmm = new GaussianMixtureModel( + Array(0.5, 0.5), + Array(Vectors.dense(-1.0), Vectors.dense(1.0)), + Array(Matrices.dense(1, 1, Array(1.0)), Matrices.dense(1, 1, Array(1.0))) + ) + + val Ew = Array(1.0 / 3.0, 2.0 / 3.0) + val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604)) + val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644))) + + val gmm = new GaussianMixtureEM() + .setK(2) + .setInitialModel(initialGmm) + .run(data) + + assert(gmm.weight(0) ~== Ew(0) absTol 1E-3) + assert(gmm.weight(1) ~== Ew(1) absTol 1E-3) + assert(gmm.mu(0) ~== Emu(0) absTol 1E-3) + assert(gmm.mu(1) ~== Emu(1) absTol 1E-3) + assert(gmm.sigma(0) ~== Esigma(0) absTol 1E-3) + assert(gmm.sigma(1) ~== Esigma(1) absTol 1E-3) + } +}