Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix](Nereids) fix or to in rule #23940

Merged
merged 1 commit into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@

import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.rules.OrToIn.OrToInContext;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
Expand Down Expand Up @@ -57,7 +55,7 @@
* adding any additional rule-specific fields to the default ExpressionRewriteContext. However, the entire expression
* rewrite framework always passes an ExpressionRewriteContext of type context to all rules.
*/
public class OrToIn extends DefaultExpressionRewriter<OrToInContext> implements
public class OrToIn extends DefaultExpressionRewriter<ExpressionRewriteContext> implements
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OrToInContext is useless? Can we remove it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have removed

ExpressionRewriteRule<ExpressionRewriteContext> {

public static final OrToIn INSTANCE = new OrToIn();
Expand All @@ -66,52 +64,47 @@ public class OrToIn extends DefaultExpressionRewriter<OrToInContext> implements

@Override
public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
return expr.accept(this, new OrToInContext());
return expr.accept(this, null);
}

@Override
public Expression visitCompoundPredicate(CompoundPredicate compoundPredicate, OrToInContext context) {
if (compoundPredicate instanceof And) {
return compoundPredicate.withChildren(compoundPredicate.child(0).accept(new OrToIn(),
new OrToInContext()),
compoundPredicate.child(1).accept(new OrToIn(),
new OrToInContext()));
}
List<Expression> expressions = ExpressionUtils.extractDisjunction(compoundPredicate);
public Expression visitOr(Or or, ExpressionRewriteContext ctx) {
Map<NamedExpression, Set<Literal>> slotNameToLiteral = new HashMap<>();
List<Expression> expressions = ExpressionUtils.extractDisjunction(or);
for (Expression expression : expressions) {
if (expression instanceof EqualTo) {
addSlotToLiteralMap((EqualTo) expression, context);
addSlotToLiteralMap((EqualTo) expression, slotNameToLiteral);
}
}
List<Expression> rewrittenOr = new ArrayList<>();
for (Map.Entry<NamedExpression, Set<Literal>> entry : context.slotNameToLiteral.entrySet()) {
for (Map.Entry<NamedExpression, Set<Literal>> entry : slotNameToLiteral.entrySet()) {
Set<Literal> literals = entry.getValue();
if (literals.size() >= REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) {
InPredicate inPredicate = new InPredicate(entry.getKey(), ImmutableList.copyOf(entry.getValue()));
rewrittenOr.add(inPredicate);
}
}
for (Expression expression : expressions) {
if (!ableToConvertToIn(expression, context)) {
rewrittenOr.add(expression);
if (!ableToConvertToIn(expression, slotNameToLiteral)) {
rewrittenOr.add(expression.accept(this, null));
}
}

return ExpressionUtils.or(rewrittenOr);
}

private void addSlotToLiteralMap(EqualTo equal, OrToInContext context) {
private void addSlotToLiteralMap(EqualTo equal, Map<NamedExpression, Set<Literal>> slotNameToLiteral) {
Expression left = equal.left();
Expression right = equal.right();
if (left instanceof NamedExpression && right instanceof Literal) {
addSlotToLiteral((NamedExpression) left, (Literal) right, context);
addSlotToLiteral((NamedExpression) left, (Literal) right, slotNameToLiteral);
}
if (right instanceof NamedExpression && left instanceof Literal) {
addSlotToLiteral((NamedExpression) right, (Literal) left, context);
addSlotToLiteral((NamedExpression) right, (Literal) left, slotNameToLiteral);
}
}

private boolean ableToConvertToIn(Expression expression, OrToInContext context) {
private boolean ableToConvertToIn(Expression expression, Map<NamedExpression, Set<Literal>> slotNameToLiteral) {
if (!(expression instanceof EqualTo)) {
return false;
}
Expand All @@ -126,24 +119,18 @@ private boolean ableToConvertToIn(Expression expression, OrToInContext context)
namedExpression = (NamedExpression) right;
}
return namedExpression != null
&& findSizeOfLiteralThatEqualToSameSlotInOr(namedExpression, context)
&& findSizeOfLiteralThatEqualToSameSlotInOr(namedExpression, slotNameToLiteral)
>= REWRITE_OR_TO_IN_PREDICATE_THRESHOLD;
}

public void addSlotToLiteral(NamedExpression namedExpression, Literal literal, OrToInContext context) {
Set<Literal> literals = context.slotNameToLiteral.computeIfAbsent(namedExpression, k -> new HashSet<>());
public void addSlotToLiteral(NamedExpression namedExpression, Literal literal,
Map<NamedExpression, Set<Literal>> slotNameToLiteral) {
Set<Literal> literals = slotNameToLiteral.computeIfAbsent(namedExpression, k -> new HashSet<>());
literals.add(literal);
}

public int findSizeOfLiteralThatEqualToSameSlotInOr(NamedExpression namedExpression, OrToInContext context) {
return context.slotNameToLiteral.getOrDefault(namedExpression, Collections.emptySet()).size();
}

/**
* Context of OrToIn
*/
public static class OrToInContext {
public final Map<NamedExpression, Set<Literal>> slotNameToLiteral = new HashMap<>();

public int findSizeOfLiteralThatEqualToSameSlotInOr(NamedExpression namedExpression,
Map<NamedExpression, Set<Literal>> slotNameToLiteral) {
return slotNameToLiteral.getOrDefault(namedExpression, Collections.emptySet()).size();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@
import java.util.List;
import java.util.Set;

public class OrToInTest extends ExpressionRewriteTestHelper {
class OrToInTest extends ExpressionRewriteTestHelper {

@Test
public void test1() {
void test1() {
String expr = "col1 = 1 or col1 = 2 or col1 = 3 and (col2 = 4)";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
Expand All @@ -59,7 +59,7 @@ public void test1() {
}

@Test
public void test2() {
void test2() {
String expr = "col1 = 1 and col1 = 3 and col2 = 3 or col2 = 4";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
Expand All @@ -68,7 +68,7 @@ public void test2() {
}

@Test
public void test3() {
void test3() {
String expr = "(col1 = 1 or col1 = 2) and (col2 = 3 or col2 = 4)";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
Expand All @@ -90,4 +90,23 @@ public void test3() {
}
}

@Test
void test4() {
String expr = "case when col = 1 or col = 2 or col = 3 then 1"
+ " when col = 4 or col = 5 or col = 6 then 1 else 0 end";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
Assertions.assertEquals("CASE WHEN col IN (1, 2, 3) THEN 1 WHEN col IN (4, 5, 6) THEN 1 ELSE 0 END",
rewritten.toSql());
}

@Test
void test5() {
String expr = "col = 1 or (col = 2 and (col = 3 or col = 4 or col = 5))";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
Assertions.assertEquals("((col = 1) OR ((col = 2) AND col IN (3, 4, 5)))",
rewritten.toSql());
}

}