From e6f00b876a57204e0538a07933255c2ceaa1f1b8 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 30 Jan 2015 22:04:18 -0800 Subject: [PATCH] [SQL][API] ComputableColumn vs IncomputableColumn This patch changes Column from a concrete implementation to a trait, and provides two concrete implementations: IncomputableColumn and ComputableColumn. --- .../scala/org/apache/spark/sql/Column.scala | 241 ++++++++----- .../apache/spark/sql/ComputableColumn.scala | 33 ++ .../org/apache/spark/sql/DataFrame.scala | 290 ++++----------- .../org/apache/spark/sql/DataFrameImpl.scala | 329 ++++++++++++++++++ .../main/scala/org/apache/spark/sql/Dsl.scala | 8 +- .../apache/spark/sql/GroupedDataFrame.scala | 9 +- .../apache/spark/sql/IncomputableColumn.scala | 162 +++++++++ .../org/apache/spark/sql/SQLContext.scala | 18 +- .../apache/spark/sql/execution/commands.scala | 2 +- .../org/apache/spark/sql/sources/ddl.scala | 2 +- .../spark/sql/test/TestSQLContext.scala | 2 +- .../spark/sql/ColumnExpressionSuite.scala | 35 ++ .../sql/parquet/ParquetFilterSuite.scala | 6 +- .../apache/spark/sql/hive/HiveContext.scala | 3 +- .../spark/sql/hive/HiveStrategies.scala | 13 +- 15 files changed, 807 insertions(+), 346 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/ComputableColumn.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 174c403059510..6f48d7c3fe1b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -17,23 +17,26 @@ package org.apache.spark.sql +import scala.annotation.tailrec import scala.language.implicitConversions import org.apache.spark.sql.Dsl.lit -import org.apache.spark.sql.catalyst.analysis.{UnresolvedStar, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{Subquery, Project, LogicalPlan} import org.apache.spark.sql.types._ -object Column { - /** - * Creates a [[Column]] based on the given column name. Same as [[Dsl.col]]. - */ - def apply(colName: String): Column = new Column(colName) +private[sql] object Column { + + def apply(colName: String): Column = new IncomputableColumn(colName) + + def apply(expr: Expression): Column = new IncomputableColumn(expr) + + def apply(sqlContext: SQLContext, plan: LogicalPlan, expr: Expression): Column = { + new ComputableColumn(sqlContext, plan, expr) + } - /** For internal pattern matching. */ - private[sql] def unapply(col: Column): Option[Expression] = Some(col.expr) + def unapply(col: Column): Option[Expression] = Some(col.expr) } @@ -53,44 +56,42 @@ object Column { * */ // TODO: Improve documentation. -class Column( - sqlContext: Option[SQLContext], - plan: Option[LogicalPlan], - protected[sql] val expr: Expression) - extends DataFrame(sqlContext, plan) with ExpressionApi { +trait Column extends DataFrame with ExpressionApi { - /** Turns a Catalyst expression into a `Column`. */ - protected[sql] def this(expr: Expression) = this(None, None, expr) + protected[sql] def expr: Expression /** - * Creates a new `Column` expression based on a column or attribute name. - * The resolution of this is the same as SQL. For example: - * - * - "colName" becomes an expression selecting the column named "colName". - * - "*" becomes an expression selecting all columns. - * - "df.*" becomes an expression selecting all columns in data frame "df". + * Returns true iff the [[Column]] is computable. */ - def this(name: String) = this(name match { - case "*" => UnresolvedStar(None) - case _ if name.endsWith(".*") => UnresolvedStar(Some(name.substring(0, name.length - 2))) - case _ => UnresolvedAttribute(name) - }) + def isComputable: Boolean - override def isComputable: Boolean = sqlContext.isDefined && plan.isDefined + private def constructColumn(other: Column)(newExpr: Expression): Column = { + // Removes all the top level projection and subquery so we can get to the underlying plan. + @tailrec def stripProject(p: LogicalPlan): LogicalPlan = p match { + case Project(_, child) => stripProject(child) + case Subquery(_, child) => stripProject(child) + case _ => p + } - /** - * An implicit conversion function internal to this class. This function creates a new Column - * based on an expression. If the expression itself is not named, it aliases the expression - * by calling it "col". - */ - private[this] implicit def toColumn(expr: Expression): Column = { - val projectedPlan = plan.map { p => - Project(Seq(expr match { + def computableCol(baseCol: ComputableColumn, expr: Expression) = { + val plan = Project(Seq(expr match { case named: NamedExpression => named case unnamed: Expression => Alias(unnamed, "col")() - }), p) + }), baseCol.plan) + Column(baseCol.sqlContext, plan, expr) + } + + (this, other) match { + case (left: ComputableColumn, right: ComputableColumn) => + if (stripProject(left.plan).sameResult(stripProject(right.plan))) { + computableCol(right, newExpr) + } else { + Column(newExpr) + } + case (left: ComputableColumn, _) => computableCol(left, newExpr) + case (_, right: ComputableColumn) => computableCol(right, newExpr) + case (_, _) => Column(newExpr) } - new Column(sqlContext, projectedPlan, expr) } /** @@ -100,7 +101,7 @@ class Column( * df.select( -df("amount") ) * }}} */ - override def unary_- : Column = UnaryMinus(expr) + override def unary_- : Column = constructColumn(null) { UnaryMinus(expr) } /** * Bitwise NOT. @@ -109,7 +110,7 @@ class Column( * df.select( ~df("flags") ) * }}} */ - override def unary_~ : Column = BitwiseNot(expr) + override def unary_~ : Column = constructColumn(null) { BitwiseNot(expr) } /** * Inversion of boolean expression, i.e. NOT. @@ -118,7 +119,7 @@ class Column( * df.select( !df("isActive") ) * }} */ - override def unary_! : Column = Not(expr) + override def unary_! : Column = constructColumn(null) { Not(expr) } /** @@ -129,7 +130,9 @@ class Column( * df.select( df("colA".equalTo(df("colB")) ) * }}} */ - override def === (other: Column): Column = EqualTo(expr, other.expr) + override def === (other: Column): Column = constructColumn(other) { + EqualTo(expr, other.expr) + } /** * Equality test with a literal value. @@ -169,7 +172,9 @@ class Column( * df.select( !(df("colA") === df("colB")) ) * }}} */ - override def !== (other: Column): Column = Not(EqualTo(expr, other.expr)) + override def !== (other: Column): Column = constructColumn(other) { + Not(EqualTo(expr, other.expr)) + } /** * Inequality test with a literal value. @@ -188,7 +193,9 @@ class Column( * people.select( people("age") > Literal(21) ) * }}} */ - override def > (other: Column): Column = GreaterThan(expr, other.expr) + override def > (other: Column): Column = constructColumn(other) { + GreaterThan(expr, other.expr) + } /** * Greater than a literal value. @@ -206,7 +213,9 @@ class Column( * people.select( people("age") < Literal(21) ) * }}} */ - override def < (other: Column): Column = LessThan(expr, other.expr) + override def < (other: Column): Column = constructColumn(other) { + LessThan(expr, other.expr) + } /** * Less than a literal value. @@ -224,7 +233,9 @@ class Column( * people.select( people("age") <= Literal(21) ) * }}} */ - override def <= (other: Column): Column = LessThanOrEqual(expr, other.expr) + override def <= (other: Column): Column = constructColumn(other) { + LessThanOrEqual(expr, other.expr) + } /** * Less than or equal to a literal value. @@ -242,7 +253,9 @@ class Column( * people.select( people("age") >= Literal(21) ) * }}} */ - override def >= (other: Column): Column = GreaterThanOrEqual(expr, other.expr) + override def >= (other: Column): Column = constructColumn(other) { + GreaterThanOrEqual(expr, other.expr) + } /** * Greater than or equal to a literal value. @@ -256,9 +269,11 @@ class Column( /** * Equality test with an expression that is safe for null values. */ - override def <=> (other: Column): Column = other match { - case null => EqualNullSafe(expr, lit(null).expr) - case _ => EqualNullSafe(expr, other.expr) + override def <=> (other: Column): Column = constructColumn(other) { + other match { + case null => EqualNullSafe(expr, lit(null).expr) + case _ => EqualNullSafe(expr, other.expr) + } } /** @@ -269,12 +284,12 @@ class Column( /** * True if the current expression is null. */ - override def isNull: Column = IsNull(expr) + override def isNull: Column = constructColumn(null) { IsNull(expr) } /** * True if the current expression is NOT null. */ - override def isNotNull: Column = IsNotNull(expr) + override def isNotNull: Column = constructColumn(null) { IsNotNull(expr) } /** * Boolean OR with an expression. @@ -283,7 +298,9 @@ class Column( * people.select( people("inSchool") || people("isEmployed") ) * }}} */ - override def || (other: Column): Column = Or(expr, other.expr) + override def || (other: Column): Column = constructColumn(other) { + Or(expr, other.expr) + } /** * Boolean OR with a literal value. @@ -301,7 +318,9 @@ class Column( * people.select( people("inSchool") && people("isEmployed") ) * }}} */ - override def && (other: Column): Column = And(expr, other.expr) + override def && (other: Column): Column = constructColumn(other) { + And(expr, other.expr) + } /** * Boolean AND with a literal value. @@ -315,7 +334,9 @@ class Column( /** * Bitwise AND with an expression. */ - override def & (other: Column): Column = BitwiseAnd(expr, other.expr) + override def & (other: Column): Column = constructColumn(other) { + BitwiseAnd(expr, other.expr) + } /** * Bitwise AND with a literal value. @@ -325,7 +346,9 @@ class Column( /** * Bitwise OR with an expression. */ - override def | (other: Column): Column = BitwiseOr(expr, other.expr) + override def | (other: Column): Column = constructColumn(other) { + BitwiseOr(expr, other.expr) + } /** * Bitwise OR with a literal value. @@ -335,7 +358,9 @@ class Column( /** * Bitwise XOR with an expression. */ - override def ^ (other: Column): Column = BitwiseXor(expr, other.expr) + override def ^ (other: Column): Column = constructColumn(other) { + BitwiseXor(expr, other.expr) + } /** * Bitwise XOR with a literal value. @@ -349,7 +374,9 @@ class Column( * people.select( people("height") + people("weight") ) * }}} */ - override def + (other: Column): Column = Add(expr, other.expr) + override def + (other: Column): Column = constructColumn(other) { + Add(expr, other.expr) + } /** * Sum of this expression and another expression. @@ -367,7 +394,9 @@ class Column( * people.select( people("height") - people("weight") ) * }}} */ - override def - (other: Column): Column = Subtract(expr, other.expr) + override def - (other: Column): Column = constructColumn(other) { + Subtract(expr, other.expr) + } /** * Subtraction. Subtract a literal value from this expression. @@ -385,7 +414,9 @@ class Column( * people.select( people("height") * people("weight") ) * }}} */ - override def * (other: Column): Column = Multiply(expr, other.expr) + override def * (other: Column): Column = constructColumn(other) { + Multiply(expr, other.expr) + } /** * Multiplication this expression and a literal value. @@ -403,7 +434,9 @@ class Column( * people.select( people("height") / people("weight") ) * }}} */ - override def / (other: Column): Column = Divide(expr, other.expr) + override def / (other: Column): Column = constructColumn(other) { + Divide(expr, other.expr) + } /** * Division this expression by a literal value. @@ -417,7 +450,9 @@ class Column( /** * Modulo (a.k.a. remainder) expression. */ - override def % (other: Column): Column = Remainder(expr, other.expr) + override def % (other: Column): Column = constructColumn(other) { + Remainder(expr, other.expr) + } /** * Modulo (a.k.a. remainder) expression. @@ -430,29 +465,40 @@ class Column( * by the evaluated values of the arguments. */ @scala.annotation.varargs - override def in(list: Column*): Column = In(expr, list.map(_.expr)) + override def in(list: Column*): Column = { + new IncomputableColumn(In(expr, list.map(_.expr))) + } - override def like(literal: String): Column = Like(expr, lit(literal).expr) + override def like(literal: String): Column = constructColumn(null) { + Like(expr, lit(literal).expr) + } - override def rlike(literal: String): Column = RLike(expr, lit(literal).expr) + override def rlike(literal: String): Column = constructColumn(null) { + RLike(expr, lit(literal).expr) + } /** * An expression that gets an item at position `ordinal` out of an array. */ - override def getItem(ordinal: Int): Column = GetItem(expr, Literal(ordinal)) + override def getItem(ordinal: Int): Column = constructColumn(null) { + GetItem(expr, Literal(ordinal)) + } /** * An expression that gets a field by name in a [[StructField]]. */ - override def getField(fieldName: String): Column = GetField(expr, fieldName) + override def getField(fieldName: String): Column = constructColumn(null) { + GetField(expr, fieldName) + } /** * An expression that returns a substring. * @param startPos expression for the starting position. * @param len expression for the length of the substring. */ - override def substr(startPos: Column, len: Column): Column = - Substring(expr, startPos.expr, len.expr) + override def substr(startPos: Column, len: Column): Column = { + new IncomputableColumn(Substring(expr, startPos.expr, len.expr)) + } /** * An expression that returns a substring. @@ -461,16 +507,21 @@ class Column( */ override def substr(startPos: Int, len: Int): Column = this.substr(lit(startPos), lit(len)) - override def contains(other: Column): Column = Contains(expr, other.expr) + override def contains(other: Column): Column = constructColumn(other) { + Contains(expr, other.expr) + } override def contains(literal: Any): Column = this.contains(lit(literal)) - - override def startsWith(other: Column): Column = StartsWith(expr, other.expr) + override def startsWith(other: Column): Column = constructColumn(other) { + StartsWith(expr, other.expr) + } override def startsWith(literal: String): Column = this.startsWith(lit(literal)) - override def endsWith(other: Column): Column = EndsWith(expr, other.expr) + override def endsWith(other: Column): Column = constructColumn(other) { + EndsWith(expr, other.expr) + } override def endsWith(literal: String): Column = this.endsWith(lit(literal)) @@ -481,7 +532,7 @@ class Column( * df.select($"colA".as("colB")) * }}} */ - override def as(alias: String): Column = Alias(expr, alias)() + override def as(alias: String): Column = constructColumn(null) { Alias(expr, alias)() } /** * Casts the column to a different data type. @@ -494,7 +545,7 @@ class Column( * df.select(df("colA").cast("int")) * }}} */ - override def cast(to: DataType): Column = Cast(expr, to) + override def cast(to: DataType): Column = constructColumn(null) { Cast(expr, to) } /** * Casts the column to a different data type, using the canonical string representation @@ -505,28 +556,30 @@ class Column( * df.select(df("colA").cast("int")) * }}} */ - override def cast(to: String): Column = Cast(expr, to.toLowerCase match { - case "string" => StringType - case "boolean" => BooleanType - case "byte" => ByteType - case "short" => ShortType - case "int" => IntegerType - case "long" => LongType - case "float" => FloatType - case "double" => DoubleType - case "decimal" => DecimalType.Unlimited - case "date" => DateType - case "timestamp" => TimestampType - case _ => throw new RuntimeException(s"""Unsupported cast type: "$to"""") - }) - - override def desc: Column = SortOrder(expr, Descending) - - override def asc: Column = SortOrder(expr, Ascending) + override def cast(to: String): Column = constructColumn(null) { + Cast(expr, to.toLowerCase match { + case "string" => StringType + case "boolean" => BooleanType + case "byte" => ByteType + case "short" => ShortType + case "int" => IntegerType + case "long" => LongType + case "float" => FloatType + case "double" => DoubleType + case "decimal" => DecimalType.Unlimited + case "date" => DateType + case "timestamp" => TimestampType + case _ => throw new RuntimeException(s"""Unsupported cast type: "$to"""") + }) + } + + override def desc: Column = constructColumn(null) { SortOrder(expr, Descending) } + + override def asc: Column = constructColumn(null) { SortOrder(expr, Ascending) } } -class ColumnName(name: String) extends Column(name) { +class ColumnName(name: String) extends IncomputableColumn(name) { /** Creates a new AttributeReference of type boolean */ def boolean: StructField = StructField(name, BooleanType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ComputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/ComputableColumn.scala new file mode 100644 index 0000000000000..ac479b26a7c6a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/ComputableColumn.scala @@ -0,0 +1,33 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import scala.language.implicitConversions + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + + +private[sql] class ComputableColumn protected[sql]( + sqlContext: SQLContext, + protected[sql] val plan: LogicalPlan, + protected[sql] val expr: Expression) + extends DataFrameImpl(sqlContext, plan) with Column { + + override def isComputable: Boolean = true +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 1096e396591df..95830bd1b0f18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -19,26 +19,21 @@ package org.apache.spark.sql import java.util.{List => JList} -import scala.language.implicitConversions import scala.reflect.ClassTag -import scala.collection.JavaConversions._ -import com.fasterxml.jackson.core.JsonFactory - -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD -import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython} -import org.apache.spark.sql.json.JsonRDD -import org.apache.spark.sql.types.{NumericType, StructType} -import org.apache.spark.util.Utils +import org.apache.spark.sql.types.StructType + + +private[sql] object DataFrame { + def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = { + new DataFrameImpl(sqlContext, logicalPlan) + } +} /** @@ -78,50 +73,14 @@ import org.apache.spark.util.Utils * }}} */ // TODO: Improve documentation. -class DataFrame protected[sql]( - val sqlContext: SQLContext, - private val baseLogicalPlan: LogicalPlan, - operatorsEnabled: Boolean) - extends DataFrameSpecificApi with RDDApi[Row] { - - protected[sql] def this(sqlContext: Option[SQLContext], plan: Option[LogicalPlan]) = - this(sqlContext.orNull, plan.orNull, sqlContext.isDefined && plan.isDefined) - - protected[sql] def this(sqlContext: SQLContext, plan: LogicalPlan) = this(sqlContext, plan, true) - - @transient protected[sql] lazy val queryExecution = sqlContext.executePlan(baseLogicalPlan) - - @transient protected[sql] val logicalPlan: LogicalPlan = baseLogicalPlan match { - // For various commands (like DDL) and queries with side effects, we force query optimization to - // happen right away to let these side effects take place eagerly. - case _: Command | _: InsertIntoTable | _: CreateTableAsSelect[_] |_: WriteToFile => - LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) - case _ => - baseLogicalPlan - } +trait DataFrame extends DataFrameSpecificApi with RDDApi[Row] { - /** - * An implicit conversion function internal to this class for us to avoid doing - * "new DataFrame(...)" everywhere. - */ - private implicit def logicalPlanToDataFrame(logicalPlan: LogicalPlan): DataFrame = { - new DataFrame(sqlContext, logicalPlan, true) - } + val sqlContext: SQLContext - /** Returns the list of numeric columns, useful for doing aggregation. */ - protected[sql] def numericColumns: Seq[Expression] = { - schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => - queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get - } - } + @DeveloperApi + def queryExecution: SQLContext#QueryExecution - /** Resolves a column name into a Catalyst [[NamedExpression]]. */ - protected[sql] def resolve(colName: String): NamedExpression = { - queryExecution.analyzed.resolve(colName, sqlContext.analyzer.resolver).getOrElse { - throw new RuntimeException( - s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""") - } - } + protected[sql] def logicalPlan: LogicalPlan /** Left here for compatibility reasons. */ @deprecated("1.3.0", "use toDataFrame") @@ -142,32 +101,19 @@ class DataFrame protected[sql]( * }}} */ @scala.annotation.varargs - def toDataFrame(colName: String, colNames: String*): DataFrame = { - val newNames = colName +: colNames - require(schema.size == newNames.size, - "The number of columns doesn't match.\n" + - "Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" + - "New column names: " + newNames.mkString(", ")) - - val newCols = schema.fieldNames.zip(newNames).map { case (oldName, newName) => - apply(oldName).as(newName) - } - select(newCols :_*) - } + def toDataFrame(colName: String, colNames: String*): DataFrame /** Returns the schema of this [[DataFrame]]. */ - override def schema: StructType = queryExecution.analyzed.schema + override def schema: StructType /** Returns all column names and their data types as an array. */ - override def dtypes: Array[(String, String)] = schema.fields.map { field => - (field.name, field.dataType.toString) - } + override def dtypes: Array[(String, String)] /** Returns all column names as an array. */ override def columns: Array[String] = schema.fields.map(_.name) /** Prints the schema to the console in a nice tree format. */ - override def printSchema(): Unit = println(schema.treeString) + override def printSchema(): Unit /** * Cartesian join with another [[DataFrame]]. @@ -176,9 +122,7 @@ class DataFrame protected[sql]( * * @param right Right side of the join operation. */ - override def join(right: DataFrame): DataFrame = { - Join(logicalPlan, right.logicalPlan, joinType = Inner, None) - } + override def join(right: DataFrame): DataFrame /** * Inner join with another [[DataFrame]], using the given join expression. @@ -189,9 +133,7 @@ class DataFrame protected[sql]( * df1.join(df2).where($"df1Key" === $"df2Key") * }}} */ - override def join(right: DataFrame, joinExprs: Column): DataFrame = { - Join(logicalPlan, right.logicalPlan, Inner, Some(joinExprs.expr)) - } + override def join(right: DataFrame, joinExprs: Column): DataFrame /** * Join with another [[DataFrame]], usin g the given join expression. The following performs @@ -205,9 +147,7 @@ class DataFrame protected[sql]( * @param joinExprs Join expression. * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. */ - override def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = { - Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr)) - } + override def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame /** * Returns a new [[DataFrame]] sorted by the specified column, all in ascending order. @@ -219,9 +159,7 @@ class DataFrame protected[sql]( * }}} */ @scala.annotation.varargs - override def sort(sortCol: String, sortCols: String*): DataFrame = { - orderBy(apply(sortCol), sortCols.map(apply) :_*) - } + override def sort(sortCol: String, sortCols: String*): DataFrame /** * Returns a new [[DataFrame]] sorted by the given expressions. For example: @@ -230,46 +168,26 @@ class DataFrame protected[sql]( * }}} */ @scala.annotation.varargs - override def sort(sortExpr: Column, sortExprs: Column*): DataFrame = { - val sortOrder: Seq[SortOrder] = (sortExpr +: sortExprs).map { col => - col.expr match { - case expr: SortOrder => - expr - case expr: Expression => - SortOrder(expr, Ascending) - } - } - Sort(sortOrder, global = true, logicalPlan) - } + override def sort(sortExpr: Column, sortExprs: Column*): DataFrame /** * Returns a new [[DataFrame]] sorted by the given expressions. * This is an alias of the `sort` function. */ @scala.annotation.varargs - override def orderBy(sortCol: String, sortCols: String*): DataFrame = { - sort(sortCol, sortCols :_*) - } + override def orderBy(sortCol: String, sortCols: String*): DataFrame /** * Returns a new [[DataFrame]] sorted by the given expressions. * This is an alias of the `sort` function. */ @scala.annotation.varargs - override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = { - sort(sortExpr, sortExprs :_*) - } + override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame /** * Selects column based on the column name and return it as a [[Column]]. */ - override def apply(colName: String): Column = colName match { - case "*" => - new Column(ResolvedStar(schema.fieldNames.map(resolve))) - case _ => - val expr = resolve(colName) - new Column(Some(sqlContext), Some(Project(Seq(expr), logicalPlan)), expr) - } + override def apply(colName: String): Column /** * Selects a set of expressions, wrapped in a Product. @@ -279,18 +197,12 @@ class DataFrame protected[sql]( * df.select($"colA", $"colB" + 1) * }}} */ - override def apply(projection: Product): DataFrame = { - require(projection.productArity >= 1) - select(projection.productIterator.map { - case c: Column => c - case o: Any => new Column(Some(sqlContext), None, Literal(o)) - }.toSeq :_*) - } + override def apply(projection: Product): DataFrame /** * Returns a new [[DataFrame]] with an alias set. */ - override def as(name: String): DataFrame = Subquery(name, logicalPlan) + override def as(name: String): DataFrame /** * Selects a set of expressions. @@ -299,15 +211,7 @@ class DataFrame protected[sql]( * }}} */ @scala.annotation.varargs - override def select(cols: Column*): DataFrame = { - val exprs = cols.zipWithIndex.map { - case (Column(expr: NamedExpression), _) => - expr - case (Column(expr: Expression), _) => - Alias(expr, expr.toString)() - } - Project(exprs.toSeq, logicalPlan) - } + override def select(cols: Column*): DataFrame /** * Selects a set of columns. This is a variant of `select` that can only select @@ -320,9 +224,7 @@ class DataFrame protected[sql]( * }}} */ @scala.annotation.varargs - override def select(col: String, cols: String*): DataFrame = { - select((col +: cols).map(new Column(_)) :_*) - } + override def select(col: String, cols: String*): DataFrame /** * Filters rows using the given condition. @@ -333,9 +235,7 @@ class DataFrame protected[sql]( * peopleDf($"age" > 15) * }}} */ - override def filter(condition: Column): DataFrame = { - Filter(condition.expr, logicalPlan) - } + override def filter(condition: Column): DataFrame /** * Filters rows using the given condition. This is an alias for `filter`. @@ -346,7 +246,7 @@ class DataFrame protected[sql]( * peopleDf($"age" > 15) * }}} */ - override def where(condition: Column): DataFrame = filter(condition) + override def where(condition: Column): DataFrame /** * Filters rows using the given condition. This is a shorthand meant for Scala. @@ -357,7 +257,7 @@ class DataFrame protected[sql]( * peopleDf($"age" > 15) * }}} */ - override def apply(condition: Column): DataFrame = filter(condition) + override def apply(condition: Column): DataFrame /** * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. @@ -375,9 +275,7 @@ class DataFrame protected[sql]( * }}} */ @scala.annotation.varargs - override def groupBy(cols: Column*): GroupedDataFrame = { - new GroupedDataFrame(this, cols.map(_.expr)) - } + override def groupBy(cols: Column*): GroupedDataFrame /** * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. @@ -398,10 +296,7 @@ class DataFrame protected[sql]( * }}} */ @scala.annotation.varargs - override def groupBy(col1: String, cols: String*): GroupedDataFrame = { - val colNames: Seq[String] = col1 +: cols - new GroupedDataFrame(this, colNames.map(colName => resolve(colName))) - } + override def groupBy(col1: String, cols: String*): GroupedDataFrame /** * Aggregates on the entire [[DataFrame]] without groups. @@ -411,7 +306,7 @@ class DataFrame protected[sql]( * df.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) * }} */ - override def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs) + override def agg(exprs: Map[String, String]): DataFrame /** * Aggregates on the entire [[DataFrame]] without groups. @@ -421,7 +316,7 @@ class DataFrame protected[sql]( * df.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) * }} */ - override def agg(exprs: java.util.Map[String, String]): DataFrame = agg(exprs.toMap) + override def agg(exprs: java.util.Map[String, String]): DataFrame /** * Aggregates on the entire [[DataFrame]] without groups. @@ -432,31 +327,31 @@ class DataFrame protected[sql]( * }} */ @scala.annotation.varargs - override def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs :_*) + override def agg(expr: Column, exprs: Column*): DataFrame /** * Returns a new [[DataFrame]] by taking the first `n` rows. The difference between this function * and `head` is that `head` returns an array while `limit` returns a new [[DataFrame]]. */ - override def limit(n: Int): DataFrame = Limit(Literal(n), logicalPlan) + override def limit(n: Int): DataFrame /** * Returns a new [[DataFrame]] containing union of rows in this frame and another frame. * This is equivalent to `UNION ALL` in SQL. */ - override def unionAll(other: DataFrame): DataFrame = Union(logicalPlan, other.logicalPlan) + override def unionAll(other: DataFrame): DataFrame /** * Returns a new [[DataFrame]] containing rows only in both this frame and another frame. * This is equivalent to `INTERSECT` in SQL. */ - override def intersect(other: DataFrame): DataFrame = Intersect(logicalPlan, other.logicalPlan) + override def intersect(other: DataFrame): DataFrame /** * Returns a new [[DataFrame]] containing rows in this frame but not in another frame. * This is equivalent to `EXCEPT` in SQL. */ - override def except(other: DataFrame): DataFrame = Except(logicalPlan, other.logicalPlan) + override def except(other: DataFrame): DataFrame /** * Returns a new [[DataFrame]] by sampling a fraction of rows. @@ -465,9 +360,7 @@ class DataFrame protected[sql]( * @param fraction Fraction of rows to generate. * @param seed Seed for sampling. */ - override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = { - Sample(fraction, withReplacement, seed, logicalPlan) - } + override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame /** * Returns a new [[DataFrame]] by sampling a fraction of rows, using a random seed. @@ -475,105 +368,85 @@ class DataFrame protected[sql]( * @param withReplacement Sample with replacement or not. * @param fraction Fraction of rows to generate. */ - override def sample(withReplacement: Boolean, fraction: Double): DataFrame = { - sample(withReplacement, fraction, Utils.random.nextLong) - } + override def sample(withReplacement: Boolean, fraction: Double): DataFrame ///////////////////////////////////////////////////////////////////////////// /** * Returns a new [[DataFrame]] by adding a column. */ - override def addColumn(colName: String, col: Column): DataFrame = { - select(Column("*"), col.as(colName)) - } + override def addColumn(colName: String, col: Column): DataFrame /** * Returns the first `n` rows. */ - override def head(n: Int): Array[Row] = limit(n).collect() + override def head(n: Int): Array[Row] /** * Returns the first row. */ - override def head(): Row = head(1).head + override def head(): Row /** * Returns the first row. Alias for head(). */ - override def first(): Row = head() + override def first(): Row /** * Returns a new RDD by applying a function to all rows of this DataFrame. */ - override def map[R: ClassTag](f: Row => R): RDD[R] = { - rdd.map(f) - } + override def map[R: ClassTag](f: Row => R): RDD[R] /** * Returns a new RDD by first applying a function to all rows of this [[DataFrame]], * and then flattening the results. */ - override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f) + override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] /** * Returns a new RDD by applying a function to each partition of this DataFrame. */ - override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = { - rdd.mapPartitions(f) - } - + override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] /** * Applies a function `f` to all rows. */ - override def foreach(f: Row => Unit): Unit = rdd.foreach(f) + override def foreach(f: Row => Unit): Unit /** * Applies a function f to each partition of this [[DataFrame]]. */ - override def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f) + override def foreachPartition(f: Iterator[Row] => Unit): Unit /** * Returns the first `n` rows in the [[DataFrame]]. */ - override def take(n: Int): Array[Row] = head(n) + override def take(n: Int): Array[Row] /** * Returns an array that contains all of [[Row]]s in this [[DataFrame]]. */ - override def collect(): Array[Row] = rdd.collect() + override def collect(): Array[Row] /** * Returns a Java list that contains all of [[Row]]s in this [[DataFrame]]. */ - override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() :_*) + override def collectAsList(): java.util.List[Row] /** * Returns the number of rows in the [[DataFrame]]. */ - override def count(): Long = groupBy().count().rdd.collect().head.getLong(0) + override def count(): Long /** * Returns a new [[DataFrame]] that has exactly `numPartitions` partitions. */ - override def repartition(numPartitions: Int): DataFrame = { - sqlContext.applySchema(rdd.repartition(numPartitions), schema) - } + override def repartition(numPartitions: Int): DataFrame - override def persist(): this.type = { - sqlContext.cacheManager.cacheQuery(this) - this - } + override def persist(): this.type - override def persist(newLevel: StorageLevel): this.type = { - sqlContext.cacheManager.cacheQuery(this, None, newLevel) - this - } + override def persist(newLevel: StorageLevel): this.type - override def unpersist(blocking: Boolean): this.type = { - sqlContext.cacheManager.tryUncacheQuery(this, blocking) - this - } + override def unpersist(blocking: Boolean): this.type ///////////////////////////////////////////////////////////////////////////// // I/O @@ -582,10 +455,7 @@ class DataFrame protected[sql]( /** * Returns the content of the [[DataFrame]] as an [[RDD]] of [[Row]]s. */ - override def rdd: RDD[Row] = { - val schema = this.schema - queryExecution.executedPlan.execute().map(ScalaReflection.convertRowToScala(_, schema)) - } + override def rdd: RDD[Row] /** * Registers this RDD as a temporary table using the given name. The lifetime of this temporary @@ -593,18 +463,14 @@ class DataFrame protected[sql]( * * @group schema */ - override def registerTempTable(tableName: String): Unit = { - sqlContext.registerRDDAsTable(this, tableName) - } + override def registerTempTable(tableName: String): Unit /** * Saves the contents of this [[DataFrame]] as a parquet file, preserving the schema. * Files that are written out using this method can be read back in as a [[DataFrame]] * using the `parquetFile` function in [[SQLContext]]. */ - override def saveAsParquetFile(path: String): Unit = { - sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd - } + override def saveAsParquetFile(path: String): Unit /** * :: Experimental :: @@ -617,31 +483,19 @@ class DataFrame protected[sql]( * be the target of an `insertInto`. */ @Experimental - override def saveAsTable(tableName: String): Unit = { - sqlContext.executePlan( - CreateTableAsSelect(None, tableName, logicalPlan, allowExisting = false)).toRdd - } + override def saveAsTable(tableName: String): Unit /** * :: Experimental :: * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. */ @Experimental - override def insertInto(tableName: String, overwrite: Boolean): Unit = { - sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)), - Map.empty, logicalPlan, overwrite)).toRdd - } + override def insertInto(tableName: String, overwrite: Boolean): Unit /** * Returns the content of the [[DataFrame]] as a RDD of JSON strings. */ - override def toJSON: RDD[String] = { - val rowSchema = this.schema - this.mapPartitions { iter => - val jsonFactory = new JsonFactory() - iter.map(JsonRDD.rowToJSON(rowSchema, jsonFactory)) - } - } + override def toJSON: RDD[String] //////////////////////////////////////////////////////////////////////////// // for Python API @@ -649,16 +503,10 @@ class DataFrame protected[sql]( /** * A helpful function for Py4j, convert a list of Column to an array */ - protected[sql] def toColumnArray(cols: JList[Column]): Array[Column] = { - cols.toList.toArray - } + protected[sql] def toColumnArray(cols: JList[Column]): Array[Column] /** * Converts a JavaRDD to a PythonRDD. */ - protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { - val fieldTypes = schema.fields.map(_.dataType) - val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() - SerDeUtil.javaToPython(jrdd) - } + protected[sql] def javaToPython: JavaRDD[Array[Byte]] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala new file mode 100644 index 0000000000000..0b07010fa8f63 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala @@ -0,0 +1,329 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import java.util.{List => JList} + +import scala.language.implicitConversions +import scala.reflect.ClassTag +import scala.collection.JavaConversions._ + +import com.fasterxml.jackson.core.JsonFactory + +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.python.SerDeUtil +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython} +import org.apache.spark.sql.json.JsonRDD +import org.apache.spark.sql.types.{NumericType, StructType} +import org.apache.spark.util.Utils + + +/** + * Implementation for [[DataFrame]]. Refer to [[DataFrame]] directly for documentation. + */ +class DataFrameImpl protected[sql]( + override val sqlContext: SQLContext, + private val baseLogicalPlan: LogicalPlan) + extends DataFrame { + + @transient override lazy val queryExecution = sqlContext.executePlan(baseLogicalPlan) + + @transient protected[sql] override val logicalPlan: LogicalPlan = baseLogicalPlan match { + // For various commands (like DDL) and queries with side effects, we force query optimization to + // happen right away to let these side effects take place eagerly. + case _: Command | _: InsertIntoTable | _: CreateTableAsSelect[_] |_: WriteToFile => + LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) + case _ => + baseLogicalPlan + } + + /** + * An implicit conversion function internal to this class for us to avoid doing + * "new DataFrameImpl(...)" everywhere. + */ + @inline private implicit def logicalPlanToDataFrame(logicalPlan: LogicalPlan): DataFrame = { + new DataFrameImpl(sqlContext, logicalPlan) + } + + protected[sql] def resolve(colName: String): NamedExpression = { + queryExecution.analyzed.resolve(colName, sqlContext.analyzer.resolver).getOrElse { + throw new RuntimeException( + s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""") + } + } + + protected[sql] def numericColumns: Seq[Expression] = { + schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => + queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get + } + } + + override def toDataFrame(colName: String, colNames: String*): DataFrame = { + val newNames = colName +: colNames + require(schema.size == newNames.size, + "The number of columns doesn't match.\n" + + "Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" + + "New column names: " + newNames.mkString(", ")) + + val newCols = schema.fieldNames.zip(newNames).map { case (oldName, newName) => + apply(oldName).as(newName) + } + select(newCols :_*) + } + + override def schema: StructType = queryExecution.analyzed.schema + + override def dtypes: Array[(String, String)] = schema.fields.map { field => + (field.name, field.dataType.toString) + } + + override def columns: Array[String] = schema.fields.map(_.name) + + override def printSchema(): Unit = println(schema.treeString) + + override def join(right: DataFrame): DataFrame = { + Join(logicalPlan, right.logicalPlan, joinType = Inner, None) + } + + override def join(right: DataFrame, joinExprs: Column): DataFrame = { + Join(logicalPlan, right.logicalPlan, Inner, Some(joinExprs.expr)) + } + + override def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = { + Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr)) + } + + override def sort(sortCol: String, sortCols: String*): DataFrame = { + orderBy(apply(sortCol), sortCols.map(apply) :_*) + } + + override def sort(sortExpr: Column, sortExprs: Column*): DataFrame = { + val sortOrder: Seq[SortOrder] = (sortExpr +: sortExprs).map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + Sort(sortOrder, global = true, logicalPlan) + } + + override def orderBy(sortCol: String, sortCols: String*): DataFrame = { + sort(sortCol, sortCols :_*) + } + + override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = { + sort(sortExpr, sortExprs :_*) + } + + override def apply(colName: String): Column = colName match { + case "*" => + Column(ResolvedStar(schema.fieldNames.map(resolve))) + case _ => + val expr = resolve(colName) + Column(sqlContext, Project(Seq(expr), logicalPlan), expr) + } + + override def apply(projection: Product): DataFrame = { + require(projection.productArity >= 1) + select(projection.productIterator.map { + case c: Column => c + case o: Any => Column(Literal(o)) + }.toSeq :_*) + } + + override def as(name: String): DataFrame = Subquery(name, logicalPlan) + + override def select(cols: Column*): DataFrame = { + val exprs = cols.zipWithIndex.map { + case (Column(expr: NamedExpression), _) => + expr + case (Column(expr: Expression), _) => + Alias(expr, expr.toString)() + } + Project(exprs.toSeq, logicalPlan) + } + + override def select(col: String, cols: String*): DataFrame = { + select((col +: cols).map(Column(_)) :_*) + } + + override def filter(condition: Column): DataFrame = { + Filter(condition.expr, logicalPlan) + } + + override def where(condition: Column): DataFrame = { + filter(condition) + } + + override def apply(condition: Column): DataFrame = { + filter(condition) + } + + override def groupBy(cols: Column*): GroupedDataFrame = { + new GroupedDataFrame(this, cols.map(_.expr)) + } + + override def groupBy(col1: String, cols: String*): GroupedDataFrame = { + val colNames: Seq[String] = col1 +: cols + new GroupedDataFrame(this, colNames.map(colName => resolve(colName))) + } + + override def agg(exprs: Map[String, String]): DataFrame = { + groupBy().agg(exprs) + } + + override def agg(exprs: java.util.Map[String, String]): DataFrame = { + agg(exprs.toMap) + } + + override def agg(expr: Column, exprs: Column*): DataFrame = { + groupBy().agg(expr, exprs :_*) + } + + override def limit(n: Int): DataFrame = { + Limit(Literal(n), logicalPlan) + } + + override def unionAll(other: DataFrame): DataFrame = { + Union(logicalPlan, other.logicalPlan) + } + + override def intersect(other: DataFrame): DataFrame = { + Intersect(logicalPlan, other.logicalPlan) + } + + override def except(other: DataFrame): DataFrame = { + Except(logicalPlan, other.logicalPlan) + } + + override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = { + Sample(fraction, withReplacement, seed, logicalPlan) + } + + override def sample(withReplacement: Boolean, fraction: Double): DataFrame = { + sample(withReplacement, fraction, Utils.random.nextLong) + } + + ///////////////////////////////////////////////////////////////////////////// + + override def addColumn(colName: String, col: Column): DataFrame = { + select(Column("*"), col.as(colName)) + } + + override def head(n: Int): Array[Row] = limit(n).collect() + + override def head(): Row = head(1).head + + override def first(): Row = head() + + override def map[R: ClassTag](f: Row => R): RDD[R] = rdd.map(f) + + override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f) + + override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = { + rdd.mapPartitions(f) + } + + override def foreach(f: Row => Unit): Unit = rdd.foreach(f) + + override def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f) + + override def take(n: Int): Array[Row] = head(n) + + override def collect(): Array[Row] = rdd.collect() + + override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() :_*) + + override def count(): Long = groupBy().count().rdd.collect().head.getLong(0) + + override def repartition(numPartitions: Int): DataFrame = { + sqlContext.applySchema(rdd.repartition(numPartitions), schema) + } + + override def persist(): this.type = { + sqlContext.cacheManager.cacheQuery(this) + this + } + + override def persist(newLevel: StorageLevel): this.type = { + sqlContext.cacheManager.cacheQuery(this, None, newLevel) + this + } + + override def unpersist(blocking: Boolean): this.type = { + sqlContext.cacheManager.tryUncacheQuery(this, blocking) + this + } + + ///////////////////////////////////////////////////////////////////////////// + // I/O + ///////////////////////////////////////////////////////////////////////////// + + override def rdd: RDD[Row] = { + val schema = this.schema + queryExecution.executedPlan.execute().map(ScalaReflection.convertRowToScala(_, schema)) + } + + override def registerTempTable(tableName: String): Unit = { + sqlContext.registerRDDAsTable(this, tableName) + } + + override def saveAsParquetFile(path: String): Unit = { + sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd + } + + override def saveAsTable(tableName: String): Unit = { + sqlContext.executePlan( + CreateTableAsSelect(None, tableName, logicalPlan, allowExisting = false)).toRdd + } + + override def insertInto(tableName: String, overwrite: Boolean): Unit = { + sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)), + Map.empty, logicalPlan, overwrite)).toRdd + } + + override def toJSON: RDD[String] = { + val rowSchema = this.schema + this.mapPartitions { iter => + val jsonFactory = new JsonFactory() + iter.map(JsonRDD.rowToJSON(rowSchema, jsonFactory)) + } + } + + //////////////////////////////////////////////////////////////////////////// + // for Python API + //////////////////////////////////////////////////////////////////////////// + protected[sql] override def toColumnArray(cols: JList[Column]): Array[Column] = { + cols.toList.toArray + } + + protected[sql] override def javaToPython: JavaRDD[Array[Byte]] = { + val fieldTypes = schema.fields.map(_.dataType) + val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() + SerDeUtil.javaToPython(jrdd) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala index 3499956023d11..4d6320c05aed2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala @@ -55,17 +55,17 @@ object Dsl { } } - private[this] implicit def toColumn(expr: Expression): Column = new Column(expr) + private[this] implicit def toColumn(expr: Expression): Column = Column(expr) /** * Returns a [[Column]] based on the given column name. */ - def col(colName: String): Column = new Column(colName) + def col(colName: String): Column = Column(colName) /** * Returns a [[Column]] based on the given column name. Alias of [[col]]. */ - def column(colName: String): Column = new Column(colName) + def column(colName: String): Column = Column(colName) /** * Creates a [[Column]] of literal value. @@ -94,7 +94,7 @@ object Dsl { case _ => throw new RuntimeException("Unsupported literal type " + literal.getClass + " " + literal) } - new Column(literalExpr) + Column(literalExpr) } def sum(e: Column): Column = Sum(e.expr) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala index 1c948cbbfe58f..3187a92fc7b71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Aggregate /** * A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]]. */ -class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) +class GroupedDataFrame protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expression]) extends GroupedDataFrameApi { private[this] implicit def toDataFrame(aggExprs: Seq[NamedExpression]): DataFrame = { @@ -36,8 +36,8 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi case expr: NamedExpression => expr case expr: Expression => Alias(expr, expr.toString)() } - new DataFrame(df.sqlContext, - Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan)) + DataFrame( + df.sqlContext, Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan)) } private[this] def aggregateNumericColumns(f: Expression => Expression): Seq[NamedExpression] = { @@ -112,8 +112,7 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi case expr: NamedExpression => expr case expr: Expression => Alias(expr, expr.toString)() } - - new DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan)) + DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala new file mode 100644 index 0000000000000..215b57ec63ba9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala @@ -0,0 +1,162 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import scala.reflect.ClassTag + +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.storage.StorageLevel +import org.apache.spark.sql.types.StructType + + +private[sql] class IncomputableColumn(protected[sql] val expr: Expression) extends Column { + + def this(name: String) = this(name match { + case "*" => UnresolvedStar(None) + case _ if name.endsWith(".*") => UnresolvedStar(Some(name.substring(0, name.length - 2))) + case _ => UnresolvedAttribute(name) + }) + + private def err[T](): T = { + throw new UnsupportedOperationException("Cannot run this method on an UncomputableColumn") + } + + override def isComputable: Boolean = false + + override val sqlContext: SQLContext = null + + override def queryExecution = err() + + protected[sql] override def logicalPlan: LogicalPlan = err() + + override def toDataFrame(colName: String, colNames: String*): DataFrame = err() + + override def schema: StructType = err() + + override def dtypes: Array[(String, String)] = err() + + override def columns: Array[String] = err() + + override def printSchema(): Unit = err() + + override def join(right: DataFrame): DataFrame = err() + + override def join(right: DataFrame, joinExprs: Column): DataFrame = err() + + override def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = err() + + override def sort(sortCol: String, sortCols: String*): DataFrame = err() + + override def sort(sortExpr: Column, sortExprs: Column*): DataFrame = err() + + override def orderBy(sortCol: String, sortCols: String*): DataFrame = err() + + override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = err() + + override def apply(colName: String): Column = err() + + override def apply(projection: Product): DataFrame = err() + + override def select(cols: Column*): DataFrame = err() + + override def select(col: String, cols: String*): DataFrame = err() + + override def filter(condition: Column): DataFrame = err() + + override def where(condition: Column): DataFrame = err() + + override def apply(condition: Column): DataFrame = err() + + override def groupBy(cols: Column*): GroupedDataFrame = err() + + override def groupBy(col1: String, cols: String*): GroupedDataFrame = err() + + override def agg(exprs: Map[String, String]): DataFrame = err() + + override def agg(exprs: java.util.Map[String, String]): DataFrame = err() + + override def agg(expr: Column, exprs: Column*): DataFrame = err() + + override def limit(n: Int): DataFrame = err() + + override def unionAll(other: DataFrame): DataFrame = err() + + override def intersect(other: DataFrame): DataFrame = err() + + override def except(other: DataFrame): DataFrame = err() + + override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = err() + + override def sample(withReplacement: Boolean, fraction: Double): DataFrame = err() + + ///////////////////////////////////////////////////////////////////////////// + + override def addColumn(colName: String, col: Column): DataFrame = err() + + override def head(n: Int): Array[Row] = err() + + override def head(): Row = err() + + override def first(): Row = err() + + override def map[R: ClassTag](f: Row => R): RDD[R] = err() + + override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = err() + + override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = err() + + override def foreach(f: Row => Unit): Unit = err() + + override def foreachPartition(f: Iterator[Row] => Unit): Unit = err() + + override def take(n: Int): Array[Row] = err() + + override def collect(): Array[Row] = err() + + override def collectAsList(): java.util.List[Row] = err() + + override def count(): Long = err() + + override def repartition(numPartitions: Int): DataFrame = err() + + override def persist(): this.type = err() + + override def persist(newLevel: StorageLevel): this.type = err() + + override def unpersist(blocking: Boolean): this.type = err() + + override def rdd: RDD[Row] = err() + + override def registerTempTable(tableName: String): Unit = err() + + override def saveAsParquetFile(path: String): Unit = err() + + override def saveAsTable(tableName: String): Unit = err() + + override def insertInto(tableName: String, overwrite: Boolean): Unit = err() + + override def toJSON: RDD[String] = err() + + protected[sql] override def toColumnArray(cols: java.util.List[Column]): Array[Column] = err() + + protected[sql] override def javaToPython: JavaRDD[Array[Byte]] = err() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 84933dd944837..d0bbb5f7a34f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -171,14 +171,14 @@ class SQLContext(@transient val sparkContext: SparkContext) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes val rowRDD = RDDConversions.productToRowRdd(rdd, schema) - new DataFrame(this, LogicalRDD(attributeSeq, rowRDD)(self)) + DataFrame(this, LogicalRDD(attributeSeq, rowRDD)(self)) } /** * Convert a [[BaseRelation]] created for external data sources into a [[DataFrame]]. */ def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = { - new DataFrame(this, LogicalRelation(baseRelation)) + DataFrame(this, LogicalRelation(baseRelation)) } /** @@ -216,7 +216,7 @@ class SQLContext(@transient val sparkContext: SparkContext) // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self) - new DataFrame(this, logicalPlan) + DataFrame(this, logicalPlan) } /** @@ -243,7 +243,7 @@ class SQLContext(@transient val sparkContext: SparkContext) ) : Row } } - new DataFrame(this, LogicalRDD(attributeSeq, rowRdd)(this)) + DataFrame(this, LogicalRDD(attributeSeq, rowRdd)(this)) } /** @@ -262,7 +262,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ def parquetFile(path: String): DataFrame = - new DataFrame(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this)) + DataFrame(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this)) /** * Loads a JSON file (one object per line), returning the result as a [[DataFrame]]. @@ -365,7 +365,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ def sql(sqlText: String): DataFrame = { if (conf.dialect == "sql") { - new DataFrame(this, parseSql(sqlText)) + DataFrame(this, parseSql(sqlText)) } else { sys.error(s"Unsupported SQL dialect: ${conf.dialect}") } @@ -373,7 +373,7 @@ class SQLContext(@transient val sparkContext: SparkContext) /** Returns the specified table as a [[DataFrame]]. */ def table(tableName: String): DataFrame = - new DataFrame(this, catalog.lookupRelation(Seq(tableName))) + DataFrame(this, catalog.lookupRelation(Seq(tableName))) protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext: SparkContext = self.sparkContext @@ -462,7 +462,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * access to the intermediate phases of query execution for developers. */ @DeveloperApi - protected class QueryExecution(val logical: LogicalPlan) { + protected[sql] class QueryExecution(val logical: LogicalPlan) { lazy val analyzed: LogicalPlan = ExtractPythonUdfs(analyzer(logical)) lazy val withCachedData: LogicalPlan = cacheManager.useCachedData(analyzed) @@ -556,7 +556,7 @@ class SQLContext(@transient val sparkContext: SparkContext) iter.map { m => new GenericRow(m): Row} } - new DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) + DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 6fba76c52171b..e1c9a2be7d20d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -138,7 +138,7 @@ case class CacheTableCommand( override def run(sqlContext: SQLContext) = { plan.foreach { logicalPlan => - sqlContext.registerRDDAsTable(new DataFrame(sqlContext, logicalPlan), tableName) + sqlContext.registerRDDAsTable(DataFrame(sqlContext, logicalPlan), tableName) } sqlContext.cacheTable(tableName) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index b4af91a768efb..43f5bb46d3aac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -226,7 +226,7 @@ private [sql] case class CreateTempTableUsing( def run(sqlContext: SQLContext) = { val resolved = ResolvedDataSource(sqlContext, userSpecifiedSchema, provider, options) sqlContext.registerRDDAsTable( - new DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) + DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) Seq.empty } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala index 906455dd40c0d..4e1ec38bd0158 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -41,7 +41,7 @@ object TestSQLContext * construct [[DataFrame]] directly out of local data without relying on implicits. */ protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { - new DataFrame(this, plan) + DataFrame(this, plan) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 2d464c2b53d79..be6f2af085da7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -27,6 +27,41 @@ class ColumnExpressionSuite extends QueryTest { // TODO: Add test cases for bitwise operations. + test("computability check") { + def shouldBeComputable(c: Column): Unit = assert(c.isComputable === true) + def shouldNotBeComputable(c: Column): Unit = assert(c.isComputable === false) + + shouldBeComputable(testData2("a")) + shouldBeComputable(testData2("b")) + + shouldBeComputable(testData2("a") + testData2("b")) + shouldBeComputable(testData2("a") + testData2("b") + 1) + + shouldBeComputable(-testData2("a")) + shouldBeComputable(!testData2("a")) + + shouldBeComputable(testData2.select(($"a" + 1).as("c"))("c") + testData2("b")) + shouldBeComputable( + testData2.select(($"a" + 1).as("c"))("c") + testData2.select(($"b" / 2).as("d"))("d")) + shouldBeComputable( + testData2.select(($"a" + 1).as("c")).select(($"c" + 2).as("d"))("d") + testData2("b")) + + // Literals and unresolved columns should not be computable. + shouldNotBeComputable(col("1")) + shouldNotBeComputable(col("1") + 2) + shouldNotBeComputable(lit(100)) + shouldNotBeComputable(lit(100) + 10) + shouldNotBeComputable(-col("1")) + shouldNotBeComputable(!col("1")) + + // Getting data from different frames should not be computable. + shouldNotBeComputable(testData2("a") + testData("key")) + shouldNotBeComputable(testData2("a") + 1 + testData("key")) + + // Aggregate functions alone should not be computable. + shouldNotBeComputable(sum(testData2("a"))) + } + test("star") { checkAnswer(testData.select($"*"), testData.collect().toSeq) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index e78145f4dda5a..ff91a0eb42049 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, Predicate, Row} import org.apache.spark.sql.types._ import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf} +import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf} /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -51,8 +51,8 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { val query = rdd - .select(output.map(e => new org.apache.spark.sql.Column(e)): _*) - .where(new org.apache.spark.sql.Column(predicate)) + .select(output.map(e => Column(e)): _*) + .where(Column(predicate)) val maybeAnalyzedPredicate = query.queryExecution.executedPlan.collect { case plan: ParquetTableScan => plan.columnPruningPred diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index b746942cb1067..5efc3b1e30774 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -72,7 +72,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { if (conf.dialect == "sql") { super.sql(substituted) } else if (conf.dialect == "hiveql") { - new DataFrame(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(substituted))) + DataFrame(this, + ddlParser(sqlText, exceptionOnError = false).getOrElse(HiveQl.parseSql(substituted))) } else { sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index ace9329cd5821..81807d290d1ce 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.hive +import org.apache.spark.sql.catalyst.expressions.Row + import scala.collection.JavaConversions._ import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.{Column, DataFrame, SQLContext, Strategy} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate @@ -29,7 +31,6 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} import org.apache.spark.sql.execution._ -import org.apache.spark.sql.hive import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.sources.CreateTableUsing @@ -56,14 +57,14 @@ private[hive] trait HiveStrategies { @Experimental object ParquetConversion extends Strategy { implicit class LogicalPlanHacks(s: DataFrame) { - def lowerCase = new DataFrame(s.sqlContext, s.logicalPlan) + def lowerCase = DataFrame(s.sqlContext, s.logicalPlan) def addPartitioningAttributes(attrs: Seq[Attribute]) = { // Don't add the partitioning key if its already present in the data. if (attrs.map(_.name).toSet.subsetOf(s.logicalPlan.output.map(_.name).toSet)) { s } else { - new DataFrame( + DataFrame( s.sqlContext, s.logicalPlan transform { case p: ParquetRelation => p.copy(partitioningAttributes = attrs) @@ -96,13 +97,13 @@ private[hive] trait HiveStrategies { // We are going to throw the predicates and projection back at the whole optimization // sequence so lets unresolve all the attributes, allowing them to be rebound to the // matching parquet attributes. - val unresolvedOtherPredicates = new Column(otherPredicates.map(_ transform { + val unresolvedOtherPredicates = Column(otherPredicates.map(_ transform { case a: AttributeReference => UnresolvedAttribute(a.name) }).reduceOption(And).getOrElse(Literal(true))) val unresolvedProjection: Seq[Column] = projectList.map(_ transform { case a: AttributeReference => UnresolvedAttribute(a.name) - }).map(new Column(_)) + }).map(Column(_)) try { if (relation.hiveQlTable.isPartitioned) {