Skip to content

Commit

Permalink
rdd -> df
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengruifeng committed Nov 17, 2020
1 parent 91ae454 commit e0605d6
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -176,18 +176,20 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
.map(_.headOption.getOrElse(Double.NaN))

case Imputer.mode =>
// Keep in line with sklearn.impute.SimpleImputer (using scipy.stats.mode).
// If there is more than one mode, choose the smallest one.
val modes = dataset.select(cols: _*).rdd.flatMap { row =>
import spark.implicits._
// If there is more than one mode, choose the smallest one to keep in line
// with sklearn.impute.SimpleImputer (using scipy.stats.mode).
val modes = dataset.select(cols: _*).flatMap { row =>
Iterator.range(0, numCols).flatMap { i =>
// Ignore null.
if (row.isNullAt(i)) Iterator.empty else Iterator.single((i, row.getDouble(i)), 1L)
// negative value to apply the default ranking of [Long, Double]
if (row.isNullAt(i)) Iterator.empty else Iterator.single((i, -row.getDouble(i)))
}
}.reduceByKey(_ + _).map { case ((i, v), c) =>
// negative value to apply the default ranking of [Long, Double]
(i, (c, -v))
}.reduceByKey(Ordering.apply[(Long, Double)].max
).mapValues(-_._2).collectAsMap()
}.toDF("index", "negative_value")
.groupBy("index", "negative_value").agg(count(lit(0)).as("count"))
.groupBy("index").agg(max(struct("count", "negative_value")).as("mode"))
.select(col("index"), negate(col("mode.negative_value")))
.as[(Int, Double)].collect().toMap
Array.tabulate(numCols)(i => modes.getOrElse(i, Double.NaN))
}

Expand Down

0 comments on commit e0605d6

Please sign in to comment.