Skip to content

Commit

Permalink
YQL-16402: Reimplement key extractor for TopSort node
Browse files Browse the repository at this point in the history
This patch fixes the issue with the excess calculations of the sorting
keys for TopSort computation node. While compiling TopSort runtime node,
the input source (either stream or flow) is transformed into {key: item}
mapping, and the resulting iterable is processed by both KeepTop and
TopSort nodes using the trivial key extractor, that obtains the key
value as the first component from the item of the resulting iterable.
The result yielded by TopSort is transformed back returning only the
second component from the item of the mapping being processed.

As a result of the changed described above, <keyExtractor> callable is
invoked once for each item of the given input iterable.

A static UDF module is used to test the fix. It provides an echo
function, that increments TLS counter each time being called. When the
processing is finished, the value of this counter is compared with the
number of the items in the Stream/Flow.
  • Loading branch information
igormunkin committed May 30, 2024
1 parent f898f0b commit 5ded4db
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 3 deletions.
117 changes: 117 additions & 0 deletions ydb/library/yql/minikql/comp_nodes/ut/mkql_sort_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <ydb/library/yql/minikql/mkql_node.h>
#include <ydb/library/yql/minikql/mkql_program_builder.h>
#include <ydb/library/yql/minikql/mkql_string_util.h>
#include <ydb/library/yql/public/udf/udf_helpers.h>

#include <ydb/library/yql/utils/sort.h>

Expand Down Expand Up @@ -519,5 +520,121 @@ Y_UNIT_TEST_SUITE(TMiniKQLSortTest) {
UNIT_ASSERT(copy == res);
}
}

Y_UNIT_TEST_SUITE(TMiniKQLStreamKeyExtractorCacheTest) {
static thread_local size_t echoCounter;

SIMPLE_UDF(TEchoU64, ui64(ui64)) {
Y_UNUSED(valueBuilder);
echoCounter++;
return args[0];
}

SIMPLE_MODULE(TCountCallsModule, TEchoU64);

Y_UNIT_TEST(TestStreamTopSort) {
echoCounter = 0;
constexpr ui64 total = 999ULL;

std::uniform_real_distribution<ui64> urdist;
std::default_random_engine rand;
rand.seed(std::time(nullptr));

std::vector<ui64> test;
test.reserve(total);
std::generate_n(std::back_inserter(test), total, [&]() { return urdist(rand) % 100ULL; });

TSetup<false> setup;
NYql::NUdf::AddToStaticUdfRegistry<TCountCallsModule>();
auto mutableRegistry = setup.FunctionRegistry->Clone();
FillStaticModules(*mutableRegistry);
setup.FunctionRegistry = mutableRegistry;
setup.PgmBuilder.Reset(new TProgramBuilder(*setup.Env, *setup.FunctionRegistry));
TProgramBuilder& pgmBuilder = *setup.PgmBuilder;

std::array<TRuntimeNode, total> data;
std::transform(test.cbegin(), test.cend(), data.begin(), [&](const ui64& v) {
return pgmBuilder.NewTuple({pgmBuilder.NewDataLiteral<ui64>(v)});
});

constexpr ui64 n = 17ULL;
const auto echoUdf = pgmBuilder.Udf("CountCalls.EchoU64");
const auto tupleType = pgmBuilder.NewTupleType({pgmBuilder.NewDataType(NUdf::TDataType<ui64>::Id)});
const auto ascending = pgmBuilder.NewTuple({pgmBuilder.NewDataLiteral<bool>(false)});
const auto list = pgmBuilder.NewList(tupleType, data);
const auto extractor = [&pgmBuilder, echoUdf](TRuntimeNode item) {
return pgmBuilder.NewTuple({ pgmBuilder.Apply(echoUdf, {pgmBuilder.Nth(item, 0U)})});
};
const auto limit = pgmBuilder.NewDataLiteral<ui64>(n);
const auto pgmRoot = pgmBuilder.TopSort(pgmBuilder.Iterator(list, {}), limit, ascending, extractor);
const auto graph = setup.BuildGraph(pgmRoot);
const auto& value = graph->GetValue();

NYql::FastPartialSort(test.begin(), test.begin() + n, test.end(), std::greater<ui64>());
test.resize(n);

std::vector<ui64> res;
res.reserve(n);
for (NUdf::TUnboxedValue item; NUdf::EFetchStatus::Ok == value.Fetch(item);) {
res.emplace_back(item.GetElement(0U).template Get<ui64>());
}

UNIT_ASSERT_VALUES_EQUAL(res.size(), n);
UNIT_ASSERT(res == test);
UNIT_ASSERT_VALUES_EQUAL(echoCounter, total);
}

Y_UNIT_TEST(TestFlowTopSort) {
echoCounter = 0;
constexpr ui64 total = 999ULL;

std::uniform_real_distribution<ui64> urdist;
std::default_random_engine rand;
rand.seed(std::time(nullptr));

std::vector<ui64> test;
test.reserve(total);
std::generate_n(std::back_inserter(test), total, [&]() { return urdist(rand) % 100ULL; });

TSetup<false> setup;
NYql::NUdf::AddToStaticUdfRegistry<TCountCallsModule>();
auto mutableRegistry = setup.FunctionRegistry->Clone();
FillStaticModules(*mutableRegistry);
setup.FunctionRegistry = mutableRegistry;
setup.PgmBuilder.Reset(new TProgramBuilder(*setup.Env, *setup.FunctionRegistry));
TProgramBuilder& pgmBuilder = *setup.PgmBuilder;

std::array<TRuntimeNode, total> data;
std::transform(test.cbegin(), test.cend(), data.begin(), [&](const ui64& v) {
return pgmBuilder.NewTuple({pgmBuilder.NewDataLiteral<ui64>(v)});
});

constexpr ui64 n = 17ULL;
const auto echoUdf = pgmBuilder.Udf("CountCalls.EchoU64");
const auto tupleType = pgmBuilder.NewTupleType({pgmBuilder.NewDataType(NUdf::TDataType<ui64>::Id)});
const auto ascending = pgmBuilder.NewTuple({pgmBuilder.NewDataLiteral<bool>(false)});
const auto list = pgmBuilder.NewList(tupleType, data);
const auto extractor = [&pgmBuilder, echoUdf](TRuntimeNode item) {
return pgmBuilder.NewTuple({ pgmBuilder.Apply(echoUdf, {pgmBuilder.Nth(item, 0U)})});
};
const auto limit = pgmBuilder.NewDataLiteral<ui64>(n);
const auto pgmRoot = pgmBuilder.FromFlow(pgmBuilder.TopSort(pgmBuilder.ToFlow(list), limit, ascending, extractor));
const auto graph = setup.BuildGraph(pgmRoot);
const auto& value = graph->GetValue();

NYql::FastPartialSort(test.begin(), test.begin() + n, test.end(), std::greater<ui64>());
test.resize(n);

std::vector<ui64> res;
res.reserve(n);
for (NUdf::TUnboxedValue item; NUdf::EFetchStatus::Ok == value.Fetch(item);) {
res.emplace_back(item.GetElement(0U).template Get<ui64>());
}

UNIT_ASSERT_VALUES_EQUAL(res.size(), n);
UNIT_ASSERT(res == test);
UNIT_ASSERT_VALUES_EQUAL(echoCounter, total);
}
}
} // NMiniKQL
} // NKikimr
12 changes: 9 additions & 3 deletions ydb/library/yql/minikql/mkql_program_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1902,14 +1902,20 @@ TRuntimeNode TProgramBuilder::Top(TRuntimeNode flow, TRuntimeNode count, TRuntim

TRuntimeNode TProgramBuilder::TopSort(TRuntimeNode flow, TRuntimeNode count, TRuntimeNode ascending, const TUnaryLambda& keyExtractor) {
if (const auto flowType = flow.GetStaticType(); flowType->IsFlow() || flowType->IsStream()) {
return FlatMap(Condense1(flow,
const TUnaryLambda getKey = [&](TRuntimeNode item) { return Nth(item, 0U); };
const TUnaryLambda getItem = [&](TRuntimeNode item) { return Nth(item, 1U); };
const TUnaryLambda cacheKeyExtractor = [&](TRuntimeNode item) {
return NewTuple({keyExtractor(item), item});
};

return FlatMap(Condense1(Map(flow, cacheKeyExtractor),
[&](TRuntimeNode item) { return AsList(item); },
[this](TRuntimeNode, TRuntimeNode) { return NewDataLiteral<bool>(false); },
[&](TRuntimeNode item, TRuntimeNode state) {
return KeepTop(count, state, item, ascending, keyExtractor);
return KeepTop(count, state, item, ascending, getKey);
}
),
[&](TRuntimeNode list) { return TopSort(list, count, ascending, keyExtractor); }
[&](TRuntimeNode list) { return Map(TopSort(list, count, ascending, getKey), getItem); }
);
}

Expand Down

0 comments on commit 5ded4db

Please sign in to comment.