Skip to content

Commit

Permalink
Fixed resolving of pg aggregation over state (#10441)
Browse files Browse the repository at this point in the history
  • Loading branch information
vitstn authored Oct 15, 2024
1 parent b9bc524 commit 168df93
Show file tree
Hide file tree
Showing 14 changed files with 177 additions and 89 deletions.
13 changes: 12 additions & 1 deletion ydb/library/yql/core/yql_aggregate_expander.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2523,9 +2523,20 @@ TExprNode::TPtr TAggregateExpander::GeneratePhases() {
.Seal()
.Build();

auto name = TString(originalTrait->ChildPtr(0)->Content());
if (name.StartsWith("pg_")) {
auto func = name.substr(3);
TVector<ui32> argTypes;
bool needRetype = false;
auto status = ExtractPgTypesFromMultiLambda(originalTrait->ChildRef(2), argTypes, needRetype, Ctx);
YQL_ENSURE(status == IGraphTransformer::TStatus::Ok);
const NPg::TAggregateDesc& aggDesc = NPg::LookupAggregation(TString(func), argTypes);
name = "pg_" + aggDesc.Name + "#" + ToString(aggDesc.AggId);
}

mergeTraits.push_back(Ctx.Builder(Node->Pos())
.Callable(many ? "AggApplyManyState" : "AggApplyState")
.Add(0, originalTrait->ChildPtr(0))
.Atom(0, name)
.Add(1, extractorTypeNode)
.Add(2, extractor)
.Add(3, originalExtractorTypeNode)
Expand Down
23 changes: 17 additions & 6 deletions ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ struct TAggParams {
ui32 Column_ = 0;
TType* StateType_ = nullptr;
TType* ReturnType_ = nullptr;
ui32 Hint_ = 0;
};

struct TKeyParams {
Expand Down Expand Up @@ -1723,15 +1724,18 @@ std::unique_ptr<IPreparedBlockAggregator<TAggregator>> PrepareBlockAggregator(co
std::optional<ui32> filterColumn,
const std::vector<ui32>& argsColumns,
const TTypeEnvironment& env,
TType* returnType);
TType* returnType,
ui32 hint);

template <>
std::unique_ptr<IPreparedBlockAggregator<IBlockAggregatorCombineAll>> PrepareBlockAggregator<IBlockAggregatorCombineAll>(const IBlockAggregatorFactory& factory,
TTupleType* tupleType,
std::optional<ui32> filterColumn,
const std::vector<ui32>& argsColumns,
const TTypeEnvironment& env,
TType* returnType) {
TType* returnType,
ui32 hint) {
Y_UNUSED(hint);
MKQL_ENSURE(!returnType, "Unexpected return type");
return factory.PrepareCombineAll(tupleType, filterColumn, argsColumns, env);
}
Expand All @@ -1742,7 +1746,9 @@ std::unique_ptr<IPreparedBlockAggregator<IBlockAggregatorCombineKeys>> PrepareBl
std::optional<ui32> filterColumn,
const std::vector<ui32>& argsColumns,
const TTypeEnvironment& env,
TType* returnType) {
TType* returnType,
ui32 hint) {
Y_UNUSED(hint);
MKQL_ENSURE(!filterColumn, "Unexpected filter column");
MKQL_ENSURE(!returnType, "Unexpected return type");
return factory.PrepareCombineKeys(tupleType, argsColumns, env);
Expand All @@ -1754,10 +1760,11 @@ std::unique_ptr<IPreparedBlockAggregator<IBlockAggregatorFinalizeKeys>> PrepareB
std::optional<ui32> filterColumn,
const std::vector<ui32>& argsColumns,
const TTypeEnvironment& env,
TType* returnType) {
TType* returnType,
ui32 hint) {
MKQL_ENSURE(!filterColumn, "Unexpected filter column");
MKQL_ENSURE(returnType, "Missing return type");
return factory.PrepareFinalizeKeys(tupleType, argsColumns, env, returnType);
return factory.PrepareFinalizeKeys(tupleType, argsColumns, env, returnType, hint);
}

template <typename TAggregator>
Expand Down Expand Up @@ -1802,9 +1809,13 @@ ui32 FillAggParams(TTupleLiteral* aggsVal, TTupleType* tupleType, std::optional<
p.Column_ = argColumns[0];
p.StateType_ = AS_TYPE(TBlockType, tupleType->GetElementType(p.Column_))->GetItemType();
p.ReturnType_ = returnTypes[i + keysCount];
TStringBuf left, right;
if (TStringBuf(name).TrySplit('#', left, right)) {
p.Hint_ = FromString<ui32>(right);
}
}

p.Prepared_ = PrepareBlockAggregator<TAggregator>(GetBlockAggregatorFactory(name), unwrappedTupleType, filterColumn, argColumns, env, p.ReturnType_);
p.Prepared_ = PrepareBlockAggregator<TAggregator>(GetBlockAggregatorFactory(name), unwrappedTupleType, filterColumn, argColumns, env, p.ReturnType_, p.Hint_);

totalStateSize += p.Prepared_->StateSize;
aggsParams.emplace_back(std::move(p));
Expand Down
8 changes: 6 additions & 2 deletions ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,13 @@ class TBlockCountAllFactory : public IBlockAggregatorFactory {
TTupleType* tupleType,
const std::vector<ui32>& argsColumns,
const TTypeEnvironment& env,
TType* returnType) const final {
TType* returnType,
ui32 hint) const final {
Y_UNUSED(tupleType);
Y_UNUSED(argsColumns);
Y_UNUSED(env);
Y_UNUSED(returnType);
Y_UNUSED(hint);
return PrepareCountAll<TFinalizeKeysTag>(std::optional<ui32>(), argsColumns[0]);
}
};
Expand Down Expand Up @@ -395,11 +397,13 @@ class TBlockCountFactory : public IBlockAggregatorFactory {
TTupleType* tupleType,
const std::vector<ui32>& argsColumns,
const TTypeEnvironment& env,
TType* returnType) const final {
TType* returnType,
ui32 hint) const final {
Y_UNUSED(tupleType);
Y_UNUSED(argsColumns);
Y_UNUSED(env);
Y_UNUSED(returnType);
Y_UNUSED(hint);
return PrepareCount<TFinalizeKeysTag>(std::optional<ui32>(), argsColumns[0]);
}
};
Expand Down
5 changes: 5 additions & 0 deletions ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ struct TAggregatorFactories {

const IBlockAggregatorFactory& GetBlockAggregatorFactory(TStringBuf name) {
const auto& f = Singleton<TAggregatorFactories>()->Factories;
TStringBuf left, right;
if (name.TrySplit('#', left, right)) {
name = left;
}

auto it = f.find(name);
if (it == f.end()) {
throw yexception() << "Unsupported block aggregation function: " << name;
Expand Down
3 changes: 2 additions & 1 deletion ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ class IBlockAggregatorFactory {
TTupleType* tupleType,
const std::vector<ui32>& argsColumns,
const TTypeEnvironment& env,
TType* returnType) const = 0;
TType* returnType,
ui32 hint) const = 0;
};

const IBlockAggregatorFactory& GetBlockAggregatorFactory(TStringBuf name);
Expand Down
4 changes: 3 additions & 1 deletion ydb/library/yql/minikql/comp_nodes/mkql_block_agg_minmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1005,9 +1005,11 @@ class TBlockMinMaxFactory : public IBlockAggregatorFactory {
TTupleType* tupleType,
const std::vector<ui32>& argsColumns,
const TTypeEnvironment& env,
TType* returnType) const final {
TType* returnType,
ui32 hint) const final {
Y_UNUSED(env);
Y_UNUSED(returnType);
Y_UNUSED(hint);
return PrepareMinMax<TFinalizeKeysTag, IsMin>(tupleType, std::optional<ui32>(), argsColumns[0]);
}
};
Expand Down
4 changes: 3 additions & 1 deletion ydb/library/yql/minikql/comp_nodes/mkql_block_agg_some.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,11 @@ class TBlockSomeFactory : public IBlockAggregatorFactory {
TTupleType* tupleType,
const std::vector<ui32>& argsColumns,
const TTypeEnvironment& env,
TType* returnType) const override {
TType* returnType,
ui32 hint) const override {
Y_UNUSED(env);
Y_UNUSED(returnType);
Y_UNUSED(hint);
return PrepareSome<TFinalizeKeysTag>(tupleType, std::optional<ui32>(), argsColumns[0]);
}
};
Expand Down
8 changes: 6 additions & 2 deletions ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -700,9 +700,11 @@ class TBlockSumFactory : public IBlockAggregatorFactory {
TTupleType* tupleType,
const std::vector<ui32>& argsColumns,
const TTypeEnvironment& env,
TType* returnType) const final
TType* returnType,
ui32 hint) const final
{
Y_UNUSED(returnType);
Y_UNUSED(hint);
return PrepareSum<TFinalizeKeysTag>(tupleType, std::optional<ui32>(), argsColumns[0], env);
}
};
Expand Down Expand Up @@ -853,8 +855,10 @@ class TBlockAvgFactory : public IBlockAggregatorFactory {
TTupleType* tupleType,
const std::vector<ui32>& argsColumns,
const TTypeEnvironment& env,
TType* returnType) const final {
TType* returnType,
ui32 hint) const final {
Y_UNUSED(returnType);
Y_UNUSED(hint);
return PrepareAvg<TFinalizeKeysTag>(tupleType, std::optional<ui32>(), argsColumns[0], env);
}
};
Expand Down
14 changes: 13 additions & 1 deletion ydb/library/yql/parser/pg_catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3430,15 +3430,27 @@ const TAggregateDesc& LookupAggregation(const TString& name, const TVector<ui32>
}

const TAggregateDesc& LookupAggregation(const TString& name, ui32 stateType, ui32 resultType) {
TStringBuf realName = name;
TMaybe<ui32> aggId;
TStringBuf left, right;
if (realName.TrySplit('#', left, right)) {
aggId = FromString<ui32>(right);
realName = left;
}

const auto& catalog = TCatalog::Instance();
auto aggIdPtr = catalog.State->AggregationsByName.FindPtr(to_lower(name));
auto aggIdPtr = catalog.State->AggregationsByName.FindPtr(to_lower(TString(realName)));
if (!aggIdPtr) {
throw yexception() << "No such aggregate: " << name;
}

for (const auto& id : *aggIdPtr) {
const auto& d = catalog.State->Aggregations.FindPtr(id);
Y_ENSURE(d);
if (aggId && d->AggId != *aggId) {
continue;
}

if (!ValidateAggregateArgs(*d, stateType, resultType)) {
continue;
}
Expand Down
5 changes: 3 additions & 2 deletions ydb/library/yql/parser/pg_wrapper/arrow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,14 @@ TExecs::TExecs()
#undef RegisterExec
}

const NPg::TAggregateDesc& ResolveAggregation(const TString& name, NKikimr::NMiniKQL::TTupleType* tupleType, const std::vector<ui32>& argsColumns, NKikimr::NMiniKQL::TType* returnType) {
const NPg::TAggregateDesc& ResolveAggregation(const TString& name, NKikimr::NMiniKQL::TTupleType* tupleType,
const std::vector<ui32>& argsColumns, NKikimr::NMiniKQL::TType* returnType, ui32 hint) {
using namespace NKikimr::NMiniKQL;
if (returnType) {
MKQL_ENSURE(argsColumns.size() == 1, "Expected one column");
TType* stateType = AS_TYPE(TBlockType, tupleType->GetElementType(argsColumns[0]))->GetItemType();
TType* returnItemType = AS_TYPE(TBlockType, returnType)->GetItemType();
return NPg::LookupAggregation(name, AS_TYPE(TPgType, stateType)->GetTypeId(), AS_TYPE(TPgType, returnItemType)->GetTypeId());
return NPg::LookupAggregation(name + "#" + ToString(hint), AS_TYPE(TPgType, stateType)->GetTypeId(), AS_TYPE(TPgType, returnItemType)->GetTypeId());
} else {
TVector<ui32> argTypeIds;
for (const auto col : argsColumns) {
Expand Down
3 changes: 2 additions & 1 deletion ydb/library/yql/parser/pg_wrapper/arrow.h
Original file line number Diff line number Diff line change
Expand Up @@ -1310,7 +1310,8 @@ SkipCall:;

TExecFunc FindExec(Oid oid);

const NPg::TAggregateDesc& ResolveAggregation(const TString& name, NKikimr::NMiniKQL::TTupleType* tupleType, const std::vector<ui32>& argsColumns, NKikimr::NMiniKQL::TType* returnType);
const NPg::TAggregateDesc& ResolveAggregation(const TString& name, NKikimr::NMiniKQL::TTupleType* tupleType,
const std::vector<ui32>& argsColumns, NKikimr::NMiniKQL::TType* returnType, ui32 hint = 0);

}

5 changes: 3 additions & 2 deletions ydb/library/yql/parser/pg_wrapper/generate_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,9 @@ def main():
" TTupleType* tupleType,\n" \
" const std::vector<ui32>& argsColumns,\n" \
" const TTypeEnvironment& env,\n" \
" TType* returnType) const final {\n" \
" const auto& aggDesc = ResolveAggregation(\"NAME\", tupleType, argsColumns, returnType);\n"
" TType* returnType,\n" \
" ui32 hint) const final {\n" \
" const auto& aggDesc = ResolveAggregation(\"NAME\", tupleType, argsColumns, returnType, hint);\n"
" switch (aggDesc.AggId) {\n" +
"".join([" case " + str(agg_id) + ": return MakePgAgg_NAME_" + str(agg_id) + "().PrepareFinalizeKeys(argsColumns.front(), aggDesc);\n" for agg_id in agg_names[name]]) +
" default: throw yexception() << \"Unsupported agg id: \" << aggDesc.AggId;\n" \
Expand Down
Loading

0 comments on commit 168df93

Please sign in to comment.