Skip to content

Commit

Permalink
update pyspark's takeSample
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Jun 13, 2014
1 parent 48d954d commit 82dde31
Showing 1 changed file with 32 additions and 26 deletions.
58 changes: 32 additions & 26 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,44 +362,50 @@ def takeSample(self, withReplacement, num, seed=None):
Return a fixed-size sampled subset of this RDD (currently requires
numpy).
>>> sc.parallelize(range(0, 10)).takeSample(True, 10, 1) #doctest: +SKIP
[4, 2, 1, 8, 2, 7, 0, 4, 1, 4]
>>> rdd = sc.parallelize(range(0, 10))
>>> len(rdd.takeSample(True, 20, 1))
20
>>> len(rdd.takeSample(False, 5, 2))
5
>>> len(rdd.takeSample(False, 15, 3))
10
"""

numStDev = 10.0
initialCount = self.count()

if num < 0:
raise ValueError
raise ValueError("Sample size cannot be negative.")
elif num == 0:
return []

if initialCount == 0 or num == 0:
return list()
initialCount = self.count()
if initialCount == 0:
return []

rand = Random(seed)
if (not withReplacement) and num > initialCount:

if (not withReplacement) and num >= initialCount:
# shuffle current RDD and return
samples = self.collect()
fraction = float(num) / initialCount
num = initialCount
else:
maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint))
if num > maxSampleSize:
raise ValueError

fraction = self._computeFractionForSampleSize(num, initialCount, withReplacement)
rand.shuffle(samples)
return samples

numStDev = 10.0
maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint))
if num > maxSampleSize:
raise ValueError("Sample size cannot be greater than %d." % maxSampleSize)

fraction = RDD._computeFractionForSampleSize(num, initialCount, withReplacement)
samples = self.sample(withReplacement, fraction, seed).collect()

# If the first sample didn't turn out large enough, keep trying to take samples;
# this shouldn't happen often because we use a big multiplier for their initial size.
# See: scala/spark/RDD.scala
while len(samples) < num:
# TODO: add log warning for when more than one iteration was run
seed = rand.randint(0, sys.maxint)
samples = self.sample(withReplacement, fraction, seed).collect()

# If the first sample didn't turn out large enough, keep trying to take samples;
# this shouldn't happen often because we use a big multiplier for their initial size.
# See: scala/spark/RDD.scala
while len(samples) < num:
#TODO add log warning for when more than one iteration was run
seed = rand.randint(0, sys.maxint)
samples = self.sample(withReplacement, fraction, seed).collect()
rand.shuffle(samples)

sampler = RDDSampler(withReplacement, fraction, rand.randint(0, sys.maxint))
sampler.shuffle(samples)
return samples[0:num]

@staticmethod
Expand Down

0 comments on commit 82dde31

Please sign in to comment.