diff --git a/velox/functions/prestosql/Arithmetic.h b/velox/functions/prestosql/Arithmetic.h index 7d3c44d894c4..c352d90b46d6 100644 --- a/velox/functions/prestosql/Arithmetic.h +++ b/velox/functions/prestosql/Arithmetic.h @@ -229,7 +229,9 @@ struct PowerFunction { template FOLLY_ALWAYS_INLINE void call(double& result, const TInput& a, const TInput& b) { - result = std::pow(a, b); + result = (std::isnan(a) && b != 0) || std::isnan(b) || std::isinf(b) + ? std::numeric_limits::quiet_NaN() + : pow(a, b); } }; diff --git a/velox/functions/prestosql/tests/ArithmeticTest.cpp b/velox/functions/prestosql/tests/ArithmeticTest.cpp index 36bffb0c10c9..cdd29db0ddec 100644 --- a/velox/functions/prestosql/tests/ArithmeticTest.cpp +++ b/velox/functions/prestosql/tests/ArithmeticTest.cpp @@ -262,9 +262,9 @@ TEST_F(ArithmeticTest, modInt) { TEST_F(ArithmeticTest, power) { std::vector baseDouble = { - 0, 0, 0, -1, -1, -1, -9, 9.1, 10.1, 11.1, -11.1, 0, kInf, kInf}; + 0, 0, 0, -1, -1, -1, -9, 9.1, 10.1, 11.1, -11.1}; std::vector exponentDouble = { - 0, 1, -1, 0, 1, -1, -3.3, 123456.432, -99.9, 0, 100000, kInf, 0, kInf}; + 0, 1, -1, 0, 1, -1, -3.3, 123456.432, -99.9, 0, 100000}; std::vector expectedDouble; expectedDouble.reserve(baseDouble.size()); @@ -279,6 +279,30 @@ TEST_F(ArithmeticTest, power) { "pow(c0, c1)", baseDouble, exponentDouble, expectedDouble); } +TEST_F(ArithmeticTest, powerNan) { + std::vector baseDouble = {1, kNan, kNan, kNan}; + std::vector exponentDouble = {kNan, 1, kInf, 0}; + std::vector expectedDouble = {kNan, kNan, kNan, 1}; + + // Check using function name and alias. + assertExpression( + "power(c0, c1)", baseDouble, exponentDouble, expectedDouble); + assertExpression( + "pow(c0, c1)", baseDouble, exponentDouble, expectedDouble); +} + +TEST_F(ArithmeticTest, powerInf) { + std::vector baseDouble = {1, kInf, kInf, kInf}; + std::vector exponentDouble = {kInf, 1, kNan, 0}; + std::vector expectedDouble = {kNan, kInf, kNan, 1}; + + // Check using function name and alias. + assertExpression( + "power(c0, c1)", baseDouble, exponentDouble, expectedDouble); + assertExpression( + "pow(c0, c1)", baseDouble, exponentDouble, expectedDouble); +} + TEST_F(ArithmeticTest, powerInt) { std::vector baseInt = {9, 10, 11, -9, -10, -11, 0}; std::vector exponentInt = {3, -3, 0, -1, 199999, 77, 0};