Skip to content

Commit

Permalink
Merge f5ba47a into 9a9711d
Browse files Browse the repository at this point in the history
  • Loading branch information
APozdniakov authored Oct 22, 2024
2 parents 9a9711d + f5ba47a commit e47b4e3
Show file tree
Hide file tree
Showing 21 changed files with 199 additions and 206 deletions.
15 changes: 15 additions & 0 deletions ydb/library/yql/core/sql_types/match_recognize.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,21 @@

namespace NYql::NMatchRecognize {

enum class EAfterMatchSkipTo {
NextRow,
PastLastRow,
ToFirst,
ToLast,
To
};

struct TAfterMatchSkipTo {
EAfterMatchSkipTo To;
TString Var;

[[nodiscard]] bool operator==(const TAfterMatchSkipTo&) const noexcept = default;
};

constexpr size_t MaxPatternNesting = 20; //Limit recursion for patterns
constexpr size_t MaxPermutedItems = 6;

Expand Down
179 changes: 27 additions & 152 deletions ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,131 +39,13 @@ struct TMatchRecognizeProcessorParameters {
TMeasureInputColumnOrder MeasureInputColumnOrder;
TComputationNodePtrVector Measures;
TOutputColumnOrder OutputColumnOrder;
};

class TBackTrackingMatchRecognize {
using TPartitionList = TSimpleList;
using TRange = TPartitionList::TRange;
using TMatchedVars = TMatchedVars<TRange>;
public:
//TODO(YQL-16486): create a tree for backtracking(replace var names with indexes)

struct TPatternConfiguration {
void Save(TMrOutputSerializer& /*serializer*/) const {
}

void Load(TMrInputSerializer& /*serializer*/) {
}

friend bool operator==(const TPatternConfiguration&, const TPatternConfiguration&) {
return true;
}
};

struct TPatternConfigurationBuilder {
using TPatternConfigurationPtr = std::shared_ptr<TPatternConfiguration>;
static TPatternConfigurationPtr Create(const TRowPattern& pattern, const THashMap<TString, size_t>& varNameToIndex) {
Y_UNUSED(pattern);
Y_UNUSED(varNameToIndex);
return std::make_shared<TPatternConfiguration>();
}
};

TBackTrackingMatchRecognize(
NUdf::TUnboxedValue&& partitionKey,
const TMatchRecognizeProcessorParameters& parameters,
const TPatternConfigurationBuilder::TPatternConfigurationPtr pattern,
const TContainerCacheOnContext& cache
)
: PartitionKey(std::move(partitionKey))
, Parameters(parameters)
, Cache(cache)
, CurMatchedVars(parameters.Defines.size())
, MatchNumber(0)
{
//TODO(YQL-16486)
Y_UNUSED(pattern);
}

bool ProcessInputRow(NUdf::TUnboxedValue&& row, TComputationContext& ctx) {
Y_UNUSED(ctx);
Rows.Append(std::move(row));
return false;
}
NUdf::TUnboxedValue GetOutputIfReady(TComputationContext& ctx) {
if (Matches.empty())
return NUdf::TUnboxedValue{};
Parameters.MatchedVarsArg->SetValue(ctx, ToValue(ctx.HolderFactory, std::move(Matches.front())));
Matches.pop_front();
Parameters.MeasureInputDataArg->SetValue(ctx, ctx.HolderFactory.Create<TMeasureInputDataValue>(
Parameters.InputDataArg->GetValue(ctx),
Parameters.MeasureInputColumnOrder,
Parameters.MatchedVarsArg->GetValue(ctx),
Parameters.VarNames,
++MatchNumber
));
NUdf::TUnboxedValue *itemsPtr = nullptr;
const auto result = Cache.NewArray(ctx, Parameters.OutputColumnOrder.size(), itemsPtr);
for (auto const& c: Parameters.OutputColumnOrder) {
switch(c.first) {
case EOutputColumnSource::Measure:
*itemsPtr++ = Parameters.Measures[c.second]->GetValue(ctx);
break;
case EOutputColumnSource::PartitionKey:
*itemsPtr++ = PartitionKey.GetElement(c.second);
break;
}
}
return result;
}
bool ProcessEndOfData(TComputationContext& ctx) {
//Assume, that data moved to IComputationExternalNode node, will not be modified or released
//till the end of the current function
auto rowsSize = Rows.Size();
Parameters.InputDataArg->SetValue(ctx, ctx.HolderFactory.Create<TListValue<TPartitionList>>(Rows));
for (size_t i = 0; i != rowsSize; ++i) {
Parameters.CurrentRowIndexArg->SetValue(ctx, NUdf::TUnboxedValuePod(static_cast<ui64>(i)));
for (size_t v = 0; v != Parameters.Defines.size(); ++v) {
const auto &d = Parameters.Defines[v]->GetValue(ctx);
if (d && d.GetOptionalValue().Get<bool>()) {
Extend(CurMatchedVars[v], TRange{i});
}
}
//for the sake of dummy usage assume non-overlapped matches at every 5th row of any partition
if (i % 5 == 0) {
TMatchedVars temp;
temp.swap(CurMatchedVars);
Matches.emplace_back(std::move(temp));
CurMatchedVars.resize(Parameters.Defines.size());
}
}
return not Matches.empty();
}

void Save(TOutputSerializer& /*serializer*/) const {
// Not used in not streaming mode.
}

void Load(TMrInputSerializer& /*serializer*/) {
// Not used in not streaming mode.
}

private:
const NUdf::TUnboxedValue PartitionKey;
const TMatchRecognizeProcessorParameters& Parameters;
const TContainerCacheOnContext& Cache;
TSimpleList Rows;
TMatchedVars CurMatchedVars;
std::deque<TMatchedVars, TMKQLAllocator<TMatchedVars>> Matches;
ui64 MatchNumber;
TAfterMatchSkipTo SkipTo;
};

class TStreamingMatchRecognize {
using TPartitionList = TSparseList;
using TRange = TPartitionList::TRange;
public:
using TPatternConfiguration = TNfaTransitionGraph;
using TPatternConfigurationBuilder = TNfaTransitionGraphBuilder;
TStreamingMatchRecognize(
NUdf::TUnboxedValue&& partitionKey,
const TMatchRecognizeProcessorParameters& parameters,
Expand Down Expand Up @@ -213,6 +95,9 @@ class TStreamingMatchRecognize {
break;
}
}
if (EAfterMatchSkipTo::PastLastRow == Parameters.SkipTo.To) {
Nfa.Clear();
}
return result;
}
bool ProcessEndOfData(TComputationContext& ctx) {
Expand Down Expand Up @@ -243,11 +128,9 @@ class TStreamingMatchRecognize {
ui64 MatchNumber = 0;
};

template <typename Algo>
class TStateForNonInterleavedPartitions
: public TComputationValue<TStateForNonInterleavedPartitions<Algo>>
: public TComputationValue<TStateForNonInterleavedPartitions>
{
using TRowPatternConfigurationBuilder = typename Algo::TPatternConfigurationBuilder;
public:
TStateForNonInterleavedPartitions(
TMemoryUsageInfo* memInfo,
Expand All @@ -265,7 +148,7 @@ class TStateForNonInterleavedPartitions
, PartitionKey(partitionKey)
, PartitionKeyPacker(true, partitionKeyType)
, Parameters(parameters)
, RowPatternConfiguration(TRowPatternConfigurationBuilder::Create(parameters.Pattern, parameters.VarNamesLookup))
, RowPatternConfiguration(TNfaTransitionGraphBuilder::Create(parameters.Pattern, parameters.VarNamesLookup))
, Cache(cache)
, Terminating(false)
, SerializerContext(ctx, rowType, rowPacker)
Expand Down Expand Up @@ -301,7 +184,7 @@ class TStateForNonInterleavedPartitions
bool validPartitionHandler = in.Read<bool>();
if (validPartitionHandler) {
NUdf::TUnboxedValue key = PartitionKeyPacker.Unpack(CurPartitionPackedKey, SerializerContext.Ctx.HolderFactory);
PartitionHandler.reset(new Algo(
PartitionHandler.reset(new TStreamingMatchRecognize(
std::move(key),
Parameters,
RowPatternConfiguration,
Expand All @@ -313,7 +196,7 @@ class TStateForNonInterleavedPartitions
if (validDelayedRow) {
in(DelayedRow);
}
auto restoredRowPatternConfiguration = std::make_shared<typename Algo::TPatternConfiguration>();
auto restoredRowPatternConfiguration = std::make_shared<TNfaTransitionGraph>();
restoredRowPatternConfiguration->Load(in);
MKQL_ENSURE(*restoredRowPatternConfiguration == *RowPatternConfiguration, "Restored and current RowPatternConfiguration is different");
MKQL_ENSURE(in.Empty(), "State is corrupted");
Expand Down Expand Up @@ -367,12 +250,11 @@ class TStateForNonInterleavedPartitions
InputRowArg->SetValue(ctx, NUdf::TUnboxedValue(temp));
auto partitionKey = PartitionKey->GetValue(ctx);
CurPartitionPackedKey = PartitionKeyPacker.Pack(partitionKey);
PartitionHandler.reset(new Algo(
PartitionHandler.reset(new TStreamingMatchRecognize(
std::move(partitionKey),
Parameters,
RowPatternConfiguration,
Cache
));
Cache));
PartitionHandler->ProcessInputRow(std::move(temp), ctx);
}
if (Terminating) {
Expand All @@ -382,12 +264,12 @@ class TStateForNonInterleavedPartitions
}
private:
TString CurPartitionPackedKey;
std::unique_ptr<Algo> PartitionHandler;
std::unique_ptr<TStreamingMatchRecognize> PartitionHandler;
IComputationExternalNode* InputRowArg;
IComputationNode* PartitionKey;
TValuePackerGeneric<false> PartitionKeyPacker;
const TMatchRecognizeProcessorParameters& Parameters;
const typename TRowPatternConfigurationBuilder::TPatternConfigurationPtr RowPatternConfiguration;
const TNfaTransitionGraph::TPtr RowPatternConfiguration;
const TContainerCacheOnContext& Cache;
NUdf::TUnboxedValue DelayedRow;
bool Terminating;
Expand Down Expand Up @@ -768,6 +650,11 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation
defines.push_back(callable.GetInput(inputIndex++));
}
const auto& streamingMode = callable.GetInput(inputIndex++);
NYql::NMatchRecognize::TAfterMatchSkipTo skipTo = {NYql::NMatchRecognize::EAfterMatchSkipTo::NextRow, ""};
if (inputIndex + 2 <= callable.GetInputsCount()) {
skipTo.To = static_cast<EAfterMatchSkipTo>(AS_VALUE(TDataLiteral, callable.GetInput(inputIndex++))->AsValue().Get<i32>());
skipTo.Var = AS_VALUE(TDataLiteral, callable.GetInput(inputIndex++))->AsValue().AsStringRef();
}
MKQL_ENSURE(callable.GetInputsCount() == inputIndex, "Wrong input count");

const auto& [vars, varsLookup] = ConvertListOfStrings(varNames);
Expand All @@ -788,6 +675,7 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation
)
, ConvertVectorOfCallables(measures, ctx)
, GetOutputColumnOrder(partitionColumnIndexes, measureColumnIndexes)
, skipTo
};
if (AS_VALUE(TDataLiteral, streamingMode)->AsValue().Get<bool>()) {
return new TMatchRecognizeWrapper<TStateForInterleavedPartitions>(ctx.Mutables
Expand All @@ -800,28 +688,15 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation
, rowType
);
} else {
const bool useNfaForTables = true; //TODO(YQL-16486) get this flag from an optimizer
if (useNfaForTables) {
return new TMatchRecognizeWrapper<TStateForNonInterleavedPartitions<TStreamingMatchRecognize>>(ctx.Mutables
, GetValueRepresentation(inputFlow.GetStaticType())
, LocateNode(ctx.NodeLocator, *inputFlow.GetNode())
, static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode()))
, LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode())
, partitionKeySelector.GetStaticType()
, std::move(parameters)
, rowType
);
} else {
return new TMatchRecognizeWrapper<TStateForNonInterleavedPartitions<TBackTrackingMatchRecognize>>(ctx.Mutables
, GetValueRepresentation(inputFlow.GetStaticType())
, LocateNode(ctx.NodeLocator, *inputFlow.GetNode())
, static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode()))
, LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode())
, partitionKeySelector.GetStaticType()
, std::move(parameters)
, rowType
);
}
return new TMatchRecognizeWrapper<TStateForNonInterleavedPartitions>(ctx.Mutables
, GetValueRepresentation(inputFlow.GetStaticType())
, LocateNode(ctx.NodeLocator, *inputFlow.GetNode())
, static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode()))
, LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode())
, partitionKeySelector.GetStaticType()
, std::move(parameters)
, rowType
);
}
}

Expand Down
7 changes: 5 additions & 2 deletions ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_nfa.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,7 @@ class TNfaTransitionGraphBuilder {
return {input, output};
}
public:
using TPatternConfigurationPtr = TNfaTransitionGraph::TPtr;
static TPatternConfigurationPtr Create(const TRowPattern& pattern, const THashMap<TString, size_t>& varNameToIndex) {
static TNfaTransitionGraph::TPtr Create(const TRowPattern& pattern, const THashMap<TString, size_t>& varNameToIndex) {
auto result = std::make_shared<TNfaTransitionGraph>();
TNfaTransitionGraphBuilder builder(result);
auto item = builder.BuildTerms(pattern, varNameToIndex);
Expand Down Expand Up @@ -455,6 +454,10 @@ class TNfa {
serializer.Read(EpsilonTransitionsLastRow);
}

void Clear() {
ActiveStates.clear();
}

private:
//TODO (zverevgeny): Consider to change to std::vector for the sake of perf
using TStateSet = std::set<TState, std::less<TState>, TMKQLAllocator<TState>>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ namespace NKikimr {
{NYql::NMatchRecognize::TRowPatternFactor{"A", 3, 3, false, false, false}}
},
getDefines,
streamingMode);
streamingMode,
{NYql::NMatchRecognize::EAfterMatchSkipTo::NextRow, ""}
);

auto graph = setup.BuildGraph(pgmReturn);
return graph;
Expand Down
7 changes: 6 additions & 1 deletion ydb/library/yql/minikql/mkql_program_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5902,7 +5902,8 @@ TRuntimeNode TProgramBuilder::MatchRecognizeCore(
const TArrayRef<std::pair<TStringBuf, TBinaryLambda>>& getMeasures,
const NYql::NMatchRecognize::TRowPattern& pattern,
const TArrayRef<std::pair<TStringBuf, TTernaryLambda>>& getDefines,
bool streamingMode
bool streamingMode,
const NYql::NMatchRecognize::TAfterMatchSkipTo& skipTo
) {
MKQL_ENSURE(RuntimeVersion >= 42, "MatchRecognize is not supported in runtime version " << RuntimeVersion);

Expand Down Expand Up @@ -6056,6 +6057,10 @@ TRuntimeNode TProgramBuilder::MatchRecognizeCore(
callableBuilder.Add(d);
}
callableBuilder.Add(NewDataLiteral(streamingMode));
if (RuntimeVersion >= 52U) {
callableBuilder.Add(NewDataLiteral(static_cast<i32>(skipTo.To)));
callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(skipTo.Var));
}
return TRuntimeNode(callableBuilder.Build(), false);
}

Expand Down
3 changes: 2 additions & 1 deletion ydb/library/yql/minikql/mkql_program_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,8 @@ class TProgramBuilder : public TTypeBuilder {
const TArrayRef<std::pair<TStringBuf, TBinaryLambda>>& getMeasures,
const NYql::NMatchRecognize::TRowPattern& pattern,
const TArrayRef<std::pair<TStringBuf, TTernaryLambda>>& getDefines,
bool streamingMode
bool streamingMode,
const NYql::NMatchRecognize::TAfterMatchSkipTo& skipTo
);

TRuntimeNode TimeOrderRecover(
Expand Down
2 changes: 1 addition & 1 deletion ydb/library/yql/minikql/mkql_runtime_version.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace NMiniKQL {
// 1. Bump this version every time incompatible runtime nodes are introduced.
// 2. Make sure you provide runtime node generation for previous runtime versions.
#ifndef MKQL_RUNTIME_VERSION
#define MKQL_RUNTIME_VERSION 50U
#define MKQL_RUNTIME_VERSION 52U
#endif

// History:
Expand Down
Loading

0 comments on commit e47b4e3

Please sign in to comment.