From 7434f98d578c83b71598ba5812b99d17e476d510 Mon Sep 17 00:00:00 2001 From: "joey.ljy" Date: Tue, 20 Jun 2023 10:29:46 +0800 Subject: [PATCH] Spark sql sum agg function support decimal (#5372) --- .../sparksql/aggregates/DecimalSumAggregate.h | 388 ++++++++++++++++++ .../sparksql/aggregates/SumAggregate.cpp | 52 ++- .../aggregates/tests/SumAggregationTest.cpp | 306 ++++++++++++++ 3 files changed, 744 insertions(+), 2 deletions(-) create mode 100644 velox/functions/sparksql/aggregates/DecimalSumAggregate.h diff --git a/velox/functions/sparksql/aggregates/DecimalSumAggregate.h b/velox/functions/sparksql/aggregates/DecimalSumAggregate.h new file mode 100644 index 000000000000..b9645f8d849f --- /dev/null +++ b/velox/functions/sparksql/aggregates/DecimalSumAggregate.h @@ -0,0 +1,388 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "velox/exec/Aggregate.h" +#include "velox/expression/FunctionSignature.h" +#include "velox/vector/FlatVector.h" + +namespace facebook::velox::functions::aggregate::sparksql { + +struct DecimalSum { + int128_t sum{0}; + int64_t overflow{0}; + bool isEmpty{true}; + + void mergeWith(const DecimalSum& other) { + this->overflow += other.overflow; + this->overflow += + DecimalUtil::addWithOverflow(this->sum, other.sum, this->sum); + this->isEmpty &= other.isEmpty; + } +}; + +template +class DecimalSumAggregate : public exec::Aggregate { + public: + explicit DecimalSumAggregate(TypePtr resultType, TypePtr sumType) + : exec::Aggregate(resultType), sumType_(sumType) {} + + int32_t accumulatorFixedWidthSize() const override { + return sizeof(DecimalSum); + } + + int32_t accumulatorAlignmentSize() const override { + return alignof(DecimalSum); + } + + void initializeNewGroups( + char** groups, + folly::Range indices) override { + setAllNulls(groups, indices); + for (auto i : indices) { + new (groups[i] + offset_) DecimalSum(); + } + } + + int128_t computeFinalValue(DecimalSum* decimalSum, bool& overflow) { + int128_t sum = decimalSum->sum; + if ((decimalSum->overflow == 1 && decimalSum->sum < 0) || + (decimalSum->overflow == -1 && decimalSum->sum > 0)) { + sum = static_cast( + DecimalUtil::kOverflowMultiplier * decimalSum->overflow + + decimalSum->sum); + } else { + if (decimalSum->overflow != 0) { + overflow = true; + return 0; + } + } + + auto [resultPrecision, resultScale] = + getDecimalPrecisionScale(*sumType_.get()); + overflow = !DecimalUtil::valueInPrecisionRange(sum, resultPrecision); + return sum; + } + + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) + override { + VELOX_CHECK_EQ((*result)->encoding(), VectorEncoding::Simple::FLAT); + auto vector = (*result)->as>(); + VELOX_CHECK(vector); + vector->resize(numGroups); + uint64_t* rawNulls = getRawNulls(vector); + + TResultType* rawValues = vector->mutableRawValues(); + for (auto i = 0; i < numGroups; ++i) { + char* group = groups[i]; + if (isNull(group)) { + vector->setNull(i, true); + } else { + clearNull(rawNulls, i); + auto* decimalSum = accumulator(group); + if (decimalSum->isEmpty) { + // If isEmpty is true, we should set null. + vector->setNull(i, true); + } else { + bool overflow = false; + auto result = (TResultType)computeFinalValue(decimalSum, overflow); + if (overflow) { + // Sum should be set to null on overflow. + vector->setNull(i, true); + } else { + rawValues[i] = result; + } + } + } + } + } + + void extractAccumulators( + char** groups, + int32_t numGroups, + facebook::velox::VectorPtr* result) override { + VELOX_CHECK_EQ((*result)->encoding(), VectorEncoding::Simple::ROW); + auto rowVector = (*result)->as(); + auto sumVector = rowVector->childAt(0)->asFlatVector(); + auto isEmptyVector = rowVector->childAt(1)->asFlatVector(); + + rowVector->resize(numGroups); + sumVector->resize(numGroups); + isEmptyVector->resize(numGroups); + + TResultType* rawSums = sumVector->mutableRawValues(); + // Bool uses compact representation, use mutableRawValues + // and bits::setBit instead. + auto* rawIsEmpty = isEmptyVector->mutableRawValues(); + uint64_t* rawNulls = getRawNulls(rowVector); + + for (auto i = 0; i < numGroups; ++i) { + char* group = groups[i]; + clearNull(rawNulls, i); + if (isNull(group)) { + rawSums[i] = 0; + bits::setBit(rawIsEmpty, i, true); + } else { + auto* decimalSum = accumulator(group); + bool overflow = false; + auto result = (TResultType)computeFinalValue(decimalSum, overflow); + if (overflow) { + // Sum should be set to null on overflow, and + // isEmpty should be set to false. + sumVector->setNull(i, true); + bits::setBit(rawIsEmpty, i, false); + } else { + rawSums[i] = result; + bits::setBit(rawIsEmpty, i, decimalSum->isEmpty); + } + } + } + } + + void addRawInput( + char** groups, + const facebook::velox::SelectivityVector& rows, + const std::vector& args, + bool /* mayPushdown */) override { + decodedRaw_.decode(*args[0], rows); + if (decodedRaw_.isConstantMapping()) { + if (!decodedRaw_.isNullAt(0)) { + auto value = decodedRaw_.valueAt(0); + rows.applyToSelected([&](vector_size_t i) { + updateNonNullValue(groups[i], value, false); + }); + } + } else if (decodedRaw_.mayHaveNulls()) { + rows.applyToSelected([&](vector_size_t i) { + if (decodedRaw_.isNullAt(i)) { + return; + } + updateNonNullValue( + groups[i], decodedRaw_.valueAt(i), false); + }); + } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + auto data = decodedRaw_.data(); + rows.applyToSelected([&](vector_size_t i) { + updateNonNullValue(groups[i], data[i], false); + }); + } else { + rows.applyToSelected([&](vector_size_t i) { + updateNonNullValue( + groups[i], decodedRaw_.valueAt(i), false); + }); + } + } + + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /* mayPushdown */) override { + decodedRaw_.decode(*args[0], rows); + if (decodedRaw_.isConstantMapping()) { + if (!decodedRaw_.isNullAt(0)) { + auto value = decodedRaw_.valueAt(0); + rows.template applyToSelected( + [&](vector_size_t i) { updateNonNullValue(group, value, false); }); + } + } else if (decodedRaw_.mayHaveNulls()) { + rows.applyToSelected([&](vector_size_t i) { + if (!decodedRaw_.isNullAt(i)) { + updateNonNullValue(group, decodedRaw_.valueAt(i), false); + } + }); + } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + auto data = decodedRaw_.data(); + DecimalSum decimalSum; + rows.applyToSelected([&](vector_size_t i) { + decimalSum.overflow += DecimalUtil::addWithOverflow( + decimalSum.sum, data[i], decimalSum.sum); + decimalSum.isEmpty = false; + }); + mergeAccumulators(group, decimalSum); + } else { + DecimalSum decimalSum; + rows.applyToSelected([&](vector_size_t i) { + decimalSum.overflow += DecimalUtil::addWithOverflow( + decimalSum.sum, decodedRaw_.valueAt(i), decimalSum.sum); + decimalSum.isEmpty = false; + }); + mergeAccumulators(group, decimalSum); + } + } + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /* mayPushdown */) override { + decodedPartial_.decode(*args[0], rows); + VELOX_CHECK_EQ( + decodedPartial_.base()->encoding(), VectorEncoding::Simple::ROW); + auto baseRowVector = dynamic_cast(decodedPartial_.base()); + auto sumVector = baseRowVector->childAt(0)->as>(); + auto isEmptyVector = baseRowVector->childAt(1)->as>(); + + if (decodedPartial_.isConstantMapping()) { + if (!decodedPartial_.isNullAt(0)) { + auto decodedIndex = decodedPartial_.index(0); + if (isIntermediateResultOverflow( + isEmptyVector, sumVector, decodedIndex)) { + rows.applyToSelected([&](vector_size_t i) { setNull(groups[i]); }); + } else { + auto sum = sumVector->valueAt(decodedIndex); + auto isEmpty = isEmptyVector->valueAt(decodedIndex); + rows.applyToSelected([&](vector_size_t i) { + clearNull(groups[i]); + updateNonNullValue(groups[i], sum, isEmpty); + }); + } + } + } else if (decodedPartial_.mayHaveNulls()) { + rows.applyToSelected([&](vector_size_t i) { + if (decodedPartial_.isNullAt(i)) { + return; + } + auto decodedIndex = decodedPartial_.index(i); + if (isIntermediateResultOverflow( + isEmptyVector, sumVector, decodedIndex)) { + setNull(groups[i]); + } else { + auto sum = sumVector->valueAt(decodedIndex); + auto isEmpty = isEmptyVector->valueAt(decodedIndex); + updateNonNullValue(groups[i], sum, isEmpty); + } + }); + } else { + rows.applyToSelected([&](vector_size_t i) { + clearNull(groups[i]); + auto decodedIndex = decodedPartial_.index(i); + if (isIntermediateResultOverflow( + isEmptyVector, sumVector, decodedIndex)) { + setNull(groups[i]); + } else { + auto sum = sumVector->valueAt(decodedIndex); + auto isEmpty = isEmptyVector->valueAt(decodedIndex); + updateNonNullValue(groups[i], sum, isEmpty); + } + }); + } + } + + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /* mayPushdown */) override { + decodedPartial_.decode(*args[0], rows); + VELOX_CHECK_EQ( + decodedPartial_.base()->encoding(), VectorEncoding::Simple::ROW); + auto baseRowVector = dynamic_cast(decodedPartial_.base()); + auto sumVector = baseRowVector->childAt(0)->as>(); + auto isEmptyVector = baseRowVector->childAt(1)->as>(); + if (decodedPartial_.isConstantMapping()) { + if (!decodedPartial_.isNullAt(0)) { + auto decodedIndex = decodedPartial_.index(0); + if (isIntermediateResultOverflow( + isEmptyVector, sumVector, decodedIndex)) { + setNull(group); + } else { + auto sum = sumVector->valueAt(decodedIndex); + auto isEmpty = isEmptyVector->valueAt(decodedIndex); + if (rows.hasSelections()) { + clearNull(group); + } + rows.applyToSelected([&](vector_size_t i) { + updateNonNullValue(group, sum, isEmpty); + }); + } + } + } else if (decodedPartial_.mayHaveNulls()) { + rows.applyToSelected([&](vector_size_t i) { + if (decodedPartial_.isNullAt(i)) { + return; + } + auto decodedIndex = decodedPartial_.index(i); + if (isIntermediateResultOverflow( + isEmptyVector, sumVector, decodedIndex)) { + setNull(group); + return; + } else { + clearNull(group); + auto sum = sumVector->valueAt(decodedIndex); + auto isEmpty = isEmptyVector->valueAt(decodedIndex); + updateNonNullValue(group, sum, isEmpty); + } + }); + } else { + if (rows.hasSelections()) { + clearNull(group); + } + rows.applyToSelected([&](vector_size_t i) { + auto decodedIndex = decodedPartial_.index(i); + if (isIntermediateResultOverflow( + isEmptyVector, sumVector, decodedIndex)) { + setNull(group); + return; + } else { + auto sum = sumVector->valueAt(decodedIndex); + auto isEmpty = isEmptyVector->valueAt(decodedIndex); + updateNonNullValue(group, sum, isEmpty); + } + }); + } + } + + private: + template + inline void updateNonNullValue(char* group, TResultType value, bool isEmpty) { + if constexpr (tableHasNulls) { + exec::Aggregate::clearNull(group); + } + auto decimalSum = accumulator(group); + decimalSum->overflow += + DecimalUtil::addWithOverflow(decimalSum->sum, value, decimalSum->sum); + decimalSum->isEmpty &= isEmpty; + } + + template + inline void mergeAccumulators(char* group, DecimalSum other) { + if constexpr (tableHasNulls) { + exec::Aggregate::clearNull(group); + } + auto decimalSum = accumulator(group); + decimalSum->mergeWith(other); + } + + inline DecimalSum* accumulator(char* group) { + return exec::Aggregate::value(group); + } + + inline bool isIntermediateResultOverflow( + const SimpleVector* isEmptyVector, + const SimpleVector* sumVector, + vector_size_t index) { + // If isEmpty is false and sum is null, it means this intermediate + // result has an overflow. The final accumulator of this group will + // be null. + return !isEmptyVector->valueAt(index) && sumVector->isNullAt(index); + } + + DecodedVector decodedRaw_; + DecodedVector decodedPartial_; + TypePtr sumType_; +}; + +} // namespace facebook::velox::functions::aggregate::sparksql diff --git a/velox/functions/sparksql/aggregates/SumAggregate.cpp b/velox/functions/sparksql/aggregates/SumAggregate.cpp index 9c36ce39efbb..4a683e4c2f22 100644 --- a/velox/functions/sparksql/aggregates/SumAggregate.cpp +++ b/velox/functions/sparksql/aggregates/SumAggregate.cpp @@ -16,6 +16,7 @@ #include "velox/functions/sparksql/aggregates/SumAggregate.h" #include "velox/functions/lib/aggregates/SumAggregateBase.h" +#include "velox/functions/sparksql/aggregates/DecimalSumAggregate.h" using namespace facebook::velox::functions::aggregate; @@ -24,7 +25,19 @@ namespace facebook::velox::functions::aggregate::sparksql { namespace { template using SumAggregate = SumAggregateBase; + +TypePtr getDecimalSumType( + const TypePtr& resultType, + core::AggregationNode::Step step) { + if (exec::isPartialOutput(step)) { + return resultType->childAt(0); + } + if (step == core::AggregationNode::Step::kSingle && resultType->isRow()) { + return resultType->childAt(0); + } + return resultType; } +} // namespace exec::AggregateRegistrationResult registerSum( const std::string& name, @@ -41,6 +54,15 @@ exec::AggregateRegistrationResult registerSum( .intermediateType("double") .argumentType("double") .build(), + exec::AggregateFunctionSignatureBuilder() + .integerVariable("a_precision") + .integerVariable("a_scale") + .integerVariable("r_precision", "min(38, a_precision + 10)") + .integerVariable("r_scale", "min(38, a_scale)") + .argumentType("DECIMAL(a_precision, a_scale)") + .intermediateType("ROW(DECIMAL(r_precision, r_scale), boolean)") + .returnType("DECIMAL(r_precision, r_scale)") + .build(), }; for (const auto& inputType : {"tinyint", "smallint", "integer", "bigint"}) { @@ -74,13 +96,26 @@ exec::AggregateRegistrationResult registerSum( BIGINT()); case TypeKind::BIGINT: { if (inputType->isShortDecimal()) { - VELOX_NYI(); + auto sumType = getDecimalSumType(resultType, step); + if (sumType->isShortDecimal()) { + return std::make_unique>( + resultType, sumType); + } else if (sumType->isLongDecimal()) { + return std::make_unique>( + resultType, sumType); + } } return std::make_unique>( BIGINT()); } case TypeKind::HUGEINT: { - VELOX_NYI(); + if (inputType->isLongDecimal()) { + auto sumType = getDecimalSumType(resultType, step); + // If inputType is long decimal, + // its output type always be long decimal. + return std::make_unique>( + resultType, sumType); + } } case TypeKind::REAL: if (resultType->kind() == TypeKind::REAL) { @@ -96,6 +131,19 @@ exec::AggregateRegistrationResult registerSum( } return std::make_unique>( DOUBLE()); + case TypeKind::ROW: { + VELOX_DCHECK(!exec::isRawInput(step)); + auto sumType = getDecimalSumType(resultType, step); + // For intermediate input agg, input intermediate sum type + // is equal to final result sum type. + if (inputType->childAt(0)->isShortDecimal()) { + return std::make_unique>( + resultType, sumType); + } else if (inputType->childAt(0)->isLongDecimal()) { + return std::make_unique>( + resultType, sumType); + } + } default: VELOX_CHECK( false, diff --git a/velox/functions/sparksql/aggregates/tests/SumAggregationTest.cpp b/velox/functions/sparksql/aggregates/tests/SumAggregationTest.cpp index 10a088c2db20..2f1f722889e2 100644 --- a/velox/functions/sparksql/aggregates/tests/SumAggregationTest.cpp +++ b/velox/functions/sparksql/aggregates/tests/SumAggregationTest.cpp @@ -14,9 +14,13 @@ * limitations under the License. */ +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/functions/lib/aggregates/tests/SumTestBase.h" #include "velox/functions/sparksql/aggregates/Register.h" +using facebook::velox::exec::test::PlanBuilder; +using namespace facebook::velox::exec::test; using namespace facebook::velox::functions::aggregate::test; namespace facebook::velox::functions::aggregate::sparksql::test { @@ -28,6 +32,82 @@ class SumAggregationTest : public SumTestBase { SumTestBase::SetUp(); registerAggregateFunctions("spark_"); } + + protected: + // check global partial agg overflow, and final agg output null + void decimalGlobalSumOverflow( + const std::vector>& input, + const std::vector>& output) { + const TypePtr type = DECIMAL(38, 0); + auto in = makeRowVector({makeNullableFlatVector({input}, type)}); + auto expected = + makeRowVector({makeNullableFlatVector({output}, type)}); + testAggregations( + {in}, + {}, + {"spark_sum(c0)"}, + {expected}, + /*config*/ {}, + /*testWithTableScan*/ false); + } + + // check group by partial agg overflow, and final agg output null + void decimalGroupBySumOverflow( + const std::vector>& input) { + const TypePtr type = DECIMAL(38, 0); + auto in = makeRowVector( + {makeFlatVector(20, [](auto row) { return row % 10; }), + makeNullableFlatVector(input, type)}); + auto expected = makeRowVector( + {makeFlatVector(10, [](auto row) { return row; }), + makeNullableFlatVector( + std::vector>(10, std::nullopt), type)}); + testAggregations( + {in}, + {"c0"}, + {"spark_sum(c1)"}, + {expected}, + /*config*/ {}, + /*testWithTableScan*/ false); + } + + template + void decimalSumAllNulls( + const std::vector>& input, + const TypePtr& inputType, + const std::vector>& output, + const TypePtr& outputType) { + std::vector vectors; + FlatVectorPtr inputDecimalVector; + if constexpr (std::is_same_v) { + inputDecimalVector = makeNullableFlatVector(input, inputType); + } else { + inputDecimalVector = makeNullableFlatVector(input, inputType); + } + for (int i = 0; i < 5; ++i) { + vectors.emplace_back(makeRowVector( + {makeFlatVector(20, [](auto row) { return row % 4; }), + inputDecimalVector})); + } + + FlatVectorPtr outputDecimalVector; + if constexpr (std::is_same_v) { + outputDecimalVector = makeNullableFlatVector(output, outputType); + } else { + outputDecimalVector = + makeNullableFlatVector(output, outputType); + } + auto expected = makeRowVector( + {makeFlatVector(std::vector{0, 1, 2, 3}), + outputDecimalVector}); + testAggregations( + {vectors}, + {"c0"}, + {"spark_sum(c1)"}, + {expected}, + /*config*/ {}, + /*testWithTableScan*/ false); + } }; TEST_F(SumAggregationTest, overflow) { @@ -38,5 +118,231 @@ TEST_F(SumAggregationTest, hookLimits) { testHookLimits(); } +TEST_F(SumAggregationTest, decimalSum) { + std::vector> shortDecimalRawVector; + std::vector> longDecimalRawVector; + for (int i = 0; i < 1000; ++i) { + shortDecimalRawVector.emplace_back(i * 1000); + longDecimalRawVector.emplace_back(HugeInt::build(i * 10, i * 100)); + } + shortDecimalRawVector.emplace_back(std::nullopt); + longDecimalRawVector.emplace_back(std::nullopt); + auto input = makeRowVector( + {makeNullableFlatVector(shortDecimalRawVector, DECIMAL(10, 1)), + makeNullableFlatVector(longDecimalRawVector, DECIMAL(23, 4))}); + createDuckDbTable({input}); + testAggregations( + {input}, + {}, + {"spark_sum(c0)", "spark_sum(c1)"}, + "SELECT sum(c0), sum(c1) FROM tmp", + /*config*/ {}, + /*testWithTableScan*/ false); + + // Short decimal sum aggregation with multiple groups. + auto inputShortDecimalRows = { + makeRowVector( + {makeNullableFlatVector({1, 1}), + makeFlatVector( + std::vector{37220, 53450}, DECIMAL(5, 2))}), + makeRowVector( + {makeNullableFlatVector({2, 2}), + makeFlatVector( + std::vector{10410, 9250}, DECIMAL(5, 2))}), + makeRowVector( + {makeNullableFlatVector({3, 3}), + makeFlatVector( + std::vector{-12783, 0}, DECIMAL(5, 2))}), + makeRowVector( + {makeNullableFlatVector({1, 2}), + makeFlatVector( + std::vector{23178, 41093}, DECIMAL(5, 2))}), + makeRowVector( + {makeNullableFlatVector({2, 3}), + makeFlatVector( + std::vector{-10023, 5290}, DECIMAL(5, 2))}), + }; + + auto expectedShortDecimalResult = { + makeRowVector( + {makeNullableFlatVector({1}), + makeFlatVector( + std::vector{113848}, DECIMAL(15, 2))}), + makeRowVector( + {makeNullableFlatVector({2}), + makeFlatVector( + std::vector{50730}, DECIMAL(15, 2))}), + makeRowVector( + {makeNullableFlatVector({3}), + makeFlatVector( + std::vector{-7493}, DECIMAL(15, 2))})}; + + testAggregations( + inputShortDecimalRows, + {"c0"}, + {"spark_sum(c1)"}, + expectedShortDecimalResult, + /*config*/ {}, + /*testWithTableScan*/ false); + + // Long decimal sum aggregation with multiple groups. + auto inputLongDecimalRows = { + makeRowVector( + {makeNullableFlatVector({1, 1}), + makeFlatVector( + {HugeInt::build(13, 113848), HugeInt::build(12, 53450)}, + DECIMAL(20, 2))}), + makeRowVector( + {makeNullableFlatVector({2, 2}), + makeFlatVector( + {HugeInt::build(21, 10410), HugeInt::build(17, 9250)}, + DECIMAL(20, 2))}), + makeRowVector( + {makeNullableFlatVector({3, 3}), + makeFlatVector( + {HugeInt::build(25, 12783), HugeInt::build(19, 0)}, + DECIMAL(20, 2))}), + makeRowVector( + {makeNullableFlatVector({1, 2}), + makeFlatVector( + {HugeInt::build(31, 23178), HugeInt::build(82, 41093)}, + DECIMAL(20, 2))}), + makeRowVector( + {makeNullableFlatVector({2, 3}), + makeFlatVector( + {HugeInt::build(25, 10023), HugeInt::build(43, 5290)}, + DECIMAL(20, 2))}), + }; + + auto expectedLongDecimalResult = { + makeRowVector( + {makeNullableFlatVector({1}), + makeFlatVector( + std::vector{HugeInt::build(56, 190476)}, + DECIMAL(38, 2))}), + makeRowVector( + {makeNullableFlatVector({2}), + makeFlatVector( + std::vector{HugeInt::build(145, 70776)}, + DECIMAL(38, 2))}), + makeRowVector( + {makeNullableFlatVector({3}), + makeFlatVector( + std::vector{HugeInt::build(87, 18073)}, + DECIMAL(38, 2))})}; + + testAggregations( + inputLongDecimalRows, + {"c0"}, + {"spark_sum(c1)"}, + expectedLongDecimalResult, + /*config*/ {}, + /*testWithTableScan*/ false); +} + +TEST_F(SumAggregationTest, decimalGlobalSumOverflow) { + // Test Positive Overflow. + std::vector> longDecimalInput; + std::vector> longDecimalOutput; + // Create input with 2 kLongDecimalMax. + longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMax); + longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMax); + // The sum must overflow, and will return null + decimalGlobalSumOverflow(longDecimalInput, {std::nullopt}); + + // Now add kLongDecimalMin. + // The sum now must not overflow. + longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMin); + longDecimalOutput.emplace_back(DecimalUtil::kLongDecimalMax); + decimalGlobalSumOverflow(longDecimalInput, longDecimalOutput); + + // Test Negative Overflow. + longDecimalInput.clear(); + longDecimalOutput.clear(); + + // Create input with 2 kLongDecimalMin. + longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMin); + longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMin); + + // The sum must overflow, and will return null + decimalGlobalSumOverflow(longDecimalInput, {std::nullopt}); + + // Now add kLongDecimalMax. + // The sum now must not overflow. + longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMax); + longDecimalOutput.emplace_back(DecimalUtil::kLongDecimalMin); + decimalGlobalSumOverflow(longDecimalInput, longDecimalOutput); + + // Check value in range. + longDecimalInput.clear(); + longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMax); + longDecimalInput.emplace_back(1); + decimalGlobalSumOverflow(longDecimalInput, {std::nullopt}); + + longDecimalInput.clear(); + longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMin); + longDecimalInput.emplace_back(-1); + decimalGlobalSumOverflow(longDecimalInput, {std::nullopt}); +} + +TEST_F(SumAggregationTest, decimalGroupBySumOverflow) { + // Test Positive Overflow. + decimalGroupBySumOverflow( + std::vector>(20, DecimalUtil::kLongDecimalMax)); + + // Test Negative Overflow. + decimalGroupBySumOverflow( + std::vector>(20, DecimalUtil::kLongDecimalMin)); + + // Check value in range. + auto decimalVector = + std::vector>(10, DecimalUtil::kLongDecimalMax); + auto oneValueVector = std::vector>(10, 1); + decimalVector.insert( + decimalVector.end(), oneValueVector.begin(), oneValueVector.end()); + decimalGroupBySumOverflow(decimalVector); + + decimalVector = + std::vector>(10, DecimalUtil::kLongDecimalMin); + oneValueVector = std::vector>(10, -1); + decimalVector.insert( + decimalVector.end(), oneValueVector.begin(), oneValueVector.end()); + decimalGroupBySumOverflow(decimalVector); +} + +// Test if all values in some groups are null, the final sum of this group +// should be null. +TEST_F(SumAggregationTest, decimalSomeGroupsAllnullValues) { + std::vector> shortDecimalNulls(20); + std::vector> longDecimalNulls(20); + for (int i = 0; i < 20; i++) { + if (i % 4 == 1 || i % 4 == 3) { + // not all groups are null + shortDecimalNulls[i] = 1; + longDecimalNulls[i] = 1; + } + } + + // Test short decimal inputs and the output sum is short decimal. + decimalSumAllNulls( + shortDecimalNulls, + DECIMAL(7, 2), + std::vector>{std::nullopt, 25, std::nullopt, 25}, + DECIMAL(17, 2)); + + // Test short decimal inputs and the output sum is long decimal. + decimalSumAllNulls( + shortDecimalNulls, + DECIMAL(17, 2), + std::vector>{std::nullopt, 25, std::nullopt, 25}, + DECIMAL(27, 2)); + + // Test long decimal inputs and the output sum is long decimal. + decimalSumAllNulls( + longDecimalNulls, + DECIMAL(25, 2), + std::vector>{std::nullopt, 25, std::nullopt, 25}, + DECIMAL(35, 2)); +} } // namespace } // namespace facebook::velox::functions::aggregate::sparksql::test