diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 3074af3ed2e83..5908ebc990a56 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -437,6 +437,10 @@ def sample(self, withReplacement, fraction, seed=None): def randomSplit(self, weights, seed=None): """Randomly splits this :class:`DataFrame` with the provided weights. + :param weights: list of doubles as weights with which to split the DataFrame. Weights will + be normalized if they don't sum up to 1.0. + :param seed: The seed for sampling. + >>> splits = df4.randomSplit([1.0, 2.0], 24) >>> splits[0].count() 1 @@ -445,7 +449,8 @@ def randomSplit(self, weights, seed=None): 3 """ for w in weights: - assert w >= 0.0, "Negative weight value: %s" % w + if w < 0.0: + raise ValueError("Weights must be positive. Found weight value: %s" % w) seed = seed if seed is not None else random.randint(0, sys.maxsize) rdd_array = self._jdf.randomSplit(_to_seq(self.sql_ctx._sc, weights), long(seed)) return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 0d02e14c21be0..2669300029545 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -752,7 +752,7 @@ class DataFrame private[sql]( * @param seed Seed for sampling. * @group dfops */ - def randomSplit(weights: List[Double], seed: Long): Array[DataFrame] = { + private[spark] def randomSplit(weights: List[Double], seed: Long): Array[DataFrame] = { randomSplit(weights.toArray, seed) }