diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index faa0f6f407b34..217674019e044 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -172,7 +172,7 @@ class StringIndexerModel ( case _ => dataset } filteredDataset.select(col("*"), - indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata)) + indexer(filteredDataset($(inputCol)).cast(StringType)).as($(outputCol), metadata)) } override def transformSchema(schema: StructType): StructType = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 49803aef71587..9e59503b4ccdb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -116,11 +116,11 @@ class OneHotEncoderSuite test("OneHotEncoder with varying types") { val df = stringIndexed() val dfWithTypes = df - .withColumn("shortLabel", df("labelIndex").cast(ShortType)) - .withColumn("longLabel", df("labelIndex").cast(LongType)) - .withColumn("intLabel", df("labelIndex").cast(IntegerType)) - .withColumn("floatLabel", df("labelIndex").cast(FloatType)) - .withColumn("decimalLabel", df("labelIndex").cast(DecimalType(10, 0))) + .withColumn("shortLabel", col("labelIndex").cast(ShortType)) + .withColumn("longLabel", col("labelIndex").cast(LongType)) + .withColumn("intLabel", col("labelIndex").cast(IntegerType)) + .withColumn("floatLabel", col("labelIndex").cast(FloatType)) + .withColumn("decimalLabel", col("labelIndex").cast(DecimalType(10, 0))) val cols = Array("labelIndex", "shortLabel", "longLabel", "intLabel", "floatLabel", "decimalLabel") for (col <- cols) { diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1a5d422af9535..a7abe30a0359b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -636,7 +636,7 @@ def test_column_select(self): df = self.df self.assertEqual(self.testData, df.select("*").collect()) self.assertEqual(self.testData, df.select(df.key, df.value).collect()) - self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) + self.assertEqual([Row(value='1')], df.where(df.key == 1).select("value").collect()) def test_freqItems(self): vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8dc0532b3f89a..418771144310a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -290,7 +290,7 @@ class Analyzer( s"grouping columns (${x.groupByExprs.mkString(",")})") } case Grouping(col: Expression) => - val idx = x.groupByExprs.indexOf(col) + val idx = x.groupByExprs.indexWhere(_ semanticEquals col) if (idx >= 0) { Cast(BitwiseAnd(ShiftRight(gid, Literal(x.groupByExprs.length - 1 - idx)), Literal(1)), ByteType) @@ -554,7 +554,7 @@ class Analyzer( def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = { expressions.map { - case a: Alias => Alias(a.child, a.name)(isGenerated = a.isGenerated) + case a: Alias => a.newInstance() case other => other } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index a5b5758167276..b2d57e7cf2cbc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -291,7 +291,7 @@ case class AttributeReference( exprId :: qualifier :: isGenerated :: Nil } - override def toString: String = s"$name#${exprId.id}$typeSuffix" + override def toString: String = s"$qualifiedName#${exprId.id}$typeSuffix" // Since the expression id is not in the first constructor it is missing from the default // tree string. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index ecf4285c46a51..ff36a8dc75599 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -252,14 +252,22 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // No matches. case Seq() => - logTrace(s"Could not find $name in ${input.mkString(", ")}") + logTrace(s"Could not find $name in ${input.map(_.qualifiedName).mkString(", ")}") None // More than one match. case ambiguousReferences => - val referenceNames = ambiguousReferences.map(_._1).mkString(", ") - throw new AnalysisException( - s"Reference '$name' is ambiguous, could be: $referenceNames.") + val qualifiers = ambiguousReferences.flatMap(_._1.qualifier) + if (qualifiers.nonEmpty && qualifiers.distinct.length == qualifiers.length) { + throw new AnalysisException(s"Reference '$name' is ambiguous, please add a qualifier " + + s"to distinguish it, e.g. '${qualifiers.head}.$name', available qualifiers: " + + qualifiers.mkString(", ")) + } else { + val qualifiedNames = ambiguousReferences.map(_._1.qualifiedName).mkString(", ") + throw new AnalysisException( + s"Input Attributes $qualifiedNames are ambiguous, please eliminate ambiguity " + + "from the inputs first, e.g. alias the left and right plan before join them.") + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index d64736e11110b..2eef40c9a7cfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -118,11 +118,6 @@ class Column(protected[sql] val expr: Expression) extends Logging { * Returns the expression for this column either with an existing or auto assigned name. */ private[sql] def named: NamedExpression = expr match { - // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we - // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to - // make it a NamedExpression. - case u: UnresolvedAttribute => UnresolvedAlias(u) - case u: UnresolvedExtractValue => UnresolvedAlias(u) case expr: NamedExpression => expr @@ -133,7 +128,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { case jt: JsonTuple => MultiAlias(jt, Nil) - case func: UnresolvedFunction => UnresolvedAlias(func, Some(usePrettyExpression(func).sql)) + case func: UnresolvedFunction => UnresolvedAlias(func, Some(toPresentableString(func))) // If we have a top level Cast, there is a chance to give it a better alias, if there is a // NamedExpression under this Cast. @@ -141,13 +136,20 @@ class Column(protected[sql] val expr: Expression) extends Logging { case Cast(ne: NamedExpression, to) => UnresolvedAlias(Cast(ne, to)) } match { case ne: NamedExpression => ne - case other => Alias(expr, usePrettyExpression(expr).sql)() + case other => Alias(expr, toPresentableString(expr))() } - case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() + case expr: Expression => Alias(expr, toPresentableString(expr))() } - override def toString: String = usePrettyExpression(expr).sql + override def toString: String = toPresentableString(expr) + + private def toPresentableString(expr: Expression): String = usePrettyExpression(expr transform { + // For unresolved attributes that generated by `Dataset.col`, we should ignore the generated + // qualifier to not annoy users. + case u: UnresolvedAttribute if u.nameParts(0).startsWith(Dataset.aliasPrefix) => + u.copy(nameParts = u.nameParts.drop(1)) + }).sql override def equals(that: Any): Boolean = that match { case that: Column => that.expr.equals(this.expr) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 33588ef72ffbe..06518065414ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -418,7 +418,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * TODO: This can be optimized to use broadcast join when replacementMap is large. */ private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column = { - val keyExpr = df.col(col.name).expr + val keyExpr = df.resolve(col.name) def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType) val branches = replacementMap.flatMap { case (source, target) => Seq(buildExpr(source), buildExpr(target)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 704535adaa60d..dc07c3698049c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -129,7 +129,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { userSpecifiedSchema = userSpecifiedSchema, className = source, options = extraOptions.toMap) - Dataset.ofRows(sqlContext, LogicalRelation(dataSource.resolveRelation())) + Dataset.ofRowsWithAlias(sqlContext, LogicalRelation(dataSource.resolveRelation())) } /** @@ -376,7 +376,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { parsedOptions) } - Dataset.ofRows( + Dataset.ofRowsWithAlias( sqlContext, LogicalRDD( schema.toAttributes, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 41cb799b97141..2d47760ea7629 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.io.CharArrayWriter +import java.util.concurrent.atomic.AtomicLong import scala.collection.JavaConverters._ import scala.language.implicitConversions @@ -55,10 +56,28 @@ private[sql] object Dataset { new Dataset(sqlContext, logicalPlan, implicitly[Encoder[T]]) } + def withAlias[T : Encoder]( + sqlContext: SQLContext, + logicalPlan: LogicalPlan): Dataset[T] = { + apply(sqlContext, alias(logicalPlan)) + } + def ofRows(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = { val qe = sqlContext.executePlan(logicalPlan) qe.assertAnalyzed() - new Dataset[Row](sqlContext, logicalPlan, RowEncoder(qe.analyzed.schema)) + new Dataset[Row](sqlContext, qe, RowEncoder(qe.analyzed.schema)) + } + + def ofRowsWithAlias(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = { + ofRows(sqlContext, alias(logicalPlan)) + } + + private[this] val nextDatasetId = new AtomicLong(0) + + val aliasPrefix = "dataset_" + + private def alias(plan: LogicalPlan): LogicalPlan = { + SubqueryAlias(aliasPrefix + nextDatasetId.getAndIncrement(), plan) } } @@ -186,6 +205,8 @@ class Dataset[T] private[sql]( } } + private[sql] def originalLogicalPlan = removeGeneratedSubquery(logicalPlan) + /** * An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is * marked implicit so that we can use it when constructing new [[Dataset]] objects that have the @@ -352,7 +373,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def as[U : Encoder]: Dataset[U] = Dataset[U](sqlContext, logicalPlan) + def as[U : Encoder]: Dataset[U] = Dataset.withAlias(sqlContext, originalLogicalPlan) /** * Converts this strongly typed collection of data to generic `DataFrame` with columns renamed. @@ -446,7 +467,7 @@ class Dataset[T] private[sql]( * @group basic * @since 1.6.0 */ - def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation] + def isLocal: Boolean = originalLogicalPlan.isInstanceOf[LocalRelation] /** * Displays the [[Dataset]] in a tabular form. Strings more than 20 characters will be truncated, @@ -542,9 +563,8 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.0.0 */ - def join(right: DataFrame): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Inner, None) - } + def join(right: DataFrame): DataFrame = Dataset.ofRows( + sqlContext, Join(logicalPlan, right.logicalPlan, joinType = Inner, None)) /** * Inner equi-join with another [[DataFrame]] using the given column. @@ -620,13 +640,13 @@ class Dataset[T] private[sql]( Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None)) .analyzed.asInstanceOf[Join] - withPlan { + val plan = Join( joined.left, joined.right, UsingJoin(JoinType(joinType), usingColumns.map(UnresolvedAttribute(_))), None) - } + Dataset.ofRows(sqlContext, plan) } /** @@ -665,46 +685,8 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = { - // Note that in this function, we introduce a hack in the case of self-join to automatically - // resolve ambiguous join conditions into ones that might make sense [SPARK-6231]. - // Consider this case: df.join(df, df("key") === df("key")) - // Since df("key") === df("key") is a trivially true condition, this actually becomes a - // cartesian join. However, most likely users expect to perform a self join using "key". - // With that assumption, this hack turns the trivially true condition into equality on join - // keys that are resolved to both sides. - - // Trigger analysis so in the case of self-join, the analyzer will clone the plan. - // After the cloning, left and right side will have distinct expression ids. - val plan = withPlan( - Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))) - .queryExecution.analyzed.asInstanceOf[Join] - - // If auto self join alias is disabled, return the plan. - if (!sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity) { - return withPlan(plan) - } - - // If left/right have no output set intersection, return the plan. - val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed - val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed - if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) { - return withPlan(plan) - } - - // Otherwise, find the trivially true predicates and automatically resolves them to both sides. - // By the time we get here, since we have already run analysis, all attributes should've been - // resolved and become AttributeReference. - val cond = plan.condition.map { _.transform { - case catalyst.expressions.EqualTo(a: AttributeReference, b: AttributeReference) - if a.sameRef(b) => - catalyst.expressions.EqualTo( - withPlan(plan.left).resolve(a.name), - withPlan(plan.right).resolve(b.name)) - }} - - withPlan { - plan.copy(condition = cond) - } + val plan = Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr)) + Dataset.ofRows(sqlContext, plan) } /** @@ -746,13 +728,11 @@ class Dataset[T] private[sql]( case _ => Alias(CreateStruct(rightOutput), "_2")() } - implicit val tuple2Encoder: Encoder[(T, U)] = + val tuple2Encoder: Encoder[(T, U)] = ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder) - withTypedPlan[(T, U)](other, encoderFor[(T, U)]) { (left, right) => - Project( - leftData :: rightData :: Nil, - joined.analyzed) - } + + val plan = Project(leftData :: rightData :: Nil, joined.analyzed) + new Dataset(sqlContext, plan, tuple2Encoder) } /** @@ -868,8 +848,11 @@ class Dataset[T] private[sql]( case "*" => Column(ResolvedStar(queryExecution.analyzed.output)) case _ => - val expr = resolve(colName) - Column(expr) + val col = resolve(colName) match { + case attr: Attribute => UnresolvedAttribute(attr.qualifier.toSeq :+ attr.name) + case Alias(child, _) => UnresolvedAttribute.quotedString(child.sql) + } + Column(col) } /** @@ -878,9 +861,8 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def as(alias: String): Dataset[T] = withTypedPlan { - SubqueryAlias(alias, logicalPlan) - } + def as(alias: String): Dataset[T] = + new Dataset(sqlContext, SubqueryAlias(alias, originalLogicalPlan), encoder) /** * (Scala-specific) Returns a new [[Dataset]] with an alias set. @@ -969,15 +951,12 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = { - new Dataset[U1]( - sqlContext, - Project( - c1.withInputType( - boundTEncoder, - logicalPlan.output).named :: Nil, - logicalPlan), - implicitly[Encoder[U1]]) + def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = withTypedPlan { + Project( + c1.withInputType( + boundTEncoder, + logicalPlan.output).named :: Nil, + logicalPlan) } /** @@ -989,9 +968,8 @@ class Dataset[T] private[sql]( val encoders = columns.map(_.encoder) val namedColumns = columns.map(_.withInputType(resolvedTEncoder, logicalPlan.output).named) - val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan)) - new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) + withTypedPlan(Project(namedColumns, logicalPlan))(ExpressionEncoder.tuple(encoders)) } /** @@ -1238,12 +1216,11 @@ class Dataset[T] private[sql]( def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { val inputPlan = logicalPlan val withGroupingKey = AppendColumns(func, inputPlan) - val executed = sqlContext.executePlan(withGroupingKey) new KeyValueGroupedDataset( encoderFor[K], encoderFor[T], - executed, + Dataset.ofRows(sqlContext, withGroupingKey), inputPlan.output, withGroupingKey.newColumns) } @@ -1410,7 +1387,9 @@ class Dataset[T] private[sql]( def union(other: Dataset[T]): Dataset[T] = withTypedPlan { // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. - CombineUnions(Union(logicalPlan, other.logicalPlan)) + CombineUnions(Union( + removeGeneratedSubquery(logicalPlan), + removeGeneratedSubquery(other.logicalPlan))) } /** @@ -1486,8 +1465,9 @@ class Dataset[T] private[sql]( val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => - new Dataset[T]( - sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted)(), encoder) + withTypedPlan { + Sample(x(0), x(1), withReplacement = false, seed, sorted)() + } }.toArray } @@ -1705,8 +1685,7 @@ class Dataset[T] private[sql]( u.name, sqlContext.sessionState.analyzer.resolver).getOrElse(u) case Column(expr: Expression) => expr } - val attrs = this.logicalPlan.output - val colsAfterDrop = attrs.filter { attr => + val colsAfterDrop = logicalPlan.output.filter { attr => attr != expression }.map(attr => Column(attr)) select(colsAfterDrop : _*) @@ -1905,11 +1884,8 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { - new Dataset[U]( - sqlContext, - MapPartitions[T, U](func, logicalPlan), - implicitly[Encoder[U]]) + def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = withTypedPlan { + MapPartitions[T, U](func, logicalPlan) } /** @@ -2364,18 +2340,24 @@ class Dataset[T] private[sql]( } } + private def removeGeneratedSubquery(plan: LogicalPlan): LogicalPlan = { + val resolved = sqlContext.executePlan(plan).analyzed + resolved transformDown { + case SubqueryAlias(alias, child) if alias.startsWith(Dataset.aliasPrefix) => child + } + } + /** A convenient function to wrap a logical plan and produce a DataFrame. */ - @inline private def withPlan(logicalPlan: => LogicalPlan): DataFrame = { - Dataset.ofRows(sqlContext, logicalPlan) + @inline private[sql] def withPlan(logicalPlan: => LogicalPlan): DataFrame = { + Dataset.ofRowsWithAlias(sqlContext, removeGeneratedSubquery(logicalPlan)) } - /** A convenient function to wrap a logical plan and produce a Dataset. */ - @inline private def withTypedPlan(logicalPlan: => LogicalPlan): Dataset[T] = { - new Dataset[T](sqlContext, logicalPlan, encoder) + @inline private def withPlanNoAlias(logicalPlan: => LogicalPlan): DataFrame = { + Dataset.ofRows(sqlContext, removeGeneratedSubquery(logicalPlan)) } - private[sql] def withTypedPlan[R]( - other: Dataset[_], encoder: Encoder[R])( - f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] = - new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan), encoder) + /** A convenient function to wrap a logical plan and produce a Dataset. */ + @inline private[sql] def withTypedPlan[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = { + Dataset.withAlias(sqlContext, removeGeneratedSubquery(logicalPlan)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index f19ad6e707526..1cf29ec6cb9fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.execution.QueryExecution class KeyValueGroupedDataset[K, V] private[sql]( kEncoder: Encoder[K], vEncoder: Encoder[V], - val queryExecution: QueryExecution, + val ds: Dataset[_], private val dataAttributes: Seq[Attribute], private val groupingAttributes: Seq[Attribute]) extends Serializable { @@ -54,8 +54,8 @@ class KeyValueGroupedDataset[K, V] private[sql]( private val resolvedVEncoder = unresolvedVEncoder.resolve(dataAttributes, OuterScopes.outerScopes) - private def logicalPlan = queryExecution.analyzed - private def sqlContext = queryExecution.sqlContext + private def logicalPlan = ds.logicalPlan + private def sqlContext = ds.sqlContext /** * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the @@ -68,7 +68,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( new KeyValueGroupedDataset( encoderFor[L], unresolvedVEncoder, - queryExecution, + ds, dataAttributes, groupingAttributes) @@ -77,11 +77,8 @@ class KeyValueGroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def keys: Dataset[K] = { - Dataset[K]( - sqlContext, - Distinct( - Project(groupingAttributes, logicalPlan))) + def keys: Dataset[K] = ds.withTypedPlan { + Distinct(Project(groupingAttributes, logicalPlan)) } /** @@ -102,14 +99,13 @@ class KeyValueGroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def flatMapGroups[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = { - Dataset[U]( - sqlContext, - MapGroups( - f, - groupingAttributes, - dataAttributes, - logicalPlan)) + def flatMapGroups[U : Encoder]( + f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = ds.withTypedPlan { + MapGroups( + f, + groupingAttributes, + dataAttributes, + logicalPlan) } /** @@ -218,12 +214,8 @@ class KeyValueGroupedDataset[K, V] private[sql]( Alias(CreateStruct(groupingAttributes), "key")() } val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) - val execution = new QueryExecution(sqlContext, aggregate) - new Dataset( - sqlContext, - execution, - ExpressionEncoder.tuple(unresolvedKEncoder +: encoders)) + ds.withTypedPlan(aggregate)(ExpressionEncoder.tuple(unresolvedKEncoder +: encoders)) } /** @@ -289,8 +281,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( other: KeyValueGroupedDataset[K, U])( f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { implicit val uEncoder = other.unresolvedVEncoder - Dataset[R]( - sqlContext, + ds.withTypedPlan { CoGroup( f, this.groupingAttributes, @@ -298,7 +289,8 @@ class KeyValueGroupedDataset[K, V] private[sql]( this.dataAttributes, other.dataAttributes, this.logicalPlan, - other.logicalPlan)) + other.logicalPlan) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 91c02053ae1a3..7e1af9a6080be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -51,30 +51,23 @@ class RelationalGroupedDataset protected[sql]( val aliasedAgg = aggregates.map(alias) groupType match { - case RelationalGroupedDataset.GroupByType => - Dataset.ofRows( - df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) - case RelationalGroupedDataset.RollupType => - Dataset.ofRows( - df.sqlContext, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan)) - case RelationalGroupedDataset.CubeType => - Dataset.ofRows( - df.sqlContext, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan)) - case RelationalGroupedDataset.PivotType(pivotCol, values) => + case RelationalGroupedDataset.GroupByType => df.withPlan { + Aggregate(groupingExprs, aliasedAgg, df.logicalPlan) + } + case RelationalGroupedDataset.RollupType => df.withPlan { + Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan) + } + case RelationalGroupedDataset.CubeType => df.withPlan { + Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan) + } + case RelationalGroupedDataset.PivotType(pivotCol, values) => df.withPlan { val aliasedGrps = groupingExprs.map(alias) - Dataset.ofRows( - df.sqlContext, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) + Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan) + } } } - // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we - // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to - // make it a NamedExpression. - private[this] def alias(expr: Expression): NamedExpression = expr match { - case u: UnresolvedAttribute => UnresolvedAlias(u) - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() - } + private[this] def alias(expr: Expression): NamedExpression = Column(expr).named private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction) : DataFrame = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 0576a1a178ec0..7b7c98d0b7992 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -336,7 +336,7 @@ class SQLContext private[sql]( val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes val rowRDD = RDDConversions.productToRowRdd(rdd, schema.map(_.dataType)) - Dataset.ofRows(self, LogicalRDD(attributeSeq, rowRDD)(self)) + Dataset.ofRowsWithAlias(self, LogicalRDD(attributeSeq, rowRDD)(self)) } /** @@ -351,7 +351,7 @@ class SQLContext private[sql]( SQLContext.setActive(self) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes - Dataset.ofRows(self, LocalRelation.fromProduct(attributeSeq, data)) + Dataset.ofRowsWithAlias(self, LocalRelation.fromProduct(attributeSeq, data)) } /** @@ -361,7 +361,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = { - Dataset.ofRows(this, LogicalRelation(baseRelation)) + Dataset.ofRowsWithAlias(this, LogicalRelation(baseRelation)) } /** @@ -416,7 +416,7 @@ class SQLContext private[sql]( rowRDD.map{r: Row => InternalRow.fromSeq(r.toSeq)} } val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self) - Dataset.ofRows(this, logicalPlan) + Dataset.ofRowsWithAlias(this, logicalPlan) } @@ -426,7 +426,7 @@ class SQLContext private[sql]( val encoded = data.map(d => enc.toRow(d).copy()) val plan = new LocalRelation(attributes, encoded) - Dataset[T](this, plan) + Dataset.withAlias(self, plan) } def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = { @@ -435,7 +435,7 @@ class SQLContext private[sql]( val encoded = data.map(d => enc.toRow(d)) val plan = LogicalRDD(attributes, encoded)(self) - Dataset[T](this, plan) + Dataset.withAlias(self, plan) } def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = { @@ -852,9 +852,8 @@ class SQLContext private[sql]( protected[sql] def applySchemaToPythonRDD( rdd: RDD[Array[Any]], schema: StructType): DataFrame = { - val rowRdd = rdd.map(r => python.EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow]) - Dataset.ofRows(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) + Dataset.ofRowsWithAlias(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index f3478a873a1a9..7601cc5a505ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -83,7 +83,7 @@ private[sql] class CacheManager extends Logging { query: Dataset[_], tableName: Option[String] = None, storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { - val planToCache = query.queryExecution.analyzed + val planToCache = query.originalLogicalPlan if (lookupCachedData(planToCache).nonEmpty) { logWarning("Asked to cache already cached data.") } else { @@ -102,7 +102,7 @@ private[sql] class CacheManager extends Logging { /** Removes the data for the given [[Dataset]] from the cache */ private[sql] def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock { - val planToCache = query.queryExecution.analyzed + val planToCache = query.originalLogicalPlan val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) require(dataIndex >= 0, s"Table $query is not cached.") cachedData(dataIndex).cachedRepresentation.uncache(blocking) @@ -115,7 +115,7 @@ private[sql] class CacheManager extends Logging { private[sql] def tryUncacheQuery( query: Dataset[_], blocking: Boolean = true): Boolean = writeLock { - val planToCache = query.queryExecution.analyzed + val planToCache = query.originalLogicalPlan val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) val found = dataIndex >= 0 if (found) { @@ -127,7 +127,7 @@ private[sql] class CacheManager extends Logging { /** Optionally returns cached data for the given [[Dataset]] */ private[sql] def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock { - lookupCachedData(query.queryExecution.analyzed) + lookupCachedData(query.originalLogicalPlan) } /** Optionally returns cached data for the given [[LogicalPlan]]. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 351b03b38bad1..0760850701095 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -176,7 +176,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { val exploded = df.select(explode('intList).as('i)) checkAnswer( - exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")), + exploded.join(exploded, "i").agg(count("*")), Row(3) :: Nil) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 067a62d011ec4..9aaea5deb318e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -115,27 +115,6 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) } - test("[SPARK-6231] join - self join auto resolve ambiguity") { - val df = Seq((1, "1"), (2, "2")).toDF("key", "value") - checkAnswer( - df.join(df, df("key") === df("key")), - Row(1, "1", 1, "1") :: Row(2, "2", 2, "2") :: Nil) - - checkAnswer( - df.join(df.filter($"value" === "2"), df("key") === df("key")), - Row(2, "2", 2, "2") :: Nil) - - checkAnswer( - df.join(df, df("key") === df("key") && df("value") === 1), - Row(1, "1", 1, "1") :: Nil) - - val left = df.groupBy("key").agg(count("*")) - val right = df.groupBy("key").agg(sum("key")) - checkAnswer( - left.join(right, left("key") === right("key")), - Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil) - } - test("broadcast join hint") { val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 86c6405522363..1b6e28557c750 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -553,8 +553,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { case Row(id: Int, name: String, age: Int, idToDrop: Int, salary: Double) => Row(id, name, age, salary) }.toSeq) + assert(joinedDf.schema.map(_.name) === Seq("id", "name", "age", "id", "salary")) assert(df.schema.map(_.name) === Seq("id", "name", "age", "salary")) - assert(df("id") == person("id")) } test("withColumnRenamed") { @@ -1432,4 +1432,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { getMessage() assert(e1.startsWith("Path does not exist")) } + + test("Un-direct self-join") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + val df2 = df.filter($"i" > 0) + checkAnswer( + df.join(df2, (df("i") + 1) === df2("i")), + Row(1, "a", 2, "b") :: Nil + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index a5a4ff13de83f..bf5915d763f83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -541,4 +541,24 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row(3, 1) :: Row(3, 2) :: Nil) } + + test("friendly error message for self-join") { + withTempTable("tbl") { + val df = Seq(1 -> "a").toDF("k", "v") + df.registerTempTable("tbl") + + val e1 = intercept[AnalysisException](sql("SELECT k FROM tbl JOIN tbl")) + assert(e1.message == "Input Attributes tbl.k, tbl.k are ambiguous, please eliminate " + + "ambiguity from the inputs first, e.g. alias the left and right plan before join them.") + + val e2 = intercept[AnalysisException](sql("SELECT k FROM tbl t1 JOIN tbl t2")) + assert(e2.message == "Reference 'k' is ambiguous, please add a qualifier to distinguish " + + "it, e.g. 't1.k', available qualifiers: t1, t2") + + checkAnswer( + sql("SELECT t1.k FROM tbl t1 JOIN tbl t2"), + Row(1) + ) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b727e88668370..d0d0c2ae04cda 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1834,8 +1834,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { // This test is for the fix of https://issues.apache.org/jira/browse/SPARK-10737. // This bug will be triggered when Tungsten is enabled and there are multiple // SortMergeJoin operators executed in the same task. - val confs = SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1" :: Nil - withSQLConf(confs: _*) { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") { val df1 = (1 to 50).map(i => (s"str_$i", i)).toDF("i", "j") val df2 = df1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index bdbcf842ca47d..657130c8e71d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, Row} +import org.apache.spark.sql.{execution, DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.Inner @@ -54,18 +54,18 @@ class PlannerSuite extends SharedSQLContext { } test("count is partially aggregated") { - val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed + val query = testData.groupBy('value).agg(count('key)).originalLogicalPlan testPartialAggregationPlan(query) } test("count distinct is partially aggregated") { - val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed + val query = testData.groupBy('value).agg(countDistinct('key)).originalLogicalPlan testPartialAggregationPlan(query) } test("mixed aggregates are partially aggregated") { val query = - testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed + testData.groupBy('value).agg(count('value), countDistinct('key)).originalLogicalPlan testPartialAggregationPlan(query) } @@ -164,25 +164,30 @@ class PlannerSuite extends SharedSQLContext { } } + private def checkOutput(planned: SparkPlan, df: DataFrame): Unit = { + assert(planned.output.map(_.withQualifier(None)) === + df.logicalPlan.output.map(_.withQualifier(None))) + } + test("efficient terminal limit -> sort should use TakeOrderedAndProject") { val query = testData.select('key, 'value).sort('key).limit(2) val planned = query.queryExecution.executedPlan assert(planned.isInstanceOf[execution.TakeOrderedAndProject]) - assert(planned.output === testData.select('key, 'value).logicalPlan.output) + checkOutput(planned, testData.select('key, 'value)) } test("terminal limit -> project -> sort should use TakeOrderedAndProject") { val query = testData.select('key, 'value).sort('key).select('value, 'key).limit(2) val planned = query.queryExecution.executedPlan assert(planned.isInstanceOf[execution.TakeOrderedAndProject]) - assert(planned.output === testData.select('value, 'key).logicalPlan.output) + checkOutput(planned, testData.select('value, 'key)) } test("terminal limits that are not handled by TakeOrderedAndProject should use CollectLimit") { val query = testData.select('value).limit(2) val planned = query.queryExecution.sparkPlan assert(planned.isInstanceOf[CollectLimit]) - assert(planned.output === testData.select('value).logicalPlan.output) + checkOutput(planned, testData.select('value)) } test("TakeOrderedAndProject can appear in the middle of plans") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 1fa15730bc2e2..76fdf2899116e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -216,7 +216,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi /** Returns a resolved expression for `str` in the context of `df`. */ def resolve(df: DataFrame, str: String): Expression = { - df.select(expr(str)).queryExecution.analyzed.expressions.head.children.head + df.select(expr(str)).originalLogicalPlan.expressions.head.children.head } /** Returns a set with all the filters present in the physical plan. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 3cb3ef1ffa2f4..e4fb811a717b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{EqualNullSafe, EqualTo, Expression} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.Join @@ -218,7 +218,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { "inner join, one match per row", myUpperCaseData, myLowerCaseData, - () => (myUpperCaseData.col("N") === myLowerCaseData.col("n")).expr, + () => EqualTo(myUpperCaseData.resolve("N"), myLowerCaseData.resolve("n")), Seq( (1, "A", 1, "a"), (2, "B", 2, "b"), @@ -234,7 +234,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { "inner join, multiple matches", left, right, - () => (left.col("a") === right.col("a")).expr, + () => EqualTo(left.resolve("a"), right.resolve("a")), Seq( (1, 1, 1, 1), (1, 1, 1, 2), @@ -251,7 +251,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { "inner join, no matches", left, right, - () => (left.col("a") === right.col("a")).expr, + () => EqualTo(left.resolve("a"), right.resolve("a")), Seq.empty ) } @@ -263,7 +263,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { "inner join, null safe", left, right, - () => (left.col("b") <=> right.col("b")).expr, + () => EqualNullSafe(left.resolve("b"), right.resolve("b")), Seq( (1, 0, 1, 0), (2, null, 2, null) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 4cacb20aa0791..0c699cffe02a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} +import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Expression, LessThan} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.Join @@ -57,8 +57,8 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { )), new StructType().add("c", IntegerType).add("d", DoubleType)) private lazy val condition = { - And((left.col("a") === right.col("c")).expr, - LessThan(left.col("b").expr, right.col("d").expr)) + And(EqualTo(left.resolve("a"), right.resolve("c")), + LessThan(left.resolve("b"), right.resolve("d"))) } // Note: the input dataframes and expression must be evaluated lazily because diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index 985a96f684541..ed032f1cce045 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} +import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Expression, LessThan} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.{Inner, LeftSemi} import org.apache.spark.sql.catalyst.plans.logical.Join @@ -54,8 +54,8 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { )), new StructType().add("c", IntegerType).add("d", DoubleType)) private lazy val condition = { - And((left.col("a") === right.col("c")).expr, - LessThan(left.col("b").expr, right.col("d").expr)) + And(EqualTo(left.resolve("a"), right.resolve("c")), + LessThan(left.resolve("b"), right.resolve("d"))) } // Note: the input dataframes and expression must be evaluated lazily because diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index e7d2b5ad96821..596dc7e1ed492 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.util import scala.collection.mutable.ArrayBuffer import org.apache.spark._ -import org.apache.spark.sql.{functions, QueryTest} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project} +import org.apache.spark.sql.{functions, Dataset, QueryTest} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project, SubqueryAlias} import org.apache.spark.sql.execution.{QueryExecution, WholeStageCodegen} import org.apache.spark.sql.test.SharedSQLContext @@ -29,6 +29,16 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { import testImplicits._ import functions._ + private def removeGeneratedSubquery(plan: LogicalPlan): LogicalPlan = { + plan transformDown { + case SubqueryAlias(alias, child) if alias.startsWith(Dataset.aliasPrefix) => child + } + } + + private def checkPlan[T <: LogicalPlan](qe: QueryExecution): Unit = { + assert(removeGeneratedSubquery(qe.analyzed).isInstanceOf[T]) + } + test("execute callback functions when a DataFrame action finished successfully") { val metrics = ArrayBuffer.empty[(String, QueryExecution, Long)] val listener = new QueryExecutionListener { @@ -48,11 +58,11 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { assert(metrics.length == 2) assert(metrics(0)._1 == "collect") - assert(metrics(0)._2.analyzed.isInstanceOf[Project]) + checkPlan[Project](metrics(0)._2) assert(metrics(0)._3 > 0) assert(metrics(1)._1 == "count") - assert(metrics(1)._2.analyzed.isInstanceOf[Aggregate]) + checkPlan[Aggregate](metrics(1)._2) assert(metrics(1)._3 > 0) sqlContext.listenerManager.unregister(listener) @@ -79,7 +89,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { assert(metrics.length == 1) assert(metrics(0)._1 == "collect") - assert(metrics(0)._2.analyzed.isInstanceOf[Project]) + checkPlan[Project](metrics(0)._2) assert(metrics(0)._3.getMessage == e.getMessage) sqlContext.listenerManager.unregister(listener) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala deleted file mode 100644 index 63cf5030ab8b6..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.hive.test.TestHiveSingleton - -class HiveDataFrameJoinSuite extends QueryTest with TestHiveSingleton { - import hiveContext.implicits._ - - // We should move this into SQL package if we make case sensitivity configurable in SQL. - test("join - self join auto resolve ambiguity with case insensitivity") { - val df = Seq((1, "1"), (2, "2")).toDF("key", "value") - checkAnswer( - df.join(df, df("key") === df("Key")), - Row(1, "1", 1, "1") :: Row(2, "2", 2, "2") :: Nil) - - checkAnswer( - df.join(df.filter($"value" === "2"), df("key") === df("Key")), - Row(2, "2", 2, "2") :: Nil) - } - -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index 8f163f27c94cf..2ebcdd7f4fa3a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -53,7 +53,7 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { checkAnswer( rowsDf, (child: SparkPlan) => new ScriptTransformation( - input = Seq(rowsDf.col("a").expr), + input = Seq(rowsDf.resolve("a")), script = "cat", output = Seq(AttributeReference("a", StringType)()), child = child, @@ -67,7 +67,7 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { checkAnswer( rowsDf, (child: SparkPlan) => new ScriptTransformation( - input = Seq(rowsDf.col("a").expr), + input = Seq(rowsDf.resolve("a")), script = "cat", output = Seq(AttributeReference("a", StringType)()), child = child, @@ -82,7 +82,7 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { checkAnswer( rowsDf, (child: SparkPlan) => new ScriptTransformation( - input = Seq(rowsDf.col("a").expr), + input = Seq(rowsDf.resolve("a")), script = "cat", output = Seq(AttributeReference("a", StringType)()), child = ExceptionInjectingOperator(child), @@ -99,7 +99,7 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { checkAnswer( rowsDf, (child: SparkPlan) => new ScriptTransformation( - input = Seq(rowsDf.col("a").expr), + input = Seq(rowsDf.resolve("a")), script = "cat", output = Seq(AttributeReference("a", StringType)()), child = ExceptionInjectingOperator(child),