diff --git a/velox/docs/develop/aggregate-functions.rst b/velox/docs/develop/aggregate-functions.rst index 024a430578375..b527480ce6cab 100644 --- a/velox/docs/develop/aggregate-functions.rst +++ b/velox/docs/develop/aggregate-functions.rst @@ -165,7 +165,8 @@ A simple aggregation function is implemented as a class as the following. FunctionState& state, const std::vector& rawInputTypes, const TypePtr& resultType, - const std::vector& constantInputs) { + const std::vector& constantInputs, + std::optional companionStep) { state.resultType = resultType; } diff --git a/velox/exec/Aggregate.h b/velox/exec/Aggregate.h index 7267bd748a408..a27f22e144ac6 100644 --- a/velox/exec/Aggregate.h +++ b/velox/exec/Aggregate.h @@ -139,7 +139,9 @@ class Aggregate { core::AggregationNode::Step step, const std::vector& rawInputType, const TypePtr& resultType, - const std::vector& constantInputs) {} + const std::vector& constantInputs, + std::optional companionStep = std::nullopt) { + } // Initializes null flags and accumulators for newly encountered groups. This // function should be called only once for each group. diff --git a/velox/exec/AggregateCompanionAdapter.cpp b/velox/exec/AggregateCompanionAdapter.cpp index 392111e992ccd..cd65af40f453e 100644 --- a/velox/exec/AggregateCompanionAdapter.cpp +++ b/velox/exec/AggregateCompanionAdapter.cpp @@ -125,15 +125,17 @@ void AggregateCompanionFunctionBase::extractAccumulators( } void AggregateCompanionAdapter::PartialFunction::initialize( - core::AggregationNode::Step /*step*/, + core::AggregationNode::Step step, const std::vector& rawInputType, const facebook::velox::TypePtr& resultType, - const std::vector& constantInputs) { + const std::vector& constantInputs, + std::optional /*companionStep*/) { fn_->initialize( - core::AggregationNode::Step::kPartial, + step, rawInputType, resultType, - constantInputs); + constantInputs, + core::AggregationNode::Step::kPartial); } void AggregateCompanionAdapter::PartialFunction::extractValues( @@ -144,15 +146,17 @@ void AggregateCompanionAdapter::PartialFunction::extractValues( } void AggregateCompanionAdapter::MergeFunction::initialize( - core::AggregationNode::Step /*step*/, + core::AggregationNode::Step step, const std::vector& rawInputType, const facebook::velox::TypePtr& resultType, - const std::vector& constantInputs) { + const std::vector& constantInputs, + std::optional /*companionStep*/) { fn_->initialize( - core::AggregationNode::Step::kIntermediate, + step, rawInputType, resultType, - constantInputs); + constantInputs, + core::AggregationNode::Step::kIntermediate); } void AggregateCompanionAdapter::MergeFunction::addRawInput( @@ -270,7 +274,8 @@ void AggregateCompanionAdapter::ExtractFunction::apply( core::AggregationNode::Step::kFinal, rawInputTypes, outputType, - constantInputs); + constantInputs, + core::AggregationNode::Step::kIntermediate); fn_->initializeNewGroups(groups, allSelectedRange); fn_->enableValidateIntermediateInputs(); fn_->addIntermediateResults(groups, rows, args, false); diff --git a/velox/exec/AggregateCompanionAdapter.h b/velox/exec/AggregateCompanionAdapter.h index 24b196e3f60ed..d2943df53f979 100644 --- a/velox/exec/AggregateCompanionAdapter.h +++ b/velox/exec/AggregateCompanionAdapter.h @@ -103,7 +103,8 @@ struct AggregateCompanionAdapter { core::AggregationNode::Step step, const std::vector& rawInputType, const TypePtr& resultType, - const std::vector& constantInputs) override; + const std::vector& constantInputs, + std::optional companionStep) override; void extractValues(char** groups, int32_t numGroups, VectorPtr* result) override; @@ -120,7 +121,8 @@ struct AggregateCompanionAdapter { core::AggregationNode::Step step, const std::vector& rawInputType, const TypePtr& resultType, - const std::vector& constantInputs) override; + const std::vector& constantInputs, + std::optional companionStep) override; void addRawInput( char** groups, diff --git a/velox/exec/SimpleAggregateAdapter.h b/velox/exec/SimpleAggregateAdapter.h index e236e32d1378b..0826016e36e83 100644 --- a/velox/exec/SimpleAggregateAdapter.h +++ b/velox/exec/SimpleAggregateAdapter.h @@ -53,9 +53,16 @@ class SimpleAggregateAdapter : public Aggregate { core::AggregationNode::Step step, const std::vector& rawInputTypes, const TypePtr& resultType, - const std::vector& constantInputs) override { + const std::vector& constantInputs, + std::optional companionStep) override { if constexpr (support_initialize_function_state_) { - FUNC::initialize(state_, step, rawInputTypes, resultType, constantInputs); + FUNC::initialize( + state_, + step, + rawInputTypes, + resultType, + constantInputs, + companionStep); } } diff --git a/velox/exec/tests/SimpleAggregateAdapterTest.cpp b/velox/exec/tests/SimpleAggregateAdapterTest.cpp index bf2a5bbf1e8a5..6f8b44451e69c 100644 --- a/velox/exec/tests/SimpleAggregateAdapterTest.cpp +++ b/velox/exec/tests/SimpleAggregateAdapterTest.cpp @@ -497,6 +497,7 @@ class FunctionStateTestAggregate { std::vector rawInputTypes; TypePtr resultType; std::vector constantInputs; + core::AggregationNode::Step companionStep; }; static void checkConstantInputs( @@ -517,19 +518,29 @@ class FunctionStateTestAggregate { core::AggregationNode::Step step, const std::vector& rawInputTypes, const TypePtr& resultType, - const std::vector& constantInputs) { + const std::vector& constantInputs, + std::optional companionStep) { auto expectedRawInputTypes = {BIGINT(), BIGINT()}; auto expectedIntermediateType = ROW({BIGINT(), DOUBLE()}); if constexpr (testCompanion) { - if (step == core::AggregationNode::Step::kPartial) { + VELOX_CHECK(companionStep.has_value()); + if (companionStep.value() == core::AggregationNode::Step::kPartial) { VELOX_CHECK(std::equal( rawInputTypes.begin(), rawInputTypes.end(), expectedRawInputTypes.begin(), expectedRawInputTypes.end())); - checkConstantInputs(constantInputs); - } else if (step == core::AggregationNode::Step::kIntermediate) { + if (step == core::AggregationNode::Step::kPartial || + step == core::AggregationNode::Step::kSingle) { + // Only check constant inputs in partial and single step. + checkConstantInputs(constantInputs); + } else { + VELOX_CHECK_EQ(constantInputs.size(), 1); + VELOX_CHECK_NULL(constantInputs[0]); + } + } else if ( + companionStep.value() == core::AggregationNode::Step::kIntermediate) { VELOX_CHECK_EQ(rawInputTypes.size(), 1); VELOX_CHECK(rawInputTypes[0]->equivalent(*expectedIntermediateType)); @@ -546,6 +557,7 @@ class FunctionStateTestAggregate { expectedRawInputTypes.end())); if (step == core::AggregationNode::Step::kPartial || step == core::AggregationNode::Step::kSingle) { + // Only check constant inputs in partial and single step. checkConstantInputs(constantInputs); } else { VELOX_CHECK_EQ(constantInputs.size(), 1); @@ -665,51 +677,15 @@ TEST_F(SimpleFunctionStateAggregationTest, aggregate) { {}, {"simple_function_state_agg_main(c0, 1)"}, {expected}); -} - -TEST_F(SimpleFunctionStateAggregationTest, companionAggregate) { - auto inputVectors = makeRowVector({makeFlatVector({1, 2, 3, 4})}); - std::vector accSum = {10}; - std::vector accCount = {4.0}; - auto intermediateExpected = makeRowVector({ - makeRowVector({ - makeFlatVector(accSum), - makeFlatVector(accCount), - }), - }); - std::vector finalResult = {2.5}; - auto finalExpected = makeRowVector({makeFlatVector(finalResult)}); - - AssertQueryBuilder( - PlanBuilder() - .values({inputVectors}) - .singleAggregation( - {}, {"simple_function_state_agg_companion_partial(c0, 1)"}) - .planNode()) - .assertResults(intermediateExpected); - - inputVectors = makeRowVector({ - makeRowVector({ - makeFlatVector({1, 2, 3, 4}), - makeFlatVector({1.0, 1.0, 1.0, 1.0}), - }), - }); - - AssertQueryBuilder( - PlanBuilder() - .values({inputVectors}) - .singleAggregation( - {}, {"simple_function_state_agg_companion_merge(c0)"}) - .planNode()) - .assertResults(intermediateExpected); - - AssertQueryBuilder( - PlanBuilder() - .values({inputVectors}) - .singleAggregation( - {}, {"simple_function_state_agg_companion_merge_extract(c0)"}) - .planNode()) - .assertResults(finalExpected); + testAggregationsWithCompanion( + {inputVectors}, + [](auto& /*builder*/) {}, + {}, + {"simple_function_state_agg_companion(c0, 1)"}, + {{BIGINT(), BIGINT()}}, + {}, + {expected}, + {}); } class SimpleFunctionStateWindowTest : public WindowTestBase {