Skip to content

Commit

Permalink
YQL-16402: Reimplement key extractor for Top node
Browse files Browse the repository at this point in the history
This patch fixes the issue with the excess calculations of the sorting
keys for Top computation node. While compiling Top runtime node, the
input source (either stream or flow) is transformed into {key: item}
mapping, and the resulting iterable is processed by both KeepTop and Top
nodes using the trivial key extractor, that obtains the key value as the
first component from the item of the resulting iterable. The result
yielded by Top is transformed back returning only the second component
from the item of the mapping being processed.

As a result of the changed described above, <keyExtractor> callable is
invoked once for each item of the given input iterable.

A static UDF module is used to test the fix. It provides an echo
function, that increments TLS counter each time being called. When the
processing is finished, the value of this counter is compared with the
number of the items in the Stream/Flow.
  • Loading branch information
igormunkin committed May 30, 2024
1 parent 5ded4db commit 43b854b
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 3 deletions.
104 changes: 104 additions & 0 deletions ydb/library/yql/minikql/comp_nodes/ut/mkql_sort_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,58 @@ Y_UNIT_TEST_SUITE(TMiniKQLStreamKeyExtractorCacheTest) {
UNIT_ASSERT_VALUES_EQUAL(echoCounter, total);
}

Y_UNIT_TEST(TestStreamTop) {
echoCounter = 0;
constexpr ui64 total = 999ULL;

std::uniform_real_distribution<ui64> urdist;
std::default_random_engine rand;
rand.seed(std::time(nullptr));

std::vector<ui64> test;
test.reserve(total);
std::generate_n(std::back_inserter(test), total, [&]() { return urdist(rand) % 100ULL; });

TSetup<false> setup;
NYql::NUdf::AddToStaticUdfRegistry<TCountCallsModule>();
auto mutableRegistry = setup.FunctionRegistry->Clone();
FillStaticModules(*mutableRegistry);
setup.FunctionRegistry = mutableRegistry;
setup.PgmBuilder.Reset(new TProgramBuilder(*setup.Env, *setup.FunctionRegistry));
TProgramBuilder& pgmBuilder = *setup.PgmBuilder;

std::array<TRuntimeNode, total> data;
std::transform(test.cbegin(), test.cend(), data.begin(), [&](const ui64& v) {
return pgmBuilder.NewTuple({pgmBuilder.NewDataLiteral<ui64>(v)});
});

constexpr ui64 n = 17ULL;
const auto echoUdf = pgmBuilder.Udf("CountCalls.EchoU64");
const auto tupleType = pgmBuilder.NewTupleType({pgmBuilder.NewDataType(NUdf::TDataType<ui64>::Id)});
const auto ascending = pgmBuilder.NewTuple({pgmBuilder.NewDataLiteral<bool>(false)});
const auto list = pgmBuilder.NewList(tupleType, data);
const auto extractor = [&pgmBuilder, echoUdf](TRuntimeNode item) {
return pgmBuilder.NewTuple({ pgmBuilder.Apply(echoUdf, {pgmBuilder.Nth(item, 0U)})});
};
const auto limit = pgmBuilder.NewDataLiteral<ui64>(n);
const auto pgmRoot = pgmBuilder.Top(pgmBuilder.Iterator(list, {}), limit, ascending, extractor);
const auto graph = setup.BuildGraph(pgmRoot);
const auto& value = graph->GetValue();

NYql::FastPartialSort(test.begin(), test.begin() + n, test.end(), std::greater<ui64>());
test.resize(n);

std::vector<ui64> res;
res.reserve(n);
for (NUdf::TUnboxedValue item; NUdf::EFetchStatus::Ok == value.Fetch(item);) {
res.emplace_back(item.GetElement(0U).template Get<ui64>());
}

UNIT_ASSERT_VALUES_EQUAL(res.size(), n);
UNIT_ASSERT(res == test);
UNIT_ASSERT_VALUES_EQUAL(echoCounter, total);
}

Y_UNIT_TEST(TestFlowTopSort) {
echoCounter = 0;
constexpr ui64 total = 999ULL;
Expand Down Expand Up @@ -635,6 +687,58 @@ Y_UNIT_TEST_SUITE(TMiniKQLStreamKeyExtractorCacheTest) {
UNIT_ASSERT(res == test);
UNIT_ASSERT_VALUES_EQUAL(echoCounter, total);
}

Y_UNIT_TEST(TestFlowTop) {
echoCounter = 0;
constexpr ui64 total = 999ULL;

std::uniform_real_distribution<ui64> urdist;
std::default_random_engine rand;
rand.seed(std::time(nullptr));

std::vector<ui64> test;
test.reserve(total);
std::generate_n(std::back_inserter(test), total, [&]() { return urdist(rand) % 100ULL; });

TSetup<false> setup;
NYql::NUdf::AddToStaticUdfRegistry<TCountCallsModule>();
auto mutableRegistry = setup.FunctionRegistry->Clone();
FillStaticModules(*mutableRegistry);
setup.FunctionRegistry = mutableRegistry;
setup.PgmBuilder.Reset(new TProgramBuilder(*setup.Env, *setup.FunctionRegistry));
TProgramBuilder& pgmBuilder = *setup.PgmBuilder;

std::array<TRuntimeNode, total> data;
std::transform(test.cbegin(), test.cend(), data.begin(), [&](const ui64& v) {
return pgmBuilder.NewTuple({pgmBuilder.NewDataLiteral<ui64>(v)});
});

constexpr ui64 n = 17ULL;
const auto echoUdf = pgmBuilder.Udf("CountCalls.EchoU64");
const auto tupleType = pgmBuilder.NewTupleType({pgmBuilder.NewDataType(NUdf::TDataType<ui64>::Id)});
const auto ascending = pgmBuilder.NewTuple({pgmBuilder.NewDataLiteral<bool>(false)});
const auto list = pgmBuilder.NewList(tupleType, data);
const auto extractor = [&pgmBuilder, echoUdf](TRuntimeNode item) {
return pgmBuilder.NewTuple({ pgmBuilder.Apply(echoUdf, {pgmBuilder.Nth(item, 0U)})});
};
const auto limit = pgmBuilder.NewDataLiteral<ui64>(n);
const auto pgmRoot = pgmBuilder.FromFlow(pgmBuilder.Top(pgmBuilder.ToFlow(list), limit, ascending, extractor));
const auto graph = setup.BuildGraph(pgmRoot);
const auto& value = graph->GetValue();

NYql::FastPartialSort(test.begin(), test.begin() + n, test.end(), std::greater<ui64>());
test.resize(n);

std::vector<ui64> res;
res.reserve(n);
for (NUdf::TUnboxedValue item; NUdf::EFetchStatus::Ok == value.Fetch(item);) {
res.emplace_back(item.GetElement(0U).template Get<ui64>());
}

UNIT_ASSERT_VALUES_EQUAL(res.size(), n);
UNIT_ASSERT(res == test);
UNIT_ASSERT_VALUES_EQUAL(echoCounter, total);
}
}
} // NMiniKQL
} // NKikimr
13 changes: 10 additions & 3 deletions ydb/library/yql/minikql/mkql_program_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1887,13 +1887,20 @@ TRuntimeNode TProgramBuilder::BuildWideTopOrSort(const std::string_view& callabl

TRuntimeNode TProgramBuilder::Top(TRuntimeNode flow, TRuntimeNode count, TRuntimeNode ascending, const TUnaryLambda& keyExtractor) {
if (const auto flowType = flow.GetStaticType(); flowType->IsFlow() || flowType->IsStream()) {
const TUnaryLambda getKey = [&](TRuntimeNode item) { return Nth(item, 0U); };
const TUnaryLambda getItem = [&](TRuntimeNode item) { return Nth(item, 1U); };
const TUnaryLambda cacheKeyExtractor = [&](TRuntimeNode item) {
return NewTuple({keyExtractor(item), item});
};

return FlatMap(Condense1(flow,
return FlatMap(Condense1(Map(flow, cacheKeyExtractor),
[&](TRuntimeNode item) { return AsList(item); },
[this](TRuntimeNode, TRuntimeNode) { return NewDataLiteral<bool>(false); },
[&](TRuntimeNode item, TRuntimeNode state) { return KeepTop(count, state, item, ascending, keyExtractor); }
[&](TRuntimeNode item, TRuntimeNode state) {
return KeepTop(count, state, item, ascending, getKey);
}
),
[&](TRuntimeNode list) { return Top(list, count, ascending, keyExtractor); }
[&](TRuntimeNode list) { return Map(Top(list, count, ascending, getKey), getItem); }
);
}

Expand Down

0 comments on commit 43b854b

Please sign in to comment.