Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[YQ-3621] support AFTER MATCH SKIP PAST LAST ROW (#10597) #10739

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions ydb/library/yql/core/sql_types/match_recognize.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,21 @@

namespace NYql::NMatchRecognize {

enum class EAfterMatchSkipTo {
NextRow,
PastLastRow,
ToFirst,
ToLast,
To
};

struct TAfterMatchSkipTo {
EAfterMatchSkipTo To;
TString Var;

[[nodiscard]] bool operator==(const TAfterMatchSkipTo&) const noexcept = default;
};

constexpr size_t MaxPatternNesting = 20; //Limit recursion for patterns
constexpr size_t MaxPermutedItems = 6;

Expand Down
179 changes: 27 additions & 152 deletions ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,131 +39,13 @@ struct TMatchRecognizeProcessorParameters {
TMeasureInputColumnOrder MeasureInputColumnOrder;
TComputationNodePtrVector Measures;
TOutputColumnOrder OutputColumnOrder;
};

class TBackTrackingMatchRecognize {
using TPartitionList = TSimpleList;
using TRange = TPartitionList::TRange;
using TMatchedVars = TMatchedVars<TRange>;
public:
//TODO(YQL-16486): create a tree for backtracking(replace var names with indexes)

struct TPatternConfiguration {
void Save(TMrOutputSerializer& /*serializer*/) const {
}

void Load(TMrInputSerializer& /*serializer*/) {
}

friend bool operator==(const TPatternConfiguration&, const TPatternConfiguration&) {
return true;
}
};

struct TPatternConfigurationBuilder {
using TPatternConfigurationPtr = std::shared_ptr<TPatternConfiguration>;
static TPatternConfigurationPtr Create(const TRowPattern& pattern, const THashMap<TString, size_t>& varNameToIndex) {
Y_UNUSED(pattern);
Y_UNUSED(varNameToIndex);
return std::make_shared<TPatternConfiguration>();
}
};

TBackTrackingMatchRecognize(
NUdf::TUnboxedValue&& partitionKey,
const TMatchRecognizeProcessorParameters& parameters,
const TPatternConfigurationBuilder::TPatternConfigurationPtr pattern,
const TContainerCacheOnContext& cache
)
: PartitionKey(std::move(partitionKey))
, Parameters(parameters)
, Cache(cache)
, CurMatchedVars(parameters.Defines.size())
, MatchNumber(0)
{
//TODO(YQL-16486)
Y_UNUSED(pattern);
}

bool ProcessInputRow(NUdf::TUnboxedValue&& row, TComputationContext& ctx) {
Y_UNUSED(ctx);
Rows.Append(std::move(row));
return false;
}
NUdf::TUnboxedValue GetOutputIfReady(TComputationContext& ctx) {
if (Matches.empty())
return NUdf::TUnboxedValue{};
Parameters.MatchedVarsArg->SetValue(ctx, ToValue(ctx.HolderFactory, std::move(Matches.front())));
Matches.pop_front();
Parameters.MeasureInputDataArg->SetValue(ctx, ctx.HolderFactory.Create<TMeasureInputDataValue>(
Parameters.InputDataArg->GetValue(ctx),
Parameters.MeasureInputColumnOrder,
Parameters.MatchedVarsArg->GetValue(ctx),
Parameters.VarNames,
++MatchNumber
));
NUdf::TUnboxedValue *itemsPtr = nullptr;
const auto result = Cache.NewArray(ctx, Parameters.OutputColumnOrder.size(), itemsPtr);
for (auto const& c: Parameters.OutputColumnOrder) {
switch(c.first) {
case EOutputColumnSource::Measure:
*itemsPtr++ = Parameters.Measures[c.second]->GetValue(ctx);
break;
case EOutputColumnSource::PartitionKey:
*itemsPtr++ = PartitionKey.GetElement(c.second);
break;
}
}
return result;
}
bool ProcessEndOfData(TComputationContext& ctx) {
//Assume, that data moved to IComputationExternalNode node, will not be modified or released
//till the end of the current function
auto rowsSize = Rows.Size();
Parameters.InputDataArg->SetValue(ctx, ctx.HolderFactory.Create<TListValue<TPartitionList>>(Rows));
for (size_t i = 0; i != rowsSize; ++i) {
Parameters.CurrentRowIndexArg->SetValue(ctx, NUdf::TUnboxedValuePod(static_cast<ui64>(i)));
for (size_t v = 0; v != Parameters.Defines.size(); ++v) {
const auto &d = Parameters.Defines[v]->GetValue(ctx);
if (d && d.GetOptionalValue().Get<bool>()) {
Extend(CurMatchedVars[v], TRange{i});
}
}
//for the sake of dummy usage assume non-overlapped matches at every 5th row of any partition
if (i % 5 == 0) {
TMatchedVars temp;
temp.swap(CurMatchedVars);
Matches.emplace_back(std::move(temp));
CurMatchedVars.resize(Parameters.Defines.size());
}
}
return not Matches.empty();
}

void Save(TOutputSerializer& /*serializer*/) const {
// Not used in not streaming mode.
}

void Load(TMrInputSerializer& /*serializer*/) {
// Not used in not streaming mode.
}

private:
const NUdf::TUnboxedValue PartitionKey;
const TMatchRecognizeProcessorParameters& Parameters;
const TContainerCacheOnContext& Cache;
TSimpleList Rows;
TMatchedVars CurMatchedVars;
std::deque<TMatchedVars, TMKQLAllocator<TMatchedVars>> Matches;
ui64 MatchNumber;
TAfterMatchSkipTo SkipTo;
};

class TStreamingMatchRecognize {
using TPartitionList = TSparseList;
using TRange = TPartitionList::TRange;
public:
using TPatternConfiguration = TNfaTransitionGraph;
using TPatternConfigurationBuilder = TNfaTransitionGraphBuilder;
TStreamingMatchRecognize(
NUdf::TUnboxedValue&& partitionKey,
const TMatchRecognizeProcessorParameters& parameters,
Expand Down Expand Up @@ -213,6 +95,9 @@ class TStreamingMatchRecognize {
break;
}
}
if (EAfterMatchSkipTo::PastLastRow == Parameters.SkipTo.To) {
Nfa.Clear();
}
return result;
}
bool ProcessEndOfData(TComputationContext& ctx) {
Expand Down Expand Up @@ -243,11 +128,9 @@ class TStreamingMatchRecognize {
ui64 MatchNumber = 0;
};

template <typename Algo>
class TStateForNonInterleavedPartitions
: public TComputationValue<TStateForNonInterleavedPartitions<Algo>>
: public TComputationValue<TStateForNonInterleavedPartitions>
{
using TRowPatternConfigurationBuilder = typename Algo::TPatternConfigurationBuilder;
public:
TStateForNonInterleavedPartitions(
TMemoryUsageInfo* memInfo,
Expand All @@ -265,7 +148,7 @@ class TStateForNonInterleavedPartitions
, PartitionKey(partitionKey)
, PartitionKeyPacker(true, partitionKeyType)
, Parameters(parameters)
, RowPatternConfiguration(TRowPatternConfigurationBuilder::Create(parameters.Pattern, parameters.VarNamesLookup))
, RowPatternConfiguration(TNfaTransitionGraphBuilder::Create(parameters.Pattern, parameters.VarNamesLookup))
, Cache(cache)
, Terminating(false)
, SerializerContext(ctx, rowType, rowPacker)
Expand Down Expand Up @@ -301,7 +184,7 @@ class TStateForNonInterleavedPartitions
bool validPartitionHandler = in.Read<bool>();
if (validPartitionHandler) {
NUdf::TUnboxedValue key = PartitionKeyPacker.Unpack(CurPartitionPackedKey, SerializerContext.Ctx.HolderFactory);
PartitionHandler.reset(new Algo(
PartitionHandler.reset(new TStreamingMatchRecognize(
std::move(key),
Parameters,
RowPatternConfiguration,
Expand All @@ -313,7 +196,7 @@ class TStateForNonInterleavedPartitions
if (validDelayedRow) {
in(DelayedRow);
}
auto restoredRowPatternConfiguration = std::make_shared<typename Algo::TPatternConfiguration>();
auto restoredRowPatternConfiguration = std::make_shared<TNfaTransitionGraph>();
restoredRowPatternConfiguration->Load(in);
MKQL_ENSURE(*restoredRowPatternConfiguration == *RowPatternConfiguration, "Restored and current RowPatternConfiguration is different");
MKQL_ENSURE(in.Empty(), "State is corrupted");
Expand Down Expand Up @@ -367,12 +250,11 @@ class TStateForNonInterleavedPartitions
InputRowArg->SetValue(ctx, NUdf::TUnboxedValue(temp));
auto partitionKey = PartitionKey->GetValue(ctx);
CurPartitionPackedKey = PartitionKeyPacker.Pack(partitionKey);
PartitionHandler.reset(new Algo(
PartitionHandler.reset(new TStreamingMatchRecognize(
std::move(partitionKey),
Parameters,
RowPatternConfiguration,
Cache
));
Cache));
PartitionHandler->ProcessInputRow(std::move(temp), ctx);
}
if (Terminating) {
Expand All @@ -382,12 +264,12 @@ class TStateForNonInterleavedPartitions
}
private:
TString CurPartitionPackedKey;
std::unique_ptr<Algo> PartitionHandler;
std::unique_ptr<TStreamingMatchRecognize> PartitionHandler;
IComputationExternalNode* InputRowArg;
IComputationNode* PartitionKey;
TValuePackerGeneric<false> PartitionKeyPacker;
const TMatchRecognizeProcessorParameters& Parameters;
const typename TRowPatternConfigurationBuilder::TPatternConfigurationPtr RowPatternConfiguration;
const TNfaTransitionGraph::TPtr RowPatternConfiguration;
const TContainerCacheOnContext& Cache;
NUdf::TUnboxedValue DelayedRow;
bool Terminating;
Expand Down Expand Up @@ -768,6 +650,11 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation
defines.push_back(callable.GetInput(inputIndex++));
}
const auto& streamingMode = callable.GetInput(inputIndex++);
NYql::NMatchRecognize::TAfterMatchSkipTo skipTo = {NYql::NMatchRecognize::EAfterMatchSkipTo::NextRow, ""};
if (inputIndex + 2 <= callable.GetInputsCount()) {
skipTo.To = static_cast<EAfterMatchSkipTo>(AS_VALUE(TDataLiteral, callable.GetInput(inputIndex++))->AsValue().Get<i32>());
skipTo.Var = AS_VALUE(TDataLiteral, callable.GetInput(inputIndex++))->AsValue().AsStringRef();
}
MKQL_ENSURE(callable.GetInputsCount() == inputIndex, "Wrong input count");

const auto& [vars, varsLookup] = ConvertListOfStrings(varNames);
Expand All @@ -788,6 +675,7 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation
)
, ConvertVectorOfCallables(measures, ctx)
, GetOutputColumnOrder(partitionColumnIndexes, measureColumnIndexes)
, skipTo
};
if (AS_VALUE(TDataLiteral, streamingMode)->AsValue().Get<bool>()) {
return new TMatchRecognizeWrapper<TStateForInterleavedPartitions>(ctx.Mutables
Expand All @@ -800,28 +688,15 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation
, rowType
);
} else {
const bool useNfaForTables = true; //TODO(YQL-16486) get this flag from an optimizer
if (useNfaForTables) {
return new TMatchRecognizeWrapper<TStateForNonInterleavedPartitions<TStreamingMatchRecognize>>(ctx.Mutables
, GetValueRepresentation(inputFlow.GetStaticType())
, LocateNode(ctx.NodeLocator, *inputFlow.GetNode())
, static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode()))
, LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode())
, partitionKeySelector.GetStaticType()
, std::move(parameters)
, rowType
);
} else {
return new TMatchRecognizeWrapper<TStateForNonInterleavedPartitions<TBackTrackingMatchRecognize>>(ctx.Mutables
, GetValueRepresentation(inputFlow.GetStaticType())
, LocateNode(ctx.NodeLocator, *inputFlow.GetNode())
, static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode()))
, LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode())
, partitionKeySelector.GetStaticType()
, std::move(parameters)
, rowType
);
}
return new TMatchRecognizeWrapper<TStateForNonInterleavedPartitions>(ctx.Mutables
, GetValueRepresentation(inputFlow.GetStaticType())
, LocateNode(ctx.NodeLocator, *inputFlow.GetNode())
, static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode()))
, LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode())
, partitionKeySelector.GetStaticType()
, std::move(parameters)
, rowType
);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,7 @@ class TNfaTransitionGraphBuilder {
return {input, output};
}
public:
using TPatternConfigurationPtr = TNfaTransitionGraph::TPtr;
static TPatternConfigurationPtr Create(const TRowPattern& pattern, const THashMap<TString, size_t>& varNameToIndex) {
static TNfaTransitionGraph::TPtr Create(const TRowPattern& pattern, const THashMap<TString, size_t>& varNameToIndex) {
auto result = std::make_shared<TNfaTransitionGraph>();
TNfaTransitionGraphBuilder builder(result);
auto item = builder.BuildTerms(pattern, varNameToIndex);
Expand Down Expand Up @@ -455,6 +454,10 @@ class TNfa {
serializer.Read(EpsilonTransitionsLastRow);
}

void Clear() {
ActiveStates.clear();
}

private:
//TODO (zverevgeny): Consider to change to std::vector for the sake of perf
using TStateSet = std::set<TState, std::less<TState>, TMKQLAllocator<TState>>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ namespace NKikimr {
{NYql::NMatchRecognize::TRowPatternFactor{"A", 3, 3, false, false, false}}
},
getDefines,
streamingMode);
streamingMode,
{NYql::NMatchRecognize::EAfterMatchSkipTo::NextRow, ""}
);

auto graph = setup.BuildGraph(pgmReturn);
return graph;
Expand Down
7 changes: 6 additions & 1 deletion ydb/library/yql/minikql/mkql_program_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5902,7 +5902,8 @@ TRuntimeNode TProgramBuilder::MatchRecognizeCore(
const TArrayRef<std::pair<TStringBuf, TBinaryLambda>>& getMeasures,
const NYql::NMatchRecognize::TRowPattern& pattern,
const TArrayRef<std::pair<TStringBuf, TTernaryLambda>>& getDefines,
bool streamingMode
bool streamingMode,
const NYql::NMatchRecognize::TAfterMatchSkipTo& skipTo
) {
MKQL_ENSURE(RuntimeVersion >= 42, "MatchRecognize is not supported in runtime version " << RuntimeVersion);

Expand Down Expand Up @@ -6056,6 +6057,10 @@ TRuntimeNode TProgramBuilder::MatchRecognizeCore(
callableBuilder.Add(d);
}
callableBuilder.Add(NewDataLiteral(streamingMode));
if (RuntimeVersion >= 52U) {
callableBuilder.Add(NewDataLiteral(static_cast<i32>(skipTo.To)));
callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(skipTo.Var));
}
return TRuntimeNode(callableBuilder.Build(), false);
}

Expand Down
3 changes: 2 additions & 1 deletion ydb/library/yql/minikql/mkql_program_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,8 @@ class TProgramBuilder : public TTypeBuilder {
const TArrayRef<std::pair<TStringBuf, TBinaryLambda>>& getMeasures,
const NYql::NMatchRecognize::TRowPattern& pattern,
const TArrayRef<std::pair<TStringBuf, TTernaryLambda>>& getDefines,
bool streamingMode
bool streamingMode,
const NYql::NMatchRecognize::TAfterMatchSkipTo& skipTo
);

TRuntimeNode TimeOrderRecover(
Expand Down
2 changes: 1 addition & 1 deletion ydb/library/yql/minikql/mkql_runtime_version.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace NMiniKQL {
// 1. Bump this version every time incompatible runtime nodes are introduced.
// 2. Make sure you provide runtime node generation for previous runtime versions.
#ifndef MKQL_RUNTIME_VERSION
#define MKQL_RUNTIME_VERSION 50U
#define MKQL_RUNTIME_VERSION 52U
#endif

// History:
Expand Down
Loading
Loading