Skip to content

Commit

Permalink
Add Scala Naive Bayes example, to use existing example data file (who…
Browse files Browse the repository at this point in the history
…se format needed a tweak)
  • Loading branch information
srowen committed May 6, 2014
1 parent 8c81982 commit 23c9ac3
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
17 changes: 13 additions & 4 deletions docs/mllib-naive-bayes.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,18 @@ can be used for evaluation and prediction.

{% highlight scala %}
import org.apache.spark.mllib.classification.NaiveBayes

val training: RDD[LabeledPoint] = ... // training set
val test: RDD[LabeledPoint] = ... // test set
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint

val data = sc.textFile("mllib/data/sample_naive_bayes_data.txt")
val parsedData = data.map { line =>
val parts = line.split(',')
LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
}
// Split data into training (60%) and test (40%).
val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L)
val training = splits(0)
val test = splits(1)

val model = NaiveBayes.train(training, lambda = 1.0)
val prediction = model.predict(test.map(_.features))
Expand Down Expand Up @@ -69,7 +78,7 @@ import scala.Tuple2;
JavaRDD<LabeledPoint> training = ... // training set
JavaRDD<LabeledPoint> test = ... // test set

NaiveBayesModel model = NaiveBayes.train(training.rdd(), 1.0);
final NaiveBayesModel model = NaiveBayes.train(training.rdd(), 1.0);

JavaRDD<Double> prediction =
test.map(new Function<LabeledPoint, Double>() {
Expand Down
12 changes: 6 additions & 6 deletions mllib/data/sample_naive_bayes_data.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
0, 1 0 0
0, 2 0 0
1, 0 1 0
1, 0 2 0
2, 0 0 1
2, 0 0 2
0,1 0 0
0,2 0 0
1,0 1 0
1,0 2 0
2,0 0 1
2,0 0 2

0 comments on commit 23c9ac3

Please sign in to comment.