diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_sort.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_sort.cpp index 061a3e96cd12..9bc6e65261cf 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_sort.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_sort.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include @@ -524,7 +525,8 @@ class TKeepTopWrapper : public TMutableComputationNode { IComputationExternalNode* arg, IComputationNode* key, IComputationNode* ascending, - IComputationExternalNode* hotkey) + IComputationExternalNode* hotkey, + IComputationExternalNode* cache) : TBaseComputation(mutables) , Description(mutables, std::move(keySchemeTypes), comparators) , Count(count) @@ -534,6 +536,7 @@ class TKeepTopWrapper : public TMutableComputationNode { , Key(key) , Ascending(ascending) , HotKey(hotkey) + , Cache(cache) {} NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const { @@ -542,14 +545,28 @@ class TKeepTopWrapper : public TMutableComputationNode { return ctx.HolderFactory.GetEmptyContainerLazy(); } + auto cache = Cache->GetValue(ctx); auto list = List->GetValue(ctx); auto item = Item->GetValue(ctx); + auto arg = item; + Arg->SetValue(ctx, arg.Release()); + auto key = Key->GetValue(ctx); + auto key_cached = key; + + if (cache.IsInvalid()) { + cache = ctx.HolderFactory.CreateDirectListHolder({key_cached.Release()}); + } else { + cache = ctx.HolderFactory.Append(cache.Release(), key_cached.Release()); + } const auto size = list.GetListLength(); if (size < count) { + Cache->SetValue(ctx, cache.Release()); return ctx.HolderFactory.Append(list.Release(), item.Release()); } + + const auto& ascending = Ascending->GetValue(ctx); auto hotkey = HotKey->GetValue(ctx); auto hotkey_prepared = hotkey; @@ -559,58 +576,51 @@ class TKeepTopWrapper : public TMutableComputationNode { if (size == count) { if (hotkey.IsInvalid()) { - TUnboxedValueVector keys; - keys.reserve(size); - - const auto ptr = list.GetElements(); - std::transform(ptr, ptr + size, std::back_inserter(keys), [&](const NUdf::TUnboxedValuePod item) { - Arg->SetValue(ctx, item); - return Key->GetValue(ctx); - }); - + const auto cacheptr = cache.GetElements(); + TUnboxedValueVector keys(cacheptr, cacheptr + size); auto keys_copy = keys; Description.Prepare(ctx, keys); - const auto& ascending = Ascending->GetValue(ctx); const auto max = std::max_element(keys.begin(), keys.end(), Description.MakeComparator(ascending)); hotkey_prepared = *max; HotKey->SetValue(ctx, std::move(keys_copy[max - keys.begin()])); } } - const auto copy = item; - Arg->SetValue(ctx, item.Release()); - auto key_prepared = Key->GetValue(ctx); + auto key_prepared = key; Description.PrepareValue(ctx, key_prepared); - const auto& ascending = Ascending->GetValue(ctx); - if (Description.MakeComparator(ascending)(key_prepared, hotkey_prepared)) { const auto reserve = std::max(count << 1ULL, 1ULL << 8ULL); if (size < reserve) { - return ctx.HolderFactory.Append(list.Release(), Arg->GetValue(ctx).Release()); + Cache->SetValue(ctx, cache.Release()); + return ctx.HolderFactory.Append(list.Release(), item.Release()); } - TKeyPayloadPairVector items(1U, TKeyPayloadPair(Key->GetValue(ctx), Arg->GetValue(ctx))); - items.reserve(items.size() + size); + TGatherIterator itemsIt(const_cast(cache.GetElements()), + const_cast(list.GetElements())); - const auto ptr = list.GetElements(); - std::transform(ptr, ptr + size, std::back_inserter(items), [&](const NUdf::TUnboxedValuePod item) { - Arg->SetValue(ctx, item); - return TKeyPayloadPair(Key->GetValue(ctx), Arg->GetValue(ctx)); - }); + TKeyPayloadPairVector items(itemsIt, itemsIt + size); + items.emplace_back(key, item); + auto items_copy = items; Description.Prepare(ctx, items); NYql::FastNthElement(items.begin(), items.begin() + count - 1U, items.end(), Description.MakeComparator(ascending)); items.resize(count); - NUdf::TUnboxedValue *inplace = nullptr; - const auto result = ctx.HolderFactory.CreateDirectArrayHolder(count, inplace); /// TODO: Use list holder. + NUdf::TUnboxedValue *cacheptr = nullptr; + NUdf::TUnboxedValue *listptr = nullptr; + const auto newCache = ctx.HolderFactory.CreateDirectArrayHolder(count, cacheptr); /// TODO: Use list holder. + const auto newList = ctx.HolderFactory.CreateDirectArrayHolder(count, listptr); /// TODO: Use list holder. for (auto& item : items) { - *inplace++ = std::move(item.second); + *cacheptr++ = std::move(item.first); + *listptr++ = std::move(item.second); } - return result; + Cache->SetValue(ctx, newCache); + const auto newHotKey = std::max_element(items.begin(), items.end(), Description.MakeComparator(ascending)); + HotKey->SetValue(ctx, std::move(items_copy[newHotKey - items.begin()].first)); + return newList; } return list.Release(); @@ -625,6 +635,7 @@ class TKeepTopWrapper : public TMutableComputationNode { DependsOn(Key); DependsOn(Ascending); Own(HotKey); + Own(Cache); } TCompareDescr Description; @@ -635,6 +646,7 @@ class TKeepTopWrapper : public TMutableComputationNode { IComputationNode* const Key; IComputationNode* const Ascending; IComputationExternalNode* const HotKey; + IComputationExternalNode* const Cache; }; std::vector> GetKeySchemeTypes(TType* keyType, TType* ascType) { @@ -759,7 +771,11 @@ IComputationNode* WrapTopSort(TCallable& callable, const TComputationNodeFactory } IComputationNode* WrapKeepTop(TCallable& callable, const TComputationNodeFactoryContext& ctx) { - MKQL_ENSURE(callable.GetInputsCount() == 7, "Expected 7 args"); + if constexpr (RuntimeVersion >= 49U) { + MKQL_ENSURE(callable.GetInputsCount() == 8, "Expected 8 args"); + } else { + MKQL_ENSURE(callable.GetInputsCount() == 7, "Expected 7 args"); + } const auto keyNode = callable.GetInput(4); const auto sortNode = callable.GetInput(5); @@ -775,9 +791,13 @@ IComputationNode* WrapKeepTop(TCallable& callable, const TComputationNodeFactory const auto ascending = LocateNode(ctx.NodeLocator, callable, 5); const auto itemArg = LocateExternalNode(ctx.NodeLocator, callable, 3); const auto hotkey = LocateExternalNode(ctx.NodeLocator, callable, 6); + IComputationExternalNode* cache = nullptr; + if constexpr (RuntimeVersion >= 49U) { + cache = LocateExternalNode(ctx.NodeLocator, callable, 7); + } auto comparators = MakeComparators(keyType, ascType->IsTuple()); - return new TKeepTopWrapper(ctx.Mutables, GetKeySchemeTypes(keyType, ascType), comparators, count, list, item, itemArg, key, ascending, hotkey); + return new TKeepTopWrapper(ctx.Mutables, GetKeySchemeTypes(keyType, ascType), comparators, count, list, item, itemArg, key, ascending, hotkey, cache); } } diff --git a/ydb/library/yql/minikql/mkql_program_builder.cpp b/ydb/library/yql/minikql/mkql_program_builder.cpp index 778aecf8e833..1a2adcd9f5f9 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.cpp +++ b/ydb/library/yql/minikql/mkql_program_builder.cpp @@ -1957,6 +1957,10 @@ TRuntimeNode TProgramBuilder::KeepTop(TRuntimeNode count, TRuntimeNode list, TRu callableBuilder.Add(key); callableBuilder.Add(ascending); callableBuilder.Add(hotkey); + if constexpr (RuntimeVersion >= 49U) { + const auto cache = Arg(TListType::Create(key.GetStaticType(), Env)); + callableBuilder.Add(cache); + } return TRuntimeNode(callableBuilder.Build(), false); } diff --git a/ydb/library/yql/minikql/mkql_runtime_version.h b/ydb/library/yql/minikql/mkql_runtime_version.h index ea9606ffbe21..d16f0184db2c 100644 --- a/ydb/library/yql/minikql/mkql_runtime_version.h +++ b/ydb/library/yql/minikql/mkql_runtime_version.h @@ -24,7 +24,7 @@ namespace NMiniKQL { // 1. Bump this version every time incompatible runtime nodes are introduced. // 2. Make sure you provide runtime node generation for previous runtime versions. #ifndef MKQL_RUNTIME_VERSION -#define MKQL_RUNTIME_VERSION 48U +#define MKQL_RUNTIME_VERSION 49U #endif // History: