Skip to content

Commit

Permalink
[SPARK-32258][SQL] NormalizeFloatingNumbers directly normalizes IF/Ca…
Browse files Browse the repository at this point in the history
…seWhen/Coalesce child expressions

### What changes were proposed in this pull request?

This patch proposes to let `NormalizeFloatingNumbers` rule directly normalizes on certain children expressions. It could simplify expression tree.

### Why are the changes needed?

Currently NormalizeFloatingNumbers rule treats some expressions as black box but we can optimize it a bit by normalizing directly the inner children expressions.

Also see apache#28962 (comment).

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Unit tests.

Closes apache#29061 from viirya/SPARK-32258.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
  • Loading branch information
viirya authored and dongjoon-hyun committed Jul 12, 2020
1 parent bc3d4ba commit b6229df
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CaseWhen, Coalesce, CreateArray, CreateMap, CreateNamedStruct, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery, Window}
Expand Down Expand Up @@ -116,6 +116,15 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
case CreateMap(children, useStringTypeWhenEmpty) =>
CreateMap(children.map(normalize), useStringTypeWhenEmpty)

case If(cond, trueValue, falseValue) =>
If(cond, normalize(trueValue), normalize(falseValue))

case CaseWhen(branches, elseVale) =>
CaseWhen(branches.map(br => (br._1, normalize(br._2))), elseVale.map(normalize))

case Coalesce(children) =>
Coalesce(children.map(normalize))

case _ if expr.dataType == FloatType || expr.dataType == DoubleType =>
KnownFloatingPointNormalized(NormalizeNaNAndZero(expr))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{And, IsNull, KnownFloatingPointNormalized}
import org.apache.spark.sql.catalyst.expressions.{CaseWhen, If, IsNull, KnownFloatingPointNormalized}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
Expand Down Expand Up @@ -85,8 +85,43 @@ class NormalizeFloatingPointNumbersSuite extends PlanTest {
val optimized = Optimize.execute(query)
val doubleOptimized = Optimize.execute(optimized)
val joinCond = IsNull(a) === IsNull(b) &&
KnownFloatingPointNormalized(NormalizeNaNAndZero(coalesce(a, 0.0))) ===
KnownFloatingPointNormalized(NormalizeNaNAndZero(coalesce(b, 0.0)))
coalesce(KnownFloatingPointNormalized(NormalizeNaNAndZero(a)),
KnownFloatingPointNormalized(NormalizeNaNAndZero(0.0))) ===
coalesce(KnownFloatingPointNormalized(NormalizeNaNAndZero(b)),
KnownFloatingPointNormalized(NormalizeNaNAndZero(0.0)))
val correctAnswer = testRelation1.join(testRelation2, condition = Some(joinCond))

comparePlans(doubleOptimized, correctAnswer)
}

test("SPARK-32258: normalize the children of If") {
val cond = If(a > 0.1D, a, a + 0.2D) === b
val query = testRelation1.join(testRelation2, condition = Some(cond))
val optimized = Optimize.execute(query)
val doubleOptimized = Optimize.execute(optimized)

val joinCond = If(a > 0.1D,
KnownFloatingPointNormalized(NormalizeNaNAndZero(a)),
KnownFloatingPointNormalized(NormalizeNaNAndZero(a + 0.2D))) ===
KnownFloatingPointNormalized(NormalizeNaNAndZero(b))
val correctAnswer = testRelation1.join(testRelation2, condition = Some(joinCond))

comparePlans(doubleOptimized, correctAnswer)
}

test("SPARK-32258: normalize the children of CaseWhen") {
val cond = CaseWhen(
Seq((a > 0.1D, a), (a > 0.2D, a + 0.2D)),
Some(a + 0.3D)) === b
val query = testRelation1.join(testRelation2, condition = Some(cond))
val optimized = Optimize.execute(query)
val doubleOptimized = Optimize.execute(optimized)

val joinCond = CaseWhen(
Seq((a > 0.1D, KnownFloatingPointNormalized(NormalizeNaNAndZero(a))),
(a > 0.2D, KnownFloatingPointNormalized(NormalizeNaNAndZero(a + 0.2D)))),
Some(KnownFloatingPointNormalized(NormalizeNaNAndZero(a + 0.3D)))) ===
KnownFloatingPointNormalized(NormalizeNaNAndZero(b))
val correctAnswer = testRelation1.join(testRelation2, condition = Some(joinCond))

comparePlans(doubleOptimized, correctAnswer)
Expand Down

0 comments on commit b6229df

Please sign in to comment.