Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-153] Following NSE-153, optimize fallback conditions for columnar window #189

Merged
merged 1 commit into from
Mar 24, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.arrow.gandiva.expression.TreeBuilder
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeReference, Cast, Descending, Expression, MakeDecimal, NamedExpression, Rank, SortOrder, UnscaledValue, WindowExpression, WindowFunction}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeReference, Cast, Descending, Expression, MakeDecimal, NamedExpression, Rank, SortOrder, UnscaledValue, WindowExpression, WindowFunction, WindowSpecDefinition}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Average, Sum}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.execution.SparkPlan
Expand Down Expand Up @@ -71,45 +71,65 @@ class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
val sparkConf = sparkContext.getConf
val numaBindingInfo = ColumnarPluginConfig.getConf.numaBindingInfo

def checkAggFunctionSpec(windowSpec: WindowSpecDefinition): Unit = {
if (windowSpec.orderSpec.nonEmpty) {
throw new UnsupportedOperationException("unsupported operation for " +
"aggregation window function: " + windowSpec)
}
}

def checkRankSpec(windowSpec: WindowSpecDefinition): Unit = {
// leave it empty for now
}

val windowFunctions: Seq[(String, Expression)] = windowExpression
.map(e => e.asInstanceOf[Alias])
.map(a => a.child.asInstanceOf[WindowExpression])
.map(w => w.windowFunction)
.map(w => (w, w.windowFunction))
.map {
case a: AggregateExpression => a.aggregateFunction
case b: WindowFunction => b
case f =>
throw new UnsupportedOperationException("unsupported window function type: " +
f)
case (expr, func) =>
(expr, func match {
case a: AggregateExpression => a.aggregateFunction
case b: WindowFunction => b
case f =>
throw new UnsupportedOperationException("unsupported window function type: " +
f)
})
}
.map { f =>
val name = f match {
case _: Sum => "sum"
case _: Average => "avg"
case _: Rank =>
val desc: Option[Boolean] = orderSpec.foldLeft[Option[Boolean]](None) {
(desc, s) =>
val currentDesc = s.direction match {
case Ascending => false
case Descending => true
case _ => throw new IllegalStateException
}
if (desc.isEmpty) {
Some(currentDesc)
} else if (currentDesc == desc.get) {
Some(currentDesc)
} else {
throw new UnsupportedOperationException("Rank: clashed rank order found")
}
}
desc match {
case Some(true) => "rank_desc"
case Some(false) => "rank_asc"
case None => "rank_asc"
}
case f => throw new UnsupportedOperationException("unsupported window function: " + f)
}
(name, f)
.map {
case (expr, func) =>
val name = func match {
case _: Sum =>
checkAggFunctionSpec(expr.windowSpec)
"sum"
case _: Average =>
checkAggFunctionSpec(expr.windowSpec)
"avg"
case _: Rank =>
checkRankSpec(expr.windowSpec)
val desc: Option[Boolean] = orderSpec.foldLeft[Option[Boolean]](None) {
(desc, s) =>
val currentDesc = s.direction match {
case Ascending => false
case Descending => true
case _ => throw new IllegalStateException
}
if (desc.isEmpty) {
Some(currentDesc)
} else if (currentDesc == desc.get) {
Some(currentDesc)
} else {
throw new UnsupportedOperationException("Rank: clashed rank order found")
}
}
desc match {
case Some(true) => "rank_desc"
case Some(false) => "rank_asc"
case None => "rank_asc"
}
case f => throw new UnsupportedOperationException("unsupported window function: " + f)
}
(name, func)
}

if (windowFunctions.isEmpty) {
Expand Down Expand Up @@ -349,11 +369,6 @@ object ColumnarWindowExec {
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
child: SparkPlan): SparkPlan = {
//TODO(): this is a quick fix on non-avg window issue
if (!windowExpression.toString.contains("avg")) {
new ColumnarWindowExec(windowExpression, partitionSpec, orderSpec, child)
} else {
createWithProjection(windowExpression, partitionSpec, orderSpec, child)
}
createWithProjection(windowExpression, partitionSpec, orderSpec, child)
}
}