From 705d9631f0d5f1942219fb1e3beff7280c0b0dee Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 5 Aug 2014 15:46:52 -0700 Subject: [PATCH] Add a rule for resolving ORDER BY expressions that reference attributes not present in the SELECT clause. --- .../sql/catalyst/analysis/Analyzer.scala | 46 +++++++++++++++++ .../sql/hive/execution/SQLQuerySuite.scala | 50 +++++++++++++++++++ 2 files changed, 96 insertions(+) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 4913b54535419..4e5f7fd78d4f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -48,6 +48,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool Batch("Resolution", fixedPoint, ResolveReferences :: ResolveRelations :: + ResolveSortReferences :: NewRelationInstances :: ImplicitGenerate :: StarExpansion :: @@ -120,6 +121,51 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool } } + /** + * In many dialects of SQL is it valid to sort by attributes that are not present in the SELECT + * clause. This rule detects such queries and adds the required attributes to the original + * projection, so that they will be available during sorting. Another projection is added to + * remove these attributes after sorting. + */ + object ResolveSortReferences extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case s@Sort(ordering, p@Project(projectList, child)) if !s.resolved && p.resolved => + val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name}) + val resolved = unresolved.flatMap(child.resolveChildren) + val requiredAttributes = resolved.collect { case a: Attribute => a }.toSet + + val missingInProject = requiredAttributes -- p.output + if (missingInProject.nonEmpty) { + // Add missing attributes and then project them away after the sort. + Project(projectList, + Sort(ordering, + Project(projectList ++ missingInProject, child))) + } else { + s // Nothing we can do here. Return original plan. + } + case s@Sort(ordering, a@Aggregate(grouping, aggs, child)) if !s.resolved && a.resolved => + val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name}) + // A small hack to create an object that will allow us to resolve any references that + // refer to named expressions that are present in the grouping expressions. + val groupingRelation = LocalRelation( + grouping.collect { case ne: NamedExpression => ne.toAttribute} + ) + + logWarning(s"Grouping expressions: $groupingRelation") + val resolved = unresolved.flatMap(groupingRelation.resolve).toSet + val missingInAggs = resolved -- a.outputSet + logWarning(s"Resolved: $resolved Missing in aggs: $missingInAggs") + if(missingInAggs.nonEmpty) { + // Add missing grouping exprs and then project them away after the sort. + Project(a.output, + Sort(ordering, + Aggregate(grouping, aggs ++ missingInAggs, child))) + } else { + s // Nothing we can do here. Return original plan. + } + } + } + /** * Replaces [[UnresolvedFunction]]s with concrete [[catalyst.expressions.Expression Expressions]]. */ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala new file mode 100644 index 0000000000000..959b574e490fc --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import scala.reflect.ClassTag + +import org.apache.spark.sql.{SQLConf, QueryTest} +import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ + +/** + * A collection of hive query tests where we generate the answers ourselves instead of depending on + * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is + * valid, but Hive currently cannot execute it. + */ +class SQLQuerySuite extends QueryTest { + test("ordering not in select") { + checkAnswer( + sql("SELECT key FROM src ORDER BY value"), + sql("SELECT key FROM (SELECT key, value FROM src ORDER BY value) a").collect().toSeq) + } + + test("ordering not in agg") { + checkAnswer( + sql("SELECT key FROM src GROUP BY key, value ORDER BY value"), + sql(""" + SELECT key + FROM ( + SELECT key, value + FROM src + GROUP BY key, value + ORDER BY value) a""").collect().toSeq) + } +} \ No newline at end of file