diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
index ebcee86acedfd..7a70117f36d4c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
@@ -17,118 +17,103 @@
 
 package org.apache.spark.mllib.evaluation
 
+import org.apache.spark.annotation.Experimental
 import org.apache.spark.rdd.RDD
 import org.apache.spark.Logging
 import org.apache.spark.SparkContext._
 
 /**
  * Evaluator for multiclass classification.
- * NB: type Double both for prediction and label is retained
- * for compatibility with model.predict that returns Double
- * and MLUtils.loadLibSVMFile that loads class labels as Double
  *
  * @param predictionsAndLabels an RDD of (prediction, label) pairs.
  */
+@Experimental
 class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Logging {
 
-  /* class = category; label = instance of class; prediction = instance of class */
-
   private lazy val labelCountByClass = predictionsAndLabels.values.countByValue()
-  private lazy val labelCount = labelCountByClass.foldLeft(0L){case(sum, (_, count)) => sum + count}
-  private lazy val tpByClass = predictionsAndLabels.map{ case (prediction, label) =>
-    (label, if(label == prediction) 1 else 0) }.reduceByKey{_ + _}.collectAsMap
-  private lazy val fpByClass = predictionsAndLabels.map{ case (prediction, label) =>
-    (prediction, if(prediction != label) 1 else 0) }.reduceByKey{_ + _}.collectAsMap
+  private lazy val labelCount = labelCountByClass.values.sum
+  private lazy val tpByClass = predictionsAndLabels
+    .map{ case (prediction, label) =>
+    (label, if (label == prediction) 1 else 0)
+    }.reduceByKey(_ + _)
+    .collectAsMap()
+  private lazy val fpByClass = predictionsAndLabels
+    .map{ case (prediction, label) =>
+    (prediction, if (prediction != label) 1 else 0)
+  }.reduceByKey(_ + _)
+    .collectAsMap()
 
   /**
-   * Returns Precision for a given label (category)
+   * Returns precision for a given label (category)
    * @param label the label.
-   * @return Precision.
    */
-  def precision(label: Double): Double = if(tpByClass(label) + fpByClass.getOrElse(label, 0) == 0) 0
-    else tpByClass(label).toDouble / (tpByClass(label) + fpByClass.getOrElse(label, 0)).toDouble
+  def precision(label: Double): Double = {
+    val tp = tpByClass(label)
+    val fp = fpByClass.getOrElse(label, 0)
+    if (tp + fp == 0) 0 else tp.toDouble / (tp + fp)
+  }
 
   /**
-   * Returns Recall for a given label (category)
+   * Returns recall for a given label (category)
    * @param label the label.
-   * @return Recall.
    */
-  def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label).toDouble
+  def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label)
 
   /**
-   * Returns F1-measure for a given label (category)
+   * Returns f-measure for a given label (category)
    * @param label the label.
-   * @return F1-measure.
    */
-  def f1Measure(label: Double): Double ={
+  def fMeasure(label: Double, beta:Double = 1.0): Double = {
     val p = precision(label)
     val r = recall(label)
-    if((p + r) == 0) 0 else 2 * p * r / (p + r)
+    val betaSqrd = beta * beta
+    if (p + r == 0) 0 else (1 + betaSqrd) * p * r / (betaSqrd * p + r)
   }
 
   /**
-   * Returns micro-averaged Recall
+   * Returns micro-averaged recall
    * (equals to microPrecision and microF1measure for multiclass classifier)
-   * @return microRecall.
    */
-  lazy val microRecall: Double =
-    tpByClass.foldLeft(0L){case (sum,(_, tp)) => sum + tp}.toDouble / labelCount
+  lazy val recall: Double =
+    tpByClass.values.sum.toDouble / labelCount
 
   /**
-   * Returns micro-averaged Precision
+   * Returns micro-averaged precision
    * (equals to microPrecision and microF1measure for multiclass classifier)
-   * @return microPrecision.
    */
-  lazy val microPrecision: Double = microRecall
+  lazy val precision: Double = recall
 
   /**
-   * Returns micro-averaged F1-measure
+   * Returns micro-averaged f-measure
    * (equals to microPrecision and microRecall for multiclass classifier)
-   * @return microF1measure.
    */
-  lazy val microF1Measure: Double = microRecall
+  lazy val fMeasure: Double = recall
 
   /**
-   * Returns weighted averaged Recall
-   * @return weightedRecall.
+   * Returns weighted averaged recall
+   * (equals to micro-averaged precision, recall and f-measure)
    */
-  lazy val weightedRecall: Double = labelCountByClass.foldLeft(0.0){case(wRecall, (category, count)) =>
-    wRecall + recall(category) * count.toDouble / labelCount}
+  lazy val weightedRecall: Double = labelCountByClass.map { case (category, count) =>
+    recall(category) * count.toDouble / labelCount
+  }.sum
 
   /**
-   * Returns weighted averaged Precision
-   * @return weightedPrecision.
+   * Returns weighted averaged precision
    */
-  lazy val weightedPrecision: Double =
-    labelCountByClass.foldLeft(0.0){case(wPrecision, (category, count)) =>
-    wPrecision + precision(category) * count.toDouble / labelCount}
+  lazy val weightedPrecision: Double = labelCountByClass.map { case (category, count) =>
+    precision(category) * count.toDouble / labelCount
+  }.sum
 
   /**
-   * Returns weighted averaged F1-measure
-   * @return weightedF1Measure.
+   * Returns weighted averaged f1-measure
    */
-  lazy val weightedF1Measure: Double =
-    labelCountByClass.foldLeft(0.0){case(wF1measure, (category, count)) =>
-    wF1measure + f1Measure(category) * count.toDouble / labelCount}
+  lazy val weightedF1Measure: Double = labelCountByClass.map { case (category, count) =>
+    fMeasure(category) * count.toDouble / labelCount
+  }.sum
 
   /**
-   * Returns map with Precisions for individual classes
-   * @return precisionPerClass.
+   * Returns the sequence of labels in ascending order
    */
-  lazy val precisionPerClass =
-    labelCountByClass.map{case (category, _) => (category, precision(category))}.toMap
+  lazy val labels = tpByClass.unzip._1.toSeq.sorted
 
-  /**
-   * Returns map with Recalls for individual classes
-   * @return recallPerClass.
-   */
-  lazy val recallPerClass =
-    labelCountByClass.map{case (category, _) => (category, recall(category))}.toMap
-
-  /**
-   * Returns map with F1-measures for individual classes
-   * @return f1MeasurePerClass.
-   */
-  lazy val f1MeasurePerClass =
-    labelCountByClass.map{case (category, _) => (category, f1Measure(category))}.toMap
 }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
index b4e3664ab7916..4b959b2d542ac 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
@@ -17,9 +17,10 @@
 
 package org.apache.spark.mllib.evaluation
 
-import org.apache.spark.mllib.util.LocalSparkContext
 import org.scalatest.FunSuite
 
+import org.apache.spark.mllib.util.LocalSparkContext
+
 class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
   test("Multiclass evaluation metrics") {
     /*
@@ -29,12 +30,12 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
     * |0|0|1| true class2 (1 instance)
     *
     */
+    val labels = Seq(0.0, 1.0, 2.0)
     val scoreAndLabels = sc.parallelize(
       Seq((0.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0),
         (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2)
     val metrics = new MulticlassMetrics(scoreAndLabels)
-
-    val delta = 0.00001
+    val delta = 0.0000001
     val precision0 = 2.0 / (2.0 + 1.0)
     val precision1 = 3.0 / (3.0 + 1.0)
     val precision2 = 1.0 / (1.0 + 1.0)
@@ -44,28 +45,26 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
     val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0)
     val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1)
     val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2)
-
     assert(math.abs(metrics.precision(0.0) - precision0) < delta)
     assert(math.abs(metrics.precision(1.0) - precision1) < delta)
     assert(math.abs(metrics.precision(2.0) - precision2) < delta)
     assert(math.abs(metrics.recall(0.0) - recall0) < delta)
     assert(math.abs(metrics.recall(1.0) - recall1) < delta)
     assert(math.abs(metrics.recall(2.0) - recall2) < delta)
-    assert(math.abs(metrics.f1Measure(0.0) - f1measure0) < delta)
-    assert(math.abs(metrics.f1Measure(1.0) - f1measure1) < delta)
-    assert(math.abs(metrics.f1Measure(2.0) - f1measure2) < delta)
-
-    assert(math.abs(metrics.microRecall -
+    assert(math.abs(metrics.fMeasure(0.0) - f1measure0) < delta)
+    assert(math.abs(metrics.fMeasure(1.0) - f1measure1) < delta)
+    assert(math.abs(metrics.fMeasure(2.0) - f1measure2) < delta)
+    assert(math.abs(metrics.recall -
       (2.0 + 3.0 + 1.0) / ((2.0 + 3.0 + 1.0) + (1.0 + 1.0 + 1.0))) < delta)
-    assert(math.abs(metrics.microRecall - metrics.microPrecision) < delta)
-    assert(math.abs(metrics.microRecall - metrics.microF1Measure) < delta)
-    assert(math.abs(metrics.microRecall - metrics.weightedRecall) < delta)
+    assert(math.abs(metrics.recall - metrics.precision) < delta)
+    assert(math.abs(metrics.recall - metrics.fMeasure) < delta)
+    assert(math.abs(metrics.recall - metrics.weightedRecall) < delta)
     assert(math.abs(metrics.weightedPrecision -
       ((4.0 / 9.0) * precision0 + (4.0 / 9.0) * precision1 + (1.0 / 9.0) * precision2)) < delta)
     assert(math.abs(metrics.weightedRecall -
       ((4.0 / 9.0) * recall0 + (4.0 / 9.0) * recall1 + (1.0 / 9.0) * recall2)) < delta)
     assert(math.abs(metrics.weightedF1Measure -
       ((4.0 / 9.0) * f1measure0 + (4.0 / 9.0) * f1measure1 + (1.0 / 9.0) * f1measure2)) < delta)
-
+    assert(metrics.labels == labels)
   }
 }