diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index c9d162fbbb581..f8865c083a23b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -176,23 +176,18 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String) .map(_.headOption.getOrElse(Double.NaN)) case Imputer.mode => + // Keep in line with sklearn.impute.SimpleImputer (using scipy.stats.mode). + // If there is more than one mode, choose the smallest one. val modes = dataset.select(cols: _*).rdd.flatMap { row => Iterator.range(0, numCols).flatMap { i => // Ignore null. if (row.isNullAt(i)) Iterator.empty else Iterator.single((i, row.getDouble(i)), 1L) } - }.reduceByKey(_ + _).map { case ((i, v), c) => (i, (v, c)) - }.reduceByKey { case ((v1, c1), (v2, c2)) => - if (c1 > c2) { - (v1, c1) - } else if (c1 < c2) { - (v2, c2) - } else { - // Keep in line with sklearn.impute.SimpleImputer (using scipy.stats.mode). - // If there is more than one mode, choose the smallest one. - (math.min(v1, v2), c1) - } - }.mapValues(_._1).collectAsMap() + }.reduceByKey(_ + _).map { case ((i, v), c) => + // negative value to apply the default ranking of [Long, Double] + (i, (c, -v)) + }.reduceByKey(Ordering.apply[(Long, Double)].max + ).mapValues(-_._2).collectAsMap() Array.tabulate(numCols)(i => modes.getOrElse(i, Double.NaN)) }