From c9adcad81a493595d075715efb2ae395d88a9ec2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 25 May 2015 23:09:22 -0700 Subject: [PATCH 001/254] [SQL][minor] Removed unused Catalyst logical plan DSL. The Catalyst DSL is no longer used as a public facing API. This pull request removes the UDF and writeToFile feature from it since they are not used in unit tests. Author: Reynold Xin Closes #6350 from rxin/unused-logical-dsl and squashes the following commits: 90b3de6 [Reynold Xin] [SQL][minor] Removed unused Catalyst logical plan DSL. --- .../spark/sql/catalyst/dsl/package.scala | 129 ++++-------------- 1 file changed, 27 insertions(+), 102 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 307a9ca9b0070..60ab9fba4885a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} import scala.language.implicitConversions -import scala.reflect.runtime.universe.{TypeTag, typeTag} import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedExtractValue, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ @@ -61,7 +60,7 @@ package object dsl { trait ImplicitOperators { def expr: Expression - def unary_- : Expression= UnaryMinus(expr) + def unary_- : Expression = UnaryMinus(expr) def unary_! : Predicate = Not(expr) def unary_~ : Expression = BitwiseNot(expr) @@ -234,133 +233,59 @@ package object dsl { implicit class DslAttribute(a: AttributeReference) { def notNull: AttributeReference = a.withNullability(false) def nullable: AttributeReference = a.withNullability(true) - - // Protobuf terminology - def required: AttributeReference = a.withNullability(false) - def at(ordinal: Int): BoundReference = BoundReference(ordinal, a.dataType, a.nullable) } } - object expressions extends ExpressionConversions // scalastyle:ignore - abstract class LogicalPlanFunctions { - def logicalPlan: LogicalPlan - - def select(exprs: NamedExpression*): LogicalPlan = Project(exprs, logicalPlan) + object plans { // scalastyle:ignore + implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) { + def select(exprs: NamedExpression*): LogicalPlan = Project(exprs, logicalPlan) - def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan) + def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan) - def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan) + def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan) - def join( + def join( otherPlan: LogicalPlan, joinType: JoinType = Inner, condition: Option[Expression] = None): LogicalPlan = - Join(logicalPlan, otherPlan, joinType, condition) + Join(logicalPlan, otherPlan, joinType, condition) - def orderBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, true, logicalPlan) + def orderBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, true, logicalPlan) - def sortBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, false, logicalPlan) + def sortBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, false, logicalPlan) - def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): LogicalPlan = { - val aliasedExprs = aggregateExprs.map { - case ne: NamedExpression => ne - case e => Alias(e, e.toString)() + def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): LogicalPlan = { + val aliasedExprs = aggregateExprs.map { + case ne: NamedExpression => ne + case e => Alias(e, e.toString)() + } + Aggregate(groupingExprs, aliasedExprs, logicalPlan) } - Aggregate(groupingExprs, aliasedExprs, logicalPlan) - } - - def subquery(alias: Symbol): LogicalPlan = Subquery(alias.name, logicalPlan) - def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) + def subquery(alias: Symbol): LogicalPlan = Subquery(alias.name, logicalPlan) - def except(otherPlan: LogicalPlan): LogicalPlan = Except(logicalPlan, otherPlan) + def except(otherPlan: LogicalPlan): LogicalPlan = Except(logicalPlan, otherPlan) - def intersect(otherPlan: LogicalPlan): LogicalPlan = Intersect(logicalPlan, otherPlan) + def intersect(otherPlan: LogicalPlan): LogicalPlan = Intersect(logicalPlan, otherPlan) - def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean): LogicalPlan = - Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan) + def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) - // TODO specify the output column names - def generate( + // TODO specify the output column names + def generate( generator: Generator, join: Boolean = false, outer: Boolean = false, alias: Option[String] = None): LogicalPlan = - Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan) + Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan) - def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = - InsertIntoTable( - analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite, false) + def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = + InsertIntoTable( + analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite, false) - def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer.execute(logicalPlan)) - } - - object plans { // scalastyle:ignore - implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) extends LogicalPlanFunctions { - def writeToFile(path: String): LogicalPlan = WriteToFile(path, logicalPlan) + def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer.execute(logicalPlan)) } } - - case class ScalaUdfBuilder[T: TypeTag](f: AnyRef) { - def call(args: Expression*): ScalaUdf = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args) - } - } - - // scalastyle:off - /** functionToUdfBuilder 1-22 were generated by this script - - (1 to 22).map { x => - val argTypes = Seq.fill(x)("_").mkString(", ") - s"implicit def functionToUdfBuilder[T: TypeTag](func: Function$x[$argTypes, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func)" - } - */ - - implicit def functionToUdfBuilder[T: TypeTag](func: Function1[_, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function2[_, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function3[_, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function4[_, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function5[_, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function6[_, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function7[_, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function8[_, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function9[_, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function10[_, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - // scalastyle:on } From 43aa819c041f6e8301ad1b8f82eb68e14254f636 Mon Sep 17 00:00:00 2001 From: Konstantin Shaposhnikov Date: Tue, 26 May 2015 07:49:32 +0100 Subject: [PATCH 002/254] [SPARK-7042] [BUILD] use the standard akka artifacts with hadoop-2.x Both akka 2.3.x and hadoop-2.x use protobuf 2.5 so only hadoop-1 build needs custom 2.3.4-spark akka version that shades protobuf-2.5 This partially fixes SPARK-7042 (for hadoop-2.x builds) Author: Konstantin Shaposhnikov Closes #6341 from kostya-sh/SPARK-7042 and squashes the following commits: 7eb8c60 [Konstantin Shaposhnikov] [SPARK-7042][BUILD] use the standard akka artifacts with hadoop-2.x --- pom.xml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index c72d7cbf843ef..8e936ab5ab9de 100644 --- a/pom.xml +++ b/pom.xml @@ -114,8 +114,8 @@ UTF-8 UTF-8 - org.spark-project.akka - 2.3.4-spark + com.typesafe.akka + 2.3.4 1.6 spark 2.0.1 @@ -1664,6 +1664,8 @@ 0.98.7-hadoop1 hadoop1 1.8.8 + org.spark-project.akka + 2.3.4-spark From bf49c22130af9a729dcc510743e4c1ea4c5d2439 Mon Sep 17 00:00:00 2001 From: scwf Date: Tue, 26 May 2015 08:42:52 -0500 Subject: [PATCH 003/254] [CORE] [TEST] Fix SimpleDateParamTest ``` sbt.ForkMain$ForkError: 1424424077190 was not equal to 1424474477190 at org.scalatest.MatchersHelper$.newTestFailedException(MatchersHelper.scala:160) at org.scalatest.Matchers$ShouldMethodHelper$.shouldMatcher(Matchers.scala:6231) at org.scalatest.Matchers$AnyShouldWrapper.should(Matchers.scala:6265) at org.apache.spark.status.api.v1.SimpleDateParamTest$$anonfun$1.apply$mcV$sp(SimpleDateParamTest.scala:25) at org.apache.spark.status.api.v1.SimpleDateParamTest$$anonfun$1.apply(SimpleDateParamTest.scala:23) at org.apache.spark.status.api.v1.SimpleDateParamTest$$anonfun$1.apply(SimpleDateParamTest.scala:23) at org.scalatest.Transformer$$anonfun$apply$1.apply$mcV$sp(Transformer.scala:22) at org.scalatest.OutcomeOf$class.outcomeOf(OutcomeOf.scala:85) at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104) at org.scalatest.Transformer.apply(Transformer.scala:22) at org.scalatest.Transformer.apply(Transformer.scala:20) at org.scalatest.FunSuiteLike$$anon$1.apply(FunSuiteLike.scala:166) at org.scalatest.Suite$class.withFixture(Suite.scala: ``` Set timezone to fix SimpleDateParamTest Author: scwf Author: Fei Wang Closes #6377 from scwf/fix-SimpleDateParamTest and squashes the following commits: b8df1e5 [Fei Wang] Update SimpleDateParamSuite.scala 8bb74f0 [scwf] fix SimpleDateParamSuite --- ...SimpleDateParamTest.scala => SimpleDateParamSuite.scala} | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) rename core/src/test/scala/org/apache/spark/status/api/v1/{SimpleDateParamTest.scala => SimpleDateParamSuite.scala} (86%) diff --git a/core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamTest.scala b/core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamSuite.scala similarity index 86% rename from core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamTest.scala rename to core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamSuite.scala index 5274df904d395..731d1f557ed33 100644 --- a/core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamTest.scala +++ b/core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.status.api.v1 import org.scalatest.{Matchers, FunSuite} -class SimpleDateParamTest extends FunSuite with Matchers { +class SimpleDateParamSuite extends FunSuite with Matchers { test("date parsing") { new SimpleDateParam("2015-02-20T23:21:17.190GMT").timestamp should be (1424474477190L) - new SimpleDateParam("2015-02-20T17:21:17.190CST").timestamp should be (1424474477190L) - new SimpleDateParam("2015-02-20").timestamp should be (1424390400000L) // GMT + new SimpleDateParam("2015-02-20T17:21:17.190EST").timestamp should be (1424470877190L) + new SimpleDateParam("2015-02-20").timestamp should be (1424390400000L) // GMT } } From 8948ad3fb5d5d095d3942855960d735f27d97dd5 Mon Sep 17 00:00:00 2001 From: linweizhong Date: Tue, 26 May 2015 08:35:39 -0700 Subject: [PATCH 004/254] [SPARK-7339] [PYSPARK] PySpark shuffle spill memory sometimes are not correct In PySpark we get memory used before and after spill, then use the difference of these two value as memorySpilled, but if the before value is small than after value, then we will get a negative value, but this scenario 0 value may be more reasonable. Below is the result in HistoryServer we have tested: Index ID Attempt Status Locality Level Executor ID / Host Launch Time Duration GC Time Input Size / Records Write Time Shuffle Write Size / Records Shuffle Spill (Memory) Shuffle Spill (Disk) Errors 0 0 0 SUCCESS NODE_LOCAL 3 / vm119 2015/05/04 17:31:06 21 s 0.1 s 128.1 MB (hadoop) / 3237 70 ms 10.1 MB / 2529 0.0 B 5.7 MB 2 2 0 SUCCESS NODE_LOCAL 1 / vm118 2015/05/04 17:31:06 22 s 89 ms 128.1 MB (hadoop) / 3205 0.1 s 10.1 MB / 2529 -1048576.0 B 5.9 MB 1 1 0 SUCCESS NODE_LOCAL 2 / vm117 2015/05/04 17:31:06 22 s 0.1 s 128.1 MB (hadoop) / 3271 68 ms 10.1 MB / 2529 -1048576.0 B 5.6 MB 4 4 0 SUCCESS NODE_LOCAL 2 / vm117 2015/05/04 17:31:06 22 s 0.1 s 128.1 MB (hadoop) / 3192 51 ms 10.1 MB / 2529 -1048576.0 B 5.9 MB 3 3 0 SUCCESS NODE_LOCAL 3 / vm119 2015/05/04 17:31:06 22 s 0.1 s 128.1 MB (hadoop) / 3262 51 ms 10.1 MB / 2529 1024.0 KB 5.8 MB 5 5 0 SUCCESS NODE_LOCAL 1 / vm118 2015/05/04 17:31:06 22 s 89 ms 128.1 MB (hadoop) / 3256 93 ms 10.1 MB / 2529 -1048576.0 B 5.7 MB /cc davies Author: linweizhong Closes #5887 from Sephiroth-Lin/spark-7339 and squashes the following commits: 9186c81 [linweizhong] Use max function to get a nonnegative value d41672b [linweizhong] Update MemoryBytesSpilled when memorySpilled > 0 --- python/pyspark/shuffle.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 1d0b16cade8bb..81c420ce16541 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -362,7 +362,7 @@ def _spill(self): self.spills += 1 gc.collect() # release the memory as much as possible - MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 def items(self): """ Return all merged items as iterator """ @@ -515,7 +515,7 @@ def load(f): gc.collect() batch //= 2 limit = self._next_limit() - MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 DiskBytesSpilled += os.path.getsize(path) os.unlink(path) # data will be deleted after close @@ -630,7 +630,7 @@ def _spill(self): self.values = [] gc.collect() DiskBytesSpilled += self._file.tell() - pos - MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 class ExternalListOfList(ExternalList): @@ -794,7 +794,7 @@ def _spill(self): self.spills += 1 gc.collect() # release the memory as much as possible - MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 def _merged_items(self, index): size = sum(os.path.getsize(os.path.join(self._get_spill_dir(j), str(index))) From 8dbe777703e0aaf47cbdfe98f66d22f723352fb5 Mon Sep 17 00:00:00 2001 From: meawoppl Date: Tue, 26 May 2015 09:02:25 -0700 Subject: [PATCH 005/254] [SPARK-7806][EC2] Fixes that allow the spark_ec2.py tool to run with Python3 I have used this script to launch, destroy, start, and stop clusters successfully. Author: meawoppl Closes #6336 from meawoppl/py3ec2spark and squashes the following commits: 2e87046 [meawoppl] Py3 compat fixes. --- ec2/spark_ec2.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index c6d5a1f0d0a81..724811eaa1bdd 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -19,8 +19,9 @@ # limitations under the License. # -from __future__ import with_statement, print_function +from __future__ import division, print_function, with_statement +import codecs import hashlib import itertools import logging @@ -47,6 +48,8 @@ else: from urllib.request import urlopen, Request from urllib.error import HTTPError + raw_input = input + xrange = range SPARK_EC2_VERSION = "1.3.1" SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -423,13 +426,14 @@ def get_spark_ami(opts): b=opts.spark_ec2_git_branch) ami_path = "%s/%s/%s" % (ami_prefix, opts.region, instance_type) + reader = codecs.getreader("ascii") try: - ami = urlopen(ami_path).read().strip() - print("Spark AMI: " + ami) + ami = reader(urlopen(ami_path)).read().strip() except: print("Could not resolve AMI at: " + ami_path, file=stderr) sys.exit(1) + print("Spark AMI: " + ami) return ami @@ -750,7 +754,7 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): 'mapreduce', 'spark-standalone', 'tachyon'] if opts.hadoop_major_version == "1": - modules = filter(lambda x: x != "mapreduce", modules) + modules = list(filter(lambda x: x != "mapreduce", modules)) if opts.ganglia: modules.append('ganglia') @@ -1160,7 +1164,7 @@ def get_zones(conn, opts): # Gets the number of items in a partition def get_partition(total, num_partitions, current_partitions): - num_slaves_this_zone = total / num_partitions + num_slaves_this_zone = total // num_partitions if (total % num_partitions) - current_partitions > 0: num_slaves_this_zone += 1 return num_slaves_this_zone From e5a63a0e39e35ba5f6b1732e09128c30fff17a2c Mon Sep 17 00:00:00 2001 From: Mike Dusenberry Date: Tue, 26 May 2015 17:05:58 +0100 Subject: [PATCH 006/254] [DOCS] [MLLIB] Fixing misformatted links in v1.4 MLlib Naive Bayes documentation by removing space and newline characters. A couple of links in the MLlib Naive Bayes documentation for v1.4 were broken due to the addition of either space or newline characters between the link title and link URL in the markdown doc. (Interestingly enough, they are rendered correctly in the GitHub viewer, but not when compiled to HTML by Jekyll.) Author: Mike Dusenberry Closes #6412 from dusenberrymw/Fix_Broken_Links_In_MLlib_Naive_Bayes_Docs and squashes the following commits: 91a4028 [Mike Dusenberry] Fixing misformatted links by removing space and newline characters. --- docs/mllib-naive-bayes.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index 56a2e9ca86bb1..acdcc371487f8 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -14,9 +14,8 @@ and use it for prediction. MLlib supports [multinomial naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes) -and [Bernoulli naive Bayes] (http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). -These models are typically used for [document classification] -(http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). +and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). +These models are typically used for [document classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). Within that context, each observation is a document and each feature represents a term whose value is the frequency of the term (in multinomial naive Bayes) or a zero or one indicating whether the term was found in the document (in Bernoulli naive Bayes). From 63099122deb93553f5a03bb3ff74e7c2f1ee2164 Mon Sep 17 00:00:00 2001 From: "Zhang, Liye" Date: Tue, 26 May 2015 17:08:16 +0100 Subject: [PATCH 007/254] [SPARK-7854] [TEST] refine Kryo test suite this modification is according to JoshRosen 's comments, for details, please refer to [#5934](https://github.com/apache/spark/pull/5934/files#r30949751). Author: Zhang, Liye Closes #6395 from liyezhang556520/kryoTest and squashes the following commits: da214c8 [Zhang, Liye] refine Kryo test suite accroding to Josh's comments --- .../serializer/KryoSerializerSuite.scala | 51 ++++++++++--------- 1 file changed, 27 insertions(+), 24 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 5faf108b394a1..8c384bd358ebc 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -34,38 +34,41 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.kryo.registrator", classOf[MyRegistrator].getName) - test("configuration limits") { - val conf1 = conf.clone() + test("SPARK-7392 configuration limits") { val kryoBufferProperty = "spark.kryoserializer.buffer" val kryoBufferMaxProperty = "spark.kryoserializer.buffer.max" - conf1.set(kryoBufferProperty, "64k") - conf1.set(kryoBufferMaxProperty, "64m") - new KryoSerializer(conf1).newInstance() + + def newKryoInstance( + conf: SparkConf, + bufferSize: String = "64k", + maxBufferSize: String = "64m"): SerializerInstance = { + val kryoConf = conf.clone() + kryoConf.set(kryoBufferProperty, bufferSize) + kryoConf.set(kryoBufferMaxProperty, maxBufferSize) + new KryoSerializer(kryoConf).newInstance() + } + + // test default values + newKryoInstance(conf, "64k", "64m") // 2048m = 2097152k - conf1.set(kryoBufferProperty, "2097151k") - conf1.set(kryoBufferMaxProperty, "64m") // should not throw exception when kryoBufferMaxProperty < kryoBufferProperty - new KryoSerializer(conf1).newInstance() - conf1.set(kryoBufferMaxProperty, "2097151k") - new KryoSerializer(conf1).newInstance() - val conf2 = conf.clone() - conf2.set(kryoBufferProperty, "2048m") - val thrown1 = intercept[IllegalArgumentException](new KryoSerializer(conf2).newInstance()) + newKryoInstance(conf, "2097151k", "64m") + // test maximum size with unit of KiB + newKryoInstance(conf, "2097151k", "2097151k") + // should throw exception with bufferSize out of bound + val thrown1 = intercept[IllegalArgumentException](newKryoInstance(conf, "2048m")) assert(thrown1.getMessage.contains(kryoBufferProperty)) - val conf3 = conf.clone() - conf3.set(kryoBufferMaxProperty, "2048m") - val thrown2 = intercept[IllegalArgumentException](new KryoSerializer(conf3).newInstance()) + // should throw exception with maxBufferSize out of bound + val thrown2 = intercept[IllegalArgumentException]( + newKryoInstance(conf, maxBufferSize = "2048m")) assert(thrown2.getMessage.contains(kryoBufferMaxProperty)) - val conf4 = conf.clone() - conf4.set(kryoBufferProperty, "2g") - conf4.set(kryoBufferMaxProperty, "3g") - val thrown3 = intercept[IllegalArgumentException](new KryoSerializer(conf4).newInstance()) + // should throw exception when both bufferSize and maxBufferSize out of bound + // exception should only contain "spark.kryoserializer.buffer" + val thrown3 = intercept[IllegalArgumentException](newKryoInstance(conf, "2g", "3g")) assert(thrown3.getMessage.contains(kryoBufferProperty)) assert(!thrown3.getMessage.contains(kryoBufferMaxProperty)) - val conf5 = conf.clone() - conf5.set(kryoBufferProperty, "8m") - conf5.set(kryoBufferMaxProperty, "9m") - new KryoSerializer(conf5).newInstance() + // test configuration with mb is supported properly + newKryoInstance(conf, "8m", "9m") } test("basic types") { From b7d8085942c564d6c5b81a14d31789e1b215e62b Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Tue, 26 May 2015 10:05:13 -0700 Subject: [PATCH 008/254] Revert "[SPARK-7042] [BUILD] use the standard akka artifacts with hadoop-2.x" This reverts commit 43aa819c041f6e8301ad1b8f82eb68e14254f636. --- pom.xml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pom.xml b/pom.xml index 8e936ab5ab9de..c72d7cbf843ef 100644 --- a/pom.xml +++ b/pom.xml @@ -114,8 +114,8 @@ UTF-8 UTF-8 - com.typesafe.akka - 2.3.4 + org.spark-project.akka + 2.3.4-spark 1.6 spark 2.0.1 @@ -1664,8 +1664,6 @@ 0.98.7-hadoop1 hadoop1 1.8.8 - org.spark-project.akka - 2.3.4-spark From 61664732b25b35f94be35a42cde651cbfd0e02b7 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 26 May 2015 13:21:00 -0700 Subject: [PATCH 009/254] [SPARK-7844] [MLLIB] Fix broken tests in KernelDensity The densities in KernelDensity are scaled down by (number of parallel processes X number of points). It should be just no.of samples. This results in broken tests in KernelDensitySuite which haven't been tested properly. Author: MechCoder Closes #6383 from MechCoder/spark-7844 and squashes the following commits: ab81302 [MechCoder] Math->math 9b8ed50 [MechCoder] Make one pass to update count a92fe50 [MechCoder] [SPARK-7844] Fix broken tests in KernelDensity --- .../org/apache/spark/mllib/stat/KernelDensity.scala | 2 +- .../apache/spark/mllib/stat/KernelDensitySuite.scala | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala index a6bfe26e1e4f5..58a50f9c19f14 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala @@ -93,7 +93,7 @@ class KernelDensity extends Serializable { x._1(i) += normPdf(y, bandwidth, logStandardDeviationPlusHalfLog2Pi, points(i)) i += 1 } - (x._1, n) + (x._1, x._2 + 1) }, (x, y) => { blas.daxpy(n, 1.0, y._1, 1, x._1, 1) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala index 14bb1cebf0b8f..a309c942cf8ff 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala @@ -29,8 +29,8 @@ class KernelDensitySuite extends FunSuite with MLlibTestSparkContext { val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints) val normal = new NormalDistribution(5.0, 3.0) val acceptableErr = 1e-6 - assert(densities(0) - normal.density(5.0) < acceptableErr) - assert(densities(0) - normal.density(6.0) < acceptableErr) + assert(math.abs(densities(0) - normal.density(5.0)) < acceptableErr) + assert(math.abs(densities(1) - normal.density(6.0)) < acceptableErr) } test("kernel density multiple samples") { @@ -40,7 +40,9 @@ class KernelDensitySuite extends FunSuite with MLlibTestSparkContext { val normal1 = new NormalDistribution(5.0, 3.0) val normal2 = new NormalDistribution(10.0, 3.0) val acceptableErr = 1e-6 - assert(densities(0) - (normal1.density(5.0) + normal2.density(5.0)) / 2 < acceptableErr) - assert(densities(0) - (normal1.density(6.0) + normal2.density(6.0)) / 2 < acceptableErr) + assert(math.abs( + densities(0) - (normal1.density(5.0) + normal2.density(5.0)) / 2) < acceptableErr) + assert(math.abs( + densities(1) - (normal1.density(6.0) + normal2.density(6.0)) / 2) < acceptableErr) } } From 2e9a5f229e1a2ccffa74fa59fa6a55b2704d9c1a Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Tue, 26 May 2015 15:01:27 -0700 Subject: [PATCH 010/254] [SPARK-3674] YARN support in Spark EC2 This corresponds to https://github.com/mesos/spark-ec2/pull/116 in the spark-ec2 repo. The only changes required on the spark_ec2.py script is to open the RM port. cc andrewor14 Author: Shivaram Venkataraman Closes #6376 from shivaram/spark-ec2-yarn and squashes the following commits: 961504a [Shivaram Venkataraman] Merge branch 'master' of https://github.com/apache/spark into spark-ec2-yarn 152c94c [Shivaram Venkataraman] Open 8088 for YARN in EC2 --- ec2/spark_ec2.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 724811eaa1bdd..ee0904c9e5d54 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -491,6 +491,8 @@ def launch_cluster(conn, opts, cluster_name): master_group.authorize('udp', 2049, 2049, authorized_address) master_group.authorize('tcp', 4242, 4242, authorized_address) master_group.authorize('udp', 4242, 4242, authorized_address) + # RM in YARN mode uses 8088 + master_group.authorize('tcp', 8088, 8088, authorized_address) if opts.ganglia: master_group.authorize('tcp', 5080, 5080, authorized_address) if slave_group.rules == []: # Group was just now created From 9f742241cbf07e5e2dadfee8dcc9b382bb2dbea1 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 26 May 2015 15:28:49 -0700 Subject: [PATCH 011/254] [SPARK-6602] [CORE] Remove some places in core that calling SparkEnv.actorSystem Author: zsxwing Closes #6333 from zsxwing/remove-actor-system-usage and squashes the following commits: f125aa6 [zsxwing] Fix YarnAllocatorSuite ceadcf6 [zsxwing] Change the "port" parameter type of "AkkaUtils.address" to "int"; update ApplicationMaster and YarnAllocator to get the driverUrl from RpcEnv 3239380 [zsxwing] Remove some places in core that calling SparkEnv.actorSystem --- .../spark/scheduler/TaskSchedulerImpl.scala | 17 +++++++++++------ .../mesos/CoarseMesosSchedulerBackend.scala | 9 ++++----- .../org/apache/spark/util/AkkaUtils.scala | 2 +- .../spark/deploy/yarn/ApplicationMaster.scala | 18 ++++++++++++------ .../spark/deploy/yarn/YarnAllocator.scala | 12 ++---------- .../spark/deploy/yarn/YarnRMClient.scala | 3 ++- .../spark/deploy/yarn/YarnAllocatorSuite.scala | 1 + 7 files changed, 33 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index b4b8a630694bb..ed3dde0fc3055 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -19,9 +19,9 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer import java.util.{TimerTask, Timer} +import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicLong -import scala.concurrent.duration._ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet @@ -32,7 +32,7 @@ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.TaskLocality.TaskLocality -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId @@ -64,6 +64,9 @@ private[spark] class TaskSchedulerImpl( // How often to check for speculative tasks val SPECULATION_INTERVAL_MS = conf.getTimeAsMs("spark.speculation.interval", "100ms") + private val speculationScheduler = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("task-scheduler-speculation") + // Threshold above which we warn user initial TaskSet may be starved val STARVATION_TIMEOUT_MS = conf.getTimeAsMs("spark.starvation.timeout", "15s") @@ -142,10 +145,11 @@ private[spark] class TaskSchedulerImpl( if (!isLocal && conf.getBoolean("spark.speculation", false)) { logInfo("Starting speculative execution thread") - sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL_MS milliseconds, - SPECULATION_INTERVAL_MS milliseconds) { - Utils.tryOrStopSparkContext(sc) { checkSpeculatableTasks() } - }(sc.env.actorSystem.dispatcher) + speculationScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryOrStopSparkContext(sc) { + checkSpeculatableTasks() + } + }, SPECULATION_INTERVAL_MS, SPECULATION_INTERVAL_MS, TimeUnit.MILLISECONDS) } } @@ -412,6 +416,7 @@ private[spark] class TaskSchedulerImpl( } override def stop() { + speculationScheduler.shutdown() if (backend != null) { backend.stop() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index dc59545b43314..aff086594c73f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -25,9 +25,10 @@ import scala.collection.mutable.{HashMap, HashSet} import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.mesos.{Scheduler => MScheduler, _} +import org.apache.spark.rpc.RpcAddress import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.Utils import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} /** @@ -115,11 +116,9 @@ private[spark] class CoarseMesosSchedulerBackend( } val command = CommandInfo.newBuilder() .setEnvironment(environment) - val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(sc.env.actorSystem), + val driverUrl = sc.env.rpcEnv.uriOf( SparkEnv.driverActorSystemName, - conf.get("spark.driver.host"), - conf.get("spark.driver.port"), + RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt), CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val uri = conf.getOption("spark.executor.uri") diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index de3316d083a22..7513b1b795dea 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -235,7 +235,7 @@ private[spark] object AkkaUtils extends Logging { protocol: String, systemName: String, host: String, - port: Any, + port: Int, actorName: String): String = { s"$protocol://$systemName@$host:$port/user/$actorName" } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index af4927b0e4bf7..760e458972d98 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -34,7 +34,7 @@ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, Spar import org.apache.spark.SparkException import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil} import org.apache.spark.deploy.history.HistoryServer -import org.apache.spark.scheduler.cluster.YarnSchedulerBackend +import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, YarnSchedulerBackend} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util._ @@ -220,7 +220,7 @@ private[spark] class ApplicationMaster( sparkContextRef.compareAndSet(sc, null) } - private def registerAM(uiAddress: String, securityMgr: SecurityManager) = { + private def registerAM(_rpcEnv: RpcEnv, uiAddress: String, securityMgr: SecurityManager) = { val sc = sparkContextRef.get() val appId = client.getAttemptId().getApplicationId().toString() @@ -231,8 +231,14 @@ private[spark] class ApplicationMaster( .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" } .getOrElse("") - allocator = client.register(yarnConf, - if (sc != null) sc.getConf else sparkConf, + val _sparkConf = if (sc != null) sc.getConf else sparkConf + val driverUrl = _rpcEnv.uriOf( + SparkEnv.driverActorSystemName, + RpcAddress(_sparkConf.get("spark.driver.host"), _sparkConf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + allocator = client.register(driverUrl, + yarnConf, + _sparkConf, if (sc != null) sc.preferredNodeLocationData else Map(), uiAddress, historyAddress, @@ -279,7 +285,7 @@ private[spark] class ApplicationMaster( sc.getConf.get("spark.driver.host"), sc.getConf.get("spark.driver.port"), isClusterMode = true) - registerAM(sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) + registerAM(rpcEnv, sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) userClassThread.join() } } @@ -289,7 +295,7 @@ private[spark] class ApplicationMaster( rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr) waitForSparkDriver() addAmIpFilter() - registerAM(sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) + registerAM(rpcEnv, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) // In client mode the actor will stop the reporter thread. reporterThread.join() diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 8a08f561a2df2..21193e7c625e3 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -34,10 +34,8 @@ import org.apache.hadoop.yarn.util.RackResolver import org.apache.log4j.{Level, Logger} -import org.apache.spark.{SparkEnv, Logging, SecurityManager, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ -import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.AkkaUtils /** * YarnAllocator is charged with requesting containers from the YARN ResourceManager and deciding @@ -53,6 +51,7 @@ import org.apache.spark.util.AkkaUtils * synchronized. */ private[yarn] class YarnAllocator( + driverUrl: String, conf: Configuration, sparkConf: SparkConf, amClient: AMRMClient[ContainerRequest], @@ -107,13 +106,6 @@ private[yarn] class YarnAllocator( new ThreadFactoryBuilder().setNameFormat("ContainerLauncher #%d").setDaemon(true).build()) launcherPool.allowCoreThreadTimeOut(true) - private val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(securityMgr.akkaSSLOptions.enabled), - SparkEnv.driverActorSystemName, - sparkConf.get("spark.driver.host"), - sparkConf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) - // For testing private val launchContainers = sparkConf.getBoolean("spark.yarn.launchContainers", true) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index ffe71dfd7d257..7f533ee55e8bb 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -55,6 +55,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg * @param uiHistoryAddress Address of the application on the History Server. */ def register( + driverUrl: String, conf: YarnConfiguration, sparkConf: SparkConf, preferredNodeLocations: Map[String, Set[SplitInfo]], @@ -72,7 +73,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) registered = true } - new YarnAllocator(conf, sparkConf, amClient, getAttemptId(), args, securityMgr) + new YarnAllocator(driverUrl, conf, sparkConf, amClient, getAttemptId(), args, securityMgr) } /** diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 455f1019d86dd..b343cbb0c7569 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -90,6 +90,7 @@ class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach "--jar", "somejar.jar", "--class", "SomeClass") new YarnAllocator( + "not used", conf, sparkConf, rmClient, From 836a75898fdc4b10d4d00676ef29e24cc96f09fd Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 26 May 2015 15:51:31 -0700 Subject: [PATCH 012/254] [SPARK-7748] [MLLIB] Graduate spark.ml from alpha With descent coverage of feature transformers, algorithms, and model tuning support, it is time to graduate `spark.ml` from alpha. This PR changes all `AlphaComponent` annotations to either `DeveloperApi` or `Experimental`, depending on whether we expect a class/method to be used by end users (who use the pipeline API to assemble/tune their ML pipelines but not to create new pipeline components.) `UnaryTransformer` becomes a `DeveloperApi` in this PR. jkbradley harsha2010 Author: Xiangrui Meng Closes #6417 from mengxr/SPARK-7748 and squashes the following commits: effbccd [Xiangrui Meng] organize imports c15028e [Xiangrui Meng] added missing docs 1b2e5f8 [Xiangrui Meng] update package doc 73ca791 [Xiangrui Meng] alpha -> ex/dev for the rest 93819db [Xiangrui Meng] alpha -> ex/dev in ml.param 55ca073 [Xiangrui Meng] alpha -> ex/dev in ml.feature 83572f1 [Xiangrui Meng] add Experimental and DeveloperApi tags (wip) --- .../scala/org/apache/spark/ml/Estimator.scala | 8 +-- .../scala/org/apache/spark/ml/Model.scala | 6 +- .../scala/org/apache/spark/ml/Pipeline.scala | 14 ++--- .../scala/org/apache/spark/ml/Predictor.scala | 3 - .../org/apache/spark/ml/Transformer.scala | 10 ++-- .../spark/ml/attribute/AttributeGroup.scala | 9 ++- .../spark/ml/attribute/AttributeType.scala | 5 ++ .../spark/ml/attribute/attributes.scala | 29 ++++++++- .../DecisionTreeClassifier.scala | 15 +++-- .../ml/classification/GBTClassifier.scala | 15 +++-- .../classification/LogisticRegression.scala | 19 +++--- .../spark/ml/classification/OneVsRest.scala | 7 +-- .../RandomForestClassifier.scala | 15 +++-- .../BinaryClassificationEvaluator.scala | 8 +-- .../spark/ml/evaluation/Evaluator.scala | 6 +- .../ml/evaluation/RegressionEvaluator.scala | 7 +-- .../apache/spark/ml/feature/Binarizer.scala | 6 +- .../apache/spark/ml/feature/Bucketizer.scala | 6 +- .../spark/ml/feature/ElementwiseProduct.scala | 6 +- .../apache/spark/ml/feature/HashingTF.scala | 10 ++-- .../org/apache/spark/ml/feature/IDF.scala | 10 ++-- .../apache/spark/ml/feature/Normalizer.scala | 6 +- .../spark/ml/feature/OneHotEncoder.scala | 7 ++- .../ml/feature/PolynomialExpansion.scala | 6 +- .../spark/ml/feature/StandardScaler.scala | 10 ++-- .../spark/ml/feature/StringIndexer.scala | 10 ++-- .../apache/spark/ml/feature/Tokenizer.scala | 10 ++-- .../spark/ml/feature/VectorAssembler.scala | 6 +- .../spark/ml/feature/VectorIndexer.scala | 12 ++-- .../apache/spark/ml/feature/Word2Vec.scala | 10 ++-- .../org/apache/spark/ml/package-info.java | 6 +- .../scala/org/apache/spark/ml/package.scala | 2 +- .../org/apache/spark/ml/param/params.scala | 59 ++++++++++++++----- .../apache/spark/ml/recommendation/ALS.scala | 14 ++++- .../ml/regression/DecisionTreeRegressor.scala | 15 +++-- .../spark/ml/regression/GBTRegressor.scala | 14 ++--- .../ml/regression/LinearRegression.scala | 12 ++-- .../ml/regression/RandomForestRegressor.scala | 15 +++-- .../scala/org/apache/spark/ml/tree/Node.scala | 8 ++- .../org/apache/spark/ml/tree/Split.scala | 7 +++ .../org/apache/spark/ml/tree/treeParams.scala | 9 --- .../spark/ml/tuning/CrossValidator.scala | 10 ++-- .../spark/ml/tuning/ParamGridBuilder.scala | 6 +- 43 files changed, 267 insertions(+), 201 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index 9e16e60270141..e9a5d7c0e7988 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -19,15 +19,15 @@ package org.apache.spark.ml import scala.annotation.varargs -import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.param.{ParamMap, ParamPair, Params} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.param.{ParamMap, ParamPair} import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: + * :: DeveloperApi :: * Abstract class for estimators that fit models to data. */ -@AlphaComponent +@DeveloperApi abstract class Estimator[M <: Model[M]] extends PipelineStage { /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala index 70e7495ac616c..186bf7ae7a2f6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala @@ -17,16 +17,16 @@ package org.apache.spark.ml -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.ParamMap /** - * :: AlphaComponent :: + * :: DeveloperApi :: * A fitted model, i.e., a [[Transformer]] produced by an [[Estimator]]. * * @tparam M model type */ -@AlphaComponent +@DeveloperApi abstract class Model[M <: Model[M]] extends Transformer { /** * The parent estimator that produced this model. diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 43bee1b770e67..9da3ff65c744e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -20,17 +20,17 @@ package org.apache.spark.ml import scala.collection.mutable.ListBuffer import org.apache.spark.Logging -import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.param.{Param, ParamMap, Params} import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType /** - * :: AlphaComponent :: + * :: DeveloperApi :: * A stage in a pipeline, either an [[Estimator]] or a [[Transformer]]. */ -@AlphaComponent +@DeveloperApi abstract class PipelineStage extends Params with Logging { /** @@ -69,7 +69,7 @@ abstract class PipelineStage extends Params with Logging { } /** - * :: AlphaComponent :: + * :: Experimental :: * A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each * of which is either an [[Estimator]] or a [[Transformer]]. When [[Pipeline#fit]] is called, the * stages are executed in order. If a stage is an [[Estimator]], its [[Estimator#fit]] method will @@ -80,7 +80,7 @@ abstract class PipelineStage extends Params with Logging { * transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as * an identity transformer. */ -@AlphaComponent +@Experimental class Pipeline(override val uid: String) extends Estimator[PipelineModel] { def this() = this(Identifiable.randomUID("pipeline")) @@ -169,10 +169,10 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] { } /** - * :: AlphaComponent :: + * :: Experimental :: * Represents a fitted pipeline. */ -@AlphaComponent +@Experimental class PipelineModel private[ml] ( override val uid: String, val stages: Array[Transformer]) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index ec0f76aa668bd..e752b81a14282 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -58,7 +58,6 @@ private[ml] trait PredictorParams extends Params /** * :: DeveloperApi :: - * * Abstraction for prediction problems (regression and classification). * * @tparam FeaturesType Type of features. @@ -113,7 +112,6 @@ abstract class Predictor[ * * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector. */ - @DeveloperApi private[ml] def featuresDataType: DataType = new VectorUDT override def transformSchema(schema: StructType): StructType = { @@ -134,7 +132,6 @@ abstract class Predictor[ /** * :: DeveloperApi :: - * * Abstraction for a model for prediction tasks (regression and classification). * * @tparam FeaturesType Type of features. diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 38bb6a5a5391e..f07f733a5ddb5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml import scala.annotation.varargs import org.apache.spark.Logging -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.sql.DataFrame @@ -28,10 +28,10 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ /** - * :: AlphaComponent :: + * :: DeveloperApi :: * Abstract class for transformers that transform one dataset into another. */ -@AlphaComponent +@DeveloperApi abstract class Transformer extends PipelineStage { /** @@ -73,10 +73,12 @@ abstract class Transformer extends PipelineStage { } /** + * :: DeveloperApi :: * Abstract class for transformers that take one input column, apply transformation, and output the * result as a new column. */ -private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] +@DeveloperApi +abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] extends Transformer with HasInputCol with HasOutputCol with Logging { /** @group setParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala index f5f37aa77929c..457c15830fd38 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala @@ -19,10 +19,12 @@ package org.apache.spark.ml.attribute import scala.collection.mutable.ArrayBuffer +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.VectorUDT import org.apache.spark.sql.types.{Metadata, MetadataBuilder, StructField} /** + * :: DeveloperApi :: * Attributes that describe a vector ML column. * * @param name name of the attribute group (the ML column name) @@ -31,6 +33,7 @@ import org.apache.spark.sql.types.{Metadata, MetadataBuilder, StructField} * @param attrs optional array of attributes. Attribute will be copied with their corresponding * indices in the array. */ +@DeveloperApi class AttributeGroup private ( val name: String, val numAttributes: Option[Int], @@ -182,7 +185,11 @@ class AttributeGroup private ( } } -/** Factory methods to create attribute groups. */ +/** + * :: DeveloperApi :: + * Factory methods to create attribute groups. + */ +@DeveloperApi object AttributeGroup { import AttributeKeys._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala index a83febd7de2cc..5c7089b491677 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala @@ -17,12 +17,17 @@ package org.apache.spark.ml.attribute +import org.apache.spark.annotation.DeveloperApi + /** + * :: DeveloperApi :: * An enum-like type for attribute types: [[AttributeType$#Numeric]], [[AttributeType$#Nominal]], * and [[AttributeType$#Binary]]. */ +@DeveloperApi sealed abstract class AttributeType(val name: String) +@DeveloperApi object AttributeType { /** Numeric type. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala index e8f7f152784a1..ce43a450daad0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala @@ -19,11 +19,14 @@ package org.apache.spark.ml.attribute import scala.annotation.varargs +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, StructField} /** + * :: DeveloperApi :: * Abstract class for ML attributes. */ +@DeveloperApi sealed abstract class Attribute extends Serializable { name.foreach { n => @@ -135,6 +138,10 @@ private[attribute] trait AttributeFactory { } } +/** + * :: DeveloperApi :: + */ +@DeveloperApi object Attribute extends AttributeFactory { private[attribute] override def fromMetadata(metadata: Metadata): Attribute = { @@ -163,6 +170,7 @@ object Attribute extends AttributeFactory { /** + * :: DeveloperApi :: * A numeric attribute with optional summary statistics. * @param name optional name * @param index optional index @@ -171,6 +179,7 @@ object Attribute extends AttributeFactory { * @param std optional standard deviation * @param sparsity optional sparsity (ratio of zeros) */ +@DeveloperApi class NumericAttribute private[ml] ( override val name: Option[String] = None, override val index: Option[Int] = None, @@ -278,8 +287,10 @@ class NumericAttribute private[ml] ( } /** + * :: DeveloperApi :: * Factory methods for numeric attributes. */ +@DeveloperApi object NumericAttribute extends AttributeFactory { /** The default numeric attribute. */ @@ -298,6 +309,7 @@ object NumericAttribute extends AttributeFactory { } /** + * :: DeveloperApi :: * A nominal attribute. * @param name optional name * @param index optional index @@ -306,6 +318,7 @@ object NumericAttribute extends AttributeFactory { * defined. * @param values optional values. At most one of `numValues` and `values` can be defined. */ +@DeveloperApi class NominalAttribute private[ml] ( override val name: Option[String] = None, override val index: Option[Int] = None, @@ -430,7 +443,11 @@ class NominalAttribute private[ml] ( } } -/** Factory methods for nominal attributes. */ +/** + * :: DeveloperApi :: + * Factory methods for nominal attributes. + */ +@DeveloperApi object NominalAttribute extends AttributeFactory { /** The default nominal attribute. */ @@ -450,11 +467,13 @@ object NominalAttribute extends AttributeFactory { } /** + * :: DeveloperApi :: * A binary attribute. * @param name optional name * @param index optional index * @param values optionla values. If set, its size must be 2. */ +@DeveloperApi class BinaryAttribute private[ml] ( override val name: Option[String] = None, override val index: Option[Int] = None, @@ -526,7 +545,11 @@ class BinaryAttribute private[ml] ( } } -/** Factory methods for binary attributes. */ +/** + * :: DeveloperApi :: + * Factory methods for binary attributes. + */ +@DeveloperApi object BinaryAttribute extends AttributeFactory { /** The default binary attribute. */ @@ -543,8 +566,10 @@ object BinaryAttribute extends AttributeFactory { } /** + * :: DeveloperApi :: * An unresolved attribute. */ +@DeveloperApi object UnresolvedAttribute extends Attribute { override def attrType: AttributeType = AttributeType.Unresolved diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 7c961332bf5b6..8030e0728a56c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.classification -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{TreeClassifierParams, DecisionTreeParams, DecisionTreeModel, Node} +import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -31,14 +31,13 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm * for classification. * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ -@AlphaComponent +@Experimental final class DecisionTreeClassifier(override val uid: String) extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] with DecisionTreeParams with TreeClassifierParams { @@ -89,19 +88,19 @@ final class DecisionTreeClassifier(override val uid: String) } } +@Experimental object DecisionTreeClassifier { /** Accessor for supported impurities: entropy, gini */ final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities } /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for classification. * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ -@AlphaComponent +@Experimental final class DecisionTreeClassificationModel private[ml] ( override val uid: String, override val rootNode: Node) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index d504d84beb91e..d8592eb2d947d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -20,11 +20,11 @@ package org.apache.spark.ml.classification import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeClassifierParams, TreeEnsembleModel} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -36,14 +36,13 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] * learning algorithm for classification. * It supports binary labels, as well as both continuous and categorical features. * Note: Multiclass labels are not currently supported. */ -@AlphaComponent +@Experimental final class GBTClassifier(override val uid: String) extends Predictor[Vector, GBTClassifier, GBTClassificationModel] with GBTParams with TreeClassifierParams with Logging { @@ -144,6 +143,7 @@ final class GBTClassifier(override val uid: String) } } +@Experimental object GBTClassifier { // The losses below should be lowercase. /** Accessor for supported loss settings: logistic */ @@ -151,8 +151,7 @@ object GBTClassifier { } /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] * model for classification. * It supports binary labels, as well as both continuous and categorical features. @@ -160,7 +159,7 @@ object GBTClassifier { * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. */ -@AlphaComponent +@Experimental final class GBTClassificationModel( override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 8694c96e4c5b6..d13109d9da4c0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -19,11 +19,11 @@ package org.apache.spark.ml.classification import scala.collection.mutable -import breeze.linalg.{norm => brzNorm, DenseVector => BDV} -import breeze.optimize.{LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} -import breeze.optimize.{CachedDiffFunction, DiffFunction} +import breeze.linalg.{DenseVector => BDV, norm => brzNorm} +import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.Identifiable @@ -35,7 +35,6 @@ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.storage.StorageLevel -import org.apache.spark.{SparkException, Logging} /** * Params for logistic regression. @@ -45,12 +44,11 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas with HasThreshold /** - * :: AlphaComponent :: - * + * :: Experimental :: * Logistic regression. * Currently, this class only supports binary classification. */ -@AlphaComponent +@Experimental class LogisticRegression(override val uid: String) extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] with LogisticRegressionParams with Logging { @@ -221,11 +219,10 @@ class LogisticRegression(override val uid: String) } /** - * :: AlphaComponent :: - * + * :: Experimental :: * Model produced by [[LogisticRegression]]. */ -@AlphaComponent +@Experimental class LogisticRegressionModel private[ml] ( override val uid: String, val weights: Vector, diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 1543f051ccd17..36735cd834cd4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -21,7 +21,7 @@ import java.util.UUID import scala.language.existentials -import org.apache.spark.annotation.{AlphaComponent, Experimental} +import org.apache.spark.annotation.Experimental import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.Param @@ -54,8 +54,7 @@ private[ml] trait OneVsRestParams extends PredictorParams { } /** - * :: AlphaComponent :: - * + * :: Experimental :: * Model produced by [[OneVsRest]]. * This stores the models resulting from training k binary classifiers: one for each class. * Each example is scored against all k models, and the model with the highest score @@ -67,7 +66,7 @@ private[ml] trait OneVsRestParams extends PredictorParams { * The i-th model is produced by testing the i-th class (taking label 1) vs the rest * (taking label 0). */ -@AlphaComponent +@Experimental final class OneVsRestModel private[ml] ( override val uid: String, labelMetadata: Metadata, diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index a1de7919859eb..67600ebd7b38e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -19,10 +19,10 @@ package org.apache.spark.ml.classification import scala.collection.mutable -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -33,14 +33,13 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for * classification. * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ -@AlphaComponent +@Experimental final class RandomForestClassifier(override val uid: String) extends Predictor[Vector, RandomForestClassifier, RandomForestClassificationModel] with RandomForestParams with TreeClassifierParams { @@ -100,6 +99,7 @@ final class RandomForestClassifier(override val uid: String) } } +@Experimental object RandomForestClassifier { /** Accessor for supported impurity settings: entropy, gini */ final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities @@ -110,15 +110,14 @@ object RandomForestClassifier { } /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for classification. * It supports both binary and multiclass labels, as well as both continuous and categorical * features. * @param _trees Decision trees in the ensemble. * Warning: These have null parents. */ -@AlphaComponent +@Experimental final class RandomForestClassificationModel private[ml] ( override val uid: String, private val _trees: Array[DecisionTreeClassificationModel]) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index ddbdd00ceb159..f695ddaeefc72 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.evaluation -import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.evaluation.Evaluator +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.{Identifiable, SchemaUtils} @@ -28,11 +27,10 @@ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.DoubleType /** - * :: AlphaComponent :: - * + * :: Experimental :: * Evaluator for binary classification, which expects two input columns: score and label. */ -@AlphaComponent +@Experimental class BinaryClassificationEvaluator(override val uid: String) extends Evaluator with HasRawPredictionCol with HasLabelCol { diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala index cabd1c97c085c..61e937e693699 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala @@ -17,15 +17,15 @@ package org.apache.spark.ml.evaluation -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.{ParamMap, Params} import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: + * :: DeveloperApi :: * Abstract class for evaluators that compute metrics from predictions. */ -@AlphaComponent +@DeveloperApi abstract class Evaluator extends Params { /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index 80458928c5439..1771177e1ea91 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.evaluation -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param.{Param, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} import org.apache.spark.ml.util.{Identifiable, SchemaUtils} @@ -26,11 +26,10 @@ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.DoubleType /** - * :: AlphaComponent :: - * + * :: Experimental :: * Evaluator for regression, which expects two input columns: prediction and label. */ -@AlphaComponent +@Experimental final class RegressionEvaluator(override val uid: String) extends Evaluator with HasPredictionCol with HasLabelCol { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 62f4a6343423e..b06122d733853 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.param._ @@ -28,10 +28,10 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} /** - * :: AlphaComponent :: + * :: Experimental :: * Binarize a column of continuous features given a threshold. */ -@AlphaComponent +@Experimental final class Binarizer(override val uid: String) extends Transformer with HasInputCol with HasOutputCol { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index ac8dfb5632a7b..a3d1f6f65ccaf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import java.{util => ju} import org.apache.spark.SparkException -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Model import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ @@ -31,10 +31,10 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** - * :: AlphaComponent :: + * :: Experimental :: * `Bucketizer` maps a column of continuous features to a column of feature buckets. */ -@AlphaComponent +@Experimental final class Bucketizer(override val uid: String) extends Model[Bucketizer] with HasInputCol with HasOutputCol { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index 8b32eee0e490a..3ae1833390152 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.Param import org.apache.spark.ml.util.Identifiable @@ -26,12 +26,12 @@ import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.types.DataType /** - * :: AlphaComponent :: + * :: Experimental :: * Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a * provided "weight" vector. In other words, it scales each column of the dataset by a scalar * multiplier. */ -@AlphaComponent +@Experimental class ElementwiseProduct(override val uid: String) extends UnaryTransformer[Vector, Vector, ElementwiseProduct] { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 8942d45219177..f936aef80f8af 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -17,22 +17,22 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.AttributeGroup -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.param.{IntParam, ParamValidators} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.feature import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.functions.{udf, col} +import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{ArrayType, StructType} /** - * :: AlphaComponent :: + * :: Experimental :: * Maps a sequence of terms to their term frequencies using the hashing trick. */ -@AlphaComponent +@Experimental class HashingTF(override val uid: String) extends Transformer with HasInputCol with HasOutputCol { def this() = this(Identifiable.randomUID("hashingTF")) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 788c392050c2d..376b84530cd57 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -58,10 +58,10 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol } /** - * :: AlphaComponent :: + * :: Experimental :: * Compute the Inverse Document Frequency (IDF) given a collection of documents. */ -@AlphaComponent +@Experimental final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase { def this() = this(Identifiable.randomUID("idf")) @@ -85,10 +85,10 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa } /** - * :: AlphaComponent :: + * :: Experimental :: * Model fitted by [[IDF]]. */ -@AlphaComponent +@Experimental class IDFModel private[ml] ( override val uid: String, idfModel: feature.IDFModel) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index 3f689d1585cd6..8282e5ffa17f7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{DoubleParam, ParamValidators} import org.apache.spark.ml.util.Identifiable @@ -26,10 +26,10 @@ import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.types.DataType /** - * :: AlphaComponent :: + * :: Experimental :: * Normalize a vector to have unit norm using the given p-norm. */ -@AlphaComponent +@Experimental class Normalizer(override val uid: String) extends UnaryTransformer[Vector, Vector, Normalizer] { def this() = this(Identifiable.randomUID("normalizer")) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 1fb9b9ae75091..eb6ec49f854be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -18,16 +18,17 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkException -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute} -import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.sql.types.{DataType, DoubleType, StructType} /** + * :: Experimental :: * A one-hot encoder that maps a column of label indices to a column of binary vectors, with * at most a single one-value. By default, the binary vector has an element for each category, so * with 5 categories, an input value of 2.0 would map to an output vector of @@ -36,7 +37,7 @@ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} * of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns * linearly dependent because they sum up to one. */ -@AlphaComponent +@Experimental class OneHotEncoder(override val uid: String) extends UnaryTransformer[Double, Vector, OneHotEncoder] with HasInputCol with HasOutputCol { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 8ddf9d6a1e138..442e95820217a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature import scala.collection.mutable -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{IntParam, ParamValidators} import org.apache.spark.ml.util.Identifiable @@ -27,14 +27,14 @@ import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.types.DataType /** - * :: AlphaComponent :: + * :: Experimental :: * Perform feature expansion in a polynomial space. As said in wikipedia of Polynomial Expansion, * which is available at [[http://en.wikipedia.org/wiki/Polynomial_expansion]], "In mathematics, an * expansion of a product of sums expresses it as a sum of products by using the fact that * multiplication distributes over addition". Take a 2-variable feature vector as an example: * `(x, y)`, if we want to expand it with degree 2, then we get `(x, x * x, y, x * y, y * y)`. */ -@AlphaComponent +@Experimental class PolynomialExpansion(override val uid: String) extends UnaryTransformer[Vector, Vector, PolynomialExpansion] { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 5ccda15d872ed..fdd2494fc87a6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -51,11 +51,11 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with } /** - * :: AlphaComponent :: + * :: Experimental :: * Standardizes features by removing the mean and scaling to unit variance using column summary * statistics on the samples in the training set. */ -@AlphaComponent +@Experimental class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel] with StandardScalerParams { @@ -95,10 +95,10 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM } /** - * :: AlphaComponent :: + * :: Experimental :: * Model fitted by [[StandardScaler]]. */ -@AlphaComponent +@Experimental class StandardScalerModel private[ml] ( override val uid: String, scaler: feature.StandardScalerModel) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 3f79b67309f07..a2dc8a8b960c5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkException -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ @@ -52,13 +52,13 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha } /** - * :: AlphaComponent :: + * :: Experimental :: * A label indexer that maps a string column of labels to an ML column of label indices. * If the input column is numeric, we cast it to string and index the string values. * The indices are in [0, numLabels), ordered by label frequencies. * So the most frequent label gets index 0. */ -@AlphaComponent +@Experimental class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] with StringIndexerBase { @@ -86,10 +86,10 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod } /** - * :: AlphaComponent :: + * :: Experimental :: * Model fitted by [[StringIndexer]]. */ -@AlphaComponent +@Experimental class StringIndexerModel private[ml] ( override val uid: String, labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 31f3a1aa4c76b..21c15b6c33f6c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -17,19 +17,19 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.types.{ArrayType, DataType, StringType} /** - * :: AlphaComponent :: + * :: Experimental :: * A tokenizer that converts the input string to lowercase and then splits it by white spaces. * * @see [[RegexTokenizer]] */ -@AlphaComponent +@Experimental class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[String], Tokenizer] { def this() = this(Identifiable.randomUID("tok")) @@ -46,13 +46,13 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S } /** - * :: AlphaComponent :: + * :: Experimental :: * A regex based tokenizer that extracts tokens either by using the provided regex pattern to split * the text (default) or repeatedly matching the regex (if `gaps` is true). * Optional parameters also allow filtering tokens using a minimal length. * It returns an array of strings that can be empty. */ -@AlphaComponent +@Experimental class RegexTokenizer(override val uid: String) extends UnaryTransformer[String, Seq[String], RegexTokenizer] { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 181b62f46fce8..514ffb03c0509 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkException -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.Identifiable @@ -30,10 +30,10 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ /** - * :: AlphaComponent :: + * :: Experimental :: * A feature transformer that merges multiple columns into a vector column. */ -@AlphaComponent +@Experimental class VectorAssembler(override val uid: String) extends Transformer with HasInputCols with HasOutputCol { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index e238fb310ed37..1d0f23b4fb3db 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -22,7 +22,7 @@ import java.util.{Map => JMap} import scala.collection.JavaConverters._ -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.{IntParam, ParamValidators, Params} @@ -56,8 +56,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu } /** - * :: AlphaComponent :: - * + * :: Experimental :: * Class for indexing categorical feature columns in a dataset of [[Vector]]. * * This has 2 usage modes: @@ -91,7 +90,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu * - Add warning if a categorical feature has only 1 category. * - Add option for allowing unknown categories. */ -@AlphaComponent +@Experimental class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerModel] with VectorIndexerParams { @@ -230,8 +229,7 @@ private object VectorIndexer { } /** - * :: AlphaComponent :: - * + * :: Experimental :: * Transform categorical features to use 0-based indices instead of their original values. * - Categorical features are mapped to indices. * - Continuous features (columns) are left unchanged. @@ -246,7 +244,7 @@ private object VectorIndexer { * Values are maps from original features values to 0-based category indices. * If a feature is not in this map, it is treated as continuous. */ -@AlphaComponent +@Experimental class VectorIndexerModel private[ml] ( override val uid: String, val numFeatures: Int, diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index ed032669229ce..36f19509f0cfb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -82,11 +82,11 @@ private[feature] trait Word2VecBase extends Params } /** - * :: AlphaComponent :: + * :: Experimental :: * Word2Vec trains a model of `Map(String, Vector)`, i.e. transforms a word into a code for further * natural language processing or machine learning process. */ -@AlphaComponent +@Experimental final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] with Word2VecBase { def this() = this(Identifiable.randomUID("w2v")) @@ -135,10 +135,10 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] } /** - * :: AlphaComponent :: + * :: Experimental :: * Model fitted by [[Word2Vec]]. */ -@AlphaComponent +@Experimental class Word2VecModel private[ml] ( override val uid: String, wordVectors: feature.Word2VecModel) diff --git a/mllib/src/main/scala/org/apache/spark/ml/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/package-info.java index 00d9c802e930d..87f4223964ada 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/package-info.java +++ b/mllib/src/main/scala/org/apache/spark/ml/package-info.java @@ -16,10 +16,10 @@ */ /** - * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly + * Spark ML is a BETA component that adds a new set of machine learning APIs to let users quickly * assemble and configure practical machine learning pipelines. */ -@AlphaComponent +@Experimental package org.apache.spark.ml; -import org.apache.spark.annotation.AlphaComponent; +import org.apache.spark.annotation.Experimental; diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala index ac75e9de1a8f2..c589d06d9f7e4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/package.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala @@ -18,7 +18,7 @@ package org.apache.spark /** - * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly + * Spark ML is a BETA component that adds a new set of machine learning APIs to let users quickly * assemble and configure practical machine learning pipelines. * * @groupname param Parameters diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 12fc5b561f76e..1afa59c994c2b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -24,11 +24,11 @@ import scala.annotation.varargs import scala.collection.mutable import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{DeveloperApi, AlphaComponent} +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.util.Identifiable /** - * :: AlphaComponent :: + * :: DeveloperApi :: * A param with self-contained documentation and optionally default value. Primitive-typed param * should use the specialized versions, which are more friendly to Java users. * @@ -39,7 +39,7 @@ import org.apache.spark.ml.util.Identifiable * See [[ParamValidators]] for factory methods for common validation functions. * @tparam T param value type */ -@AlphaComponent +@DeveloperApi class Param[T](val parent: String, val name: String, val doc: String, val isValid: T => Boolean) extends Serializable { @@ -174,7 +174,11 @@ object ParamValidators { // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... -/** Specialized version of [[Param[Double]]] for Java. */ +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Double]]] for Java. + */ +@DeveloperApi class DoubleParam(parent: String, name: String, doc: String, isValid: Double => Boolean) extends Param[Double](parent, name, doc, isValid) { @@ -189,7 +193,11 @@ class DoubleParam(parent: String, name: String, doc: String, isValid: Double => override def w(value: Double): ParamPair[Double] = super.w(value) } -/** Specialized version of [[Param[Int]]] for Java. */ +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Int]]] for Java. + */ +@DeveloperApi class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolean) extends Param[Int](parent, name, doc, isValid) { @@ -204,7 +212,11 @@ class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolea override def w(value: Int): ParamPair[Int] = super.w(value) } -/** Specialized version of [[Param[Float]]] for Java. */ +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Float]]] for Java. + */ +@DeveloperApi class FloatParam(parent: String, name: String, doc: String, isValid: Float => Boolean) extends Param[Float](parent, name, doc, isValid) { @@ -219,7 +231,11 @@ class FloatParam(parent: String, name: String, doc: String, isValid: Float => Bo override def w(value: Float): ParamPair[Float] = super.w(value) } -/** Specialized version of [[Param[Long]]] for Java. */ +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Long]]] for Java. + */ +@DeveloperApi class LongParam(parent: String, name: String, doc: String, isValid: Long => Boolean) extends Param[Long](parent, name, doc, isValid) { @@ -234,7 +250,11 @@ class LongParam(parent: String, name: String, doc: String, isValid: Long => Bool override def w(value: Long): ParamPair[Long] = super.w(value) } -/** Specialized version of [[Param[Boolean]]] for Java. */ +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Boolean]]] for Java. + */ +@DeveloperApi class BooleanParam(parent: String, name: String, doc: String) // No need for isValid extends Param[Boolean](parent, name, doc) { @@ -243,7 +263,11 @@ class BooleanParam(parent: String, name: String, doc: String) // No need for isV override def w(value: Boolean): ParamPair[Boolean] = super.w(value) } -/** Specialized version of [[Param[Array[String]]]] for Java. */ +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Array[String]]]] for Java. + */ +@DeveloperApi class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array[String] => Boolean) extends Param[Array[String]](parent, name, doc, isValid) { @@ -256,7 +280,11 @@ class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray) } -/** Specialized version of [[Param[Array[Double]]]] for Java. */ +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Array[Double]]]] for Java. + */ +@DeveloperApi class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array[Double] => Boolean) extends Param[Array[Double]](parent, name, doc, isValid) { @@ -270,8 +298,10 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array } /** + * :: Experimental :: * A param amd its value. */ +@Experimental case class ParamPair[T](param: Param[T], value: T) { // This is *the* place Param.validate is called. Whenever a parameter is specified, we should // always construct a ParamPair so that validate is called. @@ -279,11 +309,11 @@ case class ParamPair[T](param: Param[T], value: T) { } /** - * :: AlphaComponent :: + * :: DeveloperApi :: * Trait for components that take parameters. This also provides an internal param map to store * parameter values attached to the instance. */ -@AlphaComponent +@DeveloperApi trait Params extends Identifiable with Serializable { /** @@ -541,10 +571,10 @@ trait Params extends Identifiable with Serializable { abstract class JavaParams extends Params /** - * :: AlphaComponent :: + * :: Experimental :: * A param to value map. */ -@AlphaComponent +@Experimental final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) extends Serializable { @@ -665,6 +695,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) def size: Int = map.size } +@Experimental object ParamMap { /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 2a5ddbfae5cdf..900b637ff8ad4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.netlib.util.intW import org.apache.spark.{Logging, Partitioner} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -169,8 +169,10 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR } /** + * :: Experimental :: * Model fitted by ALS. */ +@Experimental class ALSModel private[ml] ( override val uid: String, k: Int, @@ -208,6 +210,7 @@ class ALSModel private[ml] ( /** + * :: Experimental :: * Alternating Least Squares (ALS) matrix factorization. * * ALS attempts to estimate the ratings matrix `R` as the product of two lower-rank matrices, @@ -236,6 +239,7 @@ class ALSModel private[ml] ( * indicated user * preferences rather than explicit ratings given to items. */ +@Experimental class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { import org.apache.spark.ml.recommendation.ALS.Rating @@ -326,7 +330,11 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { @DeveloperApi object ALS extends Logging { - /** Rating class for better code readability. */ + /** + * :: DeveloperApi :: + * Rating class for better code readability. + */ + @DeveloperApi case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float) /** Trait for least squares solvers applied to the normal equation. */ @@ -487,8 +495,10 @@ object ALS extends Logging { } /** + * :: DeveloperApi :: * Implementation of the ALS algorithm. */ + @DeveloperApi def train[ID: ClassTag]( // scalastyle:ignore ratings: RDD[Rating[ID]], rank: Int = 10, diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index e67df21b2e4ae..43b68e7bb20fa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.regression -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{TreeRegressorParams, DecisionTreeParams, DecisionTreeModel, Node} +import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -31,13 +31,12 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm * for regression. * It supports both continuous and categorical features. */ -@AlphaComponent +@Experimental final class DecisionTreeRegressor(override val uid: String) extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] with DecisionTreeParams with TreeRegressorParams { @@ -79,19 +78,19 @@ final class DecisionTreeRegressor(override val uid: String) } } +@Experimental object DecisionTreeRegressor { /** Accessor for supported impurities: variance */ final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities } /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for regression. * It supports both continuous and categorical features. * @param rootNode Root of the decision tree */ -@AlphaComponent +@Experimental final class DecisionTreeRegressionModel private[ml] ( override val uid: String, override val rootNode: Node) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 4249ff5c1ebc7..69f4f5414c8c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -20,10 +20,10 @@ package org.apache.spark.ml.regression import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.ml.tree.{GBTParams, TreeRegressorParams, DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel, TreeRegressorParams} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -35,13 +35,12 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] * learning algorithm for regression. * It supports both continuous and categorical features. */ -@AlphaComponent +@Experimental final class GBTRegressor(override val uid: String) extends Predictor[Vector, GBTRegressor, GBTRegressionModel] with GBTParams with TreeRegressorParams with Logging { @@ -134,6 +133,7 @@ final class GBTRegressor(override val uid: String) } } +@Experimental object GBTRegressor { // The losses below should be lowercase. /** Accessor for supported loss settings: squared (L2), absolute (L1) */ @@ -141,7 +141,7 @@ object GBTRegressor { } /** - * :: AlphaComponent :: + * :: Experimental :: * * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] * model for regression. @@ -149,7 +149,7 @@ object GBTRegressor { * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. */ -@AlphaComponent +@Experimental final class GBTRegressionModel( override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 3ebb78f79201a..7c40db1a40040 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -23,7 +23,7 @@ import breeze.linalg.{DenseVector => BDV, norm => brzNorm} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} import org.apache.spark.Logging -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol} @@ -44,8 +44,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol /** - * :: AlphaComponent :: - * + * :: Experimental :: * Linear regression. * * The learning objective is to minimize the squared error, with regularization. @@ -58,7 +57,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams * - L1 (Lasso) * - L2 + L1 (elastic net) */ -@AlphaComponent +@Experimental class LinearRegression(override val uid: String) extends Regressor[Vector, LinearRegression, LinearRegressionModel] with LinearRegressionParams with Logging { @@ -190,11 +189,10 @@ class LinearRegression(override val uid: String) } /** - * :: AlphaComponent :: - * + * :: Experimental :: * Model produced by [[LinearRegression]]. */ -@AlphaComponent +@Experimental class LinearRegressionModel private[ml] ( override val uid: String, val weights: Vector, diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 82437aa8de294..ae767a17329d2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.regression -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{RandomForestParams, TreeRegressorParams, DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -31,12 +31,11 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for regression. * It supports both continuous and categorical features. */ -@AlphaComponent +@Experimental final class RandomForestRegressor(override val uid: String) extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel] with RandomForestParams with TreeRegressorParams { @@ -89,6 +88,7 @@ final class RandomForestRegressor(override val uid: String) } } +@Experimental object RandomForestRegressor { /** Accessor for supported impurity settings: variance */ final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities @@ -99,13 +99,12 @@ object RandomForestRegressor { } /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression. * It supports both continuous and categorical features. * @param _trees Decision trees in the ensemble. */ -@AlphaComponent +@Experimental final class RandomForestRegressionModel private[ml] ( override val uid: String, private val _trees: Array[DecisionTreeRegressionModel]) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index d2dec0c76cb12..6a84176efb086 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -17,14 +17,16 @@ package org.apache.spark.ml.tree +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} - /** + * :: DeveloperApi :: * Decision tree node interface. */ +@DeveloperApi sealed abstract class Node extends Serializable { // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree @@ -89,10 +91,12 @@ private[ml] object Node { } /** + * :: DeveloperApi :: * Decision tree leaf node. * @param prediction Prediction this node makes * @param impurity Impurity measure at this node (for training data) */ +@DeveloperApi final class LeafNode private[ml] ( override val prediction: Double, override val impurity: Double) extends Node { @@ -118,6 +122,7 @@ final class LeafNode private[ml] ( } /** + * :: DeveloperApi :: * Internal Decision Tree node. * @param prediction Prediction this node would make if it were a leaf node * @param impurity Impurity measure at this node (for training data) @@ -127,6 +132,7 @@ final class LeafNode private[ml] ( * @param rightChild Right-hand child node * @param split Information about the test used to split to the left or right child. */ +@DeveloperApi final class InternalNode private[ml] ( override val prediction: Double, override val impurity: Double, diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala index 90f1d052764d3..7acdeeee72d23 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala @@ -17,15 +17,18 @@ package org.apache.spark.ml.tree +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType} import org.apache.spark.mllib.tree.model.{Split => OldSplit} /** + * :: DeveloperApi :: * Interface for a "Split," which specifies a test made at a decision tree node * to choose the left or right path. */ +@DeveloperApi sealed trait Split extends Serializable { /** Index of feature which this split tests */ @@ -52,12 +55,14 @@ private[tree] object Split { } /** + * :: DeveloperApi :: * Split which tests a categorical feature. * @param featureIndex Index of the feature to test * @param _leftCategories If the feature value is in this set of categories, then the split goes * left. Otherwise, it goes right. * @param numCategories Number of categories for this feature. */ +@DeveloperApi final class CategoricalSplit private[ml] ( override val featureIndex: Int, _leftCategories: Array[Double], @@ -125,11 +130,13 @@ final class CategoricalSplit private[ml] ( } /** + * :: DeveloperApi :: * Split which tests a continuous feature. * @param featureIndex Index of the feature to test * @param threshold If the feature value is <= this threshold, then the split goes left. * Otherwise, it goes right. */ +@DeveloperApi final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double) extends Split { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 816fcedf2efb3..a0c5238d966bf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.tree -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed} @@ -26,12 +25,10 @@ import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldG import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} /** - * :: DeveloperApi :: * Parameters for Decision Tree-based algorithms. * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -@DeveloperApi private[ml] trait DecisionTreeParams extends PredictorParams { /** @@ -265,12 +262,10 @@ private[ml] object TreeRegressorParams { } /** - * :: DeveloperApi :: * Parameters for Decision Tree-based ensemble algorithms. * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -@DeveloperApi private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { /** @@ -307,12 +302,10 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { } /** - * :: DeveloperApi :: * Parameters for Random Forest algorithms. * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -@DeveloperApi private[ml] trait RandomForestParams extends TreeEnsembleParams { /** @@ -377,12 +370,10 @@ private[ml] object RandomForestParams { } /** - * :: DeveloperApi :: * Parameters for Gradient-Boosted Tree algorithms. * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -@DeveloperApi private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index e21ff94a20f54..2e5a629561180 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.tuning import com.github.fommil.netlib.F2jBLAS import org.apache.spark.Logging -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml._ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param._ @@ -79,10 +79,10 @@ private[ml] trait CrossValidatorParams extends Params { } /** - * :: AlphaComponent :: + * :: Experimental :: * K-fold cross validation. */ -@AlphaComponent +@Experimental class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel] with CrossValidatorParams with Logging { @@ -150,10 +150,10 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM } /** - * :: AlphaComponent :: + * :: Experimental :: * Model from k-fold cross validation. */ -@AlphaComponent +@Experimental class CrossValidatorModel private[ml] ( override val uid: String, val bestModel: Model[_]) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala index dafe73d82c00a..98a8f0330ca45 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala @@ -20,14 +20,14 @@ package org.apache.spark.ml.tuning import scala.annotation.varargs import scala.collection.mutable -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param._ /** - * :: AlphaComponent :: + * :: Experimental :: * Builder for a param grid used in grid search-based model selection. */ -@AlphaComponent +@Experimental class ParamGridBuilder { private val paramGrid = mutable.Map.empty[Param[_], Iterable[_]] From 8f2082426828c15704426ebca1d015bf956c6841 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 26 May 2015 16:31:34 -0700 Subject: [PATCH 013/254] [SPARK-7864] [UI] Do not kill innocent stages from visualization **Reproduction.** Run a long-running job, go to the job page, expand the DAG visualization, and click into a stage. Your stage is now killed. Why? This is because the visualization code just reaches into the stage table and grabs the first link it finds. In our case, this first link happens to be the kill link instead of the one to the stage page. **Fix.** Use proper CSS selectors to avoid ambiguity. This is an alternative to #6407. Thanks carsonwang for catching this. Author: Andrew Or Closes #6419 from andrewor14/fix-ui-viz-kill and squashes the following commits: 25203bd [Andrew Or] Do not kill innocent stages --- .../main/resources/org/apache/spark/ui/static/spark-dag-viz.js | 2 +- .../main/resources/org/apache/spark/ui/static/timeline-view.js | 2 +- core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js index aaeba5b1027c9..e96af8768daa0 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -193,7 +193,7 @@ function renderDagVizForJob(svgContainer) { // Use the link from the stage table so it also works for the history server var attemptId = 0 var stageLink = d3.select("#stage-" + stageId + "-" + attemptId) - .select("a") + .select("a.name-link") .attr("href") + "&expandDagViz=true"; container = svgContainer .append("a") diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js index 604c29994145a..28ac998e8d065 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js +++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js @@ -105,7 +105,7 @@ function drawJobTimeline(groupArray, eventObjArray, startTime) { }; $(this).click(function() { - var stagePagePath = $(getSelectorForStageEntry(this)).find("a").attr("href") + var stagePagePath = $(getSelectorForStageEntry(this)).find("a.name-link").attr("href") window.location.href = stagePagePath }); diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 82ba561eefb16..99812db4912a3 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -93,7 +93,7 @@ private[ui] class StageTableBase( } val nameLinkUri = s"$basePathUri/stages/stage?id=${s.stageId}&attempt=${s.attemptId}" - val nameLink = {s.name} + val nameLink = {s.name} val cachedRddInfos = s.rddInfos.filter(_.numCachedPartitions > 0) val details = if (s.details.nonEmpty) { From 0463428b6e8f364f0b1f39445a60cd85ae7c07bc Mon Sep 17 00:00:00 2001 From: Mike Dusenberry Date: Tue, 26 May 2015 18:08:57 -0700 Subject: [PATCH 014/254] [SPARK-7883] [DOCS] [MLLIB] Fixing broken trainImplicit Scala example in MLlib Collaborative Filtering documentation. Fixing broken trainImplicit Scala example in MLlib Collaborative Filtering documentation to match one of the possible ALS.trainImplicit function signatures. Author: Mike Dusenberry Closes #6422 from dusenberrymw/Fix_MLlib_Collab_Filtering_trainImplicit_Example and squashes the following commits: 36492f4 [Mike Dusenberry] Fixing broken trainImplicit example in MLlib Collaborative Filtering documentation to match one of the possible ALS.trainImplicit function signatures. --- docs/mllib-collaborative-filtering.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 7b397e30b2d90..dfdf6216b270c 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -107,7 +107,8 @@ other signals), you can use the `trainImplicit` method to get better results. {% highlight scala %} val alpha = 0.01 -val model = ALS.trainImplicit(ratings, rank, numIterations, alpha) +val lambda = 0.01 +val model = ALS.trainImplicit(ratings, rank, numIterations, lambda, alpha) {% endhighlight %} From 03668348e29eb52c1a7d57a1e0ed7fca6c323890 Mon Sep 17 00:00:00 2001 From: rowan Date: Tue, 26 May 2015 18:17:16 -0700 Subject: [PATCH 015/254] [SPARK-7637] [SQL] O(N) merge implementation for StructType merge Contribution is my original work and I license the work to the project under the projects open source license. Author: rowan Closes #6259 from rowan000/SPARK-7637 and squashes the following commits: c479df4 [rowan] SPARK-7637: rename mapFields to fieldsMap as per comments on github. 8d2e419 [rowan] SPARK-7637: fix up whitespace changes 0e9d662 [rowan] SPARK-7637: O(N) merge implementatio for StructType merge --- .../apache/spark/sql/types/StructType.scala | 12 ++- .../spark/sql/types/DataTypeSuite.scala | 73 ++++++++++++++++++- 2 files changed, 81 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 7e00a27dfe724..a4f30c825befb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -230,10 +230,10 @@ object StructType { case (StructType(leftFields), StructType(rightFields)) => val newFields = ArrayBuffer.empty[StructField] + val rightMapped = fieldsMap(rightFields) leftFields.foreach { case leftField @ StructField(leftName, leftType, leftNullable, _) => - rightFields - .find(_.name == leftName) + rightMapped.get(leftName) .map { case rightField @ StructField(_, rightType, rightNullable, _) => leftField.copy( dataType = merge(leftType, rightType), @@ -243,8 +243,9 @@ object StructType { .foreach(newFields += _) } + val leftMapped = fieldsMap(leftFields) rightFields - .filterNot(f => leftFields.map(_.name).contains(f.name)) + .filterNot(f => leftMapped.get(f.name).nonEmpty) .foreach(newFields += _) StructType(newFields) @@ -264,4 +265,9 @@ object StructType { case _ => throw new SparkException(s"Failed to merge incompatible data types $left and $right") } + + private[sql] def fieldsMap(fields: Array[StructField]): Map[String, StructField] = { + import scala.collection.breakOut + fields.map(s => (s.name, s))(breakOut) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index d797510f36685..a73317c86916b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.types +import org.apache.spark.SparkException import org.scalatest.FunSuite class DataTypeSuite extends FunSuite { @@ -69,6 +70,76 @@ class DataTypeSuite extends FunSuite { } } + test("fieldsMap returns map of name to StructField") { + val struct = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + + val mapped = StructType.fieldsMap(struct.fields) + + val expected = Map( + "a" -> StructField("a", LongType), + "b" -> StructField("b", FloatType)) + + assert(mapped === expected) + } + + test("merge where right is empty") { + val left = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + + val right = StructType(List()) + val merged = left.merge(right) + + assert(merged === left) + } + + test("merge where left is empty") { + + val left = StructType(List()) + + val right = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + + val merged = left.merge(right) + + assert(right === merged) + + } + + test("merge where both are non-empty") { + val left = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + + val right = StructType( + StructField("c", LongType) :: Nil) + + val expected = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: + StructField("c", LongType) :: Nil) + + val merged = left.merge(right) + + assert(merged === expected) + } + + test("merge where right contains type conflict") { + val left = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + + val right = StructType( + StructField("b", LongType) :: Nil) + + intercept[SparkException] { + left.merge(right) + } + } + def checkDataTypeJsonRepr(dataType: DataType): Unit = { test(s"JSON - $dataType") { assert(DataType.fromJson(dataType.json) === dataType) @@ -120,7 +191,7 @@ class DataTypeSuite extends FunSuite { checkDefaultSize(DecimalType(10, 5), 4096) checkDefaultSize(DecimalType.Unlimited, 4096) checkDefaultSize(DateType, 4) - checkDefaultSize(TimestampType,12) + checkDefaultSize(TimestampType, 12) checkDefaultSize(StringType, 4096) checkDefaultSize(BinaryType, 4096) checkDefaultSize(ArrayType(DoubleType, true), 800) From 0c33c7b4a66e47f6246f1b7f2b96f2c33126ec63 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 26 May 2015 20:24:35 -0700 Subject: [PATCH 016/254] [SPARK-7858] [SQL] Use output schema, not relation schema, for data source input conversion In `DataSourceStrategy.createPhysicalRDD`, we use the relation schema as the target schema for converting incoming rows into Catalyst rows. However, we should be using the output schema instead, since our scan might return a subset of the relation's columns. This patch incorporates #6414 by liancheng, which fixes an issue in `SimpleTestRelation` that prevented this bug from being caught by our old tests: > In `SimpleTextRelation`, we specified `needsConversion` to `true`, indicating that values produced by this testing relation should be of Scala types, and need to be converted to Catalyst types when necessary. However, we also used `Cast` to convert strings to expected data types. And `Cast` always produces values of Catalyst types, thus no conversion is done at all. This PR makes `SimpleTextRelation` produce Scala values so that data conversion code paths can be properly tested. Closes #5986. Author: Josh Rosen Author: Cheng Lian Author: Cheng Lian Closes #6400 from JoshRosen/SPARK-7858 and squashes the following commits: e71c866 [Josh Rosen] Re-fix bug so that the tests pass again 56b13e5 [Josh Rosen] Add regression test to hadoopFsRelationSuites 2169a0f [Josh Rosen] Remove use of SpecificMutableRow and BufferedIterator 6cd7366 [Josh Rosen] Fix SPARK-7858 by using output types for conversion. 5a00e66 [Josh Rosen] Add assertions in order to reproduce SPARK-7858 8ba195c [Cheng Lian] Merge 9968fba9979287aaa1f141ba18bfb9d4c116a3b3 into 61664732b25b35f94be35a42cde651cbfd0e02b7 9968fba [Cheng Lian] Tests the data type conversion code paths --- .../org/apache/spark/sql/SQLContext.scala | 2 +- .../spark/sql/execution/ExistingRDD.scala | 62 +++++++------------ .../sql/sources/DataSourceStrategy.scala | 2 +- .../sql/sources/SimpleTextRelation.scala | 6 +- .../sql/sources/hadoopFsRelationSuites.scala | 6 ++ 5 files changed, 37 insertions(+), 41 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 1ea596dddff02..3935f7b321b85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -392,7 +392,7 @@ class SQLContext(@transient val sparkContext: SparkContext) SparkPlan.currentContext.set(self) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes - val rowRDD = RDDConversions.productToRowRdd(rdd, schema) + val rowRDD = RDDConversions.productToRowRdd(rdd, schema.map(_.dataType)) DataFrame(self, LogicalRDD(attributeSeq, rowRDD)(self)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index a500269f3cdcf..f931dc95ef575 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -21,9 +21,9 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.DataType import org.apache.spark.sql.{Row, SQLContext} /** @@ -31,26 +31,19 @@ import org.apache.spark.sql.{Row, SQLContext} */ @DeveloperApi object RDDConversions { - def productToRowRdd[A <: Product](data: RDD[A], schema: StructType): RDD[Row] = { + def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[Row] = { data.mapPartitions { iterator => - if (iterator.isEmpty) { - Iterator.empty - } else { - val bufferedIterator = iterator.buffered - val mutableRow = new SpecificMutableRow(schema.fields.map(_.dataType)) - val schemaFields = schema.fields.toArray - val converters = schemaFields.map { - f => CatalystTypeConverters.createToCatalystConverter(f.dataType) - } - bufferedIterator.map { r => - var i = 0 - while (i < mutableRow.length) { - mutableRow(i) = converters(i)(r.productElement(i)) - i += 1 - } - - mutableRow + val numColumns = outputTypes.length + val mutableRow = new GenericMutableRow(numColumns) + val converters = outputTypes.map(CatalystTypeConverters.createToCatalystConverter) + iterator.map { r => + var i = 0 + while (i < numColumns) { + mutableRow(i) = converters(i)(r.productElement(i)) + i += 1 } + + mutableRow } } } @@ -58,26 +51,19 @@ object RDDConversions { /** * Convert the objects inside Row into the types Catalyst expected. */ - def rowToRowRdd(data: RDD[Row], schema: StructType): RDD[Row] = { + def rowToRowRdd(data: RDD[Row], outputTypes: Seq[DataType]): RDD[Row] = { data.mapPartitions { iterator => - if (iterator.isEmpty) { - Iterator.empty - } else { - val bufferedIterator = iterator.buffered - val mutableRow = new GenericMutableRow(bufferedIterator.head.toSeq.toArray) - val schemaFields = schema.fields.toArray - val converters = schemaFields.map { - f => CatalystTypeConverters.createToCatalystConverter(f.dataType) - } - bufferedIterator.map { r => - var i = 0 - while (i < mutableRow.length) { - mutableRow(i) = converters(i)(r(i)) - i += 1 - } - - mutableRow + val numColumns = outputTypes.length + val mutableRow = new GenericMutableRow(numColumns) + val converters = outputTypes.map(CatalystTypeConverters.createToCatalystConverter) + iterator.map { r => + var i = 0 + while (i < numColumns) { + mutableRow(i) = converters(i)(r(i)) + i += 1 } + + mutableRow } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index dacd967cff856..c6a4dabbab05e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -309,7 +309,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { output: Seq[Attribute], rdd: RDD[Row]): SparkPlan = { val converted = if (relation.needConversion) { - execution.RDDConversions.rowToRowRdd(rdd, relation.schema) + execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType)) } else { rdd } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index de907846b9180..0f959b3d0b86d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputForma import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.{Row, SQLContext} @@ -108,7 +109,10 @@ class SimpleTextRelation( sparkContext.textFile(inputStatuses.map(_.getPath).mkString(",")).map { record => Row(record.split(",").zip(fields).map { case (value, dataType) => - Cast(Literal(value), dataType).eval() + // `Cast`ed values are always of Catalyst types (i.e. UTF8String instead of String, etc.) + val catalystValue = Cast(Literal(value), dataType).eval() + // Here we're converting Catalyst values to Scala values to test `needsConversion` + CatalystTypeConverters.convertToScala(catalystValue, dataType) }: _*) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 70328e1ef810d..7c02d563f8d9a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -76,6 +76,12 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { df.filter('a > 1 && 'p1 < 2).select('b, 'p1), for (i <- 2 to 3; _ <- Seq("foo", "bar")) yield Row(s"val_$i", 1)) + // Project many copies of columns with different types (reproduction for SPARK-7858) + checkAnswer( + df.filter('a > 1 && 'p1 < 2).select('b, 'b, 'b, 'b, 'p1, 'p1, 'p1, 'p1), + for (i <- 2 to 3; _ <- Seq("foo", "bar")) + yield Row(s"val_$i", s"val_$i", s"val_$i", s"val_$i", 1, 1, 1, 1)) + // Self-join df.registerTempTable("t") withTempTable("t") { From b463e6d618e69c535297e51f41eca4f91bd33cc8 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 26 May 2015 20:48:56 -0700 Subject: [PATCH 017/254] [SPARK-7868] [SQL] Ignores _temporary directories in HadoopFsRelation So that potential partial/corrupted data files left by failed tasks/jobs won't affect normal data scan. Author: Cheng Lian Closes #6411 from liancheng/spark-7868 and squashes the following commits: 273ea36 [Cheng Lian] Ignores _temporary directories --- .../apache/spark/sql/sources/interfaces.scala | 20 ++++++++++++------- .../sql/sources/hadoopFsRelationSuites.scala | 16 +++++++++++++++ 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index aaabbadcd651b..c06026e042d9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -31,7 +31,7 @@ import org.apache.spark.SerializableWritable import org.apache.spark.sql.{Row, _} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.StructType /** * ::DeveloperApi:: @@ -378,16 +378,22 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio var leafDirToChildrenFiles = mutable.Map.empty[Path, Array[FileStatus]] def refresh(): Unit = { + // We don't filter files/directories whose name start with "_" or "." here, as specific data + // sources may take advantages over them (e.g. Parquet _metadata and _common_metadata files). + // But "_temporary" directories are explicitly ignored since failed tasks/jobs may leave + // partial/corrupted data files there. def listLeafFilesAndDirs(fs: FileSystem, status: FileStatus): Set[FileStatus] = { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) - val leafDirs = if (dirs.isEmpty) Set(status) else Set.empty[FileStatus] - files.toSet ++ leafDirs ++ dirs.flatMap(dir => listLeafFilesAndDirs(fs, dir)) + if (status.getPath.getName.toLowerCase == "_temporary") { + Set.empty + } else { + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) + val leafDirs = if (dirs.isEmpty) Set(status) else Set.empty[FileStatus] + files.toSet ++ leafDirs ++ dirs.flatMap(dir => listLeafFilesAndDirs(fs, dir)) + } } leafFiles.clear() - // We don't filter files/directories like _temporary/_SUCCESS here, as specific data sources - // may take advantages over them (e.g. Parquet _metadata and _common_metadata files). val statuses = paths.flatMap { path => val hdfsPath = new Path(path) val fs = hdfsPath.getFileSystem(hadoopConf) @@ -395,7 +401,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio Try(fs.getFileStatus(qualified)).toOption.toArray.flatMap(listLeafFilesAndDirs(fs, _)) } - val (dirs, files) = statuses.partition(_.isDir) + val files = statuses.filterNot(_.isDir) leafFiles ++= files.map(f => f.getPath -> f).toMap leafDirToChildrenFiles ++= files.groupBy(_.getPath.getParent) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 7c02d563f8d9a..cf5ae88dc4bee 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -548,4 +548,20 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { checkAnswer(table("t"), df.select('b, 'c, 'a).collect()) } } + + test("SPARK-7868: _temporary directories should be ignored") { + withTempPath { dir => + val df = Seq("a", "b", "c").zipWithIndex.toDF() + + df.write + .format("parquet") + .save(dir.getCanonicalPath) + + df.write + .format("parquet") + .save(s"${dir.getCanonicalPath}/_temporary") + + checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df.collect()) + } + } } From a9f1c0c57b9be586dbada09dab91dcfce31141d9 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 26 May 2015 23:51:32 -0700 Subject: [PATCH 018/254] [SPARK-7535] [.1] [MLLIB] minor changes to the pipeline API 1. removed `Params.validateParams(extra)` 2. added `Evaluate.evaluate(dataset, paramPairs*)` 3. updated `RegressionEvaluator` doc jkbradley Author: Xiangrui Meng Closes #6392 from mengxr/SPARK-7535.1 and squashes the following commits: 5ff5af8 [Xiangrui Meng] add unit test for CV.validateParams f1f8369 [Xiangrui Meng] update CV.validateParams() to test estimatorParamMaps 607445d [Xiangrui Meng] merge master 8716f5f [Xiangrui Meng] specify default metric name in RegressionEvaluator e4e5631 [Xiangrui Meng] update RegressionEvaluator doc 801e864 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7535.1 fcbd3e2 [Xiangrui Meng] Merge branch 'master' into SPARK-7535.1 2192316 [Xiangrui Meng] remove validateParams(extra); add evaluate(dataset, extra*) --- .../scala/org/apache/spark/ml/Pipeline.scala | 9 ++-- .../ml/evaluation/RegressionEvaluator.scala | 4 +- .../org/apache/spark/ml/param/params.scala | 13 ----- .../spark/ml/tuning/CrossValidator.scala | 23 +++++--- .../apache/spark/ml/param/ParamsSuite.scala | 2 +- .../spark/ml/tuning/CrossValidatorSuite.scala | 52 ++++++++++++++++++- 6 files changed, 71 insertions(+), 32 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 9da3ff65c744e..11a4722722ea1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -97,12 +97,9 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] { /** @group getParam */ def getStages: Array[PipelineStage] = $(stages).clone() - override def validateParams(paramMap: ParamMap): Unit = { - val map = extractParamMap(paramMap) - getStages.foreach { - case pStage: Params => pStage.validateParams(map) - case _ => - } + override def validateParams(): Unit = { + super.validateParams() + $(stages).foreach(_.validateParams()) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index 1771177e1ea91..abb1b35bedea5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -36,8 +36,8 @@ final class RegressionEvaluator(override val uid: String) def this() = this(Identifiable.randomUID("regEval")) /** - * param for metric name in evaluation - * @group param supports mse, rmse, r2, mae as valid metric names. + * param for metric name in evaluation (supports `"rmse"` (default), `"mse"`, `"r2"`, and `"mae"`) + * @group param */ val metricName: Param[String] = { val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae")) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 1afa59c994c2b..473488dce9b0d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -333,19 +333,6 @@ trait Params extends Identifiable with Serializable { .map(m => m.invoke(this).asInstanceOf[Param[_]]) } - /** - * Validates parameter values stored internally plus the input parameter map. - * Raises an exception if any parameter is invalid. - * - * This only needs to check for interactions between parameters. - * Parameter value checks which do not depend on other parameters are handled by - * [[Param.validate()]]. This method does not handle input/output column parameters; - * those are checked during schema validation. - */ - def validateParams(paramMap: ParamMap): Unit = { - copy(paramMap).validateParams() - } - /** * Validates parameter values stored internally. * Raise an exception if any parameter value is invalid. diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 2e5a629561180..6434b64aed15d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -102,12 +102,6 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM /** @group setParam */ def setNumFolds(value: Int): this.type = set(numFolds, value) - override def validateParams(paramMap: ParamMap): Unit = { - getEstimatorParamMaps.foreach { eMap => - getEstimator.validateParams(eMap ++ paramMap) - } - } - override def fit(dataset: DataFrame): CrossValidatorModel = { val schema = dataset.schema transformSchema(schema, logging = true) @@ -147,6 +141,14 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM override def transformSchema(schema: StructType): StructType = { $(estimator).transformSchema(schema) } + + override def validateParams(): Unit = { + super.validateParams() + val est = $(estimator) + for (paramMap <- $(estimatorParamMaps)) { + est.copy(paramMap).validateParams() + } + } } /** @@ -159,8 +161,8 @@ class CrossValidatorModel private[ml] ( val bestModel: Model[_]) extends Model[CrossValidatorModel] with CrossValidatorParams { - override def validateParams(paramMap: ParamMap): Unit = { - bestModel.validateParams(paramMap) + override def validateParams(): Unit = { + bestModel.validateParams() } override def transform(dataset: DataFrame): DataFrame = { @@ -171,4 +173,9 @@ class CrossValidatorModel private[ml] ( override def transformSchema(schema: StructType): StructType = { bestModel.transformSchema(schema) } + + override def copy(extra: ParamMap): CrossValidatorModel = { + val copied = new CrossValidatorModel(uid, bestModel.copy(extra).asInstanceOf[Model[_]]) + copyValues(copied, extra) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index d270ad7613af1..04f2af4727ea4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -135,7 +135,7 @@ class ParamsSuite extends FunSuite { intercept[IllegalArgumentException] { solver.validateParams() } - solver.validateParams(ParamMap(inputCol -> "input")) + solver.copy(ParamMap(inputCol -> "input")).validateParams() solver.setInputCol("input") assert(solver.isSet(inputCol)) assert(solver.isDefined(inputCol)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 05313d440fbf6..65972ec79b9a5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -19,11 +19,15 @@ package org.apache.spark.ml.tuning import org.scalatest.FunSuite +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator +import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{SQLContext, DataFrame} +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.types.StructType class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { @@ -53,4 +57,48 @@ class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) } + + test("validateParams should check estimatorParamMaps") { + import CrossValidatorSuite._ + + val est = new MyEstimator("est") + val eval = new MyEvaluator + val paramMaps = new ParamGridBuilder() + .addGrid(est.inputCol, Array("input1", "input2")) + .build() + + val cv = new CrossValidator() + .setEstimator(est) + .setEstimatorParamMaps(paramMaps) + .setEvaluator(eval) + + cv.validateParams() // This should pass. + + val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "") + cv.setEstimatorParamMaps(invalidParamMaps) + intercept[IllegalArgumentException] { + cv.validateParams() + } + } +} + +object CrossValidatorSuite { + + abstract class MyModel extends Model[MyModel] + + class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { + + override def validateParams(): Unit = require($(inputCol).nonEmpty) + + override def fit(dataset: DataFrame): MyModel = ??? + + override def transformSchema(schema: StructType): StructType = ??? + } + + class MyEvaluator extends Evaluator { + + override def evaluate(dataset: DataFrame): Double = ??? + + override val uid: String = "eval" + } } From 6dd645870d34d97ac992032bfd6cf39f20a0c50f Mon Sep 17 00:00:00 2001 From: Cheolsoo Park Date: Wed, 27 May 2015 00:18:42 -0700 Subject: [PATCH 019/254] [SPARK-7850][BUILD] Hive 0.12.0 profile in POM should be removed I grep'ed hive-0.12.0 in the source code and removed all the profiles and doc references. Author: Cheolsoo Park Closes #6393 from piaozhexiu/SPARK-7850 and squashes the following commits: fb429ce [Cheolsoo Park] Remove hive-0.13.1 profile 82bf09a [Cheolsoo Park] Remove hive 0.12.0 shim code f3722da [Cheolsoo Park] Remove hive-0.12.0 profile and references from POM and build docs --- docs/building-spark.md | 6 +- pom.xml | 16 - .../spark/sql/hive/thriftserver/Shim12.scala | 278 ------------------ sql/hive/pom.xml | 10 - .../org/apache/spark/sql/hive/Shim12.scala | 265 ----------------- 5 files changed, 1 insertion(+), 574 deletions(-) delete mode 100644 sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala delete mode 100644 sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala diff --git a/docs/building-spark.md b/docs/building-spark.md index 4dbccb9e6e46c..3ca7f2746e678 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -118,14 +118,10 @@ mvn -Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=2.2.0 -DskipTests # Building With Hive and JDBC Support To enable Hive integration for Spark SQL along with its JDBC server and CLI, add the `-Phive` and `Phive-thriftserver` profiles to your existing build options. -By default Spark will build with Hive 0.13.1 bindings. You can also build for -Hive 0.12.0 using the `-Phive-0.12.0` profile. +By default Spark will build with Hive 0.13.1 bindings. {% highlight bash %} # Apache Hadoop 2.4.X with Hive 13 support mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-thriftserver -DskipTests clean package - -# Apache Hadoop 2.4.X with Hive 12 support -mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-0.12.0 -Phive-thriftserver -DskipTests clean package {% endhighlight %} # Building for Scala 2.11 diff --git a/pom.xml b/pom.xml index c72d7cbf843ef..711edf9efad2b 100644 --- a/pom.xml +++ b/pom.xml @@ -1753,22 +1753,6 @@ sql/hive-thriftserver - - hive-0.12.0 - - 0.12.0-protobuf-2.5 - 0.12.0 - 10.4.2.0 - - - - hive-0.13.1 - - 0.13.1a - 0.13.1 - 10.10.1.1 - - scala-2.10 diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala deleted file mode 100644 index b3a79ba1c7d6b..0000000000000 --- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala +++ /dev/null @@ -1,278 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.thriftserver - -import java.sql.{Date, Timestamp} -import java.util.concurrent.Executors -import java.util.{ArrayList => JArrayList, Map => JMap, UUID} - -import org.apache.commons.logging.Log -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hive.service.cli.thrift.TProtocolVersion -import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager - -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, Map => SMap} - -import org.apache.hadoop.hive.common.`type`.HiveDecimal -import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.hadoop.hive.shims.ShimLoader -import org.apache.hadoop.security.UserGroupInformation -import org.apache.hive.service.cli._ -import org.apache.hive.service.cli.operation.ExecuteStatementOperation -import org.apache.hive.service.cli.session.{SessionManager, HiveSession} - -import org.apache.spark.Logging -import org.apache.spark.sql.{DataFrame, SQLConf, Row => SparkRow} -import org.apache.spark.sql.execution.SetCommand -import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ -import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -import org.apache.spark.sql.types._ - -/** - * A compatibility layer for interacting with Hive version 0.12.0. - */ -private[thriftserver] object HiveThriftServerShim { - val version = "0.12.0" - - def setServerUserName(sparkServiceUGI: UserGroupInformation, sparkCliService:SparkSQLCLIService) = { - val serverUserName = ShimLoader.getHadoopShims.getShortUserName(sparkServiceUGI) - setSuperField(sparkCliService, "serverUserName", serverUserName) - } -} - -private[hive] class SparkSQLDriver(val _context: HiveContext = SparkSQLEnv.hiveContext) - extends AbstractSparkSQLDriver(_context) { - override def getResults(res: JArrayList[String]): Boolean = { - if (hiveResponse == null) { - false - } else { - res.addAll(hiveResponse) - hiveResponse = null - true - } - } -} - -private[hive] class SparkExecuteStatementOperation( - parentSession: HiveSession, - statement: String, - confOverlay: JMap[String, String])( - hiveContext: HiveContext, - sessionToActivePool: SMap[SessionHandle, String]) - extends ExecuteStatementOperation(parentSession, statement, confOverlay) with Logging { - - private var result: DataFrame = _ - private var iter: Iterator[SparkRow] = _ - private var dataTypes: Array[DataType] = _ - - def close(): Unit = { - // RDDs will be cleaned automatically upon garbage collection. - logDebug("CLOSING") - } - - def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = { - if (!iter.hasNext) { - new RowSet() - } else { - // maxRowsL here typically maps to java.sql.Statement.getFetchSize, which is an int - val maxRows = maxRowsL.toInt - var curRow = 0 - var rowSet = new ArrayBuffer[Row](maxRows.min(1024)) - - while (curRow < maxRows && iter.hasNext) { - val sparkRow = iter.next() - val row = new Row() - var curCol = 0 - - while (curCol < sparkRow.length) { - if (sparkRow.isNullAt(curCol)) { - addNullColumnValue(sparkRow, row, curCol) - } else { - addNonNullColumnValue(sparkRow, row, curCol) - } - curCol += 1 - } - rowSet += row - curRow += 1 - } - new RowSet(rowSet, 0) - } - } - - def addNonNullColumnValue(from: SparkRow, to: Row, ordinal: Int) { - dataTypes(ordinal) match { - case StringType => - to.addString(from(ordinal).asInstanceOf[String]) - case IntegerType => - to.addColumnValue(ColumnValue.intValue(from.getInt(ordinal))) - case BooleanType => - to.addColumnValue(ColumnValue.booleanValue(from.getBoolean(ordinal))) - case DoubleType => - to.addColumnValue(ColumnValue.doubleValue(from.getDouble(ordinal))) - case FloatType => - to.addColumnValue(ColumnValue.floatValue(from.getFloat(ordinal))) - case DecimalType() => - val hiveDecimal = from.getDecimal(ordinal) - to.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal))) - case LongType => - to.addColumnValue(ColumnValue.longValue(from.getLong(ordinal))) - case ByteType => - to.addColumnValue(ColumnValue.byteValue(from.getByte(ordinal))) - case ShortType => - to.addColumnValue(ColumnValue.shortValue(from.getShort(ordinal))) - case DateType => - to.addColumnValue(ColumnValue.dateValue(from(ordinal).asInstanceOf[Date])) - case TimestampType => - to.addColumnValue( - ColumnValue.timestampValue(from.get(ordinal).asInstanceOf[Timestamp])) - case BinaryType | _: ArrayType | _: StructType | _: MapType => - val hiveString = HiveContext.toHiveString((from.get(ordinal), dataTypes(ordinal))) - to.addColumnValue(ColumnValue.stringValue(hiveString)) - } - } - - def addNullColumnValue(from: SparkRow, to: Row, ordinal: Int) { - dataTypes(ordinal) match { - case StringType => - to.addString(null) - case IntegerType => - to.addColumnValue(ColumnValue.intValue(null)) - case BooleanType => - to.addColumnValue(ColumnValue.booleanValue(null)) - case DoubleType => - to.addColumnValue(ColumnValue.doubleValue(null)) - case FloatType => - to.addColumnValue(ColumnValue.floatValue(null)) - case DecimalType() => - to.addColumnValue(ColumnValue.stringValue(null: HiveDecimal)) - case LongType => - to.addColumnValue(ColumnValue.longValue(null)) - case ByteType => - to.addColumnValue(ColumnValue.byteValue(null)) - case ShortType => - to.addColumnValue(ColumnValue.shortValue(null)) - case DateType => - to.addColumnValue(ColumnValue.dateValue(null)) - case TimestampType => - to.addColumnValue(ColumnValue.timestampValue(null)) - case BinaryType | _: ArrayType | _: StructType | _: MapType => - to.addColumnValue(ColumnValue.stringValue(null: String)) - } - } - - def getResultSetSchema: TableSchema = { - logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") - if (result.queryExecution.analyzed.output.size == 0) { - new TableSchema(new FieldSchema("Result", "string", "") :: Nil) - } else { - val schema = result.queryExecution.analyzed.output.map { attr => - new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") - } - new TableSchema(schema) - } - } - - def run(): Unit = { - val statementId = UUID.randomUUID().toString - logInfo(s"Running query '$statement'") - setState(OperationState.RUNNING) - HiveThriftServer2.listener.onStatementStart( - statementId, parentSession.getSessionHandle.getSessionId.toString, statement, statementId) - hiveContext.sparkContext.setJobGroup(statementId, statement) - sessionToActivePool.get(parentSession.getSessionHandle).foreach { pool => - hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) - } - try { - result = hiveContext.sql(statement) - logDebug(result.queryExecution.toString()) - result.queryExecution.logical match { - case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value))), _) => - sessionToActivePool(parentSession.getSessionHandle) = value - logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") - case _ => - } - HiveThriftServer2.listener.onStatementParsed(statementId, result.queryExecution.toString()) - iter = { - val useIncrementalCollect = - hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean - if (useIncrementalCollect) { - result.rdd.toLocalIterator - } else { - result.collect().iterator - } - } - dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray - setHasResultSet(true) - } catch { - // Actually do need to catch Throwable as some failures don't inherit from Exception and - // HiveServer will silently swallow them. - case e: Throwable => - setState(OperationState.ERROR) - HiveThriftServer2.listener.onStatementError( - statementId, e.getMessage, e.getStackTraceString) - logError("Error executing query:",e) - throw new HiveSQLException(e.toString) - } - setState(OperationState.FINISHED) - HiveThriftServer2.listener.onStatementFinish(statementId) - } -} - -private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) - extends SessionManager - with ReflectedCompositeService { - - private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) - - override def init(hiveConf: HiveConf) { - setSuperField(this, "hiveConf", hiveConf) - - val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) - setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) - getAncestorField[Log](this, 3, "LOG").info( - s"HiveServer2: Async execution pool size $backgroundPoolSize") - - setSuperField(this, "operationManager", sparkSqlOperationManager) - addService(sparkSqlOperationManager) - - initCompositeService(hiveConf) - } - - override def openSession( - username: String, - passwd: String, - sessionConf: java.util.Map[String, String], - withImpersonation: Boolean, - delegationToken: String): SessionHandle = { - hiveContext.openSession() - val sessionHandle = super.openSession( - username, passwd, sessionConf, withImpersonation, delegationToken) - HiveThriftServer2.listener.onSessionCreated("UNKNOWN", sessionHandle.getSessionId.toString) - sessionHandle - } - - override def closeSession(sessionHandle: SessionHandle) { - HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString) - super.closeSession(sessionHandle) - sparkSqlOperationManager.sessionToActivePool -= sessionHandle - - hiveContext.detachSession() - } -} diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index e322340094e6f..615b07e74d535 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -136,16 +136,6 @@ - - hive-0.12.0 - - - com.twitter - parquet-hive-bundle - 1.5.0 - - - diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala deleted file mode 100644 index 33e96eaabfbf6..0000000000000 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ /dev/null @@ -1,265 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.net.URI -import java.util.{ArrayList => JArrayList, Properties} - -import scala.collection.JavaConversions._ -import scala.language.implicitConversions - -import org.apache.hadoop.{io => hadoopIo} -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.common.`type`.HiveDecimal -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.Context -import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} -import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} -import org.apache.hadoop.hive.ql.processors._ -import org.apache.hadoop.hive.ql.stats.StatsSetupConst -import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer, io => hiveIo} -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, ObjectInspector, PrimitiveObjectInspector} -import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory -import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} -import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory} -import org.apache.hadoop.io.{NullWritable, Writable} -import org.apache.hadoop.mapred.InputFormat - -import org.apache.spark.sql.types.{UTF8String, Decimal, DecimalType} - -private[hive] case class HiveFunctionWrapper(functionClassName: String) - extends java.io.Serializable { - - // for Serialization - def this() = this(null) - - import org.apache.spark.util.Utils._ - def createFunction[UDFType <: AnyRef](): UDFType = { - getContextOrSparkClassLoader - .loadClass(functionClassName).newInstance.asInstanceOf[UDFType] - } -} - -/** - * A compatibility layer for interacting with Hive version 0.12.0. - */ -private[hive] object HiveShim { - val version = "0.12.0" - - def getTableDesc( - serdeClass: Class[_ <: Deserializer], - inputFormatClass: Class[_ <: InputFormat[_, _]], - outputFormatClass: Class[_], - properties: Properties) = { - new TableDesc(serdeClass, inputFormatClass, outputFormatClass, properties) - } - - def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.STRING, - getStringWritable(value)) - - def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.INT, - getIntWritable(value)) - - def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.DOUBLE, - getDoubleWritable(value)) - - def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.BOOLEAN, - getBooleanWritable(value)) - - def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.LONG, - getLongWritable(value)) - - def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.FLOAT, - getFloatWritable(value)) - - def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.SHORT, - getShortWritable(value)) - - def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.BYTE, - getByteWritable(value)) - - def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.BINARY, - getBinaryWritable(value)) - - def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.DATE, - getDateWritable(value)) - - def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.TIMESTAMP, - getTimestampWritable(value)) - - def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.DECIMAL, - getDecimalWritable(value)) - - def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.VOID, null) - - def getStringWritable(value: Any): hadoopIo.Text = - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) - - def getIntWritable(value: Any): hadoopIo.IntWritable = - if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) - - def getDoubleWritable(value: Any): hiveIo.DoubleWritable = - if (value == null) null else new hiveIo.DoubleWritable(value.asInstanceOf[Double]) - - def getBooleanWritable(value: Any): hadoopIo.BooleanWritable = - if (value == null) null else new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]) - - def getLongWritable(value: Any): hadoopIo.LongWritable = - if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long]) - - def getFloatWritable(value: Any): hadoopIo.FloatWritable = - if (value == null) null else new hadoopIo.FloatWritable(value.asInstanceOf[Float]) - - def getShortWritable(value: Any): hiveIo.ShortWritable = - if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short]) - - def getByteWritable(value: Any): hiveIo.ByteWritable = - if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte]) - - def getBinaryWritable(value: Any): hadoopIo.BytesWritable = - if (value == null) null else new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) - - def getDateWritable(value: Any): hiveIo.DateWritable = - if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[Int]) - - def getTimestampWritable(value: Any): hiveIo.TimestampWritable = - if (value == null) { - null - } else { - new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) - } - - def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable = - if (value == null) { - null - } else { - new hiveIo.HiveDecimalWritable( - HiveShim.createDecimal(value.asInstanceOf[Decimal].toJavaBigDecimal)) - } - - def getPrimitiveNullWritable: NullWritable = NullWritable.get() - - def createDriverResultsArray = new JArrayList[String] - - def processResults(results: JArrayList[String]) = results - - def getStatsSetupConstTotalSize = StatsSetupConst.TOTAL_SIZE - - def getStatsSetupConstRawDataSize = StatsSetupConst.RAW_DATA_SIZE - - def createDefaultDBIfNeeded(context: HiveContext) = { } - - def getCommandProcessor(cmd: Array[String], conf: HiveConf) = { - CommandProcessorFactory.get(cmd(0), conf) - } - - def createDecimal(bd: java.math.BigDecimal): HiveDecimal = { - new HiveDecimal(bd) - } - - def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { - ColumnProjectionUtils.appendReadColumnIDs(conf, ids) - ColumnProjectionUtils.appendReadColumnNames(conf, names) - } - - def getExternalTmpPath(context: Context, uri: URI) = { - context.getExternalTmpFileURI(uri) - } - - def getDataLocationPath(p: Partition) = p.getPartitionPath - - def getAllPartitionsOf(client: Hive, tbl: Table) = client.getAllPartitionsForPruner(tbl) - - def compatibilityBlackList = Seq( - "decimal_.*", - "udf7", - "drop_partitions_filter2", - "show_.*", - "serde_regex", - "udf_to_date", - "udaf_collect_set", - "udf_concat" - ) - - def setLocation(tbl: Table, crtTbl: CreateTableDesc): Unit = { - tbl.setDataLocation(new Path(crtTbl.getLocation()).toUri()) - } - - def decimalMetastoreString(decimalType: DecimalType): String = "decimal" - - def decimalTypeInfo(decimalType: DecimalType): TypeInfo = - TypeInfoFactory.decimalTypeInfo - - def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = { - DecimalType.Unlimited - } - - def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { - if (hdoi.preferWritable()) { - Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue) - } else { - Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) - } - } - - def getConvertedOI( - inputOI: ObjectInspector, - outputOI: ObjectInspector): ObjectInspector = { - ObjectInspectorConverters.getConvertedOI(inputOI, outputOI, true) - } - - def prepareWritable(w: Writable): Writable = { - w - } - - def setTblNullFormat(crtTbl: CreateTableDesc, tbl: Table) = {} -} - -private[hive] class ShimFileSinkDesc( - var dir: String, - var tableInfo: TableDesc, - var compressed: Boolean) - extends FileSinkDesc(dir, tableInfo, compressed) { -} From 4f98d7a7f1715273bc91f1903bb7e0f287cc7394 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 27 May 2015 00:27:39 -0700 Subject: [PATCH 020/254] [SPARK-7697][SQL] Use LongType for unsigned int in JDBCRDD JIRA: https://issues.apache.org/jira/browse/SPARK-7697 The reported problem case is mysql. But for h2 db, there is no unsigned int. So it is not able to add corresponding test. Author: Liang-Chi Hsieh Closes #6229 from viirya/unsignedint_as_long and squashes the following commits: dc4b5d8 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into unsignedint_as_long 608695b [Liang-Chi Hsieh] Use LongType for unsigned int in JDBCRDD. --- .../scala/org/apache/spark/sql/jdbc/JDBCRDD.scala | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index be03a237b6c4e..244bd3ebfeb7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -46,7 +46,11 @@ private[sql] object JDBCRDD extends Logging { * @param sqlType - A field of java.sql.Types * @return The Catalyst type corresponding to sqlType. */ - private def getCatalystType(sqlType: Int, precision: Int, scale: Int): DataType = { + private def getCatalystType( + sqlType: Int, + precision: Int, + scale: Int, + signed: Boolean): DataType = { val answer = sqlType match { case java.sql.Types.ARRAY => null case java.sql.Types.BIGINT => LongType @@ -64,7 +68,7 @@ private[sql] object JDBCRDD extends Logging { case java.sql.Types.DISTINCT => null case java.sql.Types.DOUBLE => DoubleType case java.sql.Types.FLOAT => FloatType - case java.sql.Types.INTEGER => IntegerType + case java.sql.Types.INTEGER => if (signed) { IntegerType } else { LongType } case java.sql.Types.JAVA_OBJECT => null case java.sql.Types.LONGNVARCHAR => StringType case java.sql.Types.LONGVARBINARY => BinaryType @@ -123,11 +127,12 @@ private[sql] object JDBCRDD extends Logging { val typeName = rsmd.getColumnTypeName(i + 1) val fieldSize = rsmd.getPrecision(i + 1) val fieldScale = rsmd.getScale(i + 1) + val isSigned = rsmd.isSigned(i + 1) val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls val metadata = new MetadataBuilder().putString("name", columnName) val columnType = dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( - getCatalystType(dataType, fieldSize, fieldScale)) + getCatalystType(dataType, fieldSize, fieldScale, isSigned)) fields(i) = StructField(columnName, columnType, nullable, metadata.build()) i = i + 1 } From 9f48bf6b3761d66c7dc50f076ed92aff21b7eea0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 27 May 2015 01:12:59 -0700 Subject: [PATCH 021/254] [SPARK-7887][SQL] Remove EvaluatedType from SQL Expression. This type is not really used. Might as well remove it. Author: Reynold Xin Closes #6427 from rxin/evalutedType and squashes the following commits: 51a319a [Reynold Xin] [SPARK-7887][SQL] Remove EvaluatedType from SQL Expression. --- .../spark/sql/catalyst/analysis/unresolved.scala | 10 +++++----- .../catalyst/expressions/BoundAttribute.scala | 2 -- .../spark/sql/catalyst/expressions/Cast.scala | 2 -- .../sql/catalyst/expressions/Expression.scala | 8 ++------ .../sql/catalyst/expressions/ExtractValue.scala | 2 -- .../sql/catalyst/expressions/ScalaUdf.scala | 2 -- .../sql/catalyst/expressions/SortOrder.scala | 2 +- .../sql/catalyst/expressions/aggregates.scala | 4 +--- .../sql/catalyst/expressions/arithmetic.scala | 16 ++++------------ .../sql/catalyst/expressions/complexTypes.scala | 6 ++---- .../catalyst/expressions/decimalFunctions.scala | 2 -- .../sql/catalyst/expressions/generators.scala | 2 -- .../sql/catalyst/expressions/literals.scala | 2 -- .../catalyst/expressions/mathfuncs/binary.scala | 3 ++- .../catalyst/expressions/mathfuncs/unary.scala | 1 - .../catalyst/expressions/namedExpressions.scala | 6 ++---- .../sql/catalyst/expressions/nullFunctions.scala | 1 - .../sql/catalyst/expressions/predicates.scala | 6 ------ .../spark/sql/catalyst/expressions/random.scala | 2 -- .../spark/sql/catalyst/expressions/sets.scala | 4 ---- .../catalyst/expressions/stringOperations.scala | 8 -------- .../catalyst/expressions/windowExpressions.scala | 12 +++++------- .../catalyst/plans/physical/partitioning.scala | 4 ++-- .../spark/sql/catalyst/trees/TreeNodeSuite.scala | 9 ++++----- .../expressions/MonotonicallyIncreasingID.scala | 2 -- .../execution/expressions/SparkPartitionID.scala | 2 -- .../apache/spark/sql/execution/pythonUdfs.scala | 2 +- .../org/apache/spark/sql/hive/hiveUdfs.scala | 5 +---- 28 files changed, 32 insertions(+), 95 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 2999c2ef3efe1..bbb150c1e83c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -67,7 +67,7 @@ case class UnresolvedAttribute(nameParts: Seq[String]) override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName) // Unresolved attributes are transient at compile time and don't get evaluated during execution. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"'$name" @@ -85,7 +85,7 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E override lazy val resolved = false // Unresolved functions are transient at compile time and don't get evaluated during execution. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"'$name(${children.mkString(",")})" @@ -107,7 +107,7 @@ trait Star extends NamedExpression with trees.LeafNode[Expression] { override lazy val resolved = false // Star gets expanded at runtime so we never evaluate a Star. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] @@ -166,7 +166,7 @@ case class MultiAlias(child: Expression, names: Seq[String]) override lazy val resolved = false - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"$child AS $names" @@ -200,7 +200,7 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression) override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"$child[$extraction]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index c6217f07c452d..1ffc95c676f6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -30,8 +30,6 @@ import org.apache.spark.sql.catalyst.trees case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) extends NamedExpression with trees.LeafNode[Expression] { - type EvaluatedType = Any - override def toString: String = s"input[$ordinal]" override def eval(input: Row): Any = input(ordinal) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d8cf2b2e32435..df3cdf2cdf992 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -105,8 +105,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w override def toString: String = s"CAST($child, $dataType)" - type EvaluatedType = Any - // [[func]] assumes the input is no longer null because eval already does the null check. @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index c7ae9da7fce49..d19928784442e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -25,9 +25,6 @@ import org.apache.spark.sql.types._ abstract class Expression extends TreeNode[Expression] { self: Product => - /** The narrowest possible type that is produced when this expression is evaluated. */ - type EvaluatedType <: Any - /** * Returns true when an expression is a candidate for static evaluation before the query is * executed. @@ -44,7 +41,7 @@ abstract class Expression extends TreeNode[Expression] { def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator)) /** Returns the result of evaluating this expression on a given input Row */ - def eval(input: Row = null): EvaluatedType + def eval(input: Row = null): Any /** * Returns `true` if this expression and all its children have been resolved to a specific schema @@ -117,8 +114,7 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio // not like a real expressions. case class GroupExpression(children: Seq[Expression]) extends Expression { self: Product => - type EvaluatedType = Seq[Any] - override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException + override def eval(input: Row): Any = throw new UnsupportedOperationException override def nullable: Boolean = false override def foldable: Boolean = false override def dataType: DataType = throw new UnsupportedOperationException diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala index e05926cbfe74b..b5f4e16745c1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala @@ -92,8 +92,6 @@ object ExtractValue { trait ExtractValue extends UnaryExpression { self: Product => - - type EvaluatedType = Any } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index fe2873e0be34d..5b45347872cca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.types.DataType case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expression]) extends Expression { - type EvaluatedType = Any - override def nullable: Boolean = true override def toString: String = s"scalaUDF(${children.mkString(",")})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 83074eb1e6310..195eec8e5cdc4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -36,7 +36,7 @@ case class SortOrder(child: Expression, direction: SortDirection) extends Expres override def nullable: Boolean = child.nullable // SortOrder itself is never evaluated. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index f3830c6d3bcf2..72eff5fe961f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -37,7 +37,7 @@ abstract class AggregateExpression extends Expression { * [[AggregateExpression.eval]] should never be invoked because [[AggregateExpression]]'s are * replaced with a physical aggregate operator at runtime. */ - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } @@ -74,8 +74,6 @@ abstract class AggregateFunction extends AggregateExpression with Serializable with trees.LeafNode[Expression] { self: Product => - override type EvaluatedType = Any - /** Base should return the generic aggregate expression that this function is computing */ val base: AggregateExpression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index c7a37ad966df6..34c833b260dc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.types._ case class UnaryMinus(child: Expression) extends UnaryExpression { - type EvaluatedType = Any override def dataType: DataType = child.dataType override def foldable: Boolean = child.foldable @@ -45,7 +44,6 @@ case class UnaryMinus(child: Expression) extends UnaryExpression { } case class Sqrt(child: Expression) extends UnaryExpression { - type EvaluatedType = Any override def dataType: DataType = DoubleType override def foldable: Boolean = child.foldable @@ -72,8 +70,6 @@ case class Sqrt(child: Expression) extends UnaryExpression { abstract class BinaryArithmetic extends BinaryExpression { self: Product => - type EvaluatedType = Any - override lazy val resolved = left.resolved && right.resolved && left.dataType == right.dataType && @@ -101,7 +97,7 @@ abstract class BinaryArithmetic extends BinaryExpression { } } - def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = + def evalInternal(evalE1: Any, evalE2: Any): Any = sys.error(s"BinaryExpressions must either override eval or evalInternal") } @@ -244,7 +240,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme case other => sys.error(s"Unsupported bitwise & operation on $other") } - override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = and(evalE1, evalE2) + override def evalInternal(evalE1: Any, evalE2: Any): Any = and(evalE1, evalE2) } /** @@ -265,7 +261,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet case other => sys.error(s"Unsupported bitwise | operation on $other") } - override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = or(evalE1, evalE2) + override def evalInternal(evalE1: Any, evalE2: Any): Any = or(evalE1, evalE2) } /** @@ -286,14 +282,13 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme case other => sys.error(s"Unsupported bitwise ^ operation on $other") } - override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = xor(evalE1, evalE2) + override def evalInternal(evalE1: Any, evalE2: Any): Any = xor(evalE1, evalE2) } /** * A function that calculates bitwise not(~) of a number. */ case class BitwiseNot(child: Expression) extends UnaryExpression { - type EvaluatedType = Any override def dataType: DataType = child.dataType override def foldable: Boolean = child.foldable @@ -323,7 +318,6 @@ case class BitwiseNot(child: Expression) extends UnaryExpression { } case class MaxOf(left: Expression, right: Expression) extends Expression { - type EvaluatedType = Any override def foldable: Boolean = left.foldable && right.foldable @@ -368,7 +362,6 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { } case class MinOf(left: Expression, right: Expression) extends Expression { - type EvaluatedType = Any override def foldable: Boolean = left.foldable && right.foldable @@ -416,7 +409,6 @@ case class MinOf(left: Expression, right: Expression) extends Expression { * A function that get the absolute value of the numeric value. */ case class Abs(child: Expression) extends UnaryExpression { - type EvaluatedType = Any override def dataType: DataType = child.dataType override def foldable: Boolean = child.foldable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 956a2429b0b61..e7cd7131a9e56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -24,8 +24,7 @@ import org.apache.spark.sql.types._ * Returns an Array containing the evaluation of all children expressions. */ case class CreateArray(children: Seq[Expression]) extends Expression { - override type EvaluatedType = Any - + override def foldable: Boolean = children.forall(_.foldable) lazy val childTypes = children.map(_.dataType).distinct @@ -54,7 +53,6 @@ case class CreateArray(children: Seq[Expression]) extends Expression { * TODO: [[CreateStruct]] does not support codegen. */ case class CreateStruct(children: Seq[NamedExpression]) extends Expression { - override type EvaluatedType = Row override def foldable: Boolean = children.forall(_.foldable) @@ -71,7 +69,7 @@ case class CreateStruct(children: Seq[NamedExpression]) extends Expression { override def nullable: Boolean = false - override def eval(input: Row): EvaluatedType = { + override def eval(input: Row): Any = { Row(children.map(_.eval(input)): _*) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index adb94df7d1c7b..65ba18924afe1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -21,7 +21,6 @@ import org.apache.spark.sql.types._ /** Return the unscaled Long value of a Decimal, assuming it fits in a Long */ case class UnscaledValue(child: Expression) extends UnaryExpression { - override type EvaluatedType = Any override def dataType: DataType = LongType override def foldable: Boolean = child.foldable @@ -40,7 +39,6 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { /** Create a Decimal from an unscaled Long value */ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression { - override type EvaluatedType = Decimal override def dataType: DataType = DecimalType(precision, scale) override def foldable: Boolean = child.foldable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 747a47bdde953..cab40feb72d47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -40,8 +40,6 @@ import org.apache.spark.sql.types._ abstract class Generator extends Expression { self: Product => - override type EvaluatedType = TraversableOnce[Row] - // TODO ideally we should return the type of ArrayType(StructType), // however, we don't keep the output field names in the Generator. override def dataType: DataType = throw new UnsupportedOperationException diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 5f8c7354aede1..d3ca3d9a4b18b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -78,14 +78,12 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres override def toString: String = if (value != null) value.toString else "null" - type EvaluatedType = Any override def eval(input: Row): Any = value } // TODO: Specialize case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true) extends LeafExpression { - type EvaluatedType = Any def update(expression: Expression, input: Row): Unit = { value = expression.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala index fcc06d3aa1036..d5be44626f8e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.types._ */ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => - type EvaluatedType = Any + override def symbol: String = null override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) @@ -68,6 +68,7 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) case class Atan2( left: Expression, right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { + override def eval(input: Row): Any = { val evalE1 = left.eval(input) if (evalE1 == null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala index dc68469e060cb..cdcb8e2840baf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.types._ abstract class MathematicalExpression(f: Double => Double, name: String) extends UnaryExpression with Serializable with ExpectsInputTypes { self: Product => - type EvaluatedType = Any override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) override def dataType: DataType = DoubleType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 50be26d0b08b5..00565ec651a59 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -111,7 +111,6 @@ case class Alias(child: Expression, name: String)( val explicitMetadata: Option[Metadata] = None) extends NamedExpression with trees.UnaryNode[Expression] { - override type EvaluatedType = Any // Alias(Generator, xx) need to be transformed into Generate(generator, ...) override lazy val resolved = childrenResolved && !child.isInstanceOf[Generator] @@ -229,7 +228,7 @@ case class AttributeReference( } // Unresolved attributes are transient at compile time and don't get evaluated during execution. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"$name#${exprId.id}$typeSuffix" @@ -240,7 +239,6 @@ case class AttributeReference( * expression id or the unresolved indicator. */ case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[Expression] { - type EvaluatedType = Any override def toString: String = name @@ -252,7 +250,7 @@ case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[E override def withName(newName: String): Attribute = throw new UnsupportedOperationException override def qualifiers: Seq[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException - override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException + override def eval(input: Row): Any = throw new UnsupportedOperationException override def nullable: Boolean = throw new UnsupportedOperationException override def dataType: DataType = NullType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index f9161cf34f0c9..5070570b4740d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.types.DataType case class Coalesce(children: Seq[Expression]) extends Expression { - type EvaluatedType = Any /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ override def nullable: Boolean = !children.exists(!_.nullable) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 1d72a9eb834b9..e2d1c8115e051 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -35,8 +35,6 @@ trait Predicate extends Expression { self: Product => override def dataType: DataType = BooleanType - - type EvaluatedType = Any } trait PredicateHelper { @@ -341,8 +339,6 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi trueValue.dataType } - type EvaluatedType = Any - override def eval(input: Row): Any = { if (true == predicate.eval(input)) { trueValue.eval(input) @@ -357,8 +353,6 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi trait CaseWhenLike extends Expression { self: Product => - type EvaluatedType = Any - // Note that `branches` are considered in consecutive pairs (cond, val), and the optional last // element is the value for the default catch-all case (if provided). // Hence, `branches` consists of at least two elements, and can have an odd or even length. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index 66d7c8b07cce8..de82c15680607 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -38,8 +38,6 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable { */ @transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.get().partitionId()) - override type EvaluatedType = Double - override def nullable: Boolean = false override def dataType: DataType = DoubleType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 4c44182278207..b65bf165f21db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -51,7 +51,6 @@ private[sql] class OpenHashSetUDT( * Creates a new set of the specified type */ case class NewSet(elementType: DataType) extends LeafExpression { - type EvaluatedType = Any override def nullable: Boolean = false @@ -69,7 +68,6 @@ case class NewSet(elementType: DataType) extends LeafExpression { * For performance, this expression mutates its input during evaluation. */ case class AddItemToSet(item: Expression, set: Expression) extends Expression { - type EvaluatedType = Any override def children: Seq[Expression] = item :: set :: Nil @@ -101,7 +99,6 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { * For performance, this expression mutates its left input set during evaluation. */ case class CombineSets(left: Expression, right: Expression) extends BinaryExpression { - type EvaluatedType = Any override def nullable: Boolean = left.nullable || right.nullable @@ -133,7 +130,6 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres * Returns the number of elements in the input set. */ case class CountSet(child: Expression) extends UnaryExpression { - type EvaluatedType = Any override def nullable: Boolean = child.nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 7683e0990ce80..5da93fe9c6cf9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -25,8 +25,6 @@ import org.apache.spark.sql.types._ trait StringRegexExpression extends ExpectsInputTypes { self: BinaryExpression => - type EvaluatedType = Any - def escape(v: String): String def matches(regex: Pattern, str: String): Boolean @@ -114,8 +112,6 @@ case class RLike(left: Expression, right: Expression) trait CaseConversionExpression extends ExpectsInputTypes { self: UnaryExpression => - type EvaluatedType = Any - def convert(v: UTF8String): UTF8String override def foldable: Boolean = child.foldable @@ -159,8 +155,6 @@ trait StringComparison extends ExpectsInputTypes { def compare(l: UTF8String, r: UTF8String): Boolean - override type EvaluatedType = Any - override def nullable: Boolean = left.nullable || right.nullable override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType) @@ -211,8 +205,6 @@ case class EndsWith(left: Expression, right: Expression) */ case class Substring(str: Expression, pos: Expression, len: Expression) extends Expression with ExpectsInputTypes { - - type EvaluatedType = Any override def foldable: Boolean = str.foldable && pos.foldable && len.foldable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 099d67ca7fee3..2729b34a0833f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -66,8 +66,6 @@ case class WindowSpecDefinition( } } - type EvaluatedType = Any - override def children: Seq[Expression] = partitionSpec ++ orderSpec override lazy val resolved: Boolean = @@ -76,7 +74,7 @@ case class WindowSpecDefinition( override def toString: String = simpleString - override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException + override def eval(input: Row): Any = throw new UnsupportedOperationException override def nullable: Boolean = true override def foldable: Boolean = false override def dataType: DataType = throw new UnsupportedOperationException @@ -299,7 +297,7 @@ case class UnresolvedWindowFunction( override def get(index: Int): Any = throw new UnresolvedException(this, "get") // Unresolved functions are transient at compile time and don't get evaluated during execution. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"'$name(${children.mkString(",")})" @@ -311,25 +309,25 @@ case class UnresolvedWindowFunction( case class UnresolvedWindowExpression( child: UnresolvedWindowFunction, windowSpec: WindowSpecReference) extends UnaryExpression { + override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def foldable: Boolean = throw new UnresolvedException(this, "foldable") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false // Unresolved functions are transient at compile time and don't get evaluated during execution. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } case class WindowExpression( windowFunction: WindowFunction, windowSpec: WindowSpecDefinition) extends Expression { - override type EvaluatedType = Any override def children: Seq[Expression] = windowFunction :: windowSpec :: Nil - override def eval(input: Row): EvaluatedType = + override def eval(input: Row): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def dataType: DataType = windowFunction.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index fb4217a44807b..80ba57a082a60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -169,7 +169,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def keyExpressions: Seq[Expression] = expressions - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } @@ -213,6 +213,6 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) override def keyExpressions: Seq[Expression] = ordering.map(_.child) - override def eval(input: Row): EvaluatedType = + override def eval(input: Row): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 3d10dab5ba34c..e5f77dcd962a4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -25,12 +25,11 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{IntegerType, StringType, NullType} case class Dummy(optKey: Option[Expression]) extends Expression { - def children: Seq[Expression] = optKey.toSeq - def nullable: Boolean = true - def dataType: NullType = NullType + override def children: Seq[Expression] = optKey.toSeq + override def nullable: Boolean = true + override def dataType: NullType = NullType override lazy val resolved = true - type EvaluatedType = Any - def eval(input: Row): Any = null.asInstanceOf[Any] + override def eval(input: Row): Any = null.asInstanceOf[Any] } class TreeNodeSuite extends FunSuite { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala index 9ac732b55b188..e228a60c9029f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -39,8 +39,6 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { */ @transient private[this] var count: Long = 0L - override type EvaluatedType = Long - override def nullable: Boolean = false override def dataType: DataType = LongType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala index c2c6cbd491598..1272793f88cd0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.types.{IntegerType, DataType} */ private[sql] case object SparkPartitionID extends LeafExpression { - override type EvaluatedType = Int - override def nullable: Boolean = false override def dataType: DataType = IntegerType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 11b2897f76786..55f3ff4709013 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -56,7 +56,7 @@ private[spark] case class PythonUDF( def nullable: Boolean = true - override def eval(input: Row): PythonUDF.this.EvaluatedType = { + override def eval(input: Row): Any = { sys.error("PythonUDFs can not be directly evaluated.") } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index bc6b3a2d58c38..7ec4f7332502e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -75,7 +75,7 @@ private[hive] abstract class HiveFunctionRegistry private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with Logging { - type EvaluatedType = Any + type UDFType = UDF override def nullable: Boolean = true @@ -139,7 +139,6 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector) private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with Logging { type UDFType = GenericUDF - type EvaluatedType = Any override def nullable: Boolean = true @@ -336,8 +335,6 @@ private[hive] case class HiveWindowFunction( def nullable: Boolean = true - override type EvaluatedType = Any - override def eval(input: Row): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") From 3e7d7d6b3d6e07b52b1a138f7aa2ef628597fe05 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 27 May 2015 01:13:57 -0700 Subject: [PATCH 022/254] [SQL] Rename MathematicalExpression UnaryMathExpression, and specify BinaryMathExpression's output data type as DoubleType. Two minor changes. cc brkyvz Author: Reynold Xin Closes #6428 from rxin/math-func-cleanup and squashes the following commits: 5910df5 [Reynold Xin] [SQL] Rename MathematicalExpression UnaryMathExpression, and specify BinaryMathExpression's output data type as DoubleType. --- .../expressions/mathfuncs/binary.scala | 9 +--- .../expressions/mathfuncs/unary.scala | 46 +++++++++---------- 2 files changed, 23 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala index d5be44626f8e1..890efc9f52ca3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.mathfuncs -import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, BinaryExpression, Expression, Row} import org.apache.spark.sql.types._ @@ -41,13 +40,7 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) left.dataType == right.dataType && !DecimalType.isFixed(left.dataType) - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") - } - left.dataType - } + override def dataType: DataType = DoubleType override def eval(input: Row): Any = { val evalE1 = left.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala index cdcb8e2840baf..41b422346a02d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.types._ * input format, therefore these functions extend `ExpectsInputTypes`. * @param name The short name of the function */ -abstract class MathematicalExpression(f: Double => Double, name: String) +abstract class UnaryMathExpression(f: Double => Double, name: String) extends UnaryExpression with Serializable with ExpectsInputTypes { self: Product => @@ -46,46 +46,44 @@ abstract class MathematicalExpression(f: Double => Double, name: String) } } -case class Acos(child: Expression) extends MathematicalExpression(math.acos, "ACOS") +case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS") -case class Asin(child: Expression) extends MathematicalExpression(math.asin, "ASIN") +case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN") -case class Atan(child: Expression) extends MathematicalExpression(math.atan, "ATAN") +case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN") -case class Cbrt(child: Expression) extends MathematicalExpression(math.cbrt, "CBRT") +case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT") -case class Ceil(child: Expression) extends MathematicalExpression(math.ceil, "CEIL") +case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL") -case class Cos(child: Expression) extends MathematicalExpression(math.cos, "COS") +case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") -case class Cosh(child: Expression) extends MathematicalExpression(math.cosh, "COSH") +case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH") -case class Exp(child: Expression) extends MathematicalExpression(math.exp, "EXP") +case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP") -case class Expm1(child: Expression) extends MathematicalExpression(math.expm1, "EXPM1") +case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1") -case class Floor(child: Expression) extends MathematicalExpression(math.floor, "FLOOR") +case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") -case class Log(child: Expression) extends MathematicalExpression(math.log, "LOG") +case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG") -case class Log10(child: Expression) extends MathematicalExpression(math.log10, "LOG10") +case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG10") -case class Log1p(child: Expression) extends MathematicalExpression(math.log1p, "LOG1P") +case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P") -case class Rint(child: Expression) extends MathematicalExpression(math.rint, "ROUND") +case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") -case class Signum(child: Expression) extends MathematicalExpression(math.signum, "SIGNUM") +case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM") -case class Sin(child: Expression) extends MathematicalExpression(math.sin, "SIN") +case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") -case class Sinh(child: Expression) extends MathematicalExpression(math.sinh, "SINH") +case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH") -case class Tan(child: Expression) extends MathematicalExpression(math.tan, "TAN") +case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") -case class Tanh(child: Expression) extends MathematicalExpression(math.tanh, "TANH") +case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH") -case class ToDegrees(child: Expression) - extends MathematicalExpression(math.toDegrees, "DEGREES") +case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") -case class ToRadians(child: Expression) - extends MathematicalExpression(math.toRadians, "RADIANS") +case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") From 4615081d7a10b023491e25478d19b8161e030974 Mon Sep 17 00:00:00 2001 From: scwf Date: Wed, 27 May 2015 09:12:18 -0500 Subject: [PATCH 023/254] [CORE] [TEST] HistoryServerSuite failed due to timezone issue follow up for #6377 Change time to the equivalent in GMT /cc squito Author: scwf Closes #6425 from scwf/fix-HistoryServerSuite and squashes the following commits: 4d37935 [scwf] fix HistoryServerSuite --- .../org/apache/spark/deploy/history/HistoryServerSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 4adb5122bcf1a..e10dd4cf837aa 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -82,7 +82,7 @@ class HistoryServerSuite extends FunSuite with BeforeAndAfter with Matchers with "running app list json" -> "applications?status=running", "minDate app list json" -> "applications?minDate=2015-02-10", "maxDate app list json" -> "applications?maxDate=2015-02-10", - "maxDate2 app list json" -> "applications?maxDate=2015-02-03T10:42:40.000CST", + "maxDate2 app list json" -> "applications?maxDate=2015-02-03T16:42:40.000GMT", "one app json" -> "applications/local-1422981780767", "one app multi-attempt json" -> "applications/local-1426533911241", "job list json" -> "applications/local-1422981780767/jobs", From ff0ddff46935ae3d036b7dbc437fff8a6c19d6a4 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Wed, 27 May 2015 09:32:29 -0700 Subject: [PATCH 024/254] [SPARK-7878] Rename Stage.jobId to firstJobId The previous name was confusing, because each stage can be associated with many jobs, and jobId is just the ID of the first job that was associated with the Stage. This commit also renames some of the method parameters in DAGScheduler.scala to clarify when the jobId refers to the first job ID associated with the stage (as opposed to the jobId associated with a job that's currently being scheduled). cc markhamstra JoshRosen (hopefully this will help prevent future bugs like SPARK-6880) Author: Kay Ousterhout Closes #6418 from kayousterhout/SPARK-7878 and squashes the following commits: b71a9b8 [Kay Ousterhout] [SPARK-7878] Rename Stage.jobId to firstJobId --- .../apache/spark/scheduler/DAGScheduler.scala | 58 +++++++++---------- .../apache/spark/scheduler/ResultStage.scala | 4 +- .../spark/scheduler/ShuffleMapStage.scala | 4 +- .../org/apache/spark/scheduler/Stage.scala | 4 +- 4 files changed, 33 insertions(+), 37 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5d812918a13d1..a083be2448aa3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -208,19 +208,17 @@ class DAGScheduler( /** * Get or create a shuffle map stage for the given shuffle dependency's map side. - * The jobId value passed in will be used if the stage doesn't already exist with - * a lower jobId (jobId always increases across jobs.) */ private def getShuffleMapStage( shuffleDep: ShuffleDependency[_, _, _], - jobId: Int): ShuffleMapStage = { + firstJobId: Int): ShuffleMapStage = { shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => // We are going to register ancestor shuffle dependencies - registerShuffleDependencies(shuffleDep, jobId) + registerShuffleDependencies(shuffleDep, firstJobId) // Then register current shuffleDep - val stage = newOrUsedShuffleStage(shuffleDep, jobId) + val stage = newOrUsedShuffleStage(shuffleDep, firstJobId) shuffleToMapStage(shuffleDep.shuffleId) = stage stage @@ -230,15 +228,15 @@ class DAGScheduler( /** * Helper function to eliminate some code re-use when creating new stages. */ - private def getParentStagesAndId(rdd: RDD[_], jobId: Int): (List[Stage], Int) = { - val parentStages = getParentStages(rdd, jobId) + private def getParentStagesAndId(rdd: RDD[_], firstJobId: Int): (List[Stage], Int) = { + val parentStages = getParentStages(rdd, firstJobId) val id = nextStageId.getAndIncrement() (parentStages, id) } /** * Create a ShuffleMapStage as part of the (re)-creation of a shuffle map stage in - * newOrUsedShuffleStage. The stage will be associated with the provided jobId. + * newOrUsedShuffleStage. The stage will be associated with the provided firstJobId. * Production of shuffle map stages should always use newOrUsedShuffleStage, not * newShuffleMapStage directly. */ @@ -246,21 +244,19 @@ class DAGScheduler( rdd: RDD[_], numTasks: Int, shuffleDep: ShuffleDependency[_, _, _], - jobId: Int, + firstJobId: Int, callSite: CallSite): ShuffleMapStage = { - val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId) + val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, firstJobId) val stage: ShuffleMapStage = new ShuffleMapStage(id, rdd, numTasks, parentStages, - jobId, callSite, shuffleDep) + firstJobId, callSite, shuffleDep) stageIdToStage(id) = stage - updateJobIdStageIdMaps(jobId, stage) + updateJobIdStageIdMaps(firstJobId, stage) stage } /** - * Create a ResultStage -- either directly for use as a result stage, or as part of the - * (re)-creation of a shuffle map stage in newOrUsedShuffleStage. The stage will be associated - * with the provided jobId. + * Create a ResultStage associated with the provided jobId. */ private def newResultStage( rdd: RDD[_], @@ -277,16 +273,16 @@ class DAGScheduler( /** * Create a shuffle map Stage for the given RDD. The stage will also be associated with the - * provided jobId. If a stage for the shuffleId existed previously so that the shuffleId is + * provided firstJobId. If a stage for the shuffleId existed previously so that the shuffleId is * present in the MapOutputTracker, then the number and location of available outputs are * recovered from the MapOutputTracker */ private def newOrUsedShuffleStage( shuffleDep: ShuffleDependency[_, _, _], - jobId: Int): ShuffleMapStage = { + firstJobId: Int): ShuffleMapStage = { val rdd = shuffleDep.rdd val numTasks = rdd.partitions.size - val stage = newShuffleMapStage(rdd, numTasks, shuffleDep, jobId, rdd.creationSite) + val stage = newShuffleMapStage(rdd, numTasks, shuffleDep, firstJobId, rdd.creationSite) if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) val locs = MapOutputTracker.deserializeMapStatuses(serLocs) @@ -304,10 +300,10 @@ class DAGScheduler( } /** - * Get or create the list of parent stages for a given RDD. The stages will be assigned the - * provided jobId if they haven't already been created with a lower jobId. + * Get or create the list of parent stages for a given RDD. The new Stages will be created with + * the provided firstJobId. */ - private def getParentStages(rdd: RDD[_], jobId: Int): List[Stage] = { + private def getParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = { val parents = new HashSet[Stage] val visited = new HashSet[RDD[_]] // We are manually maintaining a stack here to prevent StackOverflowError @@ -321,7 +317,7 @@ class DAGScheduler( for (dep <- r.dependencies) { dep match { case shufDep: ShuffleDependency[_, _, _] => - parents += getShuffleMapStage(shufDep, jobId) + parents += getShuffleMapStage(shufDep, firstJobId) case _ => waitingForVisit.push(dep.rdd) } @@ -336,11 +332,11 @@ class DAGScheduler( } /** Find ancestor missing shuffle dependencies and register into shuffleToMapStage */ - private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], jobId: Int) { + private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], firstJobId: Int) { val parentsWithNoMapStage = getAncestorShuffleDependencies(shuffleDep.rdd) while (parentsWithNoMapStage.nonEmpty) { val currentShufDep = parentsWithNoMapStage.pop() - val stage = newOrUsedShuffleStage(currentShufDep, jobId) + val stage = newOrUsedShuffleStage(currentShufDep, firstJobId) shuffleToMapStage(currentShufDep.shuffleId) = stage } } @@ -390,7 +386,7 @@ class DAGScheduler( for (dep <- rdd.dependencies) { dep match { case shufDep: ShuffleDependency[_, _, _] => - val mapStage = getShuffleMapStage(shufDep, stage.jobId) + val mapStage = getShuffleMapStage(shufDep, stage.firstJobId) if (!mapStage.isAvailable) { missing += mapStage } @@ -577,7 +573,7 @@ class DAGScheduler( private[scheduler] def doCancelAllJobs() { // Cancel all running jobs. - runningStages.map(_.jobId).foreach(handleJobCancellation(_, + runningStages.map(_.firstJobId).foreach(handleJobCancellation(_, reason = "as part of cancellation of all jobs")) activeJobs.clear() // These should already be empty by this point, jobIdToActiveJob.clear() // but just in case we lost track of some jobs... @@ -603,7 +599,7 @@ class DAGScheduler( clearCacheLocs() val failedStagesCopy = failedStages.toArray failedStages.clear() - for (stage <- failedStagesCopy.sortBy(_.jobId)) { + for (stage <- failedStagesCopy.sortBy(_.firstJobId)) { submitStage(stage) } } @@ -623,7 +619,7 @@ class DAGScheduler( logTrace("failed: " + failedStages) val waitingStagesCopy = waitingStages.toArray waitingStages.clear() - for (stage <- waitingStagesCopy.sortBy(_.jobId)) { + for (stage <- waitingStagesCopy.sortBy(_.firstJobId)) { submitStage(stage) } } @@ -843,7 +839,7 @@ class DAGScheduler( } } - val properties = jobIdToActiveJob.get(stage.jobId).map(_.properties).orNull + val properties = jobIdToActiveJob.get(stage.firstJobId).map(_.properties).orNull runningStages += stage // SparkListenerStageSubmitted should be posted before testing whether tasks are @@ -909,7 +905,7 @@ class DAGScheduler( stage.pendingTasks ++= tasks logDebug("New pending tasks: " + stage.pendingTasks) taskScheduler.submitTasks( - new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) + new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.firstJobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } else { // Because we posted SparkListenerStageSubmitted earlier, we should mark @@ -1323,7 +1319,7 @@ class DAGScheduler( for (dep <- rdd.dependencies) { dep match { case shufDep: ShuffleDependency[_, _, _] => - val mapStage = getShuffleMapStage(shufDep, stage.jobId) + val mapStage = getShuffleMapStage(shufDep, stage.firstJobId) if (!mapStage.isAvailable) { waitingForVisit.push(mapStage.rdd) } // Otherwise there's no need to follow the dependency back diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala index c0f3d5a13d623..bf81b9aca4810 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala @@ -28,9 +28,9 @@ private[spark] class ResultStage( rdd: RDD[_], numTasks: Int, parents: List[Stage], - jobId: Int, + firstJobId: Int, callSite: CallSite) - extends Stage(id, rdd, numTasks, parents, jobId, callSite) { + extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { // The active job for this result stage. Will be empty if the job has already finished // (e.g., because the job was cancelled). diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index d02210743484c..66c75f325fcde 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -30,10 +30,10 @@ private[spark] class ShuffleMapStage( rdd: RDD[_], numTasks: Int, parents: List[Stage], - jobId: Int, + firstJobId: Int, callSite: CallSite, val shuffleDep: ShuffleDependency[_, _, _]) - extends Stage(id, rdd, numTasks, parents, jobId, callSite) { + extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { override def toString: String = "ShuffleMapStage " + id diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 5d0ddb8377c33..c59d6e4f5bc04 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -34,7 +34,7 @@ import org.apache.spark.util.CallSite * initiated a job (e.g. count(), save(), etc). For shuffle map stages, we also track the nodes * that each output partition is on. * - * Each Stage also has a jobId, identifying the job that first submitted the stage. When FIFO + * Each Stage also has a firstJobId, identifying the job that first submitted the stage. When FIFO * scheduling is used, this allows Stages from earlier jobs to be computed first or recovered * faster on failure. * @@ -51,7 +51,7 @@ private[spark] abstract class Stage( val rdd: RDD[_], val numTasks: Int, val parents: List[Stage], - val jobId: Int, + val firstJobId: Int, val callSite: CallSite) extends Logging { From 15459db4f6867e95076cf53fade2fca833c4cf4e Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 27 May 2015 10:09:12 -0700 Subject: [PATCH 025/254] [SPARK-7847] [SQL] Fixes dynamic partition directory escaping Please refer to [SPARK-7847] [1] for details. [1]: https://issues.apache.org/jira/browse/SPARK-7847 Author: Cheng Lian Closes #6389 from liancheng/spark-7847 and squashes the following commits: 935c652 [Cheng Lian] Adds test case for writing various data types as dynamic partition value f4fc398 [Cheng Lian] Converts partition columns to Scala type when writing dynamic partitions d0aeca0 [Cheng Lian] Fixes dynamic partition directory escaping --- .../apache/spark/sql/parquet/newParquet.scala | 22 ++++-- .../spark/sql/sources/PartitioningUtils.scala | 76 ++++++++++++++++++- .../apache/spark/sql/sources/commands.scala | 57 ++------------ .../ParquetPartitionDiscoverySuite.scala | 57 +++++++++++++- 4 files changed, 152 insertions(+), 60 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index cb1e60883df1e..8b3e1b2b59bf6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.parquet +import java.net.URI import java.util.{List => JList} import scala.collection.JavaConversions._ @@ -282,21 +283,28 @@ private[sql] class ParquetRelation2( val cacheMetadata = useMetadataCache @transient val cachedStatuses = inputFiles.map { f => - // In order to encode the authority of a Path containing special characters such as /, - // we need to use the string returned by the URI of the path to create a new Path. - val pathWithAuthority = new Path(f.getPath.toUri.toString) - + // In order to encode the authority of a Path containing special characters such as '/' + // (which does happen in some S3N credentials), we need to use the string returned by the + // URI of the path to create a new Path. + val pathWithEscapedAuthority = escapePathUserInfo(f.getPath) new FileStatus( f.getLen, f.isDir, f.getReplication, f.getBlockSize, f.getModificationTime, - f.getAccessTime, f.getPermission, f.getOwner, f.getGroup, pathWithAuthority) + f.getAccessTime, f.getPermission, f.getOwner, f.getGroup, pathWithEscapedAuthority) }.toSeq @transient val cachedFooters = footers.map { f => // In order to encode the authority of a Path containing special characters such as /, // we need to use the string returned by the URI of the path to create a new Path. - new Footer(new Path(f.getFile.toUri.toString), f.getParquetMetadata) + new Footer(escapePathUserInfo(f.getFile), f.getParquetMetadata) }.toSeq + private def escapePathUserInfo(path: Path): Path = { + val uri = path.toUri + new Path(new URI( + uri.getScheme, uri.getRawUserInfo, uri.getHost, uri.getPort, uri.getPath, + uri.getQuery, uri.getFragment)) + } + // Overridden so we can inject our own cached files statuses. override def getPartitions: Array[SparkPartition] = { val inputFormat = if (cacheMetadata) { @@ -377,7 +385,7 @@ private[sql] class ParquetRelation2( .orElse(readSchema()) .orElse(maybeMetastoreSchema) .getOrElse(sys.error("Failed to get the schema.")) - + // If this Parquet relation is converted from a Hive Metastore table, must reconcile case // case insensitivity issue and possible schema mismatch (probably caused by schema // evolution). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala index e0ead23d786f9..dafdf0f8b4564 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer import scala.util.Try import org.apache.hadoop.fs.Path +import org.apache.hadoop.util.Shell import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} @@ -221,7 +222,7 @@ private[sql] object PartitioningUtils { // Then falls back to string .getOrElse { if (raw == defaultPartitionName) Literal.create(null, NullType) - else Literal.create(raw, StringType) + else Literal.create(unescapePathName(raw), StringType) } } @@ -243,4 +244,77 @@ private[sql] object PartitioningUtils { Literal.create(Cast(l, desiredType).eval(), desiredType) } } + + ////////////////////////////////////////////////////////////////////////////////////////////////// + // The following string escaping code is mainly copied from Hive (o.a.h.h.common.FileUtils). + ////////////////////////////////////////////////////////////////////////////////////////////////// + + val charToEscape = { + val bitSet = new java.util.BitSet(128) + + /** + * ASCII 01-1F are HTTP control characters that need to be escaped. + * \u000A and \u000D are \n and \r, respectively. + */ + val clist = Array( + '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u0009', + '\n', '\u000B', '\u000C', '\r', '\u000E', '\u000F', '\u0010', '\u0011', '\u0012', '\u0013', + '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001A', '\u001B', '\u001C', + '\u001D', '\u001E', '\u001F', '"', '#', '%', '\'', '*', '/', ':', '=', '?', '\\', '\u007F', + '{', '[', ']', '^') + + clist.foreach(bitSet.set(_)) + + if (Shell.WINDOWS) { + Array(' ', '<', '>', '|').foreach(bitSet.set(_)) + } + + bitSet + } + + def needsEscaping(c: Char): Boolean = { + c >= 0 && c < charToEscape.size() && charToEscape.get(c) + } + + def escapePathName(path: String): String = { + val builder = new StringBuilder() + path.foreach { c => + if (needsEscaping(c)) { + builder.append('%') + builder.append(f"${c.asInstanceOf[Int]}%02x") + } else { + builder.append(c) + } + } + + builder.toString() + } + + def unescapePathName(path: String): String = { + val sb = new StringBuilder + var i = 0 + + while (i < path.length) { + val c = path.charAt(i) + if (c == '%' && i + 2 < path.length) { + val code: Int = try { + Integer.valueOf(path.substring(i + 1, i + 3), 16) + } catch { case e: Exception => + -1: Integer + } + if (code >= 0) { + sb.append(code.asInstanceOf[Char]) + i += 3 + } else { + sb.append(c) + i += 1 + } + } else { + sb.append(c) + i += 1 + } + } + + sb.toString() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index fbd98ef0380e1..3132067d562f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -24,7 +24,6 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter, FileOutputFormat} -import org.apache.hadoop.util.Shell import parquet.hadoop.util.ContextUtil import org.apache.spark._ @@ -35,7 +34,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.{SQLConf, DataFrame, SQLContext, SaveMode} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext, SaveMode} private[sql] case class InsertIntoDataSource( logicalRelation: LogicalRelation, @@ -208,9 +208,11 @@ private[sql] case class InsertIntoHadoopFsRelation( writerContainer.outputWriterForRow(partitionPart).write(convertedDataPart) } } else { + val partitionSchema = StructType.fromAttributes(partitionOutput) + val converter = CatalystTypeConverters.createToScalaConverter(partitionSchema) while (iterator.hasNext) { val row = iterator.next() - val partitionPart = partitionProj(row) + val partitionPart = converter(partitionProj(row)).asInstanceOf[Row] val dataPart = dataProj(row) writerContainer.outputWriterForRow(partitionPart).write(dataPart) } @@ -416,7 +418,7 @@ private[sql] class DynamicPartitionWriterContainer( val valueString = if (string == null || string.isEmpty) { defaultPartitionName } else { - DynamicPartitionWriterContainer.escapePathName(string) + PartitioningUtils.escapePathName(string) } s"/$col=$valueString" }.mkString.stripPrefix(Path.SEPARATOR) @@ -448,50 +450,3 @@ private[sql] class DynamicPartitionWriterContainer( } } } - -private[sql] object DynamicPartitionWriterContainer { - ////////////////////////////////////////////////////////////////////////////////////////////////// - // The following string escaping code is mainly copied from Hive (o.a.h.h.common.FileUtils). - ////////////////////////////////////////////////////////////////////////////////////////////////// - - val charToEscape = { - val bitSet = new java.util.BitSet(128) - - /** - * ASCII 01-1F are HTTP control characters that need to be escaped. - * \u000A and \u000D are \n and \r, respectively. - */ - val clist = Array( - '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u0009', - '\n', '\u000B', '\u000C', '\r', '\u000E', '\u000F', '\u0010', '\u0011', '\u0012', '\u0013', - '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001A', '\u001B', '\u001C', - '\u001D', '\u001E', '\u001F', '"', '#', '%', '\'', '*', '/', ':', '=', '?', '\\', '\u007F', - '{', '[', ']', '^') - - clist.foreach(bitSet.set(_)) - - if (Shell.WINDOWS) { - Array(' ', '<', '>', '|').foreach(bitSet.set(_)) - } - - bitSet - } - - def needsEscaping(c: Char): Boolean = { - c >= 0 && c < charToEscape.size() && charToEscape.get(c) - } - - def escapePathName(path: String): String = { - val builder = new StringBuilder() - path.foreach { c => - if (DynamicPartitionWriterContainer.needsEscaping(c)) { - builder.append('%') - builder.append(f"${c.asInstanceOf[Int]}%02x") - } else { - builder.append(c) - } - } - - builder.toString() - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index 90d4528efca48..f231589e9674d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.parquet import java.io.File +import java.math.BigInteger +import java.sql.{Timestamp, Date} import scala.collection.mutable.ArrayBuffer @@ -27,7 +29,7 @@ import org.apache.spark.sql.sources.PartitioningUtils._ import org.apache.spark.sql.sources.{LogicalRelation, Partition, PartitionSpec} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.{QueryTest, Row, SQLContext} +import org.apache.spark.sql.{Column, QueryTest, Row, SQLContext} // The data where the partitioning key exists only in the directory structure. case class ParquetData(intField: Int, stringField: String) @@ -377,4 +379,57 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { } } } + + test("SPARK-7847: Dynamic partition directory path escaping and unescaping") { + withTempPath { dir => + val df = Seq("/", "[]", "?").zipWithIndex.map(_.swap).toDF("i", "s") + df.write.format("parquet").partitionBy("s").save(dir.getCanonicalPath) + checkAnswer(read.parquet(dir.getCanonicalPath), df.collect()) + } + } + + test("Various partition value types") { + val row = + Row( + 100.toByte, + 40000.toShort, + Int.MaxValue, + Long.MaxValue, + 1.5.toFloat, + 4.5, + new java.math.BigDecimal(new BigInteger("212500"), 5), + new java.math.BigDecimal(2.125), + java.sql.Date.valueOf("2015-05-23"), + new Timestamp(0), + "This is a string, /[]?=:", + "This is not a partition column") + + // BooleanType is not supported yet + val partitionColumnTypes = + Seq( + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + DecimalType(10, 5), + DecimalType.Unlimited, + DateType, + TimestampType, + StringType) + + val partitionColumns = partitionColumnTypes.zipWithIndex.map { + case (t, index) => StructField(s"p_$index", t) + } + + val schema = StructType(partitionColumns :+ StructField(s"i", StringType)) + val df = createDataFrame(sparkContext.parallelize(row :: Nil), schema) + + withTempPath { dir => + df.write.format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) + val fields = schema.map(f => Column(f.name).cast(f.dataType)) + checkAnswer(read.load(dir.toString).select(fields: _*), row) + } + } } From 0db76c90ad5f84d7a5640c41de74876b906ddc90 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 27 May 2015 11:41:35 -0700 Subject: [PATCH 026/254] [SPARK-7864] [UI] Fix the logic grabbing the link from table in AllJobPage This issue is related to #6419 . Now AllJobPage doesn't have a "kill link" but I think fix the issue mentioned in #6419 just in case to avoid accidents in the future. So, it's minor issue for now and I don't file this issue in JIRA. Author: Kousuke Saruta Closes #6432 from sarutak/remove-ambiguity-of-link and squashes the following commits: cd1a503 [Kousuke Saruta] Fixed ambiguity link issue in AllJobPage --- .../main/resources/org/apache/spark/ui/static/timeline-view.js | 2 +- core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js index 28ac998e8d065..ca74ef9d7e94e 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js +++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js @@ -46,7 +46,7 @@ function drawApplicationTimeline(groupArray, eventObjArray, startTime) { }; $(this).click(function() { - var jobPagePath = $(getSelectorForJobEntry(this)).find("a").attr("href") + var jobPagePath = $(getSelectorForJobEntry(this)).find("a.name-link").attr("href") window.location.href = jobPagePath }); diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index e010ebef3b34a..2ce670ad02e97 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -231,7 +231,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { {lastStageDescription} - {lastStageName} + {lastStageName} {formattedSubmissionTime} From 6fec1a9409b34d8ce58ea1c330b52cc7ef3e7e7e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 27 May 2015 11:54:35 -0700 Subject: [PATCH 027/254] Removed Guava dependency from JavaTypeInference's type signature. This should also close #6243. Author: Reynold Xin Closes #6431 from rxin/JavaTypeInference-guava and squashes the following commits: e58df3c [Reynold Xin] Removed Gauva dependency from JavaTypeInference's type signature. --- .../apache/spark/sql/catalyst/JavaTypeInference.scala | 11 ++++++++++- .../main/scala/org/apache/spark/sql/SQLContext.scala | 4 +--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 625c8d3a62125..9a3f9694e4c48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -38,12 +38,21 @@ private [sql] object JavaTypeInference { private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType + /** + * Infers the corresponding SQL data type of a JavaClean class. + * @param beanClass Java type + * @return (SQL data type, nullable) + */ + def inferDataType(beanClass: Class[_]): (DataType, Boolean) = { + inferDataType(TypeToken.of(beanClass)) + } + /** * Infers the corresponding SQL data type of a Java type. * @param typeToken Java type * @return (SQL data type, nullable) */ - private [sql] def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { + private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific. typeToken.getRawType match { case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 3935f7b321b85..15c30352bee69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -27,8 +27,6 @@ import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import com.google.common.reflect.TypeToken - import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} @@ -1011,7 +1009,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * Returns a Catalyst Schema for the given java bean class. */ protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = { - val (dataType, _) = JavaTypeInference.inferDataType(TypeToken.of(beanClass)) + val (dataType, _) = JavaTypeInference.inferDataType(beanClass) dataType.asInstanceOf[StructType].fields.map { f => AttributeReference(f.name, f.dataType, f.nullable)() } From 8161562eabc1eff430cfd9d8eaf413a8c4ef2cfb Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 27 May 2015 12:42:13 -0700 Subject: [PATCH 028/254] [SPARK-7790] [SQL] date and decimal conversion for dynamic partition key Author: Daoyuan Wang Closes #6318 from adrian-wang/dynpart and squashes the following commits: ad73b61 [Daoyuan Wang] not use sqlTestUtils for try catch because dont have sqlcontext here 6c33b51 [Daoyuan Wang] fix according to liancheng f0f8074 [Daoyuan Wang] some specific types as dynamic partition --- .../hive/execution/InsertIntoHiveTable.scala | 2 +- .../spark/sql/hive/hiveWriterContainers.scala | 17 ++++++++-- .../sql/hive/execution/SQLQuerySuite.scala | 33 +++++++++++++++++++ 3 files changed, 48 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index c0b0b104e9142..7a6ca48b54a24 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -106,7 +106,7 @@ case class InsertIntoHiveTable( } writerContainer - .getLocalFileWriter(row) + .getLocalFileWriter(row, table.schema) .write(serializer.serialize(outputData, standardOI)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index cbc381cc81b59..50b209f7ccbb8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -34,8 +34,10 @@ import org.apache.hadoop.hive.common.FileUtils import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.Row import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.hive.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive.HiveShim._ +import org.apache.spark.sql.types._ /** * Internal helper class that saves an RDD using a Hive OutputFormat. @@ -92,7 +94,7 @@ private[hive] class SparkHiveWriterContainer( "part-" + numberFormat.format(splitID) + extension } - def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = writer + def getLocalFileWriter(row: Row, schema: StructType): FileSinkOperator.RecordWriter = writer def close() { // Seems the boolean value passed into close does not matter. @@ -195,11 +197,20 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( jobConf.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, oldMarker) } - override def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = { + override def getLocalFileWriter(row: Row, schema: StructType): FileSinkOperator.RecordWriter = { + def convertToHiveRawString(col: String, value: Any): String = { + val raw = String.valueOf(value) + schema(col).dataType match { + case DateType => DateUtils.toString(raw.toInt) + case _: DecimalType => BigDecimal(raw).toString() + case _ => raw + } + } + val dynamicPartPath = dynamicPartColNames .zip(row.toSeq.takeRight(dynamicPartColNames.length)) .map { case (col, rawVal) => - val string = if (rawVal == null) null else String.valueOf(rawVal) + val string = if (rawVal == null) null else convertToHiveRawString(col, rawVal) val colString = if (string == null || string.isEmpty) { defaultPartName diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index b707f5e68489b..538e66125c5fe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -837,4 +837,37 @@ class SQLQuerySuite extends QueryTest { java.lang.Math.exp(1.0).toString, java.lang.Math.floor(1.9).toString)) } + + test("dynamic partition value test") { + try { + sql("set hive.exec.dynamic.partition.mode=nonstrict") + // date + sql("drop table if exists dynparttest1") + sql("create table dynparttest1 (value int) partitioned by (pdate date)") + sql( + """ + |insert into table dynparttest1 partition(pdate) + | select count(*), cast('2015-05-21' as date) as pdate from src + """.stripMargin) + checkAnswer( + sql("select * from dynparttest1"), + Seq(Row(500, java.sql.Date.valueOf("2015-05-21")))) + + // decimal + sql("drop table if exists dynparttest2") + sql("create table dynparttest2 (value int) partitioned by (pdec decimal(5, 1))") + sql( + """ + |insert into table dynparttest2 partition(pdec) + | select count(*), cast('100.12' as decimal(5, 1)) as pdec from src + """.stripMargin) + checkAnswer( + sql("select * from dynparttest2"), + Seq(Row(500, new java.math.BigDecimal("100.1")))) + } finally { + sql("drop table if exists dynparttest1") + sql("drop table if exists dynparttest2") + sql("set hive.exec.dynamic.partition.mode=strict") + } + } } From b97ddff000b99adca3dd8fe13d01054fd5014fa0 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 27 May 2015 13:09:33 -0700 Subject: [PATCH 029/254] [SPARK-7684] [SQL] Refactoring MetastoreDataSourcesSuite to workaround SPARK-7684 As stated in SPARK-7684, currently `TestHive.reset` has some execution order specific bug, which makes running specific test suites locally pretty frustrating. This PR refactors `MetastoreDataSourcesSuite` (which relies on `TestHive.reset` heavily) using various `withXxx` utility methods in `SQLTestUtils` to ask each test case to cleanup their own mess so that we can avoid calling `TestHive.reset`. Author: Cheng Lian Author: Yin Huai Closes #6353 from liancheng/workaround-spark-7684 and squashes the following commits: 26939aa [Yin Huai] Move the initialization of jsonFilePath to beforeAll. a423d48 [Cheng Lian] Fixes Scala style issue dfe45d0 [Cheng Lian] Refactors MetastoreDataSourcesSuite to workaround SPARK-7684 92a116d [Cheng Lian] Fixes minor styling issues --- .../org/apache/spark/sql/QueryTest.scala | 4 + .../apache/spark/sql/test/SQLTestUtils.scala | 12 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 1372 +++++++++-------- 3 files changed, 722 insertions(+), 666 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index bbf9ab113ca43..98ba3c99283a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -67,6 +67,10 @@ class QueryTest extends PlanTest { checkAnswer(df, Seq(expectedAnswer)) } + protected def checkAnswer(df: DataFrame, expectedAnswer: DataFrame): Unit = { + checkAnswer(df, expectedAnswer.collect()) + } + def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext) { test(sqlString) { checkAnswer(sqlContext.sql(sqlString), expectedAnswer) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index ca66cdc48272d..17a8b0cca09df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -75,14 +75,18 @@ trait SQLTestUtils { /** * Drops temporary table `tableName` after calling `f`. */ - protected def withTempTable(tableName: String)(f: => Unit): Unit = { - try f finally sqlContext.dropTempTable(tableName) + protected def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(sqlContext.dropTempTable) } /** * Drops table `tableName` after calling `f`. */ - protected def withTable(tableName: String)(f: => Unit): Unit = { - try f finally sqlContext.sql(s"DROP TABLE IF EXISTS $tableName") + protected def withTable(tableNames: String*)(f: => Unit): Unit = { + try f finally { + tableNames.foreach { name => + sqlContext.sql(s"DROP TABLE IF EXISTS $name") + } + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 9623ef06aa9b0..58e2d1fbfa73e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -21,770 +21,818 @@ import java.io.File import scala.collection.mutable.ArrayBuffer +import org.scalatest.BeforeAndAfterAll + import org.apache.hadoop.fs.Path import org.apache.hadoop.mapred.InvalidInputException -import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql._ import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.LogicalRelation +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils /** * Tests for persisting tables created though the data sources API into the metastore. */ -class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { +class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { + override val sqlContext = TestHive + + var jsonFilePath: String = _ - override def afterEach(): Unit = { - reset() - Utils.deleteRecursively(tempPath) + override def beforeAll(): Unit = { + jsonFilePath = Utils.getSparkClassLoader.getResource("sample.json").getFile } - val filePath = Utils.getSparkClassLoader.getResource("sample.json").getFile - var tempPath: File = Utils.createTempDir() - tempPath.delete() - - test ("persistent JSON table") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - checkAnswer( - sql("SELECT * FROM jsonTable"), - read.json(filePath).collect().toSeq) + test("persistent JSON table") { + withTable("jsonTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + read.json(jsonFilePath).collect().toSeq) + } } - test ("persistent JSON table with a user specified schema") { - sql( - s""" - |CREATE TABLE jsonTable ( - |a string, - |b String, - |`c_!@(3)` int, - |`` Struct<`d!`:array, `=`:array>>) - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - read.json(filePath).registerTempTable("expectedJsonTable") - - checkAnswer( - sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM jsonTable"), - sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM expectedJsonTable").collect().toSeq) + test("persistent JSON table with a user specified schema") { + withTable("jsonTable") { + sql( + s"""CREATE TABLE jsonTable ( + |a string, + |b String, + |`c_!@(3)` int, + |`` Struct<`d!`:array, `=`:array>>) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + withTempTable("expectedJsonTable") { + read.json(jsonFilePath).registerTempTable("expectedJsonTable") + checkAnswer( + sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM jsonTable"), + sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM expectedJsonTable")) + } + } } - test ("persistent JSON table with a user specified schema with a subset of fields") { - // This works because JSON objects are self-describing and JSONRelation can get needed - // field values based on field names. - sql( - s""" - |CREATE TABLE jsonTable (`` Struct<`=`:array>>, b String) - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - val innerStruct = StructType( - StructField("=", ArrayType(StructType(StructField("Dd2", BooleanType, true) :: Nil))) :: Nil) - val expectedSchema = StructType( - StructField("", innerStruct, true) :: - StructField("b", StringType, true) :: Nil) - - assert(expectedSchema === table("jsonTable").schema) - - read.json(filePath).registerTempTable("expectedJsonTable") - - checkAnswer( - sql("SELECT b, ``.`=` FROM jsonTable"), - sql("SELECT b, ``.`=` FROM expectedJsonTable").collect().toSeq) + test("persistent JSON table with a user specified schema with a subset of fields") { + withTable("jsonTable") { + // This works because JSON objects are self-describing and JSONRelation can get needed + // field values based on field names. + sql( + s"""CREATE TABLE jsonTable (`` Struct<`=`:array>>, b String) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + val innerStruct = StructType(Seq( + StructField("=", ArrayType(StructType(StructField("Dd2", BooleanType, true) :: Nil))))) + + val expectedSchema = StructType(Seq( + StructField("", innerStruct, true), + StructField("b", StringType, true))) + + assert(expectedSchema === table("jsonTable").schema) + + withTempTable("expectedJsonTable") { + read.json(jsonFilePath).registerTempTable("expectedJsonTable") + checkAnswer( + sql("SELECT b, ``.`=` FROM jsonTable"), + sql("SELECT b, ``.`=` FROM expectedJsonTable")) + } + } } test("resolve shortened provider names") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - checkAnswer( - sql("SELECT * FROM jsonTable"), - read.json(filePath).collect().toSeq) + withTable("jsonTable") { + sql( + s""" + |CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + read.json(jsonFilePath).collect().toSeq) + } } test("drop table") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - checkAnswer( - sql("SELECT * FROM jsonTable"), - read.json(filePath).collect().toSeq) - - sql("DROP TABLE jsonTable") - - intercept[Exception] { - sql("SELECT * FROM jsonTable").collect() - } + withTable("jsonTable") { + sql( + s""" + |CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + read.json(jsonFilePath)) + + sql("DROP TABLE jsonTable") - assert( - (new File(filePath)).exists(), - "The table with specified path is considered as an external table, " + - "its data should not deleted after DROP TABLE.") + intercept[Exception] { + sql("SELECT * FROM jsonTable").collect() + } + + assert( + new File(jsonFilePath).exists(), + "The table with specified path is considered as an external table, " + + "its data should not deleted after DROP TABLE.") + } } test("check change without refresh") { - val tempDir = File.createTempFile("sparksql", "json", Utils.createTempDir()) - tempDir.delete() - sparkContext.parallelize(("a", "b") :: Nil).toDF() - .toJSON.saveAsTextFile(tempDir.getCanonicalPath) - - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json - |OPTIONS ( - | path '${tempDir.getCanonicalPath}' - |) - """.stripMargin) - - checkAnswer( - sql("SELECT * FROM jsonTable"), - Row("a", "b")) - - Utils.deleteRecursively(tempDir) - sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toDF() - .toJSON.saveAsTextFile(tempDir.getCanonicalPath) - - // Schema is cached so the new column does not show. The updated values in existing columns - // will show. - checkAnswer( - sql("SELECT * FROM jsonTable"), - Row("a1", "b1")) - - sql("REFRESH TABLE jsonTable") - - // Check that the refresh worked - checkAnswer( - sql("SELECT * FROM jsonTable"), - Row("a1", "b1", "c1")) - Utils.deleteRecursively(tempDir) + withTempPath { tempDir => + withTable("jsonTable") { + (("a", "b") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) + + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '${tempDir.getCanonicalPath}' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + Row("a", "b")) + + Utils.deleteRecursively(tempDir) + (("a1", "b1", "c1") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) + + // Schema is cached so the new column does not show. The updated values in existing columns + // will show. + checkAnswer( + sql("SELECT * FROM jsonTable"), + Row("a1", "b1")) + + sql("REFRESH TABLE jsonTable") + + // Check that the refresh worked + checkAnswer( + sql("SELECT * FROM jsonTable"), + Row("a1", "b1", "c1")) + } + } } test("drop, change, recreate") { - val tempDir = File.createTempFile("sparksql", "json", Utils.createTempDir()) - tempDir.delete() - sparkContext.parallelize(("a", "b") :: Nil).toDF() - .toJSON.saveAsTextFile(tempDir.getCanonicalPath) - - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json - |OPTIONS ( - | path '${tempDir.getCanonicalPath}' - |) - """.stripMargin) - - checkAnswer( - sql("SELECT * FROM jsonTable"), - Row("a", "b")) - - Utils.deleteRecursively(tempDir) - sparkContext.parallelize(("a", "b", "c") :: Nil).toDF() - .toJSON.saveAsTextFile(tempDir.getCanonicalPath) - - sql("DROP TABLE jsonTable") - - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json - |OPTIONS ( - | path '${tempDir.getCanonicalPath}' - |) - """.stripMargin) - - // New table should reflect new schema. - checkAnswer( - sql("SELECT * FROM jsonTable"), - Row("a", "b", "c")) - Utils.deleteRecursively(tempDir) + withTempPath { tempDir => + (("a", "b") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) + + withTable("jsonTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '${tempDir.getCanonicalPath}' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + Row("a", "b")) + + Utils.deleteRecursively(tempDir) + (("a", "b", "c") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) + + sql("DROP TABLE jsonTable") + + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '${tempDir.getCanonicalPath}' + |) + """.stripMargin) + + // New table should reflect new schema. + checkAnswer( + sql("SELECT * FROM jsonTable"), + Row("a", "b", "c")) + } + } } test("invalidate cache and reload") { - sql( - s""" - |CREATE TABLE jsonTable (`c_!@(3)` int) - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) + withTable("jsonTable") { + sql( + s"""CREATE TABLE jsonTable (`c_!@(3)` int) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) - read.json(filePath).registerTempTable("expectedJsonTable") + withTempTable("expectedJsonTable") { + read.json(jsonFilePath).registerTempTable("expectedJsonTable") - checkAnswer( - sql("SELECT * FROM jsonTable"), - sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) + checkAnswer( + sql("SELECT * FROM jsonTable"), + sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) - // Discard the cached relation. - invalidateTable("jsonTable") + // Discard the cached relation. + invalidateTable("jsonTable") - checkAnswer( - sql("SELECT * FROM jsonTable"), - sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) + checkAnswer( + sql("SELECT * FROM jsonTable"), + sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) - invalidateTable("jsonTable") - val expectedSchema = StructType(StructField("c_!@(3)", IntegerType, true) :: Nil) + invalidateTable("jsonTable") + val expectedSchema = StructType(StructField("c_!@(3)", IntegerType, true) :: Nil) - assert(expectedSchema === table("jsonTable").schema) + assert(expectedSchema === table("jsonTable").schema) + } + } } test("CTAS") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - sql( - s""" - |CREATE TABLE ctasJsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${tempPath}' - |) AS - |SELECT * FROM jsonTable - """.stripMargin) - - assert(table("ctasJsonTable").schema === table("jsonTable").schema) - - checkAnswer( - sql("SELECT * FROM ctasJsonTable"), - sql("SELECT * FROM jsonTable").collect()) + withTempPath { tempPath => + withTable("jsonTable", "ctasJsonTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + sql( + s"""CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$tempPath' + |) AS + |SELECT * FROM jsonTable + """.stripMargin) + + assert(table("ctasJsonTable").schema === table("jsonTable").schema) + + checkAnswer( + sql("SELECT * FROM ctasJsonTable"), + sql("SELECT * FROM jsonTable").collect()) + } + } } test("CTAS with IF NOT EXISTS") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - sql( - s""" - |CREATE TABLE ctasJsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${tempPath}' - |) AS - |SELECT * FROM jsonTable - """.stripMargin) - - // Create the table again should trigger a AnalysisException. - val message = intercept[AnalysisException] { - sql( - s""" - |CREATE TABLE ctasJsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${tempPath}' - |) AS - |SELECT * FROM jsonTable - """.stripMargin) - }.getMessage - assert(message.contains("Table ctasJsonTable already exists."), - "We should complain that ctasJsonTable already exists") - - // The following statement should be fine if it has IF NOT EXISTS. - // It tries to create a table ctasJsonTable with a new schema. - // The actual table's schema and data should not be changed. - sql( - s""" - |CREATE TABLE IF NOT EXISTS ctasJsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${tempPath}' - |) AS - |SELECT a FROM jsonTable - """.stripMargin) - - // Discard the cached relation. - invalidateTable("ctasJsonTable") - - // Schema should not be changed. - assert(table("ctasJsonTable").schema === table("jsonTable").schema) - // Table data should not be changed. - checkAnswer( - sql("SELECT * FROM ctasJsonTable"), - sql("SELECT * FROM jsonTable").collect()) + withTempPath { path => + val tempPath = path.getCanonicalPath + + withTable("jsonTable", "ctasJsonTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + sql( + s"""CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$tempPath' + |) AS + |SELECT * FROM jsonTable + """.stripMargin) + + // Create the table again should trigger a AnalysisException. + val message = intercept[AnalysisException] { + sql( + s"""CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$tempPath' + |) AS + |SELECT * FROM jsonTable + """.stripMargin) + }.getMessage + + assert( + message.contains("Table ctasJsonTable already exists."), + "We should complain that ctasJsonTable already exists") + + // The following statement should be fine if it has IF NOT EXISTS. + // It tries to create a table ctasJsonTable with a new schema. + // The actual table's schema and data should not be changed. + sql( + s"""CREATE TABLE IF NOT EXISTS ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$tempPath' + |) AS + |SELECT a FROM jsonTable + """.stripMargin) + + // Discard the cached relation. + invalidateTable("ctasJsonTable") + + // Schema should not be changed. + assert(table("ctasJsonTable").schema === table("jsonTable").schema) + // Table data should not be changed. + checkAnswer( + sql("SELECT * FROM ctasJsonTable"), + sql("SELECT * FROM jsonTable").collect()) + } + } } test("CTAS a managed table") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - val expectedPath = catalog.hiveDefaultTableFilePath("ctasJsonTable") - val filesystemPath = new Path(expectedPath) - val fs = filesystemPath.getFileSystem(sparkContext.hadoopConfiguration) - if (fs.exists(filesystemPath)) fs.delete(filesystemPath, true) - - // It is a managed table when we do not specify the location. - sql( - s""" - |CREATE TABLE ctasJsonTable - |USING org.apache.spark.sql.json.DefaultSource - |AS - |SELECT * FROM jsonTable - """.stripMargin) - - assert(fs.exists(filesystemPath), s"$expectedPath should exist after we create the table.") - - sql( - s""" - |CREATE TABLE loadedTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${expectedPath}' - |) - """.stripMargin) - - assert(table("ctasJsonTable").schema === table("loadedTable").schema) - - checkAnswer( - sql("SELECT * FROM ctasJsonTable"), - sql("SELECT * FROM loadedTable").collect() - ) - - sql("DROP TABLE ctasJsonTable") - assert(!fs.exists(filesystemPath), s"$expectedPath should not exist after we drop the table.") - } - - test("SPARK-5286 Fail to drop an invalid table when using the data source API") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path 'it is not a path at all!' - |) - """.stripMargin) - - sql("DROP TABLE jsonTable").collect().foreach(println) - } - - test("SPARK-5839 HiveMetastoreCatalog does not recognize table aliases of data source tables.") { - val originalDefaultSource = conf.defaultDataSourceName - - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - val df = read.json(rdd) - - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") - // Save the df as a managed table (by not specifiying the path). - df.write.saveAsTable("savedJsonTable") - - checkAnswer( - sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), - (1 to 4).map(i => Row(i, s"str${i}"))) + withTable("jsonTable", "ctasJsonTable", "loadedTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + val expectedPath = catalog.hiveDefaultTableFilePath("ctasJsonTable") + val filesystemPath = new Path(expectedPath) + val fs = filesystemPath.getFileSystem(sparkContext.hadoopConfiguration) + if (fs.exists(filesystemPath)) fs.delete(filesystemPath, true) + + // It is a managed table when we do not specify the location. + sql( + s"""CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |AS + |SELECT * FROM jsonTable + """.stripMargin) - checkAnswer( - sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), - (6 to 10).map(i => Row(i, s"str${i}"))) + assert(fs.exists(filesystemPath), s"$expectedPath should exist after we create the table.") - invalidateTable("savedJsonTable") + sql( + s"""CREATE TABLE loadedTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$expectedPath' + |) + """.stripMargin) - checkAnswer( - sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), - (1 to 4).map(i => Row(i, s"str${i}"))) + assert(table("ctasJsonTable").schema === table("loadedTable").schema) - checkAnswer( - sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), - (6 to 10).map(i => Row(i, s"str${i}"))) + checkAnswer( + sql("SELECT * FROM ctasJsonTable"), + sql("SELECT * FROM loadedTable")) - // Drop table will also delete the data. - sql("DROP TABLE savedJsonTable") + sql("DROP TABLE ctasJsonTable") + assert(!fs.exists(filesystemPath), s"$expectedPath should not exist after we drop the table.") + } + } - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + test("SPARK-5286 Fail to drop an invalid table when using the data source API") { + withTable("jsonTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path 'it is not a path at all!' + |) + """.stripMargin) + + sql("DROP TABLE jsonTable").collect().foreach(println) + } } - test("save table") { - val originalDefaultSource = conf.defaultDataSourceName + test("SPARK-5839 HiveMetastoreCatalog does not recognize table aliases of data source tables.") { + withTable("savedJsonTable") { + // Save the df as a managed table (by not specifying the path). + (1 to 10) + .map(i => i -> s"str$i") + .toDF("a", "b") + .write + .format("json") + .saveAsTable("savedJsonTable") - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - val df = read.json(rdd) + checkAnswer( + sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), + (1 to 4).map(i => Row(i, s"str$i"))) - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") - // Save the df as a managed table (by not specifiying the path). - df.write.saveAsTable("savedJsonTable") + checkAnswer( + sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), + (6 to 10).map(i => Row(i, s"str$i"))) - checkAnswer( - sql("SELECT * FROM savedJsonTable"), - df.collect()) + invalidateTable("savedJsonTable") - // Right now, we cannot append to an existing JSON table. - intercept[RuntimeException] { - df.write.mode(SaveMode.Append).saveAsTable("savedJsonTable") - } + checkAnswer( + sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), + (1 to 4).map(i => Row(i, s"str$i"))) - // We can overwrite it. - df.write.mode(SaveMode.Overwrite).saveAsTable("savedJsonTable") - checkAnswer( - sql("SELECT * FROM savedJsonTable"), - df.collect()) - - // When the save mode is Ignore, we will do nothing when the table already exists. - df.select("b").write.mode(SaveMode.Ignore).saveAsTable("savedJsonTable") - assert(df.schema === table("savedJsonTable").schema) - checkAnswer( - sql("SELECT * FROM savedJsonTable"), - df.collect()) - - // Drop table will also delete the data. - sql("DROP TABLE savedJsonTable") - intercept[InvalidInputException] { - read.json(catalog.hiveDefaultTableFilePath("savedJsonTable")) + checkAnswer( + sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), + (6 to 10).map(i => Row(i, s"str$i"))) } + } - // Create an external table by specifying the path. - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - df.write - .format("org.apache.spark.sql.json") - .mode(SaveMode.Append) - .option("path", tempPath.toString) - .saveAsTable("savedJsonTable") - checkAnswer( - sql("SELECT * FROM savedJsonTable"), - df.collect()) - - // Data should not be deleted after we drop the table. - sql("DROP TABLE savedJsonTable") - checkAnswer( - read.json(tempPath.toString), - df.collect()) - - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + test("save table") { + withTempPath { path => + val tempPath = path.getCanonicalPath + + withTable("savedJsonTable") { + val df = (1 to 10).map(i => i -> s"str$i").toDF("a", "b") + + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "json") { + // Save the df as a managed table (by not specifying the path). + df.write.saveAsTable("savedJsonTable") + + checkAnswer(sql("SELECT * FROM savedJsonTable"), df) + + // Right now, we cannot append to an existing JSON table. + intercept[RuntimeException] { + df.write.mode(SaveMode.Append).saveAsTable("savedJsonTable") + } + + // We can overwrite it. + df.write.mode(SaveMode.Overwrite).saveAsTable("savedJsonTable") + checkAnswer(sql("SELECT * FROM savedJsonTable"), df) + + // When the save mode is Ignore, we will do nothing when the table already exists. + df.select("b").write.mode(SaveMode.Ignore).saveAsTable("savedJsonTable") + assert(df.schema === table("savedJsonTable").schema) + checkAnswer(sql("SELECT * FROM savedJsonTable"), df) + + // Drop table will also delete the data. + sql("DROP TABLE savedJsonTable") + intercept[InvalidInputException] { + read.json(catalog.hiveDefaultTableFilePath("savedJsonTable")) + } + } + + // Create an external table by specifying the path. + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "not a source name") { + df.write + .format("org.apache.spark.sql.json") + .mode(SaveMode.Append) + .option("path", tempPath.toString) + .saveAsTable("savedJsonTable") + + checkAnswer(sql("SELECT * FROM savedJsonTable"), df) + } + + // Data should not be deleted after we drop the table. + sql("DROP TABLE savedJsonTable") + checkAnswer(read.json(tempPath.toString), df) + } + } } test("create external table") { - val originalDefaultSource = conf.defaultDataSourceName - - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - val df = read.json(rdd) - - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - df.write.format("org.apache.spark.sql.json") - .mode(SaveMode.Append) - .option("path", tempPath.toString) - .saveAsTable("savedJsonTable") - - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") - createExternalTable("createdJsonTable", tempPath.toString) - assert(table("createdJsonTable").schema === df.schema) - checkAnswer( - sql("SELECT * FROM createdJsonTable"), - df.collect()) - - var message = intercept[AnalysisException] { - createExternalTable("createdJsonTable", filePath.toString) - }.getMessage - assert(message.contains("Table createdJsonTable already exists."), - "We should complain that ctasJsonTable already exists") - - // Data should not be deleted. - sql("DROP TABLE createdJsonTable") - checkAnswer( - read.json(tempPath.toString), - df.collect()) - - // Try to specify the schema. - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - val schema = StructType(StructField("b", StringType, true) :: Nil) - createExternalTable( - "createdJsonTable", - "org.apache.spark.sql.json", - schema, - Map("path" -> tempPath.toString)) - checkAnswer( - sql("SELECT * FROM createdJsonTable"), - sql("SELECT b FROM savedJsonTable").collect()) - - sql("DROP TABLE createdJsonTable") - - message = intercept[RuntimeException] { - createExternalTable( - "createdJsonTable", - "org.apache.spark.sql.json", - schema, - Map.empty[String, String]) - }.getMessage - assert( - message.contains("'path' must be specified for json data."), - "We should complain that path is not specified.") - - sql("DROP TABLE savedJsonTable") - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + withTempPath { tempPath => + withTable("savedJsonTable", "createdJsonTable") { + val df = read.json(sparkContext.parallelize((1 to 10).map { i => + s"""{ "a": $i, "b": "str$i" }""" + })) + + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "not a source name") { + df.write + .format("json") + .mode(SaveMode.Append) + .option("path", tempPath.toString) + .saveAsTable("savedJsonTable") + } + + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "json") { + createExternalTable("createdJsonTable", tempPath.toString) + assert(table("createdJsonTable").schema === df.schema) + checkAnswer(sql("SELECT * FROM createdJsonTable"), df) + + assert( + intercept[AnalysisException] { + createExternalTable("createdJsonTable", jsonFilePath.toString) + }.getMessage.contains("Table createdJsonTable already exists."), + "We should complain that createdJsonTable already exists") + } + + // Data should not be deleted. + sql("DROP TABLE createdJsonTable") + checkAnswer(read.json(tempPath.toString), df) + + // Try to specify the schema. + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "not a source name") { + val schema = StructType(StructField("b", StringType, true) :: Nil) + createExternalTable( + "createdJsonTable", + "org.apache.spark.sql.json", + schema, + Map("path" -> tempPath.toString)) + + checkAnswer( + sql("SELECT * FROM createdJsonTable"), + sql("SELECT b FROM savedJsonTable")) + + sql("DROP TABLE createdJsonTable") + + assert( + intercept[RuntimeException] { + createExternalTable( + "createdJsonTable", + "org.apache.spark.sql.json", + schema, + Map.empty[String, String]) + }.getMessage.contains("'path' must be specified for json data."), + "We should complain that path is not specified.") + } + } + } } if (HiveShim.version == "0.13.1") { test("scan a parquet table created through a CTAS statement") { - val originalConvertMetastore = getConf("spark.sql.hive.convertMetastoreParquet", "true") - val originalUseDataSource = getConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") - setConf("spark.sql.hive.convertMetastoreParquet", "true") - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") - - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - read.json(rdd).registerTempTable("jt") - sql( - """ - |create table test_parquet_ctas STORED AS parquET - |AS select tmp.a from jt tmp where tmp.a < 5 - """.stripMargin) - - checkAnswer( - sql(s"SELECT a FROM test_parquet_ctas WHERE a > 2 "), - Row(3) :: Row(4) :: Nil - ) - - table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(p: ParquetRelation2) => // OK - case _ => - fail( - "test_parquet_ctas should be converted to " + - s"${classOf[ParquetRelation2].getCanonicalName}") + withSQLConf( + "spark.sql.hive.convertMetastoreParquet" -> "true", + SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { + + withTempTable("jt") { + (1 to 10).map(i => i -> s"str$i").toDF("a", "b").registerTempTable("jt") + + withTable("test_parquet_ctas") { + sql( + """CREATE TABLE test_parquet_ctas STORED AS PARQUET + |AS SELECT tmp.a FROM jt tmp WHERE tmp.a < 5 + """.stripMargin) + + checkAnswer( + sql(s"SELECT a FROM test_parquet_ctas WHERE a > 2 "), + Row(3) :: Row(4) :: Nil) + + table("test_parquet_ctas").queryExecution.optimizedPlan match { + case LogicalRelation(p: ParquetRelation2) => // OK + case _ => + fail(s"test_parquet_ctas should have be converted to ${classOf[ParquetRelation2]}") + } + } + } } - - // Clenup and reset confs. - sql("DROP TABLE IF EXISTS jt") - sql("DROP TABLE IF EXISTS test_parquet_ctas") - setConf("spark.sql.hive.convertMetastoreParquet", originalConvertMetastore) - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalUseDataSource) } } test("Pre insert nullability check (ArrayType)") { - val df1 = - createDataFrame(Tuple1(Seq(Int.box(1), null.asInstanceOf[Integer])) :: Nil).toDF("a") - val expectedSchema1 = - StructType( - StructField("a", ArrayType(IntegerType, containsNull = true), nullable = true) :: Nil) - assert(df1.schema === expectedSchema1) - df1.write.mode(SaveMode.Overwrite).format("parquet").saveAsTable("arrayInParquet") - - val df2 = - createDataFrame(Tuple1(Seq(2, 3)) :: Nil).toDF("a") - val expectedSchema2 = - StructType( - StructField("a", ArrayType(IntegerType, containsNull = false), nullable = true) :: Nil) - assert(df2.schema === expectedSchema2) - df2.write.mode(SaveMode.Append).insertInto("arrayInParquet") - createDataFrame(Tuple1(Seq(4, 5)) :: Nil).toDF("a").write.mode(SaveMode.Append) - .saveAsTable("arrayInParquet") // This one internally calls df2.insertInto. - createDataFrame(Tuple1(Seq(Int.box(6), null.asInstanceOf[Integer])) :: Nil).toDF("a").write - .mode(SaveMode.Append).saveAsTable("arrayInParquet") - refreshTable("arrayInParquet") - - checkAnswer( - sql("SELECT a FROM arrayInParquet"), - Row(ArrayBuffer(1, null)) :: - Row(ArrayBuffer(2, 3)) :: - Row(ArrayBuffer(4, 5)) :: - Row(ArrayBuffer(6, null)) :: Nil) - - sql("DROP TABLE arrayInParquet") + withTable("arrayInParquet") { + { + val df = (Tuple1(Seq(Int.box(1), null: Integer)) :: Nil).toDF("a") + val expectedSchema = + StructType( + StructField( + "a", + ArrayType(IntegerType, containsNull = true), + nullable = true) :: Nil) + + assert(df.schema === expectedSchema) + + df.write + .format("parquet") + .mode(SaveMode.Overwrite) + .saveAsTable("arrayInParquet") + } + + { + val df = (Tuple1(Seq(2, 3)) :: Nil).toDF("a") + val expectedSchema = + StructType( + StructField( + "a", + ArrayType(IntegerType, containsNull = false), + nullable = true) :: Nil) + + assert(df.schema === expectedSchema) + + df.write + .format("parquet") + .mode(SaveMode.Append) + .insertInto("arrayInParquet") + } + + (Tuple1(Seq(4, 5)) :: Nil).toDF("a") + .write + .mode(SaveMode.Append) + .saveAsTable("arrayInParquet") // This one internally calls df2.insertInto. + + (Tuple1(Seq(Int.box(6), null: Integer)) :: Nil).toDF("a") + .write + .mode(SaveMode.Append) + .saveAsTable("arrayInParquet") + + refreshTable("arrayInParquet") + + checkAnswer( + sql("SELECT a FROM arrayInParquet"), + Row(ArrayBuffer(1, null)) :: + Row(ArrayBuffer(2, 3)) :: + Row(ArrayBuffer(4, 5)) :: + Row(ArrayBuffer(6, null)) :: Nil) + } } test("Pre insert nullability check (MapType)") { - val df1 = - createDataFrame(Tuple1(Map(1 -> null.asInstanceOf[Integer])) :: Nil).toDF("a") - val mapType1 = MapType(IntegerType, IntegerType, valueContainsNull = true) - val expectedSchema1 = - StructType( - StructField("a", mapType1, nullable = true) :: Nil) - assert(df1.schema === expectedSchema1) - df1.write.mode(SaveMode.Overwrite).format("parquet").saveAsTable("mapInParquet") - - val df2 = - createDataFrame(Tuple1(Map(2 -> 3)) :: Nil).toDF("a") - val mapType2 = MapType(IntegerType, IntegerType, valueContainsNull = false) - val expectedSchema2 = - StructType( - StructField("a", mapType2, nullable = true) :: Nil) - assert(df2.schema === expectedSchema2) - df2.write.mode(SaveMode.Append).insertInto("mapInParquet") - createDataFrame(Tuple1(Map(4 -> 5)) :: Nil).toDF("a").write.mode(SaveMode.Append) - .saveAsTable("mapInParquet") // This one internally calls df2.insertInto. - createDataFrame(Tuple1(Map(6 -> null.asInstanceOf[Integer])) :: Nil).toDF("a").write - .format("parquet").mode(SaveMode.Append).saveAsTable("mapInParquet") - refreshTable("mapInParquet") - - checkAnswer( - sql("SELECT a FROM mapInParquet"), - Row(Map(1 -> null)) :: - Row(Map(2 -> 3)) :: - Row(Map(4 -> 5)) :: - Row(Map(6 -> null)) :: Nil) - - sql("DROP TABLE mapInParquet") + withTable("mapInParquet") { + { + val df = (Tuple1(Map(1 -> (null: Integer))) :: Nil).toDF("a") + val expectedSchema = + StructType( + StructField( + "a", + MapType(IntegerType, IntegerType, valueContainsNull = true), + nullable = true) :: Nil) + + assert(df.schema === expectedSchema) + + df.write + .format("parquet") + .mode(SaveMode.Overwrite) + .saveAsTable("mapInParquet") + } + + { + val df = (Tuple1(Map(2 -> 3)) :: Nil).toDF("a") + val expectedSchema = + StructType( + StructField( + "a", + MapType(IntegerType, IntegerType, valueContainsNull = false), + nullable = true) :: Nil) + + assert(df.schema === expectedSchema) + + df.write + .format("parquet") + .mode(SaveMode.Append) + .insertInto("mapInParquet") + } + + (Tuple1(Map(4 -> 5)) :: Nil).toDF("a") + .write + .format("parquet") + .mode(SaveMode.Append) + .saveAsTable("mapInParquet") // This one internally calls df2.insertInto. + + (Tuple1(Map(6 -> null.asInstanceOf[Integer])) :: Nil).toDF("a") + .write + .format("parquet") + .mode(SaveMode.Append) + .saveAsTable("mapInParquet") + + refreshTable("mapInParquet") + + checkAnswer( + sql("SELECT a FROM mapInParquet"), + Row(Map(1 -> null)) :: + Row(Map(2 -> 3)) :: + Row(Map(4 -> 5)) :: + Row(Map(6 -> null)) :: Nil) + } } test("SPARK-6024 wide schema support") { - // We will need 80 splits for this schema if the threshold is 4000. - val schema = StructType((1 to 5000).map(i => StructField(s"c_${i}", StringType, true))) - assert( - schema.json.size > conf.schemaStringLengthThreshold, - "To correctly test the fix of SPARK-6024, the value of " + - s"spark.sql.sources.schemaStringLengthThreshold needs to be less than ${schema.json.size}") - // Manually create a metastore data source table. - catalog.createDataSourceTable( - tableName = "wide_schema", - userSpecifiedSchema = Some(schema), - partitionColumns = Array.empty[String], - provider = "json", - options = Map("path" -> "just a dummy path"), - isExternal = false) - - invalidateTable("wide_schema") - - val actualSchema = table("wide_schema").schema - assert(schema === actualSchema) + withSQLConf(SQLConf.SCHEMA_STRING_LENGTH_THRESHOLD -> "4000") { + withTable("wide_schema") { + // We will need 80 splits for this schema if the threshold is 4000. + val schema = StructType((1 to 5000).map(i => StructField(s"c_$i", StringType, true))) + + // Manually create a metastore data source table. + catalog.createDataSourceTable( + tableName = "wide_schema", + userSpecifiedSchema = Some(schema), + partitionColumns = Array.empty[String], + provider = "json", + options = Map("path" -> "just a dummy path"), + isExternal = false) + + invalidateTable("wide_schema") + + val actualSchema = table("wide_schema").schema + assert(schema === actualSchema) + } + } } test("SPARK-6655 still support a schema stored in spark.sql.sources.schema") { val tableName = "spark6655" - val schema = StructType(StructField("int", IntegerType, true) :: Nil) - - val hiveTable = HiveTable( - specifiedDatabase = Some("default"), - name = tableName, - schema = Seq.empty, - partitionColumns = Seq.empty, - properties = Map( - "spark.sql.sources.provider" -> "json", - "spark.sql.sources.schema" -> schema.json, - "EXTERNAL" -> "FALSE"), - tableType = ManagedTable, - serdeProperties = Map( - "path" -> catalog.hiveDefaultTableFilePath(tableName))) - - catalog.client.createTable(hiveTable) - - invalidateTable(tableName) - val actualSchema = table(tableName).schema - assert(schema === actualSchema) - sql(s"drop table $tableName") + withTable(tableName) { + val schema = StructType(StructField("int", IntegerType, true) :: Nil) + val hiveTable = HiveTable( + specifiedDatabase = Some("default"), + name = tableName, + schema = Seq.empty, + partitionColumns = Seq.empty, + properties = Map( + "spark.sql.sources.provider" -> "json", + "spark.sql.sources.schema" -> schema.json, + "EXTERNAL" -> "FALSE"), + tableType = ManagedTable, + serdeProperties = Map( + "path" -> catalog.hiveDefaultTableFilePath(tableName))) + + catalog.client.createTable(hiveTable) + + invalidateTable(tableName) + val actualSchema = table(tableName).schema + assert(schema === actualSchema) + } } test("Saving partition columns information") { - val df = - sparkContext.parallelize(1 to 10, 4).map { i => - Tuple4(i, i + 1, s"str$i", s"str${i + 1}") - }.toDF("a", "b", "c", "d") - + val df = (1 to 10).map(i => (i, i + 1, s"str$i", s"str${i + 1}")).toDF("a", "b", "c", "d") val tableName = s"partitionInfo_${System.currentTimeMillis()}" - df.write.format("parquet").partitionBy("d", "b").saveAsTable(tableName) - invalidateTable(tableName) - val metastoreTable = catalog.client.getTable("default", tableName) - val expectedPartitionColumns = - StructType(df.schema("d") :: df.schema("b") :: Nil) - val actualPartitionColumns = - StructType( - metastoreTable.partitionColumns.map(c => - StructField(c.name, HiveMetastoreTypes.toDataType(c.hiveType)))) - // Make sure partition columns are correctly stored in metastore. - assert( - expectedPartitionColumns.sameType(actualPartitionColumns), - s"Partitions columns stored in metastore $actualPartitionColumns is not the " + - s"partition columns defined by the saveAsTable operation $expectedPartitionColumns.") - - // Check the content of the saved table. - checkAnswer( - table(tableName).selectExpr("c", "b", "d", "a"), - df.selectExpr("c", "b", "d", "a").collect()) - - sql(s"drop table $tableName") + + withTable(tableName) { + df.write.format("parquet").partitionBy("d", "b").saveAsTable(tableName) + invalidateTable(tableName) + val metastoreTable = catalog.client.getTable("default", tableName) + val expectedPartitionColumns = StructType(df.schema("d") :: df.schema("b") :: Nil) + val actualPartitionColumns = + StructType( + metastoreTable.partitionColumns.map(c => + StructField(c.name, HiveMetastoreTypes.toDataType(c.hiveType)))) + // Make sure partition columns are correctly stored in metastore. + assert( + expectedPartitionColumns.sameType(actualPartitionColumns), + s"Partitions columns stored in metastore $actualPartitionColumns is not the " + + s"partition columns defined by the saveAsTable operation $expectedPartitionColumns.") + + // Check the content of the saved table. + checkAnswer( + table(tableName).select("c", "b", "d", "a"), + df.select("c", "b", "d", "a")) + } } test("insert into a table") { - def createDF(from: Int, to: Int): DataFrame = - createDataFrame((from to to).map(i => Tuple2(i, s"str$i"))).toDF("c1", "c2") + def createDF(from: Int, to: Int): DataFrame = { + (from to to).map(i => i -> s"str$i").toDF("c1", "c2") + } - createDF(0, 9).write.format("parquet").saveAsTable("insertParquet") - checkAnswer( - sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), - (6 to 9).map(i => Row(i, s"str$i"))) + withTable("insertParquet") { + createDF(0, 9).write.format("parquet").saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + (6 to 9).map(i => Row(i, s"str$i"))) - intercept[AnalysisException] { - createDF(10, 19).write.format("parquet").saveAsTable("insertParquet") - } + intercept[AnalysisException] { + createDF(10, 19).write.format("parquet").saveAsTable("insertParquet") + } - createDF(10, 19).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") - checkAnswer( - sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), - (6 to 19).map(i => Row(i, s"str$i"))) + createDF(10, 19).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + (6 to 19).map(i => Row(i, s"str$i"))) - createDF(20, 29).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") - checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 25"), - (6 to 24).map(i => Row(i, s"str$i"))) + createDF(20, 29).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 25"), + (6 to 24).map(i => Row(i, s"str$i"))) - intercept[AnalysisException] { - createDF(30, 39).write.saveAsTable("insertParquet") - } + intercept[AnalysisException] { + createDF(30, 39).write.saveAsTable("insertParquet") + } + + createDF(30, 39).write.mode(SaveMode.Append).saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 35"), + (6 to 34).map(i => Row(i, s"str$i"))) - createDF(30, 39).write.mode(SaveMode.Append).saveAsTable("insertParquet") - checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 35"), - (6 to 34).map(i => Row(i, s"str$i"))) - - createDF(40, 49).write.mode(SaveMode.Append).insertInto("insertParquet") - checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 45"), - (6 to 44).map(i => Row(i, s"str$i"))) - - createDF(50, 59).write.mode(SaveMode.Overwrite).saveAsTable("insertParquet") - checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 51 AND p.c1 < 55"), - (52 to 54).map(i => Row(i, s"str$i"))) - createDF(60, 69).write.mode(SaveMode.Ignore).saveAsTable("insertParquet") - checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p"), - (50 to 59).map(i => Row(i, s"str$i"))) - - createDF(70, 79).write.mode(SaveMode.Overwrite).insertInto("insertParquet") - checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p"), - (70 to 79).map(i => Row(i, s"str$i"))) + createDF(40, 49).write.mode(SaveMode.Append).insertInto("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 45"), + (6 to 44).map(i => Row(i, s"str$i"))) + + createDF(50, 59).write.mode(SaveMode.Overwrite).saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 51 AND p.c1 < 55"), + (52 to 54).map(i => Row(i, s"str$i"))) + createDF(60, 69).write.mode(SaveMode.Ignore).saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p"), + (50 to 59).map(i => Row(i, s"str$i"))) + + createDF(70, 79).write.mode(SaveMode.Overwrite).insertInto("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p"), + (70 to 79).map(i => Row(i, s"str$i"))) + } } } From db3fd054f240c7e38aba0732e471df65cd14011a Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 27 May 2015 14:21:00 -0700 Subject: [PATCH 030/254] [SPARK-7853] [SQL] Fixes a class loader issue in Spark SQL This PR is based on PR #6396 authored by chenghao-intel. Essentially, Spark SQL should use context classloader to load SerDe classes. yhuai helped updating the test case, and I fixed a bug in the original `CliSuite`: while testing the CLI tool with `runCliWithin`, we don't append `\n` to the last query, thus the last query is never executed. Original PR description is pasted below. ---- ``` bin/spark-sql --jars ./sql/hive/src/test/resources/hive-hcatalog-core-0.13.1.jar CREATE TABLE t1(a string, b string) ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'; ``` Throws exception like ``` 15/05/26 00:16:33 ERROR SparkSQLDriver: Failed in [CREATE TABLE t1(a string, b string) ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'] org.apache.spark.sql.execution.QueryExecutionException: FAILED: Execution Error, return code 1 from org.apache.hadoop.hive.ql.exec.DDLTask. Cannot validate serde: org.apache.hive.hcatalog.data.JsonSerDe at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$runHive$1.apply(ClientWrapper.scala:333) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$runHive$1.apply(ClientWrapper.scala:310) at org.apache.spark.sql.hive.client.ClientWrapper.withHiveState(ClientWrapper.scala:139) at org.apache.spark.sql.hive.client.ClientWrapper.runHive(ClientWrapper.scala:310) at org.apache.spark.sql.hive.client.ClientWrapper.runSqlHive(ClientWrapper.scala:300) at org.apache.spark.sql.hive.HiveContext.runSqlHive(HiveContext.scala:457) at org.apache.spark.sql.hive.execution.HiveNativeCommand.run(HiveNativeCommand.scala:33) at org.apache.spark.sql.execution.ExecutedCommand.sideEffectResult$lzycompute(commands.scala:57) at org.apache.spark.sql.execution.ExecutedCommand.sideEffectResult(commands.scala:57) at org.apache.spark.sql.execution.ExecutedCommand.doExecute(commands.scala:68) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:88) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:88) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:148) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:87) at org.apache.spark.sql.SQLContext$QueryExecution.toRdd$lzycompute(SQLContext.scala:922) at org.apache.spark.sql.SQLContext$QueryExecution.toRdd(SQLContext.scala:922) at org.apache.spark.sql.DataFrame.(DataFrame.scala:147) at org.apache.spark.sql.DataFrame.(DataFrame.scala:131) at org.apache.spark.sql.DataFrame$.apply(DataFrame.scala:51) at org.apache.spark.sql.SQLContext.sql(SQLContext.scala:727) at org.apache.spark.sql.hive.thriftserver.AbstractSparkSQLDriver.run(AbstractSparkSQLDriver.scala:57) ``` Author: Cheng Hao Author: Cheng Lian Author: Yin Huai Closes #6435 from liancheng/classLoader and squashes the following commits: d4c4845 [Cheng Lian] Fixes CliSuite 75e80e2 [Yin Huai] Update the fix. fd26533 [Cheng Hao] scalastyle dd78775 [Cheng Hao] workaround for classloader of IsolatedClientLoader --- .../sql/hive/thriftserver/CliSuite.scala | 41 +++++++++++++++++-- .../apache/spark/sql/hive/HiveContext.scala | 18 ++++++-- .../apache/spark/sql/hive/TableReader.scala | 2 +- 3 files changed, 53 insertions(+), 8 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index b070fa8eaa469..cc07db827d359 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -25,11 +25,15 @@ import scala.concurrent.{Await, Promise} import scala.sys.process.{Process, ProcessLogger} import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, FunSuite} import org.apache.spark.Logging import org.apache.spark.util.Utils +/** + * A test suite for the `spark-sql` CLI tool. Note that all test cases share the same temporary + * Hive metastore and warehouse. + */ class CliSuite extends FunSuite with BeforeAndAfter with Logging { val warehousePath = Utils.createTempDir() val metastorePath = Utils.createTempDir() @@ -58,13 +62,13 @@ class CliSuite extends FunSuite with BeforeAndAfter with Logging { | --master local | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$jdbcUrl | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath - | --driver-class-path ${sys.props("java.class.path")} """.stripMargin.split("\\s+").toSeq ++ extraArgs } var next = 0 val foundAllExpectedAnswers = Promise.apply[Unit]() - val queryStream = new ByteArrayInputStream(queries.mkString("\n").getBytes) + // Explicitly adds ENTER for each statement to make sure they are actually entered into the CLI. + val queryStream = new ByteArrayInputStream(queries.map(_ + "\n").mkString.getBytes) val buffer = new ArrayBuffer[String]() val lock = new Object @@ -124,7 +128,7 @@ class CliSuite extends FunSuite with BeforeAndAfter with Logging { "SELECT COUNT(*) FROM hive_test;" -> "5", "DROP TABLE hive_test;" - -> "Time taken: " + -> "OK" ) } @@ -151,4 +155,33 @@ class CliSuite extends FunSuite with BeforeAndAfter with Logging { -> "hive_test" ) } + + test("Commands using SerDe provided in --jars") { + val jarFile = + "../hive/src/test/resources/hive-hcatalog-core-0.13.1.jar" + .split("/") + .mkString(File.separator) + + val dataFilePath = + Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") + + runCliWithin(1.minute, Seq("--jars", s"$jarFile"))( + """CREATE TABLE t1(key string, val string) + |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'; + """.stripMargin + -> "OK", + "CREATE TABLE sourceTable (key INT, val STRING);" + -> "OK", + s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE sourceTable;" + -> "OK", + "INSERT INTO TABLE t1 SELECT key, val FROM sourceTable;" + -> "Time taken:", + "SELECT count(key) FROM t1;" + -> "5", + "DROP TABLE t1;" + -> "OK", + "DROP TABLE sourceTable;" + -> "OK" + ) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index b64768ababef9..9ab98fdcce725 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import java.io.{BufferedReader, File, InputStreamReader, PrintStream} +import java.net.{URL, URLClassLoader} import java.sql.Timestamp import java.util.{ArrayList => JArrayList} @@ -25,7 +26,7 @@ import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.spark.sql.catalyst.ParserDialect import scala.collection.JavaConversions._ -import scala.collection.mutable.HashMap +import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.language.implicitConversions import org.apache.hadoop.fs.{FileSystem, Path} @@ -188,8 +189,19 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { "Specify a vaild path to the correct hive jars using $HIVE_METASTORE_JARS " + s"or change $HIVE_METASTORE_VERSION to $hiveExecutionVersion.") } - val jars = getClass.getClassLoader match { - case urlClassLoader: java.net.URLClassLoader => urlClassLoader.getURLs + // We recursively add all jars in the class loader chain, + // starting from the given urlClassLoader. + def addJars(urlClassLoader: URLClassLoader): Array[URL] = { + val jarsInParent = urlClassLoader.getParent match { + case parent: URLClassLoader => addJars(parent) + case other => Array.empty[URL] + } + + urlClassLoader.getURLs ++ jarsInParent + } + + val jars = Utils.getContextOrSparkClassLoader match { + case urlClassLoader: URLClassLoader => addJars(urlClassLoader) case other => throw new IllegalArgumentException( "Unable to locate hive jars to connect to metastore " + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 0b6f7a334a715..294fc3bd7d5e9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -79,7 +79,7 @@ class HadoopTableReader( makeRDDForTable( hiveTable, Class.forName( - relation.tableDesc.getSerdeClassName, true, Utils.getSparkClassLoader) + relation.tableDesc.getSerdeClassName, true, Utils.getContextOrSparkClassLoader) .asInstanceOf[Class[Deserializer]], filterOpt = None) From a1e092eae57172909ff2af06d8b461742595734c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 27 May 2015 18:51:36 -0700 Subject: [PATCH 031/254] [SPARK-7897][SQL] Use DecimalType to represent unsigned bigint in JDBCRDD JIRA: https://issues.apache.org/jira/browse/SPARK-7897 Author: Liang-Chi Hsieh Closes #6438 from viirya/jdbc_unsigned_bigint and squashes the following commits: ccb3c3f [Liang-Chi Hsieh] Use DecimalType to represent unsigned bigint. --- sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 244bd3ebfeb7e..88f1b02549e21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -53,7 +53,7 @@ private[sql] object JDBCRDD extends Logging { signed: Boolean): DataType = { val answer = sqlType match { case java.sql.Types.ARRAY => null - case java.sql.Types.BIGINT => LongType + case java.sql.Types.BIGINT => if (signed) { LongType } else { DecimalType.Unlimited } case java.sql.Types.BINARY => BinaryType case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks case java.sql.Types.BLOB => BinaryType From 3c1f1baaf003d50786d3eee1e288f4bac69096f2 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 27 May 2015 20:04:29 -0700 Subject: [PATCH 032/254] [SPARK-7907] [SQL] [UI] Rename tab ThriftServer to SQL. This PR has three changes: 1. Renaming the table of `ThriftServer` to `SQL`; 2. Renaming the title of the tab from `ThriftServer` to `JDBC/ODBC Server`; and 3. Renaming the title of the session page from `ThriftServer` to `JDBC/ODBC Session`. https://issues.apache.org/jira/browse/SPARK-7907 Author: Yin Huai Closes #6448 from yhuai/JDBCServer and squashes the following commits: eadcc3d [Yin Huai] Update test. 9168005 [Yin Huai] Use SQL as the tab name. 221831e [Yin Huai] Rename ThriftServer to JDBCServer. --- .../spark/sql/hive/thriftserver/ui/ThriftServerPage.scala | 4 ++-- .../sql/hive/thriftserver/ui/ThriftServerSessionPage.scala | 2 +- .../spark/sql/hive/thriftserver/ui/ThriftServerTab.scala | 4 +++- .../apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala | 4 ++-- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index 6a2be4a58e5cb..7c48ff4b35df5 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -47,7 +47,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" ++ generateSessionStatsTable() ++ generateSQLStatsTable() - UIUtils.headerSparkPage("ThriftServer", content, parent, Some(5000)) + UIUtils.headerSparkPage("JDBC/ODBC Server", content, parent, Some(5000)) } /** Generate basic stats of the thrift server program */ @@ -143,7 +143,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", "Total Execute") def generateDataRow(session: SessionInfo): Seq[Node] = { - val sessionLink = "%s/ThriftServer/session?id=%s" + val sessionLink = "%s/sql/session?id=%s" .format(UIUtils.prependBaseUri(parent.basePath), session.sessionId) {session.userName} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index 33ba038ecce73..d9d66dcd8517e 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -55,7 +55,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) Total run {sessionStat._2.totalExecution} SQL ++ generateSQLStatsTable(sessionStat._2.sessionId) - UIUtils.headerSparkPage("ThriftServer", content, parent, Some(5000)) + UIUtils.headerSparkPage("JDBC/ODBC Session", content, parent, Some(5000)) } /** Generate basic stats of the streaming program */ diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala index 343031f10c75c..94fd8a6bb60b9 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala @@ -27,7 +27,9 @@ import org.apache.spark.{SparkContext, Logging, SparkException} * This assumes the given SparkContext has enabled its SparkUI. */ private[thriftserver] class ThriftServerTab(sparkContext: SparkContext) - extends SparkUITab(getSparkUI(sparkContext), "ThriftServer") with Logging { + extends SparkUITab(getSparkUI(sparkContext), "sql") with Logging { + + override val name = "SQL" val parent = getSparkUI(sparkContext) val listener = HiveThriftServer2.listener diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala index a286dc5825f77..e1466e0423033 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala @@ -84,11 +84,11 @@ class UISeleniumSuite eventually(timeout(10 seconds), interval(50 milliseconds)) { go to baseURL - find(cssSelector("""ul li a[href*="ThriftServer"]""")) should not be None + find(cssSelector("""ul li a[href*="sql"]""")) should not be None } eventually(timeout(10 seconds), interval(50 milliseconds)) { - go to (baseURL + "/ThriftServer") + go to (baseURL + "/sql") find(id("sessionstat")) should not be None find(id("sqlstat")) should not be None From 852f4de2d3d0c5fff2fa66000a7a3088bb3dbe74 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 27 May 2015 20:19:53 -0700 Subject: [PATCH 033/254] [SPARK-7873] Allow KryoSerializerInstance to create multiple streams at the same time This is a somewhat obscure bug, but I think that it will seriously impact KryoSerializer users who use custom registrators which disabled auto-reset. When auto-reset is disabled, then this breaks things in some of our shuffle paths which actually end up creating multiple OutputStreams from the same shared SerializerInstance (which is unsafe). This was introduced by a patch (SPARK-3386) which enables serializer re-use in some of the shuffle paths, since constructing new serializer instances is actually pretty costly for KryoSerializer. We had already fixed another corner-case (SPARK-7766) bug related to this, but missed this one. I think that the root problem here is that KryoSerializerInstance can be used in a way which is unsafe even within a single thread, e.g. by creating multiple open OutputStreams from the same instance or by interleaving deserialize and deserializeStream calls. I considered a smaller patch which adds assertions to guard against this type of "misuse" but abandoned that approach after I realized how convoluted the Scaladoc became. This patch fixes this bug by making it legal to create multiple streams from the same KryoSerializerInstance. Internally, KryoSerializerInstance now implements a `borrowKryo()` / `releaseKryo()` API that's backed by a "pool" of capacity 1. Each call to a KryoSerializerInstance method will borrow the Kryo, do its work, then release the serializer instance back to the pool. If the pool is empty and we need an instance, it will allocate a new Kryo on-demand. This makes it safe for multiple OutputStreams to be opened from the same serializer. If we try to release a Kryo back to the pool but the pool already contains a Kryo, then we'll just discard the new Kryo. I don't think there's a clear benefit to having a larger pool since our usages tend to fall into two cases, a) where we only create a single OutputStream and b) where we create a huge number of OutputStreams with the same lifecycle, then destroy the KryoSerializerInstance (this is what's happening in the bypassMergeSort code path that my regression test hits). Author: Josh Rosen Closes #6415 from JoshRosen/SPARK-7873 and squashes the following commits: 00b402e [Josh Rosen] Initialize eagerly to fix a failing test ba55d20 [Josh Rosen] Add explanatory comments 3f1da96 [Josh Rosen] Guard against duplicate close() ab457ca [Josh Rosen] Sketch a loan/release based solution. 9816e8f [Josh Rosen] Add a failing test showing how deserialize() and deserializeStream() can interfere. 7350886 [Josh Rosen] Add failing regression test for SPARK-7873 --- .../spark/serializer/KryoSerializer.scala | 129 ++++++++++++++---- .../apache/spark/serializer/Serializer.scala | 5 + .../serializer/KryoSerializerSuite.scala | 37 ++++- 3 files changed, 147 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 217957963437d..3f909885dbd66 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -17,8 +17,9 @@ package org.apache.spark.serializer -import java.io.{EOFException, InputStream, OutputStream} +import java.io.{EOFException, IOException, InputStream, OutputStream} import java.nio.ByteBuffer +import javax.annotation.Nullable import scala.reflect.ClassTag @@ -136,21 +137,45 @@ class KryoSerializer(conf: SparkConf) } private[spark] -class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream { - val output = new KryoOutput(outStream) +class KryoSerializationStream( + serInstance: KryoSerializerInstance, + outStream: OutputStream) extends SerializationStream { + + private[this] var output: KryoOutput = new KryoOutput(outStream) + private[this] var kryo: Kryo = serInstance.borrowKryo() override def writeObject[T: ClassTag](t: T): SerializationStream = { kryo.writeClassAndObject(output, t) this } - override def flush() { output.flush() } - override def close() { output.close() } + override def flush() { + if (output == null) { + throw new IOException("Stream is closed") + } + output.flush() + } + + override def close() { + if (output != null) { + try { + output.close() + } finally { + serInstance.releaseKryo(kryo) + kryo = null + output = null + } + } + } } private[spark] -class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream { - private val input = new KryoInput(inStream) +class KryoDeserializationStream( + serInstance: KryoSerializerInstance, + inStream: InputStream) extends DeserializationStream { + + private[this] var input: KryoInput = new KryoInput(inStream) + private[this] var kryo: Kryo = serInstance.borrowKryo() override def readObject[T: ClassTag](): T = { try { @@ -163,52 +188,105 @@ class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends Deser } override def close() { - // Kryo's Input automatically closes the input stream it is using. - input.close() + if (input != null) { + try { + // Kryo's Input automatically closes the input stream it is using. + input.close() + } finally { + serInstance.releaseKryo(kryo) + kryo = null + input = null + } + } } } private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { - private val kryo = ks.newKryo() - // Make these lazy vals to avoid creating a buffer unless we use them + /** + * A re-used [[Kryo]] instance. Methods will borrow this instance by calling `borrowKryo()`, do + * their work, then release the instance by calling `releaseKryo()`. Logically, this is a caching + * pool of size one. SerializerInstances are not thread-safe, hence accesses to this field are + * not synchronized. + */ + @Nullable private[this] var cachedKryo: Kryo = borrowKryo() + + /** + * Borrows a [[Kryo]] instance. If possible, this tries to re-use a cached Kryo instance; + * otherwise, it allocates a new instance. + */ + private[serializer] def borrowKryo(): Kryo = { + if (cachedKryo != null) { + val kryo = cachedKryo + // As a defensive measure, call reset() to clear any Kryo state that might have been modified + // by the last operation to borrow this instance (see SPARK-7766 for discussion of this issue) + kryo.reset() + cachedKryo = null + kryo + } else { + ks.newKryo() + } + } + + /** + * Release a borrowed [[Kryo]] instance. If this serializer instance already has a cached Kryo + * instance, then the given Kryo instance is discarded; otherwise, the Kryo is stored for later + * re-use. + */ + private[serializer] def releaseKryo(kryo: Kryo): Unit = { + if (cachedKryo == null) { + cachedKryo = kryo + } + } + + // Make these lazy vals to avoid creating a buffer unless we use them. private lazy val output = ks.newKryoOutput() private lazy val input = new KryoInput() override def serialize[T: ClassTag](t: T): ByteBuffer = { output.clear() - kryo.reset() // We must reset in case this serializer instance was reused (see SPARK-7766) + val kryo = borrowKryo() try { kryo.writeClassAndObject(output, t) } catch { case e: KryoException if e.getMessage.startsWith("Buffer overflow") => throw new SparkException(s"Kryo serialization failed: ${e.getMessage}. To avoid this, " + "increase spark.kryoserializer.buffer.max value.") + } finally { + releaseKryo(kryo) } ByteBuffer.wrap(output.toBytes) } override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { - input.setBuffer(bytes.array) - kryo.readClassAndObject(input).asInstanceOf[T] + val kryo = borrowKryo() + try { + input.setBuffer(bytes.array) + kryo.readClassAndObject(input).asInstanceOf[T] + } finally { + releaseKryo(kryo) + } } override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { + val kryo = borrowKryo() val oldClassLoader = kryo.getClassLoader - kryo.setClassLoader(loader) - input.setBuffer(bytes.array) - val obj = kryo.readClassAndObject(input).asInstanceOf[T] - kryo.setClassLoader(oldClassLoader) - obj + try { + kryo.setClassLoader(loader) + input.setBuffer(bytes.array) + kryo.readClassAndObject(input).asInstanceOf[T] + } finally { + kryo.setClassLoader(oldClassLoader) + releaseKryo(kryo) + } } override def serializeStream(s: OutputStream): SerializationStream = { - kryo.reset() // We must reset in case this serializer instance was reused (see SPARK-7766) - new KryoSerializationStream(kryo, s) + new KryoSerializationStream(this, s) } override def deserializeStream(s: InputStream): DeserializationStream = { - new KryoDeserializationStream(kryo, s) + new KryoDeserializationStream(this, s) } /** @@ -218,7 +296,12 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ def getAutoReset(): Boolean = { val field = classOf[Kryo].getDeclaredField("autoReset") field.setAccessible(true) - field.get(kryo).asInstanceOf[Boolean] + val kryo = borrowKryo() + try { + field.get(kryo).asInstanceOf[Boolean] + } finally { + releaseKryo(kryo) + } } } diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index 6078c9d433ebf..f1bdff96d3df1 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -19,6 +19,7 @@ package org.apache.spark.serializer import java.io._ import java.nio.ByteBuffer +import javax.annotation.concurrent.NotThreadSafe import scala.reflect.ClassTag @@ -114,8 +115,12 @@ object Serializer { /** * :: DeveloperApi :: * An instance of a serializer, for use by one thread at a time. + * + * It is legal to create multiple serialization / deserialization streams from the same + * SerializerInstance as long as those streams are all used within the same thread. */ @DeveloperApi +@NotThreadSafe abstract class SerializerInstance { def serialize[T: ClassTag](t: T): ByteBuffer diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 8c384bd358ebc..ef50bc9438f95 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.serializer -import java.io.ByteArrayOutputStream +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable import scala.reflect.ClassTag @@ -361,6 +361,41 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { } } +class KryoSerializerAutoResetDisabledSuite extends FunSuite with SharedSparkContext { + conf.set("spark.serializer", classOf[KryoSerializer].getName) + conf.set("spark.kryo.registrator", classOf[RegistratorWithoutAutoReset].getName) + conf.set("spark.kryo.referenceTracking", "true") + conf.set("spark.shuffle.manager", "sort") + conf.set("spark.shuffle.sort.bypassMergeThreshold", "200") + + test("sort-shuffle with bypassMergeSort (SPARK-7873)") { + val myObject = ("Hello", "World") + assert(sc.parallelize(Seq.fill(100)(myObject)).repartition(2).collect().toSet === Set(myObject)) + } + + test("calling deserialize() after deserializeStream()") { + val serInstance = new KryoSerializer(conf).newInstance().asInstanceOf[KryoSerializerInstance] + assert(!serInstance.getAutoReset()) + val hello = "Hello" + val world = "World" + // Here, we serialize the same value twice, so the reference-tracking should cause us to store + // references to some of these values + val helloHello = serInstance.serialize((hello, hello)) + // Here's a stream which only contains one value + val worldWorld: Array[Byte] = { + val baos = new ByteArrayOutputStream() + val serStream = serInstance.serializeStream(baos) + serStream.writeObject(world) + serStream.writeObject(world) + serStream.close() + baos.toByteArray + } + val deserializationStream = serInstance.deserializeStream(new ByteArrayInputStream(worldWorld)) + assert(deserializationStream.readValue[Any]() === world) + deserializationStream.close() + assert(serInstance.deserialize[Any](helloHello) === (hello, hello)) + } +} class ClassLoaderTestingObject From bd11b01ebaf62df8b0d8c0b63b51b66e58f50960 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Wed, 27 May 2015 22:23:22 -0700 Subject: [PATCH 034/254] [SPARK-7896] Allow ChainedBuffer to store more than 2 GB Author: Sandy Ryza Closes #6440 from sryza/sandy-spark-7896 and squashes the following commits: 49d8a0d [Sandy Ryza] Fix bug introduced when reading over record boundaries 6006856 [Sandy Ryza] Fix overflow issues 006b4b2 [Sandy Ryza] Fix scalastyle by removing non ascii characters 8b000ca [Sandy Ryza] Add ascii art to describe layout of data in metaBuffer f2053c0 [Sandy Ryza] Fix negative overflow issue 0368c78 [Sandy Ryza] Initialize size as 0 a5a4820 [Sandy Ryza] Use explicit types for all numbers in ChainedBuffer b7e0213 [Sandy Ryza] SPARK-7896. Allow ChainedBuffer to store more than 2 GB --- .../spark/util/collection/ChainedBuffer.scala | 46 +++++++++-------- .../PartitionedSerializedPairBuffer.scala | 51 +++++++++++-------- 2 files changed, 55 insertions(+), 42 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala index a60bffe611f14..516aaa44d03fc 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala @@ -28,11 +28,13 @@ import scala.collection.mutable.ArrayBuffer * occupy a contiguous segment of memory. */ private[spark] class ChainedBuffer(chunkSize: Int) { - private val chunkSizeLog2 = (math.log(chunkSize) / math.log(2)).toInt - assert(math.pow(2, chunkSizeLog2).toInt == chunkSize, + + private val chunkSizeLog2: Int = java.lang.Long.numberOfTrailingZeros( + java.lang.Long.highestOneBit(chunkSize)) + assert((1 << chunkSizeLog2) == chunkSize, s"ChainedBuffer chunk size $chunkSize must be a power of two") private val chunks: ArrayBuffer[Array[Byte]] = new ArrayBuffer[Array[Byte]]() - private var _size: Int = _ + private var _size: Long = 0 /** * Feed bytes from this buffer into a BlockObjectWriter. @@ -41,16 +43,16 @@ private[spark] class ChainedBuffer(chunkSize: Int) { * @param os OutputStream to read into. * @param len Number of bytes to read. */ - def read(pos: Int, os: OutputStream, len: Int): Unit = { + def read(pos: Long, os: OutputStream, len: Int): Unit = { if (pos + len > _size) { throw new IndexOutOfBoundsException( s"Read of $len bytes at position $pos would go past size ${_size} of buffer") } - var chunkIndex = pos >> chunkSizeLog2 - var posInChunk = pos - (chunkIndex << chunkSizeLog2) - var written = 0 + var chunkIndex: Int = (pos >> chunkSizeLog2).toInt + var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt + var written: Int = 0 while (written < len) { - val toRead = math.min(len - written, chunkSize - posInChunk) + val toRead: Int = math.min(len - written, chunkSize - posInChunk) os.write(chunks(chunkIndex), posInChunk, toRead) written += toRead chunkIndex += 1 @@ -66,16 +68,16 @@ private[spark] class ChainedBuffer(chunkSize: Int) { * @param offs Offset in the byte array to read to. * @param len Number of bytes to read. */ - def read(pos: Int, bytes: Array[Byte], offs: Int, len: Int): Unit = { + def read(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = { if (pos + len > _size) { throw new IndexOutOfBoundsException( s"Read of $len bytes at position $pos would go past size of buffer") } - var chunkIndex = pos >> chunkSizeLog2 - var posInChunk = pos - (chunkIndex << chunkSizeLog2) - var written = 0 + var chunkIndex: Int = (pos >> chunkSizeLog2).toInt + var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt + var written: Int = 0 while (written < len) { - val toRead = math.min(len - written, chunkSize - posInChunk) + val toRead: Int = math.min(len - written, chunkSize - posInChunk) System.arraycopy(chunks(chunkIndex), posInChunk, bytes, offs + written, toRead) written += toRead chunkIndex += 1 @@ -91,22 +93,22 @@ private[spark] class ChainedBuffer(chunkSize: Int) { * @param offs Offset in the byte array to write from. * @param len Number of bytes to write. */ - def write(pos: Int, bytes: Array[Byte], offs: Int, len: Int): Unit = { + def write(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = { if (pos > _size) { throw new IndexOutOfBoundsException( s"Write at position $pos starts after end of buffer ${_size}") } // Grow if needed - val endChunkIndex = (pos + len - 1) >> chunkSizeLog2 + val endChunkIndex: Int = ((pos + len - 1) >> chunkSizeLog2).toInt while (endChunkIndex >= chunks.length) { chunks += new Array[Byte](chunkSize) } - var chunkIndex = pos >> chunkSizeLog2 - var posInChunk = pos - (chunkIndex << chunkSizeLog2) - var written = 0 + var chunkIndex: Int = (pos >> chunkSizeLog2).toInt + var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt + var written: Int = 0 while (written < len) { - val toWrite = math.min(len - written, chunkSize - posInChunk) + val toWrite: Int = math.min(len - written, chunkSize - posInChunk) System.arraycopy(bytes, offs + written, chunks(chunkIndex), posInChunk, toWrite) written += toWrite chunkIndex += 1 @@ -119,19 +121,19 @@ private[spark] class ChainedBuffer(chunkSize: Int) { /** * Total size of buffer that can be written to without allocating additional memory. */ - def capacity: Int = chunks.size * chunkSize + def capacity: Long = chunks.size.toLong * chunkSize /** * Size of the logical buffer. */ - def size: Int = _size + def size: Long = _size } /** * Output stream that writes to a ChainedBuffer. */ private[spark] class ChainedBufferOutputStream(chainedBuffer: ChainedBuffer) extends OutputStream { - private var pos = 0 + private var pos: Long = 0 override def write(b: Int): Unit = { throw new UnsupportedOperationException() diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala index ac9ea6393628f..554d88206e221 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala @@ -41,6 +41,13 @@ import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._ * * Currently, only sorting by partition is supported. * + * Each record is laid out inside the the metaBuffer as follows. keyStart, a long, is split across + * two integers: + * + * +-------------+------------+------------+-------------+ + * | keyStart | keyValLen | partitionId | + * +-------------+------------+------------+-------------+ + * * @param metaInitialRecords The initial number of entries in the metadata buffer. * @param kvBlockSize The size of each byte buffer in the ChainedBuffer used to store the records. * @param serializerInstance the serializer used for serializing inserted records. @@ -68,19 +75,15 @@ private[spark] class PartitionedSerializedPairBuffer[K, V]( } val keyStart = kvBuffer.size - if (keyStart < 0) { - throw new Exception(s"Can't grow buffer beyond ${1 << 31} bytes") - } kvSerializationStream.writeKey[Any](key) - kvSerializationStream.flush() - val valueStart = kvBuffer.size kvSerializationStream.writeValue[Any](value) kvSerializationStream.flush() - val valueEnd = kvBuffer.size + val keyValLen = (kvBuffer.size - keyStart).toInt - metaBuffer.put(keyStart) - metaBuffer.put(valueStart) - metaBuffer.put(valueEnd) + // keyStart, a long, gets split across two ints + metaBuffer.put(keyStart.toInt) + metaBuffer.put((keyStart >> 32).toInt) + metaBuffer.put(keyValLen) metaBuffer.put(partition) } @@ -114,7 +117,7 @@ private[spark] class PartitionedSerializedPairBuffer[K, V]( } } - override def estimateSize: Long = metaBuffer.capacity * 4 + kvBuffer.capacity + override def estimateSize: Long = metaBuffer.capacity * 4L + kvBuffer.capacity override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]]) : WritablePartitionedIterator = { @@ -128,10 +131,10 @@ private[spark] class PartitionedSerializedPairBuffer[K, V]( var pos = 0 def writeNext(writer: BlockObjectWriter): Unit = { - val keyStart = metaBuffer.get(pos + KEY_START) - val valueEnd = metaBuffer.get(pos + VAL_END) + val keyStart = getKeyStartPos(metaBuffer, pos) + val keyValLen = metaBuffer.get(pos + KEY_VAL_LEN) pos += RECORD_SIZE - kvBuffer.read(keyStart, writer, valueEnd - keyStart) + kvBuffer.read(keyStart, writer, keyValLen) writer.recordWritten() } def nextPartition(): Int = metaBuffer.get(pos + PARTITION) @@ -163,9 +166,11 @@ private[spark] class PartitionedSerializedPairBuffer[K, V]( private[spark] class OrderedInputStream(metaBuffer: IntBuffer, kvBuffer: ChainedBuffer) extends InputStream { + import PartitionedSerializedPairBuffer._ + private var metaBufferPos = 0 private var kvBufferPos = - if (metaBuffer.position > 0) metaBuffer.get(metaBufferPos + KEY_START) else 0 + if (metaBuffer.position > 0) getKeyStartPos(metaBuffer, metaBufferPos) else 0 override def read(bytes: Array[Byte]): Int = read(bytes, 0, bytes.length) @@ -173,13 +178,14 @@ private[spark] class OrderedInputStream(metaBuffer: IntBuffer, kvBuffer: Chained if (metaBufferPos >= metaBuffer.position) { return -1 } - val bytesRemainingInRecord = metaBuffer.get(metaBufferPos + VAL_END) - kvBufferPos + val bytesRemainingInRecord = (metaBuffer.get(metaBufferPos + KEY_VAL_LEN) - + (kvBufferPos - getKeyStartPos(metaBuffer, metaBufferPos))).toInt val toRead = math.min(bytesRemainingInRecord, len) kvBuffer.read(kvBufferPos, bytes, offs, toRead) if (toRead == bytesRemainingInRecord) { metaBufferPos += RECORD_SIZE if (metaBufferPos < metaBuffer.position) { - kvBufferPos = metaBuffer.get(metaBufferPos + KEY_START) + kvBufferPos = getKeyStartPos(metaBuffer, metaBufferPos) } } else { kvBufferPos += toRead @@ -246,9 +252,14 @@ private[spark] class SerializedSortDataFormat extends SortDataFormat[Int, IntBuf } private[spark] object PartitionedSerializedPairBuffer { - val KEY_START = 0 - val VAL_START = 1 - val VAL_END = 2 + val KEY_START = 0 // keyStart, a long, gets split across two ints + val KEY_VAL_LEN = 2 val PARTITION = 3 - val RECORD_SIZE = Seq(KEY_START, VAL_START, VAL_END, PARTITION).size // num ints of metadata + val RECORD_SIZE = PARTITION + 1 // num ints of metadata + + def getKeyStartPos(metaBuffer: IntBuffer, metaBufferPos: Int): Long = { + val lower32 = metaBuffer.get(metaBufferPos + KEY_START) + val upper32 = metaBuffer.get(metaBufferPos + KEY_START + 1) + (upper32.toLong << 32) | (lower32 & 0xFFFFFFFFL) + } } From 35410614deb7feea1c9d5cca00a6fa7970404f21 Mon Sep 17 00:00:00 2001 From: Matt Wise Date: Wed, 27 May 2015 22:39:19 -0700 Subject: [PATCH 035/254] [DOCS] Fix typo in documentation for Java UDF registration This contribution is my original work and I license the work to the project under the project's open source license Author: Matt Wise Closes #6447 from wisematthew/fix-typo-in-java-udf-registration-doc and squashes the following commits: e7ef5f7 [Matt Wise] Fix typo in documentation for Java UDF registration --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 5b41c0ee6e430..ab646f65bb5eb 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1939,7 +1939,7 @@ sqlContext.udf.register("strLen", (s: String) => s.length())
{% highlight java %} -sqlContext.udf().register("strLen", (String s) -> { s.length(); }); +sqlContext.udf().register("strLen", (String s) -> s.length(), DataTypes.IntegerType); {% endhighlight %}
From e838a25bdb5603ef05e779225704c972ce436145 Mon Sep 17 00:00:00 2001 From: zuxqoj Date: Wed, 27 May 2015 23:13:13 -0700 Subject: [PATCH 036/254] [SPARK-7782] fixed sort arrow issue Current behaviour:: In spark UI ![screen shot 2015-05-27 at 3 27 51 pm](https://cloud.githubusercontent.com/assets/3919211/7837541/47d330ba-04a5-11e5-89d1-e5b11da1a513.png) In YARN ![screen shot 2015-05-27 at 3](https://cloud.githubusercontent.com/assets/3919211/7837594/aebd1d36-04a5-11e5-8216-86e03c07d2bd.png) In jira ![screen shot 2015-05-27 at 3_2](https://cloud.githubusercontent.com/assets/3919211/7837616/d3fedce2-04a5-11e5-9e68-960ed54e5d83.png) Author: zuxqoj Closes #6437 from zuxqoj/SPARK-7782_PR and squashes the following commits: cd068b9 [zuxqoj] [SPARK-7782] fixed sort arrow issue --- .../main/resources/org/apache/spark/ui/static/sorttable.js | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js index dbacbf19beee5..dde6069000bc4 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js +++ b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js @@ -100,7 +100,7 @@ sorttable = { this.removeChild(document.getElementById('sorttable_sortfwdind')); sortrevind = document.createElement('span'); sortrevind.id = "sorttable_sortrevind"; - sortrevind.innerHTML = stIsIE ? ' 5' : ' ▴'; + sortrevind.innerHTML = stIsIE ? ' 5' : ' ▾'; this.appendChild(sortrevind); return; } @@ -113,7 +113,7 @@ sorttable = { this.removeChild(document.getElementById('sorttable_sortrevind')); sortfwdind = document.createElement('span'); sortfwdind.id = "sorttable_sortfwdind"; - sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▾'; + sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▴'; this.appendChild(sortfwdind); return; } @@ -134,7 +134,7 @@ sorttable = { this.className += ' sorttable_sorted'; sortfwdind = document.createElement('span'); sortfwdind.id = "sorttable_sortfwdind"; - sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▾'; + sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▴'; this.appendChild(sortfwdind); // build an array to sort. This is a Schwartzian transform thing, From 000df2f0d6af068bb188e81bbb207f0c2f43bf16 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 28 May 2015 09:04:12 -0700 Subject: [PATCH 037/254] [SPARK-7895] [STREAMING] [EXAMPLES] Move Kafka examples from scala-2.10/src to src Since `spark-streaming-kafka` now is published for both Scala 2.10 and 2.11, we can move `KafkaWordCount` and `DirectKafkaWordCount` from `examples/scala-2.10/src/` to `examples/src/` so that they will appear in `spark-examples-***-jar` for Scala 2.11. Author: zsxwing Closes #6436 from zsxwing/SPARK-7895 and squashes the following commits: c6052f1 [zsxwing] Update examples/pom.xml 0bcfa87 [zsxwing] Fix the sleep time b9d1256 [zsxwing] Move Kafka examples from scala-2.10/src to src --- examples/pom.xml | 44 +++---------------- .../streaming/JavaDirectKafkaWordCount.java | 0 .../streaming/JavaKafkaWordCount.java | 0 .../streaming/DirectKafkaWordCount.scala | 0 .../examples/streaming/KafkaWordCount.scala | 2 +- 5 files changed, 6 insertions(+), 40 deletions(-) rename examples/{scala-2.10 => }/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java (100%) rename examples/{scala-2.10 => }/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java (100%) rename examples/{scala-2.10 => }/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala (100%) rename examples/{scala-2.10 => }/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala (99%) diff --git a/examples/pom.xml b/examples/pom.xml index 5b04b4f8d6ca0..e4efee7b5e647 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -97,6 +97,11 @@ + + org.apache.spark + spark-streaming-kafka_${scala.binary.version} + ${project.version} + org.apache.hbase hbase-testing-util @@ -392,45 +397,6 @@ - - - scala-2.10 - - !scala-2.11 - - - - org.apache.spark - spark-streaming-kafka_${scala.binary.version} - ${project.version} - - - - - - org.codehaus.mojo - build-helper-maven-plugin - - - add-scala-sources - generate-sources - - add-source - - - - src/main/scala - scala-2.10/src/main/scala - scala-2.10/src/main/java - - - - - - - - diff --git a/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java similarity index 100% rename from examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java rename to examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java diff --git a/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java similarity index 100% rename from examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java rename to examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala similarity index 100% rename from examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala rename to examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala similarity index 99% rename from examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala rename to examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala index f407367a54f6c..9ae1b045c2c76 100644 --- a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala @@ -96,7 +96,7 @@ object KafkaWordCountProducer { producer.send(message) } - Thread.sleep(100) + Thread.sleep(1000) } } From 530efe3e80c62b25c869b85167e00330eb1ddea6 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 28 May 2015 12:03:46 -0700 Subject: [PATCH 038/254] [SPARK-7911] [MLLIB] A workaround for VectorUDT serialize (or deserialize) being called multiple times ~~A PythonUDT shouldn't be serialized into external Scala types in PythonRDD. I'm not sure whether this should fix one of the bugs related to SQL UDT/UDF in PySpark.~~ The fix above didn't work. So I added a workaround for this. If a Python UDF is applied to a Python UDT. This will put the Python SQL types as inputs. Still incorrect, but at least it doesn't throw exceptions on the Scala side. davies harsha2010 Author: Xiangrui Meng Closes #6442 from mengxr/SPARK-7903 and squashes the following commits: c257d2a [Xiangrui Meng] add a workaround for VectorUDT --- .../apache/spark/mllib/linalg/Vectors.scala | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index f6bcdf83cd337..2ffa497a99d93 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -176,27 +176,31 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { } override def serialize(obj: Any): Row = { - val row = new GenericMutableRow(4) obj match { case SparseVector(size, indices, values) => + val row = new GenericMutableRow(4) row.setByte(0, 0) row.setInt(1, size) row.update(2, indices.toSeq) row.update(3, values.toSeq) + row case DenseVector(values) => + val row = new GenericMutableRow(4) row.setByte(0, 1) row.setNullAt(1) row.setNullAt(2) row.update(3, values.toSeq) + row + // TODO: There are bugs in UDT serialization because we don't have a clear separation between + // TODO: internal SQL types and language specific types (including UDT). UDT serialize and + // TODO: deserialize may get called twice. See SPARK-7186. + case row: Row => + row } - row } override def deserialize(datum: Any): Vector = { datum match { - // TODO: something wrong with UDT serialization - case v: Vector => - v case row: Row => require(row.length == 4, s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4") @@ -211,6 +215,11 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { val values = row.getAs[Iterable[Double]](3).toArray new DenseVector(values) } + // TODO: There are bugs in UDT serialization because we don't have a clear separation between + // TODO: internal SQL types and language specific types (including UDT). UDT serialize and + // TODO: deserialize may get called twice. See SPARK-7186. + case v: Vector => + v } } From c771589c96403b2a518fb77d5162eca8f495f37b Mon Sep 17 00:00:00 2001 From: Li Yao Date: Thu, 28 May 2015 13:39:39 -0700 Subject: [PATCH 039/254] [MINOR] Fix the a minor bug in PageRank Example. Fix the bug that entering only 1 arg will cause array out of bounds exception in PageRank example. Author: Li Yao Closes #6455 from lastland/patch-1 and squashes the following commits: de06128 [Li Yao] Fix the bug that entering only 1 arg will cause array out of bounds exception. --- .../main/scala/org/apache/spark/examples/SparkPageRank.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index 8d092b6506d33..bd7894f184c4c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -51,7 +51,7 @@ object SparkPageRank { showWarning() val sparkConf = new SparkConf().setAppName("PageRank") - val iters = if (args.length > 0) args(1).toInt else 10 + val iters = if (args.length > 1) args(1).toInt else 10 val ctx = new SparkContext(sparkConf) val lines = ctx.textFile(args(0), 1) val links = lines.map{ s => From 3e312a5ed0154527c66eeeee0d2cc3bfce0a820e Mon Sep 17 00:00:00 2001 From: Mike Dusenberry Date: Thu, 28 May 2015 17:15:10 -0400 Subject: [PATCH 040/254] [DOCS] Fixing broken "IDE setup" link in the Building Spark documentation. The location of the IDE setup information has changed, so this just updates the link on the Building Spark page. Author: Mike Dusenberry Closes #6467 from dusenberrymw/Fix_Broken_Link_On_Building_Spark_Doc and squashes the following commits: 75c533a [Mike Dusenberry] Fixing broken "IDE setup" link in the Building Spark documentation by pointing to new location. --- docs/building-spark.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/building-spark.md b/docs/building-spark.md index 3ca7f2746e678..b2649d1ee2a53 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -176,7 +176,7 @@ Thus, the full flow for running continuous-compilation of the `core` submodule m # Building Spark with IntelliJ IDEA or Eclipse For help in setting up IntelliJ IDEA or Eclipse for Spark development, and troubleshooting, refer to the -[wiki page for IDE setup](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark#ContributingtoSpark-IDESetup). +[wiki page for IDE setup](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-IDESetup). # Running Java 8 Test Suites From 7859ab659eecbcf2d8b9a274a4e9e4f5186a528c Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 28 May 2015 16:32:51 -0700 Subject: [PATCH 041/254] [SPARK-7198] [MLLIB] VectorAssembler should output ML attributes `VectorAssembler` should carry over ML attributes. For unknown attributes, we assume numeric values. This PR handles the following cases: 1. DoubleType with ML attribute: carry over 2. DoubleType without ML attribute: numeric value 3. Scalar type: numeric value 4. VectorType with all ML attributes: carry over and update names 5. VectorType with number of ML attributes: assume all numeric 6. VectorType without ML attributes: check the first row and get the number of attributes jkbradley Author: Xiangrui Meng Closes #6452 from mengxr/SPARK-7198 and squashes the following commits: a9d2469 [Xiangrui Meng] add space facdb1f [Xiangrui Meng] VectorAssembler should output ML attributes --- .../spark/ml/feature/VectorAssembler.scala | 51 +++++++++++++++++-- .../ml/feature/VectorAssemblerSuite.scala | 37 ++++++++++++++ 2 files changed, 83 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 514ffb03c0509..229ee27ec5942 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkException import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Transformer +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} @@ -37,7 +38,7 @@ import org.apache.spark.sql.types._ class VectorAssembler(override val uid: String) extends Transformer with HasInputCols with HasOutputCol { - def this() = this(Identifiable.randomUID("va")) + def this() = this(Identifiable.randomUID("vecAssembler")) /** @group setParam */ def setInputCols(value: Array[String]): this.type = set(inputCols, value) @@ -46,19 +47,59 @@ class VectorAssembler(override val uid: String) def setOutputCol(value: String): this.type = set(outputCol, value) override def transform(dataset: DataFrame): DataFrame = { + // Schema transformation. + val schema = dataset.schema + lazy val first = dataset.first() + val attrs = $(inputCols).flatMap { c => + val field = schema(c) + val index = schema.fieldIndex(c) + field.dataType match { + case DoubleType => + val attr = Attribute.fromStructField(field) + // If the input column doesn't have ML attribute, assume numeric. + if (attr == UnresolvedAttribute) { + Some(NumericAttribute.defaultAttr.withName(c)) + } else { + Some(attr.withName(c)) + } + case _: NumericType | BooleanType => + // If the input column type is a compatible scalar type, assume numeric. + Some(NumericAttribute.defaultAttr.withName(c)) + case _: VectorUDT => + val group = AttributeGroup.fromStructField(field) + if (group.attributes.isDefined) { + // If attributes are defined, copy them with updated names. + group.attributes.get.map { attr => + if (attr.name.isDefined) { + // TODO: Define a rigorous naming scheme. + attr.withName(c + "_" + attr.name.get) + } else { + attr + } + } + } else { + // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes + // from metadata, check the first row. + val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) + Array.fill(numAttrs)(NumericAttribute.defaultAttr) + } + } + } + val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() + + // Data transformation. val assembleFunc = udf { r: Row => VectorAssembler.assemble(r.toSeq: _*) } - val schema = dataset.schema - val inputColNames = $(inputCols) - val args = inputColNames.map { c => + val args = $(inputCols).map { c => schema(c).dataType match { case DoubleType => dataset(c) case _: VectorUDT => dataset(c) case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid") } } - dataset.select(col("*"), assembleFunc(struct(args : _*)).as($(outputCol))) + + dataset.select(col("*"), assembleFunc(struct(args : _*)).as($(outputCol), metadata)) } override def transformSchema(schema: StructType): StructType = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index d0cd62c5e4864..43534e89928b1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -20,9 +20,11 @@ package org.apache.spark.ml.feature import org.scalatest.FunSuite import org.apache.spark.SparkException +import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.col class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext { @@ -61,4 +63,39 @@ class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext { assert(v === Vectors.sparse(6, Array(1, 2, 4, 5), Array(1.0, 2.0, 3.0, 10.0))) } } + + test("ML attributes") { + val browser = NominalAttribute.defaultAttr.withValues("chrome", "firefox", "safari") + val hour = NumericAttribute.defaultAttr.withMin(0.0).withMax(24.0) + val user = new AttributeGroup("user", Array( + NominalAttribute.defaultAttr.withName("gender").withValues("male", "female"), + NumericAttribute.defaultAttr.withName("salary"))) + val row = (1.0, 0.5, 1, Vectors.dense(1.0, 1000.0), Vectors.sparse(2, Array(1), Array(2.0))) + val df = sqlContext.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad") + .select( + col("browser").as("browser", browser.toMetadata()), + col("hour").as("hour", hour.toMetadata()), + col("count"), // "count" is an integer column without ML attribute + col("user").as("user", user.toMetadata()), + col("ad")) // "ad" is a vector column without ML attribute + val assembler = new VectorAssembler() + .setInputCols(Array("browser", "hour", "count", "user", "ad")) + .setOutputCol("features") + val output = assembler.transform(df) + val schema = output.schema + val features = AttributeGroup.fromStructField(schema("features")) + assert(features.size === 7) + val browserOut = features.getAttr(0) + assert(browserOut === browser.withIndex(0).withName("browser")) + val hourOut = features.getAttr(1) + assert(hourOut === hour.withIndex(1).withName("hour")) + val countOut = features.getAttr(2) + assert(countOut === NumericAttribute.defaultAttr.withName("count").withIndex(2)) + val userGenderOut = features.getAttr(3) + assert(userGenderOut === user.getAttr("gender").withName("user_gender").withIndex(3)) + val userSalaryOut = features.getAttr(4) + assert(userSalaryOut === user.getAttr("salary").withName("user_salary").withIndex(4)) + assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5)) + assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6)) + } } From 0077af22ca5fcb2e50dcf7daa4f6804ae722bfbe Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 28 May 2015 16:56:59 -0700 Subject: [PATCH 042/254] Remove SizeEstimator from o.a.spark package. See comments on https://github.com/apache/spark/pull/3913 Author: Reynold Xin Closes #6471 from rxin/sizeestimator and squashes the following commits: c057095 [Reynold Xin] Fixed import. 2da478b [Reynold Xin] Remove SizeEstimator from o.a.spark package. --- .../org/apache/spark/SizeEstimator.scala | 44 ------------------- .../org/apache/spark/util/SizeEstimator.scala | 20 +++++++-- 2 files changed, 17 insertions(+), 47 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/SizeEstimator.scala diff --git a/core/src/main/scala/org/apache/spark/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/SizeEstimator.scala deleted file mode 100644 index 54fc3a856adfa..0000000000000 --- a/core/src/main/scala/org/apache/spark/SizeEstimator.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import org.apache.spark.annotation.DeveloperApi - -/** - * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in - * memory-aware caches. - * - * Based on the following JavaWorld article: - * http://www.javaworld.com/javaworld/javaqa/2003-12/02-qa-1226-sizeof.html - */ -@DeveloperApi -object SizeEstimator { - /** - * :: DeveloperApi :: - * Estimate the number of bytes that the given object takes up on the JVM heap. The estimate - * includes space taken up by objects referenced by the given object, their references, and so on - * and so forth. - * - * This is useful for determining the amount of heap space a broadcast variable will occupy on - * each executor or the amount of space each object will take when caching objects in - * deserialized form. This is not the same as the serialized size of the object, which will - * typically be much smaller. - */ - @DeveloperApi - def estimate(obj: AnyRef): Long = org.apache.spark.util.SizeEstimator.estimate(obj) -} diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 968a72d5adae9..f38949c3cb846 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -21,21 +21,37 @@ import java.lang.management.ManagementFactory import java.lang.reflect.{Field, Modifier} import java.util.{IdentityHashMap, Random} import java.util.concurrent.ConcurrentHashMap + import scala.collection.mutable.ArrayBuffer import scala.runtime.ScalaRunTime import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.collection.OpenHashSet /** + * :: DeveloperApi :: * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in * memory-aware caches. * * Based on the following JavaWorld article: * http://www.javaworld.com/javaworld/javaqa/2003-12/02-qa-1226-sizeof.html */ -private[spark] object SizeEstimator extends Logging { +@DeveloperApi +object SizeEstimator extends Logging { + + /** + * Estimate the number of bytes that the given object takes up on the JVM heap. The estimate + * includes space taken up by objects referenced by the given object, their references, and so on + * and so forth. + * + * This is useful for determining the amount of heap space a broadcast variable will occupy on + * each executor or the amount of space each object will take when caching objects in + * deserialized form. This is not the same as the serialized size of the object, which will + * typically be much smaller. + */ + def estimate(obj: AnyRef): Long = estimate(obj, new IdentityHashMap[AnyRef, AnyRef]) // Sizes of primitive types private val BYTE_SIZE = 1 @@ -161,8 +177,6 @@ private[spark] object SizeEstimator extends Logging { val shellSize: Long, val pointerFields: List[Field]) {} - def estimate(obj: AnyRef): Long = estimate(obj, new IdentityHashMap[AnyRef, AnyRef]) - private def estimate(obj: AnyRef, visited: IdentityHashMap[AnyRef, AnyRef]): Long = { val state = new SearchState(visited) state.enqueue(obj) From 572b62cafe4bc7b1d464c9dcfb449c9d53456826 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 28 May 2015 17:12:30 -0700 Subject: [PATCH 043/254] [SPARK-7853] [SQL] Fix HiveContext in Spark Shell https://issues.apache.org/jira/browse/SPARK-7853 This fixes the problem introduced by my change in https://github.com/apache/spark/pull/6435, which causes that Hive Context fails to create in spark shell because of the class loader issue. Author: Yin Huai Closes #6459 from yhuai/SPARK-7853 and squashes the following commits: 37ad33e [Yin Huai] Do not use hiveQlTable at all. 47cdb6d [Yin Huai] Move hiveconf.set to the end of setConf. 005649b [Yin Huai] Update comment. 35d86f3 [Yin Huai] Access TTable directly to make sure Hive will not internally use any metastore utility functions. 3737766 [Yin Huai] Recursively find all jars. --- .../apache/spark/sql/hive/HiveContext.scala | 35 ++++++++++--------- .../spark/sql/hive/HiveMetastoreCatalog.scala | 12 +++---- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 9ab98fdcce725..2ed71d3d52880 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -189,24 +189,22 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { "Specify a vaild path to the correct hive jars using $HIVE_METASTORE_JARS " + s"or change $HIVE_METASTORE_VERSION to $hiveExecutionVersion.") } - // We recursively add all jars in the class loader chain, - // starting from the given urlClassLoader. - def addJars(urlClassLoader: URLClassLoader): Array[URL] = { - val jarsInParent = urlClassLoader.getParent match { - case parent: URLClassLoader => addJars(parent) - case other => Array.empty[URL] - } - urlClassLoader.getURLs ++ jarsInParent + // We recursively find all jars in the class loader chain, + // starting from the given classLoader. + def allJars(classLoader: ClassLoader): Array[URL] = classLoader match { + case null => Array.empty[URL] + case urlClassLoader: URLClassLoader => + urlClassLoader.getURLs ++ allJars(urlClassLoader.getParent) + case other => allJars(other.getParent) } - val jars = Utils.getContextOrSparkClassLoader match { - case urlClassLoader: URLClassLoader => addJars(urlClassLoader) - case other => - throw new IllegalArgumentException( - "Unable to locate hive jars to connect to metastore " + - s"using classloader ${other.getClass.getName}. " + - "Please set spark.sql.hive.metastore.jars") + val classLoader = Utils.getContextOrSparkClassLoader + val jars = allJars(classLoader) + if (jars.length == 0) { + throw new IllegalArgumentException( + "Unable to locate hive jars to connect to metastore. " + + "Please set spark.sql.hive.metastore.jars.") } logInfo( @@ -356,9 +354,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override def setConf(key: String, value: String): Unit = { super.setConf(key, value) - hiveconf.set(key, value) executionHive.runSqlHive(s"SET $key=$value") metadataHive.runSqlHive(s"SET $key=$value") + // If users put any Spark SQL setting in the spark conf (e.g. spark-defaults.conf), + // this setConf will be called in the constructor of the SQLContext. + // Also, calling hiveconf will create a default session containing a HiveConf, which + // will interfer with the creation of executionHive (which is a lazy val). So, + // we put hiveconf.set at the end of this method. + hiveconf.set(key, value) } /* A catalyst metadata catalog that points to the Hive Metastore. */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 425a4005aa2c3..95117f7a6847e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -707,20 +707,20 @@ private[hive] case class MetastoreRelation hiveQlTable.getMetadata ) - implicit class SchemaAttribute(f: FieldSchema) { + implicit class SchemaAttribute(f: HiveColumn) { def toAttribute: AttributeReference = AttributeReference( - f.getName, - HiveMetastoreTypes.toDataType(f.getType), + f.name, + HiveMetastoreTypes.toDataType(f.hiveType), // Since data can be dumped in randomly with no validation, everything is nullable. nullable = true )(qualifiers = Seq(alias.getOrElse(tableName))) } - // Must be a stable value since new attributes are born here. - val partitionKeys = hiveQlTable.getPartitionKeys.map(_.toAttribute) + /** PartitionKey attributes */ + val partitionKeys = table.partitionColumns.map(_.toAttribute) /** Non-partitionKey attributes */ - val attributes = hiveQlTable.getCols.map(_.toAttribute) + val attributes = table.schema.map(_.toAttribute) val output = attributes ++ partitionKeys From 1bd63e82fdb6ee57c61051430d63685b801df016 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 28 May 2015 17:30:12 -0700 Subject: [PATCH 044/254] [SPARK-7577] [ML] [DOC] add bucketizer doc CC jkbradley Author: Xusen Yin Closes #6451 from yinxusen/SPARK-7577 and squashes the following commits: e2dc32e [Xusen Yin] rename colums e350e49 [Xusen Yin] add all demos 006ddf1 [Xusen Yin] add java test 3238481 [Xusen Yin] add bucketizer --- docs/ml-features.md | 86 +++++++++++++++++++ .../spark/ml/feature/JavaBucketizerSuite.java | 80 +++++++++++++++++ 2 files changed, 166 insertions(+) create mode 100644 mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java diff --git a/docs/ml-features.md b/docs/ml-features.md index efe9b3b8edb6e..d7851a55fabfe 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -789,6 +789,92 @@ scaledData = scalerModel.transform(dataFrame) +## Bucketizer + +`Bucketizer` transforms a column of continuous features to a column of feature buckets, where the buckets are specified by users. It takes a parameter: + +* `splits`: Parameter for mapping continuous features into buckets. With n+1 splits, there are n buckets. A bucket defined by splits x,y holds values in the range [x,y) except the last bucket, which also includes y. Splits should be strictly increasing. Values at -inf, inf must be explicitly provided to cover all Double values; Otherwise, values outside the splits specified will be treated as errors. Two examples of `splits` are `Array(Double.NegativeInfinity, 0.0, 1.0, Double.PositiveInfinity)` and `Array(0.0, 1.0, 2.0)`. + +Note that if you have no idea of the upper bound and lower bound of the targeted column, you would better add the `Double.NegativeInfinity` and `Double.PositiveInfinity` as the bounds of your splits to prevent a potenial out of Bucketizer bounds exception. + +Note also that the splits that you provided have to be in strictly increasing order, i.e. `s0 < s1 < s2 < ... < sn`. + +More details can be found in the API docs for [Bucketizer](api/scala/index.html#org.apache.spark.ml.feature.Bucketizer). + +The following example demonstrates how to bucketize a column of `Double`s into another index-wised column. + +
+
+{% highlight scala %} +import org.apache.spark.ml.feature.Bucketizer +import org.apache.spark.sql.DataFrame + +val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) + +val data = Array(-0.5, -0.3, 0.0, 0.2) +val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + +val bucketizer = new Bucketizer() + .setInputCol("features") + .setOutputCol("bucketedFeatures") + .setSplits(splits) + +// Transform original data into its bucket index. +val bucketedData = bucketizer.transform(dataFrame) +{% endhighlight %} +
+ +
+{% highlight java %} +import com.google.common.collect.Lists; + +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; + +JavaRDD data = jsc.parallelize(Lists.newArrayList( + RowFactory.create(-0.5), + RowFactory.create(-0.3), + RowFactory.create(0.0), + RowFactory.create(0.2) +)); +StructType schema = new StructType(new StructField[] { + new StructField("features", DataTypes.DoubleType, false, Metadata.empty()) +}); +DataFrame dataFrame = jsql.createDataFrame(data, schema); + +Bucketizer bucketizer = new Bucketizer() + .setInputCol("features") + .setOutputCol("bucketedFeatures") + .setSplits(splits); + +// Transform original data into its bucket index. +DataFrame bucketedData = bucketizer.transform(dataFrame); +{% endhighlight %} +
+ +
+{% highlight python %} +from pyspark.ml.feature import Bucketizer + +splits = [-float("inf"), -0.5, 0.0, 0.5, float("inf")] + +data = [(-0.5,), (-0.3,), (0.0,), (0.2,)] +dataFrame = sqlContext.createDataFrame(data, ["features"]) + +bucketizer = Bucketizer(splits=splits, inputCol="features", outputCol="bucketedFeatures") + +# Transform original data into its bucket index. +bucketedData = bucketizer.transform(dataFrame) +{% endhighlight %} +
+
# Feature Selectors diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java new file mode 100644 index 0000000000000..d5bd230a957a1 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class JavaBucketizerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaBucketizerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void bucketizerTest() { + double[] splits = {-0.5, 0.0, 0.5}; + + JavaRDD data = jsc.parallelize(Lists.newArrayList( + RowFactory.create(-0.5), + RowFactory.create(-0.3), + RowFactory.create(0.0), + RowFactory.create(0.2) + )); + StructType schema = new StructType(new StructField[] { + new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) + }); + DataFrame dataset = jsql.createDataFrame(data, schema); + + Bucketizer bucketizer = new Bucketizer() + .setInputCol("feature") + .setOutputCol("result") + .setSplits(splits); + + Row[] result = bucketizer.transform(dataset).select("result").collect(); + + for (Row r : result) { + double index = r.getDouble(0); + Assert.assertTrue((index >= 0) && (index <= 1)); + } + } +} From 3af0b3136e4b7dea52c413d640653ccddc638574 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 28 May 2015 17:55:22 -0700 Subject: [PATCH 045/254] [SPARK-7927] whitespace fixes for streaming. So we can enable a whitespace enforcement rule in the style checker to save code review time. Author: Reynold Xin Closes #6475 from rxin/whitespace-streaming and squashes the following commits: 810dae4 [Reynold Xin] Fixed tests. 89068ad [Reynold Xin] [SPARK-7927] whitespace fixes for streaming. --- .../org/apache/spark/streaming/StreamingContext.scala | 2 +- .../apache/spark/streaming/api/java/JavaPairDStream.scala | 8 ++++---- .../org/apache/spark/streaming/dstream/DStream.scala | 2 +- .../apache/spark/streaming/dstream/FileInputDStream.scala | 8 ++++---- .../spark/streaming/dstream/PairDStreamFunctions.scala | 2 +- .../spark/streaming/dstream/ReducedWindowedDStream.scala | 8 ++++---- .../apache/spark/streaming/dstream/ShuffledDStream.scala | 6 +++--- .../org/apache/spark/streaming/dstream/StateDStream.scala | 2 +- .../apache/spark/streaming/dstream/WindowedDStream.scala | 4 ++-- .../apache/spark/streaming/receiver/BlockGenerator.scala | 2 +- .../org/apache/spark/streaming/receiver/RateLimiter.scala | 3 ++- .../spark/streaming/scheduler/ReceiverTracker.scala | 2 +- .../spark/streaming/util/FileBasedWriteAheadLog.scala | 2 +- .../org/apache/spark/streaming/util/RawTextHelper.scala | 4 ++-- .../org/apache/spark/streaming/BasicOperationsSuite.scala | 6 +++--- .../org/apache/spark/streaming/InputStreamsSuite.scala | 2 +- .../apache/spark/streaming/StreamingContextSuite.scala | 4 +++- .../apache/spark/streaming/StreamingListenerSuite.scala | 6 ++++-- .../streaming/ui/StreamingJobProgressListenerSuite.scala | 2 +- 19 files changed, 40 insertions(+), 35 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 5e58ed714829e..25842d502543e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -461,7 +461,7 @@ class StreamingContext private[streaming] ( val conf = sc_.hadoopConfiguration conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength) val br = fileStream[LongWritable, BytesWritable, FixedLengthBinaryInputFormat]( - directory, FileInputDStream.defaultFilter : Path => Boolean, newFilesOnly=true, conf) + directory, FileInputDStream.defaultFilter: Path => Boolean, newFilesOnly = true, conf) val data = br.map { case (k, v) => val bytes = v.getBytes require(bytes.length == recordLength, "Byte array does not have correct length. " + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index 93baad19e3ee1..959ac9c177f81 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -227,7 +227,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * @param numPartitions Number of partitions of each RDD in the new DStream. */ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) - :JavaPairDStream[K, JIterable[V]] = { + : JavaPairDStream[K, JIterable[V]] = { dstream.groupByKeyAndWindow(windowDuration, slideDuration, numPartitions) .mapValues(asJavaIterable _) } @@ -247,7 +247,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( windowDuration: Duration, slideDuration: Duration, partitioner: Partitioner - ):JavaPairDStream[K, JIterable[V]] = { + ): JavaPairDStream[K, JIterable[V]] = { dstream.groupByKeyAndWindow(windowDuration, slideDuration, partitioner) .mapValues(asJavaIterable _) } @@ -262,7 +262,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * batching interval */ def reduceByKeyAndWindow(reduceFunc: JFunction2[V, V, V], windowDuration: Duration) - :JavaPairDStream[K, V] = { + : JavaPairDStream[K, V] = { dstream.reduceByKeyAndWindow(reduceFunc, windowDuration) } @@ -281,7 +281,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( reduceFunc: JFunction2[V, V, V], windowDuration: Duration, slideDuration: Duration - ):JavaPairDStream[K, V] = { + ): JavaPairDStream[K, V] = { dstream.reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index c858647c6406d..6efcc193bfccc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -659,7 +659,7 @@ abstract class DStream[T: ClassTag] ( // DStreams can't be serialized with closures, we can't proactively check // it for serializability and so we pass the optional false to SparkContext.clean val cleanedF = context.sparkContext.clean(transformFunc, false) - val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { + val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { assert(rdds.length == 1) cleanedF(rdds.head.asInstanceOf[RDD[T]], time) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index eca69f00188e4..6c1fab56740ee 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -69,7 +69,7 @@ import org.apache.spark.util.{TimeStampedHashMap, Utils} * processing semantics are undefined. */ private[streaming] -class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( +class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( @transient ssc_ : StreamingContext, directory: String, filter: Path => Boolean = FileInputDStream.defaultFilter, @@ -251,7 +251,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( /** Generate one RDD from an array of files */ private def filesToRDD(files: Seq[String]): RDD[(K, V)] = { - val fileRDDs = files.map(file =>{ + val fileRDDs = files.map { file => val rdd = serializableConfOpt.map(_.value) match { case Some(config) => context.sparkContext.newAPIHadoopFile( file, @@ -267,7 +267,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( "Refer to the streaming programming guide for more details.") } rdd - }) + } new UnionRDD(context.sparkContext, fileRDDs) } @@ -294,7 +294,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { logDebug(this.getClass().getSimpleName + ".readObject used") ois.defaultReadObject() - generatedRDDs = new mutable.HashMap[Time, RDD[(K,V)]] () + generatedRDDs = new mutable.HashMap[Time, RDD[(K, V)]]() batchTimeToSelectedFiles = new mutable.HashMap[Time, Array[String]] with mutable.SynchronizedMap[Time, Array[String]] recentlySelectedFiles = new mutable.HashSet[String]() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index fda22eb6ec42e..358e4c66df7ba 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -32,7 +32,7 @@ import org.apache.spark.streaming.StreamingContext.rddToFileName /** * Extra functions available on DStream of (key, value) pairs through an implicit conversion. */ -class PairDStreamFunctions[K, V](self: DStream[(K,V)]) +class PairDStreamFunctions[K, V](self: DStream[(K, V)]) (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K]) extends Serializable { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala index df9f7f140eddc..6a583bf2a3626 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -38,7 +38,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( _windowDuration: Duration, _slideDuration: Duration, partitioner: Partitioner - ) extends DStream[(K,V)](parent.ssc) { + ) extends DStream[(K, V)](parent.ssc) { require(_windowDuration.isMultipleOf(parent.slideDuration), "The window duration of ReducedWindowedDStream (" + _windowDuration + ") " + @@ -58,7 +58,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( super.persist(StorageLevel.MEMORY_ONLY_SER) reducedStream.persist(StorageLevel.MEMORY_ONLY_SER) - def windowDuration: Duration = _windowDuration + def windowDuration: Duration = _windowDuration override def dependencies: List[DStream[_]] = List(reducedStream) @@ -68,7 +68,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( override def parentRememberDuration: Duration = rememberDuration + windowDuration - override def persist(storageLevel: StorageLevel): DStream[(K,V)] = { + override def persist(storageLevel: StorageLevel): DStream[(K, V)] = { super.persist(storageLevel) reducedStream.persist(storageLevel) this @@ -118,7 +118,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( // Get the RDD of the reduced value of the previous window val previousWindowRDD = - getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K,V)]())) + getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K, V)]())) // Make the list of RDDs that needs to cogrouped together for reducing their reduced values val allRDDs = new ArrayBuffer[RDD[(K, V)]]() += previousWindowRDD ++= oldRDDs ++= newRDDs diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala index 7757ccac09a58..e0ffd5d86b435 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala @@ -25,19 +25,19 @@ import scala.reflect.ClassTag private[streaming] class ShuffledDStream[K: ClassTag, V: ClassTag, C: ClassTag]( - parent: DStream[(K,V)], + parent: DStream[(K, V)], createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiner: (C, C) => C, partitioner: Partitioner, mapSideCombine: Boolean = true - ) extends DStream[(K,C)] (parent.ssc) { + ) extends DStream[(K, C)] (parent.ssc) { override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration - override def compute(validTime: Time): Option[RDD[(K,C)]] = { + override def compute(validTime: Time): Option[RDD[(K, C)]] = { parent.getOrCompute(validTime) match { case Some(rdd) => Some(rdd.combineByKey[C]( createCombiner, mergeValue, mergeCombiner, partitioner, mapSideCombine)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index de8718d0a80fe..621d6dff788f4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -51,7 +51,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => { val i = iterator.map(t => { val itr = t._2._2.iterator - val headOption = if(itr.hasNext) Some(itr.next) else None + val headOption = if (itr.hasNext) Some(itr.next()) else None (t._1, t._2._1.toSeq, headOption) }) updateFuncLocal(i) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala index 899865a906c27..4efba039f8959 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala @@ -44,7 +44,7 @@ class WindowedDStream[T: ClassTag]( // Persist parent level by default, as those RDDs are going to be obviously reused. parent.persist(StorageLevel.MEMORY_ONLY_SER) - def windowDuration: Duration = _windowDuration + def windowDuration: Duration = _windowDuration override def dependencies: List[DStream[_]] = List(parent) @@ -68,7 +68,7 @@ class WindowedDStream[T: ClassTag]( new PartitionerAwareUnionRDD(ssc.sc, rddsInWindow) } else { logDebug("Using normal union for windowing at " + validTime) - new UnionRDD(ssc.sc,rddsInWindow) + new UnionRDD(ssc.sc, rddsInWindow) } Some(windowRDD) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 4bebcc5aa7ca0..0588517a2de39 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -164,7 +164,7 @@ private[streaming] class BlockGenerator( private def keepPushingBlocks() { logInfo("Started block pushing thread") try { - while(!stopped) { + while (!stopped) { Option(blocksForPushing.poll(100, TimeUnit.MILLISECONDS)) match { case Some(block) => pushBlock(block) case None => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala index 97db9ded83367..8df542b367d27 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala @@ -17,8 +17,9 @@ package org.apache.spark.streaming.receiver +import com.google.common.util.concurrent.{RateLimiter => GuavaRateLimiter} + import org.apache.spark.{Logging, SparkConf} -import com.google.common.util.concurrent.{RateLimiter=>GuavaRateLimiter} /** Provides waitToPush() method to limit the rate at which receivers consume data. * diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index f73f7e705ee0d..f1504b09c9873 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -230,7 +230,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false class ReceiverLauncher { @transient val env = ssc.env @volatile @transient private var running = false - @transient val thread = new Thread() { + @transient val thread = new Thread() { override def run() { try { SparkEnv.set(env) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 87ba4f84a9ceb..fe6328b1ce727 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -200,7 +200,7 @@ private[streaming] class FileBasedWriteAheadLog( /** Initialize the log directory or recover existing logs inside the directory */ private def initializeOrRecover(): Unit = synchronized { val logDirectoryPath = new Path(logDirectory) - val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) + val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) if (fileSystem.exists(logDirectoryPath) && fileSystem.getFileStatus(logDirectoryPath).isDir) { val logFileInfo = logFilesTologInfo(fileSystem.listStatus(logDirectoryPath).map { _.getPath }) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala index 4d968f8bfa7a8..408936653c790 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala @@ -27,7 +27,7 @@ object RawTextHelper { * Splits lines and counts the words. */ def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, Long)] = { - val map = new OpenHashMap[String,Long] + val map = new OpenHashMap[String, Long] var i = 0 var j = 0 while (iter.hasNext) { @@ -98,7 +98,7 @@ object RawTextHelper { * before real workload starts. */ def warmUp(sc: SparkContext) { - for(i <- 0 to 1) { + for (i <- 0 to 1) { sc.parallelize(1 to 200000, 1000) .map(_ % 1331).map(_.toString) .mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index f269cb74e0c2b..08faeaa58f419 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -255,7 +255,7 @@ class BasicOperationsSuite extends TestSuiteBase { Seq( ) ) val operation = (s1: DStream[String], s2: DStream[String]) => { - s1.map(x => (x,1)).cogroup(s2.map(x => (x, "x"))).mapValues(x => (x._1.toSeq, x._2.toSeq)) + s1.map(x => (x, 1)).cogroup(s2.map(x => (x, "x"))).mapValues(x => (x._1.toSeq, x._2.toSeq)) } testOperation(inputData1, inputData2, operation, outputData, true) } @@ -427,9 +427,9 @@ class BasicOperationsSuite extends TestSuiteBase { test("updateStateByKey - object lifecycle") { val inputData = Seq( - Seq("a","b"), + Seq("a", "b"), null, - Seq("a","c","a"), + Seq("a", "c", "a"), Seq("c"), null, null diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 0122514f9374c..b74d67c63a788 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -418,7 +418,7 @@ class TestServer(portToBind: Int = 0) extends Logging { val servingThread = new Thread() { override def run() { try { - while(true) { + while (true) { logInfo("Accepting connections on port " + port) val clientSocket = serverSocket.accept() if (startLatch.getCount == 1) { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index f8e8030791df1..e36c7914b130e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -732,7 +732,9 @@ class SlowTestReceiver(totalRecords: Int, recordsPerSecond: Int) def onStop() { // Simulate slow receiver by waiting for all records to be produced - while(!SlowTestReceiver.receivedAllRecords) Thread.sleep(100) + while (!SlowTestReceiver.receivedAllRecords) { + Thread.sleep(100) + } // no clean to be done, the receiving thread should stop on it own } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 312cce408cfe7..1dc8960d60528 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -133,8 +133,10 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { /** Check if a sequence of numbers is in increasing order */ def isInIncreasingOrder(seq: Seq[Long]): Boolean = { - for(i <- 1 until seq.size) { - if (seq(i - 1) > seq(i)) return false + for (i <- 1 until seq.size) { + if (seq(i - 1) > seq(i)) { + return false + } } true } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index 2a0f45830e03c..c9175d61b1f49 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -64,7 +64,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.numTotalReceivedRecords should be (0) // onBatchStarted - val batchInfoStarted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoStarted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) listener.waitingBatches should be (Nil) listener.runningBatches should be (List(BatchUIData(batchInfoStarted))) From ee6a0e12fb76e4d5c24175900e5bf6a8cb35e2b0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 28 May 2015 18:08:56 -0700 Subject: [PATCH 046/254] [SPARK-7927] whitespace fixes for Hive and ThriftServer. So we can enable a whitespace enforcement rule in the style checker to save code review time. Author: Reynold Xin Closes #6478 from rxin/whitespace-hive and squashes the following commits: e01b0e0 [Reynold Xin] Fixed tests. a3bba22 [Reynold Xin] [SPARK-7927] whitespace fixes for Hive and ThriftServer. --- .../sql/hive/thriftserver/SparkSQLCLIDriver.scala | 8 ++++---- .../sql/hive/thriftserver/ui/ThriftServerPage.scala | 6 +++--- .../hive/thriftserver/ui/ThriftServerSessionPage.scala | 2 +- .../spark/sql/hive/thriftserver/UISeleniumSuite.scala | 4 ++-- .../apache/spark/sql/hive/ExtendedHiveQlParser.scala | 6 +++--- .../scala/org/apache/spark/sql/hive/HiveContext.scala | 4 ++-- .../org/apache/spark/sql/hive/HiveInspectors.scala | 10 +++++----- .../apache/spark/sql/hive/HiveMetastoreCatalog.scala | 10 +++++++--- .../main/scala/org/apache/spark/sql/hive/HiveQl.scala | 9 +++++---- .../spark/sql/hive/execution/InsertIntoHiveTable.scala | 7 +++---- .../sql/hive/execution/ScriptTransformation.scala | 2 +- .../scala/org/apache/spark/sql/hive/hiveUdfs.scala | 6 +++--- .../apache/spark/sql/hive/hiveWriterContainers.scala | 2 +- .../org/apache/spark/sql/hive/test/TestHive.scala | 6 +++--- 14 files changed, 43 insertions(+), 39 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index deb1008c468bf..14f6f658d9b75 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -43,7 +43,7 @@ import org.apache.spark.util.Utils private[hive] object SparkSQLCLIDriver { private var prompt = "spark-sql" private var continuedPrompt = "".padTo(prompt.length, ' ') - private var transport:TSocket = _ + private var transport: TSocket = _ installSignalHandler() @@ -276,13 +276,13 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { driver.init() val out = sessionState.out - val start:Long = System.currentTimeMillis() + val start: Long = System.currentTimeMillis() if (sessionState.getIsVerbose) { out.println(cmd) } val rc = driver.run(cmd) val end = System.currentTimeMillis() - val timeTaken:Double = (end - start) / 1000.0 + val timeTaken: Double = (end - start) / 1000.0 ret = rc.getResponseCode if (ret != 0) { @@ -310,7 +310,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { res.clear() } } catch { - case e:IOException => + case e: IOException => console.printError( s"""Failed with exception ${e.getClass.getName}: ${e.getMessage} |${org.apache.hadoop.util.StringUtils.stringifyException(e)} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index 7c48ff4b35df5..10c83d8b27a2a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -77,7 +77,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" [{id}] } - val detail = if(info.state == ExecutionState.FAILED) info.detail else info.executePlan + val detail = if (info.state == ExecutionState.FAILED) info.detail else info.executePlan {info.userName} @@ -85,7 +85,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" {info.groupId} {formatDate(info.startTimestamp)} - {if(info.finishTimestamp > 0) formatDate(info.finishTimestamp)} + {if (info.finishTimestamp > 0) formatDate(info.finishTimestamp)} {formatDurationOption(Some(info.totalTime))} {info.statement} {info.state} @@ -150,7 +150,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" {session.ip} {session.sessionId} {formatDate(session.startTimestamp)} - {if(session.finishTimestamp > 0) formatDate(session.finishTimestamp)} + {if (session.finishTimestamp > 0) formatDate(session.finishTimestamp)} {formatDurationOption(Some(session.totalTime))} {session.totalExecution.toString} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index d9d66dcd8517e..3b01afa603cea 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -87,7 +87,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) [{id}] } - val detail = if(info.state == ExecutionState.FAILED) info.detail else info.executePlan + val detail = if (info.state == ExecutionState.FAILED) info.detail else info.executePlan {info.userName} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala index e1466e0423033..4c9fab7ef6136 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala @@ -73,7 +73,7 @@ class UISeleniumSuite } ignore("thrift server ui test") { - withJdbcStatement(statement =>{ + withJdbcStatement { statement => val baseURL = s"http://localhost:$uiPort" val queries = Seq( @@ -97,6 +97,6 @@ class UISeleniumSuite findAll(cssSelector("""ul table tbody tr td""")).map(_.text).toList should contain (line) } } - }) + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala index 3f20c6142e59a..7f8449cdc282d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala @@ -29,10 +29,10 @@ import org.apache.spark.sql.hive.execution.{AddJar, AddFile, HiveNativeCommand} private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser { // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` // properties via reflection the class in runtime for constructing the SqlLexical object - protected val ADD = Keyword("ADD") - protected val DFS = Keyword("DFS") + protected val ADD = Keyword("ADD") + protected val DFS = Keyword("DFS") protected val FILE = Keyword("FILE") - protected val JAR = Keyword("JAR") + protected val JAR = Keyword("JAR") protected lazy val start: Parser[LogicalPlan] = dfs | addJar | addFile | hiveQl diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 2ed71d3d52880..fbf2c7d8cbc06 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -530,7 +530,7 @@ private[hive] object HiveContext { val propMap: HashMap[String, String] = HashMap() // We have to mask all properties in hive-site.xml that relates to metastore data source // as we used a local metastore here. - HiveConf.ConfVars.values().foreach { confvar => + HiveConf.ConfVars.values().foreach { confvar => if (confvar.varname.contains("datanucleus") || confvar.varname.contains("jdo")) { propMap.put(confvar.varname, confvar.defaultVal) } @@ -553,7 +553,7 @@ private[hive] object HiveContext { }.mkString("{", ",", "}") case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_,_], MapType(kType, vType, _)) => + case (map: Map[_, _], MapType(kType, vType, _)) => map.map { case (key, value) => toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 0a694c70e4e5c..24cd335082639 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -335,7 +335,7 @@ private[hive] trait HiveInspectors { val allRefs = si.getAllStructFieldRefs new GenericRow( allRefs.map(r => - unwrap(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray) + unwrap(si.getStructFieldData(data, r), r.getFieldObjectInspector)).toArray) } @@ -561,8 +561,8 @@ private[hive] trait HiveInspectors { case DecimalType() => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector case StructType(fields) => ObjectInspectorFactory.getStandardStructObjectInspector( - java.util.Arrays.asList(fields.map(f => f.name) :_*), - java.util.Arrays.asList(fields.map(f => toInspector(f.dataType)) :_*)) + java.util.Arrays.asList(fields.map(f => f.name) : _*), + java.util.Arrays.asList(fields.map(f => toInspector(f.dataType)) : _*)) } /** @@ -677,8 +677,8 @@ private[hive] trait HiveInspectors { getListTypeInfo(elemType.toTypeInfo) case StructType(fields) => getStructTypeInfo( - java.util.Arrays.asList(fields.map(_.name) :_*), - java.util.Arrays.asList(fields.map(_.dataType.toTypeInfo) :_*)) + java.util.Arrays.asList(fields.map(_.name) : _*), + java.util.Arrays.asList(fields.map(_.dataType.toTypeInfo) : _*)) case MapType(keyType, valueType, _) => getMapTypeInfo(keyType.toTypeInfo, valueType.toTypeInfo) case BinaryType => binaryTypeInfo diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 95117f7a6847e..47b85731587d5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -546,13 +546,17 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore. * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]]. */ - override def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = ??? + override def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { + throw new UnsupportedOperationException + } /** * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore. * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]]. */ - override def unregisterTable(tableIdentifier: Seq[String]): Unit = ??? + override def unregisterTable(tableIdentifier: Seq[String]): Unit = { + throw new UnsupportedOperationException + } override def unregisterAllTables(): Unit = {} } @@ -725,7 +729,7 @@ private[hive] case class MetastoreRelation val output = attributes ++ partitionKeys /** An attribute map that can be used to lookup original attributes based on expression id. */ - val attributeMap = AttributeMap(output.map(o => (o,o))) + val attributeMap = AttributeMap(output.map(o => (o, o))) /** An attribute map for determining the ordinal for non-partition columns. */ val columnOrdinals = AttributeMap(attributes.zipWithIndex) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 2cbb5ca4d2e0c..3915ee835685f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -665,7 +665,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C HiveColumn(field.getName, field.getType, field.getComment) }) } - case Token("TOK_TABLEROWFORMAT", Token("TOK_SERDEPROPS", child :: Nil) :: Nil)=> + case Token("TOK_TABLEROWFORMAT", Token("TOK_SERDEPROPS", child :: Nil) :: Nil) => val serdeParams = new java.util.HashMap[String, String]() child match { case Token("TOK_TABLEROWFORMATFIELD", rowChild1 :: rowChild2) => @@ -775,7 +775,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // Support "TRUNCATE TABLE table_name [PARTITION partition_spec]" case Token("TOK_TRUNCATETABLE", - Token("TOK_TABLE_PARTITION",table)::Nil) => NativePlaceholder + Token("TOK_TABLE_PARTITION", table) :: Nil) => NativePlaceholder case Token("TOK_QUERY", queryArgs) if Seq("TOK_FROM", "TOK_INSERT").contains(queryArgs.head.getText) => @@ -1151,7 +1151,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Seq(false, false) => Inner }.toBuffer - val joinedTables = tables.reduceLeft(Join(_,_, Inner, None)) + val joinedTables = tables.reduceLeft(Join(_, _, Inner, None)) // Must be transform down. val joinedResult = joinedTables transform { @@ -1171,7 +1171,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // worth the number of hacks that will be required to implement it. Namely, we need to add // some sort of mapped star expansion that would expand all child output row to be similarly // named output expressions where some aggregate expression has been applied (i.e. First). - ??? // Aggregate(groups, Star(None, First(_)) :: Nil, joinedResult) + // Aggregate(groups, Star(None, First(_)) :: Nil, joinedResult) + throw new UnsupportedOperationException case Token(allJoinTokens(joinToken), relation1 :: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 7a6ca48b54a24..8613332186f28 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -194,10 +194,9 @@ case class InsertIntoHiveTable( if (partition.nonEmpty) { // loadPartition call orders directories created on the iteration order of the this map - val orderedPartitionSpec = new util.LinkedHashMap[String,String]() - table.hiveQlTable.getPartCols().foreach{ - entry=> - orderedPartitionSpec.put(entry.getName,partitionSpec.get(entry.getName).getOrElse("")) + val orderedPartitionSpec = new util.LinkedHashMap[String, String]() + table.hiveQlTable.getPartCols().foreach { entry => + orderedPartitionSpec.put(entry.getName, partitionSpec.get(entry.getName).getOrElse("")) } val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols, partitionSpec) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index bfd26e0170c70..6f27a8626fc1e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -216,7 +216,7 @@ case class HiveScriptIOSchema ( val columnTypes = attrs.map { case aref: AttributeReference => aref.dataType case e: NamedExpression => e.dataType - case _ => null + case _ => null } (columns, columnTypes) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 7ec4f7332502e..bb116e3ab7de7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -315,7 +315,7 @@ private[hive] case class HiveWindowFunction( // The object inspector of values returned from the Hive window function. @transient - protected lazy val returnInspector = { + protected lazy val returnInspector = { evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) } @@ -410,7 +410,7 @@ private[hive] case class HiveGenericUdaf( protected lazy val resolver: AbstractGenericUDAFResolver = funcWrapper.createFunction() @transient - protected lazy val objectInspector = { + protected lazy val objectInspector = { val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false) resolver.getEvaluator(parameterInfo) .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) @@ -443,7 +443,7 @@ private[hive] case class HiveUdaf( new GenericUDAFBridge(funcWrapper.createFunction()) @transient - protected lazy val objectInspector = { + protected lazy val objectInspector = { val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false) resolver.getEvaluator(parameterInfo) .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 50b209f7ccbb8..2bb526b14be34 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -71,7 +71,7 @@ private[hive] class SparkHiveWriterContainer( @transient protected lazy val jobContext = newJobContext(conf.value, jID.value) @transient private lazy val taskContext = newTaskAttemptContext(conf.value, taID.value) @transient private lazy val outputFormat = - conf.value.getOutputFormat.asInstanceOf[HiveOutputFormat[AnyRef,Writable]] + conf.value.getOutputFormat.asInstanceOf[HiveOutputFormat[AnyRef, Writable]] def driverSideSetup() { setIDs(0, 0, 0) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 2e06cabfa80c9..7c7afc824d7a6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -189,7 +189,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { } } - case class TestTable(name: String, commands: (()=>Unit)*) + case class TestTable(name: String, commands: (() => Unit)*) protected[hive] implicit class SqlCmd(sql: String) { def cmd: () => Unit = { @@ -253,8 +253,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { | 'serialization.format'='${classOf[TBinaryProtocol].getName}' |) |STORED AS - |INPUTFORMAT '${classOf[SequenceFileInputFormat[_,_]].getName}' - |OUTPUTFORMAT '${classOf[SequenceFileOutputFormat[_,_]].getName}' + |INPUTFORMAT '${classOf[SequenceFileInputFormat[_, _]].getName}' + |OUTPUTFORMAT '${classOf[SequenceFileOutputFormat[_, _]].getName}' """.stripMargin) runSqlHive( From 66c49ed60dcef48a6b38ae2d2c4c479933f3aa19 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Thu, 28 May 2015 19:04:32 -0700 Subject: [PATCH 047/254] [SPARK-7933] Remove Patrick's username/pw from merge script Looks like this was added by accident when pwendell merged a commit back in September: fe2b1d6a209db9fe96b1c6630677955b94bd48c9 Author: Kay Ousterhout Closes #6485 from kayousterhout/SPARK-7933 and squashes the following commits: 7c6164a [Kay Ousterhout] [SPARK-7933] Remove Patrick's username/pw from merge script --- dev/merge_spark_pr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 787c5cc8e892d..cd83b352c1bfb 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -44,9 +44,9 @@ # Remote name which points to Apache git PUSH_REMOTE_NAME = os.environ.get("PUSH_REMOTE_NAME", "apache") # ASF JIRA username -JIRA_USERNAME = os.environ.get("JIRA_USERNAME", "pwendell") +JIRA_USERNAME = os.environ.get("JIRA_USERNAME", "") # ASF JIRA password -JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", "35500") +JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", "") GITHUB_BASE = "https://github.com/apache/spark/pull" GITHUB_API_BASE = "https://api.github.com/repos/apache/spark" From 9b692bfdfcc91b32498865d21138cf215a378665 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 28 May 2015 19:05:12 -0700 Subject: [PATCH 048/254] [SPARK-7826] [CORE] Suppress extra calling getCacheLocs. There are too many extra call method `getCacheLocs` for `DAGScheduler`, which includes Akka communication. To improve `DAGScheduler` performance, suppress extra calling the method. In my application with over 1200 stages, the execution time became 3.8 min from 8.5 min with my patch. Author: Takuya UESHIN Closes #6352 from ueshin/issues/SPARK-7826 and squashes the following commits: 3d4d036 [Takuya UESHIN] Modify a test and the documentation. 10b1b22 [Takuya UESHIN] Simplify the unit test. d858b59 [Takuya UESHIN] Move the storageLevel check inside the if (!cacheLocs.contains(rdd.id)) block. 6f3125c [Takuya UESHIN] Fix scalastyle. b9c835c [Takuya UESHIN] Put the condition that checks if the RDD has uncached partition or not into variable for readability. f87f2ec [Takuya UESHIN] Get cached locations from block manager only if the storage level of the RDD is not StorageLevel.NONE. 8248386 [Takuya UESHIN] Revert "Suppress extra calling getCacheLocs." a4d944a [Takuya UESHIN] Add an unit test. 9a80fad [Takuya UESHIN] Suppress extra calling getCacheLocs. --- .../apache/spark/scheduler/DAGScheduler.scala | 15 +++++--- .../spark/scheduler/DAGSchedulerSuite.scala | 35 ++++++++++++++++--- 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index a083be2448aa3..a2299e907c5ae 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -193,9 +193,15 @@ class DAGScheduler( def getCacheLocs(rdd: RDD[_]): Seq[Seq[TaskLocation]] = cacheLocs.synchronized { // Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times if (!cacheLocs.contains(rdd.id)) { - val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId] - val locs: Seq[Seq[TaskLocation]] = blockManagerMaster.getLocations(blockIds).map { bms => - bms.map(bm => TaskLocation(bm.host, bm.executorId)) + // Note: if the storage level is NONE, we don't need to get locations from block manager. + val locs: Seq[Seq[TaskLocation]] = if (rdd.getStorageLevel == StorageLevel.NONE) { + Seq.fill(rdd.partitions.size)(Nil) + } else { + val blockIds = + rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId] + blockManagerMaster.getLocations(blockIds).map { bms => + bms.map(bm => TaskLocation(bm.host, bm.executorId)) + } } cacheLocs(rdd.id) = locs } @@ -382,7 +388,8 @@ class DAGScheduler( def visit(rdd: RDD[_]) { if (!visited(rdd)) { visited += rdd - if (getCacheLocs(rdd).contains(Nil)) { + val rddHasUncachedPartitions = getCacheLocs(rdd).contains(Nil) + if (rddHasUncachedPartitions) { for (dep <- rdd.dependencies) { dep match { case shufDep: ShuffleDependency[_, _, _] => diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 6a8ae29aae675..46642236e454a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -318,7 +318,7 @@ class DAGSchedulerSuite } test("cache location preferences w/ dependency") { - val baseRdd = new MyRDD(sc, 1, Nil) + val baseRdd = new MyRDD(sc, 1, Nil).cache() val finalRdd = new MyRDD(sc, 1, List(new OneToOneDependency(baseRdd))) cacheLocations(baseRdd.id -> 0) = Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")) @@ -331,7 +331,7 @@ class DAGSchedulerSuite } test("regression test for getCacheLocs") { - val rdd = new MyRDD(sc, 3, Nil) + val rdd = new MyRDD(sc, 3, Nil).cache() cacheLocations(rdd.id -> 0) = Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")) cacheLocations(rdd.id -> 1) = @@ -342,6 +342,33 @@ class DAGSchedulerSuite assert(locs === Seq(Seq("hostA", "hostB"), Seq("hostB", "hostC"), Seq("hostC", "hostD"))) } + /** + * This test ensures that if a particular RDD is cached, RDDs earlier in the dependency chain + * are not computed. It constructs the following chain of dependencies: + * +---+ shuffle +---+ +---+ +---+ + * | A |<--------| B |<---| C |<---| D | + * +---+ +---+ +---+ +---+ + * Here, B is derived from A by performing a shuffle, C has a one-to-one dependency on B, + * and D similarly has a one-to-one dependency on C. If none of the RDDs were cached, this + * set of RDDs would result in a two stage job: one ShuffleMapStage, and a ResultStage that + * reads the shuffled data from RDD A. This test ensures that if C is cached, the scheduler + * doesn't perform a shuffle, and instead computes the result using a single ResultStage + * that reads C's cached data. + */ + test("getMissingParentStages should consider all ancestor RDDs' cache statuses") { + val rddA = new MyRDD(sc, 1, Nil) + val rddB = new MyRDD(sc, 1, List(new ShuffleDependency(rddA, null))) + val rddC = new MyRDD(sc, 1, List(new OneToOneDependency(rddB))).cache() + val rddD = new MyRDD(sc, 1, List(new OneToOneDependency(rddC))) + cacheLocations(rddC.id -> 0) = + Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")) + submit(rddD, Array(0)) + assert(scheduler.runningStages.size === 1) + // Make sure that the scheduler is running the final result stage. + // Because C is cached, the shuffle map stage to compute A does not need to be run. + assert(scheduler.runningStages.head.isInstanceOf[ResultStage]) + } + test("avoid exponential blowup when getting preferred locs list") { // Build up a complex dependency graph with repeated zip operations, without preferred locations var rdd: RDD[_] = new MyRDD(sc, 1, Nil) @@ -678,9 +705,9 @@ class DAGSchedulerSuite } test("cached post-shuffle") { - val shuffleOneRdd = new MyRDD(sc, 2, Nil) + val shuffleOneRdd = new MyRDD(sc, 2, Nil).cache() val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) - val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne)) + val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne)).cache() val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo)) submit(finalRdd, Array(0)) From 04616b1a2f5244710b07ecbb404384ded893292c Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 28 May 2015 20:09:12 -0700 Subject: [PATCH 049/254] [SPARK-7927] [MLLIB] Enforce whitespace for more tokens in style checker rxin Author: Xiangrui Meng Closes #6481 from mengxr/mllib-scalastyle and squashes the following commits: 3ca4d61 [Xiangrui Meng] revert scalastyle config 30961ba [Xiangrui Meng] adjust spaces in mllib/test 571b5c5 [Xiangrui Meng] fix spaces in mllib --- .../spark/ml/classification/OneVsRest.scala | 10 +++++----- .../scala/org/apache/spark/ml/tree/Node.scala | 4 ++-- .../mllib/api/python/PythonMLLibAPI.scala | 8 ++++---- .../mllib/clustering/GaussianMixture.scala | 2 +- .../spark/mllib/clustering/LDAOptimizer.scala | 6 +++--- .../org/apache/spark/mllib/feature/IDF.scala | 2 +- .../spark/mllib/feature/StandardScaler.scala | 2 +- .../apache/spark/mllib/feature/Word2Vec.scala | 6 +++--- .../mllib/linalg/distributed/RowMatrix.scala | 2 +- .../apache/spark/mllib/random/RandomRDDs.scala | 2 +- .../mllib/regression/IsotonicRegression.scala | 2 +- .../stat/MultivariateOnlineSummarizer.scala | 2 +- .../spark/mllib/stat/test/ChiSqTest.scala | 2 +- .../apache/spark/mllib/tree/DecisionTree.scala | 4 ++-- .../mllib/tree/GradientBoostedTrees.scala | 12 ++++++------ .../apache/spark/mllib/tree/RandomForest.scala | 2 +- .../apache/spark/mllib/tree/model/Node.scala | 8 ++++---- .../spark/mllib/util/MFDataGenerator.scala | 7 +++---- .../spark/ml/feature/Word2VecSuite.scala | 6 +++--- .../spark/ml/tuning/CrossValidatorSuite.scala | 12 +++++++++--- .../mllib/api/python/PythonMLLibAPISuite.scala | 2 +- .../mllib/classification/NaiveBayesSuite.scala | 2 +- .../spark/mllib/classification/SVMSuite.scala | 18 +++++++++--------- .../spark/mllib/clustering/KMeansSuite.scala | 4 ++-- .../PowerIterationClusteringSuite.scala | 2 ++ .../evaluation/RegressionMetricsSuite.scala | 4 ++-- .../mllib/feature/StandardScalerSuite.scala | 2 +- .../linalg/distributed/BlockMatrixSuite.scala | 2 ++ .../optimization/GradientDescentSuite.scala | 2 +- .../spark/mllib/optimization/NNLSSuite.scala | 2 ++ .../spark/mllib/regression/LassoSuite.scala | 10 ++++++---- .../spark/mllib/stat/CorrelationSuite.scala | 6 +++++- .../apache/spark/mllib/util/MLUtilsSuite.scala | 2 +- 33 files changed, 88 insertions(+), 71 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 36735cd834cd4..b8c7f3c5bc3b9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -70,7 +70,7 @@ private[ml] trait OneVsRestParams extends PredictorParams { final class OneVsRestModel private[ml] ( override val uid: String, labelMetadata: Metadata, - val models: Array[_ <: ClassificationModel[_,_]]) + val models: Array[_ <: ClassificationModel[_, _]]) extends Model[OneVsRestModel] with OneVsRestParams { override def transformSchema(schema: StructType): StructType = { @@ -104,17 +104,17 @@ final class OneVsRestModel private[ml] ( // add temporary column to store intermediate scores and update val tmpColName = "mbc$tmp" + UUID.randomUUID().toString - val update: (Map[Int, Double], Vector) => Map[Int, Double] = + val update: (Map[Int, Double], Vector) => Map[Int, Double] = (predictions: Map[Int, Double], prediction: Vector) => { predictions + ((index, prediction(1))) } val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol)) - val transformedDataset = model.transform(df).select(columns:_*) + val transformedDataset = model.transform(df).select(columns : _*) val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf) val newColumns = origCols ++ List(col(tmpColName)) // switch out the intermediate column with the accumulator column - updatedDataset.select(newColumns:_*).withColumnRenamed(tmpColName, accColName) + updatedDataset.select(newColumns : _*).withColumnRenamed(tmpColName, accColName) } if (handlePersistence) { @@ -190,7 +190,7 @@ final class OneVsRest(override val uid: String) val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta) val classifier = getClassifier classifier.fit(trainingDataset, classifier.labelCol -> labelColName) - }.toArray[ClassificationModel[_,_]] + }.toArray[ClassificationModel[_, _]] if (handlePersistence) { multiclassLabeled.unpersist() diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index 6a84176efb086..4242154be14ce 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -159,9 +159,9 @@ final class InternalNode private[ml] ( override private[tree] def subtreeToString(indentFactor: Int = 0): String = { val prefix: String = " " * indentFactor - prefix + s"If (${InternalNode.splitToString(split, left=true)})\n" + + prefix + s"If (${InternalNode.splitToString(split, left = true)})\n" + leftChild.subtreeToString(indentFactor + 1) + - prefix + s"Else (${InternalNode.splitToString(split, left=false)})\n" + + prefix + s"Else (${InternalNode.splitToString(split, left = false)})\n" + rightChild.subtreeToString(indentFactor + 1) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 2fa54df6fc2b2..65f30fdba7393 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -392,7 +392,7 @@ private[python] class PythonMLLibAPI extends Serializable { data: JavaRDD[Vector], wt: Vector, mu: Array[Object], - si: Array[Object]): RDD[Vector] = { + si: Array[Object]): RDD[Vector] = { val weight = wt.toArray val mean = mu.map(_.asInstanceOf[DenseVector]) @@ -428,7 +428,7 @@ private[python] class PythonMLLibAPI extends Serializable { if (seed != null) als.setSeed(seed) - val model = als.run(ratingsJRDD.rdd) + val model = als.run(ratingsJRDD.rdd) new MatrixFactorizationModelWrapper(model) } @@ -459,7 +459,7 @@ private[python] class PythonMLLibAPI extends Serializable { if (seed != null) als.setSeed(seed) - val model = als.run(ratingsJRDD.rdd) + val model = als.run(ratingsJRDD.rdd) new MatrixFactorizationModelWrapper(model) } @@ -1242,7 +1242,7 @@ private[spark] object SerDe extends Serializable { } /* convert RDD[Tuple2[,]] to RDD[Array[Any]] */ - def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = { + def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = { rdd.map(x => Array(x._1, x._2)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index c88410ac0ff43..e9a23e40cc790 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -211,7 +211,7 @@ class GaussianMixture private ( private object ExpectationSum { def zero(k: Int, d: Int): ExpectationSum = { new ExpectationSum(0.0, Array.fill(k)(0.0), - Array.fill(k)(BDV.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d))) + Array.fill(k)(BDV.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d, d))) } // compute cluster contributions for each input point diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 6fa2fe053c6a4..8e5154b902d1d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -273,7 +273,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * Default: 1024, following the original Online LDA paper. */ def setTau0(tau0: Double): this.type = { - require(tau0 > 0, s"LDA tau0 must be positive, but was set to $tau0") + require(tau0 > 0, s"LDA tau0 must be positive, but was set to $tau0") this.tau0 = tau0 this } @@ -339,7 +339,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { override private[clustering] def initialize( docs: RDD[(Long, Vector)], - lda: LDA): OnlineLDAOptimizer = { + lda: LDA): OnlineLDAOptimizer = { this.k = lda.getK this.corpusSize = docs.count() this.vocabSize = docs.first()._2.size @@ -458,7 +458,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * uses digamma which is accurate but expensive. */ private def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = { - val rowSum = sum(alpha(breeze.linalg.*, ::)) + val rowSum = sum(alpha(breeze.linalg.*, ::)) val digAlpha = digamma(alpha) val digRowSum = digamma(rowSum) val result = digAlpha(::, breeze.linalg.*) - digRowSum diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index a89eea0e21be2..efbfeb4059f5a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -144,7 +144,7 @@ private object IDF { * Since arrays are initialized to 0 by default, * we just omit changing those entries. */ - if(df(j) >= minDocFreq) { + if (df(j) >= minDocFreq) { inv(j) = math.log((m + 1.0) / (df(j) + 1.0)) } j += 1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index 6ae6917eae595..c73b8f258060d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -90,7 +90,7 @@ class StandardScalerModel ( @DeveloperApi def setWithMean(withMean: Boolean): this.type = { - require(!(withMean && this.mean == null),"cannot set withMean to true while mean is null") + require(!(withMean && this.mean == null), "cannot set withMean to true while mean is null") this.withMean = withMean this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 9106b73dfcd76..466ae95859b82 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -49,7 +49,7 @@ private case class VocabWord( var cn: Int, var point: Array[Int], var code: Array[Int], - var codeLen:Int + var codeLen: Int ) /** @@ -469,7 +469,7 @@ class Word2VecModel private[mllib] ( val norm1 = blas.snrm2(n, v1, 1) val norm2 = blas.snrm2(n, v2, 1) if (norm1 == 0 || norm2 == 0) return 0.0 - blas.sdot(n, v1, 1, v2,1) / norm1 / norm2 + blas.sdot(n, v1, 1, v2, 1) / norm1 / norm2 } override protected def formatVersion = "1.0" @@ -500,7 +500,7 @@ class Word2VecModel private[mllib] ( */ def findSynonyms(word: String, num: Int): Array[(String, Double)] = { val vector = transform(word) - findSynonyms(vector,num) + findSynonyms(vector, num) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 9a89a6f3a515f..1626da9c3d2ee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -219,7 +219,7 @@ class RowMatrix( val computeMode = mode match { case "auto" => - if(k > 5000) { + if (k > 5000) { logWarning(s"computing svd with k=$k and n=$n, please check necessity") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index 8341bb86afd71..7db5a14fd45a5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -52,7 +52,7 @@ object RandomRDDs { numPartitions: Int = 0, seed: Long = Utils.random.nextLong()): RDD[Double] = { val uniform = new UniformGenerator() - randomRDD(sc, uniform, size, numPartitionsOrDefault(sc, numPartitions), seed) + randomRDD(sc, uniform, size, numPartitionsOrDefault(sc, numPartitions), seed) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index 3ea63dd8c0acd..96e50faca2b19 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -203,7 +203,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { override def load(sc: SparkContext, path: String): IsotonicRegressionModel = { implicit val formats = DefaultFormats val (loadedClassName, version, metadata) = loadMetadata(sc, path) - val isotonic = (metadata \ "isotonic").extract[Boolean] + val isotonic = (metadata \ "isotonic").extract[Boolean] val classNameV1_0 = SaveLoadV1_0.thisClassName (loadedClassName, version) match { case (className, "1.0") if className == classNameV1_0 => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 0b1755613aac4..d321cc554c1cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -70,7 +70,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S require(n == sample.size, s"Dimensions mismatch when adding new sample." + s" Expecting $n but got ${sample.size}.") - val localCurrMean= currMean + val localCurrMean = currMean val localCurrM2n = currM2n val localCurrM2 = currM2 val localCurrL1 = currL1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala index e597fce2babd1..23c8d7c7c8075 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -196,7 +196,7 @@ private[stat] object ChiSqTest extends Logging { * Pearson's independence test on the input contingency matrix. * TODO: optimize for SparseMatrix when it becomes supported. */ - def chiSquaredMatrix(counts: Matrix, methodName:String = PEARSON.name): ChiSqTestResult = { + def chiSquaredMatrix(counts: Matrix, methodName: String = PEARSON.name): ChiSqTestResult = { val method = methodFromString(methodName) val numRows = counts.numRows val numCols = counts.numCols diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index dfe3a0b6913ef..cecd1fed896d5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -169,7 +169,7 @@ object DecisionTree extends Serializable with Logging { numClasses: Int, maxBins: Int, quantileCalculationStrategy: QuantileStrategy, - categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { + categoricalFeaturesInfo: Map[Int, Int]): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo) new DecisionTree(strategy).run(input) @@ -768,7 +768,7 @@ object DecisionTree extends Serializable with Logging { */ private def calculatePredictImpurity( leftImpurityCalculator: ImpurityCalculator, - rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = { + rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = { val parentNodeAgg = leftImpurityCalculator.copy parentNodeAgg.add(rightImpurityCalculator) val predict = calculatePredict(parentNodeAgg) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 1f779584dcffd..e3ddc7053693c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -60,12 +60,12 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { val algo = boostingStrategy.treeStrategy.algo algo match { - case Regression => GradientBoostedTrees.boost(input, input, boostingStrategy, validate=false) + case Regression => + GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false) case Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - GradientBoostedTrees.boost(remappedInput, - remappedInput, boostingStrategy, validate=false) + GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false) case _ => throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") } @@ -93,8 +93,8 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = { val algo = boostingStrategy.treeStrategy.algo algo match { - case Regression => GradientBoostedTrees.boost( - input, validationInput, boostingStrategy, validate=true) + case Regression => + GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true) case Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map( @@ -102,7 +102,7 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) val remappedValidationInput = validationInput.map( x => new LabeledPoint((x.label * 2) - 1, x.features)) GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy, - validate=true) + validate = true) case _ => throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index b347c450c1aa8..99d0e3cf2fd6d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -249,7 +249,7 @@ private class RandomForest ( try { nodeIdCache.get.deleteAllCheckpoints() } catch { - case e:IOException => + case e: IOException => logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}") } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 431a839817eac..ee710fc1ed299 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -151,9 +151,9 @@ class Node ( s"(feature ${split.feature} > ${split.threshold})" } case Categorical => if (left) { - s"(feature ${split.feature} in ${split.categories.mkString("{",",","}")})" + s"(feature ${split.feature} in ${split.categories.mkString("{", ",", "}")})" } else { - s"(feature ${split.feature} not in ${split.categories.mkString("{",",","}")})" + s"(feature ${split.feature} not in ${split.categories.mkString("{", ",", "}")})" } } } @@ -161,9 +161,9 @@ class Node ( if (isLeaf) { prefix + s"Predict: ${predict.predict}\n" } else { - prefix + s"If ${splitToString(split.get, left=true)}\n" + + prefix + s"If ${splitToString(split.get, left = true)}\n" + leftNode.get.subtreeToString(indentFactor + 1) + - prefix + s"Else ${splitToString(split.get, left=false)}\n" + + prefix + s"Else ${splitToString(split.get, left = false)}\n" + rightNode.get.subtreeToString(indentFactor + 1) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala index 0c5b4f9d04a74..bd73a866c8a82 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala @@ -82,8 +82,7 @@ object MFDataGenerator { BLAS.gemm(z, A, B, 1.0, fullData) val df = rank * (m + n - rank) - val sampSize = scala.math.min(scala.math.round(trainSampFact * df), - scala.math.round(.99 * m * n)).toInt + val sampSize = math.min(math.round(trainSampFact * df), math.round(.99 * m * n)).toInt val rand = new Random() val mn = m * n val shuffled = rand.shuffle((0 until mn).toList) @@ -102,8 +101,8 @@ object MFDataGenerator { // optionally generate testing data if (test) { - val testSampSize = scala.math - .min(scala.math.round(sampSize * testSampFact),scala.math.round(mn - sampSize)).toInt + val testSampSize = math.min( + math.round(sampSize * testSampFact), math.round(mn - sampSize)).toInt val testOmega = shuffled.slice(sampSize, sampSize + testSampSize) val testOrdered = testOmega.sortWith(_ < _).toArray val testData: RDD[(Int, Int, Double)] = sc.parallelize(testOrdered) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 43a09cc418703..df446d0c22015 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -35,9 +35,9 @@ class Word2VecSuite extends FunSuite with MLlibTestSparkContext { val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) val codes = Map( - "a" -> Array(-0.2811822295188904,-0.6356269121170044,-0.3020961284637451), - "b" -> Array(1.0309048891067505,-1.29472815990448,0.22276712954044342), - "c" -> Array(-0.08456747233867645,0.5137411952018738,0.11731560528278351) + "a" -> Array(-0.2811822295188904, -0.6356269121170044, -0.3020961284637451), + "b" -> Array(1.0309048891067505, -1.29472815990448, 0.22276712954044342), + "c" -> Array(-0.08456747233867645, 0.5137411952018738, 0.11731560528278351) ) val expected = doc.map { sentence => diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 65972ec79b9a5..60d8bfe38fb13 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -90,14 +90,20 @@ object CrossValidatorSuite { override def validateParams(): Unit = require($(inputCol).nonEmpty) - override def fit(dataset: DataFrame): MyModel = ??? + override def fit(dataset: DataFrame): MyModel = { + throw new UnsupportedOperationException + } - override def transformSchema(schema: StructType): StructType = ??? + override def transformSchema(schema: StructType): StructType = { + throw new UnsupportedOperationException + } } class MyEvaluator extends Evaluator { - override def evaluate(dataset: DataFrame): Double = ??? + override def evaluate(dataset: DataFrame): Double = { + throw new UnsupportedOperationException + } override val uid: String = "eval" } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala index a629dba8a426f..3d362b5ee53ea 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala @@ -84,7 +84,7 @@ class PythonMLLibAPISuite extends FunSuite { val smt = new SparseMatrix( 3, 3, Array(0, 2, 3, 5), Array(0, 2, 1, 0, 2), Array(0.9, 1.2, 3.4, 5.7, 8.9), - isTransposed=true) + isTransposed = true) val nsmt = SerDe.loads(SerDe.dumps(smt)).asInstanceOf[SparseMatrix] assert(smt.toArray === nsmt.toArray) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index c111a78a55806..ea40b41bbbe5e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -163,7 +163,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { val theta = Array( Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0 Array(0.02, 0.70, 0.10, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02), // label 1 - Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2 + Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2 ).map(_.map(math.log)) val testData = NaiveBayesSuite.generateNaiveBayesInput( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index 6de098b383ba3..90f9cec6855bf 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -46,7 +46,7 @@ object SVMSuite { nPoints: Int, seed: Int): Seq[LabeledPoint] = { val rnd = new Random(seed) - val weightsMat = new DoubleMatrix(1, weights.length, weights:_*) + val weightsMat = new DoubleMatrix(1, weights.length, weights : _*) val x = Array.fill[Array[Double]](nPoints)( Array.fill[Double](weights.length)(rnd.nextDouble() * 2.0 - 1.0)) val y = x.map { xi => @@ -91,7 +91,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val model = svm.run(testRDD) val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17) - val validationRDD = sc.parallelize(validationData, 2) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. @@ -117,7 +117,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val B = -1.5 val C = 1.0 - val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) + val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42) val testRDD = sc.parallelize(testData, 2) testRDD.cache() @@ -127,8 +127,8 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val model = svm.run(testRDD) - val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17) - val validationRDD = sc.parallelize(validationData, 2) + val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) @@ -145,7 +145,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val B = -1.5 val C = 1.0 - val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) + val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42) val initialB = -1.0 val initialC = -1.0 @@ -159,8 +159,8 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val model = svm.run(testRDD, initialWeights) - val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17) - val validationRDD = sc.parallelize(validationData,2) + val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) @@ -177,7 +177,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val B = -1.5 val C = 1.0 - val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) + val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42) val testRDD = sc.parallelize(testData, 2) val testRDDInvalid = testRDD.map { lp => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 0f2b26d462ad2..877e6dc699523 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -75,7 +75,7 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { val center = Vectors.dense(1.0, 2.0, 3.0) // Make sure code runs. - var model = KMeans.train(data, k=2, maxIterations=1) + var model = KMeans.train(data, k = 2, maxIterations = 1) assert(model.clusterCenters.size === 2) } @@ -87,7 +87,7 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { 2) // Make sure code runs. - var model = KMeans.train(data, k=3, maxIterations=1) + var model = KMeans.train(data, k = 3, maxIterations = 1) assert(model.clusterCenters.size === 3) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala index 6d6fe6fe46bab..556842f3129a3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala @@ -94,11 +94,13 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext */ val similarities = Seq[(Long, Long, Double)]( (0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), (2, 3, 1.0)) + // scalastyle:off val expected = Array( Array(0.0, 1.0/3.0, 1.0/3.0, 1.0/3.0), Array(1.0/2.0, 0.0, 1.0/2.0, 0.0), Array(1.0/3.0, 1.0/3.0, 0.0, 1.0/3.0), Array(1.0/2.0, 0.0, 1.0/2.0, 0.0)) + // scalastyle:on val w = normalize(sc.parallelize(similarities, 2)) w.edges.collect().foreach { case Edge(i, j, x) => assert(x ~== expected(i.toInt)(j.toInt) absTol 1e-14) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala index 670b4c34e6095..3aa732474ec2e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala @@ -26,7 +26,7 @@ class RegressionMetricsSuite extends FunSuite with MLlibTestSparkContext { test("regression metrics") { val predictionAndObservations = sc.parallelize( - Seq((2.5,3.0),(0.0,-0.5),(2.0,2.0),(8.0,7.0)), 2) + Seq((2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)), 2) val metrics = new RegressionMetrics(predictionAndObservations) assert(metrics.explainedVariance ~== 0.95717 absTol 1E-5, "explained variance regression score mismatch") @@ -39,7 +39,7 @@ class RegressionMetricsSuite extends FunSuite with MLlibTestSparkContext { test("regression metrics with complete fitting") { val predictionAndObservations = sc.parallelize( - Seq((3.0,3.0),(0.0,0.0),(2.0,2.0),(8.0,8.0)), 2) + Seq((3.0, 3.0), (0.0, 0.0), (2.0, 2.0), (8.0, 8.0)), 2) val metrics = new RegressionMetrics(predictionAndObservations) assert(metrics.explainedVariance ~== 1.0 absTol 1E-5, "explained variance regression score mismatch") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala index 7f94564b2a3ae..1eb991869de40 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala @@ -360,7 +360,7 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { } withClue("model needs std and mean vectors to be equal size when both are provided") { intercept[IllegalArgumentException] { - val model = new StandardScalerModel(Vectors.dense(0.0), Vectors.dense(0.0,1.0)) + val model = new StandardScalerModel(Vectors.dense(0.0), Vectors.dense(0.0, 1.0)) } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala index 949d1c9939570..a58336175899c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala @@ -57,11 +57,13 @@ class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext { val random = new ju.Random() // This should generate a 4x4 grid of 1x2 blocks. val part0 = GridPartitioner(4, 7, suggestedNumPartitions = 12) + // scalastyle:off val expected0 = Array( Array(0, 0, 4, 4, 8, 8, 12), Array(1, 1, 5, 5, 9, 9, 13), Array(2, 2, 6, 6, 10, 10, 14), Array(3, 3, 7, 7, 11, 11, 15)) + // scalastyle:on for (i <- 0 until 4; j <- 0 until 7) { assert(part0.getPartition((i, j)) === expected0(i)(j)) assert(part0.getPartition((i, j, random.nextInt())) === expected0(i)(j)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index 86481c6e66200..e110506d579b0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -42,7 +42,7 @@ object GradientDescentSuite { offset: Double, scale: Double, nPoints: Int, - seed: Int): Seq[LabeledPoint] = { + seed: Int): Seq[LabeledPoint] = { val rnd = new Random(seed) val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian()) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala index 22855e4e8f247..bb723fc471181 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala @@ -68,12 +68,14 @@ class NNLSSuite extends FunSuite { test("NNLS: nonnegativity constraint active") { val n = 5 + // scalastyle:off val ata = new DoubleMatrix(Array( Array( 4.377, -3.531, -1.306, -0.139, 3.418), Array(-3.531, 4.344, 0.934, 0.305, -2.140), Array(-1.306, 0.934, 2.644, -0.203, -0.170), Array(-0.139, 0.305, -0.203, 5.883, 1.428), Array( 3.418, -2.140, -0.170, 1.428, 4.684))) + // scalastyle:on val atb = new DoubleMatrix(Array(-1.632, 2.115, 1.094, -1.025, -0.636)) val goodx = Array(0.13025, 0.54506, 0.2874, 0.0, 0.028628) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index c9f5dc069ef2e..71dce50922991 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -67,11 +67,12 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext { assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]") assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]") - val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17) + val validationData = LinearDataGenerator + .generateLinearInput(A, Array[Double](B, C), nPoints, 17) .map { case LabeledPoint(label, features) => LabeledPoint(label, Vectors.dense(1.0 +: features.toArray)) } - val validationRDD = sc.parallelize(validationData, 2) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) @@ -110,11 +111,12 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext { assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]") assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]") - val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17) + val validationData = LinearDataGenerator + .generateLinearInput(A, Array[Double](B, C), nPoints, 17) .map { case LabeledPoint(label, features) => LabeledPoint(label, Vectors.dense(1.0 +: features.toArray)) } - val validationRDD = sc.parallelize(validationData,2) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala index d20a09b4b4925..a7e6fce31ff7e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala @@ -96,11 +96,13 @@ class CorrelationSuite extends FunSuite with MLlibTestSparkContext { val X = sc.parallelize(data) val defaultMat = Statistics.corr(X) val pearsonMat = Statistics.corr(X, "pearson") + // scalastyle:off val expected = BDM( (1.00000000, 0.05564149, Double.NaN, 0.4004714), (0.05564149, 1.00000000, Double.NaN, 0.9135959), (Double.NaN, Double.NaN, 1.00000000, Double.NaN), - (0.40047142, 0.91359586, Double.NaN,1.0000000)) + (0.40047142, 0.91359586, Double.NaN, 1.0000000)) + // scalastyle:on assert(matrixApproxEqual(defaultMat.toBreeze, expected)) assert(matrixApproxEqual(pearsonMat.toBreeze, expected)) } @@ -108,11 +110,13 @@ class CorrelationSuite extends FunSuite with MLlibTestSparkContext { test("corr(X) spearman") { val X = sc.parallelize(data) val spearmanMat = Statistics.corr(X, "spearman") + // scalastyle:off val expected = BDM( (1.0000000, 0.1054093, Double.NaN, 0.4000000), (0.1054093, 1.0000000, Double.NaN, 0.9486833), (Double.NaN, Double.NaN, 1.00000000, Double.NaN), (0.4000000, 0.9486833, Double.NaN, 1.0000000)) + // scalastyle:on assert(matrixApproxEqual(spearmanMat.toBreeze, expected)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 668fc1d43c5d6..cdece2c174be4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -168,7 +168,7 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext { "Each training+validation set combined should contain all of the data.") } // K fold cross validation should only have each element in the validation set exactly once - assert(foldedRdds.map(_._2).reduce((x,y) => x.union(y)).collect().sorted === + assert(foldedRdds.map(_._2).reduce((x, y) => x.union(y)).collect().sorted === data.collect().sorted) } } From ff44c711abc7ca545dfa1e836279c00fe7539c18 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 28 May 2015 20:10:21 -0700 Subject: [PATCH 050/254] [SPARK-7927] whitespace fixes for SQL core. So we can enable a whitespace enforcement rule in the style checker to save code review time. Author: Reynold Xin Closes #6477 from rxin/whitespace-sql-core and squashes the following commits: ce6e369 [Reynold Xin] Fixed tests. 6095fed [Reynold Xin] [SPARK-7927] whitespace fixes for SQL core. --- .../scala/org/apache/spark/sql/Column.scala | 4 +- .../org/apache/spark/sql/DataFrame.scala | 18 ++++---- .../apache/spark/sql/DataFrameHolder.scala | 2 +- .../org/apache/spark/sql/GroupedData.scala | 10 ++--- .../org/apache/spark/sql/SQLContext.scala | 2 +- .../org/apache/spark/sql/SparkSQLParser.scala | 18 ++++---- .../columnar/InMemoryColumnarTableScan.scala | 2 +- .../apache/spark/sql/execution/Exchange.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 7 +-- .../joins/BroadcastLeftSemiJoinHash.scala | 2 +- .../sql/execution/stat/FrequentItems.scala | 4 +- .../org/apache/spark/sql/functions.scala | 2 +- .../org/apache/spark/sql/jdbc/JDBCRDD.scala | 44 ++++++++++--------- .../apache/spark/sql/json/InferSchema.scala | 2 +- .../spark/sql/json/JacksonGenerator.scala | 10 ++--- .../org/apache/spark/sql/json/JsonRDD.scala | 8 ++-- .../spark/sql/parquet/ParquetConverter.scala | 12 ++--- .../sql/parquet/ParquetTableOperations.scala | 4 +- .../spark/sql/parquet/ParquetTypes.scala | 2 +- .../org/apache/spark/sql/sources/ddl.scala | 8 ++-- .../spark/sql/ColumnExpressionSuite.scala | 14 +++--- .../spark/sql/DataFrameAggregateSuite.scala | 4 +- .../org/apache/spark/sql/DataFrameSuite.scala | 38 ++++++++-------- .../org/apache/spark/sql/JoinSuite.scala | 8 ++-- .../apache/spark/sql/ListTablesSuite.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 42 +++++++++--------- .../sql/ScalaReflectionRelationSuite.scala | 4 +- .../scala/org/apache/spark/sql/TestData.scala | 4 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 2 +- .../spark/sql/columnar/ColumnTypeSuite.scala | 6 +-- .../compression/DictionaryEncodingSuite.scala | 4 +- .../compression/IntegralDeltaSuite.scala | 4 +- .../compression/RunLengthEncodingSuite.scala | 10 ++--- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 4 +- .../org/apache/spark/sql/json/JsonSuite.scala | 2 +- .../spark/sql/sources/DDLTestSuite.scala | 5 +-- .../spark/sql/sources/FilteredScanSuite.scala | 2 +- 37 files changed, 160 insertions(+), 158 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 6895aa1010956..b49b1d327289f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -349,7 +349,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def when(condition: Column, value: Any):Column = this.expr match { + def when(condition: Column, value: Any): Column = this.expr match { case CaseWhen(branches: Seq[Expression]) => CaseWhen(branches ++ Seq(lit(condition).expr, lit(value).expr)) case _ => @@ -378,7 +378,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def otherwise(value: Any):Column = this.expr match { + def otherwise(value: Any): Column = this.expr match { case CaseWhen(branches: Seq[Expression]) => if (branches.size % 2 == 0) { CaseWhen(branches :+ lit(value).expr) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index f968577bc5848..e90109446b642 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -255,7 +255,7 @@ class DataFrame private[sql]( val newCols = logicalPlan.output.zip(colNames).map { case (oldAttribute, newName) => Column(oldAttribute).as(newName) } - select(newCols :_*) + select(newCols : _*) } /** @@ -500,7 +500,7 @@ class DataFrame private[sql]( */ @scala.annotation.varargs def sort(sortCol: String, sortCols: String*): DataFrame = { - sort((sortCol +: sortCols).map(apply) :_*) + sort((sortCol +: sortCols).map(apply) : _*) } /** @@ -531,7 +531,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def orderBy(sortCol: String, sortCols: String*): DataFrame = sort(sortCol, sortCols :_*) + def orderBy(sortCol: String, sortCols: String*): DataFrame = sort(sortCol, sortCols : _*) /** * Returns a new [[DataFrame]] sorted by the given expressions. @@ -540,7 +540,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs :_*) + def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs : _*) /** * Selects column based on the column name and return it as a [[Column]]. @@ -611,7 +611,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)) :_*) + def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)) : _*) /** * Selects a set of SQL expressions. This is a variant of `select` that accepts @@ -825,7 +825,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { - groupBy().agg(aggExpr, aggExprs :_*) + groupBy().agg(aggExpr, aggExprs : _*) } /** @@ -863,7 +863,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs :_*) + def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs : _*) /** * Returns a new [[DataFrame]] by taking the first `n` rows. The difference between this function @@ -1039,7 +1039,7 @@ class DataFrame private[sql]( val name = field.name if (resolver(name, colName)) col.as(colName) else Column(name) } - select(colNames :_*) + select(colNames : _*) } else { select(Column("*"), col.as(colName)) } @@ -1262,7 +1262,7 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() :_*) + override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() : _*) /** * Returns the number of rows in the [[DataFrame]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala index b87efb58d51e5..2f19ec0403017 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala @@ -28,5 +28,5 @@ private[sql] case class DataFrameHolder(df: DataFrame) { // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. def toDF(): DataFrame = df - def toDF(colNames: String*): DataFrame = df.toDF(colNames :_*) + def toDF(colNames: String*): DataFrame = df.toDF(colNames : _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index f730e4ae00e2b..516ba2ac23371 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -247,7 +247,7 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def mean(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames:_*)(Average) + aggregateNumericColumns(colNames : _*)(Average) } /** @@ -259,7 +259,7 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def max(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames:_*)(Max) + aggregateNumericColumns(colNames : _*)(Max) } /** @@ -271,7 +271,7 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def avg(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames:_*)(Average) + aggregateNumericColumns(colNames : _*)(Average) } /** @@ -283,7 +283,7 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def min(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames:_*)(Min) + aggregateNumericColumns(colNames : _*)(Min) } /** @@ -295,6 +295,6 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def sum(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames:_*)(Sum) + aggregateNumericColumns(colNames : _*)(Sum) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 15c30352bee69..a32897c20b474 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -298,7 +298,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ implicit class StringToColumn(val sc: StringContext) { def $(args: Any*): ColumnName = { - new ColumnName(sc.s(args :_*)) + new ColumnName(sc.s(args : _*)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala index 6b1ae81972e4e..305b306a79871 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala @@ -54,15 +54,15 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr } } - protected val AS = Keyword("AS") - protected val CACHE = Keyword("CACHE") - protected val CLEAR = Keyword("CLEAR") - protected val IN = Keyword("IN") - protected val LAZY = Keyword("LAZY") - protected val SET = Keyword("SET") - protected val SHOW = Keyword("SHOW") - protected val TABLE = Keyword("TABLE") - protected val TABLES = Keyword("TABLES") + protected val AS = Keyword("AS") + protected val CACHE = Keyword("CACHE") + protected val CLEAR = Keyword("CLEAR") + protected val IN = Keyword("IN") + protected val LAZY = Keyword("LAZY") + protected val SET = Keyword("SET") + protected val SHOW = Keyword("SHOW") + protected val TABLE = Keyword("TABLE") + protected val TABLES = Keyword("TABLES") protected val UNCACHE = Keyword("UNCACHE") override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | show | others diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index a59d42cdd6028..3db26fad2b92f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -236,7 +236,7 @@ private[sql] case class InMemoryColumnarTableScan( case GreaterThanOrEqual(a: AttributeReference, l: Literal) => l <= statsFor(a).upperBound case GreaterThanOrEqual(l: Literal, a: AttributeReference) => statsFor(a).lowerBound <= l - case IsNull(a: Attribute) => statsFor(a).nullCount > 0 + case IsNull(a: Attribute) => statsFor(a).nullCount > 0 case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 3e46596ecf6ac..f25d10fec0411 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -296,7 +296,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ .sliding(2) .map { case Seq(a) => true - case Seq(a,b) => a compatibleWith b + case Seq(a, b) => a.compatibleWith(b) }.exists(!_) // Adds Exchange or Sort operators as required diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 3f6a0345bc17d..d0a1ad00560d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -243,8 +243,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case (predicate, None) => predicate // Filter needs to be applied above when it contains partitioning // columns - case (predicate, _) if(!predicate.references.map(_.name).toSet - .intersect (partitionColNames).isEmpty) => predicate + case (predicate, _) + if !predicate.references.map(_.name).toSet.intersect(partitionColNames).isEmpty => + predicate } } } else { @@ -270,7 +271,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { projectList, filters, identity[Seq[Expression]], // All filters still need to be evaluated. - InMemoryColumnarTableScan(_, filters, mem)) :: Nil + InMemoryColumnarTableScan(_, filters, mem)) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 640fc26ba3baa..a32e5fc4f7ea4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -39,7 +39,7 @@ case class BroadcastLeftSemiJoinHash( override def output: Seq[Attribute] = left.output protected override def doExecute(): RDD[Row] = { - val buildIter= buildPlan.execute().map(_.copy()).collect().toIterator + val buildIter = buildPlan.execute().map(_.copy()).collect().toIterator val hashSet = new java.util.HashSet[Row]() var currentRow: Row = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index 5ae7e107544f8..fe8a81e3d0434 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -89,7 +89,7 @@ private[sql] object FrequentItems extends Logging { (name, originalSchema.fields(index).dataType) } - val freqItems = df.select(cols.map(Column(_)):_*).rdd.aggregate(countMaps)( + val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)( seqOp = (counts, row) => { var i = 0 while (i < numCols) { @@ -110,7 +110,7 @@ private[sql] object FrequentItems extends Logging { } ) val justItems = freqItems.map(m => m.baseMap.keys.toSeq) - val resultRow = Row(justItems:_*) + val resultRow = Row(justItems : _*) // append frequent Items to the column name for easy debugging val outputCols = colInfo.map { v => StructField(v._1 + "_freqItems", ArrayType(v._2, false)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9a23cfb89ca12..6dc17bbb2e768 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -187,7 +187,7 @@ object functions { */ @scala.annotation.varargs def countDistinct(columnName: String, columnNames: String*): Column = - countDistinct(Column(columnName), columnNames.map(Column.apply) :_*) + countDistinct(Column(columnName), columnNames.map(Column.apply) : _*) /** * Aggregate function: returns the approximate number of distinct items in a group. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 88f1b02549e21..0bdb68e8ac845 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -52,6 +52,7 @@ private[sql] object JDBCRDD extends Logging { scale: Int, signed: Boolean): DataType = { val answer = sqlType match { + // scalastyle:off case java.sql.Types.ARRAY => null case java.sql.Types.BIGINT => if (signed) { LongType } else { DecimalType.Unlimited } case java.sql.Types.BINARY => BinaryType @@ -92,7 +93,8 @@ private[sql] object JDBCRDD extends Logging { case java.sql.Types.TINYINT => IntegerType case java.sql.Types.VARBINARY => BinaryType case java.sql.Types.VARCHAR => StringType - case _ => null + case _ => null + // scalastyle:on } if (answer == null) throw new SQLException("Unsupported type " + sqlType) @@ -323,19 +325,19 @@ private[sql] class JDBCRDD( */ def getConversions(schema: StructType): Array[JDBCConversion] = { schema.fields.map(sf => sf.dataType match { - case BooleanType => BooleanConversion - case DateType => DateConversion + case BooleanType => BooleanConversion + case DateType => DateConversion case DecimalType.Unlimited => DecimalConversion(None) - case DecimalType.Fixed(d) => DecimalConversion(Some(d)) - case DoubleType => DoubleConversion - case FloatType => FloatConversion - case IntegerType => IntegerConversion - case LongType => + case DecimalType.Fixed(d) => DecimalConversion(Some(d)) + case DoubleType => DoubleConversion + case FloatType => FloatConversion + case IntegerType => IntegerConversion + case LongType => if (sf.metadata.contains("binarylong")) BinaryLongConversion else LongConversion - case StringType => StringConversion - case TimestampType => TimestampConversion - case BinaryType => BinaryConversion - case _ => throw new IllegalArgumentException(s"Unsupported field $sf") + case StringType => StringConversion + case TimestampType => TimestampConversion + case BinaryType => BinaryConversion + case _ => throw new IllegalArgumentException(s"Unsupported field $sf") }).toArray } @@ -376,8 +378,8 @@ private[sql] class JDBCRDD( while (i < conversions.length) { val pos = i + 1 conversions(i) match { - case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos)) - case DateConversion => + case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos)) + case DateConversion => // DateUtils.fromJavaDate does not handle null value, so we need to check it. val dateVal = rs.getDate(pos) if (dateVal != null) { @@ -407,14 +409,14 @@ private[sql] class JDBCRDD( } else { mutableRow.update(i, Decimal(decimalVal)) } - case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos)) - case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos)) - case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos)) - case LongConversion => mutableRow.setLong(i, rs.getLong(pos)) + case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos)) + case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos)) + case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos)) + case LongConversion => mutableRow.setLong(i, rs.getLong(pos)) // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 - case StringConversion => mutableRow.setString(i, rs.getString(pos)) - case TimestampConversion => mutableRow.update(i, rs.getTimestamp(pos)) - case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) + case StringConversion => mutableRow.setString(i, rs.getString(pos)) + case TimestampConversion => mutableRow.update(i, rs.getTimestamp(pos)) + case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) case BinaryLongConversion => { val bytes = rs.getBytes(pos) var ans = 0L diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala index 9c58b8e4bb16a..06aa19ef09bd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala @@ -124,7 +124,7 @@ private[sql] object InferSchema { case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull) case ArrayType(struct: StructType, containsNull) => ArrayType(nullTypeToStringType(struct), containsNull) - case struct: StructType =>nullTypeToStringType(struct) + case struct: StructType => nullTypeToStringType(struct) case other: DataType => other } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala index 80bf74aa02602..325f54b6808a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala @@ -33,7 +33,7 @@ private[sql] object JacksonGenerator { */ def apply(rowSchema: StructType, gen: JsonGenerator)(row: Row): Unit = { def valWriter: (DataType, Any) => Unit = { - case (_, null) | (NullType, _) => gen.writeNull() + case (_, null) | (NullType, _) => gen.writeNull() case (StringType, v: String) => gen.writeString(v) case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString) case (IntegerType, v: Int) => gen.writeNumber(v) @@ -48,16 +48,16 @@ private[sql] object JacksonGenerator { case (DateType, v) => gen.writeString(v.toString) case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, udt.serialize(v)) - case (ArrayType(ty, _), v: Seq[_] ) => + case (ArrayType(ty, _), v: Seq[_]) => gen.writeStartArray() - v.foreach(valWriter(ty,_)) + v.foreach(valWriter(ty, _)) gen.writeEndArray() - case (MapType(kv,vv, _), v: Map[_,_]) => + case (MapType(kv, vv, _), v: Map[_, _]) => gen.writeStartObject() v.foreach { p => gen.writeFieldName(p._1.toString) - valWriter(vv,p._2) + valWriter(vv, p._2) } gen.writeEndObject() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 037a6d60a2ed6..95eb1174b1dd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -141,7 +141,7 @@ private[sql] object JsonRDD extends Logging { case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull) case ArrayType(struct: StructType, containsNull) => ArrayType(nullTypeToStringType(struct), containsNull) - case struct: StructType =>nullTypeToStringType(struct) + case struct: StructType => nullTypeToStringType(struct) case other: DataType => other } StructField(fieldName, newType, nullable) @@ -216,7 +216,7 @@ private[sql] object JsonRDD extends Logging { case map: Map[_, _] => StructType(Nil) // We have an array of arrays. If those element arrays do not have the same // element types, we will return ArrayType[StringType]. - case seq: Seq[_] => typeOfArray(seq) + case seq: Seq[_] => typeOfArray(seq) case value => typeOfPrimitiveValue(value) } }.reduce((type1: DataType, type2: DataType) => compatibleType(type1, type2)) @@ -406,7 +406,7 @@ private[sql] object JsonRDD extends Logging { } } - private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any ={ + private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any = { if (value == null) { null } else { @@ -434,7 +434,7 @@ private[sql] object JsonRDD extends Logging { } } - private def asRow(json: Map[String,Any], schema: StructType): Row = { + private def asRow(json: Map[String, Any], schema: StructType): Row = { // TODO: Reuse the row instead of creating a new one for every record. val row = new GenericMutableRow(schema.fields.length) schema.fields.zipWithIndex.foreach { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 36cb5e03bbca7..1b4196ab0be35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -480,7 +480,7 @@ private[parquet] class CatalystPrimitiveStringConverter(parent: CatalystConverte override def hasDictionarySupport: Boolean = true - override def setDictionary(dictionary: Dictionary):Unit = + override def setDictionary(dictionary: Dictionary): Unit = dict = Array.tabulate(dictionary.getMaxId + 1) { dictionary.decodeToBinary(_).getBytes } override def addValueFromDictionary(dictionaryId: Int): Unit = @@ -591,8 +591,8 @@ private[parquet] class CatalystArrayConverter( CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, elementType, false), - fieldIndex=0, - parent=this) + fieldIndex = 0, + parent = this) override def getConverter(fieldIndex: Int): Converter = converter @@ -601,7 +601,7 @@ private[parquet] class CatalystArrayConverter( override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { // fieldIndex is ignored (assumed to be zero but not checked) - if(value == null) { + if (value == null) { throw new IllegalArgumentException("Null values inside Parquet arrays are not supported!") } buffer += value @@ -654,8 +654,8 @@ private[parquet] class CatalystNativeArrayConverter( CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, elementType, false), - fieldIndex=0, - parent=this) + fieldIndex = 0, + parent = this) override def getConverter(fieldIndex: Int): Converter = converter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 90950f924a054..cb7ae246d0d75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -541,7 +541,7 @@ private[parquet] class FilteringParquetRowInputFormat val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] val filter: Filter = ParquetInputFormat.getFilter(configuration) var rowGroupsDropped: Long = 0 - var totalRowGroups: Long = 0 + var totalRowGroups: Long = 0 // Ugly hack, stuck with it until PR: // https://github.com/apache/incubator-parquet-mr/pull/17 @@ -664,7 +664,7 @@ private[parquet] object FileSystemHelper { s"ParquetTableOperations: path $path does not exist or is not a directory") } fs.globStatus(path) - .flatMap { status => if(status.isDir) fs.listStatus(status.getPath) else List(status) } + .flatMap { status => if (status.isDir) fs.listStatus(status.getPath) else List(status) } .map(_.getPath) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 1dc819b5d7b9b..6698b19c7477d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -489,7 +489,7 @@ private[parquet] object ParquetTypesConverter extends Logging { val children = fs .globStatus(path) - .flatMap { status => if(status.isDir) fs.listStatus(status.getPath) else List(status) } + .flatMap { status => if (status.isDir) fs.listStatus(status.getPath) else List(status) } .filterNot { status => val name = status.getPath.getName (name(0) == '.' || name(0) == '_') && name != ParquetFileWriter.PARQUET_METADATA_FILE diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index ca30b8e74626f..22587f5a1c6f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -130,7 +130,7 @@ private[sql] class DDLParser( } } - protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" + protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" /* * describe [extended] table avroTable @@ -138,7 +138,7 @@ private[sql] class DDLParser( */ protected lazy val describeTable: Parser[LogicalPlan] = (DESCRIBE ~> opt(EXTENDED)) ~ (ident <~ ".").? ~ ident ^^ { - case e ~ db ~ tbl => + case e ~ db ~ tbl => val tblIdentifier = db match { case Some(dbName) => Seq(dbName, tbl) @@ -171,7 +171,7 @@ private[sql] class DDLParser( } protected lazy val pair: Parser[(String, String)] = - optionName ~ stringLit ^^ { case k ~ v => (k,v) } + optionName ~ stringLit ^^ { case k ~ v => (k, v) } protected lazy val column: Parser[StructField] = ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => @@ -239,7 +239,7 @@ private[sql] object ResolvedDataSource { Some(partitionColumnsSchema(schema, partitionColumns)) } - val caseInsensitiveOptions= new CaseInsensitiveMap(options) + val caseInsensitiveOptions = new CaseInsensitiveMap(options) val paths = { val patternPath = new Path(caseInsensitiveOptions("path")) SparkHadoopUtil.get.globPath(patternPath).map(_.toString).toArray diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 9bdf201b3be7c..d006b83fc075a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -28,14 +28,14 @@ class ColumnExpressionSuite extends QueryTest { import org.apache.spark.sql.TestData._ test("single explode") { - val df = Seq((1, Seq(1,2,3))).toDF("a", "intList") + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") checkAnswer( df.select(explode('intList)), Row(1) :: Row(2) :: Row(3) :: Nil) } test("explode and other columns") { - val df = Seq((1, Seq(1,2,3))).toDF("a", "intList") + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") checkAnswer( df.select($"a", explode('intList)), @@ -45,13 +45,13 @@ class ColumnExpressionSuite extends QueryTest { checkAnswer( df.select($"*", explode('intList)), - Row(1, Seq(1,2,3), 1) :: - Row(1, Seq(1,2,3), 2) :: - Row(1, Seq(1,2,3), 3) :: Nil) + Row(1, Seq(1, 2, 3), 1) :: + Row(1, Seq(1, 2, 3), 2) :: + Row(1, Seq(1, 2, 3), 3) :: Nil) } test("aliased explode") { - val df = Seq((1, Seq(1,2,3))).toDF("a", "intList") + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") checkAnswer( df.select(explode('intList).as('int)).select('int), @@ -79,7 +79,7 @@ class ColumnExpressionSuite extends QueryTest { } test("self join explode") { - val df = Seq((1, Seq(1,2,3))).toDF("a", "intList") + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") val exploded = df.select(explode('intList).as('i)) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 35a574f354741..232f05c00918f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -148,12 +148,12 @@ class DataFrameAggregateSuite extends QueryTest { test("null count") { checkAnswer( testData3.groupBy('a).agg(count('b)), - Seq(Row(1,0), Row(2, 1)) + Seq(Row(1, 0), Row(2, 1)) ) checkAnswer( testData3.groupBy('a).agg(count('a + 'b)), - Seq(Row(1,0), Row(2, 1)) + Seq(Row(1, 0), Row(2, 1)) ) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 0dcba80ef2a20..a4fd1058afce5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -59,7 +59,7 @@ class DataFrameSuite extends QueryTest { } test("rename nested groupby") { - val df = Seq((1,(1,1))).toDF() + val df = Seq((1, (1, 1))).toDF() checkAnswer( df.groupBy("_1").agg(sum("_2._1")).toDF("key", "total"), @@ -211,23 +211,23 @@ class DataFrameSuite extends QueryTest { test("global sorting") { checkAnswer( testData2.orderBy('a.asc, 'b.asc), - Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) + Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))) checkAnswer( testData2.orderBy(asc("a"), desc("b")), - Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) + Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1))) checkAnswer( testData2.orderBy('a.asc, 'b.desc), - Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) + Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1))) checkAnswer( testData2.orderBy('a.desc, 'b.desc), - Seq(Row(3,2), Row(3,1), Row(2,2), Row(2,1), Row(1,2), Row(1,1))) + Seq(Row(3, 2), Row(3, 1), Row(2, 2), Row(2, 1), Row(1, 2), Row(1, 1))) checkAnswer( testData2.orderBy('a.desc, 'b.asc), - Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2))) + Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2))) checkAnswer( arrayData.toDF().orderBy('data.getItem(0).asc), @@ -331,7 +331,7 @@ class DataFrameSuite extends QueryTest { checkAnswer( df, testData.collect().toSeq) - assert(df.schema.map(_.name) === Seq("key","value")) + assert(df.schema.map(_.name) === Seq("key", "value")) } test("withColumnRenamed") { @@ -364,24 +364,24 @@ class DataFrameSuite extends QueryTest { test("describe") { val describeTestData = Seq( - ("Bob", 16, 176), + ("Bob", 16, 176), ("Alice", 32, 164), ("David", 60, 192), - ("Amy", 24, 180)).toDF("name", "age", "height") + ("Amy", 24, 180)).toDF("name", "age", "height") val describeResult = Seq( - Row("count", "4", "4"), - Row("mean", "33.0", "178.0"), - Row("stddev", "16.583123951777", "10.0"), - Row("min", "16", "164"), - Row("max", "60", "192")) + Row("count", "4", "4"), + Row("mean", "33.0", "178.0"), + Row("stddev", "16.583123951777", "10.0"), + Row("min", "16", "164"), + Row("max", "60", "192")) val emptyDescribeResult = Seq( - Row("count", "0", "0"), - Row("mean", null, null), - Row("stddev", null, null), - Row("min", null, null), - Row("max", null, null)) + Row("count", "0", "0"), + Row("mean", null, null), + Row("stddev", null, null), + Row("min", null, null), + Row("max", null, null)) def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 037d392c1f929..407c789657834 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -167,10 +167,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val y = testData2.where($"a" === 1).as("y") checkAnswer( x.join(y).where($"x.a" === $"y.a"), - Row(1,1,1,1) :: - Row(1,1,1,2) :: - Row(1,2,1,1) :: - Row(1,2,1,2) :: Nil + Row(1, 1, 1, 1) :: + Row(1, 1, 1, 2) :: + Row(1, 2, 1, 1) :: + Row(1, 2, 1, 2) :: Nil ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index f9f41eb358bd5..3ce97c3fffdb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -28,7 +28,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter { import org.apache.spark.sql.test.TestSQLContext.implicits._ val df = - sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDF("key", "value") + sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value") before { df.registerTempTable("ListTablesSuiteTable") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 7c47fe454b6dc..bf18bf854aa4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -53,7 +53,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("self join with aliases") { - Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str").registerTempTable("df") + Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str").registerTempTable("df") checkAnswer( sql( @@ -76,7 +76,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("self join with alias in agg") { - Seq(1,2,3) + Seq(1, 2, 3) .map(i => (i, i.toString)) .toDF("int", "str") .groupBy("str") @@ -113,7 +113,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") { checkAnswer( sql("SELECT a FROM testData2 SORT BY a"), - Seq(1, 1, 2 ,2 ,3 ,3).map(Row(_)) + Seq(1, 1, 2, 2, 3, 3).map(Row(_)) ) } @@ -354,7 +354,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("left semi greater than predicate") { checkAnswer( sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"), - Seq(Row(3,1), Row(3,2)) + Seq(Row(3, 1), Row(3, 2)) ) } @@ -371,16 +371,16 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("agg") { checkAnswer( sql("SELECT a, SUM(b) FROM testData2 GROUP BY a"), - Seq(Row(1,3), Row(2,3), Row(3,3))) + Seq(Row(1, 3), Row(2, 3), Row(3, 3))) } test("literal in agg grouping expressions") { checkAnswer( sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), - Seq(Row(1,2), Row(2,2), Row(3,2))) + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) checkAnswer( sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), - Seq(Row(1,2), Row(2,2), Row(3,2))) + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) } test("aggregates with nulls") { @@ -405,19 +405,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { def sortTest(): Unit = { checkAnswer( sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"), - Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) + Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))) checkAnswer( sql("SELECT * FROM testData2 ORDER BY a ASC, b DESC"), - Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) + Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1))) checkAnswer( sql("SELECT * FROM testData2 ORDER BY a DESC, b DESC"), - Seq(Row(3,2), Row(3,1), Row(2,2), Row(2,1), Row(1,2), Row(1,1))) + Seq(Row(3, 2), Row(3, 1), Row(2, 2), Row(2, 1), Row(1, 2), Row(1, 1))) checkAnswer( sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"), - Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2))) + Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2))) checkAnswer( sql("SELECT b FROM binaryData ORDER BY a ASC"), @@ -552,7 +552,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("average overflow") { checkAnswer( sql("SELECT AVG(a),b FROM largeAndSmallInts group by b"), - Seq(Row(2147483645.0,1), Row(2.0,2))) + Seq(Row(2147483645.0, 1), Row(2.0, 2))) } test("count") { @@ -619,10 +619,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | (SELECT * FROM testData2 WHERE a = 1) x JOIN | (SELECT * FROM testData2 WHERE a = 1) y |WHERE x.a = y.a""".stripMargin), - Row(1,1,1,1) :: - Row(1,1,1,2) :: - Row(1,2,1,1) :: - Row(1,2,1,2) :: Nil) + Row(1, 1, 1, 1) :: + Row(1, 1, 1, 2) :: + Row(1, 2, 1, 1) :: + Row(1, 2, 1, 2) :: Nil) } test("inner join, no matches") { @@ -1266,22 +1266,22 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-4432 Fix attribute reference resolution error when using ORDER BY") { checkAnswer( sql("SELECT a + b FROM testData2 ORDER BY a"), - Seq(2, 3, 3 ,4 ,4 ,5).map(Row(_)) + Seq(2, 3, 3, 4, 4, 5).map(Row(_)) ) } test("oder by asc by default when not specify ascending and descending") { checkAnswer( sql("SELECT a, b FROM testData2 ORDER BY a desc, b"), - Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2,2), Row(1, 1), Row(1, 2)) + Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2)) ) } test("Supporting relational operator '<=>' in Spark SQL") { - val nullCheckData1 = TestData(1,"1") :: TestData(2,null) :: Nil + val nullCheckData1 = TestData(1, "1") :: TestData(2, null) :: Nil val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) rdd1.toDF().registerTempTable("nulldata1") - val nullCheckData2 = TestData(1,"1") :: TestData(2,null) :: Nil + val nullCheckData2 = TestData(1, "1") :: TestData(2, null) :: Nil val rdd2 = sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) rdd2.toDF().registerTempTable("nulldata2") checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " + @@ -1290,7 +1290,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("Multi-column COUNT(DISTINCT ...)") { - val data = TestData(1,"val_1") :: TestData(2,"val_2") :: Nil + val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("distinctData") checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index 3fa00fd9d0ccb..52d265b445e14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -80,14 +80,14 @@ class ScalaReflectionRelationSuite extends FunSuite { test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, - new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3)) + new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1, 2, 3)) val rdd = sparkContext.parallelize(data :: Nil) rdd.toDF().registerTempTable("reflectData") assert(sql("SELECT * FROM reflectData").collect().head === Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), - new Timestamp(12345), Seq(1,2,3))) + new Timestamp(12345), Seq(1, 2, 3))) } test("query case class RDD with nulls") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 8fbc2d23d47e6..725a18bfae3a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -109,8 +109,8 @@ object TestData { case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) val arrayData = TestSQLContext.sparkContext.parallelize( - ArrayData(Seq(1,2,3), Seq(Seq(1,2,3))) :: - ArrayData(Seq(2,3,4), Seq(Seq(2,3,4))) :: Nil) + ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: + ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) arrayData.toDF().registerTempTable("arrayData") case class MapData(data: scala.collection.Map[Int, String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index d615542ab50a7..1a9ba66416b21 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -38,7 +38,7 @@ class UDFSuite extends QueryTest { } test("TwoArgument UDF") { - udf.register("strLenScala", (_: String).length + (_:Int)) + udf.register("strLenScala", (_: String).length + (_: Int)) assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 1e105e259dce7..061efb37a0ac3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -73,7 +73,7 @@ class ColumnTypeSuite extends FunSuite with Logging { checkActualSize(TIMESTAMP, new Timestamp(0L), 12) val binary = Array.fill[Byte](4)(0: Byte) - checkActualSize(BINARY, binary, 4 + 4) + checkActualSize(BINARY, binary, 4 + 4) val generic = Map(1 -> "a") checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8) @@ -167,7 +167,7 @@ class ColumnTypeSuite extends FunSuite with Logging { val serializer = new SparkSqlSerializer(conf).newInstance() val buffer = ByteBuffer.allocate(512) - val obj = CustomClass(Int.MaxValue,Long.MaxValue) + val obj = CustomClass(Int.MaxValue, Long.MaxValue) val serializedObj = serializer.serialize(obj).array() GENERIC.append(serializer.serialize(obj).array(), buffer) @@ -278,7 +278,7 @@ private[columnar] object CustomerSerializer extends Serializer[CustomClass] { override def read(kryo: Kryo, input: Input, aClass: Class[CustomClass]): CustomClass = { val a = input.readInt() val b = input.readLong() - CustomClass(a,b) + CustomClass(a, b) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala index 64b70552eb047..cef60ec204faa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType class DictionaryEncodingSuite extends FunSuite { - testDictionaryEncoding(new IntColumnStats, INT) - testDictionaryEncoding(new LongColumnStats, LONG) + testDictionaryEncoding(new IntColumnStats, INT) + testDictionaryEncoding(new LongColumnStats, LONG) testDictionaryEncoding(new StringColumnStats, STRING) def testDictionaryEncoding[T <: AtomicType]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala index bfd99f143bedc..5514590541dd6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.IntegralType class IntegralDeltaSuite extends FunSuite { - testIntegralDelta(new IntColumnStats, INT, IntDelta) + testIntegralDelta(new IntColumnStats, INT, IntDelta) testIntegralDelta(new LongColumnStats, LONG, LongDelta) def testIntegralDelta[I <: IntegralType]( @@ -116,7 +116,7 @@ class IntegralDeltaSuite extends FunSuite { test(s"$scheme: simple case") { val input = columnType match { - case INT => Seq(2: Int, 1: Int, 2: Int, 130: Int) + case INT => Seq(2: Int, 1: Int, 2: Int, 130: Int) case LONG => Seq(2: Long, 1: Long, 2: Long, 130: Long) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala index fde7a4595be0e..6ee48f6291914 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala @@ -26,11 +26,11 @@ import org.apache.spark.sql.types.AtomicType class RunLengthEncodingSuite extends FunSuite { testRunLengthEncoding(new NoopColumnStats, BOOLEAN) - testRunLengthEncoding(new ByteColumnStats, BYTE) - testRunLengthEncoding(new ShortColumnStats, SHORT) - testRunLengthEncoding(new IntColumnStats, INT) - testRunLengthEncoding(new LongColumnStats, LONG) - testRunLengthEncoding(new StringColumnStats, STRING) + testRunLengthEncoding(new ByteColumnStats, BYTE) + testRunLengthEncoding(new ShortColumnStats, SHORT) + testRunLengthEncoding(new IntColumnStats, INT) + testRunLengthEncoding(new LongColumnStats, LONG) + testRunLengthEncoding(new StringColumnStats, STRING) def testRunLengthEncoding[T <: AtomicType]( columnStats: ColumnStats, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 347f28351fd72..30279f528944b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -429,8 +429,8 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { }, testH2Dialect)) assert(agg.canHandle("jdbc:h2:xxx")) assert(!agg.canHandle("jdbc:h2")) - assert(agg.getCatalystType(0,"",1,null) == Some(LongType)) - assert(agg.getCatalystType(1,"",1,null) == Some(StringType)) + assert(agg.getCatalystType(0, "", 1, null) == Some(LongType)) + assert(agg.getCatalystType(1, "", 1, null) == Some(StringType)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 7e6eeba17752a..f8d62f9e7e02b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -522,7 +522,7 @@ class JsonSuite extends QueryTest { Row(Seq(), "11", "[1,2,3]", Row(null), "[]") :: Row(null, """{"field":false}""", null, null, "{}") :: Row(Seq(4, 5, 6), null, "str", Row(null), "[7,8,9]") :: - Row(Seq(7), "{}","""["str1","str2",33]""", Row("str"), """{"field":true}""") :: Nil + Row(Seq(7), "{}", """["str1","str2",33]""", Row("str"), """{"field":true}""") :: Nil ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index f5106f67a08df..5c3467158a01b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -43,7 +43,7 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo StructField("bigintType", LongType, nullable = false), StructField("tinyintType", ByteType, nullable = false), StructField("decimalType", DecimalType.Unlimited, nullable = false), - StructField("fixedDecimalType", DecimalType(5,1), nullable = false), + StructField("fixedDecimalType", DecimalType(5, 1), nullable = false), StructField("binaryType", BinaryType, nullable = false), StructField("booleanType", BooleanType, nullable = false), StructField("smallIntType", ShortType, nullable = false), @@ -51,8 +51,7 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo StructField("mapType", MapType(StringType, StringType)), StructField("arrayType", ArrayType(StringType)), StructField("structType", - StructType(StructField("f1",StringType) :: - (StructField("f2",IntegerType)) :: Nil + StructType(StructField("f1", StringType) :: StructField("f2", IntegerType) :: Nil ) ) )) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index cce747e7dbf64..db94b1f3e8926 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -154,7 +154,7 @@ class FilteredScanSuite extends DataSourceTest { sqlTest( "SELECT a, b FROM oneToTenFiltered WHERE a IN (1,3,5)", - Seq(1,3,5).map(i => Row(i, i * 2))) + Seq(1, 3, 5).map(i => Row(i, i * 2))) sqlTest( "SELECT a, b FROM oneToTenFiltered WHERE A = 1", From 2881d14cbedc14f1cd8ae5078446dba1a8d39086 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 28 May 2015 20:11:04 -0700 Subject: [PATCH 051/254] [SPARK-7929] Remove Bagel examples & whitespace fix for examples. Author: Reynold Xin Closes #6480 from rxin/whitespace-example and squashes the following commits: 8a4a3d4 [Reynold Xin] [SPARK-7929] Remove Bagel examples & whitespace fix for examples. --- .../spark/examples/CassandraCQLTest.scala | 6 +- .../org/apache/spark/examples/LocalLR.scala | 2 +- .../org/apache/spark/examples/SparkALS.scala | 2 +- .../org/apache/spark/examples/SparkLR.scala | 2 +- .../spark/examples/bagel/PageRankUtils.scala | 112 --------- .../examples/bagel/WikipediaPageRank.scala | 106 -------- .../bagel/WikipediaPageRankStandalone.scala | 232 ------------------ .../spark/examples/ml/OneVsRestExample.scala | 6 +- .../examples/mllib/DenseGaussianMixture.scala | 2 +- .../pythonconverters/AvroConverters.scala | 29 ++- .../examples/streaming/ActorWordCount.scala | 6 +- .../streaming/DirectKafkaWordCount.scala | 2 +- .../examples/streaming/KafkaWordCount.scala | 4 +- .../examples/streaming/MQTTWordCount.scala | 2 +- .../clickstream/PageViewGenerator.scala | 3 +- 15 files changed, 31 insertions(+), 485 deletions(-) delete mode 100644 examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index 11d5c92c5952d..023bb3ee2d108 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -104,8 +104,8 @@ object CassandraCQLTest { val casRdd = sc.newAPIHadoopRDD(job.getConfiguration(), classOf[CqlPagingInputFormat], - classOf[java.util.Map[String,ByteBuffer]], - classOf[java.util.Map[String,ByteBuffer]]) + classOf[java.util.Map[String, ByteBuffer]], + classOf[java.util.Map[String, ByteBuffer]]) println("Count: " + casRdd.count) val productSaleRDD = casRdd.map { @@ -118,7 +118,7 @@ object CassandraCQLTest { case (productId, saleCount) => println(productId + ":" + saleCount) } - val casoutputCF = aggregatedRDD.map { + val casoutputCF = aggregatedRDD.map { case (productId, saleCount) => { val outColFamKey = Map("prod_id" -> ByteBufferUtil.bytes(productId)) val outKey: java.util.Map[String, ByteBuffer] = outColFamKey diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index a55e0dc8d36c2..c3fc74a116c0a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -39,7 +39,7 @@ object LocalLR { def generateData: Array[DataPoint] = { def generatePoint(i: Int): DataPoint = { - val y = if(i % 2 == 0) -1 else 1 + val y = if (i % 2 == 0) -1 else 1 val x = DenseVector.fill(D){rand.nextGaussian + y * R} DataPoint(x, y) } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index 6c0ac8013ce34..30c4261551837 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -117,7 +117,7 @@ object SparkALS { var us = Array.fill(U)(randomVector(F)) // Iteratively update movies then users - val Rc = sc.broadcast(R) + val Rc = sc.broadcast(R) var msb = sc.broadcast(ms) var usb = sc.broadcast(us) for (iter <- 1 to ITERATIONS) { diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index 8c01a60844620..1e6b4fb0c7514 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -44,7 +44,7 @@ object SparkLR { def generateData: Array[DataPoint] = { def generatePoint(i: Int): DataPoint = { - val y = if(i % 2 == 0) -1 else 1 + val y = if (i % 2 == 0) -1 else 1 val x = DenseVector.fill(D){rand.nextGaussian + y * R} DataPoint(x, y) } diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala deleted file mode 100644 index ab6e63deb3c95..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.bagel - -import org.apache.spark._ -import org.apache.spark.bagel._ - -class PageRankUtils extends Serializable { - def computeWithCombiner(numVertices: Long, epsilon: Double)( - self: PRVertex, messageSum: Option[Double], superstep: Int - ): (PRVertex, Array[PRMessage]) = { - val newValue = messageSum match { - case Some(msgSum) if msgSum != 0 => - 0.15 / numVertices + 0.85 * msgSum - case _ => self.value - } - - val terminate = superstep >= 10 - - val outbox: Array[PRMessage] = - if (!terminate) { - self.outEdges.map(targetId => new PRMessage(targetId, newValue / self.outEdges.size)) - } else { - Array[PRMessage]() - } - - (new PRVertex(newValue, self.outEdges, !terminate), outbox) - } - - def computeNoCombiner(numVertices: Long, epsilon: Double) - (self: PRVertex, messages: Option[Array[PRMessage]], superstep: Int) - : (PRVertex, Array[PRMessage]) = - computeWithCombiner(numVertices, epsilon)(self, messages match { - case Some(msgs) => Some(msgs.map(_.value).sum) - case None => None - }, superstep) -} - -class PRCombiner extends Combiner[PRMessage, Double] with Serializable { - def createCombiner(msg: PRMessage): Double = - msg.value - def mergeMsg(combiner: Double, msg: PRMessage): Double = - combiner + msg.value - def mergeCombiners(a: Double, b: Double): Double = - a + b -} - -class PRVertex() extends Vertex with Serializable { - var value: Double = _ - var outEdges: Array[String] = _ - var active: Boolean = _ - - def this(value: Double, outEdges: Array[String], active: Boolean = true) { - this() - this.value = value - this.outEdges = outEdges - this.active = active - } - - override def toString(): String = { - "PRVertex(value=%f, outEdges.length=%d, active=%s)" - .format(value, outEdges.length, active.toString) - } -} - -class PRMessage() extends Message[String] with Serializable { - var targetId: String = _ - var value: Double = _ - - def this(targetId: String, value: Double) { - this() - this.targetId = targetId - this.value = value - } -} - -class CustomPartitioner(partitions: Int) extends Partitioner { - def numPartitions: Int = partitions - - def getPartition(key: Any): Int = { - val hash = key match { - case k: Long => (k & 0x00000000FFFFFFFFL).toInt - case _ => key.hashCode - } - - val mod = key.hashCode % partitions - if (mod < 0) mod + partitions else mod - } - - override def equals(other: Any): Boolean = other match { - case c: CustomPartitioner => - c.numPartitions == numPartitions - case _ => false - } - - override def hashCode: Int = numPartitions -} diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala deleted file mode 100644 index 859abedf2a55e..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.bagel - -import org.apache.spark._ -import org.apache.spark.SparkContext._ - -import org.apache.spark.bagel._ - -import scala.xml.{XML,NodeSeq} - -/** - * Run PageRank on XML Wikipedia dumps from http://wiki.freebase.com/wiki/WEX. Uses the "articles" - * files from there, which contains one line per wiki article in a tab-separated format - * (http://wiki.freebase.com/wiki/WEX/Documentation#articles). - */ -object WikipediaPageRank { - def main(args: Array[String]) { - if (args.length < 4) { - System.err.println( - "Usage: WikipediaPageRank ") - System.exit(-1) - } - val sparkConf = new SparkConf() - sparkConf.setAppName("WikipediaPageRank") - sparkConf.registerKryoClasses(Array(classOf[PRVertex], classOf[PRMessage])) - - val inputFile = args(0) - val threshold = args(1).toDouble - val numPartitions = args(2).toInt - val usePartitioner = args(3).toBoolean - - sparkConf.setAppName("WikipediaPageRank") - val sc = new SparkContext(sparkConf) - - // Parse the Wikipedia page data into a graph - val input = sc.textFile(inputFile) - - println("Counting vertices...") - val numVertices = input.count() - println("Done counting vertices.") - - println("Parsing input file...") - var vertices = input.map(line => { - val fields = line.split("\t") - val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) - val links = - if (body == "\\N") { - NodeSeq.Empty - } else { - try { - XML.loadString(body) \\ "link" \ "target" - } catch { - case e: org.xml.sax.SAXParseException => - System.err.println("Article \"" + title + "\" has malformed XML in body:\n" + body) - NodeSeq.Empty - } - } - val outEdges = links.map(link => new String(link.text)).toArray - val id = new String(title) - (id, new PRVertex(1.0 / numVertices, outEdges)) - }) - if (usePartitioner) { - vertices = vertices.partitionBy(new HashPartitioner(sc.defaultParallelism)).cache() - } else { - vertices = vertices.cache() - } - println("Done parsing input file.") - - // Do the computation - val epsilon = 0.01 / numVertices - val messages = sc.parallelize(Array[(String, PRMessage)]()) - val utils = new PageRankUtils - val result = - Bagel.run( - sc, vertices, messages, combiner = new PRCombiner(), - numPartitions = numPartitions)( - utils.computeWithCombiner(numVertices, epsilon)) - - // Print the result - System.err.println("Articles with PageRank >= " + threshold + ":") - val top = - (result - .filter { case (id, vertex) => vertex.value >= threshold } - .map { case (id, vertex) => "%s\t%s\n".format(id, vertex.value) } - .collect().mkString) - println(top) - - sc.stop() - } -} diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala deleted file mode 100644 index 576a3e371b993..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala +++ /dev/null @@ -1,232 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.bagel - -import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer -import scala.xml.{XML, NodeSeq} - -import org.apache.spark._ -import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.RDD - -import scala.reflect.ClassTag - -object WikipediaPageRankStandalone { - def main(args: Array[String]) { - if (args.length < 4) { - System.err.println("Usage: WikipediaPageRankStandalone " + - " ") - System.exit(-1) - } - val sparkConf = new SparkConf() - sparkConf.set("spark.serializer", "spark.bagel.examples.WPRSerializer") - - val inputFile = args(0) - val threshold = args(1).toDouble - val numIterations = args(2).toInt - val usePartitioner = args(3).toBoolean - - sparkConf.setAppName("WikipediaPageRankStandalone") - - val sc = new SparkContext(sparkConf) - - val input = sc.textFile(inputFile) - val partitioner = new HashPartitioner(sc.defaultParallelism) - val links = - if (usePartitioner) { - input.map(parseArticle _).partitionBy(partitioner).cache() - } else { - input.map(parseArticle _).cache() - } - val n = links.count() - val defaultRank = 1.0 / n - val a = 0.15 - - // Do the computation - val startTime = System.currentTimeMillis - val ranks = - pageRank(links, numIterations, defaultRank, a, n, partitioner, usePartitioner, - sc.defaultParallelism) - - // Print the result - System.err.println("Articles with PageRank >= " + threshold + ":") - val top = - (ranks - .filter { case (id, rank) => rank >= threshold } - .map { case (id, rank) => "%s\t%s\n".format(id, rank) } - .collect().mkString) - println(top) - - val time = (System.currentTimeMillis - startTime) / 1000.0 - println("Completed %d iterations in %f seconds: %f seconds per iteration" - .format(numIterations, time, time / numIterations)) - sc.stop() - } - - def parseArticle(line: String): (String, Array[String]) = { - val fields = line.split("\t") - val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) - val id = new String(title) - val links = - if (body == "\\N") { - NodeSeq.Empty - } else { - try { - XML.loadString(body) \\ "link" \ "target" - } catch { - case e: org.xml.sax.SAXParseException => - System.err.println("Article \"" + title + "\" has malformed XML in body:\n" + body) - NodeSeq.Empty - } - } - val outEdges = links.map(link => new String(link.text)).toArray - (id, outEdges) - } - - def pageRank( - links: RDD[(String, Array[String])], - numIterations: Int, - defaultRank: Double, - a: Double, - n: Long, - partitioner: Partitioner, - usePartitioner: Boolean, - numPartitions: Int - ): RDD[(String, Double)] = { - var ranks = links.mapValues { edges => defaultRank } - for (i <- 1 to numIterations) { - val contribs = links.groupWith(ranks).flatMap { - case (id, (linksWrapperIterable, rankWrapperIterable)) => - val linksWrapper = linksWrapperIterable.iterator - val rankWrapper = rankWrapperIterable.iterator - if (linksWrapper.hasNext) { - val linksWrapperHead = linksWrapper.next - if (rankWrapper.hasNext) { - val rankWrapperHead = rankWrapper.next - linksWrapperHead.map(dest => (dest, rankWrapperHead / linksWrapperHead.size)) - } else { - linksWrapperHead.map(dest => (dest, defaultRank / linksWrapperHead.size)) - } - } else { - Array[(String, Double)]() - } - } - ranks = (contribs.combineByKey((x: Double) => x, - (x: Double, y: Double) => x + y, - (x: Double, y: Double) => x + y, - partitioner) - .mapValues(sum => a/n + (1-a)*sum)) - } - ranks - } -} - -class WPRSerializer extends org.apache.spark.serializer.Serializer { - def newInstance(): SerializerInstance = new WPRSerializerInstance() -} - -class WPRSerializerInstance extends SerializerInstance { - def serialize[T: ClassTag](t: T): ByteBuffer = { - throw new UnsupportedOperationException() - } - - def deserialize[T: ClassTag](bytes: ByteBuffer): T = { - throw new UnsupportedOperationException() - } - - def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { - throw new UnsupportedOperationException() - } - - def serializeStream(s: OutputStream): SerializationStream = { - new WPRSerializationStream(s) - } - - def deserializeStream(s: InputStream): DeserializationStream = { - new WPRDeserializationStream(s) - } -} - -class WPRSerializationStream(os: OutputStream) extends SerializationStream { - val dos = new DataOutputStream(os) - - def writeObject[T: ClassTag](t: T): SerializationStream = t match { - case (id: String, wrapper: ArrayBuffer[_]) => wrapper(0) match { - case links: Array[String] => { - dos.writeInt(0) // links - dos.writeUTF(id) - dos.writeInt(links.length) - for (link <- links) { - dos.writeUTF(link) - } - this - } - case rank: Double => { - dos.writeInt(1) // rank - dos.writeUTF(id) - dos.writeDouble(rank) - this - } - } - case (id: String, rank: Double) => { - dos.writeInt(2) // rank without wrapper - dos.writeUTF(id) - dos.writeDouble(rank) - this - } - } - - def flush() { dos.flush() } - def close() { dos.close() } -} - -class WPRDeserializationStream(is: InputStream) extends DeserializationStream { - val dis = new DataInputStream(is) - - def readObject[T: ClassTag](): T = { - val typeId = dis.readInt() - typeId match { - case 0 => { - val id = dis.readUTF() - val numLinks = dis.readInt() - val links = new Array[String](numLinks) - for (i <- 0 until numLinks) { - val link = dis.readUTF() - links(i) = link - } - (id, ArrayBuffer(links)).asInstanceOf[T] - } - case 1 => { - val id = dis.readUTF() - val rank = dis.readDouble() - (id, ArrayBuffer(rank)).asInstanceOf[T] - } - case 2 => { - val id = dis.readUTF() - val rank = dis.readDouble() - (id, rank).asInstanceOf[T] - } - } - } - - def close() { dis.close() } -} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala index b99d0a1246011..6927eb8f275cf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -73,7 +73,7 @@ object OneVsRestExample { .action((x, c) => c.copy(fracTest = x)) opt[String]("testInput") .text("input path to test dataset. If given, option fracTest is ignored") - .action((x,c) => c.copy(testInput = Some(x))) + .action((x, c) => c.copy(testInput = Some(x))) opt[Int]("maxIter") .text(s"maximum number of iterations for Logistic Regression." + s" default: ${defaultParams.maxIter}") @@ -88,10 +88,10 @@ object OneVsRestExample { .action((x, c) => c.copy(fitIntercept = x)) opt[Double]("regParam") .text(s"the regularization parameter for Logistic Regression.") - .action((x,c) => c.copy(regParam = Some(x))) + .action((x, c) => c.copy(regParam = Some(x))) opt[Double]("elasticNetParam") .text(s"the ElasticNet mixing parameter for Logistic Regression.") - .action((x,c) => c.copy(elasticNetParam = Some(x))) + .action((x, c) => c.copy(elasticNetParam = Some(x))) checkConfig { params => if (params.fracTest < 0 || params.fracTest >= 1) { failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala index df76b45e50810..9a1aab036aa0f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala @@ -40,7 +40,7 @@ object DenseGaussianMixture { private def run(inputFile: String, k: Int, convergenceTol: Double, maxIterations: Int) { val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example") - val ctx = new SparkContext(conf) + val ctx = new SparkContext(conf) val data = ctx.textFile(inputFile).map { line => Vectors.dense(line.trim.split(' ').map(_.toDouble)) diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala index a11890d6f2b1c..3ebb112fc069e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala @@ -36,22 +36,21 @@ object AvroConversionUtil extends Serializable { return null } schema.getType match { - case UNION => unpackUnion(obj, schema) - case ARRAY => unpackArray(obj, schema) - case FIXED => unpackFixed(obj, schema) - case MAP => unpackMap(obj, schema) - case BYTES => unpackBytes(obj) - case RECORD => unpackRecord(obj) - case STRING => obj.toString - case ENUM => obj.toString - case NULL => obj + case UNION => unpackUnion(obj, schema) + case ARRAY => unpackArray(obj, schema) + case FIXED => unpackFixed(obj, schema) + case MAP => unpackMap(obj, schema) + case BYTES => unpackBytes(obj) + case RECORD => unpackRecord(obj) + case STRING => obj.toString + case ENUM => obj.toString + case NULL => obj case BOOLEAN => obj - case DOUBLE => obj - case FLOAT => obj - case INT => obj - case LONG => obj - case other => throw new SparkException( - s"Unknown Avro schema type ${other.getName}") + case DOUBLE => obj + case FLOAT => obj + case INT => obj + case LONG => obj + case other => throw new SparkException(s"Unknown Avro schema type ${other.getName}") } } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala index 92867b44be138..016de4c63d1d2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala @@ -104,10 +104,8 @@ extends Actor with ActorHelper { object FeederActor { def main(args: Array[String]) { - if(args.length < 2){ - System.err.println( - "Usage: FeederActor \n" - ) + if (args.length < 2){ + System.err.println("Usage: FeederActor \n") System.exit(1) } val Seq(host, port) = args.toSeq diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala index 11a8cf09533ce..fbe394de4a179 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala @@ -51,7 +51,7 @@ object DirectKafkaWordCount { // Create context with 2 second batch interval val sparkConf = new SparkConf().setAppName("DirectKafkaWordCount") - val ssc = new StreamingContext(sparkConf, Seconds(2)) + val ssc = new StreamingContext(sparkConf, Seconds(2)) // Create direct kafka stream with brokers and topics val topicsSet = topics.split(",").toSet diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala index 9ae1b045c2c76..60416ee343544 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala @@ -49,10 +49,10 @@ object KafkaWordCount { val Array(zkQuorum, group, topics, numThreads) = args val sparkConf = new SparkConf().setAppName("KafkaWordCount") - val ssc = new StreamingContext(sparkConf, Seconds(2)) + val ssc = new StreamingContext(sparkConf, Seconds(2)) ssc.checkpoint("checkpoint") - val topicMap = topics.split(",").map((_,numThreads.toInt)).toMap + val topicMap = topics.split(",").map((_, numThreads.toInt)).toMap val lines = KafkaUtils.createStream(ssc, zkQuorum, group, topicMap).map(_._2) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1L)) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala index 85b9a54b40baf..b336751d81616 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala @@ -49,7 +49,7 @@ object MQTTPublisher { client.connect() - val msgtopic = client.getTopic(topic) + val msgtopic = client.getTopic(topic) val msgContent = "hello mqtt demo for spark streaming" val message = new MqttMessage(msgContent.getBytes("utf-8")) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 54d996b8ac990..889f052c70263 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -57,8 +57,7 @@ object PageViewGenerator { 404 -> .05) val userZipCode = Map(94709 -> .5, 94117 -> .5) - val userID = Map((1 to 100).map(_ -> .01):_*) - + val userID = Map((1 to 100).map(_ -> .01) : _*) def pickFromDistribution[T](inputMap : Map[T, Double]) : T = { val rand = new Random().nextDouble() From 8da560d7de9b3c9a3e3ff197eeb10a3d7023f10d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 28 May 2015 20:11:57 -0700 Subject: [PATCH 052/254] [SPARK-7927] whitespace fixes for Catalyst module. So we can enable a whitespace enforcement rule in the style checker to save code review time. Author: Reynold Xin Closes #6476 from rxin/whitespace-catalyst and squashes the following commits: 650409d [Reynold Xin] Fixed tests. 51a9e5d [Reynold Xin] [SPARK-7927] whitespace fixes for Catalyst module. --- .../sql/catalyst/AbstractSparkSQLParser.scala | 2 +- .../apache/spark/sql/catalyst/SqlParser.scala | 8 +- .../sql/catalyst/analysis/Analyzer.scala | 9 +- .../spark/sql/catalyst/analysis/Catalog.scala | 2 +- .../catalyst/analysis/HiveTypeCoercion.scala | 5 +- .../spark/sql/catalyst/dsl/package.scala | 2 +- .../spark/sql/catalyst/errors/package.scala | 7 -- .../spark/sql/catalyst/expressions/Cast.scala | 84 +++++++++---------- .../catalyst/expressions/ExtractValue.scala | 2 +- .../sql/catalyst/expressions/aggregates.scala | 4 +- .../expressions/codegen/CodeGenerator.scala | 4 +- .../codegen/GenerateProjection.scala | 2 +- .../sql/catalyst/expressions/generators.scala | 4 +- .../expressions/stringOperations.scala | 4 +- .../expressions/windowExpressions.scala | 2 +- .../spark/sql/catalyst/plans/QueryPlan.scala | 4 +- .../plans/logical/basicOperators.scala | 2 +- .../spark/sql/catalyst/trees/TreeNode.scala | 4 +- .../spark/sql/catalyst/util/package.scala | 2 +- .../org/apache/spark/sql/types/DataType.scala | 2 +- .../sql/catalyst/ScalaReflectionSuite.scala | 2 +- .../spark/sql/catalyst/SqlParserSuite.scala | 4 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 2 +- .../ExpressionEvaluationSuite.scala | 56 ++++++------- .../GeneratedEvaluationSuite.scala | 6 +- .../GeneratedMutableEvaluationSuite.scala | 2 +- .../BooleanSimplificationSuite.scala | 4 +- .../optimizer/FilterPushdownSuite.scala | 2 +- .../catalyst/optimizer/OptimizeInSuite.scala | 6 +- .../optimizer/UnionPushdownSuite.scala | 4 +- .../sql/catalyst/trees/TreeNodeSuite.scala | 4 +- .../spark/sql/types/DataTypeSuite.scala | 4 +- 32 files changed, 121 insertions(+), 130 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index 2eb3e167baad5..ef7b3ad9432cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -103,7 +103,7 @@ class SqlLexical extends StdLexical { ( identChar ~ (identChar | digit).* ^^ { case first ~ rest => processIdent((first :: rest).mkString) } | rep1(digit) ~ ('.' ~> digit.*).? ^^ { - case i ~ None => NumericLit(i.mkString) + case i ~ None => NumericLit(i.mkString) case i ~ Some(d) => FloatLit(i.mkString + "." + d.mkString) } | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index fc36b9f1f20d2..e85312aee7d16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -140,7 +140,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { (HAVING ~> expression).? ~ sortType.? ~ (LIMIT ~> expression).? ^^ { - case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => + case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => val base = r.getOrElse(OneRowRelation) val withFilter = f.map(Filter(_, base)).getOrElse(base) val withProjection = g @@ -212,7 +212,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val ordering: Parser[Seq[SortOrder]] = ( rep1sep(expression ~ direction.? , ",") ^^ { - case exps => exps.map(pair => SortOrder(pair._1, pair._2.getOrElse(Ascending))) + case exps => exps.map(pair => SortOrder(pair._1, pair._2.getOrElse(Ascending))) } ) @@ -242,7 +242,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { | termExpression ~ NOT.? ~ (BETWEEN ~> termExpression) ~ (AND ~> termExpression) ^^ { case e ~ not ~ el ~ eu => val betweenExpr: Expression = And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu)) - not.fold(betweenExpr)(f=> Not(betweenExpr)) + not.fold(betweenExpr)(f => Not(betweenExpr)) } | termExpression ~ (RLIKE ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } | termExpression ~ (REGEXP ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } @@ -365,7 +365,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val baseExpression: Parser[Expression] = ( "*" ^^^ UnresolvedStar(None) - | ident <~ "." ~ "*" ^^ { case tableName => UnresolvedStar(Option(tableName)) } + | ident <~ "." ~ "*" ^^ { case tableName => UnresolvedStar(Option(tableName)) } | primary ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c239e83271615..df37889eedcf0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -494,7 +494,7 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _)) if aggregate.resolved && containsAggregate(havingCondition) => { - val evaluatedCondition = Alias(havingCondition, "havingCondition")() + val evaluatedCondition = Alias(havingCondition, "havingCondition")() val aggExprsWithHaving = evaluatedCondition +: originalAggExprs Project(aggregate.output, @@ -515,16 +515,15 @@ class Analyzer( * - concrete attribute references for their output. * - to be relocated from a SELECT clause (i.e. from a [[Project]]) into a [[Generate]]). * - * Names for the output [[Attributes]] are extracted from [[Alias]] or [[MultiAlias]] expressions + * Names for the output [[Attribute]]s are extracted from [[Alias]] or [[MultiAlias]] expressions * that wrap the [[Generator]]. If more than one [[Generator]] is found in a Project, an * [[AnalysisException]] is throw. */ object ResolveGenerate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case p: Generate if !p.child.resolved || !p.generator.resolved => p - case g: Generate if g.resolved == false => - g.copy( - generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) + case g: Generate if !g.resolved => + g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) case p @ Project(projectList, child) => // Holds the resolved generator, if one exists in the project list. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 208021c421326..3e240fd55e250 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -140,7 +140,7 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { trait OverrideCatalog extends Catalog { // TODO: This doesn't work when the database changes... - val overrides = new mutable.HashMap[(Option[String],String), LogicalPlan]() + val overrides = new mutable.HashMap[(Option[String], String), LogicalPlan]() abstract override def tableExists(tableIdentifier: Seq[String]): Boolean = { val tableIdent = processTableIdentifier(tableIdentifier) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index b45b17d856fac..44664f898f762 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -561,8 +561,7 @@ trait HiveTypeCoercion { case a @ CreateArray(children) if !a.resolved => val commonType = a.childTypes.reduce( - (a,b) => - findTightestCommonType(a,b).getOrElse(StringType)) + (a, b) => findTightestCommonType(a, b).getOrElse(StringType)) CreateArray( children.map(c => if (c.dataType == commonType) c else Cast(c, commonType))) @@ -634,7 +633,7 @@ trait HiveTypeCoercion { import HiveTypeCoercion._ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case cw: CaseWhenLike if !cw.resolved && cw.childrenResolved && !cw.valueTypesEqual => + case cw: CaseWhenLike if !cw.resolved && cw.childrenResolved && !cw.valueTypesEqual => logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}") val commonType = cw.valueTypes.reduce { (v1, v2) => findTightestCommonType(v1, v2).getOrElse(sys.error( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 60ab9fba4885a..51821757967d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -140,7 +140,7 @@ package object dsl { // Note that if we make ExpressionConversions an object rather than a trait, we can // then make this a value class to avoid the small penalty of runtime instantiation. def $(args: Any*): analysis.UnresolvedAttribute = { - analysis.UnresolvedAttribute(sc.s(args :_*)) + analysis.UnresolvedAttribute(sc.s(args : _*)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala index 0fd4f9b374ee0..d2a90a50c89f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala @@ -49,11 +49,4 @@ package object errors { case e: Exception => throw new TreeNodeException(tree, msg, e) } } - - /** - * Executes `f` which is expected to throw a - * [[catalyst.errors.TreeNodeException TreeNodeException]]. The first tree encountered in - * the stack of exceptions of type `TreeType` is returned. - */ - def getTree[TreeType <: TreeNode[_]](f: => Unit): TreeType = ??? // TODO: Implement } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index df3cdf2cdf992..21adac144112e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -35,48 +35,48 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def forceNullable(from: DataType, to: DataType) = (from, to) match { case (StringType, _: NumericType) => true - case (StringType, TimestampType) => true - case (DoubleType, TimestampType) => true - case (FloatType, TimestampType) => true - case (StringType, DateType) => true - case (_: NumericType, DateType) => true - case (BooleanType, DateType) => true - case (DateType, _: NumericType) => true - case (DateType, BooleanType) => true + case (StringType, TimestampType) => true + case (DoubleType, TimestampType) => true + case (FloatType, TimestampType) => true + case (StringType, DateType) => true + case (_: NumericType, DateType) => true + case (BooleanType, DateType) => true + case (DateType, _: NumericType) => true + case (DateType, BooleanType) => true case (DoubleType, _: DecimalType) => true - case (FloatType, _: DecimalType) => true + case (FloatType, _: DecimalType) => true case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null - case _ => false + case _ => false } private[this] def resolvableNullability(from: Boolean, to: Boolean) = !from || to private[this] def resolve(from: DataType, to: DataType): Boolean = { (from, to) match { - case (from, to) if from == to => true + case (from, to) if from == to => true - case (NullType, _) => true + case (NullType, _) => true - case (_, StringType) => true + case (_, StringType) => true - case (StringType, BinaryType) => true + case (StringType, BinaryType) => true - case (StringType, BooleanType) => true - case (DateType, BooleanType) => true - case (TimestampType, BooleanType) => true - case (_: NumericType, BooleanType) => true + case (StringType, BooleanType) => true + case (DateType, BooleanType) => true + case (TimestampType, BooleanType) => true + case (_: NumericType, BooleanType) => true - case (StringType, TimestampType) => true - case (BooleanType, TimestampType) => true - case (DateType, TimestampType) => true - case (_: NumericType, TimestampType) => true + case (StringType, TimestampType) => true + case (BooleanType, TimestampType) => true + case (DateType, TimestampType) => true + case (_: NumericType, TimestampType) => true - case (_, DateType) => true + case (_, DateType) => true - case (StringType, _: NumericType) => true - case (BooleanType, _: NumericType) => true - case (DateType, _: NumericType) => true - case (TimestampType, _: NumericType) => true + case (StringType, _: NumericType) => true + case (BooleanType, _: NumericType) => true + case (DateType, _: NumericType) => true + case (TimestampType, _: NumericType) => true case (_: NumericType, _: NumericType) => true case (ArrayType(from, fn), ArrayType(to, tn)) => @@ -410,21 +410,21 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def cast(from: DataType, to: DataType): Any => Any = to match { case dt if dt == child.dataType => identity[Any] - case StringType => castToString(from) - case BinaryType => castToBinary(from) - case DateType => castToDate(from) - case decimal: DecimalType => castToDecimal(from, decimal) - case TimestampType => castToTimestamp(from) - case BooleanType => castToBoolean(from) - case ByteType => castToByte(from) - case ShortType => castToShort(from) - case IntegerType => castToInt(from) - case FloatType => castToFloat(from) - case LongType => castToLong(from) - case DoubleType => castToDouble(from) - case array: ArrayType => castArray(from.asInstanceOf[ArrayType], array) - case map: MapType => castMap(from.asInstanceOf[MapType], map) - case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) + case StringType => castToString(from) + case BinaryType => castToBinary(from) + case DateType => castToDate(from) + case decimal: DecimalType => castToDecimal(from, decimal) + case TimestampType => castToTimestamp(from) + case BooleanType => castToBoolean(from) + case ByteType => castToByte(from) + case ShortType => castToShort(from) + case IntegerType => castToInt(from) + case FloatType => castToFloat(from) + case LongType => castToLong(from) + case DoubleType => castToDouble(from) + case array: ArrayType => castArray(from.asInstanceOf[ArrayType], array) + case map: MapType => castMap(from.asInstanceOf[MapType], map) + case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) } private[this] lazy val cast: Any => Any = cast(child.dataType, dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala index b5f4e16745c1b..a1e0819e8a433 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala @@ -47,7 +47,7 @@ object ExtractValue { case (ArrayType(StructType(fields), containsNull), Literal(fieldName, StringType)) => val ordinal = findField(fields, fieldName.toString, resolver) GetArrayStructFields(child, fields(ordinal), ordinal, containsNull) - case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] => + case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] => GetArrayItem(child, extraction) case (_: MapType, _) => GetMapValue(child, extraction) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 72eff5fe961f0..6c380d3084652 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -111,7 +111,7 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr override def update(input: Row): Unit = { if (currentMin.value == null) { currentMin.value = expr.eval(input) - } else if(cmp.eval(input) == true) { + } else if (cmp.eval(input) == true) { currentMin.value = expr.eval(input) } } @@ -142,7 +142,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr override def update(input: Row): Unit = { if (currentMax.value == null) { currentMax.value = expr.eval(input) - } else if(cmp.eval(input) == true) { + } else if (cmp.eval(input) == true) { currentMax.value = expr.eval(input) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index ecb4c4b68f904..36964af68dd8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -373,7 +373,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin // Uh, bad function name... child.castOrNull(c => q"!$c", BooleanType) - case Add(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 + $eval2" } + case Add(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 + $eval2" } case Subtract(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 - $eval2" } case Multiply(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 * $eval2" } case Divide(e1, e2) => @@ -665,7 +665,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected def defaultPrimitive(dt: DataType) = dt match { case BooleanType => ru.Literal(Constant(false)) case FloatType => ru.Literal(Constant(-1.0.toFloat)) - case StringType => q"""org.apache.spark.sql.types.UTF8String("")""" + case StringType => q"""org.apache.spark.sql.types.UTF8String("")""" case ShortType => ru.Literal(Constant(-1.toShort)) case LongType => ru.Literal(Constant(-1L)) case ByteType => ru.Literal(Constant(-1.toByte)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 584f938445c8c..31c63a79ebc8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -161,7 +161,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } } - val hashValues = expressions.zipWithIndex.map { case (e,i) => + val hashValues = expressions.zipWithIndex.map { case (e, i) => val elementName = newTermName(s"c$i") val nonNull = e.dataType match { case BooleanType => q"if ($elementName) 0 else 1" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index cab40feb72d47..634138010fd21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -103,8 +103,8 @@ case class Explode(child: Expression) val inputArray = child.eval(input).asInstanceOf[Seq[Any]] if (inputArray == null) Nil else inputArray.map(v => new GenericRow(Array(v))) case MapType(_, _, _) => - val inputMap = child.eval(input).asInstanceOf[Map[Any,Any]] - if (inputMap == null) Nil else inputMap.map { case (k,v) => new GenericRow(Array(k,v)) } + val inputMap = child.eval(input).asInstanceOf[Map[Any, Any]] + if (inputMap == null) Nil else inputMap.map { case (k, v) => new GenericRow(Array(k, v)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 5da93fe9c6cf9..83a44a12f0682 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -38,14 +38,14 @@ trait StringRegexExpression extends ExpectsInputTypes { case _ => null } - protected def compile(str: String): Pattern = if(str == null) { + protected def compile(str: String): Pattern = if (str == null) { null } else { // Let it raise exception if couldn't compile the regex string Pattern.compile(escape(str)) } - protected def pattern(str: String) = if(cache == null) compile(str) else cache + protected def pattern(str: String) = if (cache == null) compile(str) else cache override def eval(input: Row): Any = { val l = left.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 2729b34a0833f..82c4d462cc322 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -66,7 +66,7 @@ case class WindowSpecDefinition( } } - override def children: Seq[Expression] = partitionSpec ++ orderSpec + override def children: Seq[Expression] = partitionSpec ++ orderSpec override lazy val resolved: Boolean = childrenResolved && frameSpecification.isInstanceOf[SpecifiedWindowFrame] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 7967189cacb24..eff5c61644944 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -84,7 +84,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy val newArgs = productIterator.map { case e: Expression => transformExpressionDown(e) case Some(e: Expression) => Some(transformExpressionDown(e)) - case m: Map[_,_] => m + case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map { case e: Expression => transformExpressionDown(e) @@ -117,7 +117,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy val newArgs = productIterator.map { case e: Expression => transformExpressionUp(e) case Some(e: Expression) => Some(transformExpressionUp(e)) - case m: Map[_,_] => m + case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map { case e: Expression => transformExpressionUp(e) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 01f4b6e9bb77d..33a9e55a47dee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -93,7 +93,7 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override lazy val resolved: Boolean = childrenResolved && - left.output.zip(right.output).forall { case (l,r) => l.dataType == r.dataType } + left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } override def statistics: Statistics = { val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 28e15566f0961..36d005d0e1684 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -254,7 +254,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } else { Some(arg) } - case m: Map[_,_] => m + case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => @@ -311,7 +311,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } else { Some(arg) } - case m: Map[_,_] => m + case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 9d613a940ee86..07054166a5e88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -83,7 +83,7 @@ package object util { } def resourceToString( - resource:String, + resource: String, encoding: String = "UTF-8", classLoader: ClassLoader = Utils.getSparkClassLoader): String = { new String(resourceToBytes(resource, classLoader), encoding) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index a0b261649f66f..54604808e133e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -271,7 +271,7 @@ object DataType { protected lazy val structField: Parser[StructField] = ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { - case name ~ tpe ~ nullable => + case name ~ tpe ~ nullable => StructField(name, tpe, nullable = nullable) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index bbc0b661a0c0c..7ff51db76b6bb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -253,7 +253,7 @@ class ScalaReflectionSuite extends FunSuite { } assert(ArrayType(IntegerType) === typeOfObject3(Seq(1, 2, 3))) - assert(ArrayType(ArrayType(IntegerType)) === typeOfObject3(Seq(Seq(1,2,3)))) + assert(ArrayType(ArrayType(IntegerType)) === typeOfObject3(Seq(Seq(1, 2, 3)))) } test("convert PrimitiveData to catalyst") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala index 890ea2a84b82e..9eed15952d82b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala @@ -28,7 +28,7 @@ private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Comman } private[sql] class SuperLongKeywordTestParser extends AbstractSparkSQLParser { - protected val EXECUTE = Keyword("THISISASUPERLONGKEYWORDTEST") + protected val EXECUTE = Keyword("THISISASUPERLONGKEYWORDTEST") override protected lazy val start: Parser[LogicalPlan] = set @@ -39,7 +39,7 @@ private[sql] class SuperLongKeywordTestParser extends AbstractSparkSQLParser { } private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser { - protected val EXECUTE = Keyword("EXECUTE") + protected val EXECUTE = Keyword("EXECUTE") override protected lazy val start: Parser[LogicalPlan] = set diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 939cefb71b817..fcff24ca31486 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -155,7 +155,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { caseSensitive: Boolean = true): Unit = { test(name) { val error = intercept[AnalysisException] { - if(caseSensitive) { + if (caseSensitive) { caseSensitiveAnalyze(plan) } else { caseInsensitiveAnalyze(plan) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 5c4a1527c27c9..a14f776b1eaee 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -43,8 +43,8 @@ class ExpressionEvaluationBaseSuite extends FunSuite { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } - if(actual != expected) { - val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" + if (actual != expected) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") } } @@ -126,37 +126,37 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } booleanLogicTest("AND", _ && _, - (true, true, true) :: - (true, false, false) :: - (true, null, null) :: - (false, true, false) :: + (true, true, true) :: + (true, false, false) :: + (true, null, null) :: + (false, true, false) :: (false, false, false) :: - (false, null, false) :: - (null, true, null) :: - (null, false, false) :: - (null, null, null) :: Nil) + (false, null, false) :: + (null, true, null) :: + (null, false, false) :: + (null, null, null) :: Nil) booleanLogicTest("OR", _ || _, - (true, true, true) :: - (true, false, true) :: - (true, null, true) :: - (false, true, true) :: + (true, true, true) :: + (true, false, true) :: + (true, null, true) :: + (false, true, true) :: (false, false, false) :: - (false, null, null) :: - (null, true, true) :: - (null, false, null) :: - (null, null, null) :: Nil) + (false, null, null) :: + (null, true, true) :: + (null, false, null) :: + (null, null, null) :: Nil) booleanLogicTest("=", _ === _, - (true, true, true) :: - (true, false, false) :: - (true, null, null) :: - (false, true, false) :: + (true, true, true) :: + (true, false, false) :: + (true, null, null) :: + (false, true, false) :: (false, false, true) :: - (false, null, null) :: - (null, true, null) :: - (null, false, null) :: - (null, null, null) :: Nil) + (false, null, null) :: + (null, true, null) :: + (null, false, null) :: + (null, null, null) :: Nil) def booleanLogicTest( name: String, @@ -164,7 +164,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { truthTable: Seq[(Any, Any, Any)]) { test(s"3VL $name") { truthTable.foreach { - case (l,r,answer) => + case (l, r, answer) => val expr = op(Literal.create(l, BooleanType), Literal.create(r, BooleanType)) checkEvaluation(expr, answer) } @@ -928,7 +928,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { :: StructField("b", StringType, nullable = false) :: Nil ) - assert(getStructField(BoundReference(2,typeS, nullable = true), "a").nullable === true) + assert(getStructField(BoundReference(2, typeS, nullable = true), "a").nullable === true) assert(getStructField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala index b5ebe4b38e337..d7c437095e395 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala @@ -41,9 +41,9 @@ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { """.stripMargin) } - val actual = plan(inputRow).apply(0) - if(actual != expected) { - val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" + val actual = plan(inputRow).apply(0) + if (actual != expected) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala index 97af2e0fd0502..a40324b008e16 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala @@ -53,7 +53,7 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { """.stripMargin) } if (actual != expectedRow) { - val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 6255578d7fa57..465a5e6914204 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -78,9 +78,9 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { test("(a && b && c && ...) || (a && b && d && ...) || (a && b && e && ...) ...") { checkCondition('b > 3 || 'c > 5, 'b > 3 || 'c > 5) - checkCondition(('a < 2 && 'a > 3 && 'b > 5) || 'a < 2, 'a < 2) + checkCondition(('a < 2 && 'a > 3 && 'b > 5) || 'a < 2, 'a < 2) - checkCondition('a < 2 || ('a < 2 && 'a > 3 && 'b > 5), 'a < 2) + checkCondition('a < 2 || ('a < 2 && 'a > 3 && 'b > 5), 'a < 2) val input = ('a === 'b && 'b > 3 && 'c > 2) || ('a === 'b && 'c < 1 && 'a === 5) || diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index be33cb9bb8eaa..ff25470bf0946 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -97,7 +97,7 @@ class FilterPushdownSuite extends PlanTest { test("column pruning for Project(ne, Limit)") { val originalQuery = testRelation - .select('a,'b) + .select('a, 'b) .limit(2) .select('a) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 3eb399e68e70c..11b0859d3f066 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -46,7 +46,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: In clause optimized to InSet") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2)))) + .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2)))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -61,13 +61,13 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: In clause not optimized in case filter has attributes") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b")))) + .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) .analyze val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b")))) + .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) .analyze comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala index a3ad200800b02..35f50be46b76f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala @@ -33,8 +33,8 @@ class UnionPushdownSuite extends PlanTest { UnionPushdown) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) val testUnion = Union(testRelation, testRelation2) test("union: filter to each side") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index e5f77dcd962a4..9fcfc51c96139 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -91,7 +91,7 @@ class TreeNodeSuite extends FunSuite { test("transform works on nodes with Option children") { val dummy1 = Dummy(Some(Literal.create("1", StringType))) val dummy2 = Dummy(None) - val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) } + val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) } var actual = dummy1 transformDown toZero assert(actual === Dummy(Some(Literal(0)))) @@ -104,7 +104,7 @@ class TreeNodeSuite extends FunSuite { } test("preserves origin") { - CurrentOrigin.setPosition(1,1) + CurrentOrigin.setPosition(1, 1) val add = Add(Literal(1), Literal(1)) CurrentOrigin.reset() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index a73317c86916b..df119827812f9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -250,11 +250,11 @@ class DataTypeSuite extends FunSuite { expected = false) checkEqualsIgnoreCompatibleNullability( from = MapType(StringType, ArrayType(IntegerType, true), valueContainsNull = true), - to = MapType(StringType, ArrayType(IntegerType, false), valueContainsNull = true), + to = MapType(StringType, ArrayType(IntegerType, false), valueContainsNull = true), expected = false) checkEqualsIgnoreCompatibleNullability( from = MapType(StringType, ArrayType(IntegerType, false), valueContainsNull = true), - to = MapType(StringType, ArrayType(IntegerType, true), valueContainsNull = true), + to = MapType(StringType, ArrayType(IntegerType, true), valueContainsNull = true), expected = true) From 7f7505d8db7759ea46e904f767c23130eff1104a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 28 May 2015 20:15:52 -0700 Subject: [PATCH 053/254] [SPARK-7927] whitespace fixes for core. So we can enable a whitespace enforcement rule in the style checker to save code review time. Author: Reynold Xin Closes #6473 from rxin/whitespace-core and squashes the following commits: 058195d [Reynold Xin] Fixed tests. fce11e9 [Reynold Xin] [SPARK-7927] whitespace fixes for core. --- .../scala/org/apache/spark/Accumulators.scala | 2 +- .../scala/org/apache/spark/Aggregator.scala | 4 +-- .../scala/org/apache/spark/Partitioner.scala | 8 ++--- .../scala/org/apache/spark/SparkConf.scala | 2 +- .../scala/org/apache/spark/SparkContext.scala | 10 +++--- .../scala/org/apache/spark/SparkEnv.scala | 4 +-- .../org/apache/spark/SparkHadoopWriter.scala | 10 +++--- .../apache/spark/api/java/JavaRDDLike.scala | 2 +- .../apache/spark/api/python/PythonRDD.scala | 2 +- .../apache/spark/api/r/RBackendHandler.scala | 4 +-- .../scala/org/apache/spark/api/r/RRDD.scala | 2 +- .../spark/broadcast/HttpBroadcast.scala | 4 +-- .../spark/deploy/FaultToleranceTest.scala | 2 +- .../org/apache/spark/deploy/SparkSubmit.scala | 2 +- .../apache/spark/deploy/worker/Worker.scala | 2 +- .../apache/spark/executor/TaskMetrics.scala | 2 +- .../mapreduce/SparkHadoopMapReduceUtil.scala | 2 +- .../spark/network/nio/BlockMessage.scala | 2 +- .../spark/network/nio/BlockMessageArray.scala | 4 +-- .../spark/network/nio/SecurityMessage.scala | 2 +- .../spark/partial/GroupedCountEvaluator.scala | 6 ++-- .../org/apache/spark/rdd/CheckpointRDD.scala | 2 +- .../org/apache/spark/rdd/CoalescedRDD.scala | 4 +-- .../apache/spark/rdd/PairRDDFunctions.scala | 6 ++-- .../main/scala/org/apache/spark/rdd/RDD.scala | 12 +++---- .../spark/rdd/SequenceFileRDDFunctions.scala | 6 ++-- .../org/apache/spark/rdd/SubtractedRDD.scala | 2 +- .../spark/rdd/ZippedPartitionsRDD.scala | 2 +- .../apache/spark/scheduler/DAGScheduler.scala | 4 +-- .../spark/scheduler/DAGSchedulerSource.scala | 3 +- .../spark/scheduler/SchedulingAlgorithm.scala | 2 +- .../spark/scheduler/SparkListener.scala | 6 ++-- .../spark/scheduler/TaskSetManager.scala | 4 +-- .../cluster/CoarseGrainedClusterMessage.scala | 3 +- .../cluster/YarnSchedulerBackend.scala | 2 +- .../mesos/CoarseMesosSchedulerBackend.scala | 2 +- .../cluster/mesos/MesosSchedulerBackend.scala | 2 +- .../mesos/MesosSchedulerBackendUtil.scala | 2 +- .../status/api/v1/AllStagesResource.scala | 4 +-- .../spark/status/api/v1/ApiRootResource.scala | 8 ++--- .../spark/status/api/v1/OneRDDResource.scala | 2 +- .../org/apache/spark/status/api/v1/api.scala | 2 +- .../storage/BlockManagerSlaveEndpoint.scala | 2 +- .../spark/storage/BlockManagerSource.scala | 3 +- .../scala/org/apache/spark/ui/SparkUI.scala | 2 +- .../scala/org/apache/spark/ui/UIUtils.scala | 2 +- .../apache/spark/ui/UIWorkloadGenerator.scala | 2 +- .../org/apache/spark/ui/storage/RDDPage.scala | 2 +- .../org/apache/spark/util/AkkaUtils.scala | 2 +- .../spark/util/CompletionIterator.scala | 2 +- .../org/apache/spark/util/Distribution.scala | 4 +-- .../apache/spark/util/MetadataCleaner.scala | 2 +- .../org/apache/spark/util/MutablePair.scala | 2 +- .../org/apache/spark/util/SizeEstimator.scala | 16 ++++----- .../scala/org/apache/spark/util/Utils.scala | 2 +- .../apache/spark/util/collection/BitSet.scala | 2 +- .../util/collection/SortDataFormat.scala | 4 +-- .../util/random/StratifiedSamplingUtils.scala | 2 +- .../org/apache/spark/AccumulatorSuite.scala | 2 +- .../org/apache/spark/CheckpointSuite.scala | 4 +-- .../apache/spark/ContextCleanerSuite.scala | 6 ++-- .../scala/org/apache/spark/FailureSuite.scala | 2 +- .../org/apache/spark/FileServerSuite.scala | 20 +++++------ .../scala/org/apache/spark/FileSuite.scala | 2 +- .../apache/spark/ImplicitOrderingSuite.scala | 4 +-- .../org/apache/spark/SparkConfSuite.scala | 12 +++---- .../org/apache/spark/SparkContextSuite.scala | 14 ++++---- .../spark/broadcast/BroadcastSuite.scala | 2 +- .../deploy/worker/WorkerArgumentsTest.scala | 2 +- .../spark/deploy/worker/WorkerSuite.scala | 2 +- .../metrics/InputOutputMetricsSuite.scala | 6 ++-- .../spark/rdd/PairRDDFunctionsSuite.scala | 10 +++--- .../scala/org/apache/spark/rdd/RDDSuite.scala | 34 +++++++++---------- .../org/apache/spark/rdd/RDDSuiteUtils.scala | 6 ++-- .../org/apache/spark/rdd/SortingSuite.scala | 4 +-- .../org/apache/spark/rpc/RpcEnvSuite.scala | 6 ++-- .../CoarseGrainedSchedulerBackendSuite.scala | 4 +-- .../spark/scheduler/DAGSchedulerSuite.scala | 6 ++-- .../apache/spark/scheduler/PoolSuite.scala | 4 +-- .../serializer/KryoSerializerSuite.scala | 4 +-- .../ProactiveClosureSerializationSuite.scala | 10 +++--- .../spark/serializer/TestSerializer.scala | 11 +++--- .../spark/storage/FlatmapIteratorSuite.scala | 6 ++-- .../org/apache/spark/ui/UISeleniumSuite.scala | 8 ++--- .../spark/ui/storage/StorageTabSuite.scala | 2 +- .../apache/spark/util/AkkaUtilsSuite.scala | 2 +- .../org/apache/spark/util/UtilsSuite.scala | 2 +- .../spark/util/collection/BitSetSuite.scala | 2 +- 88 files changed, 205 insertions(+), 203 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 330df1d59a9b1..5a8d17bd99933 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -228,7 +228,7 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa * @tparam T result type */ class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String]) - extends Accumulable[T,T](initialValue, param, name) { + extends Accumulable[T, T](initialValue, param, name) { def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None) } diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index af9765d313e9e..b8a5f5016860f 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -45,7 +45,7 @@ case class Aggregator[K, V, C] ( def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]], context: TaskContext): Iterator[(K, C)] = { if (!isSpillEnabled) { - val combiners = new AppendOnlyMap[K,C] + val combiners = new AppendOnlyMap[K, C] var kv: Product2[K, V] = null val update = (hadValue: Boolean, oldValue: C) => { if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) @@ -76,7 +76,7 @@ case class Aggregator[K, V, C] ( : Iterator[(K, C)] = { if (!isSpillEnabled) { - val combiners = new AppendOnlyMap[K,C] + val combiners = new AppendOnlyMap[K, C] var kc: Product2[K, C] = null val update = (hadValue: Boolean, oldValue: C) => { if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2 diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index b8d244408bc5b..82889bcd30988 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -103,7 +103,7 @@ class HashPartitioner(partitions: Int) extends Partitioner { */ class RangePartitioner[K : Ordering : ClassTag, V]( @transient partitions: Int, - @transient rdd: RDD[_ <: Product2[K,V]], + @transient rdd: RDD[_ <: Product2[K, V]], private var ascending: Boolean = true) extends Partitioner { @@ -185,7 +185,7 @@ class RangePartitioner[K : Ordering : ClassTag, V]( } override def equals(other: Any): Boolean = other match { - case r: RangePartitioner[_,_] => + case r: RangePartitioner[_, _] => r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending case _ => false @@ -249,7 +249,7 @@ private[spark] object RangePartitioner { * @param sampleSizePerPartition max sample size per partition * @return (total number of items, an array of (partitionId, number of items, sample)) */ - def sketch[K:ClassTag]( + def sketch[K : ClassTag]( rdd: RDD[K], sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = { val shift = rdd.id @@ -272,7 +272,7 @@ private[spark] object RangePartitioner { * @param partitions number of partitions * @return selected bounds */ - def determineBounds[K:Ordering:ClassTag]( + def determineBounds[K : Ordering : ClassTag]( candidates: ArrayBuffer[(K, Float)], partitions: Int): Array[K] = { val ordering = implicitly[Ordering[K]] diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index b5e5d6f1465f3..4b5bcb54aa873 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -481,7 +481,7 @@ private[spark] object SparkConf extends Logging { "are no longer accepted. To specify the equivalent now, one may use '64k'.") ) - Map(configs.map { cfg => (cfg.key -> cfg) }:_*) + Map(configs.map { cfg => (cfg.key -> cfg) } : _*) } /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ea6c0dea08e47..a453c9bf4864a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -389,7 +389,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER) - _jars =_conf.getOption("spark.jars").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten + _jars = _conf.getOption("spark.jars").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten _files = _conf.getOption("spark.files").map(_.split(",")).map(_.filter(_.size != 0)) .toSeq.flatten @@ -438,7 +438,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _ui = if (conf.getBoolean("spark.ui.enabled", true)) { Some(SparkUI.createLiveUI(this, _conf, listenerBus, _jobProgressListener, - _env.securityManager,appName, startTime = startTime)) + _env.securityManager, appName, startTime = startTime)) } else { // For tests, do not enable the UI None @@ -917,7 +917,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli classOf[FixedLengthBinaryInputFormat], classOf[LongWritable], classOf[BytesWritable], - conf=conf) + conf = conf) val data = br.map { case (k, v) => val bytes = v.getBytes assert(bytes.length == recordLength, "Byte array does not have correct length") @@ -1267,7 +1267,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T] (initialValue: R): Accumulable[R, T] = { - val param = new GrowableAccumulableParam[R,T] + val param = new GrowableAccumulableParam[R, T] val acc = new Accumulable(initialValue, param) cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc @@ -1316,7 +1316,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val uri = new URI(path) val schemeCorrectedPath = uri.getScheme match { case null | "local" => new File(path).getCanonicalFile.toURI.toString - case _ => path + case _ => path } val hadoopPath = new Path(schemeCorrectedPath) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 327114542880d..a185954089528 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -298,7 +298,7 @@ object SparkEnv extends Logging { } } - val mapOutputTracker = if (isDriver) { + val mapOutputTracker = if (isDriver) { new MapOutputTrackerMaster(conf) } else { new MapOutputTrackerWorker(conf) @@ -348,7 +348,7 @@ object SparkEnv extends Logging { val fileServerPort = conf.getInt("spark.fileserver.port", 0) val server = new HttpFileServer(conf, securityManager, fileServerPort) server.initialize() - conf.set("spark.fileserver.uri", server.serverUri) + conf.set("spark.fileserver.uri", server.serverUri) server } else { null diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 2ec42d3aea169..59ac82ccec53b 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -50,8 +50,8 @@ class SparkHadoopWriter(@transient jobConf: JobConf) private var jID: SerializableWritable[JobID] = null private var taID: SerializableWritable[TaskAttemptID] = null - @transient private var writer: RecordWriter[AnyRef,AnyRef] = null - @transient private var format: OutputFormat[AnyRef,AnyRef] = null + @transient private var writer: RecordWriter[AnyRef, AnyRef] = null + @transient private var format: OutputFormat[AnyRef, AnyRef] = null @transient private var committer: OutputCommitter = null @transient private var jobContext: JobContext = null @transient private var taskContext: TaskAttemptContext = null @@ -114,10 +114,10 @@ class SparkHadoopWriter(@transient jobConf: JobConf) // ********* Private Functions ********* - private def getOutputFormat(): OutputFormat[AnyRef,AnyRef] = { + private def getOutputFormat(): OutputFormat[AnyRef, AnyRef] = { if (format == null) { format = conf.value.getOutputFormat() - .asInstanceOf[OutputFormat[AnyRef,AnyRef]] + .asInstanceOf[OutputFormat[AnyRef, AnyRef]] } format } @@ -138,7 +138,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) private def getTaskContext(): TaskAttemptContext = { if (taskContext == null) { - taskContext = newTaskAttemptContext(conf.value, taID.value) + taskContext = newTaskAttemptContext(conf.value, taID.value) } taskContext } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 74db7643224f5..b8e15f38a20d2 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -96,7 +96,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsWithIndex[R]( f: JFunction2[jl.Integer, java.util.Iterator[T], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = - new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))), + new JavaRDD(rdd.mapPartitionsWithIndex(((a, b) => f(a, asJavaIterator(b))), preservesPartitioning)(fakeClassTag))(fakeClassTag) /** diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 2d92f6a42b308..a77bf42ce1d38 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -723,7 +723,7 @@ private[spark] object PythonRDD extends Logging { val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new JavaToWritableConverter) val fc = Utils.classForName(outputFormatClass).asInstanceOf[Class[F]] - converted.saveAsHadoopFile(path, kc, vc, fc, new JobConf(mergedConf), codec=codec) + converted.saveAsHadoopFile(path, kc, vc, fc, new JobConf(mergedConf), codec = codec) } /** diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 0075d963711f1..026a1b9380357 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -124,7 +124,7 @@ private[r] class RBackendHandler(server: RBackend) } throw new Exception(s"No matched method found for $cls.$methodName") } - val ret = methods.head.invoke(obj, args:_*) + val ret = methods.head.invoke(obj, args : _*) // Write status bit writeInt(dos, 0) @@ -135,7 +135,7 @@ private[r] class RBackendHandler(server: RBackend) matchMethod(numArgs, args, x.getParameterTypes) }.head - val obj = ctor.newInstance(args:_*) + val obj = ctor.newInstance(args : _*) writeInt(dos, 0) writeObject(dos, obj.asInstanceOf[AnyRef]) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 06247f7e8b78c..e020458888e4a 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -309,7 +309,7 @@ private class StringRRDD[T: ClassTag]( } private object SpecialLengths { - val TIMING_DATA = -1 + val TIMING_DATA = -1 } private[r] class BufferedStreamThread( diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 4457c75e8b0fc..b69af639f7862 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -125,7 +125,7 @@ private[broadcast] object HttpBroadcast extends Logging { securityManager = securityMgr if (isDriver) { createServer(conf) - conf.set("spark.httpBroadcast.uri", serverUri) + conf.set("spark.httpBroadcast.uri", serverUri) } serverUri = conf.get("spark.httpBroadcast.uri") cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup, conf) @@ -187,7 +187,7 @@ private[broadcast] object HttpBroadcast extends Logging { } private def read[T: ClassTag](id: Long): T = { - logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id) + logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id) val url = serverUri + "/" + BroadcastBlockId(id).name var uc: URLConnection = null diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index c048b78910f38..b4edb6109e839 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -65,7 +65,7 @@ private object FaultToleranceTest extends App with Logging { private val workers = ListBuffer[TestWorkerInfo]() private var sc: SparkContext = _ - private val zk = SparkCuratorUtil.newClient(conf) + private val zk = SparkCuratorUtil.newClient(conf) private var numPassed = 0 private var numFailed = 0 diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 198371b70f14f..92bb5059a0313 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -361,7 +361,7 @@ object SparkSubmit { pyArchives = pythonPath.mkString(",") } - pyArchives = pyArchives.split(",").map { localPath=> + pyArchives = pyArchives.split(",").map { localPath => val localURI = Utils.resolveURI(localPath) if (localURI.getScheme != "local") { args.files = mergeFileLists(args.files, localURI.toString) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index c8df024dda355..ebc6cd76c6afd 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -554,7 +554,7 @@ private[deploy] object Worker extends Logging { conf = conf, securityManager = securityMgr) val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem))) actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory, - masterAkkaUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName) + masterAkkaUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName) (actorSystem, boundPort) } diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 06152f16ae618..d90ae405a0849 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -261,7 +261,7 @@ case class InputMetrics(readMethod: DataReadMethod.Value) { */ private var _recordsRead: Long = _ def recordsRead: Long = _recordsRead - def incRecordsRead(records: Long): Unit = _recordsRead += records + def incRecordsRead(records: Long): Unit = _recordsRead += records /** * Invoke the bytesReadCallback and mutate bytesRead. diff --git a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala index cfd20392d12f1..390d148bc97f9 100644 --- a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala @@ -60,7 +60,7 @@ trait SparkHadoopMapReduceUtil { val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType") .asInstanceOf[Class[Enum[_]]] val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke( - taskTypeClass, if(isMap) "MAP" else "REDUCE") + taskTypeClass, if (isMap) "MAP" else "REDUCE") val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], taskTypeClass, classOf[Int], classOf[Int]) ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId), diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala index b573f1a8a5fcb..1a92a799d004a 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala @@ -110,7 +110,7 @@ private[nio] class BlockMessage() { def getType: Int = typ def getId: BlockId = id def getData: ByteBuffer = data - def getLevel: StorageLevel = level + def getLevel: StorageLevel = level def toBufferMessage: BufferMessage = { val buffers = new ArrayBuffer[ByteBuffer]() diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala index 1ba25aa74aa02..7d0806f0c2580 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala @@ -114,8 +114,8 @@ private[nio] object BlockMessageArray { val blockMessages = (0 until 10).map { i => if (i % 2 == 0) { - val buffer = ByteBuffer.allocate(100) - buffer.clear + val buffer = ByteBuffer.allocate(100) + buffer.clear() BlockMessage.fromPutBlock(PutBlock(TestBlockId(i.toString), buffer, StorageLevel.MEMORY_ONLY_SER)) } else { diff --git a/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala index 747a2088a7258..232c552f9865d 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala @@ -75,7 +75,7 @@ private[nio] class SecurityMessage extends Logging { for (i <- 1 to idLength) { idBuilder += buffer.getChar() } - connectionId = idBuilder.toString() + connectionId = idBuilder.toString() val tokenLength = buffer.getInt() token = new Array[Byte](tokenLength) diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala index 3ef3cc219dec6..91b07ce3af1b6 100644 --- a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala @@ -32,12 +32,12 @@ import org.apache.spark.util.collection.OpenHashMap * An ApproximateEvaluator for counts by key. Returns a map of key to confidence interval. */ private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[OpenHashMap[T,Long], Map[T, BoundedDouble]] { + extends ApproximateEvaluator[OpenHashMap[T, Long], Map[T, BoundedDouble]] { var outputsMerged = 0 - var sums = new OpenHashMap[T,Long]() // Sum of counts for each key + var sums = new OpenHashMap[T, Long]() // Sum of counts for each key - override def merge(outputId: Int, taskResult: OpenHashMap[T,Long]) { + override def merge(outputId: Int, taskResult: OpenHashMap[T, Long]) { outputsMerged += 1 taskResult.foreach { case (key, value) => sums.changeValue(key, value, _ + value) diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 0d130dd4c7a60..a4715e3437d94 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -49,7 +49,7 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) if (fs.exists(cpath)) { val dirContents = fs.listStatus(cpath).map(_.getPath) val partitionFiles = dirContents.filter(_.getName.startsWith("part-")).map(_.toString).sorted - val numPart = partitionFiles.length + val numPart = partitionFiles.length if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) { throw new SparkException("Invalid checkpoint directory: " + checkpointPath) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 0c1b02c07d09f..663eebb8e4191 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -310,11 +310,11 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: def throwBalls() { if (noLocality) { // no preferredLocations in parent RDD, no randomization needed if (maxPartitions > groupArr.size) { // just return prev.partitions - for ((p,i) <- prev.partitions.zipWithIndex) { + for ((p, i) <- prev.partitions.zipWithIndex) { groupArr(i).arr += p } } else { // no locality available, then simply split partitions based on positions in array - for(i <- 0 until maxPartitions) { + for (i <- 0 until maxPartitions) { val rangeStart = ((i.toLong * prev.partitions.length) / maxPartitions).toInt val rangeEnd = (((i.toLong + 1) * prev.partitions.length) / maxPartitions).toInt (rangeStart until rangeEnd).foreach{ j => groupArr(i).arr += prev.partitions(j) } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 8653cdee1adee..004899f27b7a6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -467,7 +467,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val mergeValue = (buf: CompactBuffer[V], v: V) => buf += v val mergeCombiners = (c1: CompactBuffer[V], c2: CompactBuffer[V]) => c1 ++= c2 val bufs = combineByKey[CompactBuffer[V]]( - createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine=false) + createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine = false) bufs.asInstanceOf[RDD[(K, Iterable[V])]] } @@ -1011,7 +1011,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) jobFormat.checkOutputSpecs(job) } - val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => { + val writeShard = (context: TaskContext, iter: Iterator[(K, V)]) => { val config = wrappedConf.value /* "reduce task" */ val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, @@ -1027,7 +1027,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context) - val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] + val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K, V]] require(writer != null, "Unable to obtain RecordWriter") var recordsWritten = 0L Utils.tryWithSafeFinally { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index d772f03f76651..5fcef255e13af 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -454,7 +454,7 @@ abstract class RDD[T: ClassTag]( withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] = { - val numStDev = 10.0 + val numStDev = 10.0 if (num < 0) { throw new IllegalArgumentException("Negative number of elements requested") @@ -1138,8 +1138,8 @@ abstract class RDD[T: ClassTag]( if (elementClassTag.runtimeClass.isArray) { throw new SparkException("countByValueApprox() does not support arrays") } - val countPartition: (TaskContext, Iterator[T]) => OpenHashMap[T,Long] = { (ctx, iter) => - val map = new OpenHashMap[T,Long] + val countPartition: (TaskContext, Iterator[T]) => OpenHashMap[T, Long] = { (ctx, iter) => + val map = new OpenHashMap[T, Long] iter.foreach { t => map.changeValue(t, 1L, _ + 1L) } @@ -1585,15 +1585,15 @@ abstract class RDD[T: ClassTag]( case 0 => Seq.empty case 1 => val d = rdd.dependencies.head - debugString(d.rdd, prefix, d.isInstanceOf[ShuffleDependency[_,_,_]], true) + debugString(d.rdd, prefix, d.isInstanceOf[ShuffleDependency[_, _, _]], true) case _ => val frontDeps = rdd.dependencies.take(len - 1) val frontDepStrings = frontDeps.flatMap( - d => debugString(d.rdd, prefix, d.isInstanceOf[ShuffleDependency[_,_,_]])) + d => debugString(d.rdd, prefix, d.isInstanceOf[ShuffleDependency[_, _, _]])) val lastDep = rdd.dependencies.last val lastDepStrings = - debugString(lastDep.rdd, prefix, lastDep.isInstanceOf[ShuffleDependency[_,_,_]], true) + debugString(lastDep.rdd, prefix, lastDep.isInstanceOf[ShuffleDependency[_, _, _]], true) (frontDepStrings ++ lastDepStrings) } diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala index 3dfcf67f0eb66..4b5f15dd06b85 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -104,13 +104,13 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag if (!convertKey && !convertValue) { self.saveAsHadoopFile(path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (!convertKey && convertValue) { - self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile( + self.map(x => (x._1, anyToWritable(x._2))).saveAsHadoopFile( path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (convertKey && !convertValue) { - self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile( + self.map(x => (anyToWritable(x._1), x._2)).saveAsHadoopFile( path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (convertKey && convertValue) { - self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile( + self.map(x => (anyToWritable(x._1), anyToWritable(x._2))).saveAsHadoopFile( path, keyWritableClass, valueWritableClass, format, jobConf, codec) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index 633aeba3bbae6..f7cb1791d4ac6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -125,7 +125,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( integrate(0, t => getSeq(t._1) += t._2) // the second dep is rdd2; remove all of its keys integrate(1, t => map.remove(t._1)) - map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten + map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index a96b6c3d23454..81f40ad33aa5d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -123,7 +123,7 @@ private[spark] class ZippedPartitionsRDD3 } private[spark] class ZippedPartitionsRDD4 - [A: ClassTag, B: ClassTag, C: ClassTag, D:ClassTag, V: ClassTag]( + [A: ClassTag, B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag]( sc: SparkContext, var f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], var rdd1: RDD[A], diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index a2299e907c5ae..75a567fb31520 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1367,10 +1367,10 @@ class DAGScheduler( private def getPreferredLocsInternal( rdd: RDD[_], partition: Int, - visited: HashSet[(RDD[_],Int)]): Seq[TaskLocation] = { + visited: HashSet[(RDD[_], Int)]): Seq[TaskLocation] = { // If the partition has already been visited, no need to re-visit. // This avoids exponential path exploration. SPARK-695 - if (!visited.add((rdd,partition))) { + if (!visited.add((rdd, partition))) { // Nil has already been returned for previously visited partitions. return Nil } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala index 12668b6c0988e..02c67073af6a0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala @@ -17,9 +17,8 @@ package org.apache.spark.scheduler -import com.codahale.metrics.{Gauge,MetricRegistry} +import com.codahale.metrics.{Gauge, MetricRegistry} -import org.apache.spark.SparkContext import org.apache.spark.metrics.source.Source private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler) diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala index 5e62c8468f007..864941d468af9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala @@ -56,7 +56,7 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0).toDouble val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble - var compare:Int = 0 + var compare: Int = 0 if (s1Needy && !s2Needy) { return true diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 863d0befbc19e..9620915f495ab 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -270,7 +270,7 @@ class StatsReportListener extends SparkListener with Logging { private[spark] object StatsReportListener extends Logging { // For profiling, the extremes are more interesting - val percentiles = Array[Int](0,5,10,25,50,75,90,95,100) + val percentiles = Array[Int](0, 5, 10, 25, 50, 75, 90, 95, 100) val probabilities = percentiles.map(_ / 100.0) val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%" @@ -304,7 +304,7 @@ private[spark] object StatsReportListener extends Logging { dOpt.foreach { d => showDistribution(heading, d, formatNumber)} } - def showDistribution(heading: String, dOpt: Option[Distribution], format:String) { + def showDistribution(heading: String, dOpt: Option[Distribution], format: String) { def f(d: Double): String = format.format(d) showDistribution(heading, dOpt, f _) } @@ -318,7 +318,7 @@ private[spark] object StatsReportListener extends Logging { } def showBytesDistribution( - heading:String, + heading: String, getMetric: (TaskInfo, TaskMetrics) => Option[Long], taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) { showBytesDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index c4487d5b37247..d473e51abab80 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -781,10 +781,10 @@ private[spark] class TaskSetManager( // that it's okay if we add a task to the same queue twice (if it had multiple preferred // locations), because dequeueTaskFromList will skip already-running tasks. for (index <- getPendingTasksForExecutor(execId)) { - addPendingTask(index, readding=true) + addPendingTask(index, readding = true) } for (index <- getPendingTasksForHost(host)) { - addPendingTask(index, readding=true) + addPendingTask(index, readding = true) } // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 70364cea62a80..4be1eda2e9291 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -75,7 +75,8 @@ private[spark] object CoarseGrainedClusterMessages { case class SetupDriver(driver: RpcEndpointRef) extends CoarseGrainedClusterMessage // Exchanged between the driver and the AM in Yarn client mode - case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase: String) + case class AddWebUIFilter( + filterName: String, filterParams: Map[String, String], proxyBase: String) extends CoarseGrainedClusterMessage // Messages exchanged between the driver and the cluster manager for executor allocation diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 2a3a5d925d06f..190ff61d689d1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -149,7 +149,7 @@ private[spark] abstract class YarnSchedulerBackend( } } - override def onStop(): Unit ={ + override def onStop(): Unit = { askAmThreadPool.shutdownNow() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index aff086594c73f..6b8edca5aa485 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -52,7 +52,7 @@ private[spark] class CoarseMesosSchedulerBackend( val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) - val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt + val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt // Cores we have acquired with each Mesos task ID val coresByTaskId = new HashMap[Int, Int] diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index db0a080b3b0c0..49de85ef48ada 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -146,7 +146,7 @@ private[spark] class MesosSchedulerBackend( private def createExecArg(): Array[Byte] = { if (execArgs == null) { val props = new HashMap[String, String] - for ((key,value) <- sc.conf.getAll) { + for ((key, value) <- sc.conf.getAll) { props(key) = value } // Serialize the map as an array of (String, String) pairs diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala index 928c5cfed417a..2f2934c249eb0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -108,7 +108,7 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { image: String, volumes: Option[List[Volume]] = None, network: Option[ContainerInfo.DockerInfo.Network] = None, - portmaps: Option[List[ContainerInfo.DockerInfo.PortMapping]] = None):Unit = { + portmaps: Option[List[ContainerInfo.DockerInfo.PortMapping]] = None): Unit = { val docker = ContainerInfo.DockerInfo.newBuilder().setImage(image) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 50608588f09ae..390c136df79b3 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -169,7 +169,7 @@ private[v1] object AllStagesResource { val outputMetrics: Option[OutputMetricDistributions] = new MetricHelper[InternalOutputMetrics, OutputMetricDistributions](rawMetrics, quantiles) { - def getSubmetrics(raw:InternalTaskMetrics): Option[InternalOutputMetrics] = { + def getSubmetrics(raw: InternalTaskMetrics): Option[InternalOutputMetrics] = { raw.outputMetrics } def build: OutputMetricDistributions = new OutputMetricDistributions( @@ -284,7 +284,7 @@ private[v1] object AllStagesResource { * the options (returning None if the metrics are all empty), and extract the quantiles for each * metric. After creating an instance, call metricOption to get the result type. */ -private[v1] abstract class MetricHelper[I,O]( +private[v1] abstract class MetricHelper[I, O]( rawMetrics: Seq[InternalTaskMetrics], quantiles: Array[Double]) { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index bf2cc2e72f1fe..f73c742732dec 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -101,7 +101,7 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { @Path("applications/{appId}/stages") - def getStages(@PathParam("appId") appId: String): AllStagesResource= { + def getStages(@PathParam("appId") appId: String): AllStagesResource = { uiRoot.withSparkUI(appId, None) { ui => new AllStagesResource(ui) } @@ -110,14 +110,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { @Path("applications/{appId}/{attemptId}/stages") def getStages( @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): AllStagesResource= { + @PathParam("attemptId") attemptId: String): AllStagesResource = { uiRoot.withSparkUI(appId, Some(attemptId)) { ui => new AllStagesResource(ui) } } @Path("applications/{appId}/stages/{stageId: \\d+}") - def getStage(@PathParam("appId") appId: String): OneStageResource= { + def getStage(@PathParam("appId") appId: String): OneStageResource = { uiRoot.withSparkUI(appId, None) { ui => new OneStageResource(ui) } @@ -171,7 +171,7 @@ private[spark] object ApiRootResource { def getServletHandler(uiRoot: UIRoot): ServletContextHandler = { val jerseyContext = new ServletContextHandler(ServletContextHandler.NO_SESSIONS) jerseyContext.setContextPath("/api") - val holder:ServletHolder = new ServletHolder(classOf[ServletContainer]) + val holder: ServletHolder = new ServletHolder(classOf[ServletContainer]) holder.setInitParameter("com.sun.jersey.config.property.resourceConfigClass", "com.sun.jersey.api.core.PackagesResourceConfig") holder.setInitParameter("com.sun.jersey.config.property.packages", diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala index 07b224fac4786..dfdc09c6caf3b 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala @@ -25,7 +25,7 @@ import org.apache.spark.ui.SparkUI private[v1] class OneRDDResource(ui: SparkUI) { @GET - def rddData(@PathParam("rddId") rddId: Int): RDDStorageInfo = { + def rddData(@PathParam("rddId") rddId: Int): RDDStorageInfo = { AllRDDResource.getRDDStorageInfo(rddId, ui.storageListener, true).getOrElse( throw new NotFoundException(s"no rdd found w/ id $rddId") ) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index ef3c8570d8186..2bec64f2ef02b 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -134,7 +134,7 @@ class StageData private[spark]( val accumulatorUpdates: Seq[AccumulableInfo], val tasks: Option[Map[Long, TaskData]], - val executorSummary:Option[Map[String,ExecutorStageSummary]]) + val executorSummary: Option[Map[String, ExecutorStageSummary]]) class TaskData private[spark]( val taskId: Long, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index 543df4e1350dd..7478ab0fc2f7a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -40,7 +40,7 @@ class BlockManagerSlaveEndpoint( private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) // Operations that involve removing blocks may be slow and should be done asynchronously - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RemoveBlock(blockId) => doAsync[Boolean]("removing block " + blockId, context) { blockManager.removeBlock(blockId) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala index 8569c6f3cbbc3..c5ba9af3e2658 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala @@ -17,9 +17,8 @@ package org.apache.spark.storage -import com.codahale.metrics.{Gauge,MetricRegistry} +import com.codahale.metrics.{Gauge, MetricRegistry} -import org.apache.spark.SparkContext import org.apache.spark.metrics.source.Source private[spark] class BlockManagerSource(val blockManager: BlockManager) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 0b11e914bb251..3788916cf39bb 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -137,7 +137,7 @@ private[spark] object SparkUI { jobProgressListener: JobProgressListener, securityManager: SecurityManager, appName: String, - startTime: Long): SparkUI = { + startTime: Long): SparkUI = { create(Some(sc), conf, listenerBus, securityManager, appName, jobProgressListener = Some(jobProgressListener), startTime = startTime) } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 6194c50ec8c7c..65162f4fdcd62 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -309,7 +309,7 @@ private[spark] object UIUtils extends Logging { started: Int, completed: Int, failed: Int, - skipped:Int, + skipped: Int, total: Int): Seq[Node] = { val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) val startWidth = "width: %s%%".format((started.toDouble/total)*100) diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala index 5fbcd6bb8ad94..ba03acdb38cc5 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala @@ -54,7 +54,7 @@ private[spark] object UIWorkloadGenerator { val sc = new SparkContext(conf) def setProperties(s: String): Unit = { - if(schedulingMode == SchedulingMode.FAIR) { + if (schedulingMode == SchedulingMode.FAIR) { sc.setLocalProperty("spark.scheduler.pool", s) } sc.setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, s) diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index fbce917a0824d..36943978ff594 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -33,7 +33,7 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { val parameterId = request.getParameter("id") require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") val rddId = parameterId.toInt - val rddStorageInfo = AllRDDResource.getRDDStorageInfo(rddId, listener,includeDetails = true) + val rddStorageInfo = AllRDDResource.getRDDStorageInfo(rddId, listener, includeDetails = true) .getOrElse { // Rather than crashing, render an "RDD Not Found" page return UIUtils.headerSparkPage("RDD Not Found", Seq[Node](), parent) diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 7513b1b795dea..96aa2fe164703 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -63,7 +63,7 @@ private[spark] object AkkaUtils extends Logging { conf: SparkConf, securityManager: SecurityManager): (ActorSystem, Int) = { - val akkaThreads = conf.getInt("spark.akka.threads", 4) + val akkaThreads = conf.getInt("spark.akka.threads", 4) val akkaBatchSize = conf.getInt("spark.akka.batchSize", 15) val akkaTimeoutS = conf.getTimeAsSeconds("spark.akka.timeout", conf.get("spark.network.timeout", "120s")) diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala index 9044aaeef2d48..31d230d0fec8e 100644 --- a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala @@ -42,7 +42,7 @@ abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterat private[spark] object CompletionIterator { def apply[A, I <: Iterator[A]](sub: I, completionFunction: => Unit) : CompletionIterator[A, I] = { - new CompletionIterator[A,I](sub) { + new CompletionIterator[A, I](sub) { def completion(): Unit = completionFunction } } diff --git a/core/src/main/scala/org/apache/spark/util/Distribution.scala b/core/src/main/scala/org/apache/spark/util/Distribution.scala index 9aea8efa38c7a..1bab707235b89 100644 --- a/core/src/main/scala/org/apache/spark/util/Distribution.scala +++ b/core/src/main/scala/org/apache/spark/util/Distribution.scala @@ -35,7 +35,7 @@ private[spark] class Distribution(val data: Array[Double], val startIdx: Int, va java.util.Arrays.sort(data, startIdx, endIdx) val length = endIdx - startIdx - val defaultProbabilities = Array(0,0.25,0.5,0.75,1.0) + val defaultProbabilities = Array(0, 0.25, 0.5, 0.75, 1.0) /** * Get the value of the distribution at the given probabilities. Probabilities should be @@ -44,7 +44,7 @@ private[spark] class Distribution(val data: Array[Double], val startIdx: Int, va */ def getQuantiles(probabilities: Traversable[Double] = defaultProbabilities) : IndexedSeq[Double] = { - probabilities.toIndexedSeq.map{p:Double => data(closestIndex(p))} + probabilities.toIndexedSeq.map { p: Double => data(closestIndex(p)) } } private def closestIndex(p: Double) = { diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index 2bbfc988a99a8..a8bbad086849e 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -89,7 +89,7 @@ private[spark] object MetadataCleaner { conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType, delay: Int) { - conf.set(MetadataCleanerType.systemProperty(cleanerType), delay.toString) + conf.set(MetadataCleanerType.systemProperty(cleanerType), delay.toString) } /** diff --git a/core/src/main/scala/org/apache/spark/util/MutablePair.scala b/core/src/main/scala/org/apache/spark/util/MutablePair.scala index dad888548ed10..3d95b7869f494 100644 --- a/core/src/main/scala/org/apache/spark/util/MutablePair.scala +++ b/core/src/main/scala/org/apache/spark/util/MutablePair.scala @@ -45,5 +45,5 @@ case class MutablePair[@specialized(Int, Long, Double, Char, Boolean/* , AnyRef override def toString: String = "(" + _1 + "," + _2 + ")" - override def canEqual(that: Any): Boolean = that.isInstanceOf[MutablePair[_,_]] + override def canEqual(that: Any): Boolean = that.isInstanceOf[MutablePair[_, _]] } diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index f38949c3cb846..f1f6b5e1f93d8 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -54,14 +54,14 @@ object SizeEstimator extends Logging { def estimate(obj: AnyRef): Long = estimate(obj, new IdentityHashMap[AnyRef, AnyRef]) // Sizes of primitive types - private val BYTE_SIZE = 1 + private val BYTE_SIZE = 1 private val BOOLEAN_SIZE = 1 - private val CHAR_SIZE = 2 - private val SHORT_SIZE = 2 - private val INT_SIZE = 4 - private val LONG_SIZE = 8 - private val FLOAT_SIZE = 4 - private val DOUBLE_SIZE = 8 + private val CHAR_SIZE = 2 + private val SHORT_SIZE = 2 + private val INT_SIZE = 4 + private val LONG_SIZE = 8 + private val FLOAT_SIZE = 4 + private val DOUBLE_SIZE = 8 // Fields can be primitive types, sizes are: 1, 2, 4, 8. Or fields can be pointers. The size of // a pointer is 4 or 8 depending on the JVM (32-bit or 64-bit) and UseCompressedOops flag. @@ -96,7 +96,7 @@ object SizeEstimator extends Logging { isCompressedOops = getIsCompressedOops objectSize = if (!is64bit) 8 else { - if(!isCompressedOops) { + if (!isCompressedOops) { 16 } else { 12 diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index b7a2473dfe920..763d4db690187 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -882,7 +882,7 @@ private[spark] object Utils extends Logging { // If not, we should change it to LRUCache or something. private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]() - def parseHostPort(hostPort: String): (String, Int) = { + def parseHostPort(hostPort: String): (String, Int) = { // Check cache first. val cached = hostPortParseResults.get(hostPort) if (cached != null) { diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index 41cb8cfe2afa3..9c15b1188d91c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -161,7 +161,7 @@ class BitSet(numBits: Int) extends Serializable { override def hasNext: Boolean = ind >= 0 override def next(): Int = { val tmp = ind - ind = nextSetBit(ind + 1) + ind = nextSetBit(ind + 1) tmp } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala b/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala index 4f0bf8384afc9..9a7a5a4e74868 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala @@ -90,9 +90,9 @@ class KVArraySortDataFormat[K, T <: AnyRef : ClassTag] extends SortDataFormat[K, override def swap(data: Array[T], pos0: Int, pos1: Int) { val tmpKey = data(2 * pos0) val tmpVal = data(2 * pos0 + 1) - data(2 * pos0) = data(2 * pos1) + data(2 * pos0) = data(2 * pos1) data(2 * pos0 + 1) = data(2 * pos1 + 1) - data(2 * pos1) = tmpKey + data(2 * pos1) = tmpKey data(2 * pos1 + 1) = tmpVal } diff --git a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala index 9e29bf9d61f17..effe6fa2adcfa 100644 --- a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala @@ -196,7 +196,7 @@ private[spark] object StratifiedSamplingUtils extends Logging { * * The sampling function has a unique seed per partition. */ - def getBernoulliSamplingFunction[K, V](rdd: RDD[(K, V)], + def getBernoulliSamplingFunction[K, V](rdd: RDD[(K, V)], fractions: Map[K, Double], exact: Boolean, seed: Long): (Int, Iterator[(K, V)]) => Iterator[(K, V)] = { diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 75399461f2a5f..746a40a21bf9e 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -103,7 +103,7 @@ class AccumulatorSuite extends FunSuite with Matchers with LocalSparkContext { sc = new SparkContext("local[" + nThreads + "]", "test") val setAcc = sc.accumulableCollection(mutable.HashSet[Int]()) val bufferAcc = sc.accumulableCollection(mutable.ArrayBuffer[Int]()) - val mapAcc = sc.accumulableCollection(mutable.HashMap[Int,String]()) + val mapAcc = sc.accumulableCollection(mutable.HashMap[Int, String]()) val d = sc.parallelize((1 to maxI) ++ (1 to maxI)) d.foreach { x => {setAcc += x; bufferAcc += x; mapAcc += (x -> x.toString)} diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index e1faddeabec79..91d8fdedbe0f3 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -218,10 +218,10 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { val pairRDD = generateFatPairRDD() pairRDD.checkpoint() val unionRDD = new PartitionerAwareUnionRDD(sc, Array(pairRDD)) - val partitionBeforeCheckpoint = serializeDeserialize( + val partitionBeforeCheckpoint = serializeDeserialize( unionRDD.partitions.head.asInstanceOf[PartitionerAwareUnionRDDPartition]) pairRDD.count() - val partitionAfterCheckpoint = serializeDeserialize( + val partitionAfterCheckpoint = serializeDeserialize( unionRDD.partitions.head.asInstanceOf[PartitionerAwareUnionRDDPartition]) assert( partitionBeforeCheckpoint.parents.head.getClass != diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 0922a2c3599cc..4a48f6580c78e 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -158,7 +158,7 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { rdd.count() // Test that GC does not cause RDD cleanup due to a strong reference - val preGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) + val preGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) runGC() intercept[Exception] { preGCTester.assertCleanup()(timeout(1000 millis)) @@ -195,7 +195,7 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { var broadcast = newBroadcast() // Test that GC does not cause broadcast cleanup due to a strong reference - val preGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) + val preGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) runGC() intercept[Exception] { preGCTester.assertCleanup()(timeout(1000 millis)) @@ -267,7 +267,7 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { val shuffleIds = 0 until sc.newShuffleId val broadcastIds = broadcastBuffer.map(_.id) - val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) + val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) runGC() intercept[Exception] { preGCTester.assertCleanup()(timeout(1000 millis)) diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index 1212d0b43207d..cade1fda2c7be 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -57,7 +57,7 @@ class FailureSuite extends FunSuite with LocalSparkContext { FailureSuiteState.synchronized { assert(FailureSuiteState.tasksRun === 4) } - assert(results.toList === List(1,4,9)) + assert(results.toList === List(1, 4, 9)) FailureSuiteState.clear() } diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index c0439f934813e..bff2d10b9946c 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -81,7 +81,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { test("Distributing files locally") { sc = new SparkContext("local[4]", "test", newConf) sc.addFile(tmpFile.toString) - val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) + val testData = Array((1, 1), (1, 1), (2, 1), (3, 5), (2, 2), (3, 0)) val result = sc.parallelize(testData).reduceByKey { val path = SparkFiles.get("FileServerSuite.txt") val in = new BufferedReader(new FileReader(path)) @@ -89,7 +89,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { in.close() _ * fileVal + _ * fileVal }.collect() - assert(result.toSet === Set((1,200), (2,300), (3,500))) + assert(result.toSet === Set((1, 200), (2, 300), (3, 500))) } test("Distributing files locally security On") { @@ -100,7 +100,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { sc.addFile(tmpFile.toString) assert(sc.env.securityManager.isAuthenticationEnabled() === true) - val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) + val testData = Array((1, 1), (1, 1), (2, 1), (3, 5), (2, 2), (3, 0)) val result = sc.parallelize(testData).reduceByKey { val path = SparkFiles.get("FileServerSuite.txt") val in = new BufferedReader(new FileReader(path)) @@ -108,14 +108,14 @@ class FileServerSuite extends FunSuite with LocalSparkContext { in.close() _ * fileVal + _ * fileVal }.collect() - assert(result.toSet === Set((1,200), (2,300), (3,500))) + assert(result.toSet === Set((1, 200), (2, 300), (3, 500))) } test("Distributing files locally using URL as input") { // addFile("file:///....") sc = new SparkContext("local[4]", "test", newConf) sc.addFile(new File(tmpFile.toString).toURI.toString) - val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) + val testData = Array((1, 1), (1, 1), (2, 1), (3, 5), (2, 2), (3, 0)) val result = sc.parallelize(testData).reduceByKey { val path = SparkFiles.get("FileServerSuite.txt") val in = new BufferedReader(new FileReader(path)) @@ -123,7 +123,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { in.close() _ * fileVal + _ * fileVal }.collect() - assert(result.toSet === Set((1,200), (2,300), (3,500))) + assert(result.toSet === Set((1, 200), (2, 300), (3, 500))) } test ("Dynamically adding JARS locally") { @@ -140,7 +140,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { test("Distributing files on a standalone cluster") { sc = new SparkContext("local-cluster[1,1,512]", "test", newConf) sc.addFile(tmpFile.toString) - val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) + val testData = Array((1, 1), (1, 1), (2, 1), (3, 5), (2, 2), (3, 0)) val result = sc.parallelize(testData).reduceByKey { val path = SparkFiles.get("FileServerSuite.txt") val in = new BufferedReader(new FileReader(path)) @@ -148,13 +148,13 @@ class FileServerSuite extends FunSuite with LocalSparkContext { in.close() _ * fileVal + _ * fileVal }.collect() - assert(result.toSet === Set((1,200), (2,300), (3,500))) + assert(result.toSet === Set((1, 200), (2, 300), (3, 500))) } test ("Dynamically adding JARS on a standalone cluster") { sc = new SparkContext("local-cluster[1,1,512]", "test", newConf) sc.addJar(tmpJarUrl) - val testData = Array((1,1)) + val testData = Array((1, 1)) sc.parallelize(testData).foreach { x => if (Thread.currentThread.getContextClassLoader.getResource("FileServerSuite.txt") == null) { throw new SparkException("jar not added") @@ -165,7 +165,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { test ("Dynamically adding JARS on a standalone cluster using local: URL") { sc = new SparkContext("local-cluster[1,1,512]", "test", newConf) sc.addJar(tmpJarUrl.replace("file", "local")) - val testData = Array((1,1)) + val testData = Array((1, 1)) sc.parallelize(testData).foreach { x => if (Thread.currentThread.getContextClassLoader.getResource("FileServerSuite.txt") == null) { throw new SparkException("jar not added") diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index c8f08eed47c76..d67de8692df62 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -334,7 +334,7 @@ class FileSuite extends FunSuite with LocalSparkContext { } val copyRdd = mappedRdd.flatMap { curData: (String, PortableDataStream) => - for(i <- 1 to numOfCopies) yield (i, curData._2) + for (i <- 1 to numOfCopies) yield (i, curData._2) } val copyArr: Array[(Int, PortableDataStream)] = copyRdd.collect() diff --git a/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala b/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala index 51348c039b5c9..69314deda1f03 100644 --- a/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala @@ -44,11 +44,11 @@ private object ImplicitOrderingSuite { class NonOrderedClass {} class ComparableClass extends Comparable[ComparableClass] { - override def compareTo(o: ComparableClass): Int = ??? + override def compareTo(o: ComparableClass): Int = throw new UnsupportedOperationException } class OrderedClass extends Ordered[OrderedClass] { - override def compare(o: OrderedClass): Int = ??? + override def compare(o: OrderedClass): Int = throw new UnsupportedOperationException } def basicMapExpectations(rdd: RDD[Int]): List[(Boolean, String)] = { diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index fafa4ed606b08..fafc9d47503b7 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -34,18 +34,18 @@ class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemPro val conf = new SparkConf() // Simply exercise the API, we don't need a complete conversion test since that's handled in // UtilsSuite.scala - assert(conf.getSizeAsBytes("fake","1k") === ByteUnit.KiB.toBytes(1)) - assert(conf.getSizeAsKb("fake","1k") === ByteUnit.KiB.toKiB(1)) - assert(conf.getSizeAsMb("fake","1k") === ByteUnit.KiB.toMiB(1)) - assert(conf.getSizeAsGb("fake","1k") === ByteUnit.KiB.toGiB(1)) + assert(conf.getSizeAsBytes("fake", "1k") === ByteUnit.KiB.toBytes(1)) + assert(conf.getSizeAsKb("fake", "1k") === ByteUnit.KiB.toKiB(1)) + assert(conf.getSizeAsMb("fake", "1k") === ByteUnit.KiB.toMiB(1)) + assert(conf.getSizeAsGb("fake", "1k") === ByteUnit.KiB.toGiB(1)) } test("Test timeString conversion") { val conf = new SparkConf() // Simply exercise the API, we don't need a complete conversion test since that's handled in // UtilsSuite.scala - assert(conf.getTimeAsMs("fake","1ms") === TimeUnit.MILLISECONDS.toMillis(1)) - assert(conf.getTimeAsSeconds("fake","1000ms") === TimeUnit.MILLISECONDS.toSeconds(1000)) + assert(conf.getTimeAsMs("fake", "1ms") === TimeUnit.MILLISECONDS.toMillis(1)) + assert(conf.getTimeAsSeconds("fake", "1000ms") === TimeUnit.MILLISECONDS.toSeconds(1000)) } test("loading from system properties") { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 9049db7755358..31ef5cd75bd4a 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -222,8 +222,8 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { val dir1 = Utils.createTempDir() val dir2 = Utils.createTempDir() - val dirpath1=dir1.getAbsolutePath - val dirpath2=dir2.getAbsolutePath + val dirpath1 = dir1.getAbsolutePath + val dirpath2 = dir2.getAbsolutePath // file1 and file2 are placed inside dir1, they are also used for // textFile, hadoopFile, and newAPIHadoopFile @@ -235,11 +235,11 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { val file4 = new File(dir2, "part-00001") val file5 = new File(dir2, "part-00002") - val filepath1=file1.getAbsolutePath - val filepath2=file2.getAbsolutePath - val filepath3=file3.getAbsolutePath - val filepath4=file4.getAbsolutePath - val filepath5=file5.getAbsolutePath + val filepath1 = file1.getAbsolutePath + val filepath2 = file2.getAbsolutePath + val filepath3 = file3.getAbsolutePath + val filepath4 = file4.getAbsolutePath + val filepath5 = file5.getAbsolutePath try { diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 06e5f1cf6b96f..c38e306b6ac40 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -286,7 +286,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { assert(statuses.size === expectedNumBlocks) } - testUnpersistBroadcast(distributed, numSlaves, torrentConf, afterCreation, + testUnpersistBroadcast(distributed, numSlaves, torrentConf, afterCreation, afterUsingBroadcast, afterUnpersist, removeFromDriver) } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala index 7cc2104281464..e432b8e94654a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala @@ -66,7 +66,7 @@ class WorkerArgumentsTest extends FunSuite { } } val conf = new MySparkConf() - val workerArgs = new WorkerArguments(args, conf) + val workerArgs = new WorkerArguments(args, conf) assert(workerArgs.memory === 5120) } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index 450fba21f4b5c..93a779d5ce6f2 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.{Matchers, FunSuite} class WorkerSuite extends FunSuite with Matchers { def cmd(javaOpts: String*): Command = { - Command("", Seq.empty, Map.empty, Seq.empty, Seq.empty, Seq(javaOpts:_*)) + Command("", Seq.empty, Map.empty, Seq.empty, Seq.empty, Seq(javaOpts : _*)) } def conf(opts: (String, String)*): SparkConf = new SparkConf(loadDefaults = false).setAll(opts) diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index ef3e213f1fcce..60dba3b2d6719 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -263,7 +263,7 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext val tmpRdd = sc.textFile(tmpFilePath, numPartitions) - val firstSize= runAndReturnBytesRead { + val firstSize = runAndReturnBytesRead { aRdd.count() } val secondSize = runAndReturnBytesRead { @@ -433,10 +433,10 @@ class OldCombineTextRecordReaderWrapper( /** * Hadoop 2 has a version of this, but we can't use it for backwards compatibility */ -class NewCombineTextInputFormat extends NewCombineFileInputFormat[LongWritable,Text] { +class NewCombineTextInputFormat extends NewCombineFileInputFormat[LongWritable, Text] { def createRecordReader(split: NewInputSplit, context: TaskAttemptContext) : NewRecordReader[LongWritable, Text] = { - new NewCombineFileRecordReader[LongWritable,Text](split.asInstanceOf[NewCombineFileSplit], + new NewCombineFileRecordReader[LongWritable, Text](split.asInstanceOf[NewCombineFileSplit], context, classOf[NewCombineTextRecordReaderWrapper]) } } diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index ca0d953d306d8..6564232986cfa 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -512,17 +512,17 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { } test("lookup") { - val pairs = sc.parallelize(Array((1,2), (3,4), (5,6), (5,7))) + val pairs = sc.parallelize(Array((1, 2), (3, 4), (5, 6), (5, 7))) assert(pairs.partitioner === None) assert(pairs.lookup(1) === Seq(2)) - assert(pairs.lookup(5) === Seq(6,7)) + assert(pairs.lookup(5) === Seq(6, 7)) assert(pairs.lookup(-1) === Seq()) } test("lookup with partitioner") { - val pairs = sc.parallelize(Array((1,2), (3,4), (5,6), (5,7))) + val pairs = sc.parallelize(Array((1, 2), (3, 4), (5, 6), (5, 7))) val p = new Partitioner { def numPartitions: Int = 2 @@ -533,12 +533,12 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { assert(shuffled.partitioner === Some(p)) assert(shuffled.lookup(1) === Seq(2)) - assert(shuffled.lookup(5) === Seq(6,7)) + assert(shuffled.lookup(5) === Seq(6, 7)) assert(shuffled.lookup(-1) === Seq()) } test("lookup with bad partitioner") { - val pairs = sc.parallelize(Array((1,2), (3,4), (5,6), (5,7))) + val pairs = sc.parallelize(Array((1, 2), (3, 4), (5, 6), (5, 7))) val p = new Partitioner { def numPartitions: Int = 2 diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index afc11bdc4d6ab..8079d5dcaea81 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -338,10 +338,10 @@ class RDDSuite extends FunSuite with SharedSparkContext { } test("coalesced RDDs with locality") { - val data3 = sc.makeRDD(List((1,List("a","c")), (2,List("a","b","c")), (3,List("b")))) + val data3 = sc.makeRDD(List((1, List("a", "c")), (2, List("a", "b", "c")), (3, List("b")))) val coal3 = data3.coalesce(3) val list3 = coal3.partitions.flatMap(_.asInstanceOf[CoalescedRDDPartition].preferredLocation) - assert(list3.sorted === Array("a","b","c"), "Locality preferences are dropped") + assert(list3.sorted === Array("a", "b", "c"), "Locality preferences are dropped") // RDD with locality preferences spread (non-randomly) over 6 machines, m0 through m5 val data = sc.makeRDD((1 to 9).map(i => (i, (i to (i + 2)).map{ j => "m" + (j%6)}))) @@ -591,8 +591,8 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(sc.emptyRDD.isEmpty()) assert(sc.parallelize(Seq[Int]()).isEmpty()) assert(!sc.parallelize(Seq(1)).isEmpty()) - assert(sc.parallelize(Seq(1,2,3), 3).filter(_ < 0).isEmpty()) - assert(!sc.parallelize(Seq(1,2,3), 3).filter(_ > 1).isEmpty()) + assert(sc.parallelize(Seq(1, 2, 3), 3).filter(_ < 0).isEmpty()) + assert(!sc.parallelize(Seq(1, 2, 3), 3).filter(_ > 1).isEmpty()) } test("sample preserves partitioner") { @@ -609,49 +609,49 @@ class RDDSuite extends FunSuite with SharedSparkContext { val data = sc.parallelize(1 to n, 2) for (num <- List(5, 20, 100)) { - val sample = data.takeSample(withReplacement=false, num=num) + val sample = data.takeSample(withReplacement = false, num = num) assert(sample.size === num) // Got exactly num elements assert(sample.toSet.size === num) // Elements are distinct assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=false, 20, seed) + val sample = data.takeSample(withReplacement = false, 20, seed) assert(sample.size === 20) // Got exactly 20 elements assert(sample.toSet.size === 20) // Elements are distinct assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=false, 100, seed) + val sample = data.takeSample(withReplacement = false, 100, seed) assert(sample.size === 100) // Got only 100 elements assert(sample.toSet.size === 100) // Elements are distinct assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, 20, seed) + val sample = data.takeSample(withReplacement = true, 20, seed) assert(sample.size === 20) // Got exactly 20 elements assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } { - val sample = data.takeSample(withReplacement=true, num=20) + val sample = data.takeSample(withReplacement = true, num = 20) assert(sample.size === 20) // Got exactly 100 elements assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements") assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } { - val sample = data.takeSample(withReplacement=true, num=n) + val sample = data.takeSample(withReplacement = true, num = n) assert(sample.size === n) // Got exactly 100 elements // Chance of getting all distinct elements is astronomically low, so test we got < 100 assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, n, seed) + val sample = data.takeSample(withReplacement = true, n, seed) assert(sample.size === n) // Got exactly 100 elements // Chance of getting all distinct elements is astronomically low, so test we got < 100 assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, 2 * n, seed) + val sample = data.takeSample(withReplacement = true, 2 * n, seed) assert(sample.size === 2 * n) // Got exactly 200 elements // Chance of getting all distinct elements is still quite low, so test we got < 100 assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") @@ -691,7 +691,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { } test("sortByKey") { - val data = sc.parallelize(Seq("5|50|A","4|60|C", "6|40|B")) + val data = sc.parallelize(Seq("5|50|A", "4|60|C", "6|40|B")) val col1 = Array("4|60|C", "5|50|A", "6|40|B") val col2 = Array("6|40|B", "5|50|A", "4|60|C") @@ -703,7 +703,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { } test("sortByKey ascending parameter") { - val data = sc.parallelize(Seq("5|50|A","4|60|C", "6|40|B")) + val data = sc.parallelize(Seq("5|50|A", "4|60|C", "6|40|B")) val asc = Array("4|60|C", "5|50|A", "6|40|B") val desc = Array("6|40|B", "5|50|A", "4|60|C") @@ -764,9 +764,9 @@ class RDDSuite extends FunSuite with SharedSparkContext { } test("intersection strips duplicates in an input") { - val a = sc.parallelize(Seq(1,2,3,3)) - val b = sc.parallelize(Seq(1,1,2,3)) - val intersection = Array(1,2,3) + val a = sc.parallelize(Seq(1, 2, 3, 3)) + val b = sc.parallelize(Seq(1, 1, 2, 3)) + val intersection = Array(1, 2, 3) assert(a.intersection(b).collect().sorted === intersection) assert(b.intersection(a).collect().sorted === intersection) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala index fe695d85e29dd..194dc45d6e399 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala @@ -21,11 +21,11 @@ object RDDSuiteUtils { case class Person(first: String, last: String, age: Int) object AgeOrdering extends Ordering[Person] { - def compare(a:Person, b:Person): Int = a.age.compare(b.age) + def compare(a: Person, b: Person): Int = a.age.compare(b.age) } object NameOrdering extends Ordering[Person] { - def compare(a:Person, b:Person): Int = - implicitly[Ordering[Tuple2[String,String]]].compare((a.last, a.first), (b.last, b.first)) + def compare(a: Person, b: Person): Int = + implicitly[Ordering[Tuple2[String, String]]].compare((a.last, a.first), (b.last, b.first)) } } diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala index 64b1c24c47168..54fc914722b46 100644 --- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala @@ -26,7 +26,7 @@ class SortingSuite extends FunSuite with SharedSparkContext with Matchers with L test("sortByKey") { val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)), 2) - assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0))) + assert(pairs.sortByKey().collect() === Array((0, 0), (1, 0), (2, 0), (3, 0))) } test("large array") { @@ -136,7 +136,7 @@ class SortingSuite extends FunSuite with SharedSparkContext with Matchers with L test("get a range of elements in an array not partitioned by a range partitioner") { val pairArr = util.Random.shuffle((1 to 1000).toList).map(x => (x, x)) - val pairs = sc.parallelize(pairArr,10) + val pairs = sc.parallelize(pairArr, 10) val range = pairs.filterByRange(200, 800).collect() assert((800 to 200 by -1).toArray.sorted === range.map(_._1).sorted) } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index ae3339d80f9c6..21eb71d9acfbd 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -42,7 +42,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } override def afterAll(): Unit = { - if(env != null) { + if (env != null) { env.shutdown() } } @@ -75,7 +75,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote" ,13345) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "send-remotely") try { @@ -338,7 +338,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { test("call receive in sequence") { // If a RpcEnv implementation breaks the `receive` contract, hope this test can expose it - for(i <- 0 until 100) { + for (i <- 0 until 100) { @volatile var result = 0 val endpointRef = env.setupEndpoint(s"receive-in-sequence-$i", new ThreadSafeRpcEndpoint { override val rpcEnv = env diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala index f77661ccbd1c5..3821166386fa6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala @@ -26,8 +26,8 @@ class CoarseGrainedSchedulerBackendSuite extends FunSuite with LocalSparkContext test("serialized task larger than akka frame size") { val conf = new SparkConf - conf.set("spark.akka.frameSize","1") - conf.set("spark.default.parallelism","1") + conf.set("spark.akka.frameSize", "1") + conf.set("spark.default.parallelism", "1") sc = new SparkContext("local-cluster[2 , 1 , 512]", "test", conf) val frameSize = AkkaUtils.maxFrameSizeBytes(sc.conf) val buffer = new SerializableBuffer(java.nio.ByteBuffer.allocate(2 * frameSize)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 46642236e454a..eea7a600841cc 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -375,7 +375,7 @@ class DAGSchedulerSuite (1 to 30).foreach(_ => rdd = rdd.zip(rdd)) // getPreferredLocs runs quickly, indicating that exponential graph traversal is avoided. failAfter(10 seconds) { - val preferredLocs = scheduler.getPreferredLocs(rdd,0) + val preferredLocs = scheduler.getPreferredLocs(rdd, 0) // No preferred locations are returned. assert(preferredLocs.length === 0) } @@ -634,8 +634,8 @@ class DAGSchedulerSuite val listener1 = new FailureRecordingJobListener() val listener2 = new FailureRecordingJobListener() - submit(reduceRdd1, Array(0, 1), listener=listener1) - submit(reduceRdd2, Array(0, 1), listener=listener2) + submit(reduceRdd1, Array(0, 1), listener = listener1) + submit(reduceRdd2, Array(0, 1), listener = listener2) val stageFailureMessage = "Exception failure in map stage" failed(taskSets(0), stageFailureMessage) diff --git a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala index e8f461e2f56c9..456451b676bed 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala @@ -97,9 +97,9 @@ class PoolSuite extends FunSuite with LocalSparkContext { assert(rootPool.getSchedulableByName("3").weight === 1) val properties1 = new Properties() - properties1.setProperty("spark.scheduler.pool","1") + properties1.setProperty("spark.scheduler.pool", "1") val properties2 = new Properties() - properties2.setProperty("spark.scheduler.pool","2") + properties2.setProperty("spark.scheduler.pool", "2") val taskSetManager10 = createTaskSetManager(0, 1, taskScheduler) val taskSetManager11 = createTaskSetManager(1, 1, taskScheduler) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index ef50bc9438f95..14c0172fa96ab 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -109,7 +109,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { check((1, 1)) check((1, 1L)) check((1L, 1)) - check((1L, 1L)) + check((1L, 1L)) check((1.0, 1)) check((1, 1.0)) check((1.0, 1.0)) @@ -147,7 +147,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4)))) check(List( mutable.HashMap("one" -> 1, "two" -> 2), - mutable.HashMap(1->"one",2->"two",3->"three"))) + mutable.HashMap(1->"one", 2->"two", 3->"three"))) } test("ranges") { diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala index 433fd6bb4a11d..673948d84d82b 100644 --- a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala @@ -66,18 +66,18 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex } private def xmap(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.map(y=>uc.op(y)) + x.map(y => uc.op(y)) private def xflatMap(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.flatMap(y=>Seq(uc.op(y))) + x.flatMap(y => Seq(uc.op(y))) private def xfilter(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.filter(y=>uc.pred(y)) + x.filter(y => uc.pred(y)) private def xmapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.mapPartitions(_.map(y=>uc.op(y))) + x.mapPartitions(_.map(y => uc.op(y))) private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.mapPartitionsWithIndex((_, it) => it.map(y=>uc.op(y))) + x.mapPartitionsWithIndex((_, it) => it.map(y => uc.op(y))) } diff --git a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala index 86fcf447287f7..c1e0a29a34bb1 100644 --- a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala +++ b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala @@ -32,16 +32,19 @@ class TestSerializer extends Serializer { class TestSerializerInstance extends SerializerInstance { - override def serialize[T: ClassTag](t: T): ByteBuffer = ??? + override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException - override def serializeStream(s: OutputStream): SerializationStream = ??? + override def serializeStream(s: OutputStream): SerializationStream = + throw new UnsupportedOperationException override def deserializeStream(s: InputStream): TestDeserializationStream = new TestDeserializationStream - override def deserialize[T: ClassTag](bytes: ByteBuffer): T = ??? + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = + throw new UnsupportedOperationException - override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = ??? + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = + throw new UnsupportedOperationException } diff --git a/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala index bcf138b5ee6d0..47341b74e9c0f 100644 --- a/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala @@ -59,10 +59,10 @@ class FlatmapIteratorSuite extends FunSuite with LocalSparkContext { .set("spark.serializer.objectStreamReset", "10") sc = new SparkContext(sconf) val expand_size = 500 - val data = sc.parallelize(Seq(1,2)). + val data = sc.parallelize(Seq(1, 2)). flatMap(x => Stream.range(1, expand_size). - map(y => "%d: string test %d".format(y,x))) - var persisted = data.persist(StorageLevel.MEMORY_ONLY_SER) + map(y => "%d: string test %d".format(y, x))) + val persisted = data.persist(StorageLevel.MEMORY_ONLY_SER) assert(persisted.filter(_.startsWith("1:")).count()===2) } diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index b6f5accef0cef..a727a43f44dfc 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -483,11 +483,11 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before val jobsJson = getJson(sc.ui.get, "jobs") jobsJson.children.size should be (expJobInfo.size) for { - (job @ JObject(_),idx) <- jobsJson.children.zipWithIndex + (job @ JObject(_), idx) <- jobsJson.children.zipWithIndex id = (job \ "jobId").extract[String] name = (job \ "name").extract[String] } { - withClue(s"idx = $idx; id = $id; name = ${name.substring(0,20)}") { + withClue(s"idx = $idx; id = $id; name = ${name.substring(0, 20)}") { id should be (expJobInfo(idx)._1) name should include (expJobInfo(idx)._2) } @@ -540,12 +540,12 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before goToUi(sc, "/stages/stage/?id=12&attempt=0") find("no-info").get.text should be ("No information to display for Stage 12 (Attempt 0)") - val badStage = HistoryServerSuite.getContentAndCode(apiUrl(sc.ui.get,"stages/12/0")) + val badStage = HistoryServerSuite.getContentAndCode(apiUrl(sc.ui.get, "stages/12/0")) badStage._1 should be (HttpServletResponse.SC_NOT_FOUND) badStage._2 should be (None) badStage._3 should be (Some("unknown stage: 12")) - val badAttempt = HistoryServerSuite.getContentAndCode(apiUrl(sc.ui.get,"stages/19/15")) + val badAttempt = HistoryServerSuite.getContentAndCode(apiUrl(sc.ui.get, "stages/19/15")) badAttempt._1 should be (HttpServletResponse.SC_NOT_FOUND) badAttempt._2 should be (None) badAttempt._3 should be (Some("unknown attempt for stage 19. Found attempts: [0]")) diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala index 7b38e6d9473e1..8778042e34657 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -169,7 +169,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { test("verify StorageTab contains all cached rdds") { val rddInfo0 = new RDDInfo(0, "rdd0", 1, memOnly, Seq(4)) - val rddInfo1 = new RDDInfo(1, "rdd1", 1 ,memOnly, Seq(4)) + val rddInfo1 = new RDDInfo(1, "rdd1", 1, memOnly, Seq(4)) val stageInfo0 = new StageInfo(0, 0, "stage0", 1, Seq(rddInfo0), Seq.empty, "details") val stageInfo1 = new StageInfo(1, 0, "stage1", 1, Seq(rddInfo1), Seq.empty, "details") val taskMetrics0 = new TaskMetrics diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index bec79fc4dc8f7..ccdb3f571429d 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -138,7 +138,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(securityManagerGood.isAuthenticationEnabled() === true) - val slaveRpcEnv =RpcEnv.create("spark-slave", hostname, 0, goodconf, securityManagerGood) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, goodconf, securityManagerGood) val slaveTracker = new MapOutputTrackerWorker(conf) slaveTracker.trackerEndpoint = slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 61152c29a681f..afa5cdc819746 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -551,7 +551,7 @@ class UtilsSuite extends FunSuite with ResetSystemProperties with Logging { test("fetch hcfs dir") { val tempDir = Utils.createTempDir() val sourceDir = new File(tempDir, "source-dir") - val innerSourceDir = Utils.createTempDir(root=sourceDir.getPath) + val innerSourceDir = Utils.createTempDir(root = sourceDir.getPath) val sourceFile = File.createTempFile("someprefix", "somesuffix", innerSourceDir) val targetDir = new File(tempDir, "target-dir") Files.write("some text", sourceFile, UTF_8) diff --git a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala index b85a409a4b2e9..ffc206991906a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala @@ -94,7 +94,7 @@ class BitSetSuite extends FunSuite { test( "xor len(bitsetX) > len(bitsetY)" ) { val setBitsX = Seq( 0, 1, 3, 37, 38, 41, 85) - val setBitsY = Seq( 0, 2, 3, 37, 41 ) + val setBitsY = Seq( 0, 2, 3, 37, 41) val bitsetX = new BitSet(100) setBitsX.foreach( i => bitsetX.set(i)) val bitsetY = new BitSet(60) From b069ad23d9b6cbfb3a8bf245547add4816669075 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 28 May 2015 20:17:16 -0700 Subject: [PATCH 054/254] [SPARK-7927] whitespace fixes for GraphX. So we can enable a whitespace enforcement rule in the style checker to save code review time. Author: Reynold Xin Closes #6474 from rxin/whitespace-graphx and squashes the following commits: 4d3cd26 [Reynold Xin] Fixed tests. 869dde4 [Reynold Xin] [SPARK-7927] whitespace fixes for GraphX. --- .../org/apache/spark/graphx/EdgeDirection.scala | 4 ++-- .../org/apache/spark/graphx/EdgeTriplet.scala | 2 +- .../scala/org/apache/spark/graphx/Graph.scala | 2 +- .../scala/org/apache/spark/graphx/GraphOps.scala | 10 +++++----- .../scala/org/apache/spark/graphx/Pregel.scala | 8 ++++---- .../apache/spark/graphx/impl/EdgePartition.scala | 4 ++-- .../org/apache/spark/graphx/lib/PageRank.scala | 8 ++++---- .../org/apache/spark/graphx/lib/SVDPlusPlus.scala | 2 +- .../apache/spark/graphx/lib/TriangleCount.scala | 4 ++-- .../spark/graphx/util/GraphGenerators.scala | 9 +++++---- .../org/apache/spark/graphx/GraphOpsSuite.scala | 6 +++--- .../org/apache/spark/graphx/GraphSuite.scala | 6 +++--- .../graphx/lib/ConnectedComponentsSuite.scala | 15 +++++++++------ .../apache/spark/graphx/lib/PageRankSuite.scala | 14 +++++++------- .../spark/graphx/lib/TriangleCountSuite.scala | 2 +- 15 files changed, 50 insertions(+), 46 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala index 058c8c8aa1b24..ce1054ed92ba1 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala @@ -26,8 +26,8 @@ class EdgeDirection private (private val name: String) extends Serializable { * out becomes in and both and either remain the same. */ def reverse: EdgeDirection = this match { - case EdgeDirection.In => EdgeDirection.Out - case EdgeDirection.Out => EdgeDirection.In + case EdgeDirection.In => EdgeDirection.Out + case EdgeDirection.Out => EdgeDirection.In case EdgeDirection.Either => EdgeDirection.Either case EdgeDirection.Both => EdgeDirection.Both } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala index c8790cac3d8a0..65f82429d2029 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala @@ -37,7 +37,7 @@ class EdgeTriplet[VD, ED] extends Edge[ED] { /** * Set the edge properties of this triplet. */ - protected[spark] def set(other: Edge[ED]): EdgeTriplet[VD,ED] = { + protected[spark] def set(other: Edge[ED]): EdgeTriplet[VD, ED] = { srcId = other.srcId dstId = other.dstId attr = other.attr diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index 36dc7b0f86c89..db73a8abc5733 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -316,7 +316,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * satisfy the predicates */ def subgraph( - epred: EdgeTriplet[VD,ED] => Boolean = (x => true), + epred: EdgeTriplet[VD, ED] => Boolean = (x => true), vpred: (VertexId, VD) => Boolean = ((v, d) => true)) : Graph[VD, ED] diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index 7edd627b20918..9451ff1e5c0e2 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -124,18 +124,18 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexId, VD)]] = { val nbrs = edgeDirection match { case EdgeDirection.Either => - graph.aggregateMessages[Array[(VertexId,VD)]]( + graph.aggregateMessages[Array[(VertexId, VD)]]( ctx => { ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))) ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))) }, (a, b) => a ++ b, TripletFields.All) case EdgeDirection.In => - graph.aggregateMessages[Array[(VertexId,VD)]]( + graph.aggregateMessages[Array[(VertexId, VD)]]( ctx => ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))), (a, b) => a ++ b, TripletFields.Src) case EdgeDirection.Out => - graph.aggregateMessages[Array[(VertexId,VD)]]( + graph.aggregateMessages[Array[(VertexId, VD)]]( ctx => ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))), (a, b) => a ++ b, TripletFields.Dst) case EdgeDirection.Both => @@ -253,7 +253,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali def filter[VD2: ClassTag, ED2: ClassTag]( preprocess: Graph[VD, ED] => Graph[VD2, ED2], epred: (EdgeTriplet[VD2, ED2]) => Boolean = (x: EdgeTriplet[VD2, ED2]) => true, - vpred: (VertexId, VD2) => Boolean = (v:VertexId, d:VD2) => true): Graph[VD, ED] = { + vpred: (VertexId, VD2) => Boolean = (v: VertexId, d: VD2) => true): Graph[VD, ED] = { graph.mask(preprocess(graph).subgraph(epred, vpred)) } @@ -356,7 +356,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali maxIterations: Int = Int.MaxValue, activeDirection: EdgeDirection = EdgeDirection.Either)( vprog: (VertexId, VD, A) => VD, - sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId,A)], + sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], mergeMsg: (A, A) => A) : Graph[VD, ED] = { Pregel(graph, initialMsg, maxIterations, activeDirection)(vprog, sendMsg, mergeMsg) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 01b013ff716fc..cfcf7244eaed5 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -147,10 +147,10 @@ object Pregel extends Logging { logInfo("Pregel finished iteration " + i) // Unpersist the RDDs hidden by newly-materialized RDDs - oldMessages.unpersist(blocking=false) - newVerts.unpersist(blocking=false) - prevG.unpersistVertices(blocking=false) - prevG.edges.unpersist(blocking=false) + oldMessages.unpersist(blocking = false) + newVerts.unpersist(blocking = false) + prevG.unpersistVertices(blocking = false) + prevG.edges.unpersist(blocking = false) // count the iteration i += 1 } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index c561570809253..ab021a252eb8a 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -156,8 +156,8 @@ class EdgePartition[ val size = data.size var i = 0 while (i < size) { - edge.srcId = srcIds(i) - edge.dstId = dstIds(i) + edge.srcId = srcIds(i) + edge.dstId = dstIds(i) edge.attr = data(i) newData(i) = f(edge) i += 1 diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index bc974b2f04e70..8c0a461e99fa4 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -116,7 +116,7 @@ object PageRank extends Logging { val personalized = srcId isDefined val src: VertexId = srcId.getOrElse(-1L) - def delta(u: VertexId, v: VertexId):Double = { if (u == v) 1.0 else 0.0 } + def delta(u: VertexId, v: VertexId): Double = { if (u == v) 1.0 else 0.0 } var iteration = 0 var prevRankGraph: Graph[Double, Double] = null @@ -133,13 +133,13 @@ object PageRank extends Logging { // edge partitions. prevRankGraph = rankGraph val rPrb = if (personalized) { - (src: VertexId ,id: VertexId) => resetProb * delta(src,id) + (src: VertexId , id: VertexId) => resetProb * delta(src, id) } else { (src: VertexId, id: VertexId) => resetProb } rankGraph = rankGraph.joinVertices(rankUpdates) { - (id, oldRank, msgSum) => rPrb(src,id) + (1.0 - resetProb) * msgSum + (id, oldRank, msgSum) => rPrb(src, id) + (1.0 - resetProb) * msgSum }.cache() rankGraph.edges.foreachPartition(x => {}) // also materializes rankGraph.vertices @@ -243,7 +243,7 @@ object PageRank extends Logging { // Execute a dynamic version of Pregel. val vp = if (personalized) { - (id: VertexId, attr: (Double, Double),msgSum: Double) => + (id: VertexId, attr: (Double, Double), msgSum: Double) => personalizedVertexProgram(id, attr, msgSum) } else { (id: VertexId, attr: (Double, Double), msgSum: Double) => diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index 3b0e1628d86b5..9cb24ed080e1c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -210,7 +210,7 @@ object SVDPlusPlus { /** * Forces materialization of a Graph by count()ing its RDDs. */ - private def materialize(g: Graph[_,_]): Unit = { + private def materialize(g: Graph[_, _]): Unit = { g.vertices.count() g.edges.count() } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala index daf162085e3e4..a5d598053f9ca 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala @@ -38,7 +38,7 @@ import org.apache.spark.graphx._ */ object TriangleCount { - def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD,ED]): Graph[Int, ED] = { + def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[Int, ED] = { // Remove redundant edges val g = graph.groupEdges((a, b) => a).cache() @@ -49,7 +49,7 @@ object TriangleCount { var i = 0 while (i < nbrs.size) { // prevent self cycle - if(nbrs(i) != vid) { + if (nbrs(i) != vid) { set.add(nbrs(i)) } i += 1 diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index 2d6a825b61726..9591c4e9b8f4e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -243,14 +243,15 @@ object GraphGenerators { * @return A graph containing vertices with the row and column ids * as their attributes and edge values as 1.0. */ - def gridGraph(sc: SparkContext, rows: Int, cols: Int): Graph[(Int,Int), Double] = { + def gridGraph(sc: SparkContext, rows: Int, cols: Int): Graph[(Int, Int), Double] = { // Convert row column address into vertex ids (row major order) def sub2ind(r: Int, c: Int): VertexId = r * cols + c - val vertices: RDD[(VertexId, (Int,Int))] = - sc.parallelize(0 until rows).flatMap( r => (0 until cols).map( c => (sub2ind(r,c), (r,c)) ) ) + val vertices: RDD[(VertexId, (Int, Int))] = sc.parallelize(0 until rows).flatMap { r => + (0 until cols).map( c => (sub2ind(r, c), (r, c)) ) + } val edges: RDD[Edge[Double]] = - vertices.flatMap{ case (vid, (r,c)) => + vertices.flatMap{ case (vid, (r, c)) => (if (r + 1 < rows) { Seq( (sub2ind(r, c), sub2ind(r + 1, c))) } else { Seq.empty }) ++ (if (c + 1 < cols) { Seq( (sub2ind(r, c), sub2ind(r, c + 1))) } else { Seq.empty }) }.map{ case (src, dst) => Edge(src, dst, 1.0) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala index 9bc8007ce49cd..68fe83739e399 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala @@ -59,7 +59,7 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext { test ("filter") { withSpark { sc => val n = 5 - val vertices = sc.parallelize((0 to n).map(x => (x:VertexId, x))) + val vertices = sc.parallelize((0 to n).map(x => (x: VertexId, x))) val edges = sc.parallelize((1 to n).map(x => Edge(0, x, x))) val graph: Graph[Int, Int] = Graph(vertices, edges).cache() val filteredGraph = graph.filter( @@ -67,11 +67,11 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext { val degrees: VertexRDD[Int] = graph.outDegrees graph.outerJoinVertices(degrees) {(vid, data, deg) => deg.getOrElse(0)} }, - vpred = (vid: VertexId, deg:Int) => deg > 0 + vpred = (vid: VertexId, deg: Int) => deg > 0 ).cache() val v = filteredGraph.vertices.collect().toSet - assert(v === Set((0,0))) + assert(v === Set((0, 0))) // the map is necessary because of object-reuse in the edge iterator val e = filteredGraph.edges.map(e => Edge(e.srcId, e.dstId, e.attr)).collect().toSet diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index a570e4ed75fc3..2b1d8e47326f8 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -248,7 +248,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { test("mask") { withSpark { sc => val n = 5 - val vertices = sc.parallelize((0 to n).map(x => (x:VertexId, x))) + val vertices = sc.parallelize((0 to n).map(x => (x: VertexId, x))) val edges = sc.parallelize((1 to n).map(x => Edge(0, x, x))) val graph: Graph[Int, Int] = Graph(vertices, edges).cache() @@ -260,11 +260,11 @@ class GraphSuite extends FunSuite with LocalSparkContext { val projectedGraph = graph.mask(subgraph) val v = projectedGraph.vertices.collect().toSet - assert(v === Set((0,0), (1,1), (2,2), (4,4), (5,5))) + assert(v === Set((0, 0), (1, 1), (2, 2), (4, 4), (5, 5))) // the map is necessary because of object-reuse in the edge iterator val e = projectedGraph.edges.map(e => Edge(e.srcId, e.dstId, e.attr)).collect().toSet - assert(e === Set(Edge(0,1,1), Edge(0,2,2), Edge(0,5,5))) + assert(e === Set(Edge(0, 1, 1), Edge(0, 2, 2), Edge(0, 5, 5))) } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala index 4cc30a96408f8..accccfc232cd3 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala @@ -52,13 +52,16 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { withSpark { sc => val chain1 = (0 until 9).map(x => (x, x + 1)) val chain2 = (10 until 20).map(x => (x, x + 1)) - val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } + val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s, d) => (s.toLong, d.toLong) } val twoChains = Graph.fromEdgeTuples(rawEdges, 1.0) val ccGraph = twoChains.connectedComponents() val vertices = ccGraph.vertices.collect() for ( (id, cc) <- vertices ) { - if(id < 10) { assert(cc === 0) } - else { assert(cc === 10) } + if (id < 10) { + assert(cc === 0) + } else { + assert(cc === 10) + } } val ccMap = vertices.toMap for (id <- 0 until 20) { @@ -75,7 +78,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { withSpark { sc => val chain1 = (0 until 9).map(x => (x, x + 1)) val chain2 = (10 until 20).map(x => (x, x + 1)) - val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } + val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s, d) => (s.toLong, d.toLong) } val twoChains = Graph.fromEdgeTuples(rawEdges, true).reverse val ccGraph = twoChains.connectedComponents() val vertices = ccGraph.vertices.collect() @@ -106,9 +109,9 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { (4L, ("peter", "student")))) // Create an RDD for edges val relationships: RDD[Edge[String]] = - sc.parallelize(Array(Edge(3L, 7L, "collab"), Edge(5L, 3L, "advisor"), + sc.parallelize(Array(Edge(3L, 7L, "collab"), Edge(5L, 3L, "advisor"), Edge(2L, 5L, "colleague"), Edge(5L, 7L, "pi"), - Edge(4L, 0L, "student"), Edge(5L, 0L, "colleague"))) + Edge(4L, 0L, "student"), Edge(5L, 0L, "colleague"))) // Edges are: // 2 ---> 5 ---> 3 // | \ diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala index 3f3c9dfd7b3dd..39c6ace912b00 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala @@ -31,14 +31,14 @@ object GridPageRank { def sub2ind(r: Int, c: Int): Int = r * nCols + c // Make the grid graph for (r <- 0 until nRows; c <- 0 until nCols) { - val ind = sub2ind(r,c) + val ind = sub2ind(r, c) if (r + 1 < nRows) { outDegree(ind) += 1 - inNbrs(sub2ind(r + 1,c)) += ind + inNbrs(sub2ind(r + 1, c)) += ind } if (c + 1 < nCols) { outDegree(ind) += 1 - inNbrs(sub2ind(r,c + 1)) += ind + inNbrs(sub2ind(r, c + 1)) += ind } } // compute the pagerank @@ -99,8 +99,8 @@ class PageRankSuite extends FunSuite with LocalSparkContext { val resetProb = 0.15 val errorTol = 1.0e-5 - val staticRanks1 = starGraph.staticPersonalizedPageRank(0,numIter = 1, resetProb).vertices - val staticRanks2 = starGraph.staticPersonalizedPageRank(0,numIter = 2, resetProb) + val staticRanks1 = starGraph.staticPersonalizedPageRank(0, numIter = 1, resetProb).vertices + val staticRanks2 = starGraph.staticPersonalizedPageRank(0, numIter = 2, resetProb) .vertices.cache() // Static PageRank should only take 2 iterations to converge @@ -117,7 +117,7 @@ class PageRankSuite extends FunSuite with LocalSparkContext { } assert(staticErrors.sum === 0) - val dynamicRanks = starGraph.personalizedPageRank(0,0, resetProb).vertices.cache() + val dynamicRanks = starGraph.personalizedPageRank(0, 0, resetProb).vertices.cache() assert(compareRanks(staticRanks2, dynamicRanks) < errorTol) } } // end of test Star PageRank @@ -162,7 +162,7 @@ class PageRankSuite extends FunSuite with LocalSparkContext { test("Chain PersonalizedPageRank") { withSpark { sc => val chain1 = (0 until 9).map(x => (x, x + 1) ) - val rawEdges = sc.parallelize(chain1, 1).map { case (s,d) => (s.toLong, d.toLong) } + val rawEdges = sc.parallelize(chain1, 1).map { case (s, d) => (s.toLong, d.toLong) } val chain = Graph.fromEdgeTuples(rawEdges, 1.0).cache() val resetProb = 0.15 val tol = 0.0001 diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala index 293c7f3ba4c21..79bf4e6cd18ee 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala @@ -58,7 +58,7 @@ class TriangleCountSuite extends FunSuite with LocalSparkContext { val triangles = Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++ Array(0L -> -1L, -1L -> -2L, -2L -> 0L) - val revTriangles = triangles.map { case (a,b) => (b,a) } + val revTriangles = triangles.map { case (a, b) => (b, a) } val rawEdges = sc.parallelize(triangles ++ revTriangles, 2) val graph = Graph.fromEdgeTuples(rawEdges, true).cache() val triangleCount = graph.triangleCount() From c45d58c143d68cb807186acc9d060daa8549dd5c Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 28 May 2015 21:20:54 -0700 Subject: [PATCH 055/254] [SPARK-7926] [PYSPARK] use the official Pyrolite release Switch to the official Pyrolite release from the one published under `org.spark-project`. Thanks irmen for making the releases on Maven Central. We didn't upgrade to 4.6 because we don't have enough time for QA. I excludes `serpent` from its dependencies because we don't use it in Spark. ~~~ [info] +-net.jpountz.lz4:lz4:1.3.0 [info] +-net.razorvine:pyrolite:4.4 [info] +-net.sf.py4j:py4j:0.8.2.1 ~~~ davies Author: Xiangrui Meng Closes #6472 from mengxr/SPARK-7926 and squashes the following commits: 7b3c6bf [Xiangrui Meng] use the official Pyrolite release --- core/pom.xml | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/core/pom.xml b/core/pom.xml index bfa49d0d6dc25..e58efe495e36d 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -377,9 +377,15 @@ test - org.spark-project + net.razorvine pyrolite 4.4 + + + net.razorvine + serpent + + net.sf.py4j From 834e699524583a7ebfe9e83b3900ec503150deca Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 28 May 2015 21:26:43 -0700 Subject: [PATCH 056/254] [MINOR] fix RegressionEvaluator doc `make clean html` under `python/doc` returns ~~~ /Users/meng/src/spark/python/pyspark/ml/evaluation.py:docstring of pyspark.ml.evaluation.RegressionEvaluator.setParams:3: WARNING: Definition list ends without a blank line; unexpected unindent. ~~~ harsha2010 Author: Xiangrui Meng Closes #6469 from mengxr/fix-regression-evaluator-doc and squashes the following commits: 91e2dad [Xiangrui Meng] fix RegressionEvaluator doc --- python/pyspark/ml/evaluation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 23c37167b3711..d8ddb78c6d639 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -205,7 +205,7 @@ def getMetricName(self): def setParams(self, predictionCol="prediction", labelCol="label", metricName="rmse"): """ - setParams(self, predictionCol="prediction", labelCol="label", + setParams(self, predictionCol="prediction", labelCol="label", \ metricName="rmse") Sets params for regression evaluator. """ From 04ddcd4db7801abefa9c9effe5d88413b29d713b Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Thu, 28 May 2015 22:09:49 -0700 Subject: [PATCH 057/254] [SPARK-7932] Fix misleading scheduler delay visualization The existing code rounds down to the nearest percent when computing the proportion of a task's time that was spent on each phase of execution, and then computes the scheduler delay proportion as 100 - sum(all other proportions). As a result, a few extra percent can end up in the scheduler delay. This commit eliminates the rounding so that the time visualizations correspond properly to the real times. sarutak If you could take a look at this, that would be great! Not sure if there's a good reason to round here that I missed. cc shivaram Author: Kay Ousterhout Closes #6484 from kayousterhout/SPARK-7932 and squashes the following commits: 1723cc4 [Kay Ousterhout] [SPARK-7932] Fix misleading scheduler delay visualization --- core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 31e2e7fba9783..b83a49f79c8a8 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -527,7 +527,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { minLaunchTime = launchTime.min(minLaunchTime) maxFinishTime = finishTime.max(maxFinishTime) - def toProportion(time: Long) = (time.toDouble / totalExecutionTime * 100).toLong + def toProportion(time: Long) = time.toDouble / totalExecutionTime * 100 val metricsOpt = taskUIData.taskMetrics val shuffleReadTime = From cd3d9a5c0c3e77098a72c85dffe4a27737009ae7 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 28 May 2015 22:28:13 -0700 Subject: [PATCH 058/254] [SPARK-7930] [CORE] [STREAMING] Fixed shutdown hook priorities Shutdown hook for temp directories had priority 100 while SparkContext was 50. So the local root directory was deleted before SparkContext was shutdown. This leads to scary errors on running jobs, at the time of shutdown. This is especially a problem when running streaming examples, where Ctrl-C is the only way to shutdown. The fix in this PR is to make the temp directory shutdown priority lower than SparkContext, so that the temp dirs are the last thing to get deleted, after the SparkContext has been shut down. Also, the DiskBlockManager shutdown priority is change from default 100 to temp_dir_prio + 1, so that it gets invoked just before all temp dirs are cleared. Author: Tathagata Das Closes #6482 from tdas/SPARK-7930 and squashes the following commits: d7cbeb5 [Tathagata Das] Removed unnecessary line 1514d0b [Tathagata Das] Fixed shutdown hook priorities --- .../org/apache/spark/storage/DiskBlockManager.scala | 4 ++-- .../src/main/scala/org/apache/spark/util/Utils.scala | 12 ++++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 2a4447705fa65..d441a4d31b954 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -139,8 +139,8 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon } private def addShutdownHook(): AnyRef = { - Utils.addShutdownHook { () => - logDebug("Shutdown hook called") + Utils.addShutdownHook(Utils.TEMP_DIR_SHUTDOWN_PRIORITY + 1) { () => + logInfo("Shutdown hook called") DiskBlockManager.this.doStop() } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 763d4db690187..693e1a0a3d5f0 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -73,6 +73,13 @@ private[spark] object Utils extends Logging { */ val SPARK_CONTEXT_SHUTDOWN_PRIORITY = 50 + /** + * The shutdown priority of temp directory must be lower than the SparkContext shutdown + * priority. Otherwise cleaning the temp directories while Spark jobs are running can + * throw undesirable errors at the time of shutdown. + */ + val TEMP_DIR_SHUTDOWN_PRIORITY = 25 + private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 @volatile private var localRootDirs: Array[String] = null @@ -189,10 +196,11 @@ private[spark] object Utils extends Logging { private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() // Add a shutdown hook to delete the temp dirs when the JVM exits - addShutdownHook { () => - logDebug("Shutdown hook called") + addShutdownHook(TEMP_DIR_SHUTDOWN_PRIORITY) { () => + logInfo("Shutdown hook called") shutdownDeletePaths.foreach { dirPath => try { + logInfo("Deleting directory " + dirPath) Utils.deleteRecursively(new File(dirPath)) } catch { case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e) From db9513789756da4f16bb1fe8cf1d19500f231f54 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 28 May 2015 22:38:38 -0700 Subject: [PATCH 059/254] [SPARK-7922] [MLLIB] use DataFrames for user/item factors in ALSModel Expose user/item factors in DataFrames. This is to be more consistent with the pipeline API. It also helps maintain consistent APIs across languages. This PR also removed fitting params from `ALSModel`. coderxiang Author: Xiangrui Meng Closes #6468 from mengxr/SPARK-7922 and squashes the following commits: 7bfb1d5 [Xiangrui Meng] update ALSModel in PySpark 1ba5607 [Xiangrui Meng] use DataFrames for user/item factors in ALS --- .../apache/spark/ml/recommendation/ALS.scala | 101 ++++++++++-------- python/pyspark/ml/recommendation.py | 30 +++++- python/pyspark/mllib/common.py | 5 +- 3 files changed, 89 insertions(+), 47 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 900b637ff8ad4..df009d855ecbb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -35,21 +35,46 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType} +import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} import org.apache.spark.util.random.XORShiftRandom +/** + * Common params for ALS and ALSModel. + */ +private[recommendation] trait ALSModelParams extends Params with HasPredictionCol { + /** + * Param for the column name for user ids. + * Default: "user" + * @group param + */ + val userCol = new Param[String](this, "userCol", "column name for user ids") + + /** @group getParam */ + def getUserCol: String = $(userCol) + + /** + * Param for the column name for item ids. + * Default: "item" + * @group param + */ + val itemCol = new Param[String](this, "itemCol", "column name for item ids") + + /** @group getParam */ + def getItemCol: String = $(itemCol) +} + /** * Common params for ALS. */ -private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam +private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter with HasRegParam with HasPredictionCol with HasCheckpointInterval with HasSeed { /** @@ -105,26 +130,6 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR /** @group getParam */ def getAlpha: Double = $(alpha) - /** - * Param for the column name for user ids. - * Default: "user" - * @group param - */ - val userCol = new Param[String](this, "userCol", "column name for user ids") - - /** @group getParam */ - def getUserCol: String = $(userCol) - - /** - * Param for the column name for item ids. - * Default: "item" - * @group param - */ - val itemCol = new Param[String](this, "itemCol", "column name for item ids") - - /** @group getParam */ - def getItemCol: String = $(itemCol) - /** * Param for the column name for ratings. * Default: "rating" @@ -156,55 +161,60 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - require(schema($(userCol)).dataType == IntegerType) - require(schema($(itemCol)).dataType== IntegerType) + SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) + SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) val ratingType = schema($(ratingCol)).dataType require(ratingType == FloatType || ratingType == DoubleType) - val predictionColName = $(predictionCol) - require(!schema.fieldNames.contains(predictionColName), - s"Prediction column $predictionColName already exists.") - val newFields = schema.fields :+ StructField($(predictionCol), FloatType, nullable = false) - StructType(newFields) + SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) } } /** * :: Experimental :: * Model fitted by ALS. + * + * @param rank rank of the matrix factorization model + * @param userFactors a DataFrame that stores user factors in two columns: `id` and `features` + * @param itemFactors a DataFrame that stores item factors in two columns: `id` and `features` */ @Experimental class ALSModel private[ml] ( override val uid: String, - k: Int, - userFactors: RDD[(Int, Array[Float])], - itemFactors: RDD[(Int, Array[Float])]) - extends Model[ALSModel] with ALSParams { + val rank: Int, + @transient val userFactors: DataFrame, + @transient val itemFactors: DataFrame) + extends Model[ALSModel] with ALSModelParams { + + /** @group setParam */ + def setUserCol(value: String): this.type = set(userCol, value) + + /** @group setParam */ + def setItemCol(value: String): this.type = set(itemCol, value) /** @group setParam */ def setPredictionCol(value: String): this.type = set(predictionCol, value) override def transform(dataset: DataFrame): DataFrame = { - import dataset.sqlContext.implicits._ - val users = userFactors.toDF("id", "features") - val items = itemFactors.toDF("id", "features") - // Register a UDF for DataFrame, and then // create a new column named map(predictionCol) by running the predict UDF. val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => if (userFeatures != null && itemFeatures != null) { - blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1) + blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1) } else { Float.NaN } } dataset - .join(users, dataset($(userCol)) === users("id"), "left") - .join(items, dataset($(itemCol)) === items("id"), "left") - .select(dataset("*"), predict(users("features"), items("features")).as($(predictionCol))) + .join(userFactors, dataset($(userCol)) === userFactors("id"), "left") + .join(itemFactors, dataset($(itemCol)) === itemFactors("id"), "left") + .select(dataset("*"), + predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) } override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) + SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) + SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) } } @@ -299,6 +309,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { } override def fit(dataset: DataFrame): ALSModel = { + import dataset.sqlContext.implicits._ val ratings = dataset .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), col($(ratingCol)).cast(FloatType)) @@ -310,7 +321,9 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs), alpha = $(alpha), nonnegative = $(nonnegative), checkpointInterval = $(checkpointInterval), seed = $(seed)) - val model = new ALSModel(uid, $(rank), userFactors, itemFactors).setParent(this) + val userDF = userFactors.toDF("id", "features") + val itemDF = itemFactors.toDF("id", "features") + val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this) copyValues(model) } diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index b3e0dd7abf681..b06099ac0aee6 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -63,8 +63,15 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha indicated user preferences rather than explicit ratings given to items. + >>> df = sqlContext.createDataFrame( + ... [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)], + ... ["user", "item", "rating"]) >>> als = ALS(rank=10, maxIter=5) >>> model = als.fit(df) + >>> model.rank + 10 + >>> model.userFactors.orderBy("id").collect() + [Row(id=0, features=[...]), Row(id=1, ...), Row(id=2, ...)] >>> test = sqlContext.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"]) >>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0]) >>> predictions[0] @@ -260,6 +267,27 @@ class ALSModel(JavaModel): Model fitted by ALS. """ + @property + def rank(self): + """rank of the matrix factorization model""" + return self._call_java("rank") + + @property + def userFactors(self): + """ + a DataFrame that stores user factors in two columns: `id` and + `features` + """ + return self._call_java("userFactors") + + @property + def itemFactors(self): + """ + a DataFrame that stores item factors in two columns: `id` and + `features` + """ + return self._call_java("itemFactors") + if __name__ == "__main__": import doctest @@ -272,8 +300,6 @@ class ALSModel(JavaModel): sqlContext = SQLContext(sc) globs['sc'] = sc globs['sqlContext'] = sqlContext - globs['df'] = sqlContext.createDataFrame([(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), - (2, 1, 1.0), (2, 2, 5.0)], ["user", "item", "rating"]) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) sc.stop() if failure_count: diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index ba6058978880a..855e85f57155e 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -27,7 +27,7 @@ from pyspark import RDD, SparkContext from pyspark.serializers import PickleSerializer, AutoBatchedSerializer - +from pyspark.sql import DataFrame, SQLContext # Hack for support float('inf') in Py4j _old_smart_decode = py4j.protocol.smart_decode @@ -99,6 +99,9 @@ def _java2py(sc, r, encoding="bytes"): jrdd = sc._jvm.SerDe.javaToPython(r) return RDD(jrdd, sc) + if clsName == 'DataFrame': + return DataFrame(r, SQLContext(sc)) + if clsName in _picklable_classes: r = sc._jvm.SerDe.dumps(r) elif isinstance(r, (JavaArray, JavaList)): From e714ecf277a7412ea8263662977fe3ad1f794975 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 28 May 2015 22:39:21 -0700 Subject: [PATCH 060/254] [SPARK-7931] [STREAMING] Do not restart receiver when stopped Attempts to restart the socket receiver when it is supposed to be stopped causes undesirable error messages. Author: Tathagata Das Closes #6483 from tdas/SPARK-7931 and squashes the following commits: 09aeee1 [Tathagata Das] Do not restart receiver when stopped --- .../spark/streaming/dstream/SocketInputDStream.scala | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala index 8b72bcf20653d..96e0a9c1a88f0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala @@ -17,6 +17,8 @@ package org.apache.spark.streaming.dstream +import scala.util.control.NonFatal + import org.apache.spark.streaming.StreamingContext import org.apache.spark.storage.StorageLevel import org.apache.spark.util.NextIterator @@ -74,13 +76,16 @@ class SocketReceiver[T: ClassTag]( while(!isStopped && iterator.hasNext) { store(iterator.next) } + if (!isStopped()) { + restart("Socket data stream had no more data") + } logInfo("Stopped receiving") - restart("Retrying connecting to " + host + ":" + port) } catch { case e: java.net.ConnectException => restart("Error connecting to " + host + ":" + port, e) - case t: Throwable => - restart("Error receiving data", t) + case NonFatal(e) => + logWarning("Error receiving data", e) + restart("Error receiving data", e) } finally { if (socket != null) { socket.close() From 36067ce398e2949c2f122625e67fd5497febdee6 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 28 May 2015 22:48:02 -0700 Subject: [PATCH 061/254] [HOTFIX] Minor style fix from last commit --- .../apache/spark/streaming/dstream/SocketInputDStream.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala index 96e0a9c1a88f0..5ce5b7aae6e69 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala @@ -78,8 +78,9 @@ class SocketReceiver[T: ClassTag]( } if (!isStopped()) { restart("Socket data stream had no more data") + } else { + logInfo("Stopped receiving") } - logInfo("Stopped receiving") } catch { case e: java.net.ConnectException => restart("Error connecting to " + host + ":" + port, e) From 97a60cf75d1fed654953eccedd04f3442389c5ca Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 28 May 2015 23:00:02 -0700 Subject: [PATCH 062/254] [SPARK-7929] Turn whitespace checker on for more token types. This is the last batch of changes to complete SPARK-7929. Previous related PRs: https://github.com/apache/spark/pull/6480 https://github.com/apache/spark/pull/6478 https://github.com/apache/spark/pull/6477 https://github.com/apache/spark/pull/6476 https://github.com/apache/spark/pull/6475 https://github.com/apache/spark/pull/6474 https://github.com/apache/spark/pull/6473 Author: Reynold Xin Closes #6487 from rxin/whitespace-lint and squashes the following commits: b33d43d [Reynold Xin] [SPARK-7929] Turn whitespace checker on for more token types. --- .../flume/sink/TransactionProcessor.scala | 2 +- .../streaming/flume/EventTransformer.scala | 2 +- .../spark/streaming/kafka/KafkaRDDSuite.scala | 2 +- .../streaming/mqtt/MQTTInputDStream.scala | 14 +------------- .../streaming/KinesisWordCountASL.scala | 2 +- .../spark/streaming/kinesis/KinesisUtils.scala | 4 ++-- scalastyle-config.xml | 13 ++++++++++++- .../spark/sql/hive/HiveInspectorSuite.scala | 12 ++++++------ .../sql/hive/InsertIntoHiveTableSuite.scala | 2 +- .../spark/sql/hive/ListTablesSuite.scala | 2 +- .../org/apache/spark/sql/hive/UDFSuite.scala | 6 +++--- .../hive/execution/HiveComparisonTest.scala | 4 ++-- .../hive/execution/HiveResolutionSuite.scala | 6 +++--- .../hive/execution/HiveTableScanSuite.scala | 2 +- .../sql/hive/execution/SQLQuerySuite.scala | 2 +- .../apache/spark/sql/hive/parquetSuites.scala | 4 ++-- .../org/apache/spark/deploy/yarn/Client.scala | 6 +++--- .../yarn/ClientDistributedCacheManager.scala | 18 +++++++++--------- .../apache/spark/deploy/yarn/ClientSuite.scala | 2 +- 19 files changed, 52 insertions(+), 53 deletions(-) diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala index ea45b14294df9..7ad43b1d7b0a0 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala @@ -143,7 +143,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, eventBatch.setErrorMsg(msg) } else { // At this point, the events are available, so fill them into the event batch - eventBatch = new EventBatch("",seqNum, events) + eventBatch = new EventBatch("", seqNum, events) } }) } catch { diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala index dc629df4f4ac2..65c49c131518b 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala @@ -60,7 +60,7 @@ private[streaming] object EventTransformer extends Logging { out.write(body) val numHeaders = headers.size() out.writeInt(numHeaders) - for ((k,v) <- headers) { + for ((k, v) <- headers) { val keyBuff = Utils.serialize(k.toString) out.writeInt(keyBuff.length) out.write(keyBuff) diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala index 39c3fb448ff57..3c875cb766513 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala @@ -65,7 +65,7 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size)) - val rdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( + val rdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( sc, kafkaParams, offsetRanges) val received = rdd.map(_._2).collect.toSet diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala index 40f5f18547236..7c2f18cb35bda 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala @@ -17,22 +17,10 @@ package org.apache.spark.streaming.mqtt -import java.io.IOException -import java.util.concurrent.Executors -import java.util.Properties - -import scala.collection.JavaConversions._ -import scala.collection.Map -import scala.collection.mutable.HashMap -import scala.reflect.ClassTag - import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken import org.eclipse.paho.client.mqttv3.MqttCallback import org.eclipse.paho.client.mqttv3.MqttClient -import org.eclipse.paho.client.mqttv3.MqttClientPersistence -import org.eclipse.paho.client.mqttv3.MqttException import org.eclipse.paho.client.mqttv3.MqttMessage -import org.eclipse.paho.client.mqttv3.MqttTopic import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence import org.apache.spark.storage.StorageLevel @@ -87,7 +75,7 @@ class MQTTReceiver( // Handles Mqtt message override def messageArrived(topic: String, message: MqttMessage) { - store(new String(message.getPayload(),"utf-8")) + store(new String(message.getPayload(), "utf-8")) } override def deliveryComplete(token: IMqttDeliveryToken) { diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index df77f4be9db1d..97c3476049289 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -208,7 +208,7 @@ object KinesisWordProducerASL { recordsPerSecond: Int, wordsPerRecord: Int): Seq[(String, Int)] = { - val randomWords = List("spark","you","are","my","father") + val randomWords = List("spark", "you", "are", "my", "father") val totals = scala.collection.mutable.Map[String, Int]() // Create the low-level Kinesis Client from the AWS Java SDK. diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index 2531aebe7813c..e5acab50181e1 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -55,7 +55,7 @@ object KinesisUtils { */ def createStream( ssc: StreamingContext, - kinesisAppName: String, + kinesisAppName: String, streamName: String, endpointUrl: String, regionName: String, @@ -102,7 +102,7 @@ object KinesisUtils { */ def createStream( ssc: StreamingContext, - kinesisAppName: String, + kinesisAppName: String, streamName: String, endpointUrl: String, regionName: String, diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 7168d5b2a8e26..68d980b610c00 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -51,8 +51,8 @@ - + @@ -142,4 +142,15 @@ + + + ARROW, EQUALS + + + + + ARROW, EQUALS, COMMA, COLON, IF, WHILE, FOR + + + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 2a7374cc172b7..80c2d32bf70d7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -78,10 +78,10 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { Literal(java.sql.Date.valueOf("2014-09-23")) :: Literal(Decimal(BigDecimal(123.123))) :: Literal(new java.sql.Timestamp(123123)) :: - Literal(Array[Byte](1,2,3)) :: - Literal.create(Seq[Int](1,2,3), ArrayType(IntegerType)) :: - Literal.create(Map[Int, Int](1->2, 2->1), MapType(IntegerType, IntegerType)) :: - Literal.create(Row(1,2.0d,3.0f), + Literal(Array[Byte](1, 2, 3)) :: + Literal.create(Seq[Int](1, 2, 3), ArrayType(IntegerType)) :: + Literal.create(Map[Int, Int](1 -> 2, 2 -> 1), MapType(IntegerType, IntegerType)) :: + Literal.create(Row(1, 2.0d, 3.0f), StructType(StructField("c1", IntegerType) :: StructField("c2", DoubleType) :: StructField("c3", FloatType) :: Nil)) :: @@ -111,8 +111,8 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { case DecimalType() => PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector case StructType(fields) => ObjectInspectorFactory.getStandardStructObjectInspector( - java.util.Arrays.asList(fields.map(f => f.name) :_*), - java.util.Arrays.asList(fields.map(f => toWritableInspector(f.dataType)) :_*)) + java.util.Arrays.asList(fields.map(f => f.name) : _*), + java.util.Arrays.asList(fields.map(f => toWritableInspector(f.dataType)) : _*)) } def checkDataType(dt1: Seq[DataType], dt2: Seq[DataType]): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index acf2f7da30188..9cc4685499f19 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -160,7 +160,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=1"::Nil , "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=4"::Nil ) - assert(listFolders(tmpDir,List()).sortBy(_.toString()) == expected.sortBy(_.toString)) + assert(listFolders(tmpDir, List()).sortBy(_.toString()) == expected.sortBy(_.toString)) sql("DROP TABLE table_with_partition") sql("DROP TABLE tmp_table") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index e12a6c21ccac4..1c15997ea8e6d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -29,7 +29,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll { import org.apache.spark.sql.hive.test.TestHive.implicits._ val df = - sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDF("key", "value") + sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value") override def beforeAll(): Unit = { // The catalog in HiveContext is a case insensitive one. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 85b6bc93d7122..8245047626d57 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -26,9 +26,9 @@ case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { test("UDF case insensitive") { - udf.register("random0", () => { Math.random()}) - udf.register("RANDOM1", () => { Math.random()}) - udf.register("strlenScala", (_: String).length + (_:Int)) + udf.register("random0", () => { Math.random() }) + udf.register("RANDOM1", () => { Math.random() }) + udf.register("strlenScala", (_: String).length + (_: Int)) assert(sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) assert(sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) assert(sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 9c056e493bfde..55e5551b63818 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -273,7 +273,7 @@ abstract class HiveComparisonTest } val hiveCacheFiles = queryList.zipWithIndex.map { - case (queryString, i) => + case (queryString, i) => val cachedAnswerName = s"$testCaseName-$i-${getMd5(queryString)}" new File(answerCache, cachedAnswerName) } @@ -304,7 +304,7 @@ abstract class HiveComparisonTest // other DDL has not been executed yet. hiveQueries.foreach(_.logical) val computedResults = (queryList.zipWithIndex, hiveQueries, hiveCacheFiles).zipped.map { - case ((queryString, i), hiveQuery, cachedAnswerFile)=> + case ((queryString, i), hiveQuery, cachedAnswerFile) => try { // Hooks often break the harness and don't really affect our test anyway, don't // even try running them. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index 3dfa6e72e1242..b08db6de2d2f6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -77,7 +77,7 @@ class HiveResolutionSuite extends HiveComparisonTest { test("case insensitivity with scala reflection") { // Test resolution with Scala Reflection - sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) + sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) .toDF().registerTempTable("caseSensitivityTest") val query = sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") @@ -88,14 +88,14 @@ class HiveResolutionSuite extends HiveComparisonTest { ignore("case insensitivity with scala reflection joins") { // Test resolution with Scala Reflection - sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) + sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) .toDF().registerTempTable("caseSensitivityTest") sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a").collect() } test("nested repeated resolution") { - sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) + sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) .toDF().registerTempTable("nestedRepeatedTest") assert(sql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index ab53c6309e089..0ba4d11478211 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -76,7 +76,7 @@ class HiveTableScanSuite extends HiveComparisonTest { TestHive.sql(s"LOAD DATA LOCAL INPATH '$location' INTO TABLE timestamp_query_null") assert(TestHive.sql("SELECT time from timestamp_query_null limit 2").collect() - === Array(Row(java.sql.Timestamp.valueOf("2014-12-11 00:00:00")),Row(null))) + === Array(Row(java.sql.Timestamp.valueOf("2014-12-11 00:00:00")), Row(null))) TestHive.sql("DROP TABLE timestamp_query_null") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 538e66125c5fe..27863a60145d7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -327,7 +327,7 @@ class SQLQuerySuite extends QueryTest { "org.apache.hadoop.hive.ql.io.RCFileInputFormat", "org.apache.hadoop.hive.ql.io.RCFileOutputFormat", "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe", - "serde_p1=p1", "serde_p2=p2", "tbl_p1=p11", "tbl_p2=p22","MANAGED_TABLE" + "serde_p1=p1", "serde_p2=p2", "tbl_p1=p11", "tbl_p2=p22", "MANAGED_TABLE" ) if (HiveShim.version =="0.13.1") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 7851f38fd4056..e62ac909cbd0c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -38,7 +38,7 @@ case class ParquetData(intField: Int, stringField: String) // The data that also includes the partitioning key case class ParquetDataWithKey(p: Int, intField: Int, stringField: String) -case class StructContainer(intStructField :Int, stringStructField: String) +case class StructContainer(intStructField: Int, stringStructField: String) case class ParquetDataWithComplexTypes( intField: Int, @@ -735,7 +735,7 @@ class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { val filePath = new File(tempDir, "testParquet").getCanonicalPath val filePath2 = new File(tempDir, "testParquet2").getCanonicalPath - val df = Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str") + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") val df2 = df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").max("y.int") intercept[Throwable](df2.write.parquet(filePath)) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 7e023f2d92578..234051eb7d3bb 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1142,9 +1142,9 @@ object Client extends Logging { logDebug("HiveMetaStore configured in localmode") } } catch { - case e:java.lang.NoSuchMethodException => { logInfo("Hive Method not found " + e); return } - case e:java.lang.ClassNotFoundException => { logInfo("Hive Class not found " + e); return } - case e:Exception => { logError("Unexpected Exception " + e) + case e: java.lang.NoSuchMethodException => { logInfo("Hive Method not found " + e); return } + case e: java.lang.ClassNotFoundException => { logInfo("Hive Class not found " + e); return } + case e: Exception => { logError("Unexpected Exception " + e) throw new RuntimeException("Unexpected exception", e) } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala index c592ecfdfce06..4ca6c903fcf12 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala @@ -95,13 +95,13 @@ private[spark] class ClientDistributedCacheManager() extends Logging { val (keys, tupleValues) = distCacheFiles.unzip val (sizes, timeStamps, visibilities) = tupleValues.unzip3 if (keys.size > 0) { - env("SPARK_YARN_CACHE_FILES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES") = keys.reduceLeft[String] { (acc, n) => acc + "," + n } env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") = - timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n } + timeStamps.reduceLeft[String] { (acc, n) => acc + "," + n } env("SPARK_YARN_CACHE_FILES_FILE_SIZES") = - sizes.reduceLeft[String] { (acc,n) => acc + "," + n } + sizes.reduceLeft[String] { (acc, n) => acc + "," + n } env("SPARK_YARN_CACHE_FILES_VISIBILITIES") = - visibilities.reduceLeft[String] { (acc,n) => acc + "," + n } + visibilities.reduceLeft[String] { (acc, n) => acc + "," + n } } } @@ -112,13 +112,13 @@ private[spark] class ClientDistributedCacheManager() extends Logging { val (keys, tupleValues) = distCacheArchives.unzip val (sizes, timeStamps, visibilities) = tupleValues.unzip3 if (keys.size > 0) { - env("SPARK_YARN_CACHE_ARCHIVES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_ARCHIVES") = keys.reduceLeft[String] { (acc, n) => acc + "," + n } env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") = - timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n } + timeStamps.reduceLeft[String] { (acc, n) => acc + "," + n } env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") = - sizes.reduceLeft[String] { (acc,n) => acc + "," + n } + sizes.reduceLeft[String] { (acc, n) => acc + "," + n } env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") = - visibilities.reduceLeft[String] { (acc,n) => acc + "," + n } + visibilities.reduceLeft[String] { (acc, n) => acc + "," + n } } } @@ -160,7 +160,7 @@ private[spark] class ClientDistributedCacheManager() extends Logging { def ancestorsHaveExecutePermissions( fs: FileSystem, path: Path, - statCache: Map[URI, FileStatus]): Boolean = { + statCache: Map[URI, FileStatus]): Boolean = { var current = path while (current != null) { // the subdirs in the path should have execute permissions for others diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 508819e242a26..6da3e82acdb14 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -203,7 +203,7 @@ class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { def getFieldValue2[A: ClassTag, A1: ClassTag, B]( clazz: Class[_], field: String, - defaults: => B)(mapTo: A => B)(mapTo1: A1 => B): B = { + defaults: => B)(mapTo: A => B)(mapTo1: A1 => B): B = { Try(clazz.getField(field)).map(_.get(null)).map { case v: A => mapTo(v) case v1: A1 => mapTo1(v1) From 23452be944463dae72a35b58551040556dd3aeb5 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 29 May 2015 00:51:12 -0700 Subject: [PATCH 063/254] [SPARK-7912] [SPARK-7921] [MLLIB] Update OneHotEncoder to handle ML attributes and change includeFirst to dropLast This PR contains two major changes to `OneHotEncoder`: 1. more robust handling of ML attributes. If the input attribute is unknown, we look at the values to get the max category index 2. change `includeFirst` to `dropLast` and leave the default to `true`. There are couple benefits: a. consistent with other tutorials of one-hot encoding (or dummy coding) (e.g., http://www.ats.ucla.edu/stat/mult_pkg/faq/general/dummy.htm) b. keep the indices unmodified in the output vector. If we drop the first, all indices will be shifted by 1. c. If users use `StringIndex`, the last element is the least frequent one. Sorry for including two changes in one PR! I'll update the user guide in another PR. jkbradley sryza Author: Xiangrui Meng Closes #6466 from mengxr/SPARK-7912 and squashes the following commits: a280dca [Xiangrui Meng] fix tests d8f234d [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7912 171b276 [Xiangrui Meng] mention the difference between our impl vs sklearn's 00dfd96 [Xiangrui Meng] update OneHotEncoder in Python 208ddad [Xiangrui Meng] update OneHotEncoder to handle ML attributes and change includeFirst to dropLast --- .../spark/ml/feature/OneHotEncoder.scala | 160 ++++++++++++------ .../spark/ml/feature/OneHotEncoderSuite.scala | 42 ++++- python/pyspark/ml/feature.py | 58 ++++--- 3 files changed, 176 insertions(+), 84 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index eb6ec49f854be..8f34878c8d329 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -17,94 +17,152 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkException import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute} +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util.{Identifiable, SchemaUtils} -import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} -import org.apache.spark.sql.types.{DataType, DoubleType, StructType} +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{DoubleType, StructType} /** * :: Experimental :: - * A one-hot encoder that maps a column of label indices to a column of binary vectors, with - * at most a single one-value. By default, the binary vector has an element for each category, so - * with 5 categories, an input value of 2.0 would map to an output vector of - * (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so the - * output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value - * of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns - * linearly dependent because they sum up to one. + * A one-hot encoder that maps a column of category indices to a column of binary vectors, with + * at most a single one-value per row that indicates the input category index. + * For example with 5 categories, an input value of 2.0 would map to an output vector of + * `[0.0, 0.0, 1.0, 0.0]`. + * The last category is not included by default (configurable via [[OneHotEncoder!.dropLast]] + * because it makes the vector entries sum up to one, and hence linearly dependent. + * So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. + * Note that this is different from scikit-learn's OneHotEncoder, which keeps all categories. + * The output vectors are sparse. + * + * @see [[StringIndexer]] for converting categorical values into category indices */ @Experimental -class OneHotEncoder(override val uid: String) - extends UnaryTransformer[Double, Vector, OneHotEncoder] with HasInputCol with HasOutputCol { +class OneHotEncoder(override val uid: String) extends Transformer + with HasInputCol with HasOutputCol { def this() = this(Identifiable.randomUID("oneHot")) /** - * Whether to include a component in the encoded vectors for the first category, defaults to true. + * Whether to drop the last category in the encoded vector (default: true) * @group param */ - final val includeFirst: BooleanParam = - new BooleanParam(this, "includeFirst", "include first category") - setDefault(includeFirst -> true) - - private var categories: Array[String] = _ + final val dropLast: BooleanParam = + new BooleanParam(this, "dropLast", "whether to drop the last category") + setDefault(dropLast -> true) /** @group setParam */ - def setIncludeFirst(value: Boolean): this.type = set(includeFirst, value) + def setDropLast(value: Boolean): this.type = set(dropLast, value) /** @group setParam */ - override def setInputCol(value: String): this.type = set(inputCol, value) + def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ - override def setOutputCol(value: String): this.type = set(outputCol, value) + def setOutputCol(value: String): this.type = set(outputCol, value) override def transformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) - val inputFields = schema.fields + val is = "_is_" + val inputColName = $(inputCol) val outputColName = $(outputCol) - require(inputFields.forall(_.name != $(outputCol)), - s"Output column ${$(outputCol)} already exists.") - val inputColAttr = Attribute.fromStructField(schema($(inputCol))) - categories = inputColAttr match { + SchemaUtils.checkColumnType(schema, inputColName, DoubleType) + val inputFields = schema.fields + require(!inputFields.exists(_.name == outputColName), + s"Output column $outputColName already exists.") + + val inputAttr = Attribute.fromStructField(schema(inputColName)) + val outputAttrNames: Option[Array[String]] = inputAttr match { case nominal: NominalAttribute => - nominal.values.getOrElse((0 until nominal.numValues.get).map(_.toString).toArray) - case binary: BinaryAttribute => binary.values.getOrElse(Array("0", "1")) + if (nominal.values.isDefined) { + nominal.values.map(_.map(v => inputColName + is + v)) + } else if (nominal.numValues.isDefined) { + nominal.numValues.map(n => Array.tabulate(n)(i => inputColName + is + i)) + } else { + None + } + case binary: BinaryAttribute => + if (binary.values.isDefined) { + binary.values.map(_.map(v => inputColName + is + v)) + } else { + Some(Array.tabulate(2)(i => inputColName + is + i)) + } + case _: NumericAttribute => + throw new RuntimeException( + s"The input column $inputColName cannot be numeric.") case _ => - throw new SparkException(s"OneHotEncoder input column ${$(inputCol)} is not nominal") + None // optimistic about unknown attributes } - val attrValues = (if ($(includeFirst)) categories else categories.drop(1)).toArray - val attr = NominalAttribute.defaultAttr.withName(outputColName).withValues(attrValues) - val outputFields = inputFields :+ attr.toStructField() + val filteredOutputAttrNames = outputAttrNames.map { names => + if ($(dropLast)) { + require(names.length > 1, + s"The input column $inputColName should have at least two distinct values.") + names.dropRight(1) + } else { + names + } + } + + val outputAttrGroup = if (filteredOutputAttrNames.isDefined) { + val attrs: Array[Attribute] = filteredOutputAttrNames.get.map { name => + BinaryAttribute.defaultAttr.withName(name) + } + new AttributeGroup($(outputCol), attrs) + } else { + new AttributeGroup($(outputCol)) + } + + val outputFields = inputFields :+ outputAttrGroup.toStructField() StructType(outputFields) } - protected override def createTransformFunc(): (Double) => Vector = { - val first = $(includeFirst) - val vecLen = if (first) categories.length else categories.length - 1 + override def transform(dataset: DataFrame): DataFrame = { + // schema transformation + val is = "_is_" + val inputColName = $(inputCol) + val outputColName = $(outputCol) + val shouldDropLast = $(dropLast) + var outputAttrGroup = AttributeGroup.fromStructField( + transformSchema(dataset.schema)(outputColName)) + if (outputAttrGroup.size < 0) { + // If the number of attributes is unknown, we check the values from the input column. + val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).map(_.getDouble(0)) + .aggregate(0.0)( + (m, x) => { + assert(x >=0.0 && x == x.toInt, + s"Values from column $inputColName must be indices, but got $x.") + math.max(m, x) + }, + (m0, m1) => { + math.max(m0, m1) + } + ).toInt + 1 + val outputAttrNames = Array.tabulate(numAttrs)(i => inputColName + is + i) + val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames + val outputAttrs: Array[Attribute] = + filtered.map(name => BinaryAttribute.defaultAttr.withName(name)) + outputAttrGroup = new AttributeGroup(outputColName, outputAttrs) + } + val metadata = outputAttrGroup.toMetadata() + + // data transformation + val size = outputAttrGroup.size val oneValue = Array(1.0) val emptyValues = Array[Double]() val emptyIndices = Array[Int]() - label: Double => { - val values = if (first || label != 0.0) oneValue else emptyValues - val indices = if (first) { - Array(label.toInt) - } else if (label != 0.0) { - Array(label.toInt - 1) + val encode = udf { label: Double => + if (label < size) { + Vectors.sparse(size, Array(label.toInt), oneValue) } else { - emptyIndices + Vectors.sparse(size, emptyIndices, emptyValues) } - Vectors.sparse(vecLen, indices, values) } - } - /** - * Returns the data type of the output column. - */ - protected def outputDataType: DataType = new VectorUDT + dataset.select(col("*"), encode(col(inputColName).cast(DoubleType)).as(outputColName, metadata)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 056b9eda86bba..9018d0024d5f0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -19,10 +19,11 @@ package org.apache.spark.ml.feature import org.scalatest.FunSuite +import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame - +import org.apache.spark.sql.functions.col class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext { @@ -36,15 +37,16 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext { indexer.transform(df) } - test("OneHotEncoder includeFirst = true") { + test("OneHotEncoder dropLast = false") { val transformed = stringIndexed() val encoder = new OneHotEncoder() .setInputCol("labelIndex") .setOutputCol("labelVec") + .setDropLast(false) val encoded = encoder.transform(transformed) val output = encoded.select("id", "labelVec").map { r => - val vec = r.get(1).asInstanceOf[Vector] + val vec = r.getAs[Vector](1) (r.getInt(0), vec(0), vec(1), vec(2)) }.collect().toSet // a -> 0, b -> 2, c -> 1 @@ -53,22 +55,46 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext { assert(output === expected) } - test("OneHotEncoder includeFirst = false") { + test("OneHotEncoder dropLast = true") { val transformed = stringIndexed() val encoder = new OneHotEncoder() - .setIncludeFirst(false) .setInputCol("labelIndex") .setOutputCol("labelVec") val encoded = encoder.transform(transformed) val output = encoded.select("id", "labelVec").map { r => - val vec = r.get(1).asInstanceOf[Vector] + val vec = r.getAs[Vector](1) (r.getInt(0), vec(0), vec(1)) }.collect().toSet // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 0.0, 0.0), (1, 0.0, 1.0), (2, 1.0, 0.0), - (3, 0.0, 0.0), (4, 0.0, 0.0), (5, 1.0, 0.0)) + val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0), (2, 0.0, 1.0), + (3, 1.0, 0.0), (4, 1.0, 0.0), (5, 0.0, 1.0)) assert(output === expected) } + test("input column with ML attribute") { + val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") + val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size") + .select(col("size").as("size", attr.toMetadata())) + val encoder = new OneHotEncoder() + .setInputCol("size") + .setOutputCol("encoded") + val output = encoder.transform(df) + val group = AttributeGroup.fromStructField(output.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1)) + } + + test("input column without ML attribute") { + val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index") + val encoder = new OneHotEncoder() + .setInputCol("index") + .setOutputCol("encoded") + val output = encoder.transform(df) + val group = AttributeGroup.fromStructField(output.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1)) + } } diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index b0479d9b074db..ddb33f427ac64 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -324,65 +324,73 @@ def getP(self): @inherit_doc class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol): """ - A one-hot encoder that maps a column of label indices to a column of binary vectors, with - at most a single one-value. By default, the binary vector has an element for each category, so - with 5 categories, an input value of 2.0 would map to an output vector of - (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so - the output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value - of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns - linearly dependent because they sum up to one. - - TODO: This method requires the use of StringIndexer first. Decouple them. + A one-hot encoder that maps a column of category indices to a + column of binary vectors, with at most a single one-value per row + that indicates the input category index. + For example with 5 categories, an input value of 2.0 would map to + an output vector of `[0.0, 0.0, 1.0, 0.0]`. + The last category is not included by default (configurable via + :py:attr:`dropLast`) because it makes the vector entries sum up to + one, and hence linearly dependent. + So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. + Note that this is different from scikit-learn's OneHotEncoder, + which keeps all categories. + The output vectors are sparse. + + .. seealso:: + + :py:class:`StringIndexer` for converting categorical values into + category indices >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") >>> model = stringIndexer.fit(stringIndDf) >>> td = model.transform(stringIndDf) - >>> encoder = OneHotEncoder(includeFirst=False, inputCol="indexed", outputCol="features") + >>> encoder = OneHotEncoder(inputCol="indexed", outputCol="features") >>> encoder.transform(td).head().features - SparseVector(2, {}) + SparseVector(2, {0: 1.0}) >>> encoder.setParams(outputCol="freqs").transform(td).head().freqs - SparseVector(2, {}) - >>> params = {encoder.includeFirst: True, encoder.outputCol: "test"} + SparseVector(2, {0: 1.0}) + >>> params = {encoder.dropLast: False, encoder.outputCol: "test"} >>> encoder.transform(td, params).head().test SparseVector(3, {0: 1.0}) """ # a placeholder to make it appear in the generated doc - includeFirst = Param(Params._dummy(), "includeFirst", "include first category") + dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category") @keyword_only - def __init__(self, includeFirst=True, inputCol=None, outputCol=None): + def __init__(self, dropLast=True, inputCol=None, outputCol=None): """ __init__(self, includeFirst=True, inputCol=None, outputCol=None) """ super(OneHotEncoder, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.OneHotEncoder", self.uid) - self.includeFirst = Param(self, "includeFirst", "include first category") - self._setDefault(includeFirst=True) + self.dropLast = Param(self, "dropLast", "whether to drop the last category") + self._setDefault(dropLast=True) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only - def setParams(self, includeFirst=True, inputCol=None, outputCol=None): + def setParams(self, dropLast=True, inputCol=None, outputCol=None): """ - setParams(self, includeFirst=True, inputCol=None, outputCol=None) + setParams(self, dropLast=True, inputCol=None, outputCol=None) Sets params for this OneHotEncoder. """ kwargs = self.setParams._input_kwargs return self._set(**kwargs) - def setIncludeFirst(self, value): + def setDropLast(self, value): """ - Sets the value of :py:attr:`includeFirst`. + Sets the value of :py:attr:`dropLast`. """ - self._paramMap[self.includeFirst] = value + self._paramMap[self.dropLast] = value return self - def getIncludeFirst(self): + def getDropLast(self): """ - Gets the value of includeFirst or its default value. + Gets the value of dropLast or its default value. """ - return self.getOrDefault(self.includeFirst) + return self.getOrDefault(self.dropLast) @inherit_doc From bf46580708e41a1d48ac091adbca8d82a4008699 Mon Sep 17 00:00:00 2001 From: Tim Ellison Date: Fri, 29 May 2015 05:14:43 -0400 Subject: [PATCH 064/254] [SPARK-7756] [CORE] Use testing cipher suites common to Oracle and IBM security providers Add alias names for supported cipher suites to the sample SSL configuration. The IBM JSSE provider reports its cipher suite with an SSL_ prefix, but accepts TLS_ prefixed suite names as an alias. However, Jetty filters the requested ciphers based on the provider's reported supported suites, so the TLS_ versions are never passed through to JSSE causing an SSL handshake failure. Author: Tim Ellison Closes #6282 from tellison/SSLFailure and squashes the following commits: 8de8a3e [Tim Ellison] Update SecurityManagerSuite with new expected suite names 96158b2 [Tim Ellison] Update the sample configs to use ciphers that are common to both the Oracle and IBM security providers. 705421b [Tim Ellison] Merge branch 'master' of github.com:tellison/spark into SSLFailure 68b9425 [Tim Ellison] Merge branch 'master' of https://github.com/apache/spark into SSLFailure b0c35f6 [Tim Ellison] [CORE] Add aliases used for cipher suites in IBM provider --- core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala | 4 ++-- .../test/scala/org/apache/spark/SecurityManagerSuite.scala | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala index 308b9ea17708d..1a099da2c6c8e 100644 --- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala +++ b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala @@ -34,7 +34,7 @@ object SSLSampleConfigs { conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") conf.set("spark.ssl.enabledAlgorithms", - "TLS_RSA_WITH_AES_128_CBC_SHA, SSL_RSA_WITH_DES_CBC_SHA") + "SSL_RSA_WITH_RC4_128_SHA, SSL_RSA_WITH_DES_CBC_SHA") conf.set("spark.ssl.protocol", "TLSv1") conf } @@ -48,7 +48,7 @@ object SSLSampleConfigs { conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") conf.set("spark.ssl.enabledAlgorithms", - "TLS_RSA_WITH_AES_128_CBC_SHA, SSL_RSA_WITH_DES_CBC_SHA") + "SSL_RSA_WITH_RC4_128_SHA, SSL_RSA_WITH_DES_CBC_SHA") conf.set("spark.ssl.protocol", "TLSv1") conf } diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index 62cb7649c0284..61571be44252a 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -147,7 +147,7 @@ class SecurityManagerSuite extends FunSuite { assert(securityManager.fileServerSSLOptions.keyPassword === Some("password")) assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1")) assert(securityManager.fileServerSSLOptions.enabledAlgorithms === - Set("TLS_RSA_WITH_AES_128_CBC_SHA", "SSL_RSA_WITH_DES_CBC_SHA")) + Set("SSL_RSA_WITH_RC4_128_SHA", "SSL_RSA_WITH_DES_CBC_SHA")) assert(securityManager.akkaSSLOptions.trustStore.isDefined === true) assert(securityManager.akkaSSLOptions.trustStore.get.getName === "truststore") @@ -158,7 +158,7 @@ class SecurityManagerSuite extends FunSuite { assert(securityManager.akkaSSLOptions.keyPassword === Some("password")) assert(securityManager.akkaSSLOptions.protocol === Some("TLSv1")) assert(securityManager.akkaSSLOptions.enabledAlgorithms === - Set("TLS_RSA_WITH_AES_128_CBC_SHA", "SSL_RSA_WITH_DES_CBC_SHA")) + Set("SSL_RSA_WITH_RC4_128_SHA", "SSL_RSA_WITH_DES_CBC_SHA")) } test("ssl off setup") { From 8db40f6711058c3c3bf67ceaaaffffcc25d67d19 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 29 May 2015 05:17:41 -0400 Subject: [PATCH 065/254] [SPARK-7863] [CORE] Create SimpleDateFormat for every SimpleDateParam instance because it's not thread-safe SimpleDateFormat is not thread-safe. This PR creates new `SimpleDateFormat` for each `SimpleDateParam` instance. Author: zsxwing Closes #6406 from zsxwing/SPARK-7863 and squashes the following commits: aeed4c1 [zsxwing] Rewrite SimpleDateParam 8cdd986 [zsxwing] Inline formats 9680a15 [zsxwing] Create SimpleDateFormat for each SimpleDateParam instance because it's not thread-safe --- .../spark/status/api/v1/SimpleDateParam.scala | 49 ++++++++----------- .../status/api/v1/SimpleDateParamSuite.scala | 5 ++ 2 files changed, 26 insertions(+), 28 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala index cee29786c3019..0c71cd2382225 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala @@ -16,40 +16,33 @@ */ package org.apache.spark.status.api.v1 -import java.text.SimpleDateFormat +import java.text.{ParseException, SimpleDateFormat} import java.util.TimeZone import javax.ws.rs.WebApplicationException import javax.ws.rs.core.Response import javax.ws.rs.core.Response.Status -import scala.util.Try - private[v1] class SimpleDateParam(val originalValue: String) { - val timestamp: Long = { - SimpleDateParam.formats.collectFirst { - case fmt if Try(fmt.parse(originalValue)).isSuccess => - fmt.parse(originalValue).getTime() - }.getOrElse( - throw new WebApplicationException( - Response - .status(Status.BAD_REQUEST) - .entity("Couldn't parse date: " + originalValue) - .build() - ) - ) - } -} -private[v1] object SimpleDateParam { - - val formats: Seq[SimpleDateFormat] = { - - val gmtDay = new SimpleDateFormat("yyyy-MM-dd") - gmtDay.setTimeZone(TimeZone.getTimeZone("GMT")) - - Seq( - new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz"), - gmtDay - ) + val timestamp: Long = { + val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz") + try { + format.parse(originalValue).getTime() + } catch { + case _: ParseException => + val gmtDay = new SimpleDateFormat("yyyy-MM-dd") + gmtDay.setTimeZone(TimeZone.getTimeZone("GMT")) + try { + gmtDay.parse(originalValue).getTime() + } catch { + case _: ParseException => + throw new WebApplicationException( + Response + .status(Status.BAD_REQUEST) + .entity("Couldn't parse date: " + originalValue) + .build() + ) + } + } } } diff --git a/core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamSuite.scala b/core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamSuite.scala index 731d1f557ed33..183043bc05233 100644 --- a/core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamSuite.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.status.api.v1 +import javax.ws.rs.WebApplicationException + import org.scalatest.{Matchers, FunSuite} class SimpleDateParamSuite extends FunSuite with Matchers { @@ -24,6 +26,9 @@ class SimpleDateParamSuite extends FunSuite with Matchers { new SimpleDateParam("2015-02-20T23:21:17.190GMT").timestamp should be (1424474477190L) new SimpleDateParam("2015-02-20T17:21:17.190EST").timestamp should be (1424470877190L) new SimpleDateParam("2015-02-20").timestamp should be (1424390400000L) // GMT + intercept[WebApplicationException] { + new SimpleDateParam("invalid date") + } } } From a51b133de3c65a991ab105b6f020082080121b4c Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Fri, 29 May 2015 11:06:11 -0500 Subject: [PATCH 066/254] [SPARK-7524] [SPARK-7846] add configs for keytab and principal, pass these two configs with different way in different modes * As spark now supports long running service by updating tokens for namenode, but only accept parameters passed with "--k=v" format which is not very convinient. This patch add spark.* configs in properties file and system property. * --principal and --keytabl options are passed to client but when we started thrift server or spark-shell these two are also passed into the Main class (org.apache.spark.sql.hive.thriftserver.HiveThriftServer2 and org.apache.spark.repl.Main). In these two main class, arguments passed in will be processed with some 3rd libraries, which will lead to some error: "Invalid option: --principal" or "Unrecgnised option: --principal". We should pass these command args in different forms, say system properties. Author: WangTaoTheTonic Closes #6051 from WangTaoTheTonic/SPARK-7524 and squashes the following commits: e65699a [WangTaoTheTonic] change logic to loadEnvironments ebd9ea0 [WangTaoTheTonic] merge master ecfe43a [WangTaoTheTonic] pass keytab and principal seperately in different mode 33a7f40 [WangTaoTheTonic] expand the use of the current configs 08bb4e8 [WangTaoTheTonic] fix wrong cite 73afa64 [WangTaoTheTonic] add configs for keytab and principal, move originals to internal --- .../org/apache/spark/deploy/SparkSubmit.scala | 8 ++++---- .../spark/deploy/SparkSubmitArguments.scala | 2 ++ docs/running-on-yarn.md | 16 ++++++++++++++++ .../deploy/yarn/AMDelegationTokenRenewer.scala | 14 ++++++++------ .../spark/deploy/yarn/ClientArguments.scala | 6 ++++++ 5 files changed, 36 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 92bb5059a0313..d1b32ea0778db 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -428,6 +428,8 @@ object SparkSubmit { OptionAssigner(args.executorCores, YARN, CLIENT, sysProp = "spark.executor.cores"), OptionAssigner(args.files, YARN, CLIENT, sysProp = "spark.yarn.dist.files"), OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"), + OptionAssigner(args.principal, YARN, CLIENT, sysProp = "spark.yarn.principal"), + OptionAssigner(args.keytab, YARN, CLIENT, sysProp = "spark.yarn.keytab"), // Yarn cluster only OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name"), @@ -440,10 +442,8 @@ object SparkSubmit { OptionAssigner(args.files, YARN, CLUSTER, clOption = "--files"), OptionAssigner(args.archives, YARN, CLUSTER, clOption = "--archives"), OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"), - - // Yarn client or cluster - OptionAssigner(args.principal, YARN, ALL_DEPLOY_MODES, clOption = "--principal"), - OptionAssigner(args.keytab, YARN, ALL_DEPLOY_MODES, clOption = "--keytab"), + OptionAssigner(args.principal, YARN, CLUSTER, clOption = "--principal"), + OptionAssigner(args.keytab, YARN, CLUSTER, clOption = "--keytab"), // Other options OptionAssigner(args.executorCores, STANDALONE, ALL_DEPLOY_MODES, diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index c0e4c771908b3..cc6a7bd9f4119 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -169,6 +169,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull numExecutors = Option(numExecutors) .getOrElse(sparkProperties.get("spark.executor.instances").orNull) + keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull + principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull // Try to set main class from JAR if no --class argument is given if (mainClass == null && !isPython && !isR && primaryResource != null) { diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 9d55f435e80ad..96cf612c54fdd 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -242,6 +242,22 @@ Most of the configs are the same for Spark on YARN as for other deployment modes running against earlier versions, this property will be ignored. + + spark.yarn.keytab + (none) + + The full path to the file that contains the keytab for the principal specified above. + This keytab will be copied to the node running the Application Master via the Secure Distributed Cache, + for renewing the login tickets and the delegation tokens periodically. + + + + spark.yarn.principal + (none) + + Principal to be used to login to KDC, while running on secure HDFS. + + # Launching Spark on YARN diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala index aaae6f9734a85..77af46c192cc2 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala @@ -60,8 +60,11 @@ private[yarn] class AMDelegationTokenRenewer( private val hadoopUtil = YarnSparkHadoopUtil.get - private val daysToKeepFiles = sparkConf.getInt("spark.yarn.credentials.file.retention.days", 5) - private val numFilesToKeep = sparkConf.getInt("spark.yarn.credentials.file.retention.count", 5) + private val credentialsFile = sparkConf.get("spark.yarn.credentials.file") + private val daysToKeepFiles = + sparkConf.getInt("spark.yarn.credentials.file.retention.days", 5) + private val numFilesToKeep = + sparkConf.getInt("spark.yarn.credentials.file.retention.count", 5) /** * Schedule a login from the keytab and principal set using the --principal and --keytab @@ -121,7 +124,7 @@ private[yarn] class AMDelegationTokenRenewer( import scala.concurrent.duration._ try { val remoteFs = FileSystem.get(hadoopConf) - val credentialsPath = new Path(sparkConf.get("spark.yarn.credentials.file")) + val credentialsPath = new Path(credentialsFile) val thresholdTime = System.currentTimeMillis() - (daysToKeepFiles days).toMillis hadoopUtil.listFilesSorted( remoteFs, credentialsPath.getParent, @@ -160,7 +163,7 @@ private[yarn] class AMDelegationTokenRenewer( val keytabLoggedInUGI = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab) logInfo("Successfully logged into KDC.") val tempCreds = keytabLoggedInUGI.getCredentials - val credentialsPath = new Path(sparkConf.get("spark.yarn.credentials.file")) + val credentialsPath = new Path(credentialsFile) val dst = credentialsPath.getParent keytabLoggedInUGI.doAs(new PrivilegedExceptionAction[Void] { // Get a copy of the credentials @@ -186,8 +189,7 @@ private[yarn] class AMDelegationTokenRenewer( } val nextSuffix = lastCredentialsFileSuffix + 1 val tokenPathStr = - sparkConf.get("spark.yarn.credentials.file") + - SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + nextSuffix + credentialsFile + SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + nextSuffix val tokenPath = new Path(tokenPathStr) val tempTokenPath = new Path(tokenPathStr + SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) logInfo("Writing out delegation tokens to " + tempTokenPath.toString) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 5653c9f14dc6d..9c7b1b3988082 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -98,6 +98,12 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) numExecutors = initialNumExecutors } + principal = Option(principal) + .orElse(sparkConf.getOption("spark.yarn.principal")) + .orNull + keytab = Option(keytab) + .orElse(sparkConf.getOption("spark.yarn.keytab")) + .orNull } /** From e7b61775571ce7a06d044bc3a6055ff94c7477d6 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 29 May 2015 10:43:34 -0700 Subject: [PATCH 067/254] [SPARK-7950] [SQL] Sets spark.sql.hive.version in HiveThriftServer2.startWithContext() When starting `HiveThriftServer2` via `startWithContext`, property `spark.sql.hive.version` isn't set. This causes Simba ODBC driver 1.0.8.1006 behaves differently and fails simple queries. Hive2 JDBC driver works fine in this case. Also, when starting the server with `start-thriftserver.sh`, both Hive2 JDBC driver and Simba ODBC driver works fine. Please refer to [SPARK-7950] [1] for details. [1]: https://issues.apache.org/jira/browse/SPARK-7950 Author: Cheng Lian Closes #6500 from liancheng/odbc-bugfix and squashes the following commits: 051e3a3 [Cheng Lian] Fixes import order 3a97376 [Cheng Lian] Sets spark.sql.hive.version in HiveThriftServer2.startWithContext() --- .../sql/hive/thriftserver/HiveThriftServer2.scala | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 3458b04bfba0f..94687eeda4179 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -17,23 +17,23 @@ package org.apache.spark.sql.hive.thriftserver +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + import org.apache.commons.logging.LogFactory import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.service.cli.thrift.{ThriftBinaryCLIService, ThriftHttpCLIService} import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} -import org.apache.spark.sql.SQLConf -import org.apache.spark.{SparkContext, SparkConf, Logging} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd, SparkListenerJobStart} +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ -import org.apache.spark.scheduler.{SparkListenerJobStart, SparkListenerApplicationEnd, SparkListener} import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab +import org.apache.spark.sql.hive.{HiveContext, HiveShim} import org.apache.spark.util.Utils - -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer +import org.apache.spark.{Logging, SparkContext} /** * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a @@ -51,6 +51,7 @@ object HiveThriftServer2 extends Logging { @DeveloperApi def startWithContext(sqlContext: HiveContext): Unit = { val server = new HiveThriftServer2(sqlContext) + sqlContext.setConf("spark.sql.hive.version", HiveShim.version) server.init(sqlContext.hiveconf) server.start() listener = new HiveThriftServer2Listener(server, sqlContext.conf) From 4782e130400f16e77c8b7f7fe8791acae1c5f8f1 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 29 May 2015 11:11:40 -0700 Subject: [PATCH 068/254] [SQL] [TEST] [MINOR] Uses a temporary log4j.properties in HiveThriftServer2Test to ensure expected logging behavior The `HiveThriftServer2Test` relies on proper logging behavior to assert whether the Thrift server daemon process is started successfully. However, some other jar files listed in the classpath may potentially contain an unexpected Log4J configuration file which overrides the logging behavior. This PR writes a temporary `log4j.properties` and prepend it to driver classpath before starting the testing Thrift server process to ensure proper logging behavior. cc andrewor14 yhuai Author: Cheng Lian Closes #6493 from liancheng/override-log4j and squashes the following commits: c489e0e [Cheng Lian] Fixes minor Scala styling issue b46ef0d [Cheng Lian] Uses a temporary log4j.properties in HiveThriftServer2Test to ensure expected logging behavior --- .../HiveThriftServer2Suites.scala | 31 +++++++++++++++---- 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 1fadea97fd07f..610939c6a9481 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.hive.thriftserver import java.io.File import java.net.URL +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, Paths} import java.sql.{Date, DriverManager, Statement} import scala.collection.mutable.ArrayBuffer @@ -54,7 +56,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { override def mode: ServerMode.Value = ServerMode.binary private def withCLIServiceClient(f: ThriftCLIServiceClient => Unit): Unit = { - // Transport creation logics below mimics HiveConnection.createBinaryTransport + // Transport creation logic below mimics HiveConnection.createBinaryTransport val rawTransport = new TSocket("localhost", serverPort) val user = System.getProperty("user.name") val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport) @@ -391,10 +393,10 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { val statements = connections.map(_.createStatement()) try { - statements.zip(fs).map { case (s, f) => f(s) } + statements.zip(fs).foreach { case (s, f) => f(s) } } finally { - statements.map(_.close()) - connections.map(_.close()) + statements.foreach(_.close()) + connections.foreach(_.close()) } } @@ -433,15 +435,32 @@ abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll wit ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT } + val driverClassPath = { + // Writes a temporary log4j.properties and prepend it to driver classpath, so that it + // overrides all other potential log4j configurations contained in other dependency jar files. + val tempLog4jConf = Utils.createTempDir().getCanonicalPath + + Files.write( + Paths.get(s"$tempLog4jConf/log4j.properties"), + """log4j.rootCategory=INFO, console + |log4j.appender.console=org.apache.log4j.ConsoleAppender + |log4j.appender.console.target=System.err + |log4j.appender.console.layout=org.apache.log4j.PatternLayout + |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + """.stripMargin.getBytes(StandardCharsets.UTF_8)) + + tempLog4jConf + File.pathSeparator + sys.props("java.class.path") + } + s"""$startScript | --master local - | --hiveconf hive.root.logger=INFO,console | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode | --hiveconf $portConf=$port - | --driver-class-path ${sys.props("java.class.path")} + | --driver-class-path $driverClassPath + | --driver-java-options -Dlog4j.debug | --conf spark.ui.enabled=false """.stripMargin.split("\\s+").toSeq } From 6181937f315480543d28e542d43269cfa591e9d0 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Fri, 29 May 2015 11:36:41 -0700 Subject: [PATCH 069/254] [SPARK-7946] [MLLIB] DecayFactor wrongly set in StreamingKMeans Author: MechCoder Closes #6497 from MechCoder/spark-7946 and squashes the following commits: 2fdd0a3 [MechCoder] Add non-regression test 8c988c6 [MechCoder] [SPARK-7946] DecayFactor wrongly set in StreamingKMeans --- .../apache/spark/mllib/clustering/StreamingKMeans.scala | 2 +- .../spark/mllib/clustering/StreamingKMeansSuite.scala | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 812014a041719..c21e4fe7dc9b6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -178,7 +178,7 @@ class StreamingKMeans( /** Set the decay factor directly (for forgetful algorithms). */ def setDecayFactor(a: Double): this.type = { - this.decayFactor = decayFactor + this.decayFactor = a this } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala index f90025d535e45..13f9b17c027a4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -133,6 +133,13 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase { assert(math.abs(c1) ~== 0.8 absTol 0.6) } + test("SPARK-7946 setDecayFactor") { + val kMeans = new StreamingKMeans() + assert(kMeans.decayFactor === 1.0) + kMeans.setDecayFactor(2.0) + assert(kMeans.decayFactor === 2.0) + } + def StreamingKMeansDataGenerator( numPoints: Int, numBatches: Int, From 94f62a4979e4bc5f7bf4f5852d76977e097209e6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 29 May 2015 13:38:37 -0700 Subject: [PATCH 070/254] [SPARK-7940] Enforce whitespace checking for DO, TRY, CATCH, FINALLY, MATCH, LARROW, RARROW in style checker. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … Author: Reynold Xin Closes #6491 from rxin/more-whitespace and squashes the following commits: f6e63dc [Reynold Xin] [SPARK-7940] Enforce whitespace checking for DO, TRY, CATCH, FINALLY, MATCH, LARROW, RARROW in style checker. --- .../scala/org/apache/spark/network/nio/BlockMessage.scala | 2 +- .../main/scala/org/apache/spark/network/nio/Connection.scala | 5 ++--- .../org/apache/spark/network/nio/ConnectionManager.scala | 5 ++--- .../org/apache/spark/rdd/PartitionerAwareUnionRDD.scala | 2 +- .../main/scala/org/apache/spark/mllib/tree/model/Node.scala | 2 +- .../spark/mllib/classification/LogisticRegressionSuite.scala | 4 ++-- scalastyle-config.xml | 4 ++-- .../main/scala/org/apache/spark/sql/types/UTF8String.scala | 2 +- .../src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala | 2 +- 9 files changed, 13 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala index 1a92a799d004a..67a376102994c 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala @@ -155,7 +155,7 @@ private[nio] class BlockMessage() { override def toString: String = { "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level + - ", data = " + (if (data != null) data.remaining.toString else "null") + "]" + ", data = " + (if (data != null) data.remaining.toString else "null") + "]" } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index 6b898bd4bfc1b..1499da07bb83b 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -326,15 +326,14 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, // MUST be called within the selector loop def connect() { - try{ + try { channel.register(selector, SelectionKey.OP_CONNECT) channel.connect(address) logInfo("Initiating connection to [" + address + "]") } catch { - case e: Exception => { + case e: Exception => logError("Error connecting to " + address, e) callOnExceptionCallbacks(e) - } } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 497871ed6d5e5..c0bca2c4bc994 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -635,12 +635,11 @@ private[nio] class ConnectionManager( val message = securityMsgResp.toBufferMessage if (message == null) throw new IOException("Error creating security message") sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) - } catch { - case e: Exception => { + } catch { + case e: Exception => logError("Error handling sasl client authentication", e) waitingConn.close() throw new IOException("Error evaluating sasl response: ", e) - } } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala index 7598ff617b399..9e3880714a79f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -86,7 +86,7 @@ class PartitionerAwareUnionRDD[T: ClassTag]( } val location = if (locations.isEmpty) { None - } else { + } else { // Find the location that maximum number of parent partitions prefer Some(locations.groupBy(x => x).maxBy(_._2.length)._1) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index ee710fc1ed299..a6d1398fc267b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -83,7 +83,7 @@ class Node ( def predict(features: Vector) : Double = { if (isLeaf) { predict.predict - } else{ + } else { if (split.get.featureType == Continuous) { if (features(split.get.feature) <= split.get.threshold) { leftNode.get.predict(features) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 966811a5a3263..b1014ab7c6203 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -119,7 +119,7 @@ object LogisticRegressionSuite { } // Preventing the overflow when we compute the probability val maxMargin = margins.max - if (maxMargin > 0) for (i <-0 until nClasses) margins(i) -= maxMargin + if (maxMargin > 0) for (i <- 0 until nClasses) margins(i) -= maxMargin // Computing the probabilities for each class from the margins. val norm = { @@ -130,7 +130,7 @@ object LogisticRegressionSuite { } temp } - for (i <-0 until nClasses) probs(i) /= norm + for (i <- 0 until nClasses) probs(i) /= norm // Compute the cumulative probability so we can generate a random number and assign a label. for (i <- 1 until nClasses) probs(i) += probs(i - 1) diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 68d980b610c00..68c8ce3b7e10b 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -144,12 +144,12 @@ - ARROW, EQUALS + ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW - ARROW, EQUALS, COMMA, COLON, IF, WHILE, FOR + ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, TRY, CATCH, FINALLY, LARROW, RARROW diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala index bc9c37bf2d5d2..f5d8fcced362b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala @@ -203,7 +203,7 @@ object UTF8String { def apply(s: String): UTF8String = { if (s != null) { new UTF8String().set(s) - } else{ + } else { null } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 0bdb68e8ac845..2d8d950038e78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -262,7 +262,7 @@ private[sql] class JDBCRDD( } private def escapeSql(value: String): String = - if (value == null) null else StringUtils.replace(value, "'", "''") + if (value == null) null else StringUtils.replace(value, "'", "''") /** * Turns a single Filter into a String representing a SQL expression. From 9eb222c13991c2b4a22db485710dc2e27ccf06dd Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 29 May 2015 14:03:12 -0700 Subject: [PATCH 071/254] [SPARK-7558] Demarcate tests in unit-tests.log Right now `unit-tests.log` are not of much value because we can't tell where the test boundaries are easily. This patch adds log statements before and after each test to outline the test boundaries, e.g.: ``` ===== TEST OUTPUT FOR o.a.s.serializer.KryoSerializerSuite: 'kryo with parallelize for primitive arrays' ===== 15/05/27 12:36:39.596 pool-1-thread-1-ScalaTest-running-KryoSerializerSuite INFO SparkContext: Starting job: count at KryoSerializerSuite.scala:230 15/05/27 12:36:39.596 dag-scheduler-event-loop INFO DAGScheduler: Got job 3 (count at KryoSerializerSuite.scala:230) with 4 output partitions (allowLocal=false) 15/05/27 12:36:39.596 dag-scheduler-event-loop INFO DAGScheduler: Final stage: ResultStage 3(count at KryoSerializerSuite.scala:230) 15/05/27 12:36:39.596 dag-scheduler-event-loop INFO DAGScheduler: Parents of final stage: List() 15/05/27 12:36:39.597 dag-scheduler-event-loop INFO DAGScheduler: Missing parents: List() 15/05/27 12:36:39.597 dag-scheduler-event-loop INFO DAGScheduler: Submitting ResultStage 3 (ParallelCollectionRDD[5] at parallelize at KryoSerializerSuite.scala:230), which has no missing parents ... 15/05/27 12:36:39.624 pool-1-thread-1-ScalaTest-running-KryoSerializerSuite INFO DAGScheduler: Job 3 finished: count at KryoSerializerSuite.scala:230, took 0.028563 s 15/05/27 12:36:39.625 pool-1-thread-1-ScalaTest-running-KryoSerializerSuite INFO KryoSerializerSuite: ***** FINISHED o.a.s.serializer.KryoSerializerSuite: 'kryo with parallelize for primitive arrays' ***** ... ``` Author: Andrew Or Closes #6441 from andrewor14/demarcate-tests and squashes the following commits: 879b060 [Andrew Or] Fix compile after rebase d622af7 [Andrew Or] Merge branch 'master' of github.com:apache/spark into demarcate-tests 017c8ba [Andrew Or] Merge branch 'master' of github.com:apache/spark into demarcate-tests 7790b6c [Andrew Or] Fix tests after logical merge conflict c7460c0 [Andrew Or] Merge branch 'master' of github.com:apache/spark into demarcate-tests c43ffc4 [Andrew Or] Fix tests? 8882581 [Andrew Or] Fix tests ee22cda [Andrew Or] Fix log message fa9450e [Andrew Or] Merge branch 'master' of github.com:apache/spark into demarcate-tests 12d1e1b [Andrew Or] Various whitespace changes (minor) 69cbb24 [Andrew Or] Make all test suites extend SparkFunSuite instead of FunSuite bbce12e [Andrew Or] Fix manual things that cannot be covered through automation da0b12f [Andrew Or] Add core tests as dependencies in all modules f7d29ce [Andrew Or] Introduce base abstract class for all test suites --- bagel/pom.xml | 7 +++ .../org/apache/spark/bagel/BagelSuite.scala | 4 +- core/pom.xml | 6 +++ .../org/apache/spark/AccumulatorSuite.scala | 3 +- .../org/apache/spark/CacheManagerSuite.scala | 4 +- .../org/apache/spark/CheckpointSuite.scala | 4 +- .../apache/spark/ContextCleanerSuite.scala | 4 +- .../org/apache/spark/DistributedSuite.scala | 3 +- .../scala/org/apache/spark/DriverSuite.scala | 3 +- .../ExecutorAllocationManagerSuite.scala | 8 +++- .../scala/org/apache/spark/FailureSuite.scala | 4 +- .../org/apache/spark/FileServerSuite.scala | 3 +- .../scala/org/apache/spark/FileSuite.scala | 3 +- .../org/apache/spark/FutureActionSuite.scala | 8 +++- .../apache/spark/HeartbeatReceiverSuite.scala | 3 +- .../apache/spark/ImplicitOrderingSuite.scala | 4 +- .../apache/spark/JobCancellationSuite.scala | 4 +- .../apache/spark/MapOutputTrackerSuite.scala | 3 +- .../org/apache/spark/PartitioningSuite.scala | 4 +- .../org/apache/spark/SSLOptionsSuite.scala | 4 +- .../apache/spark/SecurityManagerSuite.scala | 4 +- .../scala/org/apache/spark/ShuffleSuite.scala | 3 +- .../org/apache/spark/SparkConfSuite.scala | 3 +- .../apache/spark/SparkContextInfoSuite.scala | 4 +- .../SparkContextSchedulerCreationSuite.scala | 4 +- .../org/apache/spark/SparkContextSuite.scala | 4 +- .../org/apache/spark/SparkFunSuite.scala | 46 +++++++++++++++++++ .../org/apache/spark/StatusTrackerSuite.scala | 4 +- .../org/apache/spark/ThreadingSuite.scala | 3 +- .../org/apache/spark/UnpersistSuite.scala | 3 +- .../api/python/PythonBroadcastSuite.scala | 6 +-- .../spark/api/python/PythonRDDSuite.scala | 4 +- .../spark/api/python/SerDeUtilSuite.scala | 6 +-- .../spark/broadcast/BroadcastSuite.scala | 6 +-- .../org/apache/spark/deploy/ClientSuite.scala | 5 +- .../spark/deploy/JsonProtocolSuite.scala | 5 +- .../spark/deploy/LogUrlsStandaloneSuite.scala | 6 +-- .../spark/deploy/PythonRunnerSuite.scala | 5 +- .../spark/deploy/SparkSubmitSuite.scala | 8 +++- .../spark/deploy/SparkSubmitUtilsSuite.scala | 5 +- .../history/FsHistoryProviderSuite.scala | 6 +-- .../deploy/history/HistoryServerSuite.scala | 6 +-- .../spark/deploy/master/MasterSuite.scala | 6 +-- .../rest/StandaloneRestSubmitSuite.scala | 4 +- .../deploy/rest/SubmitRestProtocolSuite.scala | 5 +- .../deploy/worker/CommandUtilsSuite.scala | 5 +- .../deploy/worker/DriverRunnerTest.scala | 5 +- .../deploy/worker/ExecutorRunnerTest.scala | 6 +-- .../deploy/worker/WorkerArgumentsTest.scala | 5 +- .../spark/deploy/worker/WorkerSuite.scala | 6 +-- .../deploy/worker/WorkerWatcherSuite.scala | 5 +- .../spark/executor/TaskMetricsSuite.scala | 4 +- .../WholeTextFileRecordReaderSuite.scala | 5 +- .../spark/io/CompressionCodecSuite.scala | 5 +- .../metrics/InputOutputMetricsSuite.scala | 6 +-- .../spark/metrics/MetricsConfigSuite.scala | 6 ++- .../spark/metrics/MetricsSystemSuite.scala | 6 +-- .../NettyBlockTransferSecuritySuite.scala | 6 +-- .../NettyBlockTransferServiceSuite.scala | 8 +++- .../network/nio/ConnectionManagerSuite.scala | 6 +-- .../spark/rdd/AsyncRDDActionsSuite.scala | 6 +-- .../org/apache/spark/rdd/DoubleRDDSuite.scala | 4 +- .../org/apache/spark/rdd/JdbcRDDSuite.scala | 6 +-- .../spark/rdd/PairRDDFunctionsSuite.scala | 6 +-- .../rdd/ParallelCollectionSplitSuite.scala | 5 +- .../spark/rdd/PartitionPruningRDDSuite.scala | 6 +-- .../rdd/PartitionwiseSampledRDDSuite.scala | 6 +-- .../org/apache/spark/rdd/PipedRDDSuite.scala | 3 +- .../spark/rdd/RDDOperationScopeSuite.scala | 6 +-- .../scala/org/apache/spark/rdd/RDDSuite.scala | 4 +- .../org/apache/spark/rdd/SortingSuite.scala | 5 +- .../spark/rdd/ZippedPartitionsSuite.scala | 5 +- .../org/apache/spark/rpc/RpcEnvSuite.scala | 6 +-- .../CoarseGrainedSchedulerBackendSuite.scala | 6 +-- .../spark/scheduler/DAGSchedulerSuite.scala | 4 +- .../scheduler/EventLoggingListenerSuite.scala | 4 +- .../spark/scheduler/MapStatusSuite.scala | 5 +- .../OutputCommitCoordinatorSuite.scala | 4 +- .../apache/spark/scheduler/PoolSuite.scala | 6 +-- .../spark/scheduler/ReplayListenerSuite.scala | 6 +-- .../spark/scheduler/SparkListenerSuite.scala | 6 +-- .../SparkListenerWithClusterSuite.scala | 6 +-- .../spark/scheduler/TaskContextSuite.scala | 3 +- .../scheduler/TaskResultGetterSuite.scala | 6 +-- .../scheduler/TaskSchedulerImplSuite.scala | 4 +- .../spark/scheduler/TaskSetManagerSuite.scala | 4 +- .../cluster/mesos/MemoryUtilsSuite.scala | 5 +- .../mesos/MesosSchedulerBackendSuite.scala | 5 +- .../mesos/MesosTaskLaunchDataSuite.scala | 4 +- .../mesos/MesosClusterSchedulerSuite.scala | 5 +- .../serializer/JavaSerializerSuite.scala | 5 +- .../KryoSerializerDistributedSuite.scala | 5 +- .../KryoSerializerResizableOutputSuite.scala | 6 +-- .../serializer/KryoSerializerSuite.scala | 7 ++- .../ProactiveClosureSerializationSuite.scala | 6 +-- .../SerializationDebuggerSuite.scala | 6 ++- .../SerializerPropertiesSuite.scala | 6 +-- .../shuffle/ShuffleMemoryManagerSuite.scala | 5 +- .../hash/HashShuffleManagerSuite.scala | 6 +-- .../unsafe/UnsafeShuffleManagerSuite.scala | 4 +- .../status/api/v1/SimpleDateParamSuite.scala | 6 ++- .../apache/spark/storage/BlockIdSuite.scala | 4 +- .../BlockManagerReplicationSuite.scala | 6 +-- .../spark/storage/BlockManagerSuite.scala | 4 +- .../storage/BlockObjectWriterSuite.scala | 6 +-- .../spark/storage/DiskBlockManagerSuite.scala | 6 +-- .../spark/storage/FlatmapIteratorSuite.scala | 5 +- .../apache/spark/storage/LocalDirsSuite.scala | 6 +-- .../ShuffleBlockFetcherIteratorSuite.scala | 5 +- .../storage/StorageStatusListenerSuite.scala | 5 +- .../apache/spark/storage/StorageSuite.scala | 4 +- .../org/apache/spark/ui/UISeleniumSuite.scala | 2 +- .../scala/org/apache/spark/ui/UISuite.scala | 5 +- .../ui/jobs/JobProgressListenerSuite.scala | 3 +- .../RDDOperationGraphListenerSuite.scala | 6 +-- .../spark/ui/storage/StorageTabSuite.scala | 6 +-- .../apache/spark/util/AkkaUtilsSuite.scala | 3 +- .../spark/util/ClosureCleanerSuite.scala | 6 +-- .../spark/util/ClosureCleanerSuite2.scala | 6 +-- .../spark/util/CompletionIteratorSuite.scala | 4 +- .../apache/spark/util/DistributionSuite.scala | 5 +- .../apache/spark/util/EventLoopSuite.scala | 5 +- .../apache/spark/util/FileAppenderSuite.scala | 6 +-- .../apache/spark/util/JsonProtocolSuite.scala | 3 +- .../util/MutableURLClassLoaderSuite.scala | 6 +-- .../apache/spark/util/NextIteratorSuite.scala | 5 +- .../spark/util/ResetSystemProperties.scala | 4 +- .../spark/util/SizeEstimatorSuite.scala | 9 +++- .../apache/spark/util/ThreadUtilsSuite.scala | 4 +- .../spark/util/TimeStampedHashMapSuite.scala | 4 +- .../org/apache/spark/util/UtilsSuite.scala | 5 +- .../org/apache/spark/util/VectorSuite.scala | 4 +- .../util/collection/AppendOnlyMapSuite.scala | 4 +- .../spark/util/collection/BitSetSuite.scala | 4 +- .../util/collection/ChainedBufferSuite.scala | 5 +- .../util/collection/CompactBufferSuite.scala | 4 +- .../ExternalAppendOnlyMapSuite.scala | 4 +- .../util/collection/ExternalSorterSuite.scala | 4 +- .../util/collection/OpenHashMapSuite.scala | 4 +- .../util/collection/OpenHashSetSuite.scala | 4 +- ...PartitionedSerializedPairBufferSuite.scala | 5 +- .../PrimitiveKeyOpenHashMapSuite.scala | 4 +- .../collection/PrimitiveVectorSuite.scala | 5 +- .../util/collection/SizeTrackerSuite.scala | 5 +- .../spark/util/collection/SorterSuite.scala | 5 +- .../io/ByteArrayChunkOutputStreamSuite.scala | 4 +- .../util/random/RandomSamplerSuite.scala | 6 ++- .../util/random/SamplingUtilsSuite.scala | 5 +- .../util/random/XORShiftRandomSuite.scala | 4 +- external/flume-sink/pom.xml | 7 +++ .../streaming/flume/sink/SparkSinkSuite.scala | 5 +- external/flume/pom.xml | 7 +++ .../flume/FlumePollingStreamSuite.scala | 6 +-- .../streaming/flume/FlumeStreamSuite.scala | 6 +-- external/kafka/pom.xml | 7 +++ .../kafka/DirectKafkaStreamSuite.scala | 6 +-- .../streaming/kafka/KafkaClusterSuite.scala | 6 ++- .../spark/streaming/kafka/KafkaRDDSuite.scala | 4 +- .../streaming/kafka/KafkaStreamSuite.scala | 6 +-- .../kafka/ReliableKafkaStreamSuite.scala | 6 +-- external/mqtt/pom.xml | 7 +++ .../streaming/mqtt/MQTTStreamSuite.scala | 6 +-- external/twitter/pom.xml | 7 +++ .../twitter/TwitterStreamSuite.scala | 6 +-- external/zeromq/pom.xml | 7 +++ .../streaming/zeromq/ZeroMQStreamSuite.scala | 4 +- graphx/pom.xml | 7 +++ .../apache/spark/graphx/EdgeRDDSuite.scala | 5 +- .../org/apache/spark/graphx/EdgeSuite.scala | 4 +- .../apache/spark/graphx/GraphOpsSuite.scala | 5 +- .../org/apache/spark/graphx/GraphSuite.scala | 6 +-- .../org/apache/spark/graphx/PregelSuite.scala | 6 +-- .../apache/spark/graphx/VertexRDDSuite.scala | 6 +-- .../graphx/impl/EdgePartitionSuite.scala | 6 +-- .../graphx/impl/VertexPartitionSuite.scala | 6 +-- .../graphx/lib/ConnectedComponentsSuite.scala | 6 +-- .../graphx/lib/LabelPropagationSuite.scala | 5 +- .../spark/graphx/lib/PageRankSuite.scala | 5 +- .../spark/graphx/lib/SVDPlusPlusSuite.scala | 5 +- .../spark/graphx/lib/ShortestPathsSuite.scala | 6 +-- .../StronglyConnectedComponentsSuite.scala | 6 +-- .../spark/graphx/lib/TriangleCountSuite.scala | 5 +- .../graphx/util/BytecodeUtilsSuite.scala | 4 +- .../graphx/util/GraphGeneratorsSuite.scala | 5 +- .../spark/ml/util/IdentifiableSuite.scala | 4 +- .../org/apache/spark/ml/PipelineSuite.scala | 4 +- .../ml/attribute/AttributeGroupSuite.scala | 4 +- .../spark/ml/attribute/AttributeSuite.scala | 5 +- .../DecisionTreeClassifierSuite.scala | 7 ++- .../classification/GBTClassifierSuite.scala | 5 +- .../LogisticRegressionSuite.scala | 5 +- .../ml/classification/OneVsRestSuite.scala | 5 +- .../RandomForestClassifierSuite.scala | 5 +- .../evaluation/RegressionEvaluatorSuite.scala | 5 +- .../spark/ml/feature/BinarizerSuite.scala | 5 +- .../spark/ml/feature/BucketizerSuite.scala | 8 ++-- .../spark/ml/feature/HashingTFSuite.scala | 5 +- .../apache/spark/ml/feature/IDFSuite.scala | 5 +- .../spark/ml/feature/NormalizerSuite.scala | 5 +- .../spark/ml/feature/OneHotEncoderSuite.scala | 5 +- .../ml/feature/PolynomialExpansionSuite.scala | 4 +- .../spark/ml/feature/StringIndexerSuite.scala | 5 +- .../spark/ml/feature/TokenizerSuite.scala | 7 ++- .../ml/feature/VectorAssemblerSuite.scala | 6 +-- .../spark/ml/feature/VectorIndexerSuite.scala | 6 +-- .../spark/ml/feature/Word2VecSuite.scala | 5 +- .../org/apache/spark/ml/impl/TreeTests.scala | 5 +- .../apache/spark/ml/param/ParamsSuite.scala | 6 +-- .../ml/param/shared/SharedParamsSuite.scala | 5 +- .../spark/ml/recommendation/ALSSuite.scala | 5 +- .../DecisionTreeRegressorSuite.scala | 7 ++- .../ml/regression/GBTRegressorSuite.scala | 5 +- .../ml/regression/LinearRegressionSuite.scala | 5 +- .../RandomForestRegressorSuite.scala | 7 ++- .../spark/ml/tuning/CrossValidatorSuite.scala | 4 +- .../ml/tuning/ParamGridBuilderSuite.scala | 5 +- .../api/python/PythonMLLibAPISuite.scala | 5 +- .../LogisticRegressionSuite.scala | 6 +-- .../classification/NaiveBayesSuite.scala | 7 ++- .../spark/mllib/classification/SVMSuite.scala | 7 ++- .../StreamingLogisticRegressionSuite.scala | 5 +- .../clustering/GaussianMixtureSuite.scala | 5 +- .../spark/mllib/clustering/KMeansSuite.scala | 9 ++-- .../spark/mllib/clustering/LDASuite.scala | 5 +- .../PowerIterationClusteringSuite.scala | 8 ++-- .../clustering/StreamingKMeansSuite.scala | 5 +- .../evaluation/AreaUnderCurveSuite.scala | 5 +- .../BinaryClassificationMetricsSuite.scala | 5 +- .../evaluation/MulticlassMetricsSuite.scala | 5 +- .../evaluation/MultilabelMetricsSuite.scala | 5 +- .../evaluation/RankingMetricsSuite.scala | 5 +- .../evaluation/RegressionMetricsSuite.scala | 5 +- .../mllib/feature/ChiSqSelectorSuite.scala | 5 +- .../feature/ElementwiseProductSuite.scala | 5 +- .../spark/mllib/feature/HashingTFSuite.scala | 5 +- .../apache/spark/mllib/feature/IDFSuite.scala | 5 +- .../spark/mllib/feature/NormalizerSuite.scala | 5 +- .../apache/spark/mllib/feature/PCASuite.scala | 5 +- .../mllib/feature/StandardScalerSuite.scala | 5 +- .../spark/mllib/feature/Word2VecSuite.scala | 5 +- .../spark/mllib/fpm/FPGrowthSuite.scala | 5 +- .../apache/spark/mllib/fpm/FPTreeSuite.scala | 5 +- .../impl/PeriodicGraphCheckpointerSuite.scala | 6 +-- .../apache/spark/mllib/linalg/BLASSuite.scala | 5 +- .../linalg/BreezeMatrixConversionSuite.scala | 6 +-- .../linalg/BreezeVectorConversionSuite.scala | 6 +-- .../spark/mllib/linalg/MatricesSuite.scala | 4 +- .../spark/mllib/linalg/VectorsSuite.scala | 5 +- .../linalg/distributed/BlockMatrixSuite.scala | 5 +- .../distributed/CoordinateMatrixSuite.scala | 5 +- .../distributed/IndexedRowMatrixSuite.scala | 5 +- .../linalg/distributed/RowMatrixSuite.scala | 6 +-- .../optimization/GradientDescentSuite.scala | 7 +-- .../spark/mllib/optimization/LBFGSSuite.scala | 7 +-- .../spark/mllib/optimization/NNLSSuite.scala | 5 +- ...ryClassificationPMMLModelExportSuite.scala | 4 +- ...eneralizedLinearPMMLModelExportSuite.scala | 4 +- .../export/KMeansPMMLModelExportSuite.scala | 4 +- .../export/PMMLModelExportFactorySuite.scala | 5 +- .../random/RandomDataGeneratorSuite.scala | 5 +- .../spark/mllib/random/RandomRDDsSuite.scala | 5 +- .../mllib/rdd/MLPairRDDFunctionsSuite.scala | 5 +- .../spark/mllib/rdd/RDDFunctionsSuite.scala | 5 +- .../spark/mllib/recommendation/ALSSuite.scala | 4 +- .../MatrixFactorizationModelSuite.scala | 5 +- .../regression/IsotonicRegressionSuite.scala | 5 +- .../mllib/regression/LabeledPointSuite.scala | 5 +- .../spark/mllib/regression/LassoSuite.scala | 7 ++- .../regression/LinearRegressionSuite.scala | 7 ++- .../regression/RidgeRegressionSuite.scala | 6 +-- .../StreamingLinearRegressionSuite.scala | 5 +- .../spark/mllib/stat/CorrelationSuite.scala | 5 +- .../mllib/stat/HypothesisTestSuite.scala | 6 +-- .../spark/mllib/stat/KernelDensitySuite.scala | 4 +- .../MultivariateOnlineSummarizerSuite.scala | 5 +- .../MultivariateGaussianSuite.scala | 5 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 7 ++- .../tree/GradientBoostedTreesSuite.scala | 5 +- .../spark/mllib/tree/ImpuritySuite.scala | 5 +- .../spark/mllib/tree/RandomForestSuite.scala | 5 +- .../mllib/tree/impl/BaggedPointSuite.scala | 5 +- .../spark/mllib/util/MLUtilsSuite.scala | 5 +- .../spark/mllib/util/NumericParserSuite.scala | 6 +-- .../spark/mllib/util/TestingUtilsSuite.scala | 4 +- repl/pom.xml | 7 +++ .../org/apache/spark/repl/ReplSuite.scala | 5 +- .../org/apache/spark/repl/ReplSuite.scala | 5 +- .../spark/repl/ExecutorClassLoaderSuite.scala | 3 +- sql/catalyst/pom.xml | 7 +++ .../sql/catalyst/DistributionSuite.scala | 5 +- .../sql/catalyst/ScalaReflectionSuite.scala | 5 +- .../spark/sql/catalyst/SqlParserSuite.scala | 4 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 5 +- .../analysis/DecimalPrecisionSuite.scala | 5 +- .../expressions/AttributeSetSuite.scala | 5 +- .../ExpressionEvaluationSuite.scala | 4 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 8 +++- .../expressions/UnsafeRowConverterSuite.scala | 5 +- .../spark/sql/catalyst/plans/PlanTest.scala | 5 +- .../sql/catalyst/plans/SameResultSuite.scala | 5 +- .../catalyst/trees/RuleExecutorSuite.scala | 5 +- .../sql/catalyst/trees/TreeNodeSuite.scala | 5 +- .../sql/catalyst/util/MetadataSuite.scala | 4 +- .../spark/sql/types/DataTypeParserSuite.scala | 4 +- .../spark/sql/types/DataTypeSuite.scala | 5 +- .../spark/sql/types/UTF8StringSuite.scala | 4 +- .../sql/types/decimal/DecimalSuite.scala | 5 +- .../apache/spark/sql/DataFrameStatSuite.scala | 4 +- .../spark/sql/MathExpressionsSuite.scala | 2 +- .../scala/org/apache/spark/sql/RowSuite.scala | 4 +- .../org/apache/spark/sql/SQLConfSuite.scala | 5 +- .../apache/spark/sql/SQLContextSuite.scala | 5 +- .../sql/ScalaReflectionRelationSuite.scala | 5 +- .../apache/spark/sql/SerializationSuite.scala | 6 +-- .../spark/sql/columnar/ColumnStatsSuite.scala | 5 +- .../spark/sql/columnar/ColumnTypeSuite.scala | 5 +- .../NullableColumnAccessorSuite.scala | 5 +- .../columnar/NullableColumnBuilderSuite.scala | 5 +- .../columnar/PartitionBatchPruningSuite.scala | 5 +- .../compression/BooleanBitSetSuite.scala | 5 +- .../compression/DictionaryEncodingSuite.scala | 5 +- .../compression/IntegralDeltaSuite.scala | 5 +- .../compression/RunLengthEncodingSuite.scala | 5 +- .../spark/sql/execution/PlannerSuite.scala | 5 +- .../execution/SparkSqlSerializer2Suite.scala | 6 +-- .../sql/execution/debug/DebuggingSuite.scala | 5 +- .../execution/joins/HashedRelationSuite.scala | 5 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 5 +- .../spark/sql/jdbc/JDBCWriteSuite.scala | 5 +- .../sql/parquet/ParquetSchemaSuite.scala | 4 +- .../sql/sources/ResolvedDataSourceSuite.scala | 4 +- sql/hive-thriftserver/pom.xml | 7 +++ .../sql/hive/thriftserver/CliSuite.scala | 6 +-- .../HiveThriftServer2Suites.scala | 6 +-- sql/hive/pom.xml | 7 +++ .../spark/sql/hive/HiveInspectorSuite.scala | 4 +- .../sql/hive/HiveMetastoreCatalogSuite.scala | 4 +- .../apache/spark/sql/hive/HiveQlSuite.scala | 5 +- .../spark/sql/hive/SerializationSuite.scala | 6 +-- .../spark/sql/hive/client/VersionsSuite.scala | 5 +- .../hive/execution/ConcurrentHiveSuite.scala | 6 +-- .../hive/execution/HiveComparisonTest.scala | 6 +-- .../hive/orc/OrcPartitionDiscoverySuite.scala | 5 +- .../spark/sql/hive/orc/OrcQuerySuite.scala | 5 +- .../sql/sources/hadoopFsRelationSuites.scala | 5 +- streaming/pom.xml | 7 +++ .../spark/streaming/DStreamClosureSuite.scala | 6 +-- .../spark/streaming/DStreamScopeSuite.scala | 6 +-- .../streaming/ReceivedBlockHandlerSuite.scala | 8 +++- .../streaming/ReceivedBlockTrackerSuite.scala | 6 +-- .../streaming/StreamingContextSuite.scala | 6 +-- .../spark/streaming/TestSuiteBase.scala | 6 +-- .../spark/streaming/UISeleniumSuite.scala | 2 +- .../WriteAheadLogBackedBlockRDDSuite.scala | 6 +-- .../scheduler/InputInfoTrackerSuite.scala | 6 +-- .../spark/streaming/ui/UIUtilsSuite.scala | 5 +- .../util/RateLimitedOutputStreamSuite.scala | 4 +- .../streaming/util/WriteAheadLogSuite.scala | 6 +-- yarn/pom.xml | 7 +++ .../ClientDistributedCacheManagerSuite.scala | 5 +- .../spark/deploy/yarn/ClientSuite.scala | 6 +-- .../deploy/yarn/YarnAllocatorSuite.scala | 6 +-- .../spark/deploy/yarn/YarnClusterSuite.scala | 6 +-- .../yarn/YarnSparkHadoopUtilSuite.scala | 6 +-- 364 files changed, 953 insertions(+), 968 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/SparkFunSuite.scala diff --git a/bagel/pom.xml b/bagel/pom.xml index 1f3dec91314f2..132cd433d78a2 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -40,6 +40,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.scalacheck scalacheck_${scala.binary.version} diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala index ccb262a4ee02a..fb10d734ac74b 100644 --- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala +++ b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.bagel -import org.scalatest.{BeforeAndAfter, FunSuite, Assertions} +import org.scalatest.{BeforeAndAfter, Assertions} import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -27,7 +27,7 @@ import org.apache.spark.storage.StorageLevel class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable class TestMessage(val targetId: String) extends Message[String] with Serializable -class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeouts { +class BagelSuite extends SparkFunSuite with Assertions with BeforeAndAfter with Timeouts { var sc: SparkContext = _ diff --git a/core/pom.xml b/core/pom.xml index e58efe495e36d..5c02be831ce06 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -338,6 +338,12 @@ org.seleniumhq.selenium selenium-java + + + com.google.guava + guava + + test diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 746a40a21bf9e..e942d6579b2fd 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -20,11 +20,10 @@ package org.apache.spark import scala.collection.mutable import scala.ref.WeakReference -import org.scalatest.FunSuite import org.scalatest.Matchers -class AccumulatorSuite extends FunSuite with Matchers with LocalSparkContext { +class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext { implicit def setAccum[A]: AccumulableParam[mutable.Set[A], A] = diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 668ddf9f5f0a9..af81e46a657d3 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark import org.mockito.Mockito._ -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.mock.MockitoSugar import org.apache.spark.executor.DataReadMethod @@ -26,7 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.storage._ // TODO: Test the CacheManager's thread-safety aspects -class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAfter +class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfter with MockitoSugar { var blockManager: BlockManager = _ diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 91d8fdedbe0f3..d1761a48babbc 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -21,13 +21,11 @@ import java.io.File import scala.reflect.ClassTag -import org.scalatest.FunSuite - import org.apache.spark.rdd._ import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} import org.apache.spark.util.Utils -class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { +class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging { var checkpointDir: File = _ val partitioner = new HashPartitioner(2) diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 4a48f6580c78e..501fe186bfd7c 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.{HashSet, SynchronizedSet} import scala.language.existentials import scala.util.Random -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.{PatienceConfiguration, Eventually} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ @@ -44,7 +44,7 @@ import org.apache.spark.storage.ShuffleIndexBlockId * config options, in particular, a different shuffle manager class */ abstract class ContextCleanerSuiteBase(val shuffleManager: Class[_] = classOf[HashShuffleManager]) - extends FunSuite with BeforeAndAfter with LocalSparkContext + extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { implicit val defaultTimeout = timeout(10000 millis) val conf = new SparkConf() diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 96a9c207ad022..9c191ed52206d 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark -import org.scalatest.FunSuite import org.scalatest.concurrent.Timeouts._ import org.scalatest.Matchers import org.scalatest.time.{Millis, Span} @@ -28,7 +27,7 @@ class NotSerializableClass class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {} -class DistributedSuite extends FunSuite with Matchers with LocalSparkContext { +class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContext { val clusterUrl = "local-cluster[2,1,512]" diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index c42dfbc82ada4..b2262033ca238 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark import java.io.File -import org.scalatest.FunSuite import org.scalatest.concurrent.Timeouts import org.scalatest.prop.TableDrivenPropertyChecks._ import org.scalatest.time.SpanSugar._ import org.apache.spark.util.Utils -class DriverSuite extends FunSuite with Timeouts { +class DriverSuite extends SparkFunSuite with Timeouts { ignore("driver should exit after finishing without cleanup (SPARK-530)") { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 84f787ee3715d..1c2b681f0b843 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import scala.collection.mutable -import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} +import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo @@ -28,7 +28,11 @@ import org.apache.spark.util.ManualClock /** * Test add and remove behavior of ExecutorAllocationManager. */ -class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAfter { +class ExecutorAllocationManagerSuite + extends SparkFunSuite + with LocalSparkContext + with BeforeAndAfter { + import ExecutorAllocationManager._ import ExecutorAllocationManagerSuite._ diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index cade1fda2c7be..b18067e68f5a1 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark -import org.scalatest.FunSuite - import org.apache.spark.util.NonSerializable import java.io.NotSerializableException @@ -38,7 +36,7 @@ object FailureSuiteState { } } -class FailureSuite extends FunSuite with LocalSparkContext { +class FailureSuite extends SparkFunSuite with LocalSparkContext { // Run a 3-task map job in which task 1 deterministically fails once, and check // whether the job completes successfully and we ran 4 tasks in total. diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index bff2d10b9946c..6e65b0a8f6c76 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -24,13 +24,12 @@ import javax.net.ssl.SSLException import com.google.common.io.{ByteStreams, Files} import org.apache.commons.lang3.RandomUtils -import org.scalatest.FunSuite import org.apache.spark.util.Utils import SSLSampleConfigs._ -class FileServerSuite extends FunSuite with LocalSparkContext { +class FileServerSuite extends SparkFunSuite with LocalSparkContext { @transient var tmpDir: File = _ @transient var tmpFile: File = _ diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index d67de8692df62..1d8fade90f398 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -30,12 +30,11 @@ import org.apache.hadoop.mapred.{JobConf, FileAlreadyExistsException, FileSplit, import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} -import org.scalatest.FunSuite import org.apache.spark.rdd.{NewHadoopRDD, HadoopRDD} import org.apache.spark.util.Utils -class FileSuite extends FunSuite with LocalSparkContext { +class FileSuite extends SparkFunSuite with LocalSparkContext { var tempDir: File = _ override def beforeEach() { diff --git a/core/src/test/scala/org/apache/spark/FutureActionSuite.scala b/core/src/test/scala/org/apache/spark/FutureActionSuite.scala index f5cdb01ec9504..1102aea96b548 100644 --- a/core/src/test/scala/org/apache/spark/FutureActionSuite.scala +++ b/core/src/test/scala/org/apache/spark/FutureActionSuite.scala @@ -20,10 +20,14 @@ package org.apache.spark import scala.concurrent.Await import scala.concurrent.duration.Duration -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} -class FutureActionSuite extends FunSuite with BeforeAndAfter with Matchers with LocalSparkContext { +class FutureActionSuite + extends SparkFunSuite + with BeforeAndAfter + with Matchers + with LocalSparkContext { before { sc = new SparkContext("local", "FutureActionSuite") diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index b789912e9ebef..911b3bddd1836 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -22,7 +22,6 @@ import scala.language.postfixOps import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId -import org.scalatest.FunSuite import org.mockito.Mockito.{mock, spy, verify, when} import org.mockito.Matchers import org.mockito.Matchers._ @@ -31,7 +30,7 @@ import org.apache.spark.scheduler.TaskScheduler import org.apache.spark.util.RpcUtils import org.scalatest.concurrent.Eventually._ -class HeartbeatReceiverSuite extends FunSuite with LocalSparkContext { +class HeartbeatReceiverSuite extends SparkFunSuite with LocalSparkContext { test("HeartbeatReceiver") { sc = spy(new SparkContext("local[2]", "test")) diff --git a/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala b/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala index 69314deda1f03..e47173f8a8b03 100644 --- a/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark -import org.scalatest.FunSuite - import org.apache.spark.rdd.RDD -class ImplicitOrderingSuite extends FunSuite with LocalSparkContext { +class ImplicitOrderingSuite extends SparkFunSuite with LocalSparkContext { // Tests that PairRDDFunctions grabs an implicit Ordering in various cases where it should. test("basic inference of Orderings"){ sc = new SparkContext("local", "test") diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index ae17fc60e4a43..340a9e327107e 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -24,7 +24,7 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.concurrent.future -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.Matchers import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} @@ -34,7 +34,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} * (e.g. count) as well as multi-job action (e.g. take). We test the local and cluster schedulers * in both FIFO and fair scheduling modes. */ -class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter +class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAfter with LocalSparkContext { override def afterEach() { diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 6ed057a7cab97..1fab69678d040 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark import org.mockito.Mockito._ import org.mockito.Matchers.{any, isA} -import org.scalatest.FunSuite import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.BlockManagerId -class MapOutputTrackerSuite extends FunSuite { +class MapOutputTrackerSuite extends SparkFunSuite { private val conf = new SparkConf def createRpcEnv(name: String, host: String = "localhost", port: Int = 0, diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 47e3bf6e1ac41..3316f561a4949 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -20,12 +20,12 @@ package org.apache.spark import scala.collection.mutable.ArrayBuffer import scala.math.abs -import org.scalatest.{FunSuite, PrivateMethodTester} +import org.scalatest.PrivateMethodTester import org.apache.spark.rdd.RDD import org.apache.spark.util.StatCounter -class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMethodTester { +class PartitioningSuite extends SparkFunSuite with SharedSparkContext with PrivateMethodTester { test("HashPartitioner equality") { val p2 = new HashPartitioner(2) diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 93f46ef11c0e2..376481ba541fa 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -21,9 +21,9 @@ import java.io.File import com.google.common.io.Files import org.apache.spark.util.Utils -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -class SSLOptionsSuite extends FunSuite with BeforeAndAfterAll { +class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { test("test resolving property file as spark conf ") { val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index 61571be44252a..e9b64aa82a17a 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -19,11 +19,9 @@ package org.apache.spark import java.io.File -import org.scalatest.FunSuite - import org.apache.spark.util.Utils -class SecurityManagerSuite extends FunSuite { +class SecurityManagerSuite extends SparkFunSuite { test("set security with conf") { val conf = new SparkConf diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index d7180516029d5..91f4ab360857e 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark -import org.scalatest.FunSuite import org.scalatest.Matchers import org.apache.spark.ShuffleSuite.NonJavaSerializableClass @@ -26,7 +25,7 @@ import org.apache.spark.serializer.KryoSerializer import org.apache.spark.storage.{ShuffleDataBlockId, ShuffleBlockId} import org.apache.spark.util.MutablePair -abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { +abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkContext { val conf = new SparkConf(loadDefaults = false) diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index fafc9d47503b7..9fbaeb33f97cd 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -23,13 +23,12 @@ import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.{Try, Random} -import org.scalatest.FunSuite import org.apache.spark.network.util.ByteUnit import org.apache.spark.serializer.{KryoRegistrator, KryoSerializer} import org.apache.spark.util.{RpcUtils, ResetSystemProperties} import com.esotericsoftware.kryo.Kryo -class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemProperties { +class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSystemProperties { test("Test byteString conversion") { val conf = new SparkConf() // Simply exercise the API, we don't need a complete conversion test since that's handled in diff --git a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala index e6ab538d77bcc..2bdbd70c638a5 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark -import org.scalatest.{Assertions, FunSuite} +import org.scalatest.Assertions import org.apache.spark.storage.StorageLevel -class SparkContextInfoSuite extends FunSuite with LocalSparkContext { +class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext { test("getPersistentRDDs only returns RDDs that are marked as cached") { sc = new SparkContext("local", "test") assert(sc.getPersistentRDDs.isEmpty === true) diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index 9343f4fff89da..f89e3d0a49920 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark -import org.scalatest.{FunSuite, PrivateMethodTester} +import org.scalatest.PrivateMethodTester import org.apache.spark.scheduler.{SchedulerBackend, TaskScheduler, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{SimrSchedulerBackend, SparkDeploySchedulerBackend} @@ -25,7 +25,7 @@ import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, Me import org.apache.spark.scheduler.local.LocalBackend class SparkContextSchedulerCreationSuite - extends FunSuite with LocalSparkContext with PrivateMethodTester with Logging { + extends SparkFunSuite with LocalSparkContext with PrivateMethodTester with Logging { def createTaskScheduler(master: String): TaskSchedulerImpl = createTaskScheduler(master, new SparkConf()) diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 31ef5cd75bd4a..93426822f704e 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -23,8 +23,6 @@ import java.util.concurrent.TimeUnit import com.google.common.base.Charsets._ import com.google.common.io.Files -import org.scalatest.FunSuite - import org.apache.hadoop.io.{BytesWritable, LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} @@ -33,7 +31,7 @@ import org.apache.spark.util.Utils import scala.concurrent.Await import scala.concurrent.duration.Duration -class SparkContextSuite extends FunSuite with LocalSparkContext { +class SparkContextSuite extends SparkFunSuite with LocalSparkContext { test("Only one SparkContext may be active at a time") { // Regression test for SPARK-4180 diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala new file mode 100644 index 0000000000000..0327dfad6ea51 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.{FunSuite, Outcome} + +/** + * Base abstract class for all unit tests in Spark for handling common functionality. + */ +private[spark] abstract class SparkFunSuite extends FunSuite with Logging { + + /** + * Log the suite name and the test name before and after each test. + * + * Subclasses should never override this method. If they wish to run + * custom code before and after each test, they should should mix in + * the {{org.scalatest.BeforeAndAfter}} trait instead. + */ + final protected override def withFixture(test: NoArgTest): Outcome = { + val testName = test.text + val suiteName = this.getClass.getName + val shortSuiteName = suiteName.replaceAll("org.apache.spark", "o.a.s") + try { + logInfo(s"\n\n===== TEST OUTPUT FOR $shortSuiteName: '$testName' =====\n") + test() + } finally { + logInfo(s"\n\n===== FINISHED $shortSuiteName: '$testName' =====\n") + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala index 084eb237d70d1..46516e8d25298 100644 --- a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala @@ -21,12 +21,12 @@ import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps -import org.scalatest.{Matchers, FunSuite} +import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ import org.apache.spark.JobExecutionStatus._ -class StatusTrackerSuite extends FunSuite with Matchers with LocalSparkContext { +class StatusTrackerSuite extends SparkFunSuite with Matchers with LocalSparkContext { test("basic status API usage") { sc = new SparkContext("local", "test", new SparkConf(false)) diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala index 10917c866cc7d..6580139df6c60 100644 --- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -22,7 +22,6 @@ import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.scheduler._ -import org.scalatest.FunSuite /** * Holds state shared across task threads in some ThreadingSuite tests. @@ -37,7 +36,7 @@ object ThreadingSuiteState { } } -class ThreadingSuite extends FunSuite with LocalSparkContext { +class ThreadingSuite extends SparkFunSuite with LocalSparkContext { test("accessing SparkContext form a different thread") { sc = new SparkContext("local", "test") diff --git a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala index 42ff059e018a3..f7a13ab3996d8 100644 --- a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala +++ b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark -import org.scalatest.FunSuite import org.scalatest.concurrent.Timeouts._ import org.scalatest.time.{Millis, Span} -class UnpersistSuite extends FunSuite with LocalSparkContext { +class UnpersistSuite extends SparkFunSuite with LocalSparkContext { test("unpersist RDD") { sc = new SparkContext("local", "test") val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala index 8959a843dbd7d..135c56bf5bc9d 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala @@ -21,15 +21,15 @@ import scala.io.Source import java.io.{PrintWriter, File} -import org.scalatest.{Matchers, FunSuite} +import org.scalatest.Matchers -import org.apache.spark.{SharedSparkContext, SparkConf} +import org.apache.spark.{SharedSparkContext, SparkConf, SparkFunSuite} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.Utils // This test suite uses SharedSparkContext because we need a SparkEnv in order to deserialize // a PythonBroadcast: -class PythonBroadcastSuite extends FunSuite with Matchers with SharedSparkContext { +class PythonBroadcastSuite extends SparkFunSuite with Matchers with SharedSparkContext { test("PythonBroadcast can be serialized with Kryo (SPARK-4882)") { val tempDir = Utils.createTempDir() val broadcastedString = "Hello, world!" diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala index c63d834f9048b..41f2a5c972b6b 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.api.python import java.io.{ByteArrayOutputStream, DataOutputStream} -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class PythonRDDSuite extends FunSuite { +class PythonRDDSuite extends SparkFunSuite { test("Writing large strings to the worker") { val input: List[String] = List("a"*100000) diff --git a/core/src/test/scala/org/apache/spark/api/python/SerDeUtilSuite.scala b/core/src/test/scala/org/apache/spark/api/python/SerDeUtilSuite.scala index f8c39326145e1..267a79fa63782 100644 --- a/core/src/test/scala/org/apache/spark/api/python/SerDeUtilSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/SerDeUtilSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.api.python -import org.scalatest.FunSuite +import org.apache.spark.{SharedSparkContext, SparkFunSuite} -import org.apache.spark.SharedSparkContext - -class SerDeUtilSuite extends FunSuite with SharedSparkContext { +class SerDeUtilSuite extends SparkFunSuite with SharedSparkContext { test("Converting an empty pair RDD to python does not throw an exception (SPARK-5441)") { val emptyRdd = sc.makeRDD(Seq[(Any, Any)]()) diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index c38e306b6ac40..c05e8bb6538ba 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -20,10 +20,10 @@ package org.apache.spark.broadcast import scala.concurrent.duration._ import scala.util.Random -import org.scalatest.{Assertions, FunSuite} +import org.scalatest.Assertions import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkEnv} +import org.apache.spark._ import org.apache.spark.io.SnappyCompressionCodec import org.apache.spark.rdd.RDD import org.apache.spark.serializer.JavaSerializer @@ -45,7 +45,7 @@ class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { } } -class BroadcastSuite extends FunSuite with LocalSparkContext { +class BroadcastSuite extends SparkFunSuite with LocalSparkContext { private val httpConf = broadcastConf("HttpBroadcastFactory") private val torrentConf = broadcastConf("TorrentBroadcastFactory") diff --git a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala index 745f9eeee7536..6a99dbca64f4b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.deploy -import org.scalatest.FunSuite import org.scalatest.Matchers -class ClientSuite extends FunSuite with Matchers { +import org.apache.spark.SparkFunSuite + +class ClientSuite extends SparkFunSuite with Matchers { test("correctly validates driver jar URL's") { ClientArguments.isValidJarUrl("http://someHost:8080/foo.jar") should be (true) ClientArguments.isValidJarUrl("https://someHost:8080/foo.jar") should be (true) diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index e04a79284175c..08529e0ef2806 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -23,14 +23,13 @@ import java.util.Date import com.fasterxml.jackson.core.JsonParseException import org.json4s._ import org.json4s.jackson.JsonMethods -import org.scalatest.FunSuite import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, RecoveryState, WorkerInfo} import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} -import org.apache.spark.{JsonTestUtils, SecurityManager, SparkConf} +import org.apache.spark.{JsonTestUtils, SecurityManager, SparkConf, SparkFunSuite} -class JsonProtocolSuite extends FunSuite with JsonTestUtils { +class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { test("writeApplicationInfo") { val output = JsonProtocol.writeApplicationInfo(createAppInfo()) diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala index c93d16f8a1586..c215b0582889f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -23,13 +23,11 @@ import scala.collection.JavaConversions._ import scala.collection.mutable import scala.io.Source -import org.scalatest.FunSuite - import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler.{SparkListenerExecutorAdded, SparkListener} -import org.apache.spark.{SparkConf, SparkContext, LocalSparkContext} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} -class LogUrlsStandaloneSuite extends FunSuite with LocalSparkContext { +class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { /** Length of time to wait while draining listener events. */ private val WAIT_TIMEOUT_MILLIS = 10000 diff --git a/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala index 80f2cc02516fe..473a2d7b2a258 100644 --- a/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.deploy -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.util.Utils -class PythonRunnerSuite extends FunSuite { +class PythonRunnerSuite extends SparkFunSuite { // Test formatting a single path to be added to the PYTHONPATH test("format path") { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index ea9227a7e9af5..46369457f000a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -23,7 +23,6 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.base.Charsets.UTF_8 import com.google.common.io.ByteStreams -import org.scalatest.FunSuite import org.scalatest.Matchers import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -35,7 +34,12 @@ import org.apache.spark.util.{ResetSystemProperties, Utils} // Note: this suite mixes in ResetSystemProperties because SparkSubmit.main() sets a bunch // of properties that neeed to be cleared after tests. -class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties with Timeouts { +class SparkSubmitSuite + extends SparkFunSuite + with Matchers + with ResetSystemProperties + with Timeouts { + def beforeAll() { System.setProperty("spark.testing", "true") } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 088ca3cb93b49..8fda5c8b472c9 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -20,15 +20,16 @@ package org.apache.spark.deploy import java.io.{File, PrintStream, OutputStream} import scala.collection.mutable.ArrayBuffer -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll import org.apache.ivy.core.module.descriptor.MDArtifact import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.resolver.IBiblioResolver +import org.apache.spark.SparkFunSuite import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate -class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { +class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { private val noOpOutputStream = new OutputStream { def write(b: Int) = {} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index a0a0afa48833e..0f6933df9e6bc 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -25,15 +25,15 @@ import scala.io.Source import org.apache.hadoop.fs.Path import org.json4s.jackson.JsonMethods._ -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.Matchers -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.io._ import org.apache.spark.scheduler._ import org.apache.spark.util.{JsonProtocol, ManualClock, Utils} -class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { +class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { private var testDir: File = null diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index e10dd4cf837aa..14f2d1a5894b8 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -22,10 +22,10 @@ import javax.servlet.http.{HttpServletRequest, HttpServletResponse} import org.apache.commons.io.{FileUtils, IOUtils} import org.mockito.Mockito.when -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.mock.MockitoSugar -import org.apache.spark.{JsonTestUtils, SecurityManager, SparkConf} +import org.apache.spark.{JsonTestUtils, SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.ui.SparkUI /** @@ -39,7 +39,7 @@ import org.apache.spark.ui.SparkUI * expectations. However, in general this should be done with extreme caution, as the metrics * are considered part of Spark's public api. */ -class HistoryServerSuite extends FunSuite with BeforeAndAfter with Matchers with MockitoSugar +class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers with MockitoSugar with JsonTestUtils { private val logDir = new File("src/test/resources/spark-events") diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index f97e5ff6db31d..014e87bb40254 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -27,14 +27,14 @@ import scala.language.postfixOps import akka.actor.Address import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers import org.scalatest.concurrent.Eventually import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory} -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.deploy._ -class MasterSuite extends FunSuite with Matchers with Eventually { +class MasterSuite extends SparkFunSuite with Matchers with Eventually { test("toAkkaUrl") { val conf = new SparkConf(loadDefaults = false) diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index f4d548d9e7720..197f68e7ec5ed 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -25,7 +25,7 @@ import scala.collection.mutable import akka.actor.{Actor, ActorRef, ActorSystem, Props} import com.google.common.base.Charsets -import org.scalatest.{BeforeAndAfterEach, FunSuite} +import org.scalatest.BeforeAndAfterEach import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods._ @@ -38,7 +38,7 @@ import org.apache.spark.deploy.master.DriverState._ /** * Tests for the REST application submission protocol used in standalone cluster mode. */ -class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { +class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { private var actorSystem: Option[ActorSystem] = None private var server: Option[RestSubmissionServer] = None diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala index 61071ee17256c..115ac0534a1b4 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -21,14 +21,13 @@ import java.lang.Boolean import java.lang.Integer import org.json4s.jackson.JsonMethods._ -import org.scalatest.FunSuite -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} /** * Tests for the REST application submission protocol. */ -class SubmitRestProtocolSuite extends FunSuite { +class SubmitRestProtocolSuite extends SparkFunSuite { test("validate") { val request = new DummyRequest diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala index 1c27d83cf876c..5b3930c0b0132 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala @@ -17,11 +17,12 @@ package org.apache.spark.deploy.worker +import org.apache.spark.SparkFunSuite import org.apache.spark.deploy.Command import org.apache.spark.util.Utils -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers -class CommandUtilsSuite extends FunSuite with Matchers { +class CommandUtilsSuite extends SparkFunSuite with Matchers { test("set libraryPath correctly") { val appId = "12345-worker321-9876" diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala index 2159fd8c16c6f..6258c18d177fd 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala @@ -23,13 +23,12 @@ import org.mockito.Mockito._ import org.mockito.Matchers._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer -import org.scalatest.FunSuite -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.{Command, DriverDescription} import org.apache.spark.util.Clock -class DriverRunnerTest extends FunSuite { +class DriverRunnerTest extends SparkFunSuite { private def createDriverRunner() = { val command = new Command("mainClass", Seq(), Map(), Seq(), Seq(), Seq()) val driverDescription = new DriverDescription("jarUrl", 512, 1, true, command) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index a8b9df227c996..3da992788962b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -21,12 +21,10 @@ import java.io.File import scala.collection.JavaConversions._ -import org.scalatest.FunSuite - import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} -class ExecutorRunnerTest extends FunSuite { +class ExecutorRunnerTest extends SparkFunSuite { test("command includes appId") { val appId = "12345-worker321-9876" val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala index e432b8e94654a..15f7ca4a6dacc 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala @@ -18,11 +18,10 @@ package org.apache.spark.deploy.worker -import org.apache.spark.SparkConf -import org.scalatest.FunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} -class WorkerArgumentsTest extends FunSuite { +class WorkerArgumentsTest extends SparkFunSuite { test("Memory can't be set to 0 when cmd line args leave off M or G") { val conf = new SparkConf diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index 93a779d5ce6f2..0f4d3b28d09df 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.deploy.worker -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.Command -import org.scalatest.{Matchers, FunSuite} +import org.scalatest.Matchers -class WorkerSuite extends FunSuite with Matchers { +class WorkerSuite extends SparkFunSuite with Matchers { def cmd(javaOpts: String*): Command = { Command("", Seq.empty, Map.empty, Seq.empty, Seq.empty, Seq(javaOpts : _*)) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index 6a6f29dd613cd..ac18f04a11475 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -18,12 +18,11 @@ package org.apache.spark.deploy.worker import akka.actor.AddressFromURIString -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.SecurityManager import org.apache.spark.rpc.{RpcAddress, RpcEnv} -import org.scalatest.FunSuite -class WorkerWatcherSuite extends FunSuite { +class WorkerWatcherSuite extends SparkFunSuite { test("WorkerWatcher shuts down on valid disassociation") { val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala index 326e203afe136..8275fd87764cd 100644 --- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.executor -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class TaskMetricsSuite extends FunSuite { +class TaskMetricsSuite extends SparkFunSuite { test("[SPARK-5701] updateShuffleReadMetrics: ShuffleReadMetrics not added when no shuffle deps") { val taskMetrics = new TaskMetrics() taskMetrics.updateShuffleReadMetrics() diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala index 2e58c159a2ed8..63947df3d43a2 100644 --- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala @@ -24,11 +24,10 @@ import java.io.FileOutputStream import scala.collection.immutable.IndexedSeq import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite import org.apache.hadoop.io.Text -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.util.Utils import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, GzipCodec} @@ -37,7 +36,7 @@ import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, Gzi * [[org.apache.spark.input.WholeTextFileRecordReader WholeTextFileRecordReader]]. A temporary * directory is created as fake input. Temporal storage would be deleted in the end. */ -class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll { +class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAll { private var sc: SparkContext = _ private var factory: CompressionCodecFactory = _ diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index cf6a143537889..cbdb33c89d0fb 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -20,11 +20,10 @@ package org.apache.spark.io import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import com.google.common.io.ByteStreams -import org.scalatest.FunSuite -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} -class CompressionCodecSuite extends FunSuite { +class CompressionCodecSuite extends SparkFunSuite { val conf = new SparkConf(false) def testCodec(codec: CompressionCodec) { diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 60dba3b2d6719..19f1af0dcd461 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -36,14 +36,14 @@ import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat => NewCombi import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} import org.apache.hadoop.mapreduce.{TaskAttemptContext, InputSplit => NewInputSplit, RecordReader => NewRecordReader} -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.SharedSparkContext +import org.apache.spark.{SharedSparkContext, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.util.Utils -class InputOutputMetricsSuite extends FunSuite with SharedSparkContext +class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext with BeforeAndAfter { @transient var tmpDir: File = _ diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala index 100ac77dec1f7..a901a069d9bfe 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.metrics -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -class MetricsConfigSuite extends FunSuite with BeforeAndAfter { +import org.apache.spark.SparkFunSuite + +class MetricsConfigSuite extends SparkFunSuite with BeforeAndAfter { var filePath: String = _ before { diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala index bbdc9568a6ddb..9c389c76bf3bd 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.metrics -import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} +import org.scalatest.{BeforeAndAfter, PrivateMethodTester} -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.master.MasterSource import org.apache.spark.metrics.source.Source @@ -27,7 +27,7 @@ import com.codahale.metrics.MetricRegistry import scala.collection.mutable.ArrayBuffer -class MetricsSystemSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester{ +class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester{ var filePath: String = _ var conf: SparkConf = null var securityMgr: SecurityManager = null diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 46d2e5173acae..3940527fb874e 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -31,12 +31,12 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.network.{BlockDataManager, BlockTransferService} import org.apache.spark.storage.{BlockId, ShuffleBlockId} -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.mockito.Mockito._ import org.scalatest.mock.MockitoSugar -import org.scalatest.{FunSuite, ShouldMatchers} +import org.scalatest.ShouldMatchers -class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with ShouldMatchers { +class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar with ShouldMatchers { test("security default off") { val conf = new SparkConf() .set("spark.app.id", "app-id") diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala index a41f8b7ce5ce0..6f8e8a7ac6033 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala @@ -18,11 +18,15 @@ package org.apache.spark.network.netty import org.apache.spark.network.BlockDataManager -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.mockito.Mockito.mock import org.scalatest._ -class NettyBlockTransferServiceSuite extends FunSuite with BeforeAndAfterEach with ShouldMatchers { +class NettyBlockTransferServiceSuite + extends SparkFunSuite + with BeforeAndAfterEach + with ShouldMatchers { + private var service0: NettyBlockTransferService = _ private var service1: NettyBlockTransferService = _ diff --git a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala index 02424c59d6831..5e364cc0edeb2 100644 --- a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala @@ -24,15 +24,13 @@ import scala.concurrent.duration._ import scala.concurrent.{Await, TimeoutException} import scala.language.postfixOps -import org.scalatest.FunSuite - -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.util.Utils /** * Test the ConnectionManager with various security settings. */ -class ConnectionManagerSuite extends FunSuite { +class ConnectionManagerSuite extends SparkFunSuite { test("security default off") { val conf = new SparkConf diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index f2b0ea1063a72..ec99f2a1bad66 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -23,13 +23,13 @@ import scala.concurrent.{Await, TimeoutException} import scala.concurrent.duration.Duration import scala.concurrent.ExecutionContext.Implicits.global -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkContext, SparkException, LocalSparkContext} +import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, SparkFunSuite} -class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll with Timeouts { +class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Timeouts { @transient private var sc: SparkContext = _ diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala index 01039b9449daf..4e72b89bfcc40 100644 --- a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.rdd -import org.scalatest.FunSuite - import org.apache.spark._ -class DoubleRDDSuite extends FunSuite with SharedSparkContext { +class DoubleRDDSuite extends SparkFunSuite with SharedSparkContext { test("sum") { assert(sc.parallelize(Seq.empty[Double]).sum() === 0.0) assert(sc.parallelize(Seq(1.0)).sum() === 1.0) diff --git a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala index be8467354b222..a8466ed8c1dc2 100644 --- a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.rdd import java.sql._ -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.{LocalSparkContext, SparkContext} +import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite} -class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { +class JdbcRDDSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { before { Class.forName("org.apache.derby.jdbc.EmbeddedDriver") diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 6564232986cfa..dfa102f432a02 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -28,12 +28,10 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, OutputCommitter => NewOutputCommitter, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, TaskAttemptContext => NewTaskAttempContext} -import org.apache.spark.{Partitioner, SharedSparkContext} +import org.apache.spark.{Partitioner, SharedSparkContext, SparkFunSuite} import org.apache.spark.util.Utils -import org.scalatest.FunSuite - -class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { +class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { test("aggregateByKey") { val pairs = sc.parallelize(Array((1, 1), (1, 1), (3, 2), (5, 1), (5, 3)), 2) diff --git a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala index 1880364581c1a..e7cc1617cdf1c 100644 --- a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala @@ -22,10 +22,11 @@ import scala.collection.immutable.NumericRange import org.scalacheck.Arbitrary._ import org.scalacheck.Gen import org.scalacheck.Prop._ -import org.scalatest.FunSuite import org.scalatest.prop.Checkers -class ParallelCollectionSplitSuite extends FunSuite with Checkers { +import org.apache.spark.SparkFunSuite + +class ParallelCollectionSplitSuite extends SparkFunSuite with Checkers { test("one element per slice") { val data = Array(1, 2, 3) val slices = ParallelCollectionRDD.slice(data, 3) diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala index 465068c6cbb16..b1544a6106110 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.rdd -import org.scalatest.FunSuite +import org.apache.spark.{Partition, SharedSparkContext, SparkFunSuite, TaskContext} -import org.apache.spark.{Partition, SharedSparkContext, TaskContext} - -class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { +class PartitionPruningRDDSuite extends SparkFunSuite with SharedSparkContext { test("Pruned Partitions inherit locality prefs correctly") { diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala index 0d1369c19c69e..132a5fa9a80fb 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala @@ -17,9 +17,7 @@ package org.apache.spark.rdd -import org.scalatest.FunSuite - -import org.apache.spark.SharedSparkContext +import org.apache.spark.{SharedSparkContext, SparkFunSuite} import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, RandomSampler} /** a sampler that outputs its seed */ @@ -38,7 +36,7 @@ class MockSampler extends RandomSampler[Long, Long] { override def clone: MockSampler = new MockSampler } -class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext { +class PartitionwiseSampledRDDSuite extends SparkFunSuite with SharedSparkContext { test("seed distribution") { val rdd = sc.makeRDD(Array(1L, 2L, 3L, 4L), 2) diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index 85eb2a1d07ba4..32f04d54eff94 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -22,7 +22,6 @@ import java.io.File import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.{FileSplit, JobConf, TextInputFormat} -import org.scalatest.FunSuite import scala.collection.Map import scala.language.postfixOps @@ -32,7 +31,7 @@ import scala.util.Try import org.apache.spark._ import org.apache.spark.util.Utils -class PipedRDDSuite extends FunSuite with SharedSparkContext { +class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { test("basic pipe") { if (testCommandAvailable("cat")) { diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala index 4434ed858c60c..f65349e3e3585 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.rdd -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.{TaskContext, Partition, SparkContext} +import org.apache.spark.{Partition, SparkContext, SparkFunSuite, TaskContext} /** * Tests whether scopes are passed from the RDD operation to the RDDs correctly. */ -class RDDOperationScopeSuite extends FunSuite with BeforeAndAfter { +class RDDOperationScopeSuite extends SparkFunSuite with BeforeAndAfter { private var sc: SparkContext = null private val scope1 = new RDDOperationScope("scope1") private val scope2 = new RDDOperationScope("scope2", Some(scope1)) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 8079d5dcaea81..f6da9f98ad253 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -25,14 +25,12 @@ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import org.scalatest.FunSuite - import org.apache.spark._ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDDSuiteUtils._ import org.apache.spark.util.Utils -class RDDSuite extends FunSuite with SharedSparkContext { +class RDDSuite extends SparkFunSuite with SharedSparkContext { test("basic operations") { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala index 54fc914722b46..a7de9cabe7cc9 100644 --- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.rdd -import org.scalatest.FunSuite import org.scalatest.Matchers -import org.apache.spark.{Logging, SharedSparkContext} +import org.apache.spark.{Logging, SharedSparkContext, SparkFunSuite} -class SortingSuite extends FunSuite with SharedSparkContext with Matchers with Logging { +class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers with Logging { test("sortByKey") { val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)), 2) diff --git a/core/src/test/scala/org/apache/spark/rdd/ZippedPartitionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ZippedPartitionsSuite.scala index 72596e86865b2..5d7b973fbd9ac 100644 --- a/core/src/test/scala/org/apache/spark/rdd/ZippedPartitionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/ZippedPartitionsSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.rdd -import org.apache.spark.SharedSparkContext -import org.scalatest.FunSuite +import org.apache.spark.{SharedSparkContext, SparkFunSuite} object ZippedPartitionsSuite { def procZippedData(i: Iterator[Int], s: Iterator[String], d: Iterator[Double]) : Iterator[Int] = { @@ -26,7 +25,7 @@ object ZippedPartitionsSuite { } } -class ZippedPartitionsSuite extends FunSuite with SharedSparkContext { +class ZippedPartitionsSuite extends SparkFunSuite with SharedSparkContext { test("print sizes") { val data1 = sc.makeRDD(Array(1, 2, 3, 4), 2) val data2 = sc.makeRDD(Array("1", "2", "3", "4", "5", "6"), 2) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 21eb71d9acfbd..1f0aa759b08da 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -24,15 +24,15 @@ import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{SparkException, SparkConf} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} /** * Common tests for an RpcEnv implementation. */ -abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { +abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { var env: RpcEnv = _ diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala index 3821166386fa6..34145691153ce 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala @@ -17,12 +17,10 @@ package org.apache.spark.scheduler -import org.apache.spark.{LocalSparkContext, SparkConf, SparkException, SparkContext} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.util.{SerializableBuffer, AkkaUtils} -import org.scalatest.FunSuite - -class CoarseGrainedSchedulerBackendSuite extends FunSuite with LocalSparkContext { +class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext { test("serialized task larger than akka frame size") { val conf = new SparkConf diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index eea7a600841cc..bfcf918e06162 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap, Map} import scala.language.reflectiveCalls import scala.util.control.NonFatal -import org.scalatest.{BeforeAndAfter, FunSuiteLike} +import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -68,7 +68,7 @@ class MyRDD( class DAGSchedulerSuiteDummyException extends Exception class DAGSchedulerSuite - extends FunSuiteLike with BeforeAndAfter with LocalSparkContext with Timeouts { + extends SparkFunSuite with BeforeAndAfter with LocalSparkContext with Timeouts { val conf = new SparkConf /** Set of TaskSets the DAGScheduler has requested executed. */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index b52a8d11d147d..f681f21b6205e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -25,7 +25,7 @@ import scala.io.Source import org.apache.hadoop.fs.Path import org.json4s.jackson.JsonMethods._ -import org.scalatest.{FunSuiteLike, BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil @@ -39,7 +39,7 @@ import org.apache.spark.util.{JsonProtocol, Utils} * logging events, whether the parsing of the file names is correct, and whether the logged events * can be read and deserialized into actual SparkListenerEvents. */ -class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with BeforeAndAfter +class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfter with Logging { import EventLoggingListenerSuite._ diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 950c6dc58e332..b8e466fab4506 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -18,14 +18,13 @@ package org.apache.spark.scheduler import org.apache.spark.storage.BlockManagerId -import org.scalatest.FunSuite -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import scala.util.Random -class MapStatusSuite extends FunSuite { +class MapStatusSuite extends SparkFunSuite { test("compressSize") { assert(MapStatus.compressSize(0L) === 0) diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 7078a7a12232a..a9036da9cc93d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -24,7 +24,7 @@ import org.mockito.Matchers import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.apache.hadoop.mapred.{TaskAttemptID, JobConf, TaskAttemptContext, OutputCommitter} @@ -64,7 +64,7 @@ import scala.language.postfixOps * increments would be captured even though the commit in both tasks was executed * erroneously. */ -class OutputCommitCoordinatorSuite extends FunSuite with BeforeAndAfter { +class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { var outputCommitCoordinator: OutputCommitCoordinator = null var tempDir: File = null diff --git a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala index 456451b676bed..467796d7c24b0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala @@ -19,15 +19,13 @@ package org.apache.spark.scheduler import java.util.Properties -import org.scalatest.FunSuite - -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} /** * Tests that pools and the associated scheduling algorithms for FIFO and fair scheduling work * correctly. */ -class PoolSuite extends FunSuite with LocalSparkContext { +class PoolSuite extends SparkFunSuite with LocalSparkContext { def createTaskSetManager(stageId: Int, numTasks: Int, taskScheduler: TaskSchedulerImpl) : TaskSetManager = { diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index dabe4574b6456..ff3fa95ec32ae 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -21,10 +21,10 @@ import java.io.{File, PrintWriter} import java.net.URI import org.json4s.jackson.JsonMethods._ -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkContext, SPARK_VERSION} -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec import org.apache.spark.util.{JsonProtocol, Utils} @@ -32,7 +32,7 @@ import org.apache.spark.util.{JsonProtocol, Utils} /** * Test whether ReplayListenerBus replays events from logs correctly. */ -class ReplayListenerSuite extends FunSuite with BeforeAndAfter { +class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter { private val fileSystem = Utils.getHadoopFileSystem("/", SparkHadoopUtil.get.newConfiguration(new SparkConf())) private var testDir: File = _ diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 825c616c0c3e0..06fb909bf5419 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -22,13 +22,13 @@ import java.util.concurrent.Semaphore import scala.collection.mutable import scala.collection.JavaConversions._ -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers import org.apache.spark.executor.TaskMetrics import org.apache.spark.util.ResetSystemProperties -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} -class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers +class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Matchers with ResetSystemProperties { /** Length of time to wait while draining listener events. */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala index 623a687c359a2..c7f179e1483a5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala @@ -18,16 +18,16 @@ package org.apache.spark.scheduler import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.{SparkContext, LocalSparkContext} +import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite} -import org.scalatest.{FunSuite, BeforeAndAfter, BeforeAndAfterAll} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import scala.collection.mutable /** * Unit tests for SparkListener that require a local cluster. */ -class SparkListenerWithClusterSuite extends FunSuite with LocalSparkContext +class SparkListenerWithClusterSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfter with BeforeAndAfterAll { /** Length of time to wait while draining listener events. */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 83ae8701243e5..7c1adc1aef1b6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.scheduler import org.mockito.Mockito._ import org.mockito.Matchers.any -import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter import org.apache.spark._ @@ -28,7 +27,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener} -class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { +class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { test("calls TaskCompletionListener after failure") { TaskContextSuite.completed = false diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index e3a3803e6483a..815caa79ff529 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -23,10 +23,10 @@ import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.control.NonFatal -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.storage.TaskResultBlockId /** @@ -71,7 +71,7 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedule /** * Tests related to handling task results (both direct and indirect). */ -class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { +class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { // Set the Akka frame size to be as small as possible (it must be an integer, so 1 is as small // as we can make it) so the tests don't take too long. diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index ffa4381969b68..a6d5232feb8de 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.scheduler -import org.scalatest.FunSuite - import org.apache.spark._ class FakeSchedulerBackend extends SchedulerBackend { @@ -28,7 +26,7 @@ class FakeSchedulerBackend extends SchedulerBackend { def defaultParallelism(): Int = 1 } -class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Logging { +class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with Logging { test("Scheduler does not always schedule tasks on the same workers") { sc = new SparkContext("local", "TaskSchedulerImplSuite") diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 6198cea46ddf8..0060f3396dcde 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -22,8 +22,6 @@ import java.util.Random import scala.collection.mutable.ArrayBuffer import scala.collection.mutable -import org.scalatest.FunSuite - import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.util.{ManualClock, Utils} @@ -146,7 +144,7 @@ class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0) { override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]() } -class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { +class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logging { import TaskLocality.{ANY, PROCESS_LOCAL, NO_PREF, NODE_LOCAL, RACK_LOCAL} private val conf = new SparkConf diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala index 3fa0115e68259..d565132a06789 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala @@ -18,12 +18,11 @@ package org.apache.spark.scheduler.cluster.mesos import org.mockito.Mockito._ -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -class MemoryUtilsSuite extends FunSuite with MockitoSugar { +class MemoryUtilsSuite extends SparkFunSuite with MockitoSugar { test("MesosMemoryUtils should always override memoryOverhead when it's set") { val sparkConf = new SparkConf diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index ab863f3d8d672..6f4ff0814b8da 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -30,16 +30,15 @@ import org.apache.mesos.SchedulerDriver import org.mockito.Matchers._ import org.mockito.Mockito._ import org.mockito.{ArgumentCaptor, Matchers} -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar import org.apache.spark.executor.MesosExecutorBackend import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerExecutorAdded, TaskDescription, TaskSchedulerImpl, WorkerOffer} -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} -class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with MockitoSugar { +class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { test("check spark-class location correctly") { val conf = new SparkConf diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala index eebcba40f8a1c..5a81bb335fdb7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.scheduler.cluster.mesos import java.nio.ByteBuffer -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class MesosTaskLaunchDataSuite extends FunSuite { +class MesosTaskLaunchDataSuite extends SparkFunSuite { test("serialize and deserialize data must be same") { val serializedTask = ByteBuffer.allocate(40) (Range(100, 110).map(serializedTask.putInt(_))) diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala index f28e29e9b8d8e..f5cef1caaf1ac 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala @@ -19,16 +19,15 @@ package org.apache.spark.scheduler.mesos import java.util.Date -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar import org.apache.spark.deploy.Command import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.scheduler.cluster.mesos._ -import org.apache.spark.{LocalSparkContext, SparkConf} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite} -class MesosClusterSchedulerSuite extends FunSuite with LocalSparkContext with MockitoSugar { +class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { private val command = new Command("mainClass", Seq("arg"), null, null, null, null) diff --git a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala index ed4d8ce632e16..329a2b6dad831 100644 --- a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.serializer -import org.apache.spark.SparkConf -import org.scalatest.FunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} -class JavaSerializerSuite extends FunSuite { +class JavaSerializerSuite extends SparkFunSuite { test("JavaSerializer instances are serializable") { val serializer = new JavaSerializer(new SparkConf()) val instance = serializer.newInstance() diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala index 054a4c64897a9..63a8480c9b57b 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala @@ -20,12 +20,11 @@ package org.apache.spark.serializer import org.apache.spark.util.Utils import com.esotericsoftware.kryo.Kryo -import org.scalatest.FunSuite -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, TestUtils} +import org.apache.spark._ import org.apache.spark.serializer.KryoDistributedTest._ -class KryoSerializerDistributedSuite extends FunSuite { +class KryoSerializerDistributedSuite extends SparkFunSuite { test("kryo objects are serialised consistently in different processes") { val conf = new SparkConf(false) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala index da98d09184735..a9b209ccfc76e 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.serializer -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.SparkContext import org.apache.spark.LocalSparkContext import org.apache.spark.SparkException -class KryoSerializerResizableOutputSuite extends FunSuite { +class KryoSerializerResizableOutputSuite extends SparkFunSuite { // trial and error showed this will not serialize with 1mb buffer val x = (1 to 400000).toArray diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 14c0172fa96ab..c32fe232cc27c 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -23,14 +23,13 @@ import scala.collection.mutable import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo -import org.scalatest.FunSuite -import org.apache.spark.{SharedSparkContext, SparkConf} +import org.apache.spark.{SharedSparkContext, SparkConf, SparkFunSuite} import org.apache.spark.scheduler.HighlyCompressedMapStatus import org.apache.spark.serializer.KryoTest._ import org.apache.spark.storage.BlockManagerId -class KryoSerializerSuite extends FunSuite with SharedSparkContext { +class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.kryo.registrator", classOf[MyRegistrator].getName) @@ -361,7 +360,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { } } -class KryoSerializerAutoResetDisabledSuite extends FunSuite with SharedSparkContext { +class KryoSerializerAutoResetDisabledSuite extends SparkFunSuite with SharedSparkContext { conf.set("spark.serializer", classOf[KryoSerializer].getName) conf.set("spark.kryo.registrator", classOf[RegistratorWithoutAutoReset].getName) conf.set("spark.kryo.referenceTracking", "true") diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala index 673948d84d82b..77d66864f755e 100644 --- a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala @@ -17,9 +17,7 @@ package org.apache.spark.serializer -import org.scalatest.FunSuite - -import org.apache.spark.{SharedSparkContext, SparkException} +import org.apache.spark.{SharedSparkContext, SparkException, SparkFunSuite} import org.apache.spark.rdd.RDD /* A trivial (but unserializable) container for trivial functions */ @@ -29,7 +27,7 @@ class UnserializableClass { def pred[T](x: T): Boolean = x.toString.length % 2 == 0 } -class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContext { +class ProactiveClosureSerializationSuite extends SparkFunSuite with SharedSparkContext { def fixture: (RDD[String], UnserializableClass) = { (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass) diff --git a/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala index e62828c4fbac6..2707bb53bc383 100644 --- a/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala @@ -19,10 +19,12 @@ package org.apache.spark.serializer import java.io.{ObjectOutput, ObjectInput} -import org.scalatest.{BeforeAndAfterEach, FunSuite} +import org.scalatest.BeforeAndAfterEach +import org.apache.spark.SparkFunSuite -class SerializationDebuggerSuite extends FunSuite with BeforeAndAfterEach { + +class SerializationDebuggerSuite extends SparkFunSuite with BeforeAndAfterEach { import SerializationDebugger.find diff --git a/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala b/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala index bb34033fe9e7e..4ce3b941bea55 100644 --- a/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala @@ -21,9 +21,9 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.util.Random -import org.scalatest.{Assertions, FunSuite} +import org.scalatest.Assertions -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.KryoTest.RegistratorWithoutAutoReset /** @@ -31,7 +31,7 @@ import org.apache.spark.serializer.KryoTest.RegistratorWithoutAutoReset * describe properties of the serialized stream, such as * [[Serializer.supportsRelocationOfSerializedObjects]]. */ -class SerializerPropertiesSuite extends FunSuite { +class SerializerPropertiesSuite extends SparkFunSuite { import SerializerPropertiesSuite._ diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala index e0e646f0a3652..96778c9ebafb1 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.shuffle -import org.scalatest.FunSuite import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.CountDownLatch -class ShuffleMemoryManagerSuite extends FunSuite with Timeouts { +import org.apache.spark.SparkFunSuite + +class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { /** Launch a thread with the given body block and return it. */ private def startThread(name: String)(body: => Unit): Thread = { val thread = new Thread("ShuffleMemorySuite " + name) { diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala index 0537bf66ad020..491dc3659e184 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala @@ -21,16 +21,14 @@ import java.io.{File, FileWriter} import scala.language.reflectiveCalls -import org.scalatest.FunSuite - -import org.apache.spark.{SparkEnv, SparkContext, LocalSparkContext, SparkConf} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.shuffle.FileShuffleBlockResolver import org.apache.spark.storage.{ShuffleBlockId, FileSegment} -class HashShuffleManagerSuite extends FunSuite with LocalSparkContext { +class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext { private val testConf = new SparkConf(false) private def checkSegments(expected: FileSegment, buffer: ManagedBuffer) { diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala index 49a04a2a45280..a73e94e05575e 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.shuffle.unsafe import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers import org.apache.spark._ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer} @@ -29,7 +29,7 @@ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer} * Tests for the fallback logic in UnsafeShuffleManager. Actual tests of shuffling data are * performed in other suites. */ -class UnsafeShuffleManagerSuite extends FunSuite with Matchers { +class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { import UnsafeShuffleManager.canUseUnsafeShuffle diff --git a/core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamSuite.scala b/core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamSuite.scala index 183043bc05233..63b0e77629dde 100644 --- a/core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamSuite.scala @@ -18,9 +18,11 @@ package org.apache.spark.status.api.v1 import javax.ws.rs.WebApplicationException -import org.scalatest.{Matchers, FunSuite} +import org.scalatest.Matchers -class SimpleDateParamSuite extends FunSuite with Matchers { +import org.apache.spark.SparkFunSuite + +class SimpleDateParamSuite extends SparkFunSuite with Matchers { test("date parsing") { new SimpleDateParam("2015-02-20T23:21:17.190GMT").timestamp should be (1424474477190L) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala index b647e8a6728ec..89ed031b6fcd1 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.storage -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class BlockIdSuite extends FunSuite { +class BlockIdSuite extends SparkFunSuite { def assertSame(id1: BlockId, id2: BlockId) { assert(id1.name === id2.name) assert(id1.hashCode === id2.hashCode) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index f647200402ecb..0f5ba46f69c2f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -23,11 +23,11 @@ import scala.language.implicitConversions import scala.language.postfixOps import org.mockito.Mockito.{mock, when} -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark.rpc.RpcEnv -import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} +import org.apache.spark._ import org.apache.spark.network.BlockTransferService import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus @@ -36,7 +36,7 @@ import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage.StorageLevel._ /** Testsuite that tests block replication in BlockManager */ -class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAndAfter { +class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with BeforeAndAfter { private val conf = new SparkConf(false) var rpcEnv: RpcEnv = null diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 151955ef7f435..bcee901f5dd5f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -31,7 +31,7 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ import org.apache.spark.rpc.RpcEnv -import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} +import org.apache.spark._ import org.apache.spark.executor.DataReadMethod import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus @@ -41,7 +41,7 @@ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat import org.apache.spark.util._ -class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach +class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach with PrivateMethodTester with ResetSystemProperties { private val conf = new SparkConf(false) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala index 43ef469c1fd48..ad43a3e5fdc88 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala @@ -18,14 +18,12 @@ package org.apache.spark.storage import java.io.File -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.Utils -class BlockObjectWriterSuite extends FunSuite { +class BlockObjectWriterSuite extends SparkFunSuite { test("verify write metrics") { val file = new File(Utils.createTempDir(), "somefile") val writeMetrics = new ShuffleWriteMetrics() diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index bc5c74c126b74..688f56f4665f3 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -22,12 +22,12 @@ import java.io.{File, FileWriter} import scala.language.reflectiveCalls import org.mockito.Mockito.{mock, when} -import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.util.Utils -class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll { +class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with BeforeAndAfterAll { private val testConf = new SparkConf(false) private var rootDir0: File = _ private var rootDir1: File = _ diff --git a/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala index 47341b74e9c0f..b21c91f75d5c7 100644 --- a/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala @@ -16,11 +16,10 @@ */ package org.apache.spark.storage -import org.scalatest.FunSuite -import org.apache.spark.{SharedSparkContext, SparkConf, LocalSparkContext, SparkContext} +import org.apache.spark._ -class FlatmapIteratorSuite extends FunSuite with LocalSparkContext { +class FlatmapIteratorSuite extends SparkFunSuite with LocalSparkContext { /* Tests the ability of Spark to deal with user provided iterators from flatMap * calls, that may generate more data then available memory. In any * memory based persistance Spark will unroll the iterator into an ArrayBuffer diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala index b47157f8331cc..ac6fec56bbf4f 100644 --- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala @@ -20,15 +20,15 @@ package org.apache.spark.storage import java.io.File import org.apache.spark.util.Utils -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} /** * Tests for the spark.local.dir and SPARK_LOCAL_DIRS configuration options. */ -class LocalDirsSuite extends FunSuite with BeforeAndAfter { +class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter { before { Utils.clearLocalRootDirs() diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 2080c432d77db..2a7fe67ad8585 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -26,15 +26,14 @@ import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer -import org.scalatest.FunSuite -import org.apache.spark.{SparkConf, TaskContextImpl} +import org.apache.spark.{SparkConf, SparkFunSuite, TaskContextImpl} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.serializer.TestSerializer -class ShuffleBlockFetcherIteratorSuite extends FunSuite { +class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Some of the tests are quite tricky because we are testing the cleanup behavior // in the presence of faults. diff --git a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala index 3a45875391e29..1a199beb3558f 100644 --- a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.storage -import org.scalatest.FunSuite -import org.apache.spark.Success +import org.apache.spark.{SparkFunSuite, Success} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ /** * Test the behavior of StorageStatusListener in response to all relevant events. */ -class StorageStatusListenerSuite extends FunSuite { +class StorageStatusListenerSuite extends SparkFunSuite { private val bm1 = BlockManagerId("big", "dog", 1) private val bm2 = BlockManagerId("fat", "duck", 2) private val taskInfo1 = new TaskInfo(0, 0, 0, 0, "big", "dog", TaskLocality.ANY, false) diff --git a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala index 17193ddbfd894..1d5a813a4d336 100644 --- a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.storage -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite /** * Test various functionalities in StorageUtils and StorageStatus. */ -class StorageSuite extends FunSuite { +class StorageSuite extends SparkFunSuite { private val memAndDisk = StorageLevel.MEMORY_AND_DISK // For testing add, update, and remove (for non-RDD blocks) diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index a727a43f44dfc..33712f1bfa782 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -42,7 +42,7 @@ import org.apache.spark.status.api.v1.{JacksonMessageWriter, StageStatus} /** * Selenium tests for the Spark Web UI. */ -class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll { +class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with BeforeAndAfterAll { implicit var webDriver: WebDriver = _ implicit val formats = DefaultFormats diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 77a038dc1720d..8f9502b5673d1 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -23,14 +23,13 @@ import scala.io.Source import scala.util.{Failure, Success, Try} import org.eclipse.jetty.servlet.ServletContextHandler -import org.scalatest.FunSuite import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.LocalSparkContext._ -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -class UISuite extends FunSuite { +class UISuite extends SparkFunSuite { /** * Create a test SparkContext with the SparkUI enabled. diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 967dd0821ebd0..56f7b9cf1f358 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.ui.jobs import java.util.Properties -import org.scalatest.FunSuite import org.scalatest.Matchers import org.apache.spark._ @@ -28,7 +27,7 @@ import org.apache.spark.executor._ import org.apache.spark.scheduler._ import org.apache.spark.util.Utils -class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matchers { +class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with Matchers { val jobSubmissionTime = 1421191042750L val jobCompletionTime = 1421191296660L diff --git a/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala index c1126f3af52e6..86b078851851f 100644 --- a/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala @@ -17,9 +17,7 @@ package org.apache.spark.ui.scope -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.SparkListenerStageSubmitted import org.apache.spark.scheduler.SparkListenerStageCompleted @@ -28,7 +26,7 @@ import org.apache.spark.scheduler.SparkListenerJobStart /** * Tests that this listener populates and cleans up its data structures properly. */ -class RDDOperationGraphListenerSuite extends FunSuite { +class RDDOperationGraphListenerSuite extends SparkFunSuite { private var jobIdCounter = 0 private var stageIdCounter = 0 private val maxRetainedJobs = 10 diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala index 8778042e34657..37e2670de9685 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.ui.storage -import org.scalatest.{BeforeAndAfter, FunSuite} -import org.apache.spark.Success +import org.scalatest.BeforeAndAfter +import org.apache.spark.{SparkFunSuite, Success} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.storage._ @@ -26,7 +26,7 @@ import org.apache.spark.storage._ /** * Test various functionality in the StorageListener that supports the StorageTab. */ -class StorageTabSuite extends FunSuite with BeforeAndAfter { +class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { private var bus: LiveListenerBus = _ private var storageStatusListener: StorageStatusListener = _ private var storageListener: StorageListener = _ diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index ccdb3f571429d..6c40685484ed4 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.util import java.util.concurrent.TimeoutException import akka.actor.ActorNotFound -import org.scalatest.FunSuite import org.apache.spark._ import org.apache.spark.rpc.RpcEnv @@ -32,7 +31,7 @@ import org.apache.spark.SSLSampleConfigs._ /** * Test the AkkaUtils with various security settings. */ -class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemProperties { +class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSystemProperties { test("remote fetch security bad password") { val conf = new SparkConf diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 7b165fe28bdd3..a97a842f434fb 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -20,14 +20,12 @@ package org.apache.spark.util import java.io.NotSerializableException import java.util.Random -import org.scalatest.FunSuite - import org.apache.spark.LocalSparkContext._ -import org.apache.spark.{TaskContext, SparkContext, SparkException} +import org.apache.spark.{SparkContext, SparkException, SparkFunSuite, TaskContext} import org.apache.spark.partial.CountEvaluator import org.apache.spark.rdd.RDD -class ClosureCleanerSuite extends FunSuite { +class ClosureCleanerSuite extends SparkFunSuite { test("closures inside an object") { assert(TestObject.run() === 30) // 6 + 7 + 8 + 9 } diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala index 59456790e89f0..3147c937769d2 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala @@ -21,16 +21,16 @@ import java.io.NotSerializableException import scala.collection.mutable -import org.scalatest.{BeforeAndAfterAll, FunSuite, PrivateMethodTester} +import org.scalatest.{BeforeAndAfterAll, PrivateMethodTester} -import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.{SparkContext, SparkException, SparkFunSuite} import org.apache.spark.serializer.SerializerInstance /** * Another test suite for the closure cleaner that is finer-grained. * For tests involving end-to-end Spark jobs, see {{ClosureCleanerSuite}}. */ -class ClosureCleanerSuite2 extends FunSuite with BeforeAndAfterAll with PrivateMethodTester { +class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with PrivateMethodTester { // Start a SparkContext so that the closure serializer is accessible // We do not actually use this explicitly otherwise diff --git a/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala index 3755d43e25ea8..688fcd9f9aaba 100644 --- a/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.util -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class CompletionIteratorSuite extends FunSuite { +class CompletionIteratorSuite extends SparkFunSuite { test("basic test") { var numTimesCompleted = 0 val iter = List(1, 2, 3).iterator diff --git a/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala b/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala index 090d48ec921a1..cdd6555697c23 100644 --- a/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala @@ -17,14 +17,15 @@ package org.apache.spark.util -import org.scalatest.FunSuite import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite + /** * */ -class DistributionSuite extends FunSuite with Matchers { +class DistributionSuite extends SparkFunSuite with Matchers { test("summary") { val d = new Distribution((1 to 100).toArray.map{_.toDouble}) val stats = d.statCounter diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala index 47b535206c949..b207d497f33c2 100644 --- a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala @@ -25,9 +25,10 @@ import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts -import org.scalatest.FunSuite -class EventLoopSuite extends FunSuite with Timeouts { +import org.apache.spark.SparkFunSuite + +class EventLoopSuite extends SparkFunSuite with Timeouts { test("EventLoop") { val buffer = new mutable.ArrayBuffer[Int] with mutable.SynchronizedBuffer[Int] diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index c05317534cddf..2b76ae1f8a24b 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -22,15 +22,15 @@ import java.io._ import scala.collection.mutable.HashSet import scala.reflect._ -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.util.logging.{RollingFileAppender, SizeBasedRollingPolicy, TimeBasedRollingPolicy, FileAppender} -class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { +class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { val testFile = new File(Utils.createTempDir(), "FileAppenderSuite-test").getAbsoluteFile diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 0d9126f23ccc5..e0ef9c70a5fc3 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -25,7 +25,6 @@ import org.apache.spark.shuffle.MetadataFetchFailedException import scala.collection.Map import org.json4s.jackson.JsonMethods._ -import org.scalatest.FunSuite import org.apache.spark._ import org.apache.spark.executor._ @@ -33,7 +32,7 @@ import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.scheduler._ import org.apache.spark.storage._ -class JsonProtocolSuite extends FunSuite { +class JsonProtocolSuite extends SparkFunSuite { val jobSubmissionTime = 1421191042750L val jobCompletionTime = 1421191296660L diff --git a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala index 87de90bb0dfb0..42125547436cb 100644 --- a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala @@ -19,11 +19,9 @@ package org.apache.spark.util import java.net.URLClassLoader -import org.scalatest.FunSuite +import org.apache.spark.{SparkContext, SparkException, SparkFunSuite, TestUtils} -import org.apache.spark.{SparkContext, SparkException, TestUtils} - -class MutableURLClassLoaderSuite extends FunSuite { +class MutableURLClassLoaderSuite extends SparkFunSuite { val urls2 = List(TestUtils.createJarWithClasses( classNames = Seq("FakeClass1", "FakeClass2", "FakeClass3"), diff --git a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala index 403dcb03bd6e5..4b7164d8acbce 100644 --- a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala @@ -21,10 +21,11 @@ import java.util.NoSuchElementException import scala.collection.mutable.Buffer -import org.scalatest.FunSuite import org.scalatest.Matchers -class NextIteratorSuite extends FunSuite with Matchers { +import org.apache.spark.SparkFunSuite + +class NextIteratorSuite extends SparkFunSuite with Matchers { test("one iteration") { val i = new StubIterator(Buffer(1)) i.hasNext should be (true) diff --git a/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala b/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala index bad1aa99952cf..c58db5e606f7c 100644 --- a/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala +++ b/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala @@ -22,12 +22,14 @@ import java.util.Properties import org.apache.commons.lang3.SerializationUtils import org.scalatest.{BeforeAndAfterEach, Suite} +import org.apache.spark.SparkFunSuite + /** * Mixin for automatically resetting system properties that are modified in ScalaTest tests. * This resets the properties after each individual test. * * The order in which fixtures are mixed in affects the order in which they are invoked by tests. - * If we have a suite `MySuite extends FunSuite with Foo with Bar`, then + * If we have a suite `MySuite extends SparkFunSuite with Foo with Bar`, then * Bar's `super` is Foo, so Bar's beforeEach() will and afterEach() methods will be invoked first * by the rest runner. * diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 04f0f3749d6b9..20550178fb1bd 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -19,7 +19,9 @@ package org.apache.spark.util import scala.collection.mutable.ArrayBuffer -import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, FunSuite, PrivateMethodTester} +import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, PrivateMethodTester} + +import org.apache.spark.SparkFunSuite class DummyClass1 {} @@ -59,7 +61,10 @@ class DummyString(val arr: Array[Char]) { } class SizeEstimatorSuite - extends FunSuite with BeforeAndAfterEach with PrivateMethodTester with ResetSystemProperties { + extends SparkFunSuite + with BeforeAndAfterEach + with PrivateMethodTester + with ResetSystemProperties { override def beforeEach() { // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala index 751d3df9cc8f7..8c51e6b14b7fc 100644 --- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala @@ -23,9 +23,9 @@ import java.util.concurrent.{CountDownLatch, TimeUnit} import scala.concurrent.{Await, Future} import scala.concurrent.duration._ -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class ThreadUtilsSuite extends FunSuite { +class ThreadUtilsSuite extends SparkFunSuite { test("newDaemonSingleThreadExecutor") { val executor = ThreadUtils.newDaemonSingleThreadExecutor("this-is-a-thread-name") diff --git a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala index 8b72fe665c214..9b3169026cda3 100644 --- a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala @@ -23,9 +23,9 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.Random -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class TimeStampedHashMapSuite extends FunSuite { +class TimeStampedHashMapSuite extends SparkFunSuite { // Test the testMap function - a Scala HashMap should obviously pass testMap(new mutable.HashMap[String, String]()) diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index afa5cdc819746..a867cf83dc3f1 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -29,16 +29,15 @@ import scala.util.Random import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files -import org.scalatest.FunSuite import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.network.util.ByteUnit -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.SparkConf -class UtilsSuite extends FunSuite with ResetSystemProperties with Logging { +class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { test("timeConversion") { // Test -1 diff --git a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala index ce2968728a996..11194cd22a419 100644 --- a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.util import scala.util.Random -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite /** * Tests org.apache.spark.util.Vector functionality */ @deprecated("suppress compile time deprecation warning", "1.0.0") -class VectorSuite extends FunSuite { +class VectorSuite extends SparkFunSuite { def verifyVector(vector: Vector, expectedLength: Int): Unit = { assert(vector.length == expectedLength) diff --git a/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala index cb99d14b27af4..a2a6d703860f2 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala @@ -21,9 +21,9 @@ import java.util.Comparator import scala.collection.mutable.HashSet -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class AppendOnlyMapSuite extends FunSuite { +class AppendOnlyMapSuite extends SparkFunSuite { test("initialization") { val goodMap1 = new AppendOnlyMap[Int, Int](1) assert(goodMap1.size === 0) diff --git a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala index ffc206991906a..69dbfa9cd7141 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.util.collection -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class BitSetSuite extends FunSuite { +class BitSetSuite extends SparkFunSuite { test("basic set and get") { val setBits = Seq(0, 9, 1, 10, 90, 96) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala index c0c38cd4ac4ad..05306f408847d 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala @@ -19,10 +19,11 @@ package org.apache.spark.util.collection import java.nio.ByteBuffer -import org.scalatest.FunSuite import org.scalatest.Matchers._ -class ChainedBufferSuite extends FunSuite { +import org.apache.spark.SparkFunSuite + +class ChainedBufferSuite extends SparkFunSuite { test("write and read at start") { // write from start of source array val buffer = new ChainedBuffer(8) diff --git a/core/src/test/scala/org/apache/spark/util/collection/CompactBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/CompactBufferSuite.scala index 6c956d93dc80d..bc5479991a99d 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/CompactBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/CompactBufferSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.util.collection -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class CompactBufferSuite extends FunSuite { +class CompactBufferSuite extends SparkFunSuite { test("empty buffer") { val b = new CompactBuffer[Int] assert(b.size === 0) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index dff8f3ddc816f..79eba61a87251 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -19,12 +19,10 @@ package org.apache.spark.util.collection import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - import org.apache.spark._ import org.apache.spark.io.CompressionCodec -class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { +class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { private val allCompressionCodecs = CompressionCodec.ALL_COMPRESSION_CODECS private def createCombiner[T](i: T) = ArrayBuffer[T](i) private def mergeValue[T](buffer: ArrayBuffer[T], i: T): ArrayBuffer[T] = buffer += i diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 7a98723bc6472..9039dbef1fb71 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -19,14 +19,14 @@ package org.apache.spark.util.collection import scala.collection.mutable.ArrayBuffer -import org.scalatest.{FunSuite, PrivateMethodTester} +import org.scalatest.PrivateMethodTester import scala.util.Random import org.apache.spark._ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} -class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMethodTester { +class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with PrivateMethodTester { private def createSparkConf(loadDefaults: Boolean, kryo: Boolean): SparkConf = { val conf = new SparkConf(loadDefaults) if (kryo) { diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala index ef890d2ba60f3..94e011799921b 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.util.collection import scala.collection.mutable.HashSet -import org.scalatest.FunSuite import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.util.SizeEstimator -class OpenHashMapSuite extends FunSuite with Matchers { +class OpenHashMapSuite extends SparkFunSuite with Matchers { test("size for specialized, primitive value (int)") { val capacity = 1024 diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala index 68a03e3a0970f..2607a543dd614 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.util.collection -import org.scalatest.FunSuite import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.util.SizeEstimator -class OpenHashSetSuite extends FunSuite with Matchers { +class OpenHashSetSuite extends SparkFunSuite with Matchers { test("size for specialized, primitive int") { val loadFactor = 0.7 diff --git a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala index b5a2d9ef720c1..6d2459d48d326 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala @@ -21,14 +21,13 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream} import com.google.common.io.ByteStreams -import org.scalatest.FunSuite import org.scalatest.Matchers._ -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.storage.{FileSegment, BlockObjectWriter} -class PartitionedSerializedPairBufferSuite extends FunSuite { +class PartitionedSerializedPairBufferSuite extends SparkFunSuite { test("OrderedInputStream single record") { val serializerInstance = new KryoSerializer(new SparkConf()).newInstance diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala index caf378fec8b3e..462bc2f29f9f8 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.util.collection import scala.collection.mutable.HashSet -import org.scalatest.FunSuite import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.util.SizeEstimator -class PrimitiveKeyOpenHashMapSuite extends FunSuite with Matchers { +class PrimitiveKeyOpenHashMapSuite extends SparkFunSuite with Matchers { test("size for specialized, primitive key, value (int, int)") { val capacity = 1024 diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala index 970dade628fe4..ae0eebc26f01b 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.util.collection -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.util.SizeEstimator -class PrimitiveVectorSuite extends FunSuite { +class PrimitiveVectorSuite extends SparkFunSuite { test("primitive value") { val vector = new PrimitiveVector[Int] diff --git a/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala index 1f33967249654..5a5919fca2469 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala @@ -20,11 +20,10 @@ package org.apache.spark.util.collection import scala.reflect.ClassTag import scala.util.Random -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.util.SizeEstimator -class SizeTrackerSuite extends FunSuite { +class SizeTrackerSuite extends SparkFunSuite { val NORMAL_ERROR = 0.20 val HIGH_ERROR = 0.30 diff --git a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala index e0d6cc16bde05..72fd6daba8de0 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala @@ -20,11 +20,10 @@ package org.apache.spark.util.collection import java.lang.{Float => JFloat, Integer => JInteger} import java.util.{Arrays, Comparator} -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.util.random.XORShiftRandom -class SorterSuite extends FunSuite { +class SorterSuite extends SparkFunSuite { test("equivalent to Arrays.sort") { val rand = new XORShiftRandom(123) diff --git a/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala index f855831b8e367..361ec95654f47 100644 --- a/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.util.io import scala.util.Random -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class ByteArrayChunkOutputStreamSuite extends FunSuite { +class ByteArrayChunkOutputStreamSuite extends SparkFunSuite { test("empty output") { val o = new ByteArrayChunkOutputStream(1024) diff --git a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala index 20944b62473c5..2f1e6a39f4554 100644 --- a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala @@ -21,9 +21,11 @@ import java.util.Random import scala.collection.mutable.ArrayBuffer import org.apache.commons.math3.distribution.PoissonDistribution -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers -class RandomSamplerSuite extends FunSuite with Matchers { +import org.apache.spark.SparkFunSuite + +class RandomSamplerSuite extends SparkFunSuite with Matchers { /** * My statistical testing methodology is to run a Kolmogorov-Smirnov (KS) test * between the random samplers and simple reference samplers (known to work correctly). diff --git a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala index 73a9d029b0248..667a4db6f7bb6 100644 --- a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala @@ -20,9 +20,10 @@ package org.apache.spark.util.random import scala.util.Random import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution} -import org.scalatest.FunSuite -class SamplingUtilsSuite extends FunSuite { +import org.apache.spark.SparkFunSuite + +class SamplingUtilsSuite extends SparkFunSuite { test("reservoirSampleAndCount") { val input = Seq.fill(100)(Random.nextInt()) diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala index 03f5f2d1b8528..6ca484ccd0c06 100644 --- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala @@ -17,16 +17,16 @@ package org.apache.spark.util.random -import org.scalatest.FunSuite import org.scalatest.Matchers import org.apache.commons.math3.stat.inference.ChiSquareTest +import org.apache.spark.SparkFunSuite import org.apache.spark.util.Utils.times import scala.language.reflectiveCalls -class XORShiftRandomSuite extends FunSuite with Matchers { +class XORShiftRandomSuite extends SparkFunSuite with Matchers { def fixture: Object {val seed: Long; val hundMil: Int; val xorRand: XORShiftRandom} = new { val seed = 1L diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 1f3e619d97a24..bb2ec96715942 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -35,6 +35,13 @@ http://spark.apache.org/ + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.commons commons-lang3 diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index 650b2fbe1c142..e9fbcb9db6b78 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -31,9 +31,10 @@ import org.apache.flume.Context import org.apache.flume.channel.MemoryChannel import org.apache.flume.event.EventBuilder import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory -import org.scalatest.FunSuite -class SparkSinkSuite extends FunSuite { +import org.apache.spark.SparkFunSuite + +class SparkSinkSuite extends SparkFunSuite { val eventsPerBatch = 1000 val channelCapacity = 5000 diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 8df7edbdcad33..a345c03582ad6 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-streaming-flume-sink_${scala.binary.version} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index 93afe50c2134f..d772b9ca9b570 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -31,16 +31,16 @@ import org.apache.flume.conf.Configurables import org.apache.flume.event.EventBuilder import org.scalatest.concurrent.Eventually._ -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext} import org.apache.spark.streaming.flume.sink._ import org.apache.spark.util.{ManualClock, Utils} -class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging { +class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { val batchCount = 5 val eventsPerBatch = 100 diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 39e6754c81dbf..3d9daeb6e4363 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -35,15 +35,15 @@ import org.jboss.netty.channel.ChannelPipeline import org.jboss.netty.channel.socket.SocketChannel import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory import org.jboss.netty.handler.codec.compression._ -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} import org.apache.spark.util.Utils -class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { +class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite") var ssc: StreamingContext = null diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 243ce6eaca658..5734d55bf4784 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.kafka kafka_${scala.binary.version} diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index b6d314dfc7783..47bbfb605850a 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -28,10 +28,10 @@ import scala.language.postfixOps import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata import kafka.serializer.StringDecoder -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually -import org.apache.spark.{Logging, SparkConf, SparkContext} +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} import org.apache.spark.streaming.dstream.DStream @@ -39,7 +39,7 @@ import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.Utils class DirectKafkaStreamSuite - extends FunSuite + extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll with Eventually diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala index 7fb841b79cb65..d66830cbacdee 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala @@ -20,9 +20,11 @@ package org.apache.spark.streaming.kafka import scala.util.Random import kafka.common.TopicAndPartition -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -class KafkaClusterSuite extends FunSuite with BeforeAndAfterAll { +import org.apache.spark.SparkFunSuite + +class KafkaClusterSuite extends SparkFunSuite with BeforeAndAfterAll { private val topic = "kcsuitetopic" + Random.nextInt(10000) private val topicAndPartition = TopicAndPartition(topic, 0) private var kc: KafkaCluster = null diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala index 3c875cb766513..054487269a935 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala @@ -22,11 +22,11 @@ import scala.util.Random import kafka.serializer.StringDecoder import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll import org.apache.spark._ -class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { +class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { private var kafkaTestUtils: KafkaTestUtils = _ diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index 24699dfc33adb..8ee2cc660f849 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -23,14 +23,14 @@ import scala.language.postfixOps import scala.util.Random import kafka.serializer.StringDecoder -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Eventually -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} -class KafkaStreamSuite extends FunSuite with Eventually with BeforeAndAfterAll { +class KafkaStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterAll { private var ssc: StreamingContext = _ private var kafkaTestUtils: KafkaTestUtils = _ diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala index 38548dd73b82c..80e2df62de3fe 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala @@ -26,15 +26,15 @@ import scala.util.Random import kafka.serializer.StringDecoder import kafka.utils.{ZKGroupTopicDirs, ZkUtils} -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} import org.apache.spark.util.Utils -class ReliableKafkaStreamSuite extends FunSuite +class ReliableKafkaStreamSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter with Eventually { private val sparkConf = new SparkConf() diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 98f95a9a64fa0..7d102e10ab60f 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.eclipse.paho org.eclipse.paho.client.mqttv3 diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index a19a72c58a705..c4bf5aa7869bb 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -29,7 +29,7 @@ import org.apache.commons.lang3.RandomUtils import org.eclipse.paho.client.mqttv3._ import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually import org.apache.spark.streaming.{Milliseconds, StreamingContext} @@ -37,10 +37,10 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.scheduler.StreamingListener import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.util.Utils -class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { +class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter { private val batchDuration = Milliseconds(500) private val master = "local[2]" diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 8b6a8959ac4cf..d28e3e1846d70 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.twitter4j twitter4j-stream diff --git a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala index 9ee57d7581d85..d9acb568879fe 100644 --- a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala +++ b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala @@ -18,16 +18,16 @@ package org.apache.spark.streaming.twitter -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import twitter4j.Status import twitter4j.auth.{NullAuthorization, Authorization} -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream -class TwitterStreamSuite extends FunSuite with BeforeAndAfter with Logging { +class TwitterStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { val batchDuration = Seconds(1) diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index a50d378b34335..9998c11c85171 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + ${akka.group} akka-zeromq_${scala.binary.version} diff --git a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala index a7566e733d891..35d2e62c68480 100644 --- a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala +++ b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.streaming.zeromq import akka.actor.SupervisorStrategy import akka.util.ByteString import akka.zeromq.Subscribe -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream -class ZeroMQStreamSuite extends FunSuite { +class ZeroMQStreamSuite extends SparkFunSuite { val batchDuration = Seconds(1) diff --git a/graphx/pom.xml b/graphx/pom.xml index d38a3aa8256b7..28b41228feb3d 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -40,6 +40,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + com.google.guava guava diff --git a/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala index eb1dbe52c2fda..f1ecc9e2219d1 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.storage.StorageLevel -class EdgeRDDSuite extends FunSuite with LocalSparkContext { +class EdgeRDDSuite extends SparkFunSuite with LocalSparkContext { test("cache, getStorageLevel") { // test to see if getStorageLevel returns correct value after caching diff --git a/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala index 5a2c73b414279..7629128010193 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class EdgeSuite extends FunSuite { +class EdgeSuite extends SparkFunSuite { test ("compare") { // decending order val testEdges: Array[Edge[Int]] = Array( diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala index 68fe83739e399..57a8b95dd12e9 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.graphx -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.Graph._ import org.apache.spark.graphx.impl.EdgePartition import org.apache.spark.rdd._ -import org.scalatest.FunSuite -class GraphOpsSuite extends FunSuite with LocalSparkContext { +class GraphOpsSuite extends SparkFunSuite with LocalSparkContext { test("joinVertices") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 2b1d8e47326f8..1f5e27d5508b8 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.Graph._ import org.apache.spark.graphx.PartitionStrategy._ import org.apache.spark.rdd._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class GraphSuite extends FunSuite with LocalSparkContext { +class GraphSuite extends SparkFunSuite with LocalSparkContext { def starGraph(sc: SparkContext, n: Int): Graph[String, Int] = { Graph.fromEdgeTuples(sc.parallelize((1 to n).map(x => (0: VertexId, x: VertexId)), 3), "v") diff --git a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala index 490b94429ea1f..8afa2d403b53f 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala @@ -17,12 +17,10 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.rdd._ -class PregelSuite extends FunSuite with LocalSparkContext { +class PregelSuite extends SparkFunSuite with LocalSparkContext { test("1 iteration") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala index d0a7198d691d7..f1aa685a79c98 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala @@ -17,13 +17,11 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - -import org.apache.spark.{HashPartitioner, SparkContext} +import org.apache.spark.{HashPartitioner, SparkContext, SparkFunSuite} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -class VertexRDDSuite extends FunSuite with LocalSparkContext { +class VertexRDDSuite extends SparkFunSuite with LocalSparkContext { private def vertices(sc: SparkContext, n: Int) = { VertexRDD(sc.parallelize((0 to n).map(x => (x.toLong, x)), 5)) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala index 515f3a9cd02eb..7435647c6d9ee 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala @@ -20,15 +20,13 @@ package org.apache.spark.graphx.impl import scala.reflect.ClassTag import scala.util.Random -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.serializer.KryoSerializer import org.apache.spark.graphx._ -class EdgePartitionSuite extends FunSuite { +class EdgePartitionSuite extends SparkFunSuite { def makeEdgePartition[A: ClassTag](xs: Iterable[(Int, Int, A)]): EdgePartition[A, Int] = { val builder = new EdgePartitionBuilder[A, Int] diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala index fe8304c1cdc32..1203f8959f506 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.graphx.impl -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.serializer.KryoSerializer import org.apache.spark.graphx._ -class VertexPartitionSuite extends FunSuite { +class VertexPartitionSuite extends SparkFunSuite { test("isDefined, filter") { val vp = VertexPartition(Iterator((0L, 1), (1L, 1))).filter { (vid, attr) => vid == 0 } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala index accccfc232cd3..c965a6eb8df13 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ import org.apache.spark.graphx.util.GraphGenerators import org.apache.spark.rdd._ -class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { +class ConnectedComponentsSuite extends SparkFunSuite with LocalSparkContext { test("Grid Connected Components") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala index 61fd0c4605568..808877f0590f8 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ -class LabelPropagationSuite extends FunSuite with LocalSparkContext { +class LabelPropagationSuite extends SparkFunSuite with LocalSparkContext { test("Label Propagation") { withSpark { sc => // Construct a graph with two cliques connected by a single edge diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala index 39c6ace912b00..45f1e3011035e 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ import org.apache.spark.graphx.util.GraphGenerators @@ -57,7 +56,7 @@ object GridPageRank { } -class PageRankSuite extends FunSuite with LocalSparkContext { +class PageRankSuite extends SparkFunSuite with LocalSparkContext { def compareRanks(a: VertexRDD[Double], b: VertexRDD[Double]): Double = { a.leftJoin(b) { case (id, a, bOpt) => (a - bOpt.getOrElse(0.0)) * (a - bOpt.getOrElse(0.0)) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala index 7bd6b7f3c4ab2..2991438f5e57e 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ -class SVDPlusPlusSuite extends FunSuite with LocalSparkContext { +class SVDPlusPlusSuite extends SparkFunSuite with LocalSparkContext { test("Test SVD++ with mean square error on training set") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala index f2c38e79c452c..d7eaa70ce6407 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ import org.apache.spark.graphx.lib._ import org.apache.spark.graphx.util.GraphGenerators import org.apache.spark.rdd._ -class ShortestPathsSuite extends FunSuite with LocalSparkContext { +class ShortestPathsSuite extends SparkFunSuite with LocalSparkContext { test("Shortest Path Computations") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala index 1f658c371ffcf..d6b03208180db 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ import org.apache.spark.graphx.util.GraphGenerators import org.apache.spark.rdd._ -class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { +class StronglyConnectedComponentsSuite extends SparkFunSuite with LocalSparkContext { test("Island Strongly Connected Components") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala index 79bf4e6cd18ee..c47552cf3a3bd 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ import org.apache.spark.graphx.PartitionStrategy.RandomVertexCut -class TriangleCountSuite extends FunSuite with LocalSparkContext { +class TriangleCountSuite extends SparkFunSuite with LocalSparkContext { test("Count a single triangle") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala index f3b3738db0dad..186d0cc2a977b 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.graphx.util -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class BytecodeUtilsSuite extends FunSuite { +class BytecodeUtilsSuite extends SparkFunSuite { import BytecodeUtilsSuite.TestClass diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala index 8d9c8ddccbb3c..32e0c841c6997 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.graphx.util -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx.LocalSparkContext -class GraphGeneratorsSuite extends FunSuite with LocalSparkContext { +class GraphGeneratorsSuite extends SparkFunSuite with LocalSparkContext { test("GraphGenerators.generateRandomEdges") { val src = 5 diff --git a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala index 67c262d0f9d8d..928301523fba9 100644 --- a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala +++ b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.util -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class IdentifiableSuite extends FunSuite { +class IdentifiableSuite extends SparkFunSuite { import IdentifiableSuite.Test diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 2b04a3034782e..05bf58e63abaf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.ml import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito.when -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar.mock +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamMap import org.apache.spark.sql.DataFrame -class PipelineSuite extends FunSuite { +class PipelineSuite extends SparkFunSuite { abstract class MyModel extends Model[MyModel] diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala index 17ddd335deb6d..512cffb1acb66 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.attribute -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class AttributeGroupSuite extends FunSuite { +class AttributeGroupSuite extends SparkFunSuite { test("attribute group") { val attrs = Array( diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala index ec9b717e41ce8..72b575d022547 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.ml.attribute -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ -class AttributeSuite extends FunSuite { +class AttributeSuite extends SparkFunSuite { test("default numeric attribute") { val attr: NumericAttribute = NumericAttribute.defaultAttr diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 3fdc66be8a314..40554f6ef94a8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.classification -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint @@ -29,7 +28,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext { +class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { import DecisionTreeClassifierSuite.compareAPIs @@ -251,7 +250,7 @@ class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext { */ } -private[ml] object DecisionTreeClassifierSuite extends FunSuite { +private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { /** * Train 2 decision trees on the given dataset, one using the old API and one using the new API. diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index ea86867f1161a..09327051621e0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.classification -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} @@ -31,7 +30,7 @@ import org.apache.spark.sql.DataFrame /** * Test suite for [[GBTClassifier]]. */ -class GBTClassifierSuite extends FunSuite with MLlibTestSparkContext { +class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { import GBTClassifierSuite.compareAPIs diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 9f77d5f3efc55..a755cac3ea76e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.ml.classification -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { +class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var dataset: DataFrame = _ @transient var binaryDataset: DataFrame = _ diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 770b56890fa45..f439f3261f06f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.classification -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS @@ -30,7 +29,7 @@ import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -class OneVsRestSuite extends FunSuite with MLlibTestSparkContext { +class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var dataset: DataFrame = _ @transient var rdd: RDD[LabeledPoint] = _ diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index cdbbacab8e0e3..f699d0c374d2f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.classification -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint @@ -32,7 +31,7 @@ import org.apache.spark.sql.DataFrame /** * Test suite for [[RandomForestClassifier]]. */ -class RandomForestClassifierSuite extends FunSuite with MLlibTestSparkContext { +class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { import RandomForestClassifierSuite.compareAPIs diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index 3ea7aad5274f2..9da0618abd23c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.ml.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ -class RegressionEvaluatorSuite extends FunSuite with MLlibTestSparkContext { +class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext { test("Regression Evaluator: default params") { /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 8f6c6b39dc93b..d4631518e0f5b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -class BinarizerSuite extends FunSuite with MLlibTestSparkContext { +class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var data: Array[Double] = _ diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 0391bd8427c2c..507a8a7db24c7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -19,15 +19,13 @@ package org.apache.spark.ml.feature import scala.util.Random -import org.scalatest.FunSuite - -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class BucketizerSuite extends FunSuite with MLlibTestSparkContext { +class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext { test("Bucket continuous features, without -inf,inf") { // Check a set of valid feature values. @@ -110,7 +108,7 @@ class BucketizerSuite extends FunSuite with MLlibTestSparkContext { } } -private object BucketizerSuite extends FunSuite { +private object BucketizerSuite extends SparkFunSuite { /** Brute force search for buckets. Bucket i is defined by the range [split(i), split(i+1)). */ def linearSearchForBuckets(splits: Array[Double], feature: Double): Double = { require(feature >= splits.head) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala index 2e4beb0bfff63..7b2d70e644005 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -26,7 +25,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class HashingTFSuite extends FunSuite with MLlibTestSparkContext { +class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { test("params") { val hashingTF = new HashingTF diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index f85e85471617a..d83772e8be755 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row -class IDFSuite extends FunSuite with MLlibTestSparkContext { +class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = { dataSet.map { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index 9d09f24709e23..9f03470b7f328 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row, SQLContext} -class NormalizerSuite extends FunSuite with MLlibTestSparkContext { +class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var data: Array[Vector] = _ @transient var dataFrame: DataFrame = _ diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 9018d0024d5f0..2e5036a844562 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.col -class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext { +class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { def stringIndexed(): DataFrame = { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index aa230ca073d5b..feca866cd711d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -17,15 +17,15 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite import org.scalatest.exceptions.TestFailedException +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row -class PolynomialExpansionSuite extends FunSuite with MLlibTestSparkContext { +class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext { test("Polynomial expansion with default parameter") { val data = Array( diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 89c2fe45573aa..cbf1e8ddcb48a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.mllib.util.MLlibTestSparkContext -class StringIndexerSuite extends FunSuite with MLlibTestSparkContext { +class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { test("StringIndexer") { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index eabda089d0988..ac279cb3215c2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -19,15 +19,14 @@ package org.apache.spark.ml.feature import scala.beans.BeanInfo -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @BeanInfo case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) -class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext { +class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext { import org.apache.spark.ml.feature.RegexTokenizerSuite._ test("RegexTokenizer") { @@ -60,7 +59,7 @@ class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext { } } -object RegexTokenizerSuite extends FunSuite { +object RegexTokenizerSuite extends SparkFunSuite { def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = { t.transform(dataset) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index 43534e89928b1..489abb5af7130 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col -class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext { +class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext { test("assemble") { import org.apache.spark.ml.feature.VectorAssembler.assemble diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index b11b029c6343e..06affc7305cf5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -19,16 +19,14 @@ package org.apache.spark.ml.feature import scala.beans.{BeanInfo, BeanProperty} -import org.scalatest.FunSuite - -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute._ import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext { +class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { import VectorIndexerSuite.FeatureData diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index df446d0c22015..94ebc3aebfa37 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{Row, SQLContext} -class Word2VecSuite extends FunSuite with MLlibTestSparkContext { +class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { test("Word2Vec") { val sqlContext = new SQLContext(sc) diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala index 1505ad872536b..778abcba22c10 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala @@ -19,8 +19,7 @@ package org.apache.spark.ml.impl import scala.collection.JavaConverters._ -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.api.java.JavaRDD import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.ml.tree._ @@ -29,7 +28,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{SQLContext, DataFrame} -private[ml] object TreeTests extends FunSuite { +private[ml] object TreeTests extends SparkFunSuite { /** * Convert the given data to a DataFrame, and set the features and label metadata. diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 04f2af4727ea4..f80e7749098a5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.param -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class ParamsSuite extends FunSuite { +class ParamsSuite extends SparkFunSuite { test("param") { val solver = new TestParams() @@ -202,7 +202,7 @@ class ParamsSuite extends FunSuite { } } -object ParamsSuite extends FunSuite { +object ParamsSuite extends SparkFunSuite { /** * Checks common requirements for [[Params.params]]: 1) number of params; 2) params are ordered diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala index ca18fa1ad3c15..eb5408d3fee7c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.ml.param.shared -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.Params -class SharedParamsSuite extends FunSuite { +class SharedParamsSuite extends SparkFunSuite { test("outputCol") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 9a35555e52b90..2e5cfe7027eb6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -25,9 +25,8 @@ import scala.collection.mutable.ArrayBuffer import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.scalatest.FunSuite -import org.apache.spark.{Logging, SparkException} +import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.recommendation.ALS._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -36,7 +35,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.Utils -class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { +class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { private var tempDir: File = _ diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 1196a772dfdd4..1182b89a8e3aa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.regression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, @@ -28,7 +27,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext { +class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { import DecisionTreeRegressorSuite.compareAPIs @@ -69,7 +68,7 @@ class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext { // TODO: test("model save/load") SPARK-6725 } -private[ml] object DecisionTreeRegressorSuite extends FunSuite { +private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { /** * Train 2 decision trees on the given dataset, one using the old API and one using the new API. diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 40e7e3273e965..f8a1469fee313 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.regression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} @@ -31,7 +30,7 @@ import org.apache.spark.sql.DataFrame /** * Test suite for [[GBTRegressor]]. */ -class GBTRegressorSuite extends FunSuite with MLlibTestSparkContext { +class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { import GBTRegressorSuite.compareAPIs diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 50a78631fa6d6..732e2c42be144 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.ml.regression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.DenseVector import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { +class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var dataset: DataFrame = _ diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 3efffbb763b78..78911560945a2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.regression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} @@ -31,7 +30,7 @@ import org.apache.spark.sql.DataFrame /** * Test suite for [[RandomForestRegressor]]. */ -class RandomForestRegressorSuite extends FunSuite with MLlibTestSparkContext { +class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { import RandomForestRegressorSuite.compareAPIs @@ -98,7 +97,7 @@ class RandomForestRegressorSuite extends FunSuite with MLlibTestSparkContext { */ } -private object RandomForestRegressorSuite extends FunSuite { +private object RandomForestRegressorSuite extends SparkFunSuite { /** * Train 2 models on the given dataset, one using the old API and one using the new API. diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 60d8bfe38fb13..5ba469c7b10a0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.tuning -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.LogisticRegression @@ -29,7 +29,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.types.StructType -class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { +class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var dataset: DataFrame = _ diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala index 20aa100112bfe..810b70049ec15 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala @@ -19,11 +19,10 @@ package org.apache.spark.ml.tuning import scala.collection.mutable -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.{ParamMap, TestParams} -class ParamGridBuilderSuite extends FunSuite { +class ParamGridBuilderSuite extends SparkFunSuite { val solver = new TestParams() import solver.{inputCol, maxIter} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala index 3d362b5ee53ea..59944416d96a6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.api.python -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors, SparseMatrix} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.recommendation.Rating -class PythonMLLibAPISuite extends FunSuite { +class PythonMLLibAPISuite extends SparkFunSuite { SerDe.initialize() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index b1014ab7c6203..e8f3d0c4db20a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -21,9 +21,9 @@ import scala.collection.JavaConversions._ import scala.util.Random import scala.util.control.Breaks._ -import org.scalatest.FunSuite import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} @@ -169,7 +169,7 @@ object LogisticRegressionSuite { } -class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { +class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { def validatePrediction( predictions: Seq[Double], input: Seq[LabeledPoint], @@ -541,7 +541,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M } -class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { +class LogisticRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction using SGD optimizer") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index ea40b41bbbe5e..f7fc8730606af 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -21,9 +21,8 @@ import scala.util.Random import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} import breeze.stats.distributions.{Multinomial => BrzMultinomial} -import org.scalatest.FunSuite -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} @@ -86,7 +85,7 @@ object NaiveBayesSuite { pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), Multinomial) } -class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { +class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { import NaiveBayes.{Multinomial, Bernoulli} @@ -286,7 +285,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { } } -class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext { +class NaiveBayesClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 10 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index 90f9cec6855bf..b1d78cba9e3dc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -21,9 +21,8 @@ import scala.collection.JavaConversions._ import scala.util.Random import org.jblas.DoubleMatrix -import org.scalatest.FunSuite -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} @@ -62,7 +61,7 @@ object SVMSuite { } -class SVMSuite extends FunSuite with MLlibTestSparkContext { +class SVMSuite extends SparkFunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => @@ -229,7 +228,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { } } -class SVMClusterSuite extends FunSuite with LocalClusterSparkContext { +class SVMClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index 5683b55e8500a..e98b61e13e21f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -19,15 +19,14 @@ package org.apache.spark.mllib.classification import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.TestSuiteBase -class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase { +class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase { // use longer wait time to ensure job completion override def maxWaitTimeMillis: Int = 30000 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala index f356ffa3e3a26..a3b085e441491 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.clustering -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vectors, Matrices} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { +class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { test("single cluster") { val data = sc.parallelize(Array( Vectors.dense(6.0, 9.0), diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 877e6dc699523..0dbbd7127444f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.mllib.clustering import scala.util.Random -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class KMeansSuite extends FunSuite with MLlibTestSparkContext { +class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { import org.apache.spark.mllib.clustering.KMeans.{K_MEANS_PARALLEL, RANDOM} @@ -281,7 +280,7 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { } } -object KMeansSuite extends FunSuite { +object KMeansSuite extends SparkFunSuite { def createModel(dim: Int, k: Int, isSparse: Boolean): KMeansModel = { val singlePoint = isSparse match { case true => @@ -305,7 +304,7 @@ object KMeansSuite extends FunSuite { } } -class KMeansClusterSuite extends FunSuite with LocalClusterSparkContext { +class KMeansClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index d5b7d96335744..406affa25539d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseMatrix => BDM} -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class LDASuite extends FunSuite with MLlibTestSparkContext { +class LDASuite extends SparkFunSuite with MLlibTestSparkContext { import LDASuite._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala index 556842f3129a3..3903712879928 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala @@ -20,15 +20,13 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable import scala.util.Random -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.{Edge, Graph} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext { +class PowerIterationClusteringSuite extends SparkFunSuite with MLlibTestSparkContext { import org.apache.spark.mllib.clustering.PowerIterationClustering._ @@ -130,7 +128,7 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext } } -object PowerIterationClusteringSuite extends FunSuite { +object PowerIterationClusteringSuite extends SparkFunSuite { def createModel(sc: SparkContext, k: Int, nPoints: Int): PowerIterationClusteringModel = { val assignments = sc.parallelize( (0 until nPoints).map(p => PowerIterationClustering.Assignment(p, Random.nextInt(k)))) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala index 13f9b17c027a4..ac01622b8a089 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.clustering -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.streaming.TestSuiteBase import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.random.XORShiftRandom -class StreamingKMeansSuite extends FunSuite with TestSuiteBase { +class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { override def maxWaitTimeMillis: Int = 30000 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala index 79847633ff0dc..87ccc7eda44ea 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class AreaUnderCurveSuite extends FunSuite with MLlibTestSparkContext { +class AreaUnderCurveSuite extends SparkFunSuite with MLlibTestSparkContext { test("auc computation") { val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0)) val auc = 4.0 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala index e0224f960cc43..99d52fabc5309 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class BinaryClassificationMetricsSuite extends FunSuite with MLlibTestSparkContext { +class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { private def areWithinEpsilon(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index 7dc4f3cfbc4e4..d55bc8c3ec09f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Matrices import org.apache.spark.mllib.util.MLlibTestSparkContext -class MulticlassMetricsSuite extends FunSuite with MLlibTestSparkContext { +class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("Multiclass evaluation metrics") { /* * Confusion matrix for 3-class classification with total 9 instances: diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala index 2537dd62c92f2..f3b19aeb42f84 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -class MultilabelMetricsSuite extends FunSuite with MLlibTestSparkContext { +class MultilabelMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("Multilabel evaluation metrics") { /* * Documents true labels (5x class0, 3x class1, 4x class2): diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala index 609eed983ff4e..c0924a213a844 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -class RankingMetricsSuite extends FunSuite with MLlibTestSparkContext { +class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("Ranking metrics: map, ndcg") { val predictionAndLabels = sc.parallelize( Seq( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala index 3aa732474ec2e..9de2bdb6d7246 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class RegressionMetricsSuite extends FunSuite with MLlibTestSparkContext { +class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("regression metrics") { val predictionAndObservations = sc.parallelize( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala index 747f5914598ec..889727fb55823 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext -class ChiSqSelectorSuite extends FunSuite with MLlibTestSparkContext { +class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { /* * Contingency tables diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala index f3a482abda873..ccbf8a91cdd37 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class ElementwiseProductSuite extends FunSuite with MLlibTestSparkContext { +class ElementwiseProductSuite extends SparkFunSuite with MLlibTestSparkContext { test("elementwise (hadamard) product should properly apply vector to dense data set") { val denseData = Array( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala index 0c4dfb7b97c7f..cf279c02334e9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext -class HashingTFSuite extends FunSuite with MLlibTestSparkContext { +class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { test("hashing tf on a single doc") { val hashingTF = new HashingTF(1000) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala index 0a5cad7caf8e4..21163633051e5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class IDFSuite extends FunSuite with MLlibTestSparkContext { +class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { test("idf") { val n = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala index 5c4af2b99e68b..34122d6ed2e95 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - import breeze.linalg.{norm => brzNorm} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class NormalizerSuite extends FunSuite with MLlibTestSparkContext { +class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { val data = Array( Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala index 758af588f1c69..e57f49191378f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.util.MLlibTestSparkContext -class PCASuite extends FunSuite with MLlibTestSparkContext { +class PCASuite extends SparkFunSuite with MLlibTestSparkContext { private val data = Array( Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala index 1eb991869de40..6ab2fa6770123 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} import org.apache.spark.rdd.RDD -class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { +class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext { // When the input data is all constant, the variance is zero. The standardization against // zero variance is not well-defined, but we decide to just set it into zero here. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index 98a98a7599bcb..b6818369208d7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class Word2VecSuite extends FunSuite with MLlibTestSparkContext { +class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { // TODO: add more tests diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index bd5b9cc3afa10..66ae3543ecc4e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -16,11 +16,10 @@ */ package org.apache.spark.mllib.fpm -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { +class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { test("FP-Growth using String type") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala index 04017f67c311d..a56d7b3579213 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala @@ -19,11 +19,10 @@ package org.apache.spark.mllib.fpm import scala.language.existentials -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -class FPTreeSuite extends FunSuite with MLlibTestSparkContext { +class FPTreeSuite extends SparkFunSuite with MLlibTestSparkContext { test("add transaction") { val tree = new FPTree[String] diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala index 699f009f0f2ec..d34888af2d73b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala @@ -17,18 +17,16 @@ package org.apache.spark.mllib.impl -import org.scalatest.FunSuite - import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.{Edge, Graph} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext { +class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { import PeriodicGraphCheckpointerSuite._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 64ecd12ea7ded..bcc2e657f3fd4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.linalg -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.linalg.BLAS._ -class BLASSuite extends FunSuite { +class BLASSuite extends SparkFunSuite { test("copy") { val sx = Vectors.sparse(4, Array(0, 2), Array(1.0, -2.0)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala index 2031032373971..dc04258e41d27 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.mllib.linalg -import org.scalatest.FunSuite - import breeze.linalg.{DenseMatrix => BDM, CSCMatrix => BSM} -class BreezeMatrixConversionSuite extends FunSuite { +import org.apache.spark.SparkFunSuite + +class BreezeMatrixConversionSuite extends SparkFunSuite { test("dense matrix to breeze") { val mat = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) val breeze = mat.toBreeze.asInstanceOf[BDM[Double]] diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala index 8abdac72902c6..3772c9235ad3a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.mllib.linalg -import org.scalatest.FunSuite - import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} +import org.apache.spark.SparkFunSuite + /** * Test Breeze vector conversions. */ -class BreezeVectorConversionSuite extends FunSuite { +class BreezeVectorConversionSuite extends SparkFunSuite { val arr = Array(0.1, 0.2, 0.3, 0.4) val n = 20 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index 86119ec38101e..8dbb70f5d1c4c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.mllib.linalg import java.util.Random import org.mockito.Mockito.when -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar._ import scala.collection.mutable.{Map => MutableMap} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.TestingUtils._ -class MatricesSuite extends FunSuite { +class MatricesSuite extends SparkFunSuite { test("dense matrix construction") { val m = 3 val n = 2 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 24755e9ff46fc..c6d29dcdb0f2b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -20,12 +20,11 @@ package org.apache.spark.mllib.linalg import scala.util.Random import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance} -import org.scalatest.FunSuite -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.util.TestingUtils._ -class VectorsSuite extends FunSuite { +class VectorsSuite extends SparkFunSuite { val arr = Array(0.1, 0.0, 0.3, 0.4) val n = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala index a58336175899c..93fe04c139b9a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala @@ -20,14 +20,13 @@ package org.apache.spark.mllib.linalg.distributed import java.{util => ju} import breeze.linalg.{DenseMatrix => BDM} -import org.scalatest.FunSuite -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.{SparseMatrix, DenseMatrix, Matrices, Matrix} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext { +class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val m = 5 val n = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala index 04b36a9ef9990..f3728cd036a3f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.mllib.linalg.distributed -import org.scalatest.FunSuite - import breeze.linalg.{DenseMatrix => BDM} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.linalg.Vectors -class CoordinateMatrixSuite extends FunSuite with MLlibTestSparkContext { +class CoordinateMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val m = 5 val n = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index 2ab53cc13db71..4a7b99a976f0a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.linalg.distributed -import org.scalatest.FunSuite - import breeze.linalg.{diag => brzDiag, DenseMatrix => BDM, DenseVector => BDV} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Matrices, Vectors} -class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext { +class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val m = 4 val n = 3 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 27bb19f472e1e..b6cb53d0c743e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -20,12 +20,12 @@ package org.apache.spark.mllib.linalg.distributed import scala.util.Random import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, svd => brzSvd} -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Matrices, Vectors, Vector} import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} -class RowMatrixSuite extends FunSuite with MLlibTestSparkContext { +class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val m = 4 val n = 3 @@ -240,7 +240,7 @@ class RowMatrixSuite extends FunSuite with MLlibTestSparkContext { } } -class RowMatrixClusterSuite extends FunSuite with LocalClusterSparkContext { +class RowMatrixClusterSuite extends SparkFunSuite with LocalClusterSparkContext { var mat: RowMatrix = _ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index e110506d579b0..a5a59e9fad5ae 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -20,8 +20,9 @@ package org.apache.spark.mllib.optimization import scala.collection.JavaConversions._ import scala.util.Random -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} @@ -61,7 +62,7 @@ object GradientDescentSuite { } } -class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matchers { +class GradientDescentSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { test("Assert the loss is decreasing.") { val nPoints = 10000 @@ -140,7 +141,7 @@ class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matc } } -class GradientDescentClusterSuite extends FunSuite with LocalClusterSparkContext { +class GradientDescentClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index c8f2adcf155a7..d07b9d5b89227 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -19,14 +19,15 @@ package org.apache.spark.mllib.optimization import scala.util.Random -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ -class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers { +class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { val nPoints = 10000 val A = 2.0 @@ -229,7 +230,7 @@ class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers { } } -class LBFGSClusterSuite extends FunSuite with LocalClusterSparkContext { +class LBFGSClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small") { val m = 10 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala index bb723fc471181..d8f9b8c33963d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.mllib.optimization import scala.util.Random -import org.scalatest.FunSuite - import org.jblas.{DoubleMatrix, SimpleBlas} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.TestingUtils._ -class NNLSSuite extends FunSuite { +class NNLSSuite extends SparkFunSuite { /** Generate an NNLS problem whose optimal solution is the all-ones vector. */ def genOnesData(n: Int, rand: Random): (DoubleMatrix, DoubleMatrix) = { val A = new DoubleMatrix(n, n, Array.fill(n*n)(rand.nextDouble()): _*) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala index 0b646cf1ce6c4..7a724fc78b1d9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.mllib.pmml.export import org.dmg.pmml.RegressionModel import org.dmg.pmml.RegressionNormalizationMethodType -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.classification.LogisticRegressionModel import org.apache.spark.mllib.classification.SVMModel import org.apache.spark.mllib.util.LinearDataGenerator -class BinaryClassificationPMMLModelExportSuite extends FunSuite { +class BinaryClassificationPMMLModelExportSuite extends SparkFunSuite { test("logistic regression PMML export") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala index f9afbd888dfc5..1d32309481787 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.mllib.pmml.export import org.dmg.pmml.RegressionModel -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel} import org.apache.spark.mllib.util.LinearDataGenerator -class GeneralizedLinearPMMLModelExportSuite extends FunSuite { +class GeneralizedLinearPMMLModelExportSuite extends SparkFunSuite { test("linear regression PMML export") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala index b985d0446d7b0..a1a683559a54c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.mllib.pmml.export import org.dmg.pmml.ClusteringModel -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.clustering.KMeansModel import org.apache.spark.mllib.linalg.Vectors -class KMeansPMMLModelExportSuite extends FunSuite { +class KMeansPMMLModelExportSuite extends SparkFunSuite { test("KMeansPMMLModelExport generate PMML format") { val clusterCenters = Array( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala index f28a4ac8ad01f..0d194005a30b2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.pmml.export -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.classification.{LogisticRegressionModel, SVMModel} import org.apache.spark.mllib.clustering.KMeansModel import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel} import org.apache.spark.mllib.util.LinearDataGenerator -class PMMLModelExportFactorySuite extends FunSuite { +class PMMLModelExportFactorySuite extends SparkFunSuite { test("PMMLModelExportFactory create KMeansPMMLModelExport when passing a KMeansModel") { val clusterCenters = Array( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala index b792d819fdabb..a5ca1518f82f5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala @@ -19,12 +19,11 @@ package org.apache.spark.mllib.random import scala.math -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.util.StatCounter // TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged -class RandomDataGeneratorSuite extends FunSuite { +class RandomDataGeneratorSuite extends SparkFunSuite { def apiChecks(gen: RandomDataGenerator[Double]) { // resetting seed should generate the same sequence of random numbers diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala index 63f2ea916d457..413db2000d6d7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.mllib.random import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.rdd.{RandomRDDPartition, RandomRDD} @@ -34,7 +33,7 @@ import org.apache.spark.util.StatCounter * * TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged */ -class RandomRDDsSuite extends FunSuite with MLlibTestSparkContext with Serializable { +class RandomRDDsSuite extends SparkFunSuite with MLlibTestSparkContext with Serializable { def testGeneratedRDD(rdd: RDD[Double], expectedSize: Long, diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala index 57216e8eb4a55..10f5a2be48f7c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.rdd -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.rdd.MLPairRDDFunctions._ -class MLPairRDDFunctionsSuite extends FunSuite with MLlibTestSparkContext { +class MLPairRDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { test("topByKey") { val topMap = sc.parallelize(Array((1, 7), (1, 3), (1, 6), (1, 1), (1, 2), (3, 2), (3, 7), (5, 1), (3, 5)), 2) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index 6d6c0aa5be812..bc64172614830 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.rdd -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.rdd.RDDFunctions._ -class RDDFunctionsSuite extends FunSuite with MLlibTestSparkContext { +class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { test("sliding") { val data = 0 until 6 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index b3798940ddc38..05b87728d6fdb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -21,9 +21,9 @@ import scala.collection.JavaConversions._ import scala.math.abs import scala.util.Random -import org.scalatest.FunSuite import org.jblas.DoubleMatrix +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.storage.StorageLevel @@ -84,7 +84,7 @@ object ALSSuite { } -class ALSSuite extends FunSuite with MLlibTestSparkContext { +class ALSSuite extends SparkFunSuite with MLlibTestSparkContext { test("rank-1 matrices") { testALS(50, 100, 1, 15, 0.7, 0.3) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala index 2c92866f3893d..2c8ed057a516a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.mllib.recommendation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils -class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext { +class MatrixFactorizationModelSuite extends SparkFunSuite with MLlibTestSparkContext { val rank = 2 var userFeatures: RDD[(Int, Array[Double])] = _ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala index 3b38bdf5ef5eb..ea4f2865757c1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.mllib.regression -import org.scalatest.{Matchers, FunSuite} +import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { +class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { private def round(d: Double) = { math.round(d * 100).toDouble / 100 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala index 110c44a7193fd..d8364a06de4da 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.mllib.regression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors -class LabeledPointSuite extends FunSuite { +class LabeledPointSuite extends SparkFunSuite { test("parse labeled points") { val points = Seq( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index 71dce50922991..08a152ffc7a23 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.mllib.regression import scala.util.Random -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} @@ -32,7 +31,7 @@ private object LassoSuite { val model = new LassoModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) } -class LassoSuite extends FunSuite with MLlibTestSparkContext { +class LassoSuite extends SparkFunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => @@ -143,7 +142,7 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext { } } -class LassoClusterSuite extends FunSuite with LocalClusterSparkContext { +class LassoClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index 3781931c2f819..f88a1c33c9f7c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.mllib.regression import scala.util.Random -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} @@ -32,7 +31,7 @@ private object LinearRegressionSuite { val model = new LinearRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) } -class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { +class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => @@ -150,7 +149,7 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { } } -class LinearRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { +class LinearRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index d6c93cc0e49cd..7a781fee634c8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.mllib.regression import scala.util.Random import org.jblas.DoubleMatrix -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} @@ -33,7 +33,7 @@ private object RidgeRegressionSuite { val model = new RidgeRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) } -class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { +class RidgeRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]): Double = { predictions.zip(input).map { case (prediction, expected) => @@ -101,7 +101,7 @@ class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { } } -class RidgeRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { +class RidgeRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index 26604dbe6c1ef..9a379406d5061 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.mllib.regression import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.TestSuiteBase -class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { +class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { // use longer wait time to ensure job completion override def maxWaitTimeMillis: Int = 20000 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala index a7e6fce31ff7e..c292ced75e870 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.mllib.stat -import org.scalatest.FunSuite - import breeze.linalg.{DenseMatrix => BDM, Matrix => BM} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation, SpearmanCorrelation} import org.apache.spark.mllib.util.MLlibTestSparkContext -class CorrelationSuite extends FunSuite with MLlibTestSparkContext { +class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext { // test input data val xData = Array(1.0, 0.0, -2.0) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala index 15418e6035965..b084a5fb4313f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala @@ -19,16 +19,14 @@ package org.apache.spark.mllib.stat import java.util.Random -import org.scalatest.FunSuite - -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.test.ChiSqTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class HypothesisTestSuite extends FunSuite with MLlibTestSparkContext { +class HypothesisTestSuite extends SparkFunSuite with MLlibTestSparkContext { test("chi squared pearson goodness of fit") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala index a309c942cf8ff..5feccdf33681a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.mllib.stat import org.apache.commons.math3.distribution.NormalDistribution -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -class KernelDensitySuite extends FunSuite with MLlibTestSparkContext { +class KernelDensitySuite extends SparkFunSuite with MLlibTestSparkContext { test("kernel density single sample") { val rdd = sc.parallelize(Array(5.0)) val evaluationPoints = Array(5.0, 6.0) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index 23b0eec865de6..07efde4f5e6dc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.stat -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.TestingUtils._ -class MultivariateOnlineSummarizerSuite extends FunSuite { +class MultivariateOnlineSummarizerSuite extends SparkFunSuite { test("basic error handing") { val summarizer = new MultivariateOnlineSummarizer diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala index fac2498e4dcb3..703b623536315 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.stat.distribution -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{ Vectors, Matrices } import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class MultivariateGaussianSuite extends FunSuite with MLlibTestSparkContext { +class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext { test("univariate") { val x1 = Vectors.dense(0.0) val x2 = Vectors.dense(1.5) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index ce983eb27fa35..356d957f15909 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -20,8 +20,7 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ import scala.collection.mutable -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ @@ -34,7 +33,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.util.Utils -class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { +class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { ///////////////////////////////////////////////////////////////////////////// // Tests examining individual elements of training @@ -859,7 +858,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { } } -object DecisionTreeSuite extends FunSuite { +object DecisionTreeSuite extends SparkFunSuite { def validateClassifier( model: DecisionTreeModel, diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index 55b0bac7d49fe..84dd3b342d4c0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.mllib.tree -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} @@ -32,7 +31,7 @@ import org.apache.spark.util.Utils /** * Test suite for [[GradientBoostedTrees]]. */ -class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { +class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext { test("Regression with continuous features: SquaredError") { GradientBoostedTreesSuite.testCombinations.foreach { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala index 92b498580af03..49aff21fe7914 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.tree -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator} import org.apache.spark.mllib.util.MLlibTestSparkContext /** * Test suites for [[GiniAggregator]] and [[EntropyAggregator]]. */ -class ImpuritySuite extends FunSuite with MLlibTestSparkContext { +class ImpuritySuite extends SparkFunSuite with MLlibTestSparkContext { test("Gini impurity does not support negative labels") { val gini = new GiniAggregator(2) intercept[IllegalArgumentException] { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index 4ed66953cb628..e6df5d974bf36 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.mllib.tree import scala.collection.mutable -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ @@ -35,7 +34,7 @@ import org.apache.spark.util.Utils /** * Test suite for [[RandomForest]]. */ -class RandomForestSuite extends FunSuite with MLlibTestSparkContext { +class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { def binaryClassificationTestWithContinuousFeatures(strategy: Strategy) { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala index b184e936672ca..9d756da410325 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.tree.impl -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.tree.EnsembleTestHelper import org.apache.spark.mllib.util.MLlibTestSparkContext /** * Test suite for [[BaggedPoint]]. */ -class BaggedPointSuite extends FunSuite with MLlibTestSparkContext { +class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { test("BaggedPoint RDD: without subsampling") { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index cdece2c174be4..87b3661f77944 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -21,19 +21,18 @@ import java.io.File import scala.io.Source -import org.scalatest.FunSuite - import breeze.linalg.{squaredDistance => breezeSquaredDistance} import com.google.common.base.Charsets import com.google.common.io.Files +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils._ import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class MLUtilsSuite extends FunSuite with MLlibTestSparkContext { +class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { test("epsilon computation") { assert(1.0 + EPSILON > 1.0, s"EPSILON is too small: $EPSILON.") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala index f68fb95eac4e4..8dcb9ba9be108 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.mllib.util -import org.scalatest.FunSuite +import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.SparkException - -class NumericParserSuite extends FunSuite { +class NumericParserSuite extends SparkFunSuite { test("parser") { val s = "((1.0,2e3),-4,[5e-6,7.0E8],+9)" diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala index 59e6c778806f4..8f475f30249d6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.mllib.util +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors -import org.scalatest.FunSuite import org.apache.spark.mllib.util.TestingUtils._ import org.scalatest.exceptions.TestFailedException -class TestingUtilsSuite extends FunSuite { +class TestingUtilsSuite extends SparkFunSuite { test("Comparing doubles using relative error.") { diff --git a/repl/pom.xml b/repl/pom.xml index 03053b4c3b287..6e5cb7f77e1df 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -48,6 +48,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-bagel_${scala.binary.version} diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 934daaeaafca1..50fd43a418bca 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -22,13 +22,12 @@ import java.net.URLClassLoader import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.util.Utils -class ReplSuite extends FunSuite { +class ReplSuite extends SparkFunSuite { def runInterpreter(master: String, input: String): String = { val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath" diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 14f5e9ed4f25e..9ecc7c229e38a 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -24,14 +24,13 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.tools.nsc.interpreter.SparkILoop -import org.scalatest.FunSuite import org.apache.commons.lang3.StringEscapeUtils -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.util.Utils -class ReplSuite extends FunSuite { +class ReplSuite extends SparkFunSuite { def runInterpreter(master: String, input: String): String = { val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath" diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index c709cde740748..a58eda12b1120 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -25,7 +25,6 @@ import scala.language.implicitConversions import scala.language.postfixOps import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite import org.scalatest.concurrent.Interruptor import org.scalatest.concurrent.Timeouts._ import org.scalatest.mock.MockitoSugar @@ -35,7 +34,7 @@ import org.apache.spark._ import org.apache.spark.util.Utils class ExecutorClassLoaderSuite - extends FunSuite + extends SparkFunSuite with BeforeAndAfterAll with MockitoSugar with Logging { diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 5c322d032d474..d9e1cdb84bb27 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -50,6 +50,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-unsafe_${scala.binary.version} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index ea82cd2622de9..c046dbf4dc2c9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.catalyst -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.plans.physical._ /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ -class DistributionSuite extends FunSuite { +class DistributionSuite extends SparkFunSuite { protected def checkSatisfied( inputPartitioning: Partitioning, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 7ff51db76b6bb..9a24b23024e18 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -20,8 +20,7 @@ package org.apache.spark.sql.catalyst import java.math.BigInteger import java.sql.{Date, Timestamp} -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.types._ @@ -75,7 +74,7 @@ case class MultipleConstructorsData(a: Int, b: String, c: Double) { def this(b: String, a: Int) = this(a, b, c = 1.0) } -class ScalaReflectionSuite extends FunSuite { +class ScalaReflectionSuite extends SparkFunSuite { import ScalaReflection._ test("primitive data") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala index 9eed15952d82b..b93a3abc6ebd2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.logical.Command -import org.scalatest.FunSuite private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Command { override def output: Seq[Attribute] = Seq.empty @@ -49,7 +49,7 @@ private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser { } } -class SqlParserSuite extends FunSuite { +class SqlParserSuite extends SparkFunSuite { test("test long keyword") { val parser = new SuperLongKeywordTestParser diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index fcff24ca31486..e09cd790a7187 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -27,7 +28,7 @@ import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -class AnalysisSuite extends FunSuite with BeforeAndAfter { +class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { val caseSensitiveConf = new SimpleCatalystConf(true) val caseInsensitiveConf = new SimpleCatalystConf(false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 565b1cfe019c7..1b8d18ded2257 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -17,14 +17,15 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation} import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.SimpleCatalystConf -class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { +class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { val conf = new SimpleCatalystConf(true) val catalog = new SimpleCatalog(conf) val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala index f2f3a84d19380..97cfb5f06dd73 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.IntegerType -class AttributeSetSuite extends FunSuite { +class AttributeSetSuite extends SparkFunSuite { val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1)) val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index a14f776b1eaee..b511aa3a24420 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -22,9 +22,9 @@ import java.sql.{Date, Timestamp} import scala.collection.immutable.HashSet import org.scalactic.TripleEqualsSupport.Spread -import org.scalatest.FunSuite import org.scalatest.Matchers._ +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ -class ExpressionEvaluationBaseSuite extends FunSuite { +class ExpressionEvaluationBaseSuite extends SparkFunSuite { def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { expression.eval(inputRow) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala index 7a19e511eb8b5..88a36aa121b55 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -20,12 +20,16 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.JavaConverters._ import scala.util.Random +import org.apache.spark.SparkFunSuite import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator} -import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfterEach, Matchers} import org.apache.spark.sql.types._ -class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers with BeforeAndAfterEach { +class UnsafeFixedWidthAggregationMapSuite + extends SparkFunSuite + with Matchers + with BeforeAndAfterEach { import UnsafeFixedWidthAggregationMap._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 3a60c7fd32675..61722f1ffa462 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Arrays -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods -class UnsafeRowConverterSuite extends FunSuite with Matchers { +class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { test("basic conversion with only primitive types") { val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index e7cafcc96de87..765c1e2dda99f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.plans -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Filter, LogicalPlan} import org.apache.spark.sql.catalyst.util._ @@ -26,7 +25,7 @@ import org.apache.spark.sql.catalyst.util._ /** * Provides helper methods for comparing plans. */ -class PlanTest extends FunSuite { +class PlanTest extends SparkFunSuite { /** * Since attribute references are given globally unique ids during analysis, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala index 1273921f6394c..62d5f6ac74885 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.plans -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference} @@ -28,7 +27,7 @@ import org.apache.spark.sql.catalyst.util._ /** * Tests for the sameResult function of [[LogicalPlan]]. */ -class SameResultSuite extends FunSuite { +class SameResultSuite extends SparkFunSuite { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala index 2a641c63f87bb..a7de7b052bdc3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.catalyst.trees -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal} import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} -class RuleExecutorSuite extends FunSuite { +class RuleExecutorSuite extends SparkFunSuite { object DecrementLiterals extends Rule[Expression] { def apply(e: Expression): Expression = e transform { case IntegerLiteral(i) if i > 0 => Literal(i - 1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 9fcfc51c96139..67db3d5e6d751 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.trees import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{IntegerType, StringType, NullType} @@ -32,7 +31,7 @@ case class Dummy(optKey: Option[Expression]) extends Expression { override def eval(input: Row): Any = null.asInstanceOf[Any] } -class TreeNodeSuite extends FunSuite { +class TreeNodeSuite extends SparkFunSuite { test("top node changed") { val after = Literal(1) transform { case Literal(1, _) => Literal(2) } assert(after === Literal(2)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala index d7d60efee50fa..4030a1b1df358 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql.catalyst.util import org.json4s.jackson.JsonMethods.parse -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.{MetadataBuilder, Metadata} -class MetadataSuite extends FunSuite { +class MetadataSuite extends SparkFunSuite { val baseMetadata = new MetadataBuilder() .putString("purpose", "ml") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala index 3e7cf7cbb5e63..c6171b7b6916d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.types -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class DataTypeParserSuite extends FunSuite { +class DataTypeParserSuite extends SparkFunSuite { def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = { test(s"parse ${dataTypeString.replace("\n", "")}") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index df119827812f9..543cdefc5293b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.types -import org.apache.spark.SparkException -import org.scalatest.FunSuite +import org.apache.spark.{SparkException, SparkFunSuite} -class DataTypeSuite extends FunSuite { +class DataTypeSuite extends SparkFunSuite { test("construct an ArrayType") { val array = ArrayType(StringType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala index a22aa6f244c48..81d7ab010f394 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.types -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite // scalastyle:off -class UTF8StringSuite extends FunSuite { +class UTF8StringSuite extends SparkFunSuite { test("basic") { def check(str: String, len: Int) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala index de6a2cd448c47..28b373e258311 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.types.decimal +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.Decimal -import org.scalatest.{PrivateMethodTester, FunSuite} +import org.scalatest.PrivateMethodTester import scala.language.postfixOps -class DecimalSuite extends FunSuite with PrivateMethodTester { +class DecimalSuite extends SparkFunSuite with PrivateMethodTester { test("creating decimals") { /** Check that a Decimal has the given string representation, precision and scale */ def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 46b1845a9180c..add0fd58e28c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql -import org.scalatest.FunSuite import org.scalatest.Matchers._ +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.implicits._ -class DataFrameStatSuite extends FunSuite { +class DataFrameStatSuite extends SparkFunSuite { val sqlCtx = TestSQLContext def toLetter(i: Int): String = (i + 97).toChar.toString diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index c4281c4b55c02..dd68965444f5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -206,7 +206,7 @@ class MathExpressionsSuite extends QueryTest { } test("log") { - testOneToOneNonNegativeMathFunction(log, math.log) + testOneToOneNonNegativeMathFunction(org.apache.spark.sql.functions.log, math.log) } test("log10") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index fb3ba4bc1b908..513ac915dcb2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -17,15 +17,15 @@ package org.apache.spark.sql +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer -import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ -class RowSuite extends FunSuite { +class RowSuite extends SparkFunSuite { test("create row") { val expected = new GenericMutableRow(4) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index bf73d0c7074a5..3a5f071e2f7cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql -import org.scalatest.FunSuiteLike - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test._ /* Implicits */ import TestSQLContext._ -class SQLConfSuite extends QueryTest with FunSuiteLike { +class SQLConfSuite extends QueryTest { val testKey = "test.key.0" val testVal = "test.val.0" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index f186bc1c18123..797d123b48668 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test.TestSQLContext -class SQLContextSuite extends FunSuite with BeforeAndAfterAll { +class SQLContextSuite extends SparkFunSuite with BeforeAndAfterAll { private val testSqlContext = TestSQLContext private val testSparkContext = TestSQLContext.sparkContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index 52d265b445e14..d2ede39f0a5f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.test.TestSQLContext._ @@ -74,7 +73,7 @@ case class ComplexReflectData( mapFieldContainsNull: Map[Int, Option[Long]], dataField: Data) -class ScalaReflectionRelationSuite extends FunSuite { +class ScalaReflectionRelationSuite extends SparkFunSuite { import org.apache.spark.sql.test.TestSQLContext.implicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index 6f6d3c9c243d4..1e8cde606b67b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.test.TestSQLContext -class SerializationSuite extends FunSuite { +class SerializationSuite extends SparkFunSuite { test("[SPARK-5235] SQLContext should be serializable") { val sqlContext = new SQLContext(TestSQLContext.sparkContext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 7cefcf44061ce..339e719f39f16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.columnar -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.types._ -class ColumnStatsSuite extends FunSuite { +class ColumnStatsSuite extends SparkFunSuite { testColumnStats(classOf[ByteColumnStats], BYTE, Row(Byte.MaxValue, Byte.MinValue, 0)) testColumnStats(classOf[ShortColumnStats], SHORT, Row(Short.MaxValue, Short.MinValue, 0)) testColumnStats(classOf[IntColumnStats], INT, Row(Int.MaxValue, Int.MinValue, 0)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 061efb37a0ac3..a1e76eaa982cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -23,15 +23,14 @@ import java.sql.Timestamp import com.esotericsoftware.kryo.{Serializer, Kryo} import com.esotericsoftware.kryo.io.{Input, Output} import org.apache.spark.serializer.KryoRegistrator -import org.scalatest.FunSuite -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ -class ColumnTypeSuite extends FunSuite with Logging { +class ColumnTypeSuite extends SparkFunSuite with Logging { val DEFAULT_BUFFER_SIZE = 512 test("defaultSize") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index a0702144f942c..2a6e0c376551a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.types.DataType @@ -39,7 +38,7 @@ object TestNullableColumnAccessor { } } -class NullableColumnAccessorSuite extends FunSuite { +class NullableColumnAccessorSuite extends SparkFunSuite { import ColumnarTestUtils._ Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index 3a5605d2335d7..cb4e9f1eb7f46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.columnar -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ @@ -35,7 +34,7 @@ object TestNullableColumnBuilder { } } -class NullableColumnBuilderSuite extends FunSuite { +class NullableColumnBuilderSuite extends SparkFunSuite { import ColumnarTestUtils._ Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 2a0b701cad7fa..cda1b0992e36f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.columnar -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.implicits._ -class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter { +class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { val originalColumnBatchSize = conf.columnBatchSize val originalInMemoryPartitionPruning = conf.inMemoryPartitionPruning diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala index 8b518f094174c..20d65a74e3b7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.columnar.compression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar.{NoopColumnStats, BOOLEAN} import org.apache.spark.sql.columnar.ColumnarTestUtils._ -class BooleanBitSetSuite extends FunSuite { +class BooleanBitSetSuite extends SparkFunSuite { import BooleanBitSet._ def skeleton(count: Int) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala index cef60ec204faa..acfab6586c0d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql.columnar.compression import java.nio.ByteBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType -class DictionaryEncodingSuite extends FunSuite { +class DictionaryEncodingSuite extends SparkFunSuite { testDictionaryEncoding(new IntColumnStats, INT) testDictionaryEncoding(new LongColumnStats, LONG) testDictionaryEncoding(new StringColumnStats, STRING) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala index 5514590541dd6..2111e9fbe62cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.columnar.compression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.IntegralType -class IntegralDeltaSuite extends FunSuite { +class IntegralDeltaSuite extends SparkFunSuite { testIntegralDelta(new IntColumnStats, INT, IntDelta) testIntegralDelta(new LongColumnStats, LONG, LongDelta) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala index 6ee48f6291914..67ec08f594a43 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.columnar.compression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType -class RunLengthEncodingSuite extends FunSuite { +class RunLengthEncodingSuite extends SparkFunSuite { testRunLengthEncoding(new NoopColumnStats, BOOLEAN) testRunLengthEncoding(new ByteColumnStats, BYTE) testRunLengthEncoding(new ShortColumnStats, SHORT) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 523be56df65ba..45a7e8fe68f72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.execution -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{SQLConf, execution} import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ @@ -31,7 +30,7 @@ import org.apache.spark.sql.test.TestSQLContext.planner._ import org.apache.spark.sql.types._ -class PlannerSuite extends FunSuite { +class PlannerSuite extends SparkFunSuite { test("unions are collapsed") { val query = testData.unionAll(testData).unionAll(testData).logicalPlan val planned = BasicOperators(query).head diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala index 15337c4045436..6ca5390cde23e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -19,17 +19,17 @@ package org.apache.spark.sql.execution import java.sql.{Timestamp, Date} -import org.scalatest.{FunSuite, BeforeAndAfterAll} +import org.scalatest.BeforeAndAfterAll import org.apache.spark.rdd.ShuffledRDD import org.apache.spark.serializer.Serializer -import org.apache.spark.ShuffleDependency +import org.apache.spark.{ShuffleDependency, SparkFunSuite} import org.apache.spark.sql.types._ import org.apache.spark.sql.Row import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest} -class SparkSqlSerializer2DataTypeSuite extends FunSuite { +class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite { // Make sure that we will not use serializer2 for unsupported data types. def checkSupported(dataType: DataType, isSupported: Boolean): Unit = { val testName = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 358d8cf06e463..8ec3985e00360 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.execution.debug -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.TestData._ import org.apache.spark.sql.test.TestSQLContext._ -class DebuggingSuite extends FunSuite { +class DebuggingSuite extends SparkFunSuite { test("DataFrame.debug()") { testData.debug() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 2aad01ded1acf..5290c28cfca02 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.execution.joins -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Projection, Row} import org.apache.spark.util.collection.CompactBuffer -class HashedRelationSuite extends FunSuite { +class HashedRelationSuite extends SparkFunSuite { // Key is simply the record itself private val keyProjection = new Projection { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 30279f528944b..af279007c587e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -21,14 +21,15 @@ import java.math.BigDecimal import java.sql.DriverManager import java.util.{Calendar, GregorianCalendar, Properties} +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test._ import org.apache.spark.sql.types._ import org.h2.jdbc.JdbcSQLException -import org.scalatest.{FunSuite, BeforeAndAfter} +import org.scalatest.BeforeAndAfter import TestSQLContext._ import TestSQLContext.implicits._ -class JDBCSuite extends FunSuite with BeforeAndAfter { +class JDBCSuite extends SparkFunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb0" val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass" var conn: java.sql.Connection = null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 2e4c12f9da80c..3cd987b0b3383 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -20,13 +20,14 @@ package org.apache.spark.sql.jdbc import java.sql.DriverManager import java.util.Properties -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{SaveMode, Row} import org.apache.spark.sql.test._ import org.apache.spark.sql.types._ -class JDBCWriteSuite extends FunSuite with BeforeAndAfter { +class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb2" var conn: java.sql.Connection = null val url1 = "jdbc:h2:mem:testdb3" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index c964b6d984557..caec2a6f25489 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.parquet import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.scalatest.FunSuite import parquet.schema.MessageTypeParser +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ -class ParquetSchemaSuite extends FunSuite with ParquetTest { +class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { val sqlContext = TestSQLContext /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 8331a14c9295c..296b0d6f74a0c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.sources -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class ResolvedDataSourceSuite extends FunSuite { +class ResolvedDataSourceSuite extends SparkFunSuite { test("builtin sources") { assert(ResolvedDataSource.lookupDataSource("jdbc") === diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 437f697d25bf3..20d3c7d4c5959 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -41,6 +41,13 @@ spark-hive_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + com.google.guava guava diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index cc07db827d359..3732af7870b93 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -25,16 +25,16 @@ import scala.concurrent.{Await, Promise} import scala.sys.process.{Process, ProcessLogger} import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.util.Utils /** * A test suite for the `spark-sql` CLI tool. Note that all test cases share the same temporary * Hive metastore and warehouse. */ -class CliSuite extends FunSuite with BeforeAndAfter with Logging { +class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { val warehousePath = Utils.createTempDir() val metastorePath = Utils.createTempDir() diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 610939c6a9481..da511ebd05ad2 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -37,9 +37,9 @@ import org.apache.hive.service.cli.thrift.TCLIService.Client import org.apache.hive.service.cli.thrift.ThriftCLIServiceClient import org.apache.thrift.protocol.TBinaryProtocol import org.apache.thrift.transport.TSocket -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.hive.HiveShim import org.apache.spark.util.Utils @@ -405,7 +405,7 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { } } -abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll with Logging { +abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAll with Logging { def mode: ServerMode.Value private val CLASS_NAME = HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$") diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 615b07e74d535..923ffabb9b99e 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -41,6 +41,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-sql_${scala.binary.version} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 80c2d32bf70d7..df137e7b2b333 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -26,12 +26,12 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectIns import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.io.LongWritable -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Literal, Row} import org.apache.spark.sql.types._ -class HiveInspectorSuite extends FunSuite with HiveInspectors { +class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { test("Test wrap SettableStructObjectInspector") { val udaf = new UDAFPercentile.PercentileLongEvaluator() udaf.init() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index fa8e11ffec2b4..e9bb32667936c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.hive +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.hive.test.TestHive -import org.scalatest.FunSuite import org.apache.spark.sql.test.ExamplePointUDT import org.apache.spark.sql.types.StructType -class HiveMetastoreCatalogSuite extends FunSuite { +class HiveMetastoreCatalogSuite extends SparkFunSuite { test("struct field should accept underscore in sub-column name") { val metastr = "struct" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index 941a2941649b8..f765395e148af 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -20,12 +20,13 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde.serdeConstants +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.hive.client.{ManagedTable, HiveColumn, ExternalTable, HiveTable} -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -class HiveQlSuite extends FunSuite with BeforeAndAfterAll { +class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { override def beforeAll() { if (SessionState.get() == null) { SessionState.start(new HiveConf()) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala index 8afe5459d4f1b..a492ecf203d17 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql.hive -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.hive.test.TestHive -class SerializationSuite extends FunSuite { +class SerializationSuite extends SparkFunSuite { test("[SPARK-5840] HiveContext should be serializable") { val hiveContext = TestHive diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 321dc8d7322b8..446a2f2d646e1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.hive.client -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.util.Utils -import org.scalatest.FunSuite /** * A simple set of tests that call the methods of a hive ClientInterface, loading different version @@ -28,7 +27,7 @@ import org.scalatest.FunSuite * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionallity * is not fully tested. */ -class VersionsSuite extends FunSuite with Logging { +class VersionsSuite extends SparkFunSuite with Logging { private def buildConf() = { lazy val warehousePath = Utils.createTempDir() lazy val metastorePath = Utils.createTempDir() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala index 23ece7e7cf6e9..b0d3dd44daedc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql.hive.test.TestHiveContext -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -class ConcurrentHiveSuite extends FunSuite with BeforeAndAfterAll { +class ConcurrentHiveSuite extends SparkFunSuite with BeforeAndAfterAll { ignore("multiple instances not supported") { test("Multiple Hive Instances") { (1 to 10).map { i => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 55e5551b63818..c9dd4c0935a72 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.hive.execution import java.io._ -import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen} +import org.scalatest.{BeforeAndAfterAll, GivenWhenThen} -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.sources.DescribeCommand import org.apache.spark.sql.execution.{SetCommand, ExplainCommand} import org.apache.spark.sql.catalyst.planning.PhysicalOperation @@ -40,7 +40,7 @@ import org.apache.spark.sql.hive.test.TestHive * configured using system properties. */ abstract class HiveComparisonTest - extends FunSuite with BeforeAndAfterAll with GivenWhenThen with Logging { + extends SparkFunSuite with BeforeAndAfterAll with GivenWhenThen with Logging { /** * When set, any cache files that result in test failures will be deleted. Used when the test diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index 88c99e35260d9..0e63d84e9824a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.hive.orc import java.io.File import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.util.Utils -import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} +import org.scalatest.BeforeAndAfterAll import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -38,7 +39,7 @@ case class OrcParData(intField: Int, stringField: String) case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) // TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot -class OrcPartitionDiscoverySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { +class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultVal def withTempDir(f: File => Unit): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index cdd6e705f4a2c..57c23fe77f8b5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -21,8 +21,9 @@ import java.io.File import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.io.orc.CompressionKind -import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} +import org.scalatest.BeforeAndAfterAll +import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.hive.test.TestHive @@ -50,7 +51,7 @@ case class Contact(name: String, phone: String) case class Person(name: String, age: Int, contacts: Seq[Contact]) -class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll with OrcTest { +class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { override val sqlContext = TestHive import TestHive.read diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index cf5ae88dc4bee..af36fa6f1faae 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -18,9 +18,8 @@ package org.apache.spark.sql.sources import org.apache.hadoop.fs.Path -import org.scalatest.FunSuite -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHive @@ -485,7 +484,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { } } -class CommitFailureTestRelationSuite extends FunSuite with SQLTestUtils { +class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils { import TestHive.implicits._ override val sqlContext = TestHive diff --git a/streaming/pom.xml b/streaming/pom.xml index 5ab7f4472c38b..49d035a1e9696 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -40,6 +40,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala index 6a1dd6949b204..9b5e4dc819a2b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.streaming import java.io.NotSerializableException -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -import org.apache.spark.{HashPartitioner, SparkContext, SparkException} +import org.apache.spark.{HashPartitioner, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.ReturnStatementInClosureException @@ -29,7 +29,7 @@ import org.apache.spark.util.ReturnStatementInClosureException /** * Test that closures passed to DStream operations are actually cleaned. */ -class DStreamClosureSuite extends FunSuite with BeforeAndAfterAll { +class DStreamClosureSuite extends SparkFunSuite with BeforeAndAfterAll { private var ssc: StreamingContext = null override def beforeAll(): Unit = { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala index e3fb2ef130859..8844c9d74b933 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.streaming -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.ui.UIUtils @@ -27,7 +27,7 @@ import org.apache.spark.streaming.ui.UIUtils /** * Tests whether scope information is passed from DStream operations to RDDs correctly. */ -class DStreamScopeSuite extends FunSuite with BeforeAndAfter with BeforeAndAfterAll { +class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { private var ssc: StreamingContext = null private val batchDuration: Duration = Seconds(1) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 23804237bda80..cca8cedb1d080 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -25,7 +25,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import org.apache.hadoop.conf.Configuration -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ @@ -41,7 +41,11 @@ import org.apache.spark.util.{ManualClock, Utils} import WriteAheadLogBasedBlockHandler._ import WriteAheadLogSuite._ -class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { +class ReceivedBlockHandlerSuite + extends SparkFunSuite + with BeforeAndAfter + with Matchers + with Logging { val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") val hadoopConf = new Configuration() diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index b1af8d5eaacfb..6f0ee774cb5cf 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -25,10 +25,10 @@ import scala.language.{implicitConversions, postfixOps} import scala.util.Random import org.apache.hadoop.conf.Configuration -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.{Logging, SparkConf, SparkException, SparkFunSuite} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult import org.apache.spark.streaming.scheduler._ @@ -37,7 +37,7 @@ import org.apache.spark.streaming.util.WriteAheadLogSuite._ import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} class ReceivedBlockTrackerSuite - extends FunSuite with BeforeAndAfter with Matchers with Logging { + extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { val hadoopConf = new Configuration() val akkaTimeout = 10 seconds diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index e36c7914b130e..d304c9a7328f3 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -25,16 +25,16 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ -import org.scalatest.{Assertions, BeforeAndAfter, FunSuite} +import org.scalatest.{Assertions, BeforeAndAfter} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException} +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException, SparkFunSuite} -class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts with Logging { +class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeouts with Logging { val master = "local[2]" val appName = this.getClass.getSimpleName diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 554cd30223f44..31b1aebf6a8ec 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -24,12 +24,12 @@ import scala.collection.mutable.SynchronizedBuffer import scala.language.implicitConversions import scala.reflect.ClassTag -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.time.{Span, Seconds => ScalaTestSeconds} import org.scalatest.concurrent.Eventually.timeout import org.scalatest.concurrent.PatienceConfiguration -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} import org.apache.spark.streaming.scheduler._ @@ -204,7 +204,7 @@ class BatchCounter(ssc: StreamingContext) { * This is the base trait for Spark Streaming testsuites. This provides basic functionality * to run user-defined set of input on user-defined stream operations, and verify the output. */ -trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { +trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { // Name of the framework for Spark context def framework: String = this.getClass.getSimpleName diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index 441bbf95d0153..021d2c95a4aad 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -35,7 +35,7 @@ import org.apache.spark._ * Selenium tests for the Spark Web UI. */ class UISeleniumSuite - extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase { + extends SparkFunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase { implicit var webDriver: WebDriver = _ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index 6859b65c7165f..cb017b798b2a4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -21,15 +21,15 @@ import java.io.File import scala.util.Random import org.apache.hadoop.conf.Configuration -import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} import org.apache.spark.streaming.util.{FileBasedWriteAheadLogSegment, FileBasedWriteAheadLogWriter} import org.apache.spark.util.Utils -import org.apache.spark.{SparkConf, SparkContext, SparkException} +import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} class WriteAheadLogBackedBlockRDDSuite - extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { + extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfterEach { val conf = new SparkConf() .setMaster("local[2]") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala index 5478b41845943..2e210397fe7c7 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.streaming.scheduler -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.streaming.{Time, Duration, StreamingContext} -class InputInfoTrackerSuite extends FunSuite with BeforeAndAfter { +class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter { private var ssc: StreamingContext = _ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala index e9ab917ab845c..d3ca2b58f36c2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.streaming.ui import java.util.TimeZone import java.util.concurrent.TimeUnit -import org.scalatest.FunSuite import org.scalatest.Matchers -class UIUtilsSuite extends FunSuite with Matchers{ +import org.apache.spark.SparkFunSuite + +class UIUtilsSuite extends SparkFunSuite with Matchers{ test("shortTimeUnitString") { assert("ns" === UIUtils.shortTimeUnitString(TimeUnit.NANOSECONDS)) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala index 9ebf7b484f421..78fc344b00177 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.streaming.util import java.io.ByteArrayOutputStream import java.util.concurrent.TimeUnit._ -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class RateLimitedOutputStreamSuite extends FunSuite { +class RateLimitedOutputStreamSuite extends SparkFunSuite { private def benchmark[U](f: => U): Long = { val start = System.nanoTime diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 79098bcf4861c..0acf7068ef4a4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -28,12 +28,12 @@ import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.scalatest.concurrent.Eventually._ -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.apache.spark.util.{ManualClock, Utils} -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} -class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { +class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { import WriteAheadLogSuite._ diff --git a/yarn/pom.xml b/yarn/pom.xml index 00d219f836708..e207a46809684 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -39,6 +39,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.hadoop hadoop-yarn-api diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala index 80b57d1355a3a..43a7334db874c 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.deploy.yarn import java.net.URI -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar import org.mockito.Mockito.when @@ -36,8 +35,10 @@ import org.apache.hadoop.yarn.util.{Records, ConverterUtils} import scala.collection.mutable.HashMap import scala.collection.mutable.Map +import org.apache.spark.SparkFunSuite -class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { + +class ClientDistributedCacheManagerSuite extends SparkFunSuite with MockitoSugar { class MockClientDistributedCacheManager extends ClientDistributedCacheManager { override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 6da3e82acdb14..01d33c9ce9297 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -33,12 +33,12 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.mockito.Matchers._ import org.mockito.Mockito._ -import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfterAll, Matchers} -import org.apache.spark.{SparkException, SparkConf} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.util.Utils -class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { +class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { override def beforeAll(): Unit = { System.setProperty("SPARK_YARN_MODE", "true") diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index b343cbb0c7569..7509000771d94 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -26,13 +26,13 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest -import org.apache.spark.SecurityManager +import org.apache.spark.{SecurityManager, SparkFunSuite} import org.apache.spark.SparkConf import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.YarnAllocator._ import org.apache.spark.scheduler.SplitInfo -import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfterEach, Matchers} class MockResolver extends DNSToSwitchMapping { @@ -46,7 +46,7 @@ class MockResolver extends DNSToSwitchMapping { def reloadCachedMappings(names: JList[String]) {} } -class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach { +class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach { val conf = new Configuration() conf.setClass( CommonConfigurationKeysPublic.NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY, diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index dcaeb2e43ff41..d8bc2534c1a6a 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -30,9 +30,9 @@ import com.google.common.io.ByteStreams import com.google.common.io.Files import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.server.MiniYARNCluster -import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfterAll, Matchers} -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException, TestUtils} +import org.apache.spark._ import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, SparkListenerExecutorAdded} @@ -43,7 +43,7 @@ import org.apache.spark.util.Utils * applications, and require the Spark assembly to be built before they can be successfully * run. */ -class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers with Logging { +class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matchers with Logging { // log4j configuration for the YARN containers, so that their output is collected // by YARN instead of trying to overwrite unit-tests.log. diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index e10b985c3c236..49bee0866dd43 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -25,15 +25,15 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers import org.apache.hadoop.yarn.api.records.ApplicationAccessType -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException, SparkFunSuite} import org.apache.spark.util.Utils -class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { +class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging { val hasBash = try { From 5f48e5c33bafa376be5741e260a037c66103fdcd Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Fri, 29 May 2015 14:11:58 -0700 Subject: [PATCH 072/254] [SPARK-6806] [SPARKR] [DOCS] Add a new SparkR programming guide This PR adds a new SparkR programming guide at the top-level. This will be useful for R users as our APIs don't directly match the Scala/Python APIs and as we need to explain SparkR without using RDDs as examples etc. cc rxin davies pwendell cc cafreeman -- Would be great if you could also take a look at this ! Author: Shivaram Venkataraman Closes #6490 from shivaram/sparkr-guide and squashes the following commits: d5ff360 [Shivaram Venkataraman] Add a section on HiveContext, HQL queries 408dce5 [Shivaram Venkataraman] Fix link dbb86e3 [Shivaram Venkataraman] Fix minor typo 9aff5e0 [Shivaram Venkataraman] Address comments, use dplyr-like syntax in example d09703c [Shivaram Venkataraman] Fix default argument in read.df ea816a1 [Shivaram Venkataraman] Add a new SparkR programming guide Also update write.df, read.df to handle defaults better --- R/pkg/R/DataFrame.R | 10 +- R/pkg/R/SQLContext.R | 5 + R/pkg/R/generics.R | 4 +- docs/_layouts/global.html | 1 + docs/index.md | 2 +- docs/sparkr.md | 223 ++++++++++++++++++++++++++++++++++ docs/sql-programming-guide.md | 4 +- 7 files changed, 238 insertions(+), 11 deletions(-) create mode 100644 docs/sparkr.md diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index ed8093c80d360..e79d324838fe3 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1314,9 +1314,8 @@ setMethod("except", #' write.df(df, "myfile", "parquet", "overwrite") #' } setMethod("write.df", - signature(df = "DataFrame", path = 'character', source = 'character', - mode = 'character'), - function(df, path = NULL, source = NULL, mode = "append", ...){ + signature(df = "DataFrame", path = 'character'), + function(df, path, source = NULL, mode = "append", ...){ if (is.null(source)) { sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", @@ -1338,9 +1337,8 @@ setMethod("write.df", #' @aliases saveDF #' @export setMethod("saveDF", - signature(df = "DataFrame", path = 'character', source = 'character', - mode = 'character'), - function(df, path = NULL, source = NULL, mode = "append", ...){ + signature(df = "DataFrame", path = 'character'), + function(df, path, source = NULL, mode = "append", ...){ write.df(df, path, source, mode, ...) }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 36cc612875879..88e1a508f37c4 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -457,6 +457,11 @@ read.df <- function(sqlContext, path = NULL, source = NULL, ...) { if (!is.null(path)) { options[['path']] <- path } + if (is.null(source)) { + sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", + "org.apache.spark.sql.parquet") + } sdf <- callJMethod(sqlContext, "load", source, options) dataFrame(sdf) } diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index a23d3b217b2fd..1f4fc6adaca8d 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -482,11 +482,11 @@ setGeneric("saveAsTable", function(df, tableName, source, mode, ...) { #' @rdname write.df #' @export -setGeneric("write.df", function(df, path, source, mode, ...) { standardGeneric("write.df") }) +setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") }) #' @rdname write.df #' @export -setGeneric("saveDF", function(df, path, source, mode, ...) { standardGeneric("saveDF") }) +setGeneric("saveDF", function(df, path, ...) { standardGeneric("saveDF") }) #' @rdname schema #' @export diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index b92c75f90b11c..eebb3faf90fc0 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -75,6 +75,7 @@
  • MLlib (Machine Learning)
  • GraphX (Graph Processing)
  • Bagel (Pregel on Spark)
  • +
  • SparkR (R on Spark)
  • diff --git a/docs/index.md b/docs/index.md index 5ef6d983c45a5..fac071da81e60 100644 --- a/docs/index.md +++ b/docs/index.md @@ -54,7 +54,7 @@ Example applications are also provided in Python. For example, ./bin/spark-submit examples/src/main/python/pi.py 10 -Spark also provides an experimental R API since 1.4 (only DataFrames APIs included). +Spark also provides an experimental [R API](sparkr.html) since 1.4 (only DataFrames APIs included). To run Spark interactively in a R interpreter, use `bin/sparkR`: ./bin/sparkR --master local[2] diff --git a/docs/sparkr.md b/docs/sparkr.md new file mode 100644 index 0000000000000..4d82129921a37 --- /dev/null +++ b/docs/sparkr.md @@ -0,0 +1,223 @@ +--- +layout: global +displayTitle: SparkR (R on Spark) +title: SparkR (R on Spark) +--- + +* This will become a table of contents (this text will be scraped). +{:toc} + +# Overview +SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. +In Spark {{site.SPARK_VERSION}}, SparkR provides a distributed data frame implementation that +supports operations like selection, filtering, aggregation etc. (similar to R data frames, +[dplyr](https://github.com/hadley/dplyr)) but on large datasets. + +# SparkR DataFrames + +A DataFrame is a distributed collection of data organized into named columns. It is conceptually +equivalent to a table in a relational database or a data frame in R, but with richer +optimizations under the hood. DataFrames can be constructed from a wide array of sources such as: +structured data files, tables in Hive, external databases, or existing local R data frames. + +All of the examples on this page use sample data included in R or the Spark distribution and can be run using the `./bin/sparkR` shell. + +## Starting Up: SparkContext, SQLContext + +
    +The entry point into SparkR is the `SparkContext` which connects your R program to a Spark cluster. +You can create a `SparkContext` using `sparkR.init` and pass in options such as the application name +etc. Further, to work with DataFrames we will need a `SQLContext`, which can be created from the +SparkContext. If you are working from the SparkR shell, the `SQLContext` and `SparkContext` should +already be created for you. + +{% highlight r %} +sc <- sparkR.init() +sqlContext <- sparkRSQL.init(sc) +{% endhighlight %} + +
    + +## Creating DataFrames +With a `SQLContext`, applications can create `DataFrame`s from a local R data frame, from a [Hive table](sql-programming-guide.html#hive-tables), or from other [data sources](sql-programming-guide.html#data-sources). + +### From local data frames +The simplest way to create a data frame is to convert a local R data frame into a SparkR DataFrame. Specifically we can use `createDataFrame` and pass in the local R data frame to create a SparkR DataFrame. As an example, the following creates a `DataFrame` based using the `faithful` dataset from R. + +
    +{% highlight r %} +df <- createDataFrame(sqlContext, faithful) + +# Displays the content of the DataFrame to stdout +head(df) +## eruptions waiting +##1 3.600 79 +##2 1.800 54 +##3 3.333 74 + +{% endhighlight %} +
    + +### From Data Sources + +SparkR supports operating on a variety of data sources through the `DataFrame` interface. This section describes the general methods for loading and saving data using Data Sources. You can check the Spark SQL programming guide for more [specific options](sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. + +The general method for creating DataFrames from data sources is `read.df`. This method takes in the `SQLContext`, the path for the file to load and the type of data source. SparkR supports reading JSON and Parquet files natively and through [Spark Packages](http://spark-packages.org/) you can find data source connectors for popular file formats like [CSV](http://spark-packages.org/package/databricks/spark-csv) and [Avro](http://spark-packages.org/package/databricks/spark-avro). + +We can see how to use data sources using an example JSON input file. Note that the file that is used here is _not_ a typical JSON file. Each line in the file must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. + +
    + +{% highlight r %} +people <- read.df(sqlContext, "./examples/src/main/resources/people.json", "json") +head(people) +## age name +##1 NA Michael +##2 30 Andy +##3 19 Justin + +# SparkR automatically infers the schema from the JSON file +printSchema(people) +# root +# |-- age: integer (nullable = true) +# |-- name: string (nullable = true) + +{% endhighlight %} +
    + +The data sources API can also be used to save out DataFrames into multiple file formats. For example we can save the DataFrame from the previous example +to a Parquet file using `write.df` + +
    +{% highlight r %} +write.df(people, path="people.parquet", source="parquet", mode="overwrite") +{% endhighlight %} +
    + +### From Hive tables + +You can also create SparkR DataFrames from Hive tables. To do this we will need to create a HiveContext which can access tables in the Hive MetaStore. Note that Spark should have been built with [Hive support](building-spark.html#building-with-hive-and-jdbc-support) and more details on the difference between SQLContext and HiveContext can be found in the [SQL programming guide](sql-programming-guide.html#starting-point-sqlcontext). + +
    +{% highlight r %} +# sc is an existing SparkContext. +hiveContext <- sparkRHive.init(sc) + +sql(hiveContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +sql(hiveContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") + +# Queries can be expressed in HiveQL. +results <- hiveContext.sql("FROM src SELECT key, value") + +# results is now a DataFrame +head(results) +## key value +## 1 238 val_238 +## 2 86 val_86 +## 3 311 val_311 + +{% endhighlight %} +
    + +## DataFrame Operations + +SparkR DataFrames support a number of functions to do structured data processing. +Here we include some basic examples and a complete list can be found in the [API](api/R/index.html) docs: + +### Selecting rows, columns + +
    +{% highlight r %} +# Create the DataFrame +df <- createDataFrame(sqlContext, faithful) + +# Get basic information about the DataFrame +df +## DataFrame[eruptions:double, waiting:double] + +# Select only the "eruptions" column +head(select(df, df$eruptions)) +## eruptions +##1 3.600 +##2 1.800 +##3 3.333 + +# You can also pass in column name as strings +head(select(df, "eruptions")) + +# Filter the DataFrame to only retain rows with wait times shorter than 50 mins +head(filter(df, df$waiting < 50)) +## eruptions waiting +##1 1.750 47 +##2 1.750 47 +##3 1.867 48 + +{% endhighlight %} + +
    + +### Grouping, Aggregation + +SparkR data frames support a number of commonly used functions to aggregate data after grouping. For example we can compute a histogram of the `waiting` time in the `faithful` dataset as shown below + +
    +{% highlight r %} + +# We use the `n` operator to count the number of times each waiting time appears +head(summarize(groupBy(df, df$waiting), count = n(df$waiting))) +## waiting count +##1 81 13 +##2 60 6 +##3 68 1 + +# We can also sort the output from the aggregation to get the most common waiting times +waiting_counts <- summarize(groupBy(df, df$waiting), count = n(df$waiting)) +head(arrange(waiting_counts, desc(waiting_counts$count))) + +## waiting count +##1 78 15 +##2 83 14 +##3 81 13 + +{% endhighlight %} +
    + +### Operating on Columns + +SparkR also provides a number of functions that can directly applied to columns for data processing and during aggregation. The example below shows the use of basic arithmetic functions. + +
    +{% highlight r %} + +# Convert waiting time from hours to seconds. +# Note that we can assign this to a new column in the same DataFrame +df$waiting_secs <- df$waiting * 60 +head(df) +## eruptions waiting waiting_secs +##1 3.600 79 4740 +##2 1.800 54 3240 +##3 3.333 74 4440 + +{% endhighlight %} +
    + +## Running SQL Queries from SparkR +A SparkR DataFrame can also be registered as a temporary table in Spark SQL and registering a DataFrame as a table allows you to run SQL queries over its data. +The `sql` function enables applications to run SQL queries programmatically and returns the result as a `DataFrame`. + +
    +{% highlight r %} +# Load a JSON file +people <- read.df(sqlContext, "./examples/src/main/resources/people.json", "json") + +# Register this DataFrame as a table. +registerTempTable(people, "people") + +# SQL statements can be run by using the sql method +teenagers <- sql(sqlContext, "SELECT name FROM people WHERE age >= 13 AND age <= 19") +head(teenagers) +## name +##1 Justin + +{% endhighlight %} +
    diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index ab646f65bb5eb..7cc0a87fd5c53 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1526,8 +1526,8 @@ adds support for finding tables in the MetaStore and writing queries using HiveQ # sc is an existing SparkContext. sqlContext <- sparkRHive.init(sc) -hql(sqlContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") -hql(sqlContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") +sql(sqlContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +sql(sqlContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") # Queries can be expressed in HiveQL. results = sqlContext.sql("FROM src SELECT key, value").collect() From 1c5b19827a091b5aba69a967600e7ca35ed3bcfd Mon Sep 17 00:00:00 2001 From: Michael Nazario Date: Fri, 29 May 2015 14:13:44 -0700 Subject: [PATCH 073/254] [SPARK-7899] [PYSPARK] Fix Python 3 pyspark/sql/types module conflict This PR makes the types module in `pyspark/sql/types` work with pylint static analysis by removing the dynamic naming of the `pyspark/sql/_types` module to `pyspark/sql/types`. Tests are now loaded using `$PYSPARK_DRIVER_PYTHON -m module` rather than `$PYSPARK_DRIVER_PYTHON module.py`. The old method adds the location of `module.py` to `sys.path`, so this change prevents accidental use of relative paths in Python. Author: Michael Nazario Closes #6439 from mnazario/feature/SPARK-7899 and squashes the following commits: 366ef30 [Michael Nazario] Remove hack on random.py bb8b04d [Michael Nazario] Make doctests consistent with other tests 6ee4f75 [Michael Nazario] Change test scripts to use "-m" 673528f [Michael Nazario] Move _types back to types --- bin/pyspark | 6 +- python/pyspark/accumulators.py | 4 ++ python/pyspark/mllib/__init__.py | 8 --- python/pyspark/mllib/{rand.py => random.py} | 0 python/pyspark/sql/__init__.py | 12 ---- python/pyspark/sql/{_types.py => types.py} | 0 python/run-tests | 76 ++++++++++----------- 7 files changed, 43 insertions(+), 63 deletions(-) rename python/pyspark/mllib/{rand.py => random.py} (100%) rename python/pyspark/sql/{_types.py => types.py} (100%) diff --git a/bin/pyspark b/bin/pyspark index 8acad6113797d..7cb19c51b43a2 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -90,11 +90,7 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR export PYTHONHASHSEED=0 - if [[ -n "$PYSPARK_DOC_TEST" ]]; then - exec "$PYSPARK_DRIVER_PYTHON" -m doctest $1 - else - exec "$PYSPARK_DRIVER_PYTHON" $1 - fi + exec "$PYSPARK_DRIVER_PYTHON" -m $1 exit fi diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 0d21a132048a5..adca90ddaf397 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -261,3 +261,7 @@ def _start_update_server(): thread.daemon = True thread.start() return server + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py index 07507b2ad0d05..b11aed2c3afda 100644 --- a/python/pyspark/mllib/__init__.py +++ b/python/pyspark/mllib/__init__.py @@ -28,11 +28,3 @@ __all__ = ['classification', 'clustering', 'feature', 'fpm', 'linalg', 'random', 'recommendation', 'regression', 'stat', 'tree', 'util'] - -import sys -from . import rand as random -modname = __name__ + '.random' -random.__name__ = modname -random.RandomRDDs.__module__ = modname -sys.modules[modname] = random -del modname, sys diff --git a/python/pyspark/mllib/rand.py b/python/pyspark/mllib/random.py similarity index 100% rename from python/pyspark/mllib/rand.py rename to python/pyspark/mllib/random.py diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 8fee92ae3aed5..726d288d97b2e 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -50,18 +50,6 @@ def deco(f): return f return deco -# fix the module name conflict for Python 3+ -import sys -from . import _types as types -modname = __name__ + '.types' -types.__name__ = modname -# update the __module__ for all objects, make them picklable -for v in types.__dict__.values(): - if hasattr(v, "__module__") and v.__module__.endswith('._types'): - v.__module__ = modname -sys.modules[modname] = types -del modname, sys - from pyspark.sql.types import Row from pyspark.sql.context import SQLContext, HiveContext from pyspark.sql.column import Column diff --git a/python/pyspark/sql/_types.py b/python/pyspark/sql/types.py similarity index 100% rename from python/pyspark/sql/_types.py rename to python/pyspark/sql/types.py diff --git a/python/run-tests b/python/run-tests index ffde2fb24b369..fcfb49556b7cf 100755 --- a/python/run-tests +++ b/python/run-tests @@ -57,54 +57,54 @@ function run_test() { function run_core_tests() { echo "Run core tests ..." - run_test "pyspark/rdd.py" - run_test "pyspark/context.py" - run_test "pyspark/conf.py" - PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py" - PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py" - run_test "pyspark/serializers.py" - run_test "pyspark/profiler.py" - run_test "pyspark/shuffle.py" - run_test "pyspark/tests.py" + run_test "pyspark.rdd" + run_test "pyspark.context" + run_test "pyspark.conf" + run_test "pyspark.broadcast" + run_test "pyspark.accumulators" + run_test "pyspark.serializers" + run_test "pyspark.profiler" + run_test "pyspark.shuffle" + run_test "pyspark.tests" } function run_sql_tests() { echo "Run sql tests ..." - run_test "pyspark/sql/_types.py" - run_test "pyspark/sql/context.py" - run_test "pyspark/sql/column.py" - run_test "pyspark/sql/dataframe.py" - run_test "pyspark/sql/group.py" - run_test "pyspark/sql/functions.py" - run_test "pyspark/sql/tests.py" + run_test "pyspark.sql.types" + run_test "pyspark.sql.context" + run_test "pyspark.sql.column" + run_test "pyspark.sql.dataframe" + run_test "pyspark.sql.group" + run_test "pyspark.sql.functions" + run_test "pyspark.sql.tests" } function run_mllib_tests() { echo "Run mllib tests ..." - run_test "pyspark/mllib/classification.py" - run_test "pyspark/mllib/clustering.py" - run_test "pyspark/mllib/evaluation.py" - run_test "pyspark/mllib/feature.py" - run_test "pyspark/mllib/fpm.py" - run_test "pyspark/mllib/linalg.py" - run_test "pyspark/mllib/rand.py" - run_test "pyspark/mllib/recommendation.py" - run_test "pyspark/mllib/regression.py" - run_test "pyspark/mllib/stat/_statistics.py" - run_test "pyspark/mllib/tree.py" - run_test "pyspark/mllib/util.py" - run_test "pyspark/mllib/tests.py" + run_test "pyspark.mllib.classification" + run_test "pyspark.mllib.clustering" + run_test "pyspark.mllib.evaluation" + run_test "pyspark.mllib.feature" + run_test "pyspark.mllib.fpm" + run_test "pyspark.mllib.linalg" + run_test "pyspark.mllib.random" + run_test "pyspark.mllib.recommendation" + run_test "pyspark.mllib.regression" + run_test "pyspark.mllib.stat._statistics" + run_test "pyspark.mllib.tree" + run_test "pyspark.mllib.util" + run_test "pyspark.mllib.tests" } function run_ml_tests() { echo "Run ml tests ..." - run_test "pyspark/ml/feature.py" - run_test "pyspark/ml/classification.py" - run_test "pyspark/ml/recommendation.py" - run_test "pyspark/ml/regression.py" - run_test "pyspark/ml/tuning.py" - run_test "pyspark/ml/tests.py" - run_test "pyspark/ml/evaluation.py" + run_test "pyspark.ml.feature" + run_test "pyspark.ml.classification" + run_test "pyspark.ml.recommendation" + run_test "pyspark.ml.regression" + run_test "pyspark.ml.tuning" + run_test "pyspark.ml.tests" + run_test "pyspark.ml.evaluation" } function run_streaming_tests() { @@ -124,8 +124,8 @@ function run_streaming_tests() { done export PYSPARK_SUBMIT_ARGS="--jars ${KAFKA_ASSEMBLY_JAR} pyspark-shell" - run_test "pyspark/streaming/util.py" - run_test "pyspark/streaming/tests.py" + run_test "pyspark.streaming.util" + run_test "pyspark.streaming.tests" } echo "Running PySpark tests. Output is in python/$LOG_FILE." From 82a396c2f594bade276606dcd0c0545a650fb838 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 29 May 2015 14:59:18 -0700 Subject: [PATCH 074/254] [SPARK-7910] [TINY] [JAVAAPI] expose partitioner information in javardd Author: Holden Karau Closes #6464 from holdenk/SPARK-7910-expose-partitioner-information-in-javardd and squashes the following commits: de1e644 [Holden Karau] Fix the test to get the partitioner bdb31cc [Holden Karau] Add Mima exclude for the new method 347ef4c [Holden Karau] Add a quick little test for the partitioner JavaAPI f49dca9 [Holden Karau] Add partitoner information to JavaRDDLike and fix some whitespace --- .../scala/org/apache/spark/api/java/JavaRDDLike.scala | 9 ++++++--- core/src/test/java/org/apache/spark/JavaAPISuite.java | 2 ++ project/MimaExcludes.scala | 2 ++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index b8e15f38a20d2..c95615a5a9307 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -60,10 +60,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { @deprecated("Use partitions() instead.", "1.1.0") def splits: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) - + /** Set of partitions in this RDD. */ def partitions: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) + /** The partitioner of this RDD. */ + def partitioner: Optional[Partitioner] = JavaUtils.optionToOptional(rdd.partitioner) + /** The [[org.apache.spark.SparkContext]] that this RDD was created on. */ def context: SparkContext = rdd.context @@ -492,9 +495,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { new java.util.ArrayList(arr) } - def takeSample(withReplacement: Boolean, num: Int): JList[T] = + def takeSample(withReplacement: Boolean, num: Int): JList[T] = takeSample(withReplacement, num, Utils.random.nextLong) - + def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = { import scala.collection.JavaConversions._ val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index c2089b0e56a1f..dfd86d3e51e7d 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -212,6 +212,8 @@ public int getPartition(Object key) { JavaPairRDD repartitioned = rdd.repartitionAndSortWithinPartitions(partitioner); + Assert.assertTrue(repartitioned.partitioner().isPresent()); + Assert.assertEquals(repartitioned.partitioner().get(), partitioner); List>> partitions = repartitioned.glom().collect(); Assert.assertEquals(partitions.get(0), Arrays.asList(new Tuple2(0, 5), new Tuple2(0, 8), new Tuple2(2, 6))); diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 11b439e7875fc..8da72b3fa7cdb 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -38,6 +38,8 @@ object MimaExcludes { Seq( MimaBuild.excludeSparkPackage("deploy"), MimaBuild.excludeSparkPackage("ml"), + // SPARK-7910 Adding a method to get the partioner to JavaRDD, + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"), // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff"), // These are needed if checking against the sbt build, since they are part of From 5fb97dca9bcfc29ac33823554c8783997e811b99 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Fri, 29 May 2015 15:08:30 -0700 Subject: [PATCH 075/254] [SPARK-7954] [SPARKR] Create SparkContext in sparkRSQL init cc davies Author: Shivaram Venkataraman Closes #6507 from shivaram/sparkr-init and squashes the following commits: 6fdd169 [Shivaram Venkataraman] Create SparkContext in sparkRSQL init --- R/pkg/R/sparkR.R | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 68387f0f5365d..5ced7c688f98a 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -225,14 +225,21 @@ sparkR.init <- function( #' sqlContext <- sparkRSQL.init(sc) #'} -sparkRSQL.init <- function(jsc) { +sparkRSQL.init <- function(jsc = NULL) { if (exists(".sparkRSQLsc", envir = .sparkREnv)) { return(get(".sparkRSQLsc", envir = .sparkREnv)) } + # If jsc is NULL, create a Spark Context + sc <- if (is.null(jsc)) { + sparkR.init() + } else { + jsc + } + sqlContext <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", - "createSQLContext", - jsc) + "createSQLContext", + sc) assign(".sparkRSQLsc", sqlContext, envir = .sparkREnv) sqlContext } @@ -249,12 +256,19 @@ sparkRSQL.init <- function(jsc) { #' sqlContext <- sparkRHive.init(sc) #'} -sparkRHive.init <- function(jsc) { +sparkRHive.init <- function(jsc = NULL) { if (exists(".sparkRHivesc", envir = .sparkREnv)) { return(get(".sparkRHivesc", envir = .sparkREnv)) } - ssc <- callJMethod(jsc, "sc") + # If jsc is NULL, create a Spark Context + sc <- if (is.null(jsc)) { + sparkR.init() + } else { + jsc + } + + ssc <- callJMethod(sc, "sc") hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.HiveContext", ssc) }, error = function(err) { From dbf8ff38de0f95f467b874a5b527dcf59439efe8 Mon Sep 17 00:00:00 2001 From: Ram Sriharsha Date: Fri, 29 May 2015 15:22:26 -0700 Subject: [PATCH 076/254] [SPARK-6013] [ML] Add more Python ML examples for spark.ml Author: Ram Sriharsha Closes #6443 from harsha2010/SPARK-6013 and squashes the following commits: 732506e [Ram Sriharsha] Code Review Feedback 121c211 [Ram Sriharsha] python style fix 5f9b8c3 [Ram Sriharsha] python style fixes 925ca86 [Ram Sriharsha] Simple Params Example 8b372b1 [Ram Sriharsha] GBT Example 965ec14 [Ram Sriharsha] Random Forest Example --- .../examples/ml/JavaSimpleParamsExample.java | 2 +- .../main/python/ml/gradient_boosted_trees.py | 83 ++++++++++++++++ .../main/python/ml/random_forest_example.py | 87 ++++++++++++++++ .../main/python/ml/simple_params_example.py | 98 +++++++++++++++++++ .../examples/ml/SimpleParamsExample.scala | 2 +- 5 files changed, 270 insertions(+), 2 deletions(-) create mode 100644 examples/src/main/python/ml/gradient_boosted_trees.py create mode 100644 examples/src/main/python/ml/random_forest_example.py create mode 100644 examples/src/main/python/ml/simple_params_example.py diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index 29158d5c85651..dac649d1d5ae6 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -97,7 +97,7 @@ public static void main(String[] args) { DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. - // LogisticRegression.transform will only use the 'features' column. + // LogisticRegressionModel.transform will only use the 'features' column. // Note that model2.transform() outputs a 'myProbability' column instead of the usual // 'probability' column since we renamed the lr.probabilityCol parameter previously. DataFrame results = model2.transform(test); diff --git a/examples/src/main/python/ml/gradient_boosted_trees.py b/examples/src/main/python/ml/gradient_boosted_trees.py new file mode 100644 index 0000000000000..6446f0fe5eeab --- /dev/null +++ b/examples/src/main/python/ml/gradient_boosted_trees.py @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import GBTClassifier +from pyspark.ml.feature import StringIndexer +from pyspark.ml.regression import GBTRegressor +from pyspark.mllib.evaluation import BinaryClassificationMetrics, RegressionMetrics +from pyspark.mllib.util import MLUtils +from pyspark.sql import Row, SQLContext + +""" +A simple example demonstrating a Gradient Boosted Trees Classification/Regression Pipeline. +Note: GBTClassifier only supports binary classification currently +Run with: + bin/spark-submit examples/src/main/python/ml/gradient_boosted_trees.py +""" + + +def testClassification(train, test): + # Train a GradientBoostedTrees model. + + rf = GBTClassifier(maxIter=30, maxDepth=4, labelCol="indexedLabel") + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = BinaryClassificationMetrics(predictionAndLabels) + print("AUC %.3f" % metrics.areaUnderROC) + + +def testRegression(train, test): + # Train a GradientBoostedTrees model. + + rf = GBTRegressor(maxIter=30, maxDepth=4, labelCol="indexedLabel") + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = RegressionMetrics(predictionAndLabels) + print("rmse %.3f" % metrics.rootMeanSquaredError) + print("r2 %.3f" % metrics.r2) + print("mae %.3f" % metrics.meanAbsoluteError) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + print("Usage: gradient_boosted_trees", file=sys.stderr) + exit(1) + sc = SparkContext(appName="PythonGBTExample") + sqlContext = SQLContext(sc) + + # Load and parse the data file into a dataframe. + df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Map labels into an indexed column of labels in [0, numLabels) + stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") + si_model = stringIndexer.fit(df) + td = si_model.transform(df) + [train, test] = td.randomSplit([0.7, 0.3]) + testClassification(train, test) + testRegression(train, test) + sc.stop() diff --git a/examples/src/main/python/ml/random_forest_example.py b/examples/src/main/python/ml/random_forest_example.py new file mode 100644 index 0000000000000..c7730e1bfacd9 --- /dev/null +++ b/examples/src/main/python/ml/random_forest_example.py @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import RandomForestClassifier +from pyspark.ml.feature import StringIndexer +from pyspark.ml.regression import RandomForestRegressor +from pyspark.mllib.evaluation import MulticlassMetrics, RegressionMetrics +from pyspark.mllib.util import MLUtils +from pyspark.sql import Row, SQLContext + +""" +A simple example demonstrating a RandomForest Classification/Regression Pipeline. +Run with: + bin/spark-submit examples/src/main/python/ml/random_forest_example.py +""" + + +def testClassification(train, test): + # Train a RandomForest model. + # Setting featureSubsetStrategy="auto" lets the algorithm choose. + # Note: Use larger numTrees in practice. + + rf = RandomForestClassifier(labelCol="indexedLabel", numTrees=3, maxDepth=4) + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = MulticlassMetrics(predictionAndLabels) + print("weighted f-measure %.3f" % metrics.weightedFMeasure()) + print("precision %s" % metrics.precision()) + print("recall %s" % metrics.recall()) + + +def testRegression(train, test): + # Train a RandomForest model. + # Note: Use larger numTrees in practice. + + rf = RandomForestRegressor(labelCol="indexedLabel", numTrees=3, maxDepth=4) + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = RegressionMetrics(predictionAndLabels) + print("rmse %.3f" % metrics.rootMeanSquaredError) + print("r2 %.3f" % metrics.r2) + print("mae %.3f" % metrics.meanAbsoluteError) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + print("Usage: random_forest_example", file=sys.stderr) + exit(1) + sc = SparkContext(appName="PythonRandomForestExample") + sqlContext = SQLContext(sc) + + # Load and parse the data file into a dataframe. + df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Map labels into an indexed column of labels in [0, numLabels) + stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") + si_model = stringIndexer.fit(df) + td = si_model.transform(df) + [train, test] = td.randomSplit([0.7, 0.3]) + testClassification(train, test) + testRegression(train, test) + sc.stop() diff --git a/examples/src/main/python/ml/simple_params_example.py b/examples/src/main/python/ml/simple_params_example.py new file mode 100644 index 0000000000000..3933d59b52cd1 --- /dev/null +++ b/examples/src/main/python/ml/simple_params_example.py @@ -0,0 +1,98 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import pprint +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import LogisticRegression +from pyspark.mllib.linalg import DenseVector +from pyspark.mllib.regression import LabeledPoint +from pyspark.sql import SQLContext + +""" +A simple example demonstrating ways to specify parameters for Estimators and Transformers. +Run with: + bin/spark-submit examples/src/main/python/ml/simple_params_example.py +""" + +if __name__ == "__main__": + if len(sys.argv) > 1: + print("Usage: simple_params_example", file=sys.stderr) + exit(1) + sc = SparkContext(appName="PythonSimpleParamsExample") + sqlContext = SQLContext(sc) + + # prepare training data. + # We create an RDD of LabeledPoints and convert them into a DataFrame. + # Spark DataFrames can automatically infer the schema from named tuples + # and LabeledPoint implements __reduce__ to behave like a named tuple. + training = sc.parallelize([ + LabeledPoint(1.0, DenseVector([0.0, 1.1, 0.1])), + LabeledPoint(0.0, DenseVector([2.0, 1.0, -1.0])), + LabeledPoint(0.0, DenseVector([2.0, 1.3, 1.0])), + LabeledPoint(1.0, DenseVector([0.0, 1.2, -0.5]))]).toDF() + + # Create a LogisticRegression instance with maxIter = 10. + # This instance is an Estimator. + lr = LogisticRegression(maxIter=10) + # Print out the parameters, documentation, and any default values. + print("LogisticRegression parameters:\n" + lr.explainParams() + "\n") + + # We may also set parameters using setter methods. + lr.setRegParam(0.01) + + # Learn a LogisticRegression model. This uses the parameters stored in lr. + model1 = lr.fit(training) + + # Since model1 is a Model (i.e., a Transformer produced by an Estimator), + # we can view the parameters it used during fit(). + # This prints the parameter (name: value) pairs, where names are unique IDs for this + # LogisticRegression instance. + print("Model 1 was fit using parameters:\n") + pprint.pprint(model1.extractParamMap()) + + # We may alternatively specify parameters using a parameter map. + # paramMap overrides all lr parameters set earlier. + paramMap = {lr.maxIter: 20, lr.threshold: 0.55, lr.probabilityCol: "myProbability"} + + # Now learn a new model using the new parameters. + model2 = lr.fit(training, paramMap) + print("Model 2 was fit using parameters:\n") + pprint.pprint(model2.extractParamMap()) + + # prepare test data. + test = sc.parallelize([ + LabeledPoint(1.0, DenseVector([-1.0, 1.5, 1.3])), + LabeledPoint(0.0, DenseVector([3.0, 2.0, -0.1])), + LabeledPoint(0.0, DenseVector([0.0, 2.2, -1.5]))]).toDF() + + # Make predictions on test data using the Transformer.transform() method. + # LogisticRegressionModel.transform will only use the 'features' column. + # Note that model2.transform() outputs a 'myProbability' column instead of the usual + # 'probability' column since we renamed the lr.probabilityCol parameter previously. + result = model2.transform(test) \ + .select("features", "label", "myProbability", "prediction") \ + .collect() + + for row in result: + print("features=%s,label=%s -> prob=%s, prediction=%s" + % (row.features, row.label, row.myProbability, row.prediction)) + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index e8a991f50e338..a0561e2573fc9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -87,7 +87,7 @@ object SimpleParamsExample { LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) // Make predictions on test data using the Transformer.transform() method. - // LogisticRegression.transform will only use the 'features' column. + // LogisticRegressionModel.transform will only use the 'features' column. // Note that model2.transform() outputs a 'myProbability' column instead of the usual // 'probability' column since we renamed the lr.probabilityCol parameter previously. model2.transform(test.toDF()) From 8c9979337f193c72fd2f1a891909283de53777e3 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 29 May 2015 15:26:49 -0700 Subject: [PATCH 077/254] [HOTFIX] [SQL] Maven test compilation issue Tests compile in SBT but not Maven. --- sql/core/pom.xml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sql/core/pom.xml b/sql/core/pom.xml index ffe95bb49188f..8210c552603ea 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -41,6 +41,13 @@ spark-core_${scala.binary.version} ${project.version}
    + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-catalyst_${scala.binary.version} From a4f24123d8857656524c9138c7c067a4b1033a5e Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 29 May 2015 17:19:46 -0700 Subject: [PATCH 078/254] [HOT FIX] [BUILD] Fix maven build failures This patch fixes a build break in maven caused by #6441. Note that this patch reverts the changes in flume-sink because this module does not currently depend on Spark core, but the tests require it. There is not an easy way to make this work because mvn test dependencies are not transitive (MNG-1378). For now, we will leave the one test suite in flume-sink out until we figure out a better solution. This patch is mainly intended to unbreak the maven build. Author: Andrew Or Closes #6511 from andrewor14/fix-build-mvn and squashes the following commits: 3d53643 [Andrew Or] [HOT FIX #6441] Fix maven build failures --- external/flume-sink/pom.xml | 7 ------- .../apache/spark/streaming/flume/sink/SparkSinkSuite.scala | 5 ++--- mllib/pom.xml | 7 +++++++ 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index bb2ec96715942..1f3e619d97a24 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -35,13 +35,6 @@ http://spark.apache.org/ - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - test-jar - test - org.apache.commons commons-lang3 diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index e9fbcb9db6b78..650b2fbe1c142 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -31,10 +31,9 @@ import org.apache.flume.Context import org.apache.flume.channel.MemoryChannel import org.apache.flume.event.EventBuilder import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory +import org.scalatest.FunSuite -import org.apache.spark.SparkFunSuite - -class SparkSinkSuite extends SparkFunSuite { +class SparkSinkSuite extends FunSuite { val eventsPerBatch = 1000 val channelCapacity = 5000 diff --git a/mllib/pom.xml b/mllib/pom.xml index 0c07ca1a62fd3..65c647a91d192 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -40,6 +40,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-streaming_${scala.binary.version} From 3792d25836e1e521da64c5a62ca1b6cca1bcb6b9 Mon Sep 17 00:00:00 2001 From: Taka Shinagawa Date: Fri, 29 May 2015 20:35:14 -0700 Subject: [PATCH 079/254] [DOCS][Tiny] Added a missing dash(-) in docs/configuration.md The first line had only two dashes (--) instead of three(---). Because of this missing dash(-), 'jekyll build' command was not converting configuration.md to _site/configuration.html Author: Taka Shinagawa Closes #6513 from mrt/docfix3 and squashes the following commits: c470e2c [Taka Shinagawa] Added a missing dash(-) preventing jekyll from converting configuration.md to html format --- docs/configuration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/configuration.md b/docs/configuration.md index 30508a617fdd8..3a48da4592dd9 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1,4 +1,4 @@ --- +--- layout: global displayTitle: Spark Configuration title: Configuration From 7ed06c39922ac90acab3a78ce0f2f21184ed68a5 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 29 May 2015 22:19:15 -0700 Subject: [PATCH 080/254] [SPARK-7957] Preserve partitioning when using randomSplit cc JoshRosen Thanks for noticing this! Author: Burak Yavuz Closes #6509 from brkyvz/sample-perf-reg and squashes the following commits: 497465d [Burak Yavuz] addressed code review 293f95f [Burak Yavuz] [SPARK-7957] Preserve partitioning when using randomSplit --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 5fcef255e13af..10610f4b6f1ff 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -434,11 +434,11 @@ abstract class RDD[T: ClassTag]( * @return A random sub-sample of the RDD without replacement. */ private[spark] def randomSampleWithRange(lb: Double, ub: Double, seed: Long): RDD[T] = { - this.mapPartitionsWithIndex { case (index, partition) => + this.mapPartitionsWithIndex( { (index, partition) => val sampler = new BernoulliCellSampler[T](lb, ub) sampler.setSeed(seed + index) sampler.sample(partition) - } + }, preservesPartitioning = true) } /** From 609c4923f98c188bce60ae35c1c8a08a8dfd95f1 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 29 May 2015 22:57:46 -0700 Subject: [PATCH 081/254] [SPARK-7558] Guard against direct uses of FunSuite / FunSuiteLike This is a follow-up patch to #6441. Author: Andrew Or Closes #6510 from andrewor14/extends-funsuite-check and squashes the following commits: 6618b46 [Andrew Or] Exempt SparkSinkSuite from the FunSuite check 99d02ac [Andrew Or] Merge branch 'master' of github.com:apache/spark into extends-funsuite-check 48874dd [Andrew Or] Guard against direct uses of FunSuite / FunSuiteLike --- core/src/test/scala/org/apache/spark/SparkFunSuite.scala | 2 ++ .../spark/streaming/flume/sink/SparkSinkSuite.scala | 9 +++++++++ scalastyle-config.xml | 7 +++++++ 3 files changed, 18 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 0327dfad6ea51..8cb344332668f 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -17,12 +17,14 @@ package org.apache.spark +// scalastyle:off import org.scalatest.{FunSuite, Outcome} /** * Base abstract class for all unit tests in Spark for handling common functionality. */ private[spark] abstract class SparkFunSuite extends FunSuite with Logging { +// scalastyle:on /** * Log the suite name and the test name before and after each test. diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index 650b2fbe1c142..605b3fe71017f 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -31,9 +31,18 @@ import org.apache.flume.Context import org.apache.flume.channel.MemoryChannel import org.apache.flume.event.EventBuilder import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory + +// Due to MNG-1378, there is not a way to include test dependencies transitively. +// We cannot include Spark core tests as a dependency here because it depends on +// Spark core main, which has too many dependencies to require here manually. +// For this reason, we continue to use FunSuite and ignore the scalastyle checks +// that fail if this is detected. +//scalastyle:off import org.scalatest.FunSuite class SparkSinkSuite extends FunSuite { +//scalastyle:on + val eventsPerBatch = 1000 val channelCapacity = 5000 diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 68c8ce3b7e10b..890bf3794925b 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -153,4 +153,11 @@ + + + + ^FunSuite[A-Za-z]*$ + + Tests must extend org.apache.spark.SparkFunSuite instead. + From 193dba01c77ef1bb63e3f617213eb257960f8d2f Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 29 May 2015 23:08:47 -0700 Subject: [PATCH 082/254] [TRIVIAL] Typo fix for last commit --- scalastyle-config.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 890bf3794925b..a0098169a0248 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -153,7 +153,7 @@ - + ^FunSuite[A-Za-z]*$ From da2112aef28e63c452f592e0abd007141787877d Mon Sep 17 00:00:00 2001 From: Octavian Geagla Date: Fri, 29 May 2015 23:55:19 -0700 Subject: [PATCH 083/254] [SPARK-7576] [MLLIB] Add spark.ml user guide doc/example for ElementwiseProduct Author: Octavian Geagla Closes #6501 from ogeagla/ml-guide-elemwiseprod and squashes the following commits: 4ad93d5 [Octavian Geagla] [SPARK-7576] [MLLIB] Incorporate code review feedback. f7be7ad [Octavian Geagla] [SPARK-7576] [MLLIB] Add spark.ml user guide doc/example for ElementwiseProduct. --- docs/ml-features.md | 88 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/docs/ml-features.md b/docs/ml-features.md index d7851a55fabfe..81f1b8823a8ce 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -876,5 +876,93 @@ bucketedData = bucketizer.transform(dataFrame) +## ElementwiseProduct + +ElementwiseProduct multiplies each input vector by a provided "weight" vector, using element-wise multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29) between the input vector, `v` and transforming vector, `w`, to yield a result vector. + +`\[ \begin{pmatrix} +v_1 \\ +\vdots \\ +v_N +\end{pmatrix} \circ \begin{pmatrix} + w_1 \\ + \vdots \\ + w_N + \end{pmatrix} += \begin{pmatrix} + v_1 w_1 \\ + \vdots \\ + v_N w_N + \end{pmatrix} +\]` + +[`ElementwiseProduct`](api/scala/index.html#org.apache.spark.ml.feature.ElementwiseProduct) takes the following parameter: + +* `scalingVec`: the transforming vector. + +This example below demonstrates how to transform vectors using a transforming vector value. + +
    +
    +{% highlight scala %} +import org.apache.spark.ml.feature.ElementwiseProduct +import org.apache.spark.mllib.linalg.Vectors + +// Create some vector data; also works for sparse vectors +val dataFrame = sqlContext.createDataFrame(Seq( + ("a", Vectors.dense(1.0, 2.0, 3.0)), + ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector") + +val transformingVector = Vectors.dense(0.0, 1.0, 2.0) +val transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector") + +// Batch transform the vectors to create new column: +val transformedData = transformer.transform(dataFrame) + +{% endhighlight %} +
    + +
    +{% highlight java %} +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.ElementwiseProduct; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +// Create some vector data; also works for sparse vectors +JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), + RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) +)); +List fields = new ArrayList(2); +fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); +fields.add(DataTypes.createStructField("vector", DataTypes.StringType, false)); +StructType schema = DataTypes.createStructType(fields); +DataFrame dataFrame = sqlContext.createDataFrame(jrdd, schema); +Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); +ElementwiseProduct transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector"); +// Batch transform the vectors to create new column: +DataFrame transformedData = transformer.transform(dataFrame); + +{% endhighlight %} +
    +
    + # Feature Selectors From 78657d53d71b9d3e86b675cc519868f99e2ffa01 Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Fri, 29 May 2015 23:56:18 -0700 Subject: [PATCH 084/254] [SPARK-7962] [MESOS] Fix master url parsing in rest submission client. Only parse standalone master url when master url starts with spark:// Author: Timothy Chen Closes #6517 from tnachen/fix_mesos_client and squashes the following commits: 61a1198 [Timothy Chen] Fix master url parsing in rest submission client. --- .../org/apache/spark/deploy/rest/RestSubmissionClient.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 6078f50518ba4..1fe956320a1b8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -57,7 +57,11 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { private val supportedMasterPrefixes = Seq("spark://", "mesos://") - private val masters: Array[String] = Utils.parseStandaloneMasterUrls(master) + private val masters: Array[String] = if (master.startsWith("spark://")) { + Utils.parseStandaloneMasterUrls(master) + } else { + Array(master) + } // Set of masters that lost contact with us, used to keep track of // whether there are masters still alive for us to communicate with From e3a43748338b02ef6864ca62de40e218e5677506 Mon Sep 17 00:00:00 2001 From: Octavian Geagla Date: Sat, 30 May 2015 00:00:36 -0700 Subject: [PATCH 085/254] [SPARK-7459] [MLLIB] ElementwiseProduct Java example Author: Octavian Geagla Closes #6008 from ogeagla/elementwise-prod-doc and squashes the following commits: 72e6dc0 [Octavian Geagla] [SPARK-7459] [MLLIB] Java example import. cf2afbd [Octavian Geagla] [SPARK-7459] [MLLIB] Update description of example. b66431b [Octavian Geagla] [SPARK-7459] [MLLIB] Add override annotation to java example, make scala example use same data as java. 6b26b03 [Octavian Geagla] [SPARK-7459] [MLLIB] Fix line which is too long. 79af020 [Octavian Geagla] [SPARK-7459] [MLLIB] Actually don't use Java 8. 9d5b31a [Octavian Geagla] [SPARK-7459] [MLLIB] Don't use Java 8 4f0c92f [Octavian Geagla] [SPARK-7459] [MLLIB] ElementwiseProduct Java example. --- docs/mllib-feature-extraction.md | 40 +++++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index f723cd6b9dfab..764985d436ead 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -505,7 +505,7 @@ v_N ### Example -This example below demonstrates how to load a simple vectors file, extract a set of vectors, then transform those vectors using a transforming vector value. +This example below demonstrates how to transform vectors using a transforming vector value.
    @@ -514,16 +514,44 @@ import org.apache.spark.SparkContext._ import org.apache.spark.mllib.feature.ElementwiseProduct import org.apache.spark.mllib.linalg.Vectors -// Load and parse the data: -val data = sc.textFile("data/mllib/kmeans_data.txt") -val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))) +// Create some vector data; also works for sparse vectors +val data = sc.parallelize(Array(Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(4.0, 5.0, 6.0))) val transformingVector = Vectors.dense(0.0, 1.0, 2.0) val transformer = new ElementwiseProduct(transformingVector) // Batch transform and per-row transform give the same results: -val transformedData = transformer.transform(parsedData) -val transformedData2 = parsedData.map(x => transformer.transform(x)) +val transformedData = transformer.transform(data) +val transformedData2 = data.map(x => transformer.transform(x)) + +{% endhighlight %} +
    + +
    +{% highlight java %} +import java.util.Arrays; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.feature.ElementwiseProduct; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; + +// Create some vector data; also works for sparse vectors +JavaRDD data = sc.parallelize(Arrays.asList( + Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(4.0, 5.0, 6.0))); +Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); +ElementwiseProduct transformer = new ElementwiseProduct(transformingVector); + +// Batch transform and per-row transform give the same results: +JavaRDD transformedData = transformer.transform(data); +JavaRDD transformedData2 = data.map( + new Function() { + @Override + public Vector call(Vector v) { + return transformer.transform(v); + } + } +); {% endhighlight %}
    From 0978aec9cd47dc0618e47b74a99e1cc2266be424 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 30 May 2015 00:26:46 -0700 Subject: [PATCH 086/254] [SPARK-7964][SQL] remove unnecessary type coercion rule We have defined these logics in `Cast` already, I think we should remove this rule. Author: Wenchen Fan Closes #6516 from cloud-fan/tmp2 and squashes the following commits: d5035a4 [Wenchen Fan] remove useless rule --- .../catalyst/analysis/HiveTypeCoercion.scala | 27 ------------------- .../analysis/HiveTypeCoercionSuite.scala | 16 ----------- .../ExpressionEvaluationSuite.scala | 2 ++ 3 files changed, 2 insertions(+), 43 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 44664f898f762..195418d6dfb1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -77,7 +77,6 @@ trait HiveTypeCoercion { PromoteStrings :: DecimalPrecision :: BooleanComparisons :: - BooleanCasts :: StringToIntegralCasts :: FunctionArgumentConversion :: CaseWhenCoercion :: @@ -510,32 +509,6 @@ trait HiveTypeCoercion { } } - /** - * Casts to/from [[BooleanType]] are transformed into comparisons since - * the JVM does not consider Booleans to be numeric types. - */ - object BooleanCasts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - // Skip nodes who's children have not been resolved yet. - case e if !e.childrenResolved => e - // Skip if the type is boolean type already. Note that this extra cast should be removed - // by optimizer.SimplifyCasts. - case Cast(e, BooleanType) if e.dataType == BooleanType => e - // DateType should be null if be cast to boolean. - case Cast(e, BooleanType) if e.dataType == DateType => Cast(e, BooleanType) - // If the data type is not boolean and is being cast boolean, turn it into a comparison - // with the numeric value, i.e. x != 0. This will coerce the type into numeric type. - case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0))) - // Stringify boolean if casting to StringType. - // TODO Ensure true/false string letter casing is consistent with Hive in all cases. - case Cast(e, StringType) if e.dataType == BooleanType => - If(e, Literal("true"), Literal("false")) - // Turn true into 1, and false into 0 if casting boolean into other types. - case Cast(e, dataType) if e.dataType == BooleanType => - Cast(If(e, Literal(1), Literal(0)), dataType) - } - } - /** * When encountering a cast from a string representing a valid fractional number to an integral * type the jvm will throw a `java.lang.NumberFormatException`. Hive, in contrast, returns the diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index fcd745f43cfbf..f0101f4a88f86 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -104,22 +104,6 @@ class HiveTypeCoercionSuite extends PlanTest { widenTest(ArrayType(IntegerType), StructType(Seq()), None) } - test("boolean casts") { - val booleanCasts = new HiveTypeCoercion { }.BooleanCasts - def ruleTest(initial: Expression, transformed: Expression) { - val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) - comparePlans( - booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)), - Project(Seq(Alias(transformed, "a")()), testRelation)) - } - // Remove superflous boolean -> boolean casts. - ruleTest(Cast(Literal(true), BooleanType), Literal(true)) - // Stringify boolean when casting to string. - ruleTest( - Cast(Literal(false), StringType), - If(Literal(false), Literal("true"), Literal("false"))) - } - test("coalesce casts") { val fac = new HiveTypeCoercion { }.FunctionArgumentConversion def ruleTest(initial: Expression, transformed: Expression) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index b511aa3a24420..10181366c2fcd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -372,6 +372,8 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType), 0) checkEvaluation(Literal(true) cast IntegerType, 1) checkEvaluation(Literal(false) cast IntegerType, 0) + checkEvaluation(Literal(true) cast StringType, "true") + checkEvaluation(Literal(false) cast StringType, "false") checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1) checkEvaluation(Cast(Literal(0) cast BooleanType, IntegerType), 0) checkEvaluation("23" cast DoubleType, 23d) From 8c8de3ed863985554e84fd07d1cdcaeca7e3375c Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 30 May 2015 07:59:27 -0400 Subject: [PATCH 087/254] [SPARK-7890] [DOCS] Document that Spark 2.11 now supports Kafka Remove caveat about Kafka / JDBC not being supported for Scala 2.11 Author: Sean Owen Closes #6470 from srowen/SPARK-7890 and squashes the following commits: 4652634 [Sean Owen] One more rewording 7b7f3c8 [Sean Owen] Restore note about JDBC component 126744d [Sean Owen] Remove caveat about Kafka / JDBC not being supported for Scala 2.11 --- docs/building-spark.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/building-spark.md b/docs/building-spark.md index b2649d1ee2a53..b4cea158dac93 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -130,9 +130,7 @@ To produce a Spark package compiled with Scala 2.11, use the `-Dscala-2.11` prop dev/change-version-to-2.11.sh mvn -Pyarn -Phadoop-2.4 -Dscala-2.11 -DskipTests clean package -Scala 2.11 support in Spark does not support a few features due to dependencies -which are themselves not Scala 2.11 ready. Specifically, Spark's external -Kafka library and JDBC component are not yet supported in Scala 2.11 builds. +Spark does not yet support its JDBC component for Scala 2.11. # Spark Tests in Maven From 9d8aadb72bbc86595e253fe30201cda6a8db877e Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Sat, 30 May 2015 08:04:27 -0400 Subject: [PATCH 088/254] [SPARK-7945] [CORE] Do trim to values in properties file https://issues.apache.org/jira/browse/SPARK-7945 Now applications submited by org.apache.spark.launcher.Main read properties file without doing trim to values in it. If user left a space after a value(say spark.driver.extraClassPath) then it probably affect global functions(like some jar could not be included in the classpath), so we should do it like Utils.getPropertiesFromFile. Author: WangTaoTheTonic Author: Tao Wang Closes #6496 from WangTaoTheTonic/SPARK-7945 and squashes the following commits: bb41b4b [Tao Wang] indent 4 to 2 6dd1cf2 [WangTaoTheTonic] use a simpler way 2c053a1 [WangTaoTheTonic] Do trim to values in properties file --- .../java/org/apache/spark/launcher/AbstractCommandBuilder.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 33fd813f7a86c..33d65d13f0d25 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -296,6 +296,9 @@ Properties loadPropertiesFile() throws IOException { try { fd = new FileInputStream(propsFile); props.load(new InputStreamReader(fd, "UTF-8")); + for (Map.Entry e : props.entrySet()) { + e.setValue(e.getValue().toString().trim()); + } } finally { if (fd != null) { try { From 2b35c99c7e73d22e82aef90b675709ae7f8d3b4a Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Sat, 30 May 2015 08:06:11 -0400 Subject: [PATCH 089/254] [SPARK-7717] [WEBUI] Only showing total memory and cores for alive workers Author: zhichao.li Closes #6317 from zhichao-li/workers and squashes the following commits: d68bf11 [zhichao.li] change prefix 99b6768 [zhichao.li] remove extra space and add 'Alive' prefix 1e8eb06 [zhichao.li] only showing alive workers --- .../apache/spark/deploy/master/ui/MasterPage.scala | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 756927682cd24..6a7c74020bace 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -75,6 +75,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { val workerHeaders = Seq("Worker Id", "Address", "State", "Cores", "Memory") val workers = state.workers.sortBy(_.id) + val aliveWorkers = state.workers.filter(_.state == WorkerState.ALIVE) val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers) val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Node", "Submitted Time", @@ -108,12 +109,12 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { }.getOrElse { Seq.empty } } -
  • Workers: {state.workers.size}
  • -
  • Cores: {state.workers.map(_.cores).sum} Total, - {state.workers.map(_.coresUsed).sum} Used
  • -
  • Memory: - {Utils.megabytesToString(state.workers.map(_.memory).sum)} Total, - {Utils.megabytesToString(state.workers.map(_.memoryUsed).sum)} Used
  • +
  • Alive Workers: {aliveWorkers.size}
  • +
  • Cores in use: {aliveWorkers.map(_.cores).sum} Total, + {aliveWorkers.map(_.coresUsed).sum} Used
  • +
  • Memory in use: + {Utils.megabytesToString(aliveWorkers.map(_.memory).sum)} Total, + {Utils.megabytesToString(aliveWorkers.map(_.memoryUsed).sum)} Used
  • Applications: {state.activeApps.size} Running, {state.completedApps.size} Completed
  • From 3ab71eb9d5e3fe21af7720421eafa51f6da9b63f Mon Sep 17 00:00:00 2001 From: Taka Shinagawa Date: Sat, 30 May 2015 08:25:21 -0400 Subject: [PATCH 090/254] [DOCS] [MINOR] Update for the Hadoop versions table with hadoop-2.6 Updated the doc for the hadoop-2.6 profile, which is new to Spark 1.4 Author: Taka Shinagawa Closes #6450 from mrt/docfix2 and squashes the following commits: db1c43b [Taka Shinagawa] Updated the hadoop versions for hadoop-2.6 profile 323710e [Taka Shinagawa] The hadoop-2.6 profile is added to the Hadoop versions table --- docs/building-spark.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/building-spark.md b/docs/building-spark.md index b4cea158dac93..78cb9086f95e8 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -80,6 +80,7 @@ Because HDFS is not protocol-compatible across versions, if you want to read fro 2.2.xhadoop-2.2 2.3.xhadoop-2.3 2.4.xhadoop-2.4 + 2.6.x and later 2.xhadoop-2.6 From d34b43bd5964e1feb03a17937de87a3f718806a5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 30 May 2015 12:06:38 -0700 Subject: [PATCH 091/254] Closes #4685 From 6e3f0c7810a6721698b0ed51cfbd41a0cd07a4a3 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 30 May 2015 12:16:09 -0700 Subject: [PATCH 092/254] [SPARK-7849] [SQL] [Docs] Updates SQL programming guide for 1.4 Author: Cheng Lian Closes #6520 from liancheng/spark-7849 and squashes the following commits: 705264b [Cheng Lian] Updates SQL programming guide for 1.4 --- docs/sql-programming-guide.md | 91 ++++++++++++++++++++++++++++++++--- 1 file changed, 85 insertions(+), 6 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 7cc0a87fd5c53..4ec3d83016ac6 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -11,6 +11,7 @@ title: Spark SQL and DataFrames Spark SQL is a Spark module for structured data processing. It provides a programming abstraction called DataFrames and can also act as distributed SQL query engine. +For how to enable Hive support, please refer to the [Hive Tables](#hive-tables) section. # DataFrames @@ -906,7 +907,7 @@ new data. Ignore mode means that when saving a DataFrame to a data source, if data already exists, the save operation is expected to not save the contents of the DataFrame and to not - change the existing data. This is similar to a `CREATE TABLE IF NOT EXISTS` in SQL. + change the existing data. This is similar to a CREATE TABLE IF NOT EXISTS in SQL. @@ -1030,7 +1031,7 @@ teenagers <- sql(sqlContext, "SELECT name FROM parquetFile WHERE age >= 13 AND a teenNames <- map(teenagers, function(p) { paste("Name:", p$name)}) for (teenName in collect(teenNames)) { cat(teenName, "\n") -} +} {% endhighlight %}
    @@ -1502,7 +1503,7 @@ Row[] results = sqlContext.sql("FROM src SELECT key, value").collect();
    When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and -adds support for finding tables in the MetaStore and writing queries using HiveQL. +adds support for finding tables in the MetaStore and writing queries using HiveQL. {% highlight python %} # sc is an existing SparkContext. from pyspark.sql import HiveContext @@ -1537,6 +1538,82 @@ results = sqlContext.sql("FROM src SELECT key, value").collect()
    +### Interacting with Different Versions of Hive Metastore + +One of the most important pieces of Spark SQL's Hive support is interaction with Hive metastore, +which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.2.0, Spark SQL can +talk to two versions of Hive metastore, either 0.12.0 or 0.13.1, default to the latter. However, to +switch to desired Hive metastore version, users have to rebuild the assembly jar with proper profile +flags (either `-Phive-0.12.0` or `-Phive-0.13.1`), which is quite inconvenient. + +Starting from 1.4.0, users no longer need to rebuild the assembly jar to switch Hive metastore +version. Instead, configuration properties described in the table below can be used to specify +desired Hive metastore version. Currently, supported versions are still limited to 0.13.1 and +0.12.0, but we are working on a more generalized mechanism to support a wider range of versions. + +Internally, Spark SQL 1.4.0 uses two Hive clients, one for executing native Hive commands like `SET` +and `DESCRIBE`, the other dedicated for communicating with Hive metastore. The former uses Hive +jars of version 0.13.1, which are bundled with Spark 1.4.0. The latter uses Hive jars of the +version specified by users. An isolated classloader is used here to avoid dependency conflicts. + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameMeaning
    spark.sql.hive.metastore.version + The version of the hive client that will be used to communicate with the metastore. Available + options are 0.12.0 and 0.13.1. Defaults to 0.13.1. +
    spark.sql.hive.metastore.jars + The location of the jars that should be used to instantiate the HiveMetastoreClient. This + property can be one of three options: +
      +
    1. builtin
    2. + Use Hive 0.13.1, which is bundled with the Spark assembly jar when -Phive is + enabled. When this option is chosen, spark.sql.hive.metastore.version must be + either 0.13.1 or not defined. +
    3. maven
    4. + Use Hive jars of specified version downloaded from Maven repositories. +
    5. A classpath in the standard format for both Hive and Hadoop.
    6. +
    + Defaults to builtin. +
    spark.sql.hive.metastore.sharedPrefixes +

    + A comma separated list of class prefixes that should be loaded using the classloader that is + shared between Spark SQL and a specific version of Hive. An example of classes that should + be shared is JDBC drivers that are needed to talk to the metastore. Other classes that need + to be shared are those that interact with classes that are already shared. For example, + custom appenders that are used by log4j. +

    +

    + Defaults to com.mysql.jdbc,org.postgresql,com.microsoft.sqlserver,oracle.jdbc. +

    +
    spark.sql.hive.metastore.barrierPrefixes +

    + A comma separated list of class prefixes that should explicitly be reloaded for each version + of Hive that Spark SQL is communicating with. For example, Hive UDFs that are declared in a + prefix that typically would be shared (i.e. org.apache.spark.*). +

    +

    Defaults to empty.

    +
    + ## JDBC To Other Databases Spark SQL also includes a data source that can read data from other databases using JDBC. This @@ -1570,7 +1647,7 @@ the Data Sources API. The following options are supported: dbtable - The JDBC table that should be read. Note that anything that is valid in a `FROM` clause of + The JDBC table that should be read. Note that anything that is valid in a FROM clause of a SQL query can be used. For example, instead of a full table you could also use a subquery in parentheses. @@ -1714,7 +1791,7 @@ that these options will be deprecated in future release as more optimizations ar Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently statistics are only supported for Hive Metastore tables where the command - `ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan` has been run. + ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run. @@ -1737,7 +1814,9 @@ that these options will be deprecated in future release as more optimizations ar # Distributed SQL Engine -Spark SQL can also act as a distributed query engine using its JDBC/ODBC or command-line interface. In this mode, end-users or applications can interact with Spark SQL directly to run SQL queries, without the need to write any code. +Spark SQL can also act as a distributed query engine using its JDBC/ODBC or command-line interface. +In this mode, end-users or applications can interact with Spark SQL directly to run SQL queries, +without the need to write any code. ## Running the Thrift JDBC/ODBC server From 7716a5a1ec8ff8dc24e0146f8ead2f51da6512ad Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 30 May 2015 14:57:23 -0700 Subject: [PATCH 093/254] Updated SQL programming guide's Hive connectivity section. --- docs/sql-programming-guide.md | 46 +++++++++++++---------------------- 1 file changed, 17 insertions(+), 29 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 4ec3d83016ac6..2ea7572c6026a 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1541,79 +1541,67 @@ results = sqlContext.sql("FROM src SELECT key, value").collect() ### Interacting with Different Versions of Hive Metastore One of the most important pieces of Spark SQL's Hive support is interaction with Hive metastore, -which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.2.0, Spark SQL can -talk to two versions of Hive metastore, either 0.12.0 or 0.13.1, default to the latter. However, to -switch to desired Hive metastore version, users have to rebuild the assembly jar with proper profile -flags (either `-Phive-0.12.0` or `-Phive-0.13.1`), which is quite inconvenient. +which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary build of Spark SQL can be used to query different versions of Hive metastores, using the configuration described below. -Starting from 1.4.0, users no longer need to rebuild the assembly jar to switch Hive metastore -version. Instead, configuration properties described in the table below can be used to specify -desired Hive metastore version. Currently, supported versions are still limited to 0.13.1 and -0.12.0, but we are working on a more generalized mechanism to support a wider range of versions. - -Internally, Spark SQL 1.4.0 uses two Hive clients, one for executing native Hive commands like `SET` -and `DESCRIBE`, the other dedicated for communicating with Hive metastore. The former uses Hive -jars of version 0.13.1, which are bundled with Spark 1.4.0. The latter uses Hive jars of the -version specified by users. An isolated classloader is used here to avoid dependency conflicts. +Internally, Spark SQL uses two Hive clients, one for executing native Hive commands like `SET` +and `DESCRIBE`, the other dedicated for communicating with Hive metastore. The former uses Hive +jars of version 0.13.1, which are bundled with Spark 1.4.0. The latter uses Hive jars of the +version specified by users. An isolated classloader is used here to avoid dependency conflicts. - + + - + - - + - +
    Property NameMeaning
    Property NameDefaultMeaning
    spark.sql.hive.metastore.version0.13.1 - The version of the hive client that will be used to communicate with the metastore. Available - options are 0.12.0 and 0.13.1. Defaults to 0.13.1. + Version of the Hive metastore. Available + options are 0.12.0 and 0.13.1. Support for more versions is coming in the future.
    spark.sql.hive.metastore.jarsbuiltin - The location of the jars that should be used to instantiate the HiveMetastoreClient. This + Location of the jars that should be used to instantiate the HiveMetastoreClient. This property can be one of three options:
    1. builtin
    2. Use Hive 0.13.1, which is bundled with the Spark assembly jar when -Phive is - enabled. When this option is chosen, spark.sql.hive.metastore.version must be + enabled. When this option is chosen, spark.sql.hive.metastore.version must be either 0.13.1 or not defined.
    3. maven
    4. Use Hive jars of specified version downloaded from Maven repositories.
    5. A classpath in the standard format for both Hive and Hadoop.
    - Defaults to builtin.
    spark.sql.hive.metastore.sharedPrefixescom.mysql.jdbc,
    org.postgresql,
    com.microsoft.sqlserver,
    oracle.jdbc

    A comma separated list of class prefixes that should be loaded using the classloader that is shared between Spark SQL and a specific version of Hive. An example of classes that should be shared is JDBC drivers that are needed to talk to the metastore. Other classes that need - to be shared are those that interact with classes that are already shared. For example, + to be shared are those that interact with classes that are already shared. For example, custom appenders that are used by log4j.

    -

    - Defaults to com.mysql.jdbc,org.postgresql,com.microsoft.sqlserver,oracle.jdbc. -

    spark.sql.hive.metastore.barrierPrefixes(empty)

    A comma separated list of class prefixes that should explicitly be reloaded for each version - of Hive that Spark SQL is communicating with. For example, Hive UDFs that are declared in a + of Hive that Spark SQL is communicating with. For example, Hive UDFs that are declared in a prefix that typically would be shared (i.e. org.apache.spark.*).

    -

    Defaults to empty.

    + ## JDBC To Other Databases Spark SQL also includes a data source that can read data from other databases using JDBC. This From a6430028ecd7a6130f1eb15af9ec00e242c46725 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 30 May 2015 15:27:51 -0700 Subject: [PATCH 094/254] [SPARK-7855] Move bypassMergeSort-handling from ExternalSorter to own component Spark's `ExternalSorter` writes shuffle output files during sort-based shuffle. Sort-shuffle contains a configuration, `spark.shuffle.sort.bypassMergeThreshold`, which causes ExternalSorter to skip sorting and merging and simply write separate files per partition, which are then concatenated together to form the final map output file. The code paths used during this bypass are almost completely separate from ExternalSorter's other code paths, so refactoring them into a separate file can significantly simplify the code. In addition to re-arranging code, this patch deletes a bunch of dead code. The main entry point into ExternalSorter is `insertAll()` and in SPARK-4479 / #3422 this method was modified to completely bypass in-memory buffering of records when `bypassMergeSort` takes effect. As a result, some of the spilling and merging code paths will no longer be called when `bypassMergeSort` is used, so we should be able to safely remove that code. There's an open JIRA ([SPARK-6026](https://issues.apache.org/jira/browse/SPARK-6026)) for removing the `bypassMergeThreshold` parameter and code paths; I have not done that here, but the changes in this patch will make removing that parameter significantly easier if we ever decide to do that. This patch also makes several improvements to shuffle-related tests and adds more defensive checks to certain shuffle classes: - DiskBlockObjectWriter now throws an exception if `fileSegment()` is called before `commitAndClose()` has been called. - DiskBlockObjectWriter's close methods are now idempotent, so calling any of the close methods twice in a row will no longer result in incorrect shuffle write metrics changes. Calling `revertPartialWritesAndClose()` on a closed DiskBlockObjectWriter now has no effect (before, it might mess up the metrics). - The end-to-end shuffle record count metrics tests have been moved from InputOutputMetricsSuite to ShuffleSuite. This means that these tests will now be run against all shuffle implementations rather than just the default shuffle configuration. - The end-to-end metrics tests now include a test of a job which performs aggregation in the shuffle. - Our tests now check that `shuffleBytesWritten == totalShuffleBytesRead`. - FileSegment now throws IllegalArgumentException if it is constructed with a negative length or offset. Author: Josh Rosen Closes #6397 from JoshRosen/external-sorter-bypass-cleanup and squashes the following commits: bf3f3f6 [Josh Rosen] Merge remote-tracking branch 'origin/master' into external-sorter-bypass-cleanup 8b216c4 [Josh Rosen] Guard against negative offsets and lengths in FileSegment 03f35a4 [Josh Rosen] Minor fix to cleanup logic. b5cc35b [Josh Rosen] Move shuffle metrics tests to ShuffleSuite. 8b8fb9e [Josh Rosen] Add more tests + defensive programming to DiskBlockObjectWriter. 16564eb [Josh Rosen] Guard against calling fileSegment() before commitAndClose() has been called. 96811b4 [Josh Rosen] Remove confusing taskMetrics.shuffleWriteMetrics() optional call 8522b6a [Josh Rosen] Do not perform a map-side sort unless we're also doing map-side aggregation 08e40f3 [Josh Rosen] Remove excessively clever (and wrong) implementation of newBuffer() d7f9938 [Josh Rosen] Add missing overrides; fix compilation 71d76ff [Josh Rosen] Update Javadoc bf0d98f [Josh Rosen] Add comment to clarify confusing factory code 5197f73 [Josh Rosen] Add missing private[this] 30ef2c8 [Josh Rosen] Convert BypassMergeSortShuffleWriter to Java bc1a820 [Josh Rosen] Fix bug when aggregator is used but map-side combine is disabled 0d3dcc0 [Josh Rosen] Remove unnecessary overloaded methods 25b964f [Josh Rosen] Rename SortShuffleSorter to SortShuffleFileWriter 0d9848c [Josh Rosen] Make it more clear that curWriteMetrics is now only used for spill metrics 7af7aea [Josh Rosen] Combine spill() and spillToMergeableFile() 6320112 [Josh Rosen] Add missing negation in deletion success check. d267e0d [Josh Rosen] Fix style issue 7f15f7b [Josh Rosen] Back out extra cleanup-handling code, since this is already covered in stop() 25aa3bd [Josh Rosen] Make sure to delete outputFile after errors. 931ca68 [Josh Rosen] Refactor tests. 6a35716 [Josh Rosen] Refactor logic for deciding when to bypass 4b03539 [Josh Rosen] Move conf prior to first use 1265b25 [Josh Rosen] Fix some style errors and comments. 02355ef [Josh Rosen] More simplification d4cb536 [Josh Rosen] Delete more unused code bb96678 [Josh Rosen] Add missing interface file b6cc1eb [Josh Rosen] Realize that bypass never buffers; proceed to delete tons of code 6185ee2 [Josh Rosen] WIP towards moving bypass code into own file. 8d0678c [Josh Rosen] Move diskBytesSpilled getter next to variable 19bccd6 [Josh Rosen] Remove duplicated buffer creation code. 18959bb [Josh Rosen] Move comparator methods closer together. --- .../sort/BypassMergeSortShuffleWriter.java | 184 +++++++++++++ .../shuffle/sort/SortShuffleFileWriter.java | 53 ++++ .../shuffle/sort/SortShuffleWriter.scala | 34 ++- .../spark/storage/BlockObjectWriter.scala | 19 +- .../apache/spark/storage/FileSegment.scala | 2 + .../util/collection/ExternalSorter.scala | 260 +++++------------- .../spark/util/collection/PairIterator.scala | 24 -- .../collection/PartitionedAppendOnlyMap.scala | 4 - .../collection/PartitionedPairBuffer.scala | 4 - .../PartitionedSerializedPairBuffer.scala | 4 - .../WritablePartitionedPairCollection.scala | 36 +-- .../scala/org/apache/spark/ShuffleSuite.scala | 65 +++++ .../metrics/InputOutputMetricsSuite.scala | 28 -- .../BypassMergeSortShuffleWriterSuite.scala | 171 ++++++++++++ .../shuffle/sort/SortShuffleWriterSuite.scala | 46 ++++ .../storage/BlockObjectWriterSuite.scala | 97 ++++++- .../util/collection/ExternalSorterSuite.scala | 130 +-------- 17 files changed, 738 insertions(+), 423 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java delete mode 100644 core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala create mode 100644 core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java new file mode 100644 index 0000000000000..d3d6280284beb --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; + +import scala.Product2; +import scala.Tuple2; +import scala.collection.Iterator; + +import com.google.common.io.Closeables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.Partitioner; +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.Serializer; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.storage.*; +import org.apache.spark.util.Utils; + +/** + * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path + * writes incoming records to separate files, one file per reduce partition, then concatenates these + * per-partition files to form a single output file, regions of which are served to reducers. + * Records are not buffered in memory. This is essentially identical to + * {@link org.apache.spark.shuffle.hash.HashShuffleWriter}, except that it writes output in a format + * that can be served / consumed via {@link org.apache.spark.shuffle.IndexShuffleBlockResolver}. + *

    + * This write path is inefficient for shuffles with large numbers of reduce partitions because it + * simultaneously opens separate serializers and file streams for all partitions. As a result, + * {@link SortShuffleManager} only selects this write path when + *

      + *
    • no Ordering is specified,
    • + *
    • no Aggregator is specific, and
    • + *
    • the number of partitions is less than + * spark.shuffle.sort.bypassMergeThreshold.
    • + *
    + * + * This code used to be part of {@link org.apache.spark.util.collection.ExternalSorter} but was + * refactored into its own class in order to reduce code complexity; see SPARK-7855 for details. + *

    + * There have been proposals to completely remove this code path; see SPARK-6026 for details. + */ +final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter { + + private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class); + + private final int fileBufferSize; + private final boolean transferToEnabled; + private final int numPartitions; + private final BlockManager blockManager; + private final Partitioner partitioner; + private final ShuffleWriteMetrics writeMetrics; + private final Serializer serializer; + + /** Array of file writers, one for each partition */ + private BlockObjectWriter[] partitionWriters; + + public BypassMergeSortShuffleWriter( + SparkConf conf, + BlockManager blockManager, + Partitioner partitioner, + ShuffleWriteMetrics writeMetrics, + Serializer serializer) { + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); + this.numPartitions = partitioner.numPartitions(); + this.blockManager = blockManager; + this.partitioner = partitioner; + this.writeMetrics = writeMetrics; + this.serializer = serializer; + } + + @Override + public void insertAll(Iterator> records) throws IOException { + assert (partitionWriters == null); + if (!records.hasNext()) { + return; + } + final SerializerInstance serInstance = serializer.newInstance(); + final long openStartTime = System.nanoTime(); + partitionWriters = new BlockObjectWriter[numPartitions]; + for (int i = 0; i < numPartitions; i++) { + final Tuple2 tempShuffleBlockIdPlusFile = + blockManager.diskBlockManager().createTempShuffleBlock(); + final File file = tempShuffleBlockIdPlusFile._2(); + final BlockId blockId = tempShuffleBlockIdPlusFile._1(); + partitionWriters[i] = + blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics).open(); + } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, and can take a long time in aggregate when we open many files, so should be + // included in the shuffle write time. + writeMetrics.incShuffleWriteTime(System.nanoTime() - openStartTime); + + while (records.hasNext()) { + final Product2 record = records.next(); + final K key = record._1(); + partitionWriters[partitioner.getPartition(key)].write(key, record._2()); + } + + for (BlockObjectWriter writer : partitionWriters) { + writer.commitAndClose(); + } + } + + @Override + public long[] writePartitionedFile( + BlockId blockId, + TaskContext context, + File outputFile) throws IOException { + // Track location of the partition starts in the output file + final long[] lengths = new long[numPartitions]; + if (partitionWriters == null) { + // We were passed an empty iterator + return lengths; + } + + final FileOutputStream out = new FileOutputStream(outputFile, true); + final long writeStartTime = System.nanoTime(); + boolean threwException = true; + try { + for (int i = 0; i < numPartitions; i++) { + final FileInputStream in = new FileInputStream(partitionWriters[i].fileSegment().file()); + boolean copyThrewException = true; + try { + lengths[i] = Utils.copyStream(in, out, false, transferToEnabled); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + } + if (!blockManager.diskBlockManager().getFile(partitionWriters[i].blockId()).delete()) { + logger.error("Unable to delete file for partition {}", i); + } + } + threwException = false; + } finally { + Closeables.close(out, threwException); + writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime); + } + partitionWriters = null; + return lengths; + } + + @Override + public void stop() throws IOException { + if (partitionWriters != null) { + try { + final DiskBlockManager diskBlockManager = blockManager.diskBlockManager(); + for (BlockObjectWriter writer : partitionWriters) { + // This method explicitly does _not_ throw exceptions: + writer.revertPartialWritesAndClose(); + if (!diskBlockManager.getFile(writer.blockId()).delete()) { + logger.error("Error while deleting file for block {}", writer.blockId()); + } + } + } finally { + partitionWriters = null; + } + } + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java new file mode 100644 index 0000000000000..656ea0401a144 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import java.io.File; +import java.io.IOException; + +import scala.Product2; +import scala.collection.Iterator; + +import org.apache.spark.annotation.Private; +import org.apache.spark.TaskContext; +import org.apache.spark.storage.BlockId; + +/** + * Interface for objects that {@link SortShuffleWriter} uses to write its output files. + */ +@Private +public interface SortShuffleFileWriter { + + void insertAll(Iterator> records) throws IOException; + + /** + * Write all the data added into this shuffle sorter into a file in the disk store. This is + * called by the SortShuffleWriter and can go through an efficient path of just concatenating + * binary files if we decided to avoid merge-sorting. + * + * @param blockId block ID to write to. The index file will be blockId.name + ".index". + * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. + * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) + */ + long[] writePartitionedFile( + BlockId blockId, + TaskContext context, + File outputFile) throws IOException; + + void stop() throws IOException; +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index c9dd6bfc4c219..5865e7640c1cf 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -17,9 +17,10 @@ package org.apache.spark.shuffle.sort -import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext} +import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus +import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle} import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.collection.ExternalSorter @@ -35,7 +36,7 @@ private[spark] class SortShuffleWriter[K, V, C]( private val blockManager = SparkEnv.get.blockManager - private var sorter: ExternalSorter[K, V, _] = null + private var sorter: SortShuffleFileWriter[K, V] = null // Are we in the process of stopping? Because map tasks can call stop() with success = true // and then call stop() with success = false if they get an exception, we want to make sure @@ -49,18 +50,27 @@ private[spark] class SortShuffleWriter[K, V, C]( /** Write a bunch of records to this task's output */ override def write(records: Iterator[Product2[K, V]]): Unit = { - if (dep.mapSideCombine) { + sorter = if (dep.mapSideCombine) { require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") - sorter = new ExternalSorter[K, V, C]( + new ExternalSorter[K, V, C]( dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) - sorter.insertAll(records) + } else if (SortShuffleWriter.shouldBypassMergeSort( + SparkEnv.get.conf, dep.partitioner.numPartitions, aggregator = None, keyOrdering = None)) { + // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't + // need local aggregation and sorting, write numPartitions files directly and just concatenate + // them at the end. This avoids doing serialization and deserialization twice to merge + // together the spilled files, which would happen with the normal code path. The downside is + // having multiple files open at a time and thus more memory allocated to buffers. + new BypassMergeSortShuffleWriter[K, V](SparkEnv.get.conf, blockManager, dep.partitioner, + writeMetrics, Serializer.getSerializer(dep.serializer)) } else { // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't // care whether the keys get sorted in each partition; that will be done on the reduce side // if the operation being run is sortByKey. - sorter = new ExternalSorter[K, V, V](None, Some(dep.partitioner), None, dep.serializer) - sorter.insertAll(records) + new ExternalSorter[K, V, V]( + aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer) } + sorter.insertAll(records) // Don't bother including the time to open the merged output file in the shuffle write time, // because it just opens a single file, so is typically too fast to measure accurately @@ -100,3 +110,13 @@ private[spark] class SortShuffleWriter[K, V, C]( } } +private[spark] object SortShuffleWriter { + def shouldBypassMergeSort( + conf: SparkConf, + numPartitions: Int, + aggregator: Option[Aggregator[_, _, _]], + keyOrdering: Option[Ordering[_]]): Boolean = { + val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + numPartitions <= bypassMergeThreshold && aggregator.isEmpty && keyOrdering.isEmpty + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index a33f22ef52687..7eeabd1e0489c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -95,6 +95,7 @@ private[spark] class DiskBlockObjectWriter( private var objOut: SerializationStream = null private var initialized = false private var hasBeenClosed = false + private var commitAndCloseHasBeenCalled = false /** * Cursors used to represent positions in the file. @@ -167,20 +168,22 @@ private[spark] class DiskBlockObjectWriter( objOut.flush() bs.flush() close() + finalPosition = file.length() + // In certain compression codecs, more bytes are written after close() is called + writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition) + } else { + finalPosition = file.length() } - finalPosition = file.length() - // In certain compression codecs, more bytes are written after close() is called - writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition) + commitAndCloseHasBeenCalled = true } // Discard current writes. We do this by flushing the outstanding writes and then // truncating the file to its initial position. override def revertPartialWritesAndClose() { try { - writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition) - writeMetrics.decShuffleRecordsWritten(numRecordsWritten) - if (initialized) { + writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition) + writeMetrics.decShuffleRecordsWritten(numRecordsWritten) objOut.flush() bs.flush() close() @@ -228,6 +231,10 @@ private[spark] class DiskBlockObjectWriter( } override def fileSegment(): FileSegment = { + if (!commitAndCloseHasBeenCalled) { + throw new IllegalStateException( + "fileSegment() is only valid after commitAndClose() has been called") + } new FileSegment(file, initialPosition, finalPosition - initialPosition) } diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala index 95e2d688d9b17..021a9facfb0b2 100644 --- a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala +++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala @@ -24,6 +24,8 @@ import java.io.File * based off an offset and a length. */ private[spark] class FileSegment(val file: File, val offset: Long, val length: Long) { + require(offset >= 0, s"File segment offset cannot be negative (got $offset)") + require(length >= 0, s"File segment length cannot be negative (got $length)") override def toString: String = { "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length) } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 3b9d14f9372b6..ef2dbb7ff0ae0 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -23,12 +23,14 @@ import java.util.Comparator import scala.collection.mutable.ArrayBuffer import scala.collection.mutable +import com.google.common.annotations.VisibleForTesting import com.google.common.io.ByteStreams import org.apache.spark._ import org.apache.spark.serializer._ import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.storage.{BlockObjectWriter, BlockId} +import org.apache.spark.shuffle.sort.{SortShuffleFileWriter, SortShuffleWriter} +import org.apache.spark.storage.{BlockId, BlockObjectWriter} /** * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner @@ -84,35 +86,40 @@ import org.apache.spark.storage.{BlockObjectWriter, BlockId} * each other for equality to merge values. * * - Users are expected to call stop() at the end to delete all the intermediate files. - * - * As a special case, if no Ordering and no Aggregator is given, and the number of partitions is - * less than spark.shuffle.sort.bypassMergeThreshold, we bypass the merge-sort and just write to - * separate files for each partition each time we spill, similar to the HashShuffleWriter. We can - * then concatenate these files to produce a single sorted file, without having to serialize and - * de-serialize each item twice (as is needed during the merge). This speeds up the map side of - * groupBy, sort, etc operations since they do no partial aggregation. */ private[spark] class ExternalSorter[K, V, C]( aggregator: Option[Aggregator[K, V, C]] = None, partitioner: Option[Partitioner] = None, ordering: Option[Ordering[K]] = None, serializer: Option[Serializer] = None) - extends Logging with Spillable[WritablePartitionedPairCollection[K, C]] { + extends Logging + with Spillable[WritablePartitionedPairCollection[K, C]] + with SortShuffleFileWriter[K, V] { + + private val conf = SparkEnv.get.conf private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1) private val shouldPartition = numPartitions > 1 + private def getPartition(key: K): Int = { + if (shouldPartition) partitioner.get.getPartition(key) else 0 + } + + // Since SPARK-7855, bypassMergeSort optimization is no longer performed as part of this class. + // As a sanity check, make sure that we're not handling a shuffle which should use that path. + if (SortShuffleWriter.shouldBypassMergeSort(conf, numPartitions, aggregator, ordering)) { + throw new IllegalArgumentException("ExternalSorter should not be used to handle " + + " a sort that the BypassMergeSortShuffleWriter should handle") + } private val blockManager = SparkEnv.get.blockManager private val diskBlockManager = blockManager.diskBlockManager private val ser = Serializer.getSerializer(serializer) private val serInstance = ser.newInstance() - private val conf = SparkEnv.get.conf private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 - private val transferToEnabled = conf.getBoolean("spark.file.transferTo", true) // Size of object batches when reading/writing from serializers. // @@ -123,43 +130,28 @@ private[spark] class ExternalSorter[K, V, C]( // grow internal data structures by growing + copying every time the number of objects doubles. private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000) - private def getPartition(key: K): Int = { - if (shouldPartition) partitioner.get.getPartition(key) else 0 - } - - private val metaInitialRecords = 256 - private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB private val useSerializedPairBuffer = - !ordering.isDefined && conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) && - ser.supportsRelocationOfSerializedObjects - + ordering.isEmpty && + conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) && + ser.supportsRelocationOfSerializedObjects + private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB + private def newBuffer(): WritablePartitionedPairCollection[K, C] with SizeTracker = { + if (useSerializedPairBuffer) { + new PartitionedSerializedPairBuffer(metaInitialRecords = 256, kvChunkSize, serInstance) + } else { + new PartitionedPairBuffer[K, C] + } + } // Data structures to store in-memory objects before we spill. Depending on whether we have an // Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we // store them in an array buffer. private var map = new PartitionedAppendOnlyMap[K, C] - private var buffer = if (useSerializedPairBuffer) { - new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize, serInstance) - } else { - new PartitionedPairBuffer[K, C] - } + private var buffer = newBuffer() // Total spilling statistics private var _diskBytesSpilled = 0L + def diskBytesSpilled: Long = _diskBytesSpilled - // Write metrics for current spill - private var curWriteMetrics: ShuffleWriteMetrics = _ - - // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't need - // local aggregation and sorting, write numPartitions files directly and just concatenate them - // at the end. This avoids doing serialization and deserialization twice to merge together the - // spilled files, which would happen with the normal code path. The downside is having multiple - // files open at a time and thus more memory allocated to buffers. - private val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - private val bypassMergeSort = - (numPartitions <= bypassMergeThreshold && aggregator.isEmpty && ordering.isEmpty) - - // Array of file writers for each partition, used if bypassMergeSort is true and we've spilled - private var partitionWriters: Array[BlockObjectWriter] = null // A comparator for keys K that orders them within a partition to allow aggregation or sorting. // Can be a partial ordering by hash code if a total ordering is not provided through by the @@ -174,6 +166,14 @@ private[spark] class ExternalSorter[K, V, C]( } }) + private def comparator: Option[Comparator[K]] = { + if (ordering.isDefined || aggregator.isDefined) { + Some(keyComparator) + } else { + None + } + } + // Information about a spilled file. Includes sizes in bytes of "batches" written by the // serializer as we periodically reset its stream, as well as number of elements in each // partition, used to efficiently keep track of partitions when merging. @@ -182,9 +182,10 @@ private[spark] class ExternalSorter[K, V, C]( blockId: BlockId, serializerBatchSizes: Array[Long], elementsPerPartition: Array[Long]) + private val spills = new ArrayBuffer[SpilledFile] - def insertAll(records: Iterator[_ <: Product2[K, V]]): Unit = { + override def insertAll(records: Iterator[Product2[K, V]]): Unit = { // TODO: stop combining if we find that the reduction factor isn't high val shouldCombine = aggregator.isDefined @@ -202,15 +203,6 @@ private[spark] class ExternalSorter[K, V, C]( map.changeValue((getPartition(kv._1), kv._1), update) maybeSpillCollection(usingMap = true) } - } else if (bypassMergeSort) { - // SPARK-4479: Also bypass buffering if merge sort is bypassed to avoid defensive copies - if (records.hasNext) { - spillToPartitionFiles( - WritablePartitionedIterator.fromIterator(records.map { kv => - ((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C]) - }) - ) - } } else { // Stick values into our buffer while (records.hasNext) { @@ -238,46 +230,33 @@ private[spark] class ExternalSorter[K, V, C]( } } else { if (maybeSpill(buffer, buffer.estimateSize())) { - buffer = if (useSerializedPairBuffer) { - new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize, serInstance) - } else { - new PartitionedPairBuffer[K, C] - } + buffer = newBuffer() } } } /** - * Spill the current in-memory collection to disk, adding a new file to spills, and clear it. - */ - override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = { - if (bypassMergeSort) { - spillToPartitionFiles(collection) - } else { - spillToMergeableFile(collection) - } - } - - /** - * Spill our in-memory collection to a sorted file that we can merge later (normal code path). - * We add this file into spilledFiles to find it later. - * - * This should not be invoked if bypassMergeSort is true. In that case, spillToPartitionedFiles() - * is used to write files for each partition. + * Spill our in-memory collection to a sorted file that we can merge later. + * We add this file into `spilledFiles` to find it later. * * @param collection whichever collection we're using (map or buffer) */ - private def spillToMergeableFile(collection: WritablePartitionedPairCollection[K, C]): Unit = { - assert(!bypassMergeSort) - + override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = { // Because these files may be read during shuffle, their compression must be controlled by // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use // createTempShuffleBlock here; see SPARK-3426 for more context. val (blockId, file) = diskBlockManager.createTempShuffleBlock() - curWriteMetrics = new ShuffleWriteMetrics() - var writer = blockManager.getDiskWriter( - blockId, file, serInstance, fileBufferSize, curWriteMetrics) - var objectsWritten = 0 // Objects written since the last flush + + // These variables are reset after each flush + var objectsWritten: Long = 0 + var spillMetrics: ShuffleWriteMetrics = null + var writer: BlockObjectWriter = null + def openWriter(): Unit = { + assert (writer == null && spillMetrics == null) + spillMetrics = new ShuffleWriteMetrics + writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics) + } + openWriter() // List of batch sizes (bytes) in the order they are written to disk val batchSizes = new ArrayBuffer[Long] @@ -291,8 +270,9 @@ private[spark] class ExternalSorter[K, V, C]( val w = writer writer = null w.commitAndClose() - _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten - batchSizes.append(curWriteMetrics.shuffleBytesWritten) + _diskBytesSpilled += spillMetrics.shuffleBytesWritten + batchSizes.append(spillMetrics.shuffleBytesWritten) + spillMetrics = null objectsWritten = 0 } @@ -307,9 +287,7 @@ private[spark] class ExternalSorter[K, V, C]( if (objectsWritten == serializerBatchSize) { flush() - curWriteMetrics = new ShuffleWriteMetrics() - writer = blockManager.getDiskWriter( - blockId, file, serInstance, fileBufferSize, curWriteMetrics) + openWriter() } } if (objectsWritten > 0) { @@ -336,46 +314,6 @@ private[spark] class ExternalSorter[K, V, C]( spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)) } - /** - * Spill our in-memory collection to separate files, one for each partition. This is used when - * there's no aggregator and ordering and the number of partitions is small, because it allows - * writePartitionedFile to just concatenate files without deserializing data. - * - * @param collection whichever collection we're using (map or buffer) - */ - private def spillToPartitionFiles(collection: WritablePartitionedPairCollection[K, C]): Unit = { - spillToPartitionFiles(collection.writablePartitionedIterator()) - } - - private def spillToPartitionFiles(iterator: WritablePartitionedIterator): Unit = { - assert(bypassMergeSort) - - // Create our file writers if we haven't done so yet - if (partitionWriters == null) { - curWriteMetrics = new ShuffleWriteMetrics() - val openStartTime = System.nanoTime - partitionWriters = Array.fill(numPartitions) { - // Because these files may be read during shuffle, their compression must be controlled by - // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use - // createTempShuffleBlock here; see SPARK-3426 for more context. - val (blockId, file) = diskBlockManager.createTempShuffleBlock() - val writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, - curWriteMetrics) - writer.open() - } - // Creating the file to write to and creating a disk writer both involve interacting with - // the disk, and can take a long time in aggregate when we open many files, so should be - // included in the shuffle write time. - curWriteMetrics.incShuffleWriteTime(System.nanoTime - openStartTime) - } - - // No need to sort stuff, just write each element out - while (iterator.hasNext) { - val partitionId = iterator.nextPartition() - iterator.writeNext(partitionWriters(partitionId)) - } - } - /** * Merge a sequence of sorted files, giving an iterator over partitions and then over elements * inside each partition. This can be used to either write out a new file or return data to @@ -665,8 +603,6 @@ private[spark] class ExternalSorter[K, V, C]( } /** - * Exposed for testing purposes. - * * Return an iterator over all the data written to this object, grouped by partition and * aggregated by the requested aggregator. For each partition we then have an iterator over its * contents, and these are expected to be accessed in order (you can't "skip ahead" to one @@ -676,10 +612,11 @@ private[spark] class ExternalSorter[K, V, C]( * For now, we just merge all the spilled files in once pass, but this can be modified to * support hierarchical merging. */ - def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { + @VisibleForTesting + def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { val usingMap = aggregator.isDefined val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer - if (spills.isEmpty && partitionWriters == null) { + if (spills.isEmpty) { // Special case: if we have only in-memory data, we don't need to merge streams, and perhaps // we don't even need to sort by anything other than partition ID if (!ordering.isDefined) { @@ -689,13 +626,6 @@ private[spark] class ExternalSorter[K, V, C]( // We do need to sort by both partition ID and key groupByPartition(collection.partitionedDestructiveSortedIterator(Some(keyComparator))) } - } else if (bypassMergeSort) { - // Read data from each partition file and merge it together with the data in memory; - // note that there's no ordering or aggregator in this case -- we just partition objects - val collIter = groupByPartition(collection.partitionedDestructiveSortedIterator(None)) - collIter.map { case (partitionId, values) => - (partitionId, values ++ readPartitionFile(partitionWriters(partitionId))) - } } else { // Merge spilled and in-memory data merge(spills, collection.partitionedDestructiveSortedIterator(comparator)) @@ -709,14 +639,13 @@ private[spark] class ExternalSorter[K, V, C]( /** * Write all the data added into this ExternalSorter into a file in the disk store. This is - * called by the SortShuffleWriter and can go through an efficient path of just concatenating - * binary files if we decided to avoid merge-sorting. + * called by the SortShuffleWriter. * * @param blockId block ID to write to. The index file will be blockId.name + ".index". * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) */ - def writePartitionedFile( + override def writePartitionedFile( blockId: BlockId, context: TaskContext, outputFile: File): Array[Long] = { @@ -724,28 +653,7 @@ private[spark] class ExternalSorter[K, V, C]( // Track location of each range in the output file val lengths = new Array[Long](numPartitions) - if (bypassMergeSort && partitionWriters != null) { - // We decided to write separate files for each partition, so just concatenate them. To keep - // this simple we spill out the current in-memory collection so that everything is in files. - spillToPartitionFiles(if (aggregator.isDefined) map else buffer) - partitionWriters.foreach(_.commitAndClose()) - val out = new FileOutputStream(outputFile, true) - val writeStartTime = System.nanoTime - util.Utils.tryWithSafeFinally { - for (i <- 0 until numPartitions) { - val in = new FileInputStream(partitionWriters(i).fileSegment().file) - util.Utils.tryWithSafeFinally { - lengths(i) = org.apache.spark.util.Utils.copyStream(in, out, false, transferToEnabled) - } { - in.close() - } - } - } { - out.close() - context.taskMetrics.shuffleWriteMetrics.foreach( - _.incShuffleWriteTime(System.nanoTime - writeStartTime)) - } - } else if (spills.isEmpty && partitionWriters == null) { + if (spills.isEmpty) { // Case where we only have in-memory data val collection = if (aggregator.isDefined) map else buffer val it = collection.destructiveSortedWritablePartitionedIterator(comparator) @@ -761,7 +669,7 @@ private[spark] class ExternalSorter[K, V, C]( lengths(partitionId) = segment.length } } else { - // Not bypassing merge-sort; get an iterator by partition and just write everything directly. + // We must perform merge-sort; get an iterator by partition and write everything directly. for ((id, elements) <- this.partitionedIterator) { if (elements.hasNext) { val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize, @@ -778,41 +686,15 @@ private[spark] class ExternalSorter[K, V, C]( context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) - context.taskMetrics.shuffleWriteMetrics.filter(_ => bypassMergeSort).foreach { m => - if (curWriteMetrics != null) { - m.incShuffleBytesWritten(curWriteMetrics.shuffleBytesWritten) - m.incShuffleWriteTime(curWriteMetrics.shuffleWriteTime) - m.incShuffleRecordsWritten(curWriteMetrics.shuffleRecordsWritten) - } - } lengths } - /** - * Read a partition file back as an iterator (used in our iterator method) - */ - private def readPartitionFile(writer: BlockObjectWriter): Iterator[Product2[K, C]] = { - if (writer.isOpen) { - writer.commitAndClose() - } - new PairIterator[K, C](blockManager.diskStore.getValues(writer.blockId, ser).get) - } - def stop(): Unit = { spills.foreach(s => s.file.delete()) spills.clear() - if (partitionWriters != null) { - partitionWriters.foreach { w => - w.revertPartialWritesAndClose() - diskBlockManager.getFile(w.blockId).delete() - } - partitionWriters = null - } } - def diskBytesSpilled: Long = _diskBytesSpilled - /** * Given a stream of ((partition, key), combiner) pairs *assumed to be sorted by partition ID*, * group together the pairs for each partition into a sub-iterator. @@ -826,14 +708,6 @@ private[spark] class ExternalSorter[K, V, C]( (0 until numPartitions).iterator.map(p => (p, new IteratorForPartition(p, buffered))) } - private def comparator: Option[Comparator[K]] = { - if (ordering.isDefined || aggregator.isDefined) { - Some(keyComparator) - } else { - None - } - } - /** * An iterator that reads only the elements for a given partition ID from an underlying buffered * stream, assuming this partition is the next one to be read. Used to make it easier to return diff --git a/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala b/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala deleted file mode 100644 index d75959f480756..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.collection - -private[spark] class PairIterator[K, V](iter: Iterator[Any]) extends Iterator[(K, V)] { - def hasNext: Boolean = iter.hasNext - - def next(): (K, V) = (iter.next().asInstanceOf[K], iter.next().asInstanceOf[V]) -} diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala index e2e2f1faae9d1..d0d25b43d0477 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala @@ -34,10 +34,6 @@ private[spark] class PartitionedAppendOnlyMap[K, V] destructiveSortedIterator(comparator) } - def writablePartitionedIterator(): WritablePartitionedIterator = { - WritablePartitionedIterator.fromIterator(super.iterator) - } - def insert(partition: Int, key: K, value: V): Unit = { update((partition, key), value) } diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala index e8332e1a87eac..5a6e9a9580e9b 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala @@ -71,10 +71,6 @@ private[spark] class PartitionedPairBuffer[K, V](initialCapacity: Int = 64) iterator } - override def writablePartitionedIterator(): WritablePartitionedIterator = { - WritablePartitionedIterator.fromIterator(iterator) - } - private def iterator(): Iterator[((Int, K), V)] = new Iterator[((Int, K), V)] { var pos = 0 diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala index 554d88206e221..862408b7a4d21 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala @@ -122,10 +122,6 @@ private[spark] class PartitionedSerializedPairBuffer[K, V]( override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]]) : WritablePartitionedIterator = { sort(keyComparator) - writablePartitionedIterator - } - - override def writablePartitionedIterator(): WritablePartitionedIterator = { new WritablePartitionedIterator { // current position in the meta buffer in ints var pos = 0 diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala index f26d1618c9200..7bc59898658e4 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala @@ -47,13 +47,20 @@ private[spark] trait WritablePartitionedPairCollection[K, V] { */ def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]]) : WritablePartitionedIterator = { - WritablePartitionedIterator.fromIterator(partitionedDestructiveSortedIterator(keyComparator)) - } + val it = partitionedDestructiveSortedIterator(keyComparator) + new WritablePartitionedIterator { + private[this] var cur = if (it.hasNext) it.next() else null - /** - * Iterate through the data and write out the elements instead of returning them. - */ - def writablePartitionedIterator(): WritablePartitionedIterator + def writeNext(writer: BlockObjectWriter): Unit = { + writer.write(cur._1._2, cur._2) + cur = if (it.hasNext) it.next() else null + } + + def hasNext(): Boolean = cur != null + + def nextPartition(): Int = cur._1._1 + } + } } private[spark] object WritablePartitionedPairCollection { @@ -94,20 +101,3 @@ private[spark] trait WritablePartitionedIterator { def nextPartition(): Int } - -private[spark] object WritablePartitionedIterator { - def fromIterator(it: Iterator[((Int, _), _)]): WritablePartitionedIterator = { - new WritablePartitionedIterator { - var cur = if (it.hasNext) it.next() else null - - def writeNext(writer: BlockObjectWriter): Unit = { - writer.write(cur._1._2, cur._2) - cur = if (it.hasNext) it.next() else null - } - - def hasNext(): Boolean = cur != null - - def nextPartition(): Int = cur._1._1 - } - } -} diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 91f4ab360857e..c3c2b1ffc1efa 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.Matchers import org.apache.spark.ShuffleSuite.NonJavaSerializableClass import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD} +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.storage.{ShuffleDataBlockId, ShuffleBlockId} import org.apache.spark.util.MutablePair @@ -281,6 +282,39 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // This count should retry the execution of the previous stage and rerun shuffle. rdd.count() } + + test("metrics for shuffle without aggregation") { + sc = new SparkContext("local", "test", conf.clone()) + val numRecords = 10000 + + val metrics = ShuffleSuite.runAndReturnMetrics(sc) { + sc.parallelize(1 to numRecords, 4) + .map(key => (key, 1)) + .groupByKey() + .collect() + } + + assert(metrics.recordsRead === numRecords) + assert(metrics.recordsWritten === numRecords) + assert(metrics.bytesWritten === metrics.byresRead) + assert(metrics.bytesWritten > 0) + } + + test("metrics for shuffle with aggregation") { + sc = new SparkContext("local", "test", conf.clone()) + val numRecords = 10000 + + val metrics = ShuffleSuite.runAndReturnMetrics(sc) { + sc.parallelize(1 to numRecords, 4) + .flatMap(key => Array.fill(100)((key, 1))) + .countByKey() + } + + assert(metrics.recordsRead === numRecords) + assert(metrics.recordsWritten === numRecords) + assert(metrics.bytesWritten === metrics.byresRead) + assert(metrics.bytesWritten > 0) + } } object ShuffleSuite { @@ -294,4 +328,35 @@ object ShuffleSuite { value - o.value } } + + case class AggregatedShuffleMetrics( + recordsWritten: Long, + recordsRead: Long, + bytesWritten: Long, + byresRead: Long) + + def runAndReturnMetrics(sc: SparkContext)(job: => Unit): AggregatedShuffleMetrics = { + @volatile var recordsWritten: Long = 0 + @volatile var recordsRead: Long = 0 + @volatile var bytesWritten: Long = 0 + @volatile var bytesRead: Long = 0 + val listener = new SparkListener { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + taskEnd.taskMetrics.shuffleWriteMetrics.foreach { m => + recordsWritten += m.shuffleRecordsWritten + bytesWritten += m.shuffleBytesWritten + } + taskEnd.taskMetrics.shuffleReadMetrics.foreach { m => + recordsRead += m.recordsRead + bytesRead += m.totalBytesRead + } + } + } + sc.addSparkListener(listener) + + job + + sc.listenerBus.waitUntilEmpty(500) + AggregatedShuffleMetrics(recordsWritten, recordsRead, bytesWritten, bytesRead) + } } diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 19f1af0dcd461..9e4d34fb7d382 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -193,26 +193,6 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext assert(records == numRecords) } - test("shuffle records read metrics") { - val recordsRead = runAndReturnShuffleRecordsRead { - sc.textFile(tmpFilePath, 4) - .map(key => (key, 1)) - .groupByKey() - .collect() - } - assert(recordsRead == numRecords) - } - - test("shuffle records written metrics") { - val recordsWritten = runAndReturnShuffleRecordsWritten { - sc.textFile(tmpFilePath, 4) - .map(key => (key, 1)) - .groupByKey() - .collect() - } - assert(recordsWritten == numRecords) - } - /** * Tests the metrics from end to end. * 1) reading a hadoop file @@ -301,14 +281,6 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext runAndReturnMetrics(job, _.taskMetrics.outputMetrics.map(_.recordsWritten)) } - private def runAndReturnShuffleRecordsRead(job: => Unit): Long = { - runAndReturnMetrics(job, _.taskMetrics.shuffleReadMetrics.map(_.recordsRead)) - } - - private def runAndReturnShuffleRecordsWritten(job: => Unit): Long = { - runAndReturnMetrics(job, _.taskMetrics.shuffleWriteMetrics.map(_.shuffleRecordsWritten)) - } - private def runAndReturnMetrics(job: => Unit, collector: (SparkListenerTaskEnd) => Option[Long]): Long = { val taskMetrics = new ArrayBuffer[Long]() diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala new file mode 100644 index 0000000000000..c8420db6126c0 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import java.io.File +import java.util.UUID + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.{BeforeAndAfterEach, FunSuite} + +import org.apache.spark._ +import org.apache.spark.executor.{TaskMetrics, ShuffleWriteMetrics} +import org.apache.spark.serializer.{SerializerInstance, Serializer, JavaSerializer} +import org.apache.spark.storage._ +import org.apache.spark.util.Utils + +class BypassMergeSortShuffleWriterSuite extends FunSuite with BeforeAndAfterEach { + + @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _ + @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _ + @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _ + + private var taskMetrics: TaskMetrics = _ + private var shuffleWriteMetrics: ShuffleWriteMetrics = _ + private var tempDir: File = _ + private var outputFile: File = _ + private val conf: SparkConf = new SparkConf(loadDefaults = false) + private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]() + private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File] + private val shuffleBlockId: ShuffleBlockId = new ShuffleBlockId(0, 0, 0) + private val serializer: Serializer = new JavaSerializer(conf) + + override def beforeEach(): Unit = { + tempDir = Utils.createTempDir() + outputFile = File.createTempFile("shuffle", null, tempDir) + shuffleWriteMetrics = new ShuffleWriteMetrics + taskMetrics = new TaskMetrics + taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics) + MockitoAnnotations.initMocks(this) + when(taskContext.taskMetrics()).thenReturn(taskMetrics) + when(blockManager.diskBlockManager).thenReturn(diskBlockManager) + when(blockManager.getDiskWriter( + any[BlockId], + any[File], + any[SerializerInstance], + anyInt(), + any[ShuffleWriteMetrics] + )).thenAnswer(new Answer[BlockObjectWriter] { + override def answer(invocation: InvocationOnMock): BlockObjectWriter = { + val args = invocation.getArguments + new DiskBlockObjectWriter( + args(0).asInstanceOf[BlockId], + args(1).asInstanceOf[File], + args(2).asInstanceOf[SerializerInstance], + args(3).asInstanceOf[Int], + compressStream = identity, + syncWrites = false, + args(4).asInstanceOf[ShuffleWriteMetrics] + ) + } + }) + when(diskBlockManager.createTempShuffleBlock()).thenAnswer( + new Answer[(TempShuffleBlockId, File)] { + override def answer(invocation: InvocationOnMock): (TempShuffleBlockId, File) = { + val blockId = new TempShuffleBlockId(UUID.randomUUID) + val file = File.createTempFile(blockId.toString, null, tempDir) + blockIdToFileMap.put(blockId, file) + temporaryFilesCreated.append(file) + (blockId, file) + } + }) + when(diskBlockManager.getFile(any[BlockId])).thenAnswer( + new Answer[File] { + override def answer(invocation: InvocationOnMock): File = { + blockIdToFileMap.get(invocation.getArguments.head.asInstanceOf[BlockId]).get + } + }) + } + + override def afterEach(): Unit = { + Utils.deleteRecursively(tempDir) + blockIdToFileMap.clear() + temporaryFilesCreated.clear() + } + + test("write empty iterator") { + val writer = new BypassMergeSortShuffleWriter[Int, Int]( + new SparkConf(loadDefaults = false), + blockManager, + new HashPartitioner(7), + shuffleWriteMetrics, + serializer + ) + writer.insertAll(Iterator.empty) + val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile) + assert(partitionLengths.sum === 0) + assert(outputFile.exists()) + assert(outputFile.length() === 0) + assert(temporaryFilesCreated.isEmpty) + assert(shuffleWriteMetrics.shuffleBytesWritten === 0) + assert(shuffleWriteMetrics.shuffleRecordsWritten === 0) + assert(taskMetrics.diskBytesSpilled === 0) + assert(taskMetrics.memoryBytesSpilled === 0) + } + + test("write with some empty partitions") { + def records: Iterator[(Int, Int)] = + Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) + val writer = new BypassMergeSortShuffleWriter[Int, Int]( + new SparkConf(loadDefaults = false), + blockManager, + new HashPartitioner(7), + shuffleWriteMetrics, + serializer + ) + writer.insertAll(records) + assert(temporaryFilesCreated.nonEmpty) + val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile) + assert(partitionLengths.sum === outputFile.length()) + assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted + assert(shuffleWriteMetrics.shuffleBytesWritten === outputFile.length()) + assert(shuffleWriteMetrics.shuffleRecordsWritten === records.length) + assert(taskMetrics.diskBytesSpilled === 0) + assert(taskMetrics.memoryBytesSpilled === 0) + } + + test("cleanup of intermediate files after errors") { + val writer = new BypassMergeSortShuffleWriter[Int, Int]( + new SparkConf(loadDefaults = false), + blockManager, + new HashPartitioner(7), + shuffleWriteMetrics, + serializer + ) + intercept[SparkException] { + writer.insertAll((0 until 100000).iterator.map(i => { + if (i == 99990) { + throw new SparkException("Intentional failure") + } + (i, i) + })) + } + assert(temporaryFilesCreated.nonEmpty) + writer.stop() + assert(temporaryFilesCreated.count(_.exists()) === 0) + } + +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala new file mode 100644 index 0000000000000..c6ada7139c198 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import org.mockito.Mockito._ +import org.scalatest.FunSuite + +import org.apache.spark.{Aggregator, SparkConf} + +class SortShuffleWriterSuite extends FunSuite { + + import SortShuffleWriter._ + + test("conditions for bypassing merge-sort") { + val conf = new SparkConf(loadDefaults = false) + val agg = mock(classOf[Aggregator[_, _, _]], RETURNS_SMART_NULLS) + val ord = implicitly[Ordering[Int]] + + // Numbers of partitions that are above and below the default bypassMergeThreshold + val FEW_PARTITIONS = 50 + val MANY_PARTITIONS = 10000 + + // Shuffles with no ordering or aggregator: should bypass unless # of partitions is high + assert(shouldBypassMergeSort(conf, FEW_PARTITIONS, None, None)) + assert(!shouldBypassMergeSort(conf, MANY_PARTITIONS, None, None)) + + // Shuffles with an ordering or aggregator: should not bypass even if they have few partitions + assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, None, Some(ord))) + assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, Some(agg), None)) + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala index ad43a3e5fdc88..7bdea724fea58 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala @@ -18,14 +18,28 @@ package org.apache.spark.storage import java.io.File +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.SparkConf import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.Utils -class BlockObjectWriterSuite extends SparkFunSuite { +class BlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { + + var tempDir: File = _ + + override def beforeEach(): Unit = { + tempDir = Utils.createTempDir() + } + + override def afterEach(): Unit = { + Utils.deleteRecursively(tempDir) + } + test("verify write metrics") { - val file = new File(Utils.createTempDir(), "somefile") + val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) @@ -47,7 +61,7 @@ class BlockObjectWriterSuite extends SparkFunSuite { } test("verify write metrics on revert") { - val file = new File(Utils.createTempDir(), "somefile") + val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) @@ -70,7 +84,7 @@ class BlockObjectWriterSuite extends SparkFunSuite { } test("Reopening a closed block writer") { - val file = new File(Utils.createTempDir(), "somefile") + val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) @@ -81,4 +95,79 @@ class BlockObjectWriterSuite extends SparkFunSuite { writer.open() } } + + test("calling revertPartialWritesAndClose() on a closed block writer should have no effect") { + val file = new File(tempDir, "somefile") + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + for (i <- 1 to 1000) { + writer.write(i, i) + } + writer.commitAndClose() + val bytesWritten = writeMetrics.shuffleBytesWritten + assert(writeMetrics.shuffleRecordsWritten === 1000) + writer.revertPartialWritesAndClose() + assert(writeMetrics.shuffleRecordsWritten === 1000) + assert(writeMetrics.shuffleBytesWritten === bytesWritten) + } + + test("commitAndClose() should be idempotent") { + val file = new File(tempDir, "somefile") + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + for (i <- 1 to 1000) { + writer.write(i, i) + } + writer.commitAndClose() + val bytesWritten = writeMetrics.shuffleBytesWritten + val writeTime = writeMetrics.shuffleWriteTime + assert(writeMetrics.shuffleRecordsWritten === 1000) + writer.commitAndClose() + assert(writeMetrics.shuffleRecordsWritten === 1000) + assert(writeMetrics.shuffleBytesWritten === bytesWritten) + assert(writeMetrics.shuffleWriteTime === writeTime) + } + + test("revertPartialWritesAndClose() should be idempotent") { + val file = new File(tempDir, "somefile") + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + for (i <- 1 to 1000) { + writer.write(i, i) + } + writer.revertPartialWritesAndClose() + val bytesWritten = writeMetrics.shuffleBytesWritten + val writeTime = writeMetrics.shuffleWriteTime + assert(writeMetrics.shuffleRecordsWritten === 0) + writer.revertPartialWritesAndClose() + assert(writeMetrics.shuffleRecordsWritten === 0) + assert(writeMetrics.shuffleBytesWritten === bytesWritten) + assert(writeMetrics.shuffleWriteTime === writeTime) + } + + test("fileSegment() can only be called after commitAndClose() has been called") { + val file = new File(tempDir, "somefile") + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + for (i <- 1 to 1000) { + writer.write(i, i) + } + intercept[IllegalStateException] { + writer.fileSegment() + } + writer.close() + } + + test("commitAndClose() without ever opening or writing") { + val file = new File(tempDir, "somefile") + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + writer.commitAndClose() + assert(writer.fileSegment().length === 0) + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 9039dbef1fb71..7d7b41bc23284 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -23,10 +23,12 @@ import org.scalatest.PrivateMethodTester import scala.util.Random +import org.scalatest.FunSuite + import org.apache.spark._ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} -class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with PrivateMethodTester { +class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { private def createSparkConf(loadDefaults: Boolean, kryo: Boolean): SparkConf = { val conf = new SparkConf(loadDefaults) if (kryo) { @@ -37,21 +39,12 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv conf.set("spark.serializer.objectStreamReset", "1") conf.set("spark.serializer", classOf[JavaSerializer].getName) } + conf.set("spark.shuffle.sort.bypassMergeThreshold", "0") // Ensure that we actually have multiple batches per spill file conf.set("spark.shuffle.spill.batchSize", "10") conf } - private def assertBypassedMergeSort(sorter: ExternalSorter[_, _, _]): Unit = { - val bypassMergeSort = PrivateMethod[Boolean]('bypassMergeSort) - assert(sorter.invokePrivate(bypassMergeSort()), "sorter did not bypass merge-sort") - } - - private def assertDidNotBypassMergeSort(sorter: ExternalSorter[_, _, _]): Unit = { - val bypassMergeSort = PrivateMethod[Boolean]('bypassMergeSort) - assert(!sorter.invokePrivate(bypassMergeSort()), "sorter bypassed merge-sort") - } - test("empty data stream with kryo ser") { emptyDataStream(createSparkConf(false, true)) } @@ -161,39 +154,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv val sorter = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(7)), Some(ord), None) - assertDidNotBypassMergeSort(sorter) - sorter.insertAll(elements) - assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled - val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList)) - assert(iter.next() === (0, Nil)) - assert(iter.next() === (1, List((1, 1)))) - assert(iter.next() === (2, (0 until 100000).map(x => (2, 2)).toList)) - assert(iter.next() === (3, Nil)) - assert(iter.next() === (4, Nil)) - assert(iter.next() === (5, List((5, 5)))) - assert(iter.next() === (6, Nil)) - sorter.stop() - } - - test("empty partitions with spilling, bypass merge-sort with kryo ser") { - emptyPartitionerWithSpillingBypassMergeSort(createSparkConf(false, true)) - } - - test("empty partitions with spilling, bypass merge-sort with java ser") { - emptyPartitionerWithSpillingBypassMergeSort(createSparkConf(false, false)) - } - - def emptyPartitionerWithSpillingBypassMergeSort(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.spill.initialMemoryThreshold", "512") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - - val elements = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) - - val sorter = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(7)), None, None) - assertBypassedMergeSort(sorter) sorter.insertAll(elements) assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList)) @@ -376,7 +336,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv val sorter = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(3)), Some(ord), None) - assertDidNotBypassMergeSort(sorter) sorter.insertAll((0 until 120000).iterator.map(i => (i, i))) assert(diskBlockManager.getAllFiles().length > 0) sorter.stop() @@ -384,7 +343,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv val sorter2 = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(3)), Some(ord), None) - assertDidNotBypassMergeSort(sorter2) sorter2.insertAll((0 until 120000).iterator.map(i => (i, i))) assert(diskBlockManager.getAllFiles().length > 0) assert(sorter2.iterator.toSet === (0 until 120000).map(i => (i, i)).toSet) @@ -392,29 +350,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv assert(diskBlockManager.getAllBlocks().length === 0) } - test("cleanup of intermediate files in sorter, bypass merge-sort") { - val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager - - val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) - assertBypassedMergeSort(sorter) - sorter.insertAll((0 until 100000).iterator.map(i => (i, i))) - assert(diskBlockManager.getAllFiles().length > 0) - sorter.stop() - assert(diskBlockManager.getAllBlocks().length === 0) - - val sorter2 = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) - assertBypassedMergeSort(sorter2) - sorter2.insertAll((0 until 100000).iterator.map(i => (i, i))) - assert(diskBlockManager.getAllFiles().length > 0) - assert(sorter2.iterator.toSet === (0 until 100000).map(i => (i, i)).toSet) - sorter2.stop() - assert(diskBlockManager.getAllBlocks().length === 0) - } - test("cleanup of intermediate files in sorter if there are errors") { val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found conf.set("spark.shuffle.memoryFraction", "0.001") @@ -426,7 +361,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv val sorter = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(3)), Some(ord), None) - assertDidNotBypassMergeSort(sorter) intercept[SparkException] { sorter.insertAll((0 until 120000).iterator.map(i => { if (i == 119990) { @@ -440,28 +374,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv assert(diskBlockManager.getAllBlocks().length === 0) } - test("cleanup of intermediate files in sorter if there are errors, bypass merge-sort") { - val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager - - val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) - assertBypassedMergeSort(sorter) - intercept[SparkException] { - sorter.insertAll((0 until 100000).iterator.map(i => { - if (i == 99990) { - throw new SparkException("Intentional failure") - } - (i, i) - })) - } - assert(diskBlockManager.getAllFiles().length > 0) - sorter.stop() - assert(diskBlockManager.getAllBlocks().length === 0) - } - test("cleanup of intermediate files in shuffle") { val conf = createSparkConf(false, false) conf.set("spark.shuffle.memoryFraction", "0.001") @@ -776,40 +688,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv } } - test("conditions for bypassing merge-sort") { - val conf = createSparkConf(false, false) - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - - val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) - val ord = implicitly[Ordering[Int]] - - // Numbers of partitions that are above and below the default bypassMergeThreshold - val FEW_PARTITIONS = 50 - val MANY_PARTITIONS = 10000 - - // Sorters with no ordering or aggregator: should bypass unless # of partitions is high - - val sorter1 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(FEW_PARTITIONS)), None, None) - assertBypassedMergeSort(sorter1) - - val sorter2 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(MANY_PARTITIONS)), None, None) - assertDidNotBypassMergeSort(sorter2) - - // Sorters with an ordering or aggregator: should not bypass even if they have few partitions - - val sorter3 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(FEW_PARTITIONS)), Some(ord), None) - assertDidNotBypassMergeSort(sorter3) - - val sorter4 = new ExternalSorter[Int, Int, Int]( - Some(agg), Some(new HashPartitioner(FEW_PARTITIONS)), None, None) - assertDidNotBypassMergeSort(sorter4) - } - test("sort without breaking sorting contracts with kryo ser") { sortWithoutBreakingSortingContracts(createSparkConf(true, true)) } From 1617363fbb9b22a2eb09e7bab98c8d05f9508761 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sat, 30 May 2015 16:24:07 -0700 Subject: [PATCH 095/254] [SPARK-7918] [MLLIB] MLlib Python doc parity check for evaluation and feature Check then make the MLlib Python evaluation and feature doc to be as complete as the Scala doc. Author: Yanbo Liang Closes #6461 from yanboliang/spark-7918 and squashes the following commits: 940e3f1 [Yanbo Liang] truncate too long line and remove extra sparse a80ae58 [Yanbo Liang] MLlib Python doc parity check for evaluation and feature --- python/pyspark/mllib/evaluation.py | 26 ++++++++-------- python/pyspark/mllib/feature.py | 49 ++++++++++++++---------------- 2 files changed, 36 insertions(+), 39 deletions(-) diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index aab5e5f4b77b5..c5cf3a4e7ff22 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -27,6 +27,8 @@ class BinaryClassificationMetrics(JavaModelWrapper): """ Evaluator for binary classification. + :param scoreAndLabels: an RDD of (score, label) pairs + >>> scoreAndLabels = sc.parallelize([ ... (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2) >>> metrics = BinaryClassificationMetrics(scoreAndLabels) @@ -38,9 +40,6 @@ class BinaryClassificationMetrics(JavaModelWrapper): """ def __init__(self, scoreAndLabels): - """ - :param scoreAndLabels: an RDD of (score, label) pairs - """ sc = scoreAndLabels.ctx sql_ctx = SQLContext(sc) df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([ @@ -76,6 +75,9 @@ class RegressionMetrics(JavaModelWrapper): """ Evaluator for regression. + :param predictionAndObservations: an RDD of (prediction, + observation) pairs. + >>> predictionAndObservations = sc.parallelize([ ... (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)]) >>> metrics = RegressionMetrics(predictionAndObservations) @@ -92,9 +94,6 @@ class RegressionMetrics(JavaModelWrapper): """ def __init__(self, predictionAndObservations): - """ - :param predictionAndObservations: an RDD of (prediction, observation) pairs. - """ sc = predictionAndObservations.ctx sql_ctx = SQLContext(sc) df = sql_ctx.createDataFrame(predictionAndObservations, schema=StructType([ @@ -148,6 +147,8 @@ class MulticlassMetrics(JavaModelWrapper): """ Evaluator for multiclass classification. + :param predictionAndLabels an RDD of (prediction, label) pairs. + >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]) >>> metrics = MulticlassMetrics(predictionAndLabels) @@ -176,9 +177,6 @@ class MulticlassMetrics(JavaModelWrapper): """ def __init__(self, predictionAndLabels): - """ - :param predictionAndLabels an RDD of (prediction, label) pairs. - """ sc = predictionAndLabels.ctx sql_ctx = SQLContext(sc) df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([ @@ -277,6 +275,9 @@ class RankingMetrics(JavaModelWrapper): """ Evaluator for ranking algorithms. + :param predictionAndLabels: an RDD of (predicted ranking, + ground truth set) pairs. + >>> predictionAndLabels = sc.parallelize([ ... ([1, 6, 2, 7, 8, 3, 9, 10, 4, 5], [1, 2, 3, 4, 5]), ... ([4, 1, 5, 6, 2, 7, 3, 8, 9, 10], [1, 2, 3]), @@ -298,9 +299,6 @@ class RankingMetrics(JavaModelWrapper): """ def __init__(self, predictionAndLabels): - """ - :param predictionAndLabels: an RDD of (predicted ranking, ground truth set) pairs. - """ sc = predictionAndLabels.ctx sql_ctx = SQLContext(sc) df = sql_ctx.createDataFrame(predictionAndLabels, @@ -347,6 +345,10 @@ class MultilabelMetrics(JavaModelWrapper): """ Evaluator for multilabel classification. + :param predictionAndLabels: an RDD of (predictions, labels) pairs, + both are non-null Arrays, each with + unique elements. + >>> predictionAndLabels = sc.parallelize([([0.0, 1.0], [0.0, 2.0]), ([0.0, 2.0], [0.0, 1.0]), ... ([], [0.0]), ([2.0], [2.0]), ([2.0, 0.0], [2.0, 0.0]), ... ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])]) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index aac305db6c19a..da90554f41437 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -68,6 +68,8 @@ class Normalizer(VectorTransformer): For `p` = float('inf'), max(abs(vector)) will be used as norm for normalization. + :param p: Normalization in L^p^ space, p = 2 by default. + >>> v = Vectors.dense(range(3)) >>> nor = Normalizer(1) >>> nor.transform(v) @@ -82,9 +84,6 @@ class Normalizer(VectorTransformer): DenseVector([0.0, 0.5, 1.0]) """ def __init__(self, p=2.0): - """ - :param p: Normalization in L^p^ space, p = 2 by default. - """ assert p >= 1.0, "p should be greater than 1.0" self.p = float(p) @@ -94,7 +93,7 @@ def transform(self, vector): :param vector: vector or RDD of vector to be normalized. :return: normalized vector. If the norm of the input is zero, it - will return the input vector. + will return the input vector. """ sc = SparkContext._active_spark_context assert sc is not None, "SparkContext should be initialized first" @@ -164,6 +163,13 @@ class StandardScaler(object): variance using column summary statistics on the samples in the training set. + :param withMean: False by default. Centers the data with mean + before scaling. It will build a dense output, so this + does not work on sparse input and will raise an + exception. + :param withStd: True by default. Scales the data to unit + standard deviation. + >>> vs = [Vectors.dense([-2.0, 2.3, 0]), Vectors.dense([3.8, 0.0, 1.9])] >>> dataset = sc.parallelize(vs) >>> standardizer = StandardScaler(True, True) @@ -174,14 +180,6 @@ class StandardScaler(object): DenseVector([0.7071, -0.7071, 0.7071]) """ def __init__(self, withMean=False, withStd=True): - """ - :param withMean: False by default. Centers the data with mean - before scaling. It will build a dense output, so this - does not work on sparse input and will raise an - exception. - :param withStd: True by default. Scales the data to unit - standard deviation. - """ if not (withMean or withStd): warnings.warn("Both withMean and withStd are false. The model does nothing.") self.withMean = withMean @@ -193,7 +191,7 @@ def fit(self, dataset): for later scaling. :param data: The data used to compute the mean and variance - to build the transformation model. + to build the transformation model. :return: a StandardScalarModel """ dataset = dataset.map(_convert_to_vector) @@ -223,6 +221,8 @@ class ChiSqSelector(object): Creates a ChiSquared feature selector. + :param numTopFeatures: number of features that selector will select. + >>> data = [ ... LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})), ... LabeledPoint(1.0, SparseVector(3, {1: 9.0, 2: 6.0})), @@ -236,9 +236,6 @@ class ChiSqSelector(object): DenseVector([5.0]) """ def __init__(self, numTopFeatures): - """ - :param numTopFeatures: number of features that selector will select. - """ self.numTopFeatures = int(numTopFeatures) def fit(self, data): @@ -246,9 +243,9 @@ def fit(self, data): Returns a ChiSquared feature selector. :param data: an `RDD[LabeledPoint]` containing the labeled dataset - with categorical features. Real-valued features will be - treated as categorical for each distinct value. - Apply feature discretizer before using this function. + with categorical features. Real-valued features will be + treated as categorical for each distinct value. + Apply feature discretizer before using this function. """ jmodel = callMLlibFunc("fitChiSqSelector", self.numTopFeatures, data) return ChiSqSelectorModel(jmodel) @@ -263,15 +260,14 @@ class HashingTF(object): Note: the terms must be hashable (can not be dict/set/list...). + :param numFeatures: number of features (default: 2^20) + >>> htf = HashingTF(100) >>> doc = "a a b b c d".split(" ") >>> htf.transform(doc) SparseVector(100, {...}) """ def __init__(self, numFeatures=1 << 20): - """ - :param numFeatures: number of features (default: 2^20) - """ self.numFeatures = numFeatures def indexOf(self, term): @@ -311,7 +307,7 @@ def transform(self, x): Call transform directly on the RDD instead. :param x: an RDD of term frequency vectors or a term frequency - vector + vector :return: an RDD of TF-IDF vectors or a TF-IDF vector """ if isinstance(x, RDD): @@ -342,6 +338,9 @@ class IDF(object): `minDocFreq`). For terms that are not in at least `minDocFreq` documents, the IDF is found as 0, resulting in TF-IDFs of 0. + :param minDocFreq: minimum of documents in which a term + should appear for filtering + >>> n = 4 >>> freqs = [Vectors.sparse(n, (1, 3), (1.0, 2.0)), ... Vectors.dense([0.0, 1.0, 2.0, 3.0]), @@ -362,10 +361,6 @@ class IDF(object): SparseVector(4, {1: 0.0, 3: 0.5754}) """ def __init__(self, minDocFreq=0): - """ - :param minDocFreq: minimum of documents in which a term - should appear for filtering - """ self.minDocFreq = minDocFreq def fit(self, dataset): From 1281a3518802bfa624618236e6b9b59bc0e78585 Mon Sep 17 00:00:00 2001 From: Mike Dusenberry Date: Sat, 30 May 2015 16:50:59 -0700 Subject: [PATCH 096/254] [SPARK-7920] [MLLIB] Make MLlib ChiSqSelector Serializable (& Fix Related Documentation Example). The MLlib ChiSqSelector class is not serializable, and so the example in the ChiSqSelector documentation fails. Also, that example is missing the import of ChiSqSelector. This PR makes ChiSqSelector extend Serializable in MLlib, and adds the ChiSqSelector import statement to the associated example in the documentation. Author: Mike Dusenberry Closes #6462 from dusenberrymw/Make_ChiSqSelector_Serializable_and_Fix_Related_Docs_Example and squashes the following commits: 9cb2f94 [Mike Dusenberry] Make MLlib ChiSqSelector Serializable. d9003bf [Mike Dusenberry] Add missing import in MLlib ChiSqSelector Docs Scala example. --- docs/mllib-feature-extraction.md | 1 + .../scala/org/apache/spark/mllib/feature/ChiSqSelector.scala | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 764985d436ead..1f6ad8b13d730 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -410,6 +410,7 @@ import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.feature.ChiSqSelector // Load some data in libsvm format val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 9cc2d0ffcab7d..5f8c1dea237b4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -108,7 +108,7 @@ class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransf * (ordered by statistic value descending) */ @Experimental -class ChiSqSelector (val numTopFeatures: Int) { +class ChiSqSelector (val numTopFeatures: Int) extends Serializable { /** * Returns a ChiSquared feature selector. From 66a53a69643e0004742667e140bad2aa8dae44e4 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 30 May 2015 16:52:34 -0700 Subject: [PATCH 097/254] [HOTFIX] Replace FunSuite with SparkFunSuite. This fixes a build break introduced by merging a6430028ecd7a6130f1eb15af9ec00e242c46725, which fails the new style checks that ensure that we use SparkFunSuite instead of FunSuite. --- .../shuffle/sort/BypassMergeSortShuffleWriterSuite.scala | 4 ++-- .../apache/spark/shuffle/sort/SortShuffleWriterSuite.scala | 5 ++--- .../apache/spark/util/collection/ExternalSorterSuite.scala | 4 ---- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index c8420db6126c0..542f8f45125a4 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -29,7 +29,7 @@ import org.mockito.Matchers._ import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer -import org.scalatest.{BeforeAndAfterEach, FunSuite} +import org.scalatest.BeforeAndAfterEach import org.apache.spark._ import org.apache.spark.executor.{TaskMetrics, ShuffleWriteMetrics} @@ -37,7 +37,7 @@ import org.apache.spark.serializer.{SerializerInstance, Serializer, JavaSerializ import org.apache.spark.storage._ import org.apache.spark.util.Utils -class BypassMergeSortShuffleWriterSuite extends FunSuite with BeforeAndAfterEach { +class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfterEach { @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _ @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _ diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala index c6ada7139c198..34b4984f12c09 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala @@ -18,11 +18,10 @@ package org.apache.spark.shuffle.sort import org.mockito.Mockito._ -import org.scalatest.FunSuite -import org.apache.spark.{Aggregator, SparkConf} +import org.apache.spark.{Aggregator, SparkConf, SparkFunSuite} -class SortShuffleWriterSuite extends FunSuite { +class SortShuffleWriterSuite extends SparkFunSuite { import SortShuffleWriter._ diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 7d7b41bc23284..9cefa612f5491 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -19,12 +19,8 @@ package org.apache.spark.util.collection import scala.collection.mutable.ArrayBuffer -import org.scalatest.PrivateMethodTester - import scala.util.Random -import org.scalatest.FunSuite - import org.apache.spark._ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} From 2b258e1c0784c8ca958bf94cd9e75fa17f104448 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sat, 30 May 2015 17:21:41 -0700 Subject: [PATCH 098/254] [SPARK-5610] [DOC] update genjavadocSettings to use the patched version of genjavadoc This PR updates `genjavadocSettings` to use a patched version of `genjavadoc-plugin` that hides package private classes/methods/interfaces in the generated Java API doc. The patch can be found at: https://github.com/typesafehub/genjavadoc/compare/master...mengxr:spark-1.4. It wasn't merged into the main repo because there exist corner cases where a package private Scala class has to be a Java public class in order to compile. This doesn't seem to apply to the Spark codebase. So we release a patched version under `org.spark-project` and use it in the Spark build. brkyvz is publishing the artifacts to Maven Central. Need more people audit the generated APIs and make sure we don't have false negatives. Current listed classes under `org.apache.spark.rdd`: ![screen shot 2015-05-29 at 12 48 52 pm](https://cloud.githubusercontent.com/assets/829644/7891396/28fb9daa-0601-11e5-8ed8-4e9522d25a71.png) After this PR: ![screen shot 2015-05-29 at 12 48 23 pm](https://cloud.githubusercontent.com/assets/829644/7891408/408e210e-0601-11e5-975c-ff0a02eb5c91.png) cc: pwendell rxin srowen Author: Xiangrui Meng Closes #6506 from mengxr/SPARK-5610 and squashes the following commits: 489c785 [Xiangrui Meng] update genjavadocSettings to use the patched version of genjavadoc --- project/SparkBuild.scala | 10 +++++++--- project/plugins.sbt | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b9515a12bc573..9a849639233bc 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -23,7 +23,6 @@ import scala.collection.JavaConversions._ import sbt._ import sbt.Classpaths.publishTask import sbt.Keys._ -import sbtunidoc.Plugin.genjavadocSettings import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion import com.typesafe.sbt.pom.{loadEffectivePom, PomBuild, SbtPomKeys} import net.virtualvoid.sbt.graph.Plugin.graphSettings @@ -118,7 +117,12 @@ object SparkBuild extends PomBuild { lazy val MavenCompile = config("m2r") extend(Compile) lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy") - lazy val sharedSettings = graphSettings ++ genjavadocSettings ++ Seq ( + lazy val sparkGenjavadocSettings: Seq[sbt.Def.Setting[_]] = Seq( + libraryDependencies += compilerPlugin( + "org.spark-project" %% "genjavadoc-plugin" % unidocGenjavadocVersion.value cross CrossVersion.full), + scalacOptions <+= target.map(t => "-P:genjavadoc:out=" + (t / "java"))) + + lazy val sharedSettings = graphSettings ++ sparkGenjavadocSettings ++ Seq ( javaHome := sys.env.get("JAVA_HOME") .orElse(sys.props.get("java.home").map { p => new File(p).getParentFile().getAbsolutePath() }) .map(file), @@ -126,7 +130,7 @@ object SparkBuild extends PomBuild { retrieveManaged := true, retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", publishMavenStyle := true, - unidocGenjavadocVersion := "0.8", + unidocGenjavadocVersion := "0.9-spark0", resolvers += Resolver.mavenLocal, otherResolvers <<= SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))), diff --git a/project/plugins.sbt b/project/plugins.sbt index 7096b0d3ee7de..75bd604a1b857 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -25,7 +25,7 @@ addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1") -addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.1") +addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.3") addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2") From 14b314dc2cad7bbf23976347217c676d338e0a2d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 30 May 2015 19:50:52 -0700 Subject: [PATCH 099/254] [SQL] Tighten up visibility for JavaDoc. I went through all the JavaDocs and tightened up visibility. Author: Reynold Xin Closes #6526 from rxin/sql-1.4-visibility-for-docs and squashes the following commits: bc37d1e [Reynold Xin] Tighten up visibility for JavaDoc. --- .../org/apache/spark/sql/types/Decimal.scala | 6 +++--- .../apache/spark/sql/types/DecimalType.scala | 4 ++-- .../spark/sql/types/SQLUserDefinedType.java | 4 ++-- .../org/apache/spark/sql/GroupedData.scala | 8 ++++---- .../apache/spark/sql/expressions/Window.scala | 17 +++++++++++++++++ .../apache/spark/sql/sources/interfaces.scala | 3 ++- .../org/apache/spark/sql/hive/HiveQl.scala | 2 +- .../spark/sql/hive/client/ReflectionMagic.scala | 2 +- 8 files changed, 32 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 994c5202c15dc..eb3c58c37f308 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -313,7 +313,7 @@ object Decimal { // See scala.math's Numeric.scala for examples for Scala's built-in types. /** Common methods for Decimal evidence parameters */ - trait DecimalIsConflicted extends Numeric[Decimal] { + private[sql] trait DecimalIsConflicted extends Numeric[Decimal] { override def plus(x: Decimal, y: Decimal): Decimal = x + y override def times(x: Decimal, y: Decimal): Decimal = x * y override def minus(x: Decimal, y: Decimal): Decimal = x - y @@ -327,12 +327,12 @@ object Decimal { } /** A [[scala.math.Fractional]] evidence parameter for Decimals. */ - object DecimalIsFractional extends DecimalIsConflicted with Fractional[Decimal] { + private[sql] object DecimalIsFractional extends DecimalIsConflicted with Fractional[Decimal] { override def div(x: Decimal, y: Decimal): Decimal = x / y } /** A [[scala.math.Integral]] evidence parameter for Decimals. */ - object DecimalAsIfIntegral extends DecimalIsConflicted with Integral[Decimal] { + private[sql] object DecimalAsIfIntegral extends DecimalIsConflicted with Integral[Decimal] { override def quot(x: Decimal, y: Decimal): Decimal = x / y override def rem(x: Decimal, y: Decimal): Decimal = x % y } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 0f8cecd28f7df..407dc27326c2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -82,12 +82,12 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT object DecimalType { val Unlimited: DecimalType = DecimalType(None) - object Fixed { + private[sql] object Fixed { def unapply(t: DecimalType): Option[(Int, Int)] = t.precisionInfo.map(p => (p.precision, p.scale)) } - object Expression { + private[sql] object Expression { def unapply(e: Expression): Option[(Int, Int)] = e.dataType match { case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale)) case _ => None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java index a64d2bb7cde37..df64a878b6b36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java @@ -24,11 +24,11 @@ /** * ::DeveloperApi:: * A user-defined type which can be automatically recognized by a SQLContext and registered. - * + *

    * WARNING: This annotation will only work if both Java and Scala reflection return the same class * names (after erasure) for the UDT. This will NOT be the case when, e.g., the UDT class * is enclosed in an object (a singleton). - * + *

    * WARNING: UDTs are currently only supported from Scala. */ // TODO: Should I used @Documented ? diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 516ba2ac23371..c4ceb0c173887 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -40,22 +40,22 @@ private[sql] object GroupedData { /** * The Grouping Type */ - trait GroupType + private[sql] trait GroupType /** * To indicate it's the GroupBy */ - object GroupByType extends GroupType + private[sql] object GroupByType extends GroupType /** * To indicate it's the CUBE */ - object CubeType extends GroupType + private[sql] object CubeType extends GroupType /** * To indicate it's the ROLLUP */ - object RollupType extends GroupType + private[sql] object RollupType extends GroupType } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index d4003b2d9cbf6..e9b60841fc28c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -79,3 +79,20 @@ object Window { } } + +/** + * :: Experimental :: + * Utility functions for defining window in DataFrames. + * + * {{{ + * // PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + * Window.partitionBy("country").orderBy("date").rowsBetween(Long.MinValue, 0) + * + * // PARTITION BY country ORDER BY date ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING + * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3) + * }}} + * + * @since 1.4.0 + */ +@Experimental +class Window private() // So we can see Window in JavaDoc. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index c06026e042d9f..b1b997c030a60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -93,7 +93,7 @@ trait SchemaRelationProvider { } /** - * ::DeveloperApi:: + * ::Experimental:: * Implemented by objects that produce relations for a specific kind of data source * with a given schema and partitioned columns. When Spark SQL is given a DDL operation with a * USING clause specified (to specify the implemented [[HadoopFsRelationProvider]]), a user defined @@ -115,6 +115,7 @@ trait SchemaRelationProvider { * * @since 1.4.0 */ +@Experimental trait HadoopFsRelationProvider { /** * Returns a new base relation with the given parameters, a user defined schema, and a list of diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 3915ee835685f..253bf1125262e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -57,7 +57,7 @@ private[hive] case object NativePlaceholder extends LogicalPlan { override def output: Seq[Attribute] = Seq.empty } -case class CreateTableAsSelect( +private[hive] case class CreateTableAsSelect( tableDesc: HiveTable, child: LogicalPlan, allowExisting: Boolean) extends UnaryNode with Command { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala index c600b158c5460..4d053ae42c2ea 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala @@ -30,7 +30,7 @@ private[client] object ReflectionException { /** * Provides implicit functions on any object for calling methods reflectively. */ -protected trait ReflectionMagic { +private[client] trait ReflectionMagic { /** code for InstanceMagic println( (1 to 22).map { n => From c63e1a742b3e87e79a4466e9bd0b927a24645756 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 30 May 2015 19:51:53 -0700 Subject: [PATCH 100/254] [SPARK-7971] Add JavaDoc style deprecation for deprecated DataFrame methods Scala deprecated annotation actually doesn't show up in JavaDoc. Author: Reynold Xin Closes #6523 from rxin/df-deprecated-javadoc and squashes the following commits: 26da2b2 [Reynold Xin] [SPARK-7971] Add JavaDoc style deprecation for deprecated DataFrame methods. --- .../org/apache/spark/sql/types/DataType.scala | 3 ++ .../org/apache/spark/sql/DataFrame.scala | 46 ++++++++++++++----- .../org/apache/spark/sql/SQLContext.scala | 33 +++++++++++++ 3 files changed, 70 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 54604808e133e..1ba3a2686639f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -165,6 +165,9 @@ object DataType { def fromJson(json: String): DataType = parseDataType(parse(json)) + /** + * @deprecated As of 1.2.0, replaced by `DataType.fromJson()` + */ @deprecated("Use DataType.fromJson instead", "1.2.0") def fromCaseClassString(string: String): DataType = CaseClassStringParser(string) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index e90109446b642..034d887901975 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -57,14 +57,11 @@ private[sql] object DataFrame { * :: Experimental :: * A distributed collection of data organized into named columns. * - * A [[DataFrame]] is equivalent to a relational table in Spark SQL. There are multiple ways - * to create a [[DataFrame]]: + * A [[DataFrame]] is equivalent to a relational table in Spark SQL. The following example creates + * a [[DataFrame]] by pointing Spark SQL to a Parquet data set. * {{{ - * // Create a DataFrame from Parquet files - * val people = sqlContext.parquetFile("...") - * - * // Create a DataFrame from data sources - * val df = sqlContext.load("...", "json") + * val people = sqlContext.read.parquet("...") // in Scala + * DataFrame people = sqlContext.read().parquet("...") // in Java * }}} * * Once created, it can be manipulated using the various domain-specific-language (DSL) functions @@ -86,8 +83,8 @@ private[sql] object DataFrame { * A more concrete example in Scala: * {{{ * // To create DataFrame using SQLContext - * val people = sqlContext.parquetFile("...") - * val department = sqlContext.parquetFile("...") + * val people = sqlContext.read.parquet("...") + * val department = sqlContext.read.parquet("...") * * people.filter("age > 30") * .join(department, people("deptId") === department("id")) @@ -98,8 +95,8 @@ private[sql] object DataFrame { * and in Java: * {{{ * // To create DataFrame using SQLContext - * DataFrame people = sqlContext.parquetFile("..."); - * DataFrame department = sqlContext.parquetFile("..."); + * DataFrame people = sqlContext.read().parquet("..."); + * DataFrame department = sqlContext.read().parquet("..."); * * people.filter("age".gt(30)) * .join(department, people.col("deptId").equalTo(department("id"))) @@ -1444,7 +1441,9 @@ class DataFrame private[sql]( //////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////// - /** Left here for backward compatibility. */ + /** + * @deprecated As of 1.3.0, replaced by `toDF()`. + */ @deprecated("use toDF", "1.3.0") def toSchemaRDD: DataFrame = this @@ -1455,6 +1454,7 @@ class DataFrame private[sql]( * given name; if you pass `false`, it will throw if the table already * exists. * @group output + * @deprecated As of 1.340, replaced by `write().jdbc()`. */ @deprecated("Use write.jdbc()", "1.4.0") def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = { @@ -1473,6 +1473,7 @@ class DataFrame private[sql]( * the RDD in order via the simple statement * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail. * @group output + * @deprecated As of 1.4.0, replaced by `write().jdbc()`. */ @deprecated("Use write.jdbc()", "1.4.0") def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = { @@ -1485,6 +1486,7 @@ class DataFrame private[sql]( * Files that are written out using this method can be read back in as a [[DataFrame]] * using the `parquetFile` function in [[SQLContext]]. * @group output + * @deprecated As of 1.4.0, replaced by `write().parquet()`. */ @deprecated("Use write.parquet(path)", "1.4.0") def saveAsParquetFile(path: String): Unit = { @@ -1508,6 +1510,7 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output + * @deprecated As of 1.4.0, replaced by `write().saveAsTable(tableName)`. */ @deprecated("Use write.saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String): Unit = { @@ -1526,6 +1529,7 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output + * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. */ @deprecated("Use write.mode(mode).saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String, mode: SaveMode): Unit = { @@ -1545,6 +1549,7 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output + * @deprecated As of 1.4.0, replaced by `write().format(source).saveAsTable(tableName)`. */ @deprecated("Use write.format(source).saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String, source: String): Unit = { @@ -1564,6 +1569,7 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output + * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. */ @deprecated("Use write.format(source).mode(mode).saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String, source: String, mode: SaveMode): Unit = { @@ -1582,6 +1588,8 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output + * @deprecated As of 1.4.0, replaced by + * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. */ @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)", "1.4.0") @@ -1606,6 +1614,8 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output + * @deprecated As of 1.4.0, replaced by + * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. */ @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)", "1.4.0") @@ -1622,6 +1632,7 @@ class DataFrame private[sql]( * using the default data source configured by spark.sql.sources.default and * [[SaveMode.ErrorIfExists]] as the save mode. * @group output + * @deprecated As of 1.4.0, replaced by `write().save(path)`. */ @deprecated("Use write.save(path)", "1.4.0") def save(path: String): Unit = { @@ -1632,6 +1643,7 @@ class DataFrame private[sql]( * Saves the contents of this DataFrame to the given path and [[SaveMode]] specified by mode, * using the default data source configured by spark.sql.sources.default. * @group output + * @deprecated As of 1.4.0, replaced by `write().mode(mode).save(path)`. */ @deprecated("Use write.mode(mode).save(path)", "1.4.0") def save(path: String, mode: SaveMode): Unit = { @@ -1642,6 +1654,7 @@ class DataFrame private[sql]( * Saves the contents of this DataFrame to the given path based on the given data source, * using [[SaveMode.ErrorIfExists]] as the save mode. * @group output + * @deprecated As of 1.4.0, replaced by `write().format(source).save(path)`. */ @deprecated("Use write.format(source).save(path)", "1.4.0") def save(path: String, source: String): Unit = { @@ -1652,6 +1665,7 @@ class DataFrame private[sql]( * Saves the contents of this DataFrame to the given path based on the given data source and * [[SaveMode]] specified by mode. * @group output + * @deprecated As of 1.4.0, replaced by `write().format(source).mode(mode).save(path)`. */ @deprecated("Use write.format(source).mode(mode).save(path)", "1.4.0") def save(path: String, source: String, mode: SaveMode): Unit = { @@ -1662,6 +1676,8 @@ class DataFrame private[sql]( * Saves the contents of this DataFrame based on the given data source, * [[SaveMode]] specified by mode, and a set of options. * @group output + * @deprecated As of 1.4.0, replaced by + * `write().format(source).mode(mode).options(options).save(path)`. */ @deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0") def save( @@ -1676,6 +1692,8 @@ class DataFrame private[sql]( * Saves the contents of this DataFrame based on the given data source, * [[SaveMode]] specified by mode, and a set of options * @group output + * @deprecated As of 1.4.0, replaced by + * `write().format(source).mode(mode).options(options).save(path)`. */ @deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0") def save( @@ -1689,6 +1707,8 @@ class DataFrame private[sql]( /** * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. * @group output + * @deprecated As of 1.4.0, replaced by + * `write().mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)`. */ @deprecated("Use write.mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)", "1.4.0") def insertInto(tableName: String, overwrite: Boolean): Unit = { @@ -1699,6 +1719,8 @@ class DataFrame private[sql]( * Adds the rows from this RDD to the specified table. * Throws an exception if the table already exists. * @group output + * @deprecated As of 1.4.0, replaced by + * `write().mode(SaveMode.Append).saveAsTable(tableName)`. */ @deprecated("Use write.mode(SaveMode.Append).saveAsTable(tableName)", "1.4.0") def insertInto(tableName: String): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a32897c20b474..7384b24c50b16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1021,21 +1021,33 @@ class SQLContext(@transient val sparkContext: SparkContext) //////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////// + /** + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + */ @deprecated("use createDataFrame", "1.3.0") def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { createDataFrame(rowRDD, schema) } + /** + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + */ @deprecated("use createDataFrame", "1.3.0") def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { createDataFrame(rowRDD, schema) } + /** + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + */ @deprecated("use createDataFrame", "1.3.0") def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { createDataFrame(rdd, beanClass) } + /** + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + */ @deprecated("use createDataFrame", "1.3.0") def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { createDataFrame(rdd, beanClass) @@ -1046,6 +1058,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * [[DataFrame]] if no paths are passed in. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().parquet()`. */ @deprecated("Use read.parquet()", "1.4.0") @scala.annotation.varargs @@ -1065,6 +1078,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * It goes through the entire dataset once to determine the schema. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonFile(path: String): DataFrame = { @@ -1076,6 +1090,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * returning the result as a [[DataFrame]]. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonFile(path: String, schema: StructType): DataFrame = { @@ -1084,6 +1099,7 @@ class SQLContext(@transient val sparkContext: SparkContext) /** * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonFile(path: String, samplingRatio: Double): DataFrame = { @@ -1096,6 +1112,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * It goes through the entire dataset once to determine the schema. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonRDD(json: RDD[String]): DataFrame = read.json(json) @@ -1106,6 +1123,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * It goes through the entire dataset once to determine the schema. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonRDD(json: JavaRDD[String]): DataFrame = read.json(json) @@ -1115,6 +1133,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * returning the result as a [[DataFrame]]. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { @@ -1126,6 +1145,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * schema, returning the result as a [[DataFrame]]. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonRDD(json: JavaRDD[String], schema: StructType): DataFrame = { @@ -1137,6 +1157,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * schema, returning the result as a [[DataFrame]]. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { @@ -1148,6 +1169,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * schema, returning the result as a [[DataFrame]]. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonRDD(json: JavaRDD[String], samplingRatio: Double): DataFrame = { @@ -1159,6 +1181,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * using the default data source configured by spark.sql.sources.default. * * @group genericdata + * @deprecated As of 1.4.0, replaced by `read().load(path)`. */ @deprecated("Use read.load(path)", "1.4.0") def load(path: String): DataFrame = { @@ -1169,6 +1192,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * Returns the dataset stored at path as a DataFrame, using the given data source. * * @group genericdata + * @deprecated As of 1.4.0, replaced by `read().format(source).load(path)`. */ @deprecated("Use read.format(source).load(path)", "1.4.0") def load(path: String, source: String): DataFrame = { @@ -1180,6 +1204,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * a set of options as a DataFrame. * * @group genericdata + * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. */ @deprecated("Use read.format(source).options(options).load()", "1.4.0") def load(source: String, options: java.util.Map[String, String]): DataFrame = { @@ -1191,6 +1216,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * a set of options as a DataFrame. * * @group genericdata + * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. */ @deprecated("Use read.format(source).options(options).load()", "1.4.0") def load(source: String, options: Map[String, String]): DataFrame = { @@ -1202,6 +1228,8 @@ class SQLContext(@transient val sparkContext: SparkContext) * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. * * @group genericdata + * @deprecated As of 1.4.0, replaced by + * `read().format(source).schema(schema).options(options).load()`. */ @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") def load(source: String, schema: StructType, options: java.util.Map[String, String]): DataFrame = @@ -1214,6 +1242,8 @@ class SQLContext(@transient val sparkContext: SparkContext) * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. * * @group genericdata + * @deprecated As of 1.4.0, replaced by + * `read().format(source).schema(schema).options(options).load()`. */ @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") def load(source: String, schema: StructType, options: Map[String, String]): DataFrame = { @@ -1225,6 +1255,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * url named table. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. */ @deprecated("use read.jdbc()", "1.4.0") def jdbc(url: String, table: String): DataFrame = { @@ -1242,6 +1273,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split * evenly into this many partitions * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. */ @deprecated("use read.jdbc()", "1.4.0") def jdbc( @@ -1261,6 +1293,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * of the [[DataFrame]]. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. */ @deprecated("use read.jdbc()", "1.4.0") def jdbc(url: String, table: String, theParts: Array[String]): DataFrame = { From 00a7137900d45188673da85cbcef4f02b7a266c1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 30 May 2015 20:10:02 -0700 Subject: [PATCH 101/254] Update documentation for the new DataFrame reader/writer interface. Author: Reynold Xin Closes #6522 from rxin/sql-doc-1.4 and squashes the following commits: c227be7 [Reynold Xin] Updated link. 040b6d7 [Reynold Xin] Update documentation for the new DataFrame reader/writer interface. --- docs/sql-programming-guide.md | 126 ++++++++++++++++++---------------- 1 file changed, 66 insertions(+), 60 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 2ea7572c6026a..282ea75e1e785 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -109,7 +109,7 @@ As an example, the following creates a `DataFrame` based on the content of a JSO val sc: SparkContext // An existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) -val df = sqlContext.jsonFile("examples/src/main/resources/people.json") +val df = sqlContext.read.json("examples/src/main/resources/people.json") // Displays the content of the DataFrame to stdout df.show() @@ -122,7 +122,7 @@ df.show() JavaSparkContext sc = ...; // An existing JavaSparkContext. SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); -DataFrame df = sqlContext.jsonFile("examples/src/main/resources/people.json"); +DataFrame df = sqlContext.read().json("examples/src/main/resources/people.json"); // Displays the content of the DataFrame to stdout df.show(); @@ -135,7 +135,7 @@ df.show(); from pyspark.sql import SQLContext sqlContext = SQLContext(sc) -df = sqlContext.jsonFile("examples/src/main/resources/people.json") +df = sqlContext.read.json("examples/src/main/resources/people.json") # Displays the content of the DataFrame to stdout df.show() @@ -171,7 +171,7 @@ val sc: SparkContext // An existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) // Create the DataFrame -val df = sqlContext.jsonFile("examples/src/main/resources/people.json") +val df = sqlContext.read.json("examples/src/main/resources/people.json") // Show the content of the DataFrame df.show() @@ -221,7 +221,7 @@ JavaSparkContext sc // An existing SparkContext. SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc) // Create the DataFrame -DataFrame df = sqlContext.jsonFile("examples/src/main/resources/people.json"); +DataFrame df = sqlContext.read().json("examples/src/main/resources/people.json"); // Show the content of the DataFrame df.show(); @@ -277,7 +277,7 @@ from pyspark.sql import SQLContext sqlContext = SQLContext(sc) # Create the DataFrame -df = sqlContext.jsonFile("examples/src/main/resources/people.json") +df = sqlContext.read.json("examples/src/main/resources/people.json") # Show the content of the DataFrame df.show() @@ -777,8 +777,8 @@ In the simplest form, the default data source (`parquet` unless otherwise config

    {% highlight scala %} -val df = sqlContext.load("examples/src/main/resources/users.parquet") -df.select("name", "favorite_color").save("namesAndFavColors.parquet") +val df = sqlContext.read.load("examples/src/main/resources/users.parquet") +df.select("name", "favorite_color").write.save("namesAndFavColors.parquet") {% endhighlight %}
    @@ -787,8 +787,8 @@ df.select("name", "favorite_color").save("namesAndFavColors.parquet") {% highlight java %} -DataFrame df = sqlContext.load("examples/src/main/resources/users.parquet"); -df.select("name", "favorite_color").save("namesAndFavColors.parquet"); +DataFrame df = sqlContext.read().load("examples/src/main/resources/users.parquet"); +df.select("name", "favorite_color").write().save("namesAndFavColors.parquet"); {% endhighlight %} @@ -798,8 +798,8 @@ df.select("name", "favorite_color").save("namesAndFavColors.parquet"); {% highlight python %} -df = sqlContext.load("examples/src/main/resources/users.parquet") -df.select("name", "favorite_color").save("namesAndFavColors.parquet") +df = sqlContext.read.load("examples/src/main/resources/users.parquet") +df.select("name", "favorite_color").write.save("namesAndFavColors.parquet") {% endhighlight %} @@ -827,8 +827,8 @@ using this syntax.
    {% highlight scala %} -val df = sqlContext.load("examples/src/main/resources/people.json", "json") -df.select("name", "age").save("namesAndAges.parquet", "parquet") +val df = sqlContext.read.format("json").load("examples/src/main/resources/people.json") +df.select("name", "age").write.format("json").save("namesAndAges.parquet") {% endhighlight %}
    @@ -837,8 +837,8 @@ df.select("name", "age").save("namesAndAges.parquet", "parquet") {% highlight java %} -DataFrame df = sqlContext.load("examples/src/main/resources/people.json", "json"); -df.select("name", "age").save("namesAndAges.parquet", "parquet"); +DataFrame df = sqlContext.read().format("json").load("examples/src/main/resources/people.json"); +df.select("name", "age").write().format("parquet").save("namesAndAges.parquet"); {% endhighlight %} @@ -848,8 +848,8 @@ df.select("name", "age").save("namesAndAges.parquet", "parquet"); {% highlight python %} -df = sqlContext.load("examples/src/main/resources/people.json", "json") -df.select("name", "age").save("namesAndAges.parquet", "parquet") +df = sqlContext.read.load("examples/src/main/resources/people.json", format="json") +df.select("name", "age").write.save("namesAndAges.parquet", format="parquet") {% endhighlight %} @@ -947,11 +947,11 @@ import sqlContext.implicits._ val people: RDD[Person] = ... // An RDD of case class objects, from the previous example. // The RDD is implicitly converted to a DataFrame by implicits, allowing it to be stored using Parquet. -people.saveAsParquetFile("people.parquet") +people.write.parquet("people.parquet") // Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. // The result of loading a Parquet file is also a DataFrame. -val parquetFile = sqlContext.parquetFile("people.parquet") +val parquetFile = sqlContext.read.parquet("people.parquet") //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile") @@ -969,11 +969,11 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println) DataFrame schemaPeople = ... // The DataFrame from the previous example. // DataFrames can be saved as Parquet files, maintaining the schema information. -schemaPeople.saveAsParquetFile("people.parquet"); +schemaPeople.write().parquet("people.parquet"); // Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a DataFrame. -DataFrame parquetFile = sqlContext.parquetFile("people.parquet"); +DataFrame parquetFile = sqlContext.read().parquet("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); @@ -995,11 +995,11 @@ List teenagerNames = teenagers.javaRDD().map(new Function() schemaPeople # The DataFrame from the previous example. # DataFrames can be saved as Parquet files, maintaining the schema information. -schemaPeople.saveAsParquetFile("people.parquet") +schemaPeople.read.parquet("people.parquet") # Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. # The result of loading a parquet file is also a DataFrame. -parquetFile = sqlContext.parquetFile("people.parquet") +parquetFile = sqlContext.write.parquet("people.parquet") # Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); @@ -1087,9 +1087,9 @@ path {% endhighlight %} -By passing `path/to/table` to either `SQLContext.parquetFile` or `SQLContext.load`, Spark SQL will -automatically extract the partitioning information from the paths. Now the schema of the returned -DataFrame becomes: +By passing `path/to/table` to either `SQLContext.read.parquet` or `SQLContext.read.load`, Spark SQL +will automatically extract the partitioning information from the paths. +Now the schema of the returned DataFrame becomes: {% highlight text %} @@ -1122,15 +1122,15 @@ import sqlContext.implicits._ // Create a simple DataFrame, stored into a partition directory val df1 = sparkContext.makeRDD(1 to 5).map(i => (i, i * 2)).toDF("single", "double") -df1.saveAsParquetFile("data/test_table/key=1") +df1.write.parquet("data/test_table/key=1") // Create another DataFrame in a new partition directory, // adding a new column and dropping an existing column val df2 = sparkContext.makeRDD(6 to 10).map(i => (i, i * 3)).toDF("single", "triple") -df2.saveAsParquetFile("data/test_table/key=2") +df2.write.parquet("data/test_table/key=2") // Read the partitioned table -val df3 = sqlContext.parquetFile("data/test_table") +val df3 = sqlContext.read.parquet("data/test_table") df3.printSchema() // The final schema consists of all 3 columns in the Parquet files together @@ -1269,12 +1269,10 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext`
    Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. -This conversion can be done using one of two methods in a `SQLContext`: +This conversion can be done using `SQLContext.read.json()` on either an RDD of String, +or a JSON file. -* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. -* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. - -Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each +Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. @@ -1285,8 +1283,7 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc) // A JSON dataset is pointed to by path. // The path can be either a single text file or a directory storing text files. val path = "examples/src/main/resources/people.json" -// Create a DataFrame from the file(s) pointed to by path -val people = sqlContext.jsonFile(path) +val people = sqlContext.read.json(path) // The inferred schema can be visualized using the printSchema() method. people.printSchema() @@ -1304,19 +1301,17 @@ val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age // an RDD[String] storing one JSON object per string. val anotherPeopleRDD = sc.parallelize( """{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}""" :: Nil) -val anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD) +val anotherPeople = sqlContext.read.json(anotherPeopleRDD) {% endhighlight %}
    Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. -This conversion can be done using one of two methods in a `SQLContext` : - -* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. -* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. +This conversion can be done using `SQLContext.read().json()` on either an RDD of String, +or a JSON file. -Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each +Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. @@ -1326,9 +1321,7 @@ SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); // A JSON dataset is pointed to by path. // The path can be either a single text file or a directory storing text files. -String path = "examples/src/main/resources/people.json"; -// Create a DataFrame from the file(s) pointed to by path -DataFrame people = sqlContext.jsonFile(path); +DataFrame people = sqlContext.read().json("examples/src/main/resources/people.json"); // The inferred schema can be visualized using the printSchema() method. people.printSchema(); @@ -1347,18 +1340,15 @@ DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AN List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = sc.parallelize(jsonData); -DataFrame anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD); +DataFrame anotherPeople = sqlContext.read().json(anotherPeopleRDD); {% endhighlight %}
    Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. -This conversion can be done using one of two methods in a `SQLContext`: - -* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. -* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. +This conversion can be done using `SQLContext.read.json` on a JSON file. -Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each +Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. @@ -1369,9 +1359,7 @@ sqlContext = SQLContext(sc) # A JSON dataset is pointed to by path. # The path can be either a single text file or a directory storing text files. -path = "examples/src/main/resources/people.json" -# Create a DataFrame from the file(s) pointed to by path -people = sqlContext.jsonFile(path) +people = sqlContext.read.json("examples/src/main/resources/people.json") # The inferred schema can be visualized using the printSchema() method. people.printSchema() @@ -1394,12 +1382,11 @@ anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD)
    -Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. -This conversion can be done using one of two methods in a `SQLContext`: - -* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. +Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. using +the `jsonFile` function, which loads data from a directory of JSON files where each line of the +files is a JSON object. -Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each +Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. @@ -1883,6 +1870,25 @@ options. ## Upgrading from Spark SQL 1.3 to 1.4 +#### DataFrame data reader/writer interface + +Based on user feedback, we created a new, more fluid API for reading data in (`SQLContext.read`) +and writing data out (`DataFrame.write`), +and deprecated the old APIs (e.g. `SQLContext.parquetFile`, `SQLContext.jsonFile`). + +See the API docs for `SQLContext.read` ( + Scala, + Java, + Python +) and `DataFrame.write` ( + Scala, + Java, + Python +) more information. + + +#### DataFrame.groupBy retains grouping columns + Based on user feedback, we changed the default behavior of `DataFrame.groupBy().agg()` to retain the grouping columns in the resulting `DataFrame`. To keep the behavior in 1.3, set `spark.sql.retainGroupColumns` to `false`.
    From f7fe9e474417a68635a5ed1aa819d81a9be40895 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sun, 31 May 2015 12:56:41 +0800 Subject: [PATCH 102/254] [SQL] [MINOR] Fixes a minor comment mistake in IsolatedClientLoader Author: Cheng Lian Closes #6521 from liancheng/classloader-comment-fix and squashes the following commits: fc09606 [Cheng Lian] Addresses @srowen's comment 59945c5 [Cheng Lian] Fixes a minor comment mistake in IsolatedClientLoader --- .../spark/sql/hive/client/IsolatedClientLoader.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 196a3d836cab2..16851fdd71a98 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -90,14 +90,14 @@ private[hive] object IsolatedClientLoader { * `ClientInterface`, unless `isolationOn` is set to `false`. * * @param version The version of hive on the classpath. used to pick specific function signatures - * that are not compatibile accross versions. + * that are not compatible across versions. * @param execJars A collection of jar files that must include hive and hadoop. * @param config A set of options that will be added to the HiveConf of the constructed client. * @param isolationOn When true, custom versions of barrier classes will be constructed. Must be * true unless loading the version of hive that is on Sparks classloader. - * @param rootClassLoader The system root classloader. Must not know about hive classes. - * @param baseClassLoader The spark classloader that is used to load shared classes. - * + * @param rootClassLoader The system root classloader. + * @param baseClassLoader The spark classloader that is used to load shared classes. Must not know + * about Hive classes. */ private[hive] class IsolatedClientLoader( val version: HiveVersion, @@ -110,7 +110,7 @@ private[hive] class IsolatedClientLoader( val barrierPrefixes: Seq[String] = Seq.empty) extends Logging { - // Check to make sure that the root classloader does not know about Hive. + // Check to make sure that the base classloader does not know about Hive. assert(Try(baseClassLoader.loadClass("org.apache.hive.HiveConf")).isFailure) /** All jars used by the hive specific classloader. */ From 084fef76e90116c6465cd6fad7c0197c3e4d4313 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 30 May 2015 23:36:32 -0700 Subject: [PATCH 103/254] [SPARK-7976] Add style checker to disallow overriding finalize. Author: Reynold Xin Closes #6528 from rxin/style-finalizer and squashes the following commits: a2211ca [Reynold Xin] [SPARK-7976] Enable NoFinalizeChecker. --- core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala | 2 ++ scalastyle-config.xml | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index a77bf42ce1d38..51388f01a31a3 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -843,6 +843,7 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: * An Wrapper for Python Broadcast, which is written into disk by Python. It also will * write the data into disk after deserialization, then Python can read it from disks. */ +// scalastyle:off no.finalize private[spark] class PythonBroadcast(@transient var path: String) extends Serializable { /** @@ -884,3 +885,4 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial } } } +// scalastyle:on no.finalize diff --git a/scalastyle-config.xml b/scalastyle-config.xml index a0098169a0248..072c48062ca75 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -96,7 +96,7 @@ - + From 8764dccebd44292ab6f6834640199aad451459c5 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 30 May 2015 23:49:42 -0700 Subject: [PATCH 104/254] [SQL] [MINOR] Adds @deprecated Scaladoc entry for SchemaRDD Author: Cheng Lian Closes #6529 from liancheng/schemardd-deprecation-fix and squashes the following commits: 49765c2 [Cheng Lian] Adds @deprecated Scaladoc entry for SchemaRDD --- sql/core/src/main/scala/org/apache/spark/sql/package.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 3f97a11ceb97d..4e94fd07a8771 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -44,6 +44,7 @@ package object sql { /** * Type alias for [[DataFrame]]. Kept here for backward source compatibility for Scala. + * @deprecated As of 1.3.0, replaced by `DataFrame`. */ @deprecated("1.3.0", "use DataFrame") type SchemaRDD = DataFrame From 7896e99b2a0a160bd0b6c5c11cf40b6cbf4a65cf Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 31 May 2015 00:05:55 -0700 Subject: [PATCH 105/254] [SPARK-7975] Add style checker to disallow overriding equals covariantly. Author: Reynold Xin This patch had conflicts when merged, resolved by Committer: Reynold Xin Closes #6527 from rxin/covariant-equals and squashes the following commits: e7d7784 [Reynold Xin] [SPARK-7975] Enforce CovariantEqualsChecker --- scalastyle-config.xml | 2 +- .../main/scala/org/apache/spark/sql/parquet/newParquet.scala | 2 +- .../scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 072c48062ca75..3a984222167b0 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -97,7 +97,7 @@ - + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 8b3e1b2b59bf6..e439a18ac43aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -155,7 +155,7 @@ private[sql] class ParquetRelation2( meta } - override def equals(other: scala.Any): Boolean = other match { + override def equals(other: Any): Boolean = other match { case that: ParquetRelation2 => val schemaEquality = if (shouldMergeSchemas) { this.shouldMergeSchemas == that.shouldMergeSchemas diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 47b85731587d5..ca1f49b546bd7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -596,7 +596,7 @@ private[hive] case class MetastoreRelation self: Product => - override def equals(other: scala.Any): Boolean = other match { + override def equals(other: Any): Boolean = other match { case relation: MetastoreRelation => databaseName == relation.databaseName && tableName == relation.tableName && From 74fdc97c7206c6d715f128ef7c46055e0bb90760 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 31 May 2015 00:16:22 -0700 Subject: [PATCH 106/254] [SPARK-3850] Trim trailing spaces for core. Author: Reynold Xin Closes #6533 from rxin/whitespace-2 and squashes the following commits: 038314c [Reynold Xin] [SPARK-3850] Trim trailing spaces for core. --- .../scala/org/apache/spark/Aggregator.scala | 4 +-- .../scala/org/apache/spark/FutureAction.scala | 2 +- .../org/apache/spark/HeartbeatReceiver.scala | 20 ++++++------- .../org/apache/spark/HttpFileServer.scala | 4 +-- .../scala/org/apache/spark/SparkConf.scala | 12 ++++---- .../scala/org/apache/spark/TestUtils.scala | 2 +- .../apache/spark/api/java/JavaDoubleRDD.scala | 2 +- .../org/apache/spark/api/java/JavaRDD.scala | 6 ++-- .../apache/spark/api/python/PythonRDD.scala | 4 +-- .../org/apache/spark/api/r/RBackend.scala | 4 +-- .../apache/spark/api/r/RBackendHandler.scala | 2 +- .../org/apache/spark/deploy/SparkSubmit.scala | 2 +- .../history/HistoryServerArguments.scala | 2 +- .../master/ZooKeeperPersistenceEngine.scala | 2 +- .../apache/spark/executor/TaskMetrics.scala | 16 +++++------ .../apache/spark/metrics/sink/Slf4jSink.scala | 4 +-- .../apache/spark/metrics/sink/package.scala | 2 +- .../apache/spark/rdd/AsyncRDDActions.scala | 4 +-- .../org/apache/spark/rdd/NewHadoopRDD.scala | 2 +- .../apache/spark/rdd/PairRDDFunctions.scala | 2 +- .../spark/scheduler/ReplayListenerBus.scala | 4 +-- .../org/apache/spark/scheduler/Task.scala | 2 +- .../spark/scheduler/TaskSetManager.scala | 4 +-- .../CoarseGrainedSchedulerBackend.scala | 2 +- .../mesos/MesosSchedulerBackendUtil.scala | 4 +-- .../spark/serializer/KryoSerializer.scala | 2 +- .../hash/BlockStoreShuffleFetcher.scala | 2 +- .../status/api/v1/OneStageResource.scala | 2 +- .../storage/BlockManagerMasterEndpoint.scala | 8 +++--- .../spark/storage/DiskBlockManager.scala | 2 +- .../spark/storage/TachyonBlockManager.scala | 4 +-- .../scala/org/apache/spark/ui/WebUI.scala | 4 +-- .../spark/ui/jobs/JobProgressListener.scala | 2 +- .../spark/util/AsynchronousListenerBus.scala | 2 +- .../org/apache/spark/util/SizeEstimator.scala | 4 +-- .../collection/ExternalAppendOnlyMap.scala | 4 +-- .../util/collection/ExternalSorter.scala | 2 +- .../scala/org/apache/spark/FailureSuite.scala | 6 ++-- .../apache/spark/ImplicitOrderingSuite.scala | 28 +++++++++---------- .../org/apache/spark/SparkContextSuite.scala | 10 +++---- .../org/apache/spark/rdd/JdbcRDDSuite.scala | 2 +- .../cluster/mesos/MemoryUtilsSuite.scala | 4 +-- .../mesos/MesosSchedulerBackendSuite.scala | 4 +-- .../serializer/KryoSerializerSuite.scala | 6 ++-- .../ProactiveClosureSerializationSuite.scala | 18 ++++++------ .../spark/util/ClosureCleanerSuite.scala | 2 +- .../util/random/RandomSamplerSuite.scala | 2 +- 47 files changed, 117 insertions(+), 117 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index b8a5f5016860f..ceeb58075d345 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -34,8 +34,8 @@ case class Aggregator[K, V, C] ( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) { - // When spilling is enabled sorting will happen externally, but not necessarily with an - // ExternalSorter. + // When spilling is enabled sorting will happen externally, but not necessarily with an + // ExternalSorter. private val isSpillEnabled = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true) @deprecated("use combineValuesByKey with TaskContext argument", "0.9.0") diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 91f9ef8ce7185..48792a958130c 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -150,7 +150,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: } override def isCompleted: Boolean = jobWaiter.jobFinished - + override def isCancelled: Boolean = _cancelled override def value: Option[Try[T]] = { diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index f2b024ff6cb67..6909015ff66e6 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -29,7 +29,7 @@ import org.apache.spark.util.{ThreadUtils, Utils} /** * A heartbeat from executors to the driver. This is a shared message used by several internal - * components to convey liveness or execution information for in-progress tasks. It will also + * components to convey liveness or execution information for in-progress tasks. It will also * expire the hosts that have not heartbeated for more than spark.network.timeout. */ private[spark] case class Heartbeat( @@ -43,8 +43,8 @@ private[spark] case class Heartbeat( */ private[spark] case object TaskSchedulerIsSet -private[spark] case object ExpireDeadHosts - +private[spark] case object ExpireDeadHosts + private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) /** @@ -62,18 +62,18 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) // "spark.network.timeout" uses "seconds", while `spark.storage.blockManagerSlaveTimeoutMs` uses // "milliseconds" - private val slaveTimeoutMs = + private val slaveTimeoutMs = sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", "120s") - private val executorTimeoutMs = + private val executorTimeoutMs = sc.conf.getTimeAsSeconds("spark.network.timeout", s"${slaveTimeoutMs}ms") * 1000 - + // "spark.network.timeoutInterval" uses "seconds", while // "spark.storage.blockManagerTimeoutIntervalMs" uses "milliseconds" - private val timeoutIntervalMs = + private val timeoutIntervalMs = sc.conf.getTimeAsMs("spark.storage.blockManagerTimeoutIntervalMs", "60s") - private val checkTimeoutIntervalMs = + private val checkTimeoutIntervalMs = sc.conf.getTimeAsSeconds("spark.network.timeoutInterval", s"${timeoutIntervalMs}ms") * 1000 - + private var timeoutCheckingTask: ScheduledFuture[_] = null // "eventLoopThread" is used to run some pretty fast actions. The actions running in it should not @@ -140,7 +140,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) } } } - + override def onStop(): Unit = { if (timeoutCheckingTask != null) { timeoutCheckingTask.cancel(true) diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index 7e706bcc42f04..7cf7bc0dc6810 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -50,8 +50,8 @@ private[spark] class HttpFileServer( def stop() { httpServer.stop() - - // If we only stop sc, but the driver process still run as a services then we need to delete + + // If we only stop sc, but the driver process still run as a services then we need to delete // the tmp dir, if not, it will create too many tmp dirs try { Utils.deleteRecursively(baseDir) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 4b5bcb54aa873..46d72841dccce 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -227,7 +227,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getSizeAsBytes(key: String, defaultValue: String): Long = { Utils.byteStringAsBytes(get(key, defaultValue)) } - + /** * Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Kibibytes are assumed. @@ -244,7 +244,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getSizeAsKb(key: String, defaultValue: String): Long = { Utils.byteStringAsKb(get(key, defaultValue)) } - + /** * Get a size parameter as Mebibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Mebibytes are assumed. @@ -261,7 +261,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getSizeAsMb(key: String, defaultValue: String): Long = { Utils.byteStringAsMb(get(key, defaultValue)) } - + /** * Get a size parameter as Gibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Gibibytes are assumed. @@ -278,7 +278,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getSizeAsGb(key: String, defaultValue: String): Long = { Utils.byteStringAsGb(get(key, defaultValue)) } - + /** Get a parameter as an Option */ def getOption(key: String): Option[String] = { Option(settings.get(key)).orElse(getDeprecatedConfig(key, this)) @@ -480,7 +480,7 @@ private[spark] object SparkConf extends Logging { "spark.kryoserializer.buffer.mb was previously specified as '0.064'. Fractional values " + "are no longer accepted. To specify the equivalent now, one may use '64k'.") ) - + Map(configs.map { cfg => (cfg.key -> cfg) } : _*) } @@ -508,7 +508,7 @@ private[spark] object SparkConf extends Logging { "spark.reducer.maxSizeInFlight" -> Seq( AlternateConfig("spark.reducer.maxMbInFlight", "1.4")), "spark.kryoserializer.buffer" -> - Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", + Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", translation = s => s"${(s.toDouble * 1000).toInt}k")), "spark.kryoserializer.buffer.max" -> Seq( AlternateConfig("spark.kryoserializer.buffer.max.mb", "1.4")), diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index fe6320b504e15..a1ebbecf93b7b 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -51,7 +51,7 @@ private[spark] object TestUtils { classpathUrls: Seq[URL] = Seq()): URL = { val tempDir = Utils.createTempDir() val files1 = for (name <- classNames) yield { - createCompiledClass(name, tempDir, toStringValue, classpathUrls = classpathUrls) + createCompiledClass(name, tempDir, toStringValue, classpathUrls = classpathUrls) } val files2 = for ((childName, baseName) <- classNamesWithBase) yield { createCompiledClass(childName, tempDir, toStringValue, baseName, classpathUrls) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 61af867b11b9c..a650df605b92e 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -137,7 +137,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) */ def sample(withReplacement: Boolean, fraction: JDouble): JavaDoubleRDD = sample(withReplacement, fraction, Utils.random.nextLong) - + /** * Return a sampled subset of this RDD. */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index db4e996feb31c..ed312770ee131 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -101,7 +101,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) /** * Return a sampled subset of this RDD. - * + * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] @@ -109,10 +109,10 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) */ def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] = sample(withReplacement, fraction, Utils.random.nextLong) - + /** * Return a sampled subset of this RDD. - * + * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 51388f01a31a3..55a37f8c944b2 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -797,10 +797,10 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536) - /** + /** * We try to reuse a single Socket to transfer accumulator updates, as they are all added * by the DAGScheduler's single-threaded actor anyway. - */ + */ @transient var socket: Socket = _ def openSocket(): Socket = synchronized { diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 0a91977928cee..d24c650d37bb0 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -44,11 +44,11 @@ private[spark] class RBackend { bossGroup = new NioEventLoopGroup(2) val workerGroup = bossGroup val handler = new RBackendHandler(this) - + bootstrap = new ServerBootstrap() .group(bossGroup, workerGroup) .channel(classOf[NioServerSocketChannel]) - + bootstrap.childHandler(new ChannelInitializer[SocketChannel]() { def initChannel(ch: SocketChannel): Unit = { ch.pipeline() diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 026a1b9380357..2e86984c66b3a 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -77,7 +77,7 @@ private[r] class RBackendHandler(server: RBackend) val reply = bos.toByteArray ctx.write(reply) } - + override def channelReadComplete(ctx: ChannelHandlerContext): Unit = { ctx.flush() } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index d1b32ea0778db..8cf4d58847d8e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -869,7 +869,7 @@ private[spark] object SparkSubmitUtils { md.addDependency(dd) } } - + /** Add exclusion rules for dependencies already included in the spark-assembly */ def addExclusionRules( ivySettings: IvySettings, diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index a2a97a7877ce7..4692d22651c93 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -23,7 +23,7 @@ import org.apache.spark.util.Utils /** * Command-line parser for the master. */ -private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String]) +private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String]) extends Logging { private var propertiesFile: String = null diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 80db6d474b5c1..328d95a7a0c68 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -32,7 +32,7 @@ import org.apache.spark.deploy.SparkCuratorUtil private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization) extends PersistenceEngine with Logging { - + private val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/master_status" private val zk: CuratorFramework = SparkCuratorUtil.newClient(conf) diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index d90ae405a0849..38b61d7242fce 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -43,22 +43,22 @@ class TaskMetrics extends Serializable { private var _hostname: String = _ def hostname: String = _hostname private[spark] def setHostname(value: String) = _hostname = value - + /** * Time taken on the executor to deserialize this task */ private var _executorDeserializeTime: Long = _ def executorDeserializeTime: Long = _executorDeserializeTime private[spark] def setExecutorDeserializeTime(value: Long) = _executorDeserializeTime = value - - + + /** * Time the executor spends actually running the task (including fetching shuffle data) */ private var _executorRunTime: Long = _ def executorRunTime: Long = _executorRunTime private[spark] def setExecutorRunTime(value: Long) = _executorRunTime = value - + /** * The number of bytes this task transmitted back to the driver as the TaskResult */ @@ -315,7 +315,7 @@ class ShuffleReadMetrics extends Serializable { def remoteBlocksFetched: Int = _remoteBlocksFetched private[spark] def incRemoteBlocksFetched(value: Int) = _remoteBlocksFetched += value private[spark] def decRemoteBlocksFetched(value: Int) = _remoteBlocksFetched -= value - + /** * Number of local blocks fetched in this shuffle by this task */ @@ -333,7 +333,7 @@ class ShuffleReadMetrics extends Serializable { def fetchWaitTime: Long = _fetchWaitTime private[spark] def incFetchWaitTime(value: Long) = _fetchWaitTime += value private[spark] def decFetchWaitTime(value: Long) = _fetchWaitTime -= value - + /** * Total number of remote bytes read from the shuffle by this task */ @@ -381,7 +381,7 @@ class ShuffleWriteMetrics extends Serializable { def shuffleBytesWritten: Long = _shuffleBytesWritten private[spark] def incShuffleBytesWritten(value: Long) = _shuffleBytesWritten += value private[spark] def decShuffleBytesWritten(value: Long) = _shuffleBytesWritten -= value - + /** * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds */ @@ -389,7 +389,7 @@ class ShuffleWriteMetrics extends Serializable { def shuffleWriteTime: Long = _shuffleWriteTime private[spark] def incShuffleWriteTime(value: Long) = _shuffleWriteTime += value private[spark] def decShuffleWriteTime(value: Long) = _shuffleWriteTime -= value - + /** * Total number of records written to the shuffle by this task */ diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala index e8b3074e8f1a6..11dfcfe2f04e1 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala @@ -26,9 +26,9 @@ import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem private[spark] class Slf4jSink( - val property: Properties, + val property: Properties, val registry: MetricRegistry, - securityMgr: SecurityManager) + securityMgr: SecurityManager) extends Sink { val SLF4J_DEFAULT_PERIOD = 10 val SLF4J_DEFAULT_UNIT = "SECONDS" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/package.scala b/core/src/main/scala/org/apache/spark/metrics/sink/package.scala index 90e3aa70b99ef..670e683663324 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/package.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/package.scala @@ -20,4 +20,4 @@ package org.apache.spark.metrics /** * Sinks used in Spark's metrics system. */ -package object sink +package object sink diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index bbf1b83af0795..ca1eb1f4e4a9a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -85,9 +85,9 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi numPartsToTry = partsScanned * 4 } else { // the left side of max is >=1 whenever partsScanned >= 2 - numPartsToTry = Math.max(1, + numPartsToTry = Math.max(1, (1.5 * num * partsScanned / results.size).toInt - partsScanned) - numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) + numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 2ab967f4bb313..84456d6d868dc 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -196,7 +196,7 @@ class NewHadoopRDD[K, V]( override def getPreferredLocations(hsplit: Partition): Seq[String] = { val split = hsplit.asInstanceOf[NewHadoopPartition].serializableHadoopSplit.value val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match { - case Some(c) => + case Some(c) => try { val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]] Some(HadoopRDD.convertSplitLocationInfo(infos)) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 004899f27b7a6..cfd3e26faf2b9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -328,7 +328,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) reduceByKeyLocally(func) } - /** + /** * Count the number of elements for each key, collecting the results to a local Map. * * Note that this method should only be used if the resulting map is expected to be small, as diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index 86f357abb8723..c6d957b65f3fb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -41,7 +41,7 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { * * @param logData Stream containing event log data. * @param sourceName Filename (or other source identifier) from whence @logData is being read - * @param maybeTruncated Indicate whether log file might be truncated (some abnormal situations + * @param maybeTruncated Indicate whether log file might be truncated (some abnormal situations * encountered, log file might not finished writing) or not */ def replay( @@ -62,7 +62,7 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { if (!maybeTruncated || lines.hasNext) { throw jpe } else { - logWarning(s"Got JsonParseException from log file $sourceName" + + logWarning(s"Got JsonParseException from log file $sourceName" + s" at line $lineNumber, the file might not have finished writing cleanly.") } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 586d1e06204c1..15101c64f0503 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -125,7 +125,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex if (interruptThread && taskThread != null) { taskThread.interrupt() } - } + } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index d473e51abab80..673cd0e19eba2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -861,9 +861,9 @@ private[spark] class TaskSetManager( case TaskLocality.RACK_LOCAL => "spark.locality.wait.rack" case _ => null } - + if (localityWaitKey != null) { - conf.getTimeAsMs(localityWaitKey, defaultWait) + conf.getTimeAsMs(localityWaitKey, defaultWait) } else { 0L } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index c5bc6294a5577..fcad959540f5a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -84,7 +84,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp override def onStart() { // Periodically revive offers to allow delay scheduling to work val reviveIntervalMs = conf.getTimeAsMs("spark.scheduler.revive.interval", "1s") - + reviveThread.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { Option(self).foreach(_.send(ReviveOffers)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala index 2f2934c249eb0..e79c543a9de27 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -37,14 +37,14 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { .newBuilder() .setMode(Volume.Mode.RW) spec match { - case Array(container_path) => + case Array(container_path) => Some(vol.setContainerPath(container_path)) case Array(container_path, "rw") => Some(vol.setContainerPath(container_path)) case Array(container_path, "ro") => Some(vol.setContainerPath(container_path) .setMode(Volume.Mode.RO)) - case Array(host_path, container_path) => + case Array(host_path, container_path) => Some(vol.setContainerPath(container_path) .setHostPath(host_path)) case Array(host_path, container_path, "rw") => diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 3f909885dbd66..cd8a82347a1e9 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -52,7 +52,7 @@ class KryoSerializer(conf: SparkConf) with Serializable { private val bufferSizeKb = conf.getSizeAsKb("spark.kryoserializer.buffer", "64k") - + if (bufferSizeKb >= ByteUnit.GiB.toKiB(2)) { throw new IllegalArgumentException("spark.kryoserializer.buffer must be less than " + s"2048 mb, got: + ${ByteUnit.KiB.toMiB(bufferSizeKb)} mb.") diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 80374adc44296..597d46a3d2223 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -80,7 +80,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { blocksByAddress, serializer, // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) val itr = blockFetcherItr.flatMap(unpackBlock) val completionIter = CompletionIterator[T, Iterator[T]](itr, { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala index fd24aea63a8a1..f9812f06cf527 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala @@ -83,7 +83,7 @@ private[v1] class OneStageResource(ui: SparkUI) { withStageAttempt(stageId, stageAttemptId) { stage => val tasks = stage.ui.taskData.values.map{AllStagesResource.convertTaskData}.toIndexedSeq .sorted(OneStageResource.ordering(sortBy)) - tasks.slice(offset, offset + length) + tasks.slice(offset, offset + length) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 3afb4c3c02e2d..2cd8c5297b741 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -292,16 +292,16 @@ class BlockManagerMasterEndpoint( blockManagerIdByExecutor.get(id.executorId) match { case Some(oldId) => // A block manager of the same executor already exists, so remove it (assumed dead) - logError("Got two different block manager registrations on same executor - " + logError("Got two different block manager registrations on same executor - " + s" will replace old one $oldId with new one $id") - removeExecutor(id.executorId) + removeExecutor(id.executorId) case None => } logInfo("Registering block manager %s with %s RAM, %s".format( id.hostPort, Utils.bytesToString(maxMemSize), id)) - + blockManagerIdByExecutor(id.executorId) = id - + blockManagerInfo(id) = new BlockManagerInfo( id, System.currentTimeMillis(), maxMemSize, slaveEndpoint) } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index d441a4d31b954..91ef86389a0c3 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -151,7 +151,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon try { Utils.removeShutdownHook(shutdownHook) } catch { - case e: Exception => + case e: Exception => logError(s"Exception while removing shutdown hook.", e) } doStop() diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index fb4ba0eac9d9a..b53c86e89a273 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -100,7 +100,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log try { os.write(bytes.array()) } catch { - case NonFatal(e) => + case NonFatal(e) => logWarning(s"Failed to put bytes of block $blockId into Tachyon", e) os.cancel() } finally { @@ -114,7 +114,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log try { blockManager.dataSerializeStream(blockId, os, values) } catch { - case NonFatal(e) => + case NonFatal(e) => logWarning(s"Failed to put values of block $blockId into Tachyon", e) os.cancel() } finally { diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 594df15e9cc85..2c84e4485996e 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -62,12 +62,12 @@ private[spark] abstract class WebUI( tab.pages.foreach(attachPage) tabs += tab } - + def detachTab(tab: WebUITab) { tab.pages.foreach(detachPage) tabs -= tab } - + def detachPage(page: WebUIPage) { pageToHandlers.remove(page).foreach(_.foreach(detachHandler)) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 246e191d64776..f39e961772c46 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -119,7 +119,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { "failedStages" -> failedStages.size ) } - + // These collections may grow arbitrarily, but once Spark becomes idle they should shrink back to // some bound based on the `spark.ui.retainedStages` and `spark.ui.retainedJobs` settings: private[spark] def getSizesOfSoftSizeLimitedCollections: Map[String, Int] = { diff --git a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala index ce7887b76ff96..1861d38640102 100644 --- a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala @@ -40,7 +40,7 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri self => private var sparkContext: SparkContext = null - + /* Cap the capacity of the event queue so we get an explicit error (rather than * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */ private val EVENT_QUEUE_CAPACITY = 10000 diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index f1f6b5e1f93d8..0180399c9dad5 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -236,7 +236,7 @@ object SizeEstimator extends Logging { val s1 = sampleArray(array, state, rand, drawn, length) val s2 = sampleArray(array, state, rand, drawn, length) val size = math.min(s1, s2) - state.size += math.max(s1, s2) + + state.size += math.max(s1, s2) + (size * ((length - ARRAY_SAMPLE_SIZE) / (ARRAY_SAMPLE_SIZE))).toLong } } @@ -244,7 +244,7 @@ object SizeEstimator extends Logging { private def sampleArray( array: AnyRef, - state: SearchState, + state: SearchState, rand: Random, drawn: OpenHashSet[Int], length: Int): Long = { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index df2d6ad3b41a4..1e4531ef395ae 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -89,9 +89,9 @@ class ExternalAppendOnlyMap[K, V, C]( // Number of bytes spilled in total private var _diskBytesSpilled = 0L - + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided - private val fileBufferSize = + private val fileBufferSize = sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 // Write metrics for current spill diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index ef2dbb7ff0ae0..757dec66c203b 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -117,7 +117,7 @@ private[spark] class ExternalSorter[K, V, C]( private val serInstance = ser.newInstance() private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) - + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index b18067e68f5a1..a8c8c6f73fb5a 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -117,7 +117,7 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { sc.parallelize(1 to 10, 2).map(x => a).count() } assert(thrown.getClass === classOf[SparkException]) - assert(thrown.getMessage.contains("NotSerializableException") || + assert(thrown.getMessage.contains("NotSerializableException") || thrown.getCause.getClass === classOf[NotSerializableException]) // Non-serializable closure in an earlier stage @@ -125,7 +125,7 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count() } assert(thrown1.getClass === classOf[SparkException]) - assert(thrown1.getMessage.contains("NotSerializableException") || + assert(thrown1.getMessage.contains("NotSerializableException") || thrown1.getCause.getClass === classOf[NotSerializableException]) // Non-serializable closure in foreach function @@ -133,7 +133,7 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { sc.parallelize(1 to 10, 2).foreach(x => println(a)) } assert(thrown2.getClass === classOf[SparkException]) - assert(thrown2.getMessage.contains("NotSerializableException") || + assert(thrown2.getMessage.contains("NotSerializableException") || thrown2.getCause.getClass === classOf[NotSerializableException]) FailureSuiteState.clear() diff --git a/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala b/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala index e47173f8a8b03..4399f25626472 100644 --- a/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala @@ -27,11 +27,11 @@ class ImplicitOrderingSuite extends SparkFunSuite with LocalSparkContext { // These RDD methods are in the companion object so that the unserializable ScalaTest Engine // won't be reachable from the closure object - + // Infer orderings after basic maps to particular types val basicMapExpectations = ImplicitOrderingSuite.basicMapExpectations(rdd) basicMapExpectations.map({case (met, explain) => assert(met, explain)}) - + // Infer orderings for other RDD methods val otherRDDMethodExpectations = ImplicitOrderingSuite.otherRDDMethodExpectations(rdd) otherRDDMethodExpectations.map({case (met, explain) => assert(met, explain)}) @@ -48,30 +48,30 @@ private object ImplicitOrderingSuite { class OrderedClass extends Ordered[OrderedClass] { override def compare(o: OrderedClass): Int = throw new UnsupportedOperationException } - + def basicMapExpectations(rdd: RDD[Int]): List[(Boolean, String)] = { - List((rdd.map(x => (x, x)).keyOrdering.isDefined, + List((rdd.map(x => (x, x)).keyOrdering.isDefined, "rdd.map(x => (x, x)).keyOrdering.isDefined"), - (rdd.map(x => (1, x)).keyOrdering.isDefined, + (rdd.map(x => (1, x)).keyOrdering.isDefined, "rdd.map(x => (1, x)).keyOrdering.isDefined"), - (rdd.map(x => (x.toString, x)).keyOrdering.isDefined, + (rdd.map(x => (x.toString, x)).keyOrdering.isDefined, "rdd.map(x => (x.toString, x)).keyOrdering.isDefined"), - (rdd.map(x => (null, x)).keyOrdering.isDefined, + (rdd.map(x => (null, x)).keyOrdering.isDefined, "rdd.map(x => (null, x)).keyOrdering.isDefined"), - (rdd.map(x => (new NonOrderedClass, x)).keyOrdering.isEmpty, + (rdd.map(x => (new NonOrderedClass, x)).keyOrdering.isEmpty, "rdd.map(x => (new NonOrderedClass, x)).keyOrdering.isEmpty"), - (rdd.map(x => (new ComparableClass, x)).keyOrdering.isDefined, + (rdd.map(x => (new ComparableClass, x)).keyOrdering.isDefined, "rdd.map(x => (new ComparableClass, x)).keyOrdering.isDefined"), - (rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined, + (rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined, "rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined")) } - + def otherRDDMethodExpectations(rdd: RDD[Int]): List[(Boolean, String)] = { - List((rdd.groupBy(x => x).keyOrdering.isDefined, + List((rdd.groupBy(x => x).keyOrdering.isDefined, "rdd.groupBy(x => x).keyOrdering.isDefined"), - (rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty, + (rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty, "rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty"), - (rdd.groupBy(x => new ComparableClass).keyOrdering.isDefined, + (rdd.groupBy(x => new ComparableClass).keyOrdering.isDefined, "rdd.groupBy(x => new ComparableClass).keyOrdering.isDefined"), (rdd.groupBy(x => new OrderedClass).keyOrdering.isDefined, "rdd.groupBy(x => new OrderedClass).keyOrdering.isDefined"), diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 93426822f704e..6838b35ab4cc8 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -71,22 +71,22 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { var sc2: SparkContext = null SparkContext.clearActiveContext() val conf = new SparkConf().setAppName("test").setMaster("local") - + sc = SparkContext.getOrCreate(conf) - + assert(sc.getConf.get("spark.app.name").equals("test")) sc2 = SparkContext.getOrCreate(new SparkConf().setAppName("test2").setMaster("local")) assert(sc2.getConf.get("spark.app.name").equals("test")) assert(sc === sc2) assert(sc eq sc2) - + // Try creating second context to confirm that it's still possible, if desired sc2 = new SparkContext(new SparkConf().setAppName("test3").setMaster("local") .set("spark.driver.allowMultipleContexts", "true")) - + sc2.stop() } - + test("BytesWritable implicit conversion is correct") { // Regression test for SPARK-3121 val bytesWritable = new BytesWritable() diff --git a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala index a8466ed8c1dc2..08215a2bafc09 100644 --- a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala @@ -82,7 +82,7 @@ class JdbcRDDSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkCont assert(rdd.count === 100) assert(rdd.reduce(_ + _) === 10100) } - + test("large id overflow") { sc = new SparkContext("local", "test") val rdd = new JdbcRDD( diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala index d565132a06789..e72285d03d3ee 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala @@ -28,11 +28,11 @@ class MemoryUtilsSuite extends SparkFunSuite with MockitoSugar { val sc = mock[SparkContext] when(sc.conf).thenReturn(sparkConf) - + // 384 > sc.executorMemory * 0.1 => 512 + 384 = 896 when(sc.executorMemory).thenReturn(512) assert(MemoryUtils.calculateTotalMemory(sc) === 896) - + // 384 < sc.executorMemory * 0.1 => 4096 + (4096 * 0.1) = 4505.6 when(sc.executorMemory).thenReturn(4096) assert(MemoryUtils.calculateTotalMemory(sc) === 4505) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index 6f4ff0814b8da..68df46a41ddc8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -79,11 +79,11 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi .set("spark.mesos.executor.docker.image", "spark/mock") .set("spark.mesos.executor.docker.volumes", "/a,/b:/b,/c:/c:rw,/d:ro,/e:/e:ro") .set("spark.mesos.executor.docker.portmaps", "80:8080,53:53:tcp") - + val listenerBus = mock[LiveListenerBus] listenerBus.post( SparkListenerExecutorAdded(anyLong, "s1", new ExecutorInfo("host1", 2, Map.empty))) - + val sc = mock[SparkContext] when(sc.executorMemory).thenReturn(100) when(sc.getSparkHome()).thenReturn(Option("/spark-home")) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index c32fe232cc27c..23a1fdb0f5009 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -36,7 +36,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { test("SPARK-7392 configuration limits") { val kryoBufferProperty = "spark.kryoserializer.buffer" val kryoBufferMaxProperty = "spark.kryoserializer.buffer.max" - + def newKryoInstance( conf: SparkConf, bufferSize: String = "64k", @@ -46,7 +46,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { kryoConf.set(kryoBufferMaxProperty, maxBufferSize) new KryoSerializer(kryoConf).newInstance() } - + // test default values newKryoInstance(conf, "64k", "64m") // 2048m = 2097152k @@ -69,7 +69,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { // test configuration with mb is supported properly newKryoInstance(conf, "8m", "9m") } - + test("basic types") { val ser = new KryoSerializer(conf).newInstance() def check[T: ClassTag](t: T) { diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala index 77d66864f755e..c657414e9e5c3 100644 --- a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.rdd.RDD /* A trivial (but unserializable) container for trivial functions */ class UnserializableClass { def op[T](x: T): String = x.toString - + def pred[T](x: T): Boolean = x.toString.length % 2 == 0 } @@ -45,7 +45,7 @@ class ProactiveClosureSerializationSuite extends SparkFunSuite with SharedSparkC // iterating over a map from transformation names to functions that perform that // transformation on a given RDD, creating one test case for each - for (transformation <- + for (transformation <- Map("map" -> xmap _, "flatMap" -> xflatMap _, "filter" -> xfilter _, @@ -58,24 +58,24 @@ class ProactiveClosureSerializationSuite extends SparkFunSuite with SharedSparkC val ex = intercept[SparkException] { xf(data, uc) } - assert(ex.getMessage.contains("Task not serializable"), + assert(ex.getMessage.contains("Task not serializable"), s"RDD.$name doesn't proactively throw NotSerializableException") } } - private def xmap(x: RDD[String], uc: UnserializableClass): RDD[String] = + private def xmap(x: RDD[String], uc: UnserializableClass): RDD[String] = x.map(y => uc.op(y)) - private def xflatMap(x: RDD[String], uc: UnserializableClass): RDD[String] = + private def xflatMap(x: RDD[String], uc: UnserializableClass): RDD[String] = x.flatMap(y => Seq(uc.op(y))) - private def xfilter(x: RDD[String], uc: UnserializableClass): RDD[String] = + private def xfilter(x: RDD[String], uc: UnserializableClass): RDD[String] = x.filter(y => uc.pred(y)) - private def xmapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] = + private def xmapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] = x.mapPartitions(_.map(y => uc.op(y))) - private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] = + private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] = x.mapPartitionsWithIndex((_, it) => it.map(y => uc.op(y))) - + } diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index a97a842f434fb..70cd27b04347d 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -201,7 +201,7 @@ object TestObjectWithNestedReturns { def run(): Int = { withSpark(new SparkContext("local", "test")) { sc => val nums = sc.parallelize(Array(1, 2, 3, 4)) - nums.map {x => + nums.map {x => // this return is fine since it will not transfer control outside the closure def foo(): Int = { return 5; 1 } foo() diff --git a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala index 2f1e6a39f4554..d6af0aebde733 100644 --- a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala @@ -78,7 +78,7 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { } // Returns iterator over gap lengths between samples. - // This function assumes input data is integers sampled from the sequence of + // This function assumes input data is integers sampled from the sequence of // increasing integers: {0, 1, 2, ...}. This works because that is how I generate them, // and the samplers preserve their input order def gaps(data: Iterator[Int]): Iterator[Int] = { From 564bc11e9827915c8652bc06f4bd591809dea4b1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 31 May 2015 00:47:56 -0700 Subject: [PATCH 107/254] [SPARK-3850] Trim trailing spaces for examples/streaming/yarn. Author: Reynold Xin Closes #6530 from rxin/trim-whitespace-1 and squashes the following commits: 7b7b3a0 [Reynold Xin] Reset again. dc14597 [Reynold Xin] Reset scalastyle. cd556c4 [Reynold Xin] YARN, Kinesis, Flume. 4223fe1 [Reynold Xin] [SPARK-3850] Trim trailing spaces for examples/streaming. --- .../org/apache/spark/examples/LogQuery.scala | 2 +- .../examples/mllib/DenseGaussianMixture.scala | 10 +++---- .../examples/streaming/MQTTWordCount.scala | 10 +++---- .../streaming/flume/FlumeInputDStream.scala | 10 +++---- .../flume/FlumePollingInputDStream.scala | 2 +- .../streaming/flume/FlumeStreamSuite.scala | 2 +- .../spark/streaming/kafka/KafkaCluster.scala | 2 +- .../spark/streaming/kafka/KafkaUtils.scala | 8 +++--- .../streaming/KinesisWordCountASL.scala | 12 ++++---- .../kinesis/KinesisCheckpointState.scala | 8 +++--- .../streaming/kinesis/KinesisReceiver.scala | 8 +++--- .../kinesis/KinesisRecordProcessor.scala | 28 +++++++++---------- .../org/apache/spark/graphx/EdgeSuite.scala | 10 +++---- .../streaming/receiver/BlockGenerator.scala | 2 +- .../receiver/ReceivedBlockHandler.scala | 2 +- .../spark/streaming/UISeleniumSuite.scala | 4 --- .../streaming/util/WriteAheadLogSuite.scala | 4 +-- .../yarn/ClientDistributedCacheManager.scala | 26 ++++++++--------- .../deploy/yarn/YarnSparkHadoopUtil.scala | 4 +-- .../ClientDistributedCacheManagerSuite.scala | 24 ++++++++-------- 20 files changed, 87 insertions(+), 91 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala index 32e02eab8b031..75c82117cbad2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala @@ -22,7 +22,7 @@ import org.apache.spark.SparkContext._ /** * Executes a roll up-style query against Apache logs. - * + * * Usage: LogQuery [logFile] */ object LogQuery { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala index 9a1aab036aa0f..f8c71ccabc43b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala @@ -41,22 +41,22 @@ object DenseGaussianMixture { private def run(inputFile: String, k: Int, convergenceTol: Double, maxIterations: Int) { val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example") val ctx = new SparkContext(conf) - + val data = ctx.textFile(inputFile).map { line => Vectors.dense(line.trim.split(' ').map(_.toDouble)) }.cache() - + val clusters = new GaussianMixture() .setK(k) .setConvergenceTol(convergenceTol) .setMaxIterations(maxIterations) .run(data) - + for (i <- 0 until clusters.k) { - println("weight=%f\nmu=%s\nsigma=\n%s\n" format + println("weight=%f\nmu=%s\nsigma=\n%s\n" format (clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma)) } - + println("Cluster labels (first <= 100):") val clusterLabels = clusters.predict(data) clusterLabels.take(100).foreach { x => diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala index b336751d81616..813c8554f5193 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala @@ -40,7 +40,7 @@ object MQTTPublisher { StreamingExamples.setStreamingLogLevels() val Seq(brokerUrl, topic) = args.toSeq - + var client: MqttClient = null try { @@ -59,10 +59,10 @@ object MQTTPublisher { println(s"Published data. topic: ${msgtopic.getName()}; Message: $message") } catch { case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => - Thread.sleep(10) + Thread.sleep(10) println("Queue is full, wait for to consume data from the message queue") - } - } + } + } } catch { case e: MqttException => println("Exception Caught: " + e) } finally { @@ -107,7 +107,7 @@ object MQTTWordCount { val lines = MQTTUtils.createStream(ssc, brokerUrl, topic, StorageLevel.MEMORY_ONLY_SER_2) val words = lines.flatMap(x => x.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - + wordCounts.print() ssc.start() ssc.awaitTermination() diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 60e2994431b38..1e32a365a1eee 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -152,9 +152,9 @@ class FlumeReceiver( val channelFactory = new NioServerSocketChannelFactory(Executors.newCachedThreadPool(), Executors.newCachedThreadPool()) val channelPipelineFactory = new CompressionChannelPipelineFactory() - + new NettyServer( - responder, + responder, new InetSocketAddress(host, port), channelFactory, channelPipelineFactory, @@ -188,12 +188,12 @@ class FlumeReceiver( override def preferredLocation: Option[String] = Option(host) - /** A Netty Pipeline factory that will decompress incoming data from + /** A Netty Pipeline factory that will decompress incoming data from * and the Netty client and compress data going back to the client. * * The compression on the return is required because Flume requires - * a successful response to indicate it can remove the event/batch - * from the configured channel + * a successful response to indicate it can remove the event/batch + * from the configured channel */ private[streaming] class CompressionChannelPipelineFactory extends ChannelPipelineFactory { diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala index 92fa5b41be89e..583e7dca317ad 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala @@ -110,7 +110,7 @@ private[streaming] class FlumePollingReceiver( } /** - * A wrapper around the transceiver and the Avro IPC API. + * A wrapper around the transceiver and the Avro IPC API. * @param transceiver The transceiver to use for communication with Flume * @param client The client that the callbacks are received on. */ diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 3d9daeb6e4363..c926359987d89 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -138,7 +138,7 @@ class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers w val status = client.appendBatch(inputEvents.toList) status should be (avro.Status.OK) } - + eventually(timeout(10 seconds), interval(100 milliseconds)) { val outputEvents = outputBuffer.flatten.map { _.event } outputEvents.foreach { diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index 6cf254a7b69cb..65d51d87f8486 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -113,7 +113,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { r.flatMap { tm: TopicMetadata => tm.partitionsMetadata.map { pm: PartitionMetadata => TopicAndPartition(tm.topic, pm.partitionId) - } + } } } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 8be2707528d93..0b8a391a2c569 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -315,7 +315,7 @@ object KafkaUtils { * Points to note: * - No receivers: This stream does not use any receiver. It directly queries Kafka * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked - * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * by the stream itself. For interoperability with Kafka monitoring tools that depend on * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. * You can access the offsets used in each batch from the generated RDDs (see * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). @@ -363,7 +363,7 @@ object KafkaUtils { * Points to note: * - No receivers: This stream does not use any receiver. It directly queries Kafka * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked - * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * by the stream itself. For interoperability with Kafka monitoring tools that depend on * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. * You can access the offsets used in each batch from the generated RDDs (see * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). @@ -427,7 +427,7 @@ object KafkaUtils { * Points to note: * - No receivers: This stream does not use any receiver. It directly queries Kafka * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked - * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * by the stream itself. For interoperability with Kafka monitoring tools that depend on * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. * You can access the offsets used in each batch from the generated RDDs (see * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). @@ -489,7 +489,7 @@ object KafkaUtils { * Points to note: * - No receivers: This stream does not use any receiver. It directly queries Kafka * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked - * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * by the stream itself. For interoperability with Kafka monitoring tools that depend on * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. * You can access the offsets used in each batch from the generated RDDs (see * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index 97c3476049289..be8b62d3cc6ba 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -119,7 +119,7 @@ object KinesisWordCountASL extends Logging { val batchInterval = Milliseconds(2000) // Kinesis checkpoint interval is the interval at which the DynamoDB is updated with information - // on sequence number of records that have been received. Same as batchInterval for this + // on sequence number of records that have been received. Same as batchInterval for this // example. val kinesisCheckpointInterval = batchInterval @@ -145,7 +145,7 @@ object KinesisWordCountASL extends Logging { // Map each word to a (word, 1) tuple so we can reduce by key to count the words val wordCounts = words.map(word => (word, 1)).reduceByKey(_ + _) - + // Print the first 10 wordCounts wordCounts.print() @@ -210,14 +210,14 @@ object KinesisWordProducerASL { val randomWords = List("spark", "you", "are", "my", "father") val totals = scala.collection.mutable.Map[String, Int]() - + // Create the low-level Kinesis Client from the AWS Java SDK. val kinesisClient = new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain()) kinesisClient.setEndpoint(endpoint) println(s"Putting records onto stream $stream and endpoint $endpoint at a rate of" + s" $recordsPerSecond records per second and $wordsPerRecord words per record") - + // Iterate and put records onto the stream per the given recordPerSec and wordsPerRecord for (i <- 1 to 10) { // Generate recordsPerSec records to put onto the stream @@ -255,8 +255,8 @@ object KinesisWordProducerASL { } } -/** - * Utility functions for Spark Streaming examples. +/** + * Utility functions for Spark Streaming examples. * This has been lifted from the examples/ project to remove the circular dependency. */ private[streaming] object StreamingExamples extends Logging { diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala index 1c9b0c218ae18..83a4537559512 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala @@ -23,20 +23,20 @@ import org.apache.spark.util.{Clock, ManualClock, SystemClock} /** * This is a helper class for managing checkpoint clocks. * - * @param checkpointInterval + * @param checkpointInterval * @param currentClock. Default to current SystemClock if none is passed in (mocking purposes) */ private[kinesis] class KinesisCheckpointState( - checkpointInterval: Duration, + checkpointInterval: Duration, currentClock: Clock = new SystemClock()) extends Logging { - + /* Initialize the checkpoint clock using the given currentClock + checkpointInterval millis */ val checkpointClock = new ManualClock() checkpointClock.setTime(currentClock.getTimeMillis() + checkpointInterval.milliseconds) /** - * Check if it's time to checkpoint based on the current time and the derived time + * Check if it's time to checkpoint based on the current time and the derived time * for the next checkpoint * * @return true if it's time to checkpoint diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 7dd8bfdc2a6db..1a8a4cecc1141 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -44,12 +44,12 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) * https://github.com/awslabs/amazon-kinesis-client * This is a custom receiver used with StreamingContext.receiverStream(Receiver) as described here: * http://spark.apache.org/docs/latest/streaming-custom-receivers.html - * Instances of this class will get shipped to the Spark Streaming Workers to run within a + * Instances of this class will get shipped to the Spark Streaming Workers to run within a * Spark Executor. * * @param appName Kinesis application name. Kinesis Apps are mapped to Kinesis Streams * by the Kinesis Client Library. If you change the App name or Stream name, - * the KCL will throw errors. This usually requires deleting the backing + * the KCL will throw errors. This usually requires deleting the backing * DynamoDB table with the same name this Kinesis application. * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) @@ -87,7 +87,7 @@ private[kinesis] class KinesisReceiver( */ /** - * workerId is used by the KCL should be based on the ip address of the actual Spark Worker + * workerId is used by the KCL should be based on the ip address of the actual Spark Worker * where this code runs (not the driver's IP address.) */ private var workerId: String = null @@ -121,7 +121,7 @@ private[kinesis] class KinesisReceiver( /* * RecordProcessorFactory creates impls of IRecordProcessor. - * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the + * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the * IRecordProcessor.processRecords() method. * We're using our custom KinesisRecordProcessor in this case. */ diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index f65e743c4e2a3..fe9e3a0c793e2 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -35,9 +35,9 @@ import com.amazonaws.services.kinesis.model.Record /** * Kinesis-specific implementation of the Kinesis Client Library (KCL) IRecordProcessor. * This implementation operates on the Array[Byte] from the KinesisReceiver. - * The Kinesis Worker creates an instance of this KinesisRecordProcessor for each - * shard in the Kinesis stream upon startup. This is normally done in separate threads, - * but the KCLs within the KinesisReceivers will balance themselves out if you create + * The Kinesis Worker creates an instance of this KinesisRecordProcessor for each + * shard in the Kinesis stream upon startup. This is normally done in separate threads, + * but the KCLs within the KinesisReceivers will balance themselves out if you create * multiple Receivers. * * @param receiver Kinesis receiver @@ -69,14 +69,14 @@ private[kinesis] class KinesisRecordProcessor( * and Spark Streaming's Receiver.store(). * * @param batch list of records from the Kinesis stream shard - * @param checkpointer used to update Kinesis when this batch has been processed/stored + * @param checkpointer used to update Kinesis when this batch has been processed/stored * in the DStream */ override def processRecords(batch: List[Record], checkpointer: IRecordProcessorCheckpointer) { if (!receiver.isStopped()) { try { /* - * Notes: + * Notes: * 1) If we try to store the raw ByteBuffer from record.getData(), the Spark Streaming * Receiver.store(ByteBuffer) attempts to deserialize the ByteBuffer using the * internally-configured Spark serializer (kryo, etc). @@ -84,19 +84,19 @@ private[kinesis] class KinesisRecordProcessor( * ourselves from Spark's internal serialization strategy. * 3) For performance, the BlockGenerator is asynchronously queuing elements within its * memory before creating blocks. This prevents the small block scenario, but requires - * that you register callbacks to know when a block has been generated and stored + * that you register callbacks to know when a block has been generated and stored * (WAL is sufficient for storage) before can checkpoint back to the source. */ batch.foreach(record => receiver.store(record.getData().array())) - + logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") /* - * Checkpoint the sequence number of the last record successfully processed/stored + * Checkpoint the sequence number of the last record successfully processed/stored * in the batch. * In this implementation, we're checkpointing after the given checkpointIntervalMillis. - * Note that this logic requires that processRecords() be called AND that it's time to - * checkpoint. I point this out because there is no background thread running the + * Note that this logic requires that processRecords() be called AND that it's time to + * checkpoint. I point this out because there is no background thread running the * checkpointer. Checkpointing is tested and trigger only when a new batch comes in. * If the worker is shutdown cleanly, checkpoint will happen (see shutdown() below). * However, if the worker dies unexpectedly, a checkpoint may not happen. @@ -130,16 +130,16 @@ private[kinesis] class KinesisRecordProcessor( } } else { /* RecordProcessor has been stopped. */ - logInfo(s"Stopped: The Spark KinesisReceiver has stopped for workerId $workerId" + + logInfo(s"Stopped: The Spark KinesisReceiver has stopped for workerId $workerId" + s" and shardId $shardId. No more records will be processed.") } } /** * Kinesis Client Library is shutting down this Worker for 1 of 2 reasons: - * 1) the stream is resharding by splitting or merging adjacent shards + * 1) the stream is resharding by splitting or merging adjacent shards * (ShutdownReason.TERMINATE) - * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason + * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason * (ShutdownReason.ZOMBIE) * * @param checkpointer used to perform a Kinesis checkpoint for ShutdownReason.TERMINATE @@ -153,7 +153,7 @@ private[kinesis] class KinesisRecordProcessor( * Checkpoint to indicate that all records from the shard have been drained and processed. * It's now OK to read from the new shards that resulted from a resharding event. */ - case ShutdownReason.TERMINATE => + case ShutdownReason.TERMINATE => KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100) /* diff --git a/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala index 7629128010193..094a63472eaab 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala @@ -23,15 +23,15 @@ class EdgeSuite extends SparkFunSuite { test ("compare") { // decending order val testEdges: Array[Edge[Int]] = Array( - Edge(0x7FEDCBA987654321L, -0x7FEDCBA987654321L, 1), - Edge(0x2345L, 0x1234L, 1), - Edge(0x1234L, 0x5678L, 1), - Edge(0x1234L, 0x2345L, 1), + Edge(0x7FEDCBA987654321L, -0x7FEDCBA987654321L, 1), + Edge(0x2345L, 0x1234L, 1), + Edge(0x1234L, 0x5678L, 1), + Edge(0x1234L, 0x2345L, 1), Edge(-0x7FEDCBA987654321L, 0x7FEDCBA987654321L, 1) ) // to ascending order val sortedEdges = testEdges.sorted(Edge.lexicographicOrdering[Int]) - + for (i <- 0 until testEdges.length) { assert(sortedEdges(i) == testEdges(testEdges.length - i - 1)) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 0588517a2de39..8d73593ab6375 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -191,7 +191,7 @@ private[streaming] class BlockGenerator( logError(message, t) listener.onError(message, t) } - + private def pushBlock(block: Block) { listener.onPushBlock(block.id, block.buffer) logInfo("Pushed block " + block.id) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 651b534ac1900..207d64d9414ee 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -62,7 +62,7 @@ private[streaming] case class BlockManagerBasedStoreResult(blockId: StreamBlockI private[streaming] class BlockManagerBasedBlockHandler( blockManager: BlockManager, storageLevel: StorageLevel) extends ReceivedBlockHandler with Logging { - + def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = { val putResult: Seq[(BlockId, BlockStatus)] = block match { case ArrayBufferBlock(arrayBuffer) => diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index 021d2c95a4aad..cbc24aee4fa1e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -28,9 +28,6 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ - - - /** * Selenium tests for the Spark Web UI. */ @@ -197,4 +194,3 @@ class UISeleniumSuite } } } - diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 0acf7068ef4a4..325ff7c74c39d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -36,7 +36,7 @@ import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { import WriteAheadLogSuite._ - + val hadoopConf = new Configuration() var tempDir: File = null var testDir: String = null @@ -359,7 +359,7 @@ object WriteAheadLogSuite { ): FileBasedWriteAheadLog = { if (manualClock.getTimeMillis() < 100000) manualClock.setTime(10000) val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1) - + // Ensure that 500 does not get sorted after 2000, so put a high base value. data.foreach { item => manualClock.advance(500) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala index 4ca6c903fcf12..3d3a966960e9f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala @@ -43,22 +43,22 @@ private[spark] class ClientDistributedCacheManager() extends Logging { * Add a resource to the list of distributed cache resources. This list can * be sent to the ApplicationMaster and possibly the executors so that it can * be downloaded into the Hadoop distributed cache for use by this application. - * Adds the LocalResource to the localResources HashMap passed in and saves + * Adds the LocalResource to the localResources HashMap passed in and saves * the stats of the resources to they can be sent to the executors and verified. * * @param fs FileSystem * @param conf Configuration * @param destPath path to the resource * @param localResources localResource hashMap to insert the resource into - * @param resourceType LocalResourceType + * @param resourceType LocalResourceType * @param link link presented in the distributed cache to the destination - * @param statCache cache to store the file/directory stats + * @param statCache cache to store the file/directory stats * @param appMasterOnly Whether to only add the resource to the app master */ def addResource( fs: FileSystem, conf: Configuration, - destPath: Path, + destPath: Path, localResources: HashMap[String, LocalResource], resourceType: LocalResourceType, link: String, @@ -74,15 +74,15 @@ private[spark] class ClientDistributedCacheManager() extends Logging { amJarRsrc.setSize(destStatus.getLen()) if (link == null || link.isEmpty()) throw new Exception("You must specify a valid link name") localResources(link) = amJarRsrc - + if (!appMasterOnly) { val uri = destPath.toUri() val pathURI = new URI(uri.getScheme(), uri.getAuthority(), uri.getPath(), null, link) if (resourceType == LocalResourceType.FILE) { - distCacheFiles(pathURI.toString()) = (destStatus.getLen().toString(), + distCacheFiles(pathURI.toString()) = (destStatus.getLen().toString(), destStatus.getModificationTime().toString(), visibility.name()) } else { - distCacheArchives(pathURI.toString()) = (destStatus.getLen().toString(), + distCacheArchives(pathURI.toString()) = (destStatus.getLen().toString(), destStatus.getModificationTime().toString(), visibility.name()) } } @@ -96,11 +96,11 @@ private[spark] class ClientDistributedCacheManager() extends Logging { val (sizes, timeStamps, visibilities) = tupleValues.unzip3 if (keys.size > 0) { env("SPARK_YARN_CACHE_FILES") = keys.reduceLeft[String] { (acc, n) => acc + "," + n } - env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") = + env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") = timeStamps.reduceLeft[String] { (acc, n) => acc + "," + n } - env("SPARK_YARN_CACHE_FILES_FILE_SIZES") = + env("SPARK_YARN_CACHE_FILES_FILE_SIZES") = sizes.reduceLeft[String] { (acc, n) => acc + "," + n } - env("SPARK_YARN_CACHE_FILES_VISIBILITIES") = + env("SPARK_YARN_CACHE_FILES_VISIBILITIES") = visibilities.reduceLeft[String] { (acc, n) => acc + "," + n } } } @@ -113,11 +113,11 @@ private[spark] class ClientDistributedCacheManager() extends Logging { val (sizes, timeStamps, visibilities) = tupleValues.unzip3 if (keys.size > 0) { env("SPARK_YARN_CACHE_ARCHIVES") = keys.reduceLeft[String] { (acc, n) => acc + "," + n } - env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") = + env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") = timeStamps.reduceLeft[String] { (acc, n) => acc + "," + n } env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") = sizes.reduceLeft[String] { (acc, n) => acc + "," + n } - env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") = + env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") = visibilities.reduceLeft[String] { (acc, n) => acc + "," + n } } } @@ -197,7 +197,7 @@ private[spark] class ClientDistributedCacheManager() extends Logging { def getFileStatus(fs: FileSystem, uri: URI, statCache: Map[URI, FileStatus]): FileStatus = { val stat = statCache.get(uri) match { case Some(existstat) => existstat - case None => + case None => val newStat = fs.getFileStatus(new Path(uri)) statCache.put(uri, newStat) newStat diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 5e6531895c7ba..68d01c17ef720 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -144,9 +144,9 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { } object YarnSparkHadoopUtil { - // Additional memory overhead + // Additional memory overhead // 10% was arrived at experimentally. In the interest of minimizing memory waste while covering - // the common cases. Memory overhead tends to grow with container size. + // the common cases. Memory overhead tends to grow with container size. val MEMORY_OVERHEAD_FACTOR = 0.10 val MEMORY_OVERHEAD_MIN = 384 diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala index 43a7334db874c..804dfecde7867 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala @@ -41,12 +41,12 @@ import org.apache.spark.SparkFunSuite class ClientDistributedCacheManagerSuite extends SparkFunSuite with MockitoSugar { class MockClientDistributedCacheManager extends ClientDistributedCacheManager { - override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): + override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): LocalResourceVisibility = { LocalResourceVisibility.PRIVATE } } - + test("test getFileStatus empty") { val distMgr = new ClientDistributedCacheManager() val fs = mock[FileSystem] @@ -61,7 +61,7 @@ class ClientDistributedCacheManagerSuite extends SparkFunSuite with MockitoSugar val distMgr = new ClientDistributedCacheManager() val fs = mock[FileSystem] val uri = new URI("/tmp/testing") - val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner", + val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner", null, new Path("/tmp/testing")) when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus()) val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus](uri -> realFileStatus) @@ -78,7 +78,7 @@ class ClientDistributedCacheManagerSuite extends SparkFunSuite with MockitoSugar val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() when(fs.getFileStatus(destPath)).thenReturn(new FileStatus()) - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link", + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link", statCache, false) val resource = localResources("link") assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) @@ -101,11 +101,11 @@ class ClientDistributedCacheManagerSuite extends SparkFunSuite with MockitoSugar assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None) // add another one and verify both there and order correct - val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", null, new Path("/tmp/testing2")) val destPath2 = new Path("file:///foo.invalid.com:8080/tmp/testing2") when(fs.getFileStatus(destPath2)).thenReturn(realFileStatus) - distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2", + distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2", statCache, false) val resource2 = localResources("link2") assert(resource2.getVisibility() === LocalResourceVisibility.PRIVATE) @@ -117,7 +117,7 @@ class ClientDistributedCacheManagerSuite extends SparkFunSuite with MockitoSugar val env2 = new HashMap[String, String]() distMgr.setDistFilesEnv(env2) val timestamps = env2("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',') - val files = env2("SPARK_YARN_CACHE_FILES").split(',') + val files = env2("SPARK_YARN_CACHE_FILES").split(',') val sizes = env2("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',') val visibilities = env2("SPARK_YARN_CACHE_FILES_VISIBILITIES") .split(',') assert(files(0) === "file:/foo.invalid.com:8080/tmp/testing#link") @@ -141,7 +141,7 @@ class ClientDistributedCacheManagerSuite extends SparkFunSuite with MockitoSugar when(fs.getFileStatus(destPath)).thenReturn(new FileStatus()) intercept[Exception] { - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null, + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null, statCache, false) } assert(localResources.get("link") === None) @@ -155,11 +155,11 @@ class ClientDistributedCacheManagerSuite extends SparkFunSuite with MockitoSugar val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing") val localResources = HashMap[String, LocalResource]() val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() - val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", null, new Path("/tmp/testing")) when(fs.getFileStatus(destPath)).thenReturn(realFileStatus) - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", statCache, true) val resource = localResources("link") assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) @@ -189,11 +189,11 @@ class ClientDistributedCacheManagerSuite extends SparkFunSuite with MockitoSugar val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing") val localResources = HashMap[String, LocalResource]() val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() - val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", null, new Path("/tmp/testing")) when(fs.getFileStatus(destPath)).thenReturn(realFileStatus) - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", statCache, false) val resource = localResources("link") assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) From 63a50be13d32b9e5f3aad8d1a6ba5362f17a252f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 31 May 2015 00:48:49 -0700 Subject: [PATCH 108/254] [SPARK-3850] Trim trailing spaces for SQL. Author: Reynold Xin Closes #6535 from rxin/whitespace-sql and squashes the following commits: de50316 [Reynold Xin] [SPARK-3850] Trim trailing spaces for SQL. --- .../sql/catalyst/CatalystTypeConverters.scala | 2 +- .../catalyst/analysis/HiveTypeCoercion.scala | 4 ++-- .../sql/catalyst/expressions/SortOrder.scala | 2 +- .../sql/catalyst/expressions/aggregates.scala | 16 ++++++++-------- .../sql/catalyst/expressions/arithmetic.scala | 2 +- .../catalyst/expressions/complexTypes.scala | 2 +- .../expressions/mathfuncs/binary.scala | 2 +- .../sql/catalyst/expressions/random.scala | 2 +- .../expressions/stringOperations.scala | 6 +++--- .../apache/spark/sql/types/StructType.scala | 2 +- .../ExpressionEvaluationSuite.scala | 4 ++-- .../optimizer/CombiningLimitsSuite.scala | 4 ++-- .../optimizer/ConstantFoldingSuite.scala | 2 +- .../optimizer/FilterPushdownSuite.scala | 4 ++-- .../catalyst/optimizer/OptimizeInSuite.scala | 2 +- .../apache/spark/sql/types/DataTypeSuite.scala | 6 +++--- .../org/apache/spark/sql/GroupedData.scala | 2 +- .../org/apache/spark/sql/api/r/SQLUtils.scala | 4 ++-- .../sql/execution/GeneratedAggregate.scala | 8 ++++---- .../spark/sql/execution/basicOperators.scala | 2 +- .../sql/execution/stat/FrequentItems.scala | 6 +++--- .../sql/execution/stat/StatFunctions.scala | 2 +- .../scala/org/apache/spark/sql/functions.scala | 2 +- .../org/apache/spark/sql/jdbc/JDBCRDD.scala | 2 +- .../apache/spark/sql/jdbc/JDBCRelation.scala | 6 +++--- .../scala/org/apache/spark/sql/jdbc/jdbc.scala | 4 ++-- .../spark/sql/sources/SqlNewHadoopRDD.scala | 2 +- .../apache/spark/sql/DataFrameStatSuite.scala | 2 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 6 +++--- .../apache/spark/sql/jdbc/JDBCWriteSuite.scala | 12 ++++++------ .../apache/spark/sql/hive/client/package.scala | 2 +- .../sql/hive/execution/HiveTableScan.scala | 4 ++-- .../hive/execution/ScriptTransformation.scala | 18 +++++++++--------- .../org/apache/spark/sql/hive/hiveUdfs.scala | 10 +++++----- .../spark/sql/hive/CachedTableSuite.scala | 2 +- .../sql/hive/InsertIntoHiveTableSuite.scala | 2 +- .../spark/sql/hive/client/VersionsSuite.scala | 6 +++--- .../hive/execution/HiveTableScanSuite.scala | 8 ++++---- .../sql/hive/execution/HiveUdfSuite.scala | 2 +- 39 files changed, 88 insertions(+), 88 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 75a493b248f6e..1c0ddb5093d17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -233,7 +233,7 @@ object CatalystTypeConverters { case other => other } - /** + /** * Converts Catalyst types used internally in rows to standard Scala types * This method is slow, and for batch conversion you should be using converter * produced by createToScalaConverter. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 195418d6dfb1f..96d7b96e60ee9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -296,8 +296,8 @@ trait HiveTypeCoercion { object InConversion extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. - case e if !e.childrenResolved => e - + case e if !e.childrenResolved => e + case i @ In(a, b) if b.exists(_.dataType != a.dataType) => i.makeCopy(Array(a, b.map(Cast(_, a.dataType)))) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 195eec8e5cdc4..99340a14c9ecc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -29,7 +29,7 @@ case object Descending extends SortDirection * An expression that can be used to sort a tuple. This class extends expression primarily so that * transformations over expression will descend into its child. */ -case class SortOrder(child: Expression, direction: SortDirection) extends Expression +case class SortOrder(child: Expression, direction: SortDirection) extends Expression with trees.UnaryNode[Expression] { override def dataType: DataType = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 6c380d3084652..0266084a6d174 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -394,13 +394,13 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ * Combining PartitionLevel InputData * <-- null * Zero <-- Zero <-- null - * + * * <-- null <-- no data - * null <-- null <-- no data + * null <-- null <-- no data */ case class CombineSum(child: Expression) extends AggregateExpression { def this() = this(null) - + override def children: Seq[Expression] = child :: Nil override def nullable: Boolean = true override def dataType: DataType = child.dataType @@ -616,7 +616,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr private val sum = MutableLiteral(null, calcType) - private val addFunction = + private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) override def update(input: Row): Unit = { @@ -634,7 +634,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr case class CombineSumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { - + def this() = this(null, null) // Required for serialization. private val calcType = @@ -649,12 +649,12 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression) private val sum = MutableLiteral(null, calcType) - private val addFunction = + private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) - + override def update(input: Row): Unit = { val result = expr.eval(input) - // partial sum result can be null only when no input rows present + // partial sum result can be null only when no input rows present if(result != null) { sum.update(addFunction, input) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 34c833b260dc0..f2299d5db6e9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -180,7 +180,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot case other => sys.error(s"Type $other does not support numeric operations") } - + override def eval(input: Row): Any = { val evalE2 = right.eval(input) if (evalE2 == null || evalE2 == 0) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index e7cd7131a9e56..6398b8f9e4ed7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.types._ case class CreateArray(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) - + lazy val childTypes = children.map(_.dataType).distinct override lazy val resolved = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala index 890efc9f52ca3..01f62ba0442e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.types._ * @param f The math function. * @param name The short name of the function */ -abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) +abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => override def symbol: String = null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index de82c15680607..4f4f67a6e482c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -24,7 +24,7 @@ import org.apache.spark.util.random.XORShiftRandom /** * A Random distribution generating expression. - * TODO: This can be made generic to generate any type of random distribution, or any type of + * TODO: This can be made generic to generate any type of random distribution, or any type of * StructType. * * Since this expression is stateful, it cannot be a case object. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 83a44a12f0682..c4ef9c30907f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -133,7 +133,7 @@ trait CaseConversionExpression extends ExpectsInputTypes { * A function that converts the characters of a string to uppercase. */ case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { - + override def convert(v: UTF8String): UTF8String = v.toUpperCase() override def toString: String = s"Upper($child)" @@ -143,7 +143,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE * A function that converts the characters of a string to lowercase. */ case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { - + override def convert(v: UTF8String): UTF8String = v.toLowerCase() override def toString: String = s"Lower($child)" @@ -223,7 +223,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) @inline def slicePos(startPos: Int, sliceLen: Int, length: () => Int): (Int, Int) = { // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and - // negative indices for start positions. If a start index i is greater than 0, it + // negative indices for start positions. If a start index i is greater than 0, it // refers to element i-1 in the sequence. If a start index i is less than 0, it refers // to the -ith element before the end of the sequence. If a start index i is 0, it // refers to the first element. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index a4f30c825befb..193c08a4d0df7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -265,7 +265,7 @@ object StructType { case _ => throw new SparkException(s"Failed to merge incompatible data types $left and $right") } - + private[sql] def fieldsMap(fields: Array[StructField]): Map[String, StructField] = { import scala.collection.breakOut fields.map(s => (s.name, s))(breakOut) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 10181366c2fcd..3f5a660f17e1d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -1209,7 +1209,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } /** - * Used for testing math functions for DataFrames. + * Used for testing math functions for DataFrames. * @param c The DataFrame function * @param f The functions in scala.math * @param domain The set of values to run the function with @@ -1217,7 +1217,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { * @tparam T Generic type for primitives */ def unaryMathFunctionEvaluation[@specialized(Int, Double, Float, Long) T]( - c: Expression => Expression, + c: Expression => Expression, f: T => T, domain: Iterable[T] = (-20 to 20).map(_ * 0.1), expectNull: Boolean = false): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index a30052b38fc11..06c592f4905a3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -71,7 +71,7 @@ class CombiningLimitsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - + test("limits: combines two limits after ColumnPruning") { val originalQuery = testRelation @@ -79,7 +79,7 @@ class CombiningLimitsSuite extends PlanTest { .limit(2) .select('a) .limit(5) - + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 5697c2272b8e8..ec3b2f1edfa05 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -248,7 +248,7 @@ class ConstantFoldingSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - + test("Constant folding test: Fold In(v, list) into true or false") { var originalQuery = testRelation diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index ff25470bf0946..17dc9124749e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -93,7 +93,7 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - + test("column pruning for Project(ne, Limit)") { val originalQuery = testRelation @@ -109,7 +109,7 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - + // After this line is unimplemented. test("simple push down") { val originalQuery = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 11b0859d3f066..1d433275fed2e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -57,7 +57,7 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - + test("OptimizedIn test: In clause not optimized in case filter has attributes") { val originalQuery = testRelation diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 543cdefc5293b..261c4fcad24aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -71,7 +71,7 @@ class DataTypeSuite extends SparkFunSuite { test("fieldsMap returns map of name to StructField") { val struct = StructType( - StructField("a", LongType) :: + StructField("a", LongType) :: StructField("b", FloatType) :: Nil) val mapped = StructType.fieldsMap(struct.fields) @@ -90,7 +90,7 @@ class DataTypeSuite extends SparkFunSuite { val right = StructType(List()) val merged = left.merge(right) - + assert(merged === left) } @@ -133,7 +133,7 @@ class DataTypeSuite extends SparkFunSuite { val right = StructType( StructField("b", LongType) :: Nil) - + intercept[SparkException] { left.merge(right) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index c4ceb0c173887..45b3e1bc627d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -249,7 +249,7 @@ class GroupedData protected[sql]( def mean(colNames: String*): DataFrame = { aggregateNumericColumns(colNames : _*)(Average) } - + /** * Compute the max value for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 423ecdff5804a..604f3124e23ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -106,7 +106,7 @@ private[r] object SQLUtils { dfCols.map { col => colToRBytes(col) - } + } } def convertRowsToColumns(localDF: Array[Row], numCols: Int): Array[Array[Any]] = { @@ -121,7 +121,7 @@ private[r] object SQLUtils { val numRows = col.length val bos = new ByteArrayOutputStream() val dos = new DataOutputStream(bos) - + SerDe.writeInt(dos, numRows) col.map { item => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 2ec7d4fbc92de..3e27c1bde2dfd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -138,15 +138,15 @@ case class GeneratedAggregate( case UnscaledValue(e) => e case _ => expr } - // partial sum result can be null only when no input rows present + // partial sum result can be null only when no input rows present val updateFunction = If( IsNotNull(actualExpr), Coalesce( Add( - Coalesce(currentSum :: zero :: Nil), + Coalesce(currentSum :: zero :: Nil), Cast(expr, calcType)) :: currentSum :: zero :: Nil), currentSum) - + val result = expr.dataType match { case DecimalType.Fixed(_, _) => @@ -155,7 +155,7 @@ case class GeneratedAggregate( } AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - + case m @ Max(expr) => val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() val initialValue = Literal.create(null, expr.dataType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 6cb67b4bbbb65..a30ade86441ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -65,7 +65,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { * :: DeveloperApi :: * Sample the dataset. * @param lowerBound Lower-bound of the sampling probability (usually 0.0) - * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled + * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled * will be ub - lb. * @param withReplacement Whether to sample with replacement. * @param seed the random seed diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index fe8a81e3d0434..c41c21c0eeb50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -62,7 +62,7 @@ private[sql] object FrequentItems extends Logging { } /** - * Finding frequent items for columns, possibly with false positives. Using the + * Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. * The `support` should be greater than 1e-4. @@ -75,7 +75,7 @@ private[sql] object FrequentItems extends Logging { * @return A Local DataFrame with the Array of frequent items for each column. */ private[sql] def singlePassFreqItems( - df: DataFrame, + df: DataFrame, cols: Seq[String], support: Double): DataFrame = { require(support >= 1e-4, s"support ($support) must be greater than 1e-4.") @@ -88,7 +88,7 @@ private[sql] object FrequentItems extends Logging { val index = originalSchema.fieldIndex(name) (name, originalSchema.fields(index).dataType) } - + val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)( seqOp = (counts, row) => { var i = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index d22f5fd2d439c..b1a8204dd5f71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ private[sql] object StatFunctions extends Logging { - + /** Calculate the Pearson Correlation Coefficient for the given columns */ private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = { val counts = collectStatisticalData(df, cols) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6dc17bbb2e768..77327f2b84eaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1299,7 +1299,7 @@ object functions { * @since 1.4.0 */ def toRadians(columnName: String): Column = toRadians(Column(columnName)) - + ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 2d8d950038e78..40b604d710dce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -304,7 +304,7 @@ private[sql] class JDBCRDD( // Each JDBC-to-Catalyst conversion corresponds to a tag defined here so that // we don't have to potentially poke around in the Metadata once for every - // row. + // row. // Is there a better way to do this? I'd rather be using a type that // contains only the tags I define. abstract class JDBCConversion diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index 09d6865457df6..30f9190d45bf8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -54,7 +54,7 @@ private[sql] object JDBCRelation { if (numPartitions == 1) return Array[Partition](JDBCPartition(null, 0)) // Overflow and silliness can happen if you subtract then divide. // Here we get a little roundoff, but that's (hopefully) OK. - val stride: Long = (partitioning.upperBound / numPartitions + val stride: Long = (partitioning.upperBound / numPartitions - partitioning.lowerBound / numPartitions) var i: Int = 0 var currentValue: Long = partitioning.lowerBound @@ -140,10 +140,10 @@ private[sql] case class JDBCRelation( filters, parts) } - + override def insert(data: DataFrame, overwrite: Boolean): Unit = { data.write .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) .jdbc(url, table, properties) - } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index f21dd29aca37f..dd8aaf6474895 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -240,10 +240,10 @@ package object jdbc { } } } - + def getDriverClassName(url: String): String = DriverManager.getDriver(url) match { case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName - case driver => driver.getClass.getCanonicalName + case driver => driver.getClass.getCanonicalName } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala index a74a98631da35..ebad0c1564ec0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala @@ -216,7 +216,7 @@ private[sql] class SqlNewHadoopRDD[K, V]( override def getPreferredLocations(hsplit: SparkPartition): Seq[String] = { val split = hsplit.asInstanceOf[SqlNewHadoopPartition].serializableHadoopSplit.value val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match { - case Some(c) => + case Some(c) => try { val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]] Some(HadoopRDD.convertSplitLocationInfo(infos)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index add0fd58e28c8..78de89f0b9f39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.implicits._ class DataFrameStatSuite extends SparkFunSuite { - + val sqlCtx = TestSQLContext def toLetter(i: Int): String = (i + 97).toChar.toString diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index af279007c587e..e20c66cb2f1d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -68,7 +68,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - + sql( s""" |CREATE TEMPORARY TABLE fetchtwo @@ -76,7 +76,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass', | fetchSize '2') """.stripMargin.replaceAll("\n", " ")) - + sql( s""" |CREATE TEMPORARY TABLE parts @@ -209,7 +209,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { assert(ids(1) === 2) assert(ids(2) === 3) } - + test("SELECT second field when fetchSize is two") { val ids = sql("SELECT THEID FROM fetchtwo").collect().map(x => x.getInt(0)).sortWith(_ < _) assert(ids.size === 3) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 3cd987b0b3383..2de8c1a6098e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -36,12 +36,12 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { properties.setProperty("user", "testUser") properties.setProperty("password", "testPass") properties.setProperty("rowId", "false") - + before { Class.forName("org.h2.Driver") conn = DriverManager.getConnection(url) conn.prepareStatement("create schema test").executeUpdate() - + conn1 = DriverManager.getConnection(url1, properties) conn1.prepareStatement("create schema test").executeUpdate() conn1.prepareStatement("drop table if exists test.people").executeUpdate() @@ -53,20 +53,20 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { conn1.prepareStatement( "create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() conn1.commit() - + TestSQLContext.sql( s""" |CREATE TEMPORARY TABLE PEOPLE |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - + TestSQLContext.sql( s""" |CREATE TEMPORARY TABLE PEOPLE1 |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE1', user 'testUser', password 'testPass') - """.stripMargin.replaceAll("\n", " ")) + """.stripMargin.replaceAll("\n", " ")) } after { @@ -152,5 +152,5 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { TestSQLContext.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) - } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 7db9200d47440..410d9881ac214 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -29,5 +29,5 @@ package object client { case object v13 extends HiveVersion("0.13.1", false) } // scalastyle:on - + } \ No newline at end of file diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 62dc4167b78dd..11ee5503146b9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -63,7 +63,7 @@ case class HiveTableScan( BindReferences.bindReference(pred, relation.partitionKeys) } - // Create a local copy of hiveconf,so that scan specific modifications should not impact + // Create a local copy of hiveconf,so that scan specific modifications should not impact // other queries @transient private[this] val hiveExtraConf = new HiveConf(context.hiveconf) @@ -72,7 +72,7 @@ case class HiveTableScan( addColumnMetadataToConf(hiveExtraConf) @transient - private[this] val hadoopReader = + private[this] val hadoopReader = new HadoopTableReader(attributes, relation, context, hiveExtraConf) private[this] def castFromString(value: String, dataType: DataType) = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 6f27a8626fc1e..fd623370cc407 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -62,7 +62,7 @@ case class ScriptTransformation( val inputStream = proc.getInputStream val outputStream = proc.getOutputStream val reader = new BufferedReader(new InputStreamReader(inputStream)) - + val (outputSerde, outputSoi) = ioschema.initOutputSerDe(output) val iterator: Iterator[Row] = new Iterator[Row] with HiveInspectors { @@ -95,7 +95,7 @@ case class ScriptTransformation( val raw = outputSerde.deserialize(writable) val dataList = outputSoi.getStructFieldsDataAsList(raw) val fieldList = outputSoi.getAllStructFieldRefs() - + var i = 0 dataList.foreach( element => { if (element == null) { @@ -117,7 +117,7 @@ case class ScriptTransformation( if (!hasNext) { throw new NoSuchElementException } - + if (outputSerde == null) { val prevLine = curLine curLine = reader.readLine() @@ -192,7 +192,7 @@ case class HiveScriptIOSchema ( val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k)) val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k)) - + def initInputSerDe(input: Seq[Expression]): (AbstractSerDe, ObjectInspector) = { val (columns, columnTypes) = parseAttrs(input) val serde = initSerDe(inputSerdeClass, columns, columnTypes, inputSerdeProps) @@ -206,13 +206,13 @@ case class HiveScriptIOSchema ( } def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = { - + val columns = attrs.map { case aref: AttributeReference => aref.name case e: NamedExpression => e.name case _ => null } - + val columnTypes = attrs.map { case aref: AttributeReference => aref.dataType case e: NamedExpression => e.dataType @@ -221,7 +221,7 @@ case class HiveScriptIOSchema ( (columns, columnTypes) } - + def initSerDe(serdeClassName: String, columns: Seq[String], columnTypes: Seq[DataType], serdeProps: Seq[(String, String)]): AbstractSerDe = { @@ -240,7 +240,7 @@ case class HiveScriptIOSchema ( (kv._1.split("'")(1), kv._2.split("'")(1)) }).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) - + val properties = new Properties() properties.putAll(propsMap) serde.initialize(null, properties) @@ -261,7 +261,7 @@ case class HiveScriptIOSchema ( null } } - + def initOutputputSoi(outputSerde: AbstractSerDe): StructObjectInspector = { if (outputSerde != null) { outputSerde.getObjectInspector().asInstanceOf[StructObjectInspector] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index bb116e3ab7de7..64a49c83cbad1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -555,12 +555,12 @@ private[hive] case class HiveUdafFunction( } else { funcWrapper.createFunction[AbstractGenericUDAFResolver]() } - + private val inspectors = exprs.map(toInspector).toArray - - private val function = { + + private val function = { val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false) - resolver.getEvaluator(parameterInfo) + resolver.getEvaluator(parameterInfo) } private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) @@ -575,7 +575,7 @@ private[hive] case class HiveUdafFunction( @transient protected lazy val cached = new Array[AnyRef](exprs.length) - + def update(input: Row): Unit = { val inputs = inputProjection(input) function.iterate(buffer, wrap(inputs, inspectors, cached)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 945596db80326..39d315aaeab57 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -57,7 +57,7 @@ class CachedTableSuite extends QueryTest { checkAnswer( sql("SELECT * FROM src s"), preCacheResults) - + uncacheTable("src") assertCached(sql("SELECT * FROM src"), 0) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 9cc4685499f19..aa5dbe2db6903 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -240,7 +240,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { checkAnswer(sql("select key,value from table_with_partition where ds='1' "), testData.collect().toSeq ) - + // test difference type of field sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") checkAnswer(sql("select key,value from table_with_partition where ds='1' "), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 446a2f2d646e1..7eb4842726665 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -22,9 +22,9 @@ import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.util.Utils /** - * A simple set of tests that call the methods of a hive ClientInterface, loading different version - * of hive from maven central. These tests are simple in that they are mostly just testing to make - * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionallity + * A simple set of tests that call the methods of a hive ClientInterface, loading different version + * of hive from maven central. These tests are simple in that they are mostly just testing to make + * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionality * is not fully tested. */ class VersionsSuite extends SparkFunSuite with Logging { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 0ba4d11478211..2209fc2f30a3c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -61,7 +61,7 @@ class HiveTableScanSuite extends HiveComparisonTest { TestHive.sql("select KEY from tb where VALUE='just_for_test' limit 5").collect() TestHive.sql("drop table tb") } - + test("Spark-4077: timestamp query for null value") { TestHive.sql("DROP TABLE IF EXISTS timestamp_query_null") TestHive.sql( @@ -71,11 +71,11 @@ class HiveTableScanSuite extends HiveComparisonTest { FIELDS TERMINATED BY ',' LINES TERMINATED BY '\n' """.stripMargin) - val location = + val location = Utils.getSparkClassLoader.getResource("data/files/issue-4077-data.txt").getFile() - + TestHive.sql(s"LOAD DATA LOCAL INPATH '$location' INTO TABLE timestamp_query_null") - assert(TestHive.sql("SELECT time from timestamp_query_null limit 2").collect() + assert(TestHive.sql("SELECT time from timestamp_query_null limit 2").collect() === Array(Row(java.sql.Timestamp.valueOf("2014-12-11 00:00:00")), Row(null))) TestHive.sql("DROP TABLE timestamp_query_null") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index 7f49eac490572..ce5985888f540 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -101,7 +101,7 @@ class HiveUdfSuite extends QueryTest { sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg") TestHive.reset() } - + test("SPARK-2693 udaf aggregates test") { checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"), sql("SELECT max(key) FROM src").collect().toSeq) From 4b5f12bac939a2f47a3a61365b5325d849b7b51f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 31 May 2015 01:37:56 -0700 Subject: [PATCH 109/254] [SPARK-7979] Enforce structural type checker. Author: Reynold Xin Closes #6536 from rxin/structural-type-checker and squashes the following commits: f833151 [Reynold Xin] Fixed compilation. 633f9a1 [Reynold Xin] Fixed typo. d1fa804 [Reynold Xin] [SPARK-7979] Enforce structural type checker. --- .../org/apache/spark/util/random/XORShiftRandomSuite.scala | 2 +- .../apache/spark/examples/mllib/DecisionTreeRunner.scala | 6 +++++- graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala | 4 +++- .../org/apache/spark/ml/classification/OneVsRest.scala | 2 ++ scalastyle-config.xml | 3 +++ 5 files changed, 14 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala index 6ca484ccd0c06..d26667bf720cf 100644 --- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala @@ -28,7 +28,7 @@ import scala.language.reflectiveCalls class XORShiftRandomSuite extends SparkFunSuite with Matchers { - def fixture: Object {val seed: Long; val hundMil: Int; val xorRand: XORShiftRandom} = new { + private def fixture = new { val seed = 1L val xorRand = new XORShiftRandom(seed) val hundMil = 1e8.toInt diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index b0613632c9946..3381941673db8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -22,7 +22,6 @@ import scala.language.reflectiveCalls import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -354,7 +353,11 @@ object DecisionTreeRunner { /** * Calculates the mean squared error for regression. + * + * This is just for demo purpose. In general, don't copy this code because it is NOT efficient + * due to the use of structural types, which leads to one reflection call per record. */ + // scalastyle:off structural.type private[mllib] def meanSquaredError( model: { def predict(features: Vector): Double }, data: RDD[LabeledPoint]): Double = { @@ -363,4 +366,5 @@ object DecisionTreeRunner { err * err }.mean() } + // scalastyle:on structural.type } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index cc70b396a8dd4..4611a3ace219b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -41,14 +41,16 @@ abstract class EdgeRDD[ED]( @transient sc: SparkContext, @transient deps: Seq[Dependency[_]]) extends RDD[Edge[ED]](sc, deps) { + // scalastyle:off structural.type private[graphx] def partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])] forSome { type VD } + // scalastyle:on structural.type override protected def getPartitions: Array[Partition] = partitionsRDD.partitions override def compute(part: Partition, context: TaskContext): Iterator[Edge[ED]] = { val p = firstParent[(PartitionID, EdgePartition[ED, _])].iterator(part, context) if (p.hasNext) { - p.next._2.iterator.map(_.copy()) + p.next()._2.iterator.map(_.copy()) } else { Iterator.empty } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index b8c7f3c5bc3b9..7b726da388075 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -37,11 +37,13 @@ import org.apache.spark.storage.StorageLevel */ private[ml] trait OneVsRestParams extends PredictorParams { + // scalastyle:off structural.type type ClassifierType = Classifier[F, E, M] forSome { type F type M <: ClassificationModel[F, M] type E <: Classifier[F, E, M] } + // scalastyle:on structural.type /** * param for the base binary classifier that we reduce multiclass classification into. diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 3a984222167b0..75ef1e964b5ac 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -114,6 +114,9 @@ + + + From d1d2def2f5f91e86f340656421170d1097f14854 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sun, 31 May 2015 11:18:12 -0700 Subject: [PATCH 110/254] [MINOR] Add license for dagre-d3 and graphlib-dot Add license for dagre-d3 and graphlib-dot Author: zsxwing Closes #6539 from zsxwing/LICENSE and squashes the following commits: 82b0475 [zsxwing] Add license for dagre-d3 and graphlib-dot --- LICENSE | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/LICENSE b/LICENSE index 9d1b00beff748..d0cd0dcb4bdb7 100644 --- a/LICENSE +++ b/LICENSE @@ -853,6 +853,52 @@ and Vis.js may be distributed under either license. +======================================================================== +For dagre-d3 (core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js): +======================================================================== +Copyright (c) 2013 Chris Pettitt + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +======================================================================== +For graphlib-dot (core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js): +======================================================================== +Copyright (c) 2012-2013 Chris Pettitt + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + ======================================================================== BSD-style licenses ======================================================================== From e1067d0ad1c32c678c23d76d7653b51770795831 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 31 May 2015 11:35:30 -0700 Subject: [PATCH 111/254] [SPARK-3850] Trim trailing spaces for MLlib. Author: Reynold Xin Closes #6534 from rxin/whitespace-mllib and squashes the following commits: 38926e3 [Reynold Xin] [SPARK-3850] Trim trailing spaces for MLlib. --- .../spark/ml/feature/StandardScaler.scala | 10 +-- .../ml/regression/LinearRegression.scala | 2 +- .../mllib/api/python/PythonMLLibAPI.scala | 4 +- .../mllib/clustering/GaussianMixture.scala | 86 +++++++++---------- .../clustering/GaussianMixtureModel.scala | 22 ++--- .../clustering/PowerIterationClustering.scala | 8 +- .../apache/spark/mllib/feature/Word2Vec.scala | 50 +++++------ .../org/apache/spark/mllib/linalg/BLAS.scala | 8 +- .../linalg/EigenValueDecomposition.scala | 2 +- .../BinaryClassificationPMMLModelExport.scala | 10 +-- .../mllib/pmml/export/PMMLModelExport.scala | 4 +- .../pmml/export/PMMLModelExportFactory.scala | 8 +- .../spark/mllib/random/RandomRDDs.scala | 6 +- .../spark/mllib/recommendation/ALS.scala | 2 +- .../mllib/regression/IsotonicRegression.scala | 10 +-- .../distribution/MultivariateGaussian.scala | 54 ++++++------ .../mllib/tree/GradientBoostedTrees.scala | 2 +- .../spark/mllib/tree/RandomForest.scala | 2 +- .../org/apache/spark/mllib/util/MLUtils.scala | 2 +- .../evaluation/RegressionEvaluatorSuite.scala | 2 +- .../spark/ml/feature/BinarizerSuite.scala | 2 +- .../clustering/GaussianMixtureSuite.scala | 4 +- .../PowerIterationClusteringSuite.scala | 2 +- .../apache/spark/mllib/linalg/BLASSuite.scala | 34 ++++---- .../spark/mllib/linalg/VectorsSuite.scala | 6 +- ...ryClassificationPMMLModelExportSuite.scala | 8 +- .../export/KMeansPMMLModelExportSuite.scala | 2 +- .../export/PMMLModelExportFactorySuite.scala | 10 +-- .../MultivariateGaussianSuite.scala | 14 +-- .../spark/mllib/util/MLUtilsSuite.scala | 2 +- 30 files changed, 189 insertions(+), 189 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index fdd2494fc87a6..b0fd06d84fdb3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -35,13 +35,13 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with /** * Centers the data with mean before scaling. - * It will build a dense output, so this does not work on sparse input + * It will build a dense output, so this does not work on sparse input * and will raise an exception. * Default: false * @group param */ val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean") - + /** * Scales the data to unit standard deviation. * Default: true @@ -68,13 +68,13 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - + /** @group setParam */ def setWithMean(value: Boolean): this.type = set(withMean, value) - + /** @group setParam */ def setWithStd(value: Boolean): this.type = set(withStd, value) - + override def fit(dataset: DataFrame): StandardScalerModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 7c40db1a40040..fe2a71a331694 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -321,7 +321,7 @@ private class LeastSquaresAggregator( } (weightsArray, -sum + labelMean / labelStd, weightsArray.length) } - + private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray) private val gradientSumArray = Array.ofDim[Double](dim) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 65f30fdba7393..16f3131796709 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -399,7 +399,7 @@ private[python] class PythonMLLibAPI extends Serializable { val sigma = si.map(_.asInstanceOf[DenseMatrix]) val gaussians = Array.tabulate(weight.length){ i => new MultivariateGaussian(mean(i), sigma(i)) - } + } val model = new GaussianMixtureModel(weight, gaussians) model.predictSoft(data).map(Vectors.dense) } @@ -494,7 +494,7 @@ private[python] class PythonMLLibAPI extends Serializable { def normalizeVector(p: Double, rdd: JavaRDD[Vector]): JavaRDD[Vector] = { new Normalizer(p).transform(rdd) } - + /** * Java stub for StandardScaler.fit(). This stub returns a * handle to the Java object instead of the content of the Java object. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index e9a23e40cc790..70b0e40948e51 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -36,11 +36,11 @@ import org.apache.spark.util.Utils * independent Gaussian distributions with associated "mixing" weights * specifying each's contribution to the composite. * - * Given a set of sample points, this class will maximize the log-likelihood - * for a mixture of k Gaussians, iterating until the log-likelihood changes by + * Given a set of sample points, this class will maximize the log-likelihood + * for a mixture of k Gaussians, iterating until the log-likelihood changes by * less than convergenceTol, or until it has reached the max number of iterations. * While this process is generally guaranteed to converge, it is not guaranteed - * to find a global optimum. + * to find a global optimum. * * Note: For high-dimensional data (with many features), this algorithm may perform poorly. * This is due to high-dimensional data (a) making it difficult to cluster at all (based @@ -53,24 +53,24 @@ import org.apache.spark.util.Utils */ @Experimental class GaussianMixture private ( - private var k: Int, - private var convergenceTol: Double, + private var k: Int, + private var convergenceTol: Double, private var maxIterations: Int, private var seed: Long) extends Serializable { - + /** * Constructs a default instance. The default parameters are {k: 2, convergenceTol: 0.01, * maxIterations: 100, seed: random}. */ def this() = this(2, 0.01, 100, Utils.random.nextLong()) - + // number of samples per cluster to use when initializing Gaussians private val nSamples = 5 - - // an initializing GMM can be provided rather than using the + + // an initializing GMM can be provided rather than using the // default random starting point private var initialModel: Option[GaussianMixtureModel] = None - + /** Set the initial GMM starting point, bypassing the random initialization. * You must call setK() prior to calling this method, and the condition * (model.k == this.k) must be met; failure will result in an IllegalArgumentException @@ -83,37 +83,37 @@ class GaussianMixture private ( } this } - + /** Return the user supplied initial GMM, if supplied */ def getInitialModel: Option[GaussianMixtureModel] = initialModel - + /** Set the number of Gaussians in the mixture model. Default: 2 */ def setK(k: Int): this.type = { this.k = k this } - + /** Return the number of Gaussians in the mixture model */ def getK: Int = k - + /** Set the maximum number of iterations to run. Default: 100 */ def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this } - + /** Return the maximum number of iterations to run */ def getMaxIterations: Int = maxIterations - + /** - * Set the largest change in log-likelihood at which convergence is + * Set the largest change in log-likelihood at which convergence is * considered to have occurred. */ def setConvergenceTol(convergenceTol: Double): this.type = { this.convergenceTol = convergenceTol this } - + /** * Return the largest change in log-likelihood at which convergence is * considered to have occurred. @@ -132,41 +132,41 @@ class GaussianMixture private ( /** Perform expectation maximization */ def run(data: RDD[Vector]): GaussianMixtureModel = { val sc = data.sparkContext - + // we will operate on the data as breeze data val breezeData = data.map(_.toBreeze).cache() - + // Get length of the input vectors val d = breezeData.first().length - + // Determine initial weights and corresponding Gaussians. // If the user supplied an initial GMM, we use those values, otherwise // we start with uniform weights, a random mean from the data, and // diagonal covariance matrices using component variances - // derived from the samples + // derived from the samples val (weights, gaussians) = initialModel match { case Some(gmm) => (gmm.weights, gmm.gaussians) - + case None => { val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed) - (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => + (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => val slice = samples.view(i * nSamples, (i + 1) * nSamples) - new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) + new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) }) } } - - var llh = Double.MinValue // current log-likelihood + + var llh = Double.MinValue // current log-likelihood var llhp = 0.0 // previous log-likelihood - + var iter = 0 while (iter < maxIterations && math.abs(llh-llhp) > convergenceTol) { // create and broadcast curried cluster contribution function val compute = sc.broadcast(ExpectationSum.add(weights, gaussians)_) - + // aggregate the cluster contribution for all sample points val sums = breezeData.aggregate(ExpectationSum.zero(k, d))(compute.value, _ += _) - + // Create new distributions based on the partial assignments // (often referred to as the "M" step in literature) val sumWeights = sums.weights.sum @@ -179,22 +179,22 @@ class GaussianMixture private ( gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i)) i = i + 1 } - + llhp = llh // current becomes previous llh = sums.logLikelihood // this is the freshly computed log-likelihood iter += 1 - } - + } + new GaussianMixtureModel(weights, gaussians) } - + /** Average of dense breeze vectors */ private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = { val v = BDV.zeros[Double](x(0).length) x.foreach(xi => v += xi) - v / x.length.toDouble + v / x.length.toDouble } - + /** * Construct matrix where diagonal entries are element-wise * variance of input vectors (computes biased variance) @@ -210,14 +210,14 @@ class GaussianMixture private ( // companion class to provide zero constructor for ExpectationSum private object ExpectationSum { def zero(k: Int, d: Int): ExpectationSum = { - new ExpectationSum(0.0, Array.fill(k)(0.0), + new ExpectationSum(0.0, Array.fill(k)(0.0), Array.fill(k)(BDV.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d, d))) } - + // compute cluster contributions for each input point // (U, T) => U for aggregation def add( - weights: Array[Double], + weights: Array[Double], dists: Array[MultivariateGaussian]) (sums: ExpectationSum, x: BV[Double]): ExpectationSum = { val p = weights.zip(dists).map { @@ -235,7 +235,7 @@ private object ExpectationSum { i = i + 1 } sums - } + } } // Aggregation class for partial expectation results @@ -244,9 +244,9 @@ private class ExpectationSum( val weights: Array[Double], val means: Array[BDV[Double]], val sigmas: Array[BreezeMatrix[Double]]) extends Serializable { - + val k = weights.length - + def +=(x: ExpectationSum): ExpectationSum = { var i = 0 while (i < k) { @@ -257,5 +257,5 @@ private class ExpectationSum( } logLikelihood += x.logLikelihood this - } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 86353aed81156..5fc2cb1b62d33 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -34,10 +34,10 @@ import org.apache.spark.sql.{SQLContext, Row} /** * :: Experimental :: * - * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points - * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are - * the respective mean and covariance for each Gaussian distribution i=1..k. - * + * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points + * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are + * the respective mean and covariance for each Gaussian distribution i=1..k. + * * @param weights Weights for each Gaussian distribution in the mixture, where weights(i) is * the weight for Gaussian i, and weights.sum == 1 * @param gaussians Array of MultivariateGaussian where gaussians(i) represents @@ -45,9 +45,9 @@ import org.apache.spark.sql.{SQLContext, Row} */ @Experimental class GaussianMixtureModel( - val weights: Array[Double], + val weights: Array[Double], val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable{ - + require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match") override protected def formatVersion = "1.0" @@ -64,20 +64,20 @@ class GaussianMixtureModel( val responsibilityMatrix = predictSoft(points) responsibilityMatrix.map(r => r.indexOf(r.max)) } - + /** * Given the input vectors, return the membership value of each vector - * to all mixture components. + * to all mixture components. */ def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = { val sc = points.sparkContext val bcDists = sc.broadcast(gaussians) val bcWeights = sc.broadcast(weights) - points.map { x => + points.map { x => computeSoftAssignments(x.toBreeze.toDenseVector, bcDists.value, bcWeights.value, k) } } - + /** * Compute the partial assignments for each vector */ @@ -89,7 +89,7 @@ class GaussianMixtureModel( val p = weights.zip(dists).map { case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(pt) } - val pSum = p.sum + val pSum = p.sum for (i <- 0 until k) { p(i) /= pSum } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 1ed01c9d8ba0b..e7a243f854e33 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -121,7 +121,7 @@ class PowerIterationClustering private[clustering] ( import org.apache.spark.mllib.clustering.PowerIterationClustering._ /** Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100, - * initMode: "random"}. + * initMode: "random"}. */ def this() = this(k = 2, maxIterations = 100, initMode = "random") @@ -243,7 +243,7 @@ object PowerIterationClustering extends Logging { /** * Generates random vertex properties (v0) to start power iteration. - * + * * @param g a graph representing the normalized affinity matrix (W) * @return a graph with edges representing W and vertices representing a random vector * with unit 1-norm @@ -266,7 +266,7 @@ object PowerIterationClustering extends Logging { * Generates the degree vector as the vertex properties (v0) to start power iteration. * It is not exactly the node degrees but just the normalized sum similarities. Call it * as degree vector because it is used in the PIC paper. - * + * * @param g a graph representing the normalized affinity matrix (W) * @return a graph with edges representing W and vertices representing the degree vector */ @@ -276,7 +276,7 @@ object PowerIterationClustering extends Logging { val v0 = g.vertices.mapValues(_ / sum) GraphImpl.fromExistingRDDs(VertexRDD(v0), g.edges) } - + /** * Runs power iteration. * @param g input graph with edges representing the normalized affinity matrix (W) and vertices diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 466ae95859b82..51546d41c36a6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -42,7 +42,7 @@ import org.apache.spark.util.random.XORShiftRandom import org.apache.spark.sql.{SQLContext, Row} /** - * Entry in vocabulary + * Entry in vocabulary */ private case class VocabWord( var word: String, @@ -56,18 +56,18 @@ private case class VocabWord( * :: Experimental :: * Word2Vec creates vector representation of words in a text corpus. * The algorithm first constructs a vocabulary from the corpus - * and then learns vector representation of words in the vocabulary. - * The vector representation can be used as features in + * and then learns vector representation of words in the vocabulary. + * The vector representation can be used as features in * natural language processing and machine learning algorithms. - * - * We used skip-gram model in our implementation and hierarchical softmax + * + * We used skip-gram model in our implementation and hierarchical softmax * method to train the model. The variable names in the implementation * matches the original C implementation. * - * For original C implementation, see https://code.google.com/p/word2vec/ - * For research papers, see + * For original C implementation, see https://code.google.com/p/word2vec/ + * For research papers, see * Efficient Estimation of Word Representations in Vector Space - * and + * and * Distributed Representations of Words and Phrases and their Compositionality. */ @Experimental @@ -79,7 +79,7 @@ class Word2Vec extends Serializable with Logging { private var numIterations = 1 private var seed = Utils.random.nextLong() private var minCount = 5 - + /** * Sets vector size (default: 100). */ @@ -122,15 +122,15 @@ class Word2Vec extends Serializable with Logging { this } - /** - * Sets minCount, the minimum number of times a token must appear to be included in the word2vec + /** + * Sets minCount, the minimum number of times a token must appear to be included in the word2vec * model's vocabulary (default: 5). */ def setMinCount(minCount: Int): this.type = { this.minCount = minCount this } - + private val EXP_TABLE_SIZE = 1000 private val MAX_EXP = 6 private val MAX_CODE_LENGTH = 40 @@ -150,13 +150,13 @@ class Word2Vec extends Serializable with Logging { .map(x => VocabWord( x._1, x._2, - new Array[Int](MAX_CODE_LENGTH), - new Array[Int](MAX_CODE_LENGTH), + new Array[Int](MAX_CODE_LENGTH), + new Array[Int](MAX_CODE_LENGTH), 0)) .filter(_.cn >= minCount) .collect() .sortWith((a, b) => a.cn > b.cn) - + vocabSize = vocab.length require(vocabSize > 0, "The vocabulary size should be > 0. You may need to check " + "the setting of minCount, which could be large enough to remove all your words in sentences.") @@ -198,8 +198,8 @@ class Word2Vec extends Serializable with Logging { } var pos1 = vocabSize - 1 var pos2 = vocabSize - - var min1i = 0 + + var min1i = 0 var min2i = 0 a = 0 @@ -268,15 +268,15 @@ class Word2Vec extends Serializable with Logging { val words = dataset.flatMap(x => x) learnVocab(words) - + createBinaryTree() - + val sc = dataset.context val expTable = sc.broadcast(createExpTable()) val bcVocab = sc.broadcast(vocab) val bcVocabHash = sc.broadcast(vocabHash) - + val sentences: RDD[Array[Int]] = words.mapPartitions { iter => new Iterator[Array[Int]] { def hasNext: Boolean = iter.hasNext @@ -297,7 +297,7 @@ class Word2Vec extends Serializable with Logging { } } } - + val newSentences = sentences.repartition(numPartitions).cache() val initRandom = new XORShiftRandom(seed) @@ -402,7 +402,7 @@ class Word2Vec extends Serializable with Logging { } } newSentences.unpersist() - + val word2VecMap = mutable.HashMap.empty[String, Array[Float]] var i = 0 while (i < vocabSize) { @@ -480,7 +480,7 @@ class Word2VecModel private[mllib] ( /** * Transforms a word to its vector representation - * @param word a word + * @param word a word * @return vector representation of word */ def transform(word: String): Vector = { @@ -495,7 +495,7 @@ class Word2VecModel private[mllib] ( /** * Find synonyms of a word * @param word a word - * @param num number of synonyms to find + * @param num number of synonyms to find * @return array of (word, cosineSimilarity) */ def findSynonyms(word: String, num: Int): Array[(String, Double)] = { @@ -506,7 +506,7 @@ class Word2VecModel private[mllib] ( /** * Find synonyms of the vector representation of a word * @param vector vector representation of a word - * @param num number of synonyms to find + * @param num number of synonyms to find * @return array of (word, cosineSimilarity) */ def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index ec38529cf8fae..557119f7b1cd1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -228,7 +228,7 @@ private[spark] object BLAS extends Serializable with Logging { } _nativeBLAS } - + /** * A := alpha * x * x^T^ + A * @param alpha a real scalar that will be multiplied to x * x^T^. @@ -264,7 +264,7 @@ private[spark] object BLAS extends Serializable with Logging { j += 1 } i += 1 - } + } } private def syr(alpha: Double, x: SparseVector, A: DenseMatrix) { @@ -505,7 +505,7 @@ private[spark] object BLAS extends Serializable with Logging { nativeBLAS.dgemv(tStrA, mA, nA, alpha, A.values, mA, x.values, 1, beta, y.values, 1) } - + /** * y := alpha * A * x + beta * y * For `DenseMatrix` A and `SparseVector` x. @@ -557,7 +557,7 @@ private[spark] object BLAS extends Serializable with Logging { } } } - + /** * y := alpha * A * x + beta * y * For `SparseMatrix` A and `SparseVector` x. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala index 866936aa4f118..ae3ba3099c878 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala @@ -81,7 +81,7 @@ private[mllib] object EigenValueDecomposition { require(n * ncv.toLong <= Integer.MAX_VALUE && ncv * (ncv.toLong + 8) <= Integer.MAX_VALUE, s"k = $k and/or n = $n are too large to compute an eigendecomposition") - + var ido = new intW(0) var info = new intW(0) var resid = new Array[Double](n) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala index 34b447584e521..622b53a252ac5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala @@ -27,10 +27,10 @@ import org.apache.spark.mllib.regression.GeneralizedLinearModel * PMML Model Export for GeneralizedLinearModel class with binary ClassificationModel */ private[mllib] class BinaryClassificationPMMLModelExport( - model : GeneralizedLinearModel, + model : GeneralizedLinearModel, description : String, normalizationMethod : RegressionNormalizationMethodType, - threshold: Double) + threshold: Double) extends PMMLModelExport { populateBinaryClassificationPMML() @@ -72,7 +72,7 @@ private[mllib] class BinaryClassificationPMMLModelExport( .withUsageType(FieldUsageType.ACTIVE)) regressionTableYES.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) } - + // add target field val targetField = FieldName.create("target") dataDictionary @@ -80,9 +80,9 @@ private[mllib] class BinaryClassificationPMMLModelExport( miningSchema .withMiningFields(new MiningField(targetField) .withUsageType(FieldUsageType.TARGET)) - + dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) - + pmml.setDataDictionary(dataDictionary) pmml.withModels(regressionModel) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala index ebdeae50bb32f..c5fdecd3ca17f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala @@ -25,7 +25,7 @@ import scala.beans.BeanProperty import org.dmg.pmml.{Application, Header, PMML, Timestamp} private[mllib] trait PMMLModelExport { - + /** * Holder of the exported model in PMML format */ @@ -33,7 +33,7 @@ private[mllib] trait PMMLModelExport { val pmml: PMML = new PMML setHeader(pmml) - + private def setHeader(pmml: PMML): Unit = { val version = getClass.getPackage.getImplementationVersion val app = new Application().withName("Apache Spark MLlib").withVersion(version) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala index c16e83d6a067d..29bd689e1185a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala @@ -27,9 +27,9 @@ import org.apache.spark.mllib.regression.LinearRegressionModel import org.apache.spark.mllib.regression.RidgeRegressionModel private[mllib] object PMMLModelExportFactory { - + /** - * Factory object to help creating the necessary PMMLModelExport implementation + * Factory object to help creating the necessary PMMLModelExport implementation * taking as input the machine learning model (for example KMeansModel). */ def createPMMLModelExport(model: Any): PMMLModelExport = { @@ -44,7 +44,7 @@ private[mllib] object PMMLModelExportFactory { new GeneralizedLinearPMMLModelExport(lasso, "lasso regression") case svm: SVMModel => new BinaryClassificationPMMLModelExport( - svm, "linear SVM", RegressionNormalizationMethodType.NONE, + svm, "linear SVM", RegressionNormalizationMethodType.NONE, svm.getThreshold.getOrElse(0.0)) case logistic: LogisticRegressionModel => if (logistic.numClasses == 2) { @@ -60,5 +60,5 @@ private[mllib] object PMMLModelExportFactory { "PMML Export not supported for model: " + model.getClass.getName) } } - + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index 7db5a14fd45a5..174d5e0f6c9f0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -234,7 +234,7 @@ object RandomRDDs { * * @param sc SparkContext used to create the RDD. * @param shape shape parameter (> 0) for the gamma distribution - * @param scale scale parameter (> 0) for the gamma distribution + * @param scale scale parameter (> 0) for the gamma distribution * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). @@ -293,7 +293,7 @@ object RandomRDDs { * * @param sc SparkContext used to create the RDD. * @param mean mean for the log normal distribution - * @param std standard deviation for the log normal distribution + * @param std standard deviation for the log normal distribution * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). @@ -671,7 +671,7 @@ object RandomRDDs { * * @param sc SparkContext used to create the RDD. * @param shape shape parameter (> 0) for the gamma distribution. - * @param scale scale parameter (> 0) for the gamma distribution. + * @param scale scale parameter (> 0) for the gamma distribution. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index dddefe1944e9d..93290e6508529 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -175,7 +175,7 @@ class ALS private ( /** * :: DeveloperApi :: * Sets storage level for final RDDs (user/product used in MatrixFactorizationModel). The default - * value is `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g. + * value is `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g. * `MEMORY_AND_DISK_SER` and set `spark.rdd.compress` to `true` to reduce the space requirement, * at the cost of speed. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index 96e50faca2b19..f3b46c75c05f3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -170,15 +170,15 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { case class Data(boundary: Double, prediction: Double) def save( - sc: SparkContext, - path: String, - boundaries: Array[Double], - predictions: Array[Double], + sc: SparkContext, + path: String, + boundaries: Array[Double], + predictions: Array[Double], isotonic: Boolean): Unit = { val sqlContext = new SQLContext(sc) val metadata = compact(render( - ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("isotonic" -> isotonic))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index cd6add9d60b0d..cf51b24ff777f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -29,102 +29,102 @@ import org.apache.spark.mllib.util.MLUtils * the event that the covariance matrix is singular, the density will be computed in a * reduced dimensional subspace under which the distribution is supported. * (see [[http://en.wikipedia.org/wiki/Multivariate_normal_distribution#Degenerate_case]]) - * + * * @param mu The mean vector of the distribution * @param sigma The covariance matrix of the distribution */ @DeveloperApi class MultivariateGaussian ( - val mu: Vector, + val mu: Vector, val sigma: Matrix) extends Serializable { require(sigma.numCols == sigma.numRows, "Covariance matrix must be square") require(mu.size == sigma.numCols, "Mean vector length must match covariance matrix size") - + private val breezeMu = mu.toBreeze.toDenseVector - + /** * private[mllib] constructor - * + * * @param mu The mean vector of the distribution * @param sigma The covariance matrix of the distribution */ private[mllib] def this(mu: DBV[Double], sigma: DBM[Double]) = { this(Vectors.fromBreeze(mu), Matrices.fromBreeze(sigma)) } - + /** * Compute distribution dependent constants: * rootSigmaInv = D^(-1/2)^ * U, where sigma = U * D * U.t - * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) + * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) */ private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants - + /** Returns density of this multivariate Gaussian at given point, x */ def pdf(x: Vector): Double = { pdf(x.toBreeze) } - + /** Returns the log-density of this multivariate Gaussian at given point, x */ def logpdf(x: Vector): Double = { logpdf(x.toBreeze) } - + /** Returns density of this multivariate Gaussian at given point, x */ private[mllib] def pdf(x: BV[Double]): Double = { math.exp(logpdf(x)) } - + /** Returns the log-density of this multivariate Gaussian at given point, x */ private[mllib] def logpdf(x: BV[Double]): Double = { val delta = x - breezeMu val v = rootSigmaInv * delta u + v.t * v * -0.5 } - + /** * Calculate distribution dependent components used for the density function: * pdf(x) = (2*pi)^(-k/2)^ * det(sigma)^(-1/2)^ * exp((-1/2) * (x-mu).t * inv(sigma) * (x-mu)) * where k is length of the mean vector. - * - * We here compute distribution-fixed parts + * + * We here compute distribution-fixed parts * log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) * and * D^(-1/2)^ * U, where sigma = U * D * U.t - * + * * Both the determinant and the inverse can be computed from the singular value decomposition * of sigma. Noting that covariance matrices are always symmetric and positive semi-definite, * we can use the eigendecomposition. We also do not compute the inverse directly; noting - * that - * + * that + * * sigma = U * D * U.t - * inv(Sigma) = U * inv(D) * U.t + * inv(Sigma) = U * inv(D) * U.t * = (D^{-1/2}^ * U).t * (D^{-1/2}^ * U) - * + * * and thus - * + * * -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U * (x-mu))^2^ - * - * To guard against singular covariance matrices, this method computes both the + * + * To guard against singular covariance matrices, this method computes both the * pseudo-determinant and the pseudo-inverse (Moore-Penrose). Singular values are considered * to be non-zero only if they exceed a tolerance based on machine precision, matrix size, and * relation to the maximum singular value (same tolerance used by, e.g., Octave). */ private def calculateCovarianceConstants: (DBM[Double], Double) = { val eigSym.EigSym(d, u) = eigSym(sigma.toBreeze.toDenseMatrix) // sigma = u * diag(d) * u.t - + // For numerical stability, values are considered to be non-zero only if they exceed tol. // This prevents any inverted value from exceeding (eps * n * max(d))^-1 val tol = MLUtils.EPSILON * max(d) * d.length - + try { // log(pseudo-determinant) is sum of the logs of all non-zero singular values val logPseudoDetSigma = d.activeValuesIterator.filter(_ > tol).map(math.log).sum - - // calculate the root-pseudo-inverse of the diagonal matrix of singular values + + // calculate the root-pseudo-inverse of the diagonal matrix of singular values // by inverting the square root of all non-zero values val pinvS = diag(new DBV(d.map(v => if (v > tol) math.sqrt(1.0 / v) else 0.0).toArray)) - + (pinvS * u, -0.5 * (mu.size * math.log(2.0 * math.Pi) + logPseudoDetSigma)) } catch { case uex: UnsupportedOperationException => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index e3ddc7053693c..a835f96d5d0e3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -270,7 +270,7 @@ object GradientBoostedTrees extends Logging { logInfo(s"$timer") if (persistedInput) input.unpersist() - + if (validate) { new GradientBoostedTreesModel( boostingStrategy.treeStrategy.algo, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 99d0e3cf2fd6d..069959976a188 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -474,7 +474,7 @@ object RandomForest extends Serializable with Logging { val (treeIndex, node) = nodeQueue.head // Choose subset of features for node (if subsampling). val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { - Some(SamplingUtils.reservoirSampleAndCount(Range(0, + Some(SamplingUtils.reservoirSampleAndCount(Range(0, metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1) } else { None diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 681f4c618d302..541f3288b6c43 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -265,7 +265,7 @@ object MLUtils { } Vectors.fromBreeze(vector1) } - + /** * Returns the squared Euclidean distance between two vectors. The following formula will be used * if it does not introduce too much numerical error: diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index 9da0618abd23c..36a1ac6b7996d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -38,7 +38,7 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext val dataset = sqlContext.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) - + /** * Using the following R code to load the data, train the model and evaluate metrics. * diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index d4631518e0f5b..7953bd0417191 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -47,7 +47,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext { test("Binarize continuous features with setter") { val threshold: Double = 0.2 - val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) + val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) val dataFrame: DataFrame = sqlContext.createDataFrame( data.zip(thresholdBinarized)).toDF("feature", "expected") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala index a3b085e441491..b218d72f1268a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala @@ -46,7 +46,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { } } - + test("two clusters") { val data = sc.parallelize(GaussianTestData.data) @@ -62,7 +62,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { val Ew = Array(1.0 / 3.0, 2.0 / 3.0) val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604)) val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644))) - + val gmm = new GaussianMixture() .setK(2) .setInitialModel(initialGmm) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala index 3903712879928..19e65f1b53ab5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala @@ -56,7 +56,7 @@ class PowerIterationClusteringSuite extends SparkFunSuite with MLlibTestSparkCon predictions(a.cluster) += a.id } assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) - + val model2 = new PowerIterationClustering() .setK(2) .setInitializationMode("degree") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index bcc2e657f3fd4..b0f3f71113c57 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -139,7 +139,7 @@ class BLASSuite extends SparkFunSuite { syr(alpha, x, dA) assert(dA ~== expected absTol 1e-15) - + val dB = new DenseMatrix(3, 4, Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0)) @@ -148,7 +148,7 @@ class BLASSuite extends SparkFunSuite { syr(alpha, x, dB) } } - + val dC = new DenseMatrix(3, 3, Array(0.0, 1.2, 2.2, 1.2, 3.2, 5.3, 2.2, 5.3, 1.8)) @@ -157,7 +157,7 @@ class BLASSuite extends SparkFunSuite { syr(alpha, x, dC) } } - + val y = new DenseVector(Array(0.0, 2.7, 3.5, 2.1, 1.5)) withClue("Size of vector must match the rank of matrix") { @@ -255,13 +255,13 @@ class BLASSuite extends SparkFunSuite { val dA = new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0)) val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0)) - + val dA2 = new DenseMatrix(4, 3, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0), true) val sA2 = new SparseMatrix(4, 3, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0), true) - + val dx = new DenseVector(Array(1.0, 2.0, 3.0)) val sx = dx.toSparse val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0)) @@ -270,7 +270,7 @@ class BLASSuite extends SparkFunSuite { assert(sA.multiply(dx) ~== expected absTol 1e-15) assert(dA.multiply(sx) ~== expected absTol 1e-15) assert(sA.multiply(sx) ~== expected absTol 1e-15) - + val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0)) val y2 = y1.copy val y3 = y1.copy @@ -287,7 +287,7 @@ class BLASSuite extends SparkFunSuite { val y14 = y1.copy val y15 = y1.copy val y16 = y1.copy - + val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0)) val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0)) @@ -295,42 +295,42 @@ class BLASSuite extends SparkFunSuite { gemv(1.0, sA, dx, 2.0, y2) gemv(1.0, dA, sx, 2.0, y3) gemv(1.0, sA, sx, 2.0, y4) - + gemv(1.0, dA2, dx, 2.0, y5) gemv(1.0, sA2, dx, 2.0, y6) gemv(1.0, dA2, sx, 2.0, y7) gemv(1.0, sA2, sx, 2.0, y8) - + gemv(2.0, dA, dx, 2.0, y9) gemv(2.0, sA, dx, 2.0, y10) gemv(2.0, dA, sx, 2.0, y11) gemv(2.0, sA, sx, 2.0, y12) - + gemv(2.0, dA2, dx, 2.0, y13) gemv(2.0, sA2, dx, 2.0, y14) gemv(2.0, dA2, sx, 2.0, y15) gemv(2.0, sA2, sx, 2.0, y16) - + assert(y1 ~== expected2 absTol 1e-15) assert(y2 ~== expected2 absTol 1e-15) assert(y3 ~== expected2 absTol 1e-15) assert(y4 ~== expected2 absTol 1e-15) - + assert(y5 ~== expected2 absTol 1e-15) assert(y6 ~== expected2 absTol 1e-15) assert(y7 ~== expected2 absTol 1e-15) assert(y8 ~== expected2 absTol 1e-15) - + assert(y9 ~== expected3 absTol 1e-15) assert(y10 ~== expected3 absTol 1e-15) assert(y11 ~== expected3 absTol 1e-15) assert(y12 ~== expected3 absTol 1e-15) - + assert(y13 ~== expected3 absTol 1e-15) assert(y14 ~== expected3 absTol 1e-15) assert(y15 ~== expected3 absTol 1e-15) assert(y16 ~== expected3 absTol 1e-15) - + withClue("columns of A don't match the rows of B") { intercept[Exception] { gemv(1.0, dA.transpose, dx, 2.0, y1) @@ -345,12 +345,12 @@ class BLASSuite extends SparkFunSuite { gemv(1.0, sA.transpose, sx, 2.0, y1) } } - + val dAT = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) val sAT = new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) - + val dATT = dAT.transpose val sATT = sAT.transpose diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index c6d29dcdb0f2b..c4ae0a16f7c04 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -214,13 +214,13 @@ class VectorsSuite extends SparkFunSuite { val squaredDist = breezeSquaredDistance(sparseVector1.toBreeze, sparseVector2.toBreeze) - // SparseVector vs. SparseVector - assert(Vectors.sqdist(sparseVector1, sparseVector2) ~== squaredDist relTol 1E-8) + // SparseVector vs. SparseVector + assert(Vectors.sqdist(sparseVector1, sparseVector2) ~== squaredDist relTol 1E-8) // DenseVector vs. SparseVector assert(Vectors.sqdist(denseVector1, sparseVector2) ~== squaredDist relTol 1E-8) // DenseVector vs. DenseVector assert(Vectors.sqdist(denseVector1, denseVector2) ~== squaredDist relTol 1E-8) - } + } } test("foreachActive") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala index 7a724fc78b1d9..4c6e76e47419b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala @@ -53,13 +53,13 @@ class BinaryClassificationPMMLModelExportSuite extends SparkFunSuite { // ensure logistic regression has normalization method set to LOGIT assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.LOGIT) } - + test("linear SVM PMML export") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) - + val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) - + // assert that the PMML format is as expected assert(svmModelExport.isInstanceOf[PMMLModelExport]) val pmml = svmModelExport.getPmml @@ -80,5 +80,5 @@ class BinaryClassificationPMMLModelExportSuite extends SparkFunSuite { // ensure linear SVM has normalization method set to NONE assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.NONE) } - + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala index a1a683559a54c..b3f9750afa730 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala @@ -45,5 +45,5 @@ class KMeansPMMLModelExportSuite extends SparkFunSuite { val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel] assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length) } - + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala index 0d194005a30b2..af49450961750 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala @@ -60,25 +60,25 @@ class PMMLModelExportFactorySuite extends SparkFunSuite { test("PMMLModelExportFactory create BinaryClassificationPMMLModelExport " + "when passing a LogisticRegressionModel or SVMModel") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) - + val logisticRegressionModel = new LogisticRegressionModel(linearInput(0).features, linearInput(0).label) val logisticRegressionModelExport = PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel) assert(logisticRegressionModelExport.isInstanceOf[BinaryClassificationPMMLModelExport]) - + val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) assert(svmModelExport.isInstanceOf[BinaryClassificationPMMLModelExport]) } - + test("PMMLModelExportFactory throw IllegalArgumentException " + "when passing a Multinomial Logistic Regression") { /** 3 classes, 2 features */ val multiclassLogisticRegressionModel = new LogisticRegressionModel( - weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, + weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, numFeatures = 2, numClasses = 3) - + intercept[IllegalArgumentException] { PMMLModelExportFactory.createPMMLModelExport(multiclassLogisticRegressionModel) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala index 703b623536315..aa60deb665aeb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala @@ -26,39 +26,39 @@ class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext test("univariate") { val x1 = Vectors.dense(0.0) val x2 = Vectors.dense(1.5) - + val mu = Vectors.dense(0.0) val sigma1 = Matrices.dense(1, 1, Array(1.0)) val dist1 = new MultivariateGaussian(mu, sigma1) assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5) assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5) - + val sigma2 = Matrices.dense(1, 1, Array(4.0)) val dist2 = new MultivariateGaussian(mu, sigma2) assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5) assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5) } - + test("multivariate") { val x1 = Vectors.dense(0.0, 0.0) val x2 = Vectors.dense(1.0, 1.0) - + val mu = Vectors.dense(0.0, 0.0) val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0)) val dist1 = new MultivariateGaussian(mu, sigma1) assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5) assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5) - + val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0)) val dist2 = new MultivariateGaussian(mu, sigma2) assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5) assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5) } - + test("multivariate degenerate") { val x1 = Vectors.dense(0.0, 0.0) val x2 = Vectors.dense(1.0, 1.0) - + val mu = Vectors.dense(0.0, 0.0) val sigma = Matrices.dense(2, 2, Array(1.0, 1.0, 1.0, 1.0)) val dist = new MultivariateGaussian(mu, sigma) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 87b3661f77944..734b7babec7be 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -62,7 +62,7 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { val fastSquaredDist3 = fastSquaredDistance(v2, norm2, v3, norm3, precision) assert((fastSquaredDist3 - squaredDist2) <= precision * squaredDist2, s"failed with m = $m") - if (m > 10) { + if (m > 10) { val v4 = Vectors.sparse(n, indices.slice(0, m - 10), indices.map(i => a(i) + 0.5).slice(0, m - 10)) val norm4 = Vectors.norm(v4, 2.0) From 0674700303da3e4737d73f5fabd2a925ec712f63 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Sun, 31 May 2015 11:51:49 -0700 Subject: [PATCH 112/254] [SPARK-7949] [MLLIB] [DOC] update document with some missing save/load add save load for examples: KMeansModel PowerIterationClusteringModel Word2VecModel IsotonicRegressionModel Author: Yuhao Yang Closes #6498 from hhbyyh/docSaveLoad and squashes the following commits: 7f9f06d [Yuhao Yang] add missing imports c604cad [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into docSaveLoad 1dd77cc [Yuhao Yang] update document with some missing save/load --- docs/mllib-clustering.md | 28 ++++++++++++++++++++++++---- docs/mllib-feature-extraction.md | 6 +++++- docs/mllib-isotonic-regression.md | 10 +++++++++- 3 files changed, 38 insertions(+), 6 deletions(-) diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index f41ca70952eb7..dac22f736e8cb 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -47,7 +47,7 @@ Set Sum of Squared Error (WSSSE). You can reduce this error measure by increasin optimal *k* is usually one where there is an "elbow" in the WSSSE graph. {% highlight scala %} -import org.apache.spark.mllib.clustering.KMeans +import org.apache.spark.mllib.clustering.{KMeans, KMeansModel} import org.apache.spark.mllib.linalg.Vectors // Load and parse the data @@ -62,6 +62,10 @@ val clusters = KMeans.train(parsedData, numClusters, numIterations) // Evaluate clustering by computing Within Set Sum of Squared Errors val WSSSE = clusters.computeCost(parsedData) println("Within Set Sum of Squared Errors = " + WSSSE) + +// Save and load model +clusters.save(sc, "myModelPath") +val sameModel = KMeansModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -110,6 +114,10 @@ public class KMeansExample { // Evaluate clustering by computing Within Set Sum of Squared Errors double WSSSE = clusters.computeCost(parsedData.rdd()); System.out.println("Within Set Sum of Squared Errors = " + WSSSE); + + // Save and load model + clusters.save(sc.sc(), "myModelPath"); + KMeansModel sameModel = KMeansModel.load(sc.sc(), "myModelPath"); } } {% endhighlight %} @@ -124,7 +132,7 @@ Within Set Sum of Squared Error (WSSSE). You can reduce this error measure by in fact the optimal *k* is usually one where there is an "elbow" in the WSSSE graph. {% highlight python %} -from pyspark.mllib.clustering import KMeans +from pyspark.mllib.clustering import KMeans, KMeansModel from numpy import array from math import sqrt @@ -143,6 +151,10 @@ def error(point): WSSSE = parsedData.map(lambda point: error(point)).reduce(lambda x, y: x + y) print("Within Set Sum of Squared Error = " + str(WSSSE)) + +# Save and load model +clusters.save(sc, "myModelPath") +sameModel = KMeansModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -312,12 +324,12 @@ Calling `PowerIterationClustering.run` returns a which contains the computed clustering assignments. {% highlight scala %} -import org.apache.spark.mllib.clustering.PowerIterationClustering +import org.apache.spark.mllib.clustering.{PowerIterationClustering, PowerIterationClusteringModel} import org.apache.spark.mllib.linalg.Vectors val similarities: RDD[(Long, Long, Double)] = ... -val pic = new PowerIteartionClustering() +val pic = new PowerIterationClustering() .setK(3) .setMaxIterations(20) val model = pic.run(similarities) @@ -325,6 +337,10 @@ val model = pic.run(similarities) model.assignments.foreach { a => println(s"${a.id} -> ${a.cluster}") } + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = PowerIterationClusteringModel.load(sc, "myModelPath") {% endhighlight %} A full example that produces the experiment described in the PIC paper can be found under @@ -360,6 +376,10 @@ PowerIterationClusteringModel model = pic.run(similarities); for (PowerIterationClustering.Assignment a: model.assignments().toJavaRDD().collect()) { System.out.println(a.id() + " -> " + a.cluster()); } + +// Save and load model +model.save(sc.sc(), "myModelPath"); +PowerIterationClusteringModel sameModel = PowerIterationClusteringModel.load(sc.sc(), "myModelPath"); {% endhighlight %} diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 1f6ad8b13d730..4fe470a8de810 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -188,7 +188,7 @@ Here we assume the extracted file is `text8` and in same directory as you run th import org.apache.spark._ import org.apache.spark.rdd._ import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.feature.Word2Vec +import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel} val input = sc.textFile("text8").map(line => line.split(" ").toSeq) @@ -201,6 +201,10 @@ val synonyms = model.findSynonyms("china", 40) for((synonym, cosineSimilarity) <- synonyms) { println(s"$synonym $cosineSimilarity") } + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = Word2VecModel.load(sc, "myModelPath") {% endhighlight %}
    diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index b521c2f27cd6e..5732bc4c7e79e 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -60,7 +60,7 @@ Model is created using the training set and a mean squared error is calculated f labels and real labels in the test set. {% highlight scala %} -import org.apache.spark.mllib.regression.IsotonicRegression +import org.apache.spark.mllib.regression.{IsotonicRegression, IsotonicRegressionModel} val data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt") @@ -88,6 +88,10 @@ val predictionAndLabel = test.map { point => // Calculate mean squared error between predicted and real labels. val meanSquaredError = predictionAndLabel.map{case(p, l) => math.pow((p - l), 2)}.mean() println("Mean Squared Error = " + meanSquaredError) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = IsotonicRegressionModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -150,6 +154,10 @@ Double meanSquaredError = new JavaDoubleRDD(predictionAndLabel.map( ).rdd()).mean(); System.out.println("Mean Squared Error = " + meanSquaredError); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +IsotonicRegressionModel sameModel = IsotonicRegressionModel.load(sc.sc(), "myModelPath"); {% endhighlight %} From 866652c903d06d1cb4356283e0741119d84dcc21 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 31 May 2015 14:23:42 -0700 Subject: [PATCH 113/254] [SPARK-3850] Turn style checker on for trailing whitespaces. Author: Reynold Xin Closes #6541 from rxin/trailing-whitespace-on and squashes the following commits: f72ebe4 [Reynold Xin] [SPARK-3850] Turn style checker on for trailing whitespaces. --- scalastyle-config.xml | 3 +++ .../scala/org/apache/spark/sql/DataFrameStatFunctions.scala | 2 +- .../org/apache/spark/sql/hive/execution/HiveQuerySuite.scala | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 75ef1e964b5ac..f52b09551adde 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -50,6 +50,9 @@ */]]>
    + + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 5d106c1ac2674..b624eaa201ea4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -43,7 +43,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { /** * Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson - * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in + * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in * MLlib's Statistics. * * @param col1 the name of the column diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 4af31d482ce42..440b7c87b0da2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -57,7 +57,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF sql( """ - |CREATE TEMPORARY FUNCTION udtf_count2 + |CREATE TEMPORARY FUNCTION udtf_count2 |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' """.stripMargin) } From 46576ab303e50c54c3bd464f8939953efe644574 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Sun, 31 May 2015 15:01:21 -0700 Subject: [PATCH 114/254] [SPARK-7227] [SPARKR] Support fillna / dropna in R DataFrame. Author: Sun Rui Closes #6183 from sun-rui/SPARK-7227 and squashes the following commits: dd6f5b3 [Sun Rui] Rename readEnv() back to readMap(). Add alias na.omit() for dropna(). 41cf725 [Sun Rui] [SPARK-7227][SPARKR] Support fillna / dropna in R DataFrame. --- R/pkg/NAMESPACE | 2 + R/pkg/R/DataFrame.R | 125 ++++++++++++++++++ R/pkg/R/generics.R | 18 +++ R/pkg/R/serialize.R | 10 +- R/pkg/inst/tests/test_sparkSQL.R | 109 +++++++++++++++ .../scala/org/apache/spark/api/r/SerDe.scala | 6 +- 6 files changed, 267 insertions(+), 3 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 411126a377950..f9447f6c3288d 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -19,9 +19,11 @@ exportMethods("arrange", "count", "describe", "distinct", + "dropna", "dtypes", "except", "explain", + "fillna", "filter", "first", "group_by", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index e79d324838fe3..0af5cb8881e35 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1429,3 +1429,128 @@ setMethod("describe", sdf <- callJMethod(x@sdf, "describe", listToSeq(colList)) dataFrame(sdf) }) + +#' dropna +#' +#' Returns a new DataFrame omitting rows with null values. +#' +#' @param x A SparkSQL DataFrame. +#' @param how "any" or "all". +#' if "any", drop a row if it contains any nulls. +#' if "all", drop a row only if all its values are null. +#' if minNonNulls is specified, how is ignored. +#' @param minNonNulls If specified, drop rows that have less than +#' minNonNulls non-null values. +#' This overwrites the how parameter. +#' @param cols Optional list of column names to consider. +#' @return A DataFrame +#' +#' @rdname nafunctions +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' dropna(df) +#' } +setMethod("dropna", + signature(x = "DataFrame"), + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + how <- match.arg(how) + if (is.null(cols)) { + cols <- columns(x) + } + if (is.null(minNonNulls)) { + minNonNulls <- if (how == "any") { length(cols) } else { 1 } + } + + naFunctions <- callJMethod(x@sdf, "na") + sdf <- callJMethod(naFunctions, "drop", + as.integer(minNonNulls), listToSeq(as.list(cols))) + dataFrame(sdf) + }) + +#' @aliases dropna +#' @export +setMethod("na.omit", + signature(x = "DataFrame"), + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + dropna(x, how, minNonNulls, cols) + }) + +#' fillna +#' +#' Replace null values. +#' +#' @param x A SparkSQL DataFrame. +#' @param value Value to replace null values with. +#' Should be an integer, numeric, character or named list. +#' If the value is a named list, then cols is ignored and +#' value must be a mapping from column name (character) to +#' replacement value. The replacement value must be an +#' integer, numeric or character. +#' @param cols optional list of column names to consider. +#' Columns specified in cols that do not have matching data +#' type are ignored. For example, if value is a character, and +#' subset contains a non-character column, then the non-character +#' column is simply ignored. +#' @return A DataFrame +#' +#' @rdname nafunctions +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' fillna(df, 1) +#' fillna(df, list("age" = 20, "name" = "unknown")) +#' } +setMethod("fillna", + signature(x = "DataFrame"), + function(x, value, cols = NULL) { + if (!(class(value) %in% c("integer", "numeric", "character", "list"))) { + stop("value should be an integer, numeric, charactor or named list.") + } + + if (class(value) == "list") { + # Check column names in the named list + colNames <- names(value) + if (length(colNames) == 0 || !all(colNames != "")) { + stop("value should be an a named list with each name being a column name.") + } + + # Convert to the named list to an environment to be passed to JVM + valueMap <- new.env() + for (col in colNames) { + # Check each item in the named list is of valid type + v <- value[[col]] + if (!(class(v) %in% c("integer", "numeric", "character"))) { + stop("Each item in value should be an integer, numeric or charactor.") + } + valueMap[[col]] <- v + } + + # When value is a named list, caller is expected not to pass in cols + if (!is.null(cols)) { + warning("When value is a named list, cols is ignored!") + cols <- NULL + } + + value <- valueMap + } else if (is.integer(value)) { + # Cast an integer to a numeric + value <- as.numeric(value) + } + + naFunctions <- callJMethod(x@sdf, "na") + sdf <- if (length(cols) == 0) { + callJMethod(naFunctions, "fill", value) + } else { + callJMethod(naFunctions, "fill", value, listToSeq(as.list(cols))) + } + dataFrame(sdf) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 1f4fc6adaca8d..12e09176c9f92 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -396,6 +396,20 @@ setGeneric("columns", function(x) {standardGeneric("columns") }) #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) +#' @rdname nafunctions +#' @export +setGeneric("dropna", + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + standardGeneric("dropna") + }) + +#' @rdname nafunctions +#' @export +setGeneric("na.omit", + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + standardGeneric("na.omit") + }) + #' @rdname schema #' @export setGeneric("dtypes", function(x) { standardGeneric("dtypes") }) @@ -408,6 +422,10 @@ setGeneric("explain", function(x, ...) { standardGeneric("explain") }) #' @export setGeneric("except", function(x, y) { standardGeneric("except") }) +#' @rdname nafunctions +#' @export +setGeneric("fillna", function(x, value, cols = NULL) { standardGeneric("fillna") }) + #' @rdname filter #' @export setGeneric("filter", function(x, condition) { standardGeneric("filter") }) diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index c53d0a961016f..2081786e6f833 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -160,6 +160,14 @@ writeList <- function(con, arr) { } } +# Used to pass arrays where the elements can be of different types +writeGenericList <- function(con, list) { + writeInt(con, length(list)) + for (elem in list) { + writeObject(con, elem) + } +} + # Used to pass in hash maps required on Java side. writeEnv <- function(con, env) { len <- length(env) @@ -168,7 +176,7 @@ writeEnv <- function(con, env) { if (len > 0) { writeList(con, as.list(ls(env))) vals <- lapply(ls(env), function(x) { env[[x]] }) - writeList(con, as.list(vals)) + writeGenericList(con, as.list(vals)) } } diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 1857e636e8577..d2d82e791e876 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -32,6 +32,15 @@ jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") parquetPath <- tempfile(pattern="sparkr-test", fileext=".parquet") writeLines(mockLines, jsonPath) +# For test nafunctions, like dropna(), fillna(),... +mockLinesNa <- c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}", + "{\"name\":\"Alice\",\"age\":null,\"height\":164.3}", + "{\"name\":\"David\",\"age\":60,\"height\":null}", + "{\"name\":\"Amy\",\"age\":null,\"height\":null}", + "{\"name\":null,\"age\":null,\"height\":null}") +jsonPathNa <- tempfile(pattern="sparkr-test", fileext=".tmp") +writeLines(mockLinesNa, jsonPathNa) + test_that("infer types", { expect_equal(infer_type(1L), "integer") expect_equal(infer_type(1.0), "double") @@ -765,5 +774,105 @@ test_that("describe() on a DataFrame", { expect_equal(collect(stats)[5, "age"], "30") }) +test_that("dropna() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPathNa) + rows <- collect(df) + + # drop with columns + + expected <- rows[!is.na(rows$name),] + actual <- collect(dropna(df, cols = "name")) + expect_true(identical(expected, actual)) + + expected <- rows[!is.na(rows$age),] + actual <- collect(dropna(df, cols = "age")) + row.names(expected) <- row.names(actual) + # identical on two dataframes does not work here. Don't know why. + # use identical on all columns as a workaround. + expect_true(identical(expected$age, actual$age)) + expect_true(identical(expected$height, actual$height)) + expect_true(identical(expected$name, actual$name)) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height),] + actual <- collect(dropna(df, cols = c("age", "height"))) + expect_true(identical(expected, actual)) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + actual <- collect(dropna(df)) + expect_true(identical(expected, actual)) + + # drop with how + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + actual <- collect(dropna(df)) + expect_true(identical(expected, actual)) + + expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name),] + actual <- collect(dropna(df, "all")) + expect_true(identical(expected, actual)) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + actual <- collect(dropna(df, "any")) + expect_true(identical(expected, actual)) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height),] + actual <- collect(dropna(df, "any", cols = c("age", "height"))) + expect_true(identical(expected, actual)) + + expected <- rows[!is.na(rows$age) | !is.na(rows$height),] + actual <- collect(dropna(df, "all", cols = c("age", "height"))) + expect_true(identical(expected, actual)) + + # drop with threshold + + expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2,] + actual <- collect(dropna(df, minNonNulls = 2, cols = c("age", "height"))) + expect_true(identical(expected, actual)) + + expected <- rows[as.integer(!is.na(rows$age)) + + as.integer(!is.na(rows$height)) + + as.integer(!is.na(rows$name)) >= 3,] + actual <- collect(dropna(df, minNonNulls = 3, cols = c("name", "age", "height"))) + expect_true(identical(expected, actual)) +}) + +test_that("fillna() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPathNa) + rows <- collect(df) + + # fill with value + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + expected$height[is.na(expected$height)] <- 50.6 + actual <- collect(fillna(df, 50.6)) + expect_true(identical(expected, actual)) + + expected <- rows + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, "unknown")) + expect_true(identical(expected, actual)) + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + actual <- collect(fillna(df, 50.6, "age")) + expect_true(identical(expected, actual)) + + expected <- rows + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, "unknown", c("age", "name"))) + expect_true(identical(expected, actual)) + + # fill with named list + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + expected$height[is.na(expected$height)] <- 50.6 + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, list("age" = 50, "height" = 50.6, "name" = "unknown"))) + expect_true(identical(expected, actual)) +}) + unlink(parquetPath) unlink(jsonPath) +unlink(jsonPathNa) diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 371dfe454d1a2..f8e3f1a79082e 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -157,9 +157,11 @@ private[spark] object SerDe { val keysLen = readInt(in) val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType)) - val valuesType = readObjectType(in) val valuesLen = readInt(in) - val values = (0 until valuesLen).map(_ => readTypedObject(in, valuesType)) + val values = (0 until valuesLen).map(_ => { + val valueType = readObjectType(in) + readTypedObject(in, valueType) + }) mapAsJavaMap(keys.zip(values).toMap) } else { new java.util.HashMap[Object, Object]() From 9126ea4d1c5c468f3662e76e0231b4d64c7c9699 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 31 May 2015 15:17:05 -0700 Subject: [PATCH 115/254] [MINOR] Enable PySpark SQL readerwriter and window tests PySpark SQL's `readerwriter` and `window` doctests weren't being run by our test runner script; this patch re-enables them. Author: Josh Rosen Closes #6542 from JoshRosen/enable-more-pyspark-sql-tests and squashes the following commits: 9f46ce4 [Josh Rosen] Enable PySpark SQL readerwriter and window tests. --- python/run-tests | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/run-tests b/python/run-tests index fcfb49556b7cf..17dda3eadac0c 100755 --- a/python/run-tests +++ b/python/run-tests @@ -76,6 +76,8 @@ function run_sql_tests() { run_test "pyspark.sql.dataframe" run_test "pyspark.sql.group" run_test "pyspark.sql.functions" + run_test "pyspark.sql.readwriter" + run_test "pyspark.sql.window" run_test "pyspark.sql.tests" } From 6f006b5f5fca649ac51745212d8fd44b1609b9ae Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 31 May 2015 18:04:57 -0700 Subject: [PATCH 116/254] [SPARK-7986] Split scalastyle config into 3 sections. (1) rules that we enforce. (2) rules that we would like to enforce, but haven't cleaned up the codebase to turn on yet (or we need to make the scalastyle rule more configurable). (3) rules that we don't want to enforce. Author: Reynold Xin Closes #6543 from rxin/scalastyle and squashes the following commits: beefaab [Reynold Xin] [SPARK-7986] Split scalastyle config into 3 sections. --- scalastyle-config.xml | 290 +++++++++++++++++++++++++----------------- 1 file changed, 174 insertions(+), 116 deletions(-) diff --git a/scalastyle-config.xml b/scalastyle-config.xml index f52b09551adde..d6f927b6fa803 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -14,25 +14,41 @@ ~ See the License for the specific language governing permissions and ~ limitations under the License. --> - - - - - - + - Scalastyle standard configuration - - - - - - - - - Scalastyle standard configuration + + + + + + + + + + - - + + + + + + - - - - - - - - true - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + true + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW - + + ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, TRY, CATCH, FINALLY, LARROW, RARROW + + + - - - ^FunSuite[A-Za-z]*$ - - Tests must extend org.apache.spark.SparkFunSuite instead. + + ^FunSuite[A-Za-z]*$ + Tests must extend org.apache.spark.SparkFunSuite instead. + + + + + + + + + ^println$ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 800> + + + + + 30 + + + + + 10 + + + + + 50 + + + + + + + + + + + -1,0,1,2,3 + + From 91777a1c3ad3b3ec7b65d5a0413209a9baf6b36a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sun, 31 May 2015 19:55:57 -0700 Subject: [PATCH 117/254] [SPARK-7978] [SQL] [PYSPARK] DecimalType should not be singleton Author: Davies Liu Closes #6532 from davies/decimal and squashes the following commits: c7fcbce [Davies Liu] Update tests.py 1425359 [Davies Liu] DecimalType should not be singleton --- python/pyspark/sql/tests.py | 9 +++++++++ python/pyspark/sql/types.py | 18 ++++++++++++++++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5c53c3a8ed4f1..76384d31f1bf4 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -100,6 +100,15 @@ def test_data_type_eq(self): lt2 = pickle.loads(pickle.dumps(LongType())) self.assertEquals(lt, lt2) + # regression test for SPARK-7978 + def test_decimal_type(self): + t1 = DecimalType() + t2 = DecimalType(10, 2) + self.assertTrue(t2 is not t1) + self.assertNotEqual(t1, t2) + t3 = DecimalType(8) + self.assertNotEqual(t2, t3) + class SQLTests(ReusedPySparkTestCase): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 9e7e9f04bc35d..b6ec6137c9180 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -97,8 +97,6 @@ class AtomicType(DataType): """An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps.""" - __metaclass__ = DataTypeSingleton - class NumericType(AtomicType): """Numeric data types. @@ -109,6 +107,8 @@ class IntegralType(NumericType): """Integral data types. """ + __metaclass__ = DataTypeSingleton + class FractionalType(NumericType): """Fractional data types. @@ -119,26 +119,36 @@ class StringType(AtomicType): """String data type. """ + __metaclass__ = DataTypeSingleton + class BinaryType(AtomicType): """Binary (byte array) data type. """ + __metaclass__ = DataTypeSingleton + class BooleanType(AtomicType): """Boolean data type. """ + __metaclass__ = DataTypeSingleton + class DateType(AtomicType): """Date (datetime.date) data type. """ + __metaclass__ = DataTypeSingleton + class TimestampType(AtomicType): """Timestamp (datetime.datetime) data type. """ + __metaclass__ = DataTypeSingleton + class DecimalType(FractionalType): """Decimal (decimal.Decimal) data type. @@ -172,11 +182,15 @@ class DoubleType(FractionalType): """Double data type, representing double precision floats. """ + __metaclass__ = DataTypeSingleton + class FloatType(FractionalType): """Float data type, representing single precision floats. """ + __metaclass__ = DataTypeSingleton + class ByteType(IntegralType): """Byte data type, i.e. a signed integer in a single byte. From a0e46a0d2ad23ce6a64e6ebdf2ccc776208696b6 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 31 May 2015 21:01:46 -0700 Subject: [PATCH 118/254] [SPARK-7952][SPARK-7984][SQL] equality check between boolean type and numeric type is broken. The origin code has several problems: * `true <=> 1` will return false as we didn't set a rule to handle it. * `true = a` where `a` is not `Literal` and its value is 1, will return false as we only handle literal values. Author: Wenchen Fan Closes #6505 from cloud-fan/tmp1 and squashes the following commits: 77f0f39 [Wenchen Fan] minor fix b6401ba [Wenchen Fan] add type coercion for CaseKeyWhen and address comments ebc8c61 [Wenchen Fan] use SQLTestUtils and If 625973c [Wenchen Fan] improve 9ba2130 [Wenchen Fan] address comments fc0d741 [Wenchen Fan] fix style 2846a04 [Wenchen Fan] fix 7952 --- .../catalyst/analysis/HiveTypeCoercion.scala | 101 +++++++++++++----- .../sql/catalyst/expressions/predicates.scala | 5 +- .../analysis/HiveTypeCoercionSuite.scala | 55 ++++++++-- .../ExpressionEvaluationSuite.scala | 8 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 36 +++++-- 5 files changed, 158 insertions(+), 47 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 96d7b96e60ee9..edcc918bfe921 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -76,7 +76,7 @@ trait HiveTypeCoercion { WidenTypes :: PromoteStrings :: DecimalPrecision :: - BooleanComparisons :: + BooleanEqualization :: StringToIntegralCasts :: FunctionArgumentConversion :: CaseWhenCoercion :: @@ -119,7 +119,7 @@ trait HiveTypeCoercion { * the appropriate numeric equivalent. */ object ConvertNaNs extends Rule[LogicalPlan] { - val stringNaN = Literal("NaN") + private val stringNaN = Literal("NaN") def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressions { @@ -349,17 +349,17 @@ trait HiveTypeCoercion { import scala.math.{max, min} // Conversion rules for integer types into fixed-precision decimals - val intTypeToFixed: Map[DataType, DecimalType] = Map( + private val intTypeToFixed: Map[DataType, DecimalType] = Map( ByteType -> DecimalType(3, 0), ShortType -> DecimalType(5, 0), IntegerType -> DecimalType(10, 0), LongType -> DecimalType(20, 0) ) - def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType + private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType // Conversion rules for float and double into fixed-precision decimals - val floatTypeToFixed: Map[DataType, DecimalType] = Map( + private val floatTypeToFixed: Map[DataType, DecimalType] = Map( FloatType -> DecimalType(7, 7), DoubleType -> DecimalType(15, 15) ) @@ -482,30 +482,66 @@ trait HiveTypeCoercion { } /** - * Changes Boolean values to Bytes so that expressions like true < false can be Evaluated. + * Changes numeric values to booleans so that expressions like true = 1 can be evaluated. */ - object BooleanComparisons extends Rule[LogicalPlan] { - val trueValues = Seq(1, 1L, 1.toByte, 1.toShort, new java.math.BigDecimal(1)).map(Literal(_)) - val falseValues = Seq(0, 0L, 0.toByte, 0.toShort, new java.math.BigDecimal(0)).map(Literal(_)) + object BooleanEqualization extends Rule[LogicalPlan] { + private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, new java.math.BigDecimal(1)) + private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, new java.math.BigDecimal(0)) + + private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = { + CaseKeyWhen(numericExpr, Seq( + Literal(trueValues.head), booleanExpr, + Literal(falseValues.head), Not(booleanExpr), + Literal(false))) + } + + private def transform(booleanExpr: Expression, numericExpr: Expression) = { + If(Or(IsNull(booleanExpr), IsNull(numericExpr)), + Literal.create(null, BooleanType), + buildCaseKeyWhen(booleanExpr, numericExpr)) + } + + private def transformNullSafe(booleanExpr: Expression, numericExpr: Expression) = { + CaseWhen(Seq( + And(IsNull(booleanExpr), IsNull(numericExpr)), Literal(true), + Or(IsNull(booleanExpr), IsNull(numericExpr)), Literal(false), + buildCaseKeyWhen(booleanExpr, numericExpr) + )) + } def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - // Hive treats (true = 1) as true and (false = 0) as true. - case EqualTo(l @ BooleanType(), r) if trueValues.contains(r) => l - case EqualTo(l, r @ BooleanType()) if trueValues.contains(l) => r - case EqualTo(l @ BooleanType(), r) if falseValues.contains(r) => Not(l) - case EqualTo(l, r @ BooleanType()) if falseValues.contains(l) => Not(r) - - // No need to change other EqualTo operators as that actually makes sense for boolean types. - case e: EqualTo => e - // No need to change the EqualNullSafe operators, too - case e: EqualNullSafe => e - // Otherwise turn them to Byte types so that there exists and ordering. - case p: BinaryComparison if p.left.dataType == BooleanType && - p.right.dataType == BooleanType => - p.makeCopy(Array(Cast(p.left, ByteType), Cast(p.right, ByteType))) + // Hive treats (true = 1) as true and (false = 0) as true, + // all other cases are considered as false. + + // We may simplify the expression if one side is literal numeric values + case EqualTo(l @ BooleanType(), Literal(value, _: NumericType)) + if trueValues.contains(value) => l + case EqualTo(l @ BooleanType(), Literal(value, _: NumericType)) + if falseValues.contains(value) => Not(l) + case EqualTo(Literal(value, _: NumericType), r @ BooleanType()) + if trueValues.contains(value) => r + case EqualTo(Literal(value, _: NumericType), r @ BooleanType()) + if falseValues.contains(value) => Not(r) + case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType)) + if trueValues.contains(value) => And(IsNotNull(l), l) + case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType)) + if falseValues.contains(value) => And(IsNotNull(l), Not(l)) + case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType()) + if trueValues.contains(value) => And(IsNotNull(r), r) + case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType()) + if falseValues.contains(value) => And(IsNotNull(r), Not(r)) + + case EqualTo(l @ BooleanType(), r @ NumericType()) => + transform(l , r) + case EqualTo(l @ NumericType(), r @ BooleanType()) => + transform(r, l) + case EqualNullSafe(l @ BooleanType(), r @ NumericType()) => + transformNullSafe(l, r) + case EqualNullSafe(l @ NumericType(), r @ BooleanType()) => + transformNullSafe(r, l) } } @@ -606,7 +642,7 @@ trait HiveTypeCoercion { import HiveTypeCoercion._ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case cw: CaseWhenLike if !cw.resolved && cw.childrenResolved && !cw.valueTypesEqual => + case cw: CaseWhenLike if cw.childrenResolved && !cw.valueTypesEqual => logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}") val commonType = cw.valueTypes.reduce { (v1, v2) => findTightestCommonType(v1, v2).getOrElse(sys.error( @@ -625,6 +661,23 @@ trait HiveTypeCoercion { case CaseKeyWhen(key, _) => CaseKeyWhen(key, transformedBranches) } + + case ckw: CaseKeyWhen if ckw.childrenResolved && !ckw.resolved => + val commonType = (ckw.key +: ckw.whenList).map(_.dataType).reduce { (v1, v2) => + findTightestCommonType(v1, v2).getOrElse(sys.error( + s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2")) + } + val transformedBranches = ckw.branches.sliding(2, 2).map { + case Seq(when, then) if when.dataType != commonType => + Seq(Cast(when, commonType), then) + case s => s + }.reduce(_ ++ _) + val transformedKey = if (ckw.key.dataType != commonType) { + Cast(ckw.key, commonType) + } else { + ckw.key + } + CaseKeyWhen(transformedKey, transformedBranches) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index e2d1c8115e051..4f422d69c4382 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -366,7 +366,7 @@ trait CaseWhenLike extends Expression { // both then and else val should be considered. def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType) - def valueTypesEqual: Boolean = valueTypes.distinct.size <= 1 + def valueTypesEqual: Boolean = valueTypes.distinct.size == 1 override def dataType: DataType = { if (!resolved) { @@ -442,7 +442,8 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW override def children: Seq[Expression] = key +: branches override lazy val resolved: Boolean = - childrenResolved && valueTypesEqual + childrenResolved && valueTypesEqual && + (key +: whenList).map(_.dataType).distinct.size == 1 /** Written in imperative fashion for performance considerations. */ override def eval(input: Row): Any = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index f0101f4a88f86..a0798428db094 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LocalRelation, Project} +import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ class HiveTypeCoercionSuite extends PlanTest { @@ -104,15 +105,16 @@ class HiveTypeCoercionSuite extends PlanTest { widenTest(ArrayType(IntegerType), StructType(Seq()), None) } + private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) { + val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) + comparePlans( + rule(Project(Seq(Alias(initial, "a")()), testRelation)), + Project(Seq(Alias(transformed, "a")()), testRelation)) + } + test("coalesce casts") { val fac = new HiveTypeCoercion { }.FunctionArgumentConversion - def ruleTest(initial: Expression, transformed: Expression) { - val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) - comparePlans( - fac(Project(Seq(Alias(initial, "a")()), testRelation)), - Project(Seq(Alias(transformed, "a")()), testRelation)) - } - ruleTest( + ruleTest(fac, Coalesce(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -121,7 +123,7 @@ class HiveTypeCoercionSuite extends PlanTest { :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest( + ruleTest(fac, Coalesce(Literal(1L) :: Literal(1) :: Literal(new java.math.BigDecimal("1000000000000000000000")) @@ -131,4 +133,39 @@ class HiveTypeCoercionSuite extends PlanTest { :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType()) :: Nil)) } + + test("type coercion for CaseKeyWhen") { + val cwc = new HiveTypeCoercion {}.CaseWhenCoercion + ruleTest(cwc, + CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), + CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) + ) + // Will remove exception expectation in PR#6405 + intercept[RuntimeException] { + ruleTest(cwc, + CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), + CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) + ) + } + } + + test("type coercion simplification for equal to") { + val be = new HiveTypeCoercion {}.BooleanEqualization + ruleTest(be, + EqualTo(Literal(true), Literal(1)), + Literal(true) + ) + ruleTest(be, + EqualTo(Literal(true), Literal(0)), + Not(Literal(true)) + ) + ruleTest(be, + EqualNullSafe(Literal(true), Literal(1)), + And(IsNotNull(Literal(true)), Literal(true)) + ) + ruleTest(be, + EqualNullSafe(Literal(true), Literal(0)), + And(IsNotNull(Literal(true)), Not(Literal(true))) + ) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 3f5a660f17e1d..b6927485f42bf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -862,7 +862,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { val c5 = 'a.string.at(4) val c6 = 'a.string.at(5) - val literalNull = Literal.create(null, BooleanType) + val literalNull = Literal.create(null, IntegerType) val literalInt = Literal(1) val literalString = Literal("a") @@ -871,12 +871,12 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(CaseKeyWhen(c2, Seq(literalInt, c4, c5)), "a", row) checkEvaluation(CaseKeyWhen(c2, Seq(c1, c4, c5)), "b", row) checkEvaluation(CaseKeyWhen(c4, Seq(literalString, c2, c3)), 1, row) - checkEvaluation(CaseKeyWhen(c4, Seq(c1, c3, c5, c2, Literal(3))), 3, row) + checkEvaluation(CaseKeyWhen(c4, Seq(c6, c3, c5, c2, Literal(3))), 3, row) checkEvaluation(CaseKeyWhen(literalInt, Seq(c2, c4, c5)), "a", row) checkEvaluation(CaseKeyWhen(literalString, Seq(c5, c2, c4, c3)), 2, row) - checkEvaluation(CaseKeyWhen(literalInt, Seq(c5, c2, c4, c3)), null, row) - checkEvaluation(CaseKeyWhen(literalNull, Seq(c5, c2, c1, c3)), 2, row) + checkEvaluation(CaseKeyWhen(c6, Seq(c5, c2, c4, c3)), null, row) + checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), "c", row) } test("complex type") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index bf18bf854aa4a..63f7d314fb699 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext} import org.apache.spark.sql.test.TestSQLContext.{udf => _, _} import org.apache.spark.sql.types._ @@ -32,12 +32,12 @@ import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ class MyDialect extends DefaultParserDialect -class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { +class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { // Make sure the tables are loaded. TestData - import org.apache.spark.sql.test.TestSQLContext.implicits._ - val sqlCtx = TestSQLContext + val sqlContext = TestSQLContext + import sqlContext.implicits._ test("SPARK-6743: no columns from cache") { Seq( @@ -915,7 +915,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(values(0).toInt, values(1), values(2).toBoolean, v4) } - val df1 = sqlCtx.createDataFrame(rowRDD1, schema1) + val df1 = createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") checkAnswer( sql("SELECT * FROM applySchema1"), @@ -945,7 +945,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df2 = sqlCtx.createDataFrame(rowRDD2, schema2) + val df2 = createDataFrame(rowRDD2, schema2) df2.registerTempTable("applySchema2") checkAnswer( sql("SELECT * FROM applySchema2"), @@ -970,7 +970,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4)) } - val df3 = sqlCtx.createDataFrame(rowRDD3, schema2) + val df3 = createDataFrame(rowRDD3, schema2) df3.registerTempTable("applySchema3") checkAnswer( @@ -1015,7 +1015,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { .build() val schemaWithMeta = new StructType(Array( schema("id"), schema("name").copy(metadata = metadata), schema("age"))) - val personWithMeta = sqlCtx.createDataFrame(person.rdd, schemaWithMeta) + val personWithMeta = createDataFrame(person.rdd, schemaWithMeta) def validateMetadata(rdd: DataFrame): Unit = { assert(rdd.schema("name").metadata.getString(docKey) == docValue) } @@ -1331,4 +1331,24 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) } + + test("SPARK-7952: fix the equality check between boolean and numeric types") { + withTempTable("t") { + // numeric field i, boolean field j, result of i = j, result of i <=> j + Seq[(Integer, java.lang.Boolean, java.lang.Boolean, java.lang.Boolean)]( + (1, true, true, true), + (0, false, true, true), + (2, true, false, false), + (2, false, false, false), + (null, true, null, false), + (null, false, null, false), + (0, null, null, false), + (1, null, null, false), + (null, null, null, true) + ).toDF("i", "b", "r1", "r2").registerTempTable("t") + + checkAnswer(sql("select i = b from t"), sql("select r1 from t")) + checkAnswer(sql("select i <=> b from t"), sql("select r2 from t")) + } + } } From 3c0156899dc1ec1f7dfe6d7c8af47fa6dc7d00bf Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 31 May 2015 23:55:45 -0700 Subject: [PATCH 119/254] Update README to include DataFrames and zinc. Also cut trailing whitespaces. Author: Reynold Xin Closes #6548 from rxin/readme and squashes the following commits: 630efc3 [Reynold Xin] Update README to include DataFrames and zinc. --- README.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 9c09d40e2bdae..380422ca00dbe 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,8 @@ Spark is a fast and general cluster computing system for Big Data. It provides high-level APIs in Scala, Java, and Python, and an optimized engine that supports general computation graphs for data analysis. It also supports a -rich set of higher-level tools including Spark SQL for SQL and structured -data processing, MLlib for machine learning, GraphX for graph processing, +rich set of higher-level tools including Spark SQL for SQL and DataFrames, +MLlib for machine learning, GraphX for graph processing, and Spark Streaming for stream processing. @@ -22,7 +22,7 @@ This README file only contains basic setup instructions. Spark is built using [Apache Maven](http://maven.apache.org/). To build Spark and its example programs, run: - mvn -DskipTests clean package + build/mvn -DskipTests clean package (You do not need to do this if you downloaded a pre-built package.) More detailed documentation is available from the project site, at @@ -43,7 +43,7 @@ Try the following command, which should return 1000: Alternatively, if you prefer Python, you can use the Python shell: ./bin/pyspark - + And run the following command, which should also return 1000: >>> sc.parallelize(range(1000)).count() @@ -58,9 +58,9 @@ To run one of them, use `./bin/run-example [params]`. For example: will run the Pi example locally. You can set the MASTER environment variable when running examples to submit -examples to a cluster. This can be a mesos:// or spark:// URL, -"yarn-cluster" or "yarn-client" to run on YARN, and "local" to run -locally with one thread, or "local[N]" to run locally with N threads. You +examples to a cluster. This can be a mesos:// or spark:// URL, +"yarn-cluster" or "yarn-client" to run on YARN, and "local" to run +locally with one thread, or "local[N]" to run locally with N threads. You can also use an abbreviated class name if the class is in the `examples` package. For instance: @@ -75,7 +75,7 @@ can be run using: ./dev/run-tests -Please see the guidance on how to +Please see the guidance on how to [run tests for a module, or individual tests](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools). ## A Note About Hadoop Versions From e7c7e51f2ec158d12a8429f753225c746f92d513 Mon Sep 17 00:00:00 2001 From: Nishkam Ravi Date: Mon, 1 Jun 2015 21:34:41 +0100 Subject: [PATCH 120/254] [DOC] Minor modification to Streaming docs with regards to parallel data receiving pwendell tdas Author: Nishkam Ravi Author: nishkamravi2 Author: nravi Closes #6544 from nishkamravi2/master_nravi and squashes the following commits: 46e8c03 [Nishkam Ravi] Slight modification to streaming docs --- docs/streaming-programming-guide.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index bd863d48d53e3..42b33947873b0 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1946,10 +1946,10 @@ creates a single receiver (running on a worker machine) that receives a single s Receiving multiple data streams can therefore be achieved by creating multiple input DStreams and configuring them to receive different partitions of the data stream from the source(s). For example, a single Kafka input DStream receiving two topics of data can be split into two -Kafka input streams, each receiving only one topic. This would run two receivers on two workers, -thus allowing data to be received in parallel, and increasing overall throughput. These multiple -DStream can be unioned together to create a single DStream. Then the transformations that was -being applied on the single input DStream can applied on the unified stream. This is done as follows. +Kafka input streams, each receiving only one topic. This would run two receivers, +allowing data to be received in parallel, and increasing overall throughput. These multiple +DStreams can be unioned together to create a single DStream. Then the transformations that were +being applied on a single input DStream can be applied on the unified stream. This is done as follows.
    From b7ab0299b03ae833d5811f380e4594837879f8ae Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 1 Jun 2015 14:40:08 -0700 Subject: [PATCH 121/254] [SPARK-7497] [PYSPARK] [STREAMING] fix streaming flaky tests Increase the duration and timeout in streaming python tests. Author: Davies Liu Closes #6239 from davies/flaky_tests and squashes the following commits: d6aee8f [Davies Liu] fix window tests 26317f7 [Davies Liu] Merge branch 'master' of github.com:apache/spark into flaky_tests 7947db6 [Davies Liu] fix streaming flaky tests --- python/pyspark/streaming/tests.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 33ea8c9293d74..46cb18b2e8ef9 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -41,8 +41,8 @@ class PySparkStreamingTestCase(unittest.TestCase): - timeout = 4 # seconds - duration = .2 + timeout = 10 # seconds + duration = .5 @classmethod def setUpClass(cls): @@ -379,13 +379,13 @@ def func(dstream): class WindowFunctionTests(PySparkStreamingTestCase): - timeout = 5 + timeout = 15 def test_window(self): input = [range(1), range(2), range(3), range(4), range(5)] def func(dstream): - return dstream.window(.6, .2).count() + return dstream.window(1.5, .5).count() expected = [[1], [3], [6], [9], [12], [9], [5]] self._test_func(input, func, expected) @@ -394,7 +394,7 @@ def test_count_by_window(self): input = [range(1), range(2), range(3), range(4), range(5)] def func(dstream): - return dstream.countByWindow(.6, .2) + return dstream.countByWindow(1.5, .5) expected = [[1], [3], [6], [9], [12], [9], [5]] self._test_func(input, func, expected) @@ -403,7 +403,7 @@ def test_count_by_window_large(self): input = [range(1), range(2), range(3), range(4), range(5), range(6)] def func(dstream): - return dstream.countByWindow(1, .2) + return dstream.countByWindow(2.5, .5) expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]] self._test_func(input, func, expected) @@ -412,7 +412,7 @@ def test_count_by_value_and_window(self): input = [range(1), range(2), range(3), range(4), range(5), range(6)] def func(dstream): - return dstream.countByValueAndWindow(1, .2) + return dstream.countByValueAndWindow(2.5, .5) expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]] self._test_func(input, func, expected) @@ -421,7 +421,7 @@ def test_group_by_key_and_window(self): input = [[('a', i)] for i in range(5)] def func(dstream): - return dstream.groupByKeyAndWindow(.6, .2).mapValues(list) + return dstream.groupByKeyAndWindow(1.5, .5).mapValues(list) expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])], [('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]] From 90c606925e7ec8f65f28e2290a0048f64af8c6a6 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 1 Jun 2015 15:05:14 -0700 Subject: [PATCH 122/254] [SPARK-7584] [MLLIB] User guide for VectorAssembler This PR adds a section in the user guide for `VectorAssembler` with code examples in Python/Java/Scala. It also adds a unit test in Java. jkbradley Author: Xiangrui Meng Closes #6556 from mengxr/SPARK-7584 and squashes the following commits: 11313f6 [Xiangrui Meng] simplify Java example 0cd47f3 [Xiangrui Meng] update user guide fd36292 [Xiangrui Meng] update Java unit test ce61ca0 [Xiangrui Meng] add Java unit test for VectorAssembler e399942 [Xiangrui Meng] scala/python example code --- docs/ml-features.md | 114 ++++++++++++++++++ .../ml/feature/JavaVectorAssemblerSuite.java | 78 ++++++++++++ 2 files changed, 192 insertions(+) create mode 100644 mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java diff --git a/docs/ml-features.md b/docs/ml-features.md index 81f1b8823a8ce..9ee5696122717 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -964,5 +964,119 @@ DataFrame transformedData = transformer.transform(dataFrame);
    +## VectorAssembler + +`VectorAssembler` is a transformer that combines a given list of columns into a single vector +column. +It is useful for combining raw features and features generated by different feature transformers +into a single feature vector, in order to train ML models like logistic regression and decision +trees. +`VectorAssembler` accepts the following input column types: all numeric types, boolean type, +and vector type. +In each row, the values of the input columns will be concatenated into a vector in the specified +order. + +**Examples** + +Assume that we have a DataFrame with the columns `id`, `hour`, `mobile`, `userFeatures`, +and `clicked`: + +~~~ + id | hour | mobile | userFeatures | clicked +----|------|--------|------------------|--------- + 0 | 18 | 1.0 | [0.0, 10.0, 0.5] | 1.0 +~~~ + +`userFeatures` is a vector column that contains three user features. +We want to combine `hour`, `mobile`, and `userFeatures` into a single feature vector +called `features` and use it to predict `clicked` or not. +If we set `VectorAssembler`'s input columns to `hour`, `mobile`, and `userFeatures` and +output column to `features`, after transformation we should get the following DataFrame: + +~~~ + id | hour | mobile | userFeatures | clicked | features +----|------|--------|------------------|---------|----------------------------- + 0 | 18 | 1.0 | [0.0, 10.0, 0.5] | 1.0 | [18.0, 1.0, 0.0, 10.0, 0.5] +~~~ + +
    +
    + +[`VectorAssembler`](api/scala/index.html#org.apache.spark.ml.feature.VectorAssembler) takes an array +of input column names and an output column name. + +{% highlight scala %} +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.feature.VectorAssembler + +val dataset = sqlContext.createDataFrame( + Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0)) +).toDF("id", "hour", "mobile", "userFeatures", "clicked") +val assembler = new VectorAssembler() + .setInputCols(Array("hour", "mobile", "userFeatures")) + .setOutputCol("features") +val output = assembler.transform(dataset) +println(output.select("features", "clicked").first()) +{% endhighlight %} +
    + +
    + +[`VectorAssembler`](api/java/org/apache/spark/ml/feature/VectorAssembler.html) takes an array +of input column names and an output column name. + +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; + +StructType schema = createStructType(new StructField[] { + createStructField("id", IntegerType, false), + createStructField("hour", IntegerType, false), + createStructField("mobile", DoubleType, false), + createStructField("userFeatures", new VectorUDT(), false), + createStructField("clicked", DoubleType, false) +}); +Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); +JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); +DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + +VectorAssembler assembler = new VectorAssembler() + .setInputCols(new String[] {"hour", "mobile", "userFeatures"}) + .setOutputCol("features"); + +DataFrame output = assembler.transform(dataset); +System.out.println(output.select("features", "clicked").first()); +{% endhighlight %} +
    + +
    + +[`VectorAssembler`](api/python/pyspark.ml.html#pyspark.ml.feature.VectorAssembler) takes a list +of input column names and an output column name. + +{% highlight python %} +from pyspark.mllib.linalg import Vectors +from pyspark.ml.feature import VectorAssembler + +dataset = sqlContext.createDataFrame( + [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)], + ["id", "hour", "mobile", "userFeatures", "clicked"]) +assembler = VectorAssembler( + inputCols=["hour", "mobile", "userFeatures"], + outputCol="features") +output = assembler.transform(dataset) +print(output.select("features", "clicked").first()) +{% endhighlight %} +
    +
    + # Feature Selectors diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java new file mode 100644 index 0000000000000..b7c564caad3bd --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.Arrays; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; + +public class JavaVectorAssemblerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext sqlContext; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaVectorAssemblerSuite"); + sqlContext = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void testVectorAssembler() { + StructType schema = createStructType(new StructField[] { + createStructField("id", IntegerType, false), + createStructField("x", DoubleType, false), + createStructField("y", new VectorUDT(), false), + createStructField("name", StringType, false), + createStructField("z", new VectorUDT(), false), + createStructField("n", LongType, false) + }); + Row row = RowFactory.create( + 0, 0.0, Vectors.dense(1.0, 2.0), "a", + Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L); + JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + VectorAssembler assembler = new VectorAssembler() + .setInputCols(new String[] {"x", "y", "z", "n"}) + .setOutputCol("features"); + DataFrame output = assembler.transform(dataset); + Assert.assertEquals( + Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}), + output.select("features").first().getAs(0)); + } +} From 2f9c7519d6a3f867100979b5e7ced3f72b7d9adc Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 1 Jun 2015 20:04:57 -0700 Subject: [PATCH 123/254] [SPARK-7958] [STREAMING] Handled exception in StreamingContext.start() to prevent leaking of actors StreamingContext.start() can throw exception because DStream.validateAtStart() fails (say, checkpoint directory not set for StateDStream). But by then JobScheduler, JobGenerator, and ReceiverTracker has already started, along with their actors. But those cannot be shutdown because the only way to do that is call StreamingContext.stop() which cannot be called as the context has not been marked as ACTIVE. The solution in this PR is to stop the internal scheduler if start throw exception, and mark the context as STOPPED. Author: Tathagata Das Closes #6559 from tdas/SPARK-7958 and squashes the following commits: 20b2ec1 [Tathagata Das] Added synchronized 790b617 [Tathagata Das] Handled exception in StreamingContext.start() --- .../spark/streaming/StreamingContext.scala | 17 +++++++++++++---- .../streaming/scheduler/JobScheduler.scala | 4 ++++ .../spark/streaming/StreamingContextSuite.scala | 16 ++++++++++++++++ 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 25842d502543e..624a31ddc2b89 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -23,6 +23,7 @@ import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} import scala.collection.Map import scala.collection.mutable.Queue import scala.reflect.ClassTag +import scala.util.control.NonFatal import akka.actor.{Props, SupervisorStrategy} import org.apache.hadoop.conf.Configuration @@ -576,18 +577,26 @@ class StreamingContext private[streaming] ( def start(): Unit = synchronized { state match { case INITIALIZED => - validate() startSite.set(DStream.getCreationSite()) sparkContext.setCallSite(startSite.get) StreamingContext.ACTIVATION_LOCK.synchronized { StreamingContext.assertNoOtherContextIsActive() - scheduler.start() - uiTab.foreach(_.attach()) - state = StreamingContextState.ACTIVE + try { + validate() + scheduler.start() + state = StreamingContextState.ACTIVE + } catch { + case NonFatal(e) => + logError("Error starting the context, marking it as stopped", e) + scheduler.stop(false) + state = StreamingContextState.STOPPED + throw e + } StreamingContext.setActiveContext(this) } shutdownHookRef = Utils.addShutdownHook( StreamingContext.SHUTDOWN_HOOK_PRIORITY)(stopOnShutdown) + uiTab.foreach(_.attach()) logInfo("StreamingContext started") case ACTIVE => logWarning("StreamingContext has already been started") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 1d1ddaaccf217..4af9b6d3b56ab 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -126,6 +126,10 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { eventLoop.post(ErrorReported(msg, e)) } + def isStarted(): Boolean = synchronized { + eventLoop != null + } + private def processEvent(event: JobSchedulerEvent) { try { event match { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index d304c9a7328f3..819dd2ccfe915 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -151,6 +151,22 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo assert(StreamingContext.getActive().isEmpty) } + test("start failure should stop internal components") { + ssc = new StreamingContext(conf, batchDuration) + val inputStream = addInputStream(ssc) + val updateFunc = (values: Seq[Int], state: Option[Int]) => { + Some(values.sum + state.getOrElse(0)) + } + inputStream.map(x => (x, 1)).updateStateByKey[Int](updateFunc) + // Require that the start fails because checkpoint directory was not set + intercept[Exception] { + ssc.start() + } + assert(ssc.getState() === StreamingContextState.STOPPED) + assert(ssc.scheduler.isStarted === false) + } + + test("start multiple times") { ssc = new StreamingContext(master, appName, batchDuration) addInputStream(ssc).register() From 15d7c90aeb0d51851f7ebb4c75c9b249ad88dfeb Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 1 Jun 2015 19:39:03 -0700 Subject: [PATCH 124/254] [MINOR] [UI] Improve error message on log page Currently if a bad log type if specified, then we get blank. We should provide a more informative error message. --- .../spark/deploy/worker/ui/LogPage.scala | 6 ++ .../spark/deploy/worker/ui/LogPageSuite.scala | 70 +++++++++++++++++++ 2 files changed, 76 insertions(+) create mode 100644 core/src/test/scala/org/apache/spark/deploy/worker/ui/LogPageSuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 88170d4df3053..dc2bee6f2bdca 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -29,6 +29,7 @@ import org.apache.spark.util.logging.RollingFileAppender private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with Logging { private val worker = parent.worker private val workDir = parent.workDir + private val supportedLogTypes = Set("stderr", "stdout") def renderLog(request: HttpServletRequest): String = { val defaultBytes = 100 * 1024 @@ -129,6 +130,11 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with offsetOption: Option[Long], byteLength: Int ): (String, Long, Long, Long) = { + + if (!supportedLogTypes.contains(logType)) { + return ("Error: Log type must be one of " + supportedLogTypes.mkString(", "), 0, 0, 0) + } + try { val files = RollingFileAppender.getSortedRolledOverFiles(logDirectory, logType) logDebug(s"Sorted log files of type $logType in $logDirectory:\n${files.mkString("\n")}") diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ui/LogPageSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ui/LogPageSuite.scala new file mode 100644 index 0000000000000..572360ddb95d4 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ui/LogPageSuite.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.worker.ui + +import java.io.{File, FileWriter} + +import org.mockito.Mockito.mock +import org.scalatest.PrivateMethodTester + +import org.apache.spark.SparkFunSuite + +class LogPageSuite extends SparkFunSuite with PrivateMethodTester { + + test("get logs simple") { + val webui = mock(classOf[WorkerWebUI]) + val logPage = new LogPage(webui) + + // Prepare some fake log files to read later + val out = "some stdout here" + val err = "some stderr here" + val tmpDir = new File(sys.props("java.io.tmpdir")) + val tmpOut = new File(tmpDir, "stdout") + val tmpErr = new File(tmpDir, "stderr") + val tmpRand = new File(tmpDir, "random") + write(tmpOut, out) + write(tmpErr, err) + write(tmpRand, "1 6 4 5 2 7 8") + + // Get the logs. All log types other than "stderr" or "stdout" will be rejected + val getLog = PrivateMethod[(String, Long, Long, Long)]('getLog) + val (stdout, _, _, _) = + logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100) + val (stderr, _, _, _) = + logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100) + val (error1, _, _, _) = + logPage invokePrivate getLog(tmpDir.getAbsolutePath, "random", None, 100) + val (error2, _, _, _) = + logPage invokePrivate getLog(tmpDir.getAbsolutePath, "does-not-exist.txt", None, 100) + assert(stdout === out) + assert(stderr === err) + assert(error1.startsWith("Error")) + assert(error2.startsWith("Error")) + } + + /** Write the specified string to the file. */ + private def write(f: File, s: String): Unit = { + val writer = new FileWriter(f) + try { + writer.write(s) + } finally { + writer.close() + } + } + +} From 6b44278ef7cd2a278dfa67e8393ef30775c72726 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 1 Jun 2015 21:01:14 -0700 Subject: [PATCH 125/254] [SPARK-8028] [SPARKR] Use addJar instead of setJars in SparkR This prevents the spark.jars from being cleared while using `--packages` or `--jars` cc pwendell davies brkyvz Author: Shivaram Venkataraman Closes #6568 from shivaram/SPARK-8028 and squashes the following commits: 3a9cf1f [Shivaram Venkataraman] Use addJar instead of setJars in SparkR This prevents the spark.jars from being cleared --- core/src/main/scala/org/apache/spark/api/r/RRDD.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index e020458888e4a..4dfa7325934ff 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -355,7 +355,6 @@ private[r] object RRDD { val sparkConf = new SparkConf().setAppName(appName) .setSparkHome(sparkHome) - .setJars(jars) // Override `master` if we have a user-specified value if (master != "") { @@ -373,7 +372,11 @@ private[r] object RRDD { sparkConf.setExecutorEnv(name.asInstanceOf[String], value.asInstanceOf[String]) } - new JavaSparkContext(sparkConf) + val jsc = new JavaSparkContext(sparkConf) + jars.foreach { jar => + jsc.addJar(jar) + } + jsc } /** From 6396cc0303ceabea53c4df436ffa50b82b7e233f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 1 Jun 2015 21:11:19 -0700 Subject: [PATCH 126/254] [SPARK-7982][SQL] DataFrame.stat.crosstab should use 0 instead of null for pairs that don't appear Author: Reynold Xin Closes #6566 from rxin/crosstab and squashes the following commits: e0ace1c [Reynold Xin] [SPARK-7982][SQL] DataFrame.stat.crosstab should use 0 instead of null for pairs that don't appear --- .../apache/spark/sql/execution/stat/StatFunctions.scala | 9 ++++++--- .../scala/org/apache/spark/sql/DataFrameStatSuite.scala | 4 ++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index b1a8204dd5f71..93383e5a62f11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.stat import org.apache.spark.Logging -import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql.{Row, Column, DataFrame} import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.functions._ @@ -116,7 +116,10 @@ private[sql] object StatFunctions extends Logging { s"exceed 1e4. Currently $columnSize") val table = counts.groupBy(_.get(0)).map { case (col1Item, rows) => val countsRow = new GenericMutableRow(columnSize + 1) - rows.foreach { row => + rows.foreach { (row: Row) => + // row.get(0) is column 1 + // row.get(1) is column 2 + // row.get(3) is the frequency countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2)) } // the value of col1 is the first value, the rest are the counts @@ -126,6 +129,6 @@ private[sql] object StatFunctions extends Logging { val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq val schema = StructType(StructField(tableName, StringType) +: headerNames) - new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)) + new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 78de89f0b9f39..438f479459dfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -74,10 +74,10 @@ class DataFrameStatSuite extends SparkFunSuite { val rows: Array[Row] = crosstab.collect().sortBy(_.getString(0)) assert(rows(0).get(0).toString === "0") assert(rows(0).getLong(1) === 2L) - assert(rows(0).get(2) === null) + assert(rows(0).get(2) === 0L) assert(rows(1).get(0).toString === "1") assert(rows(1).getLong(1) === 1L) - assert(rows(1).get(2) === null) + assert(rows(1).get(2) === 0L) assert(rows(2).get(0).toString === "2") assert(rows(2).getLong(1) === 2L) assert(rows(2).getLong(2) === 1L) From 89f642a0e8c3a6bc9149a0bb413f1a8939cb0283 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 1 Jun 2015 21:13:15 -0700 Subject: [PATCH 127/254] [SPARK-8026][SQL] Add Column.alias to Scala/Java DataFrame API Author: Reynold Xin Closes #6565 from rxin/alias and squashes the following commits: 286d880 [Reynold Xin] [SPARK-8026][SQL] Add Column.alias to Scala/Java DataFrame API --- .../src/main/scala/org/apache/spark/sql/Column.scala | 12 ++++++++++++ .../org/apache/spark/sql/ColumnExpressionSuite.scala | 6 ++++++ 2 files changed, 18 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index b49b1d327289f..d3efa83380d04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -716,6 +716,18 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def endsWith(literal: String): Column = this.endsWith(lit(literal)) + /** + * Gives the column an alias. Same as `as`. + * {{{ + * // Renames colA to colB in select output. + * df.select($"colA".alias("colB")) + * }}} + * + * @group expr_ops + * @since 1.4.0 + */ + def alias(alias: String): Column = as(alias) + /** * Gives the column an alias. * {{{ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index d006b83fc075a..b8bb1bff9ea72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -27,6 +27,12 @@ import org.apache.spark.sql.types._ class ColumnExpressionSuite extends QueryTest { import org.apache.spark.sql.TestData._ + test("alias") { + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") + assert(df.select(df("a").as("b")).columns.head === "b") + assert(df.select(df("a").alias("b")).columns.head === "b") + } + test("single explode") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") checkAnswer( From cae9306c4f437c722baa57593fe83f4b7d82dbff Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 1 Jun 2015 21:21:45 -0700 Subject: [PATCH 128/254] [SPARK-8027] [SPARKR] Add maven profile to build R package docs Also use that profile in create-release.sh cc pwendell -- Note that this means that we need `knitr` and `roxygen` installed on the machines used for building the release. Let me know if you need help with that. Author: Shivaram Venkataraman Closes #6567 from shivaram/SPARK-8027 and squashes the following commits: 8dc8ecf [Shivaram Venkataraman] Add maven profile to build R package docs Also use that profile in create-release.sh --- core/pom.xml | 23 +++++++++++++++++++++++ dev/create-release/create-release.sh | 16 ++++++++-------- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 5c02be831ce06..a02184222e9f0 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -481,6 +481,29 @@
    + + sparkr-docs + + + + org.codehaus.mojo + exec-maven-plugin + + + sparkr-pkg-docs + compile + + exec + + + + + ..${path.separator}R${path.separator}create-docs${script.extension} + + + + + diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 54274a83f6d66..0b14a618e755c 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -228,14 +228,14 @@ if [[ ! "$@" =~ --skip-package ]]; then # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds # share the same Zinc server. - make_binary_release "hadoop1" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver" "3030" & - make_binary_release "hadoop1-scala2.11" "-Psparkr -Phadoop-1 -Phive -Dscala-2.11" "3031" & - make_binary_release "cdh4" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & - make_binary_release "hadoop2.3" "-Psparkr -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & - make_binary_release "hadoop2.4" "-Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & - make_binary_release "mapr3" "-Pmapr3 -Psparkr -Phive -Phive-thriftserver" "3035" & - make_binary_release "mapr4" "-Pmapr4 -Psparkr -Pyarn -Phive -Phive-thriftserver" "3036" & - make_binary_release "hadoop2.4-without-hive" "-Psparkr -Phadoop-2.4 -Pyarn" "3037" & + make_binary_release "hadoop1" "-Psparkr -Psparkr-docs -Phadoop-1 -Phive -Phive-thriftserver" "3030" & + make_binary_release "hadoop1-scala2.11" "-Psparkr -Psparkr-docs -Phadoop-1 -Phive -Dscala-2.11" "3031" & + make_binary_release "cdh4" "-Psparkr -Psparkr-docs -Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & + make_binary_release "hadoop2.3" "-Psparkr -Psparkr-docs -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & + make_binary_release "hadoop2.4" "-Psparkr -Psparkr-docs -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & + make_binary_release "mapr3" "-Pmapr3 -Psparkr -Psparkr-docs -Phive -Phive-thriftserver" "3035" & + make_binary_release "mapr4" "-Pmapr4 -Psparkr -Psparkr-docs -Pyarn -Phive -Phive-thriftserver" "3036" & + make_binary_release "hadoop2.4-without-hive" "-Psparkr -Psparkr-docs -Phadoop-2.4 -Pyarn" "3037" & wait rm -rf spark-$RELEASE_VERSION-bin-*/ From 4c868b9943a2d86107d1f15f8df9830aac36fb75 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 1 Jun 2015 21:29:39 -0700 Subject: [PATCH 129/254] [minor doc] Add exploratory data analysis warning for DataFrame.stat.freqItem API Author: Reynold Xin Closes #6569 from rxin/freqItemsWarning and squashes the following commits: 7eec145 [Reynold Xin] [minor doc] Add exploratory data analysis warning for DataFrame.stat.freqItem API. --- python/pyspark/sql/dataframe.py | 3 +++ .../apache/spark/sql/DataFrameStatFunctions.scala | 12 ++++++++++++ 2 files changed, 15 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 936487519a645..a82b6b87c413e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1170,6 +1170,9 @@ def freqItems(self, cols, support=None): "http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou". :func:`DataFrame.freqItems` and :func:`DataFrameStatFunctions.freqItems` are aliases. + This function is meant for exploratory data analysis, as we make no guarantee about the + backward compatibility of the schema of the resulting DataFrame. + :param cols: Names of the columns to calculate frequent items for as a list or tuple of strings. :param support: The frequency with which to consider an item 'frequent'. Default is 1%. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index b624eaa201ea4..edb9ed7bba56a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -97,6 +97,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. * The `support` should be greater than 1e-4. * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[DataFrame]]. + * * @param cols the names of the columns to search frequent items in. * @param support The minimum frequency for an item to be considered `frequent`. Should be greater * than 1e-4. @@ -114,6 +117,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. * Uses a `default` support of 1%. * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[DataFrame]]. + * * @param cols the names of the columns to search frequent items in. * @return A Local DataFrame with the Array of frequent items for each column. * @@ -128,6 +134,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * frequent element count algorithm described in * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[DataFrame]]. + * * @param cols the names of the columns to search frequent items in. * @return A Local DataFrame with the Array of frequent items for each column. * @@ -143,6 +152,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. * Uses a `default` support of 1%. * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[DataFrame]]. + * * @param cols the names of the columns to search frequent items in. * @return A Local DataFrame with the Array of frequent items for each column. * From 91f6be87bc5cff41ca7a9cca9fdcc4678a4e7086 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 1 Jun 2015 21:33:57 -0700 Subject: [PATCH 130/254] [SPARK-8020] Spark SQL in spark-defaults.conf make metadataHive get constructed too early https://issues.apache.org/jira/browse/SPARK-8020 Author: Yin Huai Closes #6563 from yhuai/SPARK-8020 and squashes the following commits: 4e5addc [Yin Huai] style bf766c6 [Yin Huai] Failed test. 0398f5b [Yin Huai] First populate the SQLConf and then construct executionHive and metadataHive. --- .../org/apache/spark/sql/SQLContext.scala | 25 +++++++++-- .../spark/sql/hive/client/VersionsSuite.scala | 45 ++++++++++++++++++- 2 files changed, 66 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 7384b24c50b16..91e6385dec81b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -182,9 +182,28 @@ class SQLContext(@transient val sparkContext: SparkContext) conf.dialect } - sparkContext.getConf.getAll.foreach { - case (key, value) if key.startsWith("spark.sql") => setConf(key, value) - case _ => + { + // We extract spark sql settings from SparkContext's conf and put them to + // Spark SQL's conf. + // First, we populate the SQLConf (conf). So, we can make sure that other values using + // those settings in their construction can get the correct settings. + // For example, metadataHive in HiveContext may need both spark.sql.hive.metastore.version + // and spark.sql.hive.metastore.jars to get correctly constructed. + val properties = new Properties + sparkContext.getConf.getAll.foreach { + case (key, value) if key.startsWith("spark.sql") => properties.setProperty(key, value) + case _ => + } + // We directly put those settings to conf to avoid of calling setConf, which may have + // side-effects. For example, in HiveContext, setConf may cause executionHive and metadataHive + // get constructed. If we call setConf directly, the constructed metadataHive may have + // wrong settings, or the construction may fail. + conf.setConf(properties) + // After we have populated SQLConf, we call setConf to populate other confs in the subclass + // (e.g. hiveconf in HiveContext). + properties.foreach { + case (key, value) => setConf(key, value) + } } @transient diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 7eb4842726665..deceb67d2b966 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.hive.client -import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.util.Utils @@ -37,6 +38,48 @@ class VersionsSuite extends SparkFunSuite with Logging { "hive.metastore.warehouse.dir" -> warehousePath.toString) } + test("SPARK-8020: successfully create a HiveContext with metastore settings in Spark conf.") { + val sparkConf = + new SparkConf() { + // We are not really clone it. We need to keep the custom getAll. + override def clone: SparkConf = this + + override def getAll: Array[(String, String)] = { + val allSettings = super.getAll + val metastoreVersion = get("spark.sql.hive.metastore.version") + val metastoreJars = get("spark.sql.hive.metastore.jars") + + val others = allSettings.filterNot { case (key, _) => + key == "spark.sql.hive.metastore.version" || key == "spark.sql.hive.metastore.jars" + } + + // Put metastore.version to the first one. It is needed to trigger the exception + // caused by SPARK-8020. Other problems triggered by SPARK-8020 + // (e.g. using Hive 0.13.1's metastore client to connect to the a 0.12 metastore) + // are not easy to test. + Array( + ("spark.sql.hive.metastore.version" -> metastoreVersion), + ("spark.sql.hive.metastore.jars" -> metastoreJars)) ++ others + } + } + sparkConf + .set("spark.sql.hive.metastore.version", "12") + .set("spark.sql.hive.metastore.jars", "maven") + + val hiveContext = new HiveContext( + new SparkContext( + "local[2]", + "TestSQLContextInVersionsSuite", + sparkConf)) { + + protected override def configure(): Map[String, String] = buildConf + + } + + // Make sure all metastore related lazy vals got created. + hiveContext.tables() + } + test("success sanity check") { val badClient = IsolatedClientLoader.forVersion("13", buildConf()).client val db = new HiveDatabase("default", "") From 75dda33f3e037d550c4ab55d438661070804c717 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 1 Jun 2015 21:35:55 -0700 Subject: [PATCH 131/254] Revert "[SPARK-8020] Spark SQL in spark-defaults.conf make metadataHive get constructed too early" This reverts commit 91f6be87bc5cff41ca7a9cca9fdcc4678a4e7086. --- .../org/apache/spark/sql/SQLContext.scala | 25 ++--------- .../spark/sql/hive/client/VersionsSuite.scala | 45 +------------------ 2 files changed, 4 insertions(+), 66 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 91e6385dec81b..7384b24c50b16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -182,28 +182,9 @@ class SQLContext(@transient val sparkContext: SparkContext) conf.dialect } - { - // We extract spark sql settings from SparkContext's conf and put them to - // Spark SQL's conf. - // First, we populate the SQLConf (conf). So, we can make sure that other values using - // those settings in their construction can get the correct settings. - // For example, metadataHive in HiveContext may need both spark.sql.hive.metastore.version - // and spark.sql.hive.metastore.jars to get correctly constructed. - val properties = new Properties - sparkContext.getConf.getAll.foreach { - case (key, value) if key.startsWith("spark.sql") => properties.setProperty(key, value) - case _ => - } - // We directly put those settings to conf to avoid of calling setConf, which may have - // side-effects. For example, in HiveContext, setConf may cause executionHive and metadataHive - // get constructed. If we call setConf directly, the constructed metadataHive may have - // wrong settings, or the construction may fail. - conf.setConf(properties) - // After we have populated SQLConf, we call setConf to populate other confs in the subclass - // (e.g. hiveconf in HiveContext). - properties.foreach { - case (key, value) => setConf(key, value) - } + sparkContext.getConf.getAll.foreach { + case (key, value) if key.startsWith("spark.sql") => setConf(key, value) + case _ => } @transient diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index deceb67d2b966..7eb4842726665 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.hive.client -import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.util.Utils @@ -38,48 +37,6 @@ class VersionsSuite extends SparkFunSuite with Logging { "hive.metastore.warehouse.dir" -> warehousePath.toString) } - test("SPARK-8020: successfully create a HiveContext with metastore settings in Spark conf.") { - val sparkConf = - new SparkConf() { - // We are not really clone it. We need to keep the custom getAll. - override def clone: SparkConf = this - - override def getAll: Array[(String, String)] = { - val allSettings = super.getAll - val metastoreVersion = get("spark.sql.hive.metastore.version") - val metastoreJars = get("spark.sql.hive.metastore.jars") - - val others = allSettings.filterNot { case (key, _) => - key == "spark.sql.hive.metastore.version" || key == "spark.sql.hive.metastore.jars" - } - - // Put metastore.version to the first one. It is needed to trigger the exception - // caused by SPARK-8020. Other problems triggered by SPARK-8020 - // (e.g. using Hive 0.13.1's metastore client to connect to the a 0.12 metastore) - // are not easy to test. - Array( - ("spark.sql.hive.metastore.version" -> metastoreVersion), - ("spark.sql.hive.metastore.jars" -> metastoreJars)) ++ others - } - } - sparkConf - .set("spark.sql.hive.metastore.version", "12") - .set("spark.sql.hive.metastore.jars", "maven") - - val hiveContext = new HiveContext( - new SparkContext( - "local[2]", - "TestSQLContextInVersionsSuite", - sparkConf)) { - - protected override def configure(): Map[String, String] = buildConf - - } - - // Make sure all metastore related lazy vals got created. - hiveContext.tables() - } - test("success sanity check") { val badClient = IsolatedClientLoader.forVersion("13", buildConf()).client val db = new HiveDatabase("default", "") From 7f74bb3bc6d29c53e67af6b6eec336f2d083322a Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 1 Jun 2015 21:36:49 -0700 Subject: [PATCH 132/254] [SPARK-8025][Streaming]Add JavaDoc style deprecation for deprecated Streaming methods Scala `deprecated` annotation actually doesn't show up in JavaDoc. Author: zsxwing Closes #6564 from zsxwing/SPARK-8025 and squashes the following commits: 2faa2bb [zsxwing] Add JavaDoc style deprecation for deprecated Streaming methods --- .../org/apache/spark/streaming/StreamingContext.scala | 8 ++++++++ .../spark/streaming/api/java/JavaStreamingContext.scala | 7 +++++++ .../org/apache/spark/streaming/dstream/DStream.scala | 4 ++++ 3 files changed, 19 insertions(+) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 624a31ddc2b89..9cd9684d36404 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -271,6 +271,8 @@ class StreamingContext private[streaming] ( * Create an input stream with any arbitrary user implemented receiver. * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html * @param receiver Custom implementation of Receiver + * + * @deprecated As of 1.0.0", replaced by `receiverStream`. */ @deprecated("Use receiverStream", "1.0.0") def networkStream[T: ClassTag](receiver: Receiver[T]): ReceiverInputDStream[T] = { @@ -617,6 +619,8 @@ class StreamingContext private[streaming] ( * Wait for the execution to stop. Any exceptions that occurs during the execution * will be thrown in this thread. * @param timeout time to wait in milliseconds + * + * @deprecated As of 1.3.0, replaced by `awaitTerminationOrTimeout(Long)`. */ @deprecated("Use awaitTerminationOrTimeout(Long) instead", "1.3.0") def awaitTermination(timeout: Long) { @@ -741,6 +745,10 @@ object StreamingContext extends Logging { } } + /** + * @deprecated As of 1.3.0, replaced by implicit functions in the DStream companion object. + * This is kept here only for backward compatibility. + */ @deprecated("Replaced by implicit functions in the DStream companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def toPairDStreamFunctions[K, V](stream: DStream[(K, V)]) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index b639b94d5ca47..989e3a729ebc2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -148,6 +148,9 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { /** The underlying SparkContext */ val sparkContext = new JavaSparkContext(ssc.sc) + /** + * @deprecated As of 0.9.0, replaced by `sparkContext` + */ @deprecated("use sparkContext", "0.9.0") val sc: JavaSparkContext = sparkContext @@ -619,6 +622,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Wait for the execution to stop. Any exceptions that occurs during the execution * will be thrown in this thread. * @param timeout time to wait in milliseconds + * @deprecated As of 1.3.0, replaced by `awaitTerminationOrTimeout(Long)`. */ @deprecated("Use awaitTerminationOrTimeout(Long) instead", "1.3.0") def awaitTermination(timeout: Long): Unit = { @@ -677,6 +681,7 @@ object JavaStreamingContext { * * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext + * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. */ @deprecated("use getOrCreate without JavaStreamingContextFactor", "1.4.0") def getOrCreate( @@ -699,6 +704,7 @@ object JavaStreamingContext { * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible * file system + * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. */ @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( @@ -724,6 +730,7 @@ object JavaStreamingContext { * file system * @param createOnError Whether to create a new JavaStreamingContext if there is an * error in reading checkpoint data. + * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. */ @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 6efcc193bfccc..192aa6a139bcb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -603,6 +603,8 @@ abstract class DStream[T: ClassTag] ( /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of 0.9.0, replaced by `foreachRDD`. */ @deprecated("use foreachRDD", "0.9.0") def foreach(foreachFunc: RDD[T] => Unit): Unit = ssc.withScope { @@ -612,6 +614,8 @@ abstract class DStream[T: ClassTag] ( /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of 0.9.0, replaced by `foreachRDD`. */ @deprecated("use foreachRDD", "0.9.0") def foreach(foreachFunc: (RDD[T], Time) => Unit): Unit = ssc.withScope { From e797dba58e8cafdd30683dd1e0263f00ce30ccc0 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 1 Jun 2015 21:40:17 -0700 Subject: [PATCH 133/254] [SPARK-7965] [SPARK-7972] [SQL] Handle expressions containing multiple window expressions and make parser match window frames in case insensitive way JIRAs: https://issues.apache.org/jira/browse/SPARK-7965 https://issues.apache.org/jira/browse/SPARK-7972 Author: Yin Huai Closes #6524 from yhuai/7965-7972 and squashes the following commits: c12c79c [Yin Huai] Add doc for returned value. de64328 [Yin Huai] Address rxin's comments. fc9b1ad [Yin Huai] wip 2996da4 [Yin Huai] scala style 20b65b7 [Yin Huai] Handle expressions containing multiple window expressions. 9568b21 [Yin Huai] case insensitive matches 41f633d [Yin Huai] Failed test case. --- .../sql/catalyst/analysis/Analyzer.scala | 108 +++++++++++++----- .../org/apache/spark/sql/hive/HiveQl.scala | 22 +++- .../sql/hive/execution/SQLQuerySuite.scala | 36 ++++++ 3 files changed, 134 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index df37889eedcf0..8e9fec70704e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -633,10 +633,10 @@ class Analyzer( * it into the plan tree. */ object ExtractWindowExpressions extends Rule[LogicalPlan] { - def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean = + private def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean = projectList.exists(hasWindowFunction) - def hasWindowFunction(expr: NamedExpression): Boolean = { + private def hasWindowFunction(expr: NamedExpression): Boolean = { expr.find { case window: WindowExpression => true case _ => false @@ -644,14 +644,24 @@ class Analyzer( } /** - * From a Seq of [[NamedExpression]]s, extract window expressions and - * other regular expressions. + * From a Seq of [[NamedExpression]]s, extract expressions containing window expressions and + * other regular expressions that do not contain any window expression. For example, for + * `col1, Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5)`, we will extract + * `col1`, `col2 + col3`, `col4`, and `col5` out and replace them appearances in + * the window expression as attribute references. So, the first returned value will be + * `[Sum(_w0) OVER (PARTITION BY _w1 ORDER BY _w2)]` and the second returned value will be + * [col1, col2 + col3 as _w0, col4 as _w1, col5 as _w2]. + * + * @return (seq of expressions containing at lease one window expressions, + * seq of non-window expressions) */ - def extract( + private def extract( expressions: Seq[NamedExpression]): (Seq[NamedExpression], Seq[NamedExpression]) = { - // First, we simple partition the input expressions to two part, one having - // WindowExpressions and another one without WindowExpressions. - val (windowExpressions, regularExpressions) = expressions.partition(hasWindowFunction) + // First, we partition the input expressions to two part. For the first part, + // every expression in it contain at least one WindowExpression. + // Expressions in the second part do not have any WindowExpression. + val (expressionsWithWindowFunctions, regularExpressions) = + expressions.partition(hasWindowFunction) // Then, we need to extract those regular expressions used in the WindowExpression. // For example, when we have col1 - Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5), @@ -660,8 +670,8 @@ class Analyzer( val extractedExprBuffer = new ArrayBuffer[NamedExpression]() def extractExpr(expr: Expression): Expression = expr match { case ne: NamedExpression => - // If a named expression is not in regularExpressions, add extract it and replace it - // with an AttributeReference. + // If a named expression is not in regularExpressions, add it to + // extractedExprBuffer and replace it with an AttributeReference. val missingExpr = AttributeSet(Seq(expr)) -- (regularExpressions ++ extractedExprBuffer) if (missingExpr.nonEmpty) { @@ -678,8 +688,9 @@ class Analyzer( withName.toAttribute } - // Now, we extract expressions from windowExpressions by using extractExpr. - val newWindowExpressions = windowExpressions.map { + // Now, we extract regular expressions from expressionsWithWindowFunctions + // by using extractExpr. + val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map { _.transform { // Extracts children expressions of a WindowFunction (input parameters of // a WindowFunction). @@ -705,37 +716,80 @@ class Analyzer( }.asInstanceOf[NamedExpression] } - (newWindowExpressions, regularExpressions ++ extractedExprBuffer) - } + (newExpressionsWithWindowFunctions, regularExpressions ++ extractedExprBuffer) + } // end of extract /** * Adds operators for Window Expressions. Every Window operator handles a single Window Spec. */ - def addWindow(windowExpressions: Seq[NamedExpression], child: LogicalPlan): LogicalPlan = { - // First, we group window expressions based on their Window Spec. - val groupedWindowExpression = windowExpressions.groupBy { expr => - val windowSpec = expr.collectFirst { + private def addWindow( + expressionsWithWindowFunctions: Seq[NamedExpression], + child: LogicalPlan): LogicalPlan = { + // First, we need to extract all WindowExpressions from expressionsWithWindowFunctions + // and put those extracted WindowExpressions to extractedWindowExprBuffer. + // This step is needed because it is possible that an expression contains multiple + // WindowExpressions with different Window Specs. + // After extracting WindowExpressions, we need to construct a project list to generate + // expressionsWithWindowFunctions based on extractedWindowExprBuffer. + // For example, for "sum(a) over (...) / sum(b) over (...)", we will first extract + // "sum(a) over (...)" and "sum(b) over (...)" out, and assign "_we0" as the alias to + // "sum(a) over (...)" and "_we1" as the alias to "sum(b) over (...)". + // Then, the projectList will be [_we0/_we1]. + val extractedWindowExprBuffer = new ArrayBuffer[NamedExpression]() + val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map { + // We need to use transformDown because we want to trigger + // "case alias @ Alias(window: WindowExpression, _)" first. + _.transformDown { + case alias @ Alias(window: WindowExpression, _) => + // If a WindowExpression has an assigned alias, just use it. + extractedWindowExprBuffer += alias + alias.toAttribute + case window: WindowExpression => + // If there is no alias assigned to the WindowExpressions. We create an + // internal column. + val withName = Alias(window, s"_we${extractedWindowExprBuffer.length}")() + extractedWindowExprBuffer += withName + withName.toAttribute + }.asInstanceOf[NamedExpression] + } + + // Second, we group extractedWindowExprBuffer based on their Window Spec. + val groupedWindowExpressions = extractedWindowExprBuffer.groupBy { expr => + val distinctWindowSpec = expr.collect { case window: WindowExpression => window.windowSpec + }.distinct + + // We do a final check and see if we only have a single Window Spec defined in an + // expressions. + if (distinctWindowSpec.length == 0 ) { + failAnalysis(s"$expr does not have any WindowExpression.") + } else if (distinctWindowSpec.length > 1) { + // newExpressionsWithWindowFunctions only have expressions with a single + // WindowExpression. If we reach here, we have a bug. + failAnalysis(s"$expr has multiple Window Specifications ($distinctWindowSpec)." + + s"Please file a bug report with this error message, stack trace, and the query.") + } else { + distinctWindowSpec.head } - windowSpec.getOrElse( - failAnalysis(s"$windowExpressions does not have any WindowExpression.")) }.toSeq - // For every Window Spec, we add a Window operator and set currentChild as the child of it. + // Third, for every Window Spec, we add a Window operator and set currentChild as the + // child of it. var currentChild = child var i = 0 - while (i < groupedWindowExpression.size) { - val (windowSpec, windowExpressions) = groupedWindowExpression(i) + while (i < groupedWindowExpressions.size) { + val (windowSpec, windowExpressions) = groupedWindowExpressions(i) // Set currentChild to the newly created Window operator. currentChild = Window(currentChild.output, windowExpressions, windowSpec, currentChild) - // Move to next WindowExpression. + // Move to next Window Spec. i += 1 } - // We return the top operator. - currentChild - } + // Finally, we create a Project to output currentChild's output + // newExpressionsWithWindowFunctions. + Project(currentChild.output ++ newExpressionsWithWindowFunctions, currentChild) + } // end of addWindow // We have to use transformDown at here to make sure the rule of // "Aggregate with Having clause" will be triggered. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 253bf1125262e..a5ca3613c5e00 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1561,6 +1561,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C """.stripMargin) } + /* Case insensitive matches for Window Specification */ + val PRECEDING = "(?i)preceding".r + val FOLLOWING = "(?i)following".r + val CURRENT = "(?i)current".r def nodesToWindowSpecification(nodes: Seq[ASTNode]): WindowSpec = nodes match { case Token(windowName, Nil) :: Nil => // Refer to a window spec defined in the window clause. @@ -1614,11 +1618,19 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } else { val frameType = rowFrame.map(_ => RowFrame).getOrElse(RangeFrame) def nodeToBoundary(node: Node): FrameBoundary = node match { - case Token("preceding", Token(count, Nil) :: Nil) => - if (count == "unbounded") UnboundedPreceding else ValuePreceding(count.toInt) - case Token("following", Token(count, Nil) :: Nil) => - if (count == "unbounded") UnboundedFollowing else ValueFollowing(count.toInt) - case Token("current", Nil) => CurrentRow + case Token(PRECEDING(), Token(count, Nil) :: Nil) => + if (count.toLowerCase() == "unbounded") { + UnboundedPreceding + } else { + ValuePreceding(count.toInt) + } + case Token(FOLLOWING(), Token(count, Nil) :: Nil) => + if (count.toLowerCase() == "unbounded") { + UnboundedFollowing + } else { + ValueFollowing(count.toInt) + } + case Token(CURRENT(), Nil) => CurrentRow case _ => throw new NotImplementedError( s"""No parse rules for the Window Frame Boundary based on Node ${node.getName} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 27863a60145d7..aba3becb1bce2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -780,6 +780,42 @@ class SQLQuerySuite extends QueryTest { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("window function: multiple window expressions in a single expression") { + val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") + nums.registerTempTable("nums") + + val expected = + Row(1, 1, 1, 55, 1, 57) :: + Row(0, 2, 3, 55, 2, 60) :: + Row(1, 3, 6, 55, 4, 65) :: + Row(0, 4, 10, 55, 6, 71) :: + Row(1, 5, 15, 55, 9, 79) :: + Row(0, 6, 21, 55, 12, 88) :: + Row(1, 7, 28, 55, 16, 99) :: + Row(0, 8, 36, 55, 20, 111) :: + Row(1, 9, 45, 55, 25, 125) :: + Row(0, 10, 55, 55, 30, 140) :: Nil + + val actual = sql( + """ + |SELECT + | y, + | x, + | sum(x) OVER w1 AS running_sum, + | sum(x) OVER w2 AS total_sum, + | sum(x) OVER w3 AS running_sum_per_y, + | ((sum(x) OVER w1) + (sum(x) OVER w2) + (sum(x) OVER w3)) as combined2 + |FROM nums + |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UnBOUNDED PRECEDiNG AND CuRRENT RoW), + | w2 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOuNDED FoLLOWING), + | w3 AS (PARTITION BY y ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) + """.stripMargin) + + checkAnswer(actual, expected) + + dropTempTable("nums") + } + test("test case key when") { (1 to 5).map(i => (i, i.toString)).toDF("k", "v").registerTempTable("t") checkAnswer( From b53a0116473a03607c5be3e4135151b4932acc06 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 1 Jun 2015 21:41:53 -0700 Subject: [PATCH 134/254] Fixed typo in the previous commit. --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8e9fec70704e6..bc17169f35a46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -647,7 +647,7 @@ class Analyzer( * From a Seq of [[NamedExpression]]s, extract expressions containing window expressions and * other regular expressions that do not contain any window expression. For example, for * `col1, Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5)`, we will extract - * `col1`, `col2 + col3`, `col4`, and `col5` out and replace them appearances in + * `col1`, `col2 + col3`, `col4`, and `col5` out and replace their appearances in * the window expression as attribute references. So, the first returned value will be * `[Sum(_w0) OVER (PARTITION BY _w1 ORDER BY _w2)]` and the second returned value will be * [col1, col2 + col3 as _w0, col4 as _w1, col5 as _w2]. From 0221c7f0efe2512f3ae3839b83aa8abb0806d516 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 1 Jun 2015 22:03:29 -0700 Subject: [PATCH 135/254] [SPARK-7582] [MLLIB] user guide for StringIndexer This PR adds a Java unit test and user guide for `StringIndexer`. I put it before `OneHotEncoder` because they are closely related. jkbradley Author: Xiangrui Meng Closes #6561 from mengxr/SPARK-7582 and squashes the following commits: 4bba4f1 [Xiangrui Meng] fix example ba1cd1b [Xiangrui Meng] fix style 7fa18d1 [Xiangrui Meng] add user guide for StringIndexer 136cb93 [Xiangrui Meng] add a Java unit test for StringIndexer --- docs/ml-features.md | 116 ++++++++++++++++++ .../ml/feature/JavaStringIndexerSuite.java | 77 ++++++++++++ 2 files changed, 193 insertions(+) create mode 100644 mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java diff --git a/docs/ml-features.md b/docs/ml-features.md index 9ee5696122717..f88c0248c1a8a 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -456,6 +456,122 @@ for expanded in polyDF.select("polyFeatures").take(3): +## StringIndexer + +`StringIndexer` encodes a string column of labels to a column of label indices. +The indices are in `[0, numLabels)`, ordered by label frequencies. +So the most frequent label gets index `0`. +If the input column is numeric, we cast it to string and index the string values. + +**Examples** + +Assume that we have the following DataFrame with columns `id` and `category`: + +~~~~ + id | category +----|---------- + 0 | a + 1 | b + 2 | c + 3 | a + 4 | a + 5 | c +~~~~ + +`category` is a string column with three labels: "a", "b", and "c". +Applying `StringIndexer` with `category` as the input column and `categoryIndex` as the output +column, we should get the following: + +~~~~ + id | category | categoryIndex +----|----------|--------------- + 0 | a | 0.0 + 1 | b | 2.0 + 2 | c | 1.0 + 3 | a | 0.0 + 4 | a | 0.0 + 5 | c | 1.0 +~~~~ + +"a" gets index `0` because it is the most frequent, followed by "c" with index `1` and "b" with +index `2`. + +
    + +
    + +[`StringIndexer`](api/scala/index.html#org.apache.spark.ml.feature.StringIndexer) takes an input +column name and an output column name. + +{% highlight scala %} +import org.apache.spark.ml.feature.StringIndexer + +val df = sqlContext.createDataFrame( + Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) +).toDF("id", "category") +val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") +val indexed = indexer.fit(df).transform(df) +indexed.show() +{% endhighlight %} +
    + +
    +[`StringIndexer`](api/java/org/apache/spark/ml/feature/StringIndexer.html) takes an input column +name and an output column name. + +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import static org.apache.spark.sql.types.DataTypes.*; + +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") +)); +StructType schema = new StructType(new StructField[] { + createStructField("id", DoubleType, false), + createStructField("category", StringType, false) +}); +DataFrame df = sqlContext.createDataFrame(jrdd, schema); +StringIndexer indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex"); +DataFrame indexed = indexer.fit(df).transform(df); +indexed.show(); +{% endhighlight %} +
    + +
    + +[`StringIndexer`](api/python/pyspark.ml.html#pyspark.ml.feature.StringIndexer) takes an input +column name and an output column name. + +{% highlight python %} +from pyspark.ml.feature import StringIndexer + +df = sqlContext.createDataFrame( + [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], + ["id", "category"]) +indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") +indexed = indexer.fit(df).transform(df) +indexed.show() +{% endhighlight %} +
    +
    + ## OneHotEncoder [One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java new file mode 100644 index 0000000000000..35b18c5308f61 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.Arrays; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import static org.apache.spark.sql.types.DataTypes.*; + +public class JavaStringIndexerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext sqlContext; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaStringIndexerSuite"); + sqlContext = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + sqlContext = null; + } + + @Test + public void testStringIndexer() { + StructType schema = createStructType(new StructField[] { + createStructField("id", IntegerType, false), + createStructField("label", StringType, false) + }); + JavaRDD rdd = jsc.parallelize( + Arrays.asList(c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c"))); + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + + StringIndexer indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex"); + DataFrame output = indexer.fit(dataset).transform(dataset); + + Assert.assertArrayEquals( + new Row[] { c(0, 0.0), c(1, 2.0), c(2, 1.0), c(3, 0.0), c(4, 0.0), c(5, 1.0) }, + output.orderBy("id").select("id", "labelIndex").collect()); + } + + /** An alias for RowFactory.create. */ + private Row c(Object... values) { + return RowFactory.create(values); + } +} From bcb47ad7718b843fbd25cd1e228a7b7e6e5b8686 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 1 Jun 2015 23:12:29 -0700 Subject: [PATCH 136/254] [SPARK-6917] [SQL] DecimalType is not read back when non-native type exists cc yhuai Author: Davies Liu Closes #6558 from davies/decimalType and squashes the following commits: c877ca8 [Davies Liu] Update ParquetConverter.scala 48cc57c [Davies Liu] Update ParquetConverter.scala b43845c [Davies Liu] add test 3b4a94f [Davies Liu] DecimalType is not read back when non-native type exists --- .../apache/spark/sql/parquet/ParquetConverter.scala | 4 +++- .../spark/sql/parquet/ParquetQuerySuite.scala | 13 +++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 1b4196ab0be35..caa9f045537d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -243,8 +243,10 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { /** * Read a decimal value from a Parquet Binary into "dest". Only supports decimals that fit in * a long (i.e. precision <= 18) + * + * Returned value is needed by CatalystConverter, which doesn't reuse the Decimal object. */ - protected[parquet] def readDecimal(dest: Decimal, value: Binary, ctype: DecimalType): Unit = { + protected[parquet] def readDecimal(dest: Decimal, value: Binary, ctype: DecimalType): Decimal = { val precision = ctype.precisionInfo.get.precision val scale = ctype.precisionInfo.get.scale val bytes = value.getBytes diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index b98ba09ccfc2d..304936fb2be8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.parquet import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.types._ import org.apache.spark.sql.{SQLConf, QueryTest} import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext @@ -111,6 +112,18 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest { List(Row("same", "run_5", 100))) } } + + test("SPARK-6917 DecimalType should work with non-native types") { + val data = (1 to 10).map(i => Row(Decimal(i, 18, 0), new java.sql.Timestamp(i))) + val schema = StructType(List(StructField("d", DecimalType(18, 0), false), + StructField("time", TimestampType, false)).toArray) + withTempPath { file => + val df = sqlContext.createDataFrame(sparkContext.parallelize(data), schema) + df.write.parquet(file.getCanonicalPath) + val df2 = sqlContext.read.parquet(file.getCanonicalPath) + checkAnswer(df2, df.collect().toSeq) + } + } } class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { From 7b7f7b6c6fd903e2ecfc886d29eaa9df58adcfc3 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 2 Jun 2015 00:16:56 -0700 Subject: [PATCH 137/254] [SPARK-8020] [SQL] Spark SQL conf in spark-defaults.conf make metadataHive get constructed too early https://issues.apache.org/jira/browse/SPARK-8020 Author: Yin Huai Closes #6571 from yhuai/SPARK-8020-1 and squashes the following commits: 0398f5b [Yin Huai] First populate the SQLConf and then construct executionHive and metadataHive. --- .../org/apache/spark/sql/SQLContext.scala | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 7384b24c50b16..91e6385dec81b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -182,9 +182,28 @@ class SQLContext(@transient val sparkContext: SparkContext) conf.dialect } - sparkContext.getConf.getAll.foreach { - case (key, value) if key.startsWith("spark.sql") => setConf(key, value) - case _ => + { + // We extract spark sql settings from SparkContext's conf and put them to + // Spark SQL's conf. + // First, we populate the SQLConf (conf). So, we can make sure that other values using + // those settings in their construction can get the correct settings. + // For example, metadataHive in HiveContext may need both spark.sql.hive.metastore.version + // and spark.sql.hive.metastore.jars to get correctly constructed. + val properties = new Properties + sparkContext.getConf.getAll.foreach { + case (key, value) if key.startsWith("spark.sql") => properties.setProperty(key, value) + case _ => + } + // We directly put those settings to conf to avoid of calling setConf, which may have + // side-effects. For example, in HiveContext, setConf may cause executionHive and metadataHive + // get constructed. If we call setConf directly, the constructed metadataHive may have + // wrong settings, or the construction may fail. + conf.setConf(properties) + // After we have populated SQLConf, we call setConf to populate other confs in the subclass + // (e.g. hiveconf in HiveContext). + properties.foreach { + case (key, value) => setConf(key, value) + } } @transient From 0f80990bfac1e9969644952d1d8edaf7d26fb436 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 2 Jun 2015 00:20:52 -0700 Subject: [PATCH 138/254] [SPARK-8023][SQL] Add "deterministic" attribute to Expression to avoid collapsing nondeterministic projects. This closes #6570. Author: Yin Huai Author: Reynold Xin Closes #6573 from rxin/deterministic and squashes the following commits: 356cd22 [Reynold Xin] Added unit test for the optimizer. da3fde1 [Reynold Xin] Merge pull request #6570 from yhuai/SPARK-8023 da56200 [Yin Huai] Comments. e38f264 [Yin Huai] Comment. f9d6a73 [Yin Huai] Add a deterministic method to Expression. --- .../sql/catalyst/expressions/Expression.scala | 8 ++ .../sql/catalyst/expressions/random.scala | 2 + .../sql/catalyst/optimizer/Optimizer.scala | 11 ++- .../optimizer/ProjectCollapsingSuite.scala | 73 +++++++++++++++++++ .../spark/sql/ColumnExpressionSuite.scala | 41 ++++++++++- .../org/apache/spark/sql/hive/hiveUdfs.scala | 4 + 6 files changed, 137 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index d19928784442e..adc6505d69cdf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -37,7 +37,15 @@ abstract class Expression extends TreeNode[Expression] { * - A [[Cast]] or [[UnaryMinus]] is foldable if its child is foldable */ def foldable: Boolean = false + + /** + * Returns true when the current expression always return the same result for fixed input values. + */ + // TODO: Need to define explicit input values vs implicit input values. + def deterministic: Boolean = true + def nullable: Boolean + def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator)) /** Returns the result of evaluating this expression on a given input Row */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index 4f4f67a6e482c..b2647124c4e49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -38,6 +38,8 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable { */ @transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.get().partitionId()) + override def deterministic: Boolean = false + override def nullable: Boolean = false override def dataType: DataType = DoubleType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c2818d957cc79..b25fb48f55e2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -179,8 +179,17 @@ object ColumnPruning extends Rule[LogicalPlan] { * expressions into one single expression. */ object ProjectCollapsing extends Rule[LogicalPlan] { + + /** Returns true if any expression in projectList is non-deterministic. */ + private def hasNondeterministic(projectList: Seq[NamedExpression]): Boolean = { + projectList.exists(expr => expr.find(!_.deterministic).isDefined) + } + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case Project(projectList1, Project(projectList2, child)) => + // We only collapse these two Projects if the child Project's expressions are all + // deterministic. + case Project(projectList1, Project(projectList2, child)) + if !hasNondeterministic(projectList2) => // Create a map of Aliases to their values from the child projection. // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). val aliasMap = AttributeMap(projectList2.collect { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala new file mode 100644 index 0000000000000..151654bffbd66 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.Rand +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + + +class ProjectCollapsingSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", FixedPoint(10), EliminateSubQueries) :: + Batch("ProjectCollapsing", Once, ProjectCollapsing) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int) + + test("collapse two deterministic, independent projects into one") { + val query = testRelation + .select(('a + 1).as('a_plus_1), 'b) + .select('a_plus_1, ('b + 1).as('b_plus_1)) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = testRelation.select(('a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("collapse two deterministic, dependent projects into one") { + val query = testRelation + .select(('a + 1).as('a_plus_1), 'b) + .select(('a_plus_1 + 1).as('a_plus_2), 'b) + + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = testRelation.select( + (('a + 1).as('a_plus_1) + 1).as('a_plus_2), + 'b).analyze + + comparePlans(optimized, correctAnswer) + } + + test("do not collapse nondeterministic projects") { + val query = testRelation + .select(Rand(10).as('rand)) + .select(('rand + 1).as('rand1), ('rand + 2).as('rand2)) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = query.analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index b8bb1bff9ea72..bfba379d9a518 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.scalatest.Matchers._ +import org.apache.spark.sql.execution.Project import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.implicits._ @@ -452,13 +453,51 @@ class ColumnExpressionSuite extends QueryTest { } test("rand") { - val randCol = testData.select('key, rand(5L).as("rand")) + val randCol = testData.select($"key", rand(5L).as("rand")) randCol.columns.length should be (2) val rows = randCol.collect() rows.foreach { row => assert(row.getDouble(1) <= 1.0) assert(row.getDouble(1) >= 0.0) } + + def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = { + val projects = df.queryExecution.executedPlan.collect { + case project: Project => project + } + assert(projects.size === expectedNumProjects) + } + + // We first create a plan with two Projects. + // Project [rand + 1 AS rand1, rand - 1 AS rand2] + // Project [key, (Rand 5 + 1) AS rand] + // LogicalRDD [key, value] + // Because Rand function is not deterministic, the column rand is not deterministic. + // So, in the optimizer, we will not collapse Project [rand + 1 AS rand1, rand - 1 AS rand2] + // and Project [key, Rand 5 AS rand]. The final plan still has two Projects. + val dfWithTwoProjects = + testData + .select($"key", (rand(5L) + 1).as("rand")) + .select(($"rand" + 1).as("rand1"), ($"rand" - 1).as("rand2")) + checkNumProjects(dfWithTwoProjects, 2) + + // Now, we add one more project rand1 - rand2 on top of the query plan. + // Since rand1 and rand2 are deterministic (they basically apply +/- to the generated + // rand value), we can collapse rand1 - rand2 to the Project generating rand1 and rand2. + // So, the plan will be optimized from ... + // Project [(rand1 - rand2) AS (rand1 - rand2)] + // Project [rand + 1 AS rand1, rand - 1 AS rand2] + // Project [key, (Rand 5 + 1) AS rand] + // LogicalRDD [key, value] + // to ... + // Project [((rand + 1 AS rand1) - (rand - 1 AS rand2)) AS (rand1 - rand2)] + // Project [key, Rand 5 AS rand] + // LogicalRDD [key, value] + val dfWithThreeProjects = dfWithTwoProjects.select($"rand1" - $"rand2") + checkNumProjects(dfWithThreeProjects, 2) + dfWithThreeProjects.collect().foreach { row => + assert(row.getDouble(0) === 2.0 +- 0.0001) + } } test("randn") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 64a49c83cbad1..1658bb93b0b79 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -78,6 +78,8 @@ private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, childre type UDFType = UDF + override def deterministic: Boolean = isUDFDeterministic + override def nullable: Boolean = true @transient @@ -140,6 +142,8 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr extends Expression with HiveInspectors with Logging { type UDFType = GenericUDF + override def deterministic: Boolean = isUDFDeterministic + override def nullable: Boolean = true @transient From 445647a1a36e1e24076a9fe506492fac462c66ad Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 2 Jun 2015 08:37:18 -0700 Subject: [PATCH 139/254] [SPARK-8021] [SQL] [PYSPARK] make Python read/write API consistent with Scala add schema()/format()/options() for reader, add mode()/format()/options()/partitionBy() for writer cc rxin yhuai pwendell Author: Davies Liu Closes #6578 from davies/readwrite and squashes the following commits: 720d293 [Davies Liu] address comments b65dfa2 [Davies Liu] Update readwriter.py 1299ab6 [Davies Liu] make Python API consistent with Scala --- python/pyspark/sql/readwriter.py | 121 ++++++++++++++++++++++++------- 1 file changed, 94 insertions(+), 27 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index b6fd413bec7db..d17d87419fe3d 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -43,6 +43,39 @@ def _df(self, jdf): from pyspark.sql.dataframe import DataFrame return DataFrame(jdf, self._sqlContext) + @since(1.4) + def format(self, source): + """ + Specifies the input data source format. + """ + self._jreader = self._jreader.format(source) + return self + + @since(1.4) + def schema(self, schema): + """ + Specifies the input schema. Some data sources (e.g. JSON) can + infer the input schema automatically from data. By specifying + the schema here, the underlying data source can skip the schema + inference step, and thus speed up data loading. + + :param schema: a StructType object + """ + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json()) + self._jreader = self._jreader.schema(jschema) + return self + + @since(1.4) + def options(self, **options): + """ + Adds input options for the underlying data source. + """ + for k in options: + self._jreader = self._jreader.option(k, options[k]) + return self + @since(1.4) def load(self, path=None, format=None, schema=None, **options): """Loads data from a data source and returns it as a :class`DataFrame`. @@ -52,20 +85,15 @@ def load(self, path=None, format=None, schema=None, **options): :param schema: optional :class:`StructType` for the input schema. :param options: all other string options """ - jreader = self._jreader if format is not None: - jreader = jreader.format(format) + self.format(format) if schema is not None: - if not isinstance(schema, StructType): - raise TypeError("schema should be StructType") - jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json()) - jreader = jreader.schema(jschema) - for k in options: - jreader = jreader.option(k, options[k]) + self.schema(schema) + self.options(**options) if path is not None: - return self._df(jreader.load(path)) + return self._df(self._jreader.load(path)) else: - return self._df(jreader.load()) + return self._df(self._jreader.load()) @since(1.4) def json(self, path, schema=None): @@ -105,12 +133,9 @@ def json(self, path, schema=None): | |-- field5: array (nullable = true) | | |-- element: integer (containsNull = true) """ - if schema is None: - jdf = self._jreader.json(path) - else: - jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json()) - jdf = self._jreader.schema(jschema).json(path) - return self._df(jdf) + if schema is not None: + self.schema(schema) + return self._df(self._jreader.json(path)) @since(1.4) def table(self, tableName): @@ -194,6 +219,51 @@ def __init__(self, df): self._sqlContext = df.sql_ctx self._jwrite = df._jdf.write() + @since(1.4) + def mode(self, saveMode): + """ + Specifies the behavior when data or table already exists. Options include: + + * `append`: Append contents of this :class:`DataFrame` to existing data. + * `overwrite`: Overwrite existing data. + * `error`: Throw an exception if data already exists. + * `ignore`: Silently ignore this operation if data already exists. + """ + self._jwrite = self._jwrite.mode(saveMode) + return self + + @since(1.4) + def format(self, source): + """ + Specifies the underlying output data source. Built-in options include + "parquet", "json", etc. + """ + self._jwrite = self._jwrite.format(source) + return self + + @since(1.4) + def options(self, **options): + """ + Adds output options for the underlying data source. + """ + for k in options: + self._jwrite = self._jwrite.option(k, options[k]) + return self + + @since(1.4) + def partitionBy(self, *cols): + """ + Partitions the output by the given columns on the file system. + If specified, the output is laid out on the file system similar + to Hive's partitioning scheme. + + :param cols: name of columns + """ + if len(cols) == 1 and isinstance(cols[0], (list, tuple)): + cols = cols[0] + self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols)) + return self + @since(1.4) def save(self, path=None, format=None, mode="error", **options): """ @@ -216,16 +286,15 @@ def save(self, path=None, format=None, mode="error", **options): :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) :param options: all other string options """ - jwrite = self._jwrite.mode(mode) + self.mode(mode).options(**options) if format is not None: - jwrite = jwrite.format(format) - for k in options: - jwrite = jwrite.option(k, options[k]) + self.format(format) if path is None: - jwrite.save() + self._jwrite.save() else: - jwrite.save(path) + self._jwrite.save(path) + @since(1.4) def insertInto(self, tableName, overwrite=False): """ Inserts the content of the :class:`DataFrame` to the specified table. @@ -256,12 +325,10 @@ def saveAsTable(self, name, format=None, mode="error", **options): :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) :param options: all other string options """ - jwrite = self._jwrite.mode(mode) + self.mode(mode).options(**options) if format is not None: - jwrite = jwrite.format(format) - for k in options: - jwrite = jwrite.option(k, options[k]) - return jwrite.saveAsTable(name) + self.format(format) + return self._jwrite.saveAsTable(name) @since(1.4) def json(self, path, mode="error"): From bd97840d5ccc3f0bfde1e5cfc7abeac9681997ab Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 2 Jun 2015 08:51:00 -0700 Subject: [PATCH 140/254] [SPARK-7432] [MLLIB] fix flaky CrossValidator doctest The new test uses CV to compare `maxIter=0` and `maxIter=1`, and validate on the evaluation result. jkbradley Author: Xiangrui Meng Closes #6572 from mengxr/SPARK-7432 and squashes the following commits: c236bb8 [Xiangrui Meng] fix flacky cv doctest --- python/pyspark/ml/tuning.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 497841b6c8ce6..0bf988fd72f14 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -91,20 +91,19 @@ class CrossValidator(Estimator): >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator >>> from pyspark.mllib.linalg import Vectors >>> dataset = sqlContext.createDataFrame( - ... [(Vectors.dense([0.0, 1.0]), 0.0), - ... (Vectors.dense([1.0, 2.0]), 1.0), - ... (Vectors.dense([0.55, 3.0]), 0.0), - ... (Vectors.dense([0.45, 4.0]), 1.0), - ... (Vectors.dense([0.51, 5.0]), 1.0)] * 10, + ... [(Vectors.dense([0.0]), 0.0), + ... (Vectors.dense([0.4]), 1.0), + ... (Vectors.dense([0.5]), 0.0), + ... (Vectors.dense([0.6]), 1.0), + ... (Vectors.dense([1.0]), 1.0)] * 10, ... ["features", "label"]) >>> lr = LogisticRegression() - >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1, 5]).build() + >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() >>> evaluator = BinaryClassificationEvaluator() >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - >>> # SPARK-7432: The following test is flaky. - >>> # cvModel = cv.fit(dataset) - >>> # expected = lr.fit(dataset, {lr.maxIter: 5}).transform(dataset) - >>> # cvModel.transform(dataset).collect() == expected.collect() + >>> cvModel = cv.fit(dataset) + >>> evaluator.evaluate(cvModel.transform(dataset)) + 0.8333... """ # a placeholder to make it appear in the generated doc From 1bb5d716c0351cd0b4c11b397fd778f30db39bd9 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 3 Jun 2015 00:59:50 +0800 Subject: [PATCH 141/254] [SPARK-8037] [SQL] Ignores files whose name starts with dot in HadoopFsRelation Author: Cheng Lian Closes #6581 from liancheng/spark-8037 and squashes the following commits: d08e97b [Cheng Lian] Ignores files whose name starts with dot in HadoopFsRelation --- .../spark/sql/sources/PartitioningUtils.scala | 2 +- .../apache/spark/sql/sources/interfaces.scala | 11 +++++++---- .../ParquetPartitionDiscoverySuite.scala | 19 ++++++++++++++++++- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala index dafdf0f8b4564..c4c99de5a38dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala @@ -187,7 +187,7 @@ private[sql] object PartitioningUtils { Seq.empty } else { assert(distinctPartitionsColNames.size == 1, { - val list = distinctPartitionsColNames.mkString("\t", "\n", "") + val list = distinctPartitionsColNames.mkString("\t", "\n\t", "") s"Conflicting partition column names detected:\n$list" }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index b1b997c030a60..c4ffa8de52640 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -379,10 +379,10 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio var leafDirToChildrenFiles = mutable.Map.empty[Path, Array[FileStatus]] def refresh(): Unit = { - // We don't filter files/directories whose name start with "_" or "." here, as specific data - // sources may take advantages over them (e.g. Parquet _metadata and _common_metadata files). - // But "_temporary" directories are explicitly ignored since failed tasks/jobs may leave - // partial/corrupted data files there. + // We don't filter files/directories whose name start with "_" except "_temporary" here, as + // specific data sources may take advantages over them (e.g. Parquet _metadata and + // _common_metadata files). "_temporary" directories are explicitly ignored since failed + // tasks/jobs may leave partial/corrupted data files there. def listLeafFilesAndDirs(fs: FileSystem, status: FileStatus): Set[FileStatus] = { if (status.getPath.getName.toLowerCase == "_temporary") { Set.empty @@ -400,6 +400,9 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) Try(fs.getFileStatus(qualified)).toOption.toArray.flatMap(listLeafFilesAndDirs(fs, _)) + }.filterNot { status => + // SPARK-8037: Ignores files like ".DS_Store" and other hidden files/directories + status.getPath.getName.startsWith(".") } val files = statuses.filterNot(_.isDir) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index f231589e9674d..3b29979452ad9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql.parquet import java.io.File import java.math.BigInteger -import java.sql.{Timestamp, Date} +import java.sql.Timestamp import scala.collection.mutable.ArrayBuffer +import com.google.common.io.Files import org.apache.hadoop.fs.Path import org.apache.spark.sql.catalyst.expressions.Literal @@ -432,4 +433,20 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { checkAnswer(read.load(dir.toString).select(fields: _*), row) } } + + test("SPARK-8037: Ignores files whose name starts with dot") { + withTempPath { dir => + val df = (1 to 3).map(i => (i, i, i, i)).toDF("a", "b", "c", "d") + + df.write + .format("parquet") + .partitionBy("b", "c", "d") + .save(dir.getCanonicalPath) + + Files.touch(new File(s"${dir.getCanonicalPath}/b=1", ".DS_Store")) + Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar")) + + checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df) + } + } } From 0071bd8d31f13abfe73b9d141a818412d374dce0 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 2 Jun 2015 11:20:33 -0700 Subject: [PATCH 142/254] [SPARK-8015] [FLUME] Remove Guava dependency from flume-sink. The minimal change would be to disable shading of Guava in the module, and rely on the transitive dependency from other libraries instead. But since Guava's use is so localized, I think it's better to just not use it instead, so I replaced that code and removed all traces of Guava from the module's build. Author: Marcelo Vanzin Closes #6555 from vanzin/SPARK-8015 and squashes the following commits: c0ceea8 [Marcelo Vanzin] Add comments about dependency management. c38228d [Marcelo Vanzin] Add guava dep in test scope. b7a0349 [Marcelo Vanzin] Add libthrift exclusion. 6e0942d [Marcelo Vanzin] Add comment in pom. 2d79260 [Marcelo Vanzin] [SPARK-8015] [flume] Remove Guava dependency from flume-sink. --- external/flume-sink/pom.xml | 39 +++++++++++++++++++ .../flume/sink/SparkAvroCallbackHandler.scala | 4 +- .../flume/sink/SparkSinkThreadFactory.scala | 35 +++++++++++++++++ .../streaming/flume/sink/SparkSinkSuite.scala | 6 +-- 4 files changed, 77 insertions(+), 7 deletions(-) create mode 100644 external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 1f3e619d97a24..71f2b6fe18bd1 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -42,15 +42,46 @@ org.apache.flume flume-ng-sdk + + + + com.google.guava + guava + + + + org.apache.thrift + libthrift + + org.apache.flume flume-ng-core + + + com.google.guava + guava + + + org.apache.thrift + libthrift + + org.scala-lang scala-library + + + com.google.guava + guava + test + + + + diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala index fd01807fc3ac4..dc2a4ab138e18 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala @@ -21,7 +21,6 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.flume.Channel import org.apache.commons.lang3.RandomStringUtils @@ -45,8 +44,7 @@ import org.apache.commons.lang3.RandomStringUtils private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Channel, val transactionTimeout: Int, val backOffInterval: Int) extends SparkFlumeProtocol with Logging { val transactionExecutorOpt = Option(Executors.newFixedThreadPool(threads, - new ThreadFactoryBuilder().setDaemon(true) - .setNameFormat("Spark Sink Processor Thread - %d").build())) + new SparkSinkThreadFactory("Spark Sink Processor Thread - %d"))) // Protected by `sequenceNumberToProcessor` private val sequenceNumberToProcessor = mutable.HashMap[CharSequence, TransactionProcessor]() // This sink will not persist sequence numbers and reuses them if it gets restarted. diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala new file mode 100644 index 0000000000000..845fc8debda75 --- /dev/null +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.flume.sink + +import java.util.concurrent.ThreadFactory +import java.util.concurrent.atomic.AtomicLong + +/** + * Thread factory that generates daemon threads with a specified name format. + */ +private[sink] class SparkSinkThreadFactory(nameFormat: String) extends ThreadFactory { + + private val threadId = new AtomicLong() + + override def newThread(r: Runnable): Thread = { + val t = new Thread(r, nameFormat.format(threadId.incrementAndGet())) + t.setDaemon(true) + t + } + +} diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index 605b3fe71017f..fa43629d49771 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConversions._ import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success} -import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.avro.ipc.NettyTransceiver import org.apache.avro.ipc.specific.SpecificRequestor import org.apache.flume.Context @@ -194,9 +193,8 @@ class SparkSinkSuite extends FunSuite { count: Int): Seq[(NettyTransceiver, SparkFlumeProtocol.Callback)] = { (1 to count).map(_ => { - lazy val channelFactoryExecutor = - Executors.newCachedThreadPool(new ThreadFactoryBuilder().setDaemon(true). - setNameFormat("Flume Receiver Channel Thread - %d").build()) + lazy val channelFactoryExecutor = Executors.newCachedThreadPool( + new SparkSinkThreadFactory("Flume Receiver Channel Thread - %d")) lazy val channelFactory = new NioClientSocketChannelFactory(channelFactoryExecutor, channelFactoryExecutor) val transceiver = new NettyTransceiver(address, channelFactory) From ad06727fe985ca243ebdaaba55cd7d35a4749d0a Mon Sep 17 00:00:00 2001 From: Mike Dusenberry Date: Tue, 2 Jun 2015 12:38:14 -0700 Subject: [PATCH 143/254] [SPARK-7985] [ML] [MLlib] [Docs] Remove "fittingParamMap" references. Updating ML Doc "Estimator, Transformer, and Param" examples. Updating ML Doc's *"Estimator, Transformer, and Param"* example to use `model.extractParamMap` instead of `model.fittingParamMap`, which no longer exists. mengxr, I believe this addresses (part of) the *update documentation* TODO list item from [PR 5820](https://github.com/apache/spark/pull/5820). Author: Mike Dusenberry Closes #6514 from dusenberrymw/Fix_ML_Doc_Estimator_Transformer_Param_Example and squashes the following commits: 6366e1f [Mike Dusenberry] Updating instances of model.extractParamMap to model.parent.extractParamMap, since the Params of the parent Estimator could possibly differ from thos of the Model. d850e0e [Mike Dusenberry] Removing all references to "fittingParamMap" throughout Spark, since it has been removed. 0480304 [Mike Dusenberry] Updating the ML Doc "Estimator, Transformer, and Param" Java example to use model.extractParamMap() instead of model.fittingParamMap(), which no longer exists. 7d34939 [Mike Dusenberry] Updating ML Doc "Estimator, Transformer, and Param" example to use model.extractParamMap instead of model.fittingParamMap, which no longer exists. --- docs/ml-guide.md | 8 ++++---- .../apache/spark/ml/classification/GBTClassifier.scala | 2 +- .../spark/ml/classification/RandomForestClassifier.scala | 2 +- .../org/apache/spark/ml/regression/GBTRegressor.scala | 2 +- .../spark/ml/regression/RandomForestRegressor.scala | 2 +- .../ml/classification/DecisionTreeClassifierSuite.scala | 2 +- .../spark/ml/classification/GBTClassifierSuite.scala | 2 +- .../ml/classification/RandomForestClassifierSuite.scala | 2 +- .../spark/ml/regression/DecisionTreeRegressorSuite.scala | 2 +- .../apache/spark/ml/regression/GBTRegressorSuite.scala | 2 +- .../spark/ml/regression/RandomForestRegressorSuite.scala | 2 +- 11 files changed, 14 insertions(+), 14 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index c5f50ed7990f1..4eb622d4b95e8 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -207,7 +207,7 @@ val model1 = lr.fit(training.toDF) // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this // LogisticRegression instance. -println("Model 1 was fit using parameters: " + model1.fittingParamMap) +println("Model 1 was fit using parameters: " + model1.parent.extractParamMap) // We may alternatively specify parameters using a ParamMap, // which supports several methods for specifying parameters. @@ -222,7 +222,7 @@ val paramMapCombined = paramMap ++ paramMap2 // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. val model2 = lr.fit(training.toDF, paramMapCombined) -println("Model 2 was fit using parameters: " + model2.fittingParamMap) +println("Model 2 was fit using parameters: " + model2.parent.extractParamMap) // Prepare test data. val test = sc.parallelize(Seq( @@ -289,7 +289,7 @@ LogisticRegressionModel model1 = lr.fit(training); // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this // LogisticRegression instance. -System.out.println("Model 1 was fit using parameters: " + model1.fittingParamMap()); +System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap()); // We may alternatively specify parameters using a ParamMap. ParamMap paramMap = new ParamMap(); @@ -305,7 +305,7 @@ ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. LogisticRegressionModel model2 = lr.fit(training, paramMapCombined); -System.out.println("Model 2 was fit using parameters: " + model2.fittingParamMap()); +System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap()); // Prepare test documents. List localTest = Lists.newArrayList( diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index d8592eb2d947d..62f4b51f770e9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -208,7 +208,7 @@ private[ml] object GBTClassificationModel { require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" + s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => - // parent, fittingParamMap for each tree is null since there are no good ways to set these. + // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc") diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 67600ebd7b38e..852a67e066322 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -170,7 +170,7 @@ private[ml] object RandomForestClassificationModel { require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => - // parent, fittingParamMap for each tree is null since there are no good ways to set these. + // parent for each tree is null since there is no good way to set this. DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 69f4f5414c8c6..b7e374bb6cb49 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -198,7 +198,7 @@ private[ml] object GBTRegressionModel { require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" + s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).") val newTrees = oldModel.trees.map { tree => - // parent, fittingParamMap for each tree is null since there are no good ways to set these. + // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index ae767a17329d2..49a1f7ce8c995 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -152,7 +152,7 @@ private[ml] object RandomForestRegressionModel { require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).") val newTrees = oldModel.trees.map { tree => - // parent, fittingParamMap for each tree is null since there are no good ways to set these. + // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } new RandomForestRegressionModel(parent.uid, newTrees) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 40554f6ef94a8..ae40b0b8ff854 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -265,7 +265,7 @@ private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { val oldTree = OldDecisionTree.train(data, oldStrategy) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) val newTree = dt.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldTreeAsNew = DecisionTreeClassificationModel.fromOld( oldTree, newTree.parent.asInstanceOf[DecisionTreeClassifier], categoricalFeatures) TreeTests.checkEqual(oldTreeAsNew, newTree) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 09327051621e0..1302da3c373ff 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -127,7 +127,7 @@ private object GBTClassifierSuite { val oldModel = oldGBT.run(data) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2) val newModel = gbt.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldModelAsNew = GBTClassificationModel.fromOld( oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index f699d0c374d2f..eee9355a67be3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -157,7 +157,7 @@ private object RandomForestClassifierSuite { data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) val newModel = rf.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldModelAsNew = RandomForestClassificationModel.fromOld( oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 1182b89a8e3aa..33aa9d0d62343 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -82,7 +82,7 @@ private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { val oldTree = OldDecisionTree.train(data, oldStrategy) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newTree = dt.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldTreeAsNew = DecisionTreeRegressionModel.fromOld( oldTree, newTree.parent.asInstanceOf[DecisionTreeRegressor], categoricalFeatures) TreeTests.checkEqual(oldTreeAsNew, newTree) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index f8a1469fee313..98fb3d3f5f22c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -128,7 +128,7 @@ private object GBTRegressorSuite { val oldModel = oldGBT.run(data) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newModel = gbt.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldModelAsNew = GBTRegressionModel.fromOld( oldModel, newModel.parent.asInstanceOf[GBTRegressor], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 78911560945a2..b24ecaa57c89b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -113,7 +113,7 @@ private object RandomForestRegressorSuite extends SparkFunSuite { data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newModel = rf.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldModelAsNew = RandomForestRegressionModel.fromOld( oldModel, newModel.parent.asInstanceOf[RandomForestRegressor], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) From 686a45f0b9c50ede2a80854ed6a155ee8a9a4f5c Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 2 Jun 2015 13:32:13 -0700 Subject: [PATCH 144/254] [SPARK-8014] [SQL] Avoid premature metadata discovery when writing a HadoopFsRelation with a save mode other than Append The current code references the schema of the DataFrame to be written before checking save mode. This triggers expensive metadata discovery prematurely. For save mode other than `Append`, this metadata discovery is useless since we either ignore the result (for `Ignore` and `ErrorIfExists`) or delete existing files (for `Overwrite`) later. This PR fixes this issue by deferring metadata discovery after save mode checking. Author: Cheng Lian Closes #6583 from liancheng/spark-8014 and squashes the following commits: 1aafabd [Cheng Lian] Updates comments 088abaa [Cheng Lian] Avoids schema merging and partition discovery when data schema and partition schema are defined 8fbd93f [Cheng Lian] Fixes SPARK-8014 --- .../apache/spark/sql/parquet/newParquet.scala | 2 +- .../apache/spark/sql/sources/commands.scala | 20 +++++-- .../org/apache/spark/sql/sources/ddl.scala | 16 ++--- .../apache/spark/sql/sources/interfaces.scala | 2 +- .../sql/sources/hadoopFsRelationSuites.scala | 59 ++++++++++++++----- 5 files changed, 67 insertions(+), 32 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index e439a18ac43aa..824ae36968c32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -190,7 +190,7 @@ private[sql] class ParquetRelation2( } } - override def dataSchema: StructType = metadataCache.dataSchema + override def dataSchema: StructType = maybeDataSchema.getOrElse(metadataCache.dataSchema) override private[sql] def refresh(): Unit = { super.refresh() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 3132067d562f6..71f016b1f14de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -30,9 +30,10 @@ import org.apache.spark._ import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext, SaveMode} @@ -94,10 +95,19 @@ private[sql] case class InsertIntoHadoopFsRelation( // We create a DataFrame by applying the schema of relation to the data to make sure. // We are writing data based on the expected schema, - val df = sqlContext.createDataFrame( - DataFrame(sqlContext, query).queryExecution.toRdd, - relation.schema, - needsConversion = false) + val df = { + // For partitioned relation r, r.schema's column ordering can be different from the column + // ordering of data.logicalPlan (partition columns are all moved after data column). We + // need a Project to adjust the ordering, so that inside InsertIntoHadoopFsRelation, we can + // safely apply the schema of r.schema to the data. + val project = Project( + relation.schema.map(field => new UnresolvedAttribute(Seq(field.name))), query) + + sqlContext.createDataFrame( + DataFrame(sqlContext, project).queryExecution.toRdd, + relation.schema, + needsConversion = false) + } val partitionColumns = relation.partitionColumns.fieldNames if (partitionColumns.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 22587f5a1c6f1..20afd60cb7767 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.catalyst.AbstractSparkSQLParser -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.RunnableCommand @@ -322,19 +322,13 @@ private[sql] object ResolvedDataSource { Some(partitionColumnsSchema(data.schema, partitionColumns)), caseInsensitiveOptions) - // For partitioned relation r, r.schema's column ordering is different with the column - // ordering of data.logicalPlan. We need a Project to adjust the ordering. - // So, inside InsertIntoHadoopFsRelation, we can safely apply the schema of r.schema to - // the data. - val project = - Project( - r.schema.map(field => new UnresolvedAttribute(Seq(field.name))), - data.logicalPlan) - + // For partitioned relation r, r.schema's column ordering can be different from the column + // ordering of data.logicalPlan (partition columns are all moved after data column). This + // will be adjusted within InsertIntoHadoopFsRelation. sqlContext.executePlan( InsertIntoHadoopFsRelation( r, - project, + data.logicalPlan, mode)).toRdd r case _ => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index c4ffa8de52640..f5bd2d2941ca0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -503,7 +503,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio */ override lazy val schema: StructType = { val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet - StructType(dataSchema ++ partitionSpec.partitionColumns.filterNot { column => + StructType(dataSchema ++ partitionColumns.filterNot { column => dataSchemaColumnNames.contains(column.name.toLowerCase) }) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index af36fa6f1faae..74095426741e3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.sources +import java.io.File + +import com.google.common.io.Files import org.apache.hadoop.fs.Path import org.apache.spark.{SparkException, SparkFunSuite} @@ -453,6 +456,20 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } } } + + test("SPARK-7616: adjust column name order accordingly when saving partitioned table") { + val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c") + + df.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("c", "a") + .saveAsTable("t") + + withTable("t") { + checkAnswer(table("t"), df.select('b, 'c, 'a).collect()) + } + } } class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { @@ -534,20 +551,6 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { } } - test("SPARK-7616: adjust column name order accordingly when saving partitioned table") { - val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c") - - df.write - .format("parquet") - .mode(SaveMode.Overwrite) - .partitionBy("c", "a") - .saveAsTable("t") - - withTable("t") { - checkAnswer(table("t"), df.select('b, 'c, 'a).collect()) - } - } - test("SPARK-7868: _temporary directories should be ignored") { withTempPath { dir => val df = Seq("a", "b", "c").zipWithIndex.toDF() @@ -563,4 +566,32 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df.collect()) } } + + test("SPARK-8014: Avoid scanning output directory when SaveMode isn't SaveMode.Append") { + withTempDir { dir => + val path = dir.getCanonicalPath + val df = Seq(1 -> "a").toDF() + + // Creates an arbitrary file. If this directory gets scanned, ParquetRelation2 will throw + // since it's not a valid Parquet file. + val emptyFile = new File(path, "empty") + Files.createParentDirs(emptyFile) + Files.touch(emptyFile) + + // This shouldn't throw anything. + df.write.format("parquet").mode(SaveMode.Ignore).save(path) + + // This should only complain that the destination directory already exists, rather than file + // "empty" is not a Parquet file. + assert { + intercept[RuntimeException] { + df.write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) + }.getMessage.contains("already exists") + } + + // This shouldn't throw anything. + df.write.format("parquet").mode(SaveMode.Overwrite).save(path) + checkAnswer(read.format("parquet").load(path), df) + } + } } From 605ddbb27c8482fc0107b21c19d4e4ae19348f35 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 2 Jun 2015 13:38:06 -0700 Subject: [PATCH 145/254] [SPARK-8038] [SQL] [PYSPARK] fix Column.when() and otherwise() Thanks ogirardot, closes #6580 cc rxin JoshRosen Author: Davies Liu Closes #6590 from davies/when and squashes the following commits: c0f2069 [Davies Liu] fix Column.when() and otherwise() --- python/pyspark/sql/column.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 8dc5039f587f0..1ecec5b126505 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -315,6 +315,14 @@ def between(self, lowerBound, upperBound): """ A boolean expression that is evaluated to true if the value of this expression is between the given columns. + + >>> df.select(df.name, df.age.between(2, 4)).show() + +-----+--------------------------+ + | name|((age >= 2) && (age <= 4))| + +-----+--------------------------+ + |Alice| true| + | Bob| false| + +-----+--------------------------+ """ return (self >= lowerBound) & (self <= upperBound) @@ -328,12 +336,20 @@ def when(self, condition, value): :param condition: a boolean :class:`Column` expression. :param value: a literal value, or a :class:`Column` expression. + + >>> from pyspark.sql import functions as F + >>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show() + +-----+--------------------------------------------------------+ + | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0| + +-----+--------------------------------------------------------+ + |Alice| -1| + | Bob| 1| + +-----+--------------------------------------------------------+ """ - sc = SparkContext._active_spark_context if not isinstance(condition, Column): raise TypeError("condition should be a Column") v = value._jc if isinstance(value, Column) else value - jc = sc._jvm.functions.when(condition._jc, v) + jc = self._jc.when(condition._jc, v) return Column(jc) @since(1.4) @@ -345,9 +361,18 @@ def otherwise(self, value): See :func:`pyspark.sql.functions.when` for example usage. :param value: a literal value, or a :class:`Column` expression. + + >>> from pyspark.sql import functions as F + >>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show() + +-----+---------------------------------+ + | name|CASE WHEN (age > 3) THEN 1 ELSE 0| + +-----+---------------------------------+ + |Alice| 0| + | Bob| 1| + +-----+---------------------------------+ """ v = value._jc if isinstance(value, Column) else value - jc = self._jc.otherwise(value) + jc = self._jc.otherwise(v) return Column(jc) @since(1.4) From 89f21f66b5549524d1a6e4fb576a4f80d9fef903 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 2 Jun 2015 16:51:17 -0700 Subject: [PATCH 146/254] [SPARK-8049] [MLLIB] drop tmp col from OneVsRest output The temporary column should be dropped after we get the prediction column. harsha2010 Author: Xiangrui Meng Closes #6592 from mengxr/SPARK-8049 and squashes the following commits: 1d89107 [Xiangrui Meng] use SparkFunSuite 6ee70de [Xiangrui Meng] drop tmp col from OneVsRest output --- .../org/apache/spark/ml/classification/OneVsRest.scala | 1 + .../apache/spark/ml/classification/OneVsRestSuite.scala | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 7b726da388075..825f9ed1b54b2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -131,6 +131,7 @@ final class OneVsRestModel private[ml] ( // output label and label metadata as prediction val labelUdf = callUDF(label, DoubleType, col(accColName)) aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata)) + .drop(accColName) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index f439f3261f06f..1d04ccb509057 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -93,6 +93,15 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features) ova.fit(datasetWithLabelMetadata) } + + test("SPARK-8049: OneVsRest shouldn't output temp columns") { + val logReg = new LogisticRegression() + .setMaxIter(1) + val ovr = new OneVsRest() + .setClassifier(logReg) + val output = ovr.fit(dataset).transform(dataset) + assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction")) + } } private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) { From 5cd6a63d9692d153751747e0293dc030d73a6194 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 2 Jun 2015 17:07:13 -0700 Subject: [PATCH 147/254] [SQL] [TEST] [MINOR] Follow-up of PR #6493, use Guava API to ensure Java 6 friendliness This is a follow-up of PR #6493, which has been reverted in branch-1.4 because it uses Java 7 specific APIs and breaks Java 6 build. This PR replaces those APIs with equivalent Guava ones to ensure Java 6 friendliness. cc andrewor14 pwendell, this should also be back ported to branch-1.4. Author: Cheng Lian Closes #6547 from liancheng/override-log4j and squashes the following commits: c900cfd [Cheng Lian] Addresses Shixiong's comment 72da795 [Cheng Lian] Uses Guava API to ensure Java 6 friendliness --- .../sql/hive/thriftserver/HiveThriftServer2Suites.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index da511ebd05ad2..a93a3dee43511 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.hive.thriftserver import java.io.File import java.net.URL -import java.nio.charset.StandardCharsets -import java.nio.file.{Files, Paths} import java.sql.{Date, DriverManager, Statement} import scala.collection.mutable.ArrayBuffer @@ -29,6 +27,8 @@ import scala.concurrent.{Await, Promise} import scala.sys.process.{Process, ProcessLogger} import scala.util.{Random, Try} +import com.google.common.base.Charsets.UTF_8 +import com.google.common.io.Files import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.jdbc.HiveDriver import org.apache.hive.service.auth.PlainSaslHelper @@ -441,13 +441,14 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl val tempLog4jConf = Utils.createTempDir().getCanonicalPath Files.write( - Paths.get(s"$tempLog4jConf/log4j.properties"), """log4j.rootCategory=INFO, console |log4j.appender.console=org.apache.log4j.ConsoleAppender |log4j.appender.console.target=System.err |log4j.appender.console.layout=org.apache.log4j.PatternLayout |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n - """.stripMargin.getBytes(StandardCharsets.UTF_8)) + """.stripMargin, + new File(s"$tempLog4jConf/log4j.properties"), + UTF_8) tempLog4jConf + File.pathSeparator + sys.props("java.class.path") } From c3f4c3257194ba34ccd298d13ea1edcfc75f7552 Mon Sep 17 00:00:00 2001 From: Ram Sriharsha Date: Tue, 2 Jun 2015 18:53:04 -0700 Subject: [PATCH 148/254] [SPARK-7387] [ML] [DOC] CrossValidator example code in Python Author: Ram Sriharsha Closes #6358 from harsha2010/SPARK-7387 and squashes the following commits: 63efda2 [Ram Sriharsha] more examples for classifier to distinguish mapreduce from spark properly aeb6bb6 [Ram Sriharsha] Python Style Fix 54a500c [Ram Sriharsha] Merge branch 'master' into SPARK-7387 615e91c [Ram Sriharsha] cleanup 204c4e3 [Ram Sriharsha] Merge branch 'master' into SPARK-7387 7246d35 [Ram Sriharsha] [SPARK-7387][ml][doc] CrossValidator example code in Python --- .../src/main/python/ml/cross_validator.py | 96 +++++++++++++++++++ .../main/python/ml/simple_params_example.py | 4 +- 2 files changed, 98 insertions(+), 2 deletions(-) create mode 100644 examples/src/main/python/ml/cross_validator.py diff --git a/examples/src/main/python/ml/cross_validator.py b/examples/src/main/python/ml/cross_validator.py new file mode 100644 index 0000000000000..f0ca97c724940 --- /dev/null +++ b/examples/src/main/python/ml/cross_validator.py @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.ml import Pipeline +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.evaluation import BinaryClassificationEvaluator +from pyspark.ml.feature import HashingTF, Tokenizer +from pyspark.ml.tuning import CrossValidator, ParamGridBuilder +from pyspark.sql import Row, SQLContext + +""" +A simple example demonstrating model selection using CrossValidator. +This example also demonstrates how Pipelines are Estimators. +Run with: + + bin/spark-submit examples/src/main/python/ml/cross_validator.py +""" + +if __name__ == "__main__": + sc = SparkContext(appName="CrossValidatorExample") + sqlContext = SQLContext(sc) + + # Prepare training documents, which are labeled. + LabeledDocument = Row("id", "text", "label") + training = sc.parallelize([(0, "a b c d e spark", 1.0), + (1, "b d", 0.0), + (2, "spark f g h", 1.0), + (3, "hadoop mapreduce", 0.0), + (4, "b spark who", 1.0), + (5, "g d a y", 0.0), + (6, "spark fly", 1.0), + (7, "was mapreduce", 0.0), + (8, "e spark program", 1.0), + (9, "a e c l", 0.0), + (10, "spark compile", 1.0), + (11, "hadoop software", 0.0) + ]) \ + .map(lambda x: LabeledDocument(*x)).toDF() + + # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. + tokenizer = Tokenizer(inputCol="text", outputCol="words") + hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") + lr = LogisticRegression(maxIter=10) + pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) + + # We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. + # This will allow us to jointly choose parameters for all Pipeline stages. + # A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + # We use a ParamGridBuilder to construct a grid of parameters to search over. + # With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, + # this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. + paramGrid = ParamGridBuilder() \ + .addGrid(hashingTF.numFeatures, [10, 100, 1000]) \ + .addGrid(lr.regParam, [0.1, 0.01]) \ + .build() + + crossval = CrossValidator(estimator=pipeline, + estimatorParamMaps=paramGrid, + evaluator=BinaryClassificationEvaluator(), + numFolds=2) # use 3+ folds in practice + + # Run cross-validation, and choose the best set of parameters. + cvModel = crossval.fit(training) + + # Prepare test documents, which are unlabeled. + Document = Row("id", "text") + test = sc.parallelize([(4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop")]) \ + .map(lambda x: Document(*x)).toDF() + + # Make predictions on test documents. cvModel uses the best model found (lrModel). + prediction = cvModel.transform(test) + selected = prediction.select("id", "text", "probability", "prediction") + for row in selected.collect(): + print(row) + + sc.stop() diff --git a/examples/src/main/python/ml/simple_params_example.py b/examples/src/main/python/ml/simple_params_example.py index 3933d59b52cd1..a9f29dab2d602 100644 --- a/examples/src/main/python/ml/simple_params_example.py +++ b/examples/src/main/python/ml/simple_params_example.py @@ -41,8 +41,8 @@ # prepare training data. # We create an RDD of LabeledPoints and convert them into a DataFrame. - # Spark DataFrames can automatically infer the schema from named tuples - # and LabeledPoint implements __reduce__ to behave like a named tuple. + # A LabeledPoint is an Object with two fields named label and features + # and Spark SQL identifies these fields and creates the schema appropriately. training = sc.parallelize([ LabeledPoint(1.0, DenseVector([0.0, 1.1, 0.1])), LabeledPoint(0.0, DenseVector([2.0, 1.0, -1.0])), From a86b3e9b9b75f5af4fdbba22e87769058f023204 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 2 Jun 2015 19:12:08 -0700 Subject: [PATCH 149/254] [SPARK-7547] [ML] Scala Example code for ElasticNet This is scala example code for both linear and logistic regression. Python and Java versions are to be added. Author: DB Tsai Closes #6576 from dbtsai/elasticNetExample and squashes the following commits: e7ca406 [DB Tsai] fix test 6bb6d77 [DB Tsai] fix suite and remove duplicated setMaxIter 136e0dd [DB Tsai] address feedback 1ec29d4 [DB Tsai] fix style 9462f5f [DB Tsai] add example --- .../examples/ml/LinearRegressionExample.scala | 142 ++++++++++++++++ .../ml/LogisticRegressionExample.scala | 159 ++++++++++++++++++ .../classification/LogisticRegression.scala | 8 +- .../ml/param/shared/SharedParamsCodeGen.scala | 2 +- .../spark/ml/param/shared/sharedParams.scala | 4 +- .../ml/regression/LinearRegression.scala | 2 +- .../apache/spark/ml/param/ParamsSuite.scala | 6 +- 7 files changed, 314 insertions(+), 9 deletions(-) create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala new file mode 100644 index 0000000000000..b54466fd48bc5 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} +import org.apache.spark.sql.DataFrame + +/** + * An example runner for linear regression with elastic-net (mixing L1/L2) regularization. + * Run with + * {{{ + * bin/run-example ml.LinearRegressionExample [options] + * }}} + * A synthetic dataset can be found at `data/mllib/sample_linear_regression_data.txt` which can be + * trained by + * {{{ + * bin/run-example ml.LinearRegressionExample --regParam 0.15 --elasticNetParam 1.0 \ + * data/mllib/sample_linear_regression_data.txt + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object LinearRegressionExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + regParam: Double = 0.0, + elasticNetParam: Double = 0.0, + maxIter: Int = 100, + tol: Double = 1E-6, + fracTest: Double = 0.2) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("LinearRegressionExample") { + head("LinearRegressionExample: an example Linear Regression with Elastic-Net app.") + opt[Double]("regParam") + .text(s"regularization parameter, default: ${defaultParams.regParam}") + .action((x, c) => c.copy(regParam = x)) + opt[Double]("elasticNetParam") + .text(s"ElasticNet mixing parameter. For alpha = 0, the penalty is an L2 penalty. " + + s"For alpha = 1, it is an L1 penalty. For 0 < alpha < 1, the penalty is a combination of " + + s"L1 and L2, default: ${defaultParams.elasticNetParam}") + .action((x, c) => c.copy(elasticNetParam = x)) + opt[Int]("maxIter") + .text(s"maximum number of iterations, default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Double]("tol") + .text(s"the convergence tolerance of iterations, Smaller value will lead " + + s"to higher accuracy with the cost of more iterations, default: ${defaultParams.tol}") + .action((x, c) => c.copy(tol = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("dataFormat") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"LinearRegressionExample with $params") + val sc = new SparkContext(conf) + + println(s"LinearRegressionExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, "regression", params.fracTest) + + val lir = new LinearRegression() + .setFeaturesCol("features") + .setLabelCol("label") + .setRegParam(params.regParam) + .setElasticNetParam(params.elasticNetParam) + .setMaxIter(params.maxIter) + .setTol(params.tol) + + // Train the model + val startTime = System.nanoTime() + val lirModel = lir.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + // Print the weights and intercept for linear regression. + println(s"Weights: ${lirModel.weights} Intercept: ${lirModel.intercept}") + + println("Training data results:") + DecisionTreeExample.evaluateRegressionModel(lirModel, training, "label") + println("Test data results:") + DecisionTreeExample.evaluateRegressionModel(lirModel, test, "label") + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala new file mode 100644 index 0000000000000..b12f833ce94c8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.feature.StringIndexer +import org.apache.spark.sql.DataFrame + +/** + * An example runner for logistic regression with elastic-net (mixing L1/L2) regularization. + * Run with + * {{{ + * bin/run-example ml.LogisticRegressionExample [options] + * }}} + * A synthetic dataset can be found at `data/mllib/sample_libsvm_data.txt` which can be + * trained by + * {{{ + * bin/run-example ml.LogisticRegressionExample --regParam 0.3 --elasticNetParam 0.8 \ + * data/mllib/sample_libsvm_data.txt + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object LogisticRegressionExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + regParam: Double = 0.0, + elasticNetParam: Double = 0.0, + maxIter: Int = 100, + fitIntercept: Boolean = true, + tol: Double = 1E-6, + fracTest: Double = 0.2) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("LogisticRegressionExample") { + head("LogisticRegressionExample: an example Logistic Regression with Elastic-Net app.") + opt[Double]("regParam") + .text(s"regularization parameter, default: ${defaultParams.regParam}") + .action((x, c) => c.copy(regParam = x)) + opt[Double]("elasticNetParam") + .text(s"ElasticNet mixing parameter. For alpha = 0, the penalty is an L2 penalty. " + + s"For alpha = 1, it is an L1 penalty. For 0 < alpha < 1, the penalty is a combination of " + + s"L1 and L2, default: ${defaultParams.elasticNetParam}") + .action((x, c) => c.copy(elasticNetParam = x)) + opt[Int]("maxIter") + .text(s"maximum number of iterations, default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Boolean]("fitIntercept") + .text(s"whether to fit an intercept term, default: ${defaultParams.fitIntercept}") + .action((x, c) => c.copy(fitIntercept = x)) + opt[Double]("tol") + .text(s"the convergence tolerance of iterations, Smaller value will lead " + + s"to higher accuracy with the cost of more iterations, default: ${defaultParams.tol}") + .action((x, c) => c.copy(tol = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("dataFormat") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"LogisticRegressionExample with $params") + val sc = new SparkContext(conf) + + println(s"LogisticRegressionExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, "classification", params.fracTest) + + // Set up Pipeline + val stages = new mutable.ArrayBuffer[PipelineStage]() + + val labelIndexer = new StringIndexer() + .setInputCol("labelString") + .setOutputCol("indexedLabel") + stages += labelIndexer + + val lor = new LogisticRegression() + .setFeaturesCol("features") + .setLabelCol("indexedLabel") + .setRegParam(params.regParam) + .setElasticNetParam(params.elasticNetParam) + .setMaxIter(params.maxIter) + .setTol(params.tol) + + stages += lor + val pipeline = new Pipeline().setStages(stages.toArray) + + // Fit the Pipeline + val startTime = System.nanoTime() + val pipelineModel = pipeline.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + val lirModel = pipelineModel.stages.last.asInstanceOf[LogisticRegressionModel] + // Print the weights and intercept for logistic regression. + println(s"Weights: ${lirModel.weights} Intercept: ${lirModel.intercept}") + + println("Training data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, "indexedLabel") + println("Test data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, "indexedLabel") + + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index d13109d9da4c0..f136bcee9cf2b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -74,7 +74,7 @@ class LogisticRegression(override val uid: String) setDefault(elasticNetParam -> 0.0) /** - * Set the maximal number of iterations. + * Set the maximum number of iterations. * Default is 100. * @group setParam */ @@ -90,7 +90,11 @@ class LogisticRegression(override val uid: String) def setTol(value: Double): this.type = set(tol, value) setDefault(tol -> 1E-6) - /** @group setParam */ + /** + * Whether to fit an intercept term. + * Default is true. + * @group setParam + * */ def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) setDefault(fitIntercept -> true) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 1ffb5eddc36bd..8ffbcf0d8bc71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -33,7 +33,7 @@ private[shared] object SharedParamsCodeGen { val params = Seq( ParamDesc[Double]("regParam", "regularization parameter (>= 0)", isValid = "ParamValidators.gtEq(0)"), - ParamDesc[Int]("maxIter", "max number of iterations (>= 0)", + ParamDesc[Int]("maxIter", "maximum number of iterations (>= 0)", isValid = "ParamValidators.gtEq(0)"), ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")), ParamDesc[String]("labelCol", "label column name", Some("\"label\"")), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index ed08417bd4df8..a0c8ccdac9ad9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -45,10 +45,10 @@ private[ml] trait HasRegParam extends Params { private[ml] trait HasMaxIter extends Params { /** - * Param for max number of iterations (>= 0). + * Param for maximum number of iterations (>= 0). * @group param */ - final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations (>= 0)", ParamValidators.gtEq(0)) + final val maxIter: IntParam = new IntParam(this, "maxIter", "maximum number of iterations (>= 0)", ParamValidators.gtEq(0)) /** @group getParam */ final def getMaxIter: Int = $(maxIter) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index fe2a71a331694..70cd8e9e87fae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -83,7 +83,7 @@ class LinearRegression(override val uid: String) setDefault(elasticNetParam -> 0.0) /** - * Set the maximal number of iterations. + * Set the maximum number of iterations. * Default is 100. * @group setParam */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index f80e7749098a5..96094d7a099aa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -27,7 +27,7 @@ class ParamsSuite extends SparkFunSuite { import solver.{maxIter, inputCol} assert(maxIter.name === "maxIter") - assert(maxIter.doc === "max number of iterations (>= 0)") + assert(maxIter.doc === "maximum number of iterations (>= 0)") assert(maxIter.parent === uid) assert(maxIter.toString === s"${uid}__maxIter") assert(!maxIter.isValid(-1)) @@ -36,7 +36,7 @@ class ParamsSuite extends SparkFunSuite { solver.setMaxIter(5) assert(solver.explainParam(maxIter) === - "maxIter: max number of iterations (>= 0) (default: 10, current: 5)") + "maxIter: maximum number of iterations (>= 0) (default: 10, current: 5)") assert(inputCol.toString === s"${uid}__inputCol") @@ -120,7 +120,7 @@ class ParamsSuite extends SparkFunSuite { intercept[NoSuchElementException](solver.getInputCol) assert(solver.explainParam(maxIter) === - "maxIter: max number of iterations (>= 0) (default: 10, current: 100)") + "maxIter: maximum number of iterations (>= 0) (default: 10, current: 100)") assert(solver.explainParams() === Seq(inputCol, maxIter).map(solver.explainParam).mkString("\n")) From cafd5056e12a15f0ebf8015d52dfab999c4443b8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 2 Jun 2015 22:11:03 -0700 Subject: [PATCH 150/254] [SPARK-7691] [SQL] Refactor CatalystTypeConverter to use type-specific row accessors This patch significantly refactors CatalystTypeConverters to both clean up the code and enable these conversions to work with future Project Tungsten features. At a high level, I've reorganized the code so that all functions dealing with the same type are grouped together into type-specific subclasses of `CatalystTypeConveter`. In addition, I've added new methods that allow the Catalyst Row -> Scala Row conversions to access the Catalyst row's fields through type-specific `getTYPE()` methods rather than the generic `get()` / `Row.apply` methods. This refactoring is a blocker to being able to unit test new operators that I'm developing as part of Project Tungsten, since those operators may output `UnsafeRow` instances which don't support the generic `get()`. The stricter type usage of types here has uncovered some bugs in other parts of Spark SQL: - #6217: DescribeCommand is assigned wrong output attributes in SparkStrategies - #6218: DataFrame.describe() should cast all aggregates to String - #6400: Use output schema, not relation schema, for data source input conversion Spark SQL current has undefined behavior for what happens when you try to create a DataFrame from user-specified rows whose values don't match the declared schema. According to the `createDataFrame()` Scaladoc: > It is important to make sure that the structure of every [[Row]] of the provided RDD matches the provided schema. Otherwise, there will be runtime exception. Given this, it sounds like it's technically not a break of our API contract to fail-fast when the data types don't match. However, there appear to be many cases where we don't fail even though the types don't match. For example, `JavaHashingTFSuite.hasingTF` passes a column of integers values for a "label" column which is supposed to contain floats. This column isn't actually read or modified as part of query processing, so its actual concrete type doesn't seem to matter. In other cases, there could be situations where we have generic numeric aggregates that tolerate being called with different numeric types than the schema specified, but this can be okay due to numeric conversions. In the long run, we will probably want to come up with precise semantics for implicit type conversions / widening when converting Java / Scala rows to Catalyst rows. Until then, though, I think that failing fast with a ClassCastException is a reasonable behavior; this is the approach taken in this patch. Note that certain optimizations in the inbound conversion functions for primitive types mean that we'll probably preserve the old undefined behavior in a majority of cases. Author: Josh Rosen Closes #6222 from JoshRosen/catalyst-converters-refactoring and squashes the following commits: 740341b [Josh Rosen] Optimize method dispatch for primitive type conversions befc613 [Josh Rosen] Add tests to document Option-handling behavior. 5989593 [Josh Rosen] Use new SparkFunSuite base in CatalystTypeConvertersSuite 6edf7f8 [Josh Rosen] Re-add convertToScala(), since a Hive test still needs it 3f7b2d8 [Josh Rosen] Initialize converters lazily so that the attributes are resolved first 6ad0ebb [Josh Rosen] Fix JavaHashingTFSuite ClassCastException 677ff27 [Josh Rosen] Fix null handling bug; add tests. 8033d4c [Josh Rosen] Fix serialization error in UserDefinedGenerator. 85bba9d [Josh Rosen] Fix wrong input data in InMemoryColumnarQuerySuite 9c0e4e1 [Josh Rosen] Remove last use of convertToScala(). ae3278d [Josh Rosen] Throw ClassCastException errors during inbound conversions. 7ca7fcb [Josh Rosen] Comments and cleanup 1e87a45 [Josh Rosen] WIP refactoring of CatalystTypeConverters --- .../spark/ml/feature/JavaHashingTFSuite.java | 6 +- .../sql/catalyst/CatalystTypeConverters.scala | 558 ++++++++++-------- .../sql/catalyst/expressions/generators.scala | 19 +- .../CatalystTypeConvertersSuite.scala | 62 ++ .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- 5 files changed, 382 insertions(+), 265 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java index da2218056307e..599e9cfd23ad4 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java @@ -55,9 +55,9 @@ public void tearDown() { @Test public void hashingTF() { JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(0, "Hi I heard about Spark"), - RowFactory.create(0, "I wish Java could use case classes"), - RowFactory.create(1, "Logistic regression models are neat") + RowFactory.create(0.0, "Hi I heard about Spark"), + RowFactory.create(0.0, "I wish Java could use case classes"), + RowFactory.create(1.0, "Logistic regression models are neat") )); StructType schema = new StructType(new StructField[]{ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 1c0ddb5093d17..2e7b4c236d8f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -18,7 +18,10 @@ package org.apache.spark.sql.catalyst import java.lang.{Iterable => JavaIterable} +import java.math.{BigDecimal => JavaBigDecimal} +import java.sql.Date import java.util.{Map => JavaMap} +import javax.annotation.Nullable import scala.collection.mutable.HashMap @@ -34,197 +37,338 @@ object CatalystTypeConverters { // Since the map values can be mutable, we explicitly import scala.collection.Map at here. import scala.collection.Map + private def isPrimitive(dataType: DataType): Boolean = { + dataType match { + case BooleanType => true + case ByteType => true + case ShortType => true + case IntegerType => true + case LongType => true + case FloatType => true + case DoubleType => true + case _ => false + } + } + + private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = { + val converter = dataType match { + case udt: UserDefinedType[_] => UDTConverter(udt) + case arrayType: ArrayType => ArrayConverter(arrayType.elementType) + case mapType: MapType => MapConverter(mapType.keyType, mapType.valueType) + case structType: StructType => StructConverter(structType) + case StringType => StringConverter + case DateType => DateConverter + case dt: DecimalType => BigDecimalConverter + case BooleanType => BooleanConverter + case ByteType => ByteConverter + case ShortType => ShortConverter + case IntegerType => IntConverter + case LongType => LongConverter + case FloatType => FloatConverter + case DoubleType => DoubleConverter + case _ => IdentityConverter + } + converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]] + } + /** - * Converts Scala objects to catalyst rows / types. This method is slow, and for batch - * conversion you should be using converter produced by createToCatalystConverter. - * Note: This is always called after schemaFor has been called. - * This ordering is important for UDT registration. + * Converts a Scala type to its Catalyst equivalent (and vice versa). + * + * @tparam ScalaInputType The type of Scala values that can be converted to Catalyst. + * @tparam ScalaOutputType The type of Scala values returned when converting Catalyst to Scala. + * @tparam CatalystType The internal Catalyst type used to represent values of this Scala type. */ - def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { - // Check UDT first since UDTs can override other types - case (obj, udt: UserDefinedType[_]) => - udt.serialize(obj) - - case (o: Option[_], _) => - o.map(convertToCatalyst(_, dataType)).orNull - - case (s: Seq[_], arrayType: ArrayType) => - s.map(convertToCatalyst(_, arrayType.elementType)) - - case (jit: JavaIterable[_], arrayType: ArrayType) => { - val iter = jit.iterator - var listOfItems: List[Any] = List() - while (iter.hasNext) { - val item = iter.next() - listOfItems :+= convertToCatalyst(item, arrayType.elementType) + private abstract class CatalystTypeConverter[ScalaInputType, ScalaOutputType, CatalystType] + extends Serializable { + + /** + * Converts a Scala type to its Catalyst equivalent while automatically handling nulls + * and Options. + */ + final def toCatalyst(@Nullable maybeScalaValue: Any): CatalystType = { + if (maybeScalaValue == null) { + null.asInstanceOf[CatalystType] + } else if (maybeScalaValue.isInstanceOf[Option[ScalaInputType]]) { + val opt = maybeScalaValue.asInstanceOf[Option[ScalaInputType]] + if (opt.isDefined) { + toCatalystImpl(opt.get) + } else { + null.asInstanceOf[CatalystType] + } + } else { + toCatalystImpl(maybeScalaValue.asInstanceOf[ScalaInputType]) } - listOfItems } - case (s: Array[_], arrayType: ArrayType) => - s.toSeq.map(convertToCatalyst(_, arrayType.elementType)) + /** + * Given a Catalyst row, convert the value at column `column` to its Scala equivalent. + */ + final def toScala(row: Row, column: Int): ScalaOutputType = { + if (row.isNullAt(column)) null.asInstanceOf[ScalaOutputType] else toScalaImpl(row, column) + } + + /** + * Convert a Catalyst value to its Scala equivalent. + */ + def toScala(@Nullable catalystValue: CatalystType): ScalaOutputType + + /** + * Converts a Scala value to its Catalyst equivalent. + * @param scalaValue the Scala value, guaranteed not to be null. + * @return the Catalyst value. + */ + protected def toCatalystImpl(scalaValue: ScalaInputType): CatalystType + + /** + * Given a Catalyst row, convert the value at column `column` to its Scala equivalent. + * This method will only be called on non-null columns. + */ + protected def toScalaImpl(row: Row, column: Int): ScalaOutputType + } - case (m: Map[_, _], mapType: MapType) => - m.map { case (k, v) => - convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) - } + private object IdentityConverter extends CatalystTypeConverter[Any, Any, Any] { + override def toCatalystImpl(scalaValue: Any): Any = scalaValue + override def toScala(catalystValue: Any): Any = catalystValue + override def toScalaImpl(row: Row, column: Int): Any = row(column) + } - case (jmap: JavaMap[_, _], mapType: MapType) => - val iter = jmap.entrySet.iterator - var listOfEntries: List[(Any, Any)] = List() - while (iter.hasNext) { - val entry = iter.next() - listOfEntries :+= (convertToCatalyst(entry.getKey, mapType.keyType), - convertToCatalyst(entry.getValue, mapType.valueType)) + private case class UDTConverter( + udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] { + override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue) + override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue) + override def toScalaImpl(row: Row, column: Int): Any = toScala(row(column)) + } + + /** Converter for arrays, sequences, and Java iterables. */ + private case class ArrayConverter( + elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], Seq[Any]] { + + private[this] val elementConverter = getConverterForType(elementType) + + override def toCatalystImpl(scalaValue: Any): Seq[Any] = { + scalaValue match { + case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst) + case s: Seq[_] => s.map(elementConverter.toCatalyst) + case i: JavaIterable[_] => + val iter = i.iterator + var convertedIterable: List[Any] = List() + while (iter.hasNext) { + val item = iter.next() + convertedIterable :+= elementConverter.toCatalyst(item) + } + convertedIterable } - listOfEntries.toMap - - case (p: Product, structType: StructType) => - val ar = new Array[Any](structType.size) - val iter = p.productIterator - var idx = 0 - while (idx < structType.size) { - ar(idx) = convertToCatalyst(iter.next(), structType.fields(idx).dataType) - idx += 1 + } + + override def toScala(catalystValue: Seq[Any]): Seq[Any] = { + if (catalystValue == null) { + null + } else { + catalystValue.asInstanceOf[Seq[_]].map(elementConverter.toScala) } - new GenericRowWithSchema(ar, structType) + } - case (d: String, _) => - UTF8String(d) + override def toScalaImpl(row: Row, column: Int): Seq[Any] = + toScala(row(column).asInstanceOf[Seq[Any]]) + } + + private case class MapConverter( + keyType: DataType, + valueType: DataType) + extends CatalystTypeConverter[Any, Map[Any, Any], Map[Any, Any]] { - case (d: BigDecimal, _) => - Decimal(d) + private[this] val keyConverter = getConverterForType(keyType) + private[this] val valueConverter = getConverterForType(valueType) - case (d: java.math.BigDecimal, _) => - Decimal(d) + override def toCatalystImpl(scalaValue: Any): Map[Any, Any] = scalaValue match { + case m: Map[_, _] => + m.map { case (k, v) => + keyConverter.toCatalyst(k) -> valueConverter.toCatalyst(v) + } - case (d: java.sql.Date, _) => - DateUtils.fromJavaDate(d) + case jmap: JavaMap[_, _] => + val iter = jmap.entrySet.iterator + val convertedMap: HashMap[Any, Any] = HashMap() + while (iter.hasNext) { + val entry = iter.next() + val key = keyConverter.toCatalyst(entry.getKey) + convertedMap(key) = valueConverter.toCatalyst(entry.getValue) + } + convertedMap + } - case (r: Row, structType: StructType) => - val converters = structType.fields.map { - f => (item: Any) => convertToCatalyst(item, f.dataType) + override def toScala(catalystValue: Map[Any, Any]): Map[Any, Any] = { + if (catalystValue == null) { + null + } else { + catalystValue.map { case (k, v) => + keyConverter.toScala(k) -> valueConverter.toScala(v) + } } - convertRowWithConverters(r, structType, converters) + } - case (other, _) => - other + override def toScalaImpl(row: Row, column: Int): Map[Any, Any] = + toScala(row(column).asInstanceOf[Map[Any, Any]]) } - /** - * Creates a converter function that will convert Scala objects to the specified catalyst type. - * Typical use case would be converting a collection of rows that have the same schema. You will - * call this function once to get a converter, and apply it to every row. - */ - private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = { - def extractOption(item: Any): Any = item match { - case opt: Option[_] => opt.orNull - case other => other - } + private case class StructConverter( + structType: StructType) extends CatalystTypeConverter[Any, Row, Row] { - dataType match { - // Check UDT first since UDTs can override other types - case udt: UserDefinedType[_] => - (item) => extractOption(item) match { - case null => null - case other => udt.serialize(other) - } + private[this] val converters = structType.fields.map { f => getConverterForType(f.dataType) } - case arrayType: ArrayType => - val elementConverter = createToCatalystConverter(arrayType.elementType) - (item: Any) => { - extractOption(item) match { - case a: Array[_] => a.toSeq.map(elementConverter) - case s: Seq[_] => s.map(elementConverter) - case i: JavaIterable[_] => { - val iter = i.iterator - var convertedIterable: List[Any] = List() - while (iter.hasNext) { - val item = iter.next() - convertedIterable :+= elementConverter(item) - } - convertedIterable - } - case null => null - } + override def toCatalystImpl(scalaValue: Any): Row = scalaValue match { + case row: Row => + val ar = new Array[Any](row.size) + var idx = 0 + while (idx < row.size) { + ar(idx) = converters(idx).toCatalyst(row(idx)) + idx += 1 } - - case mapType: MapType => - val keyConverter = createToCatalystConverter(mapType.keyType) - val valueConverter = createToCatalystConverter(mapType.valueType) - (item: Any) => { - extractOption(item) match { - case m: Map[_, _] => - m.map { case (k, v) => - keyConverter(k) -> valueConverter(v) - } - - case jmap: JavaMap[_, _] => - val iter = jmap.entrySet.iterator - val convertedMap: HashMap[Any, Any] = HashMap() - while (iter.hasNext) { - val entry = iter.next() - convertedMap(keyConverter(entry.getKey)) = valueConverter(entry.getValue) - } - convertedMap - - case null => null - } + new GenericRowWithSchema(ar, structType) + + case p: Product => + val ar = new Array[Any](structType.size) + val iter = p.productIterator + var idx = 0 + while (idx < structType.size) { + ar(idx) = converters(idx).toCatalyst(iter.next()) + idx += 1 } + new GenericRowWithSchema(ar, structType) + } - case structType: StructType => - val converters = structType.fields.map(f => createToCatalystConverter(f.dataType)) - (item: Any) => { - extractOption(item) match { - case r: Row => - convertRowWithConverters(r, structType, converters) - - case p: Product => - val ar = new Array[Any](structType.size) - val iter = p.productIterator - var idx = 0 - while (idx < structType.size) { - ar(idx) = converters(idx)(iter.next()) - idx += 1 - } - new GenericRowWithSchema(ar, structType) - - case null => - null - } + override def toScala(row: Row): Row = { + if (row == null) { + null + } else { + val ar = new Array[Any](row.size) + var idx = 0 + while (idx < row.size) { + ar(idx) = converters(idx).toScala(row, idx) + idx += 1 } - - case dateType: DateType => (item: Any) => extractOption(item) match { - case d: java.sql.Date => DateUtils.fromJavaDate(d) - case other => other + new GenericRowWithSchema(ar, structType) } + } - case dataType: StringType => (item: Any) => extractOption(item) match { - case s: String => UTF8String(s) - case other => other - } + override def toScalaImpl(row: Row, column: Int): Row = toScala(row(column).asInstanceOf[Row]) + } + + private object StringConverter extends CatalystTypeConverter[Any, String, Any] { + override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match { + case str: String => UTF8String(str) + case utf8: UTF8String => utf8 + } + override def toScala(catalystValue: Any): String = catalystValue match { + case null => null + case str: String => str + case utf8: UTF8String => utf8.toString() + } + override def toScalaImpl(row: Row, column: Int): String = row(column).toString + } + + private object DateConverter extends CatalystTypeConverter[Date, Date, Any] { + override def toCatalystImpl(scalaValue: Date): Int = DateUtils.fromJavaDate(scalaValue) + override def toScala(catalystValue: Any): Date = + if (catalystValue == null) null else DateUtils.toJavaDate(catalystValue.asInstanceOf[Int]) + override def toScalaImpl(row: Row, column: Int): Date = toScala(row.getInt(column)) + } + + private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { + override def toCatalystImpl(scalaValue: Any): Decimal = scalaValue match { + case d: BigDecimal => Decimal(d) + case d: JavaBigDecimal => Decimal(d) + case d: Decimal => d + } + override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal + override def toScalaImpl(row: Row, column: Int): JavaBigDecimal = row.get(column) match { + case d: JavaBigDecimal => d + case d: Decimal => d.toJavaBigDecimal + } + } + + private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] { + final override def toScala(catalystValue: Any): Any = catalystValue + final override def toCatalystImpl(scalaValue: T): Any = scalaValue + } + + private object BooleanConverter extends PrimitiveConverter[Boolean] { + override def toScalaImpl(row: Row, column: Int): Boolean = row.getBoolean(column) + } + + private object ByteConverter extends PrimitiveConverter[Byte] { + override def toScalaImpl(row: Row, column: Int): Byte = row.getByte(column) + } + + private object ShortConverter extends PrimitiveConverter[Short] { + override def toScalaImpl(row: Row, column: Int): Short = row.getShort(column) + } + + private object IntConverter extends PrimitiveConverter[Int] { + override def toScalaImpl(row: Row, column: Int): Int = row.getInt(column) + } + + private object LongConverter extends PrimitiveConverter[Long] { + override def toScalaImpl(row: Row, column: Int): Long = row.getLong(column) + } + + private object FloatConverter extends PrimitiveConverter[Float] { + override def toScalaImpl(row: Row, column: Int): Float = row.getFloat(column) + } - case _ => - (item: Any) => extractOption(item) match { - case d: BigDecimal => Decimal(d) - case d: java.math.BigDecimal => Decimal(d) - case other => other + private object DoubleConverter extends PrimitiveConverter[Double] { + override def toScalaImpl(row: Row, column: Int): Double = row.getDouble(column) + } + + /** + * Converts Scala objects to catalyst rows / types. This method is slow, and for batch + * conversion you should be using converter produced by createToCatalystConverter. + * Note: This is always called after schemaFor has been called. + * This ordering is important for UDT registration. + */ + def convertToCatalyst(scalaValue: Any, dataType: DataType): Any = { + getConverterForType(dataType).toCatalyst(scalaValue) + } + + /** + * Creates a converter function that will convert Scala objects to the specified Catalyst type. + * Typical use case would be converting a collection of rows that have the same schema. You will + * call this function once to get a converter, and apply it to every row. + */ + private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = { + if (isPrimitive(dataType)) { + // Although the `else` branch here is capable of handling inbound conversion of primitives, + // we add some special-case handling for those types here. The motivation for this relates to + // Java method invocation costs: if we have rows that consist entirely of primitive columns, + // then returning the same conversion function for all of the columns means that the call site + // will be monomorphic instead of polymorphic. In microbenchmarks, this actually resulted in + // a measurable performance impact. Note that this optimization will be unnecessary if we + // use code generation to construct Scala Row -> Catalyst Row converters. + def convert(maybeScalaValue: Any): Any = { + if (maybeScalaValue.isInstanceOf[Option[Any]]) { + maybeScalaValue.asInstanceOf[Option[Any]].orNull + } else { + maybeScalaValue } + } + convert + } else { + getConverterForType(dataType).toCatalyst } } /** - * Converts Scala objects to catalyst rows / types. + * Converts Scala objects to Catalyst rows / types. * * Note: This should be called before do evaluation on Row * (It does not support UDT) * This is used to create an RDD or test results with correct types for Catalyst. */ def convertToCatalyst(a: Any): Any = a match { - case s: String => UTF8String(s) - case d: java.sql.Date => DateUtils.fromJavaDate(d) - case d: BigDecimal => Decimal(d) - case d: java.math.BigDecimal => Decimal(d) + case s: String => StringConverter.toCatalyst(s) + case d: Date => DateConverter.toCatalyst(d) + case d: BigDecimal => BigDecimalConverter.toCatalyst(d) + case d: JavaBigDecimal => BigDecimalConverter.toCatalyst(d) case seq: Seq[Any] => seq.map(convertToCatalyst) case r: Row => Row(r.toSeq.map(convertToCatalyst): _*) case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray @@ -238,33 +382,8 @@ object CatalystTypeConverters { * This method is slow, and for batch conversion you should be using converter * produced by createToScalaConverter. */ - def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match { - // Check UDT first since UDTs can override other types - case (d, udt: UserDefinedType[_]) => - udt.deserialize(d) - - case (s: Seq[_], arrayType: ArrayType) => - s.map(convertToScala(_, arrayType.elementType)) - - case (m: Map[_, _], mapType: MapType) => - m.map { case (k, v) => - convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) - } - - case (r: Row, s: StructType) => - convertRowToScala(r, s) - - case (d: Decimal, _: DecimalType) => - d.toJavaBigDecimal - - case (i: Int, DateType) => - DateUtils.toJavaDate(i) - - case (s: UTF8String, StringType) => - s.toString() - - case (other, _) => - other + def convertToScala(catalystValue: Any, dataType: DataType): Any = { + getConverterForType(dataType).toScala(catalystValue) } /** @@ -272,82 +391,7 @@ object CatalystTypeConverters { * Typical use case would be converting a collection of rows that have the same schema. You will * call this function once to get a converter, and apply it to every row. */ - private[sql] def createToScalaConverter(dataType: DataType): Any => Any = dataType match { - // Check UDT first since UDTs can override other types - case udt: UserDefinedType[_] => - (item: Any) => if (item == null) null else udt.deserialize(item) - - case arrayType: ArrayType => - val elementConverter = createToScalaConverter(arrayType.elementType) - (item: Any) => if (item == null) null else item.asInstanceOf[Seq[_]].map(elementConverter) - - case mapType: MapType => - val keyConverter = createToScalaConverter(mapType.keyType) - val valueConverter = createToScalaConverter(mapType.valueType) - (item: Any) => if (item == null) { - null - } else { - item.asInstanceOf[Map[_, _]].map { case (k, v) => - keyConverter(k) -> valueConverter(v) - } - } - - case s: StructType => - val converters = s.fields.map(f => createToScalaConverter(f.dataType)) - (item: Any) => { - if (item == null) { - null - } else { - convertRowWithConverters(item.asInstanceOf[Row], s, converters) - } - } - - case _: DecimalType => - (item: Any) => item match { - case d: Decimal => d.toJavaBigDecimal - case other => other - } - - case DateType => - (item: Any) => item match { - case i: Int => DateUtils.toJavaDate(i) - case other => other - } - - case StringType => - (item: Any) => item match { - case s: UTF8String => s.toString() - case other => other - } - - case other => - (item: Any) => item - } - - def convertRowToScala(r: Row, schema: StructType): Row = { - val ar = new Array[Any](r.size) - var idx = 0 - while (idx < r.size) { - ar(idx) = convertToScala(r(idx), schema.fields(idx).dataType) - idx += 1 - } - new GenericRowWithSchema(ar, schema) - } - - /** - * Converts a row by applying the provided set of converter functions. It is used for both - * toScala and toCatalyst conversions. - */ - private[sql] def convertRowWithConverters( - row: Row, - schema: StructType, - converters: Array[Any => Any]): Row = { - val ar = new Array[Any](row.size) - var idx = 0 - while (idx < row.size) { - ar(idx) = converters(idx)(row(idx)) - idx += 1 - } - new GenericRowWithSchema(ar, schema) + private[sql] def createToScalaConverter(dataType: DataType): Any => Any = { + getConverterForType(dataType).toScala } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 634138010fd21..b6191eafba71b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -71,12 +71,23 @@ case class UserDefinedGenerator( children: Seq[Expression]) extends Generator { + @transient private[this] var inputRow: InterpretedProjection = _ + @transient private[this] var convertToScala: (Row) => Row = _ + + private def initializeConverters(): Unit = { + inputRow = new InterpretedProjection(children) + convertToScala = { + val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) + CatalystTypeConverters.createToScalaConverter(inputSchema) + }.asInstanceOf[(Row => Row)] + } + override def eval(input: Row): TraversableOnce[Row] = { - // TODO(davies): improve this + if (inputRow == null) { + initializeConverters() + } // Convert the objects into Scala Type before calling function, we need schema to support UDT - val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) - val inputRow = new InterpretedProjection(children) - function(CatalystTypeConverters.convertToScala(inputRow(input), inputSchema).asInstanceOf[Row]) + function(convertToScala(inputRow(input))) } override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala new file mode 100644 index 0000000000000..df0f04563edcf --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + +class CatalystTypeConvertersSuite extends SparkFunSuite { + + private val simpleTypes: Seq[DataType] = Seq( + StringType, + DateType, + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType) + + test("null handling in rows") { + val schema = StructType(simpleTypes.map(t => StructField(t.getClass.getName, t))) + val convertToCatalyst = CatalystTypeConverters.createToCatalystConverter(schema) + val convertToScala = CatalystTypeConverters.createToScalaConverter(schema) + + val scalaRow = Row.fromSeq(Seq.fill(simpleTypes.length)(null)) + assert(convertToScala(convertToCatalyst(scalaRow)) === scalaRow) + } + + test("null handling for individual values") { + for (dataType <- simpleTypes) { + assert(CatalystTypeConverters.createToScalaConverter(dataType)(null) === null) + } + } + + test("option handling in convertToCatalyst") { + // convertToCatalyst doesn't handle unboxing from Options. This is inconsistent with + // createToCatalystConverter but it may not actually matter as this is only called internally + // in a handful of places where we don't expect to receive Options. + assert(CatalystTypeConverters.convertToCatalyst(Some(123)) === Some(123)) + } + + test("option handling in createToCatalystConverter") { + assert(CatalystTypeConverters.createToCatalystConverter(IntegerType)(Some(123)) === 123) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 56591d9dba29e..055453e688e73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -173,7 +173,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { new Timestamp(i), (1 to i).toSeq, (0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, - Row((i - 0.25).toFloat, (1 to i).toSeq)) + Row((i - 0.25).toFloat, Seq(true, false, null))) } createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") // Cache the table. From 07c16cb5ba9cb0bfe34e8c0efbf06540a22d4e4e Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 2 Jun 2015 22:56:56 -0700 Subject: [PATCH 151/254] [SPARK-8053] [MLLIB] renamed scalingVector to scalingVec I searched the Spark codebase for all occurrences of "scalingVector" CC: mengxr Author: Joseph K. Bradley Closes #6596 from jkbradley/scalingVec-rename and squashes the following commits: d3812f8 [Joseph K. Bradley] renamed scalingVector to scalingVec --- .../spark/ml/feature/ElementwiseProduct.scala | 2 +- .../spark/mllib/feature/ElementwiseProduct.scala | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index 3ae1833390152..1e758cb775de7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -41,7 +41,7 @@ class ElementwiseProduct(override val uid: String) * the vector to multiply with input vectors * @group param */ - val scalingVec: Param[Vector] = new Param(this, "scalingVector", "vector for hadamard product") + val scalingVec: Param[Vector] = new Param(this, "scalingVec", "vector for hadamard product") /** @group setParam */ def setScalingVec(value: Vector): this.type = set(scalingVec, value) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala index b0985baf9b278..d67fe6c3ee4f8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala @@ -25,10 +25,10 @@ import org.apache.spark.mllib.linalg._ * Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a * provided "weight" vector. In other words, it scales each column of the dataset by a scalar * multiplier. - * @param scalingVector The values used to scale the reference vector's individual components. + * @param scalingVec The values used to scale the reference vector's individual components. */ @Experimental -class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer { +class ElementwiseProduct(val scalingVec: Vector) extends VectorTransformer { /** * Does the hadamard product transformation. @@ -37,15 +37,15 @@ class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer { * @return transformed vector. */ override def transform(vector: Vector): Vector = { - require(vector.size == scalingVector.size, - s"vector sizes do not match: Expected ${scalingVector.size} but found ${vector.size}") + require(vector.size == scalingVec.size, + s"vector sizes do not match: Expected ${scalingVec.size} but found ${vector.size}") vector match { case dv: DenseVector => val values: Array[Double] = dv.values.clone() - val dim = scalingVector.size + val dim = scalingVec.size var i = 0 while (i < dim) { - values(i) *= scalingVector(i) + values(i) *= scalingVec(i) i += 1 } Vectors.dense(values) @@ -54,7 +54,7 @@ class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer { val dim = values.length var i = 0 while (i < dim) { - values(i) *= scalingVector(indices(i)) + values(i) *= scalingVec(indices(i)) i += 1 } Vectors.sparse(size, indices, values) From ccaa823290cbe859cd224ac0f7071dfd0218b669 Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Tue, 2 Jun 2015 22:59:48 -0700 Subject: [PATCH 152/254] [MINOR] make the launcher project name consistent with others I found this by chance while building spark and think it is better to keep its name consistent with other sub-projects (Spark Project *). I am not gonna file JIRA as it is a pretty small issue. Author: WangTaoTheTonic Closes #6603 from WangTaoTheTonic/projName and squashes the following commits: 994b3ba [WangTaoTheTonic] make the project name consistent --- launcher/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launcher/pom.xml b/launcher/pom.xml index ebfa7685eaa18..cc177d23dff77 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -29,7 +29,7 @@ org.apache.spark spark-launcher_2.10 jar - Spark Launcher Project + Spark Project Launcher http://spark.apache.org/ launcher From 43adbd56114ba80039a23909b0a30d393eaacc62 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 2 Jun 2015 23:15:38 -0700 Subject: [PATCH 153/254] [SPARK-8043] [MLLIB] [DOC] update NaiveBayes and SVM examples in doc jira: https://issues.apache.org/jira/browse/SPARK-8043 I found some issues during testing the save/load examples in markdown Documents, as a part of 1.4 QA plan Author: Yuhao Yang Closes #6584 from hhbyyh/naiveDocExample and squashes the following commits: a01a206 [Yuhao Yang] fix for Gaussian mixture 2fb8b96 [Yuhao Yang] update NaiveBayes and SVM examples in doc --- docs/mllib-clustering.md | 6 +++--- docs/mllib-linear-methods.md | 24 ++++++++++-------------- docs/mllib-naive-bayes.md | 2 +- 3 files changed, 14 insertions(+), 18 deletions(-) diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index dac22f736e8cb..1b088969ddc25 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -249,11 +249,11 @@ public class GaussianMixtureExample { GaussianMixtureModel gmm = new GaussianMixture().setK(2).run(parsedData.rdd()); // Save and load GaussianMixtureModel - gmm.save(sc, "myGMMModel") - GaussianMixtureModel sameModel = GaussianMixtureModel.load(sc, "myGMMModel") + gmm.save(sc.sc(), "myGMMModel"); + GaussianMixtureModel sameModel = GaussianMixtureModel.load(sc.sc(), "myGMMModel"); // Output the parameters of the mixture model for(int j=0; j
    -The following example shows how to load a sample dataset, build Logistic Regression model, +The following example shows how to load a sample dataset, build SVM model, and make predictions with the resulting model to compute the training error. -Note that the Python API does not yet support model save/load but will in the future. - {% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithSGD +from pyspark.mllib.classification import SVMWithSGD, SVMModel from pyspark.mllib.regression import LabeledPoint -from numpy import array # Load and parse the data def parsePoint(line): @@ -334,12 +326,16 @@ data = sc.textFile("data/mllib/sample_svm_data.txt") parsedData = data.map(parsePoint) # Build the model -model = LogisticRegressionWithSGD.train(parsedData) +model = SVMWithSGD.train(parsedData, iterations=100) # Evaluating the model on training data labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) print("Training Error = " + str(trainErr)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = SVMModel.load(sc, "myModelPath") {% endhighlight %}
    diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index acdcc371487f8..bf6d124fd5d8d 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -53,7 +53,7 @@ val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) val training = splits(0) val test = splits(1) -val model = NaiveBayes.train(training, lambda = 1.0, model = "multinomial") +val model = NaiveBayes.train(training, lambda = 1.0, modelType = "multinomial") val predictionAndLabel = test.map(p => (model.predict(p.features), p.label)) val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count() From 452eb82dd722e5dfd00ee47bb8b6353933b0016e Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 2 Jun 2015 23:24:47 -0700 Subject: [PATCH 154/254] [SPARK-8032] [PYSPARK] Make version checking for NumPy in MLlib more robust The current checking does version `1.x' is less than `1.4' this will fail if x has greater than 1 digit, since x > 4, however `1.x` < `1.4` It fails in my system since I have version `1.10` :P Author: MechCoder Closes #6579 from MechCoder/np_ver and squashes the following commits: 15430f8 [MechCoder] fix syntax error 893fb7e [MechCoder] remove equal to e35f0d4 [MechCoder] minor e89376c [MechCoder] Better checking 22703dd [MechCoder] [SPARK-8032] Make version checking for NumPy in MLlib more robust --- python/pyspark/mllib/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py index b11aed2c3afda..acba3a717d21a 100644 --- a/python/pyspark/mllib/__init__.py +++ b/python/pyspark/mllib/__init__.py @@ -23,7 +23,9 @@ # MLlib currently needs NumPy 1.4+, so complain if lower import numpy -if numpy.version.version < '1.4': + +ver = [int(x) for x in numpy.version.version.split('.')[:2]] +if ver < [1, 4]: raise Exception("MLlib requires NumPy 1.4+") __all__ = ['classification', 'clustering', 'feature', 'fpm', 'linalg', 'random', From ce320cb2dbf28825f80795ce569735888f98d6e8 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 3 Jun 2015 00:23:34 -0700 Subject: [PATCH 155/254] [SPARK-8060] Improve DataFrame Python test coverage and documentation. Author: Reynold Xin Closes #6601 from rxin/python-read-write-test-and-doc and squashes the following commits: baa8ad5 [Reynold Xin] Code review feedback. f081d47 [Reynold Xin] More documentation updates. c9902fa [Reynold Xin] [SPARK-8060] Improve DataFrame Python reader/writer interface doc and testing. --- .rat-excludes | 1 + python/pyspark/sql/__init__.py | 13 +- python/pyspark/sql/context.py | 89 +++---- python/pyspark/sql/dataframe.py | 82 +++---- python/pyspark/sql/readwriter.py | 217 ++++++++---------- python/pyspark/sql/tests.py | 2 + .../sql/parquet_partitioned/_SUCCESS | 0 .../sql/parquet_partitioned/_common_metadata | Bin 0 -> 210 bytes .../sql/parquet_partitioned/_metadata | Bin 0 -> 743 bytes .../day=1/.part-r-00008.gz.parquet.crc | Bin 0 -> 12 bytes .../month=9/day=1/part-r-00008.gz.parquet | Bin 0 -> 322 bytes .../day=25/.part-r-00002.gz.parquet.crc | Bin 0 -> 12 bytes .../day=25/.part-r-00004.gz.parquet.crc | Bin 0 -> 12 bytes .../month=10/day=25/part-r-00002.gz.parquet | Bin 0 -> 343 bytes .../month=10/day=25/part-r-00004.gz.parquet | Bin 0 -> 343 bytes .../day=26/.part-r-00005.gz.parquet.crc | Bin 0 -> 12 bytes .../month=10/day=26/part-r-00005.gz.parquet | Bin 0 -> 333 bytes .../day=1/.part-r-00007.gz.parquet.crc | Bin 0 -> 12 bytes .../month=9/day=1/part-r-00007.gz.parquet | Bin 0 -> 343 bytes python/test_support/sql/people.json | 3 + 20 files changed, 180 insertions(+), 227 deletions(-) create mode 100644 python/test_support/sql/parquet_partitioned/_SUCCESS create mode 100644 python/test_support/sql/parquet_partitioned/_common_metadata create mode 100644 python/test_support/sql/parquet_partitioned/_metadata create mode 100644 python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc create mode 100644 python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet create mode 100644 python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00002.gz.parquet.crc create mode 100644 python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00004.gz.parquet.crc create mode 100644 python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00002.gz.parquet create mode 100644 python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00004.gz.parquet create mode 100644 python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc create mode 100644 python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet create mode 100644 python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc create mode 100644 python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet create mode 100644 python/test_support/sql/people.json diff --git a/.rat-excludes b/.rat-excludes index c0f81b57fe09d..8f2722cbd001f 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -82,3 +82,4 @@ local-1426633911242/* local-1430917381534/* DESCRIPTION NAMESPACE +test_support/* diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 726d288d97b2e..ad9c891ba1c04 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -45,11 +45,20 @@ def since(version): + """ + A decorator that annotates a function to append the version of Spark the function was added. + """ + import re + indent_p = re.compile(r'\n( +)') + def deco(f): - f.__doc__ = f.__doc__.rstrip() + "\n\n.. versionadded:: %s" % version + indents = indent_p.findall(f.__doc__) + indent = ' ' * (min(len(m) for m in indents) if indents else 0) + f.__doc__ = f.__doc__.rstrip() + "\n\n%s.. versionadded:: %s" % (indent, version) return f return deco + from pyspark.sql.types import Row from pyspark.sql.context import SQLContext, HiveContext from pyspark.sql.column import Column @@ -58,7 +67,9 @@ def deco(f): from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter from pyspark.sql.window import Window, WindowSpec + __all__ = [ 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', 'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec', + 'DataFrameReader', 'DataFrameWriter' ] diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 22f6257dfe02d..9fdf43c3e6eb5 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -124,7 +124,10 @@ def getConf(self, key, defaultValue): @property @since("1.3.1") def udf(self): - """Returns a :class:`UDFRegistration` for UDF registration.""" + """Returns a :class:`UDFRegistration` for UDF registration. + + :return: :class:`UDFRegistration` + """ return UDFRegistration(self) @since(1.4) @@ -138,7 +141,7 @@ def range(self, start, end, step=1, numPartitions=None): :param end: the end value (exclusive) :param step: the incremental step (default: 1) :param numPartitions: the number of partitions of the DataFrame - :return: A new DataFrame + :return: :class:`DataFrame` >>> sqlContext.range(1, 7, 2).collect() [Row(id=1), Row(id=3), Row(id=5)] @@ -195,8 +198,8 @@ def _inferSchema(self, rdd, samplingRatio=None): raise ValueError("The first row in RDD is empty, " "can not infer schema") if type(first) is dict: - warnings.warn("Using RDD of dict to inferSchema is deprecated," - "please use pyspark.sql.Row instead") + warnings.warn("Using RDD of dict to inferSchema is deprecated. " + "Use pyspark.sql.Row instead") if samplingRatio is None: schema = _infer_schema(first) @@ -219,7 +222,7 @@ def inferSchema(self, rdd, samplingRatio=None): """ .. note:: Deprecated in 1.3, use :func:`createDataFrame` instead. """ - warnings.warn("inferSchema is deprecated, please use createDataFrame instead") + warnings.warn("inferSchema is deprecated, please use createDataFrame instead.") if isinstance(rdd, DataFrame): raise TypeError("Cannot apply schema to DataFrame") @@ -262,6 +265,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): :class:`list`, or :class:`pandas.DataFrame`. :param schema: a :class:`StructType` or list of column names. default None. :param samplingRatio: the sample ratio of rows used for inferring + :return: :class:`DataFrame` >>> l = [('Alice', 1)] >>> sqlContext.createDataFrame(l).collect() @@ -359,18 +363,15 @@ def registerDataFrameAsTable(self, df, tableName): else: raise ValueError("Can only register DataFrame as table") - @since(1.0) def parquetFile(self, *paths): """Loads a Parquet file, returning the result as a :class:`DataFrame`. - >>> import tempfile, shutil - >>> parquetFile = tempfile.mkdtemp() - >>> shutil.rmtree(parquetFile) - >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlContext.parquetFile(parquetFile) - >>> sorted(df.collect()) == sorted(df2.collect()) - True + .. note:: Deprecated in 1.4, use :func:`DataFrameReader.parquet` instead. + + >>> sqlContext.parquetFile('python/test_support/sql/parquet_partitioned').dtypes + [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] """ + warnings.warn("parquetFile is deprecated. Use read.parquet() instead.") gateway = self._sc._gateway jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths)) for i in range(0, len(paths)): @@ -378,39 +379,15 @@ def parquetFile(self, *paths): jdf = self._ssql_ctx.parquetFile(jpaths) return DataFrame(jdf, self) - @since(1.0) def jsonFile(self, path, schema=None, samplingRatio=1.0): """Loads a text file storing one JSON object per line as a :class:`DataFrame`. - If the schema is provided, applies the given schema to this JSON dataset. - Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema. - - >>> import tempfile, shutil - >>> jsonFile = tempfile.mkdtemp() - >>> shutil.rmtree(jsonFile) - >>> with open(jsonFile, 'w') as f: - ... f.writelines(jsonStrings) - >>> df1 = sqlContext.jsonFile(jsonFile) - >>> df1.printSchema() - root - |-- field1: long (nullable = true) - |-- field2: string (nullable = true) - |-- field3: struct (nullable = true) - | |-- field4: long (nullable = true) + .. note:: Deprecated in 1.4, use :func:`DataFrameReader.json` instead. - >>> from pyspark.sql.types import * - >>> schema = StructType([ - ... StructField("field2", StringType()), - ... StructField("field3", - ... StructType([StructField("field5", ArrayType(IntegerType()))]))]) - >>> df2 = sqlContext.jsonFile(jsonFile, schema) - >>> df2.printSchema() - root - |-- field2: string (nullable = true) - |-- field3: struct (nullable = true) - | |-- field5: array (nullable = true) - | | |-- element: integer (containsNull = true) + >>> sqlContext.jsonFile('python/test_support/sql/people.json').dtypes + [('age', 'bigint'), ('name', 'string')] """ + warnings.warn("jsonFile is deprecated. Use read.json() instead.") if schema is None: df = self._ssql_ctx.jsonFile(path, samplingRatio) else: @@ -462,21 +439,16 @@ def func(iterator): df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) return DataFrame(df, self) - @since(1.3) def load(self, path=None, source=None, schema=None, **options): """Returns the dataset in a data source as a :class:`DataFrame`. - The data source is specified by the ``source`` and a set of ``options``. - If ``source`` is not specified, the default data source configured by - ``spark.sql.sources.default`` will be used. - - Optionally, a schema can be provided as the schema of the returned DataFrame. + .. note:: Deprecated in 1.4, use :func:`DataFrameReader.load` instead. """ + warnings.warn("load is deprecated. Use read.load() instead.") return self.read.load(path, source, schema, **options) @since(1.3) - def createExternalTable(self, tableName, path=None, source=None, - schema=None, **options): + def createExternalTable(self, tableName, path=None, source=None, schema=None, **options): """Creates an external table based on the dataset in a data source. It returns the DataFrame associated with the external table. @@ -487,6 +459,8 @@ def createExternalTable(self, tableName, path=None, source=None, Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and created external table. + + :return: :class:`DataFrame` """ if path is not None: options["path"] = path @@ -508,6 +482,8 @@ def createExternalTable(self, tableName, path=None, source=None, def sql(self, sqlQuery): """Returns a :class:`DataFrame` representing the result of the given query. + :return: :class:`DataFrame` + >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.sql("SELECT field1 AS f1, field2 as f2 from table1") >>> df2.collect() @@ -519,6 +495,8 @@ def sql(self, sqlQuery): def table(self, tableName): """Returns the specified table as a :class:`DataFrame`. + :return: :class:`DataFrame` + >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.table("table1") >>> sorted(df.collect()) == sorted(df2.collect()) @@ -536,6 +514,9 @@ def tables(self, dbName=None): The returned DataFrame has two columns: ``tableName`` and ``isTemporary`` (a column with :class:`BooleanType` indicating if a table is a temporary one or not). + :param dbName: string, name of the database to use. + :return: :class:`DataFrame` + >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.tables() >>> df2.filter("tableName = 'table1'").first() @@ -550,7 +531,8 @@ def tables(self, dbName=None): def tableNames(self, dbName=None): """Returns a list of names of tables in the database ``dbName``. - If ``dbName`` is not specified, the current database will be used. + :param dbName: string, name of the database to use. Default to the current database. + :return: list of table names, in string >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> "table1" in sqlContext.tableNames() @@ -585,8 +567,7 @@ def read(self): Returns a :class:`DataFrameReader` that can be used to read data in as a :class:`DataFrame`. - >>> sqlContext.read - + :return: :class:`DataFrameReader` """ return DataFrameReader(self) @@ -644,10 +625,14 @@ def register(self, name, f, returnType=StringType()): def _test(): + import os import doctest from pyspark.context import SparkContext from pyspark.sql import Row, SQLContext import pyspark.sql.context + + os.chdir(os.environ["SPARK_HOME"]) + globs = pyspark.sql.context.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index a82b6b87c413e..7673153abe0e2 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -44,7 +44,7 @@ class DataFrame(object): A :class:`DataFrame` is equivalent to a relational table in Spark SQL, and can be created using various functions in :class:`SQLContext`:: - people = sqlContext.parquetFile("...") + people = sqlContext.read.parquet("...") Once created, it can be manipulated using the various domain-specific-language (DSL) functions defined in: :class:`DataFrame`, :class:`Column`. @@ -56,8 +56,8 @@ class DataFrame(object): A more concrete example:: # To create DataFrame using SQLContext - people = sqlContext.parquetFile("...") - department = sqlContext.parquetFile("...") + people = sqlContext.read.parquet("...") + department = sqlContext.read.parquet("...") people.filter(people.age > 30).join(department, people.deptId == department.id)) \ .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) @@ -120,21 +120,12 @@ def toJSON(self, use_unicode=True): rdd = self._jdf.toJSON() return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) - @since(1.3) def saveAsParquetFile(self, path): """Saves the contents as a Parquet file, preserving the schema. - Files that are written out using this method can be read back in as - a :class:`DataFrame` using :func:`SQLContext.parquetFile`. - - >>> import tempfile, shutil - >>> parquetFile = tempfile.mkdtemp() - >>> shutil.rmtree(parquetFile) - >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlContext.parquetFile(parquetFile) - >>> sorted(df2.collect()) == sorted(df.collect()) - True + .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.parquet` instead. """ + warnings.warn("saveAsParquetFile is deprecated. Use write.parquet() instead.") self._jdf.saveAsParquetFile(path) @since(1.3) @@ -151,69 +142,45 @@ def registerTempTable(self, name): """ self._jdf.registerTempTable(name) - @since(1.3) def registerAsTable(self, name): - """DEPRECATED: use :func:`registerTempTable` instead""" - warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) + """ + .. note:: Deprecated in 1.4, use :func:`registerTempTable` instead. + """ + warnings.warn("Use registerTempTable instead of registerAsTable.") self.registerTempTable(name) - @since(1.3) def insertInto(self, tableName, overwrite=False): """Inserts the contents of this :class:`DataFrame` into the specified table. - Optionally overwriting any existing data. + .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.insertInto` instead. """ + warnings.warn("insertInto is deprecated. Use write.insertInto() instead.") self.write.insertInto(tableName, overwrite) - @since(1.3) def saveAsTable(self, tableName, source=None, mode="error", **options): """Saves the contents of this :class:`DataFrame` to a data source as a table. - The data source is specified by the ``source`` and a set of ``options``. - If ``source`` is not specified, the default data source configured by - ``spark.sql.sources.default`` will be used. - - Additionally, mode is used to specify the behavior of the saveAsTable operation when - table already exists in the data source. There are four modes: - - * `append`: Append contents of this :class:`DataFrame` to existing data. - * `overwrite`: Overwrite existing data. - * `error`: Throw an exception if data already exists. - * `ignore`: Silently ignore this operation if data already exists. + .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.saveAsTable` instead. """ + warnings.warn("insertInto is deprecated. Use write.saveAsTable() instead.") self.write.saveAsTable(tableName, source, mode, **options) @since(1.3) def save(self, path=None, source=None, mode="error", **options): """Saves the contents of the :class:`DataFrame` to a data source. - The data source is specified by the ``source`` and a set of ``options``. - If ``source`` is not specified, the default data source configured by - ``spark.sql.sources.default`` will be used. - - Additionally, mode is used to specify the behavior of the save operation when - data already exists in the data source. There are four modes: - - * `append`: Append contents of this :class:`DataFrame` to existing data. - * `overwrite`: Overwrite existing data. - * `error`: Throw an exception if data already exists. - * `ignore`: Silently ignore this operation if data already exists. + .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.save` instead. """ + warnings.warn("insertInto is deprecated. Use write.save() instead.") return self.write.save(path, source, mode, **options) @property @since(1.4) def write(self): """ - Interface for saving the content of the :class:`DataFrame` out - into external storage. - - :return :class:`DataFrameWriter` + Interface for saving the content of the :class:`DataFrame` out into external storage. - .. note:: Experimental - - >>> df.write - + :return: :class:`DataFrameWriter` """ return DataFrameWriter(self) @@ -636,6 +603,9 @@ def describe(self, *cols): This include count, mean, stddev, min, and max. If no columns are given, this function computes statistics for all numerical columns. + .. note:: This function is meant for exploratory data analysis, as we make no \ + guarantee about the backward compatibility of the schema of the resulting DataFrame. + >>> df.describe().show() +-------+---+ |summary|age| @@ -653,9 +623,11 @@ def describe(self, *cols): @ignore_unicode_prefix @since(1.3) def head(self, n=None): - """ - Returns the first ``n`` rows as a list of :class:`Row`, - or the first :class:`Row` if ``n`` is ``None.`` + """Returns the first ``n`` rows. + + :param n: int, default 1. Number of rows to return. + :return: If n is greater than 1, return a list of :class:`Row`. + If n is 1, return a single Row. >>> df.head() Row(age=2, name=u'Alice') @@ -1170,8 +1142,8 @@ def freqItems(self, cols, support=None): "http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou". :func:`DataFrame.freqItems` and :func:`DataFrameStatFunctions.freqItems` are aliases. - This function is meant for exploratory data analysis, as we make no guarantee about the - backward compatibility of the schema of the resulting DataFrame. + .. note:: This function is meant for exploratory data analysis, as we make no \ + guarantee about the backward compatibility of the schema of the resulting DataFrame. :param cols: Names of the columns to calculate frequent items for as a list or tuple of strings. diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index d17d87419fe3d..f036644acc961 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -45,18 +45,24 @@ def _df(self, jdf): @since(1.4) def format(self, source): - """ - Specifies the input data source format. + """Specifies the input data source format. + + :param source: string, name of the data source, e.g. 'json', 'parquet'. + + >>> df = sqlContext.read.format('json').load('python/test_support/sql/people.json') + >>> df.dtypes + [('age', 'bigint'), ('name', 'string')] + """ self._jreader = self._jreader.format(source) return self @since(1.4) def schema(self, schema): - """ - Specifies the input schema. Some data sources (e.g. JSON) can - infer the input schema automatically from data. By specifying - the schema here, the underlying data source can skip the schema + """Specifies the input schema. + + Some data sources (e.g. JSON) can infer the input schema automatically from data. + By specifying the schema here, the underlying data source can skip the schema inference step, and thus speed up data loading. :param schema: a StructType object @@ -69,8 +75,7 @@ def schema(self, schema): @since(1.4) def options(self, **options): - """ - Adds input options for the underlying data source. + """Adds input options for the underlying data source. """ for k in options: self._jreader = self._jreader.option(k, options[k]) @@ -84,6 +89,10 @@ def load(self, path=None, format=None, schema=None, **options): :param format: optional string for format of the data source. Default to 'parquet'. :param schema: optional :class:`StructType` for the input schema. :param options: all other string options + + >>> df = sqlContext.read.load('python/test_support/sql/parquet_partitioned') + >>> df.dtypes + [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] """ if format is not None: self.format(format) @@ -107,31 +116,10 @@ def json(self, path, schema=None): :param path: string, path to the JSON dataset. :param schema: an optional :class:`StructType` for the input schema. - >>> import tempfile, shutil - >>> jsonFile = tempfile.mkdtemp() - >>> shutil.rmtree(jsonFile) - >>> with open(jsonFile, 'w') as f: - ... f.writelines(jsonStrings) - >>> df1 = sqlContext.read.json(jsonFile) - >>> df1.printSchema() - root - |-- field1: long (nullable = true) - |-- field2: string (nullable = true) - |-- field3: struct (nullable = true) - | |-- field4: long (nullable = true) - - >>> from pyspark.sql.types import * - >>> schema = StructType([ - ... StructField("field2", StringType()), - ... StructField("field3", - ... StructType([StructField("field5", ArrayType(IntegerType()))]))]) - >>> df2 = sqlContext.read.json(jsonFile, schema) - >>> df2.printSchema() - root - |-- field2: string (nullable = true) - |-- field3: struct (nullable = true) - | |-- field5: array (nullable = true) - | | |-- element: integer (containsNull = true) + >>> df = sqlContext.read.json('python/test_support/sql/people.json') + >>> df.dtypes + [('age', 'bigint'), ('name', 'string')] + """ if schema is not None: self.schema(schema) @@ -141,10 +129,12 @@ def json(self, path, schema=None): def table(self, tableName): """Returns the specified table as a :class:`DataFrame`. - >>> sqlContext.registerDataFrameAsTable(df, "table1") - >>> df2 = sqlContext.read.table("table1") - >>> sorted(df.collect()) == sorted(df2.collect()) - True + :param tableName: string, name of the table. + + >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned') + >>> df.registerTempTable('tmpTable') + >>> sqlContext.read.table('tmpTable').dtypes + [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] """ return self._df(self._jreader.table(tableName)) @@ -152,13 +142,9 @@ def table(self, tableName): def parquet(self, *path): """Loads a Parquet file, returning the result as a :class:`DataFrame`. - >>> import tempfile, shutil - >>> parquetFile = tempfile.mkdtemp() - >>> shutil.rmtree(parquetFile) - >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlContext.read.parquet(parquetFile) - >>> sorted(df.collect()) == sorted(df2.collect()) - True + >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned') + >>> df.dtypes + [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] """ return self._df(self._jreader.parquet(_to_seq(self._sqlContext._sc, path))) @@ -221,30 +207,34 @@ def __init__(self, df): @since(1.4) def mode(self, saveMode): - """ - Specifies the behavior when data or table already exists. Options include: + """Specifies the behavior when data or table already exists. + + Options include: * `append`: Append contents of this :class:`DataFrame` to existing data. * `overwrite`: Overwrite existing data. * `error`: Throw an exception if data already exists. * `ignore`: Silently ignore this operation if data already exists. + + >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ self._jwrite = self._jwrite.mode(saveMode) return self @since(1.4) def format(self, source): - """ - Specifies the underlying output data source. Built-in options include - "parquet", "json", etc. + """Specifies the underlying output data source. + + :param source: string, name of the data source, e.g. 'json', 'parquet'. + + >>> df.write.format('json').save(os.path.join(tempfile.mkdtemp(), 'data')) """ self._jwrite = self._jwrite.format(source) return self @since(1.4) def options(self, **options): - """ - Adds output options for the underlying data source. + """Adds output options for the underlying data source. """ for k in options: self._jwrite = self._jwrite.option(k, options[k]) @@ -252,12 +242,14 @@ def options(self, **options): @since(1.4) def partitionBy(self, *cols): - """ - Partitions the output by the given columns on the file system. + """Partitions the output by the given columns on the file system. + If specified, the output is laid out on the file system similar to Hive's partitioning scheme. :param cols: name of columns + + >>> df.write.partitionBy('year', 'month').parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ if len(cols) == 1 and isinstance(cols[0], (list, tuple)): cols = cols[0] @@ -266,25 +258,23 @@ def partitionBy(self, *cols): @since(1.4) def save(self, path=None, format=None, mode="error", **options): - """ - Saves the contents of the :class:`DataFrame` to a data source. + """Saves the contents of the :class:`DataFrame` to a data source. The data source is specified by the ``format`` and a set of ``options``. If ``format`` is not specified, the default data source configured by ``spark.sql.sources.default`` will be used. - Additionally, mode is used to specify the behavior of the save operation when - data already exists in the data source. There are four modes: - - * `append`: Append contents of this :class:`DataFrame` to existing data. - * `overwrite`: Overwrite existing data. - * `error`: Throw an exception if data already exists. - * `ignore`: Silently ignore this operation if data already exists. - :param path: the path in a Hadoop supported file system :param format: the format used to save - :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) + :param mode: specifies the behavior of the save operation when data already exists. + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. :param options: all other string options + + >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode).options(**options) if format is not None: @@ -296,8 +286,8 @@ def save(self, path=None, format=None, mode="error", **options): @since(1.4) def insertInto(self, tableName, overwrite=False): - """ - Inserts the content of the :class:`DataFrame` to the specified table. + """Inserts the content of the :class:`DataFrame` to the specified table. + It requires that the schema of the class:`DataFrame` is the same as the schema of the table. @@ -307,8 +297,7 @@ def insertInto(self, tableName, overwrite=False): @since(1.4) def saveAsTable(self, name, format=None, mode="error", **options): - """ - Saves the content of the :class:`DataFrame` as the specified table. + """Saves the content of the :class:`DataFrame` as the specified table. In the case the table already exists, behavior of this function depends on the save mode, specified by the `mode` function (default to throwing an exception). @@ -328,67 +317,58 @@ def saveAsTable(self, name, format=None, mode="error", **options): self.mode(mode).options(**options) if format is not None: self.format(format) - return self._jwrite.saveAsTable(name) + self._jwrite.saveAsTable(name) @since(1.4) def json(self, path, mode="error"): - """ - Saves the content of the :class:`DataFrame` in JSON format at the - specified path. + """Saves the content of the :class:`DataFrame` in JSON format at the specified path. - Additionally, mode is used to specify the behavior of the save operation when - data already exists in the data source. There are four modes: + :param path: the path in any Hadoop supported file system + :param mode: specifies the behavior of the save operation when data already exists. - * `append`: Append contents of this :class:`DataFrame` to existing data. - * `overwrite`: Overwrite existing data. - * `error`: Throw an exception if data already exists. - * `ignore`: Silently ignore this operation if data already exists. + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. - :param path: the path in any Hadoop supported file system - :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) + >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ - return self._jwrite.mode(mode).json(path) + self._jwrite.mode(mode).json(path) @since(1.4) def parquet(self, path, mode="error"): - """ - Saves the content of the :class:`DataFrame` in Parquet format at the - specified path. + """Saves the content of the :class:`DataFrame` in Parquet format at the specified path. - Additionally, mode is used to specify the behavior of the save operation when - data already exists in the data source. There are four modes: + :param path: the path in any Hadoop supported file system + :param mode: specifies the behavior of the save operation when data already exists. - * `append`: Append contents of this :class:`DataFrame` to existing data. - * `overwrite`: Overwrite existing data. - * `error`: Throw an exception if data already exists. - * `ignore`: Silently ignore this operation if data already exists. + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. - :param path: the path in any Hadoop supported file system - :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) + >>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ - return self._jwrite.mode(mode).parquet(path) + self._jwrite.mode(mode).parquet(path) @since(1.4) def jdbc(self, url, table, mode="error", properties={}): - """ - Saves the content of the :class:`DataFrame` to a external database table - via JDBC. - - In the case the table already exists in the external database, - behavior of this function depends on the save mode, specified by the `mode` - function (default to throwing an exception). There are four modes: + """Saves the content of the :class:`DataFrame` to a external database table via JDBC. - * `append`: Append contents of this :class:`DataFrame` to existing data. - * `overwrite`: Overwrite existing data. - * `error`: Throw an exception if data already exists. - * `ignore`: Silently ignore this operation if data already exists. + .. note:: Don't create too many partitions in parallel on a large cluster;\ + otherwise Spark might crash your external database systems. - :param url: a JDBC URL of the form `jdbc:subprotocol:subname` + :param url: a JDBC URL of the form ``jdbc:subprotocol:subname`` :param table: Name of the table in the external database. - :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) + :param mode: specifies the behavior of the save operation when data already exists. + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. :param properties: JDBC database connection arguments, a list of - arbitrary string tag/value. Normally at least a - "user" and "password" property should be included. + arbitrary string tag/value. Normally at least a + "user" and "password" property should be included. """ jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() for k in properties: @@ -398,24 +378,23 @@ def jdbc(self, url, table, mode="error", properties={}): def _test(): import doctest + import os + import tempfile from pyspark.context import SparkContext from pyspark.sql import Row, SQLContext import pyspark.sql.readwriter + + os.chdir(os.environ["SPARK_HOME"]) + globs = pyspark.sql.readwriter.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') + + globs['tempfile'] = tempfile + globs['os'] = os globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) - globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ - .toDF(StructType([StructField('age', IntegerType()), - StructField('name', StringType())])) - jsonStrings = [ - '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', - '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' - '"field6":[{"field7": "row2"}]}', - '{"field1" : null, "field2": "row3", ' - '"field3":{"field4":33, "field5": []}}' - ] - globs['jsonStrings'] = jsonStrings + globs['df'] = globs['sqlContext'].read.parquet('python/test_support/sql/parquet_partitioned') + (failure_count, test_count) = doctest.testmod( pyspark.sql.readwriter, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 76384d31f1bf4..6e498f0af0af5 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -753,8 +753,10 @@ def setUpClass(cls): try: cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() except py4j.protocol.Py4JError: + cls.tearDownClass() raise unittest.SkipTest("Hive is not available") except TypeError: + cls.tearDownClass() raise unittest.SkipTest("Hive is not available") os.unlink(cls.tempdir.name) _scala_HiveContext =\ diff --git a/python/test_support/sql/parquet_partitioned/_SUCCESS b/python/test_support/sql/parquet_partitioned/_SUCCESS new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/python/test_support/sql/parquet_partitioned/_common_metadata b/python/test_support/sql/parquet_partitioned/_common_metadata new file mode 100644 index 0000000000000000000000000000000000000000..7ef2320651dee5f2a4d588ecf8d0b281a2bc3018 GIT binary patch literal 210 zcmYk1!3u&v5QYcw=+#i_AOju(TauuIw{9J!W6@#L&7^f#ch@4s*Xz03qM*|Z%=dr8 zpKo@l?}W+LRZ<$?0pE+Az!kJ%F~9^uFPsH)sVYKST3i^>Emc>dJ5KD<^~?|@@1$Xd zmekN-KcIQE3^UY5^@YI%&o$$v#_TZQTWe3Bk^F(Rs4OUY&gqF;!bVwwKPhIzI37m` brr(!~MnyNKbS*`ck~LYXVg*kC$ZeY!e^o%_ literal 0 HcmV?d00001 diff --git a/python/test_support/sql/parquet_partitioned/_metadata b/python/test_support/sql/parquet_partitioned/_metadata new file mode 100644 index 0000000000000000000000000000000000000000..78a1ca7d38279147e00a826e627f7514d988bd8e GIT binary patch literal 743 zcmb7?&riZI6vyjCA%^RgE^;7EFfke3h76U+ftwc+!pZpKP`3&TWgA`5Oyuq#;D6^o zW7!Xai3#`)eXs57`{ea~hy9VQD!Or7;$bLM1*p}A0!smz(FOq8iT~h>;mEB2-`{-EoU3j+6jrYuY)zEJn-EKp==XmwCF#y_WraHO@felu$%|` z(K_3`IXh{d_L=r}G$4ZdFmoBn%lg_3s`$k}26efUv-!gz5!`pDu$%|Kx;hW}7?X&& z6N+Ow_$iL(tWW^v;TxV&K|CS|yk8=bL=<&VEcn6|$UrYXWnPTB4@g~45h?>03XD` A@&Et; literal 0 HcmV?d00001 diff --git a/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc new file mode 100644 index 0000000000000000000000000000000000000000..e93f42ed6f35018df1ae76c9746b26bd299c0339 GIT binary patch literal 12 TcmYc;N@ieSU}9KgSR(=e5vKy4 literal 0 HcmV?d00001 diff --git a/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet new file mode 100644 index 0000000000000000000000000000000000000000..461c382937ecdf13db1003bd2bf8c9a79c510d3e GIT binary patch literal 322 zcmWG=3^EjD5S0?O(-CC?GT1~pWF(j!b27n%7y}Tga9#tj@mb}E=R8}Ei*MIrC7--x>^b35TF#8(m_&~nU@Y! zm{*#UlbDnPQ~}hQs-pxmRLQEkwl=nwK|&g8rEYGKLRo52ab|v=f}x(7owc~qGsSn?8cQ&kxl#F!*yBxTe%WJGx+ zP1zVYBq1`QEMiPz1!7Ye)i`Y6w!s--Yk|^C43aVun)yZPdWi*z$r-77#RZ8)*?Pr= zIeI`wVQFfKUQvFzUT$hhVoG93qC`}+Qb}b&s*;sbaY<2Wa*2|TQd(wePD-(oRdlry z$VEUYFr|Z}Ff%V5s4%ZICnqr}2dDz5HC0CmW~h=?b!}~IErWzK)JomlB89TlqT4=Dm5m#e6wK014Xam4qu!d8N{n7iq=X0 zb^h^Qf1OXk&E@BCXbc2#aBV9oHG%*Q#?Z5KmhmwFF2p|eCytK>__PNc{No_og>K=# zSrg}?YwN_m*4PkW-1E*!d)FUmof*P@{xTZ=z(~N7DFwMN%hUmKBBqXI) zRjf%s)+rZBNy58^>@G6ao`QeDG~bwDUJ1cg!X(Tn56ItA5;kpn-vaO8xAG`cqbIJ) UROX`@J)_4eJ^_{mz{0%r8w^idvH$=8 literal 0 HcmV?d00001 diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc new file mode 100644 index 0000000000000000000000000000000000000000..ae94a15d08c816a6fb60117395462aa640da5026 GIT binary patch literal 12 TcmYc;N@ieSU}AV;?~?@p63+t^ literal 0 HcmV?d00001 diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet new file mode 100644 index 0000000000000000000000000000000000000000..6cb8538aa890475f9128c8025e2e9b0f72b1b785 GIT binary patch literal 333 zcmZut%SyvQ6di}O1}V6jFoVICffR%|SVR}Wjjr4X#ib%-noMhO^5|ruNXTAs=gv>? zC;T(PTdm;2vpI)*&V6vFrOxdZ`77WuvSx<%7tTm8rCnUbWmlR*FZwwx&re5BWS( zI<0wh-SX8nV0}~gCzurr2o{aja;6~xtt#ZdLwVG8-A#w+&U)p3ZbtXY)LB`KCgNBe NnB)+B!ULx8$S(}ES498- literal 0 HcmV?d00001 diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc new file mode 100644 index 0000000000000000000000000000000000000000..58d9bb5fc5883f0b65b637a8deea2672b9fee8a2 GIT binary patch literal 12 TcmYc;N@ieSU}C7hE>;2n5}^Yd literal 0 HcmV?d00001 diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet new file mode 100644 index 0000000000000000000000000000000000000000..9b00805481e7b311763d48ff965966b4d09daa99 GIT binary patch literal 343 zcmYjN%SyvQ6ulWjn?*Mw6EYaE45TQ;!6Lc{g1Asz2A7JEX)>*$d37>TB;*6cow)In z`~?5Ok8oNwcsA$2IrpB+4bQKq7%;_`K1Ny$u;n_#kSm$S%U;-^vHN1JNh6*`Q8Z76 zug0@?@&54%+h==UTiU>g_*bSZON9~Ok%t_!;JNSsY(!k*PAnIX$ngLy^5bCBMs{Vt z858TYZ|lXTR@(@O>+F|u!Fa{vd%^08%O$H<8Pj6b2*qUi$a0~0!WDOJTB@EZK?7PV z*~E(abe@VVscCTA()C5!+K~S*m=+5iESfCivrH%SsPO6EQW~^fch`Zl^ILh4%khJd Vby^nVDLY|@GCl&s00{L Date: Wed, 3 Jun 2015 00:47:52 -0700 Subject: [PATCH 156/254] [SPARK-7562][SPARK-6444][SQL] Improve error reporting for expression data type mismatch It seems hard to find a common pattern of checking types in `Expression`. Sometimes we know what input types we need(like `And`, we know we need two booleans), sometimes we just have some rules(like `Add`, we need 2 numeric types which are equal). So I defined a general interface `checkInputDataTypes` in `Expression` which returns a `TypeCheckResult`. `TypeCheckResult` can tell whether this expression passes the type checking or what the type mismatch is. This PR mainly works on apply input types checking for arithmetic and predicate expressions. TODO: apply type checking interface to more expressions. Author: Wenchen Fan Closes #6405 from cloud-fan/6444 and squashes the following commits: b5ff31b [Wenchen Fan] address comments b917275 [Wenchen Fan] rebase 39929d9 [Wenchen Fan] add todo 0808fd2 [Wenchen Fan] make constrcutor of TypeCheckResult private 3bee157 [Wenchen Fan] and decimal type coercion rule for binary comparison 8883025 [Wenchen Fan] apply type check interface to CaseWhen cffb67c [Wenchen Fan] to have resolved call the data type check function 6eaadff [Wenchen Fan] add equal type constraint to EqualTo 3affbd8 [Wenchen Fan] more fixes 654d46a [Wenchen Fan] improve tests e0a3628 [Wenchen Fan] improve error message 1524ff6 [Wenchen Fan] fix style 69ca3fe [Wenchen Fan] add error message and tests c71d02c [Wenchen Fan] fix hive tests 6491721 [Wenchen Fan] use value class TypeCheckResult 7ae76b9 [Wenchen Fan] address comments cb77e4f [Wenchen Fan] Improve error reporting for expression data type mismatch --- .../org/apache/spark/SparkFunSuite.scala | 4 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 12 +- .../catalyst/analysis/HiveTypeCoercion.scala | 132 ++++---- .../catalyst/analysis/TypeCheckResult.scala | 45 +++ .../sql/catalyst/expressions/Expression.scala | 28 +- .../sql/catalyst/expressions/arithmetic.scala | 308 +++++++----------- .../expressions/mathfuncs/binary.scala | 17 +- .../sql/catalyst/expressions/predicates.scala | 226 ++++++------- .../sql/catalyst/optimizer/Optimizer.scala | 4 + .../spark/sql/catalyst/util/DateUtils.scala | 2 +- .../spark/sql/catalyst/util/TypeUtils.scala | 56 ++++ .../org/apache/spark/sql/types/DataType.scala | 2 +- .../analysis/DecimalPrecisionSuite.scala | 6 +- .../analysis/HiveTypeCoercionSuite.scala | 15 +- .../ExpressionTypeCheckingSuite.scala | 143 ++++++++ .../apache/spark/sql/json/InferSchema.scala | 2 +- .../org/apache/spark/sql/json/JsonRDD.scala | 2 +- 17 files changed, 583 insertions(+), 421 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 8cb344332668f..9be9db01c7de9 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -30,8 +30,8 @@ private[spark] abstract class SparkFunSuite extends FunSuite with Logging { * Log the suite name and the test name before and after each test. * * Subclasses should never override this method. If they wish to run - * custom code before and after each test, they should should mix in - * the {{org.scalatest.BeforeAndAfter}} trait instead. + * custom code before and after each test, they should mix in the + * {{org.scalatest.BeforeAndAfter}} trait instead. */ final protected override def withFixture(test: NoArgTest): Outcome = { val testName = test.text diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 193dc6b6546b5..c0695ae369421 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -62,15 +62,17 @@ trait CheckAnalysis { val from = operator.inputSet.map(_.name).mkString(", ") a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") + case e: Expression if e.checkInputDataTypes().isFailure => + e.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(message) => + e.failAnalysis( + s"cannot resolve '${e.prettyString}' due to data type mismatch: $message") + } + case c: Cast if !c.resolved => failAnalysis( s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") - case b: BinaryExpression if !b.resolved => - failAnalysis( - s"invalid expression ${b.prettyString} " + - s"between ${b.left.dataType.simpleString} and ${b.right.dataType.simpleString}") - case WindowExpression(UnresolvedWindowFunction(name, _), _) => failAnalysis( s"Could not resolve window function '$name'. " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index edcc918bfe921..b064600e94fac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -41,7 +41,7 @@ object HiveTypeCoercion { * with primitive types, because in that case the precision and scale of the result depends on * the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]]. */ - val findTightestCommonType: (DataType, DataType) => Option[DataType] = { + val findTightestCommonTypeOfTwo: (DataType, DataType) => Option[DataType] = { case (t1, t2) if t1 == t2 => Some(t1) case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) @@ -57,6 +57,17 @@ object HiveTypeCoercion { case _ => None } + + /** + * Find the tightest common type of a set of types by continuously applying + * `findTightestCommonTypeOfTwo` on these types. + */ + private def findTightestCommonType(types: Seq[DataType]) = { + types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { + case None => None + case Some(d) => findTightestCommonTypeOfTwo(d, c) + }) + } } /** @@ -180,7 +191,7 @@ trait HiveTypeCoercion { case (l, r) if l.dataType != r.dataType => logDebug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") - findTightestCommonType(l.dataType, r.dataType).map { widestType => + findTightestCommonTypeOfTwo(l.dataType, r.dataType).map { widestType => val newLeft = if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)() val newRight = @@ -217,7 +228,7 @@ trait HiveTypeCoercion { case e if !e.childrenResolved => e case b: BinaryExpression if b.left.dataType != b.right.dataType => - findTightestCommonType(b.left.dataType, b.right.dataType).map { widestType => + findTightestCommonTypeOfTwo(b.left.dataType, b.right.dataType).map { widestType => val newLeft = if (b.left.dataType == widestType) b.left else Cast(b.left, widestType) val newRight = @@ -441,21 +452,18 @@ trait HiveTypeCoercion { DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) ) - case LessThan(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case GreaterThan(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + // When we compare 2 decimal types with different precisions, cast them to the smallest + // common precision. + case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + val resultType = DecimalType(max(p1, p2), max(s1, s2)) + b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType))) + case b @ BinaryComparison(e1 @ DecimalType.Fixed(_, _), e2) + if e2.dataType == DecimalType.Unlimited => + b.makeCopy(Array(Cast(e1, DecimalType.Unlimited), e2)) + case b @ BinaryComparison(e1, e2 @ DecimalType.Fixed(_, _)) + if e1.dataType == DecimalType.Unlimited => + b.makeCopy(Array(e1, Cast(e2, DecimalType.Unlimited))) // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles @@ -570,7 +578,7 @@ trait HiveTypeCoercion { case a @ CreateArray(children) if !a.resolved => val commonType = a.childTypes.reduce( - (a, b) => findTightestCommonType(a, b).getOrElse(StringType)) + (a, b) => findTightestCommonTypeOfTwo(a, b).getOrElse(StringType)) CreateArray( children.map(c => if (c.dataType == commonType) c else Cast(c, commonType))) @@ -599,14 +607,9 @@ trait HiveTypeCoercion { // from the list. So we need to make sure the return type is deterministic and // compatible with every child column. case Coalesce(es) if es.map(_.dataType).distinct.size > 1 => - val dt: Option[DataType] = Some(NullType) val types = es.map(_.dataType) - val rt = types.foldLeft(dt)((r, c) => r match { - case None => None - case Some(d) => findTightestCommonType(d, c) - }) - rt match { - case Some(finaldt) => Coalesce(es.map(Cast(_, finaldt))) + findTightestCommonType(types) match { + case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) case None => sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}") } @@ -619,17 +622,13 @@ trait HiveTypeCoercion { */ object Division extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - // Skip nodes who's children have not been resolved yet. - case e if !e.childrenResolved => e + // Skip nodes who has not been resolved yet, + // as this is an extra rule which should be applied at last. + case e if !e.resolved => e // Decimal and Double remain the same - case d: Divide if d.resolved && d.dataType == DoubleType => d - case d: Divide if d.resolved && d.dataType.isInstanceOf[DecimalType] => d - - case Divide(l, r) if l.dataType.isInstanceOf[DecimalType] => - Divide(l, Cast(r, DecimalType.Unlimited)) - case Divide(l, r) if r.dataType.isInstanceOf[DecimalType] => - Divide(Cast(l, DecimalType.Unlimited), r) + case d: Divide if d.dataType == DoubleType => d + case d: Divide if d.dataType.isInstanceOf[DecimalType] => d case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType)) } @@ -642,42 +641,33 @@ trait HiveTypeCoercion { import HiveTypeCoercion._ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case cw: CaseWhenLike if cw.childrenResolved && !cw.valueTypesEqual => - logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}") - val commonType = cw.valueTypes.reduce { (v1, v2) => - findTightestCommonType(v1, v2).getOrElse(sys.error( - s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2")) - } - val transformedBranches = cw.branches.sliding(2, 2).map { - case Seq(when, value) if value.dataType != commonType => - Seq(when, Cast(value, commonType)) - case Seq(elseVal) if elseVal.dataType != commonType => - Seq(Cast(elseVal, commonType)) - case s => s - }.reduce(_ ++ _) - cw match { - case _: CaseWhen => - CaseWhen(transformedBranches) - case CaseKeyWhen(key, _) => - CaseKeyWhen(key, transformedBranches) - } - - case ckw: CaseKeyWhen if ckw.childrenResolved && !ckw.resolved => - val commonType = (ckw.key +: ckw.whenList).map(_.dataType).reduce { (v1, v2) => - findTightestCommonType(v1, v2).getOrElse(sys.error( - s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2")) - } - val transformedBranches = ckw.branches.sliding(2, 2).map { - case Seq(when, then) if when.dataType != commonType => - Seq(Cast(when, commonType), then) - case s => s - }.reduce(_ ++ _) - val transformedKey = if (ckw.key.dataType != commonType) { - Cast(ckw.key, commonType) - } else { - ckw.key - } - CaseKeyWhen(transformedKey, transformedBranches) + case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual => + logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}") + val maybeCommonType = findTightestCommonType(c.valueTypes) + maybeCommonType.map { commonType => + val castedBranches = c.branches.grouped(2).map { + case Seq(when, value) if value.dataType != commonType => + Seq(when, Cast(value, commonType)) + case Seq(elseVal) if elseVal.dataType != commonType => + Seq(Cast(elseVal, commonType)) + case other => other + }.reduce(_ ++ _) + c match { + case _: CaseWhen => CaseWhen(castedBranches) + case CaseKeyWhen(key, _) => CaseKeyWhen(key, castedBranches) + } + }.getOrElse(c) + + case c: CaseKeyWhen if c.childrenResolved && !c.resolved => + val maybeCommonType = findTightestCommonType((c.key +: c.whenList).map(_.dataType)) + maybeCommonType.map { commonType => + val castedBranches = c.branches.grouped(2).map { + case Seq(when, then) if when.dataType != commonType => + Seq(Cast(when, commonType), then) + case other => other + }.reduce(_ ++ _) + CaseKeyWhen(Cast(c.key, commonType), castedBranches) + }.getOrElse(c) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala new file mode 100644 index 0000000000000..79c3528a522d3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +/** + * Represents the result of `Expression.checkInputDataTypes`. + * We will throw `AnalysisException` in `CheckAnalysis` if `isFailure` is true. + */ +trait TypeCheckResult { + def isFailure: Boolean = !isSuccess + def isSuccess: Boolean +} + +object TypeCheckResult { + + /** + * Represents the successful result of `Expression.checkInputDataTypes`. + */ + object TypeCheckSuccess extends TypeCheckResult { + def isSuccess: Boolean = true + } + + /** + * Represents the failing result of `Expression.checkInputDataTypes`, + * with a error message to show the reason of failure. + */ + case class TypeCheckFailure(message: String) extends TypeCheckResult { + def isSuccess: Boolean = false + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index adc6505d69cdf..3cf851aec15ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ @@ -53,11 +53,12 @@ abstract class Expression extends TreeNode[Expression] { /** * Returns `true` if this expression and all its children have been resolved to a specific schema - * and `false` if it still contains any unresolved placeholders. Implementations of expressions - * should override this if the resolution of this type of expression involves more than just - * the resolution of its children. + * and input data types checking passed, and `false` if it still contains any unresolved + * placeholders or has data types mismatch. + * Implementations of expressions should override this if the resolution of this type of + * expression involves more than just the resolution of its children and type checking. */ - lazy val resolved: Boolean = childrenResolved + lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess /** * Returns the [[DataType]] of the result of evaluating this expression. It is @@ -94,12 +95,21 @@ abstract class Expression extends TreeNode[Expression] { case (i1, i2) => i1 == i2 } } + + /** + * Checks the input data types, returns `TypeCheckResult.success` if it's valid, + * or returns a `TypeCheckResult` with an error message if invalid. + * Note: it's not valid to call this method until `childrenResolved == true` + * TODO: we should remove the default implementation and implement it for all + * expressions with proper error message. + */ + def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess } abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { self: Product => - def symbol: String + def symbol: String = sys.error(s"BinaryExpressions must override either toString or symbol") override def foldable: Boolean = left.foldable && right.foldable @@ -133,7 +143,13 @@ case class GroupExpression(children: Seq[Expression]) extends Expression { * so that the proper type conversions can be performed in the analyzer. */ trait ExpectsInputTypes { + self: Expression => def expectedChildTypes: Seq[DataType] + override def checkInputDataTypes(): TypeCheckResult = { + // We will always do type casting for `ExpectsInputTypes` in `HiveTypeCoercion`, + // so type mismatch error won't be reported here, but for underling `Cast`s. + TypeCheckResult.TypeCheckSuccess + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index f2299d5db6e9f..2ac53f8f6613f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -17,72 +17,89 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -case class UnaryMinus(child: Expression) extends UnaryExpression { +abstract class UnaryArithmetic extends UnaryExpression { + self: Product => - override def dataType: DataType = child.dataType override def foldable: Boolean = child.foldable override def nullable: Boolean = child.nullable - override def toString: String = s"-$child" - - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } + override def dataType: DataType = child.dataType override def eval(input: Row): Any = { val evalE = child.eval(input) if (evalE == null) { null } else { - numeric.negate(evalE) + evalInternal(evalE) } } + + protected def evalInternal(evalE: Any): Any = + sys.error(s"UnaryArithmetics must override either eval or evalInternal") } -case class Sqrt(child: Expression) extends UnaryExpression { +case class UnaryMinus(child: Expression) extends UnaryArithmetic { + override def toString: String = s"-$child" + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "operator -") + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE: Any) = numeric.negate(evalE) +} + +case class Sqrt(child: Expression) extends UnaryArithmetic { override def dataType: DataType = DoubleType - override def foldable: Boolean = child.foldable override def nullable: Boolean = true override def toString: String = s"SQRT($child)" - lazy val numeric = child.dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support non-negative numeric operations") - } + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function sqrt") - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - val value = numeric.toDouble(evalE) - if (value < 0) null - else math.sqrt(value) - } + private lazy val numeric = TypeUtils.getNumeric(child.dataType) + + protected override def evalInternal(evalE: Any) = { + val value = numeric.toDouble(evalE) + if (value < 0) null + else math.sqrt(value) } } +/** + * A function that get the absolute value of the numeric value. + */ +case class Abs(child: Expression) extends UnaryArithmetic { + override def toString: String = s"Abs($child)" + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function abs") + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE: Any) = numeric.abs(evalE) +} + abstract class BinaryArithmetic extends BinaryExpression { self: Product => - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType && - !DecimalType.isFixed(left.dataType) - - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + override def dataType: DataType = left.dataType + + override def checkInputDataTypes(): TypeCheckResult = { + if (left.dataType != right.dataType) { + TypeCheckResult.TypeCheckFailure( + s"differing types in ${this.getClass.getSimpleName} " + + s"(${left.dataType} and ${right.dataType}).") + } else { + checkTypesInternal(dataType) } - left.dataType } + protected def checkTypesInternal(t: DataType): TypeCheckResult + override def eval(input: Row): Any = { val evalE1 = left.eval(input) if(evalE1 == null) { @@ -97,88 +114,65 @@ abstract class BinaryArithmetic extends BinaryExpression { } } - def evalInternal(evalE1: Any, evalE2: Any): Any = - sys.error(s"BinaryExpressions must either override eval or evalInternal") + protected def evalInternal(evalE1: Any, evalE2: Any): Any = + sys.error(s"BinaryArithmetics must override either eval or evalInternal") } case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - numeric.plus(evalE1, evalE2) - } - } - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = numeric.plus(evalE1, evalE2) } case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "-" - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - numeric.minus(evalE1, evalE2) - } - } - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = numeric.minus(evalE1, evalE2) } case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "*" - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - numeric.times(evalE1, evalE2) - } - } - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = numeric.times(evalE1, evalE2) } case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "/" - override def nullable: Boolean = true - lazy val div: (Any, Any) => Any = dataType match { + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) + + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) + + private lazy val div: (Any, Any) => Any = dataType match { case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot - case other => sys.error(s"Type $other does not support numeric operations") } override def eval(input: Row): Any = { @@ -198,13 +192,17 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "%" - override def nullable: Boolean = true - lazy val integral = dataType match { + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) + + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) + + private lazy val integral = dataType match { case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]] - case other => sys.error(s"Type $other does not support numeric operations") } override def eval(input: Row): Any = { @@ -228,7 +226,10 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "&" - lazy val and: (Any, Any) => Any = dataType match { + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + + private lazy val and: (Any, Any) => Any = dataType match { case ByteType => ((evalE1: Byte, evalE2: Byte) => (evalE1 & evalE2).toByte).asInstanceOf[(Any, Any) => Any] case ShortType => @@ -237,10 +238,9 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme ((evalE1: Int, evalE2: Int) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] case LongType => ((evalE1: Long, evalE2: Long) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] - case other => sys.error(s"Unsupported bitwise & operation on $other") } - override def evalInternal(evalE1: Any, evalE2: Any): Any = and(evalE1, evalE2) + protected override def evalInternal(evalE1: Any, evalE2: Any) = and(evalE1, evalE2) } /** @@ -249,7 +249,10 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "|" - lazy val or: (Any, Any) => Any = dataType match { + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + + private lazy val or: (Any, Any) => Any = dataType match { case ByteType => ((evalE1: Byte, evalE2: Byte) => (evalE1 | evalE2).toByte).asInstanceOf[(Any, Any) => Any] case ShortType => @@ -258,10 +261,9 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet ((evalE1: Int, evalE2: Int) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] case LongType => ((evalE1: Long, evalE2: Long) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] - case other => sys.error(s"Unsupported bitwise | operation on $other") } - override def evalInternal(evalE1: Any, evalE2: Any): Any = or(evalE1, evalE2) + protected override def evalInternal(evalE1: Any, evalE2: Any) = or(evalE1, evalE2) } /** @@ -270,7 +272,10 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "^" - lazy val xor: (Any, Any) => Any = dataType match { + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + + private lazy val xor: (Any, Any) => Any = dataType match { case ByteType => ((evalE1: Byte, evalE2: Byte) => (evalE1 ^ evalE2).toByte).asInstanceOf[(Any, Any) => Any] case ShortType => @@ -279,23 +284,21 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme ((evalE1: Int, evalE2: Int) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] case LongType => ((evalE1: Long, evalE2: Long) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] - case other => sys.error(s"Unsupported bitwise ^ operation on $other") } - override def evalInternal(evalE1: Any, evalE2: Any): Any = xor(evalE1, evalE2) + protected override def evalInternal(evalE1: Any, evalE2: Any): Any = xor(evalE1, evalE2) } /** * A function that calculates bitwise not(~) of a number. */ -case class BitwiseNot(child: Expression) extends UnaryExpression { - - override def dataType: DataType = child.dataType - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable +case class BitwiseNot(child: Expression) extends UnaryArithmetic { override def toString: String = s"~$child" - lazy val not: (Any) => Any = dataType match { + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForBitwiseExpr(child.dataType, "operator ~") + + private lazy val not: (Any) => Any = dataType match { case ByteType => ((evalE: Byte) => (~evalE).toByte).asInstanceOf[(Any) => Any] case ShortType => @@ -304,43 +307,18 @@ case class BitwiseNot(child: Expression) extends UnaryExpression { ((evalE: Int) => ~evalE).asInstanceOf[(Any) => Any] case LongType => ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any] - case other => sys.error(s"Unsupported bitwise ~ operation on $other") } - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - not(evalE) - } - } + protected override def evalInternal(evalE: Any) = not(evalE) } -case class MaxOf(left: Expression, right: Expression) extends Expression { - - override def foldable: Boolean = left.foldable && right.foldable - +case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = left.nullable && right.nullable - override def children: Seq[Expression] = left :: right :: Nil - - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(t, "function maxOf") - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") - } - left.dataType - } - - lazy val ordering = left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } + private lazy val ordering = TypeUtils.getOrdering(dataType) override def eval(input: Row): Any = { val evalE1 = left.eval(input) @@ -361,30 +339,13 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { override def toString: String = s"MaxOf($left, $right)" } -case class MinOf(left: Expression, right: Expression) extends Expression { - - override def foldable: Boolean = left.foldable && right.foldable - +case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = left.nullable && right.nullable - override def children: Seq[Expression] = left :: right :: Nil + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(t, "function minOf") - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType - - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") - } - left.dataType - } - - lazy val ordering = left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } + private lazy val ordering = TypeUtils.getOrdering(dataType) override def eval(input: Row): Any = { val evalE1 = left.eval(input) @@ -404,28 +365,3 @@ case class MinOf(left: Expression, right: Expression) extends Expression { override def toString: String = s"MinOf($left, $right)" } - -/** - * A function that get the absolute value of the numeric value. - */ -case class Abs(child: Expression) extends UnaryExpression { - - override def dataType: DataType = child.dataType - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable - override def toString: String = s"Abs($child)" - - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } - - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - numeric.abs(evalE) - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala index 01f62ba0442e9..db853a2b97fad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala @@ -29,17 +29,10 @@ import org.apache.spark.sql.types._ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => - override def symbol: String = null override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) - override def nullable: Boolean = left.nullable || right.nullable override def toString: String = s"$name($left, $right)" - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType && - !DecimalType.isFixed(left.dataType) - override def dataType: DataType = DoubleType override def eval(input: Row): Any = { @@ -58,9 +51,8 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) } } -case class Atan2( - left: Expression, - right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { +case class Atan2(left: Expression, right: Expression) + extends BinaryMathExpression(math.atan2, "ATAN2") { override def eval(input: Row): Any = { val evalE1 = left.eval(input) @@ -80,8 +72,7 @@ case class Atan2( } } -case class Hypot( - left: Expression, - right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") +case class Hypot(left: Expression, right: Expression) + extends BinaryMathExpression(math.hypot, "HYPOT") case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 4f422d69c4382..807021d50e8e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.{DataType, BinaryType, BooleanType, AtomicType} +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType} object InterpretedPredicate { def create(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = @@ -171,22 +171,51 @@ case class Or(left: Expression, right: Expression) abstract class BinaryComparison extends BinaryExpression with Predicate { self: Product => -} -case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = "=" + override def checkInputDataTypes(): TypeCheckResult = { + if (left.dataType != right.dataType) { + TypeCheckResult.TypeCheckFailure( + s"differing types in ${this.getClass.getSimpleName} " + + s"(${left.dataType} and ${right.dataType}).") + } else { + checkTypesInternal(dataType) + } + } + + protected def checkTypesInternal(t: DataType): TypeCheckResult override def eval(input: Row): Any = { - val l = left.eval(input) - if (l == null) { + val evalE1 = left.eval(input) + if (evalE1 == null) { null } else { - val r = right.eval(input) - if (r == null) null - else if (left.dataType != BinaryType) l == r - else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + evalInternal(evalE1, evalE2) + } } } + + protected def evalInternal(evalE1: Any, evalE2: Any): Any = + sys.error(s"BinaryComparisons must override either eval or evalInternal") +} + +object BinaryComparison { + def unapply(b: BinaryComparison): Option[(Expression, Expression)] = + Some((b.left, b.right)) +} + +case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { + override def symbol: String = "=" + + override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess + + protected override def evalInternal(l: Any, r: Any) = { + if (left.dataType != BinaryType) l == r + else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) + } } case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { @@ -194,6 +223,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp override def nullable: Boolean = false + override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess + override def eval(input: Row): Any = { val l = left.eval(input) val r = right.eval(input) @@ -210,117 +241,45 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp case class LessThan(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "<" - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } - } + override protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - ordering.lt(evalE1, evalE2) - } - } - } + private lazy val ordering = TypeUtils.getOrdering(left.dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.lt(evalE1, evalE2) } case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "<=" - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } - } + override protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - ordering.lteq(evalE1, evalE2) - } - } - } + private lazy val ordering = TypeUtils.getOrdering(left.dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.lteq(evalE1, evalE2) } case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = ">" - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } - } + override protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - ordering.gt(evalE1, evalE2) - } - } - } + private lazy val ordering = TypeUtils.getOrdering(left.dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.gt(evalE1, evalE2) } case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = ">=" - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } - } + override protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - ordering.gteq(evalE1, evalE2) - } - } - } + private lazy val ordering = TypeUtils.getOrdering(left.dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.gteq(evalE1, evalE2) } case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) @@ -329,16 +288,20 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil override def nullable: Boolean = trueValue.nullable || falseValue.nullable - override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException( - this, - s"Can not resolve due to differing types ${trueValue.dataType}, ${falseValue.dataType}") + override def checkInputDataTypes(): TypeCheckResult = { + if (predicate.dataType != BooleanType) { + TypeCheckResult.TypeCheckFailure( + s"type of predicate expression in If should be boolean, not ${predicate.dataType}") + } else if (trueValue.dataType != falseValue.dataType) { + TypeCheckResult.TypeCheckFailure( + s"differing types in If (${trueValue.dataType} and ${falseValue.dataType}).") + } else { + TypeCheckResult.TypeCheckSuccess } - trueValue.dataType } + override def dataType: DataType = trueValue.dataType + override def eval(input: Row): Any = { if (true == predicate.eval(input)) { trueValue.eval(input) @@ -364,17 +327,23 @@ trait CaseWhenLike extends Expression { branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq val elseValue = if (branches.length % 2 == 0) None else Option(branches.last) - // both then and else val should be considered. + // both then and else expressions should be considered. def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType) def valueTypesEqual: Boolean = valueTypes.distinct.size == 1 - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, "cannot resolve due to differing types in some branches") + override def checkInputDataTypes(): TypeCheckResult = { + if (valueTypesEqual) { + checkTypesInternal() + } else { + TypeCheckResult.TypeCheckFailure( + "THEN and ELSE expressions should all be same type or coercible to a common type") } - valueTypes.head } + protected def checkTypesInternal(): TypeCheckResult + + override def dataType: DataType = thenList.head.dataType + override def nullable: Boolean = { // If no value is nullable and no elseValue is provided, the whole statement defaults to null. thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) @@ -395,10 +364,16 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { override def children: Seq[Expression] = branches - override lazy val resolved: Boolean = - childrenResolved && - whenList.forall(_.dataType == BooleanType) && - valueTypesEqual + override protected def checkTypesInternal(): TypeCheckResult = { + if (whenList.forall(_.dataType == BooleanType)) { + TypeCheckResult.TypeCheckSuccess + } else { + val index = whenList.indexWhere(_.dataType != BooleanType) + TypeCheckResult.TypeCheckFailure( + s"WHEN expressions in CaseWhen should all be boolean type, " + + s"but the ${index + 1}th when expression's type is ${whenList(index)}") + } + } /** Written in imperative fashion for performance considerations. */ override def eval(input: Row): Any = { @@ -441,9 +416,14 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW override def children: Seq[Expression] = key +: branches - override lazy val resolved: Boolean = - childrenResolved && valueTypesEqual && - (key +: whenList).map(_.dataType).distinct.size == 1 + override protected def checkTypesInternal(): TypeCheckResult = { + if ((key +: whenList).map(_.dataType).distinct.size > 1) { + TypeCheckResult.TypeCheckFailure( + "key and WHEN expressions should all be same type or coercible to a common type") + } else { + TypeCheckResult.TypeCheckSuccess + } + } /** Written in imperative fashion for performance considerations. */ override def eval(input: Row): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b25fb48f55e2b..5c6379b8d44b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -273,6 +273,10 @@ object NullPropagation extends Rule[LogicalPlan] { case e @ Substring(_, Literal(null, _), _) => Literal.create(null, e.dataType) case e @ Substring(_, _, Literal(null, _)) => Literal.create(null, e.dataType) + // MaxOf and MinOf can't do null propagation + case e: MaxOf => e + case e: MinOf => e + // Put exceptional cases above if any case e: BinaryArithmetic => e.children match { case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala index 3f92be4a55d7d..ad649acf536f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala @@ -24,7 +24,7 @@ import java.util.{Calendar, TimeZone} import org.apache.spark.sql.catalyst.expressions.Cast /** - * helper function to convert between Int value of days since 1970-01-01 and java.sql.Date + * Helper function to convert between Int value of days since 1970-01-01 and java.sql.Date */ object DateUtils { private val MILLIS_PER_DAY = 86400000 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala new file mode 100644 index 0000000000000..0bb12d2039ffc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.types._ + +/** + * Helper function to check for valid data types + */ +object TypeUtils { + def checkForNumericExpr(t: DataType, caller: String): TypeCheckResult = { + if (t.isInstanceOf[NumericType] || t == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"$caller accepts numeric types, not $t") + } + } + + def checkForBitwiseExpr(t: DataType, caller: String): TypeCheckResult = { + if (t.isInstanceOf[IntegralType] || t == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"$caller accepts integral types, not $t") + } + } + + def checkForOrderingExpr(t: DataType, caller: String): TypeCheckResult = { + if (t.isInstanceOf[AtomicType] || t == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"$caller accepts non-complex types, not $t") + } + } + + def getNumeric(t: DataType): Numeric[Any] = + t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] + + def getOrdering(t: DataType): Ordering[Any] = + t.asInstanceOf[AtomicType].ordering.asInstanceOf[Ordering[Any]] +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 1ba3a2686639f..74677ddfcad65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -107,7 +107,7 @@ protected[sql] abstract class AtomicType extends DataType { abstract class NumericType extends AtomicType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a - // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets + // type parameter and add a numeric annotation (i.e., [JvmType : Numeric]). This gets // desugared by the compiler into an argument to the objects constructor. This means there is no // longer an no argument constructor and thus the JVM cannot serialize the object anymore. private[sql] val numeric: Numeric[InternalType] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 1b8d18ded2257..7bac97b7894f5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -92,8 +92,10 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { } test("Comparison operations") { - checkComparison(LessThan(i, d1), DecimalType.Unlimited) - checkComparison(LessThanOrEqual(d1, d2), DecimalType.Unlimited) + checkComparison(EqualTo(i, d1), DecimalType(10, 1)) + checkComparison(EqualNullSafe(d2, d1), DecimalType(5, 2)) + checkComparison(LessThan(i, d1), DecimalType(10, 1)) + checkComparison(LessThanOrEqual(d1, d2), DecimalType(5, 2)) checkComparison(GreaterThan(d2, u), DecimalType.Unlimited) checkComparison(GreaterThanOrEqual(d1, f), DoubleType) checkComparison(GreaterThan(d2, d2), DecimalType(5, 2)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index a0798428db094..0df446636ea89 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -28,11 +28,11 @@ class HiveTypeCoercionSuite extends PlanTest { test("tightest common bound for types") { def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) { - var found = HiveTypeCoercion.findTightestCommonType(t1, t2) + var found = HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2) assert(found == tightestCommon, s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found") // Test both directions to make sure the widening is symmetric. - found = HiveTypeCoercion.findTightestCommonType(t2, t1) + found = HiveTypeCoercion.findTightestCommonTypeOfTwo(t2, t1) assert(found == tightestCommon, s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found") } @@ -140,13 +140,10 @@ class HiveTypeCoercionSuite extends PlanTest { CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) - // Will remove exception expectation in PR#6405 - intercept[RuntimeException] { - ruleTest(cwc, - CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), - CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) - ) - } + ruleTest(cwc, + CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), + CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) + ) } test("type coercion simplification for equal to") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala new file mode 100644 index 0000000000000..dcb3635c5ccae --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.types.StringType + +class ExpressionTypeCheckingSuite extends SparkFunSuite { + + val testRelation = LocalRelation( + 'intField.int, + 'stringField.string, + 'booleanField.boolean, + 'complexField.array(StringType)) + + def assertError(expr: Expression, errorMessage: String): Unit = { + val e = intercept[AnalysisException] { + assertSuccess(expr) + } + assert(e.getMessage.contains( + s"cannot resolve '${expr.prettyString}' due to data type mismatch:")) + assert(e.getMessage.contains(errorMessage)) + } + + def assertSuccess(expr: Expression): Unit = { + val analyzed = testRelation.select(expr.as("c")).analyze + SimpleAnalyzer.checkAnalysis(analyzed) + } + + def assertErrorForDifferingTypes(expr: Expression): Unit = { + assertError(expr, + s"differing types in ${expr.getClass.getSimpleName} (IntegerType and BooleanType).") + } + + test("check types for unary arithmetic") { + assertError(UnaryMinus('stringField), "operator - accepts numeric type") + assertSuccess(Sqrt('stringField)) // We will cast String to Double for sqrt + assertError(Sqrt('booleanField), "function sqrt accepts numeric type") + assertError(Abs('stringField), "function abs accepts numeric type") + assertError(BitwiseNot('stringField), "operator ~ accepts integral type") + } + + test("check types for binary arithmetic") { + // We will cast String to Double for binary arithmetic + assertSuccess(Add('intField, 'stringField)) + assertSuccess(Subtract('intField, 'stringField)) + assertSuccess(Multiply('intField, 'stringField)) + assertSuccess(Divide('intField, 'stringField)) + assertSuccess(Remainder('intField, 'stringField)) + // checkAnalysis(BitwiseAnd('intField, 'stringField)) + + assertErrorForDifferingTypes(Add('intField, 'booleanField)) + assertErrorForDifferingTypes(Subtract('intField, 'booleanField)) + assertErrorForDifferingTypes(Multiply('intField, 'booleanField)) + assertErrorForDifferingTypes(Divide('intField, 'booleanField)) + assertErrorForDifferingTypes(Remainder('intField, 'booleanField)) + assertErrorForDifferingTypes(BitwiseAnd('intField, 'booleanField)) + assertErrorForDifferingTypes(BitwiseOr('intField, 'booleanField)) + assertErrorForDifferingTypes(BitwiseXor('intField, 'booleanField)) + assertErrorForDifferingTypes(MaxOf('intField, 'booleanField)) + assertErrorForDifferingTypes(MinOf('intField, 'booleanField)) + + assertError(Add('booleanField, 'booleanField), "operator + accepts numeric type") + assertError(Subtract('booleanField, 'booleanField), "operator - accepts numeric type") + assertError(Multiply('booleanField, 'booleanField), "operator * accepts numeric type") + assertError(Divide('booleanField, 'booleanField), "operator / accepts numeric type") + assertError(Remainder('booleanField, 'booleanField), "operator % accepts numeric type") + + assertError(BitwiseAnd('booleanField, 'booleanField), "operator & accepts integral type") + assertError(BitwiseOr('booleanField, 'booleanField), "operator | accepts integral type") + assertError(BitwiseXor('booleanField, 'booleanField), "operator ^ accepts integral type") + + assertError(MaxOf('complexField, 'complexField), "function maxOf accepts non-complex type") + assertError(MinOf('complexField, 'complexField), "function minOf accepts non-complex type") + } + + test("check types for predicates") { + // We will cast String to Double for binary comparison + assertSuccess(EqualTo('intField, 'stringField)) + assertSuccess(EqualNullSafe('intField, 'stringField)) + assertSuccess(LessThan('intField, 'stringField)) + assertSuccess(LessThanOrEqual('intField, 'stringField)) + assertSuccess(GreaterThan('intField, 'stringField)) + assertSuccess(GreaterThanOrEqual('intField, 'stringField)) + + // We will transform EqualTo with numeric and boolean types to CaseKeyWhen + assertSuccess(EqualTo('intField, 'booleanField)) + assertSuccess(EqualNullSafe('intField, 'booleanField)) + + assertError(EqualTo('intField, 'complexField), "differing types") + assertError(EqualNullSafe('intField, 'complexField), "differing types") + + assertErrorForDifferingTypes(LessThan('intField, 'booleanField)) + assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField)) + assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) + assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) + + assertError( + LessThan('complexField, 'complexField), "operator < accepts non-complex type") + assertError( + LessThanOrEqual('complexField, 'complexField), "operator <= accepts non-complex type") + assertError( + GreaterThan('complexField, 'complexField), "operator > accepts non-complex type") + assertError( + GreaterThanOrEqual('complexField, 'complexField), "operator >= accepts non-complex type") + + assertError( + If('intField, 'stringField, 'stringField), + "type of predicate expression in If should be boolean") + assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField)) + + assertError( + CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'complexField)), + "THEN and ELSE expressions should all be same type or coercible to a common type") + assertError( + CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'complexField)), + "THEN and ELSE expressions should all be same type or coercible to a common type") + assertError( + CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)), + "WHEN expressions in CaseWhen should all be boolean type") + + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala index 06aa19ef09bd2..565d10247f10e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala @@ -147,7 +147,7 @@ private[sql] object InferSchema { * Returns the most general data type for two given data types. */ private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { - HiveTypeCoercion.findTightestCommonType(t1, t2).getOrElse { + HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. (t1, t2) match { case (other: DataType, NullType) => other diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 95eb1174b1dd6..7e1e21f5fbb99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -155,7 +155,7 @@ private[sql] object JsonRDD extends Logging { * Returns the most general data type for two given data types. */ private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { - HiveTypeCoercion.findTightestCommonType(t1, t2) match { + HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2) match { case Some(commonType) => commonType case None => // t1 or t2 is a StructType, ArrayType, or an unexpected type. From 28dbde3874ccdd44b73675938719b69336d23dac Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 3 Jun 2015 13:15:57 +0200 Subject: [PATCH 157/254] [SPARK-7983] [MLLIB] Add require for one-based indices in loadLibSVMFile jira: https://issues.apache.org/jira/browse/SPARK-7983 Customers frequently use zero-based indices in their LIBSVM files. No warnings or errors from Spark will be reported during their computation afterwards, and usually it will lead to wired result for many algorithms (like GBDT). add a quick check. Author: Yuhao Yang Closes #6538 from hhbyyh/loadSVM and squashes the following commits: 79d9c11 [Yuhao Yang] optimization as respond to comments 4310710 [Yuhao Yang] merge conflict 96460f1 [Yuhao Yang] merge conflict 20a2811 [Yuhao Yang] use require 6e4f8ca [Yuhao Yang] add check for ascending order 9956365 [Yuhao Yang] add ut for 0-based loadlibsvm exception 5bd1f9a [Yuhao Yang] add require for one-based in loadLIBSVM --- .../org/apache/spark/mllib/util/MLUtils.scala | 12 +++++++ .../spark/mllib/util/MLUtilsSuite.scala | 35 +++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 541f3288b6c43..52d6468a72af7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -82,6 +82,18 @@ object MLUtils { val value = indexAndValue(1).toDouble (index, value) }.unzip + + // check if indices are one-based and in ascending order + var previous = -1 + var i = 0 + val indicesLength = indices.length + while (i < indicesLength) { + val current = indices(i) + require(current > previous, "indices should be one-based and in ascending order" ) + previous = current + i += 1 + } + (label, indices.toArray, values.toArray) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 734b7babec7be..70219e9ad9d3e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -25,6 +25,7 @@ import breeze.linalg.{squaredDistance => breezeSquaredDistance} import com.google.common.base.Charsets import com.google.common.io.Files +import org.apache.spark.SparkException import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint @@ -108,6 +109,40 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { Utils.deleteRecursively(tempDir) } + test("loadLibSVMFile throws IllegalArgumentException when indices is zero-based") { + val lines = + """ + |0 + |0 0:4.0 4:5.0 6:6.0 + """.stripMargin + val tempDir = Utils.createTempDir() + val file = new File(tempDir.getPath, "part-00000") + Files.write(lines, file, Charsets.US_ASCII) + val path = tempDir.toURI.toString + + intercept[SparkException] { + loadLibSVMFile(sc, path).collect() + } + Utils.deleteRecursively(tempDir) + } + + test("loadLibSVMFile throws IllegalArgumentException when indices is not in ascending order") { + val lines = + """ + |0 + |0 3:4.0 2:5.0 6:6.0 + """.stripMargin + val tempDir = Utils.createTempDir() + val file = new File(tempDir.getPath, "part-00000") + Files.write(lines, file, Charsets.US_ASCII) + val path = tempDir.toURI.toString + + intercept[SparkException] { + loadLibSVMFile(sc, path).collect() + } + Utils.deleteRecursively(tempDir) + } + test("saveAsLibSVMFile") { val examples = sc.parallelize(Seq( LabeledPoint(1.1, Vectors.sparse(3, Seq((0, 1.23), (2, 4.56)))), From f1646e1023bd03e27268a8aa2ea11b6cc284075f Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 3 Jun 2015 09:26:21 -0700 Subject: [PATCH 158/254] [SPARK-7973] [SQL] Increase the timeout of two CliSuite tests. https://issues.apache.org/jira/browse/SPARK-7973 Author: Yin Huai Closes #6525 from yhuai/SPARK-7973 and squashes the following commits: 763b821 [Yin Huai] Also change the timeout of "Single command with -e" to 2 minutes. e598a08 [Yin Huai] Increase the timeout to 3 minutes. --- .../org/apache/spark/sql/hive/thriftserver/CliSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 3732af7870b93..13b0c5951dddc 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -133,7 +133,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { } test("Single command with -e") { - runCliWithin(1.minute, Seq("-e", "SHOW DATABASES;"))("" -> "OK") + runCliWithin(2.minute, Seq("-e", "SHOW DATABASES;"))("" -> "OK") } test("Single command with --database") { @@ -165,7 +165,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { val dataFilePath = Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") - runCliWithin(1.minute, Seq("--jars", s"$jarFile"))( + runCliWithin(3.minute, Seq("--jars", s"$jarFile"))( """CREATE TABLE t1(key string, val string) |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'; """.stripMargin From 2c4d550eda0e6f33d2d575825c3faef4c9217067 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 3 Jun 2015 10:11:27 -0700 Subject: [PATCH 159/254] [SPARK-7801] [BUILD] Updating versions to SPARK 1.5.0 Author: Patrick Wendell Closes #6328 from pwendell/spark-1.5-update and squashes the following commits: 2f42d02 [Patrick Wendell] A few more excludes 4bebcf0 [Patrick Wendell] Update to RC4 61aaf46 [Patrick Wendell] Using new release candidate 55f1610 [Patrick Wendell] Another exclude 04b4f04 [Patrick Wendell] More issues with transient 1.4 changes 36f549b [Patrick Wendell] [SPARK-7801] [BUILD] Updating versions to SPARK 1.5.0 --- assembly/pom.xml | 2 +- bagel/pom.xml | 2 +- core/pom.xml | 2 +- core/src/main/scala/org/apache/spark/package.scala | 2 +- docs/_config.yml | 4 ++-- examples/pom.xml | 2 +- external/flume-sink/pom.xml | 2 +- external/flume/pom.xml | 2 +- external/kafka-assembly/pom.xml | 2 +- external/kafka/pom.xml | 2 +- external/mqtt/pom.xml | 2 +- external/twitter/pom.xml | 2 +- external/zeromq/pom.xml | 2 +- extras/java8-tests/pom.xml | 2 +- extras/kinesis-asl/pom.xml | 2 +- extras/spark-ganglia-lgpl/pom.xml | 2 +- graphx/pom.xml | 2 +- launcher/pom.xml | 2 +- mllib/pom.xml | 2 +- network/common/pom.xml | 2 +- network/shuffle/pom.xml | 2 +- network/yarn/pom.xml | 2 +- pom.xml | 14 +++++++++++++- project/MimaBuild.scala | 3 ++- project/MimaExcludes.scala | 14 ++++++++++++++ repl/pom.xml | 2 +- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- sql/hive-thriftserver/pom.xml | 2 +- sql/hive/pom.xml | 2 +- streaming/pom.xml | 2 +- tools/pom.xml | 2 +- unsafe/pom.xml | 2 +- yarn/pom.xml | 2 +- 34 files changed, 61 insertions(+), 34 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index 626c8577e31fe..e9c6d26ccddc7 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml diff --git a/bagel/pom.xml b/bagel/pom.xml index 132cd433d78a2..ed5c37e595a96 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml diff --git a/core/pom.xml b/core/pom.xml index a02184222e9f0..e35694e9e98b4 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 2ab41ba488ff6..8ae76c5f72f2e 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -43,5 +43,5 @@ package org.apache package object spark { // For package docs only - val SPARK_VERSION = "1.4.0-SNAPSHOT" + val SPARK_VERSION = "1.5.0-SNAPSHOT" } diff --git a/docs/_config.yml b/docs/_config.yml index b22b627f09007..c0e031a83ba9c 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,8 +14,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 1.4.0-SNAPSHOT -SPARK_VERSION_SHORT: 1.4.0 +SPARK_VERSION: 1.5.0-SNAPSHOT +SPARK_VERSION_SHORT: 1.5.0 SCALA_BINARY_VERSION: "2.10" SCALA_VERSION: "2.10.4" MESOS_VERSION: 0.21.0 diff --git a/examples/pom.xml b/examples/pom.xml index e4efee7b5e647..e6884b09dca94 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 71f2b6fe18bd1..7a7dccc3d0922 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index a345c03582ad6..14f7daaf417e0 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml index 0b79f47647f6b..8059c443827ef 100644 --- a/external/kafka-assembly/pom.xml +++ b/external/kafka-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 5734d55bf4784..ded863bd985e8 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 7d102e10ab60f..0e41e5781784b 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index d28e3e1846d70..178ae8de13b57 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 9998c11c85171..37bfd10d43663 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 4351a8a12fe21..f138251748c9e 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 25847a1b33d9c..4787991572b61 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index e14bbae4a9b6e..478d0019a25f0 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 28b41228feb3d..853dea9a7795e 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml diff --git a/launcher/pom.xml b/launcher/pom.xml index cc177d23dff77..48dd0d5f9106b 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index 65c647a91d192..b16058ddc203a 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml diff --git a/network/common/pom.xml b/network/common/pom.xml index 0c3147761cfc5..a85e0a66f4a30 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index 7dc7c65825e34..4b5bfcb6f04bc 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml index 1e2e9c80af6cc..a99f7c4392d3d 100644 --- a/network/yarn/pom.xml +++ b/network/yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/pom.xml b/pom.xml index 711edf9efad2b..0b1aaad7566bc 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ @@ -269,6 +269,18 @@ false + + + spark-1.4-staging + Spark 1.4 RC4 Staging Repository + https://repository.apache.org/content/repositories/orgapachespark-1112 + + true + + + false + + diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index dde92949fa175..5812b72f0aa78 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -91,7 +91,8 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - val previousSparkVersion = "1.3.0" + // TODO: Change this once Spark 1.4.0 is released + val previousSparkVersion = "1.4.0-rc4" val fullId = "spark-" + projectRef.project + "_2.10" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 8da72b3fa7cdb..34371c9659423 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -34,6 +34,20 @@ import com.typesafe.tools.mima.core.ProblemFilters._ object MimaExcludes { def excludes(version: String) = version match { + case v if v.startsWith("1.5") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + // These are needed if checking against the sbt build, since they are part of + // the maven-generated artifacts in 1.3. + excludePackage("org.spark-project.jetty"), + MimaBuild.excludeSparkPackage("unused"), + // JavaRDDLike is not meant to be extended by user programs + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.partitioner"), + // Mima false positive (was a private[spark] class) + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.util.collection.PairIterator") + ) case v if v.startsWith("1.4") => Seq( MimaBuild.excludeSparkPackage("deploy"), diff --git a/repl/pom.xml b/repl/pom.xml index 6e5cb7f77e1df..85f7bc8ac1024 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index d9e1cdb84bb27..bf0a7327a58a2 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 8210c552603ea..3192f81ffaecd 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 20d3c7d4c5959..73e6ccdb1eaf8 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 923ffabb9b99e..a17546d706248 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/streaming/pom.xml b/streaming/pom.xml index 49d035a1e9696..697895e72fe5b 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index 1c6f3e83a1819..feffde4c857eb 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 2fd17267ac427..62c6354f1e203 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml diff --git a/yarn/pom.xml b/yarn/pom.xml index e207a46809684..644def7501dc8 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml From d053a31be93d789e3f26cf55d747ecf6ca386c29 Mon Sep 17 00:00:00 2001 From: animesh Date: Wed, 3 Jun 2015 11:28:18 -0700 Subject: [PATCH 160/254] [SPARK-7980] [SQL] Support SQLContext.range(end) 1. range() overloaded in SQLContext.scala 2. range() modified in python sql context.py 3. Tests added accordingly in DataFrameSuite.scala and python sql tests.py Author: animesh Closes #6609 from animeshbaranawal/SPARK-7980 and squashes the following commits: 935899c [animesh] SPARK-7980:python+scala changes --- python/pyspark/sql/context.py | 12 ++++++++++-- python/pyspark/sql/tests.py | 2 ++ .../main/scala/org/apache/spark/sql/SQLContext.scala | 11 +++++++++++ .../scala/org/apache/spark/sql/DataFrameSuite.scala | 8 ++++++++ 4 files changed, 31 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 9fdf43c3e6eb5..1bebfc48376b4 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -131,7 +131,7 @@ def udf(self): return UDFRegistration(self) @since(1.4) - def range(self, start, end, step=1, numPartitions=None): + def range(self, start, end=None, step=1, numPartitions=None): """ Create a :class:`DataFrame` with single LongType column named `id`, containing elements in a range from `start` to `end` (exclusive) with @@ -145,10 +145,18 @@ def range(self, start, end, step=1, numPartitions=None): >>> sqlContext.range(1, 7, 2).collect() [Row(id=1), Row(id=3), Row(id=5)] + + >>> sqlContext.range(3).collect() + [Row(id=0), Row(id=1), Row(id=2)] """ if numPartitions is None: numPartitions = self._sc.defaultParallelism - jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions)) + + if end is None: + jdf = self._ssql_ctx.range(0, int(start), int(step), int(numPartitions)) + else: + jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions)) + return DataFrame(jdf, self) @ignore_unicode_prefix diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6e498f0af0af5..a6fce50c76c2b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -131,6 +131,8 @@ def test_range(self): self.assertEqual(self.sqlCtx.range(1, 1).count(), 0) self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1) self.assertEqual(self.sqlCtx.range(0, 1 << 40, 1 << 39).count(), 2) + self.assertEqual(self.sqlCtx.range(-2).count(), 0) + self.assertEqual(self.sqlCtx.range(3).count(), 3) def test_explode(self): from pyspark.sql.functions import explode diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 91e6385dec81b..f08fb4fafe650 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -717,6 +717,17 @@ class SQLContext(@transient val sparkContext: SparkContext) StructType(StructField("id", LongType, nullable = false) :: Nil)) } + /** + * :: Experimental :: + * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements + * in an range from 0 to `end`(exclusive) with step value 1. + * + * @since 1.4.0 + * @group dataframe + */ + @Experimental + def range(end: Long): DataFrame = range(0, end) + /** * :: Experimental :: * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index a4fd1058afce5..9aaec2b064d76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -576,5 +576,13 @@ class DataFrameSuite extends QueryTest { val res9 = TestSQLContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") assert(res9.count == 2) assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) + + // only end provided as argument + val res10 = TestSQLContext.range(10).select("id") + assert(res10.count == 10) + assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) + + val res11 = TestSQLContext.range(-1).select("id") + assert(res11.count == 0) } } From d2a86eb8f0fcc02304604da56c589ea58c77587a Mon Sep 17 00:00:00 2001 From: Hari Shreedharan Date: Wed, 3 Jun 2015 13:43:13 -0500 Subject: [PATCH 161/254] [SPARK-7161] [HISTORY SERVER] Provide REST api to download event logs fro... ...m History Server This PR adds a new API that allows the user to download event logs for an application as a zip file. APIs have been added to download all logs for a given application or just for a specific attempt. This also add an additional method to the ApplicationHistoryProvider to get the raw files, zipped. Author: Hari Shreedharan Closes #5792 from harishreedharan/eventlog-download and squashes the following commits: 221cc26 [Hari Shreedharan] Update docs with new API information. a131be6 [Hari Shreedharan] Fix style issues. 5528bd8 [Hari Shreedharan] Merge branch 'master' into eventlog-download 6e8156e [Hari Shreedharan] Simplify tests, use Guava stream copy methods. d8ddede [Hari Shreedharan] Remove unnecessary case in EventLogDownloadResource. ffffb53 [Hari Shreedharan] Changed interface to use zip stream. Added more tests. 1100b40 [Hari Shreedharan] Ensure that `Path` does not appear in interfaces, by rafactoring interfaces. 5a5f3e2 [Hari Shreedharan] Fix test ordering issue. 0b66948 [Hari Shreedharan] Minor formatting/import fixes. 4fc518c [Hari Shreedharan] Fix rat failures. a48b91f [Hari Shreedharan] Refactor to make attemptId optional in the API. Also added tests. 0fc1424 [Hari Shreedharan] File download now works for individual attempts and the entire application. 350d7e8 [Hari Shreedharan] Merge remote-tracking branch 'asf/master' into eventlog-download fd6ab00 [Hari Shreedharan] Fix style issues 32b7662 [Hari Shreedharan] Use UIRoot directly in ApiRootResource. Also, use `Response` class to set headers. 7b362b2 [Hari Shreedharan] Almost working. 3d18ebc [Hari Shreedharan] [WIP] Try getting the event log download to work. --- .rat-excludes | 2 + .../history/ApplicationHistoryProvider.scala | 11 +++ .../deploy/history/FsHistoryProvider.scala | 63 ++++++++++++- .../spark/deploy/history/HistoryServer.scala | 8 ++ .../spark/status/api/v1/ApiRootResource.scala | 20 +++++ .../api/v1/EventLogDownloadResource.scala | 70 +++++++++++++++ .../application_list_json_expectation.json | 16 ++++ .../completed_app_list_json_expectation.json | 16 ++++ .../minDate_app_list_json_expectation.json | 34 +++++-- .../spark-events/local-1430917381535_1 | 5 ++ .../spark-events/local-1430917381535_2 | 5 ++ .../history/FsHistoryProviderSuite.scala | 40 ++++++++- .../deploy/history/HistoryServerSuite.scala | 88 +++++++++++++++++-- docs/monitoring.md | 8 ++ 14 files changed, 367 insertions(+), 19 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala create mode 100644 core/src/test/resources/spark-events/local-1430917381535_1 create mode 100644 core/src/test/resources/spark-events/local-1430917381535_2 diff --git a/.rat-excludes b/.rat-excludes index 8f2722cbd001f..994c7e86f8a91 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -80,6 +80,8 @@ local-1425081759269/* local-1426533911241/* local-1426633911242/* local-1430917381534/* +local-1430917381535_1 +local-1430917381535_2 DESCRIPTION NAMESPACE test_support/* diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index 298a8201960d1..5f5e0fe1c34d7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -17,6 +17,9 @@ package org.apache.spark.deploy.history +import java.util.zip.ZipOutputStream + +import org.apache.spark.SparkException import org.apache.spark.ui.SparkUI private[spark] case class ApplicationAttemptInfo( @@ -62,4 +65,12 @@ private[history] abstract class ApplicationHistoryProvider { */ def getConfig(): Map[String, String] = Map() + /** + * Writes out the event logs to the output stream provided. The logs will be compressed into a + * single zip file and written out. + * @throws SparkException if the logs for the app id cannot be found. + */ + @throws(classOf[SparkException]) + def writeEventLogs(appId: String, attemptId: Option[String], zipStream: ZipOutputStream): Unit + } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 45c2be34c8680..52b149b273e4b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -17,16 +17,18 @@ package org.apache.spark.deploy.history -import java.io.{BufferedInputStream, FileNotFoundException, IOException, InputStream} +import java.io.{BufferedInputStream, FileNotFoundException, InputStream, IOException, OutputStream} import java.util.concurrent.{ExecutorService, Executors, TimeUnit} +import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.mutable +import com.google.common.io.ByteStreams import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.fs.permission.AccessControlException -import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec import org.apache.spark.scheduler._ @@ -59,7 +61,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) .map { d => Utils.resolveURI(d).toString } .getOrElse(DEFAULT_LOG_DIR) - private val fs = Utils.getHadoopFileSystem(logDir, SparkHadoopUtil.get.newConfiguration(conf)) + private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) + private val fs = Utils.getHadoopFileSystem(logDir, hadoopConf) // Used by check event thread and clean log thread. // Scheduled thread pool size must be one, otherwise it will have concurrent issues about fs @@ -219,6 +222,58 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } + override def writeEventLogs( + appId: String, + attemptId: Option[String], + zipStream: ZipOutputStream): Unit = { + + /** + * This method compresses the files passed in, and writes the compressed data out into the + * [[OutputStream]] passed in. Each file is written as a new [[ZipEntry]] with its name being + * the name of the file being compressed. + */ + def zipFileToStream(file: Path, entryName: String, outputStream: ZipOutputStream): Unit = { + val fs = FileSystem.get(hadoopConf) + val inputStream = fs.open(file, 1 * 1024 * 1024) // 1MB Buffer + try { + outputStream.putNextEntry(new ZipEntry(entryName)) + ByteStreams.copy(inputStream, outputStream) + outputStream.closeEntry() + } finally { + inputStream.close() + } + } + + applications.get(appId) match { + case Some(appInfo) => + try { + // If no attempt is specified, or there is no attemptId for attempts, return all attempts + appInfo.attempts.filter { attempt => + attempt.attemptId.isEmpty || attemptId.isEmpty || attempt.attemptId.get == attemptId.get + }.foreach { attempt => + val logPath = new Path(logDir, attempt.logPath) + // If this is a legacy directory, then add the directory to the zipStream and add + // each file to that directory. + if (isLegacyLogDirectory(fs.getFileStatus(logPath))) { + val files = fs.listFiles(logPath, false) + zipStream.putNextEntry(new ZipEntry(attempt.logPath + "/")) + zipStream.closeEntry() + while (files.hasNext) { + val file = files.next().getPath + zipFileToStream(file, attempt.logPath + Path.SEPARATOR + file.getName, zipStream) + } + } else { + zipFileToStream(new Path(logDir, attempt.logPath), attempt.logPath, zipStream) + } + } + } finally { + zipStream.close() + } + case None => throw new SparkException(s"Logs for $appId not found.") + } + } + + /** * Replay the log files in the list and merge the list of old applications with new ones */ diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 5a0eb585a9049..10638afb74900 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.history import java.util.NoSuchElementException +import java.util.zip.ZipOutputStream import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import com.google.common.cache._ @@ -173,6 +174,13 @@ class HistoryServer( getApplicationList().iterator.map(ApplicationsListResource.appHistoryInfoToPublicAppInfo) } + override def writeEventLogs( + appId: String, + attemptId: Option[String], + zipStream: ZipOutputStream): Unit = { + provider.writeEventLogs(appId, attemptId, zipStream) + } + /** * Returns the provider configuration to show in the listing page. * diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index f73c742732dec..9af90ee5ecd9d 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.status.api.v1 +import java.util.zip.ZipOutputStream import javax.servlet.ServletContext import javax.ws.rs._ import javax.ws.rs.core.{Context, Response} @@ -164,6 +165,18 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { } } + @Path("applications/{appId}/logs") + def getEventLogs( + @PathParam("appId") appId: String): EventLogDownloadResource = { + new EventLogDownloadResource(uiRoot, appId, None) + } + + @Path("applications/{appId}/{attemptId}/logs") + def getEventLogs( + @PathParam("appId") appId: String, + @PathParam("attemptId") attemptId: String): EventLogDownloadResource = { + new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) + } } private[spark] object ApiRootResource { @@ -193,6 +206,13 @@ private[spark] trait UIRoot { def getSparkUI(appKey: String): Option[SparkUI] def getApplicationInfoList: Iterator[ApplicationInfo] + def writeEventLogs(appId: String, attemptId: Option[String], zipStream: ZipOutputStream): Unit = { + Response.serverError() + .entity("Event logs are only available through the history server.") + .status(Response.Status.SERVICE_UNAVAILABLE) + .build() + } + /** * Get the spark UI with the given appID, and apply a function * to it. If there is no such app, throw an appropriate exception diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala new file mode 100644 index 0000000000000..d416dba8324d8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status.api.v1 + +import java.io.OutputStream +import java.util.zip.ZipOutputStream +import javax.ws.rs.{GET, Produces} +import javax.ws.rs.core.{MediaType, Response, StreamingOutput} + +import scala.util.control.NonFatal + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.deploy.SparkHadoopUtil + +@Produces(Array(MediaType.APPLICATION_OCTET_STREAM)) +private[v1] class EventLogDownloadResource( + val uIRoot: UIRoot, + val appId: String, + val attemptId: Option[String]) extends Logging { + val conf = SparkHadoopUtil.get.newConfiguration(new SparkConf) + + @GET + def getEventLogs(): Response = { + try { + val fileName = { + attemptId match { + case Some(id) => s"eventLogs-$appId-$id.zip" + case None => s"eventLogs-$appId.zip" + } + } + + val stream = new StreamingOutput { + override def write(output: OutputStream) = { + val zipStream = new ZipOutputStream(output) + try { + uIRoot.writeEventLogs(appId, attemptId, zipStream) + } finally { + zipStream.close() + } + + } + } + + Response.ok(stream) + .header("Content-Disposition", s"attachment; filename=$fileName") + .header("Content-Type", MediaType.APPLICATION_OCTET_STREAM) + .build() + } catch { + case NonFatal(e) => + Response.serverError() + .entity(s"Event logs are not available for app: $appId.") + .status(Response.Status.SERVICE_UNAVAILABLE) + .build() + } + } +} diff --git a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json index ce4fe80b66aa5..d575bf2f284b9 100644 --- a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json @@ -7,6 +7,22 @@ "sparkUser" : "irashid", "completed" : true } ] +}, { + "id" : "local-1430917381535", + "name" : "Spark shell", + "attempts" : [ { + "attemptId" : "2", + "startTime" : "2015-05-06T13:03:00.893GMT", + "endTime" : "2015-05-06T13:03:00.950GMT", + "sparkUser" : "irashid", + "completed" : true + }, { + "attemptId" : "1", + "startTime" : "2015-05-06T13:03:00.880GMT", + "endTime" : "2015-05-06T13:03:00.890GMT", + "sparkUser" : "irashid", + "completed" : true + } ] }, { "id" : "local-1426533911241", "name" : "Spark shell", diff --git a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json index ce4fe80b66aa5..d575bf2f284b9 100644 --- a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json @@ -7,6 +7,22 @@ "sparkUser" : "irashid", "completed" : true } ] +}, { + "id" : "local-1430917381535", + "name" : "Spark shell", + "attempts" : [ { + "attemptId" : "2", + "startTime" : "2015-05-06T13:03:00.893GMT", + "endTime" : "2015-05-06T13:03:00.950GMT", + "sparkUser" : "irashid", + "completed" : true + }, { + "attemptId" : "1", + "startTime" : "2015-05-06T13:03:00.880GMT", + "endTime" : "2015-05-06T13:03:00.890GMT", + "sparkUser" : "irashid", + "completed" : true + } ] }, { "id" : "local-1426533911241", "name" : "Spark shell", diff --git a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json index dca86fe5f7e6a..15c2de8ef99ea 100644 --- a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json @@ -7,6 +7,22 @@ "sparkUser" : "irashid", "completed" : true } ] +}, { + "id" : "local-1430917381535", + "name" : "Spark shell", + "attempts" : [ { + "attemptId" : "2", + "startTime" : "2015-05-06T13:03:00.893GMT", + "endTime" : "2015-05-06T13:03:00.950GMT", + "sparkUser" : "irashid", + "completed" : true + }, { + "attemptId" : "1", + "startTime" : "2015-05-06T13:03:00.880GMT", + "endTime" : "2015-05-06T13:03:00.890GMT", + "sparkUser" : "irashid", + "completed" : true + } ] }, { "id" : "local-1426533911241", "name" : "Spark shell", @@ -24,12 +40,14 @@ "completed" : true } ] }, { - "id" : "local-1425081759269", - "name" : "Spark shell", - "attempts" : [ { - "startTime" : "2015-02-28T00:02:38.277GMT", - "endTime" : "2015-02-28T00:02:46.912GMT", - "sparkUser" : "irashid", - "completed" : true - } ] + "id": "local-1425081759269", + "name": "Spark shell", + "attempts": [ + { + "startTime": "2015-02-28T00:02:38.277GMT", + "endTime": "2015-02-28T00:02:46.912GMT", + "sparkUser": "irashid", + "completed": true + } + ] } ] \ No newline at end of file diff --git a/core/src/test/resources/spark-events/local-1430917381535_1 b/core/src/test/resources/spark-events/local-1430917381535_1 new file mode 100644 index 0000000000000..d5a1303344825 --- /dev/null +++ b/core/src/test/resources/spark-events/local-1430917381535_1 @@ -0,0 +1,5 @@ +{"Event":"SparkListenerLogStart","Spark Version":"1.4.0-SNAPSHOT"} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"driver","Host":"localhost","Port":61103},"Maximum Memory":278019440,"Timestamp":1430917380880} +{"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"Java Home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre","Java Version":"1.8.0_25 (Oracle Corporation)","Scala Version":"version 2.10.4"},"Spark Properties":{"spark.driver.host":"192.168.1.102","spark.eventLog.enabled":"true","spark.driver.port":"61101","spark.repl.class.uri":"http://192.168.1.102:61100","spark.jars":"","spark.app.name":"Spark shell","spark.scheduler.mode":"FIFO","spark.executor.id":"driver","spark.master":"local[*]","spark.eventLog.dir":"/Users/irashid/github/kraps/core/src/test/resources/spark-events","spark.fileserver.uri":"http://192.168.1.102:61102","spark.tachyonStore.folderName":"spark-aaaf41b3-d1dd-447f-8951-acf51490758b","spark.app.id":"local-1430917381534"},"System Properties":{"java.io.tmpdir":"/var/folders/36/m29jw1z95qv4ywb1c4n0rz000000gp/T/","line.separator":"\n","path.separator":":","sun.management.compiler":"HotSpot 64-Bit Tiered Compilers","SPARK_SUBMIT":"true","sun.cpu.endian":"little","java.specification.version":"1.8","java.vm.specification.name":"Java Virtual Machine Specification","java.vendor":"Oracle Corporation","java.vm.specification.version":"1.8","user.home":"/Users/irashid","file.encoding.pkg":"sun.io","sun.nio.ch.bugLevel":"","ftp.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","sun.arch.data.model":"64","sun.boot.library.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib","user.dir":"/Users/irashid/github/spark","java.library.path":"/Users/irashid/Library/Java/Extensions:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java:.","sun.cpu.isalist":"","os.arch":"x86_64","java.vm.version":"25.25-b02","java.endorsed.dirs":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/endorsed","java.runtime.version":"1.8.0_25-b17","java.vm.info":"mixed mode","java.ext.dirs":"/Users/irashid/Library/Java/Extensions:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/ext:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java","java.runtime.name":"Java(TM) SE Runtime Environment","file.separator":"/","java.class.version":"52.0","scala.usejavacp":"true","java.specification.name":"Java Platform API Specification","sun.boot.class.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/resources.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/rt.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/sunrsasign.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/jsse.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/jce.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/charsets.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/jfr.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/classes","file.encoding":"UTF-8","user.timezone":"America/Chicago","java.specification.vendor":"Oracle Corporation","sun.java.launcher":"SUN_STANDARD","os.version":"10.9.5","sun.os.patch.level":"unknown","gopherProxySet":"false","java.vm.specification.vendor":"Oracle Corporation","user.country":"US","sun.jnu.encoding":"UTF-8","http.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","user.language":"en","socksNonProxyHosts":"local|*.local|169.254/16|*.169.254/16","java.vendor.url":"http://java.oracle.com/","java.awt.printerjob":"sun.lwawt.macosx.CPrinterJob","java.awt.graphicsenv":"sun.awt.CGraphicsEnvironment","awt.toolkit":"sun.lwawt.macosx.LWCToolkit","os.name":"Mac OS X","java.vm.vendor":"Oracle Corporation","java.vendor.url.bug":"http://bugreport.sun.com/bugreport/","user.name":"irashid","java.vm.name":"Java HotSpot(TM) 64-Bit Server VM","sun.java.command":"org.apache.spark.deploy.SparkSubmit --conf spark.eventLog.enabled=true --conf spark.eventLog.dir=/Users/irashid/github/kraps/core/src/test/resources/spark-events --class org.apache.spark.repl.Main spark-shell","java.home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre","java.version":"1.8.0_25","sun.io.unicode.encoding":"UnicodeBig"},"Classpath Entries":{"/etc/hadoop":"System Classpath","/Users/irashid/github/spark/lib_managed/jars/datanucleus-rdbms-3.2.9.jar":"System Classpath","/Users/irashid/github/spark/conf/":"System Classpath","/Users/irashid/github/spark/assembly/target/scala-2.10/spark-assembly-1.4.0-SNAPSHOT-hadoop2.5.0.jar":"System Classpath","/Users/irashid/github/spark/lib_managed/jars/datanucleus-core-3.2.10.jar":"System Classpath","/Users/irashid/github/spark/lib_managed/jars/datanucleus-api-jdo-3.2.6.jar":"System Classpath"}} +{"Event":"SparkListenerApplicationStart","App Name":"Spark shell","App ID":"local-1430917381535","Timestamp":1430917380880,"User":"irashid","App Attempt ID":"1"} +{"Event":"SparkListenerApplicationEnd","Timestamp":1430917380890} \ No newline at end of file diff --git a/core/src/test/resources/spark-events/local-1430917381535_2 b/core/src/test/resources/spark-events/local-1430917381535_2 new file mode 100644 index 0000000000000..abb637a22e1e3 --- /dev/null +++ b/core/src/test/resources/spark-events/local-1430917381535_2 @@ -0,0 +1,5 @@ +{"Event":"SparkListenerLogStart","Spark Version":"1.4.0-SNAPSHOT"} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"driver","Host":"localhost","Port":61103},"Maximum Memory":278019440,"Timestamp":1430917380893} +{"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"Java Home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre","Java Version":"1.8.0_25 (Oracle Corporation)","Scala Version":"version 2.10.4"},"Spark Properties":{"spark.driver.host":"192.168.1.102","spark.eventLog.enabled":"true","spark.driver.port":"61101","spark.repl.class.uri":"http://192.168.1.102:61100","spark.jars":"","spark.app.name":"Spark shell","spark.scheduler.mode":"FIFO","spark.executor.id":"driver","spark.master":"local[*]","spark.eventLog.dir":"/Users/irashid/github/kraps/core/src/test/resources/spark-events","spark.fileserver.uri":"http://192.168.1.102:61102","spark.tachyonStore.folderName":"spark-aaaf41b3-d1dd-447f-8951-acf51490758b","spark.app.id":"local-1430917381534"},"System Properties":{"java.io.tmpdir":"/var/folders/36/m29jw1z95qv4ywb1c4n0rz000000gp/T/","line.separator":"\n","path.separator":":","sun.management.compiler":"HotSpot 64-Bit Tiered Compilers","SPARK_SUBMIT":"true","sun.cpu.endian":"little","java.specification.version":"1.8","java.vm.specification.name":"Java Virtual Machine Specification","java.vendor":"Oracle Corporation","java.vm.specification.version":"1.8","user.home":"/Users/irashid","file.encoding.pkg":"sun.io","sun.nio.ch.bugLevel":"","ftp.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","sun.arch.data.model":"64","sun.boot.library.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib","user.dir":"/Users/irashid/github/spark","java.library.path":"/Users/irashid/Library/Java/Extensions:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java:.","sun.cpu.isalist":"","os.arch":"x86_64","java.vm.version":"25.25-b02","java.endorsed.dirs":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/endorsed","java.runtime.version":"1.8.0_25-b17","java.vm.info":"mixed mode","java.ext.dirs":"/Users/irashid/Library/Java/Extensions:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/ext:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java","java.runtime.name":"Java(TM) SE Runtime Environment","file.separator":"/","java.class.version":"52.0","scala.usejavacp":"true","java.specification.name":"Java Platform API Specification","sun.boot.class.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/resources.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/rt.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/sunrsasign.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/jsse.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/jce.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/charsets.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/jfr.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/classes","file.encoding":"UTF-8","user.timezone":"America/Chicago","java.specification.vendor":"Oracle Corporation","sun.java.launcher":"SUN_STANDARD","os.version":"10.9.5","sun.os.patch.level":"unknown","gopherProxySet":"false","java.vm.specification.vendor":"Oracle Corporation","user.country":"US","sun.jnu.encoding":"UTF-8","http.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","user.language":"en","socksNonProxyHosts":"local|*.local|169.254/16|*.169.254/16","java.vendor.url":"http://java.oracle.com/","java.awt.printerjob":"sun.lwawt.macosx.CPrinterJob","java.awt.graphicsenv":"sun.awt.CGraphicsEnvironment","awt.toolkit":"sun.lwawt.macosx.LWCToolkit","os.name":"Mac OS X","java.vm.vendor":"Oracle Corporation","java.vendor.url.bug":"http://bugreport.sun.com/bugreport/","user.name":"irashid","java.vm.name":"Java HotSpot(TM) 64-Bit Server VM","sun.java.command":"org.apache.spark.deploy.SparkSubmit --conf spark.eventLog.enabled=true --conf spark.eventLog.dir=/Users/irashid/github/kraps/core/src/test/resources/spark-events --class org.apache.spark.repl.Main spark-shell","java.home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre","java.version":"1.8.0_25","sun.io.unicode.encoding":"UnicodeBig"},"Classpath Entries":{"/etc/hadoop":"System Classpath","/Users/irashid/github/spark/lib_managed/jars/datanucleus-rdbms-3.2.9.jar":"System Classpath","/Users/irashid/github/spark/conf/":"System Classpath","/Users/irashid/github/spark/assembly/target/scala-2.10/spark-assembly-1.4.0-SNAPSHOT-hadoop2.5.0.jar":"System Classpath","/Users/irashid/github/spark/lib_managed/jars/datanucleus-core-3.2.10.jar":"System Classpath","/Users/irashid/github/spark/lib_managed/jars/datanucleus-api-jdo-3.2.6.jar":"System Classpath"}} +{"Event":"SparkListenerApplicationStart","App Name":"Spark shell","App ID":"local-1430917381535","Timestamp":1430917380893,"User":"irashid","App Attempt ID":"2"} +{"Event":"SparkListenerApplicationEnd","Timestamp":1430917380950} \ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 0f6933df9e6bc..09075eeb539aa 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -17,12 +17,16 @@ package org.apache.spark.deploy.history -import java.io.{BufferedOutputStream, File, FileOutputStream, OutputStreamWriter} +import java.io.{BufferedOutputStream, ByteArrayInputStream, ByteArrayOutputStream, File, + FileOutputStream, OutputStreamWriter} import java.net.URI import java.util.concurrent.TimeUnit +import java.util.zip.{ZipInputStream, ZipOutputStream} import scala.io.Source +import com.google.common.base.Charsets +import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.fs.Path import org.json4s.jackson.JsonMethods._ import org.scalatest.BeforeAndAfter @@ -335,6 +339,40 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc assert(!log2.exists()) } + test("Event log copy") { + val provider = new FsHistoryProvider(createTestConf()) + val logs = (1 to 2).map { i => + val log = newLogFile("downloadApp1", Some(s"attempt$i"), inProgress = false) + writeFile(log, true, None, + SparkListenerApplicationStart( + "downloadApp1", Some("downloadApp1"), 5000 * i, "test", Some(s"attempt$i")), + SparkListenerApplicationEnd(5001 * i) + ) + log + } + provider.checkForLogs() + + (1 to 2).foreach { i => + val underlyingStream = new ByteArrayOutputStream() + val outputStream = new ZipOutputStream(underlyingStream) + provider.writeEventLogs("downloadApp1", Some(s"attempt$i"), outputStream) + outputStream.close() + val inputStream = new ZipInputStream(new ByteArrayInputStream(underlyingStream.toByteArray)) + var totalEntries = 0 + var entry = inputStream.getNextEntry + entry should not be null + while (entry != null) { + val actual = new String(ByteStreams.toByteArray(inputStream), Charsets.UTF_8) + val expected = Files.toString(logs.find(_.getName == entry.getName).get, Charsets.UTF_8) + actual should be (expected) + totalEntries += 1 + entry = inputStream.getNextEntry + } + totalEntries should be (1) + inputStream.close() + } + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 14f2d1a5894b8..e5b5e1bb65337 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -16,10 +16,13 @@ */ package org.apache.spark.deploy.history -import java.io.{File, FileInputStream, FileWriter, IOException} +import java.io.{File, FileInputStream, FileWriter, InputStream, IOException} import java.net.{HttpURLConnection, URL} +import java.util.zip.ZipInputStream import javax.servlet.http.{HttpServletRequest, HttpServletResponse} +import com.google.common.base.Charsets +import com.google.common.io.{ByteStreams, Files} import org.apache.commons.io.{FileUtils, IOUtils} import org.mockito.Mockito.when import org.scalatest.{BeforeAndAfter, Matchers} @@ -147,6 +150,70 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers } } + test("download all logs for app with multiple attempts") { + doDownloadTest("local-1430917381535", None) + } + + test("download one log for app with multiple attempts") { + (1 to 2).foreach { attemptId => doDownloadTest("local-1430917381535", Some(attemptId)) } + } + + test("download legacy logs - all attempts") { + doDownloadTest("local-1426533911241", None, legacy = true) + } + + test("download legacy logs - single attempts") { + (1 to 2). foreach { + attemptId => doDownloadTest("local-1426533911241", Some(attemptId), legacy = true) + } + } + + // Test that the files are downloaded correctly, and validate them. + def doDownloadTest(appId: String, attemptId: Option[Int], legacy: Boolean = false): Unit = { + + val url = attemptId match { + case Some(id) => + new URL(s"${generateURL(s"applications/$appId")}/$id/logs") + case None => + new URL(s"${generateURL(s"applications/$appId")}/logs") + } + + val (code, inputStream, error) = HistoryServerSuite.connectAndGetInputStream(url) + code should be (HttpServletResponse.SC_OK) + inputStream should not be None + error should be (None) + + val zipStream = new ZipInputStream(inputStream.get) + var entry = zipStream.getNextEntry + entry should not be null + val totalFiles = { + if (legacy) { + attemptId.map { x => 3 }.getOrElse(6) + } else { + attemptId.map { x => 1 }.getOrElse(2) + } + } + var filesCompared = 0 + while (entry != null) { + if (!entry.isDirectory) { + val expectedFile = { + if (legacy) { + val splits = entry.getName.split("/") + new File(new File(logDir, splits(0)), splits(1)) + } else { + new File(logDir, entry.getName) + } + } + val expected = Files.toString(expectedFile, Charsets.UTF_8) + val actual = new String(ByteStreams.toByteArray(zipStream), Charsets.UTF_8) + actual should be (expected) + filesCompared += 1 + } + entry = zipStream.getNextEntry + } + filesCompared should be (totalFiles) + } + test("response codes on bad paths") { val badAppId = getContentAndCode("applications/foobar") badAppId._1 should be (HttpServletResponse.SC_NOT_FOUND) @@ -202,7 +269,11 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers } def getUrl(path: String): String = { - HistoryServerSuite.getUrl(new URL(s"http://localhost:$port/api/v1/$path")) + HistoryServerSuite.getUrl(generateURL(path)) + } + + def generateURL(path: String): URL = { + new URL(s"http://localhost:$port/api/v1/$path") } def generateExpectation(name: String, path: String): Unit = { @@ -233,13 +304,18 @@ object HistoryServerSuite { } def getContentAndCode(url: URL): (Int, Option[String], Option[String]) = { + val (code, in, errString) = connectAndGetInputStream(url) + val inString = in.map(IOUtils.toString) + (code, inString, errString) + } + + def connectAndGetInputStream(url: URL): (Int, Option[InputStream], Option[String]) = { val connection = url.openConnection().asInstanceOf[HttpURLConnection] connection.setRequestMethod("GET") connection.connect() val code = connection.getResponseCode() - val inString = try { - val in = Option(connection.getInputStream()) - in.map(IOUtils.toString) + val inStream = try { + Option(connection.getInputStream()) } catch { case io: IOException => None } @@ -249,7 +325,7 @@ object HistoryServerSuite { } catch { case io: IOException => None } - (code, inString, errString) + (code, inStream, errString) } diff --git a/docs/monitoring.md b/docs/monitoring.md index e75018499003a..31ecddc6dbbb9 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -228,6 +228,14 @@ for a running application, at `http://localhost:4040/api/v1`. /applications/[app-id]/storage/rdd/[rdd-id] Details for the storage status of a given RDD + + /applications/[app-id]/logs + Download the event logs for all attempts of the given application as a zip file + + + /applications/[app-id]/[attempt-id/logs + Download the event logs for the specified attempt of the given application as a zip file + When running on Yarn, each application has multiple attempts, so `[app-id]` is actually From 708c63bbbe9580eb774fe47e23ef61338103afda Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Wed, 3 Jun 2015 11:56:35 -0700 Subject: [PATCH 162/254] [SPARK-8063] [SPARKR] Spark master URL conflict between MASTER env variable and --master command line option. Author: Sun Rui Closes #6605 from sun-rui/SPARK-8063 and squashes the following commits: 51ca48b [Sun Rui] [SPARK-8063][SPARKR] Spark master URL conflict between MASTER env variable and --master command line option. --- R/pkg/inst/profile/shell.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R index ca94f1d4e7fd5..773b6ecf582d9 100644 --- a/R/pkg/inst/profile/shell.R +++ b/R/pkg/inst/profile/shell.R @@ -24,7 +24,7 @@ old <- getOption("defaultPackages") options(defaultPackages = c(old, "SparkR")) - sc <- SparkR::sparkR.init(Sys.getenv("MASTER", unset = "")) + sc <- SparkR::sparkR.init() assign("sc", sc, envir=.GlobalEnv) sqlContext <- SparkR::sparkRSQL.init(sc) assign("sqlContext", sqlContext, envir=.GlobalEnv) From 939e4f3d8def16dfe03f0196be8e1c218a9daa32 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 3 Jun 2015 13:57:57 -0700 Subject: [PATCH 163/254] [SPARK-8074] Parquet should throw AnalysisException during setup for data type/name related failures. Author: Reynold Xin Closes #6608 from rxin/parquet-analysis and squashes the following commits: b5dc8e2 [Reynold Xin] Code review feedback. 5617cf6 [Reynold Xin] [SPARK-8074] Parquet should throw AnalysisException during setup for data type/name related failures. --- .../spark/sql/parquet/ParquetTypes.scala | 20 +++++++++---------- .../apache/spark/sql/parquet/newParquet.scala | 14 +++++++------ 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 6698b19c7477d..f8a5d84549336 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.parquet import java.io.IOException -import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConversions._ import scala.util.Try import org.apache.hadoop.conf.Configuration @@ -33,12 +33,11 @@ import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeNa import parquet.schema.Type.Repetition import parquet.schema.{ConversionPatterns, DecimalMetadata, GroupType => ParquetGroupType, MessageType, OriginalType => ParquetOriginalType, PrimitiveType => ParquetPrimitiveType, Type => ParquetType, Types => ParquetTypes} +import org.apache.spark.Logging +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.types._ -import org.apache.spark.{Logging, SparkException} -// Implicits -import scala.collection.JavaConversions._ /** A class representing Parquet info fields we care about, for passing back to Parquet */ private[parquet] case class ParquetTypeInfo( @@ -73,13 +72,12 @@ private[parquet] object ParquetTypesConverter extends Logging { case ParquetPrimitiveTypeName.INT96 if int96AsTimestamp => TimestampType case ParquetPrimitiveTypeName.INT96 => // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? - sys.error("Potential loss of precision: cannot convert INT96") + throw new AnalysisException("Potential loss of precision: cannot convert INT96") case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY if (originalType == ParquetOriginalType.DECIMAL && decimalInfo.getPrecision <= 18) => // TODO: for now, our reader only supports decimals that fit in a Long DecimalType(decimalInfo.getPrecision, decimalInfo.getScale) - case _ => sys.error( - s"Unsupported parquet datatype $parquetType") + case _ => throw new AnalysisException(s"Unsupported parquet datatype $parquetType") } } @@ -371,7 +369,7 @@ private[parquet] object ParquetTypesConverter extends Logging { parquetKeyType, parquetValueType) } - case _ => sys.error(s"Unsupported datatype $ctype") + case _ => throw new AnalysisException(s"Unsupported datatype $ctype") } } } @@ -403,7 +401,7 @@ private[parquet] object ParquetTypesConverter extends Logging { def convertFromString(string: String): Seq[Attribute] = { Try(DataType.fromJson(string)).getOrElse(DataType.fromCaseClassString(string)) match { case s: StructType => s.toAttributes - case other => sys.error(s"Can convert $string to row") + case other => throw new AnalysisException(s"Can convert $string to row") } } @@ -411,8 +409,8 @@ private[parquet] object ParquetTypesConverter extends Logging { // ,;{}()\n\t= and space character are special characters in Parquet schema schema.map(_.name).foreach { name => if (name.matches(".*[ ,;{}()\n\t=].*")) { - sys.error( - s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\n\t=". + throw new AnalysisException( + s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=". |Please use alias to rename it. """.stripMargin.split("\n").mkString(" ")) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 824ae36968c32..bf55e2383ab56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -39,6 +39,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD._ import org.apache.spark.rdd.RDD +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.{Row, SQLConf, SQLContext} @@ -83,7 +84,7 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext case partFilePattern(id) => id.toInt case name if name.startsWith("_") => 0 case name if name.startsWith(".") => 0 - case name => sys.error( + case name => throw new AnalysisException( s"Trying to write Parquet files to directory $outputPath, " + s"but found items with illegal name '$name'.") }.reduceOption(_ max _).getOrElse(0) @@ -380,11 +381,12 @@ private[sql] class ParquetRelation2( // time-consuming. if (dataSchema == null) { dataSchema = { - val dataSchema0 = - maybeDataSchema - .orElse(readSchema()) - .orElse(maybeMetastoreSchema) - .getOrElse(sys.error("Failed to get the schema.")) + val dataSchema0 = maybeDataSchema + .orElse(readSchema()) + .orElse(maybeMetastoreSchema) + .getOrElse(throw new AnalysisException( + s"Failed to discover schema of Parquet file(s) in the following location(s):\n" + + paths.mkString("\n\t"))) // If this Parquet relation is converted from a Hive Metastore table, must reconcile case // case insensitivity issue and possible schema mismatch (probably caused by schema From 2c5a06cafd2885ff5431fa96485db2564ae1cce3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 3 Jun 2015 14:19:10 -0700 Subject: [PATCH 164/254] Update documentation for [SPARK-7980] [SQL] Support SQLContext.range(end) --- python/pyspark/sql/context.py | 2 ++ .../org/apache/spark/sql/SQLContext.scala | 20 +++++++++---------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 1bebfc48376b4..599c9ac5794a2 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -146,6 +146,8 @@ def range(self, start, end=None, step=1, numPartitions=None): >>> sqlContext.range(1, 7, 2).collect() [Row(id=1), Row(id=3), Row(id=5)] + If only one argument is specified, it will be used as the end value. + >>> sqlContext.range(3).collect() [Row(id=0), Row(id=1), Row(id=2)] """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index f08fb4fafe650..0aab7fa8709b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -705,33 +705,33 @@ class SQLContext(@transient val sparkContext: SparkContext) /** * :: Experimental :: * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements - * in an range from `start` to `end`(exclusive) with step value 1. + * in an range from 0 to `end` (exclusive) with step value 1. * - * @since 1.4.0 + * @since 1.4.1 * @group dataframe */ @Experimental - def range(start: Long, end: Long): DataFrame = { - createDataFrame( - sparkContext.range(start, end).map(Row(_)), - StructType(StructField("id", LongType, nullable = false) :: Nil)) - } + def range(end: Long): DataFrame = range(0, end) /** * :: Experimental :: * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements - * in an range from 0 to `end`(exclusive) with step value 1. + * in an range from `start` to `end` (exclusive) with step value 1. * * @since 1.4.0 * @group dataframe */ @Experimental - def range(end: Long): DataFrame = range(0, end) + def range(start: Long, end: Long): DataFrame = { + createDataFrame( + sparkContext.range(start, end).map(Row(_)), + StructType(StructField("id", LongType, nullable = false) :: Nil)) + } /** * :: Experimental :: * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements - * in an range from `start` to `end`(exclusive) with an step value, with partition number + * in an range from `start` to `end` (exclusive) with an step value, with partition number * specified. * * @since 1.4.0 From 20a26b595c74ac41cf7c19e6091d7e675e503321 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 3 Jun 2015 14:34:20 -0700 Subject: [PATCH 165/254] [SPARK-8054] [MLLIB] Added several Java-friendly APIs + unit tests Java-friendly APIs added: * GaussianMixture.run() * GaussianMixtureModel.predict() * DistributedLDAModel.javaTopicDistributions() * StreamingKMeans: trainOn, predictOn, predictOnValues * Statistics.corr * params * added doc to w() since Java docs do not inherit doc * removed non-Java-friendly w() from StringArrayParam and DoubleArrayParam * made DoubleArrayParam Java-friendly w() actually Java-friendly I generated the doc and verified all changes. CC: mengxr Author: Joseph K. Bradley Closes #6562 from jkbradley/java-api-1.4 and squashes the following commits: c16821b [Joseph K. Bradley] Small fixes based on code review. d955581 [Joseph K. Bradley] unit test fixes 29b6b0d [Joseph K. Bradley] small fixes fe6dcfe [Joseph K. Bradley] Added several Java-friendly APIs + unit tests: NaiveBayes, GaussianMixture, LDA, StreamingKMeans, Statistics.corr, params --- .../org/apache/spark/ml/param/params.scala | 20 ++--- .../mllib/clustering/GaussianMixture.scala | 4 + .../clustering/GaussianMixtureModel.scala | 7 +- .../spark/mllib/clustering/LDAModel.scala | 6 ++ .../mllib/clustering/StreamingKMeans.scala | 18 ++++ .../apache/spark/mllib/stat/Statistics.scala | 9 ++ .../spark/ml/param/JavaParamsSuite.java | 1 + .../apache/spark/ml/param/JavaTestParams.java | 29 +++++-- .../JavaStreamingLogisticRegressionSuite.java | 3 +- .../clustering/JavaGaussianMixtureSuite.java | 64 +++++++++++++++ .../spark/mllib/clustering/JavaLDASuite.java | 4 + .../clustering/JavaStreamingKMeansSuite.java | 82 +++++++++++++++++++ .../spark/mllib/stat/JavaStatisticsSuite.java | 56 +++++++++++++ 13 files changed, 284 insertions(+), 19 deletions(-) rename mllib/src/test/java/org/apache/spark/{ml => mllib}/classification/JavaStreamingLogisticRegressionSuite.java (95%) create mode 100644 mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 473488dce9b0d..ba94d6a3a80a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -69,14 +69,10 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali } } - /** - * Creates a param pair with the given value (for Java). - */ + /** Creates a param pair with the given value (for Java). */ def w(value: T): ParamPair[T] = this -> value - /** - * Creates a param pair with the given value (for Scala). - */ + /** Creates a param pair with the given value (for Scala). */ def ->(value: T): ParamPair[T] = ParamPair(this, value) override final def toString: String = s"${parent}__$name" @@ -190,6 +186,7 @@ class DoubleParam(parent: String, name: String, doc: String, isValid: Double => def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + /** Creates a param pair with the given value (for Java). */ override def w(value: Double): ParamPair[Double] = super.w(value) } @@ -209,6 +206,7 @@ class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolea def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + /** Creates a param pair with the given value (for Java). */ override def w(value: Int): ParamPair[Int] = super.w(value) } @@ -228,6 +226,7 @@ class FloatParam(parent: String, name: String, doc: String, isValid: Float => Bo def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + /** Creates a param pair with the given value (for Java). */ override def w(value: Float): ParamPair[Float] = super.w(value) } @@ -247,6 +246,7 @@ class LongParam(parent: String, name: String, doc: String, isValid: Long => Bool def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + /** Creates a param pair with the given value (for Java). */ override def w(value: Long): ParamPair[Long] = super.w(value) } @@ -260,6 +260,7 @@ class BooleanParam(parent: String, name: String, doc: String) // No need for isV def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + /** Creates a param pair with the given value (for Java). */ override def w(value: Boolean): ParamPair[Boolean] = super.w(value) } @@ -274,8 +275,6 @@ class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array def this(parent: Params, name: String, doc: String) = this(parent, name, doc, ParamValidators.alwaysTrue) - override def w(value: Array[String]): ParamPair[Array[String]] = super.w(value) - /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray) } @@ -291,10 +290,9 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array def this(parent: Params, name: String, doc: String) = this(parent, name, doc, ParamValidators.alwaysTrue) - override def w(value: Array[Double]): ParamPair[Array[Double]] = super.w(value) - /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ - def w(value: java.util.List[Double]): ParamPair[Array[Double]] = w(value.asScala.toArray) + def w(value: java.util.List[java.lang.Double]): ParamPair[Array[Double]] = + w(value.asScala.map(_.asInstanceOf[Double]).toArray) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index 70b0e40948e51..fc509d2ba1470 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.IndexedSeq import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, Vector => BV} import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, Matrices, Vector, Vectors} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLUtils @@ -188,6 +189,9 @@ class GaussianMixture private ( new GaussianMixtureModel(weights, gaussians) } + /** Java-friendly version of [[run()]] */ + def run(data: JavaRDD[Vector]): GaussianMixtureModel = run(data.rdd) + /** Average of dense breeze vectors */ private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = { val v = BDV.zeros[Double](x(0).length) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 5fc2cb1b62d33..cb807c8038101 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -25,6 +25,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.{MLUtils, Loader, Saveable} @@ -46,7 +47,7 @@ import org.apache.spark.sql.{SQLContext, Row} @Experimental class GaussianMixtureModel( val weights: Array[Double], - val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable{ + val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable { require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match") @@ -65,6 +66,10 @@ class GaussianMixtureModel( responsibilityMatrix.map(r => r.indexOf(r.max)) } + /** Java-friendly version of [[predict()]] */ + def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = + predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]] + /** * Given the input vectors, return the membership value of each vector * to all mixture components. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 6cf26445f20a0..974b26924dfb8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum} import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaPairRDD import org.apache.spark.graphx.{VertexId, EdgeContext, Graph} import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix} import org.apache.spark.rdd.RDD @@ -345,6 +346,11 @@ class DistributedLDAModel private ( } } + /** Java-friendly version of [[topicDistributions]] */ + def javaTopicDistributions: JavaPairRDD[java.lang.Long, Vector] = { + JavaPairRDD.fromRDD(topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]]) + } + // TODO: // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index c21e4fe7dc9b6..d9b34cec64894 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -21,8 +21,10 @@ import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaSparkContext._ import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStream} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -234,6 +236,9 @@ class StreamingKMeans( } } + /** Java-friendly version of `trainOn`. */ + def trainOn(data: JavaDStream[Vector]): Unit = trainOn(data.dstream) + /** * Use the clustering model to make predictions on batches of data from a DStream. * @@ -245,6 +250,11 @@ class StreamingKMeans( data.map(model.predict) } + /** Java-friendly version of `predictOn`. */ + def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Integer] = { + JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Integer]]) + } + /** * Use the model to make predictions on the values of a DStream and carry over its keys. * @@ -257,6 +267,14 @@ class StreamingKMeans( data.mapValues(model.predict) } + /** Java-friendly version of `predictOnValues`. */ + def predictOnValues[K]( + data: JavaPairDStream[K, Vector]): JavaPairDStream[K, java.lang.Integer] = { + implicit val tag = fakeClassTag[K] + JavaPairDStream.fromPairDStream( + predictOnValues(data.dstream).asInstanceOf[DStream[(K, java.lang.Integer)]]) + } + /** Check whether cluster centers have been initialized. */ private[this] def assertInitialized(): Unit = { if (model.clusterCenters == null) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index b3fad0c52d655..900007ec6bc74 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.stat import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.linalg.{Matrix, Vector} import org.apache.spark.mllib.regression.LabeledPoint @@ -80,6 +81,10 @@ object Statistics { */ def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) + /** Java-friendly version of [[corr()]] */ + def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double]): Double = + corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]]) + /** * Compute the correlation for the input RDDs using the specified method. * Methods currently supported: `pearson` (default), `spearman`. @@ -96,6 +101,10 @@ object Statistics { */ def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) + /** Java-friendly version of [[corr()]] */ + def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double], method: String): Double = + corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]], method) + /** * Conduct Pearson's chi-squared goodness of fit test of the observed data against the * expected distribution. diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java index e7df10dfa63ac..9890155e9f865 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java @@ -50,6 +50,7 @@ public void testParams() { testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a"); Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0); Assert.assertEquals(testParams.getMyStringParam(), "a"); + Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[] {1.0, 2.0}, 0.0); } @Test diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java index 947ae3a2ce06f..ff5929235ac2c 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java @@ -51,7 +51,8 @@ public String uid() { public int getMyIntParam() { return (Integer)getOrDefault(myIntParam_); } public JavaTestParams setMyIntParam(int value) { - set(myIntParam_, value); return this; + set(myIntParam_, value); + return this; } private DoubleParam myDoubleParam_; @@ -60,7 +61,8 @@ public JavaTestParams setMyIntParam(int value) { public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam_); } public JavaTestParams setMyDoubleParam(double value) { - set(myDoubleParam_, value); return this; + set(myDoubleParam_, value); + return this; } private Param myStringParam_; @@ -69,7 +71,18 @@ public JavaTestParams setMyDoubleParam(double value) { public String getMyStringParam() { return getOrDefault(myStringParam_); } public JavaTestParams setMyStringParam(String value) { - set(myStringParam_, value); return this; + set(myStringParam_, value); + return this; + } + + private DoubleArrayParam myDoubleArrayParam_; + public DoubleArrayParam myDoubleArrayParam() { return myDoubleArrayParam_; } + + public double[] getMyDoubleArrayParam() { return getOrDefault(myDoubleArrayParam_); } + + public JavaTestParams setMyDoubleArrayParam(double[] value) { + set(myDoubleArrayParam_, value); + return this; } private void init() { @@ -79,8 +92,14 @@ private void init() { List validStrings = Lists.newArrayList("a", "b"); myStringParam_ = new Param(this, "myStringParam", "this is a string param", ParamValidators.inArray(validStrings)); - setDefault(myIntParam_, 1); - setDefault(myDoubleParam_, 0.5); + myDoubleArrayParam_ = + new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param"); + + setDefault(myIntParam(), 1); + setDefault(myIntParam().w(1)); + setDefault(myDoubleParam(), 0.5); setDefault(myIntParam().w(1), myDoubleParam().w(0.5)); + setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0}); + setDefault(myDoubleArrayParam().w(new double[] {1.0, 2.0})); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java similarity index 95% rename from mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java rename to mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java index 640d2ec55e4e7..55787f8606d48 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.ml.classification; +package org.apache.spark.mllib.classification; import java.io.Serializable; import java.util.List; @@ -28,7 +28,6 @@ import org.junit.Test; import org.apache.spark.SparkConf; -import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java new file mode 100644 index 0000000000000..467a7a69e8f30 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering; + +import java.io.Serializable; +import java.util.List; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; + +public class JavaGaussianMixtureSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaGaussianMixture"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runGaussianMixture() { + List points = Lists.newArrayList( + Vectors.dense(1.0, 2.0, 6.0), + Vectors.dense(1.0, 3.0, 0.0), + Vectors.dense(1.0, 4.0, 6.0) + ); + + JavaRDD data = sc.parallelize(points, 2); + GaussianMixtureModel model = new GaussianMixture().setK(2).setMaxIterations(1).setSeed(1234) + .run(data); + assertEquals(model.gaussians().length, 2); + JavaRDD predictions = model.predict(data); + predictions.first(); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index 96c2da169961f..581c033f08ebe 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -107,6 +107,10 @@ public void distributedLDAModel() { // Check: log probabilities assert(model.logLikelihood() < 0.0); assert(model.logPrior() < 0.0); + + // Check: topic distributions + JavaPairRDD topicDistributions = model.javaTopicDistributions(); + assertEquals(topicDistributions.count(), corpus.count()); } @Test diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java new file mode 100644 index 0000000000000..3b0e879eec77f --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering; + +import java.io.Serializable; +import java.util.List; + +import scala.Tuple2; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.apache.spark.streaming.JavaTestUtils.*; + +import org.apache.spark.SparkConf; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +public class JavaStreamingKMeansSuite implements Serializable { + + protected transient JavaStreamingContext ssc; + + @Before + public void setUp() { + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); + ssc.checkpoint("checkpoint"); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + } + + @Test + @SuppressWarnings("unchecked") + public void javaAPI() { + List trainingBatch = Lists.newArrayList( + Vectors.dense(1.0), + Vectors.dense(0.0)); + JavaDStream training = + attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2); + List> testBatch = Lists.newArrayList( + new Tuple2(10, Vectors.dense(1.0)), + new Tuple2(11, Vectors.dense(0.0))); + JavaPairDStream test = JavaPairDStream.fromJavaDStream( + attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2)); + StreamingKMeans skmeans = new StreamingKMeans() + .setK(1) + .setDecayFactor(1.0) + .setInitialCenters(new Vector[]{Vectors.dense(1.0)}, new double[]{0.0}); + skmeans.trainOn(training); + JavaPairDStream prediction = skmeans.predictOnValues(test); + attachTestOutputStream(prediction.count()); + runStreams(ssc, 2, 2); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java new file mode 100644 index 0000000000000..62f7f26b7c98f --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat; + +import java.io.Serializable; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; + +public class JavaStatisticsSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaStatistics"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void testCorr() { + JavaRDD x = sc.parallelize(Lists.newArrayList(1.0, 2.0, 3.0, 4.0)); + JavaRDD y = sc.parallelize(Lists.newArrayList(1.1, 2.2, 3.1, 4.3)); + + Double corr1 = Statistics.corr(x, y); + Double corr2 = Statistics.corr(x, y, "pearson"); + // Check default method + assertEquals(corr1, corr2); + } +} From c6a6dd0d0736d548ff9f255e5ed5df45b29c46c1 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 3 Jun 2015 12:10:12 -0700 Subject: [PATCH 166/254] [MINOR] [UI] Improve confusing message on log page It's good practice to check if the input path is in the directory we expect to avoid potentially confusing error messages. --- .../spark/deploy/worker/ui/LogPage.scala | 9 +++ .../scala/org/apache/spark/util/Utils.scala | 16 +++++ .../spark/deploy/worker/ui/LogPageSuite.scala | 36 ++++++---- .../org/apache/spark/util/UtilsSuite.scala | 65 +++++++++++++++++++ 4 files changed, 115 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index dc2bee6f2bdca..53f8f9a46cf8d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -17,6 +17,8 @@ package org.apache.spark.deploy.worker.ui +import java.io.File +import java.net.URI import javax.servlet.http.HttpServletRequest import scala.xml.Node @@ -135,6 +137,13 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with return ("Error: Log type must be one of " + supportedLogTypes.mkString(", "), 0, 0, 0) } + // Verify that the normalized path of the log directory is in the working directory + val normalizedUri = new URI(logDirectory).normalize() + val normalizedLogDir = new File(normalizedUri.getPath) + if (!Utils.isInDirectory(workDir, normalizedLogDir)) { + return ("Error: invalid log directory " + logDirectory, 0, 0, 0) + } + try { val files = RollingFileAppender.getSortedRolledOverFiles(logDirectory, logType) logDebug(s"Sorted log files of type $logType in $logDirectory:\n${files.mkString("\n")}") diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 693e1a0a3d5f0..5f132410540fd 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2227,6 +2227,22 @@ private[spark] object Utils extends Logging { } } + /** + * Return whether the specified file is a parent directory of the child file. + */ + def isInDirectory(parent: File, child: File): Boolean = { + if (child == null || parent == null) { + return false + } + if (!child.exists() || !parent.exists() || !parent.isDirectory()) { + return false + } + if (parent.equals(child)) { + return true + } + isInDirectory(parent, child.getParentFile) + } + } private [util] class SparkShutdownHookManager { diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ui/LogPageSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ui/LogPageSuite.scala index 572360ddb95d4..72eaffb416981 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ui/LogPageSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ui/LogPageSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.worker.ui import java.io.{File, FileWriter} -import org.mockito.Mockito.mock +import org.mockito.Mockito.{mock, when} import org.scalatest.PrivateMethodTester import org.apache.spark.SparkFunSuite @@ -28,33 +28,47 @@ class LogPageSuite extends SparkFunSuite with PrivateMethodTester { test("get logs simple") { val webui = mock(classOf[WorkerWebUI]) + val tmpDir = new File(sys.props("java.io.tmpdir")) + val workDir = new File(tmpDir, "work-dir") + workDir.mkdir() + when(webui.workDir).thenReturn(workDir) val logPage = new LogPage(webui) // Prepare some fake log files to read later val out = "some stdout here" val err = "some stderr here" - val tmpDir = new File(sys.props("java.io.tmpdir")) - val tmpOut = new File(tmpDir, "stdout") - val tmpErr = new File(tmpDir, "stderr") - val tmpRand = new File(tmpDir, "random") + val tmpOut = new File(workDir, "stdout") + val tmpErr = new File(workDir, "stderr") + val tmpErrBad = new File(tmpDir, "stderr") // outside the working directory + val tmpOutBad = new File(tmpDir, "stdout") + val tmpRand = new File(workDir, "random") write(tmpOut, out) write(tmpErr, err) + write(tmpOutBad, out) + write(tmpErrBad, err) write(tmpRand, "1 6 4 5 2 7 8") // Get the logs. All log types other than "stderr" or "stdout" will be rejected val getLog = PrivateMethod[(String, Long, Long, Long)]('getLog) val (stdout, _, _, _) = - logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100) + logPage invokePrivate getLog(workDir.getAbsolutePath, "stdout", None, 100) val (stderr, _, _, _) = - logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100) + logPage invokePrivate getLog(workDir.getAbsolutePath, "stderr", None, 100) val (error1, _, _, _) = - logPage invokePrivate getLog(tmpDir.getAbsolutePath, "random", None, 100) + logPage invokePrivate getLog(workDir.getAbsolutePath, "random", None, 100) val (error2, _, _, _) = - logPage invokePrivate getLog(tmpDir.getAbsolutePath, "does-not-exist.txt", None, 100) + logPage invokePrivate getLog(workDir.getAbsolutePath, "does-not-exist.txt", None, 100) + // These files exist, but live outside the working directory + val (error3, _, _, _) = + logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100) + val (error4, _, _, _) = + logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100) assert(stdout === out) assert(stderr === err) - assert(error1.startsWith("Error")) - assert(error2.startsWith("Error")) + assert(error1.startsWith("Error: Log type must be one of ")) + assert(error2.startsWith("Error: Log type must be one of ")) + assert(error3.startsWith("Error: invalid log directory")) + assert(error4.startsWith("Error: invalid log directory")) } /** Write the specified string to the file. */ diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index a867cf83dc3f1..a61ea3918f46a 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -608,4 +608,69 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { manager.runAll() assert(output.toList === List(4, 3, 2)) } + + test("isInDirectory") { + val tmpDir = new File(sys.props("java.io.tmpdir")) + val parentDir = new File(tmpDir, "parent-dir") + val childDir1 = new File(parentDir, "child-dir-1") + val childDir1b = new File(parentDir, "child-dir-1b") + val childFile1 = new File(parentDir, "child-file-1.txt") + val childDir2 = new File(childDir1, "child-dir-2") + val childDir2b = new File(childDir1, "child-dir-2b") + val childFile2 = new File(childDir1, "child-file-2.txt") + val childFile3 = new File(childDir2, "child-file-3.txt") + val nullFile: File = null + parentDir.mkdir() + childDir1.mkdir() + childDir1b.mkdir() + childDir2.mkdir() + childDir2b.mkdir() + childFile1.createNewFile() + childFile2.createNewFile() + childFile3.createNewFile() + + // Identity + assert(Utils.isInDirectory(parentDir, parentDir)) + assert(Utils.isInDirectory(childDir1, childDir1)) + assert(Utils.isInDirectory(childDir2, childDir2)) + + // Valid ancestor-descendant pairs + assert(Utils.isInDirectory(parentDir, childDir1)) + assert(Utils.isInDirectory(parentDir, childFile1)) + assert(Utils.isInDirectory(parentDir, childDir2)) + assert(Utils.isInDirectory(parentDir, childFile2)) + assert(Utils.isInDirectory(parentDir, childFile3)) + assert(Utils.isInDirectory(childDir1, childDir2)) + assert(Utils.isInDirectory(childDir1, childFile2)) + assert(Utils.isInDirectory(childDir1, childFile3)) + assert(Utils.isInDirectory(childDir2, childFile3)) + + // Inverted ancestor-descendant pairs should fail + assert(!Utils.isInDirectory(childDir1, parentDir)) + assert(!Utils.isInDirectory(childDir2, parentDir)) + assert(!Utils.isInDirectory(childDir2, childDir1)) + assert(!Utils.isInDirectory(childFile1, parentDir)) + assert(!Utils.isInDirectory(childFile2, parentDir)) + assert(!Utils.isInDirectory(childFile3, parentDir)) + assert(!Utils.isInDirectory(childFile2, childDir1)) + assert(!Utils.isInDirectory(childFile3, childDir1)) + assert(!Utils.isInDirectory(childFile3, childDir2)) + + // Non-existent files or directories should fail + assert(!Utils.isInDirectory(parentDir, new File(parentDir, "one.txt"))) + assert(!Utils.isInDirectory(parentDir, new File(parentDir, "one/two.txt"))) + assert(!Utils.isInDirectory(parentDir, new File(parentDir, "one/two/three.txt"))) + + // Siblings should fail + assert(!Utils.isInDirectory(childDir1, childDir1b)) + assert(!Utils.isInDirectory(childDir1, childFile1)) + assert(!Utils.isInDirectory(childDir2, childDir2b)) + assert(!Utils.isInDirectory(childDir2, childFile2)) + + // Null files should fail without throwing NPE + assert(!Utils.isInDirectory(parentDir, nullFile)) + assert(!Utils.isInDirectory(childFile3, nullFile)) + assert(!Utils.isInDirectory(nullFile, parentDir)) + assert(!Utils.isInDirectory(nullFile, childFile3)) + } } From bfbf12b349e998c7e674649a07b88c4658ae0711 Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Wed, 3 Jun 2015 14:57:23 -0700 Subject: [PATCH 167/254] [SPARK-8083] [MESOS] Use the correct base path in mesos driver page. Author: Timothy Chen Closes #6615 from tnachen/mesos_driver_path and squashes the following commits: 4f47b7c [Timothy Chen] Use the correct base path in mesos driver page. --- .../scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala index be8560d10fc62..e8ef60bd5428a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala @@ -68,7 +68,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") retryHeaders, retryRow, Iterable.apply(driverState.description.retryState)) val content =

    Driver state information for driver id {driverId}

    - Back to Drivers + Back to Drivers

    Driver state: {driverState.state}

    From aa40c4420717aa06a7964bd30b428fb73548beb2 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 3 Jun 2015 14:59:30 -0700 Subject: [PATCH 168/254] [SPARK-8059] [YARN] Wake up allocation thread when new requests arrive. This should help reduce latency for new executor allocations. Author: Marcelo Vanzin Closes #6600 from vanzin/SPARK-8059 and squashes the following commits: 8387a3a [Marcelo Vanzin] [SPARK-8059] [yarn] Wake up allocation thread when new requests arrive. --- .../spark/deploy/yarn/ApplicationMaster.scala | 16 +++++++++++++--- .../apache/spark/deploy/yarn/YarnAllocator.scala | 7 ++++++- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 760e458972d98..002d7b6eaf498 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -67,6 +67,7 @@ private[spark] class ApplicationMaster( @volatile private var reporterThread: Thread = _ @volatile private var allocator: YarnAllocator = _ + private val allocatorLock = new Object() // Fields used in client mode. private var rpcEnv: RpcEnv = null @@ -359,7 +360,9 @@ private[spark] class ApplicationMaster( } logDebug(s"Number of pending allocations is $numPendingAllocate. " + s"Sleeping for $sleepInterval.") - Thread.sleep(sleepInterval) + allocatorLock.synchronized { + allocatorLock.wait(sleepInterval) + } } catch { case e: InterruptedException => } @@ -546,8 +549,15 @@ private[spark] class ApplicationMaster( override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RequestExecutors(requestedTotal) => Option(allocator) match { - case Some(a) => a.requestTotalExecutors(requestedTotal) - case None => logWarning("Container allocator is not ready to request executors yet.") + case Some(a) => + allocatorLock.synchronized { + if (a.requestTotalExecutors(requestedTotal)) { + allocatorLock.notifyAll() + } + } + + case None => + logWarning("Container allocator is not ready to request executors yet.") } context.reply(true) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 21193e7c625e3..940873fbd046c 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -146,11 +146,16 @@ private[yarn] class YarnAllocator( * Request as many executors from the ResourceManager as needed to reach the desired total. If * the requested total is smaller than the current number of running executors, no executors will * be killed. + * + * @return Whether the new requested total is different than the old value. */ - def requestTotalExecutors(requestedTotal: Int): Unit = synchronized { + def requestTotalExecutors(requestedTotal: Int): Boolean = synchronized { if (requestedTotal != targetNumExecutors) { logInfo(s"Driver requested a total number of $requestedTotal executor(s).") targetNumExecutors = requestedTotal + true + } else { + false } } From 1d8669f15c136cd81f494dd487400c62c9498602 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 3 Jun 2015 15:03:07 -0700 Subject: [PATCH 169/254] [SPARK-8001] [CORE] Make AsynchronousListenerBus.waitUntilEmpty throw TimeoutException if timeout Some places forget to call `assert` to check the return value of `AsynchronousListenerBus.waitUntilEmpty`. Instead of adding `assert` in these places, I think it's better to make `AsynchronousListenerBus.waitUntilEmpty` throw `TimeoutException`. Author: zsxwing Closes #6550 from zsxwing/SPARK-8001 and squashes the following commits: 607674a [zsxwing] Make AsynchronousListenerBus.waitUntilEmpty throw TimeoutException if timeout --- .../spark/util/AsynchronousListenerBus.scala | 11 +++++----- .../spark/deploy/LogUrlsStandaloneSuite.scala | 4 ++-- .../spark/scheduler/DAGSchedulerSuite.scala | 18 +++++++-------- .../spark/scheduler/SparkListenerSuite.scala | 22 +++++++++---------- .../SparkListenerWithClusterSuite.scala | 2 +- .../spark/deploy/yarn/YarnClusterSuite.scala | 2 +- 6 files changed, 30 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala index 1861d38640102..61b5a4cecddce 100644 --- a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala @@ -120,21 +120,22 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri /** * For testing only. Wait until there are no more events in the queue, or until the specified - * time has elapsed. Return true if the queue has emptied and false is the specified time - * elapsed before the queue emptied. + * time has elapsed. Throw `TimeoutException` if the specified time elapsed before the queue + * emptied. */ @VisibleForTesting - def waitUntilEmpty(timeoutMillis: Int): Boolean = { + @throws(classOf[TimeoutException]) + def waitUntilEmpty(timeoutMillis: Long): Unit = { val finishTime = System.currentTimeMillis + timeoutMillis while (!queueIsEmpty) { if (System.currentTimeMillis > finishTime) { - return false + throw new TimeoutException( + s"The event queue is not empty after $timeoutMillis milliseconds") } /* Sleep rather than using wait/notify, because this is used only for testing and * wait/notify add overhead in the general case. */ Thread.sleep(10) } - true } /** diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala index c215b0582889f..ddc92814c0acf 100644 --- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -41,7 +41,7 @@ class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { // Trigger a job so that executors get added sc.parallelize(1 to 100, 4).map(_.toString).count() - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) listener.addedExecutorInfos.values.foreach { info => assert(info.logUrlMap.nonEmpty) // Browse to each URL to check that it's valid @@ -71,7 +71,7 @@ class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { // Trigger a job so that executors get added sc.parallelize(1 to 100, 4).map(_.toString).count() - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) val listeners = sc.listenerBus.findListenersByClass[SaveExecutorInfo] assert(listeners.size === 1) val listener = listeners(0) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index bfcf918e06162..47b2868753c0e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -254,7 +254,7 @@ class DAGSchedulerSuite test("[SPARK-3353] parent stage should have lower stage id") { sparkListener.stageByOrderOfExecution.clear() sc.parallelize(1 to 10).map(x => (x, x)).reduceByKey(_ + _, 4).count() - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(sparkListener.stageByOrderOfExecution.length === 2) assert(sparkListener.stageByOrderOfExecution(0) < sparkListener.stageByOrderOfExecution(1)) } @@ -389,7 +389,7 @@ class DAGSchedulerSuite submit(unserializableRdd, Array(0)) assert(failure.getMessage.startsWith( "Job aborted due to stage failure: Task not serializable:")) - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(sparkListener.failedStages.contains(0)) assert(sparkListener.failedStages.size === 1) assertDataStructuresEmpty() @@ -399,7 +399,7 @@ class DAGSchedulerSuite submit(new MyRDD(sc, 1, Nil), Array(0)) failed(taskSets(0), "some failure") assert(failure.getMessage === "Job aborted due to stage failure: some failure") - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(sparkListener.failedStages.contains(0)) assert(sparkListener.failedStages.size === 1) assertDataStructuresEmpty() @@ -410,7 +410,7 @@ class DAGSchedulerSuite val jobId = submit(rdd, Array(0)) cancel(jobId) assert(failure.getMessage === s"Job $jobId cancelled ") - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(sparkListener.failedStages.contains(0)) assert(sparkListener.failedStages.size === 1) assertDataStructuresEmpty() @@ -462,7 +462,7 @@ class DAGSchedulerSuite assert(results === Map(0 -> 42)) assertDataStructuresEmpty() - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(sparkListener.failedStages.isEmpty) assert(sparkListener.successfulStages.contains(0)) } @@ -531,7 +531,7 @@ class DAGSchedulerSuite Map[Long, Any](), createFakeTaskInfo(), null)) - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(sparkListener.failedStages.contains(1)) // The second ResultTask fails, with a fetch failure for the output from the second mapper. @@ -543,7 +543,7 @@ class DAGSchedulerSuite createFakeTaskInfo(), null)) // The SparkListener should not receive redundant failure events. - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(sparkListener.failedStages.size == 1) } @@ -592,7 +592,7 @@ class DAGSchedulerSuite // Listener bus should get told about the map stage failing, but not the reduce stage // (since the reduce stage hasn't been started yet). - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(sparkListener.failedStages.toSet === Set(0)) assertDataStructuresEmpty() @@ -643,7 +643,7 @@ class DAGSchedulerSuite assert(cancelledStages.toSet === Set(0, 2)) // Make sure the listeners got told about both failed stages. - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(sparkListener.successfulStages.isEmpty) assert(sparkListener.failedStages.toSet === Set(0, 2)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 06fb909bf5419..651295b7344c5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -47,7 +47,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match // Starting listener bus should flush all buffered events bus.start(sc) - assert(bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(counter.count === 5) // After listener bus has stopped, posting events should not increment counter @@ -131,7 +131,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match rdd2.setName("Target RDD") rdd2.count() - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) listener.stageInfos.size should be {1} val (stageInfo, taskInfoMetrics) = listener.stageInfos.head @@ -156,7 +156,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match rdd3.setName("Trois") rdd1.count() - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) listener.stageInfos.size should be {1} val stageInfo1 = listener.stageInfos.keys.find(_.stageId == 0).get stageInfo1.rddInfos.size should be {1} // ParallelCollectionRDD @@ -165,7 +165,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match listener.stageInfos.clear() rdd2.count() - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) listener.stageInfos.size should be {1} val stageInfo2 = listener.stageInfos.keys.find(_.stageId == 1).get stageInfo2.rddInfos.size should be {3} // ParallelCollectionRDD, FilteredRDD, MappedRDD @@ -174,7 +174,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match listener.stageInfos.clear() rdd3.count() - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) listener.stageInfos.size should be {2} // Shuffle map stage + result stage val stageInfo3 = listener.stageInfos.keys.find(_.stageId == 3).get stageInfo3.rddInfos.size should be {1} // ShuffledRDD @@ -190,7 +190,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val rdd2 = rdd1.map(_.toString) sc.runJob(rdd2, (items: Iterator[String]) => items.size, Seq(0, 1), true) - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) listener.stageInfos.size should be {1} val (stageInfo, _) = listener.stageInfos.head @@ -214,7 +214,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val d = sc.parallelize(0 to 1e4.toInt, 64).map(w) d.count() - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) listener.stageInfos.size should be (1) val d2 = d.map { i => w(i) -> i * 2 }.setName("shuffle input 1") @@ -225,7 +225,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match d4.setName("A Cogroup") d4.collectAsMap() - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) listener.stageInfos.size should be (4) listener.stageInfos.foreach { case (stageInfo, taskInfoMetrics) => /** @@ -281,7 +281,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match .reduce { case (x, y) => x } assert(result === 1.to(akkaFrameSize).toArray) - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) val TASK_INDEX = 0 assert(listener.startedTasks.contains(TASK_INDEX)) assert(listener.startedGettingResultTasks.contains(TASK_INDEX)) @@ -297,7 +297,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val result = sc.parallelize(Seq(1), 1).map(2 * _).reduce { case (x, y) => x } assert(result === 2) - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) val TASK_INDEX = 0 assert(listener.startedTasks.contains(TASK_INDEX)) assert(listener.startedGettingResultTasks.isEmpty) @@ -352,7 +352,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match // Post events to all listeners, and wait until the queue is drained (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } - assert(bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) // The exception should be caught, and the event should be propagated to other listeners assert(bus.listenerThreadIsAlive) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala index c7f179e1483a5..50273bcc8ce5e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala @@ -46,7 +46,7 @@ class SparkListenerWithClusterSuite extends SparkFunSuite with LocalSparkContext rdd2.setName("Target RDD") rdd2.count() - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(listener.addedExecutorInfo.size == 2) assert(listener.addedExecutorInfo("0").totalCores == 1) assert(listener.addedExecutorInfo("1").totalCores == 1) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index d8bc2534c1a6a..bc42e12dfafd7 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -326,7 +326,7 @@ private object YarnClusterDriver extends Logging with Matchers { var result = "failure" try { val data = sc.parallelize(1 to 4, 4).collect().toSet - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) data should be (Set(1, 2, 3, 4)) result = "success" } finally { From f27134782ebb61c360330e2d6d5bb1aa02be3fb6 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 3 Jun 2015 15:04:20 -0700 Subject: [PATCH 170/254] [SPARK-7989] [CORE] [TESTS] Fix flaky tests in ExternalShuffleServiceSuite and SparkListenerWithClusterSuite The flaky tests in ExternalShuffleServiceSuite and SparkListenerWithClusterSuite will fail if there are not enough executors up before running the jobs. This PR adds `JobProgressListener.waitUntilExecutorsUp`. The tests for the cluster mode can use it to wait until the expected executors are up. Author: zsxwing Closes #6546 from zsxwing/SPARK-7989 and squashes the following commits: 5560e09 [zsxwing] Fix a typo 3b69840 [zsxwing] Fix flaky tests in ExternalShuffleServiceSuite and SparkListenerWithClusterSuite --- .../spark/ui/jobs/JobProgressListener.scala | 30 +++++++++++++++++++ .../spark/ExternalShuffleServiceSuite.scala | 8 +++++ .../spark/broadcast/BroadcastSuite.scala | 10 +------ .../SparkListenerWithClusterSuite.scala | 10 +++++-- 4 files changed, 46 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index f39e961772c46..1d31fce4c697b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -17,8 +17,12 @@ package org.apache.spark.ui.jobs +import java.util.concurrent.TimeoutException + import scala.collection.mutable.{HashMap, HashSet, ListBuffer} +import com.google.common.annotations.VisibleForTesting + import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics @@ -526,4 +530,30 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { override def onApplicationStart(appStarted: SparkListenerApplicationStart) { startTime = appStarted.time } + + /** + * For testing only. Wait until at least `numExecutors` executors are up, or throw + * `TimeoutException` if the waiting time elapsed before `numExecutors` executors up. + * + * @param numExecutors the number of executors to wait at least + * @param timeout time to wait in milliseconds + */ + @VisibleForTesting + private[spark] def waitUntilExecutorsUp(numExecutors: Int, timeout: Long): Unit = { + val finishTime = System.currentTimeMillis() + timeout + while (System.currentTimeMillis() < finishTime) { + val numBlockManagers = synchronized { + blockManagerIds.size + } + if (numBlockManagers >= numExecutors + 1) { + // Need to count the block manager in driver + return + } + // Sleep rather than using wait/notify, because this is used only for testing and wait/notify + // add overhead in the general case. + Thread.sleep(10) + } + throw new TimeoutException( + s"Can't find $numExecutors executors before $timeout milliseconds elapsed") + } } diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index bac6fdbcdc976..5b127a070c07f 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -55,6 +55,14 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { sc.env.blockManager.externalShuffleServiceEnabled should equal(true) sc.env.blockManager.shuffleClient.getClass should equal(classOf[ExternalShuffleClient]) + // In a slow machine, one slave may register hundreds of milliseconds ahead of the other one. + // If we don't wait for all salves, it's possible that only one executor runs all jobs. Then + // all shuffle blocks will be in this executor, ShuffleBlockFetcherIterator will directly fetch + // local blocks from the local BlockManager and won't send requests to ExternalShuffleService. + // In this case, we won't receive FetchFailed. And it will make this test fail. + // Therefore, we should wait until all salves are up + sc.jobProgressListener.waitUntilExecutorsUp(2, 10000) + val rdd = sc.parallelize(0 until 1000, 10).map(i => (i, 1)).reduceByKey(_ + _) rdd.count() diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index c05e8bb6538ba..c054c718075f8 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.broadcast -import scala.concurrent.duration._ import scala.util.Random import org.scalatest.Assertions -import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.io.SnappyCompressionCodec @@ -312,13 +310,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { val _sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf) // Wait until all salves are up - eventually(timeout(10.seconds), interval(10.milliseconds)) { - _sc.jobProgressListener.synchronized { - val numBlockManagers = _sc.jobProgressListener.blockManagerIds.size - assert(numBlockManagers == numSlaves + 1, - s"Expect ${numSlaves + 1} block managers, but was ${numBlockManagers}") - } - } + _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 10000) _sc } else { new SparkContext("local", "test", broadcastConf) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala index 50273bcc8ce5e..d97fba00976d2 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.scheduler -import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite} +import scala.collection.mutable import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import scala.collection.mutable +import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite} +import org.apache.spark.scheduler.cluster.ExecutorInfo /** * Unit tests for SparkListener that require a local cluster. @@ -41,6 +41,10 @@ class SparkListenerWithClusterSuite extends SparkFunSuite with LocalSparkContext val listener = new SaveExecutorInfo sc.addSparkListener(listener) + // This test will check if the number of executors received by "SparkListener" is same as the + // number of all executors, so we need to wait until all executors are up + sc.jobProgressListener.waitUntilExecutorsUp(2, 10000) + val rdd1 = sc.parallelize(1 to 100, 4) val rdd2 = rdd1.map(_.toString) rdd2.setName("Target RDD") From a8f1f1543e29fb2897e9ae6940581b9e4a3a13fb Mon Sep 17 00:00:00 2001 From: Hari Shreedharan Date: Wed, 3 Jun 2015 15:11:02 -0700 Subject: [PATCH 171/254] [HOTFIX] Fix Hadoop-1 build caused by #5792. Replaced `fs.listFiles` with Hadoop-1 friendly `fs.listStatus` method. Author: Hari Shreedharan Closes #6619 from harishreedharan/evetlog-hadoop-1-fix and squashes the following commits: 6192078 [Hari Shreedharan] [HOTFIX] Fix Hadoop-1 build caused by #5972. --- .../apache/spark/deploy/history/FsHistoryProvider.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 52b149b273e4b..5427a88f32ffd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -255,12 +255,12 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // If this is a legacy directory, then add the directory to the zipStream and add // each file to that directory. if (isLegacyLogDirectory(fs.getFileStatus(logPath))) { - val files = fs.listFiles(logPath, false) + val files = fs.listStatus(logPath) zipStream.putNextEntry(new ZipEntry(attempt.logPath + "/")) zipStream.closeEntry() - while (files.hasNext) { - val file = files.next().getPath - zipFileToStream(file, attempt.logPath + Path.SEPARATOR + file.getName, zipStream) + files.foreach { file => + val path = file.getPath + zipFileToStream(path, attempt.logPath + Path.SEPARATOR + path.getName, zipStream) } } else { zipFileToStream(new Path(logDir, attempt.logPath), attempt.logPath, zipStream) From d3e026f8798f9875b90e8c372056ee3d71489be5 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Wed, 3 Jun 2015 15:14:38 -0700 Subject: [PATCH 172/254] [SPARK-3674] [EC2] Clear SPARK_WORKER_INSTANCES when using YARN cc andrewor14 Author: Shivaram Venkataraman Closes #6424 from shivaram/spark-worker-instances-yarn-ec2 and squashes the following commits: db244ae [Shivaram Venkataraman] Make Python Lint happy 0593d1b [Shivaram Venkataraman] Clear SPARK_WORKER_INSTANCES when using YARN --- ec2/spark_ec2.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index ee0904c9e5d54..84629cb9a0ca0 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -219,7 +219,8 @@ def parse_args(): "(default: %default).") parser.add_option( "--hadoop-major-version", default="1", - help="Major version of Hadoop (default: %default)") + help="Major version of Hadoop. Valid options are 1 (Hadoop 1.0.4), 2 (CDH 4.2.0), yarn " + + "(Hadoop 2.4.0) (default: %default)") parser.add_option( "-D", metavar="[ADDRESS:]PORT", dest="proxy_port", help="Use SSH dynamic port forwarding to create a SOCKS proxy at " + @@ -271,7 +272,8 @@ def parse_args(): help="Launch fresh slaves, but use an existing stopped master if possible") parser.add_option( "--worker-instances", type="int", default=1, - help="Number of instances per worker: variable SPARK_WORKER_INSTANCES (default: %default)") + help="Number of instances per worker: variable SPARK_WORKER_INSTANCES. Not used if YARN " + + "is used as Hadoop major version (default: %default)") parser.add_option( "--master-opts", type="string", default="", help="Extra options to give to master through SPARK_MASTER_OPTS variable " + @@ -761,6 +763,10 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): if opts.ganglia: modules.append('ganglia') + # Clear SPARK_WORKER_INSTANCES if running on YARN + if opts.hadoop_major_version == "yarn": + opts.worker_instances = "" + # NOTE: We should clone the repository before running deploy_files to # prevent ec2-variables.sh from being overwritten print("Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format( @@ -998,6 +1004,7 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): master_addresses = [get_dns_name(i, opts.private_ips) for i in master_nodes] slave_addresses = [get_dns_name(i, opts.private_ips) for i in slave_nodes] + worker_instances_str = "%d" % opts.worker_instances if opts.worker_instances else "" template_vars = { "master_list": '\n'.join(master_addresses), "active_master": active_master, @@ -1011,7 +1018,7 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): "spark_version": spark_v, "tachyon_version": tachyon_v, "hadoop_major_version": opts.hadoop_major_version, - "spark_worker_instances": "%d" % opts.worker_instances, + "spark_worker_instances": worker_instances_str, "spark_master_opts": opts.master_opts } From 26c9d7a0f975009e22ec91e5c0b5cfcada79b35e Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 3 Jun 2015 15:16:24 -0700 Subject: [PATCH 173/254] [SPARK-8051] [MLLIB] make StringIndexerModel silent if input column does not exist This is just a workaround to a bigger problem. Some pipeline stages may not be effective during prediction, and they should not complain about missing required columns, e.g. `StringIndexerModel`. jkbradley Author: Xiangrui Meng Closes #6595 from mengxr/SPARK-8051 and squashes the following commits: b6a36b9 [Xiangrui Meng] add doc f143fd4 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-8051 8ee7c7e [Xiangrui Meng] use SparkFunSuite e112394 [Xiangrui Meng] make StringIndexerModel silent if input column does not exist --- .../apache/spark/ml/feature/StringIndexer.scala | 16 +++++++++++++++- .../spark/ml/feature/StringIndexerSuite.scala | 8 ++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index a2dc8a8b960c5..f4e250757560a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -88,6 +88,9 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod /** * :: Experimental :: * Model fitted by [[StringIndexer]]. + * NOTE: During transformation, if the input column does not exist, + * [[StringIndexerModel.transform]] would return the input dataset unmodified. + * This is a temporary fix for the case when target labels do not exist during prediction. */ @Experimental class StringIndexerModel private[ml] ( @@ -112,6 +115,12 @@ class StringIndexerModel private[ml] ( def setOutputCol(value: String): this.type = set(outputCol, value) override def transform(dataset: DataFrame): DataFrame = { + if (!dataset.schema.fieldNames.contains($(inputCol))) { + logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " + + "Skip StringIndexerModel.") + return dataset + } + val indexer = udf { label: String => if (labelToIndex.contains(label)) { labelToIndex(label) @@ -128,6 +137,11 @@ class StringIndexerModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + if (schema.fieldNames.contains($(inputCol))) { + validateAndTransformSchema(schema) + } else { + // If the input column does not exist during transformation, we skip StringIndexerModel. + schema + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index cbf1e8ddcb48a..5f557e16e5150 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -60,4 +60,12 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) assert(output === expected) } + + test("StringIndexerModel should keep silent if the input column does not exist.") { + val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c")) + .setInputCol("label") + .setOutputCol("labelIndex") + val df = sqlContext.range(0L, 10L) + assert(indexerModel.transform(df).eq(df)) + } } From d8662cd909a41575df6e0ea1630d2386d3711240 Mon Sep 17 00:00:00 2001 From: leahmcguire Date: Wed, 3 Jun 2015 15:46:38 -0700 Subject: [PATCH 174/254] [SPARK-6164] [ML] CrossValidatorModel should keep stats from fitting Added stats from cross validation as a val in the cross validation model to save them for user access. Author: leahmcguire Closes #5915 from leahmcguire/saveCVmetrics and squashes the following commits: 49b507b [leahmcguire] fixed tyle error 67537b1 [leahmcguire] rebased 85907f0 [leahmcguire] fixed name 59987cc [leahmcguire] changed param name and test according to comments 36e71e3 [leahmcguire] rebasing 4b8223e [leahmcguire] fixed name 4ddffc6 [leahmcguire] changed param name and test according to comments 3a995da [leahmcguire] Added stats from cross validation as a val in the cross validation model to save them for user access --- .../org/apache/spark/ml/tuning/CrossValidator.scala | 10 +++++++--- .../apache/spark/ml/tuning/CrossValidatorSuite.scala | 1 + 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 6434b64aed15d..cb29392e8bc63 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -135,7 +135,7 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM logInfo(s"Best set of parameters:\n${epm(bestIndex)}") logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] - copyValues(new CrossValidatorModel(uid, bestModel).setParent(this)) + copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -158,7 +158,8 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM @Experimental class CrossValidatorModel private[ml] ( override val uid: String, - val bestModel: Model[_]) + val bestModel: Model[_], + val avgMetrics: Array[Double]) extends Model[CrossValidatorModel] with CrossValidatorParams { override def validateParams(): Unit = { @@ -175,7 +176,10 @@ class CrossValidatorModel private[ml] ( } override def copy(extra: ParamMap): CrossValidatorModel = { - val copied = new CrossValidatorModel(uid, bestModel.copy(extra).asInstanceOf[Model[_]]) + val copied = new CrossValidatorModel( + uid, + bestModel.copy(extra).asInstanceOf[Model[_]], + avgMetrics.clone()) copyValues(copied, extra) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 5ba469c7b10a0..9b3619f0046ea 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -56,6 +56,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) + assert(cvModel.avgMetrics.length === lrParamMaps.length) } test("validateParams should check estimatorParamMaps") { From bfbdab12dd37587e5518dcbb76507b752759cace Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 3 Jun 2015 16:04:02 -0700 Subject: [PATCH 175/254] [HOTFIX] [TYPO] Fix typo in #6546 --- .../scala/org/apache/spark/ExternalShuffleServiceSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 5b127a070c07f..140012226fdbb 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -56,11 +56,11 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { sc.env.blockManager.shuffleClient.getClass should equal(classOf[ExternalShuffleClient]) // In a slow machine, one slave may register hundreds of milliseconds ahead of the other one. - // If we don't wait for all salves, it's possible that only one executor runs all jobs. Then + // If we don't wait for all slaves, it's possible that only one executor runs all jobs. Then // all shuffle blocks will be in this executor, ShuffleBlockFetcherIterator will directly fetch // local blocks from the local BlockManager and won't send requests to ExternalShuffleService. // In this case, we won't receive FetchFailed. And it will make this test fail. - // Therefore, we should wait until all salves are up + // Therefore, we should wait until all slaves are up sc.jobProgressListener.waitUntilExecutorsUp(2, 10000) val rdd = sc.parallelize(0 until 1000, 10).map(i => (i, 1)).reduceByKey(_ + _) From 566cb5947925c79ef90af72346672ab7d27bf4df Mon Sep 17 00:00:00 2001 From: Hari Shreedharan Date: Wed, 3 Jun 2015 16:53:57 -0700 Subject: [PATCH 176/254] [HOTFIX] History Server API docs error fix. Minor error in the monitoring docs. Also made indentation changes in `ApiRootResource` Author: Hari Shreedharan Closes #6628 from harishreedharan/eventlog-formatting and squashes the following commits: a12553d [Hari Shreedharan] Javadoc updates. ca399b6 [Hari Shreedharan] [HOTFIX] History Server API docs error fix. --- .../apache/spark/status/api/v1/ApiRootResource.scala | 10 +++++++--- .../spark/status/api/v1/EventLogDownloadResource.scala | 2 +- docs/monitoring.md | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index 9af90ee5ecd9d..50b6ba67e9931 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -167,14 +167,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { @Path("applications/{appId}/logs") def getEventLogs( - @PathParam("appId") appId: String): EventLogDownloadResource = { + @PathParam("appId") appId: String): EventLogDownloadResource = { new EventLogDownloadResource(uiRoot, appId, None) } @Path("applications/{appId}/{attemptId}/logs") def getEventLogs( - @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): EventLogDownloadResource = { + @PathParam("appId") appId: String, + @PathParam("attemptId") attemptId: String): EventLogDownloadResource = { new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) } } @@ -206,6 +206,10 @@ private[spark] trait UIRoot { def getSparkUI(appKey: String): Option[SparkUI] def getApplicationInfoList: Iterator[ApplicationInfo] + /** + * Write the event logs for the given app to the [[ZipOutputStream]] instance. If attemptId is + * [[None]], event logs for all attempts of this application will be written out. + */ def writeEventLogs(appId: String, attemptId: Option[String], zipStream: ZipOutputStream): Unit = { Response.serverError() .entity("Event logs are only available through the history server.") diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala index d416dba8324d8..22e21f0c62a29 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala @@ -44,7 +44,7 @@ private[v1] class EventLogDownloadResource( } val stream = new StreamingOutput { - override def write(output: OutputStream) = { + override def write(output: OutputStream): Unit = { val zipStream = new ZipOutputStream(output) try { uIRoot.writeEventLogs(appId, attemptId, zipStream) diff --git a/docs/monitoring.md b/docs/monitoring.md index 31ecddc6dbbb9..bcf885fe4e681 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -233,7 +233,7 @@ for a running application, at `http://localhost:4040/api/v1`. Download the event logs for all attempts of the given application as a zip file - /applications/[app-id]/[attempt-id/logs + /applications/[app-id]/[attempt-id]/logs Download the event logs for the specified attempt of the given application as a zip file From 51898b5158ac7e7e67b0539bc062c9c16ce9a7ce Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Wed, 3 Jun 2015 16:54:46 -0700 Subject: [PATCH 177/254] [SPARK-8088] don't attempt to lower number of executors by 0 Author: Ryan Williams Closes #6624 from ryan-williams/execs and squashes the following commits: b6f71d4 [Ryan Williams] don't attempt to lower number of executors by 0 --- .../org/apache/spark/ExecutorAllocationManager.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 9514604752640..f7323a4d9db72 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -266,10 +266,14 @@ private[spark] class ExecutorAllocationManager( // executors and inform the cluster manager to cancel the extra pending requests val oldNumExecutorsTarget = numExecutorsTarget numExecutorsTarget = math.max(maxNeeded, minNumExecutors) - client.requestTotalExecutors(numExecutorsTarget) numExecutorsToAdd = 1 - logInfo(s"Lowering target number of executors to $numExecutorsTarget because " + - s"not all requests are actually needed (previously $oldNumExecutorsTarget)") + + // If the new target has not changed, avoid sending a message to the cluster manager + if (numExecutorsTarget < oldNumExecutorsTarget) { + client.requestTotalExecutors(numExecutorsTarget) + logInfo(s"Lowering target number of executors to $numExecutorsTarget (previously " + + s"$oldNumExecutorsTarget) because not all requested executors are actually needed") + } numExecutorsTarget - oldNumExecutorsTarget } else if (addTime != NOT_SET && now >= addTime) { val delta = addExecutors(maxNeeded) From 0576c3c4ff9d9bbff208e915bee1ac0d4956548c Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Wed, 3 Jun 2015 17:02:16 -0700 Subject: [PATCH 178/254] [SPARK-8084] [SPARKR] Make SparkR scripts fail on error cc shaneknapp pwendell JoshRosen Author: Shivaram Venkataraman Closes #6623 from shivaram/SPARK-8084 and squashes the following commits: 0ec5b26 [Shivaram Venkataraman] Make SparkR scripts fail on error --- R/create-docs.sh | 3 +++ R/install-dev.sh | 2 ++ 2 files changed, 5 insertions(+) diff --git a/R/create-docs.sh b/R/create-docs.sh index 4194172a2e115..af47c0863bdd0 100755 --- a/R/create-docs.sh +++ b/R/create-docs.sh @@ -23,6 +23,9 @@ # After running this script the html docs can be found in # $SPARK_HOME/R/pkg/html +set -o pipefail +set -e + # Figure out where the script is export FWDIR="$(cd "`dirname "$0"`"; pwd)" pushd $FWDIR diff --git a/R/install-dev.sh b/R/install-dev.sh index 55ed6f4be1a4a..b9e2527035994 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -26,6 +26,8 @@ # NOTE(shivaram): Right now we use $SPARK_HOME/R/lib to be the installation directory # to load the SparkR package on the worker nodes. +set -o pipefail +set -e FWDIR="$(cd `dirname $0`; pwd)" LIB_DIR="$FWDIR/lib" From e35cd36e08faa43466759c412c420a9d8901d368 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 3 Jun 2015 17:40:14 -0700 Subject: [PATCH 179/254] [BUILD] Increase Jenkins test timeout Currently hive tests alone take 40m. The right thing to do is to reduce the test time. However, that is a bigger project and we currently have PRs blocking on tests not timing out. --- dev/run-tests-jenkins | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 8b2a44fd72ba5..3cbd8666c8d68 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -47,7 +47,9 @@ COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}" # GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :( SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}" -TESTS_TIMEOUT="150m" # format: http://linux.die.net/man/1/timeout +# format: http://linux.die.net/man/1/timeout +# must be less than the timeout configured on Jenkins (currently 180m) +TESTS_TIMEOUT="175m" # Array to capture all tests to run on the pull request. These tests are held under the #+ dev/tests/ directory. From 9cf740f357fef00b5251618b20501774852f8a28 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 3 Jun 2015 18:08:53 -0700 Subject: [PATCH 180/254] [BUILD] Use right branch when checking against Hive Right now we always run hive tests in branch-1.4 PRs because we compare whether the diff against master involves hive changes. Really we should be comparing against the target branch itself. Author: Andrew Or Closes #6629 from andrewor14/build-check-hive and squashes the following commits: 450fbbd [Andrew Or] [BUILD] Use right branch when checking against Hive --- dev/run-tests | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dev/run-tests b/dev/run-tests index 7dd8d31fd44e3..d178e2a4601ea 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -80,18 +80,19 @@ export SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Pkinesis-asl" # Only run Hive tests if there are SQL changes. # Partial solution for SPARK-1455. if [ -n "$AMPLAB_JENKINS" ]; then - git fetch origin master:master + target_branch="$ghprbTargetBranch" + git fetch origin "$target_branch":"$target_branch" # AMP_JENKINS_PRB indicates if the current build is a pull request build. if [ -n "$AMP_JENKINS_PRB" ]; then # It is a pull request build. sql_diffs=$( - git diff --name-only master \ + git diff --name-only "$target_branch" \ | grep -e "^sql/" -e "^bin/spark-sql" -e "^sbin/start-thriftserver.sh" ) non_sql_diffs=$( - git diff --name-only master \ + git diff --name-only "$target_branch" \ | grep -v -e "^sql/" -e "^bin/spark-sql" -e "^sbin/start-thriftserver.sh" ) From 984ad60147c933f2d5a2040c87ae687c14eb1724 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 3 Jun 2015 20:45:31 -0700 Subject: [PATCH 181/254] [BUILD] Fix Maven build for Kinesis A necessary dependency that is transitively referenced is not provided, causing compilation failures in builds that provide the kinesis-asl profile. --- extras/kinesis-asl/pom.xml | 7 +++++++ pom.xml | 2 ++ 2 files changed, 9 insertions(+) diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 4787991572b61..c6f60bc907438 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -40,6 +40,13 @@ spark-streaming_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-streaming_${scala.binary.version} diff --git a/pom.xml b/pom.xml index 0b1aaad7566bc..d03d33bf02468 100644 --- a/pom.xml +++ b/pom.xml @@ -1438,6 +1438,8 @@ 2.3 false + + false From 9982d453c39e50aedae7d01e4c38fab1b2bc6be0 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 3 Jun 2015 23:45:06 -0700 Subject: [PATCH 182/254] MAINTENANCE: Automated closing of pull requests. This commit exists to close the following pull requests on Github: Closes #5976 (close requested by 'JoshRosen') Closes #4576 (close requested by 'pwendell') Closes #3430 (close requested by 'pwendell') Closes #2495 (close requested by 'pwendell') From 10ba1880878d0babcdc5c9b688df5458ea131531 Mon Sep 17 00:00:00 2001 From: Daniel Darabos Date: Thu, 4 Jun 2015 13:46:49 +0200 Subject: [PATCH 183/254] Fix maxTaskFailures comment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit If maxTaskFailures is 1, the task set is aborted after 1 task failure. Other documentation and the code supports this reading, I think it's just this comment that was off. It's easy to make this mistake — can you please double-check if I'm correct? Thanks! Author: Daniel Darabos Closes #6621 from darabos/patch-2 and squashes the following commits: dfebdec [Daniel Darabos] Fix comment. --- .../main/scala/org/apache/spark/scheduler/TaskSetManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 673cd0e19eba2..82455b0426a5d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -46,7 +46,7 @@ import org.apache.spark.util.{Clock, SystemClock, Utils} * * @param sched the TaskSchedulerImpl associated with the TaskSetManager * @param taskSet the TaskSet to manage scheduling for - * @param maxTaskFailures if any particular task fails more than this number of times, the entire + * @param maxTaskFailures if any particular task fails this number of times, the entire * task set will be aborted */ private[spark] class TaskSetManager( From c8709dcfd1237ffa19ee9286e99ddf2718a616d8 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 4 Jun 2015 10:28:59 -0700 Subject: [PATCH 184/254] [SPARK-7956] [SQL] Use Janino to compile SQL expressions into bytecode In order to reduce the overhead of codegen, this PR switch to use Janino to compile SQL expressions into bytecode. After this, the time used to compile a SQL expression is decreased from 100ms to 5ms, which is necessary to turn on codegen for general workload, also tests. cc rxin Author: Davies Liu Closes #6479 from davies/janino and squashes the following commits: cc689f5 [Davies Liu] remove globalLock 262d848 [Davies Liu] Merge branch 'master' of github.com:apache/spark into janino eec3a33 [Davies Liu] address comments from Josh f37c8c3 [Davies Liu] fix DecimalType and cast to String 202298b [Davies Liu] Merge branch 'master' of github.com:apache/spark into janino a21e968 [Davies Liu] fix style 0ed3dc6 [Davies Liu] Merge branch 'master' of github.com:apache/spark into janino 551a851 [Davies Liu] fix tests c3bdffa [Davies Liu] remove print 6089ce5 [Davies Liu] change logging level 7e46ac3 [Davies Liu] fix style d8f0f6c [Davies Liu] Merge branch 'master' of github.com:apache/spark into janino da4926a [Davies Liu] fix tests 03660f3 [Davies Liu] WIP: use Janino to compile Java source f2629cd [Davies Liu] Merge branch 'master' of github.com:apache/spark into janino f7d66cf [Davies Liu] use template based string for codegen --- .../spark/util/collection/OpenHashSet.scala | 12 +- pom.xml | 10 - project/SparkBuild.scala | 11 - sql/catalyst/pom.xml | 16 +- .../sql/catalyst/expressions/UnsafeRow.java | 101 +-- .../org/apache/spark/sql/BaseMutableRow.java | 68 ++ .../scala/org/apache/spark/sql/BaseRow.java | 190 +++++ .../expressions/codegen/CodeGenerator.scala | 797 +++++++++--------- .../codegen/GenerateMutableProjection.scala | 87 +- .../codegen/GenerateOrdering.scala | 146 ++-- .../codegen/GeneratePredicate.scala | 44 +- .../codegen/GenerateProjection.scala | 316 ++++--- .../expressions/codegen/package.scala | 6 - .../ExpressionEvaluationSuite.scala | 15 +- .../GeneratedEvaluationSuite.scala | 5 +- .../GeneratedMutableEvaluationSuite.scala | 7 +- .../org/apache/spark/sql/DataFrameSuite.scala | 11 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 162 ++-- 18 files changed, 1116 insertions(+), 888 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 1501111a06655..64e7102e3654c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -20,6 +20,8 @@ package org.apache.spark.util.collection import scala.reflect._ import com.google.common.hash.Hashing +import org.apache.spark.annotation.Private + /** * A simple, fast hash set optimized for non-null insertion-only use case, where keys are never * removed. @@ -37,7 +39,7 @@ import com.google.common.hash.Hashing * It uses quadratic probing with a power-of-2 hash table size, which is guaranteed * to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing). */ -private[spark] +@Private class OpenHashSet[@specialized(Long, Int) T: ClassTag]( initialCapacity: Int, loadFactor: Double) @@ -110,6 +112,14 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( rehashIfNeeded(k, grow, move) } + def union(other: OpenHashSet[T]): OpenHashSet[T] = { + val iterator = other.iterator + while (iterator.hasNext) { + add(iterator.next()) + } + this + } + /** * Add an element to the set. This one differs from add in that it doesn't trigger rehashing. * The caller is responsible for calling rehashIfNeeded. diff --git a/pom.xml b/pom.xml index d03d33bf02468..bcb6ef96a1206 100644 --- a/pom.xml +++ b/pom.xml @@ -118,7 +118,6 @@ 2.3.4-spark 1.6 spark - 2.0.1 0.21.1 shaded-protobuf 1.7.10 @@ -1217,15 +1216,6 @@ -target ${java.version} - - - - org.scalamacros - paradise_${scala.version} - ${scala.macros.version} - - diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 9a849639233bc..f65031fe25ac2 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -178,9 +178,6 @@ object SparkBuild extends PomBuild { /* Enable unidoc only for the root spark project */ enable(Unidoc.settings)(spark) - /* Catalyst macro settings */ - enable(Catalyst.settings)(catalyst) - /* Spark SQL Core console settings */ enable(SQL.settings)(sql) @@ -275,14 +272,6 @@ object OldDeps { ) } -object Catalyst { - lazy val settings = Seq( - addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full), - // Quasiquotes break compiling scala doc... - // TODO: Investigate fixing this. - sources in (Compile, doc) ~= (_ filter (_.getName contains "codegen"))) -} - object SQL { lazy val settings = Seq( initialCommands in console := diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index bf0a7327a58a2..f4b1cc3a4ffe7 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -36,10 +36,6 @@ - - org.scala-lang - scala-compiler - org.scala-lang scala-reflect @@ -67,6 +63,11 @@ scalacheck_${scala.binary.version} test + + org.codehaus.janino + janino + 2.7.8 + target/scala-${scala.binary.version}/classes @@ -108,13 +109,6 @@ !scala-2.11 - - - org.scalamacros - quasiquotes_${scala.binary.version} - ${scala.macros.version} - - diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index bb546b3086b33..ec97fe603c44f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,23 +17,25 @@ package org.apache.spark.sql.catalyst.expressions; -import scala.collection.Map; +import javax.annotation.Nullable; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + import scala.collection.Seq; import scala.collection.mutable.ArraySeq; -import javax.annotation.Nullable; -import java.math.BigDecimal; -import java.sql.Date; -import java.util.*; - import org.apache.spark.sql.Row; +import org.apache.spark.sql.BaseMutableRow; import org.apache.spark.sql.types.DataType; -import static org.apache.spark.sql.types.DataTypes.*; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.UTF8String; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.bitset.BitSetMethods; +import static org.apache.spark.sql.types.DataTypes.*; + /** * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. * @@ -49,7 +51,7 @@ * * Instances of `UnsafeRow` act as pointers to row data stored in this format. */ -public final class UnsafeRow implements MutableRow { +public final class UnsafeRow extends BaseMutableRow { private Object baseObject; private long baseOffset; @@ -227,21 +229,11 @@ public int size() { return numFields; } - @Override - public int length() { - return size(); - } - @Override public StructType schema() { return schema; } - @Override - public Object apply(int i) { - return get(i); - } - @Override public Object get(int i) { assertIndexIsValid(i); @@ -339,60 +331,7 @@ public String getString(int i) { return getUTF8String(i).toString(); } - @Override - public BigDecimal getDecimal(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Date getDate(int i) { - throw new UnsupportedOperationException(); - } - @Override - public Seq getSeq(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public List getList(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Map getMap(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public scala.collection.immutable.Map getValuesMap(Seq fieldNames) { - throw new UnsupportedOperationException(); - } - - @Override - public java.util.Map getJavaMap(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Row getStruct(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public T getAs(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public T getAs(String fieldName) { - throw new UnsupportedOperationException(); - } - - @Override - public int fieldIndex(String name) { - throw new UnsupportedOperationException(); - } @Override public Row copy() { @@ -412,24 +351,4 @@ public Seq toSeq() { } return values; } - - @Override - public String toString() { - return mkString("[", ",", "]"); - } - - @Override - public String mkString() { - return toSeq().mkString(); - } - - @Override - public String mkString(String sep) { - return toSeq().mkString(sep); - } - - @Override - public String mkString(String start, String sep, String end) { - return toSeq().mkString(start, sep, end); - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java new file mode 100644 index 0000000000000..acec2bf4520f2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql; + +import org.apache.spark.sql.catalyst.expressions.MutableRow; + +public abstract class BaseMutableRow extends BaseRow implements MutableRow { + + @Override + public void update(int ordinal, Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setInt(int ordinal, int value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setLong(int ordinal, long value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setDouble(int ordinal, double value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setBoolean(int ordinal, boolean value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setShort(int ordinal, short value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setByte(int ordinal, byte value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setFloat(int ordinal, float value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setString(int ordinal, String value) { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java new file mode 100644 index 0000000000000..d138b43a3482b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql; + +import java.math.BigDecimal; +import java.sql.Date; +import java.util.List; + +import scala.collection.Seq; +import scala.collection.mutable.ArraySeq; + +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.types.StructType; + +public abstract class BaseRow implements Row { + + @Override + final public int length() { + return size(); + } + + @Override + public boolean anyNull() { + final int n = size(); + for (int i=0; i < n; i++) { + if (isNullAt(i)) { + return true; + } + } + return false; + } + + @Override + public StructType schema() { throw new UnsupportedOperationException(); } + + @Override + final public Object apply(int i) { + return get(i); + } + + @Override + public int getInt(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public long getLong(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public float getFloat(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public double getDouble(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public byte getByte(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public short getShort(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean getBoolean(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public String getString(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public BigDecimal getDecimal(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Date getDate(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Seq getSeq(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public List getList(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public scala.collection.Map getMap(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public scala.collection.immutable.Map getValuesMap(Seq fieldNames) { + throw new UnsupportedOperationException(); + } + + @Override + public java.util.Map getJavaMap(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Row getStruct(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public T getAs(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public T getAs(String fieldName) { + throw new UnsupportedOperationException(); + } + + @Override + public int fieldIndex(String name) { + throw new UnsupportedOperationException(); + } + + @Override + public Row copy() { + final int n = size(); + Object[] arr = new Object[n]; + for (int i = 0; i < n; i++) { + arr[i] = get(i); + } + return new GenericRow(arr); + } + + @Override + public Seq toSeq() { + final int n = size(); + final ArraySeq values = new ArraySeq(n); + for (int i = 0; i < n; i++) { + values.update(i, get(i)); + } + return values; + } + + @Override + public String toString() { + return mkString("[", ",", "]"); + } + + @Override + public String mkString() { + return toSeq().mkString(); + } + + @Override + public String mkString(String sep) { + return toSeq().mkString(sep); + } + + @Override + public String mkString(String start, String sep, String end) { + return toSeq().mkString(start, sep, end); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 36964af68dd8d..cd604121b7dd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import com.google.common.cache.{CacheLoader, CacheBuilder} - +import scala.collection.mutable import scala.language.existentials +import com.google.common.cache.{CacheBuilder, CacheLoader} +import org.codehaus.janino.ClassBodyEvaluator + import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ @@ -36,23 +38,15 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] * expressions. */ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { - import scala.reflect.runtime.{universe => ru} - import scala.reflect.runtime.universe._ - - import scala.tools.reflect.ToolBox - - protected val toolBox = runtimeMirror(getClass.getClassLoader).mkToolBox() - protected val rowType = typeOf[Row] - protected val mutableRowType = typeOf[MutableRow] - protected val genericRowType = typeOf[GenericRow] - protected val genericMutableRowType = typeOf[GenericMutableRow] - - protected val projectionType = typeOf[Projection] - protected val mutableProjectionType = typeOf[MutableProjection] + protected val rowType = classOf[Row].getName + protected val stringType = classOf[UTF8String].getName + protected val decimalType = classOf[Decimal].getName + protected val exprType = classOf[Expression].getName + protected val mutableRowType = classOf[MutableRow].getName + protected val genericMutableRowType = classOf[GenericMutableRow].getName private val curId = new java.util.concurrent.atomic.AtomicInteger() - private val javaSeparator = "$" /** * Can be flipped on manually in the console to add (expensive) expression evaluation trace code. @@ -74,6 +68,20 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin /** Binds an input expression to a given input schema */ protected def bind(in: InType, inputSchema: Seq[Attribute]): InType + /** + * Compile the Java source code into a Java class, using Janino. + * + * It will track the time used to compile + */ + protected def compile(code: String): Class[_] = { + val startTime = System.nanoTime() + val clazz = new ClassBodyEvaluator(code).getClazz() + val endTime = System.nanoTime() + def timeMs: Double = (endTime - startTime).toDouble / 1000000 + logDebug(s"Compiled Java code (${code.size} bytes) in $timeMs ms") + clazz + } + /** * A cache of generated classes. * @@ -87,7 +95,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin .maximumSize(1000) .build( new CacheLoader[InType, OutType]() { - override def load(in: InType): OutType = globalLock.synchronized { + override def load(in: InType): OutType = { val startTime = System.nanoTime() val result = create(in) val endTime = System.nanoTime() @@ -110,8 +118,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` * function.) */ - protected def freshName(prefix: String): TermName = { - newTermName(s"$prefix$javaSeparator${curId.getAndIncrement}") + protected def freshName(prefix: String): String = { + s"$prefix${curId.getAndIncrement}" } /** @@ -125,32 +133,51 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * @param objectTerm A possibly boxed version of the result of evaluating this expression. */ protected case class EvaluatedExpression( - code: Seq[Tree], - nullTerm: TermName, - primitiveTerm: TermName, - objectTerm: TermName) + code: String, + nullTerm: String, + primitiveTerm: String, + objectTerm: String) + + /** + * A context for codegen, which is used to bookkeeping the expressions those are not supported + * by codegen, then they are evaluated directly. The unsupported expression is appended at the + * end of `references`, the position of it is kept in the code, used to access and evaluate it. + */ + protected class CodeGenContext { + /** + * Holding all the expressions those do not support codegen, will be evaluated directly. + */ + val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]() + } + + /** + * Create a new codegen context for expression evaluator, used to store those + * expressions that don't support codegen + */ + def newCodeGenContext(): CodeGenContext = { + new CodeGenContext() + } /** * Given an expression tree returns an [[EvaluatedExpression]], which contains Scala trees that * can be used to determine the result of evaluating the expression on an input row. */ - def expressionEvaluator(e: Expression): EvaluatedExpression = { + def expressionEvaluator(e: Expression, ctx: CodeGenContext): EvaluatedExpression = { val primitiveTerm = freshName("primitiveTerm") val nullTerm = freshName("nullTerm") val objectTerm = freshName("objectTerm") implicit class Evaluate1(e: Expression) { - def castOrNull(f: TermName => Tree, dataType: DataType): Seq[Tree] = { - val eval = expressionEvaluator(e) - eval.code ++ - q""" - val $nullTerm = ${eval.nullTerm} - val $primitiveTerm = - if($nullTerm) - ${defaultPrimitive(dataType)} - else - ${f(eval.primitiveTerm)} - """.children + def castOrNull(f: String => String, dataType: DataType): String = { + val eval = expressionEvaluator(e, ctx) + eval.code + + s""" + boolean $nullTerm = ${eval.nullTerm}; + ${primitiveForType(dataType)} $primitiveTerm = ${defaultPrimitive(dataType)}; + if (!$nullTerm) { + $primitiveTerm = ${f(eval.primitiveTerm)}; + } + """ } } @@ -163,529 +190,505 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * * @param f a function from two primitive term names to a tree that evaluates them. */ - def evaluate(f: (TermName, TermName) => Tree): Seq[Tree] = + def evaluate(f: (String, String) => String): String = evaluateAs(expressions._1.dataType)(f) - def evaluateAs(resultType: DataType)(f: (TermName, TermName) => Tree): Seq[Tree] = { + def evaluateAs(resultType: DataType)(f: (String, String) => String): String = { // TODO: Right now some timestamp tests fail if we enforce this... if (expressions._1.dataType != expressions._2.dataType) { log.warn(s"${expressions._1.dataType} != ${expressions._2.dataType}") } - val eval1 = expressionEvaluator(expressions._1) - val eval2 = expressionEvaluator(expressions._2) + val eval1 = expressionEvaluator(expressions._1, ctx) + val eval2 = expressionEvaluator(expressions._2, ctx) val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm) - eval1.code ++ eval2.code ++ - q""" - val $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm} - val $primitiveTerm: ${termForType(resultType)} = - if($nullTerm) { - ${defaultPrimitive(resultType)} - } else { - $resultCode.asInstanceOf[${termForType(resultType)}] - } - """.children : Seq[Tree] + eval1.code + eval2.code + + s""" + boolean $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm}; + ${primitiveForType(resultType)} $primitiveTerm = ${defaultPrimitive(resultType)}; + if(!$nullTerm) { + $primitiveTerm = (${primitiveForType(resultType)})($resultCode); + } + """ } } - val inputTuple = newTermName(s"i") + val inputTuple = "i" // TODO: Skip generation of null handling code when expression are not nullable. - val primitiveEvaluation: PartialFunction[Expression, Seq[Tree]] = { + val primitiveEvaluation: PartialFunction[Expression, String] = { case b @ BoundReference(ordinal, dataType, nullable) => - val nullValue = q"$inputTuple.isNullAt($ordinal)" - q""" - val $nullTerm: Boolean = $nullValue - val $primitiveTerm: ${termForType(dataType)} = - if($nullTerm) - ${defaultPrimitive(dataType)} - else - ${getColumn(inputTuple, dataType, ordinal)} - """.children + s""" + final boolean $nullTerm = $inputTuple.isNullAt($ordinal); + final ${primitiveForType(dataType)} $primitiveTerm = $nullTerm ? + ${defaultPrimitive(dataType)} : (${getColumn(inputTuple, dataType, ordinal)}); + """ case expressions.Literal(null, dataType) => - q""" - val $nullTerm = true - val $primitiveTerm: ${termForType(dataType)} = null.asInstanceOf[${termForType(dataType)}] - """.children - - case expressions.Literal(value: Boolean, dataType) => - q""" - val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = $value - """.children - - case expressions.Literal(value: UTF8String, dataType) => - q""" - val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = - org.apache.spark.sql.types.UTF8String(${value.getBytes}) - """.children - - case expressions.Literal(value: Int, dataType) => - q""" - val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = $value - """.children - - case expressions.Literal(value: Long, dataType) => - q""" - val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = $value - """.children - - case Cast(e @ BinaryType(), StringType) => - val eval = expressionEvaluator(e) - eval.code ++ - q""" - val $nullTerm = ${eval.nullTerm} - val $primitiveTerm = - if($nullTerm) - ${defaultPrimitive(StringType)} - else - org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) - """.children + s""" + final boolean $nullTerm = true; + ${primitiveForType(dataType)} $primitiveTerm = ${defaultPrimitive(dataType)}; + """ + + case expressions.Literal(value: UTF8String, StringType) => + val arr = s"new byte[]{${value.getBytes.map(_.toString).mkString(", ")}}" + s""" + final boolean $nullTerm = false; + ${stringType} $primitiveTerm = + new ${stringType}().set(${arr}); + """ + + case expressions.Literal(value, FloatType) => + s""" + final boolean $nullTerm = false; + float $primitiveTerm = ${value}f; + """ + + case expressions.Literal(value, dt @ DecimalType()) => + s""" + final boolean $nullTerm = false; + ${primitiveForType(dt)} $primitiveTerm = new ${primitiveForType(dt)}().set($value); + """ + + case expressions.Literal(value, dataType) => + s""" + final boolean $nullTerm = false; + ${primitiveForType(dataType)} $primitiveTerm = $value; + """ + + case Cast(child @ BinaryType(), StringType) => + child.castOrNull(c => + s"new ${stringType}().set($c)", + StringType) case Cast(child @ DateType(), StringType) => child.castOrNull(c => - q"""org.apache.spark.sql.types.UTF8String( + s"""new ${stringType}().set( org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""", StringType) - case Cast(child @ NumericType(), IntegerType) => - child.castOrNull(c => q"$c.toInt", IntegerType) + case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => + child.castOrNull(c => s"(${primitiveForType(dt)})($c?1:0)", dt) - case Cast(child @ NumericType(), LongType) => - child.castOrNull(c => q"$c.toLong", LongType) + case Cast(child @ DecimalType(), IntegerType) => + child.castOrNull(c => s"($c).toInt()", IntegerType) - case Cast(child @ NumericType(), DoubleType) => - child.castOrNull(c => q"$c.toDouble", DoubleType) + case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => + child.castOrNull(c => s"($c).to${termForType(dt)}()", dt) - case Cast(child @ NumericType(), FloatType) => - child.castOrNull(c => q"$c.toFloat", FloatType) + case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => + child.castOrNull(c => s"(${primitiveForType(dt)})($c)", dt) // Special handling required for timestamps in hive test cases since the toString function // does not match the expected output. case Cast(e, StringType) if e.dataType != TimestampType => - val eval = expressionEvaluator(e) - eval.code ++ - q""" - val $nullTerm = ${eval.nullTerm} - val $primitiveTerm = - if($nullTerm) - ${defaultPrimitive(StringType)} - else - org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString) - """.children + e.castOrNull(c => + s"new ${stringType}().set(String.valueOf($c))", + StringType) case EqualTo(e1 @ BinaryType(), e2 @ BinaryType()) => (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => - q""" - java.util.Arrays.equals($eval1.asInstanceOf[Array[Byte]], - $eval2.asInstanceOf[Array[Byte]]) - """ + s"java.util.Arrays.equals((byte[])$eval1, (byte[])$eval2)" } case EqualTo(e1, e2) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" } - - /* TODO: Fix null semantics. - case In(e1, list) if !list.exists(!_.isInstanceOf[expressions.Literal]) => - val eval = expressionEvaluator(e1) - - val checks = list.map { - case expressions.Literal(v: String, dataType) => - q"if(${eval.primitiveTerm} == $v) return true" - case expressions.Literal(v: Int, dataType) => - q"if(${eval.primitiveTerm} == $v) return true" - } - - val funcName = newTermName(s"isIn${curId.getAndIncrement()}") - - q""" - def $funcName: Boolean = { - ..${eval.code} - if(${eval.nullTerm}) return false - ..$checks - return false - } - val $nullTerm = false - val $primitiveTerm = $funcName - """.children - */ + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 == $eval2" } case GreaterThan(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 > $eval2" } + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 > $eval2" } case GreaterThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 >= $eval2" } + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 >= $eval2" } case LessThan(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 < $eval2" } + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 < $eval2" } case LessThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 <= $eval2" } + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 <= $eval2" } case And(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) - - q""" - ..${eval1.code} - var $nullTerm = false - var $primitiveTerm: ${termForType(BooleanType)} = false - - if (!${eval1.nullTerm} && ${eval1.primitiveTerm} == false) { + val eval1 = expressionEvaluator(e1, ctx) + val eval2 = expressionEvaluator(e2, ctx) + s""" + ${eval1.code} + boolean $nullTerm = false; + boolean $primitiveTerm = false; + + if (!${eval1.nullTerm} && !${eval1.primitiveTerm}) { } else { - ..${eval2.code} - if (!${eval2.nullTerm} && ${eval2.primitiveTerm} == false) { + ${eval2.code} + if (!${eval2.nullTerm} && !${eval2.primitiveTerm}) { } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { - $primitiveTerm = true + $primitiveTerm = true; } else { - $nullTerm = true + $nullTerm = true; } } - """.children + """ case Or(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) + val eval1 = expressionEvaluator(e1, ctx) + val eval2 = expressionEvaluator(e2, ctx) - q""" - ..${eval1.code} - var $nullTerm = false - var $primitiveTerm: ${termForType(BooleanType)} = false + s""" + ${eval1.code} + boolean $nullTerm = false; + boolean $primitiveTerm = false; if (!${eval1.nullTerm} && ${eval1.primitiveTerm}) { - $primitiveTerm = true + $primitiveTerm = true; } else { - ..${eval2.code} + ${eval2.code} if (!${eval2.nullTerm} && ${eval2.primitiveTerm}) { - $primitiveTerm = true + $primitiveTerm = true; } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { - $primitiveTerm = false + $primitiveTerm = false; } else { - $nullTerm = true + $nullTerm = true; } } - """.children + """ case Not(child) => // Uh, bad function name... - child.castOrNull(c => q"!$c", BooleanType) - - case Add(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 + $eval2" } - case Subtract(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 - $eval2" } - case Multiply(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 * $eval2" } + child.castOrNull(c => s"!$c", BooleanType) + + case Add(e1 @ DecimalType(), e2 @ DecimalType()) => + (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$plus($eval2)" } + case Subtract(e1 @ DecimalType(), e2 @ DecimalType()) => + (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$minus($eval2)" } + case Multiply(e1 @ DecimalType(), e2 @ DecimalType()) => + (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$times($eval2)" } + case Divide(e1 @ DecimalType(), e2 @ DecimalType()) => + val eval1 = expressionEvaluator(e1, ctx) + val eval2 = expressionEvaluator(e2, ctx) + eval1.code + eval2.code + + s""" + boolean $nullTerm = false; + ${primitiveForType(e1.dataType)} $primitiveTerm = null; + if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm}.isZero()) { + $nullTerm = true; + } else { + $primitiveTerm = ${eval1.primitiveTerm}.$$div${eval2.primitiveTerm}); + } + """ + case Remainder(e1 @ DecimalType(), e2 @ DecimalType()) => + val eval1 = expressionEvaluator(e1, ctx) + val eval2 = expressionEvaluator(e2, ctx) + eval1.code + eval2.code + + s""" + boolean $nullTerm = false; + ${primitiveForType(e1.dataType)} $primitiveTerm = 0; + if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm}.isZero()) { + $nullTerm = true; + } else { + $primitiveTerm = ${eval1.primitiveTerm}.remainder(${eval2.primitiveTerm}); + } + """ + + case Add(e1, e2) => + (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 + $eval2" } + case Subtract(e1, e2) => + (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 - $eval2" } + case Multiply(e1, e2) => + (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 * $eval2" } case Divide(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) - - eval1.code ++ eval2.code ++ - q""" - var $nullTerm = false - var $primitiveTerm: ${termForType(e1.dataType)} = 0 - - if (${eval1.nullTerm} || ${eval2.nullTerm} ) { - $nullTerm = true - } else if (${eval2.primitiveTerm} == 0) - $nullTerm = true - else { - $primitiveTerm = ${eval1.primitiveTerm} / ${eval2.primitiveTerm} + val eval1 = expressionEvaluator(e1, ctx) + val eval2 = expressionEvaluator(e2, ctx) + eval1.code + eval2.code + + s""" + boolean $nullTerm = false; + ${primitiveForType(e1.dataType)} $primitiveTerm = 0; + if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm} == 0) { + $nullTerm = true; + } else { + $primitiveTerm = ${eval1.primitiveTerm} / ${eval2.primitiveTerm}; } - """.children - + """ case Remainder(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) - - eval1.code ++ eval2.code ++ - q""" - var $nullTerm = false - var $primitiveTerm: ${termForType(e1.dataType)} = 0 - - if (${eval1.nullTerm} || ${eval2.nullTerm} ) { - $nullTerm = true - } else if (${eval2.primitiveTerm} == 0) - $nullTerm = true - else { - $nullTerm = false - $primitiveTerm = ${eval1.primitiveTerm} % ${eval2.primitiveTerm} + val eval1 = expressionEvaluator(e1, ctx) + val eval2 = expressionEvaluator(e2, ctx) + eval1.code + eval2.code + + s""" + boolean $nullTerm = false; + ${primitiveForType(e1.dataType)} $primitiveTerm = 0; + if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm} == 0) { + $nullTerm = true; + } else { + $primitiveTerm = ${eval1.primitiveTerm} % ${eval2.primitiveTerm}; } - """.children + """ case IsNotNull(e) => - val eval = expressionEvaluator(e) - q""" - ..${eval.code} - var $nullTerm = false - var $primitiveTerm: ${termForType(BooleanType)} = !${eval.nullTerm} - """.children + val eval = expressionEvaluator(e, ctx) + s""" + ${eval.code} + boolean $nullTerm = false; + boolean $primitiveTerm = !${eval.nullTerm}; + """ case IsNull(e) => - val eval = expressionEvaluator(e) - q""" - ..${eval.code} - var $nullTerm = false - var $primitiveTerm: ${termForType(BooleanType)} = ${eval.nullTerm} - """.children - - case c @ Coalesce(children) => - q""" - var $nullTerm = true - var $primitiveTerm: ${termForType(c.dataType)} = ${defaultPrimitive(c.dataType)} - """.children ++ + val eval = expressionEvaluator(e, ctx) + s""" + ${eval.code} + boolean $nullTerm = false; + boolean $primitiveTerm = ${eval.nullTerm}; + """ + + case e @ Coalesce(children) => + s""" + boolean $nullTerm = true; + ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)}; + """ + children.map { c => - val eval = expressionEvaluator(c) - q""" + val eval = expressionEvaluator(c, ctx) + s""" if($nullTerm) { - ..${eval.code} + ${eval.code} if(!${eval.nullTerm}) { - $nullTerm = false - $primitiveTerm = ${eval.primitiveTerm} + $nullTerm = false; + $primitiveTerm = ${eval.primitiveTerm}; } } """ - } + }.mkString("\n") - case i @ expressions.If(condition, trueValue, falseValue) => - val condEval = expressionEvaluator(condition) - val trueEval = expressionEvaluator(trueValue) - val falseEval = expressionEvaluator(falseValue) + case e @ expressions.If(condition, trueValue, falseValue) => + val condEval = expressionEvaluator(condition, ctx) + val trueEval = expressionEvaluator(trueValue, ctx) + val falseEval = expressionEvaluator(falseValue, ctx) - q""" - var $nullTerm = false - var $primitiveTerm: ${termForType(i.dataType)} = ${defaultPrimitive(i.dataType)} - ..${condEval.code} + s""" + boolean $nullTerm = false; + ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)}; + ${condEval.code} if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) { - ..${trueEval.code} - $nullTerm = ${trueEval.nullTerm} - $primitiveTerm = ${trueEval.primitiveTerm} + ${trueEval.code} + $nullTerm = ${trueEval.nullTerm}; + $primitiveTerm = ${trueEval.primitiveTerm}; } else { - ..${falseEval.code} - $nullTerm = ${falseEval.nullTerm} - $primitiveTerm = ${falseEval.primitiveTerm} + ${falseEval.code} + $nullTerm = ${falseEval.nullTerm}; + $primitiveTerm = ${falseEval.primitiveTerm}; } - """.children + """ case NewSet(elementType) => - q""" - val $nullTerm = false - val $primitiveTerm = new ${hashSetForType(elementType)}() - """.children + s""" + boolean $nullTerm = false; + ${hashSetForType(elementType)} $primitiveTerm = new ${hashSetForType(elementType)}(); + """ case AddItemToSet(item, set) => - val itemEval = expressionEvaluator(item) - val setEval = expressionEvaluator(set) + val itemEval = expressionEvaluator(item, ctx) + val setEval = expressionEvaluator(set, ctx) val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType + val htype = hashSetForType(elementType) - itemEval.code ++ setEval.code ++ - q""" - if (!${itemEval.nullTerm}) { - ${setEval.primitiveTerm} - .asInstanceOf[${hashSetForType(elementType)}] - .add(${itemEval.primitiveTerm}) + itemEval.code + setEval.code + + s""" + if (!${itemEval.nullTerm} && !${setEval.nullTerm}) { + (($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm}); } - - val $nullTerm = false - val $primitiveTerm = ${setEval.primitiveTerm} - """.children + boolean $nullTerm = false; + ${htype} $primitiveTerm = ($htype)${setEval.primitiveTerm}; + """ case CombineSets(left, right) => - val leftEval = expressionEvaluator(left) - val rightEval = expressionEvaluator(right) + val leftEval = expressionEvaluator(left, ctx) + val rightEval = expressionEvaluator(right, ctx) val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType + val htype = hashSetForType(elementType) - leftEval.code ++ rightEval.code ++ - q""" - val $nullTerm = false - var $primitiveTerm: ${hashSetForType(elementType)} = null - - { - val leftSet = ${leftEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}] - val rightSet = ${rightEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}] - val iterator = rightSet.iterator - while (iterator.hasNext) { - leftSet.add(iterator.next()) - } - $primitiveTerm = leftSet - } - """.children + leftEval.code + rightEval.code + + s""" + boolean $nullTerm = false; + ${htype} $primitiveTerm = + (${htype})${leftEval.primitiveTerm}; + $primitiveTerm.union((${htype})${rightEval.primitiveTerm}); + """ - case MaxOf(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) + case MaxOf(e1, e2) if !e1.dataType.isInstanceOf[DecimalType] => + val eval1 = expressionEvaluator(e1, ctx) + val eval2 = expressionEvaluator(e2, ctx) - eval1.code ++ eval2.code ++ - q""" - var $nullTerm = false - var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)} + eval1.code + eval2.code + + s""" + boolean $nullTerm = false; + ${primitiveForType(e1.dataType)} $primitiveTerm = ${defaultPrimitive(e1.dataType)}; if (${eval1.nullTerm}) { - $nullTerm = ${eval2.nullTerm} - $primitiveTerm = ${eval2.primitiveTerm} + $nullTerm = ${eval2.nullTerm}; + $primitiveTerm = ${eval2.primitiveTerm}; } else if (${eval2.nullTerm}) { - $nullTerm = ${eval1.nullTerm} - $primitiveTerm = ${eval1.primitiveTerm} + $nullTerm = ${eval1.nullTerm}; + $primitiveTerm = ${eval1.primitiveTerm}; } else { if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) { - $primitiveTerm = ${eval1.primitiveTerm} + $primitiveTerm = ${eval1.primitiveTerm}; } else { - $primitiveTerm = ${eval2.primitiveTerm} + $primitiveTerm = ${eval2.primitiveTerm}; } } - """.children + """ - case MinOf(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) + case MinOf(e1, e2) if !e1.dataType.isInstanceOf[DecimalType] => + val eval1 = expressionEvaluator(e1, ctx) + val eval2 = expressionEvaluator(e2, ctx) - eval1.code ++ eval2.code ++ - q""" - var $nullTerm = false - var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)} + eval1.code + eval2.code + + s""" + boolean $nullTerm = false; + ${primitiveForType(e1.dataType)} $primitiveTerm = ${defaultPrimitive(e1.dataType)}; if (${eval1.nullTerm}) { - $nullTerm = ${eval2.nullTerm} - $primitiveTerm = ${eval2.primitiveTerm} + $nullTerm = ${eval2.nullTerm}; + $primitiveTerm = ${eval2.primitiveTerm}; } else if (${eval2.nullTerm}) { - $nullTerm = ${eval1.nullTerm} - $primitiveTerm = ${eval1.primitiveTerm} + $nullTerm = ${eval1.nullTerm}; + $primitiveTerm = ${eval1.primitiveTerm}; } else { if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) { - $primitiveTerm = ${eval1.primitiveTerm} + $primitiveTerm = ${eval1.primitiveTerm}; } else { - $primitiveTerm = ${eval2.primitiveTerm} + $primitiveTerm = ${eval2.primitiveTerm}; } } - """.children + """ case UnscaledValue(child) => - val childEval = expressionEvaluator(child) - - childEval.code ++ - q""" - var $nullTerm = ${childEval.nullTerm} - var $primitiveTerm: Long = if (!$nullTerm) { - ${childEval.primitiveTerm}.toUnscaledLong - } else { - ${defaultPrimitive(LongType)} - } - """.children + val childEval = expressionEvaluator(child, ctx) + + childEval.code + + s""" + boolean $nullTerm = ${childEval.nullTerm}; + long $primitiveTerm = $nullTerm ? -1 : ${childEval.primitiveTerm}.toUnscaledLong(); + """ case MakeDecimal(child, precision, scale) => - val childEval = expressionEvaluator(child) + val eval = expressionEvaluator(child, ctx) - childEval.code ++ - q""" - var $nullTerm = ${childEval.nullTerm} - var $primitiveTerm: org.apache.spark.sql.types.Decimal = - ${defaultPrimitive(DecimalType())} + eval.code + + s""" + boolean $nullTerm = ${eval.nullTerm}; + org.apache.spark.sql.types.Decimal $primitiveTerm = ${defaultPrimitive(DecimalType())}; if (!$nullTerm) { - $primitiveTerm = new org.apache.spark.sql.types.Decimal() - $primitiveTerm = $primitiveTerm.setOrNull(${childEval.primitiveTerm}, $precision, $scale) - $nullTerm = $primitiveTerm == null + $primitiveTerm = new org.apache.spark.sql.types.Decimal(); + $primitiveTerm = $primitiveTerm.setOrNull(${eval.primitiveTerm}, $precision, $scale); + $nullTerm = $primitiveTerm == null; } - """.children + """ } // If there was no match in the partial function above, we fall back on calling the interpreted // expression evaluator. - val code: Seq[Tree] = + val code: String = primitiveEvaluation.lift.apply(e).getOrElse { - log.debug(s"No rules to generate $e") - val tree = reify { e } - q""" - val $objectTerm = $tree.eval(i) - val $nullTerm = $objectTerm == null - val $primitiveTerm = $objectTerm.asInstanceOf[${termForType(e.dataType)}] - """.children - } - - // Only inject debugging code if debugging is turned on. - val debugCode = - if (debugLogging) { - val localLogger = log - val localLoggerTree = reify { localLogger } - q""" - $localLoggerTree.debug( - ${e.toString} + ": " + (if ($nullTerm) "null" else $primitiveTerm.toString)) - """ :: Nil - } else { - Nil + logError(s"No rules to generate $e") + ctx.references += e + s""" + /* expression: ${e} */ + Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i); + boolean $nullTerm = $objectTerm == null; + ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)}; + if (!$nullTerm) $primitiveTerm = (${termForType(e.dataType)})$objectTerm; + """ } - EvaluatedExpression(code ++ debugCode, nullTerm, primitiveTerm, objectTerm) + EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm) } - protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = { + protected def getColumn(inputRow: String, dataType: DataType, ordinal: Int) = { dataType match { - case StringType => q"$inputRow($ordinal).asInstanceOf[org.apache.spark.sql.types.UTF8String]" - case dt: DataType if isNativeType(dt) => q"$inputRow.${accessorForType(dt)}($ordinal)" - case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]" + case StringType => s"(${stringType})$inputRow.apply($ordinal)" + case dt: DataType if isNativeType(dt) => s"$inputRow.${accessorForType(dt)}($ordinal)" + case _ => s"(${termForType(dataType)})$inputRow.apply($ordinal)" } } protected def setColumn( - destinationRow: TermName, + destinationRow: String, dataType: DataType, ordinal: Int, - value: TermName) = { + value: String): String = { dataType match { - case StringType => q"$destinationRow.update($ordinal, $value)" + case StringType => s"$destinationRow.update($ordinal, $value)" case dt: DataType if isNativeType(dt) => - q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" - case _ => q"$destinationRow.update($ordinal, $value)" + s"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" + case _ => s"$destinationRow.update($ordinal, $value)" } } - protected def accessorForType(dt: DataType) = newTermName(s"get${primitiveForType(dt)}") - protected def mutatorForType(dt: DataType) = newTermName(s"set${primitiveForType(dt)}") + protected def accessorForType(dt: DataType) = dt match { + case IntegerType => "getInt" + case other => s"get${termForType(dt)}" + } + + protected def mutatorForType(dt: DataType) = dt match { + case IntegerType => "setInt" + case other => s"set${termForType(dt)}" + } - protected def hashSetForType(dt: DataType) = dt match { - case IntegerType => typeOf[IntegerHashSet] - case LongType => typeOf[LongHashSet] + protected def hashSetForType(dt: DataType): String = dt match { + case IntegerType => classOf[IntegerHashSet].getName + case LongType => classOf[LongHashSet].getName case unsupportedType => sys.error(s"Code generation not support for hashset of type $unsupportedType") } - protected def primitiveForType(dt: DataType) = dt match { - case IntegerType => "Int" + protected def primitiveForType(dt: DataType): String = dt match { + case IntegerType => "int" + case LongType => "long" + case ShortType => "short" + case ByteType => "byte" + case DoubleType => "double" + case FloatType => "float" + case BooleanType => "boolean" + case dt: DecimalType => decimalType + case BinaryType => "byte[]" + case StringType => stringType + case DateType => "int" + case TimestampType => "java.sql.Timestamp" + case _ => "Object" + } + + protected def defaultPrimitive(dt: DataType): String = dt match { + case BooleanType => "false" + case FloatType => "-1.0f" + case ShortType => "-1" + case LongType => "-1" + case ByteType => "-1" + case DoubleType => "-1.0" + case IntegerType => "-1" + case DateType => "-1" + case dt: DecimalType => "null" + case StringType => "null" + case _ => "null" + } + + protected def termForType(dt: DataType): String = dt match { + case IntegerType => "Integer" case LongType => "Long" case ShortType => "Short" case ByteType => "Byte" case DoubleType => "Double" case FloatType => "Float" case BooleanType => "Boolean" - case StringType => "org.apache.spark.sql.types.UTF8String" - } - - protected def defaultPrimitive(dt: DataType) = dt match { - case BooleanType => ru.Literal(Constant(false)) - case FloatType => ru.Literal(Constant(-1.0.toFloat)) - case StringType => q"""org.apache.spark.sql.types.UTF8String("")""" - case ShortType => ru.Literal(Constant(-1.toShort)) - case LongType => ru.Literal(Constant(-1L)) - case ByteType => ru.Literal(Constant(-1.toByte)) - case DoubleType => ru.Literal(Constant(-1.toDouble)) - case DecimalType() => q"org.apache.spark.sql.types.Decimal(-1)" - case IntegerType => ru.Literal(Constant(-1)) - case DateType => ru.Literal(Constant(-1)) - case _ => ru.Literal(Constant(null)) - } - - protected def termForType(dt: DataType) = dt match { - case n: AtomicType => n.tag - case _ => typeTag[Any] + case dt: DecimalType => decimalType + case BinaryType => "byte[]" + case StringType => stringType + case DateType => "Integer" + case TimestampType => "java.sql.Timestamp" + case _ => "Object" } /** * List of data types that have special accessors and setters in [[Row]]. */ protected val nativeTypes = - Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) + Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType) /** * Returns true if the data type has a special accessor and setter in [[Row]]. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 840260703ab74..638b53fe0fe2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -19,15 +19,14 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +// MutableProjection is not accessible in Java +abstract class BaseMutableProjection extends MutableProjection {} + /** * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new * input [[Row]] for a fixed set of [[Expression Expressions]]. */ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => MutableProjection] { - import scala.reflect.runtime.{universe => ru} - import scala.reflect.runtime.universe._ - - val mutableRowName = newTermName("mutableRow") protected def canonicalize(in: Seq[Expression]): Seq[Expression] = in.map(ExpressionCanonicalizer.execute) @@ -36,41 +35,61 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu in.map(BindReferences.bindReference(_, inputSchema)) protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { - val projectionCode = expressions.zipWithIndex.flatMap { case (e, i) => - val evaluationCode = expressionEvaluator(e) - - evaluationCode.code :+ - q""" - if(${evaluationCode.nullTerm}) - mutableRow.setNullAt($i) - else - ${setColumn(mutableRowName, e.dataType, i, evaluationCode.primitiveTerm)} - """ - } + val ctx = newCodeGenContext() + val projectionCode = expressions.zipWithIndex.map { case (e, i) => + val evaluationCode = expressionEvaluator(e, ctx) + evaluationCode.code + + s""" + if(${evaluationCode.nullTerm}) + mutableRow.setNullAt($i); + else + ${setColumn("mutableRow", e.dataType, i, evaluationCode.primitiveTerm)}; + """ + }.mkString("\n") + val code = s""" + import org.apache.spark.sql.Row; + + public SpecificProjection generate($exprType[] expr) { + return new SpecificProjection(expr); + } + + class SpecificProjection extends ${classOf[BaseMutableProjection].getName} { - val code = - q""" - () => { new $mutableProjectionType { + private $exprType[] expressions = null; + private $mutableRowType mutableRow = null; - private[this] var $mutableRowName: $mutableRowType = - new $genericMutableRowType(${expressions.size}) + public SpecificProjection($exprType[] expr) { + expressions = expr; + mutableRow = new $genericMutableRowType(${expressions.size}); + } - def target(row: $mutableRowType): $mutableProjectionType = { - $mutableRowName = row - this - } + public ${classOf[BaseMutableProjection].getName} target($mutableRowType row) { + mutableRow = row; + return this; + } - /* Provide immutable access to the last projected row. */ - def currentValue: $rowType = mutableRow + /* Provide immutable access to the last projected row. */ + public Row currentValue() { + return mutableRow; + } - def apply(i: $rowType): $rowType = { - ..$projectionCode - mutableRow - } - } } - """ + public Object apply(Object _i) { + Row i = (Row) _i; + $projectionCode - log.debug(s"code for ${expressions.mkString(",")}:\n$code") - toolBox.eval(code).asInstanceOf[() => MutableProjection] + return mutableRow; + } + } + """ + + + logDebug(s"code for ${expressions.mkString(",")}:\n$code") + + val c = compile(code) + // fetch the only one method `generate(Expression[])` + val m = c.getDeclaredMethods()(0) + () => { + m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[BaseMutableProjection] + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index b129c0d898bb7..0ff840dab393c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -18,18 +18,29 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.Logging +import org.apache.spark.annotation.Private +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{BinaryType, StringType, NumericType} +import org.apache.spark.sql.types.{BinaryType, NumericType} + +/** + * Inherits some default implementation for Java from `Ordering[Row]` + */ +@Private +class BaseOrdering extends Ordering[Row] { + def compare(a: Row, b: Row): Int = { + throw new UnsupportedOperationException + } +} /** * Generates bytecode for an [[Ordering]] of [[Row Rows]] for a given set of * [[Expression Expressions]]. */ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] with Logging { - import scala.reflect.runtime.{universe => ru} import scala.reflect.runtime.universe._ - protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = + protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder]) protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] = @@ -38,73 +49,90 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit protected def create(ordering: Seq[SortOrder]): Ordering[Row] = { val a = newTermName("a") val b = newTermName("b") - val comparisons = ordering.zipWithIndex.map { case (order, i) => - val evalA = expressionEvaluator(order.child) - val evalB = expressionEvaluator(order.child) + val ctx = newCodeGenContext() + val comparisons = ordering.zipWithIndex.map { case (order, i) => + val evalA = expressionEvaluator(order.child, ctx) + val evalB = expressionEvaluator(order.child, ctx) + val asc = order.direction == Ascending val compare = order.child.dataType match { case BinaryType => - q""" - val x = ${if (order.direction == Ascending) evalA.primitiveTerm else evalB.primitiveTerm} - val y = ${if (order.direction != Ascending) evalB.primitiveTerm else evalA.primitiveTerm} - var i = 0 - while (i < x.length && i < y.length) { - val res = x(i).compareTo(y(i)) - if (res != 0) return res - i = i+1 - } - return x.length - y.length - """ + s""" + { + byte[] x = ${if (asc) evalA.primitiveTerm else evalB.primitiveTerm}; + byte[] y = ${if (!asc) evalB.primitiveTerm else evalA.primitiveTerm}; + int j = 0; + while (j < x.length && j < y.length) { + if (x[j] != y[j]) return x[j] - y[j]; + j = j + 1; + } + int d = x.length - y.length; + if (d != 0) { + return d; + } + }""" case _: NumericType => - q""" - val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm} - if(comp != 0) { - return ${if (order.direction == Ascending) q"comp.toInt" else q"-comp.toInt"} - } - """ - case StringType => - if (order.direction == Ascending) { - q"""return ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm})""" + s""" + if (${evalA.primitiveTerm} != ${evalB.primitiveTerm}) { + if (${evalA.primitiveTerm} > ${evalB.primitiveTerm}) { + return ${if (asc) "1" else "-1"}; + } else { + return ${if (asc) "-1" else "1"}; + } + }""" + case _ => + s""" + int comp = ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm}); + if (comp != 0) { + return ${if (asc) "comp" else "-comp"}; + }""" + } + + s""" + i = $a; + ${evalA.code} + i = $b; + ${evalB.code} + if (${evalA.nullTerm} && ${evalB.nullTerm}) { + // Nothing + } else if (${evalA.nullTerm}) { + return ${if (order.direction == Ascending) "-1" else "1"}; + } else if (${evalB.nullTerm}) { + return ${if (order.direction == Ascending) "1" else "-1"}; } else { - q"""return ${evalB.primitiveTerm}.compare(${evalA.primitiveTerm})""" + $compare } + """ + }.mkString("\n") + + val code = s""" + import org.apache.spark.sql.Row; + + public SpecificOrdering generate($exprType[] expr) { + return new SpecificOrdering(expr); } - q""" - i = $a - ..${evalA.code} - i = $b - ..${evalB.code} - if (${evalA.nullTerm} && ${evalB.nullTerm}) { - // Nothing - } else if (${evalA.nullTerm}) { - return ${if (order.direction == Ascending) q"-1" else q"1"} - } else if (${evalB.nullTerm}) { - return ${if (order.direction == Ascending) q"1" else q"-1"} - } else { - $compare + class SpecificOrdering extends ${typeOf[BaseOrdering]} { + + private $exprType[] expressions = null; + + public SpecificOrdering($exprType[] expr) { + expressions = expr; } - """ - } - val q"class $orderingName extends $orderingType { ..$body }" = reify { - class SpecificOrdering extends Ordering[Row] { - val o = ordering - } - }.tree.children.head - - val code = q""" - class $orderingName extends $orderingType { - ..$body - def compare(a: $rowType, b: $rowType): Int = { - var i: $rowType = null // Holds current row being evaluated. - ..$comparisons - return 0 + @Override + public int compare(Row a, Row b) { + Row i = null; // Holds current row being evaluated. + $comparisons + return 0; } - } - new $orderingName() - """ + }""" + logDebug(s"Generated Ordering: $code") - toolBox.eval(code).asInstanceOf[Ordering[Row]] + + val c = compile(code) + // fetch the only one method `generate(Expression[])` + val m = c.getDeclaredMethods()(0) + m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[BaseOrdering] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 40e163024360e..fb18769f00da3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -19,12 +19,17 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +/** + * Interface for generated predicate + */ +abstract class Predicate { + def eval(r: Row): Boolean +} + /** * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[Row]]. */ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { - import scala.reflect.runtime.{universe => ru} - import scala.reflect.runtime.universe._ protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in) @@ -32,17 +37,34 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { BindReferences.bindReference(in, inputSchema) protected def create(predicate: Expression): ((Row) => Boolean) = { - val cEval = expressionEvaluator(predicate) + val ctx = newCodeGenContext() + val eval = expressionEvaluator(predicate, ctx) + val code = s""" + import org.apache.spark.sql.Row; - val code = - q""" - (i: $rowType) => { - ..${cEval.code} - if (${cEval.nullTerm}) false else ${cEval.primitiveTerm} + public SpecificPredicate generate($exprType[] expr) { + return new SpecificPredicate(expr); + } + + class SpecificPredicate extends ${classOf[Predicate].getName} { + private final $exprType[] expressions; + public SpecificPredicate($exprType[] expr) { + expressions = expr; + } + + @Override + public boolean eval(Row i) { + ${eval.code} + return !${eval.nullTerm} && ${eval.primitiveTerm}; } - """ + }""" + + logDebug(s"Generated predicate '$predicate':\n$code") - log.debug(s"Generated predicate '$predicate':\n$code") - toolBox.eval(code).asInstanceOf[Row => Boolean] + val c = compile(code) + // fetch the only one method `generate(Expression[])` + val m = c.getDeclaredMethods()(0) + val p = m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[Predicate] + (r: Row) => p.eval(r) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 31c63a79ebc8c..d5be1fc12e0f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -17,9 +17,14 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import org.apache.spark.sql.BaseMutableRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ +/** + * Java can not access Projection (in package object) + */ +abstract class BaseProject extends Projection {} /** * Generates bytecode that produces a new [[Row]] object based on a fixed set of input @@ -27,7 +32,6 @@ import org.apache.spark.sql.types._ * generated based on the output types of the [[Expression]] to avoid boxing of primitive values. */ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { - import scala.reflect.runtime.{universe => ru} import scala.reflect.runtime.universe._ protected def canonicalize(in: Seq[Expression]): Seq[Expression] = @@ -38,201 +42,183 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { // Make Mutablility optional... protected def create(expressions: Seq[Expression]): Projection = { - val tupleLength = ru.Literal(Constant(expressions.length)) - val lengthDef = q"final val length = $tupleLength" - - /* TODO: Configurable... - val nullFunctions = - q""" - private final val nullSet = new org.apache.spark.util.collection.BitSet(length) - final def setNullAt(i: Int) = nullSet.set(i) - final def isNullAt(i: Int) = nullSet.get(i) - """ - */ - - val nullFunctions = - q""" - private[this] var nullBits = new Array[Boolean](${expressions.size}) - override def setNullAt(i: Int) = { nullBits(i) = true } - override def isNullAt(i: Int) = nullBits(i) - """.children - - val tupleElements = expressions.zipWithIndex.flatMap { + val ctx = newCodeGenContext() + val columns = expressions.zipWithIndex.map { case (e, i) => - val elementName = newTermName(s"c$i") - val evaluatedExpression = expressionEvaluator(e) - val iLit = ru.Literal(Constant(i)) + s"private ${primitiveForType(e.dataType)} c$i = ${defaultPrimitive(e.dataType)};\n" + }.mkString("\n ") - q""" - var ${newTermName(s"c$i")}: ${termForType(e.dataType)} = _ + val initColumns = expressions.zipWithIndex.map { + case (e, i) => + val eval = expressionEvaluator(e, ctx) + s""" { - ..${evaluatedExpression.code} - if(${evaluatedExpression.nullTerm}) - setNullAt($iLit) - else { - nullBits($iLit) = false - $elementName = ${evaluatedExpression.primitiveTerm} + // column$i + ${eval.code} + nullBits[$i] = ${eval.nullTerm}; + if(!${eval.nullTerm}) { + c$i = ${eval.primitiveTerm}; } } - """.children : Seq[Tree] - } + """ + }.mkString("\n") - val accessorFailure = q"""scala.sys.error("Invalid ordinal:" + i)""" - val applyFunction = { - val cases = (0 until expressions.size).map { i => - val ordinal = ru.Literal(Constant(i)) - val elementName = newTermName(s"c$i") - val iLit = ru.Literal(Constant(i)) + val getCases = (0 until expressions.size).map { i => + s"case $i: return c$i;" + }.mkString("\n ") - q"if(i == $ordinal) { if(isNullAt($i)) return null else return $elementName }" - } - q"override def apply(i: Int): Any = { ..$cases; $accessorFailure }" - } - - val updateFunction = { - val cases = expressions.zipWithIndex.map {case (e, i) => - val ordinal = ru.Literal(Constant(i)) - val elementName = newTermName(s"c$i") - val iLit = ru.Literal(Constant(i)) - - q""" - if(i == $ordinal) { - if(value == null) { - setNullAt(i) - } else { - nullBits(i) = false - $elementName = value.asInstanceOf[${termForType(e.dataType)}] - } - return - }""" - } - q"override def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }" - } + val updateCases = expressions.zipWithIndex.map { case (e, i) => + s"case $i: { c$i = (${termForType(e.dataType)})value; return;}" + }.mkString("\n ") val specificAccessorFunctions = nativeTypes.map { dataType => - val ifStatements = expressions.zipWithIndex.flatMap { - // getString() is not used by expressions - case (e, i) if e.dataType == dataType && dataType != StringType => - val elementName = newTermName(s"c$i") - // TODO: The string of ifs gets pretty inefficient as the row grows in size. - // TODO: Optional null checks? - q"if(i == $i) return $elementName" :: Nil - case _ => Nil - } - dataType match { - // Row() need this interface to compile - case StringType => - q""" - override def getString(i: Int): String = { - $accessorFailure - }""" - case other => - q""" - override def ${accessorForType(dataType)}(i: Int): ${termForType(dataType)} = { - ..$ifStatements; - $accessorFailure - }""" + val cases = expressions.zipWithIndex.map { + case (e, i) if e.dataType == dataType => + s"case $i: return c$i;" + case _ => "" + }.mkString("\n ") + if (cases.count(_ != '\n') > 0) { + s""" + @Override + public ${primitiveForType(dataType)} ${accessorForType(dataType)}(int i) { + if (isNullAt(i)) { + return ${defaultPrimitive(dataType)}; + } + switch (i) { + $cases + } + return ${defaultPrimitive(dataType)}; + }""" + } else { + "" } - } + }.mkString("\n") val specificMutatorFunctions = nativeTypes.map { dataType => - val ifStatements = expressions.zipWithIndex.flatMap { - // setString() is not used by expressions - case (e, i) if e.dataType == dataType && dataType != StringType => - val elementName = newTermName(s"c$i") - // TODO: The string of ifs gets pretty inefficient as the row grows in size. - // TODO: Optional null checks? - q"if(i == $i) { nullBits($i) = false; $elementName = value; return }" :: Nil - case _ => Nil - } - dataType match { - case StringType => - // MutableRow() need this interface to compile - q""" - override def setString(i: Int, value: String) { - $accessorFailure - }""" - case other => - q""" - override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}) { - ..$ifStatements; - $accessorFailure - }""" + val cases = expressions.zipWithIndex.map { + case (e, i) if e.dataType == dataType => + s"case $i: { c$i = value; return; }" + case _ => "" + }.mkString("\n") + if (cases.count(_ != '\n') > 0) { + s""" + @Override + public void ${mutatorForType(dataType)}(int i, ${primitiveForType(dataType)} value) { + nullBits[i] = false; + switch (i) { + $cases + } + }""" + } else { + "" } - } + }.mkString("\n") val hashValues = expressions.zipWithIndex.map { case (e, i) => - val elementName = newTermName(s"c$i") + val col = newTermName(s"c$i") val nonNull = e.dataType match { - case BooleanType => q"if ($elementName) 0 else 1" - case ByteType | ShortType | IntegerType => q"$elementName.toInt" - case LongType => q"($elementName ^ ($elementName >>> 32)).toInt" - case FloatType => q"java.lang.Float.floatToIntBits($elementName)" + case BooleanType => s"$col ? 0 : 1" + case ByteType | ShortType | IntegerType | DateType => s"$col" + case LongType => s"$col ^ ($col >>> 32)" + case FloatType => s"Float.floatToIntBits($col)" case DoubleType => - q"{ val b = java.lang.Double.doubleToLongBits($elementName); (b ^ (b >>>32)).toInt }" - case _ => q"$elementName.hashCode" + s"Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32)" + case _ => s"$col.hashCode()" } - q"if (isNullAt($i)) 0 else $nonNull" + s"isNullAt($i) ? 0 : ($nonNull)" } - val hashUpdates: Seq[Tree] = hashValues.map(v => q"""result = 37 * result + $v""": Tree) + val hashUpdates: String = hashValues.map( v => + s""" + result *= 37; result += $v;""" + ).mkString("\n") - val hashCodeFunction = - q""" - override def hashCode(): Int = { - var result: Int = 37 - ..$hashUpdates - result - } + val columnChecks = expressions.zipWithIndex.map { case (e, i) => + s""" + if (isNullAt($i) != row.isNullAt($i) || !isNullAt($i) && !get($i).equals(row.get($i))) { + return false; + } """ + }.mkString("\n") - val columnChecks = (0 until expressions.size).map { i => - val elementName = newTermName(s"c$i") - q"if (this.$elementName != specificType.$elementName) return false" + val code = s""" + import org.apache.spark.sql.Row; + + public SpecificProjection generate($exprType[] expr) { + return new SpecificProjection(expr); } - val equalsFunction = - q""" - override def equals(other: Any): Boolean = other match { - case specificType: SpecificRow => - ..$columnChecks - return true - case other => super.equals(other) - } - """ + class SpecificProjection extends ${typeOf[BaseProject]} { + private $exprType[] expressions = null; + + public SpecificProjection($exprType[] expr) { + expressions = expr; + } - val allColumns = (0 until expressions.size).map { i => - val iLit = ru.Literal(Constant(i)) - q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }" + @Override + public Object apply(Object r) { + return new SpecificRow(expressions, (Row) r); + } } - val copyFunction = - q"override def copy() = new $genericRowType(Array[Any](..$allColumns))" - - val toSeqFunction = - q"override def toSeq: Seq[Any] = Seq(..$allColumns)" - - val classBody = - nullFunctions ++ ( - lengthDef +: - applyFunction +: - updateFunction +: - equalsFunction +: - hashCodeFunction +: - copyFunction +: - toSeqFunction +: - (tupleElements ++ specificAccessorFunctions ++ specificMutatorFunctions)) - - val code = q""" - final class SpecificRow(i: $rowType) extends $mutableRowType { - ..$classBody + final class SpecificRow extends ${typeOf[BaseMutableRow]} { + + $columns + + public SpecificRow($exprType[] expressions, Row i) { + $initColumns + } + + public int size() { return ${expressions.length};} + private boolean[] nullBits = new boolean[${expressions.length}]; + public void setNullAt(int i) { nullBits[i] = true; } + public boolean isNullAt(int i) { return nullBits[i]; } + + public Object get(int i) { + if (isNullAt(i)) return null; + switch (i) { + $getCases + } + return null; + } + public void update(int i, Object value) { + if (value == null) { + setNullAt(i); + return; + } + nullBits[i] = false; + switch (i) { + $updateCases + } + } + $specificAccessorFunctions + $specificMutatorFunctions + + @Override + public int hashCode() { + int result = 37; + $hashUpdates + return result; } - new $projectionType { def apply(r: $rowType) = new SpecificRow(r) } + @Override + public boolean equals(Object other) { + if (other instanceof Row) { + Row row = (Row) other; + if (row.length() != size()) return false; + $columnChecks + return true; + } + return super.equals(other); + } + } """ - log.debug( - s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${toolBox.typeCheck(code)}") - toolBox.eval(code).asInstanceOf[Projection] + logDebug(s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${code}") + + val c = compile(code) + // fetch the only one method `generate(Expression[])` + val m = c.getDeclaredMethods()(0) + m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[Projection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala index 528e38a50a740..7f1b12cdd5800 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala @@ -27,12 +27,6 @@ import org.apache.spark.util.Utils */ package object codegen { - /** - * A lock to protect invoking the scala compiler at runtime, since it is not thread safe in Scala - * 2.10. - */ - protected[codegen] val globalLock = org.apache.spark.sql.catalyst.ScalaReflectionLock - /** Canonicalizes an expression so those that differ only by names can reuse the same code. */ object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] { val batches = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index b6927485f42bf..5df528770ca6e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -344,7 +344,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation("abdef" cast TimestampType, null) checkEvaluation("12.65" cast DecimalType.Unlimited, Decimal(12.65)) - checkEvaluation(Literal(1) cast LongType, 1) + checkEvaluation(Literal(1) cast LongType, 1.toLong) checkEvaluation(Cast(Literal(1000) cast TimestampType, LongType), 1.toLong) checkEvaluation(Cast(Literal(-1200) cast TimestampType, LongType), -2.toLong) checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) @@ -363,13 +363,16 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(Cast("abdef" cast BinaryType, StringType), "abdef") checkEvaluation(Cast(Cast(Cast(Cast( - Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), 5) + Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), + 5.toLong) checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast - ByteType, TimestampType), DecimalType.Unlimited), LongType), StringType), ShortType), 0) + ByteType, TimestampType), DecimalType.Unlimited), LongType), StringType), ShortType), + 0.toShort) checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast TimestampType, ByteType), DecimalType.Unlimited), LongType), StringType), ShortType), null) checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast - DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType), 0) + DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType), + 0.toShort) checkEvaluation(Literal(true) cast IntegerType, 1) checkEvaluation(Literal(false) cast IntegerType, 0) checkEvaluation(Literal(true) cast StringType, "true") @@ -509,9 +512,9 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { val seconds = millis * 1000 + 2 val ts = new Timestamp(millis) val tss = new Timestamp(seconds) - checkEvaluation(Cast(ts, ShortType), 15) + checkEvaluation(Cast(ts, ShortType), 15.toShort) checkEvaluation(Cast(ts, IntegerType), 15) - checkEvaluation(Cast(ts, LongType), 15) + checkEvaluation(Cast(ts, LongType), 15.toLong) checkEvaluation(Cast(ts, FloatType), 15.002f) checkEvaluation(Cast(ts, DoubleType), 15.002) checkEvaluation(Cast(Cast(tss, ShortType), TimestampType), ts) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala index d7c437095e395..8cfd853afa35f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala @@ -32,11 +32,12 @@ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() } catch { case e: Throwable => - val evaluated = GenerateProjection.expressionEvaluator(expression) + val ctx = GenerateProjection.newCodeGenContext() + val evaluated = GenerateProjection.expressionEvaluator(expression, ctx) fail( s""" |Code generation of $expression failed: - |${evaluated.code.mkString("\n")} + |${evaluated.code} |$e """.stripMargin) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala index a40324b008e16..9ab1f7d7ad0db 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala @@ -28,7 +28,8 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { - lazy val evaluated = GenerateProjection.expressionEvaluator(expression) + val ctx = GenerateProjection.newCodeGenContext() + lazy val evaluated = GenerateProjection.expressionEvaluator(expression, ctx) val plan = try { GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) @@ -37,7 +38,7 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { fail( s""" |Code generation of $expression failed: - |${evaluated.code.mkString("\n")} + |${evaluated.code} |$e """.stripMargin) } @@ -49,7 +50,7 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { s""" |Mismatched hashCodes for values: $actual, $expectedRow |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} - |${evaluated.code.mkString("\n")} + |${evaluated.code} """.stripMargin) } if (actual != expectedRow) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 9aaec2b064d76..b41b1b77d049e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -451,10 +451,13 @@ class DataFrameSuite extends QueryTest { test("SPARK-6899") { val originalValue = TestSQLContext.conf.codegenEnabled TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, "true") - checkAnswer( - decimalData.agg(avg('a)), - Row(new java.math.BigDecimal(2.0))) - TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + try{ + checkAnswer( + decimalData.agg(avg('a)), + Row(new java.math.BigDecimal(2.0))) + } finally { + TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + } } test("SPARK-7133: Implement struct, array, and map field accessor") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 63f7d314fb699..55b68d8e2283c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -184,77 +184,79 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { checkAnswer(df, expectedResults) } - // Just to group rows. - testCodeGen( - "SELECT key FROM testData3x GROUP BY key", - (1 to 100).map(Row(_))) - // COUNT - testCodeGen( - "SELECT key, count(value) FROM testData3x GROUP BY key", - (1 to 100).map(i => Row(i, 3))) - testCodeGen( - "SELECT count(key) FROM testData3x", - Row(300) :: Nil) - // COUNT DISTINCT ON int - testCodeGen( - "SELECT value, count(distinct key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, 1))) - testCodeGen( - "SELECT count(distinct key) FROM testData3x", - Row(100) :: Nil) - // SUM - testCodeGen( - "SELECT value, sum(key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, 3 * i))) - testCodeGen( - "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x", - Row(5050 * 3, 5050 * 3.0) :: Nil) - // AVERAGE - testCodeGen( - "SELECT value, avg(key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, i))) - testCodeGen( - "SELECT avg(key) FROM testData3x", - Row(50.5) :: Nil) - // MAX - testCodeGen( - "SELECT value, max(key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, i))) - testCodeGen( - "SELECT max(key) FROM testData3x", - Row(100) :: Nil) - // MIN - testCodeGen( - "SELECT value, min(key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, i))) - testCodeGen( - "SELECT min(key) FROM testData3x", - Row(1) :: Nil) - // Some combinations. - testCodeGen( - """ - |SELECT - | value, - | sum(key), - | max(key), - | min(key), - | avg(key), - | count(key), - | count(distinct key) - |FROM testData3x - |GROUP BY value - """.stripMargin, - (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1))) - testCodeGen( - "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x", - Row(100, 1, 50.5, 300, 100) :: Nil) - // Aggregate with Code generation handling all null values - testCodeGen( - "SELECT sum('a'), avg('a'), count(null) FROM testData", - Row(0, null, 0) :: Nil) - - dropTempTable("testData3x") - setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + try { + // Just to group rows. + testCodeGen( + "SELECT key FROM testData3x GROUP BY key", + (1 to 100).map(Row(_))) + // COUNT + testCodeGen( + "SELECT key, count(value) FROM testData3x GROUP BY key", + (1 to 100).map(i => Row(i, 3))) + testCodeGen( + "SELECT count(key) FROM testData3x", + Row(300) :: Nil) + // COUNT DISTINCT ON int + testCodeGen( + "SELECT value, count(distinct key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 1))) + testCodeGen( + "SELECT count(distinct key) FROM testData3x", + Row(100) :: Nil) + // SUM + testCodeGen( + "SELECT value, sum(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 3 * i))) + testCodeGen( + "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x", + Row(5050 * 3, 5050 * 3.0) :: Nil) + // AVERAGE + testCodeGen( + "SELECT value, avg(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT avg(key) FROM testData3x", + Row(50.5) :: Nil) + // MAX + testCodeGen( + "SELECT value, max(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT max(key) FROM testData3x", + Row(100) :: Nil) + // MIN + testCodeGen( + "SELECT value, min(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT min(key) FROM testData3x", + Row(1) :: Nil) + // Some combinations. + testCodeGen( + """ + |SELECT + | value, + | sum(key), + | max(key), + | min(key), + | avg(key), + | count(key), + | count(distinct key) + |FROM testData3x + |GROUP BY value + """.stripMargin, + (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1))) + testCodeGen( + "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x", + Row(100, 1, 50.5, 300, 100) :: Nil) + // Aggregate with Code generation handling all null values + testCodeGen( + "SELECT sum('a'), avg('a'), count(null) FROM testData", + Row(0, null, 0) :: Nil) + } finally { + dropTempTable("testData3x") + setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + } } test("Add Parser of SQL COALESCE()") { @@ -463,9 +465,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { val codegenbefore = conf.codegenEnabled setConf(SQLConf.EXTERNAL_SORT, "false") setConf(SQLConf.CODEGEN_ENABLED, "true") - sortTest() - setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) - setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + try{ + sortTest() + } finally { + setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) + setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + } } test("SPARK-6927 external sorting with codegen on") { @@ -473,9 +478,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { val codegenbefore = conf.codegenEnabled setConf(SQLConf.CODEGEN_ENABLED, "true") setConf(SQLConf.EXTERNAL_SORT, "true") - sortTest() - setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) - setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + try { + sortTest() + } finally { + setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) + setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + } } test("limit") { From df7da07a86a30c684d5b07d955f1045a66715e3a Mon Sep 17 00:00:00 2001 From: Mike Dusenberry Date: Thu, 4 Jun 2015 11:30:07 -0700 Subject: [PATCH 185/254] [SPARK-7969] [SQL] Added a DataFrame.drop function that accepts a Column reference. Added a `DataFrame.drop` function that accepts a `Column` reference rather than a `String`, and added associated unit tests. Basically iterates through the `DataFrame` to find a column with an expression that is equivalent to that of the `Column` argument supplied to the function. Author: Mike Dusenberry Closes #6585 from dusenberrymw/SPARK-7969_Drop_method_on_Dataframes_should_handle_Column and squashes the following commits: 514727a [Mike Dusenberry] Updating the @since tag of the drop(Column) function doc to reflect version 1.4.1 instead of 1.4.0. 2f1bb4e [Mike Dusenberry] Adding an additional assert statement to the 'drop column after join' unit test in order to make sure the correct column was indeed left over. 6bf7c0e [Mike Dusenberry] Minor code formatting change. e583888 [Mike Dusenberry] Adding more Python doctests for the df.drop with column reference function to test joined datasets that have columns with the same name. 5f74401 [Mike Dusenberry] Updating DataFrame.drop with column reference function to use logicalPlan.output to prevent ambiguities resulting from columns with the same name. Also added associated unit tests for joined datasets with duplicate column names. 4b8bbe8 [Mike Dusenberry] Adding Python support for Dataframe.drop with a Column reference. 986129c [Mike Dusenberry] Added a DataFrame.drop function that accepts a Column reference rather than a String, and added associated unit tests. Basically iterates through the DataFrame to find a column with an expression that is equivalent to one supplied to the function. --- python/pyspark/sql/dataframe.py | 21 +++++++-- .../org/apache/spark/sql/DataFrame.scala | 16 +++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 45 +++++++++++++++++++ 3 files changed, 79 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 7673153abe0e2..03b01a1136e45 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1189,15 +1189,30 @@ def withColumnRenamed(self, existing, new): @since(1.4) @ignore_unicode_prefix - def drop(self, colName): + def drop(self, col): """Returns a new :class:`DataFrame` that drops the specified column. - :param colName: string, name of the column to drop. + :param col: a string name of the column to drop, or a + :class:`Column` to drop. >>> df.drop('age').collect() [Row(name=u'Alice'), Row(name=u'Bob')] + + >>> df.drop(df.age).collect() + [Row(name=u'Alice'), Row(name=u'Bob')] + + >>> df.join(df2, df.name == df2.name, 'inner').drop(df.name).collect() + [Row(age=5, height=85, name=u'Bob')] + + >>> df.join(df2, df.name == df2.name, 'inner').drop(df2.name).collect() + [Row(age=5, name=u'Bob', height=85)] """ - jdf = self._jdf.drop(colName) + if isinstance(col, basestring): + jdf = self._jdf.drop(col) + elif isinstance(col, Column): + jdf = self._jdf.drop(col._jc) + else: + raise TypeError("col should be a string or a Column") return DataFrame(jdf, self.sql_ctx) @since(1.3) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 034d887901975..d1a54ada7b191 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1082,6 +1082,22 @@ class DataFrame private[sql]( } } + /** + * Returns a new [[DataFrame]] with a column dropped. + * This version of drop accepts a Column rather than a name. + * This is a no-op if the DataFrame doesn't have a column + * with an equivalent expression. + * @group dfops + * @since 1.4.1 + */ + def drop(col: Column): DataFrame = { + val attrs = this.logicalPlan.output + val colsAfterDrop = attrs.filter { attr => + attr != col.expr + }.map(attr => Column(attr)) + select(colsAfterDrop : _*) + } + /** * Returns a new [[DataFrame]] that contains only the unique rows from this [[DataFrame]]. * This is an alias for `distinct`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b41b1b77d049e..8e81dacb8660f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -334,6 +334,51 @@ class DataFrameSuite extends QueryTest { assert(df.schema.map(_.name) === Seq("key", "value")) } + test("drop column using drop with column reference") { + val col = testData("key") + val df = testData.drop(col) + checkAnswer( + df, + testData.collect().map(x => Row(x.getString(1))).toSeq) + assert(df.schema.map(_.name) === Seq("value")) + } + + test("drop unknown column (no-op) with column reference") { + val col = Column("random") + val df = testData.drop(col) + checkAnswer( + df, + testData.collect().toSeq) + assert(df.schema.map(_.name) === Seq("key", "value")) + } + + test("drop unknown column with same name (no-op) with column reference") { + val col = Column("key") + val df = testData.drop(col) + checkAnswer( + df, + testData.collect().toSeq) + assert(df.schema.map(_.name) === Seq("key", "value")) + } + + test("drop column after join with duplicate columns using column reference") { + val newSalary = salary.withColumnRenamed("personId", "id") + val col = newSalary("id") + // this join will result in duplicate "id" columns + val joinedDf = person.join(newSalary, + person("id") === newSalary("id"), "inner") + // remove only the "id" column that was associated with newSalary + val df = joinedDf.drop(col) + checkAnswer( + df, + joinedDf.collect().map { + case Row(id: Int, name: String, age: Int, idToDrop: Int, salary: Double) => + Row(id, name, age, salary) + }.toSeq) + assert(df.schema.map(_.name) === Seq("id", "name", "age", "salary")) + assert(df("id") == person("id")) + } + test("withColumnRenamed") { val df = testData.toDF().withColumn("newCol", col("key") + 1) .withColumnRenamed("value", "valueRenamed") From cd3176bd86eafa09a5e11baf3636861c1f46e844 Mon Sep 17 00:00:00 2001 From: Thomas Omans Date: Thu, 4 Jun 2015 11:32:03 -0700 Subject: [PATCH 186/254] [SPARK-7743] [SQL] Parquet 1.7 Resolves [SPARK-7743](https://issues.apache.org/jira/browse/SPARK-7743). Trivial changes of versions, package names, as well as a small issue in `ParquetTableOperations.scala` ```diff - val readContext = getReadSupport(configuration).init( + val readContext = ParquetInputFormat.getReadSupportInstance(configuration).init( ``` Since ParquetInputFormat.getReadSupport was made package private in the latest release. Thanks -- Thomas Omans Author: Thomas Omans Closes #6597 from eggsby/SPARK-7743 and squashes the following commits: 2df0d1b [Thomas Omans] [SPARK-7743] [SQL] Upgrading parquet version to 1.7.0 --- .../src/main/python/parquet_inputformat.py | 2 +- pom.xml | 6 ++-- sql/core/pom.xml | 4 +-- .../DirectParquetOutputCommitter.scala | 6 ++-- .../spark/sql/parquet/ParquetConverter.scala | 6 ++-- .../spark/sql/parquet/ParquetFilters.scala | 10 +++--- .../spark/sql/parquet/ParquetRelation.scala | 10 +++--- .../sql/parquet/ParquetTableOperations.scala | 34 +++++++++---------- .../sql/parquet/ParquetTableSupport.scala | 12 +++---- .../spark/sql/parquet/ParquetTypes.scala | 14 ++++---- .../apache/spark/sql/parquet/newParquet.scala | 8 ++--- .../sql/parquet/timestamp/NanoTime.scala | 4 +-- .../apache/spark/sql/sources/commands.scala | 2 +- sql/core/src/test/resources/log4j.properties | 10 +++--- .../sql/parquet/ParquetFilterSuite.scala | 4 +-- .../spark/sql/parquet/ParquetIOSuite.scala | 18 +++++----- .../sql/parquet/ParquetSchemaSuite.scala | 2 +- 17 files changed, 76 insertions(+), 76 deletions(-) diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py index 96ddac761d698..e1fd85b082c08 100644 --- a/examples/src/main/python/parquet_inputformat.py +++ b/examples/src/main/python/parquet_inputformat.py @@ -51,7 +51,7 @@ parquet_rdd = sc.newAPIHadoopFile( path, - 'parquet.avro.AvroParquetInputFormat', + 'org.apache.parquet.avro.AvroParquetInputFormat', 'java.lang.Void', 'org.apache.avro.generic.IndexedRecord', valueConverter='org.apache.spark.examples.pythonconverters.IndexedRecordToJavaConverter') diff --git a/pom.xml b/pom.xml index bcb6ef96a1206..abb9b55400340 100644 --- a/pom.xml +++ b/pom.xml @@ -136,7 +136,7 @@ 0.13.1 10.10.1.1 - 1.6.0rc3 + 1.7.0 1.2.4 8.1.14.v20131031 3.0.0.v201112011016 @@ -1080,13 +1080,13 @@ - com.twitter + org.apache.parquet parquet-column ${parquet.version} ${parquet.deps.scope} - com.twitter + org.apache.parquet parquet-hadoop ${parquet.version} ${parquet.deps.scope} diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 3192f81ffaecd..ed75475a87067 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -61,11 +61,11 @@ test - com.twitter + org.apache.parquet parquet-column - com.twitter + org.apache.parquet parquet-hadoop diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala index f5ce2718bec4a..62c4e92ebec68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala @@ -21,9 +21,9 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter -import parquet.Log -import parquet.hadoop.util.ContextUtil -import parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat} +import org.apache.parquet.Log +import org.apache.parquet.hadoop.util.ContextUtil +import org.apache.parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat} private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) extends ParquetOutputCommitter(outputPath, context) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index caa9f045537d0..85c2ce740fe52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -23,9 +23,9 @@ import java.util.{TimeZone, Calendar} import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap} import jodd.datetime.JDateTime -import parquet.column.Dictionary -import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} -import parquet.schema.MessageType +import org.apache.parquet.column.Dictionary +import org.apache.parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} +import org.apache.parquet.schema.MessageType import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.parquet.CatalystConverter.FieldType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index f0f4e7d147e75..88ae88e9684c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -21,11 +21,11 @@ import java.nio.ByteBuffer import com.google.common.io.BaseEncoding import org.apache.hadoop.conf.Configuration -import parquet.filter2.compat.FilterCompat -import parquet.filter2.compat.FilterCompat._ -import parquet.filter2.predicate.FilterApi._ -import parquet.filter2.predicate.{FilterApi, FilterPredicate} -import parquet.io.api.Binary +import org.apache.parquet.filter2.compat.FilterCompat +import org.apache.parquet.filter2.compat.FilterCompat._ +import org.apache.parquet.filter2.predicate.FilterApi._ +import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} +import org.apache.parquet.io.api.Binary import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index fcb9513ab66f6..09088ee91106c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -24,9 +24,9 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.permission.FsAction import org.apache.spark.sql.types.{StructType, DataType} -import parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat} -import parquet.hadoop.metadata.CompressionCodecName -import parquet.schema.MessageType +import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat} +import org.apache.parquet.hadoop.metadata.CompressionCodecName +import org.apache.parquet.schema.MessageType import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException} @@ -107,7 +107,7 @@ private[sql] object ParquetRelation { // // Therefore we need to force the class to be loaded. // This should really be resolved by Parquet. - Class.forName(classOf[parquet.Log].getName) + Class.forName(classOf[org.apache.parquet.Log].getName) // Note: Logger.getLogger("parquet") has a default logger // that appends to Console which needs to be cleared. @@ -127,7 +127,7 @@ private[sql] object ParquetRelation { type RowType = org.apache.spark.sql.catalyst.expressions.GenericMutableRow // The compression type - type CompressionType = parquet.hadoop.metadata.CompressionCodecName + type CompressionType = org.apache.parquet.hadoop.metadata.CompressionCodecName // The parquet compression short names val shortParquetCompressionCodecNames = Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index cb7ae246d0d75..1e694f2feabee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -33,13 +33,13 @@ import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat => NewFileOutputFormat} -import parquet.hadoop._ -import parquet.hadoop.api.ReadSupport.ReadContext -import parquet.hadoop.api.{InitContext, ReadSupport} -import parquet.hadoop.metadata.GlobalMetaData -import parquet.hadoop.util.ContextUtil -import parquet.io.ParquetDecodingException -import parquet.schema.MessageType +import org.apache.parquet.hadoop._ +import org.apache.parquet.hadoop.api.ReadSupport.ReadContext +import org.apache.parquet.hadoop.api.{InitContext, ReadSupport} +import org.apache.parquet.hadoop.metadata.GlobalMetaData +import org.apache.parquet.hadoop.util.ContextUtil +import org.apache.parquet.io.ParquetDecodingException +import org.apache.parquet.schema.MessageType import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mapred.SparkHadoopMapRedUtil @@ -78,7 +78,7 @@ private[sql] case class ParquetTableScan( }.toArray protected override def doExecute(): RDD[Row] = { - import parquet.filter2.compat.FilterCompat.FilterPredicateCompat + import org.apache.parquet.filter2.compat.FilterCompat.FilterPredicateCompat val sc = sqlContext.sparkContext val job = new Job(sc.hadoopConfiguration) @@ -136,7 +136,7 @@ private[sql] case class ParquetTableScan( baseRDD.mapPartitionsWithInputSplit { case (split, iter) => val partValue = "([^=]+)=([^=]+)".r val partValues = - split.asInstanceOf[parquet.hadoop.ParquetInputSplit] + split.asInstanceOf[org.apache.parquet.hadoop.ParquetInputSplit] .getPath .toString .split("/") @@ -378,7 +378,7 @@ private[sql] case class InsertIntoParquetTable( * to imported ones. */ private[parquet] class AppendingParquetOutputFormat(offset: Int) - extends parquet.hadoop.ParquetOutputFormat[Row] { + extends org.apache.parquet.hadoop.ParquetOutputFormat[Row] { // override to accept existing directories as valid output directory override def checkOutputSpecs(job: JobContext): Unit = {} var committer: OutputCommitter = null @@ -431,7 +431,7 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) * RecordFilter we want to use. */ private[parquet] class FilteringParquetRowInputFormat - extends parquet.hadoop.ParquetInputFormat[Row] with Logging { + extends org.apache.parquet.hadoop.ParquetInputFormat[Row] with Logging { private var fileStatuses = Map.empty[Path, FileStatus] @@ -439,7 +439,7 @@ private[parquet] class FilteringParquetRowInputFormat inputSplit: InputSplit, taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = { - import parquet.filter2.compat.FilterCompat.NoOpFilter + import org.apache.parquet.filter2.compat.FilterCompat.NoOpFilter val readSupport: ReadSupport[Row] = new RowReadSupport() @@ -501,7 +501,7 @@ private[parquet] class FilteringParquetRowInputFormat globalMetaData = new GlobalMetaData(globalMetaData.getSchema, mergedMetadata, globalMetaData.getCreatedBy) - val readContext = getReadSupport(configuration).init( + val readContext = ParquetInputFormat.getReadSupportInstance(configuration).init( new InitContext(configuration, globalMetaData.getKeyValueMetaData, globalMetaData.getSchema)) @@ -531,8 +531,8 @@ private[parquet] class FilteringParquetRowInputFormat minSplitSize: JLong, readContext: ReadContext): JList[ParquetInputSplit] = { - import parquet.filter2.compat.FilterCompat.Filter - import parquet.filter2.compat.RowGroupFilter + import org.apache.parquet.filter2.compat.FilterCompat.Filter + import org.apache.parquet.filter2.compat.RowGroupFilter import org.apache.spark.sql.parquet.FilteringParquetRowInputFormat.blockLocationCache @@ -547,7 +547,7 @@ private[parquet] class FilteringParquetRowInputFormat // https://github.com/apache/incubator-parquet-mr/pull/17 // is resolved val generateSplits = - Class.forName("parquet.hadoop.ClientSideMetadataSplitStrategy") + Class.forName("org.apache.parquet.hadoop.ClientSideMetadataSplitStrategy") .getDeclaredMethods.find(_.getName == "generateSplits").getOrElse( sys.error(s"Failed to reflectively invoke ClientSideMetadataSplitStrategy.generateSplits")) generateSplits.setAccessible(true) @@ -612,7 +612,7 @@ private[parquet] class FilteringParquetRowInputFormat // https://github.com/apache/incubator-parquet-mr/pull/17 // is resolved val generateSplits = - Class.forName("parquet.hadoop.TaskSideMetadataSplitStrategy") + Class.forName("org.apache.parquet.hadoop.TaskSideMetadataSplitStrategy") .getDeclaredMethods.find(_.getName == "generateTaskSideMDSplits").getOrElse( sys.error( s"Failed to reflectively invoke TaskSideMetadataSplitStrategy.generateTaskSideMDSplits")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 70a220cc43ab9..89db408b1c382 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql.parquet import java.util.{HashMap => JHashMap} import org.apache.hadoop.conf.Configuration -import parquet.column.ParquetProperties -import parquet.hadoop.ParquetOutputFormat -import parquet.hadoop.api.ReadSupport.ReadContext -import parquet.hadoop.api.{ReadSupport, WriteSupport} -import parquet.io.api._ -import parquet.schema.MessageType +import org.apache.parquet.column.ParquetProperties +import org.apache.parquet.hadoop.ParquetOutputFormat +import org.apache.parquet.hadoop.api.ReadSupport.ReadContext +import org.apache.parquet.hadoop.api.{ReadSupport, WriteSupport} +import org.apache.parquet.io.api._ +import org.apache.parquet.schema.MessageType import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index f8a5d84549336..ba2a35b74ef82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -25,13 +25,13 @@ import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job -import parquet.format.converter.ParquetMetadataConverter -import parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} -import parquet.hadoop.util.ContextUtil -import parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} -import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} -import parquet.schema.Type.Repetition -import parquet.schema.{ConversionPatterns, DecimalMetadata, GroupType => ParquetGroupType, MessageType, OriginalType => ParquetOriginalType, PrimitiveType => ParquetPrimitiveType, Type => ParquetType, Types => ParquetTypes} +import org.apache.parquet.format.converter.ParquetMetadataConverter +import org.apache.parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} +import org.apache.parquet.hadoop.util.ContextUtil +import org.apache.parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} +import org.apache.parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} +import org.apache.parquet.schema.Type.Repetition +import org.apache.parquet.schema.{ConversionPatterns, DecimalMetadata, GroupType => ParquetGroupType, MessageType, OriginalType => ParquetOriginalType, PrimitiveType => ParquetPrimitiveType, Type => ParquetType, Types => ParquetTypes} import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index bf55e2383ab56..5dda440240e60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -29,10 +29,10 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat -import parquet.filter2.predicate.FilterApi -import parquet.hadoop._ -import parquet.hadoop.metadata.CompressionCodecName -import parquet.hadoop.util.ContextUtil +import org.apache.parquet.filter2.predicate.FilterApi +import org.apache.parquet.hadoop._ +import org.apache.parquet.hadoop.metadata.CompressionCodecName +import org.apache.parquet.hadoop.util.ContextUtil import org.apache.spark.{Partition => SparkPartition, SerializableWritable, Logging, SparkException} import org.apache.spark.broadcast.Broadcast diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala index 70bcca7526aae..4d5ed211ad0c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.parquet.timestamp import java.nio.{ByteBuffer, ByteOrder} -import parquet.Preconditions -import parquet.io.api.{Binary, RecordConsumer} +import org.apache.parquet.Preconditions +import org.apache.parquet.io.api.{Binary, RecordConsumer} private[parquet] class NanoTime extends Serializable { private var julianDay = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 71f016b1f14de..e9932c09107db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter, FileOutputFormat} -import parquet.hadoop.util.ContextUtil +import org.apache.parquet.hadoop.util.ContextUtil import org.apache.spark._ import org.apache.spark.mapred.SparkHadoopMapRedUtil diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties index 28e90b9520b2c..12fb128149d32 100644 --- a/sql/core/src/test/resources/log4j.properties +++ b/sql/core/src/test/resources/log4j.properties @@ -36,11 +36,11 @@ log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %t %p %c{1}: %m%n log4j.appender.FA.Threshold = INFO # Some packages are noisy for no good reason. -log4j.additivity.parquet.hadoop.ParquetRecordReader=false -log4j.logger.parquet.hadoop.ParquetRecordReader=OFF +log4j.additivity.org.apache.parquet.hadoop.ParquetRecordReader=false +log4j.logger.org.apache.parquet.hadoop.ParquetRecordReader=OFF -log4j.additivity.parquet.hadoop.ParquetOutputCommitter=false -log4j.logger.parquet.hadoop.ParquetOutputCommitter=OFF +log4j.additivity.org.apache.parquet.hadoop.ParquetOutputCommitter=false +log4j.logger.org.apache.parquet.hadoop.ParquetOutputCommitter=OFF log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false log4j.logger.org.apache.hadoop.hive.serde2.lazy.LazyStruct=OFF @@ -52,5 +52,5 @@ log4j.additivity.hive.ql.metadata.Hive=false log4j.logger.hive.ql.metadata.Hive=OFF # Parquet related logging -log4j.logger.parquet.hadoop=WARN +log4j.logger.org.apache.parquet.hadoop=WARN log4j.logger.org.apache.spark.sql.parquet=INFO diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index bdc2ebabc5e9a..4aa5bcb7fdbca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.parquet import org.scalatest.BeforeAndAfterAll -import parquet.filter2.predicate.Operators._ -import parquet.filter2.predicate.{FilterPredicate, Operators} +import org.apache.parquet.filter2.predicate.Operators._ +import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index dd48bb350f26d..7f7c2cc1a6c26 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -24,14 +24,14 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.scalatest.BeforeAndAfterAll -import parquet.example.data.simple.SimpleGroup -import parquet.example.data.{Group, GroupWriter} -import parquet.hadoop.api.WriteSupport -import parquet.hadoop.api.WriteSupport.WriteContext -import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData, CompressionCodecName} -import parquet.hadoop.{Footer, ParquetFileWriter, ParquetWriter} -import parquet.io.api.RecordConsumer -import parquet.schema.{MessageType, MessageTypeParser} +import org.apache.parquet.example.data.simple.SimpleGroup +import org.apache.parquet.example.data.{Group, GroupWriter} +import org.apache.parquet.hadoop.api.WriteSupport +import org.apache.parquet.hadoop.api.WriteSupport.WriteContext +import org.apache.parquet.hadoop.metadata.{ParquetMetadata, FileMetaData, CompressionCodecName} +import org.apache.parquet.hadoop.{Footer, ParquetFileWriter, ParquetWriter} +import org.apache.parquet.io.api.RecordConsumer +import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Row @@ -400,7 +400,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } finally { configuration.set("spark.sql.parquet.output.committer.class", - "parquet.hadoop.ParquetOutputCommitter") + "org.apache.parquet.hadoop.ParquetOutputCommitter") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index caec2a6f25489..8b1745124b8e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.parquet import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import parquet.schema.MessageTypeParser +import org.apache.parquet.schema.MessageTypeParser import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.ScalaReflection From 3dc005282a694e105f40e429b28b0a677743341f Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Thu, 4 Jun 2015 12:52:16 -0700 Subject: [PATCH 187/254] [SPARK-8027] [SPARKR] Move man pages creation to install-dev.sh This also helps us get rid of the sparkr-docs maven profile as docs are now built by just using -Psparkr when the roxygen2 package is available Related to discussion in #6567 cc pwendell srowen -- Let me know if this looks better Author: Shivaram Venkataraman Closes #6593 from shivaram/sparkr-pom-cleanup and squashes the following commits: b282241 [Shivaram Venkataraman] Remove sparkr-docs from release script as well 8f100a5 [Shivaram Venkataraman] Move man pages creation to install-dev.sh This also helps us get rid of the sparkr-docs maven profile as docs are now built by just using -Psparkr when the roxygen2 package is available --- R/create-docs.sh | 5 +---- R/install-dev.sh | 9 ++++++++- core/pom.xml | 23 ----------------------- dev/create-release/create-release.sh | 16 ++++++++-------- 4 files changed, 17 insertions(+), 36 deletions(-) diff --git a/R/create-docs.sh b/R/create-docs.sh index af47c0863bdd0..6a4687b06ecb9 100755 --- a/R/create-docs.sh +++ b/R/create-docs.sh @@ -30,10 +30,7 @@ set -e export FWDIR="$(cd "`dirname "$0"`"; pwd)" pushd $FWDIR -# Generate Rd file -Rscript -e 'library(devtools); devtools::document(pkg="./pkg", roclets=c("rd"))' - -# Install the package +# Install the package (this will also generate the Rd files) ./install-dev.sh # Now create HTML files diff --git a/R/install-dev.sh b/R/install-dev.sh index b9e2527035994..1edd551f8d243 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -34,5 +34,12 @@ LIB_DIR="$FWDIR/lib" mkdir -p $LIB_DIR -# Install R +pushd $FWDIR + +# Generate Rd files if devtools is installed +Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }' + +# Install SparkR to $LIB_DIR R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ + +popd diff --git a/core/pom.xml b/core/pom.xml index e35694e9e98b4..40a64beccdc24 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -481,29 +481,6 @@ - - sparkr-docs - - - - org.codehaus.mojo - exec-maven-plugin - - - sparkr-pkg-docs - compile - - exec - - - - - ..${path.separator}R${path.separator}create-docs${script.extension} - - - - - diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 0b14a618e755c..54274a83f6d66 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -228,14 +228,14 @@ if [[ ! "$@" =~ --skip-package ]]; then # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds # share the same Zinc server. - make_binary_release "hadoop1" "-Psparkr -Psparkr-docs -Phadoop-1 -Phive -Phive-thriftserver" "3030" & - make_binary_release "hadoop1-scala2.11" "-Psparkr -Psparkr-docs -Phadoop-1 -Phive -Dscala-2.11" "3031" & - make_binary_release "cdh4" "-Psparkr -Psparkr-docs -Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & - make_binary_release "hadoop2.3" "-Psparkr -Psparkr-docs -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & - make_binary_release "hadoop2.4" "-Psparkr -Psparkr-docs -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & - make_binary_release "mapr3" "-Pmapr3 -Psparkr -Psparkr-docs -Phive -Phive-thriftserver" "3035" & - make_binary_release "mapr4" "-Pmapr4 -Psparkr -Psparkr-docs -Pyarn -Phive -Phive-thriftserver" "3036" & - make_binary_release "hadoop2.4-without-hive" "-Psparkr -Psparkr-docs -Phadoop-2.4 -Pyarn" "3037" & + make_binary_release "hadoop1" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver" "3030" & + make_binary_release "hadoop1-scala2.11" "-Psparkr -Phadoop-1 -Phive -Dscala-2.11" "3031" & + make_binary_release "cdh4" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & + make_binary_release "hadoop2.3" "-Psparkr -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & + make_binary_release "hadoop2.4" "-Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & + make_binary_release "mapr3" "-Pmapr3 -Psparkr -Phive -Phive-thriftserver" "3035" & + make_binary_release "mapr4" "-Pmapr4 -Psparkr -Pyarn -Phive -Phive-thriftserver" "3036" & + make_binary_release "hadoop2.4-without-hive" "-Psparkr -Phadoop-2.4 -Pyarn" "3037" & wait rm -rf spark-$RELEASE_VERSION-bin-*/ From 0526fea483066086dfc27d1606f74220fe822f7f Mon Sep 17 00:00:00 2001 From: Cheolsoo Park Date: Thu, 4 Jun 2015 13:27:35 -0700 Subject: [PATCH 188/254] [SPARK-6909][SQL] Remove Hive Shim code This is a follow-up on #6393. I am removing the following files in this PR. ``` ./sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala ./sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala ``` Basically, I re-factored the shim code as follows- * Rewrote code directly with Hive 0.13 methods, or * Converted code into private methods, or * Extracted code into separate classes But for leftover code that didn't fit in any of these cases, I created a HiveShim object. For eg, helper functions which wrap Hive 0.13 methods to work around Hive bugs are placed here. Author: Cheolsoo Park Closes #6604 from piaozhexiu/SPARK-6909 and squashes the following commits: 5dccc20 [Cheolsoo Park] Remove hive shim code --- .../hive/thriftserver/HiveThriftServer2.scala | 10 +- .../SparkExecuteStatementOperation.scala} | 102 +--- .../hive/thriftserver/SparkSQLCLIDriver.scala | 6 +- .../thriftserver/SparkSQLCLIService.scala | 7 +- ...rkSQLDriver.scala => SparkSQLDriver.scala} | 20 +- .../sql/hive/thriftserver/SparkSQLEnv.scala | 4 +- .../thriftserver/SparkSQLSessionManager.scala | 75 +++ .../HiveThriftServer2Suites.scala | 8 +- .../execution/HiveCompatibilitySuite.scala | 3 +- .../apache/spark/sql/hive/HiveContext.scala | 23 +- .../spark/sql/hive/HiveInspectors.scala | 187 +++++-- .../spark/sql/hive/HiveMetastoreCatalog.scala | 22 +- .../org/apache/spark/sql/hive/HiveQl.scala | 4 +- .../org/apache/spark/sql/hive/HiveShim.scala | 247 ++++++++++ .../apache/spark/sql/hive/TableReader.scala | 11 +- .../hive/execution/InsertIntoHiveTable.scala | 12 +- .../org/apache/spark/sql/hive/hiveUdfs.scala | 1 + .../spark/sql/hive/hiveWriterContainers.scala | 3 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 46 +- .../spark/sql/hive/StatisticsSuite.scala | 4 - .../sql/hive/execution/HiveQuerySuite.scala | 25 +- .../sql/hive/execution/SQLQuerySuite.scala | 58 ++- .../org/apache/spark/sql/hive/Shim13.scala | 457 ------------------ 23 files changed, 619 insertions(+), 716 deletions(-) rename sql/hive-thriftserver/{v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala => src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala} (66%) rename sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/{AbstractSparkSQLDriver.scala => SparkSQLDriver.scala} (86%) create mode 100644 sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala delete mode 100644 sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 94687eeda4179..5b391d3dce882 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -17,9 +17,6 @@ package org.apache.spark.sql.hive.thriftserver -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - import org.apache.commons.logging.LogFactory import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars @@ -29,12 +26,15 @@ import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd, SparkListenerJobStart} import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab -import org.apache.spark.sql.hive.{HiveContext, HiveShim} import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkContext} +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + /** * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a * `HiveThriftServer2` thrift server. @@ -51,7 +51,7 @@ object HiveThriftServer2 extends Logging { @DeveloperApi def startWithContext(sqlContext: HiveContext): Unit = { val server = new HiveThriftServer2(sqlContext) - sqlContext.setConf("spark.sql.hive.version", HiveShim.version) + sqlContext.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) server.init(sqlContext.hiveconf) server.start() listener = new HiveThriftServer2Listener(server, sqlContext.conf) diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala similarity index 66% rename from sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala rename to sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index b9d4f1c58c982..c0d1266212cdd 100644 --- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -18,66 +18,31 @@ package org.apache.spark.sql.hive.thriftserver import java.sql.{Date, Timestamp} -import java.util.concurrent.Executors -import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, UUID} - -import org.apache.commons.logging.Log -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hive.service.cli.thrift.TProtocolVersion -import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager - -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, Map => SMap} +import java.util.{Map => JMap, UUID} import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.hadoop.security.UserGroupInformation import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.ExecuteStatementOperation -import org.apache.hive.service.cli.session.{SessionManager, HiveSession} +import org.apache.hive.service.cli.session.HiveSession -import org.apache.spark.{SparkContext, Logging} -import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} +import org.apache.spark.Logging import org.apache.spark.sql.execution.SetCommand -import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} -/** - * A compatibility layer for interacting with Hive version 0.13.1. - */ -private[thriftserver] object HiveThriftServerShim { - val version = "0.13.1" - - def setServerUserName( - sparkServiceUGI: UserGroupInformation, - sparkCliService:SparkSQLCLIService) = { - setSuperField(sparkCliService, "serviceUGI", sparkServiceUGI) - } -} - -private[hive] class SparkSQLDriver(val _context: HiveContext = SparkSQLEnv.hiveContext) - extends AbstractSparkSQLDriver(_context) { - override def getResults(res: JList[_]): Boolean = { - if (hiveResponse == null) { - false - } else { - res.asInstanceOf[JArrayList[String]].addAll(hiveResponse) - hiveResponse = null - true - } - } -} +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ArrayBuffer, Map => SMap} private[hive] class SparkExecuteStatementOperation( parentSession: HiveSession, statement: String, confOverlay: JMap[String, String], - runInBackground: Boolean = true)( - hiveContext: HiveContext, - sessionToActivePool: SMap[SessionHandle, String]) + runInBackground: Boolean = true) + (hiveContext: HiveContext, sessionToActivePool: SMap[SessionHandle, String]) // NOTE: `runInBackground` is set to `false` intentionally to disable asynchronous execution - extends ExecuteStatementOperation(parentSession, statement, confOverlay, false) with Logging { + extends ExecuteStatementOperation(parentSession, statement, confOverlay, false) + with Logging { private var result: DataFrame = _ private var iter: Iterator[SparkRow] = _ @@ -88,7 +53,7 @@ private[hive] class SparkExecuteStatementOperation( logDebug("CLOSING") } - def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int) { + def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int) { dataTypes(ordinal) match { case StringType => to += from.getString(ordinal) @@ -209,48 +174,3 @@ private[hive] class SparkExecuteStatementOperation( HiveThriftServer2.listener.onStatementFinish(statementId) } } - -private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) - extends SessionManager - with ReflectedCompositeService { - - private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) - - override def init(hiveConf: HiveConf) { - setSuperField(this, "hiveConf", hiveConf) - - val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) - setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) - getAncestorField[Log](this, 3, "LOG").info( - s"HiveServer2: Async execution pool size $backgroundPoolSize") - - setSuperField(this, "operationManager", sparkSqlOperationManager) - addService(sparkSqlOperationManager) - - initCompositeService(hiveConf) - } - - override def openSession( - protocol: TProtocolVersion, - username: String, - passwd: String, - sessionConf: java.util.Map[String, String], - withImpersonation: Boolean, - delegationToken: String): SessionHandle = { - hiveContext.openSession() - val sessionHandle = super.openSession( - protocol, username, passwd, sessionConf, withImpersonation, delegationToken) - val session = super.getSession(sessionHandle) - HiveThriftServer2.listener.onSessionCreated( - session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername) - sessionHandle - } - - override def closeSession(sessionHandle: SessionHandle) { - HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString) - super.closeSession(sessionHandle) - sparkSqlOperationManager.sessionToActivePool -= sessionHandle - - hiveContext.detachSession() - } -} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 14f6f658d9b75..039cfa40d26b3 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -32,12 +32,12 @@ import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor} +import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor, CommandProcessorFactory} import org.apache.hadoop.hive.ql.session.SessionState import org.apache.thrift.transport.TSocket import org.apache.spark.Logging -import org.apache.spark.sql.hive.{HiveContext, HiveShim} +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.util.Utils private[hive] object SparkSQLCLIDriver { @@ -267,7 +267,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { } else { var ret = 0 val hconf = conf.asInstanceOf[HiveConf] - val proc: CommandProcessor = HiveShim.getCommandProcessor(Array(tokens(0)), hconf) + val proc: CommandProcessor = CommandProcessorFactory.get(Array(tokens(0)), hconf) if (proc != null) { if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor] || diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index 499e077d7294a..41f647d5f8c5a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -21,8 +21,6 @@ import java.io.IOException import java.util.{List => JList} import javax.security.auth.login.LoginException -import scala.collection.JavaConversions._ - import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.shims.ShimLoader @@ -34,7 +32,8 @@ import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ -import org.apache.spark.util.Utils + +import scala.collection.JavaConversions._ private[hive] class SparkSQLCLIService(hiveContext: HiveContext) extends CLIService @@ -52,7 +51,7 @@ private[hive] class SparkSQLCLIService(hiveContext: HiveContext) try { HiveAuthFactory.loginFromKeytab(hiveConf) sparkServiceUGI = ShimLoader.getHadoopShims.getUGIForConf(hiveConf) - HiveThriftServerShim.setServerUserName(sparkServiceUGI, this) + setSuperField(this, "serviceUGI", sparkServiceUGI) } catch { case e @ (_: IOException | _: LoginException) => throw new ServiceException("Unable to login to kerberos with given principal/keytab", e) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala similarity index 86% rename from sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala rename to sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index 48ac9062af96a..77272aecf2835 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.thriftserver -import scala.collection.JavaConversions._ +import java.util.{ArrayList => JArrayList, List => JList} import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} @@ -27,8 +27,12 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse import org.apache.spark.Logging import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -private[hive] abstract class AbstractSparkSQLDriver( - val context: HiveContext = SparkSQLEnv.hiveContext) extends Driver with Logging { +import scala.collection.JavaConversions._ + +private[hive] class SparkSQLDriver( + val context: HiveContext = SparkSQLEnv.hiveContext) + extends Driver + with Logging { private[hive] var tableSchema: Schema = _ private[hive] var hiveResponse: Seq[String] = _ @@ -71,6 +75,16 @@ private[hive] abstract class AbstractSparkSQLDriver( 0 } + override def getResults(res: JList[_]): Boolean = { + if (hiveResponse == null) { + false + } else { + res.asInstanceOf[JArrayList[String]].addAll(hiveResponse) + hiveResponse = null + true + } + } + override def getSchema: Schema = tableSchema override def destroy() { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 7c0c505e2d61e..79eda1f5123bf 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -22,7 +22,7 @@ import java.io.PrintStream import scala.collection.JavaConversions._ import org.apache.spark.scheduler.StatsReportListener -import org.apache.spark.sql.hive.{HiveShim, HiveContext} +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.util.Utils @@ -56,7 +56,7 @@ private[hive] object SparkSQLEnv extends Logging { hiveContext.metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) hiveContext.metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) - hiveContext.setConf("spark.sql.hive.version", HiveShim.version) + hiveContext.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) if (log.isDebugEnabled) { hiveContext.hiveconf.getAllProperties.toSeq.sorted.foreach { case (k, v) => diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala new file mode 100644 index 0000000000000..357b27f7401a3 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.util.concurrent.Executors + +import org.apache.commons.logging.Log +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hive.service.cli.SessionHandle +import org.apache.hive.service.cli.session.SessionManager +import org.apache.hive.service.cli.thrift.TProtocolVersion + +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ +import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager + +private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) + extends SessionManager + with ReflectedCompositeService { + + private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) + + override def init(hiveConf: HiveConf) { + setSuperField(this, "hiveConf", hiveConf) + + val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) + setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) + getAncestorField[Log](this, 3, "LOG").info( + s"HiveServer2: Async execution pool size $backgroundPoolSize") + + setSuperField(this, "operationManager", sparkSqlOperationManager) + addService(sparkSqlOperationManager) + + initCompositeService(hiveConf) + } + + override def openSession(protocol: TProtocolVersion, + username: String, + passwd: String, + sessionConf: java.util.Map[String, String], + withImpersonation: Boolean, + delegationToken: String): SessionHandle = { + hiveContext.openSession() + val sessionHandle = super.openSession( + protocol, username, passwd, sessionConf, withImpersonation, delegationToken) + val session = super.getSession(sessionHandle) + HiveThriftServer2.listener.onSessionCreated( + session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername) + sessionHandle + } + + override def closeSession(sessionHandle: SessionHandle) { + HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString) + super.closeSession(sessionHandle) + sparkSqlOperationManager.sessionToActivePool -= sessionHandle + + hiveContext.detachSession() + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index a93a3dee43511..f57c7083ea504 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -40,7 +40,7 @@ import org.apache.thrift.transport.TSocket import org.scalatest.BeforeAndAfterAll import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.sql.hive.HiveShim +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.util.Utils object TestData { @@ -111,7 +111,8 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { withJdbcStatement { statement => val resultSet = statement.executeQuery("SET spark.sql.hive.version") resultSet.next() - assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") + assert(resultSet.getString(1) === + s"spark.sql.hive.version=${HiveContext.hiveExecutionVersion}") } } @@ -365,7 +366,8 @@ class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { withJdbcStatement { statement => val resultSet = statement.executeQuery("SET spark.sql.hive.version") resultSet.next() - assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") + assert(resultSet.getString(1) === + s"spark.sql.hive.version=${HiveContext.hiveExecutionVersion}") } } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 0b1917a392901..048f78b4daa8d 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -23,7 +23,6 @@ import java.util.{Locale, TimeZone} import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.hive.HiveShim import org.apache.spark.sql.hive.test.TestHive /** @@ -254,7 +253,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // the answer is sensitive for jdk version "udf_java_method" - ) ++ HiveShim.compatibilityBlackList + ) /** * The set of tests that are believed to be working in catalyst. Tests not on whiteList or diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index fbf2c7d8cbc06..800f51c5e2e86 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -17,37 +17,34 @@ package org.apache.spark.sql.hive -import java.io.{BufferedReader, File, InputStreamReader, PrintStream} +import java.io.File import java.net.{URL, URLClassLoader} import java.sql.Timestamp -import java.util.{ArrayList => JArrayList} -import org.apache.hadoop.hive.ql.parse.VariableSubstitution +import org.apache.hadoop.hive.common.StatsSetupConst +import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.spark.sql.catalyst.ParserDialect import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.mutable.HashMap import scala.language.implicitConversions import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.parse.VariableSubstitution -import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateSubQueries, OverrideCatalog, OverrideFunctionRegistry} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, QueryExecutionException, SetCommand} +import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} -import org.apache.spark.sql.sources.{DDLParser, DataSourceStrategy} +import org.apache.spark.sql.sources.DataSourceStrategy import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -331,7 +328,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val tableParameters = relation.hiveQlTable.getParameters val oldTotalSize = - Option(tableParameters.get(HiveShim.getStatsSetupConstTotalSize)) + Option(tableParameters.get(StatsSetupConst.TOTAL_SIZE)) .map(_.toLong) .getOrElse(0L) val newTotalSize = getFileSizeForTable(hiveconf, relation.hiveQlTable) @@ -342,7 +339,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.client.alterTable( relation.table.copy( properties = relation.table.properties + - (HiveShim.getStatsSetupConstTotalSize -> newTotalSize.toString))) + (StatsSetupConst.TOTAL_SIZE -> newTotalSize.toString))) } case otherRelation => throw new UnsupportedOperationException( @@ -564,7 +561,7 @@ private[hive] object HiveContext { case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8") case (decimal: java.math.BigDecimal, DecimalType()) => // Hive strips trailing zeros so use its toString - HiveShim.createDecimal(decimal).toString + HiveDecimal.create(decimal).toString case (other, tpe) if primitiveTypes contains tpe => other.toString } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 24cd335082639..c466203cd0220 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _} +import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfoFactory} import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.{io => hadoopIo} @@ -350,7 +351,7 @@ private[hive] trait HiveInspectors { new HiveVarchar(s, s.size) case _: JavaHiveDecimalObjectInspector => - (o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toJavaBigDecimal) + (o: Any) => HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal) case _: JavaDateObjectInspector => (o: Any) => DateUtils.toJavaDate(o.asInstanceOf[Int]) @@ -439,31 +440,31 @@ private[hive] trait HiveInspectors { case _ if a == null => null case x: PrimitiveObjectInspector => x match { // TODO we don't support the HiveVarcharObjectInspector yet. - case _: StringObjectInspector if x.preferWritable() => HiveShim.getStringWritable(a) + case _: StringObjectInspector if x.preferWritable() => getStringWritable(a) case _: StringObjectInspector => a.asInstanceOf[UTF8String].toString() - case _: IntObjectInspector if x.preferWritable() => HiveShim.getIntWritable(a) + case _: IntObjectInspector if x.preferWritable() => getIntWritable(a) case _: IntObjectInspector => a.asInstanceOf[java.lang.Integer] - case _: BooleanObjectInspector if x.preferWritable() => HiveShim.getBooleanWritable(a) + case _: BooleanObjectInspector if x.preferWritable() => getBooleanWritable(a) case _: BooleanObjectInspector => a.asInstanceOf[java.lang.Boolean] - case _: FloatObjectInspector if x.preferWritable() => HiveShim.getFloatWritable(a) + case _: FloatObjectInspector if x.preferWritable() => getFloatWritable(a) case _: FloatObjectInspector => a.asInstanceOf[java.lang.Float] - case _: DoubleObjectInspector if x.preferWritable() => HiveShim.getDoubleWritable(a) + case _: DoubleObjectInspector if x.preferWritable() => getDoubleWritable(a) case _: DoubleObjectInspector => a.asInstanceOf[java.lang.Double] - case _: LongObjectInspector if x.preferWritable() => HiveShim.getLongWritable(a) + case _: LongObjectInspector if x.preferWritable() => getLongWritable(a) case _: LongObjectInspector => a.asInstanceOf[java.lang.Long] - case _: ShortObjectInspector if x.preferWritable() => HiveShim.getShortWritable(a) + case _: ShortObjectInspector if x.preferWritable() => getShortWritable(a) case _: ShortObjectInspector => a.asInstanceOf[java.lang.Short] - case _: ByteObjectInspector if x.preferWritable() => HiveShim.getByteWritable(a) + case _: ByteObjectInspector if x.preferWritable() => getByteWritable(a) case _: ByteObjectInspector => a.asInstanceOf[java.lang.Byte] case _: HiveDecimalObjectInspector if x.preferWritable() => - HiveShim.getDecimalWritable(a.asInstanceOf[Decimal]) + getDecimalWritable(a.asInstanceOf[Decimal]) case _: HiveDecimalObjectInspector => - HiveShim.createDecimal(a.asInstanceOf[Decimal].toJavaBigDecimal) - case _: BinaryObjectInspector if x.preferWritable() => HiveShim.getBinaryWritable(a) + HiveDecimal.create(a.asInstanceOf[Decimal].toJavaBigDecimal) + case _: BinaryObjectInspector if x.preferWritable() => getBinaryWritable(a) case _: BinaryObjectInspector => a.asInstanceOf[Array[Byte]] - case _: DateObjectInspector if x.preferWritable() => HiveShim.getDateWritable(a) + case _: DateObjectInspector if x.preferWritable() => getDateWritable(a) case _: DateObjectInspector => DateUtils.toJavaDate(a.asInstanceOf[Int]) - case _: TimestampObjectInspector if x.preferWritable() => HiveShim.getTimestampWritable(a) + case _: TimestampObjectInspector if x.preferWritable() => getTimestampWritable(a) case _: TimestampObjectInspector => a.asInstanceOf[java.sql.Timestamp] } case x: SettableStructObjectInspector => @@ -574,31 +575,31 @@ private[hive] trait HiveInspectors { */ def toInspector(expr: Expression): ObjectInspector = expr match { case Literal(value, StringType) => - HiveShim.getStringWritableConstantObjectInspector(value) + getStringWritableConstantObjectInspector(value) case Literal(value, IntegerType) => - HiveShim.getIntWritableConstantObjectInspector(value) + getIntWritableConstantObjectInspector(value) case Literal(value, DoubleType) => - HiveShim.getDoubleWritableConstantObjectInspector(value) + getDoubleWritableConstantObjectInspector(value) case Literal(value, BooleanType) => - HiveShim.getBooleanWritableConstantObjectInspector(value) + getBooleanWritableConstantObjectInspector(value) case Literal(value, LongType) => - HiveShim.getLongWritableConstantObjectInspector(value) + getLongWritableConstantObjectInspector(value) case Literal(value, FloatType) => - HiveShim.getFloatWritableConstantObjectInspector(value) + getFloatWritableConstantObjectInspector(value) case Literal(value, ShortType) => - HiveShim.getShortWritableConstantObjectInspector(value) + getShortWritableConstantObjectInspector(value) case Literal(value, ByteType) => - HiveShim.getByteWritableConstantObjectInspector(value) + getByteWritableConstantObjectInspector(value) case Literal(value, BinaryType) => - HiveShim.getBinaryWritableConstantObjectInspector(value) + getBinaryWritableConstantObjectInspector(value) case Literal(value, DateType) => - HiveShim.getDateWritableConstantObjectInspector(value) + getDateWritableConstantObjectInspector(value) case Literal(value, TimestampType) => - HiveShim.getTimestampWritableConstantObjectInspector(value) + getTimestampWritableConstantObjectInspector(value) case Literal(value, DecimalType()) => - HiveShim.getDecimalWritableConstantObjectInspector(value) + getDecimalWritableConstantObjectInspector(value) case Literal(_, NullType) => - HiveShim.getPrimitiveNullWritableConstantObjectInspector + getPrimitiveNullWritableConstantObjectInspector case Literal(value, ArrayType(dt, _)) => val listObjectInspector = toInspector(dt) if (value == null) { @@ -658,8 +659,8 @@ private[hive] trait HiveInspectors { case _: JavaFloatObjectInspector => FloatType case _: WritableBinaryObjectInspector => BinaryType case _: JavaBinaryObjectInspector => BinaryType - case w: WritableHiveDecimalObjectInspector => HiveShim.decimalTypeInfoToCatalyst(w) - case j: JavaHiveDecimalObjectInspector => HiveShim.decimalTypeInfoToCatalyst(j) + case w: WritableHiveDecimalObjectInspector => decimalTypeInfoToCatalyst(w) + case j: JavaHiveDecimalObjectInspector => decimalTypeInfoToCatalyst(j) case _: WritableDateObjectInspector => DateType case _: JavaDateObjectInspector => DateType case _: WritableTimestampObjectInspector => TimestampType @@ -668,10 +669,136 @@ private[hive] trait HiveInspectors { case _: JavaVoidObjectInspector => NullType } + private def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = { + val info = inspector.getTypeInfo.asInstanceOf[DecimalTypeInfo] + DecimalType(info.precision(), info.scale()) + } + + private def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.stringTypeInfo, getStringWritable(value)) + + private def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.intTypeInfo, getIntWritable(value)) + + private def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.doubleTypeInfo, getDoubleWritable(value)) + + private def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.booleanTypeInfo, getBooleanWritable(value)) + + private def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.longTypeInfo, getLongWritable(value)) + + private def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.floatTypeInfo, getFloatWritable(value)) + + private def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.shortTypeInfo, getShortWritable(value)) + + private def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.byteTypeInfo, getByteWritable(value)) + + private def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.binaryTypeInfo, getBinaryWritable(value)) + + private def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.dateTypeInfo, getDateWritable(value)) + + private def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.timestampTypeInfo, getTimestampWritable(value)) + + private def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.decimalTypeInfo, getDecimalWritable(value)) + + private def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.voidTypeInfo, null) + + private def getStringWritable(value: Any): hadoopIo.Text = + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) + + private def getIntWritable(value: Any): hadoopIo.IntWritable = + if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) + + private def getDoubleWritable(value: Any): hiveIo.DoubleWritable = + if (value == null) { + null + } else { + new hiveIo.DoubleWritable(value.asInstanceOf[Double]) + } + + private def getBooleanWritable(value: Any): hadoopIo.BooleanWritable = + if (value == null) { + null + } else { + new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]) + } + + private def getLongWritable(value: Any): hadoopIo.LongWritable = + if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long]) + + private def getFloatWritable(value: Any): hadoopIo.FloatWritable = + if (value == null) { + null + } else { + new hadoopIo.FloatWritable(value.asInstanceOf[Float]) + } + + private def getShortWritable(value: Any): hiveIo.ShortWritable = + if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short]) + + private def getByteWritable(value: Any): hiveIo.ByteWritable = + if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte]) + + private def getBinaryWritable(value: Any): hadoopIo.BytesWritable = + if (value == null) { + null + } else { + new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) + } + + private def getDateWritable(value: Any): hiveIo.DateWritable = + if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[Int]) + + private def getTimestampWritable(value: Any): hiveIo.TimestampWritable = + if (value == null) { + null + } else { + new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) + } + + private def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable = + if (value == null) { + null + } else { + // TODO precise, scale? + new hiveIo.HiveDecimalWritable( + HiveDecimal.create(value.asInstanceOf[Decimal].toJavaBigDecimal)) + } + implicit class typeInfoConversions(dt: DataType) { import org.apache.hadoop.hive.serde2.typeinfo._ import TypeInfoFactory._ + private def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match { + case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale) + case _ => new DecimalTypeInfo( + HiveShim.UNLIMITED_DECIMAL_PRECISION, + HiveShim.UNLIMITED_DECIMAL_SCALE) + } + def toTypeInfo: TypeInfo = dt match { case ArrayType(elemType, _) => getListTypeInfo(elemType.toTypeInfo) @@ -690,7 +817,7 @@ private[hive] trait HiveInspectors { case LongType => longTypeInfo case ShortType => shortTypeInfo case StringType => stringTypeInfo - case d: DecimalType => HiveShim.decimalTypeInfo(d) + case d: DecimalType => decimalTypeInfo(d) case DateType => dateTypeInfo case TimestampType => timestampTypeInfo case NullType => voidTypeInfo diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index ca1f49b546bd7..5a4651a887b7c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql.hive import com.google.common.base.Objects import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} + import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.metastore.Warehouse import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.ql.metadata._ -import org.apache.hadoop.hive.serde2.Deserializer +import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} @@ -37,7 +39,6 @@ import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode, sources} -import org.apache.spark.util.Utils /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -670,8 +671,8 @@ private[hive] case class MetastoreRelation @transient override lazy val statistics: Statistics = Statistics( sizeInBytes = { - val totalSize = hiveQlTable.getParameters.get(HiveShim.getStatsSetupConstTotalSize) - val rawDataSize = hiveQlTable.getParameters.get(HiveShim.getStatsSetupConstRawDataSize) + val totalSize = hiveQlTable.getParameters.get(StatsSetupConst.TOTAL_SIZE) + val rawDataSize = hiveQlTable.getParameters.get(StatsSetupConst.RAW_DATA_SIZE) // TODO: check if this estimate is valid for tables after partition pruning. // NOTE: getting `totalSize` directly from params is kind of hacky, but this should be // relatively cheap if parameters for the table are populated into the metastore. An @@ -697,11 +698,7 @@ private[hive] case class MetastoreRelation } } - val tableDesc = HiveShim.getTableDesc( - Class.forName( - hiveQlTable.getSerializationLib, - true, - Utils.getContextOrSparkClassLoader).asInstanceOf[Class[Deserializer]], + val tableDesc = new TableDesc( hiveQlTable.getInputFormatClass, // The class of table should be org.apache.hadoop.hive.ql.metadata.Table because // getOutputFormatClass will use HiveFileFormatUtils.getOutputFormatSubstitute to @@ -743,6 +740,11 @@ private[hive] case class MetastoreRelation private[hive] object HiveMetastoreTypes { def toDataType(metastoreType: String): DataType = DataTypeParser.parse(metastoreType) + def decimalMetastoreString(decimalType: DecimalType): String = decimalType match { + case DecimalType.Fixed(precision, scale) => s"decimal($precision,$scale)" + case _ => s"decimal($HiveShim.UNLIMITED_DECIMAL_PRECISION,$HiveShim.UNLIMITED_DECIMAL_SCALE)" + } + def toMetastoreType(dt: DataType): String = dt match { case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" case StructType(fields) => @@ -759,7 +761,7 @@ private[hive] object HiveMetastoreTypes { case BinaryType => "binary" case BooleanType => "boolean" case DateType => "date" - case d: DecimalType => HiveShim.decimalMetastoreString(d) + case d: DecimalType => decimalMetastoreString(d) case TimestampType => "timestamp" case NullType => "void" case udt: UserDefinedType[_] => toMetastoreType(udt.sqlType) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index a5ca3613c5e00..9544d12c9053c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.hive import java.sql.Date -import scala.collection.mutable.ArrayBuffer - import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.ql.{ErrorMsg, Context} @@ -39,6 +37,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.execution.ExplainCommand import org.apache.spark.sql.sources.DescribeCommand +import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable, HiveScriptIOSchema} import org.apache.spark.sql.types._ @@ -46,6 +45,7 @@ import org.apache.spark.util.random.RandomSampler /* Implicit conversions */ import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer /** * Used when we need to start parsing the AST before deciding that we are going to pass the command diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala new file mode 100644 index 0000000000000..fa5409f602444 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.{InputStream, OutputStream} +import java.rmi.server.UID + +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} +import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} +import org.apache.hadoop.hive.serde2.ColumnProjectionUtils +import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable +import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector +import org.apache.hadoop.io.Writable + +import org.apache.spark.Logging +import org.apache.spark.sql.types.Decimal +import org.apache.spark.util.Utils + +/* Implicit conversions */ +import scala.collection.JavaConversions._ +import scala.reflect.ClassTag + +private[hive] object HiveShim { + // Precision and scale to pass for unlimited decimals; these are the same as the precision and + // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs) + val UNLIMITED_DECIMAL_PRECISION = 38 + val UNLIMITED_DECIMAL_SCALE = 18 + + /* + * This function in hive-0.13 become private, but we have to do this to walkaround hive bug + */ + private def appendReadColumnNames(conf: Configuration, cols: Seq[String]) { + val old: String = conf.get(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, "") + val result: StringBuilder = new StringBuilder(old) + var first: Boolean = old.isEmpty + + for (col <- cols) { + if (first) { + first = false + } else { + result.append(',') + } + result.append(col) + } + conf.set(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, result.toString) + } + + /* + * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null or empty + */ + def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { + if (ids != null && ids.size > 0) { + ColumnProjectionUtils.appendReadColumns(conf, ids) + } + if (names != null && names.size > 0) { + appendReadColumnNames(conf, names) + } + } + + /* + * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that + * is needed to initialize before serialization. + */ + def prepareWritable(w: Writable): Writable = { + w match { + case w: AvroGenericRecordWritable => + w.setRecordReaderID(new UID()) + case _ => + } + w + } + + def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { + if (hdoi.preferWritable()) { + Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue, + hdoi.precision(), hdoi.scale()) + } else { + Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) + } + } + + /** + * This class provides the UDF creation and also the UDF instance serialization and + * de-serialization cross process boundary. + * + * Detail discussion can be found at https://github.com/apache/spark/pull/3640 + * + * @param functionClassName UDF class name + */ + private[hive] case class HiveFunctionWrapper(var functionClassName: String) + extends java.io.Externalizable { + + // for Serialization + def this() = this(null) + + @transient + def deserializeObjectByKryo[T: ClassTag]( + kryo: Kryo, + in: InputStream, + clazz: Class[_]): T = { + val inp = new Input(in) + val t: T = kryo.readObject(inp, clazz).asInstanceOf[T] + inp.close() + t + } + + @transient + def serializeObjectByKryo( + kryo: Kryo, + plan: Object, + out: OutputStream) { + val output: Output = new Output(out) + kryo.writeObject(output, plan) + output.close() + } + + def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = { + deserializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), is, clazz) + .asInstanceOf[UDFType] + } + + def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = { + serializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), function, out) + } + + private var instance: AnyRef = null + + def writeExternal(out: java.io.ObjectOutput) { + // output the function name + out.writeUTF(functionClassName) + + // Write a flag if instance is null or not + out.writeBoolean(instance != null) + if (instance != null) { + // Some of the UDF are serializable, but some others are not + // Hive Utilities can handle both cases + val baos = new java.io.ByteArrayOutputStream() + serializePlan(instance, baos) + val functionInBytes = baos.toByteArray + + // output the function bytes + out.writeInt(functionInBytes.length) + out.write(functionInBytes, 0, functionInBytes.length) + } + } + + def readExternal(in: java.io.ObjectInput) { + // read the function name + functionClassName = in.readUTF() + + if (in.readBoolean()) { + // if the instance is not null + // read the function in bytes + val functionInBytesLength = in.readInt() + val functionInBytes = new Array[Byte](functionInBytesLength) + in.read(functionInBytes, 0, functionInBytesLength) + + // deserialize the function object via Hive Utilities + instance = deserializePlan[AnyRef](new java.io.ByteArrayInputStream(functionInBytes), + Utils.getContextOrSparkClassLoader.loadClass(functionClassName)) + } + } + + def createFunction[UDFType <: AnyRef](): UDFType = { + if (instance != null) { + instance.asInstanceOf[UDFType] + } else { + val func = Utils.getContextOrSparkClassLoader + .loadClass(functionClassName).newInstance.asInstanceOf[UDFType] + if (!func.isInstanceOf[UDF]) { + // We cache the function if it's no the Simple UDF, + // as we always have to create new instance for Simple UDF + instance = func + } + func + } + } + } + + /* + * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not. + * Fix it through wrapper. + * */ + implicit def wrapperToFileSinkDesc(w: ShimFileSinkDesc): FileSinkDesc = { + var f = new FileSinkDesc(new Path(w.dir), w.tableInfo, w.compressed) + f.setCompressCodec(w.compressCodec) + f.setCompressType(w.compressType) + f.setTableInfo(w.tableInfo) + f.setDestTableId(w.destTableId) + f + } + + /* + * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not. + * Fix it through wrapper. + */ + private[hive] class ShimFileSinkDesc( + var dir: String, + var tableInfo: TableDesc, + var compressed: Boolean) + extends Serializable with Logging { + var compressCodec: String = _ + var compressType: String = _ + var destTableId: Int = _ + + def setCompressed(compressed: Boolean) { + this.compressed = compressed + } + + def getDirName(): String = dir + + def setDestTableId(destTableId: Int) { + this.destTableId = destTableId + } + + def setTableInfo(tableInfo: TableDesc) { + this.tableInfo = tableInfo + } + + def setCompressCodec(intermediateCompressorCodec: String) { + compressCodec = intermediateCompressorCodec + } + + def setCompressType(intermediateCompressType: String) { + compressType = intermediateCompressType + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 294fc3bd7d5e9..334bfccc9d200 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -25,14 +25,13 @@ import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} import org.apache.hadoop.hive.serde2.Deserializer -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, StructObjectInspector} import org.apache.hadoop.hive.serde2.objectinspector.primitive._ +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, StructObjectInspector} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} -import org.apache.spark.SerializableWritable +import org.apache.spark.{Logging, SerializableWritable} import org.apache.spark.broadcast.Broadcast -import org.apache.spark.Logging import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DateUtils @@ -172,7 +171,7 @@ class HadoopTableReader( path.toString + tails } - val partPath = HiveShim.getDataLocationPath(partition) + val partPath = partition.getDataLocation val partNum = Utilities.getPartitionDesc(partition).getPartSpec.size(); var pathPatternStr = getPathPatternByPath(partNum, partPath) if (!pathPatternSet.contains(pathPatternStr)) { @@ -187,7 +186,7 @@ class HadoopTableReader( val hivePartitionRDDs = verifyPartitionPath(partitionToDeserializer) .map { case (partition, partDeserializer) => val partDesc = Utilities.getPartitionDesc(partition) - val partPath = HiveShim.getDataLocationPath(partition) + val partPath = partition.getDataLocation val inputPathStr = applyFilterIfNeeded(partPath, filterOpt) val ifc = partDesc.getInputFileFormatClass .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] @@ -325,7 +324,7 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { val soi = if (rawDeser.getObjectInspector.equals(tableDeser.getObjectInspector)) { rawDeser.getObjectInspector.asInstanceOf[StructObjectInspector] } else { - HiveShim.getConvertedOI( + ObjectInspectorConverters.getConvertedOI( rawDeser.getObjectInspector, tableDeser.getObjectInspector).asInstanceOf[StructObjectInspector] } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 8613332186f28..eeb472602be3c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -19,27 +19,25 @@ package org.apache.spark.sql.hive.execution import java.util -import scala.collection.JavaConversions._ - import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.metastore.MetaStoreUtils -import org.apache.hadoop.hive.ql.metadata.Hive import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde2.Serializer import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} +import org.apache.hadoop.mapred.{FileOutputFormat, JobConf} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} +import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.{ ShimFileSinkDesc => FileSinkDesc} -import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.{SerializableWritable, SparkException, TaskContext} +import scala.collection.JavaConversions._ + private[hive] case class InsertIntoHiveTable( table: MetastoreRelation, @@ -126,7 +124,7 @@ case class InsertIntoHiveTable( // instances within the closure, since Serializer is not serializable while TableDesc is. val tableDesc = table.tableDesc val tableLocation = table.hiveQlTable.getDataLocation - val tmpLocation = HiveShim.getExternalTmpPath(hiveContext, tableLocation) + val tmpLocation = hiveContext.getExternalTmpPath(tableLocation.toUri) val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) val isCompressed = sc.hiveconf.getBoolean( ConfVars.COMPRESSRESULT.varname, ConfVars.COMPRESSRESULT.defaultBoolVal) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 1658bb93b0b79..01f47352b2313 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.types._ /* Implicit conversions */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 2bb526b14be34..ee440e304ec19 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -35,8 +35,7 @@ import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.Row import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.sql.hive.{ShimFileSinkDesc => FileSinkDesc} -import org.apache.spark.sql.hive.HiveShim._ +import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.types._ /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 58e2d1fbfa73e..af586712e3235 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -561,30 +561,28 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA } } - if (HiveShim.version == "0.13.1") { - test("scan a parquet table created through a CTAS statement") { - withSQLConf( - "spark.sql.hive.convertMetastoreParquet" -> "true", - SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { - - withTempTable("jt") { - (1 to 10).map(i => i -> s"str$i").toDF("a", "b").registerTempTable("jt") - - withTable("test_parquet_ctas") { - sql( - """CREATE TABLE test_parquet_ctas STORED AS PARQUET - |AS SELECT tmp.a FROM jt tmp WHERE tmp.a < 5 - """.stripMargin) - - checkAnswer( - sql(s"SELECT a FROM test_parquet_ctas WHERE a > 2 "), - Row(3) :: Row(4) :: Nil) - - table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(p: ParquetRelation2) => // OK - case _ => - fail(s"test_parquet_ctas should have be converted to ${classOf[ParquetRelation2]}") - } + test("scan a parquet table created through a CTAS statement") { + withSQLConf( + "spark.sql.hive.convertMetastoreParquet" -> "true", + SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { + + withTempTable("jt") { + (1 to 10).map(i => i -> s"str$i").toDF("a", "b").registerTempTable("jt") + + withTable("test_parquet_ctas") { + sql( + """CREATE TABLE test_parquet_ctas STORED AS PARQUET + |AS SELECT tmp.a FROM jt tmp WHERE tmp.a < 5 + """.stripMargin) + + checkAnswer( + sql(s"SELECT a FROM test_parquet_ctas WHERE a > 2 "), + Row(3) :: Row(4) :: Nil) + + table("test_parquet_ctas").queryExecution.optimizedPlan match { + case LogicalRelation(p: ParquetRelation2) => // OK + case _ => + fail(s"test_parquet_ctas should have be converted to ${classOf[ParquetRelation2]}") } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 00a69de9e4262..e16e530555aee 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -79,10 +79,6 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() - // TODO: How does it works? needs to add it back for other hive version. - if (HiveShim.version =="0.12.0") { - assert(queryTotalSize("analyzeTable") === conf.defaultSizeInBytes) - } sql("ANALYZE TABLE analyzeTable COMPUTE STATISTICS noscan") assert(queryTotalSize("analyzeTable") === BigInt(11624)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 440b7c87b0da2..6d8d99ebc8164 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -874,15 +874,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |WITH serdeproperties('s1'='9') """.stripMargin) } - // Now only verify 0.12.0, and ignore other versions due to binary compatibility - // current TestSerDe.jar is from 0.12.0 - if (HiveShim.version == "0.12.0") { - sql(s"ADD JAR $testJar") - sql( - """ALTER TABLE alter1 SET SERDE 'org.apache.hadoop.hive.serde2.TestSerDe' - |WITH serdeproperties('s1'='9') - """.stripMargin) - } sql("DROP TABLE alter1") } @@ -890,15 +881,13 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // this is a test case from mapjoin_addjar.q val testJar = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath val testData = TestHive.getHiveFile("data/files/sample.json").getCanonicalPath - if (HiveShim.version == "0.13.1") { - sql(s"ADD JAR $testJar") - sql( - """CREATE TABLE t1(a string, b string) - |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'""".stripMargin) - sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE t1""") - sql("select * from src join t1 on src.key = t1.a") - sql("DROP TABLE t1") - } + sql(s"ADD JAR $testJar") + sql( + """CREATE TABLE t1(a string, b string) + |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'""".stripMargin) + sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE t1""") + sql("select * from src join t1 on src.key = t1.a") + sql("DROP TABLE t1") } test("ADD FILE command") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index aba3becb1bce2..40a35674e4cb8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.hive.{HiveQLDialect, HiveShim, MetastoreRelation} +import org.apache.spark.sql.hive.{HiveQLDialect, MetastoreRelation} import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.types._ @@ -330,35 +330,33 @@ class SQLQuerySuite extends QueryTest { "serde_p1=p1", "serde_p2=p2", "tbl_p1=p11", "tbl_p2=p22", "MANAGED_TABLE" ) - if (HiveShim.version =="0.13.1") { - val origUseParquetDataSource = conf.parquetUseDataSourceApi - try { - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") - sql( - """CREATE TABLE ctas5 - | STORED AS parquet AS - | SELECT key, value - | FROM src - | ORDER BY key, value""".stripMargin).collect() - - checkExistence(sql("DESC EXTENDED ctas5"), true, - "name:key", "type:string", "name:value", "ctas5", - "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", - "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", - "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", - "MANAGED_TABLE" - ) - - val default = getConf("spark.sql.hive.convertMetastoreParquet", "true") - // use the Hive SerDe for parquet tables - sql("set spark.sql.hive.convertMetastoreParquet = false") - checkAnswer( - sql("SELECT key, value FROM ctas5 ORDER BY key, value"), - sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) - sql(s"set spark.sql.hive.convertMetastoreParquet = $default") - } finally { - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, origUseParquetDataSource.toString) - } + val origUseParquetDataSource = conf.parquetUseDataSourceApi + try { + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + sql( + """CREATE TABLE ctas5 + | STORED AS parquet AS + | SELECT key, value + | FROM src + | ORDER BY key, value""".stripMargin).collect() + + checkExistence(sql("DESC EXTENDED ctas5"), true, + "name:key", "type:string", "name:value", "ctas5", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", + "MANAGED_TABLE" + ) + + val default = getConf("spark.sql.hive.convertMetastoreParquet", "true") + // use the Hive SerDe for parquet tables + sql("set spark.sql.hive.convertMetastoreParquet = false") + checkAnswer( + sql("SELECT key, value FROM ctas5 ORDER BY key, value"), + sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) + sql(s"set spark.sql.hive.convertMetastoreParquet = $default") + } finally { + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, origUseParquetDataSource.toString) } } diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala deleted file mode 100644 index dbc5e029e2047..0000000000000 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ /dev/null @@ -1,457 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.rmi.server.UID -import java.util.{Properties, ArrayList => JArrayList} -import java.io.{OutputStream, InputStream} - -import scala.collection.JavaConversions._ -import scala.language.implicitConversions -import scala.reflect.ClassTag - -import com.esotericsoftware.kryo.Kryo -import com.esotericsoftware.kryo.io.Input -import com.esotericsoftware.kryo.io.Output -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.common.StatsSetupConst -import org.apache.hadoop.hive.common.`type`.HiveDecimal -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.Context -import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} -import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} -import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} -import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory -import org.apache.hadoop.hive.serde.serdeConstants -import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable -import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorConverters, PrimitiveObjectInspector} -import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfo, TypeInfoFactory} -import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer, io => hiveIo} -import org.apache.hadoop.io.{NullWritable, Writable} -import org.apache.hadoop.mapred.InputFormat -import org.apache.hadoop.{io => hadoopIo} - -import org.apache.spark.Logging -import org.apache.spark.sql.types.{Decimal, DecimalType, UTF8String} -import org.apache.spark.util.Utils._ - -/** - * This class provides the UDF creation and also the UDF instance serialization and - * de-serialization cross process boundary. - * - * Detail discussion can be found at https://github.com/apache/spark/pull/3640 - * - * @param functionClassName UDF class name - */ -private[hive] case class HiveFunctionWrapper(var functionClassName: String) - extends java.io.Externalizable { - - // for Serialization - def this() = this(null) - - @transient - def deserializeObjectByKryo[T: ClassTag]( - kryo: Kryo, - in: InputStream, - clazz: Class[_]): T = { - val inp = new Input(in) - val t: T = kryo.readObject(inp,clazz).asInstanceOf[T] - inp.close() - t - } - - @transient - def serializeObjectByKryo( - kryo: Kryo, - plan: Object, - out: OutputStream ) { - val output: Output = new Output(out) - kryo.writeObject(output, plan) - output.close() - } - - def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = { - deserializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), is, clazz) - .asInstanceOf[UDFType] - } - - def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = { - serializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), function, out) - } - - private var instance: AnyRef = null - - def writeExternal(out: java.io.ObjectOutput) { - // output the function name - out.writeUTF(functionClassName) - - // Write a flag if instance is null or not - out.writeBoolean(instance != null) - if (instance != null) { - // Some of the UDF are serializable, but some others are not - // Hive Utilities can handle both cases - val baos = new java.io.ByteArrayOutputStream() - serializePlan(instance, baos) - val functionInBytes = baos.toByteArray - - // output the function bytes - out.writeInt(functionInBytes.length) - out.write(functionInBytes, 0, functionInBytes.length) - } - } - - def readExternal(in: java.io.ObjectInput) { - // read the function name - functionClassName = in.readUTF() - - if (in.readBoolean()) { - // if the instance is not null - // read the function in bytes - val functionInBytesLength = in.readInt() - val functionInBytes = new Array[Byte](functionInBytesLength) - in.read(functionInBytes, 0, functionInBytesLength) - - // deserialize the function object via Hive Utilities - instance = deserializePlan[AnyRef](new java.io.ByteArrayInputStream(functionInBytes), - getContextOrSparkClassLoader.loadClass(functionClassName)) - } - } - - def createFunction[UDFType <: AnyRef](): UDFType = { - if (instance != null) { - instance.asInstanceOf[UDFType] - } else { - val func = getContextOrSparkClassLoader - .loadClass(functionClassName).newInstance.asInstanceOf[UDFType] - if (!func.isInstanceOf[UDF]) { - // We cache the function if it's no the Simple UDF, - // as we always have to create new instance for Simple UDF - instance = func - } - func - } - } -} - -/** - * A compatibility layer for interacting with Hive version 0.13.1. - */ -private[hive] object HiveShim { - val version = "0.13.1" - - def getTableDesc( - serdeClass: Class[_ <: Deserializer], - inputFormatClass: Class[_ <: InputFormat[_, _]], - outputFormatClass: Class[_], - properties: Properties) = { - new TableDesc(inputFormatClass, outputFormatClass, properties) - } - - - def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.stringTypeInfo, getStringWritable(value)) - - def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.intTypeInfo, getIntWritable(value)) - - def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.doubleTypeInfo, getDoubleWritable(value)) - - def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.booleanTypeInfo, getBooleanWritable(value)) - - def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.longTypeInfo, getLongWritable(value)) - - def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.floatTypeInfo, getFloatWritable(value)) - - def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.shortTypeInfo, getShortWritable(value)) - - def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.byteTypeInfo, getByteWritable(value)) - - def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.binaryTypeInfo, getBinaryWritable(value)) - - def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.dateTypeInfo, getDateWritable(value)) - - def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.timestampTypeInfo, getTimestampWritable(value)) - - def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.decimalTypeInfo, getDecimalWritable(value)) - - def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.voidTypeInfo, null) - - def getStringWritable(value: Any): hadoopIo.Text = - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) - - def getIntWritable(value: Any): hadoopIo.IntWritable = - if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) - - def getDoubleWritable(value: Any): hiveIo.DoubleWritable = - if (value == null) { - null - } else { - new hiveIo.DoubleWritable(value.asInstanceOf[Double]) - } - - def getBooleanWritable(value: Any): hadoopIo.BooleanWritable = - if (value == null) { - null - } else { - new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]) - } - - def getLongWritable(value: Any): hadoopIo.LongWritable = - if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long]) - - def getFloatWritable(value: Any): hadoopIo.FloatWritable = - if (value == null) { - null - } else { - new hadoopIo.FloatWritable(value.asInstanceOf[Float]) - } - - def getShortWritable(value: Any): hiveIo.ShortWritable = - if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short]) - - def getByteWritable(value: Any): hiveIo.ByteWritable = - if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte]) - - def getBinaryWritable(value: Any): hadoopIo.BytesWritable = - if (value == null) { - null - } else { - new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) - } - - def getDateWritable(value: Any): hiveIo.DateWritable = - if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[Int]) - - def getTimestampWritable(value: Any): hiveIo.TimestampWritable = - if (value == null) { - null - } else { - new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) - } - - def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable = - if (value == null) { - null - } else { - // TODO precise, scale? - new hiveIo.HiveDecimalWritable( - HiveShim.createDecimal(value.asInstanceOf[Decimal].toJavaBigDecimal)) - } - - def getPrimitiveNullWritable: NullWritable = NullWritable.get() - - def createDriverResultsArray = new JArrayList[Object] - - def processResults(results: JArrayList[Object]) = { - results.map { r => - r match { - case s: String => s - case a: Array[Object] => a(0).asInstanceOf[String] - } - } - } - - def getStatsSetupConstTotalSize = StatsSetupConst.TOTAL_SIZE - - def getStatsSetupConstRawDataSize = StatsSetupConst.RAW_DATA_SIZE - - def createDefaultDBIfNeeded(context: HiveContext) = { - context.runSqlHive("CREATE DATABASE default") - context.runSqlHive("USE default") - } - - def getCommandProcessor(cmd: Array[String], conf: HiveConf) = { - CommandProcessorFactory.get(cmd, conf) - } - - def createDecimal(bd: java.math.BigDecimal): HiveDecimal = { - HiveDecimal.create(bd) - } - - /* - * This function in hive-0.13 become private, but we have to do this to walkaround hive bug - */ - private def appendReadColumnNames(conf: Configuration, cols: Seq[String]) { - val old: String = conf.get(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, "") - val result: StringBuilder = new StringBuilder(old) - var first: Boolean = old.isEmpty - - for (col <- cols) { - if (first) { - first = false - } else { - result.append(',') - } - result.append(col) - } - conf.set(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, result.toString) - } - - /* - * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null or empty - */ - def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { - if (ids != null && ids.size > 0) { - ColumnProjectionUtils.appendReadColumns(conf, ids) - } - if (names != null && names.size > 0) { - appendReadColumnNames(conf, names) - } - } - - def getExternalTmpPath(context: Context, path: Path) = { - context.getExternalTmpPath(path.toUri) - } - - def getDataLocationPath(p: Partition) = p.getDataLocation - - def getAllPartitionsOf(client: Hive, tbl: Table) = client.getAllPartitionsOf(tbl) - - def compatibilityBlackList = Seq() - - def setLocation(tbl: Table, crtTbl: CreateTableDesc): Unit = { - tbl.setDataLocation(new Path(crtTbl.getLocation())) - } - - /* - * Bug introdiced in hive-0.13. FileSinkDesc is serializable, but its member path is not. - * Fix it through wrapper. - * */ - implicit def wrapperToFileSinkDesc(w: ShimFileSinkDesc): FileSinkDesc = { - var f = new FileSinkDesc(new Path(w.dir), w.tableInfo, w.compressed) - f.setCompressCodec(w.compressCodec) - f.setCompressType(w.compressType) - f.setTableInfo(w.tableInfo) - f.setDestTableId(w.destTableId) - f - } - - // Precision and scale to pass for unlimited decimals; these are the same as the precision and - // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs) - private val UNLIMITED_DECIMAL_PRECISION = 38 - private val UNLIMITED_DECIMAL_SCALE = 18 - - def decimalMetastoreString(decimalType: DecimalType): String = decimalType match { - case DecimalType.Fixed(precision, scale) => s"decimal($precision,$scale)" - case _ => s"decimal($UNLIMITED_DECIMAL_PRECISION,$UNLIMITED_DECIMAL_SCALE)" - } - - def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match { - case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale) - case _ => new DecimalTypeInfo(UNLIMITED_DECIMAL_PRECISION, UNLIMITED_DECIMAL_SCALE) - } - - def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = { - val info = inspector.getTypeInfo.asInstanceOf[DecimalTypeInfo] - DecimalType(info.precision(), info.scale()) - } - - def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { - if (hdoi.preferWritable()) { - Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue, - hdoi.precision(), hdoi.scale()) - } else { - Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) - } - } - - def getConvertedOI(inputOI: ObjectInspector, outputOI: ObjectInspector): ObjectInspector = { - ObjectInspectorConverters.getConvertedOI(inputOI, outputOI) - } - - /* - * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that - * is needed to initialize before serialization. - */ - def prepareWritable(w: Writable): Writable = { - w match { - case w: AvroGenericRecordWritable => - w.setRecordReaderID(new UID()) - case _ => - } - w - } - - def setTblNullFormat(crtTbl: CreateTableDesc, tbl: Table) = { - if (crtTbl != null && crtTbl.getNullFormat() != null) { - tbl.setSerdeParam(serdeConstants.SERIALIZATION_NULL_FORMAT, crtTbl.getNullFormat()) - } - } -} - -/* - * Bug introduced in hive-0.13. FileSinkDesc is serilizable, but its member path is not. - * Fix it through wrapper. - */ -private[hive] class ShimFileSinkDesc( - var dir: String, - var tableInfo: TableDesc, - var compressed: Boolean) - extends Serializable with Logging { - var compressCodec: String = _ - var compressType: String = _ - var destTableId: Int = _ - - def setCompressed(compressed: Boolean) { - this.compressed = compressed - } - - def getDirName = dir - - def setDestTableId(destTableId: Int) { - this.destTableId = destTableId - } - - def setTableInfo(tableInfo: TableDesc) { - this.tableInfo = tableInfo - } - - def setCompressCodec(intermediateCompressorCodec: String) { - compressCodec = intermediateCompressorCodec - } - - def setCompressType(intermediateCompressType: String) { - compressType = intermediateCompressType - } -} From 65938422718383d17f084e577763e2c671726baa Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 4 Jun 2015 13:44:47 -0700 Subject: [PATCH 189/254] Fixed style issues for [SPARK-6909][SQL] Remove Hive Shim code. --- .../sql/hive/thriftserver/HiveThriftServer2.scala | 5 +++-- .../hive/thriftserver/SparkSQLSessionManager.scala | 14 ++++++++------ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 5b391d3dce882..c9da25253e13f 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.hive.thriftserver +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + import org.apache.commons.logging.LogFactory import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars @@ -32,8 +35,6 @@ import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkContext} -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer /** * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index 357b27f7401a3..2d5ee68002286 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager + private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) extends SessionManager with ReflectedCompositeService { @@ -50,12 +51,13 @@ private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) initCompositeService(hiveConf) } - override def openSession(protocol: TProtocolVersion, - username: String, - passwd: String, - sessionConf: java.util.Map[String, String], - withImpersonation: Boolean, - delegationToken: String): SessionHandle = { + override def openSession( + protocol: TProtocolVersion, + username: String, + passwd: String, + sessionConf: java.util.Map[String, String], + withImpersonation: Boolean, + delegationToken: String): SessionHandle = { hiveContext.openSession() val sessionHandle = super.openSession( protocol, username, passwd, sessionConf, withImpersonation, delegationToken) From 2bcdf8c239d2ba79f64fb8878da83d4c2ec28b30 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 4 Jun 2015 13:52:53 -0700 Subject: [PATCH 190/254] [SPARK-7440][SQL] Remove physical Distinct operator in favor of Aggregate This patch replaces Distinct with Aggregate in the optimizer, so Distinct will become more efficient over time as we optimize Aggregate (via Tungsten). Author: Reynold Xin Closes #6637 from rxin/replace-distinct and squashes the following commits: b3cc50e [Reynold Xin] Mima excludes. 93d6117 [Reynold Xin] Code review feedback. 87e4741 [Reynold Xin] [SPARK-7440][SQL] Remove physical Distinct operator in favor of Aggregate. --- project/MimaExcludes.scala | 4 +- .../sql/catalyst/optimizer/Optimizer.scala | 14 +++++++ .../plans/logical/basicOperators.scala | 3 ++ .../ReplaceDistinctWithAggregateSuite.scala | 42 +++++++++++++++++++ .../org/apache/spark/sql/DataFrame.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 4 +- .../spark/sql/execution/basicOperators.scala | 31 -------------- 7 files changed, 65 insertions(+), 35 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 34371c9659423..73e4bfd78e577 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -46,7 +46,9 @@ object MimaExcludes { "org.apache.spark.api.java.JavaRDDLike.partitioner"), // Mima false positive (was a private[spark] class) ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.util.collection.PairIterator") + "org.apache.spark.util.collection.PairIterator"), + // SQL execution is considered private. + excludePackage("org.apache.spark.sql.execution") ) case v if v.startsWith("1.4") => Seq( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5c6379b8d44b0..0a17b10c521e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -36,6 +36,8 @@ object DefaultOptimizer extends Optimizer { // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: + Batch("Distinct", FixedPoint(100), + ReplaceDistinctWithAggregate) :: Batch("Operator Reordering", FixedPoint(100), UnionPushdown, CombineFilters, @@ -696,3 +698,15 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] { LocalRelation(projectList.map(_.toAttribute), data.map(projection)) } } + +/** + * Replaces logical [[Distinct]] operator with an [[Aggregate]] operator. + * {{{ + * SELECT DISTINCT f1, f2 FROM t ==> SELECT f1, f2 FROM t GROUP BY f1, f2 + * }}} + */ +object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Distinct(child) => Aggregate(child.output, child.output, child) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 33a9e55a47dee..e77e5c27b687a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -339,6 +339,9 @@ case class Sample( override def output: Seq[Attribute] = child.output } +/** + * Returns a new logical plan that dedups input rows. + */ case class Distinct(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala new file mode 100644 index 0000000000000..df29a62ff0e15 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class ReplaceDistinctWithAggregateSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("ProjectCollapsing", Once, ReplaceDistinctWithAggregate) :: Nil + } + + test("replace distinct with aggregate") { + val input = LocalRelation('a.int, 'b.int) + + val query = Distinct(input) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = Aggregate(input.output, input.output, input) + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index d1a54ada7b191..4a224153e1a37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1311,7 +1311,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - override def distinct: DataFrame = Distinct(logicalPlan) + override def distinct: DataFrame = dropDuplicates() /** * @group basic diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d0a1ad00560d3..7a1331a39151a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -284,8 +284,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case r: RunnableCommand => ExecutedCommand(r) :: Nil case logical.Distinct(child) => - execution.Distinct(partial = false, - execution.Distinct(partial = true, planLater(child))) :: Nil + throw new IllegalStateException( + "logical distinct operator should have been replaced by aggregate in the optimizer") case logical.Repartition(numPartitions, shuffle, child) => execution.Repartition(numPartitions, shuffle, planLater(child)) :: Nil case logical.SortPartitions(sortExprs, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index a30ade86441ca..fb42072f9d5a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -230,37 +230,6 @@ case class ExternalSort( override def outputOrdering: Seq[SortOrder] = sortOrder } -/** - * :: DeveloperApi :: - * Computes the set of distinct input rows using a HashSet. - * @param partial when true the distinct operation is performed partially, per partition, without - * shuffling the data. - * @param child the input query plan. - */ -@DeveloperApi -case class Distinct(partial: Boolean, child: SparkPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output - - override def requiredChildDistribution: Seq[Distribution] = - if (partial) UnspecifiedDistribution :: Nil else ClusteredDistribution(child.output) :: Nil - - protected override def doExecute(): RDD[Row] = { - child.execute().mapPartitions { iter => - val hashSet = new scala.collection.mutable.HashSet[Row]() - - var currentRow: Row = null - while (iter.hasNext) { - currentRow = iter.next() - if (!hashSet.contains(currentRow)) { - hashSet.add(currentRow.copy()) - } - } - - hashSet.iterator - } - } -} - /** * :: DeveloperApi :: * Return a new RDD that has exactly `numPartitions` partitions. From 63bc0c4430680cce230dd7a10d34da0492351446 Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Thu, 4 Jun 2015 16:24:50 -0700 Subject: [PATCH 191/254] [SPARK-8098] [WEBUI] Show correct length of bytes on log page The log page should only show desired length of bytes. Currently it shows bytes from the startIndex to the end of the file. The "Next" button on the page is always disabled. Author: Carson Wang Closes #6640 from carsonwang/logpage and squashes the following commits: 58cb3fd [Carson Wang] Show correct length of bytes on log page --- .../main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 53f8f9a46cf8d..5a1d06eb87db9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -159,7 +159,7 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with offset } } - val endIndex = math.min(startIndex + totalLength, totalLength) + val endIndex = math.min(startIndex + byteLength, totalLength) logDebug(s"Getting log from $startIndex to $endIndex") val logText = Utils.offsetBytes(files, startIndex, endIndex) logDebug(s"Got log of length ${logText.length} bytes") From 74dc2a90bcb05b64c3e7efc02d1451b0cbc2adba Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 4 Jun 2015 17:33:24 -0700 Subject: [PATCH 192/254] [SPARK-8106] [SQL] Set derby.system.durability=test to speed up Hive compatibility tests Derby has a `derby.system.durability` configuration property that can be used to disable I/O synchronization calls for writes. This sacrifices durability but can result in large performance gains, which is appropriate for tests. We should enable this in our test system properties in order to speed up the Hive compatibility tests. I saw 2-3x speedups locally with this change. See https://db.apache.org/derby/docs/10.8/ref/rrefproperdurability.html for more documentation of this property. Author: Josh Rosen Closes #6651 from JoshRosen/hive-compat-suite-speedup and squashes the following commits: b7a08a2 [Josh Rosen] Set derby.system.durability=test in our unit tests. --- pom.xml | 2 ++ project/SparkBuild.scala | 1 + 2 files changed, 3 insertions(+) diff --git a/pom.xml b/pom.xml index abb9b55400340..e28d4b9fc2b17 100644 --- a/pom.xml +++ b/pom.xml @@ -1254,6 +1254,7 @@ ${test.java.home} + test true ${spark.test.home} 1 @@ -1286,6 +1287,7 @@ ${test.java.home} + test true ${spark.test.home} 1 diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index f65031fe25ac2..ef3a175bac209 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -504,6 +504,7 @@ object TestSettings { javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", + javaOptions in Test += "-Dderby.system.durability=test", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") .map { case (k,v) => s"-D$k=$v" }.toSeq, javaOptions in Test += "-ea", From 8f16b94afb39e1641c02d4e0be18d34ef7c211cc Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 4 Jun 2015 22:15:58 -0700 Subject: [PATCH 193/254] [SPARK-8114][SQL] Remove some wildcard import on TestSQLContext._ I kept some of the sql import there to avoid changing too many lines. Author: Reynold Xin Closes #6661 from rxin/remove-wildcard-import-sqlcontext and squashes the following commits: c265347 [Reynold Xin] Fixed ListTablesSuite failure. de9d491 [Reynold Xin] Fixed tests. 73b5365 [Reynold Xin] Mima. 8f6b642 [Reynold Xin] Fixed style violation. 443f6e8 [Reynold Xin] [SPARK-8113][SQL] Remove some wildcard import on TestSQLContext._ --- .../sql/catalyst/analysis/Analyzer.scala | 12 +- .../apache/spark/sql/CachedTableSuite.scala | 160 +++++++++--------- .../spark/sql/ColumnExpressionSuite.scala | 15 +- .../spark/sql/DataFrameAggregateSuite.scala | 9 +- .../spark/sql/DataFrameFunctionsSuite.scala | 4 +- .../spark/sql/DataFrameImplicitsSuite.scala | 15 +- .../apache/spark/sql/DataFrameJoinSuite.scala | 9 +- .../spark/sql/DataFrameNaFunctionsSuite.scala | 5 +- .../apache/spark/sql/DataFrameStatSuite.scala | 8 +- .../org/apache/spark/sql/DataFrameSuite.scala | 68 ++++---- .../org/apache/spark/sql/JoinSuite.scala | 65 +++---- .../apache/spark/sql/ListTablesSuite.scala | 35 ++-- .../spark/sql/MathExpressionsSuite.scala | 44 +++-- .../scala/org/apache/spark/sql/RowSuite.scala | 7 +- .../org/apache/spark/sql/SQLConfSuite.scala | 67 ++++---- .../apache/spark/sql/SQLContextSuite.scala | 16 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 125 +++++++------- .../sql/ScalaReflectionRelationSuite.scala | 31 ++-- .../apache/spark/sql/SerializationSuite.scala | 5 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 28 ++- .../spark/sql/UserDefinedTypeSuite.scala | 23 ++- 21 files changed, 373 insertions(+), 378 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index bc17169f35a46..5883d938b676d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -235,9 +235,8 @@ class Analyzer( } /** - * Replaces [[UnresolvedAttribute]]s with concrete - * [[catalyst.expressions.AttributeReference AttributeReferences]] from a logical plan node's - * children. + * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from + * a logical plan node's children. */ object ResolveReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { @@ -455,7 +454,7 @@ class Analyzer( } /** - * Replaces [[UnresolvedFunction]]s with concrete [[catalyst.expressions.Expression Expressions]]. + * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -846,9 +845,8 @@ class Analyzer( } /** - * Removes [[catalyst.plans.logical.Subquery Subquery]] operators from the plan. Subqueries are - * only required to provide scoping information for attributes and can be removed once analysis is - * complete. + * Removes [[Subquery]] operators from the plan. Subqueries are only required to provide + * scoping information for attributes and can be removed once analysis is complete. */ object EliminateSubQueries extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 0772e5e187425..72e60d9aa75cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -25,8 +25,6 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.storage.{RDDBlockId, StorageLevel} case class BigData(s: String) @@ -34,8 +32,12 @@ case class BigData(s: String) class CachedTableSuite extends QueryTest { TestData // Load test tables. + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + import ctx.sql + def rddIdOf(tableName: String): Int = { - val executedPlan = table(tableName).queryExecution.executedPlan + val executedPlan = ctx.table(tableName).queryExecution.executedPlan executedPlan.collect { case InMemoryColumnarTableScan(_, _, relation) => relation.cachedColumnBuffers.id @@ -45,47 +47,47 @@ class CachedTableSuite extends QueryTest { } def isMaterialized(rddId: Int): Boolean = { - sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty + ctx.sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty } test("cache temp table") { testData.select('key).registerTempTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable"), 0) - cacheTable("tempTable") + ctx.cacheTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - uncacheTable("tempTable") + ctx.uncacheTable("tempTable") } test("unpersist an uncached table will not raise exception") { - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == ctx.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == ctx.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == ctx.cacheManager.lookupCachedData(testData)) testData.persist() - assert(None != cacheManager.lookupCachedData(testData)) + assert(None != ctx.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == ctx.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == ctx.cacheManager.lookupCachedData(testData)) } test("cache table as select") { sql("CACHE TABLE tempTable AS SELECT key FROM testData") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - uncacheTable("tempTable") + ctx.uncacheTable("tempTable") } test("uncaching temp table") { testData.select('key).registerTempTable("tempTable1") testData.select('key).registerTempTable("tempTable2") - cacheTable("tempTable1") + ctx.cacheTable("tempTable1") assertCached(sql("SELECT COUNT(*) FROM tempTable1")) assertCached(sql("SELECT COUNT(*) FROM tempTable2")) // Is this valid? - uncacheTable("tempTable2") + ctx.uncacheTable("tempTable2") // Should this be cached? assertCached(sql("SELECT COUNT(*) FROM tempTable1"), 0) @@ -93,103 +95,103 @@ class CachedTableSuite extends QueryTest { test("too big for memory") { val data = "*" * 10000 - sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() + ctx.sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() .registerTempTable("bigData") - table("bigData").persist(StorageLevel.MEMORY_AND_DISK) - assert(table("bigData").count() === 200000L) - table("bigData").unpersist(blocking = true) + ctx.table("bigData").persist(StorageLevel.MEMORY_AND_DISK) + assert(ctx.table("bigData").count() === 200000L) + ctx.table("bigData").unpersist(blocking = true) } test("calling .cache() should use in-memory columnar caching") { - table("testData").cache() - assertCached(table("testData")) - table("testData").unpersist(blocking = true) + ctx.table("testData").cache() + assertCached(ctx.table("testData")) + ctx.table("testData").unpersist(blocking = true) } test("calling .unpersist() should drop in-memory columnar cache") { - table("testData").cache() - table("testData").count() - table("testData").unpersist(blocking = true) - assertCached(table("testData"), 0) + ctx.table("testData").cache() + ctx.table("testData").count() + ctx.table("testData").unpersist(blocking = true) + assertCached(ctx.table("testData"), 0) } test("isCached") { - cacheTable("testData") + ctx.cacheTable("testData") - assertCached(table("testData")) - assert(table("testData").queryExecution.withCachedData match { + assertCached(ctx.table("testData")) + assert(ctx.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => true case _ => false }) - uncacheTable("testData") - assert(!isCached("testData")) - assert(table("testData").queryExecution.withCachedData match { + ctx.uncacheTable("testData") + assert(!ctx.isCached("testData")) + assert(ctx.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => false case _ => true }) } test("SPARK-1669: cacheTable should be idempotent") { - assume(!table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) + assume(!ctx.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) - cacheTable("testData") - assertCached(table("testData")) + ctx.cacheTable("testData") + assertCached(ctx.table("testData")) assertResult(1, "InMemoryRelation not found, testData should have been cached") { - table("testData").queryExecution.withCachedData.collect { + ctx.table("testData").queryExecution.withCachedData.collect { case r: InMemoryRelation => r }.size } - cacheTable("testData") + ctx.cacheTable("testData") assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { - table("testData").queryExecution.withCachedData.collect { + ctx.table("testData").queryExecution.withCachedData.collect { case r @ InMemoryRelation(_, _, _, _, _: InMemoryColumnarTableScan, _) => r }.size } - uncacheTable("testData") + ctx.uncacheTable("testData") } test("read from cached table and uncache") { - cacheTable("testData") - checkAnswer(table("testData"), testData.collect().toSeq) - assertCached(table("testData")) + ctx.cacheTable("testData") + checkAnswer(ctx.table("testData"), testData.collect().toSeq) + assertCached(ctx.table("testData")) - uncacheTable("testData") - checkAnswer(table("testData"), testData.collect().toSeq) - assertCached(table("testData"), 0) + ctx.uncacheTable("testData") + checkAnswer(ctx.table("testData"), testData.collect().toSeq) + assertCached(ctx.table("testData"), 0) } test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { - uncacheTable("testData") + ctx.uncacheTable("testData") } } test("SELECT star from cached table") { sql("SELECT * FROM testData").registerTempTable("selectStar") - cacheTable("selectStar") + ctx.cacheTable("selectStar") checkAnswer( sql("SELECT * FROM selectStar WHERE key = 1"), Seq(Row(1, "1"))) - uncacheTable("selectStar") + ctx.uncacheTable("selectStar") } test("Self-join cached") { val unCachedAnswer = sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() - cacheTable("testData") + ctx.cacheTable("testData") checkAnswer( sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), unCachedAnswer.toSeq) - uncacheTable("testData") + ctx.uncacheTable("testData") } test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") { sql("CACHE TABLE testData") - assertCached(table("testData")) + assertCached(ctx.table("testData")) val rddId = rddIdOf("testData") assert( @@ -197,7 +199,7 @@ class CachedTableSuite extends QueryTest { "Eagerly cached in-memory table should have already been materialized") sql("UNCACHE TABLE testData") - assert(!isCached("testData"), "Table 'testData' should not be cached") + assert(!ctx.isCached("testData"), "Table 'testData' should not be cached") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") @@ -206,14 +208,14 @@ class CachedTableSuite extends QueryTest { test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") - assertCached(table("testCacheTable")) + assertCached(ctx.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - uncacheTable("testCacheTable") + ctx.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -221,14 +223,14 @@ class CachedTableSuite extends QueryTest { test("CACHE TABLE tableName AS SELECT ...") { sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10") - assertCached(table("testCacheTable")) + assertCached(ctx.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - uncacheTable("testCacheTable") + ctx.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -236,7 +238,7 @@ class CachedTableSuite extends QueryTest { test("CACHE LAZY TABLE tableName") { sql("CACHE LAZY TABLE testData") - assertCached(table("testData")) + assertCached(ctx.table("testData")) val rddId = rddIdOf("testData") assert( @@ -248,7 +250,7 @@ class CachedTableSuite extends QueryTest { isMaterialized(rddId), "Lazily cached in-memory table should have been materialized") - uncacheTable("testData") + ctx.uncacheTable("testData") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -256,7 +258,7 @@ class CachedTableSuite extends QueryTest { test("InMemoryRelation statistics") { sql("CACHE TABLE testData") - table("testData").queryExecution.withCachedData.collect { + ctx.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum assert(cached.statistics.sizeInBytes === actualSizeInBytes) @@ -265,38 +267,38 @@ class CachedTableSuite extends QueryTest { test("Drops temporary table") { testData.select('key).registerTempTable("t1") - table("t1") - dropTempTable("t1") - assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found")) + ctx.table("t1") + ctx.dropTempTable("t1") + assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found")) } test("Drops cached temporary table") { testData.select('key).registerTempTable("t1") testData.select('key).registerTempTable("t2") - cacheTable("t1") + ctx.cacheTable("t1") - assert(isCached("t1")) - assert(isCached("t2")) + assert(ctx.isCached("t1")) + assert(ctx.isCached("t2")) - dropTempTable("t1") - assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found")) - assert(!isCached("t2")) + ctx.dropTempTable("t1") + assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found")) + assert(!ctx.isCached("t2")) } test("Clear all cache") { sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - cacheTable("t1") - cacheTable("t2") - clearCache() - assert(cacheManager.isEmpty) + ctx.cacheTable("t1") + ctx.cacheTable("t2") + ctx.clearCache() + assert(ctx.cacheManager.isEmpty) sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - cacheTable("t1") - cacheTable("t2") + ctx.cacheTable("t1") + ctx.cacheTable("t2") sql("Clear CACHE") - assert(cacheManager.isEmpty) + assert(ctx.cacheManager.isEmpty) } test("Clear accumulators when uncacheTable to prevent memory leaking") { @@ -305,8 +307,8 @@ class CachedTableSuite extends QueryTest { Accumulators.synchronized { val accsSize = Accumulators.originals.size - cacheTable("t1") - cacheTable("t2") + ctx.cacheTable("t1") + ctx.cacheTable("t2") assert((accsSize + 2) == Accumulators.originals.size) } @@ -317,8 +319,8 @@ class CachedTableSuite extends QueryTest { Accumulators.synchronized { val accsSize = Accumulators.originals.size - uncacheTable("t1") - uncacheTable("t2") + ctx.uncacheTable("t1") + ctx.uncacheTable("t2") assert((accsSize - 2) == Accumulators.originals.size) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index bfba379d9a518..4f5484f1368d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -21,13 +21,14 @@ import org.scalatest.Matchers._ import org.apache.spark.sql.execution.Project import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ class ColumnExpressionSuite extends QueryTest { import org.apache.spark.sql.TestData._ + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("alias") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") assert(df.select(df("a").as("b")).columns.head === "b") @@ -213,7 +214,7 @@ class ColumnExpressionSuite extends QueryTest { } test("!==") { - val nullData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize( + val nullData = ctx.createDataFrame(ctx.sparkContext.parallelize( Row(1, 1) :: Row(1, 2) :: Row(1, null) :: @@ -274,7 +275,7 @@ class ColumnExpressionSuite extends QueryTest { } test("between") { - val testData = TestSQLContext.sparkContext.parallelize( + val testData = ctx.sparkContext.parallelize( (0, 1, 2) :: (1, 2, 3) :: (2, 1, 0) :: @@ -287,7 +288,7 @@ class ColumnExpressionSuite extends QueryTest { checkAnswer(testData.filter($"a".between($"b", $"c")), expectAnswer) } - val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize( + val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize( Row(false, false) :: Row(false, true) :: Row(true, false) :: @@ -413,7 +414,7 @@ class ColumnExpressionSuite extends QueryTest { test("monotonicallyIncreasingId") { // Make sure we have 2 partitions, each with 2 records. - val df = TestSQLContext.sparkContext.parallelize(1 to 2, 2).mapPartitions { iter => + val df = ctx.sparkContext.parallelize(1 to 2, 2).mapPartitions { iter => Iterator(Tuple1(1), Tuple1(2)) }.toDF("a") checkAnswer( @@ -423,7 +424,7 @@ class ColumnExpressionSuite extends QueryTest { } test("sparkPartitionId") { - val df = TestSQLContext.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b") + val df = ctx.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b") checkAnswer( df.select(sparkPartitionId()), Row(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 232f05c00918f..790b405c72697 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types.DecimalType class DataFrameAggregateSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("groupBy") { checkAnswer( testData2.groupBy("a").agg(sum($"b")), @@ -67,12 +68,12 @@ class DataFrameAggregateSuite extends QueryTest { Seq(Row(1, 3), Row(2, 3), Row(3, 3)) ) - TestSQLContext.conf.setConf("spark.sql.retainGroupColumns", "false") + ctx.conf.setConf("spark.sql.retainGroupColumns", "false") checkAnswer( testData2.groupBy("a").agg(sum($"b")), Seq(Row(3), Row(3), Row(3)) ) - TestSQLContext.conf.setConf("spark.sql.retainGroupColumns", "true") + ctx.conf.setConf("spark.sql.retainGroupColumns", "true") } test("agg without groups") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index b1e0faa310b68..53c2befb73702 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ /** @@ -27,6 +26,9 @@ import org.apache.spark.sql.types._ */ class DataFrameFunctionsSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("array with column name") { val df = Seq((0, 1)).toDF("a", "b") val row = df.select(array("a", "b")).first() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala index 2d2367d6e7292..fbb30706a4943 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.TestSQLContext.{sparkContext => sc} -import org.apache.spark.sql.test.TestSQLContext.implicits._ - - class DataFrameImplicitsSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("RDD of tuples") { checkAnswer( - sc.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), + ctx.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), (1 to 10).map(i => Row(i, i.toString))) } @@ -37,19 +36,19 @@ class DataFrameImplicitsSuite extends QueryTest { test("RDD[Int]") { checkAnswer( - sc.parallelize(1 to 10).toDF("intCol"), + ctx.sparkContext.parallelize(1 to 10).toDF("intCol"), (1 to 10).map(i => Row(i))) } test("RDD[Long]") { checkAnswer( - sc.parallelize(1L to 10L).toDF("longCol"), + ctx.sparkContext.parallelize(1L to 10L).toDF("longCol"), (1L to 10L).map(i => Row(i))) } test("RDD[String]") { checkAnswer( - sc.parallelize(1 to 10).map(_.toString).toDF("stringCol"), + ctx.sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"), (1 to 10).map(i => Row(i.toString))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 787f3f175fea2..051d13e9a544f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ - class DataFrameJoinSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("join - join using") { val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") val df2 = Seq(1, 2, 3).map(i => (i, (i + 1).toString)).toDF("int", "str") @@ -49,7 +49,8 @@ class DataFrameJoinSuite extends QueryTest { checkAnswer( df1.join(df2, $"df1.key" === $"df2.key"), - sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq) + ctx.sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") + .collect().toSeq) } test("join - using aliases after self join") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 41b4f02e6a294..495701d4f616c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql import scala.collection.JavaConversions._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ - class DataFrameNaFunctionsSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + def createDF(): DataFrame = { Seq[(String, java.lang.Integer, java.lang.Double)]( ("Bob", 16, 176.5), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 438f479459dfe..0d3ff899dad72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.implicits._ class DataFrameStatSuite extends SparkFunSuite { - val sqlCtx = TestSQLContext - def toLetter(i: Int): String = (i + 97).toChar.toString + private val sqlCtx = org.apache.spark.sql.test.TestSQLContext + import sqlCtx.implicits._ + + private def toLetter(i: Int): String = (i + 97).toChar.toString test("pearson correlation") { val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 8e81dacb8660f..bb8621abe64ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -21,17 +21,19 @@ import scala.language.postfixOps import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext} -import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint} class DataFrameSuite extends QueryTest { import org.apache.spark.sql.TestData._ + lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("analysis error should be eagerly reported") { - val oldSetting = TestSQLContext.conf.dataFrameEagerAnalysis + val oldSetting = ctx.conf.dataFrameEagerAnalysis // Eager analysis. - TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") intercept[Exception] { testData.select('nonExistentName) } intercept[Exception] { @@ -45,11 +47,11 @@ class DataFrameSuite extends QueryTest { } // No more eager analysis once the flag is turned off - TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false") + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false") testData.select('nonExistentName) // Set the flag back to original value before this test. - TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) } test("dataframe toString") { @@ -67,12 +69,12 @@ class DataFrameSuite extends QueryTest { } test("invalid plan toString, debug mode") { - val oldSetting = TestSQLContext.conf.dataFrameEagerAnalysis - TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") + val oldSetting = ctx.conf.dataFrameEagerAnalysis + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") // Turn on debug mode so we can see invalid query plans. import org.apache.spark.sql.execution.debug._ - TestSQLContext.debug() + ctx.debug() val badPlan = testData.select('badColumn) @@ -81,7 +83,7 @@ class DataFrameSuite extends QueryTest { badPlan.toString) // Set the flag back to original value before this test. - TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) } test("access complex data") { @@ -97,8 +99,8 @@ class DataFrameSuite extends QueryTest { } test("empty data frame") { - assert(TestSQLContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) - assert(TestSQLContext.emptyDataFrame.count() === 0) + assert(ctx.emptyDataFrame.columns.toSeq === Seq.empty[String]) + assert(ctx.emptyDataFrame.count() === 0) } test("head and take") { @@ -311,7 +313,7 @@ class DataFrameSuite extends QueryTest { } test("replace column using withColumn") { - val df2 = TestSQLContext.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") + val df2 = ctx.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") val df3 = df2.withColumn("x", df2("x") + 1) checkAnswer( df3.select("x"), @@ -392,7 +394,7 @@ class DataFrameSuite extends QueryTest { test("randomSplit") { val n = 600 - val data = TestSQLContext.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") for (seed <- 1 to 5) { val splits = data.randomSplit(Array[Double](1, 2, 3), seed) assert(splits.length == 3, "wrong number of splits") @@ -487,21 +489,21 @@ class DataFrameSuite extends QueryTest { } test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { - val rowRDD = TestSQLContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) + val rowRDD = ctx.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) - val df = TestSQLContext.createDataFrame(rowRDD, schema) + val df = ctx.createDataFrame(rowRDD, schema) df.rdd.collect() } test("SPARK-6899") { - val originalValue = TestSQLContext.conf.codegenEnabled - TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, "true") + val originalValue = ctx.conf.codegenEnabled + ctx.setConf(SQLConf.CODEGEN_ENABLED, "true") try{ checkAnswer( decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2.0))) } finally { - TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + ctx.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) } } @@ -513,14 +515,14 @@ class DataFrameSuite extends QueryTest { } test("SPARK-7551: support backticks for DataFrame attribute resolution") { - val df = TestSQLContext.read.json(TestSQLContext.sparkContext.makeRDD( + val df = ctx.read.json(ctx.sparkContext.makeRDD( """{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil)) checkAnswer( df.select(df("`a.b`.c.`d..e`.`f`")), Row(1) ) - val df2 = TestSQLContext.read.json(TestSQLContext.sparkContext.makeRDD( + val df2 = ctx.read.json(ctx.sparkContext.makeRDD( """{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil)) checkAnswer( df2.select(df2("`a b`.c.d e.f")), @@ -540,7 +542,7 @@ class DataFrameSuite extends QueryTest { } test("SPARK-7324 dropDuplicates") { - val testData = TestSQLContext.sparkContext.parallelize( + val testData = ctx.sparkContext.parallelize( (2, 1, 2) :: (1, 1, 1) :: (1, 2, 1) :: (2, 1, 2) :: (2, 2, 2) :: (2, 2, 1) :: @@ -588,49 +590,49 @@ class DataFrameSuite extends QueryTest { test("SPARK-7150 range api") { // numSlice is greater than length - val res1 = TestSQLContext.range(0, 10, 1, 15).select("id") + val res1 = ctx.range(0, 10, 1, 15).select("id") assert(res1.count == 10) assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - val res2 = TestSQLContext.range(3, 15, 3, 2).select("id") + val res2 = ctx.range(3, 15, 3, 2).select("id") assert(res2.count == 4) assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) - val res3 = TestSQLContext.range(1, -2).select("id") + val res3 = ctx.range(1, -2).select("id") assert(res3.count == 0) // start is positive, end is negative, step is negative - val res4 = TestSQLContext.range(1, -2, -2, 6).select("id") + val res4 = ctx.range(1, -2, -2, 6).select("id") assert(res4.count == 2) assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0))) // start, end, step are negative - val res5 = TestSQLContext.range(-3, -8, -2, 1).select("id") + val res5 = ctx.range(-3, -8, -2, 1).select("id") assert(res5.count == 3) assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15))) // start, end are negative, step is positive - val res6 = TestSQLContext.range(-8, -4, 2, 1).select("id") + val res6 = ctx.range(-8, -4, 2, 1).select("id") assert(res6.count == 2) assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14))) - val res7 = TestSQLContext.range(-10, -9, -20, 1).select("id") + val res7 = ctx.range(-10, -9, -20, 1).select("id") assert(res7.count == 0) - val res8 = TestSQLContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") + val res8 = ctx.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") assert(res8.count == 3) assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) - val res9 = TestSQLContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") + val res9 = ctx.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") assert(res9.count == 2) assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) // only end provided as argument - val res10 = TestSQLContext.range(10).select("id") + val res10 = ctx.range(10).select("id") assert(res10.count == 10) assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - val res11 = TestSQLContext.range(-1).select("id") + val res11 = ctx.range(-1).select("id") assert(res11.count == 0) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 407c789657834..ffd26c4f5a7c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -20,27 +20,28 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.functions._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Ensures tables are loaded. TestData + lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + import ctx.logicalPlanToSparkQuery + test("equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan - val planned = planner.HashJoin(join) + val planned = ctx.planner.HashJoin(join) assert(planned.size === 1) } def assertJoin(sqlString: String, c: Class[_]): Any = { - val df = sql(sqlString) + val df = ctx.sql(sqlString) val physical = df.queryExecution.sparkPlan val operators = physical.collect { case j: ShuffledHashJoin => j @@ -61,9 +62,9 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("join operator selection") { - cacheManager.clearCache() + ctx.cacheManager.clearCache() - val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled + val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), @@ -94,22 +95,22 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } try { - conf.setConf("spark.sql.planner.sortMergeJoin", "true") + ctx.conf.setConf("spark.sql.planner.sortMergeJoin", "true") Seq( ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } finally { - conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + ctx.conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) } } test("broadcasted hash join operator selection") { - cacheManager.clearCache() - sql("CACHE TABLE testData") + ctx.cacheManager.clearCache() + ctx.sql("CACHE TABLE testData") - val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled + val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), @@ -117,7 +118,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } try { - conf.setConf("spark.sql.planner.sortMergeJoin", "true") + ctx.conf.setConf("spark.sql.planner.sortMergeJoin", "true") Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", @@ -126,17 +127,17 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } finally { - conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + ctx.conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) } - sql("UNCACHE TABLE testData") + ctx.sql("UNCACHE TABLE testData") } test("multiple-key equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan - val planned = planner.HashJoin(join) + val planned = ctx.planner.HashJoin(join) assert(planned.size === 1) } @@ -241,7 +242,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Make sure we are choosing left.outputPartitioning as the // outputPartitioning for the outer join operator. checkAnswer( - sql( + ctx.sql( """ |SELECT l.N, count(*) |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) @@ -255,7 +256,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(6, 1) :: Nil) checkAnswer( - sql( + ctx.sql( """ |SELECT r.a, count(*) |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) @@ -301,7 +302,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Make sure we are choosing right.outputPartitioning as the // outputPartitioning for the outer join operator. checkAnswer( - sql( + ctx.sql( """ |SELECT l.a, count(*) |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -310,7 +311,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 6)) checkAnswer( - sql( + ctx.sql( """ |SELECT r.N, count(*) |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -362,7 +363,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator. checkAnswer( - sql( + ctx.sql( """ |SELECT l.a, count(*) |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -371,7 +372,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 10)) checkAnswer( - sql( + ctx.sql( """ |SELECT r.N, count(*) |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -386,7 +387,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 4) :: Nil) checkAnswer( - sql( + ctx.sql( """ |SELECT l.N, count(*) |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) @@ -401,7 +402,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 4) :: Nil) checkAnswer( - sql( + ctx.sql( """ |SELECT r.a, count(*) |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) @@ -411,11 +412,11 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("broadcasted left semi join operator selection") { - cacheManager.clearCache() - sql("CACHE TABLE testData") - val tmp = conf.autoBroadcastJoinThreshold + ctx.cacheManager.clearCache() + ctx.sql("CACHE TABLE testData") + val tmp = ctx.conf.autoBroadcastJoinThreshold - sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000") + ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000") Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[BroadcastLeftSemiJoinHash]) @@ -423,7 +424,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case (query, joinClass) => assertJoin(query, joinClass) } - sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") + ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) @@ -431,12 +432,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case (query, joinClass) => assertJoin(query, joinClass) } - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp.toString) - sql("UNCACHE TABLE testData") + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp.toString) + ctx.sql("UNCACHE TABLE testData") } test("left semi join") { - val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") + val df = ctx.sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") checkAnswer(df, Row(1, 1) :: Row(1, 2) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index 3ce97c3fffdb4..2089660c52bf7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -19,49 +19,47 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} class ListTablesSuite extends QueryTest with BeforeAndAfter { - import org.apache.spark.sql.test.TestSQLContext.implicits._ + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ - val df = - sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value") + private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value") before { df.registerTempTable("ListTablesSuiteTable") } after { - catalog.unregisterTable(Seq("ListTablesSuiteTable")) + ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) } test("get all tables") { checkAnswer( - tables().filter("tableName = 'ListTablesSuiteTable'"), + ctx.tables().filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) checkAnswer( - sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), + ctx.sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - catalog.unregisterTable(Seq("ListTablesSuiteTable")) - assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } test("getting all Tables with a database name has no impact on returned table names") { checkAnswer( - tables("DB").filter("tableName = 'ListTablesSuiteTable'"), + ctx.tables("DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) checkAnswer( - sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), + ctx.sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - catalog.unregisterTable(Seq("ListTablesSuiteTable")) - assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } test("query the returned DataFrame of tables") { @@ -69,19 +67,20 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter { StructField("tableName", StringType, false) :: StructField("isTemporary", BooleanType, false) :: Nil) - Seq(tables(), sql("SHOW TABLes")).foreach { + Seq(ctx.tables(), ctx.sql("SHOW TABLes")).foreach { case tableDF => assert(expectedSchema === tableDF.schema) tableDF.registerTempTable("tables") checkAnswer( - sql("SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"), + ctx.sql( + "SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"), Row(true, "ListTablesSuiteTable") ) checkAnswer( - tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), + ctx.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), Row("tables", true)) - dropTempTable("tables") + ctx.dropTempTable("tables") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index dd68965444f5d..0a38af2b4c889 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -17,36 +17,29 @@ package org.apache.spark.sql -import java.lang.{Double => JavaDouble} - import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.implicits._ - -private[this] object MathExpressionsTestData { - - case class DoubleData(a: JavaDouble, b: JavaDouble) - val doubleData = TestSQLContext.sparkContext.parallelize( - (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1))).toDF() - - val nnDoubleData = TestSQLContext.sparkContext.parallelize( - (1 to 10).map(i => DoubleData(i * 0.1, i * -0.1))).toDF() - - case class NullDoubles(a: JavaDouble) - val nullDoubles = - TestSQLContext.sparkContext.parallelize( - NullDoubles(1.0) :: - NullDoubles(2.0) :: - NullDoubles(3.0) :: - NullDoubles(null) :: Nil - ).toDF() + + +private object MathExpressionsTestData { + case class DoubleData(a: java.lang.Double, b: java.lang.Double) + case class NullDoubles(a: java.lang.Double) } class MathExpressionsSuite extends QueryTest { import MathExpressionsTestData._ - def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T]( + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + private lazy val doubleData = (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1)).toDF() + + private lazy val nnDoubleData = (1 to 10).map(i => DoubleData(i * 0.1, i * -0.1)).toDF() + + private lazy val nullDoubles = + Seq(NullDoubles(1.0), NullDoubles(2.0), NullDoubles(3.0), NullDoubles(null)).toDF() + + private def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T]( c: Column => Column, f: T => T): Unit = { checkAnswer( @@ -65,7 +58,8 @@ class MathExpressionsSuite extends QueryTest { ) } - def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = { + private def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = + { checkAnswer( nnDoubleData.select(c('a)), (1 to 10).map(n => Row(f(n * 0.1))) @@ -89,7 +83,7 @@ class MathExpressionsSuite extends QueryTest { ) } - def testTwoToOneMathFunction( + private def testTwoToOneMathFunction( c: (Column, Column) => Column, d: (Column, Double) => Column, f: (Double, Double) => Double): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 513ac915dcb2a..d84b57af9c882 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -21,12 +21,13 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ class RowSuite extends SparkFunSuite { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("create row") { val expected = new GenericMutableRow(4) expected.update(0, 2147483647) @@ -56,7 +57,7 @@ class RowSuite extends SparkFunSuite { test("serialize w/ kryo") { val row = Seq((1, Seq(1), Map(1 -> 1), BigDecimal(1))).toDF().first() - val serializer = new SparkSqlSerializer(TestSQLContext.sparkContext.getConf) + val serializer = new SparkSqlSerializer(ctx.sparkContext.getConf) val instance = serializer.newInstance() val ser = instance.serialize(row) val de = instance.deserialize(ser).asInstanceOf[Row] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 3a5f071e2f7cb..76d0dd1744a41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,67 +17,64 @@ package org.apache.spark.sql -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test._ - -/* Implicits */ -import TestSQLContext._ class SQLConfSuite extends QueryTest { - val testKey = "test.key.0" - val testVal = "test.val.0" + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + + private val testKey = "test.key.0" + private val testVal = "test.val.0" test("propagate from spark conf") { // We create a new context here to avoid order dependence with other tests that might call // clear(). - val newContext = new SQLContext(TestSQLContext.sparkContext) - assert(newContext.getConf("spark.sql.testkey", "false") == "true") + val newContext = new SQLContext(ctx.sparkContext) + assert(newContext.getConf("spark.sql.testkey", "false") === "true") } test("programmatic ways of basic setting and getting") { - conf.clear() - assert(getAllConfs.size === 0) + ctx.conf.clear() + assert(ctx.getAllConfs.size === 0) - setConf(testKey, testVal) - assert(getConf(testKey) == testVal) - assert(getConf(testKey, testVal + "_") == testVal) - assert(getAllConfs.contains(testKey)) + ctx.setConf(testKey, testVal) + assert(ctx.getConf(testKey) === testVal) + assert(ctx.getConf(testKey, testVal + "_") === testVal) + assert(ctx.getAllConfs.contains(testKey)) // Tests SQLConf as accessed from a SQLContext is mutable after // the latter is initialized, unlike SparkConf inside a SparkContext. - assert(TestSQLContext.getConf(testKey) == testVal) - assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal) - assert(TestSQLContext.getAllConfs.contains(testKey)) + assert(ctx.getConf(testKey) == testVal) + assert(ctx.getConf(testKey, testVal + "_") === testVal) + assert(ctx.getAllConfs.contains(testKey)) - conf.clear() + ctx.conf.clear() } test("parse SQL set commands") { - conf.clear() - sql(s"set $testKey=$testVal") - assert(getConf(testKey, testVal + "_") == testVal) - assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal) + ctx.conf.clear() + ctx.sql(s"set $testKey=$testVal") + assert(ctx.getConf(testKey, testVal + "_") === testVal) + assert(ctx.getConf(testKey, testVal + "_") === testVal) - sql("set some.property=20") - assert(getConf("some.property", "0") == "20") - sql("set some.property = 40") - assert(getConf("some.property", "0") == "40") + ctx.sql("set some.property=20") + assert(ctx.getConf("some.property", "0") === "20") + ctx.sql("set some.property = 40") + assert(ctx.getConf("some.property", "0") === "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" - sql(s"set $key=$vs") - assert(getConf(key, "0") == vs) + ctx.sql(s"set $key=$vs") + assert(ctx.getConf(key, "0") === vs) - sql(s"set $key=") - assert(getConf(key, "0") == "") + ctx.sql(s"set $key=") + assert(ctx.getConf(key, "0") === "") - conf.clear() + ctx.conf.clear() } test("deprecated property") { - conf.clear() - sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") - assert(getConf(SQLConf.SHUFFLE_PARTITIONS) == "10") + ctx.conf.clear() + ctx.sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") + assert(ctx.getConf(SQLConf.SHUFFLE_PARTITIONS) === "10") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index 797d123b48668..c8d8796568a41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -20,31 +20,29 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.TestSQLContext class SQLContextSuite extends SparkFunSuite with BeforeAndAfterAll { - private val testSqlContext = TestSQLContext - private val testSparkContext = TestSQLContext.sparkContext + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext override def afterAll(): Unit = { - SQLContext.setLastInstantiatedContext(testSqlContext) + SQLContext.setLastInstantiatedContext(ctx) } test("getOrCreate instantiates SQLContext") { SQLContext.clearLastInstantiatedContext() - val sqlContext = SQLContext.getOrCreate(testSparkContext) + val sqlContext = SQLContext.getOrCreate(ctx.sparkContext) assert(sqlContext != null, "SQLContext.getOrCreate returned null") - assert(SQLContext.getOrCreate(testSparkContext).eq(sqlContext), + assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext), "SQLContext created by SQLContext.getOrCreate not returned by SQLContext.getOrCreate") } test("getOrCreate gets last explicitly instantiated SQLContext") { SQLContext.clearLastInstantiatedContext() - val sqlContext = new SQLContext(testSparkContext) - assert(SQLContext.getOrCreate(testSparkContext) != null, + val sqlContext = new SQLContext(ctx.sparkContext) + assert(SQLContext.getOrCreate(ctx.sparkContext) != null, "SQLContext.getOrCreate after explicitly created SQLContext returned null") - assert(SQLContext.getOrCreate(testSparkContext).eq(sqlContext), + assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext), "SQLContext.getOrCreate after explicitly created SQLContext did not return the context") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 55b68d8e2283c..5babc4332cc77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -24,9 +24,7 @@ import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext} -import org.apache.spark.sql.test.TestSQLContext.{udf => _, _} - +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ @@ -36,8 +34,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { // Make sure the tables are loaded. TestData - val sqlContext = TestSQLContext + val sqlContext = org.apache.spark.sql.test.TestSQLContext import sqlContext.implicits._ + import sqlContext.sql test("SPARK-6743: no columns from cache") { Seq( @@ -46,7 +45,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { (43, 81, 24) ).toDF("a", "b", "c").registerTempTable("cachedData") - cacheTable("cachedData") + sqlContext.cacheTable("cachedData") checkAnswer( sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"), Row(0) :: Row(81) :: Nil) @@ -94,14 +93,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SQL Dialect Switching to a new SQL parser") { - val newContext = new SQLContext(TestSQLContext.sparkContext) + val newContext = new SQLContext(sqlContext.sparkContext) newContext.setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) assert(newContext.getSQLDialect().getClass === classOf[MyDialect]) assert(newContext.sql("SELECT 1").collect() === Array(Row(1))) } test("SQL Dialect Switch to an invalid parser with alias") { - val newContext = new SQLContext(TestSQLContext.sparkContext) + val newContext = new SQLContext(sqlContext.sparkContext) newContext.sql("SET spark.sql.dialect=MyTestClass") intercept[DialectException] { newContext.sql("SELECT 1") @@ -118,7 +117,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("grouping on nested fields") { - read.json(sparkContext.parallelize("""{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) + sqlContext.read.json(sqlContext.sparkContext.parallelize( + """{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) .registerTempTable("rows") checkAnswer( @@ -135,8 +135,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-6201 IN type conversion") { - read.json( - sparkContext.parallelize(Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) + sqlContext.read.json( + sqlContext.sparkContext.parallelize( + Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) .registerTempTable("d") checkAnswer( @@ -157,12 +158,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("aggregation with codegen") { - val originalValue = conf.codegenEnabled - setConf(SQLConf.CODEGEN_ENABLED, "true") + val originalValue = sqlContext.conf.codegenEnabled + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true") // Prepare a table that we can group some rows. - table("testData") - .unionAll(table("testData")) - .unionAll(table("testData")) + sqlContext.table("testData") + .unionAll(sqlContext.table("testData")) + .unionAll(sqlContext.table("testData")) .registerTempTable("testData3x") def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { @@ -254,8 +255,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { "SELECT sum('a'), avg('a'), count(null) FROM testData", Row(0, null, 0) :: Nil) } finally { - dropTempTable("testData3x") - setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + sqlContext.dropTempTable("testData3x") + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) } } @@ -447,42 +448,42 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("sorting") { - val before = conf.externalSortEnabled - setConf(SQLConf.EXTERNAL_SORT, "false") + val before = sqlContext.conf.externalSortEnabled + sqlContext.setConf(SQLConf.EXTERNAL_SORT, "false") sortTest() - setConf(SQLConf.EXTERNAL_SORT, before.toString) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, before.toString) } test("external sorting") { - val before = conf.externalSortEnabled - setConf(SQLConf.EXTERNAL_SORT, "true") + val before = sqlContext.conf.externalSortEnabled + sqlContext.setConf(SQLConf.EXTERNAL_SORT, "true") sortTest() - setConf(SQLConf.EXTERNAL_SORT, before.toString) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, before.toString) } test("SPARK-6927 sorting with codegen on") { - val externalbefore = conf.externalSortEnabled - val codegenbefore = conf.codegenEnabled - setConf(SQLConf.EXTERNAL_SORT, "false") - setConf(SQLConf.CODEGEN_ENABLED, "true") + val externalbefore = sqlContext.conf.externalSortEnabled + val codegenbefore = sqlContext.conf.codegenEnabled + sqlContext.setConf(SQLConf.EXTERNAL_SORT, "false") + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true") try{ sortTest() } finally { - setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) - setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) } } test("SPARK-6927 external sorting with codegen on") { - val externalbefore = conf.externalSortEnabled - val codegenbefore = conf.codegenEnabled - setConf(SQLConf.CODEGEN_ENABLED, "true") - setConf(SQLConf.EXTERNAL_SORT, "true") + val externalbefore = sqlContext.conf.externalSortEnabled + val codegenbefore = sqlContext.conf.codegenEnabled + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true") + sqlContext.setConf(SQLConf.EXTERNAL_SORT, "true") try { sortTest() } finally { - setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) - setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) } } @@ -516,7 +517,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("Allow only a single WITH clause per query") { intercept[RuntimeException] { - sql("with q1 as (select * from testData) with q2 as (select * from q1) select * from q2") + sql( + "with q1 as (select * from testData) with q2 as (select * from q1) select * from q2") } } @@ -863,7 +865,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SET commands semantics using sql()") { - conf.clear() + sqlContext.conf.clear() val testKey = "test.key.0" val testVal = "test.val.0" val nonexistentKey = "nonexistent" @@ -895,17 +897,17 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { sql(s"SET $nonexistentKey"), Row(s"$nonexistentKey=") ) - conf.clear() + sqlContext.conf.clear() } test("SET commands with illegal or inappropriate argument") { - conf.clear() + sqlContext.conf.clear() // Set negative mapred.reduce.tasks for automatically determing // the number of reducers is not supported intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-1")) intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-01")) intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-2")) - conf.clear() + sqlContext.conf.clear() } test("apply schema") { @@ -923,7 +925,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(values(0).toInt, values(1), values(2).toBoolean, v4) } - val df1 = createDataFrame(rowRDD1, schema1) + val df1 = sqlContext.createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") checkAnswer( sql("SELECT * FROM applySchema1"), @@ -953,7 +955,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df2 = createDataFrame(rowRDD2, schema2) + val df2 = sqlContext.createDataFrame(rowRDD2, schema2) df2.registerTempTable("applySchema2") checkAnswer( sql("SELECT * FROM applySchema2"), @@ -978,7 +980,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4)) } - val df3 = createDataFrame(rowRDD3, schema2) + val df3 = sqlContext.createDataFrame(rowRDD3, schema2) df3.registerTempTable("applySchema3") checkAnswer( @@ -1023,7 +1025,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { .build() val schemaWithMeta = new StructType(Array( schema("id"), schema("name").copy(metadata = metadata), schema("age"))) - val personWithMeta = createDataFrame(person.rdd, schemaWithMeta) + val personWithMeta = sqlContext.createDataFrame(person.rdd, schemaWithMeta) def validateMetadata(rdd: DataFrame): Unit = { assert(rdd.schema("name").metadata.getString(docKey) == docValue) } @@ -1038,7 +1040,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-3371 Renaming a function expression with group by gives error") { - TestSQLContext.udf.register("len", (s: String) => s.length) + sqlContext.udf.register("len", (s: String) => s.length) checkAnswer( sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), Row(1)) @@ -1219,9 +1221,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-3483 Special chars in column names") { - val data = sparkContext.parallelize( + val data = sqlContext.sparkContext.parallelize( Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) - read.json(data).registerTempTable("records") + sqlContext.read.json(data).registerTempTable("records") sql("SELECT `key?number1`, `key.number2` FROM records") } @@ -1262,13 +1264,15 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-4322 Grouping field with struct field as sub expression") { - read.json(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)).registerTempTable("data") + sqlContext.read.json(sqlContext.sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) + .registerTempTable("data") checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) - dropTempTable("data") + sqlContext.dropTempTable("data") - read.json(sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") + sqlContext.read.json( + sqlContext.sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) - dropTempTable("data") + sqlContext.dropTempTable("data") } test("SPARK-4432 Fix attribute reference resolution error when using ORDER BY") { @@ -1287,10 +1291,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("Supporting relational operator '<=>' in Spark SQL") { val nullCheckData1 = TestData(1, "1") :: TestData(2, null) :: Nil - val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) + val rdd1 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) rdd1.toDF().registerTempTable("nulldata1") val nullCheckData2 = TestData(1, "1") :: TestData(2, null) :: Nil - val rdd2 = sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) + val rdd2 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) rdd2.toDF().registerTempTable("nulldata2") checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " + "nulldata2 on nulldata1.value <=> nulldata2.value"), @@ -1299,22 +1303,23 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("Multi-column COUNT(DISTINCT ...)") { val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil - val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) + val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("distinctData") checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2)) } test("SPARK-4699 case sensitivity SQL query") { - setConf(SQLConf.CASE_SENSITIVE, "false") + sqlContext.setConf(SQLConf.CASE_SENSITIVE, "false") val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil - val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) + val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("testTable1") checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1")) - setConf(SQLConf.CASE_SENSITIVE, "true") + sqlContext.setConf(SQLConf.CASE_SENSITIVE, "true") } test("SPARK-6145: ORDER BY test for nested fields") { - read.json(sparkContext.makeRDD("""{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) + sqlContext.read.json(sqlContext.sparkContext.makeRDD( + """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) .registerTempTable("nestedOrder") checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1)) @@ -1326,14 +1331,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-6145: special cases") { - read.json(sparkContext.makeRDD( + sqlContext.read.json(sqlContext.sparkContext.makeRDD( """{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t") checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1)) checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1)) } test("SPARK-6898: complete support for special chars in column names") { - read.json(sparkContext.makeRDD( + sqlContext.read.json(sqlContext.sparkContext.makeRDD( """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) .registerTempTable("t") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index d2ede39f0a5f6..ece3d6fdf2af5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -21,7 +21,6 @@ import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.test.TestSQLContext._ case class ReflectData( stringField: String, @@ -75,15 +74,15 @@ case class ComplexReflectData( class ScalaReflectionRelationSuite extends SparkFunSuite { - import org.apache.spark.sql.test.TestSQLContext.implicits._ + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1, 2, 3)) - val rdd = sparkContext.parallelize(data :: Nil) - rdd.toDF().registerTempTable("reflectData") + Seq(data).toDF().registerTempTable("reflectData") - assert(sql("SELECT * FROM reflectData").collect().head === + assert(ctx.sql("SELECT * FROM reflectData").collect().head === Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3))) @@ -91,27 +90,26 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { test("query case class RDD with nulls") { val data = NullReflectData(null, null, null, null, null, null, null) - val rdd = sparkContext.parallelize(data :: Nil) - rdd.toDF().registerTempTable("reflectNullData") + Seq(data).toDF().registerTempTable("reflectNullData") - assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null))) + assert(ctx.sql("SELECT * FROM reflectNullData").collect().head === + Row.fromSeq(Seq.fill(7)(null))) } test("query case class RDD with Nones") { val data = OptionalReflectData(None, None, None, None, None, None, None) - val rdd = sparkContext.parallelize(data :: Nil) - rdd.toDF().registerTempTable("reflectOptionalData") + Seq(data).toDF().registerTempTable("reflectOptionalData") - assert(sql("SELECT * FROM reflectOptionalData").collect().head === + assert(ctx.sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } // Equality is broken for Arrays, so we test that separately. test("query binary data") { - val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil) - rdd.toDF().registerTempTable("reflectBinary") + Seq(ReflectBinary(Array[Byte](1))).toDF().registerTempTable("reflectBinary") - val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]] + val result = ctx.sql("SELECT data FROM reflectBinary") + .collect().head(0).asInstanceOf[Array[Byte]] assert(result.toSeq === Seq[Byte](1)) } @@ -127,10 +125,9 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { Map(10 -> 100L, 20 -> 200L), Map(10 -> Some(100L), 20 -> Some(200L), 30 -> None), Nested(None, "abc"))) - val rdd = sparkContext.parallelize(data :: Nil) - rdd.toDF().registerTempTable("reflectComplexData") - assert(sql("SELECT * FROM reflectComplexData").collect().head === + Seq(data).toDF().registerTempTable("reflectComplexData") + assert(ctx.sql("SELECT * FROM reflectComplexData").collect().head === new GenericRow(Array[Any]( Seq(1, 2, 3), Seq(1, 2, null), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index 1e8cde606b67b..e55c9e460b791 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.sql.test.TestSQLContext class SerializationSuite extends SparkFunSuite { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + test("[SPARK-5235] SQLContext should be serializable") { - val sqlContext = new SQLContext(TestSQLContext.sparkContext) + val sqlContext = new SQLContext(ctx.sparkContext) new JavaSerializer(new SparkConf()).newInstance().serialize(sqlContext) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 1a9ba66416b21..064c040d2b771 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,43 +17,41 @@ package org.apache.spark.sql -import org.apache.spark.sql.test._ - -/* Implicits */ -import TestSQLContext._ -import TestSQLContext.implicits._ case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("Simple UDF") { - udf.register("strLenScala", (_: String).length) - assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) + ctx.udf.register("strLenScala", (_: String).length) + assert(ctx.sql("SELECT strLenScala('test')").head().getInt(0) === 4) } test("ZeroArgument UDF") { - udf.register("random0", () => { Math.random()}) - assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) + ctx.udf.register("random0", () => { Math.random()}) + assert(ctx.sql("SELECT random0()").head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { - udf.register("strLenScala", (_: String).length + (_: Int)) - assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) + ctx.udf.register("strLenScala", (_: String).length + (_: Int)) + assert(ctx.sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } test("struct UDF") { - udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) + ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) val result = - sql("SELECT returnStruct('test', 'test2') as ret") + ctx.sql("SELECT returnStruct('test', 'test2') as ret") .select($"ret.f1").head().getString(0) assert(result === "test") } test("udf that is transformed") { - udf.register("makeStruct", (x: Int, y: Int) => (x, y)) + ctx.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) // 1 + 1 is constant folded causing a transformation. - assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) + assert(ctx.sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index dc2d43a197f40..45c9f06941c10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql -import java.io.File - -import org.apache.spark.util.Utils - import scala.beans.{BeanInfo, BeanProperty} import com.clearspring.analytics.stream.cardinality.HyperLogLog @@ -28,12 +24,11 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql} -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashSet + @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { override def equals(other: Any): Boolean = other match { @@ -72,11 +67,13 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { } class UserDefinedTypeSuite extends QueryTest { - val points = Seq( - MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), - MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) - val pointsRDD = sparkContext.parallelize(points).toDF() + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + private lazy val pointsRDD = Seq( + MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), + MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))).toDF() test("register user type: MyDenseVector for MyLabeledPoint") { val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v } @@ -94,10 +91,10 @@ class UserDefinedTypeSuite extends QueryTest { } test("UDTs and UDFs") { - TestSQLContext.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) + ctx.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) pointsRDD.registerTempTable("points") checkAnswer( - sql("SELECT testType(features) from points"), + ctx.sql("SELECT testType(features) from points"), Seq(Row(true), Row(true))) } From e5054605994b8777e629c02fcbf8a5a6cbd0b0fe Mon Sep 17 00:00:00 2001 From: Ted Blackman Date: Thu, 4 Jun 2015 22:21:11 -0700 Subject: [PATCH 194/254] [SPARK-8116][PYSPARK] Allow sc.range() to take a single argument. Author: Ted Blackman Closes #6656 from belisarius222/branch-1.4 and squashes the following commits: 747cbc2 [Ted Blackman] [SPARK-8116][PYSPARK] Allow sc.range() to take a single argument. (cherry picked from commit f02af7c8f7f43e4cfe3c412d2b5ea4128669ce22) Signed-off-by: Reynold Xin --- python/pyspark/context.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index aeb7ad4f2f83e..44d90f1437bc9 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -324,10 +324,12 @@ def stop(self): with SparkContext._lock: SparkContext._active_spark_context = None - def range(self, start, end, step=1, numSlices=None): + def range(self, start, end=None, step=1, numSlices=None): """ Create a new RDD of int containing elements from `start` to `end` - (exclusive), increased by `step` every element. + (exclusive), increased by `step` every element. Can be called the same + way as python's built-in range() function. If called with a single argument, + the argument is interpreted as `end`, and `start` is set to 0. :param start: the start value :param end: the end value (exclusive) @@ -335,9 +337,17 @@ def range(self, start, end, step=1, numSlices=None): :param numSlices: the number of partitions of the new RDD :return: An RDD of int + >>> sc.range(5).collect() + [0, 1, 2, 3, 4] + >>> sc.range(2, 4).collect() + [2, 3] >>> sc.range(1, 7, 2).collect() [1, 3, 5] """ + if end is None: + end = start + start = 0 + return self.parallelize(xrange(start, end, step), numSlices) def parallelize(self, c, numSlices=None): From 2777ed3948d26b14e342ba161e145009e31b8829 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 5 Jun 2015 07:45:25 +0200 Subject: [PATCH 195/254] [DOC][Minor]Specify the common sources available for collecting I was wondering what else common sources available until search the source code. Maybe better to make this clear. Author: Yijie Shen Closes #6641 from yijieshen/patch-1 and squashes the following commits: b5b99b4 [Yijie Shen] Make it clear that JvmSource is the only available additional source currently f23140c [Yijie Shen] [DOC][Minor]Specify the common sources available for collecting --- conf/metrics.properties.template | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index 7de0011a48ca8..7f17bc7eea4f5 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -4,7 +4,7 @@ # divided into instances which correspond to internal components. # Each instance can be configured to report its metrics to one or more sinks. # Accepted values for [instance] are "master", "worker", "executor", "driver", -# and "applications". A wild card "*" can be used as an instance name, in +# and "applications". A wildcard "*" can be used as an instance name, in # which case all instances will inherit the supplied property. # # Within an instance, a "source" specifies a particular set of grouped metrics. @@ -32,7 +32,7 @@ # name (see examples below). # 2. Some sinks involve a polling period. The minimum allowed polling period # is 1 second. -# 3. Wild card properties can be overridden by more specific properties. +# 3. Wildcard properties can be overridden by more specific properties. # For example, master.sink.console.period takes precedence over # *.sink.console.period. # 4. A metrics specific configuration @@ -47,6 +47,13 @@ # instance master and applications. MetricsServlet may not be configured by self. # +## List of available common sources and their properties. + +# org.apache.spark.metrics.source.JvmSource +# Note: Currently, JvmSource is the only available common source +# to add additionaly to an instance, to enable this, +# set the "class" option to its fully qulified class name (see examples below) + ## List of available sinks and their properties. # org.apache.spark.metrics.sink.ConsoleSink From 3a5c4da473a8a497004dfe6eacc0e6646651b227 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 5 Jun 2015 00:32:46 -0700 Subject: [PATCH 196/254] [MINOR] remove unused interpolation var in log message Completely trivial but I noticed this wrinkle in a log message today; `$sender` doesn't refer to anything and isn't interpolated here. Author: Sean Owen Closes #6650 from srowen/Interpolation and squashes the following commits: 518687a [Sean Owen] Actually interpolate log string 7edb866 [Sean Owen] Trivial: remove unused interpolation var in log message --- .../spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index fcad959540f5a..7c7f70d8a193b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -103,7 +103,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp case None => // Ignoring the update since we don't know about the executor. logWarning(s"Ignored task status update ($taskId state $state) " + - "from unknown executor $sender with ID $executorId") + s"from unknown executor with ID $executorId") } } From da20c8ca37663738112b04657057858ee3e55072 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 5 Jun 2015 10:32:33 +0200 Subject: [PATCH 197/254] [MINOR] [BUILD] Change link to jenkins builds on github. Link to the tail of the console log, instead of the full log. That's bound to have the info the user is looking for, and at the same time loads way more quickly than the (huge) full log, which is just one click away if needed. Author: Marcelo Vanzin Closes #6664 from vanzin/jenkins-link and squashes the following commits: ba07ed8 [Marcelo Vanzin] [minor] [build] Change link to jenkins builds on github. --- dev/run-tests-jenkins | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 3cbd8666c8d68..641b0ff3c4be4 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -193,7 +193,7 @@ done test_result="$?" if [ "$test_result" -eq "124" ]; then - fail_message="**[Test build ${BUILD_DISPLAY_NAME} timed out](${BUILD_URL}consoleFull)** \ + fail_message="**[Test build ${BUILD_DISPLAY_NAME} timed out](${BUILD_URL}console)** \ for PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL}) \ after a configured wait of \`${TESTS_TIMEOUT}\`." @@ -233,7 +233,7 @@ done # post end message { result_message="\ - [Test build ${BUILD_DISPLAY_NAME} has finished](${BUILD_URL}consoleFull) for \ + [Test build ${BUILD_DISPLAY_NAME} has finished](${BUILD_URL}console) for \ PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." result_message="${result_message}\n${test_result_note}" From b16b5434ff44c42e4b3a337f9af147669ba44896 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 5 Jun 2015 14:11:38 +0200 Subject: [PATCH 198/254] [MINOR] [BUILD] Use custom temp directory during build. Even with all the efforts to cleanup the temp directories created by unit tests, Spark leaves a lot of garbage in /tmp after a test run. This change overrides java.io.tmpdir to place those files under the build directory instead. After an sbt full unit test run, I was left with > 400 MB of temp files. Since they're now under the build dir, it's much easier to clean them up. Also make a slight change to a unit test to make it not pollute the source directory with test data. Author: Marcelo Vanzin Closes #6653 from vanzin/unit-test-tmp and squashes the following commits: 31e2dd5 [Marcelo Vanzin] Fix tests that depend on each other. aa92944 [Marcelo Vanzin] [minor] [build] Use custom temp directory during build. --- .../spark/deploy/SparkSubmitUtilsSuite.scala | 22 ++++++++++--------- pom.xml | 4 +++- project/SparkBuild.scala | 1 + 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 8fda5c8b472c9..07d261cc428c4 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -28,9 +28,12 @@ import org.apache.ivy.plugins.resolver.IBiblioResolver import org.apache.spark.SparkFunSuite import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate +import org.apache.spark.util.Utils class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { + private var tempIvyPath: String = _ + private val noOpOutputStream = new OutputStream { def write(b: Int) = {} } @@ -47,6 +50,7 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { super.beforeAll() // We don't want to write logs during testing SparkSubmitUtils.printStream = new BufferPrintStream + tempIvyPath = Utils.createTempDir(namePrefix = "ivy").getAbsolutePath() } test("incorrect maven coordinate throws error") { @@ -90,21 +94,20 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { } test("ivy path works correctly") { - val ivyPath = "dummy" + File.separator + "ivy" val md = SparkSubmitUtils.getModuleDescriptor val artifacts = for (i <- 0 until 3) yield new MDArtifact(md, s"jar-$i", "jar", "jar") - var jPaths = SparkSubmitUtils.resolveDependencyPaths(artifacts.toArray, new File(ivyPath)) + var jPaths = SparkSubmitUtils.resolveDependencyPaths(artifacts.toArray, new File(tempIvyPath)) for (i <- 0 until 3) { - val index = jPaths.indexOf(ivyPath) + val index = jPaths.indexOf(tempIvyPath) assert(index >= 0) - jPaths = jPaths.substring(index + ivyPath.length) + jPaths = jPaths.substring(index + tempIvyPath.length) } val main = MavenCoordinate("my.awesome.lib", "mylib", "0.1") IvyTestUtils.withRepository(main, None, None) { repo => // end to end val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, Option(repo), - Option(ivyPath), true) - assert(jarPath.indexOf(ivyPath) >= 0, "should use non-default ivy path") + Option(tempIvyPath), true) + assert(jarPath.indexOf(tempIvyPath) >= 0, "should use non-default ivy path") } } @@ -123,13 +126,12 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(jarPath.indexOf("mylib") >= 0, "should find artifact") } // Local ivy repository with modified home - val dummyIvyPath = "dummy" + File.separator + "ivy" - val dummyIvyLocal = new File(dummyIvyPath, "local" + File.separator) + val dummyIvyLocal = new File(tempIvyPath, "local" + File.separator) IvyTestUtils.withRepository(main, None, Some(dummyIvyLocal), true) { repo => val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, None, - Some(dummyIvyPath), true) + Some(tempIvyPath), true) assert(jarPath.indexOf("mylib") >= 0, "should find artifact") - assert(jarPath.indexOf(dummyIvyPath) >= 0, "should be in new ivy path") + assert(jarPath.indexOf(tempIvyPath) >= 0, "should be in new ivy path") } } diff --git a/pom.xml b/pom.xml index e28d4b9fc2b17..a848deffe7375 100644 --- a/pom.xml +++ b/pom.xml @@ -179,7 +179,7 @@ compile ${session.executionRootDirectory} @@ -1256,6 +1256,7 @@ test true + ${project.build.directory}/tmp ${spark.test.home} 1 false @@ -1289,6 +1290,7 @@ test true + ${project.build.directory}/tmp ${spark.test.home} 1 false diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index ef3a175bac209..921f1599fedef 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -496,6 +496,7 @@ object TestSettings { "SPARK_DIST_CLASSPATH" -> (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":"), "JAVA_HOME" -> sys.env.get("JAVA_HOME").getOrElse(sys.props("java.home"))), + javaOptions in Test += s"-Djava.io.tmpdir=$sparkHome/target/tmp", javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dspark.port.maxRetries=100", From 019dc9f558cf7c0b708d3b1f0882b0c19134ffb6 Mon Sep 17 00:00:00 2001 From: Akhil Das Date: Fri, 5 Jun 2015 14:23:23 +0200 Subject: [PATCH 199/254] [STREAMING] Update streaming-kafka-integration.md Fixed the broken links (Examples) in the documentation. Author: Akhil Das Closes #6666 from akhld/patch-2 and squashes the following commits: 2228b83 [Akhil Das] Update streaming-kafka-integration.md --- docs/streaming-kafka-integration.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 64714f0b799fc..d6d5605948a5a 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -29,7 +29,7 @@ Next, we discuss how to use this approach in your streaming application. [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]) You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala).
    import org.apache.spark.streaming.kafka.*; @@ -39,7 +39,7 @@ Next, we discuss how to use this approach in your streaming application. [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]); You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java).
    @@ -105,7 +105,7 @@ Next, we discuss how to use this approach in your streaming application. streamingContext, [map of Kafka parameters], [set of topics to consume]) See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala).
    import org.apache.spark.streaming.kafka.*; @@ -116,7 +116,7 @@ Next, we discuss how to use this approach in your streaming application. [map of Kafka parameters], [set of topics to consume]); See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java).
    @@ -153,4 +153,4 @@ Next, we discuss how to use this approach in your streaming application. Another thing to note is that since this approach does not use Receivers, the standard receiver-related (that is, [configurations](configuration.html) of the form `spark.streaming.receiver.*` ) will not apply to the input DStreams created by this approach (will apply to other input DStreams though). Instead, use the [configurations](configuration.html) `spark.streaming.kafka.*`. An important one is `spark.streaming.kafka.maxRatePerPartition` which is the maximum rate at which each Kafka partition will be read by this direct API. -3. **Deploying:** Similar to the first approach, you can package `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR and the launch the application using `spark-submit`. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. \ No newline at end of file +3. **Deploying:** Similar to the first approach, you can package `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR and the launch the application using `spark-submit`. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. From 700312e12f9588f01a592d6eac7bff7eb366ac8f Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 5 Jun 2015 14:32:00 +0200 Subject: [PATCH 200/254] [SPARK-6324] [CORE] Centralize handling of script usage messages. Reorganize code so that the launcher library handles most of the work of printing usage messages, instead of having an awkward protocol between the library and the scripts for that. This mostly applies to SparkSubmit, since the launcher lib does not do command line parsing for classes invoked in other ways, and thus cannot handle failures for those. Most scripts end up going through SparkSubmit, though, so it all works. The change adds a new, internal command line switch, "--usage-error", which prints the usage message and exits with a non-zero status. Scripts can override the command printed in the usage message by setting an environment variable - this avoids having to grep the output of SparkSubmit to remove references to the "spark-submit" script. The only sub-optimal part of the change is the special handling for the spark-sql usage, which is now done in SparkSubmitArguments. Author: Marcelo Vanzin Closes #5841 from vanzin/SPARK-6324 and squashes the following commits: 2821481 [Marcelo Vanzin] Merge branch 'master' into SPARK-6324 bf139b5 [Marcelo Vanzin] Filter output of Spark SQL CLI help. c6609bf [Marcelo Vanzin] Fix exit code never being used when printing usage messages. 6bc1b41 [Marcelo Vanzin] [SPARK-6324] [core] Centralize handling of script usage messages. --- bin/pyspark | 16 +--- bin/pyspark2.cmd | 1 + bin/spark-class | 13 +-- bin/spark-shell | 15 +--- bin/spark-shell2.cmd | 21 +---- bin/spark-sql | 39 +-------- bin/spark-submit | 12 --- bin/spark-submit2.cmd | 13 +-- bin/sparkR | 18 +--- .../org/apache/spark/deploy/SparkSubmit.scala | 10 +-- .../spark/deploy/SparkSubmitArguments.scala | 76 ++++++++++++++++- .../spark/deploy/SparkSubmitSuite.scala | 2 +- .../java/org/apache/spark/launcher/Main.java | 83 ++++++++++--------- .../launcher/SparkSubmitCommandBuilder.java | 18 +++- .../launcher/SparkSubmitOptionParser.java | 2 + 15 files changed, 147 insertions(+), 192 deletions(-) diff --git a/bin/pyspark b/bin/pyspark index 7cb19c51b43a2..f9dbddfa53560 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -17,24 +17,10 @@ # limitations under the License. # -# Figure out where Spark is installed export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" source "$SPARK_HOME"/bin/load-spark-env.sh - -function usage() { - if [ -n "$1" ]; then - echo $1 - fi - echo "Usage: ./bin/pyspark [options]" 1>&2 - "$SPARK_HOME"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - exit $2 -} -export -f usage - -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - usage -fi +export _SPARK_CMD_USAGE="Usage: ./bin/pyspark [options]" # In Spark <= 1.1, setting IPYTHON=1 would cause the driver to be launched using the `ipython` # executable, while the worker would still be launched using PYSPARK_PYTHON. diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 09b4149c2a439..45e9e3def5121 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -21,6 +21,7 @@ rem Figure out where the Spark framework is installed set SPARK_HOME=%~dp0.. call %SPARK_HOME%\bin\load-spark-env.cmd +set _SPARK_CMD_USAGE=Usage: bin\pyspark.cmd [options] rem Figure out which Python to use. if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( diff --git a/bin/spark-class b/bin/spark-class index c49d97ce5cf25..7bb1afe4b44f5 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -16,18 +16,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # -set -e # Figure out where Spark is installed export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" . "$SPARK_HOME"/bin/load-spark-env.sh -if [ -z "$1" ]; then - echo "Usage: spark-class []" 1>&2 - exit 1 -fi - # Find the java binary if [ -n "${JAVA_HOME}" ]; then RUNNER="${JAVA_HOME}/bin/java" @@ -98,9 +92,4 @@ CMD=() while IFS= read -d '' -r ARG; do CMD+=("$ARG") done < <("$RUNNER" -cp "$LAUNCH_CLASSPATH" org.apache.spark.launcher.Main "$@") - -if [ "${CMD[0]}" = "usage" ]; then - "${CMD[@]}" -else - exec "${CMD[@]}" -fi +exec "${CMD[@]}" diff --git a/bin/spark-shell b/bin/spark-shell index b3761b5e1375b..a6dc863d83fc6 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -29,20 +29,7 @@ esac set -o posix export FWDIR="$(cd "`dirname "$0"`"/..; pwd)" - -usage() { - if [ -n "$1" ]; then - echo "$1" - fi - echo "Usage: ./bin/spark-shell [options]" - "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - exit "$2" -} -export -f usage - -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - usage "" 0 -fi +export _SPARK_CMD_USAGE="Usage: ./bin/spark-shell [options]" # SPARK-4161: scala does not assume use of the java classpath, # so we need to add the "-Dscala.usejavacp=true" flag manually. We diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd index 00fd30fa38d36..251309d67f860 100644 --- a/bin/spark-shell2.cmd +++ b/bin/spark-shell2.cmd @@ -18,12 +18,7 @@ rem limitations under the License. rem set SPARK_HOME=%~dp0.. - -echo "%*" | findstr " \<--help\> \<-h\>" >nul -if %ERRORLEVEL% equ 0 ( - call :usage - exit /b 0 -) +set _SPARK_CMD_USAGE=Usage: .\bin\spark-shell.cmd [options] rem SPARK-4161: scala does not assume use of the java classpath, rem so we need to add the "-Dscala.usejavacp=true" flag manually. We @@ -37,16 +32,4 @@ if "x%SPARK_SUBMIT_OPTS%"=="x" ( set SPARK_SUBMIT_OPTS="%SPARK_SUBMIT_OPTS% -Dscala.usejavacp=true" :run_shell -call %SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main %* -set SPARK_ERROR_LEVEL=%ERRORLEVEL% -if not "x%SPARK_LAUNCHER_USAGE_ERROR%"=="x" ( - call :usage - exit /b 1 -) -exit /b %SPARK_ERROR_LEVEL% - -:usage -echo %SPARK_LAUNCHER_USAGE_ERROR% -echo "Usage: .\bin\spark-shell.cmd [options]" >&2 -call %SPARK_HOME%\bin\spark-submit2.cmd --help 2>&1 | findstr /V "Usage" 1>&2 -goto :eof +%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main %* diff --git a/bin/spark-sql b/bin/spark-sql index ca1729f4cfcb4..4ea7bc6e39c07 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -17,41 +17,6 @@ # limitations under the License. # -# -# Shell script for starting the Spark SQL CLI - -# Enter posix mode for bash -set -o posix - -# NOTE: This exact class name is matched downstream by SparkSubmit. -# Any changes need to be reflected there. -export CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" - -# Figure out where Spark is installed export FWDIR="$(cd "`dirname "$0"`"/..; pwd)" - -function usage { - if [ -n "$1" ]; then - echo "$1" - fi - echo "Usage: ./bin/spark-sql [options] [cli option]" - pattern="usage" - pattern+="\|Spark assembly has been built with Hive" - pattern+="\|NOTE: SPARK_PREPEND_CLASSES is set" - pattern+="\|Spark Command: " - pattern+="\|--help" - pattern+="\|=======" - - "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - echo - echo "CLI options:" - "$FWDIR"/bin/spark-class "$CLASS" --help 2>&1 | grep -v "$pattern" 1>&2 - exit "$2" -} -export -f usage - -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - usage "" 0 -fi - -exec "$FWDIR"/bin/spark-submit --class "$CLASS" "$@" +export _SPARK_CMD_USAGE="Usage: ./bin/spark-sql [options] [cli option]" +exec "$FWDIR"/bin/spark-submit --class org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver "$@" diff --git a/bin/spark-submit b/bin/spark-submit index 0e0afe71a0f05..255378b0f077c 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -22,16 +22,4 @@ SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" # disable randomized hash for string in Python 3.3+ export PYTHONHASHSEED=0 -# Only define a usage function if an upstream script hasn't done so. -if ! type -t usage >/dev/null 2>&1; then - usage() { - if [ -n "$1" ]; then - echo "$1" - fi - "$SPARK_HOME"/bin/spark-class org.apache.spark.deploy.SparkSubmit --help - exit "$2" - } - export -f usage -fi - exec "$SPARK_HOME"/bin/spark-class org.apache.spark.deploy.SparkSubmit "$@" diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd index d3fc4a5cc3f6e..651376e526928 100644 --- a/bin/spark-submit2.cmd +++ b/bin/spark-submit2.cmd @@ -24,15 +24,4 @@ rem disable randomized hash for string in Python 3.3+ set PYTHONHASHSEED=0 set CLASS=org.apache.spark.deploy.SparkSubmit -call %~dp0spark-class2.cmd %CLASS% %* -set SPARK_ERROR_LEVEL=%ERRORLEVEL% -if not "x%SPARK_LAUNCHER_USAGE_ERROR%"=="x" ( - call :usage - exit /b 1 -) -exit /b %SPARK_ERROR_LEVEL% - -:usage -echo %SPARK_LAUNCHER_USAGE_ERROR% -call %SPARK_HOME%\bin\spark-class2.cmd %CLASS% --help -goto :eof +%~dp0spark-class2.cmd %CLASS% %* diff --git a/bin/sparkR b/bin/sparkR index 8c918e2b09aef..464c29f369424 100755 --- a/bin/sparkR +++ b/bin/sparkR @@ -17,23 +17,7 @@ # limitations under the License. # -# Figure out where Spark is installed export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" - source "$SPARK_HOME"/bin/load-spark-env.sh - -function usage() { - if [ -n "$1" ]; then - echo $1 - fi - echo "Usage: ./bin/sparkR [options]" 1>&2 - "$SPARK_HOME"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - exit $2 -} -export -f usage - -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - usage -fi - +export _SPARK_CMD_USAGE="Usage: ./bin/sparkR [options]" exec "$SPARK_HOME"/bin/spark-submit sparkr-shell-main "$@" diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 8cf4d58847d8e..3aa3f948e865d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -82,13 +82,13 @@ object SparkSubmit { private val CLASS_NOT_FOUND_EXIT_STATUS = 101 // Exposed for testing - private[spark] var exitFn: () => Unit = () => System.exit(1) + private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode) private[spark] var printStream: PrintStream = System.err private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str) private[spark] def printErrorAndExit(str: String): Unit = { printStream.println("Error: " + str) printStream.println("Run with --help for usage help or --verbose for debug output") - exitFn() + exitFn(1) } private[spark] def printVersionAndExit(): Unit = { printStream.println("""Welcome to @@ -99,7 +99,7 @@ object SparkSubmit { /_/ """.format(SPARK_VERSION)) printStream.println("Type --help for more information.") - exitFn() + exitFn(0) } def main(args: Array[String]): Unit = { @@ -160,7 +160,7 @@ object SparkSubmit { // detect exceptions with empty stack traces here, and treat them differently. if (e.getStackTrace().length == 0) { printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") - exitFn() + exitFn(1) } else { throw e } @@ -700,7 +700,7 @@ object SparkSubmit { /** * Return whether the given main class represents a sql shell. */ - private def isSqlShell(mainClass: String): Boolean = { + private[deploy] def isSqlShell(mainClass: String): Boolean = { mainClass == "org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index cc6a7bd9f4119..b7429a901e162 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -17,12 +17,15 @@ package org.apache.spark.deploy +import java.io.{ByteArrayOutputStream, PrintStream} +import java.lang.reflect.InvocationTargetException import java.net.URI import java.util.{List => JList} import java.util.jar.JarFile import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.io.Source import org.apache.spark.deploy.SparkSubmitAction._ import org.apache.spark.launcher.SparkSubmitArgumentsParser @@ -412,6 +415,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case VERSION => SparkSubmit.printVersionAndExit() + case USAGE_ERROR => + printUsageAndExit(1) + case _ => throw new IllegalArgumentException(s"Unexpected argument '$opt'.") } @@ -449,11 +455,14 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S if (unknownParam != null) { outStream.println("Unknown/unsupported param " + unknownParam) } - outStream.println( + val command = sys.env.get("_SPARK_CMD_USAGE").getOrElse( """Usage: spark-submit [options] [app arguments] |Usage: spark-submit --kill [submission ID] --master [spark://...] - |Usage: spark-submit --status [submission ID] --master [spark://...] - | + |Usage: spark-submit --status [submission ID] --master [spark://...]""".stripMargin) + outStream.println(command) + + outStream.println( + """ |Options: | --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local. | --deploy-mode DEPLOY_MODE Whether to launch the driver program locally ("client") or @@ -525,6 +534,65 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | delegation tokens periodically. """.stripMargin ) - SparkSubmit.exitFn() + + if (SparkSubmit.isSqlShell(mainClass)) { + outStream.println("CLI options:") + outStream.println(getSqlShellOptions()) + } + + SparkSubmit.exitFn(exitCode) } + + /** + * Run the Spark SQL CLI main class with the "--help" option and catch its output. Then filter + * the results to remove unwanted lines. + * + * Since the CLI will call `System.exit()`, we install a security manager to prevent that call + * from working, and restore the original one afterwards. + */ + private def getSqlShellOptions(): String = { + val currentOut = System.out + val currentErr = System.err + val currentSm = System.getSecurityManager() + try { + val out = new ByteArrayOutputStream() + val stream = new PrintStream(out) + System.setOut(stream) + System.setErr(stream) + + val sm = new SecurityManager() { + override def checkExit(status: Int): Unit = { + throw new SecurityException() + } + + override def checkPermission(perm: java.security.Permission): Unit = {} + } + System.setSecurityManager(sm) + + try { + Class.forName(mainClass).getMethod("main", classOf[Array[String]]) + .invoke(null, Array(HELP)) + } catch { + case e: InvocationTargetException => + // Ignore SecurityException, since we throw it above. + if (!e.getCause().isInstanceOf[SecurityException]) { + throw e + } + } + + stream.flush() + + // Get the output and discard any unnecessary lines from it. + Source.fromString(new String(out.toByteArray())).getLines + .filter { line => + !line.startsWith("log4j") && !line.startsWith("usage") + } + .mkString("\n") + } finally { + System.setSecurityManager(currentSm) + System.setOut(currentOut) + System.setErr(currentErr) + } + } + } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 46369457f000a..46ea28d0f18f6 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -62,7 +62,7 @@ class SparkSubmitSuite SparkSubmit.printStream = printStream @volatile var exitedCleanly = false - SparkSubmit.exitFn = () => exitedCleanly = true + SparkSubmit.exitFn = (_) => exitedCleanly = true val thread = new Thread { override def run() = try { diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index 929b29a49ed70..62492f9baf3bb 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -53,21 +53,33 @@ public static void main(String[] argsArray) throws Exception { List args = new ArrayList(Arrays.asList(argsArray)); String className = args.remove(0); - boolean printLaunchCommand; - boolean printUsage; + boolean printLaunchCommand = !isEmpty(System.getenv("SPARK_PRINT_LAUNCH_COMMAND")); AbstractCommandBuilder builder; - try { - if (className.equals("org.apache.spark.deploy.SparkSubmit")) { + if (className.equals("org.apache.spark.deploy.SparkSubmit")) { + try { builder = new SparkSubmitCommandBuilder(args); - } else { - builder = new SparkClassCommandBuilder(className, args); + } catch (IllegalArgumentException e) { + printLaunchCommand = false; + System.err.println("Error: " + e.getMessage()); + System.err.println(); + + MainClassOptionParser parser = new MainClassOptionParser(); + try { + parser.parse(args); + } catch (Exception ignored) { + // Ignore parsing exceptions. + } + + List help = new ArrayList(); + if (parser.className != null) { + help.add(parser.CLASS); + help.add(parser.className); + } + help.add(parser.USAGE_ERROR); + builder = new SparkSubmitCommandBuilder(help); } - printLaunchCommand = !isEmpty(System.getenv("SPARK_PRINT_LAUNCH_COMMAND")); - printUsage = false; - } catch (IllegalArgumentException e) { - builder = new UsageCommandBuilder(e.getMessage()); - printLaunchCommand = false; - printUsage = true; + } else { + builder = new SparkClassCommandBuilder(className, args); } Map env = new HashMap(); @@ -78,13 +90,7 @@ public static void main(String[] argsArray) throws Exception { } if (isWindows()) { - // When printing the usage message, we can't use "cmd /v" since that prevents the env - // variable from being seen in the caller script. So do not call prepareWindowsCommand(). - if (printUsage) { - System.out.println(join(" ", cmd)); - } else { - System.out.println(prepareWindowsCommand(cmd, env)); - } + System.out.println(prepareWindowsCommand(cmd, env)); } else { // In bash, use NULL as the arg separator since it cannot be used in an argument. List bashCmd = prepareBashCommand(cmd, env); @@ -135,33 +141,30 @@ private static List prepareBashCommand(List cmd, Map buildCommand(Map env) { - if (isWindows()) { - return Arrays.asList("set", "SPARK_LAUNCHER_USAGE_ERROR=" + message); - } else { - return Arrays.asList("usage", message, "1"); - } + protected boolean handleUnknown(String opt) { + return false; + } + + @Override + protected void handleExtraArgs(List extra) { + } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 7d387d406edae..3e5a2820b6c11 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -77,6 +77,7 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { } private final List sparkArgs; + private final boolean printHelp; /** * Controls whether mixing spark-submit arguments with app arguments is allowed. This is needed @@ -87,10 +88,11 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { SparkSubmitCommandBuilder() { this.sparkArgs = new ArrayList(); + this.printHelp = false; } SparkSubmitCommandBuilder(List args) { - this(); + this.sparkArgs = new ArrayList(); List submitArgs = args; if (args.size() > 0 && args.get(0).equals(PYSPARK_SHELL)) { this.allowsMixedArguments = true; @@ -104,14 +106,16 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { this.allowsMixedArguments = false; } - new OptionParser().parse(submitArgs); + OptionParser parser = new OptionParser(); + parser.parse(submitArgs); + this.printHelp = parser.helpRequested; } @Override public List buildCommand(Map env) throws IOException { - if (PYSPARK_SHELL_RESOURCE.equals(appResource)) { + if (PYSPARK_SHELL_RESOURCE.equals(appResource) && !printHelp) { return buildPySparkShellCommand(env); - } else if (SPARKR_SHELL_RESOURCE.equals(appResource)) { + } else if (SPARKR_SHELL_RESOURCE.equals(appResource) && !printHelp) { return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); @@ -311,6 +315,8 @@ private boolean isThriftServer(String mainClass) { private class OptionParser extends SparkSubmitOptionParser { + boolean helpRequested = false; + @Override protected boolean handle(String opt, String value) { if (opt.equals(MASTER)) { @@ -341,6 +347,9 @@ protected boolean handle(String opt, String value) { allowsMixedArguments = true; appResource = specialClasses.get(value); } + } else if (opt.equals(HELP) || opt.equals(USAGE_ERROR)) { + helpRequested = true; + sparkArgs.add(opt); } else { sparkArgs.add(opt); if (value != null) { @@ -360,6 +369,7 @@ protected boolean handleUnknown(String opt) { appArgs.add(opt); return true; } else { + checkArgument(!opt.startsWith("-"), "Unrecognized option: %s", opt); sparkArgs.add(opt); return false; } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java index 229000087688f..b88bba883ac65 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java @@ -61,6 +61,7 @@ class SparkSubmitOptionParser { // Options that do not take arguments. protected final String HELP = "--help"; protected final String SUPERVISE = "--supervise"; + protected final String USAGE_ERROR = "--usage-error"; protected final String VERBOSE = "--verbose"; protected final String VERSION = "--version"; @@ -120,6 +121,7 @@ class SparkSubmitOptionParser { final String[][] switches = { { HELP, "-h" }, { SUPERVISE }, + { USAGE_ERROR }, { VERBOSE, "-v" }, { VERSION }, }; From bc0d76a246cc534234b96a661d70feb94b26538c Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 5 Jun 2015 23:06:19 +0800 Subject: [PATCH 201/254] [SQL] Simplifies binary node pattern matching This PR is a simpler version of #2764, and adds `unapply` methods to the following binary nodes for simpler pattern matching: - `BinaryExpression` - `BinaryComparison` - `BinaryArithmetics` This enables nested pattern matching for binary nodes. For example, the following pattern matching ```scala case p: BinaryComparison if p.left.dataType == StringType && p.right.dataType == DateType => p.makeCopy(Array(p.left, Cast(p.right, StringType))) ``` can be simplified to ```scala case p BinaryComparison(l StringType(), r DateType()) => p.makeCopy(Array(l, Cast(r, StringType))) ``` Author: Cheng Lian Closes #6537 from liancheng/binary-node-patmat and squashes the following commits: a3bf5fe [Cheng Lian] Fixes compilation error introduced while rebasing b738986 [Cheng Lian] Renames `l`/`r` to `left`/`right` or `lhs`/`rhs` 14900ae [Cheng Lian] Simplifies binary node pattern matching --- .../catalyst/analysis/HiveTypeCoercion.scala | 215 ++++++++---------- .../sql/catalyst/expressions/Expression.scala | 4 + .../sql/catalyst/expressions/arithmetic.scala | 4 + .../sql/catalyst/expressions/predicates.scala | 5 +- .../sql/catalyst/optimizer/Optimizer.scala | 19 +- 5 files changed, 119 insertions(+), 128 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index b064600e94fac..9b8a08a88dcb0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -130,7 +130,7 @@ trait HiveTypeCoercion { * the appropriate numeric equivalent. */ object ConvertNaNs extends Rule[LogicalPlan] { - private val stringNaN = Literal("NaN") + private val StringNaN = Literal("NaN") def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressions { @@ -138,20 +138,20 @@ trait HiveTypeCoercion { case e if !e.childrenResolved => e /* Double Conversions */ - case b: BinaryExpression if b.left == stringNaN && b.right.dataType == DoubleType => - b.makeCopy(Array(b.right, Literal(Double.NaN))) - case b: BinaryExpression if b.left.dataType == DoubleType && b.right == stringNaN => - b.makeCopy(Array(Literal(Double.NaN), b.left)) - case b: BinaryExpression if b.left == stringNaN && b.right == stringNaN => - b.makeCopy(Array(Literal(Double.NaN), b.left)) + case b @ BinaryExpression(StringNaN, right @ DoubleType()) => + b.makeCopy(Array(Literal(Double.NaN), right)) + case b @ BinaryExpression(left @ DoubleType(), StringNaN) => + b.makeCopy(Array(left, Literal(Double.NaN))) /* Float Conversions */ - case b: BinaryExpression if b.left == stringNaN && b.right.dataType == FloatType => - b.makeCopy(Array(b.right, Literal(Float.NaN))) - case b: BinaryExpression if b.left.dataType == FloatType && b.right == stringNaN => - b.makeCopy(Array(Literal(Float.NaN), b.left)) - case b: BinaryExpression if b.left == stringNaN && b.right == stringNaN => - b.makeCopy(Array(Literal(Float.NaN), b.left)) + case b @ BinaryExpression(StringNaN, right @ FloatType()) => + b.makeCopy(Array(Literal(Float.NaN), right)) + case b @ BinaryExpression(left @ FloatType(), StringNaN) => + b.makeCopy(Array(left, Literal(Float.NaN))) + + /* Use float NaN by default to avoid unnecessary type widening */ + case b @ BinaryExpression(left @ StringNaN, StringNaN) => + b.makeCopy(Array(left, Literal(Float.NaN))) } } } @@ -184,21 +184,25 @@ trait HiveTypeCoercion { case u @ Union(left, right) if u.childrenResolved && !u.resolved => val castedInput = left.output.zip(right.output).map { // When a string is found on one side, make the other side a string too. - case (l, r) if l.dataType == StringType && r.dataType != StringType => - (l, Alias(Cast(r, StringType), r.name)()) - case (l, r) if l.dataType != StringType && r.dataType == StringType => - (Alias(Cast(l, StringType), l.name)(), r) - - case (l, r) if l.dataType != r.dataType => - logDebug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") - findTightestCommonTypeOfTwo(l.dataType, r.dataType).map { widestType => + case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType => + (lhs, Alias(Cast(rhs, StringType), rhs.name)()) + case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType => + (Alias(Cast(lhs, StringType), lhs.name)(), rhs) + + case (lhs, rhs) if lhs.dataType != rhs.dataType => + logDebug(s"Resolving mismatched union input ${lhs.dataType}, ${rhs.dataType}") + findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType => val newLeft = - if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)() + if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, widestType), lhs.name)() val newRight = - if (r.dataType == widestType) r else Alias(Cast(r, widestType), r.name)() + if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, widestType), rhs.name)() (newLeft, newRight) - }.getOrElse((l, r)) // If there is no applicable conversion, leave expression unchanged. + }.getOrElse { + // If there is no applicable conversion, leave expression unchanged. + (lhs, rhs) + } + case other => other } @@ -227,12 +231,10 @@ trait HiveTypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case b: BinaryExpression if b.left.dataType != b.right.dataType => - findTightestCommonTypeOfTwo(b.left.dataType, b.right.dataType).map { widestType => - val newLeft = - if (b.left.dataType == widestType) b.left else Cast(b.left, widestType) - val newRight = - if (b.right.dataType == widestType) b.right else Cast(b.right, widestType) + case b @ BinaryExpression(left, right) if left.dataType != right.dataType => + findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType => + val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) + val newRight = if (right.dataType == widestType) right else Cast(right, widestType) b.makeCopy(Array(newLeft, newRight)) }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. } @@ -247,57 +249,42 @@ trait HiveTypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a: BinaryArithmetic if a.left.dataType == StringType => - a.makeCopy(Array(Cast(a.left, DoubleType), a.right)) - case a: BinaryArithmetic if a.right.dataType == StringType => - a.makeCopy(Array(a.left, Cast(a.right, DoubleType))) + case a @ BinaryArithmetic(left @ StringType(), r) => + a.makeCopy(Array(Cast(left, DoubleType), r)) + case a @ BinaryArithmetic(left, right @ StringType()) => + a.makeCopy(Array(left, Cast(right, DoubleType))) // we should cast all timestamp/date/string compare into string compare - case p: BinaryComparison if p.left.dataType == StringType && - p.right.dataType == DateType => - p.makeCopy(Array(p.left, Cast(p.right, StringType))) - case p: BinaryComparison if p.left.dataType == DateType && - p.right.dataType == StringType => - p.makeCopy(Array(Cast(p.left, StringType), p.right)) - case p: BinaryComparison if p.left.dataType == StringType && - p.right.dataType == TimestampType => - p.makeCopy(Array(Cast(p.left, TimestampType), p.right)) - case p: BinaryComparison if p.left.dataType == TimestampType && - p.right.dataType == StringType => - p.makeCopy(Array(p.left, Cast(p.right, TimestampType))) - case p: BinaryComparison if p.left.dataType == TimestampType && - p.right.dataType == DateType => - p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType))) - case p: BinaryComparison if p.left.dataType == DateType && - p.right.dataType == TimestampType => - p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType))) - - case p: BinaryComparison if p.left.dataType == StringType && - p.right.dataType != StringType => - p.makeCopy(Array(Cast(p.left, DoubleType), p.right)) - case p: BinaryComparison if p.left.dataType != StringType && - p.right.dataType == StringType => - p.makeCopy(Array(p.left, Cast(p.right, DoubleType))) - - case i @ In(a, b) if a.dataType == DateType && - b.forall(_.dataType == StringType) => + case p @ BinaryComparison(left @ StringType(), right @ DateType()) => + p.makeCopy(Array(left, Cast(right, StringType))) + case p @ BinaryComparison(left @ DateType(), right @ StringType()) => + p.makeCopy(Array(Cast(left, StringType), right)) + case p @ BinaryComparison(left @ StringType(), right @ TimestampType()) => + p.makeCopy(Array(Cast(left, TimestampType), right)) + case p @ BinaryComparison(left @ TimestampType(), right @ StringType()) => + p.makeCopy(Array(left, Cast(right, TimestampType))) + case p @ BinaryComparison(left @ TimestampType(), right @ DateType()) => + p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) + case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) => + p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) + + case p @ BinaryComparison(left @ StringType(), right) if right.dataType != StringType => + p.makeCopy(Array(Cast(left, DoubleType), right)) + case p @ BinaryComparison(left, right @ StringType()) if left.dataType != StringType => + p.makeCopy(Array(left, Cast(right, DoubleType))) + + case i @ In(a @ DateType(), b) if b.forall(_.dataType == StringType) => i.makeCopy(Array(Cast(a, StringType), b)) - case i @ In(a, b) if a.dataType == TimestampType && - b.forall(_.dataType == StringType) => + case i @ In(a @ TimestampType(), b) if b.forall(_.dataType == StringType) => i.makeCopy(Array(a, b.map(Cast(_, TimestampType)))) - case i @ In(a, b) if a.dataType == DateType && - b.forall(_.dataType == TimestampType) => + case i @ In(a @ DateType(), b) if b.forall(_.dataType == TimestampType) => i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) - case i @ In(a, b) if a.dataType == TimestampType && - b.forall(_.dataType == DateType) => + case i @ In(a @ TimestampType(), b) if b.forall(_.dataType == DateType) => i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) - case Sum(e) if e.dataType == StringType => - Sum(Cast(e, DoubleType)) - case Average(e) if e.dataType == StringType => - Average(Cast(e, DoubleType)) - case Sqrt(e) if e.dataType == StringType => - Sqrt(Cast(e, DoubleType)) + case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) + case Average(e @ StringType()) => Average(Cast(e, DoubleType)) + case Sqrt(e @ StringType()) => Sqrt(Cast(e, DoubleType)) } } @@ -379,22 +366,22 @@ trait HiveTypeCoercion { // fix decimal precision for union case u @ Union(left, right) if u.childrenResolved && !u.resolved => val castedInput = left.output.zip(right.output).map { - case (l, r) if l.dataType != r.dataType => - (l.dataType, r.dataType) match { + case (lhs, rhs) if lhs.dataType != rhs.dataType => + (lhs.dataType, rhs.dataType) match { case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) => // Union decimals with precision/scale p1/s2 and p2/s2 will be promoted to // DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)) val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2)) - (Alias(Cast(l, fixedType), l.name)(), Alias(Cast(r, fixedType), r.name)()) + (Alias(Cast(lhs, fixedType), lhs.name)(), Alias(Cast(rhs, fixedType), rhs.name)()) case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => - (Alias(Cast(l, intTypeToFixed(t)), l.name)(), r) + (Alias(Cast(lhs, intTypeToFixed(t)), lhs.name)(), rhs) case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => - (l, Alias(Cast(r, intTypeToFixed(t)), r.name)()) + (lhs, Alias(Cast(rhs, intTypeToFixed(t)), rhs.name)()) case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) => - (Alias(Cast(l, floatTypeToFixed(t)), l.name)(), r) + (Alias(Cast(lhs, floatTypeToFixed(t)), lhs.name)(), rhs) case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) => - (l, Alias(Cast(r, floatTypeToFixed(t)), r.name)()) - case _ => (l, r) + (lhs, Alias(Cast(rhs, floatTypeToFixed(t)), rhs.name)()) + case _ => (lhs, rhs) } case other => other } @@ -467,16 +454,16 @@ trait HiveTypeCoercion { // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles - case b: BinaryExpression if b.left.dataType != b.right.dataType => - (b.left.dataType, b.right.dataType) match { + case b @ BinaryExpression(left, right) if left.dataType != right.dataType => + (left.dataType, right.dataType) match { case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => - b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right)) + b.makeCopy(Array(Cast(left, intTypeToFixed(t)), right)) case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => - b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t)))) + b.makeCopy(Array(left, Cast(right, intTypeToFixed(t)))) case (t, DecimalType.Fixed(p, s)) if isFloat(t) => - b.makeCopy(Array(b.left, Cast(b.right, DoubleType))) + b.makeCopy(Array(left, Cast(right, DoubleType))) case (DecimalType.Fixed(p, s), t) if isFloat(t) => - b.makeCopy(Array(Cast(b.left, DoubleType), b.right)) + b.makeCopy(Array(Cast(left, DoubleType), right)) case _ => b } @@ -525,31 +512,31 @@ trait HiveTypeCoercion { // all other cases are considered as false. // We may simplify the expression if one side is literal numeric values - case EqualTo(l @ BooleanType(), Literal(value, _: NumericType)) - if trueValues.contains(value) => l - case EqualTo(l @ BooleanType(), Literal(value, _: NumericType)) - if falseValues.contains(value) => Not(l) - case EqualTo(Literal(value, _: NumericType), r @ BooleanType()) - if trueValues.contains(value) => r - case EqualTo(Literal(value, _: NumericType), r @ BooleanType()) - if falseValues.contains(value) => Not(r) - case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType)) - if trueValues.contains(value) => And(IsNotNull(l), l) - case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType)) - if falseValues.contains(value) => And(IsNotNull(l), Not(l)) - case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType()) - if trueValues.contains(value) => And(IsNotNull(r), r) - case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType()) - if falseValues.contains(value) => And(IsNotNull(r), Not(r)) - - case EqualTo(l @ BooleanType(), r @ NumericType()) => - transform(l , r) - case EqualTo(l @ NumericType(), r @ BooleanType()) => - transform(r, l) - case EqualNullSafe(l @ BooleanType(), r @ NumericType()) => - transformNullSafe(l, r) - case EqualNullSafe(l @ NumericType(), r @ BooleanType()) => - transformNullSafe(r, l) + case EqualTo(left @ BooleanType(), Literal(value, _: NumericType)) + if trueValues.contains(value) => left + case EqualTo(left @ BooleanType(), Literal(value, _: NumericType)) + if falseValues.contains(value) => Not(left) + case EqualTo(Literal(value, _: NumericType), right @ BooleanType()) + if trueValues.contains(value) => right + case EqualTo(Literal(value, _: NumericType), right @ BooleanType()) + if falseValues.contains(value) => Not(right) + case EqualNullSafe(left @ BooleanType(), Literal(value, _: NumericType)) + if trueValues.contains(value) => And(IsNotNull(left), left) + case EqualNullSafe(left @ BooleanType(), Literal(value, _: NumericType)) + if falseValues.contains(value) => And(IsNotNull(left), Not(left)) + case EqualNullSafe(Literal(value, _: NumericType), right @ BooleanType()) + if trueValues.contains(value) => And(IsNotNull(right), right) + case EqualNullSafe(Literal(value, _: NumericType), right @ BooleanType()) + if falseValues.contains(value) => And(IsNotNull(right), Not(right)) + + case EqualTo(left @ BooleanType(), right @ NumericType()) => + transform(left , right) + case EqualTo(left @ NumericType(), right @ BooleanType()) => + transform(right, left) + case EqualNullSafe(left @ BooleanType(), right @ NumericType()) => + transformNullSafe(left, right) + case EqualNullSafe(left @ NumericType(), right @ BooleanType()) => + transformNullSafe(right, left) } } @@ -630,7 +617,7 @@ trait HiveTypeCoercion { case d: Divide if d.dataType == DoubleType => d case d: Divide if d.dataType.isInstanceOf[DecimalType] => d - case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType)) + case Divide(left, right) => Divide(Cast(left, DoubleType), Cast(right, DoubleType)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 3cf851aec15ea..b2b9d1a5e1581 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -118,6 +118,10 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express override def toString: String = s"($left $symbol $right)" } +private[sql] object BinaryExpression { + def unapply(e: BinaryExpression): Option[(Expression, Expression)] = Some((e.left, e.right)) +} + abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { self: Product => } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 2ac53f8f6613f..a3770f998d94d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -118,6 +118,10 @@ abstract class BinaryArithmetic extends BinaryExpression { sys.error(s"BinaryArithmetics must override either eval or evalInternal") } +private[sql] object BinaryArithmetic { + def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right)) +} + case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 807021d50e8e0..58273b166fe91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -202,9 +202,8 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { sys.error(s"BinaryComparisons must override either eval or evalInternal") } -object BinaryComparison { - def unapply(b: BinaryComparison): Option[(Expression, Expression)] = - Some((b.left, b.right)) +private[sql] object BinaryComparison { + def unapply(e: BinaryComparison): Option[(Expression, Expression)] = Some((e.left, e.right)) } case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0a17b10c521e5..c16f08d389955 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -266,7 +266,7 @@ object NullPropagation extends Rule[LogicalPlan] { if (newChildren.length == 0) { Literal.create(null, e.dataType) } else if (newChildren.length == 1) { - newChildren(0) + newChildren.head } else { Coalesce(newChildren) } @@ -280,21 +280,18 @@ object NullPropagation extends Rule[LogicalPlan] { case e: MinOf => e // Put exceptional cases above if any - case e: BinaryArithmetic => e.children match { - case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) - case _ => e - } - case e: BinaryComparison => e.children match { - case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) - case _ => e - } + case e @ BinaryArithmetic(Literal(null, _), _) => Literal.create(null, e.dataType) + case e @ BinaryArithmetic(_, Literal(null, _)) => Literal.create(null, e.dataType) + + case e @ BinaryComparison(Literal(null, _), _) => Literal.create(null, e.dataType) + case e @ BinaryComparison(_, Literal(null, _)) => Literal.create(null, e.dataType) + case e: StringRegexExpression => e.children match { case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e } + case e: StringComparison => e.children match { case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) From 12f5eaeee1235850a076ce5716d069bd2f1205a5 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Fri, 5 Jun 2015 10:19:03 -0700 Subject: [PATCH 202/254] [SPARK-8085] [SPARKR] Support user-specified schema in read.df cc davies sun-rui Author: Shivaram Venkataraman Closes #6620 from shivaram/sparkr-read-schema and squashes the following commits: 16a6726 [Shivaram Venkataraman] Fix loadDF to pass schema Also add a unit test a229877 [Shivaram Venkataraman] Use wrapper function to DataFrameReader ee70ba8 [Shivaram Venkataraman] Support user-specified schema in read.df --- R/pkg/R/SQLContext.R | 14 ++++++++++---- R/pkg/inst/tests/test_sparkSQL.R | 13 +++++++++++++ .../org/apache/spark/sql/api/r/SQLUtils.scala | 15 +++++++++++++++ 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 88e1a508f37c4..22a4b5bf86ebd 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -452,7 +452,7 @@ dropTempTable <- function(sqlContext, tableName) { #' df <- read.df(sqlContext, "path/to/file.json", source = "json") #' } -read.df <- function(sqlContext, path = NULL, source = NULL, ...) { +read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { options <- varargsToEnv(...) if (!is.null(path)) { options[['path']] <- path @@ -462,15 +462,21 @@ read.df <- function(sqlContext, path = NULL, source = NULL, ...) { source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", "org.apache.spark.sql.parquet") } - sdf <- callJMethod(sqlContext, "load", source, options) + if (!is.null(schema)) { + stopifnot(class(schema) == "structType") + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source, + schema$jobj, options) + } else { + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source, options) + } dataFrame(sdf) } #' @aliases loadDF #' @export -loadDF <- function(sqlContext, path = NULL, source = NULL, ...) { - read.df(sqlContext, path, source, ...) +loadDF <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { + read.df(sqlContext, path, source, schema, ...) } #' Create an external table diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index d2d82e791e876..30edfc8a7bd94 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -504,6 +504,19 @@ test_that("read.df() from json file", { df <- read.df(sqlContext, jsonPath, "json") expect_true(inherits(df, "DataFrame")) expect_true(count(df) == 3) + + # Check if we can apply a user defined schema + schema <- structType(structField("name", type = "string"), + structField("age", type = "double")) + + df1 <- read.df(sqlContext, jsonPath, "json", schema) + expect_true(inherits(df1, "DataFrame")) + expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) + + # Run the same with loadDF + df2 <- loadDF(sqlContext, jsonPath, "json", schema) + expect_true(inherits(df2, "DataFrame")) + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) }) test_that("write.df() as parquet file", { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 604f3124e23ae..43b62f0e822f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -139,4 +139,19 @@ private[r] object SQLUtils { case "ignore" => SaveMode.Ignore } } + + def loadDF( + sqlContext: SQLContext, + source: String, + options: java.util.Map[String, String]): DataFrame = { + sqlContext.read.format(source).options(options).load() + } + + def loadDF( + sqlContext: SQLContext, + source: String, + schema: StructType, + options: java.util.Map[String, String]): DataFrame = { + sqlContext.read.format(source).schema(schema).options(options).load() + } } From 4036d05ceeec77ebfa9c683cbc699250df3e3895 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 5 Jun 2015 10:53:32 -0700 Subject: [PATCH 203/254] Revert "[MINOR] [BUILD] Use custom temp directory during build." This reverts commit b16b5434ff44c42e4b3a337f9af147669ba44896. --- .../spark/deploy/SparkSubmitUtilsSuite.scala | 22 +++++++++---------- pom.xml | 4 +--- project/SparkBuild.scala | 1 - 3 files changed, 11 insertions(+), 16 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 07d261cc428c4..8fda5c8b472c9 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -28,12 +28,9 @@ import org.apache.ivy.plugins.resolver.IBiblioResolver import org.apache.spark.SparkFunSuite import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate -import org.apache.spark.util.Utils class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { - private var tempIvyPath: String = _ - private val noOpOutputStream = new OutputStream { def write(b: Int) = {} } @@ -50,7 +47,6 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { super.beforeAll() // We don't want to write logs during testing SparkSubmitUtils.printStream = new BufferPrintStream - tempIvyPath = Utils.createTempDir(namePrefix = "ivy").getAbsolutePath() } test("incorrect maven coordinate throws error") { @@ -94,20 +90,21 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { } test("ivy path works correctly") { + val ivyPath = "dummy" + File.separator + "ivy" val md = SparkSubmitUtils.getModuleDescriptor val artifacts = for (i <- 0 until 3) yield new MDArtifact(md, s"jar-$i", "jar", "jar") - var jPaths = SparkSubmitUtils.resolveDependencyPaths(artifacts.toArray, new File(tempIvyPath)) + var jPaths = SparkSubmitUtils.resolveDependencyPaths(artifacts.toArray, new File(ivyPath)) for (i <- 0 until 3) { - val index = jPaths.indexOf(tempIvyPath) + val index = jPaths.indexOf(ivyPath) assert(index >= 0) - jPaths = jPaths.substring(index + tempIvyPath.length) + jPaths = jPaths.substring(index + ivyPath.length) } val main = MavenCoordinate("my.awesome.lib", "mylib", "0.1") IvyTestUtils.withRepository(main, None, None) { repo => // end to end val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, Option(repo), - Option(tempIvyPath), true) - assert(jarPath.indexOf(tempIvyPath) >= 0, "should use non-default ivy path") + Option(ivyPath), true) + assert(jarPath.indexOf(ivyPath) >= 0, "should use non-default ivy path") } } @@ -126,12 +123,13 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(jarPath.indexOf("mylib") >= 0, "should find artifact") } // Local ivy repository with modified home - val dummyIvyLocal = new File(tempIvyPath, "local" + File.separator) + val dummyIvyPath = "dummy" + File.separator + "ivy" + val dummyIvyLocal = new File(dummyIvyPath, "local" + File.separator) IvyTestUtils.withRepository(main, None, Some(dummyIvyLocal), true) { repo => val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, None, - Some(tempIvyPath), true) + Some(dummyIvyPath), true) assert(jarPath.indexOf("mylib") >= 0, "should find artifact") - assert(jarPath.indexOf(tempIvyPath) >= 0, "should be in new ivy path") + assert(jarPath.indexOf(dummyIvyPath) >= 0, "should be in new ivy path") } } diff --git a/pom.xml b/pom.xml index a848deffe7375..e28d4b9fc2b17 100644 --- a/pom.xml +++ b/pom.xml @@ -179,7 +179,7 @@ compile ${session.executionRootDirectory} @@ -1256,7 +1256,6 @@ test true - ${project.build.directory}/tmp ${spark.test.home} 1 false @@ -1290,7 +1289,6 @@ test true - ${project.build.directory}/tmp ${spark.test.home} 1 false diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 921f1599fedef..ef3a175bac209 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -496,7 +496,6 @@ object TestSettings { "SPARK_DIST_CLASSPATH" -> (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":"), "JAVA_HOME" -> sys.env.get("JAVA_HOME").getOrElse(sys.props("java.home"))), - javaOptions in Test += s"-Djava.io.tmpdir=$sparkHome/target/tmp", javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dspark.port.maxRetries=100", From 0992a0a77d38081c6c206bb34333013125d85376 Mon Sep 17 00:00:00 2001 From: Xutingjun Date: Fri, 5 Jun 2015 11:41:39 -0700 Subject: [PATCH 204/254] [SPARK-8099] set executor cores into system in yarn-cluster mode Author: Xutingjun Author: xutingjun Closes #6643 from XuTingjun/SPARK-8099 and squashes the following commits: 80b18cd [Xutingjun] change to STANDALONE | YARN ce33148 [Xutingjun] set executor cores into system e51cc9e [Xutingjun] set executor cores into system 0600861 [xutingjun] set executor cores into system --- core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 3aa3f948e865d..a0eae774268ed 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -425,7 +425,6 @@ object SparkSubmit { // Yarn client only OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"), OptionAssigner(args.numExecutors, YARN, CLIENT, sysProp = "spark.executor.instances"), - OptionAssigner(args.executorCores, YARN, CLIENT, sysProp = "spark.executor.cores"), OptionAssigner(args.files, YARN, CLIENT, sysProp = "spark.yarn.dist.files"), OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"), OptionAssigner(args.principal, YARN, CLIENT, sysProp = "spark.yarn.principal"), @@ -446,7 +445,7 @@ object SparkSubmit { OptionAssigner(args.keytab, YARN, CLUSTER, clOption = "--keytab"), // Other options - OptionAssigner(args.executorCores, STANDALONE, ALL_DEPLOY_MODES, + OptionAssigner(args.executorCores, STANDALONE | YARN, ALL_DEPLOY_MODES, sysProp = "spark.executor.cores"), OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES, sysProp = "spark.executor.memory"), From 3f80bc841ab155925fb0530eef5927990f4a5793 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 5 Jun 2015 12:28:37 -0700 Subject: [PATCH 205/254] [SPARK-7699] [CORE] Lazy start the scheduler for dynamic allocation This patch propose to lazy start the scheduler for dynamic allocation to avoid fast ramp down executor numbers is load is less. This implementation will: 1. immediately start the scheduler is `numExecutorsTarget` is 0, this is the expected behavior. 2. if `numExecutorsTarget` is not zero, start the scheduler until the number is satisfied, if the load is less, this initial started executors will last for at least 60 seconds, user will have a window to submit a job, no need to revamp the executors. 3. if `numExecutorsTarget` is not satisfied until the timeout, this means resource is not enough, the scheduler will start until this timeout, will not wait infinitely. Please help to review, thanks a lot. Author: jerryshao Closes #6430 from jerryshao/SPARK-7699 and squashes the following commits: 02cac8e [jerryshao] Address the comments 7242450 [jerryshao] Remove the useless import ecc0b00 [jerryshao] Address the comments 6f75f00 [jerryshao] Style changes 8b8decc [jerryshao] change the test name fb822ca [jerryshao] Change the solution according to comments 1cc74e5 [jerryshao] Lazy start the scheduler for dynamic allocation --- .../spark/ExecutorAllocationManager.scala | 17 +++- .../ExecutorAllocationManagerSuite.scala | 90 +++++++++++++++---- 2 files changed, 89 insertions(+), 18 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index f7323a4d9db72..9939103bb0903 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -150,6 +150,13 @@ private[spark] class ExecutorAllocationManager( // Metric source for ExecutorAllocationManager to expose internal status to MetricsSystem. val executorAllocationManagerSource = new ExecutorAllocationManagerSource + // Whether we are still waiting for the initial set of executors to be allocated. + // While this is true, we will not cancel outstanding executor requests. This is + // set to false when: + // (1) a stage is submitted, or + // (2) an executor idle timeout has elapsed. + @volatile private var initializing: Boolean = true + /** * Verify that the settings specified through the config are valid. * If not, throw an appropriate exception. @@ -240,6 +247,7 @@ private[spark] class ExecutorAllocationManager( removeTimes.retain { case (executorId, expireTime) => val expired = now >= expireTime if (expired) { + initializing = false removeExecutor(executorId) } !expired @@ -261,7 +269,11 @@ private[spark] class ExecutorAllocationManager( private def updateAndSyncNumExecutorsTarget(now: Long): Int = synchronized { val maxNeeded = maxNumExecutorsNeeded - if (maxNeeded < numExecutorsTarget) { + if (initializing) { + // Do not change our target while we are still initializing, + // Otherwise the first job may have to ramp up unnecessarily + 0 + } else if (maxNeeded < numExecutorsTarget) { // The target number exceeds the number we actually need, so stop adding new // executors and inform the cluster manager to cancel the extra pending requests val oldNumExecutorsTarget = numExecutorsTarget @@ -271,7 +283,7 @@ private[spark] class ExecutorAllocationManager( // If the new target has not changed, avoid sending a message to the cluster manager if (numExecutorsTarget < oldNumExecutorsTarget) { client.requestTotalExecutors(numExecutorsTarget) - logInfo(s"Lowering target number of executors to $numExecutorsTarget (previously " + + logDebug(s"Lowering target number of executors to $numExecutorsTarget (previously " + s"$oldNumExecutorsTarget) because not all requested executors are actually needed") } numExecutorsTarget - oldNumExecutorsTarget @@ -481,6 +493,7 @@ private[spark] class ExecutorAllocationManager( private var numRunningTasks: Int = _ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { + initializing = false val stageId = stageSubmitted.stageInfo.stageId val numTasks = stageSubmitted.stageInfo.numTasks allocationManager.synchronized { diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 1c2b681f0b843..803e1831bb269 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -90,7 +90,7 @@ class ExecutorAllocationManagerSuite } test("add executors") { - sc = createSparkContext(1, 10) + sc = createSparkContext(1, 10, 1) val manager = sc.executorAllocationManager.get sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) @@ -135,7 +135,7 @@ class ExecutorAllocationManagerSuite } test("add executors capped by num pending tasks") { - sc = createSparkContext(0, 10) + sc = createSparkContext(0, 10, 0) val manager = sc.executorAllocationManager.get sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 5))) @@ -186,7 +186,7 @@ class ExecutorAllocationManagerSuite } test("cancel pending executors when no longer needed") { - sc = createSparkContext(0, 10) + sc = createSparkContext(0, 10, 0) val manager = sc.executorAllocationManager.get sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(2, 5))) @@ -213,7 +213,7 @@ class ExecutorAllocationManagerSuite } test("remove executors") { - sc = createSparkContext(5, 10) + sc = createSparkContext(5, 10, 5) val manager = sc.executorAllocationManager.get (1 to 10).map(_.toString).foreach { id => onExecutorAdded(manager, id) } @@ -263,7 +263,7 @@ class ExecutorAllocationManagerSuite } test ("interleaving add and remove") { - sc = createSparkContext(5, 10) + sc = createSparkContext(5, 10, 5) val manager = sc.executorAllocationManager.get sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) @@ -331,7 +331,7 @@ class ExecutorAllocationManagerSuite } test("starting/canceling add timer") { - sc = createSparkContext(2, 10) + sc = createSparkContext(2, 10, 2) val clock = new ManualClock(8888L) val manager = sc.executorAllocationManager.get manager.setClock(clock) @@ -363,7 +363,7 @@ class ExecutorAllocationManagerSuite } test("starting/canceling remove timers") { - sc = createSparkContext(2, 10) + sc = createSparkContext(2, 10, 2) val clock = new ManualClock(14444L) val manager = sc.executorAllocationManager.get manager.setClock(clock) @@ -410,7 +410,7 @@ class ExecutorAllocationManagerSuite } test("mock polling loop with no events") { - sc = createSparkContext(0, 20) + sc = createSparkContext(0, 20, 0) val manager = sc.executorAllocationManager.get val clock = new ManualClock(2020L) manager.setClock(clock) @@ -436,7 +436,7 @@ class ExecutorAllocationManagerSuite } test("mock polling loop add behavior") { - sc = createSparkContext(0, 20) + sc = createSparkContext(0, 20, 0) val clock = new ManualClock(2020L) val manager = sc.executorAllocationManager.get manager.setClock(clock) @@ -486,7 +486,7 @@ class ExecutorAllocationManagerSuite } test("mock polling loop remove behavior") { - sc = createSparkContext(1, 20) + sc = createSparkContext(1, 20, 1) val clock = new ManualClock(2020L) val manager = sc.executorAllocationManager.get manager.setClock(clock) @@ -547,7 +547,7 @@ class ExecutorAllocationManagerSuite } test("listeners trigger add executors correctly") { - sc = createSparkContext(2, 10) + sc = createSparkContext(2, 10, 2) val manager = sc.executorAllocationManager.get assert(addTime(manager) === NOT_SET) @@ -577,7 +577,7 @@ class ExecutorAllocationManagerSuite } test("listeners trigger remove executors correctly") { - sc = createSparkContext(2, 10) + sc = createSparkContext(2, 10, 2) val manager = sc.executorAllocationManager.get assert(removeTimes(manager).isEmpty) @@ -608,7 +608,7 @@ class ExecutorAllocationManagerSuite } test("listeners trigger add and remove executor callbacks correctly") { - sc = createSparkContext(2, 10) + sc = createSparkContext(2, 10, 2) val manager = sc.executorAllocationManager.get assert(executorIds(manager).isEmpty) assert(removeTimes(manager).isEmpty) @@ -641,7 +641,7 @@ class ExecutorAllocationManagerSuite } test("SPARK-4951: call onTaskStart before onBlockManagerAdded") { - sc = createSparkContext(2, 10) + sc = createSparkContext(2, 10, 2) val manager = sc.executorAllocationManager.get assert(executorIds(manager).isEmpty) assert(removeTimes(manager).isEmpty) @@ -677,7 +677,7 @@ class ExecutorAllocationManagerSuite } test("avoid ramp up when target < running executors") { - sc = createSparkContext(0, 100000) + sc = createSparkContext(0, 100000, 0) val manager = sc.executorAllocationManager.get val stage1 = createStageInfo(0, 1000) sc.listenerBus.postToAll(SparkListenerStageSubmitted(stage1)) @@ -701,13 +701,67 @@ class ExecutorAllocationManagerSuite assert(numExecutorsTarget(manager) === 16) } - private def createSparkContext(minExecutors: Int = 1, maxExecutors: Int = 5): SparkContext = { + test("avoid ramp down initial executors until first job is submitted") { + sc = createSparkContext(2, 5, 3) + val manager = sc.executorAllocationManager.get + val clock = new ManualClock(10000L) + manager.setClock(clock) + + // Verify the initial number of executors + assert(numExecutorsTarget(manager) === 3) + schedule(manager) + // Verify whether the initial number of executors is kept with no pending tasks + assert(numExecutorsTarget(manager) === 3) + + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 2))) + clock.advance(100L) + + assert(maxNumExecutorsNeeded(manager) === 2) + schedule(manager) + + // Verify that current number of executors should be ramp down when first job is submitted + assert(numExecutorsTarget(manager) === 2) + } + + test("avoid ramp down initial executors until idle executor is timeout") { + sc = createSparkContext(2, 5, 3) + val manager = sc.executorAllocationManager.get + val clock = new ManualClock(10000L) + manager.setClock(clock) + + // Verify the initial number of executors + assert(numExecutorsTarget(manager) === 3) + schedule(manager) + // Verify the initial number of executors is kept when no pending tasks + assert(numExecutorsTarget(manager) === 3) + (0 until 3).foreach { i => + onExecutorAdded(manager, s"executor-$i") + } + + clock.advance(executorIdleTimeout * 1000) + + assert(maxNumExecutorsNeeded(manager) === 0) + schedule(manager) + // Verify executor is timeout but numExecutorsTarget is not recalculated + assert(numExecutorsTarget(manager) === 3) + + // Schedule again to recalculate the numExecutorsTarget after executor is timeout + schedule(manager) + // Verify that current number of executors should be ramp down when executor is timeout + assert(numExecutorsTarget(manager) === 2) + } + + private def createSparkContext( + minExecutors: Int = 1, + maxExecutors: Int = 5, + initialExecutors: Int = 1): SparkContext = { val conf = new SparkConf() .setMaster("local") .setAppName("test-executor-allocation-manager") .set("spark.dynamicAllocation.enabled", "true") .set("spark.dynamicAllocation.minExecutors", minExecutors.toString) .set("spark.dynamicAllocation.maxExecutors", maxExecutors.toString) + .set("spark.dynamicAllocation.initialExecutors", initialExecutors.toString) .set("spark.dynamicAllocation.schedulerBacklogTimeout", s"${schedulerBacklogTimeout.toString}s") .set("spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", @@ -791,6 +845,10 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { manager invokePrivate _schedule() } + private def maxNumExecutorsNeeded(manager: ExecutorAllocationManager): Int = { + manager invokePrivate _maxNumExecutorsNeeded() + } + private def addExecutors(manager: ExecutorAllocationManager): Int = { val maxNumExecutorsNeeded = manager invokePrivate _maxNumExecutorsNeeded() manager invokePrivate _addExecutors(maxNumExecutorsNeeded) From 4f16d3fe2e260a716b5b4e4005cb6229386440ed Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 5 Jun 2015 12:46:02 -0700 Subject: [PATCH 206/254] [SPARK-8112] [STREAMING] Fix the negative event count issue Author: zsxwing Closes #6659 from zsxwing/SPARK-8112 and squashes the following commits: a5d7da6 [zsxwing] Address comments d255b6e [zsxwing] Fix the negative event count issue --- .../apache/spark/streaming/dstream/ReceiverInputDStream.scala | 2 +- .../spark/streaming/receiver/ReceiverSupervisorImpl.scala | 4 ++-- .../apache/spark/streaming/scheduler/InputInfoTracker.scala | 4 +++- .../apache/spark/streaming/scheduler/ReceivedBlockInfo.scala | 4 +++- .../apache/spark/streaming/ReceivedBlockTrackerSuite.scala | 2 +- 5 files changed, 10 insertions(+), 6 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index e4ff05e12f201..e76e7eb0dea19 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -70,7 +70,7 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont val blockIds = blockInfos.map { _.blockId.asInstanceOf[BlockId] }.toArray // Register the input blocks information into InputInfoTracker - val inputInfo = InputInfo(id, blockInfos.map(_.numRecords).sum) + val inputInfo = InputInfo(id, blockInfos.flatMap(_.numRecords).sum) ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) if (blockInfos.nonEmpty) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 92938379b9c17..8be732b64e3a3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -138,8 +138,8 @@ private[streaming] class ReceiverSupervisorImpl( ) { val blockId = blockIdOption.getOrElse(nextBlockId) val numRecords = receivedBlock match { - case ArrayBufferBlock(arrayBuffer) => arrayBuffer.size - case _ => -1 + case ArrayBufferBlock(arrayBuffer) => Some(arrayBuffer.size.toLong) + case _ => None } val time = System.currentTimeMillis diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala index a72efccf2f994..7c0db8a863c67 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala @@ -23,7 +23,9 @@ import org.apache.spark.Logging import org.apache.spark.streaming.{Time, StreamingContext} /** To track the information of input stream at specified batch time. */ -private[streaming] case class InputInfo(inputStreamId: Int, numRecords: Long) +private[streaming] case class InputInfo(inputStreamId: Int, numRecords: Long) { + require(numRecords >= 0, "numRecords must not be negative") +} /** * This class manages all the input streams as well as their input data statistics. The information diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockInfo.scala index dc11e84f29965..656ac80df8979 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockInfo.scala @@ -24,11 +24,13 @@ import org.apache.spark.streaming.util.WriteAheadLogRecordHandle /** Information about blocks received by the receiver */ private[streaming] case class ReceivedBlockInfo( streamId: Int, - numRecords: Long, + numRecords: Option[Long], metadataOption: Option[Any], blockStoreResult: ReceivedBlockStoreResult ) { + require(numRecords.isEmpty || numRecords.get >= 0, "numRecords must not be negative") + @volatile private var _isBlockIdValid = true def blockId: StreamBlockId = blockStoreResult.blockId diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index 6f0ee774cb5cf..be305b5e0dfea 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -224,7 +224,7 @@ class ReceivedBlockTrackerSuite /** Generate blocks infos using random ids */ def generateBlockInfos(): Seq[ReceivedBlockInfo] = { - List.fill(5)(ReceivedBlockInfo(streamId, 0, None, + List.fill(5)(ReceivedBlockInfo(streamId, Some(0L), None, BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt))))) } From 4060526cd3b7e9ba345ce94f6e081cc1156e53ab Mon Sep 17 00:00:00 2001 From: Luca Martinetti Date: Fri, 5 Jun 2015 13:40:11 -0700 Subject: [PATCH 207/254] [SPARK-7747] [SQL] [DOCS] spark.sql.planner.externalSort Add documentation for spark.sql.planner.externalSort Author: Luca Martinetti Closes #6272 from lucamartinetti/docs-externalsort and squashes the following commits: 985661b [Luca Martinetti] [SPARK-7747] [SQL] [DOCS] Add documentation for spark.sql.planner.externalSort --- docs/sql-programming-guide.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 282ea75e1e785..cde5830c733e0 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1785,6 +1785,13 @@ that these options will be deprecated in future release as more optimizations ar Configures the number of partitions to use when shuffling data for joins or aggregations. + + spark.sql.planner.externalSort + false + + When true, performs sorts spilling to disk as needed otherwise sort each partition in memory. + + # Distributed SQL Engine From 356a4a9b93a1eeedb910c6bccc0abadf59e4877f Mon Sep 17 00:00:00 2001 From: amey Date: Fri, 5 Jun 2015 13:49:33 -0700 Subject: [PATCH 208/254] [SPARK-7991] [PySpark] Adding support for passing lists to describe. This is a minor change. Author: amey Closes #6655 from ameyc/JIRA-7991/support-passing-list-to-describe and squashes the following commits: e8a1dff [amey] Adding support for passing lists to describe. --- python/pyspark/sql/dataframe.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 03b01a1136e45..902504df5b11b 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -616,7 +616,19 @@ def describe(self, *cols): | min| 2| | max| 5| +-------+---+ + >>> df.describe(['age', 'name']).show() + +-------+---+-----+ + |summary|age| name| + +-------+---+-----+ + | count| 2| 2| + | mean|3.5| null| + | stddev|1.5| null| + | min| 2|Alice| + | max| 5| Bob| + +-------+---+-----+ """ + if len(cols) == 1 and isinstance(cols[0], list): + cols = cols[0] jdf = self._jdf.describe(self._jseq(cols)) return DataFrame(jdf, self.sql_ctx) From 6ebe419f335fcfb66dd3da74baf35eb5b2fc061d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 5 Jun 2015 13:57:21 -0700 Subject: [PATCH 209/254] [SPARK-8114][SQL] Remove some wildcard import on TestSQLContext._ cont'd. Fixed the following packages: sql.columnar sql.jdbc sql.json sql.parquet Author: Reynold Xin Closes #6667 from rxin/testsqlcontext_wildcard and squashes the following commits: 134a776 [Reynold Xin] Fixed compilation break. 6da7b69 [Reynold Xin] [SPARK-8114][SQL] Remove some wildcard import on TestSQLContext._ cont'd. --- .../columnar/InMemoryColumnarQuerySuite.scala | 40 ++++---- .../columnar/PartitionBatchPruningSuite.scala | 28 +++--- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 45 +++++---- .../spark/sql/jdbc/JDBCWriteSuite.scala | 75 +++++++-------- .../org/apache/spark/sql/json/JsonSuite.scala | 95 +++++++++---------- .../apache/spark/sql/json/TestJsonData.scala | 62 ++++++------ .../sql/parquet/ParquetFilterSuite.scala | 7 +- .../spark/sql/parquet/ParquetIOSuite.scala | 40 ++++---- .../ParquetPartitionDiscoverySuite.scala | 27 +++--- .../spark/sql/parquet/ParquetQuerySuite.scala | 24 ++--- .../sql/parquet/ParquetSchemaSuite.scala | 3 +- .../spark/sql/parquet/ParquetTest.scala | 6 +- .../apache/spark/sql/test/SQLTestUtils.scala | 14 ++- .../spark/sql/hive/orc/OrcQuerySuite.scala | 5 +- .../apache/spark/sql/hive/orc/OrcTest.scala | 8 +- 15 files changed, 234 insertions(+), 245 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 055453e688e73..fa3b8144c086e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -21,8 +21,6 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{QueryTest, TestData} import org.apache.spark.storage.StorageLevel.MEMORY_ONLY @@ -31,8 +29,12 @@ class InMemoryColumnarQuerySuite extends QueryTest { // Make sure the tables are loaded. TestData + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + import ctx.{logicalPlanToSparkQuery, sql} + test("simple columnar query") { - val plan = executePlan(testData.logicalPlan).executedPlan + val plan = ctx.executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -40,16 +42,16 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("default size avoids broadcast") { // TODO: Improve this test when we have better statistics - sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) + ctx.sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) .toDF().registerTempTable("sizeTst") - cacheTable("sizeTst") + ctx.cacheTable("sizeTst") assert( - table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > - conf.autoBroadcastJoinThreshold) + ctx.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > + ctx.conf.autoBroadcastJoinThreshold) } test("projection") { - val plan = executePlan(testData.select('value, 'key).logicalPlan).executedPlan + val plan = ctx.executePlan(testData.select('value, 'key).logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().map { @@ -58,7 +60,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { } test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { - val plan = executePlan(testData.logicalPlan).executedPlan + val plan = ctx.executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -70,7 +72,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM repeatedData"), repeatedData.collect().toSeq.map(Row.fromTuple)) - cacheTable("repeatedData") + ctx.cacheTable("repeatedData") checkAnswer( sql("SELECT * FROM repeatedData"), @@ -82,7 +84,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM nullableRepeatedData"), nullableRepeatedData.collect().toSeq.map(Row.fromTuple)) - cacheTable("nullableRepeatedData") + ctx.cacheTable("nullableRepeatedData") checkAnswer( sql("SELECT * FROM nullableRepeatedData"), @@ -94,7 +96,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT time FROM timestamps"), timestamps.collect().toSeq.map(Row.fromTuple)) - cacheTable("timestamps") + ctx.cacheTable("timestamps") checkAnswer( sql("SELECT time FROM timestamps"), @@ -106,7 +108,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM withEmptyParts"), withEmptyParts.collect().toSeq.map(Row.fromTuple)) - cacheTable("withEmptyParts") + ctx.cacheTable("withEmptyParts") checkAnswer( sql("SELECT * FROM withEmptyParts"), @@ -155,7 +157,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { // Create a RDD for the schema val rdd = - sparkContext.parallelize((1 to 100), 10).map { i => + ctx.sparkContext.parallelize((1 to 100), 10).map { i => Row( s"str${i}: test cache.", s"binary${i}: test cache.".getBytes("UTF-8"), @@ -175,18 +177,18 @@ class InMemoryColumnarQuerySuite extends QueryTest { (0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, Row((i - 0.25).toFloat, Seq(true, false, null))) } - createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") + ctx.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") // Cache the table. sql("cache table InMemoryCache_different_data_types") // Make sure the table is indeed cached. - val tableScan = table("InMemoryCache_different_data_types").queryExecution.executedPlan + val tableScan = ctx.table("InMemoryCache_different_data_types").queryExecution.executedPlan assert( - isCached("InMemoryCache_different_data_types"), + ctx.isCached("InMemoryCache_different_data_types"), "InMemoryCache_different_data_types should be cached.") // Issue a query and check the results. checkAnswer( sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"), - table("InMemoryCache_different_data_types").collect()) - dropTempTable("InMemoryCache_different_data_types") + ctx.table("InMemoryCache_different_data_types").collect()) + ctx.dropTempTable("InMemoryCache_different_data_types") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index cda1b0992e36f..6545c6b314a4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -21,40 +21,42 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { - val originalColumnBatchSize = conf.columnBatchSize - val originalInMemoryPartitionPruning = conf.inMemoryPartitionPruning + + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize + private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning override protected def beforeAll(): Unit = { // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch - setConf(SQLConf.COLUMN_BATCH_SIZE, "10") + ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, "10") - val pruningData = sparkContext.makeRDD((1 to 100).map { key => + val pruningData = ctx.sparkContext.makeRDD((1 to 100).map { key => val string = if (((key - 1) / 10) % 2 == 0) null else key.toString TestData(key, string) }, 5).toDF() pruningData.registerTempTable("pruningData") // Enable in-memory partition pruning - setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") + ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") // Enable in-memory table scan accumulators - setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") + ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") } override protected def afterAll(): Unit = { - setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString) - setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString) + ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString) + ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString) } before { - cacheTable("pruningData") + ctx.cacheTable("pruningData") } after { - uncacheTable("pruningData") + ctx.uncacheTable("pruningData") } // Comparisons @@ -108,7 +110,7 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi expectedQueryResult: => Seq[Int]): Unit = { test(query) { - val df = sql(query) + val df = ctx.sql(query) val queryExecution = df.queryExecution assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index e20c66cb2f1d7..7931854db27c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -21,13 +21,11 @@ import java.math.BigDecimal import java.sql.DriverManager import java.util.{Calendar, GregorianCalendar, Properties} -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test._ -import org.apache.spark.sql.types._ import org.h2.jdbc.JdbcSQLException import org.scalatest.BeforeAndAfter -import TestSQLContext._ -import TestSQLContext.implicits._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb0" @@ -37,12 +35,16 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) val testH2Dialect = new JdbcDialect { - def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2") + override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = Some(StringType) } + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + import ctx.sql + before { Class.forName("org.h2.Driver") // Extra properties that will be specified for our database. We need these to test @@ -253,26 +255,26 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { } test("Basic API") { - assert(TestSQLContext.read.jdbc( + assert(ctx.read.jdbc( urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3) } test("Basic API with FetchSize") { val properties = new Properties properties.setProperty("fetchSize", "2") - assert(TestSQLContext.read.jdbc( + assert(ctx.read.jdbc( urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3) } test("Partitioning via JDBCPartitioningInfo API") { assert( - TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) + ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) .collect().length === 3) } test("Partitioning via list-of-where-clauses API") { val parts = Array[String]("THEID < 2", "THEID >= 2") - assert(TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) + assert(ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) .collect().length === 3) } @@ -328,9 +330,9 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { } test("test DATE types") { - val rows = TestSQLContext.read.jdbc( + val rows = ctx.read.jdbc( urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() - val cachedRows = TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val cachedRows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) .cache().collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) assert(rows(1).getAs[java.sql.Date](1) === null) @@ -338,9 +340,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { } test("test DATE types in cache") { - val rows = - TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() - TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val rows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() + ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) .cache().registerTempTable("mycached_date") val cachedRows = sql("select * from mycached_date").collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) @@ -348,7 +349,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { } test("test types for null value") { - val rows = TestSQLContext.read.jdbc( + val rows = ctx.read.jdbc( urlWithUserAndPass, "TEST.NULLTYPES", new Properties).collect() assert((0 to 14).forall(i => rows(0).isNullAt(i))) } @@ -395,10 +396,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { test("Remap types via JdbcDialects") { JdbcDialects.registerDialect(testH2Dialect) - val df = TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) - assert(df.schema.filter( - _.dataType != org.apache.spark.sql.types.StringType - ).isEmpty) + val df = ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) + assert(df.schema.filter(_.dataType != org.apache.spark.sql.types.StringType).isEmpty) val rows = df.collect() assert(rows(0).get(0).isInstanceOf[String]) assert(rows(0).get(1).isInstanceOf[String]) @@ -419,7 +418,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { test("Aggregated dialects") { val agg = new AggregatedDialect(List(new JdbcDialect { - def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:") + override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = if (sqlType % 2 == 0) { @@ -430,8 +429,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { }, testH2Dialect)) assert(agg.canHandle("jdbc:h2:xxx")) assert(!agg.canHandle("jdbc:h2")) - assert(agg.getCatalystType(0, "", 1, null) == Some(LongType)) - assert(agg.getCatalystType(1, "", 1, null) == Some(StringType)) + assert(agg.getCatalystType(0, "", 1, null) === Some(LongType)) + assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 2de8c1a6098e0..d949ef42267ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -24,7 +24,6 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{SaveMode, Row} -import org.apache.spark.sql.test._ import org.apache.spark.sql.types._ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { @@ -37,6 +36,10 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { properties.setProperty("password", "testPass") properties.setProperty("rowId", "false") + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + import ctx.sql + before { Class.forName("org.h2.Driver") conn = DriverManager.getConnection(url) @@ -54,14 +57,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { "create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() conn1.commit() - TestSQLContext.sql( + ctx.sql( s""" |CREATE TEMPORARY TABLE PEOPLE |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - TestSQLContext.sql( + ctx.sql( s""" |CREATE TEMPORARY TABLE PEOPLE1 |USING org.apache.spark.sql.jdbc @@ -74,66 +77,64 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { conn1.close() } - val sc = TestSQLContext.sparkContext + private lazy val sc = ctx.sparkContext - val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222)) - val arr1x2 = Array[Row](Row.apply("fred", 3)) - val schema2 = StructType( + private lazy val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222)) + private lazy val arr1x2 = Array[Row](Row.apply("fred", 3)) + private lazy val schema2 = StructType( StructField("name", StringType) :: StructField("id", IntegerType) :: Nil) - val arr2x3 = Array[Row](Row.apply("dave", 42, 1), Row.apply("mary", 222, 2)) - val schema3 = StructType( + private lazy val arr2x3 = Array[Row](Row.apply("dave", 42, 1), Row.apply("mary", 222, 2)) + private lazy val schema3 = StructType( StructField("name", StringType) :: StructField("id", IntegerType) :: StructField("seq", IntegerType) :: Nil) test("Basic CREATE") { - val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) + val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties) - assert(2 == TestSQLContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) - assert(2 == - TestSQLContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) + assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) + assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) } test("CREATE with overwrite") { - val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3) - val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = ctx.createDataFrame(sc.parallelize(arr2x3), schema3) + val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) df.write.jdbc(url1, "TEST.DROPTEST", properties) - assert(2 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(3 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(3 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.DROPTEST", properties) - assert(1 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(2 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + assert(1 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) } test("CREATE then INSERT to append") { - val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) + val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) df.write.jdbc(url, "TEST.APPENDTEST", new Properties) df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties) - assert(3 == TestSQLContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) - assert(2 == - TestSQLContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) + assert(3 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) + assert(2 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) } test("CREATE then INSERT to truncate") { - val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) + val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) df.write.jdbc(url1, "TEST.TRUNCATETEST", properties) df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.TRUNCATETEST", properties) - assert(1 == TestSQLContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) - assert(2 == TestSQLContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) + assert(1 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) + assert(2 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) } test("Incompatible INSERT to append") { - val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3) + val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) + val df2 = ctx.createDataFrame(sc.parallelize(arr2x3), schema3) df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties) intercept[org.apache.spark.SparkException] { @@ -142,15 +143,15 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { } test("INSERT to JDBC Datasource") { - TestSQLContext.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } test("INSERT to JDBC Datasource with overwrite") { - TestSQLContext.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - TestSQLContext.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + ctx.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") + assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index f8d62f9e7e02b..d889c7be17ce7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -23,21 +23,19 @@ import java.sql.{Date, Timestamp} import com.fasterxml.jackson.core.JsonFactory import org.scalactic.Tolerance._ +import org.apache.spark.sql.{QueryTest, Row, SQLConf} import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.json.InferSchema.compatibleType import org.apache.spark.sql.sources.LogicalRelation -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{QueryTest, Row, SQLConf} import org.apache.spark.util.Utils -class JsonSuite extends QueryTest { - import org.apache.spark.sql.json.TestJsonData._ +class JsonSuite extends QueryTest with TestJsonData { - TestJsonData + protected lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.sql + import ctx.implicits._ test("Type promotion") { def checkTypePromotion(expected: Any, actual: Any) { @@ -214,7 +212,7 @@ class JsonSuite extends QueryTest { } test("Complex field and type inferring with null in sampling") { - val jsonDF = read.json(jsonNullStruct) + val jsonDF = ctx.read.json(jsonNullStruct) val expectedSchema = StructType( StructField("headers", StructType( StructField("Charset", StringType, true) :: @@ -233,7 +231,7 @@ class JsonSuite extends QueryTest { } test("Primitive field and type inferring") { - val jsonDF = read.json(primitiveFieldAndType) + val jsonDF = ctx.read.json(primitiveFieldAndType) val expectedSchema = StructType( StructField("bigInteger", DecimalType.Unlimited, true) :: @@ -261,7 +259,7 @@ class JsonSuite extends QueryTest { } test("Complex field and type inferring") { - val jsonDF = read.json(complexFieldAndType1) + val jsonDF = ctx.read.json(complexFieldAndType1) val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: @@ -360,7 +358,7 @@ class JsonSuite extends QueryTest { } test("GetField operation on complex data type") { - val jsonDF = read.json(complexFieldAndType1) + val jsonDF = ctx.read.json(complexFieldAndType1) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -376,7 +374,7 @@ class JsonSuite extends QueryTest { } test("Type conflict in primitive field values") { - val jsonDF = read.json(primitiveFieldValueTypeConflict) + val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict) val expectedSchema = StructType( StructField("num_bool", StringType, true) :: @@ -450,7 +448,7 @@ class JsonSuite extends QueryTest { } ignore("Type conflict in primitive field values (Ignored)") { - val jsonDF = read.json(primitiveFieldValueTypeConflict) + val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict) jsonDF.registerTempTable("jsonTable") // Right now, the analyzer does not promote strings in a boolean expression. @@ -503,7 +501,7 @@ class JsonSuite extends QueryTest { } test("Type conflict in complex field values") { - val jsonDF = read.json(complexFieldValueTypeConflict) + val jsonDF = ctx.read.json(complexFieldValueTypeConflict) val expectedSchema = StructType( StructField("array", ArrayType(LongType, true), true) :: @@ -527,7 +525,7 @@ class JsonSuite extends QueryTest { } test("Type conflict in array elements") { - val jsonDF = read.json(arrayElementTypeConflict) + val jsonDF = ctx.read.json(arrayElementTypeConflict) val expectedSchema = StructType( StructField("array1", ArrayType(StringType, true), true) :: @@ -555,7 +553,7 @@ class JsonSuite extends QueryTest { } test("Handling missing fields") { - val jsonDF = read.json(missingFields) + val jsonDF = ctx.read.json(missingFields) val expectedSchema = StructType( StructField("a", BooleanType, true) :: @@ -574,8 +572,9 @@ class JsonSuite extends QueryTest { val dir = Utils.createTempDir() dir.delete() val path = dir.getCanonicalPath - sparkContext.parallelize(1 to 100).map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) - val jsonDF = read.option("samplingRatio", "0.49").json(path) + ctx.sparkContext.parallelize(1 to 100) + .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) + val jsonDF = ctx.read.option("samplingRatio", "0.49").json(path) val analyzed = jsonDF.queryExecution.analyzed assert( @@ -590,7 +589,7 @@ class JsonSuite extends QueryTest { val schema = StructType(StructField("a", LongType, true) :: Nil) val logicalRelation = - read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation] + ctx.read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation] val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation] assert(relationWithSchema.path === Some(path)) assert(relationWithSchema.schema === schema) @@ -602,7 +601,7 @@ class JsonSuite extends QueryTest { dir.delete() val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = read.json(path) + val jsonDF = ctx.read.json(path) val expectedSchema = StructType( StructField("bigInteger", DecimalType.Unlimited, true) :: @@ -671,7 +670,7 @@ class JsonSuite extends QueryTest { StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) - val jsonDF1 = read.schema(schema).json(path) + val jsonDF1 = ctx.read.schema(schema).json(path) assert(schema === jsonDF1.schema) @@ -688,7 +687,7 @@ class JsonSuite extends QueryTest { "this is a simple string.") ) - val jsonDF2 = read.schema(schema).json(primitiveFieldAndType) + val jsonDF2 = ctx.read.schema(schema).json(primitiveFieldAndType) assert(schema === jsonDF2.schema) @@ -709,7 +708,7 @@ class JsonSuite extends QueryTest { test("Applying schemas with MapType") { val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - val jsonWithSimpleMap = read.schema(schemaWithSimpleMap).json(mapType1) + val jsonWithSimpleMap = ctx.read.schema(schemaWithSimpleMap).json(mapType1) jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap") @@ -737,7 +736,7 @@ class JsonSuite extends QueryTest { val schemaWithComplexMap = StructType( StructField("map", MapType(StringType, innerStruct, true), false) :: Nil) - val jsonWithComplexMap = read.schema(schemaWithComplexMap).json(mapType2) + val jsonWithComplexMap = ctx.read.schema(schemaWithComplexMap).json(mapType2) jsonWithComplexMap.registerTempTable("jsonWithComplexMap") @@ -763,7 +762,7 @@ class JsonSuite extends QueryTest { } test("SPARK-2096 Correctly parse dot notations") { - val jsonDF = read.json(complexFieldAndType2) + val jsonDF = ctx.read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -781,7 +780,7 @@ class JsonSuite extends QueryTest { } test("SPARK-3390 Complex arrays") { - val jsonDF = read.json(complexFieldAndType2) + val jsonDF = ctx.read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -804,7 +803,7 @@ class JsonSuite extends QueryTest { } test("SPARK-3308 Read top level JSON arrays") { - val jsonDF = read.json(jsonArray) + val jsonDF = ctx.read.json(jsonArray) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -822,10 +821,10 @@ class JsonSuite extends QueryTest { test("Corrupt records") { // Test if we can query corrupt records. - val oldColumnNameOfCorruptRecord = TestSQLContext.conf.columnNameOfCorruptRecord - TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") + val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord + ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") - val jsonDF = read.json(corruptRecords) + val jsonDF = ctx.read.json(corruptRecords) jsonDF.registerTempTable("jsonTable") val schema = StructType( @@ -875,11 +874,11 @@ class JsonSuite extends QueryTest { Row("]") :: Nil ) - TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) } test("SPARK-4068: nulls in arrays") { - val jsonDF = read.json(nullsInArrays) + val jsonDF = ctx.read.json(nullsInArrays) jsonDF.registerTempTable("jsonTable") val schema = StructType( @@ -925,7 +924,7 @@ class JsonSuite extends QueryTest { Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5) } - val df1 = createDataFrame(rowRDD1, schema1) + val df1 = ctx.createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") val df2 = df1.toDF val result = df2.toJSON.collect() @@ -948,7 +947,7 @@ class JsonSuite extends QueryTest { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df3 = createDataFrame(rowRDD2, schema2) + val df3 = ctx.createDataFrame(rowRDD2, schema2) df3.registerTempTable("applySchema2") val df4 = df3.toDF val result2 = df4.toJSON.collect() @@ -956,8 +955,8 @@ class JsonSuite extends QueryTest { assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") - val jsonDF = read.json(primitiveFieldAndType) - val primTable = read.json(jsonDF.toJSON) + val jsonDF = ctx.read.json(primitiveFieldAndType) + val primTable = ctx.read.json(jsonDF.toJSON) primTable.registerTempTable("primativeTable") checkAnswer( sql("select * from primativeTable"), @@ -969,8 +968,8 @@ class JsonSuite extends QueryTest { "this is a simple string.") ) - val complexJsonDF = read.json(complexFieldAndType1) - val compTable = read.json(complexJsonDF.toJSON) + val complexJsonDF = ctx.read.json(complexFieldAndType1) + val compTable = ctx.read.json(complexJsonDF.toJSON) compTable.registerTempTable("complexTable") // Access elements of a primitive array. checkAnswer( @@ -1074,29 +1073,29 @@ class JsonSuite extends QueryTest { } test("SPARK-7565 MapType in JsonRDD") { - val useStreaming = getConf(SQLConf.USE_JACKSON_STREAMING_API, "true") - val oldColumnNameOfCorruptRecord = TestSQLContext.conf.columnNameOfCorruptRecord - TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") + val useStreaming = ctx.getConf(SQLConf.USE_JACKSON_STREAMING_API, "true") + val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord + ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) try{ for (useStreaming <- List("true", "false")) { - setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) + ctx.setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) val temp = Utils.createTempDir().getPath - val df = read.schema(schemaWithSimpleMap).json(mapType1) + val df = ctx.read.schema(schemaWithSimpleMap).json(mapType1) df.write.mode("overwrite").parquet(temp) // order of MapType is not defined - assert(read.parquet(temp).count() == 5) + assert(ctx.read.parquet(temp).count() == 5) - val df2 = read.json(corruptRecords) + val df2 = ctx.read.json(corruptRecords) df2.write.mode("overwrite").parquet(temp) - checkAnswer(read.parquet(temp), df2.collect()) + checkAnswer(ctx.read.parquet(temp), df2.collect()) } } finally { - setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) - setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + ctx.setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) + ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index 47a97a49daabb..b6a6a8dc6a63c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.json -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext -object TestJsonData { +trait TestJsonData { - val primitiveFieldAndType = - TestSQLContext.sparkContext.parallelize( + protected def ctx: SQLContext + + def primitiveFieldAndType: RDD[String] = + ctx.sparkContext.parallelize( """{"string":"this is a simple string.", "integer":10, "long":21474836470, @@ -32,8 +35,8 @@ object TestJsonData { "null":null }""" :: Nil) - val primitiveFieldValueTypeConflict = - TestSQLContext.sparkContext.parallelize( + def primitiveFieldValueTypeConflict: RDD[String] = + ctx.sparkContext.parallelize( """{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1, "num_bool":true, "num_str":13.1, "str_bool":"str1"}""" :: """{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null, @@ -43,15 +46,15 @@ object TestJsonData { """{"num_num_1":21474836570, "num_num_2":1.1, "num_num_3": 21474836470, "num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil) - val jsonNullStruct = - TestSQLContext.sparkContext.parallelize( + def jsonNullStruct: RDD[String] = + ctx.sparkContext.parallelize( """{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":{}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":""}""" :: """{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil) - val complexFieldValueTypeConflict = - TestSQLContext.sparkContext.parallelize( + def complexFieldValueTypeConflict: RDD[String] = + ctx.sparkContext.parallelize( """{"num_struct":11, "str_array":[1, 2, 3], "array":[], "struct_array":[], "struct": {}}""" :: """{"num_struct":{"field":false}, "str_array":null, @@ -61,23 +64,23 @@ object TestJsonData { """{"num_struct":{}, "str_array":["str1", "str2", 33], "array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil) - val arrayElementTypeConflict = - TestSQLContext.sparkContext.parallelize( + def arrayElementTypeConflict: RDD[String] = + ctx.sparkContext.parallelize( """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}], "array2": [{"field":214748364700}, {"field":1}]}""" :: """{"array3": [{"field":"str"}, {"field":1}]}""" :: """{"array3": [1, 2, 3]}""" :: Nil) - val missingFields = - TestSQLContext.sparkContext.parallelize( + def missingFields: RDD[String] = + ctx.sparkContext.parallelize( """{"a":true}""" :: """{"b":21474836470}""" :: """{"c":[33, 44]}""" :: """{"d":{"field":true}}""" :: """{"e":"str"}""" :: Nil) - val complexFieldAndType1 = - TestSQLContext.sparkContext.parallelize( + def complexFieldAndType1: RDD[String] = + ctx.sparkContext.parallelize( """{"struct":{"field1": true, "field2": 92233720368547758070}, "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]}, "arrayOfString":["str1", "str2"], @@ -92,8 +95,8 @@ object TestJsonData { "arrayOfArray2":[[1, 2, 3], [1.1, 2.1, 3.1]] }""" :: Nil) - val complexFieldAndType2 = - TestSQLContext.sparkContext.parallelize( + def complexFieldAndType2: RDD[String] = + ctx.sparkContext.parallelize( """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], "complexArrayOfStruct": [ { @@ -146,16 +149,16 @@ object TestJsonData { ]] }""" :: Nil) - val mapType1 = - TestSQLContext.sparkContext.parallelize( + def mapType1: RDD[String] = + ctx.sparkContext.parallelize( """{"map": {"a": 1}}""" :: """{"map": {"b": 2}}""" :: """{"map": {"c": 3}}""" :: """{"map": {"c": 1, "d": 4}}""" :: """{"map": {"e": null}}""" :: Nil) - val mapType2 = - TestSQLContext.sparkContext.parallelize( + def mapType2: RDD[String] = + ctx.sparkContext.parallelize( """{"map": {"a": {"field1": [1, 2, 3, null]}}}""" :: """{"map": {"b": {"field2": 2}}}""" :: """{"map": {"c": {"field1": [], "field2": 4}}}""" :: @@ -163,22 +166,22 @@ object TestJsonData { """{"map": {"e": null}}""" :: """{"map": {"f": {"field1": null}}}""" :: Nil) - val nullsInArrays = - TestSQLContext.sparkContext.parallelize( + def nullsInArrays: RDD[String] = + ctx.sparkContext.parallelize( """{"field1":[[null], [[["Test"]]]]}""" :: """{"field2":[null, [{"Test":1}]]}""" :: """{"field3":[[null], [{"Test":"2"}]]}""" :: """{"field4":[[null, [1,2,3]]]}""" :: Nil) - val jsonArray = - TestSQLContext.sparkContext.parallelize( + def jsonArray: RDD[String] = + ctx.sparkContext.parallelize( """[{"a":"str_a_1"}]""" :: """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """[]""" :: Nil) - val corruptRecords = - TestSQLContext.sparkContext.parallelize( + def corruptRecords: RDD[String] = + ctx.sparkContext.parallelize( """{""" :: """""" :: """{"a":1, b:2}""" :: @@ -186,6 +189,5 @@ object TestJsonData { """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """]""" :: Nil) - val empty = - TestSQLContext.sparkContext.parallelize(Seq[String]()) + def empty: RDD[String] = ctx.sparkContext.parallelize(Seq[String]()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 4aa5bcb7fdbca..17f5f9a491e6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.sources.LogicalRelation -import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf} @@ -42,7 +41,7 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf} * data type is nullable. */ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { - val sqlContext = TestSQLContext + lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext private def checkFilterPredicate( df: DataFrame, @@ -312,7 +311,7 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { } class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { - val originalConf = sqlContext.conf.parquetUseDataSourceApi + lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") @@ -341,7 +340,7 @@ class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeA } class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { - val originalConf = sqlContext.conf.parquetUseDataSourceApi + lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 7f7c2cc1a6c26..2b6a27032e637 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -36,9 +36,6 @@ import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf, SaveMode} @@ -66,9 +63,8 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS * A test suite that tests basic Parquet I/O. */ class ParquetIOSuiteBase extends QueryTest with ParquetTest { - val sqlContext = TestSQLContext - - import sqlContext.implicits.localSeqToDataFrameHolder + lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext + import sqlContext.implicits._ /** * Writes `data` to a Parquet file, reads it back and check file contents. @@ -104,7 +100,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { test("fixed-length decimals") { def makeDecimalRDD(decimal: DecimalType): DataFrame = - sparkContext + sqlContext.sparkContext .parallelize(0 to 1000) .map(i => Tuple1(i / 100.0)) .toDF() @@ -115,7 +111,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withTempPath { dir => val data = makeDecimalRDD(DecimalType(precision, scale)) data.write.parquet(dir.getCanonicalPath) - checkAnswer(read.parquet(dir.getCanonicalPath), data.collect().toSeq) + checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq) } } @@ -123,7 +119,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { intercept[Throwable] { withTempPath { dir => makeDecimalRDD(DecimalType(19, 10)).write.parquet(dir.getCanonicalPath) - read.parquet(dir.getCanonicalPath).collect() + sqlContext.read.parquet(dir.getCanonicalPath).collect() } } @@ -131,14 +127,14 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { intercept[Throwable] { withTempPath { dir => makeDecimalRDD(DecimalType.Unlimited).write.parquet(dir.getCanonicalPath) - read.parquet(dir.getCanonicalPath).collect() + sqlContext.read.parquet(dir.getCanonicalPath).collect() } } } test("date type") { def makeDateRDD(): DataFrame = - sparkContext + sqlContext.sparkContext .parallelize(0 to 1000) .map(i => Tuple1(DateUtils.toJavaDate(i))) .toDF() @@ -147,7 +143,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withTempPath { dir => val data = makeDateRDD() data.write.parquet(dir.getCanonicalPath) - checkAnswer(read.parquet(dir.getCanonicalPath), data.collect().toSeq) + checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq) } } @@ -236,7 +232,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { def checkCompressionCodec(codec: CompressionCodecName): Unit = { withSQLConf(SQLConf.PARQUET_COMPRESSION -> codec.name()) { withParquetFile(data) { path => - assertResult(conf.parquetCompressionCodec.toUpperCase) { + assertResult(sqlContext.conf.parquetCompressionCodec.toUpperCase) { compressionCodecFor(path) } } @@ -244,7 +240,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } // Checks default compression codec - checkCompressionCodec(CompressionCodecName.fromConf(conf.parquetCompressionCodec)) + checkCompressionCodec(CompressionCodecName.fromConf(sqlContext.conf.parquetCompressionCodec)) checkCompressionCodec(CompressionCodecName.UNCOMPRESSED) checkCompressionCodec(CompressionCodecName.GZIP) @@ -283,7 +279,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withTempDir { dir => val path = new Path(dir.toURI.toString, "part-r-0.parquet") makeRawParquetFile(path) - checkAnswer(read.parquet(path.toString), (0 until 10).map { i => + checkAnswer(sqlContext.read.parquet(path.toString), (0 until 10).map { i => Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) }) } @@ -312,7 +308,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withParquetFile((1 to 10).map(i => (i, i.toString))) { file => val newData = (11 to 20).map(i => (i, i.toString)) newData.toDF().write.format("parquet").mode(SaveMode.Overwrite).save(file) - checkAnswer(read.parquet(file), newData.map(Row.fromTuple)) + checkAnswer(sqlContext.read.parquet(file), newData.map(Row.fromTuple)) } } @@ -321,7 +317,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withParquetFile(data) { file => val newData = (11 to 20).map(i => (i, i.toString)) newData.toDF().write.format("parquet").mode(SaveMode.Ignore).save(file) - checkAnswer(read.parquet(file), data.map(Row.fromTuple)) + checkAnswer(sqlContext.read.parquet(file), data.map(Row.fromTuple)) } } @@ -341,7 +337,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withParquetFile(data) { file => val newData = (11 to 20).map(i => (i, i.toString)) newData.toDF().write.format("parquet").mode(SaveMode.Append).save(file) - checkAnswer(read.parquet(file), (data ++ newData).map(Row.fromTuple)) + checkAnswer(sqlContext.read.parquet(file), (data ++ newData).map(Row.fromTuple)) } } @@ -369,11 +365,11 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { val path = new Path(location.getCanonicalPath) ParquetFileWriter.writeMetadataFile( - sparkContext.hadoopConfiguration, + sqlContext.sparkContext.hadoopConfiguration, path, new Footer(path, new ParquetMetadata(fileMetadata, Nil)) :: Nil) - assertResult(read.parquet(path.toString).schema) { + assertResult(sqlContext.read.parquet(path.toString).schema) { StructType( StructField("a", BooleanType, nullable = false) :: StructField("b", IntegerType, nullable = false) :: @@ -406,7 +402,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { - val originalConf = sqlContext.conf.parquetUseDataSourceApi + private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") @@ -430,7 +426,7 @@ class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterA } class ParquetDataSourceOffIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { - val originalConf = sqlContext.conf.parquetUseDataSourceApi + private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index 3b29979452ad9..8979a0a210a42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.sql.parquet import java.io.File @@ -28,7 +29,6 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.sources.PartitioningUtils._ import org.apache.spark.sql.sources.{LogicalRelation, Partition, PartitionSpec} -import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, QueryTest, Row, SQLContext} @@ -39,10 +39,10 @@ case class ParquetData(intField: Int, stringField: String) case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { - override val sqlContext: SQLContext = TestSQLContext - import sqlContext._ + override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext import sqlContext.implicits._ + import sqlContext.sql val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__" @@ -190,8 +190,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { // Introduce _temporary dir to the base dir the robustness of the schema discovery process. new File(base.getCanonicalPath, "_temporary").mkdir() - println("load the partitioned table") - read.parquet(base.getCanonicalPath).registerTempTable("t") + sqlContext.read.parquet(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -238,7 +237,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - read.parquet(base.getCanonicalPath).registerTempTable("t") + sqlContext.read.parquet(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -286,7 +285,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - val parquetRelation = read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath) + val parquetRelation = sqlContext.read.format("parquet").load(base.getCanonicalPath) parquetRelation.registerTempTable("t") withTempTable("t") { @@ -326,7 +325,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - val parquetRelation = read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath) + val parquetRelation = sqlContext.read.format("parquet").load(base.getCanonicalPath) parquetRelation.registerTempTable("t") withTempTable("t") { @@ -358,7 +357,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { (1 to 10).map(i => (i, i.toString)).toDF("intField", "stringField"), makePartitionDir(base, defaultPartitionName, "pi" -> 2)) - read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath).registerTempTable("t") + sqlContext.read.format("parquet").load(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -371,7 +370,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { test("SPARK-7749 Non-partitioned table should have empty partition spec") { withTempPath { dir => (1 to 10).map(i => (i, i.toString)).toDF("a", "b").write.parquet(dir.getCanonicalPath) - val queryExecution = read.parquet(dir.getCanonicalPath).queryExecution + val queryExecution = sqlContext.read.parquet(dir.getCanonicalPath).queryExecution queryExecution.analyzed.collectFirst { case LogicalRelation(relation: ParquetRelation2) => assert(relation.partitionSpec === PartitionSpec.emptySpec) @@ -385,7 +384,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { withTempPath { dir => val df = Seq("/", "[]", "?").zipWithIndex.map(_.swap).toDF("i", "s") df.write.format("parquet").partitionBy("s").save(dir.getCanonicalPath) - checkAnswer(read.parquet(dir.getCanonicalPath), df.collect()) + checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), df.collect()) } } @@ -425,12 +424,12 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { } val schema = StructType(partitionColumns :+ StructField(s"i", StringType)) - val df = createDataFrame(sparkContext.parallelize(row :: Nil), schema) + val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(row :: Nil), schema) withTempPath { dir => df.write.format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) val fields = schema.map(f => Column(f.name).cast(f.dataType)) - checkAnswer(read.load(dir.toString).select(fields: _*), row) + checkAnswer(sqlContext.read.load(dir.toString).select(fields: _*), row) } } @@ -446,7 +445,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { Files.touch(new File(s"${dir.getCanonicalPath}/b=1", ".DS_Store")) Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar")) - checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df) + checkAnswer(sqlContext.read.format("parquet").load(dir.getCanonicalPath), df) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 304936fb2be8e..de0107a361815 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -22,14 +22,14 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.types._ import org.apache.spark.sql.{SQLConf, QueryTest} import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ /** * A test suite that tests various Parquet queries. */ class ParquetQuerySuiteBase extends QueryTest with ParquetTest { - val sqlContext = TestSQLContext + lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext + import sqlContext.implicits._ + import sqlContext.sql test("simple select queries") { withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { @@ -40,22 +40,22 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest { test("appending") { val data = (0 until 10).map(i => (i, i.toString)) - createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { sql("INSERT INTO TABLE t SELECT * FROM tmp") - checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) + checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple)) } - catalog.unregisterTable(Seq("tmp")) + sqlContext.catalog.unregisterTable(Seq("tmp")) } test("overwriting") { val data = (0 until 10).map(i => (i, i.toString)) - createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") - checkAnswer(table("t"), data.map(Row.fromTuple)) + checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple)) } - catalog.unregisterTable(Seq("tmp")) + sqlContext.catalog.unregisterTable(Seq("tmp")) } test("self-join") { @@ -118,7 +118,7 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest { val schema = StructType(List(StructField("d", DecimalType(18, 0), false), StructField("time", TimestampType, false)).toArray) withTempPath { file => - val df = sqlContext.createDataFrame(sparkContext.parallelize(data), schema) + val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data), schema) df.write.parquet(file.getCanonicalPath) val df2 = sqlContext.read.parquet(file.getCanonicalPath) checkAnswer(df2, df.collect().toSeq) @@ -127,7 +127,7 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest { } class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { - val originalConf = sqlContext.conf.parquetUseDataSourceApi + private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") @@ -139,7 +139,7 @@ class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAnd } class ParquetDataSourceOffQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { - val originalConf = sqlContext.conf.parquetUseDataSourceApi + private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index 8b1745124b8e1..171a656f0e01e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -24,11 +24,10 @@ import org.apache.parquet.schema.MessageTypeParser import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { - val sqlContext = TestSQLContext + lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext /** * Checks whether the reflected Parquet message type for product type `T` conforms `messageType`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala index 516ba373f41d2..eb15a1609f1d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala @@ -33,8 +33,6 @@ import org.apache.spark.sql.{DataFrame, SaveMode} * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. */ private[sql] trait ParquetTest extends SQLTestUtils { - import sqlContext.implicits.{localSeqToDataFrameHolder, rddToDataFrameHolder} - import sqlContext.sparkContext /** * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` @@ -44,7 +42,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - sparkContext.parallelize(data).toDF().write.parquet(file.getCanonicalPath) + sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath) f(file.getCanonicalPath) } } @@ -75,7 +73,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - data.toDF().write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) + sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 17a8b0cca09df..ac4a00a6f3dac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -25,11 +25,9 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils trait SQLTestUtils { - val sqlContext: SQLContext + def sqlContext: SQLContext - import sqlContext.{conf, sparkContext} - - protected def configuration = sparkContext.hadoopConfiguration + protected def configuration = sqlContext.sparkContext.hadoopConfiguration /** * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL @@ -39,12 +37,12 @@ trait SQLTestUtils { */ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(conf.getConf(key)).toOption) - (keys, values).zipped.foreach(conf.setConf) + val currentValues = keys.map(key => Try(sqlContext.conf.getConf(key)).toOption) + (keys, values).zipped.foreach(sqlContext.conf.setConf) try f finally { keys.zip(currentValues).foreach { - case (key, Some(value)) => conf.setConf(key, value) - case (key, None) => conf.unsetConf(key) + case (key, Some(value)) => sqlContext.conf.setConf(key, value) + case (key, None) => sqlContext.conf.unsetConf(key) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 57c23fe77f8b5..b384fb39f3d66 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -52,9 +52,6 @@ case class Contact(name: String, phone: String) case class Person(name: String, age: Int, contacts: Seq[Contact]) class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { - override val sqlContext = TestHive - - import TestHive.read def getTempFilePath(prefix: String, suffix: String = ""): File = { val tempFile = File.createTempFile(prefix, suffix) @@ -69,7 +66,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withOrcFile(data) { file => checkAnswer( - read.format("orc").load(file), + sqlContext.read.format("orc").load(file), data.toDF().collect()) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 750f0b04aaa87..5daf691aa8c53 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -22,13 +22,11 @@ import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql._ private[sql] trait OrcTest extends SQLTestUtils { - protected def hiveContext = sqlContext.asInstanceOf[HiveContext] + lazy val sqlContext = org.apache.spark.sql.hive.test.TestHive import sqlContext.sparkContext import sqlContext.implicits._ @@ -53,7 +51,7 @@ private[sql] trait OrcTest extends SQLTestUtils { protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: DataFrame => Unit): Unit = { - withOrcFile(data)(path => f(hiveContext.read.format("orc").load(path))) + withOrcFile(data)(path => f(sqlContext.read.format("orc").load(path))) } /** @@ -65,7 +63,7 @@ private[sql] trait OrcTest extends SQLTestUtils { (data: Seq[T], tableName: String) (f: => Unit): Unit = { withOrcDataFrame(data) { df => - hiveContext.registerDataFrameAsTable(df, tableName) + sqlContext.registerDataFrameAsTable(df, tableName) withTempTable(tableName)(f) } } From eb19d3f75cbd002f7e72ce02017a8de67f562792 Mon Sep 17 00:00:00 2001 From: Dong Wang Date: Fri, 5 Jun 2015 17:41:12 -0700 Subject: [PATCH 210/254] [SPARK-6964] [SQL] Support Cancellation in the Thrift Server Support runInBackground in SparkExecuteStatementOperation, and add cancellation Author: Dong Wang Closes #6207 from dongwang218/SPARK-6964-jdbc-cancel and squashes the following commits: 687c113 [Dong Wang] fix 100 characters 7bfa2a7 [Dong Wang] fix merge 380480f [Dong Wang] fix for liancheng's comments eb3e385 [Dong Wang] small nit 341885b [Dong Wang] small fix 3d8ebf8 [Dong Wang] add spark.sql.hive.thriftServer.async flag 04142c3 [Dong Wang] set SQLSession for async execution 184ec35 [Dong Wang] keep hive conf 819ae03 [Dong Wang] [SPARK-6964][SQL][WIP] Support Cancellation in the Thrift Server --- .../org/apache/spark/sql/SQLContext.scala | 5 + .../SparkExecuteStatementOperation.scala | 164 ++++++++++++++++-- .../server/SparkSQLOperationManager.scala | 7 +- .../HiveThriftServer2Suites.scala | 42 ++++- .../apache/spark/sql/hive/HiveContext.scala | 6 + 5 files changed, 208 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 0aab7fa8709b8..ddb54025baa24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -916,6 +916,11 @@ class SQLContext(@transient val sparkContext: SparkContext) tlSession.remove() } + protected[sql] def setSession(session: SQLSession): Unit = { + detachSession() + tlSession.set(session) + } + protected[sql] class SQLSession { // Note that this is a lazy val so we can override the default value in subclasses. protected[sql] lazy val conf: SQLConf = new SQLConf diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index c0d1266212cdd..e071103df925c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -17,11 +17,23 @@ package org.apache.spark.sql.hive.thriftserver +import java.security.PrivilegedExceptionAction import java.sql.{Date, Timestamp} +import java.util.concurrent.RejectedExecutionException import java.util.{Map => JMap, UUID} +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ArrayBuffer, Map => SMap} +import scala.util.control.NonFatal + +import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hive.service.cli._ +import org.apache.hadoop.hive.ql.metadata.Hive +import org.apache.hadoop.hive.ql.metadata.HiveException +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.hadoop.security.UserGroupInformation import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.HiveSession @@ -31,8 +43,6 @@ import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, Map => SMap} private[hive] class SparkExecuteStatementOperation( parentSession: HiveSession, @@ -40,17 +50,19 @@ private[hive] class SparkExecuteStatementOperation( confOverlay: JMap[String, String], runInBackground: Boolean = true) (hiveContext: HiveContext, sessionToActivePool: SMap[SessionHandle, String]) - // NOTE: `runInBackground` is set to `false` intentionally to disable asynchronous execution - extends ExecuteStatementOperation(parentSession, statement, confOverlay, false) + extends ExecuteStatementOperation(parentSession, statement, confOverlay, runInBackground) with Logging { private var result: DataFrame = _ private var iter: Iterator[SparkRow] = _ private var dataTypes: Array[DataType] = _ + private var statementId: String = _ def close(): Unit = { // RDDs will be cleaned automatically upon garbage collection. - logDebug("CLOSING") + hiveContext.sparkContext.clearJobGroup() + logDebug(s"CLOSING $statementId") + cleanup(OperationState.CLOSED) } def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int) { @@ -114,10 +126,10 @@ private[hive] class SparkExecuteStatementOperation( } def getResultSetSchema: TableSchema = { - logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") - if (result.queryExecution.analyzed.output.size == 0) { + if (result == null || result.queryExecution.analyzed.output.size == 0) { new TableSchema(new FieldSchema("Result", "string", "") :: Nil) } else { + logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") val schema = result.queryExecution.analyzed.output.map { attr => new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") } @@ -125,9 +137,73 @@ private[hive] class SparkExecuteStatementOperation( } } - def run(): Unit = { - val statementId = UUID.randomUUID().toString - logInfo(s"Running query '$statement'") + override def run(): Unit = { + setState(OperationState.PENDING) + setHasResultSet(true) // avoid no resultset for async run + + if (!runInBackground) { + runInternal() + } else { + val parentSessionState = SessionState.get() + val hiveConf = getConfigForOperation() + val sparkServiceUGI = ShimLoader.getHadoopShims.getUGIForConf(hiveConf) + val sessionHive = getCurrentHive() + val currentSqlSession = hiveContext.currentSession + + // Runnable impl to call runInternal asynchronously, + // from a different thread + val backgroundOperation = new Runnable() { + + override def run(): Unit = { + val doAsAction = new PrivilegedExceptionAction[Object]() { + override def run(): Object = { + + // User information is part of the metastore client member in Hive + hiveContext.setSession(currentSqlSession) + Hive.set(sessionHive) + SessionState.setCurrentSessionState(parentSessionState) + try { + runInternal() + } catch { + case e: HiveSQLException => + setOperationException(e) + log.error("Error running hive query: ", e) + } + return null + } + } + + try { + ShimLoader.getHadoopShims().doAs(sparkServiceUGI, doAsAction) + } catch { + case e: Exception => + setOperationException(new HiveSQLException(e)) + logError("Error running hive query as user : " + + sparkServiceUGI.getShortUserName(), e) + } + } + } + try { + // This submit blocks if no background threads are available to run this operation + val backgroundHandle = + getParentSession().getSessionManager().submitBackgroundOperation(backgroundOperation) + setBackgroundHandle(backgroundHandle) + } catch { + case rejected: RejectedExecutionException => + setState(OperationState.ERROR) + throw new HiveSQLException("The background threadpool cannot accept" + + " new task for execution, please retry the operation", rejected) + case NonFatal(e) => + logError(s"Error executing query in background", e) + setState(OperationState.ERROR) + throw e + } + } + } + + private def runInternal(): Unit = { + statementId = UUID.randomUUID().toString + logInfo(s"Running query '$statement' with $statementId") setState(OperationState.RUNNING) HiveThriftServer2.listener.onStatementStart( statementId, @@ -159,18 +235,82 @@ private[hive] class SparkExecuteStatementOperation( } } dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray - setHasResultSet(true) } catch { + case e: HiveSQLException => + if (getStatus().getState() == OperationState.CANCELED) { + return + } else { + setState(OperationState.ERROR); + throw e + } // Actually do need to catch Throwable as some failures don't inherit from Exception and // HiveServer will silently swallow them. case e: Throwable => + val currentState = getStatus().getState() + logError(s"Error executing query, currentState $currentState, ", e) setState(OperationState.ERROR) HiveThriftServer2.listener.onStatementError( statementId, e.getMessage, e.getStackTraceString) - logError("Error executing query:", e) throw new HiveSQLException(e.toString) } setState(OperationState.FINISHED) HiveThriftServer2.listener.onStatementFinish(statementId) } + + override def cancel(): Unit = { + logInfo(s"Cancel '$statement' with $statementId") + if (statementId != null) { + hiveContext.sparkContext.cancelJobGroup(statementId) + } + cleanup(OperationState.CANCELED) + } + + private def cleanup(state: OperationState) { + setState(state) + if (runInBackground) { + val backgroundHandle = getBackgroundHandle() + if (backgroundHandle != null) { + backgroundHandle.cancel(true) + } + } + } + + /** + * If there are query specific settings to overlay, then create a copy of config + * There are two cases we need to clone the session config that's being passed to hive driver + * 1. Async query - + * If the client changes a config setting, that shouldn't reflect in the execution + * already underway + * 2. confOverlay - + * The query specific settings should only be applied to the query config and not session + * @return new configuration + * @throws HiveSQLException + */ + private def getConfigForOperation(): HiveConf = { + var sqlOperationConf = getParentSession().getHiveConf() + if (!getConfOverlay().isEmpty() || runInBackground) { + // clone the partent session config for this query + sqlOperationConf = new HiveConf(sqlOperationConf) + + // apply overlay query specific settings, if any + getConfOverlay().foreach { case (k, v) => + try { + sqlOperationConf.verifyAndSet(k, v) + } catch { + case e: IllegalArgumentException => + throw new HiveSQLException("Error applying statement specific settings", e) + } + } + } + return sqlOperationConf + } + + private def getCurrentHive(): Hive = { + try { + return Hive.get() + } catch { + case e: HiveException => + throw new HiveSQLException("Failed to get current Hive object", e); + } + } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 9c0bf02391e0e..c8031ed0f3437 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -44,9 +44,12 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext) confOverlay: JMap[String, String], async: Boolean): ExecuteStatementOperation = synchronized { - val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay)( - hiveContext, sessionToActivePool) + val runInBackground = async && hiveContext.hiveThriftServerAsync + val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay, + runInBackground)(hiveContext, sessionToActivePool) handleToOperation.put(operation.getHandle, operation) + logDebug(s"Created Operation for $statement with session=$parentSession, " + + s"runInBackground=$runInBackground") operation } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index f57c7083ea504..178bd1f5cb164 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql.hive.thriftserver import java.io.File import java.net.URL -import java.sql.{Date, DriverManager, Statement} +import java.nio.charset.StandardCharsets +import java.sql.{Date, DriverManager, SQLException, Statement} import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ -import scala.concurrent.{Await, Promise} +import scala.concurrent.{Await, Promise, future} +import scala.concurrent.ExecutionContext.Implicits.global import scala.sys.process.{Process, ProcessLogger} import scala.util.{Random, Try} @@ -338,6 +340,42 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } ) } + + test("test jdbc cancel") { + withJdbcStatement { statement => + val queries = Seq( + "DROP TABLE IF EXISTS test_map", + "CREATE TABLE test_map(key INT, value STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map") + + queries.foreach(statement.execute) + + val largeJoin = "SELECT COUNT(*) FROM test_map " + + List.fill(10)("join test_map").mkString(" ") + val f = future { Thread.sleep(100); statement.cancel(); } + val e = intercept[SQLException] { + statement.executeQuery(largeJoin) + } + assert(e.getMessage contains "cancelled") + Await.result(f, 3.minute) + + // cancel is a noop + statement.executeQuery("SET spark.sql.hive.thriftServer.async=false") + val sf = future { Thread.sleep(100); statement.cancel(); } + val smallJoin = "SELECT COUNT(*) FROM test_map " + + List.fill(4)("join test_map").mkString(" ") + val rs1 = statement.executeQuery(smallJoin) + Await.result(sf, 3.minute) + rs1.next() + assert(rs1.getInt(1) === math.pow(5, 5)) + rs1.close() + + val rs2 = statement.executeQuery("SELECT COUNT(*) FROM test_map") + rs2.next() + assert(rs2.getInt(1) === 5) + rs2.close() + } + } } class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 800f51c5e2e86..b8f294c262af7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -144,6 +144,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { getConf("spark.sql.hive.metastore.barrierPrefixes", "") .split(",").filterNot(_ == "") + /* + * hive thrift server use background spark sql thread pool to execute sql queries + */ + protected[hive] def hiveThriftServerAsync: Boolean = + getConf("spark.sql.hive.thriftServer.async", "true").toBoolean + @transient protected[sql] lazy val substitutor = new VariableSubstitution() From a71be0a36de94b3962c09f871845d745047a78e6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 5 Jun 2015 23:15:10 -0700 Subject: [PATCH 211/254] [SPARK-8114][SQL] Remove some wildcard import on TestSQLContext._ round 3. Author: Reynold Xin Closes #6677 from rxin/test-wildcard and squashes the following commits: 8a17b33 [Reynold Xin] Fixed line length. 6663813 [Reynold Xin] [SPARK-8114][SQL] Remove some wildcard import on TestSQLContext._ round 3. --- .../execution/SparkSqlSerializer2Suite.scala | 43 ++++++++--------- .../sources/CreateTableAsSelectSuite.scala | 14 +++--- .../spark/sql/sources/DDLTestSuite.scala | 24 +++++----- .../spark/sql/sources/DataSourceTest.scala | 12 +++-- .../spark/sql/sources/FilteredScanSuite.scala | 2 +- .../spark/sql/sources/InsertSuite.scala | 29 ++++++------ .../spark/sql/sources/PrunedScanSuite.scala | 5 +- .../spark/sql/sources/SaveLoadSuite.scala | 35 ++++++++------ .../spark/sql/sources/TableScanSuite.scala | 10 ++-- .../spark/sql/hive/QueryPartitionSuite.scala | 16 +++---- .../spark/sql/hive/SerializationSuite.scala | 3 +- .../spark/sql/hive/StatisticsSuite.scala | 37 ++++++++------- .../org/apache/spark/sql/hive/UDFSuite.scala | 18 ++++---- .../sql/sources/hadoopFsRelationSuites.scala | 46 +++++++++---------- 14 files changed, 156 insertions(+), 138 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala index 6ca5390cde23e..8631e247c6c05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import java.sql.{Timestamp, Date} +import org.apache.spark.sql.test.TestSQLContext import org.scalatest.BeforeAndAfterAll import org.apache.spark.rdd.ShuffledRDD @@ -26,7 +27,6 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.{ShuffleDependency, SparkFunSuite} import org.apache.spark.sql.types._ import org.apache.spark.sql.Row -import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest} class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite { @@ -74,11 +74,13 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll var numShufflePartitions: Int = _ var useSerializer2: Boolean = _ + protected lazy val ctx = TestSQLContext + override def beforeAll(): Unit = { - numShufflePartitions = conf.numShufflePartitions - useSerializer2 = conf.useSqlSerializer2 + numShufflePartitions = ctx.conf.numShufflePartitions + useSerializer2 = ctx.conf.useSqlSerializer2 - sql("set spark.sql.useSerializer2=true") + ctx.sql("set spark.sql.useSerializer2=true") val supportedTypes = Seq(StringType, BinaryType, NullType, BooleanType, @@ -94,7 +96,7 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll // Create a RDD with all data types supported by SparkSqlSerializer2. val rdd = - sparkContext.parallelize((1 to 1000), 10).map { i => + ctx.sparkContext.parallelize((1 to 1000), 10).map { i => Row( s"str${i}: test serializer2.", s"binary${i}: test serializer2.".getBytes("UTF-8"), @@ -112,15 +114,15 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll new Timestamp(i)) } - createDataFrame(rdd, schema).registerTempTable("shuffle") + ctx.createDataFrame(rdd, schema).registerTempTable("shuffle") super.beforeAll() } override def afterAll(): Unit = { - dropTempTable("shuffle") - sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions") - sql(s"set spark.sql.useSerializer2=$useSerializer2") + ctx.dropTempTable("shuffle") + ctx.sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions") + ctx.sql(s"set spark.sql.useSerializer2=$useSerializer2") super.afterAll() } @@ -141,16 +143,16 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll } test("key schema and value schema are not nulls") { - val df = sql(s"SELECT DISTINCT ${allColumns} FROM shuffle") + val df = ctx.sql(s"SELECT DISTINCT ${allColumns} FROM shuffle") checkSerializer(df.queryExecution.executedPlan, serializerClass) checkAnswer( df, - table("shuffle").collect()) + ctx.table("shuffle").collect()) } test("key schema is null") { val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",") - val df = sql(s"SELECT $aggregations FROM shuffle") + val df = ctx.sql(s"SELECT $aggregations FROM shuffle") checkSerializer(df.queryExecution.executedPlan, serializerClass) checkAnswer( df, @@ -158,15 +160,14 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll } test("value schema is null") { - val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0") + val df = ctx.sql(s"SELECT col0 FROM shuffle ORDER BY col0") checkSerializer(df.queryExecution.executedPlan, serializerClass) - assert( - df.map(r => r.getString(0)).collect().toSeq === - table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq) + assert(df.map(r => r.getString(0)).collect().toSeq === + ctx.table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq) } test("no map output field") { - val df = sql(s"SELECT 1 + 1 FROM shuffle") + val df = ctx.sql(s"SELECT 1 + 1 FROM shuffle") checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer]) } } @@ -177,8 +178,8 @@ class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite { super.beforeAll() // Sort merge will not be triggered. val bypassMergeThreshold = - sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}") + ctx.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + ctx.sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}") } } @@ -189,7 +190,7 @@ class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite super.beforeAll() // To trigger the sort merge. val bypassMergeThreshold = - sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}") + ctx.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + ctx.sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index d2d1011b8e917..a71088430bfd5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -26,18 +26,20 @@ import org.apache.spark.util.Utils class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { - import caseInsensitiveContext._ + import caseInsensitiveContext.sql + + private lazy val sparkContext = caseInsensitiveContext.sparkContext var path: File = null override def beforeAll(): Unit = { path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - read.json(rdd).registerTempTable("jt") + caseInsensitiveContext.read.json(rdd).registerTempTable("jt") } override def afterAll(): Unit = { - dropTempTable("jt") + caseInsensitiveContext.dropTempTable("jt") } after { @@ -59,7 +61,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql("SELECT a, b FROM jsonTable"), sql("SELECT a, b FROM jt").collect()) - dropTempTable("jsonTable") + caseInsensitiveContext.dropTempTable("jsonTable") } test("CREATE TEMPORARY TABLE AS SELECT based on the file without write permission") { @@ -129,7 +131,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql("SELECT * FROM jsonTable"), sql("SELECT a * 4 FROM jt").collect()) - dropTempTable("jsonTable") + caseInsensitiveContext.dropTempTable("jsonTable") // Explicitly delete the data. if (path.exists()) Utils.deleteRecursively(path) @@ -147,7 +149,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql("SELECT * FROM jsonTable"), sql("SELECT b FROM jt").collect()) - dropTempTable("jsonTable") + caseInsensitiveContext.dropTempTable("jsonTable") } test("CREATE TEMPORARY TABLE AS SELECT with IF NOT EXISTS is not allowed") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 5c3467158a01b..51d22b6a1378a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -63,19 +63,18 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo } class DDLTestSuite extends DataSourceTest { - import caseInsensitiveContext._ before { - sql( - """ - |CREATE TEMPORARY TABLE ddlPeople - |USING org.apache.spark.sql.sources.DDLScanSource - |OPTIONS ( - | From '1', - | To '10', - | Table 'test1' - |) - """.stripMargin) + caseInsensitiveContext.sql( + """ + |CREATE TEMPORARY TABLE ddlPeople + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) } sqlTest( @@ -100,7 +99,8 @@ class DDLTestSuite extends DataSourceTest { )) test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") { - val attributes = sql("describe ddlPeople").queryExecution.executedPlan.output + val attributes = caseInsensitiveContext.sql("describe ddlPeople") + .queryExecution.executedPlan.output assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment")) assert(attributes.map(_.dataType).toSet === Set(StringType)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 24ed665c67d2e..3f77960d09246 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -17,14 +17,18 @@ package org.apache.spark.sql.sources +import org.scalatest.BeforeAndAfter + import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.test.TestSQLContext -import org.scalatest.BeforeAndAfter + abstract class DataSourceTest extends QueryTest with BeforeAndAfter { // We want to test some edge cases. - implicit val caseInsensitiveContext = new SQLContext(TestSQLContext.sparkContext) + protected implicit lazy val caseInsensitiveContext = { + val ctx = new SQLContext(TestSQLContext.sparkContext) + ctx.setConf(SQLConf.CASE_SENSITIVE, "false") + ctx + } - caseInsensitiveContext.setConf(SQLConf.CASE_SENSITIVE, "false") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index db94b1f3e8926..81b3a0f0c5b3a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -97,7 +97,7 @@ object FiltersPushed { class FilteredScanSuite extends DataSourceTest { - import caseInsensitiveContext._ + import caseInsensitiveContext.sql before { sql( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 6f375ef36237d..0b7c46c482c88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -26,14 +26,16 @@ import org.apache.spark.util.Utils class InsertSuite extends DataSourceTest with BeforeAndAfterAll { - import caseInsensitiveContext._ + import caseInsensitiveContext.sql + + private lazy val sparkContext = caseInsensitiveContext.sparkContext var path: File = null override def beforeAll: Unit = { path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - read.json(rdd).registerTempTable("jt") + caseInsensitiveContext.read.json(rdd).registerTempTable("jt") sql( s""" |CREATE TEMPORARY TABLE jsonTable (a int, b string) @@ -45,8 +47,8 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { } override def afterAll: Unit = { - dropTempTable("jsonTable") - dropTempTable("jt") + caseInsensitiveContext.dropTempTable("jsonTable") + caseInsensitiveContext.dropTempTable("jt") Utils.deleteRecursively(path) } @@ -109,7 +111,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { // Writing the table to less part files. val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 5) - read.json(rdd1).registerTempTable("jt1") + caseInsensitiveContext.read.json(rdd1).registerTempTable("jt1") sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt1 @@ -121,7 +123,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { // Writing the table to more part files. val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 10) - read.json(rdd2).registerTempTable("jt2") + caseInsensitiveContext.read.json(rdd2).registerTempTable("jt2") sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt2 @@ -140,8 +142,8 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { (1 to 10).map(i => Row(i * 10, s"str$i")) ) - dropTempTable("jt1") - dropTempTable("jt2") + caseInsensitiveContext.dropTempTable("jt1") + caseInsensitiveContext.dropTempTable("jt2") } test("INSERT INTO not supported for JSONRelation for now") { @@ -154,13 +156,14 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { } test("save directly to the path of a JSON table") { - table("jt").selectExpr("a * 5 as a", "b").write.mode(SaveMode.Overwrite).json(path.toString) + caseInsensitiveContext.table("jt").selectExpr("a * 5 as a", "b") + .write.mode(SaveMode.Overwrite).json(path.toString) checkAnswer( sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i * 5, s"str$i")) ) - table("jt").write.mode(SaveMode.Overwrite).json(path.toString) + caseInsensitiveContext.table("jt").write.mode(SaveMode.Overwrite).json(path.toString) checkAnswer( sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i, s"str$i")) @@ -181,7 +184,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { test("Caching") { // Cached Query Execution - cacheTable("jsonTable") + caseInsensitiveContext.cacheTable("jsonTable") assertCached(sql("SELECT * FROM jsonTable")) checkAnswer( sql("SELECT * FROM jsonTable"), @@ -220,7 +223,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { sql("SELECT a * 2, b FROM jt").collect()) // Verify uncaching - uncacheTable("jsonTable") + caseInsensitiveContext.uncacheTable("jsonTable") assertCached(sql("SELECT * FROM jsonTable"), 0) } @@ -251,6 +254,6 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { "It is not allowed to insert into a table that is not an InsertableRelation." ) - dropTempTable("oneToTen") + caseInsensitiveContext.dropTempTable("oneToTen") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index c2bc52e2120c1..257526feab945 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -52,10 +52,9 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo } class PrunedScanSuite extends DataSourceTest { - import caseInsensitiveContext._ before { - sql( + caseInsensitiveContext.sql( """ |CREATE TEMPORARY TABLE oneToTenPruned |USING org.apache.spark.sql.sources.PrunedScanSource @@ -115,7 +114,7 @@ class PrunedScanSuite extends DataSourceTest { def testPruning(sqlString: String, expectedColumns: String*): Unit = { test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") { - val queryExecution = sql(sqlString).queryExecution + val queryExecution = caseInsensitiveContext.sql(sqlString).queryExecution val rawPlan = queryExecution.executedPlan.collect { case p: execution.PhysicalRDD => p } match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index 274c652dd14d6..b032515a9d28c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -27,7 +27,9 @@ import org.apache.spark.util.Utils class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { - import caseInsensitiveContext._ + import caseInsensitiveContext.sql + + private lazy val sparkContext = caseInsensitiveContext.sparkContext var originalDefaultSource: String = null @@ -36,60 +38,63 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { var df: DataFrame = null override def beforeAll(): Unit = { - originalDefaultSource = conf.defaultDataSourceName + originalDefaultSource = caseInsensitiveContext.conf.defaultDataSourceName path = Utils.createTempDir() path.delete() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - df = read.json(rdd) + df = caseInsensitiveContext.read.json(rdd) df.registerTempTable("jsonTable") } override def afterAll(): Unit = { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) } after { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) Utils.deleteRecursively(path) } def checkLoad(): Unit = { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") - checkAnswer(read.load(path.toString), df.collect()) + caseInsensitiveContext.conf.setConf( + SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + checkAnswer(caseInsensitiveContext.read.load(path.toString), df.collect()) // Test if we can pick up the data source name passed in load. - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - checkAnswer(read.format("json").load(path.toString), df.collect()) - checkAnswer(read.format("json").load(path.toString), df.collect()) + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + checkAnswer(caseInsensitiveContext.read.format("json").load(path.toString), df.collect()) + checkAnswer(caseInsensitiveContext.read.format("json").load(path.toString), df.collect()) val schema = StructType(StructField("b", StringType, true) :: Nil) checkAnswer( - read.format("json").schema(schema).load(path.toString), + caseInsensitiveContext.read.format("json").schema(schema).load(path.toString), sql("SELECT b FROM jsonTable").collect()) } test("save with path and load") { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + caseInsensitiveContext.conf.setConf( + SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") df.write.save(path.toString) checkLoad() } test("save with string mode and path, and load") { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + caseInsensitiveContext.conf.setConf( + SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") path.createNewFile() df.write.mode("overwrite").save(path.toString) checkLoad() } test("save with path and datasource, and load") { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") df.write.json(path.toString) checkLoad() } test("save with data source and options, and load") { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") df.write.mode(SaveMode.ErrorIfExists).json(path.toString) checkLoad() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 77af04a491742..5d4ecd810862c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -88,9 +88,9 @@ case class AllDataTypesScan( } class TableScanSuite extends DataSourceTest { - import caseInsensitiveContext._ + import caseInsensitiveContext.sql - var tableWithSchemaExpected = (1 to 10).map { i => + private lazy val tableWithSchemaExpected = (1 to 10).map { i => Row( s"str_$i", s"str_$i", @@ -215,7 +215,7 @@ class TableScanSuite extends DataSourceTest { Nil ) - assert(expectedSchema == table("tableWithSchema").schema) + assert(expectedSchema == caseInsensitiveContext.table("tableWithSchema").schema) checkAnswer( sql( @@ -270,7 +270,7 @@ class TableScanSuite extends DataSourceTest { test("Caching") { // Cached Query Execution - cacheTable("oneToTen") + caseInsensitiveContext.cacheTable("oneToTen") assertCached(sql("SELECT * FROM oneToTen")) checkAnswer( sql("SELECT * FROM oneToTen"), @@ -297,7 +297,7 @@ class TableScanSuite extends DataSourceTest { (2 to 10).map(i => Row(i, i - 1)).toSeq) // Verify uncaching - uncacheTable("oneToTen") + caseInsensitiveContext.uncacheTable("oneToTen") assertCached(sql("SELECT * FROM oneToTen"), 0) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index 4990092df6a99..017bc2adc103b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -20,16 +20,17 @@ package org.apache.spark.sql.hive import com.google.common.io.Files import org.apache.spark.sql.{QueryTest, _} -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.util.Utils class QueryPartitionSuite extends QueryTest { - import org.apache.spark.sql.hive.test.TestHive.implicits._ + + private lazy val ctx = org.apache.spark.sql.hive.test.TestHive + import ctx.implicits._ + import ctx.sql test("SPARK-5068: query data when path doesn't exist"){ - val testData = TestHive.sparkContext.parallelize( + val testData = ctx.sparkContext.parallelize( (1 to 10).map(i => TestData(i, i.toString))).toDF() testData.registerTempTable("testData") @@ -48,8 +49,8 @@ class QueryPartitionSuite extends QueryTest { // test for the exist path checkAnswer(sql("select key,value from table_with_partition"), - testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect - ++ testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect) + testData.toDF.collect ++ testData.toDF.collect + ++ testData.toDF.collect ++ testData.toDF.collect) // delete the path of one partition tmpDir.listFiles @@ -58,8 +59,7 @@ class QueryPartitionSuite extends QueryTest { // test for after delete the path checkAnswer(sql("select key,value from table_with_partition"), - testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect - ++ testData.toSchemaRDD.collect) + testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) sql("DROP TABLE table_with_partition") sql("DROP TABLE createAndInsertTest") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala index a492ecf203d17..93dcb10f7a296 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala @@ -19,12 +19,11 @@ package org.apache.spark.sql.hive import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.sql.hive.test.TestHive class SerializationSuite extends SparkFunSuite { test("[SPARK-5840] HiveContext should be serializable") { - val hiveContext = TestHive + val hiveContext = org.apache.spark.sql.hive.test.TestHive hiveContext.hiveconf val serializer = new JavaSerializer(new SparkConf()).newInstance() val bytes = serializer.serialize(hiveContext) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index e16e530555aee..78c94e6490e36 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -23,13 +23,18 @@ import scala.reflect.ClassTag import org.apache.spark.sql.{Row, SQLConf, QueryTest} import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.execution._ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { - TestHive.reset() - TestHive.cacheTables = false + + private lazy val ctx: HiveContext = { + val ctx = org.apache.spark.sql.hive.test.TestHive + ctx.reset() + ctx.cacheTables = false + ctx + } + + import ctx.sql test("parse analyze commands") { def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { @@ -72,7 +77,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { test("analyze MetastoreRelations") { def queryTotalSize(tableName: String): BigInt = - catalog.lookupRelation(Seq(tableName)).statistics.sizeInBytes + ctx.catalog.lookupRelation(Seq(tableName)).statistics.sizeInBytes // Non-partitioned table sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() @@ -106,7 +111,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { |SELECT * FROM src """.stripMargin).collect() - assert(queryTotalSize("analyzeTable_part") === conf.defaultSizeInBytes) + assert(queryTotalSize("analyzeTable_part") === ctx.conf.defaultSizeInBytes) sql("ANALYZE TABLE analyzeTable_part COMPUTE STATISTICS noscan") @@ -117,9 +122,9 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { // Try to analyze a temp table sql("""SELECT * FROM src""").registerTempTable("tempTable") intercept[UnsupportedOperationException] { - analyze("tempTable") + ctx.analyze("tempTable") } - catalog.unregisterTable(Seq("tempTable")) + ctx.catalog.unregisterTable(Seq("tempTable")) } test("estimates the size of a test MetastoreRelation") { @@ -147,8 +152,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { val sizes = df.queryExecution.analyzed.collect { case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } - assert(sizes.size === 2 && sizes(0) <= conf.autoBroadcastJoinThreshold - && sizes(1) <= conf.autoBroadcastJoinThreshold, + assert(sizes.size === 2 && sizes(0) <= ctx.conf.autoBroadcastJoinThreshold + && sizes(1) <= ctx.conf.autoBroadcastJoinThreshold, s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be @@ -159,8 +164,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { checkAnswer(df, expectedAnswer) // check correctness of output - TestHive.conf.settings.synchronized { - val tmp = conf.autoBroadcastJoinThreshold + ctx.conf.settings.synchronized { + val tmp = ctx.conf.autoBroadcastJoinThreshold sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1""") df = sql(query) @@ -203,8 +208,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { .isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } - assert(sizes.size === 2 && sizes(1) <= conf.autoBroadcastJoinThreshold - && sizes(0) <= conf.autoBroadcastJoinThreshold, + assert(sizes.size === 2 && sizes(1) <= ctx.conf.autoBroadcastJoinThreshold + && sizes(0) <= ctx.conf.autoBroadcastJoinThreshold, s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be @@ -217,8 +222,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { checkAnswer(df, answer) // check correctness of output - TestHive.conf.settings.synchronized { - val tmp = conf.autoBroadcastJoinThreshold + ctx.conf.settings.synchronized { + val tmp = ctx.conf.autoBroadcastJoinThreshold sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") df = sql(leftSemiJoinQuery) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 8245047626d57..4056dee777574 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -17,20 +17,20 @@ package org.apache.spark.sql.hive -/* Implicits */ - import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.hive.test.TestHive._ case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { + + private lazy val ctx = org.apache.spark.sql.hive.test.TestHive + test("UDF case insensitive") { - udf.register("random0", () => { Math.random() }) - udf.register("RANDOM1", () => { Math.random() }) - udf.register("strlenScala", (_: String).length + (_: Int)) - assert(sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) - assert(sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) - assert(sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) + ctx.udf.register("random0", () => { Math.random() }) + ctx.udf.register("RANDOM1", () => { Math.random() }) + ctx.udf.register("strlenScala", (_: String).length + (_: Int)) + assert(ctx.sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(ctx.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(ctx.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 74095426741e3..8787663a98f8f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -30,9 +30,9 @@ import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { - override val sqlContext: SQLContext = TestHive + override lazy val sqlContext: SQLContext = TestHive - import sqlContext._ + import sqlContext.sql import sqlContext.implicits._ val dataSourceName = classOf[SimpleTextSource].getCanonicalName @@ -43,19 +43,19 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { StructField("a", IntegerType, nullable = false), StructField("b", StringType, nullable = false))) - val testDF = (1 to 3).map(i => (i, s"val_$i")).toDF("a", "b") + lazy val testDF = (1 to 3).map(i => (i, s"val_$i")).toDF("a", "b") - val partitionedTestDF1 = (for { + lazy val partitionedTestDF1 = (for { i <- 1 to 3 p2 <- Seq("foo", "bar") } yield (i, s"val_$i", 1, p2)).toDF("a", "b", "p1", "p2") - val partitionedTestDF2 = (for { + lazy val partitionedTestDF2 = (for { i <- 1 to 3 p2 <- Seq("foo", "bar") } yield (i, s"val_$i", 2, p2)).toDF("a", "b", "p1", "p2") - val partitionedTestDF = partitionedTestDF1.unionAll(partitionedTestDF2) + lazy val partitionedTestDF = partitionedTestDF1.unionAll(partitionedTestDF2) def checkQueries(df: DataFrame): Unit = { // Selects everything @@ -103,7 +103,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { testDF.write.mode(SaveMode.Overwrite).format(dataSourceName).save(file.getCanonicalPath) checkAnswer( - read.format(dataSourceName) + sqlContext.read.format(dataSourceName) .option("path", file.getCanonicalPath) .option("dataSchema", dataSchema.json) .load(), @@ -117,7 +117,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { testDF.write.mode(SaveMode.Append).format(dataSourceName).save(file.getCanonicalPath) checkAnswer( - read.format(dataSourceName) + sqlContext.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath).orderBy("a"), testDF.unionAll(testDF).orderBy("a").collect()) @@ -151,7 +151,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .save(file.getCanonicalPath) checkQueries( - read.format(dataSourceName) + sqlContext.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath)) } @@ -172,7 +172,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .save(file.getCanonicalPath) checkAnswer( - read.format(dataSourceName) + sqlContext.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath), partitionedTestDF.collect()) @@ -194,7 +194,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .save(file.getCanonicalPath) checkAnswer( - read.format(dataSourceName) + sqlContext.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath), partitionedTestDF.unionAll(partitionedTestDF).collect()) @@ -216,7 +216,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .save(file.getCanonicalPath) checkAnswer( - read.format(dataSourceName) + sqlContext.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath), partitionedTestDF.collect()) @@ -252,7 +252,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .saveAsTable("t") withTable("t") { - checkAnswer(table("t"), testDF.collect()) + checkAnswer(sqlContext.table("t"), testDF.collect()) } } @@ -261,7 +261,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { testDF.write.format(dataSourceName).mode(SaveMode.Append).saveAsTable("t") withTable("t") { - checkAnswer(table("t"), testDF.unionAll(testDF).orderBy("a").collect()) + checkAnswer(sqlContext.table("t"), testDF.unionAll(testDF).orderBy("a").collect()) } } @@ -280,7 +280,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { withTempTable("t") { testDF.write.format(dataSourceName).mode(SaveMode.Ignore).saveAsTable("t") - assert(table("t").collect().isEmpty) + assert(sqlContext.table("t").collect().isEmpty) } } @@ -291,7 +291,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .saveAsTable("t") withTable("t") { - checkQueries(table("t")) + checkQueries(sqlContext.table("t")) } } @@ -311,7 +311,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .saveAsTable("t") withTable("t") { - checkAnswer(table("t"), partitionedTestDF.collect()) + checkAnswer(sqlContext.table("t"), partitionedTestDF.collect()) } } @@ -331,7 +331,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .saveAsTable("t") withTable("t") { - checkAnswer(table("t"), partitionedTestDF.unionAll(partitionedTestDF).collect()) + checkAnswer(sqlContext.table("t"), partitionedTestDF.unionAll(partitionedTestDF).collect()) } } @@ -351,7 +351,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .saveAsTable("t") withTable("t") { - checkAnswer(table("t"), partitionedTestDF.collect()) + checkAnswer(sqlContext.table("t"), partitionedTestDF.collect()) } } @@ -400,7 +400,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .partitionBy("p1", "p2") .saveAsTable("t") - assert(table("t").collect().isEmpty) + assert(sqlContext.table("t").collect().isEmpty) } } @@ -412,7 +412,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .partitionBy("p1", "p2") .save(file.getCanonicalPath) - val df = read + val df = sqlContext.read .format(dataSourceName) .option("dataSchema", dataSchema.json) .load(s"${file.getCanonicalPath}/p1=*/p2=???") @@ -452,7 +452,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .saveAsTable("t") withTempTable("t") { - checkAnswer(table("t"), input.collect()) + checkAnswer(sqlContext.table("t"), input.collect()) } } } @@ -467,7 +467,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .saveAsTable("t") withTable("t") { - checkAnswer(table("t"), df.select('b, 'c, 'a).collect()) + checkAnswer(sqlContext.table("t"), df.select('b, 'c, 'a).collect()) } } } From a8077e5cfc48bdb9f0641d62fe6c01cc8c4f1694 Mon Sep 17 00:00:00 2001 From: Xu Tingjun Date: Sat, 6 Jun 2015 09:53:53 +0100 Subject: [PATCH 212/254] [SPARK-6973] remove skipped stage ID from completed set on the allJobsPage Though totalStages = allStages - skippedStages is understandable. But consider the problem [SPARK-6973], I think totalStages = allStages is more reasonable. Like "2/1 (2 failed) (1 skipped)", this item also shows the skipped num, it also will be understandable. Author: Xu Tingjun Author: Xutingjun Author: meiyoula <1039320815@qq.com> Closes #5550 from XuTingjun/allJobsPage and squashes the following commits: a742541 [Xu Tingjun] delete the loop 40ce94b [Xutingjun] remove stage id from completed set if it retries again 6459238 [meiyoula] delete space 9e23c71 [Xu Tingjun] recover numSkippedStages b987ea7 [Xutingjun] delete skkiped stages from completed set 47525c6 [Xu Tingjun] modify total stages/tasks on the allJobsPage --- .../org/apache/spark/ui/jobs/JobProgressListener.scala | 7 ++++++- core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala | 3 ++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 1d31fce4c697b..730f9806e518e 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -282,7 +282,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { ) { jobData.numActiveStages -= 1 if (stage.failureReason.isEmpty) { - jobData.completedStageIndices.add(stage.stageId) + if (!stage.submissionTime.isEmpty) { + jobData.completedStageIndices.add(stage.stageId) + } } else { jobData.numFailedStages += 1 } @@ -315,6 +317,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { jobData <- jobIdToData.get(jobId) ) { jobData.numActiveStages += 1 + + // If a stage retries again, it should be removed from completedStageIndices set + jobData.completedStageIndices.remove(stage.stageId) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 3d96113aa5fe9..f008d40180611 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -22,6 +22,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} import org.apache.spark.util.collection.OpenHashSet +import scala.collection.mutable import scala.collection.mutable.HashMap private[spark] object UIData { @@ -63,7 +64,7 @@ private[spark] object UIData { /* Stages */ var numActiveStages: Int = 0, // This needs to be a set instead of a simple count to prevent double-counting of rerun stages: - var completedStageIndices: OpenHashSet[Int] = new OpenHashSet[Int](), + var completedStageIndices: mutable.HashSet[Int] = new mutable.HashSet[Int](), var numSkippedStages: Int = 0, var numFailedStages: Int = 0 ) From 16fc49617e1dfcbe9122b224f7f63b7bfddb36ce Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 6 Jun 2015 17:23:12 +0800 Subject: [PATCH 213/254] [SPARK-8079] [SQL] Makes InsertIntoHadoopFsRelation job/task abortion more robust As described in SPARK-8079, when writing a DataFrame to a `HadoopFsRelation`, if `HadoopFsRelation.prepareForWriteJob` throws exception, an unexpected NPE will be thrown during job abortion. (This issue doesn't bring much damage since the job is failing anyway.) This PR makes the job/task abortion logic in `InsertIntoHadoopFsRelation` more robust to avoid such confusing exceptions. Author: Cheng Lian Closes #6612 from liancheng/spark-8079 and squashes the following commits: 87cd81e [Cheng Lian] Addresses @rxin's comment 1864c75 [Cheng Lian] Addresses review comments 9e6dbb3 [Cheng Lian] Makes InsertIntoHadoopFsRelation job/task abortion more robust --- .../apache/spark/sql/sources/commands.scala | 93 ++++++++++++------- .../sql/sources/hadoopFsRelationSuites.scala | 15 +++ 2 files changed, 76 insertions(+), 32 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index e9932c09107db..bd3aad6631748 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection -import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext, SaveMode} @@ -127,8 +127,11 @@ private[sql] case class InsertIntoHadoopFsRelation( val needsConversion = relation.needConversion val dataSchema = relation.dataSchema + // This call shouldn't be put into the `try` block below because it only initializes and + // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. + writerContainer.driverSideSetup() + try { - writerContainer.driverSideSetup() df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _) writerContainer.commitJob() relation.refresh() @@ -139,9 +142,10 @@ private[sql] case class InsertIntoHadoopFsRelation( } def writeRows(taskContext: TaskContext, iterator: Iterator[Row]): Unit = { - writerContainer.executorSideSetup(taskContext) - + // If anything below fails, we should abort the task. try { + writerContainer.executorSideSetup(taskContext) + if (needsConversion) { val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) while (iterator.hasNext) { @@ -154,6 +158,7 @@ private[sql] case class InsertIntoHadoopFsRelation( writerContainer.outputWriterForRow(row).write(row) } } + writerContainer.commitTask() } catch { case cause: Throwable => logError("Aborting task.", cause) @@ -191,8 +196,11 @@ private[sql] case class InsertIntoHadoopFsRelation( val (partitionOutput, dataOutput) = output.partition(a => partitionColumns.contains(a.name)) val codegenEnabled = df.sqlContext.conf.codegenEnabled + // This call shouldn't be put into the `try` block below because it only initializes and + // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. + writerContainer.driverSideSetup() + try { - writerContainer.driverSideSetup() df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _) writerContainer.commitJob() relation.refresh() @@ -203,32 +211,39 @@ private[sql] case class InsertIntoHadoopFsRelation( } def writeRows(taskContext: TaskContext, iterator: Iterator[Row]): Unit = { - writerContainer.executorSideSetup(taskContext) - - val partitionProj = newProjection(codegenEnabled, partitionOutput, output) - val dataProj = newProjection(codegenEnabled, dataOutput, output) - - if (needsConversion) { - val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) - while (iterator.hasNext) { - val row = iterator.next() - val partitionPart = partitionProj(row) - val dataPart = dataProj(row) - val convertedDataPart = converter(dataPart).asInstanceOf[Row] - writerContainer.outputWriterForRow(partitionPart).write(convertedDataPart) - } - } else { - val partitionSchema = StructType.fromAttributes(partitionOutput) - val converter = CatalystTypeConverters.createToScalaConverter(partitionSchema) - while (iterator.hasNext) { - val row = iterator.next() - val partitionPart = converter(partitionProj(row)).asInstanceOf[Row] - val dataPart = dataProj(row) - writerContainer.outputWriterForRow(partitionPart).write(dataPart) + // If anything below fails, we should abort the task. + try { + writerContainer.executorSideSetup(taskContext) + + val partitionProj = newProjection(codegenEnabled, partitionOutput, output) + val dataProj = newProjection(codegenEnabled, dataOutput, output) + + if (needsConversion) { + val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) + while (iterator.hasNext) { + val row = iterator.next() + val partitionPart = partitionProj(row) + val dataPart = dataProj(row) + val convertedDataPart = converter(dataPart).asInstanceOf[Row] + writerContainer.outputWriterForRow(partitionPart).write(convertedDataPart) + } + } else { + val partitionSchema = StructType.fromAttributes(partitionOutput) + val converter = CatalystTypeConverters.createToScalaConverter(partitionSchema) + while (iterator.hasNext) { + val row = iterator.next() + val partitionPart = converter(partitionProj(row)).asInstanceOf[Row] + val dataPart = dataProj(row) + writerContainer.outputWriterForRow(partitionPart).write(dataPart) + } } - } - writerContainer.commitTask() + writerContainer.commitTask() + } catch { case cause: Throwable => + logError("Aborting task.", cause) + writerContainer.abortTask() + throw new SparkException("Task failed while writing rows.", cause) + } } } @@ -283,7 +298,12 @@ private[sql] abstract class BaseWriterContainer( setupIDs(0, 0, 0) setupConf() taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) + + // This preparation must happen before initializing output format and output committer, since + // their initialization involves the job configuration, which can be potentially decorated in + // `relation.prepareJobForWrite`. outputWriterFactory = relation.prepareJobForWrite(job) + outputFormatClass = job.getOutputFormatClass outputCommitter = newOutputCommitter(taskAttemptContext) outputCommitter.setupJob(jobContext) @@ -359,7 +379,9 @@ private[sql] abstract class BaseWriterContainer( } def abortTask(): Unit = { - outputCommitter.abortTask(taskAttemptContext) + if (outputCommitter != null) { + outputCommitter.abortTask(taskAttemptContext) + } logError(s"Task attempt $taskAttemptId aborted.") } @@ -369,7 +391,9 @@ private[sql] abstract class BaseWriterContainer( } def abortJob(): Unit = { - outputCommitter.abortJob(jobContext, JobStatus.State.FAILED) + if (outputCommitter != null) { + outputCommitter.abortJob(jobContext, JobStatus.State.FAILED) + } logError(s"Job $jobId aborted.") } } @@ -390,6 +414,7 @@ private[sql] class DefaultWriterContainer( override def commitTask(): Unit = { try { + assert(writer != null, "OutputWriter instance should have been initialized") writer.close() super.commitTask() } catch { @@ -401,7 +426,9 @@ private[sql] class DefaultWriterContainer( override def abortTask(): Unit = { try { - writer.close() + if (writer != null) { + writer.close() + } } finally { super.abortTask() } @@ -445,6 +472,7 @@ private[sql] class DynamicPartitionWriterContainer( override def commitTask(): Unit = { try { outputWriters.values.foreach(_.close()) + outputWriters.clear() super.commitTask() } catch { case cause: Throwable => super.abortTask() @@ -455,6 +483,7 @@ private[sql] class DynamicPartitionWriterContainer( override def abortTask(): Unit = { try { outputWriters.values.foreach(_.close()) + outputWriters.clear() } finally { super.abortTask() } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 8787663a98f8f..76469d7a3d6a5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -594,4 +594,19 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { checkAnswer(read.format("parquet").load(path), df) } } + + test("SPARK-8079: Avoid NPE thrown from BaseWriterContainer.abortJob") { + withTempPath { dir => + intercept[AnalysisException] { + // Parquet doesn't allow field names with spaces. Here we are intentionally making an + // exception thrown from the `ParquetRelation2.prepareForWriteJob()` method to trigger + // the bug. Please refer to spark-8079 for more details. + range(1, 10) + .withColumnRenamed("id", "a b") + .write + .format("parquet") + .save(dir.getCanonicalPath) + } + } + } } From 5aa804f3c6485670937a658ce8207c2317c6a506 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Sat, 6 Jun 2015 14:52:14 -0700 Subject: [PATCH 214/254] [SPARK-7639] [PYSPARK] [MLLIB] Python API for KernelDensity Python API for KernelDensity Author: MechCoder Closes #6387 from MechCoder/spark-7639 and squashes the following commits: 17abc62 [MechCoder] add tests 2de6540 [MechCoder] style tests bf4acc0 [MechCoder] Added doctests 84359d5 [MechCoder] [SPARK-7639] Python API for KernelDensity --- .../mllib/api/python/PythonMLLibAPI.scala | 12 +++- python/pyspark/mllib/stat/KernelDensity.py | 61 +++++++++++++++++++ python/pyspark/mllib/stat/__init__.py | 3 +- python/run-tests | 1 + 4 files changed, 75 insertions(+), 2 deletions(-) create mode 100644 python/pyspark/mllib/stat/KernelDensity.py diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 16f3131796709..8f66bc808a007 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -43,7 +43,8 @@ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.stat.test.ChiSqTestResult -import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} +import org.apache.spark.mllib.stat.{ + KernelDensity, MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy, Strategy} import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.loss.Losses @@ -945,6 +946,15 @@ private[python] class PythonMLLibAPI extends Serializable { r => (r.getSeq(0).toArray[Any], r.getSeq(1).toArray[Any]))) } + /** + * Java stub for the estimate method of KernelDensity + */ + def estimateKernelDensity( + sample: JavaRDD[Double], + bandwidth: Double, points: java.util.ArrayList[Double]): Array[Double] = { + return new KernelDensity().setSample(sample).setBandwidth(bandwidth).estimate( + points.asScala.toArray) + } } diff --git a/python/pyspark/mllib/stat/KernelDensity.py b/python/pyspark/mllib/stat/KernelDensity.py new file mode 100644 index 0000000000000..7da921976d4d2 --- /dev/null +++ b/python/pyspark/mllib/stat/KernelDensity.py @@ -0,0 +1,61 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +if sys.version > '3': + xrange = range + +import numpy as np + +from pyspark.mllib.common import callMLlibFunc +from pyspark.rdd import RDD + + +class KernelDensity(object): + """ + .. note:: Experimental + + Estimate probability density at required points given a RDD of samples + from the population. + + >>> kd = KernelDensity() + >>> sample = sc.parallelize([0.0, 1.0]) + >>> kd.setSample(sample) + >>> kd.estimate([0.0, 1.0]) + array([ 0.12938758, 0.12938758]) + """ + def __init__(self): + self._bandwidth = 1.0 + self._sample = None + + def setBandwidth(self, bandwidth): + """Set bandwidth of each sample. Defaults to 1.0""" + self._bandwidth = bandwidth + + def setSample(self, sample): + """Set sample points from the population. Should be a RDD""" + if not isinstance(sample, RDD): + raise TypeError("samples should be a RDD, received %s" % type(sample)) + self._sample = sample + + def estimate(self, points): + """Estimate the probability density at points""" + points = list(points) + densities = callMLlibFunc( + "estimateKernelDensity", self._sample, self._bandwidth, points) + return np.asarray(densities) diff --git a/python/pyspark/mllib/stat/__init__.py b/python/pyspark/mllib/stat/__init__.py index e3e128513e0d7..c8a721d3fe41c 100644 --- a/python/pyspark/mllib/stat/__init__.py +++ b/python/pyspark/mllib/stat/__init__.py @@ -22,6 +22,7 @@ from pyspark.mllib.stat._statistics import * from pyspark.mllib.stat.distribution import MultivariateGaussian from pyspark.mllib.stat.test import ChiSqTestResult +from pyspark.mllib.stat.KernelDensity import KernelDensity __all__ = ["Statistics", "MultivariateStatisticalSummary", "ChiSqTestResult", - "MultivariateGaussian"] + "MultivariateGaussian", "KernelDensity"] diff --git a/python/run-tests b/python/run-tests index 17dda3eadac0c..4468fdb3f267e 100755 --- a/python/run-tests +++ b/python/run-tests @@ -93,6 +93,7 @@ function run_mllib_tests() { run_test "pyspark.mllib.recommendation" run_test "pyspark.mllib.regression" run_test "pyspark.mllib.stat._statistics" + run_test "pyspark.mllib.stat.KernelDensity" run_test "pyspark.mllib.tree" run_test "pyspark.mllib.util" run_test "pyspark.mllib.tests" From 18c4fcebbeecc3b26476a728bc9db62f5c0a6f87 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Sat, 6 Jun 2015 21:08:36 -0700 Subject: [PATCH 215/254] [SPARK-7169] [CORE] Allow metrics system to be configured through SparkConf. Author: Marcelo Vanzin Author: Jacek Lewandowski Closes #6560 from vanzin/SPARK-7169 and squashes the following commits: 737266f [Marcelo Vanzin] Feedback. 702d5a3 [Marcelo Vanzin] Scalastyle. ce66e7e [Marcelo Vanzin] Remove metrics config handling from SparkConf. 439938a [Jacek Lewandowski] SPARK-7169: Metrics can be additionally configured from Spark configuration --- .../apache/spark/metrics/MetricsConfig.scala | 55 ++++++++----- .../apache/spark/metrics/MetricsSystem.scala | 3 +- .../spark/metrics/MetricsConfigSuite.scala | 82 ++++++++++++++++++- 3 files changed, 115 insertions(+), 25 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala index 8edf493780687..d7495551ad233 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala @@ -23,10 +23,10 @@ import java.util.Properties import scala.collection.mutable import scala.util.matching.Regex -import org.apache.spark.Logging import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf} -private[spark] class MetricsConfig(val configFile: Option[String]) extends Logging { +private[spark] class MetricsConfig(conf: SparkConf) extends Logging { private val DEFAULT_PREFIX = "*" private val INSTANCE_REGEX = "^(\\*|[a-zA-Z]+)\\.(.+)".r @@ -46,23 +46,14 @@ private[spark] class MetricsConfig(val configFile: Option[String]) extends Loggi // Add default properties in case there's no properties file setDefaultProperties(properties) - // If spark.metrics.conf is not set, try to get file in class path - val isOpt: Option[InputStream] = configFile.map(new FileInputStream(_)).orElse { - try { - Option(Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_METRICS_CONF_FILENAME)) - } catch { - case e: Exception => - logError("Error loading default configuration file", e) - None - } - } + loadPropertiesFromFile(conf.getOption("spark.metrics.conf")) - isOpt.foreach { is => - try { - properties.load(is) - } finally { - is.close() - } + // Also look for the properties in provided Spark configuration + val prefix = "spark.metrics.conf." + conf.getAll.foreach { + case (k, v) if k.startsWith(prefix) => + properties.setProperty(k.substring(prefix.length()), v) + case _ => } propertyCategories = subProperties(properties, INSTANCE_REGEX) @@ -97,5 +88,31 @@ private[spark] class MetricsConfig(val configFile: Option[String]) extends Loggi case None => propertyCategories.getOrElse(DEFAULT_PREFIX, new Properties) } } -} + /** + * Loads configuration from a config file. If no config file is provided, try to get file + * in class path. + */ + private[this] def loadPropertiesFromFile(path: Option[String]): Unit = { + var is: InputStream = null + try { + is = path match { + case Some(f) => new FileInputStream(f) + case None => Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_METRICS_CONF_FILENAME) + } + + if (is != null) { + properties.load(is) + } + } catch { + case e: Exception => + val file = path.getOrElse(DEFAULT_METRICS_CONF_FILENAME) + logError(s"Error loading configuration file $file", e) + } finally { + if (is != null) { + is.close() + } + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 9150ad35712a1..ed5131c79fdc5 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -70,8 +70,7 @@ private[spark] class MetricsSystem private ( securityMgr: SecurityManager) extends Logging { - private[this] val confFile = conf.get("spark.metrics.conf", null) - private[this] val metricsConfig = new MetricsConfig(Option(confFile)) + private[this] val metricsConfig = new MetricsConfig(conf) private val sinks = new mutable.ArrayBuffer[Sink] private val sources = new mutable.ArrayBuffer[Source] diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala index a901a069d9bfe..41f2ff725a17b 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.metrics +import org.apache.spark.SparkConf + import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite @@ -29,7 +31,9 @@ class MetricsConfigSuite extends SparkFunSuite with BeforeAndAfter { } test("MetricsConfig with default properties") { - val conf = new MetricsConfig(None) + val sparkConf = new SparkConf(loadDefaults = false) + sparkConf.set("spark.metrics.conf", "dummy-file") + val conf = new MetricsConfig(sparkConf) conf.initialize() assert(conf.properties.size() === 4) @@ -42,8 +46,41 @@ class MetricsConfigSuite extends SparkFunSuite with BeforeAndAfter { assert(property.getProperty("sink.servlet.path") === "/metrics/json") } - test("MetricsConfig with properties set") { - val conf = new MetricsConfig(Option(filePath)) + test("MetricsConfig with properties set from a file") { + val sparkConf = new SparkConf(loadDefaults = false) + sparkConf.set("spark.metrics.conf", filePath) + val conf = new MetricsConfig(sparkConf) + conf.initialize() + + val masterProp = conf.getInstance("master") + assert(masterProp.size() === 5) + assert(masterProp.getProperty("sink.console.period") === "20") + assert(masterProp.getProperty("sink.console.unit") === "minutes") + assert(masterProp.getProperty("source.jvm.class") === + "org.apache.spark.metrics.source.JvmSource") + assert(masterProp.getProperty("sink.servlet.class") === + "org.apache.spark.metrics.sink.MetricsServlet") + assert(masterProp.getProperty("sink.servlet.path") === "/metrics/master/json") + + val workerProp = conf.getInstance("worker") + assert(workerProp.size() === 5) + assert(workerProp.getProperty("sink.console.period") === "10") + assert(workerProp.getProperty("sink.console.unit") === "seconds") + assert(workerProp.getProperty("source.jvm.class") === + "org.apache.spark.metrics.source.JvmSource") + assert(workerProp.getProperty("sink.servlet.class") === + "org.apache.spark.metrics.sink.MetricsServlet") + assert(workerProp.getProperty("sink.servlet.path") === "/metrics/json") + } + + test("MetricsConfig with properties set from a Spark configuration") { + val sparkConf = new SparkConf(loadDefaults = false) + setMetricsProperty(sparkConf, "*.sink.console.period", "10") + setMetricsProperty(sparkConf, "*.sink.console.unit", "seconds") + setMetricsProperty(sparkConf, "*.source.jvm.class", "org.apache.spark.metrics.source.JvmSource") + setMetricsProperty(sparkConf, "master.sink.console.period", "20") + setMetricsProperty(sparkConf, "master.sink.console.unit", "minutes") + val conf = new MetricsConfig(sparkConf) conf.initialize() val masterProp = conf.getInstance("master") @@ -67,8 +104,40 @@ class MetricsConfigSuite extends SparkFunSuite with BeforeAndAfter { assert(workerProp.getProperty("sink.servlet.path") === "/metrics/json") } + test("MetricsConfig with properties set from a file and a Spark configuration") { + val sparkConf = new SparkConf(loadDefaults = false) + setMetricsProperty(sparkConf, "*.sink.console.period", "10") + setMetricsProperty(sparkConf, "*.sink.console.unit", "seconds") + setMetricsProperty(sparkConf, "*.source.jvm.class", "org.apache.spark.SomeOtherSource") + setMetricsProperty(sparkConf, "master.sink.console.period", "50") + setMetricsProperty(sparkConf, "master.sink.console.unit", "seconds") + sparkConf.set("spark.metrics.conf", filePath) + val conf = new MetricsConfig(sparkConf) + conf.initialize() + + val masterProp = conf.getInstance("master") + assert(masterProp.size() === 5) + assert(masterProp.getProperty("sink.console.period") === "50") + assert(masterProp.getProperty("sink.console.unit") === "seconds") + assert(masterProp.getProperty("source.jvm.class") === "org.apache.spark.SomeOtherSource") + assert(masterProp.getProperty("sink.servlet.class") === + "org.apache.spark.metrics.sink.MetricsServlet") + assert(masterProp.getProperty("sink.servlet.path") === "/metrics/master/json") + + val workerProp = conf.getInstance("worker") + assert(workerProp.size() === 5) + assert(workerProp.getProperty("sink.console.period") === "10") + assert(workerProp.getProperty("sink.console.unit") === "seconds") + assert(workerProp.getProperty("source.jvm.class") === "org.apache.spark.SomeOtherSource") + assert(workerProp.getProperty("sink.servlet.class") === + "org.apache.spark.metrics.sink.MetricsServlet") + assert(workerProp.getProperty("sink.servlet.path") === "/metrics/json") + } + test("MetricsConfig with subProperties") { - val conf = new MetricsConfig(Option(filePath)) + val sparkConf = new SparkConf(loadDefaults = false) + sparkConf.set("spark.metrics.conf", filePath) + val conf = new MetricsConfig(sparkConf) conf.initialize() val propCategories = conf.propertyCategories @@ -90,4 +159,9 @@ class MetricsConfigSuite extends SparkFunSuite with BeforeAndAfter { val servletProps = sinkProps("servlet") assert(servletProps.size() === 2) } + + private def setMetricsProperty(conf: SparkConf, name: String, value: String): Unit = { + conf.set(s"spark.metrics.conf.$name", value) + } + } From ed2cc3ee890694ca0c1fa0bbc7186c8b80da3fab Mon Sep 17 00:00:00 2001 From: Hari Shreedharan Date: Sat, 6 Jun 2015 21:09:56 -0700 Subject: [PATCH 216/254] [SPARK-8136] [YARN] Fix flakiness in YarnClusterSuite. Instead of actually downloading the logs, just verify that the logs link is actually a URL and is in the expected format. Author: Hari Shreedharan Closes #6680 from harishreedharan/simplify-am-log-tests and squashes the following commits: 3183aeb [Hari Shreedharan] Remove check for hostname which can fail on machines with several hostnames. Removed some unused imports. 50d69a7 [Hari Shreedharan] [SPARK-8136][YARN] Fix flakiness in YarnClusterSuite. --- .../spark/deploy/yarn/YarnClusterSuite.scala | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index bc42e12dfafd7..93d587d0cb36a 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.deploy.yarn import java.io.{File, FileOutputStream, OutputStreamWriter} +import java.net.URL import java.util.Properties import java.util.concurrent.TimeUnit import scala.collection.JavaConversions._ import scala.collection.mutable -import scala.io.Source import com.google.common.base.Charsets.UTF_8 import com.google.common.io.ByteStreams @@ -344,18 +344,20 @@ private object YarnClusterDriver extends Logging with Matchers { assert(info.logUrlMap.nonEmpty) } - // If we are running in yarn-cluster mode, verify that driver logs are downloadable. + // If we are running in yarn-cluster mode, verify that driver logs links and present and are + // in the expected format. if (conf.get("spark.master") == "yarn-cluster") { assert(listener.driverLogs.nonEmpty) val driverLogs = listener.driverLogs.get assert(driverLogs.size === 2) assert(driverLogs.containsKey("stderr")) assert(driverLogs.containsKey("stdout")) - val stderr = driverLogs("stderr") // YARN puts everything in stderr. - val lines = Source.fromURL(stderr).getLines() - // Look for a line that contains YarnClusterSchedulerBackend, since that is guaranteed in - // cluster mode. - assert(lines.exists(_.contains("YarnClusterSchedulerBackend"))) + val urlStr = driverLogs("stderr") + // Ensure that this is a valid URL, else this will throw an exception + new URL(urlStr) + val containerId = YarnSparkHadoopUtil.get.getContainerId + val user = Utils.getCurrentUserName() + assert(urlStr.endsWith(s"/node/containerlogs/$containerId/$user/stderr?start=0")) } } From 3285a51121397bfd2e62dbee8e1f0fa7c72512a7 Mon Sep 17 00:00:00 2001 From: Hari Shreedharan Date: Sat, 6 Jun 2015 21:13:26 -0700 Subject: [PATCH 217/254] =?UTF-8?q?[SPARK-7955]=20[CORE]=20Ensure=20execut?= =?UTF-8?q?ors=20with=20cached=20RDD=20blocks=20are=20not=20re=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …moved if dynamic allocation is enabled. This is a work in progress. This patch ensures that an executor that has cached RDD blocks are not removed, but makes no attempt to find another executor to remove. This is meant to get some feedback on the current approach, and if it makes sense then I will look at choosing another executor to remove. No testing has been done either. Author: Hari Shreedharan Closes #6508 from harishreedharan/dymanic-caching and squashes the following commits: dddf1eb [Hari Shreedharan] Minor configuration description update. 10130e2 [Hari Shreedharan] Fix compile issue. 5417b53 [Hari Shreedharan] Add documentation for new config. Remove block from cachedBlocks when it is dropped. 875916a [Hari Shreedharan] Make some code more readable. 39940ca [Hari Shreedharan] Handle the case where the executor has not yet registered. 90ad711 [Hari Shreedharan] Remove unused imports and unused methods. 063985c [Hari Shreedharan] Send correct message instead of recursively calling same method. ec2fd7e [Hari Shreedharan] Add file missed in last commit 5d10fad [Hari Shreedharan] Update cached blocks status using local info, rather than doing an RPC. 193af4c [Hari Shreedharan] WIP. Use local state rather than via RPC. ae932ff [Hari Shreedharan] Fix config param name. 272969d [Hari Shreedharan] Fix seconds to millis bug. 5a1993f [Hari Shreedharan] Add timeout for cache executors. Ignore broadcast blocks while checking if there are cached blocks. 57fefc2 [Hari Shreedharan] [SPARK-7955][Core] Ensure executors with cached RDD blocks are not removed if dynamic allocation is enabled. --- .../spark/ExecutorAllocationManager.scala | 21 ++++++++++-- .../spark/storage/BlockManagerMaster.scala | 8 +++++ .../storage/BlockManagerMasterEndpoint.scala | 33 +++++++++++++++++-- .../spark/storage/BlockManagerMessages.scala | 3 +- docs/configuration.md | 9 +++++ 5 files changed, 68 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 9939103bb0903..49329423dca76 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -101,6 +101,9 @@ private[spark] class ExecutorAllocationManager( private val executorIdleTimeoutS = conf.getTimeAsSeconds( "spark.dynamicAllocation.executorIdleTimeout", "60s") + private val cachedExecutorIdleTimeoutS = conf.getTimeAsSeconds( + "spark.dynamicAllocation.cachedExecutorIdleTimeout", s"${2 * executorIdleTimeoutS}s") + // During testing, the methods to actually kill and add executors are mocked out private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) @@ -459,9 +462,23 @@ private[spark] class ExecutorAllocationManager( private def onExecutorIdle(executorId: String): Unit = synchronized { if (executorIds.contains(executorId)) { if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) { + // Note that it is not necessary to query the executors since all the cached + // blocks we are concerned with are reported to the driver. Note that this + // does not include broadcast blocks. + val hasCachedBlocks = SparkEnv.get.blockManager.master.hasCachedBlocks(executorId) + val now = clock.getTimeMillis() + val timeout = { + if (hasCachedBlocks) { + // Use a different timeout if the executor has cached blocks. + now + cachedExecutorIdleTimeoutS * 1000 + } else { + now + executorIdleTimeoutS * 1000 + } + } + val realTimeout = if (timeout <= 0) Long.MaxValue else timeout // overflow + removeTimes(executorId) = realTimeout logDebug(s"Starting idle timer for $executorId because there are no more tasks " + - s"scheduled to run on the executor (to expire in $executorIdleTimeoutS seconds)") - removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeoutS * 1000 + s"scheduled to run on the executor (to expire in ${(realTimeout - now)/1000} seconds)") } } else { logWarning(s"Attempted to mark unknown executor $executorId idle") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index abcad9438bf28..7cdae22b0e253 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -202,6 +202,14 @@ class BlockManagerMaster( Await.result(future, timeout) } + /** + * Find out if the executor has cached blocks. This method does not consider broadcast blocks, + * since they are not reported the master. + */ + def hasCachedBlocks(executorId: String): Boolean = { + driverEndpoint.askWithRetry[Boolean](HasCachedBlocks(executorId)) + } + /** Stop the driver endpoint, called only on the Spark driver node */ def stop() { if (driverEndpoint != null && isDriver) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 2cd8c5297b741..68ed9096731c5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.util.{HashMap => JHashMap} +import scala.collection.immutable.HashSet import scala.collection.mutable import scala.collection.JavaConversions._ import scala.concurrent.{ExecutionContext, Future} @@ -112,6 +113,17 @@ class BlockManagerMasterEndpoint( case BlockManagerHeartbeat(blockManagerId) => context.reply(heartbeatReceived(blockManagerId)) + case HasCachedBlocks(executorId) => + blockManagerIdByExecutor.get(executorId) match { + case Some(bm) => + if (blockManagerInfo.contains(bm)) { + val bmInfo = blockManagerInfo(bm) + context.reply(bmInfo.cachedBlocks.nonEmpty) + } else { + context.reply(false) + } + case None => context.reply(false) + } } private def removeRdd(rddId: Int): Future[Seq[Int]] = { @@ -418,6 +430,9 @@ private[spark] class BlockManagerInfo( // Mapping from block id to its status. private val _blocks = new JHashMap[BlockId, BlockStatus] + // Cached blocks held by this BlockManager. This does not include broadcast blocks. + private val _cachedBlocks = new mutable.HashSet[BlockId] + def getStatus(blockId: BlockId): Option[BlockStatus] = Option(_blocks.get(blockId)) def updateLastSeenMs() { @@ -451,27 +466,35 @@ private[spark] class BlockManagerInfo( * and the diskSize here indicates the data size in or dropped to disk. * They can be both larger than 0, when a block is dropped from memory to disk. * Therefore, a safe way to set BlockStatus is to set its info in accurate modes. */ + var blockStatus: BlockStatus = null if (storageLevel.useMemory) { - _blocks.put(blockId, BlockStatus(storageLevel, memSize, 0, 0)) + blockStatus = BlockStatus(storageLevel, memSize, 0, 0) + _blocks.put(blockId, blockStatus) _remainingMem -= memSize logInfo("Added %s in memory on %s (size: %s, free: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(memSize), Utils.bytesToString(_remainingMem))) } if (storageLevel.useDisk) { - _blocks.put(blockId, BlockStatus(storageLevel, 0, diskSize, 0)) + blockStatus = BlockStatus(storageLevel, 0, diskSize, 0) + _blocks.put(blockId, blockStatus) logInfo("Added %s on disk on %s (size: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize))) } if (storageLevel.useOffHeap) { - _blocks.put(blockId, BlockStatus(storageLevel, 0, 0, externalBlockStoreSize)) + blockStatus = BlockStatus(storageLevel, 0, 0, externalBlockStoreSize) + _blocks.put(blockId, blockStatus) logInfo("Added %s on ExternalBlockStore on %s (size: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(externalBlockStoreSize))) } + if (!blockId.isBroadcast && blockStatus.isCached) { + _cachedBlocks += blockId + } } else if (_blocks.containsKey(blockId)) { // If isValid is not true, drop the block. val blockStatus: BlockStatus = _blocks.get(blockId) _blocks.remove(blockId) + _cachedBlocks -= blockId if (blockStatus.storageLevel.useMemory) { logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize), @@ -494,6 +517,7 @@ private[spark] class BlockManagerInfo( _remainingMem += _blocks.get(blockId).memSize _blocks.remove(blockId) } + _cachedBlocks -= blockId } def remainingMem: Long = _remainingMem @@ -502,6 +526,9 @@ private[spark] class BlockManagerInfo( def blocks: JHashMap[BlockId, BlockStatus] = _blocks + // This does not include broadcast blocks. + def cachedBlocks: collection.Set[BlockId] = _cachedBlocks + override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem def clear() { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 1683576067fe8..376e9eb48843d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -42,7 +42,6 @@ private[spark] object BlockManagerMessages { case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true) extends ToBlockManagerSlave - ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. ////////////////////////////////////////////////////////////////////////////////// @@ -108,4 +107,6 @@ private[spark] object BlockManagerMessages { extends ToBlockManagerMaster case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster + + case class HasCachedBlocks(executorId: String) extends ToBlockManagerMaster } diff --git a/docs/configuration.md b/docs/configuration.md index 3a48da4592dd9..9667cebe0b87c 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1201,6 +1201,15 @@ Apart from these, the following properties are also available, and may be useful description. + + spark.dynamicAllocation.cachedExecutorIdleTimeout + 2 * executorIdleTimeout + + If dynamic allocation is enabled and an executor which has cached data blocks has been idle for more than this duration, + the executor will be removed. For more details, see this + description. + + spark.dynamicAllocation.initialExecutors spark.dynamicAllocation.minExecutors From 901a552c5e973262fddbf70ee2d4078c948bc668 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 6 Jun 2015 22:59:31 -0700 Subject: [PATCH 218/254] [SPARK-8004][SQL] Enclose column names by JDBC Dialect JIRA: https://issues.apache.org/jira/browse/SPARK-8004 Author: Liang-Chi Hsieh Closes #6577 from viirya/enclose_jdbc_columns and squashes the following commits: 614606a [Liang-Chi Hsieh] For comment. bc50182 [Liang-Chi Hsieh] Enclose column names by JDBC Dialect. --- .../scala/org/apache/spark/sql/jdbc/JDBCRDD.scala | 4 +++- .../org/apache/spark/sql/jdbc/JdbcDialects.scala | 13 +++++++++++++ .../scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 11 +++++++++++ 3 files changed, 27 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 40b604d710dce..2930f7bb4cae1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -211,12 +211,14 @@ private[sql] object JDBCRDD extends Logging { requiredColumns: Array[String], filters: Array[Filter], parts: Array[Partition]): RDD[Row] = { + val dialect = JdbcDialects.get(url) + val enclosedColumns = requiredColumns.map(dialect.columnEnclosing(_)) new JDBCRDD( sc, getConnector(driver, url, properties), pruneSchema(schema, requiredColumns), fqTable, - requiredColumns, + enclosedColumns, filters, parts, properties) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 6a169e106b968..04052f80f5e78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -80,6 +80,15 @@ abstract class JdbcDialect { * @return The new JdbcType if there is an override for this DataType */ def getJDBCType(dt: DataType): Option[JdbcType] = None + + /** + * Enclose column name + * @param colName The coulmn name + * @return Enclosed column name + */ + def columnEnclosing(colName: String): String = { + s""""$colName"""" + } } /** @@ -208,4 +217,8 @@ case object MySQLDialect extends JdbcDialect { Some(BooleanType) } else None } + + override def columnEnclosing(colName: String): String = { + s"`$colName`" + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 7931854db27c1..a228543953536 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -410,6 +410,17 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { assert(JdbcDialects.get("test.invalid") == NoopDialect) } + test("Enclosing column names by jdbc dialect") { + val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + + val columns = Seq("abc", "key") + val MySQLColumns = columns.map(MySQL.columnEnclosing(_)) + val PostgresColumns = columns.map(Postgres.columnEnclosing(_)) + assert(MySQLColumns === Seq("`abc`", "`key`")) + assert(PostgresColumns === Seq(""""abc"""", """"key"""")) + } + test("Dialect unregister") { JdbcDialects.registerDialect(testH2Dialect) JdbcDialects.unregisterDialect(testH2Dialect) From 081db9479abc559b26d115298fbcdc109858cad3 Mon Sep 17 00:00:00 2001 From: 979969786 Date: Sat, 6 Jun 2015 23:15:27 -0700 Subject: [PATCH 219/254] [SPARK-8145] [WEBUI] Trigger a double click on the span to show full job description. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When using the Spark SQL, Jobs tab and Stages tab display only part of SQL. I change it to display full SQL by double-click on the description span before: ![before](https://cloud.githubusercontent.com/assets/5399861/8022257/9f8e0a22-0cf8-11e5-98c8-da4d7a615e7e.png) after double click on the description span: ![after](https://cloud.githubusercontent.com/assets/5399861/8022261/dac08d4a-0cf8-11e5-8fe7-74c96c6ce933.png) Author: 979969786 Closes #6646 from 979969786/master and squashes the following commits: b5ba20e [979969786] Trigger a double click on the span to show full job description. --- .../org/apache/spark/ui/static/additional-metrics.js | 5 +++++ .../main/resources/org/apache/spark/ui/static/webui.css | 8 ++++++++ 2 files changed, 13 insertions(+) diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js index 013db8df9b363..0b450dc76bc38 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js +++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js @@ -50,4 +50,9 @@ $(function() { $("span.additional-metric-title").click(function() { $(this).parent().find('input[type="checkbox"]').trigger('click'); }); + + // Trigger a double click on the span to show full job description. + $(".description-input").dblclick(function() { + $(this).removeClass("description-input").addClass("description-input-full"); + }); }); diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index e7c1d475d4e52..b1cef47042247 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -135,6 +135,14 @@ pre { display: block; } +.description-input-full { + overflow: hidden; + text-overflow: ellipsis; + width: 100%; + white-space: normal; + display: block; +} + .stacktrace-details { max-height: 300px; overflow-y: auto; From 26d07f1ece4174788b0bcdc338a14d0bbc0e3602 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 7 Jun 2015 15:33:48 +0800 Subject: [PATCH 220/254] [SPARK-8141] [SQL] Precompute datatypes for partition columns and reuse it JIRA: https://issues.apache.org/jira/browse/SPARK-8141 Author: Liang-Chi Hsieh Closes #6687 from viirya/reuse_partition_column_types and squashes the following commits: dab0688 [Liang-Chi Hsieh] Reuse partitionColumnTypes. --- .../main/scala/org/apache/spark/sql/sources/interfaces.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index f5bd2d2941ca0..25887ba9a15b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -435,8 +435,9 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio // partition values. userDefinedPartitionColumns.map { partitionSchema => val spec = discoverPartitions() + val partitionColumnTypes = spec.partitionColumns.map(_.dataType) val castedPartitions = spec.partitions.map { case p @ Partition(values, path) => - val literals = values.toSeq.zip(spec.partitionColumns.map(_.dataType)).map { + val literals = values.toSeq.zip(partitionColumnTypes).map { case (value, dataType) => Literal.create(value, dataType) } val castedValues = partitionSchema.zip(literals).map { case (field, literal) => From 0ac47083f7ef5fca9847bca2f0490719e1ccf50a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 7 Jun 2015 01:21:02 -0700 Subject: [PATCH 221/254] [SPARK-8146] DataFrame Python API: Alias replace in df.na Author: Reynold Xin Closes #6688 from rxin/df-alias-replace and squashes the following commits: 774c19c [Reynold Xin] [SPARK-8146] DataFrame Python API: Alias replace in DataFrameNaFunctions. --- python/pyspark/sql/dataframe.py | 47 +++++++++++++++------------------ python/pyspark/sql/window.py | 1 - 2 files changed, 22 insertions(+), 26 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 902504df5b11b..2d8c59518b35a 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -909,8 +909,7 @@ def dropDuplicates(self, subset=None): @since("1.3.1") def dropna(self, how='any', thresh=None, subset=None): """Returns a new :class:`DataFrame` omitting rows with null values. - - This is an alias for ``na.drop()``. + :func:`DataFrame.dropna` and :func:`DataFrameNaFunctions.drop` are aliases of each other. :param how: 'any' or 'all'. If 'any', drop a row if it contains any nulls. @@ -920,13 +919,6 @@ def dropna(self, how='any', thresh=None, subset=None): This overwrites the `how` parameter. :param subset: optional list of column names to consider. - >>> df4.dropna().show() - +---+------+-----+ - |age|height| name| - +---+------+-----+ - | 10| 80|Alice| - +---+------+-----+ - >>> df4.na.drop().show() +---+------+-----+ |age|height| name| @@ -952,6 +944,7 @@ def dropna(self, how='any', thresh=None, subset=None): @since("1.3.1") def fillna(self, value, subset=None): """Replace null values, alias for ``na.fill()``. + :func:`DataFrame.fillna` and :func:`DataFrameNaFunctions.fill` are aliases of each other. :param value: int, long, float, string, or dict. Value to replace null values with. @@ -963,7 +956,7 @@ def fillna(self, value, subset=None): For example, if `value` is a string, and subset contains a non-string column, then the non-string column is simply ignored. - >>> df4.fillna(50).show() + >>> df4.na.fill(50).show() +---+------+-----+ |age|height| name| +---+------+-----+ @@ -973,16 +966,6 @@ def fillna(self, value, subset=None): | 50| 50| null| +---+------+-----+ - >>> df4.fillna({'age': 50, 'name': 'unknown'}).show() - +---+------+-------+ - |age|height| name| - +---+------+-------+ - | 10| 80| Alice| - | 5| null| Bob| - | 50| null| Tom| - | 50| null|unknown| - +---+------+-------+ - >>> df4.na.fill({'age': 50, 'name': 'unknown'}).show() +---+------+-------+ |age|height| name| @@ -1014,6 +997,8 @@ def fillna(self, value, subset=None): @since(1.4) def replace(self, to_replace, value, subset=None): """Returns a new :class:`DataFrame` replacing a value with another value. + :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are + aliases of each other. :param to_replace: int, long, float, string, or list. Value to be replaced. @@ -1029,7 +1014,7 @@ def replace(self, to_replace, value, subset=None): For example, if `value` is a string, and subset contains a non-string column, then the non-string column is simply ignored. - >>> df4.replace(10, 20).show() + >>> df4.na.replace(10, 20).show() +----+------+-----+ | age|height| name| +----+------+-----+ @@ -1039,7 +1024,7 @@ def replace(self, to_replace, value, subset=None): |null| null| null| +----+------+-----+ - >>> df4.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show() + >>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show() +----+------+----+ | age|height|name| +----+------+----+ @@ -1090,9 +1075,9 @@ def replace(self, to_replace, value, subset=None): @since(1.4) def corr(self, col1, col2, method=None): """ - Calculates the correlation of two columns of a DataFrame as a double value. Currently only - supports the Pearson Correlation Coefficient. - :func:`DataFrame.corr` and :func:`DataFrameStatFunctions.corr` are aliases. + Calculates the correlation of two columns of a DataFrame as a double value. + Currently only supports the Pearson Correlation Coefficient. + :func:`DataFrame.corr` and :func:`DataFrameStatFunctions.corr` are aliases of each other. :param col1: The name of the first column :param col2: The name of the second column @@ -1241,7 +1226,10 @@ def toPandas(self): import pandas as pd return pd.DataFrame.from_records(self.collect(), columns=self.columns) + ########################################################################################## # Pandas compatibility + ########################################################################################## + groupby = groupBy drop_duplicates = dropDuplicates @@ -1261,6 +1249,8 @@ def _to_scala_map(sc, jm): class DataFrameNaFunctions(object): """Functionality for working with missing data in :class:`DataFrame`. + + .. versionadded:: 1.4 """ def __init__(self, df): @@ -1276,9 +1266,16 @@ def fill(self, value, subset=None): fill.__doc__ = DataFrame.fillna.__doc__ + def replace(self, to_replace, value, subset=None): + return self.df.replace(to_replace, value, subset) + + replace.__doc__ = DataFrame.replace.__doc__ + class DataFrameStatFunctions(object): """Functionality for statistic functions with :class:`DataFrame`. + + .. versionadded:: 1.4 """ def __init__(self, df): diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index 0a0e006bdf83a..c74745c726a0c 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -32,7 +32,6 @@ def _to_java_cols(cols): class Window(object): - """ Utility functions for defining window in DataFrames. From 8c321d66d79716f03b6f9c8ad5cedd75e8bfe631 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sun, 7 Jun 2015 16:59:55 +0800 Subject: [PATCH 222/254] [SPARK-8118] [SQL] Mutes noisy Parquet log output reappeared after upgrading Parquet to 1.7.0 Author: Cheng Lian Closes #6670 from liancheng/spark-8118 and squashes the following commits: b6e85a6 [Cheng Lian] Suppresses unnecesary ParquetRecordReader log message (PARQUET-220) 385603c [Cheng Lian] Mutes noisy Parquet log output reappeared after upgrading Parquet to 1.7.0 --- .../spark/sql/parquet/ParquetRelation.scala | 35 +++++++++++-------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 09088ee91106c..704cf56f38265 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -18,20 +18,21 @@ package org.apache.spark.sql.parquet import java.io.IOException -import java.util.logging.Level +import java.util.logging.{Level, Logger => JLogger} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.permission.FsAction -import org.apache.spark.sql.types.{StructType, DataType} -import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat} import org.apache.parquet.hadoop.metadata.CompressionCodecName +import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat, ParquetRecordReader} import org.apache.parquet.schema.MessageType +import org.apache.parquet.{Log => ParquetLog} -import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException} -import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, SQLContext} /** * Relation that consists of data stored in a Parquet columnar format. @@ -94,33 +95,37 @@ private[sql] case class ParquetRelation( private[sql] object ParquetRelation { def enableLogForwarding() { - // Note: the parquet.Log class has a static initializer that - // sets the java.util.logging Logger for "parquet". This + // Note: the org.apache.parquet.Log class has a static initializer that + // sets the java.util.logging Logger for "org.apache.parquet". This // checks first to see if there's any handlers already set // and if not it creates them. If this method executes prior // to that class being loaded then: // 1) there's no handlers installed so there's none to // remove. But when it IS finally loaded the desired affect // of removing them is circumvented. - // 2) The parquet.Log static initializer calls setUseParentHanders(false) + // 2) The parquet.Log static initializer calls setUseParentHandlers(false) // undoing the attempt to override the logging here. // // Therefore we need to force the class to be loaded. // This should really be resolved by Parquet. - Class.forName(classOf[org.apache.parquet.Log].getName) + Class.forName(classOf[ParquetLog].getName) // Note: Logger.getLogger("parquet") has a default logger // that appends to Console which needs to be cleared. - val parquetLogger = java.util.logging.Logger.getLogger("parquet") + val parquetLogger = JLogger.getLogger(classOf[ParquetLog].getPackage.getName) parquetLogger.getHandlers.foreach(parquetLogger.removeHandler) - // TODO(witgo): Need to set the log level ? - // if(parquetLogger.getLevel != null) parquetLogger.setLevel(null) - if (!parquetLogger.getUseParentHandlers) parquetLogger.setUseParentHandlers(true) + parquetLogger.setUseParentHandlers(true) - // Disables WARN log message in ParquetOutputCommitter. + // Disables a WARN log message in ParquetOutputCommitter. We first ensure that + // ParquetOutputCommitter is loaded and the static LOG field gets initialized. // See https://issues.apache.org/jira/browse/SPARK-5968 for details Class.forName(classOf[ParquetOutputCommitter].getName) - java.util.logging.Logger.getLogger(classOf[ParquetOutputCommitter].getName).setLevel(Level.OFF) + JLogger.getLogger(classOf[ParquetOutputCommitter].getName).setLevel(Level.OFF) + + // Similar as above, disables a unnecessary WARN log message in ParquetRecordReader. + // See https://issues.apache.org/jira/browse/PARQUET-220 for details + Class.forName(classOf[ParquetRecordReader[_]].getName) + JLogger.getLogger(classOf[ParquetRecordReader[_]].getName).setLevel(Level.OFF) } // The element type for the RDDs that this relation maps to. From ca8dafcc9fa661f05da9c98104d987716aa5f5eb Mon Sep 17 00:00:00 2001 From: Konstantin Shaposhnikov Date: Sun, 7 Jun 2015 13:41:00 +0100 Subject: [PATCH 223/254] [SPARK-7042] [BUILD] use the standard akka artifacts with hadoop-2.x Both akka 2.3.x and hadoop-2.x use protobuf 2.5 so only hadoop-1 build needs custom 2.3.4-spark akka version that shades protobuf-2.5 This change also updates akka version (for hadoop-2.x profiles only) to the latest 2.3.11 as akka-zeromq_2.11 is not available for akka 2.3.4. This partially fixes SPARK-7042 (for hadoop-2.x builds) Author: Konstantin Shaposhnikov Closes #6492 from kostya-sh/SPARK-7042 and squashes the following commits: dc195b0 [Konstantin Shaposhnikov] [SPARK-7042] [BUILD] use the standard akka artifacts with hadoop-2.x --- pom.xml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index e28d4b9fc2b17..e65448e4b2325 100644 --- a/pom.xml +++ b/pom.xml @@ -114,8 +114,8 @@ UTF-8 UTF-8 - org.spark-project.akka - 2.3.4-spark + com.typesafe.akka + 2.3.11 1.6 spark 0.21.1 @@ -1670,6 +1670,8 @@ 0.98.7-hadoop1 hadoop1 1.8.8 + org.spark-project.akka + 2.3.4-spark From 835f1380d95a345208b682492f0735155e61a824 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Sun, 7 Jun 2015 15:30:37 +0100 Subject: [PATCH 224/254] [DOC] [TYPO] Fix typo in standalone deploy scripts description Author: Yijie Shen Closes #6691 from yijieshen/patch-2 and squashes the following commits: b40a4b0 [Yijie Shen] [DOC][TYPO] Fix typo in standalone deploy scripts description --- docs/spark-standalone.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 0eed9adacf123..12d7d6e159bea 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -77,7 +77,7 @@ Note, the master machine accesses each of the worker machines via ssh. By defaul If you do not have a password-less setup, you can set the environment variable SPARK_SSH_FOREGROUND and serially provide a password for each worker. -Once you've set up this file, you can launch or stop your cluster with the following shell scripts, based on Hadoop's deploy scripts, and available in `SPARK_HOME/bin`: +Once you've set up this file, you can launch or stop your cluster with the following shell scripts, based on Hadoop's deploy scripts, and available in `SPARK_HOME/sbin`: - `sbin/start-master.sh` - Starts a master instance on the machine the script is executed on. - `sbin/start-slaves.sh` - Starts a slave instance on each machine specified in the `conf/slaves` file. From d6d601a07b17069d41eb4114bd5f7ab2c106720d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 7 Jun 2015 10:52:02 -0700 Subject: [PATCH 225/254] [SPARK-8004][SQL] Quote identifier in JDBC data source. This is a follow-up patch to #6577 to replace columnEnclosing to quoteIdentifier. I also did some minor cleanup to the JdbcDialect file. Author: Reynold Xin Closes #6689 from rxin/jdbc-quote and squashes the following commits: bad365f [Reynold Xin] Fixed test compilation... e39e14e [Reynold Xin] Fixed compilation. db9a8e0 [Reynold Xin] [SPARK-8004][SQL] Quote identifier in JDBC data source. --- .../org/apache/spark/sql/jdbc/JDBCRDD.scala | 4 +-- .../apache/spark/sql/jdbc/JdbcDialects.scala | 34 +++++++++---------- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 6 ++-- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 2930f7bb4cae1..db68b9c86db1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -212,13 +212,13 @@ private[sql] object JDBCRDD extends Logging { filters: Array[Filter], parts: Array[Partition]): RDD[Row] = { val dialect = JdbcDialects.get(url) - val enclosedColumns = requiredColumns.map(dialect.columnEnclosing(_)) + val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName)) new JDBCRDD( sc, getConnector(driver, url, properties), pruneSchema(schema, requiredColumns), fqTable, - enclosedColumns, + quotedColumns, filters, parts, properties) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 04052f80f5e78..8849fc2f1f0ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.jdbc +import java.sql.Types + import org.apache.spark.sql.types._ import org.apache.spark.annotation.DeveloperApi -import java.sql.Types - /** * :: DeveloperApi :: * A database type definition coupled with the jdbc type needed to send null @@ -82,11 +82,10 @@ abstract class JdbcDialect { def getJDBCType(dt: DataType): Option[JdbcType] = None /** - * Enclose column name - * @param colName The coulmn name - * @return Enclosed column name + * Quotes the identifier. This is used to put quotes around the identifier in case the column + * name is a reserved keyword, or in case it contains characters that require quotes (e.g. space). */ - def columnEnclosing(colName: String): String = { + def quoteIdentifier(colName: String): String = { s""""$colName"""" } } @@ -150,18 +149,19 @@ object JdbcDialects { @DeveloperApi class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect { - require(!dialects.isEmpty) + require(dialects.nonEmpty) - def canHandle(url : String): Boolean = + override def canHandle(url : String): Boolean = dialects.map(_.canHandle(url)).reduce(_ && _) override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = - dialects.map(_.getCatalystType(sqlType, typeName, size, md)).flatten.headOption - - override def getJDBCType(dt: DataType): Option[JdbcType] = - dialects.map(_.getJDBCType(dt)).flatten.headOption + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + dialects.flatMap(_.getCatalystType(sqlType, typeName, size, md)).headOption + } + override def getJDBCType(dt: DataType): Option[JdbcType] = { + dialects.flatMap(_.getJDBCType(dt)).headOption + } } /** @@ -170,7 +170,7 @@ class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect { */ @DeveloperApi case object NoopDialect extends JdbcDialect { - def canHandle(url : String): Boolean = true + override def canHandle(url : String): Boolean = true } /** @@ -179,7 +179,7 @@ case object NoopDialect extends JdbcDialect { */ @DeveloperApi case object PostgresDialect extends JdbcDialect { - def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") + override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { @@ -205,7 +205,7 @@ case object PostgresDialect extends JdbcDialect { */ @DeveloperApi case object MySQLDialect extends JdbcDialect { - def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") + override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { @@ -218,7 +218,7 @@ case object MySQLDialect extends JdbcDialect { } else None } - override def columnEnclosing(colName: String): String = { + override def quoteIdentifier(colName: String): String = { s"`$colName`" } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index a228543953536..49d348c3ed21b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -410,13 +410,13 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { assert(JdbcDialects.get("test.invalid") == NoopDialect) } - test("Enclosing column names by jdbc dialect") { + test("quote column names by jdbc dialect") { val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") val columns = Seq("abc", "key") - val MySQLColumns = columns.map(MySQL.columnEnclosing(_)) - val PostgresColumns = columns.map(Postgres.columnEnclosing(_)) + val MySQLColumns = columns.map(MySQL.quoteIdentifier(_)) + val PostgresColumns = columns.map(Postgres.quoteIdentifier(_)) assert(MySQLColumns === Seq("`abc`", "`key`")) assert(PostgresColumns === Seq(""""abc"""", """"key"""")) } From db81b9d89f62f18a3c4c2d9bced8486bfdea54a2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 7 Jun 2015 11:07:19 -0700 Subject: [PATCH 226/254] [SPARK-7952][SQL] use internal Decimal instead of java.math.BigDecimal This PR fixes a bug introduced in https://github.com/apache/spark/pull/6505. Decimal literal's value is not `java.math.BigDecimal`, but Spark SQL internal type: `Decimal`. Author: Wenchen Fan Closes #6574 from cloud-fan/fix and squashes the following commits: b0e3549 [Wenchen Fan] rename to BooleanEquality 1987b37 [Wenchen Fan] use Decimal instead of java.math.BigDecimal f93c420 [Wenchen Fan] compare literal --- .../catalyst/analysis/HiveTypeCoercion.scala | 40 +++++++++---------- .../analysis/HiveTypeCoercionSuite.scala | 24 ++++++++++- 2 files changed, 43 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 9b8a08a88dcb0..a42ffce0d26fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -87,7 +87,7 @@ trait HiveTypeCoercion { WidenTypes :: PromoteStrings :: DecimalPrecision :: - BooleanEqualization :: + BooleanEquality :: StringToIntegralCasts :: FunctionArgumentConversion :: CaseWhenCoercion :: @@ -479,9 +479,9 @@ trait HiveTypeCoercion { /** * Changes numeric values to booleans so that expressions like true = 1 can be evaluated. */ - object BooleanEqualization extends Rule[LogicalPlan] { - private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, new java.math.BigDecimal(1)) - private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, new java.math.BigDecimal(0)) + object BooleanEquality extends Rule[LogicalPlan] { + private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1)) + private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal(0)) private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = { CaseKeyWhen(numericExpr, Seq( @@ -512,22 +512,22 @@ trait HiveTypeCoercion { // all other cases are considered as false. // We may simplify the expression if one side is literal numeric values - case EqualTo(left @ BooleanType(), Literal(value, _: NumericType)) - if trueValues.contains(value) => left - case EqualTo(left @ BooleanType(), Literal(value, _: NumericType)) - if falseValues.contains(value) => Not(left) - case EqualTo(Literal(value, _: NumericType), right @ BooleanType()) - if trueValues.contains(value) => right - case EqualTo(Literal(value, _: NumericType), right @ BooleanType()) - if falseValues.contains(value) => Not(right) - case EqualNullSafe(left @ BooleanType(), Literal(value, _: NumericType)) - if trueValues.contains(value) => And(IsNotNull(left), left) - case EqualNullSafe(left @ BooleanType(), Literal(value, _: NumericType)) - if falseValues.contains(value) => And(IsNotNull(left), Not(left)) - case EqualNullSafe(Literal(value, _: NumericType), right @ BooleanType()) - if trueValues.contains(value) => And(IsNotNull(right), right) - case EqualNullSafe(Literal(value, _: NumericType), right @ BooleanType()) - if falseValues.contains(value) => And(IsNotNull(right), Not(right)) + case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType)) + if trueValues.contains(value) => bool + case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType)) + if falseValues.contains(value) => Not(bool) + case EqualTo(Literal(value, _: NumericType), bool @ BooleanType()) + if trueValues.contains(value) => bool + case EqualTo(Literal(value, _: NumericType), bool @ BooleanType()) + if falseValues.contains(value) => Not(bool) + case EqualNullSafe(bool @ BooleanType(), Literal(value, _: NumericType)) + if trueValues.contains(value) => And(IsNotNull(bool), bool) + case EqualNullSafe(bool @ BooleanType(), Literal(value, _: NumericType)) + if falseValues.contains(value) => And(IsNotNull(bool), Not(bool)) + case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanType()) + if trueValues.contains(value) => And(IsNotNull(bool), bool) + case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanType()) + if falseValues.contains(value) => And(IsNotNull(bool), Not(bool)) case EqualTo(left @ BooleanType(), right @ NumericType()) => transform(left , right) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 0df446636ea89..9977f7af00f6b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -147,7 +147,8 @@ class HiveTypeCoercionSuite extends PlanTest { } test("type coercion simplification for equal to") { - val be = new HiveTypeCoercion {}.BooleanEqualization + val be = new HiveTypeCoercion {}.BooleanEquality + ruleTest(be, EqualTo(Literal(true), Literal(1)), Literal(true) @@ -164,5 +165,26 @@ class HiveTypeCoercionSuite extends PlanTest { EqualNullSafe(Literal(true), Literal(0)), And(IsNotNull(Literal(true)), Not(Literal(true))) ) + + ruleTest(be, + EqualTo(Literal(true), Literal(1L)), + Literal(true) + ) + ruleTest(be, + EqualTo(Literal(new java.math.BigDecimal(1)), Literal(true)), + Literal(true) + ) + ruleTest(be, + EqualTo(Literal(BigDecimal(0)), Literal(true)), + Not(Literal(true)) + ) + ruleTest(be, + EqualTo(Literal(Decimal(1)), Literal(true)), + Literal(true) + ) + ruleTest(be, + EqualTo(Literal.create(Decimal(1), DecimalType(8, 0)), Literal(true)), + Literal(true) + ) } } From e84815dc333a69368a48e0152f02934980768a14 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 7 Jun 2015 20:18:13 +0100 Subject: [PATCH 227/254] [SPARK-7733] [CORE] [BUILD] Update build, code to use Java 7 for 1.5.0+ Update build to use Java 7, and remove some comments and special-case support for Java 6. Author: Sean Owen Closes #6265 from srowen/SPARK-7733 and squashes the following commits: 59bda4e [Sean Owen] Update build to use Java 7, and remove some comments and special-case support for Java 6 --- bin/spark-class | 18 ------------------ .../spark/util/MutableURLClassLoader.scala | 4 ++-- .../scala/org/apache/spark/util/Utils.scala | 3 +-- .../spark/util/collection/SorterSuite.scala | 3 --- docs/building-spark.md | 6 +----- docs/index.md | 2 +- docs/programming-guide.md | 2 +- make-distribution.sh | 16 ---------------- pom.xml | 2 +- .../apache/spark/unsafe/PlatformDependent.java | 3 +-- 10 files changed, 8 insertions(+), 51 deletions(-) diff --git a/bin/spark-class b/bin/spark-class index 7bb1afe4b44f5..2b59e5df5736f 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -58,24 +58,6 @@ fi SPARK_ASSEMBLY_JAR="${ASSEMBLY_DIR}/${ASSEMBLY_JARS}" -# Verify that versions of java used to build the jars and run Spark are compatible -if [ -n "$JAVA_HOME" ]; then - JAR_CMD="$JAVA_HOME/bin/jar" -else - JAR_CMD="jar" -fi - -if [ $(command -v "$JAR_CMD") ] ; then - jar_error_check=$("$JAR_CMD" -tf "$SPARK_ASSEMBLY_JAR" nonexistent/class/path 2>&1) - if [[ "$jar_error_check" =~ "invalid CEN header" ]]; then - echo "Loading Spark jar with '$JAR_CMD' failed. " 1>&2 - echo "This is likely because Spark was compiled with Java 7 and run " 1>&2 - echo "with Java 6. (see SPARK-1703). Please use Java 7 to run Spark " 1>&2 - echo "or build Spark with Java 6." 1>&2 - exit 1 - fi -fi - LAUNCH_CLASSPATH="$SPARK_ASSEMBLY_JAR" # Add the launcher build dir to the classpath if requested. diff --git a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala index 1e0ba5c28754a..169489df6c1ea 100644 --- a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala +++ b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala @@ -52,8 +52,8 @@ private[spark] class ChildFirstURLClassLoader(urls: Array[URL], parent: ClassLoa * Used to implement fine-grained class loading locks similar to what is done by Java 7. This * prevents deadlock issues when using non-hierarchical class loaders. * - * Note that due to Java 6 compatibility (and some issues with implementing class loaders in - * Scala), Java 7's `ClassLoader.registerAsParallelCapable` method is not called. + * Note that due to some issues with implementing class loaders in + * Scala, Java 7's `ClassLoader.registerAsParallelCapable` method is not called. */ private val locks = new ConcurrentHashMap[String, Object]() diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 5f132410540fd..153ece6224a6d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1295,8 +1295,7 @@ private[spark] object Utils extends Logging { } catch { case t: Throwable => if (originalThrowable != null) { - // We could do originalThrowable.addSuppressed(t), but it's - // not available in JDK 1.6. + originalThrowable.addSuppressed(t) logWarning(s"Suppressing exception in finally: " + t.getMessage, t) throw originalThrowable } else { diff --git a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala index 72fd6daba8de0..b2f5d9009ee5d 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala @@ -103,9 +103,6 @@ class SorterSuite extends SparkFunSuite { * has the keys and values alternating. The basic Java sorts work only on the keys, so the * real Java solution is to make Tuple2s to store the keys and values and sort an array of * those, while the Sorter approach can work directly on the input data format. - * - * Note that the Java implementation varies tremendously between Java 6 and Java 7, when - * the Java sort changed from merge sort to TimSort. */ ignore("Sorter benchmark for key-value pairs") { val numElements = 25000000 // 25 mil diff --git a/docs/building-spark.md b/docs/building-spark.md index 78cb9086f95e8..2128fdffecc05 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -7,11 +7,7 @@ redirect_from: "building-with-maven.html" * This will become a table of contents (this text will be scraped). {:toc} -Building Spark using Maven requires Maven 3.0.4 or newer and Java 6+. - -**Note:** Building Spark with Java 7 or later can create JAR files that may not be -readable with early versions of Java 6, due to the large number of files in the JAR -archive. Build with Java 6 if this is an issue for your deployment. +Building Spark using Maven requires Maven 3.0.4 or newer and Java 7+. # Building with `build/mvn` diff --git a/docs/index.md b/docs/index.md index fac071da81e60..7939657915fc9 100644 --- a/docs/index.md +++ b/docs/index.md @@ -20,7 +20,7 @@ Spark runs on both Windows and UNIX-like systems (e.g. Linux, Mac OS). It's easy locally on one machine --- all you need is to have `java` installed on your system `PATH`, or the `JAVA_HOME` environment variable pointing to a Java installation. -Spark runs on Java 6+, Python 2.6+ and R 3.1+. For the Scala API, Spark {{site.SPARK_VERSION}} uses +Spark runs on Java 7+, Python 2.6+ and R 3.1+. For the Scala API, Spark {{site.SPARK_VERSION}} uses Scala {{site.SCALA_BINARY_VERSION}}. You will need to use a compatible Scala version ({{site.SCALA_BINARY_VERSION}}.x). diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 10f474f237bfa..d5ff416fe89a4 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -54,7 +54,7 @@ import org.apache.spark.SparkConf
    -Spark {{site.SPARK_VERSION}} works with Java 6 and higher. If you are using Java 8, Spark supports +Spark {{site.SPARK_VERSION}} works with Java 7 and higher. If you are using Java 8, Spark supports [lambda expressions](http://docs.oracle.com/javase/tutorial/java/javaOO/lambdaexpressions.html) for concisely writing functions, otherwise you can use the classes in the [org.apache.spark.api.java.function](api/java/index.html?org/apache/spark/api/java/function/package-summary.html) package. diff --git a/make-distribution.sh b/make-distribution.sh index a2b0c431fb4d0..9f063da3a16c0 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -141,22 +141,6 @@ SPARK_HIVE=$("$MVN" help:evaluate -Dexpression=project.activeProfiles -pl sql/hi # because we use "set -o pipefail" echo -n) -JAVA_CMD="$JAVA_HOME"/bin/java -JAVA_VERSION=$("$JAVA_CMD" -version 2>&1) -if [[ ! "$JAVA_VERSION" =~ "1.6" && -z "$SKIP_JAVA_TEST" ]]; then - echo "***NOTE***: JAVA_HOME is not set to a JDK 6 installation. The resulting" - echo " distribution may not work well with PySpark and will not run" - echo " with Java 6 (See SPARK-1703 and SPARK-1911)." - echo " This test can be disabled by adding --skip-java-test." - echo "Output from 'java -version' was:" - echo "$JAVA_VERSION" - read -p "Would you like to continue anyways? [y,n]: " -r - if [[ ! "$REPLY" =~ ^[Yy]$ ]]; then - echo "Okay, exiting." - exit 1 - fi -fi - if [ "$NAME" == "none" ]; then NAME=$SPARK_HADOOP_VERSION fi diff --git a/pom.xml b/pom.xml index e65448e4b2325..67b6375f576d3 100644 --- a/pom.xml +++ b/pom.xml @@ -116,7 +116,7 @@ UTF-8 com.typesafe.akka 2.3.11 - 1.6 + 1.7 spark 0.21.1 shaded-protobuf diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java index 24b2892098059..192c6714b2406 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java @@ -25,8 +25,7 @@ public final class PlatformDependent { /** * Facade in front of {@link sun.misc.Unsafe}, used to avoid directly exposing Unsafe outside of - * this package. This also lets us aovid accidental use of deprecated methods or methods that - * aren't present in Java 6. + * this package. This also lets us avoid accidental use of deprecated methods. */ public static final class UNSAFE { From b127ff8a0c5fb704da574d101a2d0e27ac5f463a Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Sun, 7 Jun 2015 21:42:45 +0100 Subject: [PATCH 228/254] [SPARK-2808] [STREAMING] [KAFKA] cleanup tests from see if requiring producer acks eliminates the need for waitUntilLeaderOffset calls in tests Author: cody koeninger Closes #5921 from koeninger/kafka-0.8.2-test-cleanup and squashes the following commits: 1e89dc8 [cody koeninger] Merge branch 'master' into kafka-0.8.2-test-cleanup 4662828 [cody koeninger] [Streaming][Kafka] filter mima issue for removal of method from private test class af1e083 [cody koeninger] Merge branch 'master' into kafka-0.8.2-test-cleanup 4298ac2 [cody koeninger] [Streaming][Kafka] update comment to trigger jenkins attempt 1274afb [cody koeninger] [Streaming][Kafka] see if requiring producer acks eliminates the need for waitUntilLeaderOffset calls in tests --- .../spark/streaming/kafka/KafkaTestUtils.scala | 17 ++--------------- .../streaming/kafka/JavaKafkaRDDSuite.java | 3 --- .../spark/streaming/kafka/KafkaRDDSuite.scala | 4 ---- project/MimaExcludes.scala | 3 +++ python/pyspark/streaming/tests.py | 5 ----- 5 files changed, 5 insertions(+), 27 deletions(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala index 6dc4e9517d5a4..b608b75952721 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala @@ -195,6 +195,8 @@ private class KafkaTestUtils extends Logging { val props = new Properties() props.put("metadata.broker.list", brokerAddress) props.put("serializer.class", classOf[StringEncoder].getName) + // wait for all in-sync replicas to ack sends + props.put("request.required.acks", "-1") props } @@ -229,21 +231,6 @@ private class KafkaTestUtils extends Logging { tryAgain(1) } - /** Wait until the leader offset for the given topic/partition equals the specified offset */ - def waitUntilLeaderOffset( - topic: String, - partition: Int, - offset: Long): Unit = { - eventually(Time(10000), Time(100)) { - val kc = new KafkaCluster(Map("metadata.broker.list" -> brokerAddress)) - val tp = TopicAndPartition(topic, partition) - val llo = kc.getLatestLeaderOffsets(Set(tp)).right.get.apply(tp).offset - assert( - llo == offset, - s"$topic $partition $offset not reached after timeout") - } - } - private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = { def isPropagated = server.apis.metadataCache.getPartitionInfo(topic, partition) match { case Some(partitionState) => diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java index 5cf379635354f..a9dc6e50613ca 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java @@ -72,9 +72,6 @@ public void testKafkaRDD() throws InterruptedException { HashMap kafkaParams = new HashMap(); kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); - kafkaTestUtils.waitUntilLeaderOffset(topic1, 0, topic1data.length); - kafkaTestUtils.waitUntilLeaderOffset(topic2, 0, topic2data.length); - OffsetRange[] offsetRanges = { OffsetRange.create(topic1, 0, 0, 1), OffsetRange.create(topic2, 0, 0, 1) diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala index 054487269a935..d5baf5fd89994 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala @@ -61,8 +61,6 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress, "group.id" -> s"test-consumer-${Random.nextInt}") - kafkaTestUtils.waitUntilLeaderOffset(topic, 0, messages.size) - val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size)) val rdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( @@ -86,7 +84,6 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { // this is the "lots of messages" case kafkaTestUtils.sendMessages(topic, sent) val sentCount = sent.values.sum - kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sentCount) // rdd defined from leaders after sending messages, should get the number sent val rdd = getRdd(kc, Set(topic)) @@ -113,7 +110,6 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { val sentOnlyOne = Map("d" -> 1) kafkaTestUtils.sendMessages(topic, sentOnlyOne) - kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sentCount + 1) assert(rdd2.isDefined) assert(rdd2.get.count === 0, "got messages when there shouldn't be any") diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 73e4bfd78e577..8a93ca2999510 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -47,6 +47,9 @@ object MimaExcludes { // Mima false positive (was a private[spark] class) ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.util.collection.PairIterator"), + // Removing a testing method from a private class + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.kafka.KafkaTestUtils.waitUntilLeaderOffset"), // SQL execution is considered private. excludePackage("org.apache.spark.sql.execution") ) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 46cb18b2e8ef9..57049beea4dba 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -615,7 +615,6 @@ def test_kafka_stream(self): self._kafkaTestUtils.createTopic(topic) self._kafkaTestUtils.sendMessages(topic, sendData) - self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values())) stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(), "test-streaming-consumer", {topic: 1}, @@ -631,7 +630,6 @@ def test_kafka_direct_stream(self): self._kafkaTestUtils.createTopic(topic) self._kafkaTestUtils.sendMessages(topic, sendData) - self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values())) stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams) self._validateStreamResult(sendData, stream) @@ -646,7 +644,6 @@ def test_kafka_direct_stream_from_offset(self): self._kafkaTestUtils.createTopic(topic) self._kafkaTestUtils.sendMessages(topic, sendData) - self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values())) stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams, fromOffsets) self._validateStreamResult(sendData, stream) @@ -661,7 +658,6 @@ def test_kafka_rdd(self): self._kafkaTestUtils.createTopic(topic) self._kafkaTestUtils.sendMessages(topic, sendData) - self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values())) rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges) self._validateRddResult(sendData, rdd) @@ -677,7 +673,6 @@ def test_kafka_rdd_with_leaders(self): self._kafkaTestUtils.createTopic(topic) self._kafkaTestUtils.sendMessages(topic, sendData) - self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values())) rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders) self._validateRddResult(sendData, rdd) From 5e7b6b67bed9cd0d8c7d4e78df666b807e8f7ef2 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sun, 7 Jun 2015 14:11:20 -0700 Subject: [PATCH 229/254] [SPARK-8117] [SQL] Push codegen implementation into each Expression This PR move codegen implementation of expressions into Expression class itself, make it easy to manage. It introduces two APIs in Expression: ``` def gen(ctx: CodeGenContext): GeneratedExpressionCode def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code ``` gen(ctx) will call genSource(ctx, ev) to generate Java source code for the current expression. A expression needs to override genSource(). Here are the types: ``` type Term String type Code String /** * Java source for evaluating an [[Expression]] given a [[Row]] of input. */ case class GeneratedExpressionCode(var code: Code, nullTerm: Term, primitiveTerm: Term, objectTerm: Term) /** * A context for codegen, which is used to bookkeeping the expressions those are not supported * by codegen, then they are evaluated directly. The unsupported expression is appended at the * end of `references`, the position of it is kept in the code, used to access and evaluate it. */ class CodeGenContext { /** * Holding all the expressions those do not support codegen, will be evaluated directly. */ val references: Seq[Expression] = new mutable.ArrayBuffer[Expression]() } ``` This is basically #6660, but fixed style violation and compilation failure. Author: Davies Liu Author: Reynold Xin Closes #6690 from rxin/codegen and squashes the following commits: e1368c2 [Reynold Xin] Fixed tests. 73db80e [Reynold Xin] Fixed compilation failure. 19d6435 [Reynold Xin] Fixed style violation. 9adaeaf [Davies Liu] address comments f42c732 [Davies Liu] improve coverage and tests bad6828 [Davies Liu] address comments e03edaa [Davies Liu] consts fold 86fac2c [Davies Liu] fix style 02262c9 [Davies Liu] address comments b5d3617 [Davies Liu] Merge pull request #5 from rxin/codegen 48c454f [Reynold Xin] Some code gen update. 2344bc0 [Davies Liu] fix test 12ff88a [Davies Liu] fix build c5fb514 [Davies Liu] rename 8c6d82d [Davies Liu] update docs b145047 [Davies Liu] fix style e57959d [Davies Liu] add type alias 3ff25f8 [Davies Liu] refactor 593d617 [Davies Liu] pushing codegen into Expression --- .../catalyst/expressions/BoundAttribute.scala | 9 + .../spark/sql/catalyst/expressions/Cast.scala | 42 + .../sql/catalyst/expressions/Expression.scala | 100 +++ .../sql/catalyst/expressions/arithmetic.scala | 161 +++- .../expressions/codegen/CodeGenerator.scala | 750 ++++-------------- .../codegen/GenerateMutableProjection.scala | 6 +- .../codegen/GenerateOrdering.scala | 20 +- .../codegen/GeneratePredicate.scala | 4 +- .../codegen/GenerateProjection.scala | 26 +- .../expressions/codegen/package.scala | 3 + .../expressions/decimalFunctions.scala | 19 + .../sql/catalyst/expressions/literals.scala | 54 ++ .../expressions/mathfuncs/binary.scala | 24 +- .../expressions/mathfuncs/unary.scala | 30 +- .../expressions/namedExpressions.scala | 6 +- .../catalyst/expressions/nullFunctions.scala | 55 ++ .../sql/catalyst/expressions/predicates.scala | 192 ++++- .../spark/sql/catalyst/expressions/sets.scala | 54 +- .../expressions/stringOperations.scala | 18 + .../ExpressionEvaluationSuite.scala | 87 +- .../GeneratedEvaluationSuite.scala | 27 +- .../GeneratedMutableEvaluationSuite.scala | 61 -- .../ParquetPartitionDiscoverySuite.scala | 6 +- 23 files changed, 1036 insertions(+), 718 deletions(-) delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 1ffc95c676f6f..005de3166095f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.Logging import org.apache.spark.sql.catalyst.errors.attachTree +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.trees @@ -41,6 +42,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def qualifiers: Seq[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + s""" + boolean ${ev.isNull} = i.isNullAt($ordinal); + ${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ? + ${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)}); + """ + } } object BindReferences extends Logging { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 21adac144112e..5f76a512679a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ @@ -433,6 +434,47 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w val evaluated = child.eval(input) if (evaluated == null) null else cast(evaluated) } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + // TODO(cg): Add support for more data types. + (child.dataType, dataType) match { + + case (BinaryType, StringType) => + defineCodeGen (ctx, ev, c => + s"new ${ctx.stringType}().set($c)") + case (DateType, StringType) => + defineCodeGen(ctx, ev, c => + s"""new ${ctx.stringType}().set( + org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""") + // Special handling required for timestamps in hive test cases since the toString function + // does not match the expected output. + case (TimestampType, StringType) => + super.genCode(ctx, ev) + case (_, StringType) => + defineCodeGen(ctx, ev, c => s"new ${ctx.stringType}().set(String.valueOf($c))") + + // fallback for DecimalType, this must be before other numeric types + case (_, dt: DecimalType) => + super.genCode(ctx, ev) + + case (BooleanType, dt: NumericType) => + defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)") + case (dt: DecimalType, BooleanType) => + defineCodeGen(ctx, ev, c => s"$c.isZero()") + case (dt: NumericType, BooleanType) => + defineCodeGen(ctx, ev, c => s"$c != 0") + + case (_: DecimalType, IntegerType) => + defineCodeGen(ctx, ev, c => s"($c).toInt()") + case (_: DecimalType, dt: NumericType) => + defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()") + case (_: NumericType, dt: NumericType) => + defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)") + + case other => + super.genCode(ctx, ev) + } + } } object Cast { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index b2b9d1a5e1581..0ed576b3d5870 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext, Term} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ @@ -51,6 +52,44 @@ abstract class Expression extends TreeNode[Expression] { /** Returns the result of evaluating this expression on a given input Row */ def eval(input: Row = null): Any + /** + * Returns an [[GeneratedExpressionCode]], which contains Java source code that + * can be used to generate the result of evaluating the expression on an input row. + * + * @param ctx a [[CodeGenContext]] + * @return [[GeneratedExpressionCode]] + */ + def gen(ctx: CodeGenContext): GeneratedExpressionCode = { + val isNull = ctx.freshName("isNull") + val primitive = ctx.freshName("primitive") + val ve = GeneratedExpressionCode("", isNull, primitive) + ve.code = genCode(ctx, ve) + ve + } + + /** + * Returns Java source code that can be compiled to evaluate this expression. + * The default behavior is to call the eval method of the expression. Concrete expression + * implementations should override this to do actual code generation. + * + * @param ctx a [[CodeGenContext]] + * @param ev an [[GeneratedExpressionCode]] with unique terms. + * @return Java source code + */ + protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + ctx.references += this + val objectTerm = ctx.freshName("obj") + s""" + /* expression: ${this} */ + Object ${objectTerm} = expressions[${ctx.references.size - 1}].eval(i); + boolean ${ev.isNull} = ${objectTerm} == null; + ${ctx.javaType(this.dataType)} ${ev.primitive} = ${ctx.defaultValue(this.dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = (${ctx.boxedType(this.dataType)})${objectTerm}; + } + """ + } + /** * Returns `true` if this expression and all its children have been resolved to a specific schema * and input data types checking passed, and `false` if it still contains any unresolved @@ -116,6 +155,41 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express override def nullable: Boolean = left.nullable || right.nullable override def toString: String = s"($left $symbol $right)" + + /** + * Short hand for generating binary evaluation code, which depends on two sub-evaluations of + * the same type. If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f accepts two variable names and returns Java code to compute the output. + */ + protected def defineCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (Term, Term) => Code): String = { + // TODO: Right now some timestamp tests fail if we enforce this... + if (left.dataType != right.dataType) { + // log.warn(s"${left.dataType} != ${right.dataType}") + } + + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val resultCode = f(eval1.primitive, eval2.primitive) + + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${eval2.code} + if(!${eval2.isNull}) { + ${ev.primitive} = $resultCode; + } else { + ${ev.isNull} = true; + } + } + """ + } } private[sql] object BinaryExpression { @@ -128,6 +202,32 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => + + /** + * Called by unary expressions to generate a code block that returns null if its parent returns + * null, and if not not null, use `f` to generate the expression. + * + * As an example, the following does a boolean inversion (i.e. NOT). + * {{{ + * defineCodeGen(ctx, ev, c => s"!($c)") + * }}} + * + * @param f function that accepts a variable name and returns Java code to compute the output. + */ + protected def defineCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: Term => Code): Code = { + val eval = child.gen(ctx) + // reuse the previous isNull + ev.isNull = eval.isNull + eval.code + s""" + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = ${f(eval.primitive)}; + } + """ + } } // TODO Semantically we probably not need GroupExpression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index a3770f998d94d..3ac7c92dcd009 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen.{Code, GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -49,6 +50,11 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic { private lazy val numeric = TypeUtils.getNumeric(dataType) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match { + case dt: DecimalType => defineCodeGen(ctx, ev, c => s"c.unary_$$minus()") + case dt: NumericType => defineCodeGen(ctx, ev, c => s"-($c)") + } + protected override def evalInternal(evalE: Any) = numeric.negate(evalE) } @@ -67,6 +73,21 @@ case class Sqrt(child: Expression) extends UnaryArithmetic { if (value < 0) null else math.sqrt(value) } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + if (${eval.primitive} < 0.0) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = java.lang.Math.sqrt(${eval.primitive}); + } + } + """ + } } /** @@ -86,6 +107,9 @@ case class Abs(child: Expression) extends UnaryArithmetic { abstract class BinaryArithmetic extends BinaryExpression { self: Product => + /** Name of the function for this expression on a [[Decimal]] type. */ + def decimalMethod: String = "" + override def dataType: DataType = left.dataType override def checkInputDataTypes(): TypeCheckResult = { @@ -114,6 +138,17 @@ abstract class BinaryArithmetic extends BinaryExpression { } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match { + case dt: DecimalType => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") + // byte and short are casted into int when add, minus, times or divide + case ByteType | ShortType => + defineCodeGen(ctx, ev, (eval1, eval2) => + s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + case _ => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") + } + protected def evalInternal(evalE1: Any, evalE2: Any): Any = sys.error(s"BinaryArithmetics must override either eval or evalInternal") } @@ -124,6 +159,7 @@ private[sql] object BinaryArithmetic { case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" + override def decimalMethod: String = "$plus" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) @@ -138,6 +174,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "-" + override def decimalMethod: String = "$minus" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) @@ -152,6 +189,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "*" + override def decimalMethod: String = "$times" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) @@ -166,6 +204,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "/" + override def decimalMethod: String = "$divide" + override def nullable: Boolean = true override lazy val resolved = @@ -192,10 +232,40 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } } } + + /** + * Special case handling due to division by 0 => null. + */ + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val test = if (left.dataType.isInstanceOf[DecimalType]) { + s"${eval2.primitive}.isZero()" + } else { + s"${eval2.primitive} == 0" + } + val method = if (left.dataType.isInstanceOf[DecimalType]) { + s".$decimalMethod" + } else { + s"$symbol" + } + eval1.code + eval2.code + + s""" + boolean ${ev.isNull} = false; + ${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)}; + if (${eval1.isNull} || ${eval2.isNull} || $test) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive}); + } + """ + } } case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "%" + override def decimalMethod: String = "reminder" + override def nullable: Boolean = true override lazy val resolved = @@ -222,6 +292,34 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } } } + + /** + * Special case handling for x % 0 ==> null. + */ + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val test = if (left.dataType.isInstanceOf[DecimalType]) { + s"${eval2.primitive}.isZero()" + } else { + s"${eval2.primitive} == 0" + } + val method = if (left.dataType.isInstanceOf[DecimalType]) { + s".$decimalMethod" + } else { + s"$symbol" + } + eval1.code + eval2.code + + s""" + boolean ${ev.isNull} = false; + ${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)}; + if (${eval1.isNull} || ${eval2.isNull} || $test) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive}); + } + """ + } } /** @@ -271,7 +369,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet } /** - * A function that calculates bitwise xor(^) of two numbers. + * A function that calculates bitwise xor of two numbers. */ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "^" @@ -313,6 +411,10 @@ case class BitwiseNot(child: Expression) extends UnaryArithmetic { ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any] } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)})~($c)") + } + protected override def evalInternal(evalE: Any) = not(evalE) } @@ -340,6 +442,33 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + if (ctx.isNativeType(left.dataType)) { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + eval1.code + eval2.code + s""" + boolean ${ev.isNull} = false; + ${ctx.javaType(left.dataType)} ${ev.primitive} = + ${ctx.defaultValue(left.dataType)}; + + if (${eval1.isNull}) { + ${ev.isNull} = ${eval2.isNull}; + ${ev.primitive} = ${eval2.primitive}; + } else if (${eval2.isNull}) { + ${ev.isNull} = ${eval1.isNull}; + ${ev.primitive} = ${eval1.primitive}; + } else { + if (${eval1.primitive} > ${eval2.primitive}) { + ${ev.primitive} = ${eval1.primitive}; + } else { + ${ev.primitive} = ${eval2.primitive}; + } + } + """ + } else { + super.genCode(ctx, ev) + } + } override def toString: String = s"MaxOf($left, $right)" } @@ -367,5 +496,35 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + if (ctx.isNativeType(left.dataType)) { + + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + + eval1.code + eval2.code + s""" + boolean ${ev.isNull} = false; + ${ctx.javaType(left.dataType)} ${ev.primitive} = + ${ctx.defaultValue(left.dataType)}; + + if (${eval1.isNull}) { + ${ev.isNull} = ${eval2.isNull}; + ${ev.primitive} = ${eval2.primitive}; + } else if (${eval2.isNull}) { + ${ev.isNull} = ${eval1.isNull}; + ${ev.primitive} = ${eval1.primitive}; + } else { + if (${eval1.primitive} < ${eval2.primitive}) { + ${ev.primitive} = ${eval1.primitive}; + } else { + ${ev.primitive} = ${eval2.primitive}; + } + } + """ + } else { + super.genCode(ctx, ev) + } + } + override def toString: String = s"MinOf($left, $right)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index cd604121b7dd9..c8d0aaf79f5f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -24,7 +24,6 @@ import com.google.common.cache.{CacheBuilder, CacheLoader} import org.codehaus.janino.ClassBodyEvaluator import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -32,6 +31,157 @@ import org.apache.spark.sql.types._ class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int] class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] +/** + * Java source for evaluating an [[Expression]] given a [[Row]] of input. + * + * @param code The sequence of statements required to evaluate the expression. + * @param isNull A term that holds a boolean value representing whether the expression evaluated + * to null. + * @param primitive A term for a possible primitive value of the result of the evaluation. Not + * valid if `isNull` is set to `true`. + */ +case class GeneratedExpressionCode(var code: Code, var isNull: Term, var primitive: Term) + +/** + * A context for codegen, which is used to bookkeeping the expressions those are not supported + * by codegen, then they are evaluated directly. The unsupported expression is appended at the + * end of `references`, the position of it is kept in the code, used to access and evaluate it. + */ +class CodeGenContext { + + /** + * Holding all the expressions those do not support codegen, will be evaluated directly. + */ + val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]() + + val stringType: String = classOf[UTF8String].getName + val decimalType: String = classOf[Decimal].getName + + private val curId = new java.util.concurrent.atomic.AtomicInteger() + + /** + * Returns a term name that is unique within this instance of a `CodeGenerator`. + * + * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` + * function.) + */ + def freshName(prefix: String): Term = { + s"$prefix${curId.getAndIncrement}" + } + + /** + * Return the code to access a column for given DataType + */ + def getColumn(dataType: DataType, ordinal: Int): Code = { + if (isNativeType(dataType)) { + s"i.${accessorForType(dataType)}($ordinal)" + } else { + s"(${boxedType(dataType)})i.apply($ordinal)" + } + } + + /** + * Return the code to update a column in Row for given DataType + */ + def setColumn(dataType: DataType, ordinal: Int, value: Term): Code = { + if (isNativeType(dataType)) { + s"${mutatorForType(dataType)}($ordinal, $value)" + } else { + s"update($ordinal, $value)" + } + } + + /** + * Return the name of accessor in Row for a DataType + */ + def accessorForType(dt: DataType): Term = dt match { + case IntegerType => "getInt" + case other => s"get${boxedType(dt)}" + } + + /** + * Return the name of mutator in Row for a DataType + */ + def mutatorForType(dt: DataType): Term = dt match { + case IntegerType => "setInt" + case other => s"set${boxedType(dt)}" + } + + /** + * Return the Java type for a DataType + */ + def javaType(dt: DataType): Term = dt match { + case IntegerType => "int" + case LongType => "long" + case ShortType => "short" + case ByteType => "byte" + case DoubleType => "double" + case FloatType => "float" + case BooleanType => "boolean" + case dt: DecimalType => decimalType + case BinaryType => "byte[]" + case StringType => stringType + case DateType => "int" + case TimestampType => "java.sql.Timestamp" + case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName + case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName + case _ => "Object" + } + + /** + * Return the boxed type in Java + */ + def boxedType(dt: DataType): Term = dt match { + case IntegerType => "Integer" + case LongType => "Long" + case ShortType => "Short" + case ByteType => "Byte" + case DoubleType => "Double" + case FloatType => "Float" + case BooleanType => "Boolean" + case DateType => "Integer" + case _ => javaType(dt) + } + + /** + * Return the representation of default value for given DataType + */ + def defaultValue(dt: DataType): Term = dt match { + case BooleanType => "false" + case FloatType => "-1.0f" + case ShortType => "(short)-1" + case LongType => "-1L" + case ByteType => "(byte)-1" + case DoubleType => "-1.0" + case IntegerType => "-1" + case DateType => "-1" + case _ => "null" + } + + /** + * Returns a function to generate equal expression in Java + */ + def equalFunc(dataType: DataType): ((Term, Term) => Code) = dataType match { + case BinaryType => { case (eval1, eval2) => + s"java.util.Arrays.equals($eval1, $eval2)" } + case IntegerType | BooleanType | LongType | DoubleType | FloatType | ShortType | ByteType => + { case (eval1, eval2) => s"$eval1 == $eval2" } + case other => + { case (eval1, eval2) => s"$eval1.equals($eval2)" } + } + + /** + * List of data types that have special accessors and setters in [[Row]]. + */ + val nativeTypes = + Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType) + + /** + * Returns true if the data type has a special accessor and setter in [[Row]]. + */ + def isNativeType(dt: DataType): Boolean = nativeTypes.contains(dt) +} + /** * A base class for generators of byte code to perform expression evaluation. Includes a set of * helpers for referring to Catalyst types and building trees that perform evaluation of individual @@ -39,14 +189,9 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] */ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { - protected val rowType = classOf[Row].getName - protected val stringType = classOf[UTF8String].getName - protected val decimalType = classOf[Decimal].getName - protected val exprType = classOf[Expression].getName - protected val mutableRowType = classOf[MutableRow].getName - protected val genericMutableRowType = classOf[GenericMutableRow].getName - - private val curId = new java.util.concurrent.atomic.AtomicInteger() + protected val exprType: String = classOf[Expression].getName + protected val mutableRowType: String = classOf[MutableRow].getName + protected val genericMutableRowType: String = classOf[GenericMutableRow].getName /** * Can be flipped on manually in the console to add (expensive) expression evaluation trace code. @@ -75,10 +220,16 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin */ protected def compile(code: String): Class[_] = { val startTime = System.nanoTime() - val clazz = new ClassBodyEvaluator(code).getClazz() + val clazz = try { + new ClassBodyEvaluator(code).getClazz() + } catch { + case e: Exception => + logError(s"failed to compile:\n $code", e) + throw e + } val endTime = System.nanoTime() def timeMs: Double = (endTime - startTime).toDouble / 1000000 - logDebug(s"Compiled Java code (${code.size} bytes) in $timeMs ms") + logDebug(s"Code (${code.size} bytes) compiled in $timeMs ms") clazz } @@ -112,586 +263,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin /** Generates the requested evaluator given already bound expression(s). */ def generate(expressions: InType): OutType = cache.get(canonicalize(expressions)) - /** - * Returns a term name that is unique within this instance of a `CodeGenerator`. - * - * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` - * function.) - */ - protected def freshName(prefix: String): String = { - s"$prefix${curId.getAndIncrement}" - } - - /** - * Scala ASTs for evaluating an [[Expression]] given a [[Row]] of input. - * - * @param code The sequence of statements required to evaluate the expression. - * @param nullTerm A term that holds a boolean value representing whether the expression evaluated - * to null. - * @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not - * valid if `nullTerm` is set to `true`. - * @param objectTerm A possibly boxed version of the result of evaluating this expression. - */ - protected case class EvaluatedExpression( - code: String, - nullTerm: String, - primitiveTerm: String, - objectTerm: String) - - /** - * A context for codegen, which is used to bookkeeping the expressions those are not supported - * by codegen, then they are evaluated directly. The unsupported expression is appended at the - * end of `references`, the position of it is kept in the code, used to access and evaluate it. - */ - protected class CodeGenContext { - /** - * Holding all the expressions those do not support codegen, will be evaluated directly. - */ - val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]() - } - /** * Create a new codegen context for expression evaluator, used to store those * expressions that don't support codegen */ def newCodeGenContext(): CodeGenContext = { - new CodeGenContext() - } - - /** - * Given an expression tree returns an [[EvaluatedExpression]], which contains Scala trees that - * can be used to determine the result of evaluating the expression on an input row. - */ - def expressionEvaluator(e: Expression, ctx: CodeGenContext): EvaluatedExpression = { - val primitiveTerm = freshName("primitiveTerm") - val nullTerm = freshName("nullTerm") - val objectTerm = freshName("objectTerm") - - implicit class Evaluate1(e: Expression) { - def castOrNull(f: String => String, dataType: DataType): String = { - val eval = expressionEvaluator(e, ctx) - eval.code + - s""" - boolean $nullTerm = ${eval.nullTerm}; - ${primitiveForType(dataType)} $primitiveTerm = ${defaultPrimitive(dataType)}; - if (!$nullTerm) { - $primitiveTerm = ${f(eval.primitiveTerm)}; - } - """ - } - } - - implicit class Evaluate2(expressions: (Expression, Expression)) { - - /** - * Short hand for generating binary evaluation code, which depends on two sub-evaluations of - * the same type. If either of the sub-expressions is null, the result of this computation - * is assumed to be null. - * - * @param f a function from two primitive term names to a tree that evaluates them. - */ - def evaluate(f: (String, String) => String): String = - evaluateAs(expressions._1.dataType)(f) - - def evaluateAs(resultType: DataType)(f: (String, String) => String): String = { - // TODO: Right now some timestamp tests fail if we enforce this... - if (expressions._1.dataType != expressions._2.dataType) { - log.warn(s"${expressions._1.dataType} != ${expressions._2.dataType}") - } - - val eval1 = expressionEvaluator(expressions._1, ctx) - val eval2 = expressionEvaluator(expressions._2, ctx) - val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm) - - eval1.code + eval2.code + - s""" - boolean $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm}; - ${primitiveForType(resultType)} $primitiveTerm = ${defaultPrimitive(resultType)}; - if(!$nullTerm) { - $primitiveTerm = (${primitiveForType(resultType)})($resultCode); - } - """ - } - } - - val inputTuple = "i" - - // TODO: Skip generation of null handling code when expression are not nullable. - val primitiveEvaluation: PartialFunction[Expression, String] = { - case b @ BoundReference(ordinal, dataType, nullable) => - s""" - final boolean $nullTerm = $inputTuple.isNullAt($ordinal); - final ${primitiveForType(dataType)} $primitiveTerm = $nullTerm ? - ${defaultPrimitive(dataType)} : (${getColumn(inputTuple, dataType, ordinal)}); - """ - - case expressions.Literal(null, dataType) => - s""" - final boolean $nullTerm = true; - ${primitiveForType(dataType)} $primitiveTerm = ${defaultPrimitive(dataType)}; - """ - - case expressions.Literal(value: UTF8String, StringType) => - val arr = s"new byte[]{${value.getBytes.map(_.toString).mkString(", ")}}" - s""" - final boolean $nullTerm = false; - ${stringType} $primitiveTerm = - new ${stringType}().set(${arr}); - """ - - case expressions.Literal(value, FloatType) => - s""" - final boolean $nullTerm = false; - float $primitiveTerm = ${value}f; - """ - - case expressions.Literal(value, dt @ DecimalType()) => - s""" - final boolean $nullTerm = false; - ${primitiveForType(dt)} $primitiveTerm = new ${primitiveForType(dt)}().set($value); - """ - - case expressions.Literal(value, dataType) => - s""" - final boolean $nullTerm = false; - ${primitiveForType(dataType)} $primitiveTerm = $value; - """ - - case Cast(child @ BinaryType(), StringType) => - child.castOrNull(c => - s"new ${stringType}().set($c)", - StringType) - - case Cast(child @ DateType(), StringType) => - child.castOrNull(c => - s"""new ${stringType}().set( - org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""", - StringType) - - case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - child.castOrNull(c => s"(${primitiveForType(dt)})($c?1:0)", dt) - - case Cast(child @ DecimalType(), IntegerType) => - child.castOrNull(c => s"($c).toInt()", IntegerType) - - case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - child.castOrNull(c => s"($c).to${termForType(dt)}()", dt) - - case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - child.castOrNull(c => s"(${primitiveForType(dt)})($c)", dt) - - // Special handling required for timestamps in hive test cases since the toString function - // does not match the expected output. - case Cast(e, StringType) if e.dataType != TimestampType => - e.castOrNull(c => - s"new ${stringType}().set(String.valueOf($c))", - StringType) - - case EqualTo(e1 @ BinaryType(), e2 @ BinaryType()) => - (e1, e2).evaluateAs (BooleanType) { - case (eval1, eval2) => - s"java.util.Arrays.equals((byte[])$eval1, (byte[])$eval2)" - } - - case EqualTo(e1, e2) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 == $eval2" } - - case GreaterThan(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 > $eval2" } - case GreaterThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 >= $eval2" } - case LessThan(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 < $eval2" } - case LessThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 <= $eval2" } - - case And(e1, e2) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - s""" - ${eval1.code} - boolean $nullTerm = false; - boolean $primitiveTerm = false; - - if (!${eval1.nullTerm} && !${eval1.primitiveTerm}) { - } else { - ${eval2.code} - if (!${eval2.nullTerm} && !${eval2.primitiveTerm}) { - } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { - $primitiveTerm = true; - } else { - $nullTerm = true; - } - } - """ - - case Or(e1, e2) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - - s""" - ${eval1.code} - boolean $nullTerm = false; - boolean $primitiveTerm = false; - - if (!${eval1.nullTerm} && ${eval1.primitiveTerm}) { - $primitiveTerm = true; - } else { - ${eval2.code} - if (!${eval2.nullTerm} && ${eval2.primitiveTerm}) { - $primitiveTerm = true; - } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { - $primitiveTerm = false; - } else { - $nullTerm = true; - } - } - """ - - case Not(child) => - // Uh, bad function name... - child.castOrNull(c => s"!$c", BooleanType) - - case Add(e1 @ DecimalType(), e2 @ DecimalType()) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$plus($eval2)" } - case Subtract(e1 @ DecimalType(), e2 @ DecimalType()) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$minus($eval2)" } - case Multiply(e1 @ DecimalType(), e2 @ DecimalType()) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$times($eval2)" } - case Divide(e1 @ DecimalType(), e2 @ DecimalType()) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = null; - if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm}.isZero()) { - $nullTerm = true; - } else { - $primitiveTerm = ${eval1.primitiveTerm}.$$div${eval2.primitiveTerm}); - } - """ - case Remainder(e1 @ DecimalType(), e2 @ DecimalType()) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = 0; - if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm}.isZero()) { - $nullTerm = true; - } else { - $primitiveTerm = ${eval1.primitiveTerm}.remainder(${eval2.primitiveTerm}); - } - """ - - case Add(e1, e2) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 + $eval2" } - case Subtract(e1, e2) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 - $eval2" } - case Multiply(e1, e2) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 * $eval2" } - case Divide(e1, e2) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = 0; - if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm} == 0) { - $nullTerm = true; - } else { - $primitiveTerm = ${eval1.primitiveTerm} / ${eval2.primitiveTerm}; - } - """ - case Remainder(e1, e2) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = 0; - if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm} == 0) { - $nullTerm = true; - } else { - $primitiveTerm = ${eval1.primitiveTerm} % ${eval2.primitiveTerm}; - } - """ - - case IsNotNull(e) => - val eval = expressionEvaluator(e, ctx) - s""" - ${eval.code} - boolean $nullTerm = false; - boolean $primitiveTerm = !${eval.nullTerm}; - """ - - case IsNull(e) => - val eval = expressionEvaluator(e, ctx) - s""" - ${eval.code} - boolean $nullTerm = false; - boolean $primitiveTerm = ${eval.nullTerm}; - """ - - case e @ Coalesce(children) => - s""" - boolean $nullTerm = true; - ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)}; - """ + - children.map { c => - val eval = expressionEvaluator(c, ctx) - s""" - if($nullTerm) { - ${eval.code} - if(!${eval.nullTerm}) { - $nullTerm = false; - $primitiveTerm = ${eval.primitiveTerm}; - } - } - """ - }.mkString("\n") - - case e @ expressions.If(condition, trueValue, falseValue) => - val condEval = expressionEvaluator(condition, ctx) - val trueEval = expressionEvaluator(trueValue, ctx) - val falseEval = expressionEvaluator(falseValue, ctx) - - s""" - boolean $nullTerm = false; - ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)}; - ${condEval.code} - if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) { - ${trueEval.code} - $nullTerm = ${trueEval.nullTerm}; - $primitiveTerm = ${trueEval.primitiveTerm}; - } else { - ${falseEval.code} - $nullTerm = ${falseEval.nullTerm}; - $primitiveTerm = ${falseEval.primitiveTerm}; - } - """ - - case NewSet(elementType) => - s""" - boolean $nullTerm = false; - ${hashSetForType(elementType)} $primitiveTerm = new ${hashSetForType(elementType)}(); - """ - - case AddItemToSet(item, set) => - val itemEval = expressionEvaluator(item, ctx) - val setEval = expressionEvaluator(set, ctx) - - val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType - val htype = hashSetForType(elementType) - - itemEval.code + setEval.code + - s""" - if (!${itemEval.nullTerm} && !${setEval.nullTerm}) { - (($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm}); - } - boolean $nullTerm = false; - ${htype} $primitiveTerm = ($htype)${setEval.primitiveTerm}; - """ - - case CombineSets(left, right) => - val leftEval = expressionEvaluator(left, ctx) - val rightEval = expressionEvaluator(right, ctx) - - val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType - val htype = hashSetForType(elementType) - - leftEval.code + rightEval.code + - s""" - boolean $nullTerm = false; - ${htype} $primitiveTerm = - (${htype})${leftEval.primitiveTerm}; - $primitiveTerm.union((${htype})${rightEval.primitiveTerm}); - """ - - case MaxOf(e1, e2) if !e1.dataType.isInstanceOf[DecimalType] => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = ${defaultPrimitive(e1.dataType)}; - - if (${eval1.nullTerm}) { - $nullTerm = ${eval2.nullTerm}; - $primitiveTerm = ${eval2.primitiveTerm}; - } else if (${eval2.nullTerm}) { - $nullTerm = ${eval1.nullTerm}; - $primitiveTerm = ${eval1.primitiveTerm}; - } else { - if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) { - $primitiveTerm = ${eval1.primitiveTerm}; - } else { - $primitiveTerm = ${eval2.primitiveTerm}; - } - } - """ - - case MinOf(e1, e2) if !e1.dataType.isInstanceOf[DecimalType] => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = ${defaultPrimitive(e1.dataType)}; - - if (${eval1.nullTerm}) { - $nullTerm = ${eval2.nullTerm}; - $primitiveTerm = ${eval2.primitiveTerm}; - } else if (${eval2.nullTerm}) { - $nullTerm = ${eval1.nullTerm}; - $primitiveTerm = ${eval1.primitiveTerm}; - } else { - if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) { - $primitiveTerm = ${eval1.primitiveTerm}; - } else { - $primitiveTerm = ${eval2.primitiveTerm}; - } - } - """ - - case UnscaledValue(child) => - val childEval = expressionEvaluator(child, ctx) - - childEval.code + - s""" - boolean $nullTerm = ${childEval.nullTerm}; - long $primitiveTerm = $nullTerm ? -1 : ${childEval.primitiveTerm}.toUnscaledLong(); - """ - - case MakeDecimal(child, precision, scale) => - val eval = expressionEvaluator(child, ctx) - - eval.code + - s""" - boolean $nullTerm = ${eval.nullTerm}; - org.apache.spark.sql.types.Decimal $primitiveTerm = ${defaultPrimitive(DecimalType())}; - - if (!$nullTerm) { - $primitiveTerm = new org.apache.spark.sql.types.Decimal(); - $primitiveTerm = $primitiveTerm.setOrNull(${eval.primitiveTerm}, $precision, $scale); - $nullTerm = $primitiveTerm == null; - } - """ - } - - // If there was no match in the partial function above, we fall back on calling the interpreted - // expression evaluator. - val code: String = - primitiveEvaluation.lift.apply(e).getOrElse { - logError(s"No rules to generate $e") - ctx.references += e - s""" - /* expression: ${e} */ - Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i); - boolean $nullTerm = $objectTerm == null; - ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)}; - if (!$nullTerm) $primitiveTerm = (${termForType(e.dataType)})$objectTerm; - """ - } - - EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm) - } - - protected def getColumn(inputRow: String, dataType: DataType, ordinal: Int) = { - dataType match { - case StringType => s"(${stringType})$inputRow.apply($ordinal)" - case dt: DataType if isNativeType(dt) => s"$inputRow.${accessorForType(dt)}($ordinal)" - case _ => s"(${termForType(dataType)})$inputRow.apply($ordinal)" - } - } - - protected def setColumn( - destinationRow: String, - dataType: DataType, - ordinal: Int, - value: String): String = { - dataType match { - case StringType => s"$destinationRow.update($ordinal, $value)" - case dt: DataType if isNativeType(dt) => - s"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" - case _ => s"$destinationRow.update($ordinal, $value)" - } - } - - protected def accessorForType(dt: DataType) = dt match { - case IntegerType => "getInt" - case other => s"get${termForType(dt)}" - } - - protected def mutatorForType(dt: DataType) = dt match { - case IntegerType => "setInt" - case other => s"set${termForType(dt)}" - } - - protected def hashSetForType(dt: DataType): String = dt match { - case IntegerType => classOf[IntegerHashSet].getName - case LongType => classOf[LongHashSet].getName - case unsupportedType => - sys.error(s"Code generation not support for hashset of type $unsupportedType") + new CodeGenContext } - - protected def primitiveForType(dt: DataType): String = dt match { - case IntegerType => "int" - case LongType => "long" - case ShortType => "short" - case ByteType => "byte" - case DoubleType => "double" - case FloatType => "float" - case BooleanType => "boolean" - case dt: DecimalType => decimalType - case BinaryType => "byte[]" - case StringType => stringType - case DateType => "int" - case TimestampType => "java.sql.Timestamp" - case _ => "Object" - } - - protected def defaultPrimitive(dt: DataType): String = dt match { - case BooleanType => "false" - case FloatType => "-1.0f" - case ShortType => "-1" - case LongType => "-1" - case ByteType => "-1" - case DoubleType => "-1.0" - case IntegerType => "-1" - case DateType => "-1" - case dt: DecimalType => "null" - case StringType => "null" - case _ => "null" - } - - protected def termForType(dt: DataType): String = dt match { - case IntegerType => "Integer" - case LongType => "Long" - case ShortType => "Short" - case ByteType => "Byte" - case DoubleType => "Double" - case FloatType => "Float" - case BooleanType => "Boolean" - case dt: DecimalType => decimalType - case BinaryType => "byte[]" - case StringType => stringType - case DateType => "Integer" - case TimestampType => "java.sql.Timestamp" - case _ => "Object" - } - - /** - * List of data types that have special accessors and setters in [[Row]]. - */ - protected val nativeTypes = - Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType) - - /** - * Returns true if the data type has a special accessor and setter in [[Row]]. - */ - protected def isNativeType(dt: DataType) = nativeTypes.contains(dt) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 638b53fe0fe2f..e5ee2accd8a84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -37,13 +37,13 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { val ctx = newCodeGenContext() val projectionCode = expressions.zipWithIndex.map { case (e, i) => - val evaluationCode = expressionEvaluator(e, ctx) + val evaluationCode = e.gen(ctx) evaluationCode.code + s""" - if(${evaluationCode.nullTerm}) + if(${evaluationCode.isNull}) mutableRow.setNullAt($i); else - ${setColumn("mutableRow", e.dataType, i, evaluationCode.primitiveTerm)}; + mutableRow.${ctx.setColumn(e.dataType, i, evaluationCode.primitive)}; """ }.mkString("\n") val code = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 0ff840dab393c..36e155d164a40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -52,15 +52,15 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit val ctx = newCodeGenContext() val comparisons = ordering.zipWithIndex.map { case (order, i) => - val evalA = expressionEvaluator(order.child, ctx) - val evalB = expressionEvaluator(order.child, ctx) + val evalA = order.child.gen(ctx) + val evalB = order.child.gen(ctx) val asc = order.direction == Ascending val compare = order.child.dataType match { case BinaryType => s""" { - byte[] x = ${if (asc) evalA.primitiveTerm else evalB.primitiveTerm}; - byte[] y = ${if (!asc) evalB.primitiveTerm else evalA.primitiveTerm}; + byte[] x = ${if (asc) evalA.primitive else evalB.primitive}; + byte[] y = ${if (!asc) evalB.primitive else evalA.primitive}; int j = 0; while (j < x.length && j < y.length) { if (x[j] != y[j]) return x[j] - y[j]; @@ -73,8 +73,8 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit }""" case _: NumericType => s""" - if (${evalA.primitiveTerm} != ${evalB.primitiveTerm}) { - if (${evalA.primitiveTerm} > ${evalB.primitiveTerm}) { + if (${evalA.primitive} != ${evalB.primitive}) { + if (${evalA.primitive} > ${evalB.primitive}) { return ${if (asc) "1" else "-1"}; } else { return ${if (asc) "-1" else "1"}; @@ -82,7 +82,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit }""" case _ => s""" - int comp = ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm}); + int comp = ${evalA.primitive}.compare(${evalB.primitive}); if (comp != 0) { return ${if (asc) "comp" else "-comp"}; }""" @@ -93,11 +93,11 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit ${evalA.code} i = $b; ${evalB.code} - if (${evalA.nullTerm} && ${evalB.nullTerm}) { + if (${evalA.isNull} && ${evalB.isNull}) { // Nothing - } else if (${evalA.nullTerm}) { + } else if (${evalA.isNull}) { return ${if (order.direction == Ascending) "-1" else "1"}; - } else if (${evalB.nullTerm}) { + } else if (${evalB.isNull}) { return ${if (order.direction == Ascending) "1" else "-1"}; } else { $compare diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index fb18769f00da3..4a547b5ce9543 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -38,7 +38,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { protected def create(predicate: Expression): ((Row) => Boolean) = { val ctx = newCodeGenContext() - val eval = expressionEvaluator(predicate, ctx) + val eval = predicate.gen(ctx) val code = s""" import org.apache.spark.sql.Row; @@ -55,7 +55,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { @Override public boolean eval(Row i) { ${eval.code} - return !${eval.nullTerm} && ${eval.primitiveTerm}; + return !${eval.isNull} && ${eval.primitive}; } }""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index d5be1fc12e0f0..7caf4aaab88bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -45,19 +45,19 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val ctx = newCodeGenContext() val columns = expressions.zipWithIndex.map { case (e, i) => - s"private ${primitiveForType(e.dataType)} c$i = ${defaultPrimitive(e.dataType)};\n" + s"private ${ctx.javaType(e.dataType)} c$i = ${ctx.defaultValue(e.dataType)};\n" }.mkString("\n ") val initColumns = expressions.zipWithIndex.map { case (e, i) => - val eval = expressionEvaluator(e, ctx) + val eval = e.gen(ctx) s""" { // column$i ${eval.code} - nullBits[$i] = ${eval.nullTerm}; - if(!${eval.nullTerm}) { - c$i = ${eval.primitiveTerm}; + nullBits[$i] = ${eval.isNull}; + if (!${eval.isNull}) { + c$i = ${eval.primitive}; } } """ @@ -68,10 +68,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { }.mkString("\n ") val updateCases = expressions.zipWithIndex.map { case (e, i) => - s"case $i: { c$i = (${termForType(e.dataType)})value; return;}" + s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}" }.mkString("\n ") - val specificAccessorFunctions = nativeTypes.map { dataType => + val specificAccessorFunctions = ctx.nativeTypes.map { dataType => val cases = expressions.zipWithIndex.map { case (e, i) if e.dataType == dataType => s"case $i: return c$i;" @@ -80,21 +80,21 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { if (cases.count(_ != '\n') > 0) { s""" @Override - public ${primitiveForType(dataType)} ${accessorForType(dataType)}(int i) { + public ${ctx.javaType(dataType)} ${ctx.accessorForType(dataType)}(int i) { if (isNullAt(i)) { - return ${defaultPrimitive(dataType)}; + return ${ctx.defaultValue(dataType)}; } switch (i) { $cases } - return ${defaultPrimitive(dataType)}; + return ${ctx.defaultValue(dataType)}; }""" } else { "" } }.mkString("\n") - val specificMutatorFunctions = nativeTypes.map { dataType => + val specificMutatorFunctions = ctx.nativeTypes.map { dataType => val cases = expressions.zipWithIndex.map { case (e, i) if e.dataType == dataType => s"case $i: { c$i = value; return; }" @@ -103,7 +103,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { if (cases.count(_ != '\n') > 0) { s""" @Override - public void ${mutatorForType(dataType)}(int i, ${primitiveForType(dataType)} value) { + public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.javaType(dataType)} value) { nullBits[i] = false; switch (i) { $cases @@ -122,7 +122,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { case LongType => s"$col ^ ($col >>> 32)" case FloatType => s"Float.floatToIntBits($col)" case DoubleType => - s"Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32)" + s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))" case _ => s"$col.hashCode()" } s"isNullAt($i) ? 0 : ($nonNull)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala index 7f1b12cdd5800..6f9589d20445e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala @@ -27,6 +27,9 @@ import org.apache.spark.util.Utils */ package object codegen { + type Term = String + type Code = String + /** Canonicalizes an expression so those that differ only by names can reuse the same code. */ object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] { val batches = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index 65ba18924afe1..ddfadf314f838 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} import org.apache.spark.sql.types._ /** Return the unscaled Long value of a Decimal, assuming it fits in a Long */ @@ -35,6 +36,10 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { childResult.asInstanceOf[Decimal].toUnscaledLong } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()") + } } /** Create a Decimal from an unscaled Long value */ @@ -53,4 +58,18 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un new Decimal().setOrNull(childResult.asInstanceOf[Long], precision, scale) } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.decimalType} ${ev.primitive} = null; + + if (!${ev.isNull}) { + ${ev.primitive} = (new ${ctx.decimalType}()).setOrNull( + ${eval.primitive}, $precision, $scale); + ${ev.isNull} = ${ev.primitive} == null; + } + """ + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index d3ca3d9a4b18b..3a9271678bc9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions.codegen.{Code, CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ @@ -78,7 +79,60 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres override def toString: String = if (value != null) value.toString else "null" + override def equals(other: Any): Boolean = other match { + case o: Literal => + dataType.equals(o.dataType) && + (value == null && null == o.value || value != null && value.equals(o.value)) + case _ => false + } + override def eval(input: Row): Any = value + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + // change the isNull and primitive to consts, to inline them + if (value == null) { + ev.isNull = "true" + ev.primitive = ctx.defaultValue(dataType) + "" + } else { + dataType match { + case BooleanType => + ev.isNull = "false" + ev.primitive = value.toString + "" + case FloatType => // This must go before NumericType + val v = value.asInstanceOf[Float] + if (v.isNaN || v.isInfinite) { + super.genCode(ctx, ev) + } else { + ev.isNull = "false" + ev.primitive = s"${value}f" + "" + } + case DoubleType => // This must go before NumericType + val v = value.asInstanceOf[Double] + if (v.isNaN || v.isInfinite) { + super.genCode(ctx, ev) + } else { + ev.isNull = "false" + ev.primitive = s"${value}" + "" + } + + case ByteType | ShortType => // This must go before NumericType + ev.isNull = "false" + ev.primitive = s"(${ctx.javaType(dataType)})$value" + "" + case dt: NumericType if !dt.isInstanceOf[DecimalType] => + ev.isNull = "false" + ev.primitive = value.toString + "" + // eval() version may be faster for non-primitive types + case other => + super.genCode(ctx, ev) + } + } + } } // TODO: Specialize diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala index db853a2b97fad..88211acd7713c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.mathfuncs +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, BinaryExpression, Expression, Row} import org.apache.spark.sql.types._ @@ -49,6 +50,10 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) } } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)") + } } case class Atan2(left: Expression, right: Expression) @@ -70,9 +75,26 @@ case class Atan2(left: Expression, right: Expression) } } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s""" + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + """ + } } case class Hypot(left: Expression, right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") -case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") +case class Pow(left: Expression, right: Expression) + extends BinaryMathExpression(math.pow, "POWER") { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s""" + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + """ + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala index 41b422346a02d..5563cd94bf86d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.mathfuncs +import org.apache.spark.sql.catalyst.expressions.codegen.{Code, CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, Row, UnaryExpression} import org.apache.spark.sql.types._ @@ -44,6 +45,23 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) if (result.isNaN) null else result } } + + // name of function in java.lang.Math + def funcName: String = name.toLowerCase + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive}); + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + } + """ + } } case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS") @@ -72,7 +90,9 @@ case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P") -case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") +case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") { + override def funcName: String = "rint" +} case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM") @@ -84,6 +104,10 @@ case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH") -case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") +case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") { + override def funcName: String = "toDegrees" +} -case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") +case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") { + override def funcName: String = "toRadians" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 00565ec651a59..2e4b9ba678433 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.trees.LeafNode +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.types._ object NamedExpression { @@ -116,6 +116,8 @@ case class Alias(child: Expression, name: String)( override def eval(input: Row): Any = child.eval(input) + override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) + override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable override def metadata: Metadata = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 5070570b4740d..9ecfb3ccc262f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.types.DataType @@ -51,6 +52,25 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } result } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + + children.map { e => + val eval = e.gen(ctx) + s""" + if (${ev.isNull}) { + ${eval.code} + if (!${eval.isNull}) { + ${ev.isNull} = false; + ${ev.primitive} = ${eval.primitive}; + } + } + """ + }.mkString("\n") + } } case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { @@ -61,6 +81,13 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr child.eval(input) == null } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val eval = child.gen(ctx) + ev.isNull = "false" + ev.primitive = eval.isNull + eval.code + } + override def toString: String = s"IS NULL $child" } @@ -72,6 +99,13 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E override def eval(input: Row): Any = { child.eval(input) != null } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val eval = child.gen(ctx) + ev.isNull = "false" + ev.primitive = s"(!(${eval.isNull}))" + eval.code + } } /** @@ -95,4 +129,25 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate } numNonNulls >= n } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val nonnull = ctx.freshName("nonnull") + val code = children.map { e => + val eval = e.gen(ctx) + s""" + if ($nonnull < $n) { + ${eval.code} + if (!${eval.isNull}) { + $nonnull += 1; + } + } + """ + }.mkString("\n") + s""" + int $nonnull = 0; + $code + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = $nonnull >= $n; + """ + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 58273b166fe91..1d0f19a400d63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.types._ object InterpretedPredicate { def create(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = @@ -82,6 +83,10 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex case b: Boolean => !b } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, c => s"!($c)") + } } /** @@ -141,6 +146,29 @@ case class And(left: Expression, right: Expression) } } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + + // The result should be `false`, if any of them is `false` whenever the other is null or not. + s""" + ${eval1.code} + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = false; + + if (!${eval1.isNull} && !${eval1.primitive}) { + } else { + ${eval2.code} + if (!${eval2.isNull} && !${eval2.primitive}) { + } else if (!${eval1.isNull} && !${eval2.isNull}) { + ${ev.primitive} = true; + } else { + ${ev.isNull} = true; + } + } + """ + } } case class Or(left: Expression, right: Expression) @@ -167,6 +195,29 @@ case class Or(left: Expression, right: Expression) } } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + + // The result should be `true`, if any of them is `true` whenever the other is null or not. + s""" + ${eval1.code} + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = true; + + if (!${eval1.isNull} && ${eval1.primitive}) { + } else { + ${eval2.code} + if (!${eval2.isNull} && ${eval2.primitive}) { + } else if (!${eval1.isNull} && !${eval2.isNull}) { + ${ev.primitive} = false; + } else { + ${ev.isNull} = true; + } + } + """ + } } abstract class BinaryComparison extends BinaryExpression with Predicate { @@ -198,6 +249,20 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + left.dataType match { + case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, { + (c1, c3) => s"$c1 $symbol $c3" + }) + case TimestampType => + // java.sql.Timestamp does not have compare() + super.genCode(ctx, ev) + case other => defineCodeGen (ctx, ev, { + (c1, c2) => s"$c1.compare($c2) $symbol 0" + }) + } + } + protected def evalInternal(evalE1: Any, evalE2: Any): Any = sys.error(s"BinaryComparisons must override either eval or evalInternal") } @@ -215,6 +280,9 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison if (left.dataType != BinaryType) l == r else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, ctx.equalFunc(left.dataType)) + } } case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { @@ -235,6 +303,17 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp l == r } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val equalCode = ctx.equalFunc(left.dataType)(eval1.primitive, eval2.primitive) + ev.isNull = "false" + eval1.code + eval2.code + s""" + boolean ${ev.primitive} = (${eval1.isNull} && ${eval2.isNull}) || + (!${eval1.isNull} && $equalCode); + """ + } } case class LessThan(left: Expression, right: Expression) extends BinaryComparison { @@ -309,6 +388,27 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val condEval = predicate.gen(ctx) + val trueEval = trueValue.gen(ctx) + val falseEval = falseValue.gen(ctx) + + s""" + ${condEval.code} + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${condEval.isNull} && ${condEval.primitive}) { + ${trueEval.code} + ${ev.isNull} = ${trueEval.isNull}; + ${ev.primitive} = ${trueEval.primitive}; + } else { + ${falseEval.code} + ${ev.isNull} = ${falseEval.isNull}; + ${ev.primitive} = ${falseEval.primitive}; + } + """ + } + override def toString: String = s"if ($predicate) $trueValue else $falseValue" } @@ -393,6 +493,48 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { return res } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val len = branchesArr.length + val got = ctx.freshName("got") + + val cases = (0 until len/2).map { i => + val cond = branchesArr(i * 2).gen(ctx) + val res = branchesArr(i * 2 + 1).gen(ctx) + s""" + if (!$got) { + ${cond.code} + if (!${cond.isNull} && ${cond.primitive}) { + $got = true; + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + } + """ + }.mkString("\n") + + val other = if (len % 2 == 1) { + val res = branchesArr(len - 1).gen(ctx) + s""" + if (!$got) { + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + """ + } else { + "" + } + + s""" + boolean $got = false; + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + $cases + $other + """ + } + override def toString: String = { "CASE" + branches.sliding(2, 2).map { case Seq(cond, value) => s" WHEN $cond THEN $value" @@ -444,6 +586,52 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW return res } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val keyEval = key.gen(ctx) + val len = branchesArr.length + val got = ctx.freshName("got") + + val cases = (0 until len/2).map { i => + val cond = branchesArr(i * 2).gen(ctx) + val res = branchesArr(i * 2 + 1).gen(ctx) + s""" + if (!$got) { + ${cond.code} + if (${keyEval.isNull} && ${cond.isNull} || + !${keyEval.isNull} && !${cond.isNull} + && ${ctx.equalFunc(key.dataType)(keyEval.primitive, cond.primitive)}) { + $got = true; + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + } + """ + }.mkString("\n") + + val other = if (len % 2 == 1) { + val res = branchesArr(len - 1).gen(ctx) + s""" + if (!$got) { + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + """ + } else { + "" + } + + s""" + boolean $got = false; + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${keyEval.code} + $cases + $other + """ + } + private def equalNullSafe(l: Any, r: Any) = { if (l == null && r == null) { true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index b65bf165f21db..b39349b988389 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -60,6 +61,17 @@ case class NewSet(elementType: DataType) extends LeafExpression { new OpenHashSet[Any]() } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + elementType match { + case IntegerType | LongType => + ev.isNull = "false" + s""" + ${ctx.javaType(dataType)} ${ev.primitive} = new ${ctx.javaType(dataType)}(); + """ + case _ => super.genCode(ctx, ev) + } + } + override def toString: String = s"new Set($dataType)" } @@ -91,6 +103,25 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType + elementType match { + case IntegerType | LongType => + val itemEval = item.gen(ctx) + val setEval = set.gen(ctx) + val htype = ctx.javaType(dataType) + + ev.isNull = "false" + ev.primitive = setEval.primitive + itemEval.code + setEval.code + s""" + if (!${itemEval.isNull} && !${setEval.isNull}) { + (($htype)${setEval.primitive}).add(${itemEval.primitive}); + } + """ + case _ => super.genCode(ctx, ev) + } + } + override def toString: String = s"$set += $item" } @@ -116,14 +147,31 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres val rightValue = iterator.next() leftEval.add(rightValue) } - leftEval - } else { - null } + leftEval } else { null } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType + elementType match { + case IntegerType | LongType => + val leftEval = left.gen(ctx) + val rightEval = right.gen(ctx) + val htype = ctx.javaType(dataType) + + ev.isNull = leftEval.isNull + ev.primitive = leftEval.primitive + leftEval.code + rightEval.code + s""" + if (!${leftEval.isNull} && !${rightEval.isNull}) { + ${leftEval.primitive}.union((${htype})${rightEval.primitive}); + } + """ + case _ => super.genCode(ctx, ev) + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index c4ef9c30907f1..78adb509b470b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.regex.Pattern import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ trait StringRegexExpression extends ExpectsInputTypes { @@ -137,6 +138,10 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE override def convert(v: UTF8String): UTF8String = v.toUpperCase() override def toString: String = s"Upper($child)" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") + } } /** @@ -147,6 +152,10 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE override def convert(v: UTF8String): UTF8String = v.toLowerCase() override def toString: String = s"Lower($child)" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") + } } /** A base trait for functions that compare two strings, returning a boolean. */ @@ -181,6 +190,9 @@ trait StringComparison extends ExpectsInputTypes { case class Contains(left: Expression, right: Expression) extends BinaryExpression with Predicate with StringComparison { override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") + } } /** @@ -189,6 +201,9 @@ case class Contains(left: Expression, right: Expression) case class StartsWith(left: Expression, right: Expression) extends BinaryExpression with Predicate with StringComparison { override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") + } } /** @@ -197,6 +212,9 @@ case class StartsWith(left: Expression, right: Expression) case class EndsWith(left: Expression, right: Expression) extends BinaryExpression with Predicate with StringComparison { override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 5df528770ca6e..eea2edc323eea 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} import org.apache.spark.sql.catalyst.expressions.mathfuncs._ import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ @@ -35,11 +36,20 @@ import org.apache.spark.sql.types._ class ExpressionEvaluationBaseSuite extends SparkFunSuite { + def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { + checkEvaluationWithoutCodegen(expression, expected, inputRow) + checkEvaluationWithGeneratedMutableProjection(expression, expected, inputRow) + checkEvaluationWithGeneratedProjection(expression, expected, inputRow) + } + def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { expression.eval(inputRow) } - def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { + def checkEvaluationWithoutCodegen( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } @@ -49,6 +59,68 @@ class ExpressionEvaluationBaseSuite extends SparkFunSuite { } } + def checkEvaluationWithGeneratedMutableProjection( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + + val plan = try { + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() + } catch { + case e: Throwable => + val ctx = GenerateProjection.newCodeGenContext() + val evaluated = expression.gen(ctx) + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code} + |$e + """.stripMargin) + } + + val actual = plan(inputRow).apply(0) + if (actual != expected) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + + def checkEvaluationWithGeneratedProjection( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + val ctx = GenerateProjection.newCodeGenContext() + lazy val evaluated = expression.gen(ctx) + + val plan = try { + GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) + } catch { + case e: Throwable => + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code} + |$e + """.stripMargin) + } + + val actual = plan(inputRow) + val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected))) + if (actual.hashCode() != expectedRow.hashCode()) { + fail( + s""" + |Mismatched hashCodes for values: $actual, $expectedRow + |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} + |Expressions: ${expression} + |Code: ${evaluated} + """.stripMargin) + } + if (actual != expectedRow) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + def checkDoubleEvaluation( expression: Expression, expected: Spread[Double], @@ -69,8 +141,16 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { test("literals") { checkEvaluation(Literal(1), 1) checkEvaluation(Literal(true), true) + checkEvaluation(Literal(false), false) checkEvaluation(Literal(0L), 0L) + List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { + d => { + checkEvaluation(Literal(d), d) + checkEvaluation(Literal(d.toFloat), d.toFloat) + } + } checkEvaluation(Literal("test"), "test") + checkEvaluation(Literal.create(null, StringType), null) checkEvaluation(Literal(1) + Literal(1), 2) } @@ -1367,6 +1447,11 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { // TODO: Make the tests work with codegen. class ExpressionEvaluationWithoutCodeGenSuite extends ExpressionEvaluationBaseSuite { + override def checkEvaluation( + expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { + checkEvaluationWithoutCodegen(expression, expected, inputRow) + } + test("CreateStruct") { val row = Row(1, 2, 3) val c1 = 'a.int.at(0).as("a") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala index 8cfd853afa35f..371a73181dad7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala @@ -21,34 +21,9 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ /** - * Overrides our expression evaluation tests to use code generation for evaluation. + * Additional tests for code generation. */ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { - override def checkEvaluation( - expression: Expression, - expected: Any, - inputRow: Row = EmptyRow): Unit = { - val plan = try { - GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() - } catch { - case e: Throwable => - val ctx = GenerateProjection.newCodeGenContext() - val evaluated = GenerateProjection.expressionEvaluator(expression, ctx) - fail( - s""" - |Code generation of $expression failed: - |${evaluated.code} - |$e - """.stripMargin) - } - - val actual = plan(inputRow).apply(0) - if (actual != expected) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") - } - } - test("multithreaded eval") { import scala.concurrent._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala deleted file mode 100644 index 9ab1f7d7ad0db..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.codegen._ - -/** - * Overrides our expression evaluation tests to use generated code on mutable rows. - */ -class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { - override def checkEvaluation( - expression: Expression, - expected: Any, - inputRow: Row = EmptyRow): Unit = { - val ctx = GenerateProjection.newCodeGenContext() - lazy val evaluated = GenerateProjection.expressionEvaluator(expression, ctx) - - val plan = try { - GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) - } catch { - case e: Throwable => - fail( - s""" - |Code generation of $expression failed: - |${evaluated.code} - |$e - """.stripMargin) - } - - val actual = plan(inputRow) - val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected))) - if (actual.hashCode() != expectedRow.hashCode()) { - fail( - s""" - |Mismatched hashCodes for values: $actual, $expectedRow - |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} - |${evaluated.code} - """.stripMargin) - } - if (actual != expectedRow) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index 8979a0a210a42..d9a010a9815a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -53,7 +53,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { check("10", Literal.create(10, IntegerType)) check("1000000000000000", Literal.create(1000000000000000L, LongType)) - check("1.5", Literal.create(1.5, FloatType)) + check("1.5", Literal.create(1.5f, FloatType)) check("hello", Literal.create("hello", StringType)) check(defaultPartitionName, Literal.create(null, NullType)) } @@ -83,13 +83,13 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { ArrayBuffer( Literal.create(10, IntegerType), Literal.create("hello", StringType), - Literal.create(1.5, FloatType))) + Literal.create(1.5f, FloatType))) }) check("file://path/a=10/b_hello/c=1.5", Some { PartitionValues( ArrayBuffer("c"), - ArrayBuffer(Literal.create(1.5, FloatType))) + ArrayBuffer(Literal.create(1.5f, FloatType))) }) check("file:///", None) From f74be744d41586690e73ec57e5551c1fbabc1d6f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 7 Jun 2015 18:45:24 -0700 Subject: [PATCH 230/254] [SPARK-8149][SQL] Break ExpressionEvaluationSuite down to multiple files Also moved a few files in expressions package around to match test suites. Author: Reynold Xin Closes #6693 from rxin/expr-refactoring and squashes the following commits: 857599f [Reynold Xin] Fixed style violation. c0eb74b [Reynold Xin] Fixed compilation. b3a40f8 [Reynold Xin] Refactored expression test suites. --- .../spark/sql/catalyst/expressions/Cast.scala | 2 +- .../sql/catalyst/expressions/Expression.scala | 6 +- .../sql/catalyst/expressions/arithmetic.scala | 96 -- .../sql/catalyst/expressions/bitwise.scala | 120 ++ .../catalyst/expressions/conditionals.scala | 313 ++++ .../{mathfuncs/unary.scala => math.scala} | 99 +- .../expressions/mathfuncs/binary.scala | 100 -- .../sql/catalyst/expressions/predicates.scala | 290 ---- .../ArithmeticExpressionSuite.scala | 144 ++ .../expressions/BitwiseFunctionsSuite.scala | 80 + .../sql/catalyst/expressions/CastSuite.scala | 532 ++++++ ...nSuite.scala => CodeGenerationSuite.scala} | 3 +- .../expressions/ComplexTypeSuite.scala | 122 ++ .../ConditionalExpressionSuite.scala | 96 ++ .../expressions/ExpressionEvalHelper.scala | 134 ++ .../ExpressionEvaluationSuite.scala | 1461 ----------------- .../expressions/LiteralExpressionSuite.scala | 55 + .../expressions/MathFunctionsSuite.scala | 179 ++ .../expressions/NullFunctionsSuite.scala | 65 + .../catalyst/expressions/PredicateSuite.scala | 179 ++ .../expressions/StringFunctionsSuite.scala | 218 +++ .../ExpressionOptimizationSuite.scala | 3 +- .../org/apache/spark/sql/functions.scala | 1 - 23 files changed, 2340 insertions(+), 1958 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{mathfuncs/unary.scala => math.scala} (54%) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/{GeneratedEvaluationSuite.scala => CodeGenerationSuite.scala} (94%) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 5f76a512679a4..2a1f96409daf4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -161,7 +161,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null } }) case BooleanType => - buildCast[Boolean](_, b => new Timestamp((if (b) 1 else 0))) + buildCast[Boolean](_, b => new Timestamp(if (b) 1 else 0)) case LongType => buildCast[Long](_, l => new Timestamp(l)) case IntegerType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 0ed576b3d5870..432d65eee54fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -81,11 +81,11 @@ abstract class Expression extends TreeNode[Expression] { val objectTerm = ctx.freshName("obj") s""" /* expression: ${this} */ - Object ${objectTerm} = expressions[${ctx.references.size - 1}].eval(i); - boolean ${ev.isNull} = ${objectTerm} == null; + Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i); + boolean ${ev.isNull} = $objectTerm == null; ${ctx.javaType(this.dataType)} ${ev.primitive} = ${ctx.defaultValue(this.dataType)}; if (!${ev.isNull}) { - ${ev.primitive} = (${ctx.boxedType(this.dataType)})${objectTerm}; + ${ev.primitive} = (${ctx.boxedType(this.dataType)}) $objectTerm; } """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 3ac7c92dcd009..d4efda2e04c29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -322,102 +322,6 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } } -/** - * A function that calculates bitwise and(&) of two numbers. - */ -case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { - override def symbol: String = "&" - - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) - - private lazy val and: (Any, Any) => Any = dataType match { - case ByteType => - ((evalE1: Byte, evalE2: Byte) => (evalE1 & evalE2).toByte).asInstanceOf[(Any, Any) => Any] - case ShortType => - ((evalE1: Short, evalE2: Short) => (evalE1 & evalE2).toShort).asInstanceOf[(Any, Any) => Any] - case IntegerType => - ((evalE1: Int, evalE2: Int) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] - case LongType => - ((evalE1: Long, evalE2: Long) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] - } - - protected override def evalInternal(evalE1: Any, evalE2: Any) = and(evalE1, evalE2) -} - -/** - * A function that calculates bitwise or(|) of two numbers. - */ -case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { - override def symbol: String = "|" - - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) - - private lazy val or: (Any, Any) => Any = dataType match { - case ByteType => - ((evalE1: Byte, evalE2: Byte) => (evalE1 | evalE2).toByte).asInstanceOf[(Any, Any) => Any] - case ShortType => - ((evalE1: Short, evalE2: Short) => (evalE1 | evalE2).toShort).asInstanceOf[(Any, Any) => Any] - case IntegerType => - ((evalE1: Int, evalE2: Int) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] - case LongType => - ((evalE1: Long, evalE2: Long) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] - } - - protected override def evalInternal(evalE1: Any, evalE2: Any) = or(evalE1, evalE2) -} - -/** - * A function that calculates bitwise xor of two numbers. - */ -case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { - override def symbol: String = "^" - - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) - - private lazy val xor: (Any, Any) => Any = dataType match { - case ByteType => - ((evalE1: Byte, evalE2: Byte) => (evalE1 ^ evalE2).toByte).asInstanceOf[(Any, Any) => Any] - case ShortType => - ((evalE1: Short, evalE2: Short) => (evalE1 ^ evalE2).toShort).asInstanceOf[(Any, Any) => Any] - case IntegerType => - ((evalE1: Int, evalE2: Int) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] - case LongType => - ((evalE1: Long, evalE2: Long) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] - } - - protected override def evalInternal(evalE1: Any, evalE2: Any): Any = xor(evalE1, evalE2) -} - -/** - * A function that calculates bitwise not(~) of a number. - */ -case class BitwiseNot(child: Expression) extends UnaryArithmetic { - override def toString: String = s"~$child" - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForBitwiseExpr(child.dataType, "operator ~") - - private lazy val not: (Any) => Any = dataType match { - case ByteType => - ((evalE: Byte) => (~evalE).toByte).asInstanceOf[(Any) => Any] - case ShortType => - ((evalE: Short) => (~evalE).toShort).asInstanceOf[(Any) => Any] - case IntegerType => - ((evalE: Int) => ~evalE).asInstanceOf[(Any) => Any] - case LongType => - ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any] - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { - defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)})~($c)") - } - - protected override def evalInternal(evalE: Any) = not(evalE) -} - case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = left.nullable && right.nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala new file mode 100644 index 0000000000000..ef34586261e70 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types._ + + +/** + * A function that calculates bitwise and(&) of two numbers. + */ +case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { + override def symbol: String = "&" + + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + + private lazy val and: (Any, Any) => Any = dataType match { + case ByteType => + ((evalE1: Byte, evalE2: Byte) => (evalE1 & evalE2).toByte).asInstanceOf[(Any, Any) => Any] + case ShortType => + ((evalE1: Short, evalE2: Short) => (evalE1 & evalE2).toShort).asInstanceOf[(Any, Any) => Any] + case IntegerType => + ((evalE1: Int, evalE2: Int) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] + case LongType => + ((evalE1: Long, evalE2: Long) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] + } + + protected override def evalInternal(evalE1: Any, evalE2: Any) = and(evalE1, evalE2) +} + +/** + * A function that calculates bitwise or(|) of two numbers. + */ +case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { + override def symbol: String = "|" + + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + + private lazy val or: (Any, Any) => Any = dataType match { + case ByteType => + ((evalE1: Byte, evalE2: Byte) => (evalE1 | evalE2).toByte).asInstanceOf[(Any, Any) => Any] + case ShortType => + ((evalE1: Short, evalE2: Short) => (evalE1 | evalE2).toShort).asInstanceOf[(Any, Any) => Any] + case IntegerType => + ((evalE1: Int, evalE2: Int) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] + case LongType => + ((evalE1: Long, evalE2: Long) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] + } + + protected override def evalInternal(evalE1: Any, evalE2: Any) = or(evalE1, evalE2) +} + +/** + * A function that calculates bitwise xor of two numbers. + */ +case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { + override def symbol: String = "^" + + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + + private lazy val xor: (Any, Any) => Any = dataType match { + case ByteType => + ((evalE1: Byte, evalE2: Byte) => (evalE1 ^ evalE2).toByte).asInstanceOf[(Any, Any) => Any] + case ShortType => + ((evalE1: Short, evalE2: Short) => (evalE1 ^ evalE2).toShort).asInstanceOf[(Any, Any) => Any] + case IntegerType => + ((evalE1: Int, evalE2: Int) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] + case LongType => + ((evalE1: Long, evalE2: Long) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] + } + + protected override def evalInternal(evalE1: Any, evalE2: Any): Any = xor(evalE1, evalE2) +} + +/** + * A function that calculates bitwise not(~) of a number. + */ +case class BitwiseNot(child: Expression) extends UnaryArithmetic { + override def toString: String = s"~$child" + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForBitwiseExpr(child.dataType, "operator ~") + + private lazy val not: (Any) => Any = dataType match { + case ByteType => + ((evalE: Byte) => (~evalE).toByte).asInstanceOf[(Any) => Any] + case ShortType => + ((evalE: Short) => (~evalE).toShort).asInstanceOf[(Any) => Any] + case IntegerType => + ((evalE: Int) => ~evalE).asInstanceOf[(Any) => Any] + case LongType => + ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any] + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)})~($c)") + } + + protected override def evalInternal(evalE: Any) = not(evalE) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala new file mode 100644 index 0000000000000..3aa86edd7ab20 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -0,0 +1,313 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.types.{BooleanType, DataType} + + +case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) + extends Expression { + + override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil + override def nullable: Boolean = trueValue.nullable || falseValue.nullable + + override def checkInputDataTypes(): TypeCheckResult = { + if (predicate.dataType != BooleanType) { + TypeCheckResult.TypeCheckFailure( + s"type of predicate expression in If should be boolean, not ${predicate.dataType}") + } else if (trueValue.dataType != falseValue.dataType) { + TypeCheckResult.TypeCheckFailure( + s"differing types in If (${trueValue.dataType} and ${falseValue.dataType}).") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def dataType: DataType = trueValue.dataType + + override def eval(input: Row): Any = { + if (true == predicate.eval(input)) { + trueValue.eval(input) + } else { + falseValue.eval(input) + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val condEval = predicate.gen(ctx) + val trueEval = trueValue.gen(ctx) + val falseEval = falseValue.gen(ctx) + + s""" + ${condEval.code} + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${condEval.isNull} && ${condEval.primitive}) { + ${trueEval.code} + ${ev.isNull} = ${trueEval.isNull}; + ${ev.primitive} = ${trueEval.primitive}; + } else { + ${falseEval.code} + ${ev.isNull} = ${falseEval.isNull}; + ${ev.primitive} = ${falseEval.primitive}; + } + """ + } + + override def toString: String = s"if ($predicate) $trueValue else $falseValue" +} + +trait CaseWhenLike extends Expression { + self: Product => + + // Note that `branches` are considered in consecutive pairs (cond, val), and the optional last + // element is the value for the default catch-all case (if provided). + // Hence, `branches` consists of at least two elements, and can have an odd or even length. + def branches: Seq[Expression] + + @transient lazy val whenList = + branches.sliding(2, 2).collect { case Seq(whenExpr, _) => whenExpr }.toSeq + @transient lazy val thenList = + branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq + val elseValue = if (branches.length % 2 == 0) None else Option(branches.last) + + // both then and else expressions should be considered. + def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType) + def valueTypesEqual: Boolean = valueTypes.distinct.size == 1 + + override def checkInputDataTypes(): TypeCheckResult = { + if (valueTypesEqual) { + checkTypesInternal() + } else { + TypeCheckResult.TypeCheckFailure( + "THEN and ELSE expressions should all be same type or coercible to a common type") + } + } + + protected def checkTypesInternal(): TypeCheckResult + + override def dataType: DataType = thenList.head.dataType + + override def nullable: Boolean = { + // If no value is nullable and no elseValue is provided, the whole statement defaults to null. + thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) + } +} + +// scalastyle:off +/** + * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". + * Refer to this link for the corresponding semantics: + * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions + */ +// scalastyle:on +case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { + + // Use private[this] Array to speed up evaluation. + @transient private[this] lazy val branchesArr = branches.toArray + + override def children: Seq[Expression] = branches + + override protected def checkTypesInternal(): TypeCheckResult = { + if (whenList.forall(_.dataType == BooleanType)) { + TypeCheckResult.TypeCheckSuccess + } else { + val index = whenList.indexWhere(_.dataType != BooleanType) + TypeCheckResult.TypeCheckFailure( + s"WHEN expressions in CaseWhen should all be boolean type, " + + s"but the ${index + 1}th when expression's type is ${whenList(index)}") + } + } + + /** Written in imperative fashion for performance considerations. */ + override def eval(input: Row): Any = { + val len = branchesArr.length + var i = 0 + // If all branches fail and an elseVal is not provided, the whole statement + // defaults to null, according to Hive's semantics. + while (i < len - 1) { + if (branchesArr(i).eval(input) == true) { + return branchesArr(i + 1).eval(input) + } + i += 2 + } + var res: Any = null + if (i == len - 1) { + res = branchesArr(i).eval(input) + } + return res + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val len = branchesArr.length + val got = ctx.freshName("got") + + val cases = (0 until len/2).map { i => + val cond = branchesArr(i * 2).gen(ctx) + val res = branchesArr(i * 2 + 1).gen(ctx) + s""" + if (!$got) { + ${cond.code} + if (!${cond.isNull} && ${cond.primitive}) { + $got = true; + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + } + """ + }.mkString("\n") + + val other = if (len % 2 == 1) { + val res = branchesArr(len - 1).gen(ctx) + s""" + if (!$got) { + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + """ + } else { + "" + } + + s""" + boolean $got = false; + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + $cases + $other + """ + } + + override def toString: String = { + "CASE" + branches.sliding(2, 2).map { + case Seq(cond, value) => s" WHEN $cond THEN $value" + case Seq(elseValue) => s" ELSE $elseValue" + }.mkString + } +} + +// scalastyle:off +/** + * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END". + * Refer to this link for the corresponding semantics: + * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions + */ +// scalastyle:on +case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseWhenLike { + + // Use private[this] Array to speed up evaluation. + @transient private[this] lazy val branchesArr = branches.toArray + + override def children: Seq[Expression] = key +: branches + + override protected def checkTypesInternal(): TypeCheckResult = { + if ((key +: whenList).map(_.dataType).distinct.size > 1) { + TypeCheckResult.TypeCheckFailure( + "key and WHEN expressions should all be same type or coercible to a common type") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + /** Written in imperative fashion for performance considerations. */ + override def eval(input: Row): Any = { + val evaluatedKey = key.eval(input) + val len = branchesArr.length + var i = 0 + // If all branches fail and an elseVal is not provided, the whole statement + // defaults to null, according to Hive's semantics. + while (i < len - 1) { + if (equalNullSafe(evaluatedKey, branchesArr(i).eval(input))) { + return branchesArr(i + 1).eval(input) + } + i += 2 + } + var res: Any = null + if (i == len - 1) { + res = branchesArr(i).eval(input) + } + return res + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val keyEval = key.gen(ctx) + val len = branchesArr.length + val got = ctx.freshName("got") + + val cases = (0 until len/2).map { i => + val cond = branchesArr(i * 2).gen(ctx) + val res = branchesArr(i * 2 + 1).gen(ctx) + s""" + if (!$got) { + ${cond.code} + if (${keyEval.isNull} && ${cond.isNull} || + !${keyEval.isNull} && !${cond.isNull} + && ${ctx.equalFunc(key.dataType)(keyEval.primitive, cond.primitive)}) { + $got = true; + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + } + """ + }.mkString("\n") + + val other = if (len % 2 == 1) { + val res = branchesArr(len - 1).gen(ctx) + s""" + if (!$got) { + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + """ + } else { + "" + } + + s""" + boolean $got = false; + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${keyEval.code} + $cases + $other + """ + } + + private def equalNullSafe(l: Any, r: Any) = { + if (l == null && r == null) { + true + } else if (l == null || r == null) { + false + } else { + l == r + } + } + + override def toString: String = { + s"CASE $key" + branches.sliding(2, 2).map { + case Seq(cond, value) => s" WHEN $cond THEN $value" + case Seq(elseValue) => s" ELSE $elseValue" + }.mkString + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala similarity index 54% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 5563cd94bf86d..a18067e4a58f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -15,11 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.expressions.mathfuncs +package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.codegen.{Code, CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, Row, UnaryExpression} -import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.types.{DataType, DoubleType} /** * A unary expression specifically for math functions. Math Functions expect a specific type of @@ -64,6 +63,47 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) } } +/** + * A binary expression specifically for math functions that take two `Double`s as input and returns + * a `Double`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) + extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => + + override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) + + override def toString: String = s"$name($left, $right)" + + override def dataType: DataType = DoubleType + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + val result = f(evalE1.asInstanceOf[Double], evalE2.asInstanceOf[Double]) + if (result.isNaN) null else result + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)") + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Unary math functions +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS") case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN") @@ -111,3 +151,54 @@ case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegre case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") { override def funcName: String = "toRadians" } + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Binary math functions +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +case class Atan2(left: Expression, right: Expression) + extends BinaryMathExpression(math.atan2, "ATAN2") { + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 + val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0, + evalE2.asInstanceOf[Double] + 0.0) + if (result.isNaN) null else result + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s""" + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + """ + } +} + +case class Hypot(left: Expression, right: Expression) + extends BinaryMathExpression(math.hypot, "HYPOT") + +case class Pow(left: Expression, right: Expression) + extends BinaryMathExpression(math.pow, "POWER") { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s""" + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + """ + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala deleted file mode 100644 index 88211acd7713c..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.mathfuncs - -import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, BinaryExpression, Expression, Row} -import org.apache.spark.sql.types._ - -/** - * A binary expression specifically for math functions that take two `Double`s as input and returns - * a `Double`. - * @param f The math function. - * @param name The short name of the function - */ -abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) - extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => - - override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) - - override def toString: String = s"$name($left, $right)" - - override def dataType: DataType = DoubleType - - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - val result = f(evalE1.asInstanceOf[Double], evalE2.asInstanceOf[Double]) - if (result.isNaN) null else result - } - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)") - } -} - -case class Atan2(left: Expression, right: Expression) - extends BinaryMathExpression(math.atan2, "ATAN2") { - - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 - val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0, - evalE2.asInstanceOf[Double] + 0.0) - if (result.isNaN) null else result - } - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s""" - if (Double.valueOf(${ev.primitive}).isNaN()) { - ${ev.isNull} = true; - } - """ - } -} - -case class Hypot(left: Expression, right: Expression) - extends BinaryMathExpression(math.hypot, "HYPOT") - -case class Pow(left: Expression, right: Expression) - extends BinaryMathExpression(math.pow, "POWER") { - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s""" - if (Double.valueOf(${ev.primitive}).isNaN()) { - ${ev.isNull} = true; - } - """ - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 1d0f19a400d63..5edcf3bd77d20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -359,293 +359,3 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.gteq(evalE1, evalE2) } - -case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) - extends Expression { - - override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil - override def nullable: Boolean = trueValue.nullable || falseValue.nullable - - override def checkInputDataTypes(): TypeCheckResult = { - if (predicate.dataType != BooleanType) { - TypeCheckResult.TypeCheckFailure( - s"type of predicate expression in If should be boolean, not ${predicate.dataType}") - } else if (trueValue.dataType != falseValue.dataType) { - TypeCheckResult.TypeCheckFailure( - s"differing types in If (${trueValue.dataType} and ${falseValue.dataType}).") - } else { - TypeCheckResult.TypeCheckSuccess - } - } - - override def dataType: DataType = trueValue.dataType - - override def eval(input: Row): Any = { - if (true == predicate.eval(input)) { - trueValue.eval(input) - } else { - falseValue.eval(input) - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { - val condEval = predicate.gen(ctx) - val trueEval = trueValue.gen(ctx) - val falseEval = falseValue.gen(ctx) - - s""" - ${condEval.code} - boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${condEval.isNull} && ${condEval.primitive}) { - ${trueEval.code} - ${ev.isNull} = ${trueEval.isNull}; - ${ev.primitive} = ${trueEval.primitive}; - } else { - ${falseEval.code} - ${ev.isNull} = ${falseEval.isNull}; - ${ev.primitive} = ${falseEval.primitive}; - } - """ - } - - override def toString: String = s"if ($predicate) $trueValue else $falseValue" -} - -trait CaseWhenLike extends Expression { - self: Product => - - // Note that `branches` are considered in consecutive pairs (cond, val), and the optional last - // element is the value for the default catch-all case (if provided). - // Hence, `branches` consists of at least two elements, and can have an odd or even length. - def branches: Seq[Expression] - - @transient lazy val whenList = - branches.sliding(2, 2).collect { case Seq(whenExpr, _) => whenExpr }.toSeq - @transient lazy val thenList = - branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq - val elseValue = if (branches.length % 2 == 0) None else Option(branches.last) - - // both then and else expressions should be considered. - def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType) - def valueTypesEqual: Boolean = valueTypes.distinct.size == 1 - - override def checkInputDataTypes(): TypeCheckResult = { - if (valueTypesEqual) { - checkTypesInternal() - } else { - TypeCheckResult.TypeCheckFailure( - "THEN and ELSE expressions should all be same type or coercible to a common type") - } - } - - protected def checkTypesInternal(): TypeCheckResult - - override def dataType: DataType = thenList.head.dataType - - override def nullable: Boolean = { - // If no value is nullable and no elseValue is provided, the whole statement defaults to null. - thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) - } -} - -// scalastyle:off -/** - * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". - * Refer to this link for the corresponding semantics: - * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions - */ -// scalastyle:on -case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { - - // Use private[this] Array to speed up evaluation. - @transient private[this] lazy val branchesArr = branches.toArray - - override def children: Seq[Expression] = branches - - override protected def checkTypesInternal(): TypeCheckResult = { - if (whenList.forall(_.dataType == BooleanType)) { - TypeCheckResult.TypeCheckSuccess - } else { - val index = whenList.indexWhere(_.dataType != BooleanType) - TypeCheckResult.TypeCheckFailure( - s"WHEN expressions in CaseWhen should all be boolean type, " + - s"but the ${index + 1}th when expression's type is ${whenList(index)}") - } - } - - /** Written in imperative fashion for performance considerations. */ - override def eval(input: Row): Any = { - val len = branchesArr.length - var i = 0 - // If all branches fail and an elseVal is not provided, the whole statement - // defaults to null, according to Hive's semantics. - while (i < len - 1) { - if (branchesArr(i).eval(input) == true) { - return branchesArr(i + 1).eval(input) - } - i += 2 - } - var res: Any = null - if (i == len - 1) { - res = branchesArr(i).eval(input) - } - return res - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { - val len = branchesArr.length - val got = ctx.freshName("got") - - val cases = (0 until len/2).map { i => - val cond = branchesArr(i * 2).gen(ctx) - val res = branchesArr(i * 2 + 1).gen(ctx) - s""" - if (!$got) { - ${cond.code} - if (!${cond.isNull} && ${cond.primitive}) { - $got = true; - ${res.code} - ${ev.isNull} = ${res.isNull}; - ${ev.primitive} = ${res.primitive}; - } - } - """ - }.mkString("\n") - - val other = if (len % 2 == 1) { - val res = branchesArr(len - 1).gen(ctx) - s""" - if (!$got) { - ${res.code} - ${ev.isNull} = ${res.isNull}; - ${ev.primitive} = ${res.primitive}; - } - """ - } else { - "" - } - - s""" - boolean $got = false; - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - $cases - $other - """ - } - - override def toString: String = { - "CASE" + branches.sliding(2, 2).map { - case Seq(cond, value) => s" WHEN $cond THEN $value" - case Seq(elseValue) => s" ELSE $elseValue" - }.mkString - } -} - -// scalastyle:off -/** - * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END". - * Refer to this link for the corresponding semantics: - * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions - */ -// scalastyle:on -case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseWhenLike { - - // Use private[this] Array to speed up evaluation. - @transient private[this] lazy val branchesArr = branches.toArray - - override def children: Seq[Expression] = key +: branches - - override protected def checkTypesInternal(): TypeCheckResult = { - if ((key +: whenList).map(_.dataType).distinct.size > 1) { - TypeCheckResult.TypeCheckFailure( - "key and WHEN expressions should all be same type or coercible to a common type") - } else { - TypeCheckResult.TypeCheckSuccess - } - } - - /** Written in imperative fashion for performance considerations. */ - override def eval(input: Row): Any = { - val evaluatedKey = key.eval(input) - val len = branchesArr.length - var i = 0 - // If all branches fail and an elseVal is not provided, the whole statement - // defaults to null, according to Hive's semantics. - while (i < len - 1) { - if (equalNullSafe(evaluatedKey, branchesArr(i).eval(input))) { - return branchesArr(i + 1).eval(input) - } - i += 2 - } - var res: Any = null - if (i == len - 1) { - res = branchesArr(i).eval(input) - } - return res - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { - val keyEval = key.gen(ctx) - val len = branchesArr.length - val got = ctx.freshName("got") - - val cases = (0 until len/2).map { i => - val cond = branchesArr(i * 2).gen(ctx) - val res = branchesArr(i * 2 + 1).gen(ctx) - s""" - if (!$got) { - ${cond.code} - if (${keyEval.isNull} && ${cond.isNull} || - !${keyEval.isNull} && !${cond.isNull} - && ${ctx.equalFunc(key.dataType)(keyEval.primitive, cond.primitive)}) { - $got = true; - ${res.code} - ${ev.isNull} = ${res.isNull}; - ${ev.primitive} = ${res.primitive}; - } - } - """ - }.mkString("\n") - - val other = if (len % 2 == 1) { - val res = branchesArr(len - 1).gen(ctx) - s""" - if (!$got) { - ${res.code} - ${ev.isNull} = ${res.isNull}; - ${ev.primitive} = ${res.primitive}; - } - """ - } else { - "" - } - - s""" - boolean $got = false; - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - ${keyEval.code} - $cases - $other - """ - } - - private def equalNullSafe(l: Any, r: Any) = { - if (l == null && r == null) { - true - } else if (l == null || r == null) { - false - } else { - l == r - } - } - - override def toString: String = { - s"CASE $key" + branches.sliding(2, 2).map { - case Seq(cond, value) => s" WHEN $cond THEN $value" - case Seq(elseValue) => s" ELSE $elseValue" - }.mkString - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala new file mode 100644 index 0000000000000..e1afa81a7a82f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.scalatest.Matchers._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.{DoubleType, IntegerType} + + +class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("arithmetic") { + val row = create_row(1, 2, 3, null) + val c1 = 'a.int.at(0) + val c2 = 'a.int.at(1) + val c3 = 'a.int.at(2) + val c4 = 'a.int.at(3) + + checkEvaluation(UnaryMinus(c1), -1, row) + checkEvaluation(UnaryMinus(Literal.create(100, IntegerType)), -100) + + checkEvaluation(Add(c1, c4), null, row) + checkEvaluation(Add(c1, c2), 3, row) + checkEvaluation(Add(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation(Add(Literal.create(null, IntegerType), c2), null, row) + checkEvaluation( + Add(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) + + checkEvaluation(-c1, -1, row) + checkEvaluation(c1 + c2, 3, row) + checkEvaluation(c1 - c2, -1, row) + checkEvaluation(c1 * c2, 2, row) + checkEvaluation(c1 / c2, 0, row) + checkEvaluation(c1 % c2, 1, row) + } + + test("fractional arithmetic") { + val row = create_row(1.1, 2.0, 3.1, null) + val c1 = 'a.double.at(0) + val c2 = 'a.double.at(1) + val c3 = 'a.double.at(2) + val c4 = 'a.double.at(3) + + checkEvaluation(UnaryMinus(c1), -1.1, row) + checkEvaluation(UnaryMinus(Literal.create(100.0, DoubleType)), -100.0) + checkEvaluation(Add(c1, c4), null, row) + checkEvaluation(Add(c1, c2), 3.1, row) + checkEvaluation(Add(c1, Literal.create(null, DoubleType)), null, row) + checkEvaluation(Add(Literal.create(null, DoubleType), c2), null, row) + checkEvaluation( + Add(Literal.create(null, DoubleType), Literal.create(null, DoubleType)), null, row) + + checkEvaluation(-c1, -1.1, row) + checkEvaluation(c1 + c2, 3.1, row) + checkDoubleEvaluation(c1 - c2, (-0.9 +- 0.001), row) + checkDoubleEvaluation(c1 * c2, (2.2 +- 0.001), row) + checkDoubleEvaluation(c1 / c2, (0.55 +- 0.001), row) + checkDoubleEvaluation(c3 % c2, (1.1 +- 0.001), row) + } + + test("Divide") { + checkEvaluation(Divide(Literal(2), Literal(1)), 2) + checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5) + checkEvaluation(Divide(Literal(1), Literal(2)), 0) + checkEvaluation(Divide(Literal(1), Literal(0)), null) + checkEvaluation(Divide(Literal(1.0), Literal(0.0)), null) + checkEvaluation(Divide(Literal(0.0), Literal(0.0)), null) + checkEvaluation(Divide(Literal(0), Literal.create(null, IntegerType)), null) + checkEvaluation(Divide(Literal(1), Literal.create(null, IntegerType)), null) + checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(0)), null) + checkEvaluation(Divide(Literal.create(null, DoubleType), Literal(0.0)), null) + checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(Divide(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), + null) + } + + test("Remainder") { + checkEvaluation(Remainder(Literal(2), Literal(1)), 0) + checkEvaluation(Remainder(Literal(1.0), Literal(2.0)), 1.0) + checkEvaluation(Remainder(Literal(1), Literal(2)), 1) + checkEvaluation(Remainder(Literal(1), Literal(0)), null) + checkEvaluation(Remainder(Literal(1.0), Literal(0.0)), null) + checkEvaluation(Remainder(Literal(0.0), Literal(0.0)), null) + checkEvaluation(Remainder(Literal(0), Literal.create(null, IntegerType)), null) + checkEvaluation(Remainder(Literal(1), Literal.create(null, IntegerType)), null) + checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(0)), null) + checkEvaluation(Remainder(Literal.create(null, DoubleType), Literal(0.0)), null) + checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), + null) + } + + test("MaxOf") { + checkEvaluation(MaxOf(1, 2), 2) + checkEvaluation(MaxOf(2, 1), 2) + checkEvaluation(MaxOf(1L, 2L), 2L) + checkEvaluation(MaxOf(2L, 1L), 2L) + + checkEvaluation(MaxOf(Literal.create(null, IntegerType), 2), 2) + checkEvaluation(MaxOf(2, Literal.create(null, IntegerType)), 2) + } + + test("MinOf") { + checkEvaluation(MinOf(1, 2), 1) + checkEvaluation(MinOf(2, 1), 1) + checkEvaluation(MinOf(1L, 2L), 1L) + checkEvaluation(MinOf(2L, 1L), 1L) + + checkEvaluation(MinOf(Literal.create(null, IntegerType), 1), 1) + checkEvaluation(MinOf(1, Literal.create(null, IntegerType)), 1) + } + + test("SQRT") { + val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24)) + val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble)) + val rowSequence = inputSequence.map(l => create_row(l.toDouble)) + val d = 'a.double.at(0) + + for ((row, expected) <- rowSequence zip expectedResults) { + checkEvaluation(Sqrt(d), expected, row) + } + + checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) + checkEvaluation(Sqrt(-1), null, EmptyRow) + checkEvaluation(Sqrt(-1.5), null, EmptyRow) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala new file mode 100644 index 0000000000000..c9bbc7a8b8c14 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types._ + + +class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("Bitwise operations") { + val row = create_row(1, 2, 3, null) + val c1 = 'a.int.at(0) + val c2 = 'a.int.at(1) + val c3 = 'a.int.at(2) + val c4 = 'a.int.at(3) + + checkEvaluation(BitwiseAnd(c1, c4), null, row) + checkEvaluation(BitwiseAnd(c1, c2), 0, row) + checkEvaluation(BitwiseAnd(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation( + BitwiseAnd(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) + + checkEvaluation(BitwiseOr(c1, c4), null, row) + checkEvaluation(BitwiseOr(c1, c2), 3, row) + checkEvaluation(BitwiseOr(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation( + BitwiseOr(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) + + checkEvaluation(BitwiseXor(c1, c4), null, row) + checkEvaluation(BitwiseXor(c1, c2), 3, row) + checkEvaluation(BitwiseXor(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation( + BitwiseXor(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) + + checkEvaluation(BitwiseNot(c4), null, row) + checkEvaluation(BitwiseNot(c1), -2, row) + checkEvaluation(BitwiseNot(Literal.create(null, IntegerType)), null, row) + + checkEvaluation(c1 & c2, 0, row) + checkEvaluation(c1 | c2, 3, row) + checkEvaluation(c1 ^ c2, 3, row) + checkEvaluation(~c1, -2, row) + } + + test("unary BitwiseNOT") { + checkEvaluation(BitwiseNot(1), -2) + assert(BitwiseNot(1).dataType === IntegerType) + assert(BitwiseNot(1).eval(EmptyRow).isInstanceOf[Int]) + + checkEvaluation(BitwiseNot(1.toLong), -2.toLong) + assert(BitwiseNot(1.toLong).dataType === LongType) + assert(BitwiseNot(1.toLong).eval(EmptyRow).isInstanceOf[Long]) + + checkEvaluation(BitwiseNot(1.toShort), -2.toShort) + assert(BitwiseNot(1.toShort).dataType === ShortType) + assert(BitwiseNot(1.toShort).eval(EmptyRow).isInstanceOf[Short]) + + checkEvaluation(BitwiseNot(1.toByte), -2.toByte) + assert(BitwiseNot(1.toByte).dataType === ByteType) + assert(BitwiseNot(1.toByte).eval(EmptyRow).isInstanceOf[Byte]) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala new file mode 100644 index 0000000000000..5bc7c30eee1b6 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -0,0 +1,532 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.sql.{Timestamp, Date} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +/** + * Test suite for data type casting expression [[Cast]]. + */ +class CastSuite extends SparkFunSuite with ExpressionEvalHelper { + + private def cast(v: Any, targetType: DataType): Cast = { + v match { + case lit: Expression => Cast(lit, targetType) + case _ => Cast(Literal(v), targetType) + } + } + + // expected cannot be null + private def checkCast(v: Any, expected: Any): Unit = { + checkEvaluation(cast(v, Literal(expected).dataType), expected) + } + + test("cast from int") { + checkCast(0, false) + checkCast(1, true) + checkCast(5, true) + checkCast(1, 1.toByte) + checkCast(1, 1.toShort) + checkCast(1, 1) + checkCast(1, 1.toLong) + checkCast(1, 1.0f) + checkCast(1, 1.0) + checkCast(123, "123") + + checkEvaluation(cast(123, DecimalType.Unlimited), Decimal(123)) + checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) + checkEvaluation(cast(123, DecimalType(3, 1)), null) + checkEvaluation(cast(123, DecimalType(2, 0)), null) + } + + test("cast from long") { + checkCast(0L, false) + checkCast(1L, true) + checkCast(5L, true) + checkCast(1L, 1.toByte) + checkCast(1L, 1.toShort) + checkCast(1L, 1) + checkCast(1L, 1.toLong) + checkCast(1L, 1.0f) + checkCast(1L, 1.0) + checkCast(123L, "123") + + checkEvaluation(cast(123L, DecimalType.Unlimited), Decimal(123)) + checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123)) + checkEvaluation(cast(123L, DecimalType(3, 1)), Decimal(123.0)) + + // TODO: Fix the following bug and re-enable it. + // checkEvaluation(cast(123L, DecimalType(2, 0)), null) + } + + test("cast from boolean") { + checkEvaluation(cast(true, IntegerType), 1) + checkEvaluation(cast(false, IntegerType), 0) + checkEvaluation(cast(true, StringType), "true") + checkEvaluation(cast(false, StringType), "false") + checkEvaluation(cast(cast(1, BooleanType), IntegerType), 1) + checkEvaluation(cast(cast(0, BooleanType), IntegerType), 0) + } + + test("cast from int 2") { + checkEvaluation(cast(1, LongType), 1.toLong) + checkEvaluation(cast(cast(1000, TimestampType), LongType), 1.toLong) + checkEvaluation(cast(cast(-1200, TimestampType), LongType), -2.toLong) + + checkEvaluation(cast(123, DecimalType.Unlimited), Decimal(123)) + checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) + checkEvaluation(cast(123, DecimalType(3, 1)), null) + checkEvaluation(cast(123, DecimalType(2, 0)), null) + } + + test("cast from float") { + + } + + test("cast from double") { + checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble) + checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble) + } + + test("cast from string") { + assert(cast("abcdef", StringType).nullable === false) + assert(cast("abcdef", BinaryType).nullable === false) + assert(cast("abcdef", BooleanType).nullable === false) + assert(cast("abcdef", TimestampType).nullable === true) + assert(cast("abcdef", LongType).nullable === true) + assert(cast("abcdef", IntegerType).nullable === true) + assert(cast("abcdef", ShortType).nullable === true) + assert(cast("abcdef", ByteType).nullable === true) + assert(cast("abcdef", DecimalType.Unlimited).nullable === true) + assert(cast("abcdef", DecimalType(4, 2)).nullable === true) + assert(cast("abcdef", DoubleType).nullable === true) + assert(cast("abcdef", FloatType).nullable === true) + } + + test("data type casting") { + val sd = "1970-01-01" + val d = Date.valueOf(sd) + val zts = sd + " 00:00:00" + val sts = sd + " 00:00:02" + val nts = sts + ".1" + val ts = Timestamp.valueOf(nts) + + checkEvaluation(cast("abdef", StringType), "abdef") + checkEvaluation(cast("abdef", DecimalType.Unlimited), null) + checkEvaluation(cast("abdef", TimestampType), null) + checkEvaluation(cast("12.65", DecimalType.Unlimited), Decimal(12.65)) + + checkEvaluation(cast(cast(sd, DateType), StringType), sd) + checkEvaluation(cast(cast(d, StringType), DateType), 0) + checkEvaluation(cast(cast(nts, TimestampType), StringType), nts) + checkEvaluation(cast(cast(ts, StringType), TimestampType), ts) + + // all convert to string type to check + checkEvaluation(cast(cast(cast(nts, TimestampType), DateType), StringType), sd) + checkEvaluation(cast(cast(cast(ts, DateType), TimestampType), StringType), zts) + + checkEvaluation(cast(cast("abdef", BinaryType), StringType), "abdef") + + checkEvaluation(cast(cast(cast(cast( + cast(cast("5", ByteType), ShortType), IntegerType), FloatType), DoubleType), LongType), + 5.toLong) + checkEvaluation( + cast(cast(cast(cast(cast(cast("5", ByteType), TimestampType), + DecimalType.Unlimited), LongType), StringType), ShortType), + 0.toShort) + checkEvaluation( + cast(cast(cast(cast(cast(cast("5", TimestampType), ByteType), + DecimalType.Unlimited), LongType), StringType), ShortType), + null) + checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.Unlimited), + ByteType), TimestampType), LongType), StringType), ShortType), + 0.toShort) + + checkEvaluation(cast("23", DoubleType), 23d) + checkEvaluation(cast("23", IntegerType), 23) + checkEvaluation(cast("23", FloatType), 23f) + checkEvaluation(cast("23", DecimalType.Unlimited), Decimal(23)) + checkEvaluation(cast("23", ByteType), 23.toByte) + checkEvaluation(cast("23", ShortType), 23.toShort) + checkEvaluation(cast("2012-12-11", DoubleType), null) + checkEvaluation(cast(123, IntegerType), 123) + + + checkEvaluation(cast(Literal.create(null, IntegerType), ShortType), null) + } + + test("cast and add") { + checkEvaluation(Add(Literal(23d), cast(true, DoubleType)), 24d) + checkEvaluation(Add(Literal(23), cast(true, IntegerType)), 24) + checkEvaluation(Add(Literal(23f), cast(true, FloatType)), 24f) + checkEvaluation(Add(Literal(Decimal(23)), cast(true, DecimalType.Unlimited)), Decimal(24)) + checkEvaluation(Add(Literal(23.toByte), cast(true, ByteType)), 24.toByte) + checkEvaluation(Add(Literal(23.toShort), cast(true, ShortType)), 24.toShort) + } + + test("casting to fixed-precision decimals") { + // Overflow and rounding for casting to fixed-precision decimals: + // - Values should round with HALF_UP mode by default when you lower scale + // - Values that would overflow the target precision should turn into null + // - Because of this, casts to fixed-precision decimals should be nullable + + assert(cast(123, DecimalType.Unlimited).nullable === false) + assert(cast(10.03f, DecimalType.Unlimited).nullable === true) + assert(cast(10.03, DecimalType.Unlimited).nullable === true) + assert(cast(Decimal(10.03), DecimalType.Unlimited).nullable === false) + + assert(cast(123, DecimalType(2, 1)).nullable === true) + assert(cast(10.03f, DecimalType(2, 1)).nullable === true) + assert(cast(10.03, DecimalType(2, 1)).nullable === true) + assert(cast(Decimal(10.03), DecimalType(2, 1)).nullable === true) + + + checkEvaluation(cast(10.03, DecimalType.Unlimited), Decimal(10.03)) + checkEvaluation(cast(10.03, DecimalType(4, 2)), Decimal(10.03)) + checkEvaluation(cast(10.03, DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(cast(10.03, DecimalType(2, 0)), Decimal(10)) + checkEvaluation(cast(10.03, DecimalType(1, 0)), null) + checkEvaluation(cast(10.03, DecimalType(2, 1)), null) + checkEvaluation(cast(10.03, DecimalType(3, 2)), null) + checkEvaluation(cast(Decimal(10.03), DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(cast(Decimal(10.03), DecimalType(3, 2)), null) + + checkEvaluation(cast(10.05, DecimalType.Unlimited), Decimal(10.05)) + checkEvaluation(cast(10.05, DecimalType(4, 2)), Decimal(10.05)) + checkEvaluation(cast(10.05, DecimalType(3, 1)), Decimal(10.1)) + checkEvaluation(cast(10.05, DecimalType(2, 0)), Decimal(10)) + checkEvaluation(cast(10.05, DecimalType(1, 0)), null) + checkEvaluation(cast(10.05, DecimalType(2, 1)), null) + checkEvaluation(cast(10.05, DecimalType(3, 2)), null) + checkEvaluation(cast(Decimal(10.05), DecimalType(3, 1)), Decimal(10.1)) + checkEvaluation(cast(Decimal(10.05), DecimalType(3, 2)), null) + + checkEvaluation(cast(9.95, DecimalType(3, 2)), Decimal(9.95)) + checkEvaluation(cast(9.95, DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(cast(9.95, DecimalType(2, 0)), Decimal(10)) + checkEvaluation(cast(9.95, DecimalType(2, 1)), null) + checkEvaluation(cast(9.95, DecimalType(1, 0)), null) + checkEvaluation(cast(Decimal(9.95), DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(cast(Decimal(9.95), DecimalType(1, 0)), null) + + checkEvaluation(cast(-9.95, DecimalType(3, 2)), Decimal(-9.95)) + checkEvaluation(cast(-9.95, DecimalType(3, 1)), Decimal(-10.0)) + checkEvaluation(cast(-9.95, DecimalType(2, 0)), Decimal(-10)) + checkEvaluation(cast(-9.95, DecimalType(2, 1)), null) + checkEvaluation(cast(-9.95, DecimalType(1, 0)), null) + checkEvaluation(cast(Decimal(-9.95), DecimalType(3, 1)), Decimal(-10.0)) + checkEvaluation(cast(Decimal(-9.95), DecimalType(1, 0)), null) + + checkEvaluation(cast(Double.NaN, DecimalType.Unlimited), null) + checkEvaluation(cast(1.0 / 0.0, DecimalType.Unlimited), null) + checkEvaluation(cast(Float.NaN, DecimalType.Unlimited), null) + checkEvaluation(cast(1.0f / 0.0f, DecimalType.Unlimited), null) + + checkEvaluation(cast(Double.NaN, DecimalType(2, 1)), null) + checkEvaluation(cast(1.0 / 0.0, DecimalType(2, 1)), null) + checkEvaluation(cast(Float.NaN, DecimalType(2, 1)), null) + checkEvaluation(cast(1.0f / 0.0f, DecimalType(2, 1)), null) + } + + test("cast from date") { + val d = Date.valueOf("1970-01-01") + checkEvaluation(cast(d, ShortType), null) + checkEvaluation(cast(d, IntegerType), null) + checkEvaluation(cast(d, LongType), null) + checkEvaluation(cast(d, FloatType), null) + checkEvaluation(cast(d, DoubleType), null) + checkEvaluation(cast(d, DecimalType.Unlimited), null) + checkEvaluation(cast(d, DecimalType(10, 2)), null) + checkEvaluation(cast(d, StringType), "1970-01-01") + checkEvaluation(cast(cast(d, TimestampType), StringType), "1970-01-01 00:00:00") + } + + test("cast from timestamp") { + val millis = 15 * 1000 + 2 + val seconds = millis * 1000 + 2 + val ts = new Timestamp(millis) + val tss = new Timestamp(seconds) + checkEvaluation(cast(ts, ShortType), 15.toShort) + checkEvaluation(cast(ts, IntegerType), 15) + checkEvaluation(cast(ts, LongType), 15.toLong) + checkEvaluation(cast(ts, FloatType), 15.002f) + checkEvaluation(cast(ts, DoubleType), 15.002) + checkEvaluation(cast(cast(tss, ShortType), TimestampType), ts) + checkEvaluation(cast(cast(tss, IntegerType), TimestampType), ts) + checkEvaluation(cast(cast(tss, LongType), TimestampType), ts) + checkEvaluation( + cast(cast(millis.toFloat / 1000, TimestampType), FloatType), + millis.toFloat / 1000) + checkEvaluation( + cast(cast(millis.toDouble / 1000, TimestampType), DoubleType), + millis.toDouble / 1000) + checkEvaluation( + cast(cast(Decimal(1), TimestampType), DecimalType.Unlimited), + Decimal(1)) + + // A test for higher precision than millis + checkEvaluation(cast(cast(0.00000001, TimestampType), DoubleType), 0.00000001) + + checkEvaluation(cast(Double.NaN, TimestampType), null) + checkEvaluation(cast(1.0 / 0.0, TimestampType), null) + checkEvaluation(cast(Float.NaN, TimestampType), null) + checkEvaluation(cast(1.0f / 0.0f, TimestampType), null) + } + + test("cast from array") { + val array = Literal.create(Seq("123", "abc", "", null), + ArrayType(StringType, containsNull = true)) + val array_notNull = Literal.create(Seq("123", "abc", ""), + ArrayType(StringType, containsNull = false)) + + { + val ret = cast(array, ArrayType(IntegerType, containsNull = true)) + assert(ret.resolved === true) + checkEvaluation(ret, Seq(123, null, null, null)) + } + { + val ret = cast(array, ArrayType(IntegerType, containsNull = false)) + assert(ret.resolved === false) + } + { + val ret = cast(array, ArrayType(BooleanType, containsNull = true)) + assert(ret.resolved === true) + checkEvaluation(ret, Seq(true, true, false, null)) + } + { + val ret = cast(array, ArrayType(BooleanType, containsNull = false)) + assert(ret.resolved === false) + } + + { + val ret = cast(array_notNull, ArrayType(IntegerType, containsNull = true)) + assert(ret.resolved === true) + checkEvaluation(ret, Seq(123, null, null)) + } + { + val ret = cast(array_notNull, ArrayType(IntegerType, containsNull = false)) + assert(ret.resolved === false) + } + { + val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = true)) + assert(ret.resolved === true) + checkEvaluation(ret, Seq(true, true, false)) + } + { + val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false)) + assert(ret.resolved === true) + checkEvaluation(ret, Seq(true, true, false)) + } + + { + val ret = cast(array, IntegerType) + assert(ret.resolved === false) + } + } + + test("cast from map") { + val map = Literal.create( + Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null), + MapType(StringType, StringType, valueContainsNull = true)) + val map_notNull = Literal.create( + Map("a" -> "123", "b" -> "abc", "c" -> ""), + MapType(StringType, StringType, valueContainsNull = false)) + + { + val ret = cast(map, MapType(StringType, IntegerType, valueContainsNull = true)) + assert(ret.resolved === true) + checkEvaluation(ret, Map("a" -> 123, "b" -> null, "c" -> null, "d" -> null)) + } + { + val ret = cast(map, MapType(StringType, IntegerType, valueContainsNull = false)) + assert(ret.resolved === false) + } + { + val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = true)) + assert(ret.resolved === true) + checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false, "d" -> null)) + } + { + val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = false)) + assert(ret.resolved === false) + } + { + val ret = cast(map, MapType(IntegerType, StringType, valueContainsNull = true)) + assert(ret.resolved === false) + } + + { + val ret = cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = true)) + assert(ret.resolved === true) + checkEvaluation(ret, Map("a" -> 123, "b" -> null, "c" -> null)) + } + { + val ret = cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = false)) + assert(ret.resolved === false) + } + { + val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = true)) + assert(ret.resolved === true) + checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false)) + } + { + val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false)) + assert(ret.resolved === true) + checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false)) + } + { + val ret = cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true)) + assert(ret.resolved === false) + } + + { + val ret = cast(map, IntegerType) + assert(ret.resolved === false) + } + } + + test("cast from struct") { + val struct = Literal.create( + Row("123", "abc", "", null), + StructType(Seq( + StructField("a", StringType, nullable = true), + StructField("b", StringType, nullable = true), + StructField("c", StringType, nullable = true), + StructField("d", StringType, nullable = true)))) + val struct_notNull = Literal.create( + Row("123", "abc", ""), + StructType(Seq( + StructField("a", StringType, nullable = false), + StructField("b", StringType, nullable = false), + StructField("c", StringType, nullable = false)))) + + { + val ret = cast(struct, StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = true), + StructField("d", IntegerType, nullable = true)))) + assert(ret.resolved === true) + checkEvaluation(ret, Row(123, null, null, null)) + } + { + val ret = cast(struct, StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = true)))) + assert(ret.resolved === false) + } + { + val ret = cast(struct, StructType(Seq( + StructField("a", BooleanType, nullable = true), + StructField("b", BooleanType, nullable = true), + StructField("c", BooleanType, nullable = true), + StructField("d", BooleanType, nullable = true)))) + assert(ret.resolved === true) + checkEvaluation(ret, Row(true, true, false, null)) + } + { + val ret = cast(struct, StructType(Seq( + StructField("a", BooleanType, nullable = true), + StructField("b", BooleanType, nullable = true), + StructField("c", BooleanType, nullable = false), + StructField("d", BooleanType, nullable = true)))) + assert(ret.resolved === false) + } + + { + val ret = cast(struct_notNull, StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = true)))) + assert(ret.resolved === true) + checkEvaluation(ret, Row(123, null, null)) + } + { + val ret = cast(struct_notNull, StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false)))) + assert(ret.resolved === false) + } + { + val ret = cast(struct_notNull, StructType(Seq( + StructField("a", BooleanType, nullable = true), + StructField("b", BooleanType, nullable = true), + StructField("c", BooleanType, nullable = true)))) + assert(ret.resolved === true) + checkEvaluation(ret, Row(true, true, false)) + } + { + val ret = cast(struct_notNull, StructType(Seq( + StructField("a", BooleanType, nullable = true), + StructField("b", BooleanType, nullable = true), + StructField("c", BooleanType, nullable = false)))) + assert(ret.resolved === true) + checkEvaluation(ret, Row(true, true, false)) + } + + { + val ret = cast(struct, StructType(Seq( + StructField("a", StringType, nullable = true), + StructField("b", StringType, nullable = true), + StructField("c", StringType, nullable = true)))) + assert(ret.resolved === false) + } + { + val ret = cast(struct, IntegerType) + assert(ret.resolved === false) + } + } + + test("complex casting") { + val complex = Literal.create( + Row( + Seq("123", "abc", ""), + Map("a" -> "123", "b" -> "abc", "c" -> ""), + Row(0)), + StructType(Seq( + StructField("a", + ArrayType(StringType, containsNull = false), nullable = true), + StructField("m", + MapType(StringType, StringType, valueContainsNull = false), nullable = true), + StructField("s", + StructType(Seq( + StructField("i", IntegerType, nullable = true))))))) + + val ret = cast(complex, StructType(Seq( + StructField("a", + ArrayType(IntegerType, containsNull = true), nullable = true), + StructField("m", + MapType(StringType, BooleanType, valueContainsNull = false), nullable = true), + StructField("s", + StructType(Seq( + StructField("l", LongType, nullable = true))))))) + + assert(ret.resolved === true) + checkEvaluation(ret, Row( + Seq(123, null, null), + Map("a" -> true, "b" -> true, "c" -> false), + Row(0L))) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala similarity index 94% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 371a73181dad7..481b335d15dfd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ /** * Additional tests for code generation. */ -class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { +class CodeGenerationSuite extends SparkFunSuite { test("multithreaded eval") { import scala.concurrent._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala new file mode 100644 index 0000000000000..f151dd2a47f78 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types._ + + +class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("CreateStruct") { + val row = Row(1, 2, 3) + val c1 = 'a.int.at(0).as("a") + val c3 = 'c.int.at(2).as("c") + checkEvaluation(CreateStruct(Seq(c1, c3)), Row(1, 3), row) + } + + test("complex type") { + val row = create_row( + "^Ba*n", // 0 + null.asInstanceOf[UTF8String], // 1 + create_row("aa", "bb"), // 2 + Map("aa"->"bb"), // 3 + Seq("aa", "bb") // 4 + ) + + val typeS = StructType( + StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil + ) + val typeMap = MapType(StringType, StringType) + val typeArray = ArrayType(StringType) + + checkEvaluation(GetMapValue(BoundReference(3, typeMap, true), + Literal("aa")), "bb", row) + checkEvaluation(GetMapValue(Literal.create(null, typeMap), Literal("aa")), null, row) + checkEvaluation( + GetMapValue(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row) + checkEvaluation(GetMapValue(BoundReference(3, typeMap, true), + Literal.create(null, StringType)), null, row) + + checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true), + Literal(1)), "bb", row) + checkEvaluation(GetArrayItem(Literal.create(null, typeArray), Literal(1)), null, row) + checkEvaluation( + GetArrayItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row) + checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true), + Literal.create(null, IntegerType)), null, row) + + def getStructField(expr: Expression, fieldName: String): ExtractValue = { + expr.dataType match { + case StructType(fields) => + val field = fields.find(_.name == fieldName).get + GetStructField(expr, field, fields.indexOf(field)) + } + } + + def quickResolve(u: UnresolvedExtractValue): ExtractValue = { + ExtractValue(u.child, u.extraction, _ == _) + } + + checkEvaluation(getStructField(BoundReference(2, typeS, nullable = true), "a"), "aa", row) + checkEvaluation(getStructField(Literal.create(null, typeS), "a"), null, row) + + val typeS_notNullable = StructType( + StructField("a", StringType, nullable = false) + :: StructField("b", StringType, nullable = false) :: Nil + ) + + assert(getStructField(BoundReference(2, typeS, nullable = true), "a").nullable === true) + assert(getStructField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable + === false) + + assert(getStructField(Literal.create(null, typeS), "a").nullable === true) + assert(getStructField(Literal.create(null, typeS_notNullable), "a").nullable === true) + + checkEvaluation(quickResolve('c.map(typeMap).at(3).getItem("aa")), "bb", row) + checkEvaluation(quickResolve('c.array(typeArray.elementType).at(4).getItem(1)), "bb", row) + checkEvaluation(quickResolve('c.struct(typeS).at(2).getField("a")), "aa", row) + } + + test("error message of ExtractValue") { + val structType = StructType(StructField("a", StringType, true) :: Nil) + val arrayStructType = ArrayType(structType) + val arrayType = ArrayType(StringType) + val otherType = StringType + + def checkErrorMessage( + childDataType: DataType, + fieldDataType: DataType, + errorMesage: String): Unit = { + val e = intercept[org.apache.spark.sql.AnalysisException] { + ExtractValue( + Literal.create(null, childDataType), + Literal.create(null, fieldDataType), + _ == _) + } + assert(e.getMessage().contains(errorMesage)) + } + + checkErrorMessage(structType, IntegerType, "Field name should be String Literal") + checkErrorMessage(arrayStructType, BooleanType, "Field name should be String Literal") + checkErrorMessage(arrayType, StringType, "Array index should be integral type") + checkErrorMessage(otherType, StringType, "Can't extract value from") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala new file mode 100644 index 0000000000000..152c4e4111244 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.{IntegerType, BooleanType} + + +class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("case when") { + val row = create_row(null, false, true, "a", "b", "c") + val c1 = 'a.boolean.at(0) + val c2 = 'a.boolean.at(1) + val c3 = 'a.boolean.at(2) + val c4 = 'a.string.at(3) + val c5 = 'a.string.at(4) + val c6 = 'a.string.at(5) + + checkEvaluation(CaseWhen(Seq(c1, c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(c2, c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(c3, c4, c6)), "a", row) + checkEvaluation(CaseWhen(Seq(Literal.create(null, BooleanType), c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(Literal.create(false, BooleanType), c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(Literal.create(true, BooleanType), c4, c6)), "a", row) + + checkEvaluation(CaseWhen(Seq(c3, c4, c2, c5, c6)), "a", row) + checkEvaluation(CaseWhen(Seq(c2, c4, c3, c5, c6)), "b", row) + checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5)), null, row) + + assert(CaseWhen(Seq(c2, c4, c6)).nullable === true) + assert(CaseWhen(Seq(c2, c4, c3, c5, c6)).nullable === true) + assert(CaseWhen(Seq(c2, c4, c3, c5)).nullable === true) + + val c4_notNull = 'a.boolean.notNull.at(3) + val c5_notNull = 'a.boolean.notNull.at(4) + val c6_notNull = 'a.boolean.notNull.at(5) + + assert(CaseWhen(Seq(c2, c4_notNull, c6_notNull)).nullable === false) + assert(CaseWhen(Seq(c2, c4, c6_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c6)).nullable === true) + + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable === false) + assert(CaseWhen(Seq(c2, c4, c3, c5_notNull, c6_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5, c6_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6)).nullable === true) + + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4, c3, c5_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true) + } + + test("case key when") { + val row = create_row(null, 1, 2, "a", "b", "c") + val c1 = 'a.int.at(0) + val c2 = 'a.int.at(1) + val c3 = 'a.int.at(2) + val c4 = 'a.string.at(3) + val c5 = 'a.string.at(4) + val c6 = 'a.string.at(5) + + val literalNull = Literal.create(null, IntegerType) + val literalInt = Literal(1) + val literalString = Literal("a") + + checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, c5)), "b", row) + checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, literalNull, c5, c6)), "b", row) + checkEvaluation(CaseKeyWhen(c2, Seq(literalInt, c4, c5)), "a", row) + checkEvaluation(CaseKeyWhen(c2, Seq(c1, c4, c5)), "b", row) + checkEvaluation(CaseKeyWhen(c4, Seq(literalString, c2, c3)), 1, row) + checkEvaluation(CaseKeyWhen(c4, Seq(c6, c3, c5, c2, Literal(3))), 3, row) + + checkEvaluation(CaseKeyWhen(literalInt, Seq(c2, c4, c5)), "a", row) + checkEvaluation(CaseKeyWhen(literalString, Seq(c5, c2, c4, c3)), 2, row) + checkEvaluation(CaseKeyWhen(c6, Seq(c5, c2, c4, c3)), null, row) + checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), "c", row) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala new file mode 100644 index 0000000000000..87a92b87962f8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.scalactic.TripleEqualsSupport.Spread +import org.scalatest.Matchers._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} + +/** + * A few helper functions for expression evaluation testing. Mixin this trait to use them. + */ +trait ExpressionEvalHelper { + self: SparkFunSuite => + + protected def create_row(values: Any*): Row = { + new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray) + } + + protected def checkEvaluation( + expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { + checkEvaluationWithoutCodegen(expression, expected, inputRow) + checkEvaluationWithGeneratedMutableProjection(expression, expected, inputRow) + checkEvaluationWithGeneratedProjection(expression, expected, inputRow) + } + + protected def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { + expression.eval(inputRow) + } + + protected def checkEvaluationWithoutCodegen( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + val actual = try evaluate(expression, inputRow) catch { + case e: Exception => fail(s"Exception evaluating $expression", e) + } + if (actual != expected) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect evaluation (codegen off): $expression, " + + s"actual: $actual, " + + s"expected: $expected$input") + } + } + + protected def checkEvaluationWithGeneratedMutableProjection( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + + val plan = try { + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() + } catch { + case e: Throwable => + val ctx = GenerateProjection.newCodeGenContext() + val evaluated = expression.gen(ctx) + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code} + |$e + """.stripMargin) + } + + val actual = plan(inputRow).apply(0) + if (actual != expected) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + + protected def checkEvaluationWithGeneratedProjection( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + val ctx = GenerateProjection.newCodeGenContext() + lazy val evaluated = expression.gen(ctx) + + val plan = try { + GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) + } catch { + case e: Throwable => + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code} + |$e + """.stripMargin) + } + + val actual = plan(inputRow) + val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected))) + if (actual.hashCode() != expectedRow.hashCode()) { + fail( + s""" + |Mismatched hashCodes for values: $actual, $expectedRow + |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} + |Expressions: $expression + |Code: $evaluated + """.stripMargin) + } + if (actual != expectedRow) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + + protected def checkDoubleEvaluation( + expression: Expression, + expected: Spread[Double], + inputRow: Row = EmptyRow): Unit = { + val actual = try evaluate(expression, inputRow) catch { + case e: Exception => fail(s"Exception evaluating $expression", e) + } + actual.asInstanceOf[Double] shouldBe expected + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala deleted file mode 100644 index eea2edc323eea..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ /dev/null @@ -1,1461 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import java.sql.{Date, Timestamp} - -import scala.collection.immutable.HashSet - -import org.scalactic.TripleEqualsSupport.Spread -import org.scalatest.Matchers._ - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} -import org.apache.spark.sql.catalyst.expressions.mathfuncs._ -import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.sql.types._ - - -class ExpressionEvaluationBaseSuite extends SparkFunSuite { - - def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { - checkEvaluationWithoutCodegen(expression, expected, inputRow) - checkEvaluationWithGeneratedMutableProjection(expression, expected, inputRow) - checkEvaluationWithGeneratedProjection(expression, expected, inputRow) - } - - def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { - expression.eval(inputRow) - } - - def checkEvaluationWithoutCodegen( - expression: Expression, - expected: Any, - inputRow: Row = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { - case e: Exception => fail(s"Exception evaluating $expression", e) - } - if (actual != expected) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") - } - } - - def checkEvaluationWithGeneratedMutableProjection( - expression: Expression, - expected: Any, - inputRow: Row = EmptyRow): Unit = { - - val plan = try { - GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() - } catch { - case e: Throwable => - val ctx = GenerateProjection.newCodeGenContext() - val evaluated = expression.gen(ctx) - fail( - s""" - |Code generation of $expression failed: - |${evaluated.code} - |$e - """.stripMargin) - } - - val actual = plan(inputRow).apply(0) - if (actual != expected) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") - } - } - - def checkEvaluationWithGeneratedProjection( - expression: Expression, - expected: Any, - inputRow: Row = EmptyRow): Unit = { - val ctx = GenerateProjection.newCodeGenContext() - lazy val evaluated = expression.gen(ctx) - - val plan = try { - GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) - } catch { - case e: Throwable => - fail( - s""" - |Code generation of $expression failed: - |${evaluated.code} - |$e - """.stripMargin) - } - - val actual = plan(inputRow) - val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected))) - if (actual.hashCode() != expectedRow.hashCode()) { - fail( - s""" - |Mismatched hashCodes for values: $actual, $expectedRow - |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} - |Expressions: ${expression} - |Code: ${evaluated} - """.stripMargin) - } - if (actual != expectedRow) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") - } - } - - def checkDoubleEvaluation( - expression: Expression, - expected: Spread[Double], - inputRow: Row = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { - case e: Exception => fail(s"Exception evaluating $expression", e) - } - actual.asInstanceOf[Double] shouldBe expected - } -} - -class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { - - def create_row(values: Any*): Row = { - new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray) - } - - test("literals") { - checkEvaluation(Literal(1), 1) - checkEvaluation(Literal(true), true) - checkEvaluation(Literal(false), false) - checkEvaluation(Literal(0L), 0L) - List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { - d => { - checkEvaluation(Literal(d), d) - checkEvaluation(Literal(d.toFloat), d.toFloat) - } - } - checkEvaluation(Literal("test"), "test") - checkEvaluation(Literal.create(null, StringType), null) - checkEvaluation(Literal(1) + Literal(1), 2) - } - - test("unary BitwiseNOT") { - checkEvaluation(BitwiseNot(1), -2) - assert(BitwiseNot(1).dataType === IntegerType) - assert(BitwiseNot(1).eval(EmptyRow).isInstanceOf[Int]) - checkEvaluation(BitwiseNot(1.toLong), -2.toLong) - assert(BitwiseNot(1.toLong).dataType === LongType) - assert(BitwiseNot(1.toLong).eval(EmptyRow).isInstanceOf[Long]) - checkEvaluation(BitwiseNot(1.toShort), -2.toShort) - assert(BitwiseNot(1.toShort).dataType === ShortType) - assert(BitwiseNot(1.toShort).eval(EmptyRow).isInstanceOf[Short]) - checkEvaluation(BitwiseNot(1.toByte), -2.toByte) - assert(BitwiseNot(1.toByte).dataType === ByteType) - assert(BitwiseNot(1.toByte).eval(EmptyRow).isInstanceOf[Byte]) - } - - // scalastyle:off - /** - * Checks for three-valued-logic. Based on: - * http://en.wikipedia.org/wiki/Null_(SQL)#Comparisons_with_NULL_and_the_three-valued_logic_.283VL.29 - * I.e. in flat cpo "False -> Unknown -> True", - * OR is lowest upper bound, - * AND is greatest lower bound. - * p q p OR q p AND q p = q - * True True True True True - * True False True False False - * True Unknown True Unknown Unknown - * False True True False False - * False False False False True - * False Unknown Unknown False Unknown - * Unknown True True Unknown Unknown - * Unknown False Unknown False Unknown - * Unknown Unknown Unknown Unknown Unknown - * - * p NOT p - * True False - * False True - * Unknown Unknown - */ - // scalastyle:on - val notTrueTable = - (true, false) :: - (false, true) :: - (null, null) :: Nil - - test("3VL Not") { - notTrueTable.foreach { - case (v, answer) => - checkEvaluation(!Literal.create(v, BooleanType), answer) - } - } - - booleanLogicTest("AND", _ && _, - (true, true, true) :: - (true, false, false) :: - (true, null, null) :: - (false, true, false) :: - (false, false, false) :: - (false, null, false) :: - (null, true, null) :: - (null, false, false) :: - (null, null, null) :: Nil) - - booleanLogicTest("OR", _ || _, - (true, true, true) :: - (true, false, true) :: - (true, null, true) :: - (false, true, true) :: - (false, false, false) :: - (false, null, null) :: - (null, true, true) :: - (null, false, null) :: - (null, null, null) :: Nil) - - booleanLogicTest("=", _ === _, - (true, true, true) :: - (true, false, false) :: - (true, null, null) :: - (false, true, false) :: - (false, false, true) :: - (false, null, null) :: - (null, true, null) :: - (null, false, null) :: - (null, null, null) :: Nil) - - def booleanLogicTest( - name: String, - op: (Expression, Expression) => Expression, - truthTable: Seq[(Any, Any, Any)]) { - test(s"3VL $name") { - truthTable.foreach { - case (l, r, answer) => - val expr = op(Literal.create(l, BooleanType), Literal.create(r, BooleanType)) - checkEvaluation(expr, answer) - } - } - } - - test("IN") { - checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) - checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) - checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) - checkEvaluation( - In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), - true) - } - - test("Divide") { - checkEvaluation(Divide(Literal(2), Literal(1)), 2) - checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5) - checkEvaluation(Divide(Literal(1), Literal(2)), 0) - checkEvaluation(Divide(Literal(1), Literal(0)), null) - checkEvaluation(Divide(Literal(1.0), Literal(0.0)), null) - checkEvaluation(Divide(Literal(0.0), Literal(0.0)), null) - checkEvaluation(Divide(Literal(0), Literal.create(null, IntegerType)), null) - checkEvaluation(Divide(Literal(1), Literal.create(null, IntegerType)), null) - checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(0)), null) - checkEvaluation(Divide(Literal.create(null, DoubleType), Literal(0.0)), null) - checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(1)), null) - checkEvaluation(Divide(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), - null) - } - - test("Remainder") { - checkEvaluation(Remainder(Literal(2), Literal(1)), 0) - checkEvaluation(Remainder(Literal(1.0), Literal(2.0)), 1.0) - checkEvaluation(Remainder(Literal(1), Literal(2)), 1) - checkEvaluation(Remainder(Literal(1), Literal(0)), null) - checkEvaluation(Remainder(Literal(1.0), Literal(0.0)), null) - checkEvaluation(Remainder(Literal(0.0), Literal(0.0)), null) - checkEvaluation(Remainder(Literal(0), Literal.create(null, IntegerType)), null) - checkEvaluation(Remainder(Literal(1), Literal.create(null, IntegerType)), null) - checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(0)), null) - checkEvaluation(Remainder(Literal.create(null, DoubleType), Literal(0.0)), null) - checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(1)), null) - checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), - null) - } - - test("INSET") { - val hS = HashSet[Any]() + 1 + 2 - val nS = HashSet[Any]() + 1 + 2 + null - val one = Literal(1) - val two = Literal(2) - val three = Literal(3) - val nl = Literal(null) - val s = Seq(one, two) - val nullS = Seq(one, two, null) - checkEvaluation(InSet(one, hS), true) - checkEvaluation(InSet(two, hS), true) - checkEvaluation(InSet(two, nS), true) - checkEvaluation(InSet(nl, nS), true) - checkEvaluation(InSet(three, hS), false) - checkEvaluation(InSet(three, nS), false) - checkEvaluation(InSet(one, hS) && InSet(two, hS), true) - } - - test("MaxOf") { - checkEvaluation(MaxOf(1, 2), 2) - checkEvaluation(MaxOf(2, 1), 2) - checkEvaluation(MaxOf(1L, 2L), 2L) - checkEvaluation(MaxOf(2L, 1L), 2L) - - checkEvaluation(MaxOf(Literal.create(null, IntegerType), 2), 2) - checkEvaluation(MaxOf(2, Literal.create(null, IntegerType)), 2) - } - - test("MinOf") { - checkEvaluation(MinOf(1, 2), 1) - checkEvaluation(MinOf(2, 1), 1) - checkEvaluation(MinOf(1L, 2L), 1L) - checkEvaluation(MinOf(2L, 1L), 1L) - - checkEvaluation(MinOf(Literal.create(null, IntegerType), 1), 1) - checkEvaluation(MinOf(1, Literal.create(null, IntegerType)), 1) - } - - test("LIKE literal Regular Expression") { - checkEvaluation(Literal.create(null, StringType).like("a"), null) - checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null) - checkEvaluation(Literal.create(null, StringType).like(Literal.create(null, StringType)), null) - checkEvaluation("abdef" like "abdef", true) - checkEvaluation("a_%b" like "a\\__b", true) - checkEvaluation("addb" like "a_%b", true) - checkEvaluation("addb" like "a\\__b", false) - checkEvaluation("addb" like "a%\\%b", false) - checkEvaluation("a_%b" like "a%\\%b", true) - checkEvaluation("addb" like "a%", true) - checkEvaluation("addb" like "**", false) - checkEvaluation("abc" like "a%", true) - checkEvaluation("abc" like "b%", false) - checkEvaluation("abc" like "bc%", false) - checkEvaluation("a\nb" like "a_b", true) - checkEvaluation("ab" like "a%b", true) - checkEvaluation("a\nb" like "a%b", true) - } - - test("LIKE Non-literal Regular Expression") { - val regEx = 'a.string.at(0) - checkEvaluation("abcd" like regEx, null, create_row(null)) - checkEvaluation("abdef" like regEx, true, create_row("abdef")) - checkEvaluation("a_%b" like regEx, true, create_row("a\\__b")) - checkEvaluation("addb" like regEx, true, create_row("a_%b")) - checkEvaluation("addb" like regEx, false, create_row("a\\__b")) - checkEvaluation("addb" like regEx, false, create_row("a%\\%b")) - checkEvaluation("a_%b" like regEx, true, create_row("a%\\%b")) - checkEvaluation("addb" like regEx, true, create_row("a%")) - checkEvaluation("addb" like regEx, false, create_row("**")) - checkEvaluation("abc" like regEx, true, create_row("a%")) - checkEvaluation("abc" like regEx, false, create_row("b%")) - checkEvaluation("abc" like regEx, false, create_row("bc%")) - checkEvaluation("a\nb" like regEx, true, create_row("a_b")) - checkEvaluation("ab" like regEx, true, create_row("a%b")) - checkEvaluation("a\nb" like regEx, true, create_row("a%b")) - - checkEvaluation(Literal.create(null, StringType) like regEx, null, create_row("bc%")) - } - - test("RLIKE literal Regular Expression") { - checkEvaluation(Literal.create(null, StringType) rlike "abdef", null) - checkEvaluation("abdef" rlike Literal.create(null, StringType), null) - checkEvaluation(Literal.create(null, StringType) rlike Literal.create(null, StringType), null) - checkEvaluation("abdef" rlike "abdef", true) - checkEvaluation("abbbbc" rlike "a.*c", true) - - checkEvaluation("fofo" rlike "^fo", true) - checkEvaluation("fo\no" rlike "^fo\no$", true) - checkEvaluation("Bn" rlike "^Ba*n", true) - checkEvaluation("afofo" rlike "fo", true) - checkEvaluation("afofo" rlike "^fo", false) - checkEvaluation("Baan" rlike "^Ba?n", false) - checkEvaluation("axe" rlike "pi|apa", false) - checkEvaluation("pip" rlike "^(pi)*$", false) - - checkEvaluation("abc" rlike "^ab", true) - checkEvaluation("abc" rlike "^bc", false) - checkEvaluation("abc" rlike "^ab", true) - checkEvaluation("abc" rlike "^bc", false) - - intercept[java.util.regex.PatternSyntaxException] { - evaluate("abbbbc" rlike "**") - } - } - - test("RLIKE Non-literal Regular Expression") { - val regEx = 'a.string.at(0) - checkEvaluation("abdef" rlike regEx, true, create_row("abdef")) - checkEvaluation("abbbbc" rlike regEx, true, create_row("a.*c")) - checkEvaluation("fofo" rlike regEx, true, create_row("^fo")) - checkEvaluation("fo\no" rlike regEx, true, create_row("^fo\no$")) - checkEvaluation("Bn" rlike regEx, true, create_row("^Ba*n")) - - intercept[java.util.regex.PatternSyntaxException] { - evaluate("abbbbc" rlike regEx, create_row("**")) - } - } - - test("data type casting") { - - val sd = "1970-01-01" - val d = Date.valueOf(sd) - val zts = sd + " 00:00:00" - val sts = sd + " 00:00:02" - val nts = sts + ".1" - val ts = Timestamp.valueOf(nts) - - checkEvaluation("abdef" cast StringType, "abdef") - checkEvaluation("abdef" cast DecimalType.Unlimited, null) - checkEvaluation("abdef" cast TimestampType, null) - checkEvaluation("12.65" cast DecimalType.Unlimited, Decimal(12.65)) - - checkEvaluation(Literal(1) cast LongType, 1.toLong) - checkEvaluation(Cast(Literal(1000) cast TimestampType, LongType), 1.toLong) - checkEvaluation(Cast(Literal(-1200) cast TimestampType, LongType), -2.toLong) - checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) - checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) - - checkEvaluation(Cast(Literal(sd) cast DateType, StringType), sd) - checkEvaluation(Cast(Literal(d) cast StringType, DateType), 0) - checkEvaluation(Cast(Literal(nts) cast TimestampType, StringType), nts) - checkEvaluation(Cast(Literal(ts) cast StringType, TimestampType), ts) - // all convert to string type to check - checkEvaluation( - Cast(Cast(Literal(nts) cast TimestampType, DateType), StringType), sd) - checkEvaluation( - Cast(Cast(Literal(ts) cast DateType, TimestampType), StringType), zts) - - checkEvaluation(Cast("abdef" cast BinaryType, StringType), "abdef") - - checkEvaluation(Cast(Cast(Cast(Cast( - Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), - 5.toLong) - checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast - ByteType, TimestampType), DecimalType.Unlimited), LongType), StringType), ShortType), - 0.toShort) - checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast - TimestampType, ByteType), DecimalType.Unlimited), LongType), StringType), ShortType), null) - checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast - DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType), - 0.toShort) - checkEvaluation(Literal(true) cast IntegerType, 1) - checkEvaluation(Literal(false) cast IntegerType, 0) - checkEvaluation(Literal(true) cast StringType, "true") - checkEvaluation(Literal(false) cast StringType, "false") - checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1) - checkEvaluation(Cast(Literal(0) cast BooleanType, IntegerType), 0) - checkEvaluation("23" cast DoubleType, 23d) - checkEvaluation("23" cast IntegerType, 23) - checkEvaluation("23" cast FloatType, 23f) - checkEvaluation("23" cast DecimalType.Unlimited, Decimal(23)) - checkEvaluation("23" cast ByteType, 23.toByte) - checkEvaluation("23" cast ShortType, 23.toShort) - checkEvaluation("2012-12-11" cast DoubleType, null) - checkEvaluation(Literal(123) cast IntegerType, 123) - - checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24d) - checkEvaluation(Literal(23) + Cast(true, IntegerType), 24) - checkEvaluation(Literal(23f) + Cast(true, FloatType), 24f) - checkEvaluation(Literal(Decimal(23)) + Cast(true, DecimalType.Unlimited), Decimal(24)) - checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24.toByte) - checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24.toShort) - - intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)} - - assert(("abcdef" cast StringType).nullable === false) - assert(("abcdef" cast BinaryType).nullable === false) - assert(("abcdef" cast BooleanType).nullable === false) - assert(("abcdef" cast TimestampType).nullable === true) - assert(("abcdef" cast LongType).nullable === true) - assert(("abcdef" cast IntegerType).nullable === true) - assert(("abcdef" cast ShortType).nullable === true) - assert(("abcdef" cast ByteType).nullable === true) - assert(("abcdef" cast DecimalType.Unlimited).nullable === true) - assert(("abcdef" cast DecimalType(4, 2)).nullable === true) - assert(("abcdef" cast DoubleType).nullable === true) - assert(("abcdef" cast FloatType).nullable === true) - - checkEvaluation(Cast(Literal.create(null, IntegerType), ShortType), null) - } - - test("date") { - val d1 = DateUtils.fromJavaDate(Date.valueOf("1970-01-01")) - val d2 = DateUtils.fromJavaDate(Date.valueOf("1970-01-02")) - checkEvaluation(Literal(d1) < Literal(d2), true) - } - - test("casting to fixed-precision decimals") { - // Overflow and rounding for casting to fixed-precision decimals: - // - Values should round with HALF_UP mode by default when you lower scale - // - Values that would overflow the target precision should turn into null - // - Because of this, casts to fixed-precision decimals should be nullable - - assert(Cast(Literal(123), DecimalType.Unlimited).nullable === false) - assert(Cast(Literal(10.03f), DecimalType.Unlimited).nullable === true) - assert(Cast(Literal(10.03), DecimalType.Unlimited).nullable === true) - assert(Cast(Literal(Decimal(10.03)), DecimalType.Unlimited).nullable === false) - - assert(Cast(Literal(123), DecimalType(2, 1)).nullable === true) - assert(Cast(Literal(10.03f), DecimalType(2, 1)).nullable === true) - assert(Cast(Literal(10.03), DecimalType(2, 1)).nullable === true) - assert(Cast(Literal(Decimal(10.03)), DecimalType(2, 1)).nullable === true) - - checkEvaluation(Cast(Literal(123), DecimalType.Unlimited), Decimal(123)) - checkEvaluation(Cast(Literal(123), DecimalType(3, 0)), Decimal(123)) - checkEvaluation(Cast(Literal(123), DecimalType(3, 1)), null) - checkEvaluation(Cast(Literal(123), DecimalType(2, 0)), null) - - checkEvaluation(Cast(Literal(10.03), DecimalType.Unlimited), Decimal(10.03)) - checkEvaluation(Cast(Literal(10.03), DecimalType(4, 2)), Decimal(10.03)) - checkEvaluation(Cast(Literal(10.03), DecimalType(3, 1)), Decimal(10.0)) - checkEvaluation(Cast(Literal(10.03), DecimalType(2, 0)), Decimal(10)) - checkEvaluation(Cast(Literal(10.03), DecimalType(1, 0)), null) - checkEvaluation(Cast(Literal(10.03), DecimalType(2, 1)), null) - checkEvaluation(Cast(Literal(10.03), DecimalType(3, 2)), null) - checkEvaluation(Cast(Literal(Decimal(10.03)), DecimalType(3, 1)), Decimal(10.0)) - checkEvaluation(Cast(Literal(Decimal(10.03)), DecimalType(3, 2)), null) - - checkEvaluation(Cast(Literal(10.05), DecimalType.Unlimited), Decimal(10.05)) - checkEvaluation(Cast(Literal(10.05), DecimalType(4, 2)), Decimal(10.05)) - checkEvaluation(Cast(Literal(10.05), DecimalType(3, 1)), Decimal(10.1)) - checkEvaluation(Cast(Literal(10.05), DecimalType(2, 0)), Decimal(10)) - checkEvaluation(Cast(Literal(10.05), DecimalType(1, 0)), null) - checkEvaluation(Cast(Literal(10.05), DecimalType(2, 1)), null) - checkEvaluation(Cast(Literal(10.05), DecimalType(3, 2)), null) - checkEvaluation(Cast(Literal(Decimal(10.05)), DecimalType(3, 1)), Decimal(10.1)) - checkEvaluation(Cast(Literal(Decimal(10.05)), DecimalType(3, 2)), null) - - checkEvaluation(Cast(Literal(9.95), DecimalType(3, 2)), Decimal(9.95)) - checkEvaluation(Cast(Literal(9.95), DecimalType(3, 1)), Decimal(10.0)) - checkEvaluation(Cast(Literal(9.95), DecimalType(2, 0)), Decimal(10)) - checkEvaluation(Cast(Literal(9.95), DecimalType(2, 1)), null) - checkEvaluation(Cast(Literal(9.95), DecimalType(1, 0)), null) - checkEvaluation(Cast(Literal(Decimal(9.95)), DecimalType(3, 1)), Decimal(10.0)) - checkEvaluation(Cast(Literal(Decimal(9.95)), DecimalType(1, 0)), null) - - checkEvaluation(Cast(Literal(-9.95), DecimalType(3, 2)), Decimal(-9.95)) - checkEvaluation(Cast(Literal(-9.95), DecimalType(3, 1)), Decimal(-10.0)) - checkEvaluation(Cast(Literal(-9.95), DecimalType(2, 0)), Decimal(-10)) - checkEvaluation(Cast(Literal(-9.95), DecimalType(2, 1)), null) - checkEvaluation(Cast(Literal(-9.95), DecimalType(1, 0)), null) - checkEvaluation(Cast(Literal(Decimal(-9.95)), DecimalType(3, 1)), Decimal(-10.0)) - checkEvaluation(Cast(Literal(Decimal(-9.95)), DecimalType(1, 0)), null) - - checkEvaluation(Cast(Literal(Double.NaN), DecimalType.Unlimited), null) - checkEvaluation(Cast(Literal(1.0 / 0.0), DecimalType.Unlimited), null) - checkEvaluation(Cast(Literal(Float.NaN), DecimalType.Unlimited), null) - checkEvaluation(Cast(Literal(1.0f / 0.0f), DecimalType.Unlimited), null) - - checkEvaluation(Cast(Literal(Double.NaN), DecimalType(2, 1)), null) - checkEvaluation(Cast(Literal(1.0 / 0.0), DecimalType(2, 1)), null) - checkEvaluation(Cast(Literal(Float.NaN), DecimalType(2, 1)), null) - checkEvaluation(Cast(Literal(1.0f / 0.0f), DecimalType(2, 1)), null) - } - - test("timestamp") { - val ts1 = new Timestamp(12) - val ts2 = new Timestamp(123) - checkEvaluation(Literal("ab") < Literal("abc"), true) - checkEvaluation(Literal(ts1) < Literal(ts2), true) - } - - test("date casting") { - val d = Date.valueOf("1970-01-01") - checkEvaluation(Cast(Literal(d), ShortType), null) - checkEvaluation(Cast(Literal(d), IntegerType), null) - checkEvaluation(Cast(Literal(d), LongType), null) - checkEvaluation(Cast(Literal(d), FloatType), null) - checkEvaluation(Cast(Literal(d), DoubleType), null) - checkEvaluation(Cast(Literal(d), DecimalType.Unlimited), null) - checkEvaluation(Cast(Literal(d), DecimalType(10, 2)), null) - checkEvaluation(Cast(Literal(d), StringType), "1970-01-01") - checkEvaluation(Cast(Cast(Literal(d), TimestampType), StringType), "1970-01-01 00:00:00") - } - - test("timestamp casting") { - val millis = 15 * 1000 + 2 - val seconds = millis * 1000 + 2 - val ts = new Timestamp(millis) - val tss = new Timestamp(seconds) - checkEvaluation(Cast(ts, ShortType), 15.toShort) - checkEvaluation(Cast(ts, IntegerType), 15) - checkEvaluation(Cast(ts, LongType), 15.toLong) - checkEvaluation(Cast(ts, FloatType), 15.002f) - checkEvaluation(Cast(ts, DoubleType), 15.002) - checkEvaluation(Cast(Cast(tss, ShortType), TimestampType), ts) - checkEvaluation(Cast(Cast(tss, IntegerType), TimestampType), ts) - checkEvaluation(Cast(Cast(tss, LongType), TimestampType), ts) - checkEvaluation(Cast(Cast(millis.toFloat / 1000, TimestampType), FloatType), - millis.toFloat / 1000) - checkEvaluation(Cast(Cast(millis.toDouble / 1000, TimestampType), DoubleType), - millis.toDouble / 1000) - checkEvaluation(Cast(Literal(Decimal(1)) cast TimestampType, DecimalType.Unlimited), Decimal(1)) - - // A test for higher precision than millis - checkEvaluation(Cast(Cast(0.00000001, TimestampType), DoubleType), 0.00000001) - - checkEvaluation(Cast(Literal(Double.NaN), TimestampType), null) - checkEvaluation(Cast(Literal(1.0 / 0.0), TimestampType), null) - checkEvaluation(Cast(Literal(Float.NaN), TimestampType), null) - checkEvaluation(Cast(Literal(1.0f / 0.0f), TimestampType), null) - } - - test("array casting") { - val array = Literal.create(Seq("123", "abc", "", null), - ArrayType(StringType, containsNull = true)) - val array_notNull = Literal.create(Seq("123", "abc", ""), - ArrayType(StringType, containsNull = false)) - - { - val cast = Cast(array, ArrayType(IntegerType, containsNull = true)) - assert(cast.resolved === true) - checkEvaluation(cast, Seq(123, null, null, null)) - } - { - val cast = Cast(array, ArrayType(IntegerType, containsNull = false)) - assert(cast.resolved === false) - } - { - val cast = Cast(array, ArrayType(BooleanType, containsNull = true)) - assert(cast.resolved === true) - checkEvaluation(cast, Seq(true, true, false, null)) - } - { - val cast = Cast(array, ArrayType(BooleanType, containsNull = false)) - assert(cast.resolved === false) - } - - { - val cast = Cast(array_notNull, ArrayType(IntegerType, containsNull = true)) - assert(cast.resolved === true) - checkEvaluation(cast, Seq(123, null, null)) - } - { - val cast = Cast(array_notNull, ArrayType(IntegerType, containsNull = false)) - assert(cast.resolved === false) - } - { - val cast = Cast(array_notNull, ArrayType(BooleanType, containsNull = true)) - assert(cast.resolved === true) - checkEvaluation(cast, Seq(true, true, false)) - } - { - val cast = Cast(array_notNull, ArrayType(BooleanType, containsNull = false)) - assert(cast.resolved === true) - checkEvaluation(cast, Seq(true, true, false)) - } - - { - val cast = Cast(array, IntegerType) - assert(cast.resolved === false) - } - } - - test("map casting") { - val map = Literal.create( - Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null), - MapType(StringType, StringType, valueContainsNull = true)) - val map_notNull = Literal.create( - Map("a" -> "123", "b" -> "abc", "c" -> ""), - MapType(StringType, StringType, valueContainsNull = false)) - - { - val cast = Cast(map, MapType(StringType, IntegerType, valueContainsNull = true)) - assert(cast.resolved === true) - checkEvaluation(cast, Map("a" -> 123, "b" -> null, "c" -> null, "d" -> null)) - } - { - val cast = Cast(map, MapType(StringType, IntegerType, valueContainsNull = false)) - assert(cast.resolved === false) - } - { - val cast = Cast(map, MapType(StringType, BooleanType, valueContainsNull = true)) - assert(cast.resolved === true) - checkEvaluation(cast, Map("a" -> true, "b" -> true, "c" -> false, "d" -> null)) - } - { - val cast = Cast(map, MapType(StringType, BooleanType, valueContainsNull = false)) - assert(cast.resolved === false) - } - { - val cast = Cast(map, MapType(IntegerType, StringType, valueContainsNull = true)) - assert(cast.resolved === false) - } - - { - val cast = Cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = true)) - assert(cast.resolved === true) - checkEvaluation(cast, Map("a" -> 123, "b" -> null, "c" -> null)) - } - { - val cast = Cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = false)) - assert(cast.resolved === false) - } - { - val cast = Cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = true)) - assert(cast.resolved === true) - checkEvaluation(cast, Map("a" -> true, "b" -> true, "c" -> false)) - } - { - val cast = Cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false)) - assert(cast.resolved === true) - checkEvaluation(cast, Map("a" -> true, "b" -> true, "c" -> false)) - } - { - val cast = Cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true)) - assert(cast.resolved === false) - } - - { - val cast = Cast(map, IntegerType) - assert(cast.resolved === false) - } - } - - test("struct casting") { - val struct = Literal.create( - Row("123", "abc", "", null), - StructType(Seq( - StructField("a", StringType, nullable = true), - StructField("b", StringType, nullable = true), - StructField("c", StringType, nullable = true), - StructField("d", StringType, nullable = true)))) - val struct_notNull = Literal.create( - Row("123", "abc", ""), - StructType(Seq( - StructField("a", StringType, nullable = false), - StructField("b", StringType, nullable = false), - StructField("c", StringType, nullable = false)))) - - { - val cast = Cast(struct, StructType(Seq( - StructField("a", IntegerType, nullable = true), - StructField("b", IntegerType, nullable = true), - StructField("c", IntegerType, nullable = true), - StructField("d", IntegerType, nullable = true)))) - assert(cast.resolved === true) - checkEvaluation(cast, Row(123, null, null, null)) - } - { - val cast = Cast(struct, StructType(Seq( - StructField("a", IntegerType, nullable = true), - StructField("b", IntegerType, nullable = true), - StructField("c", IntegerType, nullable = false), - StructField("d", IntegerType, nullable = true)))) - assert(cast.resolved === false) - } - { - val cast = Cast(struct, StructType(Seq( - StructField("a", BooleanType, nullable = true), - StructField("b", BooleanType, nullable = true), - StructField("c", BooleanType, nullable = true), - StructField("d", BooleanType, nullable = true)))) - assert(cast.resolved === true) - checkEvaluation(cast, Row(true, true, false, null)) - } - { - val cast = Cast(struct, StructType(Seq( - StructField("a", BooleanType, nullable = true), - StructField("b", BooleanType, nullable = true), - StructField("c", BooleanType, nullable = false), - StructField("d", BooleanType, nullable = true)))) - assert(cast.resolved === false) - } - - { - val cast = Cast(struct_notNull, StructType(Seq( - StructField("a", IntegerType, nullable = true), - StructField("b", IntegerType, nullable = true), - StructField("c", IntegerType, nullable = true)))) - assert(cast.resolved === true) - checkEvaluation(cast, Row(123, null, null)) - } - { - val cast = Cast(struct_notNull, StructType(Seq( - StructField("a", IntegerType, nullable = true), - StructField("b", IntegerType, nullable = true), - StructField("c", IntegerType, nullable = false)))) - assert(cast.resolved === false) - } - { - val cast = Cast(struct_notNull, StructType(Seq( - StructField("a", BooleanType, nullable = true), - StructField("b", BooleanType, nullable = true), - StructField("c", BooleanType, nullable = true)))) - assert(cast.resolved === true) - checkEvaluation(cast, Row(true, true, false)) - } - { - val cast = Cast(struct_notNull, StructType(Seq( - StructField("a", BooleanType, nullable = true), - StructField("b", BooleanType, nullable = true), - StructField("c", BooleanType, nullable = false)))) - assert(cast.resolved === true) - checkEvaluation(cast, Row(true, true, false)) - } - - { - val cast = Cast(struct, StructType(Seq( - StructField("a", StringType, nullable = true), - StructField("b", StringType, nullable = true), - StructField("c", StringType, nullable = true)))) - assert(cast.resolved === false) - } - { - val cast = Cast(struct, IntegerType) - assert(cast.resolved === false) - } - } - - test("complex casting") { - val complex = Literal.create( - Row( - Seq("123", "abc", ""), - Map("a" -> "123", "b" -> "abc", "c" -> ""), - Row(0)), - StructType(Seq( - StructField("a", - ArrayType(StringType, containsNull = false), nullable = true), - StructField("m", - MapType(StringType, StringType, valueContainsNull = false), nullable = true), - StructField("s", - StructType(Seq( - StructField("i", IntegerType, nullable = true))))))) - - val cast = Cast(complex, StructType(Seq( - StructField("a", - ArrayType(IntegerType, containsNull = true), nullable = true), - StructField("m", - MapType(StringType, BooleanType, valueContainsNull = false), nullable = true), - StructField("s", - StructType(Seq( - StructField("l", LongType, nullable = true))))))) - - assert(cast.resolved === true) - checkEvaluation(cast, Row( - Seq(123, null, null), - Map("a" -> true, "b" -> true, "c" -> false), - Row(0L))) - } - - test("null checking") { - val row = create_row("^Ba*n", null, true, null) - val c1 = 'a.string.at(0) - val c2 = 'a.string.at(1) - val c3 = 'a.boolean.at(2) - val c4 = 'a.boolean.at(3) - - checkEvaluation(c1.isNull, false, row) - checkEvaluation(c1.isNotNull, true, row) - - checkEvaluation(c2.isNull, true, row) - checkEvaluation(c2.isNotNull, false, row) - - checkEvaluation(Literal.create(1, ShortType).isNull, false) - checkEvaluation(Literal.create(1, ShortType).isNotNull, true) - - checkEvaluation(Literal.create(null, ShortType).isNull, true) - checkEvaluation(Literal.create(null, ShortType).isNotNull, false) - - checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row) - checkEvaluation(Coalesce(Literal.create(null, StringType) :: Nil), null, row) - checkEvaluation(Coalesce(Literal.create(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row) - - checkEvaluation( - If(c3, Literal.create("a", StringType), Literal.create("b", StringType)), "a", row) - checkEvaluation(If(c3, c1, c2), "^Ba*n", row) - checkEvaluation(If(c4, c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal.create(null, BooleanType), c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal.create(true, BooleanType), c1, c2), "^Ba*n", row) - checkEvaluation(If(Literal.create(false, BooleanType), c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal.create(false, BooleanType), - Literal.create("a", StringType), Literal.create("b", StringType)), "b", row) - - checkEvaluation(c1 in (c1, c2), true, row) - checkEvaluation( - Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType)), true, row) - checkEvaluation( - Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType), c2), true, row) - } - - test("case when") { - val row = create_row(null, false, true, "a", "b", "c") - val c1 = 'a.boolean.at(0) - val c2 = 'a.boolean.at(1) - val c3 = 'a.boolean.at(2) - val c4 = 'a.string.at(3) - val c5 = 'a.string.at(4) - val c6 = 'a.string.at(5) - - checkEvaluation(CaseWhen(Seq(c1, c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(c2, c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(c3, c4, c6)), "a", row) - checkEvaluation(CaseWhen(Seq(Literal.create(null, BooleanType), c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(Literal.create(false, BooleanType), c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(Literal.create(true, BooleanType), c4, c6)), "a", row) - - checkEvaluation(CaseWhen(Seq(c3, c4, c2, c5, c6)), "a", row) - checkEvaluation(CaseWhen(Seq(c2, c4, c3, c5, c6)), "b", row) - checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5)), null, row) - - assert(CaseWhen(Seq(c2, c4, c6)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5, c6)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5)).nullable === true) - - val c4_notNull = 'a.boolean.notNull.at(3) - val c5_notNull = 'a.boolean.notNull.at(4) - val c6_notNull = 'a.boolean.notNull.at(5) - - assert(CaseWhen(Seq(c2, c4_notNull, c6_notNull)).nullable === false) - assert(CaseWhen(Seq(c2, c4, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c6)).nullable === true) - - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable === false) - assert(CaseWhen(Seq(c2, c4, c3, c5_notNull, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6)).nullable === true) - - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true) - } - - test("case key when") { - val row = create_row(null, 1, 2, "a", "b", "c") - val c1 = 'a.int.at(0) - val c2 = 'a.int.at(1) - val c3 = 'a.int.at(2) - val c4 = 'a.string.at(3) - val c5 = 'a.string.at(4) - val c6 = 'a.string.at(5) - - val literalNull = Literal.create(null, IntegerType) - val literalInt = Literal(1) - val literalString = Literal("a") - - checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, c5)), "b", row) - checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, literalNull, c5, c6)), "b", row) - checkEvaluation(CaseKeyWhen(c2, Seq(literalInt, c4, c5)), "a", row) - checkEvaluation(CaseKeyWhen(c2, Seq(c1, c4, c5)), "b", row) - checkEvaluation(CaseKeyWhen(c4, Seq(literalString, c2, c3)), 1, row) - checkEvaluation(CaseKeyWhen(c4, Seq(c6, c3, c5, c2, Literal(3))), 3, row) - - checkEvaluation(CaseKeyWhen(literalInt, Seq(c2, c4, c5)), "a", row) - checkEvaluation(CaseKeyWhen(literalString, Seq(c5, c2, c4, c3)), 2, row) - checkEvaluation(CaseKeyWhen(c6, Seq(c5, c2, c4, c3)), null, row) - checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), "c", row) - } - - test("complex type") { - val row = create_row( - "^Ba*n", // 0 - null.asInstanceOf[UTF8String], // 1 - create_row("aa", "bb"), // 2 - Map("aa"->"bb"), // 3 - Seq("aa", "bb") // 4 - ) - - val typeS = StructType( - StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil - ) - val typeMap = MapType(StringType, StringType) - val typeArray = ArrayType(StringType) - - checkEvaluation(GetMapValue(BoundReference(3, typeMap, true), - Literal("aa")), "bb", row) - checkEvaluation(GetMapValue(Literal.create(null, typeMap), Literal("aa")), null, row) - checkEvaluation( - GetMapValue(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row) - checkEvaluation(GetMapValue(BoundReference(3, typeMap, true), - Literal.create(null, StringType)), null, row) - - checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true), - Literal(1)), "bb", row) - checkEvaluation(GetArrayItem(Literal.create(null, typeArray), Literal(1)), null, row) - checkEvaluation( - GetArrayItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row) - checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true), - Literal.create(null, IntegerType)), null, row) - - def getStructField(expr: Expression, fieldName: String): ExtractValue = { - expr.dataType match { - case StructType(fields) => - val field = fields.find(_.name == fieldName).get - GetStructField(expr, field, fields.indexOf(field)) - } - } - - def quickResolve(u: UnresolvedExtractValue): ExtractValue = { - ExtractValue(u.child, u.extraction, _ == _) - } - - checkEvaluation(getStructField(BoundReference(2, typeS, nullable = true), "a"), "aa", row) - checkEvaluation(getStructField(Literal.create(null, typeS), "a"), null, row) - - val typeS_notNullable = StructType( - StructField("a", StringType, nullable = false) - :: StructField("b", StringType, nullable = false) :: Nil - ) - - assert(getStructField(BoundReference(2, typeS, nullable = true), "a").nullable === true) - assert(getStructField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable - === false) - - assert(getStructField(Literal.create(null, typeS), "a").nullable === true) - assert(getStructField(Literal.create(null, typeS_notNullable), "a").nullable === true) - - checkEvaluation(quickResolve('c.map(typeMap).at(3).getItem("aa")), "bb", row) - checkEvaluation(quickResolve('c.array(typeArray.elementType).at(4).getItem(1)), "bb", row) - checkEvaluation(quickResolve('c.struct(typeS).at(2).getField("a")), "aa", row) - } - - test("error message of ExtractValue") { - val structType = StructType(StructField("a", StringType, true) :: Nil) - val arrayStructType = ArrayType(structType) - val arrayType = ArrayType(StringType) - val otherType = StringType - - def checkErrorMessage( - childDataType: DataType, - fieldDataType: DataType, - errorMesage: String): Unit = { - val e = intercept[org.apache.spark.sql.AnalysisException] { - ExtractValue( - Literal.create(null, childDataType), - Literal.create(null, fieldDataType), - _ == _) - } - assert(e.getMessage().contains(errorMesage)) - } - - checkErrorMessage(structType, IntegerType, "Field name should be String Literal") - checkErrorMessage(arrayStructType, BooleanType, "Field name should be String Literal") - checkErrorMessage(arrayType, StringType, "Array index should be integral type") - checkErrorMessage(otherType, StringType, "Can't extract value from") - } - - test("arithmetic") { - val row = create_row(1, 2, 3, null) - val c1 = 'a.int.at(0) - val c2 = 'a.int.at(1) - val c3 = 'a.int.at(2) - val c4 = 'a.int.at(3) - - checkEvaluation(UnaryMinus(c1), -1, row) - checkEvaluation(UnaryMinus(Literal.create(100, IntegerType)), -100) - - checkEvaluation(Add(c1, c4), null, row) - checkEvaluation(Add(c1, c2), 3, row) - checkEvaluation(Add(c1, Literal.create(null, IntegerType)), null, row) - checkEvaluation(Add(Literal.create(null, IntegerType), c2), null, row) - checkEvaluation( - Add(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) - - checkEvaluation(-c1, -1, row) - checkEvaluation(c1 + c2, 3, row) - checkEvaluation(c1 - c2, -1, row) - checkEvaluation(c1 * c2, 2, row) - checkEvaluation(c1 / c2, 0, row) - checkEvaluation(c1 % c2, 1, row) - } - - test("fractional arithmetic") { - val row = create_row(1.1, 2.0, 3.1, null) - val c1 = 'a.double.at(0) - val c2 = 'a.double.at(1) - val c3 = 'a.double.at(2) - val c4 = 'a.double.at(3) - - checkEvaluation(UnaryMinus(c1), -1.1, row) - checkEvaluation(UnaryMinus(Literal.create(100.0, DoubleType)), -100.0) - checkEvaluation(Add(c1, c4), null, row) - checkEvaluation(Add(c1, c2), 3.1, row) - checkEvaluation(Add(c1, Literal.create(null, DoubleType)), null, row) - checkEvaluation(Add(Literal.create(null, DoubleType), c2), null, row) - checkEvaluation( - Add(Literal.create(null, DoubleType), Literal.create(null, DoubleType)), null, row) - - checkEvaluation(-c1, -1.1, row) - checkEvaluation(c1 + c2, 3.1, row) - checkDoubleEvaluation(c1 - c2, (-0.9 +- 0.001), row) - checkDoubleEvaluation(c1 * c2, (2.2 +- 0.001), row) - checkDoubleEvaluation(c1 / c2, (0.55 +- 0.001), row) - checkDoubleEvaluation(c3 % c2, (1.1 +- 0.001), row) - } - - test("BinaryComparison") { - val row = create_row(1, 2, 3, null, 3, null) - val c1 = 'a.int.at(0) - val c2 = 'a.int.at(1) - val c3 = 'a.int.at(2) - val c4 = 'a.int.at(3) - val c5 = 'a.int.at(4) - val c6 = 'a.int.at(5) - - checkEvaluation(LessThan(c1, c4), null, row) - checkEvaluation(LessThan(c1, c2), true, row) - checkEvaluation(LessThan(c1, Literal.create(null, IntegerType)), null, row) - checkEvaluation(LessThan(Literal.create(null, IntegerType), c2), null, row) - checkEvaluation( - LessThan(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) - - checkEvaluation(c1 < c2, true, row) - checkEvaluation(c1 <= c2, true, row) - checkEvaluation(c1 > c2, false, row) - checkEvaluation(c1 >= c2, false, row) - checkEvaluation(c1 === c2, false, row) - checkEvaluation(c1 !== c2, true, row) - checkEvaluation(c4 <=> c1, false, row) - checkEvaluation(c1 <=> c4, false, row) - checkEvaluation(c4 <=> c6, true, row) - checkEvaluation(c3 <=> c5, true, row) - checkEvaluation(Literal(true) <=> Literal.create(null, BooleanType), false, row) - checkEvaluation(Literal.create(null, BooleanType) <=> Literal(true), false, row) - } - - test("StringComparison") { - val row = create_row("abc", null) - val c1 = 'a.string.at(0) - val c2 = 'a.string.at(1) - - checkEvaluation(c1 contains "b", true, row) - checkEvaluation(c1 contains "x", false, row) - checkEvaluation(c2 contains "b", null, row) - checkEvaluation(c1 contains Literal.create(null, StringType), null, row) - - checkEvaluation(c1 startsWith "a", true, row) - checkEvaluation(c1 startsWith "b", false, row) - checkEvaluation(c2 startsWith "a", null, row) - checkEvaluation(c1 startsWith Literal.create(null, StringType), null, row) - - checkEvaluation(c1 endsWith "c", true, row) - checkEvaluation(c1 endsWith "b", false, row) - checkEvaluation(c2 endsWith "b", null, row) - checkEvaluation(c1 endsWith Literal.create(null, StringType), null, row) - } - - test("Substring") { - val row = create_row("example", "example".toArray.map(_.toByte)) - - val s = 'a.string.at(0) - - // substring from zero position with less-than-full length - checkEvaluation( - Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)), "ex", row) - checkEvaluation( - Substring(s, Literal.create(1, IntegerType), Literal.create(2, IntegerType)), "ex", row) - - // substring from zero position with full length - checkEvaluation( - Substring(s, Literal.create(0, IntegerType), Literal.create(7, IntegerType)), "example", row) - checkEvaluation( - Substring(s, Literal.create(1, IntegerType), Literal.create(7, IntegerType)), "example", row) - - // substring from zero position with greater-than-full length - checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(100, IntegerType)), - "example", row) - checkEvaluation(Substring(s, Literal.create(1, IntegerType), Literal.create(100, IntegerType)), - "example", row) - - // substring from nonzero position with less-than-full length - checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(2, IntegerType)), - "xa", row) - - // substring from nonzero position with full length - checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(6, IntegerType)), - "xample", row) - - // substring from nonzero position with greater-than-full length - checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(100, IntegerType)), - "xample", row) - - // zero-length substring (within string bounds) - checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(0, IntegerType)), - "", row) - - // zero-length substring (beyond string bounds) - checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), - "", row) - - // substring(null, _, _) -> null - checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), - null, create_row(null)) - - // substring(_, null, _) -> null - checkEvaluation(Substring(s, Literal.create(null, IntegerType), Literal.create(4, IntegerType)), - null, row) - - // substring(_, _, null) -> null - checkEvaluation( - Substring(s, Literal.create(100, IntegerType), Literal.create(null, IntegerType)), - null, - row) - - // 2-arg substring from zero position - checkEvaluation( - Substring(s, Literal.create(0, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), - "example", - row) - checkEvaluation( - Substring(s, Literal.create(1, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), - "example", - row) - - // 2-arg substring from nonzero position - checkEvaluation( - Substring(s, Literal.create(2, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), - "xample", - row) - - val s_notNull = 'a.string.notNull.at(0) - - assert(Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable - === true) - assert( - Substring(s_notNull, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable - === false) - assert(Substring(s_notNull, - Literal.create(null, IntegerType), Literal.create(2, IntegerType)).nullable === true) - assert(Substring(s_notNull, - Literal.create(0, IntegerType), Literal.create(null, IntegerType)).nullable === true) - - checkEvaluation(s.substr(0, 2), "ex", row) - checkEvaluation(s.substr(0), "example", row) - checkEvaluation(s.substring(0, 2), "ex", row) - checkEvaluation(s.substring(0), "example", row) - } - - test("SQRT") { - val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24)) - val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble)) - val rowSequence = inputSequence.map(l => create_row(l.toDouble)) - val d = 'a.double.at(0) - - for ((row, expected) <- rowSequence zip expectedResults) { - checkEvaluation(Sqrt(d), expected, row) - } - - checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) - checkEvaluation(Sqrt(-1), null, EmptyRow) - checkEvaluation(Sqrt(-1.5), null, EmptyRow) - } - - test("Bitwise operations") { - val row = create_row(1, 2, 3, null) - val c1 = 'a.int.at(0) - val c2 = 'a.int.at(1) - val c3 = 'a.int.at(2) - val c4 = 'a.int.at(3) - - checkEvaluation(BitwiseAnd(c1, c4), null, row) - checkEvaluation(BitwiseAnd(c1, c2), 0, row) - checkEvaluation(BitwiseAnd(c1, Literal.create(null, IntegerType)), null, row) - checkEvaluation( - BitwiseAnd(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) - - checkEvaluation(BitwiseOr(c1, c4), null, row) - checkEvaluation(BitwiseOr(c1, c2), 3, row) - checkEvaluation(BitwiseOr(c1, Literal.create(null, IntegerType)), null, row) - checkEvaluation( - BitwiseOr(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) - - checkEvaluation(BitwiseXor(c1, c4), null, row) - checkEvaluation(BitwiseXor(c1, c2), 3, row) - checkEvaluation(BitwiseXor(c1, Literal.create(null, IntegerType)), null, row) - checkEvaluation( - BitwiseXor(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) - - checkEvaluation(BitwiseNot(c4), null, row) - checkEvaluation(BitwiseNot(c1), -2, row) - checkEvaluation(BitwiseNot(Literal.create(null, IntegerType)), null, row) - - checkEvaluation(c1 & c2, 0, row) - checkEvaluation(c1 | c2, 3, row) - checkEvaluation(c1 ^ c2, 3, row) - checkEvaluation(~c1, -2, row) - } - - /** - * Used for testing math functions for DataFrames. - * @param c The DataFrame function - * @param f The functions in scala.math - * @param domain The set of values to run the function with - * @param expectNull Whether the given values should return null or not - * @tparam T Generic type for primitives - */ - def unaryMathFunctionEvaluation[@specialized(Int, Double, Float, Long) T]( - c: Expression => Expression, - f: T => T, - domain: Iterable[T] = (-20 to 20).map(_ * 0.1), - expectNull: Boolean = false): Unit = { - if (expectNull) { - domain.foreach { value => - checkEvaluation(c(Literal(value)), null, EmptyRow) - } - } else { - domain.foreach { value => - checkEvaluation(c(Literal(value)), f(value), EmptyRow) - } - } - checkEvaluation(c(Literal.create(null, DoubleType)), null, create_row(null)) - } - - test("sin") { - unaryMathFunctionEvaluation(Sin, math.sin) - } - - test("asin") { - unaryMathFunctionEvaluation(Asin, math.asin, (-10 to 10).map(_ * 0.1)) - unaryMathFunctionEvaluation(Asin, math.asin, (11 to 20).map(_ * 0.1), true) - } - - test("sinh") { - unaryMathFunctionEvaluation(Sinh, math.sinh) - } - - test("cos") { - unaryMathFunctionEvaluation(Cos, math.cos) - } - - test("acos") { - unaryMathFunctionEvaluation(Acos, math.acos, (-10 to 10).map(_ * 0.1)) - unaryMathFunctionEvaluation(Acos, math.acos, (11 to 20).map(_ * 0.1), true) - } - - test("cosh") { - unaryMathFunctionEvaluation(Cosh, math.cosh) - } - - test("tan") { - unaryMathFunctionEvaluation(Tan, math.tan) - } - - test("atan") { - unaryMathFunctionEvaluation(Atan, math.atan) - } - - test("tanh") { - unaryMathFunctionEvaluation(Tanh, math.tanh) - } - - test("toDegrees") { - unaryMathFunctionEvaluation(ToDegrees, math.toDegrees) - } - - test("toRadians") { - unaryMathFunctionEvaluation(ToRadians, math.toRadians) - } - - test("cbrt") { - unaryMathFunctionEvaluation(Cbrt, math.cbrt) - } - - test("ceil") { - unaryMathFunctionEvaluation(Ceil, math.ceil) - } - - test("floor") { - unaryMathFunctionEvaluation(Floor, math.floor) - } - - test("rint") { - unaryMathFunctionEvaluation(Rint, math.rint) - } - - test("exp") { - unaryMathFunctionEvaluation(Exp, math.exp) - } - - test("expm1") { - unaryMathFunctionEvaluation(Expm1, math.expm1) - } - - test("signum") { - unaryMathFunctionEvaluation[Double](Signum, math.signum) - } - - test("log") { - unaryMathFunctionEvaluation(Log, math.log, (0 to 20).map(_ * 0.1)) - unaryMathFunctionEvaluation(Log, math.log, (-5 to -1).map(_ * 0.1), true) - } - - test("log10") { - unaryMathFunctionEvaluation(Log10, math.log10, (0 to 20).map(_ * 0.1)) - unaryMathFunctionEvaluation(Log10, math.log10, (-5 to -1).map(_ * 0.1), true) - } - - test("log1p") { - unaryMathFunctionEvaluation(Log1p, math.log1p, (-1 to 20).map(_ * 0.1)) - unaryMathFunctionEvaluation(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), true) - } - - /** - * Used for testing math functions for DataFrames. - * @param c The DataFrame function - * @param f The functions in scala.math - * @param domain The set of values to run the function with - */ - def binaryMathFunctionEvaluation( - c: (Expression, Expression) => Expression, - f: (Double, Double) => Double, - domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)), - expectNull: Boolean = false): Unit = { - if (expectNull) { - domain.foreach { case (v1, v2) => - checkEvaluation(c(v1, v2), null, create_row(null)) - } - } else { - domain.foreach { case (v1, v2) => - checkEvaluation(c(v1, v2), f(v1 + 0.0, v2 + 0.0), EmptyRow) - checkEvaluation(c(v2, v1), f(v2 + 0.0, v1 + 0.0), EmptyRow) - } - } - checkEvaluation(c(Literal.create(null, DoubleType), 1.0), null, create_row(null)) - checkEvaluation(c(1.0, Literal.create(null, DoubleType)), null, create_row(null)) - } - - test("pow") { - binaryMathFunctionEvaluation(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) - binaryMathFunctionEvaluation(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), true) - } - - test("hypot") { - binaryMathFunctionEvaluation(Hypot, math.hypot) - } - - test("atan2") { - binaryMathFunctionEvaluation(Atan2, math.atan2) - } -} - -// TODO: Make the tests work with codegen. -class ExpressionEvaluationWithoutCodeGenSuite extends ExpressionEvaluationBaseSuite { - - override def checkEvaluation( - expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { - checkEvaluationWithoutCodegen(expression, expected, inputRow) - } - - test("CreateStruct") { - val row = Row(1, 2, 3) - val c1 = 'a.int.at(0).as("a") - val c3 = 'c.int.at(2).as("c") - checkEvaluation(CreateStruct(Seq(c1, c3)), Row(1, 3), row) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala new file mode 100644 index 0000000000000..f44f55dfb92d1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.StringType + + +class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + + // TODO: Add tests for all data types. + + test("boolean literals") { + checkEvaluation(Literal(true), true) + checkEvaluation(Literal(false), false) + } + + test("int literals") { + checkEvaluation(Literal(1), 1) + checkEvaluation(Literal(0L), 0L) + } + + test("double literals") { + List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { + d => { + checkEvaluation(Literal(d), d) + checkEvaluation(Literal(d.toFloat), d.toFloat) + } + } + } + + test("string literals") { + checkEvaluation(Literal("test"), "test") + checkEvaluation(Literal.create(null, StringType), null) + } + + test("sum two literals") { + checkEvaluation(Add(Literal(1), Literal(1)), 2) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala new file mode 100644 index 0000000000000..25ebc70d095d8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.DoubleType + +class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + /** + * Used for testing unary math expressions. + * + * @param c expression + * @param f The functions in scala.math + * @param domain The set of values to run the function with + * @param expectNull Whether the given values should return null or not + * @tparam T Generic type for primitives + */ + private def testUnary[T]( + c: Expression => Expression, + f: T => T, + domain: Iterable[T] = (-20 to 20).map(_ * 0.1), + expectNull: Boolean = false): Unit = { + if (expectNull) { + domain.foreach { value => + checkEvaluation(c(Literal(value)), null, EmptyRow) + } + } else { + domain.foreach { value => + checkEvaluation(c(Literal(value)), f(value), EmptyRow) + } + } + checkEvaluation(c(Literal.create(null, DoubleType)), null, create_row(null)) + } + + /** + * Used for testing binary math expressions. + * + * @param c The DataFrame function + * @param f The functions in scala.math + * @param domain The set of values to run the function with + */ + private def testBinary( + c: (Expression, Expression) => Expression, + f: (Double, Double) => Double, + domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)), + expectNull: Boolean = false): Unit = { + if (expectNull) { + domain.foreach { case (v1, v2) => + checkEvaluation(c(Literal(v1), Literal(v2)), null, create_row(null)) + } + } else { + domain.foreach { case (v1, v2) => + checkEvaluation(c(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow) + checkEvaluation(c(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow) + } + } + checkEvaluation(c(Literal.create(null, DoubleType), Literal(1.0)), null, create_row(null)) + checkEvaluation(c(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null)) + } + + test("sin") { + testUnary(Sin, math.sin) + } + + test("asin") { + testUnary(Asin, math.asin, (-10 to 10).map(_ * 0.1)) + testUnary(Asin, math.asin, (11 to 20).map(_ * 0.1), expectNull = true) + } + + test("sinh") { + testUnary(Sinh, math.sinh) + } + + test("cos") { + testUnary(Cos, math.cos) + } + + test("acos") { + testUnary(Acos, math.acos, (-10 to 10).map(_ * 0.1)) + testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNull = true) + } + + test("cosh") { + testUnary(Cosh, math.cosh) + } + + test("tan") { + testUnary(Tan, math.tan) + } + + test("atan") { + testUnary(Atan, math.atan) + } + + test("tanh") { + testUnary(Tanh, math.tanh) + } + + test("toDegrees") { + testUnary(ToDegrees, math.toDegrees) + } + + test("toRadians") { + testUnary(ToRadians, math.toRadians) + } + + test("cbrt") { + testUnary(Cbrt, math.cbrt) + } + + test("ceil") { + testUnary(Ceil, math.ceil) + } + + test("floor") { + testUnary(Floor, math.floor) + } + + test("rint") { + testUnary(Rint, math.rint) + } + + test("exp") { + testUnary(Exp, math.exp) + } + + test("expm1") { + testUnary(Expm1, math.expm1) + } + + test("signum") { + testUnary[Double](Signum, math.signum) + } + + test("log") { + testUnary(Log, math.log, (0 to 20).map(_ * 0.1)) + testUnary(Log, math.log, (-5 to -1).map(_ * 0.1), expectNull = true) + } + + test("log10") { + testUnary(Log10, math.log10, (0 to 20).map(_ * 0.1)) + testUnary(Log10, math.log10, (-5 to -1).map(_ * 0.1), expectNull = true) + } + + test("log1p") { + testUnary(Log1p, math.log1p, (-1 to 20).map(_ * 0.1)) + testUnary(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), expectNull = true) + } + + test("pow") { + testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) + testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true) + } + + test("hypot") { + testBinary(Hypot, math.hypot) + } + + test("atan2") { + testBinary(Atan2, math.atan2) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala new file mode 100644 index 0000000000000..ccdada8b56f83 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.{BooleanType, StringType, ShortType} + +class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("null checking") { + val row = create_row("^Ba*n", null, true, null) + val c1 = 'a.string.at(0) + val c2 = 'a.string.at(1) + val c3 = 'a.boolean.at(2) + val c4 = 'a.boolean.at(3) + + checkEvaluation(c1.isNull, false, row) + checkEvaluation(c1.isNotNull, true, row) + + checkEvaluation(c2.isNull, true, row) + checkEvaluation(c2.isNotNull, false, row) + + checkEvaluation(Literal.create(1, ShortType).isNull, false) + checkEvaluation(Literal.create(1, ShortType).isNotNull, true) + + checkEvaluation(Literal.create(null, ShortType).isNull, true) + checkEvaluation(Literal.create(null, ShortType).isNotNull, false) + + checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row) + checkEvaluation(Coalesce(Literal.create(null, StringType) :: Nil), null, row) + checkEvaluation(Coalesce(Literal.create(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row) + + checkEvaluation( + If(c3, Literal.create("a", StringType), Literal.create("b", StringType)), "a", row) + checkEvaluation(If(c3, c1, c2), "^Ba*n", row) + checkEvaluation(If(c4, c2, c1), "^Ba*n", row) + checkEvaluation(If(Literal.create(null, BooleanType), c2, c1), "^Ba*n", row) + checkEvaluation(If(Literal.create(true, BooleanType), c1, c2), "^Ba*n", row) + checkEvaluation(If(Literal.create(false, BooleanType), c2, c1), "^Ba*n", row) + checkEvaluation(If(Literal.create(false, BooleanType), + Literal.create("a", StringType), Literal.create("b", StringType)), "b", row) + + checkEvaluation(c1 in (c1, c2), true, row) + checkEvaluation( + Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType)), true, row) + checkEvaluation( + Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType), c2), true, row) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala new file mode 100644 index 0000000000000..b6261bfba0786 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.sql.{Date, Timestamp} + +import scala.collection.immutable.HashSet + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.types.{IntegerType, BooleanType} + + +class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { + + private def booleanLogicTest( + name: String, + op: (Expression, Expression) => Expression, + truthTable: Seq[(Any, Any, Any)]) { + test(s"3VL $name") { + truthTable.foreach { + case (l, r, answer) => + val expr = op(Literal.create(l, BooleanType), Literal.create(r, BooleanType)) + checkEvaluation(expr, answer) + } + } + } + + // scalastyle:off + /** + * Checks for three-valued-logic. Based on: + * http://en.wikipedia.org/wiki/Null_(SQL)#Comparisons_with_NULL_and_the_three-valued_logic_.283VL.29 + * I.e. in flat cpo "False -> Unknown -> True", + * OR is lowest upper bound, + * AND is greatest lower bound. + * p q p OR q p AND q p = q + * True True True True True + * True False True False False + * True Unknown True Unknown Unknown + * False True True False False + * False False False False True + * False Unknown Unknown False Unknown + * Unknown True True Unknown Unknown + * Unknown False Unknown False Unknown + * Unknown Unknown Unknown Unknown Unknown + * + * p NOT p + * True False + * False True + * Unknown Unknown + */ + // scalastyle:on + val notTrueTable = + (true, false) :: + (false, true) :: + (null, null) :: Nil + + test("3VL Not") { + notTrueTable.foreach { case (v, answer) => + checkEvaluation(Not(Literal.create(v, BooleanType)), answer) + } + } + + booleanLogicTest("AND", And, + (true, true, true) :: + (true, false, false) :: + (true, null, null) :: + (false, true, false) :: + (false, false, false) :: + (false, null, false) :: + (null, true, null) :: + (null, false, false) :: + (null, null, null) :: Nil) + + booleanLogicTest("OR", Or, + (true, true, true) :: + (true, false, true) :: + (true, null, true) :: + (false, true, true) :: + (false, false, false) :: + (false, null, null) :: + (null, true, true) :: + (null, false, null) :: + (null, null, null) :: Nil) + + booleanLogicTest("=", EqualTo, + (true, true, true) :: + (true, false, false) :: + (true, null, null) :: + (false, true, false) :: + (false, false, true) :: + (false, null, null) :: + (null, true, null) :: + (null, false, null) :: + (null, null, null) :: Nil) + + test("IN") { + checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) + checkEvaluation( + And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))), + true) + } + + test("INSET") { + val hS = HashSet[Any]() + 1 + 2 + val nS = HashSet[Any]() + 1 + 2 + null + val one = Literal(1) + val two = Literal(2) + val three = Literal(3) + val nl = Literal(null) + val s = Seq(one, two) + val nullS = Seq(one, two, null) + checkEvaluation(InSet(one, hS), true) + checkEvaluation(InSet(two, hS), true) + checkEvaluation(InSet(two, nS), true) + checkEvaluation(InSet(nl, nS), true) + checkEvaluation(InSet(three, hS), false) + checkEvaluation(InSet(three, nS), false) + checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true) + } + + + test("BinaryComparison") { + val row = create_row(1, 2, 3, null, 3, null) + val c1 = 'a.int.at(0) + val c2 = 'a.int.at(1) + val c3 = 'a.int.at(2) + val c4 = 'a.int.at(3) + val c5 = 'a.int.at(4) + val c6 = 'a.int.at(5) + + checkEvaluation(LessThan(c1, c4), null, row) + checkEvaluation(LessThan(c1, c2), true, row) + checkEvaluation(LessThan(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation(LessThan(Literal.create(null, IntegerType), c2), null, row) + checkEvaluation( + LessThan(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) + + checkEvaluation(c1 < c2, true, row) + checkEvaluation(c1 <= c2, true, row) + checkEvaluation(c1 > c2, false, row) + checkEvaluation(c1 >= c2, false, row) + checkEvaluation(c1 === c2, false, row) + checkEvaluation(c1 !== c2, true, row) + checkEvaluation(c4 <=> c1, false, row) + checkEvaluation(c1 <=> c4, false, row) + checkEvaluation(c4 <=> c6, true, row) + checkEvaluation(c3 <=> c5, true, row) + checkEvaluation(Literal(true) <=> Literal.create(null, BooleanType), false, row) + checkEvaluation(Literal.create(null, BooleanType) <=> Literal(true), false, row) + + val d1 = DateUtils.fromJavaDate(Date.valueOf("1970-01-01")) + val d2 = DateUtils.fromJavaDate(Date.valueOf("1970-01-02")) + checkEvaluation(Literal(d1) < Literal(d2), true) + + val ts1 = new Timestamp(12) + val ts2 = new Timestamp(123) + checkEvaluation(Literal("ab") < Literal("abc"), true) + checkEvaluation(Literal(ts1) < Literal(ts2), true) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala new file mode 100644 index 0000000000000..2e81296c4e623 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.{IntegerType, StringType} + + +class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("StringComparison") { + val row = create_row("abc", null) + val c1 = 'a.string.at(0) + val c2 = 'a.string.at(1) + + checkEvaluation(c1 contains "b", true, row) + checkEvaluation(c1 contains "x", false, row) + checkEvaluation(c2 contains "b", null, row) + checkEvaluation(c1 contains Literal.create(null, StringType), null, row) + + checkEvaluation(c1 startsWith "a", true, row) + checkEvaluation(c1 startsWith "b", false, row) + checkEvaluation(c2 startsWith "a", null, row) + checkEvaluation(c1 startsWith Literal.create(null, StringType), null, row) + + checkEvaluation(c1 endsWith "c", true, row) + checkEvaluation(c1 endsWith "b", false, row) + checkEvaluation(c2 endsWith "b", null, row) + checkEvaluation(c1 endsWith Literal.create(null, StringType), null, row) + } + + test("Substring") { + val row = create_row("example", "example".toArray.map(_.toByte)) + + val s = 'a.string.at(0) + + // substring from zero position with less-than-full length + checkEvaluation( + Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)), "ex", row) + checkEvaluation( + Substring(s, Literal.create(1, IntegerType), Literal.create(2, IntegerType)), "ex", row) + + // substring from zero position with full length + checkEvaluation( + Substring(s, Literal.create(0, IntegerType), Literal.create(7, IntegerType)), "example", row) + checkEvaluation( + Substring(s, Literal.create(1, IntegerType), Literal.create(7, IntegerType)), "example", row) + + // substring from zero position with greater-than-full length + checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(100, IntegerType)), + "example", row) + checkEvaluation(Substring(s, Literal.create(1, IntegerType), Literal.create(100, IntegerType)), + "example", row) + + // substring from nonzero position with less-than-full length + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(2, IntegerType)), + "xa", row) + + // substring from nonzero position with full length + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(6, IntegerType)), + "xample", row) + + // substring from nonzero position with greater-than-full length + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(100, IntegerType)), + "xample", row) + + // zero-length substring (within string bounds) + checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(0, IntegerType)), + "", row) + + // zero-length substring (beyond string bounds) + checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), + "", row) + + // substring(null, _, _) -> null + checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), + null, create_row(null)) + + // substring(_, null, _) -> null + checkEvaluation(Substring(s, Literal.create(null, IntegerType), Literal.create(4, IntegerType)), + null, row) + + // substring(_, _, null) -> null + checkEvaluation( + Substring(s, Literal.create(100, IntegerType), Literal.create(null, IntegerType)), + null, + row) + + // 2-arg substring from zero position + checkEvaluation( + Substring(s, Literal.create(0, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), + "example", + row) + checkEvaluation( + Substring(s, Literal.create(1, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), + "example", + row) + + // 2-arg substring from nonzero position + checkEvaluation( + Substring(s, Literal.create(2, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), + "xample", + row) + + val s_notNull = 'a.string.notNull.at(0) + + assert(Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable + === true) + assert( + Substring(s_notNull, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable + === false) + assert(Substring(s_notNull, + Literal.create(null, IntegerType), Literal.create(2, IntegerType)).nullable === true) + assert(Substring(s_notNull, + Literal.create(0, IntegerType), Literal.create(null, IntegerType)).nullable === true) + + checkEvaluation(s.substr(0, 2), "ex", row) + checkEvaluation(s.substr(0), "example", row) + checkEvaluation(s.substring(0, 2), "ex", row) + checkEvaluation(s.substring(0), "example", row) + } + + test("LIKE literal Regular Expression") { + checkEvaluation(Literal.create(null, StringType).like("a"), null) + checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null) + checkEvaluation(Literal.create(null, StringType).like(Literal.create(null, StringType)), null) + checkEvaluation("abdef" like "abdef", true) + checkEvaluation("a_%b" like "a\\__b", true) + checkEvaluation("addb" like "a_%b", true) + checkEvaluation("addb" like "a\\__b", false) + checkEvaluation("addb" like "a%\\%b", false) + checkEvaluation("a_%b" like "a%\\%b", true) + checkEvaluation("addb" like "a%", true) + checkEvaluation("addb" like "**", false) + checkEvaluation("abc" like "a%", true) + checkEvaluation("abc" like "b%", false) + checkEvaluation("abc" like "bc%", false) + checkEvaluation("a\nb" like "a_b", true) + checkEvaluation("ab" like "a%b", true) + checkEvaluation("a\nb" like "a%b", true) + } + + test("LIKE Non-literal Regular Expression") { + val regEx = 'a.string.at(0) + checkEvaluation("abcd" like regEx, null, create_row(null)) + checkEvaluation("abdef" like regEx, true, create_row("abdef")) + checkEvaluation("a_%b" like regEx, true, create_row("a\\__b")) + checkEvaluation("addb" like regEx, true, create_row("a_%b")) + checkEvaluation("addb" like regEx, false, create_row("a\\__b")) + checkEvaluation("addb" like regEx, false, create_row("a%\\%b")) + checkEvaluation("a_%b" like regEx, true, create_row("a%\\%b")) + checkEvaluation("addb" like regEx, true, create_row("a%")) + checkEvaluation("addb" like regEx, false, create_row("**")) + checkEvaluation("abc" like regEx, true, create_row("a%")) + checkEvaluation("abc" like regEx, false, create_row("b%")) + checkEvaluation("abc" like regEx, false, create_row("bc%")) + checkEvaluation("a\nb" like regEx, true, create_row("a_b")) + checkEvaluation("ab" like regEx, true, create_row("a%b")) + checkEvaluation("a\nb" like regEx, true, create_row("a%b")) + + checkEvaluation(Literal.create(null, StringType) like regEx, null, create_row("bc%")) + } + + test("RLIKE literal Regular Expression") { + checkEvaluation(Literal.create(null, StringType) rlike "abdef", null) + checkEvaluation("abdef" rlike Literal.create(null, StringType), null) + checkEvaluation(Literal.create(null, StringType) rlike Literal.create(null, StringType), null) + checkEvaluation("abdef" rlike "abdef", true) + checkEvaluation("abbbbc" rlike "a.*c", true) + + checkEvaluation("fofo" rlike "^fo", true) + checkEvaluation("fo\no" rlike "^fo\no$", true) + checkEvaluation("Bn" rlike "^Ba*n", true) + checkEvaluation("afofo" rlike "fo", true) + checkEvaluation("afofo" rlike "^fo", false) + checkEvaluation("Baan" rlike "^Ba?n", false) + checkEvaluation("axe" rlike "pi|apa", false) + checkEvaluation("pip" rlike "^(pi)*$", false) + + checkEvaluation("abc" rlike "^ab", true) + checkEvaluation("abc" rlike "^bc", false) + checkEvaluation("abc" rlike "^ab", true) + checkEvaluation("abc" rlike "^bc", false) + + intercept[java.util.regex.PatternSyntaxException] { + evaluate("abbbbc" rlike "**") + } + } + + test("RLIKE Non-literal Regular Expression") { + val regEx = 'a.string.at(0) + checkEvaluation("abdef" rlike regEx, true, create_row("abdef")) + checkEvaluation("abbbbc" rlike regEx, true, create_row("a.*c")) + checkEvaluation("fofo" rlike regEx, true, create_row("^fo")) + checkEvaluation("fo\no" rlike regEx, true, create_row("^fo\no$")) + checkEvaluation("Bn" rlike regEx, true, create_row("^Ba*n")) + + intercept[java.util.regex.PatternSyntaxException] { + evaluate("abbbbc" rlike regEx, create_row("**")) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala index a4a3a66b8b229..f33a18d53b1a9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -24,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ * Overrides our expression evaluation tests and reruns them after optimization has occured. This * is to ensure that constant folding and other optimizations do not break anything. */ -class ExpressionOptimizationSuite extends ExpressionEvaluationSuite { +class ExpressionOptimizationSuite extends SparkFunSuite with ExpressionEvalHelper { override def checkEvaluation( expression: Expression, expected: Any, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 77327f2b84eaa..454af47913bf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -24,7 +24,6 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.mathfuncs._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils From 72ba0fc4fd441f4bf25f19bed59ba0a39dd04b65 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 7 Jun 2015 23:16:19 -0700 Subject: [PATCH 231/254] [SPARK-8154][SQL] Remove Term/Code type aliases in code generation. From my perspective as a code reviewer, I find them more confusing than using String directly. Author: Reynold Xin Closes #6694 from rxin/SPARK-8154 and squashes the following commits: 4e5056c [Reynold Xin] [SPARK-8154][SQL] Remove Term/Code type aliases in code generation. --- .../catalyst/expressions/BoundAttribute.scala | 4 ++-- .../spark/sql/catalyst/expressions/Cast.scala | 4 ++-- .../sql/catalyst/expressions/Expression.scala | 10 +++++----- .../sql/catalyst/expressions/arithmetic.scala | 16 +++++++-------- .../sql/catalyst/expressions/bitwise.scala | 10 ++++++++-- .../expressions/codegen/CodeGenerator.scala | 20 +++++++++---------- .../expressions/codegen/package.scala | 3 --- .../catalyst/expressions/conditionals.scala | 6 +++--- .../expressions/decimalFunctions.scala | 6 +++--- .../sql/catalyst/expressions/literals.scala | 4 ++-- .../spark/sql/catalyst/expressions/math.scala | 8 ++++---- .../catalyst/expressions/nullFunctions.scala | 10 +++++----- .../sql/catalyst/expressions/predicates.scala | 16 +++++++-------- .../spark/sql/catalyst/expressions/sets.scala | 8 ++++---- .../expressions/stringOperations.scala | 10 +++++----- 15 files changed, 69 insertions(+), 66 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 005de3166095f..fcadf9595e768 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.Logging import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.trees @@ -43,7 +43,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def exprId: ExprId = throw new UnsupportedOperationException - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { s""" boolean ${ev.isNull} = i.isNullAt($ordinal); ${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 2a1f96409daf4..18102d1acb5b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ @@ -435,7 +435,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w if (evaluated == null) null else cast(evaluated) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { // TODO(cg): Add support for more data types. (child.dataType, dataType) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 432d65eee54fb..a9a9c0cfb7027 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext, Term} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ @@ -76,7 +76,7 @@ abstract class Expression extends TreeNode[Expression] { * @param ev an [[GeneratedExpressionCode]] with unique terms. * @return Java source code */ - protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { ctx.references += this val objectTerm = ctx.freshName("obj") s""" @@ -166,7 +166,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express protected def defineCodeGen( ctx: CodeGenContext, ev: GeneratedExpressionCode, - f: (Term, Term) => Code): String = { + f: (String, String) => String): String = { // TODO: Right now some timestamp tests fail if we enforce this... if (left.dataType != right.dataType) { // log.warn(s"${left.dataType} != ${right.dataType}") @@ -182,7 +182,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${eval2.code} - if(!${eval2.isNull}) { + if (!${eval2.isNull}) { ${ev.primitive} = $resultCode; } else { ${ev.isNull} = true; @@ -217,7 +217,7 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio protected def defineCodeGen( ctx: CodeGenContext, ev: GeneratedExpressionCode, - f: Term => Code): Code = { + f: String => String): String = { val eval = child.gen(ctx) // reuse the previous isNull ev.isNull = eval.isNull diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index d4efda2e04c29..124274c94203c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{Code, GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -50,7 +50,7 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic { private lazy val numeric = TypeUtils.getNumeric(dataType) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, c => s"c.unary_$$minus()") case dt: NumericType => defineCodeGen(ctx, ev, c => s"-($c)") } @@ -74,7 +74,7 @@ case class Sqrt(child: Expression) extends UnaryArithmetic { else math.sqrt(value) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval = child.gen(ctx) eval.code + s""" boolean ${ev.isNull} = ${eval.isNull}; @@ -138,7 +138,7 @@ abstract class BinaryArithmetic extends BinaryExpression { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") // byte and short are casted into int when add, minus, times or divide @@ -236,7 +236,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic /** * Special case handling due to division by 0 => null. */ - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val test = if (left.dataType.isInstanceOf[DecimalType]) { @@ -296,7 +296,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet /** * Special case handling for x % 0 ==> null. */ - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val test = if (left.dataType.isInstanceOf[DecimalType]) { @@ -346,7 +346,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { if (ctx.isNativeType(left.dataType)) { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) @@ -400,7 +400,7 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { if (ctx.isNativeType(left.dataType)) { val eval1 = left.gen(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala index ef34586261e70..9002dda7bf4d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala @@ -25,6 +25,8 @@ import org.apache.spark.sql.types._ /** * A function that calculates bitwise and(&) of two numbers. + * + * Code generation inherited from BinaryArithmetic. */ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "&" @@ -48,6 +50,8 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme /** * A function that calculates bitwise or(|) of two numbers. + * + * Code generation inherited from BinaryArithmetic. */ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "|" @@ -71,6 +75,8 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet /** * A function that calculates bitwise xor of two numbers. + * + * Code generation inherited from BinaryArithmetic. */ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "^" @@ -112,8 +118,8 @@ case class BitwiseNot(child: Expression) extends UnaryArithmetic { ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any] } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { - defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)})~($c)") + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)") } protected override def evalInternal(evalE: Any) = not(evalE) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index c8d0aaf79f5f2..e95682f952a7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -40,7 +40,7 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] * @param primitive A term for a possible primitive value of the result of the evaluation. Not * valid if `isNull` is set to `true`. */ -case class GeneratedExpressionCode(var code: Code, var isNull: Term, var primitive: Term) +case class GeneratedExpressionCode(var code: String, var isNull: String, var primitive: String) /** * A context for codegen, which is used to bookkeeping the expressions those are not supported @@ -65,14 +65,14 @@ class CodeGenContext { * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` * function.) */ - def freshName(prefix: String): Term = { + def freshName(prefix: String): String = { s"$prefix${curId.getAndIncrement}" } /** * Return the code to access a column for given DataType */ - def getColumn(dataType: DataType, ordinal: Int): Code = { + def getColumn(dataType: DataType, ordinal: Int): String = { if (isNativeType(dataType)) { s"i.${accessorForType(dataType)}($ordinal)" } else { @@ -83,7 +83,7 @@ class CodeGenContext { /** * Return the code to update a column in Row for given DataType */ - def setColumn(dataType: DataType, ordinal: Int, value: Term): Code = { + def setColumn(dataType: DataType, ordinal: Int, value: String): String = { if (isNativeType(dataType)) { s"${mutatorForType(dataType)}($ordinal, $value)" } else { @@ -94,7 +94,7 @@ class CodeGenContext { /** * Return the name of accessor in Row for a DataType */ - def accessorForType(dt: DataType): Term = dt match { + def accessorForType(dt: DataType): String = dt match { case IntegerType => "getInt" case other => s"get${boxedType(dt)}" } @@ -102,7 +102,7 @@ class CodeGenContext { /** * Return the name of mutator in Row for a DataType */ - def mutatorForType(dt: DataType): Term = dt match { + def mutatorForType(dt: DataType): String = dt match { case IntegerType => "setInt" case other => s"set${boxedType(dt)}" } @@ -110,7 +110,7 @@ class CodeGenContext { /** * Return the Java type for a DataType */ - def javaType(dt: DataType): Term = dt match { + def javaType(dt: DataType): String = dt match { case IntegerType => "int" case LongType => "long" case ShortType => "short" @@ -131,7 +131,7 @@ class CodeGenContext { /** * Return the boxed type in Java */ - def boxedType(dt: DataType): Term = dt match { + def boxedType(dt: DataType): String = dt match { case IntegerType => "Integer" case LongType => "Long" case ShortType => "Short" @@ -146,7 +146,7 @@ class CodeGenContext { /** * Return the representation of default value for given DataType */ - def defaultValue(dt: DataType): Term = dt match { + def defaultValue(dt: DataType): String = dt match { case BooleanType => "false" case FloatType => "-1.0f" case ShortType => "(short)-1" @@ -161,7 +161,7 @@ class CodeGenContext { /** * Returns a function to generate equal expression in Java */ - def equalFunc(dataType: DataType): ((Term, Term) => Code) = dataType match { + def equalFunc(dataType: DataType): ((String, String) => String) = dataType match { case BinaryType => { case (eval1, eval2) => s"java.util.Arrays.equals($eval1, $eval2)" } case IntegerType | BooleanType | LongType | DoubleType | FloatType | ShortType | ByteType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala index 6f9589d20445e..7f1b12cdd5800 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala @@ -27,9 +27,6 @@ import org.apache.spark.util.Utils */ package object codegen { - type Term = String - type Code = String - /** Canonicalizes an expression so those that differ only by names can reuse the same code. */ object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] { val batches = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala index 3aa86edd7ab20..1a5cde26c9b13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -50,7 +50,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val condEval = predicate.gen(ctx) val trueEval = trueValue.gen(ctx) val falseEval = falseValue.gen(ctx) @@ -155,7 +155,7 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { return res } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val len = branchesArr.length val got = ctx.freshName("got") @@ -248,7 +248,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW return res } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val keyEval = key.gen(ctx) val len = branchesArr.length val got = ctx.freshName("got") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index ddfadf314f838..8ab6d977dd3a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ /** Return the unscaled Long value of a Decimal, assuming it fits in a Long */ @@ -37,7 +37,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()") } } @@ -59,7 +59,7 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval = child.gen(ctx) eval.code + s""" boolean ${ev.isNull} = ${eval.isNull}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 3a9271678bc9c..297b35b4da94c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.codegen.{Code, CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ @@ -88,7 +88,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres override def eval(input: Row): Any = value - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { // change the isNull and primitive to consts, to inline them if (value == null) { ev.isNull = "true" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index a18067e4a58f1..7dacb6a9b47b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -48,7 +48,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) // name of function in java.lang.Math def funcName: String = name.toLowerCase - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval = child.gen(ctx) eval.code + s""" boolean ${ev.isNull} = ${eval.isNull}; @@ -93,7 +93,7 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)") } } @@ -180,7 +180,7 @@ case class Atan2(left: Expression, right: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s""" if (Double.valueOf(${ev.primitive}).isNaN()) { ${ev.isNull} = true; @@ -194,7 +194,7 @@ case class Hypot(left: Expression, right: Expression) case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") { - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s""" if (Double.valueOf(${ev.primitive}).isNaN()) { ${ev.isNull} = true; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 9ecfb3ccc262f..c2d1a4eadae29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.types.DataType @@ -53,7 +53,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { result } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { s""" boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; @@ -81,7 +81,7 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr child.eval(input) == null } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval = child.gen(ctx) ev.isNull = "false" ev.primitive = eval.isNull @@ -100,7 +100,7 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E child.eval(input) != null } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval = child.gen(ctx) ev.isNull = "false" ev.primitive = s"(!(${eval.isNull}))" @@ -130,7 +130,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate numNonNulls >= n } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val nonnull = ctx.freshName("nonnull") val code = children.map { e => val eval = e.gen(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 5edcf3bd77d20..3cbdfdfb13847 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ @@ -84,7 +84,7 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"!($c)") } } @@ -147,7 +147,7 @@ case class And(left: Expression, right: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) @@ -155,7 +155,7 @@ case class And(left: Expression, right: Expression) s""" ${eval1.code} boolean ${ev.isNull} = false; - boolean ${ev.primitive} = false; + boolean ${ev.primitive} = false; if (!${eval1.isNull} && !${eval1.primitive}) { } else { @@ -196,7 +196,7 @@ case class Or(left: Expression, right: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) @@ -249,7 +249,7 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { left.dataType match { case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, { (c1, c3) => s"$c1 $symbol $c3" @@ -280,7 +280,7 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison if (left.dataType != BinaryType) l == r else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, ctx.equalFunc(left.dataType)) } } @@ -304,7 +304,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val equalCode = ctx.equalFunc(left.dataType)(eval1.primitive, eval2.primitive) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index b39349b988389..2bcb960e9177e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -61,7 +61,7 @@ case class NewSet(elementType: DataType) extends LeafExpression { new OpenHashSet[Any]() } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { elementType match { case IntegerType | LongType => ev.isNull = "false" @@ -103,7 +103,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType elementType match { case IntegerType | LongType => @@ -154,7 +154,7 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType elementType match { case IntegerType | LongType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 78adb509b470b..aae122a981e47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -139,7 +139,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE override def toString: String = s"Upper($child)" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") } } @@ -153,7 +153,7 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE override def toString: String = s"Lower($child)" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") } } @@ -190,7 +190,7 @@ trait StringComparison extends ExpectsInputTypes { case class Contains(left: Expression, right: Expression) extends BinaryExpression with Predicate with StringComparison { override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") } } @@ -201,7 +201,7 @@ case class Contains(left: Expression, right: Expression) case class StartsWith(left: Expression, right: Expression) extends BinaryExpression with Predicate with StringComparison { override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") } } @@ -212,7 +212,7 @@ case class StartsWith(left: Expression, right: Expression) case class EndsWith(left: Expression, right: Expression) extends BinaryExpression with Predicate with StringComparison { override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") } } From 10fc2f6f51819f263eec941bdc1db22c554f9118 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Mon, 8 Jun 2015 01:07:50 -0700 Subject: [PATCH 232/254] [SPARK-4761] [DOC] [SQL] kryo default setting in SQL Thrift server this is a follow up of #3621 /cc liancheng pwendell Author: Daoyuan Wang Closes #6639 from adrian-wang/kryodoc and squashes the following commits: 3c4b1cf [Daoyuan Wang] [DOC] kryo default setting in SQL Thrift server --- docs/configuration.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 9667cebe0b87c..3960e7e78bde1 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -618,7 +618,7 @@ Apart from these, the following properties are also available, and may be useful spark.kryo.referenceTracking - true + true (false when using Spark SQL Thrift Server) Whether to track references to the same object when serializing data with Kryo, which is necessary if your object graphs have loops and useful for efficiency if they contain multiple @@ -679,7 +679,10 @@ Apart from these, the following properties are also available, and may be useful spark.serializer - org.apache.spark.serializer.
    JavaSerializer + + org.apache.spark.serializer.
    JavaSerializer (org.apache.spark.serializer.
    + KryoSerializer when using Spark SQL Thrift Server) + Class to use for serializing objects that will be sent over the network or need to be cached in serialized form. The default of Java serialization works with any Serializable Java object From eacd4a929bf5d697c33b1b705dcf958651cd20f4 Mon Sep 17 00:00:00 2001 From: linweizhong Date: Mon, 8 Jun 2015 09:34:16 +0100 Subject: [PATCH 233/254] [SPARK-7705] [YARN] Cleanup of .sparkStaging directory fails if application is killed As I have tested, if we cancel or kill the app then the final status may be undefined, killed or succeeded, so clean up staging directory when appMaster exit at any final application status. Author: linweizhong Closes #6409 from Sephiroth-Lin/SPARK-7705 and squashes the following commits: 3a5a0a5 [linweizhong] Update 83dc274 [linweizhong] Update 923d44d [linweizhong] Update 0dd7c2d [linweizhong] Update b76a102 [linweizhong] Update code style 7846b69 [linweizhong] Update bd6cf0d [linweizhong] Refactor aed9f18 [linweizhong] Clean up stagingDir when launch app on yarn 95595c3 [linweizhong] Cleanup of .sparkStaging directory when AppMaster exit at any final application status --- .../org/apache/spark/deploy/yarn/Client.scala | 34 ++++++++++++------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 234051eb7d3bb..f4d43214b08ca 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -121,24 +121,31 @@ private[spark] class Client( } catch { case e: Throwable => if (appId != null) { - val appStagingDir = getAppStagingDir(appId) - try { - val preserveFiles = sparkConf.getBoolean("spark.yarn.preserve.staging.files", false) - val stagingDirPath = new Path(appStagingDir) - val fs = FileSystem.get(hadoopConf) - if (!preserveFiles && fs.exists(stagingDirPath)) { - logInfo("Deleting staging directory " + stagingDirPath) - fs.delete(stagingDirPath, true) - } - } catch { - case ioe: IOException => - logWarning("Failed to cleanup staging dir " + appStagingDir, ioe) - } + cleanupStagingDir(appId) } throw e } } + /** + * Cleanup application staging directory. + */ + private def cleanupStagingDir(appId: ApplicationId): Unit = { + val appStagingDir = getAppStagingDir(appId) + try { + val preserveFiles = sparkConf.getBoolean("spark.yarn.preserve.staging.files", false) + val stagingDirPath = new Path(appStagingDir) + val fs = FileSystem.get(hadoopConf) + if (!preserveFiles && fs.exists(stagingDirPath)) { + logInfo("Deleting staging directory " + stagingDirPath) + fs.delete(stagingDirPath, true) + } + } catch { + case ioe: IOException => + logWarning("Failed to cleanup staging dir " + appStagingDir, ioe) + } + } + /** * Set up the context for submitting our ApplicationMaster. * This uses the YarnClientApplication not available in the Yarn alpha API. @@ -782,6 +789,7 @@ private[spark] class Client( if (state == YarnApplicationState.FINISHED || state == YarnApplicationState.FAILED || state == YarnApplicationState.KILLED) { + cleanupStagingDir(appId) return (state, report.getFinalApplicationStatus) } From 03ef6be9ce61a13dcd9d8c71298fb4be39119411 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 8 Jun 2015 17:50:38 +0800 Subject: [PATCH 234/254] [SPARK-7939] [SQL] Add conf to enable/disable partition column type inference JIRA: https://issues.apache.org/jira/browse/SPARK-7939 Author: Liang-Chi Hsieh Closes #6503 from viirya/disable_partition_type_inference and squashes the following commits: 3e90470 [Liang-Chi Hsieh] Default to enable type inference and update docs. 455edb1 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into disable_partition_type_inference 9a57933 [Liang-Chi Hsieh] Add conf to enable/disable partition column type inference. --- docs/sql-programming-guide.md | 6 +- .../scala/org/apache/spark/sql/SQLConf.scala | 6 ++ .../spark/sql/sources/PartitioningUtils.scala | 48 ++++++----- .../apache/spark/sql/sources/interfaces.scala | 4 +- .../ParquetPartitionDiscoverySuite.scala | 79 ++++++++++++++++++- 5 files changed, 119 insertions(+), 24 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index cde5830c733e0..40e33f757d693 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1102,7 +1102,11 @@ root {% endhighlight %} Notice that the data types of the partitioning columns are automatically inferred. Currently, -numeric data types and string type are supported. +numeric data types and string type are supported. Sometimes users may not want to automatically +infer the data types of the partitioning columns. For these use cases, the automatic type inference +can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, which is default to +`true`. When type inference is disabled, string type will be used for the partitioning columns. + ### Schema merging diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 77c6af27d1007..c778889045d02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -71,6 +71,9 @@ private[spark] object SQLConf { // Whether to perform partition discovery when loading external data sources. Default to true. val PARTITION_DISCOVERY_ENABLED = "spark.sql.sources.partitionDiscovery.enabled" + // Whether to perform partition column type inference. Default to true. + val PARTITION_COLUMN_TYPE_INFERENCE = "spark.sql.sources.partitionColumnTypeInference.enabled" + // The output committer class used by FSBasedRelation. The specified class needs to be a // subclass of org.apache.hadoop.mapreduce.OutputCommitter. val OUTPUT_COMMITTER_CLASS = "spark.sql.sources.outputCommitterClass" @@ -250,6 +253,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def partitionDiscoveryEnabled() = getConf(SQLConf.PARTITION_DISCOVERY_ENABLED, "true").toBoolean + private[spark] def partitionColumnTypeInferenceEnabled() = + getConf(SQLConf.PARTITION_COLUMN_TYPE_INFERENCE, "true").toBoolean + // Do not use a value larger than 4000 as the default value of this property. // See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information. private[spark] def schemaStringLengthThreshold: Int = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala index c4c99de5a38dc..9f6ec2ed8fc8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala @@ -72,10 +72,11 @@ private[sql] object PartitioningUtils { */ private[sql] def parsePartitions( paths: Seq[Path], - defaultPartitionName: String): PartitionSpec = { + defaultPartitionName: String, + typeInference: Boolean): PartitionSpec = { // First, we need to parse every partition's path and see if we can find partition values. val pathsWithPartitionValues = paths.flatMap { path => - parsePartition(path, defaultPartitionName).map(path -> _) + parsePartition(path, defaultPartitionName, typeInference).map(path -> _) } if (pathsWithPartitionValues.isEmpty) { @@ -124,7 +125,8 @@ private[sql] object PartitioningUtils { */ private[sql] def parsePartition( path: Path, - defaultPartitionName: String): Option[PartitionValues] = { + defaultPartitionName: String, + typeInference: Boolean): Option[PartitionValues] = { val columns = ArrayBuffer.empty[(String, Literal)] // Old Hadoop versions don't have `Path.isRoot` var finished = path.getParent == null @@ -137,7 +139,7 @@ private[sql] object PartitioningUtils { return None } - val maybeColumn = parsePartitionColumn(chopped.getName, defaultPartitionName) + val maybeColumn = parsePartitionColumn(chopped.getName, defaultPartitionName, typeInference) maybeColumn.foreach(columns += _) chopped = chopped.getParent finished = maybeColumn.isEmpty || chopped.getParent == null @@ -153,7 +155,8 @@ private[sql] object PartitioningUtils { private def parsePartitionColumn( columnSpec: String, - defaultPartitionName: String): Option[(String, Literal)] = { + defaultPartitionName: String, + typeInference: Boolean): Option[(String, Literal)] = { val equalSignIndex = columnSpec.indexOf('=') if (equalSignIndex == -1) { None @@ -164,7 +167,7 @@ private[sql] object PartitioningUtils { val rawColumnValue = columnSpec.drop(equalSignIndex + 1) assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'") - val literal = inferPartitionColumnValue(rawColumnValue, defaultPartitionName) + val literal = inferPartitionColumnValue(rawColumnValue, defaultPartitionName, typeInference) Some(columnName -> literal) } } @@ -211,19 +214,28 @@ private[sql] object PartitioningUtils { */ private[sql] def inferPartitionColumnValue( raw: String, - defaultPartitionName: String): Literal = { - // First tries integral types - Try(Literal.create(Integer.parseInt(raw), IntegerType)) - .orElse(Try(Literal.create(JLong.parseLong(raw), LongType))) - // Then falls back to fractional types - .orElse(Try(Literal.create(JFloat.parseFloat(raw), FloatType))) - .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) - .orElse(Try(Literal.create(new JBigDecimal(raw), DecimalType.Unlimited))) - // Then falls back to string - .getOrElse { - if (raw == defaultPartitionName) Literal.create(null, NullType) - else Literal.create(unescapePathName(raw), StringType) + defaultPartitionName: String, + typeInference: Boolean): Literal = { + if (typeInference) { + // First tries integral types + Try(Literal.create(Integer.parseInt(raw), IntegerType)) + .orElse(Try(Literal.create(JLong.parseLong(raw), LongType))) + // Then falls back to fractional types + .orElse(Try(Literal.create(JFloat.parseFloat(raw), FloatType))) + .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) + .orElse(Try(Literal.create(new JBigDecimal(raw), DecimalType.Unlimited))) + // Then falls back to string + .getOrElse { + if (raw == defaultPartitionName) Literal.create(null, NullType) + else Literal.create(unescapePathName(raw), StringType) + } + } else { + if (raw == defaultPartitionName) { + Literal.create(null, NullType) + } else { + Literal.create(unescapePathName(raw), StringType) } + } } private val upCastingOrder: Seq[DataType] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 25887ba9a15b0..d1547fb1e4abb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -491,9 +491,11 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio } private def discoverPartitions(): PartitionSpec = { + val typeInference = sqlContext.conf.partitionColumnTypeInferenceEnabled() // We use leaf dirs containing data files to discover the schema. val leafDirs = fileStatusCache.leafDirToChildrenFiles.keys.toSeq - PartitioningUtils.parsePartitions(leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME) + PartitioningUtils.parsePartitions(leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME, + typeInference) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index d9a010a9815a1..c2f1cc8ffd1fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -48,7 +48,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { test("column type inference") { def check(raw: String, literal: Literal): Unit = { - assert(inferPartitionColumnValue(raw, defaultPartitionName) === literal) + assert(inferPartitionColumnValue(raw, defaultPartitionName, true) === literal) } check("10", Literal.create(10, IntegerType)) @@ -60,12 +60,12 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { test("parse partition") { def check(path: String, expected: Option[PartitionValues]): Unit = { - assert(expected === parsePartition(new Path(path), defaultPartitionName)) + assert(expected === parsePartition(new Path(path), defaultPartitionName, true)) } def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = { val message = intercept[T] { - parsePartition(new Path(path), defaultPartitionName).get + parsePartition(new Path(path), defaultPartitionName, true).get }.getMessage assert(message.contains(expected)) @@ -105,7 +105,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { test("parse partitions") { def check(paths: Seq[String], spec: PartitionSpec): Unit = { - assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName) === spec) + assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName, true) === spec) } check(Seq( @@ -174,6 +174,77 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { PartitionSpec.emptySpec) } + test("parse partitions with type inference disabled") { + def check(paths: Seq[String], spec: PartitionSpec): Unit = { + assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName, false) === spec) + } + + check(Seq( + "hdfs://host:9000/path/a=10/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType))), + Seq(Partition(Row("10", "hello"), "hdfs://host:9000/path/a=10/b=hello")))) + + check(Seq( + "hdfs://host:9000/path/a=10/b=20", + "hdfs://host:9000/path/a=10.5/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType))), + Seq( + Partition(Row("10", "20"), "hdfs://host:9000/path/a=10/b=20"), + Partition(Row("10.5", "hello"), "hdfs://host:9000/path/a=10.5/b=hello")))) + + check(Seq( + "hdfs://host:9000/path/_temporary", + "hdfs://host:9000/path/a=10/b=20", + "hdfs://host:9000/path/a=10.5/b=hello", + "hdfs://host:9000/path/a=10.5/_temporary", + "hdfs://host:9000/path/a=10.5/_TeMpOrArY", + "hdfs://host:9000/path/a=10.5/b=hello/_temporary", + "hdfs://host:9000/path/a=10.5/b=hello/_TEMPORARY", + "hdfs://host:9000/path/_temporary/path", + "hdfs://host:9000/path/a=11/_temporary/path", + "hdfs://host:9000/path/a=10.5/b=world/_temporary/path"), + PartitionSpec( + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType))), + Seq( + Partition(Row("10", "20"), "hdfs://host:9000/path/a=10/b=20"), + Partition(Row("10.5", "hello"), "hdfs://host:9000/path/a=10.5/b=hello")))) + + check(Seq( + s"hdfs://host:9000/path/a=10/b=20", + s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType))), + Seq( + Partition(Row("10", "20"), s"hdfs://host:9000/path/a=10/b=20"), + Partition(Row(null, "hello"), s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello")))) + + check(Seq( + s"hdfs://host:9000/path/a=10/b=$defaultPartitionName", + s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName"), + PartitionSpec( + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType))), + Seq( + Partition(Row("10", null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), + Partition(Row("10.5", null), s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName")))) + + check(Seq( + s"hdfs://host:9000/path1", + s"hdfs://host:9000/path2"), + PartitionSpec.emptySpec) + } + test("read partitioned table - normal case") { withTempDir { base => for { From a1d9e5cc60d317ecf8fe390b66b623ae39c4534d Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 8 Jun 2015 15:37:28 +0100 Subject: [PATCH 235/254] [SPARK-8126] [BUILD] Use custom temp directory during build. Even with all the efforts to cleanup the temp directories created by unit tests, Spark leaves a lot of garbage in /tmp after a test run. This change overrides java.io.tmpdir to place those files under the build directory instead. After an sbt full unit test run, I was left with > 400 MB of temp files. Since they're now under the build dir, it's much easier to clean them up. Also make a slight change to a unit test to make it not pollute the source directory with test data. Author: Marcelo Vanzin Closes #6674 from vanzin/SPARK-8126 and squashes the following commits: 0f8ad41 [Marcelo Vanzin] Make sure tmp dir exists when tests run. 643e916 [Marcelo Vanzin] [MINOR] [BUILD] Use custom temp directory during build. --- .../spark/deploy/SparkSubmitUtilsSuite.scala | 22 +++++++++-------- pom.xml | 24 ++++++++++++++++++- project/SparkBuild.scala | 6 +++++ 3 files changed, 41 insertions(+), 11 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 8fda5c8b472c9..07d261cc428c4 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -28,9 +28,12 @@ import org.apache.ivy.plugins.resolver.IBiblioResolver import org.apache.spark.SparkFunSuite import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate +import org.apache.spark.util.Utils class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { + private var tempIvyPath: String = _ + private val noOpOutputStream = new OutputStream { def write(b: Int) = {} } @@ -47,6 +50,7 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { super.beforeAll() // We don't want to write logs during testing SparkSubmitUtils.printStream = new BufferPrintStream + tempIvyPath = Utils.createTempDir(namePrefix = "ivy").getAbsolutePath() } test("incorrect maven coordinate throws error") { @@ -90,21 +94,20 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { } test("ivy path works correctly") { - val ivyPath = "dummy" + File.separator + "ivy" val md = SparkSubmitUtils.getModuleDescriptor val artifacts = for (i <- 0 until 3) yield new MDArtifact(md, s"jar-$i", "jar", "jar") - var jPaths = SparkSubmitUtils.resolveDependencyPaths(artifacts.toArray, new File(ivyPath)) + var jPaths = SparkSubmitUtils.resolveDependencyPaths(artifacts.toArray, new File(tempIvyPath)) for (i <- 0 until 3) { - val index = jPaths.indexOf(ivyPath) + val index = jPaths.indexOf(tempIvyPath) assert(index >= 0) - jPaths = jPaths.substring(index + ivyPath.length) + jPaths = jPaths.substring(index + tempIvyPath.length) } val main = MavenCoordinate("my.awesome.lib", "mylib", "0.1") IvyTestUtils.withRepository(main, None, None) { repo => // end to end val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, Option(repo), - Option(ivyPath), true) - assert(jarPath.indexOf(ivyPath) >= 0, "should use non-default ivy path") + Option(tempIvyPath), true) + assert(jarPath.indexOf(tempIvyPath) >= 0, "should use non-default ivy path") } } @@ -123,13 +126,12 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(jarPath.indexOf("mylib") >= 0, "should find artifact") } // Local ivy repository with modified home - val dummyIvyPath = "dummy" + File.separator + "ivy" - val dummyIvyLocal = new File(dummyIvyPath, "local" + File.separator) + val dummyIvyLocal = new File(tempIvyPath, "local" + File.separator) IvyTestUtils.withRepository(main, None, Some(dummyIvyLocal), true) { repo => val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, None, - Some(dummyIvyPath), true) + Some(tempIvyPath), true) assert(jarPath.indexOf("mylib") >= 0, "should find artifact") - assert(jarPath.indexOf(dummyIvyPath) >= 0, "should be in new ivy path") + assert(jarPath.indexOf(tempIvyPath) >= 0, "should be in new ivy path") } } diff --git a/pom.xml b/pom.xml index 67b6375f576d3..5a5d183e3dcca 100644 --- a/pom.xml +++ b/pom.xml @@ -179,7 +179,7 @@ compile ${session.executionRootDirectory} @@ -1256,6 +1256,7 @@ test true + ${project.build.directory}/tmp ${spark.test.home} 1 false @@ -1289,6 +1290,7 @@ test true + ${project.build.directory}/tmp ${spark.test.home} 1 false @@ -1548,6 +1550,26 @@ + + + org.apache.maven.plugins + maven-antrun-plugin + + + create-tmp-dir + generate-test-resources + + run + + + + + + + + + + org.apache.maven.plugins diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index ef3a175bac209..d7e374558c5e2 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -51,6 +51,11 @@ object BuildCommons { // Root project. val spark = ProjectRef(buildLocation, "spark") val sparkHome = buildLocation + + val testTempDir = s"$sparkHome/target/tmp" + if (!new File(testTempDir).isDirectory()) { + require(new File(testTempDir).mkdirs()) + } } object SparkBuild extends PomBuild { @@ -496,6 +501,7 @@ object TestSettings { "SPARK_DIST_CLASSPATH" -> (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":"), "JAVA_HOME" -> sys.env.get("JAVA_HOME").getOrElse(sys.props("java.home"))), + javaOptions in Test += s"-Djava.io.tmpdir=$testTempDir", javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dspark.port.maxRetries=100", From e3e9c70384028cc0c322ccea14f19d3b6d6b39eb Mon Sep 17 00:00:00 2001 From: MechCoder Date: Mon, 8 Jun 2015 15:45:12 +0100 Subject: [PATCH 236/254] [SPARK-8140] [MLLIB] Remove empty model check in StreamingLinearAlgorithm 1. Prevent creating a map of data to find numFeatures 2. If model is empty, then initialize with a zero vector of numFeature Author: MechCoder Closes #6684 from MechCoder/spark-8140 and squashes the following commits: 7fbf5f9 [MechCoder] [SPARK-8140] Remove empty model check in StreamingLinearAlgorithm And other minor cosmits --- .../apache/spark/mllib/optimization/GradientDescent.scala | 2 +- .../spark/mllib/regression/GeneralizedLinearAlgorithm.scala | 6 +++--- .../spark/mllib/regression/StreamingLinearAlgorithm.scala | 3 --- .../mllib/regression/StreamingLinearRegressionWithSGD.scala | 2 +- 4 files changed, 5 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 4b7d0589c973b..06e45e10c5bf4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -179,7 +179,7 @@ object GradientDescent extends Logging { * if it's L2 updater; for L1 updater, the same logic is followed. */ var regVal = updater.compute( - weights, Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2 + weights, Vectors.zeros(weights.size), 0, 1, regParam)._2 for (i <- 1 to numIterations) { val bcWeights = data.context.broadcast(weights) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 26be30ff9d6fd..6709bd79bc820 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -195,11 +195,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] */ val initialWeights = { if (numOfLinearPredictor == 1) { - Vectors.dense(new Array[Double](numFeatures)) + Vectors.zeros(numFeatures) } else if (addIntercept) { - Vectors.dense(new Array[Double]((numFeatures + 1) * numOfLinearPredictor)) + Vectors.zeros((numFeatures + 1) * numOfLinearPredictor) } else { - Vectors.dense(new Array[Double](numFeatures * numOfLinearPredictor)) + Vectors.zeros(numFeatures * numOfLinearPredictor) } } run(input, initialWeights) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index cea8f3f47307b..39308e5ae1dde 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -87,9 +87,6 @@ abstract class StreamingLinearAlgorithm[ model match { case Some(m) => m.weights - case None => - val numFeatures = rdd.first().features.size - Vectors.dense(numFeatures) } model = Some(algorithm.run(rdd, initialWeights)) logInfo("Model updated at time %s".format(time.toString)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index a49153bf73c0d..235e043c7754b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -79,7 +79,7 @@ class StreamingLinearRegressionWithSGD private[mllib] ( this } - /** Set the initial weights. Default: [0.0, 0.0]. */ + /** Set the initial weights. */ def setInitialWeights(initialWeights: Vector): this.type = { this.model = Some(algorithm.createModel(initialWeights, 0.0)) this From 149d1b28e899177ed170292fd2af30aad5a610e0 Mon Sep 17 00:00:00 2001 From: Mingfei Date: Mon, 8 Jun 2015 16:23:43 +0100 Subject: [PATCH 237/254] [SMALL FIX] Return null if catch EOFException Return null if catch EOFException, just like function "asKeyValueIterator" in this class Author: Mingfei Closes #6703 from shimingfei/returnNull and squashes the following commits: 205deec [Mingfei] return null if catch EOFException --- core/src/main/scala/org/apache/spark/serializer/Serializer.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index f1bdff96d3df1..bd2704dc81871 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -182,6 +182,7 @@ abstract class DeserializationStream { } catch { case eof: EOFException => finished = true + null } } From 49f19b954b32d57d03ca0e25ea4205d01e794d48 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Mon, 8 Jun 2015 09:41:06 -0700 Subject: [PATCH 238/254] [MINOR] change new Exception to IllegalArgumentException Author: Daoyuan Wang Closes #6434 from adrian-wang/joinerr and squashes the following commits: ee1b64f [Daoyuan Wang] break line f7c53e9 [Daoyuan Wang] to IllegalArgumentException f8dea2d [Daoyuan Wang] sys.err to IllegalStateException be82259 [Daoyuan Wang] change new exception to sys.err --- .../apache/spark/sql/execution/joins/HashOuterJoin.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 45574392996ca..c21a453115292 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -48,7 +48,8 @@ case class HashOuterJoin( case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) - case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + case x => + throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") } override def requiredChildDistribution: Seq[ClusteredDistribution] = @@ -63,7 +64,7 @@ case class HashOuterJoin( case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) case x => - throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") } } @@ -216,7 +217,8 @@ case class HashOuterJoin( rightHashTable.getOrElse(key, EMPTY_LIST), joinedRow) } - case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + case x => + throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") } } } From ed5c2dccd0397c4c4b0008c437e6845dd583c9c2 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Mon, 8 Jun 2015 11:06:27 -0700 Subject: [PATCH 239/254] [SPARK-8158] [SQL] several fix for HiveShim 1. explicitly import implicit conversion support. 2. use .nonEmpty instead of .size > 0 3. use val instead of var 4. comment indention Author: Daoyuan Wang Closes #6700 from adrian-wang/shimsimprove and squashes the following commits: d22e108 [Daoyuan Wang] several fix for HiveShim --- .../org/apache/spark/sql/hive/HiveShim.scala | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index fa5409f602444..d08c594151654 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -20,6 +20,11 @@ package org.apache.spark.sql.hive import java.io.{InputStream, OutputStream} import java.rmi.server.UID +/* Implicit conversions */ +import scala.collection.JavaConversions._ +import scala.language.implicitConversions +import scala.reflect.ClassTag + import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.io.{Input, Output} import org.apache.hadoop.conf.Configuration @@ -35,10 +40,6 @@ import org.apache.spark.Logging import org.apache.spark.sql.types.Decimal import org.apache.spark.util.Utils -/* Implicit conversions */ -import scala.collection.JavaConversions._ -import scala.reflect.ClassTag - private[hive] object HiveShim { // Precision and scale to pass for unlimited decimals; these are the same as the precision and // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs) @@ -68,10 +69,10 @@ private[hive] object HiveShim { * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null or empty */ def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { - if (ids != null && ids.size > 0) { + if (ids != null && ids.nonEmpty) { ColumnProjectionUtils.appendReadColumns(conf, ids) } - if (names != null && names.size > 0) { + if (names != null && names.nonEmpty) { appendReadColumnNames(conf, names) } } @@ -197,11 +198,11 @@ private[hive] object HiveShim { } /* - * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not. - * Fix it through wrapper. - * */ + * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not. + * Fix it through wrapper. + */ implicit def wrapperToFileSinkDesc(w: ShimFileSinkDesc): FileSinkDesc = { - var f = new FileSinkDesc(new Path(w.dir), w.tableInfo, w.compressed) + val f = new FileSinkDesc(new Path(w.dir), w.tableInfo, w.compressed) f.setCompressCodec(w.compressCodec) f.setCompressType(w.compressType) f.setTableInfo(w.tableInfo) From bbdfc0a40fb39760c122e7b9ce80aa1e340e55ee Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 8 Jun 2015 11:34:18 -0700 Subject: [PATCH 240/254] [SPARK-8121] [SQL] Fixes InsertIntoHadoopFsRelation job initialization for Hadoop 1.x For Hadoop 1.x, `TaskAttemptContext` constructor clones the `Configuration` argument, thus configurations done in `HadoopFsRelation.prepareForWriteJob()` are not populated to *driver* side `TaskAttemptContext` (executor side configurations are properly populated). Currently this should only affect Parquet output committer class configuration. Author: Cheng Lian Closes #6669 from liancheng/spark-8121 and squashes the following commits: 73819e8 [Cheng Lian] Minor logging fix fce089c [Cheng Lian] Adds more logging b6f78a6 [Cheng Lian] Fixes compilation error introduced while rebasing 963a1aa [Cheng Lian] Addresses @yhuai's comment c3a0b1a [Cheng Lian] Fixes InsertIntoHadoopFsRelation job initialization --- .../scala/org/apache/spark/sql/SQLConf.scala | 1 + .../apache/spark/sql/parquet/newParquet.scala | 7 +++ .../apache/spark/sql/sources/commands.scala | 18 +++++-- .../spark/sql/parquet/ParquetIOSuite.scala | 52 ++++++++++++++++--- 4 files changed, 65 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index c778889045d02..be786f9b7f49e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -76,6 +76,7 @@ private[spark] object SQLConf { // The output committer class used by FSBasedRelation. The specified class needs to be a // subclass of org.apache.hadoop.mapreduce.OutputCommitter. + // NOTE: This property should be set in Hadoop `Configuration` rather than Spark `SQLConf` val OUTPUT_COMMITTER_CLASS = "spark.sql.sources.outputCommitterClass" // Whether to perform eager analysis when constructing a dataframe. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 5dda440240e60..7af4eb1ca4716 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -212,6 +212,13 @@ private[sql] class ParquetRelation2( classOf[ParquetOutputCommitter], classOf[ParquetOutputCommitter]) + if (conf.get("spark.sql.parquet.output.committer.class") == null) { + logInfo("Using default output committer for Parquet: " + + classOf[ParquetOutputCommitter].getCanonicalName) + } else { + logInfo("Using user defined output committer for Parquet: " + committerClass.getCanonicalName) + } + conf.setClass( SQLConf.OUTPUT_COMMITTER_CLASS, committerClass, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index bd3aad6631748..c94199bfcd233 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -297,12 +297,16 @@ private[sql] abstract class BaseWriterContainer( def driverSideSetup(): Unit = { setupIDs(0, 0, 0) setupConf() - taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) - // This preparation must happen before initializing output format and output committer, since - // their initialization involves the job configuration, which can be potentially decorated in - // `relation.prepareJobForWrite`. + // Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor + // clones the Configuration object passed in. If we initialize the TaskAttemptContext first, + // configurations made in prepareJobForWrite(job) are not populated into the TaskAttemptContext. + // + // Also, the `prepareJobForWrite` call must happen before initializing output format and output + // committer, since their initialization involve the job configuration, which can be potentially + // decorated in `prepareJobForWrite`. outputWriterFactory = relation.prepareJobForWrite(job) + taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) outputFormatClass = job.getOutputFormatClass outputCommitter = newOutputCommitter(taskAttemptContext) @@ -331,6 +335,8 @@ private[sql] abstract class BaseWriterContainer( SQLConf.OUTPUT_COMMITTER_CLASS, null, classOf[OutputCommitter]) Option(committerClass).map { clazz => + logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}") + // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat // has an associated output committer. To override this output committer, // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS. @@ -350,7 +356,9 @@ private[sql] abstract class BaseWriterContainer( }.getOrElse { // If output committer class is not set, we will use the one associated with the // file output format. - outputFormatClass.newInstance().getOutputCommitter(context) + val outputCommitter = outputFormatClass.newInstance().getOutputCommitter(context) + logInfo(s"Using output committer class ${outputCommitter.getClass.getCanonicalName}") + outputCommitter } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 2b6a27032e637..46b25859d9a68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -23,16 +23,18 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} -import org.scalatest.BeforeAndAfterAll +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.parquet.example.data.simple.SimpleGroup import org.apache.parquet.example.data.{Group, GroupWriter} import org.apache.parquet.hadoop.api.WriteSupport import org.apache.parquet.hadoop.api.WriteSupport.WriteContext -import org.apache.parquet.hadoop.metadata.{ParquetMetadata, FileMetaData, CompressionCodecName} -import org.apache.parquet.hadoop.{Footer, ParquetFileWriter, ParquetWriter} +import org.apache.parquet.hadoop.metadata.{CompressionCodecName, FileMetaData, ParquetMetadata} +import org.apache.parquet.hadoop.{Footer, ParquetFileWriter, ParquetOutputCommitter, ParquetWriter} import org.apache.parquet.io.api.RecordConsumer import org.apache.parquet.schema.{MessageType, MessageTypeParser} +import org.scalatest.BeforeAndAfterAll +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.util.DateUtils @@ -196,7 +198,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withParquetDataFrame(allNulls :: Nil) { df => val rows = df.collect() - assert(rows.size === 1) + assert(rows.length === 1) assert(rows.head === Row(Seq.fill(5)(null): _*)) } } @@ -209,7 +211,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withParquetDataFrame(allNones :: Nil) { df => val rows = df.collect() - assert(rows.size === 1) + assert(rows.length === 1) assert(rows.head === Row(Seq.fill(3)(null): _*)) } } @@ -379,6 +381,8 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } test("SPARK-6352 DirectParquetOutputCommitter") { + val clonedConf = new Configuration(configuration) + // Write to a parquet file and let it fail. // _temporary should be missing if direct output committer works. try { @@ -393,14 +397,46 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { val fs = path.getFileSystem(configuration) assert(!fs.exists(path)) } + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + configuration.clear() + clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) } - finally { - configuration.set("spark.sql.parquet.output.committer.class", - "org.apache.parquet.hadoop.ParquetOutputCommitter") + } + + test("SPARK-8121: spark.sql.parquet.output.committer.class shouldn't be overriden") { + withTempPath { dir => + val clonedConf = new Configuration(configuration) + + configuration.set( + SQLConf.OUTPUT_COMMITTER_CLASS, classOf[ParquetOutputCommitter].getCanonicalName) + + configuration.set( + "spark.sql.parquet.output.committer.class", + classOf[BogusParquetOutputCommitter].getCanonicalName) + + try { + val message = intercept[SparkException] { + sqlContext.range(0, 1).write.parquet(dir.getCanonicalPath) + }.getCause.getMessage + assert(message === "Intentional exception for testing purposes") + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + configuration.clear() + clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + } } } } +class BogusParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) + extends ParquetOutputCommitter(outputPath, context) { + + override def commitJob(jobContext: JobContext): Unit = { + sys.error("Intentional exception for testing purposes") + } +} + class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi From fe7669d3072b72954ad0c3f2f8846a0fde839ead Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 8 Jun 2015 11:52:02 -0700 Subject: [PATCH 241/254] [SQL][minor] remove duplicated cases in `DecimalPrecision` We already have a rule to do type coercion for fixed decimal and unlimited decimal in `WidenTypes`, so we don't need to handle them in `DecimalPrecision`. Author: Wenchen Fan Closes #6698 from cloud-fan/fix and squashes the following commits: 413ad4a [Wenchen Fan] remove duplicated cases --- .../spark/sql/catalyst/analysis/HiveTypeCoercion.scala | 6 ------ 1 file changed, 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index a42ffce0d26fa..737905c3582ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -445,12 +445,6 @@ trait HiveTypeCoercion { e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => val resultType = DecimalType(max(p1, p2), max(s1, s2)) b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType))) - case b @ BinaryComparison(e1 @ DecimalType.Fixed(_, _), e2) - if e2.dataType == DecimalType.Unlimited => - b.makeCopy(Array(Cast(e1, DecimalType.Unlimited), e2)) - case b @ BinaryComparison(e1, e2 @ DecimalType.Fixed(_, _)) - if e1.dataType == DecimalType.Unlimited => - b.makeCopy(Array(e1, Cast(e2, DecimalType.Unlimited))) // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles From 51853891686f353dc9decc31066b0de01ed8b49e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 8 Jun 2015 13:15:44 -0700 Subject: [PATCH 242/254] [SPARK-8148] Do not use FloatType in partition column inference. Use DoubleType instead to be more stable and robust. Author: Reynold Xin Closes #6692 from rxin/SPARK-8148 and squashes the following commits: 6742ecc [Reynold Xin] [SPARK-8148] Do not use FloatType in partition column inference. --- .../spark/sql/sources/PartitioningUtils.scala | 16 +++++++++------- .../parquet/ParquetPartitionDiscoverySuite.scala | 12 ++++++------ 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala index 9f6ec2ed8fc8d..7a2b5b949dd4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources -import java.lang.{Double => JDouble, Float => JFloat, Long => JLong} +import java.lang.{Double => JDouble, Float => JFloat, Integer => JInteger, Long => JLong} import java.math.{BigDecimal => JBigDecimal} import scala.collection.mutable.ArrayBuffer @@ -178,7 +178,7 @@ private[sql] object PartitioningUtils { * {{{ * NullType -> * IntegerType -> LongType -> - * FloatType -> DoubleType -> DecimalType.Unlimited -> + * DoubleType -> DecimalType.Unlimited -> * StringType * }}} */ @@ -208,8 +208,8 @@ private[sql] object PartitioningUtils { } /** - * Converts a string to a `Literal` with automatic type inference. Currently only supports - * [[IntegerType]], [[LongType]], [[FloatType]], [[DoubleType]], [[DecimalType.Unlimited]], and + * Converts a string to a [[Literal]] with automatic type inference. Currently only supports + * [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType.Unlimited]], and * [[StringType]]. */ private[sql] def inferPartitionColumnValue( @@ -221,13 +221,15 @@ private[sql] object PartitioningUtils { Try(Literal.create(Integer.parseInt(raw), IntegerType)) .orElse(Try(Literal.create(JLong.parseLong(raw), LongType))) // Then falls back to fractional types - .orElse(Try(Literal.create(JFloat.parseFloat(raw), FloatType))) .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) .orElse(Try(Literal.create(new JBigDecimal(raw), DecimalType.Unlimited))) // Then falls back to string .getOrElse { - if (raw == defaultPartitionName) Literal.create(null, NullType) - else Literal.create(unescapePathName(raw), StringType) + if (raw == defaultPartitionName) { + Literal.create(null, NullType) + } else { + Literal.create(unescapePathName(raw), StringType) + } } } else { if (raw == defaultPartitionName) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index c2f1cc8ffd1fb..3240079483545 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -53,7 +53,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { check("10", Literal.create(10, IntegerType)) check("1000000000000000", Literal.create(1000000000000000L, LongType)) - check("1.5", Literal.create(1.5f, FloatType)) + check("1.5", Literal.create(1.5, DoubleType)) check("hello", Literal.create("hello", StringType)) check(defaultPartitionName, Literal.create(null, NullType)) } @@ -83,13 +83,13 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { ArrayBuffer( Literal.create(10, IntegerType), Literal.create("hello", StringType), - Literal.create(1.5f, FloatType))) + Literal.create(1.5, DoubleType))) }) check("file://path/a=10/b_hello/c=1.5", Some { PartitionValues( ArrayBuffer("c"), - ArrayBuffer(Literal.create(1.5f, FloatType))) + ArrayBuffer(Literal.create(1.5, DoubleType))) }) check("file:///", None) @@ -121,7 +121,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { "hdfs://host:9000/path/a=10.5/b=hello"), PartitionSpec( StructType(Seq( - StructField("a", FloatType), + StructField("a", DoubleType), StructField("b", StringType))), Seq( Partition(Row(10, "20"), "hdfs://host:9000/path/a=10/b=20"), @@ -140,7 +140,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { "hdfs://host:9000/path/a=10.5/b=world/_temporary/path"), PartitionSpec( StructType(Seq( - StructField("a", FloatType), + StructField("a", DoubleType), StructField("b", StringType))), Seq( Partition(Row(10, "20"), "hdfs://host:9000/path/a=10/b=20"), @@ -162,7 +162,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName"), PartitionSpec( StructType(Seq( - StructField("a", FloatType), + StructField("a", DoubleType), StructField("b", StringType))), Seq( Partition(Row(10, null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), From f3eec92ce7e13cc461d2f0404f26730259210f12 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 8 Jun 2015 18:09:21 -0700 Subject: [PATCH 243/254] [SPARK-8162] [HOTFIX] Fix NPE in spark-shell This was caused by this commit: f271347 This patch does not attempt to fix the root cause of why the `VisibleForTesting` annotation causes a NPE in the shell. We should find a way to fix that separately. Author: Andrew Or Closes #6711 from andrewor14/fix-spark-shell and squashes the following commits: bf62ecc [Andrew Or] Prevent NPE in spark-shell --- .../scala/org/apache/spark/ui/jobs/JobProgressListener.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 730f9806e518e..0c854f04890b6 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -539,11 +539,11 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { /** * For testing only. Wait until at least `numExecutors` executors are up, or throw * `TimeoutException` if the waiting time elapsed before `numExecutors` executors up. + * Exposed for testing. * * @param numExecutors the number of executors to wait at least * @param timeout time to wait in milliseconds */ - @VisibleForTesting private[spark] def waitUntilExecutorsUp(numExecutors: Int, timeout: Long): Unit = { val finishTime = System.currentTimeMillis() + timeout while (System.currentTimeMillis() < finishTime) { From 82870d507dfaeeaf315d6766ca1496205c6216d3 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 8 Jun 2015 21:33:47 -0700 Subject: [PATCH 244/254] [SPARK-8168] [MLLIB] Add Python friendly constructor to PipelineModel This makes the constructor callable in Python. dbtsai Author: Xiangrui Meng Closes #6709 from mengxr/SPARK-8168 and squashes the following commits: f871de4 [Xiangrui Meng] Add Python friendly constructor to PipelineModel --- .../scala/org/apache/spark/ml/Pipeline.scala | 8 ++++++++ .../org/apache/spark/ml/PipelineSuite.scala | 17 +++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 11a4722722ea1..a9bd28df71ee1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -17,6 +17,9 @@ package org.apache.spark.ml +import java.{util => ju} + +import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import org.apache.spark.Logging @@ -175,6 +178,11 @@ class PipelineModel private[ml] ( val stages: Array[Transformer]) extends Model[PipelineModel] with Logging { + /** A Java/Python-friendly auxiliary constructor. */ + private[ml] def this(uid: String, stages: ju.List[Transformer]) = { + this(uid, stages.asScala.toArray) + } + override def validateParams(): Unit = { super.validateParams() stages.foreach(_.validateParams()) diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 05bf58e63abaf..29394fefcbc43 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml +import scala.collection.JavaConverters._ + import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito.when import org.scalatest.mock.MockitoSugar.mock @@ -81,4 +83,19 @@ class PipelineSuite extends SparkFunSuite { pipeline.fit(dataset) } } + + test("pipeline model constructors") { + val transform0 = mock[Transformer] + val model1 = mock[MyModel] + + val stages = Array(transform0, model1) + val pipelineModel0 = new PipelineModel("pipeline0", stages) + assert(pipelineModel0.uid === "pipeline0") + assert(pipelineModel0.stages === stages) + + val stagesAsList = stages.toList.asJava + val pipelineModel1 = new PipelineModel("pipeline1", stagesAsList) + assert(pipelineModel1.uid === "pipeline1") + assert(pipelineModel1.stages === stages) + } } From a5c52c1a3488b69bec19e460d2d1fdb0c9ada58d Mon Sep 17 00:00:00 2001 From: hqzizania Date: Mon, 8 Jun 2015 21:40:12 -0700 Subject: [PATCH 245/254] [SPARK-6820] [SPARKR] Convert NAs to null type in SparkR DataFrames Author: hqzizania Closes #6190 from hqzizania/R and squashes the following commits: 1641f9e [hqzizania] fixes and add test units bb3411a [hqzizania] Convert NAs to null type in SparkR DataFrames --- R/pkg/R/serialize.R | 8 +++++++ R/pkg/inst/tests/test_sparkSQL.R | 37 ++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 2081786e6f833..3169d7968f8fe 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -37,6 +37,14 @@ writeObject <- function(con, object, writeType = TRUE) { # passing in vectors as arrays and instead require arrays to be passed # as lists. type <- class(object)[[1]] # class of POSIXlt is c("POSIXlt", "POSIXt") + # Checking types is needed here, since ‘is.na’ only handles atomic vectors, + # lists and pairlists + if (type %in% c("integer", "character", "logical", "double", "numeric")) { + if (is.na(object)) { + object <- NULL + type <- "NULL" + } + } if (writeType) { writeType(con, type) } diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 30edfc8a7bd94..8946348ef801c 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -101,6 +101,43 @@ test_that("create DataFrame from RDD", { expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) }) +test_that("convert NAs to null type in DataFrames", { + rdd <- parallelize(sc, list(list(1L, 2L), list(NA, 4L))) + df <- createDataFrame(sqlContext, rdd, list("a", "b")) + expect_true(is.na(collect(df)[2, "a"])) + expect_equal(collect(df)[2, "b"], 4L) + + l <- data.frame(x = 1L, y = c(1L, NA_integer_, 3L)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(df)[2, "x"], 1L) + expect_true(is.na(collect(df)[2, "y"])) + + rdd <- parallelize(sc, list(list(1, 2), list(NA, 4))) + df <- createDataFrame(sqlContext, rdd, list("a", "b")) + expect_true(is.na(collect(df)[2, "a"])) + expect_equal(collect(df)[2, "b"], 4) + + l <- data.frame(x = 1, y = c(1, NA_real_, 3)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(df)[2, "x"], 1) + expect_true(is.na(collect(df)[2, "y"])) + + l <- list("a", "b", NA, "d") + df <- createDataFrame(sqlContext, l) + expect_true(is.na(collect(df)[3, "_1"])) + expect_equal(collect(df)[4, "_1"], "d") + + l <- list("a", "b", NA_character_, "d") + df <- createDataFrame(sqlContext, l) + expect_true(is.na(collect(df)[3, "_1"])) + expect_equal(collect(df)[4, "_1"], "d") + + l <- list(TRUE, FALSE, NA, TRUE) + df <- createDataFrame(sqlContext, l) + expect_true(is.na(collect(df)[3, "_1"])) + expect_equal(collect(df)[4, "_1"], TRUE) +}) + test_that("toDF", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- toDF(rdd, list("a", "b")) From 7658eb28a2ea28c06e3b5a26f7734a7dc36edc19 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 8 Jun 2015 23:27:05 -0700 Subject: [PATCH 246/254] [SPARK-7990][SQL] Add methods to facilitate equi-join on multiple joining keys JIRA: https://issues.apache.org/jira/browse/SPARK-7990 Author: Liang-Chi Hsieh Closes #6616 from viirya/multi_keys_equi_join and squashes the following commits: cd5c888 [Liang-Chi Hsieh] Import reduce in python3. c43722c [Liang-Chi Hsieh] For comments. 0400e89 [Liang-Chi Hsieh] Fix scala style. cc90015 [Liang-Chi Hsieh] Add methods to facilitate equi-join on multiple joining keys. --- python/pyspark/sql/dataframe.py | 45 +++++++++++++------ .../org/apache/spark/sql/DataFrame.scala | 40 ++++++++++++++--- .../apache/spark/sql/DataFrameJoinSuite.scala | 9 ++++ 3 files changed, 75 insertions(+), 19 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 2d8c59518b35a..e9dd05e2d0c7a 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -22,6 +22,7 @@ if sys.version >= '3': basestring = unicode = str long = int + from functools import reduce else: from itertools import imap as map @@ -503,36 +504,52 @@ def alias(self, alias): @ignore_unicode_prefix @since(1.3) - def join(self, other, joinExprs=None, joinType=None): + def join(self, other, on=None, how=None): """Joins with another :class:`DataFrame`, using the given join expression. The following performs a full outer join between ``df1`` and ``df2``. :param other: Right side of the join - :param joinExprs: a string for join column name, or a join expression (Column). - If joinExprs is a string indicating the name of the join column, - the column must exist on both sides, and this performs an inner equi-join. - :param joinType: str, default 'inner'. + :param on: a string for join column name, a list of column names, + , a join expression (Column) or a list of Columns. + If `on` is a string or a list of string indicating the name of the join column(s), + the column(s) must exist on both sides, and this performs an inner equi-join. + :param how: str, default 'inner'. One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] + >>> cond = [df.name == df3.name, df.age == df3.age] + >>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect() + [Row(name=u'Bob', age=5), Row(name=u'Alice', age=2)] + >>> df.join(df2, 'name').select(df.name, df2.height).collect() [Row(name=u'Bob', height=85)] + + >>> df.join(df4, ['name', 'age']).select(df.name, df.age).collect() + [Row(name=u'Bob', age=5)] """ - if joinExprs is None: + if on is not None and not isinstance(on, list): + on = [on] + + if on is None or len(on) == 0: jdf = self._jdf.join(other._jdf) - elif isinstance(joinExprs, basestring): - jdf = self._jdf.join(other._jdf, joinExprs) + + if isinstance(on[0], basestring): + jdf = self._jdf.join(other._jdf, self._jseq(on)) else: - assert isinstance(joinExprs, Column), "joinExprs should be Column" - if joinType is None: - jdf = self._jdf.join(other._jdf, joinExprs._jc) + assert isinstance(on[0], Column), "on should be Column or list of Column" + if len(on) > 1: + on = reduce(lambda x, y: x.__and__(y), on) + else: + on = on[0] + if how is None: + jdf = self._jdf.join(other._jdf, on._jc, "inner") else: - assert isinstance(joinType, basestring), "joinType should be basestring" - jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType) + assert isinstance(how, basestring), "how should be basestring" + jdf = self._jdf.join(other._jdf, on._jc, how) return DataFrame(jdf, self.sql_ctx) @ignore_unicode_prefix @@ -1315,6 +1332,8 @@ def _test(): .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF() + globs['df3'] = sc.parallelize([Row(name='Alice', age=2), + Row(name='Bob', age=5)]).toDF() globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80), Row(name='Bob', age=5, height=None), Row(name='Tom', age=None, height=None), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 4a224153e1a37..59f64dd4bc648 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -395,22 +395,50 @@ class DataFrame private[sql]( * @since 1.4.0 */ def join(right: DataFrame, usingColumn: String): DataFrame = { + join(right, Seq(usingColumn)) + } + + /** + * Inner equi-join with another [[DataFrame]] using the given columns. + * + * Different from other join functions, the join columns will only appear once in the output, + * i.e. similar to SQL's `JOIN USING` syntax. + * + * {{{ + * // Joining df1 and df2 using the columns "user_id" and "user_name" + * df1.join(df2, Seq("user_id", "user_name")) + * }}} + * + * Note that if you perform a self-join using this function without aliasing the input + * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since + * there is no way to disambiguate which side of the join you would like to reference. + * + * @param right Right side of the join operation. + * @param usingColumns Names of the columns to join on. This columns must exist on both sides. + * @group dfops + * @since 1.4.0 + */ + def join(right: DataFrame, usingColumns: Seq[String]): DataFrame = { // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sqlContext.executePlan( Join(logicalPlan, right.logicalPlan, joinType = Inner, None)).analyzed.asInstanceOf[Join] - // Project only one of the join column. - val joinedCol = joined.right.resolve(usingColumn) + // Project only one of the join columns. + val joinedCols = usingColumns.map(col => joined.right.resolve(col)) + val condition = usingColumns.map { col => + catalyst.expressions.EqualTo(joined.left.resolve(col), joined.right.resolve(col)) + }.reduceLeftOption[catalyst.expressions.BinaryExpression] { (cond, eqTo) => + catalyst.expressions.And(cond, eqTo) + } + Project( - joined.output.filterNot(_ == joinedCol), + joined.output.filterNot(joinedCols.contains(_)), Join( joined.left, joined.right, joinType = Inner, - Some(catalyst.expressions.EqualTo( - joined.left.resolve(usingColumn), - joined.right.resolve(usingColumn)))) + condition) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 051d13e9a544f..6165764632c29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -34,6 +34,15 @@ class DataFrameJoinSuite extends QueryTest { Row(1, "1", "2") :: Row(2, "2", "3") :: Row(3, "3", "4") :: Nil) } + test("join - join using multiple columns") { + val df = Seq(1, 2, 3).map(i => (i, i + 1, i.toString)).toDF("int", "int2", "str") + val df2 = Seq(1, 2, 3).map(i => (i, i + 1, (i + 1).toString)).toDF("int", "int2", "str") + + checkAnswer( + df.join(df2, Seq("int", "int2")), + Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil) + } + test("join - join using self join") { val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") From 0902a11940e550e85a53e110b490fe90e16ddaf4 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 9 Jun 2015 08:00:04 +0100 Subject: [PATCH 247/254] [SPARK-8101] [CORE] Upgrade netty to avoid memory leak accord to netty #3837 issues Update to Netty 4.0.28-Final Author: Sean Owen Closes #6701 from srowen/SPARK-8101 and squashes the following commits: f3b6369 [Sean Owen] Update to Netty 4.0.28-Final --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 5a5d183e3dcca..e9700a5d7b149 100644 --- a/pom.xml +++ b/pom.xml @@ -587,7 +587,7 @@ io.netty netty-all - 4.0.23.Final + 4.0.28.Final org.apache.derby From 1b499993ad185b04dd5065facb565cbe7e249521 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 9 Jun 2015 16:24:38 +0800 Subject: [PATCH 248/254] [SPARK-7886] Add built-in expressions to FunctionRegistry. This patch switches to using FunctionRegistry for built-in expressions. It is based on #6463, but with some work to simplify it along with unit tests. TODOs for future pull requests: - Use static registration so we don't need to register all functions every time we start a new SQLContext - Switch to using this in HiveContext Author: Reynold Xin Author: Santiago M. Mola Closes #6710 from rxin/udf-registry and squashes the following commits: 6930822 [Reynold Xin] Fixed Python test. b802c9a [Reynold Xin] Made UDF case insensitive. e60d815 [Reynold Xin] Made UDF case insensitive. 852f9c0 [Reynold Xin] Fixed style violation. e76a3c1 [Reynold Xin] Fixed parser. 52ddaba [Reynold Xin] Fixed compilation. ee7854f [Reynold Xin] Improved error reporting. ff906f2 [Reynold Xin] More robust constructor calling. 77b46f1 [Reynold Xin] Simplified the code. 2a2a149 [Reynold Xin] Merge pull request #6463 from smola/SPARK-7886 8616924 [Santiago M. Mola] [SPARK-7886] Add built-in expressions to FunctionRegistry. --- python/pyspark/sql/dataframe.py | 2 +- .../apache/spark/sql/catalyst/SqlParser.scala | 75 +++++------ .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../catalyst/analysis/FunctionRegistry.scala | 127 +++++++++++++----- .../sql/catalyst/expressions/Expression.scala | 9 ++ .../sql/catalyst/expressions/random.scala | 23 +++- .../expressions/stringOperations.scala | 7 + .../sql/catalyst/util/StringKeyHashMap.scala | 44 ++++++ .../org/apache/spark/sql/SQLContext.scala | 6 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 42 ++++++ .../apache/spark/sql/hive/HiveContext.scala | 9 +- .../org/apache/spark/sql/hive/hiveUdfs.scala | 14 +- 12 files changed, 269 insertions(+), 93 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index e9dd05e2d0c7a..9615e576497cd 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -746,7 +746,7 @@ def selectExpr(self, *expr): This is a variant of :func:`select` that accepts SQL expressions. >>> df.selectExpr("age * 2", "abs(age)").collect() - [Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)] + [Row((age * 2)=4, 'abs(age)=2), Row((age * 2)=10, 'abs(age)=5)] """ if len(expr) == 1 and isinstance(expr[0], list): expr = expr[0] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index e85312aee7d16..f74c17d583359 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst import scala.language.implicitConversions +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -48,26 +49,21 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` // properties via reflection the class in runtime for constructing the SqlLexical object - protected val ABS = Keyword("ABS") protected val ALL = Keyword("ALL") protected val AND = Keyword("AND") protected val APPROXIMATE = Keyword("APPROXIMATE") protected val AS = Keyword("AS") protected val ASC = Keyword("ASC") - protected val AVG = Keyword("AVG") protected val BETWEEN = Keyword("BETWEEN") protected val BY = Keyword("BY") protected val CASE = Keyword("CASE") protected val CAST = Keyword("CAST") - protected val COALESCE = Keyword("COALESCE") - protected val COUNT = Keyword("COUNT") protected val DESC = Keyword("DESC") protected val DISTINCT = Keyword("DISTINCT") protected val ELSE = Keyword("ELSE") protected val END = Keyword("END") protected val EXCEPT = Keyword("EXCEPT") protected val FALSE = Keyword("FALSE") - protected val FIRST = Keyword("FIRST") protected val FROM = Keyword("FROM") protected val FULL = Keyword("FULL") protected val GROUP = Keyword("GROUP") @@ -80,13 +76,9 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected val INTO = Keyword("INTO") protected val IS = Keyword("IS") protected val JOIN = Keyword("JOIN") - protected val LAST = Keyword("LAST") protected val LEFT = Keyword("LEFT") protected val LIKE = Keyword("LIKE") protected val LIMIT = Keyword("LIMIT") - protected val LOWER = Keyword("LOWER") - protected val MAX = Keyword("MAX") - protected val MIN = Keyword("MIN") protected val NOT = Keyword("NOT") protected val NULL = Keyword("NULL") protected val ON = Keyword("ON") @@ -100,15 +92,10 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected val RLIKE = Keyword("RLIKE") protected val SELECT = Keyword("SELECT") protected val SEMI = Keyword("SEMI") - protected val SQRT = Keyword("SQRT") - protected val SUBSTR = Keyword("SUBSTR") - protected val SUBSTRING = Keyword("SUBSTRING") - protected val SUM = Keyword("SUM") protected val TABLE = Keyword("TABLE") protected val THEN = Keyword("THEN") protected val TRUE = Keyword("TRUE") protected val UNION = Keyword("UNION") - protected val UPPER = Keyword("UPPER") protected val WHEN = Keyword("WHEN") protected val WHERE = Keyword("WHERE") protected val WITH = Keyword("WITH") @@ -277,25 +264,36 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { ) protected lazy val function: Parser[Expression] = - ( SUM ~> "(" ~> expression <~ ")" ^^ { case exp => Sum(exp) } - | SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) } - | COUNT ~ "(" ~> "*" <~ ")" ^^ { case _ => Count(Literal(1)) } - | COUNT ~ "(" ~> expression <~ ")" ^^ { case exp => Count(exp) } - | COUNT ~> "(" ~> DISTINCT ~> repsep(expression, ",") <~ ")" ^^ - { case exps => CountDistinct(exps) } - | APPROXIMATE ~ COUNT ~ "(" ~ DISTINCT ~> expression <~ ")" ^^ - { case exp => ApproxCountDistinct(exp) } - | APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ - { case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble) } - | FIRST ~ "(" ~> expression <~ ")" ^^ { case exp => First(exp) } - | LAST ~ "(" ~> expression <~ ")" ^^ { case exp => Last(exp) } - | AVG ~ "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } - | MIN ~ "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } - | MAX ~ "(" ~> expression <~ ")" ^^ { case exp => Max(exp) } - | UPPER ~ "(" ~> expression <~ ")" ^^ { case exp => Upper(exp) } - | LOWER ~ "(" ~> expression <~ ")" ^^ { case exp => Lower(exp) } - | IF ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ - { case c ~ t ~ f => If(c, t, f) } + ( ident <~ ("(" ~ "*" ~ ")") ^^ { case udfName => + if (lexical.normalizeKeyword(udfName) == "count") { + Count(Literal(1)) + } else { + throw new AnalysisException(s"invalid expression $udfName(*)") + } + } + | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ + { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) } + | ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs => + lexical.normalizeKeyword(udfName) match { + case "sum" => SumDistinct(exprs.head) + case "count" => CountDistinct(exprs) + } + } + | APPROXIMATE ~> ident ~ ("(" ~ DISTINCT ~> expression <~ ")") ^^ { case udfName ~ exp => + if (lexical.normalizeKeyword(udfName) == "count") { + ApproxCountDistinct(exp) + } else { + throw new AnalysisException(s"invalid function approximate $udfName") + } + } + | APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ ident ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ + { case s ~ _ ~ udfName ~ _ ~ _ ~ exp => + if (lexical.normalizeKeyword(udfName) == "count") { + ApproxCountDistinct(exp, s.toDouble) + } else { + throw new AnalysisException(s"invalid function approximate($floatLit) $udfName") + } + } | CASE ~> expression.? ~ rep1(WHEN ~> expression ~ (THEN ~> expression)) ~ (ELSE ~> expression).? <~ END ^^ { case casePart ~ altPart ~ elsePart => @@ -304,16 +302,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } ++ elsePart casePart.map(CaseKeyWhen(_, branches)).getOrElse(CaseWhen(branches)) } - | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) <~ ")" ^^ - { case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) } - | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ - { case s ~ p ~ l => Substring(s, p, l) } - | COALESCE ~ "(" ~> repsep(expression, ",") <~ ")" ^^ { case exprs => Coalesce(exprs) } - | SQRT ~ "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } - | ABS ~ "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) } - | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ - { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) } - ) + ) protected lazy val cast: Parser[Expression] = CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5883d938b676d..02b10c444d1a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -461,7 +461,9 @@ class Analyzer( case q: LogicalPlan => q transformExpressions { case u @ UnresolvedFunction(name, children) if u.childrenResolved => - registry.lookupFunction(name, children) + withPosition(u) { + registry.lookupFunction(name, children) + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 0849faa9bfa7b..406f6fad8413b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -17,24 +17,27 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.sql.catalyst.expressions.Expression -import scala.collection.mutable +import scala.reflect.ClassTag +import scala.util.{Failure, Success, Try} + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.StringKeyHashMap + /** A catalog for looking up user defined functions, used by an [[Analyzer]]. */ trait FunctionRegistry { - type FunctionBuilder = Seq[Expression] => Expression def registerFunction(name: String, builder: FunctionBuilder): Unit + @throws[AnalysisException]("If function does not exist") def lookupFunction(name: String, children: Seq[Expression]): Expression - - def conf: CatalystConf } trait OverrideFunctionRegistry extends FunctionRegistry { - val functionBuilders = StringKeyHashMap[FunctionBuilder](conf.caseSensitiveAnalysis) + private val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive = false) override def registerFunction(name: String, builder: FunctionBuilder): Unit = { functionBuilders.put(name, builder) @@ -45,16 +48,19 @@ trait OverrideFunctionRegistry extends FunctionRegistry { } } -class SimpleFunctionRegistry(val conf: CatalystConf) extends FunctionRegistry { +class SimpleFunctionRegistry extends FunctionRegistry { - val functionBuilders = StringKeyHashMap[FunctionBuilder](conf.caseSensitiveAnalysis) + private val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive = false) override def registerFunction(name: String, builder: FunctionBuilder): Unit = { functionBuilders.put(name, builder) } override def lookupFunction(name: String, children: Seq[Expression]): Expression = { - functionBuilders(name)(children) + val func = functionBuilders.get(name).getOrElse { + throw new AnalysisException(s"undefined function $name") + } + func(children) } } @@ -70,30 +76,89 @@ object EmptyFunctionRegistry extends FunctionRegistry { override def lookupFunction(name: String, children: Seq[Expression]): Expression = { throw new UnsupportedOperationException } - - override def conf: CatalystConf = throw new UnsupportedOperationException } -/** - * Build a map with String type of key, and it also supports either key case - * sensitive or insensitive. - * TODO move this into util folder? - */ -object StringKeyHashMap { - def apply[T](caseSensitive: Boolean): StringKeyHashMap[T] = caseSensitive match { - case false => new StringKeyHashMap[T](_.toLowerCase) - case true => new StringKeyHashMap[T](identity) - } -} -class StringKeyHashMap[T](normalizer: (String) => String) { - private val base = new collection.mutable.HashMap[String, T]() +object FunctionRegistry { - def apply(key: String): T = base(normalizer(key)) + type FunctionBuilder = Seq[Expression] => Expression - def get(key: String): Option[T] = base.get(normalizer(key)) - def put(key: String, value: T): Option[T] = base.put(normalizer(key), value) - def remove(key: String): Option[T] = base.remove(normalizer(key)) - def iterator: Iterator[(String, T)] = base.toIterator + val expressions: Map[String, FunctionBuilder] = Map( + // Non aggregate functions + expression[Abs]("abs"), + expression[CreateArray]("array"), + expression[Coalesce]("coalesce"), + expression[Explode]("explode"), + expression[Lower]("lower"), + expression[Substring]("substr"), + expression[Substring]("substring"), + expression[Rand]("rand"), + expression[Randn]("randn"), + expression[CreateStruct]("struct"), + expression[Sqrt]("sqrt"), + expression[Upper]("upper"), + + // Math functions + expression[Acos]("acos"), + expression[Asin]("asin"), + expression[Atan]("atan"), + expression[Atan2]("atan2"), + expression[Cbrt]("cbrt"), + expression[Ceil]("ceil"), + expression[Cos]("cos"), + expression[Exp]("exp"), + expression[Expm1]("expm1"), + expression[Floor]("floor"), + expression[Hypot]("hypot"), + expression[Log]("log"), + expression[Log10]("log10"), + expression[Log1p]("log1p"), + expression[Pow]("pow"), + expression[Rint]("rint"), + expression[Signum]("signum"), + expression[Sin]("sin"), + expression[Sinh]("sinh"), + expression[Tan]("tan"), + expression[Tanh]("tanh"), + expression[ToDegrees]("todegrees"), + expression[ToRadians]("toradians"), + + // aggregate functions + expression[Average]("avg"), + expression[Count]("count"), + expression[First]("first"), + expression[Last]("last"), + expression[Max]("max"), + expression[Min]("min"), + expression[Sum]("sum") + ) + + /** See usage above. */ + private def expression[T <: Expression](name: String) + (implicit tag: ClassTag[T]): (String, FunctionBuilder) = { + // Use the companion class to find apply methods. + val objectClass = Class.forName(tag.runtimeClass.getName + "$") + val companionObj = objectClass.getDeclaredField("MODULE$").get(null) + + // See if we can find an apply that accepts Seq[Expression] + val varargApply = Try(objectClass.getDeclaredMethod("apply", classOf[Seq[_]])).toOption + + val builder = (expressions: Seq[Expression]) => { + if (varargApply.isDefined) { + // If there is an apply method that accepts Seq[Expression], use that one. + varargApply.get.invoke(companionObj, expressions).asInstanceOf[Expression] + } else { + // Otherwise, find an apply method that matches the number of arguments, and use that. + val params = Seq.fill(expressions.size)(classOf[Expression]) + val f = Try(objectClass.getDeclaredMethod("apply", params : _*)) match { + case Success(e) => + e + case Failure(e) => + throw new AnalysisException(s"Invalid number of arguments for function $name") + } + f.invoke(companionObj, expressions : _*).asInstanceOf[Expression] + } + } + (name, builder) + } } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index a9a9c0cfb7027..f2ed1f0929987 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -23,6 +23,15 @@ import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ + +/** + * For Catalyst to work correctly, concrete implementations of [[Expression]]s must be case classes + * whose constructor arguments are all Expressions types. In addition, if we want to support more + * than one constructor, define those constructors explicitly as apply methods in the companion + * object. + * + * See [[Substring]] for an example. + */ abstract class Expression extends TreeNode[Expression] { self: Product => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index b2647124c4e49..6e4e9cb1be090 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.types.{DataType, DoubleType} import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -46,11 +47,29 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable { } /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ -case class Rand(seed: Long = Utils.random.nextLong()) extends RDG(seed) { +case class Rand(seed: Long) extends RDG(seed) { override def eval(input: Row): Double = rng.nextDouble() } +object Rand { + def apply(): Rand = apply(Utils.random.nextLong()) + + def apply(seed: Expression): Rand = apply(seed match { + case IntegerLiteral(s) => s + case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") + }) +} + /** Generate a random column with i.i.d. gaussian random distribution. */ -case class Randn(seed: Long = Utils.random.nextLong()) extends RDG(seed) { +case class Randn(seed: Long) extends RDG(seed) { override def eval(input: Row): Double = rng.nextGaussian() } + +object Randn { + def apply(): Randn = apply(Utils.random.nextLong()) + + def apply(seed: Expression): Randn = apply(seed match { + case IntegerLiteral(s) => s + case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") + }) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index aae122a981e47..856f56488c7a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -227,6 +227,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) override def foldable: Boolean = str.foldable && pos.foldable && len.foldable override def nullable: Boolean = str.nullable || pos.nullable || len.nullable + override def dataType: DataType = { if (!resolved) { throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved") @@ -287,3 +288,9 @@ case class Substring(str: Expression, pos: Expression, len: Expression) case _ => s"SUBSTR($str, $pos, $len)" } } + +object Substring { + def apply(str: Expression, pos: Expression): Substring = { + apply(str, pos, Literal(Integer.MAX_VALUE)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala new file mode 100644 index 0000000000000..191d5e6399fc9 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +/** + * Build a map with String type of key, and it also supports either key case + * sensitive or insensitive. + */ +object StringKeyHashMap { + def apply[T](caseSensitive: Boolean): StringKeyHashMap[T] = caseSensitive match { + case false => new StringKeyHashMap[T](_.toLowerCase) + case true => new StringKeyHashMap[T](identity) + } +} + + +class StringKeyHashMap[T](normalizer: (String) => String) { + private val base = new collection.mutable.HashMap[String, T]() + + def apply(key: String): T = base(normalizer(key)) + + def get(key: String): Option[T] = base.get(normalizer(key)) + + def put(key: String, value: T): Option[T] = base.put(normalizer(key), value) + + def remove(key: String): Option[T] = base.remove(normalizer(key)) + + def iterator: Iterator[(String, T)] = base.toIterator +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index ddb54025baa24..8cad3885b7d46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -120,7 +120,11 @@ class SQLContext(@transient val sparkContext: SparkContext) // TODO how to handle the temp function per user session? @transient - protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry(conf) + protected[sql] lazy val functionRegistry: FunctionRegistry = { + val fr = new SimpleFunctionRegistry + FunctionRegistry.expressions.foreach { case (name, func) => fr.registerFunction(name, func) } + fr + } @transient protected[sql] lazy val analyzer: Analyzer = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 064c040d2b771..703a34c47ec20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -25,6 +25,48 @@ class UDFSuite extends QueryTest { private lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ + test("built-in fixed arity expressions") { + val df = ctx.emptyDataFrame + df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)") + } + + test("built-in vararg expressions") { + val df = Seq((1, 2)).toDF("a", "b") + df.selectExpr("array(a, b)") + df.selectExpr("struct(a, b)") + } + + test("built-in expressions with multiple constructors") { + val df = Seq(("abcd", 2)).toDF("a", "b") + df.selectExpr("substr(a, 2)", "substr(a, 2, 3)").collect() + } + + test("count") { + val df = Seq(("abcd", 2)).toDF("a", "b") + df.selectExpr("count(a)") + } + + test("count distinct") { + val df = Seq(("abcd", 2)).toDF("a", "b") + df.selectExpr("count(distinct a)") + } + + test("error reporting for incorrect number of arguments") { + val df = ctx.emptyDataFrame + val e = intercept[AnalysisException] { + df.selectExpr("substr('abcd', 2, 3, 4)") + } + assert(e.getMessage.contains("arguments")) + } + + test("error reporting for undefined functions") { + val df = ctx.emptyDataFrame + val e = intercept[AnalysisException] { + df.selectExpr("a_function_that_does_not_exist()") + } + assert(e.getMessage.contains("undefined function")) + } + test("Simple UDF") { ctx.udf.register("strLenScala", (_: String).length) assert(ctx.sql("SELECT strLenScala('test')").head().getInt(0) === 4) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index b8f294c262af7..3b8cafb4a6c37 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -39,13 +39,12 @@ import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateSubQueries, OverrideCatalog, OverrideFunctionRegistry} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} import org.apache.spark.sql.sources.DataSourceStrategy -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -374,10 +373,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // Note that HiveUDFs will be overridden by functions registered in this context. @transient - override protected[sql] lazy val functionRegistry = - new HiveFunctionRegistry with OverrideFunctionRegistry { - override def conf: CatalystConf = currentSession().conf - } + override protected[sql] lazy val functionRegistry: FunctionRegistry = + new HiveFunctionRegistry with OverrideFunctionRegistry /* An analyzer that uses the Hive metastore. */ @transient diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 01f47352b2313..6e6ac987b668a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -17,11 +17,8 @@ package org.apache.spark.sql.hive -import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper -import org.apache.spark.sql.AnalysisException - import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConversions._ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ConstantObjectInspector} import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions @@ -30,8 +27,11 @@ import org.apache.hadoop.hive.ql.exec._ import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper import org.apache.spark.Logging +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ @@ -40,20 +40,18 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.types._ -/* Implicit conversions */ -import scala.collection.JavaConversions._ private[hive] abstract class HiveFunctionRegistry extends analysis.FunctionRegistry with HiveInspectors { def getFunctionInfo(name: String): FunctionInfo = FunctionRegistry.getFunctionInfo(name) - def lookupFunction(name: String, children: Seq[Expression]): Expression = { + override def lookupFunction(name: String, children: Seq[Expression]): Expression = { // We only look it up to see if it exists, but do not include it in the HiveUDF since it is // not always serializable. val functionInfo: FunctionInfo = Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse( - sys.error(s"Couldn't find function $name")) + throw new AnalysisException(s"undefined function $name")) val functionClassName = functionInfo.getFunctionClass.getName From e6fb6cedf3ecbde6f01d4753d7d05d0c52827fce Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Tue, 9 Jun 2015 12:19:01 +0100 Subject: [PATCH 249/254] [STREAMING] [DOC] Remove duplicated description about WAL I noticed there is a duplicated description about WAL. ``` To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming. To ensure zero data loss, enable the Write Ahead Logs (introduced in Spark 1.2). ``` Let's remove the duplication. I don't file this issue in JIRA because it's minor. Author: Kousuke Saruta Closes #6719 from sarutak/remove-multiple-description and squashes the following commits: cc9bb21 [Kousuke Saruta] Removed duplicated description about WAL --- docs/streaming-kafka-integration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index d6d5605948a5a..998c8c994e4b4 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -7,7 +7,7 @@ title: Spark Streaming + Kafka Integration Guide ## Approach 1: Receiver-based Approach This approach uses a Receiver to receive the data. The Received is implemented using the Kafka high-level consumer API. As with all receivers, the data received from Kafka through a Receiver is stored in Spark executors, and then jobs launched by Spark Streaming processes the data. -However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming. To ensure zero data loss, enable the Write Ahead Logs (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs. +However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs. Next, we discuss how to use this approach in your streaming application. From 6c1723abeb4e0580efec05a655343f46521fc265 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 9 Jun 2015 15:00:35 +0100 Subject: [PATCH 250/254] [SPARK-8140] [MLLIB] Remove construct to get weights in StreamingLinearAlgorithm Author: MechCoder Closes #6720 from MechCoder/empty_model_check and squashes the following commits: 3a07de5 [MechCoder] Remove construct to get weights in StreamingLinearAlgorithm --- .../spark/mllib/regression/StreamingLinearAlgorithm.scala | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index 39308e5ae1dde..aee51bf22d8d0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -83,12 +83,7 @@ abstract class StreamingLinearAlgorithm[ throw new IllegalArgumentException("Model must be initialized before starting training.") } data.foreachRDD { (rdd, time) => - val initialWeights = - model match { - case Some(m) => - m.weights - } - model = Some(algorithm.run(rdd, initialWeights)) + model = Some(algorithm.run(rdd, model.get.weights)) logInfo("Model updated at time %s".format(time.toString)) val display = model.get.weights.size match { case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...") From 490d5a72ec1e5105f030fd7110acf62534e05f5a Mon Sep 17 00:00:00 2001 From: FavioVazquez Date: Tue, 9 Jun 2015 15:02:18 +0100 Subject: [PATCH 251/254] [SPARK-8274] [DOCUMENTATION-MLLIB] Fix wrong URLs in MLlib Frequent Pattern Mining Documentation There is a mistake in the URLs of the Scala section of FP-Growth in the MLlib Frequent Pattern Mining documentation. The URL points to https://spark.apache.org/docs/latest/api/java/org/apache/spark/mllib/fpm/FPGrowth.html which is the Java's API, the link should point to the Scala API https://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowth There's another mistake in the FP-GrowthModel in the same section, the link points, again, to the Java's API https://spark.apache.org/docs/latest/api/java/org/apache/spark/mllib/fpm/FPGrowthModel.html, the link should point to the Scala API https://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowthModel Author: FavioVazquez Closes #6722 from FavioVazquez/fix-wrog-urls-mllib-fpgrowth and squashes the following commits: e1ca54d [FavioVazquez] - Fixed wrong URLs in MLlib Frequent Pattern Mining, FP-Growth Scala section ad882a3 [FavioVazquez] Merge remote-tracking branch 'upstream/master' f27a20b [FavioVazquez] Merge remote-tracking branch 'upstream/master' 9af7074 [FavioVazquez] Merge remote-tracking branch 'upstream/master' edab1ef [FavioVazquez] Merge remote-tracking branch 'upstream/master' b2e2f8c [FavioVazquez] Merge remote-tracking branch 'upstream/master' --- docs/mllib-frequent-pattern-mining.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md index 9fd9be0dd01b1..bcc066a185526 100644 --- a/docs/mllib-frequent-pattern-mining.md +++ b/docs/mllib-frequent-pattern-mining.md @@ -39,11 +39,11 @@ MLlib's FP-growth implementation takes the following (hyper-)parameters:
    -[`FPGrowth`](api/java/org/apache/spark/mllib/fpm/FPGrowth.html) implements the +[`FPGrowth`](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowth) implements the FP-growth algorithm. It take a `JavaRDD` of transactions, where each transaction is an `Iterable` of items of a generic type. Calling `FPGrowth.run` with transactions returns an -[`FPGrowthModel`](api/java/org/apache/spark/mllib/fpm/FPGrowthModel.html) +[`FPGrowthModel`](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowthModel) that stores the frequent itemsets with their frequencies. {% highlight scala %} From 0d5892dc723d203e7d892d3beacbaa97aedb1a24 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 9 Jun 2015 15:44:02 -0700 Subject: [PATCH 252/254] [MINOR] [UI] DAG visualization: trim whitespace from input Just as a safeguard against DOM rewriting. Author: Andrew Or Closes #6732 from andrewor14/dag-viz-trim and squashes the following commits: 7e9bacb [Andrew Or] [MINOR] [UI] DAG visualization: trim whitespace from input --- .../main/resources/org/apache/spark/ui/static/spark-dag-viz.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js index e96af8768daa0..7a0dec2a3eaec 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -235,7 +235,7 @@ function renderDagVizForJob(svgContainer) { // them separately later. Note that we cannot draw them now because we need to // put these edges in a separate container that is on top of all stage graphs. metadata.selectAll(".incoming-edge").each(function(v) { - var edge = d3.select(this).text().split(","); // e.g. 3,4 => [3, 4] + var edge = d3.select(this).text().trim().split(","); // e.g. 3,4 => [3, 4] crossStageEdges.push(edge); }); }); From 6e4fb0c9e8f03cf068c422777cfce82a89e8e738 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Tue, 9 Jun 2015 16:14:21 -0700 Subject: [PATCH 253/254] [SPARK-6511] [DOCUMENTATION] Explain how to use Hadoop provided builds This provides preliminary documentation pointing out how to use the Hadoop free builds. I am hoping over time this list can grow to include most of the popular Hadoop distributions. Getting more people using these builds will help us long term reduce the number of binaries we build. Author: Patrick Wendell Closes #6729 from pwendell/hadoop-provided and squashes the following commits: 1113b76 [Patrick Wendell] [SPARK-6511] [Documentation] Explain how to use Hadoop provided builds --- docs/hadoop-provided.md | 26 ++++++++++++++++++++++++++ docs/index.md | 10 +++++++--- 2 files changed, 33 insertions(+), 3 deletions(-) create mode 100644 docs/hadoop-provided.md diff --git a/docs/hadoop-provided.md b/docs/hadoop-provided.md new file mode 100644 index 0000000000000..0ba5a58051abc --- /dev/null +++ b/docs/hadoop-provided.md @@ -0,0 +1,26 @@ +--- +layout: global +displayTitle: Using Spark's "Hadoop Free" Build +title: Using Spark's "Hadoop Free" Build +--- + +Spark uses Hadoop client libraries for HDFS and YARN. Starting in version Spark 1.4, the project packages "Hadoop free" builds that lets you more easily connect a single Spark binary to any Hadoop version. To use these builds, you need to modify `SPARK_DIST_CLASSPATH` to include Hadoop's package jars. The most convenient place to do this is by adding an entry in `conf/spark-env.sh`. + +This page describes how to connect Spark to Hadoop for different types of distributions. + +# Apache Hadoop +For Apache distributions, you can use Hadoop's 'classpath' command. For instance: + +{% highlight bash %} +### in conf/spark-env.sh ### + +# If 'hadoop' binary is on your PATH +export SPARK_DIST_CLASSPATH=$(hadoop classpath) + +# With explicit path to 'hadoop' binary +export SPARK_DIST_CLASSPATH=$(/path/to/hadoop/bin/hadoop classpath) + +# Passing a Hadoop configuration directory +export SPARK_DIST_CLASSPATH=$(hadoop classpath --config /path/to/configs) + +{% endhighlight %} diff --git a/docs/index.md b/docs/index.md index 7939657915fc9..d85cf12defefd 100644 --- a/docs/index.md +++ b/docs/index.md @@ -12,9 +12,13 @@ It also supports a rich set of higher-level tools including [Spark SQL](sql-prog # Downloading -Get Spark from the [downloads page](http://spark.apache.org/downloads.html) of the project website. This documentation is for Spark version {{site.SPARK_VERSION}}. The downloads page -contains Spark packages for many popular HDFS versions. If you'd like to build Spark from -scratch, visit [Building Spark](building-spark.html). +Get Spark from the [downloads page](http://spark.apache.org/downloads.html) of the project website. This documentation is for Spark version {{site.SPARK_VERSION}}. Spark uses Hadoop's client libraries for HDFS and YARN. Downloads are pre-packaged for a handful of popular Hadoop versions. +Users can also download a "Hadoop free" binary and run Spark with any Hadoop version +[by augmenting Spark's classpath](hadoop-provided.html). + +If you'd like to build Spark from +source, visit [Building Spark](building-spark.html). + Spark runs on both Windows and UNIX-like systems (e.g. Linux, Mac OS). It's easy to run locally on one machine --- all you need is to have `java` installed on your system `PATH`, From 778f3ca81f8d90faec0775509632fe68f1399dc4 Mon Sep 17 00:00:00 2001 From: "navis.ryu" Date: Tue, 9 Jun 2015 19:33:00 -0700 Subject: [PATCH 254/254] [SPARK-7792] [SQL] HiveContext registerTempTable not thread safe Just replaced mutable.HashMap to ConcurrentHashMap Author: navis.ryu Closes #6699 from navis/SPARK-7792 and squashes the following commits: f03654a [navis.ryu] [SPARK-7792] [SQL] HiveContext registerTempTable not thread safe --- .../spark/sql/catalyst/analysis/Catalog.scala | 28 +++++++++++-------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 3e240fd55e250..1541491608b24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -17,7 +17,11 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConversions._ import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.EmptyConf @@ -81,18 +85,18 @@ trait Catalog { } class SimpleCatalog(val conf: CatalystConf) extends Catalog { - val tables = new mutable.HashMap[String, LogicalPlan]() + val tables = new ConcurrentHashMap[String, LogicalPlan] override def registerTable( tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { val tableIdent = processTableIdentifier(tableIdentifier) - tables += ((getDbTableName(tableIdent), plan)) + tables.put(getDbTableName(tableIdent), plan) } override def unregisterTable(tableIdentifier: Seq[String]): Unit = { val tableIdent = processTableIdentifier(tableIdentifier) - tables -= getDbTableName(tableIdent) + tables.remove(getDbTableName(tableIdent)) } override def unregisterAllTables(): Unit = { @@ -101,10 +105,7 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { override def tableExists(tableIdentifier: Seq[String]): Boolean = { val tableIdent = processTableIdentifier(tableIdentifier) - tables.get(getDbTableName(tableIdent)) match { - case Some(_) => true - case None => false - } + tables.containsKey(getDbTableName(tableIdent)) } override def lookupRelation( @@ -112,7 +113,10 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { alias: Option[String] = None): LogicalPlan = { val tableIdent = processTableIdentifier(tableIdentifier) val tableFullName = getDbTableName(tableIdent) - val table = tables.getOrElse(tableFullName, sys.error(s"Table Not Found: $tableFullName")) + val table = tables.get(tableFullName) + if (table == null) { + sys.error(s"Table Not Found: $tableFullName") + } val tableWithQualifiers = Subquery(tableIdent.last, table) // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are @@ -121,9 +125,11 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { } override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { - tables.map { - case (name, _) => (name, true) - }.toSeq + val result = ArrayBuffer.empty[(String, Boolean)] + for (name <- tables.keySet()) { + result += ((name, true)) + } + result } override def refreshTable(databaseName: String, tableName: String): Unit = {