diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index ebcee86acedfd..7a70117f36d4c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -17,118 +17,103 @@ package org.apache.spark.mllib.evaluation +import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.Logging import org.apache.spark.SparkContext._ /** * Evaluator for multiclass classification. - * NB: type Double both for prediction and label is retained - * for compatibility with model.predict that returns Double - * and MLUtils.loadLibSVMFile that loads class labels as Double * * @param predictionsAndLabels an RDD of (prediction, label) pairs. */ +@Experimental class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Logging { - /* class = category; label = instance of class; prediction = instance of class */ - private lazy val labelCountByClass = predictionsAndLabels.values.countByValue() - private lazy val labelCount = labelCountByClass.foldLeft(0L){case(sum, (_, count)) => sum + count} - private lazy val tpByClass = predictionsAndLabels.map{ case (prediction, label) => - (label, if(label == prediction) 1 else 0) }.reduceByKey{_ + _}.collectAsMap - private lazy val fpByClass = predictionsAndLabels.map{ case (prediction, label) => - (prediction, if(prediction != label) 1 else 0) }.reduceByKey{_ + _}.collectAsMap + private lazy val labelCount = labelCountByClass.values.sum + private lazy val tpByClass = predictionsAndLabels + .map{ case (prediction, label) => + (label, if (label == prediction) 1 else 0) + }.reduceByKey(_ + _) + .collectAsMap() + private lazy val fpByClass = predictionsAndLabels + .map{ case (prediction, label) => + (prediction, if (prediction != label) 1 else 0) + }.reduceByKey(_ + _) + .collectAsMap() /** - * Returns Precision for a given label (category) + * Returns precision for a given label (category) * @param label the label. - * @return Precision. */ - def precision(label: Double): Double = if(tpByClass(label) + fpByClass.getOrElse(label, 0) == 0) 0 - else tpByClass(label).toDouble / (tpByClass(label) + fpByClass.getOrElse(label, 0)).toDouble + def precision(label: Double): Double = { + val tp = tpByClass(label) + val fp = fpByClass.getOrElse(label, 0) + if (tp + fp == 0) 0 else tp.toDouble / (tp + fp) + } /** - * Returns Recall for a given label (category) + * Returns recall for a given label (category) * @param label the label. - * @return Recall. */ - def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label).toDouble + def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label) /** - * Returns F1-measure for a given label (category) + * Returns f-measure for a given label (category) * @param label the label. - * @return F1-measure. */ - def f1Measure(label: Double): Double ={ + def fMeasure(label: Double, beta:Double = 1.0): Double = { val p = precision(label) val r = recall(label) - if((p + r) == 0) 0 else 2 * p * r / (p + r) + val betaSqrd = beta * beta + if (p + r == 0) 0 else (1 + betaSqrd) * p * r / (betaSqrd * p + r) } /** - * Returns micro-averaged Recall + * Returns micro-averaged recall * (equals to microPrecision and microF1measure for multiclass classifier) - * @return microRecall. */ - lazy val microRecall: Double = - tpByClass.foldLeft(0L){case (sum,(_, tp)) => sum + tp}.toDouble / labelCount + lazy val recall: Double = + tpByClass.values.sum.toDouble / labelCount /** - * Returns micro-averaged Precision + * Returns micro-averaged precision * (equals to microPrecision and microF1measure for multiclass classifier) - * @return microPrecision. */ - lazy val microPrecision: Double = microRecall + lazy val precision: Double = recall /** - * Returns micro-averaged F1-measure + * Returns micro-averaged f-measure * (equals to microPrecision and microRecall for multiclass classifier) - * @return microF1measure. */ - lazy val microF1Measure: Double = microRecall + lazy val fMeasure: Double = recall /** - * Returns weighted averaged Recall - * @return weightedRecall. + * Returns weighted averaged recall + * (equals to micro-averaged precision, recall and f-measure) */ - lazy val weightedRecall: Double = labelCountByClass.foldLeft(0.0){case(wRecall, (category, count)) => - wRecall + recall(category) * count.toDouble / labelCount} + lazy val weightedRecall: Double = labelCountByClass.map { case (category, count) => + recall(category) * count.toDouble / labelCount + }.sum /** - * Returns weighted averaged Precision - * @return weightedPrecision. + * Returns weighted averaged precision */ - lazy val weightedPrecision: Double = - labelCountByClass.foldLeft(0.0){case(wPrecision, (category, count)) => - wPrecision + precision(category) * count.toDouble / labelCount} + lazy val weightedPrecision: Double = labelCountByClass.map { case (category, count) => + precision(category) * count.toDouble / labelCount + }.sum /** - * Returns weighted averaged F1-measure - * @return weightedF1Measure. + * Returns weighted averaged f1-measure */ - lazy val weightedF1Measure: Double = - labelCountByClass.foldLeft(0.0){case(wF1measure, (category, count)) => - wF1measure + f1Measure(category) * count.toDouble / labelCount} + lazy val weightedF1Measure: Double = labelCountByClass.map { case (category, count) => + fMeasure(category) * count.toDouble / labelCount + }.sum /** - * Returns map with Precisions for individual classes - * @return precisionPerClass. + * Returns the sequence of labels in ascending order */ - lazy val precisionPerClass = - labelCountByClass.map{case (category, _) => (category, precision(category))}.toMap + lazy val labels = tpByClass.unzip._1.toSeq.sorted - /** - * Returns map with Recalls for individual classes - * @return recallPerClass. - */ - lazy val recallPerClass = - labelCountByClass.map{case (category, _) => (category, recall(category))}.toMap - - /** - * Returns map with F1-measures for individual classes - * @return f1MeasurePerClass. - */ - lazy val f1MeasurePerClass = - labelCountByClass.map{case (category, _) => (category, f1Measure(category))}.toMap } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index b4e3664ab7916..4b959b2d542ac 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.mllib.evaluation -import org.apache.spark.mllib.util.LocalSparkContext import org.scalatest.FunSuite +import org.apache.spark.mllib.util.LocalSparkContext + class MulticlassMetricsSuite extends FunSuite with LocalSparkContext { test("Multiclass evaluation metrics") { /* @@ -29,12 +30,12 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext { * |0|0|1| true class2 (1 instance) * */ + val labels = Seq(0.0, 1.0, 2.0) val scoreAndLabels = sc.parallelize( Seq((0.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2) val metrics = new MulticlassMetrics(scoreAndLabels) - - val delta = 0.00001 + val delta = 0.0000001 val precision0 = 2.0 / (2.0 + 1.0) val precision1 = 3.0 / (3.0 + 1.0) val precision2 = 1.0 / (1.0 + 1.0) @@ -44,28 +45,26 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext { val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0) val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1) val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2) - assert(math.abs(metrics.precision(0.0) - precision0) < delta) assert(math.abs(metrics.precision(1.0) - precision1) < delta) assert(math.abs(metrics.precision(2.0) - precision2) < delta) assert(math.abs(metrics.recall(0.0) - recall0) < delta) assert(math.abs(metrics.recall(1.0) - recall1) < delta) assert(math.abs(metrics.recall(2.0) - recall2) < delta) - assert(math.abs(metrics.f1Measure(0.0) - f1measure0) < delta) - assert(math.abs(metrics.f1Measure(1.0) - f1measure1) < delta) - assert(math.abs(metrics.f1Measure(2.0) - f1measure2) < delta) - - assert(math.abs(metrics.microRecall - + assert(math.abs(metrics.fMeasure(0.0) - f1measure0) < delta) + assert(math.abs(metrics.fMeasure(1.0) - f1measure1) < delta) + assert(math.abs(metrics.fMeasure(2.0) - f1measure2) < delta) + assert(math.abs(metrics.recall - (2.0 + 3.0 + 1.0) / ((2.0 + 3.0 + 1.0) + (1.0 + 1.0 + 1.0))) < delta) - assert(math.abs(metrics.microRecall - metrics.microPrecision) < delta) - assert(math.abs(metrics.microRecall - metrics.microF1Measure) < delta) - assert(math.abs(metrics.microRecall - metrics.weightedRecall) < delta) + assert(math.abs(metrics.recall - metrics.precision) < delta) + assert(math.abs(metrics.recall - metrics.fMeasure) < delta) + assert(math.abs(metrics.recall - metrics.weightedRecall) < delta) assert(math.abs(metrics.weightedPrecision - ((4.0 / 9.0) * precision0 + (4.0 / 9.0) * precision1 + (1.0 / 9.0) * precision2)) < delta) assert(math.abs(metrics.weightedRecall - ((4.0 / 9.0) * recall0 + (4.0 / 9.0) * recall1 + (1.0 / 9.0) * recall2)) < delta) assert(math.abs(metrics.weightedF1Measure - ((4.0 / 9.0) * f1measure0 + (4.0 / 9.0) * f1measure1 + (1.0 / 9.0) * f1measure2)) < delta) - + assert(metrics.labels == labels) } }