Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support NaN comparison in wscg #2

Merged
merged 1 commit into from
Mar 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,16 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
}
codes_str_ = ss.str();
} else if (func_name.compare("less_than_with_nan") == 0) {
real_codes_str_ = "(" + child_visitor_list[0]->GetResult() + " < " +
child_visitor_list[1]->GetResult() + ")";
real_codes_str_ = "less_than_with_nan(" + child_visitor_list[0]->GetResult()
+ ", " + child_visitor_list[1]->GetResult() + ")";
real_validity_str_ = CombineValidity(
{child_visitor_list[0]->GetPreCheck(), child_visitor_list[1]->GetPreCheck()});
ss << real_validity_str_ << " && " << real_codes_str_;
for (int i = 0; i < 2; i++) {
prepare_str_ += child_visitor_list[i]->GetPrepare();
}
codes_str_ = ss.str();
header_list_.push_back(R"(#include "precompile/gandiva.h")");
} else if (func_name.compare("greater_than") == 0) {
real_codes_str_ = "(" + child_visitor_list[0]->GetResult() + " > " +
child_visitor_list[1]->GetResult() + ")";
Expand All @@ -105,15 +106,16 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
}
codes_str_ = ss.str();
} else if (func_name.compare("greater_than_with_nan") == 0) {
real_codes_str_ = "(" + child_visitor_list[0]->GetResult() + " > " +
child_visitor_list[1]->GetResult() + ")";
real_codes_str_ = "greater_than_with_nan(" + child_visitor_list[0]->GetResult()
+ ", " + child_visitor_list[1]->GetResult() + ")";
real_validity_str_ = CombineValidity(
{child_visitor_list[0]->GetPreCheck(), child_visitor_list[1]->GetPreCheck()});
ss << real_validity_str_ << " && " << real_codes_str_;
for (int i = 0; i < 2; i++) {
prepare_str_ += child_visitor_list[i]->GetPrepare();
}
codes_str_ = ss.str();
header_list_.push_back(R"(#include "precompile/gandiva.h")");
} else if (func_name.compare("less_than_or_equal_to") == 0) {
real_codes_str_ = "(" + child_visitor_list[0]->GetResult() +
" <= " + child_visitor_list[1]->GetResult() + ")";
Expand All @@ -125,15 +127,17 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
}
codes_str_ = ss.str();
} else if (func_name.compare("less_than_or_equal_to_with_nan") == 0) {
real_codes_str_ = "(" + child_visitor_list[0]->GetResult() +
" <= " + child_visitor_list[1]->GetResult() + ")";
real_codes_str_ = "less_than_or_equal_to_with_nan("
+ child_visitor_list[0]->GetResult()
+ ", " + child_visitor_list[1]->GetResult() + ")";
real_validity_str_ = CombineValidity(
{child_visitor_list[0]->GetPreCheck(), child_visitor_list[1]->GetPreCheck()});
ss << real_validity_str_ << " && " << real_codes_str_;
for (int i = 0; i < 2; i++) {
prepare_str_ += child_visitor_list[i]->GetPrepare();
}
codes_str_ = ss.str();
header_list_.push_back(R"(#include "precompile/gandiva.h")");
} else if (func_name.compare("greater_than_or_equal_to") == 0) {
real_codes_str_ = "(" + child_visitor_list[0]->GetResult() +
" >= " + child_visitor_list[1]->GetResult() + ")";
Expand All @@ -145,15 +149,17 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
}
codes_str_ = ss.str();
} else if (func_name.compare("greater_than_or_equal_to_with_nan") == 0) {
real_codes_str_ = "(" + child_visitor_list[0]->GetResult() +
" >= " + child_visitor_list[1]->GetResult() + ")";
real_codes_str_ = "greater_than_or_equal_to_with_nan("
+ child_visitor_list[0]->GetResult()
+ ", " + child_visitor_list[1]->GetResult() + ")";
real_validity_str_ = CombineValidity(
{child_visitor_list[0]->GetPreCheck(), child_visitor_list[1]->GetPreCheck()});
ss << real_validity_str_ << " && " << real_codes_str_;
for (int i = 0; i < 2; i++) {
prepare_str_ += child_visitor_list[i]->GetPrepare();
}
codes_str_ = ss.str();
header_list_.push_back(R"(#include "precompile/gandiva.h")");
} else if (func_name.compare("equal") == 0) {
real_codes_str_ = "(" + child_visitor_list[0]->GetResult() +
" == " + child_visitor_list[1]->GetResult() + ")";
Expand All @@ -165,15 +171,16 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
}
codes_str_ = ss.str();
} else if (func_name.compare("equal_with_nan") == 0) {
real_codes_str_ = "(" + child_visitor_list[0]->GetResult() +
" == " + child_visitor_list[1]->GetResult() + ")";
real_codes_str_ = "equal_with_nan(" + child_visitor_list[0]->GetResult()
+ ", " + child_visitor_list[1]->GetResult() + ")";
real_validity_str_ = CombineValidity(
{child_visitor_list[0]->GetPreCheck(), child_visitor_list[1]->GetPreCheck()});
ss << real_validity_str_ << " && " << real_codes_str_;
for (int i = 0; i < 2; i++) {
prepare_str_ += child_visitor_list[i]->GetPrepare();
}
codes_str_ = ss.str();
header_list_.push_back(R"(#include "precompile/gandiva.h")");
} else if (func_name.compare("not") == 0) {
std::string check_validity;
if (child_visitor_list[0]->GetPreCheck() != "") {
Expand Down
68 changes: 68 additions & 0 deletions cpp/src/precompile/gandiva.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,71 @@ arrow::Decimal128 divide(arrow::Decimal128 left, int32_t left_precision,
}
return arrow::Decimal128(out);
}

// A comparison with a NaN always returns false even when comparing with itself.
// To get the same result as spark, we can regard NaN as big as Infinity when
// doing comparison.
bool less_than_with_nan(double left, double right) {
bool left_is_nan = std::isnan(left);
bool right_is_nan = std::isnan(right);
if (left_is_nan && right_is_nan) {
return false;
} else if (left_is_nan) {
return false;
} else if (right_is_nan) {
return true;
}
return left < right;
}

bool greater_than_with_nan(double left, double right) {
bool left_is_nan = std::isnan(left);
bool right_is_nan = std::isnan(right);
if (left_is_nan && right_is_nan) {
return false;
} else if (left_is_nan) {
return true;
} else if (right_is_nan) {
return false;
}
return left > right;
}

bool less_than_or_equal_to_with_nan(double left, double right) {
bool left_is_nan = std::isnan(left);
bool right_is_nan = std::isnan(right);
if (left_is_nan && right_is_nan) {
return true;
} else if (left_is_nan) {
return false;
} else if (right_is_nan) {
return true;
}
return left <= right;
}

bool greater_than_or_equal_to_with_nan(double left, double right) {
bool left_is_nan = std::isnan(left);
bool right_is_nan = std::isnan(right);
if (left_is_nan && right_is_nan) {
return true;
} else if (left_is_nan) {
return true;
} else if (right_is_nan) {
return false;
}
return left >= right;
}

bool equal_with_nan(double left, double right) {
bool left_is_nan = std::isnan(left);
bool right_is_nan = std::isnan(right);
if (left_is_nan && right_is_nan) {
return true;
} else if (left_is_nan) {
return false;
} else if (right_is_nan) {
return false;
}
return left == right;
}
25 changes: 25 additions & 0 deletions cpp/src/tests/arrow_compute_test_precompile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,30 @@ TEST(TestArrowCompute, ArithmeticDecimalTest) {
ASSERT_EQ(res, arrow::Decimal128("13780.2495094037"));
}

TEST(TestArrowCompute, ArithmeticComparisonTest) {
double v1 = std::numeric_limits<double>::quiet_NaN();
double v2 = 1.0;
bool res = less_than_with_nan(v1, v2);
ASSERT_EQ(res, false);
res = less_than_with_nan(v1, v1);
ASSERT_EQ(res, false);
res = less_than_or_equal_to_with_nan(v1, v2);
ASSERT_EQ(res, false);
res = less_than_or_equal_to_with_nan(v1, v1);
ASSERT_EQ(res, true);
res = greater_than_with_nan(v1, v2);
ASSERT_EQ(res, true);
res = greater_than_with_nan(v1, v1);
ASSERT_EQ(res, false);
res = greater_than_or_equal_to_with_nan(v1, v2);
ASSERT_EQ(res, true);
res = greater_than_or_equal_to_with_nan(v1, v1);
ASSERT_EQ(res, true);
res = equal_with_nan(v1, v2);
ASSERT_EQ(res, false);
res = equal_with_nan(v1, v1);
ASSERT_EQ(res, true);
}

} // namespace codegen
} // namespace sparkcolumnarplugin