From a976ba5d03a61af6243e2178ebe3b90db67ee849 Mon Sep 17 00:00:00 2001 From: Daniel Hunte Date: Thu, 10 Oct 2024 15:37:00 -0700 Subject: [PATCH] Fix power function NaN handling so it returns NaN (#11210) Summary: Pull Request resolved: https://github.com/facebookincubator/velox/pull/11210 The C++ std::pow function returns 1 when 1 is raised to NaN. It returns NaN for every other number except for 1 which is strange. This change makes sure it will return NaN. Query: select pow(c0, nan()) from (values 1.0) t(c0); Presto-0.289: Returns NaN Velox: Returns 1.0 Reviewed By: kgpai Differential Revision: D64145275 fbshipit-source-id: 2d4890ad47b2c140867328778d7102733cd1cb32 --- velox/functions/prestosql/Arithmetic.h | 4 ++- .../prestosql/tests/ArithmeticTest.cpp | 28 +++++++++++++++++-- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/velox/functions/prestosql/Arithmetic.h b/velox/functions/prestosql/Arithmetic.h index 7d3c44d894c4..c352d90b46d6 100644 --- a/velox/functions/prestosql/Arithmetic.h +++ b/velox/functions/prestosql/Arithmetic.h @@ -229,7 +229,9 @@ struct PowerFunction { template FOLLY_ALWAYS_INLINE void call(double& result, const TInput& a, const TInput& b) { - result = std::pow(a, b); + result = (std::isnan(a) && b != 0) || std::isnan(b) || std::isinf(b) + ? std::numeric_limits::quiet_NaN() + : pow(a, b); } }; diff --git a/velox/functions/prestosql/tests/ArithmeticTest.cpp b/velox/functions/prestosql/tests/ArithmeticTest.cpp index 36bffb0c10c9..cdd29db0ddec 100644 --- a/velox/functions/prestosql/tests/ArithmeticTest.cpp +++ b/velox/functions/prestosql/tests/ArithmeticTest.cpp @@ -262,9 +262,9 @@ TEST_F(ArithmeticTest, modInt) { TEST_F(ArithmeticTest, power) { std::vector baseDouble = { - 0, 0, 0, -1, -1, -1, -9, 9.1, 10.1, 11.1, -11.1, 0, kInf, kInf}; + 0, 0, 0, -1, -1, -1, -9, 9.1, 10.1, 11.1, -11.1}; std::vector exponentDouble = { - 0, 1, -1, 0, 1, -1, -3.3, 123456.432, -99.9, 0, 100000, kInf, 0, kInf}; + 0, 1, -1, 0, 1, -1, -3.3, 123456.432, -99.9, 0, 100000}; std::vector expectedDouble; expectedDouble.reserve(baseDouble.size()); @@ -279,6 +279,30 @@ TEST_F(ArithmeticTest, power) { "pow(c0, c1)", baseDouble, exponentDouble, expectedDouble); } +TEST_F(ArithmeticTest, powerNan) { + std::vector baseDouble = {1, kNan, kNan, kNan}; + std::vector exponentDouble = {kNan, 1, kInf, 0}; + std::vector expectedDouble = {kNan, kNan, kNan, 1}; + + // Check using function name and alias. + assertExpression( + "power(c0, c1)", baseDouble, exponentDouble, expectedDouble); + assertExpression( + "pow(c0, c1)", baseDouble, exponentDouble, expectedDouble); +} + +TEST_F(ArithmeticTest, powerInf) { + std::vector baseDouble = {1, kInf, kInf, kInf}; + std::vector exponentDouble = {kInf, 1, kNan, 0}; + std::vector expectedDouble = {kNan, kInf, kNan, 1}; + + // Check using function name and alias. + assertExpression( + "power(c0, c1)", baseDouble, exponentDouble, expectedDouble); + assertExpression( + "pow(c0, c1)", baseDouble, exponentDouble, expectedDouble); +} + TEST_F(ArithmeticTest, powerInt) { std::vector baseInt = {9, 10, 11, -9, -10, -11, 0}; std::vector exponentInt = {3, -3, 0, -1, 199999, 77, 0};