Skip to content

Commit

Permalink
[SPARK-33677][SQL] Skip LikeSimplification rule if pattern contains a…
Browse files Browse the repository at this point in the history
…ny escapeChar
  • Loading branch information
luluorta committed Dec 7, 2020
1 parent e88f0d4 commit 65e0d63
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -543,27 +543,27 @@ object LikeSimplification extends Rule[LogicalPlan] {
private val equalTo = "([^_%]*)".r

def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case Like(input, Literal(pattern, StringType), escapeChar) =>
case l @ Like(input, Literal(pattern, StringType), escapeChar) =>
if (pattern == null) {
// If pattern is null, return null value directly, since "col like null" == null.
Literal(null, BooleanType)
} else {
val escapeStr = String.valueOf(escapeChar)
pattern.toString match {
case startsWith(prefix) if !prefix.endsWith(escapeStr) =>
case p if p.contains(escapeChar) => l
case startsWith(prefix) =>
StartsWith(input, Literal(prefix))
case endsWith(postfix) =>
EndsWith(input, Literal(postfix))
// 'a%a' pattern is basically same with 'a%' && '%a'.
// However, the additional `Length` condition is required to prevent 'a' match 'a%a'.
case startsAndEndsWith(prefix, postfix) if !prefix.endsWith(escapeStr) =>
case startsAndEndsWith(prefix, postfix) =>
And(GreaterThanOrEqual(Length(input), Literal(prefix.length + postfix.length)),
And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix))))
case contains(infix) if !infix.endsWith(escapeStr) =>
case contains(infix) =>
Contains(input, Literal(infix))
case equalTo(str) =>
EqualTo(input, Literal(str))
case _ => Like(input, Literal.create(pattern, StringType), escapeChar)
case _ => l
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,52 @@ class LikeSimplificationSuite extends PlanTest {
val optimized2 = Optimize.execute(originalQuery2.analyze)
comparePlans(optimized2, originalQuery2.analyze)
}

test("SPARK-33677: LikeSimplification should be skipped if pattern contains any escapeChar") {
val originalQuery1 =
testRelation
.where(('a like "abc%") || ('a like "\\abc%"))
val optimized1 = Optimize.execute(originalQuery1.analyze)
val correctAnswer1 = testRelation
.where(StartsWith('a, "abc") || ('a like "\\abc%"))
.analyze
comparePlans(optimized1, correctAnswer1)

val originalQuery2 =
testRelation
.where(('a like "%xyz") || ('a like "%xyz\\"))
val optimized2 = Optimize.execute(originalQuery2.analyze)
val correctAnswer2 = testRelation
.where(EndsWith('a, "xyz") || ('a like "%xyz\\"))
.analyze
comparePlans(optimized2, correctAnswer2)

val originalQuery3 =
testRelation
.where(('a like ("@bc%def", '@')) || ('a like "abc%def"))
val optimized3 = Optimize.execute(originalQuery3.analyze)
val correctAnswer3 = testRelation
.where(('a like ("@bc%def", '@')) ||
(Length('a) >= 6 && (StartsWith('a, "abc") && EndsWith('a, "def"))))
.analyze
comparePlans(optimized3, correctAnswer3)

val originalQuery4 =
testRelation
.where(('a like "%mn%") || ('a like ("%mn%", '%')))
val optimized4 = Optimize.execute(originalQuery4.analyze)
val correctAnswer4 = testRelation
.where(Contains('a, "mn") || ('a like ("%mn%", '%')))
.analyze
comparePlans(optimized4, correctAnswer4)

val originalQuery5 =
testRelation
.where(('a like "abc") || ('a like ("abbc", 'b')))
val optimized5 = Optimize.execute(originalQuery5.analyze)
val correctAnswer5 = testRelation
.where(('a === "abc") || ('a like ("abbc", 'b')))
.analyze
comparePlans(optimized5, correctAnswer5)
}
}
14 changes: 14 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3718,6 +3718,20 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
}
}

test("SPARK-33677: LikeSimplification should be skipped if pattern contains any escapeChar") {
withTable("string_tbl") {
sql("CREATE TABLE string_tbl USING parquet SELECT 'm@ca' AS s")

val e = intercept[AnalysisException] {
sql("SELECT s LIKE 'm%@ca' ESCAPE '%' FROM string_tbl").collect()
}
assert(e.message.contains("the pattern 'm%@ca' is invalid, " +
"the escape character is not allowed to precede '@'"))

checkAnswer(sql("SELECT s LIKE 'm@@ca' ESCAPE '@' FROM string_tbl"), Row(true))
}
}
}

case class Foo(bar: Option[String])

0 comments on commit 65e0d63

Please sign in to comment.