diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_skip.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_skip.cpp index 881450504cb8..b1bd436740b7 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_skip.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_skip.cpp @@ -1,6 +1,7 @@ #include "mkql_skip.h" #include #include // Y_IGNORE +#include #include namespace NKikimr { @@ -117,105 +118,56 @@ using TBaseComputation = TStatefulFlowCodegeneratorNode; IComputationNode* const Count; }; -class TWideSkipWrapper : public TStatefulWideFlowCodegeneratorNode { -using TBaseComputation = TStatefulWideFlowCodegeneratorNode; +class TWideSkipWrapper : public TSimpleStatefulWideFlowCodegeneratorNode { +using TBaseComputation = TSimpleStatefulWideFlowCodegeneratorNode; public: TWideSkipWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, IComputationNode* count, ui32 size) - : TBaseComputation(mutables, flow, EValueRepresentation::Embedded) + : TBaseComputation(mutables, flow, size, size) , Flow(flow) , Count(count) , StubsIndex(mutables.IncrementWideFieldsIndex(size)) {} - EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const { - if (state.IsInvalid()) { - state = Count->GetValue(ctx); - } + void InitState(NUdf::TUnboxedValue& cntToSkip, TComputationContext& ctx) const { + cntToSkip = Count->GetValue(ctx); + } - if (auto count = state.Get()) { - do if (const auto result = Flow->FetchValues(ctx, ctx.WideFields.data() + StubsIndex); EFetchResult::One != result) { - state = NUdf::TUnboxedValuePod(count); - return result; - } while (--count); + NUdf::TUnboxedValue*const* PrepareInput(NUdf::TUnboxedValue& cntToSkip, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const { + return cntToSkip.Get() ? ctx.WideFields.data() + StubsIndex : output; + } - state = NUdf::TUnboxedValuePod::Zero(); + TMaybeFetchResult DoProcess(NUdf::TUnboxedValue& cntToSkip, TComputationContext&, TMaybeFetchResult fetchRes, NUdf::TUnboxedValue*const*) const { + if (fetchRes.Get() == EFetchResult::One && cntToSkip.Get()) { + cntToSkip = NUdf::TUnboxedValuePod(cntToSkip.Get() - 1); + return TMaybeFetchResult::None(); } - - return Flow->FetchValues(ctx, output); + return fetchRes; } #ifndef MKQL_DISABLE_CODEGEN - TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const { + TGenerateResult GenFetchProcess(Value* statePtrVal, const TCodegenContext& ctx, const TResultCodegenerator& fetchGenerator, BasicBlock*& block) const override { auto& context = ctx.Codegen.GetContext(); - - const auto valueType = Type::getInt128Ty(context); - - const auto init = BasicBlock::Create(context, "init", ctx.Func); - const auto main = BasicBlock::Create(context, "main", ctx.Func); - - const auto load = new LoadInst(valueType, statePtr, "load", block); - const auto state = PHINode::Create(valueType, 2U, "state", main); - state->addIncoming(load, block); - BranchInst::Create(init, main, IsInvalid(load, block), block); - - block = init; - - GetNodeValue(statePtr, Count, ctx, block); - const auto save = new LoadInst(valueType, statePtr, "save", block); - state->addIncoming(save, block); - BranchInst::Create(main, block); - - block = main; - - const auto work = BasicBlock::Create(context, "work", ctx.Func); - const auto good = BasicBlock::Create(context, "good", ctx.Func); - const auto pass = BasicBlock::Create(context, "pass", ctx.Func); - const auto exit = BasicBlock::Create(context, "exit", ctx.Func); - const auto skip = BasicBlock::Create(context, "skip", ctx.Func); - const auto done = BasicBlock::Create(context, "done", ctx.Func); - - const auto resultType = Type::getInt32Ty(context); - const auto result = PHINode::Create(resultType, 2U, "result", done); - - const auto trunc = GetterFor(state, context, block); - - const auto count = PHINode::Create(trunc->getType(), 2U, "count", work); - count->addIncoming(trunc, block); - - const auto plus = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, trunc, ConstantInt::get(trunc->getType(), 0ULL), "plus", block); - - BranchInst::Create(work, skip, plus, block); - - block = work; - const auto status = GetNodeValues(Flow, ctx, block).first; - const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, status, ConstantInt::get(status->getType(), 0), "special", block); - BranchInst::Create(pass, good, special, block); - - block = pass; - new StoreInst(SetterFor(count, context, block), statePtr, block); - result->addIncoming(status, block); - BranchInst::Create(done, block); - - block = good; - - const auto decr = BinaryOperator::CreateSub(count, ConstantInt::get(count->getType(), 1ULL), "decr", block); - const auto next = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, decr, ConstantInt::get(decr->getType(), 0ULL), "next", block); - count->addIncoming(decr, block); - BranchInst::Create(work, exit, next, block); - - block = exit; - new StoreInst(SetterFor(decr, context, block), statePtr, block); - BranchInst::Create(skip, block); - - block = skip; - auto getres = GetNodeValues(Flow, ctx, block); - result->addIncoming(getres.first, block); - BranchInst::Create(done, block); - - block = done; - return {result, std::move(getres.second)}; + const auto decr = BasicBlock::Create(context, "decr", ctx.Func); + const auto end = BasicBlock::Create(context, "end", ctx.Func); + + const auto fetched = fetchGenerator(ctx, block); + const auto cntToSkipVal = GetterFor(new LoadInst(IntegerType::getInt128Ty(context), statePtrVal, "unboxed_state", block), context, block); + const auto needSkipCond = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, cntToSkipVal, ConstantInt::get(cntToSkipVal->getType(), 0), "need_skip", block); + const auto gotOneCond = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, fetched.first, ConstantInt::get(fetched.first->getType(), 1), "got_one", block); + const auto willSkipCond = BinaryOperator::Create(Instruction::And, needSkipCond, gotOneCond, "will_skip", block); + BranchInst::Create(decr, end, willSkipCond, block); + + block = decr; + const auto cntToSkipNewVal = BinaryOperator::CreateSub(cntToSkipVal, ConstantInt::get(cntToSkipVal->getType(), 1), "decr", block); + new StoreInst(SetterFor(cntToSkipNewVal, context, block), statePtrVal, block); + BranchInst::Create(end, block); + + block = end; + const auto result = SelectInst::Create(willSkipCond, TMaybeFetchResult::None().LLVMConst(context), TMaybeFetchResult::LLVMFromFetchResult(fetched.first, "fetch_res_ext", block), "result", block); + return {result, fetched.second}; } #endif + private: void RegisterDependencies() const final { if (const auto flow = FlowDependsOn(Flow)) diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_wide_filter.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_wide_filter.cpp index e54effbe2e7c..d2d262303690 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_wide_filter.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_wide_filter.cpp @@ -1,6 +1,7 @@ #include "mkql_wide_filter.h" #include #include // Y_IGNORE +#include #include #include @@ -26,7 +27,7 @@ class TBaseWideFilterWrapper { return ctx.WideFields.data() + WideFieldsIndex; } - void PrepareArguments(TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const { + NUdf::TUnboxedValue*const* PrepareArguments(TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const { auto** fields = GetFields(ctx); for (auto i = 0U; i < Items.size(); ++i) { @@ -35,6 +36,8 @@ class TBaseWideFilterWrapper { else fields[i] = output[i]; } + + return fields; } void FillOutputs(TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const { @@ -45,6 +48,7 @@ class TBaseWideFilterWrapper { if (Predicate == Items[i] || Items[i]->GetDependencesCount() > 0U) *out = *fields[i]; } + #ifndef MKQL_DISABLE_CODEGEN template Value* GenGetPredicate(const TCodegenContext& ctx, @@ -134,97 +138,75 @@ using TBaseComputation = TStatelessWideFlowCodegeneratorNode } }; -class TWideFilterWithLimitWrapper : public TStatefulWideFlowCodegeneratorNode, public TBaseWideFilterWrapper { -using TBaseComputation = TStatefulWideFlowCodegeneratorNode; +class TWideFilterWithLimitWrapper : public TSimpleStatefulWideFlowCodegeneratorNode, public TBaseWideFilterWrapper { +using TBaseComputation = TSimpleStatefulWideFlowCodegeneratorNode; public: TWideFilterWithLimitWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, IComputationNode* limit, TComputationExternalNodePtrVector&& items, IComputationNode* predicate) - : TBaseComputation(mutables, flow, EValueRepresentation::Embedded) + : TBaseComputation(mutables, flow, items.size(), items.size()) , TBaseWideFilterWrapper(mutables, flow, std::move(items), predicate) , Limit(limit) {} - EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const { - if (state.IsInvalid()) { - state = Limit->GetValue(ctx); - } else if (!state.Get()) { - return EFetchResult::Finish; - } - - auto **fields = GetFields(ctx); - while (true) { - PrepareArguments(ctx, output); + void InitState(NUdf::TUnboxedValue& cntToTake, TComputationContext& ctx) const { + cntToTake = Limit->GetValue(ctx); + } - if (const auto result = Flow->FetchValues(ctx, fields); EFetchResult::One != result) - return result; + NUdf::TUnboxedValue*const* PrepareInput(NUdf::TUnboxedValue& cntToTake, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const { + return cntToTake.Get() ? PrepareArguments(ctx, output) : nullptr; + } + TMaybeFetchResult DoProcess(NUdf::TUnboxedValue& cntToTake, TComputationContext& ctx, TMaybeFetchResult fetchRes, NUdf::TUnboxedValue*const* output) const { + if (fetchRes.Empty()) { + return EFetchResult::Finish; + } else if (fetchRes.Get() == EFetchResult::One) { if (Predicate->GetValue(ctx).Get()) { FillOutputs(ctx, output); - - auto todo = state.Get(); - state = NUdf::TUnboxedValuePod(--todo); + cntToTake = NUdf::TUnboxedValuePod(cntToTake.Get() - 1); return EFetchResult::One; } + return TMaybeFetchResult::None(); } + return fetchRes; } -#ifndef MKQL_DISABLE_CODEGEN - TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const { - auto& context = ctx.Codegen.GetContext(); - - const auto init = BasicBlock::Create(context, "init", ctx.Func); - const auto test = BasicBlock::Create(context, "test", ctx.Func); - const auto loop = BasicBlock::Create(context, "loop", ctx.Func); - const auto work = BasicBlock::Create(context, "work", ctx.Func); - const auto pass = BasicBlock::Create(context, "pass", ctx.Func); - const auto exit = BasicBlock::Create(context, "exit", ctx.Func); - - const auto valueType = Type::getInt128Ty(context); - const auto resultType = Type::getInt32Ty(context); - const auto result = PHINode::Create(resultType, 3U, "result", exit); - - BranchInst::Create(test, init, IsValid(statePtr, block), block); - - block = init; - - GetNodeValue(statePtr, Limit, ctx, block); - BranchInst::Create(test, block); - - block = test; - - const auto state = new LoadInst(valueType, statePtr, "state", block); - const auto done = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, state, GetFalse(context), "done", block); - result->addIncoming(ConstantInt::get(resultType, -1), block); - - BranchInst::Create(exit, loop, done, block); - block = loop; - - auto status = GetNodeValues(Flow, ctx, block); - const auto good = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SGT, status.first, ConstantInt::get(status.first->getType(), 0), "good", block); - - result->addIncoming(status.first, block); - - BranchInst::Create(work, exit, good, block); - - block = work; - - const auto predicate = GenGetPredicate(ctx, status.second, block); - - BranchInst::Create(pass, loop, predicate, block); +#ifndef MKQL_DISABLE_CODEGEN + typename TBaseComputation::TGenerateResult GenFetchProcess(Value* statePtrVal, const TCodegenContext& ctx, const TResultCodegenerator& fetchGenerator, BasicBlock*& block) const override { + auto &context = ctx.Codegen.GetContext(); + auto fetch = BasicBlock::Create(context, "fetch", ctx.Func); + auto pass = BasicBlock::Create(context, "pass", ctx.Func); + auto check = BasicBlock::Create(context, "check", ctx.Func); + auto decr = BasicBlock::Create(context, "decr", ctx.Func); + auto maybeResultVal = PHINode::Create(TMaybeFetchResult::LLVMType(context), 4, "maybe_res", pass); + + auto stateVal = new LoadInst(statePtrVal->getType()->getPointerElementType(), statePtrVal, "state", block); + auto needFetchCond = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, GetFalse(context), stateVal, "need_fetch", block); + maybeResultVal->addIncoming(TMaybeFetchResult(EFetchResult::Finish).LLVMConst(context), block); + BranchInst::Create(fetch, pass, needFetchCond, block); + + block = fetch; + auto [fetchResVal, fetchGetters] = fetchGenerator(ctx, block); + auto passCond = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, ConstantInt::get(fetchResVal->getType(), static_cast(EFetchResult::One)), fetchResVal, "not_one", block); + maybeResultVal->addIncoming(TMaybeFetchResult::LLVMFromFetchResult(fetchResVal, "fetch_res_ext", block), block); + BranchInst::Create(pass, check, passCond, block); + + block = check; + auto predicateCond = GenGetPredicate(ctx, fetchGetters, block); + maybeResultVal->addIncoming(TMaybeFetchResult::None().LLVMConst(context), block); + BranchInst::Create(decr, pass, predicateCond, block); + + block = decr; + auto newStateVal = BinaryOperator::CreateSub(stateVal, ConstantInt::get(stateVal->getType(), 1), "new_state", block); + new StoreInst(newStateVal, statePtrVal, block); + maybeResultVal->addIncoming(TMaybeFetchResult(EFetchResult::One).LLVMConst(context), block); + BranchInst::Create(pass, block); block = pass; - const auto decr = BinaryOperator::CreateSub(state, ConstantInt::get(state->getType(), 1ULL), "decr", block); - new StoreInst(decr, statePtr, block); - - result->addIncoming(status.first, block); - - BranchInst::Create(exit, block); - - block = exit; - return {result, std::move(status.second)}; + return {maybeResultVal, std::move(fetchGetters)}; } #endif + private: void RegisterDependencies() const final { if (const auto flow = FlowDependsOn(Flow)) { @@ -238,77 +220,74 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode -class TWideTakeWhileWrapper : public TStatefulWideFlowCodegeneratorNode>, public TBaseWideFilterWrapper { -using TBaseComputation = TStatefulWideFlowCodegeneratorNode>; +class TWideTakeWhileWrapper : public TSimpleStatefulWideFlowCodegeneratorNode, bool>, public TBaseWideFilterWrapper { +using TBaseComputation = TSimpleStatefulWideFlowCodegeneratorNode, bool>; public: TWideTakeWhileWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, TComputationExternalNodePtrVector&& items, IComputationNode* predicate) - : TBaseComputation(mutables, flow, EValueRepresentation::Embedded) + : TBaseComputation(mutables, flow, items.size(), items.size()) , TBaseWideFilterWrapper(mutables, flow, std::move(items), predicate) {} - EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const { - if (!state.IsInvalid()) { - return EFetchResult::Finish; - } - - PrepareArguments(ctx, output); - - auto **fields = GetFields(ctx); - - if (const auto result = Flow->FetchValues(ctx, fields); EFetchResult::One != result) - return result; + void InitState(NUdf::TUnboxedValue& stop, TComputationContext& ) const { + stop = NUdf::TUnboxedValuePod(false); + } - const bool predicate = Predicate->GetValue(ctx).Get(); - if (!predicate) - state = NUdf::TUnboxedValuePod(); + NUdf::TUnboxedValue*const* PrepareInput(NUdf::TUnboxedValue& stop, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const { + return stop.Get() ? nullptr : PrepareArguments(ctx, output); + } - if (Inclusive || predicate) { - FillOutputs(ctx, output); - return EFetchResult::One; + TMaybeFetchResult DoProcess(NUdf::TUnboxedValue& stop, TComputationContext& ctx, TMaybeFetchResult fetchRes, NUdf::TUnboxedValue*const* output) const { + if (fetchRes.Empty()) { + return EFetchResult::Finish; + } else if (fetchRes.Get() == EFetchResult::One) { + const bool predicate = Predicate->GetValue(ctx).Get(); + if (!predicate) { + stop = NUdf::TUnboxedValuePod(true); + } + if (Inclusive || predicate) { + FillOutputs(ctx, output); + return EFetchResult::One; + } + return EFetchResult::Finish; } - - return EFetchResult::Finish; + return fetchRes; } -#ifndef MKQL_DISABLE_CODEGEN - ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const { - auto& context = ctx.Codegen.GetContext(); - - const auto resultType = Type::getInt32Ty(context); - - const auto work = BasicBlock::Create(context, "work", ctx.Func); - const auto test = BasicBlock::Create(context, "test", ctx.Func); - const auto stop = BasicBlock::Create(context, "stop", ctx.Func); - const auto done = BasicBlock::Create(context, "done", ctx.Func); - - const auto result = PHINode::Create(resultType, 4U, "result", done); - result->addIncoming(ConstantInt::get(resultType, static_cast(EFetchResult::Finish)), block); - - BranchInst::Create(done, work, IsValid(statePtr, block), block); - - block = work; - auto status = GetNodeValues(Flow, ctx, block); - const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, status.first, ConstantInt::get(resultType, 0), "special", block); - result->addIncoming(status.first, block); - BranchInst::Create(done, test, special, block); - block = test; - - const auto predicate = GenGetPredicate(ctx, status.second, block); - result->addIncoming(status.first, block); - BranchInst::Create(done, stop, predicate, block); - - block = stop; - - new StoreInst(GetEmpty(context), statePtr, block); - result->addIncoming(ConstantInt::get(resultType, static_cast(Inclusive ? EFetchResult::One: EFetchResult::Finish)), block); +#ifndef MKQL_DISABLE_CODEGEN + typename TBaseComputation::TGenerateResult GenFetchProcess(Value* statePtrVal, const TCodegenContext& ctx, const TResultCodegenerator& fetchGenerator, BasicBlock*& block) const override { + auto &context = ctx.Codegen.GetContext(); + auto fetch = BasicBlock::Create(context, "fetch", ctx.Func); + auto pass = BasicBlock::Create(context, "pass", ctx.Func); + auto check = BasicBlock::Create(context, "check", ctx.Func); + auto maybeResultVal = PHINode::Create(TMaybeFetchResult::LLVMType(context), 3, "maybe_res", pass); + + auto stateVal = new LoadInst(statePtrVal->getType()->getPointerElementType(), statePtrVal, "state", block); + auto needFetchCond = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, GetTrue(context), stateVal, "need_fetch", block); + maybeResultVal->addIncoming(TMaybeFetchResult(EFetchResult::Finish).LLVMConst(context), block); + BranchInst::Create(fetch, pass, needFetchCond, block); + + block = fetch; + auto [fetchResVal, fetchGetters] = fetchGenerator(ctx, block); + auto passCond = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, ConstantInt::get(fetchResVal->getType(), static_cast(EFetchResult::One)), fetchResVal, "not_one", block); + maybeResultVal->addIncoming(TMaybeFetchResult::LLVMFromFetchResult(fetchResVal, "fetch_res_ext", block), block); + BranchInst::Create(pass, check, passCond, block); + + block = check; + auto predicateCond = GenGetPredicate(ctx, fetchGetters, block); + auto newStateVal = SelectInst::Create(predicateCond, GetFalse(context), GetTrue(context), "new_state", block); + new StoreInst(newStateVal, statePtrVal, block); + auto retOneCond = Inclusive ? ConstantInt::getTrue(context) : predicateCond; + auto retStatusVal = SelectInst::Create(retOneCond, TMaybeFetchResult(EFetchResult::One).LLVMConst(context), TMaybeFetchResult(EFetchResult::Finish).LLVMConst(context), "ret_status", block); + maybeResultVal->addIncoming(retStatusVal, block); + BranchInst::Create(pass, block); - BranchInst::Create(done, block); + block = pass; - block = done; - return {result, std::move(status.second)}; + return {maybeResultVal, std::move(fetchGetters)}; } #endif + private: void RegisterDependencies() const final { if (const auto flow = this->FlowDependsOn(Flow)) { @@ -319,72 +298,69 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode -class TWideSkipWhileWrapper : public TStatefulWideFlowCodegeneratorNode>, public TBaseWideFilterWrapper { -using TBaseComputation = TStatefulWideFlowCodegeneratorNode>; +class TWideSkipWhileWrapper : public TSimpleStatefulWideFlowCodegeneratorNode, bool>, public TBaseWideFilterWrapper { +using TBaseComputation = TSimpleStatefulWideFlowCodegeneratorNode, bool>; public: TWideSkipWhileWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, TComputationExternalNodePtrVector&& items, IComputationNode* predicate) - : TBaseComputation(mutables, flow, EValueRepresentation::Embedded) + : TBaseComputation(mutables, flow, items.size(), items.size()) , TBaseWideFilterWrapper(mutables, flow, std::move(items), predicate) {} - EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const { - if (!state.IsInvalid()) { - return Flow->FetchValues(ctx, output); - } - - auto **fields = GetFields(ctx); - - do { - PrepareArguments(ctx, output); - if (const auto result = Flow->FetchValues(ctx, fields); EFetchResult::One != result) - return result; - } while (Predicate->GetValue(ctx).Get()); + void InitState(NUdf::TUnboxedValue& start, TComputationContext& ) const { + start = NUdf::TUnboxedValuePod(false); + } - state = NUdf::TUnboxedValuePod(); + NUdf::TUnboxedValue*const* PrepareInput(NUdf::TUnboxedValue& start, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const { + return start.Get() ? output : PrepareArguments(ctx, output); + } - if constexpr (Inclusive) - return Flow->FetchValues(ctx, output); - else { - FillOutputs(ctx, output); - return EFetchResult::One; + TMaybeFetchResult DoProcess(NUdf::TUnboxedValue& start, TComputationContext& ctx, TMaybeFetchResult fetchRes, NUdf::TUnboxedValue*const* output) const { + if (!start.Get() && fetchRes.Get() == EFetchResult::One) { + const bool predicate = Predicate->GetValue(ctx).Get(); + if (!predicate) { + start = NUdf::TUnboxedValuePod(true); + } + if (!Inclusive && !predicate) { + FillOutputs(ctx, output); + return EFetchResult::One; + } + return TMaybeFetchResult::None(); } + return fetchRes; } + #ifndef MKQL_DISABLE_CODEGEN - ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const { + typename TBaseComputation::TGenerateResult GenFetchProcess(Value* statePtrVal, const TCodegenContext& ctx, const TResultCodegenerator& fetchGenerator, BasicBlock*& block) const override { auto& context = ctx.Codegen.GetContext(); + auto check = BasicBlock::Create(context, "check", ctx.Func); + auto save = BasicBlock::Create(context, "save", ctx.Func); + auto pass = BasicBlock::Create(context, "pass", ctx.Func); + auto maybeResultVal = PHINode::Create(TMaybeFetchResult::LLVMType(context), 3, "maybe_res", pass); + + auto [fetchResVal, fetchGetters] = fetchGenerator(ctx, block); + auto stateVal = new LoadInst(statePtrVal->getType()->getPointerElementType(), statePtrVal, "state", block); + auto needCheckCond = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, GetTrue(context), stateVal, "need_check", block); + auto oneCond = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, ConstantInt::get(fetchResVal->getType(), static_cast(EFetchResult::One)), fetchResVal, "one", block); + auto willCheckCond = BinaryOperator::Create(Instruction::And, needCheckCond, oneCond, "will_check", block); + maybeResultVal->addIncoming(TMaybeFetchResult::LLVMFromFetchResult(fetchResVal, "fetch_res_ext", block), block); + BranchInst::Create(check, pass, willCheckCond, block); + + block = check; + auto predicateCond = GenGetPredicate(ctx, fetchGetters, block); + maybeResultVal->addIncoming(TMaybeFetchResult::None().LLVMConst(context), block); + BranchInst::Create(pass, save, predicateCond, block); + + block = save; + new StoreInst(GetTrue(context), statePtrVal, block); + maybeResultVal->addIncoming((Inclusive ? TMaybeFetchResult::None() : TMaybeFetchResult(EFetchResult::One)).LLVMConst(context), block); + BranchInst::Create(pass, block); - const auto resultType = Type::getInt32Ty(context); - - const auto work = BasicBlock::Create(context, "work", ctx.Func); - const auto test = BasicBlock::Create(context, "test", ctx.Func); - const auto stop = BasicBlock::Create(context, "stop", ctx.Func); - const auto done = BasicBlock::Create(context, "done", ctx.Func); - - BranchInst::Create(work, block); - - block = work; - - const auto status = GetNodeValues(Flow, ctx, block); - - const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, status.first, ConstantInt::get(resultType, 0), "special", block); - const auto passtrought = BinaryOperator::CreateOr(special, IsValid(statePtr, block), "passtrought", block); - BranchInst::Create(done, test, passtrought, block); - - block = test; - - const auto predicate = GenGetPredicate(ctx, status.second, block); - BranchInst::Create(work, stop, predicate, block); - - block = stop; - - new StoreInst(GetEmpty(context), statePtr, block); - - BranchInst::Create(Inclusive ? work : done, block); - - block = done; - return status; + block = pass; + + return {maybeResultVal, std::move(fetchGetters)}; } #endif + private: void RegisterDependencies() const final { if (const auto flow = this->FlowDependsOn(Flow)) { diff --git a/ydb/library/yql/minikql/computation/mkql_simple_codegen.cpp b/ydb/library/yql/minikql/computation/mkql_simple_codegen.cpp new file mode 100644 index 000000000000..4829d23ac221 --- /dev/null +++ b/ydb/library/yql/minikql/computation/mkql_simple_codegen.cpp @@ -0,0 +1,149 @@ +#include // Y_IGNORE +#include "mkql_simple_codegen.h" + +namespace NKikimr { +namespace NMiniKQL { + +#ifndef MKQL_DISABLE_CODEGEN +ICodegeneratorInlineWideNode::TGenerateResult TSimpleStatefulWideFlowCodegeneratorNodeLLVMBase::DoGenGetValues(const NKikimr::NMiniKQL::TCodegenContext &ctx, llvm::Value *statePtrVal, llvm::BasicBlock *&genToBlock) const { + // init stuff (mainly in global entry block) + + auto& context = ctx.Codegen.GetContext(); + + const auto valueType = Type::getInt128Ty(context); + const auto init = BasicBlock::Create(context, "init", ctx.Func); + const auto loop = BasicBlock::Create(context, "loop", ctx.Func); + const auto loopFetch = BasicBlock::Create(context, "loop_fetch", ctx.Func); + const auto loopCalc = BasicBlock::Create(context, "loop_calc", ctx.Func); + const auto loopTail = BasicBlock::Create(context, "loop_tail", ctx.Func); + const auto done = BasicBlock::Create(context, "done", ctx.Func); + const auto entryPos = &ctx.Func->getEntryBlock().back(); + + const auto thisType = StructType::get(context)->getPointerTo(); + const auto thisRawVal = ConstantInt::get(Type::getInt64Ty(context), PtrTable.ThisPtr); + const auto thisVal = CastInst::Create(Instruction::IntToPtr, thisRawVal, thisType, "this", entryPos); + const auto valuePtrType = PointerType::getUnqual(valueType); + const auto valuePtrsPtrType = PointerType::getUnqual(valuePtrType); + const auto statePtrType = statePtrVal->getType(); + const auto ctxType = ctx.Ctx->getType(); + const auto i32Type = Type::getInt32Ty(context); + const auto valueNullptrVal = ConstantPointerNull::get(valuePtrType); + const auto valuePtrNullptrVal = ConstantPointerNull::get(valuePtrsPtrType); + const auto oneVal = ConstantInt::get(i32Type, static_cast(EFetchResult::One)); + const auto maybeResType = TMaybeFetchResult::LLVMType(context); + const auto noneVal = TMaybeFetchResult::None().LLVMConst(context); + + const auto outputArrayVal = new AllocaInst(valueType, 0, ConstantInt::get(i32Type, OutWidth), "output_array", entryPos); + const auto outputPtrsVal = new AllocaInst(valuePtrType, 0, ConstantInt::get(Type::getInt64Ty(context), OutWidth), "output_ptrs", entryPos); + for (ui32 pos = 0; pos < OutWidth; pos++) { + const auto posVal = ConstantInt::get(i32Type, pos); + const auto arrayPtrVal = GetElementPtrInst::CreateInBounds(valueType, outputArrayVal, {posVal}, "array_ptr", entryPos); + const auto ptrsPtrVal = GetElementPtrInst::CreateInBounds(valuePtrType, outputPtrsVal, {posVal}, "ptrs_ptr", entryPos); + new StoreInst(arrayPtrVal, ptrsPtrVal, &ctx.Func->getEntryBlock().back()); + } + + auto block = genToBlock; // >>> start of main code chunk + + const auto stateVal = new LoadInst(valueType, statePtrVal, "state", block); + BranchInst::Create(init, loop, IsInvalid(stateVal, block), block); + + block = init; // state initialization block: + + const auto initFuncType = FunctionType::get(Type::getVoidTy(context), {thisType, statePtrType, ctxType}, false); + const auto initFuncRawVal = ConstantInt::get(Type::getInt64Ty(context), PtrTable.InitStateMethPtr); + const auto initFuncVal = CastInst::Create(Instruction::IntToPtr, initFuncRawVal, PointerType::getUnqual(initFuncType), "init_func", block); + CallInst::Create(initFuncType, initFuncVal, {thisVal, statePtrVal, ctx.Ctx}, "", block); + BranchInst::Create(loop, block); + + block = loop; // loop head block: (prepare inputs and decide whether to calculate row or not) + + const auto generated = GenFetchProcess(statePtrVal, ctx, std::bind_front(GetNodeValues, SourceFlow), block); + auto processResVal = generated.first; + if (processResVal == nullptr) { + const auto prepareFuncType = FunctionType::get(valuePtrsPtrType, {thisType, statePtrType, ctxType, valuePtrsPtrType}, false); + const auto prepareFuncRawVal = ConstantInt::get(Type::getInt64Ty(context), PtrTable.PrepareInputMethPtr); + const auto prepareFuncVal = CastInst::Create(Instruction::IntToPtr, prepareFuncRawVal, PointerType::getUnqual(prepareFuncType), "prepare_func", block); + const auto inputPtrsVal = CallInst::Create(prepareFuncType, prepareFuncVal, {thisVal, statePtrVal, ctx.Ctx, outputPtrsVal}, "input_ptrs", block); + const auto skipFetchCond = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, inputPtrsVal, valuePtrNullptrVal, "skip_fetch", block); + BranchInst::Create(loopTail, loopFetch, skipFetchCond, block); + + block = loopFetch; // loop fetch chunk: + + const auto [fetchResVal, getters] = GetNodeValues(SourceFlow, ctx, block); + const auto fetchResExtVal = new ZExtInst(fetchResVal, maybeResType, "res_ext", block); + const auto skipCalcCond = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, fetchResVal, oneVal, "skip_calc", block); + const auto fetchSourceBlock = block; + BranchInst::Create(loopTail, loopCalc, skipCalcCond, block); + + block = loopCalc; // loop calc chunk: (calculate needed values in the row) + + for (ui32 pos = 0; pos < InWidth; pos++) { + const auto stor = BasicBlock::Create(context, "stor", ctx.Func); + const auto cont = BasicBlock::Create(context, "cont", ctx.Func); + + auto innerBlock = block; // >>> start of inner chunk (calculates and stores the value if needed) + + const auto posVal = ConstantInt::get(i32Type, pos); + const auto inputPtrPtrVal = GetElementPtrInst::CreateInBounds(valuePtrType, inputPtrsVal, {posVal}, "input_ptr_ptr", innerBlock); + const auto inputPtrVal = new LoadInst(valuePtrType, inputPtrPtrVal, "input_ptr", innerBlock); + const auto isNullCond = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, inputPtrVal, valueNullptrVal, "is_null", innerBlock); + BranchInst::Create(cont, stor, isNullCond, innerBlock); + + innerBlock = stor; // calculate & store chunk: + + new StoreInst(getters[pos](ctx, innerBlock), inputPtrVal, innerBlock); + BranchInst::Create(cont, innerBlock); + + innerBlock = cont; // skip input value block: + + /* nothing here yet */ + + block = innerBlock; // <<< end of inner chunk + } + const auto calcSourceBlock = block; + BranchInst::Create(loopTail, block); + + block = loopTail; // loop tail block: (process row) + + const auto maybeFetchResVal = PHINode::Create(maybeResType, 2, "fetch_res", block); + maybeFetchResVal->addIncoming(noneVal, loop); + maybeFetchResVal->addIncoming(fetchResExtVal, fetchSourceBlock); + maybeFetchResVal->addIncoming(fetchResExtVal, calcSourceBlock); + const auto processFuncType = FunctionType::get(maybeResType, {thisType, statePtrType, ctxType, maybeResType, valuePtrsPtrType}, false); + const auto processFuncRawVal = ConstantInt::get(Type::getInt64Ty(context), PtrTable.DoProcessMethPtr); + const auto processFuncVal = CastInst::Create(Instruction::IntToPtr, processFuncRawVal, PointerType::getUnqual(processFuncType), "process_func", block); + processResVal = CallInst::Create(processFuncType, processFuncVal, {thisVal, statePtrVal, ctx.Ctx, maybeFetchResVal, outputPtrsVal}, "process_res", block); + } else { + BranchInst::Create(loopFetch, loopFetch); + BranchInst::Create(loopCalc, loopCalc); + BranchInst::Create(loopTail, loopTail); + } + const auto brkCond = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, processResVal, noneVal, "brk", block); + BranchInst::Create(done, loop, brkCond, block); + + block = done; // finalization block: + + const auto processResTruncVal = new TruncInst(processResVal, i32Type, "res_trunc", block); + + genToBlock = block; // <<< end of main code chunk + + if (generated.first) { + return {processResTruncVal, generated.second}; + } + + ICodegeneratorInlineWideNode::TGettersList new_getters; + new_getters.reserve(OutWidth); + for (size_t pos = 0; pos < OutWidth; pos++) { + new_getters.push_back([pos, outputArrayVal, i32Type, valueType] (const TCodegenContext&, BasicBlock*& block) -> Value* { + const auto posVal = ConstantInt::get(i32Type, pos); + const auto arrayPtrVal = GetElementPtrInst::CreateInBounds(valueType, outputArrayVal, {posVal}, "array_ptr", block); + const auto valueVal = new LoadInst(valueType, arrayPtrVal, "value", block); + return valueVal; + }); + } + return {processResTruncVal, std::move(new_getters)}; +} +#endif + +} +} \ No newline at end of file diff --git a/ydb/library/yql/minikql/computation/mkql_simple_codegen.h b/ydb/library/yql/minikql/computation/mkql_simple_codegen.h new file mode 100644 index 000000000000..b87509006eb1 --- /dev/null +++ b/ydb/library/yql/minikql/computation/mkql_simple_codegen.h @@ -0,0 +1,120 @@ +#pragma once + +namespace NKikimr { +namespace NMiniKQL { + +class TMaybeFetchResult final { + ui64 Raw; + + explicit TMaybeFetchResult(ui64 raw) : Raw(raw) {} + +public: + /* implicit */ TMaybeFetchResult(EFetchResult res) : TMaybeFetchResult(static_cast(static_cast(res))) {} + + + [[nodiscard]] bool Empty() const { + return Raw >> ui64(32); + } + + [[nodiscard]] EFetchResult Get() const { + Y_ABORT_IF(Empty()); + return static_cast(static_cast(Raw)); + } + + [[nodiscard]] ui64 RawU64() const { + return Raw; + } + + static TMaybeFetchResult None() { + return TMaybeFetchResult(ui64(1) << ui64(32)); + } + +#ifndef MKQL_DISABLE_CODEGEN + static Type* LLVMType(LLVMContext& context) { + return Type::getInt64Ty(context); + } + + static Value* LLVMFromFetchResult(Value *fetchRes, const Twine& name, BasicBlock* block) { + return new ZExtInst(fetchRes, LLVMType(fetchRes->getContext()), name, block); + } + + Value* LLVMConst(LLVMContext& context) const { + return ConstantInt::get(LLVMType(context), RawU64()); + } +#endif +}; + +#ifndef MKQL_DISABLE_CODEGEN +using TResultCodegenerator = std::function; +#endif + +class TSimpleStatefulWideFlowCodegeneratorNodeLLVMBase { +public: + struct TMethPtrTable { + uintptr_t ThisPtr; + uintptr_t InitStateMethPtr; + uintptr_t PrepareInputMethPtr; + uintptr_t DoProcessMethPtr; + }; + + TSimpleStatefulWideFlowCodegeneratorNodeLLVMBase(IComputationWideFlowNode* source, ui32 inWidth, ui32 outWidth, TMethPtrTable ptrTable) + : SourceFlow(source), InWidth(inWidth), OutWidth(outWidth) + , PtrTable(ptrTable) {} + +#ifndef MKQL_DISABLE_CODEGEN + virtual ICodegeneratorInlineWideNode::TGenerateResult GenFetchProcess(Value* statePtrVal, const TCodegenContext& ctx, const TResultCodegenerator& fetchGenerator, BasicBlock*& block) const { + Y_UNUSED(statePtrVal); + Y_UNUSED(ctx); + Y_UNUSED(fetchGenerator); + Y_UNUSED(block); + return {nullptr, {}}; + } + + ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtrVal, BasicBlock*& genToBlock) const; +#endif + +protected: + IComputationWideFlowNode* const SourceFlow; + const ui32 InWidth, OutWidth; + const TMethPtrTable PtrTable; +}; + +template +class TSimpleStatefulWideFlowCodegeneratorNode + : public TStatefulWideFlowCodegeneratorNode> + , public TSimpleStatefulWideFlowCodegeneratorNodeLLVMBase { + using TBase = TStatefulWideFlowCodegeneratorNode; + using TLLVMBase = TSimpleStatefulWideFlowCodegeneratorNodeLLVMBase; + +protected: + TSimpleStatefulWideFlowCodegeneratorNode(TComputationMutables& mutables, IComputationWideFlowNode* source, ui32 inWidth, ui32 outWidth) + : TBase(mutables, source, StateKind) + , TLLVMBase(source, inWidth, outWidth, { + .ThisPtr = reinterpret_cast(this), + .InitStateMethPtr = GetMethodPtr(&TDerived::InitState), + .PrepareInputMethPtr = GetMethodPtr(&TDerived::PrepareInput), + .DoProcessMethPtr = GetMethodPtr(&TDerived::DoProcess) + }) {} + +public: + EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const { + if (state.IsInvalid()) { + state = NUdf::TUnboxedValuePod(); + static_cast(this)->InitState(state, ctx); + } + NUdf::TUnboxedValue *const stub = nullptr; + if (!output && !OutWidth) { + output = &stub; + } + auto result = TMaybeFetchResult::None(); + while (result.Empty()) { + NUdf::TUnboxedValue*const* input = static_cast(this)->PrepareInput(state, ctx, output); + TMaybeFetchResult fetchResult = input ? SourceFlow->FetchValues(ctx, input) : TMaybeFetchResult::None(); + result = static_cast(this)->DoProcess(state, ctx, fetchResult, output); + } + return result.Get(); + } +}; + +} +} \ No newline at end of file diff --git a/ydb/library/yql/minikql/computation/ya.make.inc b/ydb/library/yql/minikql/computation/ya.make.inc index 35a960acdf4e..1b3dc2b7d498 100644 --- a/ydb/library/yql/minikql/computation/ya.make.inc +++ b/ydb/library/yql/minikql/computation/ya.make.inc @@ -10,6 +10,7 @@ SET(ORIG_SRC_DIR ydb/library/yql/minikql/computation) SET(ORIG_SOURCES mkql_computation_node_codegen.cpp + mkql_simple_codegen.cpp mkql_computation_node_graph.cpp mkql_computation_node_graph_saveload.cpp mkql_computation_node_holders_codegen.cpp