From 68257a56838498a88795fe5ba5c96348128f21f8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 29 Mar 2016 00:25:29 +0800 Subject: [PATCH 1/3] remove trait Queryable --- .../scala/org/apache/spark/sql/Dataset.scala | 88 +++++++++++-- .../spark/sql/KeyValueGroupedDataset.scala | 13 -- .../org/apache/spark/sql/SQLContext.scala | 4 +- .../spark/sql/execution/CacheManager.scala | 17 +-- .../spark/sql/execution/Queryable.scala | 124 ------------------ 5 files changed, 90 insertions(+), 156 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 703ea4d1498ce..41cb799b97141 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -22,8 +22,10 @@ import java.io.CharArrayWriter import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag +import scala.util.control.NonFatal import com.fasterxml.jackson.core.JsonFactory +import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD @@ -39,7 +41,7 @@ import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.usePrettyExpression -import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution} +import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator @@ -150,10 +152,10 @@ private[sql] object Dataset { * @since 1.6.0 */ class Dataset[T] private[sql]( - @transient override val sqlContext: SQLContext, - @DeveloperApi @transient override val queryExecution: QueryExecution, + @transient val sqlContext: SQLContext, + @DeveloperApi @transient val queryExecution: QueryExecution, encoder: Encoder[T]) - extends Queryable with Serializable { + extends Serializable { queryExecution.assertAnalyzed() @@ -224,7 +226,7 @@ class Dataset[T] private[sql]( * @param _numRows Number of rows to show * @param truncate Whether truncate long strings and align cells right */ - override private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { + private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { val numRows = _numRows.max(0) val takeResult = take(numRows + 1) val hasMoreData = takeResult.length > numRows @@ -249,7 +251,75 @@ class Dataset[T] private[sql]( }: Seq[String] } - formatString ( rows, numRows, hasMoreData, truncate ) + val sb = new StringBuilder + val numCols = schema.fieldNames.length + + // Initialise the width of each column to a minimum value of '3' + val colWidths = Array.fill(numCols)(3) + + // Compute the width of each column + for (row <- rows) { + for ((cell, i) <- row.zipWithIndex) { + colWidths(i) = math.max(colWidths(i), cell.length) + } + } + + // Create SeparateLine + val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() + + // column names + rows.head.zipWithIndex.map { case (cell, i) => + if (truncate) { + StringUtils.leftPad(cell, colWidths(i)) + } else { + StringUtils.rightPad(cell, colWidths(i)) + } + }.addString(sb, "|", "|", "|\n") + + sb.append(sep) + + // data + rows.tail.map { + _.zipWithIndex.map { case (cell, i) => + if (truncate) { + StringUtils.leftPad(cell.toString, colWidths(i)) + } else { + StringUtils.rightPad(cell.toString, colWidths(i)) + } + }.addString(sb, "|", "|", "|\n") + } + + sb.append(sep) + + // For Data that has more than "numRows" records + if (hasMoreData) { + val rowsString = if (numRows == 1) "row" else "rows" + sb.append(s"only showing top $numRows $rowsString\n") + } + + sb.toString() + } + + override def toString: String = { + try { + val builder = new StringBuilder + val fields = schema.take(2).map { + case f => s"${f.name}: ${f.dataType.simpleString(2)}" + } + builder.append("[") + builder.append(fields.mkString(", ")) + if (schema.length > 2) { + if (schema.length - fields.size == 1) { + builder.append(" ... 1 more field") + } else { + builder.append(" ... " + (schema.length - 2) + " more fields") + } + } + builder.append("]").toString() + } catch { + case NonFatal(e) => + s"Invalid tree; ${e.getMessage}:\n$queryExecution" + } } /** @@ -325,7 +395,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ // scalastyle:off println - override def printSchema(): Unit = println(schema.treeString) + def printSchema(): Unit = println(schema.treeString) // scalastyle:on println /** @@ -334,7 +404,7 @@ class Dataset[T] private[sql]( * @group basic * @since 1.6.0 */ - override def explain(extended: Boolean): Unit = { + def explain(extended: Boolean): Unit = { val explain = ExplainCommand(queryExecution.logical, extended = extended) sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { // scalastyle:off println @@ -349,7 +419,7 @@ class Dataset[T] private[sql]( * @group basic * @since 1.6.0 */ - override def explain(): Unit = explain(extended = false) + def explain(): Unit = explain(extended = false) /** * Returns all column names and their data types as an array. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 07aa1515f3841..f19ad6e707526 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -57,13 +57,6 @@ class KeyValueGroupedDataset[K, V] private[sql]( private def logicalPlan = queryExecution.analyzed private def sqlContext = queryExecution.sqlContext - private def groupedData = { - new RelationalGroupedDataset( - Dataset.ofRows(sqlContext, logicalPlan), - groupingAttributes, - RelationalGroupedDataset.GroupByType) - } - /** * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the * specified type. The mapping of key columns to the type follows the same rules as `as` on @@ -207,12 +200,6 @@ class KeyValueGroupedDataset[K, V] private[sql]( reduceGroups(f.call _) } - private def withEncoder(c: Column): Column = c match { - case tc: TypedColumn[_, _] => - tc.withInputType(resolvedVEncoder.bind(dataAttributes), dataAttributes) - case _ => c - } - /** * Internal helper function for building typed aggregations that return tuples. For simplicity * and code reuse, we do this without the help of the type system and then use helper functions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index e413e77bc1349..1084bb2167cb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -272,11 +272,11 @@ class SQLContext private[sql]( } /** - * Returns true if the [[Queryable]] is currently cached in-memory. + * Returns true if the [[Dataset]] is currently cached in-memory. * @group cachemgmt * @since 1.3.0 */ - private[sql] def isCached(qName: Queryable): Boolean = { + private[sql] def isCached(qName: Dataset[_]): Boolean = { cacheManager.lookupCachedData(qName).nonEmpty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 14b8b6fc3b38b..f3478a873a1a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -22,6 +22,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.Dataset import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK @@ -74,12 +75,12 @@ private[sql] class CacheManager extends Logging { } /** - * Caches the data produced by the logical representation of the given [[Queryable]]. + * Caches the data produced by the logical representation of the given [[Dataset]]. * Unlike `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because * recomputing the in-memory columnar representation of the underlying table is expensive. */ private[sql] def cacheQuery( - query: Queryable, + query: Dataset[_], tableName: Option[String] = None, storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { val planToCache = query.queryExecution.analyzed @@ -99,8 +100,8 @@ private[sql] class CacheManager extends Logging { } } - /** Removes the data for the given [[Queryable]] from the cache */ - private[sql] def uncacheQuery(query: Queryable, blocking: Boolean = true): Unit = writeLock { + /** Removes the data for the given [[Dataset]] from the cache */ + private[sql] def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) require(dataIndex >= 0, s"Table $query is not cached.") @@ -108,11 +109,11 @@ private[sql] class CacheManager extends Logging { cachedData.remove(dataIndex) } - /** Tries to remove the data for the given [[Queryable]] from the cache + /** Tries to remove the data for the given [[Dataset]] from the cache * if it's cached */ private[sql] def tryUncacheQuery( - query: Queryable, + query: Dataset[_], blocking: Boolean = true): Boolean = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) @@ -124,8 +125,8 @@ private[sql] class CacheManager extends Logging { found } - /** Optionally returns cached data for the given [[Queryable]] */ - private[sql] def lookupCachedData(query: Queryable): Option[CachedData] = readLock { + /** Optionally returns cached data for the given [[Dataset]] */ + private[sql] def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock { lookupCachedData(query.queryExecution.analyzed) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala deleted file mode 100644 index 38263af0f7e30..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import scala.util.control.NonFatal - -import org.apache.commons.lang3.StringUtils - -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.types.StructType - -/** A trait that holds shared code between DataFrames and Datasets. */ -private[sql] trait Queryable { - def schema: StructType - def queryExecution: QueryExecution - def sqlContext: SQLContext - - override def toString: String = { - try { - val builder = new StringBuilder - val fields = schema.take(2).map { - case f => s"${f.name}: ${f.dataType.simpleString(2)}" - } - builder.append("[") - builder.append(fields.mkString(", ")) - if (schema.length > 2) { - if (schema.length - fields.size == 1) { - builder.append(" ... 1 more field") - } else { - builder.append(" ... " + (schema.length - 2) + " more fields") - } - } - builder.append("]").toString() - } catch { - case NonFatal(e) => - s"Invalid tree; ${e.getMessage}:\n$queryExecution" - } - } - - def printSchema(): Unit - - def explain(extended: Boolean): Unit - - def explain(): Unit - - private[sql] def showString(_numRows: Int, truncate: Boolean = true): String - - /** - * Format the string representing rows for output - * @param rows The rows to show - * @param numRows Number of rows to show - * @param hasMoreData Whether some rows are not shown due to the limit - * @param truncate Whether truncate long strings and align cells right - * - */ - private[sql] def formatString ( - rows: Seq[Seq[String]], - numRows: Int, - hasMoreData : Boolean, - truncate: Boolean = true): String = { - val sb = new StringBuilder - val numCols = schema.fieldNames.length - - // Initialise the width of each column to a minimum value of '3' - val colWidths = Array.fill(numCols)(3) - - // Compute the width of each column - for (row <- rows) { - for ((cell, i) <- row.zipWithIndex) { - colWidths(i) = math.max(colWidths(i), cell.length) - } - } - - // Create SeparateLine - val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() - - // column names - rows.head.zipWithIndex.map { case (cell, i) => - if (truncate) { - StringUtils.leftPad(cell, colWidths(i)) - } else { - StringUtils.rightPad(cell, colWidths(i)) - } - }.addString(sb, "|", "|", "|\n") - - sb.append(sep) - - // data - rows.tail.map { - _.zipWithIndex.map { case (cell, i) => - if (truncate) { - StringUtils.leftPad(cell.toString, colWidths(i)) - } else { - StringUtils.rightPad(cell.toString, colWidths(i)) - } - }.addString(sb, "|", "|", "|\n") - } - - sb.append(sep) - - // For Data that has more than "numRows" records - if (hasMoreData) { - val rowsString = if (numRows == 1) "row" else "rows" - sb.append(s"only showing top $numRows $rowsString\n") - } - - sb.toString() - } -} From 9bd20aaa50fd4db49fd740197419c75ecae0ca39 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 29 Mar 2016 01:32:04 +0800 Subject: [PATCH 2/3] fix mima --- project/MimaExcludes.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 208c7a28cf9bc..94621d7fa3723 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -589,6 +589,9 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.util.MLUtils.loadLabeledData"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.optimization.LBFGS.setMaxNumIterations"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.setScoreCol") + ) ++ Seq( + // [SPARK-14205][SQL] remove trait Queryable + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.Dataset") ) case v if v.startsWith("1.6") => Seq( From 9d0497f4ec7377c3f258a5a8c017a1f626dd02e3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 29 Mar 2016 07:34:57 +0800 Subject: [PATCH 3/3] fix tests --- .../test/scala/org/apache/spark/sql/QueryTest.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index a1b45ca7ebd19..7ff4ffcaecd49 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.{LogicalRDD, Queryable} +import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.LogicalRelation @@ -180,9 +180,9 @@ abstract class QueryTest extends PlanTest { } /** - * Asserts that a given [[Queryable]] will be executed using the given number of cached results. + * Asserts that a given [[Dataset]] will be executed using the given number of cached results. */ - def assertCached(query: Queryable, numCachedTables: Int = 1): Unit = { + def assertCached(query: Dataset[_], numCachedTables: Int = 1): Unit = { val planWithCaching = query.queryExecution.withCachedData val cachedData = planWithCaching collect { case cached: InMemoryRelation => cached @@ -286,9 +286,9 @@ abstract class QueryTest extends PlanTest { } /** - * Asserts that a given [[Queryable]] does not have missing inputs in all the analyzed plans. + * Asserts that a given [[Dataset]] does not have missing inputs in all the analyzed plans. */ - def assertEmptyMissingInput(query: Queryable): Unit = { + def assertEmptyMissingInput(query: Dataset[_]): Unit = { assert(query.queryExecution.analyzed.missingInput.isEmpty, s"The analyzed logical plan has missing inputs: ${query.queryExecution.analyzed}") assert(query.queryExecution.optimizedPlan.missingInput.isEmpty,