diff --git a/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc b/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc index c0c7e5b73..4b4bd657b 100644 --- a/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc +++ b/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc @@ -85,8 +85,8 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) } codes_str_ = ss.str(); } else if (func_name.compare("less_than_with_nan") == 0) { - real_codes_str_ = "(" + child_visitor_list[0]->GetResult() + " < " + - child_visitor_list[1]->GetResult() + ")"; + real_codes_str_ = "less_than_with_nan(" + child_visitor_list[0]->GetResult() + + ", " + child_visitor_list[1]->GetResult() + ")"; real_validity_str_ = CombineValidity( {child_visitor_list[0]->GetPreCheck(), child_visitor_list[1]->GetPreCheck()}); ss << real_validity_str_ << " && " << real_codes_str_; @@ -94,6 +94,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) prepare_str_ += child_visitor_list[i]->GetPrepare(); } codes_str_ = ss.str(); + header_list_.push_back(R"(#include "precompile/gandiva.h")"); } else if (func_name.compare("greater_than") == 0) { real_codes_str_ = "(" + child_visitor_list[0]->GetResult() + " > " + child_visitor_list[1]->GetResult() + ")"; @@ -105,8 +106,8 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) } codes_str_ = ss.str(); } else if (func_name.compare("greater_than_with_nan") == 0) { - real_codes_str_ = "(" + child_visitor_list[0]->GetResult() + " > " + - child_visitor_list[1]->GetResult() + ")"; + real_codes_str_ = "greater_than_with_nan(" + child_visitor_list[0]->GetResult() + + ", " + child_visitor_list[1]->GetResult() + ")"; real_validity_str_ = CombineValidity( {child_visitor_list[0]->GetPreCheck(), child_visitor_list[1]->GetPreCheck()}); ss << real_validity_str_ << " && " << real_codes_str_; @@ -114,6 +115,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) prepare_str_ += child_visitor_list[i]->GetPrepare(); } codes_str_ = ss.str(); + header_list_.push_back(R"(#include "precompile/gandiva.h")"); } else if (func_name.compare("less_than_or_equal_to") == 0) { real_codes_str_ = "(" + child_visitor_list[0]->GetResult() + " <= " + child_visitor_list[1]->GetResult() + ")"; @@ -125,8 +127,9 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) } codes_str_ = ss.str(); } else if (func_name.compare("less_than_or_equal_to_with_nan") == 0) { - real_codes_str_ = "(" + child_visitor_list[0]->GetResult() + - " <= " + child_visitor_list[1]->GetResult() + ")"; + real_codes_str_ = "less_than_or_equal_to_with_nan(" + + child_visitor_list[0]->GetResult() + + ", " + child_visitor_list[1]->GetResult() + ")"; real_validity_str_ = CombineValidity( {child_visitor_list[0]->GetPreCheck(), child_visitor_list[1]->GetPreCheck()}); ss << real_validity_str_ << " && " << real_codes_str_; @@ -134,6 +137,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) prepare_str_ += child_visitor_list[i]->GetPrepare(); } codes_str_ = ss.str(); + header_list_.push_back(R"(#include "precompile/gandiva.h")"); } else if (func_name.compare("greater_than_or_equal_to") == 0) { real_codes_str_ = "(" + child_visitor_list[0]->GetResult() + " >= " + child_visitor_list[1]->GetResult() + ")"; @@ -145,8 +149,9 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) } codes_str_ = ss.str(); } else if (func_name.compare("greater_than_or_equal_to_with_nan") == 0) { - real_codes_str_ = "(" + child_visitor_list[0]->GetResult() + - " >= " + child_visitor_list[1]->GetResult() + ")"; + real_codes_str_ = "greater_than_or_equal_to_with_nan(" + + child_visitor_list[0]->GetResult() + + ", " + child_visitor_list[1]->GetResult() + ")"; real_validity_str_ = CombineValidity( {child_visitor_list[0]->GetPreCheck(), child_visitor_list[1]->GetPreCheck()}); ss << real_validity_str_ << " && " << real_codes_str_; @@ -154,6 +159,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) prepare_str_ += child_visitor_list[i]->GetPrepare(); } codes_str_ = ss.str(); + header_list_.push_back(R"(#include "precompile/gandiva.h")"); } else if (func_name.compare("equal") == 0) { real_codes_str_ = "(" + child_visitor_list[0]->GetResult() + " == " + child_visitor_list[1]->GetResult() + ")"; @@ -165,8 +171,8 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) } codes_str_ = ss.str(); } else if (func_name.compare("equal_with_nan") == 0) { - real_codes_str_ = "(" + child_visitor_list[0]->GetResult() + - " == " + child_visitor_list[1]->GetResult() + ")"; + real_codes_str_ = "equal_with_nan(" + child_visitor_list[0]->GetResult() + + ", " + child_visitor_list[1]->GetResult() + ")"; real_validity_str_ = CombineValidity( {child_visitor_list[0]->GetPreCheck(), child_visitor_list[1]->GetPreCheck()}); ss << real_validity_str_ << " && " << real_codes_str_; @@ -174,6 +180,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) prepare_str_ += child_visitor_list[i]->GetPrepare(); } codes_str_ = ss.str(); + header_list_.push_back(R"(#include "precompile/gandiva.h")"); } else if (func_name.compare("not") == 0) { std::string check_validity; if (child_visitor_list[0]->GetPreCheck() != "") { diff --git a/cpp/src/precompile/gandiva.h b/cpp/src/precompile/gandiva.h index de8e608c4..df1416138 100644 --- a/cpp/src/precompile/gandiva.h +++ b/cpp/src/precompile/gandiva.h @@ -121,3 +121,71 @@ arrow::Decimal128 divide(arrow::Decimal128 left, int32_t left_precision, } return arrow::Decimal128(out); } + +// A comparison with a NaN always returns false even when comparing with itself. +// To get the same result as spark, we can regard NaN as big as Infinity when +// doing comparison. +bool less_than_with_nan(double left, double right) { + bool left_is_nan = std::isnan(left); + bool right_is_nan = std::isnan(right); + if (left_is_nan && right_is_nan) { + return false; + } else if (left_is_nan) { + return false; + } else if (right_is_nan) { + return true; + } + return left < right; +} + +bool greater_than_with_nan(double left, double right) { + bool left_is_nan = std::isnan(left); + bool right_is_nan = std::isnan(right); + if (left_is_nan && right_is_nan) { + return false; + } else if (left_is_nan) { + return true; + } else if (right_is_nan) { + return false; + } + return left > right; +} + +bool less_than_or_equal_to_with_nan(double left, double right) { + bool left_is_nan = std::isnan(left); + bool right_is_nan = std::isnan(right); + if (left_is_nan && right_is_nan) { + return true; + } else if (left_is_nan) { + return false; + } else if (right_is_nan) { + return true; + } + return left <= right; +} + +bool greater_than_or_equal_to_with_nan(double left, double right) { + bool left_is_nan = std::isnan(left); + bool right_is_nan = std::isnan(right); + if (left_is_nan && right_is_nan) { + return true; + } else if (left_is_nan) { + return true; + } else if (right_is_nan) { + return false; + } + return left >= right; +} + +bool equal_with_nan(double left, double right) { + bool left_is_nan = std::isnan(left); + bool right_is_nan = std::isnan(right); + if (left_is_nan && right_is_nan) { + return true; + } else if (left_is_nan) { + return false; + } else if (right_is_nan) { + return false; + } + return left == right; +} diff --git a/cpp/src/tests/arrow_compute_test_precompile.cc b/cpp/src/tests/arrow_compute_test_precompile.cc index 8edbd0b74..4f3549749 100644 --- a/cpp/src/tests/arrow_compute_test_precompile.cc +++ b/cpp/src/tests/arrow_compute_test_precompile.cc @@ -68,5 +68,30 @@ TEST(TestArrowCompute, ArithmeticDecimalTest) { ASSERT_EQ(res, arrow::Decimal128("13780.2495094037")); } +TEST(TestArrowCompute, ArithmeticComparisonTest) { + double v1 = std::numeric_limits::quiet_NaN(); + double v2 = 1.0; + bool res = less_than_with_nan(v1, v2); + ASSERT_EQ(res, false); + res = less_than_with_nan(v1, v1); + ASSERT_EQ(res, false); + res = less_than_or_equal_to_with_nan(v1, v2); + ASSERT_EQ(res, false); + res = less_than_or_equal_to_with_nan(v1, v1); + ASSERT_EQ(res, true); + res = greater_than_with_nan(v1, v2); + ASSERT_EQ(res, true); + res = greater_than_with_nan(v1, v1); + ASSERT_EQ(res, false); + res = greater_than_or_equal_to_with_nan(v1, v2); + ASSERT_EQ(res, true); + res = greater_than_or_equal_to_with_nan(v1, v1); + ASSERT_EQ(res, true); + res = equal_with_nan(v1, v2); + ASSERT_EQ(res, false); + res = equal_with_nan(v1, v1); + ASSERT_EQ(res, true); +} + } // namespace codegen } // namespace sparkcolumnarplugin