Skip to content

Commit

Permalink
Addressed follow up comments
Browse files Browse the repository at this point in the history
  • Loading branch information
brkyvz committed Apr 29, 2015
1 parent f8cbb0a commit 1ea456f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
7 changes: 6 additions & 1 deletion python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,10 @@ def sample(self, withReplacement, fraction, seed=None):
def randomSplit(self, weights, seed=None):
"""Randomly splits this :class:`DataFrame` with the provided weights.
:param weights: list of doubles with which to split the DataFrame into. Weights will be
normalized if they don't sum up to 1.0.
:param seed: The seed for sampling.
>>> splits = df4.randomSplit([1.0, 2.0], 24)
>>> splits[0].count()
1
Expand All @@ -445,7 +449,8 @@ def randomSplit(self, weights, seed=None):
3
"""
for w in weights:
assert w >= 0.0, "Negative weight value: %s" % w
if w < 0.0:
raise ValueError("Weights must be positive. Found weight value: %s" % w)
seed = seed if seed is not None else random.randint(0, sys.maxsize)
rdd_array = self._jdf.randomSplit(_to_seq(self.sql_ctx._sc, weights), long(seed))
return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ class DataFrame private[sql](
* @param seed Seed for sampling.
* @group dfops
*/
def randomSplit(weights: List[Double], seed: Long): Array[DataFrame] = {
private[spark] def randomSplit(weights: List[Double], seed: Long): Array[DataFrame] = {
randomSplit(weights.toArray, seed)
}

Expand Down

0 comments on commit 1ea456f

Please sign in to comment.