From ccc0f87d04cb8d92b7e0ad0f18b168bc04a6750b Mon Sep 17 00:00:00 2001 From: Valerii Mironov Date: Wed, 4 Sep 2024 17:10:07 +0000 Subject: [PATCH] WIP --- ydb/core/protos/out/out.cpp | 10 +- ydb/core/protos/tx_datashard.proto | 66 ++ ydb/core/tablet_flat/flat_scan_lead.h | 15 +- ydb/core/tx/datashard/buffer_data.h | 76 ++ ydb/core/tx/datashard/build_index.cpp | 102 +- ydb/core/tx/datashard/datashard.h | 15 + ydb/core/tx/datashard/datashard_impl.h | 4 + .../tx/datashard/datashard_ut_build_index.cpp | 5 - .../datashard/datashard_ut_local_kmeans.cpp | 207 ++++ .../tx/datashard/datashard_ut_sample_k.cpp | 5 - ydb/core/tx/datashard/local_kmeans.cpp | 919 ++++++++++++++++++ ydb/core/tx/datashard/scan_common.cpp | 21 + ydb/core/tx/datashard/scan_common.h | 7 + ydb/core/tx/datashard/ut_local_kmeans/ya.make | 39 + ydb/core/tx/datashard/ya.make | 4 + ydb/library/services/services.proto | 1 + ydb/public/api/protos/out/out.cpp | 4 + 17 files changed, 1384 insertions(+), 116 deletions(-) create mode 100644 ydb/core/tx/datashard/buffer_data.h create mode 100644 ydb/core/tx/datashard/datashard_ut_local_kmeans.cpp create mode 100644 ydb/core/tx/datashard/local_kmeans.cpp create mode 100644 ydb/core/tx/datashard/ut_local_kmeans/ya.make diff --git a/ydb/core/protos/out/out.cpp b/ydb/core/protos/out/out.cpp index ba99fa6e3820..b8f193fdfd6b 100644 --- a/ydb/core/protos/out/out.cpp +++ b/ydb/core/protos/out/out.cpp @@ -1,5 +1,3 @@ -#include - #include #include #include @@ -254,6 +252,10 @@ Y_DECLARE_OUT_SPEC(, NKikimrStat::TEvStatisticsResponse::EStatus, stream, value) stream << NKikimrStat::TEvStatisticsResponse::EStatus_Name(value); } -Y_DECLARE_OUT_SPEC(, Ydb::Table::IndexBuildState_State, stream, value) { - stream << IndexBuildState_State_Name(value); +Y_DECLARE_OUT_SPEC(, NKikimrIndexBuilder::EBuildStatus, stream, value) { + stream << NKikimrIndexBuilder::EBuildStatus_Name(value); +} + +Y_DECLARE_OUT_SPEC(, NKikimrTxDataShard::TEvLocalKMeansRequest_EState, stream, value) { + stream << NKikimrTxDataShard::TEvLocalKMeansRequest_EState_Name(value); } diff --git a/ydb/core/protos/tx_datashard.proto b/ydb/core/protos/tx_datashard.proto index 1db0740d9769..d9d59bf10228 100644 --- a/ydb/core/protos/tx_datashard.proto +++ b/ydb/core/protos/tx_datashard.proto @@ -17,6 +17,7 @@ import "ydb/core/protos/subdomains.proto"; import "ydb/core/protos/query_stats.proto"; import "ydb/public/api/protos/ydb_issue_message.proto"; import "ydb/public/api/protos/ydb_status_codes.proto"; +import "ydb/public/api/protos/ydb_table.proto"; import "ydb/library/yql/dq/actors/protos/dq_events.proto"; import "ydb/library/yql/dq/actors/protos/dq_stats.proto"; import "ydb/library/yql/dq/proto/dq_tasks.proto"; @@ -1486,6 +1487,71 @@ message TEvSampleKResponse { repeated bytes Rows = 11; } +message TEvLocalKMeansRequest { + optional uint64 Id = 1; + + optional uint64 TabletId = 2; + optional NKikimrProto.TPathID PathId = 3; + + optional uint64 SnapshotTxId = 4; + optional uint64 SnapshotStep = 5; + + optional uint64 SeqNoGeneration = 6; + optional uint64 SeqNoRound = 7; + + optional Ydb.Table.VectorIndexSettings Settings = 8; + + optional uint64 Seed = 9; + optional uint32 K = 10; + + enum EState { + SAMPLE = 1; + KMEANS = 2; + UPLOAD_MAIN_TO_TMP = 3; + UPLOAD_MAIN_TO_POSTING = 4; + UPLOAD_TMP_TO_TMP = 5; + UPLOAD_TMP_TO_POSTING = 6; + DONE = 7; + }; + optional EState Upload = 11; + // State != DONE + optional EState State = 12; + // State != KMEANS || DoneRounds < NeedsRounds + optional uint32 DoneRounds = 13; + optional uint32 NeedsRounds = 14; + + // id of parent cluster + optional uint32 Parent = 15; + // [Child ... Child + K] ids reserved for our clusters + optional uint32 Child = 16; + + optional string LevelName = 17; + optional string PostingName = 18; + + optional string EmbeddingColumn = 19; + repeated string DataColumns = 20; +} + +message TEvLocalKMeansProgressResponse { + optional uint64 Id = 1; + + optional uint64 TabletId = 2; + optional NKikimrProto.TPathID PathId = 3; + + optional uint64 RequestSeqNoGeneration = 4; + optional uint64 RequestSeqNoRound = 5; + + optional NKikimrIndexBuilder.EBuildStatus Status = 6; + repeated Ydb.Issue.IssueMessage Issues = 7; + + // TODO(mbkkt) implement slow-path (reliable-path) + // optional uint64 RowsDelta = 8; + // optional uint64 BytesDelta = 9; + + // optional TEvLocalKMeansRequest.EState State = 10; + // optional uint32 DoneRounds = 11; +} + message TEvCdcStreamScanRequest { message TLimits { optional uint32 BatchMaxBytes = 1 [default = 512000]; diff --git a/ydb/core/tablet_flat/flat_scan_lead.h b/ydb/core/tablet_flat/flat_scan_lead.h index 80b4d00f3e6f..f7dd11151923 100644 --- a/ydb/core/tablet_flat/flat_scan_lead.h +++ b/ydb/core/tablet_flat/flat_scan_lead.h @@ -9,9 +9,14 @@ namespace NTable { struct TLead { void To(TTagsRef tags, TArrayRef key, ESeek seek) + { + To(key, seek); + SetTags(tags); + } + + void To(TArrayRef key, ESeek seek) { Valid = true; - Tags.assign(tags.begin(), tags.end()); Relation = seek; Key = TSerializedCellVec(key); StopKey = { }; @@ -24,6 +29,10 @@ namespace NTable { StopKeyInclusive = inclusive; } + void SetTags(TTagsRef tags) { + Tags.assign(tags.begin(), tags.end()); + } + explicit operator bool() const noexcept { return Valid; @@ -34,12 +43,12 @@ namespace NTable { Valid = false; } - bool Valid = false; ESeek Relation = ESeek::Exact; + bool Valid = false; + bool StopKeyInclusive = true; TVector Tags; TSerializedCellVec Key; TSerializedCellVec StopKey; - bool StopKeyInclusive = true; }; } diff --git a/ydb/core/tx/datashard/buffer_data.h b/ydb/core/tx/datashard/buffer_data.h new file mode 100644 index 000000000000..efe4d68ed6ef --- /dev/null +++ b/ydb/core/tx/datashard/buffer_data.h @@ -0,0 +1,76 @@ +#include "ydb/core/scheme/scheme_tablecell.h" +#include "ydb/core/tx/datashard/upload_stats.h" +#include "ydb/core/tx/tx_proxy/upload_rows.h" + +namespace NKikimr::NDataShard { + +using TTypes = NTxProxy::TUploadTypes; +using TRows = NTxProxy::TUploadRows; + +class TBufferData: public IStatHolder, public TNonCopyable { +public: + TBufferData() + : Rows{std::make_shared()} + { + } + + ui64 GetRows() const override final { + return Rows->size(); + } + + std::shared_ptr GetRowsData() const { + return Rows; + } + + ui64 GetBytes() const override final { + return ByteSize; + } + + void FlushTo(TBufferData& other) { + Y_ABORT_UNLESS(this != &other); + Y_ABORT_UNLESS(other.IsEmpty()); + other.Rows.swap(Rows); + other.ByteSize = std::exchange(ByteSize, 0); + other.LastKey = std::exchange(LastKey, {}); + } + + void Clear() { + Rows->clear(); + ByteSize = 0; + LastKey = {}; + } + + void AddRow(TSerializedCellVec&& key, TSerializedCellVec&& targetPk, TString&& targetValue) { + Rows->emplace_back(std::move(targetPk), std::move(targetValue)); + ByteSize += Rows->back().first.GetBuffer().size() + Rows->back().second.size(); + LastKey = std::move(key); + } + + bool IsEmpty() const { + return Rows->empty(); + } + + size_t Size() const { + return Rows->size(); + } + + bool IsReachLimits(const TUploadLimits& Limits) { + // TODO(mbkkt) why [0..BatchRowsLimit) but [0..BatchBytesLimit] + return Rows->size() >= Limits.BatchRowsLimit || ByteSize > Limits.BatchBytesLimit; + } + + auto&& ExtractLastKey() { + return std::move(LastKey); + } + + const auto& GetLastKey() const { + return LastKey; + } + +private: + std::shared_ptr Rows; + ui64 ByteSize = 0; + TSerializedCellVec LastKey; +}; + +} diff --git a/ydb/core/tx/datashard/build_index.cpp b/ydb/core/tx/datashard/build_index.cpp index ba562f83c4b8..fe75fac1a96e 100644 --- a/ydb/core/tx/datashard/build_index.cpp +++ b/ydb/core/tx/datashard/build_index.cpp @@ -2,6 +2,7 @@ #include "range_ops.h" #include "scan_common.h" #include "upload_stats.h" +#include "buffer_data.h" #include #include @@ -27,31 +28,6 @@ namespace NKikimr::NDataShard { #define LOG_W(stream) LOG_WARN_S(*TlsActivationContext, NKikimrServices::TX_DATASHARD, stream) #define LOG_E(stream) LOG_ERROR_S(*TlsActivationContext, NKikimrServices::TX_DATASHARD, stream) -using TColumnsTypes = THashMap; -using TTypes = NTxProxy::TUploadTypes; -using TRows = NTxProxy::TUploadRows; - -static TColumnsTypes GetAllTypes(const TUserTable& tableInfo) { - TColumnsTypes result; - - for (const auto& it : tableInfo.Columns) { - result[it.second.Name] = it.second.Type; - } - - return result; -} - -static void ProtoYdbTypeFromTypeInfo(Ydb::Type* type, const NScheme::TTypeInfo typeInfo) { - if (typeInfo.GetTypeId() == NScheme::NTypeIds::Pg) { - auto* typeDesc = typeInfo.GetTypeDesc(); - auto* pg = type->mutable_pg_type(); - pg->set_type_name(NPg::PgTypeNameFromTypeDesc(typeDesc)); - pg->set_oid(NPg::PgTypeIdFromTypeDesc(typeDesc)); - } else { - type->set_type_id((Ydb::Type::PrimitiveTypeId)typeInfo.GetTypeId()); - } -} - static std::shared_ptr BuildTypes(const TUserTable& tableInfo, const NKikimrIndexBuilder::TColumnBuildSettings& buildSettings) { auto types = GetAllTypes(tableInfo); @@ -119,74 +95,6 @@ bool BuildExtraColumns(TVector& cells, const NKikimrIndexBuilder::TColumn return true; } -class TBufferData: public IStatHolder, public TNonCopyable { -public: - TBufferData() - : Rows(new TRows) - { - } - - ui64 GetRows() const override final { - return Rows->size(); - } - - std::shared_ptr GetRowsData() const { - return Rows; - } - - ui64 GetBytes() const override final { - return ByteSize; - } - - void FlushTo(TBufferData& other) { - if (this == &other) { - return; - } - - Y_ABORT_UNLESS(other.Rows); - Y_ABORT_UNLESS(other.IsEmpty()); - - other.Rows.swap(Rows); - other.ByteSize = ByteSize; - other.LastKey = std::move(LastKey); - - Clear(); - } - - void Clear() { - Rows->clear(); - ByteSize = 0; - LastKey = {}; - } - - void AddRow(TSerializedCellVec&& key, TSerializedCellVec&& targetPk, TString&& targetValue) { - Rows->emplace_back(std::move(targetPk), std::move(targetValue)); - ByteSize += Rows->back().first.GetBuffer().size() + Rows->back().second.size(); - LastKey = std::move(key); - } - - bool IsEmpty() const { - return Rows->empty(); - } - - bool IsReachLimits(const TUploadLimits& Limits) { - return Rows->size() >= Limits.BatchRowsLimit || ByteSize > Limits.BatchBytesLimit; - } - - void ExtractLastKey(TSerializedCellVec& out) { - out = std::move(LastKey); - } - - const TSerializedCellVec& GetLastKey() const { - return LastKey; - } - -private: - std::shared_ptr Rows; - ui64 ByteSize = 0; - TSerializedCellVec LastKey; -}; - template class TBuildScanUpload: public TActor>, public NTable::IScan { using TThis = TBuildScanUpload; @@ -382,11 +290,7 @@ class TBuildScanUpload: public TActor>, public NTable << " WriteBuf empty: " << WriteBuf.IsEmpty() << " " << Debug()); - if (ReadBuf.IsEmpty()) { - return EScan::Feed; - } - - if (WriteBuf.IsEmpty()) { + if (!ReadBuf.IsEmpty() && WriteBuf.IsEmpty()) { ReadBuf.FlushTo(WriteBuf); Upload(); } @@ -433,7 +337,7 @@ class TBuildScanUpload: public TActor>, public NTable if (UploadStatus.IsSuccess()) { Stats.Aggr(&WriteBuf); - WriteBuf.ExtractLastKey(LastUploadedKey); + LastUploadedKey = WriteBuf.ExtractLastKey(); //send progress TAutoPtr progress = new TEvDataShard::TEvBuildIndexProgressResponse; diff --git a/ydb/core/tx/datashard/datashard.h b/ydb/core/tx/datashard/datashard.h index 8c305a3a5ceb..a574bcd2d900 100644 --- a/ydb/core/tx/datashard/datashard.h +++ b/ydb/core/tx/datashard/datashard.h @@ -332,6 +332,9 @@ struct TEvDataShard { EvSampleKRequest, EvSampleKResponse, + EvLocalKMeansRequest, + EvLocalKMeansProgressResponse, + EvEnd }; @@ -1454,6 +1457,18 @@ struct TEvDataShard { TEvDataShard::EvSampleKResponse> { }; + struct TEvLocalKMeansRequest + : public TEventPB { + }; + + struct TEvLocalKMeansProgressResponse + : public TEventPB { + }; + struct TEvKqpScan : public TEventPB -template <> -inline void Out(IOutputStream& o, NKikimrIndexBuilder::EBuildStatus status) { - o << NKikimrIndexBuilder::EBuildStatus_Name(status); -} - namespace NKikimr { using namespace NKikimr::NDataShard::NKqpHelpers; diff --git a/ydb/core/tx/datashard/datashard_ut_local_kmeans.cpp b/ydb/core/tx/datashard/datashard_ut_local_kmeans.cpp new file mode 100644 index 000000000000..81cd4ff7549e --- /dev/null +++ b/ydb/core/tx/datashard/datashard_ut_local_kmeans.cpp @@ -0,0 +1,207 @@ +#include "defs.h" +#include "datashard_ut_common_kqp.h" +#include "upload_stats.h" + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +namespace NKikimr { +using namespace Tests; +using Ydb::Table::VectorIndexSettings; +using namespace NTableIndex::NTableVectorKmeansTreeIndex; + +static ui64 sId = 1; +static constexpr const char* kMainTable = "/Root/table-main"; +static constexpr const char* kLevelTable = "/Root/table-level"; +static constexpr const char* kPostingTable = "/Root/table-posting"; + +Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) { + static std::tuple DoLocalKMeans(Tests::TServer::TPtr server, TActorId sender, ui64 seed, ui64 k, + VectorIndexSettings::VectorType type, VectorIndexSettings::Distance distance) { + auto id = sId++; + auto& runtime = *server->GetRuntime(); + auto snapshot = CreateVolatileSnapshot(server, {kMainTable}); + auto datashards = GetTableShards(server, sender, kMainTable); + TTableId tableId = ResolveTableId(server, sender, kMainTable); + + TString err; + + for (auto tid : datashards) { + auto ev1 = std::make_unique(); + auto ev2 = std::make_unique(); + auto fill = [&](std::unique_ptr& ev) { + auto& rec = ev->Record; + rec.SetId(1); + + rec.SetSeqNoGeneration(id); + rec.SetSeqNoRound(1); + + rec.SetTabletId(tid); + PathIdFromPathId(tableId.PathId, rec.MutablePathId()); + + rec.SetSnapshotTxId(snapshot.TxId); + rec.SetSnapshotStep(snapshot.Step); + + VectorIndexSettings settings; + settings.set_vector_dimension(2); + settings.set_vector_type(type); + settings.set_distance(distance); + *rec.MutableSettings() = settings; + + rec.SetK(k); + rec.SetSeed(seed); + + rec.SetState(NKikimrTxDataShard::TEvLocalKMeansRequest::SAMPLE); + rec.SetUpload(NKikimrTxDataShard::TEvLocalKMeansRequest::UPLOAD_MAIN_TO_POSTING); + + rec.SetDoneRounds(0); + rec.SetNeedsRounds(3); + + rec.SetParent(0); + rec.SetChild(1); + + rec.SetEmbeddingColumn("embedding"); + rec.AddDataColumns("data"); + + rec.SetLevelName(kLevelTable); + rec.SetPostingName(kPostingTable); + }; + fill(ev1); + fill(ev2); + + runtime.SendToPipe(tid, sender, ev1.release(), 0, GetPipeConfigWithRetries()); + runtime.SendToPipe(tid, sender, ev2.release(), 0, GetPipeConfigWithRetries()); + + TAutoPtr handle; + auto reply = runtime.GrabEdgeEventRethrow(handle); + + NYql::TIssues issues; + NYql::IssuesFromMessage(reply->Record.GetIssues(), issues); + UNIT_ASSERT_EQUAL_C(reply->Record.GetStatus(), NKikimrIndexBuilder::EBuildStatus::DONE, issues.ToOneLineString()); + } + + auto level = ReadShardedTable(server, kLevelTable); + auto posting = ReadShardedTable(server, kPostingTable); + return {std::move(level), std::move(posting)}; + } + + static void DropTable(Tests::TServer::TPtr server, TActorId sender, const char* name) { + ui64 txId = AsyncDropTable(server, sender, "/Root", name); + WaitTxNotification(server, sender, txId); + } + + static void CreateLevelTable(Tests::TServer::TPtr server, TActorId sender, TShardedTableOptions options) { + options.AllowSystemColumnNames(true); + options.Columns({ + {LevelTable_ParentIdColumn, "Uint32", true, true}, + {LevelTable_IdColumn, "Uint32", true, true}, + {LevelTable_EmbeddingColumn, "String", false, true}, + }); + CreateShardedTable(server, sender, "/Root", "table-level", options); + } + + static void CreatePostingTable(Tests::TServer::TPtr server, TActorId sender, TShardedTableOptions options) { + options.AllowSystemColumnNames(true); + options.Columns({ + {PostingTable_ParentIdColumn, "Uint32", true, true}, + {"key", "Uint32", true, true}, + {"data", "String", false, false}, + }); + CreateShardedTable(server, sender, "/Root", "table-posting", options); + } + + Y_UNIT_TEST (RunScan) { + TPortManager pm; + TServerSettings serverSettings(pm.GetPort(2134)); + serverSettings.SetDomainName("Root"); + + Tests::TServer::TPtr server = new TServer(serverSettings); + auto& runtime = *server->GetRuntime(); + auto sender = runtime.AllocateEdgeActor(); + + runtime.SetLogPriority(NKikimrServices::TX_DATASHARD, NLog::PRI_DEBUG); + + InitRoot(server, sender); + + TShardedTableOptions options; + options.EnableOutOfOrder(true); // TODO(mbkkt) what is it? + options.Shards(1); // single shard request + + options.AllowSystemColumnNames(false); + options.Columns({ + {"key", "Uint32", true, true}, + {"embedding", "String", false, false}, + {"data", "String", false, false}, + }); + options.EnableOutOfOrder(true); // TODO(mbkkt) what is it? + + CreateShardedTable(server, sender, "/Root", "table-main", options); + + // Upsert some initial values + ExecSQL(server, sender, R"( + UPSERT INTO `/Root/table-main` + (key, embedding, data) + VALUES )" + "(1, \"\x30\x30\3\", \"one\")," + "(2, \"\x31\x31\3\", \"two\")," + "(3, \"\x32\x32\3\", \"three\")," + "(4, \"\x65\x65\3\", \"four\")," + "(5, \"\x75\x75\3\", \"five\");"); + + auto create = [&] { + CreateLevelTable(server, sender, options); + CreatePostingTable(server, sender, options); + }; + create(); + auto recreate = [&] { + DropTable(server, sender, "table-level"); + DropTable(server, sender, "table-posting"); + create(); + }; + + ui64 seed, k; + + seed = 0; + for (auto distance : {VectorIndexSettings::DISTANCE_MANHATTAN, VectorIndexSettings::DISTANCE_EUCLIDEAN}) { + k = 2; + auto [level, posting] = DoLocalKMeans(server, sender, seed, k, VectorIndexSettings::VECTOR_TYPE_UINT8, distance); + UNIT_ASSERT_VALUES_EQUAL(level, + "__ydb_parent = 0, __ydb_id = 1, __ydb_embedding = mm\3\n" + "__ydb_parent = 0, __ydb_id = 2, __ydb_embedding = 11\3\n"); + UNIT_ASSERT_VALUES_EQUAL(posting, + "__ydb_parent = 1, key = 4, data = four\n" + "__ydb_parent = 1, key = 5, data = five\n" + "__ydb_parent = 2, key = 1, data = one\n" + "__ydb_parent = 2, key = 2, data = two\n" + "__ydb_parent = 2, key = 3, data = three\n"); + recreate(); + } + + seed = 111; + for (auto distance : {VectorIndexSettings::DISTANCE_MANHATTAN, VectorIndexSettings::DISTANCE_EUCLIDEAN}) { + k = 2; + auto [level, posting] = DoLocalKMeans(server, sender, seed, k, VectorIndexSettings::VECTOR_TYPE_UINT8, distance); + UNIT_ASSERT_VALUES_EQUAL(level, + "__ydb_parent = 0, __ydb_id = 1, __ydb_embedding = 11\3\n" + "__ydb_parent = 0, __ydb_id = 2, __ydb_embedding = mm\3\n"); + UNIT_ASSERT_VALUES_EQUAL(posting, + "__ydb_parent = 1, key = 1, data = one\n" + "__ydb_parent = 1, key = 2, data = two\n" + "__ydb_parent = 1, key = 3, data = three\n" + "__ydb_parent = 2, key = 4, data = four\n" + "__ydb_parent = 2, key = 5, data = five\n"); + recreate(); + } + } +} + +} // namespace NKikimr diff --git a/ydb/core/tx/datashard/datashard_ut_sample_k.cpp b/ydb/core/tx/datashard/datashard_ut_sample_k.cpp index 7a0e8b1a9370..204b9b29078e 100644 --- a/ydb/core/tx/datashard/datashard_ut_sample_k.cpp +++ b/ydb/core/tx/datashard/datashard_ut_sample_k.cpp @@ -12,11 +12,6 @@ #include -template <> -inline void Out(IOutputStream& o, NKikimrIndexBuilder::EBuildStatus status) { - o << NKikimrIndexBuilder::EBuildStatus_Name(status); -} - namespace NKikimr { static ui64 sId = 1; diff --git a/ydb/core/tx/datashard/local_kmeans.cpp b/ydb/core/tx/datashard/local_kmeans.cpp new file mode 100644 index 000000000000..56a8b30650bd --- /dev/null +++ b/ydb/core/tx/datashard/local_kmeans.cpp @@ -0,0 +1,919 @@ +#include "datashard_impl.h" +#include "range_ops.h" +#include "scan_common.h" +#include "upload_stats.h" +#include "buffer_data.h" + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +#include +#include + +#include +#include +#include + +#define LOG_T(stream) LOG_TRACE_S(*TlsActivationContext, NKikimrServices::TX_DATASHARD, stream) +#define LOG_D(stream) LOG_DEBUG_S(*TlsActivationContext, NKikimrServices::TX_DATASHARD, stream) +#define LOG_N(stream) LOG_NOTICE_S(*TlsActivationContext, NKikimrServices::TX_DATASHARD, stream) +#define LOG_E(stream) LOG_ERROR_S(*TlsActivationContext, NKikimrServices::TX_DATASHARD, stream) + +template +Y_PURE_FUNCTION TTriWayDotProduct CosineImpl(const float* lhs, const float* rhs, size_t length) noexcept { + auto r = TriWayDotProduct(lhs, rhs, length); + return {static_cast(r.LL), static_cast(r.LR), static_cast(r.RR)}; +} + +template +Y_PURE_FUNCTION TTriWayDotProduct CosineImpl(const i8* lhs, const i8* rhs, size_t length) noexcept { + const auto ll = DotProduct(lhs, lhs, length); + const auto lr = DotProduct(lhs, rhs, length); + const auto rr = DotProduct(rhs, rhs, length); + return {static_cast(ll), static_cast(lr), static_cast(rr)}; +} + +template +Y_PURE_FUNCTION TTriWayDotProduct CosineImpl(const ui8* lhs, const ui8* rhs, size_t length) noexcept { + const auto ll = DotProduct(lhs, lhs, length); + const auto lr = DotProduct(lhs, rhs, length); + const auto rr = DotProduct(rhs, rhs, length); + return {static_cast(ll), static_cast(lr), static_cast(rr)}; +} + +namespace NKikimr::NDataShard { + +TTableRange CreateRangeFrom(const TUserTable& table, ui32 parent) { + if (parent == 0) { + return table.GetTableRange(); + } + const auto parentCell = TCell::Make(parent); + const TTableRange parentRange{{&parentCell, 1}}; + return Intersect(table.KeyColumnTypes, table.GetTableRange(), parentRange); +} + +NTable::TLead CreateLeadFrom(const TTableRange& range) { + NTable::TLead lead; + if (range.From) { + lead.To(range.From, range.InclusiveFrom ? NTable::ESeek::Lower : NTable::ESeek::Upper); + } else { + lead.To({}, NTable::ESeek::Lower); + } + if (range.To) { + lead.Until(range.To, range.InclusiveTo); + } + return lead; +} + +// TODO(mbkkt) separate implementation for bit +template +struct TMetric { + using TCoord = T; + // TODO(mbkkt) maybe compute floating sum in double? Needs benchmark + using TSum = std::conditional_t, T, int64_t>; + + ui32 Dimensions = 0; + + bool IsExpectedSize(TArrayRef data) const noexcept { + return data.size() == 1 + sizeof(TCoord) * Dimensions; + } + + auto GetCoords(const char* coords) { + return std::span{reinterpret_cast(coords), Dimensions}; + } + + auto GetData(char* data) { + return std::span{reinterpret_cast(data), Dimensions}; + } + + void Fill(TString& d, TSum* embedding, ui64& c) { + const auto count = static_cast(std::exchange(c, 0)); + auto data = GetData(d.MutRef().data()); + for (auto& coord : data) { + coord = *embedding / count; + *embedding++ = 0; + } + } +}; + +template +struct TCosineSimilarity: TMetric { + using TCoord = TMetric::TCoord; + using TSum = TMetric::TSum; + // double used to avoid precision issues + using TRes = double; + + static TRes Init() { + return std::numeric_limits::max(); + } + + auto Distance(const char* cluster, const char* embedding) const noexcept { + const auto r = CosineImpl(reinterpret_cast(cluster), reinterpret_cast(embedding), this->Dimensions); + // sqrt(ll) * sqrt(rr) computed instead of sqrt(ll * rr) to avoid precision issues + const auto norm = std::sqrt(r.LL) * std::sqrt(r.RR); + const TRes similarity = norm != 0 ? static_cast(r.LR) / static_cast(norm) : 0; + return -similarity; + } +}; + +template +struct TL1Distance: TMetric { + using TCoord = TMetric::TCoord; + using TSum = TMetric::TSum; + using TRes = std::conditional_t, T, ui64>; + + static TRes Init() { + return std::numeric_limits::max(); + } + + auto Distance(const char* cluster, const char* embedding) const noexcept { + const auto distance = L1Distance(reinterpret_cast(cluster), reinterpret_cast(embedding), this->Dimensions); + return distance; + } +}; + +template +struct TL2Distance: TMetric { + using TCoord = TMetric::TCoord; + using TSum = TMetric::TSum; + using TRes = std::conditional_t, T, ui64>; + + static TRes Init() { + return std::numeric_limits::max(); + } + + auto Distance(const char* cluster, const char* embedding) const noexcept { + const auto distance = L2SqrDistance(reinterpret_cast(cluster), reinterpret_cast(embedding), this->Dimensions); + return distance; + } +}; + +template +struct TMaxInnerProductSimilarity: TMetric { + using TCoord = TMetric::TCoord; + using TSum = TMetric::TSum; + using TRes = std::conditional_t, T, i64>; + + static TRes Init() { + return std::numeric_limits::max(); + } + + auto Distance(const char* cluster, const char* embedding) const noexcept { + const TRes similarity = DotProduct(reinterpret_cast(cluster), reinterpret_cast(embedding), this->Dimensions); + return -similarity; + } +}; + +template +struct TCalculation: TMetric { + using TEmbedding = std::vector; + + struct TAggregated { + TEmbedding Cluster; + ui64 Count = 0; + }; + + ui32 FindClosest(std::span clusters, const char* embedding, std::span aggregated) { + auto min = TMetric::Init(); + ui32 closest = std::numeric_limits::max(); + for (size_t i = 0; const auto& cluster : clusters) { + auto distance = TMetric::Distance(cluster.data(), embedding); + if (distance < min || (distance == min && aggregated[i].Count < aggregated[closest].Count)) { + closest = i; + } + ++i; + } + return closest; + } +}; + +template +class TLocalKMeansScan final: public TActor>, public NTable::IScan, private TCalculation { +protected: + using EState = NKikimrTxDataShard::TEvLocalKMeansRequest; + using TTBase = TActor>; + + ui32 Parent = 0; + ui32 Child = 0; + + ui32 Round = 0; + ui32 MaxRounds = 0; + + ui32 K = 0; + + EState::EState State; + EState::EState UploadState; + + IDriver* Driver = nullptr; + + TLead Lead; + + // Sample + ui64 ReadRows = 0; + ui64 ReadBytes = 0; + + ui64 MaxProbability = std::numeric_limits::max(); + TReallyFastRng32 Rng; + + struct TProbability { + ui64 P = 0; + ui64 I = 0; + + bool operator==(const TProbability&) const noexcept = default; + auto operator<=>(const TProbability&) const noexcept = default; + }; + + std::vector MaxRows; + std::vector Clusters; + + // KMeans + std::vector::TAggregated> Aggregated; + + // Upload + std::shared_ptr TargetTypes; + std::shared_ptr NextTypes; + + TString TargetTable; + TString NextTable; + + TBufferData ReadBuf; + TBufferData WriteBuf; + + NTable::TPos EmbeddingPos = 0; + NTable::TPos DataPos = 1; + + ui32 RetryCount = 0; + + TActorId Uploader; + TUploadLimits Limits; + + NTable::TTag KMeansScan; + TTags UploadScan; + + TUploadStatus UploadStatus; + + // Response + TActorId ResponseActorId; + TAutoPtr Response; + +public: + static constexpr NKikimrServices::TActivity::EType ActorActivityType() { + return NKikimrServices::TActivity::LOCAL_KMEANS_SCAN_ACTOR; + } + + TLocalKMeansScan(const TUserTable& table, TLead&& lead, NKikimrTxDataShard::TEvLocalKMeansRequest& request, const TActorId& responseActorId, TAutoPtr&& response) + : TTBase::TActor{&TTBase::TThis::StateWork} + , Parent{request.GetParent()} + , Child{request.GetChild()} + , MaxRounds{request.GetNeedsRounds() - request.GetDoneRounds()} + , K{request.GetK()} + , State{request.GetState()} + , UploadState{request.GetUpload()} + , Lead{std::move(lead)} + , Rng{request.GetSeed()} + , TargetTable{request.GetLevelName()} + , NextTable{request.GetPostingName()} + , ResponseActorId{responseActorId} + , Response{std::move(response)} { + this->Dimensions = request.GetSettings().vector_dimension(); + const auto& embedding = request.GetEmbeddingColumn(); + const auto& data = request.GetDataColumns(); + // scan tags + { + auto tags = GetAllTags(table); + KMeansScan = tags.at(embedding); + UploadScan.reserve(1 + data.size()); + if (auto it = std::find(data.begin(), data.end(), embedding); it != data.end()) { + EmbeddingPos = it - data.begin(); + DataPos = 0; + } else { + UploadScan.push_back(KMeansScan); + } + for (const auto& column : data) { + UploadScan.push_back(tags.at(column)); + } + } + // upload types + Ydb::Type type; + if (State <= EState::KMEANS) { + TargetTypes = std::make_shared(3); + type.set_type_id(Ydb::Type::UINT32); + (*TargetTypes)[0] = {NTableIndex::NTableVectorKmeansTreeIndex::LevelTable_ParentIdColumn, type}; + (*TargetTypes)[1] = {NTableIndex::NTableVectorKmeansTreeIndex::LevelTable_IdColumn, type}; + type.set_type_id(Ydb::Type::STRING); + (*TargetTypes)[2] = {NTableIndex::NTableVectorKmeansTreeIndex::LevelTable_EmbeddingColumn, type}; + } + { + auto types = GetAllTypes(table); + + NextTypes = std::make_shared(); + NextTypes->reserve(1 + 1 + std::min(table.KeyColumnTypes.size() + data.size(), types.size())); + + type.set_type_id(Ydb::Type::UINT32); + NextTypes->emplace_back(NTableIndex::NTableVectorKmeansTreeIndex::PostingTable_ParentIdColumn, type); + + auto addType = [&](const auto& column) { + auto it = types.find(column); + Y_ABORT_UNLESS(it != types.end()); + ProtoYdbTypeFromTypeInfo(&type, it->second); + NextTypes->emplace_back(it->first, type); + types.erase(it); + }; + for (const auto& column : table.KeyColumnIds) { + addType(table.Columns.at(column).Name); + } + switch (UploadState) { + case EState::UPLOAD_MAIN_TO_TMP: + case EState::UPLOAD_TMP_TO_TMP: + addType(embedding); + [[fallthrough]]; + case EState::UPLOAD_MAIN_TO_POSTING: + case EState::UPLOAD_TMP_TO_POSTING: { + for (const auto& column : data) { + addType(column); + } + } break; + default: + Y_UNREACHABLE(); + } + } + } + + ~TLocalKMeansScan() final = default; + + TInitialState Prepare(IDriver* driver, TIntrusiveConstPtr) noexcept final { + TActivationContext::AsActorContext().RegisterWithSameMailbox(this); + LOG_T("Prepare " << Debug()); + + Driver = driver; + return {EScan::Feed, {}}; + } + + EScan Seek(TLead& lead, ui64 seq) noexcept final { + LOG_T("Seek " << Debug()); + if (State == UploadState) { + if (!WriteBuf.IsEmpty()) { + return EScan::Sleep; + } + if (!ReadBuf.IsEmpty()) { + ReadBuf.FlushTo(WriteBuf); + Upload(false); + return EScan::Sleep; + } + if (UploadStatus.IsNone()) { + UploadStatus.StatusCode = Ydb::StatusIds::SUCCESS; + } + return EScan::Final; + } + + if (State == EState::SAMPLE) { + lead = Lead; + lead.SetTags({&KMeansScan, 1}); + if (seq == 0) { + return EScan::Feed; + } + State = EState::KMEANS; + if (!InitAggregated()) { + // We don't need to do anything, + // because this datashard doesn't have valid embeddings for this parent + return EScan::Final; + } + return EScan::Feed; + } + + Y_ASSERT(State == EState::KMEANS); + RecomputeClusters(); + if (Round == MaxRounds) { + lead = std::move(Lead); + lead.SetTags(UploadScan); + + UploadSample(); + State = UploadState; + } else { + lead = Lead; + lead.SetTags({&KMeansScan, 1}); + ++Round; + } + return EScan::Feed; + } + + EScan Feed(TArrayRef key, const TRow& row) noexcept final { + LOG_T("Feed " << Debug()); + switch (State) { + case EState::SAMPLE: + return FeedSample(row); + case EState::KMEANS: + return FeedKMeans(row); + case EState::UPLOAD_MAIN_TO_TMP: + return FeedUploadMain2Tmp(key, row); + case EState::UPLOAD_MAIN_TO_POSTING: + return FeedUploadMain2Posting(key, row); + case EState::UPLOAD_TMP_TO_TMP: + return FeedUploadTmp2Tmp(key, row); + case EState::UPLOAD_TMP_TO_POSTING: + return FeedUploadTmp2Posting(key, row); + case EState::DONE: + Y_UNREACHABLE(); + } + } + + TAutoPtr Finish(EAbort abort) noexcept final { + LOG_T("Finish " << Debug()); + auto ctx = TActivationContext::AsActorContext().MakeFor(this->SelfId()); + + if (Uploader) { + TAutoPtr poison = new TEvents::TEvPoisonPill; + ctx.Send(Uploader, poison.Release()); + Uploader = {}; + } + + auto& record = Response->Record; + if (abort != EAbort::None) { + record.SetStatus(NKikimrIndexBuilder::EBuildStatus::ABORTED); + } else if (UploadStatus.IsSuccess()) { + record.SetStatus(NKikimrIndexBuilder::EBuildStatus::DONE); + } else { + record.SetStatus(NKikimrIndexBuilder::EBuildStatus::BUILD_ERROR); + } + NYql::IssuesToMessage(UploadStatus.Issues, record.MutableIssues()); + ctx.Send(ResponseActorId, Response.Release()); + + Driver = nullptr; + this->PassAway(); + return nullptr; + } + + void Describe(IOutputStream& out) const noexcept final { + out << Debug(); + } + + TString Debug() const { + auto builder = TStringBuilder() << " TLocalKMeansScan"; + if (Response) { + auto& r = Response->Record; + builder << " Id: " << r.GetId(); + } + return builder << " State: " << State + << " Round: " << Round + << " MaxRounds: " << MaxRounds + << " ReadBuf size: " << ReadBuf.Size() + << " WriteBuf size: " << WriteBuf.Size() + << " "; + } + + EScan PageFault() noexcept final { + LOG_T("PageFault " << Debug()); + + if (!ReadBuf.IsEmpty() && WriteBuf.IsEmpty()) { + ReadBuf.FlushTo(WriteBuf); + Upload(false); + } + + return EScan::Feed; + } + +private: + STFUNC(StateWork) { + switch (ev->GetTypeRewrite()) { + HFunc(TEvTxUserProxy::TEvUploadRowsResponse, Handle); + CFunc(TEvents::TSystem::Wakeup, HandleWakeup); + default: + LOG_E("TLocalKMeansScan: StateWork unexpected event type: " << ev->GetTypeRewrite() << " event: " << ev->ToString() << " " << Debug()); + } + } + + void HandleWakeup(const NActors::TActorContext& /*ctx*/) { + LOG_T("Retry upload " << Debug()); + + if (!WriteBuf.IsEmpty()) { + Upload(true); + } + } + + void Handle(TEvTxUserProxy::TEvUploadRowsResponse::TPtr& ev, const TActorContext& ctx) { + LOG_T("Handle TEvUploadRowsResponse " + << Debug() + << " Uploader: " << Uploader.ToString() + << " ev->Sender: " << ev->Sender.ToString()); + + if (Uploader) { + Y_VERIFY_S(Uploader == ev->Sender, "Mismatch Uploader: " << Uploader.ToString() << " ev->Sender: " << ev->Sender.ToString() << Debug()); + } else { + Y_ABORT_UNLESS(Driver == nullptr); + return; + } + + UploadStatus.StatusCode = ev->Get()->Status; + UploadStatus.Issues = ev->Get()->Issues; + if (UploadStatus.IsSuccess()) { + WriteBuf.Clear(); + if (!ReadBuf.IsEmpty() && ReadBuf.IsReachLimits(Limits)) { + ReadBuf.FlushTo(WriteBuf); + Upload(true); + } + + Driver->Touch(EScan::Feed); + return; + } + + if (RetryCount < Limits.MaxUploadRowsRetryCount && UploadStatus.IsRetriable()) { + LOG_N("Got retriable error, " << Debug() << UploadStatus.ToString()); + + ctx.Schedule(Limits.GetTimeoutBackouff(RetryCount), new TEvents::TEvWakeup()); + return; + } + + LOG_N("Got error, abort scan, " << Debug() << UploadStatus.ToString()); + + Driver->Touch(EScan::Final); + } + + EScan FeedUpload() { + if (!ReadBuf.IsReachLimits(Limits)) { + return EScan::Feed; + } + if (!WriteBuf.IsEmpty()) { + return EScan::Sleep; + } + ReadBuf.FlushTo(WriteBuf); + Upload(false); + return EScan::Feed; + } + + bool InitAggregated() { + if (Clusters.size() == 0) { + return false; + } + if (Clusters.size() < K) { + // if this datashard have smaller than K count of valid embeddings for this parent + // lets make single centroid for it + K = 1; + Clusters.resize(K); + } + Y_ASSERT(Clusters.size() == K); + Aggregated.resize(K); + for (auto& aggregate : Aggregated) { + aggregate.Cluster.resize(this->Dimensions, 0); + } + return true; + } + + void Aggregate(ui32 pos, const char* embedding) { + if (pos >= K) { + return; + } + auto& aggregate = Aggregated[pos]; + auto* coords = aggregate.Cluster.data(); + for (auto coord : this->GetCoords(embedding)) { + *coords++ += coord; + } + ++aggregate.Count; + } + + void RecomputeClusters() { + auto* clusters = Clusters.data(); + for (auto& aggregate : Aggregated) { + auto& cluster = *clusters++; + if (aggregate.Count == 0) { + continue; // TODO(mbkkt) is it impossible? + } + this->Fill(cluster, aggregate.Cluster.data(), aggregate.Count); + } + } + + ui64 GetProbability() { + return Rng.GenRand64(); + } + + ui32 FeedEmbedding(const TRow& row, NTable::TPos embeddingPos) { + Y_ASSERT(embeddingPos < row.Size()); + const auto embedding = row.Get(embeddingPos).AsRef(); + ++ReadRows; + ReadBytes += embedding.size(); // TODO(mbkkt) add some constant overhead? + if (!this->IsExpectedSize(embedding)) { + return std::numeric_limits::max(); + } + return this->FindClosest(Clusters, embedding.data(), Aggregated); + } + + EScan FeedSample(const TRow& row) noexcept { + Y_ASSERT(row.Size() == 1); + const auto embedding = row.Get(0).AsRef(); + ++ReadRows; + ReadBytes += embedding.size(); // TODO(mbkkt) add some constant overhead? + if (!this->IsExpectedSize(embedding)) { + return EScan::Feed; + } + + const auto probability = GetProbability(); + if (Clusters.size() < K) { + MaxRows.push_back({probability, Clusters.size()}); + Clusters.emplace_back(embedding.data(), embedding.size()); + if (Clusters.size() == K) { + std::make_heap(MaxRows.begin(), MaxRows.end()); + MaxProbability = MaxRows.front().P; + } + } else if (probability < MaxProbability) { + // TODO(mbkkt) use tournament tree to make less compare and swaps + std::pop_heap(MaxRows.begin(), MaxRows.end()); + Clusters[MaxRows.back().I].assign(embedding.data(), embedding.size()); + MaxRows.back().P = probability; + std::push_heap(MaxRows.begin(), MaxRows.end()); + MaxProbability = MaxRows.front().P; + } + return MaxProbability != 0 ? EScan::Feed : EScan::Reset; + } + + EScan FeedKMeans(const TRow& row) noexcept { + Y_ASSERT(row.Size() == 1); + const ui32 pos = FeedEmbedding(row, 0); + Aggregate(pos, row.Get(0).Data()); + return EScan::Feed; + } + + EScan FeedUploadMain2Tmp(TArrayRef key, const TRow& row) noexcept { + const ui32 pos = FeedEmbedding(row, EmbeddingPos); + if (pos > K) { + return EScan::Feed; + } + std::array cells; + cells[0] = TCell::Make(Child + pos); + auto pk = TSerializedCellVec::Serialize(cells); + TSerializedCellVec::UnsafeAppendCells(key, pk); + ReadBuf.AddRow(TSerializedCellVec{key}, TSerializedCellVec{std::move(pk)}, TSerializedCellVec::Serialize(*row)); + return FeedUpload(); + } + + EScan FeedUploadMain2Posting(TArrayRef key, const TRow& row) noexcept { + const ui32 pos = FeedEmbedding(row, EmbeddingPos); + if (pos > K) { + return EScan::Feed; + } + std::array cells; + cells[0] = TCell::Make(Child + pos); + auto pk = TSerializedCellVec::Serialize(cells); + TSerializedCellVec::UnsafeAppendCells(key, pk); + ReadBuf.AddRow(TSerializedCellVec{key}, TSerializedCellVec{std::move(pk)}, TSerializedCellVec::Serialize((*row).Slice(DataPos))); + return FeedUpload(); + } + + EScan FeedUploadTmp2Tmp(TArrayRef key, const TRow& row) noexcept { + const ui32 pos = FeedEmbedding(row, EmbeddingPos); + if (pos > K) { + return EScan::Feed; + } + std::array cells; + cells[0] = TCell::Make(Child + pos); + auto pk = TSerializedCellVec::Serialize(cells); + TSerializedCellVec::UnsafeAppendCells(key.Slice(1), pk); + ReadBuf.AddRow(TSerializedCellVec{key}, TSerializedCellVec{std::move(pk)}, TSerializedCellVec::Serialize(*row)); + return FeedUpload(); + } + + EScan FeedUploadTmp2Posting(TArrayRef key, const TRow& row) noexcept { + const ui32 pos = FeedEmbedding(row, EmbeddingPos); + if (pos > K) { + return EScan::Feed; + } + std::array cells; + cells[0] = TCell::Make(Child + pos); + auto pk = TSerializedCellVec::Serialize(cells); + TSerializedCellVec::UnsafeAppendCells(key.Slice(1), pk); + ReadBuf.AddRow(TSerializedCellVec{key}, TSerializedCellVec{std::move(pk)}, TSerializedCellVec::Serialize((*row).Slice(DataPos))); + return FeedUpload(); + } + + void Upload(bool isRetry) { + if (isRetry) { + ++RetryCount; + } else { + RetryCount = 0; + if (State != EState::KMEANS && NextTypes) { + TargetTypes = std::exchange(NextTypes, {}); + TargetTable = std::move(NextTable); + } + } + + auto actor = NTxProxy::CreateUploadRowsInternal( + this->SelfId(), TargetTable, + TargetTypes, + WriteBuf.GetRowsData(), + NTxProxy::EUploadRowsMode::WriteToTableShadow, + true /*writeToPrivateTable*/); + + Uploader = TActivationContext::AsActorContext().MakeFor(this->SelfId()).Register(actor); + } + + void UploadSample() { + Y_ASSERT(ReadBuf.IsEmpty()); + Y_ASSERT(WriteBuf.IsEmpty()); + std::array pk; + std::array data; + for (NTable::TPos pos = 0; const auto& row : Clusters) { + pk[0] = TCell::Make(Parent); + pk[1] = TCell::Make(Child + pos); + data[0] = TCell{row}; + WriteBuf.AddRow({}, TSerializedCellVec{pk}, TSerializedCellVec::Serialize(data)); + ++pos; + } + Upload(false); + } +}; + +class TDataShard::TTxHandleSafeLocalKMeansScan final: public NTabletFlatExecutor::TTransactionBase { +public: + TTxHandleSafeLocalKMeansScan(TDataShard* self, TEvDataShard::TEvLocalKMeansRequest::TPtr&& ev) + : TTransactionBase(self) + , Ev(std::move(ev)) { + } + + bool Execute(TTransactionContext&, const TActorContext& ctx) final { + Self->HandleSafe(Ev, ctx); + return true; + } + + void Complete(const TActorContext&) final { + } + +private: + TEvDataShard::TEvLocalKMeansRequest::TPtr Ev; +}; + +void TDataShard::Handle(TEvDataShard::TEvLocalKMeansRequest::TPtr& ev, const TActorContext&) { + Execute(new TTxHandleSafeLocalKMeansScan(this, std::move(ev))); +} + +void TDataShard::HandleSafe(TEvDataShard::TEvLocalKMeansRequest::TPtr& ev, const TActorContext& ctx) { + auto& record = ev->Get()->Record; + TRowVersion rowVersion(record.GetSnapshotStep(), record.GetSnapshotTxId()); + + // Note: it's very unlikely that we have volatile txs before this snapshot + if (VolatileTxManager.HasVolatileTxsAtSnapshot(rowVersion)) { + VolatileTxManager.AttachWaitingSnapshotEvent(rowVersion, + std::unique_ptr(ev.Release())); + return; + } + const ui64 id = record.GetId(); + + auto response = MakeHolder(); + response->Record.SetId(id); + response->Record.SetTabletId(TabletID()); + + TScanRecord::TSeqNo seqNo = {record.GetSeqNoGeneration(), record.GetSeqNoRound()}; + response->Record.SetRequestSeqNoGeneration(seqNo.Generation); + response->Record.SetRequestSeqNoRound(seqNo.Round); + + auto badRequest = [&](const TString& error) { + response->Record.SetStatus(NKikimrIndexBuilder::EBuildStatus::BAD_REQUEST); + auto issue = response->Record.AddIssues(); + issue->set_severity(NYql::TSeverityIds::S_ERROR); + issue->set_message(error); + ctx.Send(ev->Sender, std::move(response)); + }; + + if (const ui64 shardId = record.GetTabletId(); shardId != TabletID()) { + badRequest(TStringBuilder() << "Wrong shard " << shardId << " this is " << TabletID()); + return; + } + + const auto pathId = PathIdFromPathId(record.GetPathId()); + const auto* userTableIt = GetUserTables().FindPtr(pathId.LocalPathId); + if (!userTableIt) { + badRequest(TStringBuilder() << "Unknown table id: " << pathId.LocalPathId); + return; + } + Y_ABORT_UNLESS(*userTableIt); + const auto& userTable = **userTableIt; + + if (const auto* recCard = ScanManager.Get(id)) { + if (recCard->SeqNo == seqNo) { + // do no start one more scan + return; + } + + CancelScan(userTable.LocalTid, recCard->ScanId); + ScanManager.Drop(id); + } + + const auto range = CreateRangeFrom(userTable, record.GetParent()); + if (range.IsEmptyRange(userTable.KeyColumnTypes)) { + badRequest(TStringBuilder() << " requested range doesn't intersect with table range"); + return; + } + + if (!record.HasSnapshotStep() || !record.HasSnapshotTxId()) { + badRequest(TStringBuilder() << " request doesn't have Shapshot Step or TxId"); + return; + } + + const TSnapshotKey snapshotKey(pathId, rowVersion.Step, rowVersion.TxId); + const TSnapshot* snapshot = SnapshotManager.FindAvailable(snapshotKey); + if (!snapshot) { + badRequest(TStringBuilder() + << "no snapshot has been found" + << " , path id is " << pathId.OwnerId << ":" << pathId.LocalPathId + << " , snapshot step is " << snapshotKey.Step + << " , snapshot tx is " << snapshotKey.TxId); + return; + } + + if (!IsStateActive()) { + badRequest(TStringBuilder() << "Shard " << TabletID() << " is not ready for requests"); + return; + } + + if (record.GetK() < 1) { + badRequest(TStringBuilder() << "Should be requested at least single row"); + return; + } + + if (!record.HasEmbeddingColumn()) { + badRequest(TStringBuilder() << "Should be specified embedding column"); + return; + } + + TScanOptions scanOpts; + scanOpts.SetSnapshotRowVersion(rowVersion); + scanOpts.SetResourceBroker("build_index", 10); // TODO(mbkkt) Should be different group? + + const auto& settings = record.GetSettings(); + TAutoPtr scan; + + auto createScan = [&] { + scan = new TLocalKMeansScan{ + userTable, + CreateLeadFrom(range), + record, + ev->Sender, + std::move(response), + }; + }; + + auto handleType = [&]