Skip to content

Commit

Permalink
[fix](Nereids) fix or to in rule (#23940)
Browse files Browse the repository at this point in the history
or expression context can't propagation cross or expression.

for example:
```
select (a = 1 or a = 2 or a = 3) + (a = 4 or a = 5 or a = 6)
= select a in [1, 2, 3] + a in [4,5,6]
!= select a in [1, 2, 3] + a in [1, 2, 3, 4, 5, 6]
```
  • Loading branch information
keanji-x authored and xiaokang committed Sep 6, 2023
1 parent aec036e commit ab228fd
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@

import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.rules.OrToIn.OrToInContext;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
Expand Down Expand Up @@ -57,7 +55,7 @@
* adding any additional rule-specific fields to the default ExpressionRewriteContext. However, the entire expression
* rewrite framework always passes an ExpressionRewriteContext of type context to all rules.
*/
public class OrToIn extends DefaultExpressionRewriter<OrToInContext> implements
public class OrToIn extends DefaultExpressionRewriter<ExpressionRewriteContext> implements
ExpressionRewriteRule<ExpressionRewriteContext> {

public static final OrToIn INSTANCE = new OrToIn();
Expand All @@ -66,52 +64,47 @@ public class OrToIn extends DefaultExpressionRewriter<OrToInContext> implements

@Override
public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
return expr.accept(this, new OrToInContext());
return expr.accept(this, null);
}

@Override
public Expression visitCompoundPredicate(CompoundPredicate compoundPredicate, OrToInContext context) {
if (compoundPredicate instanceof And) {
return compoundPredicate.withChildren(compoundPredicate.child(0).accept(new OrToIn(),
new OrToInContext()),
compoundPredicate.child(1).accept(new OrToIn(),
new OrToInContext()));
}
List<Expression> expressions = ExpressionUtils.extractDisjunction(compoundPredicate);
public Expression visitOr(Or or, ExpressionRewriteContext ctx) {
Map<NamedExpression, Set<Literal>> slotNameToLiteral = new HashMap<>();
List<Expression> expressions = ExpressionUtils.extractDisjunction(or);
for (Expression expression : expressions) {
if (expression instanceof EqualTo) {
addSlotToLiteralMap((EqualTo) expression, context);
addSlotToLiteralMap((EqualTo) expression, slotNameToLiteral);
}
}
List<Expression> rewrittenOr = new ArrayList<>();
for (Map.Entry<NamedExpression, Set<Literal>> entry : context.slotNameToLiteral.entrySet()) {
for (Map.Entry<NamedExpression, Set<Literal>> entry : slotNameToLiteral.entrySet()) {
Set<Literal> literals = entry.getValue();
if (literals.size() >= REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) {
InPredicate inPredicate = new InPredicate(entry.getKey(), ImmutableList.copyOf(entry.getValue()));
rewrittenOr.add(inPredicate);
}
}
for (Expression expression : expressions) {
if (!ableToConvertToIn(expression, context)) {
rewrittenOr.add(expression);
if (!ableToConvertToIn(expression, slotNameToLiteral)) {
rewrittenOr.add(expression.accept(this, null));
}
}

return ExpressionUtils.or(rewrittenOr);
}

private void addSlotToLiteralMap(EqualTo equal, OrToInContext context) {
private void addSlotToLiteralMap(EqualTo equal, Map<NamedExpression, Set<Literal>> slotNameToLiteral) {
Expression left = equal.left();
Expression right = equal.right();
if (left instanceof NamedExpression && right instanceof Literal) {
addSlotToLiteral((NamedExpression) left, (Literal) right, context);
addSlotToLiteral((NamedExpression) left, (Literal) right, slotNameToLiteral);
}
if (right instanceof NamedExpression && left instanceof Literal) {
addSlotToLiteral((NamedExpression) right, (Literal) left, context);
addSlotToLiteral((NamedExpression) right, (Literal) left, slotNameToLiteral);
}
}

private boolean ableToConvertToIn(Expression expression, OrToInContext context) {
private boolean ableToConvertToIn(Expression expression, Map<NamedExpression, Set<Literal>> slotNameToLiteral) {
if (!(expression instanceof EqualTo)) {
return false;
}
Expand All @@ -126,24 +119,18 @@ private boolean ableToConvertToIn(Expression expression, OrToInContext context)
namedExpression = (NamedExpression) right;
}
return namedExpression != null
&& findSizeOfLiteralThatEqualToSameSlotInOr(namedExpression, context)
&& findSizeOfLiteralThatEqualToSameSlotInOr(namedExpression, slotNameToLiteral)
>= REWRITE_OR_TO_IN_PREDICATE_THRESHOLD;
}

public void addSlotToLiteral(NamedExpression namedExpression, Literal literal, OrToInContext context) {
Set<Literal> literals = context.slotNameToLiteral.computeIfAbsent(namedExpression, k -> new HashSet<>());
public void addSlotToLiteral(NamedExpression namedExpression, Literal literal,
Map<NamedExpression, Set<Literal>> slotNameToLiteral) {
Set<Literal> literals = slotNameToLiteral.computeIfAbsent(namedExpression, k -> new HashSet<>());
literals.add(literal);
}

public int findSizeOfLiteralThatEqualToSameSlotInOr(NamedExpression namedExpression, OrToInContext context) {
return context.slotNameToLiteral.getOrDefault(namedExpression, Collections.emptySet()).size();
}

/**
* Context of OrToIn
*/
public static class OrToInContext {
public final Map<NamedExpression, Set<Literal>> slotNameToLiteral = new HashMap<>();

public int findSizeOfLiteralThatEqualToSameSlotInOr(NamedExpression namedExpression,
Map<NamedExpression, Set<Literal>> slotNameToLiteral) {
return slotNameToLiteral.getOrDefault(namedExpression, Collections.emptySet()).size();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@
import java.util.List;
import java.util.Set;

public class OrToInTest extends ExpressionRewriteTestHelper {
class OrToInTest extends ExpressionRewriteTestHelper {

@Test
public void test1() {
void test1() {
String expr = "col1 = 1 or col1 = 2 or col1 = 3 and (col2 = 4)";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
Expand All @@ -59,7 +59,7 @@ public void test1() {
}

@Test
public void test2() {
void test2() {
String expr = "col1 = 1 and col1 = 3 and col2 = 3 or col2 = 4";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
Expand All @@ -68,7 +68,7 @@ public void test2() {
}

@Test
public void test3() {
void test3() {
String expr = "(col1 = 1 or col1 = 2) and (col2 = 3 or col2 = 4)";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
Expand All @@ -90,4 +90,23 @@ public void test3() {
}
}

@Test
void test4() {
String expr = "case when col = 1 or col = 2 or col = 3 then 1"
+ " when col = 4 or col = 5 or col = 6 then 1 else 0 end";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
Assertions.assertEquals("CASE WHEN col IN (1, 2, 3) THEN 1 WHEN col IN (4, 5, 6) THEN 1 ELSE 0 END",
rewritten.toSql());
}

@Test
void test5() {
String expr = "col = 1 or (col = 2 and (col = 3 or col = 4 or col = 5))";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
Assertions.assertEquals("((col = 1) OR ((col = 2) AND col IN (3, 4, 5)))",
rewritten.toSql());
}

}

0 comments on commit ab228fd

Please sign in to comment.