Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-7156][SQL] support RandomSplit in DataFrames #5761

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,26 @@ abstract class RDD[T: ClassTag](
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
new PartitionwiseSampledRDD[T, T](
this, new BernoulliCellSampler[T](x(0), x(1)), true, seed)
randomSampleWithRange(x(0), x(1), seed)
}.toArray
}

/**
* Internal method exposed for Random Splits in DataFrames. Samples an RDD given a probability
* range.
* @param lb lower bound to use for the Bernoulli sampler
* @param ub upper bound to use for the Bernoulli sampler
* @param seed the seed for the Random number generator
* @return A random sub-sample of the RDD without replacement.
*/
private[spark] def randomSampleWithRange(lb: Double, ub: Double, seed: Long): RDD[T] = {
this.mapPartitionsWithIndex { case (index, partition) =>
val sampler = new BernoulliCellSampler[T](lb, ub)
sampler.setSeed(seed + index)
sampler.sample(partition)
}
}

/**
* Return a fixed-size sampled subset of this RDD in an array
*
Expand Down
8 changes: 4 additions & 4 deletions core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,11 @@ public void sample() {
public void randomSplit() {
List<Integer> ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
JavaRDD<Integer> rdd = sc.parallelize(ints);
JavaRDD<Integer>[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 11);
JavaRDD<Integer>[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 31);
Assert.assertEquals(3, splits.length);
Assert.assertEquals(2, splits[0].count());
Assert.assertEquals(3, splits[1].count());
Assert.assertEquals(5, splits[2].count());
Assert.assertEquals(1, splits[0].count());
Assert.assertEquals(2, splits[1].count());
Assert.assertEquals(7, splits[2].count());
}

@Test
Expand Down
18 changes: 17 additions & 1 deletion python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,14 +425,30 @@ def distinct(self):
def sample(self, withReplacement, fraction, seed=None):
"""Returns a sampled subset of this :class:`DataFrame`.

>>> df.sample(False, 0.5, 97).count()
>>> df.sample(False, 0.5, 42).count()
1
"""
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
seed = seed if seed is not None else random.randint(0, sys.maxsize)
rdd = self._jdf.sample(withReplacement, fraction, long(seed))
return DataFrame(rdd, self.sql_ctx)

def randomSplit(self, weights, seed=None):
"""Randomly splits this :class:`DataFrame` with the provided weights.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be great to add params doc


>>> splits = df4.randomSplit([1.0, 2.0], 24)
>>> splits[0].count()
1

>>> splits[1].count()
3
"""
for w in weights:
assert w >= 0.0, "Negative weight value: %s" % w
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

            raise ValueError( ... )

seed = seed if seed is not None else random.randint(0, sys.maxsize)
rdd_array = self._jdf.randomSplit(_to_seq(self.sql_ctx._sc, weights), long(seed))
return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array]

@property
def dtypes(self):
"""Returns all column names and their data types as a list.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,6 @@ package object dsl {
def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean): LogicalPlan =
Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan)

def sample(
fraction: Double,
withReplacement: Boolean = true,
seed: Int = (math.random * 1000).toInt): LogicalPlan =
Sample(fraction, withReplacement, seed, logicalPlan)

// TODO specify the output column names
def generate(
generator: Generator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,22 @@ case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil))
}

case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan)
extends UnaryNode {
/**
* Sample the dataset.
*
* @param lowerBound Lower-bound of the sampling probability (usually 0.0)
* @param upperBound Upper-bound of the sampling probability. The expected fraction sampled
* will be ub - lb.
* @param withReplacement Whether to sample with replacement.
* @param seed the random seed
* @param child the LogicalPlan
*/
case class Sample(
lowerBound: Double,
upperBound: Double,
withReplacement: Boolean,
seed: Long,
child: LogicalPlan) extends UnaryNode {

override def output: Seq[Attribute] = child.output
}
Expand Down
38 changes: 37 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ class DataFrame private[sql](
* @group dfops
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = {
Sample(fraction, withReplacement, seed, logicalPlan)
Sample(0.0, fraction, withReplacement, seed, logicalPlan)
}

/**
Expand All @@ -725,6 +725,42 @@ class DataFrame private[sql](
sample(withReplacement, fraction, Utils.random.nextLong)
}

/**
* Randomly splits this [[DataFrame]] with the provided weights.
*
* @param weights weights for splits, will be normalized if they don't sum to 1.
* @param seed Seed for sampling.
* @group dfops
*/
def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] = {
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
new DataFrame(sqlContext, Sample(x(0), x(1), false, seed, logicalPlan))
}.toArray
}

/**
* Randomly splits this [[DataFrame]] with the provided weights.
*
* @param weights weights for splits, will be normalized if they don't sum to 1.
* @group dfops
*/
def randomSplit(weights: Array[Double]): Array[DataFrame] = {
randomSplit(weights, Utils.random.nextLong)
}

/**
* Randomly splits this [[DataFrame]] with the provided weights. Provided for the Python Api.
*
* @param weights weights for splits, will be normalized if they don't sum to 1.
* @param seed Seed for sampling.
* @group dfops
*/
def randomSplit(weights: List[Double], seed: Long): Array[DataFrame] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

private[spark] ? since it is only used for python.

randomSplit(weights.toArray, seed)
}

/**
* (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more
* rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Expand(projections, output, planLater(child)) :: Nil
case logical.Aggregate(group, agg, child) =>
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
case logical.Sample(fraction, withReplacement, seed, child) =>
execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
case logical.Sample(lb, ub, withReplacement, seed, child) =>
execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data) =>
LocalTableScan(output, data) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,32 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {

/**
* :: DeveloperApi ::
* Sample the dataset.
* @param lowerBound Lower-bound of the sampling probability (usually 0.0)
* @param upperBound Upper-bound of the sampling probability. The expected fraction sampled
* will be ub - lb.
* @param withReplacement Whether to sample with replacement.
* @param seed the random seed
* @param child the QueryPlan
*/
@DeveloperApi
case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: SparkPlan)
case class Sample(
lowerBound: Double,
upperBound: Double,
withReplacement: Boolean,
seed: Long,
child: SparkPlan)
extends UnaryNode
{
override def output: Seq[Attribute] = child.output

// TODO: How to pick seed?
override def execute(): RDD[Row] = {
child.execute().map(_.copy()).sample(withReplacement, fraction, seed)
if (withReplacement) {
child.execute().map(_.copy()).sample(withReplacement, upperBound - lowerBound, seed)
} else {
child.execute().map(_.copy()).randomSampleWithRange(lowerBound, upperBound, seed)
}
}
}

Expand Down
17 changes: 17 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,23 @@ class DataFrameSuite extends QueryTest {
assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol"))
}

test("randomSplit") {
val n = 600
val data = TestSQLContext.sparkContext.parallelize(1 to n, 2).toDF("id")
for (seed <- 1 to 5) {
val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
assert(splits.length == 3, "wrong number of splits")

assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList ==
data.collect().toList, "incomplete or wrong split")

val s = splits.map(_.count())
assert(math.abs(s(0) - 100) < 50) // std = 9.13
assert(math.abs(s(1) - 200) < 50) // std = 11.55
assert(math.abs(s(2) - 300) < 50) // std = 12.25
}
}

test("describe") {
val describeTestData = Seq(
("Bob", 16, 176),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -887,13 +887,13 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon)
&& fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon),
s"Sampling fraction ($fraction) must be on interval [0, 100]")
Sample(fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt,
Sample(0.0, fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt,
relation)
case Token("TOK_TABLEBUCKETSAMPLE",
Token(numerator, Nil) ::
Token(denominator, Nil) :: Nil) =>
val fraction = numerator.toDouble / denominator.toDouble
Sample(fraction, withReplacement = false, (math.random * 1000).toInt, relation)
Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation)
case a: ASTNode =>
throw new NotImplementedError(
s"""No parse rules for sampling clause: ${a.getType}, text: ${a.getText} :
Expand Down