diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index a069b4710f38c..583338da57117 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.types._ * Throws user facing errors when passed invalid queries that fail to analyze. */ trait CheckAnalysis { - self: Analyzer => /** * Override to provide additional checks for correct analysis. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index a9d396d1faeeb..2ab5cb666fbcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -45,7 +45,7 @@ object HiveTypeCoercion { IfCoercion :: Division :: PropagateTypes :: - AddCastForAutoCastInputTypes :: + ImplicitTypeCasts :: Nil // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. @@ -705,13 +705,13 @@ object HiveTypeCoercion { * Casts types according to the expected input types for Expressions that have the trait * [[AutoCastInputTypes]]. */ - object AddCastForAutoCastInputTypes extends Rule[LogicalPlan] { + object ImplicitTypeCasts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case e: AutoCastInputTypes if e.children.map(_.dataType) != e.expectedChildTypes => - val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map { + case e: AutoCastInputTypes if e.children.map(_.dataType) != e.inputTypes => + val newC = (e.children, e.children.map(_.dataType), e.inputTypes).zipped.map { case (child, actual, expected) => if (actual == expected) child else Cast(child, expected) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index b5063f32fa529..e18a3118945e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -265,17 +265,38 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio } } +/** + * An trait that gets mixin to define the expected input types of an expression. + */ +trait ExpectsInputTypes { self: Expression => + + /** + * Expected input types from child expressions. The i-th position in the returned seq indicates + * the type requirement for the i-th child. + * + * The possible values at each position are: + * 1. a specific data type, e.g. LongType, StringType. + * 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType. + * 3. a list of specific data types, e.g. Seq(StringType, BinaryType). + */ + def inputTypes: Seq[Any] + + override def checkInputDataTypes(): TypeCheckResult = { + // We will do the type checking in `HiveTypeCoercion`, so always returning success here. + TypeCheckResult.TypeCheckSuccess + } +} + /** * Expressions that require a specific `DataType` as input should implement this trait * so that the proper type conversions can be performed in the analyzer. */ -trait AutoCastInputTypes { - self: Expression => +trait AutoCastInputTypes { self: Expression => - def expectedChildTypes: Seq[DataType] + def inputTypes: Seq[DataType] override def checkInputDataTypes(): TypeCheckResult = { - // We will always do type casting for `ExpectsInputTypes` in `HiveTypeCoercion`, + // We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`, // so type mismatch error won't be reported here, but for underling `Cast`s. TypeCheckResult.TypeCheckSuccess } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index da63f2fa970cf..b51318dd5044c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -59,7 +59,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) extends UnaryExpression with Serializable with AutoCastInputTypes { self: Product => - override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) + override def inputTypes: Seq[DataType] = Seq(DoubleType) override def dataType: DataType = DoubleType override def nullable: Boolean = true override def toString: String = s"$name($child)" @@ -98,7 +98,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) extends BinaryExpression with Serializable with AutoCastInputTypes { self: Product => - override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) + override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType) override def toString: String = s"$name($left, $right)" @@ -210,7 +210,7 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia case class Bin(child: Expression) extends UnaryExpression with Serializable with AutoCastInputTypes { - override def expectedChildTypes: Seq[DataType] = Seq(LongType) + override def inputTypes: Seq[DataType] = Seq(LongType) override def dataType: DataType = StringType override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index a7bcbe46c339a..407023e472081 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -36,7 +36,7 @@ case class Md5(child: Expression) override def dataType: DataType = StringType - override def expectedChildTypes: Seq[DataType] = Seq(BinaryType) + override def inputTypes: Seq[DataType] = Seq(BinaryType) override def eval(input: InternalRow): Any = { val value = child.eval(input) @@ -68,7 +68,7 @@ case class Sha2(left: Expression, right: Expression) override def toString: String = s"SHA2($left, $right)" - override def expectedChildTypes: Seq[DataType] = Seq(BinaryType, IntegerType) + override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType) override def eval(input: InternalRow): Any = { val evalE1 = left.eval(input) @@ -151,7 +151,7 @@ case class Sha1(child: Expression) extends UnaryExpression with AutoCastInputTyp override def dataType: DataType = StringType - override def expectedChildTypes: Seq[DataType] = Seq(BinaryType) + override def inputTypes: Seq[DataType] = Seq(BinaryType) override def eval(input: InternalRow): Any = { val value = child.eval(input) @@ -179,7 +179,7 @@ case class Crc32(child: Expression) override def dataType: DataType = LongType - override def expectedChildTypes: Seq[DataType] = Seq(BinaryType) + override def inputTypes: Seq[DataType] = Seq(BinaryType) override def eval(input: InternalRow): Any = { val value = child.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 98cd5aa8148c4..a777f77add2db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -72,7 +72,7 @@ trait PredicateHelper { case class Not(child: Expression) extends UnaryExpression with Predicate with AutoCastInputTypes { override def toString: String = s"NOT $child" - override def expectedChildTypes: Seq[DataType] = Seq(BooleanType) + override def inputTypes: Seq[DataType] = Seq(BooleanType) override def eval(input: InternalRow): Any = { child.eval(input) match { @@ -122,7 +122,7 @@ case class InSet(value: Expression, hset: Set[Any]) case class And(left: Expression, right: Expression) extends BinaryExpression with Predicate with AutoCastInputTypes { - override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType) + override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) override def symbol: String = "&&" @@ -171,7 +171,7 @@ case class And(left: Expression, right: Expression) case class Or(left: Expression, right: Expression) extends BinaryExpression with Predicate with AutoCastInputTypes { - override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType) + override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) override def symbol: String = "||" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index ce184e4f32f18..4cbfc4e084948 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -32,7 +32,7 @@ trait StringRegexExpression extends AutoCastInputTypes { override def nullable: Boolean = left.nullable || right.nullable override def dataType: DataType = BooleanType - override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType) + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) // try cache the pattern for Literal private lazy val cache: Pattern = right match { @@ -117,7 +117,7 @@ trait CaseConversionExpression extends AutoCastInputTypes { def convert(v: UTF8String): UTF8String override def dataType: DataType = StringType - override def expectedChildTypes: Seq[DataType] = Seq(StringType) + override def inputTypes: Seq[DataType] = Seq(StringType) override def eval(input: InternalRow): Any = { val evaluated = child.eval(input) @@ -165,7 +165,7 @@ trait StringComparison extends AutoCastInputTypes { override def nullable: Boolean = left.nullable || right.nullable - override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType) + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) override def eval(input: InternalRow): Any = { val leftEval = left.eval(input) @@ -238,7 +238,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) if (str.dataType == BinaryType) str.dataType else StringType } - override def expectedChildTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType) + override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType) override def children: Seq[Expression] = str :: pos :: len :: Nil @@ -297,7 +297,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) */ case class StringLength(child: Expression) extends UnaryExpression with AutoCastInputTypes { override def dataType: DataType = IntegerType - override def expectedChildTypes: Seq[DataType] = Seq(StringType) + override def inputTypes: Seq[DataType] = Seq(StringType) override def eval(input: InternalRow): Any = { val string = child.eval(input)