Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-7156][SQL] support RandomSplit in DataFrames #5761

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,27 @@ abstract class RDD[T: ClassTag](
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
new PartitionwiseSampledRDD[T, T](
this, new BernoulliCellSampler[T](x(0), x(1)), true, seed)
randomSampleWithRange(x(0), x(1), seed)
}.toArray
}

/**
* Internal method exposed for Random Splits in DataFrames. Samples an RDD given a probability
* range.
* @param lb lower bound to use for the Bernoulli sampler
* @param ub upper bound to use for the Bernoulli sampler
* @param seed the seed for the Random number generator
* @return A random sub-sample of the RDD without replacement.
*/
private[spark] def randomSampleWithRange(lb: Double, ub: Double, seed: Long): RDD[T] = {
val random = new Random(seed)
this.mapPartitions { partition =>
val sampler = new BernoulliCellSampler[T](lb, ub)
sampler.setSeed(random.nextLong)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add the partition id here so each partition has its own seed

sampler.sample(partition)
}
}

/**
* Return a fixed-size sampled subset of this RDD in an array
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,11 @@ package object dsl {
Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan)

def sample(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just remove this since it is not used

fraction: Double,
lb: Double,
ub: Double,
withReplacement: Boolean = true,
seed: Int = (math.random * 1000).toInt): LogicalPlan =
Sample(fraction, withReplacement, seed, logicalPlan)
Sample(lb, ub, withReplacement, seed, logicalPlan)

// TODO specify the output column names
def generate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,12 @@ case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil))
}

case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan)
extends UnaryNode {
case class Sample(
lb: Double,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lowerBound, upperBound and document this

ub: Double,
withReplacement: Boolean,
seed: Long,
child: LogicalPlan) extends UnaryNode {

override def output: Seq[Attribute] = child.output
}
Expand Down
27 changes: 26 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ class DataFrame private[sql](
* @group dfops
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = {
Sample(fraction, withReplacement, seed, logicalPlan)
Sample(0.0, fraction, withReplacement, seed, logicalPlan)
}

/**
Expand All @@ -725,6 +725,31 @@ class DataFrame private[sql](
sample(withReplacement, fraction, Utils.random.nextLong)
}

/**
* Randomly splits this [[DataFrame]] with the provided weights.
*
* @param weights weights for splits, will be normalized if they don't sum to 1
* @param seed Seed for sampling.
* @group dfops
*/
def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] = {
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
new DataFrame(sqlContext, Sample(x(0), x(1), false, seed, logicalPlan))
}.toArray
}

/**
* Randomly splits this [[DataFrame]] with the provided weights.
*
* @param weights weights for splits, will be normalized if they don't sum to 1
* @group dfops
*/
def randomSplit(weights: Array[Double]): Array[DataFrame] = {
randomSplit(weights, Utils.random.nextLong)
}

/**
* (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more
* rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Expand(projections, output, planLater(child)) :: Nil
case logical.Aggregate(group, agg, child) =>
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
case logical.Sample(fraction, withReplacement, seed, child) =>
execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
case logical.Sample(lb, ub, withReplacement, seed, child) =>
execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data) =>
LocalTableScan(output, data) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,31 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {

/**
* :: DeveloperApi ::
* Sample the dataset.
* @param lb Lower-bound of the sampling probability (usually 0.0)
* @param ub Upper-bound of the sampling probability. The expected fraction sampled will be ub - lb.
* @param withReplacement Whether to sample with replacement.
* @param seed the random seed
* @param child the QueryPlan
*/
@DeveloperApi
case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: SparkPlan)
case class Sample(
lb: Double,
ub: Double,
withReplacement: Boolean,
seed: Long,
child: SparkPlan)
extends UnaryNode
{
override def output: Seq[Attribute] = child.output

// TODO: How to pick seed?
override def execute(): RDD[Row] = {
child.execute().map(_.copy()).sample(withReplacement, fraction, seed)
if (withReplacement) {
child.execute().map(_.copy()).sample(withReplacement, ub - lb, seed)
} else {
child.execute().map(_.copy()).randomSampleWithRange(lb, ub, seed)
}
}
}

Expand Down
17 changes: 17 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,23 @@ class DataFrameSuite extends QueryTest {
assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol"))
}

test("randomSplit") {
val n = 600
val data = TestSQLContext.sparkContext.parallelize(1 to n, 2).toDF("id")
for (seed <- 1 to 5) {
val splits = data.randomSplit(Array(1, 2, 3), seed)
assert(splits.length == 3, "wrong number of splits")

assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList ==
data.collect().toList, "incomplete or wrong split")

val s = splits.map(_.count())
assert(math.abs(s(0) - 100) < 50) // std = 9.13
assert(math.abs(s(1) - 200) < 50) // std = 11.55
assert(math.abs(s(2) - 300) < 50) // std = 12.25
}
}

test("describe") {
val describeTestData = Seq(
("Bob", 16, 176),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -887,13 +887,13 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon)
&& fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon),
s"Sampling fraction ($fraction) must be on interval [0, 100]")
Sample(fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt,
Sample(0.0, fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt,
relation)
case Token("TOK_TABLEBUCKETSAMPLE",
Token(numerator, Nil) ::
Token(denominator, Nil) :: Nil) =>
val fraction = numerator.toDouble / denominator.toDouble
Sample(fraction, withReplacement = false, (math.random * 1000).toInt, relation)
Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation)
case a: ASTNode =>
throw new NotImplementedError(
s"""No parse rules for sampling clause: ${a.getType}, text: ${a.getText} :
Expand Down