diff --git a/core/src/main/scala/com/intel/oap/execution/ColumnarBasicPhysicalOperators.scala b/core/src/main/scala/com/intel/oap/execution/ColumnarBasicPhysicalOperators.scala index fbff56e57..d39101abc 100644 --- a/core/src/main/scala/com/intel/oap/execution/ColumnarBasicPhysicalOperators.scala +++ b/core/src/main/scala/com/intel/oap/execution/ColumnarBasicPhysicalOperators.scala @@ -27,8 +27,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch} +import org.apache.spark.sql.types.{DecimalType, StructType} import org.apache.spark.util.ExecutorManager import org.apache.spark.sql.util.StructTypeFWD import org.apache.spark.{SparkConf, TaskContext} @@ -70,8 +70,10 @@ case class ColumnarConditionProjectExec( ConverterUtils.checkIfTypeSupported(attr.dataType) } catch { case e : UnsupportedOperationException => - throw new UnsupportedOperationException( - s"${attr.dataType} is not supported in ColumnarConditionProjector.") + if (!attr.dataType.isInstanceOf[DecimalType]) { + throw new UnsupportedOperationException( + s"${attr.dataType} is not supported in ColumnarConditionProjector.") + } } }) // check expr @@ -80,8 +82,10 @@ case class ColumnarConditionProjectExec( ConverterUtils.checkIfTypeSupported(condExpr.dataType) } catch { case e : UnsupportedOperationException => - throw new UnsupportedOperationException( - s"${condExpr.dataType} is not supported in ColumnarConditionProjector.") + if (!condExpr.dataType.isInstanceOf[DecimalType]) { + throw new UnsupportedOperationException( + s"${condExpr.dataType} is not supported in ColumnarConditionProjector.") + } } ColumnarExpressionConverter.replaceWithColumnarExpression(condExpr) } @@ -91,8 +95,10 @@ case class ColumnarConditionProjectExec( ConverterUtils.checkIfTypeSupported(expr.dataType) } catch { case e : UnsupportedOperationException => - throw new UnsupportedOperationException( - s"${expr.dataType} is not supported in ColumnarConditionProjector.") + if (!expr.dataType.isInstanceOf[DecimalType]) { + throw new UnsupportedOperationException( + s"${expr.dataType} is not supported in ColumnarConditionProjector.") + } } ColumnarExpressionConverter.replaceWithColumnarExpression(expr) } diff --git a/core/src/main/scala/com/intel/oap/expression/ColumnarArithmetic.scala b/core/src/main/scala/com/intel/oap/expression/ColumnarArithmetic.scala index e2319c910..1dd5e3c06 100644 --- a/core/src/main/scala/com/intel/oap/expression/ColumnarArithmetic.scala +++ b/core/src/main/scala/com/intel/oap/expression/ColumnarArithmetic.scala @@ -18,19 +18,20 @@ package com.intel.oap.expression import com.google.common.collect.Lists - import org.apache.arrow.gandiva.evaluator._ import org.apache.arrow.gandiva.exceptions.GandivaException import org.apache.arrow.gandiva.expression._ +import org.apache.arrow.vector.types.FloatingPointPrecision import org.apache.arrow.vector.types.pojo.ArrowType import org.apache.arrow.vector.types.pojo.Field - import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ import scala.collection.mutable.ListBuffer +import org.apache.arrow.gandiva.evaluator.DecimalTypeUtil + /** * A version of add that supports columnar processing for longs. */ @@ -44,22 +45,30 @@ class ColumnarAdd(left: Expression, right: Expression, original: Expression) var (right_node, right_type): (TreeNode, ArrowType) = right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) - val resultType = CodeGeneration.getResultType(left_type, right_type) - if (!left_type.equals(resultType)) { - val func_name = CodeGeneration.getCastFuncName(resultType) - left_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType), + (left_type, right_type) match { + case (l: ArrowType.Decimal, r: ArrowType.Decimal) => + val resultType = DecimalTypeUtil.getResultTypeForOperation( + DecimalTypeUtil.OperationType.ADD, l, r) + val addNode = TreeBuilder.makeFunction( + "add", Lists.newArrayList(left_node, right_node), resultType) + (addNode, resultType) + case _ => + val resultType = CodeGeneration.getResultType(left_type, right_type) + if (!left_type.equals(resultType)) { + val func_name = CodeGeneration.getCastFuncName(resultType) + left_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType) + } + if (!right_type.equals(resultType)) { + val func_name = CodeGeneration.getCastFuncName(resultType) + right_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType) + } + //logInfo(s"(TreeBuilder.makeFunction(add, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)") + val funcNode = TreeBuilder.makeFunction( + "add", Lists.newArrayList(left_node, right_node), resultType) + (funcNode, resultType) } - if (!right_type.equals(resultType)) { - val func_name = CodeGeneration.getCastFuncName(resultType) - right_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType), - } - - //logInfo(s"(TreeBuilder.makeFunction(add, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)") - ( - TreeBuilder.makeFunction("add", Lists.newArrayList(left_node, right_node), resultType), - resultType) } } @@ -73,21 +82,30 @@ class ColumnarSubtract(left: Expression, right: Expression, original: Expression var (right_node, right_type): (TreeNode, ArrowType) = right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) - val resultType = CodeGeneration.getResultType(left_type, right_type) - if (!left_type.equals(resultType)) { - val func_name = CodeGeneration.getCastFuncName(resultType) - left_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType), + (left_type, right_type) match { + case (l: ArrowType.Decimal, r: ArrowType.Decimal) => + val resultType = DecimalTypeUtil.getResultTypeForOperation( + DecimalTypeUtil.OperationType.SUBTRACT, l, r) + val subNode = TreeBuilder.makeFunction( + "subtract", Lists.newArrayList(left_node, right_node), resultType) + (subNode, resultType) + case _ => + val resultType = CodeGeneration.getResultType(left_type, right_type) + if (!left_type.equals(resultType)) { + val func_name = CodeGeneration.getCastFuncName(resultType) + left_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType) + } + if (!right_type.equals(resultType)) { + val func_name = CodeGeneration.getCastFuncName(resultType) + right_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType) + } + //logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)") + val funcNode = TreeBuilder.makeFunction( + "subtract", Lists.newArrayList(left_node, right_node), resultType) + (funcNode, resultType) } - if (!right_type.equals(resultType)) { - val func_name = CodeGeneration.getCastFuncName(resultType) - right_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType), - } - //logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)") - ( - TreeBuilder.makeFunction("subtract", Lists.newArrayList(left_node, right_node), resultType), - resultType) } } @@ -101,22 +119,30 @@ class ColumnarMultiply(left: Expression, right: Expression, original: Expression var (right_node, right_type): (TreeNode, ArrowType) = right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) - val resultType = CodeGeneration.getResultType(left_type, right_type) - if (!left_type.equals(resultType)) { - val func_name = CodeGeneration.getCastFuncName(resultType) - left_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType), - } - if (!right_type.equals(resultType)) { - val func_name = CodeGeneration.getCastFuncName(resultType) - right_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType), + (left_type, right_type) match { + case (l: ArrowType.Decimal, r: ArrowType.Decimal) => + val resultType = DecimalTypeUtil.getResultTypeForOperation( + DecimalTypeUtil.OperationType.MULTIPLY, l, r) + val mulNode = TreeBuilder.makeFunction( + "multiply", Lists.newArrayList(left_node, right_node), resultType) + (mulNode, resultType) + case _ => + val resultType = CodeGeneration.getResultType(left_type, right_type) + if (!left_type.equals(resultType)) { + val func_name = CodeGeneration.getCastFuncName(resultType) + left_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType), + } + if (!right_type.equals(resultType)) { + val func_name = CodeGeneration.getCastFuncName(resultType) + right_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType), + } + //logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)") + val funcNode = TreeBuilder.makeFunction( + "multiply", Lists.newArrayList(left_node, right_node), resultType) + (funcNode, resultType) } - - //logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)") - ( - TreeBuilder.makeFunction("multiply", Lists.newArrayList(left_node, right_node), resultType), - resultType) } } @@ -130,21 +156,30 @@ class ColumnarDivide(left: Expression, right: Expression, original: Expression) var (right_node, right_type): (TreeNode, ArrowType) = right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) - val resultType = CodeGeneration.getResultType(left_type, right_type) - if (!left_type.equals(resultType)) { - val func_name = CodeGeneration.getCastFuncName(resultType) - left_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType), - } - if (!right_type.equals(resultType)) { - val func_name = CodeGeneration.getCastFuncName(resultType) - right_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType), + (left_type, right_type) match { + case (l: ArrowType.Decimal, r: ArrowType.Decimal) => + val resultType = DecimalTypeUtil.getResultTypeForOperation( + DecimalTypeUtil.OperationType.DIVIDE, l, r) + val divNode = TreeBuilder.makeFunction( + "divide", Lists.newArrayList(left_node, right_node), resultType) + (divNode, resultType) + case _ => + val resultType = CodeGeneration.getResultType(left_type, right_type) + if (!left_type.equals(resultType)) { + val func_name = CodeGeneration.getCastFuncName(resultType) + left_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType), + } + if (!right_type.equals(resultType)) { + val func_name = CodeGeneration.getCastFuncName(resultType) + right_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType), + } + //logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)") + val funcNode = TreeBuilder.makeFunction( + "divide", Lists.newArrayList(left_node, right_node), resultType) + (funcNode, resultType) } - //logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)") - ( - TreeBuilder.makeFunction("divide", Lists.newArrayList(left_node, right_node), resultType), - resultType) } } @@ -238,8 +273,11 @@ object ColumnarBinaryArithmetic { ConverterUtils.checkIfTypeSupported(right.dataType) } catch { case e : UnsupportedOperationException => - throw new UnsupportedOperationException( - s"${left.dataType} or ${right.dataType} is not supported in ColumnarBinaryArithmetic") + if (!left.dataType.isInstanceOf[DecimalType] || + !right.dataType.isInstanceOf[DecimalType]) { + throw new UnsupportedOperationException( + s"${left.dataType} or ${right.dataType} is not supported in ColumnarBinaryArithmetic") + } } } } diff --git a/core/src/main/scala/com/intel/oap/expression/ColumnarBinaryOperator.scala b/core/src/main/scala/com/intel/oap/expression/ColumnarBinaryOperator.scala index f39b68b0b..1991f2e5a 100644 --- a/core/src/main/scala/com/intel/oap/expression/ColumnarBinaryOperator.scala +++ b/core/src/main/scala/com/intel/oap/expression/ColumnarBinaryOperator.scala @@ -146,16 +146,21 @@ class ColumnarEqualTo(left: Expression, right: Expression, original: Expression) right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) val unifiedType = CodeGeneration.getResultType(left_type, right_type) - if (!left_type.equals(unifiedType)) { - val func_name = CodeGeneration.getCastFuncName(unifiedType) - left_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), unifiedType) - } - if (!right_type.equals(unifiedType)) { - val func_name = CodeGeneration.getCastFuncName(unifiedType) - right_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), unifiedType) + (left_type, right_type) match { + case (l: ArrowType.Decimal, r: ArrowType.Decimal) => + case _ => + if (!left_type.equals(unifiedType)) { + val func_name = CodeGeneration.getCastFuncName(unifiedType) + left_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), unifiedType) + } + if (!right_type.equals(unifiedType)) { + val func_name = CodeGeneration.getCastFuncName(unifiedType) + right_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), unifiedType) + } } + var function = "equal" val nanCheck = ColumnarPluginConfig.getConf.enableColumnarNaNCheck if (nanCheck) { @@ -183,16 +188,21 @@ class ColumnarEqualNull(left: Expression, right: Expression, original: Expressio right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) val unifiedType = CodeGeneration.getResultType(left_type, right_type) - if (!left_type.equals(unifiedType)) { - val func_name = CodeGeneration.getCastFuncName(unifiedType) - left_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), unifiedType) - } - if (!right_type.equals(unifiedType)) { - val func_name = CodeGeneration.getCastFuncName(unifiedType) - right_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), unifiedType) + (left_type, right_type) match { + case (l: ArrowType.Decimal, r: ArrowType.Decimal) => + case _ => + if (!left_type.equals(unifiedType)) { + val func_name = CodeGeneration.getCastFuncName(unifiedType) + left_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), unifiedType) + } + if (!right_type.equals(unifiedType)) { + val func_name = CodeGeneration.getCastFuncName(unifiedType) + right_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), unifiedType) + } } + var function = "equal" val nanCheck = ColumnarPluginConfig.getConf.enableColumnarNaNCheck if (nanCheck) { @@ -220,16 +230,21 @@ class ColumnarLessThan(left: Expression, right: Expression, original: Expression right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) val unifiedType = CodeGeneration.getResultType(left_type, right_type) - if (!left_type.equals(unifiedType)) { - val func_name = CodeGeneration.getCastFuncName(unifiedType) - left_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), unifiedType) - } - if (!right_type.equals(unifiedType)) { - val func_name = CodeGeneration.getCastFuncName(unifiedType) - right_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), unifiedType) + (left_type, right_type) match { + case (l: ArrowType.Decimal, r: ArrowType.Decimal) => + case _ => + if (!left_type.equals(unifiedType)) { + val func_name = CodeGeneration.getCastFuncName(unifiedType) + left_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), unifiedType) + } + if (!right_type.equals(unifiedType)) { + val func_name = CodeGeneration.getCastFuncName(unifiedType) + right_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), unifiedType) + } } + var function = "less_than" val nanCheck = ColumnarPluginConfig.getConf.enableColumnarNaNCheck if (nanCheck) { @@ -257,16 +272,21 @@ class ColumnarLessThanOrEqual(left: Expression, right: Expression, original: Exp right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) val unifiedType = CodeGeneration.getResultType(left_type, right_type) - if (!left_type.equals(unifiedType)) { - val func_name = CodeGeneration.getCastFuncName(unifiedType) - left_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), unifiedType) - } - if (!right_type.equals(unifiedType)) { - val func_name = CodeGeneration.getCastFuncName(unifiedType) - right_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), unifiedType) + (left_type, right_type) match { + case (l: ArrowType.Decimal, r: ArrowType.Decimal) => + case _ => + if (!left_type.equals(unifiedType)) { + val func_name = CodeGeneration.getCastFuncName(unifiedType) + left_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), unifiedType) + } + if (!right_type.equals(unifiedType)) { + val func_name = CodeGeneration.getCastFuncName(unifiedType) + right_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), unifiedType) + } } + var function = "less_than_or_equal_to" val nanCheck = ColumnarPluginConfig.getConf.enableColumnarNaNCheck if (nanCheck) { @@ -296,15 +316,19 @@ class ColumnarGreaterThan(left: Expression, right: Expression, original: Express right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) val unifiedType = CodeGeneration.getResultType(left_type, right_type) - if (!left_type.equals(unifiedType)) { - val func_name = CodeGeneration.getCastFuncName(unifiedType) - left_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), unifiedType) - } - if (!right_type.equals(unifiedType)) { - val func_name = CodeGeneration.getCastFuncName(unifiedType) - right_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), unifiedType) + (left_type, right_type) match { + case (l: ArrowType.Decimal, r: ArrowType.Decimal) => + case _ => + if (!left_type.equals(unifiedType)) { + val func_name = CodeGeneration.getCastFuncName(unifiedType) + left_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), unifiedType) + } + if (!right_type.equals(unifiedType)) { + val func_name = CodeGeneration.getCastFuncName(unifiedType) + right_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), unifiedType) + } } var function = "greater_than" @@ -336,16 +360,21 @@ class ColumnarGreaterThanOrEqual(left: Expression, right: Expression, original: right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) val unifiedType = CodeGeneration.getResultType(left_type, right_type) - if (!left_type.equals(unifiedType)) { - val func_name = CodeGeneration.getCastFuncName(unifiedType) - left_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), unifiedType) - } - if (!right_type.equals(unifiedType)) { - val func_name = CodeGeneration.getCastFuncName(unifiedType) - right_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), unifiedType) + (left_type, right_type) match { + case (l: ArrowType.Decimal, r: ArrowType.Decimal) => + case _ => + if (!left_type.equals(unifiedType)) { + val func_name = CodeGeneration.getCastFuncName(unifiedType) + left_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), unifiedType) + } + if (!right_type.equals(unifiedType)) { + val func_name = CodeGeneration.getCastFuncName(unifiedType) + right_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), unifiedType) + } } + var function = "greater_than_or_equal_to" val nanCheck = ColumnarPluginConfig.getConf.enableColumnarNaNCheck if (nanCheck) { @@ -452,8 +481,11 @@ object ColumnarBinaryOperator { ConverterUtils.checkIfTypeSupported(right.dataType) } catch { case e : UnsupportedOperationException => - throw new UnsupportedOperationException( - s"${left.dataType} or ${right.dataType} is not supported in ColumnarBinaryOperator") + if (!left.dataType.isInstanceOf[DecimalType] || + !right.dataType.isInstanceOf[DecimalType]) { + throw new UnsupportedOperationException( + s"${left.dataType} or ${right.dataType} is not supported in ColumnarBinaryOperator") + } } } } diff --git a/core/src/main/scala/com/intel/oap/expression/ColumnarBoundAttribute.scala b/core/src/main/scala/com/intel/oap/expression/ColumnarBoundAttribute.scala index 64d784030..0dd8ba5cb 100644 --- a/core/src/main/scala/com/intel/oap/expression/ColumnarBoundAttribute.scala +++ b/core/src/main/scala/com/intel/oap/expression/ColumnarBoundAttribute.scala @@ -42,8 +42,10 @@ class ColumnarBoundReference(ordinal: Int, dataType: DataType, nullable: Boolean ConverterUtils.checkIfTypeSupported(dataType) } catch { case e : UnsupportedOperationException => - throw new UnsupportedOperationException( - s"${dataType} is not supported in ColumnarBoundReference.") + if (!dataType.isInstanceOf[DecimalType]) { + throw new UnsupportedOperationException( + s"${dataType} is not supported in ColumnarBoundReference.") + } } } override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = { diff --git a/core/src/main/scala/com/intel/oap/expression/ColumnarCaseWhenOperator.scala b/core/src/main/scala/com/intel/oap/expression/ColumnarCaseWhenOperator.scala index 0bc161d1f..6f5d9849c 100644 --- a/core/src/main/scala/com/intel/oap/expression/ColumnarCaseWhenOperator.scala +++ b/core/src/main/scala/com/intel/oap/expression/ColumnarCaseWhenOperator.scala @@ -52,8 +52,10 @@ class ColumnarCaseWhen( ConverterUtils.checkIfTypeSupported(expr.dataType) } catch { case e : UnsupportedOperationException => - throw new UnsupportedOperationException( - s"${dataType} is not supported in ColumnarCaseWhen") + if (!expr.dataType.isInstanceOf[DecimalType]) { + throw new UnsupportedOperationException( + s"${dataType} is not supported in ColumnarCaseWhen") + } }) } diff --git a/core/src/main/scala/com/intel/oap/expression/ColumnarCoalesceOperator.scala b/core/src/main/scala/com/intel/oap/expression/ColumnarCoalesceOperator.scala index 00b422364..f7907225a 100644 --- a/core/src/main/scala/com/intel/oap/expression/ColumnarCoalesceOperator.scala +++ b/core/src/main/scala/com/intel/oap/expression/ColumnarCoalesceOperator.scala @@ -53,8 +53,10 @@ class ColumnarCoalesce(exps: Seq[Expression], original: Expression) ConverterUtils.checkIfTypeSupported(expr.dataType) } catch { case e : UnsupportedOperationException => - throw new UnsupportedOperationException( - s"${expr.dataType} is not supported in ColumnarCoalesce") + if (!expr.dataType.isInstanceOf[DecimalType]) { + throw new UnsupportedOperationException( + s"${expr.dataType} is not supported in ColumnarCoalesce") + } } ) } diff --git a/core/src/main/scala/com/intel/oap/expression/ColumnarIfOperator.scala b/core/src/main/scala/com/intel/oap/expression/ColumnarIfOperator.scala index 7fbf22772..0c2700cbb 100644 --- a/core/src/main/scala/com/intel/oap/expression/ColumnarIfOperator.scala +++ b/core/src/main/scala/com/intel/oap/expression/ColumnarIfOperator.scala @@ -43,9 +43,13 @@ class ColumnarIf(predicate: Expression, trueValue: Expression, ConverterUtils.checkIfTypeSupported(falseValue.dataType) } catch { case e : UnsupportedOperationException => - throw new UnsupportedOperationException( - s"${predicate.dataType} or ${trueValue.dataType} or ${falseValue.dataType} " + - s"is not supported in ColumnarIf") + if (!predicate.dataType.isInstanceOf[DecimalType] || + !trueValue.dataType.isInstanceOf[DecimalType] || + !falseValue.dataType.isInstanceOf[DecimalType]) { + throw new UnsupportedOperationException( + s"${predicate.dataType} or ${trueValue.dataType} or ${falseValue.dataType} " + + s"is not supported in ColumnarIf") + } } } diff --git a/core/src/main/scala/com/intel/oap/expression/ColumnarUnaryOperator.scala b/core/src/main/scala/com/intel/oap/expression/ColumnarUnaryOperator.scala index 6daa85c73..98d49d7c9 100644 --- a/core/src/main/scala/com/intel/oap/expression/ColumnarUnaryOperator.scala +++ b/core/src/main/scala/com/intel/oap/expression/ColumnarUnaryOperator.scala @@ -249,11 +249,7 @@ class ColumnarCheckOverflow(child: Expression, original: CheckOverflow) val (child_node, childType): (TreeNode, ArrowType) = child.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) // since spark will call toPrecision in checkOverFlow and rescale from zero, we need to re-calculate result dataType here - val childScale: Int = childType match { - case d: ArrowType.Decimal => d.getScale - case _ => 0 - } - val newDataType = DecimalType(dataType.precision, dataType.scale + childScale) + val newDataType = DecimalType(dataType.precision, dataType.scale) val resType = CodeGeneration.getResultType(newDataType) val funcNode = TreeBuilder.makeFunction( "castDECIMAL", diff --git a/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc b/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc index 87d59b29f..7198fdcb5 100644 --- a/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc +++ b/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc @@ -277,20 +277,33 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) func_name.compare("castDECIMAL") != 0) { codes_str_ = func_name + "_" + std::to_string(cur_func_id); auto validity = codes_str_ + "_validity"; - std::stringstream fix_ss; - if (node.return_type()->id() == arrow::Type::DOUBLE || - node.return_type()->id() == arrow::Type::FLOAT) { - fix_ss << " * 1.0 "; - } std::stringstream prepare_ss; prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";" << std::endl; prepare_ss << "bool " << validity << " = " << child_visitor_list[0]->GetPreCheck() << ";" << std::endl; prepare_ss << "if (" << validity << ") {" << std::endl; - prepare_ss << codes_str_ << " = static_cast<" << GetCTypeString(node.return_type()) - << ">(" << child_visitor_list[0]->GetResult() << fix_ss.str() << ");" - << std::endl; + + auto childNode = node.children().at(0); + if (childNode->return_type()->id() != arrow::Type::DECIMAL) { + // if not casting form Decimal + std::stringstream fix_ss; + if (node.return_type()->id() == arrow::Type::DOUBLE || + node.return_type()->id() == arrow::Type::FLOAT) { + fix_ss << " * 1.0 "; + } + prepare_ss << codes_str_ << " = static_cast<" << GetCTypeString(node.return_type()) + << ">(" << child_visitor_list[0]->GetResult() << fix_ss.str() << ");" + << std::endl; + } else { + // if casting From Decimal + auto decimal_type = + std::dynamic_pointer_cast(childNode->return_type()); + prepare_ss << codes_str_ << " = static_cast<" << GetCTypeString(node.return_type()) + << ">(castFloatFromDecimal(" << child_visitor_list[0]->GetResult() + << ", " << decimal_type->scale() << "));" << std::endl; + header_list_.push_back(R"(#include "precompile/gandiva.h")"); + } prepare_ss << "}" << std::endl; for (int i = 0; i < 1; i++) { diff --git a/cpp/src/precompile/array.h b/cpp/src/precompile/array.h index 1641ba9f7..e5fa14e88 100644 --- a/cpp/src/precompile/array.h +++ b/cpp/src/precompile/array.h @@ -1,8 +1,8 @@ #pragma once #include - -#include "arrow/util/string_view.h" // IWYU pragma: export +#include +#include // IWYU pragma: export namespace sparkcolumnarplugin { namespace precompile { @@ -130,9 +130,12 @@ class FixedSizeBinaryArray { public: FixedSizeBinaryArray(const std::shared_ptr&); arrow::util::string_view GetView(int64_t i) const { - return arrow::util::string_view(reinterpret_cast(raw_value_[i]), + return arrow::util::string_view(reinterpret_cast(GetValue(i)), byte_width_); } + const uint8_t* GetValue(int64_t i) const { + return raw_value_ + (i + offset_) * byte_width_; + } bool IsNull(int64_t i) const { i += offset_; return null_bitmap_data_ != NULLPTR && @@ -156,6 +159,10 @@ class FixedSizeBinaryArray { class Decimal128Array : public FixedSizeBinaryArray { public: Decimal128Array(const std::shared_ptr& in) : FixedSizeBinaryArray(in) {} + arrow::Decimal128 GetView(int64_t i) const { + const arrow::Decimal128 value(GetValue(i)); + return value; + } }; arrow::Status MakeFixedSizeBinaryArray(const std::shared_ptr&, diff --git a/cpp/src/precompile/gandiva.h b/cpp/src/precompile/gandiva.h index 0c445736f..8fb8a736d 100644 --- a/cpp/src/precompile/gandiva.h +++ b/cpp/src/precompile/gandiva.h @@ -26,10 +26,15 @@ arrow::Decimal128 castDECIMAL(double val, int32_t precision, int32_t scale) { snprintf(buffer, charsNeeded, "%.*f", (int)scale, nextafter(val, val + 0.5)); auto decimal_str = std::string(buffer); free(buffer); - return arrow::Decimal128(decimal_str); + return arrow::Decimal128::FromString(decimal_str).ValueOrDie(); } arrow::Decimal128 castDECIMAL(arrow::Decimal128 in, int32_t original_scale, int32_t new_scale) { return in.Rescale(original_scale, new_scale).ValueOrDie(); -} \ No newline at end of file +} + +double castFloatFromDecimal(arrow::Decimal128 val, int32_t scale) { + std::string str = val.ToString(scale); + return atof(str.c_str()); +}