diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_sort_ut.cpp b/ydb/library/yql/minikql/comp_nodes/ut/mkql_sort_ut.cpp index 64f1ba1fe927..4373cd6cf2a2 100644 --- a/ydb/library/yql/minikql/comp_nodes/ut/mkql_sort_ut.cpp +++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_sort_ut.cpp @@ -584,6 +584,58 @@ Y_UNIT_TEST_SUITE(TMiniKQLStreamKeyExtractorCacheTest) { UNIT_ASSERT_VALUES_EQUAL(echoCounter, total); } + Y_UNIT_TEST(TestStreamTop) { + echoCounter = 0; + constexpr ui64 total = 999ULL; + + std::uniform_real_distribution urdist; + std::default_random_engine rand; + rand.seed(std::time(nullptr)); + + std::vector test; + test.reserve(total); + std::generate_n(std::back_inserter(test), total, [&]() { return urdist(rand) % 100ULL; }); + + TSetup setup; + NYql::NUdf::AddToStaticUdfRegistry(); + auto mutableRegistry = setup.FunctionRegistry->Clone(); + FillStaticModules(*mutableRegistry); + setup.FunctionRegistry = mutableRegistry; + setup.PgmBuilder.Reset(new TProgramBuilder(*setup.Env, *setup.FunctionRegistry)); + TProgramBuilder& pgmBuilder = *setup.PgmBuilder; + + std::array data; + std::transform(test.cbegin(), test.cend(), data.begin(), [&](const ui64& v) { + return pgmBuilder.NewTuple({pgmBuilder.NewDataLiteral(v)}); + }); + + constexpr ui64 n = 17ULL; + const auto echoUdf = pgmBuilder.Udf("CountCalls.EchoU64"); + const auto tupleType = pgmBuilder.NewTupleType({pgmBuilder.NewDataType(NUdf::TDataType::Id)}); + const auto ascending = pgmBuilder.NewTuple({pgmBuilder.NewDataLiteral(false)}); + const auto list = pgmBuilder.NewList(tupleType, data); + const auto extractor = [&pgmBuilder, echoUdf](TRuntimeNode item) { + return pgmBuilder.NewTuple({ pgmBuilder.Apply(echoUdf, {pgmBuilder.Nth(item, 0U)})}); + }; + const auto limit = pgmBuilder.NewDataLiteral(n); + const auto pgmRoot = pgmBuilder.Top(pgmBuilder.Iterator(list, {}), limit, ascending, extractor); + const auto graph = setup.BuildGraph(pgmRoot); + const auto& value = graph->GetValue(); + + NYql::FastPartialSort(test.begin(), test.begin() + n, test.end(), std::greater()); + test.resize(n); + + std::vector res; + res.reserve(n); + for (NUdf::TUnboxedValue item; NUdf::EFetchStatus::Ok == value.Fetch(item);) { + res.emplace_back(item.GetElement(0U).template Get()); + } + + UNIT_ASSERT_VALUES_EQUAL(res.size(), n); + UNIT_ASSERT(res == test); + UNIT_ASSERT_VALUES_EQUAL(echoCounter, total); + } + Y_UNIT_TEST(TestFlowTopSort) { echoCounter = 0; constexpr ui64 total = 999ULL; @@ -635,6 +687,58 @@ Y_UNIT_TEST_SUITE(TMiniKQLStreamKeyExtractorCacheTest) { UNIT_ASSERT(res == test); UNIT_ASSERT_VALUES_EQUAL(echoCounter, total); } + + Y_UNIT_TEST(TestFlowTop) { + echoCounter = 0; + constexpr ui64 total = 999ULL; + + std::uniform_real_distribution urdist; + std::default_random_engine rand; + rand.seed(std::time(nullptr)); + + std::vector test; + test.reserve(total); + std::generate_n(std::back_inserter(test), total, [&]() { return urdist(rand) % 100ULL; }); + + TSetup setup; + NYql::NUdf::AddToStaticUdfRegistry(); + auto mutableRegistry = setup.FunctionRegistry->Clone(); + FillStaticModules(*mutableRegistry); + setup.FunctionRegistry = mutableRegistry; + setup.PgmBuilder.Reset(new TProgramBuilder(*setup.Env, *setup.FunctionRegistry)); + TProgramBuilder& pgmBuilder = *setup.PgmBuilder; + + std::array data; + std::transform(test.cbegin(), test.cend(), data.begin(), [&](const ui64& v) { + return pgmBuilder.NewTuple({pgmBuilder.NewDataLiteral(v)}); + }); + + constexpr ui64 n = 17ULL; + const auto echoUdf = pgmBuilder.Udf("CountCalls.EchoU64"); + const auto tupleType = pgmBuilder.NewTupleType({pgmBuilder.NewDataType(NUdf::TDataType::Id)}); + const auto ascending = pgmBuilder.NewTuple({pgmBuilder.NewDataLiteral(false)}); + const auto list = pgmBuilder.NewList(tupleType, data); + const auto extractor = [&pgmBuilder, echoUdf](TRuntimeNode item) { + return pgmBuilder.NewTuple({ pgmBuilder.Apply(echoUdf, {pgmBuilder.Nth(item, 0U)})}); + }; + const auto limit = pgmBuilder.NewDataLiteral(n); + const auto pgmRoot = pgmBuilder.FromFlow(pgmBuilder.Top(pgmBuilder.ToFlow(list), limit, ascending, extractor)); + const auto graph = setup.BuildGraph(pgmRoot); + const auto& value = graph->GetValue(); + + NYql::FastPartialSort(test.begin(), test.begin() + n, test.end(), std::greater()); + test.resize(n); + + std::vector res; + res.reserve(n); + for (NUdf::TUnboxedValue item; NUdf::EFetchStatus::Ok == value.Fetch(item);) { + res.emplace_back(item.GetElement(0U).template Get()); + } + + UNIT_ASSERT_VALUES_EQUAL(res.size(), n); + UNIT_ASSERT(res == test); + UNIT_ASSERT_VALUES_EQUAL(echoCounter, total); + } } } // NMiniKQL } // NKikimr diff --git a/ydb/library/yql/minikql/mkql_program_builder.cpp b/ydb/library/yql/minikql/mkql_program_builder.cpp index a63b9ac5d494..7e4ad0cdb437 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.cpp +++ b/ydb/library/yql/minikql/mkql_program_builder.cpp @@ -1887,13 +1887,20 @@ TRuntimeNode TProgramBuilder::BuildWideTopOrSort(const std::string_view& callabl TRuntimeNode TProgramBuilder::Top(TRuntimeNode flow, TRuntimeNode count, TRuntimeNode ascending, const TUnaryLambda& keyExtractor) { if (const auto flowType = flow.GetStaticType(); flowType->IsFlow() || flowType->IsStream()) { + const TUnaryLambda getKey = [&](TRuntimeNode item) { return Nth(item, 0U); }; + const TUnaryLambda getItem = [&](TRuntimeNode item) { return Nth(item, 1U); }; + const TUnaryLambda cacheKeyExtractor = [&](TRuntimeNode item) { + return NewTuple({keyExtractor(item), item}); + }; - return FlatMap(Condense1(flow, + return FlatMap(Condense1(Map(flow, cacheKeyExtractor), [&](TRuntimeNode item) { return AsList(item); }, [this](TRuntimeNode, TRuntimeNode) { return NewDataLiteral(false); }, - [&](TRuntimeNode item, TRuntimeNode state) { return KeepTop(count, state, item, ascending, keyExtractor); } + [&](TRuntimeNode item, TRuntimeNode state) { + return KeepTop(count, state, item, ascending, getKey); + } ), - [&](TRuntimeNode list) { return Top(list, count, ascending, keyExtractor); } + [&](TRuntimeNode list) { return Map(Top(list, count, ascending, getKey), getItem); } ); }