Skip to content

Commit

Permalink
fix code styles
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Apr 9, 2014
1 parent 9dc3518 commit b1b7dab
Showing 1 changed file with 18 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ private case class BinaryConfusionMatrixImpl(
*
* @param scoreAndLabels an RDD of (score, label) pairs.
*/
class BinaryClassificationEvaluator(scoreAndLabels: RDD[(Double, Double)]) extends Serializable with Logging {
class BinaryClassificationEvaluator(scoreAndLabels: RDD[(Double, Double)])
extends Serializable with Logging {

private lazy val (
cumCounts: RDD[(Double, LabelCounter)],
Expand All @@ -73,16 +74,18 @@ class BinaryClassificationEvaluator(scoreAndLabels: RDD[(Double, Double)]) exten
iter.foreach(agg += _)
Iterator(agg)
}, preservesPartitioning = true).collect()
val partitionwiseCumCounts = agg.scanLeft(new LabelCounter())((agg: LabelCounter, c: LabelCounter) => agg + c)
val partitionwiseCumCounts =
agg.scanLeft(new LabelCounter())((agg: LabelCounter, c: LabelCounter) => agg.clone() += c)
val totalCount = partitionwiseCumCounts.last
logInfo(s"Total counts: $totalCount")
val cumCounts = counts.mapPartitionsWithIndex((index: Int, iter: Iterator[(Double, LabelCounter)]) => {
val cumCount = partitionwiseCumCounts(index)
iter.map { case (score, c) =>
cumCount += c
(score, cumCount.clone())
}
}, preservesPartitioning = true)
val cumCounts = counts.mapPartitionsWithIndex(
(index: Int, iter: Iterator[(Double, LabelCounter)]) => {
val cumCount = partitionwiseCumCounts(index)
iter.map { case (score, c) =>
cumCount += c
(score, cumCount.clone())
}
}, preservesPartitioning = true)
cumCounts.persist()
val confusions = cumCounts.map { case (score, cumCount) =>
(score, BinaryConfusionMatrixImpl(cumCount, totalCount).asInstanceOf[BinaryConfusionMatrix])
Expand Down Expand Up @@ -136,7 +139,9 @@ class BinaryClassificationEvaluator(scoreAndLabels: RDD[(Double, Double)]) exten
}

/** Creates a curve of (metricX, metricY). */
private def createCurve(x: BinaryClassificationMetric, y: BinaryClassificationMetric): RDD[(Double, Double)] = {
private def createCurve(
x: BinaryClassificationMetric,
y: BinaryClassificationMetric): RDD[(Double, Double)] = {
confusions.map { case (_, c) =>
(x(c), y(c))
}
Expand All @@ -149,7 +154,9 @@ class BinaryClassificationEvaluator(scoreAndLabels: RDD[(Double, Double)]) exten
* @param numPositives number of positive labels
* @param numNegatives number of negative labels
*/
private class LabelCounter(var numPositives: Long = 0L, var numNegatives: Long = 0L) extends Serializable {
private class LabelCounter(
var numPositives: Long = 0L,
var numNegatives: Long = 0L) extends Serializable {

/** Processes a label. */
def +=(label: Double): LabelCounter = {
Expand All @@ -166,11 +173,6 @@ private class LabelCounter(var numPositives: Long = 0L, var numNegatives: Long =
this
}

/** Sums this counter and another counter and returns the result in a new counter. */
def +(other: LabelCounter): LabelCounter = {
this.clone() += other
}

override def clone: LabelCounter = {
new LabelCounter(numPositives, numNegatives)
}
Expand Down

0 comments on commit b1b7dab

Please sign in to comment.