diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarWindowExec.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarWindowExec.scala index 0b8b170ac..bd0d7cd9c 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarWindowExec.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarWindowExec.scala @@ -27,7 +27,7 @@ import org.apache.arrow.gandiva.expression.TreeBuilder import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeReference, Cast, Descending, Expression, MakeDecimal, NamedExpression, Rank, SortOrder, UnscaledValue, WindowExpression, WindowFunction} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeReference, Cast, Descending, Expression, MakeDecimal, NamedExpression, Rank, SortOrder, UnscaledValue, WindowExpression, WindowFunction, WindowSpecDefinition} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Average, Sum} import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.execution.SparkPlan @@ -71,45 +71,65 @@ class ColumnarWindowExec(windowExpression: Seq[NamedExpression], val sparkConf = sparkContext.getConf val numaBindingInfo = ColumnarPluginConfig.getConf.numaBindingInfo + def checkAggFunctionSpec(windowSpec: WindowSpecDefinition): Unit = { + if (windowSpec.orderSpec.nonEmpty) { + throw new UnsupportedOperationException("unsupported operation for " + + "aggregation window function: " + windowSpec) + } + } + + def checkRankSpec(windowSpec: WindowSpecDefinition): Unit = { + // leave it empty for now + } + val windowFunctions: Seq[(String, Expression)] = windowExpression .map(e => e.asInstanceOf[Alias]) .map(a => a.child.asInstanceOf[WindowExpression]) - .map(w => w.windowFunction) + .map(w => (w, w.windowFunction)) .map { - case a: AggregateExpression => a.aggregateFunction - case b: WindowFunction => b - case f => - throw new UnsupportedOperationException("unsupported window function type: " + - f) + case (expr, func) => + (expr, func match { + case a: AggregateExpression => a.aggregateFunction + case b: WindowFunction => b + case f => + throw new UnsupportedOperationException("unsupported window function type: " + + f) + }) } - .map { f => - val name = f match { - case _: Sum => "sum" - case _: Average => "avg" - case _: Rank => - val desc: Option[Boolean] = orderSpec.foldLeft[Option[Boolean]](None) { - (desc, s) => - val currentDesc = s.direction match { - case Ascending => false - case Descending => true - case _ => throw new IllegalStateException - } - if (desc.isEmpty) { - Some(currentDesc) - } else if (currentDesc == desc.get) { - Some(currentDesc) - } else { - throw new UnsupportedOperationException("Rank: clashed rank order found") - } - } - desc match { - case Some(true) => "rank_desc" - case Some(false) => "rank_asc" - case None => "rank_asc" - } - case f => throw new UnsupportedOperationException("unsupported window function: " + f) - } - (name, f) + .map { + case (expr, func) => + val name = func match { + case _: Sum => + checkAggFunctionSpec(expr.windowSpec) + "sum" + case _: Average => + checkAggFunctionSpec(expr.windowSpec) + "avg" + case _: Rank => + checkRankSpec(expr.windowSpec) + val desc: Option[Boolean] = orderSpec.foldLeft[Option[Boolean]](None) { + (desc, s) => + val currentDesc = s.direction match { + case Ascending => false + case Descending => true + case _ => throw new IllegalStateException + } + if (desc.isEmpty) { + Some(currentDesc) + } else if (currentDesc == desc.get) { + Some(currentDesc) + } else { + throw new UnsupportedOperationException("Rank: clashed rank order found") + } + } + desc match { + case Some(true) => "rank_desc" + case Some(false) => "rank_asc" + case None => "rank_asc" + } + case f => throw new UnsupportedOperationException("unsupported window function: " + f) + } + (name, func) } if (windowFunctions.isEmpty) { @@ -349,11 +369,6 @@ object ColumnarWindowExec { partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], child: SparkPlan): SparkPlan = { - //TODO(): this is a quick fix on non-avg window issue - if (!windowExpression.toString.contains("avg")) { - new ColumnarWindowExec(windowExpression, partitionSpec, orderSpec, child) - } else { - createWithProjection(windowExpression, partitionSpec, orderSpec, child) - } + createWithProjection(windowExpression, partitionSpec, orderSpec, child) } }