From f5eca90f466bbeff9863fd371114e1082260384c Mon Sep 17 00:00:00 2001 From: "joey.ljy" Date: Tue, 31 Oct 2023 10:43:16 +0800 Subject: [PATCH] introduce computeValidSum for spark sql reuse --- .../lib/aggregates/SumAggregateBase.h | 18 +++++------------- velox/type/DecimalUtil.h | 18 +++++++++++++++++- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/velox/functions/lib/aggregates/SumAggregateBase.h b/velox/functions/lib/aggregates/SumAggregateBase.h index a159ae7f1345a..18423e099ac39 100644 --- a/velox/functions/lib/aggregates/SumAggregateBase.h +++ b/velox/functions/lib/aggregates/SumAggregateBase.h @@ -185,19 +185,11 @@ class DecimalSumAggregate virtual int128_t computeFinalValue( functions::aggregate::LongDecimalWithOverflowState* accumulator) final { - // Value is valid if the conditions below are true. - int128_t sum = accumulator->sum; - if ((accumulator->overflow == 1 && accumulator->sum < 0) || - (accumulator->overflow == -1 && accumulator->sum > 0)) { - sum = static_cast( - DecimalUtil::kOverflowMultiplier * accumulator->overflow + - accumulator->sum); - } else { - VELOX_CHECK(accumulator->overflow == 0, "Decimal overflow"); - } - - DecimalUtil::valueInRange(sum); - return sum; + auto sum = + DecimalUtil::computeValidSum(accumulator->sum, accumulator->overflow); + VELOX_CHECK(sum.has_value(), "Decimal overflow"); + DecimalUtil::valueInRange(sum.value()); + return sum.value(); } }; diff --git a/velox/type/DecimalUtil.h b/velox/type/DecimalUtil.h index 88d12b63bc072..3562a448bff84 100644 --- a/velox/type/DecimalUtil.h +++ b/velox/type/DecimalUtil.h @@ -304,7 +304,23 @@ class DecimalUtil { } } - /// Origins from java side BigInteger#bitLength. + inline static std::optional computeValidSum( + int128_t sum, + int64_t overflow) { + // Value is valid if the conditions below are true. + int128_t validSum = sum; + if ((overflow == 1 && sum < 0) || (overflow == -1 && sum > 0)) { + validSum = static_cast( + DecimalUtil::kOverflowMultiplier * overflow + sum); + } else { + if (overflow != 0) { + return std::nullopt; + } + } + return validSum; + } + + / // Origins from java side BigInteger#bitLength. /// /// Returns the number of bits in the minimal two's-complement /// representation of this BigInteger, excluding a sign bit.