Skip to content

Commit

Permalink
[SPARK-2053][SQL] Add Catalyst expressions for CASE WHEN.
Browse files Browse the repository at this point in the history
JIRA ticket: https://issues.apache.org/jira/browse/SPARK-2053

This PR adds support for two types of CASE statements present in Hive. The first type is of the form `CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END`, with the semantics like a chain of if statements. The second type is of the form `CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END`, with the semantics like a switch statement on key `a`. Both forms are implemented in `CaseWhen`.

[This link](https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions) contains more detailed descriptions on their semantics.

Notes / Open issues:

* Please check if any implicit contracts / invariants are broken in the implementations (especially for the operators). I am not very familiar with them and I currently find them tricky to spot.
* We should decide whether or not a non-boolean condition is allowed in a branch of `CaseWhen`. Hive throws a `SemanticException` for this situation and I think it'd be good to mimic it -- the question is where in the whole Spark SQL pipeline should we signal an exception for such a query.

Author: Zongheng Yang <zongheng.y@gmail.com>

Closes #1055 from concretevitamin/caseWhen and squashes the following commits:

4226eb9 [Zongheng Yang] Comment.
79d26fc [Zongheng Yang] Merge branch 'master' into caseWhen
caf9383 [Zongheng Yang] Update a FIXME.
9d26ab8 [Zongheng Yang] Add @transient marker.
788a0d9 [Zongheng Yang] Implement CastNulls, which fixes udf_case and udf_when.
7ef284f [Zongheng Yang] Refactors: remove redundant passes, improve toString, mark transient.
f47ae7b [Zongheng Yang] Modify queries in tests to have shorter golden files.
1c1fbfc [Zongheng Yang] Cleanups per review comments.
7d2b7e2 [Zongheng Yang] Translate CaseKeyWhen to CaseWhen at parsing time.
47d406a [Zongheng Yang] Do toArray once and lazily outside of eval().
bb3d109 [Zongheng Yang] Update scaladoc of a method.
aea3195 [Zongheng Yang] Fix bug that branchesArr is not used; remove unused import.
96870a8 [Zongheng Yang] Turn off scalastyle for some comments.
7392f3a [Zongheng Yang] Minor cleanup.
2cf08bb [Zongheng Yang] Merge branch 'master' into caseWhen
9f84b40 [Zongheng Yang] Add golden outputs from Hive.
db51a85 [Zongheng Yang] Add allCondBooleans check; uncomment tests.
3f9ef0a [Zongheng Yang] Cleanups and bug fixes (mainly in eval() and resolved).
be54bc8 [Zongheng Yang] Rewrite eval() to a low-level implementation. Separate two CASE stmts.
f2bcb9d [Zongheng Yang] WIP
5906f75 [Zongheng Yang] WIP
efd019b [Zongheng Yang] eval() and toString() bug fixes.
7d81e95 [Zongheng Yang] Clean up resolved.
a31d782 [Zongheng Yang] Finish up Case.
  • Loading branch information
concretevitamin authored and marmbrus committed Jun 17, 2014
1 parent f5a4049 commit e243c5f
Show file tree
Hide file tree
Showing 15 changed files with 290 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,16 @@ import org.apache.spark.sql.catalyst.types._
trait HiveTypeCoercion {

val typeCoercionRules =
List(PropagateTypes, ConvertNaNs, WidenTypes, PromoteStrings, BooleanComparisons, BooleanCasts,
StringToIntegralCasts, FunctionArgumentConversion)
PropagateTypes ::
ConvertNaNs ::
WidenTypes ::
PromoteStrings ::
BooleanComparisons ::
BooleanCasts ::
StringToIntegralCasts ::
FunctionArgumentConversion ::
CastNulls ::
Nil

/**
* Applies any changes to [[catalyst.expressions.AttributeReference AttributeReference]] data
Expand Down Expand Up @@ -282,4 +290,33 @@ trait HiveTypeCoercion {
Average(Cast(e, DoubleType))
}
}

/**
* Ensures that NullType gets casted to some other types under certain circumstances.
*/
object CastNulls extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case cw @ CaseWhen(branches) =>
val valueTypes = branches.sliding(2, 2).map {
case Seq(_, value) if value.resolved => Some(value.dataType)
case Seq(elseVal) if elseVal.resolved => Some(elseVal.dataType)
case _ => None
}.toSeq
if (valueTypes.distinct.size == 2 && valueTypes.exists(_ == Some(NullType))) {
val otherType = valueTypes.filterNot(_ == Some(NullType))(0).get
val transformedBranches = branches.sliding(2, 2).map {
case Seq(cond, value) if value.resolved && value.dataType == NullType =>
Seq(cond, Cast(value, otherType))
case Seq(elseVal) if elseVal.resolved && elseVal.dataType == NullType =>
Seq(Cast(elseVal, otherType))
case s => s
}.reduce(_ ++ _)
CaseWhen(transformedBranches)
} else {
// It is possible to have more types due to the possibility of short-circuiting.
cw
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ abstract class Expression extends TreeNode[Expression] {
/** The narrowest possible type that is produced when this expression is evaluated. */
type EvaluatedType <: Any

def dataType: DataType

/**
* Returns true when an expression is a candidate for static evaluation before the query is
* executed.
Expand All @@ -53,12 +51,18 @@ abstract class Expression extends TreeNode[Expression] {

/**
* Returns `true` if this expression and all its children have been resolved to a specific schema
* and `false` if it is still contains any unresolved placeholders. Implementations of expressions
* and `false` if it still contains any unresolved placeholders. Implementations of expressions
* should override this if the resolution of this type of expression involves more than just
* the resolution of its children.
*/
lazy val resolved: Boolean = childrenResolved

/**
* Returns the [[types.DataType DataType]] of the result of evaluating this expression. It is
* invalid to query the dataType of an unresolved expression (i.e., when `resolved` == false).
*/
def dataType: DataType

/**
* Returns true if all the children of this expression have been resolved to a specific schema
* and false if any still contains any unresolved placeholders.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.types.BooleanType


Expand Down Expand Up @@ -202,3 +201,78 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi

override def toString = s"if ($predicate) $trueValue else $falseValue"
}

// scalastyle:off
/**
* Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
* Refer to this link for the corresponding semantics:
* https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
*
* The other form of case statements "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END" gets
* translated to this form at parsing time. Namely, such a statement gets translated to
* "CASE WHEN a=b THEN c [WHEN a=d THEN e]* [ELSE f] END".
*
* Note that `branches` are considered in consecutive pairs (cond, val), and the optional last
* element is the value for the default catch-all case (if provided). Hence, `branches` consists of
* at least two elements, and can have an odd or even length.
*/
// scalastyle:on
case class CaseWhen(branches: Seq[Expression]) extends Expression {
type EvaluatedType = Any
def children = branches
def references = children.flatMap(_.references).toSet
def dataType = {
if (!resolved) {
throw new UnresolvedException(this, "cannot resolve due to differing types in some branches")
}
branches(1).dataType
}

@transient private[this] lazy val branchesArr = branches.toArray
@transient private[this] lazy val predicates =
branches.sliding(2, 2).collect { case Seq(cond, _) => cond }.toSeq
@transient private[this] lazy val values =
branches.sliding(2, 2).collect { case Seq(_, value) => value }.toSeq

override def nullable = {
// If no value is nullable and no elseValue is provided, the whole statement defaults to null.
values.exists(_.nullable) || (values.length % 2 == 0)
}

override lazy val resolved = {
if (!childrenResolved) {
false
} else {
val allCondBooleans = predicates.forall(_.dataType == BooleanType)
val dataTypesEqual = values.map(_.dataType).distinct.size <= 1
allCondBooleans && dataTypesEqual
}
}

/** Written in imperative fashion for performance considerations. Same for CaseKeyWhen. */
override def eval(input: Row): Any = {
val len = branchesArr.length
var i = 0
// If all branches fail and an elseVal is not provided, the whole statement
// defaults to null, according to Hive's semantics.
var res: Any = null
while (i < len - 1) {
if (branchesArr(i).eval(input) == true) {
res = branchesArr(i + 1).eval(input)
return res
}
i += 2
}
if (i == len - 1) {
res = branchesArr(i).eval(input)
}
res
}

override def toString = {
"CASE" + branches.sliding(2, 2).map {
case Seq(cond, value) => s" WHEN $cond THEN $value"
case Seq(elseValue) => s" ELSE $elseValue"
}.mkString
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ package object util {
}

/* FIX ME
implicit class debugLogging(a: AnyRef) {
implicit class debugLogging(a: Any) {
def debugLogging() {
org.apache.log4j.Logger.getLogger(a.getClass.getName).setLevel(org.apache.log4j.Level.DEBUG)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ExpressionEvaluationSuite extends FunSuite {
/**
* Checks for three-valued-logic. Based on:
* http://en.wikipedia.org/wiki/Null_(SQL)#Comparisons_with_NULL_and_the_three-valued_logic_.283VL.29
*
* I.e. in flat cpo "False -> Unknown -> True", OR is lowest upper bound, AND is greatest lower bound.
* p q p OR q p AND q p = q
* True True True True True
* True False True False False
Expand Down
17 changes: 17 additions & 0 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,8 @@ private[hive] object HiveQl {
val IN = "(?i)IN".r
val DIV = "(?i)DIV".r
val BETWEEN = "(?i)BETWEEN".r
val WHEN = "(?i)WHEN".r
val CASE = "(?i)CASE".r

protected def nodeToExpr(node: Node): Expression = node match {
/* Attribute References */
Expand Down Expand Up @@ -917,6 +919,21 @@ private[hive] object HiveQl {
case Token(OR(), left :: right:: Nil) => Or(nodeToExpr(left), nodeToExpr(right))
case Token(NOT(), child :: Nil) => Not(nodeToExpr(child))

/* Case statements */
case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) =>
CaseWhen(branches.map(nodeToExpr))
case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) =>
val transformed = branches.drop(1).sliding(2, 2).map {
case Seq(condVal, value) =>
// FIXME (SPARK-2155): the key will get evaluated for multiple times in CaseWhen's eval().
// Hence effectful / non-deterministic key expressions are *not* supported at the moment.
// We should consider adding new Expressions to get around this.
Seq(Equals(nodeToExpr(branches(0)), nodeToExpr(condVal)),
nodeToExpr(value))
case Seq(elseVal) => Seq(nodeToExpr(elseVal))
}.toSeq.reduce(_ ++ _)
CaseWhen(transformed)

/* Complex datatype manipulation */
case Token("[", child :: ordinal :: Nil) =>
GetItem(nodeToExpr(child), nodeToExpr(ordinal))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
NULL
3
3
3
NULL
NULL
3
3
3
3
NULL
3
3
3
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
4
3
3
3
4
4
3
3
3
3
4
3
3
3
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
2
3
3
3
2
2
3
3
3
3
NULL
3
3
3
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
2
3
3
3
2
2
3
3
3
3
0
3
3
3
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
NULL
NULL
NULL
NULL
NULL
NULL
NULL
NULL
NULL
NULL
NULL
NULL
NULL
NULL
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
0
0
0
0
0
0
0
0
0
0
3
0
0
0
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
NULL
NULL
NULL
NULL
NULL
NULL
NULL
NULL
NULL
NULL
3
NULL
NULL
NULL
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
0
0
0
0
0
0
0
0
0
0
3
0
0
0
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,44 @@ class HiveQuerySuite extends HiveComparisonTest {
hql("SELECT * FROM src").toString
}

createQueryTest("case statements with key #1",
"SELECT (CASE 1 WHEN 2 THEN 3 END) FROM src where key < 15")

createQueryTest("case statements with key #2",
"SELECT (CASE key WHEN 2 THEN 3 ELSE 0 END) FROM src WHERE key < 15")

createQueryTest("case statements with key #3",
"SELECT (CASE key WHEN 2 THEN 3 WHEN NULL THEN 4 END) FROM src WHERE key < 15")

createQueryTest("case statements with key #4",
"SELECT (CASE key WHEN 2 THEN 3 WHEN NULL THEN 4 ELSE 0 END) FROM src WHERE key < 15")

createQueryTest("case statements WITHOUT key #1",
"SELECT (CASE WHEN key > 2 THEN 3 END) FROM src WHERE key < 15")

createQueryTest("case statements WITHOUT key #2",
"SELECT (CASE WHEN key > 2 THEN 3 ELSE 4 END) FROM src WHERE key < 15")

createQueryTest("case statements WITHOUT key #3",
"SELECT (CASE WHEN key > 2 THEN 3 WHEN 2 > key THEN 2 END) FROM src WHERE key < 15")

createQueryTest("case statements WITHOUT key #4",
"SELECT (CASE WHEN key > 2 THEN 3 WHEN 2 > key THEN 2 ELSE 0 END) FROM src WHERE key < 15")

test("implement identity function using case statement") {
val actual = hql("SELECT (CASE key WHEN key THEN key END) FROM src").collect().toSet
val expected = hql("SELECT key FROM src").collect().toSet
assert(actual === expected)
}

// TODO: adopt this test when Spark SQL has the functionality / framework to report errors.
// See https://github.com/apache/spark/pull/1055#issuecomment-45820167 for a discussion.
ignore("non-boolean conditions in a CaseWhen are illegal") {
intercept[Exception] {
hql("SELECT (CASE WHEN key > 2 THEN 3 WHEN 1 THEN 2 ELSE 0 END) FROM src").collect()
}
}

private val explainCommandClassName =
classOf[execution.ExplainCommand].getSimpleName.stripSuffix("$")

Expand Down

0 comments on commit e243c5f

Please sign in to comment.