Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring to make cardinality estimation provider-specific #3866

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions ydb/core/kqp/opt/kqp_statistics_transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ void InferStatisticsForReadTable(const TExprNode::TPtr& input, TTypeAnnotationCo

YQL_CLOG(TRACE, CoreDq) << "Infer statistics for read table, nrows:" << nRows << ", nattrs: " << nAttrs;

auto outputStats = TOptimizerStatistics(EStatisticsType::BaseTable, nRows, nAttrs, byteSize, 0.0, tableData.Metadata->KeyColumnNames);
typeCtx->SetStats(input.Get(), std::make_shared<TOptimizerStatistics>(outputStats));
typeCtx->SetStats(input.Get(), std::make_shared<TOptimizerStatistics>(EStatisticsType::BaseTable, nRows, nAttrs, byteSize, 0.0, tableData.Metadata->KeyColumnNames));
}

/**
Expand All @@ -63,8 +62,7 @@ void InferStatisticsForKqpTable(const TExprNode::TPtr& input, TTypeAnnotationCon
int nAttrs = tableData.Metadata->Columns.size();
YQL_CLOG(TRACE, CoreDq) << "Infer statistics for table: " << path.Value() << ", nrows: " << nRows << ", nattrs: " << nAttrs << ", nKeyColumns: " << tableData.Metadata->KeyColumnNames.size();

auto outputStats = TOptimizerStatistics(EStatisticsType::BaseTable, nRows, nAttrs, byteSize, 0.0, tableData.Metadata->KeyColumnNames);
typeCtx->SetStats(input.Get(), std::make_shared<TOptimizerStatistics>(outputStats));
typeCtx->SetStats(input.Get(), std::make_shared<TOptimizerStatistics>(EStatisticsType::BaseTable, nRows, nAttrs, byteSize, 0.0, tableData.Metadata->KeyColumnNames));
}

/**
Expand All @@ -84,8 +82,7 @@ void InferStatisticsForSteamLookup(const TExprNode::TPtr& input, TTypeAnnotation
auto inputStats = typeCtx->GetStats(streamLookup.Table().Raw());
auto byteSize = inputStats->ByteSize * (nAttrs / (double) inputStats->Ncols);

auto outputStats = TOptimizerStatistics(EStatisticsType::BaseTable, inputStats->Nrows, nAttrs, byteSize, 0, inputStats->KeyColumns);
typeCtx->SetStats(input.Get(), std::make_shared<TOptimizerStatistics>(outputStats));
typeCtx->SetStats(input.Get(), std::make_shared<TOptimizerStatistics>(EStatisticsType::BaseTable, inputStats->Nrows, nAttrs, byteSize, 0, inputStats->KeyColumns));
}

/**
Expand Down Expand Up @@ -116,8 +113,7 @@ void InferStatisticsForLookupTable(const TExprNode::TPtr& input, TTypeAnnotation
byteSize = 10;
}

auto outputStats = TOptimizerStatistics(EStatisticsType::BaseTable, nRows, nAttrs, byteSize, 0, inputStats->KeyColumns);
typeCtx->SetStats(input.Get(), std::make_shared<TOptimizerStatistics>(outputStats));
typeCtx->SetStats(input.Get(), std::make_shared<TOptimizerStatistics>(EStatisticsType::BaseTable, nRows, nAttrs, byteSize, 0, inputStats->KeyColumns));
}

/**
Expand Down Expand Up @@ -151,17 +147,15 @@ void InferStatisticsForRowsSourceSettings(const TExprNode::TPtr& input, TTypeAnn
double cost = inputStats->Cost;
double byteSize = inputStats->ByteSize * (nAttrs / (double)inputStats->Ncols);

auto outputStats = TOptimizerStatistics(EStatisticsType::BaseTable, nRows, nAttrs, byteSize, cost, inputStats->KeyColumns);
typeCtx->SetStats(input.Get(), std::make_shared<TOptimizerStatistics>(outputStats));
typeCtx->SetStats(input.Get(), std::make_shared<TOptimizerStatistics>(EStatisticsType::BaseTable, nRows, nAttrs, byteSize, cost, inputStats->KeyColumns));
}

/**
* Compute statistics for index lookup
* Currently we just make up a number for cardinality (5) and set cost to 0
*/
void InferStatisticsForIndexLookup(const TExprNode::TPtr& input, TTypeAnnotationContext* typeCtx) {
auto outputStats = TOptimizerStatistics(EStatisticsType::BaseTable, 5, 5, 20, 0.0);
typeCtx->SetStats(input.Get(), std::make_shared<TOptimizerStatistics>(outputStats));
typeCtx->SetStats(input.Get(), std::make_shared<TOptimizerStatistics>(EStatisticsType::BaseTable, 5, 5, 20, 0.0));
}

/***
Expand Down
4 changes: 2 additions & 2 deletions ydb/core/kqp/opt/logical/kqp_opt_cbo.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct TKqpRelOptimizerNode : public NYql::TRelOptimizerNode {
/**
* KQP Specific cost function and join applicability cost function
*/
struct TKqpProviderContext : public NYql::IProviderContext {
struct TKqpProviderContext : public NYql::TBaseProviderContext {
TKqpProviderContext(const TKqpOptimizeContext& kqpCtx, const int optLevel) : KqpCtx(kqpCtx), OptLevel(optLevel) {}

virtual bool IsJoinApplicable(const std::shared_ptr<NYql::IBaseOptimizerNode>& left,
Expand All @@ -35,4 +35,4 @@ struct TKqpProviderContext : public NYql::IProviderContext {
int OptLevel;
};

}
}
119 changes: 119 additions & 0 deletions ydb/library/yql/core/cbo/cbo_optimizer_new.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,123 @@ void TJoinOptimizerNode::Print(std::stringstream& stream, int ntabs) {
RightArg->Print(stream, ntabs+1);
}

bool IsPKJoin(const TOptimizerStatistics& stats, const TVector<TString>& joinKeys) {
if (stats.KeyColumns.size()==0) {
return false;
}

for(size_t i=0; i<stats.KeyColumns.size(); i++){
if (std::find(joinKeys.begin(), joinKeys.end(), stats.KeyColumns[i]) == joinKeys.end()) {
return false;
}
}
return true;
}

bool TBaseProviderContext::IsJoinApplicable(const std::shared_ptr<IBaseOptimizerNode>& left,
const std::shared_ptr<IBaseOptimizerNode>& right,
const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>>& joinConditions,
const TVector<TString>& leftJoinKeys,
const TVector<TString>& rightJoinKeys,
EJoinAlgoType joinAlgo) {

Y_UNUSED(left);
Y_UNUSED(right);
Y_UNUSED(joinConditions);
Y_UNUSED(leftJoinKeys);
Y_UNUSED(rightJoinKeys);

return joinAlgo == EJoinAlgoType::MapJoin;
}

double TBaseProviderContext::ComputeJoinCost(const TOptimizerStatistics& leftStats, const TOptimizerStatistics& rightStats, const double outputRows, const double outputByteSize, EJoinAlgoType joinAlgo) const {
Y_UNUSED(outputByteSize);
Y_UNUSED(joinAlgo);
return leftStats.Nrows + 2.0 * rightStats.Nrows + outputRows;
}

/**
* Compute the cost and output cardinality of a join
*
* Currently a very basic computation targeted at GraceJoin
*
* The build is on the right side, so we make the build side a bit more expensive than the probe
*/
TOptimizerStatistics TBaseProviderContext::ComputeJoinStats(
const TOptimizerStatistics& leftStats,
const TOptimizerStatistics& rightStats,
const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>>& joinConditions,
EJoinAlgoType joinAlgo) const
{
TVector<TString> leftJoinKeys;
TVector<TString> rightJoinKeys;

for (auto c : joinConditions) {
leftJoinKeys.emplace_back(c.first.AttributeName);
rightJoinKeys.emplace_back(c.second.AttributeName);
}

return ComputeJoinStats(leftStats, rightStats, leftJoinKeys, rightJoinKeys, joinAlgo);
}

TOptimizerStatistics TBaseProviderContext::ComputeJoinStats(
const TOptimizerStatistics& leftStats,
const TOptimizerStatistics& rightStats,
const TVector<TString>& leftJoinKeys,
const TVector<TString>& rightJoinKeys,
EJoinAlgoType joinAlgo) const
{
double newCard;
EStatisticsType outputType;
bool leftKeyColumns = false;
bool rightKeyColumns = false;
double selectivity = 1.0;


if (IsPKJoin(rightStats,rightJoinKeys)) {
newCard = leftStats.Nrows * rightStats.Selectivity;
selectivity = leftStats.Selectivity * rightStats.Selectivity;
leftKeyColumns = true;
if (leftStats.Type == EStatisticsType::BaseTable){
outputType = EStatisticsType::FilteredFactTable;
} else {
outputType = leftStats.Type;
}
}
else if (IsPKJoin(leftStats,leftJoinKeys)) {
newCard = rightStats.Nrows;
newCard = rightStats.Nrows * leftStats.Selectivity;
selectivity = leftStats.Selectivity * rightStats.Selectivity;

rightKeyColumns = true;
if (rightStats.Type == EStatisticsType::BaseTable){
outputType = EStatisticsType::FilteredFactTable;
} else {
outputType = rightStats.Type;
}
}
else {
newCard = 0.2 * leftStats.Nrows * rightStats.Nrows;
outputType = EStatisticsType::ManyManyJoin;
}

int newNCols = leftStats.Ncols + rightStats.Ncols;
double newByteSize = leftStats.Nrows ? (leftStats.ByteSize / leftStats.Nrows) * newCard : 0 +
rightStats.Nrows ? (rightStats.ByteSize / rightStats.Nrows) * newCard : 0;

double cost = ComputeJoinCost(leftStats, rightStats, newCard, newByteSize, joinAlgo)
+ leftStats.Cost + rightStats.Cost;

auto result = TOptimizerStatistics(outputType, newCard, newNCols, newByteSize, cost,
leftKeyColumns ? leftStats.KeyColumns : ( rightKeyColumns ? rightStats.KeyColumns : TOptimizerStatistics::EmptyColumns));
result.Selectivity = selectivity;
return result;
}

const TBaseProviderContext& TBaseProviderContext::instance() {
static TBaseProviderContext staticContext;
return staticContext;
}


} // namespace NYql
135 changes: 71 additions & 64 deletions ydb/library/yql/core/cbo/cbo_optimizer_new.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@
#include <map>
#include <sstream>


namespace NYql {

/**
* OptimizerNodes are the internal representations of operators inside the
* Cost-based optimizer. Currently we only support RelOptimizerNode - a node that
* is an input relation to the equi-join, and JoinOptimizerNode - an inner join
* is an input relation to the equi-join, and JoinOptimizerNode - an inner join
* that connects two sets of relations.
*/
enum EOptimizerNodeKind: ui32
Expand All @@ -35,13 +34,76 @@ struct IBaseOptimizerNode {
std::shared_ptr<TOptimizerStatistics> Stats;

IBaseOptimizerNode(EOptimizerNodeKind k) : Kind(k) {}
IBaseOptimizerNode(EOptimizerNodeKind k, std::shared_ptr<TOptimizerStatistics> s) :
IBaseOptimizerNode(EOptimizerNodeKind k, std::shared_ptr<TOptimizerStatistics> s) :
Kind(k), Stats(s) {}

virtual TVector<TString> Labels()=0;
virtual void Print(std::stringstream& stream, int ntabs=0)=0;
};


/**
* This is a temporary structure for KQP provider
* We will soon be supporting multiple providers and we will need to design
* some interfaces to pass provider-specific context to the optimizer
*/
struct IProviderContext {
virtual ~IProviderContext() = default;

virtual double ComputeJoinCost(const TOptimizerStatistics& leftStats, const TOptimizerStatistics& rightStats, const double outputRows, const double outputByteSize, EJoinAlgoType joinAlgol) const = 0;

virtual TOptimizerStatistics ComputeJoinStats(
const TOptimizerStatistics& leftStats,
const TOptimizerStatistics& rightStats,
const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>>& joinConditions, EJoinAlgoType joinAlgo) const = 0;

virtual TOptimizerStatistics ComputeJoinStats(
const TOptimizerStatistics& leftStats,
const TOptimizerStatistics& rightStats,
const TVector<TString>& leftJoinKeys,
const TVector<TString>& rightJoinKeys,
EJoinAlgoType joinAlgo) const = 0;

virtual bool IsJoinApplicable(const std::shared_ptr<IBaseOptimizerNode>& left,
const std::shared_ptr<IBaseOptimizerNode>& right,
const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>>& joinConditions,
const TVector<TString>& leftJoinKeys,
const TVector<TString>& rightJoinKeys,
EJoinAlgoType joinAlgo) = 0;
};

/**
* Default provider context with default cost and stats computation.
*/

struct TBaseProviderContext : public IProviderContext {
TBaseProviderContext() {}

double ComputeJoinCost(const TOptimizerStatistics& leftStats, const TOptimizerStatistics& rightStats, const double outputRows, const double outputByteSize, EJoinAlgoType joinAlgo) const override;

bool IsJoinApplicable(const std::shared_ptr<IBaseOptimizerNode>& left,
const std::shared_ptr<IBaseOptimizerNode>& right,
const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>>& joinConditions,
const TVector<TString>& leftJoinKeys,
const TVector<TString>& rightJoinKeys,
EJoinAlgoType joinAlgo) override;

virtual TOptimizerStatistics ComputeJoinStats(
const TOptimizerStatistics& leftStats,
const TOptimizerStatistics& rightStats,
const TVector<TString>& leftJoinKeys,
const TVector<TString>& rightJoinKeys,
EJoinAlgoType joinAlgo) const override;

virtual TOptimizerStatistics ComputeJoinStats(
const TOptimizerStatistics& leftStats,
const TOptimizerStatistics& rightStats,
const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>>& joinConditions,
EJoinAlgoType joinAlgo) const override;

static const TBaseProviderContext& instance();
};

/**
* RelOptimizerNode adds a label to base class
* This is the label assinged to the input by equi-Join
Expand All @@ -52,9 +114,9 @@ struct TRelOptimizerNode : public IBaseOptimizerNode {
// Temporary solution to check if a LookupJoin is possible in KQP
//void* Expr;

TRelOptimizerNode(TString label, std::shared_ptr<TOptimizerStatistics> stats) :
TRelOptimizerNode(TString label, std::shared_ptr<TOptimizerStatistics> stats) :
IBaseOptimizerNode(RelNodeType, stats), Label(label) { }
//TRelOptimizerNode(TString label, std::shared_ptr<TOptimizerStatistics> stats, const TExprNode::TPtr expr) :
//TRelOptimizerNode(TString label, std::shared_ptr<TOptimizerStatistics> stats, const TExprNode::TPtr expr) :
// IBaseOptimizerNode(RelNodeType, stats), Label(label), Expr(expr) { }
virtual ~TRelOptimizerNode() {}

Expand All @@ -79,61 +141,6 @@ enum EJoinKind: ui32
EJoinKind ConvertToJoinKind(const TString& joinString);
TString ConvertToJoinString(const EJoinKind kind);

/**
* This is a temporary structure for KQP provider
* We will soon be supporting multiple providers and we will need to design
* some interfaces to pass provider-specific context to the optimizer
*/
struct IProviderContext {
virtual ~IProviderContext() = default;

virtual double ComputeJoinCost(const TOptimizerStatistics& leftStats, const TOptimizerStatistics& rightStats, const double outputRows, const double outputByteSize, EJoinAlgoType joinAlgol) const = 0;

virtual bool IsJoinApplicable(const std::shared_ptr<IBaseOptimizerNode>& left,
const std::shared_ptr<IBaseOptimizerNode>& right,
const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>>& joinConditions,
const TVector<TString>& leftJoinKeys,
const TVector<TString>& rightJoinKeys,
EJoinAlgoType joinAlgo) = 0;

};

/**
* Temporary solution for default provider context
*/

struct TDummyProviderContext : public IProviderContext {
TDummyProviderContext() {}

double ComputeJoinCost(const TOptimizerStatistics& leftStats, const TOptimizerStatistics& rightStats, const double outputRows, const double outputByteSize, EJoinAlgoType joinAlgo) const override {
Y_UNUSED(outputByteSize);
Y_UNUSED(joinAlgo);
return leftStats.Nrows + 2.0 * rightStats.Nrows + outputRows;
}

bool IsJoinApplicable(const std::shared_ptr<IBaseOptimizerNode>& left,
const std::shared_ptr<IBaseOptimizerNode>& right,
const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>>& joinConditions,
const TVector<TString>& leftJoinKeys,
const TVector<TString>& rightJoinKeys,
EJoinAlgoType joinAlgo) override {

Y_UNUSED(left);
Y_UNUSED(right);
Y_UNUSED(joinConditions);
Y_UNUSED(leftJoinKeys);
Y_UNUSED(rightJoinKeys);

return joinAlgo == EJoinAlgoType::MapJoin;
}

static const TDummyProviderContext& instance() {
static TDummyProviderContext staticContext;
return staticContext;
}

};

/**
* JoinOptimizerNode records the left and right arguments of the join
* as well as the set of join conditions.
Expand All @@ -150,11 +157,11 @@ struct TJoinOptimizerNode : public IBaseOptimizerNode {
EJoinAlgoType JoinAlgo;
bool IsReorderable;

TJoinOptimizerNode(const std::shared_ptr<IBaseOptimizerNode>& left,
const std::shared_ptr<IBaseOptimizerNode>& right,
TJoinOptimizerNode(const std::shared_ptr<IBaseOptimizerNode>& left,
const std::shared_ptr<IBaseOptimizerNode>& right,
const std::set<std::pair<NDq::TJoinColumn, NDq::TJoinColumn>>& joinConditions,
const EJoinKind joinType,
const EJoinAlgoType joinAlgo,
const EJoinKind joinType,
const EJoinAlgoType joinAlgo,
bool nonReorderable=false);
virtual ~TJoinOptimizerNode() {}
virtual TVector<TString> Labels();
Expand Down
Loading
Loading