Skip to content

Commit

Permalink
Generalize DoGenGetValues of wide While-variations
Browse files Browse the repository at this point in the history
  • Loading branch information
avevad committed Feb 21, 2024
1 parent 5cc4e57 commit 70651ae
Showing 1 changed file with 48 additions and 125 deletions.
173 changes: 48 additions & 125 deletions ydb/library/yql/minikql/comp_nodes/mkql_wide_filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ class TBaseWideFilterWrapper {
if (Predicate == Items[i] || Items[i]->GetDependencesCount() > 0U)
*out = *fields[i];
}

bool ApplyPredicate(TComputationContext& ctx, NUdf::TUnboxedValue*const* values) const {
auto **fields = GetFields(ctx);
PrepareArguments(ctx, values);
for (size_t idx = 0; idx < Items.size(); idx++) {
*fields[idx] = *values[idx];
}
return Predicate->GetValue(ctx).Get<bool>();
}

#ifndef MKQL_DISABLE_CODEGEN
template<bool ReplaceOriginalGetter = true>
Value* GenGetPredicate(const TCodegenContext& ctx,
Expand Down Expand Up @@ -154,12 +164,7 @@ using TBaseComputation = TSimpleStatefulWideFlowCodegeneratorNode<TWideFilterWit
return EProcessResult::Finish;
}
if (fetchRes == EFetchResult::One) {
auto **fields = GetFields(ctx);
PrepareArguments(ctx, values);
for (size_t idx = 0; idx < Items.size(); idx++) {
*fields[idx] = *values[idx];
}
if (Predicate->GetValue(ctx).Get<bool>()) {
if (ApplyPredicate(ctx, values)) {
FillOutputs(ctx, values);
limit--;
return EProcessResult::One;
Expand All @@ -182,80 +187,37 @@ using TBaseComputation = TSimpleStatefulWideFlowCodegeneratorNode<TWideFilterWit
};

template<bool Inclusive>
class TWideTakeWhileWrapper : public TStatefulWideFlowCodegeneratorNode<TWideTakeWhileWrapper<Inclusive>>, public TBaseWideFilterWrapper {
using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideTakeWhileWrapper<Inclusive>>;
class TWideTakeWhileWrapper : public TSimpleStatefulWideFlowCodegeneratorNode<TWideTakeWhileWrapper<Inclusive>, bool>, public TBaseWideFilterWrapper {
using TBaseComputation = TSimpleStatefulWideFlowCodegeneratorNode<TWideTakeWhileWrapper<Inclusive>, bool>;
public:
TWideTakeWhileWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, TComputationExternalNodePtrVector&& items,
IComputationNode* predicate)
: TBaseComputation(mutables, flow, EValueRepresentation::Embedded)
, TBaseWideFilterWrapper(mutables, flow, std::move(items), predicate)
{}

EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
if (!state.IsInvalid()) {
return EFetchResult::Finish;
}

PrepareArguments(ctx, output);

auto **fields = GetFields(ctx);

if (const auto result = Flow->FetchValues(ctx, fields); EFetchResult::One != result)
return result;

const bool predicate = Predicate->GetValue(ctx).Get<bool>();
if (!predicate)
state = NUdf::TUnboxedValuePod();
void InitState(bool& stop, TComputationContext& ctx) const {
stop = false;
}

if (Inclusive || predicate) {
FillOutputs(ctx, output);
return EFetchResult::One;
TBaseComputation::EProcessResult DoProcess(bool& stop, TComputationContext& ctx, EFetchResult fetchRes, NUdf::TUnboxedValue*const* values) const {
if (stop) {
return TBaseComputation::EProcessResult::Finish;
}

return EFetchResult::Finish;
if (fetchRes == EFetchResult::One) {
const bool predicate = ApplyPredicate(ctx, values);
if (!predicate) {
stop = true;
}
if (Inclusive || predicate) {
FillOutputs(ctx, values);
return TBaseComputation::EProcessResult::One;
}
return TBaseComputation::EProcessResult::Finish;
}
return static_cast<TBaseComputation::EProcessResult>(fetchRes);
}
#ifndef MKQL_DISABLE_CODEGEN
ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
auto& context = ctx.Codegen.GetContext();

const auto resultType = Type::getInt32Ty(context);

const auto work = BasicBlock::Create(context, "work", ctx.Func);
const auto test = BasicBlock::Create(context, "test", ctx.Func);
const auto stop = BasicBlock::Create(context, "stop", ctx.Func);
const auto done = BasicBlock::Create(context, "done", ctx.Func);

const auto result = PHINode::Create(resultType, 4U, "result", done);
result->addIncoming(ConstantInt::get(resultType, static_cast<i32>(EFetchResult::Finish)), block);

const auto state = new LoadInst(Type::getInt128Ty(context), statePtr, "state", block);
const auto finished = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, state, GetTrue(context), "finished", block);

BranchInst::Create(done, work, IsValid(statePtr, block), block);

block = work;
auto status = GetNodeValues(Flow, ctx, block);
const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, status.first, ConstantInt::get(resultType, 0), "special", block);
result->addIncoming(status.first, block);
BranchInst::Create(done, test, special, block);

block = test;

const auto predicate = GenGetPredicate(ctx, status.second, block);
result->addIncoming(status.first, block);
BranchInst::Create(done, stop, predicate, block);

block = stop;

new StoreInst(GetEmpty(context), statePtr, block);
result->addIncoming(ConstantInt::get(resultType, static_cast<i32>(Inclusive ? EFetchResult::One: EFetchResult::Finish)), block);

BranchInst::Create(done, block);

block = done;
return {result, std::move(status.second)};
}
#endif
private:
void RegisterDependencies() const final {
if (const auto flow = this->FlowDependsOn(Flow)) {
Expand All @@ -266,72 +228,33 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideTakeWhileWrappe
};

template<bool Inclusive>
class TWideSkipWhileWrapper : public TStatefulWideFlowCodegeneratorNode<TWideSkipWhileWrapper<Inclusive>>, public TBaseWideFilterWrapper {
using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideSkipWhileWrapper<Inclusive>>;
class TWideSkipWhileWrapper : public TSimpleStatefulWideFlowCodegeneratorNode<TWideSkipWhileWrapper<Inclusive>, bool>, public TBaseWideFilterWrapper {
using TBaseComputation = TSimpleStatefulWideFlowCodegeneratorNode<TWideSkipWhileWrapper<Inclusive>, bool>;
public:
TWideSkipWhileWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, TComputationExternalNodePtrVector&& items, IComputationNode* predicate)
: TBaseComputation(mutables, flow, EValueRepresentation::Embedded)
, TBaseWideFilterWrapper(mutables, flow, std::move(items), predicate)
{}

EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
if (!state.IsInvalid()) {
return Flow->FetchValues(ctx, output);
}

auto **fields = GetFields(ctx);

do {
PrepareArguments(ctx, output);
if (const auto result = Flow->FetchValues(ctx, fields); EFetchResult::One != result)
return result;
} while (Predicate->GetValue(ctx).Get<bool>());

state = NUdf::TUnboxedValuePod();
void InitState(bool& start, TComputationContext& ctx) const {
start = false;
}

if constexpr (Inclusive)
return Flow->FetchValues(ctx, output);
else {
FillOutputs(ctx, output);
return EFetchResult::One;
TBaseComputation::EProcessResult DoProcess(bool& start, TComputationContext& ctx, EFetchResult fetchRes, NUdf::TUnboxedValue*const* values) const {
if (!start && fetchRes == EFetchResult::One) {
const bool predicate = ApplyPredicate(ctx, values);
if (!predicate) {
start = true;
}
if (!Inclusive && !predicate) {
FillOutputs(ctx, values);
return TBaseComputation::EProcessResult::One;
}
return TBaseComputation::EProcessResult::Fetch;
}
return static_cast<TBaseComputation::EProcessResult>(fetchRes);
}
#ifndef MKQL_DISABLE_CODEGEN
ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
auto& context = ctx.Codegen.GetContext();

const auto resultType = Type::getInt32Ty(context);

const auto work = BasicBlock::Create(context, "work", ctx.Func);
const auto test = BasicBlock::Create(context, "test", ctx.Func);
const auto stop = BasicBlock::Create(context, "stop", ctx.Func);
const auto done = BasicBlock::Create(context, "done", ctx.Func);

BranchInst::Create(work, block);

block = work;

const auto status = GetNodeValues(Flow, ctx, block);

const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, status.first, ConstantInt::get(resultType, 0), "special", block);
const auto passtrought = BinaryOperator::CreateOr(special, IsValid(statePtr, block), "passtrought", block);
BranchInst::Create(done, test, passtrought, block);

block = test;

const auto predicate = GenGetPredicate<false>(ctx, status.second, block);
BranchInst::Create(work, stop, predicate, block);

block = stop;

new StoreInst(GetEmpty(context), statePtr, block);

BranchInst::Create(Inclusive ? work : done, block);

block = done;
return status;
}
#endif
private:
void RegisterDependencies() const final {
if (const auto flow = this->FlowDependsOn(Flow)) {
Expand Down

0 comments on commit 70651ae

Please sign in to comment.