Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-130] fix overflow and precision loss #156

Merged
merged 2 commits into from
Mar 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,36 @@ class ColumnarAdd(left: Expression, right: Expression, original: Expression)
extends Add(left: Expression, right: Expression)
with ColumnarExpression
with Logging {

// If casting between DecimalType, unnecessary cast is skipped to avoid data loss,
// because res type of "cast" is actually the res type of "add/subtract".
val left_val: Any = left match {
case c: ColumnarCast =>
if (c.child.dataType.isInstanceOf[DecimalType] &&
c.dataType.isInstanceOf[DecimalType]) {
c.child
} else {
left
}
case _ =>
left
}
val right_val: Any = right match {
case c: ColumnarCast =>
if (c.child.dataType.isInstanceOf[DecimalType] &&
c.dataType.isInstanceOf[DecimalType]) {
c.child
} else {
right
}
case _ =>
right
}
override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
var (left_node, left_type): (TreeNode, ArrowType) =
left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
left_val.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
var (right_node, right_type): (TreeNode, ArrowType) =
right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
right_val.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)

(left_type, right_type) match {
case (l: ArrowType.Decimal, r: ArrowType.Decimal) =>
Expand Down Expand Up @@ -76,11 +101,34 @@ class ColumnarSubtract(left: Expression, right: Expression, original: Expression
extends Subtract(left: Expression, right: Expression)
with ColumnarExpression
with Logging {

val left_val: Any = left match {
case c: ColumnarCast =>
if (c.child.dataType.isInstanceOf[DecimalType] &&
c.dataType.isInstanceOf[DecimalType]) {
c.child
} else {
left
}
case _ =>
left
}
val right_val: Any = right match {
case c: ColumnarCast =>
if (c.child.dataType.isInstanceOf[DecimalType] &&
c.dataType.isInstanceOf[DecimalType]) {
c.child
} else {
right
}
case _ =>
right
}
override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
var (left_node, left_type): (TreeNode, ArrowType) =
left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
left_val.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
var (right_node, right_type): (TreeNode, ArrowType) =
right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
right_val.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)

(left_type, right_type) match {
case (l: ArrowType.Decimal, r: ArrowType.Decimal) =>
Expand Down Expand Up @@ -113,6 +161,7 @@ class ColumnarMultiply(left: Expression, right: Expression, original: Expression
extends Multiply(left: Expression, right: Expression)
with ColumnarExpression
with Logging {

override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
var (left_node, left_type): (TreeNode, ArrowType) =
left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
Expand All @@ -121,10 +170,39 @@ class ColumnarMultiply(left: Expression, right: Expression, original: Expression

(left_type, right_type) match {
case (l: ArrowType.Decimal, r: ArrowType.Decimal) =>
val resultType = DecimalTypeUtil.getResultTypeForOperation(
var resultType = DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.MULTIPLY, l, r)
// Scaling down the unnecessary scale for Literal to avoid precision loss
val newLeftNode = left match {
case literal: ColumnarLiteral =>
val leftStr = literal.value.asInstanceOf[Decimal].toDouble.toString
val newLeftPrecision = leftStr.length - 1
val newLeftScale = leftStr.split('.')(1).length
val newLeftType =
new ArrowType.Decimal(newLeftPrecision, newLeftScale, 128)
resultType = DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.MULTIPLY, newLeftType, r)
TreeBuilder.makeFunction(
"castDECIMAL", Lists.newArrayList(left_node), newLeftType)
case _ =>
left_node
}
val newRightNode = right match {
case literal: ColumnarLiteral =>
val rightStr = literal.value.asInstanceOf[Decimal].toDouble.toString
val newRightPrecision = rightStr.length - 1
val newRightScale = rightStr.split('.')(1).length
val newRightType =
new ArrowType.Decimal(newRightPrecision, newRightScale, 128)
resultType = DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.MULTIPLY, l, newRightType)
TreeBuilder.makeFunction(
"castDECIMAL", Lists.newArrayList(right_node), newRightType)
case _ =>
right_node
}
val mulNode = TreeBuilder.makeFunction(
"multiply", Lists.newArrayList(left_node, right_node), resultType)
"multiply", Lists.newArrayList(newLeftNode, newRightNode), resultType)
(mulNode, resultType)
case _ =>
val resultType = CodeGeneration.getResultType(left_type, right_type)
Expand All @@ -150,6 +228,7 @@ class ColumnarDivide(left: Expression, right: Expression, original: Expression)
extends Divide(left: Expression, right: Expression)
with ColumnarExpression
with Logging {

override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
var (left_node, left_type): (TreeNode, ArrowType) =
left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
Expand All @@ -158,10 +237,38 @@ class ColumnarDivide(left: Expression, right: Expression, original: Expression)

(left_type, right_type) match {
case (l: ArrowType.Decimal, r: ArrowType.Decimal) =>
val resultType = DecimalTypeUtil.getResultTypeForOperation(
var resultType = DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.DIVIDE, l, r)
val newLeftNode = left match {
case literal: ColumnarLiteral =>
val leftStr = literal.value.asInstanceOf[Decimal].toDouble.toString
val newLeftPrecision = leftStr.length - 1
val newLeftScale = leftStr.split('.')(1).length
val newLeftType =
new ArrowType.Decimal(newLeftPrecision, newLeftScale, 128)
resultType = DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.DIVIDE, newLeftType, r)
TreeBuilder.makeFunction(
"castDECIMAL", Lists.newArrayList(left_node), newLeftType)
case _ =>
left_node
}
val newRightNode = right match {
case literal: ColumnarLiteral =>
val rightStr = literal.value.asInstanceOf[Decimal].toDouble.toString
val newRightPrecision = rightStr.length - 1
val newRightScale = rightStr.split('.')(1).length
val newRightType =
new ArrowType.Decimal(newRightPrecision, newRightScale, 128)
resultType = DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.DIVIDE, l, newRightType)
TreeBuilder.makeFunction(
"castDECIMAL", Lists.newArrayList(right_node), newRightType)
case _ =>
right_node
}
val divNode = TreeBuilder.makeFunction(
"divide", Lists.newArrayList(left_node, right_node), resultType)
"divide", Lists.newArrayList(newLeftNode, newRightNode), resultType)
(divNode, resultType)
case _ =>
val resultType = CodeGeneration.getResultType(left_type, right_type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,21 +273,21 @@ class ColumnarCheckOverflow(child: Expression, original: CheckOverflow)
override def doColumnarCodeGen(args: Object): (TreeNode, ArrowType) = {
val (child_node, childType): (TreeNode, ArrowType) =
child.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
// since spark will call toPrecision in checkOverFlow and rescale from zero, we need to re-calculate result dataType here
val childScale: Int = childType match {
case d: ArrowType.Decimal => d.getScale
case _ => 0
}
val newDataType =
DecimalType(dataType.precision, dataType.scale)
val resType = CodeGeneration.getResultType(newDataType)
var function = "castDECIMAL"
if (nullOnOverflow) {
function = "castDECIMALNullOnOverflow"
if (resType == childType) {
// If target type is the same as childType, cast is not needed
(child_node, childType)
} else {
var function = "castDECIMAL"
if (nullOnOverflow) {
function = "castDECIMALNullOnOverflow"
}
val funcNode =
TreeBuilder.makeFunction(function, Lists.newArrayList(child_node), resType)
(funcNode, resType)
}
val funcNode =
TreeBuilder.makeFunction(function, Lists.newArrayList(child_node), resType)
(funcNode, resType)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
<< child_visitor_list[1]->GetResult() << ", " << rightType->precision()
<< ", " << rightType->scale() << ", " << resType->precision() << ", "
<< resType->scale() << ")";
header_list_.push_back(R"(#include "precompile/gandiva.h")");
}
std::stringstream prepare_ss;
prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";"
Expand Down Expand Up @@ -596,6 +597,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
<< child_visitor_list[1]->GetResult() << ", " << rightType->precision()
<< ", " << rightType->scale() << ", " << resType->precision() << ", "
<< resType->scale() << ")";
header_list_.push_back(R"(#include "precompile/gandiva.h")");
}
std::stringstream prepare_ss;
prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";"
Expand Down Expand Up @@ -632,7 +634,8 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
<< leftType->precision() << ", " << leftType->scale() << ", "
<< child_visitor_list[1]->GetResult() << ", " << rightType->precision()
<< ", " << rightType->scale() << ", " << resType->precision() << ", "
<< resType->scale() << ")";
<< resType->scale() << ", &overflow)";
header_list_.push_back(R"(#include "precompile/gandiva.h")");
}
std::stringstream prepare_ss;
prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";"
Expand All @@ -642,7 +645,13 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
child_visitor_list[1]->GetPreCheck()})
<< ");" << std::endl;
prepare_ss << "if (" << validity << ") {" << std::endl;
if (node.return_type()->id() == arrow::Type::DECIMAL) {
prepare_ss << "bool overflow = false;" << std::endl;
}
prepare_ss << codes_str_ << " = " << fix_ss.str() << ";" << std::endl;
if (node.return_type()->id() == arrow::Type::DECIMAL) {
prepare_ss << "if (overflow) {\n" << validity << " = false;}" << std::endl;
}
prepare_ss << "}" << std::endl;

for (int i = 0; i < 2; i++) {
Expand All @@ -669,7 +678,8 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
<< leftType->precision() << ", " << leftType->scale() << ", "
<< child_visitor_list[1]->GetResult() << ", " << rightType->precision()
<< ", " << rightType->scale() << ", " << resType->precision() << ", "
<< resType->scale() << ")";
<< resType->scale() << ", &overflow)";
header_list_.push_back(R"(#include "precompile/gandiva.h")");
}
std::stringstream prepare_ss;
prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";"
Expand All @@ -679,7 +689,13 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
child_visitor_list[1]->GetPreCheck()})
<< ");" << std::endl;
prepare_ss << "if (" << validity << ") {" << std::endl;
if (node.return_type()->id() == arrow::Type::DECIMAL) {
prepare_ss << "bool overflow = false;" << std::endl;
}
prepare_ss << codes_str_ << " = " << fix_ss.str() << ";" << std::endl;
if (node.return_type()->id() == arrow::Type::DECIMAL) {
prepare_ss << "if (overflow) {\n" << validity << " = false;}" << std::endl;
}
prepare_ss << "}" << std::endl;

for (int i = 0; i < 2; i++) {
Expand Down
10 changes: 6 additions & 4 deletions native-sql-engine/cpp/src/precompile/gandiva.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,29 +95,31 @@ arrow::Decimal128 subtract(arrow::Decimal128 left, int32_t left_precision,
arrow::Decimal128 multiply(arrow::Decimal128 left, int32_t left_precision,
int32_t left_scale, arrow::Decimal128 right,
int32_t right_precision, int32_t right_scale,
int32_t out_precision, int32_t out_scale) {
int32_t out_precision, int32_t out_scale,
bool* overflow_) {
gandiva::BasicDecimalScalar128 x(left, left_precision, left_scale);
gandiva::BasicDecimalScalar128 y(right, right_precision, right_scale);
bool overflow = false;
arrow::BasicDecimal128 out =
gandiva::decimalops::Multiply(x, y, out_precision, out_scale, &overflow);
if (overflow) {
throw std::overflow_error("Decimal multiply overflowed!");
*overflow_ = true;
}
return arrow::Decimal128(out);
}

arrow::Decimal128 divide(arrow::Decimal128 left, int32_t left_precision,
int32_t left_scale, arrow::Decimal128 right,
int32_t right_precision, int32_t right_scale,
int32_t out_precision, int32_t out_scale) {
int32_t out_precision, int32_t out_scale,
bool* overflow_) {
gandiva::BasicDecimalScalar128 x(left, left_precision, left_scale);
gandiva::BasicDecimalScalar128 y(right, right_precision, right_scale);
bool overflow = false;
arrow::BasicDecimal128 out =
gandiva::decimalops::Divide(0, x, y, out_precision, out_scale, &overflow);
if (overflow) {
throw std::overflow_error("Decimal divide overflowed!");
*overflow_ = true;
}
return arrow::Decimal128(out);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,21 @@ TEST(TestArrowCompute, ArithmeticDecimalTest) {
int32_t out_scale = 10;
auto res = castDECIMAL(left, left_precision, left_scale, out_precision, out_scale);
ASSERT_EQ(res, arrow::Decimal128("32342423.0128750000"));
bool overflow = false;
res = castDECIMALNullOnOverflow(left, left_precision, left_scale, out_precision,
out_scale, &overflow);
ASSERT_EQ(res, arrow::Decimal128("32342423.0128750000"));
res = add(left, left_precision, left_scale, right, right_precision, right_scale,
17, 9);
ASSERT_EQ(res, arrow::Decimal128("32344770.025749535"));
res = subtract(left, left_precision, left_scale, right, right_precision, right_scale,
17, 9);
ASSERT_EQ(res, arrow::Decimal128("32340076.000000465"));
res = multiply(left, left_precision, left_scale, right, right_precision, right_scale,
28, 15);
28, 15, &overflow);
ASSERT_EQ(res, arrow::Decimal128("75908083204.874689064638125"));
res = divide(left, left_precision, left_scale, right, right_precision, right_scale,
out_precision, out_scale);
out_precision, out_scale, &overflow);
ASSERT_EQ(res, arrow::Decimal128("13780.2495094037"));
}

Expand Down