Skip to content

Commit

Permalink
[SPARK-7156][SQL] support RandomSplit in DataFrames
Browse files Browse the repository at this point in the history
  • Loading branch information
brkyvz committed Apr 28, 2015
1 parent b14cd23 commit e98ebac
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 13 deletions.
20 changes: 18 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,27 @@ abstract class RDD[T: ClassTag](
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
new PartitionwiseSampledRDD[T, T](
this, new BernoulliCellSampler[T](x(0), x(1)), true, seed)
randomSampleWithRange(x(0), x(1), seed)
}.toArray
}

/**
* Internal method exposed for Random Splits in DataFrames. Samples an RDD given a probability
* range.
* @param lb lower bound to use for the Bernoulli sampler
* @param ub upper bound to use for the Bernoulli sampler
* @param seed the seed for the Random number generator
* @return A random sub-sample of the RDD without replacement.
*/
private[spark] def randomSampleWithRange(lb: Double, ub: Double, seed: Long): RDD[T] = {
val random = new Random(seed)
this.mapPartitions { partition =>
val sampler = new BernoulliCellSampler[T](lb, ub)
sampler.setSeed(random.nextLong)
sampler.sample(partition)
}
}

/**
* Return a fixed-size sampled subset of this RDD in an array
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,11 @@ package object dsl {
Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan)

def sample(
fraction: Double,
lb: Double,
ub: Double,
withReplacement: Boolean = true,
seed: Int = (math.random * 1000).toInt): LogicalPlan =
Sample(fraction, withReplacement, seed, logicalPlan)
Sample(lb, ub, withReplacement, seed, logicalPlan)

// TODO specify the output column names
def generate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,12 @@ case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil))
}

case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan)
extends UnaryNode {
case class Sample(
lb: Double,
ub: Double,
withReplacement: Boolean,
seed: Long,
child: LogicalPlan) extends UnaryNode {

override def output: Seq[Attribute] = child.output
}
Expand Down
27 changes: 26 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ class DataFrame private[sql](
* @group dfops
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = {
Sample(fraction, withReplacement, seed, logicalPlan)
Sample(0.0, fraction, withReplacement, seed, logicalPlan)
}

/**
Expand All @@ -725,6 +725,31 @@ class DataFrame private[sql](
sample(withReplacement, fraction, Utils.random.nextLong)
}

/**
* Randomly splits this [[DataFrame]] with the provided weights.
*
* @param weights weights for splits, will be normalized if they don't sum to 1
* @param seed Seed for sampling.
* @group dfops
*/
def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] = {
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
new DataFrame(sqlContext, Sample(x(0), x(1), false, seed, logicalPlan))
}.toArray
}

/**
* Randomly splits this [[DataFrame]] with the provided weights.
*
* @param weights weights for splits, will be normalized if they don't sum to 1
* @group dfops
*/
def randomSplit(weights: Array[Double]): Array[DataFrame] = {
randomSplit(weights, Utils.random.nextLong)
}

/**
* (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more
* rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Expand(projections, output, planLater(child)) :: Nil
case logical.Aggregate(group, agg, child) =>
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
case logical.Sample(fraction, withReplacement, seed, child) =>
execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
case logical.Sample(lb, ub, withReplacement, seed, child) =>
execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data) =>
LocalTableScan(output, data) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,31 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {

/**
* :: DeveloperApi ::
* Sample the dataset.
* @param lb Lower-bound of the sampling probability (usually 0.0)
* @param ub Upper-bound of the sampling probability. The expected fraction sampled will be ub - lb.
* @param withReplacement Whether to sample with replacement.
* @param seed the random seed
* @param child the QueryPlan
*/
@DeveloperApi
case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: SparkPlan)
case class Sample(
lb: Double,
ub: Double,
withReplacement: Boolean,
seed: Long,
child: SparkPlan)
extends UnaryNode
{
override def output: Seq[Attribute] = child.output

// TODO: How to pick seed?
override def execute(): RDD[Row] = {
child.execute().map(_.copy()).sample(withReplacement, fraction, seed)
if (withReplacement) {
child.execute().map(_.copy()).sample(withReplacement, ub - lb, seed)
} else {
child.execute().map(_.copy()).randomSampleWithRange(lb, ub, seed)
}
}
}

Expand Down
17 changes: 17 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,23 @@ class DataFrameSuite extends QueryTest {
assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol"))
}

test("randomSplit") {
val n = 600
val data = TestSQLContext.sparkContext.parallelize(1 to n, 2).toDF("id")
for (seed <- 1 to 5) {
val splits = data.randomSplit(Array(1, 2, 3), seed)
assert(splits.length == 3, "wrong number of splits")

assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList ==
data.collect().toList, "incomplete or wrong split")

val s = splits.map(_.count())
assert(math.abs(s(0) - 100) < 50) // std = 9.13
assert(math.abs(s(1) - 200) < 50) // std = 11.55
assert(math.abs(s(2) - 300) < 50) // std = 12.25
}
}

test("describe") {
val describeTestData = Seq(
("Bob", 16, 176),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -887,13 +887,13 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon)
&& fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon),
s"Sampling fraction ($fraction) must be on interval [0, 100]")
Sample(fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt,
Sample(0.0, fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt,
relation)
case Token("TOK_TABLEBUCKETSAMPLE",
Token(numerator, Nil) ::
Token(denominator, Nil) :: Nil) =>
val fraction = numerator.toDouble / denominator.toDouble
Sample(fraction, withReplacement = false, (math.random * 1000).toInt, relation)
Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation)
case a: ASTNode =>
throw new NotImplementedError(
s"""No parse rules for sampling clause: ${a.getType}, text: ${a.getText} :
Expand Down

0 comments on commit e98ebac

Please sign in to comment.