Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-745] Improve codegen check for expression #751

Merged
merged 2 commits into from
Mar 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,24 @@ case class ColumnarConditionProjectExec(

override def getChild: SparkPlan = child

override def supportColumnarCodegen: Boolean = true

// override def canEqual(that: Any): Boolean = false
override def supportColumnarCodegen: Boolean = {
if (condition != null) {
val colCondExpr = ColumnarExpressionConverter.replaceWithColumnarExpression(condition)
// support codegen if cond expression and proj expression both supports codegen
if (!colCondExpr.asInstanceOf[ColumnarExpression].supportColumnarCodegen(Lists.newArrayList())) {
return false
}
}
if (projectList != null) {
for (expr <- projectList) {
val colExpr = ColumnarExpressionConverter.replaceWithColumnarExpression(expr)
if (!colExpr.asInstanceOf[ColumnarExpression].supportColumnarCodegen(Lists.newArrayList())) {
return false
}
}
}
true
}

def getKernelFunction(childTreeNode: TreeNode): TreeNode = {
val (filterNode, projectNode) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -637,9 +637,19 @@ case class ColumnarHashAggregateExec(

override def getChild: SparkPlan = child

override def supportColumnarCodegen: Boolean = true
override def supportColumnarCodegen: Boolean = {
for (expr <- aggregateExpressions) {
val internalExpressionList = expr.aggregateFunction.children
for (expr <- internalExpressionList) {
val colExpr = ColumnarExpressionConverter.replaceWithColumnarExpression(expr)
if (!colExpr.asInstanceOf[ColumnarExpression].supportColumnarCodegen(Lists.newArrayList())) {
return false
}
}

// override def canEqual(that: Any): Boolean = false
}
return true
}

def getKernelFunction: TreeNode = {
ColumnarHashAggregation.prepareKernelFunction(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ class ColumnarMultiply(left: Expression, right: Expression, original: Expression
(funcNode, resultType)
}
}

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
return left_val.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) && right_val.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}
}

class ColumnarDivide(left: Expression, right: Expression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ class ColumnarContains(left: Expression, right: Expression, original: Expression
TreeBuilder.makeFunction("is_substr", Lists.newArrayList(left_node, right_node), resultType)
(funcNode, resultType)
}

override def supportColumnarCodegen(args: Object): Boolean = {
false && left.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) && right.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}
}

class ColumnarEqualTo(left: Expression, right: Expression, original: Expression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ class ColumnarCaseWhen(
})
}

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
val exprs = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue
val exprList = { exprs.filter(expr => !expr.isInstanceOf[Literal]) }
!exprList.map(expr => expr.asInstanceOf[ColumnarExpression].supportColumnarCodegen(Lists.newArrayList())).exists(_ == false)
}
override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
logInfo(s"children: ${branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue}")
logInfo(s"branches: $branches")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,10 @@ object ColumnarDateTimeExpressions {
"date_diff", Lists.newArrayList(leftNode, rightNode), outType)
(funcNode, outType)
}

override def supportColumnarCodegen(args: Object): Boolean = {
false && left.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) && right.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}
}

class ColumnarMakeDate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ import scala.collection.mutable.ListBuffer

trait ColumnarExpression {

def supportColumnarCodegen(args: java.lang.Object): (Boolean) = {
// TODO: disable all codegen unless manuall enabled
true
}

def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
throw new UnsupportedOperationException(s"Not support doColumnarCodeGen.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ class ColumnarIf(predicate: Expression, trueValue: Expression,
val funcNode = TreeBuilder.makeIf(predicate_node, true_node, false_node, trueType)
(funcNode, trueType)
}

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
// return true only when all branches are true
val ret = (predicate.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) &&
trueValue.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) &&
falseValue.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args))
return ret
}
}

object ColumnarIfOperator {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class ColumnarAlias(child: Expression, name: String)(
child.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
}

override def supportColumnarCodegen(args: Object): Boolean = {
child.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}

}

class ColumnarAttributeReference(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,10 @@ class ColumnarFloor(child: Expression, original: Expression)
TreeBuilder.makeFunction("floor", Lists.newArrayList(child_node), resultType)
(funcNode, resultType)
}

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
false && child.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}
}

class ColumnarCeil(child: Expression, original: Expression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ case class ColumnarCollapseCodegenStages(
if plan.output.length == 1 && plan.output.head.dataType.isInstanceOf[ObjectType] =>
plan.withNewChildren(plan.children.map(insertWholeStageCodegen))
case j: ColumnarHashAggregateExec =>
if (!j.child.isInstanceOf[ColumnarHashAggregateExec] && existsJoins(j)) {
if (j.supportColumnarCodegen && !j.child.isInstanceOf[ColumnarHashAggregateExec] && existsJoins(j)) {
ColumnarWholeStageCodegenExec(j.withNewChildren(j.children.map(insertInputAdapter)))(
codegenStageCounter.incrementAndGet())
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,8 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
check_str_ = validity;
} else {
std::cout << "function name: " << func_name << std::endl;
return arrow::Status::NotImplemented(func_name, " is currently not supported.");
return arrow::Status::NotImplemented(func_name,
" is currently not supported in WSCG.");
}
return arrow::Status::OK();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ class WholeStageCodeGenKernel::Impl {
result_field_node_list,
result_expr_node_list, out));
} else {
return arrow::Status::NotImplemented("Not supported function name:", func_name);
return arrow::Status::NotImplemented("WSCG Not supported function name:",
func_name);
}
return arrow::Status::OK();
}
Expand Down