Skip to content

Commit

Permalink
[SPARK-8752][SQL] Add ExpectsInputTypes trait for defining expected i…
Browse files Browse the repository at this point in the history
…nput types.

This patch doesn't actually introduce any code that uses the new ExpectsInputTypes. It just adds the trait so others can use it. Also renamed the old expectsInputTypes function to just inputTypes.

We should add implicit type casting also in the future.

Author: Reynold Xin <rxin@databricks.com>

Closes apache#7151 from rxin/expects-input-types and squashes the following commits:

16cf07b [Reynold Xin] [SPARK-8752][SQL] Add ExpectsInputTypes trait for defining expected input types.
  • Loading branch information
rxin committed Jul 1, 2015
1 parent 69c5dee commit 4137f76
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.apache.spark.sql.types._
* Throws user facing errors when passed invalid queries that fail to analyze.
*/
trait CheckAnalysis {
self: Analyzer =>

/**
* Override to provide additional checks for correct analysis.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ object HiveTypeCoercion {
IfCoercion ::
Division ::
PropagateTypes ::
AddCastForAutoCastInputTypes ::
ImplicitTypeCasts ::
Nil

// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
Expand Down Expand Up @@ -705,13 +705,13 @@ object HiveTypeCoercion {
* Casts types according to the expected input types for Expressions that have the trait
* [[AutoCastInputTypes]].
*/
object AddCastForAutoCastInputTypes extends Rule[LogicalPlan] {
object ImplicitTypeCasts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

case e: AutoCastInputTypes if e.children.map(_.dataType) != e.expectedChildTypes =>
val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map {
case e: AutoCastInputTypes if e.children.map(_.dataType) != e.inputTypes =>
val newC = (e.children, e.children.map(_.dataType), e.inputTypes).zipped.map {
case (child, actual, expected) =>
if (actual == expected) child else Cast(child, expected)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,17 +265,38 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
}
}

/**
* An trait that gets mixin to define the expected input types of an expression.
*/
trait ExpectsInputTypes { self: Expression =>

/**
* Expected input types from child expressions. The i-th position in the returned seq indicates
* the type requirement for the i-th child.
*
* The possible values at each position are:
* 1. a specific data type, e.g. LongType, StringType.
* 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType.
* 3. a list of specific data types, e.g. Seq(StringType, BinaryType).
*/
def inputTypes: Seq[Any]

override def checkInputDataTypes(): TypeCheckResult = {
// We will do the type checking in `HiveTypeCoercion`, so always returning success here.
TypeCheckResult.TypeCheckSuccess
}
}

/**
* Expressions that require a specific `DataType` as input should implement this trait
* so that the proper type conversions can be performed in the analyzer.
*/
trait AutoCastInputTypes {
self: Expression =>
trait AutoCastInputTypes { self: Expression =>

def expectedChildTypes: Seq[DataType]
def inputTypes: Seq[DataType]

override def checkInputDataTypes(): TypeCheckResult = {
// We will always do type casting for `ExpectsInputTypes` in `HiveTypeCoercion`,
// We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`,
// so type mismatch error won't be reported here, but for underling `Cast`s.
TypeCheckResult.TypeCheckSuccess
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
extends UnaryExpression with Serializable with AutoCastInputTypes {
self: Product =>

override def expectedChildTypes: Seq[DataType] = Seq(DoubleType)
override def inputTypes: Seq[DataType] = Seq(DoubleType)
override def dataType: DataType = DoubleType
override def nullable: Boolean = true
override def toString: String = s"$name($child)"
Expand Down Expand Up @@ -98,7 +98,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
extends BinaryExpression with Serializable with AutoCastInputTypes { self: Product =>

override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType)
override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType)

override def toString: String = s"$name($left, $right)"

Expand Down Expand Up @@ -210,7 +210,7 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia
case class Bin(child: Expression)
extends UnaryExpression with Serializable with AutoCastInputTypes {

override def expectedChildTypes: Seq[DataType] = Seq(LongType)
override def inputTypes: Seq[DataType] = Seq(LongType)
override def dataType: DataType = StringType

override def eval(input: InternalRow): Any = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ case class Md5(child: Expression)

override def dataType: DataType = StringType

override def expectedChildTypes: Seq[DataType] = Seq(BinaryType)
override def inputTypes: Seq[DataType] = Seq(BinaryType)

override def eval(input: InternalRow): Any = {
val value = child.eval(input)
Expand Down Expand Up @@ -68,7 +68,7 @@ case class Sha2(left: Expression, right: Expression)

override def toString: String = s"SHA2($left, $right)"

override def expectedChildTypes: Seq[DataType] = Seq(BinaryType, IntegerType)
override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType)

override def eval(input: InternalRow): Any = {
val evalE1 = left.eval(input)
Expand Down Expand Up @@ -151,7 +151,7 @@ case class Sha1(child: Expression) extends UnaryExpression with AutoCastInputTyp

override def dataType: DataType = StringType

override def expectedChildTypes: Seq[DataType] = Seq(BinaryType)
override def inputTypes: Seq[DataType] = Seq(BinaryType)

override def eval(input: InternalRow): Any = {
val value = child.eval(input)
Expand Down Expand Up @@ -179,7 +179,7 @@ case class Crc32(child: Expression)

override def dataType: DataType = LongType

override def expectedChildTypes: Seq[DataType] = Seq(BinaryType)
override def inputTypes: Seq[DataType] = Seq(BinaryType)

override def eval(input: InternalRow): Any = {
val value = child.eval(input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ trait PredicateHelper {
case class Not(child: Expression) extends UnaryExpression with Predicate with AutoCastInputTypes {
override def toString: String = s"NOT $child"

override def expectedChildTypes: Seq[DataType] = Seq(BooleanType)
override def inputTypes: Seq[DataType] = Seq(BooleanType)

override def eval(input: InternalRow): Any = {
child.eval(input) match {
Expand Down Expand Up @@ -122,7 +122,7 @@ case class InSet(value: Expression, hset: Set[Any])
case class And(left: Expression, right: Expression)
extends BinaryExpression with Predicate with AutoCastInputTypes {

override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType)

override def symbol: String = "&&"

Expand Down Expand Up @@ -171,7 +171,7 @@ case class And(left: Expression, right: Expression)
case class Or(left: Expression, right: Expression)
extends BinaryExpression with Predicate with AutoCastInputTypes {

override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType)

override def symbol: String = "||"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ trait StringRegexExpression extends AutoCastInputTypes {

override def nullable: Boolean = left.nullable || right.nullable
override def dataType: DataType = BooleanType
override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)

// try cache the pattern for Literal
private lazy val cache: Pattern = right match {
Expand Down Expand Up @@ -117,7 +117,7 @@ trait CaseConversionExpression extends AutoCastInputTypes {
def convert(v: UTF8String): UTF8String

override def dataType: DataType = StringType
override def expectedChildTypes: Seq[DataType] = Seq(StringType)
override def inputTypes: Seq[DataType] = Seq(StringType)

override def eval(input: InternalRow): Any = {
val evaluated = child.eval(input)
Expand Down Expand Up @@ -165,7 +165,7 @@ trait StringComparison extends AutoCastInputTypes {

override def nullable: Boolean = left.nullable || right.nullable

override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)

override def eval(input: InternalRow): Any = {
val leftEval = left.eval(input)
Expand Down Expand Up @@ -238,7 +238,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
if (str.dataType == BinaryType) str.dataType else StringType
}

override def expectedChildTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType)
override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType)

override def children: Seq[Expression] = str :: pos :: len :: Nil

Expand Down Expand Up @@ -297,7 +297,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
*/
case class StringLength(child: Expression) extends UnaryExpression with AutoCastInputTypes {
override def dataType: DataType = IntegerType
override def expectedChildTypes: Seq[DataType] = Seq(StringType)
override def inputTypes: Seq[DataType] = Seq(StringType)

override def eval(input: InternalRow): Any = {
val string = child.eval(input)
Expand Down

0 comments on commit 4137f76

Please sign in to comment.