Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-981] check all expressions in HashAgg (#991)
Browse files Browse the repository at this point in the history
* check all expressions in HashAgg

Signed-off-by: Yuan Zhou <yuan.zhou@intel.com>

* check codegen support

Signed-off-by: Yuan Zhou <yuan.zhou@intel.com>

* check codegen in And/Or exprs

Signed-off-by: Yuan Zhou <yuan.zhou@intel.com>

* check Not expr

Signed-off-by: Yuan Zhou <yuan.zhou@intel.com>
  • Loading branch information
zhouyuan authored Jun 28, 2022
1 parent 84511b5 commit e63ef98
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -476,10 +476,24 @@ case class ColumnarHashAggregateExec(
// check project
for (expr <- aggregateExpressions) {
val internalExpressionList = expr.aggregateFunction.children
ColumnarProjection.buildCheck(child.output, internalExpressionList)
try {
ColumnarProjection.buildCheck(child.output, internalExpressionList)
} catch {
case e: UnsupportedOperationException =>
throw new UnsupportedOperationException(
"internalExpressionList has unsupported type in ColumnarAggregation")
}
}
ColumnarProjection.buildCheck(child.output, groupingExpressions)
ColumnarProjection.buildCheck(child.output, resultExpressions)

try {
ColumnarProjection.buildCheck(child.output, groupingExpressions)
ColumnarProjection.buildCheck(child.output, resultExpressions)
} catch {
case e: UnsupportedOperationException =>
throw new UnsupportedOperationException(
"groupingExpressions/resultExpressions has unsupported type in ColumnarAggregation")
}

// check the supported types and modes for different aggregate functions
checkTypeAndAggrFunction(aggregateExpressions, aggregateAttributes)
}
Expand Down Expand Up @@ -691,6 +705,20 @@ case class ColumnarHashAggregateExec(
return false
}
}
for (expr <- groupingExpressions) {
val colExpr = ColumnarExpressionConverter.replaceWithColumnarExpression(expr)
if (!colExpr.asInstanceOf[ColumnarExpression].supportColumnarCodegen(
Lists.newArrayList())) {
return false
}
}
for (expr <- resultExpressions) {
val colExpr = ColumnarExpressionConverter.replaceWithColumnarExpression(expr)
if (!colExpr.asInstanceOf[ColumnarExpression].supportColumnarCodegen(
Lists.newArrayList())) {
return false
}
}
}
return true
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ class ColumnarAnd(left: Expression, right: Expression, original: Expression)
extends And(left: Expression, right: Expression)
with ColumnarExpression
with Logging {

override def supportColumnarCodegen(args: Object): Boolean = {
true && left.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) && right.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}

override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
val (left_node, left_type): (TreeNode, ArrowType) =
left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
Expand All @@ -55,6 +60,11 @@ class ColumnarOr(left: Expression, right: Expression, original: Expression)
extends Or(left: Expression, right: Expression)
with ColumnarExpression
with Logging {

override def supportColumnarCodegen(args: Object): Boolean = {
true && left.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) && right.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}

override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
val (left_node, left_type): (TreeNode, ArrowType) =
left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class ColumnarCaseWhen(
val exprList = { exprs.filter(expr => !expr.isInstanceOf[Literal]) }
!exprList.map(expr => expr.asInstanceOf[ColumnarExpression].supportColumnarCodegen(Lists.newArrayList())).exists(_ == false)
}

override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
logInfo(s"children: ${branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue}")
logInfo(s"branches: $branches")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,10 @@ class ColumnarNot(child: Expression, original: Expression)
}
}

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
true && child.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}

override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
val (child_node, childType): (TreeNode, ArrowType) =
child.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
Expand Down

0 comments on commit e63ef98

Please sign in to comment.