diff --git a/ydb/core/kqp/opt/kqp_query_plan.cpp b/ydb/core/kqp/opt/kqp_query_plan.cpp index bf0e642eb5d9..e00ff75bfff2 100644 --- a/ydb/core/kqp/opt/kqp_query_plan.cpp +++ b/ydb/core/kqp/opt/kqp_query_plan.cpp @@ -1945,6 +1945,9 @@ TVector RemoveRedundantNodes(NJson::TJsonValue& plan, const T } } + if (!planMap.contains("Node Type")) { + return {}; + } const auto typeName = planMap.at("Node Type").GetStringSafe(); if (redundantNodes.contains(typeName) || typeName.find("Precompute") != TString::npos) { return children; @@ -1953,167 +1956,235 @@ TVector RemoveRedundantNodes(NJson::TJsonValue& plan, const T return {plan}; } -NJson::TJsonValue ReconstructQueryPlanRec(const NJson::TJsonValue& plan, - int operatorIndex, - const THashMap& planIndex, - const THashMap& precomputes, - int& nodeCounter) { - - int currentNodeId = nodeCounter++; - - NJson::TJsonValue result; - result["PlanNodeId"] = currentNodeId; - - if (plan.GetMapSafe().contains("PlanNodeType")) { - result["PlanNodeType"] = plan.GetMapSafe().at("PlanNodeType").GetStringSafe(); - } +struct TQueryPlanReconstructor { + TQueryPlanReconstructor( + const THashMap& planIndex, + const THashMap& precomputes + ) + : PlanIndex(planIndex) + , Precomputes(precomputes) + , NodeIDCounter(0) + , Budget(10'000) + {} - if (plan.GetMapSafe().contains("Stats") && operatorIndex==0) { - result["Stats"] = plan.GetMapSafe().at("Stats"); - } + NJson::TJsonValue Reconstruct( + const NJson::TJsonValue& plan, + int operatorIndex + ) { + int currentNodeId = NodeIDCounter++; - if (!plan.GetMapSafe().contains("Operators")) { - NJson::TJsonValue planInputs; + NJson::TJsonValue result; + result["PlanNodeId"] = currentNodeId; - result["Node Type"] = plan.GetMapSafe().at("Node Type").GetStringSafe(); + if (--Budget <= 0) { + YQL_CLOG(DEBUG, ProviderKqp) << "Can't build the plan - recursion depth has been exceeded!"; + return result; + } - if (plan.GetMapSafe().contains("CTE Name")) { - auto precompute = plan.GetMapSafe().at("CTE Name").GetStringSafe(); - if (precomputes.contains(precompute)) { - planInputs.AppendValue(ReconstructQueryPlanRec(precomputes.at(precompute), 0, planIndex, precomputes, nodeCounter)); - } + if (plan.GetMapSafe().contains("PlanNodeType")) { + result["PlanNodeType"] = plan.GetMapSafe().at("PlanNodeType").GetStringSafe(); } - if (!plan.GetMapSafe().contains("Plans")) { - result["Plans"] = planInputs; - return result; + if (plan.GetMapSafe().contains("Stats") && operatorIndex==0) { + result["Stats"] = plan.GetMapSafe().at("Stats"); } - if (plan.GetMapSafe().at("Node Type").GetStringSafe() == "TableLookup") { + if (plan.GetMapSafe().at("Node Type") == "TableLookupJoin" && plan.GetMapSafe().contains("Table")) { + result["Node Type"] = "LookupJoin"; NJson::TJsonValue newOps; NJson::TJsonValue op; - op["Name"] = "TableLookup"; - op["Columns"] = plan.GetMapSafe().at("Columns"); + op["Name"] = "LookupJoin"; op["LookupKeyColumns"] = plan.GetMapSafe().at("LookupKeyColumns"); - op["Table"] = plan.GetMapSafe().at("Table"); + + newOps.AppendValue(std::move(op)); + result["Operators"] = std::move(newOps); + + NJson::TJsonValue newPlans; + + NJson::TJsonValue lookupPlan; + lookupPlan["Node Type"] = "TableLookup"; + lookupPlan["PlanNodeType"] = "TableLookup"; + + NJson::TJsonValue lookupOps; + NJson::TJsonValue lookupOp; + + lookupOp["Name"] = "TableLookup"; + lookupOp["Columns"] = plan.GetMapSafe().at("Columns"); + lookupOp["LookupKeyColumns"] = plan.GetMapSafe().at("LookupKeyColumns"); + lookupOp["Table"] = plan.GetMapSafe().at("Table"); if (plan.GetMapSafe().contains("E-Cost")) { - op["E-Cost"] = plan.GetMapSafe().at("E-Cost"); - } + lookupOp["E-Cost"] = plan.GetMapSafe().at("E-Cost"); + } if (plan.GetMapSafe().contains("E-Rows")) { - op["E-Rows"] = plan.GetMapSafe().at("E-Rows"); + lookupOp["E-Rows"] = plan.GetMapSafe().at("E-Rows"); } if (plan.GetMapSafe().contains("E-Size")) { - op["E-Size"] = plan.GetMapSafe().at("E-Size"); + lookupOp["E-Size"] = plan.GetMapSafe().at("E-Size"); } - newOps.AppendValue(op); + lookupOps.AppendValue(std::move(lookupOp)); + lookupPlan["Operators"] = std::move(lookupOps); + + newPlans.AppendValue(Reconstruct(plan.GetMapSafe().at("Plans").GetArraySafe()[0], 0)); + + newPlans.AppendValue(std::move(lookupPlan)); + + result["Plans"] = std::move(newPlans); - result["Operators"] = newOps; return result; } - for (auto p : plan.GetMapSafe().at("Plans").GetArraySafe()) { - if (!p.GetMapSafe().contains("Operators") && p.GetMapSafe().contains("CTE Name")) { - auto precompute = p.GetMapSafe().at("CTE Name").GetStringSafe(); - if (precomputes.contains(precompute)) { - planInputs.AppendValue(ReconstructQueryPlanRec(precomputes.at(precompute), 0, planIndex, precomputes, nodeCounter)); + if (!plan.GetMapSafe().contains("Operators")) { + NJson::TJsonValue planInputs; + + result["Node Type"] = plan.GetMapSafe().at("Node Type").GetStringSafe(); + + if (plan.GetMapSafe().contains("CTE Name")) { + auto precompute = plan.GetMapSafe().at("CTE Name").GetStringSafe(); + if (Precomputes.contains(precompute)) { + planInputs.AppendValue(Reconstruct(Precomputes.at(precompute), 0)); } - } else if (p.GetMapSafe().at("Node Type").GetStringSafe().find("Precompute") == TString::npos) { - planInputs.AppendValue(ReconstructQueryPlanRec(p, 0, planIndex, precomputes, nodeCounter)); } - } - result["Plans"] = planInputs; - return result; - } - if (plan.GetMapSafe().contains("CTE Name") && plan.GetMapSafe().at("Node Type").GetStringSafe() == "ConstantExpr") { - auto precompute = plan.GetMapSafe().at("CTE Name").GetStringSafe(); - if (!precomputes.contains(precompute)) { - result["Node Type"] = plan.GetMapSafe().at("Node Type"); + if (!plan.GetMapSafe().contains("Plans")) { + result["Plans"] = std::move(planInputs); + return result; + } + + if (plan.GetMapSafe().at("Node Type").GetStringSafe() == "TableLookup") { + NJson::TJsonValue newOps; + NJson::TJsonValue op; + + op["Name"] = "TableLookup"; + op["Columns"] = plan.GetMapSafe().at("Columns"); + op["LookupKeyColumns"] = plan.GetMapSafe().at("LookupKeyColumns"); + op["Table"] = plan.GetMapSafe().at("Table"); + + if (plan.GetMapSafe().contains("E-Cost")) { + op["E-Cost"] = plan.GetMapSafe().at("E-Cost"); + } + if (plan.GetMapSafe().contains("E-Rows")) { + op["E-Rows"] = plan.GetMapSafe().at("E-Rows"); + } + if (plan.GetMapSafe().contains("E-Size")) { + op["E-Size"] = plan.GetMapSafe().at("E-Size"); + } + + newOps.AppendValue(std::move(op)); + + result["Operators"] = std::move(newOps); + return result; + } + + for (auto p : plan.GetMapSafe().at("Plans").GetArraySafe()) { + if (!p.GetMapSafe().contains("Operators") && p.GetMapSafe().contains("CTE Name")) { + auto precompute = p.GetMapSafe().at("CTE Name").GetStringSafe(); + if (Precomputes.contains(precompute)) { + planInputs.AppendValue(Reconstruct(Precomputes.at(precompute), 0)); + } + } else if (p.GetMapSafe().at("Node Type").GetStringSafe().find("Precompute") == TString::npos) { + planInputs.AppendValue(Reconstruct(p, 0)); + } + } + result["Plans"] = planInputs; return result; } - return ReconstructQueryPlanRec(precomputes.at(precompute), 0, planIndex, precomputes, nodeCounter); - } + if (plan.GetMapSafe().contains("CTE Name") && plan.GetMapSafe().at("Node Type").GetStringSafe() == "ConstantExpr") { + auto precompute = plan.GetMapSafe().at("CTE Name").GetStringSafe(); + if (!Precomputes.contains(precompute)) { + result["Node Type"] = plan.GetMapSafe().at("Node Type"); + return result; + } - auto ops = plan.GetMapSafe().at("Operators").GetArraySafe(); - auto op = ops[operatorIndex]; + return Reconstruct(Precomputes.at(precompute), 0); + } - TVector planInputs; + auto ops = plan.GetMapSafe().at("Operators").GetArraySafe(); + auto op = ops[operatorIndex]; - auto opName = op.GetMapSafe().at("Name").GetStringSafe(); + TVector planInputs; - THashSet processedExternalOperators; - THashSet processedInternalOperators; - for (auto opInput : op.GetMapSafe().at("Inputs").GetArraySafe()) { + auto opName = op.GetMapSafe().at("Name").GetStringSafe(); - if (opInput.GetMapSafe().contains("ExternalPlanNodeId")) { - auto inputPlanKey = opInput.GetMapSafe().at("ExternalPlanNodeId").GetIntegerSafe(); + THashSet processedExternalOperators; + THashSet processedInternalOperators; + for (auto opInput : op.GetMapSafe().at("Inputs").GetArraySafe()) { - if (processedExternalOperators.contains(inputPlanKey)) { - continue; - } - processedExternalOperators.insert(inputPlanKey); + if (opInput.GetMapSafe().contains("ExternalPlanNodeId")) { + auto inputPlanKey = opInput.GetMapSafe().at("ExternalPlanNodeId").GetIntegerSafe(); - auto inputPlan = planIndex.at(inputPlanKey); - planInputs.push_back( ReconstructQueryPlanRec(inputPlan, 0, planIndex, precomputes, nodeCounter)); - } else if (opInput.GetMapSafe().contains("InternalOperatorId")) { - auto inputPlanId = opInput.GetMapSafe().at("InternalOperatorId").GetIntegerSafe(); + if (processedExternalOperators.contains(inputPlanKey)) { + continue; + } + processedExternalOperators.insert(inputPlanKey); - if (processedInternalOperators.contains(inputPlanId)) { - continue; - } - processedInternalOperators.insert(inputPlanId); + auto inputPlan = PlanIndex.at(inputPlanKey); + planInputs.push_back( Reconstruct(inputPlan, 0) ); + } else if (opInput.GetMapSafe().contains("InternalOperatorId")) { + auto inputPlanId = opInput.GetMapSafe().at("InternalOperatorId").GetIntegerSafe(); - planInputs.push_back( ReconstructQueryPlanRec(plan, inputPlanId, planIndex, precomputes, nodeCounter)); + if (processedInternalOperators.contains(inputPlanId)) { + continue; + } + processedInternalOperators.insert(inputPlanId); + + planInputs.push_back( Reconstruct(plan, inputPlanId) ); + } } - } - if (op.GetMapSafe().contains("Inputs")) { - op.GetMapSafe().erase("Inputs"); - } + if (op.GetMapSafe().contains("Inputs")) { + op.GetMapSafe().erase("Inputs"); + } - if (op.GetMapSafe().contains("Input") - || op.GetMapSafe().contains("ToFlow") - || op.GetMapSafe().contains("Member") - || op.GetMapSafe().contains("AssumeSorted") - || op.GetMapSafe().contains("Iterator")) { + if (op.GetMapSafe().contains("Input") + || op.GetMapSafe().contains("ToFlow") + || op.GetMapSafe().contains("Member") + || op.GetMapSafe().contains("AssumeSorted") + || op.GetMapSafe().contains("Iterator")) { - TString maybePrecompute = ""; - if (op.GetMapSafe().contains("Input")) { - maybePrecompute = op.GetMapSafe().at("Input").GetStringSafe(); - } else if (op.GetMapSafe().contains("ToFlow")) { - maybePrecompute = op.GetMapSafe().at("ToFlow").GetStringSafe(); - } else if (op.GetMapSafe().contains("Member")) { - maybePrecompute = op.GetMapSafe().at("Member").GetStringSafe(); - } else if (op.GetMapSafe().contains("AssumeSorted")) { - maybePrecompute = op.GetMapSafe().at("AssumeSorted").GetStringSafe(); - } else if (op.GetMapSafe().contains("Iterator")) { - maybePrecompute = op.GetMapSafe().at("Iterator").GetStringSafe(); - } + TString maybePrecompute = ""; + if (op.GetMapSafe().contains("Input")) { + maybePrecompute = op.GetMapSafe().at("Input").GetStringSafe(); + } else if (op.GetMapSafe().contains("ToFlow")) { + maybePrecompute = op.GetMapSafe().at("ToFlow").GetStringSafe(); + } else if (op.GetMapSafe().contains("Member")) { + maybePrecompute = op.GetMapSafe().at("Member").GetStringSafe(); + } else if (op.GetMapSafe().contains("AssumeSorted")) { + maybePrecompute = op.GetMapSafe().at("AssumeSorted").GetStringSafe(); + } else if (op.GetMapSafe().contains("Iterator")) { + maybePrecompute = op.GetMapSafe().at("Iterator").GetStringSafe(); + } - if (precomputes.contains(maybePrecompute) && planInputs.empty()) { - planInputs.push_back(ReconstructQueryPlanRec(precomputes.at(maybePrecompute), 0, planIndex, precomputes, nodeCounter)); + if (Precomputes.contains(maybePrecompute) && planInputs.empty()) { + planInputs.push_back(Reconstruct(Precomputes.at(maybePrecompute), 0)); + } } - } - result["Node Type"] = opName; - NJson::TJsonValue newOps; - newOps.AppendValue(op); - result["Operators"] = newOps; + result["Node Type"] = std::move(opName); + NJson::TJsonValue newOps; + newOps.AppendValue(std::move(op)); + result["Operators"] = std::move(newOps); - if (planInputs.size()){ - NJson::TJsonValue plans; - for( auto i : planInputs) { - plans.AppendValue(i); + if (!planInputs.empty()){ + NJson::TJsonValue plans; + for(auto&& i : planInputs) { + plans.AppendValue(std::move(i)); + } + result["Plans"] = std::move(plans); } - result["Plans"] = plans; + + return result; } - return result; -} +private: + const THashMap& PlanIndex; + const THashMap& Precomputes; + ui32 NodeIDCounter; + i32 Budget; // Prevent bugs with inf recursion +}; double ComputeCpuTimes(NJson::TJsonValue& plan) { double currCpuTime = 0; @@ -2209,8 +2280,7 @@ NJson::TJsonValue SimplifyQueryPlan(NJson::TJsonValue& plan) { BuildPlanIndex(plan, planIndex, precomputes); - int nodeCounter = 0; - plan = ReconstructQueryPlanRec(plan, 0, planIndex, precomputes, nodeCounter); + plan = TQueryPlanReconstructor(planIndex, precomputes).Reconstruct(plan, 0); RemoveRedundantNodes(plan, redundantNodes); ComputeCpuTimes(plan); diff --git a/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp b/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp index c0e8bd96aebd..733693f7251a 100644 --- a/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp +++ b/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp @@ -77,13 +77,35 @@ static TKikimrRunner GetKikimrWithJoinSettings(bool useStreamLookupJoin = false, class TChainConstructor { public: - TChainConstructor(size_t chainSize) - : Kikimr_(GetKikimrWithJoinSettings()) - , TableClient_(Kikimr_.GetTableClient()) - , Session_(TableClient_.CreateSession().GetValueSync().GetSession()) - , ChainSize_(chainSize) + TChainTester(size_t chainSize) + : Kikimr(GetKikimrWithJoinSettings(false, GetStats(chainSize))) + , TableClient(Kikimr.GetTableClient()) + , Session(TableClient.CreateSession().GetValueSync().GetSession()) + , ChainSize(chainSize) {} +public: + void Test() { + CreateTables(); + JoinTables(); + } + + static TString GetStats(size_t chainSize) { + srand(228); + NJson::TJsonValue stats; + for (size_t i = 0; i < chainSize; ++i) { + ui64 nRows = rand(); + NJson::TJsonValue tableStat; + tableStat["n_rows"] = nRows; + tableStat["byte_size"] = nRows * 10; + + TString table = Sprintf("/Root/table_%ld", i); + stats[table] = std::move(tableStat); + } + return stats.GetStringRobust(); + } + +private: void CreateTables() { for (size_t i = 0; i < ChainSize_; ++i) { TString tableName;