Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-13801][SQL] DataFrame.col should return unresolved attribute #11632

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class StringIndexerModel (
case _ => dataset
}
filteredDataset.select(col("*"),
indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata))
indexer(filteredDataset($(inputCol)).cast(StringType)).as($(outputCol), metadata))
}

override def transformSchema(schema: StructType): StructType = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ class OneHotEncoderSuite
test("OneHotEncoder with varying types") {
val df = stringIndexed()
val dfWithTypes = df
.withColumn("shortLabel", df("labelIndex").cast(ShortType))
.withColumn("longLabel", df("labelIndex").cast(LongType))
.withColumn("intLabel", df("labelIndex").cast(IntegerType))
.withColumn("floatLabel", df("labelIndex").cast(FloatType))
.withColumn("decimalLabel", df("labelIndex").cast(DecimalType(10, 0)))
.withColumn("shortLabel", col("labelIndex").cast(ShortType))
.withColumn("longLabel", col("labelIndex").cast(LongType))
.withColumn("intLabel", col("labelIndex").cast(IntegerType))
.withColumn("floatLabel", col("labelIndex").cast(FloatType))
.withColumn("decimalLabel", col("labelIndex").cast(DecimalType(10, 0)))
val cols = Array("labelIndex", "shortLabel", "longLabel", "intLabel",
"floatLabel", "decimalLabel")
for (col <- cols) {
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ def test_column_select(self):
df = self.df
self.assertEqual(self.testData, df.select("*").collect())
self.assertEqual(self.testData, df.select(df.key, df.value).collect())
self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())
self.assertEqual([Row(value='1')], df.where(df.key == 1).select("value").collect())

def test_freqItems(self):
vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ class Analyzer(
s"grouping columns (${x.groupByExprs.mkString(",")})")
}
case Grouping(col: Expression) =>
val idx = x.groupByExprs.indexOf(col)
val idx = x.groupByExprs.indexWhere(_ semanticEquals col)
if (idx >= 0) {
Cast(BitwiseAnd(ShiftRight(gid, Literal(x.groupByExprs.length - 1 - idx)),
Literal(1)), ByteType)
Expand Down Expand Up @@ -554,7 +554,7 @@ class Analyzer(

def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = {
expressions.map {
case a: Alias => Alias(a.child, a.name)(isGenerated = a.isGenerated)
case a: Alias => a.newInstance()
case other => other
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ case class AttributeReference(
exprId :: qualifier :: isGenerated :: Nil
}

override def toString: String = s"$name#${exprId.id}$typeSuffix"
override def toString: String = s"$qualifiedName#${exprId.id}$typeSuffix"

// Since the expression id is not in the first constructor it is missing from the default
// tree string.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,14 +252,22 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {

// No matches.
case Seq() =>
logTrace(s"Could not find $name in ${input.mkString(", ")}")
logTrace(s"Could not find $name in ${input.map(_.qualifiedName).mkString(", ")}")
None

// More than one match.
case ambiguousReferences =>
val referenceNames = ambiguousReferences.map(_._1).mkString(", ")
throw new AnalysisException(
s"Reference '$name' is ambiguous, could be: $referenceNames.")
val qualifiers = ambiguousReferences.flatMap(_._1.qualifier)
if (qualifiers.nonEmpty && qualifiers.distinct.length == qualifiers.length) {
throw new AnalysisException(s"Reference '$name' is ambiguous, please add a qualifier " +
s"to distinguish it, e.g. '${qualifiers.head}.$name', available qualifiers: " +
qualifiers.mkString(", "))
} else {
val qualifiedNames = ambiguousReferences.map(_._1.qualifiedName).mkString(", ")
throw new AnalysisException(
s"Input Attributes $qualifiedNames are ambiguous, please eliminate ambiguity " +
"from the inputs first, e.g. alias the left and right plan before join them.")
}
}
}
}
Expand Down
20 changes: 11 additions & 9 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,6 @@ class Column(protected[sql] val expr: Expression) extends Logging {
* Returns the expression for this column either with an existing or auto assigned name.
*/
private[sql] def named: NamedExpression = expr match {
// Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
// will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
// make it a NamedExpression.
case u: UnresolvedAttribute => UnresolvedAlias(u)

case u: UnresolvedExtractValue => UnresolvedAlias(u)

case expr: NamedExpression => expr
Expand All @@ -133,21 +128,28 @@ class Column(protected[sql] val expr: Expression) extends Logging {

case jt: JsonTuple => MultiAlias(jt, Nil)

case func: UnresolvedFunction => UnresolvedAlias(func, Some(usePrettyExpression(func).sql))
case func: UnresolvedFunction => UnresolvedAlias(func, Some(toPresentableString(func)))

// If we have a top level Cast, there is a chance to give it a better alias, if there is a
// NamedExpression under this Cast.
case c: Cast => c.transformUp {
case Cast(ne: NamedExpression, to) => UnresolvedAlias(Cast(ne, to))
} match {
case ne: NamedExpression => ne
case other => Alias(expr, usePrettyExpression(expr).sql)()
case other => Alias(expr, toPresentableString(expr))()
}

case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)()
case expr: Expression => Alias(expr, toPresentableString(expr))()
}

override def toString: String = usePrettyExpression(expr).sql
override def toString: String = toPresentableString(expr)

private def toPresentableString(expr: Expression): String = usePrettyExpression(expr transform {
// For unresolved attributes that generated by `Dataset.col`, we should ignore the generated
// qualifier to not annoy users.
case u: UnresolvedAttribute if u.nameParts(0).startsWith(Dataset.aliasPrefix) =>
u.copy(nameParts = u.nameParts.drop(1))
}).sql

override def equals(that: Any): Boolean = that match {
case that: Column => that.expr.equals(this.expr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* TODO: This can be optimized to use broadcast join when replacementMap is large.
*/
private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column = {
val keyExpr = df.col(col.name).expr
val keyExpr = df.resolve(col.name)
def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType)
val branches = replacementMap.flatMap { case (source, target) =>
Seq(buildExpr(source), buildExpr(target))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
userSpecifiedSchema = userSpecifiedSchema,
className = source,
options = extraOptions.toMap)
Dataset.ofRows(sqlContext, LogicalRelation(dataSource.resolveRelation()))
Dataset.ofRowsWithAlias(sqlContext, LogicalRelation(dataSource.resolveRelation()))
}

/**
Expand Down Expand Up @@ -376,7 +376,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
parsedOptions)
}

Dataset.ofRows(
Dataset.ofRowsWithAlias(
sqlContext,
LogicalRDD(
schema.toAttributes,
Expand Down
Loading