diff --git a/ydb/core/kqp/host/kqp_runner.cpp b/ydb/core/kqp/host/kqp_runner.cpp index 2925bd58a32a..6db4bb3285be 100644 --- a/ydb/core/kqp/host/kqp_runner.cpp +++ b/ydb/core/kqp/host/kqp_runner.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include @@ -20,6 +21,8 @@ #include #include +#include + #include namespace NKikimr { @@ -143,6 +146,7 @@ class TKqpRunner : public IKqpRunner { , OptimizeCtx(MakeIntrusive(cluster, Config, sessionCtx->QueryPtr(), sessionCtx->TablesPtr())) , BuildQueryCtx(MakeIntrusive()) + , Pctx(TKqpProviderContext(*OptimizeCtx, Config->CostBasedOptimizationLevel.Get().GetOrElse(TDqSettings::TDefault::CostBasedOptimizationLevel))) { CreateGraphTransformer(typesCtx, sessionCtx, funcRegistry); } @@ -259,8 +263,8 @@ class TKqpRunner : public IKqpRunner { .AddPostTypeAnnotation(/* forSubgraph */ true) .AddCommonOptimization() .Add(CreateKqpConstantFoldingTransformer(OptimizeCtx, *typesCtx, Config), "ConstantFolding") - .Add(CreateKqpStatisticsTransformer(OptimizeCtx, *typesCtx, Config), "Statistics") - .Add(CreateKqpLogOptTransformer(OptimizeCtx, *typesCtx, Config), "LogicalOptimize") + .Add(CreateKqpStatisticsTransformer(OptimizeCtx, *typesCtx, Config, Pctx), "Statistics") + .Add(CreateKqpLogOptTransformer(OptimizeCtx, *typesCtx, Config, Pctx), "LogicalOptimize") .Add(CreateLogicalDataProposalsInspector(*typesCtx), "ProvidersLogicalOptimize") .Add(CreateKqpPhyOptTransformer(OptimizeCtx, *typesCtx), "KqpPhysicalOptimize") .Add(CreatePhysicalDataProposalsInspector(*typesCtx), "ProvidersPhysicalOptimize") @@ -293,7 +297,7 @@ class TKqpRunner : public IKqpRunner { .AddTypeAnnotationTransformer(CreateKqpTypeAnnotationTransformer(Cluster, sessionCtx->TablesPtr(), *typesCtx, Config)) .AddPostTypeAnnotation() .Add(CreateKqpBuildPhysicalQueryTransformer(OptimizeCtx, BuildQueryCtx), "BuildPhysicalQuery") - .Add(CreateKqpStatisticsTransformer(OptimizeCtx, *typesCtx, Config), "Statistics") + .Add(CreateKqpStatisticsTransformer(OptimizeCtx, *typesCtx, Config, Pctx), "Statistics") .Build(false); auto physicalPeepholeTransformer = TTransformationPipeline(typesCtx) @@ -355,6 +359,8 @@ class TKqpRunner : public IKqpRunner { TIntrusivePtr OptimizeCtx; TIntrusivePtr BuildQueryCtx; + TKqpProviderContext Pctx; + TAutoPtr Transformer; }; diff --git a/ydb/core/kqp/opt/kqp_query_plan.cpp b/ydb/core/kqp/opt/kqp_query_plan.cpp index d2778ba244ea..86a1dfa0b1a6 100644 --- a/ydb/core/kqp/opt/kqp_query_plan.cpp +++ b/ydb/core/kqp/opt/kqp_query_plan.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -1357,7 +1358,7 @@ class TxPlanSerializer { } void AddOptimizerEstimates(TOperator& op, const TExprBase& expr) { - if (!SerializerCtx.Config->HasOptEnableCostBasedOptimization()) { + if (SerializerCtx.Config->CostBasedOptimizationLevel.Get().GetOrElse(TDqSettings::TDefault::CostBasedOptimizationLevel)==0) { return; } diff --git a/ydb/core/kqp/opt/kqp_statistics_transformer.cpp b/ydb/core/kqp/opt/kqp_statistics_transformer.cpp index a5594b43c8b8..ae12cba6c887 100644 --- a/ydb/core/kqp/opt/kqp_statistics_transformer.cpp +++ b/ydb/core/kqp/opt/kqp_statistics_transformer.cpp @@ -3,6 +3,9 @@ #include #include +#include + + #include using namespace NYql; @@ -187,7 +190,7 @@ IGraphTransformer::TStatus TKqpStatisticsTransformer::DoTransform(TExprNode::TPt TExprNode::TPtr& output, TExprContext& ctx) { output = input; - if (!Config->HasOptEnableCostBasedOptimization()) { + if (Config->CostBasedOptimizationLevel.Get().GetOrElse(TDqSettings::TDefault::CostBasedOptimizationLevel) == 0) { return IGraphTransformer::TStatus::Ok; } @@ -238,6 +241,6 @@ bool TKqpStatisticsTransformer::AfterLambdasSpecific(const TExprNode::TPtr& inpu } TAutoPtr NKikimr::NKqp::CreateKqpStatisticsTransformer(const TIntrusivePtr& kqpCtx, - TTypeAnnotationContext& typeCtx, const TKikimrConfiguration::TPtr& config) { - return THolder(new TKqpStatisticsTransformer(kqpCtx, typeCtx, config)); + TTypeAnnotationContext& typeCtx, const TKikimrConfiguration::TPtr& config, const TKqpProviderContext& pctx) { + return THolder(new TKqpStatisticsTransformer(kqpCtx, typeCtx, config, pctx)); } diff --git a/ydb/core/kqp/opt/kqp_statistics_transformer.h b/ydb/core/kqp/opt/kqp_statistics_transformer.h index 3f4d9a3a39c5..3c54c7ee768f 100644 --- a/ydb/core/kqp/opt/kqp_statistics_transformer.h +++ b/ydb/core/kqp/opt/kqp_statistics_transformer.h @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -33,8 +34,8 @@ class TKqpStatisticsTransformer : public NYql::NDq::TDqStatisticsTransformerBase public: TKqpStatisticsTransformer(const TIntrusivePtr& kqpCtx, TTypeAnnotationContext& typeCtx, - const TKikimrConfiguration::TPtr& config) : - TDqStatisticsTransformerBase(&typeCtx), + const TKikimrConfiguration::TPtr& config, const TKqpProviderContext& pctx) : + TDqStatisticsTransformerBase(&typeCtx, pctx), Config(config), KqpCtx(*kqpCtx) {} @@ -47,6 +48,6 @@ class TKqpStatisticsTransformer : public NYql::NDq::TDqStatisticsTransformerBase }; TAutoPtr CreateKqpStatisticsTransformer(const TIntrusivePtr& kqpCtx, - TTypeAnnotationContext& typeCtx, const TKikimrConfiguration::TPtr& config); + TTypeAnnotationContext& typeCtx, const TKikimrConfiguration::TPtr& config, const TKqpProviderContext& pctx); } } diff --git a/ydb/core/kqp/opt/logical/kqp_opt_cbo.cpp b/ydb/core/kqp/opt/logical/kqp_opt_cbo.cpp new file mode 100644 index 000000000000..b421ba29757f --- /dev/null +++ b/ydb/core/kqp/opt/logical/kqp_opt_cbo.cpp @@ -0,0 +1,164 @@ +#include "kqp_opt_cbo.h" +#include "kqp_opt_log_impl.h" + +#include +#include + + +namespace NKikimr::NKqp::NOpt { + +using namespace NYql; +using namespace NYql::NCommon; +using namespace NYql::NDq; +using namespace NYql::NNodes; + +namespace { + +/** + * KQP specific rule to check if a LookupJoin is applicable +*/ +bool IsLookupJoinApplicableDetailed(const std::shared_ptr& node, const TVector& joinColumns, const TKqpProviderContext& ctx) { + + auto rel = std::static_pointer_cast(node); + auto expr = TExprBase(rel->Node); + + if (ctx.KqpCtx.IsScanQuery() && !ctx.KqpCtx.Config->EnableKqpScanQueryStreamIdxLookupJoin) { + return false; + } + + if (find_if(joinColumns.begin(), joinColumns.end(), [&] (const TString& s) { return node->Stats->KeyColumns[0] == s;})) { + return true; + } + + auto readMatch = MatchRead(expr); + TMaybeNode maybeTablePrefix; + size_t prefixSize; + + if (readMatch) { + if (readMatch->FlatMap && !IsPassthroughFlatMap(readMatch->FlatMap.Cast(), nullptr)){ + return false; + } + auto read = readMatch->Read.Cast(); + maybeTablePrefix = GetRightTableKeyPrefix(read.Range()); + + if (!maybeTablePrefix) { + return false; + } + + prefixSize = maybeTablePrefix.Cast().ArgCount(); + + if (!prefixSize) { + return true; + } + } + else { + readMatch = MatchRead(expr); + if (readMatch) { + if (readMatch->FlatMap && !IsPassthroughFlatMap(readMatch->FlatMap.Cast(), nullptr)){ + return false; + } + auto read = readMatch->Read.Cast(); + if (TCoVoid::Match(read.Ranges().Raw())) { + return true; + } else { + auto prompt = TKqpReadTableExplainPrompt::Parse(read); + + if (prompt.PointPrefixLen != prompt.UsedKeyColumns.size()) { + return false; + } + + if (prompt.ExpectedMaxRanges != TMaybe(1)) { + return false; + } + prefixSize = prompt.PointPrefixLen; + } + } + } + if (! readMatch) { + return false; + } + + if (prefixSize < node->Stats->KeyColumns.size() && !(find_if(joinColumns.begin(), joinColumns.end(), [&] (const TString& s) { + return node->Stats->KeyColumns[prefixSize] == s; + }))){ + return false; + } + + return true; +} + +bool IsLookupJoinApplicable(std::shared_ptr left, + std::shared_ptr right, + const std::set>& joinConditions, + TKqpProviderContext& ctx) { + + Y_UNUSED(left); + + auto rightStats = right->Stats; + + if (rightStats->Type != EStatisticsType::BaseTable) { + return false; + } + if (joinConditions.size() > rightStats->KeyColumns.size()) { + return false; + } + + for (auto [leftCol, rightCol] : joinConditions) { + if (! find_if(rightStats->KeyColumns.begin(), rightStats->KeyColumns.end(), + [rightCol] (const TString& s) { + return rightCol.AttributeName == s; + } )) { + return false; + } + } + + TVector joinKeys; + for( auto [leftJc, rightJc] : joinConditions ) { + joinKeys.emplace_back( rightJc.AttributeName); + } + + return IsLookupJoinApplicableDetailed(std::static_pointer_cast(right), joinKeys, ctx); +} + +} + +bool TKqpProviderContext::IsJoinApplicable(const std::shared_ptr& left, + const std::shared_ptr& right, + const std::set>& joinConditions, + EJoinAlgoType joinAlgo) { + + switch( joinAlgo ) { + case EJoinAlgoType::LookupJoin: + if (OptLevel==2 && left->Stats->Nrows > 10e3) { + return false; + } + return IsLookupJoinApplicable(left, right, joinConditions, *this); + + case EJoinAlgoType::DictJoin: + return right->Stats->Nrows < 10e5; + case EJoinAlgoType::MapJoin: + return right->Stats->Nrows < 10e6; + case EJoinAlgoType::GraceJoin: + return true; + } +} + +double TKqpProviderContext::ComputeJoinCost(const TOptimizerStatistics& leftStats, const TOptimizerStatistics& rightStats, EJoinAlgoType joinAlgo) const { + + switch(joinAlgo) { + case EJoinAlgoType::LookupJoin: + if (OptLevel==1) { + return -1; + } + return leftStats.Nrows; + case EJoinAlgoType::DictJoin: + return leftStats.Nrows + 1.7 * rightStats.Nrows; + case EJoinAlgoType::MapJoin: + return leftStats.Nrows + 1.8 * rightStats.Nrows; + case EJoinAlgoType::GraceJoin: + return leftStats.Nrows + 2.0 * rightStats.Nrows; + } +} + + +} \ No newline at end of file diff --git a/ydb/core/kqp/opt/logical/kqp_opt_cbo.h b/ydb/core/kqp/opt/logical/kqp_opt_cbo.h new file mode 100644 index 000000000000..13b6b0200ec5 --- /dev/null +++ b/ydb/core/kqp/opt/logical/kqp_opt_cbo.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include + +#include + +namespace NKikimr::NKqp::NOpt { + +/** + * KQP specific Rel node, includes a pointer to ExprNode +*/ +struct TKqpRelOptimizerNode : public NYql::TRelOptimizerNode { + const NYql::TExprNode::TPtr Node; + + TKqpRelOptimizerNode(TString label, std::shared_ptr stats, const NYql::TExprNode::TPtr node) : + TRelOptimizerNode(label, stats), Node(node) { } +}; + +/** + * KQP Specific cost function and join applicability cost function +*/ +struct TKqpProviderContext : public NYql::IProviderContext { + TKqpProviderContext(const TKqpOptimizeContext& kqpCtx, const int optLevel) : KqpCtx(kqpCtx), OptLevel(optLevel) {} + + virtual bool IsJoinApplicable(const std::shared_ptr& left, + const std::shared_ptr& right, + const std::set>& joinConditions, + NYql::EJoinAlgoType joinAlgo) override; + + virtual double ComputeJoinCost(const NYql::TOptimizerStatistics& leftStats, const NYql::TOptimizerStatistics& rightStats, NYql::EJoinAlgoType joinAlgo) const override; + + const TKqpOptimizeContext& KqpCtx; + int OptLevel; +}; + +} \ No newline at end of file diff --git a/ydb/core/kqp/opt/logical/kqp_opt_log.cpp b/ydb/core/kqp/opt/logical/kqp_opt_log.cpp index 96f7055653fb..5119d7769d20 100644 --- a/ydb/core/kqp/opt/logical/kqp_opt_log.cpp +++ b/ydb/core/kqp/opt/logical/kqp_opt_log.cpp @@ -1,4 +1,5 @@ #include "kqp_opt_log_rules.h" +#include "kqp_opt_cbo.h" #include #include @@ -21,11 +22,12 @@ using namespace NYql::NNodes; class TKqpLogicalOptTransformer : public TOptimizeTransformerBase { public: TKqpLogicalOptTransformer(TTypeAnnotationContext& typesCtx, const TIntrusivePtr& kqpCtx, - const TKikimrConfiguration::TPtr& config) + const TKikimrConfiguration::TPtr& config, TKqpProviderContext& pctx) : TOptimizeTransformerBase(nullptr, NYql::NLog::EComponent::ProviderKqp, {}) , TypesCtx(typesCtx) , KqpCtx(*kqpCtx) , Config(config) + , Pctx(pctx) { #define HNDL(name) "KqpLogical-"#name, Hndl(&TKqpLogicalOptTransformer::name) AddHandler(0, &TCoFlatMapBase::Match, HNDL(PushPredicateToReadTable)); @@ -134,7 +136,10 @@ class TKqpLogicalOptTransformer : public TOptimizeTransformerBase { TMaybeNode OptimizeEquiJoinWithCosts(TExprBase node, TExprContext& ctx) { auto maxDPccpDPTableSize = Config->MaxDPccpDPTableSize.Get().GetOrElse(TDqSettings::TDefault::MaxDPccpDPTableSize); - TExprBase output = DqOptimizeEquiJoinWithCosts(node, ctx, TypesCtx, Config->HasOptEnableCostBasedOptimization(), maxDPccpDPTableSize); + TExprBase output = DqOptimizeEquiJoinWithCosts(node, ctx, TypesCtx, Config->CostBasedOptimizationLevel.Get().GetOrElse(TDqSettings::TDefault::CostBasedOptimizationLevel), + maxDPccpDPTableSize, Pctx, [](auto& rels, auto label, auto node, auto stat) { + rels.emplace_back(std::make_shared(TString(label), stat, node)); + }); DumpAppliedRule("OptimizeEquiJoinWithCosts", node.Ptr(), output.Ptr(), ctx); return output; } @@ -269,12 +274,14 @@ class TKqpLogicalOptTransformer : public TOptimizeTransformerBase { TTypeAnnotationContext& TypesCtx; const TKqpOptimizeContext& KqpCtx; const TKikimrConfiguration::TPtr& Config; + TKqpProviderContext& Pctx; }; TAutoPtr CreateKqpLogOptTransformer(const TIntrusivePtr& kqpCtx, - TTypeAnnotationContext& typesCtx, const TKikimrConfiguration::TPtr& config) + TTypeAnnotationContext& typesCtx, const TKikimrConfiguration::TPtr& config, + TKqpProviderContext& pctx) { - return THolder(new TKqpLogicalOptTransformer(typesCtx, kqpCtx, config)); + return THolder(new TKqpLogicalOptTransformer(typesCtx, kqpCtx, config, pctx)); } } // namespace NKikimr::NKqp::NOpt diff --git a/ydb/core/kqp/opt/logical/kqp_opt_log.h b/ydb/core/kqp/opt/logical/kqp_opt_log.h index e833934a54be..3e5e9ed87996 100644 --- a/ydb/core/kqp/opt/logical/kqp_opt_log.h +++ b/ydb/core/kqp/opt/logical/kqp_opt_log.h @@ -1,12 +1,14 @@ #pragma once #include +#include namespace NKikimr::NKqp::NOpt { struct TKqpOptimizeContext; TAutoPtr CreateKqpLogOptTransformer(const TIntrusivePtr& kqpCtx, - NYql::TTypeAnnotationContext& typesCtx, const NYql::TKikimrConfiguration::TPtr& config); + NYql::TTypeAnnotationContext& typesCtx, const NYql::TKikimrConfiguration::TPtr& config, + NKikimr::NKqp::NOpt::TKqpProviderContext& pctx); } // namespace NKikimr::NKqp::NOpt diff --git a/ydb/core/kqp/opt/logical/kqp_opt_log_impl.h b/ydb/core/kqp/opt/logical/kqp_opt_log_impl.h index 49f5c74e0324..5ff32edc6eeb 100644 --- a/ydb/core/kqp/opt/logical/kqp_opt_log_impl.h +++ b/ydb/core/kqp/opt/logical/kqp_opt_log_impl.h @@ -22,6 +22,8 @@ TMaybe MatchRead(NYql::NNodes::TExprBase node) { return MatchRead(node, [] (NYql::NNodes::TExprBase node) { return node.Maybe().IsValid(); }); } +NYql::NNodes::TMaybeNode GetRightTableKeyPrefix(const NYql::NNodes::TKqlKeyRange& range); + } // NKikimr::NKqp::NOpt diff --git a/ydb/core/kqp/opt/logical/kqp_opt_log_join.cpp b/ydb/core/kqp/opt/logical/kqp_opt_log_join.cpp index a9a248a19701..47affa3e19bc 100644 --- a/ydb/core/kqp/opt/logical/kqp_opt_log_join.cpp +++ b/ydb/core/kqp/opt/logical/kqp_opt_log_join.cpp @@ -167,25 +167,6 @@ TDqJoin FlipLeftSemiJoin(const TDqJoin& join, TExprContext& ctx) { .Done(); } -TMaybeNode GetRightTableKeyPrefix(const TKqlKeyRange& range) { - if (!range.From().Maybe() || !range.To().Maybe()) { - return {}; - } - auto rangeFrom = range.From().Cast(); - auto rangeTo = range.To().Cast(); - - if (rangeFrom.ArgCount() != rangeTo.ArgCount()) { - return {}; - } - for (ui32 i = 0; i < rangeFrom.ArgCount(); ++i) { - if (rangeFrom.Arg(i).Raw() != rangeTo.Arg(i).Raw()) { - return {}; - } - } - - return rangeFrom; -} - TExprBase BuildLookupIndex(TExprContext& ctx, const TPositionHandle pos, const TKqpTable& table, const TCoAtomList& columns, const TExprBase& keysToLookup, const TVector& skipNullColumns, const TString& indexName, @@ -859,6 +840,25 @@ TMaybeNode KqpJoinToIndexLookupImpl(const TDqJoin& join, TExprContext } // anonymous namespace +TMaybeNode GetRightTableKeyPrefix(const TKqlKeyRange& range) { + if (!range.From().Maybe() || !range.To().Maybe()) { + return {}; + } + auto rangeFrom = range.From().Cast(); + auto rangeTo = range.To().Cast(); + + if (rangeFrom.ArgCount() != rangeTo.ArgCount()) { + return {}; + } + for (ui32 i = 0; i < rangeFrom.ArgCount(); ++i) { + if (rangeFrom.Arg(i).Raw() != rangeTo.Arg(i).Raw()) { + return {}; + } + } + + return rangeFrom; +} + TExprBase KqpJoinToIndexLookup(const TExprBase& node, TExprContext& ctx, const TKqpOptimizeContext& kqpCtx) { if ((kqpCtx.IsScanQuery() && !kqpCtx.Config->EnableKqpScanQueryStreamIdxLookupJoin) || !node.Maybe()) { diff --git a/ydb/core/kqp/opt/logical/ya.make b/ydb/core/kqp/opt/logical/ya.make index 8b0b7ad5ab1e..017e5bf87b28 100644 --- a/ydb/core/kqp/opt/logical/ya.make +++ b/ydb/core/kqp/opt/logical/ya.make @@ -12,6 +12,7 @@ SRCS( kqp_opt_log_sqlin.cpp kqp_opt_log_sqlin_compact.cpp kqp_opt_log.cpp + kqp_opt_cbo.cpp ) PEERDIR( diff --git a/ydb/core/kqp/provider/yql_kikimr_settings.cpp b/ydb/core/kqp/provider/yql_kikimr_settings.cpp index c3a4e769f0a0..b1dfad359e33 100644 --- a/ydb/core/kqp/provider/yql_kikimr_settings.cpp +++ b/ydb/core/kqp/provider/yql_kikimr_settings.cpp @@ -65,7 +65,7 @@ TKikimrConfiguration::TKikimrConfiguration() { REGISTER_SETTING(*this, OptEnableOlapProvideComputeSharding); REGISTER_SETTING(*this, OptUseFinalizeByKey); - REGISTER_SETTING(*this, OptEnableCostBasedOptimization); + REGISTER_SETTING(*this, CostBasedOptimizationLevel); REGISTER_SETTING(*this, OptEnableConstantFolding); REGISTER_SETTING(*this, MaxDPccpDPTableSize); @@ -122,10 +122,6 @@ bool TKikimrSettings::HasOptUseFinalizeByKey() const { return GetOptionalFlagValue(OptUseFinalizeByKey.Get()) != EOptionalFlag::Disabled; } -bool TKikimrSettings::HasOptEnableCostBasedOptimization() const { - return GetOptionalFlagValue(OptEnableCostBasedOptimization.Get()) == EOptionalFlag::Enabled; -} - bool TKikimrSettings::HasOptEnableConstantFolding() const { return GetOptionalFlagValue(OptEnableConstantFolding.Get()) == EOptionalFlag::Enabled; } diff --git a/ydb/core/kqp/provider/yql_kikimr_settings.h b/ydb/core/kqp/provider/yql_kikimr_settings.h index fd8f09180052..f6fb2beb1e2e 100644 --- a/ydb/core/kqp/provider/yql_kikimr_settings.h +++ b/ydb/core/kqp/provider/yql_kikimr_settings.h @@ -58,7 +58,7 @@ struct TKikimrSettings { NCommon::TConfSetting OptEnableOlapPushdown; NCommon::TConfSetting OptEnableOlapProvideComputeSharding; NCommon::TConfSetting OptUseFinalizeByKey; - NCommon::TConfSetting OptEnableCostBasedOptimization; + NCommon::TConfSetting CostBasedOptimizationLevel; NCommon::TConfSetting OptEnableConstantFolding; NCommon::TConfSetting MaxDPccpDPTableSize; @@ -81,7 +81,6 @@ struct TKikimrSettings { bool HasOptEnableOlapPushdown() const; bool HasOptEnableOlapProvideComputeSharding() const; bool HasOptUseFinalizeByKey() const; - bool HasOptEnableCostBasedOptimization() const; bool HasOptEnableConstantFolding() const; diff --git a/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp b/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp index f2e780e40605..a85a5cd80725 100644 --- a/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp +++ b/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp @@ -17,7 +17,7 @@ using namespace NYdb::NTable; static void CreateSampleTable(TSession session) { UNIT_ASSERT(session.ExecuteSchemeQuery(R"( CREATE TABLE `/Root/R` ( - id Int32, + id Int32 not null, payload1 String, ts Date, PRIMARY KEY (id) @@ -26,7 +26,7 @@ static void CreateSampleTable(TSession session) { UNIT_ASSERT(session.ExecuteSchemeQuery(R"( CREATE TABLE `/Root/S` ( - id Int32, + id Int32 not null, payload2 String, PRIMARY KEY (id) ); @@ -34,7 +34,7 @@ static void CreateSampleTable(TSession session) { UNIT_ASSERT(session.ExecuteSchemeQuery(R"( CREATE TABLE `/Root/T` ( - id Int32, + id Int32 not null, payload3 String, PRIMARY KEY (id) ); @@ -42,7 +42,7 @@ static void CreateSampleTable(TSession session) { UNIT_ASSERT(session.ExecuteSchemeQuery(R"( CREATE TABLE `/Root/U` ( - id Int32, + id Int32 not null, payload4 String, PRIMARY KEY (id) ); @@ -50,7 +50,7 @@ static void CreateSampleTable(TSession session) { UNIT_ASSERT(session.ExecuteSchemeQuery(R"( CREATE TABLE `/Root/V` ( - id Int32, + id Int32 not null, payload5 String, PRIMARY KEY (id) ); @@ -73,15 +73,118 @@ static void CreateSampleTable(TSession session) { REPLACE INTO `/Root/V` (id, payload5) VALUES (1, "blah"); )", TTxControl::BeginTx().CommitTx()).GetValueSync().IsSuccess()); + + UNIT_ASSERT(session.ExecuteSchemeQuery(R"( + CREATE TABLE `/Root/customer` ( + c_acctbal Double, + c_address String, + c_comment String, + c_custkey Int32, -- Identifier + c_mktsegment String , + c_name String , + c_nationkey Int32 , -- FK to N_NATIONKEY + c_phone String , + PRIMARY KEY (c_custkey) +) +; + +CREATE TABLE `/Root/lineitem` ( + l_comment String , + l_commitdate Date , + l_discount Double , -- it should be Decimal(12, 2) + l_extendedprice Double , -- it should be Decimal(12, 2) + l_linenumber Int32 , + l_linestatus String , + l_orderkey Int32 , -- FK to O_ORDERKEY + l_partkey Int32 , -- FK to P_PARTKEY, first part of the compound FK to (PS_PARTKEY, PS_SUPPKEY) with L_SUPPKEY + l_quantity Double , -- it should be Decimal(12, 2) + l_receiptdate Date , + l_returnflag String , + l_shipdate Date , + l_shipinstruct String , + l_shipmode String , + l_suppkey Int32 , -- FK to S_SUPPKEY, second part of the compound FK to (PS_PARTKEY, PS_SUPPKEY) with L_PARTKEY + l_tax Double , -- it should be Decimal(12, 2) + PRIMARY KEY (l_orderkey, l_linenumber) +) +; + +CREATE TABLE `/Root/nation` ( + n_comment String , + n_name String , + n_nationkey Int32 , -- Identifier + n_regionkey Int32 , -- FK to R_REGIONKEY + PRIMARY KEY(n_nationkey) +) +; + +CREATE TABLE `/Root/orders` ( + o_clerk String , + o_comment String , + o_custkey Int32 , -- FK to C_CUSTKEY + o_orderdate Date , + o_orderkey Int32 , -- Identifier + o_orderpriority String , + o_orderstatus String , + o_shippriority Int32 , + o_totalprice Double , -- it should be Decimal(12, 2) + PRIMARY KEY (o_orderkey) +) +; + +CREATE TABLE `/Root/part` ( + p_brand String , + p_comment String , + p_container String , + p_mfgr String , + p_name String , + p_partkey Int32 , -- Identifier + p_retailprice Double , -- it should be Decimal(12, 2) + p_size Int32 , + p_type String , + PRIMARY KEY(p_partkey) +) +; + +CREATE TABLE `/Root/partsupp` ( + ps_availqty Int32 , + ps_comment String , + ps_partkey Int32 , -- FK to P_PARTKEY + ps_suppkey Int32 , -- FK to S_SUPPKEY + ps_supplycost Double , -- it should be Decimal(12, 2) + PRIMARY KEY(ps_partkey, ps_suppkey) +) +; + +CREATE TABLE `/Root/region` ( + r_comment String , + r_name String , + r_regionkey Int32 , -- Identifier + PRIMARY KEY(r_regionkey) +) +; + +CREATE TABLE `/Root/supplier` ( + s_acctbal Double , -- it should be Decimal(12, 2) + s_address String , + s_comment String , + s_name String , + s_nationkey Int32 , -- FK to N_NATIONKEY + s_phone String , + s_suppkey Int32 , -- Identifier + PRIMARY KEY(s_suppkey) +) +;)").GetValueSync().IsSuccess()); + } static TKikimrRunner GetKikimrWithJoinSettings(){ TVector settings; NKikimrKqp::TKqpSetting setting; - - setting.SetName("OptEnableCostBasedOptimization"); - setting.SetValue("true"); + + setting.SetName("CostBasedOptimizationLevel"); + setting.SetValue("1"); settings.push_back(setting); setting.SetName("OptEnableConstantFolding"); @@ -458,8 +561,312 @@ Y_UNIT_TEST_SUITE(KqpJoinOrder) { Cout << result.GetPlan(); } } -} + Y_UNIT_TEST(TPCH21) { + + auto kikimr = GetKikimrWithJoinSettings(); + auto db = kikimr.GetTableClient(); + auto session = db.CreateSession().GetValueSync().GetSession(); + + CreateSampleTable(session); + + /* join with parameters */ + { + const TString query = Q_(R"( +-- TPC-H/TPC-R Suppliers Who Kept Orders Waiting Query (Q21) +-- TPC TPC-H Parameter Substitution (Version 2.17.2 build 0) +-- using 1680793381 as a seed to the RNG + +$n = select n_nationkey from `/Root/nation` +where n_name = 'EGYPT'; + +$s = select s_name, s_suppkey from `/Root/supplier` as supplier +join $n as nation +on supplier.s_nationkey = nation.n_nationkey; + +$l = select l_suppkey, l_orderkey from `/Root/lineitem` +where l_receiptdate > l_commitdate; + +$j1 = select s_name, l_suppkey, l_orderkey from $l as l1 +join $s as supplier +on l1.l_suppkey = supplier.s_suppkey; + +-- exists +$j2 = select l1.l_orderkey as l_orderkey, l1.l_suppkey as l_suppkey, l1.s_name as s_name, l2.l_receiptdate as l_receiptdate, l2.l_commitdate as l_commitdate from $j1 as l1 +join `/Root/lineitem` as l2 +on l1.l_orderkey = l2.l_orderkey +where l2.l_suppkey <> l1.l_suppkey; + +$j2_1 = select s_name, l1.l_suppkey as l_suppkey, l1.l_orderkey as l_orderkey from $j1 as l1 +left semi join $j2 as l2 +on l1.l_orderkey = l2.l_orderkey; + +-- not exists +$j2_2 = select l_orderkey from $j2 where l_receiptdate > l_commitdate; + +$j3 = select s_name, l_suppkey, l_orderkey from $j2_1 as l1 +left only join $j2_2 as l3 +on l1.l_orderkey = l3.l_orderkey; + +$j4 = select s_name from $j3 as l1 +join `/Root/orders` as orders +on orders.o_orderkey = l1.l_orderkey +where o_orderstatus = 'F'; + +select s_name, + count(*) as numwait from $j4 +group by + s_name +order by + numwait desc, + s_name +limit 100;)"); + + auto result = session.ExplainDataQuery(query).ExtractValueSync(); + + UNIT_ASSERT_VALUES_EQUAL(result.GetStatus(), EStatus::SUCCESS); + + NJson::TJsonValue plan; + NJson::ReadJsonTree(result.GetPlan(), &plan, true); + Cout << result.GetPlan(); + } + } + + Y_UNIT_TEST(TPCH5) { + + auto kikimr = GetKikimrWithJoinSettings(); + auto db = kikimr.GetTableClient(); + auto session = db.CreateSession().GetValueSync().GetSession(); + + CreateSampleTable(session); + + /* join with parameters */ + { + const TString query = Q_(R"( +-- TPC-H/TPC-R Local Supplier Volume Query (Q5) +-- TPC TPC-H Parameter Substitution (Version 2.17.2 build 0) +-- using 1680793381 as a seed to the RNG + +$join1 = ( +select + o.o_orderkey as o_orderkey, + o.o_orderdate as o_orderdate, + c.c_nationkey as c_nationkey +from + `/Root/customer` as c +join + `/Root/orders` as o +on + c.c_custkey = o.o_custkey +); + +$join2 = ( +select + j.o_orderkey as o_orderkey, + j.o_orderdate as o_orderdate, + j.c_nationkey as c_nationkey, + l.l_extendedprice as l_extendedprice, + l.l_discount as l_discount, + l.l_suppkey as l_suppkey +from + $join1 as j +join + `/Root/lineitem` as l +on + l.l_orderkey = j.o_orderkey +); + +$join3 = ( +select + j.o_orderkey as o_orderkey, + j.o_orderdate as o_orderdate, + j.c_nationkey as c_nationkey, + j.l_extendedprice as l_extendedprice, + j.l_discount as l_discount, + j.l_suppkey as l_suppkey, + s.s_nationkey as s_nationkey +from + $join2 as j +join + `/Root/supplier` as s +on + j.l_suppkey = s.s_suppkey +); +$join4 = ( +select + j.o_orderkey as o_orderkey, + j.o_orderdate as o_orderdate, + j.c_nationkey as c_nationkey, + j.l_extendedprice as l_extendedprice, + j.l_discount as l_discount, + j.l_suppkey as l_suppkey, + j.s_nationkey as s_nationkey, + n.n_regionkey as n_regionkey, + n.n_name as n_name +from + $join3 as j +join + `/Root/nation` as n +on + j.s_nationkey = n.n_nationkey + and j.c_nationkey = n.n_nationkey +); +$join5 = ( +select + j.o_orderkey as o_orderkey, + j.o_orderdate as o_orderdate, + j.c_nationkey as c_nationkey, + j.l_extendedprice as l_extendedprice, + j.l_discount as l_discount, + j.l_suppkey as l_suppkey, + j.s_nationkey as s_nationkey, + j.n_regionkey as n_regionkey, + j.n_name as n_name, + r.r_name as r_name +from + $join4 as j +join + `/Root/region` as r +on + j.n_regionkey = r.r_regionkey +); +$border = Date('1995-01-01'); +select + n_name, + sum(l_extendedprice * (1 - l_discount)) as revenue +from + $join5 +where + r_name = 'AFRICA' + and CAST(o_orderdate AS Timestamp) >= $border + and CAST(o_orderdate AS Timestamp) < ($border + Interval('P365D')) +group by + n_name +order by + revenue desc; + + )"); + + auto result = session.ExplainDataQuery(query).ExtractValueSync(); + + UNIT_ASSERT_VALUES_EQUAL(result.GetStatus(), EStatus::SUCCESS); + + NJson::TJsonValue plan; + NJson::ReadJsonTree(result.GetPlan(), &plan, true); + Cout << result.GetPlan(); + } + } + + Y_UNIT_TEST(TPCH10) { + + auto kikimr = GetKikimrWithJoinSettings(); + auto db = kikimr.GetTableClient(); + auto session = db.CreateSession().GetValueSync().GetSession(); + + CreateSampleTable(session); + + /* join with parameters */ + { + const TString query = Q_(R"( + +-- TPC-H/TPC-R Returned Item Reporting Query (Q10) +-- TPC TPC-H Parameter Substitution (Version 2.17.2 build 0) +-- using 1680793381 as a seed to the RNG + +$border = Date("1993-12-01"); +$join1 = ( +select + c.c_custkey as c_custkey, + c.c_name as c_name, + c.c_acctbal as c_acctbal, + c.c_address as c_address, + c.c_phone as c_phone, + c.c_comment as c_comment, + c.c_nationkey as c_nationkey, + o.o_orderkey as o_orderkey +from + `/Root/customer` as c +join + `/Root/orders` as o +on + c.c_custkey = o.o_custkey +where + cast(o.o_orderdate as timestamp) >= $border and + cast(o.o_orderdate as timestamp) < ($border + Interval("P90D")) +); +$join2 = ( +select + j.c_custkey as c_custkey, + j.c_name as c_name, + j.c_acctbal as c_acctbal, + j.c_address as c_address, + j.c_phone as c_phone, + j.c_comment as c_comment, + j.c_nationkey as c_nationkey, + l.l_extendedprice as l_extendedprice, + l.l_discount as l_discount +from + $join1 as j +join + `/Root/lineitem` as l +on + l.l_orderkey = j.o_orderkey +where + l.l_returnflag = 'R' +); +$join3 = ( +select + j.c_custkey as c_custkey, + j.c_name as c_name, + j.c_acctbal as c_acctbal, + j.c_address as c_address, + j.c_phone as c_phone, + j.c_comment as c_comment, + j.c_nationkey as c_nationkey, + j.l_extendedprice as l_extendedprice, + j.l_discount as l_discount, + n.n_name as n_name +from + $join2 as j +join + `/Root/nation` as n +on + n.n_nationkey = j.c_nationkey +); +select + c_custkey, + c_name, + sum(l_extendedprice * (1 - l_discount)) as revenue, + c_acctbal, + n_name, + c_address, + c_phone, + c_comment +from + $join3 +group by + c_custkey, + c_name, + c_acctbal, + c_phone, + n_name, + c_address, + c_comment +order by + revenue desc +limit 20; + )"); + + auto result = session.ExplainDataQuery(query).ExtractValueSync(); + + UNIT_ASSERT_VALUES_EQUAL(result.GetStatus(), EStatus::SUCCESS); + + NJson::TJsonValue plan; + NJson::ReadJsonTree(result.GetPlan(), &plan, true); + Cout << result.GetPlan(); + } + } +} } } diff --git a/ydb/library/yql/core/cbo/cbo_optimizer_new.cpp b/ydb/library/yql/core/cbo/cbo_optimizer_new.cpp index 2ee3d1a0d900..c7ce233d76f5 100644 --- a/ydb/library/yql/core/cbo/cbo_optimizer_new.cpp +++ b/ydb/library/yql/core/cbo/cbo_optimizer_new.cpp @@ -63,12 +63,13 @@ void TRelOptimizerNode::Print(std::stringstream& stream, int ntabs) { } TJoinOptimizerNode::TJoinOptimizerNode(const std::shared_ptr& left, const std::shared_ptr& right, - const std::set>& joinConditions, const EJoinKind joinType, bool nonReorderable) : + const std::set>& joinConditions, const EJoinKind joinType, const EJoinAlgoType joinAlgo, bool nonReorderable) : IBaseOptimizerNode(JoinNodeType), LeftArg(left), RightArg(right), JoinConditions(joinConditions), - JoinType(joinType) { + JoinType(joinType), + JoinAlgo(joinAlgo) { IsReorderable = (JoinType==EJoinKind::InnerJoin) && (nonReorderable==false); } diff --git a/ydb/library/yql/core/cbo/cbo_optimizer_new.h b/ydb/library/yql/core/cbo/cbo_optimizer_new.h index 1f35a0a2310c..256252241e76 100644 --- a/ydb/library/yql/core/cbo/cbo_optimizer_new.h +++ b/ydb/library/yql/core/cbo/cbo_optimizer_new.h @@ -49,8 +49,13 @@ struct IBaseOptimizerNode { struct TRelOptimizerNode : public IBaseOptimizerNode { TString Label; + // Temporary solution to check if a LookupJoin is possible in KQP + //void* Expr; + TRelOptimizerNode(TString label, std::shared_ptr stats) : IBaseOptimizerNode(RelNodeType, stats), Label(label) { } + //TRelOptimizerNode(TString label, std::shared_ptr stats, const TExprNode::TPtr expr) : + // IBaseOptimizerNode(RelNodeType, stats), Label(label), Expr(expr) { } virtual ~TRelOptimizerNode() {} virtual TVector Labels(); @@ -74,6 +79,54 @@ enum EJoinKind: ui32 EJoinKind ConvertToJoinKind(const TString& joinString); TString ConvertToJoinString(const EJoinKind kind); +/** + * This is a temporary structure for KQP provider + * We will soon be supporting multiple providers and we will need to design + * some interfaces to pass provider-specific context to the optimizer +*/ +struct IProviderContext { + virtual ~IProviderContext() = default; + + virtual double ComputeJoinCost(const TOptimizerStatistics& leftStats, const TOptimizerStatistics& rightStats, EJoinAlgoType joinAlgol) const = 0; + + virtual bool IsJoinApplicable(const std::shared_ptr& left, + const std::shared_ptr& right, + const std::set>& joinConditions, + EJoinAlgoType joinAlgo) = 0; + +}; + +/** + * Temporary solution for default provider context +*/ + +struct TDummyProviderContext : public IProviderContext { + TDummyProviderContext() {} + + double ComputeJoinCost(const TOptimizerStatistics& leftStats, const TOptimizerStatistics& rightStats, EJoinAlgoType joinAlgo) const override { + Y_UNUSED(joinAlgo); + return leftStats.Nrows + 2.0 * rightStats.Nrows; + } + + bool IsJoinApplicable(const std::shared_ptr& left, + const std::shared_ptr& right, + const std::set>& joinConditions, + EJoinAlgoType joinAlgo) override { + + Y_UNUSED(left); + Y_UNUSED(right); + Y_UNUSED(joinConditions); + Y_UNUSED(joinAlgo); + + return true; + } + + static const TDummyProviderContext& instance() { + static TDummyProviderContext staticContext; + return staticContext; + } + +}; /** * JoinOptimizerNode records the left and right arguments of the join @@ -86,16 +139,20 @@ struct TJoinOptimizerNode : public IBaseOptimizerNode { std::shared_ptr RightArg; std::set> JoinConditions; EJoinKind JoinType; + EJoinAlgoType JoinAlgo; bool IsReorderable; TJoinOptimizerNode(const std::shared_ptr& left, const std::shared_ptr& right, - const std::set>& joinConditions, const EJoinKind joinType, bool nonReorderable=false); + const std::set>& joinConditions, const EJoinKind joinType, const EJoinAlgoType joinAlgo, bool nonReorderable=false); virtual ~TJoinOptimizerNode() {} virtual TVector Labels(); virtual void Print(std::stringstream& stream, int ntabs=0); }; struct IOptimizerNew { + IProviderContext& Pctx; + + IOptimizerNew(IProviderContext& ctx) : Pctx(ctx) {} virtual ~IOptimizerNew() = default; virtual std::shared_ptr JoinSearch(const std::shared_ptr& joinTree) = 0; }; diff --git a/ydb/library/yql/core/yql_cost_function.cpp b/ydb/library/yql/core/yql_cost_function.cpp index 5724c91e5276..dcf395ca408e 100644 --- a/ydb/library/yql/core/yql_cost_function.cpp +++ b/ydb/library/yql/core/yql_cost_function.cpp @@ -1,5 +1,7 @@ #include "yql_cost_function.h" +#include + using namespace NYql; namespace { @@ -16,6 +18,7 @@ bool IsPKJoin(const TOptimizerStatistics& stats, const TVector& joinKey } return true; } + } bool NDq::operator < (const NDq::TJoinColumn& c1, const NDq::TJoinColumn& c2) { @@ -36,8 +39,7 @@ bool NDq::operator < (const NDq::TJoinColumn& c1, const NDq::TJoinColumn& c2) { */ TOptimizerStatistics NYql::ComputeJoinStats(const TOptimizerStatistics& leftStats, const TOptimizerStatistics& rightStats, - const TVector& leftJoinKeys, const TVector& rightJoinKeys, EJoinImplType joinImpl) { - Y_UNUSED(joinImpl); + const TVector& leftJoinKeys, const TVector& rightJoinKeys, EJoinAlgoType joinAlgo, const IProviderContext& ctx) { double newCard; EStatisticsType outputType; @@ -68,7 +70,7 @@ TOptimizerStatistics NYql::ComputeJoinStats(const TOptimizerStatistics& leftStat int newNCols = leftStats.Ncols + rightStats.Ncols; - double cost = leftStats.Nrows + 2.0 * rightStats.Nrows + double cost = ctx.ComputeJoinCost(leftStats, rightStats, joinAlgo) + newCard + leftStats.Cost + rightStats.Cost; @@ -76,7 +78,7 @@ TOptimizerStatistics NYql::ComputeJoinStats(const TOptimizerStatistics& leftStat } TOptimizerStatistics NYql::ComputeJoinStats(const TOptimizerStatistics& leftStats, const TOptimizerStatistics& rightStats, - const std::set>& joinConditions, EJoinImplType joinImpl) { + const std::set>& joinConditions, EJoinAlgoType joinAlgo, const IProviderContext& ctx) { TVector leftJoinKeys; TVector rightJoinKeys; @@ -86,5 +88,5 @@ TOptimizerStatistics NYql::ComputeJoinStats(const TOptimizerStatistics& leftStat rightJoinKeys.emplace_back(c.second.AttributeName); } - return ComputeJoinStats(leftStats, rightStats, leftJoinKeys, rightJoinKeys, joinImpl); + return ComputeJoinStats(leftStats, rightStats, leftJoinKeys, rightJoinKeys, joinAlgo, ctx); } diff --git a/ydb/library/yql/core/yql_cost_function.h b/ydb/library/yql/core/yql_cost_function.h index ae0b16de8021..9a040404cfce 100644 --- a/ydb/library/yql/core/yql_cost_function.h +++ b/ydb/library/yql/core/yql_cost_function.h @@ -14,6 +14,8 @@ */ namespace NYql { +struct IProviderContext; + namespace NDq { /** * Join column is a struct that records the relation label and @@ -43,16 +45,19 @@ bool operator < (const TJoinColumn& c1, const TJoinColumn& c2); } -enum EJoinImplType { +enum EJoinAlgoType { DictJoin, MapJoin, - GraceJoin + GraceJoin, + LookupJoin }; +static const EJoinAlgoType AllJoinTypes[] = { DictJoin, MapJoin, GraceJoin, LookupJoin }; + TOptimizerStatistics ComputeJoinStats(const TOptimizerStatistics& leftStats, const TOptimizerStatistics& rightStats, - const std::set>& joinConditions, EJoinImplType joinType); + const std::set>& joinConditions, EJoinAlgoType joinAlgo, const IProviderContext& ctx); TOptimizerStatistics ComputeJoinStats(const TOptimizerStatistics& leftStats, const TOptimizerStatistics& rightStats, - const TVector& leftJoinKeys, const TVector& rightJoinKeys, EJoinImplType joinType); + const TVector& leftJoinKeys, const TVector& rightJoinKeys, EJoinAlgoType joinAlgo, const IProviderContext& ctx); } \ No newline at end of file diff --git a/ydb/library/yql/dq/opt/dq_opt_join.h b/ydb/library/yql/dq/opt/dq_opt_join.h index 8d83b2de6466..9b9c0713454c 100644 --- a/ydb/library/yql/dq/opt/dq_opt_join.h +++ b/ydb/library/yql/dq/opt/dq_opt_join.h @@ -8,6 +8,7 @@ namespace NYql { struct TOptimizerStatistics; +struct TRelOptimizerNode; namespace NDq { @@ -25,9 +26,10 @@ NNodes::TExprBase DqBuildJoinDict(const NNodes::TDqJoin& join, TExprContext& ctx NNodes::TDqJoin DqSuppressSortOnJoinInput(const NNodes::TDqJoin& node, TExprContext& ctx); bool DqCollectJoinRelationsWithStats( + TVector>& rels, TTypeAnnotationContext& typesCtx, const NNodes::TCoEquiJoin& equiJoin, - const std::function&)>& collector); + const std::function>&, TStringBuf, const TExprNode::TPtr, const std::shared_ptr&)>& collector); } // namespace NDq } // namespace NYql diff --git a/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp b/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp index 79d73f11dec5..92d17378cc31 100644 --- a/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp +++ b/ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp @@ -13,7 +13,6 @@ #include //interface #include //new interface - #include @@ -98,10 +97,86 @@ void ComputeJoinConditions(const TCoEquiJoinTuple& joinTuple, std::shared_ptr MakeJoin(std::shared_ptr left, std::shared_ptr right, const std::set>& joinConditions, - EJoinImplType joinImpl) { + EJoinAlgoType joinAlgo, + IProviderContext& ctx) { + + auto res = std::make_shared(left, right, joinConditions, EJoinKind::InnerJoin, joinAlgo); + res->Stats = std::make_shared( ComputeJoinStats(*left->Stats, *right->Stats, joinConditions, joinAlgo, ctx)); + return res; +} + +/** + * Iterate over all join algorithms and pick the best join that is applicable. + * Also considers commuting joins +*/ +std::shared_ptr PickBestJoin(std::shared_ptr left, + std::shared_ptr right, + const std::set>& leftJoinConditions, + const std::set>& rightJoinConditions, + IProviderContext& ctx) { + + auto res = std::shared_ptr(); + + for ( auto joinType : AllJoinTypes ) { + auto p1 = ctx.IsJoinApplicable(left, right, leftJoinConditions, joinType) ? + MakeJoin(left, right, leftJoinConditions, joinType, ctx) : + std::shared_ptr(); + auto p2 = ctx.IsJoinApplicable(right, left, rightJoinConditions, joinType) ? + MakeJoin(right, left, rightJoinConditions, joinType, ctx) : + std::shared_ptr(); + + if (p1) { + if (res) { + if (p1->Stats->Cost < res->Stats->Cost) { + res = p1; + } + } else { + res = p1; + } + } + if (p2) { + if (res) { + if (p2->Stats->Cost < res->Stats->Cost) { + res = p2; + } + } else { + res = p2; + } + } + } - auto res = std::make_shared(left, right, joinConditions, EJoinKind::InnerJoin); - res->Stats = std::make_shared( ComputeJoinStats(*left->Stats, *right->Stats, joinConditions, joinImpl)); + Y_ENSURE(res,"No join was chosen!"); + return res; +} + +/** + * Iterate over all join algorithms and pick the best join that is applicable +*/ +std::shared_ptr PickBestNonReorderabeJoin(std::shared_ptr left, + std::shared_ptr right, + const std::set>& leftJoinConditions, + IProviderContext& ctx) { + + auto res = std::shared_ptr(); + + for ( auto joinType : AllJoinTypes ) { + auto p = ctx.IsJoinApplicable(left, right, leftJoinConditions, joinType) ? + MakeJoin(left, right, leftJoinConditions, joinType, ctx) : + std::shared_ptr(); + + if (p) { + if (res) { + if (p->Stats->Cost < res->Stats->Cost) { + res = p; + } + } else { + res = p; + } + } + + } + + Y_ENSURE(res,"No join was chosen!"); return res; } @@ -309,8 +384,8 @@ class TDPccpSolver { public: // Construct the DPccp solver based on the join graph and data about input relations - TDPccpSolver(TGraph& g, TVector> rels): - Graph(g), Rels(rels) { + TDPccpSolver(TGraph& g, TVector> rels, IProviderContext& ctx): + Graph(g), Rels(rels), Pctx(ctx) { NNodes = g.NNodes; } @@ -342,6 +417,10 @@ class TDPccpSolver { // List of input relations to DPccp TVector> Rels; + + // Provider specific contexts? + // FIXME: This is a temporary structure that needs to be extended to multiple providers + IProviderContext& Pctx; // Emit connected subgraph void EmitCsg(const std::bitset&, int=0); @@ -548,34 +627,27 @@ template void TDPccpSolver::EmitCsgCmp(const std::bitset& S1, cons std::bitset joined = S1 | S2; + TEdge e1 = Graph.FindCrossingEdge(S1, S2); + TEdge e2 = Graph.FindCrossingEdge(S2, S1); + auto bestJoin = PickBestJoin(DpTable[S1], DpTable[S2], e1.JoinConditions, e2.JoinConditions, Pctx); + if (! DpTable.contains(joined)) { - TEdge e1 = Graph.FindCrossingEdge(S1, S2); - DpTable[joined] = MakeJoin(DpTable[S1], DpTable[S2], e1.JoinConditions, GraceJoin); - TEdge e2 = Graph.FindCrossingEdge(S2, S1); - std::shared_ptr newJoin = - MakeJoin(DpTable[S2], DpTable[S1], e2.JoinConditions, GraceJoin); - if (newJoin->Stats->Cost < DpTable[joined]->Stats->Cost){ - DpTable[joined] = newJoin; - } + DpTable[joined] = bestJoin; } else { - TEdge e1 = Graph.FindCrossingEdge(S1, S2); - std::shared_ptr newJoin1 = - MakeJoin(DpTable[S1], DpTable[S2], e1.JoinConditions, GraceJoin); - TEdge e2 = Graph.FindCrossingEdge(S2, S1); - std::shared_ptr newJoin2 = - MakeJoin(DpTable[S2], DpTable[S1], e2.JoinConditions, GraceJoin); - if (newJoin1->Stats->Cost < DpTable[joined]->Stats->Cost){ - DpTable[joined] = newJoin1; - } - if (newJoin2->Stats->Cost < DpTable[joined]->Stats->Cost){ - DpTable[joined] = newJoin2; + if (bestJoin->Stats->Cost < DpTable[joined]->Stats->Cost) { + DpTable[joined] = bestJoin; } } + /* + * This is a sanity check that slows down the optimizer + * + auto pair = std::make_pair(S1, S2); Y_ENSURE (!CheckTable.contains(pair), "Check table already contains pair S1|S2"); CheckTable[ std::pair,std::bitset>(S1, S2) ] = true; + */ } /** @@ -782,9 +854,10 @@ TExprBase RearrangeEquiJoinTree(TExprContext& ctx, const TCoEquiJoin& equiJoin, } bool DqCollectJoinRelationsWithStats( + TVector>& rels, TTypeAnnotationContext& typesCtx, const TCoEquiJoin& equiJoin, - const std::function&)>& collector) + const std::function>&, TStringBuf, const TExprNode::TPtr, const std::shared_ptr&)>& collector) { if (equiJoin.ArgCount() < 3) { return false; @@ -808,7 +881,7 @@ bool DqCollectJoinRelationsWithStats( TStringBuf label = scope.Cast(); auto stats = maybeStat->second; - collector(label, stats); + collector(rels, label, joinArg.Ptr(), stats); } return true; } @@ -861,7 +934,7 @@ std::shared_ptr ConvertToJoinTree(const TCoEquiJoinTuple& jo TJoinColumn(rightScope, rightColumn))); } - return std::make_shared(left,right,joinConds,ConvertToJoinKind(joinTuple.Type().StringValue())); + return std::make_shared(left, right, joinConds, ConvertToJoinKind(joinTuple.Type().StringValue()), EJoinAlgoType::DictJoin); } /** @@ -919,14 +992,14 @@ void ExtractRelsAndJoinConditions(const std::shared_ptr& joi /** * Recursively computes statistics for a join tree */ -void ComputeStatistics(const std::shared_ptr& join) { +void ComputeStatistics(const std::shared_ptr& join, IProviderContext& ctx) { if (join->LeftArg->Kind == EOptimizerNodeKind::JoinNodeType) { - ComputeStatistics(static_pointer_cast(join->LeftArg)); + ComputeStatistics(static_pointer_cast(join->LeftArg), ctx); } if (join->RightArg->Kind == EOptimizerNodeKind::JoinNodeType) { - ComputeStatistics(static_pointer_cast(join->RightArg)); + ComputeStatistics(static_pointer_cast(join->RightArg), ctx); } - join->Stats = std::make_shared(ComputeJoinStats(*join->LeftArg->Stats, *join->RightArg->Stats, join->JoinConditions, EJoinImplType::DictJoin)); + join->Stats = std::make_shared(ComputeJoinStats(*join->LeftArg->Stats, *join->RightArg->Stats, join->JoinConditions, EJoinAlgoType::DictJoin, ctx)); } /** @@ -934,10 +1007,9 @@ void ComputeStatistics(const std::shared_ptr& join) { * The root of the subtree that needs to be optimizer needs to be reorderable, otherwise we will * only update the statistics for it and return it unchanged */ -std::shared_ptr OptimizeSubtree(const std::shared_ptr& joinTree, ui32 maxDPccpDPTableSize) { +std::shared_ptr OptimizeSubtree(const std::shared_ptr& joinTree, ui32 maxDPccpDPTableSize, IProviderContext& ctx) { if (!joinTree->IsReorderable) { - joinTree->Stats = std::make_shared(ComputeJoinStats(*joinTree->LeftArg->Stats, *joinTree->RightArg->Stats, joinTree->JoinConditions, EJoinImplType::DictJoin)); - return joinTree; + return PickBestNonReorderabeJoin(joinTree->LeftArg, joinTree->RightArg, joinTree->JoinConditions, ctx); } TGraph<64> joinGraph; @@ -954,7 +1026,7 @@ std::shared_ptr OptimizeSubtree(const std::shared_ptr= 64) { - ComputeStatistics(joinTree); + ComputeStatistics(joinTree, ctx); return joinTree; } @@ -981,12 +1053,12 @@ std::shared_ptr OptimizeSubtree(const std::shared_ptr solver(joinGraph,rels); + TDPccpSolver<64> solver(joinGraph, rels, ctx); // Check that the dynamic table of DPccp is not too big // If it is, just compute the statistics for the join tree and return it if (solver.CountCC(maxDPccpDPTableSize) >= maxDPccpDPTableSize) { - ComputeStatistics(joinTree); + ComputeStatistics(joinTree, ctx); return joinTree; } @@ -1005,8 +1077,8 @@ std::shared_ptr OptimizeSubtree(const std::shared_ptr JoinSearch(const std::shared_ptr& joinTree) override { // Traverse the join tree and generate a list of non-orderable joins in a post-order @@ -1016,16 +1088,16 @@ class TOptimizerNativeNew: public IOptimizerNew { // For all non-orderable joins, optimize the children for( auto join : nonOrderables ) { if (join->LeftArg->Kind == EOptimizerNodeKind::JoinNodeType) { - join->LeftArg = OptimizeSubtree(static_pointer_cast(join->LeftArg), MaxDPccpDPTableSize); + join->LeftArg = OptimizeSubtree(static_pointer_cast(join->LeftArg), MaxDPccpDPTableSize, Pctx); } if (join->RightArg->Kind == EOptimizerNodeKind::JoinNodeType) { - join->RightArg = OptimizeSubtree(static_pointer_cast(join->RightArg), MaxDPccpDPTableSize); + join->RightArg = OptimizeSubtree(static_pointer_cast(join->RightArg), MaxDPccpDPTableSize, Pctx); } - join->Stats = std::make_shared(ComputeJoinStats(*join->LeftArg->Stats, *join->RightArg->Stats, join->JoinConditions, EJoinImplType::DictJoin)); + join->Stats = std::make_shared(ComputeJoinStats(*join->LeftArg->Stats, *join->RightArg->Stats, join->JoinConditions, EJoinAlgoType::DictJoin, Pctx)); } // Optimize the root - return OptimizeSubtree(joinTree, MaxDPccpDPTableSize); + return OptimizeSubtree(joinTree, MaxDPccpDPTableSize, Pctx); } const ui32 MaxDPccpDPTableSize; @@ -1041,9 +1113,10 @@ class TOptimizerNativeNew: public IOptimizerNew { * and finally optimizes the root of the tree */ TExprBase DqOptimizeEquiJoinWithCosts(const TExprBase& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx, - bool ruleEnabled, ui32 maxDPccpDPTableSize) { + ui32 optLevel, ui32 maxDPccpDPTableSize, IProviderContext& providerCtx, + const std::function>&, TStringBuf, const TExprNode::TPtr, const std::shared_ptr&)>& providerCollect) { - if (!ruleEnabled) { + if (optLevel==0) { return node; } @@ -1065,9 +1138,8 @@ TExprBase DqOptimizeEquiJoinWithCosts(const TExprBase& node, TExprContext& ctx, // Check that statistics for all inputs of equiJoin were computed // The arguments of the EquiJoin are 1..n-2, n-2 is the actual join tree // of the EquiJoin and n-1 argument are the parameters to EquiJoin - if (!DqCollectJoinRelationsWithStats(typesCtx, equiJoin, [&](auto label, auto stat) { - rels.emplace_back(std::make_shared(TString(label), stat)); - })) { + + if (!DqCollectJoinRelationsWithStats(rels, typesCtx, equiJoin, providerCollect)){ return node; } @@ -1078,7 +1150,7 @@ TExprBase DqOptimizeEquiJoinWithCosts(const TExprBase& node, TExprContext& ctx, // Generate an initial tree auto joinTree = ConvertToJoinTree(joinTuple,rels); - auto opt = TOptimizerNativeNew(maxDPccpDPTableSize); + auto opt = TOptimizerNativeNew(providerCtx, maxDPccpDPTableSize); joinTree = opt.JoinSearch(joinTree); // rewrite the join tree and record the output statistics @@ -1097,7 +1169,8 @@ class TOptimizerNative: public IOptimizer { } TOutput JoinSearch() override { - TDPccpSolver<64> solver(JoinGraph, Rels); + auto dummyProviderCtx = TDummyProviderContext(); + TDPccpSolver<64> solver(JoinGraph, Rels, dummyProviderCtx); std::shared_ptr result = solver.Solve(); if (Log) { std::stringstream str; diff --git a/ydb/library/yql/dq/opt/dq_opt_join_cost_based_generic.cpp b/ydb/library/yql/dq/opt/dq_opt_join_cost_based_generic.cpp index 75a1bbf1a107..591737af615f 100644 --- a/ydb/library/yql/dq/opt/dq_opt_join_cost_based_generic.cpp +++ b/ydb/library/yql/dq/opt/dq_opt_join_cost_based_generic.cpp @@ -157,11 +157,11 @@ TExprBase DqOptimizeEquiJoinWithCosts( TExprContext& ctx, TTypeAnnotationContext& typesCtx, const std::function& optFactory, - bool ruleEnabled) + ui32 optLevel) { Y_UNUSED(ctx); - if (!ruleEnabled) { + if (optLevel==0) { return node; } @@ -184,7 +184,10 @@ TExprBase DqOptimizeEquiJoinWithCosts( TState state(equiJoin); // collect Rels - if (!DqCollectJoinRelationsWithStats(typesCtx, equiJoin, [&](auto label, auto stat) { + TVector> rels; + if (!DqCollectJoinRelationsWithStats(rels, typesCtx, equiJoin, [&](auto r, auto label, auto node, auto stat) { + Y_UNUSED(r); + Y_UNUSED(node); state.CollectRel(label, stat); })) { return node; diff --git a/ydb/library/yql/dq/opt/dq_opt_log.h b/ydb/library/yql/dq/opt/dq_opt_log.h index 83061e1b3de6..0c140b9d99d0 100644 --- a/ydb/library/yql/dq/opt/dq_opt_log.h +++ b/ydb/library/yql/dq/opt/dq_opt_log.h @@ -11,6 +11,9 @@ namespace NYql { struct TTypeAnnotationContext; struct TDqSettings; + struct IProviderContext; + struct TRelOptimizerNode; + struct TOptimizerStatistics; } namespace NYql::NDq { @@ -19,14 +22,21 @@ NNodes::TExprBase DqRewriteAggregate(NNodes::TExprBase node, TExprContext& ctx, NNodes::TExprBase DqRewriteTakeSortToTopSort(NNodes::TExprBase node, TExprContext& ctx, const TParentsMap& parents); -NNodes::TExprBase DqOptimizeEquiJoinWithCosts(const NNodes::TExprBase& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx, bool isRuleEnabled, ui32 maxDPccpDPTableSize); +NNodes::TExprBase DqOptimizeEquiJoinWithCosts( + const NNodes::TExprBase& node, + TExprContext& ctx, + TTypeAnnotationContext& typesCtx, + ui32 optLevel, + ui32 maxDPccpDPTableSize, + IProviderContext& providerCtx, + const std::function>&, TStringBuf, const TExprNode::TPtr, const std::shared_ptr&)>& providerCollect); NNodes::TExprBase DqOptimizeEquiJoinWithCosts( const NNodes::TExprBase& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx, const std::function& optFactory, - bool ruleEnabled); + ui32 optLevel); NNodes::TExprBase DqRewriteEquiJoin(const NNodes::TExprBase& node, TExprContext& ctx); diff --git a/ydb/library/yql/dq/opt/dq_opt_stat.cpp b/ydb/library/yql/dq/opt/dq_opt_stat.cpp index 747bda5ca5dc..aed5076e9e31 100644 --- a/ydb/library/yql/dq/opt/dq_opt_stat.cpp +++ b/ydb/library/yql/dq/opt/dq_opt_stat.cpp @@ -93,7 +93,7 @@ bool IsConstantExpr(const TExprNode::TPtr& input) { * Compute statistics for map join * FIX: Currently we treat all join the same from the cost perspective, need to refine cost function */ -void InferStatisticsForMapJoin(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) { +void InferStatisticsForMapJoin(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx, const IProviderContext& ctx) { auto inputNode = TExprBase(input); auto join = inputNode.Cast(); @@ -118,14 +118,14 @@ void InferStatisticsForMapJoin(const TExprNode::TPtr& input, TTypeAnnotationCont } typeCtx->SetStats(join.Raw(), std::make_shared( - ComputeJoinStats(*leftStats, *rightStats, leftJoinKeys, rightJoinKeys, MapJoin))); + ComputeJoinStats(*leftStats, *rightStats, leftJoinKeys, rightJoinKeys, MapJoin, ctx))); } /** * Compute statistics for grace join * FIX: Currently we treat all join the same from the cost perspective, need to refine cost function */ -void InferStatisticsForGraceJoin(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) { +void InferStatisticsForGraceJoin(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx, const IProviderContext& ctx) { auto inputNode = TExprBase(input); auto join = inputNode.Cast(); @@ -150,7 +150,7 @@ void InferStatisticsForGraceJoin(const TExprNode::TPtr& input, TTypeAnnotationCo } typeCtx->SetStats(join.Raw(), std::make_shared( - ComputeJoinStats(*leftStats, *rightStats, leftJoinKeys, rightJoinKeys, GraceJoin))); + ComputeJoinStats(*leftStats, *rightStats, leftJoinKeys, rightJoinKeys, GraceJoin, ctx))); } /** diff --git a/ydb/library/yql/dq/opt/dq_opt_stat.h b/ydb/library/yql/dq/opt/dq_opt_stat.h index 7a5f9542764b..4f3497da4fcd 100644 --- a/ydb/library/yql/dq/opt/dq_opt_stat.h +++ b/ydb/library/yql/dq/opt/dq_opt_stat.h @@ -1,6 +1,7 @@ #include "dq_opt.h" #include +#include namespace NYql::NDq { @@ -14,8 +15,8 @@ void PropagateStatisticsToLambdaArgument(const TExprNode::TPtr& input, TTypeAnno void PropagateStatisticsToStageArguments(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx); void InferStatisticsForStage(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx); void InferStatisticsForDqSource(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx); -void InferStatisticsForGraceJoin(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx); -void InferStatisticsForMapJoin(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx); +void InferStatisticsForGraceJoin(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx, const IProviderContext& ctx); +void InferStatisticsForMapJoin(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx, const IProviderContext& ctx); double ComputePredicateSelectivity(const NNodes::TExprBase& input, const std::shared_ptr& stats); bool NeedCalc(NNodes::TExprBase node); bool IsConstantExpr(const TExprNode::TPtr& input); diff --git a/ydb/library/yql/dq/opt/dq_opt_stat_transformer_base.cpp b/ydb/library/yql/dq/opt/dq_opt_stat_transformer_base.cpp index 4284d30cecaf..f29e814444ca 100644 --- a/ydb/library/yql/dq/opt/dq_opt_stat_transformer_base.cpp +++ b/ydb/library/yql/dq/opt/dq_opt_stat_transformer_base.cpp @@ -7,8 +7,8 @@ namespace NYql::NDq { using namespace NNodes; -TDqStatisticsTransformerBase::TDqStatisticsTransformerBase(TTypeAnnotationContext* typeCtx) - : TypeCtx(typeCtx) +TDqStatisticsTransformerBase::TDqStatisticsTransformerBase(TTypeAnnotationContext* typeCtx, const IProviderContext& ctx) + : TypeCtx(typeCtx), Pctx(ctx) { } IGraphTransformer::TStatus TDqStatisticsTransformerBase::DoTransform(TExprNode::TPtr input, TExprNode::TPtr& output, TExprContext& ctx) { @@ -55,10 +55,10 @@ bool TDqStatisticsTransformerBase::BeforeLambdas(const TExprNode::TPtr& input, T // Join matchers else if(TCoMapJoinCore::Match(input.Get())) { - InferStatisticsForMapJoin(input, TypeCtx); + InferStatisticsForMapJoin(input, TypeCtx, Pctx); } else if(TCoGraceJoinCore::Match(input.Get())) { - InferStatisticsForGraceJoin(input, TypeCtx); + InferStatisticsForGraceJoin(input, TypeCtx, Pctx); } // Do nothing in case of EquiJoin, otherwise the EquiJoin rule won't fire diff --git a/ydb/library/yql/dq/opt/dq_opt_stat_transformer_base.h b/ydb/library/yql/dq/opt/dq_opt_stat_transformer_base.h index 45ff27acc310..8201832b6458 100644 --- a/ydb/library/yql/dq/opt/dq_opt_stat_transformer_base.h +++ b/ydb/library/yql/dq/opt/dq_opt_stat_transformer_base.h @@ -2,12 +2,13 @@ #include #include +#include namespace NYql::NDq { class TDqStatisticsTransformerBase : public TSyncTransformerBase { public: - TDqStatisticsTransformerBase(TTypeAnnotationContext* typeCtx); + TDqStatisticsTransformerBase(TTypeAnnotationContext* typeCtx, const IProviderContext& ctx); IGraphTransformer::TStatus DoTransform(TExprNode::TPtr input, TExprNode::TPtr& output, TExprContext& ctx) override; void Rewind() override; @@ -21,6 +22,7 @@ class TDqStatisticsTransformerBase : public TSyncTransformerBase { bool AfterLambdas(const TExprNode::TPtr& input, TExprContext& ctx); TTypeAnnotationContext* TypeCtx; + const IProviderContext& Pctx; }; } // namespace NYql::NDq diff --git a/ydb/library/yql/providers/dq/common/yql_dq_settings.h b/ydb/library/yql/providers/dq/common/yql_dq_settings.h index b50224a76e67..3530a3178e8b 100644 --- a/ydb/library/yql/providers/dq/common/yql_dq_settings.h +++ b/ydb/library/yql/providers/dq/common/yql_dq_settings.h @@ -56,6 +56,7 @@ struct TDqSettings { static constexpr bool ExportStats = false; static constexpr ETaskRunnerStats TaskRunnerStats = ETaskRunnerStats::Basic; static constexpr ESpillingEngine SpillingEngine = ESpillingEngine::Disable; + static constexpr ui32 CostBasedOptimizationLevel = 0; static constexpr ui32 MaxDPccpDPTableSize = 10000U; }; diff --git a/ydb/library/yql/providers/dq/provider/yql_dq_datasource.cpp b/ydb/library/yql/providers/dq/provider/yql_dq_datasource.cpp index e91c019b5a22..886d5787b45f 100644 --- a/ydb/library/yql/providers/dq/provider/yql_dq_datasource.cpp +++ b/ydb/library/yql/providers/dq/provider/yql_dq_datasource.cpp @@ -50,7 +50,7 @@ class TDqDataProviderSource: public TDataProviderBase { , ExecTransformer_([this, execTransformerFactory] () { return THolder(execTransformerFactory(State_)); }) , TypeAnnotationTransformer_([] () { return CreateDqsDataSourceTypeAnnotationTransformer(); }) , ConstraintsTransformer_([] () { return CreateDqDataSourceConstraintTransformer(); }) - , StatisticsTransformer_([this]() { return CreateDqsStatisticsTransformer(State_); }) + , StatisticsTransformer_([this]() { return CreateDqsStatisticsTransformer(State_, TDummyProviderContext::instance()); }) { } TStringBuf GetName() const override { diff --git a/ydb/library/yql/providers/dq/provider/yql_dq_statistics.cpp b/ydb/library/yql/providers/dq/provider/yql_dq_statistics.cpp index 4d44e7cd03fd..36b3e4603155 100644 --- a/ydb/library/yql/providers/dq/provider/yql_dq_statistics.cpp +++ b/ydb/library/yql/providers/dq/provider/yql_dq_statistics.cpp @@ -14,8 +14,8 @@ using namespace NNodes; class TDqsStatisticsTransformer : public NDq::TDqStatisticsTransformerBase { public: - TDqsStatisticsTransformer(const TDqStatePtr& state) - : NDq::TDqStatisticsTransformerBase(state->TypeCtx) + TDqsStatisticsTransformer(const TDqStatePtr& state, const IProviderContext& ctx) + : NDq::TDqStatisticsTransformerBase(state->TypeCtx, ctx) , State(state) { } @@ -55,8 +55,8 @@ class TDqsStatisticsTransformer : public NDq::TDqStatisticsTransformerBase { TDqStatePtr State; }; -THolder CreateDqsStatisticsTransformer(TDqStatePtr state) { - return MakeHolder(state); +THolder CreateDqsStatisticsTransformer(TDqStatePtr state, const IProviderContext& ctx) { + return MakeHolder(state, ctx); } } // namespace NYql diff --git a/ydb/library/yql/providers/dq/provider/yql_dq_statistics.h b/ydb/library/yql/providers/dq/provider/yql_dq_statistics.h index 460b363bf40c..6a5592c775fc 100644 --- a/ydb/library/yql/providers/dq/provider/yql_dq_statistics.h +++ b/ydb/library/yql/providers/dq/provider/yql_dq_statistics.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include @@ -9,6 +10,6 @@ namespace NYql { struct TDqState; using TDqStatePtr = TIntrusivePtr; -THolder CreateDqsStatisticsTransformer(TDqStatePtr state); +THolder CreateDqsStatisticsTransformer(TDqStatePtr state, const IProviderContext& ctx); } // namespace NYql