Skip to content

Commit

Permalink
SPARK-1939 Refactor takeSample method in RDD to use ScaSRS
Browse files Browse the repository at this point in the history
Modified the takeSample method in RDD to use the ScaSRS sampling technique to improve performance. Added a private method that computes sampling rate > sample_size/total to ensure sufficient sample size with success rate >= 0.9999. Added a unit test for the private method to validate choice of sampling rate.

Author: Doris Xin <doris.s.xin@gmail.com>
Author: dorx <doris.s.xin@gmail.com>
Author: Xiangrui Meng <meng@databricks.com>

Closes apache#916 from dorx/takeSample and squashes the following commits:

5b061ae [Doris Xin] merge master
444e750 [Doris Xin] edge cases
3de882b [dorx] Merge pull request #2 from mengxr/SPARK-1939
82dde31 [Xiangrui Meng] update pyspark's takeSample
48d954d [Doris Xin] remove unused imports from RDDSuite
fb1452f [Doris Xin] allowing num to be greater than count in all cases
1481b01 [Doris Xin] washing test tubes and making coffee
dc699f3 [Doris Xin] give back imports removed by accident in rdd.py
64e445b [Doris Xin] logwarnning as soon as it enters the while loop
55518ed [Doris Xin] added TODO for logging in rdd.py
eff89e2 [Doris Xin] addressed reviewer comments.
ecab508 [Doris Xin] "fixed checkstyle violation
0a9b3e3 [Doris Xin] "reviewer comment addressed"
f80f270 [Doris Xin] Merge branch 'master' into takeSample
ae3ad04 [Doris Xin] fixed edge cases to prevent overflow
065ebcd [Doris Xin] Merge branch 'master' into takeSample
9bdd36e [Doris Xin] Check sample size and move computeFraction
e3fd6a6 [Doris Xin] Merge branch 'master' into takeSample
7cab53a [Doris Xin] fixed import bug in rdd.py
ffea61a [Doris Xin] SPARK-1939: Refactor takeSample method in RDD
1441977 [Doris Xin] SPARK-1939 Refactor takeSample method in RDD to use ScaSRS
  • Loading branch information
dorx authored and mengxr committed Jun 13, 2014
1 parent 0154587 commit 1de1d70
Show file tree
Hide file tree
Showing 8 changed files with 263 additions and 100 deletions.
5 changes: 5 additions & 0 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
Expand Down
52 changes: 31 additions & 21 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler}
import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils}

/**
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
Expand Down Expand Up @@ -378,46 +378,56 @@ abstract class RDD[T: ClassTag](
}.toArray
}

def takeSample(withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] =
{
var fraction = 0.0
var total = 0
val multiplier = 3.0
val initialCount = this.count()
var maxSelected = 0
/**
* Return a fixed-size sampled subset of this RDD in an array
*
* @param withReplacement whether sampling is done with replacement
* @param num size of the returned sample
* @param seed seed for the random number generator
* @return sample of specified size in an array
*/
def takeSample(withReplacement: Boolean,
num: Int,
seed: Long = Utils.random.nextLong): Array[T] = {
val numStDev = 10.0

if (num < 0) {
throw new IllegalArgumentException("Negative number of elements requested")
} else if (num == 0) {
return new Array[T](0)
}

val initialCount = this.count()
if (initialCount == 0) {
return new Array[T](0)
}

if (initialCount > Integer.MAX_VALUE - 1) {
maxSelected = Integer.MAX_VALUE - 1
} else {
maxSelected = initialCount.toInt
val maxSampleSize = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt
if (num > maxSampleSize) {
throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " +
s"$numStDev * math.sqrt(Int.MaxValue)")
}

if (num > initialCount && !withReplacement) {
total = maxSelected
fraction = multiplier * (maxSelected + 1) / initialCount
} else {
fraction = multiplier * (num + 1) / initialCount
total = num
val rand = new Random(seed)
if (!withReplacement && num >= initialCount) {
return Utils.randomizeInPlace(this.collect(), rand)
}

val rand = new Random(seed)
val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount,
withReplacement)

var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()

// If the first sample didn't turn out large enough, keep trying to take samples;
// this shouldn't happen often because we use a big multiplier for the initial size
while (samples.length < total) {
var numIters = 0
while (samples.length < num) {
logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters")
samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
numIters += 1
}

Utils.randomizeInPlace(samples, rand).take(total)
Utils.randomizeInPlace(samples, rand).take(num)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
}

/**
* Return a sampler with is the complement of the range specified of the current sampler.
* Return a sampler that is the complement of the range specified of the current sampler.
*/
def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util.random

private[spark] object SamplingUtils {

/**
* Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of
* the time.
*
* How the sampling rate is determined:
* Let p = num / total, where num is the sample size and total is the total number of
* datapoints in the RDD. We're trying to compute q > p such that
* - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q),
* where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to total),
* i.e. the failure rate of not having a sufficiently large sample < 0.0001.
* Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for
* num > 12, but we need a slightly larger q (9 empirically determined).
* - when sampling without replacement, we're drawing each datapoint with prob_i
* ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success
* rate, where success rate is defined the same as in sampling with replacement.
*
* @param sampleSizeLowerBound sample size
* @param total size of RDD
* @param withReplacement whether sampling with replacement
* @return a sampling rate that guarantees sufficient sample size with 99.99% success rate
*/
def computeFractionForSampleSize(sampleSizeLowerBound: Int, total: Long,
withReplacement: Boolean): Double = {
val fraction = sampleSizeLowerBound.toDouble / total
if (withReplacement) {
val numStDev = if (sampleSizeLowerBound < 12) 9 else 5
fraction + numStDev * math.sqrt(fraction / total)
} else {
val delta = 1e-4
val gamma = - math.log(delta) / total
math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction))
}
}
}
35 changes: 18 additions & 17 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -505,55 +505,56 @@ class RDDSuite extends FunSuite with SharedSparkContext {
}

test("takeSample") {
val data = sc.parallelize(1 to 100, 2)
val n = 1000000
val data = sc.parallelize(1 to n, 2)

for (num <- List(5, 20, 100)) {
val sample = data.takeSample(withReplacement=false, num=num)
assert(sample.size === num) // Got exactly num elements
assert(sample.toSet.size === num) // Elements are distinct
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=false, 20, seed)
assert(sample.size === 20) // Got exactly 20 elements
assert(sample.toSet.size === 20) // Elements are distinct
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=false, 200, seed)
val sample = data.takeSample(withReplacement=false, 100, seed)
assert(sample.size === 100) // Got only 100 elements
assert(sample.toSet.size === 100) // Elements are distinct
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=true, 20, seed)
assert(sample.size === 20) // Got exactly 20 elements
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
{
val sample = data.takeSample(withReplacement=true, num=20)
assert(sample.size === 20) // Got exactly 100 elements
assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements")
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
{
val sample = data.takeSample(withReplacement=true, num=100)
assert(sample.size === 100) // Got exactly 100 elements
val sample = data.takeSample(withReplacement=true, num=n)
assert(sample.size === n) // Got exactly 100 elements
// Chance of getting all distinct elements is astronomically low, so test we got < 100
assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=true, 100, seed)
assert(sample.size === 100) // Got exactly 100 elements
val sample = data.takeSample(withReplacement=true, n, seed)
assert(sample.size === n) // Got exactly 100 elements
// Chance of getting all distinct elements is astronomically low, so test we got < 100
assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=true, 200, seed)
assert(sample.size === 200) // Got exactly 200 elements
val sample = data.takeSample(withReplacement=true, 2 * n, seed)
assert(sample.size === 2 * n) // Got exactly 200 elements
// Chance of getting all distinct elements is still quite low, so test we got < 100
assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util.random

import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution}
import org.scalatest.FunSuite

class SamplingUtilsSuite extends FunSuite {

test("computeFraction") {
// test that the computed fraction guarantees enough data points
// in the sample with a failure rate <= 0.0001
val n = 100000

for (s <- 1 to 15) {
val frac = SamplingUtils.computeFractionForSampleSize(s, n, true)
val poisson = new PoissonDistribution(frac * n)
assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
}
for (s <- List(20, 100, 1000)) {
val frac = SamplingUtils.computeFractionForSampleSize(s, n, true)
val poisson = new PoissonDistribution(frac * n)
assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
}
for (s <- List(1, 10, 100, 1000)) {
val frac = SamplingUtils.computeFractionForSampleSize(s, n, false)
val binomial = new BinomialDistribution(n, frac)
assert(binomial.inverseCumulativeProbability(0.0001)*n >= s, "Computed fraction is too low")
}
}
}
1 change: 1 addition & 0 deletions project/SparkBuild.scala
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ object SparkBuild extends Build {
libraryDependencies ++= Seq(
"com.google.guava" % "guava" % "14.0.1",
"org.apache.commons" % "commons-lang3" % "3.3.2",
"org.apache.commons" % "commons-math3" % "3.3" % "test",
"com.google.code.findbugs" % "jsr305" % "1.3.9",
"log4j" % "log4j" % "1.2.17",
"org.slf4j" % "slf4j-api" % slf4jVersion,
Expand Down
Loading

0 comments on commit 1de1d70

Please sign in to comment.