diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index 8d5dbc7dc90eb..98c78c6312222 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CaseWhen, Coalesce, CreateArray, CreateMap, CreateNamedStruct, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery, Window} @@ -116,6 +116,15 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { case CreateMap(children, useStringTypeWhenEmpty) => CreateMap(children.map(normalize), useStringTypeWhenEmpty) + case If(cond, trueValue, falseValue) => + If(cond, normalize(trueValue), normalize(falseValue)) + + case CaseWhen(branches, elseVale) => + CaseWhen(branches.map(br => (br._1, normalize(br._2))), elseVale.map(normalize)) + + case Coalesce(children) => + Coalesce(children.map(normalize)) + case _ if expr.dataType == FloatType || expr.dataType == DoubleType => KnownFloatingPointNormalized(NormalizeNaNAndZero(expr)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala index f5af416602c9d..3f6bdd206535b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{And, IsNull, KnownFloatingPointNormalized} +import org.apache.spark.sql.catalyst.expressions.{CaseWhen, If, IsNull, KnownFloatingPointNormalized} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -85,8 +85,43 @@ class NormalizeFloatingPointNumbersSuite extends PlanTest { val optimized = Optimize.execute(query) val doubleOptimized = Optimize.execute(optimized) val joinCond = IsNull(a) === IsNull(b) && - KnownFloatingPointNormalized(NormalizeNaNAndZero(coalesce(a, 0.0))) === - KnownFloatingPointNormalized(NormalizeNaNAndZero(coalesce(b, 0.0))) + coalesce(KnownFloatingPointNormalized(NormalizeNaNAndZero(a)), + KnownFloatingPointNormalized(NormalizeNaNAndZero(0.0))) === + coalesce(KnownFloatingPointNormalized(NormalizeNaNAndZero(b)), + KnownFloatingPointNormalized(NormalizeNaNAndZero(0.0))) + val correctAnswer = testRelation1.join(testRelation2, condition = Some(joinCond)) + + comparePlans(doubleOptimized, correctAnswer) + } + + test("SPARK-32258: normalize the children of If") { + val cond = If(a > 0.1D, a, a + 0.2D) === b + val query = testRelation1.join(testRelation2, condition = Some(cond)) + val optimized = Optimize.execute(query) + val doubleOptimized = Optimize.execute(optimized) + + val joinCond = If(a > 0.1D, + KnownFloatingPointNormalized(NormalizeNaNAndZero(a)), + KnownFloatingPointNormalized(NormalizeNaNAndZero(a + 0.2D))) === + KnownFloatingPointNormalized(NormalizeNaNAndZero(b)) + val correctAnswer = testRelation1.join(testRelation2, condition = Some(joinCond)) + + comparePlans(doubleOptimized, correctAnswer) + } + + test("SPARK-32258: normalize the children of CaseWhen") { + val cond = CaseWhen( + Seq((a > 0.1D, a), (a > 0.2D, a + 0.2D)), + Some(a + 0.3D)) === b + val query = testRelation1.join(testRelation2, condition = Some(cond)) + val optimized = Optimize.execute(query) + val doubleOptimized = Optimize.execute(optimized) + + val joinCond = CaseWhen( + Seq((a > 0.1D, KnownFloatingPointNormalized(NormalizeNaNAndZero(a))), + (a > 0.2D, KnownFloatingPointNormalized(NormalizeNaNAndZero(a + 0.2D)))), + Some(KnownFloatingPointNormalized(NormalizeNaNAndZero(a + 0.3D)))) === + KnownFloatingPointNormalized(NormalizeNaNAndZero(b)) val correctAnswer = testRelation1.join(testRelation2, condition = Some(joinCond)) comparePlans(doubleOptimized, correctAnswer)