From 34e7e42c49ec6e18be9a2851b48711fd6ebf7c71 Mon Sep 17 00:00:00 2001 From: "joey.ljy" Date: Tue, 19 Mar 2024 15:48:09 +0800 Subject: [PATCH 1/8] Add function state --- velox/docs/develop/aggregate-functions.rst | 116 ++++++++++++------ velox/exec/Aggregate.h | 10 ++ velox/exec/AggregateCompanionAdapter.cpp | 16 +++ velox/exec/AggregateCompanionAdapter.h | 5 + velox/exec/AggregateInfo.cpp | 2 + velox/exec/AggregateWindow.cpp | 3 + velox/exec/SimpleAggregateAdapter.h | 65 +++++++--- .../exec/tests/SimpleAggregateAdapterTest.cpp | 20 ++- velox/exec/tests/SimpleArrayAggAggregate.cpp | 18 ++- velox/exec/tests/SimpleAverageAggregate.cpp | 22 +++- .../aggregates/BitwiseXorAggregate.cpp | 24 +++- .../aggregates/GeometricMeanAggregate.cpp | 20 ++- .../sparksql/aggregates/DecimalSumAggregate.h | 20 ++- 13 files changed, 261 insertions(+), 80 deletions(-) diff --git a/velox/docs/develop/aggregate-functions.rst b/velox/docs/develop/aggregate-functions.rst index c11a5dabc211..85c600035326 100644 --- a/velox/docs/develop/aggregate-functions.rst +++ b/velox/docs/develop/aggregate-functions.rst @@ -152,13 +152,29 @@ A simple aggregation function is implemented as a class as the following. using IntermediateType = Array>; using OutputType = Array>; + // If UDAF does not require the use of FunctionState, it is necessary + // to declare an empty FunctionState struct. + struct FunctionState { + // Optional. + TypePtr resultType; + }; + + // Optional. Used only when the UDAF needs to use FunctionState. + static void initialize( + FunctionState& state, + const std::vector& rawInputTypes, + const TypePtr& resultType, + const std::vector& constantInputs) { + state.resultType = resultType; + } + // Optional. Default is true. static constexpr bool default_null_behavior_ = false; // Optional. static bool toIntermediate( - exec::out_type>>& out, - exec::optional_arg_type> in); + exec::out_type>>& out, + exec::optional_arg_type> in); struct AccumulatorType { ... }; }; @@ -169,6 +185,15 @@ function's argument type(s) wrapped in a Row<> even if the function only takes one argument. This is needed for the SimpleAggregateAdapter to parse input types for arbitrary aggregation functions properly. +A FunctionState struct needs to be declared in the simple aggregation function +class, it is used to hold the function-level variables that are typically +computed once and used at every row when adding inputs to accumulators or +extracting values from accumulators. For example, if the UDAF needs to get the +result type or the raw input type of the aggregaiton function, the author can +hold them in the FunctionState struct, and initialize them in the initialize() +method. If the UDAF does not require the use ofFunctionState, it is necessary +to declare an empty FunctionState struct. + The author can define an optional flag `default_null_behavior_` indicating whether the aggregation function has default-null behavior. This flag is true by default. Next, the class can have an optional method `toIntermediate()` @@ -257,17 +282,21 @@ For aggregaiton functions of default-null behavior, the author defines an // Optional. Default is false. static constexpr bool is_aligned_ = true; - explicit AccumulatorType(HashStringAllocator* allocator); + explicit AccumulatorType(HashStringAllocator* allocator, const FunctionState& state); - void addInput(HashStringAllocator* allocator, exec::arg_type value1, ...); + void addInput( + HashStringAllocator* allocator, + exec::arg_type value1, ..., + const FunctionState& state); void combine( HashStringAllocator* allocator, - exec::arg_type other); + exec::arg_type other, + const FunctionState& state); - bool writeIntermediateResult(exec::out_type& out); + bool writeIntermediateResult(exec::out_type& out, const FunctionState& state); - bool writeFinalResult(exec::out_type& out); + bool writeFinalResult(exec::out_type& out, const FunctionState& state); // Optional. Called during destruction. void destroy(HashStringAllocator* allocator); @@ -296,7 +325,8 @@ addInput This method adds raw input values to *this* accumulator. It receives a `HashStringAllocator*` followed by `exec::arg_type`-typed values, one for -each argument type `Ti` wrapped in InputType. +each argument type `Ti` wrapped in InputType. `const FunctionState&` hold the +function-level variables. With default-null behavior, raw-input rows where at least one column is null are ignored before `addInput` is called. After `addInput` is called, *this* @@ -306,31 +336,32 @@ combine """"""" This method adds an input intermediate state to *this* accumulator. It receives -a `HashStringAllocator*` and one `exec::arg_type` value. With -default-null behavior, nulls among the input intermediate states are ignored -before `combine` is called. After `combine` is called, *this* accumulator is -assumed to be non-null. +a `HashStringAllocator*` and one `exec::arg_type` value. +`const FunctionState&` hold the function-level variables. With default-null +behavior, nulls among the input intermediate states are ignored before `combine` +is called. After `combine` is called, *this* accumulator is assumed to be non-null. writeIntermediateResult """"""""""""""""""""""" This method writes *this* accumulator out to an intermediate state vector. It -has an out-parameter of the type `exec::out_type&`. This -method returns true if it writes a non-null value to `out`, or returns false -meaning a null should be written to the intermediate state vector. Accumulators -that are nulls (i.e., no value has been added to them) automatically become -nulls in the intermediate state vector without `writeIntermediateResult` being -called. +has an out-parameter of the type `exec::out_type&`. +`const FunctionState&` hold the function-level variables. This method returns +true if it writes a non-null value to `out`, or returns false meaning a null +should be written to the intermediate state vector. Accumulators that are +nulls (i.e., no value has been added to them) automatically become nulls in +the intermediate state vector without `writeIntermediateResult` being called. writeFinalResult """""""""""""""" This method writes *this* accumulator out to a final result vector. It -has an out-parameter of the type `exec::out_type&`. This -method returns true if it writes a non-null value to `out`, or returns false -meaning a null should be written to the final result vector. Accumulators -that are nulls (i.e., no value has been added to them) automatically become -nulls in the final result vector without `writeFinalResult` being called. +has an out-parameter of the type `exec::out_type&`. +`const FunctionState&` hold the function-level variables. This method returns +true if it writes a non-null value to `out`, or returns false meaning a null +should be written to the final result vector. Accumulators that are +nulls (i.e., no value has been added to them) automatically become nulls in the +final result vector without `writeFinalResult` being called. AccumulatorType of Non-Default-Null Behavior ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -355,15 +386,25 @@ For aggregaiton functions of non-default-null behavior, the author defines an explicit AccumulatorType(HashStringAllocator* allocator); - bool addInput(HashStringAllocator* allocator, exec::optional_arg_type value1, ...); + bool addInput( + HashStringAllocator* allocator, + exec::optional_arg_type value1, ..., + const FunctionState& state); bool combine( HashStringAllocator* allocator, - exec::optional_arg_type other); + exec::optional_arg_type other, + const FunctionState& state); - bool writeIntermediateResult(bool nonNullGroup, exec::out_type& out); + bool writeIntermediateResult( + bool nonNullGroup, + exec::out_type& out, + const FunctionState& state); - bool writeFinalResult(bool nonNullGroup, exec::out_type& out); + bool writeFinalResult( + bool nonNullGroup, + exec::out_type& out, + const FunctionState& state); // Optional. void destroy(HashStringAllocator* allocator); @@ -384,7 +425,7 @@ addInput This method receives a `HashStringAllocator*` followed by `exec::optional_arg_type` values, one for each argument type `Ti` wrapped -in InputType. +in InputType. `const FunctionState&` hold the function-level variables. This method is called on all raw-input rows even if some columns may be null. It returns a boolean meaning whether *this* accumulator is non-null after the @@ -397,26 +438,29 @@ combine """"""" This method receives a `HashStringAllocator*` and an -`exec::optional_arg_type` value. This method is called on -all intermediate states even if some are nulls. Same as `addInput`, this method -returns a boolean meaning whether *this* accumulator is non-null after the call. +`exec::optional_arg_type` value. `const FunctionState&` hold +the function-level variables.This method is called on all intermediate states +even if some are nulls. Same as `addInput`, this method returns a boolean +meaning whether *this* accumulator is non-null after the call. writeIntermediateResult """"""""""""""""""""""" This method has an out-parameter of the type `exec::out_type&` and a boolean flag `nonNullGroup` indicating whether *this* accumulator is -non-null. This method returns true if it writes a non-null value to `out`, or -return false meaning a null should be written to the intermediate state vector. +non-null. `const FunctionState&` hold the function-level variables. This method +returns true if it writes a non-null value to `out`, or return false meaning a +null should be written to the intermediate state vector. writeFinalResult """""""""""""""" This method writes *this* accumulator out to a final result vector. It has an out-parameter of the type `exec::out_type&` and a boolean flag -`nonNullGroup` indicating whether *this* accumulator is non-null. This method -returns true if it writes a non-null value to `out`, or return false meaning a -null should be written to the final result vector. +`nonNullGroup` indicating whether *this* accumulator is non-null. +`const FunctionState&` hold the function-level variables.This method returns +true if it writes a non-null value to `out`, or return false meaning a null +should be written to the final result vector. Limitations ^^^^^^^^^^^ diff --git a/velox/exec/Aggregate.h b/velox/exec/Aggregate.h index d6bc12aefcde..0830f7b7d9a9 100644 --- a/velox/exec/Aggregate.h +++ b/velox/exec/Aggregate.h @@ -129,6 +129,16 @@ class Aggregate { rowSizeOffset); } + // Initialize the function-level state of the simple function interface for + // UDAF. + // @param rawInputType The raw input type of the UDAF. + // @param resultType The result type of the UDAF. + // @param constantInputs Optional constant inputs. + virtual void initialize( + const std::vector& rawInputType, + const TypePtr& resultType, + const std::vector& constantInputs) {} + // Initializes null flags and accumulators for newly encountered groups. This // function should be called only once for each group. // diff --git a/velox/exec/AggregateCompanionAdapter.cpp b/velox/exec/AggregateCompanionAdapter.cpp index c1a087f24d7c..52cda091a1f7 100644 --- a/velox/exec/AggregateCompanionAdapter.cpp +++ b/velox/exec/AggregateCompanionAdapter.cpp @@ -73,6 +73,13 @@ void AggregateCompanionFunctionBase::clearInternal() { fn_->clear(); } +void AggregateCompanionFunctionBase::initialize( + const std::vector& rawInputType, + const facebook::velox::TypePtr& resultType, + const std::vector& constantInputs) { + fn_->initialize(rawInputType, resultType, constantInputs); +} + void AggregateCompanionFunctionBase::initializeNewGroups( char** groups, folly::Range indices) { @@ -229,6 +236,15 @@ void AggregateCompanionAdapter::ExtractFunction::apply( // Perform per-row aggregation. std::vector allSelectedRange; rows.applyToSelected([&](auto row) { allSelectedRange.push_back(row); }); + + // Get the raw input types. + std::vector rawInputTypes; + rawInputTypes.reserve(args.size()); + for (const auto& arg : args) { + rawInputTypes.emplace_back(arg->type()); + } + + fn_->initialize(rawInputTypes, outputType, {}); fn_->initializeNewGroups(groups, allSelectedRange); fn_->enableValidateIntermediateInputs(); fn_->addIntermediateResults(groups, rows, args, false); diff --git a/velox/exec/AggregateCompanionAdapter.h b/velox/exec/AggregateCompanionAdapter.h index 91b7c3a7bed8..597d6e0df0b2 100644 --- a/velox/exec/AggregateCompanionAdapter.h +++ b/velox/exec/AggregateCompanionAdapter.h @@ -38,6 +38,11 @@ class AggregateCompanionFunctionBase : public Aggregate { void destroy(folly::Range groups) override final; + void initialize( + const std::vector& rawInputType, + const TypePtr& resultType, + const std::vector& constantInputs) override; + void initializeNewGroups( char** groups, folly::Range indices) override final; diff --git a/velox/exec/AggregateInfo.cpp b/velox/exec/AggregateInfo.cpp index 9b19ab687965..6e5a2ff25f35 100644 --- a/velox/exec/AggregateInfo.cpp +++ b/velox/exec/AggregateInfo.cpp @@ -103,6 +103,8 @@ std::vector toAggregateInfo( aggResultType, operatorCtx.driverCtx()->queryConfig()); + info.function->initialize( + aggregate.rawInputTypes, aggResultType, info.constantInputs); auto lambdas = extractLambdaInputs(aggregate); if (!lambdas.empty()) { if (expressionEvaluator == nullptr) { diff --git a/velox/exec/AggregateWindow.cpp b/velox/exec/AggregateWindow.cpp index cb32bd0779c3..a4a15b8baccc 100644 --- a/velox/exec/AggregateWindow.cpp +++ b/velox/exec/AggregateWindow.cpp @@ -151,6 +151,7 @@ class AggregateWindowFunction : public exec::WindowFunction { // aggregate_ function object should be initialized. auto singleGroup = std::vector{0}; aggregate_->clear(); + aggregate_->initialize(argTypes_, resultType_, argVectors_); aggregate_->initializeNewGroups(&rawSingleGroupRow_, singleGroup); aggregateInitialized_ = true; } @@ -332,6 +333,7 @@ class AggregateWindowFunction : public exec::WindowFunction { // the aggregation based on the frame changes with each row. This would // require adding new APIs to the Aggregate framework. aggregate_->clear(); + aggregate_->initialize(argTypes_, resultType_, argVectors_); aggregate_->initializeNewGroups(&rawSingleGroupRow_, kSingleGroup); aggregateInitialized_ = true; @@ -349,6 +351,7 @@ class AggregateWindowFunction : public exec::WindowFunction { // This value is returned for rows with empty frames. void computeDefaultAggregateValue(const TypePtr& resultType) { aggregate_->clear(); + aggregate_->initialize(argTypes_, resultType, argVectors_); aggregate_->initializeNewGroups( &rawSingleGroupRow_, std::vector{0}); aggregateInitialized_ = true; diff --git a/velox/exec/SimpleAggregateAdapter.h b/velox/exec/SimpleAggregateAdapter.h index d92584c6d647..558deca4f793 100644 --- a/velox/exec/SimpleAggregateAdapter.h +++ b/velox/exec/SimpleAggregateAdapter.h @@ -44,6 +44,20 @@ class SimpleAggregateAdapter : public Aggregate { explicit SimpleAggregateAdapter(TypePtr resultType) : Aggregate(std::move(resultType)) {} + // Function-level states are variables hold by a UDAF instance that are + // typically computed once and used at every row when adding inputs to + // accumulators or extracting values from accumulators. + typename FUNC::FunctionState state_; + + void initialize( + const std::vector& rawInputTypes, + const TypePtr& resultType, + const std::vector& constantInputs) override { + if constexpr (support_initialize_function_state_) { + FUNC::initialize(state_, rawInputTypes, resultType, constantInputs); + } + } + // Assume most aggregate functions have fixed-size accumulators. Functions // that // have non-fixed-size accumulators should overwrite `is_fixed_size_` in their @@ -145,6 +159,19 @@ class SimpleAggregateAdapter : public Aggregate { struct support_to_intermediate> : std::true_type {}; + // Whether the function defines its initialize() method or not. If it is + // defined, SimpleAggregateAdapter::supportInitializeFunctionState() returns + // true. Otherwise, SimpleAggregateAdapter::supportInitializeFunctionState() + // returns false and SimpleAggregateAdapter::initialize() will not initialize + // the UDAF's FunctionState. + template + struct support_initialize_function_state : std::false_type {}; + + template + struct support_initialize_function_state< + T, + std::void_t> : std::true_type {}; + // Whether the accumulator requires aligned access. If it is defined, // SimpleAggregateAdapter::accumulatorAlignmentSize() returns // alignof(typename FUNC::AccumulatorType). @@ -175,6 +202,9 @@ class SimpleAggregateAdapter : public Aggregate { static constexpr bool accumulator_is_aligned_ = accumulator_is_aligned::value; + static constexpr bool support_initialize_function_state_ = + support_initialize_function_state::value; + bool isFixedSize() const override { return accumulator_is_fixed_size_; } @@ -301,12 +331,13 @@ class SimpleAggregateAdapter : public Aggregate { if (isNull(groups[i])) { writer.commitNull(); } else { - bool nonNull = group->writeIntermediateResult(writer.current()); + bool nonNull = + group->writeIntermediateResult(writer.current(), state_); writer.commit(nonNull); } } else { bool nonNull = group->writeIntermediateResult( - !isNull(groups[i]), writer.current()); + !isNull(groups[i]), writer.current(), state_); writer.commit(nonNull); } } @@ -332,12 +363,12 @@ class SimpleAggregateAdapter : public Aggregate { if (isNull(groups[i])) { writer.commitNull(); } else { - bool nonNull = group->writeFinalResult(writer.current()); + bool nonNull = group->writeFinalResult(writer.current(), state_); writer.commit(nonNull); } } else { - bool nonNull = - group->writeFinalResult(!isNull(groups[i]), writer.current()); + bool nonNull = group->writeFinalResult( + !isNull(groups[i]), writer.current(), state_); writer.commit(nonNull); } } @@ -350,7 +381,8 @@ class SimpleAggregateAdapter : public Aggregate { folly::Range indices) override { setAllNulls(groups, indices); for (auto i : indices) { - new (groups[i] + offset_) typename FUNC::AccumulatorType(allocator_); + new (groups[i] + offset_) + typename FUNC::AccumulatorType(allocator_, state_); } } @@ -387,7 +419,7 @@ class SimpleAggregateAdapter : public Aggregate { tracker.emplace(groups[row][rowSizeOffset_], *allocator_); } auto group = value(groups[row]); - group->addInput(allocator_, std::get(readers)[row]...); + group->addInput(allocator_, std::get(readers)[row]..., state_); clearNull(groups[row]); }); } else { @@ -400,7 +432,8 @@ class SimpleAggregateAdapter : public Aggregate { bool nonNull = group->addInput( allocator_, OptionalAccessor>{ - &std::get(readers), (int64_t)row}...); + &std::get(readers), (int64_t)row}..., + state_); if (nonNull) { clearNull(groups[row]); } @@ -427,7 +460,8 @@ class SimpleAggregateAdapter : public Aggregate { if constexpr (!accumulator_is_fixed_size_) { tracker.emplace(group[rowSizeOffset_], *allocator_); } - accumulator->addInput(allocator_, std::get(readers)[row]...); + accumulator->addInput( + allocator_, std::get(readers)[row]..., state_); clearNull(group); }); } else { @@ -439,7 +473,8 @@ class SimpleAggregateAdapter : public Aggregate { bool nonNull = accumulator->addInput( allocator_, OptionalAccessor>{ - &std::get(readers), (int64_t)row}...); + &std::get(readers), (int64_t)row}..., + state_); if (nonNull) { clearNull(group); } @@ -511,7 +546,7 @@ class SimpleAggregateAdapter : public Aggregate { tracker.emplace(groups[row][rowSizeOffset_], *allocator_); } auto group = value(groups[row]); - group->combine(allocator_, reader[row]); + group->combine(allocator_, reader[row], state_); clearNull(groups[row]); }); } else { @@ -524,7 +559,8 @@ class SimpleAggregateAdapter : public Aggregate { bool nonNull = group->combine( allocator_, OptionalAccessor{ - &reader, (int64_t)row}); + &reader, (int64_t)row}, + state_); if (nonNull) { clearNull(groups[row]); } @@ -549,7 +585,7 @@ class SimpleAggregateAdapter : public Aggregate { if constexpr (!accumulator_is_fixed_size_) { tracker.emplace(group[rowSizeOffset_], *allocator_); } - accumulator->combine(allocator_, reader[row]); + accumulator->combine(allocator_, reader[row], state_); clearNull(group); }); } else { @@ -561,7 +597,8 @@ class SimpleAggregateAdapter : public Aggregate { bool nonNull = accumulator->combine( allocator_, OptionalAccessor{ - &reader, (int64_t)row}); + &reader, (int64_t)row}, + state_); if (nonNull) { clearNull(group); } diff --git a/velox/exec/tests/SimpleAggregateAdapterTest.cpp b/velox/exec/tests/SimpleAggregateAdapterTest.cpp index 284814801739..92074f9865f9 100644 --- a/velox/exec/tests/SimpleAggregateAdapterTest.cpp +++ b/velox/exec/tests/SimpleAggregateAdapterTest.cpp @@ -357,6 +357,8 @@ class CountNullsAggregate { using IntermediateType = int64_t; // Intermediate result type. using OutputType = int64_t; // Output vector type. + struct FunctionState {}; + static constexpr bool default_null_behavior_ = false; struct Accumulator { @@ -364,13 +366,16 @@ class CountNullsAggregate { Accumulator() = delete; - explicit Accumulator(HashStringAllocator* /*allocator*/) { + explicit Accumulator( + HashStringAllocator* /*allocator*/, + const FunctionState& /*state*/) { nullsCount_ = 0; } bool addInput( HashStringAllocator* /*allocator*/, - exec::optional_arg_type data) { + exec::optional_arg_type data, + const FunctionState& /*state*/) { if (!data.has_value()) { nullsCount_++; return true; @@ -380,7 +385,8 @@ class CountNullsAggregate { bool combine( HashStringAllocator* /*allocator*/, - exec::optional_arg_type nullsCount) { + exec::optional_arg_type nullsCount, + const FunctionState& /*state*/) { if (nullsCount.has_value()) { nullsCount_ += nullsCount.value(); return true; @@ -388,13 +394,17 @@ class CountNullsAggregate { return false; } - bool writeFinalResult(bool nonNull, exec::out_type& out) { + bool writeFinalResult( + bool nonNull, + exec::out_type& out, + const FunctionState& /*state*/) { return writeResult(nonNull, out); } bool writeIntermediateResult( bool nonNull, - exec::out_type& out) { + exec::out_type& out, + const FunctionState& /*state*/) { return writeResult(nonNull, out); } diff --git a/velox/exec/tests/SimpleArrayAggAggregate.cpp b/velox/exec/tests/SimpleArrayAggAggregate.cpp index e3998aadf2a2..591865845b82 100644 --- a/velox/exec/tests/SimpleArrayAggAggregate.cpp +++ b/velox/exec/tests/SimpleArrayAggAggregate.cpp @@ -36,6 +36,8 @@ class ArrayAggAggregate { // Type of output vector. using OutputType = Array>; + struct FunctionState {}; + static constexpr bool default_null_behavior_ = false; static bool toIntermediate( @@ -55,7 +57,9 @@ class ArrayAggAggregate { AccumulatorType() = delete; // Constructor used in initializeNewGroups(). - explicit AccumulatorType(HashStringAllocator* /*allocator*/) + explicit AccumulatorType( + HashStringAllocator* /*allocator*/, + const FunctionState& /*state*/) : elements_{} {} static constexpr bool is_fixed_size_ = false; @@ -64,7 +68,8 @@ class ArrayAggAggregate { // child-type T wrapped in InputType. bool addInput( HashStringAllocator* allocator, - exec::optional_arg_type> data) { + exec::optional_arg_type> data, + const FunctionState& /*state*/) { elements_.appendValue(data, allocator); return true; } @@ -73,7 +78,8 @@ class ArrayAggAggregate { // exec::optional_arg_type. bool combine( HashStringAllocator* allocator, - exec::optional_arg_type>> other) { + exec::optional_arg_type>> other, + const FunctionState& /*state*/) { if (!other.has_value()) { return false; } @@ -85,7 +91,8 @@ class ArrayAggAggregate { bool writeFinalResult( bool nonNullGroup, - exec::out_type>>& out) { + exec::out_type>>& out, + const FunctionState& /*state*/) { if (!nonNullGroup) { return false; } @@ -95,7 +102,8 @@ class ArrayAggAggregate { bool writeIntermediateResult( bool nonNullGroup, - exec::out_type>>& out) { + exec::out_type>>& out, + const FunctionState& /*state*/) { // If the group's accumulator is null, the corresponding intermediate // result is null too. if (!nonNullGroup) { diff --git a/velox/exec/tests/SimpleAverageAggregate.cpp b/velox/exec/tests/SimpleAverageAggregate.cpp index 9f887f34eadf..9f9494648147 100644 --- a/velox/exec/tests/SimpleAverageAggregate.cpp +++ b/velox/exec/tests/SimpleAverageAggregate.cpp @@ -42,6 +42,8 @@ class AverageAggregate { using OutputType = std::conditional_t, float, double>; + struct FunctionState {}; + static bool toIntermediate( exec::out_type>& out, exec::arg_type in) { @@ -56,14 +58,19 @@ class AverageAggregate { AccumulatorType() = delete; // Constructor used in initializeNewGroups(). - explicit AccumulatorType(HashStringAllocator* /*allocator*/) { + explicit AccumulatorType( + HashStringAllocator* /*allocator*/, + const FunctionState& /*state*/) { sum_ = 0; count_ = 0; } // addInput expects one parameter of exec::arg_type for each child-type T // wrapped in InputType. - void addInput(HashStringAllocator* /*allocator*/, exec::arg_type data) { + void addInput( + HashStringAllocator* /*allocator*/, + exec::arg_type data, + const FunctionState& /*state*/) { sum_ += data; count_ = checkedPlus(count_, 1); } @@ -71,7 +78,8 @@ class AverageAggregate { // combine expects one parameter of exec::arg_type. void combine( HashStringAllocator* /*allocator*/, - exec::arg_type> other) { + exec::arg_type> other, + const FunctionState& /*state*/) { // Both field of an intermediate result should be non-null because // writeIntermediateResult() never make an intermediate result with a // single null. @@ -81,12 +89,16 @@ class AverageAggregate { count_ = checkedPlus(count_, other.at<1>().value()); } - bool writeFinalResult(exec::out_type& out) { + bool writeFinalResult( + exec::out_type& out, + const FunctionState& /*state*/) { out = sum_ / count_; return true; } - bool writeIntermediateResult(exec::out_type& out) { + bool writeIntermediateResult( + exec::out_type& out, + const FunctionState& /*state*/) { out = std::make_tuple(sum_, count_); return true; } diff --git a/velox/functions/prestosql/aggregates/BitwiseXorAggregate.cpp b/velox/functions/prestosql/aggregates/BitwiseXorAggregate.cpp index aa3a52d5fa28..8a8034ec3f8f 100644 --- a/velox/functions/prestosql/aggregates/BitwiseXorAggregate.cpp +++ b/velox/functions/prestosql/aggregates/BitwiseXorAggregate.cpp @@ -32,6 +32,8 @@ class BitwiseXorAggregate { using OutputType = T; + struct FunctionState {}; + static bool toIntermediate(exec::out_type& out, exec::arg_type in) { out = in; return true; @@ -42,22 +44,34 @@ class BitwiseXorAggregate { AccumulatorType() = delete; - explicit AccumulatorType(HashStringAllocator* /*allocator*/) {} + explicit AccumulatorType( + HashStringAllocator* /*allocator*/, + const FunctionState& /*state*/) {} - void addInput(HashStringAllocator* /*allocator*/, exec::arg_type data) { + void addInput( + HashStringAllocator* /*allocator*/, + exec::arg_type data, + const FunctionState& /*state*/) { xor_ ^= data; } - void combine(HashStringAllocator* /*allocator*/, exec::arg_type other) { + void combine( + HashStringAllocator* /*allocator*/, + exec::arg_type other, + const FunctionState& /*state*/) { xor_ ^= other; } - bool writeFinalResult(exec::out_type& out) { + bool writeFinalResult( + exec::out_type& out, + const FunctionState& /*state*/) { out = xor_; return true; } - bool writeIntermediateResult(exec::out_type& out) { + bool writeIntermediateResult( + exec::out_type& out, + const FunctionState& /*state*/) { out = xor_; return true; } diff --git a/velox/functions/prestosql/aggregates/GeometricMeanAggregate.cpp b/velox/functions/prestosql/aggregates/GeometricMeanAggregate.cpp index 9bbe9ec8c73b..703a2c5c3367 100644 --- a/velox/functions/prestosql/aggregates/GeometricMeanAggregate.cpp +++ b/velox/functions/prestosql/aggregates/GeometricMeanAggregate.cpp @@ -35,6 +35,8 @@ class GeometricMeanAggregate { using OutputType = TResult; + struct FunctionState {}; + static bool toIntermediate( exec::out_type>& out, exec::arg_type in) { @@ -50,30 +52,38 @@ class GeometricMeanAggregate { AccumulatorType() = delete; - explicit AccumulatorType(HashStringAllocator* /*allocator*/) {} + explicit AccumulatorType( + HashStringAllocator* /*allocator*/, + const FunctionState& /*state*/) {} void addInput( HashStringAllocator* /*allocator*/, - exec::arg_type data) { + exec::arg_type data, + const FunctionState& /*state*/) { logSum_ += std::log(data); count_ = checkedPlus(count_, 1); } void combine( HashStringAllocator* /*allocator*/, - exec::arg_type> other) { + exec::arg_type> other, + const FunctionState& /*state*/) { VELOX_CHECK(other.at<0>().has_value()); VELOX_CHECK(other.at<1>().has_value()); logSum_ += other.at<0>().value(); count_ = checkedPlus(count_, other.at<1>().value()); } - bool writeFinalResult(exec::out_type& out) { + bool writeFinalResult( + exec::out_type& out, + const FunctionState& /*state*/) { out = std::exp(logSum_ / count_); return true; } - bool writeIntermediateResult(exec::out_type& out) { + bool writeIntermediateResult( + exec::out_type& out, + const FunctionState& /*state*/) { out = std::make_tuple(logSum_, count_); return true; } diff --git a/velox/functions/sparksql/aggregates/DecimalSumAggregate.h b/velox/functions/sparksql/aggregates/DecimalSumAggregate.h index 5019ae1a2ca1..ce25c12af763 100644 --- a/velox/functions/sparksql/aggregates/DecimalSumAggregate.h +++ b/velox/functions/sparksql/aggregates/DecimalSumAggregate.h @@ -34,6 +34,8 @@ class DecimalSumAggregate { using OutputType = TSumType; + struct FunctionState {}; + /// Spark's decimal sum doesn't have the concept of a null group, each group /// is initialized with an initial value, where sum = 0 and isEmpty = true. /// The final agg may fallback to being executed in Spark, so the meaning of @@ -70,7 +72,9 @@ class DecimalSumAggregate { AccumulatorType() = delete; - explicit AccumulatorType(HashStringAllocator* /*allocator*/) {} + explicit AccumulatorType( + HashStringAllocator* /*allocator*/, + const FunctionState& /*state*/) {} std::optional computeFinalResult() const { if (!sum.has_value()) { @@ -92,7 +96,8 @@ class DecimalSumAggregate { bool addInput( HashStringAllocator* /*allocator*/, - exec::optional_arg_type data) { + exec::optional_arg_type data, + const FunctionState& /*state*/) { if (!data.has_value()) { return false; } @@ -112,7 +117,8 @@ class DecimalSumAggregate { bool combine( HashStringAllocator* /*allocator*/, - exec::optional_arg_type> other) { + exec::optional_arg_type> other, + const FunctionState& /*state*/) { if (!other.has_value()) { return false; } @@ -143,7 +149,8 @@ class DecimalSumAggregate { bool writeIntermediateResult( bool nonNullGroup, - exec::out_type& out) { + exec::out_type& out, + const FunctionState& /*state*/) { if (!nonNullGroup) { // If a group is null, all values in this group are null. In Spark, this // group will be the initial value, where sum is 0 and isEmpty is true. @@ -163,7 +170,10 @@ class DecimalSumAggregate { return true; } - bool writeFinalResult(bool nonNullGroup, exec::out_type& out) { + bool writeFinalResult( + bool nonNullGroup, + exec::out_type& out, + const FunctionState& /*state*/) { if (!nonNullGroup || isEmpty) { // If isEmpty is true, we should set null. return false; From 6fbe73098b6c9027a3b85f80d77edc2d41b859d2 Mon Sep 17 00:00:00 2001 From: "joey.ljy" Date: Sat, 13 Apr 2024 21:55:11 +0800 Subject: [PATCH 2/8] Address comments --- velox/docs/develop/aggregate-functions.rst | 1 + velox/exec/Aggregate.h | 2 ++ velox/exec/AggregateCompanionAdapter.cpp | 19 ++++++++++------- velox/exec/AggregateCompanionAdapter.h | 1 + velox/exec/AggregateInfo.cpp | 2 +- velox/exec/AggregateWindow.cpp | 21 ++++++++++++++++--- velox/exec/SimpleAggregateAdapter.h | 5 +++-- .../aggregates/CollectListAggregate.cpp | 18 +++++++++++----- 8 files changed, 50 insertions(+), 19 deletions(-) diff --git a/velox/docs/develop/aggregate-functions.rst b/velox/docs/develop/aggregate-functions.rst index 85c600035326..024a43057837 100644 --- a/velox/docs/develop/aggregate-functions.rst +++ b/velox/docs/develop/aggregate-functions.rst @@ -161,6 +161,7 @@ A simple aggregation function is implemented as a class as the following. // Optional. Used only when the UDAF needs to use FunctionState. static void initialize( + core::AggregationNode::Step step, FunctionState& state, const std::vector& rawInputTypes, const TypePtr& resultType, diff --git a/velox/exec/Aggregate.h b/velox/exec/Aggregate.h index 0830f7b7d9a9..7267bd748a40 100644 --- a/velox/exec/Aggregate.h +++ b/velox/exec/Aggregate.h @@ -131,10 +131,12 @@ class Aggregate { // Initialize the function-level state of the simple function interface for // UDAF. + // @param step The aggregation step. // @param rawInputType The raw input type of the UDAF. // @param resultType The result type of the UDAF. // @param constantInputs Optional constant inputs. virtual void initialize( + core::AggregationNode::Step step, const std::vector& rawInputType, const TypePtr& resultType, const std::vector& constantInputs) {} diff --git a/velox/exec/AggregateCompanionAdapter.cpp b/velox/exec/AggregateCompanionAdapter.cpp index 52cda091a1f7..423ee123a001 100644 --- a/velox/exec/AggregateCompanionAdapter.cpp +++ b/velox/exec/AggregateCompanionAdapter.cpp @@ -74,10 +74,11 @@ void AggregateCompanionFunctionBase::clearInternal() { } void AggregateCompanionFunctionBase::initialize( + core::AggregationNode::Step step, const std::vector& rawInputType, const facebook::velox::TypePtr& resultType, const std::vector& constantInputs) { - fn_->initialize(rawInputType, resultType, constantInputs); + fn_->initialize(step, rawInputType, resultType, constantInputs); } void AggregateCompanionFunctionBase::initializeNewGroups( @@ -238,13 +239,15 @@ void AggregateCompanionAdapter::ExtractFunction::apply( rows.applyToSelected([&](auto row) { allSelectedRange.push_back(row); }); // Get the raw input types. - std::vector rawInputTypes; - rawInputTypes.reserve(args.size()); - for (const auto& arg : args) { - rawInputTypes.emplace_back(arg->type()); - } - - fn_->initialize(rawInputTypes, outputType, {}); + std::vector rawInputTypes{args.size()}; + std::transform( + args.begin(), + args.end(), + rawInputTypes.begin(), + [](const VectorPtr& arg) { return arg->type(); }); + + fn_->initialize( + core::AggregationNode::Step::kFinal, rawInputTypes, outputType, {}); fn_->initializeNewGroups(groups, allSelectedRange); fn_->enableValidateIntermediateInputs(); fn_->addIntermediateResults(groups, rows, args, false); diff --git a/velox/exec/AggregateCompanionAdapter.h b/velox/exec/AggregateCompanionAdapter.h index 597d6e0df0b2..b12b87219a1a 100644 --- a/velox/exec/AggregateCompanionAdapter.h +++ b/velox/exec/AggregateCompanionAdapter.h @@ -39,6 +39,7 @@ class AggregateCompanionFunctionBase : public Aggregate { void destroy(folly::Range groups) override final; void initialize( + core::AggregationNode::Step step, const std::vector& rawInputType, const TypePtr& resultType, const std::vector& constantInputs) override; diff --git a/velox/exec/AggregateInfo.cpp b/velox/exec/AggregateInfo.cpp index 6e5a2ff25f35..0a51b2cfbfb1 100644 --- a/velox/exec/AggregateInfo.cpp +++ b/velox/exec/AggregateInfo.cpp @@ -104,7 +104,7 @@ std::vector toAggregateInfo( operatorCtx.driverCtx()->queryConfig()); info.function->initialize( - aggregate.rawInputTypes, aggResultType, info.constantInputs); + step, aggregate.rawInputTypes, aggResultType, info.constantInputs); auto lambdas = extractLambdaInputs(aggregate); if (!lambdas.empty()) { if (expressionEvaluator == nullptr) { diff --git a/velox/exec/AggregateWindow.cpp b/velox/exec/AggregateWindow.cpp index a4a15b8baccc..f184757ece54 100644 --- a/velox/exec/AggregateWindow.cpp +++ b/velox/exec/AggregateWindow.cpp @@ -45,8 +45,10 @@ class AggregateWindowFunction : public exec::WindowFunction { argTypes_.reserve(args.size()); argIndices_.reserve(args.size()); argVectors_.reserve(args.size()); + constantInputs_.reserve(args.size()); for (const auto& arg : args) { argTypes_.push_back(arg.type); + constantInputs_.push_back(arg.constantValue); if (arg.constantValue) { argIndices_.push_back(kConstantChannel); argVectors_.push_back(arg.constantValue); @@ -151,7 +153,11 @@ class AggregateWindowFunction : public exec::WindowFunction { // aggregate_ function object should be initialized. auto singleGroup = std::vector{0}; aggregate_->clear(); - aggregate_->initialize(argTypes_, resultType_, argVectors_); + aggregate_->initialize( + core::AggregationNode::Step::kSingle, + argTypes_, + resultType_, + constantInputs_); aggregate_->initializeNewGroups(&rawSingleGroupRow_, singleGroup); aggregateInitialized_ = true; } @@ -333,7 +339,11 @@ class AggregateWindowFunction : public exec::WindowFunction { // the aggregation based on the frame changes with each row. This would // require adding new APIs to the Aggregate framework. aggregate_->clear(); - aggregate_->initialize(argTypes_, resultType_, argVectors_); + aggregate_->initialize( + core::AggregationNode::Step::kSingle, + argTypes_, + resultType_, + constantInputs_); aggregate_->initializeNewGroups(&rawSingleGroupRow_, kSingleGroup); aggregateInitialized_ = true; @@ -351,7 +361,11 @@ class AggregateWindowFunction : public exec::WindowFunction { // This value is returned for rows with empty frames. void computeDefaultAggregateValue(const TypePtr& resultType) { aggregate_->clear(); - aggregate_->initialize(argTypes_, resultType, argVectors_); + aggregate_->initialize( + core::AggregationNode::Step::kSingle, + argTypes_, + resultType, + constantInputs_); aggregate_->initializeNewGroups( &rawSingleGroupRow_, std::vector{0}); aggregateInitialized_ = true; @@ -377,6 +391,7 @@ class AggregateWindowFunction : public exec::WindowFunction { std::vector argTypes_; std::vector argIndices_; std::vector argVectors_; + std::vector constantInputs_; // This is a single aggregate row needed by the aggregate function for its // computation. These values are for the row and its various components. diff --git a/velox/exec/SimpleAggregateAdapter.h b/velox/exec/SimpleAggregateAdapter.h index 558deca4f793..e04d34e69eeb 100644 --- a/velox/exec/SimpleAggregateAdapter.h +++ b/velox/exec/SimpleAggregateAdapter.h @@ -45,11 +45,12 @@ class SimpleAggregateAdapter : public Aggregate { : Aggregate(std::move(resultType)) {} // Function-level states are variables hold by a UDAF instance that are - // typically computed once and used at every row when adding inputs to - // accumulators or extracting values from accumulators. + // computed once and used at every row when adding inputs to accumulators or + // extracting values from accumulators. typename FUNC::FunctionState state_; void initialize( + core::AggregationNode::Step step, const std::vector& rawInputTypes, const TypePtr& resultType, const std::vector& constantInputs) override { diff --git a/velox/functions/sparksql/aggregates/CollectListAggregate.cpp b/velox/functions/sparksql/aggregates/CollectListAggregate.cpp index e2c14cfa7969..3a05e4cab1e7 100644 --- a/velox/functions/sparksql/aggregates/CollectListAggregate.cpp +++ b/velox/functions/sparksql/aggregates/CollectListAggregate.cpp @@ -32,6 +32,8 @@ class CollectListAggregate { using OutputType = Array>; + struct FunctionState {}; + /// In Spark, when all inputs are null, the output is an empty array instead /// of null. Therefore, in the writeIntermediateResult and writeFinalResult, /// we still need to output the empty element_ when the group is null. This @@ -51,14 +53,17 @@ class CollectListAggregate { struct AccumulatorType { ValueList elements_; - explicit AccumulatorType(HashStringAllocator* /*allocator*/) + explicit AccumulatorType( + HashStringAllocator* /*allocator*/, + const FunctionState& /*state*/) : elements_{} {} static constexpr bool is_fixed_size_ = false; bool addInput( HashStringAllocator* allocator, - exec::optional_arg_type> data) { + exec::optional_arg_type> data, + const FunctionState& /*state*/) { if (data.has_value()) { elements_.appendValue(data, allocator); return true; @@ -68,7 +73,8 @@ class CollectListAggregate { bool combine( HashStringAllocator* allocator, - exec::optional_arg_type other) { + exec::optional_arg_type other, + const FunctionState& /*state*/) { if (!other.has_value()) { return false; } @@ -80,7 +86,8 @@ class CollectListAggregate { bool writeIntermediateResult( bool /*nonNullGroup*/, - exec::out_type& out) { + exec::out_type& out, + const FunctionState& /*state*/) { // If the group's accumulator is null, the corresponding intermediate // result is an empty array. copyValueListToArrayWriter(out, elements_); @@ -89,7 +96,8 @@ class CollectListAggregate { bool writeFinalResult( bool /*nonNullGroup*/, - exec::out_type& out) { + exec::out_type& out, + const FunctionState& /*state*/) { // If the group's accumulator is null, the corresponding result is an // empty array. copyValueListToArrayWriter(out, elements_); From 3906ab7a444490e539964f388e7ebfd4ba416f3e Mon Sep 17 00:00:00 2001 From: "joey.ljy" Date: Tue, 16 Apr 2024 20:47:23 +0800 Subject: [PATCH 3/8] Add test cases --- velox/exec/AggregateCompanionAdapter.cpp | 32 ++- velox/exec/AggregateCompanionAdapter.h | 18 +- velox/exec/SimpleAggregateAdapter.h | 2 +- .../exec/tests/SimpleAggregateAdapterTest.cpp | 233 +++++++++++++++++- .../tests/utils/AggregationTestBase.cpp | 15 ++ 5 files changed, 284 insertions(+), 16 deletions(-) diff --git a/velox/exec/AggregateCompanionAdapter.cpp b/velox/exec/AggregateCompanionAdapter.cpp index 423ee123a001..64ee49fe127b 100644 --- a/velox/exec/AggregateCompanionAdapter.cpp +++ b/velox/exec/AggregateCompanionAdapter.cpp @@ -73,14 +73,6 @@ void AggregateCompanionFunctionBase::clearInternal() { fn_->clear(); } -void AggregateCompanionFunctionBase::initialize( - core::AggregationNode::Step step, - const std::vector& rawInputType, - const facebook::velox::TypePtr& resultType, - const std::vector& constantInputs) { - fn_->initialize(step, rawInputType, resultType, constantInputs); -} - void AggregateCompanionFunctionBase::initializeNewGroups( char** groups, folly::Range indices) { @@ -132,6 +124,18 @@ void AggregateCompanionFunctionBase::extractAccumulators( fn_->extractAccumulators(groups, numGroups, result); } +void AggregateCompanionAdapter::PartialFunction::initialize( + core::AggregationNode::Step /*step*/, + const std::vector& rawInputType, + const facebook::velox::TypePtr& resultType, + const std::vector& constantInputs) { + fn_->initialize( + core::AggregationNode::Step::kPartial, + rawInputType, + resultType, + constantInputs); +} + void AggregateCompanionAdapter::PartialFunction::extractValues( char** groups, int32_t numGroups, @@ -139,6 +143,18 @@ void AggregateCompanionAdapter::PartialFunction::extractValues( fn_->extractAccumulators(groups, numGroups, result); } +void AggregateCompanionAdapter::MergeFunction::initialize( + core::AggregationNode::Step /*step*/, + const std::vector& rawInputType, + const facebook::velox::TypePtr& resultType, + const std::vector& constantInputs) { + fn_->initialize( + core::AggregationNode::Step::kIntermediate, + rawInputType, + resultType, + constantInputs); +} + void AggregateCompanionAdapter::MergeFunction::addRawInput( char** groups, const SelectivityVector& rows, diff --git a/velox/exec/AggregateCompanionAdapter.h b/velox/exec/AggregateCompanionAdapter.h index b12b87219a1a..24b196e3f60e 100644 --- a/velox/exec/AggregateCompanionAdapter.h +++ b/velox/exec/AggregateCompanionAdapter.h @@ -38,12 +38,6 @@ class AggregateCompanionFunctionBase : public Aggregate { void destroy(folly::Range groups) override final; - void initialize( - core::AggregationNode::Step step, - const std::vector& rawInputType, - const TypePtr& resultType, - const std::vector& constantInputs) override; - void initializeNewGroups( char** groups, folly::Range indices) override final; @@ -105,6 +99,12 @@ struct AggregateCompanionAdapter { const TypePtr& resultType) : AggregateCompanionFunctionBase{std::move(fn), resultType} {} + void initialize( + core::AggregationNode::Step step, + const std::vector& rawInputType, + const TypePtr& resultType, + const std::vector& constantInputs) override; + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) override; }; @@ -116,6 +116,12 @@ struct AggregateCompanionAdapter { const TypePtr& resultType) : AggregateCompanionFunctionBase{std::move(fn), resultType} {} + void initialize( + core::AggregationNode::Step step, + const std::vector& rawInputType, + const TypePtr& resultType, + const std::vector& constantInputs) override; + void addRawInput( char** groups, const SelectivityVector& rows, diff --git a/velox/exec/SimpleAggregateAdapter.h b/velox/exec/SimpleAggregateAdapter.h index e04d34e69eeb..e236e32d1378 100644 --- a/velox/exec/SimpleAggregateAdapter.h +++ b/velox/exec/SimpleAggregateAdapter.h @@ -55,7 +55,7 @@ class SimpleAggregateAdapter : public Aggregate { const TypePtr& resultType, const std::vector& constantInputs) override { if constexpr (support_initialize_function_state_) { - FUNC::initialize(state_, rawInputTypes, resultType, constantInputs); + FUNC::initialize(state_, step, rawInputTypes, resultType, constantInputs); } } diff --git a/velox/exec/tests/SimpleAggregateAdapterTest.cpp b/velox/exec/tests/SimpleAggregateAdapterTest.cpp index 92074f9865f9..f24d4bd4af22 100644 --- a/velox/exec/tests/SimpleAggregateAdapterTest.cpp +++ b/velox/exec/tests/SimpleAggregateAdapterTest.cpp @@ -15,12 +15,13 @@ */ #include "velox/exec/SimpleAggregateAdapter.h" -#include "velox/exec/Aggregate.h" #include "velox/exec/tests/SimpleAggregateFunctionsRegistration.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h" using namespace facebook::velox::exec; using namespace facebook::velox::exec::test; +using facebook::velox::common::testutil::TestValue; using facebook::velox::functions::aggregate::test::AggregationTestBase; namespace facebook::velox::aggregate::test { @@ -29,6 +30,7 @@ namespace { const char* const kSimpleAvg = "simple_avg"; const char* const kSimpleArrayAgg = "simple_array_agg"; const char* const kSimpleCountNulls = "simple_count_nulls"; +const char* const kSimpleFunctionStateAgg = "simple_function_state_agg"; class SimpleAverageAggregationTest : public AggregationTestBase { protected: @@ -481,5 +483,234 @@ TEST_F(SimpleCountNullsAggregationTest, basic) { testAggregations({vectors}, {}, {"simple_count_nulls(c2)"}, {expected}); } +// A testing aggregation function that uses the function state. +class FunctionStateTestAggregate { + public: + using InputType = Row; // Input vector type wrapped in Row. + using IntermediateType = int64_t; // Intermediate result type. + using OutputType = int64_t; // Output vector type. + + struct FunctionState { + core::AggregationNode::Step step; + std::vector rawInputType; + TypePtr resultType; + std::vector constantInputs; + }; + + static void initialize( + FunctionState& state, + core::AggregationNode::Step step, + const std::vector& rawInputTypes, + const TypePtr& resultType, + const std::vector& constantInputs) { + state.step = step; + state.rawInputType = rawInputTypes; + state.resultType = resultType; + if (resultType == nullptr) { + LOG(INFO) << "nullptr"; + } + state.constantInputs = constantInputs; + } + + struct Accumulator { + int64_t sum{0}; + + explicit Accumulator( + HashStringAllocator* /*allocator*/, + const FunctionState& /*state*/) {} + + void checkpoint(FunctionState* state) { + TestValue::adjust( + "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", + state); + } + + void addInput( + HashStringAllocator* /*allocator*/, + exec::arg_type data, + const FunctionState& state) { + checkpoint(const_cast(&state)); + sum += data; + } + + void combine( + HashStringAllocator* /*allocator*/, + exec::arg_type other, + const FunctionState& state) { + checkpoint(const_cast(&state)); + sum += other; + } + + bool writeIntermediateResult( + exec::out_type& out, + const FunctionState& state) { + checkpoint(const_cast(&state)); + out = sum; + return true; + } + + bool writeFinalResult( + exec::out_type& out, + const FunctionState& state) { + checkpoint(const_cast(&state)); + out = sum; + return true; + } + }; + + using AccumulatorType = Accumulator; +}; + +exec::AggregateRegistrationResult registerSimpleFunctionStateTestAggregate( + const std::string& name) { + std::vector> signatures{ + exec::AggregateFunctionSignatureBuilder() + .returnType("bigint") + .intermediateType("bigint") + .argumentType("bigint") + .build()}; + + return exec::registerAggregateFunction( + name, + std::move(signatures), + [name]( + core::AggregationNode::Step /*step*/, + const std::vector& argTypes, + const TypePtr& resultType, + const core::QueryConfig& /*config*/) + -> std::unique_ptr { + VELOX_CHECK_LE( + argTypes.size(), 1, "{} takes at most one argument", name); + return std::make_unique< + SimpleAggregateAdapter>(resultType); + }, + true /*registerCompanionFunctions*/, + true /*overwrite*/); +} + +void registerFunctionStateTestAggregate() { + registerSimpleFunctionStateTestAggregate(kSimpleFunctionStateAgg); +} + +class SimpleFunctionStateAggregationTest : public AggregationTestBase { + protected: + SimpleFunctionStateAggregationTest() { + registerFunctionStateTestAggregate(); + } + + static void checkState( + FunctionStateTestAggregate::FunctionState* state, + const std::string& step = "") { + VELOX_CHECK_NOT_NULL(state->resultType); + VELOX_CHECK(!state->rawInputType.empty()); + if (!step.empty()) { + VELOX_CHECK_EQ(core::AggregationNode::stepName(state->step), step); + } + } +}; + +TEST_F(SimpleFunctionStateAggregationTest, aggregate) { + auto inputVectors = makeRowVector({makeFlatVector({1, 2, 3, 4})}); + std::vector sum = {10}; + auto expected = makeRowVector({makeFlatVector(sum)}); + + SCOPED_TESTVALUE_SET( + "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", + std::function( + [&](FunctionStateTestAggregate::FunctionState* state) { + checkState(state); + })); + + testAggregations( + {inputVectors}, {}, {"simple_function_state_agg(c0)"}, {expected}); + testAggregationsWithCompanion( + {inputVectors}, + [](auto& /*builder*/) {}, + {}, + {"simple_function_state_agg(c0)"}, + {{BIGINT()}}, + {}, + {expected}, + {}); +} + +TEST_F(SimpleFunctionStateAggregationTest, window) { + auto inputVectors = + makeRowVector({makeFlatVector({1, 1, 2, 2, 3, 3, 4})}); + auto expected = + makeRowVector({makeFlatVector({2, 2, 4, 4, 6, 6, 4})}); + SCOPED_TESTVALUE_SET( + "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", + std::function( + [&](FunctionStateTestAggregate::FunctionState* state) { + checkState(state, "SINGLE"); + })); + auto plan = + PlanBuilder() + .values({inputVectors}) + .window({"simple_function_state_agg(c0) over (partition by c0)"}) + .project({"w0"}) + .planNode(); + AssertQueryBuilder(plan).assertResults(expected); +} + +TEST_F(SimpleFunctionStateAggregationTest, aggregateStep) { + auto inputVectors = makeRowVector({makeFlatVector({1, 2, 3, 4})}); + std::vector sum = {10}; + auto expected = makeRowVector({makeFlatVector(sum)}); + SCOPED_TESTVALUE_SET( + "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", + std::function( + [&](FunctionStateTestAggregate::FunctionState* state) { + checkState(state, "PARTIAL"); + })); + AssertQueryBuilder( + PlanBuilder() + .values({inputVectors}) + .singleAggregation({}, {"simple_function_state_agg_partial(c0)"}) + .planNode()) + .assertResults(expected); + + SCOPED_TESTVALUE_SET( + "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", + std::function( + [&](FunctionStateTestAggregate::FunctionState* state) { + checkState(state, "INTERMEDIATE"); + })); + AssertQueryBuilder( + PlanBuilder() + .values({inputVectors}) + .singleAggregation({}, {"simple_function_state_agg_merge(c0)"}) + .planNode()) + .assertResults(expected); + + SCOPED_TESTVALUE_SET( + "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", + std::function( + [&](FunctionStateTestAggregate::FunctionState* state) { + checkState(state, "INTERMEDIATE"); + })); + AssertQueryBuilder( + PlanBuilder() + .values({inputVectors}) + .singleAggregation( + {}, {"simple_function_state_agg_merge_extract(c0)"}) + .planNode()) + .assertResults(expected); + + SCOPED_TESTVALUE_SET( + "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", + std::function( + [&](FunctionStateTestAggregate::FunctionState* state) { + checkState(state, "FINAL"); + })); + AssertQueryBuilder( + PlanBuilder() + .values({inputVectors}) + .finalAggregation({}, {"simple_function_state_agg(c0)"}, {{BIGINT()}}) + .planNode()) + .assertResults(expected); +} + } // namespace } // namespace facebook::velox::aggregate::test diff --git a/velox/functions/lib/aggregates/tests/utils/AggregationTestBase.cpp b/velox/functions/lib/aggregates/tests/utils/AggregationTestBase.cpp index 16c0ed4f5dad..5b3dd3382dc3 100644 --- a/velox/functions/lib/aggregates/tests/utils/AggregationTestBase.cpp +++ b/velox/functions/lib/aggregates/tests/utils/AggregationTestBase.cpp @@ -1256,6 +1256,11 @@ void AggregationTestBase::testIncrementalAggregation( std::vector group(kOffset + func->accumulatorFixedWidthSize()); std::vector groups(inputSize, group.data()); std::vector indices(1, 0); + func->initialize( + aggregationNode.step(), + aggregate.rawInputTypes, + func->resultType(), + {}); func->initializeNewGroups(groups.data(), indices); func->addSingleGroupRawInput( group.data(), SelectivityVector(inputSize), input, false); @@ -1308,6 +1313,11 @@ VectorPtr AggregationTestBase::testStreaming( std::vector group(kOffset + func->accumulatorFixedWidthSize()); std::vector groups(maxRowCount, group.data()); std::vector indices(maxRowCount, 0); + func->initialize( + core::AggregationNode::Step::kSingle, + rawInputTypes, + func->resultType(), + {}); func->initializeNewGroups(groups.data(), indices); if (testGlobal) { func->addSingleGroupRawInput( @@ -1327,6 +1337,11 @@ VectorPtr AggregationTestBase::testStreaming( // Create a new function picking up the intermediate result. auto func2 = createAggregateFunction(functionName, rawInputTypes, allocator, config); + func2->initialize( + core::AggregationNode::Step::kSingle, + rawInputTypes, + func2->resultType(), + {}); func2->initializeNewGroups(groups.data(), indices); if (testGlobal) { func2->addSingleGroupIntermediateResults( From 75a908b3915e88a42c0349691d95720ca713209d Mon Sep 17 00:00:00 2001 From: "joey.ljy" Date: Wed, 17 Apr 2024 20:21:18 +0800 Subject: [PATCH 4/8] Optimize test cases Fix --- velox/exec/AggregateCompanionAdapter.cpp | 19 +- .../exec/tests/SimpleAggregateAdapterTest.cpp | 214 +++++++++++++----- .../tests/utils/AggregationTestBase.cpp | 43 +++- .../aggregates/RegrReplacementAggregate.cpp | 20 +- 4 files changed, 215 insertions(+), 81 deletions(-) diff --git a/velox/exec/AggregateCompanionAdapter.cpp b/velox/exec/AggregateCompanionAdapter.cpp index 64ee49fe127b..392111e992cc 100644 --- a/velox/exec/AggregateCompanionAdapter.cpp +++ b/velox/exec/AggregateCompanionAdapter.cpp @@ -256,14 +256,21 @@ void AggregateCompanionAdapter::ExtractFunction::apply( // Get the raw input types. std::vector rawInputTypes{args.size()}; - std::transform( - args.begin(), - args.end(), - rawInputTypes.begin(), - [](const VectorPtr& arg) { return arg->type(); }); + std::vector constantInputs{args.size()}; + for (auto i = 0; i < args.size(); i++) { + rawInputTypes[i] = args[i]->type(); + if (args[i]->isConstantEncoding()) { + constantInputs[i] = args[i]; + } else { + constantInputs[i] = nullptr; + } + } fn_->initialize( - core::AggregationNode::Step::kFinal, rawInputTypes, outputType, {}); + core::AggregationNode::Step::kFinal, + rawInputTypes, + outputType, + constantInputs); fn_->initializeNewGroups(groups, allSelectedRange); fn_->enableValidateIntermediateInputs(); fn_->addIntermediateResults(groups, rows, args, false); diff --git a/velox/exec/tests/SimpleAggregateAdapterTest.cpp b/velox/exec/tests/SimpleAggregateAdapterTest.cpp index f24d4bd4af22..3ae60fcb15b4 100644 --- a/velox/exec/tests/SimpleAggregateAdapterTest.cpp +++ b/velox/exec/tests/SimpleAggregateAdapterTest.cpp @@ -486,13 +486,13 @@ TEST_F(SimpleCountNullsAggregationTest, basic) { // A testing aggregation function that uses the function state. class FunctionStateTestAggregate { public: - using InputType = Row; // Input vector type wrapped in Row. - using IntermediateType = int64_t; // Intermediate result type. - using OutputType = int64_t; // Output vector type. + using InputType = Row; // Input vector type wrapped in Row. + using IntermediateType = Row; // Intermediate result type. + using OutputType = double; // Output vector type. struct FunctionState { core::AggregationNode::Step step; - std::vector rawInputType; + std::vector rawInputTypes; TypePtr resultType; std::vector constantInputs; }; @@ -504,16 +504,14 @@ class FunctionStateTestAggregate { const TypePtr& resultType, const std::vector& constantInputs) { state.step = step; - state.rawInputType = rawInputTypes; + state.rawInputTypes = rawInputTypes; state.resultType = resultType; - if (resultType == nullptr) { - LOG(INFO) << "nullptr"; - } state.constantInputs = constantInputs; } struct Accumulator { int64_t sum{0}; + double count{0}; explicit Accumulator( HashStringAllocator* /*allocator*/, @@ -528,9 +526,11 @@ class FunctionStateTestAggregate { void addInput( HashStringAllocator* /*allocator*/, exec::arg_type data, + exec::arg_type increment, const FunctionState& state) { checkpoint(const_cast(&state)); sum += data; + count += increment; } void combine( @@ -538,14 +538,17 @@ class FunctionStateTestAggregate { exec::arg_type other, const FunctionState& state) { checkpoint(const_cast(&state)); - sum += other; + VELOX_CHECK(other.at<0>().has_value()); + VELOX_CHECK(other.at<1>().has_value()); + sum += other.at<0>().value(); + count += other.at<1>().value(); } bool writeIntermediateResult( exec::out_type& out, const FunctionState& state) { checkpoint(const_cast(&state)); - out = sum; + out = std::make_tuple(sum, count); return true; } @@ -553,7 +556,7 @@ class FunctionStateTestAggregate { exec::out_type& out, const FunctionState& state) { checkpoint(const_cast(&state)); - out = sum; + out = sum / count; return true; } }; @@ -565,9 +568,10 @@ exec::AggregateRegistrationResult registerSimpleFunctionStateTestAggregate( const std::string& name) { std::vector> signatures{ exec::AggregateFunctionSignatureBuilder() - .returnType("bigint") - .intermediateType("bigint") - .argumentType("bigint") + .returnType("DOUBLE") + .intermediateType("ROW(BIGINT, DOUBLE)") + .argumentType("BIGINT") + .argumentType("BIGINT") .build()}; return exec::registerAggregateFunction( @@ -579,8 +583,7 @@ exec::AggregateRegistrationResult registerSimpleFunctionStateTestAggregate( const TypePtr& resultType, const core::QueryConfig& /*config*/) -> std::unique_ptr { - VELOX_CHECK_LE( - argTypes.size(), 1, "{} takes at most one argument", name); + VELOX_CHECK_LE(argTypes.size(), 2, "{} takes 2 argument", name); return std::make_unique< SimpleAggregateAdapter>(resultType); }, @@ -598,97 +601,197 @@ class SimpleFunctionStateAggregationTest : public AggregationTestBase { registerFunctionStateTestAggregate(); } + static void checkRowTypeEqual(TypePtr expected, TypePtr actual) { + VELOX_CHECK(expected->isRow()); + VELOX_CHECK(actual->isRow()); + VELOX_CHECK_EQ(expected->asRow().size(), actual->asRow().size()); + for (auto i = 0; i < expected->asRow().size(); i++) { + VELOX_CHECK_EQ(expected->asRow().childAt(i), actual->asRow().childAt(i)); + } + } + static void checkState( FunctionStateTestAggregate::FunctionState* state, - const std::string& step = "") { + const std::vector& rawInputTypes, + const TypePtr& intermediateType, + const TypePtr& resultType, + const std::vector& constantInputs) { + VELOX_CHECK(!state->rawInputTypes.empty()); VELOX_CHECK_NOT_NULL(state->resultType); - VELOX_CHECK(!state->rawInputType.empty()); - if (!step.empty()) { - VELOX_CHECK_EQ(core::AggregationNode::stepName(state->step), step); + + switch (state->step) { + case core::AggregationNode::Step::kPartial: + case core::AggregationNode::Step::kIntermediate: + if (state->rawInputTypes.size() == 1 && + state->rawInputTypes[0]->isRow()) { + // Merge or merge_extract companion function. + VELOX_CHECK_EQ(rawInputTypes.size(), 1); + VELOX_CHECK(rawInputTypes[0]->isRow()); + checkRowTypeEqual(rawInputTypes[0], state->rawInputTypes[0]); + } else { + VELOX_CHECK(std::equal( + rawInputTypes.begin(), + rawInputTypes.end(), + state->rawInputTypes.begin(), + state->rawInputTypes.end())); + } + if (state->resultType->isRow()) { + checkRowTypeEqual(intermediateType, state->resultType); + } else { + VELOX_CHECK_EQ(resultType, state->resultType) + } + break; + + case core::AggregationNode::Step::kSingle: + case core::AggregationNode::Step::kFinal: + VELOX_CHECK(std::equal( + rawInputTypes.begin(), + rawInputTypes.end(), + state->rawInputTypes.begin(), + state->rawInputTypes.end())); + VELOX_CHECK_EQ(resultType, state->resultType); + break; + + default: + VELOX_FAIL("Unknown aggregate step"); + break; + } + + VELOX_CHECK(!state->constantInputs.empty()); + if (state->step == core::AggregationNode::Step::kPartial || + state->step == core::AggregationNode::Step::kSingle) { + VELOX_CHECK_EQ(constantInputs.size(), state->constantInputs.size()); + for (auto i = 0; i < constantInputs.size(); i++) { + auto expected = constantInputs[i]; + auto actual = state->constantInputs[i]; + if (expected == nullptr && actual == nullptr) { + continue; + } else { + VELOX_CHECK(expected != nullptr && actual != nullptr); + VELOX_CHECK(expected->isConstantEncoding()); + VELOX_CHECK(actual->isConstantEncoding()); + VELOX_CHECK_EQ( + expected->asUnchecked>()->valueAt(0), + actual->asUnchecked>()->valueAt(0)); + } + } + } else { + VELOX_CHECK_EQ(state->constantInputs.size(), 1); + VELOX_CHECK_NULL(state->constantInputs[0]); } } }; TEST_F(SimpleFunctionStateAggregationTest, aggregate) { auto inputVectors = makeRowVector({makeFlatVector({1, 2, 3, 4})}); - std::vector sum = {10}; - auto expected = makeRowVector({makeFlatVector(sum)}); + std::vector finalResult = {2.5}; + auto expected = makeRowVector({makeFlatVector(finalResult)}); SCOPED_TESTVALUE_SET( "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", std::function( [&](FunctionStateTestAggregate::FunctionState* state) { - checkState(state); + checkState( + state, + {BIGINT(), BIGINT()}, + ROW({BIGINT(), DOUBLE()}), + DOUBLE(), + {nullptr, makeConstant(1, 4)}); })); - testAggregations( - {inputVectors}, {}, {"simple_function_state_agg(c0)"}, {expected}); - testAggregationsWithCompanion( - {inputVectors}, - [](auto& /*builder*/) {}, - {}, - {"simple_function_state_agg(c0)"}, - {{BIGINT()}}, - {}, - {expected}, - {}); + {inputVectors}, {}, {"simple_function_state_agg(c0, 1)"}, {expected}); } TEST_F(SimpleFunctionStateAggregationTest, window) { auto inputVectors = makeRowVector({makeFlatVector({1, 1, 2, 2, 3, 3, 4})}); - auto expected = - makeRowVector({makeFlatVector({2, 2, 4, 4, 6, 6, 4})}); + auto expected = makeRowVector( + {makeFlatVector({1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0})}); SCOPED_TESTVALUE_SET( "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", std::function( [&](FunctionStateTestAggregate::FunctionState* state) { - checkState(state, "SINGLE"); + checkState( + state, + {BIGINT(), BIGINT()}, + ROW({BIGINT(), DOUBLE()}), + DOUBLE(), + {nullptr, makeConstant(1, 7)}); })); auto plan = PlanBuilder() .values({inputVectors}) - .window({"simple_function_state_agg(c0) over (partition by c0)"}) + .window({"simple_function_state_agg(c0, 1) over (partition by c0)"}) .project({"w0"}) .planNode(); AssertQueryBuilder(plan).assertResults(expected); } -TEST_F(SimpleFunctionStateAggregationTest, aggregateStep) { +TEST_F(SimpleFunctionStateAggregationTest, companionAggregateFunction) { auto inputVectors = makeRowVector({makeFlatVector({1, 2, 3, 4})}); - std::vector sum = {10}; - auto expected = makeRowVector({makeFlatVector(sum)}); + std::vector accSum = {10}; + std::vector accCount = {4.0}; + auto intermediateExpected = makeRowVector({ + makeRowVector({ + makeFlatVector(accSum), + makeFlatVector(accCount), + }), + }); + std::vector finalResult = {2.5}; + auto finalExpected = makeRowVector({makeFlatVector(finalResult)}); + SCOPED_TESTVALUE_SET( "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", std::function( [&](FunctionStateTestAggregate::FunctionState* state) { - checkState(state, "PARTIAL"); + checkState( + state, + {BIGINT(), BIGINT()}, + ROW({BIGINT(), DOUBLE()}), + DOUBLE(), + {nullptr, makeConstant(1, 4)}); })); AssertQueryBuilder( PlanBuilder() .values({inputVectors}) - .singleAggregation({}, {"simple_function_state_agg_partial(c0)"}) + .singleAggregation({}, {"simple_function_state_agg_partial(c0, 1)"}) .planNode()) - .assertResults(expected); - + .assertResults(intermediateExpected); + + inputVectors = makeRowVector({ + makeRowVector({ + makeFlatVector({1, 2, 3, 4}), + makeFlatVector({1.0, 1.0, 1.0, 1.0}), + }), + }); SCOPED_TESTVALUE_SET( "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", std::function( [&](FunctionStateTestAggregate::FunctionState* state) { - checkState(state, "INTERMEDIATE"); + checkState( + state, + {ROW({BIGINT(), DOUBLE()})}, + ROW({BIGINT(), DOUBLE()}), + ROW({BIGINT(), DOUBLE()}), + {nullptr, makeConstant(1, 4)}); })); AssertQueryBuilder( PlanBuilder() .values({inputVectors}) .singleAggregation({}, {"simple_function_state_agg_merge(c0)"}) .planNode()) - .assertResults(expected); + .assertResults(intermediateExpected); SCOPED_TESTVALUE_SET( "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", std::function( [&](FunctionStateTestAggregate::FunctionState* state) { - checkState(state, "INTERMEDIATE"); + checkState( + state, + {ROW({BIGINT(), DOUBLE()})}, + ROW({BIGINT(), DOUBLE()}), + DOUBLE(), + {nullptr, makeConstant(1, 4)}); })); AssertQueryBuilder( PlanBuilder() @@ -696,20 +799,7 @@ TEST_F(SimpleFunctionStateAggregationTest, aggregateStep) { .singleAggregation( {}, {"simple_function_state_agg_merge_extract(c0)"}) .planNode()) - .assertResults(expected); - - SCOPED_TESTVALUE_SET( - "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", - std::function( - [&](FunctionStateTestAggregate::FunctionState* state) { - checkState(state, "FINAL"); - })); - AssertQueryBuilder( - PlanBuilder() - .values({inputVectors}) - .finalAggregation({}, {"simple_function_state_agg(c0)"}, {{BIGINT()}}) - .planNode()) - .assertResults(expected); + .assertResults(finalExpected); } } // namespace diff --git a/velox/functions/lib/aggregates/tests/utils/AggregationTestBase.cpp b/velox/functions/lib/aggregates/tests/utils/AggregationTestBase.cpp index 5b3dd3382dc3..2a0e0e82a60c 100644 --- a/velox/functions/lib/aggregates/tests/utils/AggregationTestBase.cpp +++ b/velox/functions/lib/aggregates/tests/utils/AggregationTestBase.cpp @@ -1231,13 +1231,27 @@ void AggregationTestBase::testIncrementalAggregation( const auto& aggregateExpr = aggregate.call; const auto& functionName = aggregateExpr->name(); auto input = extractArgColumns(aggregateExpr, data, pool()); + const auto& inputType = aggregationNode.sources()[0]->outputType(); HashStringAllocator allocator(pool()); std::vector lambdas; + std::vector constantInputs; for (const auto& arg : aggregate.call->inputs()) { - if (auto lambda = + if (dynamic_cast(arg.get())) { + constantInputs.push_back(nullptr); + } else if ( + auto constant = + dynamic_cast(arg.get())) { + constantInputs.push_back(constant->toConstantVector(pool())); + } else if ( + auto lambda = std::dynamic_pointer_cast(arg)) { lambdas.push_back(lambda); + for (const auto& name : lambda->signature()->names()) { + if (auto captureIndex = inputType->getChildIdxIfExists(name)) { + constantInputs.push_back(nullptr); + } + } } } auto queryCtxConfig = config; @@ -1260,7 +1274,7 @@ void AggregationTestBase::testIncrementalAggregation( aggregationNode.step(), aggregate.rawInputTypes, func->resultType(), - {}); + constantInputs); func->initializeNewGroups(groups.data(), indices); func->addSingleGroupRawInput( group.data(), SelectivityVector(inputSize), input, false); @@ -1300,11 +1314,24 @@ VectorPtr AggregationTestBase::testStreaming( vector_size_t rawInput2Size, const std::unordered_map& config) { std::vector rawInputTypes(rawInput1.size()); + std::vector constantInputs1(rawInput1.size()); + std::vector constantInputs2(rawInput2.size()); + for (auto i = 0; i < rawInput1.size(); i++) { + rawInputTypes[i] = rawInput1[i]->type(); + if (rawInput1[i]->isConstantEncoding()) { + constantInputs1[i] = rawInput1[i]; + } else { + constantInputs1[i] = nullptr; + } + } + std::transform( - rawInput1.begin(), - rawInput1.end(), - rawInputTypes.begin(), - [](const VectorPtr& vec) { return vec->type(); }); + rawInput2.begin(), + rawInput2.end(), + constantInputs2.begin(), + [](const VectorPtr& vec) { + return vec->isConstantEncoding() ? vec : nullptr; + }); HashStringAllocator allocator(pool()); auto func = @@ -1317,7 +1344,7 @@ VectorPtr AggregationTestBase::testStreaming( core::AggregationNode::Step::kSingle, rawInputTypes, func->resultType(), - {}); + constantInputs1); func->initializeNewGroups(groups.data(), indices); if (testGlobal) { func->addSingleGroupRawInput( @@ -1341,7 +1368,7 @@ VectorPtr AggregationTestBase::testStreaming( core::AggregationNode::Step::kSingle, rawInputTypes, func2->resultType(), - {}); + constantInputs2); func2->initializeNewGroups(groups.data(), indices); if (testGlobal) { func2->addSingleGroupIntermediateResults( diff --git a/velox/functions/sparksql/aggregates/RegrReplacementAggregate.cpp b/velox/functions/sparksql/aggregates/RegrReplacementAggregate.cpp index 85f69d1f7c3f..5b5e63b9bea8 100644 --- a/velox/functions/sparksql/aggregates/RegrReplacementAggregate.cpp +++ b/velox/functions/sparksql/aggregates/RegrReplacementAggregate.cpp @@ -29,6 +29,8 @@ class RegrReplacementAggregate { /*m2*/ double>; using OutputType = double; + struct FunctionState {}; + static bool toIntermediate( exec::out_type>& out, exec::arg_type in) { @@ -41,11 +43,14 @@ class RegrReplacementAggregate { double avg{0.0}; double m2{0.0}; - explicit AccumulatorType(HashStringAllocator* /*allocator*/) {} + explicit AccumulatorType( + HashStringAllocator* /*allocator*/, + const FunctionState& /*state*/) {} void addInput( HashStringAllocator* /*allocator*/, - exec::arg_type data) { + exec::arg_type data, + const FunctionState& /*state*/) { n += 1.0; double delta = data - avg; double deltaN = delta / n; @@ -55,7 +60,8 @@ class RegrReplacementAggregate { void combine( HashStringAllocator* /*allocator*/, - exec::arg_type> other) { + exec::arg_type> other, + const FunctionState& /*state*/) { VELOX_CHECK(other.at<0>().has_value()); VELOX_CHECK(other.at<1>().has_value()); VELOX_CHECK(other.at<2>().has_value()); @@ -72,12 +78,16 @@ class RegrReplacementAggregate { m2 += otherM2 + delta * deltaN * originN * otherN; } - bool writeIntermediateResult(exec::out_type& out) { + bool writeIntermediateResult( + exec::out_type& out, + const FunctionState& /*state*/) { out = std::make_tuple(n, avg, m2); return true; } - bool writeFinalResult(exec::out_type& out) { + bool writeFinalResult( + exec::out_type& out, + const FunctionState& /*state*/) { if (n == 0.0) { return false; } From cb89b749e8f31a2f42a931a80e0420b71782321b Mon Sep 17 00:00:00 2001 From: "joey.ljy" Date: Sat, 20 Apr 2024 16:14:24 +0800 Subject: [PATCH 5/8] Address comments --- velox/exec/tests/CMakeLists.txt | 9 +- .../exec/tests/SimpleAggregateAdapterTest.cpp | 276 +++++++----------- .../tests/utils/AggregationTestBase.cpp | 52 +--- 3 files changed, 117 insertions(+), 220 deletions(-) diff --git a/velox/exec/tests/CMakeLists.txt b/velox/exec/tests/CMakeLists.txt index ddfb25743d23..3cc4f1aee144 100644 --- a/velox/exec/tests/CMakeLists.txt +++ b/velox/exec/tests/CMakeLists.txt @@ -244,8 +244,13 @@ add_executable(velox_simple_aggregate_test SimpleAggregateAdapterTest.cpp Main.cpp) target_link_libraries( - velox_simple_aggregate_test velox_simple_aggregate velox_exec - velox_functions_aggregates_test_lib gtest gtest_main) + velox_simple_aggregate_test + velox_simple_aggregate + velox_exec + velox_functions_aggregates_test_lib + velox_functions_window_test_lib + gtest + gtest_main) add_library(velox_spiller_join_benchmark_base JoinSpillInputBenchmarkBase.cpp SpillerBenchmarkBase.cpp) diff --git a/velox/exec/tests/SimpleAggregateAdapterTest.cpp b/velox/exec/tests/SimpleAggregateAdapterTest.cpp index 3ae60fcb15b4..bf2a5bbf1e8a 100644 --- a/velox/exec/tests/SimpleAggregateAdapterTest.cpp +++ b/velox/exec/tests/SimpleAggregateAdapterTest.cpp @@ -18,11 +18,13 @@ #include "velox/exec/tests/SimpleAggregateFunctionsRegistration.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h" +#include "velox/functions/lib/window/tests/WindowTestBase.h" using namespace facebook::velox::exec; using namespace facebook::velox::exec::test; using facebook::velox::common::testutil::TestValue; using facebook::velox::functions::aggregate::test::AggregationTestBase; +using facebook::velox::window::test::WindowTestBase; namespace facebook::velox::aggregate::test { namespace { @@ -30,7 +32,6 @@ namespace { const char* const kSimpleAvg = "simple_avg"; const char* const kSimpleArrayAgg = "simple_array_agg"; const char* const kSimpleCountNulls = "simple_count_nulls"; -const char* const kSimpleFunctionStateAgg = "simple_function_state_agg"; class SimpleAverageAggregationTest : public AggregationTestBase { protected: @@ -484,6 +485,7 @@ TEST_F(SimpleCountNullsAggregationTest, basic) { } // A testing aggregation function that uses the function state. +template class FunctionStateTestAggregate { public: using InputType = Row; // Input vector type wrapped in Row. @@ -497,12 +499,60 @@ class FunctionStateTestAggregate { std::vector constantInputs; }; + static void checkConstantInputs( + const std::vector& constantInputs) { + // Check that the constantInputs is {nullptr, 1} + VELOX_CHECK_EQ(constantInputs.size(), 2); + VELOX_CHECK_NULL(constantInputs[0]); + VELOX_CHECK(constantInputs[1]->isConstantEncoding()); + VELOX_CHECK_EQ( + constantInputs[1] + ->template asUnchecked>() + ->valueAt(0), + 1); + } + static void initialize( FunctionState& state, core::AggregationNode::Step step, const std::vector& rawInputTypes, const TypePtr& resultType, const std::vector& constantInputs) { + auto expectedRawInputTypes = {BIGINT(), BIGINT()}; + auto expectedIntermediateType = ROW({BIGINT(), DOUBLE()}); + + if constexpr (testCompanion) { + if (step == core::AggregationNode::Step::kPartial) { + VELOX_CHECK(std::equal( + rawInputTypes.begin(), + rawInputTypes.end(), + expectedRawInputTypes.begin(), + expectedRawInputTypes.end())); + checkConstantInputs(constantInputs); + } else if (step == core::AggregationNode::Step::kIntermediate) { + VELOX_CHECK_EQ(rawInputTypes.size(), 1); + VELOX_CHECK(rawInputTypes[0]->equivalent(*expectedIntermediateType)); + + VELOX_CHECK_EQ(constantInputs.size(), 1); + VELOX_CHECK_NULL(constantInputs[0]); + } else { + VELOX_FAIL("Unexpected aggregation step"); + } + } else { + VELOX_CHECK(std::equal( + rawInputTypes.begin(), + rawInputTypes.end(), + expectedRawInputTypes.begin(), + expectedRawInputTypes.end())); + if (step == core::AggregationNode::Step::kPartial || + step == core::AggregationNode::Step::kSingle) { + checkConstantInputs(constantInputs); + } else { + VELOX_CHECK_EQ(constantInputs.size(), 1); + VELOX_CHECK_NULL(constantInputs[0]); + } + } + state.step = step; state.rawInputTypes = rawInputTypes; state.resultType = resultType; @@ -517,27 +567,24 @@ class FunctionStateTestAggregate { HashStringAllocator* /*allocator*/, const FunctionState& /*state*/) {} - void checkpoint(FunctionState* state) { - TestValue::adjust( - "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", - state); - } - void addInput( HashStringAllocator* /*allocator*/, exec::arg_type data, - exec::arg_type increment, + exec::arg_type /*increment*/, const FunctionState& state) { - checkpoint(const_cast(&state)); + VELOX_CHECK_EQ(state.constantInputs.size(), 2); + VELOX_CHECK(state.constantInputs[1]->isConstantEncoding()); + auto constant = state.constantInputs[1] + ->template asUnchecked>() + ->valueAt(0); sum += data; - count += increment; + count += constant; } void combine( HashStringAllocator* /*allocator*/, exec::arg_type other, const FunctionState& state) { - checkpoint(const_cast(&state)); VELOX_CHECK(other.at<0>().has_value()); VELOX_CHECK(other.at<1>().has_value()); sum += other.at<0>().value(); @@ -547,7 +594,6 @@ class FunctionStateTestAggregate { bool writeIntermediateResult( exec::out_type& out, const FunctionState& state) { - checkpoint(const_cast(&state)); out = std::make_tuple(sum, count); return true; } @@ -555,7 +601,6 @@ class FunctionStateTestAggregate { bool writeFinalResult( exec::out_type& out, const FunctionState& state) { - checkpoint(const_cast(&state)); out = sum / count; return true; } @@ -564,6 +609,7 @@ class FunctionStateTestAggregate { using AccumulatorType = Accumulator; }; +template exec::AggregateRegistrationResult registerSimpleFunctionStateTestAggregate( const std::string& name) { std::vector> signatures{ @@ -585,100 +631,28 @@ exec::AggregateRegistrationResult registerSimpleFunctionStateTestAggregate( -> std::unique_ptr { VELOX_CHECK_LE(argTypes.size(), 2, "{} takes 2 argument", name); return std::make_unique< - SimpleAggregateAdapter>(resultType); + SimpleAggregateAdapter>>( + resultType); }, true /*registerCompanionFunctions*/, true /*overwrite*/); } void registerFunctionStateTestAggregate() { - registerSimpleFunctionStateTestAggregate(kSimpleFunctionStateAgg); + registerSimpleFunctionStateTestAggregate( + "simple_function_state_agg_main"); + registerSimpleFunctionStateTestAggregate( + "simple_function_state_agg_companion"); } class SimpleFunctionStateAggregationTest : public AggregationTestBase { protected: - SimpleFunctionStateAggregationTest() { - registerFunctionStateTestAggregate(); - } - - static void checkRowTypeEqual(TypePtr expected, TypePtr actual) { - VELOX_CHECK(expected->isRow()); - VELOX_CHECK(actual->isRow()); - VELOX_CHECK_EQ(expected->asRow().size(), actual->asRow().size()); - for (auto i = 0; i < expected->asRow().size(); i++) { - VELOX_CHECK_EQ(expected->asRow().childAt(i), actual->asRow().childAt(i)); - } - } - - static void checkState( - FunctionStateTestAggregate::FunctionState* state, - const std::vector& rawInputTypes, - const TypePtr& intermediateType, - const TypePtr& resultType, - const std::vector& constantInputs) { - VELOX_CHECK(!state->rawInputTypes.empty()); - VELOX_CHECK_NOT_NULL(state->resultType); - - switch (state->step) { - case core::AggregationNode::Step::kPartial: - case core::AggregationNode::Step::kIntermediate: - if (state->rawInputTypes.size() == 1 && - state->rawInputTypes[0]->isRow()) { - // Merge or merge_extract companion function. - VELOX_CHECK_EQ(rawInputTypes.size(), 1); - VELOX_CHECK(rawInputTypes[0]->isRow()); - checkRowTypeEqual(rawInputTypes[0], state->rawInputTypes[0]); - } else { - VELOX_CHECK(std::equal( - rawInputTypes.begin(), - rawInputTypes.end(), - state->rawInputTypes.begin(), - state->rawInputTypes.end())); - } - if (state->resultType->isRow()) { - checkRowTypeEqual(intermediateType, state->resultType); - } else { - VELOX_CHECK_EQ(resultType, state->resultType) - } - break; - - case core::AggregationNode::Step::kSingle: - case core::AggregationNode::Step::kFinal: - VELOX_CHECK(std::equal( - rawInputTypes.begin(), - rawInputTypes.end(), - state->rawInputTypes.begin(), - state->rawInputTypes.end())); - VELOX_CHECK_EQ(resultType, state->resultType); - break; - - default: - VELOX_FAIL("Unknown aggregate step"); - break; - } + void SetUp() override { + AggregationTestBase::SetUp(); - VELOX_CHECK(!state->constantInputs.empty()); - if (state->step == core::AggregationNode::Step::kPartial || - state->step == core::AggregationNode::Step::kSingle) { - VELOX_CHECK_EQ(constantInputs.size(), state->constantInputs.size()); - for (auto i = 0; i < constantInputs.size(); i++) { - auto expected = constantInputs[i]; - auto actual = state->constantInputs[i]; - if (expected == nullptr && actual == nullptr) { - continue; - } else { - VELOX_CHECK(expected != nullptr && actual != nullptr); - VELOX_CHECK(expected->isConstantEncoding()); - VELOX_CHECK(actual->isConstantEncoding()); - VELOX_CHECK_EQ( - expected->asUnchecked>()->valueAt(0), - actual->asUnchecked>()->valueAt(0)); - } - } - } else { - VELOX_CHECK_EQ(state->constantInputs.size(), 1); - VELOX_CHECK_NULL(state->constantInputs[0]); - } + disableTestIncremental(); + disableTestStreaming(); + registerFunctionStateTestAggregate(); } }; @@ -686,48 +660,14 @@ TEST_F(SimpleFunctionStateAggregationTest, aggregate) { auto inputVectors = makeRowVector({makeFlatVector({1, 2, 3, 4})}); std::vector finalResult = {2.5}; auto expected = makeRowVector({makeFlatVector(finalResult)}); - - SCOPED_TESTVALUE_SET( - "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", - std::function( - [&](FunctionStateTestAggregate::FunctionState* state) { - checkState( - state, - {BIGINT(), BIGINT()}, - ROW({BIGINT(), DOUBLE()}), - DOUBLE(), - {nullptr, makeConstant(1, 4)}); - })); testAggregations( - {inputVectors}, {}, {"simple_function_state_agg(c0, 1)"}, {expected}); -} - -TEST_F(SimpleFunctionStateAggregationTest, window) { - auto inputVectors = - makeRowVector({makeFlatVector({1, 1, 2, 2, 3, 3, 4})}); - auto expected = makeRowVector( - {makeFlatVector({1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0})}); - SCOPED_TESTVALUE_SET( - "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", - std::function( - [&](FunctionStateTestAggregate::FunctionState* state) { - checkState( - state, - {BIGINT(), BIGINT()}, - ROW({BIGINT(), DOUBLE()}), - DOUBLE(), - {nullptr, makeConstant(1, 7)}); - })); - auto plan = - PlanBuilder() - .values({inputVectors}) - .window({"simple_function_state_agg(c0, 1) over (partition by c0)"}) - .project({"w0"}) - .planNode(); - AssertQueryBuilder(plan).assertResults(expected); + {inputVectors}, + {}, + {"simple_function_state_agg_main(c0, 1)"}, + {expected}); } -TEST_F(SimpleFunctionStateAggregationTest, companionAggregateFunction) { +TEST_F(SimpleFunctionStateAggregationTest, companionAggregate) { auto inputVectors = makeRowVector({makeFlatVector({1, 2, 3, 4})}); std::vector accSum = {10}; std::vector accCount = {4.0}; @@ -740,21 +680,11 @@ TEST_F(SimpleFunctionStateAggregationTest, companionAggregateFunction) { std::vector finalResult = {2.5}; auto finalExpected = makeRowVector({makeFlatVector(finalResult)}); - SCOPED_TESTVALUE_SET( - "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", - std::function( - [&](FunctionStateTestAggregate::FunctionState* state) { - checkState( - state, - {BIGINT(), BIGINT()}, - ROW({BIGINT(), DOUBLE()}), - DOUBLE(), - {nullptr, makeConstant(1, 4)}); - })); AssertQueryBuilder( PlanBuilder() .values({inputVectors}) - .singleAggregation({}, {"simple_function_state_agg_partial(c0, 1)"}) + .singleAggregation( + {}, {"simple_function_state_agg_companion_partial(c0, 1)"}) .planNode()) .assertResults(intermediateExpected); @@ -764,43 +694,47 @@ TEST_F(SimpleFunctionStateAggregationTest, companionAggregateFunction) { makeFlatVector({1.0, 1.0, 1.0, 1.0}), }), }); - SCOPED_TESTVALUE_SET( - "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", - std::function( - [&](FunctionStateTestAggregate::FunctionState* state) { - checkState( - state, - {ROW({BIGINT(), DOUBLE()})}, - ROW({BIGINT(), DOUBLE()}), - ROW({BIGINT(), DOUBLE()}), - {nullptr, makeConstant(1, 4)}); - })); + AssertQueryBuilder( PlanBuilder() .values({inputVectors}) - .singleAggregation({}, {"simple_function_state_agg_merge(c0)"}) + .singleAggregation( + {}, {"simple_function_state_agg_companion_merge(c0)"}) .planNode()) .assertResults(intermediateExpected); - SCOPED_TESTVALUE_SET( - "facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint", - std::function( - [&](FunctionStateTestAggregate::FunctionState* state) { - checkState( - state, - {ROW({BIGINT(), DOUBLE()})}, - ROW({BIGINT(), DOUBLE()}), - DOUBLE(), - {nullptr, makeConstant(1, 4)}); - })); AssertQueryBuilder( PlanBuilder() .values({inputVectors}) .singleAggregation( - {}, {"simple_function_state_agg_merge_extract(c0)"}) + {}, {"simple_function_state_agg_companion_merge_extract(c0)"}) .planNode()) .assertResults(finalExpected); } +class SimpleFunctionStateWindowTest : public WindowTestBase { + protected: + void SetUp() override { + WindowTestBase::SetUp(); + + registerFunctionStateTestAggregate(); + } +}; + +TEST_F(SimpleFunctionStateWindowTest, window) { + auto inputVectors = + makeRowVector({makeFlatVector({1, 1, 2, 2, 3, 3, 4})}); + auto expected = makeRowVector({ + makeFlatVector({1, 1, 2, 2, 3, 3, 4}), + makeFlatVector({1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0}), + }); + WindowTestBase::testWindowFunction( + {inputVectors}, + "simple_function_state_agg_main(c0, 1)", + {"partition by c0"}, + {}, + expected); +} + } // namespace } // namespace facebook::velox::aggregate::test diff --git a/velox/functions/lib/aggregates/tests/utils/AggregationTestBase.cpp b/velox/functions/lib/aggregates/tests/utils/AggregationTestBase.cpp index 2a0e0e82a60c..16c0ed4f5dad 100644 --- a/velox/functions/lib/aggregates/tests/utils/AggregationTestBase.cpp +++ b/velox/functions/lib/aggregates/tests/utils/AggregationTestBase.cpp @@ -1231,27 +1231,13 @@ void AggregationTestBase::testIncrementalAggregation( const auto& aggregateExpr = aggregate.call; const auto& functionName = aggregateExpr->name(); auto input = extractArgColumns(aggregateExpr, data, pool()); - const auto& inputType = aggregationNode.sources()[0]->outputType(); HashStringAllocator allocator(pool()); std::vector lambdas; - std::vector constantInputs; for (const auto& arg : aggregate.call->inputs()) { - if (dynamic_cast(arg.get())) { - constantInputs.push_back(nullptr); - } else if ( - auto constant = - dynamic_cast(arg.get())) { - constantInputs.push_back(constant->toConstantVector(pool())); - } else if ( - auto lambda = + if (auto lambda = std::dynamic_pointer_cast(arg)) { lambdas.push_back(lambda); - for (const auto& name : lambda->signature()->names()) { - if (auto captureIndex = inputType->getChildIdxIfExists(name)) { - constantInputs.push_back(nullptr); - } - } } } auto queryCtxConfig = config; @@ -1270,11 +1256,6 @@ void AggregationTestBase::testIncrementalAggregation( std::vector group(kOffset + func->accumulatorFixedWidthSize()); std::vector groups(inputSize, group.data()); std::vector indices(1, 0); - func->initialize( - aggregationNode.step(), - aggregate.rawInputTypes, - func->resultType(), - constantInputs); func->initializeNewGroups(groups.data(), indices); func->addSingleGroupRawInput( group.data(), SelectivityVector(inputSize), input, false); @@ -1314,24 +1295,11 @@ VectorPtr AggregationTestBase::testStreaming( vector_size_t rawInput2Size, const std::unordered_map& config) { std::vector rawInputTypes(rawInput1.size()); - std::vector constantInputs1(rawInput1.size()); - std::vector constantInputs2(rawInput2.size()); - for (auto i = 0; i < rawInput1.size(); i++) { - rawInputTypes[i] = rawInput1[i]->type(); - if (rawInput1[i]->isConstantEncoding()) { - constantInputs1[i] = rawInput1[i]; - } else { - constantInputs1[i] = nullptr; - } - } - std::transform( - rawInput2.begin(), - rawInput2.end(), - constantInputs2.begin(), - [](const VectorPtr& vec) { - return vec->isConstantEncoding() ? vec : nullptr; - }); + rawInput1.begin(), + rawInput1.end(), + rawInputTypes.begin(), + [](const VectorPtr& vec) { return vec->type(); }); HashStringAllocator allocator(pool()); auto func = @@ -1340,11 +1308,6 @@ VectorPtr AggregationTestBase::testStreaming( std::vector group(kOffset + func->accumulatorFixedWidthSize()); std::vector groups(maxRowCount, group.data()); std::vector indices(maxRowCount, 0); - func->initialize( - core::AggregationNode::Step::kSingle, - rawInputTypes, - func->resultType(), - constantInputs1); func->initializeNewGroups(groups.data(), indices); if (testGlobal) { func->addSingleGroupRawInput( @@ -1364,11 +1327,6 @@ VectorPtr AggregationTestBase::testStreaming( // Create a new function picking up the intermediate result. auto func2 = createAggregateFunction(functionName, rawInputTypes, allocator, config); - func2->initialize( - core::AggregationNode::Step::kSingle, - rawInputTypes, - func2->resultType(), - constantInputs2); func2->initializeNewGroups(groups.data(), indices); if (testGlobal) { func2->addSingleGroupIntermediateResults( From a40de6294c63a8d438e47ec6cd4622d3fbf95d14 Mon Sep 17 00:00:00 2001 From: "joey.ljy" Date: Tue, 23 Apr 2024 20:06:03 +0800 Subject: [PATCH 6/8] Add companionStep --- velox/docs/develop/aggregate-functions.rst | 3 +- velox/exec/Aggregate.h | 4 +- velox/exec/AggregateCompanionAdapter.cpp | 23 +++--- velox/exec/AggregateCompanionAdapter.h | 6 +- velox/exec/SimpleAggregateAdapter.h | 11 ++- .../exec/tests/SimpleAggregateAdapterTest.cpp | 74 +++++++------------ 6 files changed, 57 insertions(+), 64 deletions(-) diff --git a/velox/docs/develop/aggregate-functions.rst b/velox/docs/develop/aggregate-functions.rst index 024a43057837..b527480ce6ca 100644 --- a/velox/docs/develop/aggregate-functions.rst +++ b/velox/docs/develop/aggregate-functions.rst @@ -165,7 +165,8 @@ A simple aggregation function is implemented as a class as the following. FunctionState& state, const std::vector& rawInputTypes, const TypePtr& resultType, - const std::vector& constantInputs) { + const std::vector& constantInputs, + std::optional companionStep) { state.resultType = resultType; } diff --git a/velox/exec/Aggregate.h b/velox/exec/Aggregate.h index 7267bd748a40..a27f22e144ac 100644 --- a/velox/exec/Aggregate.h +++ b/velox/exec/Aggregate.h @@ -139,7 +139,9 @@ class Aggregate { core::AggregationNode::Step step, const std::vector& rawInputType, const TypePtr& resultType, - const std::vector& constantInputs) {} + const std::vector& constantInputs, + std::optional companionStep = std::nullopt) { + } // Initializes null flags and accumulators for newly encountered groups. This // function should be called only once for each group. diff --git a/velox/exec/AggregateCompanionAdapter.cpp b/velox/exec/AggregateCompanionAdapter.cpp index 392111e992cc..cd65af40f453 100644 --- a/velox/exec/AggregateCompanionAdapter.cpp +++ b/velox/exec/AggregateCompanionAdapter.cpp @@ -125,15 +125,17 @@ void AggregateCompanionFunctionBase::extractAccumulators( } void AggregateCompanionAdapter::PartialFunction::initialize( - core::AggregationNode::Step /*step*/, + core::AggregationNode::Step step, const std::vector& rawInputType, const facebook::velox::TypePtr& resultType, - const std::vector& constantInputs) { + const std::vector& constantInputs, + std::optional /*companionStep*/) { fn_->initialize( - core::AggregationNode::Step::kPartial, + step, rawInputType, resultType, - constantInputs); + constantInputs, + core::AggregationNode::Step::kPartial); } void AggregateCompanionAdapter::PartialFunction::extractValues( @@ -144,15 +146,17 @@ void AggregateCompanionAdapter::PartialFunction::extractValues( } void AggregateCompanionAdapter::MergeFunction::initialize( - core::AggregationNode::Step /*step*/, + core::AggregationNode::Step step, const std::vector& rawInputType, const facebook::velox::TypePtr& resultType, - const std::vector& constantInputs) { + const std::vector& constantInputs, + std::optional /*companionStep*/) { fn_->initialize( - core::AggregationNode::Step::kIntermediate, + step, rawInputType, resultType, - constantInputs); + constantInputs, + core::AggregationNode::Step::kIntermediate); } void AggregateCompanionAdapter::MergeFunction::addRawInput( @@ -270,7 +274,8 @@ void AggregateCompanionAdapter::ExtractFunction::apply( core::AggregationNode::Step::kFinal, rawInputTypes, outputType, - constantInputs); + constantInputs, + core::AggregationNode::Step::kIntermediate); fn_->initializeNewGroups(groups, allSelectedRange); fn_->enableValidateIntermediateInputs(); fn_->addIntermediateResults(groups, rows, args, false); diff --git a/velox/exec/AggregateCompanionAdapter.h b/velox/exec/AggregateCompanionAdapter.h index 24b196e3f60e..d2943df53f97 100644 --- a/velox/exec/AggregateCompanionAdapter.h +++ b/velox/exec/AggregateCompanionAdapter.h @@ -103,7 +103,8 @@ struct AggregateCompanionAdapter { core::AggregationNode::Step step, const std::vector& rawInputType, const TypePtr& resultType, - const std::vector& constantInputs) override; + const std::vector& constantInputs, + std::optional companionStep) override; void extractValues(char** groups, int32_t numGroups, VectorPtr* result) override; @@ -120,7 +121,8 @@ struct AggregateCompanionAdapter { core::AggregationNode::Step step, const std::vector& rawInputType, const TypePtr& resultType, - const std::vector& constantInputs) override; + const std::vector& constantInputs, + std::optional companionStep) override; void addRawInput( char** groups, diff --git a/velox/exec/SimpleAggregateAdapter.h b/velox/exec/SimpleAggregateAdapter.h index e236e32d1378..0826016e36e8 100644 --- a/velox/exec/SimpleAggregateAdapter.h +++ b/velox/exec/SimpleAggregateAdapter.h @@ -53,9 +53,16 @@ class SimpleAggregateAdapter : public Aggregate { core::AggregationNode::Step step, const std::vector& rawInputTypes, const TypePtr& resultType, - const std::vector& constantInputs) override { + const std::vector& constantInputs, + std::optional companionStep) override { if constexpr (support_initialize_function_state_) { - FUNC::initialize(state_, step, rawInputTypes, resultType, constantInputs); + FUNC::initialize( + state_, + step, + rawInputTypes, + resultType, + constantInputs, + companionStep); } } diff --git a/velox/exec/tests/SimpleAggregateAdapterTest.cpp b/velox/exec/tests/SimpleAggregateAdapterTest.cpp index bf2a5bbf1e8a..6f8b44451e69 100644 --- a/velox/exec/tests/SimpleAggregateAdapterTest.cpp +++ b/velox/exec/tests/SimpleAggregateAdapterTest.cpp @@ -497,6 +497,7 @@ class FunctionStateTestAggregate { std::vector rawInputTypes; TypePtr resultType; std::vector constantInputs; + core::AggregationNode::Step companionStep; }; static void checkConstantInputs( @@ -517,19 +518,29 @@ class FunctionStateTestAggregate { core::AggregationNode::Step step, const std::vector& rawInputTypes, const TypePtr& resultType, - const std::vector& constantInputs) { + const std::vector& constantInputs, + std::optional companionStep) { auto expectedRawInputTypes = {BIGINT(), BIGINT()}; auto expectedIntermediateType = ROW({BIGINT(), DOUBLE()}); if constexpr (testCompanion) { - if (step == core::AggregationNode::Step::kPartial) { + VELOX_CHECK(companionStep.has_value()); + if (companionStep.value() == core::AggregationNode::Step::kPartial) { VELOX_CHECK(std::equal( rawInputTypes.begin(), rawInputTypes.end(), expectedRawInputTypes.begin(), expectedRawInputTypes.end())); - checkConstantInputs(constantInputs); - } else if (step == core::AggregationNode::Step::kIntermediate) { + if (step == core::AggregationNode::Step::kPartial || + step == core::AggregationNode::Step::kSingle) { + // Only check constant inputs in partial and single step. + checkConstantInputs(constantInputs); + } else { + VELOX_CHECK_EQ(constantInputs.size(), 1); + VELOX_CHECK_NULL(constantInputs[0]); + } + } else if ( + companionStep.value() == core::AggregationNode::Step::kIntermediate) { VELOX_CHECK_EQ(rawInputTypes.size(), 1); VELOX_CHECK(rawInputTypes[0]->equivalent(*expectedIntermediateType)); @@ -546,6 +557,7 @@ class FunctionStateTestAggregate { expectedRawInputTypes.end())); if (step == core::AggregationNode::Step::kPartial || step == core::AggregationNode::Step::kSingle) { + // Only check constant inputs in partial and single step. checkConstantInputs(constantInputs); } else { VELOX_CHECK_EQ(constantInputs.size(), 1); @@ -665,51 +677,15 @@ TEST_F(SimpleFunctionStateAggregationTest, aggregate) { {}, {"simple_function_state_agg_main(c0, 1)"}, {expected}); -} - -TEST_F(SimpleFunctionStateAggregationTest, companionAggregate) { - auto inputVectors = makeRowVector({makeFlatVector({1, 2, 3, 4})}); - std::vector accSum = {10}; - std::vector accCount = {4.0}; - auto intermediateExpected = makeRowVector({ - makeRowVector({ - makeFlatVector(accSum), - makeFlatVector(accCount), - }), - }); - std::vector finalResult = {2.5}; - auto finalExpected = makeRowVector({makeFlatVector(finalResult)}); - - AssertQueryBuilder( - PlanBuilder() - .values({inputVectors}) - .singleAggregation( - {}, {"simple_function_state_agg_companion_partial(c0, 1)"}) - .planNode()) - .assertResults(intermediateExpected); - - inputVectors = makeRowVector({ - makeRowVector({ - makeFlatVector({1, 2, 3, 4}), - makeFlatVector({1.0, 1.0, 1.0, 1.0}), - }), - }); - - AssertQueryBuilder( - PlanBuilder() - .values({inputVectors}) - .singleAggregation( - {}, {"simple_function_state_agg_companion_merge(c0)"}) - .planNode()) - .assertResults(intermediateExpected); - - AssertQueryBuilder( - PlanBuilder() - .values({inputVectors}) - .singleAggregation( - {}, {"simple_function_state_agg_companion_merge_extract(c0)"}) - .planNode()) - .assertResults(finalExpected); + testAggregationsWithCompanion( + {inputVectors}, + [](auto& /*builder*/) {}, + {}, + {"simple_function_state_agg_companion(c0, 1)"}, + {{BIGINT(), BIGINT()}}, + {}, + {expected}, + {}); } class SimpleFunctionStateWindowTest : public WindowTestBase { From 9072f7c3948aa89bf84411fb858ab460667cb986 Mon Sep 17 00:00:00 2001 From: "joey.ljy" Date: Wed, 24 Apr 2024 17:22:00 +0800 Subject: [PATCH 7/8] Address comments --- velox/exec/Aggregate.h | 11 ++++- velox/exec/AggregateCompanionAdapter.cpp | 16 ++++++- velox/exec/AggregateCompanionAdapter.h | 7 ++++ velox/exec/AggregateWindow.cpp | 4 ++ .../exec/tests/SimpleAggregateAdapterTest.cpp | 42 ++++++++++--------- 5 files changed, 57 insertions(+), 23 deletions(-) diff --git a/velox/exec/Aggregate.h b/velox/exec/Aggregate.h index a27f22e144ac..de3e6b529dac 100644 --- a/velox/exec/Aggregate.h +++ b/velox/exec/Aggregate.h @@ -133,8 +133,15 @@ class Aggregate { // UDAF. // @param step The aggregation step. // @param rawInputType The raw input type of the UDAF. - // @param resultType The result type of the UDAF. - // @param constantInputs Optional constant inputs. + // @param resultType The result type of the current aggregation step. + // @param constantInputs Optional constant input values for aggregate + // function. constantInputs should be empty if there are no constant inputs, + // aligned with inputTypes if there is at least one constant input, with + // non-constant inputs represented as nullptr, and must be instances of + // ConstantVector. + // @param companionStep The step used to register aggregate companion + // functions. kPartial for partial companion function, kIntermediate for merge + // and merge extract companion function. virtual void initialize( core::AggregationNode::Step step, const std::vector& rawInputType, diff --git a/velox/exec/AggregateCompanionAdapter.cpp b/velox/exec/AggregateCompanionAdapter.cpp index cd65af40f453..5b4135428bc2 100644 --- a/velox/exec/AggregateCompanionAdapter.cpp +++ b/velox/exec/AggregateCompanionAdapter.cpp @@ -184,6 +184,20 @@ void AggregateCompanionAdapter::MergeFunction::extractValues( fn_->extractAccumulators(groups, numGroups, result); } +void AggregateCompanionAdapter::MergeExtractFunction::initialize( + core::AggregationNode::Step step, + const std::vector& rawInputType, + const facebook::velox::TypePtr& resultType, + const std::vector& constantInputs, + std::optional /*companionStep*/) { + fn_->initialize( + step, + rawInputType, + resultType, + constantInputs, + core::AggregationNode::Step::kFinal); +} + void AggregateCompanionAdapter::MergeExtractFunction::extractValues( char** groups, int32_t numGroups, @@ -275,7 +289,7 @@ void AggregateCompanionAdapter::ExtractFunction::apply( rawInputTypes, outputType, constantInputs, - core::AggregationNode::Step::kIntermediate); + core::AggregationNode::Step::kFinal); fn_->initializeNewGroups(groups, allSelectedRange); fn_->enableValidateIntermediateInputs(); fn_->addIntermediateResults(groups, rows, args, false); diff --git a/velox/exec/AggregateCompanionAdapter.h b/velox/exec/AggregateCompanionAdapter.h index d2943df53f97..6c316a74b888 100644 --- a/velox/exec/AggregateCompanionAdapter.h +++ b/velox/exec/AggregateCompanionAdapter.h @@ -147,6 +147,13 @@ struct AggregateCompanionAdapter { const TypePtr& resultType) : MergeFunction{std::move(fn), resultType} {} + void initialize( + core::AggregationNode::Step step, + const std::vector& rawInputType, + const TypePtr& resultType, + const std::vector& constantInputs, + std::optional companionStep) override; + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) override; }; diff --git a/velox/exec/AggregateWindow.cpp b/velox/exec/AggregateWindow.cpp index f184757ece54..83650f74322e 100644 --- a/velox/exec/AggregateWindow.cpp +++ b/velox/exec/AggregateWindow.cpp @@ -391,6 +391,10 @@ class AggregateWindowFunction : public exec::WindowFunction { std::vector argTypes_; std::vector argIndices_; std::vector argVectors_; + // Constant input values for aggregate function. it should be empty if there + // are no constant inputs, aligned with inputTypes if there is at least one + // constant input, with non-constant inputs represented as nullptr, and must + // be instances of ConstantVector. std::vector constantInputs_; // This is a single aggregate row needed by the aggregate function for its diff --git a/velox/exec/tests/SimpleAggregateAdapterTest.cpp b/velox/exec/tests/SimpleAggregateAdapterTest.cpp index 6f8b44451e69..18ea484db144 100644 --- a/velox/exec/tests/SimpleAggregateAdapterTest.cpp +++ b/velox/exec/tests/SimpleAggregateAdapterTest.cpp @@ -484,7 +484,10 @@ TEST_F(SimpleCountNullsAggregationTest, basic) { testAggregations({vectors}, {}, {"simple_count_nulls(c2)"}, {expected}); } -// A testing aggregation function that uses the function state. +// A testing aggregate function calculates a weighted average by taking an +// int64_t input value and a constant weight value as input. The sum of all +// input values is divided by the sum of all weight values to produce the final +// result. It is used to check for expectations in initialize method. template class FunctionStateTestAggregate { public: @@ -520,17 +523,14 @@ class FunctionStateTestAggregate { const TypePtr& resultType, const std::vector& constantInputs, std::optional companionStep) { - auto expectedRawInputTypes = {BIGINT(), BIGINT()}; + std::vector expectedRawInputTypes = {BIGINT(), BIGINT()}; auto expectedIntermediateType = ROW({BIGINT(), DOUBLE()}); if constexpr (testCompanion) { + // Check for companion functions. VELOX_CHECK(companionStep.has_value()); if (companionStep.value() == core::AggregationNode::Step::kPartial) { - VELOX_CHECK(std::equal( - rawInputTypes.begin(), - rawInputTypes.end(), - expectedRawInputTypes.begin(), - expectedRawInputTypes.end())); + VELOX_CHECK(rawInputTypes == expectedRawInputTypes); if (step == core::AggregationNode::Step::kPartial || step == core::AggregationNode::Step::kSingle) { // Only check constant inputs in partial and single step. @@ -540,7 +540,8 @@ class FunctionStateTestAggregate { VELOX_CHECK_NULL(constantInputs[0]); } } else if ( - companionStep.value() == core::AggregationNode::Step::kIntermediate) { + companionStep.value() == core::AggregationNode::Step::kIntermediate || + companionStep.value() == core::AggregationNode::Step::kFinal) { VELOX_CHECK_EQ(rawInputTypes.size(), 1); VELOX_CHECK(rawInputTypes[0]->equivalent(*expectedIntermediateType)); @@ -550,11 +551,7 @@ class FunctionStateTestAggregate { VELOX_FAIL("Unexpected aggregation step"); } } else { - VELOX_CHECK(std::equal( - rawInputTypes.begin(), - rawInputTypes.end(), - expectedRawInputTypes.begin(), - expectedRawInputTypes.end())); + VELOX_CHECK(rawInputTypes == expectedRawInputTypes); if (step == core::AggregationNode::Step::kPartial || step == core::AggregationNode::Step::kSingle) { // Only check constant inputs in partial and single step. @@ -641,12 +638,12 @@ exec::AggregateRegistrationResult registerSimpleFunctionStateTestAggregate( const TypePtr& resultType, const core::QueryConfig& /*config*/) -> std::unique_ptr { - VELOX_CHECK_LE(argTypes.size(), 2, "{} takes 2 argument", name); + VELOX_CHECK_LE(argTypes.size(), 2, "{} takes at most 2 argument", name); return std::make_unique< SimpleAggregateAdapter>>( resultType); }, - true /*registerCompanionFunctions*/, + testCompanion, true /*overwrite*/); } @@ -698,15 +695,20 @@ class SimpleFunctionStateWindowTest : public WindowTestBase { }; TEST_F(SimpleFunctionStateWindowTest, window) { - auto inputVectors = - makeRowVector({makeFlatVector({1, 1, 2, 2, 3, 3, 4})}); + auto inputVectors = makeRowVector({ + makeFlatVector({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 5}), + makeFlatVector({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}), + }); + auto expected = makeRowVector({ - makeFlatVector({1, 1, 2, 2, 3, 3, 4}), - makeFlatVector({1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0}), + inputVectors->childAt(0), + inputVectors->childAt(1), + makeFlatVector( + {2.0, 2.0, 2.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 10.5, 10.5, 12.0}), }); WindowTestBase::testWindowFunction( {inputVectors}, - "simple_function_state_agg_main(c0, 1)", + "simple_function_state_agg_main(c1, 1)", {"partition by c0"}, {}, expected); From 28855c38cb025245d95c4961ccbab379a2aef2db Mon Sep 17 00:00:00 2001 From: "joey.ljy" Date: Fri, 26 Apr 2024 11:02:04 +0800 Subject: [PATCH 8/8] Fix rst doc --- velox/docs/develop/aggregate-functions.rst | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/velox/docs/develop/aggregate-functions.rst b/velox/docs/develop/aggregate-functions.rst index b527480ce6ca..e872ae5dda94 100644 --- a/velox/docs/develop/aggregate-functions.rst +++ b/velox/docs/develop/aggregate-functions.rst @@ -152,14 +152,16 @@ A simple aggregation function is implemented as a class as the following. using IntermediateType = Array>; using OutputType = Array>; - // If UDAF does not require the use of FunctionState, it is necessary - // to declare an empty FunctionState struct. + // Define a struct for function-level states. Even if the aggregation function + // doesn't use function-level states, it is still necessary to define an empty + // FunctionState struct. struct FunctionState { // Optional. TypePtr resultType; }; - // Optional. Used only when the UDAF needs to use FunctionState. + // Optional. Defined only when the aggregation function needs to use function-level states. + // This method is called once when the aggregation function is created. static void initialize( core::AggregationNode::Step step, FunctionState& state, @@ -188,13 +190,13 @@ one argument. This is needed for the SimpleAggregateAdapter to parse input types for arbitrary aggregation functions properly. A FunctionState struct needs to be declared in the simple aggregation function -class, it is used to hold the function-level variables that are typically -computed once and used at every row when adding inputs to accumulators or -extracting values from accumulators. For example, if the UDAF needs to get the -result type or the raw input type of the aggregaiton function, the author can -hold them in the FunctionState struct, and initialize them in the initialize() -method. If the UDAF does not require the use ofFunctionState, it is necessary -to declare an empty FunctionState struct. +class. FunctionState is initialized once when the aggregation function is +created and used at every row when adding inputs to accumulators or extracting +values from accumulators. For example, if the aggregation function needs to get +the result type or the raw input type of the aggregaiton function, the author +can hold them in the FunctionState struct, and initialize them in the +initialize() method. If the aggregation function does not require the use of +FunctionState, it is necessary to declare an empty FunctionState struct. The author can define an optional flag `default_null_behavior_` indicating whether the aggregation function has default-null behavior. This flag is true