Skip to content

Commit

Permalink
[SPARK-2042] Prevent unnecessary shuffle triggered by take()
Browse files Browse the repository at this point in the history
This PR implements `take()` on a `SchemaRDD` by inserting a logical limit that is followed by a `collect()`. This is also accompanied by adding a catalyst optimizer rule for collapsing adjacent limits. Doing so prevents an unnecessary shuffle that is sometimes triggered by `take()`.

Author: Sameer Agarwal <sameer@databricks.com>

Closes #1048 from sameeragarwal/master and squashes the following commits:

3eeb848 [Sameer Agarwal] Fixing Tests
1b76ff1 [Sameer Agarwal] Deprecating limit(limitExpr: Expression) in v1.1.0
b723ac4 [Sameer Agarwal] Added limit folding tests
a0ff7c4 [Sameer Agarwal] Adding catalyst rule to fold two consecutive limits
8d42d03 [Sameer Agarwal] Implement trigger() as limit() followed by collect()
  • Loading branch information
sameeragarwal authored and marmbrus committed Jun 11, 2014
1 parent 4d5c12a commit 4107cce
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ package object dsl {

def where(condition: Expression) = Filter(condition, logicalPlan)

def limit(limitExpr: Expression) = Limit(limitExpr, logicalPlan)

def join(
otherPlan: LogicalPlan,
joinType: JoinType = Inner,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import org.apache.spark.sql.catalyst.types._

object Optimizer extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Combine Limits", FixedPoint(100),
CombineLimits) ::
Batch("ConstantFolding", FixedPoint(100),
NullPropagation,
ConstantFolding,
Expand Down Expand Up @@ -362,3 +364,14 @@ object SimplifyCasts extends Rule[LogicalPlan] {
case Cast(e, dataType) if e.dataType == dataType => e
}
}

/**
* Combines two adjacent [[catalyst.plans.logical.Limit Limit]] operators into one, merging the
* expressions into one single expression.
*/
object CombineLimits extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case ll @ Limit(le, nl @ Limit(ne, grandChild)) =>
Limit(If(LessThan(ne, le), ne, le), grandChild)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ case class Aggregate(
def references = (groupingExpressions ++ aggregateExpressions).flatMap(_.references).toSet
}

case class Limit(limit: Expression, child: LogicalPlan) extends UnaryNode {
case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
def output = child.output
def references = limit.references
def references = limitExpr.references
}

case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._

class CombiningLimitsSuite extends OptimizerTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Combine Limit", FixedPoint(2),
CombineLimits) ::
Batch("Constant Folding", FixedPoint(3),
NullPropagation,
ConstantFolding,
BooleanSimplification) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

test("limits: combines two limits") {
val originalQuery =
testRelation
.select('a)
.limit(10)
.limit(5)

val optimized = Optimize(originalQuery.analyze)
val correctAnswer =
testRelation
.select('a)
.limit(5).analyze

comparePlans(optimized, correctAnswer)
}

test("limits: combines three limits") {
val originalQuery =
testRelation
.select('a)
.limit(2)
.limit(7)
.limit(5)

val optimized = Optimize(originalQuery.analyze)
val correctAnswer =
testRelation
.select('a)
.limit(2).analyze

comparePlans(optimized, correctAnswer)
}
}
12 changes: 9 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,18 @@ class SchemaRDD(
def orderBy(sortExprs: SortOrder*): SchemaRDD =
new SchemaRDD(sqlContext, Sort(sortExprs, logicalPlan))

@deprecated("use limit with integer argument", "1.1.0")
def limit(limitExpr: Expression): SchemaRDD =
new SchemaRDD(sqlContext, Limit(limitExpr, logicalPlan))

/**
* Limits the results by the given expressions.
* Limits the results by the given integer.
* {{{
* schemaRDD.limit(10)
* }}}
*/
def limit(limitExpr: Expression): SchemaRDD =
new SchemaRDD(sqlContext, Limit(limitExpr, logicalPlan))
def limit(limitNum: Int): SchemaRDD =
new SchemaRDD(sqlContext, Limit(Literal(limitNum), logicalPlan))

/**
* Performs a grouping followed by an aggregation.
Expand Down Expand Up @@ -374,6 +378,8 @@ class SchemaRDD(

override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect()

override def take(num: Int): Array[Row] = limit(num).collect()

// =======================================================================
// Base RDD functions that do NOT change schema
// =======================================================================
Expand Down

0 comments on commit 4107cce

Please sign in to comment.