Skip to content

Commit

Permalink
Merge pull request #2 from marmbrus/pr/2246
Browse files Browse the repository at this point in the history
reduce code duplication
  • Loading branch information
adrian-wang committed Sep 5, 2014
2 parents ef6f986 + 832e640 commit 595b417
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,20 @@ object HiveTypeCoercion {
val numericPrecedence =
Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType)
val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: Nil

def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
val valueTypes = Seq(t1, t2).filter(t => t != NullType)
if (valueTypes.distinct.size > 1) {
// Try and find a promotion rule that contains both types in question.
val applicableConversion =
HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))

// If found return the widest common type, otherwise None
applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
} else {
Some(if (valueTypes.size == 0) NullType else valueTypes.head)
}
}
}

/**
Expand All @@ -51,22 +65,6 @@ trait HiveTypeCoercion {
Division ::
Nil

trait TypeWidening {
def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
val valueTypes = Seq(t1, t2).filter(t => t != NullType)
if (valueTypes.distinct.size > 1) {
// Try and find a promotion rule that contains both types in question.
val applicableConversion =
HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))

// If found return the widest common type, otherwise None
applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
} else {
Some(if (valueTypes.size == 0) NullType else valueTypes.head)
}
}
}

/**
* Applies any changes to [[AttributeReference]] data types that are made by other rules to
* instances higher in the query tree.
Expand Down Expand Up @@ -147,7 +145,8 @@ trait HiveTypeCoercion {
* - LongType to FloatType
* - LongType to DoubleType
*/
object WidenTypes extends Rule[LogicalPlan] with TypeWidening {
object WidenTypes extends Rule[LogicalPlan] {
import HiveTypeCoercion._

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
Expand Down Expand Up @@ -343,7 +342,9 @@ trait HiveTypeCoercion {
/**
* Coerces the type of different branches of a CASE WHEN statement to a common type.
*/
object CaseWhenCoercion extends Rule[LogicalPlan] with TypeWidening {
object CaseWhenCoercion extends Rule[LogicalPlan] {
import HiveTypeCoercion._

def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case cw @ CaseWhen(branches) if !cw.resolved && !branches.exists(!_.resolved) =>
val valueTypes = branches.sliding(2, 2).map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,13 @@ import org.apache.spark.sql.catalyst.types._

class HiveTypeCoercionSuite extends FunSuite {

val rules = new HiveTypeCoercion { }
import rules._

test("tightest common bound for types") {
def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) {
var found = WidenTypes.findTightestCommonType(t1, t2)
var found = HiveTypeCoercion.findTightestCommonType(t1, t2)
assert(found == tightestCommon,
s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found")
// Test both directions to make sure the widening is symmetric.
found = WidenTypes.findTightestCommonType(t2, t1)
found = HiveTypeCoercion.findTightestCommonType(t2, t1)
assert(found == tightestCommon,
s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found")
}
Expand Down
52 changes: 22 additions & 30 deletions sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -125,39 +125,31 @@ private[sql] object JsonRDD extends Logging {
* Returns the most general data type for two given data types.
*/
private[json] def compatibleType(t1: DataType, t2: DataType): DataType = {
// Try and find a promotion rule that contains both types in question.
val applicableConversion = HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p
.contains(t2))

// If found return the widest common type, otherwise None
val returnType = applicableConversion.map(_.filter(t => t == t1 || t == t2).last)

if (returnType.isDefined) {
returnType.get
} else {
// t1 or t2 is a StructType, ArrayType, BooleanType, or an unexpected type.
(t1, t2) match {
case (other: DataType, NullType) => other
case (NullType, other: DataType) => other
case (StructType(fields1), StructType(fields2)) => {
val newFields = (fields1 ++ fields2).groupBy(field => field.name).map {
case (name, fieldTypes) => {
val dataType = fieldTypes.map(field => field.dataType).reduce(
(type1: DataType, type2: DataType) => compatibleType(type1, type2))
StructField(name, dataType, true)
HiveTypeCoercion.findTightestCommonType(t1,t2) match {
case Some(commonType) => commonType
case None =>
// t1 or t2 is a StructType, ArrayType, BooleanType, or an unexpected type.
(t1, t2) match {
case (other: DataType, NullType) => other
case (NullType, other: DataType) => other
case (StructType(fields1), StructType(fields2)) => {
val newFields = (fields1 ++ fields2).groupBy(field => field.name).map {
case (name, fieldTypes) => {
val dataType = fieldTypes.map(field => field.dataType).reduce(
(type1: DataType, type2: DataType) => compatibleType(type1, type2))
StructField(name, dataType, true)
}
}
StructType(newFields.toSeq.sortBy {
case StructField(name, _, _) => name
})
}
StructType(newFields.toSeq.sortBy {
case StructField(name, _, _) => name
})
case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
// TODO: We should use JsonObjectStringType to mark that values of field will be
// strings and every string is a Json object.
case (_, _) => StringType
}
case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
// TODO: We should use JsonObjectStringType to mark that values of field will be
// strings and every string is a Json object.
case (BooleanType, BooleanType) => BooleanType
case (_, _) => StringType
}
}
}

Expand Down

0 comments on commit 595b417

Please sign in to comment.