Skip to content

Commit

Permalink
Refactor the process of symbols in optimizer. (#4146)
Browse files Browse the repository at this point in the history
* Refactor the process of symbols in optimizer.

* Fix variable shadowing.

* Collect boundary from MatchResult directly.

* Check variable in TransformResult don't referenced by outside plan
nodes.

Co-authored-by: Sophie <84560950+Sophie-Xie@users.noreply.github.com>
  • Loading branch information
Shylock-Hg and Sophie-Xie committed Apr 13, 2022
1 parent dad1373 commit 8dbcd2e
Show file tree
Hide file tree
Showing 28 changed files with 139 additions and 33 deletions.
1 change: 1 addition & 0 deletions src/graph/context/Symbols.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ SymbolTable::SymbolTable(ObjectPool* objPool) {

Variable* SymbolTable::newVariable(std::string name) {
VLOG(1) << "New variable for: " << name;
DCHECK(vars_.find(name) == vars_.end());
auto* variable = objPool_->makeAndAdd<Variable>(name);
addVar(std::move(name), variable);
return variable;
Expand Down
15 changes: 15 additions & 0 deletions src/graph/optimizer/OptGroup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,21 @@ OptGroup::OptGroup(OptContext *ctx) noexcept : ctx_(ctx) {
void OptGroup::addGroupNode(OptGroupNode *groupNode) {
DCHECK(groupNode != nullptr);
DCHECK(groupNode->group() == this);
if (outputVar_.empty()) {
outputVar_ = groupNode->node()->outputVar();
} else {
DCHECK_EQ(outputVar_, groupNode->node()->outputVar());
}
groupNodes_.emplace_back(groupNode);
groupNode->node()->updateSymbols();
}

OptGroupNode *OptGroup::makeGroupNode(PlanNode *node) {
if (outputVar_.empty()) {
outputVar_ = node->outputVar();
} else {
DCHECK_EQ(outputVar_, node->outputVar());
}
groupNodes_.emplace_back(OptGroupNode::create(ctx_, node, this));
return groupNodes_.back();
}
Expand All @@ -70,16 +80,21 @@ Status OptGroup::explore(const OptRule *rule) {
NG_RETURN_IF_ERROR(groupNode->explore(rule));

// Find more equivalents
std::vector<OptGroup *> boundary;
auto status = rule->match(ctx_, groupNode);
if (!status.ok()) {
++iter;
continue;
}
ctx_->setChanged(true);
auto matched = std::move(status).value();
matched.collectBoundary(boundary);
auto resStatus = rule->transform(ctx_, matched);
NG_RETURN_IF_ERROR(resStatus);
auto result = std::move(resStatus).value();
DLOG_IF(WARNING, !result.checkDataFlow(boundary))
<< "Plan of transfromed result should keep input variable same with dependencies in rule "
<< rule->toString();
if (result.eraseAll) {
for (auto gnode : groupNodes_) {
gnode->node()->releaseSymbols();
Expand Down
5 changes: 5 additions & 0 deletions src/graph/optimizer/OptGroup.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class OptGroup final {
Status exploreUntilMaxRound(const OptRule *rule);
double getCost() const;
const graph::PlanNode *getPlan() const;
const std::string &outputVar() const {
return outputVar_;
}

private:
explicit OptGroup(OptContext *ctx) noexcept;
Expand All @@ -54,6 +57,8 @@ class OptGroup final {
OptContext *ctx_{nullptr};
std::list<OptGroupNode *> groupNodes_;
std::vector<const OptRule *> exploredRules_;
// The output variable should be same across the whole group.
std::string outputVar_;
};

class OptGroupNode final {
Expand Down
52 changes: 52 additions & 0 deletions src/graph/optimizer/OptRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ const PlanNode *MatchedResult::planNode(const std::vector<int32_t> &pos) const {
return DCHECK_NOTNULL(result->node)->node();
}

void MatchedResult::collectBoundary(std::vector<OptGroup *> &boundary) const {
if (dependencies.empty()) {
boundary.insert(boundary.end(), node->dependencies().begin(), node->dependencies().end());
} else {
for (const auto &dep : dependencies) {
dep.collectBoundary(boundary);
}
}
}

Pattern Pattern::create(graph::PlanNode::Kind kind, std::initializer_list<Pattern> patterns) {
return Pattern(kind, std::move(patterns));
}
Expand Down Expand Up @@ -76,6 +86,48 @@ StatusOr<MatchedResult> Pattern::match(const OptGroup *group) const {
return Status::Error();
}

bool OptRule::TransformResult::checkDataFlow(const std::vector<OptGroup *> &boundary) {
return std::all_of(
newGroupNodes.begin(), newGroupNodes.end(), [&boundary](const OptGroupNode *groupNode) {
return checkDataFlow(groupNode, boundary);
});
}

/*static*/ bool OptRule::TransformResult::checkDataFlow(const OptGroupNode *groupNode,
const std::vector<OptGroup *> &boundary) {
const auto &deps = groupNode->dependencies();
// reach the boundary
if (std::all_of(deps.begin(), deps.end(), [&boundary](OptGroup *dep) {
return std::find(boundary.begin(), boundary.end(), dep) != boundary.end();
})) {
return true;
}
const auto *group = groupNode->group();
if (std::find(boundary.begin(), boundary.end(), group) != boundary.end()) {
return true;
}
// Check dataflow
const auto *node = groupNode->node();
if (node->inputVars().size() == deps.size()) {
// Don't check when count of dependencies is different from count of input variables
for (std::size_t i = 0; i < deps.size(); i++) {
const OptGroup *dep = deps[i];
if (node->inputVar(i) != dep->outputVar()) {
return false;
}
// Only use by father plan node
if (node->inputVars()[i]->readBy.size() != 1) {
return false;
}
return std::all_of(
dep->groupNodes().begin(), dep->groupNodes().end(), [&boundary](const OptGroupNode *gn) {
return checkDataFlow(gn, boundary);
});
}
}
return true;
}

StatusOr<MatchedResult> OptRule::match(OptContext *ctx, const OptGroupNode *groupNode) const {
const auto &pattern = this->pattern();
auto status = pattern.match(groupNode);
Expand Down
7 changes: 7 additions & 0 deletions src/graph/optimizer/OptRule.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ struct MatchedResult {
// {0, 1, 0} | this->dependencies[1].dependencies[0]
// {0, 1, 0, 1} | this->dependencies[1].dependencies[0].dependencies[1]
const graph::PlanNode *planNode(const std::vector<int32_t> &pos = {}) const;

void collectBoundary(std::vector<OptGroup *> &boundary) const;
};

// Match plan node by trait or kind of plan node.
Expand Down Expand Up @@ -86,6 +88,11 @@ class OptRule {
return kNoTrans;
}

// The plan of result should keep dataflow same as dependencies
bool checkDataFlow(const std::vector<OptGroup *> &boundary);
static bool checkDataFlow(const OptGroupNode *groupNode,
const std::vector<OptGroup *> &boundary);

bool eraseCurr{false};
bool eraseAll{false};
std::vector<OptGroupNode *> newGroupNodes;
Expand Down
1 change: 1 addition & 0 deletions src/graph/optimizer/rule/CollapseProjectRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ StatusOr<OptRule::TransformResult> CollapseProjectRule::transform(

// 4. rebuild OptGroupNode
newProj->setInputVar(projBelow->inputVar());
newProj->setOutputVar(projAbove->outputVar());
auto* newGroupNode = OptGroupNode::create(octx, newProj, projGroup);
newGroupNode->setDeps(groupNodeBelow->dependencies());

Expand Down
1 change: 1 addition & 0 deletions src/graph/optimizer/rule/GetEdgesTransformRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ StatusOr<OptRule::TransformResult> GetEdgesTransformRule::transform(

auto newAppendVertices = appendVertices->clone();
auto colSize = appendVertices->colNames().size();
newAppendVertices->setOutputVar(appendVertices->outputVar());
newAppendVertices->setColNames(
{appendVertices->colNames()[colSize - 2], appendVertices->colNames()[colSize - 1]});
auto newAppendVerticesGroupNode =
Expand Down
1 change: 1 addition & 0 deletions src/graph/optimizer/rule/IndexScanRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ StatusOr<OptRule::TransformResult> IndexScanRule::transform(OptContext* ctx,
const auto* oldIN = groupNode->node();
DCHECK_EQ(oldIN->kind(), graph::PlanNode::Kind::kIndexScan);
auto* newIN = static_cast<IndexScan*>(oldIN->clone());
newIN->setOutputVar(oldIN->outputVar());
newIN->setIndexQueryContext(std::move(iqctx));
auto newGroupNode = OptGroupNode::create(ctx, newIN, groupNode->group());
if (groupNode->dependencies().size() != 1) {
Expand Down
1 change: 1 addition & 0 deletions src/graph/optimizer/rule/MergeGetNbrsAndDedupRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ StatusOr<OptRule::TransformResult> MergeGetNbrsAndDedupRule::transform(
newGN->setDedup();
}
newGN->setInputVar(dedup->inputVar());
newGN->setOutputVar(gn->outputVar());
auto newOptGV = OptGroupNode::create(octx, newGN, optGN->group());
for (auto dep : optDedup->dependencies()) {
newOptGV->dependsOn(dep);
Expand Down
1 change: 1 addition & 0 deletions src/graph/optimizer/rule/MergeGetNbrsAndProjectRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ StatusOr<OptRule::TransformResult> MergeGetNbrsAndProjectRule::transform(
auto srcExpr = column->expr()->clone();
newGN->setSrc(srcExpr);
newGN->setInputVar(project->inputVar());
newGN->setOutputVar(gn->outputVar());
auto newOptGV = OptGroupNode::create(ctx, newGN, optGN->group());
for (auto dep : optProj->dependencies()) {
newOptGV->dependsOn(dep);
Expand Down
1 change: 1 addition & 0 deletions src/graph/optimizer/rule/MergeGetVerticesAndDedupRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ StatusOr<OptRule::TransformResult> MergeGetVerticesAndDedupRule::transform(
newGV->setDedup();
}
newGV->setInputVar(dedup->inputVar());
newGV->setOutputVar(gv->outputVar());
auto newOptGV = OptGroupNode::create(ctx, newGV, optGV->group());
for (auto dep : optDedup->dependencies()) {
newOptGV->dependsOn(dep);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ StatusOr<OptRule::TransformResult> MergeGetVerticesAndProjectRule::transform(
auto srcExpr = column->expr()->clone();
newGV->setSrc(srcExpr);
newGV->setInputVar(project->inputVar());
newGV->setOutputVar(gv->outputVar());
auto newOptGV = OptGroupNode::create(ctx, newGV, optGV->group());
for (auto dep : optProj->dependencies()) {
newOptGV->dependsOn(dep);
Expand Down
5 changes: 3 additions & 2 deletions src/graph/optimizer/rule/PushFilterDownGetNbrsRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ StatusOr<OptRule::TransformResult> PushFilterDownGetNbrsRule::transform(

auto remainedExpr = std::move(visitor).remainedExpr();
OptGroupNode *newFilterGroupNode = nullptr;
PlanNode *newFilter = nullptr;
if (remainedExpr != nullptr) {
auto newFilter = Filter::make(qctx, nullptr, remainedExpr);
newFilter = Filter::make(qctx, nullptr, remainedExpr);
newFilter->setOutputVar(filter->outputVar());
newFilter->setInputVar(filter->inputVar());
newFilterGroupNode = OptGroupNode::create(ctx, newFilter, filterGroupNode->group());
}

Expand All @@ -84,6 +84,7 @@ StatusOr<OptRule::TransformResult> PushFilterDownGetNbrsRule::transform(
// Filter(A&&B)<-GetNeighbors(C) => Filter(A)<-GetNeighbors(B&&C)
auto newGroup = OptGroup::create(ctx);
newGnGroupNode = newGroup->makeGroupNode(newGN);
newFilter->setInputVar(newGN->outputVar());
newFilterGroupNode->dependsOn(newGroup);
} else {
// Filter(A)<-GetNeighbors(C) => GetNeighbors(A&&C)
Expand Down
3 changes: 3 additions & 0 deletions src/graph/optimizer/rule/PushFilterDownProjectRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,15 @@ StatusOr<OptRule::TransformResult> PushFilterDownProjectRule::transform(
auto newProjGroup = OptGroup::create(octx);
auto newProjGroupNode = newProjGroup->makeGroupNode(newProjNode);
newProjGroupNode->setDeps({newBelowFilterGroup});
newProjNode->setInputVar(newBelowFilterNode->outputVar());
newAboveFilterGroupNode->setDeps({newProjGroup});
newAboveFilterNode->setInputVar(newProjNode->outputVar());
result.newGroupNodes.emplace_back(newAboveFilterGroupNode);
} else {
newProjNode->setOutputVar(oldFilterNode->outputVar());
auto newProjGroupNode = OptGroupNode::create(octx, newProjNode, filterGroupNode->group());
newProjGroupNode->setDeps({newBelowFilterGroup});
newProjNode->setInputVar(newBelowFilterNode->outputVar());
result.newGroupNodes.emplace_back(newProjGroupNode);
}

Expand Down
5 changes: 4 additions & 1 deletion src/graph/optimizer/rule/PushFilterDownScanVerticesRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ StatusOr<OptRule::TransformResult> PushFilterDownScanVerticesRule::transform(

auto remainedExpr = std::move(visitor).remainedExpr();
OptGroupNode *newFilterGroupNode = nullptr;
PlanNode *newFilter = nullptr;
if (remainedExpr != nullptr) {
auto newFilter = Filter::make(qctx, nullptr, remainedExpr);
newFilter = Filter::make(qctx, nullptr, remainedExpr);
newFilter->setOutputVar(filter->outputVar());
newFilter->setInputVar(filter->inputVar());
newFilterGroupNode = OptGroupNode::create(ctx, newFilter, filterGroupNode->group());
Expand All @@ -72,8 +73,10 @@ StatusOr<OptRule::TransformResult> PushFilterDownScanVerticesRule::transform(
OptGroupNode *newSvGroupNode = nullptr;
if (newFilterGroupNode != nullptr) {
// Filter(A&&B)<-ScanVertices(C) => Filter(A)<-ScanVertices(B&&C)
// newSV->regenerateOutputVar();
auto newGroup = OptGroup::create(ctx);
newSvGroupNode = newGroup->makeGroupNode(newSV);
newFilter->setInputVar(newSV->outputVar());
newFilterGroupNode->dependsOn(newGroup);
} else {
// Filter(A)<-ScanVertices(C) => ScanVertices(A&&C)
Expand Down
2 changes: 2 additions & 0 deletions src/graph/optimizer/rule/PushLimitDownGetNeighborsRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ StatusOr<OptRule::TransformResult> PushLimitDownGetNeighborsRule::transform(
}

auto newLimit = static_cast<Limit *>(limit->clone());
newLimit->setOutputVar(limit->outputVar());
auto newLimitGroupNode = OptGroupNode::create(octx, newLimit, limitGroupNode->group());

auto newGn = static_cast<GetNeighbors *>(gn->clone());
Expand All @@ -58,6 +59,7 @@ StatusOr<OptRule::TransformResult> PushLimitDownGetNeighborsRule::transform(
auto newGnGroupNode = newGnGroup->makeGroupNode(newGn);

newLimitGroupNode->dependsOn(newGnGroup);
newLimit->setInputVar(newGn->outputVar());
for (auto dep : gnGroupNode->dependencies()) {
newGnGroupNode->dependsOn(dep);
}
Expand Down
2 changes: 2 additions & 0 deletions src/graph/optimizer/rule/PushLimitDownIndexScanRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ StatusOr<OptRule::TransformResult> PushLimitDownIndexScanRule::transform(
}

auto newLimit = static_cast<Limit *>(limit->clone());
newLimit->setOutputVar(limit->outputVar());
auto newLimitGroupNode = OptGroupNode::create(octx, newLimit, limitGroupNode->group());

auto newIndexScan = static_cast<IndexScan *>(indexScan->clone());
Expand All @@ -66,6 +67,7 @@ StatusOr<OptRule::TransformResult> PushLimitDownIndexScanRule::transform(
auto newIndexScanGroupNode = newIndexScanGroup->makeGroupNode(newIndexScan);

newLimitGroupNode->dependsOn(newIndexScanGroup);
newLimit->setInputVar(newIndexScan->outputVar());
for (auto dep : indexScanGroupNode->dependencies()) {
newIndexScanGroupNode->dependsOn(dep);
}
Expand Down
6 changes: 3 additions & 3 deletions src/graph/optimizer/rule/PushLimitDownProjectRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,18 @@ StatusOr<OptRule::TransformResult> PushLimitDownProjectRule::transform(

auto newLimit = static_cast<Limit *>(limit->clone());
auto newLimitGroup = OptGroup::create(octx);
auto newLimitGroupNode = newLimitGroup->makeGroupNode(newLimit);
auto projInputVar = proj->inputVar();
newLimit->setOutputVar(proj->outputVar());
// newLimit->regenerateOutputVar();
newLimit->setInputVar(projInputVar);
auto *varPtr = octx->qctx()->symTable()->getVar(projInputVar);
DCHECK(!!varPtr);
newLimit->setColNames(varPtr->colNames);
auto newLimitGroupNode = newLimitGroup->makeGroupNode(newLimit);

auto newProj = static_cast<Project *>(proj->clone());
auto newProjGroupNode = OptGroupNode::create(octx, newProj, limitGroupNode->group());
newProj->setOutputVar(limit->outputVar());
newProj->setInputVar(newLimit->outputVar());
auto newProjGroupNode = OptGroupNode::create(octx, newProj, limitGroupNode->group());

newProjGroupNode->dependsOn(const_cast<OptGroup *>(newLimitGroupNode->group()));
for (auto dep : projGroupNode->dependencies()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ StatusOr<OptRule::TransformResult> PushLimitDownScanAppendVerticesRule::transfor
}

auto newLimit = static_cast<Limit *>(limit->clone());
newLimit->setOutputVar(limit->outputVar());
auto newLimitGroupNode = OptGroupNode::create(octx, newLimit, limitGroupNode->group());

auto newAppendVertices = static_cast<AppendVertices *>(appendVertices->clone());
Expand All @@ -85,7 +86,9 @@ StatusOr<OptRule::TransformResult> PushLimitDownScanAppendVerticesRule::transfor
auto newScanVerticesGroupNode = newScanVerticesGroup->makeGroupNode(newScanVertices);

newLimitGroupNode->dependsOn(newAppendVerticesGroup);
newLimit->setInputVar(newAppendVertices->outputVar());
newAppendVerticesGroupNode->dependsOn(newScanVerticesGroup);
newAppendVertices->setInputVar(newScanVertices->outputVar());
for (auto dep : scanVerticesGroupNode->dependencies()) {
newScanVerticesGroupNode->dependsOn(dep);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ StatusOr<OptRule::TransformResult> PushLimitDownScanEdgesAppendVerticesRule::tra
}

auto newLimit = static_cast<Limit *>(limit->clone());
newLimit->setOutputVar(limit->outputVar());
auto newLimitGroupNode = OptGroupNode::create(octx, newLimit, limitGroupNode->group());

auto newAppendVertices = static_cast<AppendVertices *>(appendVertices->clone());
Expand All @@ -93,8 +94,11 @@ StatusOr<OptRule::TransformResult> PushLimitDownScanEdgesAppendVerticesRule::tra
auto newScanEdgesGroupNode = newScanEdgesGroup->makeGroupNode(newScanEdges);

newLimitGroupNode->dependsOn(newAppendVerticesGroup);
newLimit->setInputVar(newAppendVertices->outputVar());
newAppendVerticesGroupNode->dependsOn(newProjGroup);
newAppendVertices->setInputVar(newProj->outputVar());
newProjGroupNode->dependsOn(newScanEdgesGroup);
newProj->setInputVar(newScanEdges->outputVar());
for (auto dep : scanEdgesGroupNode->dependencies()) {
newScanEdgesGroupNode->dependsOn(dep);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ StatusOr<OptRule::TransformResult> PushStepLimitDownGetNeighborsRule::transform(
}

auto newLimit = static_cast<Limit *>(limit->clone());
newLimit->setOutputVar(limit->outputVar());
auto newLimitGroupNode = OptGroupNode::create(octx, newLimit, limitGroupNode->group());

auto newGn = static_cast<GetNeighbors *>(gn->clone());
Expand All @@ -59,6 +60,7 @@ StatusOr<OptRule::TransformResult> PushStepLimitDownGetNeighborsRule::transform(
auto newGnGroupNode = newGnGroup->makeGroupNode(newGn);

newLimitGroupNode->dependsOn(newGnGroup);
newLimit->setInputVar(newGn->outputVar());
for (auto dep : gnGroupNode->dependencies()) {
newGnGroupNode->dependsOn(dep);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ StatusOr<OptRule::TransformResult> PushStepSampleDownGetNeighborsRule::transform
}

auto newSample = static_cast<Sample *>(sample->clone());
newSample->setOutputVar(sample->outputVar());
auto newSampleGroupNode = OptGroupNode::create(octx, newSample, sampleGroupNode->group());

auto newGn = static_cast<GetNeighbors *>(gn->clone());
Expand All @@ -58,6 +59,7 @@ StatusOr<OptRule::TransformResult> PushStepSampleDownGetNeighborsRule::transform
auto newGnGroupNode = newGnGroup->makeGroupNode(newGn);

newSampleGroupNode->dependsOn(newGnGroup);
newSample->setInputVar(newGn->outputVar());
for (auto dep : gnGroupNode->dependencies()) {
newGnGroupNode->dependsOn(dep);
}
Expand Down
Loading

0 comments on commit 8dbcd2e

Please sign in to comment.