Skip to content

Commit

Permalink
Optimize test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 committed Apr 17, 2024
1 parent 6f23305 commit 7a15313
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 76 deletions.
19 changes: 13 additions & 6 deletions velox/exec/AggregateCompanionAdapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,21 @@ void AggregateCompanionAdapter::ExtractFunction::apply(

// Get the raw input types.
std::vector<TypePtr> rawInputTypes{args.size()};
std::transform(
args.begin(),
args.end(),
rawInputTypes.begin(),
[](const VectorPtr& arg) { return arg->type(); });
std::vector<VectorPtr> constantInputs{args.size()};
for (auto i = 0; i < args.size(); i++) {
rawInputTypes[i] = args[i]->type();
if (args[i]->isConstantEncoding()) {
constantInputs[i] = args[i];
} else {
constantInputs[i] = nullptr;
}
}

fn_->initialize(
core::AggregationNode::Step::kFinal, rawInputTypes, outputType, {});
core::AggregationNode::Step::kFinal,
rawInputTypes,
outputType,
constantInputs);
fn_->initializeNewGroups(groups, allSelectedRange);
fn_->enableValidateIntermediateInputs();
fn_->addIntermediateResults(groups, rows, args, false);
Expand Down
214 changes: 152 additions & 62 deletions velox/exec/tests/SimpleAggregateAdapterTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,13 +486,13 @@ TEST_F(SimpleCountNullsAggregationTest, basic) {
// A testing aggregation function that uses the function state.
class FunctionStateTestAggregate {
public:
using InputType = Row<int64_t>; // Input vector type wrapped in Row.
using IntermediateType = int64_t; // Intermediate result type.
using OutputType = int64_t; // Output vector type.
using InputType = Row<int64_t, int64_t>; // Input vector type wrapped in Row.
using IntermediateType = Row<int64_t, double>; // Intermediate result type.
using OutputType = double; // Output vector type.

struct FunctionState {
core::AggregationNode::Step step;
std::vector<TypePtr> rawInputType;
std::vector<TypePtr> rawInputTypes;
TypePtr resultType;
std::vector<VectorPtr> constantInputs;
};
Expand All @@ -504,16 +504,14 @@ class FunctionStateTestAggregate {
const TypePtr& resultType,
const std::vector<VectorPtr>& constantInputs) {
state.step = step;
state.rawInputType = rawInputTypes;
state.rawInputTypes = rawInputTypes;
state.resultType = resultType;
if (resultType == nullptr) {
LOG(INFO) << "nullptr";
}
state.constantInputs = constantInputs;
}

struct Accumulator {
int64_t sum{0};
double count{0};

explicit Accumulator(
HashStringAllocator* /*allocator*/,
Expand All @@ -528,32 +526,37 @@ class FunctionStateTestAggregate {
void addInput(
HashStringAllocator* /*allocator*/,
exec::arg_type<int64_t> data,
exec::arg_type<int64_t> increment,
const FunctionState& state) {
checkpoint(const_cast<FunctionState*>(&state));
sum += data;
count += increment;
}

void combine(
HashStringAllocator* /*allocator*/,
exec::arg_type<IntermediateType> other,
const FunctionState& state) {
checkpoint(const_cast<FunctionState*>(&state));
sum += other;
VELOX_CHECK(other.at<0>().has_value());
VELOX_CHECK(other.at<1>().has_value());
sum += other.at<0>().value();
count += other.at<1>().value();
}

bool writeIntermediateResult(
exec::out_type<IntermediateType>& out,
const FunctionState& state) {
checkpoint(const_cast<FunctionState*>(&state));
out = sum;
out = std::make_tuple(sum, count);
return true;
}

bool writeFinalResult(
exec::out_type<OutputType>& out,
const FunctionState& state) {
checkpoint(const_cast<FunctionState*>(&state));
out = sum;
out = sum / count;
return true;
}
};
Expand All @@ -565,9 +568,10 @@ exec::AggregateRegistrationResult registerSimpleFunctionStateTestAggregate(
const std::string& name) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{
exec::AggregateFunctionSignatureBuilder()
.returnType("bigint")
.intermediateType("bigint")
.argumentType("bigint")
.returnType("DOUBLE")
.intermediateType("ROW(BIGINT, DOUBLE)")
.argumentType("BIGINT")
.argumentType("BIGINT")
.build()};

return exec::registerAggregateFunction(
Expand All @@ -579,8 +583,7 @@ exec::AggregateRegistrationResult registerSimpleFunctionStateTestAggregate(
const TypePtr& resultType,
const core::QueryConfig& /*config*/)
-> std::unique_ptr<exec::Aggregate> {
VELOX_CHECK_LE(
argTypes.size(), 1, "{} takes at most one argument", name);
VELOX_CHECK_LE(argTypes.size(), 2, "{} takes 2 argument", name);
return std::make_unique<
SimpleAggregateAdapter<FunctionStateTestAggregate>>(resultType);
},
Expand All @@ -598,118 +601,205 @@ class SimpleFunctionStateAggregationTest : public AggregationTestBase {
registerFunctionStateTestAggregate();
}

static void checkRowTypeEqual(TypePtr expected, TypePtr actual) {
VELOX_CHECK(expected->isRow());
VELOX_CHECK(actual->isRow());
VELOX_CHECK_EQ(expected->asRow().size(), actual->asRow().size());
for (auto i = 0; i < expected->asRow().size(); i++) {
VELOX_CHECK_EQ(expected->asRow().childAt(i), actual->asRow().childAt(i));
}
}

static void checkState(
FunctionStateTestAggregate::FunctionState* state,
const std::string& step = "") {
const std::vector<TypePtr>& rawInputTypes,
const TypePtr& intermediateType,
const TypePtr& resultType,
const std::vector<VectorPtr>& constantInputs) {
VELOX_CHECK(!state->rawInputTypes.empty());
VELOX_CHECK_NOT_NULL(state->resultType);
VELOX_CHECK(!state->rawInputType.empty());
if (!step.empty()) {
VELOX_CHECK_EQ(core::AggregationNode::stepName(state->step), step);

switch (state->step) {
case core::AggregationNode::Step::kPartial:
case core::AggregationNode::Step::kIntermediate:
if (state->rawInputTypes.size() == 1 &&
state->rawInputTypes[0]->isRow()) {
// Merge or merge_extract companion function.
VELOX_CHECK_EQ(rawInputTypes.size(), 1);
VELOX_CHECK(rawInputTypes[0]->isRow());
checkRowTypeEqual(rawInputTypes[0], state->rawInputTypes[0]);
} else {
VELOX_CHECK(std::equal(
rawInputTypes.begin(),
rawInputTypes.end(),
state->rawInputTypes.begin(),
state->rawInputTypes.end()));
}
if (state->resultType->isRow()) {
checkRowTypeEqual(intermediateType, state->resultType);
} else {
VELOX_CHECK_EQ(resultType, state->resultType)
}
break;

case core::AggregationNode::Step::kSingle:
case core::AggregationNode::Step::kFinal:
VELOX_CHECK(std::equal(
rawInputTypes.begin(),
rawInputTypes.end(),
state->rawInputTypes.begin(),
state->rawInputTypes.end()));
VELOX_CHECK_EQ(resultType, state->resultType);
break;

default:
VELOX_FAIL("Unknown aggregate step");
break;
}

VELOX_CHECK(!state->constantInputs.empty());
if (state->step == core::AggregationNode::Step::kPartial ||
state->step == core::AggregationNode::Step::kSingle) {
VELOX_CHECK_EQ(constantInputs.size(), state->constantInputs.size());
for (auto i = 0; i < constantInputs.size(); i++) {
auto expected = constantInputs[i];
auto actual = state->constantInputs[i];
if (expected == nullptr && actual == nullptr) {
continue;
} else {
VELOX_CHECK(expected != nullptr && actual != nullptr);
VELOX_CHECK(expected->isConstantEncoding());
VELOX_CHECK(actual->isConstantEncoding());
VELOX_CHECK_EQ(
expected->asUnchecked<SimpleVector<int64_t>>()->valueAt(0),
actual->asUnchecked<SimpleVector<int64_t>>()->valueAt(0));
}
}
} else {
VELOX_CHECK_EQ(state->constantInputs.size(), 1);
VELOX_CHECK_NULL(state->constantInputs[0]);
}
}
};

TEST_F(SimpleFunctionStateAggregationTest, aggregate) {
auto inputVectors = makeRowVector({makeFlatVector<int64_t>({1, 2, 3, 4})});
std::vector<int64_t> sum = {10};
auto expected = makeRowVector({makeFlatVector<int64_t>(sum)});
std::vector<double> finalResult = {2.5};
auto expected = makeRowVector({makeFlatVector<double>(finalResult)});

SCOPED_TESTVALUE_SET(
"facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint",
std::function<void(FunctionStateTestAggregate::FunctionState*)>(
[&](FunctionStateTestAggregate::FunctionState* state) {
checkState(state);
checkState(
state,
{BIGINT(), BIGINT()},
ROW({BIGINT(), DOUBLE()}),
DOUBLE(),
{nullptr, makeConstant<int64_t>(1, 4)});
}));

testAggregations(
{inputVectors}, {}, {"simple_function_state_agg(c0)"}, {expected});
testAggregationsWithCompanion(
{inputVectors},
[](auto& /*builder*/) {},
{},
{"simple_function_state_agg(c0)"},
{{BIGINT()}},
{},
{expected},
{});
{inputVectors}, {}, {"simple_function_state_agg(c0, 1)"}, {expected});
}

TEST_F(SimpleFunctionStateAggregationTest, window) {
auto inputVectors =
makeRowVector({makeFlatVector<int64_t>({1, 1, 2, 2, 3, 3, 4})});
auto expected =
makeRowVector({makeFlatVector<int64_t>({2, 2, 4, 4, 6, 6, 4})});
auto expected = makeRowVector(
{makeFlatVector<double>({1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0})});
SCOPED_TESTVALUE_SET(
"facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint",
std::function<void(FunctionStateTestAggregate::FunctionState*)>(
[&](FunctionStateTestAggregate::FunctionState* state) {
checkState(state, "SINGLE");
checkState(
state,
{BIGINT(), BIGINT()},
ROW({BIGINT(), DOUBLE()}),
DOUBLE(),
{nullptr, makeConstant<int64_t>(1, 7)});
}));
auto plan =
PlanBuilder()
.values({inputVectors})
.window({"simple_function_state_agg(c0) over (partition by c0)"})
.window({"simple_function_state_agg(c0, 1) over (partition by c0)"})
.project({"w0"})
.planNode();
AssertQueryBuilder(plan).assertResults(expected);
}

TEST_F(SimpleFunctionStateAggregationTest, aggregateStep) {
TEST_F(SimpleFunctionStateAggregationTest, companionAggregateFunction) {
auto inputVectors = makeRowVector({makeFlatVector<int64_t>({1, 2, 3, 4})});
std::vector<int64_t> sum = {10};
auto expected = makeRowVector({makeFlatVector<int64_t>(sum)});
std::vector<int64_t> accSum = {10};
std::vector<double> accCount = {4.0};
auto intermediateExpected = makeRowVector({
makeRowVector({
makeFlatVector<int64_t>(accSum),
makeFlatVector<double>(accCount),
}),
});
std::vector<double> finalResult = {2.5};
auto finalExpected = makeRowVector({makeFlatVector<double>(finalResult)});

SCOPED_TESTVALUE_SET(
"facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint",
std::function<void(FunctionStateTestAggregate::FunctionState*)>(
[&](FunctionStateTestAggregate::FunctionState* state) {
checkState(state, "PARTIAL");
checkState(
state,
{BIGINT(), BIGINT()},
ROW({BIGINT(), DOUBLE()}),
DOUBLE(),
{nullptr, makeConstant<int64_t>(1, 4)});
}));
AssertQueryBuilder(
PlanBuilder()
.values({inputVectors})
.singleAggregation({}, {"simple_function_state_agg_partial(c0)"})
.singleAggregation({}, {"simple_function_state_agg_partial(c0, 1)"})
.planNode())
.assertResults(expected);

.assertResults(intermediateExpected);

inputVectors = makeRowVector({
makeRowVector({
makeFlatVector<int64_t>({1, 2, 3, 4}),
makeFlatVector<double>({1.0, 1.0, 1.0, 1.0}),
}),
});
SCOPED_TESTVALUE_SET(
"facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint",
std::function<void(FunctionStateTestAggregate::FunctionState*)>(
[&](FunctionStateTestAggregate::FunctionState* state) {
checkState(state, "INTERMEDIATE");
checkState(
state,
{ROW({BIGINT(), DOUBLE()})},
ROW({BIGINT(), DOUBLE()}),
ROW({BIGINT(), DOUBLE()}),
{nullptr, makeConstant<int64_t>(1, 4)});
}));
AssertQueryBuilder(
PlanBuilder()
.values({inputVectors})
.singleAggregation({}, {"simple_function_state_agg_merge(c0)"})
.planNode())
.assertResults(expected);
.assertResults(intermediateExpected);

SCOPED_TESTVALUE_SET(
"facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint",
std::function<void(FunctionStateTestAggregate::FunctionState*)>(
[&](FunctionStateTestAggregate::FunctionState* state) {
checkState(state, "INTERMEDIATE");
checkState(
state,
{ROW({BIGINT(), DOUBLE()})},
ROW({BIGINT(), DOUBLE()}),
DOUBLE(),
{nullptr, makeConstant<int64_t>(1, 4)});
}));
AssertQueryBuilder(
PlanBuilder()
.values({inputVectors})
.singleAggregation(
{}, {"simple_function_state_agg_merge_extract(c0)"})
.planNode())
.assertResults(expected);

SCOPED_TESTVALUE_SET(
"facebook::velox::aggregate::test::FunctionStateTestAggregate::checkpoint",
std::function<void(FunctionStateTestAggregate::FunctionState*)>(
[&](FunctionStateTestAggregate::FunctionState* state) {
checkState(state, "FINAL");
}));
AssertQueryBuilder(
PlanBuilder()
.values({inputVectors})
.finalAggregation({}, {"simple_function_state_agg(c0)"}, {{BIGINT()}})
.planNode())
.assertResults(expected);
.assertResults(finalExpected);
}

} // namespace
Expand Down
Loading

0 comments on commit 7a15313

Please sign in to comment.