Skip to content

Commit

Permalink
[SPARK-40880][SQL] Reimplement summary with dataframe operations
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Reimplement `summary` with dataframe operations

### Why are the changes needed?
1, do not truncate the sql plan any more;
2, enable sql optimization like column pruning:

```
scala> val df = spark.range(0, 3, 1, 10).withColumn("value", lit("str"))
df: org.apache.spark.sql.DataFrame = [id: bigint, value: string]

scala> df.summary("max", "50%").show
+-------+---+-----+
|summary| id|value|
+-------+---+-----+
|    max|  2|  str|
|    50%|  1| null|
+-------+---+-----+

scala> df.summary("max", "50%").select("id").show
+---+
| id|
+---+
|  2|
|  1|
+---+

scala> df.summary("max", "50%").select("id").queryExecution.optimizedPlan
res4: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan =
Project [element_at(id#367, summary#376, None, false) AS id#371]
+- Generate explode([max,50%]), false, [summary#376]
   +- Aggregate [map(max, cast(max(id#153L) as string), 50%, cast(percentile_approx(id#153L, [0.5], 10000, 0, 0)[0] as string)) AS id#367]
      +- Range (0, 3, step=1, splits=Some(10))

```

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
existing UTs and manually check

Closes #38346 from zhengruifeng/sql_stat_summary.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
zhengruifeng authored and HyukjinKwon committed Oct 24, 2022
1 parent 5d3b1e6 commit 6a0713a
Showing 1 changed file with 59 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@ import java.util.Locale

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, EvalMode, Expression, GenericInternalRow, GetArrayItem, Literal}
import org.apache.spark.sql.catalyst.expressions.{Cast, ElementAt, EvalMode, GenericInternalRow}
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.util.{GenericArrayData, QuantileSummaries}
import org.apache.spark.sql.catalyst.util.QuantileSummaries
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -199,9 +198,11 @@ object StatFunctions extends Logging {

/** Calculate selected summary statistics for a dataset */
def summary(ds: Dataset[_], statistics: Seq[String]): DataFrame = {

val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max")
val selectedStatistics = if (statistics.nonEmpty) statistics else defaultStatistics
val selectedStatistics = if (statistics.nonEmpty) {
statistics.toArray
} else {
Array("count", "mean", "stddev", "min", "25%", "50%", "75%", "max")
}

val percentiles = selectedStatistics.filter(a => a.endsWith("%")).map { p =>
try {
Expand All @@ -213,71 +214,66 @@ object StatFunctions extends Logging {
}
require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]")

def castAsDoubleIfNecessary(e: Expression): Expression = if (e.dataType == StringType) {
Cast(e, DoubleType, evalMode = EvalMode.TRY)
} else {
e
}
var percentileIndex = 0
val statisticFns = selectedStatistics.map { stats =>
if (stats.endsWith("%")) {
val index = percentileIndex
percentileIndex += 1
(child: Expression) =>
GetArrayItem(
new ApproximatePercentile(castAsDoubleIfNecessary(child),
Literal(new GenericArrayData(percentiles), ArrayType(DoubleType, false)))
.toAggregateExpression(),
Literal(index))
} else {
stats.toLowerCase(Locale.ROOT) match {
case "count" => (child: Expression) => Count(child).toAggregateExpression()
case "count_distinct" => (child: Expression) =>
Count(child).toAggregateExpression(isDistinct = true)
case "approx_count_distinct" => (child: Expression) =>
HyperLogLogPlusPlus(child).toAggregateExpression()
case "mean" => (child: Expression) =>
Average(castAsDoubleIfNecessary(child)).toAggregateExpression()
case "stddev" => (child: Expression) =>
StddevSamp(castAsDoubleIfNecessary(child)).toAggregateExpression()
case "min" => (child: Expression) => Min(child).toAggregateExpression()
case "max" => (child: Expression) => Max(child).toAggregateExpression()
case _ => throw QueryExecutionErrors.statisticNotRecognizedError(stats)
var mapColumns = Seq.empty[Column]
var columnNames = Seq.empty[String]

ds.schema.fields.foreach { field =>
if (field.dataType.isInstanceOf[NumericType] || field.dataType.isInstanceOf[StringType]) {
val column = col(field.name)
var casted = column
if (field.dataType.isInstanceOf[StringType]) {
casted = new Column(Cast(column.expr, DoubleType, evalMode = EvalMode.TRY))
}
}
}

val selectedCols = ds.logicalPlan.output
.filter(a => a.dataType.isInstanceOf[NumericType] || a.dataType.isInstanceOf[StringType])
val percentilesCol = if (percentiles.nonEmpty) {
percentile_approx(casted, lit(percentiles),
lit(ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY))
} else null

val aggExprs = statisticFns.flatMap { func =>
selectedCols.map(c => Column(Cast(func(c), StringType)).as(c.name))
}
var aggColumns = Seq.empty[Column]
var percentileIndex = 0
selectedStatistics.foreach { stats =>
aggColumns :+= lit(stats)

// If there is no selected columns, we don't need to run this aggregate, so make it a lazy val.
lazy val aggResult = ds.select(aggExprs: _*).queryExecution.toRdd.collect().head
stats.toLowerCase(Locale.ROOT) match {
case "count" => aggColumns :+= count(column)

// We will have one row for each selected statistic in the result.
val result = Array.fill[InternalRow](selectedStatistics.length) {
// each row has the statistic name, and statistic values of each selected column.
new GenericInternalRow(selectedCols.length + 1)
}
case "count_distinct" => aggColumns :+= count_distinct(column)

case "approx_count_distinct" => aggColumns :+= approx_count_distinct(column)

var rowIndex = 0
while (rowIndex < result.length) {
val statsName = selectedStatistics(rowIndex)
result(rowIndex).update(0, UTF8String.fromString(statsName))
for (colIndex <- selectedCols.indices) {
val statsValue = aggResult.getUTF8String(rowIndex * selectedCols.length + colIndex)
result(rowIndex).update(colIndex + 1, statsValue)
case "mean" => aggColumns :+= avg(casted)

case "stddev" => aggColumns :+= stddev(casted)

case "min" => aggColumns :+= min(column)

case "max" => aggColumns :+= max(column)

case percentile if percentile.endsWith("%") =>
aggColumns :+= get(percentilesCol, lit(percentileIndex))
percentileIndex += 1

case _ => throw QueryExecutionErrors.statisticNotRecognizedError(stats)
}
}

// map { "count" -> "1024", "min" -> "1.0", ... }
mapColumns :+= map(aggColumns.map(_.cast(StringType)): _*).as(field.name)
columnNames :+= field.name
}
rowIndex += 1
}

// All columns are string type
val output = AttributeReference("summary", StringType)() +:
selectedCols.map(c => AttributeReference(c.name, StringType)())

Dataset.ofRows(ds.sparkSession, LocalRelation(output, result))
if (mapColumns.isEmpty) {
ds.sparkSession.createDataFrame(selectedStatistics.map(Tuple1.apply))
.withColumnRenamed("_1", "summary")
} else {
val valueColumns = columnNames.map { columnName =>
new Column(ElementAt(col(columnName).expr, col("summary").expr)).as(columnName)
}
ds.select(mapColumns: _*)
.withColumn("summary", explode(lit(selectedStatistics)))
.select(Array(col("summary")) ++ valueColumns: _*)
}
}
}

0 comments on commit 6a0713a

Please sign in to comment.