Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
init

py

nit
  • Loading branch information
zhengruifeng committed Nov 17, 2020
1 parent f5e3302 commit 4626614
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 120 deletions.
52 changes: 33 additions & 19 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,16 @@ private[feature] trait ImputerParams extends Params with HasInputCol with HasInp
* The imputation strategy. Currently only "mean" and "median" are supported.
* If "mean", then replace missing values using the mean value of the feature.
* If "median", then replace missing values using the approximate median value of the feature.
* If "mode", then replace missing using the most frequent value of the feature.
* Default: mean
*
* @group param
*/
final val strategy: Param[String] = new Param(this, "strategy", s"strategy for imputation. " +
s"If ${Imputer.mean}, then replace missing values using the mean value of the feature. " +
s"If ${Imputer.median}, then replace missing values using the median value of the feature.",
ParamValidators.inArray[String](Array(Imputer.mean, Imputer.median)))
s"If ${Imputer.median}, then replace missing values using the median value of the feature. " +
s"If ${Imputer.mode}, then replace missing values using the most frequent value of " +
s"the feature.", ParamValidators.inArray[String](Imputer.supportedStrategies))

/** @group getParam */
def getStrategy: String = $(strategy)
Expand Down Expand Up @@ -104,7 +106,7 @@ private[feature] trait ImputerParams extends Params with HasInputCol with HasInp
* For example, if the input column is IntegerType (1, 2, 4, null),
* the output will be IntegerType (1, 2, 4, 2) after mean imputation.
*
* Note that the mean/median value is computed after filtering out missing values.
* Note that the mean/median/mode value is computed after filtering out missing values.
* All Null values in the input columns are treated as missing, and so are also imputed. For
* computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001.
*/
Expand Down Expand Up @@ -132,7 +134,7 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)

/**
* Imputation strategy. Available options are ["mean", "median"].
* Imputation strategy. Available options are ["mean", "median", "mode"].
* @group setParam
*/
@Since("2.2.0")
Expand All @@ -151,39 +153,47 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
val spark = dataset.sparkSession

val (inputColumns, _) = getInOutCols()

val cols = inputColumns.map { inputCol =>
when(col(inputCol).equalTo($(missingValue)), null)
.when(col(inputCol).isNaN, null)
.otherwise(col(inputCol))
.cast("double")
.cast(DoubleType)
.as(inputCol)
}
val numCols = cols.length

val results = $(strategy) match {
case Imputer.mean =>
// Function avg will ignore null automatically.
// For a column only containing null, avg will return null.
val row = dataset.select(cols.map(avg): _*).head()
Array.range(0, inputColumns.length).map { i =>
if (row.isNullAt(i)) {
Double.NaN
} else {
row.getDouble(i)
}
}
Array.tabulate(numCols)(i => if (row.isNullAt(i)) Double.NaN else row.getDouble(i))

case Imputer.median =>
// Function approxQuantile will ignore null automatically.
// For a column only containing null, approxQuantile will return an empty array.
dataset.select(cols: _*).stat.approxQuantile(inputColumns, Array(0.5), $(relativeError))
.map { array =>
if (array.isEmpty) {
Double.NaN
} else {
array.head
}
.map(_.headOption.getOrElse(Double.NaN))

case Imputer.mode =>
val modes = dataset.select(cols: _*).rdd.flatMap { row =>
Iterator.range(0, numCols).flatMap { i =>
// Ignore null.
if (row.isNullAt(i)) Iterator.empty else Iterator.single((i, row.getDouble(i)), 1L)
}
}.reduceByKey(_ + _).map { case ((i, v), c) => (i, (v, c))
}.reduceByKey { case ((v1, c1), (v2, c2)) =>
if (c1 > c2) {
(v1, c1)
} else if (c1 < c2) {
(v2, c2)
} else {
// Keep in line with sklearn.impute.SimpleImputer (using scipy.stats.mode).
// If there is more than one mode, choose the smallest one.
(math.min(v1, v2), c1)
}
}.mapValues(_._1).collectAsMap()
Array.tabulate(numCols)(i => modes.getOrElse(i, Double.NaN))
}

val emptyCols = inputColumns.zip(results).filter(_._2.isNaN).map(_._1)
Expand Down Expand Up @@ -212,6 +222,10 @@ object Imputer extends DefaultParamsReadable[Imputer] {
/** strategy names that Imputer currently supports. */
private[feature] val mean = "mean"
private[feature] val median = "median"
private[feature] val mode = "mode"

/* Set of strategies that Imputer supports */
private[feature] val supportedStrategies = Array(mean, median, mode)

@Since("2.2.0")
override def load(path: String): Imputer = super.load(path)
Expand Down
Loading

0 comments on commit 4626614

Please sign in to comment.