From 07886917b71ad0b23fbe68253a568d29882a21b1 Mon Sep 17 00:00:00 2001 From: Cody Koeninger Date: Thu, 11 Sep 2014 17:05:21 -0500 Subject: [PATCH] SPARK-3462 add tests, handle compatible schema with different aliases, per marmbrus feedback --- .../sql/catalyst/optimizer/Optimizer.scala | 49 ++++++++++----- .../optimizer/UnionPushdownSuite.scala | 62 +++++++++++++++++++ 2 files changed, 95 insertions(+), 16 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 82b0ccc01695e..ba4d58e0acaf1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -47,32 +47,49 @@ object Optimizer extends RuleExecutor[LogicalPlan] { ColumnPruning) :: Nil } +/** + * Pushes operations to either side of a Union. + */ object UnionPushdown extends Rule[LogicalPlan] { - def fixProject(project: Project, child: LogicalPlan): Project = { - val pl = project.projectList.map(p => child.output.find { c => - c.name == p.name && c.qualifiers == p.qualifiers - }.getOrElse(p)) - Project(pl, child) + + /** + * Maps Attributes from the left side to the corresponding Attribute on the right side. + */ + def buildRewrites(union: Union): AttributeMap[Attribute] = { + assert(union.left.output.size == union.right.output.size) + + AttributeMap(union.left.output.zip(union.right.output)) } - def fixFilter(filter: Filter, child: LogicalPlan): Filter = { - val cond = filter.condition.transform { - case a: AttributeReference => - child.output.find { c => - c.name == a.name && c.qualifiers == a.qualifiers - }.getOrElse(a) + /** + * Rewrites an expression so that it can be pushed to the right side of a Union operator. + * This method relies on the fact that the output attributes of a union are always equal + * to the left child's output. + */ + def pushToRight[A <: Expression](e: A, union: Union, rewrites: AttributeMap[Attribute]): A = { + val result = e transform { + case a: Attribute => rewrites(a) } - Filter(cond, child) + + // We must promise the compiler that we did not discard the names in the case of project + // expressions. This is safe since the only transformation is from Attribute => Attribute. + result.asInstanceOf[A] } def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Push down filter into union - case f @ Filter(_, Union(left, right)) => - Union(fixFilter(f, left), fixFilter(f, right)) + case Filter(condition, u @ Union(left, right)) => + val rewrites = buildRewrites(u) + Union( + Filter(condition, left), + Filter(pushToRight(condition, u, rewrites), right)) // Push down projection into union - case p @ Project(_, Union(left, right)) => - Union(fixProject(p, left), fixProject(p, right)) + case Project(projectList, u @ Union(left, right)) => + val rewrites = buildRewrites(u) + Union( + Project(projectList, left), + Project(projectList.map(pushToRight(_, u, rewrites)), right)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala new file mode 100644 index 0000000000000..dfef87bd9133d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.{PlanTest, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class UnionPushdownSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateAnalysisOperators) :: + Batch("Union Pushdown", Once, + UnionPushdown) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + val testUnion = Union(testRelation, testRelation2) + + test("union: filter to each side") { + val query = testUnion.where('a === 1) + + val optimized = Optimize(query.analyze) + + val correctAnswer = + Union(testRelation.where('a === 1), testRelation2.where('d === 1)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("union: project to each side") { + val query = testUnion.select('b) + + val optimized = Optimize(query.analyze) + + val correctAnswer = + Union(testRelation.select('b), testRelation2.select('e)).analyze + + comparePlans(optimized, correctAnswer) + } +}