diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_wide_combine.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_wide_combine.cpp index eb77f490bf2d..c28e8dee937c 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_wide_combine.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_wide_combine.cpp @@ -343,8 +343,8 @@ class TSpillingSupportState : public TComputationValue { }; public: - enum class ETasteResult: ui8 { - Init, + enum class ETasteResult: i8 { + Init = -1, Update, Skip }; @@ -372,15 +372,9 @@ class TSpillingSupportState : public TComputationValue { Tongue = InMemoryProcessingState.Tongue; Throat = InMemoryProcessingState.Throat; } - ~TSpillingSupportState() { - } - - bool IsFetchRequired() const { - return InputStatus != EFetchResult::Finish; - } bool HasAnyData() const { - return SpilledBuckets.size(); + return !SpilledBuckets.empty(); } bool IsProcessingRequired() const { @@ -456,6 +450,20 @@ class TSpillingSupportState : public TComputationValue { return ETasteResult::Skip; } + NUdf::TUnboxedValuePod* Extract() { + if (GetMode() == EOperatingMode::InMemory) return static_cast(InMemoryProcessingState.Extract()); + + MKQL_ENSURE(SpilledBuckets.front().BucketState == TSpilledBucket::EBucketState::InMemory, "Internal logic error"); + MKQL_ENSURE(SpilledBuckets.size() > 0, "Internal logic error"); + + auto value = static_cast(SpilledBuckets.front().InMemoryProcessingState->Extract()); + if (!value) { + SpilledBuckets.pop_front(); + } + + return value; + } +private: void MoveKeyToBucket(TSpilledBucket& bucket) { for (size_t i = 0; i < KeyWidth; ++i) { //jumping into unsafe world, refusing ownership @@ -483,20 +491,6 @@ class TSpillingSupportState : public TComputationValue { BufferForUsedInputItems.resize(0); } - NUdf::TUnboxedValuePod* Extract() { - if (GetMode() == EOperatingMode::InMemory) return static_cast(InMemoryProcessingState.Extract()); - - MKQL_ENSURE(SpilledBuckets.front().BucketState == TSpilledBucket::EBucketState::InMemory, "Internal logic error"); - MKQL_ENSURE(SpilledBuckets.size() > 0, "Internal logic error"); - - auto value = static_cast(SpilledBuckets.front().InMemoryProcessingState->Extract()); - if (!value) { - SpilledBuckets.pop_front(); - } - - return value; - } - bool FlushSpillingBuffersAndWait() { UpdateSpillingBuckets(); @@ -521,7 +515,6 @@ class TSpillingSupportState : public TComputationValue { return ProcessSpilledDataAndWait(); } -private: void SplitStateIntoBuckets() { while (const auto keyAndState = static_cast(InMemoryProcessingState.Extract())) { auto hash = Hasher(keyAndState); //Hasher uses only key for hashing @@ -1246,7 +1239,7 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode(state.AsBoxed().Get())) { @@ -1306,6 +1299,7 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNodegetType()}, false); BranchInst::Create(more, block); - block = more; - - const auto loop = BasicBlock::Create(context, "loop", ctx.Func); const auto full = BasicBlock::Create(context, "full", ctx.Func); const auto over = BasicBlock::Create(context, "over", ctx.Func); - const auto result = PHINode::Create(statusType, 3U, "result", over); - - const auto statusPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetStatus() }, "last", block); - const auto last = new LoadInst(statusType, statusPtr, "last", block); - const auto finish = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, last, ConstantInt::get(last->getType(), static_cast(EFetchResult::Finish)), "finish", block); - - BranchInst::Create(full, loop, finish, block); + const auto result = PHINode::Create(statusType, 4U, "result", over); { + const auto test = BasicBlock::Create(context, "test", ctx.Func); + const auto pull = BasicBlock::Create(context, "pull", ctx.Func); const auto rest = BasicBlock::Create(context, "rest", ctx.Func); + const auto proc = BasicBlock::Create(context, "proc", ctx.Func); const auto good = BasicBlock::Create(context, "good", ctx.Func); - block = loop; + block = more; + + const auto waitMoreFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSpillingSupportState::UpdateAndWait)); + const auto waitMoreFuncPtr = CastInst::Create(Instruction::IntToPtr, waitMoreFunc, PointerType::getUnqual(boolFuncType), "wait_more_func", block); + const auto waitMore = CallInst::Create(boolFuncType, waitMoreFuncPtr, { stateArg }, "wait_more", block); + + result->addIncoming(ConstantInt::get(statusType, static_cast(EFetchResult::Yield)), block); + + BranchInst::Create(over, test, waitMore, block); + + block = test; + + const auto statusPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetStatus() }, "last", block); + const auto last = new LoadInst(statusType, statusPtr, "last", block); + const auto finish = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, last, ConstantInt::get(last->getType(), static_cast(EFetchResult::Finish)), "finish", block); + + BranchInst::Create(good, pull, finish, block); + + block = pull; const auto getres = GetNodeValues(Flow, ctx, block); @@ -1362,12 +1369,19 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNodeaddCase(ConstantInt::get(statusType, static_cast(EFetchResult::Finish)), rest); block = rest; - new StoreInst(ConstantInt::get(last->getType(), static_cast(EFetchResult::Finish)), statusPtr, block); - - BranchInst::Create(full, block); + new StoreInst(ConstantInt::get(statusType, static_cast(EFetchResult::Finish)), statusPtr, block); + BranchInst::Create(more, block); block = good; + const auto processingFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSpillingSupportState::IsProcessingRequired)); + const auto processingFuncPtr = CastInst::Create(Instruction::IntToPtr, processingFunc, PointerType::getUnqual(boolFuncType), "processing_func", block); + const auto processing = CallInst::Create(boolFuncType, processingFuncPtr, { stateArg }, "processing", block); + + BranchInst::Create(proc, full, processing, block); + + block = proc; + std::vector items(Nodes.ItemNodes.size(), nullptr); for (ui32 i = 0U; i < items.size(); ++i) { if (Nodes.ItemNodes[i]->GetDependencesCount() > 0U) @@ -1398,10 +1412,10 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNodegetType()}, false); + const auto atFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSpillingSupportState::TasteIt)); + const auto atType = FunctionType::get(wayType, {stateArg->getType()}, false); const auto atPtr = CastInst::Create(Instruction::IntToPtr, atFunc, PointerType::getUnqual(atType), "function", block); - const auto newKey = CallInst::Create(atType, atPtr, {stateArg}, "new_key", block); + const auto taste= CallInst::Create(atType, atPtr, {stateArg}, "taste", block); const auto init = BasicBlock::Create(context, "init", ctx.Func); const auto next = BasicBlock::Create(context, "next", ctx.Func); @@ -1415,7 +1429,9 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNodeaddCase(ConstantInt::get(wayType, static_cast(TSpillingSupportState::ETasteResult::Init)), init); + way->addCase(ConstantInt::get(wayType, static_cast(TSpillingSupportState::ETasteResult::Update)), next); block = init; @@ -1439,7 +1455,7 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNodegetType()}, false); const auto extractPtr = CastInst::Create(Instruction::IntToPtr, extractFunc, PointerType::getUnqual(extractType), "extract", block); const auto out = CallInst::Create(extractType, extractPtr, {stateArg}, "out", block); const auto has = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, out, ConstantPointerNull::get(ptrValueType), "has", block); - result->addIncoming(ConstantInt::get(statusType, static_cast(EFetchResult::Finish)), block); - - BranchInst::Create(good, over, has, block); + BranchInst::Create(good, last, has, block); block = good; @@ -1514,6 +1529,16 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNodeaddIncoming(ConstantInt::get(statusType, static_cast(EFetchResult::One)), block); BranchInst::Create(over, block); + + block = last; + + const auto hasDataFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSpillingSupportState::HasAnyData)); + const auto hasDataFuncPtr = CastInst::Create(Instruction::IntToPtr, hasDataFunc, PointerType::getUnqual(boolFuncType), "has_data_func", block); + const auto hasData = CallInst::Create(boolFuncType, hasDataFuncPtr, { stateArg }, "has_data", block); + + result->addIncoming(ConstantInt::get(statusType, static_cast(EFetchResult::Finish)), block); + + BranchInst::Create(more, over, hasData, block); } block = over; @@ -1528,23 +1553,17 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode(Nodes.KeyNodes.size(), Nodes.StateNodes.size(), TMyValueHasher(KeyTypes), TMyValueEqual(KeyTypes)); -#else - state = ctx.HolderFactory.Create(Nodes.KeyNodes.size(), Nodes.StateNodes.size(), - ctx.ExecuteLLVM && Hash ? THashFunc(std::ptr_fun(Hash)) : THashFunc(TMyValueHasher(KeyTypes)), - ctx.ExecuteLLVM && Equals ? TEqualsFunc(std::ptr_fun(Equals)) : TEqualsFunc(TMyValueEqual(KeyTypes)) - ); -#endif - } - - void MakeSpillingSupportState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const { state = ctx.HolderFactory.Create(WideFieldsIndex, UsedInputItemType, KeyAndStateType, Nodes.KeyNodes.size(), Nodes.ItemNodes.size(), +#ifdef MKQL_DISABLE_CODEGEN TMyValueHasher(KeyTypes), TMyValueEqual(KeyTypes), +#else + ctx.ExecuteLLVM && Hash ? THashFunc(std::ptr_fun(Hash)) : THashFunc(TMyValueHasher(KeyTypes)), + ctx.ExecuteLLVM && Equals ? TEqualsFunc(std::ptr_fun(Equals)) : TEqualsFunc(TMyValueEqual(KeyTypes)), +#endif AllowSpilling, ctx ); @@ -1569,7 +1588,6 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNodeGetDataSlot(), optional); }