Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-8075][SQL] apply type check interface to more expressions #6723

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -585,8 +585,8 @@ class Analyzer(
failAnalysis(
s"""Expect multiple names given for ${g.getClass.getName},
|but only single name '${name}' specified""".stripMargin)
case Alias(g: Generator, name) => Some((g, name :: Nil))
case MultiAlias(g: Generator, names) => Some(g, names)
case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil))
case MultiAlias(g: Generator, names) if g.resolved => Some(g, names)
case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ trait HiveTypeCoercion {
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))

case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
}
}
Expand Down Expand Up @@ -590,11 +591,12 @@ trait HiveTypeCoercion {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

case a @ CreateArray(children) if !a.resolved =>
val commonType = a.childTypes.reduce(
(a, b) => findTightestCommonTypeOfTwo(a, b).getOrElse(StringType))
CreateArray(
children.map(c => if (c.dataType == commonType) c else Cast(c, commonType)))
case a @ CreateArray(children) if children.map(_.dataType).distinct.size > 1 =>
val types = children.map(_.dataType)
findTightestCommonTypeAndPromoteToString(types) match {
case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType)))
case None => a
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked with hive, CreateArray should cast input to string, but with some rules like Coalesce did. For example we should not cast boolean to string for CreateArray.


// Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows.
case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest.
Expand All @@ -620,12 +622,11 @@ trait HiveTypeCoercion {
// Coalesce should return the first non-null value, which could be any column
// from the list. So we need to make sure the return type is deterministic and
// compatible with every child column.
case Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
case c @ Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
val types = es.map(_.dataType)
findTightestCommonTypeAndPromoteToString(types) match {
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
case None =>
sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}")
case None => c
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
Expand All @@ -31,7 +31,14 @@ import org.apache.spark.unsafe.types.UTF8String
/** Cast the child expression to the target data type. */
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging {

override lazy val resolved = childrenResolved && resolve(child.dataType, dataType)
override def checkInputDataTypes(): TypeCheckResult = {
if (resolve(child.dataType, dataType)) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(
s"cannot cast ${child.dataType} to $dataType")
}
}

override def foldable: Boolean = child.foldable

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,7 @@ abstract class Expression extends TreeNode[Expression] {
/**
* Checks the input data types, returns `TypeCheckResult.success` if it's valid,
* or returns a `TypeCheckResult` with an error message if invalid.
* Note: it's not valid to call this method until `childrenResolved == true`
* TODO: we should remove the default implementation and implement it for all
* expressions with proper error message.
* Note: it's not valid to call this method until `childrenResolved == true`.
*/
def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ object ExtractValue {
}
}

/**
* A common interface of all kinds of extract value expressions.
* Note: concrete extract value expressions are created only by `ExtractValue.apply`,
* we don't need to do type check for them.
*/
trait ExtractValue extends UnaryExpression {
self: Product =>
}
Expand Down Expand Up @@ -179,9 +184,6 @@ case class GetArrayItem(child: Expression, ordinal: Expression)

override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType

override lazy val resolved = childrenResolved &&
child.dataType.isInstanceOf[ArrayType] && ordinal.dataType.isInstanceOf[IntegralType]

protected def evalNotNull(value: Any, ordinal: Any) = {
// TODO: consider using Array[_] for ArrayType child to avoid
// boxing of primitives
Expand All @@ -203,8 +205,6 @@ case class GetMapValue(child: Expression, ordinal: Expression)

override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType

override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[MapType]

protected def evalNotNull(value: Any, ordinal: Any) = {
val baseValue = value.asInstanceOf[Map[Any, _]]
baseValue.get(ordinal).orNull
Expand Down
Loading