Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 committed Apr 26, 2024
1 parent a40de62 commit 9072f7c
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 23 deletions.
11 changes: 9 additions & 2 deletions velox/exec/Aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,15 @@ class Aggregate {
// UDAF.
// @param step The aggregation step.
// @param rawInputType The raw input type of the UDAF.
// @param resultType The result type of the UDAF.
// @param constantInputs Optional constant inputs.
// @param resultType The result type of the current aggregation step.
// @param constantInputs Optional constant input values for aggregate
// function. constantInputs should be empty if there are no constant inputs,
// aligned with inputTypes if there is at least one constant input, with
// non-constant inputs represented as nullptr, and must be instances of
// ConstantVector.
// @param companionStep The step used to register aggregate companion
// functions. kPartial for partial companion function, kIntermediate for merge
// and merge extract companion function.
virtual void initialize(
core::AggregationNode::Step step,
const std::vector<TypePtr>& rawInputType,
Expand Down
16 changes: 15 additions & 1 deletion velox/exec/AggregateCompanionAdapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,20 @@ void AggregateCompanionAdapter::MergeFunction::extractValues(
fn_->extractAccumulators(groups, numGroups, result);
}

void AggregateCompanionAdapter::MergeExtractFunction::initialize(
core::AggregationNode::Step step,
const std::vector<TypePtr>& rawInputType,
const facebook::velox::TypePtr& resultType,
const std::vector<VectorPtr>& constantInputs,
std::optional<core::AggregationNode::Step> /*companionStep*/) {
fn_->initialize(
step,
rawInputType,
resultType,
constantInputs,
core::AggregationNode::Step::kFinal);
}

void AggregateCompanionAdapter::MergeExtractFunction::extractValues(
char** groups,
int32_t numGroups,
Expand Down Expand Up @@ -275,7 +289,7 @@ void AggregateCompanionAdapter::ExtractFunction::apply(
rawInputTypes,
outputType,
constantInputs,
core::AggregationNode::Step::kIntermediate);
core::AggregationNode::Step::kFinal);
fn_->initializeNewGroups(groups, allSelectedRange);
fn_->enableValidateIntermediateInputs();
fn_->addIntermediateResults(groups, rows, args, false);
Expand Down
7 changes: 7 additions & 0 deletions velox/exec/AggregateCompanionAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,13 @@ struct AggregateCompanionAdapter {
const TypePtr& resultType)
: MergeFunction{std::move(fn), resultType} {}

void initialize(
core::AggregationNode::Step step,
const std::vector<TypePtr>& rawInputType,
const TypePtr& resultType,
const std::vector<VectorPtr>& constantInputs,
std::optional<core::AggregationNode::Step> companionStep) override;

void extractValues(char** groups, int32_t numGroups, VectorPtr* result)
override;
};
Expand Down
4 changes: 4 additions & 0 deletions velox/exec/AggregateWindow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,10 @@ class AggregateWindowFunction : public exec::WindowFunction {
std::vector<TypePtr> argTypes_;
std::vector<column_index_t> argIndices_;
std::vector<VectorPtr> argVectors_;
// Constant input values for aggregate function. it should be empty if there
// are no constant inputs, aligned with inputTypes if there is at least one
// constant input, with non-constant inputs represented as nullptr, and must
// be instances of ConstantVector.
std::vector<VectorPtr> constantInputs_;

// This is a single aggregate row needed by the aggregate function for its
Expand Down
42 changes: 22 additions & 20 deletions velox/exec/tests/SimpleAggregateAdapterTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,10 @@ TEST_F(SimpleCountNullsAggregationTest, basic) {
testAggregations({vectors}, {}, {"simple_count_nulls(c2)"}, {expected});
}

// A testing aggregation function that uses the function state.
// A testing aggregate function calculates a weighted average by taking an
// int64_t input value and a constant weight value as input. The sum of all
// input values is divided by the sum of all weight values to produce the final
// result. It is used to check for expectations in initialize method.
template <bool testCompanion>
class FunctionStateTestAggregate {
public:
Expand Down Expand Up @@ -520,17 +523,14 @@ class FunctionStateTestAggregate {
const TypePtr& resultType,
const std::vector<VectorPtr>& constantInputs,
std::optional<core::AggregationNode::Step> companionStep) {
auto expectedRawInputTypes = {BIGINT(), BIGINT()};
std::vector<TypePtr> expectedRawInputTypes = {BIGINT(), BIGINT()};
auto expectedIntermediateType = ROW({BIGINT(), DOUBLE()});

if constexpr (testCompanion) {
// Check for companion functions.
VELOX_CHECK(companionStep.has_value());
if (companionStep.value() == core::AggregationNode::Step::kPartial) {
VELOX_CHECK(std::equal(
rawInputTypes.begin(),
rawInputTypes.end(),
expectedRawInputTypes.begin(),
expectedRawInputTypes.end()));
VELOX_CHECK(rawInputTypes == expectedRawInputTypes);
if (step == core::AggregationNode::Step::kPartial ||
step == core::AggregationNode::Step::kSingle) {
// Only check constant inputs in partial and single step.
Expand All @@ -540,7 +540,8 @@ class FunctionStateTestAggregate {
VELOX_CHECK_NULL(constantInputs[0]);
}
} else if (
companionStep.value() == core::AggregationNode::Step::kIntermediate) {
companionStep.value() == core::AggregationNode::Step::kIntermediate ||
companionStep.value() == core::AggregationNode::Step::kFinal) {
VELOX_CHECK_EQ(rawInputTypes.size(), 1);
VELOX_CHECK(rawInputTypes[0]->equivalent(*expectedIntermediateType));

Expand All @@ -550,11 +551,7 @@ class FunctionStateTestAggregate {
VELOX_FAIL("Unexpected aggregation step");
}
} else {
VELOX_CHECK(std::equal(
rawInputTypes.begin(),
rawInputTypes.end(),
expectedRawInputTypes.begin(),
expectedRawInputTypes.end()));
VELOX_CHECK(rawInputTypes == expectedRawInputTypes);
if (step == core::AggregationNode::Step::kPartial ||
step == core::AggregationNode::Step::kSingle) {
// Only check constant inputs in partial and single step.
Expand Down Expand Up @@ -641,12 +638,12 @@ exec::AggregateRegistrationResult registerSimpleFunctionStateTestAggregate(
const TypePtr& resultType,
const core::QueryConfig& /*config*/)
-> std::unique_ptr<exec::Aggregate> {
VELOX_CHECK_LE(argTypes.size(), 2, "{} takes 2 argument", name);
VELOX_CHECK_LE(argTypes.size(), 2, "{} takes at most 2 argument", name);
return std::make_unique<
SimpleAggregateAdapter<FunctionStateTestAggregate<testCompanion>>>(
resultType);
},
true /*registerCompanionFunctions*/,
testCompanion,
true /*overwrite*/);
}

Expand Down Expand Up @@ -698,15 +695,20 @@ class SimpleFunctionStateWindowTest : public WindowTestBase {
};

TEST_F(SimpleFunctionStateWindowTest, window) {
auto inputVectors =
makeRowVector({makeFlatVector<int64_t>({1, 1, 2, 2, 3, 3, 4})});
auto inputVectors = makeRowVector({
makeFlatVector<int32_t>({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 5}),
makeFlatVector<int64_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}),
});

auto expected = makeRowVector({
makeFlatVector<int64_t>({1, 1, 2, 2, 3, 3, 4}),
makeFlatVector<double>({1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0}),
inputVectors->childAt(0),
inputVectors->childAt(1),
makeFlatVector<double>(
{2.0, 2.0, 2.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 10.5, 10.5, 12.0}),
});
WindowTestBase::testWindowFunction(
{inputVectors},
"simple_function_state_agg_main(c0, 1)",
"simple_function_state_agg_main(c1, 1)",
{"partition by c0"},
{},
expected);
Expand Down

0 comments on commit 9072f7c

Please sign in to comment.