From 900b5864c16cc0db93a46ec3a4591a787e5a21a0 Mon Sep 17 00:00:00 2001 From: leahmcguire Date: Thu, 5 Mar 2015 11:53:46 -0800 Subject: [PATCH] fixed model call so that uses type argument --- .../apache/spark/mllib/classification/NaiveBayes.scala | 8 ++++---- .../spark/mllib/classification/NaiveBayesSuite.scala | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 4e1c6a63fb01c..bcf5acdada671 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -310,10 +310,10 @@ object NaiveBayes { * * The model type can be set to either Multinomial NB ([[http://tinyurl.com/lsdw6p]]) * or Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The Multinomial NB can handle - * discrete count data and can be called by setting the model type to "Multinomial". + * discrete count data and can be called by setting the model type to "multinomial". * For example, it can be used with word counts or TF_IDF vectors of documents. * The Bernoulli model fits presence or absence (0-1) counts. By making every vector a - * 0-1 vector and setting the model type to "Bernoulli", the fits and predicts as + * 0-1 vector and setting the model type to "bernoulli", the fits and predicts as * Bernoulli NB. * * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency @@ -321,10 +321,10 @@ object NaiveBayes { * @param lambda The smoothing parameter * * @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be - * Multinomial or Bernoulli + * multinomial or bernoulli */ def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = { - new NaiveBayes(lambda, Multinomial).run(input) + new NaiveBayes(lambda, MODELTYPE.fromString(modelType)).run(input) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 0874bb0b90ce4..f7310bef2bc9b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -124,7 +124,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { val testRDD = sc.parallelize(testData, 2) testRDD.cache() - val model = NaiveBayes.train(testRDD, 1.0, "Multinomial") + val model = NaiveBayes.train(testRDD, 1.0, "multinomial") validateModelFit(pi, theta, model) val validationData = NaiveBayesSuite.generateNaiveBayesInput( @@ -161,7 +161,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { val testRDD = sc.parallelize(testData, 2) testRDD.cache() - val model = NaiveBayes.train(testRDD, 1.0, "Bernoulli") + val model = NaiveBayes.train(testRDD, 1.0, "bernoulli") validateModelFit(pi, theta, model) val validationData = NaiveBayesSuite.generateNaiveBayesInput(