Skip to content

Commit

Permalink
fixed edge cases to prevent overflow
Browse files Browse the repository at this point in the history
  • Loading branch information
dorx committed Jun 10, 2014
1 parent 065ebcd commit ae3ad04
Showing 1 changed file with 7 additions and 17 deletions.
24 changes: 7 additions & 17 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -391,41 +391,31 @@ abstract class RDD[T: ClassTag](
seed: Long = Utils.random.nextLong): Array[T] = {
var fraction = 0.0
var total = 0
val multiplier = 3.0
val initialCount = this.count()
var maxSelected = 0

if (num < 0) {
throw new IllegalArgumentException("Negative number of elements requested")
}

if (initialCount == 0) {
return new Array[T](0)
}

if (!withReplacement && num > initialCount) {
throw new IllegalArgumentException("Cannot create sample larger than the original when " +
"sampling without replacement")
}

if (initialCount == 0) {
return new Array[T](0)
}

if (initialCount > Integer.MAX_VALUE - 1) {
maxSelected = Integer.MAX_VALUE - (5.0 * math.sqrt(Integer.MAX_VALUE)).toInt
val maxSelected = Integer.MAX_VALUE - (5.0 * math.sqrt(Integer.MAX_VALUE)).toInt
if (num > maxSelected) {
throw new IllegalArgumentException("Cannot support a sample size > Integer.MAX_VALUE - " +
"5.0 * math.sqrt(Integer.MAX_VALUE)")
}
} else {
maxSelected = initialCount.toInt
}

if (num > initialCount && !withReplacement) {
// special case not covered in computeFraction
total = maxSelected
fraction = multiplier * (maxSelected + 1) / initialCount
} else {
fraction = SamplingUtils.computeFraction(num, initialCount, withReplacement)
total = num
}
fraction = SamplingUtils.computeFraction(num, initialCount, withReplacement)
total = num

val rand = new Random(seed)
var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
Expand Down

0 comments on commit ae3ad04

Please sign in to comment.