diff --git a/velox/exec/AggregateCompanionAdapter.cpp b/velox/exec/AggregateCompanionAdapter.cpp index 64ee49fe127bd..392111e992ccd 100644 --- a/velox/exec/AggregateCompanionAdapter.cpp +++ b/velox/exec/AggregateCompanionAdapter.cpp @@ -256,14 +256,21 @@ void AggregateCompanionAdapter::ExtractFunction::apply( // Get the raw input types. std::vector rawInputTypes{args.size()}; - std::transform( - args.begin(), - args.end(), - rawInputTypes.begin(), - [](const VectorPtr& arg) { return arg->type(); }); + std::vector constantInputs{args.size()}; + for (auto i = 0; i < args.size(); i++) { + rawInputTypes[i] = args[i]->type(); + if (args[i]->isConstantEncoding()) { + constantInputs[i] = args[i]; + } else { + constantInputs[i] = nullptr; + } + } fn_->initialize( - core::AggregationNode::Step::kFinal, rawInputTypes, outputType, {}); + core::AggregationNode::Step::kFinal, + rawInputTypes, + outputType, + constantInputs); fn_->initializeNewGroups(groups, allSelectedRange); fn_->enableValidateIntermediateInputs(); fn_->addIntermediateResults(groups, rows, args, false); diff --git a/velox/exec/tests/SimpleAggregateAdapterTest.cpp b/velox/exec/tests/SimpleAggregateAdapterTest.cpp index f24d4bd4af226..3ae60fcb15b4b 100644 --- a/velox/exec/tests/SimpleAggregateAdapterTest.cpp +++ b/velox/exec/tests/SimpleAggregateAdapterTest.cpp @@ -486,13 +486,13 @@ TEST_F(SimpleCountNullsAggregationTest, basic) { // A testing aggregation function that uses the function state. class FunctionStateTestAggregate { public: - using InputType = Row; // Input vector type wrapped in Row. - using IntermediateType = int64_t; // Intermediate result type. - using OutputType = int64_t; // Output vector type. + using InputType = Row; // Input vector type wrapped in Row. + using IntermediateType = Row; // Intermediate result type. + using OutputType = double; // Output vector type. struct FunctionState { core::AggregationNode::Step step; - std::vector rawInputType; + std::vector rawInputTypes; TypePtr resultType; std::vector constantInputs; }; @@ -504,16 +504,14 @@ class FunctionStateTestAggregate { const TypePtr& resultType, const std::vector& constantInputs) { state.step = step; - state.rawInputType = rawInputTypes; + state.rawInputTypes = rawInputTypes; state.resultType = resultType; - if (resultType == nullptr) { - LOG(INFO) << "nullptr"; - } state.constantInputs = constantInputs; } struct Accumulator { int64_t sum{0}; + double count{0}; explicit Accumulator( HashStringAllocator* /*allocator*/, @@ -528,9 +526,11 @@ class FunctionStateTestAggregate { void addInput( HashStringAllocator* /*allocator*/, exec::arg_type data, + exec::arg_type increment, const FunctionState& state) { checkpoint(const_cast(&state)); sum += data; + count += increment; } void combine( @@ -538,14 +538,17 @@ class FunctionStateTestAggregate { exec::arg_type other, const FunctionState& state) { checkpoint(const_cast(&state)); - sum += other; + VELOX_CHECK(other.at<0>().has_value()); + VELOX_CHECK(other.at<1>().has_value()); + sum += other.at<0>().value(); + count += other.at<1>().value(); } bool writeIntermediateResult( exec::out_type& out, const FunctionState& state) { checkpoint(const_cast(&state)); - out = sum; + out = std::make_tuple(sum, count); return true; } @@ -553,7 +556,7 @@ class FunctionStateTestAggregate { exec::out_type& out, const FunctionState& state) { checkpoint(const_cast(&state)); - out = sum; + out = sum / count; return true; } }; @@ -565,9 +568,10 @@ exec::AggregateRegistrationResult registerSimpleFunctionStateTestAggregate( const std::string& name) { std::vector> signatures{ exec::AggregateFunctionSignatureBuilder() - .returnType("bigint") - .intermediateType("bigint") - .argumentType("bigint") + .returnType("DOUBLE") + .intermediateType("ROW(BIGINT, DOUBLE)") + .argumentType("BIGINT") + .argumentType("BIGINT") .build()}; return exec::registerAggregateFunction( @@ -579,8 +583,7 @@ exec::AggregateRegistrationResult registerSimpleFunctionStateTestAggregate( const TypePtr& resultType, const core::QueryConfig& /*config*/) -> std::unique_ptr { - VELOX_CHECK_LE( - argTypes.size(), 1, "{} takes at most one argument", name); + VELOX_CHECK_LE(argTypes.size(), 2, "{} takes 2 argument", name); return std::make_unique< SimpleAggregateAdapter>(resultType); }, @@ -598,97 +601,197 @@ class SimpleFunctionStateAggregationTest : public AggregationTestBase { registerFunctionStateTestAggregate(); } + static void checkRowTypeEqual(TypePtr expected, TypePtr actual) { + VELOX_CHECK(expected->isRow()); + VELOX_CHECK(actual->isRow()); + VELOX_CHECK_EQ(expected->asRow().size(), actual->asRow().size()); + for (auto i = 0; i < expected->asRow().size(); i++) { + VELOX_CHECK_EQ(expected->asRow().childAt(i), actual->asRow().childAt(i)); + } + } + static void checkState( FunctionStateTestAggregate::FunctionState* state, - const std::string& step = "") { + const std::vector& rawInputTypes, + const TypePtr& intermediateType, + const TypePtr& resultType, + const std::vector& constantInputs) { + VELOX_CHECK(!state->rawInputTypes.empty()); VELOX_CHECK_NOT_NULL(state->resultType); - VELOX_CHECK(!state->rawInputType.empty()); - if (!step.empty()) { - VELOX_CHECK_EQ(core::AggregationNode::stepName(state->step), step); + + switch (state->step) { + case core::AggregationNode::Step::kPartial: + case core::AggregationNode::Step::kIntermediate: + if (state->rawInputTypes.size() == 1 && + state->rawInputTypes[0]->isRow()) { + // Merge or merge_extract companion function. + VELOX_CHECK_EQ(rawInputTypes.size(), 1); + VELOX_CHECK(rawInputTypes[0]->isRow()); + checkRowTypeEqual(rawInputTypes[0], state->rawInputTypes[0]); + } else { + VELOX_CHECK(std::equal( + rawInputTypes.begin(), + rawInputTypes.end(), + state->rawInputTypes.begin(), + state->rawInputTypes.end())); + } + if (state->resultType->isRow()) { + checkRowTypeEqual(intermediateType, state->resultType); + } else { + VELOX_CHECK_EQ(resultType, state->resultType) + } + break; + + case core::AggregationNode::Step::kSingle: + case core::AggregationNode::Step::kFinal: + VELOX_CHECK(std::equal( + rawInputTypes.begin(), + rawInputTypes.end(), + state->rawInputTypes.begin(), + state->rawInputTypes.end())); + VELOX_CHECK_EQ(resultType, state->resultType); + break; + + default: + VELOX_FAIL("Unknown aggregate step"); + break; + } + + VELOX_CHECK(!state->constantInputs.empty()); + if (state->step == core::AggregationNode::Step::kPartial || + state->step == core::AggregationNode::Step::kSingle) { + VELOX_CHECK_EQ(constantInputs.size(), state->constantInputs.size()); + for (auto i = 0; i < constantInputs.size(); i++) { + auto expected = constantInputs[i]; + auto actual = state->constantInputs[i]; + if (expected == nullptr && actual == nullptr) { + continue; + } else { + VELOX_CHECK(expected != nullptr && actual != nullptr); + VELOX_CHECK(expected->isConstantEncoding()); + VELOX_CHECK(actual->isConstantEncoding()); + VELOX_CHECK_EQ( + expected->asUnchecked>()->valueAt(0), + actual->asUnchecked>()->valueAt(0)); + } + } + } else { + VELOX_CHECK_EQ(state->constantInputs.size(), 1); + VELOX_CHECK_NULL(state->constantInputs[0]); } } }; TEST_F(SimpleFunctionStateAggregationTest, aggregate) { auto inputVectors = makeRowVector({makeFlatVector({1, 2, 3, 4})}); - std::vector sum = {10}; - auto expected = makeRowVector({makeFlatVector(sum)}); + std::vector finalResult = {2.5}; + auto expected = makeRowVector({makeFlatVector(finalResult)}); SCOPED_TESTVALUE_SET( "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", std::function( [&](FunctionStateTestAggregate::FunctionState* state) { - checkState(state); + checkState( + state, + {BIGINT(), BIGINT()}, + ROW({BIGINT(), DOUBLE()}), + DOUBLE(), + {nullptr, makeConstant(1, 4)}); })); - testAggregations( - {inputVectors}, {}, {"simple_function_state_agg(c0)"}, {expected}); - testAggregationsWithCompanion( - {inputVectors}, - [](auto& /*builder*/) {}, - {}, - {"simple_function_state_agg(c0)"}, - {{BIGINT()}}, - {}, - {expected}, - {}); + {inputVectors}, {}, {"simple_function_state_agg(c0, 1)"}, {expected}); } TEST_F(SimpleFunctionStateAggregationTest, window) { auto inputVectors = makeRowVector({makeFlatVector({1, 1, 2, 2, 3, 3, 4})}); - auto expected = - makeRowVector({makeFlatVector({2, 2, 4, 4, 6, 6, 4})}); + auto expected = makeRowVector( + {makeFlatVector({1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0})}); SCOPED_TESTVALUE_SET( "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", std::function( [&](FunctionStateTestAggregate::FunctionState* state) { - checkState(state, "SINGLE"); + checkState( + state, + {BIGINT(), BIGINT()}, + ROW({BIGINT(), DOUBLE()}), + DOUBLE(), + {nullptr, makeConstant(1, 7)}); })); auto plan = PlanBuilder() .values({inputVectors}) - .window({"simple_function_state_agg(c0) over (partition by c0)"}) + .window({"simple_function_state_agg(c0, 1) over (partition by c0)"}) .project({"w0"}) .planNode(); AssertQueryBuilder(plan).assertResults(expected); } -TEST_F(SimpleFunctionStateAggregationTest, aggregateStep) { +TEST_F(SimpleFunctionStateAggregationTest, companionAggregateFunction) { auto inputVectors = makeRowVector({makeFlatVector({1, 2, 3, 4})}); - std::vector sum = {10}; - auto expected = makeRowVector({makeFlatVector(sum)}); + std::vector accSum = {10}; + std::vector accCount = {4.0}; + auto intermediateExpected = makeRowVector({ + makeRowVector({ + makeFlatVector(accSum), + makeFlatVector(accCount), + }), + }); + std::vector finalResult = {2.5}; + auto finalExpected = makeRowVector({makeFlatVector(finalResult)}); + SCOPED_TESTVALUE_SET( "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", std::function( [&](FunctionStateTestAggregate::FunctionState* state) { - checkState(state, "PARTIAL"); + checkState( + state, + {BIGINT(), BIGINT()}, + ROW({BIGINT(), DOUBLE()}), + DOUBLE(), + {nullptr, makeConstant(1, 4)}); })); AssertQueryBuilder( PlanBuilder() .values({inputVectors}) - .singleAggregation({}, {"simple_function_state_agg_partial(c0)"}) + .singleAggregation({}, {"simple_function_state_agg_partial(c0, 1)"}) .planNode()) - .assertResults(expected); - + .assertResults(intermediateExpected); + + inputVectors = makeRowVector({ + makeRowVector({ + makeFlatVector({1, 2, 3, 4}), + makeFlatVector({1.0, 1.0, 1.0, 1.0}), + }), + }); SCOPED_TESTVALUE_SET( "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", std::function( [&](FunctionStateTestAggregate::FunctionState* state) { - checkState(state, "INTERMEDIATE"); + checkState( + state, + {ROW({BIGINT(), DOUBLE()})}, + ROW({BIGINT(), DOUBLE()}), + ROW({BIGINT(), DOUBLE()}), + {nullptr, makeConstant(1, 4)}); })); AssertQueryBuilder( PlanBuilder() .values({inputVectors}) .singleAggregation({}, {"simple_function_state_agg_merge(c0)"}) .planNode()) - .assertResults(expected); + .assertResults(intermediateExpected); SCOPED_TESTVALUE_SET( "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", std::function( [&](FunctionStateTestAggregate::FunctionState* state) { - checkState(state, "INTERMEDIATE"); + checkState( + state, + {ROW({BIGINT(), DOUBLE()})}, + ROW({BIGINT(), DOUBLE()}), + DOUBLE(), + {nullptr, makeConstant(1, 4)}); })); AssertQueryBuilder( PlanBuilder() @@ -696,20 +799,7 @@ TEST_F(SimpleFunctionStateAggregationTest, aggregateStep) { .singleAggregation( {}, {"simple_function_state_agg_merge_extract(c0)"}) .planNode()) - .assertResults(expected); - - SCOPED_TESTVALUE_SET( - "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", - std::function( - [&](FunctionStateTestAggregate::FunctionState* state) { - checkState(state, "FINAL"); - })); - AssertQueryBuilder( - PlanBuilder() - .values({inputVectors}) - .finalAggregation({}, {"simple_function_state_agg(c0)"}, {{BIGINT()}}) - .planNode()) - .assertResults(expected); + .assertResults(finalExpected); } } // namespace diff --git a/velox/functions/lib/aggregates/tests/utils/AggregationTestBase.cpp b/velox/functions/lib/aggregates/tests/utils/AggregationTestBase.cpp index 5b3dd3382dc3b..2a0e0e82a60ce 100644 --- a/velox/functions/lib/aggregates/tests/utils/AggregationTestBase.cpp +++ b/velox/functions/lib/aggregates/tests/utils/AggregationTestBase.cpp @@ -1231,13 +1231,27 @@ void AggregationTestBase::testIncrementalAggregation( const auto& aggregateExpr = aggregate.call; const auto& functionName = aggregateExpr->name(); auto input = extractArgColumns(aggregateExpr, data, pool()); + const auto& inputType = aggregationNode.sources()[0]->outputType(); HashStringAllocator allocator(pool()); std::vector lambdas; + std::vector constantInputs; for (const auto& arg : aggregate.call->inputs()) { - if (auto lambda = + if (dynamic_cast(arg.get())) { + constantInputs.push_back(nullptr); + } else if ( + auto constant = + dynamic_cast(arg.get())) { + constantInputs.push_back(constant->toConstantVector(pool())); + } else if ( + auto lambda = std::dynamic_pointer_cast(arg)) { lambdas.push_back(lambda); + for (const auto& name : lambda->signature()->names()) { + if (auto captureIndex = inputType->getChildIdxIfExists(name)) { + constantInputs.push_back(nullptr); + } + } } } auto queryCtxConfig = config; @@ -1260,7 +1274,7 @@ void AggregationTestBase::testIncrementalAggregation( aggregationNode.step(), aggregate.rawInputTypes, func->resultType(), - {}); + constantInputs); func->initializeNewGroups(groups.data(), indices); func->addSingleGroupRawInput( group.data(), SelectivityVector(inputSize), input, false); @@ -1300,11 +1314,24 @@ VectorPtr AggregationTestBase::testStreaming( vector_size_t rawInput2Size, const std::unordered_map& config) { std::vector rawInputTypes(rawInput1.size()); + std::vector constantInputs1(rawInput1.size()); + std::vector constantInputs2(rawInput2.size()); + for (auto i = 0; i < rawInput1.size(); i++) { + rawInputTypes[i] = rawInput1[i]->type(); + if (rawInput1[i]->isConstantEncoding()) { + constantInputs1[i] = rawInput1[i]; + } else { + constantInputs1[i] = nullptr; + } + } + std::transform( - rawInput1.begin(), - rawInput1.end(), - rawInputTypes.begin(), - [](const VectorPtr& vec) { return vec->type(); }); + rawInput2.begin(), + rawInput2.end(), + constantInputs2.begin(), + [](const VectorPtr& vec) { + return vec->isConstantEncoding() ? vec : nullptr; + }); HashStringAllocator allocator(pool()); auto func = @@ -1317,7 +1344,7 @@ VectorPtr AggregationTestBase::testStreaming( core::AggregationNode::Step::kSingle, rawInputTypes, func->resultType(), - {}); + constantInputs1); func->initializeNewGroups(groups.data(), indices); if (testGlobal) { func->addSingleGroupRawInput( @@ -1341,7 +1368,7 @@ VectorPtr AggregationTestBase::testStreaming( core::AggregationNode::Step::kSingle, rawInputTypes, func2->resultType(), - {}); + constantInputs2); func2->initializeNewGroups(groups.data(), indices); if (testGlobal) { func2->addSingleGroupIntermediateResults( diff --git a/velox/functions/sparksql/aggregates/RegrReplacementAggregate.cpp b/velox/functions/sparksql/aggregates/RegrReplacementAggregate.cpp index 85f69d1f7c3fb..5b5e63b9bea84 100644 --- a/velox/functions/sparksql/aggregates/RegrReplacementAggregate.cpp +++ b/velox/functions/sparksql/aggregates/RegrReplacementAggregate.cpp @@ -29,6 +29,8 @@ class RegrReplacementAggregate { /*m2*/ double>; using OutputType = double; + struct FunctionState {}; + static bool toIntermediate( exec::out_type>& out, exec::arg_type in) { @@ -41,11 +43,14 @@ class RegrReplacementAggregate { double avg{0.0}; double m2{0.0}; - explicit AccumulatorType(HashStringAllocator* /*allocator*/) {} + explicit AccumulatorType( + HashStringAllocator* /*allocator*/, + const FunctionState& /*state*/) {} void addInput( HashStringAllocator* /*allocator*/, - exec::arg_type data) { + exec::arg_type data, + const FunctionState& /*state*/) { n += 1.0; double delta = data - avg; double deltaN = delta / n; @@ -55,7 +60,8 @@ class RegrReplacementAggregate { void combine( HashStringAllocator* /*allocator*/, - exec::arg_type> other) { + exec::arg_type> other, + const FunctionState& /*state*/) { VELOX_CHECK(other.at<0>().has_value()); VELOX_CHECK(other.at<1>().has_value()); VELOX_CHECK(other.at<2>().has_value()); @@ -72,12 +78,16 @@ class RegrReplacementAggregate { m2 += otherM2 + delta * deltaN * originN * otherN; } - bool writeIntermediateResult(exec::out_type& out) { + bool writeIntermediateResult( + exec::out_type& out, + const FunctionState& /*state*/) { out = std::make_tuple(n, avg, m2); return true; } - bool writeFinalResult(exec::out_type& out) { + bool writeFinalResult( + exec::out_type& out, + const FunctionState& /*state*/) { if (n == 0.0) { return false; }