diff --git a/velox/exec/AggregateCompanionAdapter.cpp b/velox/exec/AggregateCompanionAdapter.cpp index 284c1ff47222..331a953c6362 100644 --- a/velox/exec/AggregateCompanionAdapter.cpp +++ b/velox/exec/AggregateCompanionAdapter.cpp @@ -20,6 +20,7 @@ #include "velox/exec/AggregateFunctionRegistry.h" #include "velox/exec/RowContainer.h" #include "velox/expression/SignatureBinder.h" +#include "velox/expression/SpecialFormRegistry.h" namespace facebook::velox::exec { @@ -420,40 +421,6 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunction( overwrite); } -VectorFunctionFactory getVectorFunctionFactory( - const std::string& originalName) { - return [originalName]( - const std::string& name, - const std::vector& inputArgs, - const core::QueryConfig& config) - -> std::shared_ptr { - std::vector argTypes{inputArgs.size()}; - std::transform( - inputArgs.begin(), - inputArgs.end(), - argTypes.begin(), - [](auto inputArg) { return inputArg.type; }); - - auto resultType = resolveVectorFunction(name, argTypes); - if (!resultType) { - // TODO: limitation -- result type must be resolveable given - // intermediate type of the original UDAF. - VELOX_UNREACHABLE( - "Signatures whose result types are not resolvable given intermediate types should have been excluded."); - } - - if (auto func = getAggregateFunctionEntry(originalName)) { - auto fn = func->factory( - core::AggregationNode::Step::kFinal, argTypes, resultType, config); - VELOX_CHECK_NOT_NULL(fn); - return std::make_shared( - std::move(fn)); - } - VELOX_FAIL( - "Original aggregation function {} not found: {}", originalName, name); - }; -} - bool CompanionFunctionsRegistrar::registerExtractFunctionWithSuffix( const std::string& originalName, const std::vector& signatures, @@ -468,15 +435,12 @@ bool CompanionFunctionsRegistrar::registerExtractFunctionWithSuffix( continue; } - auto factory = getVectorFunctionFactory(originalName); - registered |= exec::registerStatefulVectorFunction( - CompanionSignatures::extractFunctionNameWithSuffix(originalName, type), - std::move(extractSignatures), - std::move(factory), - exec::VectorFunctionMetadataBuilder() - .defaultNullBehavior(false) - .build(), - overwrite); + auto functionName = + CompanionSignatures::extractFunctionNameWithSuffix(originalName, type); + registerFunctionCallToSpecialForm( + functionName, + std::make_unique(originalName, functionName)); + registered = true; } return registered; } @@ -497,13 +461,11 @@ bool CompanionFunctionsRegistrar::registerExtractFunction( return false; } - auto factory = getVectorFunctionFactory(originalName); - return exec::registerStatefulVectorFunction( - CompanionSignatures::extractFunctionName(originalName), - std::move(extractSignatures), - std::move(factory), - exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build(), - overwrite); + auto functionName = CompanionSignatures::extractFunctionName(originalName); + registerFunctionCallToSpecialForm( + functionName, + std::make_unique(originalName, functionName)); + return true; } } // namespace facebook::velox::exec diff --git a/velox/exec/AggregateCompanionAdapter.h b/velox/exec/AggregateCompanionAdapter.h index 91b7c3a7bed8..088e338efc70 100644 --- a/velox/exec/AggregateCompanionAdapter.h +++ b/velox/exec/AggregateCompanionAdapter.h @@ -17,6 +17,7 @@ #include "velox/common/memory/HashStringAllocator.h" #include "velox/exec/Aggregate.h" +#include "velox/expression/FunctionCallToSpecialForm.h" #include "velox/expression/VectorFunction.h" namespace facebook::velox::exec { @@ -169,6 +170,55 @@ struct AggregateCompanionAdapter { }; }; +class ExtractCallToSpecialForm : public exec::FunctionCallToSpecialForm { + public: + ExtractCallToSpecialForm( + const std::string& originalName, + const std::string& functionName) + : originalName_{originalName}, functionName_{functionName} {} + + TypePtr resolveType(const std::vector& /*argTypes*/) override { + VELOX_FAIL("Extract function does not support type resolution."); + } + + exec::ExprPtr constructSpecialForm( + const TypePtr& type, + std::vector&& args, + bool trackCpuUsage, + const core::QueryConfig& config) override { + std::vector argTypes{args.size()}; + std::transform(args.begin(), args.end(), argTypes.begin(), [](auto arg) { + return arg->type(); + }); + + std::shared_ptr extractFunction; + if (auto func = getAggregateFunctionEntry(originalName_)) { + auto fn = func->factory( + core::AggregationNode::Step::kFinal, argTypes, type, config); + VELOX_CHECK_NOT_NULL(fn); + extractFunction = + std::make_shared( + std::move(fn)); + } + VELOX_FAIL( + "Original aggregation function {} not found: {}", + originalName_, + functionName_); + + return std::make_shared( + type, + std::move(args), + std::move(extractFunction), + exec::VectorFunctionMetadata{}, + functionName_, + trackCpuUsage); + } + + private: + std::string originalName_; + std::string functionName_; +}; + class CompanionFunctionsRegistrar { public: // Register the partial companion function for an aggregation function of