From 229df8b301cda4621443d9e7b8ad7ac1eb5de53e Mon Sep 17 00:00:00 2001 From: Igor Munkin Date: Wed, 22 May 2024 17:44:06 +0000 Subject: [PATCH] YQL-16402: Implement key cache for TKeepTopWrapper --- .../yql/minikql/comp_nodes/mkql_sort.cpp | 72 +++++++++++-------- .../yql/minikql/mkql_program_builder.cpp | 2 + 2 files changed, 43 insertions(+), 31 deletions(-) diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_sort.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_sort.cpp index 061a3e96cd12..02eb105add3d 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_sort.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_sort.cpp @@ -524,7 +524,8 @@ class TKeepTopWrapper : public TMutableComputationNode { IComputationExternalNode* arg, IComputationNode* key, IComputationNode* ascending, - IComputationExternalNode* hotkey) + IComputationExternalNode* hotkey, + IComputationExternalNode* cache) : TBaseComputation(mutables) , Description(mutables, std::move(keySchemeTypes), comparators) , Count(count) @@ -534,6 +535,7 @@ class TKeepTopWrapper : public TMutableComputationNode { , Key(key) , Ascending(ascending) , HotKey(hotkey) + , Cache(cache) {} NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const { @@ -542,14 +544,27 @@ class TKeepTopWrapper : public TMutableComputationNode { return ctx.HolderFactory.GetEmptyContainerLazy(); } - auto list = List->GetValue(ctx); auto item = Item->GetValue(ctx); + auto arg = item; + Arg->SetValue(ctx, arg.Release()); + + auto list = List->GetValue(ctx); + auto cache = Cache->GetValue(ctx); + auto key = Key->GetValue(ctx); const auto size = list.GetListLength(); + if (cache.IsInvalid()) { + cache = ctx.HolderFactory.CreateDirectListHolder({}); + const auto cacheSize = cache.GetListLength(); + Y_UNUSED(cacheSize); + } if (size < count) { + Cache->SetValue(ctx, ctx.HolderFactory.Append(cache.Release(), key.Release())); return ctx.HolderFactory.Append(list.Release(), item.Release()); } + + const auto& ascending = Ascending->GetValue(ctx); auto hotkey = HotKey->GetValue(ctx); auto hotkey_prepared = hotkey; @@ -559,58 +574,50 @@ class TKeepTopWrapper : public TMutableComputationNode { if (size == count) { if (hotkey.IsInvalid()) { - TUnboxedValueVector keys; - keys.reserve(size); - - const auto ptr = list.GetElements(); - std::transform(ptr, ptr + size, std::back_inserter(keys), [&](const NUdf::TUnboxedValuePod item) { - Arg->SetValue(ctx, item); - return Key->GetValue(ctx); - }); - + const auto cacheptr = cache.GetElements(); + TUnboxedValueVector keys(cacheptr, cacheptr + size); auto keys_copy = keys; Description.Prepare(ctx, keys); - const auto& ascending = Ascending->GetValue(ctx); const auto max = std::max_element(keys.begin(), keys.end(), Description.MakeComparator(ascending)); hotkey_prepared = *max; HotKey->SetValue(ctx, std::move(keys_copy[max - keys.begin()])); } } - const auto copy = item; - Arg->SetValue(ctx, item.Release()); - auto key_prepared = Key->GetValue(ctx); + auto key_prepared = key; Description.PrepareValue(ctx, key_prepared); - const auto& ascending = Ascending->GetValue(ctx); - if (Description.MakeComparator(ascending)(key_prepared, hotkey_prepared)) { const auto reserve = std::max(count << 1ULL, 1ULL << 8ULL); if (size < reserve) { - return ctx.HolderFactory.Append(list.Release(), Arg->GetValue(ctx).Release()); + Cache->SetValue(ctx, ctx.HolderFactory.Append(cache.Release(), key.Release())); + return ctx.HolderFactory.Append(list.Release(), item.Release()); } - TKeyPayloadPairVector items(1U, TKeyPayloadPair(Key->GetValue(ctx), Arg->GetValue(ctx))); - items.reserve(items.size() + size); + TGatherIterator itemsIt(const_cast(cache.GetElements()), + const_cast(list.GetElements())); - const auto ptr = list.GetElements(); - std::transform(ptr, ptr + size, std::back_inserter(items), [&](const NUdf::TUnboxedValuePod item) { - Arg->SetValue(ctx, item); - return TKeyPayloadPair(Key->GetValue(ctx), Arg->GetValue(ctx)); - }); + TKeyPayloadPairVector items(itemsIt, itemsIt + size); + items.emplace_back(key, item); Description.Prepare(ctx, items); NYql::FastNthElement(items.begin(), items.begin() + count - 1U, items.end(), Description.MakeComparator(ascending)); items.resize(count); - NUdf::TUnboxedValue *inplace = nullptr; - const auto result = ctx.HolderFactory.CreateDirectArrayHolder(count, inplace); /// TODO: Use list holder. + // TODO: Update hotkey here. + + NUdf::TUnboxedValue *listptr = nullptr; + NUdf::TUnboxedValue *cacheptr = nullptr; + const auto newList = ctx.HolderFactory.CreateDirectArrayHolder(count, listptr); /// TODO: Use list holder. + const auto newCache = ctx.HolderFactory.CreateDirectArrayHolder(count, cacheptr); /// TODO: Use list holder. for (auto& item : items) { - *inplace++ = std::move(item.second); + *listptr++ = std::move(item.first); + *cacheptr++ = std::move(item.second); } - return result; + Cache->SetValue(ctx, newCache); + return newList; } return list.Release(); @@ -625,6 +632,7 @@ class TKeepTopWrapper : public TMutableComputationNode { DependsOn(Key); DependsOn(Ascending); Own(HotKey); + Own(Cache); } TCompareDescr Description; @@ -635,6 +643,7 @@ class TKeepTopWrapper : public TMutableComputationNode { IComputationNode* const Key; IComputationNode* const Ascending; IComputationExternalNode* const HotKey; + IComputationExternalNode* const Cache; }; std::vector> GetKeySchemeTypes(TType* keyType, TType* ascType) { @@ -759,7 +768,7 @@ IComputationNode* WrapTopSort(TCallable& callable, const TComputationNodeFactory } IComputationNode* WrapKeepTop(TCallable& callable, const TComputationNodeFactoryContext& ctx) { - MKQL_ENSURE(callable.GetInputsCount() == 7, "Expected 7 args"); + MKQL_ENSURE(callable.GetInputsCount() == 8, "Expected 8 args"); const auto keyNode = callable.GetInput(4); const auto sortNode = callable.GetInput(5); @@ -775,9 +784,10 @@ IComputationNode* WrapKeepTop(TCallable& callable, const TComputationNodeFactory const auto ascending = LocateNode(ctx.NodeLocator, callable, 5); const auto itemArg = LocateExternalNode(ctx.NodeLocator, callable, 3); const auto hotkey = LocateExternalNode(ctx.NodeLocator, callable, 6); + const auto cache = LocateExternalNode(ctx.NodeLocator, callable, 7); auto comparators = MakeComparators(keyType, ascType->IsTuple()); - return new TKeepTopWrapper(ctx.Mutables, GetKeySchemeTypes(keyType, ascType), comparators, count, list, item, itemArg, key, ascending, hotkey); + return new TKeepTopWrapper(ctx.Mutables, GetKeySchemeTypes(keyType, ascType), comparators, count, list, item, itemArg, key, ascending, hotkey, cache); } } diff --git a/ydb/library/yql/minikql/mkql_program_builder.cpp b/ydb/library/yql/minikql/mkql_program_builder.cpp index 778aecf8e833..afabab229d23 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.cpp +++ b/ydb/library/yql/minikql/mkql_program_builder.cpp @@ -1936,6 +1936,7 @@ TRuntimeNode TProgramBuilder::KeepTop(TRuntimeNode count, TRuntimeNode list, TRu auto key = keyExtractor(itemArg); const auto hotkey = Arg(key.GetStaticType()); + const auto cache = Arg(TListType::Create(key.GetStaticType(), Env)); if (ascendingType->IsTuple()) { const auto ascendingTuple = AS_TYPE(TTupleType, ascendingType); @@ -1957,6 +1958,7 @@ TRuntimeNode TProgramBuilder::KeepTop(TRuntimeNode count, TRuntimeNode list, TRu callableBuilder.Add(key); callableBuilder.Add(ascending); callableBuilder.Add(hotkey); + callableBuilder.Add(cache); return TRuntimeNode(callableBuilder.Build(), false); }