Skip to content

Commit

Permalink
Disable spilling in LastWideCombiner if input or state is not seriali…
Browse files Browse the repository at this point in the history
…zable. (#4590)
  • Loading branch information
Darych authored May 16, 2024
1 parent f9b9e70 commit be4c88f
Showing 1 changed file with 36 additions and 16 deletions.
52 changes: 36 additions & 16 deletions ydb/library/yql/minikql/comp_nodes/mkql_wide_combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ class TSpillingSupportState : public TComputationValue<TSpillingSupportState> {
TMemoryUsageInfo* memInfo,
const TCombinerNodes& nodes, IComputationWideFlowNode *const flow, size_t wideFieldsIndex,
const TMultiType* usedInputItemType, const TMultiType* keyAndStateType, ui32 keyWidth,
const THashFunc& hash, const TEqualsFunc& equal
const THashFunc& hash, const TEqualsFunc& equal, bool allowSpilling
)
: TBase(memInfo)
, InMemoryProcessingState(memInfo, keyWidth, keyAndStateType->GetElementsCount() - keyWidth, hash, equal)
Expand All @@ -358,6 +358,7 @@ class TSpillingSupportState : public TComputationValue<TSpillingSupportState> {
, Mode(EOperatingMode::InMemory)
, MemInfo(memInfo)
, Equal(equal)
, AllowSpilling(allowSpilling)
{
BufferForUsedInputItems.reserve(usedInputItemType->GetElementsCount());
BufferForKeyAnsState.reserve(keyAndStateType->GetElementsCount());
Expand Down Expand Up @@ -430,7 +431,7 @@ class TSpillingSupportState : public TComputationValue<TSpillingSupportState> {
isNew ? nullptr : static_cast<NUdf::TUnboxedValue *>(InMemoryProcessingState.Tongue),
static_cast<NUdf::TUnboxedValue *>(InMemoryProcessingState.Throat)
);
if (ctx.SpillerFactory && IsSwitchToSpillingModeCondition()) {
if (AllowSpilling && ctx.SpillerFactory && IsSwitchToSpillingModeCondition()) {
SwitchMode(EOperatingMode::Spilling, ctx);
return EFetchResult::Yield;
}
Expand Down Expand Up @@ -488,7 +489,7 @@ class TSpillingSupportState : public TComputationValue<TSpillingSupportState> {
} else {
bucket.SpilledData->AsyncWriteCompleted(bucket.AsyncWriteOperation->ExtractValue());
bucket.AsyncWriteOperation = std::nullopt;
}
}
}
}
}
Expand Down Expand Up @@ -536,7 +537,7 @@ class TSpillingSupportState : public TComputationValue<TSpillingSupportState> {
Nodes.ExtractKey(ctx, fields, static_cast<NUdf::TUnboxedValue *>(BufferForKeyAnsState.data()));

auto hash = Hasher(BufferForKeyAnsState.data());

auto bucketId = hash % SpilledBucketCount;

auto& bucket = SpilledBuckets[bucketId];
Expand Down Expand Up @@ -608,7 +609,7 @@ class TSpillingSupportState : public TComputationValue<TSpillingSupportState> {
AsyncReadOperation = std::nullopt;
}
while(!SpilledBuckets.empty()){

auto& bucket = SpilledBuckets.front();
//recover spilled state
while(!bucket.SpilledState->Empty()) {
Expand Down Expand Up @@ -656,7 +657,7 @@ class TSpillingSupportState : public TComputationValue<TSpillingSupportState> {
);
BufferForUsedInputItems.resize(0);
}

if (const auto values = static_cast<NUdf::TUnboxedValue*>(bucket.InMemoryProcessingState->Extract())) {
Nodes.FinishItem(ctx, values, output);

Expand Down Expand Up @@ -732,6 +733,7 @@ class TSpillingSupportState : public TComputationValue<TSpillingSupportState> {

TMemoryUsageInfo* MemInfo = nullptr;
TEqualsFunc const Equal;
const bool AllowSpilling;
};

#ifndef MKQL_DISABLE_CODEGEN
Expand Down Expand Up @@ -1205,24 +1207,26 @@ class TWideLastCombinerWrapper: public TStatefulWideFlowCodegeneratorNode<TWideL
using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWrapper>;
public:
TWideLastCombinerWrapper(
TComputationMutables& mutables,
IComputationWideFlowNode* flow,
TCombinerNodes&& nodes,
TComputationMutables& mutables,
IComputationWideFlowNode* flow,
TCombinerNodes&& nodes,
const TMultiType* usedInputItemType,
TKeyTypes&& keyTypes,
const TMultiType* keyAndStateType)
const TMultiType* keyAndStateType,
bool allowSpilling)
: TBaseComputation(mutables, flow, EValueRepresentation::Boxed)
, Flow(flow)
, Nodes(std::move(nodes))
, KeyTypes(std::move(keyTypes))
, UsedInputItemType(usedInputItemType)
, KeyAndStateType(keyAndStateType)
, WideFieldsIndex(mutables.IncrementWideFieldsIndex(Nodes.ItemNodes.size()))
, AllowSpilling(allowSpilling)
{}

EFetchResult DoCalculate(NUdf::TUnboxedValue& stateValue, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
if (!stateValue.HasValue()) {
MakeSpillingSupportState(ctx, stateValue);
MakeSpillingSupportState(ctx, stateValue, AllowSpilling);
}
auto *const state = static_cast<TSpillingSupportState *>(stateValue.AsBoxed().Get());
return state->DoCalculate(ctx, output);
Expand Down Expand Up @@ -1466,12 +1470,13 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
#endif
}

void MakeSpillingSupportState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
void MakeSpillingSupportState(TComputationContext& ctx, NUdf::TUnboxedValue& state, bool allowSpilling) const {
state = ctx.HolderFactory.Create<TSpillingSupportState>(Nodes, Flow, WideFieldsIndex,
UsedInputItemType, KeyAndStateType,
Nodes.KeyNodes.size(),
TMyValueHasher(KeyTypes),
TMyValueEqual(KeyTypes)
TMyValueEqual(KeyTypes),
allowSpilling
);
}

Expand All @@ -1493,6 +1498,8 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra

const ui32 WideFieldsIndex;

const bool AllowSpilling;

#ifndef MKQL_DISABLE_CODEGEN
TEqualsPtr Equals = nullptr;
THashPtr Hash = nullptr;
Expand Down Expand Up @@ -1523,6 +1530,11 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideLastCombinerWra
#endif
};

bool IsTypeSerializable(const TType* type) {
return ! (type->IsResource() || type->IsType() || type->IsStream() || type->IsCallable()
|| type->IsAny() || type->IsFlow() || type->IsReservedKind());
}

}

template<bool Last>
Expand All @@ -1542,13 +1554,17 @@ IComputationNode* WrapWideCombinerT(TCallable& callable, const TComputationNodeF

++index += inputWidth;

bool allowSpilling = true;

std::vector<TType*> keyAndStateItemTypes;
keyAndStateItemTypes.reserve(keysSize + stateSize);

TKeyTypes keyTypes;
keyTypes.reserve(keysSize);
for (ui32 i = index; i < index + keysSize; ++i) {
keyAndStateItemTypes.push_back(callable.GetInput(i).GetStaticType());
TType *type = callable.GetInput(i).GetStaticType();
allowSpilling = allowSpilling && IsTypeSerializable(type);
keyAndStateItemTypes.push_back(type);
bool optional;
keyTypes.emplace_back(*UnpackOptionalData(callable.GetInput(i).GetStaticType(), optional)->GetDataSlot(), optional);
}
Expand All @@ -1560,7 +1576,9 @@ IComputationNode* WrapWideCombinerT(TCallable& callable, const TComputationNodeF
index += keysSize;
nodes.InitResultNodes.reserve(stateSize);
for (size_t i = 0; i != stateSize; ++i) {
keyAndStateItemTypes.push_back(callable.GetInput(index).GetStaticType());
TType *type = callable.GetInput(index).GetStaticType();
allowSpilling = allowSpilling && IsTypeSerializable(type);
keyAndStateItemTypes.push_back(type);
nodes.InitResultNodes.push_back(LocateNode(ctx.NodeLocator, callable, index++));
}

Expand Down Expand Up @@ -1597,13 +1615,15 @@ IComputationNode* WrapWideCombinerT(TCallable& callable, const TComputationNodeF
usedInputItemTypes.reserve(inputItemTypes.size());
for (size_t i = 0; i != inputItemTypes.size(); ++i) {
if (nodes.IsInputItemNodeUsed(i)) {
allowSpilling = allowSpilling && IsTypeSerializable(inputItemTypes[i]);
usedInputItemTypes.push_back(inputItemTypes[i]);
}
}
return new TWideLastCombinerWrapper(ctx.Mutables, wide, std::move(nodes),
TMultiType::Create(usedInputItemTypes.size(), usedInputItemTypes.data(), ctx.Env),
std::move(keyTypes),
TMultiType::Create(keyAndStateItemTypes.size(),keyAndStateItemTypes.data(), ctx.Env)
TMultiType::Create(keyAndStateItemTypes.size(),keyAndStateItemTypes.data(), ctx.Env),
allowSpilling
);
} else {
if constexpr (RuntimeVersion < 46U) {
Expand Down

0 comments on commit be4c88f

Please sign in to comment.