Skip to content

Commit

Permalink
YQL-16402: Implement key cache for TKeepTopWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
igormunkin committed May 22, 2024
1 parent 303d901 commit 229df8b
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 31 deletions.
72 changes: 41 additions & 31 deletions ydb/library/yql/minikql/comp_nodes/mkql_sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,8 @@ class TKeepTopWrapper : public TMutableComputationNode<TKeepTopWrapper> {
IComputationExternalNode* arg,
IComputationNode* key,
IComputationNode* ascending,
IComputationExternalNode* hotkey)
IComputationExternalNode* hotkey,
IComputationExternalNode* cache)
: TBaseComputation(mutables)
, Description(mutables, std::move(keySchemeTypes), comparators)
, Count(count)
Expand All @@ -534,6 +535,7 @@ class TKeepTopWrapper : public TMutableComputationNode<TKeepTopWrapper> {
, Key(key)
, Ascending(ascending)
, HotKey(hotkey)
, Cache(cache)
{}

NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
Expand All @@ -542,14 +544,27 @@ class TKeepTopWrapper : public TMutableComputationNode<TKeepTopWrapper> {
return ctx.HolderFactory.GetEmptyContainerLazy();
}

auto list = List->GetValue(ctx);
auto item = Item->GetValue(ctx);
auto arg = item;
Arg->SetValue(ctx, arg.Release());

auto list = List->GetValue(ctx);
auto cache = Cache->GetValue(ctx);
auto key = Key->GetValue(ctx);

const auto size = list.GetListLength();
if (cache.IsInvalid()) {
cache = ctx.HolderFactory.CreateDirectListHolder({});
const auto cacheSize = cache.GetListLength();
Y_UNUSED(cacheSize);
}

if (size < count) {
Cache->SetValue(ctx, ctx.HolderFactory.Append(cache.Release(), key.Release()));
return ctx.HolderFactory.Append(list.Release(), item.Release());
}

const auto& ascending = Ascending->GetValue(ctx);
auto hotkey = HotKey->GetValue(ctx);
auto hotkey_prepared = hotkey;

Expand All @@ -559,58 +574,50 @@ class TKeepTopWrapper : public TMutableComputationNode<TKeepTopWrapper> {

if (size == count) {
if (hotkey.IsInvalid()) {
TUnboxedValueVector keys;
keys.reserve(size);

const auto ptr = list.GetElements();
std::transform(ptr, ptr + size, std::back_inserter(keys), [&](const NUdf::TUnboxedValuePod item) {
Arg->SetValue(ctx, item);
return Key->GetValue(ctx);
});

const auto cacheptr = cache.GetElements();
TUnboxedValueVector keys(cacheptr, cacheptr + size);
auto keys_copy = keys;

Description.Prepare(ctx, keys);

const auto& ascending = Ascending->GetValue(ctx);
const auto max = std::max_element(keys.begin(), keys.end(), Description.MakeComparator<TUnboxedValueVector>(ascending));
hotkey_prepared = *max;
HotKey->SetValue(ctx, std::move(keys_copy[max - keys.begin()]));
}
}

const auto copy = item;
Arg->SetValue(ctx, item.Release());
auto key_prepared = Key->GetValue(ctx);
auto key_prepared = key;
Description.PrepareValue(ctx, key_prepared);

const auto& ascending = Ascending->GetValue(ctx);

if (Description.MakeComparator<TUnboxedValueVector>(ascending)(key_prepared, hotkey_prepared)) {
const auto reserve = std::max<ui64>(count << 1ULL, 1ULL << 8ULL);
if (size < reserve) {
return ctx.HolderFactory.Append(list.Release(), Arg->GetValue(ctx).Release());
Cache->SetValue(ctx, ctx.HolderFactory.Append(cache.Release(), key.Release()));
return ctx.HolderFactory.Append(list.Release(), item.Release());
}

TKeyPayloadPairVector items(1U, TKeyPayloadPair(Key->GetValue(ctx), Arg->GetValue(ctx)));
items.reserve(items.size() + size);
TGatherIterator itemsIt(const_cast<NUdf::TUnboxedValue*>(cache.GetElements()),
const_cast<NUdf::TUnboxedValue*>(list.GetElements()));

const auto ptr = list.GetElements();
std::transform(ptr, ptr + size, std::back_inserter(items), [&](const NUdf::TUnboxedValuePod item) {
Arg->SetValue(ctx, item);
return TKeyPayloadPair(Key->GetValue(ctx), Arg->GetValue(ctx));
});
TKeyPayloadPairVector items(itemsIt, itemsIt + size);
items.emplace_back(key, item);

Description.Prepare(ctx, items);
NYql::FastNthElement(items.begin(), items.begin() + count - 1U, items.end(), Description.MakeComparator<TKeyPayloadPairVector>(ascending));
items.resize(count);

NUdf::TUnboxedValue *inplace = nullptr;
const auto result = ctx.HolderFactory.CreateDirectArrayHolder(count, inplace); /// TODO: Use list holder.
// TODO: Update hotkey here.

NUdf::TUnboxedValue *listptr = nullptr;
NUdf::TUnboxedValue *cacheptr = nullptr;
const auto newList = ctx.HolderFactory.CreateDirectArrayHolder(count, listptr); /// TODO: Use list holder.
const auto newCache = ctx.HolderFactory.CreateDirectArrayHolder(count, cacheptr); /// TODO: Use list holder.
for (auto& item : items) {
*inplace++ = std::move(item.second);
*listptr++ = std::move(item.first);
*cacheptr++ = std::move(item.second);
}
return result;
Cache->SetValue(ctx, newCache);
return newList;
}

return list.Release();
Expand All @@ -625,6 +632,7 @@ class TKeepTopWrapper : public TMutableComputationNode<TKeepTopWrapper> {
DependsOn(Key);
DependsOn(Ascending);
Own(HotKey);
Own(Cache);
}

TCompareDescr Description;
Expand All @@ -635,6 +643,7 @@ class TKeepTopWrapper : public TMutableComputationNode<TKeepTopWrapper> {
IComputationNode* const Key;
IComputationNode* const Ascending;
IComputationExternalNode* const HotKey;
IComputationExternalNode* const Cache;
};

std::vector<std::tuple<NUdf::EDataSlot, bool, TType*>> GetKeySchemeTypes(TType* keyType, TType* ascType) {
Expand Down Expand Up @@ -759,7 +768,7 @@ IComputationNode* WrapTopSort(TCallable& callable, const TComputationNodeFactory
}

IComputationNode* WrapKeepTop(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
MKQL_ENSURE(callable.GetInputsCount() == 7, "Expected 7 args");
MKQL_ENSURE(callable.GetInputsCount() == 8, "Expected 8 args");

const auto keyNode = callable.GetInput(4);
const auto sortNode = callable.GetInput(5);
Expand All @@ -775,9 +784,10 @@ IComputationNode* WrapKeepTop(TCallable& callable, const TComputationNodeFactory
const auto ascending = LocateNode(ctx.NodeLocator, callable, 5);
const auto itemArg = LocateExternalNode(ctx.NodeLocator, callable, 3);
const auto hotkey = LocateExternalNode(ctx.NodeLocator, callable, 6);
const auto cache = LocateExternalNode(ctx.NodeLocator, callable, 7);

auto comparators = MakeComparators(keyType, ascType->IsTuple());
return new TKeepTopWrapper(ctx.Mutables, GetKeySchemeTypes(keyType, ascType), comparators, count, list, item, itemArg, key, ascending, hotkey);
return new TKeepTopWrapper(ctx.Mutables, GetKeySchemeTypes(keyType, ascType), comparators, count, list, item, itemArg, key, ascending, hotkey, cache);
}

}
Expand Down
2 changes: 2 additions & 0 deletions ydb/library/yql/minikql/mkql_program_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1936,6 +1936,7 @@ TRuntimeNode TProgramBuilder::KeepTop(TRuntimeNode count, TRuntimeNode list, TRu

auto key = keyExtractor(itemArg);
const auto hotkey = Arg(key.GetStaticType());
const auto cache = Arg(TListType::Create(key.GetStaticType(), Env));

if (ascendingType->IsTuple()) {
const auto ascendingTuple = AS_TYPE(TTupleType, ascendingType);
Expand All @@ -1957,6 +1958,7 @@ TRuntimeNode TProgramBuilder::KeepTop(TRuntimeNode count, TRuntimeNode list, TRu
callableBuilder.Add(key);
callableBuilder.Add(ascending);
callableBuilder.Add(hotkey);
callableBuilder.Add(cache);
return TRuntimeNode(callableBuilder.Build(), false);
}

Expand Down

0 comments on commit 229df8b

Please sign in to comment.