From 65e0d63ad8bb5db14791f79d8a4e2572be493d3b Mon Sep 17 00:00:00 2001 From: luluorta Date: Sun, 6 Dec 2020 20:40:46 +0800 Subject: [PATCH 1/3] [SPARK-33677][SQL] Skip LikeSimplification rule if pattern contains any escapeChar --- .../sql/catalyst/optimizer/expressions.scala | 12 ++--- .../optimizer/LikeSimplificationSuite.scala | 48 +++++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 14 ++++++ 3 files changed, 68 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 1b1e2ad71e7c8..ca790f375075e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -543,27 +543,27 @@ object LikeSimplification extends Rule[LogicalPlan] { private val equalTo = "([^_%]*)".r def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case Like(input, Literal(pattern, StringType), escapeChar) => + case l @ Like(input, Literal(pattern, StringType), escapeChar) => if (pattern == null) { // If pattern is null, return null value directly, since "col like null" == null. Literal(null, BooleanType) } else { - val escapeStr = String.valueOf(escapeChar) pattern.toString match { - case startsWith(prefix) if !prefix.endsWith(escapeStr) => + case p if p.contains(escapeChar) => l + case startsWith(prefix) => StartsWith(input, Literal(prefix)) case endsWith(postfix) => EndsWith(input, Literal(postfix)) // 'a%a' pattern is basically same with 'a%' && '%a'. // However, the additional `Length` condition is required to prevent 'a' match 'a%a'. - case startsAndEndsWith(prefix, postfix) if !prefix.endsWith(escapeStr) => + case startsAndEndsWith(prefix, postfix) => And(GreaterThanOrEqual(Length(input), Literal(prefix.length + postfix.length)), And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix)))) - case contains(infix) if !infix.endsWith(escapeStr) => + case contains(infix) => Contains(input, Literal(infix)) case equalTo(str) => EqualTo(input, Literal(str)) - case _ => Like(input, Literal.create(pattern, StringType), escapeChar) + case _ => l } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala index 436f62e4225c8..1812dce0da426 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala @@ -116,4 +116,52 @@ class LikeSimplificationSuite extends PlanTest { val optimized2 = Optimize.execute(originalQuery2.analyze) comparePlans(optimized2, originalQuery2.analyze) } + + test("SPARK-33677: LikeSimplification should be skipped if pattern contains any escapeChar") { + val originalQuery1 = + testRelation + .where(('a like "abc%") || ('a like "\\abc%")) + val optimized1 = Optimize.execute(originalQuery1.analyze) + val correctAnswer1 = testRelation + .where(StartsWith('a, "abc") || ('a like "\\abc%")) + .analyze + comparePlans(optimized1, correctAnswer1) + + val originalQuery2 = + testRelation + .where(('a like "%xyz") || ('a like "%xyz\\")) + val optimized2 = Optimize.execute(originalQuery2.analyze) + val correctAnswer2 = testRelation + .where(EndsWith('a, "xyz") || ('a like "%xyz\\")) + .analyze + comparePlans(optimized2, correctAnswer2) + + val originalQuery3 = + testRelation + .where(('a like ("@bc%def", '@')) || ('a like "abc%def")) + val optimized3 = Optimize.execute(originalQuery3.analyze) + val correctAnswer3 = testRelation + .where(('a like ("@bc%def", '@')) || + (Length('a) >= 6 && (StartsWith('a, "abc") && EndsWith('a, "def")))) + .analyze + comparePlans(optimized3, correctAnswer3) + + val originalQuery4 = + testRelation + .where(('a like "%mn%") || ('a like ("%mn%", '%'))) + val optimized4 = Optimize.execute(originalQuery4.analyze) + val correctAnswer4 = testRelation + .where(Contains('a, "mn") || ('a like ("%mn%", '%'))) + .analyze + comparePlans(optimized4, correctAnswer4) + + val originalQuery5 = + testRelation + .where(('a like "abc") || ('a like ("abbc", 'b'))) + val optimized5 = Optimize.execute(originalQuery5.analyze) + val correctAnswer5 = testRelation + .where(('a === "abc") || ('a like ("abbc", 'b'))) + .analyze + comparePlans(optimized5, correctAnswer5) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 727482e551a8b..24248c9869595 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3718,6 +3718,20 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } } + + test("SPARK-33677: LikeSimplification should be skipped if pattern contains any escapeChar") { + withTable("string_tbl") { + sql("CREATE TABLE string_tbl USING parquet SELECT 'm@ca' AS s") + + val e = intercept[AnalysisException] { + sql("SELECT s LIKE 'm%@ca' ESCAPE '%' FROM string_tbl").collect() + } + assert(e.message.contains("the pattern 'm%@ca' is invalid, " + + "the escape character is not allowed to precede '@'")) + + checkAnswer(sql("SELECT s LIKE 'm@@ca' ESCAPE '@' FROM string_tbl"), Row(true)) + } + } } case class Foo(bar: Option[String]) From 765591c0030ac8fad6833c8f8e78c8169422b1ec Mon Sep 17 00:00:00 2001 From: luluorta Date: Mon, 7 Dec 2020 20:35:05 +0800 Subject: [PATCH 2/3] add comment --- .../apache/spark/sql/catalyst/optimizer/expressions.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index ca790f375075e..b2fc334ac893e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -549,6 +549,12 @@ object LikeSimplification extends Rule[LogicalPlan] { Literal(null, BooleanType) } else { pattern.toString match { + // There are three different situations when pattern containing escapeChar: + // 1. pattern contains invalid escape sequence, e.g. 'm\aca' + // 2. pattern contains escaped wildcard character, e.g. 'ma\%ca' + // 3. pattern contains escaped escape character, e.g. 'ma\\ca' + // Although there are patterns can be optimized if we handle the escape first, we just + // skip this rule if pattern contains any escapeChar for simplicity. case p if p.contains(escapeChar) => l case startsWith(prefix) => StartsWith(input, Literal(prefix)) From 58b8bd28945a4f9e95a2da2707d62d6eb5f33fa9 Mon Sep 17 00:00:00 2001 From: luluorta Date: Tue, 8 Dec 2020 10:43:48 +0800 Subject: [PATCH 3/3] use temp view in test --- .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 24248c9869595..2eeb729ece3fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3720,16 +3720,16 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } test("SPARK-33677: LikeSimplification should be skipped if pattern contains any escapeChar") { - withTable("string_tbl") { - sql("CREATE TABLE string_tbl USING parquet SELECT 'm@ca' AS s") + withTempView("df") { + Seq("m@ca").toDF("s").createOrReplaceTempView("df") val e = intercept[AnalysisException] { - sql("SELECT s LIKE 'm%@ca' ESCAPE '%' FROM string_tbl").collect() + sql("SELECT s LIKE 'm%@ca' ESCAPE '%' FROM df").collect() } assert(e.message.contains("the pattern 'm%@ca' is invalid, " + "the escape character is not allowed to precede '@'")) - checkAnswer(sql("SELECT s LIKE 'm@@ca' ESCAPE '@' FROM string_tbl"), Row(true)) + checkAnswer(sql("SELECT s LIKE 'm@@ca' ESCAPE '@' FROM df"), Row(true)) } } }