Skip to content

Commit

Permalink
Update LLVM part of combiner with spilling. (#6277)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tony-Romanov authored Jul 4, 2024
1 parent f256aea commit 29e1e45
Showing 1 changed file with 80 additions and 62 deletions.
142 changes: 80 additions & 62 deletions ydb/library/yql/minikql/comp_nodes/mkql_wide_combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,8 @@ class TSpillingSupportState : public TComputationValue<TSpillingSupportState> {
};

public:
enum class ETasteResult: ui8 {
Init,
enum class ETasteResult: i8 {
Init = -1,
Update,
Skip
};
Expand Down Expand Up @@ -372,15 +372,9 @@ class TSpillingSupportState : public TComputationValue<TSpillingSupportState> {
Tongue = InMemoryProcessingState.Tongue;
Throat = InMemoryProcessingState.Throat;
}
~TSpillingSupportState() {
}

bool IsFetchRequired() const {
return InputStatus != EFetchResult::Finish;
}

bool HasAnyData() const {
return SpilledBuckets.size();
return !SpilledBuckets.empty();
}

bool IsProcessingRequired() const {
Expand Down Expand Up @@ -456,6 +450,20 @@ class TSpillingSupportState : public TComputationValue<TSpillingSupportState> {
return ETasteResult::Skip;
}

NUdf::TUnboxedValuePod* Extract() {
if (GetMode() == EOperatingMode::InMemory) return static_cast<NUdf::TUnboxedValue*>(InMemoryProcessingState.Extract());

MKQL_ENSURE(SpilledBuckets.front().BucketState == TSpilledBucket::EBucketState::InMemory, "Internal logic error");
MKQL_ENSURE(SpilledBuckets.size() > 0, "Internal logic error");

auto value = static_cast<NUdf::TUnboxedValue*>(SpilledBuckets.front().InMemoryProcessingState->Extract());
if (!value) {
SpilledBuckets.pop_front();
}

return value;
}
private:
void MoveKeyToBucket(TSpilledBucket& bucket) {
for (size_t i = 0; i < KeyWidth; ++i) {
//jumping into unsafe world, refusing ownership
Expand Down Expand Up @@ -483,20 +491,6 @@ class TSpillingSupportState : public TComputationValue<TSpillingSupportState> {
BufferForUsedInputItems.resize(0);
}

NUdf::TUnboxedValuePod* Extract() {
if (GetMode() == EOperatingMode::InMemory) return static_cast<NUdf::TUnboxedValue*>(InMemoryProcessingState.Extract());

MKQL_ENSURE(SpilledBuckets.front().BucketState == TSpilledBucket::EBucketState::InMemory, "Internal logic error");
MKQL_ENSURE(SpilledBuckets.size() > 0, "Internal logic error");

auto value = static_cast<NUdf::TUnboxedValue*>(SpilledBuckets.front().InMemoryProcessingState->Extract());
if (!value) {
SpilledBuckets.pop_front();
}

return value;
}

bool FlushSpillingBuffersAndWait() {
UpdateSpillingBuckets();

Expand All @@ -521,7 +515,6 @@ class TSpillingSupportState : public TComputationValue<TSpillingSupportState> {
return ProcessSpilledDataAndWait();
}

private:
void SplitStateIntoBuckets() {
while (const auto keyAndState = static_cast<NUdf::TUnboxedValue *>(InMemoryProcessingState.Extract())) {
auto hash = Hasher(keyAndState); //Hasher uses only key for hashing
Expand Down Expand Up @@ -1246,7 +1239,7 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra

EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
if (!state.HasValue()) {
MakeSpillingSupportState(ctx, state);
MakeState(ctx, state);
}

if (const auto ptr = static_cast<TSpillingSupportState*>(state.AsBoxed().Get())) {
Expand Down Expand Up @@ -1306,6 +1299,7 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
const auto valueType = Type::getInt128Ty(context);
const auto ptrValueType = PointerType::getUnqual(valueType);
const auto statusType = Type::getInt32Ty(context);
const auto wayType = Type::getInt8Ty(context);

TLLVMFieldsStructureState stateFields(context);

Expand All @@ -1332,26 +1326,39 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
const auto state = new LoadInst(valueType, statePtr, "state", block);
const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
const auto boolFuncType = FunctionType::get(Type::getInt1Ty(context), {stateArg->getType()}, false);
BranchInst::Create(more, block);

block = more;

const auto loop = BasicBlock::Create(context, "loop", ctx.Func);
const auto full = BasicBlock::Create(context, "full", ctx.Func);
const auto over = BasicBlock::Create(context, "over", ctx.Func);
const auto result = PHINode::Create(statusType, 3U, "result", over);

const auto statusPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetStatus() }, "last", block);
const auto last = new LoadInst(statusType, statusPtr, "last", block);
const auto finish = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, last, ConstantInt::get(last->getType(), static_cast<i32>(EFetchResult::Finish)), "finish", block);

BranchInst::Create(full, loop, finish, block);
const auto result = PHINode::Create(statusType, 4U, "result", over);

{
const auto test = BasicBlock::Create(context, "test", ctx.Func);
const auto pull = BasicBlock::Create(context, "pull", ctx.Func);
const auto rest = BasicBlock::Create(context, "rest", ctx.Func);
const auto proc = BasicBlock::Create(context, "proc", ctx.Func);
const auto good = BasicBlock::Create(context, "good", ctx.Func);

block = loop;
block = more;

const auto waitMoreFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSpillingSupportState::UpdateAndWait));
const auto waitMoreFuncPtr = CastInst::Create(Instruction::IntToPtr, waitMoreFunc, PointerType::getUnqual(boolFuncType), "wait_more_func", block);
const auto waitMore = CallInst::Create(boolFuncType, waitMoreFuncPtr, { stateArg }, "wait_more", block);

result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block);

BranchInst::Create(over, test, waitMore, block);

block = test;

const auto statusPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetStatus() }, "last", block);
const auto last = new LoadInst(statusType, statusPtr, "last", block);
const auto finish = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, last, ConstantInt::get(last->getType(), static_cast<i32>(EFetchResult::Finish)), "finish", block);

BranchInst::Create(good, pull, finish, block);

block = pull;

const auto getres = GetNodeValues(Flow, ctx, block);

Expand All @@ -1362,12 +1369,19 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
choise->addCase(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), rest);

block = rest;
new StoreInst(ConstantInt::get(last->getType(), static_cast<i32>(EFetchResult::Finish)), statusPtr, block);

BranchInst::Create(full, block);
new StoreInst(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), statusPtr, block);
BranchInst::Create(more, block);

block = good;

const auto processingFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSpillingSupportState::IsProcessingRequired));
const auto processingFuncPtr = CastInst::Create(Instruction::IntToPtr, processingFunc, PointerType::getUnqual(boolFuncType), "processing_func", block);
const auto processing = CallInst::Create(boolFuncType, processingFuncPtr, { stateArg }, "processing", block);

BranchInst::Create(proc, full, processing, block);

block = proc;

std::vector<Value*> items(Nodes.ItemNodes.size(), nullptr);
for (ui32 i = 0U; i < items.size(); ++i) {
if (Nodes.ItemNodes[i]->GetDependencesCount() > 0U)
Expand Down Expand Up @@ -1398,10 +1412,10 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
new StoreInst(key, keyPtr, block);
}

const auto atFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::TasteIt));
const auto atType = FunctionType::get(Type::getInt1Ty(context), {stateArg->getType()}, false);
const auto atFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSpillingSupportState::TasteIt));
const auto atType = FunctionType::get(wayType, {stateArg->getType()}, false);
const auto atPtr = CastInst::Create(Instruction::IntToPtr, atFunc, PointerType::getUnqual(atType), "function", block);
const auto newKey = CallInst::Create(atType, atPtr, {stateArg}, "new_key", block);
const auto taste= CallInst::Create(atType, atPtr, {stateArg}, "taste", block);

const auto init = BasicBlock::Create(context, "init", ctx.Func);
const auto next = BasicBlock::Create(context, "next", ctx.Func);
Expand All @@ -1415,7 +1429,9 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
pointers.emplace_back(GetElementPtrInst::CreateInBounds(valueType, throat, {ConstantInt::get(Type::getInt32Ty(context), i)}, (TString("state_") += ToString(i)).c_str(), block));
}

BranchInst::Create(init, next, newKey, block);
const auto way = SwitchInst::Create(taste, more, 2U, block);
way->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::ETasteResult::Init)), init);
way->addCase(ConstantInt::get(wayType, static_cast<i8>(TSpillingSupportState::ETasteResult::Update)), next);

block = init;

Expand All @@ -1439,7 +1455,7 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
}
}

BranchInst::Create(loop, block);
BranchInst::Create(more, block);

block = next;

Expand Down Expand Up @@ -1484,23 +1500,22 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
}
}

BranchInst::Create(loop, block);
BranchInst::Create(more, block);
}

{
block = full;

const auto good = BasicBlock::Create(context, "good", ctx.Func);
const auto last = BasicBlock::Create(context, "last", ctx.Func);

const auto extractFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Extract));
const auto extractFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSpillingSupportState::Extract));
const auto extractType = FunctionType::get(ptrValueType, {stateArg->getType()}, false);
const auto extractPtr = CastInst::Create(Instruction::IntToPtr, extractFunc, PointerType::getUnqual(extractType), "extract", block);
const auto out = CallInst::Create(extractType, extractPtr, {stateArg}, "out", block);
const auto has = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, out, ConstantPointerNull::get(ptrValueType), "has", block);

result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block);

BranchInst::Create(good, over, has, block);
BranchInst::Create(good, last, has, block);

block = good;

Expand All @@ -1514,6 +1529,16 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra

result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), block);
BranchInst::Create(over, block);

block = last;

const auto hasDataFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSpillingSupportState::HasAnyData));
const auto hasDataFuncPtr = CastInst::Create(Instruction::IntToPtr, hasDataFunc, PointerType::getUnqual(boolFuncType), "has_data_func", block);
const auto hasData = CallInst::Create(boolFuncType, hasDataFuncPtr, { stateArg }, "has_data", block);

result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block);

BranchInst::Create(more, over, hasData, block);
}

block = over;
Expand All @@ -1528,23 +1553,17 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
#endif
private:
void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
#ifdef MKQL_DISABLE_CODEGEN
state = ctx.HolderFactory.Create<TState>(Nodes.KeyNodes.size(), Nodes.StateNodes.size(), TMyValueHasher(KeyTypes), TMyValueEqual(KeyTypes));
#else
state = ctx.HolderFactory.Create<TState>(Nodes.KeyNodes.size(), Nodes.StateNodes.size(),
ctx.ExecuteLLVM && Hash ? THashFunc(std::ptr_fun(Hash)) : THashFunc(TMyValueHasher(KeyTypes)),
ctx.ExecuteLLVM && Equals ? TEqualsFunc(std::ptr_fun(Equals)) : TEqualsFunc(TMyValueEqual(KeyTypes))
);
#endif
}

void MakeSpillingSupportState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
state = ctx.HolderFactory.Create<TSpillingSupportState>(WideFieldsIndex,
UsedInputItemType, KeyAndStateType,
Nodes.KeyNodes.size(),
Nodes.ItemNodes.size(),
#ifdef MKQL_DISABLE_CODEGEN
TMyValueHasher(KeyTypes),
TMyValueEqual(KeyTypes),
#else
ctx.ExecuteLLVM && Hash ? THashFunc(std::ptr_fun(Hash)) : THashFunc(TMyValueHasher(KeyTypes)),
ctx.ExecuteLLVM && Equals ? TEqualsFunc(std::ptr_fun(Equals)) : TEqualsFunc(TMyValueEqual(KeyTypes)),
#endif
AllowSpilling,
ctx
);
Expand All @@ -1569,7 +1588,6 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
const ui32 WideFieldsIndex;

const bool AllowSpilling;

#ifndef MKQL_DISABLE_CODEGEN
TEqualsPtr Equals = nullptr;
THashPtr Hash = nullptr;
Expand Down Expand Up @@ -1626,7 +1644,7 @@ IComputationNode* WrapWideCombinerT(TCallable& callable, const TComputationNodeF
keyTypes.reserve(keysSize);
for (ui32 i = index; i < index + keysSize; ++i) {
TType *type = callable.GetInput(i).GetStaticType();
keyAndStateItemTypes.push_back(type);
keyAndStateItemTypes.push_back(type);
bool optional;
keyTypes.emplace_back(*UnpackOptionalData(callable.GetInput(i).GetStaticType(), optional)->GetDataSlot(), optional);
}
Expand Down

0 comments on commit 29e1e45

Please sign in to comment.