Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented two pass window functions #5247

Merged
merged 2 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ydb/core/kqp/opt/logical/kqp_opt_log.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ class TKqpLogicalOptTransformer : public TOptimizeTransformerBase {
}

TMaybeNode<TExprBase> ExpandWindowFunctions(TExprBase node, TExprContext& ctx) {
TExprBase output = DqExpandWindowFunctions(node, ctx, true);
TExprBase output = DqExpandWindowFunctions(node, ctx, TypesCtx, true);
DumpAppliedRule("ExpandWindowFunctions", node.Ptr(), output.Ptr(), ctx);
return output;
}
Expand Down
22 changes: 16 additions & 6 deletions ydb/library/yql/core/common_opt/yql_co_flow2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1758,22 +1758,32 @@ void RegisterCoFlowCallables2(TCallableOptimizerMap& map) {
};

map["SessionWindowTraits"] = map["SortTraits"] = map["Lag"] = map["Lead"] = map["RowNumber"] = map["Rank"] = map["DenseRank"] =
map["CumeDist"] = map["PercentRank"] = map["NTile"] =
[](const TExprNode::TPtr& node, TExprContext& ctx, TOptimizeContext& optCtx)
{
auto structType = node->Child(0)->GetTypeAnn()->Cast<TTypeExprType>()->GetType()
->Cast<TListExprType>()->GetItemType()->Cast<TStructExprType>();
if (node->IsCallable("RowNumber")) {
if (node->IsCallable({"RowNumber", "CumeDist", "NTile"})) {
if (structType->GetSize() == 0) {
return node;
}

auto subsetType = ctx.MakeType<TListExprType>(ctx.MakeType<TStructExprType>(TVector<const TItemExprType*>()));
YQL_CLOG(DEBUG, Core) << "FieldSubset for " << node->Content();
return ctx.Builder(node->Pos())
.Callable(node->Content())
.Add(0, ExpandType(node->Pos(), *subsetType, ctx))
.Seal()
.Build();
if (node->IsCallable("NTile")) {
return ctx.Builder(node->Pos())
.Callable(node->Content())
.Add(0, ExpandType(node->Pos(), *subsetType, ctx))
.Add(1, node->TailPtr())
.Seal()
.Build();
} else {
return ctx.Builder(node->Pos())
.Callable(node->Content())
.Add(0, ExpandType(node->Pos(), *subsetType, ctx))
.Seal()
.Build();
}
}

TSet<ui32> lambdaIndexes;
Expand Down
49 changes: 44 additions & 5 deletions ydb/library/yql/core/common_opt/yql_co_pgselect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2691,17 +2691,30 @@ TExprNode::TPtr BuildWindows(TPositionHandle pos, const TExprNode::TPtr& list, c
if (isAgg) {
value = BuildAggregationTraits(pos, true, "", p, listTypeNode, &aggId, ctx, optCtx);
} else {
if (name == "row_number") {
if (name == "row_number" || name == "cume_dist") {
value = ctx.Builder(pos)
.Callable("RowNumber")
.Callable(name == "row_number" ? "RowNumber" : "CumeDist")
.Callable(0, "TypeOf")
.Add(0, list)
.Add(0, list)
.Seal()
.Seal()
.Build();
} else if (name == "ntile") {
value = ctx.Builder(pos)
.Callable("NTile")
.Callable(0, "TypeOf")
.Add(0, list)
.Seal()
.Callable(1, "Unwrap")
.Callable(0, "FromPg")
.Add(0, p.first->ChildPtr(3))
.Seal()
.Seal()
.Seal()
.Build();
} else if (name == "rank" || name == "dense_rank") {
} else if (name == "rank" || name == "dense_rank" || name == "percent_rank") {
value = ctx.Builder(pos)
.Callable((name == "rank") ? "Rank" : "DenseRank")
.Callable((name == "rank") ? "Rank" : (name == "dense_rank" ? "DenseRank" : "PercentRank"))
.Callable(0, "TypeOf")
.Add(0, list)
.Seal()
Expand Down Expand Up @@ -2804,6 +2817,32 @@ TExprNode::TPtr BuildWindows(TPositionHandle pos, const TExprNode::TPtr& list, c
.Seal()
.Seal()
.Build();
} else if (node->Head().Content() == "ntile") {
ret = ctx.Builder(node->Pos())
.Callable("ToPg")
.Callable(0, "SafeCast")
.Add(0, ret)
.Atom(1, "Int32")
.Seal()
.Seal()
.Build();
} else if (node->Head().Content() == "cume_dist" || node->Head().Content() == "percent_rank") {
if (node->Head().Content() == "percent_rank") {
ret = ctx.Builder(node->Pos())
.Callable("Nanvl")
.Add(0, ret)
.Callable(1, "Double")
.Atom(0, "0.0")
.Seal()
.Seal()
.Build();
}

ret = ctx.Builder(node->Pos())
.Callable("ToPg")
.Add(0, ret)
.Seal()
.Build();
}

return ret;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8217,9 +8217,6 @@ struct TPeepHoleRules {
{"FoldMap", &CleckClosureOnUpperLambdaOverList<2U>},
{"Fold1Map", &CleckClosureOnUpperLambdaOverList<1U, 2U>},
{"Chain1Map", &CleckClosureOnUpperLambdaOverList<1U, 2U>},
{"CalcOverWindow", &ExpandCalcOverWindow},
{"CalcOverSessionWindow", &ExpandCalcOverWindow},
{"CalcOverWindowGroup", &ExpandCalcOverWindow},
{"PartitionsByKeys", &ExpandPartitionsByKeys},
{"DictItems", &MapForOptionalContainer},
{"DictKeys", &MapForOptionalContainer},
Expand Down Expand Up @@ -8283,7 +8280,10 @@ struct TPeepHoleRules {
{"AggregateFinalize", &ExpandAggregatePeephole},
{"CostsOf", &ExpandCostsOf},
{"JsonQuery", &ExpandJsonQuery},
{"MatchRecognize", &ExpandMatchRecognize}
{"MatchRecognize", &ExpandMatchRecognize},
{"CalcOverWindow", &ExpandCalcOverWindow},
{"CalcOverSessionWindow", &ExpandCalcOverWindow},
{"CalcOverWindowGroup", &ExpandCalcOverWindow},
};

const TPeepHoleOptimizerMap SimplifyStageRules = {
Expand Down
4 changes: 2 additions & 2 deletions ydb/library/yql/core/services/yql_lineage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -723,9 +723,9 @@ class TLineageScanner {
const auto& list = f->Child(i);
auto field = list->Head().Content();
auto& res = (*lineage.Fields)[field];
if (list->Tail().IsCallable("RowNumber")) {
if (list->Tail().IsCallable({"RowNumber","CumeDist","NTile"})) {
continue;
} else if (list->Tail().IsCallable({"Lag","Lead","Rank","DenseRank"})) {
} else if (list->Tail().IsCallable({"Lag","Lead","Rank","DenseRank","PercentRank"})) {
const auto& lambda = list->Tail().Child(1);
bool produceStruct = list->Tail().IsCallable({"Lag","Lead"});
MergeLineageFromUsedFields(lambda->Tail(), lambda->Head().Head(), innerLineage, res, produceStruct);
Expand Down
3 changes: 3 additions & 0 deletions ydb/library/yql/core/type_ann/type_ann_core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12315,6 +12315,9 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
Functions["RowNumber"] = &WinRowNumberWrapper;
Functions["Rank"] = &WinRankWrapper;
Functions["DenseRank"] = &WinRankWrapper;
Functions["PercentRank"] = &WinRankWrapper;
Functions["CumeDist"] = &WinRowNumberWrapper;
Functions["NTile"] = &WinNTileWrapper;
Functions["Ascending"] = &PresortWrapper;
Functions["Descending"] = &PresortWrapper;
Functions["IsKeySwitch"] = &IsKeySwitchWrapper;
Expand Down
27 changes: 24 additions & 3 deletions ydb/library/yql/core/type_ann/type_ann_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ namespace {

const auto paramName = func->Child(0)->Content();
const auto calcSpec = func->Child(1);
YQL_ENSURE(calcSpec->IsCallable({"Lag", "Lead", "RowNumber", "Rank", "DenseRank", "WindowTraits"}));
YQL_ENSURE(calcSpec->IsCallable({"Lag", "Lead", "RowNumber", "Rank", "DenseRank", "WindowTraits", "PercentRank", "CumeDist", "NTile"}));

auto traitsInputTypeNode = calcSpec->Child(0);
YQL_ENSURE(traitsInputTypeNode->GetTypeAnn());
Expand Down Expand Up @@ -5948,7 +5948,7 @@ namespace {
}
auto currColumn = input->Child(i)->Child(0)->Content();
auto calcSpec = input->Child(i)->Child(1);
if (!calcSpec->IsCallable({"WindowTraits", "Lag", "Lead", "RowNumber", "Rank", "DenseRank", "Void"})) {
if (!calcSpec->IsCallable({"WindowTraits", "Lag", "Lead", "RowNumber", "Rank", "DenseRank", "PercentRank", "CumeDist", "NTile", "Void"})) {
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(calcSpec->Pos()),
"Invalid traits or special function for calculation on window"));
return IGraphTransformer::TStatus::Error;
Expand Down Expand Up @@ -6305,6 +6305,26 @@ namespace {
if (auto status = EnsureTypeRewrite(input->HeadRef(), ctx.Expr); status != IGraphTransformer::TStatus::Ok) {
return status;
}
input->SetTypeAnn(ctx.Expr.MakeType<TDataExprType>(input->IsCallable("CumeDist") ? EDataSlot::Double : EDataSlot::Uint64));
return IGraphTransformer::TStatus::Ok;
}

IGraphTransformer::TStatus WinNTileWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
Y_UNUSED(output);
if (!EnsureArgsCount(*input, 2, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}

if (auto status = EnsureTypeRewrite(input->HeadRef(), ctx.Expr); status != IGraphTransformer::TStatus::Ok) {
return status;
}

auto expectedType = ctx.Expr.MakeType<TDataExprType>(EDataSlot::Int64);
auto status = TryConvertTo(input->ChildRef(1), *expectedType, ctx.Expr);
if (status.Level != IGraphTransformer::TStatus::Ok) {
return status;
}

input->SetTypeAnn(ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64));
return IGraphTransformer::TStatus::Ok;
}
Expand Down Expand Up @@ -6403,7 +6423,8 @@ namespace {
return IGraphTransformer::TStatus::Repeat;
}

const TTypeAnnotationNode* outputType = ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64);
const TTypeAnnotationNode* outputType = ctx.Expr.MakeType<TDataExprType>(input->IsCallable("PercentRank") ?
EDataSlot::Double : EDataSlot::Uint64);
if (!isAnsi && keyType->GetKind() == ETypeAnnotationKind::Optional) {
outputType = ctx.Expr.MakeType<TOptionalExprType>(outputType);
}
Expand Down
1 change: 1 addition & 0 deletions ydb/library/yql/core/type_ann/type_ann_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ namespace NTypeAnnImpl {
IGraphTransformer::TStatus WinLeadLagWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus WinRowNumberWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus WinRankWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus WinNTileWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus HoppingCoreWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus MultiHoppingCoreWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus HoppingTraitsWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
Expand Down
37 changes: 29 additions & 8 deletions ydb/library/yql/core/type_ann/type_ann_pg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -774,14 +774,6 @@ IGraphTransformer::TStatus PgWindowCallWrapper(const TExprNode::TPtr& input, TEx
return IGraphTransformer::TStatus::Error;
}

auto name = input->Child(4)->GetTypeAnn()->Cast<TPgExprType>()->GetName();
if (name != "int4") {
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Child(4)->Pos()), TStringBuilder() <<
"Expected pgint4 type, but got: " << name));
return IGraphTransformer::TStatus::Error;
}


auto arg = input->Child(3)->GetTypeAnn();
if (arg->IsOptionalOrNull()) {
input->SetTypeAnn(arg);
Expand All @@ -796,6 +788,35 @@ IGraphTransformer::TStatus PgWindowCallWrapper(const TExprNode::TPtr& input, TEx
}

input->SetTypeAnn(ctx.Expr.MakeType<TPgExprType>(NPg::LookupType("int8").TypeId));
} else if (name == "cume_dist" || name == "percent_rank") {
if (input->ChildrenSize() != 3) {
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
TStringBuilder() << "Expected no arguments in function " << name));
return IGraphTransformer::TStatus::Error;
}

input->SetTypeAnn(ctx.Expr.MakeType<TPgExprType>(NPg::LookupType("float8").TypeId));
} else if (name == "ntile") {
if (input->ChildrenSize() != 4) {
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
TStringBuilder() << "Expected exactly one argument in function " << name));
return IGraphTransformer::TStatus::Error;
}

if (input->Child(3)->GetTypeAnn() && input->Child(3)->GetTypeAnn()->GetKind() == ETypeAnnotationKind::Pg) {
auto name = input->Child(3)->GetTypeAnn()->Cast<TPgExprType>()->GetName();
if (name != "int4") {
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Child(3)->Pos()), TStringBuilder() <<
"Expected int4 type, but got: " << name));
return IGraphTransformer::TStatus::Error;
}
} else {
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Child(3)->Pos()), TStringBuilder() <<
"Expected pg type, but got: " << input->Child(3)->GetTypeAnn()->GetKind()));
return IGraphTransformer::TStatus::Error;
}

input->SetTypeAnn(ctx.Expr.MakeType<TPgExprType>(NPg::LookupType("int4").TypeId));
} else {
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
TStringBuilder() << "Unsupported function: " << name));
Expand Down
Loading
Loading