Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-33466][ML][PYTHON] Imputer support mode(most_frequent) strategy #30397

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 30 additions & 19 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,16 @@ private[feature] trait ImputerParams extends Params with HasInputCol with HasInp
* The imputation strategy. Currently only "mean" and "median" are supported.
* If "mean", then replace missing values using the mean value of the feature.
* If "median", then replace missing values using the approximate median value of the feature.
* If "mode", then replace missing using the most frequent value of the feature.
* Default: mean
*
* @group param
*/
final val strategy: Param[String] = new Param(this, "strategy", s"strategy for imputation. " +
s"If ${Imputer.mean}, then replace missing values using the mean value of the feature. " +
s"If ${Imputer.median}, then replace missing values using the median value of the feature.",
ParamValidators.inArray[String](Array(Imputer.mean, Imputer.median)))
s"If ${Imputer.median}, then replace missing values using the median value of the feature. " +
s"If ${Imputer.mode}, then replace missing values using the most frequent value of " +
s"the feature.", ParamValidators.inArray[String](Imputer.supportedStrategies))

/** @group getParam */
def getStrategy: String = $(strategy)
Expand Down Expand Up @@ -104,7 +106,7 @@ private[feature] trait ImputerParams extends Params with HasInputCol with HasInp
* For example, if the input column is IntegerType (1, 2, 4, null),
* the output will be IntegerType (1, 2, 4, 2) after mean imputation.
*
* Note that the mean/median value is computed after filtering out missing values.
* Note that the mean/median/mode value is computed after filtering out missing values.
* All Null values in the input columns are treated as missing, and so are also imputed. For
* computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001.
*/
Expand Down Expand Up @@ -132,7 +134,7 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)

/**
* Imputation strategy. Available options are ["mean", "median"].
* Imputation strategy. Available options are ["mean", "median", "mode"].
* @group setParam
*/
@Since("2.2.0")
Expand All @@ -151,39 +153,44 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
val spark = dataset.sparkSession

val (inputColumns, _) = getInOutCols()

val cols = inputColumns.map { inputCol =>
when(col(inputCol).equalTo($(missingValue)), null)
.when(col(inputCol).isNaN, null)
.otherwise(col(inputCol))
.cast("double")
.cast(DoubleType)
.as(inputCol)
}
val numCols = cols.length

val results = $(strategy) match {
case Imputer.mean =>
// Function avg will ignore null automatically.
// For a column only containing null, avg will return null.
val row = dataset.select(cols.map(avg): _*).head()
Array.range(0, inputColumns.length).map { i =>
if (row.isNullAt(i)) {
Double.NaN
} else {
row.getDouble(i)
}
}
Array.tabulate(numCols)(i => if (row.isNullAt(i)) Double.NaN else row.getDouble(i))

case Imputer.median =>
// Function approxQuantile will ignore null automatically.
// For a column only containing null, approxQuantile will return an empty array.
dataset.select(cols: _*).stat.approxQuantile(inputColumns, Array(0.5), $(relativeError))
.map { array =>
if (array.isEmpty) {
Double.NaN
} else {
array.head
}
.map(_.headOption.getOrElse(Double.NaN))

case Imputer.mode =>
import spark.implicits._
// If there is more than one mode, choose the smallest one to keep in line
// with sklearn.impute.SimpleImputer (using scipy.stats.mode).
val modes = dataset.select(cols: _*).flatMap { row =>
Iterator.range(0, numCols).flatMap { i =>
// Ignore null.
// negative value to apply the default ranking of [Long, Double]
if (row.isNullAt(i)) Iterator.empty else Iterator.single((i, -row.getDouble(i)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: is None / Some simpler here in the flatMap?

}
}.toDF("index", "negative_value")
.groupBy("index", "negative_value").agg(count(lit(0)).as("count"))
.groupBy("index").agg(max(struct("count", "negative_value")).as("mode"))
.select(col("index"), negate(col("mode.negative_value")))
.as[(Int, Double)].collect().toMap
Array.tabulate(numCols)(i => modes.getOrElse(i, Double.NaN))
}

val emptyCols = inputColumns.zip(results).filter(_._2.isNaN).map(_._1)
Expand Down Expand Up @@ -212,6 +219,10 @@ object Imputer extends DefaultParamsReadable[Imputer] {
/** strategy names that Imputer currently supports. */
private[feature] val mean = "mean"
private[feature] val median = "median"
private[feature] val mode = "mode"

/* Set of strategies that Imputer supports */
private[feature] val supportedStrategies = Array(mean, median, mode)

@Since("2.2.0")
override def load(path: String): Imputer = super.load(path)
Expand Down
Loading