Skip to content

Commit

Permalink
[SPARK-21332][SQL] Incorrect result type inferred for some decimal ex…
Browse files Browse the repository at this point in the history
…pressions

This PR changes the direction of expression transformation in the DecimalPrecision rule. Previously, the expressions were transformed down, which led to incorrect result types when decimal expressions had other decimal expressions as their operands. The root cause of this issue was in visiting outer nodes before their children. Consider the example below:

```
    val inputSchema = StructType(StructField("col", DecimalType(26, 6)) :: Nil)
    val sc = spark.sparkContext
    val rdd = sc.parallelize(1 to 2).map(_ => Row(BigDecimal(12)))
    val df = spark.createDataFrame(rdd, inputSchema)

    // Works correctly since no nested decimal expression is involved
    // Expected result type: (26, 6) * (26, 6) = (38, 12)
    df.select($"col" * $"col").explain(true)
    df.select($"col" * $"col").printSchema()

    // Gives a wrong result since there is a nested decimal expression that should be visited first
    // Expected result type: ((26, 6) * (26, 6)) * (26, 6) = (38, 12) * (26, 6) = (38, 18)
    df.select($"col" * $"col" * $"col").explain(true)
    df.select($"col" * $"col" * $"col").printSchema()
```

The example above gives the following output:

```
// Correct result without sub-expressions
== Parsed Logical Plan ==
'Project [('col * 'col) AS (col * col)#4]
+- LogicalRDD [col#1]

== Analyzed Logical Plan ==
(col * col): decimal(38,12)
Project [CheckOverflow((promote_precision(cast(col#1 as decimal(26,6))) * promote_precision(cast(col#1 as decimal(26,6)))), DecimalType(38,12)) AS (col * col)#4]
+- LogicalRDD [col#1]

== Optimized Logical Plan ==
Project [CheckOverflow((col#1 * col#1), DecimalType(38,12)) AS (col * col)#4]
+- LogicalRDD [col#1]

== Physical Plan ==
*Project [CheckOverflow((col#1 * col#1), DecimalType(38,12)) AS (col * col)#4]
+- Scan ExistingRDD[col#1]

// Schema
root
 |-- (col * col): decimal(38,12) (nullable = true)

// Incorrect result with sub-expressions
== Parsed Logical Plan ==
'Project [(('col * 'col) * 'col) AS ((col * col) * col)apache#11]
+- LogicalRDD [col#1]

== Analyzed Logical Plan ==
((col * col) * col): decimal(38,12)
Project [CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(col#1 as decimal(26,6))) * promote_precision(cast(col#1 as decimal(26,6)))), DecimalType(38,12)) as decimal(26,6))) * promote_precision(cast(col#1 as decimal(26,6)))), DecimalType(38,12)) AS ((col * col) * col)apache#11]
+- LogicalRDD [col#1]

== Optimized Logical Plan ==
Project [CheckOverflow((cast(CheckOverflow((col#1 * col#1), DecimalType(38,12)) as decimal(26,6)) * col#1), DecimalType(38,12)) AS ((col * col) * col)apache#11]
+- LogicalRDD [col#1]

== Physical Plan ==
*Project [CheckOverflow((cast(CheckOverflow((col#1 * col#1), DecimalType(38,12)) as decimal(26,6)) * col#1), DecimalType(38,12)) AS ((col * col) * col)apache#11]
+- Scan ExistingRDD[col#1]

// Schema
root
 |-- ((col * col) * col): decimal(38,12) (nullable = true)
```

This PR was tested with available unit tests. Moreover, there are tests to cover previously failing scenarios.

Author: aokolnychyi <anton.okolnychyi@sap.com>

Closes apache#18583 from aokolnychyi/spark-21332.

(cherry picked from commit 0be5fb4)
Signed-off-by: gatorsmile <gatorsmile@gmail.com>
  • Loading branch information
aokolnychyi authored and Vinitha Gankidi committed Nov 29, 2017
1 parent 55d5bbe commit 29e60e3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ object DecimalPrecision extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// fix decimal precision for expressions
case q => q.transformExpressions(
case q => q.transformExpressionsUp(
decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,14 @@ class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter {
checkType(Average(d1), DecimalType(6, 5))

checkType(Add(Add(d1, d2), d1), DecimalType(7, 2))
checkType(Add(Add(d1, d1), d1), DecimalType(4, 1))
checkType(Add(d1, Add(d1, d1)), DecimalType(4, 1))
checkType(Add(Add(Add(d1, d2), d1), d2), DecimalType(8, 2))
checkType(Add(Add(d1, d2), Add(d1, d2)), DecimalType(7, 2))
checkType(Subtract(Subtract(d2, d1), d1), DecimalType(7, 2))
checkType(Multiply(Multiply(d1, d1), d2), DecimalType(11, 4))
checkType(Divide(d2, Add(d1, d1)), DecimalType(10, 6))
checkType(Sum(Add(d1, d1)), DecimalType(13, 1))
}

test("Comparison operations") {
Expand Down

0 comments on commit 29e60e3

Please sign in to comment.