Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature](Nereids): return residual expr of join #28760

Merged
merged 1 commit into from
Dec 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.StructInfoNode;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.rules.exploration.mv.ComparisonResult;
import org.apache.doris.nereids.rules.exploration.mv.LogicalCompatibilityContext;
import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughJoin;
import org.apache.doris.nereids.trees.expressions.Alias;
Expand All @@ -44,18 +45,21 @@

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;

/**
* The graph is a join graph, whose node is the leaf plan and edge is a join operator.
Expand Down Expand Up @@ -268,11 +272,11 @@ private void makeFilterConflictRules(JoinEdge joinEdge) {
filterEdges.forEach(e -> {
if (LongBitmap.isSubset(e.getReferenceNodes(), leftSubNodes)
&& !PushDownFilterThroughJoin.COULD_PUSH_THROUGH_LEFT.contains(joinEdge.getJoinType())) {
e.addRejectJoin(joinEdge);
e.addRejectEdge(joinEdge);
}
if (LongBitmap.isSubset(e.getReferenceNodes(), rightSubNodes)
&& !PushDownFilterThroughJoin.COULD_PUSH_THROUGH_RIGHT.contains(joinEdge.getJoinType())) {
e.addRejectJoin(joinEdge);
e.addRejectEdge(joinEdge);
}
});
}
Expand All @@ -289,19 +293,23 @@ private void makeJoinConflictRules(JoinEdge edgeB) {
JoinEdge childA = joinEdges.get(i);
if (!JoinType.isAssoc(childA.getJoinType(), edgeB.getJoinType())) {
leftRequired = LongBitmap.newBitmapUnion(leftRequired, childA.getLeftSubNodes(joinEdges));
childA.addRejectEdge(edgeB);
}
if (!JoinType.isLAssoc(childA.getJoinType(), edgeB.getJoinType())) {
leftRequired = LongBitmap.newBitmapUnion(leftRequired, childA.getRightSubNodes(joinEdges));
childA.addRejectEdge(edgeB);
}
}

for (int i = rightSubTreeEdges.nextSetBit(0); i >= 0; i = rightSubTreeEdges.nextSetBit(i + 1)) {
JoinEdge childA = joinEdges.get(i);
if (!JoinType.isAssoc(edgeB.getJoinType(), childA.getJoinType())) {
rightRequired = LongBitmap.newBitmapUnion(rightRequired, childA.getRightSubNodes(joinEdges));
childA.addRejectEdge(edgeB);
}
if (!JoinType.isRAssoc(edgeB.getJoinType(), childA.getJoinType())) {
rightRequired = LongBitmap.newBitmapUnion(rightRequired, childA.getLeftSubNodes(joinEdges));
childA.addRejectEdge(edgeB);
}
}
edgeB.setLeftExtendedNodes(leftRequired);
Expand Down Expand Up @@ -593,57 +601,75 @@ public int edgeSize() {
* compare hypergraph
*
* @param viewHG the compared hyper graph
* @return null represents not compatible, or return some expression which can
* be pull up from this hyper graph
* @return Comparison result
*/
public @Nullable List<Expression> isLogicCompatible(HyperGraph viewHG, LogicalCompatibilityContext ctx) {
Map<Edge, Edge> queryToView = constructEdgeMap(viewHG, ctx.getQueryToViewEdgeExpressionMapping());
public ComparisonResult isLogicCompatible(HyperGraph viewHG, LogicalCompatibilityContext ctx) {
// 1 try to construct a map which can be mapped from edge to edge
Map<Edge, Edge> queryToView = constructMapWithNode(viewHG, ctx.getQueryToViewNodeIDMapping());

// All edge in view must have a mapped edge in query
if (queryToView.size() != viewHG.edgeSize()) {
return null;
// 2. compare them by expression and extract residual expr
ComparisonResult.Builder builder = new ComparisonResult.Builder();
ComparisonResult edgeCompareRes = compareEdgesWithExpr(queryToView, ctx.getQueryToViewEdgeExpressionMapping());
if (edgeCompareRes.isInvalid()) {
return ComparisonResult.INVALID;
}
builder.addComparisonResult(edgeCompareRes);

boolean allMatch = queryToView.entrySet().stream()
.allMatch(entry ->
compareEdgeWithNode(entry.getKey(), entry.getValue(), ctx.getQueryToViewNodeIDMapping()));
if (!allMatch) {
return null;
// 3. pull join edge of view is no sense, so reject them
if (!queryToView.values().containsAll(viewHG.joinEdges)) {
return ComparisonResult.INVALID;
}

// join edges must be identical
boolean isJoinIdentical = joinEdges.stream()
.allMatch(queryToView::containsKey);
if (!isJoinIdentical) {
return null;
// 4. process residual edges
List<Expression> residualQueryJoin =
processOrphanEdges(Sets.difference(Sets.newHashSet(joinEdges), queryToView.keySet()));
if (residualQueryJoin == null) {
return ComparisonResult.INVALID;
}
builder.addQueryExpressions(residualQueryJoin);

// extract all top filters
List<FilterEdge> residualFilterEdges = filterEdges.stream()
.filter(e -> !queryToView.containsKey(e))
.collect(ImmutableList.toImmutableList());
if (residualFilterEdges.stream().anyMatch(e -> !e.isTopFilter())) {
return null;
List<Expression> residualQueryFilter =
processOrphanEdges(Sets.difference(Sets.newHashSet(filterEdges), queryToView.keySet()));
if (residualQueryFilter == null) {
return ComparisonResult.INVALID;
}
return residualFilterEdges.stream()
.flatMap(e -> e.getExpressions().stream())
.collect(ImmutableList.toImmutableList());
builder.addQueryExpressions(residualQueryFilter);

List<Expression> residualViewFilter =
processOrphanEdges(
Sets.difference(Sets.newHashSet(viewHG.filterEdges), Sets.newHashSet(queryToView.values())));
if (residualViewFilter == null) {
return ComparisonResult.INVALID;
}
builder.addViewExpressions(residualViewFilter);

return builder.build();
}

private Map<Edge, Edge> constructEdgeMap(HyperGraph viewHG, Map<Expression, Expression> exprMap) {
Map<Expression, Edge> exprToEdge = constructExprMap(viewHG);
Map<Edge, Edge> queryToView = new HashMap<>();
joinEdges.stream()
.filter(e -> !e.getExpressions().isEmpty()
&& exprMap.containsKey(e.getExpression(0))
&& compareEdgeWithExpr(e, exprToEdge.get(exprMap.get(e.getExpression(0))), exprMap))
.forEach(e -> queryToView.put(e, exprToEdge.get(exprMap.get(e.getExpression(0)))));
filterEdges.stream()
.filter(e -> !e.getExpressions().isEmpty()
&& exprMap.containsKey(e.getExpression(0))
&& compareEdgeWithExpr(e, exprToEdge.get(exprMap.get(e.getExpression(0))), exprMap))
.forEach(e -> queryToView.put(e, exprToEdge.get(exprMap.get(e.getExpression(0)))));
return queryToView;
private List<Expression> processOrphanEdges(Set<Edge> edges) {
List<Expression> expressions = new ArrayList<>();
for (Edge edge : edges) {
if (!edge.canPullUp()) {
return null;
}
expressions.addAll(edge.getExpressions());
}
return expressions;
}

private Map<Edge, Edge> constructMapWithNode(HyperGraph viewHG, Map<Integer, Integer> nodeMap) {
// TODO use hash map to reduce loop
Map<Edge, Edge> joinEdgeMap = joinEdges.stream().map(qe -> {
Optional<JoinEdge> viewEdge = viewHG.joinEdges.stream()
.filter(ve -> compareEdgeWithNode(qe, ve, nodeMap)).findFirst();
return Pair.of(qe, viewEdge);
}).filter(e -> e.second.isPresent()).collect(ImmutableMap.toImmutableMap(p -> p.first, p -> p.second.get()));
Map<Edge, Edge> filterEdgeMap = filterEdges.stream().map(qe -> {
Optional<FilterEdge> viewEdge = viewHG.filterEdges.stream()
.filter(ve -> compareEdgeWithNode(qe, ve, nodeMap)).findFirst();
return Pair.of(qe, viewEdge);
}).filter(e -> e.second.isPresent()).collect(ImmutableMap.toImmutableMap(p -> p.first, p -> p.second.get()));
return ImmutableMap.<Edge, Edge>builder().putAll(joinEdgeMap).putAll(filterEdgeMap).build();
}

private boolean compareEdgeWithNode(Edge t, Edge o, Map<Integer, Integer> nodeMap) {
Expand Down Expand Up @@ -686,24 +712,40 @@ private boolean compareNodeMap(long bitmap1, long bitmap2, Map<Integer, Integer>
return bitmap2 == newBitmap1;
}

private boolean compareEdgeWithExpr(Edge t, Edge o, Map<Expression, Expression> expressionMap) {
if (t.getExpressions().size() != o.getExpressions().size()) {
return false;
}
int size = t.getExpressions().size();
for (int i = 0; i < size; i++) {
if (!Objects.equals(expressionMap.get(t.getExpression(i)), o.getExpression(i))) {
return false;
private ComparisonResult compareEdgesWithExpr(Map<Edge, Edge> queryToViewedgeMap,
Map<Expression, Expression> queryToView) {
ComparisonResult.Builder builder = new ComparisonResult.Builder();
for (Entry<Edge, Edge> e : queryToViewedgeMap.entrySet()) {
ComparisonResult res = compareEdgeWithExpr(e.getKey(), e.getValue(), queryToView);
if (res.isInvalid()) {
return ComparisonResult.INVALID;
}
builder.addComparisonResult(res);
}
return true;
return builder.build();
}

private Map<Expression, Edge> constructExprMap(HyperGraph hyperGraph) {
Map<Expression, Edge> exprToEdge = new HashMap<>();
hyperGraph.joinEdges.forEach(edge -> edge.getExpressions().forEach(expr -> exprToEdge.put(expr, edge)));
hyperGraph.filterEdges.forEach(edge -> edge.getExpressions().forEach(expr -> exprToEdge.put(expr, edge)));
return exprToEdge;
private ComparisonResult compareEdgeWithExpr(Edge query, Edge view, Map<Expression, Expression> queryToView) {
Set<? extends Expression> queryExprSet = query.getExpressionSet();
Set<? extends Expression> viewExprSet = view.getExpressionSet();

Set<Expression> equalViewExpr = new HashSet<>();
List<Expression> residualQueryExpr = new ArrayList<>();
for (Expression queryExpr : queryExprSet) {
if (queryToView.containsKey(queryExpr) && viewExprSet.contains(queryToView.get(queryExpr))) {
equalViewExpr.add(queryToView.get(queryExpr));
} else {
residualQueryExpr.add(queryExpr);
}
}
List<Expression> residualViewExpr = ImmutableList.copyOf(Sets.difference(viewExprSet, equalViewExpr));
if (!residualViewExpr.isEmpty() && !view.canPullUp()) {
return ComparisonResult.INVALID;
}
if (!residualQueryExpr.isEmpty() && !query.canPullUp()) {
return ComparisonResult.INVALID;
}
return new ComparisonResult(residualQueryExpr, residualViewExpr);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;

import com.google.common.collect.ImmutableSet;

import java.util.BitSet;
import java.util.List;
import java.util.Set;
Expand Down Expand Up @@ -51,6 +53,8 @@ public abstract class Edge {
// record all sub nodes behind in this operator. It's T function in paper
private final long subTreeNodes;

private long rejectNodes = 0;

/**
* Create simple edge.
*/
Expand All @@ -71,6 +75,10 @@ public boolean isSimple() {
return LongBitmap.getCardinality(leftExtendedNodes) == 1 && LongBitmap.getCardinality(rightExtendedNodes) == 1;
}

public void addRejectEdge(Edge edge) {
rejectNodes = LongBitmap.newBitmapUnion(edge.getReferenceNodes(), rejectNodes);
}

public void addLeftExtendNode(long left) {
this.leftExtendedNodes = LongBitmap.or(this.leftExtendedNodes, left);
}
Expand Down Expand Up @@ -171,6 +179,20 @@ public double getSelectivity() {

public abstract List<? extends Expression> getExpressions();

public Set<? extends Expression> getExpressionSet() {
return ImmutableSet.copyOf(getExpressions());
}

public boolean canPullUp() {
// Only inner join and filter with none rejectNodes can be pull up
return rejectNodes == 0
&& !(this instanceof JoinEdge && !((JoinEdge) this).getJoinType().isInnerJoin());
}

public long getRejectNodes() {
return rejectNodes;
}

public Expression getExpression(int i) {
return getExpressions().get(i);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
import java.util.Set;
Expand All @@ -32,25 +31,11 @@
*/
public class FilterEdge extends Edge {
private final LogicalFilter<? extends Plan> filter;
private final List<Integer> rejectEdges;

public FilterEdge(LogicalFilter<? extends Plan> filter, int index,
BitSet childEdges, long subTreeNodes, long childRequireNodes) {
super(index, childEdges, new BitSet(), subTreeNodes, childRequireNodes, 0L);
this.filter = filter;
rejectEdges = new ArrayList<>();
}

public void addRejectJoin(JoinEdge joinEdge) {
rejectEdges.add(joinEdge.getIndex());
}

public List<Integer> getRejectEdges() {
return rejectEdges;
}

public boolean isTopFilter() {
return rejectEdges.isEmpty();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,14 @@ protected List<Plan> rewrite(Plan queryPlan, CascadesContext cascadesContext) {
LogicalCompatibilityContext compatibilityContext =
LogicalCompatibilityContext.from(queryToViewTableMapping, queryToViewSlotMapping,
queryStructInfo, viewStructInfo);
List<Expression> pulledUpExpressions = StructInfo.isGraphLogicalEquals(queryStructInfo, viewStructInfo,
ComparisonResult comparisonResult = StructInfo.isGraphLogicalEquals(queryStructInfo, viewStructInfo,
compatibilityContext);
if (pulledUpExpressions == null) {
if (comparisonResult.isInvalid()) {
logger.debug(currentClassName + " graph logical is not equals so continue");
continue;
}
// TODO: Use set of list? And consider view expr
List<Expression> pulledUpExpressions = ImmutableList.copyOf(comparisonResult.getQueryExpressions());
// set pulled up expression to queryStructInfo predicates and update related predicates
if (!pulledUpExpressions.isEmpty()) {
queryStructInfo.addPredicates(pulledUpExpressions);
Expand Down
Loading