Skip to content

Commit

Permalink
[WIP][Spark-SQL] Optimize the Constant Folding for Expression
Browse files Browse the repository at this point in the history
Currently, expression does not support the "constant null" well in constant folding.
e.g. Sum(a, 0) actually always produces Literal(0, NumericType) in runtime.

For example:
```
explain select isnull(key+null)  from src;
== Logical Plan ==
Project [HiveGenericUdf#isnull((key#30 + CAST(null, IntegerType))) AS c_0#28]
 MetastoreRelation default, src, None

== Optimized Logical Plan ==
Project [true AS c_0#28]
 MetastoreRelation default, src, None

== Physical Plan ==
Project [true AS c_0#28]
 HiveTableScan [], (MetastoreRelation default, src, None), None
```

I've create a new Optimization rule called NullPropagation for such kind of constant folding.

Author: Cheng Hao <hao.cheng@intel.com>
Author: Michael Armbrust <michael@databricks.com>

Closes apache#482 from chenghao-intel/optimize_constant_folding and squashes the following commits:

2f14b50 [Cheng Hao] Fix code style issues
68b9fad [Cheng Hao] Remove the Literal pattern matching for NullPropagation
29c8166 [Cheng Hao] Update the code for feedback of code review
50444cc [Cheng Hao] Remove the unnecessary null checking
80f9f18 [Cheng Hao] Update the UnitTest for aggregation constant folding
27ea3d7 [Cheng Hao] Fix Constant Folding Bugs & Add More Unittests
b28e03a [Cheng Hao] Merge pull request #1 from marmbrus/pr/482
9ccefdb [Michael Armbrust] Add tests for optimized expression evaluation.
543ef9d [Cheng Hao] fix code style issues
9cf0396 [Cheng Hao] update code according to the code review comment
536c005 [Cheng Hao] Add Exceptional case for constant folding
3c045c7 [Cheng Hao] Optimize the Constant Folding by adding more rules
2645d4f [Cheng Hao] Constant Folding(null propagation)
  • Loading branch information
chenghao-intel authored and rxin committed May 7, 2014
1 parent 913a0a9 commit 3eb53bd
Show file tree
Hide file tree
Showing 14 changed files with 1,502 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,37 +114,37 @@ package object dsl {
def attr = analysis.UnresolvedAttribute(s)

/** Creates a new AttributeReference of type boolean */
def boolean = AttributeReference(s, BooleanType, nullable = false)()
def boolean = AttributeReference(s, BooleanType, nullable = true)()

/** Creates a new AttributeReference of type byte */
def byte = AttributeReference(s, ByteType, nullable = false)()
def byte = AttributeReference(s, ByteType, nullable = true)()

/** Creates a new AttributeReference of type short */
def short = AttributeReference(s, ShortType, nullable = false)()
def short = AttributeReference(s, ShortType, nullable = true)()

/** Creates a new AttributeReference of type int */
def int = AttributeReference(s, IntegerType, nullable = false)()
def int = AttributeReference(s, IntegerType, nullable = true)()

/** Creates a new AttributeReference of type long */
def long = AttributeReference(s, LongType, nullable = false)()
def long = AttributeReference(s, LongType, nullable = true)()

/** Creates a new AttributeReference of type float */
def float = AttributeReference(s, FloatType, nullable = false)()
def float = AttributeReference(s, FloatType, nullable = true)()

/** Creates a new AttributeReference of type double */
def double = AttributeReference(s, DoubleType, nullable = false)()
def double = AttributeReference(s, DoubleType, nullable = true)()

/** Creates a new AttributeReference of type string */
def string = AttributeReference(s, StringType, nullable = false)()
def string = AttributeReference(s, StringType, nullable = true)()

/** Creates a new AttributeReference of type decimal */
def decimal = AttributeReference(s, DecimalType, nullable = false)()
def decimal = AttributeReference(s, DecimalType, nullable = true)()

/** Creates a new AttributeReference of type timestamp */
def timestamp = AttributeReference(s, TimestampType, nullable = false)()
def timestamp = AttributeReference(s, TimestampType, nullable = true)()

/** Creates a new AttributeReference of type binary */
def binary = AttributeReference(s, BinaryType, nullable = false)()
def binary = AttributeReference(s, BinaryType, nullable = true)()
}

implicit class DslAttribute(a: AttributeReference) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ abstract class Expression extends TreeNode[Expression] {
* - A [[expressions.Cast Cast]] or [[expressions.UnaryMinus UnaryMinus]] is foldable if its
* child is foldable.
*/
// TODO: Supporting more foldable expressions. For example, deterministic Hive UDFs.
def foldable: Boolean = false
def nullable: Boolean
def references: Set[Attribute]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.trees

abstract sealed class SortDirection
case object Ascending extends SortDirection
Expand All @@ -27,7 +28,10 @@ case object Descending extends SortDirection
* An expression that can be used to sort a tuple. This class extends expression primarily so that
* transformations over expression will descend into its child.
*/
case class SortOrder(child: Expression, direction: SortDirection) extends UnaryExpression {
case class SortOrder(child: Expression, direction: SortDirection) extends Expression
with trees.UnaryNode[Expression] {

override def references = child.references
override def dataType = child.dataType
override def nullable = child.nullable

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
val children = child :: ordinal :: Nil
/** `Null` is returned for invalid ordinals. */
override def nullable = true
override def foldable = child.foldable && ordinal.foldable
override def references = children.flatMap(_.references).toSet
def dataType = child.dataType match {
case ArrayType(dt) => dt
Expand All @@ -40,23 +41,27 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
override def toString = s"$child[$ordinal]"

override def eval(input: Row): Any = {
if (child.dataType.isInstanceOf[ArrayType]) {
val baseValue = child.eval(input).asInstanceOf[Seq[_]]
val o = ordinal.eval(input).asInstanceOf[Int]
if (baseValue == null) {
null
} else if (o >= baseValue.size || o < 0) {
null
} else {
baseValue(o)
}
val value = child.eval(input)
if (value == null) {
null
} else {
val baseValue = child.eval(input).asInstanceOf[Map[Any, _]]
val key = ordinal.eval(input)
if (baseValue == null) {
if (key == null) {
null
} else {
baseValue.get(key).orNull
if (child.dataType.isInstanceOf[ArrayType]) {
val baseValue = value.asInstanceOf[Seq[_]]
val o = key.asInstanceOf[Int]
if (o >= baseValue.size || o < 0) {
null
} else {
baseValue(o)
}
} else {
val baseValue = value.asInstanceOf[Map[Any, _]]
val key = ordinal.eval(input)
baseValue.get(key).orNull
}
}
}
}
Expand All @@ -69,7 +74,8 @@ case class GetField(child: Expression, fieldName: String) extends UnaryExpressio
type EvaluatedType = Any

def dataType = field.dataType
def nullable = field.nullable
override def nullable = field.nullable
override def foldable = child.foldable

protected def structType = child.dataType match {
case s: StructType => s
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ abstract class BinaryPredicate extends BinaryExpression with Predicate {
def nullable = left.nullable || right.nullable
}

case class Not(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
def references = child.references
case class Not(child: Expression) extends UnaryExpression with Predicate {
override def foldable = child.foldable
def nullable = child.nullable
override def toString = s"NOT $child"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.types._
object Optimizer extends RuleExecutor[LogicalPlan] {
val batches =
Batch("ConstantFolding", Once,
NullPropagation,
ConstantFolding,
BooleanSimplification,
SimplifyFilters,
Expand Down Expand Up @@ -85,6 +86,72 @@ object ColumnPruning extends Rule[LogicalPlan] {
}
}

/**
* Replaces [[catalyst.expressions.Expression Expressions]] that can be statically evaluated with
* equivalent [[catalyst.expressions.Literal Literal]] values. This rule is more specific with
* Null value propagation from bottom to top of the expression tree.
*/
object NullPropagation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case e @ Count(Literal(null, _)) => Literal(0, e.dataType)
case e @ Sum(Literal(c, _)) if c == 0 => Literal(0, e.dataType)
case e @ Average(Literal(c, _)) if c == 0 => Literal(0.0, e.dataType)
case e @ IsNull(c) if c.nullable == false => Literal(false, BooleanType)
case e @ IsNotNull(c) if c.nullable == false => Literal(true, BooleanType)
case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType)
case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType)
case e @ GetField(Literal(null, _), _) => Literal(null, e.dataType)
case e @ Coalesce(children) => {
val newChildren = children.filter(c => c match {
case Literal(null, _) => false
case _ => true
})
if (newChildren.length == 0) {
Literal(null, e.dataType)
} else if (newChildren.length == 1) {
newChildren(0)
} else {
Coalesce(newChildren)
}
}
case e @ If(Literal(v, _), trueValue, falseValue) => if (v == true) trueValue else falseValue
case e @ In(Literal(v, _), list) if (list.exists(c => c match {
case Literal(candidate, _) if candidate == v => true
case _ => false
})) => Literal(true, BooleanType)
case e: UnaryMinus => e.child match {
case Literal(null, _) => Literal(null, e.dataType)
case _ => e
}
case e: Cast => e.child match {
case Literal(null, _) => Literal(null, e.dataType)
case _ => e
}
case e: Not => e.child match {
case Literal(null, _) => Literal(null, e.dataType)
case _ => e
}
// Put exceptional cases above if any
case e: BinaryArithmetic => e.children match {
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
case _ => e
}
case e: BinaryComparison => e.children match {
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
case _ => e
}
case e: StringRegexExpression => e.children match {
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
case _ => e
}
}
}
}

/**
* Replaces [[catalyst.expressions.Expression Expressions]] that can be statically evaluated with
* equivalent [[catalyst.expressions.Literal Literal]] values.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,7 @@ class ExpressionEvaluationSuite extends FunSuite {
truthTable.foreach {
case (l,r,answer) =>
val expr = op(Literal(l, BooleanType), Literal(r, BooleanType))
val result = expr.eval(null)
if (result != answer)
fail(s"$expr should not evaluate to $result, expected: $answer")
checkEvaluation(expr, answer)
}
}
}
Expand All @@ -131,6 +129,7 @@ class ExpressionEvaluationSuite extends FunSuite {

test("LIKE literal Regular Expression") {
checkEvaluation(Literal(null, StringType).like("a"), null)
checkEvaluation(Literal("a", StringType).like(Literal(null, StringType)), null)
checkEvaluation(Literal(null, StringType).like(Literal(null, StringType)), null)
checkEvaluation("abdef" like "abdef", true)
checkEvaluation("a_%b" like "a\\__b", true)
Expand Down Expand Up @@ -159,9 +158,14 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation("abc" like regEx, true, new GenericRow(Array[Any]("a%")))
checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("b%")))
checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("bc%")))

checkEvaluation(Literal(null, StringType) like regEx, null, new GenericRow(Array[Any]("bc%")))
}

test("RLIKE literal Regular Expression") {
checkEvaluation(Literal(null, StringType) rlike "abdef", null)
checkEvaluation("abdef" rlike Literal(null, StringType), null)
checkEvaluation(Literal(null, StringType) rlike Literal(null, StringType), null)
checkEvaluation("abdef" rlike "abdef", true)
checkEvaluation("abbbbc" rlike "a.*c", true)

Expand Down Expand Up @@ -257,6 +261,8 @@ class ExpressionEvaluationSuite extends FunSuite {
assert(("abcdef" cast DecimalType).nullable === true)
assert(("abcdef" cast DoubleType).nullable === true)
assert(("abcdef" cast FloatType).nullable === true)

checkEvaluation(Cast(Literal(null, IntegerType), ShortType), null)
}

test("timestamp") {
Expand Down Expand Up @@ -287,5 +293,108 @@ class ExpressionEvaluationSuite extends FunSuite {
// A test for higher precision than millis
checkEvaluation(Cast(Cast(0.00000001, TimestampType), DoubleType), 0.00000001)
}

test("null checking") {
val row = new GenericRow(Array[Any]("^Ba*n", null, true, null))
val c1 = 'a.string.at(0)
val c2 = 'a.string.at(1)
val c3 = 'a.boolean.at(2)
val c4 = 'a.boolean.at(3)

checkEvaluation(IsNull(c1), false, row)
checkEvaluation(IsNotNull(c1), true, row)

checkEvaluation(IsNull(c2), true, row)
checkEvaluation(IsNotNull(c2), false, row)

checkEvaluation(IsNull(Literal(1, ShortType)), false)
checkEvaluation(IsNotNull(Literal(1, ShortType)), true)

checkEvaluation(IsNull(Literal(null, ShortType)), true)
checkEvaluation(IsNotNull(Literal(null, ShortType)), false)

checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row)
checkEvaluation(Coalesce(Literal(null, StringType) :: Nil), null, row)
checkEvaluation(Coalesce(Literal(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row)

checkEvaluation(If(c3, Literal("a", StringType), Literal("b", StringType)), "a", row)
checkEvaluation(If(c3, c1, c2), "^Ba*n", row)
checkEvaluation(If(c4, c2, c1), "^Ba*n", row)
checkEvaluation(If(Literal(null, BooleanType), c2, c1), "^Ba*n", row)
checkEvaluation(If(Literal(true, BooleanType), c1, c2), "^Ba*n", row)
checkEvaluation(If(Literal(false, BooleanType), c2, c1), "^Ba*n", row)
checkEvaluation(If(Literal(false, BooleanType),
Literal("a", StringType), Literal("b", StringType)), "b", row)

checkEvaluation(In(c1, c1 :: c2 :: Nil), true, row)
checkEvaluation(In(Literal("^Ba*n", StringType),
Literal("^Ba*n", StringType) :: Nil), true, row)
checkEvaluation(In(Literal("^Ba*n", StringType),
Literal("^Ba*n", StringType) :: c2 :: Nil), true, row)
}

test("complex type") {
val row = new GenericRow(Array[Any](
"^Ba*n", // 0
null.asInstanceOf[String], // 1
new GenericRow(Array[Any]("aa", "bb")), // 2
Map("aa"->"bb"), // 3
Seq("aa", "bb") // 4
))

val typeS = StructType(
StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil
)
val typeMap = MapType(StringType, StringType)
val typeArray = ArrayType(StringType)

checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()),
Literal("aa")), "bb", row)
checkEvaluation(GetItem(Literal(null, typeMap), Literal("aa")), null, row)
checkEvaluation(GetItem(Literal(null, typeMap), Literal(null, StringType)), null, row)
checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()),
Literal(null, StringType)), null, row)

checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()),
Literal(1)), "bb", row)
checkEvaluation(GetItem(Literal(null, typeArray), Literal(1)), null, row)
checkEvaluation(GetItem(Literal(null, typeArray), Literal(null, IntegerType)), null, row)
checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()),
Literal(null, IntegerType)), null, row)

checkEvaluation(GetField(BoundReference(2, AttributeReference("c", typeS)()), "a"), "aa", row)
checkEvaluation(GetField(Literal(null, typeS), "a"), null, row)
}

test("arithmetic") {
val row = new GenericRow(Array[Any](1, 2, 3, null))
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.int.at(2)
val c4 = 'a.int.at(3)

checkEvaluation(UnaryMinus(c1), -1, row)
checkEvaluation(UnaryMinus(Literal(100, IntegerType)), -100)

checkEvaluation(Add(c1, c4), null, row)
checkEvaluation(Add(c1, c2), 3, row)
checkEvaluation(Add(c1, Literal(null, IntegerType)), null, row)
checkEvaluation(Add(Literal(null, IntegerType), c2), null, row)
checkEvaluation(Add(Literal(null, IntegerType), Literal(null, IntegerType)), null, row)
}

test("BinaryComparison") {
val row = new GenericRow(Array[Any](1, 2, 3, null))
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.int.at(2)
val c4 = 'a.int.at(3)

checkEvaluation(LessThan(c1, c4), null, row)
checkEvaluation(LessThan(c1, c2), true, row)
checkEvaluation(LessThan(c1, Literal(null, IntegerType)), null, row)
checkEvaluation(LessThan(Literal(null, IntegerType), c2), null, row)
checkEvaluation(LessThan(Literal(null, IntegerType), Literal(null, IntegerType)), null, row)
}
}

Loading

0 comments on commit 3eb53bd

Please sign in to comment.