From f8ff4b04ca9b619cf9879a1aadeb3f7c4b513fb1 Mon Sep 17 00:00:00 2001 From: "joey.ljy" Date: Sat, 20 Apr 2024 16:14:24 +0800 Subject: [PATCH] Address comments --- velox/exec/tests/CMakeLists.txt | 9 +- .../exec/tests/SimpleAggregateAdapterTest.cpp | 276 +++++++----------- .../tests/utils/AggregationTestBase.cpp | 52 +--- 3 files changed, 117 insertions(+), 220 deletions(-) diff --git a/velox/exec/tests/CMakeLists.txt b/velox/exec/tests/CMakeLists.txt index ddfb25743d23..3cc4f1aee144 100644 --- a/velox/exec/tests/CMakeLists.txt +++ b/velox/exec/tests/CMakeLists.txt @@ -244,8 +244,13 @@ add_executable(velox_simple_aggregate_test SimpleAggregateAdapterTest.cpp Main.cpp) target_link_libraries( - velox_simple_aggregate_test velox_simple_aggregate velox_exec - velox_functions_aggregates_test_lib gtest gtest_main) + velox_simple_aggregate_test + velox_simple_aggregate + velox_exec + velox_functions_aggregates_test_lib + velox_functions_window_test_lib + gtest + gtest_main) add_library(velox_spiller_join_benchmark_base JoinSpillInputBenchmarkBase.cpp SpillerBenchmarkBase.cpp) diff --git a/velox/exec/tests/SimpleAggregateAdapterTest.cpp b/velox/exec/tests/SimpleAggregateAdapterTest.cpp index 3ae60fcb15b4..bf2a5bbf1e8a 100644 --- a/velox/exec/tests/SimpleAggregateAdapterTest.cpp +++ b/velox/exec/tests/SimpleAggregateAdapterTest.cpp @@ -18,11 +18,13 @@ #include "velox/exec/tests/SimpleAggregateFunctionsRegistration.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h" +#include "velox/functions/lib/window/tests/WindowTestBase.h" using namespace facebook::velox::exec; using namespace facebook::velox::exec::test; using facebook::velox::common::testutil::TestValue; using facebook::velox::functions::aggregate::test::AggregationTestBase; +using facebook::velox::window::test::WindowTestBase; namespace facebook::velox::aggregate::test { namespace { @@ -30,7 +32,6 @@ namespace { const char* const kSimpleAvg = "simple_avg"; const char* const kSimpleArrayAgg = "simple_array_agg"; const char* const kSimpleCountNulls = "simple_count_nulls"; -const char* const kSimpleFunctionStateAgg = "simple_function_state_agg"; class SimpleAverageAggregationTest : public AggregationTestBase { protected: @@ -484,6 +485,7 @@ TEST_F(SimpleCountNullsAggregationTest, basic) { } // A testing aggregation function that uses the function state. +template class FunctionStateTestAggregate { public: using InputType = Row; // Input vector type wrapped in Row. @@ -497,12 +499,60 @@ class FunctionStateTestAggregate { std::vector constantInputs; }; + static void checkConstantInputs( + const std::vector& constantInputs) { + // Check that the constantInputs is {nullptr, 1} + VELOX_CHECK_EQ(constantInputs.size(), 2); + VELOX_CHECK_NULL(constantInputs[0]); + VELOX_CHECK(constantInputs[1]->isConstantEncoding()); + VELOX_CHECK_EQ( + constantInputs[1] + ->template asUnchecked>() + ->valueAt(0), + 1); + } + static void initialize( FunctionState& state, core::AggregationNode::Step step, const std::vector& rawInputTypes, const TypePtr& resultType, const std::vector& constantInputs) { + auto expectedRawInputTypes = {BIGINT(), BIGINT()}; + auto expectedIntermediateType = ROW({BIGINT(), DOUBLE()}); + + if constexpr (testCompanion) { + if (step == core::AggregationNode::Step::kPartial) { + VELOX_CHECK(std::equal( + rawInputTypes.begin(), + rawInputTypes.end(), + expectedRawInputTypes.begin(), + expectedRawInputTypes.end())); + checkConstantInputs(constantInputs); + } else if (step == core::AggregationNode::Step::kIntermediate) { + VELOX_CHECK_EQ(rawInputTypes.size(), 1); + VELOX_CHECK(rawInputTypes[0]->equivalent(*expectedIntermediateType)); + + VELOX_CHECK_EQ(constantInputs.size(), 1); + VELOX_CHECK_NULL(constantInputs[0]); + } else { + VELOX_FAIL("Unexpected aggregation step"); + } + } else { + VELOX_CHECK(std::equal( + rawInputTypes.begin(), + rawInputTypes.end(), + expectedRawInputTypes.begin(), + expectedRawInputTypes.end())); + if (step == core::AggregationNode::Step::kPartial || + step == core::AggregationNode::Step::kSingle) { + checkConstantInputs(constantInputs); + } else { + VELOX_CHECK_EQ(constantInputs.size(), 1); + VELOX_CHECK_NULL(constantInputs[0]); + } + } + state.step = step; state.rawInputTypes = rawInputTypes; state.resultType = resultType; @@ -517,27 +567,24 @@ class FunctionStateTestAggregate { HashStringAllocator* /*allocator*/, const FunctionState& /*state*/) {} - void checkpoint(FunctionState* state) { - TestValue::adjust( - "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", - state); - } - void addInput( HashStringAllocator* /*allocator*/, exec::arg_type data, - exec::arg_type increment, + exec::arg_type /*increment*/, const FunctionState& state) { - checkpoint(const_cast(&state)); + VELOX_CHECK_EQ(state.constantInputs.size(), 2); + VELOX_CHECK(state.constantInputs[1]->isConstantEncoding()); + auto constant = state.constantInputs[1] + ->template asUnchecked>() + ->valueAt(0); sum += data; - count += increment; + count += constant; } void combine( HashStringAllocator* /*allocator*/, exec::arg_type other, const FunctionState& state) { - checkpoint(const_cast(&state)); VELOX_CHECK(other.at<0>().has_value()); VELOX_CHECK(other.at<1>().has_value()); sum += other.at<0>().value(); @@ -547,7 +594,6 @@ class FunctionStateTestAggregate { bool writeIntermediateResult( exec::out_type& out, const FunctionState& state) { - checkpoint(const_cast(&state)); out = std::make_tuple(sum, count); return true; } @@ -555,7 +601,6 @@ class FunctionStateTestAggregate { bool writeFinalResult( exec::out_type& out, const FunctionState& state) { - checkpoint(const_cast(&state)); out = sum / count; return true; } @@ -564,6 +609,7 @@ class FunctionStateTestAggregate { using AccumulatorType = Accumulator; }; +template exec::AggregateRegistrationResult registerSimpleFunctionStateTestAggregate( const std::string& name) { std::vector> signatures{ @@ -585,100 +631,28 @@ exec::AggregateRegistrationResult registerSimpleFunctionStateTestAggregate( -> std::unique_ptr { VELOX_CHECK_LE(argTypes.size(), 2, "{} takes 2 argument", name); return std::make_unique< - SimpleAggregateAdapter>(resultType); + SimpleAggregateAdapter>>( + resultType); }, true /*registerCompanionFunctions*/, true /*overwrite*/); } void registerFunctionStateTestAggregate() { - registerSimpleFunctionStateTestAggregate(kSimpleFunctionStateAgg); + registerSimpleFunctionStateTestAggregate( + "simple_function_state_agg_main"); + registerSimpleFunctionStateTestAggregate( + "simple_function_state_agg_companion"); } class SimpleFunctionStateAggregationTest : public AggregationTestBase { protected: - SimpleFunctionStateAggregationTest() { - registerFunctionStateTestAggregate(); - } - - static void checkRowTypeEqual(TypePtr expected, TypePtr actual) { - VELOX_CHECK(expected->isRow()); - VELOX_CHECK(actual->isRow()); - VELOX_CHECK_EQ(expected->asRow().size(), actual->asRow().size()); - for (auto i = 0; i < expected->asRow().size(); i++) { - VELOX_CHECK_EQ(expected->asRow().childAt(i), actual->asRow().childAt(i)); - } - } - - static void checkState( - FunctionStateTestAggregate::FunctionState* state, - const std::vector& rawInputTypes, - const TypePtr& intermediateType, - const TypePtr& resultType, - const std::vector& constantInputs) { - VELOX_CHECK(!state->rawInputTypes.empty()); - VELOX_CHECK_NOT_NULL(state->resultType); - - switch (state->step) { - case core::AggregationNode::Step::kPartial: - case core::AggregationNode::Step::kIntermediate: - if (state->rawInputTypes.size() == 1 && - state->rawInputTypes[0]->isRow()) { - // Merge or merge_extract companion function. - VELOX_CHECK_EQ(rawInputTypes.size(), 1); - VELOX_CHECK(rawInputTypes[0]->isRow()); - checkRowTypeEqual(rawInputTypes[0], state->rawInputTypes[0]); - } else { - VELOX_CHECK(std::equal( - rawInputTypes.begin(), - rawInputTypes.end(), - state->rawInputTypes.begin(), - state->rawInputTypes.end())); - } - if (state->resultType->isRow()) { - checkRowTypeEqual(intermediateType, state->resultType); - } else { - VELOX_CHECK_EQ(resultType, state->resultType) - } - break; - - case core::AggregationNode::Step::kSingle: - case core::AggregationNode::Step::kFinal: - VELOX_CHECK(std::equal( - rawInputTypes.begin(), - rawInputTypes.end(), - state->rawInputTypes.begin(), - state->rawInputTypes.end())); - VELOX_CHECK_EQ(resultType, state->resultType); - break; - - default: - VELOX_FAIL("Unknown aggregate step"); - break; - } + void SetUp() override { + AggregationTestBase::SetUp(); - VELOX_CHECK(!state->constantInputs.empty()); - if (state->step == core::AggregationNode::Step::kPartial || - state->step == core::AggregationNode::Step::kSingle) { - VELOX_CHECK_EQ(constantInputs.size(), state->constantInputs.size()); - for (auto i = 0; i < constantInputs.size(); i++) { - auto expected = constantInputs[i]; - auto actual = state->constantInputs[i]; - if (expected == nullptr && actual == nullptr) { - continue; - } else { - VELOX_CHECK(expected != nullptr && actual != nullptr); - VELOX_CHECK(expected->isConstantEncoding()); - VELOX_CHECK(actual->isConstantEncoding()); - VELOX_CHECK_EQ( - expected->asUnchecked>()->valueAt(0), - actual->asUnchecked>()->valueAt(0)); - } - } - } else { - VELOX_CHECK_EQ(state->constantInputs.size(), 1); - VELOX_CHECK_NULL(state->constantInputs[0]); - } + disableTestIncremental(); + disableTestStreaming(); + registerFunctionStateTestAggregate(); } }; @@ -686,48 +660,14 @@ TEST_F(SimpleFunctionStateAggregationTest, aggregate) { auto inputVectors = makeRowVector({makeFlatVector({1, 2, 3, 4})}); std::vector finalResult = {2.5}; auto expected = makeRowVector({makeFlatVector(finalResult)}); - - SCOPED_TESTVALUE_SET( - "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", - std::function( - [&](FunctionStateTestAggregate::FunctionState* state) { - checkState( - state, - {BIGINT(), BIGINT()}, - ROW({BIGINT(), DOUBLE()}), - DOUBLE(), - {nullptr, makeConstant(1, 4)}); - })); testAggregations( - {inputVectors}, {}, {"simple_function_state_agg(c0, 1)"}, {expected}); -} - -TEST_F(SimpleFunctionStateAggregationTest, window) { - auto inputVectors = - makeRowVector({makeFlatVector({1, 1, 2, 2, 3, 3, 4})}); - auto expected = makeRowVector( - {makeFlatVector({1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0})}); - SCOPED_TESTVALUE_SET( - "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", - std::function( - [&](FunctionStateTestAggregate::FunctionState* state) { - checkState( - state, - {BIGINT(), BIGINT()}, - ROW({BIGINT(), DOUBLE()}), - DOUBLE(), - {nullptr, makeConstant(1, 7)}); - })); - auto plan = - PlanBuilder() - .values({inputVectors}) - .window({"simple_function_state_agg(c0, 1) over (partition by c0)"}) - .project({"w0"}) - .planNode(); - AssertQueryBuilder(plan).assertResults(expected); + {inputVectors}, + {}, + {"simple_function_state_agg_main(c0, 1)"}, + {expected}); } -TEST_F(SimpleFunctionStateAggregationTest, companionAggregateFunction) { +TEST_F(SimpleFunctionStateAggregationTest, companionAggregate) { auto inputVectors = makeRowVector({makeFlatVector({1, 2, 3, 4})}); std::vector accSum = {10}; std::vector accCount = {4.0}; @@ -740,21 +680,11 @@ TEST_F(SimpleFunctionStateAggregationTest, companionAggregateFunction) { std::vector finalResult = {2.5}; auto finalExpected = makeRowVector({makeFlatVector(finalResult)}); - SCOPED_TESTVALUE_SET( - "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", - std::function( - [&](FunctionStateTestAggregate::FunctionState* state) { - checkState( - state, - {BIGINT(), BIGINT()}, - ROW({BIGINT(), DOUBLE()}), - DOUBLE(), - {nullptr, makeConstant(1, 4)}); - })); AssertQueryBuilder( PlanBuilder() .values({inputVectors}) - .singleAggregation({}, {"simple_function_state_agg_partial(c0, 1)"}) + .singleAggregation( + {}, {"simple_function_state_agg_companion_partial(c0, 1)"}) .planNode()) .assertResults(intermediateExpected); @@ -764,43 +694,47 @@ TEST_F(SimpleFunctionStateAggregationTest, companionAggregateFunction) { makeFlatVector({1.0, 1.0, 1.0, 1.0}), }), }); - SCOPED_TESTVALUE_SET( - "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", - std::function( - [&](FunctionStateTestAggregate::FunctionState* state) { - checkState( - state, - {ROW({BIGINT(), DOUBLE()})}, - ROW({BIGINT(), DOUBLE()}), - ROW({BIGINT(), DOUBLE()}), - {nullptr, makeConstant(1, 4)}); - })); + AssertQueryBuilder( PlanBuilder() .values({inputVectors}) - .singleAggregation({}, {"simple_function_state_agg_merge(c0)"}) + .singleAggregation( + {}, {"simple_function_state_agg_companion_merge(c0)"}) .planNode()) .assertResults(intermediateExpected); - SCOPED_TESTVALUE_SET( - "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", - std::function( - [&](FunctionStateTestAggregate::FunctionState* state) { - checkState( - state, - {ROW({BIGINT(), DOUBLE()})}, - ROW({BIGINT(), DOUBLE()}), - DOUBLE(), - {nullptr, makeConstant(1, 4)}); - })); AssertQueryBuilder( PlanBuilder() .values({inputVectors}) .singleAggregation( - {}, {"simple_function_state_agg_merge_extract(c0)"}) + {}, {"simple_function_state_agg_companion_merge_extract(c0)"}) .planNode()) .assertResults(finalExpected); } +class SimpleFunctionStateWindowTest : public WindowTestBase { + protected: + void SetUp() override { + WindowTestBase::SetUp(); + + registerFunctionStateTestAggregate(); + } +}; + +TEST_F(SimpleFunctionStateWindowTest, window) { + auto inputVectors = + makeRowVector({makeFlatVector({1, 1, 2, 2, 3, 3, 4})}); + auto expected = makeRowVector({ + makeFlatVector({1, 1, 2, 2, 3, 3, 4}), + makeFlatVector({1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0}), + }); + WindowTestBase::testWindowFunction( + {inputVectors}, + "simple_function_state_agg_main(c0, 1)", + {"partition by c0"}, + {}, + expected); +} + } // namespace } // namespace facebook::velox::aggregate::test diff --git a/velox/functions/lib/aggregates/tests/utils/AggregationTestBase.cpp b/velox/functions/lib/aggregates/tests/utils/AggregationTestBase.cpp index 2a0e0e82a60c..16c0ed4f5dad 100644 --- a/velox/functions/lib/aggregates/tests/utils/AggregationTestBase.cpp +++ b/velox/functions/lib/aggregates/tests/utils/AggregationTestBase.cpp @@ -1231,27 +1231,13 @@ void AggregationTestBase::testIncrementalAggregation( const auto& aggregateExpr = aggregate.call; const auto& functionName = aggregateExpr->name(); auto input = extractArgColumns(aggregateExpr, data, pool()); - const auto& inputType = aggregationNode.sources()[0]->outputType(); HashStringAllocator allocator(pool()); std::vector lambdas; - std::vector constantInputs; for (const auto& arg : aggregate.call->inputs()) { - if (dynamic_cast(arg.get())) { - constantInputs.push_back(nullptr); - } else if ( - auto constant = - dynamic_cast(arg.get())) { - constantInputs.push_back(constant->toConstantVector(pool())); - } else if ( - auto lambda = + if (auto lambda = std::dynamic_pointer_cast(arg)) { lambdas.push_back(lambda); - for (const auto& name : lambda->signature()->names()) { - if (auto captureIndex = inputType->getChildIdxIfExists(name)) { - constantInputs.push_back(nullptr); - } - } } } auto queryCtxConfig = config; @@ -1270,11 +1256,6 @@ void AggregationTestBase::testIncrementalAggregation( std::vector group(kOffset + func->accumulatorFixedWidthSize()); std::vector groups(inputSize, group.data()); std::vector indices(1, 0); - func->initialize( - aggregationNode.step(), - aggregate.rawInputTypes, - func->resultType(), - constantInputs); func->initializeNewGroups(groups.data(), indices); func->addSingleGroupRawInput( group.data(), SelectivityVector(inputSize), input, false); @@ -1314,24 +1295,11 @@ VectorPtr AggregationTestBase::testStreaming( vector_size_t rawInput2Size, const std::unordered_map& config) { std::vector rawInputTypes(rawInput1.size()); - std::vector constantInputs1(rawInput1.size()); - std::vector constantInputs2(rawInput2.size()); - for (auto i = 0; i < rawInput1.size(); i++) { - rawInputTypes[i] = rawInput1[i]->type(); - if (rawInput1[i]->isConstantEncoding()) { - constantInputs1[i] = rawInput1[i]; - } else { - constantInputs1[i] = nullptr; - } - } - std::transform( - rawInput2.begin(), - rawInput2.end(), - constantInputs2.begin(), - [](const VectorPtr& vec) { - return vec->isConstantEncoding() ? vec : nullptr; - }); + rawInput1.begin(), + rawInput1.end(), + rawInputTypes.begin(), + [](const VectorPtr& vec) { return vec->type(); }); HashStringAllocator allocator(pool()); auto func = @@ -1340,11 +1308,6 @@ VectorPtr AggregationTestBase::testStreaming( std::vector group(kOffset + func->accumulatorFixedWidthSize()); std::vector groups(maxRowCount, group.data()); std::vector indices(maxRowCount, 0); - func->initialize( - core::AggregationNode::Step::kSingle, - rawInputTypes, - func->resultType(), - constantInputs1); func->initializeNewGroups(groups.data(), indices); if (testGlobal) { func->addSingleGroupRawInput( @@ -1364,11 +1327,6 @@ VectorPtr AggregationTestBase::testStreaming( // Create a new function picking up the intermediate result. auto func2 = createAggregateFunction(functionName, rawInputTypes, allocator, config); - func2->initialize( - core::AggregationNode::Step::kSingle, - rawInputTypes, - func2->resultType(), - constantInputs2); func2->initializeNewGroups(groups.data(), indices); if (testGlobal) { func2->addSingleGroupIntermediateResults(