Skip to content

Commit

Permalink
[SPARK-8075] [SQL] apply type check interface to more expressions
Browse files Browse the repository at this point in the history
a follow up of apache#6405.
Note: It's not a big change, a lot of changing is due to I swap some code in `aggregates.scala` to make aggregate functions right below its corresponding aggregate expressions.

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes apache#6723 from cloud-fan/type-check and squashes the following commits:

2124301 [Wenchen Fan] fix tests
5a658bb [Wenchen Fan] add tests
287d3bb [Wenchen Fan] apply type check interface to more expressions
  • Loading branch information
cloud-fan authored and marmbrus committed Jun 24, 2015
1 parent 7daa702 commit b71d325
Show file tree
Hide file tree
Showing 21 changed files with 337 additions and 290 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -587,8 +587,8 @@ class Analyzer(
failAnalysis(
s"""Expect multiple names given for ${g.getClass.getName},
|but only single name '${name}' specified""".stripMargin)
case Alias(g: Generator, name) => Some((g, name :: Nil))
case MultiAlias(g: Generator, names) => Some(g, names)
case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil))
case MultiAlias(g: Generator, names) if g.resolved => Some(g, names)
case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ trait HiveTypeCoercion {
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))

case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
}
}
Expand Down Expand Up @@ -590,11 +591,12 @@ trait HiveTypeCoercion {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

case a @ CreateArray(children) if !a.resolved =>
val commonType = a.childTypes.reduce(
(a, b) => findTightestCommonTypeOfTwo(a, b).getOrElse(StringType))
CreateArray(
children.map(c => if (c.dataType == commonType) c else Cast(c, commonType)))
case a @ CreateArray(children) if children.map(_.dataType).distinct.size > 1 =>
val types = children.map(_.dataType)
findTightestCommonTypeAndPromoteToString(types) match {
case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType)))
case None => a
}

// Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows.
case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest.
Expand All @@ -620,12 +622,11 @@ trait HiveTypeCoercion {
// Coalesce should return the first non-null value, which could be any column
// from the list. So we need to make sure the return type is deterministic and
// compatible with every child column.
case Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
case c @ Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
val types = es.map(_.dataType)
findTightestCommonTypeAndPromoteToString(types) match {
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
case None =>
sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}")
case None => c
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
Expand All @@ -31,7 +31,14 @@ import org.apache.spark.unsafe.types.UTF8String
/** Cast the child expression to the target data type. */
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging {

override lazy val resolved = childrenResolved && resolve(child.dataType, dataType)
override def checkInputDataTypes(): TypeCheckResult = {
if (resolve(child.dataType, dataType)) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(
s"cannot cast ${child.dataType} to $dataType")
}
}

override def foldable: Boolean = child.foldable

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,7 @@ abstract class Expression extends TreeNode[Expression] {
/**
* Checks the input data types, returns `TypeCheckResult.success` if it's valid,
* or returns a `TypeCheckResult` with an error message if invalid.
* Note: it's not valid to call this method until `childrenResolved == true`
* TODO: we should remove the default implementation and implement it for all
* expressions with proper error message.
* Note: it's not valid to call this method until `childrenResolved == true`.
*/
def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ object ExtractValue {
}
}

/**
* A common interface of all kinds of extract value expressions.
* Note: concrete extract value expressions are created only by `ExtractValue.apply`,
* we don't need to do type check for them.
*/
trait ExtractValue extends UnaryExpression {
self: Product =>
}
Expand Down Expand Up @@ -179,9 +184,6 @@ case class GetArrayItem(child: Expression, ordinal: Expression)

override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType

override lazy val resolved = childrenResolved &&
child.dataType.isInstanceOf[ArrayType] && ordinal.dataType.isInstanceOf[IntegralType]

protected def evalNotNull(value: Any, ordinal: Any) = {
// TODO: consider using Array[_] for ArrayType child to avoid
// boxing of primitives
Expand All @@ -203,8 +205,6 @@ case class GetMapValue(child: Expression, ordinal: Expression)

override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType

override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[MapType]

protected def evalNotNull(value: Any, ordinal: Any) = {
val baseValue = value.asInstanceOf[Map[Any, _]]
baseValue.get(ordinal).orNull
Expand Down
Loading

0 comments on commit b71d325

Please sign in to comment.