diff --git a/velox/functions/lib/aggregates/AverageAggregateBase.cpp b/velox/functions/lib/aggregates/AverageAggregateBase.cpp index efef798b6202..3353caed48b6 100644 --- a/velox/functions/lib/aggregates/AverageAggregateBase.cpp +++ b/velox/functions/lib/aggregates/AverageAggregateBase.cpp @@ -21,14 +21,16 @@ namespace facebook::velox::functions::aggregate { void checkAvgIntermediateType(const TypePtr& type) { VELOX_USER_CHECK( type->isRow() || type->isVarbinary(), - "Input type for final average must be row type or varbinary type."); + "Input type for final average must be row type or varbinary type, find {}", + type->toString()); if (type->kind() == TypeKind::VARBINARY) { return; } VELOX_USER_CHECK( type->childAt(0)->kind() == TypeKind::DOUBLE || type->childAt(0)->isLongDecimal(), - "Input type for sum in final average must be double or long decimal type.") + "Input type for sum in final average must be double or long decimal type, find {}", + type->childAt(0)->toString()); VELOX_USER_CHECK_EQ( type->childAt(1)->kind(), TypeKind::BIGINT, diff --git a/velox/functions/lib/aggregates/DecimalAggregate.h b/velox/functions/lib/aggregates/DecimalAggregate.h index def5c81b52e0..6c0415d82cff 100644 --- a/velox/functions/lib/aggregates/DecimalAggregate.h +++ b/velox/functions/lib/aggregates/DecimalAggregate.h @@ -78,7 +78,7 @@ class DecimalAggregate : public exec::Aggregate { } int32_t accumulatorAlignmentSize() const override { - return static_cast(sizeof(int128_t)); + return alignof(LongDecimalWithOverflowState); } void addRawInput( @@ -275,7 +275,9 @@ class DecimalAggregate : public exec::Aggregate { } virtual TResultType computeFinalValue( - LongDecimalWithOverflowState* accumulator) = 0; + LongDecimalWithOverflowState* accumulator) { + return 0; + }; void extractValues(char** groups, int32_t numGroups, VectorPtr* result) override { @@ -327,11 +329,11 @@ class DecimalAggregate : public exec::Aggregate { } } - private: inline LongDecimalWithOverflowState* decimalAccumulator(char* group) { return exec::Aggregate::value(group); } + private: DecodedVector decodedRaw_; DecodedVector decodedPartial_; }; diff --git a/velox/functions/sparksql/aggregates/AverageAggregate.cpp b/velox/functions/sparksql/aggregates/AverageAggregate.cpp index 2149651cbbb3..1d03defe9312 100644 --- a/velox/functions/sparksql/aggregates/AverageAggregate.cpp +++ b/velox/functions/sparksql/aggregates/AverageAggregate.cpp @@ -16,6 +16,7 @@ #include "velox/functions/sparksql/aggregates/AverageAggregate.h" #include "velox/functions/lib/aggregates/AverageAggregateBase.h" +#include "velox/functions/sparksql/DecimalUtil.h" using namespace facebook::velox::functions::aggregate; @@ -74,6 +75,307 @@ class AverageAggregate } }; +template +class DecimalAverageAggregate : public DecimalAggregate { + public: + explicit DecimalAverageAggregate(TypePtr resultType, TypePtr sumType) + : DecimalAggregate(resultType), sumType_(sumType) {} + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /* mayPushdown */) override { + decodedPartial_.decode(*args[0], rows); + auto baseRowVector = dynamic_cast(decodedPartial_.base()); + auto sumVector = baseRowVector->childAt(0)->as>(); + auto countVector = baseRowVector->childAt(1)->as>(); + VELOX_USER_CHECK_NOT_NULL(sumVector); + + if (decodedPartial_.isConstantMapping()) { + if (!decodedPartial_.isNullAt(0)) { + auto decodedIndex = decodedPartial_.index(0); + auto count = countVector->valueAt(decodedIndex); + if (sumVector->isNullAt(decodedIndex) && + !countVector->isNullAt(decodedIndex) && count > 0) { + // Find overflow, set all groups to null. + rows.applyToSelected( + [&](vector_size_t i) { this->setNull(groups[i]); }); + } else { + auto sum = sumVector->valueAt(decodedIndex); + rows.applyToSelected([&](vector_size_t i) { + this->clearNull(groups[i]); + auto accumulator = this->decimalAccumulator(groups[i]); + mergeSumCount(accumulator, sum, count); + }); + } + } + } else if (decodedPartial_.mayHaveNulls()) { + rows.applyToSelected([&](vector_size_t i) { + if (decodedPartial_.isNullAt(i)) { + return; + } + auto decodedIndex = decodedPartial_.index(i); + auto count = countVector->valueAt(decodedIndex); + if (sumVector->isNullAt(decodedIndex) && + !countVector->isNullAt(decodedIndex) && count > 0) { + this->setNull(groups[i]); + } else { + this->clearNull(groups[i]); + auto sum = sumVector->valueAt(decodedIndex); + auto accumulator = this->decimalAccumulator(groups[i]); + mergeSumCount(accumulator, sum, count); + } + }); + } else { + rows.applyToSelected([&](vector_size_t i) { + auto decodedIndex = decodedPartial_.index(i); + auto count = countVector->valueAt(decodedIndex); + if (sumVector->isNullAt(decodedIndex) && + !countVector->isNullAt(decodedIndex) && count > 0) { + this->setNull(groups[i]); + } else { + this->clearNull(groups[i]); + auto sum = sumVector->valueAt(decodedIndex); + auto accumulator = this->decimalAccumulator(groups[i]); + mergeSumCount(accumulator, sum, count); + } + }); + } + } + + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /* mayPushdown */) override { + decodedPartial_.decode(*args[0], rows); + auto baseRowVector = dynamic_cast(decodedPartial_.base()); + auto sumVector = baseRowVector->childAt(0)->as>(); + auto countVector = baseRowVector->childAt(1)->as>(); + + if (decodedPartial_.isConstantMapping()) { + if (!decodedPartial_.isNullAt(0)) { + auto decodedIndex = decodedPartial_.index(0); + if (isPartialSumOverflow(sumVector, countVector, decodedIndex)) { + // Find overflow, just set group to null and return. + this->setNull(group); + return; + } else { + if (rows.hasSelections()) { + this->clearNull(group); + } + auto sum = sumVector->valueAt(decodedIndex); + auto count = countVector->valueAt(decodedIndex); + rows.applyToSelected( + [&](vector_size_t i) { mergeAccumulators(group, sum, count); }); + } + } + } else if (decodedPartial_.mayHaveNulls()) { + rows.applyToSelected([&](vector_size_t i) { + if (!decodedPartial_.isNullAt(i)) { + this->clearNull(group); + auto decodedIndex = decodedPartial_.index(i); + if (isPartialSumOverflow(sumVector, countVector, decodedIndex)) { + // Find overflow, just set group to null. + this->setNull(group); + } else { + auto sum = sumVector->valueAt(decodedIndex); + auto count = countVector->valueAt(decodedIndex); + mergeAccumulators(group, sum, count); + } + } + }); + } else { + if (rows.hasSelections()) { + this->clearNull(group); + } + rows.applyToSelected([&](vector_size_t i) { + auto decodedIndex = decodedPartial_.index(i); + if (isPartialSumOverflow(sumVector, countVector, decodedIndex)) { + // Find overflow, just set group to null. + this->setNull(group); + } else { + auto sum = sumVector->valueAt(decodedIndex); + auto count = countVector->valueAt(decodedIndex); + mergeAccumulators(group, sum, count); + } + }); + } + } + + void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) + override { + auto rowVector = (*result)->as(); + auto sumVector = rowVector->childAt(0)->asFlatVector(); + auto countVector = rowVector->childAt(1)->asFlatVector(); + VELOX_USER_CHECK_NOT_NULL(sumVector); + + rowVector->resize(numGroups); + sumVector->resize(numGroups); + countVector->resize(numGroups); + rowVector->clearAllNulls(); + + int64_t* rawCounts = countVector->mutableRawValues(); + int128_t* rawSums = sumVector->mutableRawValues(); + for (auto i = 0; i < numGroups; ++i) { + char* group = groups[i]; + auto* accumulator = this->decimalAccumulator(group); + std::optional adjustedSum = DecimalUtil::adjustSumForOverflow( + accumulator->sum, accumulator->overflow); + if (adjustedSum.has_value()) { + rawCounts[i] = accumulator->count; + rawSums[i] = adjustedSum.value(); + } else { + // Find overflow. + sumVector->setNull(i, true); + rawCounts[i] = accumulator->count; + } + } + } + + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) + override { + auto vector = (*result)->as>(); + VELOX_CHECK(vector); + vector->resize(numGroups); + uint64_t* rawNulls = this->getRawNulls(vector); + + TResultType* rawValues = vector->mutableRawValues(); + for (int32_t i = 0; i < numGroups; ++i) { + char* group = groups[i]; + auto accumulator = this->decimalAccumulator(group); + if (accumulator->count == 0) { + // In Spark, if all inputs are null, count will be 0, + // and the result of final avg will be null. + vector->setNull(i, true); + } else { + this->clearNull(rawNulls, i); + std::optional avg = computeAvg(accumulator); + if (avg.has_value()) { + rawValues[i] = avg.value(); + } else { + // Find overflow. + vector->setNull(i, true); + } + } + } + } + + std::optional computeAvg( + LongDecimalWithOverflowState* accumulator) { + std::optional validSum = DecimalUtil::adjustSumForOverflow( + accumulator->sum, accumulator->overflow); + if (!validSum.has_value()) { + return std::nullopt; + } + + auto [resultPrecision, resultScale] = + getDecimalPrecisionScale(*this->resultType().get()); + // Spark use DECIMAL(20,0) to represent long value. + const uint8_t countPrecision = 20, countScale = 0; + auto [sumPrecision, sumScale] = + getDecimalPrecisionScale(*this->sumType_.get()); + auto [avgPrecision, avgScale] = computeResultPrecisionScale( + sumPrecision, sumScale, countPrecision, countScale); + auto sumRescale = computeRescaleFactor(sumScale, countScale, avgScale); + auto countDecimal = accumulator->count; + int128_t avg = 0; + + bool overflow = false; + functions::sparksql::DecimalUtil:: + divideWithRoundUp( + avg, validSum.value(), countDecimal, sumRescale, overflow); + if (overflow) { + return std::nullopt; + } + TResultType rescaledValue; + const auto status = DecimalUtil::rescaleWithRoundUp( + avg, + avgPrecision, + avgScale, + resultPrecision, + resultScale, + rescaledValue); + return status.ok() ? std::optional(rescaledValue) + : std::nullopt; + } + + private: + template + inline void mergeSumCount( + LongDecimalWithOverflowState* accumulator, + UnscaledType sum, + int64_t count) { + accumulator->count += count; + accumulator->overflow += + DecimalUtil::addWithOverflow(accumulator->sum, sum, accumulator->sum); + } + + template + void mergeAccumulators( + char* group, + const UnscaledType& otherSum, + const int64_t& otherCount) { + if constexpr (tableHasNulls) { + exec::Aggregate::clearNull(group); + } + auto accumulator = this->decimalAccumulator(group); + mergeSumCount(accumulator, otherSum, otherCount); + } + + inline static bool isPartialSumOverflow( + SimpleVector* sumVector, + SimpleVector* countVector, + int32_t index) { + return sumVector->isNullAt(index) && !countVector->isNullAt(index) && + countVector->valueAt(index) > 0; + } + + inline static uint8_t + computeRescaleFactor(uint8_t fromScale, uint8_t toScale, uint8_t rScale) { + return rScale - fromScale + toScale; + } + + inline static std::pair computeResultPrecisionScale( + const uint8_t aPrecision, + const uint8_t aScale, + const uint8_t bPrecision, + const uint8_t bScale) { + uint8_t intDig = aPrecision - aScale + bScale; + uint8_t scale = std::max(6, aScale + bPrecision + 1); + uint8_t precision = intDig + scale; + return functions::sparksql::DecimalUtil::adjustPrecisionScale( + precision, scale); + } + + inline static std::pair adjustPrecisionScale( + const uint8_t precision, + const uint8_t scale) { + VELOX_CHECK(precision >= scale); + if (precision <= 38) { + return {precision, scale}; + } else { + uint8_t intDigits = precision - scale; + uint8_t minScaleValue = std::min(scale, (uint8_t)6); + uint8_t adjustedScale = + std::max((uint8_t)(38 - intDigits), minScaleValue); + return {38, adjustedScale}; + } + } + + DecodedVector decodedRaw_; + DecodedVector decodedPartial_; + TypePtr sumType_; +}; + +TypePtr getDecimalSumType( + const uint8_t rawInputPrecision, + const uint8_t rawInputScale) { + // This computational logic is derived from the definition of Spark SQL. + return DECIMAL(std::min(38, rawInputPrecision + 10), rawInputScale); +} + } // namespace /// Count is BIGINT() while sum and the final aggregates type depends on @@ -99,13 +401,25 @@ exec::AggregateRegistrationResult registerAverage( .build()); } - signatures.push_back(exec::AggregateFunctionSignatureBuilder() - .integerVariable("a_precision") - .integerVariable("a_scale") - .argumentType("DECIMAL(a_precision, a_scale)") - .intermediateType("varbinary") - .returnType("DECIMAL(a_precision, a_scale)") - .build()); + signatures.push_back( + exec::AggregateFunctionSignatureBuilder() + .integerVariable("a_precision") + .integerVariable("a_scale") + .integerVariable("r_precision", "min(38, a_precision + 4)") + .integerVariable("r_scale", "min(38, a_scale + 4)") + .argumentType("DECIMAL(a_precision, a_scale)") + .intermediateType("ROW(DECIMAL(38, a_scale), BIGINT)") + .returnType("DECIMAL(r_precision, r_scale)") + .build()); + + signatures.push_back( + exec::AggregateFunctionSignatureBuilder() + .integerVariable("a_precision") + .integerVariable("a_scale") + .argumentType("DECIMAL(a_precision, a_scale)") + .intermediateType("ROW(DECIMAL(a_precision, a_scale), BIGINT)") + .returnType("DECIMAL(a_precision, a_scale)") + .build()); return exec::registerAggregateFunction( name, @@ -118,7 +432,7 @@ exec::AggregateRegistrationResult registerAverage( -> std::unique_ptr { VELOX_CHECK_LE( argTypes.size(), 1, "{} takes at most one argument", name); - auto inputType = argTypes[0]; + const auto& inputType = argTypes[0]; if (exec::isRawInput(step)) { switch (inputType->kind()) { case TypeKind::SMALLINT: @@ -129,16 +443,39 @@ exec::AggregateRegistrationResult registerAverage( AverageAggregate>(resultType); case TypeKind::BIGINT: { if (inputType->isShortDecimal()) { - return std::make_unique>( - resultType); + auto inputPrecision = inputType->asShortDecimal().precision(); + auto inputScale = inputType->asShortDecimal().scale(); + auto sumType = + DECIMAL(std::min(38, inputPrecision + 10), inputScale); + if (exec::isPartialOutput(step)) { + return std::make_unique< + DecimalAverageAggregate>( + resultType, sumType); + } else { + if (resultType->isShortDecimal()) { + return std::make_unique< + DecimalAverageAggregate>( + resultType, sumType); + } else if (resultType->isLongDecimal()) { + return std::make_unique< + DecimalAverageAggregate>( + resultType, sumType); + } else { + VELOX_FAIL("Result type must be decimal"); + } + } } return std::make_unique< AverageAggregate>(resultType); } case TypeKind::HUGEINT: { if (inputType->isLongDecimal()) { - return std::make_unique>( - resultType); + auto inputPrecision = inputType->asLongDecimal().precision(); + auto inputScale = inputType->asLongDecimal().scale(); + auto sumType = getDecimalSumType(inputPrecision, inputScale); + return std::make_unique< + DecimalAverageAggregate>( + resultType, sumType); } VELOX_NYI(); } @@ -162,26 +499,37 @@ exec::AggregateRegistrationResult registerAverage( resultType); case TypeKind::DOUBLE: case TypeKind::ROW: + if (inputType->childAt(0)->isLongDecimal()) { + return std::make_unique< + DecimalAverageAggregate>( + resultType, inputType->childAt(0)); + } return std::make_unique< AverageAggregate>(resultType); case TypeKind::BIGINT: - return std::make_unique>( - resultType); + VELOX_USER_CHECK(resultType->isShortDecimal()); + return std::make_unique< + DecimalAverageAggregate>( + resultType, inputType->childAt(0)); case TypeKind::HUGEINT: - return std::make_unique>( - resultType); + VELOX_USER_CHECK(resultType->isLongDecimal()); + return std::make_unique< + DecimalAverageAggregate>( + resultType, inputType->childAt(0)); case TypeKind::VARBINARY: if (inputType->isLongDecimal()) { - return std::make_unique>( - resultType); + return std::make_unique< + DecimalAverageAggregate>( + resultType, inputType->childAt(0)); } else if ( inputType->isShortDecimal() || inputType->kind() == TypeKind::VARBINARY) { // If the input and out type are VARBINARY, then the // LongDecimalWithOverflowState is used and the template type // does not matter. - return std::make_unique>( - resultType); + return std::make_unique< + DecimalAverageAggregate>( + resultType, inputType->childAt(0)); } [[fallthrough]]; default: diff --git a/velox/functions/sparksql/aggregates/tests/AverageAggregationTest.cpp b/velox/functions/sparksql/aggregates/tests/AverageAggregationTest.cpp index 93057ef155a5..7ae0b53abe73 100644 --- a/velox/functions/sparksql/aggregates/tests/AverageAggregationTest.cpp +++ b/velox/functions/sparksql/aggregates/tests/AverageAggregationTest.cpp @@ -110,5 +110,132 @@ TEST_F(AverageAggregationTest, avgAllNulls) { assertQuery(plan, expected); } +TEST_F(AverageAggregationTest, avgDecimal) { + int64_t kRescale = DecimalUtil::kPowersOfTen[4]; + // Short decimal aggregation + auto shortDecimal = makeNullableFlatVector( + {1'000, 2'000, 3'000, 4'000, 5'000, std::nullopt}, DECIMAL(10, 1)); + testAggregations( + {makeRowVector({shortDecimal})}, + {}, + {"spark_avg(c0)"}, + {}, + {makeRowVector({makeNullableFlatVector( + {3'000 * kRescale}, DECIMAL(14, 5))})}); + + // Long decimal aggregation + testAggregations( + {makeRowVector({makeNullableFlatVector( + {HugeInt::build(10, 100), + HugeInt::build(10, 200), + HugeInt::build(10, 300), + HugeInt::build(10, 400), + HugeInt::build(10, 500), + std::nullopt}, + DECIMAL(23, 4))})}, + {}, + {"spark_avg(c0)"}, + {}, + {makeRowVector({makeFlatVector( + std::vector{HugeInt::build(10, 300) * kRescale}, + DECIMAL(27, 8))})}); + + // The total sum overflows the max int128_t limit. + std::vector rawVector; + for (int i = 0; i < 10; ++i) { + rawVector.push_back(DecimalUtil::kLongDecimalMax); + } + testAggregations( + {makeRowVector({makeFlatVector(rawVector, DECIMAL(38, 0))})}, + {}, + {"spark_avg(c0)"}, + {}, + {makeRowVector({makeNullableFlatVector( + std::vector>{std::nullopt}, + DECIMAL(38, 4))})}); + + // The total sum underflows the min int128_t limit. + rawVector.clear(); + auto underFlowTestResult = makeNullableFlatVector( + std::vector>{std::nullopt}, DECIMAL(38, 4)); + for (int i = 0; i < 10; ++i) { + rawVector.push_back(DecimalUtil::kLongDecimalMin); + } + testAggregations( + {makeRowVector({makeFlatVector(rawVector, DECIMAL(38, 0))})}, + {}, + {"spark_avg(c0)"}, + {}, + {makeRowVector({underFlowTestResult})}); + + // Test constant vector. + testAggregations( + {makeRowVector({makeConstant(100, 10, DECIMAL(10, 2))})}, + {}, + {"spark_avg(c0)"}, + {}, + {makeRowVector({makeFlatVector( + std::vector{100 * kRescale}, DECIMAL(14, 6))})}); + + auto newSize = shortDecimal->size() * 2; + auto indices = makeIndices(newSize, [&](int row) { return row / 2; }); + auto dictVector = + VectorTestBase::wrapInDictionary(indices, newSize, shortDecimal); + + testAggregations( + {makeRowVector({dictVector})}, + {}, + {"spark_avg(c0)"}, + {}, + {makeRowVector({makeFlatVector( + std::vector{3'000 * kRescale}, DECIMAL(14, 5))})}); + + // Decimal average aggregation with multiple groups. + auto inputRows = { + makeRowVector( + {makeNullableFlatVector({1, 1}), + makeFlatVector({37220, 53450}, DECIMAL(15, 2))}), + makeRowVector( + {makeNullableFlatVector({2, 2}), + makeFlatVector({10410, 9250}, DECIMAL(15, 2))}), + makeRowVector( + {makeNullableFlatVector({3, 3}), + makeFlatVector({-12783, 0}, DECIMAL(15, 2))}), + makeRowVector( + {makeNullableFlatVector({1, 2}), + makeFlatVector({23178, 41093}, DECIMAL(15, 2))}), + makeRowVector( + {makeNullableFlatVector({2, 3}), + makeFlatVector({-10023, 5290}, DECIMAL(15, 2))}), + }; + + auto expectedResult = { + makeRowVector( + {makeNullableFlatVector({1}), + makeFlatVector(std::vector{379493333}, DECIMAL(19, 6))}), + makeRowVector( + {makeNullableFlatVector({2}), + makeFlatVector(std::vector{126825000}, DECIMAL(19, 6))}), + makeRowVector( + {makeNullableFlatVector({3}), + makeFlatVector(std::vector{-24976667}, DECIMAL(19, 6))})}; + + testAggregations(inputRows, {"c0"}, {"spark_avg(c1)"}, expectedResult); +} + +TEST_F(AverageAggregationTest, avgDecimalWithMultipleRowVectors) { + int64_t kRescale = DecimalUtil::kPowersOfTen[4]; + auto inputRows = { + makeRowVector({makeFlatVector({100, 200}, DECIMAL(15, 2))}), + makeRowVector({makeFlatVector({300, 400}, DECIMAL(15, 2))}), + makeRowVector({makeFlatVector({500, 600}, DECIMAL(15, 2))}), + }; + + auto expectedResult = {makeRowVector( + {makeFlatVector(std::vector{350 * kRescale}, DECIMAL(19, 6))})}; + + testAggregations(inputRows, {}, {"spark_avg(c0)"}, expectedResult); +} + } // namespace } // namespace facebook::velox::functions::aggregate::sparksql::test