Skip to content

Commit

Permalink
Fix logical conflict with PR apache#15703
Browse files Browse the repository at this point in the history
  • Loading branch information
liancheng committed Nov 22, 2016
1 parent 4b88eed commit 6db5af9
Showing 1 changed file with 62 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -205,23 +205,19 @@ class ObjectHashAggregateSuite
// A TypedImperativeAggregate function
val typed = percentile_approx($"c0", 0.5)

// A Hive UDAF without partial aggregation support
val withoutPartial = function("hive_max", $"c1")

// A Spark SQL native aggregate function with partial aggregation support that can be executed
// by the Tungsten `HashAggregateExec`
val withPartialUnsafe = max($"c2")
val withPartialUnsafe = max($"c1")

// A Spark SQL native aggregate function with partial aggregation support that can only be
// executed by the Tungsten `HashAggregateExec`
val withPartialSafe = max($"c3")
val withPartialSafe = max($"c2")

// A Spark SQL native distinct aggregate function
val withDistinct = countDistinct($"c4")
val withDistinct = countDistinct($"c3")

val allAggs = Seq(
"typed" -> typed,
"without partial" -> withoutPartial,
"with partial + unsafe" -> withPartialUnsafe,
"with partial + safe" -> withPartialSafe,
"with distinct" -> withDistinct
Expand Down Expand Up @@ -276,10 +272,9 @@ class ObjectHashAggregateSuite
// Generates a random schema for the randomized data generator
val schema = new StructType()
.add("c0", numericTypes(random.nextInt(numericTypes.length)), nullable = true)
.add("c1", orderedTypes(random.nextInt(orderedTypes.length)), nullable = true)
.add("c2", fixedLengthTypes(random.nextInt(fixedLengthTypes.length)), nullable = true)
.add("c3", varLenOrderedTypes(random.nextInt(varLenOrderedTypes.length)), nullable = true)
.add("c4", allTypes(random.nextInt(allTypes.length)), nullable = true)
.add("c1", fixedLengthTypes(random.nextInt(fixedLengthTypes.length)), nullable = true)
.add("c2", varLenOrderedTypes(random.nextInt(varLenOrderedTypes.length)), nullable = true)
.add("c3", allTypes(random.nextInt(allTypes.length)), nullable = true)

logInfo(
s"""Using the following random schema to generate all the randomized aggregation tests:
Expand Down Expand Up @@ -325,69 +320,67 @@ class ObjectHashAggregateSuite

// Currently Spark SQL doesn't support evaluating distinct aggregate function together
// with aggregate functions without partial aggregation support.
if (!(aggs.contains(withoutPartial) && aggs.contains(withDistinct))) {
test(
s"randomized aggregation test - " +
s"${names.mkString("[", ", ", "]")} - " +
s"${if (withGroupingKeys) "with" else "without"} grouping keys - " +
s"with ${if (emptyInput) "empty" else "non-empty"} input"
) {
var expected: Seq[Row] = null
var actual1: Seq[Row] = null
var actual2: Seq[Row] = null

// Disables `ObjectHashAggregateExec` to obtain a standard answer
withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") {
val aggDf = doAggregation(df)

if (aggs.intersect(Seq(withoutPartial, withPartialSafe, typed)).nonEmpty) {
assert(containsSortAggregateExec(aggDf))
assert(!containsObjectHashAggregateExec(aggDf))
assert(!containsHashAggregateExec(aggDf))
} else {
assert(!containsSortAggregateExec(aggDf))
assert(!containsObjectHashAggregateExec(aggDf))
assert(containsHashAggregateExec(aggDf))
}

expected = aggDf.collect().toSeq
test(
s"randomized aggregation test - " +
s"${names.mkString("[", ", ", "]")} - " +
s"${if (withGroupingKeys) "with" else "without"} grouping keys - " +
s"with ${if (emptyInput) "empty" else "non-empty"} input"
) {
var expected: Seq[Row] = null
var actual1: Seq[Row] = null
var actual2: Seq[Row] = null

// Disables `ObjectHashAggregateExec` to obtain a standard answer
withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") {
val aggDf = doAggregation(df)

if (aggs.intersect(Seq(withPartialSafe, typed)).nonEmpty) {
assert(containsSortAggregateExec(aggDf))
assert(!containsObjectHashAggregateExec(aggDf))
assert(!containsHashAggregateExec(aggDf))
} else {
assert(!containsSortAggregateExec(aggDf))
assert(!containsObjectHashAggregateExec(aggDf))
assert(containsHashAggregateExec(aggDf))
}

expected = aggDf.collect().toSeq
}

// Enables `ObjectHashAggregateExec`
withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") {
val aggDf = doAggregation(df)

if (aggs.contains(typed)) {
assert(!containsSortAggregateExec(aggDf))
assert(containsObjectHashAggregateExec(aggDf))
assert(!containsHashAggregateExec(aggDf))
} else if (aggs.contains(withPartialSafe)) {
assert(containsSortAggregateExec(aggDf))
assert(!containsObjectHashAggregateExec(aggDf))
assert(!containsHashAggregateExec(aggDf))
} else {
assert(!containsSortAggregateExec(aggDf))
assert(!containsObjectHashAggregateExec(aggDf))
assert(containsHashAggregateExec(aggDf))
}

// Enables `ObjectHashAggregateExec`
withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") {
val aggDf = doAggregation(df)

if (aggs.contains(typed) && !aggs.contains(withoutPartial)) {
assert(!containsSortAggregateExec(aggDf))
assert(containsObjectHashAggregateExec(aggDf))
assert(!containsHashAggregateExec(aggDf))
} else if (aggs.intersect(Seq(withoutPartial, withPartialSafe)).nonEmpty) {
assert(containsSortAggregateExec(aggDf))
assert(!containsObjectHashAggregateExec(aggDf))
assert(!containsHashAggregateExec(aggDf))
} else {
assert(!containsSortAggregateExec(aggDf))
assert(!containsObjectHashAggregateExec(aggDf))
assert(containsHashAggregateExec(aggDf))
}

// Disables sort-based aggregation fallback (we only generate 50 rows, so 100 is
// big enough) to obtain a result to be checked.
withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "100") {
actual1 = aggDf.collect().toSeq
}

// Enables sort-based aggregation fallback to obtain another result to be checked.
withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "3") {
// Here we are not reusing `aggDf` because the physical plan in `aggDf` is
// cached and won't be re-planned using the new fallback threshold.
actual2 = doAggregation(df).collect().toSeq
}
// Disables sort-based aggregation fallback (we only generate 50 rows, so 100 is
// big enough) to obtain a result to be checked.
withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "100") {
actual1 = aggDf.collect().toSeq
}

doubleSafeCheckRows(actual1, expected, 1e-4)
doubleSafeCheckRows(actual2, expected, 1e-4)
// Enables sort-based aggregation fallback to obtain another result to be checked.
withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "3") {
// Here we are not reusing `aggDf` because the physical plan in `aggDf` is
// cached and won't be re-planned using the new fallback threshold.
actual2 = doAggregation(df).collect().toSeq
}
}

doubleSafeCheckRows(actual1, expected, 1e-4)
doubleSafeCheckRows(actual2, expected, 1e-4)
}
}
}
Expand Down Expand Up @@ -424,10 +417,6 @@ class ObjectHashAggregateSuite
}
}

private def function(name: String, args: Column*): Column = {
Column(UnresolvedFunction(FunctionIdentifier(name), args.map(_.expr), isDistinct = false))
}

test("SPARK-18403 Fix unsafe data false sharing issue in ObjectHashAggregateExec") {
// SPARK-18403: An unsafe data false sharing issue may trigger OOM / SIGSEGV when evaluating
// certain aggregate functions. To reproduce this issue, the following conditions must be
Expand Down

0 comments on commit 6db5af9

Please sign in to comment.