Skip to content

Commit

Permalink
[SPARK-22384][SQL][FOLLOWUP] Refine partition pruning when attribute …
Browse files Browse the repository at this point in the history
…is wrapped in Cast

## What changes were proposed in this pull request?

As mentioned in #21586 , `Cast.mayTruncate` is not 100% safe, string to boolean is allowed. Since changing `Cast.mayTruncate` also changes the behavior of Dataset, here I propose to add a new `Cast.canSafeCast` for partition pruning.

## How was this patch tested?

new test cases

Author: Wenchen Fan <wenchen@databricks.com>

Closes #21712 from cloud-fan/safeCast.
  • Loading branch information
cloud-fan authored and gatorsmile committed Jul 5, 2018
1 parent ca8243f commit bf764a3
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,26 @@ object Cast {
toPrecedence > 0 && fromPrecedence > toPrecedence
}

/**
* Returns true iff we can safely cast the `from` type to `to` type without any truncating or
* precision lose, e.g. int -> long, date -> timestamp.
*/
def canSafeCast(from: AtomicType, to: AtomicType): Boolean = (from, to) match {
case _ if from == to => true
case (from: NumericType, to: DecimalType) if to.isWiderThan(from) => true
case (from: DecimalType, to: NumericType) if from.isTighterThan(to) => true
case (from, to) if legalNumericPrecedence(from, to) => true
case (DateType, TimestampType) => true
case (_, StringType) => true
case _ => false
}

private def legalNumericPrecedence(from: DataType, to: DataType): Boolean = {
val fromPrecedence = TypeCoercion.numericPrecedence.indexOf(from)
val toPrecedence = TypeCoercion.numericPrecedence.indexOf(to)
fromPrecedence >= 0 && fromPrecedence < toPrecedence
}

def forceNullable(from: DataType, to: DataType): Boolean = (from, to) match {
case (NullType, _) => true
case (_, _) if from == to => false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ import org.apache.spark.sql.catalyst.analysis.NoSuchPermanentFunctionException
import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTablePartition, CatalogUtils, FunctionResource, FunctionResourceType}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegralType, StringType}
import org.apache.spark.sql.types.{AtomicType, IntegralType, StringType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -660,7 +660,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
def unapply(expr: Expression): Option[Attribute] = {
expr match {
case attr: Attribute => Some(attr)
case Cast(child, dt, _) if !Cast.mayTruncate(child.dataType, dt) => unapply(child)
case Cast(child @ AtomicType(), dt: AtomicType, _)
if Cast.canSafeCast(child.dataType.asInstanceOf[AtomicType], dt) => unapply(child)
case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType}

// TODO: Refactor this to `HivePartitionFilteringSuite`
class HiveClientSuite(version: String)
Expand Down Expand Up @@ -122,6 +122,22 @@ class HiveClientSuite(version: String)
"aa" :: Nil)
}

test("getPartitionsByFilter: cast(chunk as int)=1 (not a valid partition predicate)") {
testMetastorePartitionFiltering(
attr("chunk").cast(IntegerType) === 1,
20170101 to 20170103,
0 to 23,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
}

test("getPartitionsByFilter: cast(chunk as boolean)=true (not a valid partition predicate)") {
testMetastorePartitionFiltering(
attr("chunk").cast(BooleanType) === true,
20170101 to 20170103,
0 to 23,
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
}

test("getPartitionsByFilter: 20170101=ds") {
testMetastorePartitionFiltering(
Literal(20170101) === attr("ds"),
Expand All @@ -138,7 +154,7 @@ class HiveClientSuite(version: String)
"aa" :: "ab" :: "ba" :: "bb" :: Nil)
}

test("getPartitionsByFilter: chunk in cast(ds as long)=20170101L") {
test("getPartitionsByFilter: cast(ds as long)=20170101L and h=10") {
testMetastorePartitionFiltering(
attr("ds").cast(LongType) === 20170101L && attr("h") === 10,
20170101 to 20170101,
Expand Down

0 comments on commit bf764a3

Please sign in to comment.