Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
support decimal in project
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Mar 2, 2021
1 parent e1e4b73 commit 9394509
Show file tree
Hide file tree
Showing 11 changed files with 260 additions and 153 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch}
import org.apache.spark.sql.types.{DecimalType, StructType}
import org.apache.spark.util.ExecutorManager
import org.apache.spark.sql.util.StructTypeFWD
import org.apache.spark.{SparkConf, TaskContext}
Expand Down Expand Up @@ -70,8 +70,10 @@ case class ColumnarConditionProjectExec(
ConverterUtils.checkIfTypeSupported(attr.dataType)
} catch {
case e : UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarConditionProjector.")
if (!attr.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarConditionProjector.")
}
}
})
// check expr
Expand All @@ -80,8 +82,10 @@ case class ColumnarConditionProjectExec(
ConverterUtils.checkIfTypeSupported(condExpr.dataType)
} catch {
case e : UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${condExpr.dataType} is not supported in ColumnarConditionProjector.")
if (!condExpr.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${condExpr.dataType} is not supported in ColumnarConditionProjector.")
}
}
ColumnarExpressionConverter.replaceWithColumnarExpression(condExpr)
}
Expand All @@ -91,8 +95,10 @@ case class ColumnarConditionProjectExec(
ConverterUtils.checkIfTypeSupported(expr.dataType)
} catch {
case e : UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${expr.dataType} is not supported in ColumnarConditionProjector.")
if (!expr.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${expr.dataType} is not supported in ColumnarConditionProjector.")
}
}
ColumnarExpressionConverter.replaceWithColumnarExpression(expr)
}
Expand Down
162 changes: 100 additions & 62 deletions core/src/main/scala/com/intel/oap/expression/ColumnarArithmetic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,20 @@
package com.intel.oap.expression

import com.google.common.collect.Lists

import org.apache.arrow.gandiva.evaluator._
import org.apache.arrow.gandiva.exceptions.GandivaException
import org.apache.arrow.gandiva.expression._
import org.apache.arrow.vector.types.FloatingPointPrecision
import org.apache.arrow.vector.types.pojo.ArrowType
import org.apache.arrow.vector.types.pojo.Field

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._

import scala.collection.mutable.ListBuffer

import org.apache.arrow.gandiva.evaluator.DecimalTypeUtil

/**
* A version of add that supports columnar processing for longs.
*/
Expand All @@ -44,22 +45,30 @@ class ColumnarAdd(left: Expression, right: Expression, original: Expression)
var (right_node, right_type): (TreeNode, ArrowType) =
right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)

val resultType = CodeGeneration.getResultType(left_type, right_type)
if (!left_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
left_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType),
(left_type, right_type) match {
case (l: ArrowType.Decimal, r: ArrowType.Decimal) =>
val resultType = DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.ADD, l, r)
val addNode = TreeBuilder.makeFunction(
"add", Lists.newArrayList(left_node, right_node), resultType)
(addNode, resultType)
case _ =>
val resultType = CodeGeneration.getResultType(left_type, right_type)
if (!left_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
left_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType)
}
if (!right_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
right_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType)
}
//logInfo(s"(TreeBuilder.makeFunction(add, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)")
val funcNode = TreeBuilder.makeFunction(
"add", Lists.newArrayList(left_node, right_node), resultType)
(funcNode, resultType)
}
if (!right_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
right_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType),
}

//logInfo(s"(TreeBuilder.makeFunction(add, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)")
(
TreeBuilder.makeFunction("add", Lists.newArrayList(left_node, right_node), resultType),
resultType)
}
}

Expand All @@ -73,21 +82,30 @@ class ColumnarSubtract(left: Expression, right: Expression, original: Expression
var (right_node, right_type): (TreeNode, ArrowType) =
right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)

val resultType = CodeGeneration.getResultType(left_type, right_type)
if (!left_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
left_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType),
(left_type, right_type) match {
case (l: ArrowType.Decimal, r: ArrowType.Decimal) =>
val resultType = DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.SUBTRACT, l, r)
val subNode = TreeBuilder.makeFunction(
"subtract", Lists.newArrayList(left_node, right_node), resultType)
(subNode, resultType)
case _ =>
val resultType = CodeGeneration.getResultType(left_type, right_type)
if (!left_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
left_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType)
}
if (!right_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
right_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType)
}
//logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)")
val funcNode = TreeBuilder.makeFunction(
"subtract", Lists.newArrayList(left_node, right_node), resultType)
(funcNode, resultType)
}
if (!right_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
right_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType),
}
//logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)")
(
TreeBuilder.makeFunction("subtract", Lists.newArrayList(left_node, right_node), resultType),
resultType)
}
}

Expand All @@ -101,22 +119,30 @@ class ColumnarMultiply(left: Expression, right: Expression, original: Expression
var (right_node, right_type): (TreeNode, ArrowType) =
right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)

val resultType = CodeGeneration.getResultType(left_type, right_type)
if (!left_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
left_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType),
}
if (!right_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
right_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType),
(left_type, right_type) match {
case (l: ArrowType.Decimal, r: ArrowType.Decimal) =>
val resultType = DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.MULTIPLY, l, r)
val mulNode = TreeBuilder.makeFunction(
"multiply", Lists.newArrayList(left_node, right_node), resultType)
(mulNode, resultType)
case _ =>
val resultType = CodeGeneration.getResultType(left_type, right_type)
if (!left_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
left_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType),
}
if (!right_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
right_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType),
}
//logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)")
val funcNode = TreeBuilder.makeFunction(
"multiply", Lists.newArrayList(left_node, right_node), resultType)
(funcNode, resultType)
}

//logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)")
(
TreeBuilder.makeFunction("multiply", Lists.newArrayList(left_node, right_node), resultType),
resultType)
}
}

Expand All @@ -130,21 +156,30 @@ class ColumnarDivide(left: Expression, right: Expression, original: Expression)
var (right_node, right_type): (TreeNode, ArrowType) =
right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)

val resultType = CodeGeneration.getResultType(left_type, right_type)
if (!left_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
left_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType),
}
if (!right_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
right_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType),
(left_type, right_type) match {
case (l: ArrowType.Decimal, r: ArrowType.Decimal) =>
val resultType = DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.DIVIDE, l, r)
val divNode = TreeBuilder.makeFunction(
"divide", Lists.newArrayList(left_node, right_node), resultType)
(divNode, resultType)
case _ =>
val resultType = CodeGeneration.getResultType(left_type, right_type)
if (!left_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
left_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType),
}
if (!right_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
right_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType),
}
//logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)")
val funcNode = TreeBuilder.makeFunction(
"divide", Lists.newArrayList(left_node, right_node), resultType)
(funcNode, resultType)
}
//logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)")
(
TreeBuilder.makeFunction("divide", Lists.newArrayList(left_node, right_node), resultType),
resultType)
}
}

Expand Down Expand Up @@ -238,8 +273,11 @@ object ColumnarBinaryArithmetic {
ConverterUtils.checkIfTypeSupported(right.dataType)
} catch {
case e : UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${left.dataType} or ${right.dataType} is not supported in ColumnarBinaryArithmetic")
if (!left.dataType.isInstanceOf[DecimalType] ||
!right.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${left.dataType} or ${right.dataType} is not supported in ColumnarBinaryArithmetic")
}
}
}
}
Loading

0 comments on commit 9394509

Please sign in to comment.