Skip to content

Commit

Permalink
[KQP] Fix recursion problem when computing SimplifiedPlan (#9519) (#9631
Browse files Browse the repository at this point in the history
)

Co-authored-by: pilik <pudge1000-7@ydb.tech>
  • Loading branch information
pavelvelikhov and pashandor789 authored Sep 23, 2024
1 parent 16c4b26 commit 979fcef
Show file tree
Hide file tree
Showing 2 changed files with 211 additions and 119 deletions.
298 changes: 184 additions & 114 deletions ydb/core/kqp/opt/kqp_query_plan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1945,6 +1945,9 @@ TVector<NJson::TJsonValue> RemoveRedundantNodes(NJson::TJsonValue& plan, const T
}
}

if (!planMap.contains("Node Type")) {
return {};
}
const auto typeName = planMap.at("Node Type").GetStringSafe();
if (redundantNodes.contains(typeName) || typeName.find("Precompute") != TString::npos) {
return children;
Expand All @@ -1953,167 +1956,235 @@ TVector<NJson::TJsonValue> RemoveRedundantNodes(NJson::TJsonValue& plan, const T
return {plan};
}

NJson::TJsonValue ReconstructQueryPlanRec(const NJson::TJsonValue& plan,
int operatorIndex,
const THashMap<int, NJson::TJsonValue>& planIndex,
const THashMap<TString, NJson::TJsonValue>& precomputes,
int& nodeCounter) {

int currentNodeId = nodeCounter++;

NJson::TJsonValue result;
result["PlanNodeId"] = currentNodeId;

if (plan.GetMapSafe().contains("PlanNodeType")) {
result["PlanNodeType"] = plan.GetMapSafe().at("PlanNodeType").GetStringSafe();
}
struct TQueryPlanReconstructor {
TQueryPlanReconstructor(
const THashMap<int, NJson::TJsonValue>& planIndex,
const THashMap<TString, NJson::TJsonValue>& precomputes
)
: PlanIndex(planIndex)
, Precomputes(precomputes)
, NodeIDCounter(0)
, Budget(10'000)
{}

if (plan.GetMapSafe().contains("Stats") && operatorIndex==0) {
result["Stats"] = plan.GetMapSafe().at("Stats");
}
NJson::TJsonValue Reconstruct(
const NJson::TJsonValue& plan,
int operatorIndex
) {
int currentNodeId = NodeIDCounter++;

if (!plan.GetMapSafe().contains("Operators")) {
NJson::TJsonValue planInputs;
NJson::TJsonValue result;
result["PlanNodeId"] = currentNodeId;

result["Node Type"] = plan.GetMapSafe().at("Node Type").GetStringSafe();
if (--Budget <= 0) {
YQL_CLOG(DEBUG, ProviderKqp) << "Can't build the plan - recursion depth has been exceeded!";
return result;
}

if (plan.GetMapSafe().contains("CTE Name")) {
auto precompute = plan.GetMapSafe().at("CTE Name").GetStringSafe();
if (precomputes.contains(precompute)) {
planInputs.AppendValue(ReconstructQueryPlanRec(precomputes.at(precompute), 0, planIndex, precomputes, nodeCounter));
}
if (plan.GetMapSafe().contains("PlanNodeType")) {
result["PlanNodeType"] = plan.GetMapSafe().at("PlanNodeType").GetStringSafe();
}

if (!plan.GetMapSafe().contains("Plans")) {
result["Plans"] = planInputs;
return result;
if (plan.GetMapSafe().contains("Stats") && operatorIndex==0) {
result["Stats"] = plan.GetMapSafe().at("Stats");
}

if (plan.GetMapSafe().at("Node Type").GetStringSafe() == "TableLookup") {
if (plan.GetMapSafe().at("Node Type") == "TableLookupJoin" && plan.GetMapSafe().contains("Table")) {
result["Node Type"] = "LookupJoin";
NJson::TJsonValue newOps;
NJson::TJsonValue op;

op["Name"] = "TableLookup";
op["Columns"] = plan.GetMapSafe().at("Columns");
op["Name"] = "LookupJoin";
op["LookupKeyColumns"] = plan.GetMapSafe().at("LookupKeyColumns");
op["Table"] = plan.GetMapSafe().at("Table");

newOps.AppendValue(std::move(op));
result["Operators"] = std::move(newOps);

NJson::TJsonValue newPlans;

NJson::TJsonValue lookupPlan;
lookupPlan["Node Type"] = "TableLookup";
lookupPlan["PlanNodeType"] = "TableLookup";

NJson::TJsonValue lookupOps;
NJson::TJsonValue lookupOp;

lookupOp["Name"] = "TableLookup";
lookupOp["Columns"] = plan.GetMapSafe().at("Columns");
lookupOp["LookupKeyColumns"] = plan.GetMapSafe().at("LookupKeyColumns");
lookupOp["Table"] = plan.GetMapSafe().at("Table");

if (plan.GetMapSafe().contains("E-Cost")) {
op["E-Cost"] = plan.GetMapSafe().at("E-Cost");
}
lookupOp["E-Cost"] = plan.GetMapSafe().at("E-Cost");
}
if (plan.GetMapSafe().contains("E-Rows")) {
op["E-Rows"] = plan.GetMapSafe().at("E-Rows");
lookupOp["E-Rows"] = plan.GetMapSafe().at("E-Rows");
}
if (plan.GetMapSafe().contains("E-Size")) {
op["E-Size"] = plan.GetMapSafe().at("E-Size");
lookupOp["E-Size"] = plan.GetMapSafe().at("E-Size");
}

newOps.AppendValue(op);
lookupOps.AppendValue(std::move(lookupOp));
lookupPlan["Operators"] = std::move(lookupOps);

newPlans.AppendValue(Reconstruct(plan.GetMapSafe().at("Plans").GetArraySafe()[0], 0));

newPlans.AppendValue(std::move(lookupPlan));

result["Plans"] = std::move(newPlans);

result["Operators"] = newOps;
return result;
}

for (auto p : plan.GetMapSafe().at("Plans").GetArraySafe()) {
if (!p.GetMapSafe().contains("Operators") && p.GetMapSafe().contains("CTE Name")) {
auto precompute = p.GetMapSafe().at("CTE Name").GetStringSafe();
if (precomputes.contains(precompute)) {
planInputs.AppendValue(ReconstructQueryPlanRec(precomputes.at(precompute), 0, planIndex, precomputes, nodeCounter));
if (!plan.GetMapSafe().contains("Operators")) {
NJson::TJsonValue planInputs;

result["Node Type"] = plan.GetMapSafe().at("Node Type").GetStringSafe();

if (plan.GetMapSafe().contains("CTE Name")) {
auto precompute = plan.GetMapSafe().at("CTE Name").GetStringSafe();
if (Precomputes.contains(precompute)) {
planInputs.AppendValue(Reconstruct(Precomputes.at(precompute), 0));
}
} else if (p.GetMapSafe().at("Node Type").GetStringSafe().find("Precompute") == TString::npos) {
planInputs.AppendValue(ReconstructQueryPlanRec(p, 0, planIndex, precomputes, nodeCounter));
}
}
result["Plans"] = planInputs;
return result;
}

if (plan.GetMapSafe().contains("CTE Name") && plan.GetMapSafe().at("Node Type").GetStringSafe() == "ConstantExpr") {
auto precompute = plan.GetMapSafe().at("CTE Name").GetStringSafe();
if (!precomputes.contains(precompute)) {
result["Node Type"] = plan.GetMapSafe().at("Node Type");
if (!plan.GetMapSafe().contains("Plans")) {
result["Plans"] = std::move(planInputs);
return result;
}

if (plan.GetMapSafe().at("Node Type").GetStringSafe() == "TableLookup") {
NJson::TJsonValue newOps;
NJson::TJsonValue op;

op["Name"] = "TableLookup";
op["Columns"] = plan.GetMapSafe().at("Columns");
op["LookupKeyColumns"] = plan.GetMapSafe().at("LookupKeyColumns");
op["Table"] = plan.GetMapSafe().at("Table");

if (plan.GetMapSafe().contains("E-Cost")) {
op["E-Cost"] = plan.GetMapSafe().at("E-Cost");
}
if (plan.GetMapSafe().contains("E-Rows")) {
op["E-Rows"] = plan.GetMapSafe().at("E-Rows");
}
if (plan.GetMapSafe().contains("E-Size")) {
op["E-Size"] = plan.GetMapSafe().at("E-Size");
}

newOps.AppendValue(std::move(op));

result["Operators"] = std::move(newOps);
return result;
}

for (auto p : plan.GetMapSafe().at("Plans").GetArraySafe()) {
if (!p.GetMapSafe().contains("Operators") && p.GetMapSafe().contains("CTE Name")) {
auto precompute = p.GetMapSafe().at("CTE Name").GetStringSafe();
if (Precomputes.contains(precompute)) {
planInputs.AppendValue(Reconstruct(Precomputes.at(precompute), 0));
}
} else if (p.GetMapSafe().at("Node Type").GetStringSafe().find("Precompute") == TString::npos) {
planInputs.AppendValue(Reconstruct(p, 0));
}
}
result["Plans"] = planInputs;
return result;
}

return ReconstructQueryPlanRec(precomputes.at(precompute), 0, planIndex, precomputes, nodeCounter);
}
if (plan.GetMapSafe().contains("CTE Name") && plan.GetMapSafe().at("Node Type").GetStringSafe() == "ConstantExpr") {
auto precompute = plan.GetMapSafe().at("CTE Name").GetStringSafe();
if (!Precomputes.contains(precompute)) {
result["Node Type"] = plan.GetMapSafe().at("Node Type");
return result;
}

auto ops = plan.GetMapSafe().at("Operators").GetArraySafe();
auto op = ops[operatorIndex];
return Reconstruct(Precomputes.at(precompute), 0);
}

TVector<NJson::TJsonValue> planInputs;
auto ops = plan.GetMapSafe().at("Operators").GetArraySafe();
auto op = ops[operatorIndex];

auto opName = op.GetMapSafe().at("Name").GetStringSafe();
TVector<NJson::TJsonValue> planInputs;

THashSet<ui32> processedExternalOperators;
THashSet<ui32> processedInternalOperators;
for (auto opInput : op.GetMapSafe().at("Inputs").GetArraySafe()) {
auto opName = op.GetMapSafe().at("Name").GetStringSafe();

if (opInput.GetMapSafe().contains("ExternalPlanNodeId")) {
auto inputPlanKey = opInput.GetMapSafe().at("ExternalPlanNodeId").GetIntegerSafe();
THashSet<ui32> processedExternalOperators;
THashSet<ui32> processedInternalOperators;
for (auto opInput : op.GetMapSafe().at("Inputs").GetArraySafe()) {

if (processedExternalOperators.contains(inputPlanKey)) {
continue;
}
processedExternalOperators.insert(inputPlanKey);
if (opInput.GetMapSafe().contains("ExternalPlanNodeId")) {
auto inputPlanKey = opInput.GetMapSafe().at("ExternalPlanNodeId").GetIntegerSafe();

auto inputPlan = planIndex.at(inputPlanKey);
planInputs.push_back( ReconstructQueryPlanRec(inputPlan, 0, planIndex, precomputes, nodeCounter));
} else if (opInput.GetMapSafe().contains("InternalOperatorId")) {
auto inputPlanId = opInput.GetMapSafe().at("InternalOperatorId").GetIntegerSafe();
if (processedExternalOperators.contains(inputPlanKey)) {
continue;
}
processedExternalOperators.insert(inputPlanKey);

if (processedInternalOperators.contains(inputPlanId)) {
continue;
}
processedInternalOperators.insert(inputPlanId);
auto inputPlan = PlanIndex.at(inputPlanKey);
planInputs.push_back( Reconstruct(inputPlan, 0) );
} else if (opInput.GetMapSafe().contains("InternalOperatorId")) {
auto inputPlanId = opInput.GetMapSafe().at("InternalOperatorId").GetIntegerSafe();

planInputs.push_back( ReconstructQueryPlanRec(plan, inputPlanId, planIndex, precomputes, nodeCounter));
if (processedInternalOperators.contains(inputPlanId)) {
continue;
}
processedInternalOperators.insert(inputPlanId);

planInputs.push_back( Reconstruct(plan, inputPlanId) );
}
}
}

if (op.GetMapSafe().contains("Inputs")) {
op.GetMapSafe().erase("Inputs");
}
if (op.GetMapSafe().contains("Inputs")) {
op.GetMapSafe().erase("Inputs");
}

if (op.GetMapSafe().contains("Input")
|| op.GetMapSafe().contains("ToFlow")
|| op.GetMapSafe().contains("Member")
|| op.GetMapSafe().contains("AssumeSorted")
|| op.GetMapSafe().contains("Iterator")) {
if (op.GetMapSafe().contains("Input")
|| op.GetMapSafe().contains("ToFlow")
|| op.GetMapSafe().contains("Member")
|| op.GetMapSafe().contains("AssumeSorted")
|| op.GetMapSafe().contains("Iterator")) {

TString maybePrecompute = "";
if (op.GetMapSafe().contains("Input")) {
maybePrecompute = op.GetMapSafe().at("Input").GetStringSafe();
} else if (op.GetMapSafe().contains("ToFlow")) {
maybePrecompute = op.GetMapSafe().at("ToFlow").GetStringSafe();
} else if (op.GetMapSafe().contains("Member")) {
maybePrecompute = op.GetMapSafe().at("Member").GetStringSafe();
} else if (op.GetMapSafe().contains("AssumeSorted")) {
maybePrecompute = op.GetMapSafe().at("AssumeSorted").GetStringSafe();
} else if (op.GetMapSafe().contains("Iterator")) {
maybePrecompute = op.GetMapSafe().at("Iterator").GetStringSafe();
}
TString maybePrecompute = "";
if (op.GetMapSafe().contains("Input")) {
maybePrecompute = op.GetMapSafe().at("Input").GetStringSafe();
} else if (op.GetMapSafe().contains("ToFlow")) {
maybePrecompute = op.GetMapSafe().at("ToFlow").GetStringSafe();
} else if (op.GetMapSafe().contains("Member")) {
maybePrecompute = op.GetMapSafe().at("Member").GetStringSafe();
} else if (op.GetMapSafe().contains("AssumeSorted")) {
maybePrecompute = op.GetMapSafe().at("AssumeSorted").GetStringSafe();
} else if (op.GetMapSafe().contains("Iterator")) {
maybePrecompute = op.GetMapSafe().at("Iterator").GetStringSafe();
}

if (precomputes.contains(maybePrecompute) && planInputs.empty()) {
planInputs.push_back(ReconstructQueryPlanRec(precomputes.at(maybePrecompute), 0, planIndex, precomputes, nodeCounter));
if (Precomputes.contains(maybePrecompute) && planInputs.empty()) {
planInputs.push_back(Reconstruct(Precomputes.at(maybePrecompute), 0));
}
}
}

result["Node Type"] = opName;
NJson::TJsonValue newOps;
newOps.AppendValue(op);
result["Operators"] = newOps;
result["Node Type"] = std::move(opName);
NJson::TJsonValue newOps;
newOps.AppendValue(std::move(op));
result["Operators"] = std::move(newOps);

if (planInputs.size()){
NJson::TJsonValue plans;
for( auto i : planInputs) {
plans.AppendValue(i);
if (!planInputs.empty()){
NJson::TJsonValue plans;
for(auto&& i : planInputs) {
plans.AppendValue(std::move(i));
}
result["Plans"] = std::move(plans);
}
result["Plans"] = plans;

return result;
}

return result;
}
private:
const THashMap<int, NJson::TJsonValue>& PlanIndex;
const THashMap<TString, NJson::TJsonValue>& Precomputes;
ui32 NodeIDCounter;
i32 Budget; // Prevent bugs with inf recursion
};

double ComputeCpuTimes(NJson::TJsonValue& plan) {
double currCpuTime = 0;
Expand Down Expand Up @@ -2209,8 +2280,7 @@ NJson::TJsonValue SimplifyQueryPlan(NJson::TJsonValue& plan) {

BuildPlanIndex(plan, planIndex, precomputes);

int nodeCounter = 0;
plan = ReconstructQueryPlanRec(plan, 0, planIndex, precomputes, nodeCounter);
plan = TQueryPlanReconstructor(planIndex, precomputes).Reconstruct(plan, 0);

RemoveRedundantNodes(plan, redundantNodes);
ComputeCpuTimes(plan);
Expand Down
Loading

0 comments on commit 979fcef

Please sign in to comment.