From fcc51e48794a0806e7f3d5858fd6f88a61948abc Mon Sep 17 00:00:00 2001 From: Chengcheng Jin Date: Wed, 23 Aug 2023 09:18:50 +0000 Subject: [PATCH 1/4] Implement Spark decimal multiply and divide --- .../boost/CMakeLists.txt | 1 + velox/docs/functions/spark/math.rst | 24 + velox/functions/sparksql/CMakeLists.txt | 1 + .../functions/sparksql/DecimalArithmetic.cpp | 506 ++++++++++++++++++ velox/functions/sparksql/DecimalUtil.h | 196 +++++++ .../functions/sparksql/RegisterArithmetic.cpp | 3 + velox/functions/sparksql/tests/CMakeLists.txt | 1 + .../sparksql/tests/DecimalArithmeticTest.cpp | 288 ++++++++++ 8 files changed, 1020 insertions(+) create mode 100644 velox/functions/sparksql/DecimalArithmetic.cpp create mode 100644 velox/functions/sparksql/DecimalUtil.h create mode 100644 velox/functions/sparksql/tests/DecimalArithmeticTest.cpp diff --git a/CMake/resolve_dependency_modules/boost/CMakeLists.txt b/CMake/resolve_dependency_modules/boost/CMakeLists.txt index 8f9bf06ed44b..c26cf2536fb9 100644 --- a/CMake/resolve_dependency_modules/boost/CMakeLists.txt +++ b/CMake/resolve_dependency_modules/boost/CMakeLists.txt @@ -53,6 +53,7 @@ set(BOOST_HEADER_ONLY circular_buffer math multi_index + multiprecision numeric_conversion random uuid diff --git a/velox/docs/functions/spark/math.rst b/velox/docs/functions/spark/math.rst index 06b671c27ee1..ae367eef9ba4 100644 --- a/velox/docs/functions/spark/math.rst +++ b/velox/docs/functions/spark/math.rst @@ -51,12 +51,25 @@ Mathematical Functions .. spark:function:: divide(x, y) -> double Returns the results of dividing x by y. Performs floating point division. + Supported type is DOUBLE. Corresponds to Spark's operator ``/``. :: SELECT 3 / 2; -- 1.5 SELECT 2L / 2L; -- 1.0 SELECT 3 / 0; -- NULL +.. spark:function:: divide(x, y) -> decimal + + Returns the results of dividing x by y. + Supported type is DECIMAL which can be different precision and scale. + Performs floating point division. + The result type depends on the precision and scale of x and y. + Overflow results return null. Corresponds to Spark's operator ``/``. :: + + SELECT CAST(1 as DECIMAL(17, 3)) / CAST(2 as DECIMAL(17, 3)); -- decimal 0.500000000000000000000 + SELECT CAST(1 as DECIMAL(20, 3)) / CAST(20 as DECIMAL(20, 2)); -- decimal 0.0500000000000000000 + SELECT CAST(1 as DECIMAL(20, 3)) / CAST(0 as DECIMAL(20, 3)); -- NULL + .. spark:function:: exp(x) -> double Returns Euler's number raised to the power of ``x``. @@ -89,6 +102,17 @@ Mathematical Functions Returns the result of multiplying x by y. The types of x and y must be the same. For integral types, overflow results in an error. Corresponds to Spark's operator ``*``. +.. spark:function:: multiply(x, y) -> [decimal] + + Returns the result of multiplying x by y. The types of x and y must be decimal which can be different precision and scale. + The result type depends on the precision and scale of x and y. + Overflow results return null. Corresponds to Spark's operator ``*``. :: + + SELECT CAST(1 as DECIMAL(17, 3)) * CAST(2 as DECIMAL(17, 3)); -- decimal 2.000000 + SELECT CAST(1 as DECIMAL(20, 3)) * CAST(20 as DECIMAL(20, 2)); -- decimal 20.00000 + SELECT CAST(1 as DECIMAL(20, 3)) * CAST(0 as DECIMAL(20, 3)); -- decimal 0.000000 + SELECT CAST(201e-38 as DECIMAL(38, 38)) * CAST(301e-38 as DECIMAL(38, 38)); -- decimal 0.0000000000000000000000000000000000000 + .. spark:function:: not(x) -> boolean Logical not. :: diff --git a/velox/functions/sparksql/CMakeLists.txt b/velox/functions/sparksql/CMakeLists.txt index 265fe8e87082..525fd1bae494 100644 --- a/velox/functions/sparksql/CMakeLists.txt +++ b/velox/functions/sparksql/CMakeLists.txt @@ -17,6 +17,7 @@ add_library( ArraySort.cpp Bitwise.cpp Comparisons.cpp + DecimalArithmetic.cpp Hash.cpp In.cpp LeastGreatest.cpp diff --git a/velox/functions/sparksql/DecimalArithmetic.cpp b/velox/functions/sparksql/DecimalArithmetic.cpp new file mode 100644 index 000000000000..1ee7133be1ed --- /dev/null +++ b/velox/functions/sparksql/DecimalArithmetic.cpp @@ -0,0 +1,506 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/base/CheckedArithmetic.h" +#include "velox/expression/DecodedArgs.h" +#include "velox/expression/VectorFunction.h" +#include "velox/functions/sparksql/DecimalUtil.h" +#include "velox/type/DecimalUtil.h" + +namespace facebook::velox::functions::sparksql { +namespace { + +inline static std::pair adjustPrecisionScale( + const uint8_t rPrecision, + const uint8_t rScale) { + if (rPrecision <= 38) { + return {rPrecision, rScale}; + } else { + int32_t minScale = std::min(static_cast(rScale), 6); + int32_t delta = rPrecision - 38; + return {38, std::max(rScale - delta, minScale)}; + } +} + +std::string getResultScale(std::string precision, std::string scale) { + return fmt::format( + "({}) <= 38 ? ({}) : max(({}) - ({}) + 38, min(({}), 6))", + precision, + scale, + scale, + precision, + scale); +} + +template +inline bool isOverflow(R result, uint8_t rPrecision) { + if (result > velox::DecimalUtil::kPowersOfTen[rPrecision] || + result < -velox::DecimalUtil::kPowersOfTen[rPrecision]) { + return true; + } + return false; +} + +template < + typename R /* Result Type */, + typename A /* Argument1 */, + typename B /* Argument2 */, + typename Operation /* Arithmetic operation */> +class DecimalBaseFunction : public exec::VectorFunction { + public: + DecimalBaseFunction( + uint8_t aRescale, + uint8_t bRescale, + uint8_t aPrecision, + uint8_t aScale, + uint8_t bPrecision, + uint8_t bScale, + uint8_t rPrecision, + uint8_t rScale) + : aRescale_(aRescale), + bRescale_(bRescale), + aPrecision_(aPrecision), + aScale_(aScale), + bPrecision_(bPrecision), + bScale_(bScale), + rPrecision_(rPrecision), + rScale_(rScale) {} + + void apply( + const SelectivityVector& rows, + std::vector& args, + const TypePtr& resultType, + exec::EvalCtx& context, + VectorPtr& result) const override { + auto rawResults = prepareResults(rows, resultType, context, result); + if (args[0]->isConstantEncoding() && args[1]->isFlatEncoding()) { + // Fast path for (const, flat). + auto constant = args[0]->asUnchecked>()->valueAt(0); + auto flatValues = args[1]->asUnchecked>(); + auto rawValues = flatValues->mutableRawValues(); + context.applyToSelectedNoThrow(rows, [&](auto row) { + bool overflow = false; + Operation::template apply( + rawResults[row], + constant, + rawValues[row], + aRescale_, + bRescale_, + aPrecision_, + aScale_, + bPrecision_, + bScale_, + rPrecision_, + rScale_, + overflow); + if (overflow || isOverflow(rawResults[row], rPrecision_)) { + result->setNull(row, true); + } + }); + } else if (args[0]->isFlatEncoding() && args[1]->isConstantEncoding()) { + // Fast path for (flat, const). + auto flatValues = args[0]->asUnchecked>(); + auto constant = args[1]->asUnchecked>()->valueAt(0); + auto rawValues = flatValues->mutableRawValues(); + context.applyToSelectedNoThrow(rows, [&](auto row) { + bool overflow = false; + Operation::template apply( + rawResults[row], + rawValues[row], + constant, + aRescale_, + bRescale_, + aPrecision_, + aScale_, + bPrecision_, + bScale_, + rPrecision_, + rScale_, + overflow); + if (overflow || isOverflow(rawResults[row], rPrecision_)) { + result->setNull(row, true); + } + }); + } else if (args[0]->isFlatEncoding() && args[1]->isFlatEncoding()) { + // Fast path for (flat, flat). + auto flatA = args[0]->asUnchecked>(); + auto rawA = flatA->mutableRawValues(); + auto flatB = args[1]->asUnchecked>(); + auto rawB = flatB->mutableRawValues(); + + context.applyToSelectedNoThrow(rows, [&](auto row) { + bool overflow = false; + Operation::template apply( + rawResults[row], + rawA[row], + rawB[row], + aRescale_, + bRescale_, + aPrecision_, + aScale_, + bPrecision_, + bScale_, + rPrecision_, + rScale_, + overflow); + if (overflow || isOverflow(rawResults[row], rPrecision_)) { + result->setNull(row, true); + } + }); + } else { + // Fast path if one or more arguments are encoded. + exec::DecodedArgs decodedArgs(rows, args, context); + auto a = decodedArgs.at(0); + auto b = decodedArgs.at(1); + context.applyToSelectedNoThrow(rows, [&](auto row) { + bool overflow = false; + Operation::template apply( + rawResults[row], + a->valueAt(row), + b->valueAt(row), + aRescale_, + bRescale_, + aPrecision_, + aScale_, + bPrecision_, + bScale_, + rPrecision_, + rScale_, + overflow); + if (overflow || isOverflow(rawResults[row], rPrecision_)) { + result->setNull(row, true); + } + }); + } + } + + private: + R* prepareResults( + const SelectivityVector& rows, + const TypePtr& resultType, + exec::EvalCtx& context, + VectorPtr& result) const { + context.ensureWritable(rows, resultType, result); + result->clearNulls(rows); + return result->asUnchecked>()->mutableRawValues(); + } + + const uint8_t aRescale_; + const uint8_t bRescale_; + const uint8_t aPrecision_; + const uint8_t aScale_; + const uint8_t bPrecision_; + const uint8_t bScale_; + const uint8_t rPrecision_; + const uint8_t rScale_; +}; + +class Multiply { + public: + // Derive from Arrow. + // https://github.com/apache/arrow/blob/release-12.0.1-rc1/cpp/src/gandiva/precompiled/decimal_ops.cc#L331 + template + inline static void apply( + R& r, + const A& a, + const B& b, + uint8_t aRescale, + uint8_t bRescale, + uint8_t aPrecision, + uint8_t aScale, + uint8_t bPrecision, + uint8_t bScale, + uint8_t rPrecision, + uint8_t rScale, + bool& overflow) { + if (rPrecision < 38) { + R result = DecimalUtil::multiply(R(a), R(b), overflow); + VELOX_DCHECK(!overflow); + r = DecimalUtil::multiply( + result, + R(velox::DecimalUtil::kPowersOfTen[aRescale + bRescale]), + overflow); + VELOX_DCHECK(!overflow); + } else if (a == 0 && b == 0) { + // Handle this separately to avoid divide-by-zero errors. + r = R(0); + } else { + auto deltaScale = aScale + bScale - rScale; + if (deltaScale == 0) { + // No scale down. + // Multiply when the out_precision is 38, and there is no trimming of + // the scale i.e the intermediate value is the same as the final value. + r = DecimalUtil::multiply(R(a), R(b), overflow); + } else { + // Scale down. + // It's possible that the intermediate value does not fit in 128-bits, + // but the final value will (after scaling down). + int32_t totalLeadingZeros = + bits::countLeadingZeros(DecimalUtil::absValue(a)) + + bits::countLeadingZeros(DecimalUtil::absValue(b)); + // This check is quick, but conservative. In some cases it will + // indicate that converting to 256 bits is necessary, when it's not + // actually the case. + if (UNLIKELY(totalLeadingZeros <= 128)) { + // Needs int256. + int256_t reslarge = + static_cast(a) * static_cast(b); + reslarge = reduceScaleBy(reslarge, deltaScale); + r = DecimalUtil::convert(reslarge, overflow); + } else { + if (LIKELY(deltaScale <= 38)) { + // The largest value that result can have here is (2^64 - 1) * (2^63 + // - 1) = 1.70141E+38,which is greater than + // DecimalUtil::kLongDecimalMax. + R result = DecimalUtil::multiply(R(a), R(b), overflow); + VELOX_DCHECK(!overflow); + // Since deltaScale is greater than zero, result can now be at most + // ((2^64 - 1) * (2^63 - 1)) / 10, which is less than + // DecimalUtil::kLongDecimalMax, so there cannot be any overflow. + DecimalUtil::divideWithRoundUp( + r, + result, + R(velox::DecimalUtil::kPowersOfTen[deltaScale]), + 0, + overflow); + VELOX_DCHECK(!overflow); + } else { + // We are multiplying decimal(38, 38) by decimal(38, 38). The result + // should be a + // decimal(38, 37), so delta scale = 38 + 38 - 37 = 39. Since we are + // not in the 256 bit intermediate value case and we are scaling + // down by 39, then we are guaranteed that the result is 0 (even if + // we try to round). The largest possible intermediate result is 38 + // "9"s. If we scale down by 39, the leftmost 9 is now two digits to + // the right of the rightmost "visible" one. The reason why we have + // to handle this case separately is because a scale multiplier with + // a deltaScale 39 does not fit into 128 bit. + r = R(0); + } + } + } + } + } + + inline static uint8_t + computeRescaleFactor(uint8_t fromScale, uint8_t toScale, uint8_t rScale = 0) { + return 0; + } + + inline static std::pair computeResultPrecisionScale( + const uint8_t aPrecision, + const uint8_t aScale, + const uint8_t bPrecision, + const uint8_t bScale) { + return adjustPrecisionScale(aPrecision + bPrecision + 1, aScale + bScale); + } + + private: + inline static int256_t reduceScaleBy(int256_t in, int32_t reduceBy) { + if (reduceBy == 0) { + return in; + } + + int256_t divisor = velox::DecimalUtil::kPowersOfTen[reduceBy]; + auto result = in / divisor; + auto remainder = in % divisor; + // Round up. + if (abs(remainder) >= (divisor >> 1)) { + result += (in > 0 ? 1 : -1); + } + return result; + } +}; + +class Divide { + public: + template + inline static void apply( + R& r, + const A& a, + const B& b, + uint8_t aRescale, + uint8_t /* bRescale */, + uint8_t /* aPrecision */, + uint8_t /* aScale */, + uint8_t /* bPrecision */, + uint8_t /* bScale */, + uint8_t /* rPrecision */, + uint8_t /* rScale */, + bool& overflow) { + DecimalUtil::divideWithRoundUp(r, a, b, aRescale, overflow); + } + + inline static uint8_t + computeRescaleFactor(uint8_t fromScale, uint8_t toScale, uint8_t rScale) { + return rScale - fromScale + toScale; + } + + inline static std::pair computeResultPrecisionScale( + const uint8_t aPrecision, + const uint8_t aScale, + const uint8_t bPrecision, + const uint8_t bScale) { + auto scale = std::max(6, aScale + bPrecision + 1); + auto precision = aPrecision - aScale + bScale + scale; + return adjustPrecisionScale(precision, scale); + } +}; + +std::vector> +decimalMultiplySignature() { + return {exec::FunctionSignatureBuilder() + .integerVariable("a_precision") + .integerVariable("a_scale") + .integerVariable("b_precision") + .integerVariable("b_scale") + .integerVariable( + "r_precision", "min(38, a_precision + b_precision + 1)") + .integerVariable( + "r_scale", + getResultScale( + "a_precision + b_precision + 1", "a_scale + b_scale")) + .returnType("DECIMAL(r_precision, r_scale)") + .argumentType("DECIMAL(a_precision, a_scale)") + .argumentType("DECIMAL(b_precision, b_scale)") + .build()}; +} + +std::vector> decimalDivideSignature() { + return { + exec::FunctionSignatureBuilder() + .integerVariable("a_precision") + .integerVariable("a_scale") + .integerVariable("b_precision") + .integerVariable("b_scale") + .integerVariable( + "r_precision", + "min(38, a_precision - a_scale + b_scale + max(6, a_scale + b_precision + 1))") + .integerVariable( + "r_scale", + getResultScale( + "a_precision - a_scale + b_scale + max(6, a_scale + b_precision + 1)", + "max(6, a_scale + b_precision + 1)")) + .returnType("DECIMAL(r_precision, r_scale)") + .argumentType("DECIMAL(a_precision, a_scale)") + .argumentType("DECIMAL(b_precision, b_scale)") + .build()}; +} + +template +std::shared_ptr createDecimalFunction( + const std::string& name, + const std::vector& inputArgs, + const core::QueryConfig& /*config*/) { + auto aType = inputArgs[0].type; + auto bType = inputArgs[1].type; + auto [aPrecision, aScale] = getDecimalPrecisionScale(*aType); + auto [bPrecision, bScale] = getDecimalPrecisionScale(*bType); + auto [rPrecision, rScale] = Operation::computeResultPrecisionScale( + aPrecision, aScale, bPrecision, bScale); + uint8_t aRescale = Operation::computeRescaleFactor(aScale, bScale, rScale); + uint8_t bRescale = Operation::computeRescaleFactor(bScale, aScale, rScale); + if (aType->isShortDecimal()) { + if (bType->isShortDecimal()) { + if (rPrecision > ShortDecimalType::kMaxPrecision) { + return std::make_shared>( + aRescale, + bRescale, + aPrecision, + aScale, + bPrecision, + bScale, + rPrecision, + rScale); + } else { + return std::make_shared>( + aRescale, + bRescale, + aPrecision, + aScale, + bPrecision, + bScale, + rPrecision, + rScale); + } + } else { + return std::make_shared>( + aRescale, + bRescale, + aPrecision, + aScale, + bPrecision, + bScale, + rPrecision, + rScale); + } + } else { + if (bType->isShortDecimal()) { + return std::make_shared>( + aRescale, + bRescale, + aPrecision, + aScale, + bPrecision, + bScale, + rPrecision, + rScale); + } else { + return std::make_shared>( + aRescale, + bRescale, + aPrecision, + aScale, + bPrecision, + bScale, + rPrecision, + rScale); + } + } + VELOX_UNSUPPORTED(); +} +}; // namespace + +VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( + udf_decimal_mul, + decimalMultiplySignature(), + createDecimalFunction); + +VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( + udf_decimal_div, + decimalDivideSignature(), + createDecimalFunction); +}; // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/DecimalUtil.h b/velox/functions/sparksql/DecimalUtil.h new file mode 100644 index 000000000000..545111ef18b8 --- /dev/null +++ b/velox/functions/sparksql/DecimalUtil.h @@ -0,0 +1,196 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "velox/type/DecimalUtil.h" +#include "velox/type/Type.h" + +namespace facebook::velox::functions::sparksql { +using int256_t = boost::multiprecision::int256_t; + +// DecimalUtil holds the utility function for Spark sql. +class DecimalUtil { + public: + /// @brief Convert int256 value to int64 or int128, set overflow to true if + /// value cannot convert to specific type. + /// @return The converted value. + template < + class T, + typename = std::enable_if_t< + std::is_same_v || std::is_same_v>> + inline static T convert(int256_t in, bool& overflow) { + typedef typename std:: + conditional, uint64_t, __uint128_t>::type UT; + T result = 0; + constexpr auto uintMask = + static_cast(std::numeric_limits::max()); + + int256_t inAbs = abs(in); + bool isNegative = in < 0; + + UT unsignResult = (inAbs & uintMask).convert_to(); + inAbs >>= sizeof(T) * 8; + + if (inAbs > 0) { + // We've shifted in by bit of T, so nothing should be left. + overflow = true; + } else if (unsignResult > std::numeric_limits::max()) { + overflow = true; + } else { + result = static_cast(unsignResult); + } + return isNegative ? -result : result; + } + + // Returns the abs value of input value. + template >> + FOLLY_ALWAYS_INLINE static uint64_t absValue(int64_t a) { + return a < 0 ? static_cast(-a) : static_cast(a); + } + + // Returns the abs value of input value. + template >> + FOLLY_ALWAYS_INLINE static uint128_t absValue(int128_t a) { + return a < 0 ? static_cast(-a) : static_cast(a); + } + + /// Multiply a and b, set overflow to true if overflow. The caller should + /// treat overflow flag first. + template >> + FOLLY_ALWAYS_INLINE static int64_t + multiply(int64_t a, int64_t b, bool& overflow) { + int64_t value; + overflow = __builtin_mul_overflow(a, b, &value); + return value; + } + + /// Multiply a and b, set overflow to true if overflow. The caller should + /// treat overflow flag first. + template >> + FOLLY_ALWAYS_INLINE static int128_t + multiply(int128_t a, int128_t b, bool& overflow) { + int128_t value; + overflow = __builtin_mul_overflow(a, b, &value); + return value; + } + + /// Derives from Arrow BasicDecimal128 Divide. + /// https://github.com/apache/arrow/blob/release-12.0.1-rc1/cpp/src/gandiva/precompiled/decimal_ops.cc#L350 + /// + /// Divide a and b, set overflow to true if the result overflows. The caller + /// should treat the overflow flag first. Using HALF_UP rounding, the digit 5 + /// is rounded up. + /// Compute the maximum bits required to increase, if it is less or equal than + /// 127 bits, int128_t is enough, if not, we should introduce int256_t as + /// intermediate type, and then convert to real result type with overflow + /// flag. + template + inline static R divideWithRoundUp( + R& r, + const A& a, + const B& b, + uint8_t aRescale, + bool& overflow) { + if (b == 0) { + overflow = true; + return R(-1); + } + int resultSign = 1; + R unsignedDividendRescaled(a); + int aSign = 1; + int bSign = 1; + if (a < 0) { + resultSign = -1; + unsignedDividendRescaled *= -1; + aSign = -1; + } + R unsignedDivisor(b); + if (b < 0) { + resultSign *= -1; + unsignedDivisor *= -1; + bSign = -1; + } + auto bitsRequiredAfterScaling = maxBitsRequiredAfterScaling(a, aRescale); + if (bitsRequiredAfterScaling <= 127) { + // Fast-path. The dividend fits in 128-bit after scaling too. + overflow = __builtin_mul_overflow( + unsignedDividendRescaled, + R(velox::DecimalUtil::kPowersOfTen[aRescale]), + &unsignedDividendRescaled); + VELOX_DCHECK(!overflow); + R quotient = unsignedDividendRescaled / unsignedDivisor; + R remainder = unsignedDividendRescaled % unsignedDivisor; + if (remainder * 2 >= unsignedDivisor) { + ++quotient; + } + r = quotient * resultSign; + return remainder; + } else { + if (aRescale > 38 && bitsRequiredAfterScaling > 255) { + overflow = true; + return R(-1); + } + int256_t aLarge = a; + int256_t aLargeScaledUp = + aLarge * velox::DecimalUtil::kPowersOfTen[aRescale]; + int256_t bLarge = b; + int256_t resultLarge = aLargeScaledUp / bLarge; + int256_t remainderLarge = aLargeScaledUp % bLarge; + /// Since we are scaling up and then, scaling down, round-up the result + /// (+1 for +ve, -1 for -ve), if the remainder is >= 2 * divisor. + if (abs(2 * remainderLarge) >= abs(bLarge)) { + /// x +ve and y +ve, result is +ve => (1 ^ 1) + 1 = 0 + 1 = +1 + /// x +ve and y -ve, result is -ve => (-1 ^ 1) + 1 = -2 + 1 = -1 + /// x +ve and y -ve, result is -ve => (1 ^ -1) + 1 = -2 + 1 = -1 + /// x -ve and y -ve, result is +ve => (-1 ^ -1) + 1 = 0 + 1 = +1 + resultLarge += (aSign ^ bSign) + 1; + } + + auto result = convert(resultLarge, overflow); + if (overflow) { + return R(-1); + } + r = result; + auto remainder = convert(remainderLarge, overflow); + return remainder; + } + } + + private: + /// We rely on the following formula: + /// bits_required(x * 10^y) <= bits_required(x) + floor(log2(10^y)) + 1 + /// We precompute floor(log2(10^x)) + 1 for x = 0, 1, 2...75, 76 + static constexpr int32_t kMaxBitsRequiredIncreaseAfterScaling[] = { + 0, 4, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37, 40, + 44, 47, 50, 54, 57, 60, 64, 67, 70, 74, 77, 80, 84, + 87, 90, 94, 97, 100, 103, 107, 110, 113, 117, 120, 123, 127, + 130, 133, 137, 140, 143, 147, 150, 153, 157, 160, 163, 167, 170, + 173, 177, 180, 183, 187, 190, 193, 196, 200, 203, 206, 210, 213, + 216, 220, 223, 226, 230, 233, 236, 240, 243, 246, 250, 253}; + + template + inline static int32_t maxBitsRequiredAfterScaling( + const A& num, + uint8_t aRescale) { + auto valueAbs = absValue(num); + int32_t numOccupied = sizeof(A) * 8 - bits::countLeadingZeros(valueAbs); + return numOccupied + kMaxBitsRequiredIncreaseAfterScaling[aRescale]; + } +}; +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/RegisterArithmetic.cpp b/velox/functions/sparksql/RegisterArithmetic.cpp index 6f16b5699f4b..5435cdf66010 100644 --- a/velox/functions/sparksql/RegisterArithmetic.cpp +++ b/velox/functions/sparksql/RegisterArithmetic.cpp @@ -90,6 +90,9 @@ void registerArithmeticFunctions(const std::string& prefix) { registerFunction({prefix + "log2"}); registerFunction({prefix + "log10"}); registerRandFunctions(prefix); + + VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_mul, prefix + "multiply"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_div, prefix + "divide"); } } // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/tests/CMakeLists.txt b/velox/functions/sparksql/tests/CMakeLists.txt index e574afcf7bd4..a4960590f5ac 100644 --- a/velox/functions/sparksql/tests/CMakeLists.txt +++ b/velox/functions/sparksql/tests/CMakeLists.txt @@ -21,6 +21,7 @@ add_executable( BitwiseTest.cpp ComparisonsTest.cpp DateTimeFunctionsTest.cpp + DecimalArithmeticTest.cpp ElementAtTest.cpp HashTest.cpp InTest.cpp diff --git a/velox/functions/sparksql/tests/DecimalArithmeticTest.cpp b/velox/functions/sparksql/tests/DecimalArithmeticTest.cpp new file mode 100644 index 000000000000..a370809a946f --- /dev/null +++ b/velox/functions/sparksql/tests/DecimalArithmeticTest.cpp @@ -0,0 +1,288 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" +#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" + +using namespace facebook::velox; +using namespace facebook::velox::test; +using namespace facebook::velox::functions::test; + +namespace facebook::velox::functions::sparksql::test { +namespace { + +class DecimalArithmeticTest : public SparkFunctionBaseTest { + public: + DecimalArithmeticTest() { + options_.parseDecimalAsDouble = false; + } + + protected: + template + void testDecimalExpr( + const VectorPtr& expected, + const std::string& expression, + const std::vector& input) { + using EvalType = typename velox::TypeTraits::NativeType; + auto result = + evaluate>(expression, makeRowVector(input)); + assertEqualVectors(expected, result); + } + + VectorPtr makeLongDecimalVector( + const std::vector& value, + int8_t precision, + int8_t scale) { + if (value.size() == 1) { + return makeConstant( + HugeInt::parse(std::move(value[0])), 1, DECIMAL(precision, scale)); + } + std::vector int128s; + for (auto& v : value) { + int128s.emplace_back(HugeInt::parse(std::move(v))); + } + return makeFlatVector(int128s, DECIMAL(precision, scale)); + } +}; // namespace + +TEST_F(DecimalArithmeticTest, multiply) { + // The result can be obtained by Spark unit test + // test("multiply") { + // val l1 = Literal.create( + // Decimal(BigDecimal(1), 17, 3), + // DecimalType(17, 3)) + // val l2 = Literal.create( + // Decimal(BigDecimal(1), 17, 3), + // DecimalType(17, 3)) + // checkEvaluation(Divide(l1, l2), null) + // } + auto shortFlat = makeFlatVector({1000, 2000}, DECIMAL(17, 3)); + // Multiply short and short, returning long. + testDecimalExpr( + makeFlatVector({1000000, 4000000}, DECIMAL(35, 6)), + "multiply(c0, c1)", + {shortFlat, shortFlat}); + // Multiply short and long, returning long. + auto longFlat = makeFlatVector({1000, 2000}, DECIMAL(20, 3)); + auto expectedLongFlat = + makeFlatVector({1000000, 4000000}, DECIMAL(38, 6)); + testDecimalExpr( + expectedLongFlat, "multiply(c0, c1)", {shortFlat, longFlat}); + // Multiply long and short, returning long. + testDecimalExpr( + expectedLongFlat, "multiply(c0, c1)", {longFlat, shortFlat}); + + // Multiply long and long, returning long. + testDecimalExpr( + makeFlatVector({1000000, 4000000}, DECIMAL(38, 6)), + "multiply(c0, c1)", + {longFlat, longFlat}); + + auto leftFlat0 = makeFlatVector({0, 1, 0}, DECIMAL(20, 3)); + auto rightFlat0 = makeFlatVector({1, 0, 0}, DECIMAL(20, 2)); + testDecimalExpr( + makeFlatVector({0, 0, 0}, DECIMAL(38, 5)), + "multiply(c0, c1)", + {leftFlat0, rightFlat0}); + + // Multiply short and short, returning short. + shortFlat = makeFlatVector({1000, 2000}, DECIMAL(6, 3)); + testDecimalExpr( + makeFlatVector({1000000, 4000000}, DECIMAL(13, 6)), + "c0 * c1", + {shortFlat, shortFlat}); + + auto expectedConstantFlat = + makeFlatVector({100000, 200000}, DECIMAL(10, 5)); + // Constant and Flat arguments. + testDecimalExpr( + expectedConstantFlat, "1.00 * c0", {shortFlat}); + + // Flat and Constant arguments. + testDecimalExpr( + expectedConstantFlat, "c0 * 1.00", {shortFlat}); + + // out_precision == 38, small input values, trimming of scale. + testDecimalExpr( + makeConstant(61, 1, DECIMAL(38, 7)), + "c0 * c1", + {makeConstant(201, 1, DECIMAL(20, 5)), + makeConstant(301, 1, DECIMAL(20, 5))}); + + // out_precision == 38, large values, trimming of scale. + testDecimalExpr( + makeConstant( + HugeInt::parse("201" + std::string(31, '0')), 1, DECIMAL(38, 6)), + "c0 * c1", + {makeConstant(201, 1, DECIMAL(20, 5)), + makeConstant( + HugeInt::parse(std::string(35, '9')), 1, DECIMAL(35, 5))}); + + // out_precision == 38, very large values, trimming of scale (requires convert + // to 256). + testDecimalExpr( + makeConstant( + HugeInt::parse("9999999999999999999999999999999999890"), + 1, + DECIMAL(38, 6)), + "c0 * c1", + {makeConstant( + HugeInt::parse(std::string(35, '9')), 1, DECIMAL(38, 20)), + makeConstant( + HugeInt::parse(std::string(36, '9')), 1, DECIMAL(38, 20))}); + + // out_precision == 38, very large values, trimming of scale (requires convert + // to 256). should cause overflow. + testDecimalExpr( + makeConstant(std::nullopt, 1, DECIMAL(38, 6)), + "c0 * c1", + {makeConstant( + HugeInt::parse(std::string(35, '9')), 1, DECIMAL(38, 4)), + makeConstant( + HugeInt::parse(std::string(36, '9')), 1, DECIMAL(38, 4))}); + + // Big scale * big scale. + testDecimalExpr( + makeConstant(0, 1, DECIMAL(38, 37)), + "c0 * c1", + {makeConstant(201, 1, DECIMAL(38, 38)), + makeConstant(301, 1, DECIMAL(38, 38))}); + + // Long decimal limits. + testDecimalExpr( + makeConstant(std::nullopt, 1, DECIMAL(38, 0)), + "c0 * cast(10.00 as decimal(2,0))", + {makeConstant( + HugeInt::build(0x08FFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF), + 1, + DECIMAL(38, 0))}); + + // Rescaling the final result overflows. + testDecimalExpr( + makeConstant(std::nullopt, 1, DECIMAL(38, 1)), + "c0 * cast(1.00 as decimal(2,1))", + {makeConstant( + HugeInt::build(0x08FFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF), + 1, + DECIMAL(38, 0))}); +} + +TEST_F(DecimalArithmeticTest, decimalDivTest) { + auto shortFlat = makeFlatVector({1000, 2000}, DECIMAL(17, 3)); + // Divide short and short, returning long. + testDecimalExpr( + makeLongDecimalVector( + {"500000000000000000000", "2000000000000000000000"}, 38, 21), + "divide(c0, c1)", + {makeFlatVector({500, 4000}, DECIMAL(17, 3)), shortFlat}); + + // Divide short and long, returning long. + auto longFlat = makeFlatVector({500, 4000}, DECIMAL(20, 2)); + testDecimalExpr( + makeFlatVector( + {500000000000000000, 2000000000000000000}, DECIMAL(38, 17)), + "divide(c0, c1)", + {longFlat, shortFlat}); + + // Divide long and short, returning long. + testDecimalExpr( + makeLongDecimalVector( + {"20" + std::string(20, '0'), "5" + std::string(20, '0')}, 38, 22), + "divide(c0, c1)", + {shortFlat, longFlat}); + + // Divide long and long, returning long. + testDecimalExpr( + makeLongDecimalVector( + {"5" + std::string(18, '0'), "3" + std::string(18, '0')}, 38, 18), + "divide(c0, c1)", + {makeFlatVector({2500, 12000}, DECIMAL(20, 2)), longFlat}); + + // Divide short and short, returning short. + testDecimalExpr( + makeFlatVector({500000000, 300000000}, DECIMAL(13, 11)), + "divide(c0, c1)", + {makeFlatVector({2500, 12000}, DECIMAL(5, 5)), + makeFlatVector({500, 4000}, DECIMAL(5, 2))}); + // This result can be obtained by Spark unit test + // test("divide decimal big") { + // val s = Seq(35, 6, 20, 3) + // var builder = new StringBuffer() + // (0 until 29).foreach(_ => builder = builder.append("9")) + // builder.append(".") + // (0 until 6).foreach(_ => builder = builder.append("9")) + // val str1 = builder.toString + + // val l1 = Literal.create( + // Decimal(BigDecimal(str1), s.head, s(1)), + // DecimalType(s.head, s(1))) + // val l2 = Literal.create( + // Decimal(BigDecimal(0.201), s(2), s(3)), + // DecimalType(s(2), s(3))) + // checkEvaluation(Divide(l1, l2), null) + // } + testDecimalExpr( + makeLongDecimalVector({"497512437810945273631840796019900493"}, 38, 6), + "c0 / c1", + {makeLongDecimalVector({std::string(35, '9')}, 35, 6), + makeConstant(201, 1, DECIMAL(20, 3))}); + + testDecimalExpr( + makeLongDecimalVector( + {"1000" + std::string(17, '0'), "500" + std::string(17, '0')}, + 24, + 20), + "1.00 / c0", + {shortFlat}); + + // Flat and Constant arguments. + testDecimalExpr( + makeLongDecimalVector( + {"500" + std::string(4, '0'), "1000" + std::string(4, '0')}, 23, 7), + "c0 / 2.00", + {shortFlat}); + + // Divide and round-up. + // The result can be obtained by Spark unit test + // test("divide test") { + // spark.sql("create table decimals_test(a decimal(2,1)) using parquet;") + // spark.sql("insert into decimals_test values(6)") + // val df = spark.sql("select a / -6.0 from decimals_test") + // df.printSchema() + // df.show(truncate = false) + // spark.sql("drop table decimals_test;") + // } + testDecimalExpr( + {makeFlatVector( + {566667, -83333, -1083333, -1500000, -33333, 816667}, DECIMAL(8, 6))}, + "c0 / -6.0", + {makeFlatVector({-34, 5, 65, 90, 2, -49}, DECIMAL(2, 1))}); + // Divide by zero. + testDecimalExpr( + makeConstant(std::nullopt, 2, DECIMAL(21, 6)), + "c0 / 0.0", + {shortFlat}); + + // Long decimal limits. + testDecimalExpr( + makeConstant(std::nullopt, 1, DECIMAL(38, 6)), + "c0 / 0.01", + {makeConstant( + DecimalUtil::kLongDecimalMax, 1, DECIMAL(38, 0))}); +} +} // namespace +} // namespace facebook::velox::functions::sparksql::test From 11507b7c3bf317329ced17762d30a1bd1c44fde5 Mon Sep 17 00:00:00 2001 From: Chengcheng Jin Date: Fri, 25 Aug 2023 10:19:12 +0000 Subject: [PATCH 2/4] optimize divideWithRoundUp and add test --- velox/functions/sparksql/DecimalUtil.h | 6 +-- velox/functions/sparksql/tests/CMakeLists.txt | 1 + .../sparksql/tests/DecimalUtilTest.cpp | 46 +++++++++++++++++++ 3 files changed, 50 insertions(+), 3 deletions(-) create mode 100644 velox/functions/sparksql/tests/DecimalUtilTest.cpp diff --git a/velox/functions/sparksql/DecimalUtil.h b/velox/functions/sparksql/DecimalUtil.h index 545111ef18b8..9c41486ca210 100644 --- a/velox/functions/sparksql/DecimalUtil.h +++ b/velox/functions/sparksql/DecimalUtil.h @@ -96,8 +96,8 @@ class DecimalUtil { /// Divide a and b, set overflow to true if the result overflows. The caller /// should treat the overflow flag first. Using HALF_UP rounding, the digit 5 /// is rounded up. - /// Compute the maximum bits required to increase, if it is less or equal than - /// 127 bits, int128_t is enough, if not, we should introduce int256_t as + /// Compute the maximum bits required to increase, if it is less than + /// result type bits, result type is enough, if not, we should introduce int256_t as /// intermediate type, and then convert to real result type with overflow /// flag. template @@ -127,7 +127,7 @@ class DecimalUtil { bSign = -1; } auto bitsRequiredAfterScaling = maxBitsRequiredAfterScaling(a, aRescale); - if (bitsRequiredAfterScaling <= 127) { + if (bitsRequiredAfterScaling < sizeof(R) * 8) { // Fast-path. The dividend fits in 128-bit after scaling too. overflow = __builtin_mul_overflow( unsignedDividendRescaled, diff --git a/velox/functions/sparksql/tests/CMakeLists.txt b/velox/functions/sparksql/tests/CMakeLists.txt index a4960590f5ac..1ed313246eb8 100644 --- a/velox/functions/sparksql/tests/CMakeLists.txt +++ b/velox/functions/sparksql/tests/CMakeLists.txt @@ -22,6 +22,7 @@ add_executable( ComparisonsTest.cpp DateTimeFunctionsTest.cpp DecimalArithmeticTest.cpp + DecimalUtilTest.cpp ElementAtTest.cpp HashTest.cpp InTest.cpp diff --git a/velox/functions/sparksql/tests/DecimalUtilTest.cpp b/velox/functions/sparksql/tests/DecimalUtilTest.cpp new file mode 100644 index 000000000000..211bec68e663 --- /dev/null +++ b/velox/functions/sparksql/tests/DecimalUtilTest.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/sparksql/DecimalUtil.h" +#include "velox/common/base/tests/GTestUtils.h" + +namespace facebook::velox::functions::sparksql::test { +namespace { + +class DecimalUtilTest : public testing::Test { + protected: + template + void testDivideWithRoundUp( + A a, + B b, + int32_t aRescale, + R expectedResult, + bool expectedOverflow) { + R r; + bool overflow; + DecimalUtil::divideWithRoundUp(r, a, b, aRescale, overflow); + ASSERT_EQ(overflow, expectedOverflow); + ASSERT_EQ(r, expectedResult); + } +}; +} // namespace + +TEST_F(DecimalUtilTest, divideWithRoundUp) { + testDivideWithRoundUp(60, 30, 3, 2000, false); + testDivideWithRoundUp( + 6, velox::DecimalUtil::kPowersOfTen[17], 20, 6000, false); +} +} // namespace facebook::velox::functions::sparksql::test \ No newline at end of file From 4a934c04e238df43de58b1b1138580f300aeb660 Mon Sep 17 00:00:00 2001 From: Chengcheng Jin Date: Tue, 29 Aug 2023 16:11:01 +0000 Subject: [PATCH 3/4] fix code style --- velox/functions/sparksql/DecimalUtil.h | 6 +++--- velox/functions/sparksql/tests/DecimalUtilTest.cpp | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/velox/functions/sparksql/DecimalUtil.h b/velox/functions/sparksql/DecimalUtil.h index 9c41486ca210..d1627772e89f 100644 --- a/velox/functions/sparksql/DecimalUtil.h +++ b/velox/functions/sparksql/DecimalUtil.h @@ -97,9 +97,9 @@ class DecimalUtil { /// should treat the overflow flag first. Using HALF_UP rounding, the digit 5 /// is rounded up. /// Compute the maximum bits required to increase, if it is less than - /// result type bits, result type is enough, if not, we should introduce int256_t as - /// intermediate type, and then convert to real result type with overflow - /// flag. + /// result type bits, result type is enough, if not, we should introduce + /// int256_t as intermediate type, and then convert to real result type with + /// overflow flag. template inline static R divideWithRoundUp( R& r, diff --git a/velox/functions/sparksql/tests/DecimalUtilTest.cpp b/velox/functions/sparksql/tests/DecimalUtilTest.cpp index 211bec68e663..350bbf526583 100644 --- a/velox/functions/sparksql/tests/DecimalUtilTest.cpp +++ b/velox/functions/sparksql/tests/DecimalUtilTest.cpp @@ -43,4 +43,4 @@ TEST_F(DecimalUtilTest, divideWithRoundUp) { testDivideWithRoundUp( 6, velox::DecimalUtil::kPowersOfTen[17], 20, 6000, false); } -} // namespace facebook::velox::functions::sparksql::test \ No newline at end of file +} // namespace facebook::velox::functions::sparksql::test From fc5f92b646c25ecca24b8050c3fe0f5ea5d9db93 Mon Sep 17 00:00:00 2001 From: Chengcheng Jin Date: Fri, 8 Sep 2023 10:44:09 +0000 Subject: [PATCH 4/4] use new API --- .../functions/sparksql/DecimalArithmetic.cpp | 42 +++++++------------ velox/functions/sparksql/DecimalUtil.h | 19 +++++++++ 2 files changed, 34 insertions(+), 27 deletions(-) diff --git a/velox/functions/sparksql/DecimalArithmetic.cpp b/velox/functions/sparksql/DecimalArithmetic.cpp index 1ee7133be1ed..ffb782da0694 100644 --- a/velox/functions/sparksql/DecimalArithmetic.cpp +++ b/velox/functions/sparksql/DecimalArithmetic.cpp @@ -23,18 +23,6 @@ namespace facebook::velox::functions::sparksql { namespace { -inline static std::pair adjustPrecisionScale( - const uint8_t rPrecision, - const uint8_t rScale) { - if (rPrecision <= 38) { - return {rPrecision, rScale}; - } else { - int32_t minScale = std::min(static_cast(rScale), 6); - int32_t delta = rPrecision - 38; - return {38, std::max(rScale - delta, minScale)}; - } -} - std::string getResultScale(std::string precision, std::string scale) { return fmt::format( "({}) <= 38 ? ({}) : max(({}) - ({}) + 38, min(({}), 6))", @@ -45,15 +33,6 @@ std::string getResultScale(std::string precision, std::string scale) { scale); } -template -inline bool isOverflow(R result, uint8_t rPrecision) { - if (result > velox::DecimalUtil::kPowersOfTen[rPrecision] || - result < -velox::DecimalUtil::kPowersOfTen[rPrecision]) { - return true; - } - return false; -} - template < typename R /* Result Type */, typename A /* Argument1 */, @@ -106,7 +85,9 @@ class DecimalBaseFunction : public exec::VectorFunction { rPrecision_, rScale_, overflow); - if (overflow || isOverflow(rawResults[row], rPrecision_)) { + if (overflow || + !velox::DecimalUtil::valueInPrecisionRange( + rawResults[row], rPrecision_)) { result->setNull(row, true); } }); @@ -130,7 +111,9 @@ class DecimalBaseFunction : public exec::VectorFunction { rPrecision_, rScale_, overflow); - if (overflow || isOverflow(rawResults[row], rPrecision_)) { + if (overflow || + !velox::DecimalUtil::valueInPrecisionRange( + rawResults[row], rPrecision_)) { result->setNull(row, true); } }); @@ -156,7 +139,9 @@ class DecimalBaseFunction : public exec::VectorFunction { rPrecision_, rScale_, overflow); - if (overflow || isOverflow(rawResults[row], rPrecision_)) { + if (overflow || + !velox::DecimalUtil::valueInPrecisionRange( + rawResults[row], rPrecision_)) { result->setNull(row, true); } }); @@ -180,7 +165,9 @@ class DecimalBaseFunction : public exec::VectorFunction { rPrecision_, rScale_, overflow); - if (overflow || isOverflow(rawResults[row], rPrecision_)) { + if (overflow || + !velox::DecimalUtil::valueInPrecisionRange( + rawResults[row], rPrecision_)) { result->setNull(row, true); } }); @@ -305,7 +292,8 @@ class Multiply { const uint8_t aScale, const uint8_t bPrecision, const uint8_t bScale) { - return adjustPrecisionScale(aPrecision + bPrecision + 1, aScale + bScale); + return DecimalUtil::adjustPrecisionScale( + aPrecision + bPrecision + 1, aScale + bScale); } private: @@ -356,7 +344,7 @@ class Divide { const uint8_t bScale) { auto scale = std::max(6, aScale + bPrecision + 1); auto precision = aPrecision - aScale + bScale + scale; - return adjustPrecisionScale(precision, scale); + return DecimalUtil::adjustPrecisionScale(precision, scale); } }; diff --git a/velox/functions/sparksql/DecimalUtil.h b/velox/functions/sparksql/DecimalUtil.h index d1627772e89f..b6aa538f3e05 100644 --- a/velox/functions/sparksql/DecimalUtil.h +++ b/velox/functions/sparksql/DecimalUtil.h @@ -27,6 +27,25 @@ using int256_t = boost::multiprecision::int256_t; // DecimalUtil holds the utility function for Spark sql. class DecimalUtil { public: + /// Scale adjustment implementation is based on Hive's one, which is itself + /// inspired to SQLServer's one. In particular, when a result precision is + /// greater than {LongDecimalType::kMaxPrecision}, the corresponding scale is + /// reduced to prevent the integral part of a result from being truncated. + /// + /// This method is used only when + /// `spark.sql.decimalOperations.allowPrecisionLoss` is set to true. + inline static std::pair adjustPrecisionScale( + const uint8_t rPrecision, + const uint8_t rScale) { + if (rPrecision <= LongDecimalType::kMaxPrecision) { + return {rPrecision, rScale}; + } else { + int32_t minScale = std::min(static_cast(rScale), 6); + int32_t delta = rPrecision - 38; + return {38, std::max(rScale - delta, minScale)}; + } + } + /// @brief Convert int256 value to int64 or int128, set overflow to true if /// value cannot convert to specific type. /// @return The converted value.