From ae3ad049161f470549510daf2eefdbe576fb01e8 Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Tue, 10 Jun 2014 12:02:34 -0700 Subject: [PATCH] fixed edge cases to prevent overflow --- .../main/scala/org/apache/spark/rdd/RDD.scala | 24 ++++++------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 39933ec340356..fb4c3c6ebd6cc 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -391,41 +391,31 @@ abstract class RDD[T: ClassTag]( seed: Long = Utils.random.nextLong): Array[T] = { var fraction = 0.0 var total = 0 - val multiplier = 3.0 val initialCount = this.count() - var maxSelected = 0 if (num < 0) { throw new IllegalArgumentException("Negative number of elements requested") } + if (initialCount == 0) { + return new Array[T](0) + } + if (!withReplacement && num > initialCount) { throw new IllegalArgumentException("Cannot create sample larger than the original when " + "sampling without replacement") } - if (initialCount == 0) { - return new Array[T](0) - } - if (initialCount > Integer.MAX_VALUE - 1) { - maxSelected = Integer.MAX_VALUE - (5.0 * math.sqrt(Integer.MAX_VALUE)).toInt + val maxSelected = Integer.MAX_VALUE - (5.0 * math.sqrt(Integer.MAX_VALUE)).toInt if (num > maxSelected) { throw new IllegalArgumentException("Cannot support a sample size > Integer.MAX_VALUE - " + "5.0 * math.sqrt(Integer.MAX_VALUE)") } - } else { - maxSelected = initialCount.toInt } - if (num > initialCount && !withReplacement) { - // special case not covered in computeFraction - total = maxSelected - fraction = multiplier * (maxSelected + 1) / initialCount - } else { - fraction = SamplingUtils.computeFraction(num, initialCount, withReplacement) - total = num - } + fraction = SamplingUtils.computeFraction(num, initialCount, withReplacement) + total = num val rand = new Random(seed) var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()