Skip to content

Commit

Permalink
[SPARK-33477][SQL] Hive Metastore support filter by date type
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Hive Metastore supports strings and integral types in filters. It could also support dates. Please see [HIVE-5679](apache/hive@5106bf1) for more details.

This pr add support it.

### Why are the changes needed?

Improve query performance.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Unit test.

Closes #30408 from wangyum/SPARK-33477.

Authored-by: Yuming Wang <yumwang@ebay.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
  • Loading branch information
wangyum authored and HyukjinKwon committed Nov 25, 2020
1 parent c3ce970 commit 781e19c
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.{PartitioningUtils, SourceOptions}
import org.apache.spark.sql.hive.client.HiveClient
Expand Down Expand Up @@ -1264,11 +1264,13 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
defaultTimeZoneId: String): Seq[CatalogTablePartition] = withClient {
val rawTable = getRawTable(db, table)
val catalogTable = restoreTableMetadata(rawTable)
val timeZoneId = CaseInsensitiveMap(catalogTable.storage.properties).getOrElse(
DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)

val partColNameMap = buildLowerCasePartColNameMap(catalogTable)

val clientPrunedPartitions =
client.getPartitionsByFilter(rawTable, predicates).map { part =>
client.getPartitionsByFilter(rawTable, predicates, timeZoneId).map { part =>
part.copy(spec = restorePartitionSpec(part.spec, partColNameMap))
}
prunePartitionsByFilter(catalogTable, clientPrunedPartitions, predicates, defaultTimeZoneId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ private[hive] trait HiveClient {
/** Returns partitions filtered by predicates for the given table. */
def getPartitionsByFilter(
catalogTable: CatalogTable,
predicates: Seq[Expression]): Seq[CatalogTablePartition]
predicates: Seq[Expression],
timeZoneId: String): Seq[CatalogTablePartition]

/** Loads a static partition into an existing table. */
def loadPartition(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -733,9 +733,11 @@ private[hive] class HiveClientImpl(

override def getPartitionsByFilter(
table: CatalogTable,
predicates: Seq[Expression]): Seq[CatalogTablePartition] = withHiveState {
predicates: Seq[Expression],
timeZoneId: String): Seq[CatalogTablePartition] = withHiveState {
val hiveTable = toHiveTable(table, Some(userName))
val parts = shim.getPartitionsByFilter(client, hiveTable, predicates).map(fromHivePartition)
val parts = shim.getPartitionsByFilter(client, hiveTable, predicates, timeZoneId)
.map(fromHivePartition)
HiveCatalogMetrics.incrementFetchedPartitions(parts.length)
parts
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.NoSuchPermanentFunctionException
import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTablePartition, CatalogUtils, FunctionResource, FunctionResourceType}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TypeUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{AtomicType, IntegralType, StringType}
import org.apache.spark.sql.types.{AtomicType, DateType, IntegralType, StringType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -79,7 +79,11 @@ private[client] sealed abstract class Shim {

def getAllPartitions(hive: Hive, table: Table): Seq[Partition]

def getPartitionsByFilter(hive: Hive, table: Table, predicates: Seq[Expression]): Seq[Partition]
def getPartitionsByFilter(
hive: Hive,
table: Table,
predicates: Seq[Expression],
timeZoneId: String): Seq[Partition]

def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor

Expand Down Expand Up @@ -349,7 +353,8 @@ private[client] class Shim_v0_12 extends Shim with Logging {
override def getPartitionsByFilter(
hive: Hive,
table: Table,
predicates: Seq[Expression]): Seq[Partition] = {
predicates: Seq[Expression],
timeZoneId: String): Seq[Partition] = {
// getPartitionsByFilter() doesn't support binary comparison ops in Hive 0.12.
// See HIVE-4888.
logDebug("Hive 0.12 doesn't support predicate pushdown to metastore. " +
Expand Down Expand Up @@ -632,7 +637,9 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
*
* Unsupported predicates are skipped.
*/
def convertFilters(table: Table, filters: Seq[Expression]): String = {
def convertFilters(table: Table, filters: Seq[Expression], timeZoneId: String): String = {
lazy val dateFormatter = DateFormatter(DateTimeUtils.getZoneId(timeZoneId))

/**
* An extractor that matches all binary comparison operators except null-safe equality.
*
Expand All @@ -650,6 +657,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
case Literal(null, _) => None // `null`s can be cast as other types; we want to avoid NPEs.
case Literal(value, _: IntegralType) => Some(value.toString)
case Literal(value, _: StringType) => Some(quoteStringLiteral(value.toString))
case Literal(value, _: DateType) =>
Some(dateFormatter.format(value.asInstanceOf[Int]))
case _ => None
}
}
Expand Down Expand Up @@ -700,6 +709,21 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
}
}

object ExtractableDateValues {
private lazy val valueToLiteralString: PartialFunction[Any, String] = {
case value: Int => dateFormatter.format(value)
}

def unapply(values: Set[Any]): Option[Seq[String]] = {
val extractables = values.toSeq.map(valueToLiteralString.lift)
if (extractables.nonEmpty && extractables.forall(_.isDefined)) {
Some(extractables.map(_.get))
} else {
None
}
}
}

object SupportedAttribute {
// hive varchar is treated as catalyst string, but hive varchar can't be pushed down.
private val varcharKeys = table.getPartitionKeys.asScala
Expand All @@ -711,7 +735,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
val resolver = SQLConf.get.resolver
if (varcharKeys.exists(c => resolver(c, attr.name))) {
None
} else if (attr.dataType.isInstanceOf[IntegralType] || attr.dataType == StringType) {
} else if (attr.dataType.isInstanceOf[IntegralType] || attr.dataType == StringType ||
attr.dataType == DateType) {
Some(attr.name)
} else {
None
Expand Down Expand Up @@ -748,6 +773,10 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
convert(And(GreaterThanOrEqual(child, Literal(sortedValues.head, dataType)),
LessThanOrEqual(child, Literal(sortedValues.last, dataType))))

case InSet(child @ ExtractAttribute(SupportedAttribute(name)), ExtractableDateValues(values))
if useAdvanced && child.dataType == DateType =>
Some(convertInToOr(name, values))

case InSet(ExtractAttribute(SupportedAttribute(name)), ExtractableValues(values))
if useAdvanced =>
Some(convertInToOr(name, values))
Expand Down Expand Up @@ -803,11 +832,12 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
override def getPartitionsByFilter(
hive: Hive,
table: Table,
predicates: Seq[Expression]): Seq[Partition] = {
predicates: Seq[Expression],
timeZoneId: String): Seq[Partition] = {

// Hive getPartitionsByFilter() takes a string that represents partition
// predicates like "str_key=\"value\" and int_key=1 ..."
val filter = convertFilters(table, predicates)
val filter = convertFilters(table, predicates, timeZoneId)

val partitions =
if (filter.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.hive.client

import java.sql.Date
import java.util.Collections

import org.apache.hadoop.hive.metastore.api.FieldSchema
Expand All @@ -29,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

/**
* A set of tests for the filter conversion logic used when pushing partition pruning into the
Expand Down Expand Up @@ -63,6 +65,28 @@ class FiltersSuite extends SparkFunSuite with Logging with PlanTest {
(Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", IntegerType)) :: Nil,
"1 = intcol and \"a\" = strcol")

filterTest("date filter",
(a("datecol", DateType) === Literal(Date.valueOf("2019-01-01"))) :: Nil,
"datecol = 2019-01-01")

filterTest("date filter with IN predicate",
(a("datecol", DateType) in
(Literal(Date.valueOf("2019-01-01")), Literal(Date.valueOf("2019-01-07")))) :: Nil,
"(datecol = 2019-01-01 or datecol = 2019-01-07)")

filterTest("date and string filter",
(Literal(Date.valueOf("2019-01-01")) === a("datecol", DateType)) ::
(Literal("a") === a("strcol", IntegerType)) :: Nil,
"2019-01-01 = datecol and \"a\" = strcol")

filterTest("date filter with null",
(a("datecol", DateType) === Literal(null)) :: Nil,
"")

filterTest("string filter with InSet predicate",
InSet(a("strcol", StringType), Set("1", "2").map(s => UTF8String.fromString(s))) :: Nil,
"(strcol = \"1\" or strcol = \"2\")")

filterTest("skip varchar",
(Literal("") === a("varchar", StringType)) :: Nil,
"")
Expand All @@ -89,7 +113,7 @@ class FiltersSuite extends SparkFunSuite with Logging with PlanTest {
private def filterTest(name: String, filters: Seq[Expression], result: String) = {
test(name) {
withSQLConf(SQLConf.ADVANCED_PARTITION_PREDICATE_PUSHDOWN.key -> "true") {
val converted = shim.convertFilters(testTable, filters)
val converted = shim.convertFilters(testTable, filters, conf.sessionLocalTimeZone)
if (converted != result) {
fail(s"Expected ${filters.mkString(",")} to convert to '$result' but got '$converted'")
}
Expand All @@ -104,7 +128,7 @@ class FiltersSuite extends SparkFunSuite with Logging with PlanTest {
val filters =
(Literal(1) === a("intcol", IntegerType) ||
Literal(2) === a("intcol", IntegerType)) :: Nil
val converted = shim.convertFilters(testTable, filters)
val converted = shim.convertFilters(testTable, filters, conf.sessionLocalTimeZone)
if (enabled) {
assert(converted == "(1 = intcol or 2 = intcol)")
} else {
Expand All @@ -116,7 +140,7 @@ class FiltersSuite extends SparkFunSuite with Logging with PlanTest {

test("SPARK-33416: Avoid Hive metastore stack overflow when InSet predicate have many values") {
def checkConverted(inSet: InSet, result: String): Unit = {
assert(shim.convertFilters(testTable, inSet :: Nil) == result)
assert(shim.convertFilters(testTable, inSet :: Nil, conf.sessionLocalTimeZone) == result)
}

withSQLConf(SQLConf.HIVE_METASTORE_PARTITION_PRUNING_INSET_THRESHOLD.key -> "15") {
Expand All @@ -139,6 +163,11 @@ class FiltersSuite extends SparkFunSuite with Logging with PlanTest {
InSet(a("doublecol", DoubleType),
Range(1, 20).map(s => Literal(s.toDouble).eval(EmptyRow)).toSet),
"")

checkConverted(
InSet(a("datecol", DateType),
Range(1, 20).map(d => Literal(d, DateType).eval(EmptyRow)).toSet),
"(datecol >= 1970-01-02 and datecol <= 1970-01-20)")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.hive.client

import java.sql.Date

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat
Expand All @@ -28,7 +30,8 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType, StringType, StructType}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, DateType, IntegerType, LongType, StringType, StructType}
import org.apache.spark.util.Utils

class HivePartitionFilteringSuite(version: String)
Expand All @@ -38,15 +41,16 @@ class HivePartitionFilteringSuite(version: String)

private val testPartitionCount = 3 * 5 * 4

private def init(tryDirectSql: Boolean): HiveClient = {
val storageFormat = CatalogStorageFormat(
locationUri = None,
inputFormat = None,
outputFormat = None,
serde = None,
compressed = false,
properties = Map.empty)
private val storageFormat = CatalogStorageFormat(
locationUri = None,
inputFormat = Some(classOf[TextInputFormat].getName),
outputFormat = Some(classOf[HiveIgnoreKeyTextOutputFormat[_, _]].getName),
serde = Some(classOf[LazySimpleSerDe].getName()),
compressed = false,
properties = Map.empty
)

private def init(tryDirectSql: Boolean): HiveClient = {
val hadoopConf = new Configuration()
hadoopConf.setBoolean(tryDirectSqlKey, tryDirectSql)
hadoopConf.set("hive.metastore.warehouse.dir", Utils.createTempDir().toURI().toString())
Expand All @@ -58,14 +62,7 @@ class HivePartitionFilteringSuite(version: String)
tableType = CatalogTableType.MANAGED,
schema = tableSchema,
partitionColumnNames = Seq("ds", "h", "chunk"),
storage = CatalogStorageFormat(
locationUri = None,
inputFormat = Some(classOf[TextInputFormat].getName),
outputFormat = Some(classOf[HiveIgnoreKeyTextOutputFormat[_, _]].getName),
serde = Some(classOf[LazySimpleSerDe].getName()),
compressed = false,
properties = Map.empty
))
storage = storageFormat)
client.createTable(table, ignoreIfExists = false)

val partitions =
Expand Down Expand Up @@ -102,7 +99,7 @@ class HivePartitionFilteringSuite(version: String)
test(s"getPartitionsByFilter returns all partitions when $tryDirectSqlKey=false") {
val client = init(false)
val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"),
Seq(attr("ds") === 20170101))
Seq(attr("ds") === 20170101), SQLConf.get.sessionLocalTimeZone)

assert(filteredPartitions.size == testPartitionCount)
}
Expand Down Expand Up @@ -297,6 +294,63 @@ class HivePartitionFilteringSuite(version: String)
day :: Nil)
}

test("getPartitionsByFilter: date type pruning by metastore") {
val table = CatalogTable(
identifier = TableIdentifier("test_date", Some("default")),
tableType = CatalogTableType.MANAGED,
schema = new StructType().add("value", "int").add("part", "date"),
partitionColumnNames = Seq("part"),
storage = storageFormat)
client.createTable(table, ignoreIfExists = false)

val partitions =
for {
date <- Seq("2019-01-01", "2019-01-02", "2019-01-03", "2019-01-04")
} yield CatalogTablePartition(Map(
"part" -> date
), storageFormat)
assert(partitions.size == 4)

client.createPartitions("default", "test_date", partitions, ignoreIfExists = false)

def testDataTypeFiltering(
filterExprs: Seq[Expression],
expectedPartitionCubes: Seq[Seq[Date]]): Unit = {
val filteredPartitions = client.getPartitionsByFilter(
client.getTable("default", "test_date"),
filterExprs,
SQLConf.get.sessionLocalTimeZone)

val expectedPartitions = expectedPartitionCubes.map {
expectedDt =>
for {
dt <- expectedDt
} yield Set(
"part" -> dt.toString
)
}.reduce(_ ++ _)

assert(filteredPartitions.map(_.spec.toSet).toSet == expectedPartitions.toSet)
}

val dateAttr: Attribute = AttributeReference("part", DateType)()

testDataTypeFiltering(
Seq(dateAttr === Date.valueOf("2019-01-01")),
Seq("2019-01-01").map(Date.valueOf) :: Nil)
testDataTypeFiltering(
Seq(dateAttr > Date.valueOf("2019-01-02")),
Seq("2019-01-03", "2019-01-04").map(Date.valueOf) :: Nil)
testDataTypeFiltering(
Seq(In(dateAttr,
Seq("2019-01-01", "2019-01-02").map(d => Literal(Date.valueOf(d))))),
Seq("2019-01-01", "2019-01-02").map(Date.valueOf) :: Nil)
testDataTypeFiltering(
Seq(InSet(dateAttr,
Set("2019-01-01", "2019-01-02").map(d => Literal(Date.valueOf(d)).eval(EmptyRow)))),
Seq("2019-01-01", "2019-01-02").map(Date.valueOf) :: Nil)
}

private def testMetastorePartitionFiltering(
filterExpr: Expression,
expectedDs: Seq[Int],
Expand Down Expand Up @@ -333,7 +387,7 @@ class HivePartitionFilteringSuite(version: String)
val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"),
Seq(
transform(filterExpr)
))
), SQLConf.get.sessionLocalTimeZone)

val expectedPartitionCount = expectedPartitionCubes.map {
case (expectedDs, expectedH, expectedChunks) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,8 @@ class VersionsSuite extends SparkFunSuite with Logging {
test(s"$version: getPartitionsByFilter") {
// Only one partition [1, 1] for key2 == 1
val result = client.getPartitionsByFilter(client.getTable("default", "src_part"),
Seq(EqualTo(AttributeReference("key2", IntegerType)(), Literal(1))))
Seq(EqualTo(AttributeReference("key2", IntegerType)(), Literal(1))),
versionSpark.conf.sessionLocalTimeZone)

// Hive 0.12 doesn't support getPartitionsByFilter, it ignores the filter condition.
if (version != "0.12") {
Expand Down

0 comments on commit 781e19c

Please sign in to comment.