diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarArithmetic.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarArithmetic.scala index 2434eba4b..0fbd8b297 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarArithmetic.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarArithmetic.scala @@ -39,11 +39,36 @@ class ColumnarAdd(left: Expression, right: Expression, original: Expression) extends Add(left: Expression, right: Expression) with ColumnarExpression with Logging { + + // If casting between DecimalType, unnecessary cast is skipped to avoid data loss, + // because res type of "cast" is actually the res type of "add/subtract". + val left_val: Any = left match { + case c: ColumnarCast => + if (c.child.dataType.isInstanceOf[DecimalType] && + c.dataType.isInstanceOf[DecimalType]) { + c.child + } else { + left + } + case _ => + left + } + val right_val: Any = right match { + case c: ColumnarCast => + if (c.child.dataType.isInstanceOf[DecimalType] && + c.dataType.isInstanceOf[DecimalType]) { + c.child + } else { + right + } + case _ => + right + } override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = { var (left_node, left_type): (TreeNode, ArrowType) = - left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) + left_val.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) var (right_node, right_type): (TreeNode, ArrowType) = - right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) + right_val.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) (left_type, right_type) match { case (l: ArrowType.Decimal, r: ArrowType.Decimal) => @@ -76,11 +101,34 @@ class ColumnarSubtract(left: Expression, right: Expression, original: Expression extends Subtract(left: Expression, right: Expression) with ColumnarExpression with Logging { + + val left_val: Any = left match { + case c: ColumnarCast => + if (c.child.dataType.isInstanceOf[DecimalType] && + c.dataType.isInstanceOf[DecimalType]) { + c.child + } else { + left + } + case _ => + left + } + val right_val: Any = right match { + case c: ColumnarCast => + if (c.child.dataType.isInstanceOf[DecimalType] && + c.dataType.isInstanceOf[DecimalType]) { + c.child + } else { + right + } + case _ => + right + } override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = { var (left_node, left_type): (TreeNode, ArrowType) = - left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) + left_val.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) var (right_node, right_type): (TreeNode, ArrowType) = - right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) + right_val.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) (left_type, right_type) match { case (l: ArrowType.Decimal, r: ArrowType.Decimal) => @@ -113,6 +161,7 @@ class ColumnarMultiply(left: Expression, right: Expression, original: Expression extends Multiply(left: Expression, right: Expression) with ColumnarExpression with Logging { + override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = { var (left_node, left_type): (TreeNode, ArrowType) = left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) @@ -121,10 +170,39 @@ class ColumnarMultiply(left: Expression, right: Expression, original: Expression (left_type, right_type) match { case (l: ArrowType.Decimal, r: ArrowType.Decimal) => - val resultType = DecimalTypeUtil.getResultTypeForOperation( + var resultType = DecimalTypeUtil.getResultTypeForOperation( DecimalTypeUtil.OperationType.MULTIPLY, l, r) + // Scaling down the unnecessary scale for Literal to avoid precision loss + val newLeftNode = left match { + case literal: ColumnarLiteral => + val leftStr = literal.value.asInstanceOf[Decimal].toDouble.toString + val newLeftPrecision = leftStr.length - 1 + val newLeftScale = leftStr.split('.')(1).length + val newLeftType = + new ArrowType.Decimal(newLeftPrecision, newLeftScale, 128) + resultType = DecimalTypeUtil.getResultTypeForOperation( + DecimalTypeUtil.OperationType.MULTIPLY, newLeftType, r) + TreeBuilder.makeFunction( + "castDECIMAL", Lists.newArrayList(left_node), newLeftType) + case _ => + left_node + } + val newRightNode = right match { + case literal: ColumnarLiteral => + val rightStr = literal.value.asInstanceOf[Decimal].toDouble.toString + val newRightPrecision = rightStr.length - 1 + val newRightScale = rightStr.split('.')(1).length + val newRightType = + new ArrowType.Decimal(newRightPrecision, newRightScale, 128) + resultType = DecimalTypeUtil.getResultTypeForOperation( + DecimalTypeUtil.OperationType.MULTIPLY, l, newRightType) + TreeBuilder.makeFunction( + "castDECIMAL", Lists.newArrayList(right_node), newRightType) + case _ => + right_node + } val mulNode = TreeBuilder.makeFunction( - "multiply", Lists.newArrayList(left_node, right_node), resultType) + "multiply", Lists.newArrayList(newLeftNode, newRightNode), resultType) (mulNode, resultType) case _ => val resultType = CodeGeneration.getResultType(left_type, right_type) @@ -150,6 +228,7 @@ class ColumnarDivide(left: Expression, right: Expression, original: Expression) extends Divide(left: Expression, right: Expression) with ColumnarExpression with Logging { + override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = { var (left_node, left_type): (TreeNode, ArrowType) = left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) @@ -158,10 +237,38 @@ class ColumnarDivide(left: Expression, right: Expression, original: Expression) (left_type, right_type) match { case (l: ArrowType.Decimal, r: ArrowType.Decimal) => - val resultType = DecimalTypeUtil.getResultTypeForOperation( + var resultType = DecimalTypeUtil.getResultTypeForOperation( DecimalTypeUtil.OperationType.DIVIDE, l, r) + val newLeftNode = left match { + case literal: ColumnarLiteral => + val leftStr = literal.value.asInstanceOf[Decimal].toDouble.toString + val newLeftPrecision = leftStr.length - 1 + val newLeftScale = leftStr.split('.')(1).length + val newLeftType = + new ArrowType.Decimal(newLeftPrecision, newLeftScale, 128) + resultType = DecimalTypeUtil.getResultTypeForOperation( + DecimalTypeUtil.OperationType.DIVIDE, newLeftType, r) + TreeBuilder.makeFunction( + "castDECIMAL", Lists.newArrayList(left_node), newLeftType) + case _ => + left_node + } + val newRightNode = right match { + case literal: ColumnarLiteral => + val rightStr = literal.value.asInstanceOf[Decimal].toDouble.toString + val newRightPrecision = rightStr.length - 1 + val newRightScale = rightStr.split('.')(1).length + val newRightType = + new ArrowType.Decimal(newRightPrecision, newRightScale, 128) + resultType = DecimalTypeUtil.getResultTypeForOperation( + DecimalTypeUtil.OperationType.DIVIDE, l, newRightType) + TreeBuilder.makeFunction( + "castDECIMAL", Lists.newArrayList(right_node), newRightType) + case _ => + right_node + } val divNode = TreeBuilder.makeFunction( - "divide", Lists.newArrayList(left_node, right_node), resultType) + "divide", Lists.newArrayList(newLeftNode, newRightNode), resultType) (divNode, resultType) case _ => val resultType = CodeGeneration.getResultType(left_type, right_type) diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarUnaryOperator.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarUnaryOperator.scala index c46809b42..0c713c8e8 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarUnaryOperator.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarUnaryOperator.scala @@ -273,21 +273,21 @@ class ColumnarCheckOverflow(child: Expression, original: CheckOverflow) override def doColumnarCodeGen(args: Object): (TreeNode, ArrowType) = { val (child_node, childType): (TreeNode, ArrowType) = child.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) - // since spark will call toPrecision in checkOverFlow and rescale from zero, we need to re-calculate result dataType here - val childScale: Int = childType match { - case d: ArrowType.Decimal => d.getScale - case _ => 0 - } val newDataType = DecimalType(dataType.precision, dataType.scale) val resType = CodeGeneration.getResultType(newDataType) - var function = "castDECIMAL" - if (nullOnOverflow) { - function = "castDECIMALNullOnOverflow" + if (resType == childType) { + // If target type is the same as childType, cast is not needed + (child_node, childType) + } else { + var function = "castDECIMAL" + if (nullOnOverflow) { + function = "castDECIMALNullOnOverflow" + } + val funcNode = + TreeBuilder.makeFunction(function, Lists.newArrayList(child_node), resType) + (funcNode, resType) } - val funcNode = - TreeBuilder.makeFunction(function, Lists.newArrayList(child_node), resType) - (funcNode, resType) } } diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc index 4b4bd657b..1a332e02f 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc @@ -559,6 +559,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) << child_visitor_list[1]->GetResult() << ", " << rightType->precision() << ", " << rightType->scale() << ", " << resType->precision() << ", " << resType->scale() << ")"; + header_list_.push_back(R"(#include "precompile/gandiva.h")"); } std::stringstream prepare_ss; prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";" @@ -596,6 +597,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) << child_visitor_list[1]->GetResult() << ", " << rightType->precision() << ", " << rightType->scale() << ", " << resType->precision() << ", " << resType->scale() << ")"; + header_list_.push_back(R"(#include "precompile/gandiva.h")"); } std::stringstream prepare_ss; prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";" @@ -632,7 +634,8 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) << leftType->precision() << ", " << leftType->scale() << ", " << child_visitor_list[1]->GetResult() << ", " << rightType->precision() << ", " << rightType->scale() << ", " << resType->precision() << ", " - << resType->scale() << ")"; + << resType->scale() << ", &overflow)"; + header_list_.push_back(R"(#include "precompile/gandiva.h")"); } std::stringstream prepare_ss; prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";" @@ -642,7 +645,13 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) child_visitor_list[1]->GetPreCheck()}) << ");" << std::endl; prepare_ss << "if (" << validity << ") {" << std::endl; + if (node.return_type()->id() == arrow::Type::DECIMAL) { + prepare_ss << "bool overflow = false;" << std::endl; + } prepare_ss << codes_str_ << " = " << fix_ss.str() << ";" << std::endl; + if (node.return_type()->id() == arrow::Type::DECIMAL) { + prepare_ss << "if (overflow) {\n" << validity << " = false;}" << std::endl; + } prepare_ss << "}" << std::endl; for (int i = 0; i < 2; i++) { @@ -669,7 +678,8 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) << leftType->precision() << ", " << leftType->scale() << ", " << child_visitor_list[1]->GetResult() << ", " << rightType->precision() << ", " << rightType->scale() << ", " << resType->precision() << ", " - << resType->scale() << ")"; + << resType->scale() << ", &overflow)"; + header_list_.push_back(R"(#include "precompile/gandiva.h")"); } std::stringstream prepare_ss; prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";" @@ -679,7 +689,13 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) child_visitor_list[1]->GetPreCheck()}) << ");" << std::endl; prepare_ss << "if (" << validity << ") {" << std::endl; + if (node.return_type()->id() == arrow::Type::DECIMAL) { + prepare_ss << "bool overflow = false;" << std::endl; + } prepare_ss << codes_str_ << " = " << fix_ss.str() << ";" << std::endl; + if (node.return_type()->id() == arrow::Type::DECIMAL) { + prepare_ss << "if (overflow) {\n" << validity << " = false;}" << std::endl; + } prepare_ss << "}" << std::endl; for (int i = 0; i < 2; i++) { diff --git a/native-sql-engine/cpp/src/precompile/gandiva.h b/native-sql-engine/cpp/src/precompile/gandiva.h index df1416138..a62616a36 100644 --- a/native-sql-engine/cpp/src/precompile/gandiva.h +++ b/native-sql-engine/cpp/src/precompile/gandiva.h @@ -95,14 +95,15 @@ arrow::Decimal128 subtract(arrow::Decimal128 left, int32_t left_precision, arrow::Decimal128 multiply(arrow::Decimal128 left, int32_t left_precision, int32_t left_scale, arrow::Decimal128 right, int32_t right_precision, int32_t right_scale, - int32_t out_precision, int32_t out_scale) { + int32_t out_precision, int32_t out_scale, + bool* overflow_) { gandiva::BasicDecimalScalar128 x(left, left_precision, left_scale); gandiva::BasicDecimalScalar128 y(right, right_precision, right_scale); bool overflow = false; arrow::BasicDecimal128 out = gandiva::decimalops::Multiply(x, y, out_precision, out_scale, &overflow); if (overflow) { - throw std::overflow_error("Decimal multiply overflowed!"); + *overflow_ = true; } return arrow::Decimal128(out); } @@ -110,14 +111,15 @@ arrow::Decimal128 multiply(arrow::Decimal128 left, int32_t left_precision, arrow::Decimal128 divide(arrow::Decimal128 left, int32_t left_precision, int32_t left_scale, arrow::Decimal128 right, int32_t right_precision, int32_t right_scale, - int32_t out_precision, int32_t out_scale) { + int32_t out_precision, int32_t out_scale, + bool* overflow_) { gandiva::BasicDecimalScalar128 x(left, left_precision, left_scale); gandiva::BasicDecimalScalar128 y(right, right_precision, right_scale); bool overflow = false; arrow::BasicDecimal128 out = gandiva::decimalops::Divide(0, x, y, out_precision, out_scale, &overflow); if (overflow) { - throw std::overflow_error("Decimal divide overflowed!"); + *overflow_ = true; } return arrow::Decimal128(out); } diff --git a/native-sql-engine/cpp/src/tests/arrow_compute_test_precompile.cc b/native-sql-engine/cpp/src/tests/arrow_compute_test_precompile.cc index 4f3549749..c74f19822 100644 --- a/native-sql-engine/cpp/src/tests/arrow_compute_test_precompile.cc +++ b/native-sql-engine/cpp/src/tests/arrow_compute_test_precompile.cc @@ -54,6 +54,10 @@ TEST(TestArrowCompute, ArithmeticDecimalTest) { int32_t out_scale = 10; auto res = castDECIMAL(left, left_precision, left_scale, out_precision, out_scale); ASSERT_EQ(res, arrow::Decimal128("32342423.0128750000")); + bool overflow = false; + res = castDECIMALNullOnOverflow(left, left_precision, left_scale, out_precision, + out_scale, &overflow); + ASSERT_EQ(res, arrow::Decimal128("32342423.0128750000")); res = add(left, left_precision, left_scale, right, right_precision, right_scale, 17, 9); ASSERT_EQ(res, arrow::Decimal128("32344770.025749535")); @@ -61,10 +65,10 @@ TEST(TestArrowCompute, ArithmeticDecimalTest) { 17, 9); ASSERT_EQ(res, arrow::Decimal128("32340076.000000465")); res = multiply(left, left_precision, left_scale, right, right_precision, right_scale, - 28, 15); + 28, 15, &overflow); ASSERT_EQ(res, arrow::Decimal128("75908083204.874689064638125")); res = divide(left, left_precision, left_scale, right, right_precision, right_scale, - out_precision, out_scale); + out_precision, out_scale, &overflow); ASSERT_EQ(res, arrow::Decimal128("13780.2495094037")); }